From 857b37718e36a90deb079d3f70f0494322aa3e67 Mon Sep 17 00:00:00 2001 From: Amine Date: Fri, 11 Jul 2025 20:19:51 +0100 Subject: [PATCH 1/2] =?UTF-8?q?feat:=20add=20AI=20module=20for=20LLM=20int?= =?UTF-8?q?eraction=20and=20a=20=20heuristic=20for=20checking=20code?= =?UTF-8?q?=E2=80=93docstring=20consistency?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Amine --- src/macaron/ai.py | 175 ++++++++++++++++++ src/macaron/config/defaults.ini | 14 ++ .../pypi_heuristics/heuristics.py | 3 + .../sourcecode/matching_docstrings.py | 101 ++++++++++ .../slsa_analyzer/build_tool/gradle.py | 4 +- src/macaron/slsa_analyzer/build_tool/maven.py | 4 +- src/macaron/slsa_analyzer/build_tool/pip.py | 4 +- .../slsa_analyzer/build_tool/poetry.py | 4 +- .../checks/detect_malicious_metadata_check.py | 7 + .../pypi/test_matching_docstrings.py | 103 +++++++++++ 10 files changed, 411 insertions(+), 8 deletions(-) create mode 100644 src/macaron/ai.py create mode 100644 src/macaron/malware_analyzer/pypi_heuristics/sourcecode/matching_docstrings.py create mode 100644 tests/malware_analyzer/pypi/test_matching_docstrings.py diff --git a/src/macaron/ai.py b/src/macaron/ai.py new file mode 100644 index 000000000..eb48ba08b --- /dev/null +++ b/src/macaron/ai.py @@ -0,0 +1,175 @@ +# Copyright (c) 2024 - 2025, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. + +"""This module provides a client for interacting with a Large Language Model (LLM).""" + +import json +import logging +import re +from typing import Any, TypeVar + +from pydantic import BaseModel, ValidationError + +from macaron.config.defaults import defaults +from macaron.errors import ConfigurationError, HeuristicAnalyzerValueError +from macaron.util import send_post_http_raw + +logger: logging.Logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class AIClient: + """A client for interacting with a Large Language Model.""" + + def __init__(self, system_prompt: str): + """ + Initialize the AI client. + + The LLM configuration (enabled, API key, endpoint, model) is read from defaults. + """ + self.enabled, self.api_endpoint, self.api_key, self.model, self.context_window = self._load_defaults() + self.system_prompt = system_prompt.strip() or "You are a helpful AI assistant." + logger.info("AI client is %s.", "enabled" if self.enabled else "disabled") + + def _load_defaults(self) -> tuple[bool, str, str, str, int]: + """Load the LLM configuration from the defaults.""" + section_name = "llm" + enabled, api_key, api_endpoint, model, context_window = False, "", "", "", 10000 + + if defaults.has_section(section_name): + section = defaults[section_name] + enabled = section.get("enabled", "False").strip().lower() == "true" + api_key = section.get("api_key", "").strip() + api_endpoint = section.get("api_endpoint", "").strip() + model = section.get("model", "").strip() + context_window = section.getint("context_window", 10000) + + if enabled: + if not api_key: + raise ConfigurationError("API key for the AI client is not configured.") + if not api_endpoint: + raise ConfigurationError("API endpoint for the AI client is not configured.") + if not model: + raise ConfigurationError("Model for the AI client is not configured.") + + return enabled, api_endpoint, api_key, model, context_window + + def _validate_response(self, response_text: str, response_model: type[T]) -> T: + """ + Validate and parse the response from the LLM. + + If raw JSON parsing fails, attempts to extract a JSON object from text. + + Parameters + ---------- + response_text: str + The response text from the LLM. + response_model: Type[T] + The Pydantic model to validate the response against. + + Returns + ------- + bool + The validated Pydantic model instance. + + Raises + ------ + HeuristicAnalyzerValueError + If there is an error in parsing or validating the response. + """ + try: + data = json.loads(response_text) + except json.JSONDecodeError: + logger.debug("Full JSON parse failed; trying to extract JSON from text.") + # If the response is not a valid JSON, try to extract a JSON object from the text. + match = re.search(r"\{.*\}", response_text, re.DOTALL) + if not match: + raise HeuristicAnalyzerValueError("No JSON object found in the LLM response.") from match + try: + data = json.loads(match.group(0)) + except json.JSONDecodeError as e: + logger.error("Failed to parse extracted JSON: %s", e) + raise HeuristicAnalyzerValueError("Invalid JSON extracted from response.") from e + + try: + return response_model.model_validate(data) + except ValidationError as e: + logger.error("Validation failed against response model: %s", e) + raise HeuristicAnalyzerValueError("Response JSON validation failed.") from e + + def invoke( + self, + user_prompt: str, + temperature: float = 0.2, + max_tokens: int = 4000, + structured_output: type[T] | None = None, + timeout: int = 30, + ) -> Any: + """ + Invoke the LLM and optionally validate its response. + + Parameters + ---------- + user_prompt: str + The user prompt to send to the LLM. + temperature: float + The temperature for the LLM response. + max_tokens: int + The maximum number of tokens for the LLM response. + structured_output: Optional[Type[T]] + The Pydantic model to validate the response against. If provided, the response will be parsed and validated. + timeout: int + The timeout for the HTTP request in seconds. + + Returns + ------- + Optional[T | str] + The validated Pydantic model instance if `structured_output` is provided, + or the raw string response if not. + + Raises + ------ + HeuristicAnalyzerValueError + If there is an error in parsing or validating the response. + """ + if not self.enabled: + raise ConfigurationError("AI client is not enabled. Please check your configuration.") + + if len(user_prompt.split()) > self.context_window: + logger.warning( + "User prompt exceeds context window (%s words). " + "Truncating the prompt to fit within the context window.", + self.context_window, + ) + user_prompt = " ".join(user_prompt.split()[: self.context_window]) + + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + payload = { + "model": self.model, + "messages": [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": user_prompt}], + "temperature": temperature, + "max_tokens": max_tokens, + } + + try: + response = send_post_http_raw(url=self.api_endpoint, json_data=payload, headers=headers, timeout=timeout) + if not response: + raise HeuristicAnalyzerValueError("No response received from the LLM.") + response_json = response.json() + usage = response_json.get("usage", {}) + + if usage: + usage_str = ", ".join(f"{key} = {value}" for key, value in usage.items()) + logger.info("LLM call token usage: %s", usage_str) + + message_content = response_json["choices"][0]["message"]["content"] + + if not structured_output: + logger.debug("Returning raw message content (no structured output requested).") + return message_content + return self._validate_response(message_content, structured_output) + + except Exception as e: + logger.error("Error during LLM invocation: %s", e) + raise HeuristicAnalyzerValueError(f"Failed to get or validate LLM response: {e}") from e diff --git a/src/macaron/config/defaults.ini b/src/macaron/config/defaults.ini index 0c31aaca7..fd7762065 100644 --- a/src/macaron/config/defaults.ini +++ b/src/macaron/config/defaults.ini @@ -635,3 +635,17 @@ custom_semgrep_rules_path = # .yaml prefix. Note, this will be ignored if a path to custom semgrep rules is not provided. This list may not contain # duplicated elements, meaning that ruleset names must be unique. disabled_custom_rulesets = + +[llm] +# The LLM configuration for Macaron. +# If enabled, the LLM will be used to analyze the results and provide insights. +enabled = +# The API key for the LLM service. +api_key = +# The API endpoint for the LLM service. +api_endpoint = +# The model to use for the LLM service. +model = +# The context window size for the LLM service. +# This is the maximum number of tokens that the LLM can process in a single request. +context_window = 10000 diff --git a/src/macaron/malware_analyzer/pypi_heuristics/heuristics.py b/src/macaron/malware_analyzer/pypi_heuristics/heuristics.py index eebce5764..d0b21f5ab 100644 --- a/src/macaron/malware_analyzer/pypi_heuristics/heuristics.py +++ b/src/macaron/malware_analyzer/pypi_heuristics/heuristics.py @@ -43,6 +43,9 @@ class Heuristics(str, Enum): #: Indicates that the package source code contains suspicious code patterns. SUSPICIOUS_PATTERNS = "suspicious_patterns" + #: Indicates that the package contains some code that doesn't match the docstrings. + MATCHING_DOCSTRINGS = "matching_docstrings" + class HeuristicResult(str, Enum): """Result type indicating the outcome of a heuristic.""" diff --git a/src/macaron/malware_analyzer/pypi_heuristics/sourcecode/matching_docstrings.py b/src/macaron/malware_analyzer/pypi_heuristics/sourcecode/matching_docstrings.py new file mode 100644 index 000000000..ca9cafbe3 --- /dev/null +++ b/src/macaron/malware_analyzer/pypi_heuristics/sourcecode/matching_docstrings.py @@ -0,0 +1,101 @@ +# Copyright (c) 2024 - 2025, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. + +"""This analyzer checks the iconsistency of code with its docstrings.""" + +import logging +import time +from typing import Literal + +from pydantic import BaseModel, Field + +from macaron.ai import AIClient +from macaron.json_tools import JsonType +from macaron.malware_analyzer.pypi_heuristics.base_analyzer import BaseHeuristicAnalyzer +from macaron.malware_analyzer.pypi_heuristics.heuristics import HeuristicResult, Heuristics +from macaron.slsa_analyzer.package_registry.pypi_registry import PyPIPackageJsonAsset + +logger: logging.Logger = logging.getLogger(__name__) + + +class Result(BaseModel): + """The result after analysing the code with its docstrings.""" + + decision: Literal["consistent", "inconsistent"] = Field( + description=""" The final decision after analysing the code with its docstrings. + It can be either 'consistent' or 'inconsistent'.""" + ) + reason: str = Field( + description=" The reason for the decision made. It should be a short sentence explaining the decision." + ) + inconsistent_code_part: str | None = Field( + default=None, + description=""" The specific part of the code that is inconsistent with the docstring. + Empty if the decision is 'consistent'.""", + ) + + +class MatchingDocstringsAnalyzer(BaseHeuristicAnalyzer): + """Check whether the docstrings and the code components are consistent.""" + + SYSTEM_PROMPT = """ + You are a code master who can detect the inconsistency of the code with the docstrings that describes its components. + You will be given a python code file. Your task is to determine whether the code is consistent with the docstrings. + Wrap the output in `json` tags. + Your response must be a JSON object matching this schema: + { + "decision": "'consistent' or 'inconsistent'", + "reason": "A short explanation.", "inconsistent_code_part": + "The inconsistent code, or null." + } + + /no_think + """ + + REQUEST_INTERVAL = 0.5 + + def __init__(self) -> None: + super().__init__( + name="matching_docstrings_analyzer", + heuristic=Heuristics.MATCHING_DOCSTRINGS, + depends_on=None, + ) + self.client = AIClient(system_prompt=self.SYSTEM_PROMPT.strip()) + + def analyze(self, pypi_package_json: PyPIPackageJsonAsset) -> tuple[HeuristicResult, dict[str, JsonType]]: + """Analyze the package. + + Parameters + ---------- + pypi_package_json: PyPIPackageJsonAsset + The PyPI package JSON asset object. + + Returns + ------- + tuple[HeuristicResult, dict[str, JsonType]]: + The result and related information collected during the analysis. + """ + if not self.client.enabled: + logger.warning("AI client is not enabled, skipping the matching docstrings analysis.") + return HeuristicResult.SKIP, {} + + download_result = pypi_package_json.download_sourcecode() + if not download_result: + logger.warning("No source code found for the package, skipping the matching docstrings analysis.") + return HeuristicResult.SKIP, {} + + for file, content in pypi_package_json.iter_sourcecode(): + if file.endswith(".py"): + time.sleep(self.REQUEST_INTERVAL) # Respect the request interval to avoid rate limiting. + code_str = content.decode("utf-8", "ignore") + analysis_result = self.client.invoke( + user_prompt=code_str, + structured_output=Result, + ) + if analysis_result and analysis_result.decision == "inconsistent": + return HeuristicResult.FAIL, { + "file": file, + "reason": analysis_result.reason, + "inconsistent part": analysis_result.inconsistent_code_part or "", + } + return HeuristicResult.PASS, {} diff --git a/src/macaron/slsa_analyzer/build_tool/gradle.py b/src/macaron/slsa_analyzer/build_tool/gradle.py index 2cc491934..607e98579 100644 --- a/src/macaron/slsa_analyzer/build_tool/gradle.py +++ b/src/macaron/slsa_analyzer/build_tool/gradle.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 - 2024, Oracle and/or its affiliates. All rights reserved. +# Copyright (c) 2022 - 2025, Oracle and/or its affiliates. All rights reserved. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. """This module contains the Gradle class which inherits BaseBuildTool. @@ -122,7 +122,7 @@ def get_dep_analyzer(self) -> CycloneDxGradle: raise DependencyAnalyzerError("No default dependency analyzer is found.") if not DependencyAnalyzer.tool_valid(defaults.get("dependency.resolver", "dep_tool_gradle")): raise DependencyAnalyzerError( - f"Dependency analyzer {defaults.get('dependency.resolver','dep_tool_gradle')} is not valid.", + f"Dependency analyzer {defaults.get('dependency.resolver', 'dep_tool_gradle')} is not valid.", ) tool_name, tool_version = tuple( diff --git a/src/macaron/slsa_analyzer/build_tool/maven.py b/src/macaron/slsa_analyzer/build_tool/maven.py index 69323ad9c..e6c11c13e 100644 --- a/src/macaron/slsa_analyzer/build_tool/maven.py +++ b/src/macaron/slsa_analyzer/build_tool/maven.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 - 2024, Oracle and/or its affiliates. All rights reserved. +# Copyright (c) 2022 - 2025, Oracle and/or its affiliates. All rights reserved. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. """This module contains the Maven class which inherits BaseBuildTool. @@ -116,7 +116,7 @@ def get_dep_analyzer(self) -> CycloneDxMaven: raise DependencyAnalyzerError("No default dependency analyzer is found.") if not DependencyAnalyzer.tool_valid(defaults.get("dependency.resolver", "dep_tool_maven")): raise DependencyAnalyzerError( - f"Dependency analyzer {defaults.get('dependency.resolver','dep_tool_maven')} is not valid.", + f"Dependency analyzer {defaults.get('dependency.resolver', 'dep_tool_maven')} is not valid.", ) tool_name, tool_version = tuple( diff --git a/src/macaron/slsa_analyzer/build_tool/pip.py b/src/macaron/slsa_analyzer/build_tool/pip.py index 5abf0c0ba..c0e970ab9 100644 --- a/src/macaron/slsa_analyzer/build_tool/pip.py +++ b/src/macaron/slsa_analyzer/build_tool/pip.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 - 2024, Oracle and/or its affiliates. All rights reserved. +# Copyright (c) 2023 - 2025, Oracle and/or its affiliates. All rights reserved. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. """This module contains the Pip class which inherits BaseBuildTool. @@ -88,7 +88,7 @@ def get_dep_analyzer(self) -> DependencyAnalyzer: tool_name = "cyclonedx_py" if not DependencyAnalyzer.tool_valid(f"{tool_name}:{cyclonedx_version}"): raise DependencyAnalyzerError( - f"Dependency analyzer {defaults.get('dependency.resolver','dep_tool_gradle')} is not valid.", + f"Dependency analyzer {defaults.get('dependency.resolver', 'dep_tool_gradle')} is not valid.", ) return CycloneDxPython( resources_path=global_config.resources_path, diff --git a/src/macaron/slsa_analyzer/build_tool/poetry.py b/src/macaron/slsa_analyzer/build_tool/poetry.py index eeb54216b..54e3899f1 100644 --- a/src/macaron/slsa_analyzer/build_tool/poetry.py +++ b/src/macaron/slsa_analyzer/build_tool/poetry.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 - 2024, Oracle and/or its affiliates. All rights reserved. +# Copyright (c) 2023 - 2025, Oracle and/or its affiliates. All rights reserved. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. """This module contains the Poetry class which inherits BaseBuildTool. @@ -126,7 +126,7 @@ def get_dep_analyzer(self) -> DependencyAnalyzer: tool_name = "cyclonedx_py" if not DependencyAnalyzer.tool_valid(f"{tool_name}:{cyclonedx_version}"): raise DependencyAnalyzerError( - f"Dependency analyzer {defaults.get('dependency.resolver','dep_tool_gradle')} is not valid.", + f"Dependency analyzer {defaults.get('dependency.resolver', 'dep_tool_gradle')} is not valid.", ) return CycloneDxPython( resources_path=global_config.resources_path, diff --git a/src/macaron/slsa_analyzer/checks/detect_malicious_metadata_check.py b/src/macaron/slsa_analyzer/checks/detect_malicious_metadata_check.py index 8514a458d..5c87d6f27 100644 --- a/src/macaron/slsa_analyzer/checks/detect_malicious_metadata_check.py +++ b/src/macaron/slsa_analyzer/checks/detect_malicious_metadata_check.py @@ -26,6 +26,7 @@ from macaron.malware_analyzer.pypi_heuristics.metadata.typosquatting_presence import TyposquattingPresenceAnalyzer from macaron.malware_analyzer.pypi_heuristics.metadata.unchanged_release import UnchangedReleaseAnalyzer from macaron.malware_analyzer.pypi_heuristics.metadata.wheel_absence import WheelAbsenceAnalyzer +from macaron.malware_analyzer.pypi_heuristics.sourcecode.matching_docstrings import MatchingDocstringsAnalyzer from macaron.malware_analyzer.pypi_heuristics.sourcecode.pypi_sourcecode_analyzer import PyPISourcecodeAnalyzer from macaron.malware_analyzer.pypi_heuristics.sourcecode.suspicious_setup import SuspiciousSetupAnalyzer from macaron.slsa_analyzer.analyze_context import AnalyzeContext @@ -358,6 +359,7 @@ def run_check(self, ctx: AnalyzeContext) -> CheckResultData: WheelAbsenceAnalyzer, AnomalousVersionAnalyzer, TyposquattingPresenceAnalyzer, + MatchingDocstringsAnalyzer, ] # name used to query the result of all problog rules, so it can be accessed outside the model. @@ -425,6 +427,10 @@ def run_check(self, ctx: AnalyzeContext) -> CheckResultData: failed({Heuristics.ONE_RELEASE.value}), failed({Heuristics.ANOMALOUS_VERSION.value}). + % Package released with a name similar to a popular package. + {Confidence.MEDIUM.value}::trigger(malware_medium_confidence_3) :- + quickUndetailed, forceSetup, failed({Heuristics.MATCHING_DOCSTRINGS.value}). + % ----- Evaluation ----- % Aggregate result @@ -432,6 +438,7 @@ def run_check(self, ctx: AnalyzeContext) -> CheckResultData: {problog_result_access} :- trigger(malware_high_confidence_2). {problog_result_access} :- trigger(malware_high_confidence_3). {problog_result_access} :- trigger(malware_high_confidence_4). + {problog_result_access} :- trigger(malware_medium_confidence_3). {problog_result_access} :- trigger(malware_medium_confidence_2). {problog_result_access} :- trigger(malware_medium_confidence_1). query({problog_result_access}). diff --git a/tests/malware_analyzer/pypi/test_matching_docstrings.py b/tests/malware_analyzer/pypi/test_matching_docstrings.py new file mode 100644 index 000000000..c427fa6f9 --- /dev/null +++ b/tests/malware_analyzer/pypi/test_matching_docstrings.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024 - 2025, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. + +"""Tests for the MatchingDocstringsAnalyzer heuristic.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from macaron.malware_analyzer.pypi_heuristics.heuristics import HeuristicResult +from macaron.malware_analyzer.pypi_heuristics.sourcecode.matching_docstrings import MatchingDocstringsAnalyzer, Result + + +@pytest.fixture(name="analyzer") +def analyzer_() -> MatchingDocstringsAnalyzer: + """Pytest fixture to create a MatchingDocstringsAnalyzer instance.""" + return MatchingDocstringsAnalyzer() + + +@pytest.fixture(autouse=True) +def skip_if_client_disabled(analyzer: MatchingDocstringsAnalyzer) -> None: + """ + Automatically skip tests in this file if the AI client is disabled. + """ + if not analyzer.client.enabled: + pytest.skip("AI client disabled - skipping test") + + +def test_analyze_consistent_docstrings_pass(analyzer: MatchingDocstringsAnalyzer, pypi_package_json: MagicMock) -> None: + """Test the analyzer passes when docstrings are consistent with the code.""" + pypi_package_json.download_sourcecode.return_value = True + pypi_package_json.iter_sourcecode.return_value = [("test.py", b"def func():\n '''docstring'''\n pass")] + + mock_result = Result(decision="consistent", reason="The code is consistent with the docstring.") + + with patch.object(analyzer.client, "invoke", return_value=mock_result) as mock_invoke: + result, info = analyzer.analyze(pypi_package_json) + assert result == HeuristicResult.PASS + assert not info + mock_invoke.assert_called_once() + + +def test_analyze_inconsistent_docstrings_fail( + analyzer: MatchingDocstringsAnalyzer, pypi_package_json: MagicMock +) -> None: + """Test the analyzer fails when docstrings are inconsistent with the code.""" + pypi_package_json.download_sourcecode.return_value = True + pypi_package_json.iter_sourcecode.return_value = [ + ("test.py", b"def func():\n '''docstring'''\n print('hello')") + ] + + mock_result = Result( + decision="inconsistent", + reason="The docstring does not mention the print statement.", + inconsistent_code_part="print('hello')", + ) + + with patch.object(analyzer.client, "invoke", return_value=mock_result): + result, info = analyzer.analyze(pypi_package_json) + assert result == HeuristicResult.FAIL + assert info["file"] == "test.py" + assert info["reason"] == "The docstring does not mention the print statement." + assert info["inconsistent part"] == "print('hello')" + + +def test_analyze_ai_client_disabled_skip(analyzer: MatchingDocstringsAnalyzer, pypi_package_json: MagicMock) -> None: + """Test the analyzer skips when the AI client is disabled.""" + with patch.object(analyzer.client, "enabled", False): + result, info = analyzer.analyze(pypi_package_json) + assert result == HeuristicResult.SKIP + assert not info + + +def test_analyze_no_source_code_skip(analyzer: MatchingDocstringsAnalyzer, pypi_package_json: MagicMock) -> None: + """Test the analyzer skips if the source code cannot be downloaded.""" + pypi_package_json.download_sourcecode.return_value = False + with patch.object(analyzer.client, "invoke") as mock_invoke: + result, info = analyzer.analyze(pypi_package_json) + assert result == HeuristicResult.SKIP + assert not info + mock_invoke.assert_not_called() + + +def test_analyze_no_python_files_pass(analyzer: MatchingDocstringsAnalyzer, pypi_package_json: MagicMock) -> None: + """Test the analyzer passes if there are no Python files in the source code.""" + pypi_package_json.download_sourcecode.return_value = True + pypi_package_json.iter_sourcecode.return_value = [("README.md", b"This is a test package.")] + with patch.object(analyzer.client, "invoke") as mock_invoke: + result, info = analyzer.analyze(pypi_package_json) + assert result == HeuristicResult.PASS + assert not info + mock_invoke.assert_not_called() + + +def test_analyze_llm_invocation_error_pass(analyzer: MatchingDocstringsAnalyzer, pypi_package_json: MagicMock) -> None: + """Test the analyzer passes if the LLM invocation returns None (e.g., on API error).""" + pypi_package_json.download_sourcecode.return_value = True + pypi_package_json.iter_sourcecode.return_value = [("test.py", b"def func():\n pass")] + + with patch.object(analyzer.client, "invoke", return_value=None): + result, info = analyzer.analyze(pypi_package_json) + assert result == HeuristicResult.PASS + assert not info From 71e5504c97db464293b4f234b1d0513b37316a5d Mon Sep 17 00:00:00 2001 From: Amine Date: Thu, 24 Jul 2025 10:52:16 +0100 Subject: [PATCH 2/2] feat(ai): improve robustness of AI client Signed-off-by: Amine --- pyproject.toml | 2 + src/macaron/ai.py | 175 ------------------ src/macaron/ai/README.md | 50 +++++ src/macaron/ai/__init__.py | 2 + src/macaron/ai/ai_client.py | 53 ++++++ src/macaron/ai/ai_factory.py | 70 +++++++ src/macaron/ai/ai_tools.py | 53 ++++++ src/macaron/ai/openai_client.py | 100 ++++++++++ src/macaron/config/defaults.ini | 6 +- .../sourcecode/matching_docstrings.py | 13 +- .../pypi/test_matching_docstrings.py | 10 +- 11 files changed, 345 insertions(+), 189 deletions(-) delete mode 100644 src/macaron/ai.py create mode 100644 src/macaron/ai/README.md create mode 100644 src/macaron/ai/__init__.py create mode 100644 src/macaron/ai/ai_client.py create mode 100644 src/macaron/ai/ai_factory.py create mode 100644 src/macaron/ai/ai_tools.py create mode 100644 src/macaron/ai/openai_client.py diff --git a/pyproject.toml b/pyproject.toml index 74705364b..08acad323 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,8 @@ dependencies = [ "problog >= 2.2.6,<3.0.0", "cryptography >=44.0.0,<45.0.0", "semgrep == 1.113.0", + "pydantic >= 2.11.5,<2.12.0", + "gradio_client == 1.4.3", ] keywords = [] # https://pypi.org/classifiers/ diff --git a/src/macaron/ai.py b/src/macaron/ai.py deleted file mode 100644 index eb48ba08b..000000000 --- a/src/macaron/ai.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) 2024 - 2025, Oracle and/or its affiliates. All rights reserved. -# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. - -"""This module provides a client for interacting with a Large Language Model (LLM).""" - -import json -import logging -import re -from typing import Any, TypeVar - -from pydantic import BaseModel, ValidationError - -from macaron.config.defaults import defaults -from macaron.errors import ConfigurationError, HeuristicAnalyzerValueError -from macaron.util import send_post_http_raw - -logger: logging.Logger = logging.getLogger(__name__) - -T = TypeVar("T", bound=BaseModel) - - -class AIClient: - """A client for interacting with a Large Language Model.""" - - def __init__(self, system_prompt: str): - """ - Initialize the AI client. - - The LLM configuration (enabled, API key, endpoint, model) is read from defaults. - """ - self.enabled, self.api_endpoint, self.api_key, self.model, self.context_window = self._load_defaults() - self.system_prompt = system_prompt.strip() or "You are a helpful AI assistant." - logger.info("AI client is %s.", "enabled" if self.enabled else "disabled") - - def _load_defaults(self) -> tuple[bool, str, str, str, int]: - """Load the LLM configuration from the defaults.""" - section_name = "llm" - enabled, api_key, api_endpoint, model, context_window = False, "", "", "", 10000 - - if defaults.has_section(section_name): - section = defaults[section_name] - enabled = section.get("enabled", "False").strip().lower() == "true" - api_key = section.get("api_key", "").strip() - api_endpoint = section.get("api_endpoint", "").strip() - model = section.get("model", "").strip() - context_window = section.getint("context_window", 10000) - - if enabled: - if not api_key: - raise ConfigurationError("API key for the AI client is not configured.") - if not api_endpoint: - raise ConfigurationError("API endpoint for the AI client is not configured.") - if not model: - raise ConfigurationError("Model for the AI client is not configured.") - - return enabled, api_endpoint, api_key, model, context_window - - def _validate_response(self, response_text: str, response_model: type[T]) -> T: - """ - Validate and parse the response from the LLM. - - If raw JSON parsing fails, attempts to extract a JSON object from text. - - Parameters - ---------- - response_text: str - The response text from the LLM. - response_model: Type[T] - The Pydantic model to validate the response against. - - Returns - ------- - bool - The validated Pydantic model instance. - - Raises - ------ - HeuristicAnalyzerValueError - If there is an error in parsing or validating the response. - """ - try: - data = json.loads(response_text) - except json.JSONDecodeError: - logger.debug("Full JSON parse failed; trying to extract JSON from text.") - # If the response is not a valid JSON, try to extract a JSON object from the text. - match = re.search(r"\{.*\}", response_text, re.DOTALL) - if not match: - raise HeuristicAnalyzerValueError("No JSON object found in the LLM response.") from match - try: - data = json.loads(match.group(0)) - except json.JSONDecodeError as e: - logger.error("Failed to parse extracted JSON: %s", e) - raise HeuristicAnalyzerValueError("Invalid JSON extracted from response.") from e - - try: - return response_model.model_validate(data) - except ValidationError as e: - logger.error("Validation failed against response model: %s", e) - raise HeuristicAnalyzerValueError("Response JSON validation failed.") from e - - def invoke( - self, - user_prompt: str, - temperature: float = 0.2, - max_tokens: int = 4000, - structured_output: type[T] | None = None, - timeout: int = 30, - ) -> Any: - """ - Invoke the LLM and optionally validate its response. - - Parameters - ---------- - user_prompt: str - The user prompt to send to the LLM. - temperature: float - The temperature for the LLM response. - max_tokens: int - The maximum number of tokens for the LLM response. - structured_output: Optional[Type[T]] - The Pydantic model to validate the response against. If provided, the response will be parsed and validated. - timeout: int - The timeout for the HTTP request in seconds. - - Returns - ------- - Optional[T | str] - The validated Pydantic model instance if `structured_output` is provided, - or the raw string response if not. - - Raises - ------ - HeuristicAnalyzerValueError - If there is an error in parsing or validating the response. - """ - if not self.enabled: - raise ConfigurationError("AI client is not enabled. Please check your configuration.") - - if len(user_prompt.split()) > self.context_window: - logger.warning( - "User prompt exceeds context window (%s words). " - "Truncating the prompt to fit within the context window.", - self.context_window, - ) - user_prompt = " ".join(user_prompt.split()[: self.context_window]) - - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} - payload = { - "model": self.model, - "messages": [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": user_prompt}], - "temperature": temperature, - "max_tokens": max_tokens, - } - - try: - response = send_post_http_raw(url=self.api_endpoint, json_data=payload, headers=headers, timeout=timeout) - if not response: - raise HeuristicAnalyzerValueError("No response received from the LLM.") - response_json = response.json() - usage = response_json.get("usage", {}) - - if usage: - usage_str = ", ".join(f"{key} = {value}" for key, value in usage.items()) - logger.info("LLM call token usage: %s", usage_str) - - message_content = response_json["choices"][0]["message"]["content"] - - if not structured_output: - logger.debug("Returning raw message content (no structured output requested).") - return message_content - return self._validate_response(message_content, structured_output) - - except Exception as e: - logger.error("Error during LLM invocation: %s", e) - raise HeuristicAnalyzerValueError(f"Failed to get or validate LLM response: {e}") from e diff --git a/src/macaron/ai/README.md b/src/macaron/ai/README.md new file mode 100644 index 000000000..28ddf4757 --- /dev/null +++ b/src/macaron/ai/README.md @@ -0,0 +1,50 @@ +# Macaron AI Module + +This module provides the foundation for interacting with Large Language Models (LLMs) in a provider-agnostic way. It includes an abstract client definition, provider-specific client implementations, a client factory, and utility functions for processing responses. + +## Module Components + +- **ai_client.py** + Defines the abstract [`AIClient`](./ai_client.py) class. This class handles the initialization of LLM configuration from the defaults and serves as the base for all specific AI client implementations. + +- **openai_client.py** + Implements the [`OpenAiClient`](./openai_client.py) class, a concrete subclass of [`AIClient`](./ai_client.py). This client interacts with OpenAI-like APIs by sending requests using HTTP and processing the responses. It also validates and structures responses using the tools provided. + +- **ai_factory.py** + Contains the [`AIClientFactory`](./ai_factory.py) class, which is responsible for reading provider configuration from the defaults and creating the correct AI client instance. + +- **ai_tools.py** + Offers utility functions such as `structure_response` to assist with parsing and validating the JSON response returned by an LLM. These functions ensure that responses conform to a given Pydantic model for easier downstream processing. + +## Usage + +1. **Configuration:** + The module reads the LLM configuration from the application defaults (using the `defaults` module). Make sure that the `llm` section in your configuration includes valid settings such as `enabled`, `api_key`, `api_endpoint`, `model`, and `context_window`. + +2. **Creating a Client:** + Use the [`AIClientFactory`](./ai_factory.py) to create an AI client instance. The factory checks the configured provider and returns a client (e.g., an instance of [`OpenAiClient`](./openai_client.py)) that can be used to invoke the LLM. + + Example: + ```py + from macaron.ai.ai_factory import AIClientFactory + + factory = AIClientFactory() + client = factory.create_client(system_prompt="You are a helpful assistant.") + response = client.invoke("Hello, how can you assist me?") + print(response) + ``` + +3. **Response Processing:** + When a structured response is required, pass a Pydantic model class to the `invoke` method. The [`ai_tools.py`](./ai_tools.py) module takes care of parsing and validating the response to ensure it meets the expected structure. + +## Logging and Error Handling + +- The module uses Python's logging framework to report important events, such as token usage and warnings when prompts exceed the allowed context window. +- Configuration errors (e.g., missing API key or endpoint) are handled by raising descriptive exceptions, such as those defined in the [`ConfigurationError`](../errors.py). + +## Extensibility + +The design of the AI module is provider-agnostic. To add support for additional LLM providers: +- Implement a new client by subclassing [`AIClient`](./ai_client.py). +- Add the new client to the [`PROVIDER_MAPPING`](./ai_factory.py). +- Update the configuration defaults accordingly. diff --git a/src/macaron/ai/__init__.py b/src/macaron/ai/__init__.py new file mode 100644 index 000000000..8e17a3508 --- /dev/null +++ b/src/macaron/ai/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. diff --git a/src/macaron/ai/ai_client.py b/src/macaron/ai/ai_client.py new file mode 100644 index 000000000..35733e5d8 --- /dev/null +++ b/src/macaron/ai/ai_client.py @@ -0,0 +1,53 @@ +# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. + +"""This module defines the abstract AIClient class for implementing AI clients.""" + +import logging +from abc import ABC, abstractmethod +from typing import Any, TypeVar + +from pydantic import BaseModel + +T = TypeVar("T", bound=BaseModel) + +logger: logging.Logger = logging.getLogger(__name__) + + +class AIClient(ABC): + """This abstract class is used to implement ai clients.""" + + def __init__(self, system_prompt: str, defaults: dict) -> None: + """ + Initialize the AI client. + + The LLM configuration is read from defaults. + """ + self.system_prompt = system_prompt + self.defaults = defaults + + @abstractmethod + def invoke( + self, + user_prompt: str, + temperature: float = 0.2, + structured_output: type[T] | None = None, + ) -> Any: + """ + Invoke the LLM and optionally validate its response. + + Parameters + ---------- + user_prompt: str + The user prompt to send to the LLM. + temperature: float + The temperature for the LLM response. + structured_output: Optional[Type[T]] + The Pydantic model to validate the response against. If provided, the response will be parsed and validated. + + Returns + ------- + Optional[T | str] + The validated Pydantic model instance if `structured_output` is provided, + or the raw string response if not. + """ diff --git a/src/macaron/ai/ai_factory.py b/src/macaron/ai/ai_factory.py new file mode 100644 index 000000000..9462ebf86 --- /dev/null +++ b/src/macaron/ai/ai_factory.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. + +"""This module defines the AIClientFactory class for creating AI clients based on provider configuration.""" + +import logging + +from macaron.ai.ai_client import AIClient +from macaron.ai.openai_client import OpenAiClient +from macaron.config.defaults import defaults +from macaron.errors import ConfigurationError + +logger: logging.Logger = logging.getLogger(__name__) + + +class AIClientFactory: + """Factory to create AI clients based on provider configuration.""" + + PROVIDER_MAPPING: dict[str, type[AIClient]] = {"openai": OpenAiClient} + + def __init__(self) -> None: + """ + Initialize the AI client. + + The LLM configuration is read from defaults. + """ + self.defaults = self._load_defaults() + + def _load_defaults(self) -> dict: + section_name = "llm" + default_values = { + "enabled": False, + "provider": "", + "api_key": "", + "api_endpoint": "", + "model": "", + "context_window": 10000, + } + + if defaults.has_section(section_name): + section = defaults[section_name] + default_values["enabled"] = section.getboolean("enabled", default_values["enabled"]) + default_values["api_key"] = str(section.get("api_key", default_values["api_key"])).strip().lower() + default_values["api_endpoint"] = ( + str(section.get("api_endpoint", default_values["api_endpoint"])).strip().lower() + ) + default_values["model"] = str(section.get("model", default_values["model"])).strip().lower() + default_values["provider"] = str(section.get("provider", default_values["provider"])).strip().lower() + default_values["context_window"] = section.getint("context_window", 10000) + + if default_values["enabled"]: + for key, value in default_values.items(): + if not value: + raise ConfigurationError( + f"AI client configuration '{key}' is required but not set in the defaults." + ) + + return default_values + + def create_client(self, system_prompt: str) -> AIClient | None: + """Create an AI client based on the configured provider.""" + client_class = self.PROVIDER_MAPPING.get(self.defaults["provider"]) + if client_class is None: + logger.error("Provider '%s' is not supported.", self.defaults["provider"]) + return None + return client_class(system_prompt, self.defaults) + + def list_available_providers(self) -> list[str]: + """List all registered providers.""" + return list(self.PROVIDER_MAPPING.keys()) diff --git a/src/macaron/ai/ai_tools.py b/src/macaron/ai/ai_tools.py new file mode 100644 index 000000000..e476376f9 --- /dev/null +++ b/src/macaron/ai/ai_tools.py @@ -0,0 +1,53 @@ +# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. + +"""This module provides utility functions for Large Language Model (LLM).""" +import json +import logging +import re +from typing import TypeVar + +from pydantic import BaseModel, ValidationError + +T = TypeVar("T", bound=BaseModel) + +logger: logging.Logger = logging.getLogger(__name__) + + +def structure_response(response_text: str, response_model: type[T]) -> T | None: + """ + Structure and parse the response from the LLM. + + If raw JSON parsing fails, attempts to extract a JSON object from text. + + Parameters + ---------- + response_text: str + The response text from the LLM. + response_model: Type[T] + The Pydantic model to structure the response against. + + Returns + ------- + T | None + The structured Pydantic model instance. + """ + try: + data = json.loads(response_text) + except json.JSONDecodeError: + logger.debug("Full JSON parse failed; trying to extract JSON from text.") + # If the response is not a valid JSON, try to extract a JSON object from the text. + match = re.search(r"\{.*\}", response_text, re.DOTALL) + if not match: + return None + try: + data = json.loads(match.group(0)) + except json.JSONDecodeError as e: + logger.debug("Failed to parse extracted JSON: %s", e) + return None + + try: + return response_model.model_validate(data) + except ValidationError as e: + logger.debug("Validation failed against response model: %s", e) + return None diff --git a/src/macaron/ai/openai_client.py b/src/macaron/ai/openai_client.py new file mode 100644 index 000000000..cd856745c --- /dev/null +++ b/src/macaron/ai/openai_client.py @@ -0,0 +1,100 @@ +# Copyright (c) 2024 - 2025, Oracle and/or its affiliates. All rights reserved. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. + +"""This module provides a client for interacting with a Large Language Model (LLM) that is Openai like.""" + +import logging +from typing import Any, TypeVar + +from pydantic import BaseModel + +from macaron.ai.ai_client import AIClient +from macaron.ai.ai_tools import structure_response +from macaron.errors import ConfigurationError, HeuristicAnalyzerValueError +from macaron.util import send_post_http_raw + +logger: logging.Logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class OpenAiClient(AIClient): + """A client for interacting with a Large Language Model that is OpenAI API like.""" + + def invoke( + self, + user_prompt: str, + temperature: float = 0.2, + structured_output: type[T] | None = None, + max_tokens: int = 4000, + timeout: int = 30, + ) -> Any: + """ + Invoke the LLM and optionally validate its response. + + Parameters + ---------- + user_prompt: str + The user prompt to send to the LLM. + temperature: float + The temperature for the LLM response. + structured_output: Optional[Type[T]] + The Pydantic model to validate the response against. If provided, the response will be parsed and validated. + max_tokens: int + The maximum number of tokens for the LLM response. + timeout: int + The timeout for the HTTP request in seconds. + + Returns + ------- + Optional[T | str] + The validated Pydantic model instance if `structured_output` is provided, + or the raw string response if not. + + Raises + ------ + HeuristicAnalyzerValueError + If there is an error in parsing or validating the response. + """ + if not self.defaults["enabled"]: + raise ConfigurationError("AI client is not enabled. Please check your configuration.") + + if len(user_prompt.split()) > self.defaults["context_window"]: + logger.warning( + "User prompt exceeds context window (%s words). " + "Truncating the prompt to fit within the context window.", + self.defaults["context_window"], + ) + user_prompt = " ".join(user_prompt.split()[: self.defaults["context_window"]]) + + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.defaults["api_key"]}"} + payload = { + "model": self.defaults["model"], + "messages": [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": user_prompt}], + "temperature": temperature, + "max_tokens": max_tokens, + } + + try: + response = send_post_http_raw( + url=self.defaults["api_endpoint"], json_data=payload, headers=headers, timeout=timeout + ) + if not response: + raise HeuristicAnalyzerValueError("No response received from the LLM.") + response_json = response.json() + usage = response_json.get("usage", {}) + + if usage: + usage_str = ", ".join(f"{key} = {value}" for key, value in usage.items()) + logger.info("LLM call token usage: %s", usage_str) + + message_content = response_json["choices"][0]["message"]["content"] + + if not structured_output: + logger.debug("Returning raw message content (no structured output requested).") + return message_content + return structure_response(message_content, structured_output) + + except Exception as e: + logger.error("Error during LLM invocation: %s", e) + raise HeuristicAnalyzerValueError(f"Failed to get or validate LLM response: {e}") from e diff --git a/src/macaron/config/defaults.ini b/src/macaron/config/defaults.ini index fd7762065..48113a4ae 100644 --- a/src/macaron/config/defaults.ini +++ b/src/macaron/config/defaults.ini @@ -639,7 +639,11 @@ disabled_custom_rulesets = [llm] # The LLM configuration for Macaron. # If enabled, the LLM will be used to analyze the results and provide insights. -enabled = +enabled = False +# The provider for the LLM service. +# Supported providers : +# - openai: OpenAI's GPT models. +provider = # The API key for the LLM service. api_key = # The API endpoint for the LLM service. diff --git a/src/macaron/malware_analyzer/pypi_heuristics/sourcecode/matching_docstrings.py b/src/macaron/malware_analyzer/pypi_heuristics/sourcecode/matching_docstrings.py index ca9cafbe3..bd5a864da 100644 --- a/src/macaron/malware_analyzer/pypi_heuristics/sourcecode/matching_docstrings.py +++ b/src/macaron/malware_analyzer/pypi_heuristics/sourcecode/matching_docstrings.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, Field -from macaron.ai import AIClient +from macaron.ai.ai_factory import AIClientFactory from macaron.json_tools import JsonType from macaron.malware_analyzer.pypi_heuristics.base_analyzer import BaseHeuristicAnalyzer from macaron.malware_analyzer.pypi_heuristics.heuristics import HeuristicResult, Heuristics @@ -60,7 +60,13 @@ def __init__(self) -> None: heuristic=Heuristics.MATCHING_DOCSTRINGS, depends_on=None, ) - self.client = AIClient(system_prompt=self.SYSTEM_PROMPT.strip()) + factory = AIClientFactory() + client = None + + if factory.defaults["enabled"]: + client = factory.create_client(self.SYSTEM_PROMPT.strip()) + + self.client = client def analyze(self, pypi_package_json: PyPIPackageJsonAsset) -> tuple[HeuristicResult, dict[str, JsonType]]: """Analyze the package. @@ -75,8 +81,7 @@ def analyze(self, pypi_package_json: PyPIPackageJsonAsset) -> tuple[HeuristicRes tuple[HeuristicResult, dict[str, JsonType]]: The result and related information collected during the analysis. """ - if not self.client.enabled: - logger.warning("AI client is not enabled, skipping the matching docstrings analysis.") + if not self.client: return HeuristicResult.SKIP, {} download_result = pypi_package_json.download_sourcecode() diff --git a/tests/malware_analyzer/pypi/test_matching_docstrings.py b/tests/malware_analyzer/pypi/test_matching_docstrings.py index c427fa6f9..f051bf76c 100644 --- a/tests/malware_analyzer/pypi/test_matching_docstrings.py +++ b/tests/malware_analyzer/pypi/test_matching_docstrings.py @@ -22,7 +22,7 @@ def skip_if_client_disabled(analyzer: MatchingDocstringsAnalyzer) -> None: """ Automatically skip tests in this file if the AI client is disabled. """ - if not analyzer.client.enabled: + if not analyzer.client: pytest.skip("AI client disabled - skipping test") @@ -63,14 +63,6 @@ def test_analyze_inconsistent_docstrings_fail( assert info["inconsistent part"] == "print('hello')" -def test_analyze_ai_client_disabled_skip(analyzer: MatchingDocstringsAnalyzer, pypi_package_json: MagicMock) -> None: - """Test the analyzer skips when the AI client is disabled.""" - with patch.object(analyzer.client, "enabled", False): - result, info = analyzer.analyze(pypi_package_json) - assert result == HeuristicResult.SKIP - assert not info - - def test_analyze_no_source_code_skip(analyzer: MatchingDocstringsAnalyzer, pypi_package_json: MagicMock) -> None: """Test the analyzer skips if the source code cannot be downloaded.""" pypi_package_json.download_sourcecode.return_value = False