cgcardona / muse public
db.py python
412 lines 14.6 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Async database helpers for the Muse CLI commit pipeline.
2
3 Provides:
4 - ``open_session()`` — async context manager that opens and commits a
5 standalone AsyncSession (for use in the CLI, outside FastAPI DI).
6 - CRUD helpers called by ``commands/commit.py``, ``commands/meter.py``,
7 and ``commands/read_tree.py``.
8
9 The session factory created by ``open_session()`` reads DATABASE_URL
10 from ``maestro.config.settings`` — the same env var used by the main
11 FastAPI app. Inside Docker all containers have this set; outside Docker
12 users need to export it before running ``muse commit``.
13 """
14 from __future__ import annotations
15
16 import contextlib
17 import logging
18 from collections.abc import AsyncGenerator
19
20 from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
21 from sqlalchemy.future import select
22
23 from maestro.config import settings
24 from maestro.muse_cli.models import MuseCliCommit, MuseCliObject, MuseCliSnapshot
25
26 logger = logging.getLogger(__name__)
27
28
29 @contextlib.asynccontextmanager
30 async def open_session(url: str | None = None) -> AsyncGenerator[AsyncSession, None]:
31 """Open a standalone async DB session suitable for CLI commands.
32
33 Commits on clean exit, rolls back on exception. Disposes the engine
34 on exit so the process does not linger with open connections.
35
36 ``url`` defaults to ``settings.database_url`` which reads the
37 ``DATABASE_URL`` env var. Pass an explicit URL in tests.
38 """
39 db_url = url or settings.database_url
40 if not db_url:
41 raise RuntimeError(
42 "DATABASE_URL is not set. "
43 "Run inside Docker or export DATABASE_URL before calling muse commit."
44 )
45 engine = create_async_engine(db_url, echo=False)
46 factory = async_sessionmaker(bind=engine, expire_on_commit=False)
47 try:
48 async with factory() as session:
49 try:
50 yield session
51 await session.commit()
52 except Exception:
53 await session.rollback()
54 raise
55 finally:
56 await engine.dispose()
57
58
59 async def upsert_object(session: AsyncSession, object_id: str, size_bytes: int) -> None:
60 """Insert a MuseCliObject row, ignoring duplicates (content-addressed)."""
61 existing = await session.get(MuseCliObject, object_id)
62 if existing is None:
63 session.add(MuseCliObject(object_id=object_id, size_bytes=size_bytes))
64 logger.debug("✅ New object %s (%d bytes)", object_id[:8], size_bytes)
65 else:
66 logger.debug("⚠️ Object %s already exists — skipped", object_id[:8])
67
68
69 async def upsert_snapshot(
70 session: AsyncSession, manifest: dict[str, str], snapshot_id: str
71 ) -> MuseCliSnapshot:
72 """Insert a MuseCliSnapshot row, ignoring duplicates."""
73 existing = await session.get(MuseCliSnapshot, snapshot_id)
74 if existing is not None:
75 logger.debug("⚠️ Snapshot %s already exists — skipped", snapshot_id[:8])
76 return existing
77 snap = MuseCliSnapshot(snapshot_id=snapshot_id, manifest=manifest)
78 session.add(snap)
79 logger.debug("✅ New snapshot %s (%d files)", snapshot_id[:8], len(manifest))
80 return snap
81
82
83 async def insert_commit(session: AsyncSession, commit: MuseCliCommit) -> None:
84 """Insert a new MuseCliCommit row.
85
86 Does NOT ignore duplicates — calling this twice with the same commit_id
87 is a programming error and will raise an IntegrityError.
88 """
89 session.add(commit)
90 logger.debug("✅ New commit %s branch=%r", commit.commit_id[:8], commit.branch)
91
92
93 async def get_head_snapshot_id(
94 session: AsyncSession, repo_id: str, branch: str
95 ) -> str | None:
96 """Return the snapshot_id of the most recent commit on *branch*, or None."""
97 result = await session.execute(
98 select(MuseCliCommit.snapshot_id)
99 .where(MuseCliCommit.repo_id == repo_id, MuseCliCommit.branch == branch)
100 .order_by(MuseCliCommit.committed_at.desc())
101 .limit(1)
102 )
103 row = result.scalar_one_or_none()
104 return row
105
106
107 async def get_commit_snapshot_manifest(
108 session: AsyncSession, commit_id: str
109 ) -> dict[str, str] | None:
110 """Return the file manifest for the snapshot attached to *commit_id*, or None.
111
112 Fetches the :class:`MuseCliCommit` row by primary key, then loads its
113 :class:`MuseCliSnapshot` to retrieve the manifest. Returns ``None``
114 when either row is missing (which should not occur in a consistent DB).
115 """
116 commit = await session.get(MuseCliCommit, commit_id)
117 if commit is None:
118 logger.warning("⚠️ Commit %s not found in DB", commit_id[:8])
119 return None
120 snapshot = await session.get(MuseCliSnapshot, commit.snapshot_id)
121 if snapshot is None:
122 logger.warning(
123 "⚠️ Snapshot %s referenced by commit %s not found in DB",
124 commit.snapshot_id[:8],
125 commit_id[:8],
126 )
127 return None
128 return dict(snapshot.manifest)
129
130
131 async def resolve_commit_ref(
132 session: AsyncSession,
133 repo_id: str,
134 branch: str,
135 ref: str | None,
136 ) -> MuseCliCommit | None:
137 """Resolve a commit reference to a ``MuseCliCommit`` row.
138
139 *ref* may be:
140
141 - ``None`` / ``"HEAD"`` — returns the most recent commit on *branch*.
142 - A full or abbreviated commit SHA — looks up by exact or prefix match.
143
144 Returns ``None`` when no matching commit is found.
145 """
146 if ref is None or ref.upper() == "HEAD":
147 result = await session.execute(
148 select(MuseCliCommit)
149 .where(MuseCliCommit.repo_id == repo_id, MuseCliCommit.branch == branch)
150 .order_by(MuseCliCommit.committed_at.desc())
151 .limit(1)
152 )
153 return result.scalar_one_or_none()
154
155 # Try exact match first
156 commit = await session.get(MuseCliCommit, ref)
157 if commit is not None:
158 return commit
159
160 # Abbreviated SHA prefix match (scan required — acceptable for CLI use)
161 result = await session.execute(
162 select(MuseCliCommit).where(
163 MuseCliCommit.repo_id == repo_id,
164 MuseCliCommit.commit_id.startswith(ref),
165 )
166 )
167 return result.scalars().first()
168
169
170 async def set_commit_tempo_bpm(
171 session: AsyncSession,
172 commit_id: str,
173 bpm: float,
174 ) -> MuseCliCommit | None:
175 """Annotate *commit_id* with an explicit BPM in its ``metadata`` JSON blob.
176
177 Merges into the existing metadata dict so other annotations are preserved.
178 Returns the updated ``MuseCliCommit`` row, or ``None`` when not found.
179 """
180 commit = await session.get(MuseCliCommit, commit_id)
181 if commit is None:
182 return None
183 existing: dict[str, object] = dict(commit.commit_metadata or {})
184 existing["tempo_bpm"] = bpm
185 commit.commit_metadata = existing
186 session.add(commit)
187 logger.debug("✅ Set tempo %.2f BPM on commit %s", bpm, commit_id[:8])
188 return commit
189
190
191 async def get_commits_for_branch(
192 session: AsyncSession,
193 repo_id: str,
194 branch: str,
195 ) -> list[MuseCliCommit]:
196 """Return all commits on *branch* for *repo_id*, newest first.
197
198 Used by ``muse push`` to collect commits since the last known remote head.
199 Ordering is newest-first so callers can slice from the front to get the
200 delta since a known commit.
201 """
202 result = await session.execute(
203 select(MuseCliCommit)
204 .where(MuseCliCommit.repo_id == repo_id, MuseCliCommit.branch == branch)
205 .order_by(MuseCliCommit.committed_at.desc())
206 )
207 return list(result.scalars().all())
208
209
210 async def get_all_object_ids(
211 session: AsyncSession,
212 repo_id: str,
213 ) -> list[str]:
214 """Return all object IDs referenced by any snapshot in this repo.
215
216 Used by ``muse pull`` to tell the Hub which objects we already have so
217 the Hub only sends the missing ones.
218 """
219 from sqlalchemy import distinct
220
221 result = await session.execute(
222 select(MuseCliCommit.snapshot_id).where(
223 MuseCliCommit.repo_id == repo_id
224 )
225 )
226 snapshot_ids = [row for row in result.scalars().all()]
227 if not snapshot_ids:
228 return []
229
230 # Collect all object_ids from all known snapshots
231 object_ids: set[str] = set()
232 for snap_id in snapshot_ids:
233 snapshot = await session.get(MuseCliSnapshot, snap_id)
234 if snapshot is not None and snapshot.manifest:
235 object_ids.update(snapshot.manifest.values())
236
237 return sorted(object_ids)
238
239
240 async def store_pulled_commit(
241 session: AsyncSession,
242 commit_data: dict[str, object],
243 ) -> bool:
244 """Persist a commit received from the Hub into local Postgres.
245
246 Idempotent — silently skips if the commit already exists. Returns
247 ``True`` if the row was newly inserted, ``False`` if it already existed.
248
249 The *commit_data* dict must contain the keys defined in
250 :class:`~maestro.muse_cli.hub_client.PullCommitPayload`.
251 """
252 import datetime
253
254 commit_id = str(commit_data.get("commit_id", ""))
255 if not commit_id:
256 logger.warning("⚠️ store_pulled_commit: missing commit_id — skipping")
257 return False
258
259 existing = await session.get(MuseCliCommit, commit_id)
260 if existing is not None:
261 logger.debug("⚠️ Pulled commit %s already exists — skipped", commit_id[:8])
262 return False
263
264 snapshot_id = str(commit_data.get("snapshot_id", ""))
265 branch = str(commit_data.get("branch", ""))
266 message = str(commit_data.get("message", ""))
267 author = str(commit_data.get("author", ""))
268 committed_at_raw = str(commit_data.get("committed_at", ""))
269 parent_commit_id_raw = commit_data.get("parent_commit_id")
270 parent_commit_id: str | None = (
271 str(parent_commit_id_raw) if parent_commit_id_raw is not None else None
272 )
273 metadata_raw = commit_data.get("metadata")
274 commit_metadata: dict[str, object] | None = (
275 dict(metadata_raw)
276 if isinstance(metadata_raw, dict)
277 else None
278 )
279
280 try:
281 committed_at = datetime.datetime.fromisoformat(committed_at_raw)
282 except ValueError:
283 committed_at = datetime.datetime.now(datetime.timezone.utc)
284
285 # Ensure the snapshot row exists (as a stub if not present — objects are
286 # content-addressed so the manifest may arrive separately or be empty for
287 # Hub-side storage).
288 existing_snap = await session.get(MuseCliSnapshot, snapshot_id)
289 if existing_snap is None:
290 stub_snap = MuseCliSnapshot(snapshot_id=snapshot_id, manifest={})
291 session.add(stub_snap)
292 await session.flush()
293
294 new_commit = MuseCliCommit(
295 commit_id=commit_id,
296 repo_id=str(commit_data.get("repo_id", "")),
297 branch=branch,
298 parent_commit_id=parent_commit_id,
299 snapshot_id=snapshot_id,
300 message=message,
301 author=author,
302 committed_at=committed_at,
303 commit_metadata=commit_metadata,
304 )
305 session.add(new_commit)
306 logger.debug("✅ Stored pulled commit %s branch=%r", commit_id[:8], branch)
307 return True
308
309
310 async def store_pulled_object(
311 session: AsyncSession,
312 object_data: dict[str, object],
313 ) -> bool:
314 """Persist an object descriptor received from the Hub into local Postgres.
315
316 Idempotent — silently skips if the object already exists. Returns
317 ``True`` if the row was newly inserted, ``False`` if it already existed.
318 """
319 object_id = str(object_data.get("object_id", ""))
320 if not object_id:
321 logger.warning("⚠️ store_pulled_object: missing object_id — skipping")
322 return False
323
324 size_raw = object_data.get("size_bytes", 0)
325 size_bytes = int(size_raw) if isinstance(size_raw, (int, float)) else 0
326
327 existing = await session.get(MuseCliObject, object_id)
328 if existing is not None:
329 logger.debug("⚠️ Pulled object %s already exists — skipped", object_id[:8])
330 return False
331
332 session.add(MuseCliObject(object_id=object_id, size_bytes=size_bytes))
333 logger.debug("✅ Stored pulled object %s (%d bytes)", object_id[:8], size_bytes)
334 return True
335
336
337 async def find_commits_by_prefix(
338 session: AsyncSession,
339 prefix: str,
340 ) -> list[MuseCliCommit]:
341 """Return all commits whose ``commit_id`` starts with *prefix*.
342
343 Used by commands that accept a short commit ID (e.g. ``muse export``,
344 ``muse open``, ``muse play``) to resolve a user-supplied prefix to a
345 full commit record before DB lookups.
346 """
347 result = await session.execute(
348 select(MuseCliCommit).where(MuseCliCommit.commit_id.startswith(prefix))
349 )
350 return list(result.scalars().all())
351
352
353 async def get_head_snapshot_manifest(
354 session: AsyncSession, repo_id: str, branch: str
355 ) -> dict[str, str] | None:
356 """Return the file manifest of the most recent commit on *branch*, or None.
357
358 Fetches the latest commit's ``snapshot_id`` and then loads the
359 corresponding :class:`MuseCliSnapshot` row to retrieve its manifest.
360 Returns ``None`` when the branch has no commits or the snapshot row is
361 missing (which should not occur in a consistent database).
362 """
363 snapshot_id = await get_head_snapshot_id(session, repo_id, branch)
364 if snapshot_id is None:
365 return None
366 snapshot = await session.get(MuseCliSnapshot, snapshot_id)
367 if snapshot is None:
368 logger.warning("⚠️ Snapshot %s referenced by HEAD not found in DB", snapshot_id[:8])
369 return None
370 return dict(snapshot.manifest)
371
372
373 async def get_commit_extra_metadata(
374 session: AsyncSession, commit_id: str
375 ) -> dict[str, object] | None:
376 """Return the ``commit_metadata`` JSON blob for *commit_id*, or None.
377
378 Returns ``None`` when the commit does not exist or when no metadata has
379 been stored yet (the column is nullable).
380 """
381 commit = await session.get(MuseCliCommit, commit_id)
382 if commit is None:
383 return None
384 return dict(commit.commit_metadata) if commit.commit_metadata else None
385
386
387 async def set_commit_extra_metadata_key(
388 session: AsyncSession,
389 commit_id: str,
390 key: str,
391 value: object,
392 ) -> bool:
393 """Set a single key in the ``commit_metadata`` blob for *commit_id*.
394
395 Merges *key* into the existing metadata dict (creating it if absent).
396 Returns ``True`` on success, ``False`` when *commit_id* is not found.
397
398 The session must be committed by the caller (``open_session()`` commits
399 on clean exit).
400 """
401 commit = await session.get(MuseCliCommit, commit_id)
402 if commit is None:
403 logger.warning("⚠️ Commit %s not found — cannot set metadata", commit_id[:8])
404 return False
405 existing: dict[str, object] = dict(commit.commit_metadata) if commit.commit_metadata else {}
406 existing[key] = value
407 commit.commit_metadata = existing
408 # Mark the column as modified so SQLAlchemy flushes the JSON change.
409 from sqlalchemy.orm.attributes import flag_modified
410 flag_modified(commit, "commit_metadata")
411 logger.debug("✅ Set %s=%r on commit %s", key, value, commit_id[:8])
412 return True