cgcardona / muse public
test_music_query.py python
218 lines 7.2 KB
9ee9c39c refactor: rename music→midi domain, strip all 5-dim backward compat Gabriel Cardona <gabriel@tellurstori.com> 1d ago
1 """Tests for muse.plugins.midi._music_query — tokenizer, parser, evaluator."""
2 from __future__ import annotations
3
4 import datetime
5
6 import pytest
7
8 from muse.core.store import CommitRecord
9 from muse.plugins.midi._midi_query import (
10 AndNode,
11 EqNode,
12 NotNode,
13 OrNode,
14 QueryContext,
15 evaluate_node,
16 parse_query,
17 )
18 from muse.plugins.midi._query import NoteInfo
19 from muse.plugins.midi.midi_diff import NoteKey
20
21
22 def _make_commit(
23 agent_id: str = "",
24 author: str = "human",
25 model_id: str = "",
26 toolchain_id: str = "",
27 ) -> CommitRecord:
28 return CommitRecord(
29 commit_id="deadbeef" * 8,
30 repo_id="repo-test",
31 branch="main",
32 snapshot_id="snap123",
33 message="test",
34 author=author,
35 committed_at=datetime.datetime.now(datetime.timezone.utc),
36 agent_id=agent_id,
37 model_id=model_id,
38 toolchain_id=toolchain_id,
39 )
40
41
42 def _make_note(pitch: int = 60, velocity: int = 80, channel: int = 0) -> NoteInfo:
43 return NoteInfo.from_note_key(
44 NoteKey(
45 pitch=pitch,
46 velocity=velocity,
47 start_tick=0,
48 duration_ticks=480,
49 channel=channel,
50 ),
51 ticks_per_beat=480,
52 )
53
54
55 def _make_ctx(
56 notes: list[NoteInfo] | None = None,
57 bar: int = 1,
58 track: str = "piano.mid",
59 chord: str = "Cmaj",
60 commit: CommitRecord | None = None,
61 ) -> QueryContext:
62 return QueryContext(
63 commit=commit or _make_commit(),
64 track=track,
65 bar=bar,
66 notes=notes or [_make_note()],
67 chord=chord,
68 ticks_per_beat=480,
69 )
70
71
72 # ---------------------------------------------------------------------------
73 # Tokenizer / parser
74 # ---------------------------------------------------------------------------
75
76
77 class TestParser:
78 def test_simple_eq_parses(self) -> None:
79 node = parse_query("bar == 4")
80 assert isinstance(node, EqNode)
81 assert node.field == "bar"
82 assert node.op == "=="
83 assert node.value == 4
84
85 def test_and_produces_and_node(self) -> None:
86 node = parse_query("bar == 1 and note.pitch > 60")
87 assert isinstance(node, AndNode)
88
89 def test_or_produces_or_node(self) -> None:
90 node = parse_query("bar == 1 or bar == 2")
91 assert isinstance(node, OrNode)
92
93 def test_not_produces_not_node(self) -> None:
94 node = parse_query("not bar == 4")
95 assert isinstance(node, NotNode)
96
97 def test_parentheses_group_correctly(self) -> None:
98 node = parse_query("(bar == 1 or bar == 2) and note.pitch > 60")
99 assert isinstance(node, AndNode)
100 assert isinstance(node.left, OrNode)
101
102 def test_string_value_parses(self) -> None:
103 node = parse_query("note.pitch_class == 'C'")
104 assert isinstance(node, EqNode)
105 assert node.value == "C"
106
107 def test_double_quoted_string_parses(self) -> None:
108 node = parse_query('track == "piano.mid"')
109 assert isinstance(node, EqNode)
110 assert node.value == "piano.mid"
111
112 def test_float_value_parses(self) -> None:
113 node = parse_query("note.duration > 0.5")
114 assert isinstance(node, EqNode)
115 assert isinstance(node.value, float)
116
117 def test_invalid_query_raises_value_error(self) -> None:
118 with pytest.raises(ValueError):
119 parse_query("bar !!! 4")
120
121 def test_incomplete_query_raises_value_error(self) -> None:
122 with pytest.raises(ValueError):
123 parse_query("bar ==")
124
125
126 # ---------------------------------------------------------------------------
127 # Evaluator — field resolution and comparison
128 # ---------------------------------------------------------------------------
129
130
131 class TestEvaluator:
132 def test_bar_eq_match(self) -> None:
133 assert evaluate_node(parse_query("bar == 4"), _make_ctx(bar=4))
134
135 def test_bar_eq_no_match(self) -> None:
136 assert not evaluate_node(parse_query("bar == 4"), _make_ctx(bar=3))
137
138 def test_note_pitch_gt(self) -> None:
139 ctx = _make_ctx(notes=[_make_note(pitch=65)])
140 assert evaluate_node(parse_query("note.pitch > 60"), ctx)
141
142 def test_note_pitch_gt_false(self) -> None:
143 ctx = _make_ctx(notes=[_make_note(pitch=55)])
144 assert not evaluate_node(parse_query("note.pitch > 60"), ctx)
145
146 def test_note_velocity_lte(self) -> None:
147 ctx = _make_ctx(notes=[_make_note(velocity=80)])
148 assert evaluate_node(parse_query("note.velocity <= 80"), ctx)
149
150 def test_note_pitch_class_match(self) -> None:
151 # Middle C = pitch 60 = C
152 ctx = _make_ctx(notes=[_make_note(pitch=60)])
153 assert evaluate_node(parse_query("note.pitch_class == 'C'"), ctx)
154
155 def test_track_match(self) -> None:
156 ctx = _make_ctx(track="strings.mid")
157 assert evaluate_node(parse_query("track == 'strings.mid'"), ctx)
158
159 def test_chord_match(self) -> None:
160 ctx = _make_ctx(chord="Fmin")
161 assert evaluate_node(parse_query("harmony.chord == 'Fmin'"), ctx)
162
163 def test_author_match(self) -> None:
164 commit = _make_commit(author="alice")
165 ctx = _make_ctx(commit=commit)
166 assert evaluate_node(parse_query("author == 'alice'"), ctx)
167
168 def test_agent_id_match(self) -> None:
169 commit = _make_commit(agent_id="counterpoint-bot")
170 ctx = _make_ctx(commit=commit)
171 assert evaluate_node(parse_query("agent_id == 'counterpoint-bot'"), ctx)
172
173 def test_and_both_must_match(self) -> None:
174 ctx = _make_ctx(notes=[_make_note(pitch=65)], bar=4)
175 assert evaluate_node(parse_query("note.pitch > 60 and bar == 4"), ctx)
176 assert not evaluate_node(parse_query("note.pitch > 60 and bar == 5"), ctx)
177
178 def test_or_one_must_match(self) -> None:
179 ctx = _make_ctx(bar=2)
180 assert evaluate_node(parse_query("bar == 1 or bar == 2"), ctx)
181 assert not evaluate_node(parse_query("bar == 1 or bar == 3"), ctx)
182
183 def test_not_negates(self) -> None:
184 ctx = _make_ctx(bar=4)
185 assert not evaluate_node(parse_query("not bar == 4"), ctx)
186 assert evaluate_node(parse_query("not bar == 5"), ctx)
187
188 def test_multiple_notes_any_match(self) -> None:
189 # If any note in the bar matches, the predicate matches.
190 ctx = _make_ctx(notes=[_make_note(pitch=55), _make_note(pitch=65)])
191 assert evaluate_node(parse_query("note.pitch > 60"), ctx)
192
193 def test_unknown_field_returns_false(self) -> None:
194 ctx = _make_ctx()
195 assert not evaluate_node(parse_query("nonexistent == 'x'"), ctx)
196
197 def test_harmony_quality_min(self) -> None:
198 ctx = _make_ctx(chord="Amin")
199 assert evaluate_node(parse_query("harmony.quality == 'min'"), ctx)
200
201 def test_harmony_quality_dim7(self) -> None:
202 ctx = _make_ctx(chord="Bdim7")
203 assert evaluate_node(parse_query("harmony.quality == 'dim7'"), ctx)
204
205
206 # ---------------------------------------------------------------------------
207 # Note channel field
208 # ---------------------------------------------------------------------------
209
210
211 class TestNoteChannel:
212 def test_channel_eq(self) -> None:
213 ctx = _make_ctx(notes=[_make_note(channel=2)])
214 assert evaluate_node(parse_query("note.channel == 2"), ctx)
215
216 def test_channel_neq(self) -> None:
217 ctx = _make_ctx(notes=[_make_note(channel=3)])
218 assert not evaluate_node(parse_query("note.channel == 2"), ctx)