cgcardona / muse public
merge_engine.py python
380 lines 12.8 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Muse VCS merge engine — fast-forward and 3-way path-level merge.
2
3 Public API
4 ----------
5 Pure functions (no I/O):
6
7 - :func:`diff_snapshots` — paths that changed between two snapshot manifests.
8 - :func:`detect_conflicts` — paths changed on *both* branches since the base.
9 - :func:`apply_merge` — build merged manifest for a conflict-free 3-way merge.
10
11 Async helpers (require a DB session):
12
13 - :func:`find_merge_base` — lowest common ancestor (LCA) of two commits.
14
15 Filesystem helpers:
16
17 - :func:`read_merge_state` — detect and load an in-progress merge.
18 - :func:`write_merge_state` — persist conflict state before exiting.
19
20 ``MERGE_STATE.json`` schema
21 ---------------------------
22
23 .. code-block:: json
24
25 {
26 "base_commit": "abc123...",
27 "ours_commit": "def456...",
28 "theirs_commit": "789abc...",
29 "conflict_paths": ["beat.mid", "lead.mp3"],
30 "other_branch": "feature/experiment"
31 }
32
33 ``other_branch`` is optional; all other fields are required when conflicts exist.
34 """
35 from __future__ import annotations
36
37 import json
38 import logging
39 import pathlib
40 from collections import deque
41 from dataclasses import dataclass, field
42 from typing import TYPE_CHECKING
43
44 if TYPE_CHECKING:
45 from sqlalchemy.ext.asyncio import AsyncSession
46
47 logger = logging.getLogger(__name__)
48
49 _MERGE_STATE_FILENAME = "MERGE_STATE.json"
50
51
52 # ---------------------------------------------------------------------------
53 # MergeState dataclass
54 # ---------------------------------------------------------------------------
55
56
57 @dataclass(frozen=True)
58 class MergeState:
59 """Describes an in-progress merge with unresolved conflicts.
60
61 Attributes:
62 conflict_paths: Relative paths (POSIX) of files with merge conflicts.
63 base_commit: Commit ID of the common ancestor (merge base).
64 ours_commit: Commit ID of HEAD when the merge was initiated.
65 theirs_commit: Commit ID of the branch being merged in.
66 other_branch: Name of the branch being merged in, if recorded.
67 """
68
69 conflict_paths: list[str] = field(default_factory=list)
70 base_commit: str | None = None
71 ours_commit: str | None = None
72 theirs_commit: str | None = None
73 other_branch: str | None = None
74
75
76 # ---------------------------------------------------------------------------
77 # Filesystem helpers
78 # ---------------------------------------------------------------------------
79
80
81 def read_merge_state(root: pathlib.Path) -> MergeState | None:
82 """Return :class:`MergeState` if a merge is in progress, otherwise ``None``.
83
84 Reads ``.muse/MERGE_STATE.json`` from *root*. Returns ``None`` when the
85 file does not exist (no in-progress merge) or when it cannot be parsed.
86
87 Args:
88 root: The repository root directory (the directory containing ``.muse/``).
89
90 Returns:
91 A :class:`MergeState` instance describing the in-progress merge, or
92 ``None`` if no merge is in progress.
93 """
94 merge_state_path = root / ".muse" / _MERGE_STATE_FILENAME
95 if not merge_state_path.exists():
96 return None
97
98 try:
99 data: dict[str, object] = json.loads(merge_state_path.read_text())
100 except (json.JSONDecodeError, OSError) as exc:
101 logger.warning("⚠️ Failed to read %s: %s", _MERGE_STATE_FILENAME, exc)
102 return None
103
104 raw_conflicts = data.get("conflict_paths", [])
105 conflict_paths: list[str] = (
106 [str(c) for c in raw_conflicts] if isinstance(raw_conflicts, list) else []
107 )
108
109 def _str_or_none(key: str) -> str | None:
110 return str(data[key]) if key in data else None
111
112 return MergeState(
113 conflict_paths=conflict_paths,
114 base_commit=_str_or_none("base_commit"),
115 ours_commit=_str_or_none("ours_commit"),
116 theirs_commit=_str_or_none("theirs_commit"),
117 other_branch=_str_or_none("other_branch"),
118 )
119
120
121 def write_merge_state(
122 root: pathlib.Path,
123 *,
124 base_commit: str,
125 ours_commit: str,
126 theirs_commit: str,
127 conflict_paths: list[str],
128 other_branch: str | None = None,
129 ) -> None:
130 """Write ``.muse/MERGE_STATE.json`` to signal an in-progress conflicted merge.
131
132 Args:
133 root: Repository root (directory containing ``.muse/``).
134 base_commit: Commit ID of the merge base (LCA).
135 ours_commit: Commit ID of HEAD at merge time.
136 theirs_commit: Commit ID of the branch being merged in.
137 conflict_paths: List of POSIX paths with unresolved conflicts.
138 other_branch: Human-readable name of the branch being merged in.
139 """
140 merge_state_path = root / ".muse" / _MERGE_STATE_FILENAME
141 data: dict[str, object] = {
142 "base_commit": base_commit,
143 "ours_commit": ours_commit,
144 "theirs_commit": theirs_commit,
145 "conflict_paths": sorted(conflict_paths),
146 }
147 if other_branch is not None:
148 data["other_branch"] = other_branch
149 merge_state_path.write_text(json.dumps(data, indent=2))
150 logger.info("✅ Wrote MERGE_STATE.json with %d conflict(s)", len(conflict_paths))
151
152
153 def clear_merge_state(root: pathlib.Path) -> None:
154 """Remove ``.muse/MERGE_STATE.json`` after a successful merge or resolution."""
155 merge_state_path = root / ".muse" / _MERGE_STATE_FILENAME
156 if merge_state_path.exists():
157 merge_state_path.unlink()
158 logger.debug("✅ Cleared MERGE_STATE.json")
159
160
161 def apply_resolution(
162 root: pathlib.Path,
163 rel_path: str,
164 object_id: str,
165 ) -> None:
166 """Copy the object identified by *object_id* from the local store to ``muse-work/<rel_path>``.
167
168 Used by ``muse resolve --theirs`` and ``muse merge --abort`` to restore
169 a specific version of a file to the working directory without requiring
170 the caller to know the internal object store layout.
171
172 Args:
173 root: Repository root (directory containing ``.muse/``).
174 rel_path: POSIX path relative to ``muse-work/``.
175 object_id: sha256 hex digest of the desired object content.
176
177 Raises:
178 FileNotFoundError: If the object is not present in the local store.
179 This means the commit's objects were never fetched locally — the
180 caller should report a user-friendly error.
181 """
182 from maestro.muse_cli.object_store import read_object
183
184 content = read_object(root, object_id)
185 if content is None:
186 raise FileNotFoundError(
187 f"Object {object_id[:8]} for '{rel_path}' not found in local store."
188 )
189 dest = root / "muse-work" / rel_path
190 dest.parent.mkdir(parents=True, exist_ok=True)
191 dest.write_bytes(content)
192 logger.debug("✅ Restored '%s' from object %s", rel_path, object_id[:8])
193
194
195 def is_conflict_resolved(merge_state: MergeState, rel_path: str) -> bool:
196 """Return ``True`` if *rel_path* is NOT listed as a conflict in *merge_state*.
197
198 A path is resolved when it no longer appears in ``conflict_paths``.
199 Call this before marking a path resolved to detect double-resolve attempts.
200
201 Args:
202 merge_state: The current in-progress merge state.
203 rel_path: POSIX path to check (relative to ``muse-work/``).
204
205 Returns:
206 ``True`` if the path is already resolved, ``False`` if it still conflicts.
207 """
208 return rel_path not in merge_state.conflict_paths
209
210
211 # ---------------------------------------------------------------------------
212 # Pure merge functions (no I/O, no DB)
213 # ---------------------------------------------------------------------------
214
215
216 def diff_snapshots(
217 base_manifest: dict[str, str],
218 other_manifest: dict[str, str],
219 ) -> set[str]:
220 """Return the set of paths that differ between *base_manifest* and *other_manifest*.
221
222 A path is included when it was:
223
224 - **added** — present in *other* but absent from *base*.
225 - **deleted** — present in *base* but absent from *other*.
226 - **modified** — present in both but with a different ``object_id``.
227
228 Args:
229 base_manifest: ``{path: object_id}`` for the common ancestor snapshot.
230 other_manifest: ``{path: object_id}`` for the branch snapshot.
231
232 Returns:
233 Set of POSIX paths that changed.
234 """
235 base_paths = set(base_manifest.keys())
236 other_paths = set(other_manifest.keys())
237
238 added = other_paths - base_paths
239 deleted = base_paths - other_paths
240 common = base_paths & other_paths
241 modified = {p for p in common if base_manifest[p] != other_manifest[p]}
242
243 return added | deleted | modified
244
245
246 def detect_conflicts(
247 ours_changed: set[str],
248 theirs_changed: set[str],
249 ) -> set[str]:
250 """Return paths changed on *both* branches since the merge base.
251
252 A conflict occurs when both ``ours`` and ``theirs`` modified the same path
253 independently. The caller decides how to handle these (write
254 ``MERGE_STATE.json`` and exit, or apply one side's version).
255
256 Args:
257 ours_changed: Paths changed on the current branch since the base.
258 theirs_changed: Paths changed on the target branch since the base.
259
260 Returns:
261 Set of conflicting POSIX paths.
262 """
263 return ours_changed & theirs_changed
264
265
266 def apply_merge(
267 base_manifest: dict[str, str],
268 ours_manifest: dict[str, str],
269 theirs_manifest: dict[str, str],
270 ours_changed: set[str],
271 theirs_changed: set[str],
272 conflict_paths: set[str],
273 ) -> dict[str, str]:
274 """Build the merged snapshot manifest for a *conflict-free* 3-way merge.
275
276 Only non-conflicting changes are applied:
277
278 - Paths changed only on ours → take ours version (or deletion).
279 - Paths changed only on theirs → take theirs version (or deletion).
280 - Conflict paths → excluded (caller already wrote ``MERGE_STATE.json``).
281
282 Args:
283 base_manifest: ``{path: object_id}`` for the common ancestor.
284 ours_manifest: ``{path: object_id}`` for the current branch HEAD.
285 theirs_manifest: ``{path: object_id}`` for the target branch HEAD.
286 ours_changed: Paths changed on the current branch since base.
287 theirs_changed: Paths changed on the target branch since base.
288 conflict_paths: Paths with conflicts (must be empty for a clean merge).
289
290 Returns:
291 Merged ``{path: object_id}`` manifest.
292 """
293 merged: dict[str, str] = dict(base_manifest)
294
295 # Apply non-conflicting ours changes.
296 for path in ours_changed - conflict_paths:
297 if path in ours_manifest:
298 merged[path] = ours_manifest[path]
299 else:
300 merged.pop(path, None)
301
302 # Apply non-conflicting theirs changes.
303 for path in theirs_changed - conflict_paths:
304 if path in theirs_manifest:
305 merged[path] = theirs_manifest[path]
306 else:
307 merged.pop(path, None)
308
309 return merged
310
311
312 # ---------------------------------------------------------------------------
313 # Async merge helpers (require a DB session)
314 # ---------------------------------------------------------------------------
315
316
317 async def find_merge_base(
318 session: AsyncSession,
319 commit_id_a: str,
320 commit_id_b: str,
321 ) -> str | None:
322 """Find the Lowest Common Ancestor (LCA) of two commits.
323
324 Uses BFS to collect all ancestors of *commit_id_a* (inclusive), then
325 walks *commit_id_b*'s ancestor graph (BFS) until the first node found
326 in *a*'s ancestor set is reached.
327
328 Supports merge commits with two parents (``parent_commit_id`` and
329 ``parent2_commit_id``).
330
331 Args:
332 session: An open async DB session.
333 commit_id_a: First commit ID (e.g., current branch HEAD).
334 commit_id_b: Second commit ID (e.g., target branch HEAD).
335
336 Returns:
337 The LCA commit ID, or ``None`` if the commits share no common ancestor
338 (disjoint histories).
339 """
340 from maestro.muse_cli.models import MuseCliCommit
341
342 async def _all_ancestors(start: str) -> set[str]:
343 """BFS from *start*, returning all reachable commit IDs (inclusive)."""
344 visited: set[str] = set()
345 queue: deque[str] = deque([start])
346 while queue:
347 cid = queue.popleft()
348 if cid in visited:
349 continue
350 visited.add(cid)
351 commit: MuseCliCommit | None = await session.get(MuseCliCommit, cid)
352 if commit is None:
353 continue
354 if commit.parent_commit_id:
355 queue.append(commit.parent_commit_id)
356 if commit.parent2_commit_id:
357 queue.append(commit.parent2_commit_id)
358 return visited
359
360 a_ancestors = await _all_ancestors(commit_id_a)
361
362 # BFS from B — return the first node that is in A's ancestor set.
363 visited_b: set[str] = set()
364 queue_b: deque[str] = deque([commit_id_b])
365 while queue_b:
366 cid = queue_b.popleft()
367 if cid in visited_b:
368 continue
369 visited_b.add(cid)
370 if cid in a_ancestors:
371 return cid
372 commit = await session.get(MuseCliCommit, cid)
373 if commit is None:
374 continue
375 if commit.parent_commit_id:
376 queue_b.append(commit.parent_commit_id)
377 if commit.parent2_commit_id:
378 queue_b.append(commit.parent2_commit_id)
379
380 return None