cgcardona / muse public
merge_engine.py python
280 lines 9.0 KB
77f04a8f feat: eliminate all Any/object/ignore — strict TypedDicts at every boundary Gabriel Cardona <gabriel@tellurstori.com> 3d ago
1 """Muse VCS merge engine — fast-forward and 3-way path-level merge.
2
3 Public API
4 ----------
5 Pure functions (no I/O):
6
7 - :func:`diff_snapshots` — paths that changed between two snapshot manifests.
8 - :func:`detect_conflicts` — paths changed on *both* branches since the base.
9 - :func:`apply_merge` — build merged manifest for a conflict-free 3-way merge.
10
11 File-based helpers:
12
13 - :func:`find_merge_base` — lowest common ancestor (LCA) of two commits.
14 - :func:`read_merge_state` — detect and load an in-progress merge.
15 - :func:`write_merge_state` — persist conflict state before exiting.
16 - :func:`clear_merge_state` — remove MERGE_STATE.json after resolution.
17 - :func:`apply_resolution` — restore a specific object version to muse-work/.
18
19 ``MERGE_STATE.json`` schema
20 ---------------------------
21
22 .. code-block:: json
23
24 {
25 "base_commit": "abc123...",
26 "ours_commit": "def456...",
27 "theirs_commit": "789abc...",
28 "conflict_paths": ["beat.mid", "lead.mp3"],
29 "other_branch": "feature/experiment"
30 }
31
32 ``other_branch`` is optional; all other fields are required when conflicts exist.
33 """
34 from __future__ import annotations
35
36 import json
37 import logging
38 import pathlib
39 from collections import deque
40 from dataclasses import dataclass, field
41 from typing import TypedDict
42
43 logger = logging.getLogger(__name__)
44
45 _MERGE_STATE_FILENAME = "MERGE_STATE.json"
46
47
48 # ---------------------------------------------------------------------------
49 # Wire-format TypedDict
50 # ---------------------------------------------------------------------------
51
52
53 class MergeStatePayload(TypedDict, total=False):
54 """JSON-serialisable form of an in-progress merge state."""
55
56 base_commit: str
57 ours_commit: str
58 theirs_commit: str
59 conflict_paths: list[str]
60 other_branch: str
61
62
63 # ---------------------------------------------------------------------------
64 # MergeState dataclass
65 # ---------------------------------------------------------------------------
66
67
68 @dataclass(frozen=True)
69 class MergeState:
70 """Describes an in-progress merge with unresolved conflicts."""
71
72 conflict_paths: list[str] = field(default_factory=list)
73 base_commit: str | None = None
74 ours_commit: str | None = None
75 theirs_commit: str | None = None
76 other_branch: str | None = None
77
78
79 # ---------------------------------------------------------------------------
80 # Filesystem helpers
81 # ---------------------------------------------------------------------------
82
83
84 def read_merge_state(root: pathlib.Path) -> MergeState | None:
85 """Return :class:`MergeState` if a merge is in progress, otherwise ``None``."""
86 merge_state_path = root / ".muse" / _MERGE_STATE_FILENAME
87 if not merge_state_path.exists():
88 return None
89 try:
90 data = json.loads(merge_state_path.read_text())
91 except (json.JSONDecodeError, OSError) as exc:
92 logger.warning("⚠️ Failed to read %s: %s", _MERGE_STATE_FILENAME, exc)
93 return None
94
95 raw_conflicts = data.get("conflict_paths", [])
96 conflict_paths: list[str] = (
97 [str(c) for c in raw_conflicts] if isinstance(raw_conflicts, list) else []
98 )
99
100 def _str_or_none(key: str) -> str | None:
101 val = data.get(key)
102 return str(val) if val is not None else None
103
104 return MergeState(
105 conflict_paths=conflict_paths,
106 base_commit=_str_or_none("base_commit"),
107 ours_commit=_str_or_none("ours_commit"),
108 theirs_commit=_str_or_none("theirs_commit"),
109 other_branch=_str_or_none("other_branch"),
110 )
111
112
113 def write_merge_state(
114 root: pathlib.Path,
115 *,
116 base_commit: str,
117 ours_commit: str,
118 theirs_commit: str,
119 conflict_paths: list[str],
120 other_branch: str | None = None,
121 ) -> None:
122 """Write ``.muse/MERGE_STATE.json`` to signal an in-progress conflicted merge."""
123 merge_state_path = root / ".muse" / _MERGE_STATE_FILENAME
124 payload: MergeStatePayload = {
125 "base_commit": base_commit,
126 "ours_commit": ours_commit,
127 "theirs_commit": theirs_commit,
128 "conflict_paths": sorted(conflict_paths),
129 }
130 if other_branch is not None:
131 payload["other_branch"] = other_branch
132 merge_state_path.write_text(json.dumps(payload, indent=2))
133 logger.info("✅ Wrote MERGE_STATE.json with %d conflict(s)", len(conflict_paths))
134
135
136 def clear_merge_state(root: pathlib.Path) -> None:
137 """Remove ``.muse/MERGE_STATE.json`` after a successful merge or resolution."""
138 merge_state_path = root / ".muse" / _MERGE_STATE_FILENAME
139 if merge_state_path.exists():
140 merge_state_path.unlink()
141 logger.debug("✅ Cleared MERGE_STATE.json")
142
143
144 def apply_resolution(
145 root: pathlib.Path,
146 rel_path: str,
147 object_id: str,
148 ) -> None:
149 """Copy the object identified by *object_id* from the local store to ``muse-work/<rel_path>``."""
150 from muse.core.object_store import read_object
151
152 content = read_object(root, object_id)
153 if content is None:
154 raise FileNotFoundError(
155 f"Object {object_id[:8]} for '{rel_path}' not found in local store."
156 )
157 dest = root / "muse-work" / rel_path
158 dest.parent.mkdir(parents=True, exist_ok=True)
159 dest.write_bytes(content)
160 logger.debug("✅ Restored '%s' from object %s", rel_path, object_id[:8])
161
162
163 def is_conflict_resolved(merge_state: MergeState, rel_path: str) -> bool:
164 """Return ``True`` if *rel_path* is NOT listed as a conflict in *merge_state*."""
165 return rel_path not in merge_state.conflict_paths
166
167
168 # ---------------------------------------------------------------------------
169 # Pure merge functions (no I/O)
170 # ---------------------------------------------------------------------------
171
172
173 def diff_snapshots(
174 base_manifest: dict[str, str],
175 other_manifest: dict[str, str],
176 ) -> set[str]:
177 """Return the set of paths that differ between *base_manifest* and *other_manifest*."""
178 base_paths = set(base_manifest.keys())
179 other_paths = set(other_manifest.keys())
180 added = other_paths - base_paths
181 deleted = base_paths - other_paths
182 common = base_paths & other_paths
183 modified = {p for p in common if base_manifest[p] != other_manifest[p]}
184 return added | deleted | modified
185
186
187 def detect_conflicts(
188 ours_changed: set[str],
189 theirs_changed: set[str],
190 ) -> set[str]:
191 """Return paths changed on *both* branches since the merge base."""
192 return ours_changed & theirs_changed
193
194
195 def apply_merge(
196 base_manifest: dict[str, str],
197 ours_manifest: dict[str, str],
198 theirs_manifest: dict[str, str],
199 ours_changed: set[str],
200 theirs_changed: set[str],
201 conflict_paths: set[str],
202 ) -> dict[str, str]:
203 """Build the merged snapshot manifest for a conflict-free 3-way merge."""
204 merged: dict[str, str] = dict(base_manifest)
205 for path in ours_changed - conflict_paths:
206 if path in ours_manifest:
207 merged[path] = ours_manifest[path]
208 else:
209 merged.pop(path, None)
210 for path in theirs_changed - conflict_paths:
211 if path in theirs_manifest:
212 merged[path] = theirs_manifest[path]
213 else:
214 merged.pop(path, None)
215 return merged
216
217
218 # ---------------------------------------------------------------------------
219 # File-based merge base finder
220 # ---------------------------------------------------------------------------
221
222
223 def find_merge_base(
224 repo_root: pathlib.Path,
225 commit_id_a: str,
226 commit_id_b: str,
227 ) -> str | None:
228 """Find the Lowest Common Ancestor (LCA) of two commits.
229
230 Uses BFS to collect all ancestors of *commit_id_a* (inclusive), then
231 walks *commit_id_b*'s ancestor graph (BFS) until the first node found
232 in *a*'s ancestor set is reached.
233
234 Args:
235 repo_root: The repository root directory.
236 commit_id_a: First commit ID (e.g., current branch HEAD).
237 commit_id_b: Second commit ID (e.g., target branch HEAD).
238
239 Returns:
240 The LCA commit ID, or ``None`` if the commits share no common ancestor.
241 """
242 from muse.core.store import read_commit
243
244 def _all_ancestors(start: str) -> set[str]:
245 visited: set[str] = set()
246 queue: deque[str] = deque([start])
247 while queue:
248 cid = queue.popleft()
249 if cid in visited:
250 continue
251 visited.add(cid)
252 commit = read_commit(repo_root, cid)
253 if commit is None:
254 continue
255 if commit.parent_commit_id:
256 queue.append(commit.parent_commit_id)
257 if commit.parent2_commit_id:
258 queue.append(commit.parent2_commit_id)
259 return visited
260
261 a_ancestors = _all_ancestors(commit_id_a)
262
263 visited_b: set[str] = set()
264 queue_b: deque[str] = deque([commit_id_b])
265 while queue_b:
266 cid = queue_b.popleft()
267 if cid in visited_b:
268 continue
269 visited_b.add(cid)
270 if cid in a_ancestors:
271 return cid
272 commit = read_commit(repo_root, cid)
273 if commit is None:
274 continue
275 if commit.parent_commit_id:
276 queue_b.append(commit.parent_commit_id)
277 if commit.parent2_commit_id:
278 queue_b.append(commit.parent2_commit_id)
279
280 return None