cgcardona / muse public
muse.py python
503 lines 18.8 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Muse VCS routes — commit graph, checkout, merge, HEAD management.
2
3 Production endpoints that expose Muse's version-control primitives to
4 the Stori DAW. These are the HTTP surface for the history engine built
5 in Phases 5–13.
6
7 Endpoint summary:
8 POST /muse/variations — persist a variation directly
9 POST /muse/head — set HEAD pointer
10 GET /muse/log — commit DAG (MuseLogGraph)
11 POST /muse/checkout — checkout to a variation (time travel)
12 POST /muse/merge — three-way merge of two variations
13 """
14
15 from __future__ import annotations
16
17 import logging
18
19 from fastapi import APIRouter, Depends, HTTPException
20 from pydantic import BaseModel, Field
21 from sqlalchemy.ext.asyncio import AsyncSession
22
23 from maestro.contracts.json_types import (
24 AftertouchDict,
25 CCEventDict,
26 JSONValue,
27 PitchBendDict,
28 RegionMetadataWire,
29 jfloat,
30 jint,
31 )
32 from maestro.contracts.pydantic_types import PydanticJson, wrap_dict, unwrap_dict
33 from maestro.auth.dependencies import require_valid_token
34 from maestro.core.state_store import get_or_create_store
35 from maestro.core.tracing import create_trace_context
36 from maestro.db import get_db
37 from maestro.models.variation import (
38 ChangeType,
39 MidiNoteSnapshot,
40 NoteChange as DomainNoteChange,
41 Phrase as DomainPhrase,
42 Variation as DomainVariation,
43 )
44 from maestro.services import muse_repository
45 from maestro.services.muse_history_controller import (
46 CheckoutBlockedError,
47 MergeConflictError,
48 checkout_to_variation,
49 merge_variations,
50 )
51 from maestro.services.muse_log_graph import MuseLogGraphResponse, build_muse_log_graph
52
53 logger = logging.getLogger(__name__)
54
55 router = APIRouter(prefix="/muse", tags=["muse"])
56
57
58 def _parse_change_type(raw: str) -> ChangeType:
59 """Narrow a wire-format string to the ChangeType literal."""
60 if raw == "added":
61 return "added"
62 if raw == "removed":
63 return "removed"
64 return "modified"
65
66
67 # ── Request models ────────────────────────────────────────────────────────
68
69
70 class SaveVariationRequest(BaseModel):
71 project_id: str
72 variation_id: str
73 intent: str
74 conversation_id: str = "default"
75 parent_variation_id: str | None = None
76 parent2_variation_id: str | None = None
77 phrases: list[dict[str, PydanticJson]] = Field(default_factory=list)
78 affected_tracks: list[str] = Field(default_factory=list)
79 affected_regions: list[str] = Field(default_factory=list)
80 beat_range: tuple[float, float] = (0.0, 8.0)
81
82
83 class SetHeadRequest(BaseModel):
84 variation_id: str
85
86
87 class CheckoutRequest(BaseModel):
88 project_id: str
89 target_variation_id: str
90 conversation_id: str = "default"
91 force: bool = False
92
93
94 class MergeRequest(BaseModel):
95 project_id: str
96 left_id: str
97 right_id: str
98 conversation_id: str = "default"
99 force: bool = False
100
101
102 # ── Response models ───────────────────────────────────────────────────────
103
104
105 class SaveVariationResponse(BaseModel):
106 """Confirmation that a variation was persisted to Muse history.
107
108 Returned by ``POST /muse/variations`` after the variation record has been
109 written to the database and the transaction committed.
110
111 Attributes:
112 variation_id: UUID of the variation that was saved. Echoes back the
113 ID supplied in the request so the caller can correlate the response
114 without re-reading the request body.
115 """
116
117 variation_id: str = Field(
118 description="UUID of the variation that was saved."
119 )
120
121
122 class SetHeadResponse(BaseModel):
123 """Confirmation that the HEAD pointer was moved.
124
125 Returned by ``POST /muse/head`` after the HEAD record has been updated and
126 the transaction committed.
127
128 Attributes:
129 head: UUID of the variation that is now HEAD. Echoes back the ID
130 supplied in the request.
131 """
132
133 head: str = Field(
134 description="UUID of the variation that is now HEAD."
135 )
136
137
138 class CheckoutExecutionStats(BaseModel):
139 """Execution statistics for a single plan-execution pass.
140
141 Shared by both ``CheckoutResponse`` and ``MergeResponse`` because both
142 operations run a checkout plan against the ``StateStore`` at the end.
143
144 Attributes:
145 executed: Number of tool-call steps that were executed successfully
146 during this checkout pass.
147 failed: Number of tool-call steps that failed during this checkout
148 pass. A non-zero value indicates a partial checkout — the DAW
149 state may be inconsistent.
150 plan_hash: SHA-256 content hash of the serialised checkout plan (hex
151 string). Identical hashes guarantee identical execution plans;
152 useful for deduplication and idempotency checks.
153 events: Ordered list of SSE event payloads that were emitted during
154 execution. Each element is a raw ``dict[str, object]`` matching
155 the wire format of the corresponding ``MaestroEvent`` subclass.
156 Included so callers can replay or inspect the execution trace
157 without re-running the checkout.
158 """
159
160 executed: int = Field(
161 description="Number of tool-call steps executed successfully during this checkout pass."
162 )
163 failed: int = Field(
164 description=(
165 "Number of tool-call steps that failed. "
166 "Non-zero indicates a partial checkout — DAW state may be inconsistent."
167 )
168 )
169 plan_hash: str = Field(
170 description=(
171 "SHA-256 content hash of the serialised checkout plan (hex string). "
172 "Identical hashes guarantee identical execution plans."
173 )
174 )
175 events: list[dict[str, PydanticJson]] = Field(
176 description=(
177 "Ordered list of SSE event payloads emitted during execution. "
178 "Each element is a raw dict matching the wire format of a MaestroEvent subclass."
179 )
180 )
181
182
183 class CheckoutResponse(BaseModel):
184 """Full summary of a checkout operation — the musical equivalent of ``git checkout``.
185
186 Returned by ``POST /muse/checkout`` after the target variation has been
187 reconstructed, its checkout plan executed against ``StateStore``, and HEAD
188 moved. Returns 409 instead if the working tree is dirty and ``force`` is
189 not set.
190
191 Attributes:
192 project_id: UUID of the project on which the checkout was performed.
193 from_variation_id: UUID of the variation that was HEAD before checkout,
194 or ``None`` if the project had no HEAD (first checkout).
195 to_variation_id: UUID of the variation that is now HEAD after checkout.
196 execution: Plan-execution statistics and event trace for this checkout
197 pass (see ``CheckoutExecutionStats``).
198 head_moved: ``True`` if the HEAD pointer was successfully updated to
199 ``to_variation_id``. ``False`` would indicate an unexpected
200 no-op (e.g. already at target), though in practice the endpoint
201 raises on failure rather than returning ``False``.
202 """
203
204 project_id: str = Field(
205 description="UUID of the project on which the checkout was performed."
206 )
207 from_variation_id: str | None = Field(
208 description=(
209 "UUID of the variation that was HEAD before checkout, "
210 "or None if the project had no HEAD (first checkout)."
211 )
212 )
213 to_variation_id: str = Field(
214 description="UUID of the variation that is now HEAD after checkout."
215 )
216 execution: CheckoutExecutionStats = Field(
217 description="Plan-execution statistics and event trace for this checkout pass."
218 )
219 head_moved: bool = Field(
220 description="True if the HEAD pointer was successfully updated to to_variation_id."
221 )
222
223
224 class MergeResponse(BaseModel):
225 """Full summary of a three-way merge — the musical equivalent of ``git merge``.
226
227 Returned by ``POST /muse/merge`` after the merge base is computed, the
228 three-way diff is applied, the merged state is checked out via plan
229 execution, and a merge commit with two parents is created. Returns 409
230 instead if the merge has unresolvable conflicts.
231
232 Attributes:
233 project_id: UUID of the project on which the merge was performed.
234 merge_variation_id: UUID of the new merge commit (two parents:
235 ``left_id`` and ``right_id``).
236 left_id: UUID of the left (first) variation passed to the merge.
237 right_id: UUID of the right (second) variation passed to the merge.
238 execution: Plan-execution statistics and event trace for the checkout
239 pass that applied the merged state (see ``CheckoutExecutionStats``).
240 head_moved: ``True`` if HEAD was moved to ``merge_variation_id`` after
241 the merge commit was created.
242 """
243
244 project_id: str = Field(
245 description="UUID of the project on which the merge was performed."
246 )
247 merge_variation_id: str = Field(
248 description=(
249 "UUID of the new merge commit with two parents: left_id and right_id."
250 )
251 )
252 left_id: str = Field(
253 description="UUID of the left (first) variation passed to the merge."
254 )
255 right_id: str = Field(
256 description="UUID of the right (second) variation passed to the merge."
257 )
258 execution: CheckoutExecutionStats = Field(
259 description=(
260 "Plan-execution statistics and event trace for the checkout pass "
261 "that applied the merged state."
262 )
263 )
264 head_moved: bool = Field(
265 description="True if HEAD was moved to merge_variation_id after the merge commit was created."
266 )
267
268
269 # ── POST /muse/variations ────────────────────────────────────────────────
270
271
272 @router.post("/variations", dependencies=[Depends(require_valid_token)])
273 async def save_variation(
274 req: SaveVariationRequest,
275 db: AsyncSession = Depends(get_db),
276 ) -> SaveVariationResponse:
277 """Persist a variation directly into Muse history.
278
279 Accepts a complete variation payload (phrases, note changes,
280 controller changes) and writes it to the variations table.
281 """
282 domain_phrases: list[DomainPhrase] = []
283 for p_raw in req.phrases:
284 p = unwrap_dict(p_raw) # dict[str, JSONValue] — known phrase shape
285 note_changes: list[DomainNoteChange] = []
286 _raw_nc: JSONValue = p.get("note_changes", [])
287 for nc in (_raw_nc if isinstance(_raw_nc, list) else []):
288 if not isinstance(nc, dict):
289 continue
290 _nc_before = nc.get("before")
291 _nc_after = nc.get("after")
292 note_changes.append(DomainNoteChange(
293 note_id=str(nc.get("note_id", "")),
294 change_type=_parse_change_type(str(nc.get("change_type", ""))),
295 before=MidiNoteSnapshot.model_validate(_nc_before) if isinstance(_nc_before, dict) else None,
296 after=MidiNoteSnapshot.model_validate(_nc_after) if isinstance(_nc_after, dict) else None,
297 ))
298 _raw_cc_events: JSONValue = p.get("cc_events", [])
299 _cc_events: list[CCEventDict] = [
300 CCEventDict(cc=jint(e.get("cc", 0)), beat=jfloat(e.get("beat", 0.0)), value=jint(e.get("value", 0)))
301 for e in (_raw_cc_events if isinstance(_raw_cc_events, list) else [])
302 if isinstance(e, dict)
303 ]
304 _raw_pb: JSONValue = p.get("pitch_bends", [])
305 _pitch_bends: list[PitchBendDict] = [
306 PitchBendDict(beat=jfloat(e.get("beat", 0.0)), value=jint(e.get("value", 0)))
307 for e in (_raw_pb if isinstance(_raw_pb, list) else [])
308 if isinstance(e, dict)
309 ]
310 _raw_at: JSONValue = p.get("aftertouch", [])
311 _aftertouch: list[AftertouchDict] = []
312 for at_raw in (_raw_at if isinstance(_raw_at, list) else []):
313 if not isinstance(at_raw, dict):
314 continue
315 at_ev: AftertouchDict = {
316 "beat": jfloat(at_raw.get("beat", 0.0)),
317 "value": jint(at_raw.get("value", 0)),
318 }
319 if "pitch" in at_raw:
320 at_ev["pitch"] = jint(at_raw["pitch"])
321 _aftertouch.append(at_ev)
322 _raw_tags: JSONValue = p.get("tags", [])
323 _tags: list[str] = [t for t in _raw_tags if isinstance(t, str)] if isinstance(_raw_tags, list) else []
324 _sb = p.get("start_beat", 0.0)
325 _eb = p.get("end_beat", 8.0)
326 domain_phrases.append(DomainPhrase(
327 phrase_id=str(p.get("phrase_id", "")),
328 track_id=str(p.get("track_id", "")),
329 region_id=str(p.get("region_id", "")),
330 start_beat=float(_sb) if isinstance(_sb, (int, float)) else 0.0,
331 end_beat=float(_eb) if isinstance(_eb, (int, float)) else 8.0,
332 label=str(p.get("label", "Muse")),
333 note_changes=note_changes,
334 cc_events=_cc_events,
335 pitch_bends=_pitch_bends,
336 aftertouch=_aftertouch,
337 tags=_tags,
338 ))
339
340 variation = DomainVariation(
341 variation_id=req.variation_id,
342 intent=req.intent,
343 ai_explanation=None,
344 affected_tracks=req.affected_tracks,
345 affected_regions=req.affected_regions,
346 beat_range=req.beat_range,
347 phrases=domain_phrases,
348 )
349
350 region_metadata: dict[str, RegionMetadataWire] = {}
351 for dp in domain_phrases:
352 region_metadata[dp.region_id] = {
353 "startBeat": dp.start_beat,
354 "durationBeats": dp.end_beat - dp.start_beat,
355 "name": dp.region_id,
356 }
357
358 await muse_repository.save_variation(
359 db,
360 variation,
361 project_id=req.project_id,
362 base_state_id="muse",
363 conversation_id=req.conversation_id,
364 region_metadata=region_metadata,
365 status="committed",
366 parent_variation_id=req.parent_variation_id,
367 parent2_variation_id=req.parent2_variation_id,
368 )
369 await db.commit()
370
371 logger.info("✅ Variation saved via route: %s", req.variation_id[:8])
372 return SaveVariationResponse(variation_id=req.variation_id)
373
374
375 # ── POST /muse/head ──────────────────────────────────────────────────────
376
377
378 @router.post("/head", dependencies=[Depends(require_valid_token)])
379 async def set_head(
380 req: SetHeadRequest,
381 db: AsyncSession = Depends(get_db),
382 ) -> SetHeadResponse:
383 """Set the HEAD pointer for a project to a specific variation."""
384 await muse_repository.set_head(db, req.variation_id)
385 await db.commit()
386 return SetHeadResponse(head=req.variation_id)
387
388
389 # ── GET /muse/log ────────────────────────────────────────────────────────
390
391
392 @router.get("/log", dependencies=[Depends(require_valid_token)])
393 async def get_log(
394 project_id: str,
395 db: AsyncSession = Depends(get_db),
396 ) -> MuseLogGraphResponse:
397 """Return the full commit DAG for a project as ``MuseLogGraphResponse``."""
398 graph = await build_muse_log_graph(db, project_id)
399 return graph.to_response()
400
401
402 # ── POST /muse/checkout ──────────────────────────────────────────────────
403
404
405 @router.post("/checkout", dependencies=[Depends(require_valid_token)])
406 async def checkout(
407 req: CheckoutRequest,
408 db: AsyncSession = Depends(get_db),
409 ) -> CheckoutResponse:
410 """Checkout to a specific variation — musical ``git checkout``.
411
412 Reconstructs the target state, generates a checkout plan, executes
413 it against StateStore, and moves HEAD.
414
415 Returns 409 if the working tree has uncommitted drift and
416 ``force`` is not set.
417 """
418 store = get_or_create_store(req.conversation_id, req.project_id)
419 trace = create_trace_context()
420
421 try:
422 summary = await checkout_to_variation(
423 session=db,
424 project_id=req.project_id,
425 target_variation_id=req.target_variation_id,
426 store=store,
427 trace=trace,
428 force=req.force,
429 )
430 await db.commit()
431 return CheckoutResponse(
432 project_id=summary.project_id,
433 from_variation_id=summary.from_variation_id,
434 to_variation_id=summary.to_variation_id,
435 execution=CheckoutExecutionStats(
436 executed=summary.execution.executed,
437 failed=summary.execution.failed,
438 plan_hash=summary.execution.plan_hash,
439 events=[wrap_dict(e) for e in summary.execution.events],
440 ),
441 head_moved=summary.head_moved,
442 )
443 except CheckoutBlockedError as e:
444 raise HTTPException(status_code=409, detail={
445 "error": "checkout_blocked",
446 "severity": e.severity.value,
447 "total_changes": e.total_changes,
448 })
449 except ValueError as e:
450 raise HTTPException(status_code=404, detail=str(e))
451
452
453 # ── POST /muse/merge ─────────────────────────────────────────────────────
454
455
456 @router.post("/merge", dependencies=[Depends(require_valid_token)])
457 async def merge(
458 req: MergeRequest,
459 db: AsyncSession = Depends(get_db),
460 ) -> MergeResponse:
461 """Three-way merge of two variations — musical ``git merge``.
462
463 Computes the merge base, builds a three-way diff, and if
464 conflict-free, applies the merged state via checkout execution.
465 Creates a merge commit with two parents.
466
467 Returns 409 with conflict details if the merge cannot auto-resolve.
468 """
469 store = get_or_create_store(req.conversation_id, req.project_id)
470 trace = create_trace_context()
471
472 try:
473 summary = await merge_variations(
474 session=db,
475 project_id=req.project_id,
476 left_id=req.left_id,
477 right_id=req.right_id,
478 store=store,
479 trace=trace,
480 force=req.force,
481 )
482 await db.commit()
483 return MergeResponse(
484 project_id=summary.project_id,
485 merge_variation_id=summary.merge_variation_id,
486 left_id=summary.left_id,
487 right_id=summary.right_id,
488 execution=CheckoutExecutionStats(
489 executed=summary.execution.executed,
490 failed=summary.execution.failed,
491 plan_hash=summary.execution.plan_hash,
492 events=[wrap_dict(e) for e in summary.execution.events],
493 ),
494 head_moved=summary.head_moved,
495 )
496 except MergeConflictError as e:
497 raise HTTPException(status_code=409, detail={
498 "error": "merge_conflict",
499 "conflicts": [
500 {"region_id": c.region_id, "type": c.type, "description": c.description}
501 for c in e.conflicts
502 ],
503 })