cgcardona / muse public
test_music_query.py python
217 lines 7.2 KB
e6786943 feat: upgrade to Python 3.14, drop from __future__ import annotations Gabriel Cardona <cgcardona@gmail.com> 1d ago
1 """Tests for muse.plugins.midi._music_query — tokenizer, parser, evaluator."""
2
3 import datetime
4
5 import pytest
6
7 from muse.core.store import CommitRecord
8 from muse.plugins.midi._midi_query import (
9 AndNode,
10 EqNode,
11 NotNode,
12 OrNode,
13 QueryContext,
14 evaluate_node,
15 parse_query,
16 )
17 from muse.plugins.midi._query import NoteInfo
18 from muse.plugins.midi.midi_diff import NoteKey
19
20
21 def _make_commit(
22 agent_id: str = "",
23 author: str = "human",
24 model_id: str = "",
25 toolchain_id: str = "",
26 ) -> CommitRecord:
27 return CommitRecord(
28 commit_id="deadbeef" * 8,
29 repo_id="repo-test",
30 branch="main",
31 snapshot_id="snap123",
32 message="test",
33 author=author,
34 committed_at=datetime.datetime.now(datetime.timezone.utc),
35 agent_id=agent_id,
36 model_id=model_id,
37 toolchain_id=toolchain_id,
38 )
39
40
41 def _make_note(pitch: int = 60, velocity: int = 80, channel: int = 0) -> NoteInfo:
42 return NoteInfo.from_note_key(
43 NoteKey(
44 pitch=pitch,
45 velocity=velocity,
46 start_tick=0,
47 duration_ticks=480,
48 channel=channel,
49 ),
50 ticks_per_beat=480,
51 )
52
53
54 def _make_ctx(
55 notes: list[NoteInfo] | None = None,
56 bar: int = 1,
57 track: str = "piano.mid",
58 chord: str = "Cmaj",
59 commit: CommitRecord | None = None,
60 ) -> QueryContext:
61 return QueryContext(
62 commit=commit or _make_commit(),
63 track=track,
64 bar=bar,
65 notes=notes or [_make_note()],
66 chord=chord,
67 ticks_per_beat=480,
68 )
69
70
71 # ---------------------------------------------------------------------------
72 # Tokenizer / parser
73 # ---------------------------------------------------------------------------
74
75
76 class TestParser:
77 def test_simple_eq_parses(self) -> None:
78 node = parse_query("bar == 4")
79 assert isinstance(node, EqNode)
80 assert node.field == "bar"
81 assert node.op == "=="
82 assert node.value == 4
83
84 def test_and_produces_and_node(self) -> None:
85 node = parse_query("bar == 1 and note.pitch > 60")
86 assert isinstance(node, AndNode)
87
88 def test_or_produces_or_node(self) -> None:
89 node = parse_query("bar == 1 or bar == 2")
90 assert isinstance(node, OrNode)
91
92 def test_not_produces_not_node(self) -> None:
93 node = parse_query("not bar == 4")
94 assert isinstance(node, NotNode)
95
96 def test_parentheses_group_correctly(self) -> None:
97 node = parse_query("(bar == 1 or bar == 2) and note.pitch > 60")
98 assert isinstance(node, AndNode)
99 assert isinstance(node.left, OrNode)
100
101 def test_string_value_parses(self) -> None:
102 node = parse_query("note.pitch_class == 'C'")
103 assert isinstance(node, EqNode)
104 assert node.value == "C"
105
106 def test_double_quoted_string_parses(self) -> None:
107 node = parse_query('track == "piano.mid"')
108 assert isinstance(node, EqNode)
109 assert node.value == "piano.mid"
110
111 def test_float_value_parses(self) -> None:
112 node = parse_query("note.duration > 0.5")
113 assert isinstance(node, EqNode)
114 assert isinstance(node.value, float)
115
116 def test_invalid_query_raises_value_error(self) -> None:
117 with pytest.raises(ValueError):
118 parse_query("bar !!! 4")
119
120 def test_incomplete_query_raises_value_error(self) -> None:
121 with pytest.raises(ValueError):
122 parse_query("bar ==")
123
124
125 # ---------------------------------------------------------------------------
126 # Evaluator — field resolution and comparison
127 # ---------------------------------------------------------------------------
128
129
130 class TestEvaluator:
131 def test_bar_eq_match(self) -> None:
132 assert evaluate_node(parse_query("bar == 4"), _make_ctx(bar=4))
133
134 def test_bar_eq_no_match(self) -> None:
135 assert not evaluate_node(parse_query("bar == 4"), _make_ctx(bar=3))
136
137 def test_note_pitch_gt(self) -> None:
138 ctx = _make_ctx(notes=[_make_note(pitch=65)])
139 assert evaluate_node(parse_query("note.pitch > 60"), ctx)
140
141 def test_note_pitch_gt_false(self) -> None:
142 ctx = _make_ctx(notes=[_make_note(pitch=55)])
143 assert not evaluate_node(parse_query("note.pitch > 60"), ctx)
144
145 def test_note_velocity_lte(self) -> None:
146 ctx = _make_ctx(notes=[_make_note(velocity=80)])
147 assert evaluate_node(parse_query("note.velocity <= 80"), ctx)
148
149 def test_note_pitch_class_match(self) -> None:
150 # Middle C = pitch 60 = C
151 ctx = _make_ctx(notes=[_make_note(pitch=60)])
152 assert evaluate_node(parse_query("note.pitch_class == 'C'"), ctx)
153
154 def test_track_match(self) -> None:
155 ctx = _make_ctx(track="strings.mid")
156 assert evaluate_node(parse_query("track == 'strings.mid'"), ctx)
157
158 def test_chord_match(self) -> None:
159 ctx = _make_ctx(chord="Fmin")
160 assert evaluate_node(parse_query("harmony.chord == 'Fmin'"), ctx)
161
162 def test_author_match(self) -> None:
163 commit = _make_commit(author="alice")
164 ctx = _make_ctx(commit=commit)
165 assert evaluate_node(parse_query("author == 'alice'"), ctx)
166
167 def test_agent_id_match(self) -> None:
168 commit = _make_commit(agent_id="counterpoint-bot")
169 ctx = _make_ctx(commit=commit)
170 assert evaluate_node(parse_query("agent_id == 'counterpoint-bot'"), ctx)
171
172 def test_and_both_must_match(self) -> None:
173 ctx = _make_ctx(notes=[_make_note(pitch=65)], bar=4)
174 assert evaluate_node(parse_query("note.pitch > 60 and bar == 4"), ctx)
175 assert not evaluate_node(parse_query("note.pitch > 60 and bar == 5"), ctx)
176
177 def test_or_one_must_match(self) -> None:
178 ctx = _make_ctx(bar=2)
179 assert evaluate_node(parse_query("bar == 1 or bar == 2"), ctx)
180 assert not evaluate_node(parse_query("bar == 1 or bar == 3"), ctx)
181
182 def test_not_negates(self) -> None:
183 ctx = _make_ctx(bar=4)
184 assert not evaluate_node(parse_query("not bar == 4"), ctx)
185 assert evaluate_node(parse_query("not bar == 5"), ctx)
186
187 def test_multiple_notes_any_match(self) -> None:
188 # If any note in the bar matches, the predicate matches.
189 ctx = _make_ctx(notes=[_make_note(pitch=55), _make_note(pitch=65)])
190 assert evaluate_node(parse_query("note.pitch > 60"), ctx)
191
192 def test_unknown_field_returns_false(self) -> None:
193 ctx = _make_ctx()
194 assert not evaluate_node(parse_query("nonexistent == 'x'"), ctx)
195
196 def test_harmony_quality_min(self) -> None:
197 ctx = _make_ctx(chord="Amin")
198 assert evaluate_node(parse_query("harmony.quality == 'min'"), ctx)
199
200 def test_harmony_quality_dim7(self) -> None:
201 ctx = _make_ctx(chord="Bdim7")
202 assert evaluate_node(parse_query("harmony.quality == 'dim7'"), ctx)
203
204
205 # ---------------------------------------------------------------------------
206 # Note channel field
207 # ---------------------------------------------------------------------------
208
209
210 class TestNoteChannel:
211 def test_channel_eq(self) -> None:
212 ctx = _make_ctx(notes=[_make_note(channel=2)])
213 assert evaluate_node(parse_query("note.channel == 2"), ctx)
214
215 def test_channel_neq(self) -> None:
216 ctx = _make_ctx(notes=[_make_note(channel=3)])
217 assert not evaluate_node(parse_query("note.channel == 2"), ctx)