cgcardona / muse public
muse_history_controller.py python
363 lines 11.7 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Muse History Controller — orchestrates checkout, merge, time travel.
2
3 Coordinates HEAD movement, snapshot reconstruction, checkout plan
4 generation, drift safety, merge execution, and commit creation.
5
6 This is the internal entry point for undo/redo/merge — no route yet.
7 """
8
9 from __future__ import annotations
10
11 import logging
12 import uuid
13 from dataclasses import dataclass
14 from typing import TYPE_CHECKING
15
16 from maestro.contracts.json_types import (
17 AftertouchDict,
18 CCEventDict,
19 NoteDict,
20 PitchBendDict,
21 RegionAftertouchMap,
22 RegionCCMap,
23 RegionNotesMap,
24 RegionPitchBendMap,
25 )
26
27 from sqlalchemy.ext.asyncio import AsyncSession
28
29 from maestro.core.tracing import TraceContext
30 from maestro.models.variation import Variation as DomainVariation
31 from maestro.services import muse_repository
32 from maestro.services.muse_checkout import build_checkout_plan
33 from maestro.services.muse_checkout_executor import (
34 CheckoutExecutionResult,
35 execute_checkout_plan,
36 )
37 from maestro.services.muse_drift import DriftSeverity, compute_drift_report
38 from maestro.services.muse_merge import (
39 MergeConflict,
40 build_merge_checkout_plan,
41 )
42 if TYPE_CHECKING:
43 from maestro.core.state_store import StateStore
44
45 from maestro.services.muse_replay import (
46 HeadSnapshot,
47 reconstruct_head_snapshot,
48 reconstruct_variation_snapshot,
49 )
50
51 logger = logging.getLogger(__name__)
52
53
54 class CheckoutBlockedError(Exception):
55 """Raised when checkout is blocked by a dirty working tree."""
56
57 def __init__(self, severity: DriftSeverity, total_changes: int) -> None:
58 self.severity = severity
59 self.total_changes = total_changes
60 super().__init__(
61 f"Working tree is {severity.value} ({total_changes} changes). "
62 f"Commit or discard changes, or use force=True."
63 )
64
65
66 @dataclass(frozen=True)
67 class CheckoutSummary:
68 """Full summary of a checkout operation."""
69
70 project_id: str
71 from_variation_id: str | None
72 to_variation_id: str
73 execution: CheckoutExecutionResult
74 head_moved: bool
75
76
77 async def checkout_to_variation(
78 *,
79 session: AsyncSession,
80 project_id: str,
81 target_variation_id: str,
82 store: StateStore,
83 trace: TraceContext,
84 force: bool = False,
85 emit_sse: bool = True,
86 ) -> CheckoutSummary:
87 """Check out to a specific variation — the musical equivalent of ``git checkout``.
88
89 Orchestration flow:
90 1. Load current HEAD.
91 2. Reconstruct target snapshot.
92 3. If not force: run drift detection → block if dirty.
93 4. Build checkout plan (target vs working).
94 5. Execute checkout plan (mutate StateStore).
95 6. Move HEAD pointer.
96
97 Args:
98 session: Async DB session.
99 project_id: Project identifier.
100 target_variation_id: Variation to check out to.
101 store: StateStore instance.
102 trace: Trace context.
103 force: Bypass drift safety check.
104 emit_sse: Emit SSE events during execution.
105
106 Raises:
107 CheckoutBlockedError: Working tree is dirty and force=False.
108 ValueError: Target variation not found.
109 """
110 current_head = await muse_repository.get_head(session, project_id)
111 from_variation_id = current_head.variation_id if current_head else None
112
113 target_snap = await reconstruct_variation_snapshot(session, target_variation_id)
114 if target_snap is None:
115 raise ValueError(f"Cannot reconstruct snapshot for variation {target_variation_id}")
116
117 # ── Drift safety ──────────────────────────────────────────────
118 if not force and current_head is not None:
119 head_snap = await reconstruct_head_snapshot(session, project_id)
120 if head_snap is not None:
121 working_notes = _capture_working_notes(store)
122 working_cc = _capture_working_cc(store)
123 working_pb = _capture_working_pb(store)
124 working_at = _capture_working_at(store)
125
126 drift = compute_drift_report(
127 project_id=project_id,
128 head_variation_id=head_snap.variation_id,
129 head_snapshot_notes=head_snap.notes,
130 working_snapshot_notes=working_notes,
131 track_regions=head_snap.track_regions,
132 head_cc=head_snap.cc,
133 working_cc=working_cc,
134 head_pb=head_snap.pitch_bends,
135 working_pb=working_pb,
136 head_at=head_snap.aftertouch,
137 working_at=working_at,
138 )
139 if drift.requires_user_action():
140 raise CheckoutBlockedError(drift.severity, drift.total_changes)
141
142 # ── Build checkout plan ───────────────────────────────────────
143 working_notes = _capture_working_notes(store)
144 working_cc = _capture_working_cc(store)
145 working_pb = _capture_working_pb(store)
146 working_at = _capture_working_at(store)
147
148 plan = build_checkout_plan(
149 project_id=project_id,
150 target_variation_id=target_variation_id,
151 target_notes=target_snap.notes,
152 target_cc=target_snap.cc,
153 target_pb=target_snap.pitch_bends,
154 target_at=target_snap.aftertouch,
155 working_notes=working_notes,
156 working_cc=working_cc,
157 working_pb=working_pb,
158 working_at=working_at,
159 track_regions=target_snap.track_regions,
160 )
161
162 # ── Execute ───────────────────────────────────────────────────
163 result = execute_checkout_plan(
164 checkout_plan=plan,
165 store=store,
166 trace=trace,
167 emit_sse=emit_sse,
168 )
169
170 # ── Move HEAD ─────────────────────────────────────────────────
171 head_moved = False
172 if result.failed == 0:
173 await muse_repository.move_head(session, project_id, target_variation_id)
174 head_moved = True
175 logger.info(
176 "✅ Checkout complete: %s → %s (%d tool calls)",
177 (from_variation_id or "none")[:8],
178 target_variation_id[:8],
179 result.executed,
180 )
181 else:
182 logger.warning(
183 "⚠️ Checkout execution had failures — HEAD not moved (%d/%d failed)",
184 result.failed, result.executed + result.failed,
185 )
186
187 return CheckoutSummary(
188 project_id=project_id,
189 from_variation_id=from_variation_id,
190 to_variation_id=target_variation_id,
191 execution=result,
192 head_moved=head_moved,
193 )
194
195
196 class MergeConflictError(Exception):
197 """Raised when a merge has unresolvable conflicts."""
198
199 def __init__(self, conflicts: tuple[MergeConflict, ...]) -> None:
200 self.conflicts = conflicts
201 regions = {c.region_id for c in conflicts}
202 super().__init__(
203 f"Merge has {len(conflicts)} conflict(s) in {len(regions)} region(s). "
204 f"Resolve conflicts or use force=True."
205 )
206
207
208 @dataclass(frozen=True)
209 class MergeSummary:
210 """Full summary of a merge operation."""
211
212 project_id: str
213 left_id: str
214 right_id: str
215 merge_variation_id: str
216 execution: CheckoutExecutionResult
217 head_moved: bool
218
219
220 async def merge_variations(
221 *,
222 session: AsyncSession,
223 project_id: str,
224 left_id: str,
225 right_id: str,
226 store: StateStore,
227 trace: TraceContext,
228 force: bool = False,
229 emit_sse: bool = True,
230 ) -> MergeSummary:
231 """Merge two variations — the musical equivalent of ``git merge``.
232
233 Orchestration flow:
234 1. Compute merge base.
235 2. Build three-way merge result.
236 3. If conflicts and not force → raise MergeConflictError.
237 4. Build checkout plan from merged snapshot.
238 5. Execute checkout plan.
239 6. Create merge commit (two parents).
240 7. Move HEAD.
241
242 Args:
243 session: Async DB session.
244 project_id: Project identifier.
245 left_id: "Ours" — typically the current HEAD.
246 right_id: "Theirs" — the branch being merged in.
247 store: StateStore instance.
248 trace: Trace context.
249 force: Bypass conflict check (takes left for conflicts).
250 emit_sse: Emit SSE events during execution.
251
252 Raises:
253 MergeConflictError: Merge has conflicts and force=False.
254 """
255 working_notes = _capture_working_notes(store)
256 working_cc = _capture_working_cc(store)
257 working_pb = _capture_working_pb(store)
258 working_at = _capture_working_at(store)
259
260 merge_plan = await build_merge_checkout_plan(
261 session, project_id, left_id, right_id,
262 working_notes=working_notes,
263 working_cc=working_cc,
264 working_pb=working_pb,
265 working_at=working_at,
266 )
267
268 if merge_plan.is_conflict and not force:
269 raise MergeConflictError(merge_plan.conflicts)
270
271 if merge_plan.checkout_plan is None:
272 raise MergeConflictError(merge_plan.conflicts)
273
274 result = execute_checkout_plan(
275 checkout_plan=merge_plan.checkout_plan,
276 store=store,
277 trace=trace,
278 emit_sse=emit_sse,
279 )
280
281 merge_vid = str(uuid.uuid4())
282 head_moved = False
283
284 if result.failed == 0:
285 merge_variation = DomainVariation(
286 variation_id=merge_vid,
287 intent="merge",
288 ai_explanation=f"Merge of {left_id[:8]} and {right_id[:8]}",
289 affected_tracks=[],
290 affected_regions=[],
291 beat_range=(0.0, 0.0),
292 phrases=[],
293 )
294 await muse_repository.save_variation(
295 session, merge_variation,
296 project_id=project_id,
297 base_state_id="merge",
298 conversation_id="merge",
299 region_metadata={},
300 status="committed",
301 parent_variation_id=left_id,
302 parent2_variation_id=right_id,
303 )
304 await muse_repository.move_head(session, project_id, merge_vid)
305 head_moved = True
306 logger.info(
307 "✅ Merge complete: %s + %s → %s",
308 left_id[:8], right_id[:8], merge_vid[:8],
309 )
310 else:
311 logger.warning(
312 "⚠️ Merge execution had failures — commit not created (%d/%d failed)",
313 result.failed, result.executed + result.failed,
314 )
315
316 return MergeSummary(
317 project_id=project_id,
318 left_id=left_id,
319 right_id=right_id,
320 merge_variation_id=merge_vid,
321 execution=result,
322 head_moved=head_moved,
323 )
324
325
326 def _capture_working_notes(store: StateStore) -> RegionNotesMap:
327 """Extract notes from all regions in the store."""
328 result: RegionNotesMap = {}
329 if hasattr(store, "_region_notes"):
330 for rid, notes in store._region_notes.items():
331 if notes:
332 result[rid] = list(notes)
333 return result
334
335
336 def _capture_working_cc(store: StateStore) -> RegionCCMap:
337 """Extract CC events from all regions in the store."""
338 result: RegionCCMap = {}
339 if hasattr(store, "_region_cc"):
340 for rid, events in store._region_cc.items():
341 if events:
342 result[rid] = list(events)
343 return result
344
345
346 def _capture_working_pb(store: StateStore) -> RegionPitchBendMap:
347 """Extract pitch bend events from all regions in the store."""
348 result: RegionPitchBendMap = {}
349 if hasattr(store, "_region_pitch_bends"):
350 for rid, events in store._region_pitch_bends.items():
351 if events:
352 result[rid] = list(events)
353 return result
354
355
356 def _capture_working_at(store: StateStore) -> RegionAftertouchMap:
357 """Extract aftertouch events from all regions in the store."""
358 result: RegionAftertouchMap = {}
359 if hasattr(store, "_region_aftertouch"):
360 for rid, events in store._region_aftertouch.items():
361 if events:
362 result[rid] = list(events)
363 return result