cgcardona / muse public
test_recall.py python
491 lines 14.4 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Tests for ``muse recall``.
2
3 All async tests call ``_recall_async`` directly with an in-memory SQLite
4 session and a ``tmp_path`` repo root — no real Postgres or running process
5 required. Commits are seeded via ``_commit_async`` so the two commands
6 are exercised as an integrated pair.
7
8 Covers:
9 - Keyword match returns top-N results sorted by score (highest first).
10 - ``--limit`` restricts result count.
11 - ``--threshold`` filters low-scoring commits.
12 - ``--since`` / ``--until`` date filters.
13 - ``--json`` emits valid JSON with the expected schema.
14 - Query with zero matches returns empty result set (not an error).
15 - CLI invocation outside a repo exits with code 2.
16 - ``--since`` / ``--until`` with bad date strings exits with code 1.
17 """
18 from __future__ import annotations
19
20 import json
21 import os
22 import pathlib
23 import uuid
24
25 import pytest
26 from sqlalchemy.ext.asyncio import AsyncSession
27
28 from maestro.muse_cli.commands.commit import _commit_async
29 from maestro.muse_cli.commands.recall import _recall_async, _score, _tokenize
30 from maestro.muse_cli.errors import ExitCode
31
32
33 # ---------------------------------------------------------------------------
34 # Repo / workdir helpers (mirror test_log.py pattern)
35 # ---------------------------------------------------------------------------
36
37
38 def _init_muse_repo(root: pathlib.Path, repo_id: str | None = None) -> str:
39 rid = repo_id or str(uuid.uuid4())
40 muse = root / ".muse"
41 (muse / "refs" / "heads").mkdir(parents=True)
42 (muse / "repo.json").write_text(
43 json.dumps({"repo_id": rid, "schema_version": "1"})
44 )
45 (muse / "HEAD").write_text("refs/heads/main")
46 (muse / "refs" / "heads" / "main").write_text("")
47 return rid
48
49
50 def _write_workdir(root: pathlib.Path, files: dict[str, bytes]) -> None:
51 workdir = root / "muse-work"
52 workdir.mkdir(exist_ok=True)
53 for name, content in files.items():
54 (workdir / name).write_bytes(content)
55
56
57 async def _make_commits(
58 root: pathlib.Path,
59 session: AsyncSession,
60 messages: list[str],
61 file_seed: int = 0,
62 ) -> list[str]:
63 """Create N commits on the repo, each with unique file content."""
64 commit_ids: list[str] = []
65 for i, msg in enumerate(messages):
66 _write_workdir(root, {f"track_{file_seed + i}.mid": f"MIDI-{file_seed + i}".encode()})
67 cid = await _commit_async(message=msg, root=root, session=session)
68 commit_ids.append(cid)
69 return commit_ids
70
71
72 # ---------------------------------------------------------------------------
73 # Unit tests for scoring helpers
74 # ---------------------------------------------------------------------------
75
76
77 class TestScoringHelpers:
78 """Unit tests for ``_tokenize`` and ``_score`` — no DB required."""
79
80 def test_tokenize_splits_on_whitespace(self) -> None:
81 assert _tokenize("dark jazz bassline") == {"dark", "jazz", "bassline"}
82
83 def test_tokenize_is_lowercase(self) -> None:
84 assert _tokenize("DARK Jazz") == {"dark", "jazz"}
85
86 def test_tokenize_ignores_punctuation(self) -> None:
87 tokens = _tokenize("boom, bap! drum-fill")
88 assert "boom" in tokens
89 assert "bap" in tokens
90 assert "drum" in tokens
91 assert "fill" in tokens
92
93 def test_score_full_match_returns_one(self) -> None:
94 q = _tokenize("jazz bassline")
95 score = _score(q, "this is a jazz bassline")
96 assert score == 1.0
97
98 def test_score_no_match_returns_zero(self) -> None:
99 q = _tokenize("jazz bassline")
100 score = _score(q, "rock guitar solo")
101 assert score == 0.0
102
103 def test_score_partial_match(self) -> None:
104 q = _tokenize("jazz drum fill")
105 score = _score(q, "a cool jazz moment")
106 assert 0.0 < score < 1.0
107 assert score == pytest.approx(1 / 3, rel=1e-6)
108
109 def test_score_empty_query_returns_zero(self) -> None:
110 score = _score(set(), "anything")
111 assert score == 0.0
112
113
114 # ---------------------------------------------------------------------------
115 # test_recall_returns_top_n_by_score
116 # ---------------------------------------------------------------------------
117
118
119 @pytest.mark.anyio
120 async def test_recall_returns_top_n_by_score(
121 tmp_path: pathlib.Path,
122 muse_cli_db_session: AsyncSession,
123 capsys: pytest.CaptureFixture[str],
124 ) -> None:
125 """Commits with higher keyword overlap rank before those with lower overlap."""
126 _init_muse_repo(tmp_path)
127 await _make_commits(
128 tmp_path,
129 muse_cli_db_session,
130 [
131 "boom bap drum pattern",
132 "jazz piano chord voicing",
133 "boom bap jazz fusion groove",
134 "classical string quartet",
135 ],
136 )
137
138 results = await _recall_async(
139 root=tmp_path,
140 session=muse_cli_db_session,
141 query="boom bap jazz",
142 limit=5,
143 threshold=0.0,
144 branch=None,
145 since=None,
146 until=None,
147 as_json=False,
148 )
149
150 assert len(results) > 0
151 assert results[0]["message"] == "boom bap jazz fusion groove"
152 assert results[0]["score"] == pytest.approx(1.0)
153
154 scores = [r["score"] for r in results]
155 assert scores == sorted(scores, reverse=True)
156
157
158 # ---------------------------------------------------------------------------
159 # test_recall_limit_restricts_result_count
160 # ---------------------------------------------------------------------------
161
162
163 @pytest.mark.anyio
164 async def test_recall_limit_restricts_result_count(
165 tmp_path: pathlib.Path,
166 muse_cli_db_session: AsyncSession,
167 capsys: pytest.CaptureFixture[str],
168 ) -> None:
169 """``--limit 2`` returns at most 2 results even if more match."""
170 _init_muse_repo(tmp_path)
171 await _make_commits(
172 tmp_path,
173 muse_cli_db_session,
174 [
175 "jazz piano solo",
176 "jazz drum groove",
177 "jazz bass walk",
178 "jazz chord progression",
179 ],
180 )
181
182 results = await _recall_async(
183 root=tmp_path,
184 session=muse_cli_db_session,
185 query="jazz",
186 limit=2,
187 threshold=0.0,
188 branch=None,
189 since=None,
190 until=None,
191 as_json=False,
192 )
193
194 assert len(results) == 2
195
196
197 # ---------------------------------------------------------------------------
198 # test_recall_threshold_filters_low_scores
199 # ---------------------------------------------------------------------------
200
201
202 @pytest.mark.anyio
203 async def test_recall_threshold_filters_low_scores(
204 tmp_path: pathlib.Path,
205 muse_cli_db_session: AsyncSession,
206 capsys: pytest.CaptureFixture[str],
207 ) -> None:
208 """Commits with score < threshold are excluded from results."""
209 _init_muse_repo(tmp_path)
210 await _make_commits(
211 tmp_path,
212 muse_cli_db_session,
213 [
214 "jazz piano solo",
215 "classical strings",
216 ],
217 )
218
219 results = await _recall_async(
220 root=tmp_path,
221 session=muse_cli_db_session,
222 query="jazz",
223 limit=5,
224 threshold=0.6,
225 branch=None,
226 since=None,
227 until=None,
228 as_json=False,
229 )
230
231 assert all(r["score"] >= 0.6 for r in results)
232 messages = [r["message"] for r in results]
233 assert "jazz piano solo" in messages
234 assert "classical strings" not in messages
235
236
237 # ---------------------------------------------------------------------------
238 # test_recall_no_matches_returns_empty
239 # ---------------------------------------------------------------------------
240
241
242 @pytest.mark.anyio
243 async def test_recall_no_matches_returns_empty(
244 tmp_path: pathlib.Path,
245 muse_cli_db_session: AsyncSession,
246 capsys: pytest.CaptureFixture[str],
247 ) -> None:
248 """A query that matches nothing returns an empty list (not an error)."""
249 _init_muse_repo(tmp_path)
250 await _make_commits(tmp_path, muse_cli_db_session, ["rock guitar riff"])
251
252 results = await _recall_async(
253 root=tmp_path,
254 session=muse_cli_db_session,
255 query="jazz bassline",
256 limit=5,
257 threshold=0.6,
258 branch=None,
259 since=None,
260 until=None,
261 as_json=False,
262 )
263
264 assert results == []
265 out = capsys.readouterr().out
266 assert "No matching commits found" in out
267
268
269 # ---------------------------------------------------------------------------
270 # test_recall_json_output_valid_schema
271 # ---------------------------------------------------------------------------
272
273
274 @pytest.mark.anyio
275 async def test_recall_json_output_valid_schema(
276 tmp_path: pathlib.Path,
277 muse_cli_db_session: AsyncSession,
278 capsys: pytest.CaptureFixture[str],
279 ) -> None:
280 """``--json`` output is valid JSON with expected fields."""
281 _init_muse_repo(tmp_path)
282 await _make_commits(tmp_path, muse_cli_db_session, ["jazz drum groove"])
283 capsys.readouterr() # discard commit output before testing recall JSON
284
285 await _recall_async(
286 root=tmp_path,
287 session=muse_cli_db_session,
288 query="jazz drum",
289 limit=5,
290 threshold=0.0,
291 branch=None,
292 since=None,
293 until=None,
294 as_json=True,
295 )
296
297 out = capsys.readouterr().out
298 parsed = json.loads(out)
299 assert isinstance(parsed, list)
300 assert len(parsed) >= 1
301
302 entry = parsed[0]
303 assert "rank" in entry
304 assert "score" in entry
305 assert "commit_id" in entry
306 assert "date" in entry
307 assert "branch" in entry
308 assert "message" in entry
309 assert entry["rank"] == 1
310 assert isinstance(entry["score"], float)
311 assert entry["branch"] == "main"
312
313
314 # ---------------------------------------------------------------------------
315 # test_recall_rank_field_is_sequential
316 # ---------------------------------------------------------------------------
317
318
319 @pytest.mark.anyio
320 async def test_recall_rank_field_is_sequential(
321 tmp_path: pathlib.Path,
322 muse_cli_db_session: AsyncSession,
323 capsys: pytest.CaptureFixture[str],
324 ) -> None:
325 """``rank`` starts at 1 and increments sequentially."""
326 _init_muse_repo(tmp_path)
327 await _make_commits(
328 tmp_path,
329 muse_cli_db_session,
330 ["jazz piano", "jazz drums", "jazz bass"],
331 )
332
333 results = await _recall_async(
334 root=tmp_path,
335 session=muse_cli_db_session,
336 query="jazz",
337 limit=5,
338 threshold=0.0,
339 branch=None,
340 since=None,
341 until=None,
342 as_json=False,
343 )
344
345 assert [r["rank"] for r in results] == list(range(1, len(results) + 1))
346
347
348 # ---------------------------------------------------------------------------
349 # test_recall_since_filter_excludes_older_commits
350 # ---------------------------------------------------------------------------
351
352
353 @pytest.mark.anyio
354 async def test_recall_since_filter_excludes_older_commits(
355 tmp_path: pathlib.Path,
356 muse_cli_db_session: AsyncSession,
357 capsys: pytest.CaptureFixture[str],
358 ) -> None:
359 """``--since`` set to the future excludes all commits."""
360 from datetime import datetime, timezone
361
362 _init_muse_repo(tmp_path)
363 await _make_commits(tmp_path, muse_cli_db_session, ["jazz rhythm section"])
364
365 future = datetime(2099, 1, 1, tzinfo=timezone.utc)
366
367 results = await _recall_async(
368 root=tmp_path,
369 session=muse_cli_db_session,
370 query="jazz",
371 limit=5,
372 threshold=0.0,
373 branch=None,
374 since=future,
375 until=None,
376 as_json=False,
377 )
378
379 assert results == []
380
381
382 # ---------------------------------------------------------------------------
383 # test_recall_until_filter_excludes_newer_commits
384 # ---------------------------------------------------------------------------
385
386
387 @pytest.mark.anyio
388 async def test_recall_until_filter_excludes_newer_commits(
389 tmp_path: pathlib.Path,
390 muse_cli_db_session: AsyncSession,
391 capsys: pytest.CaptureFixture[str],
392 ) -> None:
393 """``--until`` set in the past excludes all commits."""
394 from datetime import datetime, timezone
395
396 _init_muse_repo(tmp_path)
397 await _make_commits(tmp_path, muse_cli_db_session, ["jazz rhythm section"])
398
399 past = datetime(2000, 1, 1, tzinfo=timezone.utc)
400
401 results = await _recall_async(
402 root=tmp_path,
403 session=muse_cli_db_session,
404 query="jazz",
405 limit=5,
406 threshold=0.0,
407 branch=None,
408 since=None,
409 until=past,
410 as_json=False,
411 )
412
413 assert results == []
414
415
416 # ---------------------------------------------------------------------------
417 # test_recall_outside_repo_exits_2
418 # ---------------------------------------------------------------------------
419
420
421 def test_recall_outside_repo_exits_2(tmp_path: pathlib.Path) -> None:
422 """``muse recall`` outside a .muse/ directory exits with code 2."""
423 from typer.testing import CliRunner
424 from maestro.muse_cli.app import cli
425
426 runner = CliRunner()
427 prev = os.getcwd()
428 try:
429 os.chdir(tmp_path)
430 result = runner.invoke(cli, ["recall", "jazz"], catch_exceptions=False)
431 finally:
432 os.chdir(prev)
433
434 assert result.exit_code == ExitCode.REPO_NOT_FOUND
435
436
437 # ---------------------------------------------------------------------------
438 # test_recall_bad_since_date_exits_1
439 # ---------------------------------------------------------------------------
440
441
442 def test_recall_bad_since_date_exits_1() -> None:
443 """``--since`` with a non-YYYY-MM-DD value exits with code 1.
444
445 Date validation occurs before repo discovery so no repo setup is needed.
446 """
447 import typer
448 from maestro.muse_cli.commands.recall import recall as recall_cmd
449
450 with pytest.raises(typer.Exit) as exc_info:
451 recall_cmd(
452 ctx=None, # type: ignore[arg-type]
453 query="jazz",
454 limit=5,
455 threshold=0.6,
456 branch=None,
457 since="not-a-date",
458 until=None,
459 as_json=False,
460 )
461
462 assert exc_info.value.exit_code == ExitCode.USER_ERROR
463
464
465
466 # ---------------------------------------------------------------------------
467 # test_recall_bad_until_date_exits_1
468 # ---------------------------------------------------------------------------
469
470
471 def test_recall_bad_until_date_exits_1() -> None:
472 """``--until`` with a non-YYYY-MM-DD value exits with code 1.
473
474 Date validation occurs before repo discovery so no repo setup is needed.
475 """
476 import typer
477 from maestro.muse_cli.commands.recall import recall as recall_cmd
478
479 with pytest.raises(typer.Exit) as exc_info:
480 recall_cmd(
481 ctx=None, # type: ignore[arg-type]
482 query="jazz",
483 limit=5,
484 threshold=0.6,
485 branch=None,
486 since=None,
487 until="2026/01/01",
488 as_json=False,
489 )
490
491 assert exc_info.value.exit_code == ExitCode.USER_ERROR