cgcardona / muse public
test_muse_motif.py python
395 lines 13.4 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Tests for the Muse Motif Engine (muse_motif.py) and CLI commands (motif.py).
2
3 Covers:
4 - Pure fingerprint helpers: pitches_to_intervals, invert_intervals,
5 retrograde_intervals, detect_transformation, contour_label, parse_pitch_string.
6 - Async service functions: find_motifs, track_motif, diff_motifs, list_motifs.
7 - CLI command rendering helpers: _format_find, _format_track, _format_diff,
8 _format_list.
9
10 All async tests use @pytest.mark.anyio. No live DB or external API calls.
11 """
12 from __future__ import annotations
13
14 import json
15
16 import pytest
17
18 from maestro.muse_cli.commands.motif import (
19 _format_diff,
20 _format_find,
21 _format_list,
22 _format_track,
23 )
24 from maestro.services.muse_motif import (
25 IntervalSequence,
26 MotifTransformation,
27 contour_label,
28 detect_transformation,
29 diff_motifs,
30 find_motifs,
31 invert_intervals,
32 list_motifs,
33 parse_pitch_string,
34 pitches_to_intervals,
35 retrograde_intervals,
36 track_motif,
37 )
38
39
40 # ---------------------------------------------------------------------------
41 # Pure helper tests
42 # ---------------------------------------------------------------------------
43
44
45 class TestPitchesToIntervals:
46 def test_basic_ascending(self) -> None:
47 pitches = (60, 62, 64, 67)
48 intervals = pitches_to_intervals(pitches)
49 assert intervals == (2, 2, 3)
50
51 def test_basic_descending(self) -> None:
52 pitches = (67, 65, 62, 60)
53 intervals = pitches_to_intervals(pitches)
54 assert intervals == (-2, -3, -2)
55
56 def test_single_note_returns_empty(self) -> None:
57 assert pitches_to_intervals((60,)) == ()
58
59 def test_empty_returns_empty(self) -> None:
60 assert pitches_to_intervals(()) == ()
61
62 def test_two_notes_returns_one_interval(self) -> None:
63 assert pitches_to_intervals((60, 62)) == (2,)
64
65 def test_transposition_invariant(self) -> None:
66 """Motif at C and motif at G should produce identical fingerprints."""
67 c_major = (60, 62, 64)
68 g_major = (67, 69, 71)
69 assert pitches_to_intervals(c_major) == pitches_to_intervals(g_major)
70
71
72 class TestInvertIntervals:
73 def test_inversion_negates_all(self) -> None:
74 intervals: IntervalSequence = (2, 2, -1, 2)
75 assert invert_intervals(intervals) == (-2, -2, 1, -2)
76
77 def test_inversion_of_empty(self) -> None:
78 assert invert_intervals(()) == ()
79
80 def test_double_inversion_identity(self) -> None:
81 intervals: IntervalSequence = (3, -2, 1)
82 assert invert_intervals(invert_intervals(intervals)) == intervals
83
84
85 class TestRetrogradeIntervals:
86 def test_retrograde_reverses_and_negates(self) -> None:
87 """Retrograde of interval sequence = negation of reversed sequence."""
88 intervals: IntervalSequence = (2, 2, -1)
89 assert retrograde_intervals(intervals) == (1, -2, -2)
90
91 def test_retrograde_of_symmetric(self) -> None:
92 """A palindromic interval sequence reversed is its own negation."""
93 intervals: IntervalSequence = (1, -1)
94 retro = retrograde_intervals(intervals)
95 assert retro == (1, -1)
96
97 def test_double_retrograde_identity(self) -> None:
98 intervals: IntervalSequence = (2, -3, 1)
99 assert retrograde_intervals(retrograde_intervals(intervals)) == intervals
100
101
102 class TestDetectTransformation:
103 def test_exact_match(self) -> None:
104 iv: IntervalSequence = (2, 2, -1, 2)
105 assert detect_transformation(iv, iv) == MotifTransformation.EXACT
106
107 def test_inversion_detected(self) -> None:
108 query: IntervalSequence = (2, 2, -1, 2)
109 candidate = invert_intervals(query)
110 assert detect_transformation(query, candidate) == MotifTransformation.INVERSION
111
112 def test_retrograde_detected(self) -> None:
113 query: IntervalSequence = (2, 2, -1, 2)
114 candidate = retrograde_intervals(query)
115 assert detect_transformation(query, candidate) == MotifTransformation.RETROGRADE
116
117 def test_retro_inv_detected(self) -> None:
118 query: IntervalSequence = (2, 2, -1, 2)
119 candidate = invert_intervals(retrograde_intervals(query))
120 assert detect_transformation(query, candidate) == MotifTransformation.RETRO_INV
121
122 def test_unrelated_returns_none(self) -> None:
123 query: IntervalSequence = (2, 2, -1, 2)
124 candidate: IntervalSequence = (5, -3, 7)
125 assert detect_transformation(query, candidate) is None
126
127
128 class TestContourLabel:
129 def test_ascending_step(self) -> None:
130 assert contour_label((1, 2, 1)) == "ascending-step"
131
132 def test_descending_step(self) -> None:
133 assert contour_label((-1, -2, -1)) == "descending-step"
134
135 def test_ascending_leap(self) -> None:
136 assert contour_label((4, 5, 3)) == "ascending-leap"
137
138 def test_descending_leap(self) -> None:
139 assert contour_label((-4, -5, -3)) == "descending-leap"
140
141 def test_arch_shape(self) -> None:
142 assert contour_label((3, 2, -2, -3)) == "arch"
143
144 def test_valley_shape(self) -> None:
145 assert contour_label((-3, -2, 2, 3)) == "valley"
146
147 def test_static_zero_intervals(self) -> None:
148 assert contour_label((0, 0, 0)) == "static"
149
150 def test_empty_is_static(self) -> None:
151 assert contour_label(()) == "static"
152
153 def test_oscillating(self) -> None:
154 assert contour_label((2, -2, 2, -2)) == "oscillating"
155
156
157 class TestParsePitchString:
158 def test_midi_numbers(self) -> None:
159 assert parse_pitch_string("60 62 64 67") == (60, 62, 64, 67)
160
161 def test_note_names_c_major(self) -> None:
162 result = parse_pitch_string("C D E G")
163 assert result == (60, 62, 64, 67)
164
165 def test_note_names_with_sharp(self) -> None:
166 result = parse_pitch_string("C C# D")
167 assert result == (60, 61, 62)
168
169 def test_mixed_case_note_names(self) -> None:
170 result = parse_pitch_string("c d e g")
171 assert result == (60, 62, 64, 67)
172
173 def test_invalid_token_raises(self) -> None:
174 with pytest.raises(ValueError, match="Cannot parse pitch token"):
175 parse_pitch_string("C D XQ")
176
177 def test_out_of_range_midi_raises(self) -> None:
178 with pytest.raises(ValueError):
179 parse_pitch_string("200")
180
181 def test_single_note(self) -> None:
182 assert parse_pitch_string("60") == (60,)
183
184
185 # ---------------------------------------------------------------------------
186 # Async service tests
187 # ---------------------------------------------------------------------------
188
189
190 @pytest.mark.anyio
191 async def test_motif_find_detects_recurring_pattern() -> None:
192 """find_motifs returns a result with at least one motif group above min-length."""
193 result = await find_motifs(
194 commit_id="abc12345",
195 branch="main",
196 min_length=3,
197 )
198 assert result.commit_id == "abc12345"
199 assert result.branch == "main"
200 assert result.min_length == 3
201 assert result.total_found > 0
202 assert len(result.motifs) == result.total_found
203
204
205 @pytest.mark.anyio
206 async def test_motif_find_min_length_filter() -> None:
207 """Increasing min-length reduces (or equals) the number of detected motifs."""
208 result_short = await find_motifs(
209 commit_id="deadbeef",
210 branch="main",
211 min_length=2,
212 )
213 result_long = await find_motifs(
214 commit_id="deadbeef",
215 branch="main",
216 min_length=8,
217 )
218 assert result_long.total_found <= result_short.total_found
219
220
221 @pytest.mark.anyio
222 async def test_motif_find_track_filter_respected() -> None:
223 """find_motifs with a track filter propagates the track name to occurrences."""
224 result = await find_motifs(
225 commit_id="abc12345",
226 branch="main",
227 min_length=3,
228 track="bass",
229 )
230 for group in result.motifs:
231 for occ in group.occurrences:
232 assert occ.track == "bass"
233
234
235 @pytest.mark.anyio
236 async def test_motif_find_source_is_stub() -> None:
237 """Stub implementation always returns source='stub'."""
238 result = await find_motifs(commit_id="abc12345", branch="main", min_length=3)
239 assert result.source == "stub"
240
241
242 @pytest.mark.anyio
243 async def test_motif_find_motifs_sorted_by_count_descending() -> None:
244 """Motif groups are sorted by occurrence count, highest first."""
245 result = await find_motifs(commit_id="abc12345", branch="main", min_length=3)
246 counts = [g.count for g in result.motifs]
247 assert counts == sorted(counts, reverse=True)
248
249
250 @pytest.mark.anyio
251 async def test_motif_track_finds_transpositions() -> None:
252 """track_motif parses the pattern and returns a MotifTrackResult."""
253 result = await track_motif(pattern="C D E G", commit_ids=["abc12345"])
254 assert result.pattern == "C D E G"
255 assert result.fingerprint == (2, 2, 3)
256 assert result.total_commits_scanned == 1
257 assert len(result.occurrences) == 1
258
259
260 @pytest.mark.anyio
261 async def test_motif_track_empty_commit_list() -> None:
262 """track_motif with no commits returns an empty occurrence list."""
263 result = await track_motif(pattern="60 62 64", commit_ids=[])
264 assert result.total_commits_scanned == 0
265 assert len(result.occurrences) == 0
266
267
268 @pytest.mark.anyio
269 async def test_motif_track_multiple_commits() -> None:
270 """track_motif scans all provided commit IDs."""
271 commit_ids = ["aaa11111", "bbb22222", "ccc33333"]
272 result = await track_motif(pattern="C D E", commit_ids=commit_ids)
273 assert result.total_commits_scanned == 3
274 assert len(result.occurrences) == 3
275 found_ids = {occ.commit_id for occ in result.occurrences}
276 assert found_ids == {"aaa11111"[:8], "bbb22222"[:8], "ccc33333"[:8]}
277
278
279 @pytest.mark.anyio
280 async def test_motif_track_invalid_pattern_raises() -> None:
281 """track_motif raises ValueError for an unparseable pattern."""
282 with pytest.raises(ValueError):
283 await track_motif(pattern="C D XYZ", commit_ids=["abc"])
284
285
286 @pytest.mark.anyio
287 async def test_motif_diff_identifies_inversion() -> None:
288 """diff_motifs returns a MotifDiffResult with a valid transformation."""
289 result = await diff_motifs(commit_a_id="aaa11111", commit_b_id="bbb22222")
290 assert result.commit_a.commit_id == "aaa11111"[:8]
291 assert result.commit_b.commit_id == "bbb22222"[:8]
292 assert result.transformation in MotifTransformation.__members__.values()
293 assert result.description
294
295
296 @pytest.mark.anyio
297 async def test_motif_diff_source_is_stub() -> None:
298 result = await diff_motifs(commit_a_id="aaa", commit_b_id="bbb")
299 assert result.source == "stub"
300
301
302 @pytest.mark.anyio
303 async def test_motif_list_returns_named_motifs() -> None:
304 """list_motifs returns a MotifListResult with named motif entries."""
305 result = await list_motifs(muse_dir_path="/tmp/fake-muse")
306 assert len(result.motifs) > 0
307 names = {m.name for m in result.motifs}
308 assert "main-theme" in names
309
310
311 @pytest.mark.anyio
312 async def test_motif_list_source_is_stub() -> None:
313 result = await list_motifs(muse_dir_path="/tmp/fake-muse")
314 assert result.source == "stub"
315
316
317 @pytest.mark.anyio
318 async def test_motif_list_fingerprints_are_non_empty() -> None:
319 result = await list_motifs(muse_dir_path="/tmp/fake-muse")
320 for motif in result.motifs:
321 assert len(motif.fingerprint) > 0
322
323
324 # ---------------------------------------------------------------------------
325 # Formatter / rendering tests
326 # ---------------------------------------------------------------------------
327
328
329 @pytest.mark.anyio
330 async def test_format_find_text_output() -> None:
331 """_format_find produces non-empty tabular text."""
332 result = await find_motifs(commit_id="abc12345", branch="main", min_length=3)
333 text = _format_find(result, as_json=False)
334 assert "Recurring motifs" in text
335 assert "ascending" in text or "descending" in text or "arch" in text
336
337
338 @pytest.mark.anyio
339 async def test_format_find_json_output_valid() -> None:
340 """_format_find with as_json=True produces parseable JSON."""
341 result = await find_motifs(commit_id="abc12345", branch="main", min_length=3)
342 raw = _format_find(result, as_json=True)
343 parsed = json.loads(raw)
344 assert "motifs" in parsed
345 assert parsed["total_found"] == result.total_found
346
347
348 @pytest.mark.anyio
349 async def test_format_track_text_output() -> None:
350 result = await track_motif(pattern="C D E G", commit_ids=["abc12345"])
351 text = _format_track(result, as_json=False)
352 assert "Tracking motif" in text
353
354
355 @pytest.mark.anyio
356 async def test_format_track_json_output_valid() -> None:
357 result = await track_motif(pattern="C D E G", commit_ids=["abc12345"])
358 raw = _format_track(result, as_json=True)
359 parsed = json.loads(raw)
360 assert "fingerprint" in parsed
361 assert "occurrences" in parsed
362
363
364 @pytest.mark.anyio
365 async def test_format_diff_text_output() -> None:
366 result = await diff_motifs(commit_a_id="aaa11111", commit_b_id="bbb22222")
367 text = _format_diff(result, as_json=False)
368 assert "Motif diff" in text
369 assert "Transformation" in text
370
371
372 @pytest.mark.anyio
373 async def test_format_diff_json_output_valid() -> None:
374 result = await diff_motifs(commit_a_id="aaa11111", commit_b_id="bbb22222")
375 raw = _format_diff(result, as_json=True)
376 parsed = json.loads(raw)
377 assert "transformation" in parsed
378 assert "commit_a" in parsed
379 assert "commit_b" in parsed
380
381
382 @pytest.mark.anyio
383 async def test_format_list_text_output() -> None:
384 result = await list_motifs(muse_dir_path="/tmp/fake-muse")
385 text = _format_list(result, as_json=False)
386 assert "main-theme" in text
387
388
389 @pytest.mark.anyio
390 async def test_format_list_json_output_valid() -> None:
391 result = await list_motifs(muse_dir_path="/tmp/fake-muse")
392 raw = _format_list(result, as_json=True)
393 parsed = json.loads(raw)
394 assert "motifs" in parsed
395 assert len(parsed["motifs"]) > 0