Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,31 @@ def no_destructive_sql(request: ExecutionRequest) -> bool:
import sqlglot
from sqlglot import exp

# Resolve destructive-statement expression classes in a
# version-robust way. sqlglot renamed ``AlterTable`` -> ``Alter``
# and added ``Grant`` only in later releases, so no single
# installable version exposes every symbol. Referencing a missing
# attribute (e.g. ``exp.Grant``) directly would raise
# AttributeError for every query on the pinned version. Resolve the
# available aliases here and skip any that are absent.
alter_types = tuple(
cls for cls in (
getattr(exp, "Alter", None),
getattr(exp, "AlterTable", None),
)
if cls is not None
)
grant_types = tuple(
cls for cls in (getattr(exp, "Grant", None),)
if cls is not None
)
# TRUNCATE parses as ``TruncateTable`` in current sqlglot and as a
# generic ``Command`` in some versions; handle whichever exists.
truncate_types = tuple(
cls for cls in (getattr(exp, "TruncateTable", None),)
if cls is not None
)

# Parse the SQL query into AST
try:
statements = sqlglot.parse(query)
Expand All @@ -914,8 +939,15 @@ def no_destructive_sql(request: ExecutionRequest) -> bool:
if isinstance(statement, exp.Drop):
return False

# Check for TRUNCATE statements
if isinstance(statement, exp.Command) and statement.this.upper() == "TRUNCATE":
# Check for TRUNCATE statements (dedicated TruncateTable node
# in current sqlglot; a Command node in older versions)
if truncate_types and isinstance(statement, truncate_types):
return False
if (
isinstance(statement, exp.Command)
and statement.this
and statement.this.upper() == "TRUNCATE"
):
return False

# Check for DELETE without WHERE clause
Expand All @@ -928,12 +960,15 @@ def no_destructive_sql(request: ExecutionRequest) -> bool:
if statement.find(exp.Where) is None:
return False

# Check for ALTER statements
if isinstance(statement, exp.AlterTable):
# Check for ALTER statements (exp.Alter / exp.AlterTable
# depending on the installed sqlglot version)
if alter_types and isinstance(statement, alter_types):
return False

# Check for GRANT / REVOKE statements
if isinstance(statement, exp.Grant):
# Check for GRANT statements when sqlglot parses them as a
# dedicated ``Grant`` node. Older versions lack this class and
# parse GRANT as a Command, which is handled below.
if grant_types and isinstance(statement, grant_types):
return False

# Check for MERGE statements (can do INSERT/UPDATE/DELETE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,21 @@ def test_multiple_statements_checked(self, sql_policy):
request = self.make_sql_request("SELECT 1; DROP TABLE users;")
assert sql_policy.validator(request) is False

@pytest.mark.skipif(not SQLGLOT_AVAILABLE, reason="Requires sqlglot for AST parsing")
def test_no_attributeerror_on_installed_sqlglot(self, sql_policy):
"""Regression: the validator must not reference sqlglot expression
symbols that are absent on the installed version.

sqlglot renamed ``AlterTable`` -> ``Alter`` and added ``Grant`` in
later releases, so referencing both ``exp.AlterTable`` and ``exp.Grant``
in one function raised AttributeError for every query on any single
pinned version. This guards the version-robust getattr resolution.
"""
# A plain SELECT exercises every isinstance branch without matching
# any of them, so it fails on the missing-symbol AttributeError.
request = self.make_sql_request("SELECT * FROM users")
assert sql_policy.validator(request) is True


