diff --git a/src/openenv/core/client_types.py b/src/openenv/core/client_types.py index c7501c656..8934f2395 100644 --- a/src/openenv/core/client_types.py +++ b/src/openenv/core/client_types.py @@ -1,6 +1,6 @@ # Type definitions for EnvTorch from dataclasses import dataclass -from typing import Generic, Optional, TypeVar +from typing import Any, Dict, Generic, Optional, TypeVar # Generic type for observations ObsT = TypeVar("ObsT") @@ -16,8 +16,10 @@ class StepResult(Generic[ObsT]): observation: The environment's observation after the action. reward: Scalar reward for this step (optional). done: Whether the episode is finished. + metadata: Optional dictionary of additional metadata from the environment. """ observation: ObsT reward: Optional[float] = None done: bool = False + metadata: Optional[Dict[str, Any]] = None diff --git a/src/openenv/core/env_server/serialization.py b/src/openenv/core/env_server/serialization.py index fd5fb588c..67604492a 100644 --- a/src/openenv/core/env_server/serialization.py +++ b/src/openenv/core/env_server/serialization.py @@ -148,24 +148,33 @@ def serialize_observation(observation: Observation) -> Dict[str, Any]: "observation": {...}, # Observation fields "reward": float | None, "done": bool, + "metadata": dict, + } } """ - # Use Pydantic's model_dump() for serialization + # Use Pydantic's model_dump() for serialization. + # reward and done are promoted to top-level sibling keys so that + # EnvClient._parse_result() can read them directly. metadata is also + # promoted so it is not lost during serialization. obs_dict = observation.model_dump( exclude={ "reward", "done", "metadata", - } # Exclude these from observation dict + } ) - # Extract reward and done directly from the observation + # Extract reward, done, and metadata directly from the observation reward = observation.reward done = observation.done + metadata = observation.metadata # Return in EnvClient expected format - return { + result = { "observation": obs_dict, "reward": reward, "done": done, } + if metadata is not None: + result["metadata"] = metadata + return result diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 34a198013..0ef0df08c 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -121,6 +121,9 @@ class ResetResponse(BaseModel): done: bool = Field( default=False, description="Whether episode is already done (typically False)" ) + metadata: Optional[Dict[str, Any]] = Field( + default=None, description="Additional metadata from the environment" + ) class StepRequest(BaseModel): @@ -164,6 +167,9 @@ class StepResponse(BaseModel): default=None, description="Reward signal from the action" ) done: bool = Field(default=False, description="Whether the episode has terminated") + metadata: Optional[Dict[str, Any]] = Field( + default=None, description="Additional metadata from the environment" + ) class BaseMessage(BaseModel): diff --git a/src/openenv/core/generic_client.py b/src/openenv/core/generic_client.py index 175768622..908ad64db 100644 --- a/src/openenv/core/generic_client.py +++ b/src/openenv/core/generic_client.py @@ -103,6 +103,7 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Dict[str, Any]]: observation=payload.get("observation", {}), reward=payload.get("reward"), done=payload.get("done", False), + metadata=payload.get("metadata"), ) def _parse_state(self, payload: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/openenv/core/mcp_client.py b/src/openenv/core/mcp_client.py index 1d8bd38ef..d54020f47 100644 --- a/src/openenv/core/mcp_client.py +++ b/src/openenv/core/mcp_client.py @@ -272,7 +272,7 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: tools=tools, done=payload.get("done", False), reward=payload.get("reward"), - metadata=obs_data.get("metadata", {}), + metadata=payload.get("metadata"), ) # Check if this is a CallToolObservation elif "tool_name" in obs_data: @@ -286,20 +286,21 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: error=error, done=payload.get("done", False), reward=payload.get("reward"), - metadata=obs_data.get("metadata", {}), + metadata=payload.get("metadata"), ) else: # Generic observation observation = Observation( done=payload.get("done", False), reward=payload.get("reward"), - metadata=obs_data.get("metadata", {}), + metadata=payload.get("metadata"), ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), + metadata=payload.get("metadata"), ) def _parse_state(self, payload: Dict[str, Any]) -> State: diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py new file mode 100644 index 000000000..e2bab2882 --- /dev/null +++ b/tests/core/test_serialization.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for observation serialization, specifically metadata preservation. + +Ensures that Observation.metadata survives the serialize -> deserialize +round-trip through serialize_observation() and GenericEnvClient._parse_result(). +""" + +import pytest +from openenv.core.env_server.serialization import serialize_observation +from openenv.core.env_server.types import Observation, ResetResponse, StepResponse +from openenv.core.generic_client import GenericEnvClient + + +# --------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------- + + +class CustomObservation(Observation): + """Observation subclass with domain-specific fields.""" + + ally_tree: str = "" + task_instruction: str = "" + + +# --------------------------------------------------------------------- +# serialize_observation tests +# --------------------------------------------------------------------- + + +class TestSerializeObservation: + """Tests for serialize_observation().""" + + def test_metadata_preserved_in_serialized_output(self): + """Metadata must appear in the serialized dict, not be silently dropped.""" + obs = Observation( + done=False, + reward=0.5, + metadata={"total_nodes": 5, "task_id": 1}, + ) + result = serialize_observation(obs) + + assert "metadata" in result + assert result["metadata"]["total_nodes"] == 5 + assert result["metadata"]["task_id"] == 1 + + def test_empty_metadata_omitted(self): + """When metadata is empty, it should not clutter the response.""" + obs = Observation(done=False, reward=0.0, metadata={}) + result = serialize_observation(obs) + + assert "metadata" not in result + + def test_falsy_metadata_preserved(self): + """Falsy metadata values must not be silently dropped.""" + obs = Observation(metadata={"active": False, "count": 0, "flag": ""}) + result = serialize_observation(obs) + + assert "metadata" in result + assert result["metadata"] == {"active": False, "count": 0, "flag": ""} + + def test_reward_and_done_promoted(self): + """reward and done must be top-level siblings, not inside observation.""" + obs = Observation(done=True, reward=1.0, metadata={"k": "v"}) + result = serialize_observation(obs) + + assert result["reward"] == 1.0 + assert result["done"] is True + assert "reward" not in result["observation"] + assert "done" not in result["observation"] + + def test_metadata_not_inside_observation(self): + """metadata must be a top-level sibling, not nested inside observation.""" + obs = Observation(done=False, reward=0.0, metadata={"step": 3}) + result = serialize_observation(obs) + + assert "metadata" not in result["observation"] + assert result["metadata"]["step"] == 3 + + def test_custom_observation_fields_in_observation_dict(self): + """Subclass fields must appear inside the observation dict.""" + obs = CustomObservation( + ally_tree="[ref=btn_1 role=button]", + task_instruction="Book a ticket", + done=False, + reward=0.2, + metadata={"variant": "label_drift"}, + ) + result = serialize_observation(obs) + + assert result["observation"]["ally_tree"] == "[ref=btn_1 role=button]" + assert result["observation"]["task_instruction"] == "Book a ticket" + assert result["metadata"]["variant"] == "label_drift" + + def test_reset_metadata_preserved(self): + """ResetResponse must preserve metadata from the observation.""" + obs = Observation(metadata={"reset_key": "val"}) + serialized = serialize_observation(obs) + reset_response = ResetResponse(**serialized) + assert reset_response.metadata == {"reset_key": "val"} + + def test_step_response_metadata_preserved(self): + """StepResponse must preserve metadata from the observation.""" + obs = Observation(metadata={"step_key": "val"}) + serialized = serialize_observation(obs) + step_response = StepResponse(**serialized) + assert step_response.metadata == {"step_key": "val"} + + +# --------------------------------------------------------------------- +# Round-trip: serialize -> client parse +# --------------------------------------------------------------------- + + +class TestMetadataRoundTrip: + """End-to-end: serialize on server, parse on client.""" + + def test_generic_client_receives_metadata(self): + """GenericEnvClient._parse_result must surface metadata from payload.""" + obs = Observation( + done=False, + reward=0.42, + metadata={"total_nodes": 6, "completed": ["origin", "dest"]}, + ) + payload = serialize_observation(obs) + + client = GenericEnvClient.__new__(GenericEnvClient) + step_result = client._parse_result(payload) + + assert step_result.reward == 0.42 + assert step_result.done is False + assert step_result.metadata is not None + assert step_result.metadata["total_nodes"] == 6 + assert step_result.metadata["completed"] == ["origin", "dest"] + + def test_generic_client_handles_missing_metadata(self): + """When server sends no metadata, StepResult.metadata should be None.""" + payload = {"observation": {"text": "hello"}, "reward": 0.0, "done": False} + + client = GenericEnvClient.__new__(GenericEnvClient) + step_result = client._parse_result(payload) + + assert step_result.metadata is None