cgcardona / muse public
test_muse_groove_check.py python
559 lines 18.3 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Tests for ``muse groove-check`` — CLI interface, flag parsing, and stub output format.
2
3 All CLI-level tests use ``typer.testing.CliRunner`` against the full ``muse``
4 app so that argument parsing, flag handling, and exit codes are exercised end-to-end.
5
6 Async core tests call ``_groove_check_async`` directly with an in-memory SQLite
7 session (the stub does not query the DB; the session satisfies the signature
8 contract only).
9 """
10 from __future__ import annotations
11
12 import json
13 import os
14 import pathlib
15 import uuid
16 from collections.abc import AsyncGenerator
17
18 import pytest
19 import pytest_asyncio
20 from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
21 from sqlalchemy.pool import StaticPool
22 from typer.testing import CliRunner
23
24 from maestro.db.database import Base
25 import maestro.muse_cli.models # noqa: F401 — registers MuseCli* with Base.metadata
26 from maestro.muse_cli.app import cli
27 from maestro.muse_cli.commands.groove_check import (
28 _groove_check_async,
29 _render_json,
30 _render_table,
31 )
32 from maestro.muse_cli.errors import ExitCode
33 from maestro.services.muse_groove_check import (
34 DEFAULT_THRESHOLD,
35 CommitGrooveMetrics,
36 GrooveCheckResult,
37 GrooveStatus,
38 build_stub_entries,
39 classify_status,
40 compute_groove_check,
41 )
42
43 runner = CliRunner()
44
45
46 # ---------------------------------------------------------------------------
47 # Fixtures
48 # ---------------------------------------------------------------------------
49
50
51 def _init_muse_repo(root: pathlib.Path, branch: str = "main") -> str:
52 """Create a minimal .muse/ layout with one empty commit ref."""
53 rid = str(uuid.uuid4())
54 muse = root / ".muse"
55 (muse / "refs" / "heads").mkdir(parents=True)
56 (muse / "repo.json").write_text(json.dumps({"repo_id": rid, "schema_version": "1"}))
57 (muse / "HEAD").write_text(f"refs/heads/{branch}")
58 (muse / "refs" / "heads" / branch).write_text("")
59 return rid
60
61
62 def _commit_ref(root: pathlib.Path, branch: str = "main") -> None:
63 """Write a fake commit ID into the branch ref so HEAD is non-empty."""
64 muse = root / ".muse"
65 (muse / "refs" / "heads" / branch).write_text("a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2")
66
67
68 @pytest_asyncio.fixture
69 async def db_session() -> AsyncGenerator[AsyncSession, None]:
70 """In-memory SQLite session (stub groove-check does not actually query it)."""
71 engine = create_async_engine(
72 "sqlite+aiosqlite:///:memory:",
73 connect_args={"check_same_thread": False},
74 poolclass=StaticPool,
75 )
76 async with engine.begin() as conn:
77 await conn.run_sync(Base.metadata.create_all)
78 factory = async_sessionmaker(bind=engine, expire_on_commit=False)
79 async with factory() as session:
80 yield session
81 async with engine.begin() as conn:
82 await conn.run_sync(Base.metadata.drop_all)
83 await engine.dispose()
84
85
86 # ---------------------------------------------------------------------------
87 # Unit — classify_status
88 # ---------------------------------------------------------------------------
89
90
91 def test_classify_status_ok_below_threshold() -> None:
92 """delta ≤ threshold → OK."""
93 assert classify_status(0.05, 0.1) == GrooveStatus.OK
94
95
96 def test_classify_status_ok_at_threshold() -> None:
97 """delta exactly at threshold → OK (inclusive boundary)."""
98 assert classify_status(0.1, 0.1) == GrooveStatus.OK
99
100
101 def test_classify_status_warn_between_threshold_and_double() -> None:
102 """threshold < delta ≤ 2× threshold → WARN."""
103 assert classify_status(0.15, 0.1) == GrooveStatus.WARN
104
105
106 def test_classify_status_fail_above_double_threshold() -> None:
107 """delta > 2× threshold → FAIL."""
108 assert classify_status(0.25, 0.1) == GrooveStatus.FAIL
109
110
111 def test_classify_status_zero_delta_ok() -> None:
112 """First commit always has delta 0.0 → OK regardless of threshold."""
113 assert classify_status(0.0, 0.1) == GrooveStatus.OK
114
115
116 # ---------------------------------------------------------------------------
117 # Unit — build_stub_entries
118 # ---------------------------------------------------------------------------
119
120
121 def test_build_stub_entries_returns_correct_limit() -> None:
122 """build_stub_entries respects the limit argument."""
123 entries = build_stub_entries(threshold=0.1, track=None, section=None, limit=3)
124 assert len(entries) == 3
125
126
127 def test_build_stub_entries_first_entry_has_zero_delta() -> None:
128 """The oldest commit in the window always has drift_delta == 0.0."""
129 entries = build_stub_entries(threshold=0.1, track=None, section=None, limit=5)
130 assert entries[0].drift_delta == 0.0
131
132
133 def test_build_stub_entries_track_stored_in_metadata() -> None:
134 """Track filter is stored in each entry's track field."""
135 entries = build_stub_entries(threshold=0.1, track="drums", section=None, limit=3)
136 for e in entries:
137 assert e.track == "drums"
138
139
140 def test_build_stub_entries_section_stored_in_metadata() -> None:
141 """Section filter is stored in each entry's section field."""
142 entries = build_stub_entries(threshold=0.1, track=None, section="verse", limit=3)
143 for e in entries:
144 assert e.section == "verse"
145
146
147 def test_build_stub_entries_status_matches_classification() -> None:
148 """Each entry's status is consistent with classify_status(drift_delta, threshold)."""
149 threshold = 0.05
150 entries = build_stub_entries(threshold=threshold, track=None, section=None, limit=7)
151 for e in entries:
152 expected = classify_status(e.drift_delta, threshold)
153 assert e.status == expected, f"{e.commit}: status mismatch"
154
155
156 def test_build_stub_entries_groove_scores_positive() -> None:
157 """All groove_score values are non-negative."""
158 entries = build_stub_entries(threshold=0.1, track=None, section=None, limit=7)
159 for e in entries:
160 assert e.groove_score >= 0.0, f"{e.commit}: negative groove_score"
161
162
163 # ---------------------------------------------------------------------------
164 # Unit — compute_groove_check
165 # ---------------------------------------------------------------------------
166
167
168 def test_compute_groove_check_returns_result() -> None:
169 """compute_groove_check returns a GrooveCheckResult with entries."""
170 result = compute_groove_check(commit_range="HEAD~5..HEAD")
171 assert isinstance(result, GrooveCheckResult)
172 assert len(result.entries) > 0
173
174
175 def test_compute_groove_check_stores_range() -> None:
176 """commit_range is echoed in the result."""
177 result = compute_groove_check(commit_range="abc123..def456")
178 assert result.commit_range == "abc123..def456"
179
180
181 def test_compute_groove_check_stores_threshold() -> None:
182 """threshold is echoed in the result."""
183 result = compute_groove_check(commit_range="HEAD~5..HEAD", threshold=0.05)
184 assert result.threshold == 0.05
185
186
187 def test_compute_groove_check_flagged_count_consistent() -> None:
188 """flagged_commits == number of entries whose status != OK."""
189 result = compute_groove_check(commit_range="HEAD~10..HEAD", threshold=0.01)
190 manual_count = sum(1 for e in result.entries if e.status != GrooveStatus.OK)
191 assert result.flagged_commits == manual_count
192
193
194 def test_compute_groove_check_worst_commit_has_max_delta() -> None:
195 """worst_commit refers to the entry with the largest drift_delta."""
196 result = compute_groove_check(commit_range="HEAD~10..HEAD")
197 if result.worst_commit:
198 max_entry = max(result.entries, key=lambda e: e.drift_delta)
199 assert result.worst_commit == max_entry.commit
200
201
202 def test_compute_groove_check_tight_threshold_flags_more() -> None:
203 """A tighter threshold flags more commits than a loose one."""
204 loose = compute_groove_check(commit_range="HEAD~10..HEAD", threshold=0.5)
205 tight = compute_groove_check(commit_range="HEAD~10..HEAD", threshold=0.01)
206 assert tight.flagged_commits >= loose.flagged_commits
207
208
209 # ---------------------------------------------------------------------------
210 # Unit — renderers
211 # ---------------------------------------------------------------------------
212
213
214 def test_render_table_outputs_header(capsys: pytest.CaptureFixture[str]) -> None:
215 """_render_table includes the range, threshold, and column headers."""
216 result = compute_groove_check(commit_range="HEAD~5..HEAD", threshold=0.1)
217 _render_table(result)
218 out = capsys.readouterr().out
219 assert "Groove-check" in out
220 assert "HEAD~5..HEAD" in out
221 assert "0.1 beats" in out
222 assert "Commit" in out
223 assert "Groove Score" in out
224 assert "Drift" in out
225 assert "Status" in out
226
227
228 def test_render_table_shows_all_commits(capsys: pytest.CaptureFixture[str]) -> None:
229 """_render_table emits one row per entry."""
230 result = compute_groove_check(commit_range="HEAD~5..HEAD")
231 _render_table(result)
232 out = capsys.readouterr().out
233 for entry in result.entries:
234 assert entry.commit in out
235
236
237 def test_render_table_shows_flagged_summary(capsys: pytest.CaptureFixture[str]) -> None:
238 """_render_table includes 'Flagged:' summary line."""
239 result = compute_groove_check(commit_range="HEAD~5..HEAD")
240 _render_table(result)
241 out = capsys.readouterr().out
242 assert "Flagged:" in out
243
244
245 def test_render_json_is_valid(capsys: pytest.CaptureFixture[str]) -> None:
246 """_render_json emits parseable JSON with the expected top-level keys."""
247 result = compute_groove_check(commit_range="HEAD~5..HEAD", threshold=0.1)
248 _render_json(result)
249 raw = capsys.readouterr().out
250 payload = json.loads(raw)
251 assert payload["commit_range"] == "HEAD~5..HEAD"
252 assert payload["threshold"] == 0.1
253 assert "total_commits" in payload
254 assert "flagged_commits" in payload
255 assert "worst_commit" in payload
256 assert isinstance(payload["entries"], list)
257
258
259 def test_render_json_entries_have_required_fields(capsys: pytest.CaptureFixture[str]) -> None:
260 """Each JSON entry contains all required per-commit fields."""
261 result = compute_groove_check(commit_range="HEAD~5..HEAD")
262 _render_json(result)
263 raw = capsys.readouterr().out
264 payload = json.loads(raw)
265 required = {"commit", "groove_score", "drift_delta", "status", "track", "section", "midi_files"}
266 for entry in payload["entries"]:
267 assert required.issubset(entry.keys()), f"Missing fields in entry: {entry}"
268
269
270 def test_render_json_status_values_valid(capsys: pytest.CaptureFixture[str]) -> None:
271 """All JSON status values are valid GrooveStatus members."""
272 valid = {s.value for s in GrooveStatus}
273 result = compute_groove_check(commit_range="HEAD~5..HEAD", threshold=0.01)
274 _render_json(result)
275 raw = capsys.readouterr().out
276 payload = json.loads(raw)
277 for entry in payload["entries"]:
278 assert entry["status"] in valid, f"Unknown status: {entry['status']}"
279
280
281 # ---------------------------------------------------------------------------
282 # Async core — _groove_check_async
283 # ---------------------------------------------------------------------------
284
285
286 @pytest.mark.anyio
287 async def test_groove_check_async_default_output(
288 tmp_path: pathlib.Path,
289 db_session: AsyncSession,
290 capsys: pytest.CaptureFixture[str],
291 ) -> None:
292 """_groove_check_async with no filters shows a table with commit rows."""
293 _init_muse_repo(tmp_path)
294 _commit_ref(tmp_path)
295
296 await _groove_check_async(
297 root=tmp_path,
298 session=db_session,
299 commit_range=None,
300 track=None,
301 section=None,
302 threshold=DEFAULT_THRESHOLD,
303 as_json=False,
304 )
305
306 out = capsys.readouterr().out
307 assert "Groove-check" in out
308 assert "Flagged:" in out
309
310
311 @pytest.mark.anyio
312 async def test_groove_check_async_json_mode(
313 tmp_path: pathlib.Path,
314 db_session: AsyncSession,
315 capsys: pytest.CaptureFixture[str],
316 ) -> None:
317 """_groove_check_async --json emits valid JSON with entries list."""
318 _init_muse_repo(tmp_path)
319 _commit_ref(tmp_path)
320
321 await _groove_check_async(
322 root=tmp_path,
323 session=db_session,
324 commit_range=None,
325 track=None,
326 section=None,
327 threshold=DEFAULT_THRESHOLD,
328 as_json=True,
329 )
330
331 raw = capsys.readouterr().out
332 payload = json.loads(raw)
333 assert "entries" in payload
334 assert len(payload["entries"]) > 0
335
336
337 @pytest.mark.anyio
338 async def test_groove_check_async_explicit_range(
339 tmp_path: pathlib.Path,
340 db_session: AsyncSession,
341 capsys: pytest.CaptureFixture[str],
342 ) -> None:
343 """An explicit commit range appears in the table header."""
344 _init_muse_repo(tmp_path)
345 _commit_ref(tmp_path)
346
347 await _groove_check_async(
348 root=tmp_path,
349 session=db_session,
350 commit_range="HEAD~3..HEAD",
351 track=None,
352 section=None,
353 threshold=DEFAULT_THRESHOLD,
354 as_json=False,
355 )
356
357 out = capsys.readouterr().out
358 assert "HEAD~3..HEAD" in out
359
360
361 @pytest.mark.anyio
362 async def test_groove_check_async_track_filter_stored(
363 tmp_path: pathlib.Path,
364 db_session: AsyncSession,
365 capsys: pytest.CaptureFixture[str],
366 ) -> None:
367 """--track is propagated to the result entries."""
368 _init_muse_repo(tmp_path)
369 _commit_ref(tmp_path)
370
371 await _groove_check_async(
372 root=tmp_path,
373 session=db_session,
374 commit_range=None,
375 track="drums",
376 section=None,
377 threshold=DEFAULT_THRESHOLD,
378 as_json=True,
379 )
380
381 raw = capsys.readouterr().out
382 payload = json.loads(raw)
383 for entry in payload["entries"]:
384 assert entry["track"] == "drums"
385
386
387 @pytest.mark.anyio
388 async def test_groove_check_async_section_filter_stored(
389 tmp_path: pathlib.Path,
390 db_session: AsyncSession,
391 capsys: pytest.CaptureFixture[str],
392 ) -> None:
393 """--section is propagated to the result entries."""
394 _init_muse_repo(tmp_path)
395 _commit_ref(tmp_path)
396
397 await _groove_check_async(
398 root=tmp_path,
399 session=db_session,
400 commit_range=None,
401 track=None,
402 section="verse",
403 threshold=DEFAULT_THRESHOLD,
404 as_json=True,
405 )
406
407 raw = capsys.readouterr().out
408 payload = json.loads(raw)
409 for entry in payload["entries"]:
410 assert entry["section"] == "verse"
411
412
413 @pytest.mark.anyio
414 async def test_groove_check_async_custom_threshold(
415 tmp_path: pathlib.Path,
416 db_session: AsyncSession,
417 capsys: pytest.CaptureFixture[str],
418 ) -> None:
419 """Custom threshold is reflected in the JSON output."""
420 _init_muse_repo(tmp_path)
421 _commit_ref(tmp_path)
422
423 await _groove_check_async(
424 root=tmp_path,
425 session=db_session,
426 commit_range=None,
427 track=None,
428 section=None,
429 threshold=0.05,
430 as_json=True,
431 )
432
433 raw = capsys.readouterr().out
434 payload = json.loads(raw)
435 assert payload["threshold"] == 0.05
436
437
438 @pytest.mark.anyio
439 async def test_groove_check_async_invalid_threshold_exits(
440 tmp_path: pathlib.Path,
441 db_session: AsyncSession,
442 ) -> None:
443 """threshold ≤ 0 exits with USER_ERROR."""
444 import typer
445
446 _init_muse_repo(tmp_path)
447 _commit_ref(tmp_path)
448
449 with pytest.raises(typer.Exit) as exc_info:
450 await _groove_check_async(
451 root=tmp_path,
452 session=db_session,
453 commit_range=None,
454 track=None,
455 section=None,
456 threshold=0.0,
457 as_json=False,
458 )
459 assert exc_info.value.exit_code == int(ExitCode.USER_ERROR)
460
461
462 @pytest.mark.anyio
463 async def test_groove_check_async_no_commits_exits_success(
464 tmp_path: pathlib.Path,
465 db_session: AsyncSession,
466 capsys: pytest.CaptureFixture[str],
467 ) -> None:
468 """With no commits and no explicit range, exits 0 with informative message."""
469 import typer
470
471 _init_muse_repo(tmp_path)
472 # No _commit_ref call — branch ref is empty.
473
474 with pytest.raises(typer.Exit) as exc_info:
475 await _groove_check_async(
476 root=tmp_path,
477 session=db_session,
478 commit_range=None,
479 track=None,
480 section=None,
481 threshold=DEFAULT_THRESHOLD,
482 as_json=False,
483 )
484 assert exc_info.value.exit_code == int(ExitCode.SUCCESS)
485 out = capsys.readouterr().out
486 assert "No commits yet" in out
487
488
489 # ---------------------------------------------------------------------------
490 # Regression — test_groove_check_outputs_table_with_drift_status
491 # ---------------------------------------------------------------------------
492
493
494 @pytest.mark.anyio
495 async def test_groove_check_outputs_table_with_drift_status(
496 tmp_path: pathlib.Path,
497 db_session: AsyncSession,
498 capsys: pytest.CaptureFixture[str],
499 ) -> None:
500 """Regression: groove-check always outputs a table with commit/drift/status columns.
501
502 This is the primary acceptance-criteria test:
503 the table must include commit refs, groove_score, drift_delta, and a status
504 column with OK/WARN/FAIL values.
505 """
506 _init_muse_repo(tmp_path)
507 _commit_ref(tmp_path)
508
509 await _groove_check_async(
510 root=tmp_path,
511 session=db_session,
512 commit_range="HEAD~6..HEAD",
513 track=None,
514 section=None,
515 threshold=DEFAULT_THRESHOLD,
516 as_json=False,
517 )
518
519 out = capsys.readouterr().out
520 assert "Groove-check" in out
521 assert "Commit" in out
522 assert "Groove Score" in out
523 assert "Drift" in out
524 assert "Status" in out
525 # At least one valid status label must appear
526 assert any(status in out for status in ("OK", "WARN", "FAIL"))
527
528
529 # ---------------------------------------------------------------------------
530 # CLI integration — CliRunner
531 # ---------------------------------------------------------------------------
532
533
534 def test_cli_groove_check_outside_repo_exits_2(tmp_path: pathlib.Path) -> None:
535 """``muse groove-check`` exits 2 when invoked outside a Muse repository."""
536 prev = os.getcwd()
537 try:
538 os.chdir(tmp_path)
539 result = runner.invoke(cli, ["groove-check"], catch_exceptions=False)
540 finally:
541 os.chdir(prev)
542
543 assert result.exit_code == int(ExitCode.REPO_NOT_FOUND)
544 assert "not a muse repository" in result.output.lower()
545
546
547 def test_cli_groove_check_help_lists_flags(tmp_path: pathlib.Path) -> None:
548 """``muse groove-check --help`` shows all documented flags."""
549 result = runner.invoke(cli, ["groove-check", "--help"])
550 assert result.exit_code == 0
551 for flag in ("--track", "--section", "--threshold", "--json"):
552 assert flag in result.output, f"Flag '{flag}' not found in help output"
553
554
555 def test_cli_groove_check_appears_in_muse_help() -> None:
556 """``muse --help`` lists the groove-check subcommand."""
557 result = runner.invoke(cli, ["--help"])
558 assert result.exit_code == 0
559 assert "groove-check" in result.output