muse.py
python
| 1 | """MuseClient — commit, branch, merge, graph via MUSE HTTP API.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import json |
| 6 | import logging |
| 7 | import uuid |
| 8 | from pathlib import Path |
| 9 | from typing import Any |
| 10 | |
| 11 | import httpx |
| 12 | |
| 13 | from tourdeforce.config import TDFConfig |
| 14 | from tourdeforce.models import ( |
| 15 | Component, EventType, Severity, TraceContext, sha256_payload, |
| 16 | ) |
| 17 | from tourdeforce.collectors.events import EventCollector |
| 18 | from tourdeforce.collectors.metrics import MetricsCollector |
| 19 | |
| 20 | logger = logging.getLogger(__name__) |
| 21 | |
| 22 | |
| 23 | class MuseClient: |
| 24 | """Talks to the MUSE VCS HTTP API for commits, merges, and graph exports.""" |
| 25 | |
| 26 | def __init__( |
| 27 | self, |
| 28 | config: TDFConfig, |
| 29 | event_collector: EventCollector, |
| 30 | metrics: MetricsCollector, |
| 31 | payload_dir: Path, |
| 32 | muse_dir: Path, |
| 33 | ) -> None: |
| 34 | self._config = config |
| 35 | self._events = event_collector |
| 36 | self._metrics = metrics |
| 37 | self._muse_dir = muse_dir |
| 38 | self._muse_dir.mkdir(parents=True, exist_ok=True) |
| 39 | self._client = httpx.AsyncClient( |
| 40 | base_url=config.muse_base_url, |
| 41 | timeout=30.0, |
| 42 | headers=config.auth_headers, |
| 43 | ) |
| 44 | |
| 45 | async def close(self) -> None: |
| 46 | await self._client.aclose() |
| 47 | |
| 48 | async def save_variation( |
| 49 | self, |
| 50 | run_id: str, |
| 51 | trace: TraceContext, |
| 52 | *, |
| 53 | variation_id: str | None = None, |
| 54 | intent: str = "compose", |
| 55 | phrases: list[dict] | None = None, |
| 56 | affected_tracks: list[str] | None = None, |
| 57 | affected_regions: list[str] | None = None, |
| 58 | parent_variation_id: str | None = None, |
| 59 | parent2_variation_id: str | None = None, |
| 60 | conversation_id: str = "default", |
| 61 | ) -> str: |
| 62 | """Persist a variation into MUSE history.""" |
| 63 | span = trace.new_span("muse_commit") |
| 64 | vid = variation_id or str(uuid.uuid4()) |
| 65 | |
| 66 | payload = { |
| 67 | "project_id": self._config.muse_project_id, |
| 68 | "variation_id": vid, |
| 69 | "intent": intent, |
| 70 | "conversation_id": conversation_id, |
| 71 | "parent_variation_id": parent_variation_id, |
| 72 | "parent2_variation_id": parent2_variation_id, |
| 73 | "phrases": phrases or [], |
| 74 | "affected_tracks": affected_tracks or [], |
| 75 | "affected_regions": affected_regions or [], |
| 76 | } |
| 77 | |
| 78 | await self._events.emit( |
| 79 | run_id=run_id, |
| 80 | scenario="muse_commit", |
| 81 | component=Component.MUSE, |
| 82 | event_type=EventType.MUSE_COMMIT, |
| 83 | trace=trace, |
| 84 | data={"variation_id": vid, "parent": parent_variation_id, "intent": intent}, |
| 85 | ) |
| 86 | |
| 87 | async with self._metrics.timer("muse_save_variation", run_id): |
| 88 | resp = await self._client.post("/variations", json=payload) |
| 89 | |
| 90 | if resp.status_code != 200: |
| 91 | trace.end_span() |
| 92 | raise MuseError(f"MUSE save_variation failed: {resp.status_code} — {resp.text[:500]}") |
| 93 | |
| 94 | trace.end_span() |
| 95 | return vid |
| 96 | |
| 97 | async def set_head( |
| 98 | self, |
| 99 | run_id: str, |
| 100 | trace: TraceContext, |
| 101 | variation_id: str, |
| 102 | ) -> None: |
| 103 | """set HEAD pointer.""" |
| 104 | resp = await self._client.post("/head", json={"variation_id": variation_id}) |
| 105 | if resp.status_code != 200: |
| 106 | raise MuseError(f"MUSE set_head failed: {resp.status_code} — {resp.text[:500]}") |
| 107 | |
| 108 | async def merge( |
| 109 | self, |
| 110 | run_id: str, |
| 111 | trace: TraceContext, |
| 112 | left_id: str, |
| 113 | right_id: str, |
| 114 | *, |
| 115 | force: bool = True, |
| 116 | conversation_id: str = "default", |
| 117 | ) -> MergeResult: |
| 118 | """Three-way merge of two variations.""" |
| 119 | span = trace.new_span("muse_merge") |
| 120 | |
| 121 | await self._events.emit( |
| 122 | run_id=run_id, |
| 123 | scenario="muse_merge", |
| 124 | component=Component.MUSE, |
| 125 | event_type=EventType.MUSE_MERGE, |
| 126 | trace=trace, |
| 127 | data={"left": left_id, "right": right_id}, |
| 128 | ) |
| 129 | |
| 130 | payload = { |
| 131 | "project_id": self._config.muse_project_id, |
| 132 | "left_id": left_id, |
| 133 | "right_id": right_id, |
| 134 | "conversation_id": conversation_id, |
| 135 | "force": force, |
| 136 | } |
| 137 | |
| 138 | async with self._metrics.timer("muse_merge", run_id): |
| 139 | resp = await self._client.post("/merge", json=payload) |
| 140 | |
| 141 | body = resp.json() |
| 142 | |
| 143 | if resp.status_code == 409: |
| 144 | conflicts = body.get("detail", {}).get("conflicts", []) |
| 145 | trace.end_span() |
| 146 | return MergeResult( |
| 147 | success=False, |
| 148 | merge_variation_id="", |
| 149 | conflicts=conflicts, |
| 150 | status_code=409, |
| 151 | ) |
| 152 | |
| 153 | if resp.status_code != 200: |
| 154 | trace.end_span() |
| 155 | raise MuseError(f"MUSE merge failed: {resp.status_code} — {resp.text[:500]}") |
| 156 | |
| 157 | trace.end_span() |
| 158 | return MergeResult( |
| 159 | success=True, |
| 160 | merge_variation_id=body.get("merge_variation_id", ""), |
| 161 | executed=body.get("executed", 0), |
| 162 | status_code=200, |
| 163 | ) |
| 164 | |
| 165 | async def get_log( |
| 166 | self, |
| 167 | run_id: str, |
| 168 | trace: TraceContext, |
| 169 | ) -> dict[str, Any]: |
| 170 | """Fetch the commit DAG.""" |
| 171 | resp = await self._client.get("/log", params={"project_id": self._config.muse_project_id}) |
| 172 | if resp.status_code != 200: |
| 173 | raise MuseError(f"MUSE get_log failed: {resp.status_code}") |
| 174 | graph = resp.json() |
| 175 | |
| 176 | # Persist graph |
| 177 | graph_file = self._muse_dir / "graph.json" |
| 178 | graph_file.write_text(json.dumps(graph, indent=2)) |
| 179 | |
| 180 | return graph |
| 181 | |
| 182 | async def checkout( |
| 183 | self, |
| 184 | run_id: str, |
| 185 | trace: TraceContext, |
| 186 | target_variation_id: str, |
| 187 | *, |
| 188 | force: bool = True, |
| 189 | conversation_id: str = "default", |
| 190 | ) -> CheckoutResult: |
| 191 | """Checkout to a specific variation. |
| 192 | |
| 193 | Returns a CheckoutResult with success/blocked status and drift details. |
| 194 | Non-force checkouts may return 409 if the working tree has drift. |
| 195 | """ |
| 196 | span = trace.new_span("muse_checkout") |
| 197 | |
| 198 | payload = { |
| 199 | "project_id": self._config.muse_project_id, |
| 200 | "target_variation_id": target_variation_id, |
| 201 | "conversation_id": conversation_id, |
| 202 | "force": force, |
| 203 | } |
| 204 | |
| 205 | await self._events.emit( |
| 206 | run_id=run_id, |
| 207 | scenario="muse_checkout", |
| 208 | component=Component.MUSE, |
| 209 | event_type=EventType.MUSE_COMMIT, |
| 210 | trace=trace, |
| 211 | tags={"operation": "checkout"}, |
| 212 | data={"target": target_variation_id, "force": force}, |
| 213 | ) |
| 214 | |
| 215 | async with self._metrics.timer("muse_checkout", run_id, tags={"force": str(force)}): |
| 216 | resp = await self._client.post("/checkout", json=payload) |
| 217 | |
| 218 | body = resp.json() |
| 219 | |
| 220 | if resp.status_code == 409: |
| 221 | detail = body.get("detail", body) |
| 222 | trace.end_span() |
| 223 | return CheckoutResult( |
| 224 | success=False, |
| 225 | blocked=True, |
| 226 | target=target_variation_id, |
| 227 | drift_severity=detail.get("severity", "unknown"), |
| 228 | drift_total_changes=detail.get("total_changes", 0), |
| 229 | status_code=409, |
| 230 | ) |
| 231 | |
| 232 | if resp.status_code == 404: |
| 233 | trace.end_span() |
| 234 | raise MuseError(f"Variation {target_variation_id} not found (404)") |
| 235 | |
| 236 | if resp.status_code != 200: |
| 237 | trace.end_span() |
| 238 | raise MuseError(f"MUSE checkout failed: {resp.status_code} — {resp.text[:500]}") |
| 239 | |
| 240 | trace.end_span() |
| 241 | return CheckoutResult( |
| 242 | success=True, |
| 243 | blocked=False, |
| 244 | target=target_variation_id, |
| 245 | head_moved=body.get("head_moved", False), |
| 246 | executed=body.get("executed", 0), |
| 247 | failed=body.get("failed", 0), |
| 248 | plan_hash=body.get("plan_hash", ""), |
| 249 | status_code=200, |
| 250 | ) |
| 251 | |
| 252 | async def save_conflict_branch( |
| 253 | self, |
| 254 | run_id: str, |
| 255 | trace: TraceContext, |
| 256 | *, |
| 257 | variation_id: str, |
| 258 | parent_variation_id: str, |
| 259 | intent: str, |
| 260 | target_region: str, |
| 261 | target_track: str, |
| 262 | notes: list[dict[str, Any]], |
| 263 | conversation_id: str = "default", |
| 264 | ) -> str: |
| 265 | """Save a variation with explicit note changes — used to create deliberate conflicts. |
| 266 | |
| 267 | Both conflict branches must add notes at the same (pitch, start_beat) position |
| 268 | but with different content to trigger MUSE's merge conflict detection. |
| 269 | """ |
| 270 | note_changes = [ |
| 271 | { |
| 272 | "note_id": f"nc-{variation_id[:8]}-{target_region}-p{n['pitch']}b{n['start_beat']}", |
| 273 | "change_type": "added", |
| 274 | "before": None, |
| 275 | "after": n, |
| 276 | } |
| 277 | for n in notes |
| 278 | ] |
| 279 | |
| 280 | phrases = [{ |
| 281 | "phrase_id": f"ph-{variation_id[:8]}-{target_region}", |
| 282 | "track_id": target_track, |
| 283 | "region_id": target_region, |
| 284 | "start_beat": 0.0, |
| 285 | "end_beat": 8.0, |
| 286 | "label": f"{intent} ({target_region})", |
| 287 | "note_changes": note_changes, |
| 288 | "cc_events": [], |
| 289 | "pitch_bends": [], |
| 290 | "aftertouch": [], |
| 291 | }] |
| 292 | |
| 293 | return await self.save_variation( |
| 294 | run_id=run_id, |
| 295 | trace=trace, |
| 296 | variation_id=variation_id, |
| 297 | intent=intent, |
| 298 | phrases=phrases, |
| 299 | affected_tracks=[target_track], |
| 300 | affected_regions=[target_region], |
| 301 | parent_variation_id=parent_variation_id, |
| 302 | conversation_id=conversation_id, |
| 303 | ) |
| 304 | |
| 305 | |
| 306 | class CheckoutResult: |
| 307 | """Structured checkout result.""" |
| 308 | |
| 309 | def __init__( |
| 310 | self, |
| 311 | success: bool, |
| 312 | blocked: bool = False, |
| 313 | target: str = "", |
| 314 | head_moved: bool = False, |
| 315 | executed: int = 0, |
| 316 | failed: int = 0, |
| 317 | plan_hash: str = "", |
| 318 | drift_severity: str = "", |
| 319 | drift_total_changes: int = 0, |
| 320 | status_code: int = 200, |
| 321 | ) -> None: |
| 322 | self.success = success |
| 323 | self.blocked = blocked |
| 324 | self.target = target |
| 325 | self.head_moved = head_moved |
| 326 | self.executed = executed |
| 327 | self.failed = failed |
| 328 | self.plan_hash = plan_hash |
| 329 | self.drift_severity = drift_severity |
| 330 | self.drift_total_changes = drift_total_changes |
| 331 | self.status_code = status_code |
| 332 | |
| 333 | def to_dict(self) -> dict[str, Any]: |
| 334 | return { |
| 335 | "success": self.success, |
| 336 | "blocked": self.blocked, |
| 337 | "target": self.target, |
| 338 | "head_moved": self.head_moved, |
| 339 | "executed": self.executed, |
| 340 | "plan_hash": self.plan_hash[:12] if self.plan_hash else "", |
| 341 | "drift_severity": self.drift_severity, |
| 342 | "drift_total_changes": self.drift_total_changes, |
| 343 | } |
| 344 | |
| 345 | |
| 346 | class MergeResult: |
| 347 | """Structured merge result.""" |
| 348 | |
| 349 | def __init__( |
| 350 | self, |
| 351 | success: bool, |
| 352 | merge_variation_id: str = "", |
| 353 | conflicts: list[dict] | None = None, |
| 354 | executed: int = 0, |
| 355 | status_code: int = 200, |
| 356 | ) -> None: |
| 357 | self.success = success |
| 358 | self.merge_variation_id = merge_variation_id |
| 359 | self.conflicts = conflicts or [] |
| 360 | self.executed = executed |
| 361 | self.status_code = status_code |
| 362 | |
| 363 | def to_dict(self) -> dict[str, Any]: |
| 364 | return { |
| 365 | "success": self.success, |
| 366 | "merge_variation_id": self.merge_variation_id, |
| 367 | "conflict_count": len(self.conflicts), |
| 368 | "conflicts": self.conflicts, |
| 369 | "executed": self.executed, |
| 370 | } |
| 371 | |
| 372 | |
| 373 | class MuseError(Exception): |
| 374 | pass |