cgcardona / muse public
test_music_invariants.py python
194 lines 7.4 KB
6d8ca4ac feat: god-tier MIDI dimension expansion + full supercharge architecture Gabriel Cardona <gabriel@tellurstori.com> 1d ago
1 """Tests for muse.plugins.music._invariants — rule checks and runner."""
2 from __future__ import annotations
3
4 import pathlib
5
6 import pytest
7
8 from muse.plugins.music._invariants import (
9 InvariantRule,
10 check_key_consistency,
11 check_max_polyphony,
12 check_no_parallel_fifths,
13 check_pitch_range,
14 load_invariant_rules,
15 )
16 from muse.plugins.music._query import NoteInfo
17 from muse.plugins.music.midi_diff import NoteKey
18
19
20 def _note(pitch: int, start_tick: int = 0, duration_ticks: int = 480,
21 velocity: int = 80, channel: int = 0) -> NoteInfo:
22 return NoteInfo.from_note_key(
23 NoteKey(
24 pitch=pitch,
25 velocity=velocity,
26 start_tick=start_tick,
27 duration_ticks=duration_ticks,
28 channel=channel,
29 ),
30 ticks_per_beat=480,
31 )
32
33
34 # ---------------------------------------------------------------------------
35 # check_max_polyphony
36 # ---------------------------------------------------------------------------
37
38
39 class TestCheckMaxPolyphony:
40 def test_no_violation_when_polyphony_ok(self) -> None:
41 notes = [_note(60, 0), _note(64, 480), _note(67, 960)]
42 violations = check_max_polyphony(notes, "track.mid", "poly", "warning", max_simultaneous=4)
43 assert violations == []
44
45 def test_violation_when_too_many_simultaneous(self) -> None:
46 # 5 notes all starting at tick 0 with long duration.
47 notes = [_note(60 + i, 0, 480) for i in range(5)]
48 violations = check_max_polyphony(notes, "track.mid", "poly", "error", max_simultaneous=4)
49 assert len(violations) == 1
50 assert violations[0]["severity"] == "error"
51 assert violations[0]["rule_name"] == "poly"
52
53 def test_violation_mentions_peak_count(self) -> None:
54 notes = [_note(60 + i, 0, 480) for i in range(6)]
55 violations = check_max_polyphony(notes, "track.mid", "poly", "warning", max_simultaneous=4)
56 assert "6" in violations[0]["description"]
57
58 def test_empty_notes_produces_no_violation(self) -> None:
59 violations = check_max_polyphony([], "track.mid", "poly", "warning")
60 assert violations == []
61
62 def test_non_overlapping_notes_ok(self) -> None:
63 # Each note starts after the previous one ends.
64 notes = [_note(60, start_tick=i * 960, duration_ticks=480) for i in range(10)]
65 violations = check_max_polyphony(notes, "track.mid", "poly", "warning", max_simultaneous=4)
66 assert violations == []
67
68
69 # ---------------------------------------------------------------------------
70 # check_pitch_range
71 # ---------------------------------------------------------------------------
72
73
74 class TestCheckPitchRange:
75 def test_all_in_range_produces_no_violation(self) -> None:
76 notes = [_note(60), _note(72), _note(84)]
77 violations = check_pitch_range(notes, "track.mid", "range", "warning",
78 min_pitch=48, max_pitch=96)
79 assert violations == []
80
81 def test_too_low_produces_violation(self) -> None:
82 notes = [_note(36)] # below min=48
83 violations = check_pitch_range(notes, "track.mid", "range", "error",
84 min_pitch=48, max_pitch=96)
85 assert len(violations) == 1
86 assert "36" in violations[0]["description"]
87 assert violations[0]["severity"] == "error"
88
89 def test_too_high_produces_violation(self) -> None:
90 notes = [_note(100)] # above max=96
91 violations = check_pitch_range(notes, "track.mid", "range", "warning",
92 min_pitch=48, max_pitch=96)
93 assert len(violations) == 1
94
95 def test_multiple_out_of_range_produces_multiple_violations(self) -> None:
96 notes = [_note(30), _note(110), _note(60)]
97 violations = check_pitch_range(notes, "t.mid", "r", "info",
98 min_pitch=48, max_pitch=96)
99 assert len(violations) == 2
100
101
102 # ---------------------------------------------------------------------------
103 # check_key_consistency
104 # ---------------------------------------------------------------------------
105
106
107 class TestCheckKeyConsistency:
108 def test_cmajor_notes_no_violation(self) -> None:
109 # C major diatonic: C D E F G A B
110 c_major_pitches = [60, 62, 64, 65, 67, 69, 71] # C4-B4
111 notes = [_note(p) for p in c_major_pitches * 4]
112 violations = check_key_consistency(notes, "t.mid", "key", "info", threshold=0.2)
113 assert violations == []
114
115 def test_empty_notes_produces_no_violation(self) -> None:
116 violations = check_key_consistency([], "t.mid", "key", "warning")
117 assert violations == []
118
119
120 # ---------------------------------------------------------------------------
121 # check_no_parallel_fifths
122 # ---------------------------------------------------------------------------
123
124
125 class TestCheckNoParallelFifths:
126 def test_no_violation_without_parallel_fifths(self) -> None:
127 # Bar 1: C4 (60) and G4 (67) — interval of 7
128 # Bar 2: D4 (62) and E4 (64) — interval of 2 (not a fifth)
129 tpb = 480
130 bar_ticks = tpb * 4
131 notes = [
132 _note(60, start_tick=0, duration_ticks=tpb),
133 _note(67, start_tick=0, duration_ticks=tpb),
134 _note(62, start_tick=bar_ticks, duration_ticks=tpb),
135 _note(64, start_tick=bar_ticks, duration_ticks=tpb),
136 ]
137 violations = check_no_parallel_fifths(notes, "t.mid", "fifths", "warning")
138 assert violations == []
139
140 def test_parallel_fifths_detected(self) -> None:
141 # Bar 1: C4 (60) and G4 (67) — perfect fifth
142 # Bar 2: D4 (62) and A4 (69) — perfect fifth, both voices moved up
143 tpb = 480
144 bar_ticks = tpb * 4
145 notes = [
146 _note(60, start_tick=0, duration_ticks=tpb),
147 _note(67, start_tick=0, duration_ticks=tpb),
148 _note(62, start_tick=bar_ticks, duration_ticks=tpb),
149 _note(69, start_tick=bar_ticks, duration_ticks=tpb),
150 ]
151 violations = check_no_parallel_fifths(notes, "t.mid", "fifths", "warning")
152 assert len(violations) >= 1
153 assert violations[0]["rule_name"] == "fifths"
154
155 def test_not_enough_notes_produces_no_violation(self) -> None:
156 notes = [_note(60)]
157 violations = check_no_parallel_fifths(notes, "t.mid", "fifths", "warning")
158 assert violations == []
159
160
161 # ---------------------------------------------------------------------------
162 # load_invariant_rules
163 # ---------------------------------------------------------------------------
164
165
166 class TestLoadInvariantRules:
167 def test_default_rules_returned_when_no_file(self) -> None:
168 rules = load_invariant_rules(None)
169 assert len(rules) >= 1
170 rule_types = {r["rule_type"] for r in rules}
171 assert "max_polyphony" in rule_types
172
173 def test_missing_file_returns_defaults(self, tmp_path: pathlib.Path) -> None:
174 rules = load_invariant_rules(tmp_path / "nonexistent.toml")
175 assert rules
176
177 def test_toml_file_parsed_correctly(self, tmp_path: pathlib.Path) -> None:
178 toml_content = """
179 [[rule]]
180 name = "test_rule"
181 severity = "error"
182 scope = "track"
183 rule_type = "max_polyphony"
184
185 [rule.params]
186 max_simultaneous = 4
187 """
188 rules_file = tmp_path / "invariants.toml"
189 rules_file.write_text(toml_content)
190 rules = load_invariant_rules(rules_file)
191 assert len(rules) == 1
192 assert rules[0]["name"] == "test_rule"
193 assert rules[0]["severity"] == "error"
194 assert rules[0].get("params", {}).get("max_simultaneous") == 4