cgcardona / muse public
test_ask.py python
451 lines 13.1 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Tests for ``muse ask``.
2
3 All async tests call ``_ask_async`` directly with an in-memory SQLite
4 session and a ``tmp_path`` repo root — no real Postgres or running
5 process required. Commits are seeded via ``_commit_async`` so the
6 commands are tested as an integrated pair.
7 """
8 from __future__ import annotations
9
10 import json
11 import os
12 import pathlib
13 import uuid
14 from datetime import date, datetime, timezone
15
16 import pytest
17 from sqlalchemy.ext.asyncio import AsyncSession
18
19 from maestro.muse_cli.commands.ask import AnswerResult, _ask_async, _keywords
20 from maestro.muse_cli.commands.commit import _commit_async
21 from maestro.muse_cli.errors import ExitCode
22
23
24 # ---------------------------------------------------------------------------
25 # Helpers
26 # ---------------------------------------------------------------------------
27
28
29 def _init_muse_repo(root: pathlib.Path, repo_id: str | None = None) -> str:
30 rid = repo_id or str(uuid.uuid4())
31 muse = root / ".muse"
32 (muse / "refs" / "heads").mkdir(parents=True)
33 (muse / "repo.json").write_text(
34 json.dumps({"repo_id": rid, "schema_version": "1"})
35 )
36 (muse / "HEAD").write_text("refs/heads/main")
37 (muse / "refs" / "heads" / "main").write_text("")
38 return rid
39
40
41 def _write_workdir(root: pathlib.Path, files: dict[str, bytes]) -> None:
42 workdir = root / "muse-work"
43 workdir.mkdir(exist_ok=True)
44 for name, content in files.items():
45 (workdir / name).write_bytes(content)
46
47
48 async def _make_commits(
49 root: pathlib.Path,
50 session: AsyncSession,
51 messages: list[str],
52 file_seed: int = 0,
53 ) -> list[str]:
54 """Create N commits on the repo, each with unique file content."""
55 commit_ids: list[str] = []
56 for i, msg in enumerate(messages):
57 _write_workdir(root, {f"track_{file_seed + i}.mid": f"MIDI-{file_seed + i}".encode()})
58 cid = await _commit_async(message=msg, root=root, session=session)
59 commit_ids.append(cid)
60 return commit_ids
61
62
63 # ---------------------------------------------------------------------------
64 # _keywords unit tests
65 # ---------------------------------------------------------------------------
66
67
68 def test_keywords_extracts_meaningful_tokens() -> None:
69 """Non-trivial keywords are extracted from a natural language question."""
70 tokens = _keywords("what tempo changes did I make last week?")
71 assert "tempo" in tokens
72 assert "changes" in tokens
73 # stop words removed
74 assert "what" not in tokens
75 assert "did" not in tokens
76 assert "last" not in tokens
77
78
79 def test_keywords_empty_string_returns_empty() -> None:
80 """Empty input yields no keywords."""
81 assert _keywords("") == []
82
83
84 def test_keywords_all_stopwords_returns_empty() -> None:
85 """A question made entirely of stop-words yields no keywords."""
86 result = _keywords("what is the")
87 assert result == []
88
89
90 def test_keywords_deduplicates_not_applied_but_all_present() -> None:
91 """Keywords preserves order and includes each meaningful token."""
92 tokens = _keywords("boom bap groove boom")
93 assert "boom" in tokens
94 assert "bap" in tokens
95 assert "groove" in tokens
96
97
98 # ---------------------------------------------------------------------------
99 # _ask_async — basic matching
100 # ---------------------------------------------------------------------------
101
102
103 @pytest.mark.anyio
104 async def test_ask_matches_keyword_in_commit_message(
105 tmp_path: pathlib.Path,
106 muse_cli_db_session: AsyncSession,
107 capsys: pytest.CaptureFixture[str],
108 ) -> None:
109 """``muse ask`` returns commits whose messages contain the keyword."""
110 _init_muse_repo(tmp_path)
111 await _make_commits(
112 tmp_path,
113 muse_cli_db_session,
114 ["boom bap take 1", "jazz piano intro", "boom bap take 2"],
115 )
116
117 result = await _ask_async(
118 question="boom bap",
119 root=tmp_path,
120 session=muse_cli_db_session,
121 branch=None,
122 since=None,
123 until=None,
124 cite=False,
125 )
126
127 assert result.total_searched == 3
128 assert len(result.matches) == 2
129 messages = [c.message for c in result.matches]
130 assert all("boom bap" in m for m in messages)
131
132
133 @pytest.mark.anyio
134 async def test_ask_no_matches_returns_empty_list(
135 tmp_path: pathlib.Path,
136 muse_cli_db_session: AsyncSession,
137 ) -> None:
138 """A query with no matching commits returns an empty matches list."""
139 _init_muse_repo(tmp_path)
140 await _make_commits(tmp_path, muse_cli_db_session, ["ambient drone take 1"])
141
142 result = await _ask_async(
143 question="hip hop",
144 root=tmp_path,
145 session=muse_cli_db_session,
146 branch=None,
147 since=None,
148 until=None,
149 cite=False,
150 )
151
152 assert result.total_searched == 1
153 assert result.matches == []
154
155
156 @pytest.mark.anyio
157 async def test_ask_empty_repo_returns_zero_searched(
158 tmp_path: pathlib.Path,
159 muse_cli_db_session: AsyncSession,
160 ) -> None:
161 """On a repo with no commits the answer reports 0 commits searched."""
162 _init_muse_repo(tmp_path)
163
164 result = await _ask_async(
165 question="anything",
166 root=tmp_path,
167 session=muse_cli_db_session,
168 branch=None,
169 since=None,
170 until=None,
171 cite=False,
172 )
173
174 assert result.total_searched == 0
175 assert result.matches == []
176
177
178 # ---------------------------------------------------------------------------
179 # _ask_async — --branch filter
180 # ---------------------------------------------------------------------------
181
182
183 @pytest.mark.anyio
184 async def test_ask_branch_filter_restricts_search(
185 tmp_path: pathlib.Path,
186 muse_cli_db_session: AsyncSession,
187 ) -> None:
188 """``--branch`` restricts the search to commits on that branch."""
189 _init_muse_repo(tmp_path)
190 # Commit on main (default HEAD branch)
191 await _make_commits(tmp_path, muse_cli_db_session, ["groove session main"])
192
193 result = await _ask_async(
194 question="groove",
195 root=tmp_path,
196 session=muse_cli_db_session,
197 branch="other-branch",
198 since=None,
199 until=None,
200 cite=False,
201 )
202
203 # No commits on other-branch → nothing searched
204 assert result.total_searched == 0
205 assert result.matches == []
206
207
208 @pytest.mark.anyio
209 async def test_ask_branch_filter_returns_matching_branch(
210 tmp_path: pathlib.Path,
211 muse_cli_db_session: AsyncSession,
212 ) -> None:
213 """Commits on the specified branch are included in the search."""
214 _init_muse_repo(tmp_path)
215 await _make_commits(tmp_path, muse_cli_db_session, ["groove session on main"])
216
217 result = await _ask_async(
218 question="groove",
219 root=tmp_path,
220 session=muse_cli_db_session,
221 branch="main",
222 since=None,
223 until=None,
224 cite=False,
225 )
226
227 assert result.total_searched == 1
228 assert len(result.matches) == 1
229
230
231 # ---------------------------------------------------------------------------
232 # _ask_async — --since / --until filters
233 # ---------------------------------------------------------------------------
234
235
236 @pytest.mark.anyio
237 async def test_ask_since_filter_excludes_older_commits(
238 tmp_path: pathlib.Path,
239 muse_cli_db_session: AsyncSession,
240 ) -> None:
241 """``--since`` excludes commits before the given date."""
242 from maestro.muse_cli.models import MuseCliCommit, MuseCliSnapshot, MuseCliObject
243
244 _init_muse_repo(tmp_path)
245 # Seed one commit using the high-level helper (today's date)
246 await _make_commits(tmp_path, muse_cli_db_session, ["new session today"])
247
248 # Use a future date to exclude everything
249 future_date = date(2099, 1, 1)
250 result = await _ask_async(
251 question="session",
252 root=tmp_path,
253 session=muse_cli_db_session,
254 branch=None,
255 since=future_date,
256 until=None,
257 cite=False,
258 )
259
260 assert result.total_searched == 0
261
262
263 @pytest.mark.anyio
264 async def test_ask_until_filter_excludes_newer_commits(
265 tmp_path: pathlib.Path,
266 muse_cli_db_session: AsyncSession,
267 ) -> None:
268 """``--until`` excludes commits after the given date."""
269 _init_muse_repo(tmp_path)
270 await _make_commits(tmp_path, muse_cli_db_session, ["new session today"])
271
272 # Use a past date so today's commit is excluded
273 past_date = date(2000, 1, 1)
274 result = await _ask_async(
275 question="session",
276 root=tmp_path,
277 session=muse_cli_db_session,
278 branch=None,
279 since=None,
280 until=past_date,
281 cite=False,
282 )
283
284 assert result.total_searched == 0
285
286
287 # ---------------------------------------------------------------------------
288 # AnswerResult rendering
289 # ---------------------------------------------------------------------------
290
291
292 @pytest.mark.anyio
293 async def test_ask_plain_output_contains_expected_text(
294 tmp_path: pathlib.Path,
295 muse_cli_db_session: AsyncSession,
296 capsys: pytest.CaptureFixture[str],
297 ) -> None:
298 """Plain-text output contains the header and note lines."""
299 _init_muse_repo(tmp_path)
300 await _make_commits(tmp_path, muse_cli_db_session, ["piano loop session"])
301
302 result = await _ask_async(
303 question="piano loop",
304 root=tmp_path,
305 session=muse_cli_db_session,
306 branch=None,
307 since=None,
308 until=None,
309 cite=False,
310 )
311
312 plain = result.to_plain()
313 assert "Based on Muse history" in plain
314 assert "commits searched" in plain
315 assert "Note: Full LLM-powered answer generation" in plain
316 assert "piano loop session" in plain
317
318
319 @pytest.mark.anyio
320 async def test_ask_cite_flag_shows_full_commit_id(
321 tmp_path: pathlib.Path,
322 muse_cli_db_session: AsyncSession,
323 ) -> None:
324 """``--cite`` flag makes the answer include the full 64-char commit ID."""
325 _init_muse_repo(tmp_path)
326 cids = await _make_commits(tmp_path, muse_cli_db_session, ["synth pad session"])
327
328 result = await _ask_async(
329 question="synth pad",
330 root=tmp_path,
331 session=muse_cli_db_session,
332 branch=None,
333 since=None,
334 until=None,
335 cite=True,
336 )
337
338 plain = result.to_plain()
339 assert cids[0] in plain # full 64-char ID present
340
341
342 @pytest.mark.anyio
343 async def test_ask_no_cite_shows_short_commit_id(
344 tmp_path: pathlib.Path,
345 muse_cli_db_session: AsyncSession,
346 ) -> None:
347 """Without ``--cite`` only the short (8-char) commit ID appears."""
348 _init_muse_repo(tmp_path)
349 cids = await _make_commits(tmp_path, muse_cli_db_session, ["drum pattern session"])
350
351 result = await _ask_async(
352 question="drum pattern",
353 root=tmp_path,
354 session=muse_cli_db_session,
355 branch=None,
356 since=None,
357 until=None,
358 cite=False,
359 )
360
361 plain = result.to_plain()
362 assert cids[0][:8] in plain
363 # Full ID should NOT appear (only the 8-char prefix is shown)
364 assert cids[0][8:] not in plain
365
366
367 @pytest.mark.anyio
368 async def test_ask_json_output_is_valid_json(
369 tmp_path: pathlib.Path,
370 muse_cli_db_session: AsyncSession,
371 ) -> None:
372 """``--json`` output is valid JSON with expected top-level keys."""
373 _init_muse_repo(tmp_path)
374 await _make_commits(tmp_path, muse_cli_db_session, ["bass groove take 3"])
375
376 result = await _ask_async(
377 question="bass groove",
378 root=tmp_path,
379 session=muse_cli_db_session,
380 branch=None,
381 since=None,
382 until=None,
383 cite=False,
384 )
385
386 payload = json.loads(result.to_json())
387 assert "question" in payload
388 assert "total_searched" in payload
389 assert "matches" in payload
390 assert "note" in payload
391 assert payload["question"] == "bass groove"
392 assert payload["total_searched"] == 1
393 assert len(payload["matches"]) == 1
394 assert payload["matches"][0]["message"] == "bass groove take 3"
395
396
397 @pytest.mark.anyio
398 async def test_ask_json_cite_flag_includes_full_id(
399 tmp_path: pathlib.Path,
400 muse_cli_db_session: AsyncSession,
401 ) -> None:
402 """``--json --cite`` shows full commit IDs in the JSON output."""
403 _init_muse_repo(tmp_path)
404 cids = await _make_commits(tmp_path, muse_cli_db_session, ["keys session"])
405
406 result = await _ask_async(
407 question="keys",
408 root=tmp_path,
409 session=muse_cli_db_session,
410 branch=None,
411 since=None,
412 until=None,
413 cite=True,
414 )
415
416 payload = json.loads(result.to_json())
417 assert payload["matches"][0]["commit_id"] == cids[0]
418
419
420 # ---------------------------------------------------------------------------
421 # CLI interface (CliRunner)
422 # ---------------------------------------------------------------------------
423
424
425 def test_ask_outside_repo_exits_2(tmp_path: pathlib.Path) -> None:
426 """``muse ask`` outside a .muse/ directory exits with code 2."""
427 from typer.testing import CliRunner
428 from maestro.muse_cli.app import cli
429
430 runner = CliRunner()
431 prev = os.getcwd()
432 try:
433 os.chdir(tmp_path)
434 result = runner.invoke(cli, ["ask", "anything"], catch_exceptions=False)
435 finally:
436 os.chdir(prev)
437
438 assert result.exit_code == ExitCode.REPO_NOT_FOUND
439
440
441 def test_ask_plain_text_no_matches_message(tmp_path: pathlib.Path) -> None:
442 """Plain text output for zero matches includes '(no matching commits)'."""
443 result = AnswerResult(
444 question="anything",
445 total_searched=5,
446 matches=[],
447 cite=False,
448 )
449 plain = result.to_plain()
450 assert "(no matching commits)" in plain
451 assert "5 commits searched" in plain