cgcardona / muse public
test_swing.py python
557 lines 16.3 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Tests for ``muse swing`` — swing factor analysis and annotation.
2
3 Covers:
4 - swing_label thresholds (Straight / Light / Medium / Hard)
5 - _swing_detect_async returns correct schema
6 - _swing_history_async returns a list with correct entries
7 - _swing_compare_async returns head/compare/delta structure
8 - Output formatters for text and JSON modes
9 - CLI flag parsing via CliRunner (--set, --detect, --track, --compare, --history, --json)
10 - --set out-of-range exits 1
11 - Outside-repo invocation exits 2
12 """
13 from __future__ import annotations
14
15 import json
16 import os
17 import pathlib
18 import uuid
19
20 import pytest
21 from sqlalchemy.ext.asyncio import AsyncSession
22 from typer.testing import CliRunner
23
24 from maestro.muse_cli.commands.swing import (
25 FACTOR_MAX,
26 FACTOR_MIN,
27 LIGHT_MAX,
28 MEDIUM_MAX,
29 STRAIGHT_MAX,
30 SwingCompareResult,
31 SwingDetectResult,
32 _format_compare,
33 _format_detect,
34 _format_history,
35 _swing_compare_async,
36 _swing_detect_async,
37 _swing_history_async,
38 swing_label,
39 )
40 from maestro.muse_cli.errors import ExitCode
41
42
43 # ---------------------------------------------------------------------------
44 # Helpers
45 # ---------------------------------------------------------------------------
46
47
48 def _init_muse_repo(root: pathlib.Path) -> str:
49 """Create a minimal .muse/ layout and return the repo_id."""
50 rid = str(uuid.uuid4())
51 muse = root / ".muse"
52 (muse / "refs" / "heads").mkdir(parents=True)
53 (muse / "repo.json").write_text(
54 json.dumps({"repo_id": rid, "schema_version": "1"})
55 )
56 (muse / "HEAD").write_text("refs/heads/main")
57 (muse / "refs" / "heads" / "main").write_text("abc1234")
58 return rid
59
60
61 # ---------------------------------------------------------------------------
62 # swing_label — threshold tests
63 # ---------------------------------------------------------------------------
64
65
66 def test_swing_label_straight_below_threshold() -> None:
67 """Factor below STRAIGHT_MAX maps to 'Straight'."""
68 assert swing_label(FACTOR_MIN) == "Straight"
69 assert swing_label(STRAIGHT_MAX - 0.001) == "Straight"
70
71
72 def test_swing_label_light_range() -> None:
73 """Factor in [STRAIGHT_MAX, LIGHT_MAX) maps to 'Light'."""
74 assert swing_label(STRAIGHT_MAX) == "Light"
75 assert swing_label((STRAIGHT_MAX + LIGHT_MAX) / 2) == "Light"
76 assert swing_label(LIGHT_MAX - 0.001) == "Light"
77
78
79 def test_swing_label_medium_range() -> None:
80 """Factor in [LIGHT_MAX, MEDIUM_MAX) maps to 'Medium'."""
81 assert swing_label(LIGHT_MAX) == "Medium"
82 assert swing_label((LIGHT_MAX + MEDIUM_MAX) / 2) == "Medium"
83 assert swing_label(MEDIUM_MAX - 0.001) == "Medium"
84
85
86 def test_swing_label_hard_at_and_above_medium_max() -> None:
87 """Factor >= MEDIUM_MAX maps to 'Hard'."""
88 assert swing_label(MEDIUM_MAX) == "Hard"
89 assert swing_label(FACTOR_MAX) == "Hard"
90
91
92 def test_swing_label_boundary_straight_exact() -> None:
93 """STRAIGHT_MAX boundary belongs to 'Light', not 'Straight'."""
94 assert swing_label(STRAIGHT_MAX) == "Light"
95
96
97 def test_swing_label_boundary_light_exact() -> None:
98 """LIGHT_MAX boundary belongs to 'Medium', not 'Light'."""
99 assert swing_label(LIGHT_MAX) == "Medium"
100
101
102 def test_swing_label_boundary_medium_exact() -> None:
103 """MEDIUM_MAX boundary belongs to 'Hard', not 'Medium'."""
104 assert swing_label(MEDIUM_MAX) == "Hard"
105
106
107 # ---------------------------------------------------------------------------
108 # _swing_detect_async
109 # ---------------------------------------------------------------------------
110
111
112 @pytest.mark.anyio
113 async def test_swing_detect_returns_correct_schema(
114 tmp_path: pathlib.Path,
115 muse_cli_db_session: AsyncSession,
116 ) -> None:
117 """_swing_detect_async returns a dict with all required keys."""
118 _init_muse_repo(tmp_path)
119 result = await _swing_detect_async(
120 root=tmp_path,
121 session=muse_cli_db_session,
122 commit=None,
123 track=None,
124 )
125 assert "factor" in result
126 assert "label" in result
127 assert "commit" in result
128 assert "branch" in result
129 assert "track" in result
130 assert "source" in result
131
132
133 @pytest.mark.anyio
134 async def test_swing_detect_factor_in_valid_range(
135 tmp_path: pathlib.Path,
136 muse_cli_db_session: AsyncSession,
137 ) -> None:
138 """Detected factor is within [FACTOR_MIN, FACTOR_MAX]."""
139 _init_muse_repo(tmp_path)
140 result = await _swing_detect_async(
141 root=tmp_path,
142 session=muse_cli_db_session,
143 commit=None,
144 track=None,
145 )
146 factor = result["factor"]
147 assert FACTOR_MIN <= factor <= FACTOR_MAX
148
149
150 @pytest.mark.anyio
151 async def test_swing_detect_label_matches_factor(
152 tmp_path: pathlib.Path,
153 muse_cli_db_session: AsyncSession,
154 ) -> None:
155 """The label in the result is consistent with the factor."""
156 _init_muse_repo(tmp_path)
157 result = await _swing_detect_async(
158 root=tmp_path,
159 session=muse_cli_db_session,
160 commit=None,
161 track=None,
162 )
163 factor = result["factor"]
164 assert result["label"] == swing_label(factor)
165
166
167 @pytest.mark.anyio
168 async def test_swing_detect_explicit_commit_reflected(
169 tmp_path: pathlib.Path,
170 muse_cli_db_session: AsyncSession,
171 ) -> None:
172 """When a commit SHA is given, it appears in the result."""
173 _init_muse_repo(tmp_path)
174 result = await _swing_detect_async(
175 root=tmp_path,
176 session=muse_cli_db_session,
177 commit="deadbeef",
178 track=None,
179 )
180 assert result["commit"] == "deadbeef"
181
182
183 @pytest.mark.anyio
184 async def test_swing_detect_track_reflected(
185 tmp_path: pathlib.Path,
186 muse_cli_db_session: AsyncSession,
187 ) -> None:
188 """When a track filter is given, it appears in the result."""
189 _init_muse_repo(tmp_path)
190 result = await _swing_detect_async(
191 root=tmp_path,
192 session=muse_cli_db_session,
193 commit=None,
194 track="bass",
195 )
196 assert result["track"] == "bass"
197
198
199 @pytest.mark.anyio
200 async def test_swing_detect_no_track_defaults_to_all(
201 tmp_path: pathlib.Path,
202 muse_cli_db_session: AsyncSession,
203 ) -> None:
204 """When no track is given, track defaults to 'all'."""
205 _init_muse_repo(tmp_path)
206 result = await _swing_detect_async(
207 root=tmp_path,
208 session=muse_cli_db_session,
209 commit=None,
210 track=None,
211 )
212 assert result["track"] == "all"
213
214
215 # ---------------------------------------------------------------------------
216 # _swing_history_async
217 # ---------------------------------------------------------------------------
218
219
220 @pytest.mark.anyio
221 async def test_swing_history_returns_list(
222 tmp_path: pathlib.Path,
223 muse_cli_db_session: AsyncSession,
224 ) -> None:
225 """_swing_history_async returns a non-empty list."""
226 _init_muse_repo(tmp_path)
227 entries = await _swing_history_async(
228 root=tmp_path, session=muse_cli_db_session, track=None
229 )
230 assert isinstance(entries, list)
231 assert len(entries) >= 1
232
233
234 @pytest.mark.anyio
235 async def test_swing_history_entries_have_correct_keys(
236 tmp_path: pathlib.Path,
237 muse_cli_db_session: AsyncSession,
238 ) -> None:
239 """Each entry in the history list has the expected keys."""
240 _init_muse_repo(tmp_path)
241 entries = await _swing_history_async(
242 root=tmp_path, session=muse_cli_db_session, track=None
243 )
244 for entry in entries:
245 assert "factor" in entry
246 assert "label" in entry
247 assert "commit" in entry
248
249
250 # ---------------------------------------------------------------------------
251 # _swing_compare_async
252 # ---------------------------------------------------------------------------
253
254
255 @pytest.mark.anyio
256 async def test_swing_compare_returns_head_compare_delta(
257 tmp_path: pathlib.Path,
258 muse_cli_db_session: AsyncSession,
259 ) -> None:
260 """_swing_compare_async returns a dict with head, compare, and delta."""
261 _init_muse_repo(tmp_path)
262 result = await _swing_compare_async(
263 root=tmp_path,
264 session=muse_cli_db_session,
265 compare_commit="abc123",
266 track=None,
267 )
268 assert "head" in result
269 assert "compare" in result
270 assert "delta" in result
271
272
273 @pytest.mark.anyio
274 async def test_swing_compare_delta_is_numeric(
275 tmp_path: pathlib.Path,
276 muse_cli_db_session: AsyncSession,
277 ) -> None:
278 """The delta field is a finite float."""
279 _init_muse_repo(tmp_path)
280 result = await _swing_compare_async(
281 root=tmp_path,
282 session=muse_cli_db_session,
283 compare_commit="abc123",
284 track=None,
285 )
286 delta = result["delta"]
287 assert isinstance(delta, float)
288
289
290 # ---------------------------------------------------------------------------
291 # Output formatters
292 # ---------------------------------------------------------------------------
293
294
295 def _make_detect_result(
296 factor: float = 0.55,
297 label: str = "Light",
298 commit: str = "abc1234",
299 branch: str = "main",
300 track: str = "all",
301 source: str = "stub",
302 ) -> SwingDetectResult:
303 return SwingDetectResult(
304 factor=factor,
305 label=label,
306 commit=commit,
307 branch=branch,
308 track=track,
309 source=source,
310 )
311
312
313 def test_format_detect_text_contains_factor_and_label() -> None:
314 """Text format for detect includes factor and label strings."""
315 output = _format_detect(_make_detect_result(), as_json=False)
316 assert "0.55" in output
317 assert "Light" in output
318
319
320 def test_format_detect_json_is_valid() -> None:
321 """JSON format for detect parses cleanly."""
322 output = _format_detect(_make_detect_result(), as_json=True)
323 parsed = json.loads(output)
324 assert parsed["factor"] == 0.55
325 assert parsed["label"] == "Light"
326
327
328 def test_format_history_text_contains_commit_and_label() -> None:
329 """Text format for history includes commit SHA and label."""
330 entries = [_make_detect_result(commit="deadbeef")]
331 output = _format_history(entries, as_json=False)
332 assert "deadbeef" in output
333 assert "Light" in output
334
335
336 def test_format_history_empty_list_shows_placeholder() -> None:
337 """Empty history list shows a human-readable placeholder."""
338 output = _format_history([], as_json=False)
339 assert "no swing history" in output
340
341
342 def test_format_history_json_is_valid() -> None:
343 """JSON format for history parses cleanly."""
344 entries = [_make_detect_result(commit="deadbeef")]
345 output = _format_history(entries, as_json=True)
346 parsed = json.loads(output)
347 assert isinstance(parsed, list)
348 assert parsed[0]["label"] == "Light"
349
350
351 def test_format_compare_text_shows_delta_sign() -> None:
352 """Positive delta is prefixed with '+' in text mode."""
353 result = SwingCompareResult(
354 head=_make_detect_result(factor=0.57),
355 compare=_make_detect_result(factor=0.55),
356 delta=0.02,
357 )
358 output = _format_compare(result, as_json=False)
359 assert "+0.02" in output
360
361
362 def test_format_compare_negative_delta_no_plus_sign() -> None:
363 """Negative delta has no '+' prefix."""
364 result = SwingCompareResult(
365 head=_make_detect_result(factor=0.55),
366 compare=_make_detect_result(factor=0.57),
367 delta=-0.02,
368 )
369 output = _format_compare(result, as_json=False)
370 assert "-0.02" in output
371 assert "+-0.02" not in output
372
373
374 def test_format_compare_json_is_valid() -> None:
375 """JSON format for compare parses cleanly."""
376 result = SwingCompareResult(
377 head=_make_detect_result(factor=0.57),
378 compare=_make_detect_result(factor=0.55),
379 delta=0.02,
380 )
381 output = _format_compare(result, as_json=True)
382 parsed = json.loads(output)
383 assert "delta" in parsed
384 assert parsed["delta"] == 0.02
385
386
387 # ---------------------------------------------------------------------------
388 # CLI flag parsing via CliRunner
389 # ---------------------------------------------------------------------------
390
391
392 def _make_repo(root: pathlib.Path) -> pathlib.Path:
393 """Create a muse repo and return the root."""
394 _init_muse_repo(root)
395 return root
396
397
398 def test_swing_cli_outside_repo_exits_2(tmp_path: pathlib.Path) -> None:
399 """``muse swing`` exits 2 when not inside a Muse repository."""
400 from maestro.muse_cli.app import cli
401
402 runner = CliRunner()
403 prev = os.getcwd()
404 try:
405 os.chdir(tmp_path)
406 result = runner.invoke(cli, ["swing"], catch_exceptions=False)
407 finally:
408 os.chdir(prev)
409
410 assert result.exit_code == int(ExitCode.REPO_NOT_FOUND)
411
412
413 def test_swing_cli_set_valid_annotates(tmp_path: pathlib.Path) -> None:
414 """``muse swing --set 0.6`` succeeds and echoes the annotation."""
415 from maestro.muse_cli.app import cli
416
417 _init_muse_repo(tmp_path)
418 runner = CliRunner()
419 prev = os.getcwd()
420 try:
421 os.chdir(tmp_path)
422 result = runner.invoke(cli, ["swing", "--set", "0.6"], catch_exceptions=False)
423 finally:
424 os.chdir(prev)
425
426 assert result.exit_code == 0
427 assert "0.6" in result.output
428
429
430 def test_swing_cli_set_out_of_range_exits_1(tmp_path: pathlib.Path) -> None:
431 """``muse swing --set`` with a value outside [0.5, 0.67] exits 1."""
432 from maestro.muse_cli.app import cli
433
434 _init_muse_repo(tmp_path)
435 runner = CliRunner()
436 prev = os.getcwd()
437 try:
438 os.chdir(tmp_path)
439 result = runner.invoke(cli, ["swing", "--set", "0.9"], catch_exceptions=False)
440 finally:
441 os.chdir(prev)
442
443 assert result.exit_code == int(ExitCode.USER_ERROR)
444
445
446 def test_swing_cli_set_below_minimum_exits_1(tmp_path: pathlib.Path) -> None:
447 """``muse swing --set 0.3`` (below minimum) exits 1."""
448 from maestro.muse_cli.app import cli
449
450 _init_muse_repo(tmp_path)
451 runner = CliRunner()
452 prev = os.getcwd()
453 try:
454 os.chdir(tmp_path)
455 result = runner.invoke(cli, ["swing", "--set", "0.3"], catch_exceptions=False)
456 finally:
457 os.chdir(prev)
458
459 assert result.exit_code == int(ExitCode.USER_ERROR)
460
461
462 def test_swing_cli_json_flag_emits_valid_json(tmp_path: pathlib.Path) -> None:
463 """``muse swing --json`` emits valid JSON output."""
464 from maestro.muse_cli.app import cli
465
466 _init_muse_repo(tmp_path)
467 runner = CliRunner()
468 prev = os.getcwd()
469 try:
470 os.chdir(tmp_path)
471 result = runner.invoke(cli, ["swing", "--json"], catch_exceptions=False)
472 finally:
473 os.chdir(prev)
474
475 assert result.exit_code == 0
476 parsed = json.loads(result.output)
477 assert "factor" in parsed
478 assert "label" in parsed
479
480
481 def test_swing_cli_history_flag_succeeds(tmp_path: pathlib.Path) -> None:
482 """``muse swing --history`` exits 0 and emits output."""
483 from maestro.muse_cli.app import cli
484
485 _init_muse_repo(tmp_path)
486 runner = CliRunner()
487 prev = os.getcwd()
488 try:
489 os.chdir(tmp_path)
490 result = runner.invoke(cli, ["swing", "--history"], catch_exceptions=False)
491 finally:
492 os.chdir(prev)
493
494 assert result.exit_code == 0
495 assert len(result.output.strip()) > 0
496
497
498 def test_swing_cli_compare_flag_succeeds(tmp_path: pathlib.Path) -> None:
499 """``muse swing --compare abc123`` exits 0 and shows HEAD/Compare/Delta."""
500 from maestro.muse_cli.app import cli
501
502 _init_muse_repo(tmp_path)
503 runner = CliRunner()
504 prev = os.getcwd()
505 try:
506 os.chdir(tmp_path)
507 result = runner.invoke(
508 cli, ["swing", "--compare", "abc123"], catch_exceptions=False
509 )
510 finally:
511 os.chdir(prev)
512
513 assert result.exit_code == 0
514 output = result.output
515 assert "HEAD" in output
516 assert "Compare" in output
517 assert "Delta" in output
518
519
520 def test_swing_cli_track_flag_reflected_in_set_output(tmp_path: pathlib.Path) -> None:
521 """``muse swing --set 0.6 --track bass`` includes track name in output."""
522 from maestro.muse_cli.app import cli
523
524 _init_muse_repo(tmp_path)
525 runner = CliRunner()
526 prev = os.getcwd()
527 try:
528 os.chdir(tmp_path)
529 result = runner.invoke(
530 cli, ["swing", "--set", "0.6", "--track", "bass"], catch_exceptions=False
531 )
532 finally:
533 os.chdir(prev)
534
535 assert result.exit_code == 0
536 assert "bass" in result.output
537
538
539 def test_swing_cli_set_json_combined(tmp_path: pathlib.Path) -> None:
540 """``muse swing --set 0.6 --json`` emits JSON with the annotation factor."""
541 from maestro.muse_cli.app import cli
542
543 _init_muse_repo(tmp_path)
544 runner = CliRunner()
545 prev = os.getcwd()
546 try:
547 os.chdir(tmp_path)
548 result = runner.invoke(
549 cli, ["swing", "--set", "0.6", "--json"], catch_exceptions=False
550 )
551 finally:
552 os.chdir(prev)
553
554 assert result.exit_code == 0
555 parsed = json.loads(result.output)
556 assert parsed["factor"] == 0.6
557 assert parsed["label"] == swing_label(0.6)