cgcardona / muse public
test_muse_restore.py python
636 lines 21.6 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Tests for ``muse restore`` — surgical file-level restore from a snapshot.
2
3 Verifies:
4 - test_muse_restore_from_head — default restore from HEAD
5 - test_muse_restore_staged_equivalent — --staged behaves like --worktree (current model)
6 - test_muse_restore_source_commit — --source <commit> extracts from historical snapshot
7 - test_muse_restore_multiple_paths — multiple paths restored in one call
8 - test_muse_restore_muse_work_prefix_stripped"muse-work/" prefix is normalised away
9 - test_muse_restore_errors_on_missing_path — PathNotInSnapshotError when path absent
10 - test_muse_restore_source_missing_path — PathNotInSnapshotError on historical commit
11 - test_muse_restore_missing_object_store — MissingObjectError when blob absent
12 - test_muse_restore_ref_not_found — unknown source ref exits USER_ERROR
13 - test_muse_restore_no_commits — branch with no commits exits USER_ERROR
14 - test_muse_restore_result_fields — RestoreResult fields are correct
15 - test_muse_restore_result_frozen — RestoreResult is immutable
16 - test_boundary_no_forbidden_imports — AST boundary seal
17 - test_restore_service_has_future_import — from __future__ import annotations present
18 - test_restore_command_has_future_import — CLI command has future import
19 """
20 from __future__ import annotations
21
22 import ast
23 import datetime
24 import json
25 import pathlib
26 import uuid
27 from collections.abc import AsyncGenerator
28 from typing import Any
29
30 import pytest
31 from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
32
33 from maestro.db.database import Base
34 from maestro.muse_cli import models as cli_models # noqa: F401 — register tables
35 from maestro.muse_cli.errors import ExitCode
36 from maestro.muse_cli.models import MuseCliCommit, MuseCliObject, MuseCliSnapshot
37 from maestro.muse_cli.object_store import write_object
38 from maestro.services.muse_reset import MissingObjectError
39 from maestro.services.muse_restore import (
40 PathNotInSnapshotError,
41 RestoreResult,
42 perform_restore,
43 )
44
45
46 # ---------------------------------------------------------------------------
47 # Fixtures
48 # ---------------------------------------------------------------------------
49
50
51 @pytest.fixture
52 async def async_session() -> AsyncGenerator[AsyncSession, None]:
53 """In-memory SQLite session with all CLI tables created."""
54 engine = create_async_engine("sqlite+aiosqlite:///:memory:")
55 async with engine.begin() as conn:
56 await conn.run_sync(Base.metadata.create_all)
57 Session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
58 async with Session() as session:
59 yield session
60 await engine.dispose()
61
62
63 @pytest.fixture
64 def repo_id() -> str:
65 return str(uuid.uuid4())
66
67
68 @pytest.fixture
69 def repo_root(tmp_path: pathlib.Path, repo_id: str) -> pathlib.Path:
70 """Minimal Muse repository structure with repo.json."""
71 muse_dir = tmp_path / ".muse"
72 muse_dir.mkdir()
73 (muse_dir / "HEAD").write_text("refs/heads/main")
74 (muse_dir / "refs" / "heads").mkdir(parents=True)
75 (muse_dir / "refs" / "heads" / "main").write_text("")
76 (muse_dir / "repo.json").write_text(json.dumps({"repo_id": repo_id}))
77 return tmp_path
78
79
80 def _sha(prefix: str, length: int = 64) -> str:
81 """Build a deterministic fake SHA of exactly *length* hex chars."""
82 return (prefix * (length // len(prefix) + 1))[:length]
83
84
85 async def _add_commit(
86 session: AsyncSession,
87 *,
88 repo_id: str,
89 branch: str = "main",
90 message: str = "commit",
91 manifest: dict[str, str] | None = None,
92 parent_commit_id: str | None = None,
93 committed_at: datetime.datetime | None = None,
94 ) -> MuseCliCommit:
95 """Insert a commit + its snapshot into the in-memory DB and return the commit."""
96 snapshot_id = _sha(str(uuid.uuid4()).replace("-", ""))
97 commit_id = _sha(str(uuid.uuid4()).replace("-", ""))
98 file_manifest: dict[str, str] = manifest or {"track.mid": _sha("ab")}
99
100 for object_id in file_manifest.values():
101 existing = await session.get(MuseCliObject, object_id)
102 if existing is None:
103 session.add(MuseCliObject(object_id=object_id, size_bytes=10))
104
105 session.add(MuseCliSnapshot(snapshot_id=snapshot_id, manifest=file_manifest))
106 await session.flush()
107
108 ts = committed_at or datetime.datetime.now(datetime.timezone.utc)
109 commit = MuseCliCommit(
110 commit_id=commit_id,
111 repo_id=repo_id,
112 branch=branch,
113 parent_commit_id=parent_commit_id,
114 snapshot_id=snapshot_id,
115 message=message,
116 author="",
117 committed_at=ts,
118 )
119 session.add(commit)
120 await session.flush()
121 return commit
122
123
124 def _write_ref(root: pathlib.Path, branch: str, commit_id: str) -> None:
125 """Update .muse/refs/heads/<branch> with *commit_id*."""
126 ref_path = root / ".muse" / "refs" / "heads" / branch
127 ref_path.parent.mkdir(parents=True, exist_ok=True)
128 ref_path.write_text(commit_id)
129
130
131 def _seed_object_store(root: pathlib.Path, object_id: str, content: bytes) -> None:
132 """Manually write a blob into the .muse/objects/ store via the canonical module."""
133 write_object(root, object_id, content)
134
135
136 # ---------------------------------------------------------------------------
137 # Core restore tests
138 # ---------------------------------------------------------------------------
139
140
141 class TestRestoreFromHead:
142
143 @pytest.mark.anyio
144 async def test_muse_restore_from_head(
145 self,
146 async_session: AsyncSession,
147 repo_id: str,
148 repo_root: pathlib.Path,
149 ) -> None:
150 """Default restore (no source) extracts the file from HEAD snapshot."""
151 object_id = "aa" * 32
152 content = b"MIDI bass take 7"
153 _seed_object_store(repo_root, object_id, content)
154
155 commit = await _add_commit(
156 async_session,
157 repo_id=repo_id,
158 message="take 7",
159 manifest={"bass/bassline.mid": object_id},
160 )
161 _write_ref(repo_root, "main", commit.commit_id)
162
163 # muse-work/ currently has stale content
164 workdir = repo_root / "muse-work" / "bass"
165 workdir.mkdir(parents=True)
166 (workdir / "bassline.mid").write_bytes(b"stale content")
167
168 result = await perform_restore(
169 root=repo_root,
170 session=async_session,
171 paths=["bass/bassline.mid"],
172 source_ref=None,
173 staged=False,
174 )
175
176 assert result.source_commit_id == commit.commit_id
177 assert result.paths_restored == ["bass/bassline.mid"]
178 assert result.staged is False
179 assert (repo_root / "muse-work" / "bass" / "bassline.mid").read_bytes() == content
180
181 @pytest.mark.anyio
182 async def test_muse_restore_staged_equivalent(
183 self,
184 async_session: AsyncSession,
185 repo_id: str,
186 repo_root: pathlib.Path,
187 ) -> None:
188 """--staged behaves identically to --worktree in the current Muse model."""
189 object_id = "bb" * 32
190 content = b"drums take 3"
191 _seed_object_store(repo_root, object_id, content)
192
193 commit = await _add_commit(
194 async_session,
195 repo_id=repo_id,
196 message="take 3",
197 manifest={"drums/kick.mid": object_id},
198 )
199 _write_ref(repo_root, "main", commit.commit_id)
200
201 workdir = repo_root / "muse-work" / "drums"
202 workdir.mkdir(parents=True)
203 (workdir / "kick.mid").write_bytes(b"wrong version")
204
205 result = await perform_restore(
206 root=repo_root,
207 session=async_session,
208 paths=["drums/kick.mid"],
209 source_ref=None,
210 staged=True,
211 )
212
213 assert result.staged is True
214 assert result.paths_restored == ["drums/kick.mid"]
215 assert (repo_root / "muse-work" / "drums" / "kick.mid").read_bytes() == content
216
217
218 class TestRestoreSourceCommit:
219
220 @pytest.mark.anyio
221 async def test_muse_restore_source_commit(
222 self,
223 async_session: AsyncSession,
224 repo_id: str,
225 repo_root: pathlib.Path,
226 ) -> None:
227 """--source <commit> restores a file from a historical snapshot."""
228 obj_take3 = "cc" * 32
229 content_take3 = b"bass take 3"
230 _seed_object_store(repo_root, obj_take3, content_take3)
231
232 obj_take7 = "dd" * 32
233 content_take7 = b"bass take 7"
234 _seed_object_store(repo_root, obj_take7, content_take7)
235
236 t0 = datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc)
237 take3 = await _add_commit(
238 async_session,
239 repo_id=repo_id,
240 message="take 3",
241 manifest={"bass/bassline.mid": obj_take3},
242 committed_at=t0,
243 )
244
245 t1 = datetime.datetime(2024, 1, 2, tzinfo=datetime.timezone.utc)
246 take7 = await _add_commit(
247 async_session,
248 repo_id=repo_id,
249 message="take 7",
250 manifest={"bass/bassline.mid": obj_take7},
251 parent_commit_id=take3.commit_id,
252 committed_at=t1,
253 )
254 _write_ref(repo_root, "main", take7.commit_id)
255
256 # muse-work/ currently has take7 content
257 workdir = repo_root / "muse-work" / "bass"
258 workdir.mkdir(parents=True)
259 (workdir / "bassline.mid").write_bytes(content_take7)
260
261 # Restore take3's bass file while HEAD is at take7
262 result = await perform_restore(
263 root=repo_root,
264 session=async_session,
265 paths=["bass/bassline.mid"],
266 source_ref=take3.commit_id,
267 staged=False,
268 )
269
270 assert result.source_commit_id == take3.commit_id
271 assert result.paths_restored == ["bass/bassline.mid"]
272 # The bass file is now take3's content
273 assert (repo_root / "muse-work" / "bass" / "bassline.mid").read_bytes() == content_take3
274
275 @pytest.mark.anyio
276 async def test_muse_restore_source_abbreviated_sha(
277 self,
278 async_session: AsyncSession,
279 repo_id: str,
280 repo_root: pathlib.Path,
281 ) -> None:
282 """An abbreviated SHA is accepted as a --source ref."""
283 object_id = "ee" * 32
284 content = b"abbreviated"
285 _seed_object_store(repo_root, object_id, content)
286
287 commit = await _add_commit(
288 async_session,
289 repo_id=repo_id,
290 message="v1",
291 manifest={"track.mid": object_id},
292 )
293 _write_ref(repo_root, "main", commit.commit_id)
294
295 workdir = repo_root / "muse-work"
296 workdir.mkdir()
297 (workdir / "track.mid").write_bytes(b"old")
298
299 result = await perform_restore(
300 root=repo_root,
301 session=async_session,
302 paths=["track.mid"],
303 source_ref=commit.commit_id[:10],
304 staged=False,
305 )
306
307 assert result.source_commit_id == commit.commit_id
308 assert (repo_root / "muse-work" / "track.mid").read_bytes() == content
309
310
311 class TestRestoreMultiplePaths:
312
313 @pytest.mark.anyio
314 async def test_muse_restore_multiple_paths(
315 self,
316 async_session: AsyncSession,
317 repo_id: str,
318 repo_root: pathlib.Path,
319 ) -> None:
320 """Multiple paths are restored atomically in one call."""
321 obj_bass = "11" * 32
322 obj_drums = "22" * 32
323 _seed_object_store(repo_root, obj_bass, b"bass v1")
324 _seed_object_store(repo_root, obj_drums, b"drums v1")
325
326 commit = await _add_commit(
327 async_session,
328 repo_id=repo_id,
329 message="v1",
330 manifest={"bass.mid": obj_bass, "drums.mid": obj_drums},
331 )
332 _write_ref(repo_root, "main", commit.commit_id)
333
334 workdir = repo_root / "muse-work"
335 workdir.mkdir()
336 (workdir / "bass.mid").write_bytes(b"stale")
337 (workdir / "drums.mid").write_bytes(b"stale")
338
339 result = await perform_restore(
340 root=repo_root,
341 session=async_session,
342 paths=["bass.mid", "drums.mid"],
343 source_ref=None,
344 staged=False,
345 )
346
347 assert sorted(result.paths_restored) == ["bass.mid", "drums.mid"]
348 assert (workdir / "bass.mid").read_bytes() == b"bass v1"
349 assert (workdir / "drums.mid").read_bytes() == b"drums v1"
350
351 @pytest.mark.anyio
352 async def test_muse_restore_muse_work_prefix_stripped(
353 self,
354 async_session: AsyncSession,
355 repo_id: str,
356 repo_root: pathlib.Path,
357 ) -> None:
358 """Paths given with 'muse-work/' prefix are normalised correctly."""
359 object_id = "33" * 32
360 content = b"prefix test"
361 _seed_object_store(repo_root, object_id, content)
362
363 commit = await _add_commit(
364 async_session,
365 repo_id=repo_id,
366 manifest={"lead.mid": object_id},
367 )
368 _write_ref(repo_root, "main", commit.commit_id)
369
370 workdir = repo_root / "muse-work"
371 workdir.mkdir()
372 (workdir / "lead.mid").write_bytes(b"old")
373
374 result = await perform_restore(
375 root=repo_root,
376 session=async_session,
377 paths=["muse-work/lead.mid"], # with prefix
378 source_ref=None,
379 staged=False,
380 )
381
382 assert result.paths_restored == ["lead.mid"]
383 assert (workdir / "lead.mid").read_bytes() == content
384
385
386 # ---------------------------------------------------------------------------
387 # Error handling
388 # ---------------------------------------------------------------------------
389
390
391 class TestRestoreErrors:
392
393 @pytest.mark.anyio
394 async def test_muse_restore_errors_on_missing_path(
395 self,
396 async_session: AsyncSession,
397 repo_id: str,
398 repo_root: pathlib.Path,
399 ) -> None:
400 """PathNotInSnapshotError raised when path absent from HEAD snapshot."""
401 object_id = "44" * 32
402 _seed_object_store(repo_root, object_id, b"only track")
403
404 commit = await _add_commit(
405 async_session,
406 repo_id=repo_id,
407 manifest={"track.mid": object_id},
408 )
409 _write_ref(repo_root, "main", commit.commit_id)
410
411 with pytest.raises(PathNotInSnapshotError) as exc_info:
412 await perform_restore(
413 root=repo_root,
414 session=async_session,
415 paths=["nonexistent.mid"],
416 source_ref=None,
417 staged=False,
418 )
419
420 assert "nonexistent.mid" in str(exc_info.value)
421 assert exc_info.value.rel_path == "nonexistent.mid"
422 assert exc_info.value.source_commit_id == commit.commit_id
423
424 @pytest.mark.anyio
425 async def test_muse_restore_source_missing_path(
426 self,
427 async_session: AsyncSession,
428 repo_id: str,
429 repo_root: pathlib.Path,
430 ) -> None:
431 """PathNotInSnapshotError when path absent from historical commit."""
432 obj_old = "55" * 32
433 _seed_object_store(repo_root, obj_old, b"old track")
434
435 t0 = datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc)
436 old_commit = await _add_commit(
437 async_session,
438 repo_id=repo_id,
439 message="old",
440 manifest={"old_track.mid": obj_old},
441 committed_at=t0,
442 )
443
444 obj_new = "66" * 32
445 _seed_object_store(repo_root, obj_new, b"new track")
446 t1 = datetime.datetime(2024, 1, 2, tzinfo=datetime.timezone.utc)
447 new_commit = await _add_commit(
448 async_session,
449 repo_id=repo_id,
450 message="new",
451 manifest={"new_track.mid": obj_new},
452 parent_commit_id=old_commit.commit_id,
453 committed_at=t1,
454 )
455 _write_ref(repo_root, "main", new_commit.commit_id)
456
457 with pytest.raises(PathNotInSnapshotError) as exc_info:
458 await perform_restore(
459 root=repo_root,
460 session=async_session,
461 paths=["new_track.mid"], # not in old_commit's snapshot
462 source_ref=old_commit.commit_id,
463 staged=False,
464 )
465
466 assert "new_track.mid" in str(exc_info.value)
467
468 @pytest.mark.anyio
469 async def test_muse_restore_missing_object_store(
470 self,
471 async_session: AsyncSession,
472 repo_id: str,
473 repo_root: pathlib.Path,
474 ) -> None:
475 """MissingObjectError raised when blob absent from object store."""
476 missing_id = "77" * 32
477 # Intentionally NOT seeding the object store
478
479 commit = await _add_commit(
480 async_session,
481 repo_id=repo_id,
482 manifest={"lead.mid": missing_id},
483 )
484 _write_ref(repo_root, "main", commit.commit_id)
485
486 with pytest.raises(MissingObjectError) as exc_info:
487 await perform_restore(
488 root=repo_root,
489 session=async_session,
490 paths=["lead.mid"],
491 source_ref=None,
492 staged=False,
493 )
494
495 assert missing_id[:8] in str(exc_info.value)
496
497 @pytest.mark.anyio
498 async def test_muse_restore_ref_not_found(
499 self,
500 async_session: AsyncSession,
501 repo_id: str,
502 repo_root: pathlib.Path,
503 ) -> None:
504 """An unknown source ref exits with USER_ERROR via typer.Exit."""
505 import typer
506
507 object_id = "88" * 32
508 _seed_object_store(repo_root, object_id, b"content")
509 commit = await _add_commit(
510 async_session,
511 repo_id=repo_id,
512 manifest={"track.mid": object_id},
513 )
514 _write_ref(repo_root, "main", commit.commit_id)
515
516 with pytest.raises(typer.Exit) as exc_info:
517 await perform_restore(
518 root=repo_root,
519 session=async_session,
520 paths=["track.mid"],
521 source_ref="deadbeef1234",
522 staged=False,
523 )
524
525 assert exc_info.value.exit_code == ExitCode.USER_ERROR
526
527 @pytest.mark.anyio
528 async def test_muse_restore_no_commits(
529 self,
530 async_session: AsyncSession,
531 repo_id: str,
532 repo_root: pathlib.Path,
533 ) -> None:
534 """Branch with no commits exits with USER_ERROR."""
535 import typer
536
537 # repo_root fixture has empty main ref
538 with pytest.raises(typer.Exit) as exc_info:
539 await perform_restore(
540 root=repo_root,
541 session=async_session,
542 paths=["track.mid"],
543 source_ref=None,
544 staged=False,
545 )
546
547 assert exc_info.value.exit_code == ExitCode.USER_ERROR
548
549
550 # ---------------------------------------------------------------------------
551 # RestoreResult type
552 # ---------------------------------------------------------------------------
553
554
555 class TestRestoreResult:
556
557 def test_muse_restore_result_fields(self) -> None:
558 """RestoreResult carries commit ID, paths, and staged flag."""
559 r = RestoreResult(
560 source_commit_id="a" * 64,
561 paths_restored=["bass.mid", "drums.mid"],
562 staged=True,
563 )
564 assert r.source_commit_id == "a" * 64
565 assert r.paths_restored == ["bass.mid", "drums.mid"]
566 assert r.staged is True
567
568 def test_muse_restore_result_frozen(self) -> None:
569 """RestoreResult is immutable (frozen dataclass)."""
570 r = RestoreResult(source_commit_id="b" * 64)
571 with pytest.raises(Exception):
572 r.staged = True # type: ignore[misc]
573
574 def test_muse_restore_result_defaults(self) -> None:
575 """RestoreResult has sensible defaults for optional fields."""
576 r = RestoreResult(source_commit_id="c" * 64)
577 assert r.paths_restored == []
578 assert r.staged is False
579
580
581 # ---------------------------------------------------------------------------
582 # PathNotInSnapshotError type
583 # ---------------------------------------------------------------------------
584
585
586 class TestPathNotInSnapshotError:
587
588 def test_error_message_contains_path_and_commit(self) -> None:
589 """Error message references the missing path and abbreviated commit ID."""
590 exc = PathNotInSnapshotError("bass/bassline.mid", "abcd1234" * 8)
591 msg = str(exc)
592 assert "bass/bassline.mid" in msg
593 assert "abcd1234" in msg
594
595
596 # ---------------------------------------------------------------------------
597 # Boundary seal — AST checks
598 # ---------------------------------------------------------------------------
599
600
601 class TestBoundarySeals:
602
603 def _parse(self, rel_path: str) -> ast.Module:
604 root = pathlib.Path(__file__).resolve().parent.parent
605 return ast.parse((root / rel_path).read_text())
606
607 def test_boundary_no_forbidden_imports(self) -> None:
608 """muse_restore service must not import executor, state_store, mcp, or maestro_handlers."""
609 tree = self._parse("maestro/services/muse_restore.py")
610 forbidden = {"state_store", "executor", "maestro_handlers", "mcp"}
611 for node in ast.walk(tree):
612 if isinstance(node, ast.ImportFrom) and node.module:
613 for fb in forbidden:
614 assert fb not in node.module, (
615 f"muse_restore imports forbidden module: {node.module}"
616 )
617
618 def test_restore_service_has_future_import(self) -> None:
619 """muse_restore.py starts with 'from __future__ import annotations'."""
620 tree = self._parse("maestro/services/muse_restore.py")
621 first_import = next(
622 (n for n in ast.walk(tree) if isinstance(n, ast.ImportFrom)),
623 None,
624 )
625 assert first_import is not None
626 assert first_import.module == "__future__"
627
628 def test_restore_command_has_future_import(self) -> None:
629 """restore.py CLI command starts with 'from __future__ import annotations'."""
630 tree = self._parse("maestro/muse_cli/commands/restore.py")
631 first_import = next(
632 (n for n in ast.walk(tree) if isinstance(n, ast.ImportFrom)),
633 None,
634 )
635 assert first_import is not None
636 assert first_import.module == "__future__"