"""Tests for muse.plugins.midi.entity — NoteEntity, EntityIndex, assign/diff.""" import pathlib import pytest from muse.plugins.midi.entity import ( EntityIndex, EntityIndexEntry, NoteEntity, assign_entity_ids, build_entity_index, diff_with_entity_ids, read_entity_index, write_entity_index, ) from muse.plugins.midi.midi_diff import NoteKey, _note_content_id def _key(pitch: int = 60, velocity: int = 80, start_tick: int = 0, duration_ticks: int = 480, channel: int = 0) -> NoteKey: return NoteKey(pitch=pitch, velocity=velocity, start_tick=start_tick, duration_ticks=duration_ticks, channel=channel) # --------------------------------------------------------------------------- # assign_entity_ids — first commit (no prior index) # --------------------------------------------------------------------------- class TestAssignEntityIdsFirstCommit: def test_all_notes_get_new_entity_ids(self) -> None: notes = [_key(60), _key(62), _key(64)] entities = assign_entity_ids(notes, None, commit_id="c001", op_id="op1") assert len(entities) == 3 eids = {e["entity_id"] for e in entities} assert len(eids) == 3 # all distinct def test_entity_has_correct_pitch_fields(self) -> None: note = _key(pitch=67, velocity=100) entities = assign_entity_ids([note], None, commit_id="c001", op_id="op1") assert entities[0]["pitch"] == 67 assert entities[0]["velocity"] == 100 def test_origin_commit_id_set(self) -> None: note = _key(60) entities = assign_entity_ids([note], None, commit_id="commit-abc", op_id="op1") assert entities[0]["origin_commit_id"] == "commit-abc" # --------------------------------------------------------------------------- # assign_entity_ids — with prior index (exact match) # --------------------------------------------------------------------------- class TestAssignEntityIdsExactMatch: def _make_index(self, notes: list[NoteKey], commit_id: str) -> EntityIndex: entities = assign_entity_ids(notes, None, commit_id=commit_id, op_id="op0") return build_entity_index(entities, "track.mid", commit_id) def test_same_notes_get_same_entity_ids(self) -> None: notes = [_key(60), _key(62), _key(64)] prior = self._make_index(notes, "c001") prior_eids = set(prior["entities"].keys()) entities = assign_entity_ids(notes, prior, commit_id="c002", op_id="op2") new_eids = {e["entity_id"] for e in entities} assert new_eids == prior_eids def test_added_note_gets_new_entity_id(self) -> None: notes = [_key(60), _key(62)] prior = self._make_index(notes, "c001") prior_eids = set(prior["entities"].keys()) new_notes = [_key(60), _key(62), _key(64)] entities = assign_entity_ids(new_notes, prior, commit_id="c002", op_id="op2") new_eids = {e["entity_id"] for e in entities} # The two original notes should retain their IDs. assert prior_eids.issubset(new_eids) # The new note gets a fresh ID. assert len(new_eids - prior_eids) == 1 # --------------------------------------------------------------------------- # diff_with_entity_ids # --------------------------------------------------------------------------- class TestDiffWithEntityIds: def _entities_from(self, notes: list[NoteKey]) -> list[NoteEntity]: return assign_entity_ids(notes, None, commit_id="c001", op_id="op1") def test_no_change_produces_no_ops(self) -> None: notes = [_key(60), _key(62), _key(64)] base = self._entities_from(notes) target = assign_entity_ids( notes, build_entity_index(base, "track.mid", "c001"), commit_id="c002", op_id="op2", ) ops = diff_with_entity_ids(base, target, 480) assert ops == [] def test_added_note_produces_insert(self) -> None: base = self._entities_from([_key(60)]) target_notes = [_key(60), _key(64)] target = assign_entity_ids( target_notes, build_entity_index(base, "track.mid", "c001"), commit_id="c002", op_id="op2", ) ops = diff_with_entity_ids(base, target, 480) op_types = [o["op"] for o in ops] assert "insert" in op_types assert "delete" not in op_types def test_removed_note_produces_delete(self) -> None: base = self._entities_from([_key(60), _key(64)]) target_notes = [_key(60)] target = assign_entity_ids( target_notes, build_entity_index(base, "track.mid", "c001"), commit_id="c002", op_id="op2", ) ops = diff_with_entity_ids(base, target, 480) op_types = [o["op"] for o in ops] assert "delete" in op_types assert "insert" not in op_types def test_velocity_change_produces_mutate(self) -> None: base_note = _key(pitch=60, velocity=80) base = self._entities_from([base_note]) prior_index = build_entity_index(base, "track.mid", "c001") # Change velocity only. changed = _key(pitch=60, velocity=100) target = assign_entity_ids( [changed], prior_index, commit_id="c002", op_id="op2", mutation_threshold_ticks=20, mutation_threshold_velocity=30, ) ops = diff_with_entity_ids(base, target, 480) op_types = [o["op"] for o in ops] # May produce mutate or insert/delete depending on match heuristic. # Accept either — the key test is no crash and some op is emitted. assert len(ops) > 0 # --------------------------------------------------------------------------- # EntityIndex I/O # --------------------------------------------------------------------------- class TestEntityIndexIO: def test_write_and_read_roundtrip(self, tmp_path: pathlib.Path) -> None: notes = [_key(60), _key(62), _key(64)] entities = assign_entity_ids(notes, None, commit_id="c001", op_id="op1") index = build_entity_index(entities, "track.mid", "c001") write_entity_index(tmp_path, "c001", "track.mid", index) recovered = read_entity_index(tmp_path, "c001", "track.mid") assert recovered is not None assert set(recovered["entities"].keys()) == set(index["entities"].keys()) def test_read_missing_returns_none(self, tmp_path: pathlib.Path) -> None: result = read_entity_index(tmp_path, "nonexistent", "track.mid") assert result is None def test_index_has_all_entities(self, tmp_path: pathlib.Path) -> None: notes = [_key(60 + i) for i in range(5)] entities = assign_entity_ids(notes, None, commit_id="c001", op_id="op1") index = build_entity_index(entities, "track.mid", "c001") assert len(index["entities"]) == 5 def test_index_content_ids_match_notes(self) -> None: note = _key(60, 80) entities = assign_entity_ids([note], None, commit_id="c001", op_id="op1") index = build_entity_index(entities, "track.mid", "c001") eid = list(index["entities"].keys())[0] expected_cid = _note_content_id(note) assert index["entities"][eid]["content_id"] == expected_cid