cgcardona / muse public
muse_replay.py python
267 lines 8.3 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Muse Replay Engine — deterministic history reconstruction from persisted data.
2
3 Builds replay plans by walking the variation lineage graph. A ReplayPlan
4 describes the ordered sequence of variations and phrases needed to
5 reconstruct musical state at any point in history.
6
7 Also provides HEAD snapshot reconstruction for drift detection.
8
9 Boundary rules:
10 - Must NOT import StateStore, EntityRegistry, or get_or_create_store.
11 - Must NOT import executor modules.
12 - Must NOT import LLM handlers or maestro_* modules.
13 - May import muse_repository (for lineage queries and domain loading).
14 - May import domain models from maestro.models.variation.
15 """
16
17 from __future__ import annotations
18
19 import logging
20 from dataclasses import dataclass, field
21
22 from sqlalchemy.ext.asyncio import AsyncSession
23
24 from maestro.contracts.json_types import (
25 NoteDict,
26 RegionAftertouchMap,
27 RegionCCMap,
28 RegionNotesMap,
29 RegionPitchBendMap,
30 )
31 from maestro.services import muse_repository
32 from maestro.services.muse_repository import HistoryNode
33
34 logger = logging.getLogger(__name__)
35
36
37
38 @dataclass(frozen=True)
39 class RegionUpdate:
40 """A region affected by a replay step."""
41
42 region_id: str
43 track_id: str
44 start_beat: float
45 end_beat: float
46
47
48 @dataclass(frozen=True)
49 class ReplayPlan:
50 """Deterministic reconstruction plan from root to target variation.
51
52 Contains everything needed to rebuild musical state without touching
53 StateStore or executor. Pure data.
54 """
55
56 ordered_variation_ids: list[str]
57 ordered_phrase_ids: list[str]
58 region_updates: list[RegionUpdate]
59 lineage: list[HistoryNode] = field(default_factory=list)
60
61
62 async def build_replay_plan(
63 session: AsyncSession,
64 project_id: str,
65 target_variation_id: str,
66 ) -> ReplayPlan | None:
67 """Build a replay plan from root to target variation.
68
69 Walks the lineage graph via parent_variation_id to find the path from
70 root to target, then collects phrases and region updates in order.
71
72 Returns None if the target variation does not exist.
73 """
74 lineage = await muse_repository.get_lineage(session, target_variation_id)
75 if not lineage:
76 return None
77
78 ordered_variation_ids: list[str] = []
79 ordered_phrase_ids: list[str] = []
80 region_updates: list[RegionUpdate] = []
81 seen_regions: set[str] = set()
82
83 for node in lineage:
84 ordered_variation_ids.append(node.variation_id)
85
86 variation = await muse_repository.load_variation(
87 session, node.variation_id,
88 )
89 if variation is None:
90 continue
91
92 for phrase in variation.phrases:
93 ordered_phrase_ids.append(phrase.phrase_id)
94 if phrase.region_id not in seen_regions:
95 seen_regions.add(phrase.region_id)
96 region_updates.append(RegionUpdate(
97 region_id=phrase.region_id,
98 track_id=phrase.track_id,
99 start_beat=phrase.start_beat,
100 end_beat=phrase.end_beat,
101 ))
102
103 logger.info(
104 "✅ Replay plan: %d variations, %d phrases, %d regions",
105 len(ordered_variation_ids),
106 len(ordered_phrase_ids),
107 len(region_updates),
108 )
109
110 return ReplayPlan(
111 ordered_variation_ids=ordered_variation_ids,
112 ordered_phrase_ids=ordered_phrase_ids,
113 region_updates=region_updates,
114 lineage=lineage,
115 )
116
117
118 # ── HEAD Snapshot Reconstruction (Phase 6) ────────────────────────────────
119
120
121 @dataclass(frozen=True)
122 class HeadSnapshot:
123 """Snapshot reconstructed from HEAD variation's persisted data.
124
125 Contains notes that Muse has committed (added/modified) and all
126 controller data (CC, pitch bends, aftertouch) from persisted phrases.
127
128 Notes that existed before Muse touched a region but were unchanged
129 are not included.
130 """
131
132 variation_id: str
133 notes: RegionNotesMap
134 cc: RegionCCMap
135 pitch_bends: RegionPitchBendMap
136 aftertouch: RegionAftertouchMap
137 track_regions: dict[str, str]
138 region_start_beats: dict[str, float]
139
140
141 async def reconstruct_head_snapshot(
142 session: AsyncSession,
143 project_id: str,
144 ) -> HeadSnapshot | None:
145 """Reconstruct a snapshot from the HEAD variation's persisted phrases.
146
147 Walks the full lineage from root to HEAD and collects the cumulative
148 note state for each Muse-touched region. For each NoteChange:
149
150 - ``added``: the ``after`` note is included in the result.
151 - ``modified``: the ``after`` note is included.
152 - ``removed``: no note is added (Muse removed it).
153
154 This is a *partial* reconstruction — it only reflects notes that Muse
155 created or modified. Notes that existed before Muse involvement are
156 not tracked.
157
158 Returns None if no HEAD exists for the project.
159 """
160 head = await muse_repository.get_head(session, project_id)
161 if head is None:
162 return None
163
164 lineage = await muse_repository.get_lineage(session, head.variation_id)
165 if not lineage:
166 return None
167
168 notes: RegionNotesMap = {}
169 cc: RegionCCMap = {}
170 pitch_bends: RegionPitchBendMap = {}
171 aftertouch: RegionAftertouchMap = {}
172 track_regions: dict[str, str] = {}
173 region_start_beats: dict[str, float] = {}
174
175 for node in lineage:
176 variation = await muse_repository.load_variation(
177 session, node.variation_id,
178 )
179 if variation is None:
180 continue
181
182 for phrase in variation.phrases:
183 rid = phrase.region_id
184 track_regions[rid] = phrase.track_id
185 region_start_beats[rid] = phrase.start_beat
186
187 region_notes = notes.setdefault(rid, [])
188 for nc in phrase.note_changes:
189 if nc.change_type in ("added", "modified") and nc.after:
190 region_notes.append(nc.after.to_note_dict())
191
192 cc.setdefault(rid, []).extend(phrase.cc_events)
193 pitch_bends.setdefault(rid, []).extend(phrase.pitch_bends)
194 aftertouch.setdefault(rid, []).extend(phrase.aftertouch)
195
196 logger.info(
197 "✅ HEAD snapshot reconstructed: %d regions, %d notes, %d cc, %d pb, %d at",
198 len(notes),
199 sum(len(n) for n in notes.values()),
200 sum(len(e) for e in cc.values()),
201 sum(len(e) for e in pitch_bends.values()),
202 sum(len(e) for e in aftertouch.values()),
203 )
204
205 return HeadSnapshot(
206 variation_id=head.variation_id,
207 notes=notes,
208 cc=cc,
209 pitch_bends=pitch_bends,
210 aftertouch=aftertouch,
211 track_regions=track_regions,
212 region_start_beats=region_start_beats,
213 )
214
215
216 async def reconstruct_variation_snapshot(
217 session: AsyncSession,
218 variation_id: str,
219 ) -> HeadSnapshot | None:
220 """Reconstruct snapshot at any variation (not necessarily HEAD).
221
222 Same lineage-walking logic as ``reconstruct_head_snapshot`` but targets
223 a specific variation_id instead of the project's current HEAD.
224
225 Returns None if the variation does not exist.
226 """
227 lineage = await muse_repository.get_lineage(session, variation_id)
228 if not lineage:
229 return None
230
231 notes: RegionNotesMap = {}
232 cc: RegionCCMap = {}
233 pitch_bends: RegionPitchBendMap = {}
234 aftertouch: RegionAftertouchMap = {}
235 track_regions: dict[str, str] = {}
236 region_start_beats: dict[str, float] = {}
237
238 for node in lineage:
239 variation = await muse_repository.load_variation(
240 session, node.variation_id,
241 )
242 if variation is None:
243 continue
244
245 for phrase in variation.phrases:
246 rid = phrase.region_id
247 track_regions[rid] = phrase.track_id
248 region_start_beats[rid] = phrase.start_beat
249
250 region_notes = notes.setdefault(rid, [])
251 for nc in phrase.note_changes:
252 if nc.change_type in ("added", "modified") and nc.after:
253 region_notes.append(nc.after.to_note_dict())
254
255 cc.setdefault(rid, []).extend(phrase.cc_events)
256 pitch_bends.setdefault(rid, []).extend(phrase.pitch_bends)
257 aftertouch.setdefault(rid, []).extend(phrase.aftertouch)
258
259 return HeadSnapshot(
260 variation_id=variation_id,
261 notes=notes,
262 cc=cc,
263 pitch_bends=pitch_bends,
264 aftertouch=aftertouch,
265 track_regions=track_regions,
266 region_start_beats=region_start_beats,
267 )