Skip to content

Commit 46e1fe6

Browse files
author
mnbplus
committed
test: add 31 tests for RAG engine and rule plugin manager
1 parent dcd3298 commit 46e1fe6

File tree

2 files changed

+378
-0
lines changed

2 files changed

+378
-0
lines changed

tests/test_rag.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""Tests for pyaegis.rag -- local RAG engine."""
2+
from __future__ import annotations
3+
4+
import textwrap
5+
from pathlib import Path
6+
7+
from pyaegis.rag import CodeRAG, _chunk_file, _cosine_similarity
8+
9+
10+
# ---------------------------------------------------------------------------
11+
# Helpers
12+
# ---------------------------------------------------------------------------
13+
14+
15+
def _write_py(tmp_path: Path, name: str, src: str) -> Path:
16+
p = tmp_path / name
17+
p.write_text(textwrap.dedent(src), encoding="utf-8")
18+
return p
19+
20+
21+
# ---------------------------------------------------------------------------
22+
# _chunk_file
23+
# ---------------------------------------------------------------------------
24+
25+
26+
def test_chunk_file_functions(tmp_path):
27+
f = _write_py(
28+
tmp_path,
29+
"sample.py",
30+
(
31+
"def foo(x):\n"
32+
" # Foo does stuff.\n"
33+
" return x + 1\n"
34+
"\n"
35+
"def bar():\n"
36+
" pass\n"
37+
),
38+
)
39+
chunks = _chunk_file(str(f))
40+
names = [c.name for c in chunks]
41+
assert "foo" in names
42+
assert "bar" in names
43+
for c in chunks:
44+
assert c.kind == "function"
45+
assert c.file_path == str(f)
46+
assert c.chunk_id
47+
48+
49+
def test_chunk_file_class(tmp_path):
50+
f = _write_py(
51+
tmp_path,
52+
"cls.py",
53+
(
54+
"class MyClass:\n"
55+
" # A class.\n"
56+
" def method(self):\n"
57+
" pass\n"
58+
),
59+
)
60+
chunks = _chunk_file(str(f))
61+
kinds = {c.kind for c in chunks}
62+
assert "class" in kinds
63+
64+
65+
def test_chunk_file_nonexistent():
66+
chunks = _chunk_file("/does/not/exist.py")
67+
assert chunks == []
68+
69+
70+
def test_chunk_file_syntax_error(tmp_path):
71+
f = _write_py(tmp_path, "bad.py", "def foo(:\n pass\n")
72+
chunks = _chunk_file(str(f))
73+
assert isinstance(chunks, list)
74+
75+
76+
# ---------------------------------------------------------------------------
77+
# _cosine_similarity
78+
# ---------------------------------------------------------------------------
79+
80+
81+
def test_cosine_similarity_identical():
82+
v = [1.0, 0.0, 0.0]
83+
assert abs(_cosine_similarity(v, v) - 1.0) < 1e-6
84+
85+
86+
def test_cosine_similarity_orthogonal():
87+
v1 = [1.0, 0.0]
88+
v2 = [0.0, 1.0]
89+
assert abs(_cosine_similarity(v1, v2)) < 1e-6
90+
91+
92+
def test_cosine_similarity_zero_vector():
93+
v = [0.0, 0.0, 0.0]
94+
assert _cosine_similarity(v, v) == 0.0
95+
96+
97+
# ---------------------------------------------------------------------------
98+
# CodeRAG -- basic index + search
99+
# ---------------------------------------------------------------------------
100+
101+
102+
def test_rag_index_and_search(tmp_path):
103+
f = _write_py(
104+
tmp_path,
105+
"auth.py",
106+
(
107+
"def authenticate_user(username, password):\n"
108+
" # Check user credentials against the database.\n"
109+
" return check_db(username, password)\n"
110+
"\n"
111+
"def logout(session):\n"
112+
" # Invalidate the user session.\n"
113+
" session.clear()\n"
114+
),
115+
)
116+
rag = CodeRAG(db_path=":memory:")
117+
n = rag.index_file(str(f))
118+
assert n >= 2
119+
120+
results = rag.search("user authentication credentials", top_k=3)
121+
assert len(results) > 0
122+
names = [r.chunk.name for r in results]
123+
assert "authenticate_user" in names
124+
rag.close()
125+
126+
127+
def test_rag_index_directory(tmp_path):
128+
_write_py(tmp_path, "a.py", "def alpha(): pass\n")
129+
_write_py(tmp_path, "b.py", "def beta(): pass\n")
130+
sub = tmp_path / "sub"
131+
sub.mkdir()
132+
_write_py(sub, "c.py", "def gamma(): pass\n")
133+
134+
rag = CodeRAG(db_path=":memory:")
135+
files, chunks = rag.index_directory(str(tmp_path))
136+
assert files >= 2
137+
assert chunks >= 2
138+
stats = rag.stats()
139+
assert stats["indexed_files"] >= 2
140+
rag.close()
141+
142+
143+
def test_rag_no_reindex_unchanged(tmp_path):
144+
f = _write_py(tmp_path, "stable.py", "def stable(): pass\n")
145+
db = tmp_path / "rag.sqlite"
146+
rag = CodeRAG(db_path=str(db))
147+
n1 = rag.index_file(str(f))
148+
n2 = rag.index_file(str(f))
149+
assert n1 > 0
150+
assert n2 == 0
151+
rag.close()
152+
153+
154+
def test_rag_force_reindex(tmp_path):
155+
f = _write_py(tmp_path, "x.py", "def x(): pass\n")
156+
db = tmp_path / "rag.sqlite"
157+
rag = CodeRAG(db_path=str(db))
158+
rag.index_file(str(f))
159+
n = rag.index_file(str(f), force=True)
160+
assert n > 0
161+
rag.close()
162+
163+
164+
def test_rag_kind_filter(tmp_path):
165+
f = _write_py(
166+
tmp_path,
167+
"mixed.py",
168+
("class Foo:\n" " pass\n" "\n" "def bar():\n" " pass\n"),
169+
)
170+
rag = CodeRAG(db_path=":memory:")
171+
rag.index_file(str(f))
172+
results = rag.search("foo bar", top_k=10, kind_filter="function")
173+
for r in results:
174+
assert r.chunk.kind == "function"
175+
rag.close()
176+
177+
178+
def test_rag_build_context(tmp_path):
179+
f = _write_py(tmp_path, "ctx.py", "def hello(): pass\n")
180+
rag = CodeRAG(db_path=":memory:")
181+
rag.index_file(str(f))
182+
results = rag.search("hello", top_k=1)
183+
ctx = rag.build_context(results, max_chars=5000)
184+
assert "hello" in ctx
185+
rag.close()
186+
187+
188+
def test_rag_stats_empty():
189+
rag = CodeRAG(db_path=":memory:")
190+
s = rag.stats()
191+
assert s["total_chunks"] == 0
192+
assert s["indexed_files"] == 0
193+
rag.close()
194+
195+
196+
def test_rag_context_manager(tmp_path):
197+
f = _write_py(tmp_path, "cm.py", "def cm(): pass\n")
198+
with CodeRAG(db_path=":memory:") as rag:
199+
rag.index_file(str(f))
200+
results = rag.search("cm")
201+
assert isinstance(results, list)

tests/test_rule_plugins.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""Tests for pyaegis.rule_plugins — community rule pack manager."""
2+
from __future__ import annotations
3+
4+
from pathlib import Path
5+
6+
import pytest
7+
import yaml
8+
9+
from pyaegis.rule_plugins import RulePluginManager
10+
11+
12+
# ---------------------------------------------------------------------------
13+
# Fixtures
14+
# ---------------------------------------------------------------------------
15+
16+
17+
@pytest.fixture
18+
def mgr(tmp_path):
19+
return RulePluginManager(rules_dir=str(tmp_path / "rules"))
20+
21+
22+
@pytest.fixture
23+
def sample_pack(tmp_path):
24+
"""Write a valid YAML rule pack to disk and return its path."""
25+
data = {
26+
"inputs": ["environ.get", "getenv"],
27+
"sinks": ["ldap.search"],
28+
"sanitizers": ["escape"],
29+
"conditional_sinks": [],
30+
"source_decorators": [],
31+
}
32+
p = tmp_path / "sample_pack.yml"
33+
p.write_text(yaml.safe_dump(data), encoding="utf-8")
34+
return str(p)
35+
36+
37+
# ---------------------------------------------------------------------------
38+
# Install
39+
# ---------------------------------------------------------------------------
40+
41+
42+
def test_install_local(mgr, sample_pack):
43+
name = mgr.install(sample_pack, name="test-pack")
44+
assert name == "test-pack"
45+
installed = mgr.list_installed()
46+
assert len(installed) == 1
47+
assert installed[0]["name"] == "test-pack"
48+
49+
50+
def test_install_auto_name(mgr, sample_pack):
51+
name = mgr.install(sample_pack) # name inferred from filename
52+
assert name == "sample_pack"
53+
54+
55+
def test_install_duplicate_raises(mgr, sample_pack):
56+
mgr.install(sample_pack, name="dup")
57+
with pytest.raises(ValueError, match="already installed"):
58+
mgr.install(sample_pack, name="dup")
59+
60+
61+
def test_install_duplicate_force(mgr, sample_pack):
62+
mgr.install(sample_pack, name="dup")
63+
name = mgr.install(sample_pack, name="dup", force=True)
64+
assert name == "dup"
65+
assert len(mgr.list_installed()) == 1
66+
67+
68+
def test_install_invalid_yaml(mgr, tmp_path):
69+
bad = tmp_path / "bad.yml"
70+
bad.write_text("[unclosed bracket\n", encoding="utf-8")
71+
with pytest.raises(ValueError, match="Invalid YAML"):
72+
mgr.install(str(bad), name="bad")
73+
74+
75+
def test_install_non_mapping_yaml(mgr, tmp_path):
76+
p = tmp_path / "list.yml"
77+
p.write_text("- item1\n- item2\n", encoding="utf-8")
78+
with pytest.raises(ValueError, match="must be a YAML mapping"):
79+
mgr.install(str(p), name="list")
80+
81+
82+
# ---------------------------------------------------------------------------
83+
# Remove
84+
# ---------------------------------------------------------------------------
85+
86+
87+
def test_remove_installed(mgr, sample_pack):
88+
mgr.install(sample_pack, name="to-remove")
89+
result = mgr.remove("to-remove")
90+
assert result is True
91+
assert mgr.list_installed() == []
92+
93+
94+
def test_remove_nonexistent(mgr):
95+
result = mgr.remove("ghost")
96+
assert result is False
97+
98+
99+
# ---------------------------------------------------------------------------
100+
# List
101+
# ---------------------------------------------------------------------------
102+
103+
104+
def test_list_empty(mgr):
105+
assert mgr.list_installed() == []
106+
107+
108+
def test_list_multiple(mgr, tmp_path, sample_pack):
109+
pack2 = tmp_path / "pack2.yml"
110+
pack2.write_text(yaml.safe_dump({"sinks": ["eval"]}), encoding="utf-8")
111+
mgr.install(sample_pack, name="p1")
112+
mgr.install(str(pack2), name="p2")
113+
names = {p["name"] for p in mgr.list_installed()}
114+
assert names == {"p1", "p2"}
115+
116+
117+
# ---------------------------------------------------------------------------
118+
# Merge
119+
# ---------------------------------------------------------------------------
120+
121+
122+
def test_merged_rules_includes_plugin(mgr, sample_pack):
123+
mgr.install(sample_pack, name="extra")
124+
merged = mgr.merged_rules(include_builtin=False)
125+
assert "environ.get" in merged["inputs"]
126+
assert "ldap.search" in merged["sinks"]
127+
128+
129+
def test_merged_rules_deduplicates(mgr, tmp_path):
130+
p1 = tmp_path / "p1.yml"
131+
p2 = tmp_path / "p2.yml"
132+
p1.write_text(yaml.safe_dump({"inputs": ["os.environ"]}), encoding="utf-8")
133+
p2.write_text(
134+
yaml.safe_dump({"inputs": ["os.environ", "getenv"]}), encoding="utf-8"
135+
)
136+
mgr.install(str(p1), name="p1")
137+
mgr.install(str(p2), name="p2")
138+
merged = mgr.merged_rules(include_builtin=False)
139+
assert merged["inputs"].count("os.environ") == 1
140+
141+
142+
def test_merged_rules_path_creates_file(mgr, sample_pack):
143+
mgr.install(sample_pack, name="mp")
144+
path = mgr.merged_rules_path(include_builtin=False)
145+
assert Path(path).exists()
146+
content = yaml.safe_load(Path(path).read_text(encoding="utf-8"))
147+
assert "sinks" in content
148+
Path(path).unlink() # cleanup
149+
150+
151+
# ---------------------------------------------------------------------------
152+
# has_plugins
153+
# ---------------------------------------------------------------------------
154+
155+
156+
def test_has_plugins_false(mgr):
157+
assert mgr.has_plugins() is False
158+
159+
160+
def test_has_plugins_true(mgr, sample_pack):
161+
mgr.install(sample_pack, name="any")
162+
assert mgr.has_plugins() is True
163+
164+
165+
# ---------------------------------------------------------------------------
166+
# Index persistence
167+
# ---------------------------------------------------------------------------
168+
169+
170+
def test_index_persists_across_instances(tmp_path, sample_pack):
171+
rules_dir = str(tmp_path / "rules")
172+
mgr1 = RulePluginManager(rules_dir=rules_dir)
173+
mgr1.install(sample_pack, name="persist")
174+
175+
mgr2 = RulePluginManager(rules_dir=rules_dir)
176+
names = {p["name"] for p in mgr2.list_installed()}
177+
assert "persist" in names

0 commit comments

Comments
 (0)