muse_replay.py
python
| 1 | """Muse Replay Engine — deterministic history reconstruction from persisted data. |
| 2 | |
| 3 | Builds replay plans by walking the variation lineage graph. A ReplayPlan |
| 4 | describes the ordered sequence of variations and phrases needed to |
| 5 | reconstruct musical state at any point in history. |
| 6 | |
| 7 | Also provides HEAD snapshot reconstruction for drift detection. |
| 8 | |
| 9 | Boundary rules: |
| 10 | - Must NOT import StateStore, EntityRegistry, or get_or_create_store. |
| 11 | - Must NOT import executor modules. |
| 12 | - Must NOT import LLM handlers or maestro_* modules. |
| 13 | - May import muse_repository (for lineage queries and domain loading). |
| 14 | - May import domain models from maestro.models.variation. |
| 15 | """ |
| 16 | |
| 17 | from __future__ import annotations |
| 18 | |
| 19 | import logging |
| 20 | from dataclasses import dataclass, field |
| 21 | |
| 22 | from sqlalchemy.ext.asyncio import AsyncSession |
| 23 | |
| 24 | from maestro.contracts.json_types import ( |
| 25 | NoteDict, |
| 26 | RegionAftertouchMap, |
| 27 | RegionCCMap, |
| 28 | RegionNotesMap, |
| 29 | RegionPitchBendMap, |
| 30 | ) |
| 31 | from maestro.services import muse_repository |
| 32 | from maestro.services.muse_repository import HistoryNode |
| 33 | |
| 34 | logger = logging.getLogger(__name__) |
| 35 | |
| 36 | |
| 37 | |
| 38 | @dataclass(frozen=True) |
| 39 | class RegionUpdate: |
| 40 | """A region affected by a replay step.""" |
| 41 | |
| 42 | region_id: str |
| 43 | track_id: str |
| 44 | start_beat: float |
| 45 | end_beat: float |
| 46 | |
| 47 | |
| 48 | @dataclass(frozen=True) |
| 49 | class ReplayPlan: |
| 50 | """Deterministic reconstruction plan from root to target variation. |
| 51 | |
| 52 | Contains everything needed to rebuild musical state without touching |
| 53 | StateStore or executor. Pure data. |
| 54 | """ |
| 55 | |
| 56 | ordered_variation_ids: list[str] |
| 57 | ordered_phrase_ids: list[str] |
| 58 | region_updates: list[RegionUpdate] |
| 59 | lineage: list[HistoryNode] = field(default_factory=list) |
| 60 | |
| 61 | |
| 62 | async def build_replay_plan( |
| 63 | session: AsyncSession, |
| 64 | project_id: str, |
| 65 | target_variation_id: str, |
| 66 | ) -> ReplayPlan | None: |
| 67 | """Build a replay plan from root to target variation. |
| 68 | |
| 69 | Walks the lineage graph via parent_variation_id to find the path from |
| 70 | root to target, then collects phrases and region updates in order. |
| 71 | |
| 72 | Returns None if the target variation does not exist. |
| 73 | """ |
| 74 | lineage = await muse_repository.get_lineage(session, target_variation_id) |
| 75 | if not lineage: |
| 76 | return None |
| 77 | |
| 78 | ordered_variation_ids: list[str] = [] |
| 79 | ordered_phrase_ids: list[str] = [] |
| 80 | region_updates: list[RegionUpdate] = [] |
| 81 | seen_regions: set[str] = set() |
| 82 | |
| 83 | for node in lineage: |
| 84 | ordered_variation_ids.append(node.variation_id) |
| 85 | |
| 86 | variation = await muse_repository.load_variation( |
| 87 | session, node.variation_id, |
| 88 | ) |
| 89 | if variation is None: |
| 90 | continue |
| 91 | |
| 92 | for phrase in variation.phrases: |
| 93 | ordered_phrase_ids.append(phrase.phrase_id) |
| 94 | if phrase.region_id not in seen_regions: |
| 95 | seen_regions.add(phrase.region_id) |
| 96 | region_updates.append(RegionUpdate( |
| 97 | region_id=phrase.region_id, |
| 98 | track_id=phrase.track_id, |
| 99 | start_beat=phrase.start_beat, |
| 100 | end_beat=phrase.end_beat, |
| 101 | )) |
| 102 | |
| 103 | logger.info( |
| 104 | "✅ Replay plan: %d variations, %d phrases, %d regions", |
| 105 | len(ordered_variation_ids), |
| 106 | len(ordered_phrase_ids), |
| 107 | len(region_updates), |
| 108 | ) |
| 109 | |
| 110 | return ReplayPlan( |
| 111 | ordered_variation_ids=ordered_variation_ids, |
| 112 | ordered_phrase_ids=ordered_phrase_ids, |
| 113 | region_updates=region_updates, |
| 114 | lineage=lineage, |
| 115 | ) |
| 116 | |
| 117 | |
| 118 | # ── HEAD Snapshot Reconstruction (Phase 6) ──────────────────────────────── |
| 119 | |
| 120 | |
| 121 | @dataclass(frozen=True) |
| 122 | class HeadSnapshot: |
| 123 | """Snapshot reconstructed from HEAD variation's persisted data. |
| 124 | |
| 125 | Contains notes that Muse has committed (added/modified) and all |
| 126 | controller data (CC, pitch bends, aftertouch) from persisted phrases. |
| 127 | |
| 128 | Notes that existed before Muse touched a region but were unchanged |
| 129 | are not included. |
| 130 | """ |
| 131 | |
| 132 | variation_id: str |
| 133 | notes: RegionNotesMap |
| 134 | cc: RegionCCMap |
| 135 | pitch_bends: RegionPitchBendMap |
| 136 | aftertouch: RegionAftertouchMap |
| 137 | track_regions: dict[str, str] |
| 138 | region_start_beats: dict[str, float] |
| 139 | |
| 140 | |
| 141 | async def reconstruct_head_snapshot( |
| 142 | session: AsyncSession, |
| 143 | project_id: str, |
| 144 | ) -> HeadSnapshot | None: |
| 145 | """Reconstruct a snapshot from the HEAD variation's persisted phrases. |
| 146 | |
| 147 | Walks the full lineage from root to HEAD and collects the cumulative |
| 148 | note state for each Muse-touched region. For each NoteChange: |
| 149 | |
| 150 | - ``added``: the ``after`` note is included in the result. |
| 151 | - ``modified``: the ``after`` note is included. |
| 152 | - ``removed``: no note is added (Muse removed it). |
| 153 | |
| 154 | This is a *partial* reconstruction — it only reflects notes that Muse |
| 155 | created or modified. Notes that existed before Muse involvement are |
| 156 | not tracked. |
| 157 | |
| 158 | Returns None if no HEAD exists for the project. |
| 159 | """ |
| 160 | head = await muse_repository.get_head(session, project_id) |
| 161 | if head is None: |
| 162 | return None |
| 163 | |
| 164 | lineage = await muse_repository.get_lineage(session, head.variation_id) |
| 165 | if not lineage: |
| 166 | return None |
| 167 | |
| 168 | notes: RegionNotesMap = {} |
| 169 | cc: RegionCCMap = {} |
| 170 | pitch_bends: RegionPitchBendMap = {} |
| 171 | aftertouch: RegionAftertouchMap = {} |
| 172 | track_regions: dict[str, str] = {} |
| 173 | region_start_beats: dict[str, float] = {} |
| 174 | |
| 175 | for node in lineage: |
| 176 | variation = await muse_repository.load_variation( |
| 177 | session, node.variation_id, |
| 178 | ) |
| 179 | if variation is None: |
| 180 | continue |
| 181 | |
| 182 | for phrase in variation.phrases: |
| 183 | rid = phrase.region_id |
| 184 | track_regions[rid] = phrase.track_id |
| 185 | region_start_beats[rid] = phrase.start_beat |
| 186 | |
| 187 | region_notes = notes.setdefault(rid, []) |
| 188 | for nc in phrase.note_changes: |
| 189 | if nc.change_type in ("added", "modified") and nc.after: |
| 190 | region_notes.append(nc.after.to_note_dict()) |
| 191 | |
| 192 | cc.setdefault(rid, []).extend(phrase.cc_events) |
| 193 | pitch_bends.setdefault(rid, []).extend(phrase.pitch_bends) |
| 194 | aftertouch.setdefault(rid, []).extend(phrase.aftertouch) |
| 195 | |
| 196 | logger.info( |
| 197 | "✅ HEAD snapshot reconstructed: %d regions, %d notes, %d cc, %d pb, %d at", |
| 198 | len(notes), |
| 199 | sum(len(n) for n in notes.values()), |
| 200 | sum(len(e) for e in cc.values()), |
| 201 | sum(len(e) for e in pitch_bends.values()), |
| 202 | sum(len(e) for e in aftertouch.values()), |
| 203 | ) |
| 204 | |
| 205 | return HeadSnapshot( |
| 206 | variation_id=head.variation_id, |
| 207 | notes=notes, |
| 208 | cc=cc, |
| 209 | pitch_bends=pitch_bends, |
| 210 | aftertouch=aftertouch, |
| 211 | track_regions=track_regions, |
| 212 | region_start_beats=region_start_beats, |
| 213 | ) |
| 214 | |
| 215 | |
| 216 | async def reconstruct_variation_snapshot( |
| 217 | session: AsyncSession, |
| 218 | variation_id: str, |
| 219 | ) -> HeadSnapshot | None: |
| 220 | """Reconstruct snapshot at any variation (not necessarily HEAD). |
| 221 | |
| 222 | Same lineage-walking logic as ``reconstruct_head_snapshot`` but targets |
| 223 | a specific variation_id instead of the project's current HEAD. |
| 224 | |
| 225 | Returns None if the variation does not exist. |
| 226 | """ |
| 227 | lineage = await muse_repository.get_lineage(session, variation_id) |
| 228 | if not lineage: |
| 229 | return None |
| 230 | |
| 231 | notes: RegionNotesMap = {} |
| 232 | cc: RegionCCMap = {} |
| 233 | pitch_bends: RegionPitchBendMap = {} |
| 234 | aftertouch: RegionAftertouchMap = {} |
| 235 | track_regions: dict[str, str] = {} |
| 236 | region_start_beats: dict[str, float] = {} |
| 237 | |
| 238 | for node in lineage: |
| 239 | variation = await muse_repository.load_variation( |
| 240 | session, node.variation_id, |
| 241 | ) |
| 242 | if variation is None: |
| 243 | continue |
| 244 | |
| 245 | for phrase in variation.phrases: |
| 246 | rid = phrase.region_id |
| 247 | track_regions[rid] = phrase.track_id |
| 248 | region_start_beats[rid] = phrase.start_beat |
| 249 | |
| 250 | region_notes = notes.setdefault(rid, []) |
| 251 | for nc in phrase.note_changes: |
| 252 | if nc.change_type in ("added", "modified") and nc.after: |
| 253 | region_notes.append(nc.after.to_note_dict()) |
| 254 | |
| 255 | cc.setdefault(rid, []).extend(phrase.cc_events) |
| 256 | pitch_bends.setdefault(rid, []).extend(phrase.pitch_bends) |
| 257 | aftertouch.setdefault(rid, []).extend(phrase.aftertouch) |
| 258 | |
| 259 | return HeadSnapshot( |
| 260 | variation_id=variation_id, |
| 261 | notes=notes, |
| 262 | cc=cc, |
| 263 | pitch_bends=pitch_bends, |
| 264 | aftertouch=aftertouch, |
| 265 | track_regions=track_regions, |
| 266 | region_start_beats=region_start_beats, |
| 267 | ) |