cgcardona / muse public
test_merge_engine.py python
366 lines 11.8 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Unit tests for the Muse CLI merge engine (pure functions + find_merge_base).
2
3 All async tests use ``@pytest.mark.anyio``. Pure-function tests are
4 synchronous and exercise the filesystem-free merge logic in isolation.
5 ``find_merge_base`` tests use the in-memory SQLite session from ``conftest.py``.
6 """
7 from __future__ import annotations
8
9 import json
10 import pathlib
11 import uuid
12
13 import pytest
14 import pytest_asyncio
15 from sqlalchemy.ext.asyncio import AsyncSession
16
17 from maestro.muse_cli.merge_engine import (
18 MergeState,
19 apply_merge,
20 detect_conflicts,
21 diff_snapshots,
22 find_merge_base,
23 read_merge_state,
24 write_merge_state,
25 )
26 from maestro.muse_cli.models import MuseCliCommit
27 from maestro.muse_cli.snapshot import compute_snapshot_id
28
29
30 # ---------------------------------------------------------------------------
31 # Helpers
32 # ---------------------------------------------------------------------------
33
34
35 def _make_commit(
36 *,
37 parent: str | None = None,
38 parent2: str | None = None,
39 branch: str = "main",
40 ) -> MuseCliCommit:
41 """Build (but don't yet persist) a MuseCliCommit with a random commit_id."""
42 import datetime
43
44 now = datetime.datetime.now(datetime.timezone.utc)
45 manifest: dict[str, str] = {}
46 snapshot_id = compute_snapshot_id(manifest)
47 commit_id = str(uuid.uuid4()).replace("-", "")[:64].ljust(64, "0")
48 return MuseCliCommit(
49 commit_id=commit_id,
50 repo_id="test-repo",
51 branch=branch,
52 parent_commit_id=parent,
53 parent2_commit_id=parent2,
54 snapshot_id=snapshot_id,
55 message="test commit",
56 author="",
57 committed_at=now,
58 )
59
60
61 # ---------------------------------------------------------------------------
62 # diff_snapshots — pure function tests
63 # ---------------------------------------------------------------------------
64
65
66 def test_diff_snapshots_empty_base_all_added() -> None:
67 """Every path in other is 'added' when base is empty."""
68 changed = diff_snapshots({}, {"a.mid": "aaa", "b.mid": "bbb"})
69 assert changed == {"a.mid", "b.mid"}
70
71
72 def test_diff_snapshots_deleted_paths() -> None:
73 """Paths removed from other relative to base are detected."""
74 changed = diff_snapshots({"a.mid": "aaa", "b.mid": "bbb"}, {"a.mid": "aaa"})
75 assert changed == {"b.mid"}
76
77
78 def test_diff_snapshots_modified_paths() -> None:
79 """Paths with different object_ids are detected as modified."""
80 changed = diff_snapshots({"a.mid": "old"}, {"a.mid": "new"})
81 assert changed == {"a.mid"}
82
83
84 def test_diff_snapshots_unchanged_paths_excluded() -> None:
85 """Paths with identical object_ids are NOT included."""
86 changed = diff_snapshots({"a.mid": "same"}, {"a.mid": "same"})
87 assert changed == set()
88
89
90 def test_diff_snapshots_mixed() -> None:
91 """Added, modified, deleted, and unchanged paths handled correctly."""
92 base = {"a.mid": "aaa", "b.mid": "bbb", "c.mid": "ccc"}
93 other = {"a.mid": "aaa", "b.mid": "BBB", "d.mid": "ddd"}
94 changed = diff_snapshots(base, other)
95 assert changed == {"b.mid", "c.mid", "d.mid"}
96
97
98 # ---------------------------------------------------------------------------
99 # detect_conflicts — pure function tests
100 # ---------------------------------------------------------------------------
101
102
103 def test_detect_conflicts_disjoint_no_conflicts() -> None:
104 """No conflict when each branch changes different paths."""
105 assert detect_conflicts({"a.mid"}, {"b.mid"}) == set()
106
107
108 def test_detect_conflicts_same_path_is_conflict() -> None:
109 """Same path changed on both sides is a conflict."""
110 assert detect_conflicts({"beat.mid", "x.mid"}, {"beat.mid", "y.mid"}) == {"beat.mid"}
111
112
113 def test_detect_conflicts_empty_inputs() -> None:
114 assert detect_conflicts(set(), set()) == set()
115
116
117 # ---------------------------------------------------------------------------
118 # apply_merge — pure function tests
119 # ---------------------------------------------------------------------------
120
121
122 def test_apply_merge_takes_ours_only_change() -> None:
123 """A path changed only on ours is taken from ours manifest."""
124 base = {"a.mid": "base"}
125 ours = {"a.mid": "ours"}
126 theirs = {"a.mid": "base"}
127 ours_changed = {"a.mid"}
128 theirs_changed: set[str] = set()
129 merged = apply_merge(base, ours, theirs, ours_changed, theirs_changed, set())
130 assert merged["a.mid"] == "ours"
131
132
133 def test_apply_merge_takes_theirs_only_change() -> None:
134 """A path changed only on theirs is taken from theirs manifest."""
135 base = {"a.mid": "base"}
136 ours = {"a.mid": "base"}
137 theirs = {"a.mid": "theirs"}
138 merged = apply_merge(base, ours, theirs, set(), {"a.mid"}, set())
139 assert merged["a.mid"] == "theirs"
140
141
142 def test_apply_merge_deleted_on_ours() -> None:
143 """A path deleted on ours (not in ours manifest) is removed from merged."""
144 base = {"a.mid": "base", "b.mid": "base"}
145 ours = {"b.mid": "base"} # a.mid deleted on ours
146 theirs = {"a.mid": "base", "b.mid": "base"}
147 ours_changed = {"a.mid"}
148 merged = apply_merge(base, ours, theirs, ours_changed, set(), set())
149 assert "a.mid" not in merged
150
151
152 def test_apply_merge_conflict_paths_not_applied() -> None:
153 """Conflict paths are excluded — base version is kept."""
154 base = {"x.mid": "base"}
155 ours = {"x.mid": "ours"}
156 theirs = {"x.mid": "theirs"}
157 ours_changed = {"x.mid"}
158 theirs_changed = {"x.mid"}
159 conflict_paths = {"x.mid"}
160 merged = apply_merge(base, ours, theirs, ours_changed, theirs_changed, conflict_paths)
161 # Conflict path keeps base version (neither side applied).
162 assert merged["x.mid"] == "base"
163
164
165 def test_apply_merge_both_sides_add_different_files() -> None:
166 """Non-conflicting additions from both sides appear in merged manifest."""
167 base: dict[str, str] = {}
168 ours = {"a.mid": "aaa"}
169 theirs = {"b.mid": "bbb"}
170 merged = apply_merge(base, ours, theirs, {"a.mid"}, {"b.mid"}, set())
171 assert merged == {"a.mid": "aaa", "b.mid": "bbb"}
172
173
174 # ---------------------------------------------------------------------------
175 # read_merge_state / write_merge_state — filesystem tests
176 # ---------------------------------------------------------------------------
177
178
179 def test_read_merge_state_no_file_returns_none(tmp_path: pathlib.Path) -> None:
180 (tmp_path / ".muse").mkdir()
181 assert read_merge_state(tmp_path) is None
182
183
184 def test_write_and_read_merge_state_roundtrip(tmp_path: pathlib.Path) -> None:
185 (tmp_path / ".muse").mkdir()
186 write_merge_state(
187 tmp_path,
188 base_commit="base000",
189 ours_commit="ours111",
190 theirs_commit="theirs222",
191 conflict_paths=["beat.mid", "lead.mp3"],
192 other_branch="feature/x",
193 )
194 state = read_merge_state(tmp_path)
195 assert state is not None
196 assert state.base_commit == "base000"
197 assert state.ours_commit == "ours111"
198 assert state.theirs_commit == "theirs222"
199 assert sorted(state.conflict_paths) == ["beat.mid", "lead.mp3"]
200 assert state.other_branch == "feature/x"
201
202
203 def test_read_merge_state_invalid_json_returns_none(tmp_path: pathlib.Path) -> None:
204 muse_dir = tmp_path / ".muse"
205 muse_dir.mkdir()
206 (muse_dir / "MERGE_STATE.json").write_text("not-valid-json{{")
207 assert read_merge_state(tmp_path) is None
208
209
210 # ---------------------------------------------------------------------------
211 # find_merge_base — async tests (require DB session)
212 # ---------------------------------------------------------------------------
213
214
215 @pytest.mark.anyio
216 async def test_find_merge_base_lca(muse_cli_db_session: AsyncSession) -> None:
217 """LCA is correct for a simple fork-and-rejoin graph.
218
219 Graph:
220 base ← A ← ours
221
222 B ← theirs
223
224 Expected LCA = base.
225 """
226 import datetime
227
228 session = muse_cli_db_session
229 now = datetime.datetime.now(datetime.timezone.utc)
230 snapshot_id = compute_snapshot_id({})
231
232 def _commit(cid: str, parent: str | None = None, parent2: str | None = None) -> MuseCliCommit:
233 return MuseCliCommit(
234 commit_id=cid,
235 repo_id="test",
236 branch="main",
237 parent_commit_id=parent,
238 parent2_commit_id=parent2,
239 snapshot_id=snapshot_id,
240 message="msg",
241 author="",
242 committed_at=now,
243 )
244
245 # Persist an empty snapshot so FK constraints pass.
246 from maestro.muse_cli.models import MuseCliSnapshot
247 session.add(MuseCliSnapshot(snapshot_id=snapshot_id, manifest={}))
248 await session.flush()
249
250 base = _commit("base" + "0" * 60)
251 commit_a = _commit("aaaa" + "0" * 60, parent=base.commit_id)
252 commit_b = _commit("bbbb" + "0" * 60, parent=base.commit_id)
253 session.add_all([base, commit_a, commit_b])
254 await session.flush()
255
256 lca = await find_merge_base(session, commit_a.commit_id, commit_b.commit_id)
257 assert lca == base.commit_id
258
259
260 @pytest.mark.anyio
261 async def test_find_merge_base_same_commit(muse_cli_db_session: AsyncSession) -> None:
262 """LCA of a commit with itself is the commit itself."""
263 import datetime
264
265 session = muse_cli_db_session
266 snapshot_id = compute_snapshot_id({})
267 from maestro.muse_cli.models import MuseCliSnapshot
268 session.add(MuseCliSnapshot(snapshot_id=snapshot_id, manifest={}))
269
270 now = datetime.datetime.now(datetime.timezone.utc)
271 c = MuseCliCommit(
272 commit_id="cccc" + "0" * 60,
273 repo_id="test",
274 branch="main",
275 parent_commit_id=None,
276 parent2_commit_id=None,
277 snapshot_id=snapshot_id,
278 message="x",
279 author="",
280 committed_at=now,
281 )
282 session.add(c)
283 await session.flush()
284
285 lca = await find_merge_base(session, c.commit_id, c.commit_id)
286 assert lca == c.commit_id
287
288
289 @pytest.mark.anyio
290 async def test_find_merge_base_linear_returns_ancestor(
291 muse_cli_db_session: AsyncSession,
292 ) -> None:
293 """For a linear history A ← B, LCA(A, B) = A."""
294 import datetime
295
296 session = muse_cli_db_session
297 snapshot_id = compute_snapshot_id({})
298 from maestro.muse_cli.models import MuseCliSnapshot
299 session.add(MuseCliSnapshot(snapshot_id=snapshot_id, manifest={}))
300 await session.flush()
301
302 now = datetime.datetime.now(datetime.timezone.utc)
303 commit_a = MuseCliCommit(
304 commit_id="aaaa" + "1" * 60,
305 repo_id="r",
306 branch="main",
307 parent_commit_id=None,
308 parent2_commit_id=None,
309 snapshot_id=snapshot_id,
310 message="a",
311 author="",
312 committed_at=now,
313 )
314 commit_b = MuseCliCommit(
315 commit_id="bbbb" + "1" * 60,
316 repo_id="r",
317 branch="main",
318 parent_commit_id=commit_a.commit_id,
319 parent2_commit_id=None,
320 snapshot_id=snapshot_id,
321 message="b",
322 author="",
323 committed_at=now,
324 )
325 session.add_all([commit_a, commit_b])
326 await session.flush()
327
328 lca = await find_merge_base(session, commit_a.commit_id, commit_b.commit_id)
329 assert lca == commit_a.commit_id
330
331
332 @pytest.mark.anyio
333 async def test_find_merge_base_disjoint_returns_none(
334 muse_cli_db_session: AsyncSession,
335 ) -> None:
336 """Disjoint histories (no shared ancestor) return None."""
337 import datetime
338
339 session = muse_cli_db_session
340 snapshot_id = compute_snapshot_id({})
341 from maestro.muse_cli.models import MuseCliSnapshot
342 session.add(MuseCliSnapshot(snapshot_id=snapshot_id, manifest={}))
343 await session.flush()
344
345 now = datetime.datetime.now(datetime.timezone.utc)
346
347 def _c(cid: str) -> MuseCliCommit:
348 return MuseCliCommit(
349 commit_id=cid,
350 repo_id="r",
351 branch="main",
352 parent_commit_id=None,
353 parent2_commit_id=None,
354 snapshot_id=snapshot_id,
355 message="x",
356 author="",
357 committed_at=now,
358 )
359
360 c1 = _c("1111" + "0" * 60)
361 c2 = _c("2222" + "0" * 60)
362 session.add_all([c1, c2])
363 await session.flush()
364
365 lca = await find_merge_base(session, c1.commit_id, c2.commit_id)
366 assert lca is None