test_recall.py
python
| 1 | """Tests for ``muse recall``. |
| 2 | |
| 3 | All async tests call ``_recall_async`` directly with an in-memory SQLite |
| 4 | session and a ``tmp_path`` repo root — no real Postgres or running process |
| 5 | required. Commits are seeded via ``_commit_async`` so the two commands |
| 6 | are exercised as an integrated pair. |
| 7 | |
| 8 | Covers: |
| 9 | - Keyword match returns top-N results sorted by score (highest first). |
| 10 | - ``--limit`` restricts result count. |
| 11 | - ``--threshold`` filters low-scoring commits. |
| 12 | - ``--since`` / ``--until`` date filters. |
| 13 | - ``--json`` emits valid JSON with the expected schema. |
| 14 | - Query with zero matches returns empty result set (not an error). |
| 15 | - CLI invocation outside a repo exits with code 2. |
| 16 | - ``--since`` / ``--until`` with bad date strings exits with code 1. |
| 17 | """ |
| 18 | from __future__ import annotations |
| 19 | |
| 20 | import json |
| 21 | import os |
| 22 | import pathlib |
| 23 | import uuid |
| 24 | |
| 25 | import pytest |
| 26 | from sqlalchemy.ext.asyncio import AsyncSession |
| 27 | |
| 28 | from maestro.muse_cli.commands.commit import _commit_async |
| 29 | from maestro.muse_cli.commands.recall import _recall_async, _score, _tokenize |
| 30 | from maestro.muse_cli.errors import ExitCode |
| 31 | |
| 32 | |
| 33 | # --------------------------------------------------------------------------- |
| 34 | # Repo / workdir helpers (mirror test_log.py pattern) |
| 35 | # --------------------------------------------------------------------------- |
| 36 | |
| 37 | |
| 38 | def _init_muse_repo(root: pathlib.Path, repo_id: str | None = None) -> str: |
| 39 | rid = repo_id or str(uuid.uuid4()) |
| 40 | muse = root / ".muse" |
| 41 | (muse / "refs" / "heads").mkdir(parents=True) |
| 42 | (muse / "repo.json").write_text( |
| 43 | json.dumps({"repo_id": rid, "schema_version": "1"}) |
| 44 | ) |
| 45 | (muse / "HEAD").write_text("refs/heads/main") |
| 46 | (muse / "refs" / "heads" / "main").write_text("") |
| 47 | return rid |
| 48 | |
| 49 | |
| 50 | def _write_workdir(root: pathlib.Path, files: dict[str, bytes]) -> None: |
| 51 | workdir = root / "muse-work" |
| 52 | workdir.mkdir(exist_ok=True) |
| 53 | for name, content in files.items(): |
| 54 | (workdir / name).write_bytes(content) |
| 55 | |
| 56 | |
| 57 | async def _make_commits( |
| 58 | root: pathlib.Path, |
| 59 | session: AsyncSession, |
| 60 | messages: list[str], |
| 61 | file_seed: int = 0, |
| 62 | ) -> list[str]: |
| 63 | """Create N commits on the repo, each with unique file content.""" |
| 64 | commit_ids: list[str] = [] |
| 65 | for i, msg in enumerate(messages): |
| 66 | _write_workdir(root, {f"track_{file_seed + i}.mid": f"MIDI-{file_seed + i}".encode()}) |
| 67 | cid = await _commit_async(message=msg, root=root, session=session) |
| 68 | commit_ids.append(cid) |
| 69 | return commit_ids |
| 70 | |
| 71 | |
| 72 | # --------------------------------------------------------------------------- |
| 73 | # Unit tests for scoring helpers |
| 74 | # --------------------------------------------------------------------------- |
| 75 | |
| 76 | |
| 77 | class TestScoringHelpers: |
| 78 | """Unit tests for ``_tokenize`` and ``_score`` — no DB required.""" |
| 79 | |
| 80 | def test_tokenize_splits_on_whitespace(self) -> None: |
| 81 | assert _tokenize("dark jazz bassline") == {"dark", "jazz", "bassline"} |
| 82 | |
| 83 | def test_tokenize_is_lowercase(self) -> None: |
| 84 | assert _tokenize("DARK Jazz") == {"dark", "jazz"} |
| 85 | |
| 86 | def test_tokenize_ignores_punctuation(self) -> None: |
| 87 | tokens = _tokenize("boom, bap! drum-fill") |
| 88 | assert "boom" in tokens |
| 89 | assert "bap" in tokens |
| 90 | assert "drum" in tokens |
| 91 | assert "fill" in tokens |
| 92 | |
| 93 | def test_score_full_match_returns_one(self) -> None: |
| 94 | q = _tokenize("jazz bassline") |
| 95 | score = _score(q, "this is a jazz bassline") |
| 96 | assert score == 1.0 |
| 97 | |
| 98 | def test_score_no_match_returns_zero(self) -> None: |
| 99 | q = _tokenize("jazz bassline") |
| 100 | score = _score(q, "rock guitar solo") |
| 101 | assert score == 0.0 |
| 102 | |
| 103 | def test_score_partial_match(self) -> None: |
| 104 | q = _tokenize("jazz drum fill") |
| 105 | score = _score(q, "a cool jazz moment") |
| 106 | assert 0.0 < score < 1.0 |
| 107 | assert score == pytest.approx(1 / 3, rel=1e-6) |
| 108 | |
| 109 | def test_score_empty_query_returns_zero(self) -> None: |
| 110 | score = _score(set(), "anything") |
| 111 | assert score == 0.0 |
| 112 | |
| 113 | |
| 114 | # --------------------------------------------------------------------------- |
| 115 | # test_recall_returns_top_n_by_score |
| 116 | # --------------------------------------------------------------------------- |
| 117 | |
| 118 | |
| 119 | @pytest.mark.anyio |
| 120 | async def test_recall_returns_top_n_by_score( |
| 121 | tmp_path: pathlib.Path, |
| 122 | muse_cli_db_session: AsyncSession, |
| 123 | capsys: pytest.CaptureFixture[str], |
| 124 | ) -> None: |
| 125 | """Commits with higher keyword overlap rank before those with lower overlap.""" |
| 126 | _init_muse_repo(tmp_path) |
| 127 | await _make_commits( |
| 128 | tmp_path, |
| 129 | muse_cli_db_session, |
| 130 | [ |
| 131 | "boom bap drum pattern", |
| 132 | "jazz piano chord voicing", |
| 133 | "boom bap jazz fusion groove", |
| 134 | "classical string quartet", |
| 135 | ], |
| 136 | ) |
| 137 | |
| 138 | results = await _recall_async( |
| 139 | root=tmp_path, |
| 140 | session=muse_cli_db_session, |
| 141 | query="boom bap jazz", |
| 142 | limit=5, |
| 143 | threshold=0.0, |
| 144 | branch=None, |
| 145 | since=None, |
| 146 | until=None, |
| 147 | as_json=False, |
| 148 | ) |
| 149 | |
| 150 | assert len(results) > 0 |
| 151 | assert results[0]["message"] == "boom bap jazz fusion groove" |
| 152 | assert results[0]["score"] == pytest.approx(1.0) |
| 153 | |
| 154 | scores = [r["score"] for r in results] |
| 155 | assert scores == sorted(scores, reverse=True) |
| 156 | |
| 157 | |
| 158 | # --------------------------------------------------------------------------- |
| 159 | # test_recall_limit_restricts_result_count |
| 160 | # --------------------------------------------------------------------------- |
| 161 | |
| 162 | |
| 163 | @pytest.mark.anyio |
| 164 | async def test_recall_limit_restricts_result_count( |
| 165 | tmp_path: pathlib.Path, |
| 166 | muse_cli_db_session: AsyncSession, |
| 167 | capsys: pytest.CaptureFixture[str], |
| 168 | ) -> None: |
| 169 | """``--limit 2`` returns at most 2 results even if more match.""" |
| 170 | _init_muse_repo(tmp_path) |
| 171 | await _make_commits( |
| 172 | tmp_path, |
| 173 | muse_cli_db_session, |
| 174 | [ |
| 175 | "jazz piano solo", |
| 176 | "jazz drum groove", |
| 177 | "jazz bass walk", |
| 178 | "jazz chord progression", |
| 179 | ], |
| 180 | ) |
| 181 | |
| 182 | results = await _recall_async( |
| 183 | root=tmp_path, |
| 184 | session=muse_cli_db_session, |
| 185 | query="jazz", |
| 186 | limit=2, |
| 187 | threshold=0.0, |
| 188 | branch=None, |
| 189 | since=None, |
| 190 | until=None, |
| 191 | as_json=False, |
| 192 | ) |
| 193 | |
| 194 | assert len(results) == 2 |
| 195 | |
| 196 | |
| 197 | # --------------------------------------------------------------------------- |
| 198 | # test_recall_threshold_filters_low_scores |
| 199 | # --------------------------------------------------------------------------- |
| 200 | |
| 201 | |
| 202 | @pytest.mark.anyio |
| 203 | async def test_recall_threshold_filters_low_scores( |
| 204 | tmp_path: pathlib.Path, |
| 205 | muse_cli_db_session: AsyncSession, |
| 206 | capsys: pytest.CaptureFixture[str], |
| 207 | ) -> None: |
| 208 | """Commits with score < threshold are excluded from results.""" |
| 209 | _init_muse_repo(tmp_path) |
| 210 | await _make_commits( |
| 211 | tmp_path, |
| 212 | muse_cli_db_session, |
| 213 | [ |
| 214 | "jazz piano solo", |
| 215 | "classical strings", |
| 216 | ], |
| 217 | ) |
| 218 | |
| 219 | results = await _recall_async( |
| 220 | root=tmp_path, |
| 221 | session=muse_cli_db_session, |
| 222 | query="jazz", |
| 223 | limit=5, |
| 224 | threshold=0.6, |
| 225 | branch=None, |
| 226 | since=None, |
| 227 | until=None, |
| 228 | as_json=False, |
| 229 | ) |
| 230 | |
| 231 | assert all(r["score"] >= 0.6 for r in results) |
| 232 | messages = [r["message"] for r in results] |
| 233 | assert "jazz piano solo" in messages |
| 234 | assert "classical strings" not in messages |
| 235 | |
| 236 | |
| 237 | # --------------------------------------------------------------------------- |
| 238 | # test_recall_no_matches_returns_empty |
| 239 | # --------------------------------------------------------------------------- |
| 240 | |
| 241 | |
| 242 | @pytest.mark.anyio |
| 243 | async def test_recall_no_matches_returns_empty( |
| 244 | tmp_path: pathlib.Path, |
| 245 | muse_cli_db_session: AsyncSession, |
| 246 | capsys: pytest.CaptureFixture[str], |
| 247 | ) -> None: |
| 248 | """A query that matches nothing returns an empty list (not an error).""" |
| 249 | _init_muse_repo(tmp_path) |
| 250 | await _make_commits(tmp_path, muse_cli_db_session, ["rock guitar riff"]) |
| 251 | |
| 252 | results = await _recall_async( |
| 253 | root=tmp_path, |
| 254 | session=muse_cli_db_session, |
| 255 | query="jazz bassline", |
| 256 | limit=5, |
| 257 | threshold=0.6, |
| 258 | branch=None, |
| 259 | since=None, |
| 260 | until=None, |
| 261 | as_json=False, |
| 262 | ) |
| 263 | |
| 264 | assert results == [] |
| 265 | out = capsys.readouterr().out |
| 266 | assert "No matching commits found" in out |
| 267 | |
| 268 | |
| 269 | # --------------------------------------------------------------------------- |
| 270 | # test_recall_json_output_valid_schema |
| 271 | # --------------------------------------------------------------------------- |
| 272 | |
| 273 | |
| 274 | @pytest.mark.anyio |
| 275 | async def test_recall_json_output_valid_schema( |
| 276 | tmp_path: pathlib.Path, |
| 277 | muse_cli_db_session: AsyncSession, |
| 278 | capsys: pytest.CaptureFixture[str], |
| 279 | ) -> None: |
| 280 | """``--json`` output is valid JSON with expected fields.""" |
| 281 | _init_muse_repo(tmp_path) |
| 282 | await _make_commits(tmp_path, muse_cli_db_session, ["jazz drum groove"]) |
| 283 | capsys.readouterr() # discard commit output before testing recall JSON |
| 284 | |
| 285 | await _recall_async( |
| 286 | root=tmp_path, |
| 287 | session=muse_cli_db_session, |
| 288 | query="jazz drum", |
| 289 | limit=5, |
| 290 | threshold=0.0, |
| 291 | branch=None, |
| 292 | since=None, |
| 293 | until=None, |
| 294 | as_json=True, |
| 295 | ) |
| 296 | |
| 297 | out = capsys.readouterr().out |
| 298 | parsed = json.loads(out) |
| 299 | assert isinstance(parsed, list) |
| 300 | assert len(parsed) >= 1 |
| 301 | |
| 302 | entry = parsed[0] |
| 303 | assert "rank" in entry |
| 304 | assert "score" in entry |
| 305 | assert "commit_id" in entry |
| 306 | assert "date" in entry |
| 307 | assert "branch" in entry |
| 308 | assert "message" in entry |
| 309 | assert entry["rank"] == 1 |
| 310 | assert isinstance(entry["score"], float) |
| 311 | assert entry["branch"] == "main" |
| 312 | |
| 313 | |
| 314 | # --------------------------------------------------------------------------- |
| 315 | # test_recall_rank_field_is_sequential |
| 316 | # --------------------------------------------------------------------------- |
| 317 | |
| 318 | |
| 319 | @pytest.mark.anyio |
| 320 | async def test_recall_rank_field_is_sequential( |
| 321 | tmp_path: pathlib.Path, |
| 322 | muse_cli_db_session: AsyncSession, |
| 323 | capsys: pytest.CaptureFixture[str], |
| 324 | ) -> None: |
| 325 | """``rank`` starts at 1 and increments sequentially.""" |
| 326 | _init_muse_repo(tmp_path) |
| 327 | await _make_commits( |
| 328 | tmp_path, |
| 329 | muse_cli_db_session, |
| 330 | ["jazz piano", "jazz drums", "jazz bass"], |
| 331 | ) |
| 332 | |
| 333 | results = await _recall_async( |
| 334 | root=tmp_path, |
| 335 | session=muse_cli_db_session, |
| 336 | query="jazz", |
| 337 | limit=5, |
| 338 | threshold=0.0, |
| 339 | branch=None, |
| 340 | since=None, |
| 341 | until=None, |
| 342 | as_json=False, |
| 343 | ) |
| 344 | |
| 345 | assert [r["rank"] for r in results] == list(range(1, len(results) + 1)) |
| 346 | |
| 347 | |
| 348 | # --------------------------------------------------------------------------- |
| 349 | # test_recall_since_filter_excludes_older_commits |
| 350 | # --------------------------------------------------------------------------- |
| 351 | |
| 352 | |
| 353 | @pytest.mark.anyio |
| 354 | async def test_recall_since_filter_excludes_older_commits( |
| 355 | tmp_path: pathlib.Path, |
| 356 | muse_cli_db_session: AsyncSession, |
| 357 | capsys: pytest.CaptureFixture[str], |
| 358 | ) -> None: |
| 359 | """``--since`` set to the future excludes all commits.""" |
| 360 | from datetime import datetime, timezone |
| 361 | |
| 362 | _init_muse_repo(tmp_path) |
| 363 | await _make_commits(tmp_path, muse_cli_db_session, ["jazz rhythm section"]) |
| 364 | |
| 365 | future = datetime(2099, 1, 1, tzinfo=timezone.utc) |
| 366 | |
| 367 | results = await _recall_async( |
| 368 | root=tmp_path, |
| 369 | session=muse_cli_db_session, |
| 370 | query="jazz", |
| 371 | limit=5, |
| 372 | threshold=0.0, |
| 373 | branch=None, |
| 374 | since=future, |
| 375 | until=None, |
| 376 | as_json=False, |
| 377 | ) |
| 378 | |
| 379 | assert results == [] |
| 380 | |
| 381 | |
| 382 | # --------------------------------------------------------------------------- |
| 383 | # test_recall_until_filter_excludes_newer_commits |
| 384 | # --------------------------------------------------------------------------- |
| 385 | |
| 386 | |
| 387 | @pytest.mark.anyio |
| 388 | async def test_recall_until_filter_excludes_newer_commits( |
| 389 | tmp_path: pathlib.Path, |
| 390 | muse_cli_db_session: AsyncSession, |
| 391 | capsys: pytest.CaptureFixture[str], |
| 392 | ) -> None: |
| 393 | """``--until`` set in the past excludes all commits.""" |
| 394 | from datetime import datetime, timezone |
| 395 | |
| 396 | _init_muse_repo(tmp_path) |
| 397 | await _make_commits(tmp_path, muse_cli_db_session, ["jazz rhythm section"]) |
| 398 | |
| 399 | past = datetime(2000, 1, 1, tzinfo=timezone.utc) |
| 400 | |
| 401 | results = await _recall_async( |
| 402 | root=tmp_path, |
| 403 | session=muse_cli_db_session, |
| 404 | query="jazz", |
| 405 | limit=5, |
| 406 | threshold=0.0, |
| 407 | branch=None, |
| 408 | since=None, |
| 409 | until=past, |
| 410 | as_json=False, |
| 411 | ) |
| 412 | |
| 413 | assert results == [] |
| 414 | |
| 415 | |
| 416 | # --------------------------------------------------------------------------- |
| 417 | # test_recall_outside_repo_exits_2 |
| 418 | # --------------------------------------------------------------------------- |
| 419 | |
| 420 | |
| 421 | def test_recall_outside_repo_exits_2(tmp_path: pathlib.Path) -> None: |
| 422 | """``muse recall`` outside a .muse/ directory exits with code 2.""" |
| 423 | from typer.testing import CliRunner |
| 424 | from maestro.muse_cli.app import cli |
| 425 | |
| 426 | runner = CliRunner() |
| 427 | prev = os.getcwd() |
| 428 | try: |
| 429 | os.chdir(tmp_path) |
| 430 | result = runner.invoke(cli, ["recall", "jazz"], catch_exceptions=False) |
| 431 | finally: |
| 432 | os.chdir(prev) |
| 433 | |
| 434 | assert result.exit_code == ExitCode.REPO_NOT_FOUND |
| 435 | |
| 436 | |
| 437 | # --------------------------------------------------------------------------- |
| 438 | # test_recall_bad_since_date_exits_1 |
| 439 | # --------------------------------------------------------------------------- |
| 440 | |
| 441 | |
| 442 | def test_recall_bad_since_date_exits_1() -> None: |
| 443 | """``--since`` with a non-YYYY-MM-DD value exits with code 1. |
| 444 | |
| 445 | Date validation occurs before repo discovery so no repo setup is needed. |
| 446 | """ |
| 447 | import typer |
| 448 | from maestro.muse_cli.commands.recall import recall as recall_cmd |
| 449 | |
| 450 | with pytest.raises(typer.Exit) as exc_info: |
| 451 | recall_cmd( |
| 452 | ctx=None, # type: ignore[arg-type] |
| 453 | query="jazz", |
| 454 | limit=5, |
| 455 | threshold=0.6, |
| 456 | branch=None, |
| 457 | since="not-a-date", |
| 458 | until=None, |
| 459 | as_json=False, |
| 460 | ) |
| 461 | |
| 462 | assert exc_info.value.exit_code == ExitCode.USER_ERROR |
| 463 | |
| 464 | |
| 465 | |
| 466 | # --------------------------------------------------------------------------- |
| 467 | # test_recall_bad_until_date_exits_1 |
| 468 | # --------------------------------------------------------------------------- |
| 469 | |
| 470 | |
| 471 | def test_recall_bad_until_date_exits_1() -> None: |
| 472 | """``--until`` with a non-YYYY-MM-DD value exits with code 1. |
| 473 | |
| 474 | Date validation occurs before repo discovery so no repo setup is needed. |
| 475 | """ |
| 476 | import typer |
| 477 | from maestro.muse_cli.commands.recall import recall as recall_cmd |
| 478 | |
| 479 | with pytest.raises(typer.Exit) as exc_info: |
| 480 | recall_cmd( |
| 481 | ctx=None, # type: ignore[arg-type] |
| 482 | query="jazz", |
| 483 | limit=5, |
| 484 | threshold=0.6, |
| 485 | branch=None, |
| 486 | since=None, |
| 487 | until="2026/01/01", |
| 488 | as_json=False, |
| 489 | ) |
| 490 | |
| 491 | assert exc_info.value.exit_code == ExitCode.USER_ERROR |