db.py
python
| 1 | """Async database helpers for the Muse CLI commit pipeline. |
| 2 | |
| 3 | Provides: |
| 4 | - ``open_session()`` — async context manager that opens and commits a |
| 5 | standalone AsyncSession (for use in the CLI, outside FastAPI DI). |
| 6 | - CRUD helpers called by ``commands/commit.py``, ``commands/meter.py``, |
| 7 | and ``commands/read_tree.py``. |
| 8 | |
| 9 | The session factory created by ``open_session()`` reads DATABASE_URL |
| 10 | from ``maestro.config.settings`` — the same env var used by the main |
| 11 | FastAPI app. Inside Docker all containers have this set; outside Docker |
| 12 | users need to export it before running ``muse commit``. |
| 13 | """ |
| 14 | from __future__ import annotations |
| 15 | |
| 16 | import contextlib |
| 17 | import logging |
| 18 | from collections.abc import AsyncGenerator |
| 19 | |
| 20 | from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine |
| 21 | from sqlalchemy.future import select |
| 22 | |
| 23 | from maestro.config import settings |
| 24 | from maestro.muse_cli.models import MuseCliCommit, MuseCliObject, MuseCliSnapshot |
| 25 | |
| 26 | logger = logging.getLogger(__name__) |
| 27 | |
| 28 | |
| 29 | @contextlib.asynccontextmanager |
| 30 | async def open_session(url: str | None = None) -> AsyncGenerator[AsyncSession, None]: |
| 31 | """Open a standalone async DB session suitable for CLI commands. |
| 32 | |
| 33 | Commits on clean exit, rolls back on exception. Disposes the engine |
| 34 | on exit so the process does not linger with open connections. |
| 35 | |
| 36 | ``url`` defaults to ``settings.database_url`` which reads the |
| 37 | ``DATABASE_URL`` env var. Pass an explicit URL in tests. |
| 38 | """ |
| 39 | db_url = url or settings.database_url |
| 40 | if not db_url: |
| 41 | raise RuntimeError( |
| 42 | "DATABASE_URL is not set. " |
| 43 | "Run inside Docker or export DATABASE_URL before calling muse commit." |
| 44 | ) |
| 45 | engine = create_async_engine(db_url, echo=False) |
| 46 | factory = async_sessionmaker(bind=engine, expire_on_commit=False) |
| 47 | try: |
| 48 | async with factory() as session: |
| 49 | try: |
| 50 | yield session |
| 51 | await session.commit() |
| 52 | except Exception: |
| 53 | await session.rollback() |
| 54 | raise |
| 55 | finally: |
| 56 | await engine.dispose() |
| 57 | |
| 58 | |
| 59 | async def upsert_object(session: AsyncSession, object_id: str, size_bytes: int) -> None: |
| 60 | """Insert a MuseCliObject row, ignoring duplicates (content-addressed).""" |
| 61 | existing = await session.get(MuseCliObject, object_id) |
| 62 | if existing is None: |
| 63 | session.add(MuseCliObject(object_id=object_id, size_bytes=size_bytes)) |
| 64 | logger.debug("✅ New object %s (%d bytes)", object_id[:8], size_bytes) |
| 65 | else: |
| 66 | logger.debug("⚠️ Object %s already exists — skipped", object_id[:8]) |
| 67 | |
| 68 | |
| 69 | async def upsert_snapshot( |
| 70 | session: AsyncSession, manifest: dict[str, str], snapshot_id: str |
| 71 | ) -> MuseCliSnapshot: |
| 72 | """Insert a MuseCliSnapshot row, ignoring duplicates.""" |
| 73 | existing = await session.get(MuseCliSnapshot, snapshot_id) |
| 74 | if existing is not None: |
| 75 | logger.debug("⚠️ Snapshot %s already exists — skipped", snapshot_id[:8]) |
| 76 | return existing |
| 77 | snap = MuseCliSnapshot(snapshot_id=snapshot_id, manifest=manifest) |
| 78 | session.add(snap) |
| 79 | logger.debug("✅ New snapshot %s (%d files)", snapshot_id[:8], len(manifest)) |
| 80 | return snap |
| 81 | |
| 82 | |
| 83 | async def insert_commit(session: AsyncSession, commit: MuseCliCommit) -> None: |
| 84 | """Insert a new MuseCliCommit row. |
| 85 | |
| 86 | Does NOT ignore duplicates — calling this twice with the same commit_id |
| 87 | is a programming error and will raise an IntegrityError. |
| 88 | """ |
| 89 | session.add(commit) |
| 90 | logger.debug("✅ New commit %s branch=%r", commit.commit_id[:8], commit.branch) |
| 91 | |
| 92 | |
| 93 | async def get_head_snapshot_id( |
| 94 | session: AsyncSession, repo_id: str, branch: str |
| 95 | ) -> str | None: |
| 96 | """Return the snapshot_id of the most recent commit on *branch*, or None.""" |
| 97 | result = await session.execute( |
| 98 | select(MuseCliCommit.snapshot_id) |
| 99 | .where(MuseCliCommit.repo_id == repo_id, MuseCliCommit.branch == branch) |
| 100 | .order_by(MuseCliCommit.committed_at.desc()) |
| 101 | .limit(1) |
| 102 | ) |
| 103 | row = result.scalar_one_or_none() |
| 104 | return row |
| 105 | |
| 106 | |
| 107 | async def get_commit_snapshot_manifest( |
| 108 | session: AsyncSession, commit_id: str |
| 109 | ) -> dict[str, str] | None: |
| 110 | """Return the file manifest for the snapshot attached to *commit_id*, or None. |
| 111 | |
| 112 | Fetches the :class:`MuseCliCommit` row by primary key, then loads its |
| 113 | :class:`MuseCliSnapshot` to retrieve the manifest. Returns ``None`` |
| 114 | when either row is missing (which should not occur in a consistent DB). |
| 115 | """ |
| 116 | commit = await session.get(MuseCliCommit, commit_id) |
| 117 | if commit is None: |
| 118 | logger.warning("⚠️ Commit %s not found in DB", commit_id[:8]) |
| 119 | return None |
| 120 | snapshot = await session.get(MuseCliSnapshot, commit.snapshot_id) |
| 121 | if snapshot is None: |
| 122 | logger.warning( |
| 123 | "⚠️ Snapshot %s referenced by commit %s not found in DB", |
| 124 | commit.snapshot_id[:8], |
| 125 | commit_id[:8], |
| 126 | ) |
| 127 | return None |
| 128 | return dict(snapshot.manifest) |
| 129 | |
| 130 | |
| 131 | async def resolve_commit_ref( |
| 132 | session: AsyncSession, |
| 133 | repo_id: str, |
| 134 | branch: str, |
| 135 | ref: str | None, |
| 136 | ) -> MuseCliCommit | None: |
| 137 | """Resolve a commit reference to a ``MuseCliCommit`` row. |
| 138 | |
| 139 | *ref* may be: |
| 140 | |
| 141 | - ``None`` / ``"HEAD"`` — returns the most recent commit on *branch*. |
| 142 | - A full or abbreviated commit SHA — looks up by exact or prefix match. |
| 143 | |
| 144 | Returns ``None`` when no matching commit is found. |
| 145 | """ |
| 146 | if ref is None or ref.upper() == "HEAD": |
| 147 | result = await session.execute( |
| 148 | select(MuseCliCommit) |
| 149 | .where(MuseCliCommit.repo_id == repo_id, MuseCliCommit.branch == branch) |
| 150 | .order_by(MuseCliCommit.committed_at.desc()) |
| 151 | .limit(1) |
| 152 | ) |
| 153 | return result.scalar_one_or_none() |
| 154 | |
| 155 | # Try exact match first |
| 156 | commit = await session.get(MuseCliCommit, ref) |
| 157 | if commit is not None: |
| 158 | return commit |
| 159 | |
| 160 | # Abbreviated SHA prefix match (scan required — acceptable for CLI use) |
| 161 | result = await session.execute( |
| 162 | select(MuseCliCommit).where( |
| 163 | MuseCliCommit.repo_id == repo_id, |
| 164 | MuseCliCommit.commit_id.startswith(ref), |
| 165 | ) |
| 166 | ) |
| 167 | return result.scalars().first() |
| 168 | |
| 169 | |
| 170 | async def set_commit_tempo_bpm( |
| 171 | session: AsyncSession, |
| 172 | commit_id: str, |
| 173 | bpm: float, |
| 174 | ) -> MuseCliCommit | None: |
| 175 | """Annotate *commit_id* with an explicit BPM in its ``metadata`` JSON blob. |
| 176 | |
| 177 | Merges into the existing metadata dict so other annotations are preserved. |
| 178 | Returns the updated ``MuseCliCommit`` row, or ``None`` when not found. |
| 179 | """ |
| 180 | commit = await session.get(MuseCliCommit, commit_id) |
| 181 | if commit is None: |
| 182 | return None |
| 183 | existing: dict[str, object] = dict(commit.commit_metadata or {}) |
| 184 | existing["tempo_bpm"] = bpm |
| 185 | commit.commit_metadata = existing |
| 186 | session.add(commit) |
| 187 | logger.debug("✅ Set tempo %.2f BPM on commit %s", bpm, commit_id[:8]) |
| 188 | return commit |
| 189 | |
| 190 | |
| 191 | async def get_commits_for_branch( |
| 192 | session: AsyncSession, |
| 193 | repo_id: str, |
| 194 | branch: str, |
| 195 | ) -> list[MuseCliCommit]: |
| 196 | """Return all commits on *branch* for *repo_id*, newest first. |
| 197 | |
| 198 | Used by ``muse push`` to collect commits since the last known remote head. |
| 199 | Ordering is newest-first so callers can slice from the front to get the |
| 200 | delta since a known commit. |
| 201 | """ |
| 202 | result = await session.execute( |
| 203 | select(MuseCliCommit) |
| 204 | .where(MuseCliCommit.repo_id == repo_id, MuseCliCommit.branch == branch) |
| 205 | .order_by(MuseCliCommit.committed_at.desc()) |
| 206 | ) |
| 207 | return list(result.scalars().all()) |
| 208 | |
| 209 | |
| 210 | async def get_all_object_ids( |
| 211 | session: AsyncSession, |
| 212 | repo_id: str, |
| 213 | ) -> list[str]: |
| 214 | """Return all object IDs referenced by any snapshot in this repo. |
| 215 | |
| 216 | Used by ``muse pull`` to tell the Hub which objects we already have so |
| 217 | the Hub only sends the missing ones. |
| 218 | """ |
| 219 | from sqlalchemy import distinct |
| 220 | |
| 221 | result = await session.execute( |
| 222 | select(MuseCliCommit.snapshot_id).where( |
| 223 | MuseCliCommit.repo_id == repo_id |
| 224 | ) |
| 225 | ) |
| 226 | snapshot_ids = [row for row in result.scalars().all()] |
| 227 | if not snapshot_ids: |
| 228 | return [] |
| 229 | |
| 230 | # Collect all object_ids from all known snapshots |
| 231 | object_ids: set[str] = set() |
| 232 | for snap_id in snapshot_ids: |
| 233 | snapshot = await session.get(MuseCliSnapshot, snap_id) |
| 234 | if snapshot is not None and snapshot.manifest: |
| 235 | object_ids.update(snapshot.manifest.values()) |
| 236 | |
| 237 | return sorted(object_ids) |
| 238 | |
| 239 | |
| 240 | async def store_pulled_commit( |
| 241 | session: AsyncSession, |
| 242 | commit_data: dict[str, object], |
| 243 | ) -> bool: |
| 244 | """Persist a commit received from the Hub into local Postgres. |
| 245 | |
| 246 | Idempotent — silently skips if the commit already exists. Returns |
| 247 | ``True`` if the row was newly inserted, ``False`` if it already existed. |
| 248 | |
| 249 | The *commit_data* dict must contain the keys defined in |
| 250 | :class:`~maestro.muse_cli.hub_client.PullCommitPayload`. |
| 251 | """ |
| 252 | import datetime |
| 253 | |
| 254 | commit_id = str(commit_data.get("commit_id", "")) |
| 255 | if not commit_id: |
| 256 | logger.warning("⚠️ store_pulled_commit: missing commit_id — skipping") |
| 257 | return False |
| 258 | |
| 259 | existing = await session.get(MuseCliCommit, commit_id) |
| 260 | if existing is not None: |
| 261 | logger.debug("⚠️ Pulled commit %s already exists — skipped", commit_id[:8]) |
| 262 | return False |
| 263 | |
| 264 | snapshot_id = str(commit_data.get("snapshot_id", "")) |
| 265 | branch = str(commit_data.get("branch", "")) |
| 266 | message = str(commit_data.get("message", "")) |
| 267 | author = str(commit_data.get("author", "")) |
| 268 | committed_at_raw = str(commit_data.get("committed_at", "")) |
| 269 | parent_commit_id_raw = commit_data.get("parent_commit_id") |
| 270 | parent_commit_id: str | None = ( |
| 271 | str(parent_commit_id_raw) if parent_commit_id_raw is not None else None |
| 272 | ) |
| 273 | metadata_raw = commit_data.get("metadata") |
| 274 | commit_metadata: dict[str, object] | None = ( |
| 275 | dict(metadata_raw) |
| 276 | if isinstance(metadata_raw, dict) |
| 277 | else None |
| 278 | ) |
| 279 | |
| 280 | try: |
| 281 | committed_at = datetime.datetime.fromisoformat(committed_at_raw) |
| 282 | except ValueError: |
| 283 | committed_at = datetime.datetime.now(datetime.timezone.utc) |
| 284 | |
| 285 | # Ensure the snapshot row exists (as a stub if not present — objects are |
| 286 | # content-addressed so the manifest may arrive separately or be empty for |
| 287 | # Hub-side storage). |
| 288 | existing_snap = await session.get(MuseCliSnapshot, snapshot_id) |
| 289 | if existing_snap is None: |
| 290 | stub_snap = MuseCliSnapshot(snapshot_id=snapshot_id, manifest={}) |
| 291 | session.add(stub_snap) |
| 292 | await session.flush() |
| 293 | |
| 294 | new_commit = MuseCliCommit( |
| 295 | commit_id=commit_id, |
| 296 | repo_id=str(commit_data.get("repo_id", "")), |
| 297 | branch=branch, |
| 298 | parent_commit_id=parent_commit_id, |
| 299 | snapshot_id=snapshot_id, |
| 300 | message=message, |
| 301 | author=author, |
| 302 | committed_at=committed_at, |
| 303 | commit_metadata=commit_metadata, |
| 304 | ) |
| 305 | session.add(new_commit) |
| 306 | logger.debug("✅ Stored pulled commit %s branch=%r", commit_id[:8], branch) |
| 307 | return True |
| 308 | |
| 309 | |
| 310 | async def store_pulled_object( |
| 311 | session: AsyncSession, |
| 312 | object_data: dict[str, object], |
| 313 | ) -> bool: |
| 314 | """Persist an object descriptor received from the Hub into local Postgres. |
| 315 | |
| 316 | Idempotent — silently skips if the object already exists. Returns |
| 317 | ``True`` if the row was newly inserted, ``False`` if it already existed. |
| 318 | """ |
| 319 | object_id = str(object_data.get("object_id", "")) |
| 320 | if not object_id: |
| 321 | logger.warning("⚠️ store_pulled_object: missing object_id — skipping") |
| 322 | return False |
| 323 | |
| 324 | size_raw = object_data.get("size_bytes", 0) |
| 325 | size_bytes = int(size_raw) if isinstance(size_raw, (int, float)) else 0 |
| 326 | |
| 327 | existing = await session.get(MuseCliObject, object_id) |
| 328 | if existing is not None: |
| 329 | logger.debug("⚠️ Pulled object %s already exists — skipped", object_id[:8]) |
| 330 | return False |
| 331 | |
| 332 | session.add(MuseCliObject(object_id=object_id, size_bytes=size_bytes)) |
| 333 | logger.debug("✅ Stored pulled object %s (%d bytes)", object_id[:8], size_bytes) |
| 334 | return True |
| 335 | |
| 336 | |
| 337 | async def find_commits_by_prefix( |
| 338 | session: AsyncSession, |
| 339 | prefix: str, |
| 340 | ) -> list[MuseCliCommit]: |
| 341 | """Return all commits whose ``commit_id`` starts with *prefix*. |
| 342 | |
| 343 | Used by commands that accept a short commit ID (e.g. ``muse export``, |
| 344 | ``muse open``, ``muse play``) to resolve a user-supplied prefix to a |
| 345 | full commit record before DB lookups. |
| 346 | """ |
| 347 | result = await session.execute( |
| 348 | select(MuseCliCommit).where(MuseCliCommit.commit_id.startswith(prefix)) |
| 349 | ) |
| 350 | return list(result.scalars().all()) |
| 351 | |
| 352 | |
| 353 | async def get_head_snapshot_manifest( |
| 354 | session: AsyncSession, repo_id: str, branch: str |
| 355 | ) -> dict[str, str] | None: |
| 356 | """Return the file manifest of the most recent commit on *branch*, or None. |
| 357 | |
| 358 | Fetches the latest commit's ``snapshot_id`` and then loads the |
| 359 | corresponding :class:`MuseCliSnapshot` row to retrieve its manifest. |
| 360 | Returns ``None`` when the branch has no commits or the snapshot row is |
| 361 | missing (which should not occur in a consistent database). |
| 362 | """ |
| 363 | snapshot_id = await get_head_snapshot_id(session, repo_id, branch) |
| 364 | if snapshot_id is None: |
| 365 | return None |
| 366 | snapshot = await session.get(MuseCliSnapshot, snapshot_id) |
| 367 | if snapshot is None: |
| 368 | logger.warning("⚠️ Snapshot %s referenced by HEAD not found in DB", snapshot_id[:8]) |
| 369 | return None |
| 370 | return dict(snapshot.manifest) |
| 371 | |
| 372 | |
| 373 | async def get_commit_extra_metadata( |
| 374 | session: AsyncSession, commit_id: str |
| 375 | ) -> dict[str, object] | None: |
| 376 | """Return the ``commit_metadata`` JSON blob for *commit_id*, or None. |
| 377 | |
| 378 | Returns ``None`` when the commit does not exist or when no metadata has |
| 379 | been stored yet (the column is nullable). |
| 380 | """ |
| 381 | commit = await session.get(MuseCliCommit, commit_id) |
| 382 | if commit is None: |
| 383 | return None |
| 384 | return dict(commit.commit_metadata) if commit.commit_metadata else None |
| 385 | |
| 386 | |
| 387 | async def set_commit_extra_metadata_key( |
| 388 | session: AsyncSession, |
| 389 | commit_id: str, |
| 390 | key: str, |
| 391 | value: object, |
| 392 | ) -> bool: |
| 393 | """Set a single key in the ``commit_metadata`` blob for *commit_id*. |
| 394 | |
| 395 | Merges *key* into the existing metadata dict (creating it if absent). |
| 396 | Returns ``True`` on success, ``False`` when *commit_id* is not found. |
| 397 | |
| 398 | The session must be committed by the caller (``open_session()`` commits |
| 399 | on clean exit). |
| 400 | """ |
| 401 | commit = await session.get(MuseCliCommit, commit_id) |
| 402 | if commit is None: |
| 403 | logger.warning("⚠️ Commit %s not found — cannot set metadata", commit_id[:8]) |
| 404 | return False |
| 405 | existing: dict[str, object] = dict(commit.commit_metadata) if commit.commit_metadata else {} |
| 406 | existing[key] = value |
| 407 | commit.commit_metadata = existing |
| 408 | # Mark the column as modified so SQLAlchemy flushes the JSON change. |
| 409 | from sqlalchemy.orm.attributes import flag_modified |
| 410 | flag_modified(commit, "commit_metadata") |
| 411 | logger.debug("✅ Set %s=%r on commit %s", key, value, commit_id[:8]) |
| 412 | return True |