Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/dippy/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,20 @@
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Literal, Optional, Protocol
from typing import TYPE_CHECKING, Literal, Optional, Protocol

if TYPE_CHECKING:
from dippy.core.config import Config


@dataclass(frozen=True)
class HandlerContext:
"""Context passed to handlers."""

tokens: list[str]
config: Config | None = None
word_has_expansions: tuple[bool, ...] = ()
"""Per-token flag: True if the original word contained bash expansions ($VAR, $(cmd), etc.)."""


@dataclass(frozen=True)
Expand Down
83 changes: 70 additions & 13 deletions src/dippy/cli/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,20 @@ class SafetyAnalyzer(ast.NodeVisitor):
are allowed. Anything unknown is flagged.
"""

def __init__(self, allow_print: bool = True):
def __init__(
self,
allow_print: bool = True,
extra_safe_modules: frozenset[str] = frozenset(),
extra_deny_modules: frozenset[str] = frozenset(),
):
self.violations: list[Violation] = []
self.allow_print = allow_print
self.safe_modules = SAFE_MODULES | extra_safe_modules
# User-configured allow explicitly overrides hardcoded dangerous list.
# Only exact matches are removed — submodules must be allowed separately.
self.deny_modules = (
DANGEROUS_MODULES | extra_deny_modules
) - extra_safe_modules

def _add(self, node: ast.AST, kind: str, detail: str) -> None:
self.violations.append(
Expand All @@ -476,9 +487,9 @@ def visit_Import(self, node: ast.Import) -> None:
module = alias.name
root = module.split(".")[0]

if module in DANGEROUS_MODULES or root in DANGEROUS_MODULES:
if module in self.deny_modules or root in self.deny_modules:
self._add(node, "import", f"dangerous module: {module}")
elif module not in SAFE_MODULES and root not in SAFE_MODULES:
elif module not in self.safe_modules and root not in self.safe_modules:
self._add(node, "import", f"unknown module: {module}")

self.generic_visit(node)
Expand All @@ -491,9 +502,9 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
module = node.module
root = module.split(".")[0]

if module in DANGEROUS_MODULES or root in DANGEROUS_MODULES:
if module in self.deny_modules or root in self.deny_modules:
self._add(node, "import", f"dangerous module: {module}")
elif module not in SAFE_MODULES and root not in SAFE_MODULES:
elif module not in self.safe_modules and root not in self.safe_modules:
self._add(node, "import", f"unknown module: {module}")

self.generic_visit(node)
Expand Down Expand Up @@ -616,7 +627,12 @@ def visit_Try(self, node: ast.Try) -> None:
self.generic_visit(node)


def analyze_python_source(source: str, allow_print: bool = True) -> list[Violation]:
def analyze_python_source(
source: str,
allow_print: bool = True,
extra_safe_modules: frozenset[str] = frozenset(),
extra_deny_modules: frozenset[str] = frozenset(),
) -> list[Violation]:
"""
Analyze Python source code for safety violations.

