muse_history_controller.py
python
| 1 | """Muse History Controller — orchestrates checkout, merge, time travel. |
| 2 | |
| 3 | Coordinates HEAD movement, snapshot reconstruction, checkout plan |
| 4 | generation, drift safety, merge execution, and commit creation. |
| 5 | |
| 6 | This is the internal entry point for undo/redo/merge — no route yet. |
| 7 | """ |
| 8 | |
| 9 | from __future__ import annotations |
| 10 | |
| 11 | import logging |
| 12 | import uuid |
| 13 | from dataclasses import dataclass |
| 14 | from typing import TYPE_CHECKING |
| 15 | |
| 16 | from maestro.contracts.json_types import ( |
| 17 | AftertouchDict, |
| 18 | CCEventDict, |
| 19 | NoteDict, |
| 20 | PitchBendDict, |
| 21 | RegionAftertouchMap, |
| 22 | RegionCCMap, |
| 23 | RegionNotesMap, |
| 24 | RegionPitchBendMap, |
| 25 | ) |
| 26 | |
| 27 | from sqlalchemy.ext.asyncio import AsyncSession |
| 28 | |
| 29 | from maestro.core.tracing import TraceContext |
| 30 | from maestro.models.variation import Variation as DomainVariation |
| 31 | from maestro.services import muse_repository |
| 32 | from maestro.services.muse_checkout import build_checkout_plan |
| 33 | from maestro.services.muse_checkout_executor import ( |
| 34 | CheckoutExecutionResult, |
| 35 | execute_checkout_plan, |
| 36 | ) |
| 37 | from maestro.services.muse_drift import DriftSeverity, compute_drift_report |
| 38 | from maestro.services.muse_merge import ( |
| 39 | MergeConflict, |
| 40 | build_merge_checkout_plan, |
| 41 | ) |
| 42 | if TYPE_CHECKING: |
| 43 | from maestro.core.state_store import StateStore |
| 44 | |
| 45 | from maestro.services.muse_replay import ( |
| 46 | HeadSnapshot, |
| 47 | reconstruct_head_snapshot, |
| 48 | reconstruct_variation_snapshot, |
| 49 | ) |
| 50 | |
| 51 | logger = logging.getLogger(__name__) |
| 52 | |
| 53 | |
| 54 | class CheckoutBlockedError(Exception): |
| 55 | """Raised when checkout is blocked by a dirty working tree.""" |
| 56 | |
| 57 | def __init__(self, severity: DriftSeverity, total_changes: int) -> None: |
| 58 | self.severity = severity |
| 59 | self.total_changes = total_changes |
| 60 | super().__init__( |
| 61 | f"Working tree is {severity.value} ({total_changes} changes). " |
| 62 | f"Commit or discard changes, or use force=True." |
| 63 | ) |
| 64 | |
| 65 | |
| 66 | @dataclass(frozen=True) |
| 67 | class CheckoutSummary: |
| 68 | """Full summary of a checkout operation.""" |
| 69 | |
| 70 | project_id: str |
| 71 | from_variation_id: str | None |
| 72 | to_variation_id: str |
| 73 | execution: CheckoutExecutionResult |
| 74 | head_moved: bool |
| 75 | |
| 76 | |
| 77 | async def checkout_to_variation( |
| 78 | *, |
| 79 | session: AsyncSession, |
| 80 | project_id: str, |
| 81 | target_variation_id: str, |
| 82 | store: StateStore, |
| 83 | trace: TraceContext, |
| 84 | force: bool = False, |
| 85 | emit_sse: bool = True, |
| 86 | ) -> CheckoutSummary: |
| 87 | """Check out to a specific variation — the musical equivalent of ``git checkout``. |
| 88 | |
| 89 | Orchestration flow: |
| 90 | 1. Load current HEAD. |
| 91 | 2. Reconstruct target snapshot. |
| 92 | 3. If not force: run drift detection → block if dirty. |
| 93 | 4. Build checkout plan (target vs working). |
| 94 | 5. Execute checkout plan (mutate StateStore). |
| 95 | 6. Move HEAD pointer. |
| 96 | |
| 97 | Args: |
| 98 | session: Async DB session. |
| 99 | project_id: Project identifier. |
| 100 | target_variation_id: Variation to check out to. |
| 101 | store: StateStore instance. |
| 102 | trace: Trace context. |
| 103 | force: Bypass drift safety check. |
| 104 | emit_sse: Emit SSE events during execution. |
| 105 | |
| 106 | Raises: |
| 107 | CheckoutBlockedError: Working tree is dirty and force=False. |
| 108 | ValueError: Target variation not found. |
| 109 | """ |
| 110 | current_head = await muse_repository.get_head(session, project_id) |
| 111 | from_variation_id = current_head.variation_id if current_head else None |
| 112 | |
| 113 | target_snap = await reconstruct_variation_snapshot(session, target_variation_id) |
| 114 | if target_snap is None: |
| 115 | raise ValueError(f"Cannot reconstruct snapshot for variation {target_variation_id}") |
| 116 | |
| 117 | # ── Drift safety ────────────────────────────────────────────── |
| 118 | if not force and current_head is not None: |
| 119 | head_snap = await reconstruct_head_snapshot(session, project_id) |
| 120 | if head_snap is not None: |
| 121 | working_notes = _capture_working_notes(store) |
| 122 | working_cc = _capture_working_cc(store) |
| 123 | working_pb = _capture_working_pb(store) |
| 124 | working_at = _capture_working_at(store) |
| 125 | |
| 126 | drift = compute_drift_report( |
| 127 | project_id=project_id, |
| 128 | head_variation_id=head_snap.variation_id, |
| 129 | head_snapshot_notes=head_snap.notes, |
| 130 | working_snapshot_notes=working_notes, |
| 131 | track_regions=head_snap.track_regions, |
| 132 | head_cc=head_snap.cc, |
| 133 | working_cc=working_cc, |
| 134 | head_pb=head_snap.pitch_bends, |
| 135 | working_pb=working_pb, |
| 136 | head_at=head_snap.aftertouch, |
| 137 | working_at=working_at, |
| 138 | ) |
| 139 | if drift.requires_user_action(): |
| 140 | raise CheckoutBlockedError(drift.severity, drift.total_changes) |
| 141 | |
| 142 | # ── Build checkout plan ─────────────────────────────────────── |
| 143 | working_notes = _capture_working_notes(store) |
| 144 | working_cc = _capture_working_cc(store) |
| 145 | working_pb = _capture_working_pb(store) |
| 146 | working_at = _capture_working_at(store) |
| 147 | |
| 148 | plan = build_checkout_plan( |
| 149 | project_id=project_id, |
| 150 | target_variation_id=target_variation_id, |
| 151 | target_notes=target_snap.notes, |
| 152 | target_cc=target_snap.cc, |
| 153 | target_pb=target_snap.pitch_bends, |
| 154 | target_at=target_snap.aftertouch, |
| 155 | working_notes=working_notes, |
| 156 | working_cc=working_cc, |
| 157 | working_pb=working_pb, |
| 158 | working_at=working_at, |
| 159 | track_regions=target_snap.track_regions, |
| 160 | ) |
| 161 | |
| 162 | # ── Execute ─────────────────────────────────────────────────── |
| 163 | result = execute_checkout_plan( |
| 164 | checkout_plan=plan, |
| 165 | store=store, |
| 166 | trace=trace, |
| 167 | emit_sse=emit_sse, |
| 168 | ) |
| 169 | |
| 170 | # ── Move HEAD ───────────────────────────────────────────────── |
| 171 | head_moved = False |
| 172 | if result.failed == 0: |
| 173 | await muse_repository.move_head(session, project_id, target_variation_id) |
| 174 | head_moved = True |
| 175 | logger.info( |
| 176 | "✅ Checkout complete: %s → %s (%d tool calls)", |
| 177 | (from_variation_id or "none")[:8], |
| 178 | target_variation_id[:8], |
| 179 | result.executed, |
| 180 | ) |
| 181 | else: |
| 182 | logger.warning( |
| 183 | "⚠️ Checkout execution had failures — HEAD not moved (%d/%d failed)", |
| 184 | result.failed, result.executed + result.failed, |
| 185 | ) |
| 186 | |
| 187 | return CheckoutSummary( |
| 188 | project_id=project_id, |
| 189 | from_variation_id=from_variation_id, |
| 190 | to_variation_id=target_variation_id, |
| 191 | execution=result, |
| 192 | head_moved=head_moved, |
| 193 | ) |
| 194 | |
| 195 | |
| 196 | class MergeConflictError(Exception): |
| 197 | """Raised when a merge has unresolvable conflicts.""" |
| 198 | |
| 199 | def __init__(self, conflicts: tuple[MergeConflict, ...]) -> None: |
| 200 | self.conflicts = conflicts |
| 201 | regions = {c.region_id for c in conflicts} |
| 202 | super().__init__( |
| 203 | f"Merge has {len(conflicts)} conflict(s) in {len(regions)} region(s). " |
| 204 | f"Resolve conflicts or use force=True." |
| 205 | ) |
| 206 | |
| 207 | |
| 208 | @dataclass(frozen=True) |
| 209 | class MergeSummary: |
| 210 | """Full summary of a merge operation.""" |
| 211 | |
| 212 | project_id: str |
| 213 | left_id: str |
| 214 | right_id: str |
| 215 | merge_variation_id: str |
| 216 | execution: CheckoutExecutionResult |
| 217 | head_moved: bool |
| 218 | |
| 219 | |
| 220 | async def merge_variations( |
| 221 | *, |
| 222 | session: AsyncSession, |
| 223 | project_id: str, |
| 224 | left_id: str, |
| 225 | right_id: str, |
| 226 | store: StateStore, |
| 227 | trace: TraceContext, |
| 228 | force: bool = False, |
| 229 | emit_sse: bool = True, |
| 230 | ) -> MergeSummary: |
| 231 | """Merge two variations — the musical equivalent of ``git merge``. |
| 232 | |
| 233 | Orchestration flow: |
| 234 | 1. Compute merge base. |
| 235 | 2. Build three-way merge result. |
| 236 | 3. If conflicts and not force → raise MergeConflictError. |
| 237 | 4. Build checkout plan from merged snapshot. |
| 238 | 5. Execute checkout plan. |
| 239 | 6. Create merge commit (two parents). |
| 240 | 7. Move HEAD. |
| 241 | |
| 242 | Args: |
| 243 | session: Async DB session. |
| 244 | project_id: Project identifier. |
| 245 | left_id: "Ours" — typically the current HEAD. |
| 246 | right_id: "Theirs" — the branch being merged in. |
| 247 | store: StateStore instance. |
| 248 | trace: Trace context. |
| 249 | force: Bypass conflict check (takes left for conflicts). |
| 250 | emit_sse: Emit SSE events during execution. |
| 251 | |
| 252 | Raises: |
| 253 | MergeConflictError: Merge has conflicts and force=False. |
| 254 | """ |
| 255 | working_notes = _capture_working_notes(store) |
| 256 | working_cc = _capture_working_cc(store) |
| 257 | working_pb = _capture_working_pb(store) |
| 258 | working_at = _capture_working_at(store) |
| 259 | |
| 260 | merge_plan = await build_merge_checkout_plan( |
| 261 | session, project_id, left_id, right_id, |
| 262 | working_notes=working_notes, |
| 263 | working_cc=working_cc, |
| 264 | working_pb=working_pb, |
| 265 | working_at=working_at, |
| 266 | ) |
| 267 | |
| 268 | if merge_plan.is_conflict and not force: |
| 269 | raise MergeConflictError(merge_plan.conflicts) |
| 270 | |
| 271 | if merge_plan.checkout_plan is None: |
| 272 | raise MergeConflictError(merge_plan.conflicts) |
| 273 | |
| 274 | result = execute_checkout_plan( |
| 275 | checkout_plan=merge_plan.checkout_plan, |
| 276 | store=store, |
| 277 | trace=trace, |
| 278 | emit_sse=emit_sse, |
| 279 | ) |
| 280 | |
| 281 | merge_vid = str(uuid.uuid4()) |
| 282 | head_moved = False |
| 283 | |
| 284 | if result.failed == 0: |
| 285 | merge_variation = DomainVariation( |
| 286 | variation_id=merge_vid, |
| 287 | intent="merge", |
| 288 | ai_explanation=f"Merge of {left_id[:8]} and {right_id[:8]}", |
| 289 | affected_tracks=[], |
| 290 | affected_regions=[], |
| 291 | beat_range=(0.0, 0.0), |
| 292 | phrases=[], |
| 293 | ) |
| 294 | await muse_repository.save_variation( |
| 295 | session, merge_variation, |
| 296 | project_id=project_id, |
| 297 | base_state_id="merge", |
| 298 | conversation_id="merge", |
| 299 | region_metadata={}, |
| 300 | status="committed", |
| 301 | parent_variation_id=left_id, |
| 302 | parent2_variation_id=right_id, |
| 303 | ) |
| 304 | await muse_repository.move_head(session, project_id, merge_vid) |
| 305 | head_moved = True |
| 306 | logger.info( |
| 307 | "✅ Merge complete: %s + %s → %s", |
| 308 | left_id[:8], right_id[:8], merge_vid[:8], |
| 309 | ) |
| 310 | else: |
| 311 | logger.warning( |
| 312 | "⚠️ Merge execution had failures — commit not created (%d/%d failed)", |
| 313 | result.failed, result.executed + result.failed, |
| 314 | ) |
| 315 | |
| 316 | return MergeSummary( |
| 317 | project_id=project_id, |
| 318 | left_id=left_id, |
| 319 | right_id=right_id, |
| 320 | merge_variation_id=merge_vid, |
| 321 | execution=result, |
| 322 | head_moved=head_moved, |
| 323 | ) |
| 324 | |
| 325 | |
| 326 | def _capture_working_notes(store: StateStore) -> RegionNotesMap: |
| 327 | """Extract notes from all regions in the store.""" |
| 328 | result: RegionNotesMap = {} |
| 329 | if hasattr(store, "_region_notes"): |
| 330 | for rid, notes in store._region_notes.items(): |
| 331 | if notes: |
| 332 | result[rid] = list(notes) |
| 333 | return result |
| 334 | |
| 335 | |
| 336 | def _capture_working_cc(store: StateStore) -> RegionCCMap: |
| 337 | """Extract CC events from all regions in the store.""" |
| 338 | result: RegionCCMap = {} |
| 339 | if hasattr(store, "_region_cc"): |
| 340 | for rid, events in store._region_cc.items(): |
| 341 | if events: |
| 342 | result[rid] = list(events) |
| 343 | return result |
| 344 | |
| 345 | |
| 346 | def _capture_working_pb(store: StateStore) -> RegionPitchBendMap: |
| 347 | """Extract pitch bend events from all regions in the store.""" |
| 348 | result: RegionPitchBendMap = {} |
| 349 | if hasattr(store, "_region_pitch_bends"): |
| 350 | for rid, events in store._region_pitch_bends.items(): |
| 351 | if events: |
| 352 | result[rid] = list(events) |
| 353 | return result |
| 354 | |
| 355 | |
| 356 | def _capture_working_at(store: StateStore) -> RegionAftertouchMap: |
| 357 | """Extract aftertouch events from all regions in the store.""" |
| 358 | result: RegionAftertouchMap = {} |
| 359 | if hasattr(store, "_region_aftertouch"): |
| 360 | for rid, events in store._region_aftertouch.items(): |
| 361 | if events: |
| 362 | result[rid] = list(events) |
| 363 | return result |