cgcardona / muse public
test_muse_drift.py python
611 lines 20.9 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Tests for the Muse Drift Detection Engine.
2
3 Verifies:
4 - Clean working tree detection (identical notes → is_clean=True).
5 - Dirty working tree detection (added/removed/modified notes).
6 - Added/deleted region detection.
7 - HEAD snapshot reconstruction from persisted data.
8 - Fingerprint stability for caching.
9 - muse_drift boundary rules.
10 """
11 from __future__ import annotations
12
13 import ast
14 import uuid
15
16 import pytest
17 from collections.abc import AsyncGenerator
18 from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
19
20 from maestro.db.database import Base
21 from maestro.db import muse_models # noqa: F401 — register tables
22 from maestro.models.variation import (
23 MidiNoteSnapshot,
24 NoteChange,
25 Phrase,
26 Variation,
27 )
28 from maestro.services import muse_repository
29
30 from maestro.contracts.json_types import NoteDict
31
32 from maestro.services.muse_drift import (
33 DriftReport,
34 DriftSeverity,
35 RegionDriftSummary,
36 compute_drift_report,
37 _fingerprint,
38 )
39 from maestro.services.muse_replay import reconstruct_head_snapshot, HeadSnapshot
40
41
42 # ── Fixtures ──────────────────────────────────────────────────────────────
43
44
45 @pytest.fixture
46 async def async_session() -> AsyncGenerator[AsyncSession, None]:
47 """In-memory SQLite async session for tests."""
48 engine = create_async_engine("sqlite+aiosqlite:///:memory:")
49 async with engine.begin() as conn:
50 await conn.run_sync(Base.metadata.create_all)
51 Session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
52 async with Session() as session:
53 yield session
54 await engine.dispose()
55
56
57 # ── Helpers ───────────────────────────────────────────────────────────────
58
59
60 def _note(pitch: int, start: float, dur: float = 1.0, vel: int = 100) -> NoteDict:
61
62 return {
63 "pitch": pitch,
64 "start_beat": start,
65 "duration_beats": dur,
66 "velocity": vel,
67 "channel": 0,
68 }
69
70
71 def _make_variation_with_notes(
72 notes: list[NoteDict],
73 region_id: str = "region-1",
74 track_id: str = "track-1",
75 intent: str = "test",
76 ) -> Variation:
77 """Build a variation where all notes are 'added' type."""
78 vid = str(uuid.uuid4())
79 pid = str(uuid.uuid4())
80 note_changes = [
81 NoteChange(
82 note_id=str(uuid.uuid4()),
83 change_type="added",
84 after=MidiNoteSnapshot.from_note_dict(n),
85 )
86 for n in notes
87 ]
88 return Variation(
89 variation_id=vid,
90 intent=intent,
91 ai_explanation="test explanation",
92 affected_tracks=[track_id],
93 affected_regions=[region_id],
94 beat_range=(0.0, 8.0),
95 phrases=[
96 Phrase(
97 phrase_id=pid,
98 track_id=track_id,
99 region_id=region_id,
100 start_beat=0.0,
101 end_beat=8.0,
102 label="Test Phrase",
103 note_changes=note_changes,
104 ),
105 ],
106 )
107
108
109 # ---------------------------------------------------------------------------
110 # 6.1 — Clean Working Tree
111 # ---------------------------------------------------------------------------
112
113
114 class TestCleanWorkingTree:
115
116 def test_identical_notes_is_clean(self) -> None:
117
118 """HEAD and working have the same notes → CLEAN."""
119 notes = [_note(60, 0.0), _note(64, 1.0), _note(67, 2.0)]
120 report = compute_drift_report(
121 project_id="proj-1",
122 head_variation_id="v-1",
123 head_snapshot_notes={"r1": notes},
124 working_snapshot_notes={"r1": notes},
125 track_regions={"r1": "t1"},
126 )
127 assert report.is_clean is True
128 assert report.severity == DriftSeverity.CLEAN
129 assert len(report.changed_regions) == 0
130 assert len(report.added_regions) == 0
131 assert len(report.deleted_regions) == 0
132 assert report.total_changes == 0
133
134 def test_empty_regions_is_clean(self) -> None:
135
136 """Both snapshots have a region with no notes → CLEAN."""
137 report = compute_drift_report(
138 project_id="proj-1",
139 head_variation_id="v-1",
140 head_snapshot_notes={"r1": []},
141 working_snapshot_notes={"r1": []},
142 track_regions={"r1": "t1"},
143 )
144 assert report.is_clean is True
145
146 def test_no_regions_is_clean(self) -> None:
147
148 """Both snapshots have zero regions → CLEAN."""
149 report = compute_drift_report(
150 project_id="proj-1",
151 head_variation_id="v-1",
152 head_snapshot_notes={},
153 working_snapshot_notes={},
154 track_regions={},
155 )
156 assert report.is_clean is True
157
158
159 # ---------------------------------------------------------------------------
160 # 6.2 — Dirty Working Tree (Notes Added)
161 # ---------------------------------------------------------------------------
162
163
164 class TestDirtyNotesAdded:
165
166 def test_extra_note_in_working(self) -> None:
167
168 """Working has an extra note → DIRTY with 1 add."""
169 head = [_note(60, 0.0)]
170 working = [_note(60, 0.0), _note(72, 4.0)]
171 report = compute_drift_report(
172 project_id="p1",
173 head_variation_id="v1",
174 head_snapshot_notes={"r1": head},
175 working_snapshot_notes={"r1": working},
176 track_regions={"r1": "t1"},
177 )
178 assert report.is_clean is False
179 assert report.severity == DriftSeverity.DIRTY
180 assert "r1" in report.changed_regions
181 summary = report.region_summaries["r1"]
182 assert summary.added == 1
183 assert summary.removed == 0
184 assert summary.modified == 0
185
186 def test_multiple_added_notes(self) -> None:
187
188 """Working has several extra notes → correct add count."""
189 head = [_note(60, 0.0)]
190 working = [_note(60, 0.0), _note(64, 1.0), _note(67, 2.0), _note(72, 3.0)]
191 report = compute_drift_report(
192 project_id="p1",
193 head_variation_id="v1",
194 head_snapshot_notes={"r1": head},
195 working_snapshot_notes={"r1": working},
196 track_regions={"r1": "t1"},
197 )
198 summary = report.region_summaries["r1"]
199 assert summary.added == 3
200
201
202 # ---------------------------------------------------------------------------
203 # 6.3 — Dirty Working Tree (Note Modified)
204 # ---------------------------------------------------------------------------
205
206
207 class TestDirtyNotesModified:
208
209 def test_velocity_change_detected(self) -> None:
210
211 """Same pitch/time, different velocity → modified."""
212 head = [_note(60, 0.0, vel=100)]
213 working = [_note(60, 0.0, vel=50)]
214 report = compute_drift_report(
215 project_id="p1",
216 head_variation_id="v1",
217 head_snapshot_notes={"r1": head},
218 working_snapshot_notes={"r1": working},
219 track_regions={"r1": "t1"},
220 )
221 assert report.is_clean is False
222 summary = report.region_summaries["r1"]
223 assert summary.modified == 1
224 assert summary.added == 0
225 assert summary.removed == 0
226
227 def test_duration_change_detected(self) -> None:
228
229 """Same pitch/time, different duration → modified."""
230 head = [_note(60, 0.0, dur=1.0)]
231 working = [_note(60, 0.0, dur=4.0)]
232 report = compute_drift_report(
233 project_id="p1",
234 head_variation_id="v1",
235 head_snapshot_notes={"r1": head},
236 working_snapshot_notes={"r1": working},
237 track_regions={"r1": "t1"},
238 )
239 summary = report.region_summaries["r1"]
240 assert summary.modified == 1
241
242 def test_note_removed_from_working(self) -> None:
243
244 """Head has note, working doesn't → removed."""
245 head = [_note(60, 0.0), _note(64, 1.0)]
246 working = [_note(60, 0.0)]
247 report = compute_drift_report(
248 project_id="p1",
249 head_variation_id="v1",
250 head_snapshot_notes={"r1": head},
251 working_snapshot_notes={"r1": working},
252 track_regions={"r1": "t1"},
253 )
254 summary = report.region_summaries["r1"]
255 assert summary.removed == 1
256
257
258 # ---------------------------------------------------------------------------
259 # 6.4 — Added / Deleted Region Detection
260 # ---------------------------------------------------------------------------
261
262
263 class TestRegionDrift:
264
265 def test_added_region(self) -> None:
266
267 """Region exists in working but not head → added."""
268 report = compute_drift_report(
269 project_id="p1",
270 head_variation_id="v1",
271 head_snapshot_notes={"r1": [_note(60, 0.0)]},
272 working_snapshot_notes={"r1": [_note(60, 0.0)], "r2": [_note(72, 0.0)]},
273 track_regions={"r1": "t1", "r2": "t2"},
274 )
275 assert "r2" in report.added_regions
276 assert report.is_clean is False
277 summary = report.region_summaries["r2"]
278 assert summary.added == 1
279
280 def test_deleted_region(self) -> None:
281
282 """Region exists in head but not working → deleted."""
283 report = compute_drift_report(
284 project_id="p1",
285 head_variation_id="v1",
286 head_snapshot_notes={"r1": [_note(60, 0.0)], "r2": [_note(72, 0.0)]},
287 working_snapshot_notes={"r1": [_note(60, 0.0)]},
288 track_regions={"r1": "t1", "r2": "t2"},
289 )
290 assert "r2" in report.deleted_regions
291 assert report.is_clean is False
292 summary = report.region_summaries["r2"]
293 assert summary.removed == 1
294
295 def test_both_added_and_deleted(self) -> None:
296
297 """One region added, another deleted → both detected."""
298 report = compute_drift_report(
299 project_id="p1",
300 head_variation_id="v1",
301 head_snapshot_notes={"r1": [_note(60, 0.0)]},
302 working_snapshot_notes={"r2": [_note(72, 0.0)]},
303 track_regions={"r1": "t1", "r2": "t2"},
304 )
305 assert "r1" in report.deleted_regions
306 assert "r2" in report.added_regions
307 assert report.is_clean is False
308
309
310 # ---------------------------------------------------------------------------
311 # Fingerprint tests
312 # ---------------------------------------------------------------------------
313
314
315 class TestFingerprint:
316
317 def test_identical_notes_same_fingerprint(self) -> None:
318
319 notes = [_note(60, 0.0), _note(64, 1.0)]
320 assert _fingerprint(notes) == _fingerprint(notes)
321
322 def test_order_independent(self) -> None:
323
324 """Fingerprint should be stable regardless of note order."""
325 a = [_note(60, 0.0), _note(64, 1.0)]
326 b = [_note(64, 1.0), _note(60, 0.0)]
327 assert _fingerprint(a) == _fingerprint(b)
328
329 def test_different_notes_different_fingerprint(self) -> None:
330
331 a = [_note(60, 0.0)]
332 b = [_note(72, 0.0)]
333 assert _fingerprint(a) != _fingerprint(b)
334
335
336 # ---------------------------------------------------------------------------
337 # Sample changes capping
338 # ---------------------------------------------------------------------------
339
340
341 class TestSampleChanges:
342
343 def test_sample_changes_capped(self) -> None:
344
345 """Sample changes should not exceed MAX_SAMPLE_CHANGES."""
346 head: list[NoteDict] = []
347 working = [_note(i, float(i)) for i in range(20)]
348 report = compute_drift_report(
349 project_id="p1",
350 head_variation_id="v1",
351 head_snapshot_notes={"r1": head},
352 working_snapshot_notes={"r1": working},
353 track_regions={"r1": "t1"},
354 )
355 summary = report.region_summaries["r1"]
356 assert len(summary.sample_changes) <= 5
357 assert summary.added == 20
358
359 def test_sample_changes_include_type(self) -> None:
360
361 """Each sample change should have a 'type' key."""
362 head = [_note(60, 0.0)]
363 working = [_note(60, 0.0, vel=50), _note(72, 4.0)]
364 report = compute_drift_report(
365 project_id="p1",
366 head_variation_id="v1",
367 head_snapshot_notes={"r1": head},
368 working_snapshot_notes={"r1": working},
369 track_regions={"r1": "t1"},
370 )
371 summary = report.region_summaries["r1"]
372 for sc in summary.sample_changes:
373 assert "type" in sc
374 assert sc["type"] in ("added", "removed", "modified")
375
376
377 # ---------------------------------------------------------------------------
378 # DriftReport properties
379 # ---------------------------------------------------------------------------
380
381
382 class TestDriftReportProperties:
383
384 def test_total_changes_sums_all_regions(self) -> None:
385
386 report = compute_drift_report(
387 project_id="p1",
388 head_variation_id="v1",
389 head_snapshot_notes={
390 "r1": [_note(60, 0.0)],
391 "r2": [_note(64, 0.0)],
392 },
393 working_snapshot_notes={
394 "r1": [_note(60, 0.0), _note(72, 4.0)],
395 "r2": [],
396 },
397 track_regions={"r1": "t1", "r2": "t2"},
398 )
399 assert report.total_changes == 2 # 1 add + 1 remove
400
401 def test_no_legacy_flags(self) -> None:
402
403 """notes_only and partial_reconstruction flags have been removed."""
404 report = compute_drift_report(
405 project_id="p1",
406 head_variation_id="v1",
407 head_snapshot_notes={},
408 working_snapshot_notes={},
409 track_regions={},
410 )
411 assert not hasattr(report, "notes_only")
412 assert not hasattr(report, "partial_reconstruction")
413
414
415 # ---------------------------------------------------------------------------
416 # HEAD Snapshot Reconstruction (requires DB)
417 # ---------------------------------------------------------------------------
418
419
420 class TestReconstructHeadSnapshot:
421
422 @pytest.mark.anyio
423 async def test_no_head_returns_none(self, async_session: AsyncSession) -> None:
424
425 result = await reconstruct_head_snapshot(async_session, "nonexistent-project")
426 assert result is None
427
428 @pytest.mark.anyio
429 async def test_single_variation_head(self, async_session: AsyncSession) -> None:
430
431 """Persist a variation, set HEAD, reconstruct snapshot."""
432 notes = [_note(60, 0.0), _note(64, 1.0), _note(67, 2.0)]
433 var = _make_variation_with_notes(notes)
434
435 await muse_repository.save_variation(
436 async_session, var,
437 project_id="proj-h", base_state_id="s1", conversation_id="c",
438 region_metadata={},
439 )
440 await muse_repository.set_head(
441 async_session, var.variation_id, commit_state_id="s2",
442 )
443 await async_session.commit()
444
445 snap = await reconstruct_head_snapshot(async_session, "proj-h")
446 assert snap is not None
447 assert snap.variation_id == var.variation_id
448 assert not hasattr(snap, "partial")
449 assert "region-1" in snap.notes
450 assert len(snap.notes["region-1"]) == 3
451 assert snap.track_regions["region-1"] == "track-1"
452
453 @pytest.mark.anyio
454 async def test_lineage_accumulates_notes(self, async_session: AsyncSession) -> None:
455
456 """Two variations in lineage → snapshot has notes from both."""
457 notes_a = [_note(60, 0.0)]
458 notes_b = [_note(72, 4.0)]
459 var_a = _make_variation_with_notes(notes_a, region_id="region-1")
460 var_b = _make_variation_with_notes(notes_b, region_id="region-2", track_id="track-2")
461
462 await muse_repository.save_variation(
463 async_session, var_a,
464 project_id="proj-lin", base_state_id="s1", conversation_id="c",
465 region_metadata={},
466 )
467 await muse_repository.save_variation(
468 async_session, var_b,
469 project_id="proj-lin", base_state_id="s1", conversation_id="c",
470 region_metadata={},
471 parent_variation_id=var_a.variation_id,
472 )
473 await muse_repository.set_head(
474 async_session, var_b.variation_id, commit_state_id="s2",
475 )
476 await async_session.commit()
477
478 snap = await reconstruct_head_snapshot(async_session, "proj-lin")
479 assert snap is not None
480 assert "region-1" in snap.notes
481 assert "region-2" in snap.notes
482 assert len(snap.notes["region-1"]) == 1
483 assert len(snap.notes["region-2"]) == 1
484
485
486 # ---------------------------------------------------------------------------
487 # End-to-end: reconstruct HEAD + compute drift
488 # ---------------------------------------------------------------------------
489
490
491 class TestEndToEndDrift:
492
493 @pytest.mark.anyio
494 async def test_clean_after_commit(self, async_session: AsyncSession) -> None:
495
496 """Persist, set HEAD, reconstruct, compare with identical working → CLEAN."""
497 notes = [_note(60, 0.0), _note(64, 1.0)]
498 var = _make_variation_with_notes(notes)
499
500 await muse_repository.save_variation(
501 async_session, var,
502 project_id="proj-e2e", base_state_id="s1", conversation_id="c",
503 region_metadata={},
504 )
505 await muse_repository.set_head(
506 async_session, var.variation_id, commit_state_id="s2",
507 )
508 await async_session.commit()
509
510 snap = await reconstruct_head_snapshot(async_session, "proj-e2e")
511 assert snap is not None
512
513 report = compute_drift_report(
514 project_id="proj-e2e",
515 head_variation_id=snap.variation_id,
516 head_snapshot_notes=snap.notes,
517 working_snapshot_notes=snap.notes,
518 track_regions=snap.track_regions,
519 )
520 assert report.is_clean is True
521 assert report.severity == DriftSeverity.CLEAN
522
523 @pytest.mark.anyio
524 async def test_dirty_after_user_edit(self, async_session: AsyncSession) -> None:
525
526 """HEAD has notes, working has different notes → DIRTY."""
527 notes = [_note(60, 0.0)]
528 var = _make_variation_with_notes(notes)
529
530 await muse_repository.save_variation(
531 async_session, var,
532 project_id="proj-dirty", base_state_id="s1", conversation_id="c",
533 region_metadata={},
534 )
535 await muse_repository.set_head(
536 async_session, var.variation_id, commit_state_id="s2",
537 )
538 await async_session.commit()
539
540 snap = await reconstruct_head_snapshot(async_session, "proj-dirty")
541 assert snap is not None
542
543 working_notes = {"region-1": [_note(60, 0.0), _note(72, 4.0)]}
544 report = compute_drift_report(
545 project_id="proj-dirty",
546 head_variation_id=snap.variation_id,
547 head_snapshot_notes=snap.notes,
548 working_snapshot_notes=working_notes,
549 track_regions=snap.track_regions,
550 )
551 assert report.is_clean is False
552 assert report.severity == DriftSeverity.DIRTY
553 assert "region-1" in report.changed_regions
554
555
556 # ---------------------------------------------------------------------------
557 # Boundary seal tests
558 # ---------------------------------------------------------------------------
559
560
561 class TestMuseDriftBoundary:
562
563 def test_no_state_store_or_executor_import(self) -> None:
564
565 """muse_drift must not import StateStore, executor, or LLM handlers."""
566 import importlib
567 spec = importlib.util.find_spec("maestro.services.muse_drift")
568 assert spec is not None and spec.origin is not None
569
570 with open(spec.origin) as f:
571 source = f.read()
572
573 tree = ast.parse(source)
574 forbidden_modules = {"state_store", "executor", "maestro_handlers", "maestro_editing", "maestro_composing"}
575 forbidden_names = {"StateStore", "get_or_create_store", "EntityRegistry"}
576
577 for node in ast.walk(tree):
578 if isinstance(node, (ast.Import, ast.ImportFrom)):
579 module = getattr(node, "module", "") or ""
580 for fb in forbidden_modules:
581 assert fb not in module, (
582 f"muse_drift imports forbidden module: {module}"
583 )
584 if hasattr(node, "names"):
585 for alias in node.names:
586 assert alias.name not in forbidden_names, (
587 f"muse_drift imports forbidden name: {alias.name}"
588 )
589
590 def test_no_get_or_create_store_call(self) -> None:
591
592 """muse_drift must not call get_or_create_store (AST-level check)."""
593 import importlib
594 spec = importlib.util.find_spec("maestro.services.muse_drift")
595 assert spec is not None and spec.origin is not None
596
597 with open(spec.origin) as f:
598 source = f.read()
599
600 tree = ast.parse(source)
601 for node in ast.walk(tree):
602 if isinstance(node, ast.Call):
603 func = node.func
604 name = ""
605 if isinstance(func, ast.Name):
606 name = func.id
607 elif isinstance(func, ast.Attribute):
608 name = func.attr
609 assert name != "get_or_create_store", (
610 "muse_drift calls get_or_create_store"
611 )