-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclient.py
More file actions
136 lines (111 loc) · 5.1 KB
/
Copy pathclient.py
File metadata and controls
136 lines (111 loc) · 5.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""QED Math Environment Client.
Provides tool-calling style interactions with the QED Math environment
via MCP (Model Context Protocol).
Example:
>>> with QEDMathEnv(base_url="http://localhost:8000") as env:
... env.reset()
... tools = env.list_tools()
... print([t.name for t in tools])
... result = env.call_tool("get_problem")
... result = env.call_tool("submit_proof", proof="By induction...")
"""
from typing import Any, Mapping, Optional
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import Observation, State
from openenv.core.mcp_client import MCPToolClient
from models import ProblemObservation, ProofSubmissionObservation
class QEDMathEnv(MCPToolClient):
"""
Client for the QED Math Environment.
Inherits MCP tool-calling interface from MCPToolClient:
- ``list_tools()``: Discover available MCP tools
- ``call_tool(name, **kwargs)``: Call a tool by name
- ``reset(**kwargs)``: Reset the environment
Example:
>>> with QEDMathEnv(base_url="http://localhost:8000") as env:
... env.reset()
... result = env.call_tool("get_problem")
... result = env.call_tool("submit_proof", proof="By induction...")
"""
@staticmethod
def _as_problem_observation(value: Any) -> ProblemObservation:
"""Normalize tool/reset outputs into a ProblemObservation instance."""
if isinstance(value, ProblemObservation):
return value
if isinstance(value, Mapping):
return ProblemObservation(**dict(value))
if hasattr(value, "model_dump"):
return ProblemObservation(**value.model_dump())
raise TypeError(f"Unsupported problem observation payload type: {type(value).__name__}")
@staticmethod
def _as_proof_submission_observation(value: Any) -> ProofSubmissionObservation:
"""Normalize tool outputs into a ProofSubmissionObservation instance."""
if isinstance(value, ProofSubmissionObservation):
return value
if isinstance(value, Mapping):
return ProofSubmissionObservation(**dict(value))
if hasattr(value, "model_dump"):
return ProofSubmissionObservation(**value.model_dump())
raise TypeError(f"Unsupported proof submission payload type: {type(value).__name__}")
async def reset(
self, problem_id: Optional[str] = None, **kwargs: Any
) -> StepResult[Observation]:
"""
Reset the environment, optionally selecting a specific problem.
Args:
problem_id: Optional problem identifier to load a specific problem.
If None, a problem is chosen randomly from the dataset.
**kwargs: Additional reset parameters (e.g., seed).
Returns:
StepResult with a normalized ProblemObservation in `observation`.
"""
if problem_id is not None:
kwargs["problem_id"] = problem_id
result = await super().reset(**kwargs)
observation = result.observation if isinstance(result, StepResult) else result
normalized_observation = self._as_problem_observation(observation)
return StepResult(
observation=normalized_observation,
reward=result.reward,
done=result.done,
)
async def submit_proof(self, proof: str) -> ProofSubmissionObservation:
"""
Submit a proof attempt for the current problem.
Args:
proof: The proof text to submit for grading.
Returns:
ProofSubmissionObservation with score (0-7), feedback, and reward.
"""
result = await self.call_tool("submit_proof", proof=proof)
return self._as_proof_submission_observation(result)
async def get_current_problem(self) -> ProblemObservation:
"""
Retrieve the current problem statement without resetting.
Returns:
ProblemObservation for the active problem.
"""
result = await self.call_tool("get_problem")
return self._as_problem_observation(result)
async def get_problem(self) -> ProblemObservation:
"""Compatibility alias for get_current_problem()."""
return await self.get_current_problem()
async def get_grading_feedback(self) -> dict[str, Any]:
"""
Retrieve the grading guidelines/rubric for the current problem.
Returns:
Tool payload containing grading_guidelines and problem metadata.
"""
result = await self.call_tool("get_grading_guidelines")
if isinstance(result, Mapping):
return dict(result)
if hasattr(result, "model_dump"):
return result.model_dump()
raise TypeError(f"Unsupported grading feedback payload type: {type(result).__name__}")
async def get_state(self) -> State:
"""Return current environment state (episode_id, step_count)."""
return await super().state()
def get_state_sync(self) -> State:
"""Synchronous helper for code paths that do not use async/await."""
with self.sync() as client:
return client.state()