cgcardona / muse public
muse_merge.py python
503 lines 16.6 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Muse Merge Engine — three-way merge for musical variations.
2
3 Produces a ``MergeResult`` by comparing base, left, and right snapshots.
4 Auto-merges non-conflicting changes; reports conflicts when both sides
5 modify the same note or controller event.
6
7 After a conflict-free merge, :func:`build_merge_checkout_plan` attempts to
8 auto-apply any cached rerere resolution so that repeated identical conflicts
9 are resolved without user intervention.
10
11 Boundary rules:
12 - Must NOT import StateStore, executor, MCP tools, or handlers.
13 - May import muse_repository, muse_replay, muse_checkout, note_matching.
14 """
15
16 from __future__ import annotations
17
18 import logging
19 from collections.abc import Callable, Sequence
20 from dataclasses import dataclass
21 from pathlib import Path
22 from typing import Literal, TypeVar
23
24 from sqlalchemy.ext.asyncio import AsyncSession
25
26 from maestro.contracts.json_types import (
27 AftertouchDict,
28 CCEventDict,
29 NoteDict,
30 PitchBendDict,
31 RegionAftertouchMap,
32 RegionCCMap,
33 RegionNotesMap,
34 RegionPitchBendMap,
35 )
36 from maestro.services.muse_checkout import CheckoutPlan, build_checkout_plan
37 from maestro.services.muse_merge_base import find_merge_base
38 from maestro.services.muse_replay import HeadSnapshot, reconstruct_variation_snapshot
39 from maestro.services.variation.note_matching import (
40 EventMatch,
41 NoteMatch,
42 match_aftertouch,
43 match_cc_events,
44 match_notes,
45 match_pitch_bends,
46 )
47
48 logger = logging.getLogger(__name__)
49
50 # Mirrors the constrained TypeVar in note_matching so _merge_event_layer
51 # can propagate the concrete event type (CCEventDict, PitchBendDict, or
52 # AftertouchDict) without overloads or casts.
53 _EV = TypeVar("_EV", CCEventDict, PitchBendDict, AftertouchDict)
54
55
56 # ---------------------------------------------------------------------------
57 # Data types
58 # ---------------------------------------------------------------------------
59
60
61 @dataclass(frozen=True)
62 class MergeConflict:
63 """A single unresolvable conflict between left and right."""
64
65 region_id: str
66 type: Literal["note", "cc", "pb", "at"]
67 description: str
68
69
70 @dataclass(frozen=True)
71 class MergeResult:
72 """Outcome of a three-way merge."""
73
74 has_conflicts: bool
75 conflicts: tuple[MergeConflict, ...]
76 merged_snapshot: HeadSnapshot | None
77
78
79 @dataclass(frozen=True)
80 class ThreeWaySnapshot:
81 """Snapshots at base, left, and right for a three-way merge."""
82
83 base: HeadSnapshot
84 left: HeadSnapshot
85 right: HeadSnapshot
86
87
88 @dataclass(frozen=True)
89 class MergeCheckoutPlan:
90 """Result of merge plan building — either a checkout plan or conflicts."""
91
92 is_conflict: bool
93 conflicts: tuple[MergeConflict, ...]
94 checkout_plan: CheckoutPlan | None
95
96
97 # ---------------------------------------------------------------------------
98 # Three-way snapshot construction
99 # ---------------------------------------------------------------------------
100
101
102 async def build_three_way_snapshots(
103 session: AsyncSession,
104 base_id: str,
105 left_id: str,
106 right_id: str,
107 ) -> ThreeWaySnapshot | None:
108 """Reconstruct snapshots for all three points in a merge.
109
110 Returns None if any of the three variations cannot be reconstructed.
111 """
112 base = await reconstruct_variation_snapshot(session, base_id)
113 left = await reconstruct_variation_snapshot(session, left_id)
114 right = await reconstruct_variation_snapshot(session, right_id)
115
116 if base is None or left is None or right is None:
117 return None
118
119 return ThreeWaySnapshot(base=base, left=left, right=right)
120
121
122 # ---------------------------------------------------------------------------
123 # Three-way merge engine
124 # ---------------------------------------------------------------------------
125
126
127 def build_merge_result(
128 *,
129 base: HeadSnapshot,
130 left: HeadSnapshot,
131 right: HeadSnapshot,
132 ) -> MergeResult:
133 """Perform a three-way merge of musical state.
134
135 For each region, compares left and right against the common base:
136 - Only one side changed → take that side.
137 - Neither changed → keep base.
138 - Both changed → per-note/event conflict detection.
139
140 Returns a MergeResult with the merged snapshot (if conflict-free)
141 or the list of conflicts.
142 """
143 all_regions = sorted(
144 set(base.notes.keys())
145 | set(left.notes.keys())
146 | set(right.notes.keys())
147 | set(base.cc.keys())
148 | set(left.cc.keys())
149 | set(right.cc.keys())
150 | set(base.pitch_bends.keys())
151 | set(left.pitch_bends.keys())
152 | set(right.pitch_bends.keys())
153 | set(base.aftertouch.keys())
154 | set(left.aftertouch.keys())
155 | set(right.aftertouch.keys())
156 )
157
158 conflicts: list[MergeConflict] = []
159 merged_notes: RegionNotesMap = {}
160 merged_cc: RegionCCMap = {}
161 merged_pb: RegionPitchBendMap = {}
162 merged_at: RegionAftertouchMap = {}
163 merged_track_regions: dict[str, str] = {}
164 merged_region_starts: dict[str, float] = {}
165
166 for tr in (base.track_regions, left.track_regions, right.track_regions):
167 merged_track_regions.update(tr)
168 for rs in (base.region_start_beats, left.region_start_beats, right.region_start_beats):
169 merged_region_starts.update(rs)
170
171 for rid in all_regions:
172 b_notes = base.notes.get(rid, [])
173 l_notes = left.notes.get(rid, [])
174 r_notes = right.notes.get(rid, [])
175
176 notes_result, note_conflicts = _merge_note_layer(
177 b_notes, l_notes, r_notes, rid,
178 )
179 merged_notes[rid] = notes_result
180 conflicts.extend(note_conflicts)
181
182 b_cc = base.cc.get(rid, [])
183 l_cc = left.cc.get(rid, [])
184 r_cc = right.cc.get(rid, [])
185 cc_result, cc_conflicts = _merge_event_layer(
186 b_cc, l_cc, r_cc, rid, "cc", match_cc_events,
187 )
188 merged_cc[rid] = cc_result
189 conflicts.extend(cc_conflicts)
190
191 b_pb = base.pitch_bends.get(rid, [])
192 l_pb = left.pitch_bends.get(rid, [])
193 r_pb = right.pitch_bends.get(rid, [])
194 pb_result, pb_conflicts = _merge_event_layer(
195 b_pb, l_pb, r_pb, rid, "pb", match_pitch_bends,
196 )
197 merged_pb[rid] = pb_result
198 conflicts.extend(pb_conflicts)
199
200 b_at = base.aftertouch.get(rid, [])
201 l_at = left.aftertouch.get(rid, [])
202 r_at = right.aftertouch.get(rid, [])
203 at_result, at_conflicts = _merge_event_layer(
204 b_at, l_at, r_at, rid, "at", match_aftertouch,
205 )
206 merged_at[rid] = at_result
207 conflicts.extend(at_conflicts)
208
209 conflict_tuple = tuple(conflicts)
210 has_conflicts = len(conflict_tuple) > 0
211
212 if has_conflicts:
213 return MergeResult(
214 has_conflicts=True,
215 conflicts=conflict_tuple,
216 merged_snapshot=None,
217 )
218
219 merged = HeadSnapshot(
220 variation_id=f"merge:{left.variation_id[:8]}+{right.variation_id[:8]}",
221 notes=merged_notes,
222 cc=merged_cc,
223 pitch_bends=merged_pb,
224 aftertouch=merged_at,
225 track_regions=merged_track_regions,
226 region_start_beats=merged_region_starts,
227 )
228 return MergeResult(
229 has_conflicts=False,
230 conflicts=(),
231 merged_snapshot=merged,
232 )
233
234
235 # ---------------------------------------------------------------------------
236 # Merge checkout plan builder
237 # ---------------------------------------------------------------------------
238
239
240 async def build_merge_checkout_plan(
241 session: AsyncSession,
242 project_id: str,
243 left_id: str,
244 right_id: str,
245 *,
246 working_notes: RegionNotesMap | None = None,
247 working_cc: RegionCCMap | None = None,
248 working_pb: RegionPitchBendMap | None = None,
249 working_at: RegionAftertouchMap | None = None,
250 repo_path: Path | None = None,
251 ) -> MergeCheckoutPlan:
252 """Build a complete merge plan: merge-base → three-way diff → checkout plan.
253
254 If conflicts exist, returns them without a checkout plan.
255 If conflict-free, builds a CheckoutPlan that would apply the merged
256 state to the working tree.
257 """
258 base_id = await find_merge_base(session, left_id, right_id)
259 if base_id is None:
260 return MergeCheckoutPlan(
261 is_conflict=True,
262 conflicts=(MergeConflict(
263 region_id="*",
264 type="note",
265 description="No common ancestor found between the two variations",
266 ),),
267 checkout_plan=None,
268 )
269
270 snapshots = await build_three_way_snapshots(session, base_id, left_id, right_id)
271 if snapshots is None:
272 return MergeCheckoutPlan(
273 is_conflict=True,
274 conflicts=(MergeConflict(
275 region_id="*",
276 type="note",
277 description="Cannot reconstruct snapshot for one or more variations",
278 ),),
279 checkout_plan=None,
280 )
281
282 result = build_merge_result(
283 base=snapshots.base, left=snapshots.left, right=snapshots.right,
284 )
285
286 if result.has_conflicts:
287 # Record conflict shape and attempt rerere auto-resolution when a repo
288 # root is available. This is a best-effort hook — rerere failures must
289 # never prevent the caller from receiving the conflict report.
290 if repo_path is not None:
291 try:
292 from maestro.services.muse_rerere import (
293 ConflictDict,
294 apply_rerere,
295 record_conflict,
296 )
297
298 conflict_dicts = [
299 ConflictDict(
300 region_id=c.region_id,
301 type=c.type,
302 description=c.description,
303 )
304 for c in result.conflicts
305 ]
306 record_conflict(repo_path, conflict_dicts)
307 applied, _resolution = apply_rerere(repo_path, conflict_dicts)
308 if applied:
309 logger.info(
310 "✅ muse rerere: resolved %d conflict(s) using rerere.",
311 applied,
312 )
313 except Exception as _rerere_exc: # noqa: BLE001
314 logger.warning(
315 "⚠️ muse rerere hook failed (non-fatal): %s", _rerere_exc
316 )
317
318 return MergeCheckoutPlan(
319 is_conflict=True,
320 conflicts=result.conflicts,
321 checkout_plan=None,
322 )
323
324 merged = result.merged_snapshot
325 assert merged is not None
326
327 plan = build_checkout_plan(
328 project_id=project_id,
329 target_variation_id=merged.variation_id,
330 target_notes=merged.notes,
331 target_cc=merged.cc,
332 target_pb=merged.pitch_bends,
333 target_at=merged.aftertouch,
334 working_notes=working_notes or {},
335 working_cc=working_cc or {},
336 working_pb=working_pb or {},
337 working_at=working_at or {},
338 track_regions=merged.track_regions,
339 )
340
341 return MergeCheckoutPlan(
342 is_conflict=False,
343 conflicts=(),
344 checkout_plan=plan,
345 )
346
347
348 # ---------------------------------------------------------------------------
349 # Per-layer merge helpers (private)
350 # ---------------------------------------------------------------------------
351
352
353 def _merge_note_layer(
354 base: list[NoteDict],
355 left: list[NoteDict],
356 right: list[NoteDict],
357 region_id: str,
358 ) -> tuple[list[NoteDict], list[MergeConflict]]:
359 """Three-way merge for notes in a single region."""
360 left_matches = match_notes(base, left)
361 right_matches = match_notes(base, right)
362
363 left_by_base: dict[int, NoteMatch] = {}
364 for m in left_matches:
365 if m.base_index is not None:
366 left_by_base[m.base_index] = m
367
368 right_by_base: dict[int, NoteMatch] = {}
369 for m in right_matches:
370 if m.base_index is not None:
371 right_by_base[m.base_index] = m
372
373 conflicts: list[MergeConflict] = []
374 merged: list[NoteDict] = []
375
376 for bi, base_note in enumerate(base):
377 lm = left_by_base.get(bi)
378 rm = right_by_base.get(bi)
379
380 l_removed = lm is not None and lm.is_removed
381 l_modified = lm is not None and lm.is_modified
382 r_removed = rm is not None and rm.is_removed
383 r_modified = rm is not None and rm.is_modified
384
385 if l_modified and r_modified:
386 conflicts.append(MergeConflict(
387 region_id=region_id, type="note",
388 description=f"Both sides modified note at pitch={base_note.get('pitch')} beat={base_note.get('start_beat')}",
389 ))
390 elif (l_removed and r_modified) or (r_removed and l_modified):
391 conflicts.append(MergeConflict(
392 region_id=region_id, type="note",
393 description=f"One side removed, other modified note at pitch={base_note.get('pitch')} beat={base_note.get('start_beat')}",
394 ))
395 elif l_removed or r_removed:
396 pass
397 elif l_modified and lm is not None and lm.proposed_note is not None:
398 merged.append(lm.proposed_note)
399 elif r_modified and rm is not None and rm.proposed_note is not None:
400 merged.append(rm.proposed_note)
401 else:
402 merged.append(base_note)
403
404 left_additions = [m.proposed_note for m in left_matches if m.is_added and m.proposed_note is not None]
405 right_additions = [m.proposed_note for m in right_matches if m.is_added and m.proposed_note is not None]
406
407 addition_conflicts = _check_addition_overlaps(left_additions, right_additions, region_id, "note")
408 conflicts.extend(addition_conflicts)
409
410 if not addition_conflicts:
411 merged.extend(left_additions)
412 merged.extend(right_additions)
413
414 return merged, conflicts
415
416
417 def _merge_event_layer(
418 base: list[_EV],
419 left: list[_EV],
420 right: list[_EV],
421 region_id: str,
422 event_type: Literal["cc", "pb", "at"],
423 match_fn: Callable[[list[_EV], list[_EV]], list[EventMatch[_EV]]],
424 ) -> tuple[list[_EV], list[MergeConflict]]:
425 """Three-way merge for a controller event layer in a single region."""
426 left_matches: list[EventMatch[_EV]] = match_fn(base, left)
427 right_matches: list[EventMatch[_EV]] = match_fn(base, right)
428
429 conflicts: list[MergeConflict] = []
430 merged: list[_EV] = []
431
432 for base_ev in base:
433 lm = _find_event_match_for_base(left_matches, base_ev)
434 rm = _find_event_match_for_base(right_matches, base_ev)
435
436 l_removed = lm is not None and lm.is_removed
437 l_modified = lm is not None and lm.is_modified
438 r_removed = rm is not None and rm.is_removed
439 r_modified = rm is not None and rm.is_modified
440
441 if l_modified and r_modified:
442 conflicts.append(MergeConflict(
443 region_id=region_id, type=event_type,
444 description=f"Both sides modified {event_type} event at beat={base_ev.get('beat')}",
445 ))
446 elif (l_removed and r_modified) or (r_removed and l_modified):
447 conflicts.append(MergeConflict(
448 region_id=region_id, type=event_type,
449 description=f"One side removed, other modified {event_type} event at beat={base_ev.get('beat')}",
450 ))
451 elif l_removed or r_removed:
452 pass
453 elif l_modified and lm is not None and lm.proposed_event is not None:
454 merged.append(lm.proposed_event)
455 elif r_modified and rm is not None and rm.proposed_event is not None:
456 merged.append(rm.proposed_event)
457 else:
458 merged.append(base_ev)
459
460 for m in left_matches:
461 if m.is_added and m.proposed_event is not None:
462 merged.append(m.proposed_event)
463 for m in right_matches:
464 if m.is_added and m.proposed_event is not None:
465 merged.append(m.proposed_event)
466
467 return merged, conflicts
468
469
470 def _find_event_match_for_base(
471 matches: list[EventMatch[_EV]],
472 base_event: _EV,
473 ) -> EventMatch[_EV] | None:
474 """Find the EventMatch that corresponds to a specific base event."""
475 for m in matches:
476 if m.base_event is base_event:
477 return m
478 return None
479
480
481 def _check_addition_overlaps(
482 left_adds: list[NoteDict],
483 right_adds: list[NoteDict],
484 region_id: str,
485 conflict_type: Literal["note", "cc", "pb", "at"],
486 ) -> list[MergeConflict]:
487 """Detect conflicting additions (same position, different content)."""
488 if not left_adds or not right_adds:
489 return []
490
491 from maestro.services.variation.note_matching import _notes_match
492
493 conflicts: list[MergeConflict] = []
494 for la in left_adds:
495 for ra in right_adds:
496 if _notes_match(la, ra):
497 if la != ra:
498 conflicts.append(MergeConflict(
499 region_id=region_id,
500 type=conflict_type,
501 description=f"Both sides added conflicting {conflict_type} at pitch={la.get('pitch')} beat={la.get('start_beat')}",
502 ))
503 return conflicts