cgcardona / muse public
_query.py python
304 lines 10.0 KB
59a915a4 refactor: repo root is the working tree — remove state/ subdirectory Gabriel Cardona <gabriel@tellurstori.com> 6h ago
1 """Shared music-domain query helpers for the Muse CLI.
2
3 Provides the low-level primitives that music-domain commands share:
4 note extraction from the object store, bar-level grouping, chord detection,
5 and commit-graph walking specific to MIDI tracks.
6
7 Nothing here belongs in the public ``MidiPlugin`` API. These are CLI-layer
8 helpers — thin adapters over ``midi_diff.extract_notes`` and the core store.
9 """
10
11 from __future__ import annotations
12
13 import logging
14 import pathlib
15 from typing import NamedTuple
16
17 from muse.core.object_store import read_object
18 from muse.core.store import CommitRecord, read_commit, get_commit_snapshot_manifest
19 from muse.plugins.midi.midi_diff import NoteKey, _pitch_name, extract_notes # noqa: PLC2701
20
21 logger = logging.getLogger(__name__)
22
23 # ---------------------------------------------------------------------------
24 # Pitch / music-theory constants
25 # ---------------------------------------------------------------------------
26
27 _PITCH_CLASSES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
28
29 # Chord templates: frozenset of pitch-class offsets (root = 0).
30 _CHORD_TEMPLATES: list[tuple[str, frozenset[int]]] = [
31 ("maj", frozenset({0, 4, 7})),
32 ("min", frozenset({0, 3, 7})),
33 ("dim", frozenset({0, 3, 6})),
34 ("aug", frozenset({0, 4, 8})),
35 ("sus2", frozenset({0, 2, 7})),
36 ("sus4", frozenset({0, 5, 7})),
37 ("dom7", frozenset({0, 4, 7, 10})),
38 ("maj7", frozenset({0, 4, 7, 11})),
39 ("min7", frozenset({0, 3, 7, 10})),
40 ("dim7", frozenset({0, 3, 6, 9})),
41 ("5", frozenset({0, 7})), # power chord
42 ]
43
44 # ---------------------------------------------------------------------------
45 # NoteInfo — enriched note for display
46 # ---------------------------------------------------------------------------
47
48
49 class NoteInfo(NamedTuple):
50 """A ``NoteKey`` with derived musical fields for display."""
51
52 pitch: int
53 velocity: int
54 start_tick: int
55 duration_ticks: int
56 channel: int
57 ticks_per_beat: int
58
59 @property
60 def pitch_name(self) -> str:
61 return _pitch_name(self.pitch)
62
63 @property
64 def beat(self) -> float:
65 return self.start_tick / max(self.ticks_per_beat, 1)
66
67 @property
68 def beat_duration(self) -> float:
69 return self.duration_ticks / max(self.ticks_per_beat, 1)
70
71 @property
72 def bar(self) -> int:
73 """1-indexed bar number (assumes 4/4 time)."""
74 return int(self.start_tick // (4 * max(self.ticks_per_beat, 1))) + 1
75
76 @property
77 def beat_in_bar(self) -> float:
78 """Beat position within the bar (1-indexed)."""
79 tpb = max(self.ticks_per_beat, 1)
80 bar_tick = (self.bar - 1) * 4 * tpb
81 return (self.start_tick - bar_tick) / tpb + 1
82
83 @property
84 def pitch_class(self) -> int:
85 return self.pitch % 12
86
87 @property
88 def pitch_class_name(self) -> str:
89 return _PITCH_CLASSES[self.pitch_class]
90
91 @classmethod
92 def from_note_key(cls, note: NoteKey, ticks_per_beat: int) -> "NoteInfo":
93 return cls(
94 pitch=note["pitch"],
95 velocity=note["velocity"],
96 start_tick=note["start_tick"],
97 duration_ticks=note["duration_ticks"],
98 channel=note["channel"],
99 ticks_per_beat=ticks_per_beat,
100 )
101
102
103 # ---------------------------------------------------------------------------
104 # Track loading from the object store
105 # ---------------------------------------------------------------------------
106
107
108 def load_track(
109 root: pathlib.Path,
110 commit_id: str,
111 track_path: str,
112 ) -> tuple[list[NoteInfo], int] | None:
113 """Load notes for *track_path* from the snapshot at *commit_id*.
114
115 Args:
116 root: Repository root.
117 commit_id: SHA-256 commit ID.
118 track_path: Workspace-relative path to the ``.mid`` file.
119
120 Returns:
121 ``(notes, ticks_per_beat)`` on success, ``None`` when the track is
122 not in the snapshot or the object is missing / unparseable.
123 """
124 manifest: dict[str, str] = get_commit_snapshot_manifest(root, commit_id) or {}
125 object_id = manifest.get(track_path)
126 if object_id is None:
127 return None
128 raw = read_object(root, object_id)
129 if raw is None:
130 return None
131 try:
132 keys, tpb = extract_notes(raw)
133 except ValueError as exc:
134 logger.debug("Cannot parse MIDI %r from commit %s: %s", track_path, commit_id[:8], exc)
135 return None
136 notes = [NoteInfo.from_note_key(k, tpb) for k in keys]
137 return notes, tpb
138
139
140 def load_track_from_workdir(
141 root: pathlib.Path,
142 track_path: str,
143 ) -> tuple[list[NoteInfo], int] | None:
144 """Load notes for *track_path* from ``state/`` (live working tree).
145
146 Args:
147 root: Repository root.
148 track_path: Workspace-relative path to the ``.mid`` file.
149
150 Returns:
151 ``(notes, ticks_per_beat)`` on success, ``None`` when unreadable.
152 """
153 work_path = root / track_path
154 if not work_path.exists():
155 return None
156 raw = work_path.read_bytes()
157 try:
158 keys, tpb = extract_notes(raw)
159 except ValueError as exc:
160 logger.debug("Cannot parse MIDI %r from workdir: %s", track_path, exc)
161 return None
162 notes = [NoteInfo.from_note_key(k, tpb) for k in keys]
163 return notes, tpb
164
165
166 # ---------------------------------------------------------------------------
167 # Musical analysis helpers
168 # ---------------------------------------------------------------------------
169
170
171 def notes_by_bar(notes: list[NoteInfo]) -> dict[int, list[NoteInfo]]:
172 """Group *notes* by 1-indexed bar number (assumes 4/4 time)."""
173 bars: dict[int, list[NoteInfo]] = {}
174 for note in sorted(notes, key=lambda n: (n.start_tick, n.pitch)):
175 bars.setdefault(note.bar, []).append(note)
176 return bars
177
178
179 def detect_chord(pitch_classes: frozenset[int]) -> str:
180 """Return the best chord name for a set of pitch classes.
181
182 Tries every chromatic root and every chord template. Returns the
183 name of the best match (most pitch classes covered) as ``"RootQuality"``
184 e.g. ``"Cmaj"``, ``"Fmin7"``. Returns ``"??"`` when fewer than two
185 distinct pitch classes are present.
186 """
187 if len(pitch_classes) < 2:
188 return "??"
189 best_name = "??"
190 best_score = 0
191 for root in range(12):
192 normalized = frozenset((pc - root) % 12 for pc in pitch_classes)
193 for quality, template in _CHORD_TEMPLATES:
194 overlap = len(normalized & template)
195 if overlap > best_score or (
196 overlap == best_score and overlap == len(template)
197 ):
198 best_score = overlap
199 root_name = _PITCH_CLASSES[root]
200 best_name = f"{root_name}{quality}"
201 return best_name
202
203
204 def key_signature_guess(notes: list[NoteInfo]) -> str:
205 """Guess the key signature from pitch class frequencies.
206
207 Uses the Krumhansl-Schmuckler key-finding algorithm with simplified
208 major and minor profiles. Returns a string like ``"G major"`` or
209 ``"D minor"``.
210 """
211 if not notes:
212 return "unknown"
213
214 # Build pitch class histogram.
215 histogram = [0] * 12
216 for note in notes:
217 histogram[note.pitch_class] += 1
218
219 # Krumhansl-Schmuckler major and minor profiles (normalized).
220 major_profile = [
221 6.35, 2.23, 3.48, 2.33, 4.38, 4.09,
222 2.52, 5.19, 2.39, 3.66, 2.29, 2.88,
223 ]
224 minor_profile = [
225 6.33, 2.68, 3.52, 5.38, 2.60, 3.53,
226 2.54, 4.75, 3.98, 2.69, 3.34, 3.17,
227 ]
228
229 total = max(sum(histogram), 1)
230 h_norm = [v / total for v in histogram]
231
232 best_key = ""
233 best_score = -999.0
234
235 for root in range(12):
236 for mode, profile in [("major", major_profile), ("minor", minor_profile)]:
237 # Rotate profile to this root.
238 score = sum(
239 h_norm[(root + i) % 12] * profile[i] for i in range(12)
240 )
241 if score > best_score:
242 best_score = score
243 best_key = f"{_PITCH_CLASSES[root]} {mode}"
244
245 return best_key
246
247
248 # ---------------------------------------------------------------------------
249 # Commit-graph walking (music-domain specific)
250 # ---------------------------------------------------------------------------
251
252
253 def walk_commits_for_track(
254 root: pathlib.Path,
255 start_commit_id: str,
256 track_path: str,
257 max_commits: int = 10_000,
258 ) -> list[tuple[CommitRecord, dict[str, str] | None]]:
259 """Walk the parent chain from *start_commit_id*, collecting snapshot manifests.
260
261 Returns ``(commit, manifest)`` pairs where ``manifest`` may be ``None``
262 when the commit has no snapshot. Only commits where the track appears
263 in the manifest (or in its parent's manifest) are useful for note-level
264 queries, but we return all so callers can filter.
265 """
266 result: list[tuple[CommitRecord, dict[str, str] | None]] = []
267 seen: set[str] = set()
268 current_id: str | None = start_commit_id
269 while current_id and current_id not in seen and len(result) < max_commits:
270 seen.add(current_id)
271 commit = read_commit(root, current_id)
272 if commit is None:
273 break
274 manifest = get_commit_snapshot_manifest(root, commit.commit_id) or None
275 result.append((commit, manifest))
276 current_id = commit.parent_commit_id
277 return result
278
279
280 # ---------------------------------------------------------------------------
281 # MIDI reconstruction helper (for transpose / mix)
282 # ---------------------------------------------------------------------------
283
284
285 def notes_to_midi_bytes(notes: list[NoteInfo], ticks_per_beat: int) -> bytes:
286 """Reconstruct a MIDI file from a list of ``NoteInfo`` objects.
287
288 Produces a Type-0 single-track MIDI file with one note_on / note_off
289 pair per note. Delegates to
290 :func:`~muse.plugins.midi.midi_diff.reconstruct_midi`.
291 """
292 from muse.plugins.midi.midi_diff import NoteKey, reconstruct_midi
293
294 keys = [
295 NoteKey(
296 pitch=n.pitch,
297 velocity=n.velocity,
298 start_tick=n.start_tick,
299 duration_ticks=n.duration_ticks,
300 channel=n.channel,
301 )
302 for n in notes
303 ]
304 return reconstruct_midi(keys, ticks_per_beat=ticks_per_beat)