cgcardona / muse public
test_music_midi_merge.py python
465 lines 18.7 KB
0e0cbf44 feat: .museattributes + multidimensional MIDI merge (#11) Gabriel Cardona <cgcardona@gmail.com> 3d ago
1 """Tests for muse/plugins/music/midi_merge.py — dimension-aware MIDI merge."""
2 from __future__ import annotations
3
4 import hashlib
5 import io
6 import pathlib
7
8 import mido
9 import pytest
10
11 from muse.core.attributes import AttributeRule
12 from muse.plugins.music.midi_merge import (
13 INTERNAL_DIMS,
14 DimensionSlice,
15 MidiDimensions,
16 _classify_event,
17 _hash_events,
18 dimension_conflict_detail,
19 extract_dimensions,
20 merge_midi_dimensions,
21 )
22
23
24 # ---------------------------------------------------------------------------
25 # MIDI builder helpers
26 # ---------------------------------------------------------------------------
27
28
29 def _make_midi(
30 *,
31 notes: list[tuple[int, int, int]] | None = None,
32 pitchwheel: list[tuple[int, int]] | None = None,
33 control_change: list[tuple[int, int, int]] | None = None,
34 tempo: int = 500_000,
35 ticks_per_beat: int = 480,
36 ) -> bytes:
37 """Build a minimal type-0 MIDI file in memory.
38
39 Args:
40 notes: List of (abs_tick, note, velocity) note-on events.
41 Each note_on is followed by a note_off 120 ticks later.
42 pitchwheel: List of (abs_tick, pitch) pitchwheel events.
43 control_change: List of (abs_tick, control, value) CC events.
44 tempo: Microseconds per beat (default 120 BPM).
45 ticks_per_beat: MIDI resolution.
46 """
47 mid = mido.MidiFile(type=0, ticks_per_beat=ticks_per_beat)
48 track = mido.MidiTrack()
49
50 # Collect all events with absolute ticks, then sort and convert to delta.
51 events: list[tuple[int, mido.Message]] = []
52 events.append((0, mido.MetaMessage("set_tempo", tempo=tempo, time=0)))
53
54 for abs_tick, note, vel in notes or []:
55 events.append((abs_tick, mido.Message("note_on", note=note, velocity=vel, time=0)))
56 events.append((abs_tick + 120, mido.Message("note_off", note=note, velocity=0, time=0)))
57
58 for abs_tick, pitch in pitchwheel or []:
59 events.append((abs_tick, mido.Message("pitchwheel", pitch=pitch, time=0)))
60
61 for abs_tick, ctrl, val in control_change or []:
62 events.append((abs_tick, mido.Message("control_change", control=ctrl, value=val, time=0)))
63
64 events.sort(key=lambda x: (x[0], x[1].type))
65 prev = 0
66 for abs_tick, msg in events:
67 delta = abs_tick - prev
68 track.append(msg.copy(time=delta))
69 prev = abs_tick
70
71 track.append(mido.MetaMessage("end_of_track", time=0))
72 mid.tracks.append(track)
73
74 buf = io.BytesIO()
75 mid.save(file=buf)
76 return buf.getvalue()
77
78
79 def _midi_bytes_to_notes(midi_bytes: bytes) -> set[int]:
80 """Return the set of note numbers present in note_on events."""
81 mid = mido.MidiFile(file=io.BytesIO(midi_bytes))
82 notes: set[int] = set()
83 for track in mid.tracks:
84 for msg in track:
85 if msg.type == "note_on" and msg.velocity > 0:
86 notes.add(msg.note)
87 return notes
88
89
90 def _midi_bytes_to_pitchwheels(midi_bytes: bytes) -> list[int]:
91 """Return the list of pitchwheel values in order."""
92 mid = mido.MidiFile(file=io.BytesIO(midi_bytes))
93 values: list[int] = []
94 for track in mid.tracks:
95 for msg in track:
96 if msg.type == "pitchwheel":
97 values.append(msg.pitch)
98 return values
99
100
101 def _midi_bytes_to_ccs(midi_bytes: bytes) -> list[tuple[int, int]]:
102 """Return list of (control, value) pairs from CC events."""
103 mid = mido.MidiFile(file=io.BytesIO(midi_bytes))
104 ccs: list[tuple[int, int]] = []
105 for track in mid.tracks:
106 for msg in track:
107 if msg.type == "control_change":
108 ccs.append((msg.control, msg.value))
109 return ccs
110
111
112 # ---------------------------------------------------------------------------
113 # _classify_event
114 # ---------------------------------------------------------------------------
115
116
117 class TestClassifyEvent:
118 def test_note_on(self) -> None:
119 assert _classify_event(mido.Message("note_on", note=60)) == "notes"
120
121 def test_note_off(self) -> None:
122 assert _classify_event(mido.Message("note_off", note=60)) == "notes"
123
124 def test_pitchwheel(self) -> None:
125 assert _classify_event(mido.Message("pitchwheel", pitch=100)) == "harmonic"
126
127 def test_control_change(self) -> None:
128 assert _classify_event(mido.Message("control_change", control=7, value=100)) == "dynamic"
129
130 def test_set_tempo(self) -> None:
131 assert _classify_event(mido.MetaMessage("set_tempo", tempo=500_000)) == "structural"
132
133 def test_time_signature(self) -> None:
134 msg = mido.MetaMessage("time_signature", numerator=4, denominator=4,
135 clocks_per_click=24, notated_32nd_notes_per_beat=8)
136 assert _classify_event(msg) == "structural"
137
138 def test_end_of_track(self) -> None:
139 assert _classify_event(mido.MetaMessage("end_of_track")) == "structural"
140
141
142 # ---------------------------------------------------------------------------
143 # extract_dimensions
144 # ---------------------------------------------------------------------------
145
146
147 class TestExtractDimensions:
148 def test_empty_midi_has_all_dims(self) -> None:
149 midi = _make_midi()
150 dims = extract_dimensions(midi)
151 assert set(dims.slices.keys()) == set(INTERNAL_DIMS)
152
153 def test_notes_in_notes_bucket(self) -> None:
154 midi = _make_midi(notes=[(0, 60, 80), (480, 64, 80)])
155 dims = extract_dimensions(midi)
156 note_events = [msg for _, msg in dims.slices["notes"].events
157 if msg.type == "note_on"]
158 assert len(note_events) == 2
159
160 def test_pitchwheel_in_harmonic(self) -> None:
161 midi = _make_midi(pitchwheel=[(100, 500), (200, -500)])
162 dims = extract_dimensions(midi)
163 assert len(dims.slices["harmonic"].events) == 2
164
165 def test_cc_in_dynamic(self) -> None:
166 midi = _make_midi(control_change=[(0, 7, 100)])
167 dims = extract_dimensions(midi)
168 assert len(dims.slices["dynamic"].events) == 1
169
170 def test_tempo_in_structural(self) -> None:
171 midi = _make_midi(tempo=600_000)
172 dims = extract_dimensions(midi)
173 structural_types = {msg.type for _, msg in dims.slices["structural"].events}
174 assert "set_tempo" in structural_types
175
176 def test_content_hash_is_deterministic(self) -> None:
177 midi = _make_midi(notes=[(0, 60, 80)])
178 d1 = extract_dimensions(midi)
179 d2 = extract_dimensions(midi)
180 assert d1.slices["notes"].content_hash == d2.slices["notes"].content_hash
181
182 def test_different_notes_give_different_hash(self) -> None:
183 midi_a = _make_midi(notes=[(0, 60, 80)])
184 midi_b = _make_midi(notes=[(0, 62, 80)])
185 da = extract_dimensions(midi_a)
186 db = extract_dimensions(midi_b)
187 assert da.slices["notes"].content_hash != db.slices["notes"].content_hash
188
189 def test_same_notes_same_pitchwheel_same_hash(self) -> None:
190 midi_a = _make_midi(notes=[(0, 60, 80)], pitchwheel=[(50, 200)])
191 midi_b = _make_midi(notes=[(0, 60, 80)], pitchwheel=[(50, 200)])
192 da = extract_dimensions(midi_a)
193 db = extract_dimensions(midi_b)
194 assert da.slices["notes"].content_hash == db.slices["notes"].content_hash
195 assert da.slices["harmonic"].content_hash == db.slices["harmonic"].content_hash
196
197 def test_ticks_per_beat_preserved(self) -> None:
198 midi = _make_midi(ticks_per_beat=960)
199 dims = extract_dimensions(midi)
200 assert dims.ticks_per_beat == 960
201
202 def test_invalid_bytes_raises(self) -> None:
203 with pytest.raises(ValueError, match="Failed to parse"):
204 extract_dimensions(b"not a midi file")
205
206 def test_get_via_user_alias(self) -> None:
207 midi = _make_midi(notes=[(0, 60, 80)])
208 dims = extract_dimensions(midi)
209 # "melodic" and "rhythmic" should both map to the "notes" bucket
210 assert dims.get("melodic").name == "notes"
211 assert dims.get("rhythmic").name == "notes"
212 assert dims.get("harmonic").name == "harmonic"
213
214
215 # ---------------------------------------------------------------------------
216 # dimension_conflict_detail
217 # ---------------------------------------------------------------------------
218
219
220 class TestDimensionConflictDetail:
221 def _dims_from(
222 self,
223 notes: list[tuple[int, int, int]] | None = None,
224 pitchwheel: list[tuple[int, int]] | None = None,
225 control_change: list[tuple[int, int, int]] | None = None,
226 tempo: int = 500_000,
227 ) -> MidiDimensions:
228 return extract_dimensions(_make_midi(
229 notes=notes, pitchwheel=pitchwheel, control_change=control_change, tempo=tempo
230 ))
231
232 def test_unchanged_when_all_same(self) -> None:
233 base = self._dims_from(notes=[(0, 60, 80)])
234 detail = dimension_conflict_detail(base, base, base)
235 assert all(v == "unchanged" for v in detail.values())
236
237 def test_left_only_change(self) -> None:
238 base = self._dims_from()
239 left = self._dims_from(notes=[(0, 60, 80)])
240 right = self._dims_from()
241 detail = dimension_conflict_detail(base, left, right)
242 assert detail["notes"] == "left_only"
243 assert detail["harmonic"] == "unchanged"
244
245 def test_right_only_change(self) -> None:
246 base = self._dims_from()
247 left = self._dims_from()
248 right = self._dims_from(pitchwheel=[(0, 100)])
249 detail = dimension_conflict_detail(base, left, right)
250 assert detail["harmonic"] == "right_only"
251
252 def test_both_sides_change(self) -> None:
253 base = self._dims_from()
254 left = self._dims_from(notes=[(0, 60, 80)])
255 right = self._dims_from(notes=[(0, 64, 80)])
256 detail = dimension_conflict_detail(base, left, right)
257 assert detail["notes"] == "both"
258
259 def test_independent_dimension_changes(self) -> None:
260 base = self._dims_from()
261 left = self._dims_from(notes=[(0, 60, 80)]) # changed notes
262 right = self._dims_from(pitchwheel=[(0, 200)]) # changed harmonic
263 detail = dimension_conflict_detail(base, left, right)
264 assert detail["notes"] == "left_only"
265 assert detail["harmonic"] == "right_only"
266 assert detail["dynamic"] == "unchanged"
267
268
269 # ---------------------------------------------------------------------------
270 # merge_midi_dimensions
271 # ---------------------------------------------------------------------------
272
273
274 class TestMergeMidiDimensions:
275 def _midi(
276 self,
277 notes: list[tuple[int, int, int]] | None = None,
278 pitchwheel: list[tuple[int, int]] | None = None,
279 control_change: list[tuple[int, int, int]] | None = None,
280 tempo: int = 500_000,
281 ticks_per_beat: int = 480,
282 ) -> bytes:
283 return _make_midi(
284 notes=notes, pitchwheel=pitchwheel, control_change=control_change,
285 tempo=tempo, ticks_per_beat=ticks_per_beat,
286 )
287
288 def _rules(self, *rules: tuple[str, str, str]) -> list[AttributeRule]:
289 return [AttributeRule(p, d, s, i + 1) for i, (p, d, s) in enumerate(rules)]
290
291 # --- Clean auto-merge: independent dimension changes ------------------
292
293 def test_independent_dims_auto_merge(self) -> None:
294 """Left changed notes, right changed harmonic → clean merge."""
295 base = self._midi()
296 left = self._midi(notes=[(0, 60, 80)])
297 right = self._midi(pitchwheel=[(0, 500)])
298 result = merge_midi_dimensions(base, left, right, [], "song.mid")
299 assert result is not None
300 merged_bytes, report = result
301 assert _midi_bytes_to_notes(merged_bytes) == {60}
302 assert _midi_bytes_to_pitchwheels(merged_bytes) == [500]
303
304 def test_one_side_changed_notes(self) -> None:
305 """Only left changed notes → take left automatically."""
306 base = self._midi()
307 left = self._midi(notes=[(0, 64, 80)])
308 right = self._midi()
309 result = merge_midi_dimensions(base, left, right, [], "song.mid")
310 assert result is not None
311 merged_bytes, _ = result
312 assert _midi_bytes_to_notes(merged_bytes) == {64}
313
314 def test_unchanged_notes_kept(self) -> None:
315 """No changes on either side → preserve base."""
316 base = self._midi(notes=[(0, 60, 80)])
317 result = merge_midi_dimensions(base, base, base, [], "song.mid")
318 assert result is not None
319 merged_bytes, _ = result
320 assert _midi_bytes_to_notes(merged_bytes) == {60}
321
322 # --- File-level and dimension-level strategy override -----------------
323
324 def test_ours_rule_on_notes_conflict(self) -> None:
325 """Both sides changed notes, 'ours' rule → take left notes."""
326 base = self._midi()
327 left = self._midi(notes=[(0, 60, 80)])
328 right = self._midi(notes=[(0, 64, 80)])
329 rules = self._rules(("*", "melodic", "ours"))
330 result = merge_midi_dimensions(base, left, right, rules, "song.mid")
331 assert result is not None
332 merged_bytes, report = result
333 assert _midi_bytes_to_notes(merged_bytes) == {60}
334 assert "notes" in str(report)
335
336 def test_theirs_rule_on_notes_conflict(self) -> None:
337 """Both sides changed notes, 'theirs' rule → take right notes."""
338 base = self._midi()
339 left = self._midi(notes=[(0, 60, 80)])
340 right = self._midi(notes=[(0, 64, 80)])
341 rules = self._rules(("*", "rhythmic", "theirs")) # rhythmic maps to notes
342 result = merge_midi_dimensions(base, left, right, rules, "song.mid")
343 assert result is not None
344 merged_bytes, _ = result
345 assert _midi_bytes_to_notes(merged_bytes) == {64}
346
347 def test_theirs_harmonic_ours_notes(self) -> None:
348 """Left changed notes, right changed harmonic + notes (both), theirs harmonic."""
349 base = self._midi()
350 left = self._midi(notes=[(0, 60, 80)])
351 right = self._midi(notes=[(0, 64, 80)], pitchwheel=[(100, 300)])
352 rules = self._rules(("*", "harmonic", "theirs"), ("*", "melodic", "ours"))
353 result = merge_midi_dimensions(base, left, right, rules, "song.mid")
354 assert result is not None
355 merged_bytes, report = result
356 assert _midi_bytes_to_notes(merged_bytes) == {60} # ours melodic
357 assert _midi_bytes_to_pitchwheels(merged_bytes) == [300] # theirs harmonic
358
359 def test_wildcard_file_strategy_resolves_all_dims(self) -> None:
360 """'* * ours' resolves every dimension to ours."""
361 base = self._midi()
362 left = self._midi(notes=[(0, 60, 80)], pitchwheel=[(0, 200)])
363 right = self._midi(notes=[(0, 64, 80)], pitchwheel=[(0, -200)])
364 rules = self._rules(("*", "*", "ours"))
365 result = merge_midi_dimensions(base, left, right, rules, "song.mid")
366 assert result is not None
367 merged_bytes, _ = result
368 assert _midi_bytes_to_notes(merged_bytes) == {60}
369 assert 200 in _midi_bytes_to_pitchwheels(merged_bytes)
370
371 def test_no_resolvable_strategy_returns_none(self) -> None:
372 """Both sides changed notes, no matching rule → None (file-level conflict)."""
373 base = self._midi()
374 left = self._midi(notes=[(0, 60, 80)])
375 right = self._midi(notes=[(0, 64, 80)])
376 result = merge_midi_dimensions(base, left, right, [], "song.mid")
377 assert result is None
378
379 def test_manual_strategy_returns_none(self) -> None:
380 """manual strategy → cannot auto-resolve → None."""
381 base = self._midi()
382 left = self._midi(notes=[(0, 60, 80)])
383 right = self._midi(notes=[(0, 64, 80)])
384 rules = self._rules(("*", "melodic", "manual"))
385 result = merge_midi_dimensions(base, left, right, rules, "song.mid")
386 assert result is None
387
388 # --- Report content ---------------------------------------------------
389
390 def test_report_shows_winner(self) -> None:
391 base = self._midi()
392 left = self._midi(notes=[(0, 60, 80)])
393 right = self._midi(pitchwheel=[(0, 100)])
394 result = merge_midi_dimensions(base, left, right, [], "song.mid")
395 assert result is not None
396 _, report = result
397 assert report["notes"] == "left"
398 assert report["harmonic"] == "right"
399
400 def test_report_shows_ours_theirs_labels(self) -> None:
401 base = self._midi()
402 left = self._midi(notes=[(0, 60, 80)])
403 right = self._midi(notes=[(0, 64, 80)])
404 rules = self._rules(("*", "melodic", "ours"))
405 result = merge_midi_dimensions(base, left, right, rules, "song.mid")
406 assert result is not None
407 _, report = result
408 assert "ours" in report["notes"]
409
410 # --- Output is valid MIDI ---------------------------------------------
411
412 def test_merged_bytes_parseable(self) -> None:
413 base = self._midi()
414 left = self._midi(notes=[(0, 60, 80)])
415 right = self._midi(pitchwheel=[(0, 100)])
416 result = merge_midi_dimensions(base, left, right, [], "song.mid")
417 assert result is not None
418 merged_bytes, _ = result
419 # Should be parseable without raising
420 parsed = mido.MidiFile(file=io.BytesIO(merged_bytes))
421 assert parsed.ticks_per_beat == 480
422
423 def test_merged_bytes_preserve_ticks_per_beat(self) -> None:
424 base = _make_midi(ticks_per_beat=960)
425 left = _make_midi(notes=[(0, 60, 80)], ticks_per_beat=960)
426 right = _make_midi(pitchwheel=[(0, 100)], ticks_per_beat=960)
427 result = merge_midi_dimensions(base, left, right, [], "song.mid")
428 assert result is not None
429 merged_bytes, _ = result
430 parsed = mido.MidiFile(file=io.BytesIO(merged_bytes))
431 assert parsed.ticks_per_beat == 960
432
433 # --- Path-pattern matching in rules -----------------------------------
434
435 def test_path_specific_rule_respected(self) -> None:
436 """Rule 'keys/* harmonic theirs' only applies to keys/ paths."""
437 base = self._midi()
438 left = self._midi(pitchwheel=[(0, 200)])
439 right = self._midi(pitchwheel=[(0, -200)])
440 rules = self._rules(("keys/*", "harmonic", "theirs"))
441
442 # keys/piano.mid → rule applies
443 result_keys = merge_midi_dimensions(base, left, right, rules, "keys/piano.mid")
444 assert result_keys is not None
445 merged_keys, _ = result_keys
446 assert _midi_bytes_to_pitchwheels(merged_keys) == [-200] # theirs
447
448 # drums/kick.mid → rule does not apply → unresolved
449 result_drums = merge_midi_dimensions(base, left, right, rules, "drums/kick.mid")
450 assert result_drums is None
451
452 # --- CC events --------------------------------------------------------
453
454 def test_dynamic_dimension_merge(self) -> None:
455 base = self._midi()
456 left = self._midi(control_change=[(0, 7, 100)]) # volume up
457 right = self._midi(control_change=[(0, 10, 64)]) # pan center
458 # Both changed dynamic — need a rule
459 rules = self._rules(("*", "dynamic", "ours"))
460 result = merge_midi_dimensions(base, left, right, rules, "song.mid")
461 assert result is not None
462 merged_bytes, _ = result
463 ccs = _midi_bytes_to_ccs(merged_bytes)
464 assert (7, 100) in ccs # ours dynamic
465 assert (10, 64) not in ccs # theirs dynamic excluded