test_core_store.py
python
| 1 | """Tests for muse.core.store — file-based commit and snapshot storage.""" |
| 2 | from __future__ import annotations |
| 3 | |
| 4 | import datetime |
| 5 | import json |
| 6 | import pathlib |
| 7 | |
| 8 | import pytest |
| 9 | |
| 10 | from muse.core.store import ( |
| 11 | CommitDict, |
| 12 | CommitRecord, |
| 13 | SnapshotRecord, |
| 14 | TagRecord, |
| 15 | find_commits_by_prefix, |
| 16 | get_all_commits, |
| 17 | get_all_tags, |
| 18 | get_commits_for_branch, |
| 19 | get_head_commit_id, |
| 20 | get_head_snapshot_id, |
| 21 | get_head_snapshot_manifest, |
| 22 | get_tags_for_commit, |
| 23 | read_commit, |
| 24 | read_snapshot, |
| 25 | update_commit_metadata, |
| 26 | write_commit, |
| 27 | write_snapshot, |
| 28 | write_tag, |
| 29 | ) |
| 30 | |
| 31 | |
| 32 | @pytest.fixture |
| 33 | def repo(tmp_path: pathlib.Path) -> pathlib.Path: |
| 34 | """Create a minimal .muse/ directory structure.""" |
| 35 | muse_dir = tmp_path / ".muse" |
| 36 | (muse_dir / "commits").mkdir(parents=True) |
| 37 | (muse_dir / "snapshots").mkdir(parents=True) |
| 38 | (muse_dir / "refs" / "heads").mkdir(parents=True) |
| 39 | (muse_dir / "repo.json").write_text(json.dumps({"repo_id": "test-repo"})) |
| 40 | (muse_dir / "HEAD").write_text("refs/heads/main\n") |
| 41 | (muse_dir / "refs" / "heads" / "main").write_text("") |
| 42 | return tmp_path |
| 43 | |
| 44 | |
| 45 | def _make_commit(root: pathlib.Path, commit_id: str, snapshot_id: str, message: str, parent: str | None = None) -> CommitRecord: |
| 46 | c = CommitRecord( |
| 47 | commit_id=commit_id, |
| 48 | repo_id="test-repo", |
| 49 | branch="main", |
| 50 | snapshot_id=snapshot_id, |
| 51 | message=message, |
| 52 | committed_at=datetime.datetime.now(datetime.timezone.utc), |
| 53 | parent_commit_id=parent, |
| 54 | ) |
| 55 | write_commit(root, c) |
| 56 | return c |
| 57 | |
| 58 | |
| 59 | def _make_snapshot(root: pathlib.Path, snapshot_id: str, manifest: dict[str, str]) -> SnapshotRecord: |
| 60 | s = SnapshotRecord(snapshot_id=snapshot_id, manifest=manifest) |
| 61 | write_snapshot(root, s) |
| 62 | return s |
| 63 | |
| 64 | |
| 65 | class TestFormatVersion: |
| 66 | """CommitRecord.format_version tracks schema evolution.""" |
| 67 | |
| 68 | def test_new_commit_has_format_version_4(self, repo: pathlib.Path) -> None: |
| 69 | c = _make_commit(repo, "abc123", "snap1", "msg") |
| 70 | assert c.format_version == 4 |
| 71 | |
| 72 | def test_format_version_round_trips_through_json(self, repo: pathlib.Path) -> None: |
| 73 | _make_commit(repo, "abc123", "snap1", "msg") |
| 74 | loaded = read_commit(repo, "abc123") |
| 75 | assert loaded is not None |
| 76 | assert loaded.format_version == 4 |
| 77 | |
| 78 | def test_format_version_in_serialised_dict(self) -> None: |
| 79 | c = CommitRecord( |
| 80 | commit_id="x", |
| 81 | repo_id="r", |
| 82 | branch="main", |
| 83 | snapshot_id="s", |
| 84 | message="m", |
| 85 | committed_at=datetime.datetime.now(datetime.timezone.utc), |
| 86 | ) |
| 87 | d = c.to_dict() |
| 88 | assert "format_version" in d |
| 89 | assert d["format_version"] == 4 |
| 90 | |
| 91 | def test_missing_format_version_defaults_to_1(self) -> None: |
| 92 | """Existing JSON without format_version field deserialises as version 1.""" |
| 93 | raw = CommitDict( |
| 94 | commit_id="abc", |
| 95 | repo_id="r", |
| 96 | branch="main", |
| 97 | snapshot_id="s", |
| 98 | message="old record", |
| 99 | committed_at="2025-01-01T00:00:00+00:00", |
| 100 | ) |
| 101 | c = CommitRecord.from_dict(raw) |
| 102 | assert c.format_version == 1 |
| 103 | |
| 104 | def test_explicit_format_version_preserved(self) -> None: |
| 105 | raw = CommitDict( |
| 106 | commit_id="abc", |
| 107 | repo_id="r", |
| 108 | branch="main", |
| 109 | snapshot_id="s", |
| 110 | message="versioned record", |
| 111 | committed_at="2025-01-01T00:00:00+00:00", |
| 112 | format_version=2, |
| 113 | ) |
| 114 | c = CommitRecord.from_dict(raw) |
| 115 | assert c.format_version == 2 |
| 116 | |
| 117 | def test_format_version_field_is_integer(self, repo: pathlib.Path) -> None: |
| 118 | _make_commit(repo, "abc123", "snap1", "msg") |
| 119 | loaded = read_commit(repo, "abc123") |
| 120 | assert loaded is not None |
| 121 | assert isinstance(loaded.format_version, int) |
| 122 | |
| 123 | |
| 124 | class TestWriteReadCommit: |
| 125 | def test_roundtrip(self, repo: pathlib.Path) -> None: |
| 126 | c = _make_commit(repo, "abc123", "snap1", "Initial commit") |
| 127 | loaded = read_commit(repo, "abc123") |
| 128 | assert loaded is not None |
| 129 | assert loaded.commit_id == "abc123" |
| 130 | assert loaded.message == "Initial commit" |
| 131 | assert loaded.repo_id == "test-repo" |
| 132 | |
| 133 | def test_read_missing_returns_none(self, repo: pathlib.Path) -> None: |
| 134 | assert read_commit(repo, "nonexistent") is None |
| 135 | |
| 136 | def test_idempotent_write(self, repo: pathlib.Path) -> None: |
| 137 | _make_commit(repo, "abc123", "snap1", "First") |
| 138 | _make_commit(repo, "abc123", "snap1", "Second") # Should not overwrite |
| 139 | loaded = read_commit(repo, "abc123") |
| 140 | assert loaded is not None |
| 141 | assert loaded.message == "First" |
| 142 | |
| 143 | def test_metadata_preserved(self, repo: pathlib.Path) -> None: |
| 144 | c = CommitRecord( |
| 145 | commit_id="abc123", |
| 146 | repo_id="test-repo", |
| 147 | branch="main", |
| 148 | snapshot_id="snap1", |
| 149 | message="With metadata", |
| 150 | committed_at=datetime.datetime.now(datetime.timezone.utc), |
| 151 | metadata={"section": "chorus", "emotion": "joyful"}, |
| 152 | ) |
| 153 | write_commit(repo, c) |
| 154 | loaded = read_commit(repo, "abc123") |
| 155 | assert loaded is not None |
| 156 | assert loaded.metadata["section"] == "chorus" |
| 157 | assert loaded.metadata["emotion"] == "joyful" |
| 158 | |
| 159 | |
| 160 | class TestUpdateCommitMetadata: |
| 161 | def test_set_key(self, repo: pathlib.Path) -> None: |
| 162 | _make_commit(repo, "abc123", "snap1", "msg") |
| 163 | result = update_commit_metadata(repo, "abc123", "tempo_bpm", 120.0) |
| 164 | assert result is True |
| 165 | loaded = read_commit(repo, "abc123") |
| 166 | assert loaded is not None |
| 167 | assert loaded.metadata["tempo_bpm"] == 120.0 |
| 168 | |
| 169 | def test_missing_commit_returns_false(self, repo: pathlib.Path) -> None: |
| 170 | assert update_commit_metadata(repo, "missing", "k", "v") is False |
| 171 | |
| 172 | |
| 173 | class TestWriteReadSnapshot: |
| 174 | def test_roundtrip(self, repo: pathlib.Path) -> None: |
| 175 | s = _make_snapshot(repo, "snap1", {"tracks/drums.mid": "deadbeef"}) |
| 176 | loaded = read_snapshot(repo, "snap1") |
| 177 | assert loaded is not None |
| 178 | assert loaded.manifest == {"tracks/drums.mid": "deadbeef"} |
| 179 | |
| 180 | def test_read_missing_returns_none(self, repo: pathlib.Path) -> None: |
| 181 | assert read_snapshot(repo, "nonexistent") is None |
| 182 | |
| 183 | |
| 184 | class TestHeadQueries: |
| 185 | def test_get_head_commit_id_empty_branch(self, repo: pathlib.Path) -> None: |
| 186 | assert get_head_commit_id(repo, "main") is None |
| 187 | |
| 188 | def test_get_head_commit_id(self, repo: pathlib.Path) -> None: |
| 189 | (repo / ".muse" / "refs" / "heads" / "main").write_text("abc123") |
| 190 | assert get_head_commit_id(repo, "main") == "abc123" |
| 191 | |
| 192 | def test_get_head_snapshot_id(self, repo: pathlib.Path) -> None: |
| 193 | _make_commit(repo, "abc123", "snap1", "msg") |
| 194 | _make_snapshot(repo, "snap1", {"f.mid": "hash1"}) |
| 195 | (repo / ".muse" / "refs" / "heads" / "main").write_text("abc123") |
| 196 | assert get_head_snapshot_id(repo, "test-repo", "main") == "snap1" |
| 197 | |
| 198 | def test_get_head_snapshot_manifest(self, repo: pathlib.Path) -> None: |
| 199 | _make_commit(repo, "abc123", "snap1", "msg") |
| 200 | _make_snapshot(repo, "snap1", {"f.mid": "hash1"}) |
| 201 | (repo / ".muse" / "refs" / "heads" / "main").write_text("abc123") |
| 202 | manifest = get_head_snapshot_manifest(repo, "test-repo", "main") |
| 203 | assert manifest == {"f.mid": "hash1"} |
| 204 | |
| 205 | |
| 206 | class TestGetCommitsForBranch: |
| 207 | def test_chain(self, repo: pathlib.Path) -> None: |
| 208 | _make_commit(repo, "root", "snap0", "Root") |
| 209 | _make_commit(repo, "child", "snap1", "Child", parent="root") |
| 210 | _make_commit(repo, "grandchild", "snap2", "Grandchild", parent="child") |
| 211 | (repo / ".muse" / "refs" / "heads" / "main").write_text("grandchild") |
| 212 | |
| 213 | commits = get_commits_for_branch(repo, "test-repo", "main") |
| 214 | assert [c.commit_id for c in commits] == ["grandchild", "child", "root"] |
| 215 | |
| 216 | def test_empty_branch(self, repo: pathlib.Path) -> None: |
| 217 | assert get_commits_for_branch(repo, "test-repo", "main") == [] |
| 218 | |
| 219 | |
| 220 | class TestFindByPrefix: |
| 221 | def test_finds_match(self, repo: pathlib.Path) -> None: |
| 222 | _make_commit(repo, "abcdef1234", "snap1", "msg") |
| 223 | results = find_commits_by_prefix(repo, "abcdef") |
| 224 | assert len(results) == 1 |
| 225 | assert results[0].commit_id == "abcdef1234" |
| 226 | |
| 227 | def test_no_match(self, repo: pathlib.Path) -> None: |
| 228 | assert find_commits_by_prefix(repo, "zzz") == [] |
| 229 | |
| 230 | |
| 231 | class TestTags: |
| 232 | def test_write_and_read(self, repo: pathlib.Path) -> None: |
| 233 | _make_commit(repo, "abc123", "snap1", "msg") |
| 234 | write_tag(repo, TagRecord( |
| 235 | tag_id="tag1", |
| 236 | repo_id="test-repo", |
| 237 | commit_id="abc123", |
| 238 | tag="emotion:joyful", |
| 239 | )) |
| 240 | tags = get_tags_for_commit(repo, "test-repo", "abc123") |
| 241 | assert len(tags) == 1 |
| 242 | assert tags[0].tag == "emotion:joyful" |
| 243 | |
| 244 | def test_get_all_tags(self, repo: pathlib.Path) -> None: |
| 245 | _make_commit(repo, "abc123", "snap1", "msg") |
| 246 | write_tag(repo, TagRecord(tag_id="t1", repo_id="test-repo", commit_id="abc123", tag="stage:rough-mix")) |
| 247 | write_tag(repo, TagRecord(tag_id="t2", repo_id="test-repo", commit_id="abc123", tag="key:Am")) |
| 248 | all_tags = get_all_tags(repo, "test-repo") |
| 249 | assert len(all_tags) == 2 |