cgcardona / muse public
test_muse_drift_controllers.py python
524 lines 18.3 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Tests for controller drift detection (CC, pitch bends, aftertouch).
2
3 Verifies:
4 - Clean controller state detection.
5 - CC event drift (add / remove / modify).
6 - Pitch bend drift.
7 - Aftertouch drift.
8 - HEAD snapshot reconstruction fidelity for controllers.
9 - Controller matching boundary isolation.
10 """
11 from __future__ import annotations
12
13 import ast
14 import uuid
15 from collections.abc import AsyncGenerator
16 import pytest
17 from maestro.contracts.json_types import AftertouchDict, CCEventDict, NoteDict, PitchBendDict
18 from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
19
20 from maestro.db.database import Base
21 from maestro.db import muse_models # noqa: F401 — register tables
22 from maestro.models.variation import (
23 MidiNoteSnapshot,
24 NoteChange,
25 Phrase,
26 Variation,
27 )
28 from maestro.services import muse_repository
29 from maestro.services.muse_drift import (
30 DriftSeverity,
31 compute_drift_report,
32 )
33 from maestro.services.muse_replay import reconstruct_head_snapshot
34 from maestro.services.variation.note_matching import (
35 EventMatch,
36 match_cc_events,
37 match_pitch_bends,
38 match_aftertouch,
39 )
40
41
42 # ── Fixtures ──────────────────────────────────────────────────────────────
43
44
45 @pytest.fixture
46 async def async_session() -> AsyncGenerator[AsyncSession, None]:
47 engine = create_async_engine("sqlite+aiosqlite:///:memory:")
48 async with engine.begin() as conn:
49 await conn.run_sync(Base.metadata.create_all)
50 Session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
51 async with Session() as session:
52 yield session
53 await engine.dispose()
54
55
56 # ── Helpers ───────────────────────────────────────────────────────────────
57
58
59 def _cc(cc_num: int, beat: float, value: int) -> CCEventDict:
60 return {"cc": cc_num, "beat": beat, "value": value}
61
62
63 def _pb(beat: float, value: int) -> PitchBendDict:
64 return {"beat": beat, "value": value}
65
66
67 def _at(beat: float, value: int, pitch: int | None = None) -> AftertouchDict:
68 d: AftertouchDict = {"beat": beat, "value": value}
69 if pitch is not None:
70 d["pitch"] = pitch
71 return d
72
73
74 def _note(pitch: int, start: float) -> NoteDict:
75 return {"pitch": pitch, "start_beat": start, "duration_beats": 1.0, "velocity": 100, "channel": 0}
76
77
78 def _make_variation_with_controllers(
79 notes: list[NoteDict],
80 cc_events: list[CCEventDict] | None = None,
81 pitch_bends: list[PitchBendDict] | None = None,
82 aftertouch: list[AftertouchDict] | None = None,
83 region_id: str = "region-1",
84 track_id: str = "track-1",
85 ) -> Variation:
86 vid = str(uuid.uuid4())
87 pid = str(uuid.uuid4())
88 return Variation(
89 variation_id=vid,
90 intent="controller test",
91 ai_explanation="test",
92 affected_tracks=[track_id],
93 affected_regions=[region_id],
94 beat_range=(0.0, 8.0),
95 phrases=[
96 Phrase(
97 phrase_id=pid,
98 track_id=track_id,
99 region_id=region_id,
100 start_beat=0.0,
101 end_beat=8.0,
102 label="Test Phrase",
103 note_changes=[
104 NoteChange(
105 note_id=str(uuid.uuid4()),
106 change_type="added",
107 after=MidiNoteSnapshot.from_note_dict(n),
108 )
109 for n in notes
110 ],
111 cc_events=cc_events or [],
112 pitch_bends=pitch_bends or [],
113 aftertouch=aftertouch or [],
114 ),
115 ],
116 )
117
118
119 # ---------------------------------------------------------------------------
120 # 6.1 — Clean controller state
121 # ---------------------------------------------------------------------------
122
123
124 class TestCleanControllerState:
125
126 def test_identical_cc_is_clean(self) -> None:
127
128 cc = [_cc(64, 0.0, 127)]
129 report = compute_drift_report(
130 project_id="p", head_variation_id="v",
131 head_snapshot_notes={"r1": [_note(60, 0.0)]},
132 working_snapshot_notes={"r1": [_note(60, 0.0)]},
133 track_regions={"r1": "t1"},
134 head_cc={"r1": cc}, working_cc={"r1": cc},
135 )
136 assert report.is_clean is True
137 assert report.severity == DriftSeverity.CLEAN
138 assert report.region_summaries["r1"].cc_added == 0
139
140 def test_identical_all_controllers_is_clean(self) -> None:
141
142 cc = [_cc(64, 0.0, 127)]
143 pb = [_pb(1.0, 4096)]
144 at = [_at(2.0, 80)]
145 report = compute_drift_report(
146 project_id="p", head_variation_id="v",
147 head_snapshot_notes={"r1": [_note(60, 0.0)]},
148 working_snapshot_notes={"r1": [_note(60, 0.0)]},
149 track_regions={"r1": "t1"},
150 head_cc={"r1": cc}, working_cc={"r1": cc},
151 head_pb={"r1": pb}, working_pb={"r1": pb},
152 head_at={"r1": at}, working_at={"r1": at},
153 )
154 assert report.is_clean is True
155 assert report.total_changes == 0
156
157 @pytest.mark.anyio
158 async def test_reconstructed_head_clean(self, async_session: AsyncSession) -> None:
159
160 """Persist with CC, reconstruct HEAD, drift against identical data → CLEAN."""
161 notes = [_note(60, 0.0)]
162 var = _make_variation_with_controllers(notes, cc_events=[_cc(64, 0.0, 127), _cc(1, 2.0, 64)])
163
164 await muse_repository.save_variation(
165 async_session, var,
166 project_id="proj-cc", base_state_id="s1", conversation_id="c",
167 region_metadata={},
168 )
169 await muse_repository.set_head(async_session, var.variation_id, commit_state_id="s2")
170 await async_session.commit()
171
172 snap = await reconstruct_head_snapshot(async_session, "proj-cc")
173 assert snap is not None
174
175 report = compute_drift_report(
176 project_id="proj-cc",
177 head_variation_id=snap.variation_id,
178 head_snapshot_notes=snap.notes,
179 working_snapshot_notes=snap.notes,
180 track_regions=snap.track_regions,
181 head_cc=snap.cc, working_cc=snap.cc,
182 head_pb=snap.pitch_bends, working_pb=snap.pitch_bends,
183 head_at=snap.aftertouch, working_at=snap.aftertouch,
184 )
185 assert report.is_clean is True
186
187
188 # ---------------------------------------------------------------------------
189 # 6.2 — Sustain pedal drift (CC64)
190 # ---------------------------------------------------------------------------
191
192
193 class TestSustainPedalDrift:
194
195 def test_cc_added_in_working(self) -> None:
196
197 report = compute_drift_report(
198 project_id="p", head_variation_id="v",
199 head_snapshot_notes={"r1": [_note(60, 0.0)]},
200 working_snapshot_notes={"r1": [_note(60, 0.0)]},
201 track_regions={"r1": "t1"},
202 head_cc={"r1": []},
203 working_cc={"r1": [_cc(64, 0.0, 127)]},
204 )
205 assert report.is_clean is False
206 s = report.region_summaries["r1"]
207 assert s.cc_added == 1
208 assert s.cc_removed == 0
209
210 def test_cc_removed_from_working(self) -> None:
211
212 report = compute_drift_report(
213 project_id="p", head_variation_id="v",
214 head_snapshot_notes={"r1": [_note(60, 0.0)]},
215 working_snapshot_notes={"r1": [_note(60, 0.0)]},
216 track_regions={"r1": "t1"},
217 head_cc={"r1": [_cc(64, 0.0, 127)]},
218 working_cc={"r1": []},
219 )
220 assert report.is_clean is False
221 s = report.region_summaries["r1"]
222 assert s.cc_removed == 1
223
224 def test_cc_value_modified(self) -> None:
225
226 report = compute_drift_report(
227 project_id="p", head_variation_id="v",
228 head_snapshot_notes={"r1": [_note(60, 0.0)]},
229 working_snapshot_notes={"r1": [_note(60, 0.0)]},
230 track_regions={"r1": "t1"},
231 head_cc={"r1": [_cc(64, 0.0, 127)]},
232 working_cc={"r1": [_cc(64, 0.0, 0)]},
233 )
234 assert report.is_clean is False
235 s = report.region_summaries["r1"]
236 assert s.cc_modified == 1
237
238
239 # ---------------------------------------------------------------------------
240 # 6.3 — Pitch bend modification
241 # ---------------------------------------------------------------------------
242
243
244 class TestPitchBendDrift:
245
246 def test_pb_same_beat_different_value(self) -> None:
247
248 report = compute_drift_report(
249 project_id="p", head_variation_id="v",
250 head_snapshot_notes={"r1": [_note(60, 0.0)]},
251 working_snapshot_notes={"r1": [_note(60, 0.0)]},
252 track_regions={"r1": "t1"},
253 head_pb={"r1": [_pb(1.0, 4096)]},
254 working_pb={"r1": [_pb(1.0, 8192)]},
255 )
256 assert report.is_clean is False
257 s = report.region_summaries["r1"]
258 assert s.pb_modified == 1
259
260 def test_pb_added(self) -> None:
261
262 report = compute_drift_report(
263 project_id="p", head_variation_id="v",
264 head_snapshot_notes={"r1": [_note(60, 0.0)]},
265 working_snapshot_notes={"r1": [_note(60, 0.0)]},
266 track_regions={"r1": "t1"},
267 head_pb={"r1": []},
268 working_pb={"r1": [_pb(1.0, 4096)]},
269 )
270 s = report.region_summaries["r1"]
271 assert s.pb_added == 1
272
273 def test_pb_removed(self) -> None:
274
275 report = compute_drift_report(
276 project_id="p", head_variation_id="v",
277 head_snapshot_notes={"r1": [_note(60, 0.0)]},
278 working_snapshot_notes={"r1": [_note(60, 0.0)]},
279 track_regions={"r1": "t1"},
280 head_pb={"r1": [_pb(1.0, 4096)]},
281 working_pb={"r1": []},
282 )
283 s = report.region_summaries["r1"]
284 assert s.pb_removed == 1
285
286
287 # ---------------------------------------------------------------------------
288 # 6.4 — Aftertouch add / delete
289 # ---------------------------------------------------------------------------
290
291
292 class TestAftertouchDrift:
293
294 def test_at_added(self) -> None:
295
296 report = compute_drift_report(
297 project_id="p", head_variation_id="v",
298 head_snapshot_notes={"r1": [_note(60, 0.0)]},
299 working_snapshot_notes={"r1": [_note(60, 0.0)]},
300 track_regions={"r1": "t1"},
301 head_at={"r1": []},
302 working_at={"r1": [_at(2.0, 80)]},
303 )
304 s = report.region_summaries["r1"]
305 assert s.at_added == 1
306
307 def test_at_removed(self) -> None:
308
309 report = compute_drift_report(
310 project_id="p", head_variation_id="v",
311 head_snapshot_notes={"r1": [_note(60, 0.0)]},
312 working_snapshot_notes={"r1": [_note(60, 0.0)]},
313 track_regions={"r1": "t1"},
314 head_at={"r1": [_at(2.0, 80)]},
315 working_at={"r1": []},
316 )
317 s = report.region_summaries["r1"]
318 assert s.at_removed == 1
319
320 def test_at_modified_value(self) -> None:
321
322 report = compute_drift_report(
323 project_id="p", head_variation_id="v",
324 head_snapshot_notes={"r1": [_note(60, 0.0)]},
325 working_snapshot_notes={"r1": [_note(60, 0.0)]},
326 track_regions={"r1": "t1"},
327 head_at={"r1": [_at(2.0, 80)]},
328 working_at={"r1": [_at(2.0, 40)]},
329 )
330 s = report.region_summaries["r1"]
331 assert s.at_modified == 1
332
333 def test_poly_aftertouch_pitch_discriminated(self) -> None:
334
335 """Poly aftertouch on different pitches → add + remove, not modify."""
336 report = compute_drift_report(
337 project_id="p", head_variation_id="v",
338 head_snapshot_notes={"r1": [_note(60, 0.0)]},
339 working_snapshot_notes={"r1": [_note(60, 0.0)]},
340 track_regions={"r1": "t1"},
341 head_at={"r1": [_at(2.0, 80, pitch=60)]},
342 working_at={"r1": [_at(2.0, 80, pitch=72)]},
343 )
344 s = report.region_summaries["r1"]
345 assert s.at_added == 1
346 assert s.at_removed == 1
347 assert s.at_modified == 0
348
349
350 # ---------------------------------------------------------------------------
351 # 6.5 — Replay fidelity (controllers roundtrip through DB)
352 # ---------------------------------------------------------------------------
353
354
355 class TestReplayFidelity:
356
357 @pytest.mark.anyio
358 async def test_cc_roundtrip(self, async_session: AsyncSession) -> None:
359
360 notes = [_note(60, 0.0)]
361 var = _make_variation_with_controllers(notes, cc_events=[_cc(64, 0.0, 127), _cc(1, 4.0, 64)])
362
363 await muse_repository.save_variation(
364 async_session, var,
365 project_id="proj-rt", base_state_id="s1", conversation_id="c",
366 region_metadata={},
367 )
368 await muse_repository.set_head(async_session, var.variation_id, commit_state_id="s2")
369 await async_session.commit()
370
371 snap = await reconstruct_head_snapshot(async_session, "proj-rt")
372 assert snap is not None
373 assert len(snap.cc.get("region-1", [])) == 2
374 cc_vals = sorted(e["cc"] for e in snap.cc["region-1"])
375 assert cc_vals == [1, 64]
376
377 @pytest.mark.anyio
378 async def test_pb_roundtrip(self, async_session: AsyncSession) -> None:
379
380 notes = [_note(60, 0.0)]
381 var = _make_variation_with_controllers(notes, pitch_bends=[_pb(1.0, 4096), _pb(3.0, 8191)])
382
383 await muse_repository.save_variation(
384 async_session, var,
385 project_id="proj-pb-rt", base_state_id="s1", conversation_id="c",
386 region_metadata={},
387 )
388 await muse_repository.set_head(async_session, var.variation_id, commit_state_id="s2")
389 await async_session.commit()
390
391 snap = await reconstruct_head_snapshot(async_session, "proj-pb-rt")
392 assert snap is not None
393 assert len(snap.pitch_bends.get("region-1", [])) == 2
394
395 @pytest.mark.anyio
396 async def test_at_roundtrip(self, async_session: AsyncSession) -> None:
397
398 notes = [_note(60, 0.0)]
399 var = _make_variation_with_controllers(notes, aftertouch=[_at(2.0, 80), _at(4.0, 100, pitch=60)])
400
401 await muse_repository.save_variation(
402 async_session, var,
403 project_id="proj-at-rt", base_state_id="s1", conversation_id="c",
404 region_metadata={},
405 )
406 await muse_repository.set_head(async_session, var.variation_id, commit_state_id="s2")
407 await async_session.commit()
408
409 snap = await reconstruct_head_snapshot(async_session, "proj-at-rt")
410 assert snap is not None
411 assert len(snap.aftertouch.get("region-1", [])) == 2
412
413 @pytest.mark.anyio
414 async def test_mixed_controllers_roundtrip(self, async_session: AsyncSession) -> None:
415
416 """All three controller types persist and reconstruct correctly."""
417 notes = [_note(60, 0.0)]
418 var = _make_variation_with_controllers(
419 notes,
420 cc_events=[_cc(64, 0.0, 127)],
421 pitch_bends=[_pb(1.0, 4096)],
422 aftertouch=[_at(2.0, 80)],
423 )
424
425 await muse_repository.save_variation(
426 async_session, var,
427 project_id="proj-mix", base_state_id="s1", conversation_id="c",
428 region_metadata={},
429 )
430 await muse_repository.set_head(async_session, var.variation_id, commit_state_id="s2")
431 await async_session.commit()
432
433 snap = await reconstruct_head_snapshot(async_session, "proj-mix")
434 assert snap is not None
435 assert len(snap.cc.get("region-1", [])) == 1
436 assert len(snap.pitch_bends.get("region-1", [])) == 1
437 assert len(snap.aftertouch.get("region-1", [])) == 1
438
439
440 # ---------------------------------------------------------------------------
441 # Controller matching unit tests
442 # ---------------------------------------------------------------------------
443
444
445 class TestEventMatching:
446
447 def test_cc_match_same_cc_and_beat(self) -> None:
448
449 matches = match_cc_events(
450 [{"cc": 64, "beat": 0.0, "value": 127}],
451 [{"cc": 64, "beat": 0.0, "value": 127}],
452 )
453 assert len(matches) == 1
454 assert matches[0].is_unchanged
455
456 def test_cc_no_match_different_cc_number(self) -> None:
457
458 matches = match_cc_events(
459 [{"cc": 64, "beat": 0.0, "value": 127}],
460 [{"cc": 1, "beat": 0.0, "value": 127}],
461 )
462 added = [m for m in matches if m.is_added]
463 removed = [m for m in matches if m.is_removed]
464 assert len(added) == 1
465 assert len(removed) == 1
466
467 def test_pb_match_same_beat(self) -> None:
468
469 matches = match_pitch_bends(
470 [{"beat": 1.0, "value": 4096}],
471 [{"beat": 1.0, "value": 4096}],
472 )
473 assert len(matches) == 1
474 assert matches[0].is_unchanged
475
476 def test_pb_modified_different_value(self) -> None:
477
478 matches = match_pitch_bends(
479 [{"beat": 1.0, "value": 4096}],
480 [{"beat": 1.0, "value": 8192}],
481 )
482 assert len(matches) == 1
483 assert matches[0].is_modified
484
485 def test_at_match_with_pitch(self) -> None:
486
487 matches = match_aftertouch(
488 [{"beat": 2.0, "value": 80, "pitch": 60}],
489 [{"beat": 2.0, "value": 80, "pitch": 60}],
490 )
491 assert len(matches) == 1
492 assert matches[0].is_unchanged
493
494 def test_at_no_match_different_pitch(self) -> None:
495
496 matches = match_aftertouch(
497 [{"beat": 2.0, "value": 80, "pitch": 60}],
498 [{"beat": 2.0, "value": 80, "pitch": 72}],
499 )
500 added = [m for m in matches if m.is_added]
501 removed = [m for m in matches if m.is_removed]
502 assert len(added) == 1
503 assert len(removed) == 1
504
505
506 # ---------------------------------------------------------------------------
507 # Boundary seal
508 # ---------------------------------------------------------------------------
509
510
511 class TestControllerMatchingBoundary:
512
513 def test_note_matching_no_forbidden_imports(self) -> None:
514
515 from pathlib import Path
516 filepath = Path(__file__).resolve().parent.parent / "maestro" / "services" / "variation" / "note_matching.py"
517 tree = ast.parse(filepath.read_text())
518 forbidden = {"state_store", "executor", "maestro_handlers", "maestro_editing", "maestro_composing"}
519 for node in ast.walk(tree):
520 if isinstance(node, ast.ImportFrom) and node.module:
521 for fb in forbidden:
522 assert fb not in node.module, (
523 f"note_matching imports forbidden module: {node.module}"
524 )