class TestSQLPolicyFallback:
"""Test the fallback SQL check when sqlglot is not available."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,15 @@ def get_decision(self, request_id: str) -> EscalationRequest | None:
def approve(self, request_id: str, approver: str = "") -> bool:
with self._lock:
req = self._requests.get(request_id)
if req is None or req.decision != EscalationDecision.PENDING:
if req is None:
return False
if not approver.strip():
return False
if any(a == approver for a, _, _ in req.votes):
return False
req.votes.append((approver, "ALLOW", datetime.now(timezone.utc)))
if req.decision != EscalationDecision.PENDING:
return True
req.decision = EscalationDecision.ALLOW
req.resolved_by = approver
req.resolved_at = datetime.now(timezone.utc)
Expand All @@ -179,8 +186,15 @@ def approve(self, request_id: str, approver: str = "") -> bool:
def deny(self, request_id: str, approver: str = "") -> bool:
with self._lock:
req = self._requests.get(request_id)
if req is None or req.decision != EscalationDecision.PENDING:
if req is None:
return False
if not approver.strip():
return False
if any(a == approver for a, _, _ in req.votes):
return False
req.votes.append((approver, "DENY", datetime.now(timezone.utc)))
if req.decision != EscalationDecision.PENDING:
return True
req.decision = EscalationDecision.DENY
req.resolved_by = approver
req.resolved_at = datetime.now(timezone.utc)
Expand Down
116 changes: 110 additions & 6 deletions agent-governance-python/agent-os/tests/test_escalation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
EscalationRequest,
EscalationResult,
InMemoryApprovalQueue,
QuorumConfig,
)


Expand Down Expand Up @@ -74,24 +75,58 @@ def test_deny(self):
retrieved = queue.get_decision(req.request_id)
assert retrieved.decision == EscalationDecision.DENY

def test_double_approve_fails(self):
def test_double_approve_same_approver_rejected(self):
queue = InMemoryApprovalQueue()
req = EscalationRequest(agent_id="a1", action="x", reason="r")
queue.submit(req)
assert queue.approve(req.request_id) is True
assert queue.approve(req.request_id) is False # Already resolved
assert queue.approve(req.request_id, approver="admin") is True
assert queue.approve(req.request_id, approver="admin") is False # duplicate vote

def test_second_approver_vote_recorded(self):
queue = InMemoryApprovalQueue()
req = EscalationRequest(agent_id="a1", action="x", reason="r")
queue.submit(req)
assert queue.approve(req.request_id, approver="admin-a") is True
assert queue.approve(req.request_id, approver="admin-b") is True
retrieved = queue.get_decision(req.request_id)
assert len(retrieved.votes) == 2
approvers = [a for a, _, _ in retrieved.votes]
assert "admin-a" in approvers
assert "admin-b" in approvers

def test_votes_recorded_on_approve(self):
queue = InMemoryApprovalQueue()
req = EscalationRequest(agent_id="a1", action="x", reason="r")
queue.submit(req)
queue.approve(req.request_id, approver="reviewer-1")
retrieved = queue.get_decision(req.request_id)
assert len(retrieved.votes) == 1
approver, verdict, _ = retrieved.votes[0]
assert approver == "reviewer-1"
assert verdict == "ALLOW"

def test_votes_recorded_on_deny(self):
queue = InMemoryApprovalQueue()
req = EscalationRequest(agent_id="a1", action="x", reason="r")
queue.submit(req)
queue.deny(req.request_id, approver="sec-team")
retrieved = queue.get_decision(req.request_id)
assert len(retrieved.votes) == 1
approver, verdict, _ = retrieved.votes[0]
assert approver == "sec-team"
assert verdict == "DENY"

def test_approve_nonexistent(self):
queue = InMemoryApprovalQueue()
assert queue.approve("nonexistent") is False
assert queue.approve("nonexistent", approver="admin") is False

def test_list_pending(self):
queue = InMemoryApprovalQueue()
r1 = EscalationRequest(agent_id="a1", action="x", reason="r")
r2 = EscalationRequest(agent_id="a2", action="y", reason="s")
queue.submit(r1)
queue.submit(r2)
queue.approve(r1.request_id)
queue.approve(r1.request_id, approver="admin")
pending = queue.list_pending()
assert len(pending) == 1
assert pending[0].request_id == r2.request_id
Expand Down Expand Up @@ -133,7 +168,7 @@ def test_resolve_with_approval(self):

def approve():
time.sleep(0.1)
queue.approve(request.request_id)
queue.approve(request.request_id, approver="admin")

t = threading.Thread(target=approve)
t.start()
Expand Down Expand Up @@ -246,3 +281,72 @@ def test_custom_fields(self):
)
assert req.agent_id == "a1"
assert req.action == "deploy"


class TestQuorumResolution:
def test_quorum_met_resolves_allow(self):
queue = InMemoryApprovalQueue()
handler = EscalationHandler(
backend=queue,
timeout_seconds=5,
quorum=QuorumConfig(required_approvals=1, total_approvers=1),
)
request = handler.escalate("agent-1", "action", "reason")

def approve():
time.sleep(0.05)
queue.approve(request.request_id, approver="reviewer-1")

t = threading.Thread(target=approve)
t.start()
decision = handler.resolve(request.request_id)
t.join()
assert decision == EscalationDecision.ALLOW

def test_quorum_not_met_times_out(self):
queue = InMemoryApprovalQueue()
handler = EscalationHandler(
backend=queue,
timeout_seconds=0.2,
default_action=DefaultTimeoutAction.DENY,
quorum=QuorumConfig(required_approvals=2, total_approvers=3),
)
request = handler.escalate("agent-1", "action", "reason")
queue.approve(request.request_id, approver="reviewer-1")
decision = handler.resolve(request.request_id)
assert decision == EscalationDecision.DENY

def test_duplicate_approver_does_not_satisfy_quorum(self):
queue = InMemoryApprovalQueue()
handler = EscalationHandler(
backend=queue,
timeout_seconds=0.2,
default_action=DefaultTimeoutAction.DENY,
quorum=QuorumConfig(required_approvals=2, total_approvers=3),
)
request = handler.escalate("agent-1", "action", "reason")
queue.approve(request.request_id, approver="reviewer-1")
result = queue.approve(request.request_id, approver="reviewer-1")
assert result is False
retrieved = queue.get_decision(request.request_id)
assert len(retrieved.votes) == 1

def test_empty_approver_rejected_on_approve(self):
queue = InMemoryApprovalQueue()
req = EscalationRequest(agent_id="a1", action="x", reason="r")
queue.submit(req)
assert queue.approve(req.request_id) is False
assert queue.approve(req.request_id, approver="") is False
assert queue.approve(req.request_id, approver=" ") is False
retrieved = queue.get_decision(req.request_id)
assert len(retrieved.votes) == 0

def test_empty_approver_rejected_on_deny(self):
queue = InMemoryApprovalQueue()
req = EscalationRequest(agent_id="a1", action="x", reason="r")
queue.submit(req)
assert queue.deny(req.request_id) is False
assert queue.deny(req.request_id, approver="") is False
assert queue.deny(req.request_id, approver=" ") is False
retrieved = queue.get_decision(req.request_id)
assert len(retrieved.votes) == 0
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,6 @@ def test_single_approval_insufficient_for_quorum(self):
request = handler.escalate("agent-1", "deploy", "needs review")
# One approval — not enough for quorum of 2
queue.approve(request.request_id, approver="admin1")
# Manually add vote tracking
req = queue.get_decision(request.request_id)
req.votes.append(("admin1", "ALLOW", req.resolved_at))
decision = handler.resolve(request.request_id)
# With only 1 vote and quorum=2, should timeout-deny
assert decision == EscalationDecision.DENY
Expand All @@ -474,9 +471,7 @@ def test_quorum_met_with_enough_approvals(self):
)
request = handler.escalate("agent-1", "deploy", "needs review")
queue.approve(request.request_id, approver="admin1")
req = queue.get_decision(request.request_id)
req.votes.append(("admin1", "ALLOW", req.resolved_at))
req.votes.append(("admin2", "ALLOW", req.resolved_at))
queue.approve(request.request_id, approver="admin2")
decision = handler.resolve(request.request_id)
assert decision == EscalationDecision.ALLOW

Expand All @@ -490,8 +485,6 @@ def test_quorum_deny_on_single_denial(self):
)
request = handler.escalate("agent-1", "deploy", "needs review")
queue.deny(request.request_id, approver="sec-team")
req = queue.get_decision(request.request_id)
req.votes.append(("sec-team", "DENY", req.resolved_at))
decision = handler.resolve(request.request_id)
assert decision == EscalationDecision.DENY

Expand Down
Loading