cgcardona / muse public
test_code_invariants.py python
289 lines 11.1 KB
8d5137ed fix(security): full surface hardening — validation, path containment, p… Gabriel Cardona <cgcardona@gmail.com> 8h ago
1 """Tests for the code-domain invariants engine."""
2
3 import pathlib
4 import tempfile
5
6 import pytest
7
8 from muse.core.invariants import InvariantChecker
9 from muse.plugins.code._invariants import (
10 CodeChecker,
11 CodeInvariantRule,
12 check_max_complexity,
13 check_no_circular_imports,
14 check_no_dead_exports,
15 check_test_coverage_floor,
16 load_invariant_rules,
17 run_invariants,
18 )
19
20
21 # ---------------------------------------------------------------------------
22 # Helpers
23 # ---------------------------------------------------------------------------
24
25
26 def _make_repo(tmp_path: pathlib.Path) -> pathlib.Path:
27 """Set up a minimal .muse/ structure."""
28 muse = tmp_path / ".muse"
29 muse.mkdir()
30 (muse / "repo.json").write_text('{"repo_id":"test"}')
31 (muse / "HEAD").write_text("refs/heads/main")
32 (muse / "commits").mkdir()
33 (muse / "snapshots").mkdir()
34 (muse / "refs" / "heads").mkdir(parents=True)
35 (muse / "objects").mkdir()
36 return tmp_path
37
38
39 def _write_object(root: pathlib.Path, content: bytes) -> str:
40 import hashlib
41 h = hashlib.sha256(content).hexdigest()
42 obj_path = root / ".muse" / "objects" / h[:2] / h[2:]
43 obj_path.parent.mkdir(parents=True, exist_ok=True)
44 obj_path.write_bytes(content)
45 return h
46
47
48 # ---------------------------------------------------------------------------
49 # _estimate_complexity (via check_max_complexity)
50 # ---------------------------------------------------------------------------
51
52
53 class TestMaxComplexity:
54 def test_simple_function_no_violation(self) -> None:
55 with tempfile.TemporaryDirectory() as tmp:
56 root = _make_repo(pathlib.Path(tmp))
57 src = b"def simple():\n return 1\n"
58 h = _write_object(root, src)
59 manifest = {"mod.py": h}
60 violations = check_max_complexity(manifest, root, "test", "error", threshold=10)
61 assert violations == []
62
63 def test_complex_function_triggers_violation(self) -> None:
64 # 15+ branches = definitely over threshold 5.
65 src = b"""
66 def complex():
67 if True:
68 pass
69 if True:
70 pass
71 if True:
72 pass
73 if True:
74 pass
75 if True:
76 pass
77 if True:
78 pass
79 if True:
80 pass
81 return 1
82 """
83 with tempfile.TemporaryDirectory() as tmp:
84 root = _make_repo(pathlib.Path(tmp))
85 h = _write_object(root, src)
86 manifest = {"mod.py": h}
87 violations = check_max_complexity(manifest, root, "gate", "error", threshold=5)
88 assert len(violations) >= 1
89 assert violations[0]["rule_name"] == "gate"
90 assert "complexity" in violations[0]["description"].lower()
91
92 def test_non_python_file_skipped(self) -> None:
93 with tempfile.TemporaryDirectory() as tmp:
94 root = _make_repo(pathlib.Path(tmp))
95 src = b"def hello() { return 1; }"
96 h = _write_object(root, src)
97 manifest = {"mod.js": h}
98 violations = check_max_complexity(manifest, root, "c", "error", threshold=1)
99 assert violations == []
100
101
102 # ---------------------------------------------------------------------------
103 # check_no_circular_imports
104 # ---------------------------------------------------------------------------
105
106
107 class TestNoCircularImports:
108 def test_no_cycle_returns_empty(self) -> None:
109 with tempfile.TemporaryDirectory() as tmp:
110 root = _make_repo(pathlib.Path(tmp))
111 a = b"import b\n"
112 b_src = b"x = 1\n"
113 ha = _write_object(root, a)
114 hb = _write_object(root, b_src)
115 manifest = {"a.py": ha, "b.py": hb}
116 violations = check_no_circular_imports(manifest, root, "no_cycles", "error")
117 assert violations == []
118
119 def test_cycle_detected(self) -> None:
120 with tempfile.TemporaryDirectory() as tmp:
121 root = _make_repo(pathlib.Path(tmp))
122 # a imports b, b imports a → cycle
123 a = b"import b\n"
124 b_src = b"import a\n"
125 ha = _write_object(root, a)
126 hb = _write_object(root, b_src)
127 manifest = {"a.py": ha, "b.py": hb}
128 violations = check_no_circular_imports(manifest, root, "no_cycles", "error")
129 assert len(violations) >= 1
130 assert "cycle" in violations[0]["description"].lower()
131
132 def test_three_file_cycle_detected(self) -> None:
133 with tempfile.TemporaryDirectory() as tmp:
134 root = _make_repo(pathlib.Path(tmp))
135 a = b"import b\n"
136 b_src = b"import c\n"
137 c_src = b"import a\n"
138 ha = _write_object(root, a)
139 hb = _write_object(root, b_src)
140 hc = _write_object(root, c_src)
141 manifest = {"a.py": ha, "b.py": hb, "c.py": hc}
142 violations = check_no_circular_imports(manifest, root, "cycles", "error")
143 assert len(violations) >= 1
144
145
146 # ---------------------------------------------------------------------------
147 # check_no_dead_exports
148 # ---------------------------------------------------------------------------
149
150
151 class TestNoDeadExports:
152 def test_used_function_not_reported(self) -> None:
153 with tempfile.TemporaryDirectory() as tmp:
154 root = _make_repo(pathlib.Path(tmp))
155 lib = b"def my_func():\n return 1\n"
156 main = b"from lib import my_func\n"
157 hl = _write_object(root, lib)
158 hm = _write_object(root, main)
159 manifest = {"lib.py": hl, "main.py": hm}
160 violations = check_no_dead_exports(manifest, root, "dead", "warning")
161 # lib.my_func is imported by main.py → should not be reported.
162 addresses = [v["address"] for v in violations]
163 assert "lib.py::my_func" not in addresses
164
165 def test_unused_function_reported(self) -> None:
166 with tempfile.TemporaryDirectory() as tmp:
167 root = _make_repo(pathlib.Path(tmp))
168 lib = b"def orphan_fn():\n return 1\n"
169 other = b"x = 1\n"
170 hl = _write_object(root, lib)
171 ho = _write_object(root, other)
172 manifest = {"lib.py": hl, "other.py": ho}
173 violations = check_no_dead_exports(manifest, root, "dead", "warning")
174 addresses = [v["address"] for v in violations]
175 assert "lib.py::orphan_fn" in addresses
176
177 def test_private_function_exempt(self) -> None:
178 with tempfile.TemporaryDirectory() as tmp:
179 root = _make_repo(pathlib.Path(tmp))
180 lib = b"def _private():\n return 1\n"
181 h = _write_object(root, lib)
182 manifest = {"lib.py": h}
183 violations = check_no_dead_exports(manifest, root, "dead", "warning")
184 # Private functions are exempt.
185 assert all("_private" not in v["address"] for v in violations)
186
187 def test_test_file_exempt(self) -> None:
188 with tempfile.TemporaryDirectory() as tmp:
189 root = _make_repo(pathlib.Path(tmp))
190 lib = b"def test_something():\n assert True\n"
191 h = _write_object(root, lib)
192 manifest = {"test_stuff.py": h}
193 violations = check_no_dead_exports(manifest, root, "dead", "warning")
194 assert violations == []
195
196
197 # ---------------------------------------------------------------------------
198 # check_test_coverage_floor
199 # ---------------------------------------------------------------------------
200
201
202 class TestTestCoverageFloor:
203 def test_well_covered_code_no_violation(self) -> None:
204 with tempfile.TemporaryDirectory() as tmp:
205 root = _make_repo(pathlib.Path(tmp))
206 src = b"def foo():\n return 1\n"
207 test_src = b"def test_foo():\n assert True\n"
208 hs = _write_object(root, src)
209 ht = _write_object(root, test_src)
210 manifest = {"src.py": hs, "test_src.py": ht}
211 violations = check_test_coverage_floor(manifest, root, "coverage", "warning", min_ratio=0.5)
212 assert violations == []
213
214 def test_uncovered_code_violates(self) -> None:
215 with tempfile.TemporaryDirectory() as tmp:
216 root = _make_repo(pathlib.Path(tmp))
217 src = b"def foo():\n pass\ndef bar():\n pass\ndef baz():\n pass\n"
218 h = _write_object(root, src)
219 manifest = {"src.py": h}
220 violations = check_test_coverage_floor(manifest, root, "coverage", "warning", min_ratio=0.5)
221 assert len(violations) == 1
222 assert "coverage floor" in violations[0]["description"].lower()
223
224 def test_no_functions_no_violation(self) -> None:
225 with tempfile.TemporaryDirectory() as tmp:
226 root = _make_repo(pathlib.Path(tmp))
227 src = b"X = 1\n"
228 h = _write_object(root, src)
229 manifest = {"config.py": h}
230 violations = check_test_coverage_floor(manifest, root, "coverage", "warning", min_ratio=0.5)
231 assert violations == []
232
233
234 # ---------------------------------------------------------------------------
235 # load_invariant_rules
236 # ---------------------------------------------------------------------------
237
238
239 class TestLoadInvariantRules:
240 def test_no_file_returns_defaults(self) -> None:
241 rules = load_invariant_rules(pathlib.Path("/no/such/file.toml"))
242 assert len(rules) >= 1
243 rule_types = {r["rule_type"] for r in rules}
244 assert "max_complexity" in rule_types
245
246 def test_toml_file_loaded(self) -> None:
247 import tempfile
248 toml = "[[rule]]\nname='r1'\nseverity='error'\nscope='function'\nrule_type='max_complexity'\n"
249 with tempfile.NamedTemporaryFile(suffix=".toml", mode="w", delete=False) as f:
250 f.write(toml)
251 path = pathlib.Path(f.name)
252 try:
253 rules = load_invariant_rules(path)
254 assert any(r["rule_type"] == "max_complexity" for r in rules)
255 finally:
256 path.unlink(missing_ok=True)
257
258
259 # ---------------------------------------------------------------------------
260 # CodeChecker (protocol)
261 # ---------------------------------------------------------------------------
262
263
264 class TestCodeChecker:
265 def test_satisfies_invariant_checker_protocol(self) -> None:
266 checker = CodeChecker()
267 assert isinstance(checker, InvariantChecker)
268
269 def test_check_returns_base_report(self) -> None:
270 with tempfile.TemporaryDirectory() as tmp:
271 root = _make_repo(pathlib.Path(tmp))
272 # No commits — check should return a report with 0 violations.
273 from muse.core.store import CommitRecord, SnapshotRecord, write_commit, write_snapshot
274 import datetime
275 snap = SnapshotRecord(snapshot_id="s" * 64, manifest={})
276 write_snapshot(root, snap)
277 commit = CommitRecord(
278 commit_id="abc123",
279 repo_id="test",
280 branch="main",
281 snapshot_id="s" * 64,
282 message="init",
283 committed_at=datetime.datetime.now(datetime.timezone.utc),
284 )
285 write_commit(root, commit)
286 report = CodeChecker().check(root, "abc123")
287 assert report["commit_id"] == "abc123"
288 assert report["domain"] == "code"
289 assert isinstance(report["violations"], list)