cgcardona / muse public
muse.py python
374 lines 11.4 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """MuseClient — commit, branch, merge, graph via MUSE HTTP API."""
2
3 from __future__ import annotations
4
5 import json
6 import logging
7 import uuid
8 from pathlib import Path
9 from typing import Any
10
11 import httpx
12
13 from tourdeforce.config import TDFConfig
14 from tourdeforce.models import (
15 Component, EventType, Severity, TraceContext, sha256_payload,
16 )
17 from tourdeforce.collectors.events import EventCollector
18 from tourdeforce.collectors.metrics import MetricsCollector
19
20 logger = logging.getLogger(__name__)
21
22
23 class MuseClient:
24 """Talks to the MUSE VCS HTTP API for commits, merges, and graph exports."""
25
26 def __init__(
27 self,
28 config: TDFConfig,
29 event_collector: EventCollector,
30 metrics: MetricsCollector,
31 payload_dir: Path,
32 muse_dir: Path,
33 ) -> None:
34 self._config = config
35 self._events = event_collector
36 self._metrics = metrics
37 self._muse_dir = muse_dir
38 self._muse_dir.mkdir(parents=True, exist_ok=True)
39 self._client = httpx.AsyncClient(
40 base_url=config.muse_base_url,
41 timeout=30.0,
42 headers=config.auth_headers,
43 )
44
45 async def close(self) -> None:
46 await self._client.aclose()
47
48 async def save_variation(
49 self,
50 run_id: str,
51 trace: TraceContext,
52 *,
53 variation_id: str | None = None,
54 intent: str = "compose",
55 phrases: list[dict] | None = None,
56 affected_tracks: list[str] | None = None,
57 affected_regions: list[str] | None = None,
58 parent_variation_id: str | None = None,
59 parent2_variation_id: str | None = None,
60 conversation_id: str = "default",
61 ) -> str:
62 """Persist a variation into MUSE history."""
63 span = trace.new_span("muse_commit")
64 vid = variation_id or str(uuid.uuid4())
65
66 payload = {
67 "project_id": self._config.muse_project_id,
68 "variation_id": vid,
69 "intent": intent,
70 "conversation_id": conversation_id,
71 "parent_variation_id": parent_variation_id,
72 "parent2_variation_id": parent2_variation_id,
73 "phrases": phrases or [],
74 "affected_tracks": affected_tracks or [],
75 "affected_regions": affected_regions or [],
76 }
77
78 await self._events.emit(
79 run_id=run_id,
80 scenario="muse_commit",
81 component=Component.MUSE,
82 event_type=EventType.MUSE_COMMIT,
83 trace=trace,
84 data={"variation_id": vid, "parent": parent_variation_id, "intent": intent},
85 )
86
87 async with self._metrics.timer("muse_save_variation", run_id):
88 resp = await self._client.post("/variations", json=payload)
89
90 if resp.status_code != 200:
91 trace.end_span()
92 raise MuseError(f"MUSE save_variation failed: {resp.status_code} — {resp.text[:500]}")
93
94 trace.end_span()
95 return vid
96
97 async def set_head(
98 self,
99 run_id: str,
100 trace: TraceContext,
101 variation_id: str,
102 ) -> None:
103 """set HEAD pointer."""
104 resp = await self._client.post("/head", json={"variation_id": variation_id})
105 if resp.status_code != 200:
106 raise MuseError(f"MUSE set_head failed: {resp.status_code} — {resp.text[:500]}")
107
108 async def merge(
109 self,
110 run_id: str,
111 trace: TraceContext,
112 left_id: str,
113 right_id: str,
114 *,
115 force: bool = True,
116 conversation_id: str = "default",
117 ) -> MergeResult:
118 """Three-way merge of two variations."""
119 span = trace.new_span("muse_merge")
120
121 await self._events.emit(
122 run_id=run_id,
123 scenario="muse_merge",
124 component=Component.MUSE,
125 event_type=EventType.MUSE_MERGE,
126 trace=trace,
127 data={"left": left_id, "right": right_id},
128 )
129
130 payload = {
131 "project_id": self._config.muse_project_id,
132 "left_id": left_id,
133 "right_id": right_id,
134 "conversation_id": conversation_id,
135 "force": force,
136 }
137
138 async with self._metrics.timer("muse_merge", run_id):
139 resp = await self._client.post("/merge", json=payload)
140
141 body = resp.json()
142
143 if resp.status_code == 409:
144 conflicts = body.get("detail", {}).get("conflicts", [])
145 trace.end_span()
146 return MergeResult(
147 success=False,
148 merge_variation_id="",
149 conflicts=conflicts,
150 status_code=409,
151 )
152
153 if resp.status_code != 200:
154 trace.end_span()
155 raise MuseError(f"MUSE merge failed: {resp.status_code} — {resp.text[:500]}")
156
157 trace.end_span()
158 return MergeResult(
159 success=True,
160 merge_variation_id=body.get("merge_variation_id", ""),
161 executed=body.get("executed", 0),
162 status_code=200,
163 )
164
165 async def get_log(
166 self,
167 run_id: str,
168 trace: TraceContext,
169 ) -> dict[str, Any]:
170 """Fetch the commit DAG."""
171 resp = await self._client.get("/log", params={"project_id": self._config.muse_project_id})
172 if resp.status_code != 200:
173 raise MuseError(f"MUSE get_log failed: {resp.status_code}")
174 graph = resp.json()
175
176 # Persist graph
177 graph_file = self._muse_dir / "graph.json"
178 graph_file.write_text(json.dumps(graph, indent=2))
179
180 return graph
181
182 async def checkout(
183 self,
184 run_id: str,
185 trace: TraceContext,
186 target_variation_id: str,
187 *,
188 force: bool = True,
189 conversation_id: str = "default",
190 ) -> CheckoutResult:
191 """Checkout to a specific variation.
192
193 Returns a CheckoutResult with success/blocked status and drift details.
194 Non-force checkouts may return 409 if the working tree has drift.
195 """
196 span = trace.new_span("muse_checkout")
197
198 payload = {
199 "project_id": self._config.muse_project_id,
200 "target_variation_id": target_variation_id,
201 "conversation_id": conversation_id,
202 "force": force,
203 }
204
205 await self._events.emit(
206 run_id=run_id,
207 scenario="muse_checkout",
208 component=Component.MUSE,
209 event_type=EventType.MUSE_COMMIT,
210 trace=trace,
211 tags={"operation": "checkout"},
212 data={"target": target_variation_id, "force": force},
213 )
214
215 async with self._metrics.timer("muse_checkout", run_id, tags={"force": str(force)}):
216 resp = await self._client.post("/checkout", json=payload)
217
218 body = resp.json()
219
220 if resp.status_code == 409:
221 detail = body.get("detail", body)
222 trace.end_span()
223 return CheckoutResult(
224 success=False,
225 blocked=True,
226 target=target_variation_id,
227 drift_severity=detail.get("severity", "unknown"),
228 drift_total_changes=detail.get("total_changes", 0),
229 status_code=409,
230 )
231
232 if resp.status_code == 404:
233 trace.end_span()
234 raise MuseError(f"Variation {target_variation_id} not found (404)")
235
236 if resp.status_code != 200:
237 trace.end_span()
238 raise MuseError(f"MUSE checkout failed: {resp.status_code} — {resp.text[:500]}")
239
240 trace.end_span()
241 return CheckoutResult(
242 success=True,
243 blocked=False,
244 target=target_variation_id,
245 head_moved=body.get("head_moved", False),
246 executed=body.get("executed", 0),
247 failed=body.get("failed", 0),
248 plan_hash=body.get("plan_hash", ""),
249 status_code=200,
250 )
251
252 async def save_conflict_branch(
253 self,
254 run_id: str,
255 trace: TraceContext,
256 *,
257 variation_id: str,
258 parent_variation_id: str,
259 intent: str,
260 target_region: str,
261 target_track: str,
262 notes: list[dict[str, Any]],
263 conversation_id: str = "default",
264 ) -> str:
265 """Save a variation with explicit note changes — used to create deliberate conflicts.
266
267 Both conflict branches must add notes at the same (pitch, start_beat) position
268 but with different content to trigger MUSE's merge conflict detection.
269 """
270 note_changes = [
271 {
272 "note_id": f"nc-{variation_id[:8]}-{target_region}-p{n['pitch']}b{n['start_beat']}",
273 "change_type": "added",
274 "before": None,
275 "after": n,
276 }
277 for n in notes
278 ]
279
280 phrases = [{
281 "phrase_id": f"ph-{variation_id[:8]}-{target_region}",
282 "track_id": target_track,
283 "region_id": target_region,
284 "start_beat": 0.0,
285 "end_beat": 8.0,
286 "label": f"{intent} ({target_region})",
287 "note_changes": note_changes,
288 "cc_events": [],
289 "pitch_bends": [],
290 "aftertouch": [],
291 }]
292
293 return await self.save_variation(
294 run_id=run_id,
295 trace=trace,
296 variation_id=variation_id,
297 intent=intent,
298 phrases=phrases,
299 affected_tracks=[target_track],
300 affected_regions=[target_region],
301 parent_variation_id=parent_variation_id,
302 conversation_id=conversation_id,
303 )
304
305
306 class CheckoutResult:
307 """Structured checkout result."""
308
309 def __init__(
310 self,
311 success: bool,
312 blocked: bool = False,
313 target: str = "",
314 head_moved: bool = False,
315 executed: int = 0,
316 failed: int = 0,
317 plan_hash: str = "",
318 drift_severity: str = "",
319 drift_total_changes: int = 0,
320 status_code: int = 200,
321 ) -> None:
322 self.success = success
323 self.blocked = blocked
324 self.target = target
325 self.head_moved = head_moved
326 self.executed = executed
327 self.failed = failed
328 self.plan_hash = plan_hash
329 self.drift_severity = drift_severity
330 self.drift_total_changes = drift_total_changes
331 self.status_code = status_code
332
333 def to_dict(self) -> dict[str, Any]:
334 return {
335 "success": self.success,
336 "blocked": self.blocked,
337 "target": self.target,
338 "head_moved": self.head_moved,
339 "executed": self.executed,
340 "plan_hash": self.plan_hash[:12] if self.plan_hash else "",
341 "drift_severity": self.drift_severity,
342 "drift_total_changes": self.drift_total_changes,
343 }
344
345
346 class MergeResult:
347 """Structured merge result."""
348
349 def __init__(
350 self,
351 success: bool,
352 merge_variation_id: str = "",
353 conflicts: list[dict] | None = None,
354 executed: int = 0,
355 status_code: int = 200,
356 ) -> None:
357 self.success = success
358 self.merge_variation_id = merge_variation_id
359 self.conflicts = conflicts or []
360 self.executed = executed
361 self.status_code = status_code
362
363 def to_dict(self) -> dict[str, Any]:
364 return {
365 "success": self.success,
366 "merge_variation_id": self.merge_variation_id,
367 "conflict_count": len(self.conflicts),
368 "conflicts": self.conflicts,
369 "executed": self.executed,
370 }
371
372
373 class MuseError(Exception):
374 pass