muse_drift.py
python
| 1 | """Muse Drift Detection Engine'git status' for music. |
| 2 | |
| 3 | Compares a HEAD snapshot (from persisted variation history) against |
| 4 | a working snapshot (live StateStore capture) to produce a deterministic |
| 5 | DriftReport describing what changed since the last commit. |
| 6 | |
| 7 | Diffs notes AND controller data (CC, pitch bends, aftertouch). |
| 8 | |
| 9 | Pure data — no side effects, no mutations. |
| 10 | |
| 11 | Boundary rules: |
| 12 | - Must NOT import StateStore, EntityRegistry, or get_or_create_store. |
| 13 | - Must NOT import executor modules or app.core.executor.*. |
| 14 | - Must NOT import LLM handlers or maestro_* modules. |
| 15 | - May import note_matching from VariationService (pure diff logic). |
| 16 | """ |
| 17 | |
| 18 | from __future__ import annotations |
| 19 | |
| 20 | import hashlib |
| 21 | import json |
| 22 | import logging |
| 23 | from collections.abc import Mapping, Sequence |
| 24 | from dataclasses import dataclass, field |
| 25 | from enum import Enum |
| 26 | from typing import Literal |
| 27 | |
| 28 | from typing_extensions import TypedDict |
| 29 | |
| 30 | from maestro.contracts.json_types import ( |
| 31 | AftertouchDict, |
| 32 | CCEventDict, |
| 33 | NoteDict, |
| 34 | PitchBendDict, |
| 35 | RegionAftertouchMap, |
| 36 | RegionCCMap, |
| 37 | RegionMetadataDB, |
| 38 | RegionNotesMap, |
| 39 | RegionPitchBendMap, |
| 40 | ) |
| 41 | from maestro.services.variation.note_matching import ( |
| 42 | match_notes, |
| 43 | match_cc_events, |
| 44 | match_pitch_bends, |
| 45 | match_aftertouch, |
| 46 | ) |
| 47 | |
| 48 | logger = logging.getLogger(__name__) |
| 49 | |
| 50 | MAX_SAMPLE_CHANGES = 5 |
| 51 | |
| 52 | |
| 53 | class SampleChange(TypedDict, total=False): |
| 54 | """A single note change captured as a human-readable diff sample. |
| 55 | |
| 56 | ``type`` is always present; ``note``/``before``/``after`` depend on type. |
| 57 | """ |
| 58 | |
| 59 | type: Literal["added", "removed", "modified"] |
| 60 | note: NoteDict | None |
| 61 | before: NoteDict | None |
| 62 | after: NoteDict | None |
| 63 | |
| 64 | |
| 65 | class DriftSeverity(str, Enum): |
| 66 | """How much the working tree has diverged from HEAD.""" |
| 67 | |
| 68 | CLEAN = "clean" |
| 69 | DIRTY = "dirty" |
| 70 | DIVERGED = "diverged" |
| 71 | |
| 72 | |
| 73 | @dataclass(frozen=True) |
| 74 | class RegionDriftSummary: |
| 75 | """Per-region drift summary with note + controller change counts.""" |
| 76 | |
| 77 | region_id: str |
| 78 | track_id: str |
| 79 | # Notes |
| 80 | added: int = 0 |
| 81 | removed: int = 0 |
| 82 | modified: int = 0 |
| 83 | # CC |
| 84 | cc_added: int = 0 |
| 85 | cc_removed: int = 0 |
| 86 | cc_modified: int = 0 |
| 87 | # Pitch bends |
| 88 | pb_added: int = 0 |
| 89 | pb_removed: int = 0 |
| 90 | pb_modified: int = 0 |
| 91 | # Aftertouch |
| 92 | at_added: int = 0 |
| 93 | at_removed: int = 0 |
| 94 | at_modified: int = 0 |
| 95 | |
| 96 | sample_changes: tuple[SampleChange, ...] = () |
| 97 | head_fingerprint: str = "" |
| 98 | working_fingerprint: str = "" |
| 99 | |
| 100 | @property |
| 101 | def is_clean(self) -> bool: |
| 102 | """``True`` when notes, CC, pitch bends, and aftertouch all have zero changes.""" |
| 103 | return ( |
| 104 | self.added == 0 and self.removed == 0 and self.modified == 0 |
| 105 | and self.cc_added == 0 and self.cc_removed == 0 and self.cc_modified == 0 |
| 106 | and self.pb_added == 0 and self.pb_removed == 0 and self.pb_modified == 0 |
| 107 | and self.at_added == 0 and self.at_removed == 0 and self.at_modified == 0 |
| 108 | ) |
| 109 | |
| 110 | |
| 111 | @dataclass(frozen=True) |
| 112 | class DriftReport: |
| 113 | """Deterministic report of working-tree vs HEAD divergence. |
| 114 | |
| 115 | Covers notes and all controller data (CC, pitch bends, aftertouch). |
| 116 | """ |
| 117 | |
| 118 | project_id: str |
| 119 | head_variation_id: str |
| 120 | severity: DriftSeverity |
| 121 | is_clean: bool |
| 122 | changed_regions: tuple[str, ...] = () |
| 123 | added_regions: tuple[str, ...] = () |
| 124 | deleted_regions: tuple[str, ...] = () |
| 125 | region_summaries: dict[str, RegionDriftSummary] = field(default_factory=dict) |
| 126 | |
| 127 | @property |
| 128 | def total_changes(self) -> int: |
| 129 | """Sum of all note and controller changes across every region in the drift.""" |
| 130 | return sum( |
| 131 | s.added + s.removed + s.modified |
| 132 | + s.cc_added + s.cc_removed + s.cc_modified |
| 133 | + s.pb_added + s.pb_removed + s.pb_modified |
| 134 | + s.at_added + s.at_removed + s.at_modified |
| 135 | for s in self.region_summaries.values() |
| 136 | ) |
| 137 | |
| 138 | def requires_user_action(self) -> bool: |
| 139 | """Whether this drift state should block a commit.""" |
| 140 | return self.severity != DriftSeverity.CLEAN |
| 141 | |
| 142 | |
| 143 | @dataclass(frozen=True) |
| 144 | class CommitConflictPayload: |
| 145 | """Lightweight conflict summary returned in 409 responses. |
| 146 | |
| 147 | Derived from DriftReport — excludes bulky sample_changes and |
| 148 | full region_summaries to keep the payload small. |
| 149 | """ |
| 150 | |
| 151 | project_id: str |
| 152 | head_variation_id: str |
| 153 | severity: str |
| 154 | changed_regions: tuple[str, ...] |
| 155 | added_regions: tuple[str, ...] |
| 156 | deleted_regions: tuple[str, ...] |
| 157 | total_changes: int |
| 158 | fingerprint_delta: dict[str, tuple[str, str]] |
| 159 | |
| 160 | @classmethod |
| 161 | def from_drift_report(cls, report: DriftReport) -> "CommitConflictPayload": |
| 162 | """Construct a lightweight conflict payload from a full ``DriftReport``. |
| 163 | |
| 164 | Excludes ``sample_changes`` and full ``region_summaries`` to keep the |
| 165 | 409 response body small. The ``fingerprint_delta`` maps each dirty |
| 166 | region to ``(head_fingerprint, working_fingerprint)`` so the client |
| 167 | can identify exactly which regions changed without reading all note data. |
| 168 | """ |
| 169 | fp_delta: dict[str, tuple[str, str]] = {} |
| 170 | for rid, summary in report.region_summaries.items(): |
| 171 | if not summary.is_clean: |
| 172 | fp_delta[rid] = (summary.head_fingerprint, summary.working_fingerprint) |
| 173 | return cls( |
| 174 | project_id=report.project_id, |
| 175 | head_variation_id=report.head_variation_id, |
| 176 | severity=report.severity.value, |
| 177 | changed_regions=report.changed_regions, |
| 178 | added_regions=report.added_regions, |
| 179 | deleted_regions=report.deleted_regions, |
| 180 | total_changes=report.total_changes, |
| 181 | fingerprint_delta=fp_delta, |
| 182 | ) |
| 183 | |
| 184 | |
| 185 | def _fingerprint(events: Sequence[Mapping[str, object]]) -> str: |
| 186 | """Stable hash of a note or event list for cache-friendly comparison.""" |
| 187 | canonical = sorted( |
| 188 | events, |
| 189 | key=lambda e: ( |
| 190 | e.get("pitch", 0), |
| 191 | e.get("cc", 0), |
| 192 | e.get("start_beat", e.get("beat", 0.0)), |
| 193 | e.get("value", 0), |
| 194 | ), |
| 195 | ) |
| 196 | raw = json.dumps(canonical, sort_keys=True, default=str) |
| 197 | return hashlib.sha256(raw.encode()).hexdigest()[:16] |
| 198 | |
| 199 | |
| 200 | def _combined_fingerprint( |
| 201 | notes: Sequence[Mapping[str, object]], |
| 202 | cc: Sequence[Mapping[str, object]], |
| 203 | pb: Sequence[Mapping[str, object]], |
| 204 | at: Sequence[Mapping[str, object]], |
| 205 | ) -> str: |
| 206 | """Composite fingerprint across all data types for a region.""" |
| 207 | combined = json.dumps({ |
| 208 | "n": _fingerprint(notes), |
| 209 | "c": _fingerprint(cc), |
| 210 | "p": _fingerprint(pb), |
| 211 | "a": _fingerprint(at), |
| 212 | }, sort_keys=True) |
| 213 | return hashlib.sha256(combined.encode()).hexdigest()[:16] |
| 214 | |
| 215 | |
| 216 | def compute_drift_report( |
| 217 | *, |
| 218 | project_id: str, |
| 219 | head_variation_id: str, |
| 220 | head_snapshot_notes: RegionNotesMap, |
| 221 | working_snapshot_notes: RegionNotesMap, |
| 222 | track_regions: dict[str, str], |
| 223 | head_cc: RegionCCMap | None = None, |
| 224 | working_cc: RegionCCMap | None = None, |
| 225 | head_pb: RegionPitchBendMap | None = None, |
| 226 | working_pb: RegionPitchBendMap | None = None, |
| 227 | head_at: RegionAftertouchMap | None = None, |
| 228 | working_at: RegionAftertouchMap | None = None, |
| 229 | region_metadata: dict[str, RegionMetadataDB] | None = None, |
| 230 | ) -> DriftReport: |
| 231 | """Compare HEAD snapshot against working snapshot — notes + controllers. |
| 232 | |
| 233 | Pure function — no database access, no StateStore. Uses matching |
| 234 | functions from the VariationService note-matching module. |
| 235 | |
| 236 | Args: |
| 237 | project_id: Project identifier. |
| 238 | head_variation_id: The HEAD variation being compared against. |
| 239 | head_snapshot_notes: Notes per region from HEAD (reconstructed). |
| 240 | working_snapshot_notes: Notes per region from working tree (live). |
| 241 | track_regions: Mapping of region_id to track_id. |
| 242 | head_cc / working_cc: CC events per region. |
| 243 | head_pb / working_pb: Pitch bend events per region. |
| 244 | head_at / working_at: Aftertouch events per region. |
| 245 | region_metadata: Optional region metadata for additional context. |
| 246 | """ |
| 247 | _head_cc = head_cc or {} |
| 248 | _working_cc = working_cc or {} |
| 249 | _head_pb = head_pb or {} |
| 250 | _working_pb = working_pb or {} |
| 251 | _head_at = head_at or {} |
| 252 | _working_at = working_at or {} |
| 253 | |
| 254 | all_head_rids = ( |
| 255 | set(head_snapshot_notes) | set(_head_cc) | set(_head_pb) | set(_head_at) |
| 256 | ) |
| 257 | all_working_rids = ( |
| 258 | set(working_snapshot_notes) | set(_working_cc) | set(_working_pb) | set(_working_at) |
| 259 | ) |
| 260 | |
| 261 | added_regions = sorted(all_working_rids - all_head_rids) |
| 262 | deleted_regions = sorted(all_head_rids - all_working_rids) |
| 263 | common_regions = all_head_rids & all_working_rids |
| 264 | |
| 265 | changed_regions: list[str] = [] |
| 266 | region_summaries: dict[str, RegionDriftSummary] = {} |
| 267 | |
| 268 | # ── Common regions: diff notes + controllers ────────────────────── |
| 269 | for rid in sorted(common_regions): |
| 270 | track_id = track_regions.get(rid, "unknown") |
| 271 | h_notes = head_snapshot_notes.get(rid, []) |
| 272 | w_notes = working_snapshot_notes.get(rid, []) |
| 273 | h_cc = _head_cc.get(rid, []) |
| 274 | w_cc = _working_cc.get(rid, []) |
| 275 | h_pb = _head_pb.get(rid, []) |
| 276 | w_pb = _working_pb.get(rid, []) |
| 277 | h_at = _head_at.get(rid, []) |
| 278 | w_at = _working_at.get(rid, []) |
| 279 | |
| 280 | head_fp = _combined_fingerprint(h_notes, h_cc, h_pb, h_at) |
| 281 | working_fp = _combined_fingerprint(w_notes, w_cc, w_pb, w_at) |
| 282 | |
| 283 | if head_fp == working_fp: |
| 284 | region_summaries[rid] = RegionDriftSummary( |
| 285 | region_id=rid, track_id=track_id, |
| 286 | head_fingerprint=head_fp, working_fingerprint=working_fp, |
| 287 | ) |
| 288 | continue |
| 289 | |
| 290 | # Notes |
| 291 | note_matches = match_notes(h_notes, w_notes) |
| 292 | n_adds = sum(1 for m in note_matches if m.is_added) |
| 293 | n_rems = sum(1 for m in note_matches if m.is_removed) |
| 294 | n_mods = sum(1 for m in note_matches if m.is_modified) |
| 295 | |
| 296 | # CC |
| 297 | cc_matches = match_cc_events(h_cc, w_cc) |
| 298 | cc_adds = sum(1 for m in cc_matches if m.is_added) |
| 299 | cc_rems = sum(1 for m in cc_matches if m.is_removed) |
| 300 | cc_mods = sum(1 for m in cc_matches if m.is_modified) |
| 301 | |
| 302 | # Pitch bends |
| 303 | pb_matches = match_pitch_bends(h_pb, w_pb) |
| 304 | pb_adds = sum(1 for m in pb_matches if m.is_added) |
| 305 | pb_rems = sum(1 for m in pb_matches if m.is_removed) |
| 306 | pb_mods = sum(1 for m in pb_matches if m.is_modified) |
| 307 | |
| 308 | # Aftertouch |
| 309 | at_matches = match_aftertouch(h_at, w_at) |
| 310 | at_adds = sum(1 for m in at_matches if m.is_added) |
| 311 | at_rems = sum(1 for m in at_matches if m.is_removed) |
| 312 | at_mods = sum(1 for m in at_matches if m.is_modified) |
| 313 | |
| 314 | has_changes = ( |
| 315 | n_adds + n_rems + n_mods |
| 316 | + cc_adds + cc_rems + cc_mods |
| 317 | + pb_adds + pb_rems + pb_mods |
| 318 | + at_adds + at_rems + at_mods |
| 319 | ) > 0 |
| 320 | |
| 321 | # Build capped sample_changes from note matches only |
| 322 | samples: list[SampleChange] = [] |
| 323 | for m in note_matches: |
| 324 | if len(samples) >= MAX_SAMPLE_CHANGES: |
| 325 | break |
| 326 | if m.is_added: |
| 327 | samples.append(SampleChange(type="added", note=m.proposed_note)) |
| 328 | elif m.is_removed: |
| 329 | samples.append(SampleChange(type="removed", note=m.base_note)) |
| 330 | elif m.is_modified: |
| 331 | samples.append(SampleChange(type="modified", before=m.base_note, after=m.proposed_note)) |
| 332 | |
| 333 | if has_changes: |
| 334 | changed_regions.append(rid) |
| 335 | |
| 336 | region_summaries[rid] = RegionDriftSummary( |
| 337 | region_id=rid, track_id=track_id, |
| 338 | added=n_adds, removed=n_rems, modified=n_mods, |
| 339 | cc_added=cc_adds, cc_removed=cc_rems, cc_modified=cc_mods, |
| 340 | pb_added=pb_adds, pb_removed=pb_rems, pb_modified=pb_mods, |
| 341 | at_added=at_adds, at_removed=at_rems, at_modified=at_mods, |
| 342 | sample_changes=tuple(samples), |
| 343 | head_fingerprint=head_fp, working_fingerprint=working_fp, |
| 344 | ) |
| 345 | |
| 346 | # ── Added regions (in working but not head) ─────────────────────── |
| 347 | for rid in added_regions: |
| 348 | track_id = track_regions.get(rid, "unknown") |
| 349 | w_notes = working_snapshot_notes.get(rid, []) |
| 350 | w_cc = _working_cc.get(rid, []) |
| 351 | w_pb = _working_pb.get(rid, []) |
| 352 | w_at = _working_at.get(rid, []) |
| 353 | region_summaries[rid] = RegionDriftSummary( |
| 354 | region_id=rid, track_id=track_id, |
| 355 | added=len(w_notes), |
| 356 | cc_added=len(w_cc), pb_added=len(w_pb), at_added=len(w_at), |
| 357 | working_fingerprint=_combined_fingerprint(w_notes, w_cc, w_pb, w_at), |
| 358 | ) |
| 359 | |
| 360 | # ── Deleted regions (in head but not working) ───────────────────── |
| 361 | for rid in deleted_regions: |
| 362 | track_id = track_regions.get(rid, "unknown") |
| 363 | h_notes = head_snapshot_notes.get(rid, []) |
| 364 | h_cc = _head_cc.get(rid, []) |
| 365 | h_pb = _head_pb.get(rid, []) |
| 366 | h_at = _head_at.get(rid, []) |
| 367 | region_summaries[rid] = RegionDriftSummary( |
| 368 | region_id=rid, track_id=track_id, |
| 369 | removed=len(h_notes), |
| 370 | cc_removed=len(h_cc), pb_removed=len(h_pb), at_removed=len(h_at), |
| 371 | head_fingerprint=_combined_fingerprint(h_notes, h_cc, h_pb, h_at), |
| 372 | ) |
| 373 | |
| 374 | is_clean = not changed_regions and not added_regions and not deleted_regions |
| 375 | severity = DriftSeverity.CLEAN if is_clean else DriftSeverity.DIRTY |
| 376 | |
| 377 | logger.info( |
| 378 | "✅ Drift report: %s (%d changed, %d added, %d deleted regions)", |
| 379 | severity.value, len(changed_regions), len(added_regions), len(deleted_regions), |
| 380 | ) |
| 381 | |
| 382 | return DriftReport( |
| 383 | project_id=project_id, |
| 384 | head_variation_id=head_variation_id, |
| 385 | severity=severity, |
| 386 | is_clean=is_clean, |
| 387 | changed_regions=tuple(changed_regions), |
| 388 | added_regions=tuple(added_regions), |
| 389 | deleted_regions=tuple(deleted_regions), |
| 390 | region_summaries=region_summaries, |
| 391 | ) |