muse_merge.py
python
| 1 | """Muse Merge Engine — three-way merge for musical variations. |
| 2 | |
| 3 | Produces a ``MergeResult`` by comparing base, left, and right snapshots. |
| 4 | Auto-merges non-conflicting changes; reports conflicts when both sides |
| 5 | modify the same note or controller event. |
| 6 | |
| 7 | After a conflict-free merge, :func:`build_merge_checkout_plan` attempts to |
| 8 | auto-apply any cached rerere resolution so that repeated identical conflicts |
| 9 | are resolved without user intervention. |
| 10 | |
| 11 | Boundary rules: |
| 12 | - Must NOT import StateStore, executor, MCP tools, or handlers. |
| 13 | - May import muse_repository, muse_replay, muse_checkout, note_matching. |
| 14 | """ |
| 15 | |
| 16 | from __future__ import annotations |
| 17 | |
| 18 | import logging |
| 19 | from collections.abc import Callable, Sequence |
| 20 | from dataclasses import dataclass |
| 21 | from pathlib import Path |
| 22 | from typing import Literal, TypeVar |
| 23 | |
| 24 | from sqlalchemy.ext.asyncio import AsyncSession |
| 25 | |
| 26 | from maestro.contracts.json_types import ( |
| 27 | AftertouchDict, |
| 28 | CCEventDict, |
| 29 | NoteDict, |
| 30 | PitchBendDict, |
| 31 | RegionAftertouchMap, |
| 32 | RegionCCMap, |
| 33 | RegionNotesMap, |
| 34 | RegionPitchBendMap, |
| 35 | ) |
| 36 | from maestro.services.muse_checkout import CheckoutPlan, build_checkout_plan |
| 37 | from maestro.services.muse_merge_base import find_merge_base |
| 38 | from maestro.services.muse_replay import HeadSnapshot, reconstruct_variation_snapshot |
| 39 | from maestro.services.variation.note_matching import ( |
| 40 | EventMatch, |
| 41 | NoteMatch, |
| 42 | match_aftertouch, |
| 43 | match_cc_events, |
| 44 | match_notes, |
| 45 | match_pitch_bends, |
| 46 | ) |
| 47 | |
| 48 | logger = logging.getLogger(__name__) |
| 49 | |
| 50 | # Mirrors the constrained TypeVar in note_matching so _merge_event_layer |
| 51 | # can propagate the concrete event type (CCEventDict, PitchBendDict, or |
| 52 | # AftertouchDict) without overloads or casts. |
| 53 | _EV = TypeVar("_EV", CCEventDict, PitchBendDict, AftertouchDict) |
| 54 | |
| 55 | |
| 56 | # --------------------------------------------------------------------------- |
| 57 | # Data types |
| 58 | # --------------------------------------------------------------------------- |
| 59 | |
| 60 | |
| 61 | @dataclass(frozen=True) |
| 62 | class MergeConflict: |
| 63 | """A single unresolvable conflict between left and right.""" |
| 64 | |
| 65 | region_id: str |
| 66 | type: Literal["note", "cc", "pb", "at"] |
| 67 | description: str |
| 68 | |
| 69 | |
| 70 | @dataclass(frozen=True) |
| 71 | class MergeResult: |
| 72 | """Outcome of a three-way merge.""" |
| 73 | |
| 74 | has_conflicts: bool |
| 75 | conflicts: tuple[MergeConflict, ...] |
| 76 | merged_snapshot: HeadSnapshot | None |
| 77 | |
| 78 | |
| 79 | @dataclass(frozen=True) |
| 80 | class ThreeWaySnapshot: |
| 81 | """Snapshots at base, left, and right for a three-way merge.""" |
| 82 | |
| 83 | base: HeadSnapshot |
| 84 | left: HeadSnapshot |
| 85 | right: HeadSnapshot |
| 86 | |
| 87 | |
| 88 | @dataclass(frozen=True) |
| 89 | class MergeCheckoutPlan: |
| 90 | """Result of merge plan building — either a checkout plan or conflicts.""" |
| 91 | |
| 92 | is_conflict: bool |
| 93 | conflicts: tuple[MergeConflict, ...] |
| 94 | checkout_plan: CheckoutPlan | None |
| 95 | |
| 96 | |
| 97 | # --------------------------------------------------------------------------- |
| 98 | # Three-way snapshot construction |
| 99 | # --------------------------------------------------------------------------- |
| 100 | |
| 101 | |
| 102 | async def build_three_way_snapshots( |
| 103 | session: AsyncSession, |
| 104 | base_id: str, |
| 105 | left_id: str, |
| 106 | right_id: str, |
| 107 | ) -> ThreeWaySnapshot | None: |
| 108 | """Reconstruct snapshots for all three points in a merge. |
| 109 | |
| 110 | Returns None if any of the three variations cannot be reconstructed. |
| 111 | """ |
| 112 | base = await reconstruct_variation_snapshot(session, base_id) |
| 113 | left = await reconstruct_variation_snapshot(session, left_id) |
| 114 | right = await reconstruct_variation_snapshot(session, right_id) |
| 115 | |
| 116 | if base is None or left is None or right is None: |
| 117 | return None |
| 118 | |
| 119 | return ThreeWaySnapshot(base=base, left=left, right=right) |
| 120 | |
| 121 | |
| 122 | # --------------------------------------------------------------------------- |
| 123 | # Three-way merge engine |
| 124 | # --------------------------------------------------------------------------- |
| 125 | |
| 126 | |
| 127 | def build_merge_result( |
| 128 | *, |
| 129 | base: HeadSnapshot, |
| 130 | left: HeadSnapshot, |
| 131 | right: HeadSnapshot, |
| 132 | ) -> MergeResult: |
| 133 | """Perform a three-way merge of musical state. |
| 134 | |
| 135 | For each region, compares left and right against the common base: |
| 136 | - Only one side changed → take that side. |
| 137 | - Neither changed → keep base. |
| 138 | - Both changed → per-note/event conflict detection. |
| 139 | |
| 140 | Returns a MergeResult with the merged snapshot (if conflict-free) |
| 141 | or the list of conflicts. |
| 142 | """ |
| 143 | all_regions = sorted( |
| 144 | set(base.notes.keys()) |
| 145 | | set(left.notes.keys()) |
| 146 | | set(right.notes.keys()) |
| 147 | | set(base.cc.keys()) |
| 148 | | set(left.cc.keys()) |
| 149 | | set(right.cc.keys()) |
| 150 | | set(base.pitch_bends.keys()) |
| 151 | | set(left.pitch_bends.keys()) |
| 152 | | set(right.pitch_bends.keys()) |
| 153 | | set(base.aftertouch.keys()) |
| 154 | | set(left.aftertouch.keys()) |
| 155 | | set(right.aftertouch.keys()) |
| 156 | ) |
| 157 | |
| 158 | conflicts: list[MergeConflict] = [] |
| 159 | merged_notes: RegionNotesMap = {} |
| 160 | merged_cc: RegionCCMap = {} |
| 161 | merged_pb: RegionPitchBendMap = {} |
| 162 | merged_at: RegionAftertouchMap = {} |
| 163 | merged_track_regions: dict[str, str] = {} |
| 164 | merged_region_starts: dict[str, float] = {} |
| 165 | |
| 166 | for tr in (base.track_regions, left.track_regions, right.track_regions): |
| 167 | merged_track_regions.update(tr) |
| 168 | for rs in (base.region_start_beats, left.region_start_beats, right.region_start_beats): |
| 169 | merged_region_starts.update(rs) |
| 170 | |
| 171 | for rid in all_regions: |
| 172 | b_notes = base.notes.get(rid, []) |
| 173 | l_notes = left.notes.get(rid, []) |
| 174 | r_notes = right.notes.get(rid, []) |
| 175 | |
| 176 | notes_result, note_conflicts = _merge_note_layer( |
| 177 | b_notes, l_notes, r_notes, rid, |
| 178 | ) |
| 179 | merged_notes[rid] = notes_result |
| 180 | conflicts.extend(note_conflicts) |
| 181 | |
| 182 | b_cc = base.cc.get(rid, []) |
| 183 | l_cc = left.cc.get(rid, []) |
| 184 | r_cc = right.cc.get(rid, []) |
| 185 | cc_result, cc_conflicts = _merge_event_layer( |
| 186 | b_cc, l_cc, r_cc, rid, "cc", match_cc_events, |
| 187 | ) |
| 188 | merged_cc[rid] = cc_result |
| 189 | conflicts.extend(cc_conflicts) |
| 190 | |
| 191 | b_pb = base.pitch_bends.get(rid, []) |
| 192 | l_pb = left.pitch_bends.get(rid, []) |
| 193 | r_pb = right.pitch_bends.get(rid, []) |
| 194 | pb_result, pb_conflicts = _merge_event_layer( |
| 195 | b_pb, l_pb, r_pb, rid, "pb", match_pitch_bends, |
| 196 | ) |
| 197 | merged_pb[rid] = pb_result |
| 198 | conflicts.extend(pb_conflicts) |
| 199 | |
| 200 | b_at = base.aftertouch.get(rid, []) |
| 201 | l_at = left.aftertouch.get(rid, []) |
| 202 | r_at = right.aftertouch.get(rid, []) |
| 203 | at_result, at_conflicts = _merge_event_layer( |
| 204 | b_at, l_at, r_at, rid, "at", match_aftertouch, |
| 205 | ) |
| 206 | merged_at[rid] = at_result |
| 207 | conflicts.extend(at_conflicts) |
| 208 | |
| 209 | conflict_tuple = tuple(conflicts) |
| 210 | has_conflicts = len(conflict_tuple) > 0 |
| 211 | |
| 212 | if has_conflicts: |
| 213 | return MergeResult( |
| 214 | has_conflicts=True, |
| 215 | conflicts=conflict_tuple, |
| 216 | merged_snapshot=None, |
| 217 | ) |
| 218 | |
| 219 | merged = HeadSnapshot( |
| 220 | variation_id=f"merge:{left.variation_id[:8]}+{right.variation_id[:8]}", |
| 221 | notes=merged_notes, |
| 222 | cc=merged_cc, |
| 223 | pitch_bends=merged_pb, |
| 224 | aftertouch=merged_at, |
| 225 | track_regions=merged_track_regions, |
| 226 | region_start_beats=merged_region_starts, |
| 227 | ) |
| 228 | return MergeResult( |
| 229 | has_conflicts=False, |
| 230 | conflicts=(), |
| 231 | merged_snapshot=merged, |
| 232 | ) |
| 233 | |
| 234 | |
| 235 | # --------------------------------------------------------------------------- |
| 236 | # Merge checkout plan builder |
| 237 | # --------------------------------------------------------------------------- |
| 238 | |
| 239 | |
| 240 | async def build_merge_checkout_plan( |
| 241 | session: AsyncSession, |
| 242 | project_id: str, |
| 243 | left_id: str, |
| 244 | right_id: str, |
| 245 | *, |
| 246 | working_notes: RegionNotesMap | None = None, |
| 247 | working_cc: RegionCCMap | None = None, |
| 248 | working_pb: RegionPitchBendMap | None = None, |
| 249 | working_at: RegionAftertouchMap | None = None, |
| 250 | repo_path: Path | None = None, |
| 251 | ) -> MergeCheckoutPlan: |
| 252 | """Build a complete merge plan: merge-base → three-way diff → checkout plan. |
| 253 | |
| 254 | If conflicts exist, returns them without a checkout plan. |
| 255 | If conflict-free, builds a CheckoutPlan that would apply the merged |
| 256 | state to the working tree. |
| 257 | """ |
| 258 | base_id = await find_merge_base(session, left_id, right_id) |
| 259 | if base_id is None: |
| 260 | return MergeCheckoutPlan( |
| 261 | is_conflict=True, |
| 262 | conflicts=(MergeConflict( |
| 263 | region_id="*", |
| 264 | type="note", |
| 265 | description="No common ancestor found between the two variations", |
| 266 | ),), |
| 267 | checkout_plan=None, |
| 268 | ) |
| 269 | |
| 270 | snapshots = await build_three_way_snapshots(session, base_id, left_id, right_id) |
| 271 | if snapshots is None: |
| 272 | return MergeCheckoutPlan( |
| 273 | is_conflict=True, |
| 274 | conflicts=(MergeConflict( |
| 275 | region_id="*", |
| 276 | type="note", |
| 277 | description="Cannot reconstruct snapshot for one or more variations", |
| 278 | ),), |
| 279 | checkout_plan=None, |
| 280 | ) |
| 281 | |
| 282 | result = build_merge_result( |
| 283 | base=snapshots.base, left=snapshots.left, right=snapshots.right, |
| 284 | ) |
| 285 | |
| 286 | if result.has_conflicts: |
| 287 | # Record conflict shape and attempt rerere auto-resolution when a repo |
| 288 | # root is available. This is a best-effort hook — rerere failures must |
| 289 | # never prevent the caller from receiving the conflict report. |
| 290 | if repo_path is not None: |
| 291 | try: |
| 292 | from maestro.services.muse_rerere import ( |
| 293 | ConflictDict, |
| 294 | apply_rerere, |
| 295 | record_conflict, |
| 296 | ) |
| 297 | |
| 298 | conflict_dicts = [ |
| 299 | ConflictDict( |
| 300 | region_id=c.region_id, |
| 301 | type=c.type, |
| 302 | description=c.description, |
| 303 | ) |
| 304 | for c in result.conflicts |
| 305 | ] |
| 306 | record_conflict(repo_path, conflict_dicts) |
| 307 | applied, _resolution = apply_rerere(repo_path, conflict_dicts) |
| 308 | if applied: |
| 309 | logger.info( |
| 310 | "✅ muse rerere: resolved %d conflict(s) using rerere.", |
| 311 | applied, |
| 312 | ) |
| 313 | except Exception as _rerere_exc: # noqa: BLE001 |
| 314 | logger.warning( |
| 315 | "⚠️ muse rerere hook failed (non-fatal): %s", _rerere_exc |
| 316 | ) |
| 317 | |
| 318 | return MergeCheckoutPlan( |
| 319 | is_conflict=True, |
| 320 | conflicts=result.conflicts, |
| 321 | checkout_plan=None, |
| 322 | ) |
| 323 | |
| 324 | merged = result.merged_snapshot |
| 325 | assert merged is not None |
| 326 | |
| 327 | plan = build_checkout_plan( |
| 328 | project_id=project_id, |
| 329 | target_variation_id=merged.variation_id, |
| 330 | target_notes=merged.notes, |
| 331 | target_cc=merged.cc, |
| 332 | target_pb=merged.pitch_bends, |
| 333 | target_at=merged.aftertouch, |
| 334 | working_notes=working_notes or {}, |
| 335 | working_cc=working_cc or {}, |
| 336 | working_pb=working_pb or {}, |
| 337 | working_at=working_at or {}, |
| 338 | track_regions=merged.track_regions, |
| 339 | ) |
| 340 | |
| 341 | return MergeCheckoutPlan( |
| 342 | is_conflict=False, |
| 343 | conflicts=(), |
| 344 | checkout_plan=plan, |
| 345 | ) |
| 346 | |
| 347 | |
| 348 | # --------------------------------------------------------------------------- |
| 349 | # Per-layer merge helpers (private) |
| 350 | # --------------------------------------------------------------------------- |
| 351 | |
| 352 | |
| 353 | def _merge_note_layer( |
| 354 | base: list[NoteDict], |
| 355 | left: list[NoteDict], |
| 356 | right: list[NoteDict], |
| 357 | region_id: str, |
| 358 | ) -> tuple[list[NoteDict], list[MergeConflict]]: |
| 359 | """Three-way merge for notes in a single region.""" |
| 360 | left_matches = match_notes(base, left) |
| 361 | right_matches = match_notes(base, right) |
| 362 | |
| 363 | left_by_base: dict[int, NoteMatch] = {} |
| 364 | for m in left_matches: |
| 365 | if m.base_index is not None: |
| 366 | left_by_base[m.base_index] = m |
| 367 | |
| 368 | right_by_base: dict[int, NoteMatch] = {} |
| 369 | for m in right_matches: |
| 370 | if m.base_index is not None: |
| 371 | right_by_base[m.base_index] = m |
| 372 | |
| 373 | conflicts: list[MergeConflict] = [] |
| 374 | merged: list[NoteDict] = [] |
| 375 | |
| 376 | for bi, base_note in enumerate(base): |
| 377 | lm = left_by_base.get(bi) |
| 378 | rm = right_by_base.get(bi) |
| 379 | |
| 380 | l_removed = lm is not None and lm.is_removed |
| 381 | l_modified = lm is not None and lm.is_modified |
| 382 | r_removed = rm is not None and rm.is_removed |
| 383 | r_modified = rm is not None and rm.is_modified |
| 384 | |
| 385 | if l_modified and r_modified: |
| 386 | conflicts.append(MergeConflict( |
| 387 | region_id=region_id, type="note", |
| 388 | description=f"Both sides modified note at pitch={base_note.get('pitch')} beat={base_note.get('start_beat')}", |
| 389 | )) |
| 390 | elif (l_removed and r_modified) or (r_removed and l_modified): |
| 391 | conflicts.append(MergeConflict( |
| 392 | region_id=region_id, type="note", |
| 393 | description=f"One side removed, other modified note at pitch={base_note.get('pitch')} beat={base_note.get('start_beat')}", |
| 394 | )) |
| 395 | elif l_removed or r_removed: |
| 396 | pass |
| 397 | elif l_modified and lm is not None and lm.proposed_note is not None: |
| 398 | merged.append(lm.proposed_note) |
| 399 | elif r_modified and rm is not None and rm.proposed_note is not None: |
| 400 | merged.append(rm.proposed_note) |
| 401 | else: |
| 402 | merged.append(base_note) |
| 403 | |
| 404 | left_additions = [m.proposed_note for m in left_matches if m.is_added and m.proposed_note is not None] |
| 405 | right_additions = [m.proposed_note for m in right_matches if m.is_added and m.proposed_note is not None] |
| 406 | |
| 407 | addition_conflicts = _check_addition_overlaps(left_additions, right_additions, region_id, "note") |
| 408 | conflicts.extend(addition_conflicts) |
| 409 | |
| 410 | if not addition_conflicts: |
| 411 | merged.extend(left_additions) |
| 412 | merged.extend(right_additions) |
| 413 | |
| 414 | return merged, conflicts |
| 415 | |
| 416 | |
| 417 | def _merge_event_layer( |
| 418 | base: list[_EV], |
| 419 | left: list[_EV], |
| 420 | right: list[_EV], |
| 421 | region_id: str, |
| 422 | event_type: Literal["cc", "pb", "at"], |
| 423 | match_fn: Callable[[list[_EV], list[_EV]], list[EventMatch[_EV]]], |
| 424 | ) -> tuple[list[_EV], list[MergeConflict]]: |
| 425 | """Three-way merge for a controller event layer in a single region.""" |
| 426 | left_matches: list[EventMatch[_EV]] = match_fn(base, left) |
| 427 | right_matches: list[EventMatch[_EV]] = match_fn(base, right) |
| 428 | |
| 429 | conflicts: list[MergeConflict] = [] |
| 430 | merged: list[_EV] = [] |
| 431 | |
| 432 | for base_ev in base: |
| 433 | lm = _find_event_match_for_base(left_matches, base_ev) |
| 434 | rm = _find_event_match_for_base(right_matches, base_ev) |
| 435 | |
| 436 | l_removed = lm is not None and lm.is_removed |
| 437 | l_modified = lm is not None and lm.is_modified |
| 438 | r_removed = rm is not None and rm.is_removed |
| 439 | r_modified = rm is not None and rm.is_modified |
| 440 | |
| 441 | if l_modified and r_modified: |
| 442 | conflicts.append(MergeConflict( |
| 443 | region_id=region_id, type=event_type, |
| 444 | description=f"Both sides modified {event_type} event at beat={base_ev.get('beat')}", |
| 445 | )) |
| 446 | elif (l_removed and r_modified) or (r_removed and l_modified): |
| 447 | conflicts.append(MergeConflict( |
| 448 | region_id=region_id, type=event_type, |
| 449 | description=f"One side removed, other modified {event_type} event at beat={base_ev.get('beat')}", |
| 450 | )) |
| 451 | elif l_removed or r_removed: |
| 452 | pass |
| 453 | elif l_modified and lm is not None and lm.proposed_event is not None: |
| 454 | merged.append(lm.proposed_event) |
| 455 | elif r_modified and rm is not None and rm.proposed_event is not None: |
| 456 | merged.append(rm.proposed_event) |
| 457 | else: |
| 458 | merged.append(base_ev) |
| 459 | |
| 460 | for m in left_matches: |
| 461 | if m.is_added and m.proposed_event is not None: |
| 462 | merged.append(m.proposed_event) |
| 463 | for m in right_matches: |
| 464 | if m.is_added and m.proposed_event is not None: |
| 465 | merged.append(m.proposed_event) |
| 466 | |
| 467 | return merged, conflicts |
| 468 | |
| 469 | |
| 470 | def _find_event_match_for_base( |
| 471 | matches: list[EventMatch[_EV]], |
| 472 | base_event: _EV, |
| 473 | ) -> EventMatch[_EV] | None: |
| 474 | """Find the EventMatch that corresponds to a specific base event.""" |
| 475 | for m in matches: |
| 476 | if m.base_event is base_event: |
| 477 | return m |
| 478 | return None |
| 479 | |
| 480 | |
| 481 | def _check_addition_overlaps( |
| 482 | left_adds: list[NoteDict], |
| 483 | right_adds: list[NoteDict], |
| 484 | region_id: str, |
| 485 | conflict_type: Literal["note", "cc", "pb", "at"], |
| 486 | ) -> list[MergeConflict]: |
| 487 | """Detect conflicting additions (same position, different content).""" |
| 488 | if not left_adds or not right_adds: |
| 489 | return [] |
| 490 | |
| 491 | from maestro.services.variation.note_matching import _notes_match |
| 492 | |
| 493 | conflicts: list[MergeConflict] = [] |
| 494 | for la in left_adds: |
| 495 | for ra in right_adds: |
| 496 | if _notes_match(la, ra): |
| 497 | if la != ra: |
| 498 | conflicts.append(MergeConflict( |
| 499 | region_id=region_id, |
| 500 | type=conflict_type, |
| 501 | description=f"Both sides added conflicting {conflict_type} at pitch={la.get('pitch')} beat={la.get('start_beat')}", |
| 502 | )) |
| 503 | return conflicts |