muse.py
python
| 1 | """Muse VCS routes — commit graph, checkout, merge, HEAD management. |
| 2 | |
| 3 | Production endpoints that expose Muse's version-control primitives to |
| 4 | the Stori DAW. These are the HTTP surface for the history engine built |
| 5 | in Phases 5–13. |
| 6 | |
| 7 | Endpoint summary: |
| 8 | POST /muse/variations — persist a variation directly |
| 9 | POST /muse/head — set HEAD pointer |
| 10 | GET /muse/log — commit DAG (MuseLogGraph) |
| 11 | POST /muse/checkout — checkout to a variation (time travel) |
| 12 | POST /muse/merge — three-way merge of two variations |
| 13 | """ |
| 14 | |
| 15 | from __future__ import annotations |
| 16 | |
| 17 | import logging |
| 18 | |
| 19 | from fastapi import APIRouter, Depends, HTTPException |
| 20 | from pydantic import BaseModel, Field |
| 21 | from sqlalchemy.ext.asyncio import AsyncSession |
| 22 | |
| 23 | from maestro.contracts.json_types import ( |
| 24 | AftertouchDict, |
| 25 | CCEventDict, |
| 26 | JSONValue, |
| 27 | PitchBendDict, |
| 28 | RegionMetadataWire, |
| 29 | jfloat, |
| 30 | jint, |
| 31 | ) |
| 32 | from maestro.contracts.pydantic_types import PydanticJson, wrap_dict, unwrap_dict |
| 33 | from maestro.auth.dependencies import require_valid_token |
| 34 | from maestro.core.state_store import get_or_create_store |
| 35 | from maestro.core.tracing import create_trace_context |
| 36 | from maestro.db import get_db |
| 37 | from maestro.models.variation import ( |
| 38 | ChangeType, |
| 39 | MidiNoteSnapshot, |
| 40 | NoteChange as DomainNoteChange, |
| 41 | Phrase as DomainPhrase, |
| 42 | Variation as DomainVariation, |
| 43 | ) |
| 44 | from maestro.services import muse_repository |
| 45 | from maestro.services.muse_history_controller import ( |
| 46 | CheckoutBlockedError, |
| 47 | MergeConflictError, |
| 48 | checkout_to_variation, |
| 49 | merge_variations, |
| 50 | ) |
| 51 | from maestro.services.muse_log_graph import MuseLogGraphResponse, build_muse_log_graph |
| 52 | |
| 53 | logger = logging.getLogger(__name__) |
| 54 | |
| 55 | router = APIRouter(prefix="/muse", tags=["muse"]) |
| 56 | |
| 57 | |
| 58 | def _parse_change_type(raw: str) -> ChangeType: |
| 59 | """Narrow a wire-format string to the ChangeType literal.""" |
| 60 | if raw == "added": |
| 61 | return "added" |
| 62 | if raw == "removed": |
| 63 | return "removed" |
| 64 | return "modified" |
| 65 | |
| 66 | |
| 67 | # ── Request models ──────────────────────────────────────────────────────── |
| 68 | |
| 69 | |
| 70 | class SaveVariationRequest(BaseModel): |
| 71 | project_id: str |
| 72 | variation_id: str |
| 73 | intent: str |
| 74 | conversation_id: str = "default" |
| 75 | parent_variation_id: str | None = None |
| 76 | parent2_variation_id: str | None = None |
| 77 | phrases: list[dict[str, PydanticJson]] = Field(default_factory=list) |
| 78 | affected_tracks: list[str] = Field(default_factory=list) |
| 79 | affected_regions: list[str] = Field(default_factory=list) |
| 80 | beat_range: tuple[float, float] = (0.0, 8.0) |
| 81 | |
| 82 | |
| 83 | class SetHeadRequest(BaseModel): |
| 84 | variation_id: str |
| 85 | |
| 86 | |
| 87 | class CheckoutRequest(BaseModel): |
| 88 | project_id: str |
| 89 | target_variation_id: str |
| 90 | conversation_id: str = "default" |
| 91 | force: bool = False |
| 92 | |
| 93 | |
| 94 | class MergeRequest(BaseModel): |
| 95 | project_id: str |
| 96 | left_id: str |
| 97 | right_id: str |
| 98 | conversation_id: str = "default" |
| 99 | force: bool = False |
| 100 | |
| 101 | |
| 102 | # ── Response models ─────────────────────────────────────────────────────── |
| 103 | |
| 104 | |
| 105 | class SaveVariationResponse(BaseModel): |
| 106 | """Confirmation that a variation was persisted to Muse history. |
| 107 | |
| 108 | Returned by ``POST /muse/variations`` after the variation record has been |
| 109 | written to the database and the transaction committed. |
| 110 | |
| 111 | Attributes: |
| 112 | variation_id: UUID of the variation that was saved. Echoes back the |
| 113 | ID supplied in the request so the caller can correlate the response |
| 114 | without re-reading the request body. |
| 115 | """ |
| 116 | |
| 117 | variation_id: str = Field( |
| 118 | description="UUID of the variation that was saved." |
| 119 | ) |
| 120 | |
| 121 | |
| 122 | class SetHeadResponse(BaseModel): |
| 123 | """Confirmation that the HEAD pointer was moved. |
| 124 | |
| 125 | Returned by ``POST /muse/head`` after the HEAD record has been updated and |
| 126 | the transaction committed. |
| 127 | |
| 128 | Attributes: |
| 129 | head: UUID of the variation that is now HEAD. Echoes back the ID |
| 130 | supplied in the request. |
| 131 | """ |
| 132 | |
| 133 | head: str = Field( |
| 134 | description="UUID of the variation that is now HEAD." |
| 135 | ) |
| 136 | |
| 137 | |
| 138 | class CheckoutExecutionStats(BaseModel): |
| 139 | """Execution statistics for a single plan-execution pass. |
| 140 | |
| 141 | Shared by both ``CheckoutResponse`` and ``MergeResponse`` because both |
| 142 | operations run a checkout plan against the ``StateStore`` at the end. |
| 143 | |
| 144 | Attributes: |
| 145 | executed: Number of tool-call steps that were executed successfully |
| 146 | during this checkout pass. |
| 147 | failed: Number of tool-call steps that failed during this checkout |
| 148 | pass. A non-zero value indicates a partial checkout — the DAW |
| 149 | state may be inconsistent. |
| 150 | plan_hash: SHA-256 content hash of the serialised checkout plan (hex |
| 151 | string). Identical hashes guarantee identical execution plans; |
| 152 | useful for deduplication and idempotency checks. |
| 153 | events: Ordered list of SSE event payloads that were emitted during |
| 154 | execution. Each element is a raw ``dict[str, object]`` matching |
| 155 | the wire format of the corresponding ``MaestroEvent`` subclass. |
| 156 | Included so callers can replay or inspect the execution trace |
| 157 | without re-running the checkout. |
| 158 | """ |
| 159 | |
| 160 | executed: int = Field( |
| 161 | description="Number of tool-call steps executed successfully during this checkout pass." |
| 162 | ) |
| 163 | failed: int = Field( |
| 164 | description=( |
| 165 | "Number of tool-call steps that failed. " |
| 166 | "Non-zero indicates a partial checkout — DAW state may be inconsistent." |
| 167 | ) |
| 168 | ) |
| 169 | plan_hash: str = Field( |
| 170 | description=( |
| 171 | "SHA-256 content hash of the serialised checkout plan (hex string). " |
| 172 | "Identical hashes guarantee identical execution plans." |
| 173 | ) |
| 174 | ) |
| 175 | events: list[dict[str, PydanticJson]] = Field( |
| 176 | description=( |
| 177 | "Ordered list of SSE event payloads emitted during execution. " |
| 178 | "Each element is a raw dict matching the wire format of a MaestroEvent subclass." |
| 179 | ) |
| 180 | ) |
| 181 | |
| 182 | |
| 183 | class CheckoutResponse(BaseModel): |
| 184 | """Full summary of a checkout operation — the musical equivalent of ``git checkout``. |
| 185 | |
| 186 | Returned by ``POST /muse/checkout`` after the target variation has been |
| 187 | reconstructed, its checkout plan executed against ``StateStore``, and HEAD |
| 188 | moved. Returns 409 instead if the working tree is dirty and ``force`` is |
| 189 | not set. |
| 190 | |
| 191 | Attributes: |
| 192 | project_id: UUID of the project on which the checkout was performed. |
| 193 | from_variation_id: UUID of the variation that was HEAD before checkout, |
| 194 | or ``None`` if the project had no HEAD (first checkout). |
| 195 | to_variation_id: UUID of the variation that is now HEAD after checkout. |
| 196 | execution: Plan-execution statistics and event trace for this checkout |
| 197 | pass (see ``CheckoutExecutionStats``). |
| 198 | head_moved: ``True`` if the HEAD pointer was successfully updated to |
| 199 | ``to_variation_id``. ``False`` would indicate an unexpected |
| 200 | no-op (e.g. already at target), though in practice the endpoint |
| 201 | raises on failure rather than returning ``False``. |
| 202 | """ |
| 203 | |
| 204 | project_id: str = Field( |
| 205 | description="UUID of the project on which the checkout was performed." |
| 206 | ) |
| 207 | from_variation_id: str | None = Field( |
| 208 | description=( |
| 209 | "UUID of the variation that was HEAD before checkout, " |
| 210 | "or None if the project had no HEAD (first checkout)." |
| 211 | ) |
| 212 | ) |
| 213 | to_variation_id: str = Field( |
| 214 | description="UUID of the variation that is now HEAD after checkout." |
| 215 | ) |
| 216 | execution: CheckoutExecutionStats = Field( |
| 217 | description="Plan-execution statistics and event trace for this checkout pass." |
| 218 | ) |
| 219 | head_moved: bool = Field( |
| 220 | description="True if the HEAD pointer was successfully updated to to_variation_id." |
| 221 | ) |
| 222 | |
| 223 | |
| 224 | class MergeResponse(BaseModel): |
| 225 | """Full summary of a three-way merge — the musical equivalent of ``git merge``. |
| 226 | |
| 227 | Returned by ``POST /muse/merge`` after the merge base is computed, the |
| 228 | three-way diff is applied, the merged state is checked out via plan |
| 229 | execution, and a merge commit with two parents is created. Returns 409 |
| 230 | instead if the merge has unresolvable conflicts. |
| 231 | |
| 232 | Attributes: |
| 233 | project_id: UUID of the project on which the merge was performed. |
| 234 | merge_variation_id: UUID of the new merge commit (two parents: |
| 235 | ``left_id`` and ``right_id``). |
| 236 | left_id: UUID of the left (first) variation passed to the merge. |
| 237 | right_id: UUID of the right (second) variation passed to the merge. |
| 238 | execution: Plan-execution statistics and event trace for the checkout |
| 239 | pass that applied the merged state (see ``CheckoutExecutionStats``). |
| 240 | head_moved: ``True`` if HEAD was moved to ``merge_variation_id`` after |
| 241 | the merge commit was created. |
| 242 | """ |
| 243 | |
| 244 | project_id: str = Field( |
| 245 | description="UUID of the project on which the merge was performed." |
| 246 | ) |
| 247 | merge_variation_id: str = Field( |
| 248 | description=( |
| 249 | "UUID of the new merge commit with two parents: left_id and right_id." |
| 250 | ) |
| 251 | ) |
| 252 | left_id: str = Field( |
| 253 | description="UUID of the left (first) variation passed to the merge." |
| 254 | ) |
| 255 | right_id: str = Field( |
| 256 | description="UUID of the right (second) variation passed to the merge." |
| 257 | ) |
| 258 | execution: CheckoutExecutionStats = Field( |
| 259 | description=( |
| 260 | "Plan-execution statistics and event trace for the checkout pass " |
| 261 | "that applied the merged state." |
| 262 | ) |
| 263 | ) |
| 264 | head_moved: bool = Field( |
| 265 | description="True if HEAD was moved to merge_variation_id after the merge commit was created." |
| 266 | ) |
| 267 | |
| 268 | |
| 269 | # ── POST /muse/variations ──────────────────────────────────────────────── |
| 270 | |
| 271 | |
| 272 | @router.post("/variations", dependencies=[Depends(require_valid_token)]) |
| 273 | async def save_variation( |
| 274 | req: SaveVariationRequest, |
| 275 | db: AsyncSession = Depends(get_db), |
| 276 | ) -> SaveVariationResponse: |
| 277 | """Persist a variation directly into Muse history. |
| 278 | |
| 279 | Accepts a complete variation payload (phrases, note changes, |
| 280 | controller changes) and writes it to the variations table. |
| 281 | """ |
| 282 | domain_phrases: list[DomainPhrase] = [] |
| 283 | for p_raw in req.phrases: |
| 284 | p = unwrap_dict(p_raw) # dict[str, JSONValue] — known phrase shape |
| 285 | note_changes: list[DomainNoteChange] = [] |
| 286 | _raw_nc: JSONValue = p.get("note_changes", []) |
| 287 | for nc in (_raw_nc if isinstance(_raw_nc, list) else []): |
| 288 | if not isinstance(nc, dict): |
| 289 | continue |
| 290 | _nc_before = nc.get("before") |
| 291 | _nc_after = nc.get("after") |
| 292 | note_changes.append(DomainNoteChange( |
| 293 | note_id=str(nc.get("note_id", "")), |
| 294 | change_type=_parse_change_type(str(nc.get("change_type", ""))), |
| 295 | before=MidiNoteSnapshot.model_validate(_nc_before) if isinstance(_nc_before, dict) else None, |
| 296 | after=MidiNoteSnapshot.model_validate(_nc_after) if isinstance(_nc_after, dict) else None, |
| 297 | )) |
| 298 | _raw_cc_events: JSONValue = p.get("cc_events", []) |
| 299 | _cc_events: list[CCEventDict] = [ |
| 300 | CCEventDict(cc=jint(e.get("cc", 0)), beat=jfloat(e.get("beat", 0.0)), value=jint(e.get("value", 0))) |
| 301 | for e in (_raw_cc_events if isinstance(_raw_cc_events, list) else []) |
| 302 | if isinstance(e, dict) |
| 303 | ] |
| 304 | _raw_pb: JSONValue = p.get("pitch_bends", []) |
| 305 | _pitch_bends: list[PitchBendDict] = [ |
| 306 | PitchBendDict(beat=jfloat(e.get("beat", 0.0)), value=jint(e.get("value", 0))) |
| 307 | for e in (_raw_pb if isinstance(_raw_pb, list) else []) |
| 308 | if isinstance(e, dict) |
| 309 | ] |
| 310 | _raw_at: JSONValue = p.get("aftertouch", []) |
| 311 | _aftertouch: list[AftertouchDict] = [] |
| 312 | for at_raw in (_raw_at if isinstance(_raw_at, list) else []): |
| 313 | if not isinstance(at_raw, dict): |
| 314 | continue |
| 315 | at_ev: AftertouchDict = { |
| 316 | "beat": jfloat(at_raw.get("beat", 0.0)), |
| 317 | "value": jint(at_raw.get("value", 0)), |
| 318 | } |
| 319 | if "pitch" in at_raw: |
| 320 | at_ev["pitch"] = jint(at_raw["pitch"]) |
| 321 | _aftertouch.append(at_ev) |
| 322 | _raw_tags: JSONValue = p.get("tags", []) |
| 323 | _tags: list[str] = [t for t in _raw_tags if isinstance(t, str)] if isinstance(_raw_tags, list) else [] |
| 324 | _sb = p.get("start_beat", 0.0) |
| 325 | _eb = p.get("end_beat", 8.0) |
| 326 | domain_phrases.append(DomainPhrase( |
| 327 | phrase_id=str(p.get("phrase_id", "")), |
| 328 | track_id=str(p.get("track_id", "")), |
| 329 | region_id=str(p.get("region_id", "")), |
| 330 | start_beat=float(_sb) if isinstance(_sb, (int, float)) else 0.0, |
| 331 | end_beat=float(_eb) if isinstance(_eb, (int, float)) else 8.0, |
| 332 | label=str(p.get("label", "Muse")), |
| 333 | note_changes=note_changes, |
| 334 | cc_events=_cc_events, |
| 335 | pitch_bends=_pitch_bends, |
| 336 | aftertouch=_aftertouch, |
| 337 | tags=_tags, |
| 338 | )) |
| 339 | |
| 340 | variation = DomainVariation( |
| 341 | variation_id=req.variation_id, |
| 342 | intent=req.intent, |
| 343 | ai_explanation=None, |
| 344 | affected_tracks=req.affected_tracks, |
| 345 | affected_regions=req.affected_regions, |
| 346 | beat_range=req.beat_range, |
| 347 | phrases=domain_phrases, |
| 348 | ) |
| 349 | |
| 350 | region_metadata: dict[str, RegionMetadataWire] = {} |
| 351 | for dp in domain_phrases: |
| 352 | region_metadata[dp.region_id] = { |
| 353 | "startBeat": dp.start_beat, |
| 354 | "durationBeats": dp.end_beat - dp.start_beat, |
| 355 | "name": dp.region_id, |
| 356 | } |
| 357 | |
| 358 | await muse_repository.save_variation( |
| 359 | db, |
| 360 | variation, |
| 361 | project_id=req.project_id, |
| 362 | base_state_id="muse", |
| 363 | conversation_id=req.conversation_id, |
| 364 | region_metadata=region_metadata, |
| 365 | status="committed", |
| 366 | parent_variation_id=req.parent_variation_id, |
| 367 | parent2_variation_id=req.parent2_variation_id, |
| 368 | ) |
| 369 | await db.commit() |
| 370 | |
| 371 | logger.info("✅ Variation saved via route: %s", req.variation_id[:8]) |
| 372 | return SaveVariationResponse(variation_id=req.variation_id) |
| 373 | |
| 374 | |
| 375 | # ── POST /muse/head ────────────────────────────────────────────────────── |
| 376 | |
| 377 | |
| 378 | @router.post("/head", dependencies=[Depends(require_valid_token)]) |
| 379 | async def set_head( |
| 380 | req: SetHeadRequest, |
| 381 | db: AsyncSession = Depends(get_db), |
| 382 | ) -> SetHeadResponse: |
| 383 | """Set the HEAD pointer for a project to a specific variation.""" |
| 384 | await muse_repository.set_head(db, req.variation_id) |
| 385 | await db.commit() |
| 386 | return SetHeadResponse(head=req.variation_id) |
| 387 | |
| 388 | |
| 389 | # ── GET /muse/log ──────────────────────────────────────────────────────── |
| 390 | |
| 391 | |
| 392 | @router.get("/log", dependencies=[Depends(require_valid_token)]) |
| 393 | async def get_log( |
| 394 | project_id: str, |
| 395 | db: AsyncSession = Depends(get_db), |
| 396 | ) -> MuseLogGraphResponse: |
| 397 | """Return the full commit DAG for a project as ``MuseLogGraphResponse``.""" |
| 398 | graph = await build_muse_log_graph(db, project_id) |
| 399 | return graph.to_response() |
| 400 | |
| 401 | |
| 402 | # ── POST /muse/checkout ────────────────────────────────────────────────── |
| 403 | |
| 404 | |
| 405 | @router.post("/checkout", dependencies=[Depends(require_valid_token)]) |
| 406 | async def checkout( |
| 407 | req: CheckoutRequest, |
| 408 | db: AsyncSession = Depends(get_db), |
| 409 | ) -> CheckoutResponse: |
| 410 | """Checkout to a specific variation — musical ``git checkout``. |
| 411 | |
| 412 | Reconstructs the target state, generates a checkout plan, executes |
| 413 | it against StateStore, and moves HEAD. |
| 414 | |
| 415 | Returns 409 if the working tree has uncommitted drift and |
| 416 | ``force`` is not set. |
| 417 | """ |
| 418 | store = get_or_create_store(req.conversation_id, req.project_id) |
| 419 | trace = create_trace_context() |
| 420 | |
| 421 | try: |
| 422 | summary = await checkout_to_variation( |
| 423 | session=db, |
| 424 | project_id=req.project_id, |
| 425 | target_variation_id=req.target_variation_id, |
| 426 | store=store, |
| 427 | trace=trace, |
| 428 | force=req.force, |
| 429 | ) |
| 430 | await db.commit() |
| 431 | return CheckoutResponse( |
| 432 | project_id=summary.project_id, |
| 433 | from_variation_id=summary.from_variation_id, |
| 434 | to_variation_id=summary.to_variation_id, |
| 435 | execution=CheckoutExecutionStats( |
| 436 | executed=summary.execution.executed, |
| 437 | failed=summary.execution.failed, |
| 438 | plan_hash=summary.execution.plan_hash, |
| 439 | events=[wrap_dict(e) for e in summary.execution.events], |
| 440 | ), |
| 441 | head_moved=summary.head_moved, |
| 442 | ) |
| 443 | except CheckoutBlockedError as e: |
| 444 | raise HTTPException(status_code=409, detail={ |
| 445 | "error": "checkout_blocked", |
| 446 | "severity": e.severity.value, |
| 447 | "total_changes": e.total_changes, |
| 448 | }) |
| 449 | except ValueError as e: |
| 450 | raise HTTPException(status_code=404, detail=str(e)) |
| 451 | |
| 452 | |
| 453 | # ── POST /muse/merge ───────────────────────────────────────────────────── |
| 454 | |
| 455 | |
| 456 | @router.post("/merge", dependencies=[Depends(require_valid_token)]) |
| 457 | async def merge( |
| 458 | req: MergeRequest, |
| 459 | db: AsyncSession = Depends(get_db), |
| 460 | ) -> MergeResponse: |
| 461 | """Three-way merge of two variations — musical ``git merge``. |
| 462 | |
| 463 | Computes the merge base, builds a three-way diff, and if |
| 464 | conflict-free, applies the merged state via checkout execution. |
| 465 | Creates a merge commit with two parents. |
| 466 | |
| 467 | Returns 409 with conflict details if the merge cannot auto-resolve. |
| 468 | """ |
| 469 | store = get_or_create_store(req.conversation_id, req.project_id) |
| 470 | trace = create_trace_context() |
| 471 | |
| 472 | try: |
| 473 | summary = await merge_variations( |
| 474 | session=db, |
| 475 | project_id=req.project_id, |
| 476 | left_id=req.left_id, |
| 477 | right_id=req.right_id, |
| 478 | store=store, |
| 479 | trace=trace, |
| 480 | force=req.force, |
| 481 | ) |
| 482 | await db.commit() |
| 483 | return MergeResponse( |
| 484 | project_id=summary.project_id, |
| 485 | merge_variation_id=summary.merge_variation_id, |
| 486 | left_id=summary.left_id, |
| 487 | right_id=summary.right_id, |
| 488 | execution=CheckoutExecutionStats( |
| 489 | executed=summary.execution.executed, |
| 490 | failed=summary.execution.failed, |
| 491 | plan_hash=summary.execution.plan_hash, |
| 492 | events=[wrap_dict(e) for e in summary.execution.events], |
| 493 | ), |
| 494 | head_moved=summary.head_moved, |
| 495 | ) |
| 496 | except MergeConflictError as e: |
| 497 | raise HTTPException(status_code=409, detail={ |
| 498 | "error": "merge_conflict", |
| 499 | "conflicts": [ |
| 500 | {"region_id": c.region_id, "type": c.type, "description": c.description} |
| 501 | for c in e.conflicts |
| 502 | ], |
| 503 | }) |