cgcardona / muse public
test_muse_log_graph.py python
325 lines 11.1 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Tests for Muse Log Graph Serialization (Phase 13).
2
3 Verifies:
4 - Linear history ordering.
5 - Branch + merge DAG with parent2 and HEAD.
6 - Deterministic JSON output.
7 - Boundary seal (AST).
8 """
9 from __future__ import annotations
10
11 import ast
12 import json
13 import uuid
14 from pathlib import Path
15 import pytest
16 from maestro.contracts.json_types import NoteDict
17 from collections.abc import AsyncGenerator
18 from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
19
20 from maestro.db.database import Base
21 from maestro.db import muse_models # noqa: F401
22 from maestro.models.variation import (
23 MidiNoteSnapshot,
24 NoteChange,
25 Phrase,
26 Variation,
27 )
28 from maestro.services import muse_repository
29 from maestro.services.muse_log_graph import (
30 MuseLogGraph,
31 MuseLogNode,
32 build_muse_log_graph,
33 )
34
35
36 # ── Fixtures ──────────────────────────────────────────────────────────────
37
38
39 @pytest.fixture
40 async def async_session() -> AsyncGenerator[AsyncSession, None]:
41 engine = create_async_engine("sqlite+aiosqlite:///:memory:")
42 async with engine.begin() as conn:
43 await conn.run_sync(Base.metadata.create_all)
44 Session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
45 async with Session() as session:
46 yield session
47 await engine.dispose()
48
49
50 # ── Helpers ───────────────────────────────────────────────────────────────
51
52
53 def _note(pitch: int, start: float) -> NoteDict:
54
55 return {"pitch": pitch, "start_beat": start, "duration_beats": 1.0, "velocity": 100, "channel": 0}
56
57
58 def _make_variation(
59 notes: list[NoteDict],
60 region_id: str = "region-1",
61 track_id: str = "track-1",
62 intent: str = "test",
63 ) -> Variation:
64 vid = str(uuid.uuid4())
65 pid = str(uuid.uuid4())
66 return Variation(
67 variation_id=vid,
68 intent=intent,
69 ai_explanation=intent,
70 affected_tracks=[track_id],
71 affected_regions=[region_id],
72 beat_range=(0.0, 8.0),
73 phrases=[
74 Phrase(
75 phrase_id=pid,
76 track_id=track_id,
77 region_id=region_id,
78 start_beat=0.0,
79 end_beat=8.0,
80 label=intent,
81 note_changes=[
82 NoteChange(
83 note_id=str(uuid.uuid4()),
84 change_type="added",
85 after=MidiNoteSnapshot.from_note_dict(n),
86 )
87 for n in notes
88 ],
89 ),
90 ],
91 )
92
93
94 async def _save(
95 session: AsyncSession,
96 var: Variation,
97 project_id: str,
98 parent: str | None = None,
99 parent2: str | None = None,
100 is_head: bool = False,
101 ) -> str:
102 await muse_repository.save_variation(
103 session, var,
104 project_id=project_id,
105 base_state_id="s1",
106 conversation_id="c",
107 region_metadata={},
108 parent_variation_id=parent,
109 parent2_variation_id=parent2,
110 )
111 if is_head:
112 await muse_repository.set_head(session, var.variation_id)
113 return var.variation_id
114
115
116 # ---------------------------------------------------------------------------
117 # 6.1 — Linear History
118 # ---------------------------------------------------------------------------
119
120
121 class TestLinearHistory:
122
123 @pytest.mark.anyio
124 async def test_linear_order_preserved(self, async_session: AsyncSession) -> None:
125
126 """C0 -> C1 -> C2 — nodes must appear in that order."""
127 c0 = _make_variation([_note(60, 0.0)], intent="init")
128 c0_id = await _save(async_session, c0, "proj-lin")
129
130 c1 = _make_variation([_note(64, 2.0)], intent="add chord")
131 c1_id = await _save(async_session, c1, "proj-lin", parent=c0_id)
132
133 c2 = _make_variation([_note(67, 4.0)], intent="add melody")
134 c2_id = await _save(async_session, c2, "proj-lin", parent=c1_id, is_head=True)
135
136 await async_session.commit()
137
138 graph = await build_muse_log_graph(async_session, "proj-lin")
139
140 assert len(graph.nodes) == 3
141 ids = [n.variation_id for n in graph.nodes]
142 assert ids == [c0_id, c1_id, c2_id]
143
144 @pytest.mark.anyio
145 async def test_head_detected(self, async_session: AsyncSession) -> None:
146
147 c0 = _make_variation([_note(60, 0.0)])
148 c0_id = await _save(async_session, c0, "proj-head", is_head=True)
149 await async_session.commit()
150
151 graph = await build_muse_log_graph(async_session, "proj-head")
152 assert graph.head == c0_id
153 assert graph.nodes[0].is_head
154
155 @pytest.mark.anyio
156 async def test_empty_project(self, async_session: AsyncSession) -> None:
157
158 graph = await build_muse_log_graph(async_session, "proj-empty")
159 assert graph.head is None
160 assert len(graph.nodes) == 0
161
162 @pytest.mark.anyio
163 async def test_parent_field_set(self, async_session: AsyncSession) -> None:
164
165 c0 = _make_variation([_note(60, 0.0)])
166 c0_id = await _save(async_session, c0, "proj-par")
167 c1 = _make_variation([_note(64, 2.0)])
168 c1_id = await _save(async_session, c1, "proj-par", parent=c0_id)
169 await async_session.commit()
170
171 graph = await build_muse_log_graph(async_session, "proj-par")
172 assert graph.nodes[0].parent is None
173 assert graph.nodes[1].parent == c0_id
174
175
176 # ---------------------------------------------------------------------------
177 # 6.2 — Branch + Merge Graph
178 # ---------------------------------------------------------------------------
179
180
181 class TestBranchMergeGraph:
182
183 @pytest.mark.anyio
184 async def test_merge_parent2_serialized(self, async_session: AsyncSession) -> None:
185
186 """
187 C0
188 ├── C1 (bass)
189 ├── C2 (piano)
190 └── C3 merge(C1,C2) ← HEAD
191 """
192 c0 = _make_variation([_note(60, 0.0)], intent="root")
193 c0_id = await _save(async_session, c0, "proj-merge")
194
195 c1 = _make_variation([_note(36, 0.0)], region_id="r-bass", intent="add bass")
196 c1_id = await _save(async_session, c1, "proj-merge", parent=c0_id)
197
198 c2 = _make_variation([_note(72, 0.0)], region_id="r-piano", intent="add piano")
199 c2_id = await _save(async_session, c2, "proj-merge", parent=c0_id)
200
201 c3 = _make_variation([], intent="merge")
202 c3_id = await _save(
203 async_session, c3, "proj-merge",
204 parent=c1_id, parent2=c2_id, is_head=True,
205 )
206 await async_session.commit()
207
208 graph = await build_muse_log_graph(async_session, "proj-merge")
209
210 assert len(graph.nodes) == 4
211 assert graph.head == c3_id
212
213 merge_node = [n for n in graph.nodes if n.variation_id == c3_id][0]
214 assert merge_node.parent == c1_id
215 assert merge_node.parent2 == c2_id
216 assert merge_node.is_head
217
218 c0_idx = next(i for i, n in enumerate(graph.nodes) if n.variation_id == c0_id)
219 c1_idx = next(i for i, n in enumerate(graph.nodes) if n.variation_id == c1_id)
220 c2_idx = next(i for i, n in enumerate(graph.nodes) if n.variation_id == c2_id)
221 c3_idx = next(i for i, n in enumerate(graph.nodes) if n.variation_id == c3_id)
222 assert c0_idx < c1_idx
223 assert c0_idx < c2_idx
224 assert c1_idx < c3_idx
225 assert c2_idx < c3_idx
226
227 @pytest.mark.anyio
228 async def test_regions_extracted(self, async_session: AsyncSession) -> None:
229
230 c0 = _make_variation([_note(60, 0.0)], region_id="r-drums", intent="drums")
231 await _save(async_session, c0, "proj-reg")
232 await async_session.commit()
233
234 graph = await build_muse_log_graph(async_session, "proj-reg")
235 assert graph.nodes[0].affected_regions == ("r-drums",)
236
237
238 # ---------------------------------------------------------------------------
239 # 6.3 — Determinism
240 # ---------------------------------------------------------------------------
241
242
243 class TestDeterminism:
244
245 @pytest.mark.anyio
246 async def test_repeated_calls_identical_json(self, async_session: AsyncSession) -> None:
247
248 c0 = _make_variation([_note(60, 0.0)], intent="root")
249 c0_id = await _save(async_session, c0, "proj-det")
250
251 c1 = _make_variation([_note(64, 2.0)], intent="branch-a")
252 await _save(async_session, c1, "proj-det", parent=c0_id)
253
254 c2 = _make_variation([_note(67, 4.0)], intent="branch-b")
255 await _save(async_session, c2, "proj-det", parent=c0_id)
256
257 await async_session.commit()
258
259 g1 = await build_muse_log_graph(async_session, "proj-det")
260 g2 = await build_muse_log_graph(async_session, "proj-det")
261
262 j1 = json.dumps(g1.to_response().model_dump(), sort_keys=True)
263 j2 = json.dumps(g2.to_response().model_dump(), sort_keys=True)
264 assert j1 == j2
265
266 @pytest.mark.anyio
267 async def test_to_dict_field_names(self, async_session: AsyncSession) -> None:
268
269 c0 = _make_variation([_note(60, 0.0)], intent="init")
270 await _save(async_session, c0, "proj-fields", is_head=True)
271 await async_session.commit()
272
273 graph = await build_muse_log_graph(async_session, "proj-fields")
274 d = graph.to_response().model_dump()
275
276 assert "projectId" in d
277 assert "head" in d
278 assert "nodes" in d
279
280 nodes_list = d["nodes"]
281 assert isinstance(nodes_list, list)
282 node = nodes_list[0]
283 assert isinstance(node, dict)
284 assert "id" in node
285 assert "parent" in node
286 assert "parent2" in node
287 assert "isHead" in node
288 assert "timestamp" in node
289 assert "intent" in node
290 assert "regions" in node
291
292
293 # ---------------------------------------------------------------------------
294 # 6.4 — Boundary Seal
295 # ---------------------------------------------------------------------------
296
297
298 class TestLogGraphBoundary:
299
300 def test_no_state_store_import(self) -> None:
301
302 filepath = Path(__file__).resolve().parent.parent / "maestro" / "services" / "muse_log_graph.py"
303 tree = ast.parse(filepath.read_text())
304 forbidden = {
305 "state_store", "executor", "maestro_handlers", "maestro_editing",
306 "muse_drift", "muse_merge", "muse_checkout", "muse_replay",
307 }
308 for node in ast.walk(tree):
309 if isinstance(node, ast.ImportFrom) and node.module:
310 for fb in forbidden:
311 assert fb not in node.module, (
312 f"muse_log_graph imports forbidden module: {node.module}"
313 )
314
315 def test_no_forbidden_names(self) -> None:
316
317 filepath = Path(__file__).resolve().parent.parent / "maestro" / "services" / "muse_log_graph.py"
318 tree = ast.parse(filepath.read_text())
319 forbidden_names = {"StateStore", "get_or_create_store", "VariationService"}
320 for node in ast.walk(tree):
321 if isinstance(node, (ast.Import, ast.ImportFrom)):
322 for alias in node.names:
323 assert alias.name not in forbidden_names, (
324 f"muse_log_graph imports forbidden name: {alias.name}"
325 )