cgcardona / muse public
muse_checkout_executor.py python
228 lines 7.9 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Muse Checkout Executor — apply a CheckoutPlan to StateStore.
2
3 Dispatches each tool call in the plan through the existing StateStore
4 mutation methods, producing SSE-compatible events that the DAW
5 processes identically to normal editing execution.
6
7 Boundary rules:
8 - Must NOT import LLM handlers or maestro_* modules.
9 - Must NOT import VariationService.
10 - Must NOT import muse_replay internals.
11 - May import tool_names, state_store, tracing, muse_checkout, muse_drift.
12 """
13
14 from __future__ import annotations
15
16 import logging
17 import uuid
18 from dataclasses import dataclass, field
19
20 from maestro.contracts.json_types import AftertouchDict, CCEventDict, JSONValue, NoteDict, PitchBendDict, is_note_dict, jfloat, jint
21 from maestro.core.state_store import StateStore, Transaction
22 from maestro.core.tools import ToolName
23 from maestro.core.tracing import TraceContext, trace_span
24 from maestro.services.muse_checkout import CheckoutPlan
25
26 logger = logging.getLogger(__name__)
27
28
29 @dataclass(frozen=True)
30 class CheckoutExecutionResult:
31 """Immutable summary returned after executing a ``CheckoutPlan``.
32
33 Callers inspect ``success`` / ``is_noop`` to decide whether to commit the
34 checkout transaction and emit the Muse SSE ``checkout`` event.
35
36 Attributes:
37 project_id: Project the checkout was applied to.
38 target_variation_id: Variation UUID that was restored.
39 executed: Number of tool calls that ran without error.
40 failed: Number of tool calls that raised an exception.
41 plan_hash: SHA-256 prefix of the ``CheckoutPlan`` for idempotency logs.
42 events: SSE-compatible event dicts emitted during execution; forwarded
43 to the Muse SSE stream by the route handler.
44 """
45
46 project_id: str
47 target_variation_id: str
48 executed: int
49 failed: int
50 plan_hash: str
51 events: tuple[dict[str, JSONValue], ...] = ()
52
53 @property
54 def success(self) -> bool:
55 """``True`` when all tool calls executed without error (and at least one ran)."""
56 return self.failed == 0 and self.executed > 0
57
58 @property
59 def is_noop(self) -> bool:
60 """``True`` when the plan had no tool calls (working tree already matched target)."""
61 return self.executed == 0 and self.failed == 0
62
63
64 def execute_checkout_plan(
65 *,
66 checkout_plan: CheckoutPlan,
67 store: StateStore,
68 trace: TraceContext,
69 emit_sse: bool = True,
70 ) -> CheckoutExecutionResult:
71 """Execute a CheckoutPlan by dispatching tool calls to StateStore.
72
73 Each tool call in the plan is applied in deterministic order:
74 ``stori_clear_notes`` → ``stori_add_notes`` → controllers.
75
76 The ``store`` parameter is typed as ``Any`` to avoid importing
77 ``StateStore`` directly — the caller passes a concrete store
78 instance. The executor calls its public methods only.
79
80 Args:
81 checkout_plan: The plan to execute (pure data).
82 store: StateStore instance (duck-typed).
83 trace: TraceContext for logging and spans.
84 emit_sse: When True, collect SSE-compatible events.
85 """
86 if checkout_plan.is_noop:
87 logger.info("✅ Checkout is no-op — nothing to execute")
88 return CheckoutExecutionResult(
89 project_id=checkout_plan.project_id,
90 target_variation_id=checkout_plan.target_variation_id,
91 executed=0,
92 failed=0,
93 plan_hash=checkout_plan.plan_hash(),
94 )
95
96 executed = 0
97 failed = 0
98 events: list[dict[str, JSONValue]] = []
99
100 txn = store.begin_transaction(
101 f"checkout:{checkout_plan.target_variation_id[:8]}",
102 )
103
104 try:
105 with trace_span(trace, "checkout_execution", {
106 "target": checkout_plan.target_variation_id,
107 "call_count": len(checkout_plan.tool_calls),
108 }):
109 for call in checkout_plan.tool_calls:
110 tool = call["tool"]
111 args = call["arguments"]
112 call_id = str(uuid.uuid4())
113
114 try:
115 with trace_span(trace, f"checkout_tool:{tool}"):
116 _dispatch_tool(tool, args, store, txn)
117
118 if emit_sse:
119 events.append({
120 "type": "toolCall",
121 "id": call_id,
122 "tool": tool,
123 "params": args,
124 })
125 executed += 1
126
127 except Exception as e:
128 failed += 1
129 logger.error("❌ Checkout tool failed: %s — %s", tool, e)
130 if emit_sse:
131 events.append({
132 "type": "toolError",
133 "id": call_id,
134 "tool": tool,
135 "error": str(e),
136 })
137
138 if failed == 0:
139 store.commit(txn)
140 logger.info(
141 "✅ Checkout executed: %d calls, target=%s",
142 executed, checkout_plan.target_variation_id[:8],
143 )
144 else:
145 store.rollback(txn)
146 logger.warning(
147 "⚠️ Checkout rolled back: %d executed, %d failed",
148 executed, failed,
149 )
150
151 except Exception:
152 store.rollback(txn)
153 raise
154
155 return CheckoutExecutionResult(
156 project_id=checkout_plan.project_id,
157 target_variation_id=checkout_plan.target_variation_id,
158 executed=executed,
159 failed=failed,
160 plan_hash=checkout_plan.plan_hash(),
161 events=tuple(events),
162 )
163
164
165 def _make_cc_event(cc_num: int, e: dict[str, JSONValue]) -> CCEventDict:
166 return {"cc": cc_num, "beat": jfloat(e.get("beat")), "value": jint(e.get("value"))}
167
168
169 def _make_pb_event(e: dict[str, JSONValue]) -> PitchBendDict:
170 return {"beat": jfloat(e.get("beat")), "value": jint(e.get("value"))}
171
172
173 def _make_at_event(e: dict[str, JSONValue]) -> AftertouchDict:
174 return {"beat": jfloat(e.get("beat")), "value": jint(e.get("value"))}
175
176
177 def _dispatch_tool(
178 tool: str,
179 args: dict[str, JSONValue],
180 store: StateStore,
181 txn: Transaction,
182 ) -> None:
183 """Dispatch a single checkout tool call to StateStore methods."""
184 _rid_raw = args.get("regionId", "")
185 region_id = _rid_raw if isinstance(_rid_raw, str) else ""
186
187 if tool == ToolName.CLEAR_NOTES.value:
188 current = store.get_region_notes(region_id)
189 if current:
190 store.remove_notes(region_id, current, transaction=txn)
191
192 elif tool == ToolName.ADD_NOTES.value:
193 _notes_raw = args.get("notes", [])
194 notes: list[NoteDict] = (
195 [n for n in _notes_raw if is_note_dict(n)]
196 if isinstance(_notes_raw, list) else []
197 )
198 if notes:
199 store.add_notes(region_id, notes, transaction=txn)
200
201 elif tool == ToolName.ADD_MIDI_CC.value:
202 _cc_raw = args.get("cc", 0)
203 cc_num = int(_cc_raw) if isinstance(_cc_raw, (int, float)) else 0
204 _events_raw = args.get("events", [])
205 raw_events = [e for e in _events_raw if isinstance(e, dict)] if isinstance(_events_raw, list) else []
206 cc_events: list[CCEventDict] = [
207 _make_cc_event(cc_num, e)
208 for e in raw_events
209 ]
210 if cc_events:
211 store.add_cc(region_id, cc_events)
212
213 elif tool == ToolName.ADD_PITCH_BEND.value:
214 _pb_raw = args.get("events", [])
215 raw_pb = [e for e in _pb_raw if isinstance(e, dict)] if isinstance(_pb_raw, list) else []
216 pb_events: list[PitchBendDict] = [_make_pb_event(e) for e in raw_pb]
217 if pb_events:
218 store.add_pitch_bends(region_id, pb_events)
219
220 elif tool == ToolName.ADD_AFTERTOUCH.value:
221 _at_raw = args.get("events", [])
222 raw_at = [e for e in _at_raw if isinstance(e, dict)] if isinstance(_at_raw, list) else []
223 at_events: list[AftertouchDict] = [_make_at_event(e) for e in raw_at]
224 if at_events:
225 store.add_aftertouch(region_id, at_events)
226
227 else:
228 raise ValueError(f"Unsupported checkout tool: {tool}")