Expand All @@ -627,12 +643,20 @@ def analyze_python_source(source: str, allow_print: bool = True) -> list[Violati
except SyntaxError as e:
return [Violation(e.lineno or 0, e.offset or 0, "syntax", str(e))]

analyzer = SafetyAnalyzer(allow_print=allow_print)
analyzer = SafetyAnalyzer(
allow_print=allow_print,
extra_safe_modules=extra_safe_modules,
extra_deny_modules=extra_deny_modules,
)
analyzer.visit(tree)
return analyzer.violations


def analyze_python_file(path: Path) -> tuple[bool, str]:
def analyze_python_file(
path: Path,
extra_safe_modules: frozenset[str] = frozenset(),
extra_deny_modules: frozenset[str] = frozenset(),
) -> tuple[bool, str]:
"""
Analyze a Python file for safety.

Expand Down Expand Up @@ -662,7 +686,11 @@ def analyze_python_file(path: Path) -> tuple[bool, str]:
except (OSError, UnicodeDecodeError) as e:
return False, f"cannot read file: {e}"

violations = analyze_python_source(source)
violations = analyze_python_source(
source,
extra_safe_modules=extra_safe_modules,
extra_deny_modules=extra_deny_modules,
)

if violations:
# Return first violation as reason
Expand Down Expand Up @@ -767,16 +795,22 @@ def classify(ctx: HandlerContext) -> Classification:

Auto-approves:
- Version/help flags
- -c inline code that passes static analysis (no bash expansions)
- Scripts that pass static analysis (no I/O, no dangerous imports)

Requires confirmation:
- -c (inline code)
- -c inline code that fails analysis or contains bash expansions
- -m (module execution)
- Scripts that fail analysis or can't be read
- Interactive mode
"""
tokens = ctx.tokens
cwd = Path.cwd()
config = ctx.config

# Build extra module sets from config
extra_safe = frozenset(config.python_allow_modules) if config else frozenset()
extra_deny = frozenset(config.python_deny_modules) if config else frozenset()

desc = get_description(tokens)

Expand All @@ -789,9 +823,30 @@ def classify(ctx: HandlerContext) -> Classification:
if token in SAFE_FLAGS:
return Classification("allow", description=desc)

# Check for -c (inline code) - too hard to analyze reliably
# Check for -c (inline code) - analyze if possible
if "-c" in tokens:
return Classification("ask", description=desc)
idx = tokens.index("-c")
if idx + 1 >= len(tokens):
return Classification("ask", description=desc)
# If the -c argument contains bash expansions ($VAR, $(cmd), etc.),
# we can't reliably analyze it since bash modifies the code at runtime.
code_token_idx = idx + 1
if (
ctx.word_has_expansions
and code_token_idx < len(ctx.word_has_expansions)
and ctx.word_has_expansions[code_token_idx]
):
return Classification("ask", description=f"{desc} (bash expansion)")
code = tokens[code_token_idx]
if not code.strip():
return Classification("ask", description=desc)
violations = analyze_python_source(
code, extra_safe_modules=extra_safe, extra_deny_modules=extra_deny
)
if not violations:
return Classification("allow", description=f"{desc} (analyzed)")
v = violations[0]
return Classification("ask", description=f"{desc}: {v.kind}: {v.detail}")

# Check for -m (module) - could run arbitrary code
if "-m" in tokens:
Expand All @@ -818,7 +873,9 @@ def classify(ctx: HandlerContext) -> Classification:
return Classification("ask", description=desc)

# Try to analyze the script
is_safe, reason = analyze_python_file(script_path)
is_safe, reason = analyze_python_file(
script_path, extra_safe_modules=extra_safe, extra_deny_modules=extra_deny
)

if is_safe:
return Classification("allow", description=f"{desc} (analyzed)")
Expand Down
23 changes: 19 additions & 4 deletions src/dippy/core/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ def _analyze_command(

# Get base command for injection check
words = [_get_word_value(w) for w in node.words]
# Track which words contain bash expansions (param, cmdsub, procsub)
word_has_expansions = tuple(bool(getattr(w, "parts", [])) for w in node.words)
# Skip env var assignments to find base command
base_idx = 0
while (
Expand Down Expand Up @@ -287,7 +289,9 @@ def _analyze_command(
and position > base_idx
):
handler = get_handler(base)
outer_result = handler.classify(HandlerContext(words[base_idx:]))
outer_result = handler.classify(
HandlerContext(words[base_idx:], config=config)
)
if outer_result.action != "allow":
inner_cmd = _get_word_value(word).strip("$()")
return Decision("ask", f"cmdsub injection risk: {inner_cmd}")
Expand Down Expand Up @@ -319,7 +323,9 @@ def _analyze_command(
decisions.append(Decision("allow", "conditional test"))
return _combine(decisions)

cmd_decision = _analyze_simple_command(words, config, cwd, remote=remote)
cmd_decision = _analyze_simple_command(
words, config, cwd, remote=remote, word_has_expansions=word_has_expansions
)
decisions.append(cmd_decision)

return _combine(decisions)
Expand Down Expand Up @@ -382,7 +388,12 @@ def _analyze_redirects(


def _analyze_simple_command(
words: list[str], config: Config, cwd: Path, *, remote: bool = False
words: list[str],
config: Config,
cwd: Path,
*,
remote: bool = False,
word_has_expansions: tuple[bool, ...] = (),
) -> Decision:
"""Analyze a simple command (list of words)."""
if not words:
Expand Down Expand Up @@ -448,7 +459,11 @@ def _analyze_simple_command(
# 5. CLI-specific handlers
handler = get_handler(base)
if handler:
result = handler.classify(HandlerContext(tokens))
result = handler.classify(
HandlerContext(
tokens, config=config, word_has_expansions=word_has_expansions
)
)
desc = result.description or get_description(tokens, base)
# Check handler-provided redirect targets against config (skip in remote mode)
if result.redirect_targets and not remote:
Expand Down
45 changes: 45 additions & 0 deletions src/dippy/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,30 @@
from dataclasses import dataclass, field, replace
from pathlib import Path

# Valid Python module path: dotted identifiers (e.g. "numpy", "http.server")
_MODULE_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$")


def _parse_module_name(rest: str) -> str:
"""Parse and validate a Python module name from a directive argument.

Strips inline comments (# ...) and validates the module name.
Raises ValueError if the name is missing, has extra words, or is invalid.
"""
# Strip inline comments
if "#" in rest:
rest = rest[: rest.index("#")].rstrip()
if not rest:
raise ValueError("requires a module name")
parts = rest.split()
if len(parts) != 1:
raise ValueError(f"requires exactly one module name, got: {rest!r}")
mod = parts[0]
if not _MODULE_RE.match(mod):
raise ValueError(f"invalid Python module name: {mod!r}")
return mod


# Cache home directory at module load - fails fast if HOME is unset
_HOME = Path.home()

Expand Down Expand Up @@ -62,6 +86,12 @@ class Config:
aliases: dict[str, str] = field(default_factory=dict)
"""Command aliases mapping source to target (e.g., ~/bin/gh -> gh)."""

python_allow_modules: list[str] = field(default_factory=list)
"""Extra modules to treat as safe for Python static analysis."""

python_deny_modules: list[str] = field(default_factory=list)
"""Extra modules to treat as dangerous for Python static analysis."""

default: str = "ask" # 'allow' | 'ask'
log: Path | None = None # None = no logging
log_full: bool = False # log full command (requires log path)
Expand Down Expand Up @@ -122,6 +152,9 @@ def _merge_configs(base: Config, overlay: Config) -> Config:
after_mcp_rules=base.after_mcp_rules + overlay.after_mcp_rules,
# Aliases: overlay wins for conflicting keys
aliases={**base.aliases, **overlay.aliases},
# Python module lists accumulate
python_allow_modules=base.python_allow_modules + overlay.python_allow_modules,
python_deny_modules=base.python_deny_modules + overlay.python_deny_modules,
# Settings: overlay wins if set
default=overlay.default if overlay.default != "ask" else base.default,
log=overlay.log if overlay.log is not None else base.log,
Expand Down Expand Up @@ -209,6 +242,8 @@ def parse_config(text: str, source: str | None = None) -> Config:
mcp_rules: list[Rule] = []
after_mcp_rules: list[Rule] = []
aliases: dict[str, str] = {}
python_allow_modules: list[str] = []
python_deny_modules: list[str] = []
settings: dict[str, bool | int | str | Path] = {}
prefix = f"{source}: " if source else ""

Expand Down Expand Up @@ -321,6 +356,14 @@ def parse_config(text: str, source: str | None = None) -> Config:
)
aliases[expanded_source] = alias_target

elif directive == "python-allow-module":
mod = _parse_module_name(rest)
python_allow_modules.append(mod)

elif directive == "python-deny-module":
mod = _parse_module_name(rest)
python_deny_modules.append(mod)

elif directive == "set":
_apply_setting(settings, rest)

Expand All @@ -337,6 +380,8 @@ def parse_config(text: str, source: str | None = None) -> Config:
mcp_rules=mcp_rules,
after_mcp_rules=after_mcp_rules,
aliases=aliases,
python_allow_modules=python_allow_modules,
python_deny_modules=python_deny_modules,
default=settings.get("default", "ask"),
log=settings.get("log"),
log_full=settings.get("log_full", False),
Expand Down
Loading
Loading