diff --git a/docs/cn/sdk_user_guide/tools_user_guide.md b/docs/cn/sdk_user_guide/tools_user_guide.md index a980a7d..aecb4cd 100644 --- a/docs/cn/sdk_user_guide/tools_user_guide.md +++ b/docs/cn/sdk_user_guide/tools_user_guide.md @@ -10,7 +10,7 @@ ## 环境要求 - Python 3.10+ -- 已安装 AgentArts SDK:`pip install agentarts` +- 已安装 AgentArts SDK:`pip install agentarts-sdk` ## 认证配置 @@ -32,9 +32,9 @@ export HUAWEICLOUD_SDK_SK="your-secret-key" 2. 进入"我的凭证"页面 3. 在"访问密钥"标签页创建或查看 AK/SK -### API Key 配置 +### 会话认证方式 -CodeInterpreter 会话认证需要 API Key,可通过以下方式配置: +CodeInterpreter 会话使用 API Key 进行认证,可通过以下方式配置: **方式一:环境变量配置(推荐)** @@ -48,6 +48,7 @@ export HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY="your-api-key" # 在 start_session 中传递 client.start_session( code_interpreter_name="my-code-interpreter", + session_name="my-session", api_key="your-api-key" ) @@ -56,6 +57,8 @@ with code_session("cn-southwest-2", "my-code-interpreter", api_key="your-api-key client.execute_code("print('Hello')") ``` +**注意**:如果代码解释器创建时使用了 `auth_type="IAM"`,则会话方法将自动使用 IAM 认证,无需传递 `api_key` 参数。 + ### 数据面端点配置 CodeInterpreter 数据面端点可以通过以下方式配置(按优先级排序): @@ -84,7 +87,7 @@ CodeInterpreter 数据面端点可以通过以下方式配置(按优先级排 `code_session` 是一个上下文管理器,自动管理会话的启动和停止,推荐用于大多数场景: ```python -from agentarts.sdk.tools.code_interpreter.code_interpreter_client import code_session +from agentarts.sdk.tools import code_session # 使用环境变量中的 API Key with code_session("cn-southwest-2", "my-code-interpreter-name") as client: @@ -111,7 +114,7 @@ with code_session("cn-southwest-2", "my-code-interpreter-name", api_key="your-ap 如果需要更细粒度的控制,可以手动管理会话: ```python -from agentarts.sdk.tools.code_interpreter.code_interpreter_client import CodeInterpreter +from agentarts.sdk.tools import CodeInterpreter # 创建客户端 client = CodeInterpreter(region="cn-southwest-2") @@ -127,7 +130,7 @@ session_id = client.start_session( result = client.execute_code("print('Hello')") # 停止会话 -client.stop_session(api_key="your-api-key") # 可选 +client.stop_session() # 可选 ``` --- @@ -138,7 +141,7 @@ CodeInterpreter ### 初始化 ```python -CodeInterpreter(region: Optional[str] = None, data_endpoint: Optional[str] = None) +CodeInterpreter(region: Optional[str] = None, data_endpoint: Optional[str] = None, auth_type: str = "API_KEY") ``` **参数说明**: @@ -147,11 +150,12 @@ CodeInterpreter(region: Optional[str] = None, data_endpoint: Optional[str] = Non |------|------|------|--------|------| | region | str | 否 | 从环境变量获取 | 华为云区域名称 | | data_endpoint | str | 否 | 从环境变量获取 | 数据面端点,优先从环境变量 AGENTARTS_CODEINTERPRETER_DATA_ENDPOINT 读取 | +| auth_type | str | 否 | "API_KEY" | 认证类型,支持 "API_KEY" 或 "IAM" | **使用示例**: ```python -# 使用环境变量配置 +# 使用环境变量配置(API Key 认证) client = CodeInterpreter(region="cn-southwest-2") # 通过参数指定数据面端点 @@ -160,6 +164,11 @@ client = CodeInterpreter( data_endpoint="https://your-custom-endpoint.com" ) +# 使用 IAM 认证 +client = CodeInterpreter( + region="cn-southwest-2", + auth_type="IAM" +) ``` ### 功能特性 @@ -179,9 +188,9 @@ client = CodeInterpreter( | 参数名 | 类型 | 描述 | | --- | --- | --- | |name |str| **Required** 代码解释器的名称,必须符合特定的命名规则| -|api_key_name |str| **Required** API Key 的名称| +|auth_type |str| 认证类型,支持 "API_KEY" 或 "IAM",默认 "API_KEY"| +|api_key_name |str| API Key 的名称,当 auth_type 为 "API_KEY" 时必填| |description |str| 代码解释器的描述信息, default: `None`| -|auth_type |str| 认证类型,例如 "API_KEY", default: `None`| |execution_agency_name |str| 执行机构的名称, default: `None`| |observability |Dict| 可观测性配置,例如日志和监控设置, default: `None`| |network_config |Dict| 网络配置,例如 VPC 和安全组设置, default: `None`| @@ -204,11 +213,19 @@ client = CodeInterpreter( **样例** ```python +# 使用 API Key 认证创建代码解释器 code_interpreter = client.create_code_interpreter( name="my-code-interpreter", + auth_type="API_KEY", api_key_name="my-api-key-name" ) code_interpreter_id = code_interpreter["id"] + +# 使用 IAM 认证创建代码解释器 +code_interpreter = client.create_code_interpreter( + name="my-code-interpreter", + auth_type="IAM" +) ``` #### 2. 查询代码解释器列表 @@ -325,7 +342,7 @@ client.delete_code_interpreter( | --- | --- | --- | |code_interpreter_name |str| **Required** 代码解释器的名称,用于识别和管理会话,名称唯一| |session_name |str| **Required** 会话名称| -|api_key |str| 认证使用的API Key,如果不提供则从环境变量HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY中获取, default: `None`| +|api_key |str| 认证使用的API Key,如果代码解释器创建时使用 API_KEY 认证,则需要提供此参数;如果不提供则从环境变量HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY中获取;如果代码解释器创建时使用 IAM 认证,则无需提供此参数, default: `None`| |session_timeout |int| 会话超时时间,单位为秒,默认15分钟,最小值为60秒,最大值为86400秒(24小时), default: `900`| **返回值** @@ -333,18 +350,24 @@ session_id (str): 会话ID **样例** ```python -# 使用环境变量中的 API Key +# 使用环境变量中的 API Key(代码解释器创建时使用 API_KEY 认证) session_id = client.start_session( code_interpreter_name="my-code-interpreter-name", session_name="my-session-name" ) -# 传入 API Key +# 传入 API Key(代码解释器创建时使用 API_KEY 认证) session_id = client.start_session( code_interpreter_name="my-code-interpreter-name", session_name="my-session-name", api_key="your-api-key" ) + +# 代码解释器创建时使用 IAM 认证,无需传入 api_key +session_id = client.start_session( + code_interpreter_name="my-code-interpreter-name", + session_name="my-session-name" +) ``` #### 7. 获取代码解释器会话详情 @@ -355,7 +378,7 @@ session_id = client.start_session( | --- | --- | --- | |code_interpreter_name |str| **Required** 代码解释器的名称,用于识别和管理会话,名称唯一| |session_id |str| 会话ID,默认使用当前会话ID, default: `None`| -|api_key |str| 认证使用的API Key,如果不提供则从环境变量HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY中获取, default: `None`| +|api_key |str| 认证使用的API Key,如果代码解释器创建时使用 API_KEY 认证,则需要提供此参数;如果不提供则从环境变量HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY中获取;如果代码解释器创建时使用 IAM 认证,则无需提供此参数, default: `None`| **返回值** 包含会话详情的字典: @@ -367,6 +390,12 @@ session_id = client.start_session( **样例** ```python +# 使用环境变量中的 API Key(代码解释器创建时使用 API_KEY 认证) +session_info = client.get_session( + code_interpreter_name="my-code-interpreter-name" +) + +# 代码解释器创建时使用 IAM 认证,无需传入 api_key session_info = client.get_session( code_interpreter_name="my-code-interpreter-name" ) @@ -378,17 +407,21 @@ session_info = client.get_session( **参数** | 参数名 | 类型 | 描述 | | --- | --- | --- | -|api_key |str| 认证使用的API Key,如果不提供则从环境变量HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY中获取, default: `None`| +|api_key |str| 认证使用的API Key,如果代码解释器创建时使用 API_KEY 认证,则需要提供此参数;如果不提供则从环境变量HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY中获取;如果代码解释器创建时使用 IAM 认证,则无需提供此参数, default: `None`| **返回值** bool: 没有活跃会话时返回True,否则返回False **样例** ```python +# 使用环境变量中的 API Key(代码解释器创建时使用 API_KEY 认证) client.stop_session() -# 传入 API Key +# 传入 API Key(代码解释器创建时使用 API_KEY 认证) client.stop_session(api_key="your-api-key") + +# 代码解释器创建时使用 IAM 认证,无需传入 api_key +client.stop_session() ``` #### 9. 调用代码解释器会话 @@ -399,13 +432,24 @@ client.stop_session(api_key="your-api-key") | --- | --- | --- | |operate_type |str| **Required** 调用方法名,"execute_code"或"execute_command"等| |arguments |Dict| **Required** 调用参数,根据operate_type不同而不同| -|api_key |str| 认证使用的API Key,如果不提供则从环境变量HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY中获取, default: `None`| +|api_key |str| 认证使用的API Key,如果代码解释器创建时使用 API_KEY 认证,则需要提供此参数;如果不提供则从环境变量HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY中获取;如果代码解释器创建时使用 IAM 认证,则无需提供此参数, default: `None`| **返回值** result[Dict]: 包含调用结果的字典 **样例** ```python +# 使用环境变量中的 API Key(代码解释器创建时使用 API_KEY 认证) +result = client.invoke( + operate_type="execute_code", + arguments={ + "clear_context": False, + "code": "print('Hello, World!')", + "language": "python" + } +) + +# 代码解释器创建时使用 IAM 认证,无需传入 api_key result = client.invoke( operate_type="execute_code", arguments={ @@ -581,7 +625,7 @@ result = client.clear_context() client.execute_code("print(x)") ``` -### 上下文管理器 +###上下文管理器 **方法名** `code_session` **参数** @@ -589,15 +633,20 @@ client.execute_code("print(x)") | --- | --- | --- | |region |str| **Required** region名称,如"cn-southwest-2"| |code_interpreter_name |str| **Required** 代码解释器名称| -|api_key |str| 认证使用的API Key,如果不提供则从环境变量HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY中获取, default: `None`| +|auth_type |str| 认证类型,支持 "API_KEY" 或 "IAM",默认 "API_KEY"| +|api_key |str| 认证使用的API Key,如果代码解释器创建时使用 API_KEY 认证,则需要提供此参数;如果不提供则从环境变量HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY中获取;如果代码解释器创建时使用 IAM 认证,则无需提供此参数, default: `None`| **样例** ```python -# 使用环境变量中的 API Key +# 使用环境变量中的 API Key(代码解释器创建时使用 API_KEY 认证) with code_session("cn-southwest-2", "my-code-interpreter-name") as client: client.execute_code("print('Hello, World!')") -# 传入 API Key +# 传入 API Key(代码解释器创建时使用 API_KEY 认证) with code_session("cn-southwest-2", "my-code-interpreter-name", api_key="your-api-key") as client: client.execute_code("print('Hello, World!')") + +# 使用 IAM 认证 +with code_session("cn-southwest-2", "my-code-interpreter-name", auth_type="IAM") as client: + client.execute_code("print('Hello, World!')") ``` diff --git a/examples/agent_tools/README.md b/examples/agent_tools/README.md index d832e4f..378d886 100644 --- a/examples/agent_tools/README.md +++ b/examples/agent_tools/README.md @@ -24,44 +24,56 @@ pip install -r requirements.txt ``` ```python -import os import json -from typing import TypedDict, Union +import os +from typing import TypedDict -from langgraph.graph import StateGraph, END -from langchain_openai import ChatOpenAI -from langchain_core.messages import HumanMessage, SystemMessage, AIMessage +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.tools import tool - +from langchain_openai import ChatOpenAI +from langgraph.graph import END, StateGraph from langgraph.prebuilt import ToolNode -from agentarts.sdk.tools import code_session + from agentarts.sdk import AgentArtsRuntimeApp +from agentarts.sdk.tools import code_session ``` ### 2. 定义系统提示词 定义Agent的行为和能力 ```python app = AgentArtsRuntimeApp() -SYSTEM_PROMPT = """你是一个通过代码执行验证所有答案的优秀AI助手 +SYSTEM_PROMPT = """你是一个优秀的AI助手,擅长通过代码执行验证答案的正确性。 验证原则: -1. 当需要代码,算法或者计算来验证时,你需要编写代码来验证它们。 -2. 使用execute_python_tool工具来测试数学计算,算法和逻辑 -3. 返回答案前,使用测试脚本来验证你的理解 +1. 当需要精确计算、数值验证或算法验证时,必须编写代码来验证结果 +2. 使用execute_python_tool工具执行代码进行验证 +3. 返回答案前,使用测试脚本来验证你的理解和计算 4. 只能通过实际的代码执行展示工作过程 5. 如果存在不确定的情况,详细说明限制条件并尽可能做验证 -方法: -- 如果问题涉及编程,通过代码实现 -- 如果要求你计算,编写程序计算并显示具体代码 -- 如果需要实现算法,你还要编写测试用例来进行确认 -- 记录验证的过程展示给用户 - -工具: -- execute_python_tool: 执行Python代码并返回结果 - -响应格式:execute_python_tool, 包括: -- content: 内容对象的数组,每个对象包含type和text/data""" +需要代码验证的场景 +- 数学计算:包括算术运算、代数计算、概率统计、数列求和、几何计算等 +- 算法验证:需要验证算法正确性,实现逻辑或性能测试时 +- 数据处理:对数据进行统计分析、排序、查找等操作时 +- 任何需要精确结果的问题,当口算或者估算无法保证准确性时 + +强制要求: +- 你必须使用execute_python_tool工具来执行python代码 +- 涉及计算的问题,编写程序计算并显示代码和结果 +- 每次给出最终答案前,至少执行一次验证代码 +- 如果工具调用失败,明确告知用户 +- 将代码执行结果作为答案的重要依据 + +可用工具: +- execute_python_tool(code: str, description: str): 在沙箱环境中执行Python代码并返回结果 + * code: 要执行的Python代码 + * description: 对代码的描述,用于上下文理解 + +响应格式要求: +- 优先展示验证代码和执行结果 +- 清晰说明每一步的计算逻辑 +- 最终答案必须基于代码执行结果 +""" ``` ### 3. 定义代码执行工具 @@ -73,19 +85,24 @@ def execute_python_tool(code: str, description: str) -> str | None: if description: code = f"# {description}\n{code}" - + print(f"\n Generated Code: {code}") - with code_session("your_region", "your_code_interpreter_name") as code_client: + # 需配置环境 HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY + api_key = os.environ.get( + "HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY", "" + ) # 配置环境变量后,api_key无需在代码中传递亦可正常工作 + with code_session("your_region", "your_code_interpreter_name", api_key) as code_client: response = code_client.invoke( operate_type="execute_code", + api_key=api_key, arguments={ "code": code, "language": "python", "clear_context": False, - } + }, ) - + return json.dumps(response["result"]) ``` @@ -104,33 +121,38 @@ llm = ChatOpenAI( # 创建工具列表 tools = [execute_python_tool] # 工具绑定Agent -llm.bind_tools(tools) +llm = llm.bind_tools(tools) # 定义graph状态 class AgentState(TypedDict): - messages: list[Union[HumanMessage, SystemMessage, AIMessage]] + messages: list[HumanMessage | SystemMessage | AIMessage] + def call_model(state: AgentState): """调用模型并返回响应""" - if not state["messages"] or all(not isinstance(msg, SystemMessage) for msg in state["messages"]): + if not state["messages"] or all( + not isinstance(msg, SystemMessage) for msg in state["messages"] + ): messages = [SystemMessage(content=SYSTEM_PROMPT)] + state["messages"] else: messages = state["messages"] - + response = llm.invoke(messages) return {"messages": [response]} + def should_continue(state): """判断是否继续使用工具""" last_message = state["messages"][-1] # 如果包含工具调用,则继续执行 - if last_message.tool_calls: + if isinstance(last_message, AIMessage) and last_message.tool_calls: return "tools" # 否则结束 return END + # 创建LangGraph工作流 workflow = StateGraph(AgentState) @@ -141,36 +163,28 @@ workflow.add_node("tools", ToolNode(tools)) workflow.set_entry_point("agent") # 添加边 -workflow.add_conditional_edges( - "agent", - should_continue, - { - "tools": "tools", - "__end__": "__end__" - } -) +workflow.add_conditional_edges("agent", should_continue, {"tools": "tools", "__end__": "__end__"}) workflow.add_edge("tools", "agent") agent = workflow.compile() ``` ## 5. 定义问题 ```python -query = "告诉我1到100之间最大的随机质数" +query = "告诉我1到100之间最大的质数" ``` ## 6. Agent执行与响应 ```python @app.entrypoint def agent_chat(): - query = "告诉我1到100之间最大的随机质数" + query = "告诉我1到100之间最大的质数" # 运行Agent - result = agent.invoke({ - "messages": [HumanMessage(content=query)] - }) + result = agent.invoke({"messages": [HumanMessage(content=query)]}) print(result["messages"][-1].content) + if __name__ == "__main__": app.run() ``` diff --git a/examples/agent_tools/integrate_tools.py b/examples/agent_tools/integrate_tools.py index 066ac70..ba2f666 100644 --- a/examples/agent_tools/integrate_tools.py +++ b/examples/agent_tools/integrate_tools.py @@ -12,26 +12,39 @@ from agentarts.sdk.tools import code_session app = AgentArtsRuntimeApp() -SYSTEM_PROMPT = """你是一个通过代码执行验证所有答案的优秀AI助手 +SYSTEM_PROMPT = """你是一个优秀的AI助手,擅长通过代码执行验证答案的正确性。 验证原则: -1. 当需要代码,算法或者计算来验证时,你需要编写代码来验证它们。 -2. 使用execute_python_tool工具来测试数学计算,算法和逻辑 -3. 返回答案前,使用测试脚本来验证你的理解 +1. 当需要精确计算、数值验证或算法验证时,必须编写代码来验证结果 +2. 使用execute_python_tool工具执行代码进行验证 +3. 返回答案前,使用测试脚本来验证你的理解和计算 4. 只能通过实际的代码执行展示工作过程 5. 如果存在不确定的情况,详细说明限制条件并尽可能做验证 -方法: -- 如果问题涉及编程,通过代码实现 -- 如果要求你计算,编写程序计算并显示具体代码 -- 如果需要实现算法,你还要编写测试用例来进行确认 -- 记录验证的过程展示给用户 +需要代码验证的场景 +- 数学计算:包括算术运算、代数计算、概率统计、数列求和、几何计算等 +- 算法验证:需要验证算法正确性,实现逻辑或性能测试时 +- 数据处理:对数据进行统计分析、排序、查找等操作时 +- 任何需要精确结果的问题,当口算或者估算无法保证准确性时 -工具: -- execute_python_tool: 执行Python代码并返回结果 +强制要求: +- 你必须使用execute_python_tool工具来执行python代码 +- 涉及计算的问题,编写程序计算并显示代码和结果 +- 每次给出最终答案前,至少执行一次验证代码 +- 如果工具调用失败,明确告知用户 +- 将代码执行结果作为答案的重要依据 + +可用工具: +- execute_python_tool(code: str, description: str): 在沙箱环境中执行Python代码并返回结果 + * code: 要执行的Python代码 + * description: 对代码的描述,用于上下文理解 + +响应格式要求: +- 优先展示验证代码和执行结果 +- 清晰说明每一步的计算逻辑 +- 最终答案必须基于代码执行结果 +""" -响应格式:execute_python_tool, 包括: -- content: 内容对象的数组,每个对象包含type和text/data""" @tool def execute_python_tool(code: str, description: str) -> str | None: @@ -42,14 +55,19 @@ def execute_python_tool(code: str, description: str) -> str | None: print(f"\n Generated Code: {code}") - with code_session("your_region", "your_code_interpreter_name") as code_client: + # 需配置环境 HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY + api_key = os.environ.get( + "HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY", "" + ) # 配置环境变量后,api_key无需在代码中传递亦可正常工作 + with code_session("your_region", "your_code_interpreter_name", api_key) as code_client: response = code_client.invoke( operate_type="execute_code", + api_key=api_key, arguments={ "code": code, "language": "python", "clear_context": False, - } + }, ) return json.dumps(response["result"]) @@ -67,15 +85,19 @@ def execute_python_tool(code: str, description: str) -> str | None: # 创建工具列表 tools = [execute_python_tool] # 工具绑定Agent -llm.bind_tools(tools) +llm = llm.bind_tools(tools) + # 定义graph状态 class AgentState(TypedDict): messages: list[HumanMessage | SystemMessage | AIMessage] + def call_model(state: AgentState): """调用模型并返回响应""" - if not state["messages"] or all(not isinstance(msg, SystemMessage) for msg in state["messages"]): + if not state["messages"] or all( + not isinstance(msg, SystemMessage) for msg in state["messages"] + ): messages = [SystemMessage(content=SYSTEM_PROMPT)] + state["messages"] else: messages = state["messages"] @@ -83,17 +105,19 @@ def call_model(state: AgentState): response = llm.invoke(messages) return {"messages": [response]} + def should_continue(state): """判断是否继续使用工具""" last_message = state["messages"][-1] # 如果包含工具调用,则继续执行 - if last_message.tool_calls: + if isinstance(last_message, AIMessage) and last_message.tool_calls: return "tools" # 否则结束 return END + # 创建LangGraph工作流 workflow = StateGraph(AgentState) @@ -104,27 +128,20 @@ def should_continue(state): workflow.set_entry_point("agent") # 添加边 -workflow.add_conditional_edges( - "agent", - should_continue, - { - "tools": "tools", - "__end__": "__end__" - } -) +workflow.add_conditional_edges("agent", should_continue, {"tools": "tools", "__end__": "__end__"}) workflow.add_edge("tools", "agent") agent = workflow.compile() + @app.entrypoint def agent_chat(): - query = "告诉我1到100之间最大的随机质数" + query = "告诉我1到100之间最大的质数" # 运行Agent - result = agent.invoke({ - "messages": [HumanMessage(content=query)] - }) + result = agent.invoke({"messages": [HumanMessage(content=query)]}) print(result["messages"][-1].content) + if __name__ == "__main__": app.run() diff --git a/src/agentarts/sdk/service/tools_http.py b/src/agentarts/sdk/service/tools_http.py index e1e59b9..e377c7f 100644 --- a/src/agentarts/sdk/service/tools_http.py +++ b/src/agentarts/sdk/service/tools_http.py @@ -2,7 +2,7 @@ from typing import Any -from .http_client import BaseHTTPClient, RequestConfig +from .http_client import BaseHTTPClient, RequestConfig, SignMode class ToolsAPIError(BaseException): @@ -19,6 +19,7 @@ def __init__(self, status_code: int, error_msg: str): self.error_msg = error_msg super().__init__(f"Tools API Error: {error_msg}") + class ControlToolsHttpClient(BaseHTTPClient): def __init__(self, region_name: str, endpoint_url: str): request_config = RequestConfig(base_url=endpoint_url, verify_ssl=False) @@ -47,7 +48,9 @@ def list_code_interpreters(self, request_params: dict) -> dict[Any, Any]: raise ToolsAPIError(response.status_code, response.error) return response.data - def update_code_interpreter(self, code_interpreter_id: str, request_params: dict) -> dict[Any, Any]: + def update_code_interpreter( + self, code_interpreter_id: str, request_params: dict + ) -> dict[Any, Any]: """PUT v1/core/code-interpreters/{code_interpreter_id} Update a code interpreter. @@ -81,40 +84,68 @@ def delete_code_interpreter(self, code_interpreter_id: str): class DataToolsHttpClient(BaseHTTPClient): - def __init__(self, region_name: str, endpoint_url: str): - super().__init__(RequestConfig(base_url=endpoint_url, verify_ssl=False)) + def __init__(self, region_name: str, endpoint_url: str, auth_type: str = "API_KEY"): + """Initialize the data tools HTTP client. + + Args: + region_name (str): The region name + endpoint_url (str): The endpoint URL for data plane API + auth_type (str, optional): Authentication type, supports "API_KEY" or "IAM". Defaults to "API_KEY" + """ + if auth_type == "IAM": + super().__init__( + RequestConfig(base_url=endpoint_url, verify_ssl=False), + open_ak_sk=True, + sign_mode=SignMode.V11_HMAC_SHA256, + region_id=region_name, + ) + else: + super().__init__(RequestConfig(base_url=endpoint_url, verify_ssl=False)) self.region_name = region_name - def start_session(self, code_interpreter_name: str, api_key: str, request_params: dict) -> dict[Any, Any]: + @property + def open_ak_sk(self) -> bool: + return self._open_ak_sk + + @open_ak_sk.setter + def open_ak_sk(self, open_ak_sk: bool): + self._open_ak_sk = open_ak_sk + + def start_session( + self, code_interpreter_name: str, request_params: dict, api_key: str | None = None + ) -> dict[Any, Any]: """PUT v1/code-interpreters/{code_interpreter_name}/sessions-start Start a code interpreter session. """ endpoint = f"/v1/code-interpreters/{code_interpreter_name}/sessions-start" - headers = { - "Authorization": f"Bearer {api_key}" - } + headers = {} + if api_key is not None: + headers = {"Authorization": f"Bearer {api_key}"} response = self.put(url=endpoint, json=request_params, headers=headers) if not response.success: raise ToolsAPIError(response.status_code, response.error) return response.data - def stop_session(self, code_interpreter_name: str, session_id: str, api_key: str) -> dict[Any, Any]: - """POST v1/code-interpreters/{code_interpreter_name}/sessions-stop + def stop_session( + self, code_interpreter_name: str, session_id: str, api_key: str | None = None + ) -> dict[Any, Any]: + """PUT v1/code-interpreters/{code_interpreter_name}/sessions-stop Stop a code interpreter session. """ endpoint = f"/v1/code-interpreters/{code_interpreter_name}/sessions-stop" - headers = { - "x-HW-Agentarts-Code-Interpreter-Session-Id": session_id, - "Authorization": f"Bearer {api_key}" - } + headers = {"x-HW-Agentarts-Code-Interpreter-Session-Id": session_id} + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" response = self.put(url=endpoint, headers=headers) if not response.success: raise ToolsAPIError(response.status_code, response.error) return response.data - def get_session(self, code_interpreter_name: str, session_id: str, api_key: str) -> dict[Any, Any]: + def get_session( + self, code_interpreter_name: str, session_id: str, api_key: str | None = None + ) -> dict[Any, Any]: """GET v1/code-interpreters/{code_interpreter_name}/sessions-get Get code interpreter session details. @@ -122,19 +153,20 @@ def get_session(self, code_interpreter_name: str, session_id: str, api_key: str) endpoint = f"/v1/code-interpreters/{code_interpreter_name}/sessions-get" headers = { "x-HW-Agentarts-Code-Interpreter-Session-Id": session_id, - "Authorization": f"Bearer {api_key}" } + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" response = self.get(url=endpoint, headers=headers) if not response.success: raise ToolsAPIError(response.status_code, response.error) return response.data def invoke( - self, - code_interpreter_name: str, - session_id: str, - api_key: str, - arguments: dict | None = None, + self, + code_interpreter_name: str, + session_id: str, + arguments: dict | None = None, + api_key: str | None = None, ) -> dict[Any, Any]: """POST v1/code-interpreters/{code_interpreter_name}/invoke @@ -143,8 +175,9 @@ def invoke( endpoint = f"/v1/code-interpreters/{code_interpreter_name}/invoke" headers = { "x-HW-Agentarts-Code-Interpreter-Session-Id": session_id, - "Authorization": f"Bearer {api_key}" } + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" response = self.post(url=endpoint, headers=headers, json=arguments) if not response.success: raise ToolsAPIError(response.status_code, response.error) diff --git a/src/agentarts/sdk/tools/__init__.py b/src/agentarts/sdk/tools/__init__.py index 29b1847..82c0abd 100644 --- a/src/agentarts/sdk/tools/__init__.py +++ b/src/agentarts/sdk/tools/__init__.py @@ -6,7 +6,10 @@ - code_session: Context manager for code interpreter sessions """ -from agentarts.sdk.tools.code_interpreter import CodeInterpreter, code_session +from agentarts.sdk.tools.code_interpreter.code_interpreter_client import ( + CodeInterpreter, + code_session, +) __all__ = [ "CodeInterpreter", diff --git a/src/agentarts/sdk/tools/code_interpreter/__init__.py b/src/agentarts/sdk/tools/code_interpreter/__init__.py index 96d3664..20111bc 100644 --- a/src/agentarts/sdk/tools/code_interpreter/__init__.py +++ b/src/agentarts/sdk/tools/code_interpreter/__init__.py @@ -6,7 +6,4 @@ from .code_interpreter_client import CodeInterpreter, code_session -__all__ = [ - "CodeInterpreter", - "code_session" -] +__all__ = ["CodeInterpreter", "code_session"] diff --git a/src/agentarts/sdk/tools/code_interpreter/code_interpreter_client.py b/src/agentarts/sdk/tools/code_interpreter/code_interpreter_client.py index 57e1afe..12aeecb 100644 --- a/src/agentarts/sdk/tools/code_interpreter/code_interpreter_client.py +++ b/src/agentarts/sdk/tools/code_interpreter/code_interpreter_client.py @@ -11,6 +11,7 @@ Manages the full lifecycle of code interpreter sessions (create, stop, get, invoke) """ + import base64 import logging import os @@ -44,30 +45,30 @@ class CodeInterpreter: data_plane_client: Client for interacting with data plane API """ - def __init__(self, region: str | None, data_endpoint: str | None = None) -> None: + def __init__(self, region: str | None, data_endpoint: str | None = None, auth_type: str = "API_KEY") -> None: """Initialize the code interpreter client in the specified region. Args: region: The specified region data_endpoint: Data plane endpoint, optional. If not provided, will be retrieved from environment variable AGENTARTS_CODEINTERPRETER_DATA_ENDPOINT + auth_type: Authentication type, optional. Defaults to "API_KEY" """ region = region or get_region() # Control plane client for managing code interpreters self.control_plane_client = ControlToolsHttpClient( - region_name=region, - endpoint_url=get_control_plane_endpoint() + region_name=region, endpoint_url=get_control_plane_endpoint() ) # Data plane client for managing code interpreter sessions # Priority: environment variable > parameter > default value endpoint_url = get_code_interpreter_data_plane_endpoint(endpoint=data_endpoint) - self.data_plane_client = DataToolsHttpClient( - region_name=region, - endpoint_url=endpoint_url - ) + if auth_type == "IAM": + self.data_plane_client = DataToolsHttpClient(region_name=region, endpoint_url=endpoint_url, auth_type=auth_type) + else: + self.data_plane_client = DataToolsHttpClient(region_name=region, endpoint_url=endpoint_url) self._code_interpreter_name = None self._session_id = None @@ -111,9 +112,9 @@ def session_id(self, session_id: str) -> None: def create_code_interpreter( self, name: str, - api_key_name: str, + auth_type: str = "API_KEY", + api_key_name: str | None = None, description: str | None = None, - auth_type: str | None = None, execution_agency_name: str | None = None, observability: dict | None = None, network_config: dict | None = None, @@ -126,9 +127,9 @@ def create_code_interpreter( Args: name (str): The code interpreter name, must follow specific naming rules - api_key_name (str): The API Key name + auth_type (str): Authentication type, supported values: "API_KEY", "IAM". default "API_KEY" + api_key_name (Optional[str]): The API Key name description (Optional[str]): The code interpreter description - auth_type (Optional[str]): Authentication type, e.g., "API_KEY" execution_agency_name (Optional[str]): IAM agency name observability (Optional[Dict]): Observability configuration, e.g., logging and monitoring settings network_config (Optional[Dict]): Network configuration, e.g., VPC and security group settings @@ -152,7 +153,7 @@ def create_code_interpreter( Example: >>> code_interpreter = client.create_code_interpreter( ... name="my-code-interpreter", - ... api_key_name="my-api-key-name", + ... auth_type="API_KEY" ... ) >>> code_interpreter_id = code_interpreter["id"] """ @@ -161,16 +162,16 @@ def create_code_interpreter( if not bool(re.match(pattern, name)): msg = "Name must match the pattern, please check your code_interpreter_name." raise ValueError(msg) + if auth_type == "API_KEY" and api_key_name is None: + msg = "API_KEY auth_type requires api_key_name." + raise ValueError(msg) - request_params = { - "name": name, - "api_key_name": api_key_name, - } + request_params = {"name": name, "auth_type": auth_type} + if api_key_name: + request_params["api_key_name"] = api_key_name if description: request_params["description"] = description - if auth_type: - request_params["auth_type"] = auth_type if execution_agency_name: request_params["execution_agency_name"] = execution_agency_name if observability: @@ -182,9 +183,7 @@ def create_code_interpreter( if tags: request_params["tags"] = tags - return self.control_plane_client.create_code_interpreter( - request_params=request_params - ) + return self.control_plane_client.create_code_interpreter(request_params=request_params) def list_code_interpreters( self, @@ -234,9 +233,7 @@ def list_code_interpreters( # Remove None values request_params = {k: v for k, v in request_params.items() if v is not None} - return self.control_plane_client.list_code_interpreters( - request_params=request_params - ) + return self.control_plane_client.list_code_interpreters(request_params=request_params) def update_code_interpreter( self, @@ -279,8 +276,7 @@ def update_code_interpreter( if tags is not None: request_params["tags"] = tags return self.control_plane_client.update_code_interpreter( - code_interpreter_id=code_interpreter_id, - request_params=request_params + code_interpreter_id=code_interpreter_id, request_params=request_params ) def get_code_interpreter(self, code_interpreter_id: str) -> dict: @@ -328,16 +324,14 @@ def delete_code_interpreter(self, code_interpreter_id: str) -> None: >>> ) """ logging.info(f"Deleting code interpreter {code_interpreter_id}") - self.control_plane_client.delete_code_interpreter( - code_interpreter_id=code_interpreter_id - ) + self.control_plane_client.delete_code_interpreter(code_interpreter_id=code_interpreter_id) def start_session( self, code_interpreter_name: str, session_name: str, api_key: str | None = None, - session_timeout: int | None = DEFAULT_TIMEOUT + session_timeout: int | None = DEFAULT_TIMEOUT, ) -> str: """Start a code interpreter session. @@ -347,7 +341,8 @@ def start_session( Args: code_interpreter_name (str): The code interpreter name, used to identify and manage sessions, must be unique session_name (str): The session name - api_key (Optional[str]): API Key for authentication, if not provided will be retrieved from environment variable API_KEY + api_key (Optional[str]): API Key for authentication, use only when auth_type is "API_KEY", + if not provided will be retrieved from environment variable API_KEY session_timeout (Optional[int]): Session timeout in seconds, default 15 minutes, minimum 60 seconds, maximum 86400 seconds (24 hours) @@ -365,29 +360,40 @@ def start_session( "name": session_name, "session_timeout": session_timeout, } - if api_key is None: - api_key = os.getenv("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") - response = self.data_plane_client.start_session( - code_interpreter_name=code_interpreter_name, - api_key=api_key, - request_params=request_params - ) + + if self.data_plane_client.open_ak_sk: + response = self.data_plane_client.start_session( + code_interpreter_name=code_interpreter_name, request_params=request_params + ) + else: + # default use API_KEY authentication + api_key = api_key or os.getenv("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") + if api_key is None: + msg = "API Key is not provided and not found in environment variable." + raise ValueError(msg) + response = self.data_plane_client.start_session( + code_interpreter_name=code_interpreter_name, + request_params=request_params, + api_key=api_key, + ) + self.session_id = response["session_id"] self.code_interpreter_name = code_interpreter_name return self.session_id def get_session( - self, - code_interpreter_name: str, - session_id: str | None = None, - api_key: str | None = None + self, + code_interpreter_name: str, + session_id: str | None = None, + api_key: str | None = None, ) -> dict: """Get code interpreter session details. Args: code_interpreter_name (str): The code interpreter name, used to identify and manage sessions, must be unique session_id (Optional[str]): The session ID, defaults to current session ID - api_key (Optional[str]): API Key for authentication, if not provided will be retrieved from environment variable API_KEY + api_key (Optional[str]): API Key for authentication, use only when auth_type is "API_KEY", + if not provided will be retrieved from environment variable API_KEY Returns: Dict: Dictionary containing session details @@ -408,12 +414,19 @@ def get_session( if not code_interpreter_name or not session_id: msg = "code_interpreter_name and session_id are required" raise ValueError(msg) + + if self.data_plane_client.open_ak_sk: + return self.data_plane_client.get_session( + code_interpreter_name=code_interpreter_name, session_id=session_id + ) + + # default use API_KEY authentication + api_key = api_key or os.getenv("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") if api_key is None: - api_key = os.getenv("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") + msg = "API Key is not provided and not found in environment variable." + raise ValueError(msg) return self.data_plane_client.get_session( - code_interpreter_name=code_interpreter_name, - session_id=session_id, - api_key=api_key + code_interpreter_name=code_interpreter_name, session_id=session_id, api_key=api_key ) def stop_session(self, api_key: str | None = None) -> bool: @@ -422,7 +435,8 @@ def stop_session(self, api_key: str | None = None) -> bool: Terminates any active session and clears session state. Args: - api_key (Optional[str]): API Key for authentication, if not provided will be retrieved from environment variable API_KEY + api_key (Optional[str]): API Key for authentication, use only when auth_type is "API_KEY", + if not provided will be retrieved from environment variable API_KEY Returns: bool: Returns True when no active session, otherwise False after stopping @@ -434,24 +448,31 @@ def stop_session(self, api_key: str | None = None) -> bool: if not self.session_id or not self.code_interpreter_name: return True - if api_key is None: - api_key = os.getenv("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") - - self.data_plane_client.stop_session( - code_interpreter_name=self.code_interpreter_name, - session_id=self.session_id, - api_key=api_key - ) + if self.data_plane_client.open_ak_sk: + self.data_plane_client.stop_session( + code_interpreter_name=self.code_interpreter_name, session_id=self.session_id + ) + else: + # default use API_KEY authentication + api_key = api_key or os.getenv("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") + if api_key is None: + msg = "API Key is not provided and not found in environment variable." + raise ValueError(msg) + self.data_plane_client.stop_session( + code_interpreter_name=self.code_interpreter_name, + session_id=self.session_id, + api_key=api_key, + ) self.code_interpreter_name = None self.session_id = None return True def invoke( - self, - operate_type: str, - arguments: dict, - api_key: str | None = None + self, + operate_type: str, + arguments: dict, + api_key: str | None = None, ) -> dict[str, Any]: """Invoke a code interpreter session. @@ -460,7 +481,8 @@ def invoke( Args: operate_type (str): The operation method name, e.g., "execute_code" or "execute_command" arguments (Dict): Invocation arguments, varies based on operate_type - api_key (Optional[str]): API Key for authentication, if not provided will be retrieved from environment variable API_KEY + api_key (Optional[str]): API Key for authentication, use only when auth_type is "API_KEY", + if not provided will be retrieved from environment variable API_KEY Returns: result[Dict]: Dictionary containing the invocation result @@ -479,26 +501,31 @@ def invoke( msg = "No Code Interpreter exists, use create_code_interpreter method first" raise ValueError(msg) - request_params = { - "operate_type": operate_type, - "arguments": arguments - } + request_params = {"operate_type": operate_type, "arguments": arguments} + if self.data_plane_client.open_ak_sk: + return self.data_plane_client.invoke( + code_interpreter_name=self.code_interpreter_name, + session_id=self.session_id, + arguments=request_params, + ) + # default use API_KEY authentication + api_key = api_key or os.getenv("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") if api_key is None: - api_key = os.getenv("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") - + msg = "API Key is not provided and not found in environment variable." + raise ValueError(msg) return self.data_plane_client.invoke( code_interpreter_name=self.code_interpreter_name, session_id=self.session_id, + arguments=request_params, api_key=api_key, - arguments=request_params ) def execute_code( - self, - code: str, - language: str = "python", - clear_context: bool = False, + self, + code: str, + language: str = "python", + clear_context: bool = False, ) -> dict[str, Any]: """Execute code in the code interpreter. @@ -529,11 +556,7 @@ def execute_code( return self.invoke( operate_type="execute_code", - arguments={ - "code": code, - "language": language, - "clear_context": clear_context - } + arguments={"code": code, "language": language, "clear_context": clear_context}, ) def execute_command(self, command: str) -> dict[str, Any]: @@ -559,8 +582,7 @@ def execute_command(self, command: str) -> dict[str, Any]: raise ValueError(msg) # Check for common injection patterns - strict_block_pattrns = [ - ] + strict_block_pattrns = [] for pattern in strict_block_pattrns: if re.search(pattern, command): @@ -569,18 +591,13 @@ def execute_command(self, command: str) -> dict[str, Any]: logger.info(f"Executing command: {command}") - return self.invoke( - operate_type="execute_command", - arguments={ - "command": command - } - ) + return self.invoke(operate_type="execute_command", arguments={"command": command}) def upload_file( - self, - path: str, - content: str | bytes, - description: str = "", + self, + path: str, + content: str | bytes, + description: str = "", ) -> dict[str, Any]: """Upload a file to the code interpreter. @@ -621,17 +638,9 @@ def upload_file( else: logger.info(f"Uploading file to {path} without description") - return self.invoke( - operate_type="write_files", - arguments={ - "write_contents": [file_content] - } - ) + return self.invoke(operate_type="write_files", arguments={"write_contents": [file_content]}) - def upload_files( - self, - files: list[dict[str, str]] - ) -> dict[str, Any]: + def upload_files(self, files: list[dict[str, str]]) -> dict[str, Any]: """Upload multiple files to the code interpreter. This operation is atomic, all files will be uploaded successfully or fail together, @@ -656,7 +665,7 @@ def upload_files( """ file_contents = [] for file_spec in files: - path = file_spec.get("path") + path = file_spec.get("path", "") content = file_spec.get("content") if not path.startswith("/"): @@ -676,17 +685,9 @@ def upload_files( logger.info(f"Uploading {len(file_contents)} files") - return self.invoke( - operate_type="write_files", - arguments={ - "write_contents": file_contents - } - ) + return self.invoke(operate_type="write_files", arguments={"write_contents": file_contents}) - def download_file( - self, - path: str - ) -> str | bytes: + def download_file(self, path: str) -> str | bytes: """Download a file from the code interpreter. Args: @@ -706,31 +707,25 @@ def download_file( raise ValueError(msg) logger.info(f"Downloading file from {path}") - result = self.invoke( - operate_type="read_files", - arguments={ - "paths": [path] - } - ) + result = self.invoke(operate_type="read_files", arguments={"paths": [path]}) # Extract file content - if "stream" not in result: + if "result" not in result or "content" not in result["result"]: msg = f"Could not read file: {path}" raise FileNotFoundError(msg) - - for event in result["stream"]: - if "result" not in event: - msg = f"Could not read file: {path}" - raise FileNotFoundError(msg) - for content_item in event["result"].get("contents", []): - if content_item.get("type") != "resource": - msg = f"Could not read file: {path}" - raise FileNotFoundError(msg) - resource = content_item.get("resource", {}) - if "text" in resource: - return resource["text"] - if "blob" in resource: - raw = base64.b64decode(resource["blob"]) + result = result["result"] + + for content_item in result["content"]: + if content_item.get("type") == "text": + return content_item.get("text", "") + if content_item.get("type") == "image": + return base64.b64decode(content_item.get("data", "")) + if content_item.get("type") == "resource": + content_resource = content_item.get("resource", {}) + if content_resource.get("type") == "text": + return content_resource["text"] + if content_resource.get("type") == "blob": + raw = base64.b64decode(content_resource["blob"]) try: return raw.decode("utf-8") except ValueError: @@ -738,10 +733,7 @@ def download_file( msg = f"Could not read file: {path}" raise FileNotFoundError(msg) - def download_files( - self, - paths: list[str] - ) -> dict[str, str | bytes]: + def download_files(self, paths: list[str]) -> dict[str, str | bytes]: """Download multiple files from the code interpreter. Args: @@ -762,39 +754,39 @@ def download_files( if not path.startswith(DEFAULT_PATH): msg = f"Invalid path. Path must start with {DEFAULT_PATH}" raise ValueError(msg) - result = self.invoke( - operate_type="read_files", - arguments={ - "paths": paths - } - ) + result = self.invoke(operate_type="read_files", arguments={"paths": paths}) + + # Extract file content + if "result" not in result or "content" not in result["result"]: + msg = f"Could not read file: {path}" + raise FileNotFoundError(msg) + result = result["result"] files = {} - for event in result["stream"]: - if "result" not in event: - return files - for content_item in event["result"].get("contents", []): - if content_item.get("type") != "resource": - return files + for content_item in result["content"]: + uri = content_item.get("uri", "") + file_path = uri.replace("file://", "") + + if content_item.get("type") == "text": + files[file_path] = content_item.get("text", "") + elif content_item.get("type") == "image": + files[file_path] = base64.b64decode(content_item.get("data", "")) + elif content_item.get("type") == "resource": resource = content_item.get("resource", {}) - uri = resource.get("uri", "") - file_path = uri.replace("file://", "") + resource_uri = resource.get("uri", "") + resource_file_path = resource_uri.replace("file://", "") - if "text" in resource: - files[file_path] = resource["text"] - elif "blob" in resource: + if resource.get("type") == "text": + files[resource_file_path] = resource["text"] + elif resource.get("type") == "blob": raw = base64.b64decode(resource["blob"]) try: - files[file_path] = raw.decode("utf-8") + files[resource_file_path] = raw.decode("utf-8") except ValueError: - files[file_path] = raw + files[resource_file_path] = raw return files - def install_packages( - self, - packages: list[str], - upgrade: bool = False - ) -> dict[str, Any]: + def install_packages(self, packages: list[str], upgrade: bool = False) -> dict[str, Any]: """Install Python packages in the code interpreter. Args: @@ -830,12 +822,7 @@ def install_packages( command = f"pip install {package_str} {upgrade_flag}" logger.info(f"Installing packages: {package_str}") - return self.invoke( - operate_type="execute_command", - arguments={ - "command": command - } - ) + return self.invoke(operate_type="execute_command", arguments={"command": command}) def clear_context(self) -> dict[str, Any]: """Clear the code interpreter context. @@ -858,26 +845,26 @@ def clear_context(self) -> dict[str, Any]: logger.info("Clearing code interpreter context") return self.invoke( operate_type="execute_code", - arguments={ - "code": "# Context cleared", - "language": "python", - "clear_context": True - } + arguments={"code": "# Context cleared", "language": "python", "clear_context": True}, ) + @contextmanager def code_session( region: str, code_interpreter_name: str, - api_key: str | None = None + auth_type: str = "API_KEY", + api_key: str | None = None, ) -> Generator[CodeInterpreter, None, None]: """Code interpreter session context manager. Args: region (str): Region name, e.g., "cn-southwest-2" + auth_type (str, optional): Authentication type, default "API_KEY". + Can be "API_KEY" or "IAM" code_interpreter_name (str): Code interpreter name - api_key (Optional[str]): API Key, if not provided will be retrieved from - environment variable HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY + api_key (Optional[str]): API Key for authentication, use only when auth_type is "API_KEY", + if not provided will be retrieved from environment variable API_KEY Yields: CodeInterpreter: Code interpreter instance with session started @@ -889,14 +876,19 @@ def code_session( >>> # With API Key >>> with code_session("cn-southwest-2", "my-code-interpreter-name", api_key="your-api-key") as client: >>> client.execute_code("print('Hello, World!')") + >>> + >>> # With IAM + >>> with code_session("cn-southwest-2", "my-code-interpreter-name", auth_type="IAM") as client: + >>> client.execute_code("print('Hello, World!')") """ - client = CodeInterpreter(region=region) + client = CodeInterpreter(region=region, auth_type=auth_type) + default_session_name = "default-session-name" client.start_session( code_interpreter_name=code_interpreter_name, session_name=default_session_name, - api_key=api_key + api_key=api_key, ) try: diff --git a/tests/unit/sdk/service/test_tools_http.py b/tests/unit/sdk/service/test_tools_http.py index 4e9a10c..68264f2 100644 --- a/tests/unit/sdk/service/test_tools_http.py +++ b/tests/unit/sdk/service/test_tools_http.py @@ -9,8 +9,12 @@ class TestToolsHttpClient(unittest.TestCase): @patch("agentarts.sdk.utils.constant.ENV_HUAWEICLOUD_SDK_AK") @patch("agentarts.sdk.utils.constant.ENV_HUAWEICLOUD_SDK_SK") def setUp(self, mock_ak, mock_sk): - self.control_client = ControlToolsHttpClient(region_name="test-region", endpoint_url="https://test.com") - self.data_client = DataToolsHttpClient(region_name="test-region", endpoint_url="https://test.com") + self.control_client = ControlToolsHttpClient( + region_name="test-region", endpoint_url="https://test.com" + ) + self.data_client = DataToolsHttpClient( + region_name="test-region", endpoint_url="https://test.com" + ) @patch.object(ControlToolsHttpClient, "post") def test_create_code_interpreter(self, mock_post): @@ -28,7 +32,7 @@ def test_create_code_interpreter(self, mock_post): "observability": {}, "network_config": {}, "agent_gateway_id": "a1b2c3d4-e5f6-7890-abcd-ef12334567890", - "tags": [] + "tags": [], }, headers={}, streaming=True, @@ -46,7 +50,7 @@ def test_create_code_interpreter(self, mock_post): "observability": {}, "network_config": {}, "agent_gateway_id": "a1b2c3d4-e5f6-7890-abcd-ef12334567890", - "tags": [] + "tags": [], } ) @@ -63,8 +67,8 @@ def test_create_code_interpreter(self, mock_post): "observability": {}, "network_config": {}, "agent_gateway_id": "a1b2c3d4-e5f6-7890-abcd-ef12334567890", - "tags": [] - } + "tags": [], + }, ) @patch.object(ControlToolsHttpClient, "get") @@ -90,9 +94,9 @@ def test_list_code_interpreters(self, mock_get): }, "access_endpoint": "", "observability": {}, - "tags": [] + "tags": [], } - ] + ], }, headers={}, streaming=True, @@ -100,21 +104,12 @@ def test_list_code_interpreters(self, mock_get): ) # Act - result = self.control_client.list_code_interpreters( - { - "offset": 0, - "limit": 10 - } - ) + result = self.control_client.list_code_interpreters({"offset": 0, "limit": 10}) # Assert assert result == mock_get.return_value.data mock_get.assert_called_once_with( - url="/v1/core/code-interpreters", - params={ - "offset": 0, - "limit": 10 - } + url="/v1/core/code-interpreters", params={"offset": 0, "limit": 10} ) @patch.object(ControlToolsHttpClient, "put") @@ -125,15 +120,7 @@ def test_update_code_interpreter(self, mock_put): mock_put.return_value = RequestResult( success=True, status_code=200, - data={ - "observability": {}, - "tags": [ - { - "key": "test-tag", - "value": "test-tag-value" - } - ] - }, + data={"observability": {}, "tags": [{"key": "test-tag", "value": "test-tag-value"}]}, headers={}, streaming=True, _raw_response=None, @@ -144,28 +131,15 @@ def test_update_code_interpreter(self, mock_put): code_interpreter_id=code_interpreter_id, request_params={ "observability": {}, - "tags": [ - { - "key": "test-tag", - "value": "test-tag-value" - } - ] - } + "tags": [{"key": "test-tag", "value": "test-tag-value"}], + }, ) # Assert assert result == mock_put.return_value.data mock_put.assert_called_once_with( url=f"/v1/core/code-interpreters/{code_interpreter_id}", - json={ - "observability": {}, - "tags": [ - { - "key": "test-tag", - "value": "test-tag-value" - } - ] - } + json={"observability": {}, "tags": [{"key": "test-tag", "value": "test-tag-value"}]}, ) @patch.object(ControlToolsHttpClient, "get") @@ -192,7 +166,7 @@ def test_get_code_interpreter(self, mock_get): "tags": [], "auth_type": "API_KEY", "api_key_name": "test-api-key-name", - "network_config": {} + "network_config": {}, }, headers={}, streaming=True, @@ -200,15 +174,11 @@ def test_get_code_interpreter(self, mock_get): ) # Act - result = self.control_client.get_code_interpreter( - code_interpreter_id=code_interpreter_id - ) + result = self.control_client.get_code_interpreter(code_interpreter_id=code_interpreter_id) # Assert assert result == mock_get.return_value.data - mock_get.assert_called_once_with( - url=f"/v1/core/code-interpreters/{code_interpreter_id}" - ) + mock_get.assert_called_once_with(url=f"/v1/core/code-interpreters/{code_interpreter_id}") @patch.object(ControlToolsHttpClient, "delete") def test_delete_code_interpreter(self, mock_delete): @@ -231,9 +201,7 @@ def test_delete_code_interpreter(self, mock_delete): # Assert assert result == mock_delete.return_value.data - mock_delete.assert_called_once_with( - url=f"/v1/core/code-interpreters/{code_interpreter_id}" - ) + mock_delete.assert_called_once_with(url=f"/v1/core/code-interpreters/{code_interpreter_id}") @patch.object(DataToolsHttpClient, "put") def test_start_session(self, mock_post): @@ -247,7 +215,7 @@ def test_start_session(self, mock_post): "created_at": "2026-01-01T00:00:00Z", "name": "test-session-name", "session_id": "test-session-id", - "session_timeout": 600 + "session_timeout": 600, }, headers={}, streaming=True, @@ -255,16 +223,13 @@ def test_start_session(self, mock_post): ) code_interpreter_name = "test-code-interpreter-name" api_key = "test-api-key" - request_params = { - "name": "test-session-name", - "session_timeout": 600 - } + request_params = {"name": "test-session-name", "session_timeout": 600} # Act result = self.data_client.start_session( code_interpreter_name=code_interpreter_name, api_key=api_key, - request_params=request_params + request_params=request_params, ) # Assert @@ -272,7 +237,7 @@ def test_start_session(self, mock_post): mock_post.assert_called_once_with( url=f"/v1/code-interpreters/{code_interpreter_name}/sessions-start", json=request_params, - headers = {"Authorization": f"Bearer {api_key}"} + headers={"Authorization": f"Bearer {api_key}"}, ) @patch.object(DataToolsHttpClient, "get") @@ -287,7 +252,7 @@ def test_get_session(self, mock_get): "created_at": "2026-01-01T00:00:00Z", "name": "test-session-name", "session_id": "test-session-id", - "session_timeout": 600 + "session_timeout": 600, }, headers={}, streaming=True, @@ -299,9 +264,7 @@ def test_get_session(self, mock_get): # Act result = self.data_client.get_session( - code_interpreter_name=code_interpreter_name, - session_id=session_id, - api_key=api_key + code_interpreter_name=code_interpreter_name, session_id=session_id, api_key=api_key ) # Assert @@ -310,8 +273,8 @@ def test_get_session(self, mock_get): url=f"/v1/code-interpreters/{code_interpreter_name}/sessions-get", headers={ "x-HW-Agentarts-Code-Interpreter-Session-Id": session_id, - "Authorization": f"Bearer {api_key}" - } + "Authorization": f"Bearer {api_key}", + }, ) @patch.object(DataToolsHttpClient, "put") @@ -324,9 +287,7 @@ def test_stop_session(self, mock_put): # Act self.data_client.stop_session( - code_interpreter_name=code_interpreter_name, - session_id=session_id, - api_key=api_key + code_interpreter_name=code_interpreter_name, session_id=session_id, api_key=api_key ) # Assert @@ -334,8 +295,8 @@ def test_stop_session(self, mock_put): url=f"/v1/code-interpreters/{code_interpreter_name}/sessions-stop", headers={ "x-HW-Agentarts-Code-Interpreter-Session-Id": session_id, - "Authorization": f"Bearer {api_key}" - } + "Authorization": f"Bearer {api_key}", + }, ) @patch.object(DataToolsHttpClient, "post") @@ -362,8 +323,8 @@ def test_invoke(self, mock_post): "blob": "", "mime_type": "string", "text": "string", - "uri": "string" - } + "uri": "string", + }, } ], "is_error": False, @@ -371,8 +332,8 @@ def test_invoke(self, mock_post): "execution_time": 100, "exit_code": 0, "stderr": "string", - "stdout": "string" - } + "stdout": "string", + }, }, }, headers={}, @@ -387,8 +348,8 @@ def test_invoke(self, mock_post): "arguments": { "clear_context": False, "code": "print('hello world')", - "language": "python" - } + "language": "python", + }, } # Act @@ -396,7 +357,7 @@ def test_invoke(self, mock_post): code_interpreter_name=code_interpreter_name, session_id=session_id, api_key=api_key, - arguments=params + arguments=params, ) # Assert @@ -405,7 +366,7 @@ def test_invoke(self, mock_post): url=f"/v1/code-interpreters/{code_interpreter_name}/invoke", headers={ "x-HW-Agentarts-Code-Interpreter-Session-Id": session_id, - "Authorization": f"Bearer {api_key}" + "Authorization": f"Bearer {api_key}", }, - json=params + json=params, ) diff --git a/tests/unit/sdk/tools/test_code_interpreter_client.py b/tests/unit/sdk/tools/test_code_interpreter_client.py index 3e28da5..b7bbd3c 100644 --- a/tests/unit/sdk/tools/test_code_interpreter_client.py +++ b/tests/unit/sdk/tools/test_code_interpreter_client.py @@ -18,7 +18,9 @@ class TestCodeInterpreterClient(unittest.TestCase): @patch("agentarts.sdk.utils.constant.ENV_HUAWEICLOUD_SDK_SK") @patch("agentarts.sdk.utils.constant.get_control_plane_endpoint") @patch("agentarts.sdk.utils.constant.get_code_interpreter_data_plane_endpoint") - def setUp(self, mock_get_data_plane_endpoint, mock_get_control_plane_endpoint, mock_sk, mock_ak): + def setUp( + self, mock_get_data_plane_endpoint, mock_get_control_plane_endpoint, mock_sk, mock_ak + ): """在每个测试方法前调用""" mock_get_control_plane_endpoint.return_value = "https://control-plane.example.com" mock_get_data_plane_endpoint.return_value = "https://data-plane.example.com" @@ -39,13 +41,43 @@ def test_create_code_interpreter_with_required_params(self, mock_create_code_int "observability": {}, "network_config": {}, "agent_gateway_id": "test-agent-gateway-id", - "tags": [] + "tags": [], } # Act result = self.code_interpreter_client.create_code_interpreter( - name="test-code-interpreter-name", - api_key_name="test-api-key-name" + name="test-code-interpreter-name", api_key_name="test-api-key-name" + ) + + # Assert + assert result == mock_create_code_interpreter.return_value + mock_create_code_interpreter.assert_called_once_with( + request_params={ + "name": "test-code-interpreter-name", + "auth_type": "API_KEY", + "api_key_name": "test-api-key-name", + } + ) + + @patch.object(ControlToolsHttpClient, "create_code_interpreter") + def test_create_code_interpreter_with_api_key(self, mock_create_code_interpreter): + """测试create_code_interpreter方法,提供API_KEY认证的情况""" + # Arrange + mock_create_code_interpreter.return_value = { + "name": "test-code-interpreter-name", + "api_key_name": "test-api-key-name", + "description": "test-code-interpreter-description", + "auth_type": "API_KEY", + "execution_agency_name": "test-execution-agency-name", + "observability": {}, + "network_config": {}, + "agent_gateway_id": "test-agent-gateway-id", + "tags": [], + } + + # Act + result = self.code_interpreter_client.create_code_interpreter( + name="test-code-interpreter-name", auth_type="API_KEY", api_key_name="test-api-key-name" ) # Assert @@ -53,79 +85,81 @@ def test_create_code_interpreter_with_required_params(self, mock_create_code_int mock_create_code_interpreter.assert_called_once_with( request_params={ "name": "test-code-interpreter-name", - "api_key_name": "test-api-key-name" + "auth_type": "API_KEY", + "api_key_name": "test-api-key-name", } ) + @patch.object(ControlToolsHttpClient, "create_code_interpreter") + def test_create_code_interpreter_with_iam(self, mock_create_code_interpreter): + """测试create_code_interpreter方法,提供IAM认证的情况""" + # Arrange + mock_create_code_interpreter.return_value = { + "name": "test-code-interpreter-name", + "api_key_name": "test-api-key-name", + "description": "test-code-interpreter-description", + "auth_type": "IAM", + "execution_agency_name": "test-execution-agency-name", + "observability": {}, + "network_config": {}, + "agent_gateway_id": "test-agent-gateway-id", + "tags": [], + } + + # Act + result = self.code_interpreter_client.create_code_interpreter( + name="test-code-interpreter-name", auth_type="IAM" + ) + + # Assert + assert result == mock_create_code_interpreter.return_value + mock_create_code_interpreter.assert_called_once_with( + request_params={"name": "test-code-interpreter-name", "auth_type": "IAM"} + ) + @patch.object(ControlToolsHttpClient, "create_code_interpreter") def test_create_code_interpreter_with_all_params(self, mock_create_code_interpreter): """测试create_code_interpreter方法,提供所有参数的情况""" # Arrange mock_create_code_interpreter.return_value = { "name": "test-code-interpreter-name", + "auth_type": "API_KEY", "api_key_name": "test-api-key-name", "description": "test-code-interpreter-description", - "auth_type": "API_KEY", "execution_agency_name": "test-execution-agency-name", "observability": { "logs": { "enable": True, "group_id": "test-group-id", - "stream_id": "test-stream-id" + "stream_id": "test-stream-id", }, - "metrics": { - "enable": True, - "instance_id": "test-instance-id" - }, - "tracing": { - "enable": True, - "service_group": "test-service-group" - } - }, - "network_config": { - "network_config": "PUBLIC" + "metrics": {"enable": True, "instance_id": "test-instance-id"}, + "tracing": {"enable": True, "service_group": "test-service-group"}, }, + "network_config": {"network_config": "PUBLIC"}, "agent_gateway_id": "test-agent-gateway-id", - "tags": [ - { - "key": "test-tag", - "value": "test-tag-value" - } - ] + "tags": [{"key": "test-tag", "value": "test-tag-value"}], } # Act result = self.code_interpreter_client.create_code_interpreter( name="test-code-interpreter-name", + auth_type="API_KEY", api_key_name="test-api-key-name", description="test-code-interpreter-description", - auth_type="API_KEY", execution_agency_name="test-execution-agency-name", observability={ "logs": { "enable": True, "group_id": "test-group-id", - "stream_id": "test-stream-id" - }, - "metrics": { - "enable": True, - "instance_id": "test-instance-id" + "stream_id": "test-stream-id", }, - "tracing": { - "enable": True, - "service_group": "test-service-group" - } - }, - network_config={ - "network_config": "PUBLIC" + "metrics": {"enable": True, "instance_id": "test-instance-id"}, + "tracing": {"enable": True, "service_group": "test-service-group"}, }, + network_config={"network_config": "PUBLIC"}, agent_gateway_id="test-agent-gateway-id", - tags=[ - { - "key": "test-tag", - "value": "test-tag-value" - } - ] + tags=[{"key": "test-tag", "value": "test-tag-value"}], ) # Assert @@ -133,35 +167,22 @@ def test_create_code_interpreter_with_all_params(self, mock_create_code_interpre mock_create_code_interpreter.assert_called_once_with( request_params={ "name": "test-code-interpreter-name", + "auth_type": "API_KEY", "api_key_name": "test-api-key-name", "description": "test-code-interpreter-description", - "auth_type": "API_KEY", "execution_agency_name": "test-execution-agency-name", "observability": { "logs": { "enable": True, "group_id": "test-group-id", - "stream_id": "test-stream-id" - }, - "metrics": { - "enable": True, - "instance_id": "test-instance-id" + "stream_id": "test-stream-id", }, - "tracing": { - "enable": True, - "service_group": "test-service-group" - } - }, - "network_config": { - "network_config": "PUBLIC" + "metrics": {"enable": True, "instance_id": "test-instance-id"}, + "tracing": {"enable": True, "service_group": "test-service-group"}, }, + "network_config": {"network_config": "PUBLIC"}, "agent_gateway_id": "test-agent-gateway-id", - "tags": [ - { - "key": "test-tag", - "value": "test-tag-value" - } - ] + "tags": [{"key": "test-tag", "value": "test-tag-value"}], } ) @@ -171,12 +192,7 @@ def test_list_code_interpreters_with_default_params(self, mock_list_code_interpr # Arrange mock_list_code_interpreters.return_value = { "total_count": 1, - "items": [ - { - "name": "test-code-interpreter-name", - "api_key_name": "test-api-key-name" - } - ] + "items": [{"name": "test-code-interpreter-name", "api_key_name": "test-api-key-name"}], } # Act @@ -197,21 +213,12 @@ def test_list_code_interpreters_with_all_params(self, mock_list_code_interpreter # Arrange mock_list_code_interpreters.return_value = { "total_count": 1, - "items": [ - { - "name": "test-code-interpreter-name", - "api_key_name": "test-api-key-name" - } - ] + "items": [{"name": "test-code-interpreter-name", "api_key_name": "test-api-key-name"}], } # Act result = self.code_interpreter_client.list_code_interpreters( - name="test-name", - limit=20, - offset=10, - sort_key="created_at", - sort_dir="asc" + name="test-name", limit=20, offset=10, sort_key="created_at", sort_dir="asc" ) # Assert @@ -222,7 +229,7 @@ def test_list_code_interpreters_with_all_params(self, mock_list_code_interpreter "limit": 20, "offset": 10, "sort_key": "created_at", - "sort_dir": "asc" + "sort_dir": "asc", } ) @@ -238,31 +245,18 @@ def test_update_code_interpreter(self, mock_update_code_interpreter): "updated_at": "2026-01-01T00:00:00Z", "execution_agency_name": "test-execution-agency-name", "agent_gateway_id": "test-agent-gateway-id", - "workload_identity": { - "urn": "test-workload-urn" - }, + "workload_identity": {"urn": "test-workload-urn"}, "access_endpoint": "test-access-endpoint-url", "observability": { "logs": { "enable": True, "group_id": "test-group-id", - "stream_id": "test-stream-id" + "stream_id": "test-stream-id", }, - "metrics": { - "enable": True, - "instance_id": "test-instance-id" - }, - "tracing": { - "enable": True, - "service_group": "test-service-group" - } + "metrics": {"enable": True, "instance_id": "test-instance-id"}, + "tracing": {"enable": True, "service_group": "test-service-group"}, }, - "tags": [ - { - "key": "test-tag", - "value": "test-tag-value" - } - ] + "tags": [{"key": "test-tag", "value": "test-tag-value"}], } # Act @@ -272,23 +266,12 @@ def test_update_code_interpreter(self, mock_update_code_interpreter): "logs": { "enable": True, "group_id": "test-group-id", - "stream_id": "test-stream-id" + "stream_id": "test-stream-id", }, - "metrics": { - "enable": True, - "instance_id": "test-instance-id" - }, - "tracing": { - "enable": True, - "service_group": "test-service-group" - } + "metrics": {"enable": True, "instance_id": "test-instance-id"}, + "tracing": {"enable": True, "service_group": "test-service-group"}, }, - tags=[ - { - "key": "test-tag", - "value": "test-tag-value" - } - ] + tags=[{"key": "test-tag", "value": "test-tag-value"}], ) # Assert assert result == mock_update_code_interpreter.return_value @@ -299,24 +282,13 @@ def test_update_code_interpreter(self, mock_update_code_interpreter): "logs": { "enable": True, "group_id": "test-group-id", - "stream_id": "test-stream-id" + "stream_id": "test-stream-id", }, - "metrics": { - "enable": True, - "instance_id": "test-instance-id" - }, - "tracing": { - "enable": True, - "service_group": "test-service-group" - } + "metrics": {"enable": True, "instance_id": "test-instance-id"}, + "tracing": {"enable": True, "service_group": "test-service-group"}, }, - "tags": [ - { - "key": "test-tag", - "value": "test-tag-value" - } - ] - } + "tags": [{"key": "test-tag", "value": "test-tag-value"}], + }, ) @patch.object(ControlToolsHttpClient, "get_code_interpreter") @@ -334,37 +306,22 @@ def test_get_code_interpreter(self, mock_get_code_interpreter): "logs": { "enable": True, "group_id": "test-group-id", - "stream_id": "test-stream-id" - }, - "metrics": { - "enable": True, - "instance_id": "test-instance-id" + "stream_id": "test-stream-id", }, - "tracing": { - "enable": True, - "service_group": "test-service-group" - } - }, - "workload_identity": { - "urn": "test-workload-urn" + "metrics": {"enable": True, "instance_id": "test-instance-id"}, + "tracing": {"enable": True, "service_group": "test-service-group"}, }, + "workload_identity": {"urn": "test-workload-urn"}, "access_endpoint": "test-access-endpoint-url", "agent_gateway_id": "test-agent-gateway-id", - "tags": [ - { - "key": "test-tag", - "value": "test-tag-value" - } - ], + "tags": [{"key": "test-tag", "value": "test-tag-value"}], "auth_type": "API_KEY", "api_key_name": "test-api-key-name", "network_config": { "vpc_id": "test-vpc-id", "subnet_id": "test-subnet-id", - "security_group_id": [ - "test-security-group-id" - ] - } + "security_group_id": ["test-security-group-id"], + }, } code_interpreter_id = "test-code-interpreter-id" @@ -375,9 +332,7 @@ def test_get_code_interpreter(self, mock_get_code_interpreter): # Assert assert result == mock_get_code_interpreter.return_value - mock_get_code_interpreter.assert_called_once_with( - code_interpreter_id=code_interpreter_id - ) + mock_get_code_interpreter.assert_called_once_with(code_interpreter_id=code_interpreter_id) @patch.object(ControlToolsHttpClient, "delete_code_interpreter") def test_delete_code_interpreter(self, mock_delete_code_interpreter): @@ -406,8 +361,7 @@ def test_start_session_with_default_params(self, mock_start_session, mock_getenv # Act result = self.code_interpreter_client.start_session( - code_interpreter_name=test_code_interpreter_name, - session_name=test_session_name + code_interpreter_name=test_code_interpreter_name, session_name=test_session_name ) # Assert @@ -417,10 +371,7 @@ def test_start_session_with_default_params(self, mock_start_session, mock_getenv mock_start_session.assert_called_once_with( code_interpreter_name="test-code-interpreter-name", api_key="test-api-key", - request_params={ - "name": test_session_name, - "session_timeout": 900 # 默认值 - } + request_params={"name": test_session_name, "session_timeout": 900}, # 默认值 ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") @@ -440,7 +391,7 @@ def test_start_session_with_custom_params(self, mock_start_session, mock_getenv) code_interpreter_name=test_code_interpreter_name, session_name=test_session_name, api_key=test_api_key, - session_timeout=test_session_timeout + session_timeout=test_session_timeout, ) # Assert @@ -450,14 +401,10 @@ def test_start_session_with_custom_params(self, mock_start_session, mock_getenv) mock_start_session.assert_called_once_with( code_interpreter_name="test-code-interpreter-name", api_key=test_api_key, - request_params={ - "name": test_session_name, - "session_timeout": test_session_timeout - } + request_params={"name": test_session_name, "session_timeout": test_session_timeout}, ) mock_getenv.assert_not_called() # 因为我们提供了api_key,所以不应该调用getenv - @patch.object(os, "getenv") @patch.object(DataToolsHttpClient, "get_session") def test_get_session_with_custom_params(self, mock_get_session, mock_getenv): @@ -470,14 +417,13 @@ def test_get_session_with_custom_params(self, mock_get_session, mock_getenv): "session_id": session_id, "created_at": "2023-01-01T00:00:00Z", "name": "test-session-name", - "session_timeout": 900 + "session_timeout": 900, } mock_getenv.return_value = "test-api-key" # Act response = self.code_interpreter_client.get_session( - code_interpreter_name=code_interpreter_name, - session_id=session_id + code_interpreter_name=code_interpreter_name, session_id=session_id ) # Assert @@ -485,7 +431,7 @@ def test_get_session_with_custom_params(self, mock_get_session, mock_getenv): mock_get_session.assert_called_once_with( code_interpreter_name=code_interpreter_name, session_id=session_id, - api_key="test-api-key" + api_key="test-api-key", ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") @@ -497,9 +443,7 @@ def test_get_session_with_no_params(self): # Act & Assert with pytest.raises(ValueError, match=error_message): - self.code_interpreter_client.get_session( - code_interpreter_name=code_interpreter_name - ) + self.code_interpreter_client.get_session(code_interpreter_name=code_interpreter_name) @patch.object(os, "getenv") @patch.object(DataToolsHttpClient, "stop_session") @@ -517,7 +461,7 @@ def test_stop_session_with_session_exists(self, mock_stop_session, mock_getenv): mock_stop_session.assert_called_once_with( code_interpreter_name="test-code-interpreter-name", session_id="test-session-id", - api_key="test-api-key" + api_key="test-api-key", ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") @@ -530,7 +474,6 @@ def test_stop_session_with_no_activate(self): assert self.code_interpreter_client.session_id is None assert self.code_interpreter_client.code_interpreter_name is None - def test_invoke_with_no_existing_session(self): """测试invoke方法,无激活会话的情况""" # Arrange @@ -538,10 +481,7 @@ def test_invoke_with_no_existing_session(self): # Act & Assert with pytest.raises(ValueError, match=error_message): - self.code_interpreter_client.invoke( - operate_type="test-method", - arguments={} - ) + self.code_interpreter_client.invoke(operate_type="test-method", arguments={}) @patch.object(os, "getenv") @patch.object(DataToolsHttpClient, "invoke") @@ -554,20 +494,14 @@ def test_invoke_with_existing_session(self, mock_invoke, mock_getenv): mock_getenv.return_value = "test-api-key" # Act - self.code_interpreter_client.invoke( - operate_type="test-method", - arguments={} - ) + self.code_interpreter_client.invoke(operate_type="test-method", arguments={}) # Assert mock_invoke.assert_called_once_with( code_interpreter_name="test-code-interpreter-name", session_id="test-session-id", api_key="test-api-key", - arguments={ - "operate_type": "test-method", - "arguments": {} - } + arguments={"operate_type": "test-method", "arguments": {}}, ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") @@ -594,12 +528,8 @@ def test_execute_code_with_python(self, mock_invoke, mock_getenv): api_key="test-api-key", arguments={ "operate_type": "execute_code", - "arguments": { - "code": code, - "language": "python", - "clear_context": False - } - } + "arguments": {"code": code, "language": "python", "clear_context": False}, + }, ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") @@ -615,10 +545,7 @@ def test_execute_code_with_clear_context(self, mock_invoke, mock_getenv): mock_getenv.return_value = "test-api-key" # Act - self.code_interpreter_client.execute_code( - code=code, - clear_context=True - ) + self.code_interpreter_client.execute_code(code=code, clear_context=True) # Assert mock_invoke.assert_called_once_with( @@ -627,12 +554,8 @@ def test_execute_code_with_clear_context(self, mock_invoke, mock_getenv): api_key="test-api-key", arguments={ "operate_type": "execute_code", - "arguments": { - "code": code, - "language": "python", - "clear_context": True - } - } + "arguments": {"code": code, "language": "python", "clear_context": True}, + }, ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") @@ -647,10 +570,7 @@ def test_execute_code_with_invalid_language(self): # Act & Assert with pytest.raises(ValueError, match=error_message): - self.code_interpreter_client.execute_code( - code=code, - language=language - ) + self.code_interpreter_client.execute_code(code=code, language=language) @patch.object(os, "getenv") @patch.object(DataToolsHttpClient, "invoke") @@ -664,21 +584,14 @@ def test_execute_command_with_valid_command(self, mock_invoke, mock_getenv): mock_getenv.return_value = "test-api-key" # Act - self.code_interpreter_client.execute_command( - command=command - ) + self.code_interpreter_client.execute_command(command=command) # Assert mock_invoke.assert_called_once_with( code_interpreter_name="test-code-interpreter-name", session_id="test-session-id", api_key="test-api-key", - arguments={ - "operate_type": "execute_command", - "arguments": { - "command": command - } - } + arguments={"operate_type": "execute_command", "arguments": {"command": command}}, ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") @@ -694,9 +607,7 @@ def test_execute_command_with_invalid_command(self, mock_getenv): # Act & Assert with pytest.raises(ValueError, match=error_message): - self.code_interpreter_client.execute_command( - command=command - ) + self.code_interpreter_client.execute_command(command=command) @patch.object(os, "getenv") @patch.object(DataToolsHttpClient, "invoke") @@ -711,9 +622,7 @@ def test_upload_file_with_text_content(self, mock_invoke, mock_getenv): text_content = "Hello, World!" description = "test file" self.code_interpreter_client.upload_file( - path=path, - content=text_content, - description=description + path=path, content=text_content, description=description ) # Assert @@ -723,15 +632,8 @@ def test_upload_file_with_text_content(self, mock_invoke, mock_getenv): api_key="test-api-key", arguments={ "operate_type": "write_files", - "arguments": { - "write_contents": [ - { - "path": path, - "text": text_content - } - ] - } - } + "arguments": {"write_contents": [{"path": path, "text": text_content}]}, + }, ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") @@ -749,9 +651,7 @@ def test_upload_file_with_binary_content(self, mock_invoke, mock_getenv): encoded_content = base64.b64encode(binary_content).decode("utf-8") description = "test file" self.code_interpreter_client.upload_file( - path=path, - content=binary_content, - description=description + path=path, content=binary_content, description=description ) # Assert @@ -761,15 +661,8 @@ def test_upload_file_with_binary_content(self, mock_invoke, mock_getenv): api_key="test-api-key", arguments={ "operate_type": "write_files", - "arguments": { - "write_contents": [ - { - "path": path, - "blob": encoded_content - } - ] - } - } + "arguments": {"write_contents": [{"path": path, "blob": encoded_content}]}, + }, ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") @@ -788,9 +681,7 @@ def test_upload_file_with_relative_path(self, mock_invoke, mock_getenv): # Act self.code_interpreter_client.upload_file( - path=path, - content=text_content, - description=description + path=path, content=text_content, description=description ) # Assert @@ -802,17 +693,13 @@ def test_upload_file_with_relative_path(self, mock_invoke, mock_getenv): "operate_type": "write_files", "arguments": { "write_contents": [ - { - "path": os.path.join("/home/user", path), - "text": text_content - } + {"path": os.path.join("/home/user", path), "text": text_content} ] - } - } + }, + }, ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") - @patch.object(os, "getenv") @patch.object(DataToolsHttpClient, "invoke") def test_upload_files_with_mixed_content(self, mock_invoke, mock_getenv): @@ -856,18 +743,17 @@ def test_upload_files_with_mixed_content(self, mock_invoke, mock_getenv): "blob": encoded_content, }, ] - } - } + }, + }, ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") - def test_upload_files_with_invalid_path(self): """测试upload_files方法,提供无效路径的情况""" # Arrange self.code_interpreter_client.code_interpreter_name = "test-code-interpreter-name" self.code_interpreter_client.session_id = "test-session-id" - error_message= "Invalid path. Path must start with /home/user" + error_message = "Invalid path. Path must start with /home/user" files = [ { "path": "/invalid/user/test.txt", @@ -883,7 +769,6 @@ def test_upload_files_with_invalid_path(self): with pytest.raises(ValueError, match=error_message): self.code_interpreter_client.upload_files(files=files) - @patch.object(os, "getenv") @patch.object(DataToolsHttpClient, "invoke") def test_download_file_with_text_content(self, mock_invoke, mock_getenv): @@ -891,21 +776,18 @@ def test_download_file_with_text_content(self, mock_invoke, mock_getenv): # Arrange text_content = "col1, col2\n1, 2\n3, 4" mock_invoke.return_value = { - "stream" : [ - { - "result": { - "contents": [ - { - "type": "resource", - "resource": { - "uri": "/home/user/data.csv", - "text": text_content - } - } - ] + "result": { + "content": [ + { + "type": "resource", + "resource": { + "type": "text", + "uri": "file:///home/user/data.csv", + "text": text_content, + }, } - } - ] + ] + } } self.code_interpreter_client.code_interpreter_name = "test-code-interpreter-name" self.code_interpreter_client.session_id = "test-session-id" @@ -920,14 +802,7 @@ def test_download_file_with_text_content(self, mock_invoke, mock_getenv): code_interpreter_name="test-code-interpreter-name", session_id="test-session-id", api_key="test-api-key", - arguments={ - "operate_type": "read_files", - "arguments": { - "paths": [ - path - ] - } - } + arguments={"operate_type": "read_files", "arguments": {"paths": [path]}}, ) assert response == text_content mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") @@ -940,21 +815,18 @@ def test_download_file_with_binary_content(self, mock_invoke, mock_getenv): binary_content = b"\x89PNG\r\n\x1a\n" encode_content = base64.b64encode(binary_content).decode("utf-8") mock_invoke.return_value = { - "stream" : [ - { - "result": { - "contents": [ - { - "type": "resource", - "resource": { - "uri": "/home/user/image.png", - "blob": encode_content - } - } - ] + "result": { + "content": [ + { + "type": "resource", + "resource": { + "type": "blob", + "uri": "file:///home/user/image.png", + "blob": encode_content, + }, } - } - ] + ] + } } self.code_interpreter_client.code_interpreter_name = "test-code-interpreter-name" self.code_interpreter_client.session_id = "test-session-id" @@ -969,28 +841,21 @@ def test_download_file_with_binary_content(self, mock_invoke, mock_getenv): code_interpreter_name="test-code-interpreter-name", session_id="test-session-id", api_key="test-api-key", - arguments={ - "operate_type": "read_files", - "arguments": { - "paths": [ - path - ] - } - } + arguments={"operate_type": "read_files", "arguments": {"paths": [path]}}, ) assert response == binary_content mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") + @patch.object(os, "getenv") @patch.object(DataToolsHttpClient, "invoke") - def test_download_file_with_no_found_file(self, mock_invoke): + def test_download_file_with_no_found_file(self, mock_invoke, mock_getenv): """测试download_file方法,提供不存在的文件的情况""" # Arrange - mock_invoke.return_value = { - "stream" : [] - } + mock_invoke.return_value = {} + mock_getenv.return_value = "test-api-key" self.code_interpreter_client.code_interpreter_name = "test-code-interpreter-name" self.code_interpreter_client.session_id = "test-session-id" - path = "/home/user/non-existent.csv" + path = "/home/user/no-existent.csv" error_message = f"Could not read file: {path}" # Act & Assert @@ -1015,28 +880,26 @@ def test_download_files_with_text_files(self, mock_invoke, mock_getenv): """测试download_files方法,提供文本文件的情况""" # Arrange mock_invoke.return_value = { - "stream" : [ - { - "result": { - "content": [ - { - "type": "resource", - "resource": { - "uri": "/home/user/data.csv", - "text": "col1, col2\n1, 2\n3, 4" - } - }, - { - "type": "resource", - "resource": { - "uri": "/home/user/config.json", - "text": '{"key": "value"}' - } - } - ] - } - } - ] + "result": { + "content": [ + { + "type": "resource", + "resource": { + "type": "text", + "uri": "file:///home/user/data.csv", + "text": "col1, col2\n1, 2\n3, 4", + }, + }, + { + "type": "resource", + "resource": { + "type": "text", + "uri": "file:///home/user/config.json", + "text": '{"key": "value"}', + }, + }, + ] + } } self.code_interpreter_client.code_interpreter_name = "test-code-interpreter-name" self.code_interpreter_client.session_id = "test-session-id" @@ -1044,20 +907,19 @@ def test_download_files_with_text_files(self, mock_invoke, mock_getenv): mock_getenv.return_value = "test-api-key" # Act - self.code_interpreter_client.download_files(paths) + result = self.code_interpreter_client.download_files(paths) # Assert mock_invoke.assert_called_once_with( code_interpreter_name="test-code-interpreter-name", session_id="test-session-id", api_key="test-api-key", - arguments={ - "operate_type": "read_files", - "arguments": { - "paths": paths - } - } + arguments={"operate_type": "read_files", "arguments": {"paths": paths}}, ) + assert result == { + "/home/user/data.csv": "col1, col2\n1, 2\n3, 4", + "/home/user/config.json": '{"key": "value"}', + } mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") @patch.object(os, "getenv") @@ -1068,28 +930,26 @@ def test_download_files_with_binary_files(self, mock_invoke, mock_getenv): binary_content = b"\x89PNG\r\n\x1a\n" encode_binary = base64.b64encode(binary_content).decode("utf-8") mock_invoke.return_value = { - "stream" : [ - { - "result": { - "content": [ - { - "type": "resource", - "resource": { - "uri": "/home/user/iamge-1.png", - "blob": encode_binary - } - }, - { - "type": "resource", - "resource": { - "uri": "/home/user/image-2.png", - "blob": encode_binary - } - } - ] - } - } - ] + "result": { + "content": [ + { + "type": "resource", + "resource": { + "type": "blob", + "uri": "file:///home/user/iamge-1.png", + "blob": encode_binary, + }, + }, + { + "type": "resource", + "resource": { + "type": "blob", + "uri": "file:///home/user/image-2.png", + "blob": encode_binary, + }, + }, + ] + } } self.code_interpreter_client.code_interpreter_name = "test-code-interpreter-name" self.code_interpreter_client.session_id = "test-session-id" @@ -1097,20 +957,19 @@ def test_download_files_with_binary_files(self, mock_invoke, mock_getenv): mock_getenv.return_value = "test-api-key" # Act - self.code_interpreter_client.download_files(paths) + result = self.code_interpreter_client.download_files(paths) # Assert mock_invoke.assert_called_once_with( code_interpreter_name="test-code-interpreter-name", session_id="test-session-id", api_key="test-api-key", - arguments={ - "operate_type": "read_files", - "arguments": { - "paths": paths - } - } + arguments={"operate_type": "read_files", "arguments": {"paths": paths}}, ) + assert result == { + "/home/user/iamge-1.png": binary_content, + "/home/user/image-2.png": binary_content, + } mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") def test_download_files_with_invalid_path(self): @@ -1125,7 +984,6 @@ def test_download_files_with_invalid_path(self): with pytest.raises(ValueError, match=error_message): self.code_interpreter_client.download_files(paths) - @patch.object(os, "getenv") @patch.object(DataToolsHttpClient, "invoke") def test_install_packages(self, mock_invoke, mock_getenv): @@ -1147,10 +1005,8 @@ def test_install_packages(self, mock_invoke, mock_getenv): api_key="test-api-key", arguments={ "operate_type": "execute_command", - "arguments": { - "command": f"pip install {' '.join(packages)} " - } - } + "arguments": {"command": f"pip install {' '.join(packages)} "}, + }, ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") @@ -1175,10 +1031,8 @@ def test_install_packages_with_version(self, mock_invoke, mock_getenv): api_key="test-api-key", arguments={ "operate_type": "execute_command", - "arguments": { - "command": f"pip install {' '.join(packages)} " - } - } + "arguments": {"command": f"pip install {' '.join(packages)} "}, + }, ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") @@ -1194,10 +1048,7 @@ def test_install_packages_with_upgrade(self, mock_invoke, mock_getenv): mock_getenv.return_value = "test-api-key" # Act - self.code_interpreter_client.install_packages( - packages=packages, - upgrade=True - ) + self.code_interpreter_client.install_packages(packages=packages, upgrade=True) # Assert mock_invoke.assert_called_once_with( @@ -1206,10 +1057,8 @@ def test_install_packages_with_upgrade(self, mock_invoke, mock_getenv): api_key="test-api-key", arguments={ "operate_type": "execute_command", - "arguments": { - "command": f"pip install {' '.join(packages)} --upgrade" - } - } + "arguments": {"command": f"pip install {' '.join(packages)} --upgrade"}, + }, ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY") @@ -1225,8 +1074,6 @@ def test_install_packages_with_invalid_package(self): with pytest.raises(ValueError, match=error_message): self.code_interpreter_client.install_packages(invalid_package) - - @patch.object(os, "getenv") @patch.object(DataToolsHttpClient, "invoke") def test_clear_context(self, mock_invoke, mock_getenv): @@ -1250,8 +1097,8 @@ def test_clear_context(self, mock_invoke, mock_getenv): "arguments": { "code": "# Context cleared", "language": "python", - "clear_context": True - } - } + "clear_context": True, + }, + }, ) mock_getenv.assert_called_once_with("HUAWEICLOUD_SDK_CODE_INTERPRETER_API_KEY")