diff --git a/README.md b/README.md index b0ed8ea..c24ed4e 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,7 @@ if __name__ == "__main__": It is recommended to install the SDK in a virtual environment to avoid dependency conflicts. **Windows:** + ```powershell # Create virtual environment python -m venv venv @@ -152,6 +153,7 @@ python -m venv venv ``` **Linux/macOS:** + ```bash # Create virtual environment python -m venv venv @@ -163,11 +165,13 @@ source venv/bin/activate ### Install via pip **Windows:** + ```powershell pip install agentarts-sdk ``` **Linux/macOS:** + ```bash pip install agentarts-sdk ``` @@ -188,6 +192,7 @@ pip install agentarts-sdk[all] ### Install from Source **Windows:** + ```powershell git clone https://github.com/huaweicloud/agentarts-sdk-python.git cd agentarts-sdk-python @@ -201,6 +206,7 @@ pip install -e ".[dev]" ``` **Linux/macOS:** + ```bash git clone https://github.com/huaweicloud/agentarts-sdk-python.git cd agentarts-sdk-python @@ -218,18 +224,21 @@ pip install -e ".[dev]" Set environment variables for Huawei Cloud authentication: **Windows (PowerShell):** + ```powershell $env:HUAWEICLOUD_SDK_AK = "your-access-key" $env:HUAWEICLOUD_SDK_SK = "your-secret-key" ``` **Windows (Command Prompt):** + ```cmd set HUAWEICLOUD_SDK_AK=your-access-key set HUAWEICLOUD_SDK_SK=your-secret-key ``` **Linux/macOS:** + ```bash export HUAWEICLOUD_SDK_AK="your-access-key" export HUAWEICLOUD_SDK_SK="your-secret-key" @@ -251,6 +260,7 @@ agentarts init -n my_agent -t langgraph ``` This creates: + ``` my_agent/ ├── agent.py # Agent implementation @@ -307,16 +317,16 @@ agentarts destroy ## CLI Commands Reference -| Command | Description | -|---------|-------------| -| `agentarts init` | Initialize a new agent project | -| `agentarts dev` | Start local development server | -| `agentarts config` | Configure SDK settings (alias: `configure`) | -| `agentarts deploy` | Deploy agent to Huawei Cloud (alias: `launch`) | -| `agentarts invoke` | Invoke deployed agent | -| `agentarts status` | Check deployment status | -| `agentarts destroy` | Remove deployed agent | -| `agentarts mcp-gateway` | Manage MCP gateways | +| Command | Description | +| ----------------------- | ---------------------------------------------- | +| `agentarts init` | Initialize a new agent project | +| `agentarts dev` | Start local development server | +| `agentarts config` | Configure SDK settings (alias: `configure`) | +| `agentarts deploy` | Deploy agent to Huawei Cloud (alias: `launch`) | +| `agentarts invoke` | Invoke deployed agent | +| `agentarts status` | Check deployment status | +| `agentarts destroy` | Remove deployed agent | +| `agentarts mcp-gateway` | Manage MCP gateways | ## Limitations & Requirements @@ -329,20 +339,22 @@ agentarts destroy When using optional framework dependencies, ensure the following minimum versions: -| Framework | Minimum Version | Install Command | -|-----------|-----------------|-----------------| -| LangGraph | 1.0.0 | `pip install agentarts-sdk[langgraph]` | -| LangChain | 0.1.0 | `pip install agentarts-sdk[langchain]` | -| langchain-core | 0.1.0 | Included with langgraph/langchain | +| Framework | Minimum Version | Install Command | +| -------------- | --------------- | -------------------------------------- | +| LangGraph | 1.0.0 | `pip install agentarts-sdk[langgraph]` | +| LangChain | 0.1.0 | `pip install agentarts-sdk[langchain]` | +| langchain-core | 0.1.0 | Included with langgraph/langchain | > **Note:** LangGraph 1.0+ introduces a new Checkpoint format with required fields (`step`, `pending_sends`, `parents`). The SDK's integration module is compatible with LangGraph 1.0 and above. ### Docker Docker is required for: + - Building and deploying agents with `agentarts deploy` (alias: `launch`) **Install Docker:** + - [Docker Desktop for Windows](https://www.docker.com/products/docker-desktop) - [Docker Engine for Linux](https://docs.docker.com/engine/install/) @@ -387,5 +399,6 @@ Contributions are welcome! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for det ## Support - **Issues:** [GitHub Issues](https://github.com/huaweicloud/agentarts-sdk-python/issues) -- **Documentation:** https://docs.huaweicloud.com/agentarts -- **Email:** agentarts@huawei.com +- **Documentation:** +- **Email:** + diff --git a/docs/cn/sdk_user_guide/environment_variables.md b/docs/cn/sdk_user_guide/environment_variables.md index c285372..7ec3e6c 100644 --- a/docs/cn/sdk_user_guide/environment_variables.md +++ b/docs/cn/sdk_user_guide/environment_variables.md @@ -210,6 +210,45 @@ export HUAWEICLOUD_SDK_AGENTIDENTITY_ENDPOINT="https://agent-identity.cn-north-4 ## 其他配置 +### SDK 日志级别 + +用于控制 SDK 的日志输出级别: + +| 环境变量 | 说明 | 默认值 | +|----------|------|--------| +| `AGENTARTS_LOG_LEVEL` | SDK 日志级别 | `INFO` | + +**可选值:** `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL` + +**配置示例:** + +```bash +# 开启详细日志(调试模式) +export AGENTARTS_LOG_LEVEL="DEBUG" + +# 只显示警告及以上级别的日志 +export AGENTARTS_LOG_LEVEL="WARNING" +``` + +**代码配置方式:** + +```python +from agentarts.sdk.utils.logging import setup_logging + +# 设置日志级别 +setup_logging(level="DEBUG") + +# 或使用环境变量(自动读取 AGENTARTS_LOG_LEVEL) +setup_logging() + +# 自定义日志输出到文件 +import logging +setup_logging( + level="DEBUG", + handler=logging.FileHandler("/var/log/agentarts/app.log") +) +``` + ### Python 基础镜像 用于指定 Agent 部署时的 Python 基础镜像: @@ -224,6 +263,31 @@ export HUAWEICLOUD_SDK_AGENTIDENTITY_ENDPOINT="https://agent-identity.cn-north-4 export PYTHON_BASE_IMAGE="python:3.11-slim" ``` +### 运行时监听地址 + +用于指定 Agent 运行时的监听地址: + +| 环境变量 | 说明 | 默认值 | +|----------|------|--------| +| `AGENTARTS_BIND_IP` | 运行时监听地址 | Docker: 自动获取 eth0 IP,本地: `127.0.0.1` | + +**默认行为:** + +| 环境 | 默认监听地址 | +|------|-------------| +| Docker 容器 | 自动获取 eth0 接口 IP(通过 `ip addr show eth0`) | +| 本地开发 | `127.0.0.1` | + +**配置示例:** + +```bash +# 手动指定监听地址 +export AGENTARTS_BIND_IP="0.0.0.0" + +# 或在代码中指定 +app.run(host="0.0.0.0", port=8080) +``` + --- ## 环境变量优先级说明 diff --git a/docs/cn/toolkit_user_guide/invoke.md b/docs/cn/toolkit_user_guide/invoke.md index b9d3f03..cfd6a03 100644 --- a/docs/cn/toolkit_user_guide/invoke.md +++ b/docs/cn/toolkit_user_guide/invoke.md @@ -17,6 +17,7 @@ | `--session` | `-s` | 会话 ID(用于有状态 Agent) | 自动生成 UUID | | `--bearer-token` | `-bt` | Bearer 认证令牌 | 无 | | `--timeout` | - | 请求超时时间(秒) | `900` | +| `--user-id` | `-u` | 用户 ID(用于 OAuth2 出站凭据) | 无 | ## 调用模式 @@ -118,13 +119,24 @@ agentarts invoke '{"message": "你好"}' --agent my-agent --bearer-token "your-t agentarts invoke '{"message": "你好"}' --agent my-agent --endpoint custom-endpoint ``` -### 示例 9: 设置超时时间 +### 示例 10: 设置超时时间 ```bash agentarts invoke '{"message": "你好"}' --agent my-agent --timeout 60 ``` -### 示例 10: 复杂数据调用 +### 示例 11: 使用用户 ID(OAuth2 出站凭据) + +```bash +agentarts invoke '{"message": "你好"}' --agent my-agent --user-id "user-123" +``` + +或使用简写: +```bash +agentarts invoke '{"message": "你好"}' -a my-agent -u "user-123" +``` + +### 示例 12: 复杂数据调用 ```bash agentarts invoke '{ diff --git a/docs/cn/toolkit_user_guide/status.md b/docs/cn/toolkit_user_guide/status.md index 138891c..589fee8 100644 --- a/docs/cn/toolkit_user_guide/status.md +++ b/docs/cn/toolkit_user_guide/status.md @@ -15,6 +15,7 @@ | `--bearer-token` | `-bt` | Bearer 认证令牌 | 无 | | `--endpoint` | `-e` | 指定端点名称 | 无 | | `--session` | `-s` | 会话 ID(用于有状态 Agent) | 无 | +| `--user-id` | `-u` | 用户 ID(用于 OAuth2 出站凭据) | 无 | ## 健康状态说明 @@ -144,6 +145,17 @@ agentarts status \ --bearer-token "your-token" ``` +### 示例 9: 使用用户 ID(OAuth2 出站凭据) + +```bash +agentarts status --agent my-agent --user-id "user-123" +``` + +或使用简写: +```bash +agentarts status -a my-agent -u "user-123" +``` + ## 检查模式 ### Cloud 模式(默认) diff --git a/pyproject.toml b/pyproject.toml index a3b6d42..dfc7812 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "uv_build" [project] name = "agentarts-sdk" -version = "0.1.0" +version = "0.1.1" description = "Huawei Cloud AgentArts SDK - Build, deploy and manage AI agents with cloud capabilities" readme = "README.md" license = {file = "LICENSE"} @@ -45,8 +45,8 @@ dependencies = [ "huaweicloudsdkswr>=3.1.0", "starlette>=0.46.2", "uvicorn[standard]>=0.24.0", - "pydantic>=2.10.0", - "pydantic-settings>=2.6.0", + "pydantic>=2.0.0,<3.0.0", + "pydantic-settings>=2.0.0,<3.0.0", "python-dotenv>=1.0.0", "pyyaml>=6.0", "requests>=2.31.0", diff --git a/src/agentarts/sdk/__init__.py b/src/agentarts/sdk/__init__.py index 93879c1..8950936 100644 --- a/src/agentarts/sdk/__init__.py +++ b/src/agentarts/sdk/__init__.py @@ -28,6 +28,10 @@ warnings.filterwarnings("ignore", message="Unverified HTTPS request") warnings.filterwarnings("ignore", category=urllib3.exceptions.InsecureRequestWarning) +from agentarts.sdk.utils.logging import setup_logging + +setup_logging() + from agentarts import __author__, __version__ from agentarts.sdk.identity import ( IdentityClient, diff --git a/src/agentarts/sdk/integration/langgraph/saver.py b/src/agentarts/sdk/integration/langgraph/saver.py index 3aed23f..dc12058 100644 --- a/src/agentarts/sdk/integration/langgraph/saver.py +++ b/src/agentarts/sdk/integration/langgraph/saver.py @@ -99,6 +99,10 @@ class AgentArtsMemorySessionSaver(BaseCheckpointSaver): falls back to HUAWEICLOUD_SDK_MEMORY_API_KEY environment variable) max_messages: Maximum number of messages to retrieve per query, default 10 serde: Serializer/deserializer for checkpoints (default: JsonPlusSerializer) + verify_ssl: SSL verification setting (default: True). Can be: + - True: Verify SSL certificates using system CA bundle + - False: Skip SSL verification (not recommended for production) + - str: Path to custom CA certificate file """ def __init__( @@ -108,6 +112,7 @@ def __init__( api_key: str | None = None, max_messages: int = 10, serde: JsonPlusSerializer | None = None, + verify_ssl: bool | str = True, ) -> None: if not LANGGRAPH_AVAILABLE: msg = ( @@ -123,9 +128,11 @@ def __init__( self._region = region or get_region() self._api_key = api_key self._max_messages = max_messages + self._verify_ssl = verify_ssl self._client = MemoryClient( region_name=self._region, - api_key=api_key + api_key=api_key, + verify_ssl=verify_ssl ) @property diff --git a/src/agentarts/sdk/mcpgateway/mcp_gateway_client.py b/src/agentarts/sdk/mcpgateway/mcp_gateway_client.py index 4cdb452..84473ea 100644 --- a/src/agentarts/sdk/mcpgateway/mcp_gateway_client.py +++ b/src/agentarts/sdk/mcpgateway/mcp_gateway_client.py @@ -21,13 +21,11 @@ class MCPGatewayClient(BaseHTTPClient): """ def __init__(self, config: RequestConfig | None = None): - # If config is None or base_url is not set, use control plane endpoint if config is None or (config.base_url is None or config.base_url == ""): from agentarts.sdk.service.http_client import RequestConfig if config is None: config = RequestConfig() config.base_url = f"{get_control_plane_endpoint()}/v1/core" - config.verify_ssl = False super().__init__(config, open_ak_sk=True) def create_mcp_gateway( diff --git a/src/agentarts/sdk/memory/client.py b/src/agentarts/sdk/memory/client.py index efe6baf..e6e6fde 100644 --- a/src/agentarts/sdk/memory/client.py +++ b/src/agentarts/sdk/memory/client.py @@ -83,7 +83,7 @@ def __init__( self, region_name: str | None = None, api_key: str | None = None, - verify_ssl: bool = False, + verify_ssl: bool | str = True, ): """ Initialize Memory Client. @@ -100,7 +100,10 @@ def __init__( region_name: Huawei Cloud region name, auto-detected from environment if not provided api_key: API Key for data plane authentication (optional, falls back to HUAWEICLOUD_SDK_MEMORY_API_KEY environment variable) - verify_ssl: Whether to verify SSL certificates (default: False) + verify_ssl: SSL verification setting. + - True: Verify SSL certificates using system CA bundle (default) + - False: Skip SSL verification (not recommended for production) + - str: Path to custom CA certificate file Environment Variables: HUAWEICLOUD_SDK_AK: Access Key, required for control plane API diff --git a/src/agentarts/sdk/memory/inner/controlplane.py b/src/agentarts/sdk/memory/inner/controlplane.py index 92c8c1b..8f87e3b 100644 --- a/src/agentarts/sdk/memory/inner/controlplane.py +++ b/src/agentarts/sdk/memory/inner/controlplane.py @@ -25,14 +25,17 @@ class _ControlPlane: def __init__( self, region_name: str | None = None, - verify_ssl: bool = False, + verify_ssl: bool | str = True, ): """ Initialize control plane. Args: region_name: Huawei Cloud region name, auto-detected from environment if not provided - verify_ssl: Whether to verify SSL certificates (default: False) + verify_ssl: SSL verification setting. + - True: Verify SSL certificates using system CA bundle (default) + - False: Skip SSL verification (not recommended for production) + - str: Path to custom CA certificate file """ self.client = MemoryHttpService( region_name=region_name, @@ -183,3 +186,12 @@ def delete_space(self, space_id: str) -> None: self.client.delete_space(space_id) logger.info(f"Space deleted: {space_id}") + + def close(self) -> None: + """Close the control plane and release resources. + + This method closes the underlying HTTP client session. + """ + if hasattr(self, "client") and self.client is not None: + self.client.close() + logger.info("ControlPlane closed") diff --git a/src/agentarts/sdk/memory/inner/dataplane.py b/src/agentarts/sdk/memory/inner/dataplane.py index 37bfebf..9fca20c 100644 --- a/src/agentarts/sdk/memory/inner/dataplane.py +++ b/src/agentarts/sdk/memory/inner/dataplane.py @@ -53,7 +53,7 @@ def __init__( self, region_name: str | None = None, api_key: str | None = None, - verify_ssl: bool = False, + verify_ssl: bool | str = True, ): """ Initialize data plane. @@ -61,7 +61,10 @@ def __init__( Args: region_name: Huawei Cloud region name, auto-detected from environment if not provided api_key: API Key for data plane authentication (optional, falls back to environment variable) - verify_ssl: Whether to verify SSL certificates (default: False) + verify_ssl: SSL verification setting (default: True). Can be: + - True: Verify SSL certificates using system CA bundle + - False: Skip SSL verification (not recommended for production) + - str: Path to custom CA certificate file """ self.client = MemoryHttpService( region_name=region_name, @@ -284,3 +287,12 @@ def delete_memory(self, space_id: str, memory_id: str) -> None: """ logger.info(f"Deleting memory: {memory_id}") self.client.delete_memory(space_id, memory_id) + + def close(self) -> None: + """Close the data plane and release resources. + + This method closes the underlying HTTP client session. + """ + if hasattr(self, "client") and self.client is not None: + self.client.close() + logger.info("DataPlane closed") diff --git a/src/agentarts/sdk/runtime/app.py b/src/agentarts/sdk/runtime/app.py index 5dc6050..c12a723 100644 --- a/src/agentarts/sdk/runtime/app.py +++ b/src/agentarts/sdk/runtime/app.py @@ -19,14 +19,18 @@ from __future__ import annotations import asyncio +import contextvars import inspect import json import logging import os +import socket +import subprocess import threading import time import uuid from concurrent.futures import ThreadPoolExecutor +from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -273,6 +277,17 @@ def _task_context(self, handler: Callable) -> bool: except Exception: return False + def _ping_task_context(self, handler: Callable) -> bool: + """ + Return ``True`` if *handler* accepts a ``context`` as its + first positional parameter. + """ + try: + params = list(inspect.signature(handler).parameters.keys()) + return len(params) >= 1 and params[0] == "context" + except Exception: + return False + # ------------------------------------------------------------------ # Invocation endpoint # ------------------------------------------------------------------ @@ -357,6 +372,8 @@ async def _handle_invocation(self, request: Request) -> Response: status_code=500, content={"error": type(exc).__name__, "message": str(exc)}, ) + finally: + AgentArtsRuntimeContext.clear() async def _invoke_handler( self, @@ -383,8 +400,9 @@ async def _invoke_handler( if asyncio.iscoroutinefunction(handler): return await handler(*args) loop = asyncio.get_event_loop() + ctx = contextvars.copy_context() return await loop.run_in_executor( - self._invocation_executor, handler, *args + self._invocation_executor, lambda: ctx.run(handler, *args) ) except Exception as exc: handler_name = getattr(handler, "__name__", "unknown") @@ -489,7 +507,8 @@ async def _handle_ping(self, request: Request) -> Response: is returned. """ try: - status = self.get_current_ping_status() + request_context = self._build_request_context(request) + status = self.get_current_ping_status(request_context) self.logger.debug(f"Ping request - status: {status}") return JSONResponse( content={ @@ -506,10 +525,13 @@ async def _handle_ping(self, request: Request) -> Response: }, ) - def get_current_ping_status(self) -> PingStatus: + def get_current_ping_status(self, request_context: RequestContext | None = None) -> PingStatus: """ Get the current status of the AgentArts runtime. + Args: + request_context: Optional request context to pass to the ping handler. + Returns: PingStatus: The current health status of the runtime. """ @@ -520,7 +542,9 @@ def get_current_ping_status(self) -> PingStatus: if self._ping_handler is not None: try: - result = self._ping_handler() + task_context = self._ping_task_context(self._ping_handler) + args = (request_context,) if task_context and request_context else () + result = self._ping_handler(*args) current_status = PingStatus(result) if isinstance(result, str) else result except Exception as exc: self.logger.warning("Custom Ping handler failed: %s: %s", type(exc).__name__, exc) @@ -607,16 +631,78 @@ def handler(payload): handler.run(host="0.0.0.0", port=8080) Args: - host: Bind address. Defaults to ``"0.0.0.0"``. - port: Bind port. Defaults to ``8080``. + host: Bind address. Defaults to eth0 IP in Docker/Kubernetes environment, + or ``"127.0.0.1"`` in local development. + Can be overridden via ``AGENTARTS_BIND_IP`` environment variable. + port: Bind port. Defaults to ``8080``. **kwargs: Additional keyword arguments forwarded to ``uvicorn.run`` (e.g. ``workers``, ``log_level``). """ import uvicorn + def _is_docker_environment() -> bool: + if os.path.exists("/.dockerenv"): + return True + if os.getenv("DOCKER_CONTAINER"): + return True + if os.getenv("KUBERNETES_SERVICE_HOST"): + return True + try: + cgroup = Path("/proc/1/cgroup").read_text() + if "docker" in cgroup or "kubepods" in cgroup or "containerd" in cgroup: + return True + except Exception: + pass + try: + mountinfo = Path("/proc/self/mountinfo").read_text() + if "docker" in mountinfo or "containers" in mountinfo: + return True + except Exception: + pass + return False + + def _get_eth0_ip() -> str | None: + try: + result = subprocess.run( + "ip addr show eth0 | grep -oP 'inet \\K[\\d.]+'", + shell=True, + capture_output=True, + text=True, + timeout=5, + ) + ip = result.stdout.strip() + if ip and result.returncode == 0: + self.logger.debug("Detected eth0 IP via ip command: %s", ip) + return ip + except Exception as e: + self.logger.debug("Failed to get IP via ip command: %s", e) + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.settimeout(2) + s.connect(("8.8.8.8", 80)) + ip = s.getsockname()[0] + s.close() + if ip and ip != "0.0.0.0": + self.logger.debug("Detected eth0 IP via socket: %s", ip) + return ip + except Exception as e: + self.logger.debug("Failed to get IP via socket: %s", e) + try: + ip = socket.gethostbyname(socket.gethostname()) + if ip and ip != "127.0.0.1": + self.logger.debug("Detected eth0 IP via hostname: %s", ip) + return ip + except Exception as e: + self.logger.debug("Failed to get IP via hostname: %s", e) + return None + if host is None: - if os.path.exists("/.dockerenv") or os.getenv("DOCKER_CONTAINER"): - host = "0.0.0.0" + env_bind_ip = os.getenv("AGENTARTS_BIND_IP") + if env_bind_ip: + host = env_bind_ip + elif _is_docker_environment(): + eth0_ip = _get_eth0_ip() + host = eth0_ip if eth0_ip else "0.0.0.0" else: host = "127.0.0.1" diff --git a/src/agentarts/sdk/runtime/model.py b/src/agentarts/sdk/runtime/model.py index 4d9d54a..c70cfbd 100644 --- a/src/agentarts/sdk/runtime/model.py +++ b/src/agentarts/sdk/runtime/model.py @@ -10,7 +10,7 @@ SESSION_HEADER = "x-hw-agentarts-session-id" ACCESS_TOKEN_HEADER = "X-HW-AgentGateway-Workload-Access-Token" -USER_ID_HEADER = "X-Hw-AgentArts-Runtime-User-Id" +USER_ID_HEADER = "X-HW-AgentGateway-User-Id" CUSTOM_HEADER_PREFIX = "X-Hw-AgentArts-Runtime-Custom-" diff --git a/src/agentarts/sdk/service/http_client.py b/src/agentarts/sdk/service/http_client.py index 69228f0..4274234 100644 --- a/src/agentarts/sdk/service/http_client.py +++ b/src/agentarts/sdk/service/http_client.py @@ -46,12 +46,22 @@ class SignMode(Enum): @dataclass class RequestConfig: - """Configuration for HTTP requests.""" + """Configuration for HTTP requests. + + Attributes: + base_url: Base URL for API requests. + timeout: Request timeout in seconds. + headers: Default headers for all requests. + verify_ssl: SSL verification setting. + - True: Verify SSL certificates using system CA bundle (default) + - False: Skip SSL verification (not recommended for production) + - str: Path to custom CA certificate file + """ base_url: str = "" timeout: float = 30.0 headers: dict[str, str] = field(default_factory=dict) - verify_ssl: bool = True + verify_ssl: bool | str = True @dataclass diff --git a/src/agentarts/sdk/service/iam_client.py b/src/agentarts/sdk/service/iam_client.py index dd1c001..1d86076 100644 --- a/src/agentarts/sdk/service/iam_client.py +++ b/src/agentarts/sdk/service/iam_client.py @@ -10,15 +10,24 @@ class IAMClient: IAM Client for making API calls to IAM service. Uses huaweicloudsdkiam.v5.IamClient to make API calls. + + Args: + verify_ssl: SSL verification setting (default: True). Can be: + - True: Verify SSL certificates using system CA bundle + - False: Skip SSL verification (not recommended for production) + - str: Path to custom CA certificate file """ - def __init__(self): + def __init__(self, verify_ssl: bool | str = True): """ Initialize IAM client. All configuration will be loaded from environment variables via constant module. + + Args: + verify_ssl: SSL verification setting (default: True) """ - # Do not execute any code here + self._verify_ssl = verify_ssl def _get_iam_client(self): """ @@ -38,9 +47,12 @@ def _get_iam_client(self): # Create credentials credentials = create_credential() - # Create HTTP config with ignore_ssl_verification=True + # Create HTTP config http_config = HttpConfig.get_default_config() - http_config.ignore_ssl_verification = True + if isinstance(self._verify_ssl, str): + http_config.ssl_ca_cert = self._verify_ssl + else: + http_config.ignore_ssl_verification = not self._verify_ssl # Create region object final_region = Region(id=get_region(), endpoint=get_iam_endpoint()) diff --git a/src/agentarts/sdk/service/memory_service.py b/src/agentarts/sdk/service/memory_service.py index ef23306..5e1d6ec 100644 --- a/src/agentarts/sdk/service/memory_service.py +++ b/src/agentarts/sdk/service/memory_service.py @@ -205,7 +205,7 @@ def __init__( endpoint_type: str = "control", timeout: int = 30, api_key: str | None = None, - verify_ssl: bool = False, + verify_ssl: bool | str = True, enable_signing: bool | None = None, ): """Initialize Memory HTTP service with region and authentication strategy. @@ -216,7 +216,10 @@ def __init__( endpoint_type: "control" for control plane, "data" for data plane timeout: Request timeout in seconds api_key: API Key for data plane authentication (optional, falls back to environment variable) - verify_ssl: Whether to verify SSL certificates (default: False) + verify_ssl: SSL verification setting. + - True: Verify SSL certificates using system CA bundle (default) + - False: Skip SSL verification (not recommended for production) + - str: Path to custom CA certificate file enable_signing: Whether to enable request signing. If None, automatically enabled for control plane and disabled for data plane. Set to True/False to explicitly control signing behavior. @@ -688,3 +691,13 @@ def endpoint_type(self) -> str: def enable_signing(self) -> bool: """Get whether signing is enabled.""" return self._enable_signing + + def close(self) -> None: + """Close the HTTP session and release resources. + + This method should be called when the service is no longer needed + to properly release the underlying requests.Session resources. + """ + if hasattr(self, "session") and self.session is not None: + self.session.close() + logger.info("MemoryHttpService session closed") diff --git a/src/agentarts/sdk/service/runtime_client.py b/src/agentarts/sdk/service/runtime_client.py index 312292d..9147516 100644 --- a/src/agentarts/sdk/service/runtime_client.py +++ b/src/agentarts/sdk/service/runtime_client.py @@ -67,7 +67,10 @@ class RuntimeClient: access_token: Bearer token for API authentication. Can also be set later via :meth:`set_auth_token`. timeout: Default request timeout in seconds. - verify_ssl: Whether to verify SSL certificates. + verify_ssl: Whether to verify SSL certificates. Can be: + - True: Verify SSL certificates using system CA bundle (default) + - False: Skip SSL verification (not recommended for production) + - str: Path to custom CA certificate file sign_mode: Signature mode for data plane requests (SDK_HMAC_SHA256 or V11_HMAC_SHA256). region_id: Region ID for V11 signature mode. """ @@ -78,7 +81,7 @@ def __init__( data_endpoint: str | None = None, access_token: str | None = None, timeout: float = 30.0, - verify_ssl: bool = True, + verify_ssl: bool | str = True, sign_mode: SignMode = SignMode.SDK_HMAC_SHA256, region_id: str = "", ) -> None: @@ -618,6 +621,7 @@ def invoke_agent( bearer_token: str | None = None, endpoint: str | None = None, timeout: int = 900, + user_id: str | None = None, **extra: Any, ) -> dict[str, Any] | Iterator[str]: """ @@ -636,13 +640,15 @@ def invoke_agent( endpoint: Optional endpoint name, appended as a query parameter ``?endpoint=xxx``. timeout: Request timeout in seconds. + user_id: Optional user ID for OAuth2 outbound credentials, + passed as the ``USER_ID_HEADER`` header. **extra: Additional fields merged into the request. Returns: A ``dict`` for JSON responses, or an ``Iterator[str]`` for SSE streaming responses. """ - from agentarts.sdk.runtime.model import SESSION_HEADER + from agentarts.sdk.runtime.model import SESSION_HEADER, USER_ID_HEADER path = f"/agent/{agent_name}/invocations" params: dict[str, Any] = {} @@ -655,6 +661,8 @@ def invoke_agent( } if bearer_token: headers["Authorization"] = f"Bearer {bearer_token}" + if user_id: + headers[USER_ID_HEADER] = user_id result = self._data( "POST", @@ -674,6 +682,7 @@ def ping_agent( endpoint: str | None = None, session_id: str | None = None, timeout: int = 900, + user_id: str | None = None, ) -> dict[str, Any] | Iterator[str]: """ Health-check an agent on the data plane. @@ -689,18 +698,22 @@ def ping_agent( session_id: Session identifier for stateful agents, passed as the ``SESSION_HEADER`` header. timeout: Request timeout in seconds. + user_id: Optional user ID for OAuth2 outbound credentials, + passed as the ``USER_ID_HEADER`` header. Returns: A ``dict`` with at least a ``status`` field (e.g. ``"Healthy"``), or an ``Iterator[str]`` for SSE streaming responses. """ - from agentarts.sdk.runtime.model import SESSION_HEADER + from agentarts.sdk.runtime.model import SESSION_HEADER, USER_ID_HEADER headers: dict[str, str] = {} if bearer_token: headers["Authorization"] = f"Bearer {bearer_token}" if session_id: headers[SESSION_HEADER] = session_id + if user_id: + headers[USER_ID_HEADER] = user_id params: dict[str, Any] = {} if endpoint: @@ -749,6 +762,7 @@ def invoke_agent( bearer_token: str | None = None, endpoint: str | None = None, timeout: int | None = None, + user_id: str | None = None, ) -> dict[str, Any] | Iterator[str]: """ Invoke a local agent. @@ -759,12 +773,14 @@ def invoke_agent( bearer_token: Optional bearer token for ``Authorization`` header. endpoint: Optional endpoint name. timeout: Request timeout in seconds. + user_id: Optional user ID for OAuth2 outbound credentials, + passed as the ``USER_ID_HEADER`` header. Returns: A ``dict`` for JSON responses, or an ``Iterator[str]`` for SSE streaming responses. """ - from agentarts.sdk.runtime.model import SESSION_HEADER + from agentarts.sdk.runtime.model import SESSION_HEADER, USER_ID_HEADER path = "/invocations" params: dict[str, Any] = {} @@ -776,6 +792,8 @@ def invoke_agent( headers[SESSION_HEADER] = session_id if bearer_token: headers["Authorization"] = f"Bearer {bearer_token}" + if user_id: + headers[USER_ID_HEADER] = user_id request_timeout = timeout or self._config.timeout @@ -827,6 +845,7 @@ def ping_agent( endpoint: str | None = None, session_id: str | None = None, timeout: int | None = None, + user_id: str | None = None, ) -> dict[str, Any]: """ Health-check a local agent. @@ -836,11 +855,13 @@ def ping_agent( endpoint: Optional endpoint name. session_id: Session identifier for stateful agents. timeout: Request timeout in seconds. + user_id: Optional user ID for OAuth2 outbound credentials, + passed as the ``USER_ID_HEADER`` header. Returns: A ``dict`` with a ``status`` field indicating health status. """ - from agentarts.sdk.runtime.model import SESSION_HEADER + from agentarts.sdk.runtime.model import SESSION_HEADER, USER_ID_HEADER path = "/ping" @@ -849,6 +870,8 @@ def ping_agent( headers["Authorization"] = f"Bearer {bearer_token}" if session_id: headers[SESSION_HEADER] = session_id + if user_id: + headers[USER_ID_HEADER] = user_id params: dict[str, Any] = {} if endpoint: diff --git a/src/agentarts/sdk/service/swr_client.py b/src/agentarts/sdk/service/swr_client.py index 750bf52..d5e6f26 100644 --- a/src/agentarts/sdk/service/swr_client.py +++ b/src/agentarts/sdk/service/swr_client.py @@ -45,17 +45,23 @@ class SWRClient: region: Huawei Cloud region (e.g., "cn-north-4"). endpoint: Override SWR endpoint URL. If ``None``, the URL is derived from the region. + verify_ssl: SSL verification setting (default: True). Can be: + - True: Verify SSL certificates using system CA bundle + - False: Skip SSL verification (not recommended for production) + - str: Path to custom CA certificate file """ def __init__( self, region: str, endpoint: str | None = None, + verify_ssl: bool | str = True, ) -> None: self._region = region self._endpoint = endpoint or get_swr_endpoint(region) self._swr_registry = f"swr.{region}.myhuaweicloud.com" + self._verify_ssl = verify_ssl self._client = None self._credentials = None @@ -90,7 +96,10 @@ def _get_client(self): credentials = self._get_credentials() http_config = HttpConfig.get_default_config() - http_config.ignore_ssl_verification = True + if isinstance(self._verify_ssl, str): + http_config.ssl_ca_cert = self._verify_ssl + else: + http_config.ignore_ssl_verification = not self._verify_ssl try: swr_region = SwrRegion.value_of(self._region) diff --git a/src/agentarts/sdk/service/tools_http.py b/src/agentarts/sdk/service/tools_http.py index e377c7f..f270dc2 100644 --- a/src/agentarts/sdk/service/tools_http.py +++ b/src/agentarts/sdk/service/tools_http.py @@ -21,8 +21,13 @@ def __init__(self, status_code: int, error_msg: str): class ControlToolsHttpClient(BaseHTTPClient): - def __init__(self, region_name: str, endpoint_url: str): - request_config = RequestConfig(base_url=endpoint_url, verify_ssl=False) + def __init__( + self, + region_name: str, + endpoint_url: str, + verify_ssl: bool | str = True, + ): + request_config = RequestConfig(base_url=endpoint_url, verify_ssl=verify_ssl) super().__init__(request_config, open_ak_sk=True) self.region_name = region_name @@ -84,7 +89,7 @@ def delete_code_interpreter(self, code_interpreter_id: str): class DataToolsHttpClient(BaseHTTPClient): - def __init__(self, region_name: str, endpoint_url: str, auth_type: str = "API_KEY"): + def __init__(self, region_name: str, endpoint_url: str, auth_type: str = "API_KEY", verify_ssl: bool | str = True): """Initialize the data tools HTTP client. Args: @@ -94,13 +99,13 @@ def __init__(self, region_name: str, endpoint_url: str, auth_type: str = "API_KE """ if auth_type == "IAM": super().__init__( - RequestConfig(base_url=endpoint_url, verify_ssl=False), + RequestConfig(base_url=endpoint_url, verify_ssl=verify_ssl), open_ak_sk=True, sign_mode=SignMode.V11_HMAC_SHA256, region_id=region_name, ) else: - super().__init__(RequestConfig(base_url=endpoint_url, verify_ssl=False)) + super().__init__(RequestConfig(base_url=endpoint_url, verify_ssl=verify_ssl)) self.region_name = region_name @property diff --git a/src/agentarts/sdk/utils/logging.py b/src/agentarts/sdk/utils/logging.py new file mode 100644 index 0000000..02bc781 --- /dev/null +++ b/src/agentarts/sdk/utils/logging.py @@ -0,0 +1,85 @@ +"""SDK logging configuration utilities.""" + +import logging +import os +from typing import Literal + +LogLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + +ENV_LOG_LEVEL = "AGENTARTS_LOG_LEVEL" + +DEFAULT_LOG_LEVEL: LogLevel = "INFO" + + +def get_log_level() -> LogLevel: + """Get log level from environment variable or default. + + Returns: + Log level string (DEBUG, INFO, WARNING, ERROR, CRITICAL). + """ + level = os.getenv(ENV_LOG_LEVEL, DEFAULT_LOG_LEVEL).upper() + valid_levels: tuple[LogLevel, ...] = ( + "DEBUG", + "INFO", + "WARNING", + "ERROR", + "CRITICAL", + ) + if level not in valid_levels: + level = DEFAULT_LOG_LEVEL + return level + + +def setup_logging( + level: LogLevel | None = None, + format: str | None = None, + handler: logging.Handler | None = None, +) -> None: + """Configure SDK logging. + + Args: + level: Log level. If None, use environment variable AGENTARTS_LOG_LEVEL. + format: Log format string. + handler: Custom handler. If None, use StreamHandler. + + Example: + >>> setup_logging(level="DEBUG") + >>> setup_logging() # Use AGENTARTS_LOG_LEVEL env var + + # Or via environment variable: + >>> import os + >>> os.environ["AGENTARTS_LOG_LEVEL"] = "DEBUG" + """ + if level is None: + level = get_log_level() + + if format is None: + format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + sdk_logger = logging.getLogger("agentarts") + sdk_logger.setLevel(level) + + if not sdk_logger.handlers: + if handler is None: + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter(format)) + sdk_logger.addHandler(handler) + + logging.getLogger("urllib3").setLevel(logging.WARNING) + logging.getLogger("huaweicloudsdkcore").setLevel(logging.WARNING) + + +def get_logger(name: str) -> logging.Logger: + """Get a logger with the SDK namespace. + + Args: + name: Logger name (will be prefixed with 'agentarts.'). + + Returns: + Configured logger instance. + + Example: + >>> logger = get_logger("runtime.app") + >>> logger.info("Application started") + """ + return logging.getLogger(f"agentarts.{name}") diff --git a/src/agentarts/sdk/utils/metadata.py b/src/agentarts/sdk/utils/metadata.py index 5d4ecff..585c4a8 100644 --- a/src/agentarts/sdk/utils/metadata.py +++ b/src/agentarts/sdk/utils/metadata.py @@ -4,6 +4,7 @@ import json import logging +import os from functools import wraps import requests @@ -15,6 +16,8 @@ ProfileCredentialProvider, ) +from .constant import ENV_HUAWEICLOUD_SDK_PROJECT_ID + def create_credential(): """ @@ -91,11 +94,15 @@ def get_credentials(self) -> BasicCredentials: if resp.status_code < 300: metadata = json.loads(resp.text) self.logger.info(f"Get metadata credentials with expired time: {metadata['expires_at']}") - return BasicCredentials() \ + + credentials = BasicCredentials() \ .with_ak(metadata["access"]) \ .with_sk(metadata["secret"]) \ .with_security_token(metadata["securitytoken"]) + project_id = os.getenv(ENV_HUAWEICLOUD_SDK_PROJECT_ID) + return credentials.with_project_id(project_id) if project_id else credentials + self.logger.warning(f"Get metadata credentials failed with status: {resp.status_code}") except requests.exceptions.RequestException as e: self.logger.warning(f"Failed to connect to metadata service: {e}") diff --git a/src/agentarts/toolkit/cli/memory/commands.py b/src/agentarts/toolkit/cli/memory/commands.py index 4ca8b98..eae6950 100644 --- a/src/agentarts/toolkit/cli/memory/commands.py +++ b/src/agentarts/toolkit/cli/memory/commands.py @@ -45,6 +45,7 @@ def create_space_cmd( subnet_id: str | None = typer.Option(None, "--subnet-id", help="Private subnet ID (requires vpc-id)"), region: str | None = typer.Option(None, "--region", "-r", help="Region name (default: cn-north-4)"), output: str = typer.Option("table", "--output", "-o", help="Output format: table, json"), + skip_ssl_verification: bool = typer.Option(False, "--skip-ssl-verification", "-k", help="Skip SSL certificate verification"), ): """Create a Memory Space. @@ -107,6 +108,7 @@ def create_space_cmd( private_vpc_id=vpc_id.strip() if vpc_id else None, private_subnet_id=subnet_id.strip() if subnet_id else None, region=region, + skip_ssl_verification=skip_ssl_verification, ) if not result.success: @@ -130,6 +132,7 @@ def get_space_cmd( space_id: str = typer.Argument(..., help="Space ID"), region: str | None = typer.Option(None, "--region", "-r", help="Region name (default: cn-north-4)"), output: str = typer.Option("table", "--output", "-o", help="Output format: table, json"), + skip_ssl_verification: bool = typer.Option(False, "--skip-ssl-verification", help="Skip SSL certificate verification"), ): """Get Space details. @@ -144,6 +147,7 @@ def get_space_cmd( result = get_space( space_id=space_id, region=region, + skip_ssl_verification=skip_ssl_verification, ) if not result.success: @@ -174,6 +178,7 @@ def list_spaces_cmd( offset: int = typer.Option(0, "--offset", help="Offset for pagination"), region: str | None = typer.Option(None, "--region", "-r", help="Region name (default: cn-north-4)"), output: str = typer.Option("table", "--output", "-o", help="Output format: table, json"), + skip_ssl_verification: bool = typer.Option(False, "--skip-ssl-verification", help="Skip SSL certificate verification"), ): """List Spaces. @@ -193,6 +198,7 @@ def list_spaces_cmd( limit=limit, offset=offset, region=region, + skip_ssl_verification=skip_ssl_verification, ) if not result.success: @@ -247,6 +253,7 @@ def update_space_cmd( subnet_id: str | None = typer.Option(None, "--subnet-id", help="Private subnet ID (requires vpc-id)"), region: str | None = typer.Option(None, "--region", "-r", help="Region name (default: cn-north-4)"), output: str = typer.Option("table", "--output", "-o", help="Output format: table, json"), + skip_ssl_verification: bool = typer.Option(False, "--skip-ssl-verification", help="Skip SSL certificate verification"), ): """Update a Space. @@ -301,6 +308,7 @@ def update_space_cmd( private_vpc_id=vpc_id.strip() if vpc_id else None, private_subnet_id=subnet_id.strip() if subnet_id else None, region=region, + skip_ssl_verification=skip_ssl_verification, ) if not result.success: @@ -347,6 +355,7 @@ def delete_space_cmd( space_id: str = typer.Argument(..., help="Space ID"), region: str | None = typer.Option(None, "--region", "-r", help="Region name (default: cn-north-4)"), force: bool = typer.Option(False, "--force", "-f", help="Force deletion without confirmation"), + skip_ssl_verification: bool = typer.Option(False, "--skip-ssl-verification", help="Skip SSL certificate verification"), ): """Delete a Space. @@ -368,6 +377,7 @@ def delete_space_cmd( result = delete_space( space_id=space_id, region=region, + skip_ssl_verification=skip_ssl_verification, ) if not result.success: @@ -382,6 +392,7 @@ def space_status_cmd( space_id: str = typer.Argument(..., help="Space ID to check status"), region: str | None = typer.Option(None, "--region", "-r", help="Region name (default: cn-north-4)"), output: str = typer.Option("table", "--output", "-o", help="Output format: table, json"), + skip_ssl_verification: bool = typer.Option(False, "--skip-ssl-verification", help="Skip SSL certificate verification"), ): """Check the status of a Memory Space. @@ -402,6 +413,7 @@ def space_status_cmd( result = get_space( space_id=space_id, region=region, + skip_ssl_verification=skip_ssl_verification, ) if not result.success: diff --git a/src/agentarts/toolkit/cli/runtime/config.py b/src/agentarts/toolkit/cli/runtime/config.py index 96b3cd5..ae436cc 100644 --- a/src/agentarts/toolkit/cli/runtime/config.py +++ b/src/agentarts/toolkit/cli/runtime/config.py @@ -45,8 +45,8 @@ def main( Examples: agentarts config - agentarts config --name my-agent --entrypoint app:main - agentarts config -n my-agent -e app:main --dependency-file requirements.txt --swr-org my-org --swr-repo my-repo + agentarts config --name myagent --entrypoint app:main + agentarts config -n myagent -e app:main --dependency-file requirements.txt --swr-org my-org --swr-repo my-repo """ if ctx.invoked_subcommand is not None: return @@ -226,7 +226,7 @@ def get( Examples: agentarts config get agentarts config get base.region - agentarts config get base.region --agent my-agent + agentarts config get base.region --agent myagent """ if key is None: success = config_op.print_agent_detail(agent) @@ -262,7 +262,7 @@ def remove( Remove an agent configuration. Examples: - agentarts config remove my-agent + agentarts config remove myagent """ success = config_op.remove_agent(name) if not success: diff --git a/src/agentarts/toolkit/cli/runtime/deploy.py b/src/agentarts/toolkit/cli/runtime/deploy.py index 8989603..9e14345 100644 --- a/src/agentarts/toolkit/cli/runtime/deploy.py +++ b/src/agentarts/toolkit/cli/runtime/deploy.py @@ -44,6 +44,10 @@ def deploy( str | None, typer.Option("--description", "-d", help="Agent description (overrides config)"), ] = None, + skip_ssl_verification: Annotated[ + bool, + typer.Option("--skip-ssl-verification", "-k", help="Skip SSL certificate verification"), + ] = False, ): """ Deploy agent to Huawei Cloud or run locally. @@ -71,7 +75,7 @@ def deploy( Examples: agentarts deploy - agentarts deploy --agent my-agent + agentarts deploy --agent myagent agentarts deploy --mode local --local-port 8080 agentarts deploy --mode cloud --tag v1.0.0 agentarts deploy --swr-org my-org --swr-repo my-repo @@ -94,6 +98,7 @@ def deploy( swr_org=swr_org, swr_repo=swr_repo, description=description, + skip_ssl_verification=skip_ssl_verification, ) if not success: diff --git a/src/agentarts/toolkit/cli/runtime/destroy.py b/src/agentarts/toolkit/cli/runtime/destroy.py index bf2cc39..02c28a1 100644 --- a/src/agentarts/toolkit/cli/runtime/destroy.py +++ b/src/agentarts/toolkit/cli/runtime/destroy.py @@ -20,6 +20,10 @@ def destroy( bool, typer.Option("--yes", "-y", help="Skip confirmation prompt"), ] = False, + skip_ssl_verification: Annotated[ + bool, + typer.Option("--skip-ssl-verification", "-k", help="Skip SSL certificate verification"), + ] = False, ): """ Destroy agent from Huawei Cloud. @@ -28,8 +32,8 @@ def destroy( Examples: agentarts destroy - agentarts destroy --agent my-agent - agentarts destroy --agent my-agent --region cn-southwest-2 + agentarts destroy --agent myagent + agentarts destroy --agent myagent --region cn-southwest-2 agentarts destroy --yes # Skip confirmation """ from rich.console import Console as RichConsole @@ -47,6 +51,7 @@ def destroy( success = destroy_agent( agent_name=agent, region=region, + skip_ssl_verification=skip_ssl_verification, ) if not success: diff --git a/src/agentarts/toolkit/cli/runtime/init.py b/src/agentarts/toolkit/cli/runtime/init.py index ef0fa35..d816913 100644 --- a/src/agentarts/toolkit/cli/runtime/init.py +++ b/src/agentarts/toolkit/cli/runtime/init.py @@ -53,7 +53,7 @@ def prompt_for_template() -> TemplateType: def prompt_for_name() -> str: """Prompt user to enter project name""" - return Prompt.ask("\n[bold]Enter project name[/bold]", default="my_agent") + return Prompt.ask("\n[bold]Enter project name[/bold]", default="myagent") def prompt_for_region() -> str: @@ -103,10 +103,10 @@ def init( Examples: agentarts init - agentarts init -n my_agent - agentarts init -n my_agent -t langgraph - agentarts init -n my_agent -t langchain -r cn-southwest-2 - agentarts init -n my_agent --swr-org my-org --swr-repo my-repo + agentarts init -n myagent + agentarts init -n myagent -t langgraph + agentarts init -n myagent -t langchain -r cn-southwest-2 + agentarts init -n myagent --swr-org my-org --swr-repo my-repo """ if name is None: name = prompt_for_name() diff --git a/src/agentarts/toolkit/cli/runtime/invoke.py b/src/agentarts/toolkit/cli/runtime/invoke.py index 2dbf5b6..70bfc8f 100644 --- a/src/agentarts/toolkit/cli/runtime/invoke.py +++ b/src/agentarts/toolkit/cli/runtime/invoke.py @@ -47,6 +47,14 @@ def status( str | None, typer.Option("--bearer-token", "-bt", help="Bearer token for authentication"), ] = None, + skip_ssl_verification: Annotated[ + bool, + typer.Option("--skip-ssl-verification", "-k", help="Skip SSL certificate verification"), + ] = False, + user_id: Annotated[ + str | None, + typer.Option("--user-id", "-u", help="User ID for OAuth2 outbound credentials"), + ] = None, ): """ Check agent health status. @@ -57,11 +65,12 @@ def status( Examples: agentarts status - agentarts status --agent my-agent + agentarts status --agent myagent agentarts status --mode local --port 8080 agentarts status --endpoint custom-endpoint agentarts status --session my-session-123 agentarts status --bearer-token my-token + agentarts status --user-id my-user-id agentarts status -bt my-token """ status_mode = InvokeMode.CLOUD @@ -79,6 +88,8 @@ def status( endpoint=endpoint, session_id=session_id, bearer_token=bearer_token, + skip_ssl_verification=skip_ssl_verification, + user_id=user_id, ) if not success: @@ -88,7 +99,7 @@ def status( def invoke( payload: Annotated[ str, - typer.Argument(help="JSON payload to send to the agent (e.g., '{\"input\": \"hello\"}')"), + typer.Argument(help="JSON payload to send to the agent (e.g., '{\"message\": \"hello\"}')"), ], agent: Annotated[ str | None, @@ -126,6 +137,14 @@ def invoke( int, typer.Option("--timeout", help="Request timeout in seconds (default: 900)"), ] = 900, + skip_ssl_verification: Annotated[ + bool, + typer.Option("--skip-ssl-verification", help="Skip SSL certificate verification"), + ] = False, + user_id: Annotated[ + str | None, + typer.Option("--user-id", "-u", help="User ID for OAuth2 outbound credentials"), + ] = None, ): """ Invoke agent with JSON payload. @@ -138,9 +157,10 @@ def invoke( Examples: agentarts invoke '{"message": "hello"}' - agentarts invoke '{"message": "hello"}' --agent my-agent + agentarts invoke '{"message": "hello"}' --agent myagent agentarts invoke '{"message": "hello"}' --mode local --port 8080 agentarts invoke '{"message": "test"}' --session my-session-123 + agentarts invoke '{"message": "test"}' --user-id my-user-id """ invoke_mode = InvokeMode.CLOUD if mode.lower() == "local": @@ -159,6 +179,8 @@ def invoke( session_id=session_id, bearer_token=bearer_token, timeout=timeout, + skip_ssl_verification=skip_ssl_verification, + user_id=user_id, ) if not success: diff --git a/src/agentarts/toolkit/operations/memory/space.py b/src/agentarts/toolkit/operations/memory/space.py index 9e75e28..5c9959a 100644 --- a/src/agentarts/toolkit/operations/memory/space.py +++ b/src/agentarts/toolkit/operations/memory/space.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) -def _get_client(region: str | None = None) -> MemoryClient: +def _get_client(region: str | None = None, verify_ssl: bool | str = True) -> MemoryClient: """Get MemoryClient instance. Uses Huawei Cloud SDK Core credential provider chain (AK/SK). @@ -24,11 +24,12 @@ def _get_client(region: str | None = None) -> MemoryClient: Args: region: Region name (optional, defaults to cn-north-4) + verify_ssl: SSL verification setting (default: True) Returns: MemoryClient instance """ - kwargs = {} + kwargs = {"verify_ssl": verify_ssl} if region: kwargs["region_name"] = region return MemoryClient(**kwargs) @@ -71,6 +72,7 @@ def create_space( private_vpc_id: str | None = None, private_subnet_id: str | None = None, region: str | None = None, + skip_ssl_verification: bool = False, **kwargs, ) -> SpaceResult: """Create a Memory Space. @@ -93,13 +95,15 @@ def create_space( private_vpc_id: Private VPC ID (requires private_subnet_id) private_subnet_id: Private subnet ID (requires private_vpc_id) region: Region name (optional, defaults to cn-north-4) + skip_ssl_verification: Skip SSL certificate verification (default: False) **kwargs: Additional parameters (ignored, for backward compatibility) Returns: SpaceResult with space_id and space details """ try: - client = _get_client(region=region) + verify_ssl = not skip_ssl_verification + client = _get_client(region=region, verify_ssl=verify_ssl) # 使用关键字参数调用 create_space(适配新 API) space = client.create_space( @@ -137,18 +141,21 @@ def create_space( def get_space( space_id: str, region: str | None = None, + skip_ssl_verification: bool = False, ) -> SpaceResult: """Get Space details. Args: space_id: Space ID region: Region name (optional, defaults to cn-north-4) + skip_ssl_verification: Skip SSL certificate verification (default: False) Returns: SpaceResult with space details """ try: - client = _get_client(region=region) + verify_ssl = not skip_ssl_verification + client = _get_client(region=region, verify_ssl=verify_ssl) space = client.get_space(space_id) space_dict = _space_info_to_dict(space) @@ -171,6 +178,7 @@ def list_spaces( limit: int = 20, offset: int = 0, region: str | None = None, + skip_ssl_verification: bool = False, ) -> SpaceListResult: """List Spaces. @@ -178,12 +186,14 @@ def list_spaces( limit: Maximum number of spaces to return (default: 20) offset: Offset for pagination (default: 0) region: Region name (optional, defaults to cn-north-4) + skip_ssl_verification: Skip SSL certificate verification (default: False) Returns: SpaceListResult with list of spaces """ try: - client = _get_client(region=region) + verify_ssl = not skip_ssl_verification + client = _get_client(region=region, verify_ssl=verify_ssl) result = client.list_spaces(limit=limit, offset=offset) # 转换为字典列表以保持向后兼容 @@ -216,6 +226,7 @@ def update_space( memory_strategies_builtin: list[str] | None = None, tags: list[dict[str, str]] | None = None, region: str | None = None, + skip_ssl_verification: bool = False, **kwargs, ) -> SpaceResult: """Update a Space. @@ -232,13 +243,15 @@ def update_space( memory_strategies_builtin: Built-in memory strategies (optional list) tags: Tags for the space (optional list of key-value dicts) region: Region name (optional, defaults to cn-north-4) + skip_ssl_verification: Skip SSL certificate verification (default: False) **kwargs: Additional update parameters (ignored, for backward compatibility) Returns: SpaceResult with updated space details """ try: - client = _get_client(region=region) + verify_ssl = not skip_ssl_verification + client = _get_client(region=region, verify_ssl=verify_ssl) # 使用关键字参数调用 update_space(适配新 API) space = client.update_space( @@ -274,18 +287,21 @@ def update_space( def delete_space( space_id: str, region: str | None = None, + skip_ssl_verification: bool = False, ) -> SpaceResult: """Delete a Space. Args: space_id: Space ID region: Region name (optional, defaults to cn-north-4) + skip_ssl_verification: Skip SSL certificate verification (default: False) Returns: SpaceResult indicating success or failure """ try: - client = _get_client(region=region) + verify_ssl = not skip_ssl_verification + client = _get_client(region=region, verify_ssl=verify_ssl) client.delete_space(space_id) logger.info(f"Space deleted successfully: {space_id}") diff --git a/src/agentarts/toolkit/operations/runtime/deploy.py b/src/agentarts/toolkit/operations/runtime/deploy.py index 594d9d5..369c73e 100644 --- a/src/agentarts/toolkit/operations/runtime/deploy.py +++ b/src/agentarts/toolkit/operations/runtime/deploy.py @@ -46,6 +46,7 @@ def create_agentarts_runtime( agent_config: Any | None = None, port: int | None = None, description: str | None = None, + verify_ssl: bool | str = True, ) -> dict | None: """ Create or update AgentArts runtime using RuntimeClient. @@ -68,7 +69,7 @@ def create_agentarts_runtime( endpoint = get_control_plane_endpoint(region) - client = RuntimeClient(control_endpoint=endpoint, verify_ssl=False) + client = RuntimeClient(control_endpoint=endpoint, verify_ssl=verify_ssl) artifact_source_config = None invoke_config = {} @@ -158,6 +159,7 @@ def deploy_project( swr_org: str | None = None, swr_repo: str | None = None, description: str | None = None, + skip_ssl_verification: bool = False, ) -> bool: """ Deploy project. @@ -241,8 +243,10 @@ def deploy_project( border_style="cyan", )) + verify_ssl = not skip_ssl_verification + try: - swr_client = SWRClient(region=region) + swr_client = SWRClient(region=region, verify_ssl=verify_ssl) if agent_config.swr_config.organization_auto_create: org_result = swr_client.create_or_get_organization(final_swr_org) @@ -323,6 +327,7 @@ def deploy_project( agent_config=agent_config, port=port, description=description, + verify_ssl=verify_ssl, ) if agent is None: diff --git a/src/agentarts/toolkit/operations/runtime/destroy.py b/src/agentarts/toolkit/operations/runtime/destroy.py index 485ef42..b42df63 100644 --- a/src/agentarts/toolkit/operations/runtime/destroy.py +++ b/src/agentarts/toolkit/operations/runtime/destroy.py @@ -13,6 +13,7 @@ def destroy_agent( agent_name: str | None = None, region: str | None = None, + skip_ssl_verification: bool = False, ) -> bool: """ Destroy agent from Huawei Cloud. @@ -20,6 +21,7 @@ def destroy_agent( Args: agent_name: Agent name to destroy region: Huawei Cloud region + skip_ssl_verification: Skip SSL certificate verification (default: False) Returns: True if destroyed successfully, False otherwise @@ -45,8 +47,9 @@ def destroy_agent( from agentarts.sdk.service import RuntimeClient from agentarts.sdk.utils.constant import get_control_plane_endpoint + verify_ssl = not skip_ssl_verification control_endpoint = get_control_plane_endpoint(actual_region) - client = RuntimeClient(control_endpoint=control_endpoint, verify_ssl=False) + client = RuntimeClient(control_endpoint=control_endpoint, verify_ssl=verify_ssl) result = client.delete_agent_by_name(agent_name=agent_name) diff --git a/src/agentarts/toolkit/operations/runtime/invoke.py b/src/agentarts/toolkit/operations/runtime/invoke.py index 9937449..fe141d8 100644 --- a/src/agentarts/toolkit/operations/runtime/invoke.py +++ b/src/agentarts/toolkit/operations/runtime/invoke.py @@ -82,6 +82,7 @@ def _get_data_endpoint( agent_name: str, region: str, agent_id: str | None = None, + verify_ssl: bool | str = True, ) -> str | None: """ Get data plane endpoint for the agent. @@ -101,7 +102,7 @@ def _get_data_endpoint( if not data_endpoint: control_endpoint = get_control_plane_endpoint(region) - control_client = RuntimeClient(control_endpoint=control_endpoint, verify_ssl=False) + control_client = RuntimeClient(control_endpoint=control_endpoint, verify_ssl=verify_ssl) if agent_id: agent_detail = control_client.find_agent_by_id(agent_id) @@ -137,6 +138,77 @@ class InvokeMode(str, Enum): CLOUD = "cloud" +def _normalize_json_payload(payload: str) -> str: + """ + Normalize JSON payload to handle Windows PowerShell quote stripping. + + On Windows PowerShell, double quotes inside single-quoted strings may be + stripped when passed to subprocess, causing '{"message":"hello"}' to become + '{message:hello}'. This function attempts to restore proper JSON formatting. + + Args: + payload: Raw payload string received from CLI + + Returns: + Normalized JSON string + """ + if not payload: + return payload + + payload = payload.strip() + + try: + json.loads(payload) + return payload + except json.JSONDecodeError: + pass + + if payload.startswith("'") and payload.endswith("'"): + payload = payload[1:-1] + + if '\\"' in payload: + payload = payload.replace('\\"', '"') + try: + json.loads(payload) + return payload + except json.JSONDecodeError: + pass + + if payload.startswith("{") and payload.endswith("}"): + inner = payload[1:-1].strip() + if not inner: + return "{}" + + result_parts = [] + parts = inner.split(",") + for part in parts: + part = part.strip() + if ":" in part: + key_val = part.split(":", 1) + key = key_val[0].strip() + val = key_val[1].strip() if len(key_val) > 1 else "" + + if not key.startswith('"') and not key.startswith("'"): + key = f'"{key}"' + + if val: + if not val.startswith('"') and not val.startswith("'") and not val.startswith("[") and not val.startswith("{") and not val.isdigit() and val.lower() not in ("true", "false", "null"): + val = f'"{val}"' + + result_parts.append(f"{key}:{val}") + else: + result_parts.append(part) + + reconstructed = "{" + ",".join(result_parts) + "}" + try: + json.loads(reconstructed) + return reconstructed + except json.JSONDecodeError: + pass + + return payload + + def invoke_agent( payload: str, agent_name: str | None = None, @@ -147,6 +219,8 @@ def invoke_agent( session_id: str | None = None, bearer_token: str | None = None, timeout: int = 900, + skip_ssl_verification: bool = False, + user_id: str | None = None, ) -> bool: """ Invoke agent locally or on cloud. @@ -161,12 +235,15 @@ def invoke_agent( session_id: Session ID for stateful agents bearer_token: Optional bearer token timeout: Request timeout in seconds + skip_ssl_verification: Skip SSL certificate verification + user_id: Optional user ID for OAuth2 outbound credentials Returns: True if successful, False otherwise """ + normalized_payload = _normalize_json_payload(payload) try: - json.loads(payload) + json.loads(normalized_payload) except json.JSONDecodeError: echo_error("Payload must be valid JSON") return False @@ -182,11 +259,12 @@ def invoke_agent( echo_info("Invoke Request", f"[cyan]Mode:[/cyan] [yellow]Local[/yellow]\n[cyan]Endpoint:[/cyan] [white]localhost:{local_port}[/white]") result = client.invoke_agent( - payload=payload, + payload=normalized_payload, session_id=session_id, bearer_token=actual_bearer_token, endpoint=endpoint, timeout=timeout, + user_id=user_id, ) else: agent_name, region, agent_id, auth_type = _resolve_agent_info(agent_name, region) @@ -198,8 +276,9 @@ def invoke_agent( actual_region = region or get_region() actual_session_id = session_id or str(uuid.uuid4()) + verify_ssl = not skip_ssl_verification - data_endpoint = _get_data_endpoint(agent_name, actual_region, agent_id) + data_endpoint = _get_data_endpoint(agent_name, actual_region, agent_id, verify_ssl) if not data_endpoint: echo_error(f"No data plane endpoint configured and could not get access_endpoint from agent [yellow]{agent_name} {actual_region}[/yellow]") @@ -218,7 +297,7 @@ def invoke_agent( client = RuntimeClient( data_endpoint=data_endpoint, - verify_ssl=False, + verify_ssl=verify_ssl, sign_mode=sign_mode, region_id=actual_region, ) @@ -226,10 +305,11 @@ def invoke_agent( result = client.invoke_agent( agent_name=agent_name, session_id=actual_session_id, - payload=payload, + payload=normalized_payload, bearer_token=actual_bearer_token, endpoint=endpoint, timeout=timeout, + user_id=user_id, ) if isinstance(result, dict): @@ -263,6 +343,8 @@ def status_agent( endpoint: str | None = None, session_id: str | None = None, bearer_token: str | None = None, + skip_ssl_verification: bool = False, + user_id: str | None = None, ) -> bool: """ Check agent health status. @@ -275,6 +357,8 @@ def status_agent( endpoint: Optional endpoint name session_id: Session ID for stateful agents (auto-generated if None) bearer_token: Optional bearer token + skip_ssl_verification: Skip SSL certificate verification + user_id: Optional user ID for OAuth2 outbound credentials Returns: True if healthy, False otherwise @@ -294,6 +378,7 @@ def status_agent( bearer_token=actual_bearer_token, endpoint=endpoint, session_id=actual_session_id, + user_id=user_id, ) status = result.get("status", "Unknown") @@ -309,8 +394,9 @@ def status_agent( return False actual_region = region or get_region() + verify_ssl = not skip_ssl_verification - data_endpoint = _get_data_endpoint(agent_name, actual_region, agent_id) + data_endpoint = _get_data_endpoint(agent_name, actual_region, agent_id, verify_ssl) if not data_endpoint: echo_error(f"No data plane endpoint configured and could not get access_endpoint from agent {agent_name}") @@ -330,7 +416,7 @@ def status_agent( client = RuntimeClient( data_endpoint=data_endpoint, - verify_ssl=False, + verify_ssl=verify_ssl, sign_mode=sign_mode, region_id=actual_region, ) @@ -340,6 +426,7 @@ def status_agent( bearer_token=actual_bearer_token, endpoint=endpoint, session_id=actual_session_id, + user_id=user_id, ) if isinstance(result, dict): diff --git a/src/agentarts/toolkit/utils/templates/docker/__init__.py b/src/agentarts/toolkit/utils/templates/docker/__init__.py index 265e714..4ee85a5 100644 --- a/src/agentarts/toolkit/utils/templates/docker/__init__.py +++ b/src/agentarts/toolkit/utils/templates/docker/__init__.py @@ -40,10 +40,10 @@ def render_dockerfile( """ template = get_dockerfile_template() + env_lines = ["# Set container environment marker", "ENV DOCKER_CONTAINER=true"] if region: - env_section = f"# Set Huawei Cloud region\nENV HUAWEICLOUD_SDK_REGION={region}" - else: - env_section = "# No region specified" + env_lines.append(f"ENV HUAWEICLOUD_SDK_REGION={region}") + env_section = "\n".join(env_lines) user_section = f"""# Create non-root user for security RUN groupadd -g {group_id} {user_name} && \\ @@ -59,11 +59,11 @@ def render_dockerfile( chown_app_section = f"RUN chown -R {user_name}:{user_name} /app" - if entrypoint and ":" in entrypoint: - module, app_target = entrypoint.split(":") - cmd_section = f'CMD ["uvicorn", "{module}:{app_target}", "--host", "0.0.0.0", "--port", "{port}"]' + if entrypoint: + module = entrypoint.split(":")[0] if ":" in entrypoint else entrypoint + cmd_section = f'CMD ["python", "-m", "{module}"]' else: - cmd_section = 'CMD ["python", "-m", "agentarts.server", "--config", "agentarts.yaml"]' + cmd_section = f'CMD ["python", "-m", "agent", "--host", "0.0.0.0", "--port", "{port}"]' content = template.format( base_image=base_image, diff --git a/tests/unit/sdk/memory/test_memory.py b/tests/unit/sdk/memory/test_memory.py index f3cae47..a6b5bd6 100644 --- a/tests/unit/sdk/memory/test_memory.py +++ b/tests/unit/sdk/memory/test_memory.py @@ -1,5 +1,6 @@ """Tests for memory module""" +from unittest.mock import MagicMock, patch def test_memory_client_import(): @@ -31,3 +32,183 @@ def test_memory_types_import(): assert SpaceUpdateRequest is not None assert SessionCreateRequest is not None assert MessageRequest is not None + + +class TestMemoryClientClose: + """Tests for MemoryClient close functionality.""" + + def test_memory_client_close_without_control_plane(self): + """Test MemoryClient close when control plane is not initialized.""" + from agentarts.sdk import MemoryClient + + client = MemoryClient(region_name="cn-north-4", api_key="test-api-key") + + assert client._control_plane is None + assert client._data_plane is not None + + client.close() + + assert client._control_plane is None + + def test_memory_client_close_with_control_plane(self): + """Test MemoryClient close when control plane is initialized.""" + from agentarts.sdk import MemoryClient + from agentarts.sdk.memory.inner.controlplane import _ControlPlane + + client = MemoryClient(region_name="cn-north-4", api_key="test-api-key") + + mock_control_plane = MagicMock(spec=_ControlPlane) + client._control_plane = mock_control_plane + + client.close() + + mock_control_plane.close.assert_called_once() + + def test_memory_client_context_manager(self): + """Test MemoryClient works as context manager.""" + from agentarts.sdk import MemoryClient + + with MemoryClient(region_name="cn-north-4", api_key="test-api-key") as client: + assert client is not None + + def test_memory_client_context_manager_calls_close(self): + """Test context manager calls close on exit.""" + from agentarts.sdk import MemoryClient + + client = MemoryClient(region_name="cn-north-4", api_key="test-api-key") + + with patch.object(client, "close") as mock_close: + with client: + pass + mock_close.assert_called_once() + + +class TestDataPlaneClose: + """Tests for _DataPlane close functionality.""" + + def test_dataplane_close(self): + """Test _DataPlane close method.""" + from agentarts.sdk.memory.inner.dataplane import _DataPlane + + dataplane = _DataPlane(region_name="cn-north-4", api_key="test-api-key") + + assert dataplane.client is not None + + dataplane.close() + + def test_dataplane_close_calls_client_close(self): + """Test _DataPlane close calls underlying client close.""" + from agentarts.sdk.memory.inner.dataplane import _DataPlane + + dataplane = _DataPlane(region_name="cn-north-4", api_key="test-api-key") + + with patch.object(dataplane.client, "close") as mock_close: + dataplane.close() + mock_close.assert_called_once() + + +class TestControlPlaneClose: + """Tests for _ControlPlane close functionality.""" + + def test_controlplane_close(self): + """Test _ControlPlane close method.""" + from agentarts.sdk.memory.inner.controlplane import _ControlPlane + + with patch("agentarts.sdk.service.memory_service.ControlPlaneAuthenticationStrategy.setup_credentials"): + controlplane = _ControlPlane(region_name="cn-north-4") + + assert controlplane.client is not None + + controlplane.close() + + def test_controlplane_close_calls_client_close(self): + """Test _ControlPlane close calls underlying client close.""" + from agentarts.sdk.memory.inner.controlplane import _ControlPlane + + with patch("agentarts.sdk.service.memory_service.ControlPlaneAuthenticationStrategy.setup_credentials"): + controlplane = _ControlPlane(region_name="cn-north-4") + + with patch.object(controlplane.client, "close") as mock_close: + controlplane.close() + mock_close.assert_called_once() + + +class TestMemoryHttpServiceClose: + """Tests for MemoryHttpService close functionality.""" + + def test_memory_http_service_close(self): + """Test MemoryHttpService close method.""" + from agentarts.sdk.service.memory_service import MemoryHttpService + + service = MemoryHttpService( + region_name="cn-north-4", + endpoint_type="data", + api_key="test-api-key" + ) + + assert service.session is not None + + service.close() + + def test_memory_http_service_close_calls_session_close(self): + """Test MemoryHttpService close calls session close.""" + from agentarts.sdk.service.memory_service import MemoryHttpService + + service = MemoryHttpService( + region_name="cn-north-4", + endpoint_type="data", + api_key="test-api-key" + ) + + with patch.object(service.session, "close") as mock_close: + service.close() + mock_close.assert_called_once() + + def test_memory_http_service_close_control_plane(self): + """Test MemoryHttpService close for control plane.""" + from agentarts.sdk.service.memory_service import MemoryHttpService + + with patch("agentarts.sdk.service.memory_service.ControlPlaneAuthenticationStrategy.setup_credentials"): + service = MemoryHttpService( + region_name="cn-north-4", + endpoint_type="control" + ) + + assert service.session is not None + + service.close() + + def test_memory_http_service_close_safe_when_no_session(self): + """Test MemoryHttpService close is safe when session is None.""" + from agentarts.sdk.service.memory_service import MemoryHttpService + + service = MemoryHttpService( + region_name="cn-north-4", + endpoint_type="data", + api_key="test-api-key" + ) + + service.session = None + + service.close() + + +class TestCloseMethodChaining: + """Tests for close method chaining from MemoryClient to session.""" + + def test_close_method_chain(self): + """Test that close calls propagate from client to session.""" + from agentarts.sdk import MemoryClient + from agentarts.sdk.memory.inner.controlplane import _ControlPlane + + client = MemoryClient(region_name="cn-north-4", api_key="test-api-key") + + mock_control_plane = MagicMock(spec=_ControlPlane) + client._control_plane = mock_control_plane + + with patch.object(client._data_plane.client.session, "close") as mock_session_close: + client.close() + + mock_control_plane.close.assert_called_once() + + mock_session_close.assert_called_once() diff --git a/tests/unit/sdk/runtime/test_app.py b/tests/unit/sdk/runtime/test_app.py index 9247241..b49464a 100644 --- a/tests/unit/sdk/runtime/test_app.py +++ b/tests/unit/sdk/runtime/test_app.py @@ -228,7 +228,7 @@ def test_build_request_context_sets_runtime_context(self): mock_request.headers = { "X-HW-AgentGateway-Workload-Access-Token": "workload-token", "x-hw-agentarts-session-id": "session-abc", - "X-Hw-AgentArts-Runtime-User-Id": "user-xyz", + "X-HW-AgentGateway-User-Id": "user-xyz", } app._build_request_context(mock_request) @@ -271,6 +271,46 @@ def handler(payload, other): assert app._task_context(handler) is False +class TestPingTaskContext: + """Tests for _ping_task_context method.""" + + def test_ping_task_context_with_context_param(self): + """Test _ping_task_context returns True for handler with context as first param.""" + app = AgentArtsRuntimeApp() + + def handler(context): + pass + + assert app._ping_task_context(handler) is True + + def test_ping_task_context_without_context_param(self): + """Test _ping_task_context returns False for handler without context param.""" + app = AgentArtsRuntimeApp() + + def handler(): + pass + + assert app._ping_task_context(handler) is False + + def test_ping_task_context_wrong_param_name(self): + """Test _ping_task_context returns False when first param is not 'context'.""" + app = AgentArtsRuntimeApp() + + def handler(other): + pass + + assert app._ping_task_context(handler) is False + + def test_ping_task_context_with_multiple_params(self): + """Test _ping_task_context returns True when first param is 'context'.""" + app = AgentArtsRuntimeApp() + + def handler(context, extra): + pass + + assert app._ping_task_context(handler) is True + + class TestSerialization: """Tests for serialization methods.""" @@ -523,6 +563,84 @@ async def test_handle_ping_with_running_tasks(self): body = json.loads(response.body) assert body["status"] == "HealthyBusy" + @pytest.mark.asyncio + async def test_handle_ping_with_context_param_handler(self): + """Test _handle_ping passes context to handler when context param is present.""" + app = AgentArtsRuntimeApp() + + received_context = None + + @app.ping + def ping_with_context(context): + nonlocal received_context + received_context = context + return PingStatus.HEALTHY + + mock_request = MagicMock(spec=Request) + mock_request.headers = { + "x-hw-agentarts-session-id": "test-session-123", + "X-Request-Id": "req-456", + } + + response = await app._handle_ping(mock_request) + + assert response.status_code == 200 + assert received_context is not None + assert received_context.session_id == "test-session-123" + assert received_context.request_id == "req-456" + assert received_context.request == mock_request + + @pytest.mark.asyncio + async def test_handle_ping_without_context_param_handler(self): + """Test _handle_ping does not pass context when handler has no context param.""" + app = AgentArtsRuntimeApp() + + handler_called = False + + @app.ping + def ping_no_context(): + nonlocal handler_called + handler_called = True + return PingStatus.HEALTHY + + mock_request = MagicMock(spec=Request) + mock_request.headers = {} + + response = await app._handle_ping(mock_request) + + assert response.status_code == 200 + assert handler_called is True + + @pytest.mark.asyncio + async def test_handle_ping_context_with_headers(self): + """Test _handle_ping builds request_context with headers.""" + app = AgentArtsRuntimeApp() + + received_session_id = None + + @app.ping + def ping_check_session(context): + nonlocal received_session_id + received_session_id = context.session_id + return PingStatus.HEALTHY + + mock_request = MagicMock(spec=Request) + mock_request.headers = { + "X-HW-AgentGateway-Workload-Access-Token": "workload-token", + "x-hw-agentarts-session-id": "session-abc", + "X-HW-AgentGateway-User-Id": "user-xyz", + } + + response = await app._handle_ping(mock_request) + + assert response.status_code == 200 + assert received_session_id == "session-abc" + assert AgentArtsRuntimeContext.get_workload_access_token() == "workload-token" + assert AgentArtsRuntimeContext.get_session_id() == "session-abc" + assert AgentArtsRuntimeContext.get_user_id() == "user-xyz" + + AgentArtsRuntimeContext.clear() + class TestGetCurrentPingStatus: """Tests for get_current_ping_status method.""" @@ -553,6 +671,93 @@ def test_get_current_ping_status_forced(self): assert status == PingStatus.UNHEALTHY + def test_get_current_ping_status_with_context_param_handler(self): + """Test get_current_ping_status passes context to handler with context param.""" + app = AgentArtsRuntimeApp() + + received_context = None + + @app.ping + def ping_with_context(context): + nonlocal received_context + received_context = context + return PingStatus.HEALTHY + + request_context = RequestContext( + session_id="test-session", + request_id="req-123", + request=None + ) + + status = app.get_current_ping_status(request_context) + + assert status == PingStatus.HEALTHY + assert received_context == request_context + + def test_get_current_ping_status_without_context_param_handler(self): + """Test get_current_ping_status does not pass context when handler has no context param.""" + app = AgentArtsRuntimeApp() + + handler_called = False + + @app.ping + def ping_no_context(): + nonlocal handler_called + handler_called = True + return PingStatus.HEALTHY + + request_context = RequestContext( + session_id="test-session", + request_id="req-123", + request=None + ) + + status = app.get_current_ping_status(request_context) + + assert status == PingStatus.HEALTHY + assert handler_called is True + + def test_get_current_ping_status_no_context_passed_when_none(self): + """Test get_current_ping_status does not pass None context.""" + app = AgentArtsRuntimeApp() + + received_context = None + + @app.ping + def ping_with_context(context): + nonlocal received_context + received_context = context + return PingStatus.HEALTHY + + status = app.get_current_ping_status(None) + + assert status == PingStatus.HEALTHY + assert received_context is None + + def test_get_current_ping_status_custom_handler_uses_context(self): + """Test custom ping handler can use context to determine status.""" + app = AgentArtsRuntimeApp() + + @app.ping + def ping_check_session(context): + if context and context.session_id == "healthy-session": + return PingStatus.HEALTHY + return PingStatus.UNHEALTHY + + healthy_context = RequestContext( + session_id="healthy-session", + request_id="req-1", + request=None + ) + unhealthy_context = RequestContext( + session_id="unhealthy-session", + request_id="req-2", + request=None + ) + + assert app.get_current_ping_status(healthy_context) == PingStatus.HEALTHY + assert app.get_current_ping_status(unhealthy_context) == PingStatus.UNHEALTHY + class TestForcePingStatus: """Tests for force_ping_status method.""" @@ -654,7 +859,20 @@ def test_run_docker_environment(self): """Test run uses 0.0.0.0 in Docker environment.""" app = AgentArtsRuntimeApp() - with patch("os.path.exists", return_value=True), patch("uvicorn.run") as mock_run: + with ( + patch("os.path.exists", return_value=True), + patch("agentarts.sdk.runtime.app.subprocess.run") as mock_subprocess, + patch("agentarts.sdk.runtime.app.socket.socket") as mock_socket, + patch("agentarts.sdk.runtime.app.socket.gethostbyname", return_value="127.0.0.1"), + patch("agentarts.sdk.runtime.app.socket.gethostname", return_value="localhost"), + patch("uvicorn.run") as mock_run, + ): + mock_subprocess.return_value.stdout = "" + mock_subprocess.return_value.returncode = 1 + mock_socket_instance = MagicMock() + mock_socket_instance.connect.side_effect = Exception("mocked failure") + mock_socket.return_value = mock_socket_instance + app.run() call_kwargs = mock_run.call_args[1] assert call_kwargs["host"] == "0.0.0.0" @@ -715,3 +933,106 @@ def failing_handler(payload): with pytest.raises(RuntimeError, match="Handler failed"): await app._invoke_handler(failing_handler, context, False, {}) + + +class TestInvokeHandlerContextPropagation: + """Tests for contextvars propagation in _invoke_handler.""" + + @pytest.mark.asyncio + async def test_sync_handler_gets_workload_access_token(self): + """Sync handler should be able to read workload_access_token from AgentArtsRuntimeContext.""" + app = AgentArtsRuntimeApp() + + def sync_handler(payload): + token = AgentArtsRuntimeContext.get_workload_access_token() + return {"token": token} + + AgentArtsRuntimeContext.set_workload_access_token("test-workload-token-123") + try: + context = RequestContext(session_id="test", request_id="req-1", request=None) + result = await app._invoke_handler(sync_handler, context, False, {}) + assert result["token"] == "test-workload-token-123" + finally: + AgentArtsRuntimeContext.set_workload_access_token(None) + + @pytest.mark.asyncio + async def test_async_handler_gets_workload_access_token(self): + """Async handler should be able to read workload_access_token from AgentArtsRuntimeContext.""" + app = AgentArtsRuntimeApp() + + async def async_handler(payload): + token = AgentArtsRuntimeContext.get_workload_access_token() + return {"token": token} + + AgentArtsRuntimeContext.set_workload_access_token("test-workload-token-456") + try: + context = RequestContext(session_id="test", request_id="req-1", request=None) + result = await app._invoke_handler(async_handler, context, False, {}) + assert result["token"] == "test-workload-token-456" + finally: + AgentArtsRuntimeContext.set_workload_access_token(None) + + @pytest.mark.asyncio + async def test_sync_handler_gets_session_id(self): + """Sync handler should be able to read session_id from AgentArtsRuntimeContext.""" + app = AgentArtsRuntimeApp() + + def sync_handler(payload): + session_id = AgentArtsRuntimeContext.get_session_id() + return {"session_id": session_id} + + AgentArtsRuntimeContext.set_session_id("session-abc") + try: + context = RequestContext(session_id="test", request_id="req-1", request=None) + result = await app._invoke_handler(sync_handler, context, False, {}) + assert result["session_id"] == "session-abc" + finally: + AgentArtsRuntimeContext.set_session_id(None) + + @pytest.mark.asyncio + async def test_sync_handler_gets_user_id(self): + """Sync handler should be able to read user_id from AgentArtsRuntimeContext.""" + app = AgentArtsRuntimeApp() + + def sync_handler(payload): + user_id = AgentArtsRuntimeContext.get_user_id() + return {"user_id": user_id} + + AgentArtsRuntimeContext.set_user_id("user-xyz") + try: + context = RequestContext(session_id="test", request_id="req-1", request=None) + result = await app._invoke_handler(sync_handler, context, False, {}) + assert result["user_id"] == "user-xyz" + finally: + AgentArtsRuntimeContext.set_user_id(None) + + @pytest.mark.asyncio + async def test_concurrent_sync_handlers_isolated_context(self): + """Concurrent sync handlers should have isolated contextvars.""" + import asyncio + + app = AgentArtsRuntimeApp() + results = {} + + def make_handler(key, token_value): + def handler(payload): + local_token = AgentArtsRuntimeContext.get_workload_access_token() + results[key] = local_token + return {"key": key, "token": local_token} + return handler + + async def run_with_context(key, token_value): + AgentArtsRuntimeContext.set_workload_access_token(token_value) + try: + context = RequestContext(session_id="test", request_id="req-1", request=None) + return await app._invoke_handler(make_handler(key, token_value), context, False, {}) + finally: + AgentArtsRuntimeContext.set_workload_access_token(None) + + await asyncio.gather( + run_with_context("handler_a", "token-A"), + run_with_context("handler_b", "token-B"), + ) + + assert results["handler_a"] == "token-A" + assert results["handler_b"] == "token-B" diff --git a/tests/unit/sdk/runtime/test_context.py b/tests/unit/sdk/runtime/test_context.py index eacd8c3..2fdb2ee 100644 --- a/tests/unit/sdk/runtime/test_context.py +++ b/tests/unit/sdk/runtime/test_context.py @@ -47,6 +47,7 @@ def test_oauth2_custom_state_context(): def test_request_id_context(): + AgentArtsRuntimeContext.clear() assert AgentArtsRuntimeContext.get_request_id() is None AgentArtsRuntimeContext.set_request_id("req-123") assert AgentArtsRuntimeContext.get_request_id() == "req-123" @@ -55,6 +56,7 @@ def test_request_id_context(): def test_session_id_context(): + AgentArtsRuntimeContext.clear() assert AgentArtsRuntimeContext.get_session_id() is None AgentArtsRuntimeContext.set_session_id("sess-123") assert AgentArtsRuntimeContext.get_session_id() == "sess-123" diff --git a/tests/unit/sdk/runtime/test_context_propagation.py b/tests/unit/sdk/runtime/test_context_propagation.py new file mode 100644 index 0000000..9d575bd --- /dev/null +++ b/tests/unit/sdk/runtime/test_context_propagation.py @@ -0,0 +1,530 @@ +""" +Tests for contextvars propagation in AgentArtsRuntimeApp with require_access_token decorator. + +This test file verifies that workload_access_token is correctly propagated when: +1. Request comes through _handle_invocation +2. Handler uses require_access_token decorator +3. Handler is sync or async function +""" + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from starlette.requests import Request + +from agentarts.sdk.runtime.app import AgentArtsRuntimeApp +from agentarts.sdk.runtime.context import AgentArtsRuntimeContext, RequestContext + + +class TestContextPropagationWithRequireAccessToken: + """Tests for context propagation when handler uses require_access_token.""" + + @pytest.mark.asyncio + async def test_sync_handler_with_require_access_token_gets_token_from_context( + self, mock_identity_client_for_app + ): + """ + Test that sync handler using require_access_token can get workload_access_token + from context set by _handle_invocation. + + This is the key test case for the reported issue. + """ + from agentarts.sdk.identity.auth import require_access_token + + app = AgentArtsRuntimeApp() + + @app.entrypoint + @require_access_token(provider_name="test-provider", auth_flow="M2M") + def sync_handler(payload, access_token=None): + return {"received_token": access_token, "payload": payload} + + mock_identity_client_for_app.get_resource_oauth2_token.return_value = "oauth2-result-token" + + mock_request = MagicMock(spec=Request) + mock_request.json = AsyncMock(return_value={"input": "test"}) + mock_request.headers = { + "X-HW-AgentGateway-Workload-Access-Token": "workload-token-from-header", + } + + AgentArtsRuntimeContext.clear() + try: + response = await app._handle_invocation(mock_request) + + assert response.status_code == 200 + body = json.loads(response.body) + assert body["received_token"] == "oauth2-result-token" + assert body["payload"] == {"input": "test"} + finally: + AgentArtsRuntimeContext.clear() + + @pytest.mark.asyncio + async def test_async_handler_with_require_access_token_gets_token_from_context( + self, mock_identity_client_for_app + ): + """ + Test that async handler using require_access_token can get workload_access_token + from context set by _handle_invocation. + """ + from agentarts.sdk.identity.auth import require_access_token + + app = AgentArtsRuntimeApp() + + @app.entrypoint + @require_access_token(provider_name="test-provider", auth_flow="M2M") + async def async_handler(payload, access_token=None): + return {"received_token": access_token, "payload": payload} + + mock_identity_client_for_app.get_resource_oauth2_token.return_value = "oauth2-result-token" + + mock_request = MagicMock(spec=Request) + mock_request.json = AsyncMock(return_value={"input": "async-test"}) + mock_request.headers = { + "X-HW-AgentGateway-Workload-Access-Token": "workload-token-from-header", + } + + AgentArtsRuntimeContext.clear() + try: + response = await app._handle_invocation(mock_request) + + assert response.status_code == 200 + body = json.loads(response.body) + assert body["received_token"] == "oauth2-result-token" + finally: + AgentArtsRuntimeContext.clear() + + @pytest.mark.asyncio + async def test_sync_handler_checks_workload_token_in_context(self): + """ + Direct test: sync handler should see workload_access_token set by _build_request_context. + + This test verifies the basic context propagation without the require_access_token decorator. + """ + app = AgentArtsRuntimeApp() + + captured_token = None + + @app.entrypoint + def sync_handler(payload): + nonlocal captured_token + captured_token = AgentArtsRuntimeContext.get_workload_access_token() + return {"captured_token": captured_token} + + mock_request = MagicMock(spec=Request) + mock_request.json = AsyncMock(return_value={"input": "test"}) + mock_request.headers = { + "X-HW-AgentGateway-Workload-Access-Token": "direct-token-test", + } + + AgentArtsRuntimeContext.clear() + try: + response = await app._handle_invocation(mock_request) + + assert response.status_code == 200 + assert captured_token == "direct-token-test" + finally: + AgentArtsRuntimeContext.clear() + + @pytest.mark.asyncio + async def test_require_access_token_wrapper_reads_context_directly(self, mock_identity_client_for_app): + """ + Test that require_access_token's _get_workload_access_token function + can read token from AgentArtsRuntimeContext set by _build_request_context. + """ + from agentarts.sdk.identity.auth import _get_workload_access_token + + AgentArtsRuntimeContext.clear() + AgentArtsRuntimeContext.set_workload_access_token("context-token-direct") + + try: + token = _get_workload_access_token(mock_identity_client_for_app) + assert token == "context-token-direct" + finally: + AgentArtsRuntimeContext.clear() + + @pytest.mark.asyncio + async def test_invoke_handler_copies_context_to_thread(self): + """ + Test that _invoke_handler correctly copies contextvars to thread pool. + + This verifies the copy_context() + run_in_executor pattern. + """ + import contextvars + + app = AgentArtsRuntimeApp() + + test_var = contextvars.ContextVar("test_var_for_propagation", default=None) + test_var.set("value-set-in-async") + + captured_value = None + + def sync_handler(payload): + nonlocal captured_value + captured_value = test_var.get() + return {"captured": captured_value} + + context = RequestContext(session_id="test", request_id="req-1", request=None) + result = await app._invoke_handler(sync_handler, context, False, {}) + + assert captured_value == "value-set-in-async" + assert result["captured"] == "value-set-in-async" + + @pytest.mark.asyncio + async def test_run_async_in_sync_context_preserves_context(self): + """ + Test that run_async_in_sync_context preserves contextvars when called from + a thread that has copied context. + """ + import contextvars + + from agentarts.sdk.runtime.context import run_async_in_sync_context + + test_var = contextvars.ContextVar("test_var_run_async", default=None) + + async def async_coro(): + return test_var.get() + + def sync_func_with_context(): + test_var.set("value-in-sync-context") + ctx = contextvars.copy_context() + result = ctx.run(run_async_in_sync_context, async_coro()) + return result + + test_var.set("value-in-async-context") + ctx = contextvars.copy_context() + result = ctx.run(sync_func_with_context) + + assert result == "value-in-sync-context" + + @pytest.mark.asyncio + async def test_consecutive_requests_context_isolation(self): + """ + Test that consecutive requests have isolated context. + + This verifies that: + 1. First request sets workload_access_token + 2. After request completes, context should be cleared + 3. Second request without token header should NOT see first request's token + + This is the key test for the reported bug: + 'Workload Access Token is invalid or expired' + """ + app = AgentArtsRuntimeApp() + + captured_tokens = [] + + @app.entrypoint + def sync_handler(payload): + captured_tokens.append(AgentArtsRuntimeContext.get_workload_access_token()) + return {"status": "ok"} + + AgentArtsRuntimeContext.clear() + + mock_request1 = MagicMock(spec=Request) + mock_request1.json = AsyncMock(return_value={"input": "first"}) + mock_request1.headers = { + "X-HW-AgentGateway-Workload-Access-Token": "first-request-token", + } + + response1 = await app._handle_invocation(mock_request1) + assert response1.status_code == 200 + assert captured_tokens[-1] == "first-request-token" + + mock_request2 = MagicMock(spec=Request) + mock_request2.json = AsyncMock(return_value={"input": "second"}) + mock_request2.headers = {} + + response2 = await app._handle_invocation(mock_request2) + assert response2.status_code == 200 + + assert captured_tokens[-1] is None, ( + f"Second request should NOT have workload_access_token from first request. " + f"Got: {captured_tokens[-1]}" + ) + + AgentArtsRuntimeContext.clear() + + @pytest.mark.asyncio + async def test_consecutive_requests_different_tokens(self): + """ + Test that consecutive requests with different tokens work correctly. + """ + app = AgentArtsRuntimeApp() + + captured_tokens = [] + + @app.entrypoint + def sync_handler(payload): + captured_tokens.append(AgentArtsRuntimeContext.get_workload_access_token()) + return {"status": "ok"} + + AgentArtsRuntimeContext.clear() + + mock_request1 = MagicMock(spec=Request) + mock_request1.json = AsyncMock(return_value={"input": "first"}) + mock_request1.headers = { + "X-HW-AgentGateway-Workload-Access-Token": "token-A", + } + + response1 = await app._handle_invocation(mock_request1) + assert response1.status_code == 200 + + mock_request2 = MagicMock(spec=Request) + mock_request2.json = AsyncMock(return_value={"input": "second"}) + mock_request2.headers = { + "X-HW-AgentGateway-Workload-Access-Token": "token-B", + } + + response2 = await app._handle_invocation(mock_request2) + assert response2.status_code == 200 + + assert captured_tokens == ["token-A", "token-B"], ( + f"Each request should have its own token. Got: {captured_tokens}" + ) + + AgentArtsRuntimeContext.clear() + + @pytest.mark.asyncio + async def test_require_api_key_context_isolation( + self, mock_identity_client_for_app + ): + """ + Test that require_api_key benefits from context clearing. + + Verifies: token from first request does not leak to second request. + """ + + app = AgentArtsRuntimeApp() + + captured_workload_tokens = [] + + @app.entrypoint + def sync_handler(payload): + captured_workload_tokens.append( + AgentArtsRuntimeContext.get_workload_access_token() + ) + return {"status": "ok"} + + AgentArtsRuntimeContext.clear() + + mock_request1 = MagicMock(spec=Request) + mock_request1.json = AsyncMock(return_value={"input": "first"}) + mock_request1.headers = { + "X-HW-AgentGateway-Workload-Access-Token": "token-A", + } + + response1 = await app._handle_invocation(mock_request1) + assert response1.status_code == 200 + assert captured_workload_tokens[-1] == "token-A" + + mock_request2 = MagicMock(spec=Request) + mock_request2.json = AsyncMock(return_value={"input": "second"}) + mock_request2.headers = {} + + response2 = await app._handle_invocation(mock_request2) + assert response2.status_code == 200 + assert captured_workload_tokens[-1] is None, ( + f"Second request should NOT see token from first request. Got: {captured_workload_tokens[-1]}" + ) + + AgentArtsRuntimeContext.clear() + + @pytest.mark.asyncio + async def test_require_sts_token_context_isolation( + self, mock_identity_client_for_app + ): + """ + Test that require_sts_token benefits from context clearing. + + Verifies: token from first request does not leak to second request. + """ + + app = AgentArtsRuntimeApp() + + captured_workload_tokens = [] + + @app.entrypoint + def sync_handler(payload): + captured_workload_tokens.append( + AgentArtsRuntimeContext.get_workload_access_token() + ) + return {"status": "ok"} + + AgentArtsRuntimeContext.clear() + + mock_request1 = MagicMock(spec=Request) + mock_request1.json = AsyncMock(return_value={"input": "first"}) + mock_request1.headers = { + "X-HW-AgentGateway-Workload-Access-Token": "token-A", + } + + response1 = await app._handle_invocation(mock_request1) + assert response1.status_code == 200 + assert captured_workload_tokens[-1] == "token-A" + + mock_request2 = MagicMock(spec=Request) + mock_request2.json = AsyncMock(return_value={"input": "second"}) + mock_request2.headers = {} + + response2 = await app._handle_invocation(mock_request2) + assert response2.status_code == 200 + assert captured_workload_tokens[-1] is None, ( + f"Second request should NOT see token from first request. Got: {captured_workload_tokens[-1]}" + ) + + AgentArtsRuntimeContext.clear() + + @pytest.mark.asyncio + async def test_concurrent_requests_have_isolated_context(self): + """ + Test that concurrent requests have isolated context. + + This verifies that clearing context in one request does NOT affect + other concurrent requests running in parallel. + """ + import asyncio + + app = AgentArtsRuntimeApp() + + results = {} + + @app.entrypoint + def sync_handler(payload): + token = AgentArtsRuntimeContext.get_workload_access_token() + results[payload["id"]] = token + return {"id": payload["id"], "token": token} + + AgentArtsRuntimeContext.clear() + + def make_request(request_id: str, token_value: str | None): + mock_request = MagicMock(spec=Request) + mock_request.json = AsyncMock(return_value={"id": request_id}) + if token_value: + mock_request.headers = { + "X-HW-AgentGateway-Workload-Access-Token": token_value, + } + else: + mock_request.headers = {} + return mock_request + + await asyncio.gather( + app._handle_invocation(make_request("req-A", "token-A")), + app._handle_invocation(make_request("req-B", "token-B")), + app._handle_invocation(make_request("req-C", None)), + ) + + assert results["req-A"] == "token-A" + assert results["req-B"] == "token-B" + assert results["req-C"] is None + + AgentArtsRuntimeContext.clear() + + @pytest.mark.asyncio + async def test_decorator_reads_same_token_as_invocation_header( + self, mock_identity_client_for_app + ): + """ + Test that decorator's _get_workload_access_token reads the EXACT same token + that was set by _build_request_context from the invocation header. + + This is the core verification: + 1. Header: X-HW-AgentGateway-Workload-Access-Token = "workload-token-XYZ" + 2. _build_request_context sets context to "workload-token-XYZ" + 3. require_access_token decorator calls _get_workload_access_token + 4. _get_workload_access_token should return "workload-token-XYZ" + 5. get_resource_oauth2_token is called with workload_access_token="workload-token-XYZ" + """ + from agentarts.sdk.identity.auth import require_access_token + + app = AgentArtsRuntimeApp() + + invocation_token = "workload-token-from-invocation-header" + + @app.entrypoint + @require_access_token(provider_name="test-provider", auth_flow="M2M") + def sync_handler(payload, access_token=None): + return {"received_oauth_token": access_token} + + mock_identity_client_for_app.get_resource_oauth2_token.return_value = "oauth2-result-token" + + mock_request = MagicMock(spec=Request) + mock_request.json = AsyncMock(return_value={"input": "test"}) + mock_request.headers = { + "X-HW-AgentGateway-Workload-Access-Token": invocation_token, + } + + AgentArtsRuntimeContext.clear() + try: + response = await app._handle_invocation(mock_request) + + assert response.status_code == 200 + + mock_identity_client_for_app.get_resource_oauth2_token.assert_called_once() + + call_kwargs = mock_identity_client_for_app.get_resource_oauth2_token.call_args.kwargs + + assert call_kwargs["workload_access_token"] == invocation_token, ( + f"Decorator should use the SAME token from invocation header. " + f"Expected: {invocation_token}, Got: {call_kwargs.get('workload_access_token')}" + ) + + body = json.loads(response.body) + assert body["received_oauth_token"] == "oauth2-result-token" + finally: + AgentArtsRuntimeContext.clear() + + @pytest.mark.asyncio + async def test_async_decorator_reads_same_token_as_invocation_header( + self, mock_identity_client_for_app + ): + """ + Test that async handler's decorator also reads the same token. + """ + from agentarts.sdk.identity.auth import require_access_token + + app = AgentArtsRuntimeApp() + + invocation_token = "async-workload-token-from-header" + + @app.entrypoint + @require_access_token(provider_name="test-provider", auth_flow="M2M") + async def async_handler(payload, access_token=None): + return {"received_oauth_token": access_token} + + mock_identity_client_for_app.get_resource_oauth2_token.return_value = "oauth2-result-token" + + mock_request = MagicMock(spec=Request) + mock_request.json = AsyncMock(return_value={"input": "async-test"}) + mock_request.headers = { + "X-HW-AgentGateway-Workload-Access-Token": invocation_token, + } + + AgentArtsRuntimeContext.clear() + try: + response = await app._handle_invocation(mock_request) + + assert response.status_code == 200 + + call_kwargs = mock_identity_client_for_app.get_resource_oauth2_token.call_args.kwargs + + assert call_kwargs["workload_access_token"] == invocation_token, ( + f"Async decorator should use the SAME token from invocation header. " + f"Expected: {invocation_token}, Got: {call_kwargs.get('workload_access_token')}" + ) + finally: + AgentArtsRuntimeContext.clear() + + +@pytest.fixture +def mock_identity_client_for_app(): + """Fixture to mock IdentityClient for app tests.""" + from agentarts.sdk.identity import auth + + with patch.object(auth, "IdentityClient") as MockClass: + mock_instance = MockClass.return_value + mock_instance.get_resource_oauth2_token = AsyncMock() + mock_instance.get_resource_api_key = MagicMock(return_value="mock-api-key") + mock_instance.get_resource_sts_token = MagicMock(return_value={}) + mock_instance.create_workload_identity = MagicMock() + mock_instance.create_workload_access_token = MagicMock(return_value="mock-workload-token") + yield mock_instance diff --git a/tests/unit/sdk/runtime/test_context_without_clear.py b/tests/unit/sdk/runtime/test_context_without_clear.py new file mode 100644 index 0000000..e7d213a --- /dev/null +++ b/tests/unit/sdk/runtime/test_context_without_clear.py @@ -0,0 +1,159 @@ +""" +Demonstration test showing what happens WITHOUT context clearing. + +This test simulates the original bug scenario where: +1. Request 1 sets workload_access_token +2. Request ends WITHOUT clearing context +3. Request 2 sees Request 1's token (LEAKED!) +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from starlette.requests import Request + +from agentarts.sdk.runtime.app import AgentArtsRuntimeApp +from agentarts.sdk.runtime.context import AgentArtsRuntimeContext + + +class TestContextWithoutClearing: + """ + These tests demonstrate the bug when context is NOT cleared. + + They bypass the normal _handle_invocation flow and directly test + the scenario where context leaks between requests. + """ + + @pytest.mark.asyncio + async def test_token_leaks_without_clear(self): + """ + Demonstrate: Without clear(), token from Request 1 leaks to Request 2. + + This is the root cause of the bug you reported: + 'Workload Access Token is invalid or expired' + """ + app = AgentArtsRuntimeApp() + + captured_tokens = [] + + @app.entrypoint + def sync_handler(payload): + captured_tokens.append(AgentArtsRuntimeContext.get_workload_access_token()) + return {"status": "ok"} + + AgentArtsRuntimeContext.clear() + + mock_request1 = MagicMock(spec=Request) + mock_request1.json = AsyncMock(return_value={"input": "first"}) + mock_request1.headers = { + "X-HW-AgentGateway-Workload-Access-Token": "expired-token-from-request-1", + } + + request_context1 = app._build_request_context(mock_request1) + result1 = await app._invoke_handler(sync_handler, request_context1, False, {"input": "first"}) + + assert captured_tokens[-1] == "expired-token-from-request-1" + + AgentArtsRuntimeContext.set_workload_access_token("expired-token-from-request-1") + + mock_request2 = MagicMock(spec=Request) + mock_request2.json = AsyncMock(return_value={"input": "second"}) + mock_request2.headers = {} + + request_context2 = app._build_request_context(mock_request2) + result2 = await app._invoke_handler(sync_handler, request_context2, False, {"input": "second"}) + + assert captured_tokens[-1] == "expired-token-from-request-1", ( + "BUG: Request 2 sees token from Request 1! " + f"Expected None or empty header, but got: {captured_tokens[-1]}" + ) + + AgentArtsRuntimeContext.clear() + + @pytest.mark.asyncio + async def test_real_world_scenario_without_clear(self): + """ + Real-world scenario simulation: + + 1. User A makes request with valid token + 2. Token expires (simulated) + 3. User B makes request (no token header) + 4. User B's request FAILS because it uses User A's expired token + """ + app = AgentArtsRuntimeApp() + + @app.entrypoint + def sync_handler(payload): + token = AgentArtsRuntimeContext.get_workload_access_token() + return {"user": payload.get("user"), "token_used": token} + + AgentArtsRuntimeContext.clear() + + mock_request_user_a = MagicMock(spec=Request) + mock_request_user_a.json = AsyncMock(return_value={"user": "User-A"}) + mock_request_user_a.headers = { + "X-HW-AgentGateway-Workload-Access-Token": "user-a-token-WILL-EXPIRE", + } + + request_context_a = app._build_request_context(mock_request_user_a) + result_a = await app._invoke_handler(sync_handler, request_context_a, False, {"user": "User-A"}) + + assert result_a["token_used"] == "user-a-token-WILL-EXPIRE" + + AgentArtsRuntimeContext.set_workload_access_token("user-a-token-WILL-EXPIRE") + + mock_request_user_b = MagicMock(spec=Request) + mock_request_user_b.json = AsyncMock(return_value={"user": "User-B"}) + mock_request_user_b.headers = {} + + request_context_b = app._build_request_context(mock_request_user_b) + result_b = await app._invoke_handler(sync_handler, request_context_b, False, {"user": "User-B"}) + + assert result_b["token_used"] == "user-a-token-WILL-EXPIRE", ( + f"BUG: User-B is using User-A's token! " + f"User-B should have no token, but got: {result_b['token_used']}" + ) + + AgentArtsRuntimeContext.clear() + + @pytest.mark.asyncio + async def test_with_clear_fixes_the_problem(self): + """ + Demonstrate: With clear(), the problem is fixed. + """ + app = AgentArtsRuntimeApp() + + captured_tokens = [] + + @app.entrypoint + def sync_handler(payload): + captured_tokens.append(AgentArtsRuntimeContext.get_workload_access_token()) + return {"status": "ok"} + + AgentArtsRuntimeContext.clear() + + mock_request1 = MagicMock(spec=Request) + mock_request1.json = AsyncMock(return_value={"input": "first"}) + mock_request1.headers = { + "X-HW-AgentGateway-Workload-Access-Token": "token-A", + } + + request_context1 = app._build_request_context(mock_request1) + result1 = await app._invoke_handler(sync_handler, request_context1, False, {"input": "first"}) + + assert captured_tokens[-1] == "token-A" + + AgentArtsRuntimeContext.clear() + + mock_request2 = MagicMock(spec=Request) + mock_request2.json = AsyncMock(return_value={"input": "second"}) + mock_request2.headers = {} + + request_context2 = app._build_request_context(mock_request2) + result2 = await app._invoke_handler(sync_handler, request_context2, False, {"input": "second"}) + + assert captured_tokens[-1] is None, ( + f"FIXED: Request 2 should NOT see token from Request 1. Got: {captured_tokens[-1]}" + ) + + AgentArtsRuntimeContext.clear() diff --git a/tests/unit/sdk/utils/test_logging.py b/tests/unit/sdk/utils/test_logging.py new file mode 100644 index 0000000..518d68d --- /dev/null +++ b/tests/unit/sdk/utils/test_logging.py @@ -0,0 +1,167 @@ +"""Tests for SDK logging configuration utilities.""" + +import logging +from pathlib import Path + +from agentarts.sdk.utils.logging import ( + DEFAULT_LOG_LEVEL, + ENV_LOG_LEVEL, + get_log_level, + get_logger, + setup_logging, +) + + +class TestGetLogLevel: + """Tests for get_log_level function.""" + + def test_returns_default_when_env_not_set(self, monkeypatch): + """Returns default level when environment variable is not set.""" + monkeypatch.delenv(ENV_LOG_LEVEL, raising=False) + assert get_log_level() == DEFAULT_LOG_LEVEL + + def test_returns_env_value_when_set(self, monkeypatch): + """Returns environment variable value when set.""" + monkeypatch.setenv(ENV_LOG_LEVEL, "DEBUG") + assert get_log_level() == "DEBUG" + + def test_returns_env_value_uppercase(self, monkeypatch): + """Returns uppercase version of environment variable.""" + monkeypatch.setenv(ENV_LOG_LEVEL, "debug") + assert get_log_level() == "DEBUG" + + def test_returns_default_for_invalid_level(self, monkeypatch): + """Returns default for invalid level value.""" + monkeypatch.setenv(ENV_LOG_LEVEL, "INVALID") + assert get_log_level() == DEFAULT_LOG_LEVEL + + def test_supports_all_valid_levels(self, monkeypatch): + """Supports all valid log levels.""" + valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + for level in valid_levels: + monkeypatch.setenv(ENV_LOG_LEVEL, level) + assert get_log_level() == level + + +class TestSetupLogging: + """Tests for setup_logging function.""" + + def test_sets_sdk_logger_level(self): + """Sets the SDK logger level correctly.""" + sdk_logger = logging.getLogger("agentarts") + sdk_logger.handlers.clear() + setup_logging(level="DEBUG") + assert sdk_logger.level == logging.DEBUG + + def test_adds_stream_handler_by_default(self): + """Adds StreamHandler by default when no handler provided.""" + sdk_logger = logging.getLogger("agentarts") + sdk_logger.handlers.clear() + setup_logging(level="INFO") + assert len(sdk_logger.handlers) == 1 + assert isinstance(sdk_logger.handlers[0], logging.StreamHandler) + + def test_does_not_add_duplicate_handlers(self): + """Does not add duplicate handlers when called multiple times.""" + sdk_logger = logging.getLogger("agentarts") + sdk_logger.handlers.clear() + setup_logging(level="INFO") + setup_logging(level="DEBUG") + assert len(sdk_logger.handlers) == 1 + + def test_uses_custom_handler(self): + """Uses custom handler when provided.""" + sdk_logger = logging.getLogger("agentarts") + sdk_logger.handlers.clear() + custom_handler = logging.FileHandler("test.log", mode="w") + setup_logging(level="INFO", handler=custom_handler) + assert len(sdk_logger.handlers) == 1 + assert isinstance(sdk_logger.handlers[0], logging.FileHandler) + custom_handler.close() + Path("test.log").unlink() + + def test_uses_env_level_when_none_provided(self, monkeypatch): + """Uses environment variable level when None provided.""" + monkeypatch.setenv(ENV_LOG_LEVEL, "WARNING") + sdk_logger = logging.getLogger("agentarts") + sdk_logger.handlers.clear() + setup_logging() + assert sdk_logger.level == logging.WARNING + + def test_suppresses_urllib3_warnings(self): + """Suppresses urllib3 logger to WARNING level.""" + setup_logging(level="DEBUG") + urllib3_logger = logging.getLogger("urllib3") + assert urllib3_logger.level == logging.WARNING + + def test_suppresses_huaweicloudsdkcore_warnings(self): + """Suppresses huaweicloudsdkcore logger to WARNING level.""" + setup_logging(level="DEBUG") + sdk_logger = logging.getLogger("huaweicloudsdkcore") + assert sdk_logger.level == logging.WARNING + + +class TestGetLogger: + """Tests for get_logger function.""" + + def test_returns_logger_with_sdk_namespace(self): + """Returns logger with agentarts namespace.""" + logger = get_logger("test") + assert logger.name == "agentarts.test" + + def test_returns_logger_with_nested_namespace(self): + """Returns logger with nested namespace.""" + logger = get_logger("runtime.app") + assert logger.name == "agentarts.runtime.app" + + def test_returns_same_logger_for_same_name(self): + """Returns same logger instance for same name.""" + logger1 = get_logger("test") + logger2 = get_logger("test") + assert logger1 is logger2 + + +class TestLoggingIntegration: + """Integration tests for logging functionality.""" + + def test_logger_inherits_sdk_level(self, monkeypatch): + """Child logger inherits SDK logger level.""" + monkeypatch.delenv(ENV_LOG_LEVEL, raising=False) + sdk_logger = logging.getLogger("agentarts") + sdk_logger.handlers.clear() + setup_logging(level="DEBUG") + + child_logger = get_logger("runtime.app") + assert child_logger.getEffectiveLevel() == logging.DEBUG + + def test_debug_messages_visible_at_debug_level(self, caplog): + """Debug messages are visible at DEBUG level.""" + sdk_logger = logging.getLogger("agentarts") + sdk_logger.handlers.clear() + setup_logging(level="DEBUG") + + logger = get_logger("test") + with caplog.at_level(logging.DEBUG, logger="agentarts.test"): + logger.debug("Debug message") + assert "Debug message" in caplog.text + + def test_debug_messages_hidden_at_info_level(self, caplog): + """Debug messages are hidden at INFO level.""" + sdk_logger = logging.getLogger("agentarts") + sdk_logger.handlers.clear() + setup_logging(level="INFO") + + logger = get_logger("test") + logger.debug("Debug message") + assert "Debug message" not in caplog.text + + def test_info_messages_visible_at_info_level(self, caplog): + """Info messages are visible at INFO level.""" + sdk_logger = logging.getLogger("agentarts") + sdk_logger.handlers.clear() + setup_logging(level="INFO") + + logger = get_logger("test") + with caplog.at_level(logging.INFO, logger="agentarts.test"): + logger.info("Info message") + assert "Info message" in caplog.text diff --git a/tests/unit/toolkit/operations/runtime/test_config.py b/tests/unit/toolkit/operations/runtime/test_config.py index 2f54cf9..def0f2d 100644 --- a/tests/unit/toolkit/operations/runtime/test_config.py +++ b/tests/unit/toolkit/operations/runtime/test_config.py @@ -176,12 +176,12 @@ def test_uses_default_swr_repo_as_agent_prefix(self, tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) add_agent( - name="my-agent", + name="myagent", entrypoint="agent:app", ) config = load_config() - agent = config.get_agent("my-agent") + agent = config.get_agent("myagent") assert agent.swr_config.repository is None def test_sets_as_default_when_first_agent(self, tmp_path, monkeypatch): diff --git a/tests/unit/toolkit/operations/runtime/test_init.py b/tests/unit/toolkit/operations/runtime/test_init.py index c45a2ba..79bab37 100644 --- a/tests/unit/toolkit/operations/runtime/test_init.py +++ b/tests/unit/toolkit/operations/runtime/test_init.py @@ -109,7 +109,7 @@ def test_dockerfile_contains_required_sections(self, tmp_path): assert "FROM python:3.10-slim" in content assert "WORKDIR /app" in content assert "EXPOSE 8080" in content - assert "uvicorn" in content + assert "agent" in content class TestCreateConfigFile: @@ -139,12 +139,12 @@ def test_config_uses_default_swr_repo(self, tmp_path): """Config uses agent_{name} as default SWR repo.""" create_config_file( project_path=tmp_path, - name="my-agent", + name="myagent", template="basic", ) content = (tmp_path / ".agentarts_config.yaml").read_text() - assert "agent_my-agent" in content + assert "agent_myagent" in content def test_config_includes_region(self, tmp_path): """Config includes specified region."""