cgcardona / muse public
muse_drift.py python
391 lines 13.7 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Muse Drift Detection Engine'git status' for music.
2
3 Compares a HEAD snapshot (from persisted variation history) against
4 a working snapshot (live StateStore capture) to produce a deterministic
5 DriftReport describing what changed since the last commit.
6
7 Diffs notes AND controller data (CC, pitch bends, aftertouch).
8
9 Pure data — no side effects, no mutations.
10
11 Boundary rules:
12 - Must NOT import StateStore, EntityRegistry, or get_or_create_store.
13 - Must NOT import executor modules or app.core.executor.*.
14 - Must NOT import LLM handlers or maestro_* modules.
15 - May import note_matching from VariationService (pure diff logic).
16 """
17
18 from __future__ import annotations
19
20 import hashlib
21 import json
22 import logging
23 from collections.abc import Mapping, Sequence
24 from dataclasses import dataclass, field
25 from enum import Enum
26 from typing import Literal
27
28 from typing_extensions import TypedDict
29
30 from maestro.contracts.json_types import (
31 AftertouchDict,
32 CCEventDict,
33 NoteDict,
34 PitchBendDict,
35 RegionAftertouchMap,
36 RegionCCMap,
37 RegionMetadataDB,
38 RegionNotesMap,
39 RegionPitchBendMap,
40 )
41 from maestro.services.variation.note_matching import (
42 match_notes,
43 match_cc_events,
44 match_pitch_bends,
45 match_aftertouch,
46 )
47
48 logger = logging.getLogger(__name__)
49
50 MAX_SAMPLE_CHANGES = 5
51
52
53 class SampleChange(TypedDict, total=False):
54 """A single note change captured as a human-readable diff sample.
55
56 ``type`` is always present; ``note``/``before``/``after`` depend on type.
57 """
58
59 type: Literal["added", "removed", "modified"]
60 note: NoteDict | None
61 before: NoteDict | None
62 after: NoteDict | None
63
64
65 class DriftSeverity(str, Enum):
66 """How much the working tree has diverged from HEAD."""
67
68 CLEAN = "clean"
69 DIRTY = "dirty"
70 DIVERGED = "diverged"
71
72
73 @dataclass(frozen=True)
74 class RegionDriftSummary:
75 """Per-region drift summary with note + controller change counts."""
76
77 region_id: str
78 track_id: str
79 # Notes
80 added: int = 0
81 removed: int = 0
82 modified: int = 0
83 # CC
84 cc_added: int = 0
85 cc_removed: int = 0
86 cc_modified: int = 0
87 # Pitch bends
88 pb_added: int = 0
89 pb_removed: int = 0
90 pb_modified: int = 0
91 # Aftertouch
92 at_added: int = 0
93 at_removed: int = 0
94 at_modified: int = 0
95
96 sample_changes: tuple[SampleChange, ...] = ()
97 head_fingerprint: str = ""
98 working_fingerprint: str = ""
99
100 @property
101 def is_clean(self) -> bool:
102 """``True`` when notes, CC, pitch bends, and aftertouch all have zero changes."""
103 return (
104 self.added == 0 and self.removed == 0 and self.modified == 0
105 and self.cc_added == 0 and self.cc_removed == 0 and self.cc_modified == 0
106 and self.pb_added == 0 and self.pb_removed == 0 and self.pb_modified == 0
107 and self.at_added == 0 and self.at_removed == 0 and self.at_modified == 0
108 )
109
110
111 @dataclass(frozen=True)
112 class DriftReport:
113 """Deterministic report of working-tree vs HEAD divergence.
114
115 Covers notes and all controller data (CC, pitch bends, aftertouch).
116 """
117
118 project_id: str
119 head_variation_id: str
120 severity: DriftSeverity
121 is_clean: bool
122 changed_regions: tuple[str, ...] = ()
123 added_regions: tuple[str, ...] = ()
124 deleted_regions: tuple[str, ...] = ()
125 region_summaries: dict[str, RegionDriftSummary] = field(default_factory=dict)
126
127 @property
128 def total_changes(self) -> int:
129 """Sum of all note and controller changes across every region in the drift."""
130 return sum(
131 s.added + s.removed + s.modified
132 + s.cc_added + s.cc_removed + s.cc_modified
133 + s.pb_added + s.pb_removed + s.pb_modified
134 + s.at_added + s.at_removed + s.at_modified
135 for s in self.region_summaries.values()
136 )
137
138 def requires_user_action(self) -> bool:
139 """Whether this drift state should block a commit."""
140 return self.severity != DriftSeverity.CLEAN
141
142
143 @dataclass(frozen=True)
144 class CommitConflictPayload:
145 """Lightweight conflict summary returned in 409 responses.
146
147 Derived from DriftReport — excludes bulky sample_changes and
148 full region_summaries to keep the payload small.
149 """
150
151 project_id: str
152 head_variation_id: str
153 severity: str
154 changed_regions: tuple[str, ...]
155 added_regions: tuple[str, ...]
156 deleted_regions: tuple[str, ...]
157 total_changes: int
158 fingerprint_delta: dict[str, tuple[str, str]]
159
160 @classmethod
161 def from_drift_report(cls, report: DriftReport) -> "CommitConflictPayload":
162 """Construct a lightweight conflict payload from a full ``DriftReport``.
163
164 Excludes ``sample_changes`` and full ``region_summaries`` to keep the
165 409 response body small. The ``fingerprint_delta`` maps each dirty
166 region to ``(head_fingerprint, working_fingerprint)`` so the client
167 can identify exactly which regions changed without reading all note data.
168 """
169 fp_delta: dict[str, tuple[str, str]] = {}
170 for rid, summary in report.region_summaries.items():
171 if not summary.is_clean:
172 fp_delta[rid] = (summary.head_fingerprint, summary.working_fingerprint)
173 return cls(
174 project_id=report.project_id,
175 head_variation_id=report.head_variation_id,
176 severity=report.severity.value,
177 changed_regions=report.changed_regions,
178 added_regions=report.added_regions,
179 deleted_regions=report.deleted_regions,
180 total_changes=report.total_changes,
181 fingerprint_delta=fp_delta,
182 )
183
184
185 def _fingerprint(events: Sequence[Mapping[str, object]]) -> str:
186 """Stable hash of a note or event list for cache-friendly comparison."""
187 canonical = sorted(
188 events,
189 key=lambda e: (
190 e.get("pitch", 0),
191 e.get("cc", 0),
192 e.get("start_beat", e.get("beat", 0.0)),
193 e.get("value", 0),
194 ),
195 )
196 raw = json.dumps(canonical, sort_keys=True, default=str)
197 return hashlib.sha256(raw.encode()).hexdigest()[:16]
198
199
200 def _combined_fingerprint(
201 notes: Sequence[Mapping[str, object]],
202 cc: Sequence[Mapping[str, object]],
203 pb: Sequence[Mapping[str, object]],
204 at: Sequence[Mapping[str, object]],
205 ) -> str:
206 """Composite fingerprint across all data types for a region."""
207 combined = json.dumps({
208 "n": _fingerprint(notes),
209 "c": _fingerprint(cc),
210 "p": _fingerprint(pb),
211 "a": _fingerprint(at),
212 }, sort_keys=True)
213 return hashlib.sha256(combined.encode()).hexdigest()[:16]
214
215
216 def compute_drift_report(
217 *,
218 project_id: str,
219 head_variation_id: str,
220 head_snapshot_notes: RegionNotesMap,
221 working_snapshot_notes: RegionNotesMap,
222 track_regions: dict[str, str],
223 head_cc: RegionCCMap | None = None,
224 working_cc: RegionCCMap | None = None,
225 head_pb: RegionPitchBendMap | None = None,
226 working_pb: RegionPitchBendMap | None = None,
227 head_at: RegionAftertouchMap | None = None,
228 working_at: RegionAftertouchMap | None = None,
229 region_metadata: dict[str, RegionMetadataDB] | None = None,
230 ) -> DriftReport:
231 """Compare HEAD snapshot against working snapshot — notes + controllers.
232
233 Pure function — no database access, no StateStore. Uses matching
234 functions from the VariationService note-matching module.
235
236 Args:
237 project_id: Project identifier.
238 head_variation_id: The HEAD variation being compared against.
239 head_snapshot_notes: Notes per region from HEAD (reconstructed).
240 working_snapshot_notes: Notes per region from working tree (live).
241 track_regions: Mapping of region_id to track_id.
242 head_cc / working_cc: CC events per region.
243 head_pb / working_pb: Pitch bend events per region.
244 head_at / working_at: Aftertouch events per region.
245 region_metadata: Optional region metadata for additional context.
246 """
247 _head_cc = head_cc or {}
248 _working_cc = working_cc or {}
249 _head_pb = head_pb or {}
250 _working_pb = working_pb or {}
251 _head_at = head_at or {}
252 _working_at = working_at or {}
253
254 all_head_rids = (
255 set(head_snapshot_notes) | set(_head_cc) | set(_head_pb) | set(_head_at)
256 )
257 all_working_rids = (
258 set(working_snapshot_notes) | set(_working_cc) | set(_working_pb) | set(_working_at)
259 )
260
261 added_regions = sorted(all_working_rids - all_head_rids)
262 deleted_regions = sorted(all_head_rids - all_working_rids)
263 common_regions = all_head_rids & all_working_rids
264
265 changed_regions: list[str] = []
266 region_summaries: dict[str, RegionDriftSummary] = {}
267
268 # ── Common regions: diff notes + controllers ──────────────────────
269 for rid in sorted(common_regions):
270 track_id = track_regions.get(rid, "unknown")
271 h_notes = head_snapshot_notes.get(rid, [])
272 w_notes = working_snapshot_notes.get(rid, [])
273 h_cc = _head_cc.get(rid, [])
274 w_cc = _working_cc.get(rid, [])
275 h_pb = _head_pb.get(rid, [])
276 w_pb = _working_pb.get(rid, [])
277 h_at = _head_at.get(rid, [])
278 w_at = _working_at.get(rid, [])
279
280 head_fp = _combined_fingerprint(h_notes, h_cc, h_pb, h_at)
281 working_fp = _combined_fingerprint(w_notes, w_cc, w_pb, w_at)
282
283 if head_fp == working_fp:
284 region_summaries[rid] = RegionDriftSummary(
285 region_id=rid, track_id=track_id,
286 head_fingerprint=head_fp, working_fingerprint=working_fp,
287 )
288 continue
289
290 # Notes
291 note_matches = match_notes(h_notes, w_notes)
292 n_adds = sum(1 for m in note_matches if m.is_added)
293 n_rems = sum(1 for m in note_matches if m.is_removed)
294 n_mods = sum(1 for m in note_matches if m.is_modified)
295
296 # CC
297 cc_matches = match_cc_events(h_cc, w_cc)
298 cc_adds = sum(1 for m in cc_matches if m.is_added)
299 cc_rems = sum(1 for m in cc_matches if m.is_removed)
300 cc_mods = sum(1 for m in cc_matches if m.is_modified)
301
302 # Pitch bends
303 pb_matches = match_pitch_bends(h_pb, w_pb)
304 pb_adds = sum(1 for m in pb_matches if m.is_added)
305 pb_rems = sum(1 for m in pb_matches if m.is_removed)
306 pb_mods = sum(1 for m in pb_matches if m.is_modified)
307
308 # Aftertouch
309 at_matches = match_aftertouch(h_at, w_at)
310 at_adds = sum(1 for m in at_matches if m.is_added)
311 at_rems = sum(1 for m in at_matches if m.is_removed)
312 at_mods = sum(1 for m in at_matches if m.is_modified)
313
314 has_changes = (
315 n_adds + n_rems + n_mods
316 + cc_adds + cc_rems + cc_mods
317 + pb_adds + pb_rems + pb_mods
318 + at_adds + at_rems + at_mods
319 ) > 0
320
321 # Build capped sample_changes from note matches only
322 samples: list[SampleChange] = []
323 for m in note_matches:
324 if len(samples) >= MAX_SAMPLE_CHANGES:
325 break
326 if m.is_added:
327 samples.append(SampleChange(type="added", note=m.proposed_note))
328 elif m.is_removed:
329 samples.append(SampleChange(type="removed", note=m.base_note))
330 elif m.is_modified:
331 samples.append(SampleChange(type="modified", before=m.base_note, after=m.proposed_note))
332
333 if has_changes:
334 changed_regions.append(rid)
335
336 region_summaries[rid] = RegionDriftSummary(
337 region_id=rid, track_id=track_id,
338 added=n_adds, removed=n_rems, modified=n_mods,
339 cc_added=cc_adds, cc_removed=cc_rems, cc_modified=cc_mods,
340 pb_added=pb_adds, pb_removed=pb_rems, pb_modified=pb_mods,
341 at_added=at_adds, at_removed=at_rems, at_modified=at_mods,
342 sample_changes=tuple(samples),
343 head_fingerprint=head_fp, working_fingerprint=working_fp,
344 )
345
346 # ── Added regions (in working but not head) ───────────────────────
347 for rid in added_regions:
348 track_id = track_regions.get(rid, "unknown")
349 w_notes = working_snapshot_notes.get(rid, [])
350 w_cc = _working_cc.get(rid, [])
351 w_pb = _working_pb.get(rid, [])
352 w_at = _working_at.get(rid, [])
353 region_summaries[rid] = RegionDriftSummary(
354 region_id=rid, track_id=track_id,
355 added=len(w_notes),
356 cc_added=len(w_cc), pb_added=len(w_pb), at_added=len(w_at),
357 working_fingerprint=_combined_fingerprint(w_notes, w_cc, w_pb, w_at),
358 )
359
360 # ── Deleted regions (in head but not working) ─────────────────────
361 for rid in deleted_regions:
362 track_id = track_regions.get(rid, "unknown")
363 h_notes = head_snapshot_notes.get(rid, [])
364 h_cc = _head_cc.get(rid, [])
365 h_pb = _head_pb.get(rid, [])
366 h_at = _head_at.get(rid, [])
367 region_summaries[rid] = RegionDriftSummary(
368 region_id=rid, track_id=track_id,
369 removed=len(h_notes),
370 cc_removed=len(h_cc), pb_removed=len(h_pb), at_removed=len(h_at),
371 head_fingerprint=_combined_fingerprint(h_notes, h_cc, h_pb, h_at),
372 )
373
374 is_clean = not changed_regions and not added_regions and not deleted_regions
375 severity = DriftSeverity.CLEAN if is_clean else DriftSeverity.DIRTY
376
377 logger.info(
378 "✅ Drift report: %s (%d changed, %d added, %d deleted regions)",
379 severity.value, len(changed_regions), len(added_regions), len(deleted_regions),
380 )
381
382 return DriftReport(
383 project_id=project_id,
384 head_variation_id=head_variation_id,
385 severity=severity,
386 is_clean=is_clean,
387 changed_regions=tuple(changed_regions),
388 added_regions=tuple(added_regions),
389 deleted_regions=tuple(deleted_regions),
390 region_summaries=region_summaries,
391 )