midi_parser.py
python
| 1 | """MIDI and MusicXML parsing for ``muse import``. |
| 2 | |
| 3 | Converts standard music file formats into Muse's internal note representation: |
| 4 | a list of :class:`NoteEvent` objects and a :class:`MuseImportData` container. |
| 5 | |
| 6 | Supported formats |
| 7 | ----------------- |
| 8 | - ``.mid`` / ``.midi`` — Standard MIDI File via ``mido`` |
| 9 | - ``.xml`` / ``.musicxml`` — MusicXML via Python's built-in ``xml.etree.ElementTree`` |
| 10 | |
| 11 | Named result types registered in ``docs/reference/type_contracts.md``: |
| 12 | - ``MuseImportData`` |
| 13 | - ``NoteEvent`` |
| 14 | """ |
| 15 | from __future__ import annotations |
| 16 | |
| 17 | import dataclasses |
| 18 | import logging |
| 19 | import pathlib |
| 20 | import xml.etree.ElementTree as ET |
| 21 | from typing import Any |
| 22 | |
| 23 | logger = logging.getLogger(__name__) |
| 24 | |
| 25 | #: File extensions accepted by this module. |
| 26 | SUPPORTED_MIDI_EXTENSIONS = {".mid", ".midi"} |
| 27 | SUPPORTED_XML_EXTENSIONS = {".xml", ".musicxml"} |
| 28 | SUPPORTED_EXTENSIONS = SUPPORTED_MIDI_EXTENSIONS | SUPPORTED_XML_EXTENSIONS |
| 29 | |
| 30 | |
| 31 | @dataclasses.dataclass |
| 32 | class NoteEvent: |
| 33 | """A single sounding note extracted from an imported file.""" |
| 34 | |
| 35 | pitch: int |
| 36 | velocity: int |
| 37 | start_tick: int |
| 38 | duration_ticks: int |
| 39 | channel: int |
| 40 | channel_name: str |
| 41 | |
| 42 | |
| 43 | @dataclasses.dataclass |
| 44 | class MuseImportData: |
| 45 | """All data extracted from a single imported music file.""" |
| 46 | |
| 47 | source_path: pathlib.Path |
| 48 | format: str |
| 49 | ticks_per_beat: int |
| 50 | tempo_bpm: float |
| 51 | notes: list[NoteEvent] |
| 52 | tracks: list[str] |
| 53 | raw_meta: dict[str, Any] |
| 54 | |
| 55 | |
| 56 | def parse_file(path: pathlib.Path) -> MuseImportData: |
| 57 | """Dispatch to the correct parser based on file extension. |
| 58 | |
| 59 | Raises: |
| 60 | ValueError: When the extension is not in :data:`SUPPORTED_EXTENSIONS`. |
| 61 | FileNotFoundError: When the file does not exist. |
| 62 | RuntimeError: When the file is malformed. |
| 63 | """ |
| 64 | if not path.exists(): |
| 65 | raise FileNotFoundError(f"File not found: {path}") |
| 66 | ext = path.suffix.lower() |
| 67 | if ext in SUPPORTED_MIDI_EXTENSIONS: |
| 68 | return parse_midi_file(path) |
| 69 | if ext in SUPPORTED_XML_EXTENSIONS: |
| 70 | return parse_musicxml_file(path) |
| 71 | supported = ", ".join(sorted(SUPPORTED_EXTENSIONS)) |
| 72 | raise ValueError( |
| 73 | f"Unsupported file extension '{path.suffix}'. Supported: {supported}" |
| 74 | ) |
| 75 | |
| 76 | |
| 77 | def parse_midi_file(path: pathlib.Path) -> MuseImportData: |
| 78 | """Parse a Standard MIDI File into a :class:`MuseImportData`. |
| 79 | |
| 80 | Uses ``mido``. Note-on with velocity=0 is treated as note-off. |
| 81 | |
| 82 | Raises: |
| 83 | RuntimeError: When ``mido`` cannot read the file. |
| 84 | """ |
| 85 | try: |
| 86 | import mido |
| 87 | except ImportError: |
| 88 | raise RuntimeError( |
| 89 | "mido is required for MIDI import. " |
| 90 | "It is pre-installed in the Maestro Docker image." |
| 91 | ) |
| 92 | |
| 93 | try: |
| 94 | mid = mido.MidiFile(str(path)) |
| 95 | except Exception as exc: |
| 96 | raise RuntimeError(f"Cannot parse MIDI file '{path}': {exc}") from exc |
| 97 | |
| 98 | ticks_per_beat = int(mid.ticks_per_beat) |
| 99 | tempo_us: int = 500_000 # 120 BPM default |
| 100 | notes: list[NoteEvent] = [] |
| 101 | # (channel, pitch) -> (start_tick, velocity) |
| 102 | active: dict[tuple[int, int], tuple[int, int]] = {} |
| 103 | |
| 104 | for track in mid.tracks: |
| 105 | abs_tick = 0 |
| 106 | for msg in track: |
| 107 | abs_tick += msg.time |
| 108 | if msg.type == "set_tempo": |
| 109 | tempo_us = msg.tempo |
| 110 | elif msg.type == "note_on" and msg.velocity > 0: |
| 111 | active[(msg.channel, msg.note)] = (abs_tick, msg.velocity) |
| 112 | elif msg.type == "note_off" or ( |
| 113 | msg.type == "note_on" and msg.velocity == 0 |
| 114 | ): |
| 115 | key = (msg.channel, msg.note) |
| 116 | if key in active: |
| 117 | start, vel = active.pop(key) |
| 118 | notes.append( |
| 119 | NoteEvent( |
| 120 | pitch=msg.note, |
| 121 | velocity=vel, |
| 122 | start_tick=start, |
| 123 | duration_ticks=max(abs_tick - start, 1), |
| 124 | channel=msg.channel, |
| 125 | channel_name=f"ch{msg.channel}", |
| 126 | ) |
| 127 | ) |
| 128 | |
| 129 | # Notes never closed — truncate to duration 1 |
| 130 | for (ch, pitch), (start, vel) in active.items(): |
| 131 | notes.append( |
| 132 | NoteEvent( |
| 133 | pitch=pitch, |
| 134 | velocity=vel, |
| 135 | start_tick=start, |
| 136 | duration_ticks=1, |
| 137 | channel=ch, |
| 138 | channel_name=f"ch{ch}", |
| 139 | ) |
| 140 | ) |
| 141 | |
| 142 | tempo_bpm = 60_000_000 / tempo_us |
| 143 | tracks = _unique_ordered([n.channel_name for n in notes]) |
| 144 | |
| 145 | logger.debug( |
| 146 | "✅ Parsed MIDI %s: %d notes, %d tracks, %.1f BPM", |
| 147 | path.name, len(notes), len(tracks), tempo_bpm, |
| 148 | ) |
| 149 | return MuseImportData( |
| 150 | source_path=path, |
| 151 | format="midi", |
| 152 | ticks_per_beat=ticks_per_beat, |
| 153 | tempo_bpm=tempo_bpm, |
| 154 | notes=notes, |
| 155 | tracks=tracks, |
| 156 | raw_meta={"num_tracks": len(mid.tracks)}, |
| 157 | ) |
| 158 | |
| 159 | |
| 160 | def parse_musicxml_file(path: pathlib.Path) -> MuseImportData: |
| 161 | """Parse a MusicXML ``<score-partwise>`` file into a :class:`MuseImportData`. |
| 162 | |
| 163 | Raises: |
| 164 | RuntimeError: When the XML is invalid or not a recognisable MusicXML document. |
| 165 | """ |
| 166 | try: |
| 167 | tree = ET.parse(str(path)) |
| 168 | except ET.ParseError as exc: |
| 169 | raise RuntimeError(f"Cannot parse MusicXML file '{path}': {exc}") from exc |
| 170 | |
| 171 | root = tree.getroot() |
| 172 | |
| 173 | # Strip XML namespace prefix, e.g. {http://www.musicxml.org/…}element → element |
| 174 | ns = "" |
| 175 | if root.tag.startswith("{"): |
| 176 | ns = root.tag[: root.tag.index("}") + 1] |
| 177 | |
| 178 | def t(name: str) -> str: |
| 179 | return f"{ns}{name}" |
| 180 | |
| 181 | if root.tag not in (t("score-partwise"), "score-partwise"): |
| 182 | raise RuntimeError( |
| 183 | f"Unrecognised MusicXML root element '{root.tag}'. " |
| 184 | "Expected <score-partwise>." |
| 185 | ) |
| 186 | |
| 187 | tempo_bpm = 120.0 |
| 188 | for direction in root.iter(t("direction")): |
| 189 | sound = direction.find(t("sound")) |
| 190 | if sound is not None: |
| 191 | raw = sound.get("tempo") |
| 192 | if raw is not None: |
| 193 | try: |
| 194 | tempo_bpm = float(raw) |
| 195 | break |
| 196 | except ValueError: |
| 197 | pass |
| 198 | |
| 199 | ticks_per_beat = 480 # internal default for MusicXML |
| 200 | |
| 201 | part_names: list[str] = [] |
| 202 | for pn in root.iter(t("part-name")): |
| 203 | name = (pn.text or "").strip() |
| 204 | part_names.append(name or f"Part {len(part_names) + 1}") |
| 205 | |
| 206 | _STEP_SEMITONE: dict[str, int] = { |
| 207 | "C": 0, "D": 2, "E": 4, "F": 5, "G": 7, "A": 9, "B": 11, |
| 208 | } |
| 209 | |
| 210 | notes: list[NoteEvent] = [] |
| 211 | parts = root.findall(t("part")) |
| 212 | |
| 213 | for ch_idx, part_el in enumerate(parts): |
| 214 | channel_name = part_names[ch_idx] if ch_idx < len(part_names) else f"ch{ch_idx}" |
| 215 | abs_tick = 0 |
| 216 | divisions = 1 |
| 217 | |
| 218 | for measure_el in part_el.findall(t("measure")): |
| 219 | attrs = measure_el.find(t("attributes")) |
| 220 | if attrs is not None: |
| 221 | div_el = attrs.find(t("divisions")) |
| 222 | if div_el is not None and div_el.text: |
| 223 | try: |
| 224 | divisions = int(div_el.text) |
| 225 | except ValueError: |
| 226 | pass |
| 227 | |
| 228 | measure_tick = abs_tick |
| 229 | |
| 230 | for note_el in measure_el.findall(t("note")): |
| 231 | dur_el = note_el.find(t("duration")) |
| 232 | dur_xml = int(dur_el.text) if dur_el is not None and dur_el.text else 0 |
| 233 | dur_ticks = int(dur_xml * ticks_per_beat / max(divisions, 1)) |
| 234 | |
| 235 | if note_el.find(t("rest")) is not None: |
| 236 | measure_tick += dur_ticks |
| 237 | continue |
| 238 | |
| 239 | pitch_el = note_el.find(t("pitch")) |
| 240 | if pitch_el is None: |
| 241 | measure_tick += dur_ticks |
| 242 | continue |
| 243 | |
| 244 | step_el = pitch_el.find(t("step")) |
| 245 | oct_el = pitch_el.find(t("octave")) |
| 246 | alt_el = pitch_el.find(t("alter")) |
| 247 | |
| 248 | step = (step_el.text or "C").strip() if step_el is not None else "C" |
| 249 | octave = int(oct_el.text or "4") if oct_el is not None else 4 |
| 250 | alter = int(float(alt_el.text or "0")) if alt_el is not None else 0 |
| 251 | |
| 252 | pitch = max(0, min(127, (octave + 1) * 12 + _STEP_SEMITONE.get(step, 0) + alter)) |
| 253 | is_chord = note_el.find(t("chord")) is not None |
| 254 | note_start = measure_tick |
| 255 | if not is_chord: |
| 256 | measure_tick += dur_ticks |
| 257 | |
| 258 | notes.append( |
| 259 | NoteEvent( |
| 260 | pitch=pitch, |
| 261 | velocity=80, |
| 262 | start_tick=note_start, |
| 263 | duration_ticks=max(dur_ticks, 1), |
| 264 | channel=ch_idx, |
| 265 | channel_name=channel_name, |
| 266 | ) |
| 267 | ) |
| 268 | |
| 269 | abs_tick = measure_tick |
| 270 | |
| 271 | tracks = _unique_ordered([n.channel_name for n in notes]) |
| 272 | logger.debug( |
| 273 | "✅ Parsed MusicXML %s: %d notes, %d parts, %.1f BPM", |
| 274 | path.name, len(notes), len(parts), tempo_bpm, |
| 275 | ) |
| 276 | return MuseImportData( |
| 277 | source_path=path, |
| 278 | format="musicxml", |
| 279 | ticks_per_beat=ticks_per_beat, |
| 280 | tempo_bpm=tempo_bpm, |
| 281 | notes=notes, |
| 282 | tracks=tracks, |
| 283 | raw_meta={"num_parts": len(parts), "part_names": part_names}, |
| 284 | ) |
| 285 | |
| 286 | |
| 287 | def apply_track_map(notes: list[NoteEvent], track_map: dict[str, str]) -> list[NoteEvent]: |
| 288 | """Return notes with ``channel_name`` fields remapped per *track_map*. |
| 289 | |
| 290 | Keys may be ``"ch<N>"`` or bare channel number strings. |
| 291 | Notes for unmapped channels are returned unchanged. |
| 292 | """ |
| 293 | normalised: dict[int, str] = {} |
| 294 | for key, name in track_map.items(): |
| 295 | k = key.strip() |
| 296 | try: |
| 297 | ch = int(k[2:]) if k.startswith("ch") else int(k) |
| 298 | normalised[ch] = name |
| 299 | except ValueError: |
| 300 | logger.warning("⚠️ Ignoring invalid track-map key %r", key) |
| 301 | |
| 302 | result: list[NoteEvent] = [] |
| 303 | for note in notes: |
| 304 | if note.channel in normalised: |
| 305 | result.append(dataclasses.replace(note, channel_name=normalised[note.channel])) |
| 306 | else: |
| 307 | result.append(note) |
| 308 | return result |
| 309 | |
| 310 | |
| 311 | def parse_track_map_arg(raw: str) -> dict[str, str]: |
| 312 | """Parse ``"ch0=bass,ch1=piano"`` into ``{"ch0": "bass", "ch1": "piano"}``. |
| 313 | |
| 314 | Raises: |
| 315 | ValueError: When any pair is not in ``KEY=VALUE`` format. |
| 316 | """ |
| 317 | result: dict[str, str] = {} |
| 318 | for pair in raw.split(","): |
| 319 | pair = pair.strip() |
| 320 | if not pair: |
| 321 | continue |
| 322 | if "=" not in pair: |
| 323 | raise ValueError( |
| 324 | f"Invalid track-map entry {pair!r}. Expected KEY=VALUE (e.g. ch0=bass)." |
| 325 | ) |
| 326 | key, _, value = pair.partition("=") |
| 327 | result[key.strip()] = value.strip() |
| 328 | return result |
| 329 | |
| 330 | |
| 331 | def analyze_import(data: MuseImportData) -> str: |
| 332 | """Return a multi-line analysis of *data* covering harmonic, rhythmic, and dynamic dimensions.""" |
| 333 | notes = data.notes |
| 334 | if not notes: |
| 335 | return " (no notes found — file may be empty or contain only meta events)" |
| 336 | |
| 337 | _NOTE_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] |
| 338 | pitches = [n.pitch for n in notes] |
| 339 | pitch_counts: dict[int, int] = {} |
| 340 | for p in pitches: |
| 341 | pitch_counts[p] = pitch_counts.get(p, 0) + 1 |
| 342 | top_pitches = sorted(pitch_counts, key=lambda p: -pitch_counts[p])[:5] |
| 343 | top_str = ", ".join( |
| 344 | f"{_NOTE_NAMES[p % 12]}{p // 12 - 1}({pitch_counts[p]}x)" for p in top_pitches |
| 345 | ) |
| 346 | pitch_min, pitch_max = min(pitches), max(pitches) |
| 347 | |
| 348 | total = len(notes) |
| 349 | span_ticks = max(n.start_tick + n.duration_ticks for n in notes) |
| 350 | span_beats = span_ticks / max(data.ticks_per_beat, 1) |
| 351 | density = total / max(span_beats, 0.001) |
| 352 | |
| 353 | velocities = [n.velocity for n in notes] |
| 354 | avg_vel = sum(velocities) / len(velocities) |
| 355 | vel_min, vel_max = min(velocities), max(velocities) |
| 356 | |
| 357 | def _band(v: float) -> str: |
| 358 | if v < 40: return "pp (very soft)" |
| 359 | if v < 65: return "p (soft)" |
| 360 | if v < 85: return "mp/mf (medium)" |
| 361 | if v < 105: return "f (loud)" |
| 362 | return "ff (very loud)" |
| 363 | |
| 364 | track_summary = ", ".join(data.tracks) if data.tracks else "(none)" |
| 365 | return "\n".join([ |
| 366 | f" Format: {data.format}", |
| 367 | f" Tempo: {data.tempo_bpm:.1f} BPM", |
| 368 | f" Tracks: {track_summary}", |
| 369 | "", |
| 370 | " ── Harmonic ──────────────────────────────────", |
| 371 | f" Pitch range: {_NOTE_NAMES[pitch_min % 12]}{pitch_min // 12 - 1}" |
| 372 | f"–{_NOTE_NAMES[pitch_max % 12]}{pitch_max // 12 - 1}", |
| 373 | f" Top pitches: {top_str}", |
| 374 | "", |
| 375 | " ── Rhythmic ──────────────────────────────────", |
| 376 | f" Notes: {total}", |
| 377 | f" Span: {span_beats:.1f} beats", |
| 378 | f" Density: {density:.1f} notes/beat", |
| 379 | "", |
| 380 | " ── Dynamic ───────────────────────────────────", |
| 381 | f" Velocity: avg={avg_vel:.0f}, min={vel_min}, max={vel_max}", |
| 382 | f" Character: {_band(avg_vel)}", |
| 383 | ]) |
| 384 | |
| 385 | |
| 386 | def _unique_ordered(items: list[str]) -> list[str]: |
| 387 | seen: set[str] = set() |
| 388 | result: list[str] = [] |
| 389 | for item in items: |
| 390 | if item not in seen: |
| 391 | seen.add(item) |
| 392 | result.append(item) |
| 393 | return result |