muse_checkout_executor.py
python
| 1 | """Muse Checkout Executor — apply a CheckoutPlan to StateStore. |
| 2 | |
| 3 | Dispatches each tool call in the plan through the existing StateStore |
| 4 | mutation methods, producing SSE-compatible events that the DAW |
| 5 | processes identically to normal editing execution. |
| 6 | |
| 7 | Boundary rules: |
| 8 | - Must NOT import LLM handlers or maestro_* modules. |
| 9 | - Must NOT import VariationService. |
| 10 | - Must NOT import muse_replay internals. |
| 11 | - May import tool_names, state_store, tracing, muse_checkout, muse_drift. |
| 12 | """ |
| 13 | |
| 14 | from __future__ import annotations |
| 15 | |
| 16 | import logging |
| 17 | import uuid |
| 18 | from dataclasses import dataclass, field |
| 19 | |
| 20 | from maestro.contracts.json_types import AftertouchDict, CCEventDict, JSONValue, NoteDict, PitchBendDict, is_note_dict, jfloat, jint |
| 21 | from maestro.core.state_store import StateStore, Transaction |
| 22 | from maestro.core.tools import ToolName |
| 23 | from maestro.core.tracing import TraceContext, trace_span |
| 24 | from maestro.services.muse_checkout import CheckoutPlan |
| 25 | |
| 26 | logger = logging.getLogger(__name__) |
| 27 | |
| 28 | |
| 29 | @dataclass(frozen=True) |
| 30 | class CheckoutExecutionResult: |
| 31 | """Immutable summary returned after executing a ``CheckoutPlan``. |
| 32 | |
| 33 | Callers inspect ``success`` / ``is_noop`` to decide whether to commit the |
| 34 | checkout transaction and emit the Muse SSE ``checkout`` event. |
| 35 | |
| 36 | Attributes: |
| 37 | project_id: Project the checkout was applied to. |
| 38 | target_variation_id: Variation UUID that was restored. |
| 39 | executed: Number of tool calls that ran without error. |
| 40 | failed: Number of tool calls that raised an exception. |
| 41 | plan_hash: SHA-256 prefix of the ``CheckoutPlan`` for idempotency logs. |
| 42 | events: SSE-compatible event dicts emitted during execution; forwarded |
| 43 | to the Muse SSE stream by the route handler. |
| 44 | """ |
| 45 | |
| 46 | project_id: str |
| 47 | target_variation_id: str |
| 48 | executed: int |
| 49 | failed: int |
| 50 | plan_hash: str |
| 51 | events: tuple[dict[str, JSONValue], ...] = () |
| 52 | |
| 53 | @property |
| 54 | def success(self) -> bool: |
| 55 | """``True`` when all tool calls executed without error (and at least one ran).""" |
| 56 | return self.failed == 0 and self.executed > 0 |
| 57 | |
| 58 | @property |
| 59 | def is_noop(self) -> bool: |
| 60 | """``True`` when the plan had no tool calls (working tree already matched target).""" |
| 61 | return self.executed == 0 and self.failed == 0 |
| 62 | |
| 63 | |
| 64 | def execute_checkout_plan( |
| 65 | *, |
| 66 | checkout_plan: CheckoutPlan, |
| 67 | store: StateStore, |
| 68 | trace: TraceContext, |
| 69 | emit_sse: bool = True, |
| 70 | ) -> CheckoutExecutionResult: |
| 71 | """Execute a CheckoutPlan by dispatching tool calls to StateStore. |
| 72 | |
| 73 | Each tool call in the plan is applied in deterministic order: |
| 74 | ``stori_clear_notes`` → ``stori_add_notes`` → controllers. |
| 75 | |
| 76 | The ``store`` parameter is typed as ``Any`` to avoid importing |
| 77 | ``StateStore`` directly — the caller passes a concrete store |
| 78 | instance. The executor calls its public methods only. |
| 79 | |
| 80 | Args: |
| 81 | checkout_plan: The plan to execute (pure data). |
| 82 | store: StateStore instance (duck-typed). |
| 83 | trace: TraceContext for logging and spans. |
| 84 | emit_sse: When True, collect SSE-compatible events. |
| 85 | """ |
| 86 | if checkout_plan.is_noop: |
| 87 | logger.info("✅ Checkout is no-op — nothing to execute") |
| 88 | return CheckoutExecutionResult( |
| 89 | project_id=checkout_plan.project_id, |
| 90 | target_variation_id=checkout_plan.target_variation_id, |
| 91 | executed=0, |
| 92 | failed=0, |
| 93 | plan_hash=checkout_plan.plan_hash(), |
| 94 | ) |
| 95 | |
| 96 | executed = 0 |
| 97 | failed = 0 |
| 98 | events: list[dict[str, JSONValue]] = [] |
| 99 | |
| 100 | txn = store.begin_transaction( |
| 101 | f"checkout:{checkout_plan.target_variation_id[:8]}", |
| 102 | ) |
| 103 | |
| 104 | try: |
| 105 | with trace_span(trace, "checkout_execution", { |
| 106 | "target": checkout_plan.target_variation_id, |
| 107 | "call_count": len(checkout_plan.tool_calls), |
| 108 | }): |
| 109 | for call in checkout_plan.tool_calls: |
| 110 | tool = call["tool"] |
| 111 | args = call["arguments"] |
| 112 | call_id = str(uuid.uuid4()) |
| 113 | |
| 114 | try: |
| 115 | with trace_span(trace, f"checkout_tool:{tool}"): |
| 116 | _dispatch_tool(tool, args, store, txn) |
| 117 | |
| 118 | if emit_sse: |
| 119 | events.append({ |
| 120 | "type": "toolCall", |
| 121 | "id": call_id, |
| 122 | "tool": tool, |
| 123 | "params": args, |
| 124 | }) |
| 125 | executed += 1 |
| 126 | |
| 127 | except Exception as e: |
| 128 | failed += 1 |
| 129 | logger.error("❌ Checkout tool failed: %s — %s", tool, e) |
| 130 | if emit_sse: |
| 131 | events.append({ |
| 132 | "type": "toolError", |
| 133 | "id": call_id, |
| 134 | "tool": tool, |
| 135 | "error": str(e), |
| 136 | }) |
| 137 | |
| 138 | if failed == 0: |
| 139 | store.commit(txn) |
| 140 | logger.info( |
| 141 | "✅ Checkout executed: %d calls, target=%s", |
| 142 | executed, checkout_plan.target_variation_id[:8], |
| 143 | ) |
| 144 | else: |
| 145 | store.rollback(txn) |
| 146 | logger.warning( |
| 147 | "⚠️ Checkout rolled back: %d executed, %d failed", |
| 148 | executed, failed, |
| 149 | ) |
| 150 | |
| 151 | except Exception: |
| 152 | store.rollback(txn) |
| 153 | raise |
| 154 | |
| 155 | return CheckoutExecutionResult( |
| 156 | project_id=checkout_plan.project_id, |
| 157 | target_variation_id=checkout_plan.target_variation_id, |
| 158 | executed=executed, |
| 159 | failed=failed, |
| 160 | plan_hash=checkout_plan.plan_hash(), |
| 161 | events=tuple(events), |
| 162 | ) |
| 163 | |
| 164 | |
| 165 | def _make_cc_event(cc_num: int, e: dict[str, JSONValue]) -> CCEventDict: |
| 166 | return {"cc": cc_num, "beat": jfloat(e.get("beat")), "value": jint(e.get("value"))} |
| 167 | |
| 168 | |
| 169 | def _make_pb_event(e: dict[str, JSONValue]) -> PitchBendDict: |
| 170 | return {"beat": jfloat(e.get("beat")), "value": jint(e.get("value"))} |
| 171 | |
| 172 | |
| 173 | def _make_at_event(e: dict[str, JSONValue]) -> AftertouchDict: |
| 174 | return {"beat": jfloat(e.get("beat")), "value": jint(e.get("value"))} |
| 175 | |
| 176 | |
| 177 | def _dispatch_tool( |
| 178 | tool: str, |
| 179 | args: dict[str, JSONValue], |
| 180 | store: StateStore, |
| 181 | txn: Transaction, |
| 182 | ) -> None: |
| 183 | """Dispatch a single checkout tool call to StateStore methods.""" |
| 184 | _rid_raw = args.get("regionId", "") |
| 185 | region_id = _rid_raw if isinstance(_rid_raw, str) else "" |
| 186 | |
| 187 | if tool == ToolName.CLEAR_NOTES.value: |
| 188 | current = store.get_region_notes(region_id) |
| 189 | if current: |
| 190 | store.remove_notes(region_id, current, transaction=txn) |
| 191 | |
| 192 | elif tool == ToolName.ADD_NOTES.value: |
| 193 | _notes_raw = args.get("notes", []) |
| 194 | notes: list[NoteDict] = ( |
| 195 | [n for n in _notes_raw if is_note_dict(n)] |
| 196 | if isinstance(_notes_raw, list) else [] |
| 197 | ) |
| 198 | if notes: |
| 199 | store.add_notes(region_id, notes, transaction=txn) |
| 200 | |
| 201 | elif tool == ToolName.ADD_MIDI_CC.value: |
| 202 | _cc_raw = args.get("cc", 0) |
| 203 | cc_num = int(_cc_raw) if isinstance(_cc_raw, (int, float)) else 0 |
| 204 | _events_raw = args.get("events", []) |
| 205 | raw_events = [e for e in _events_raw if isinstance(e, dict)] if isinstance(_events_raw, list) else [] |
| 206 | cc_events: list[CCEventDict] = [ |
| 207 | _make_cc_event(cc_num, e) |
| 208 | for e in raw_events |
| 209 | ] |
| 210 | if cc_events: |
| 211 | store.add_cc(region_id, cc_events) |
| 212 | |
| 213 | elif tool == ToolName.ADD_PITCH_BEND.value: |
| 214 | _pb_raw = args.get("events", []) |
| 215 | raw_pb = [e for e in _pb_raw if isinstance(e, dict)] if isinstance(_pb_raw, list) else [] |
| 216 | pb_events: list[PitchBendDict] = [_make_pb_event(e) for e in raw_pb] |
| 217 | if pb_events: |
| 218 | store.add_pitch_bends(region_id, pb_events) |
| 219 | |
| 220 | elif tool == ToolName.ADD_AFTERTOUCH.value: |
| 221 | _at_raw = args.get("events", []) |
| 222 | raw_at = [e for e in _at_raw if isinstance(e, dict)] if isinstance(_at_raw, list) else [] |
| 223 | at_events: list[AftertouchDict] = [_make_at_event(e) for e in raw_at] |
| 224 | if at_events: |
| 225 | store.add_aftertouch(region_id, at_events) |
| 226 | |
| 227 | else: |
| 228 | raise ValueError(f"Unsupported checkout tool: {tool}") |