cgcardona / muse public
test_muse_persistence.py python
664 lines 23.0 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Tests for Muse persistent variation storage, lineage, and replay.
2
3 Verifies:
4 - Variation → DB → domain model roundtrip fidelity.
5 - Commit-from-DB produces identical results to commit-from-memory.
6 - Lineage graph formation and HEAD tracking.
7 - Replay plan construction and determinism.
8 - muse_repository and muse_replay module boundary rules.
9 """
10 from __future__ import annotations
11
12 import ast
13 import uuid
14
15 import pytest
16 from collections.abc import AsyncGenerator
17 from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
18
19 from maestro.contracts.json_types import RegionMetadataWire
20 from maestro.db.database import Base
21 from maestro.db import muse_models # noqa: F401 — register tables
22 from maestro.models.variation import (
23 MidiNoteSnapshot,
24 NoteChange,
25 Phrase,
26 Variation,
27 )
28 from maestro.services import muse_repository, muse_replay
29 from maestro.services.muse_repository import HistoryNode
30 from maestro.services.muse_replay import ReplayPlan
31
32
33 @pytest.fixture
34 async def async_session() -> AsyncGenerator[AsyncSession, None]:
35 """Create an in-memory SQLite async session for tests."""
36 engine = create_async_engine("sqlite+aiosqlite:///:memory:")
37 async with engine.begin() as conn:
38 await conn.run_sync(Base.metadata.create_all)
39 Session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
40 async with Session() as session:
41 yield session
42 await engine.dispose()
43
44
45 def _make_variation() -> Variation:
46 """Build a realistic test variation."""
47 vid = str(uuid.uuid4())
48 pid1 = str(uuid.uuid4())
49 pid2 = str(uuid.uuid4())
50
51 note_added = NoteChange(
52 note_id=str(uuid.uuid4()),
53 change_type="added",
54 before=None,
55 after=MidiNoteSnapshot(
56 pitch=60, start_beat=0.0, duration_beats=1.0, velocity=100, channel=0,
57 ),
58 )
59 note_removed = NoteChange(
60 note_id=str(uuid.uuid4()),
61 change_type="removed",
62 before=MidiNoteSnapshot(
63 pitch=64, start_beat=2.0, duration_beats=0.5, velocity=80, channel=0,
64 ),
65 after=None,
66 )
67 note_modified = NoteChange(
68 note_id=str(uuid.uuid4()),
69 change_type="modified",
70 before=MidiNoteSnapshot(
71 pitch=67, start_beat=4.0, duration_beats=2.0, velocity=90, channel=0,
72 ),
73 after=MidiNoteSnapshot(
74 pitch=67, start_beat=4.0, duration_beats=3.0, velocity=110, channel=0,
75 ),
76 )
77
78 phrase1 = Phrase(
79 phrase_id=pid1,
80 track_id="track-1",
81 region_id="region-1",
82 start_beat=0.0,
83 end_beat=4.0,
84 label="Phrase A",
85 note_changes=[note_added, note_removed],
86 cc_events=[{"cc": 64, "beat": 0.0, "value": 127}],
87 explanation="first phrase",
88 tags=["intro"],
89 )
90 phrase2 = Phrase(
91 phrase_id=pid2,
92 track_id="track-2",
93 region_id="region-2",
94 start_beat=4.0,
95 end_beat=8.0,
96 label="Phrase B",
97 note_changes=[note_modified],
98 cc_events=[],
99 explanation="second phrase",
100 tags=["verse"],
101 )
102
103 return Variation(
104 variation_id=vid,
105 intent="test composition",
106 ai_explanation="test explanation",
107 affected_tracks=["track-1", "track-2"],
108 affected_regions=["region-1", "region-2"],
109 beat_range=(0.0, 8.0),
110 phrases=[phrase1, phrase2],
111 )
112
113
114 # ---------------------------------------------------------------------------
115 # 3.1 — Variation roundtrip
116 # ---------------------------------------------------------------------------
117
118
119 @pytest.mark.anyio
120 async def test_variation_roundtrip(async_session: AsyncSession) -> None:
121
122 """Persist a variation, reload it, assert equality."""
123 original = _make_variation()
124 region_metadata: dict[str, RegionMetadataWire] = {
125 "region-1": {"startBeat": 0, "durationBeats": 16, "name": "Intro Region"},
126 "region-2": {"startBeat": 16, "durationBeats": 16, "name": "Verse Region"},
127 }
128
129 await muse_repository.save_variation(
130 async_session,
131 original,
132 project_id="proj-1",
133 base_state_id="state-42",
134 conversation_id="conv-1",
135 region_metadata=region_metadata,
136 )
137 await async_session.commit()
138
139 loaded = await muse_repository.load_variation(async_session, original.variation_id)
140 assert loaded is not None
141
142 assert loaded.variation_id == original.variation_id
143 assert loaded.intent == original.intent
144 assert loaded.ai_explanation == original.ai_explanation
145 assert loaded.affected_tracks == original.affected_tracks
146 assert loaded.affected_regions == original.affected_regions
147 assert loaded.beat_range == original.beat_range
148 assert len(loaded.phrases) == len(original.phrases)
149
150 for orig_p, load_p in zip(original.phrases, loaded.phrases):
151 assert load_p.phrase_id == orig_p.phrase_id
152 assert load_p.track_id == orig_p.track_id
153 assert load_p.region_id == orig_p.region_id
154 assert load_p.start_beat == orig_p.start_beat
155 assert load_p.end_beat == orig_p.end_beat
156 assert load_p.label == orig_p.label
157 assert load_p.explanation == orig_p.explanation
158 assert load_p.tags == orig_p.tags
159 assert load_p.cc_events == orig_p.cc_events
160 assert load_p.pitch_bends == orig_p.pitch_bends
161 assert load_p.aftertouch == orig_p.aftertouch
162 assert len(load_p.note_changes) == len(orig_p.note_changes)
163
164 for orig_nc, load_nc in zip(orig_p.note_changes, load_p.note_changes):
165 assert load_nc.change_type == orig_nc.change_type
166 if orig_nc.before:
167 assert load_nc.before is not None
168 assert load_nc.before.pitch == orig_nc.before.pitch
169 assert load_nc.before.start_beat == orig_nc.before.start_beat
170 assert load_nc.before.duration_beats == orig_nc.before.duration_beats
171 assert load_nc.before.velocity == orig_nc.before.velocity
172 else:
173 assert load_nc.before is None
174 if orig_nc.after:
175 assert load_nc.after is not None
176 assert load_nc.after.pitch == orig_nc.after.pitch
177 assert load_nc.after.start_beat == orig_nc.after.start_beat
178 assert load_nc.after.duration_beats == orig_nc.after.duration_beats
179 assert load_nc.after.velocity == orig_nc.after.velocity
180 else:
181 assert load_nc.after is None
182
183
184 @pytest.mark.anyio
185 async def test_variation_status_lifecycle(async_session: AsyncSession) -> None:
186
187 """Persist → mark committed → verify status transition."""
188 var = _make_variation()
189 await muse_repository.save_variation(
190 async_session, var,
191 project_id="p", base_state_id="s", conversation_id="c",
192 region_metadata={},
193 )
194 await async_session.commit()
195
196 status = await muse_repository.get_status(async_session, var.variation_id)
197 assert status == "ready"
198
199 await muse_repository.mark_committed(async_session, var.variation_id)
200 await async_session.commit()
201
202 status = await muse_repository.get_status(async_session, var.variation_id)
203 assert status == "committed"
204
205
206 @pytest.mark.anyio
207 async def test_variation_discard(async_session: AsyncSession) -> None:
208
209 """Persist → mark discarded → verify."""
210 var = _make_variation()
211 await muse_repository.save_variation(
212 async_session, var,
213 project_id="p", base_state_id="s", conversation_id="c",
214 region_metadata={},
215 )
216 await async_session.commit()
217
218 await muse_repository.mark_discarded(async_session, var.variation_id)
219 await async_session.commit()
220
221 status = await muse_repository.get_status(async_session, var.variation_id)
222 assert status == "discarded"
223
224
225 @pytest.mark.anyio
226 async def test_load_nonexistent_returns_none(async_session: AsyncSession) -> None:
227
228 """Load with unknown ID returns None."""
229 result = await muse_repository.load_variation(async_session, "nonexistent-id")
230 assert result is None
231
232
233 @pytest.mark.anyio
234 async def test_region_metadata_roundtrip(async_session: AsyncSession) -> None:
235
236 """Region metadata stored on phrases is retrievable."""
237 var = _make_variation()
238 region_metadata: dict[str, RegionMetadataWire] = {
239 "region-1": {"startBeat": 0, "durationBeats": 16, "name": "Intro"},
240 "region-2": {"startBeat": 16, "durationBeats": 8, "name": "Verse"},
241 }
242 await muse_repository.save_variation(
243 async_session, var,
244 project_id="p", base_state_id="s", conversation_id="c",
245 region_metadata=region_metadata,
246 )
247 await async_session.commit()
248
249 loaded_meta = await muse_repository.get_region_metadata(
250 async_session, var.variation_id,
251 )
252 assert "region-1" in loaded_meta
253 assert loaded_meta["region-1"]["name"] == "Intro"
254 assert loaded_meta["region-1"]["start_beat"] == 0
255 assert loaded_meta["region-1"]["duration_beats"] == 16
256
257
258 @pytest.mark.anyio
259 async def test_phrase_ids_in_order(async_session: AsyncSession) -> None:
260
261 """Phrase IDs returned in sequence order."""
262 var = _make_variation()
263 await muse_repository.save_variation(
264 async_session, var,
265 project_id="p", base_state_id="s", conversation_id="c",
266 region_metadata={},
267 )
268 await async_session.commit()
269
270 ids = await muse_repository.get_phrase_ids(async_session, var.variation_id)
271 assert len(ids) == 2
272 assert ids == [p.phrase_id for p in var.phrases]
273
274
275 # ---------------------------------------------------------------------------
276 # 3.2 — Commit replay safety
277 # ---------------------------------------------------------------------------
278
279
280 @pytest.mark.anyio
281 async def test_commit_replay_from_db(async_session: AsyncSession) -> None:
282
283 """Simulate memory loss: persist variation, reload, verify commit-ready data."""
284 original = _make_variation()
285 region_metadata: dict[str, RegionMetadataWire] = {
286 "region-1": {"startBeat": 0, "durationBeats": 16, "name": "R1"},
287 "region-2": {"startBeat": 16, "durationBeats": 16, "name": "R2"},
288 }
289
290 await muse_repository.save_variation(
291 async_session, original,
292 project_id="proj-1", base_state_id="state-42", conversation_id="c",
293 region_metadata=region_metadata,
294 )
295 await async_session.commit()
296
297 loaded = await muse_repository.load_variation(async_session, original.variation_id)
298 assert loaded is not None
299
300 base_state = await muse_repository.get_base_state_id(
301 async_session, original.variation_id,
302 )
303 assert base_state == "state-42"
304
305 phrase_ids = await muse_repository.get_phrase_ids(
306 async_session, original.variation_id,
307 )
308 assert phrase_ids == [p.phrase_id for p in original.phrases]
309
310 assert len(loaded.phrases) == len(original.phrases)
311 for orig_p, loaded_p in zip(original.phrases, loaded.phrases):
312 assert loaded_p.phrase_id == orig_p.phrase_id
313 assert len(loaded_p.note_changes) == len(orig_p.note_changes)
314 for orig_nc, load_nc in zip(orig_p.note_changes, loaded_p.note_changes):
315 assert load_nc.change_type == orig_nc.change_type
316 assert load_nc.before == orig_nc.before
317 assert load_nc.after == orig_nc.after
318
319
320 # ---------------------------------------------------------------------------
321 # Phase 5 — Lineage graph tests
322 # ---------------------------------------------------------------------------
323
324
325 def _make_child_variation(parent_id: str, intent: str = "child") -> Variation:
326
327 """Build a simple variation for lineage tests."""
328 vid = str(uuid.uuid4())
329 pid = str(uuid.uuid4())
330 return Variation(
331 variation_id=vid,
332 intent=intent,
333 ai_explanation=f"explanation for {intent}",
334 affected_tracks=["track-1"],
335 affected_regions=["region-1"],
336 beat_range=(0.0, 4.0),
337 phrases=[
338 Phrase(
339 phrase_id=pid,
340 track_id="track-1",
341 region_id="region-1",
342 start_beat=0.0,
343 end_beat=4.0,
344 label=f"Phrase for {intent}",
345 note_changes=[
346 NoteChange(
347 note_id=str(uuid.uuid4()),
348 change_type="added",
349 after=MidiNoteSnapshot(
350 pitch=60, start_beat=0.0, duration_beats=1.0,
351 velocity=100, channel=0,
352 ),
353 ),
354 ],
355 ),
356 ],
357 )
358
359
360 @pytest.mark.anyio
361 async def test_set_and_get_head(async_session: AsyncSession) -> None:
362
363 """set_head marks a variation as HEAD, get_head retrieves it."""
364 var = _make_variation()
365 await muse_repository.save_variation(
366 async_session, var,
367 project_id="proj-head", base_state_id="s1", conversation_id="c",
368 region_metadata={},
369 )
370 await async_session.commit()
371
372 head = await muse_repository.get_head(async_session, "proj-head")
373 assert head is None
374
375 await muse_repository.set_head(
376 async_session, var.variation_id, commit_state_id="state-99",
377 )
378 await async_session.commit()
379
380 head = await muse_repository.get_head(async_session, "proj-head")
381 assert head is not None
382 assert head.variation_id == var.variation_id
383 assert head.commit_state_id == "state-99"
384
385
386 @pytest.mark.anyio
387 async def test_set_head_clears_previous(async_session: AsyncSession) -> None:
388
389 """Setting HEAD on one variation clears HEAD from another."""
390 var_a = _make_variation()
391 var_b = _make_child_variation(var_a.variation_id, "second")
392
393 for var in [var_a, var_b]:
394 await muse_repository.save_variation(
395 async_session, var,
396 project_id="proj-swap", base_state_id="s1", conversation_id="c",
397 region_metadata={},
398 )
399 await async_session.commit()
400
401 await muse_repository.set_head(async_session, var_a.variation_id)
402 await async_session.commit()
403
404 head = await muse_repository.get_head(async_session, "proj-swap")
405 assert head is not None
406 assert head.variation_id == var_a.variation_id
407
408 await muse_repository.set_head(async_session, var_b.variation_id)
409 await async_session.commit()
410
411 head = await muse_repository.get_head(async_session, "proj-swap")
412 assert head is not None
413 assert head.variation_id == var_b.variation_id
414
415
416 @pytest.mark.anyio
417 async def test_move_head(async_session: AsyncSession) -> None:
418
419 """move_head moves HEAD pointer without any StateStore involvement."""
420 var_a = _make_variation()
421 var_b = _make_child_variation(var_a.variation_id, "b")
422
423 for var in [var_a, var_b]:
424 await muse_repository.save_variation(
425 async_session, var,
426 project_id="proj-move", base_state_id="s1", conversation_id="c",
427 region_metadata={},
428 )
429 await async_session.commit()
430
431 await muse_repository.set_head(async_session, var_a.variation_id)
432 await async_session.commit()
433
434 await muse_repository.move_head(async_session, "proj-move", var_b.variation_id)
435 await async_session.commit()
436
437 head = await muse_repository.get_head(async_session, "proj-move")
438 assert head is not None
439 assert head.variation_id == var_b.variation_id
440
441
442 @pytest.mark.anyio
443 async def test_get_children(async_session: AsyncSession) -> None:
444
445 """get_children returns child variations."""
446 parent = _make_variation()
447 child_a = _make_child_variation(parent.variation_id, "child-a")
448 child_b = _make_child_variation(parent.variation_id, "child-b")
449
450 await muse_repository.save_variation(
451 async_session, parent,
452 project_id="proj-c", base_state_id="s1", conversation_id="c",
453 region_metadata={},
454 )
455 for child in [child_a, child_b]:
456 await muse_repository.save_variation(
457 async_session, child,
458 project_id="proj-c", base_state_id="s1", conversation_id="c",
459 region_metadata={},
460 parent_variation_id=parent.variation_id,
461 )
462 await async_session.commit()
463
464 children = await muse_repository.get_children(async_session, parent.variation_id)
465 assert len(children) == 2
466 child_ids = {c.variation_id for c in children}
467 assert child_a.variation_id in child_ids
468 assert child_b.variation_id in child_ids
469
470
471 @pytest.mark.anyio
472 async def test_get_lineage(async_session: AsyncSession) -> None:
473
474 """get_lineage returns root-first path."""
475 root = _make_variation()
476 mid = _make_child_variation(root.variation_id, "mid")
477 leaf = _make_child_variation(mid.variation_id, "leaf")
478
479 await muse_repository.save_variation(
480 async_session, root,
481 project_id="proj-l", base_state_id="s1", conversation_id="c",
482 region_metadata={},
483 )
484 await muse_repository.save_variation(
485 async_session, mid,
486 project_id="proj-l", base_state_id="s1", conversation_id="c",
487 region_metadata={},
488 parent_variation_id=root.variation_id,
489 )
490 await muse_repository.save_variation(
491 async_session, leaf,
492 project_id="proj-l", base_state_id="s1", conversation_id="c",
493 region_metadata={},
494 parent_variation_id=mid.variation_id,
495 )
496 await async_session.commit()
497
498 lineage = await muse_repository.get_lineage(async_session, leaf.variation_id)
499 assert len(lineage) == 3
500 assert lineage[0].variation_id == root.variation_id
501 assert lineage[1].variation_id == mid.variation_id
502 assert lineage[2].variation_id == leaf.variation_id
503
504
505 # ---------------------------------------------------------------------------
506 # Phase 5 — Replay plan tests
507 # ---------------------------------------------------------------------------
508
509
510 @pytest.mark.anyio
511 async def test_replay_plan_linear(async_session: AsyncSession) -> None:
512
513 """build_replay_plan reconstructs A → B lineage correctly."""
514 var_a = _make_variation()
515 var_b = _make_child_variation(var_a.variation_id, "child-b")
516
517 await muse_repository.save_variation(
518 async_session, var_a,
519 project_id="proj-rp", base_state_id="s1", conversation_id="c",
520 region_metadata={},
521 )
522 await muse_repository.save_variation(
523 async_session, var_b,
524 project_id="proj-rp", base_state_id="s1", conversation_id="c",
525 region_metadata={},
526 parent_variation_id=var_a.variation_id,
527 )
528 await async_session.commit()
529
530 plan = await muse_replay.build_replay_plan(
531 async_session, "proj-rp", var_b.variation_id,
532 )
533 assert plan is not None
534 assert plan.ordered_variation_ids == [var_a.variation_id, var_b.variation_id]
535 assert len(plan.ordered_phrase_ids) == len(var_a.phrases) + len(var_b.phrases)
536 assert len(plan.region_updates) >= 1
537
538
539 @pytest.mark.anyio
540 async def test_replay_plan_single_variation(async_session: AsyncSession) -> None:
541
542 """Replay plan for a root variation (no parent) works correctly."""
543 var = _make_variation()
544 await muse_repository.save_variation(
545 async_session, var,
546 project_id="proj-single", base_state_id="s1", conversation_id="c",
547 region_metadata={},
548 )
549 await async_session.commit()
550
551 plan = await muse_replay.build_replay_plan(
552 async_session, "proj-single", var.variation_id,
553 )
554 assert plan is not None
555 assert plan.ordered_variation_ids == [var.variation_id]
556 assert len(plan.ordered_phrase_ids) == len(var.phrases)
557
558
559 @pytest.mark.anyio
560 async def test_replay_plan_nonexistent_returns_none(async_session: AsyncSession) -> None:
561
562 """Replay plan for nonexistent variation returns None."""
563 plan = await muse_replay.build_replay_plan(
564 async_session, "proj-x", "nonexistent",
565 )
566 assert plan is None
567
568
569 @pytest.mark.anyio
570 async def test_replay_preserves_phrase_ordering(async_session: AsyncSession) -> None:
571
572 """Restart safety: persist, reload, build plan — phrase order is stable."""
573 var_a = _make_variation()
574 var_b = _make_child_variation(var_a.variation_id, "after-restart")
575
576 await muse_repository.save_variation(
577 async_session, var_a,
578 project_id="proj-restart", base_state_id="s1", conversation_id="c",
579 region_metadata={},
580 )
581 await muse_repository.save_variation(
582 async_session, var_b,
583 project_id="proj-restart", base_state_id="s1", conversation_id="c",
584 region_metadata={},
585 parent_variation_id=var_a.variation_id,
586 )
587 await async_session.commit()
588
589 plan = await muse_replay.build_replay_plan(
590 async_session, "proj-restart", var_b.variation_id,
591 )
592 assert plan is not None
593
594 expected_phrases = [
595 p.phrase_id for p in var_a.phrases
596 ] + [
597 p.phrase_id for p in var_b.phrases
598 ]
599 assert plan.ordered_phrase_ids == expected_phrases
600
601
602 # ---------------------------------------------------------------------------
603 # Boundary check — muse_repository must not import StateStore
604 # ---------------------------------------------------------------------------
605
606
607 def test_muse_repository_boundary() -> None:
608 """muse_repository must not import StateStore or executor modules."""
609 import importlib
610 spec = importlib.util.find_spec("maestro.services.muse_repository")
611 assert spec is not None and spec.origin is not None
612
613 with open(spec.origin) as f:
614 source = f.read()
615
616 tree = ast.parse(source)
617 forbidden = {"StateStore", "get_or_create_store", "EntityRegistry"}
618
619 for node in ast.walk(tree):
620 if isinstance(node, (ast.Import, ast.ImportFrom)):
621 module = getattr(node, "module", "") or ""
622 assert "state_store" not in module, (
623 f"muse_repository imports state_store: {module}"
624 )
625 assert "executor" not in module, (
626 f"muse_repository imports executor: {module}"
627 )
628 if hasattr(node, "names"):
629 for alias in node.names:
630 assert alias.name not in forbidden, (
631 f"muse_repository imports forbidden name: {alias.name}"
632 )
633
634
635 # ---------------------------------------------------------------------------
636 # Boundary check — muse_replay must be pure data (Phase 5)
637 # ---------------------------------------------------------------------------
638
639
640 def test_muse_replay_boundary() -> None:
641 """muse_replay must not import StateStore, executor, or LLM handlers."""
642 import importlib
643 spec = importlib.util.find_spec("maestro.services.muse_replay")
644 assert spec is not None and spec.origin is not None
645
646 with open(spec.origin) as f:
647 source = f.read()
648
649 tree = ast.parse(source)
650 forbidden_modules = {"state_store", "executor", "maestro_handlers", "maestro_editing", "maestro_composing"}
651 forbidden_names = {"StateStore", "get_or_create_store", "EntityRegistry"}
652
653 for node in ast.walk(tree):
654 if isinstance(node, (ast.Import, ast.ImportFrom)):
655 module = getattr(node, "module", "") or ""
656 for fb in forbidden_modules:
657 assert fb not in module, (
658 f"muse_replay imports forbidden module: {module}"
659 )
660 if hasattr(node, "names"):
661 for alias in node.names:
662 assert alias.name not in forbidden_names, (
663 f"muse_replay imports forbidden name: {alias.name}"
664 )