cgcardona / muse public
muse_repository.py python
514 lines 16.5 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Muse persistence adapter — single point of DB access for variation history.
2
3 This module is the ONLY place that touches the variations/phrases/note_changes
4 tables. Orchestration, executor, and VariationService must never import it
5 or depend on it structurally — they produce/consume domain models
6 (app.models.variation) and this module handles the storage translation.
7
8 Boundary rules:
9 - Must NOT import StateStore, EntityRegistry, or get_or_create_store.
10 - Must NOT import VariationService or executor modules.
11 - May import domain models from maestro.models.variation.
12 - May import ORM models from maestro.db.muse_models.
13 """
14
15 from __future__ import annotations
16
17 import logging
18 import uuid
19 from dataclasses import dataclass
20 from datetime import datetime
21 from sqlalchemy import select, update
22 from sqlalchemy.ext.asyncio import AsyncSession
23 from sqlalchemy.orm import selectinload
24
25 from maestro.contracts.json_types import (
26 AftertouchDict,
27 CCEventDict,
28 PitchBendDict,
29 RegionMetadataDB,
30 RegionMetadataWire,
31 )
32 from maestro.db import muse_models as db
33 from maestro.models.variation import (
34 ChangeType,
35 Variation as DomainVariation,
36 Phrase as DomainPhrase,
37 NoteChange as DomainNoteChange,
38 MidiNoteSnapshot,
39 )
40
41 logger = logging.getLogger(__name__)
42
43 def _validate_change_type(raw: str) -> ChangeType:
44 """Narrow a DB string to the ChangeType literal, raising on bad data."""
45 if raw == "added":
46 return "added"
47 if raw == "removed":
48 return "removed"
49 if raw == "modified":
50 return "modified"
51 raise ValueError(f"Invalid change_type in DB: {raw!r}")
52
53
54 def _parse_cc_event(raw: CCEventDict) -> CCEventDict:
55 """Coerce a DB-deserialized CCEventDict to correct Python types.
56
57 SQLAlchemy deserialises JSON columns as plain Python dicts; the value types
58 are whatever json.loads produced (int, float, str). The defensive casts
59 here handle cases where the stored value doesn't match the expected type.
60 """
61 def _to_int(v: object) -> int:
62 return int(v) if isinstance(v, (int, float, str)) else 0
63
64 def _to_float(v: object) -> float:
65 return float(v) if isinstance(v, (int, float, str)) else 0.0
66
67 return CCEventDict(
68 cc=_to_int(raw.get("cc", 0)),
69 beat=_to_float(raw.get("beat", 0)),
70 value=_to_int(raw.get("value", 0)),
71 )
72
73
74 def _parse_pitch_bend(raw: PitchBendDict) -> PitchBendDict:
75 """Coerce a DB-deserialized PitchBendDict to correct Python types."""
76 raw_beat = raw.get("beat", 0)
77 raw_value = raw.get("value", 0)
78 return PitchBendDict(
79 beat=float(raw_beat) if isinstance(raw_beat, (int, float, str)) else 0.0,
80 value=int(raw_value) if isinstance(raw_value, (int, float, str)) else 0,
81 )
82
83
84 def _parse_aftertouch(raw: AftertouchDict) -> AftertouchDict:
85 """Coerce a DB-deserialized AftertouchDict to correct Python types."""
86 raw_beat = raw.get("beat", 0)
87 raw_value = raw.get("value", 0)
88 ev: AftertouchDict = {
89 "beat": float(raw_beat) if isinstance(raw_beat, (int, float, str)) else 0.0,
90 "value": int(raw_value) if isinstance(raw_value, (int, float, str)) else 0,
91 }
92 if "pitch" in raw:
93 raw_pitch = raw["pitch"]
94 ev["pitch"] = int(raw_pitch) if isinstance(raw_pitch, (int, float, str)) else 0
95 return ev
96
97
98 @dataclass(frozen=True)
99 class HistoryNode:
100 """Lightweight lineage node — used by replay engine to traverse history."""
101
102 variation_id: str
103 parent_variation_id: str | None
104 commit_state_id: str | None
105 created_at: datetime
106
107
108 @dataclass(frozen=True)
109 class VariationSummary:
110 """Lightweight variation metadata for log graph serialization."""
111
112 variation_id: str
113 parent_variation_id: str | None
114 parent2_variation_id: str | None
115 is_head: bool
116 created_at: datetime
117 intent: str
118 affected_regions: tuple[str, ...]
119
120
121 async def save_variation(
122 session: AsyncSession,
123 variation: DomainVariation,
124 *,
125 project_id: str,
126 base_state_id: str,
127 conversation_id: str,
128 region_metadata: dict[str, RegionMetadataWire],
129 status: str = "ready",
130 parent_variation_id: str | None = None,
131 parent2_variation_id: str | None = None,
132 ) -> None:
133 """Persist a domain Variation and all its phrases/note_changes to Postgres."""
134 row = db.Variation(
135 variation_id=variation.variation_id,
136 project_id=project_id,
137 base_state_id=base_state_id,
138 conversation_id=conversation_id,
139 intent=variation.intent,
140 explanation=variation.ai_explanation,
141 status=status,
142 affected_tracks=variation.affected_tracks,
143 affected_regions=variation.affected_regions,
144 beat_range_start=variation.beat_range[0],
145 beat_range_end=variation.beat_range[1],
146 parent_variation_id=parent_variation_id,
147 parent2_variation_id=parent2_variation_id,
148 )
149 session.add(row)
150
151 for seq, phrase in enumerate(variation.phrases, start=1):
152 r_meta = region_metadata.get(phrase.region_id, {})
153 p_row = db.Phrase(
154 phrase_id=phrase.phrase_id,
155 variation_id=variation.variation_id,
156 sequence=seq,
157 track_id=phrase.track_id,
158 region_id=phrase.region_id,
159 start_beat=phrase.start_beat,
160 end_beat=phrase.end_beat,
161 label=phrase.label,
162 tags=phrase.tags or [],
163 explanation=phrase.explanation,
164 cc_events=phrase.cc_events,
165 pitch_bends=phrase.pitch_bends,
166 aftertouch=phrase.aftertouch,
167 region_start_beat=r_meta.get("startBeat"),
168 region_duration_beats=r_meta.get("durationBeats"),
169 region_name=r_meta.get("name"),
170 )
171 session.add(p_row)
172
173 for nc in phrase.note_changes:
174 nc_row = db.NoteChange(
175 id=str(uuid.uuid4()),
176 phrase_id=phrase.phrase_id,
177 change_type=nc.change_type,
178 before_json=nc.before.model_dump() if nc.before else None,
179 after_json=nc.after.model_dump() if nc.after else None,
180 )
181 session.add(nc_row)
182
183 await session.flush()
184 logger.info(
185 "✅ Variation persisted: %s (%d phrases)",
186 variation.variation_id[:8],
187 len(variation.phrases),
188 )
189
190
191 async def load_variation(
192 session: AsyncSession,
193 variation_id: str,
194 ) -> DomainVariation | None:
195 """Load a persisted variation and reconstruct the domain model.
196
197 Returns None if the variation_id does not exist in the DB.
198 """
199 stmt = (
200 select(db.Variation)
201 .options(
202 selectinload(db.Variation.phrases).selectinload(db.Phrase.note_changes)
203 )
204 .where(db.Variation.variation_id == variation_id)
205 )
206 result = await session.execute(stmt)
207 row = result.scalar_one_or_none()
208 if row is None:
209 return None
210
211 phrases: list[DomainPhrase] = []
212 for p in sorted(row.phrases, key=lambda p: p.sequence):
213 note_changes = [
214 DomainNoteChange(
215 note_id=nc.id,
216 change_type=_validate_change_type(nc.change_type),
217 before=MidiNoteSnapshot.model_validate(nc.before_json) if nc.before_json else None,
218 after=MidiNoteSnapshot.model_validate(nc.after_json) if nc.after_json else None,
219 )
220 for nc in p.note_changes
221 ]
222 phrases.append(DomainPhrase(
223 phrase_id=p.phrase_id,
224 track_id=p.track_id,
225 region_id=p.region_id,
226 start_beat=p.start_beat,
227 end_beat=p.end_beat,
228 label=p.label,
229 note_changes=note_changes,
230 cc_events=[_parse_cc_event(ev) for ev in (p.cc_events or [])],
231 pitch_bends=[_parse_pitch_bend(ev) for ev in (p.pitch_bends or [])],
232 aftertouch=[_parse_aftertouch(ev) for ev in (p.aftertouch or [])],
233 explanation=p.explanation,
234 tags=p.tags or [],
235 ))
236
237 beat_starts = [p.start_beat for p in phrases] if phrases else [0.0]
238 beat_ends = [p.end_beat for p in phrases] if phrases else [0.0]
239
240 return DomainVariation(
241 variation_id=row.variation_id,
242 intent=row.intent,
243 ai_explanation=row.explanation,
244 affected_tracks=row.affected_tracks or [],
245 affected_regions=row.affected_regions or [],
246 beat_range=(min(beat_starts), max(beat_ends)),
247 phrases=phrases,
248 )
249
250
251 async def get_status(
252 session: AsyncSession,
253 variation_id: str,
254 ) -> str | None:
255 """Return the current status string, or None if not found."""
256 stmt = select(db.Variation.status).where(
257 db.Variation.variation_id == variation_id
258 )
259 result = await session.execute(stmt)
260 return result.scalar_one_or_none()
261
262
263 async def get_base_state_id(
264 session: AsyncSession,
265 variation_id: str,
266 ) -> str | None:
267 """Return the base_state_id for a variation, or None if not found."""
268 stmt = select(db.Variation.base_state_id).where(
269 db.Variation.variation_id == variation_id
270 )
271 result = await session.execute(stmt)
272 return result.scalar_one_or_none()
273
274
275 async def get_phrase_ids(
276 session: AsyncSession,
277 variation_id: str,
278 ) -> list[str]:
279 """Return phrase IDs for a variation in sequence order."""
280 stmt = (
281 select(db.Phrase.phrase_id)
282 .where(db.Phrase.variation_id == variation_id)
283 .order_by(db.Phrase.sequence)
284 )
285 result = await session.execute(stmt)
286 return list(result.scalars().all())
287
288
289 async def get_region_metadata(
290 session: AsyncSession,
291 variation_id: str,
292 ) -> dict[str, RegionMetadataDB]:
293 """Return region metadata keyed by region_id from persisted phrases."""
294 stmt = (
295 select(
296 db.Phrase.region_id,
297 db.Phrase.region_start_beat,
298 db.Phrase.region_duration_beats,
299 db.Phrase.region_name,
300 )
301 .where(db.Phrase.variation_id == variation_id)
302 )
303 result = await session.execute(stmt)
304 meta: dict[str, RegionMetadataDB] = {}
305 for row in result:
306 rid = row[0]
307 if rid not in meta:
308 meta[rid] = {
309 "start_beat": row[1],
310 "duration_beats": row[2],
311 "name": row[3],
312 }
313 return meta
314
315
316 async def mark_committed(session: AsyncSession, variation_id: str) -> None:
317 """Transition a variation to COMMITTED status."""
318 stmt = (
319 update(db.Variation)
320 .where(db.Variation.variation_id == variation_id)
321 .values(status="committed")
322 )
323 await session.execute(stmt)
324 logger.info("Variation %s marked committed", variation_id[:8])
325
326
327 async def mark_discarded(session: AsyncSession, variation_id: str) -> None:
328 """Transition a variation to DISCARDED status."""
329 stmt = (
330 update(db.Variation)
331 .where(db.Variation.variation_id == variation_id)
332 .values(status="discarded")
333 )
334 await session.execute(stmt)
335 logger.info("Variation %s marked discarded", variation_id[:8])
336
337
338 # ── Lineage / History Graph (Phase 5) ────────────────────────────────────
339
340
341 async def get_head(session: AsyncSession, project_id: str) -> HistoryNode | None:
342 """Return the current HEAD variation for a project, or None."""
343 stmt = (
344 select(db.Variation)
345 .where(db.Variation.project_id == project_id, db.Variation.is_head.is_(True))
346 )
347 result = await session.execute(stmt)
348 row = result.scalar_one_or_none()
349 if row is None:
350 return None
351 return HistoryNode(
352 variation_id=row.variation_id,
353 parent_variation_id=row.parent_variation_id,
354 commit_state_id=row.commit_state_id,
355 created_at=row.created_at,
356 )
357
358
359 async def set_head(
360 session: AsyncSession,
361 variation_id: str,
362 *,
363 commit_state_id: str | None = None,
364 ) -> None:
365 """Mark a variation as HEAD for its project, clearing any previous HEAD.
366
367 Only call this when a variation is committed — HEAD tracks the latest
368 committed point in the project timeline.
369 """
370 # Find the project_id for this variation
371 proj_stmt = select(db.Variation.project_id).where(
372 db.Variation.variation_id == variation_id
373 )
374 proj_result = await session.execute(proj_stmt)
375 project_id = proj_result.scalar_one_or_none()
376 if project_id is None:
377 logger.warning("⚠️ set_head: variation %s not found", variation_id[:8])
378 return
379
380 # Clear existing HEAD(s) for this project
381 clear_stmt = (
382 update(db.Variation)
383 .where(db.Variation.project_id == project_id, db.Variation.is_head.is_(True))
384 .values(is_head=False)
385 )
386 await session.execute(clear_stmt)
387
388 # set new HEAD
389 set_stmt = (
390 update(db.Variation)
391 .where(db.Variation.variation_id == variation_id)
392 .values(
393 is_head=True,
394 **({"commit_state_id": commit_state_id} if commit_state_id is not None else {}),
395 )
396 )
397 await session.execute(set_stmt)
398 logger.info("✅ HEAD set: %s (project %s)", variation_id[:8], project_id[:8])
399
400
401 async def move_head(
402 session: AsyncSession,
403 project_id: str,
404 variation_id: str,
405 ) -> None:
406 """Move HEAD pointer without mutating StateStore.
407
408 This is a soft undo/redo primitive. Future endpoints will combine this
409 with a replay plan to reconstruct the target state.
410 """
411 # Clear existing HEAD(s)
412 clear_stmt = (
413 update(db.Variation)
414 .where(db.Variation.project_id == project_id, db.Variation.is_head.is_(True))
415 .values(is_head=False)
416 )
417 await session.execute(clear_stmt)
418
419 # Move HEAD to target
420 set_stmt = (
421 update(db.Variation)
422 .where(db.Variation.variation_id == variation_id)
423 .values(is_head=True)
424 )
425 await session.execute(set_stmt)
426 logger.info("✅ HEAD moved to %s (project %s)", variation_id[:8], project_id[:8])
427
428
429 async def get_children(
430 session: AsyncSession,
431 variation_id: str,
432 ) -> list[HistoryNode]:
433 """Return child HistoryNodes (variations whose parent is variation_id)."""
434 stmt = (
435 select(db.Variation)
436 .where(db.Variation.parent_variation_id == variation_id)
437 .order_by(db.Variation.created_at)
438 )
439 result = await session.execute(stmt)
440 return [
441 HistoryNode(
442 variation_id=row.variation_id,
443 parent_variation_id=row.parent_variation_id,
444 commit_state_id=row.commit_state_id,
445 created_at=row.created_at,
446 )
447 for row in result.scalars().all()
448 ]
449
450
451 async def get_lineage(
452 session: AsyncSession,
453 variation_id: str,
454 ) -> list[HistoryNode]:
455 """Walk parent_variation_id chain from variation_id to root.
456
457 Returns nodes in root-first order: [root, ..., target].
458 """
459 chain: list[HistoryNode] = []
460 current_id: str | None = variation_id
461
462 while current_id is not None:
463 stmt = select(db.Variation).where(db.Variation.variation_id == current_id)
464 result = await session.execute(stmt)
465 row = result.scalar_one_or_none()
466 if row is None:
467 break
468 chain.append(HistoryNode(
469 variation_id=row.variation_id,
470 parent_variation_id=row.parent_variation_id,
471 commit_state_id=row.commit_state_id,
472 created_at=row.created_at,
473 ))
474 current_id = row.parent_variation_id
475
476 chain.reverse()
477 return chain
478
479
480 # ── Bulk queries (Phase 13) ───────────────────────────────────────────────
481
482
483 async def get_variations_for_project(
484 session: AsyncSession,
485 project_id: str,
486 ) -> list[VariationSummary]:
487 """Fetch all variations for a project in a single query.
488
489 Eagerly loads phrases to extract affected region IDs.
490 Returned in creation order (earliest first).
491 """
492 stmt = (
493 select(db.Variation)
494 .options(selectinload(db.Variation.phrases))
495 .where(db.Variation.project_id == project_id)
496 .order_by(db.Variation.created_at)
497 )
498 result = await session.execute(stmt)
499 rows = result.scalars().all()
500
501 summaries: list[VariationSummary] = []
502 for row in rows:
503 region_ids = tuple(sorted({p.region_id for p in row.phrases}))
504 summaries.append(VariationSummary(
505 variation_id=row.variation_id,
506 parent_variation_id=row.parent_variation_id,
507 parent2_variation_id=row.parent2_variation_id,
508 is_head=row.is_head,
509 created_at=row.created_at,
510 intent=row.intent,
511 affected_regions=region_ids,
512 ))
513
514 return summaries