cgcardona / muse public
test_import.py python
470 lines 15.1 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Tests for ``muse import`` — MIDI and MusicXML import pipeline.
2
3 All async tests use ``@pytest.mark.anyio``.
4 The ``muse_cli_db_session`` fixture (in tests/muse_cli/conftest.py) provides
5 an isolated in-memory SQLite session; no real Postgres instance is required.
6
7 Test MIDI fixtures are synthesised in-memory using ``mido`` so no binary
8 files need to be committed to the repository.
9 """
10 from __future__ import annotations
11
12 import json
13 import pathlib
14 import struct
15 import uuid
16
17 import pytest
18 from sqlalchemy.ext.asyncio import AsyncSession
19 from sqlalchemy.future import select
20
21 from maestro.muse_cli.commands.import_cmd import _import_async
22 from maestro.muse_cli.errors import ExitCode
23 from maestro.muse_cli.midi_parser import (
24 MuseImportData,
25 NoteEvent,
26 analyze_import,
27 apply_track_map,
28 parse_file,
29 parse_midi_file,
30 parse_musicxml_file,
31 parse_track_map_arg,
32 )
33 from maestro.muse_cli.models import MuseCliCommit
34
35
36 # ---------------------------------------------------------------------------
37 # Helpers
38 # ---------------------------------------------------------------------------
39
40
41 def _init_muse_repo(root: pathlib.Path, repo_id: str | None = None) -> str:
42 """Create a minimal .muse/ layout compatible with _commit_async."""
43 rid = repo_id or str(uuid.uuid4())
44 muse = root / ".muse"
45 (muse / "refs" / "heads").mkdir(parents=True)
46 (muse / "repo.json").write_text(json.dumps({"repo_id": rid, "schema_version": "1"}))
47 (muse / "HEAD").write_text("refs/heads/main")
48 (muse / "refs" / "heads" / "main").write_text("")
49 return rid
50
51
52 def _make_minimal_midi(path: pathlib.Path) -> None:
53 """Write a minimal but valid Type-0 MIDI file using raw bytes.
54
55 Contains a single track with: tempo (120 BPM), note-on C4 ch0, note-off C4 ch0.
56 Using raw bytes avoids requiring mido at test-fixture-creation time.
57 """
58 # MIDI header: MThd, length=6, format=0, ntracks=1, division=480
59 header = b"MThd" + struct.pack(">IHHH", 6, 0, 1, 480)
60
61 # Track events (delta_time, event):
62 # 0 FF 51 03 07 A1 20 — set_tempo: 500000 µs = 120 BPM
63 # 0 90 3C 64 — note_on ch0 pitch=60 vel=100
64 # 240 80 3C 00 — note_off ch0 pitch=60
65 # 0 FF 2F 00 — end_of_track
66 track_data = (
67 b"\x00\xFF\x51\x03\x07\xA1\x20" # tempo
68 b"\x00\x90\x3C\x64" # note_on C4
69 b"\x81\x70\x80\x3C\x00" # delta=240 (varint), note_off
70 b"\x00\xFF\x2F\x00" # end_of_track
71 )
72 track = b"MTrk" + struct.pack(">I", len(track_data)) + track_data
73 path.write_bytes(header + track)
74
75
76 def _make_minimal_musicxml(path: pathlib.Path) -> None:
77 """Write a minimal valid MusicXML file with one part and two notes."""
78 xml = """<?xml version="1.0" encoding="UTF-8"?>
79 <!DOCTYPE score-partwise PUBLIC
80 "-//Recordare//DTD MusicXML 3.1 Partwise//EN"
81 "http://www.musicxml.org/dtds/partwise.dtd">
82 <score-partwise version="3.1">
83 <part-list>
84 <score-part id="P1">
85 <part-name>Piano</part-name>
86 </score-part>
87 </part-list>
88 <part id="P1">
89 <measure number="1">
90 <attributes>
91 <divisions>1</divisions>
92 <key><fifths>0</fifths></key>
93 <time><beats>4</beats><beat-type>4</beat-type></time>
94 </attributes>
95 <direction placement="above">
96 <direction-type><metronome><beat-unit>quarter</beat-unit><per-minute>120</per-minute></metronome></direction-type>
97 <sound tempo="120"/>
98 </direction>
99 <note>
100 <pitch><step>C</step><octave>4</octave></pitch>
101 <duration>1</duration>
102 <type>quarter</type>
103 </note>
104 <note>
105 <pitch><step>E</step><octave>4</octave></pitch>
106 <duration>1</duration>
107 <type>quarter</type>
108 </note>
109 </measure>
110 </part>
111 </score-partwise>
112 """
113 path.write_text(xml)
114
115
116 # ---------------------------------------------------------------------------
117 # midi_parser unit tests
118 # ---------------------------------------------------------------------------
119
120
121 def test_parse_midi_file_returns_note_data(tmp_path: pathlib.Path) -> None:
122 """parse_midi_file extracts at least one NoteEvent from a valid MIDI file."""
123 mid = tmp_path / "song.mid"
124 _make_minimal_midi(mid)
125 data = parse_midi_file(mid)
126 assert data.format == "midi"
127 assert len(data.notes) >= 1
128 assert data.ticks_per_beat == 480
129 assert data.tempo_bpm == pytest.approx(120.0, abs=1.0)
130
131
132 def test_parse_musicxml_creates_commit(tmp_path: pathlib.Path) -> None:
133 """parse_musicxml_file returns a MuseImportData with notes for a valid MusicXML."""
134 xml = tmp_path / "song.musicxml"
135 _make_minimal_musicxml(xml)
136 data = parse_musicxml_file(xml)
137 assert data.format == "musicxml"
138 assert len(data.notes) >= 1
139 assert data.tempo_bpm == pytest.approx(120.0, abs=1.0)
140 assert "Piano" in data.tracks
141
142
143 def test_parse_file_dispatches_by_extension(tmp_path: pathlib.Path) -> None:
144 """`parse_file` dispatches to the correct parser via extension."""
145 mid = tmp_path / "x.mid"
146 _make_minimal_midi(mid)
147 data = parse_file(mid)
148 assert data.format == "midi"
149
150 xml = tmp_path / "x.musicxml"
151 _make_minimal_musicxml(xml)
152 data2 = parse_file(xml)
153 assert data2.format == "musicxml"
154
155
156 def test_import_unsupported_extension_raises_error(tmp_path: pathlib.Path) -> None:
157 """parse_file raises ValueError for unsupported extensions."""
158 bad = tmp_path / "song.mp3"
159 bad.write_bytes(b"not midi")
160 with pytest.raises(ValueError, match="Unsupported file extension"):
161 parse_file(bad)
162
163
164 def test_import_malformed_midi_raises_clear_error(tmp_path: pathlib.Path) -> None:
165 """Malformed MIDI content raises RuntimeError with a clear message (regression test)."""
166 bad = tmp_path / "bad.mid"
167 bad.write_bytes(b"not a midi file at all")
168 with pytest.raises(RuntimeError, match="Cannot parse MIDI file"):
169 parse_midi_file(bad)
170
171
172 def test_import_track_map_assigns_named_tracks(tmp_path: pathlib.Path) -> None:
173 """apply_track_map renames channel_name fields per the provided mapping."""
174 mid = tmp_path / "song.mid"
175 _make_minimal_midi(mid)
176 data = parse_midi_file(mid)
177
178 remapped = apply_track_map(data.notes, {"ch0": "bass", "ch1": "piano"})
179 ch0_notes = [n for n in remapped if n.channel == 0]
180 assert all(n.channel_name == "bass" for n in ch0_notes)
181
182
183 def test_apply_track_map_bare_channel_key(tmp_path: pathlib.Path) -> None:
184 """apply_track_map accepts bare channel numbers as keys (e.g. '0' not 'ch0')."""
185 notes = [NoteEvent(pitch=60, velocity=80, start_tick=0, duration_ticks=100, channel=0, channel_name="ch0")]
186 remapped = apply_track_map(notes, {"0": "bass"})
187 assert remapped[0].channel_name == "bass"
188
189
190 def test_apply_track_map_does_not_mutate_original() -> None:
191 """apply_track_map returns new NoteEvent objects; originals are unchanged."""
192 note = NoteEvent(pitch=60, velocity=80, start_tick=0, duration_ticks=100, channel=0, channel_name="ch0")
193 apply_track_map([note], {"ch0": "bass"})
194 assert note.channel_name == "ch0"
195
196
197 def test_parse_track_map_arg_valid() -> None:
198 """parse_track_map_arg parses comma-separated KEY=VALUE pairs."""
199 result = parse_track_map_arg("ch0=bass,ch1=piano,ch9=drums")
200 assert result == {"ch0": "bass", "ch1": "piano", "ch9": "drums"}
201
202
203 def test_parse_track_map_arg_invalid_raises() -> None:
204 """parse_track_map_arg raises ValueError for malformed entries."""
205 with pytest.raises(ValueError, match="KEY=VALUE"):
206 parse_track_map_arg("ch0=bass,nodivider")
207
208
209 def test_analyze_import_returns_string(tmp_path: pathlib.Path) -> None:
210 """analyze_import produces a non-empty multi-line analysis string."""
211 mid = tmp_path / "song.mid"
212 _make_minimal_midi(mid)
213 data = parse_midi_file(mid)
214 analysis = analyze_import(data)
215 assert "Harmonic" in analysis
216 assert "Rhythmic" in analysis
217 assert "Dynamic" in analysis
218
219
220 def test_analyze_import_empty_notes() -> None:
221 """analyze_import handles empty note lists gracefully."""
222 data = MuseImportData(
223 source_path=pathlib.Path("/tmp/empty.mid"),
224 format="midi",
225 ticks_per_beat=480,
226 tempo_bpm=120.0,
227 notes=[],
228 tracks=[],
229 raw_meta={},
230 )
231 result = analyze_import(data)
232 assert "no notes found" in result
233
234
235 def test_musicxml_part_name_becomes_track(tmp_path: pathlib.Path) -> None:
236 """MusicXML <part-name> elements are used as channel_name values."""
237 xml = tmp_path / "song.xml"
238 _make_minimal_musicxml(xml)
239 data = parse_musicxml_file(xml)
240 assert "Piano" in data.tracks
241 assert all(n.channel_name == "Piano" for n in data.notes if n.channel == 0)
242
243
244 def test_parse_musicxml_malformed_raises(tmp_path: pathlib.Path) -> None:
245 """Malformed XML raises RuntimeError with a clear message."""
246 bad = tmp_path / "bad.xml"
247 bad.write_text("not xml at all <unclosed")
248 with pytest.raises(RuntimeError, match="Cannot parse MusicXML file"):
249 parse_musicxml_file(bad)
250
251
252 # ---------------------------------------------------------------------------
253 # _import_async integration tests
254 # ---------------------------------------------------------------------------
255
256
257 @pytest.mark.anyio
258 async def test_import_midi_creates_commit_with_note_data(
259 tmp_path: pathlib.Path, muse_cli_db_session: AsyncSession
260 ) -> None:
261 """_import_async creates a MuseCliCommit with correct message and copies the file."""
262 _init_muse_repo(tmp_path)
263 mid = tmp_path / "session.mid"
264 _make_minimal_midi(mid)
265
266 commit_id = await _import_async(
267 file_path=mid,
268 root=tmp_path,
269 session=muse_cli_db_session,
270 message="Import original session MIDI",
271 )
272
273 assert commit_id is not None
274 result = await muse_cli_db_session.execute(
275 select(MuseCliCommit).where(MuseCliCommit.commit_id == commit_id)
276 )
277 row = result.scalar_one_or_none()
278 assert row is not None
279 assert row.message == "Import original session MIDI"
280
281 # File was copied into muse-work/imports/
282 dest = tmp_path / "muse-work" / "imports" / "session.mid"
283 assert dest.exists()
284
285 # Metadata JSON was written
286 meta_path = tmp_path / "muse-work" / "imports" / "session.mid.meta.json"
287 assert meta_path.exists()
288 meta = json.loads(meta_path.read_text())
289 assert meta["format"] == "midi"
290 assert meta["note_count"] >= 1
291
292
293 @pytest.mark.anyio
294 async def test_import_default_message_is_import_filename(
295 tmp_path: pathlib.Path, muse_cli_db_session: AsyncSession
296 ) -> None:
297 """When no --message is given the commit message defaults to 'Import <filename>'."""
298 _init_muse_repo(tmp_path)
299 mid = tmp_path / "groove.mid"
300 _make_minimal_midi(mid)
301
302 commit_id = await _import_async(
303 file_path=mid,
304 root=tmp_path,
305 session=muse_cli_db_session,
306 )
307
308 assert commit_id is not None
309 result = await muse_cli_db_session.execute(
310 select(MuseCliCommit).where(MuseCliCommit.commit_id == commit_id)
311 )
312 row = result.scalar_one()
313 assert row.message == "Import groove.mid"
314
315
316 @pytest.mark.anyio
317 async def test_import_track_map_recorded_in_metadata(
318 tmp_path: pathlib.Path, muse_cli_db_session: AsyncSession
319 ) -> None:
320 """--track-map is persisted in the .meta.json file."""
321 _init_muse_repo(tmp_path)
322 mid = tmp_path / "band.mid"
323 _make_minimal_midi(mid)
324
325 await _import_async(
326 file_path=mid,
327 root=tmp_path,
328 session=muse_cli_db_session,
329 track_map={"ch0": "bass", "ch1": "piano", "ch9": "drums"},
330 )
331
332 meta = json.loads(
333 (tmp_path / "muse-work" / "imports" / "band.mid.meta.json").read_text()
334 )
335 assert meta["track_map"] == {"ch0": "bass", "ch1": "piano", "ch9": "drums"}
336
337
338 @pytest.mark.anyio
339 async def test_import_dry_run_no_commit_created(
340 tmp_path: pathlib.Path, muse_cli_db_session: AsyncSession
341 ) -> None:
342 """--dry-run returns None and does not create a commit or copy files."""
343 _init_muse_repo(tmp_path)
344 mid = tmp_path / "check.mid"
345 _make_minimal_midi(mid)
346
347 result = await _import_async(
348 file_path=mid,
349 root=tmp_path,
350 session=muse_cli_db_session,
351 dry_run=True,
352 )
353
354 assert result is None
355
356 # No file copied
357 dest = tmp_path / "muse-work" / "imports" / "check.mid"
358 assert not dest.exists()
359
360 # No commit row in DB
361 rows = await muse_cli_db_session.execute(select(MuseCliCommit))
362 assert rows.scalars().all() == []
363
364
365 @pytest.mark.anyio
366 async def test_import_musicxml_creates_commit(
367 tmp_path: pathlib.Path, muse_cli_db_session: AsyncSession
368 ) -> None:
369 """_import_async handles .musicxml files and creates a valid commit."""
370 _init_muse_repo(tmp_path)
371 xml = tmp_path / "score.musicxml"
372 _make_minimal_musicxml(xml)
373
374 commit_id = await _import_async(
375 file_path=xml,
376 root=tmp_path,
377 session=muse_cli_db_session,
378 message="Import MusicXML score",
379 )
380
381 assert commit_id is not None
382 meta = json.loads(
383 (tmp_path / "muse-work" / "imports" / "score.musicxml.meta.json").read_text()
384 )
385 assert meta["format"] == "musicxml"
386
387
388 @pytest.mark.anyio
389 async def test_import_analyze_runs_context_analysis(
390 tmp_path: pathlib.Path,
391 muse_cli_db_session: AsyncSession,
392 capsys: pytest.CaptureFixture[str],
393 ) -> None:
394 """--analyze prints harmonic, rhythmic, and dynamic analysis after import."""
395 _init_muse_repo(tmp_path)
396 mid = tmp_path / "song.mid"
397 _make_minimal_midi(mid)
398
399 await _import_async(
400 file_path=mid,
401 root=tmp_path,
402 session=muse_cli_db_session,
403 analyze=True,
404 )
405
406 captured = capsys.readouterr()
407 assert "Harmonic" in captured.out
408 assert "Rhythmic" in captured.out
409 assert "Dynamic" in captured.out
410
411
412 @pytest.mark.anyio
413 async def test_import_missing_file_exits_user_error(
414 tmp_path: pathlib.Path, muse_cli_db_session: AsyncSession
415 ) -> None:
416 """Importing a nonexistent file exits with USER_ERROR."""
417 import typer
418
419 _init_muse_repo(tmp_path)
420 missing = tmp_path / "ghost.mid"
421
422 with pytest.raises(typer.Exit) as exc_info:
423 await _import_async(
424 file_path=missing,
425 root=tmp_path,
426 session=muse_cli_db_session,
427 )
428 assert exc_info.value.exit_code == ExitCode.USER_ERROR
429
430
431 @pytest.mark.anyio
432 async def test_import_unsupported_extension_exits_user_error(
433 tmp_path: pathlib.Path, muse_cli_db_session: AsyncSession
434 ) -> None:
435 """Importing an unsupported file extension exits with USER_ERROR."""
436 import typer
437
438 _init_muse_repo(tmp_path)
439 bad = tmp_path / "song.mp3"
440 bad.write_bytes(b"not midi")
441
442 with pytest.raises(typer.Exit) as exc_info:
443 await _import_async(
444 file_path=bad,
445 root=tmp_path,
446 session=muse_cli_db_session,
447 )
448 assert exc_info.value.exit_code == ExitCode.USER_ERROR
449
450
451 @pytest.mark.anyio
452 async def test_import_section_recorded_in_metadata(
453 tmp_path: pathlib.Path, muse_cli_db_session: AsyncSession
454 ) -> None:
455 """--section is persisted in the .meta.json file."""
456 _init_muse_repo(tmp_path)
457 mid = tmp_path / "intro.mid"
458 _make_minimal_midi(mid)
459
460 await _import_async(
461 file_path=mid,
462 root=tmp_path,
463 session=muse_cli_db_session,
464 section="verse",
465 )
466
467 meta = json.loads(
468 (tmp_path / "muse-work" / "imports" / "intro.mid.meta.json").read_text()
469 )
470 assert meta["section"] == "verse"