cgcardona / muse public
muse_checkout.py python
327 lines 10.7 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Muse Checkout Engine — translate ReplayPlan into DAW tool calls.
2
3 Converts a target snapshot (from replay/reconstruction) into a
4 deterministic, ordered stream of tool calls that would reconstruct
5 the target musical state from the current working state.
6
7 Pure data translator — does NOT execute tool calls.
8
9 Boundary rules:
10 - Must NOT import StateStore, EntityRegistry, or get_or_create_store.
11 - Must NOT import executor modules or app.core.executor.*.
12 - Must NOT import LLM handlers or maestro_* modules.
13 - May import muse_replay (HeadSnapshot), muse_drift (fingerprinting).
14 - May import ToolName enum from maestro.core.tool_names.
15 """
16
17 from __future__ import annotations
18
19 import hashlib
20 import json
21 import logging
22 from dataclasses import dataclass, field
23 from typing_extensions import TypedDict
24
25 from maestro.contracts.json_types import (
26 AftertouchDict,
27 CCEventDict,
28 JSONValue,
29 NoteDict,
30 PitchBendDict,
31 RegionAftertouchMap,
32 RegionCCMap,
33 RegionNotesMap,
34 RegionPitchBendMap,
35 json_list,
36 )
37 from maestro.core.tools import ToolName
38 from maestro.services.muse_drift import _fingerprint, _combined_fingerprint
39 from maestro.services.variation.note_matching import (
40 match_notes,
41 match_cc_events,
42 match_pitch_bends,
43 match_aftertouch,
44 )
45
46 logger = logging.getLogger(__name__)
47
48 REGION_RESET_THRESHOLD = 20
49
50
51 class CheckoutToolCall(TypedDict):
52 """A single tool call produced by the checkout planner.
53
54 Structural twin of ``ToolCall`` but serialised as a plain TypedDict so
55 it can be stored in frozen dataclasses (tuples) without hashability issues.
56
57 Attributes:
58 tool: Canonical ``ToolName`` value string (e.g. ``"stori_add_notes"``).
59 arguments: Keyword arguments forwarded verbatim to the executor.
60 """
61
62 tool: str
63 arguments: dict[str, JSONValue]
64
65
66 @dataclass(frozen=True)
67 class CheckoutPlan:
68 """Deterministic, immutable plan for restoring a variation's state.
69
70 Produced by ``build_checkout_plan`` — a pure function that diffs the
71 current working tree against the target variation. Consumed by
72 ``execute_checkout_plan`` in ``muse_checkout_executor``.
73
74 Pure data — no side effects, no mutations.
75
76 Attributes:
77 project_id: Project the checkout targets.
78 target_variation_id: Variation UUID to restore.
79 tool_calls: Ordered sequence of ``CheckoutToolCall``s that, when
80 executed, transform the working tree into the target state.
81 regions_reset: Region UUIDs that required a full clear + re-add because
82 the diff exceeded ``REGION_RESET_THRESHOLD`` or had removals.
83 fingerprint_target: Expected ``{region_id: sha256}`` fingerprint map
84 after execution — used to verify the checkout landed correctly.
85 """
86
87 project_id: str
88 target_variation_id: str
89 tool_calls: tuple[CheckoutToolCall, ...]
90 regions_reset: tuple[str, ...]
91 fingerprint_target: dict[str, str]
92
93 @property
94 def is_noop(self) -> bool:
95 """``True`` when the working tree already matches the target (no calls needed)."""
96 return len(self.tool_calls) == 0
97
98 def plan_hash(self) -> str:
99 """Deterministic hash of the entire plan for idempotency checks."""
100 raw = json.dumps(
101 {
102 "project_id": self.project_id,
103 "target": self.target_variation_id,
104 "calls": list(self.tool_calls),
105 "resets": list(self.regions_reset),
106 },
107 sort_keys=True,
108 default=str,
109 )
110 return hashlib.sha256(raw.encode()).hexdigest()[:32]
111
112
113 def _make_tool_call(tool: ToolName, arguments: dict[str, JSONValue]) -> CheckoutToolCall:
114 """Construct a ``CheckoutToolCall`` from a ``ToolName`` enum and argument dict."""
115 return {"tool": tool.value, "arguments": arguments}
116
117
118 def _build_region_note_calls(
119 region_id: str,
120 target_notes: list[NoteDict],
121 working_notes: list[NoteDict],
122 ) -> tuple[list[CheckoutToolCall], bool]:
123 """Produce tool calls to transition notes from working → target.
124
125 Returns (tool_calls, was_reset). Uses region reset (clear + add) when
126 there are removals/modifications or the diff exceeds the threshold,
127 because there is no individual note-remove tool.
128 """
129 matches = match_notes(working_notes, target_notes)
130
131 added = [m for m in matches if m.is_added]
132 removed = [m for m in matches if m.is_removed]
133 modified = [m for m in matches if m.is_modified]
134
135 if not added and not removed and not modified:
136 return [], False
137
138 total_changes = len(added) + len(removed) + len(modified)
139 needs_reset = bool(removed or modified) or total_changes >= REGION_RESET_THRESHOLD
140
141 calls: list[CheckoutToolCall] = []
142
143 if needs_reset:
144 calls.append(_make_tool_call(ToolName.CLEAR_NOTES, {"regionId": region_id}))
145 if target_notes:
146 notes_for_reset: list[JSONValue] = [
147 {
148 "pitch": n.get("pitch", 60),
149 "startBeat": n.get("start_beat", 0.0),
150 "durationBeats": n.get("duration_beats", 0.5),
151 "velocity": n.get("velocity", 100),
152 }
153 for n in target_notes
154 ]
155 calls.append(_make_tool_call(
156 ToolName.ADD_NOTES,
157 {"regionId": region_id, "notes": notes_for_reset},
158 ))
159 return calls, True
160
161 if added:
162 notes_to_add: list[JSONValue] = [
163 {
164 "pitch": m.proposed_note.get("pitch", 60),
165 "startBeat": m.proposed_note.get("start_beat", 0.0),
166 "durationBeats": m.proposed_note.get("duration_beats", 0.5),
167 "velocity": m.proposed_note.get("velocity", 100),
168 }
169 for m in added
170 if m.proposed_note is not None
171 ]
172 if notes_to_add:
173 calls.append(_make_tool_call(
174 ToolName.ADD_NOTES,
175 {"regionId": region_id, "notes": notes_to_add},
176 ))
177
178 return calls, False
179
180
181 def _build_cc_calls(
182 region_id: str,
183 target_cc: list[CCEventDict],
184 working_cc: list[CCEventDict],
185 ) -> list[CheckoutToolCall]:
186 matches = match_cc_events(working_cc, target_cc)
187 needed = [m for m in matches if m.is_added or m.is_modified]
188 if not needed:
189 return []
190
191 by_cc: dict[int, list[CCEventDict]] = {}
192 for m in needed:
193 ev = m.proposed_event
194 if ev is None:
195 continue
196 raw_cc = ev.get("cc", 0)
197 cc_num = int(raw_cc) if isinstance(raw_cc, (int, float, str)) else 0
198 by_cc.setdefault(cc_num, []).append(CCEventDict(
199 cc=cc_num,
200 beat=float(ev.get("beat", 0.0) or 0.0),
201 value=int(ev.get("value", 0) or 0),
202 ))
203
204 calls: list[CheckoutToolCall] = []
205 for cc_num in sorted(by_cc):
206 cc_events_json: list[JSONValue] = json_list(by_cc[cc_num])
207 calls.append(_make_tool_call(
208 ToolName.ADD_MIDI_CC,
209 {"regionId": region_id, "cc": cc_num, "events": cc_events_json},
210 ))
211 return calls
212
213
214 def _build_pb_calls(
215 region_id: str,
216 target_pb: list[PitchBendDict],
217 working_pb: list[PitchBendDict],
218 ) -> list[CheckoutToolCall]:
219 matches = match_pitch_bends(working_pb, target_pb)
220 needed = [m for m in matches if m.is_added or m.is_modified]
221 if not needed:
222 return []
223
224 pb_events: list[JSONValue] = [
225 {"beat": m.proposed_event.get("beat", 0.0), "value": m.proposed_event.get("value", 0)}
226 for m in needed
227 if m.proposed_event is not None
228 ]
229 return [_make_tool_call(
230 ToolName.ADD_PITCH_BEND,
231 {"regionId": region_id, "events": pb_events},
232 )]
233
234
235 def _build_at_calls(
236 region_id: str,
237 target_at: list[AftertouchDict],
238 working_at: list[AftertouchDict],
239 ) -> list[CheckoutToolCall]:
240 matches = match_aftertouch(working_at, target_at)
241 needed = [m for m in matches if m.is_added or m.is_modified]
242 if not needed:
243 return []
244
245 at_events: list[JSONValue] = []
246 for m in needed:
247 ev = m.proposed_event
248 if ev is None:
249 continue
250 beat = ev.get("beat", 0.0)
251 value = ev.get("value", 0)
252 pitch_val = ev.get("pitch")
253 at_entry: dict[str, JSONValue] = {"beat": beat, "value": value}
254 if isinstance(pitch_val, int):
255 at_entry["pitch"] = pitch_val
256 at_events.append(at_entry)
257 return [_make_tool_call(
258 ToolName.ADD_AFTERTOUCH,
259 {"regionId": region_id, "events": at_events},
260 )]
261
262
263 def build_checkout_plan(
264 *,
265 project_id: str,
266 target_variation_id: str,
267 target_notes: RegionNotesMap,
268 target_cc: RegionCCMap,
269 target_pb: RegionPitchBendMap,
270 target_at: RegionAftertouchMap,
271 working_notes: RegionNotesMap,
272 working_cc: RegionCCMap,
273 working_pb: RegionPitchBendMap,
274 working_at: RegionAftertouchMap,
275 track_regions: dict[str, str],
276 ) -> CheckoutPlan:
277 """Build a checkout plan that transforms working state → target state.
278
279 Produces an ordered sequence of tool calls:
280 1. ``stori_clear_notes`` (region resets, when needed)
281 2. ``stori_add_notes``
282 3. ``stori_add_midi_cc`` / ``stori_add_pitch_bend`` / ``stori_add_aftertouch``
283
284 Pure function — no I/O, no StateStore.
285 """
286 all_rids = sorted(
287 set(target_notes) | set(target_cc) | set(target_pb) | set(target_at)
288 | set(working_notes) | set(working_cc) | set(working_pb) | set(working_at)
289 )
290
291 tool_calls: list[CheckoutToolCall] = []
292 regions_reset: list[str] = []
293 fingerprint_target: dict[str, str] = {}
294
295 for rid in all_rids:
296 t_notes = target_notes.get(rid, [])
297 w_notes = working_notes.get(rid, [])
298 t_cc = target_cc.get(rid, [])
299 w_cc = working_cc.get(rid, [])
300 t_pb = target_pb.get(rid, [])
301 w_pb = working_pb.get(rid, [])
302 t_at = target_at.get(rid, [])
303 w_at = working_at.get(rid, [])
304
305 fingerprint_target[rid] = _combined_fingerprint(t_notes, t_cc, t_pb, t_at)
306
307 note_calls, was_reset = _build_region_note_calls(rid, t_notes, w_notes)
308 if was_reset:
309 regions_reset.append(rid)
310 tool_calls.extend(note_calls)
311
312 tool_calls.extend(_build_cc_calls(rid, t_cc, w_cc))
313 tool_calls.extend(_build_pb_calls(rid, t_pb, w_pb))
314 tool_calls.extend(_build_at_calls(rid, t_at, w_at))
315
316 logger.info(
317 "✅ Checkout plan: %d tool calls, %d region resets, %d regions",
318 len(tool_calls), len(regions_reset), len(all_rids),
319 )
320
321 return CheckoutPlan(
322 project_id=project_id,
323 target_variation_id=target_variation_id,
324 tool_calls=tuple(tool_calls),
325 regions_reset=tuple(regions_reset),
326 fingerprint_target=fingerprint_target,
327 )