muse_fixtures.py
python
| 1 | """Deterministic fixtures for Muse E2E harness. |
| 2 | |
| 3 | Provides fixed IDs, snapshot builders, and variation payload constructors |
| 4 | so the full VCS lifecycle can be exercised with stable, predictable data. |
| 5 | """ |
| 6 | |
| 7 | from __future__ import annotations |
| 8 | |
| 9 | from typing_extensions import NotRequired, TypedDict |
| 10 | |
| 11 | from maestro.contracts.json_types import CCEventDict, NoteDict |
| 12 | |
| 13 | # ── Fixed IDs ───────────────────────────────────────────────────────────── |
| 14 | |
| 15 | PROJECT_ID = "proj_muse_e2e" |
| 16 | CONVO_ID = "convo_muse_e2e" |
| 17 | |
| 18 | R_KEYS = "r_keys" |
| 19 | R_BASS = "r_bass" |
| 20 | R_DRUMS = "r_drums" |
| 21 | |
| 22 | T_KEYS = "t_keys" |
| 23 | T_BASS = "t_bass" |
| 24 | T_DRUMS = "t_drums" |
| 25 | |
| 26 | C0 = "c0000000-0000-0000-0000-000000000000" |
| 27 | C1 = "c1000000-0000-0000-0000-000000000000" |
| 28 | C2 = "c2000000-0000-0000-0000-000000000000" |
| 29 | C3 = "c3000000-0000-0000-0000-000000000000" |
| 30 | # C4 = merge commit — ID assigned by merge_variations at runtime |
| 31 | C5 = "c5000000-0000-0000-0000-000000000000" |
| 32 | C6 = "c6000000-0000-0000-0000-000000000000" |
| 33 | |
| 34 | _REGION_TRACK_MAP: dict[str, str] = { |
| 35 | R_KEYS: T_KEYS, |
| 36 | R_BASS: T_BASS, |
| 37 | R_DRUMS: T_DRUMS, |
| 38 | } |
| 39 | |
| 40 | |
| 41 | # ── Fixture entities ─────────────────────────────────────────────────────── |
| 42 | |
| 43 | |
| 44 | class MuseNoteChange(TypedDict, total=False): |
| 45 | """One note add/remove record in a region diff. |
| 46 | |
| 47 | Exactly one of ``before`` / ``after`` is ``None``: |
| 48 | - added → before=None, after=<new note> |
| 49 | - removed → before=<old note>, after=None |
| 50 | """ |
| 51 | |
| 52 | note_id: str |
| 53 | change_type: str # "added" | "removed" |
| 54 | before: NoteDict | None |
| 55 | after: NoteDict | None |
| 56 | |
| 57 | |
| 58 | class MusePhrase(TypedDict): |
| 59 | """A phrase (one region's contribution) in a Muse variation payload.""" |
| 60 | |
| 61 | phrase_id: str |
| 62 | track_id: str |
| 63 | region_id: str |
| 64 | start_beat: float |
| 65 | end_beat: float |
| 66 | label: str |
| 67 | note_changes: list[MuseNoteChange] |
| 68 | cc_events: list[CCEventDict] |
| 69 | pitch_bends: list[dict[str, object]] |
| 70 | aftertouch: list[dict[str, object]] |
| 71 | |
| 72 | |
| 73 | class MuseVariationPayload(TypedDict, total=False): |
| 74 | """POST /muse/variations request body built by the fixture helpers.""" |
| 75 | |
| 76 | project_id: str |
| 77 | variation_id: str |
| 78 | intent: str |
| 79 | conversation_id: str |
| 80 | parent_variation_id: str | None |
| 81 | parent2_variation_id: str | None |
| 82 | affected_tracks: list[str] |
| 83 | affected_regions: list[str] |
| 84 | phrases: list[MusePhrase] |
| 85 | beat_range: list[float] |
| 86 | |
| 87 | |
| 88 | # ── Helpers ──────────────────────────────────────────────────────────────── |
| 89 | |
| 90 | |
| 91 | def _track_for(region_id: str) -> str: |
| 92 | return _REGION_TRACK_MAP.get(region_id, region_id.replace("r_", "t_")) |
| 93 | |
| 94 | |
| 95 | # ── Snapshot builders ───────────────────────────────────────────────────── |
| 96 | |
| 97 | |
| 98 | def snapshot_empty() -> dict[str, list[NoteDict]]: |
| 99 | return {} |
| 100 | |
| 101 | |
| 102 | def snapshot_keys_v1() -> dict[str, list[NoteDict]]: |
| 103 | """C major arpeggio — 4 notes in r_keys.""" |
| 104 | return { |
| 105 | R_KEYS: [ |
| 106 | {"pitch": 60, "start_beat": 0.0, "duration_beats": 1.0, "velocity": 100}, |
| 107 | {"pitch": 64, "start_beat": 1.0, "duration_beats": 1.0, "velocity": 90}, |
| 108 | {"pitch": 67, "start_beat": 2.0, "duration_beats": 1.0, "velocity": 80}, |
| 109 | {"pitch": 72, "start_beat": 3.0, "duration_beats": 1.0, "velocity": 100}, |
| 110 | ], |
| 111 | } |
| 112 | |
| 113 | |
| 114 | def snapshot_bass_v1() -> dict[str, list[NoteDict]]: |
| 115 | """Simple root-fifth bass line in r_bass.""" |
| 116 | return { |
| 117 | R_BASS: [ |
| 118 | {"pitch": 36, "start_beat": 0.0, "duration_beats": 2.0, "velocity": 110}, |
| 119 | {"pitch": 43, "start_beat": 2.0, "duration_beats": 2.0, "velocity": 105}, |
| 120 | ], |
| 121 | } |
| 122 | |
| 123 | |
| 124 | def snapshot_drums_v1() -> dict[str, list[NoteDict]]: |
| 125 | """Kick-snare-hat pattern in r_drums.""" |
| 126 | return { |
| 127 | R_DRUMS: [ |
| 128 | {"pitch": 36, "start_beat": 0.0, "duration_beats": 0.5, "velocity": 120}, |
| 129 | {"pitch": 38, "start_beat": 1.0, "duration_beats": 0.5, "velocity": 100}, |
| 130 | {"pitch": 42, "start_beat": 0.0, "duration_beats": 0.25, "velocity": 80}, |
| 131 | {"pitch": 42, "start_beat": 0.5, "duration_beats": 0.25, "velocity": 80}, |
| 132 | ], |
| 133 | } |
| 134 | |
| 135 | |
| 136 | def snapshot_keys_v2_with_cc() -> dict[str, list[NoteDict]]: |
| 137 | """Keys v1 with an extra note at pitch=48 beat=4 — conflict branch A.""" |
| 138 | notes = snapshot_keys_v1()[R_KEYS].copy() |
| 139 | notes.append({"pitch": 48, "start_beat": 4.0, "duration_beats": 1.0, "velocity": 95}) |
| 140 | return {R_KEYS: notes} |
| 141 | |
| 142 | |
| 143 | def snapshot_keys_v3_conflict() -> dict[str, list[NoteDict]]: |
| 144 | """Keys v1 with same pitch=48 beat=4 but different velocity — conflict branch B. |
| 145 | |
| 146 | Overlaps with v2 at the same (pitch, start_beat) so the merge engine |
| 147 | detects a conflicting addition. |
| 148 | """ |
| 149 | notes = snapshot_keys_v1()[R_KEYS].copy() |
| 150 | notes.append({"pitch": 48, "start_beat": 4.0, "duration_beats": 2.0, "velocity": 60}) |
| 151 | return {R_KEYS: notes} |
| 152 | |
| 153 | |
| 154 | def cc_sustain_branch_a() -> dict[str, list[CCEventDict]]: |
| 155 | """CC64 sustain pattern for conflict branch A.""" |
| 156 | return { |
| 157 | R_KEYS: [ |
| 158 | CCEventDict(cc=64, beat=0.0, value=127), |
| 159 | CCEventDict(cc=64, beat=3.0, value=0), |
| 160 | ], |
| 161 | } |
| 162 | |
| 163 | |
| 164 | def cc_sustain_branch_b() -> dict[str, list[CCEventDict]]: |
| 165 | """CC64 sustain pattern for conflict branch B (different values).""" |
| 166 | return { |
| 167 | R_KEYS: [ |
| 168 | CCEventDict(cc=64, beat=0.0, value=64), |
| 169 | CCEventDict(cc=64, beat=2.0, value=0), |
| 170 | ], |
| 171 | } |
| 172 | |
| 173 | |
| 174 | # ── Variation payload builder ───────────────────────────────────────────── |
| 175 | |
| 176 | |
| 177 | def _note_key(n: NoteDict) -> tuple[int, float]: |
| 178 | return (n.get("pitch", 0), n.get("start_beat", 0.0)) |
| 179 | |
| 180 | |
| 181 | def make_variation_payload( |
| 182 | variation_id: str, |
| 183 | intent: str, |
| 184 | base_notes: dict[str, list[NoteDict]], |
| 185 | proposed_notes: dict[str, list[NoteDict]], |
| 186 | *, |
| 187 | parent_variation_id: str | None = None, |
| 188 | parent2_variation_id: str | None = None, |
| 189 | cc_events: dict[str, list[CCEventDict]] | None = None, |
| 190 | ) -> MuseVariationPayload: |
| 191 | """Build a POST /muse/variations request body with proper NoteChange diffs.""" |
| 192 | phrases: list[MusePhrase] = [] |
| 193 | all_regions = sorted(set(base_notes) | set(proposed_notes)) |
| 194 | |
| 195 | for rid in all_regions: |
| 196 | base = base_notes.get(rid, []) |
| 197 | proposed = proposed_notes.get(rid, []) |
| 198 | |
| 199 | base_keys = {_note_key(n) for n in base} |
| 200 | proposed_keys = {_note_key(n) for n in proposed} |
| 201 | |
| 202 | note_changes: list[MuseNoteChange] = [] |
| 203 | for n in proposed: |
| 204 | key = _note_key(n) |
| 205 | if key not in base_keys: |
| 206 | note_changes.append(MuseNoteChange( |
| 207 | note_id=f"nc-{variation_id[:8]}-{rid}-p{key[0]}b{key[1]}", |
| 208 | change_type="added", |
| 209 | before=None, |
| 210 | after=n, |
| 211 | )) |
| 212 | for n in base: |
| 213 | key = _note_key(n) |
| 214 | if key not in proposed_keys: |
| 215 | note_changes.append(MuseNoteChange( |
| 216 | note_id=f"nc-{variation_id[:8]}-{rid}-p{key[0]}b{key[1]}", |
| 217 | change_type="removed", |
| 218 | before=n, |
| 219 | after=None, |
| 220 | )) |
| 221 | |
| 222 | region_cc = (cc_events or {}).get(rid, []) |
| 223 | tid = _track_for(rid) |
| 224 | |
| 225 | phrases.append(MusePhrase( |
| 226 | phrase_id=f"ph-{variation_id[:8]}-{rid}", |
| 227 | track_id=tid, |
| 228 | region_id=rid, |
| 229 | start_beat=0.0, |
| 230 | end_beat=8.0, |
| 231 | label=f"{intent} ({rid})", |
| 232 | note_changes=note_changes, |
| 233 | cc_events=region_cc, |
| 234 | pitch_bends=[], |
| 235 | aftertouch=[], |
| 236 | )) |
| 237 | |
| 238 | return MuseVariationPayload( |
| 239 | project_id=PROJECT_ID, |
| 240 | variation_id=variation_id, |
| 241 | intent=intent, |
| 242 | conversation_id=CONVO_ID, |
| 243 | parent_variation_id=parent_variation_id, |
| 244 | parent2_variation_id=parent2_variation_id, |
| 245 | affected_tracks=[_track_for(r) for r in all_regions], |
| 246 | affected_regions=list(all_regions), |
| 247 | phrases=phrases, |
| 248 | beat_range=[0.0, 8.0], |
| 249 | ) |