From 05a0d8c13ed3788298fa13e4640f13fc45720a5e Mon Sep 17 00:00:00 2001 From: goodmorningsaksham Date: Tue, 28 Apr 2026 03:00:45 +0530 Subject: [PATCH 1/8] fix: preserve Observation.metadata in serialize_observation() metadata was excluded from model_dump() alongside reward and done, but unlike those two it was never re-added to the response dict, silently dropped during serialization. This promotes metadata to a top-level sibling key consistent with reward/done, adds it to StepResult, and updates GenericEnvClient to pass it through. Includes 7 new tests covering preservation, empty-metadata omission, round-trip through client, and subclass observation fields. --- src/openenv/core/client_types.py | 4 +- src/openenv/core/env_server/serialization.py | 15 ++- src/openenv/core/generic_client.py | 1 + tests/core/test_serialization.py | 123 +++++++++++++++++++ 4 files changed, 138 insertions(+), 5 deletions(-) create mode 100644 tests/core/test_serialization.py diff --git a/src/openenv/core/client_types.py b/src/openenv/core/client_types.py index c7501c656..8934f2395 100644 --- a/src/openenv/core/client_types.py +++ b/src/openenv/core/client_types.py @@ -1,6 +1,6 @@ # Type definitions for EnvTorch from dataclasses import dataclass -from typing import Generic, Optional, TypeVar +from typing import Any, Dict, Generic, Optional, TypeVar # Generic type for observations ObsT = TypeVar("ObsT") @@ -16,8 +16,10 @@ class StepResult(Generic[ObsT]): observation: The environment's observation after the action. reward: Scalar reward for this step (optional). done: Whether the episode is finished. + metadata: Optional dictionary of additional metadata from the environment. """ observation: ObsT reward: Optional[float] = None done: bool = False + metadata: Optional[Dict[str, Any]] = None diff --git a/src/openenv/core/env_server/serialization.py b/src/openenv/core/env_server/serialization.py index fd5fb588c..2750651d0 100644 --- a/src/openenv/core/env_server/serialization.py +++ b/src/openenv/core/env_server/serialization.py @@ -150,22 +150,29 @@ def serialize_observation(observation: Observation) -> Dict[str, Any]: "done": bool, } """ - # Use Pydantic's model_dump() for serialization + # Use Pydantic's model_dump() for serialization. + # reward and done are promoted to top-level sibling keys so that + # EnvClient._parse_result() can read them directly. metadata is also + # promoted so it is not lost during serialization. obs_dict = observation.model_dump( exclude={ "reward", "done", "metadata", - } # Exclude these from observation dict + } ) - # Extract reward and done directly from the observation + # Extract reward, done, and metadata directly from the observation reward = observation.reward done = observation.done + metadata = observation.metadata # Return in EnvClient expected format - return { + result = { "observation": obs_dict, "reward": reward, "done": done, } + if metadata: + result["metadata"] = metadata + return result \ No newline at end of file diff --git a/src/openenv/core/generic_client.py b/src/openenv/core/generic_client.py index 175768622..908ad64db 100644 --- a/src/openenv/core/generic_client.py +++ b/src/openenv/core/generic_client.py @@ -103,6 +103,7 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Dict[str, Any]]: observation=payload.get("observation", {}), reward=payload.get("reward"), done=payload.get("done", False), + metadata=payload.get("metadata"), ) def _parse_state(self, payload: Dict[str, Any]) -> Dict[str, Any]: diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py new file mode 100644 index 000000000..9a924b103 --- /dev/null +++ b/tests/core/test_serialization.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for observation serialization, specifically metadata preservation. + +Ensures that Observation.metadata survives the serialize -> deserialize +round-trip through serialize_observation() and GenericEnvClient._parse_result(). +""" + +import pytest +from openenv.core.env_server.serialization import serialize_observation +from openenv.core.env_server.types import Observation +from openenv.core.generic_client import GenericEnvClient + + +# --------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------- + +class CustomObservation(Observation): + """Observation subclass with domain-specific fields.""" + ally_tree: str = "" + task_instruction: str = "" + + +# --------------------------------------------------------------------- +# serialize_observation tests +# --------------------------------------------------------------------- + +class TestSerializeObservation: + """Tests for serialize_observation().""" + + def test_metadata_preserved_in_serialized_output(self): + """Metadata must appear in the serialized dict, not be silently dropped.""" + obs = Observation( + done=False, + reward=0.5, + metadata={"total_nodes": 5, "task_id": 1}, + ) + result = serialize_observation(obs) + + assert "metadata" in result + assert result["metadata"]["total_nodes"] == 5 + assert result["metadata"]["task_id"] == 1 + + def test_empty_metadata_omitted(self): + """When metadata is empty, it should not clutter the response.""" + obs = Observation(done=False, reward=0.0, metadata={}) + result = serialize_observation(obs) + + assert "metadata" not in result + + def test_reward_and_done_promoted(self): + """reward and done must be top-level siblings, not inside observation.""" + obs = Observation(done=True, reward=1.0, metadata={"k": "v"}) + result = serialize_observation(obs) + + assert result["reward"] == 1.0 + assert result["done"] is True + assert "reward" not in result["observation"] + assert "done" not in result["observation"] + + def test_metadata_not_inside_observation(self): + """metadata must be a top-level sibling, not nested inside observation.""" + obs = Observation(done=False, reward=0.0, metadata={"step": 3}) + result = serialize_observation(obs) + + assert "metadata" not in result["observation"] + assert result["metadata"]["step"] == 3 + + def test_custom_observation_fields_in_observation_dict(self): + """Subclass fields must appear inside the observation dict.""" + obs = CustomObservation( + ally_tree="[ref=btn_1 role=button]", + task_instruction="Book a ticket", + done=False, + reward=0.2, + metadata={"variant": "label_drift"}, + ) + result = serialize_observation(obs) + + assert result["observation"]["ally_tree"] == "[ref=btn_1 role=button]" + assert result["observation"]["task_instruction"] == "Book a ticket" + assert result["metadata"]["variant"] == "label_drift" + + +# --------------------------------------------------------------------- +# Round-trip: serialize -> client parse +# --------------------------------------------------------------------- + +class TestMetadataRoundTrip: + """End-to-end: serialize on server, parse on client.""" + + def test_generic_client_receives_metadata(self): + """GenericEnvClient._parse_result must surface metadata from payload.""" + obs = Observation( + done=False, + reward=0.42, + metadata={"total_nodes": 6, "completed": ["origin", "dest"]}, + ) + payload = serialize_observation(obs) + + client = GenericEnvClient.__new__(GenericEnvClient) + step_result = client._parse_result(payload) + + assert step_result.reward == 0.42 + assert step_result.done is False + assert step_result.metadata is not None + assert step_result.metadata["total_nodes"] == 6 + assert step_result.metadata["completed"] == ["origin", "dest"] + + def test_generic_client_handles_missing_metadata(self): + """When server sends no metadata, StepResult.metadata should be None.""" + payload = {"observation": {"text": "hello"}, "reward": 0.0, "done": False} + + client = GenericEnvClient.__new__(GenericEnvClient) + step_result = client._parse_result(payload) + + assert step_result.metadata is None \ No newline at end of file From 9829cbff9cc8ef4089db367d8348e1026706bbe5 Mon Sep 17 00:00:00 2001 From: goodmorningsaksham Date: Tue, 28 Apr 2026 03:53:56 +0530 Subject: [PATCH 2/8] fix: add metadata field to RestResponse/StepResponse for extra=forbid compat --- src/openenv/core/env_server/serialization.py | 2 +- src/openenv/core/env_server/types.py | 9 +++++++-- tests/core/test_serialization.py | 6 +++++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/openenv/core/env_server/serialization.py b/src/openenv/core/env_server/serialization.py index 2750651d0..0eed9909c 100644 --- a/src/openenv/core/env_server/serialization.py +++ b/src/openenv/core/env_server/serialization.py @@ -175,4 +175,4 @@ def serialize_observation(observation: Observation) -> Dict[str, Any]: } if metadata: result["metadata"] = metadata - return result \ No newline at end of file + return result diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 34a198013..de5d0c48d 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -121,6 +121,9 @@ class ResetResponse(BaseModel): done: bool = Field( default=False, description="Whether episode is already done (typically False)" ) + metadata: Optional[Dict[str, Any]] = Field( + default=None, description="Additional metadata from the environment" + ) class StepRequest(BaseModel): @@ -164,8 +167,10 @@ class StepResponse(BaseModel): default=None, description="Reward signal from the action" ) done: bool = Field(default=False, description="Whether the episode has terminated") - - + metadata: Optional[Dict[str, Any]] = Field( + default=None, description="Additional metadata from the environment" + ) + class BaseMessage(BaseModel): """Base class for WebSocket messages with shared configuration.""" diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py index 9a924b103..6ce9d73d8 100644 --- a/tests/core/test_serialization.py +++ b/tests/core/test_serialization.py @@ -21,8 +21,10 @@ # Helpers # --------------------------------------------------------------------- + class CustomObservation(Observation): """Observation subclass with domain-specific fields.""" + ally_tree: str = "" task_instruction: str = "" @@ -31,6 +33,7 @@ class CustomObservation(Observation): # serialize_observation tests # --------------------------------------------------------------------- + class TestSerializeObservation: """Tests for serialize_observation().""" @@ -92,6 +95,7 @@ def test_custom_observation_fields_in_observation_dict(self): # Round-trip: serialize -> client parse # --------------------------------------------------------------------- + class TestMetadataRoundTrip: """End-to-end: serialize on server, parse on client.""" @@ -120,4 +124,4 @@ def test_generic_client_handles_missing_metadata(self): client = GenericEnvClient.__new__(GenericEnvClient) step_result = client._parse_result(payload) - assert step_result.metadata is None \ No newline at end of file + assert step_result.metadata is None From 5d1976ffd4bf919393edb4adaca74b3a59c62546 Mon Sep 17 00:00:00 2001 From: goodmorningsaksham Date: Tue, 28 Apr 2026 11:47:09 +0530 Subject: [PATCH 3/8] fix: trailing whitespace in types.py diff --- src/openenv/core/env_server/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index de5d0c48d..0ef0df08c 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -170,7 +170,8 @@ class StepResponse(BaseModel): metadata: Optional[Dict[str, Any]] = Field( default=None, description="Additional metadata from the environment" ) - + + class BaseMessage(BaseModel): """Base class for WebSocket messages with shared configuration.""" From c879df5d5ac02ea056843edcbfae1e6ab346a5c6 Mon Sep 17 00:00:00 2001 From: goodmorningsaksham Date: Tue, 28 Apr 2026 18:21:41 +0530 Subject: [PATCH 4/8] fix: address review comments for metadata handling --- src/openenv/core/env_server/serialization.py | 2 ++ src/openenv/core/mcp_client.py | 1 + 2 files changed, 3 insertions(+) diff --git a/src/openenv/core/env_server/serialization.py b/src/openenv/core/env_server/serialization.py index 0eed9909c..64873106e 100644 --- a/src/openenv/core/env_server/serialization.py +++ b/src/openenv/core/env_server/serialization.py @@ -148,6 +148,8 @@ def serialize_observation(observation: Observation) -> Dict[str, Any]: "observation": {...}, # Observation fields "reward": float | None, "done": bool, + "metadata": dict, + } } """ # Use Pydantic's model_dump() for serialization. diff --git a/src/openenv/core/mcp_client.py b/src/openenv/core/mcp_client.py index 1d8bd38ef..2d5483e2c 100644 --- a/src/openenv/core/mcp_client.py +++ b/src/openenv/core/mcp_client.py @@ -300,6 +300,7 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: observation=observation, reward=payload.get("reward"), done=payload.get("done", False), + metadata=payload.get("metadata"), ) def _parse_state(self, payload: Dict[str, Any]) -> State: From df656a5ee9be73092672e6d80ae1b907c1f93715 Mon Sep 17 00:00:00 2001 From: goodmorningsaksham Date: Tue, 28 Apr 2026 21:39:06 +0530 Subject: [PATCH 5/8] fix: preserve metadata in serialization and add reset path coverage - Use \if metadata is not None:\ in serialize_observation() to avoid lossy round-trip when metadata is an empty dict - Fix mcp_client._parse_result() to read metadata from top-level payload instead of nested observation dict - Add metadata field to StepResponse and ResetResponse Pydantic models - Add test_reset_metadata_preserved to cover reset path round-trip --- src/openenv/core/env_server/serialization.py | 2 +- src/openenv/core/mcp_client.py | 6 +++--- tests/core/test_serialization.py | 9 ++++++++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/openenv/core/env_server/serialization.py b/src/openenv/core/env_server/serialization.py index 64873106e..67604492a 100644 --- a/src/openenv/core/env_server/serialization.py +++ b/src/openenv/core/env_server/serialization.py @@ -175,6 +175,6 @@ def serialize_observation(observation: Observation) -> Dict[str, Any]: "reward": reward, "done": done, } - if metadata: + if metadata is not None: result["metadata"] = metadata return result diff --git a/src/openenv/core/mcp_client.py b/src/openenv/core/mcp_client.py index 2d5483e2c..92771d039 100644 --- a/src/openenv/core/mcp_client.py +++ b/src/openenv/core/mcp_client.py @@ -272,7 +272,7 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: tools=tools, done=payload.get("done", False), reward=payload.get("reward"), - metadata=obs_data.get("metadata", {}), + metadata=payload.get("metadata", {}), ) # Check if this is a CallToolObservation elif "tool_name" in obs_data: @@ -286,14 +286,14 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: error=error, done=payload.get("done", False), reward=payload.get("reward"), - metadata=obs_data.get("metadata", {}), + metadata=payload.get("metadata", {}), ) else: # Generic observation observation = Observation( done=payload.get("done", False), reward=payload.get("reward"), - metadata=obs_data.get("metadata", {}), + metadata=payload.get("metadata", {}), ) return StepResult( diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py index 6ce9d73d8..9efdc1ea0 100644 --- a/tests/core/test_serialization.py +++ b/tests/core/test_serialization.py @@ -13,7 +13,7 @@ import pytest from openenv.core.env_server.serialization import serialize_observation -from openenv.core.env_server.types import Observation +from openenv.core.env_server.types import Observation, ResetResponse from openenv.core.generic_client import GenericEnvClient @@ -90,6 +90,13 @@ def test_custom_observation_fields_in_observation_dict(self): assert result["observation"]["task_instruction"] == "Book a ticket" assert result["metadata"]["variant"] == "label_drift" + def test_reset_metadata_preserved(self): + """ResetResponse must preserve metadata from the observation.""" + obs = Observation(metadata={"reset_key": "val"}) + serialized = serialize_observation(obs) + reset_response = ResetResponse(**serialized) + assert reset_response.metadata == {"reset_key": "val"} + # --------------------------------------------------------------------- # Round-trip: serialize -> client parse From 449c48cae547995c0f881fb001f43da681543e78 Mon Sep 17 00:00:00 2001 From: goodmorningsaksham Date: Wed, 29 Apr 2026 02:24:10 +0530 Subject: [PATCH 6/8] fix: address review comments - metadata guard, mcp consistency, tests --- src/openenv/core/env_server/serialization.py | 2 +- src/openenv/core/mcp_client.py | 6 +++--- tests/core/test_serialization.py | 11 +++++++++-- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/openenv/core/env_server/serialization.py b/src/openenv/core/env_server/serialization.py index 67604492a..64873106e 100644 --- a/src/openenv/core/env_server/serialization.py +++ b/src/openenv/core/env_server/serialization.py @@ -175,6 +175,6 @@ def serialize_observation(observation: Observation) -> Dict[str, Any]: "reward": reward, "done": done, } - if metadata is not None: + if metadata: result["metadata"] = metadata return result diff --git a/src/openenv/core/mcp_client.py b/src/openenv/core/mcp_client.py index 92771d039..d54020f47 100644 --- a/src/openenv/core/mcp_client.py +++ b/src/openenv/core/mcp_client.py @@ -272,7 +272,7 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: tools=tools, done=payload.get("done", False), reward=payload.get("reward"), - metadata=payload.get("metadata", {}), + metadata=payload.get("metadata"), ) # Check if this is a CallToolObservation elif "tool_name" in obs_data: @@ -286,14 +286,14 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: error=error, done=payload.get("done", False), reward=payload.get("reward"), - metadata=payload.get("metadata", {}), + metadata=payload.get("metadata"), ) else: # Generic observation observation = Observation( done=payload.get("done", False), reward=payload.get("reward"), - metadata=payload.get("metadata", {}), + metadata=payload.get("metadata"), ) return StepResult( diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py index 9efdc1ea0..b64a55841 100644 --- a/tests/core/test_serialization.py +++ b/tests/core/test_serialization.py @@ -13,7 +13,7 @@ import pytest from openenv.core.env_server.serialization import serialize_observation -from openenv.core.env_server.types import Observation, ResetResponse +from openenv.core.env_server.types import Observation, ResetResponse, StepResponse from openenv.core.generic_client import GenericEnvClient @@ -97,6 +97,13 @@ def test_reset_metadata_preserved(self): reset_response = ResetResponse(**serialized) assert reset_response.metadata == {"reset_key": "val"} + def test_step_response_metadata_preserved(self): + """StepResponse must preserve metadata from the observation.""" + obs = Observation(metadata={"step_key": "val"}) + serialized = serialize_observation(obs) + step_response = StepResponse(**serialized) + assert step_response.metadata == {"step_key": "val"} + # --------------------------------------------------------------------- # Round-trip: serialize -> client parse @@ -131,4 +138,4 @@ def test_generic_client_handles_missing_metadata(self): client = GenericEnvClient.__new__(GenericEnvClient) step_result = client._parse_result(payload) - assert step_result.metadata is None + assert step_result.metadata is None \ No newline at end of file From 1d992fab031b558de1be3d2ac6d8cd25e2ec07c7 Mon Sep 17 00:00:00 2001 From: goodmorningsaksham Date: Wed, 29 Apr 2026 09:25:49 +0530 Subject: [PATCH 7/8] fix: correct metadata fallbacks in mcp_client._parse_result --- src/openenv/core/mcp_client.py | 6 +++--- tests/core/test_serialization.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/openenv/core/mcp_client.py b/src/openenv/core/mcp_client.py index d54020f47..20e0090a5 100644 --- a/src/openenv/core/mcp_client.py +++ b/src/openenv/core/mcp_client.py @@ -272,7 +272,7 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: tools=tools, done=payload.get("done", False), reward=payload.get("reward"), - metadata=payload.get("metadata"), + metadata=payload.get("metadata") or {}, ) # Check if this is a CallToolObservation elif "tool_name" in obs_data: @@ -286,14 +286,14 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: error=error, done=payload.get("done", False), reward=payload.get("reward"), - metadata=payload.get("metadata"), + metadata=payload.get("metadata") or {}, ) else: # Generic observation observation = Observation( done=payload.get("done", False), reward=payload.get("reward"), - metadata=payload.get("metadata"), + metadata=payload.get("metadata") or {}, ) return StepResult( diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py index b64a55841..b4499dd68 100644 --- a/tests/core/test_serialization.py +++ b/tests/core/test_serialization.py @@ -138,4 +138,4 @@ def test_generic_client_handles_missing_metadata(self): client = GenericEnvClient.__new__(GenericEnvClient) step_result = client._parse_result(payload) - assert step_result.metadata is None \ No newline at end of file + assert step_result.metadata is None From 5e15e3875c4ca6c429dcf4b181f3d5c234db3b3c Mon Sep 17 00:00:00 2001 From: goodmorningsaksham Date: Wed, 29 Apr 2026 21:26:44 +0530 Subject: [PATCH 8/8] fix: use is not None guard for metadata, remove or {} coercion, add falsy metadata test --- src/openenv/core/env_server/serialization.py | 2 +- src/openenv/core/mcp_client.py | 6 +++--- tests/core/test_serialization.py | 8 ++++++++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/openenv/core/env_server/serialization.py b/src/openenv/core/env_server/serialization.py index 64873106e..67604492a 100644 --- a/src/openenv/core/env_server/serialization.py +++ b/src/openenv/core/env_server/serialization.py @@ -175,6 +175,6 @@ def serialize_observation(observation: Observation) -> Dict[str, Any]: "reward": reward, "done": done, } - if metadata: + if metadata is not None: result["metadata"] = metadata return result diff --git a/src/openenv/core/mcp_client.py b/src/openenv/core/mcp_client.py index 20e0090a5..d54020f47 100644 --- a/src/openenv/core/mcp_client.py +++ b/src/openenv/core/mcp_client.py @@ -272,7 +272,7 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: tools=tools, done=payload.get("done", False), reward=payload.get("reward"), - metadata=payload.get("metadata") or {}, + metadata=payload.get("metadata"), ) # Check if this is a CallToolObservation elif "tool_name" in obs_data: @@ -286,14 +286,14 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: error=error, done=payload.get("done", False), reward=payload.get("reward"), - metadata=payload.get("metadata") or {}, + metadata=payload.get("metadata"), ) else: # Generic observation observation = Observation( done=payload.get("done", False), reward=payload.get("reward"), - metadata=payload.get("metadata") or {}, + metadata=payload.get("metadata"), ) return StepResult( diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py index b4499dd68..e2bab2882 100644 --- a/tests/core/test_serialization.py +++ b/tests/core/test_serialization.py @@ -57,6 +57,14 @@ def test_empty_metadata_omitted(self): assert "metadata" not in result + def test_falsy_metadata_preserved(self): + """Falsy metadata values must not be silently dropped.""" + obs = Observation(metadata={"active": False, "count": 0, "flag": ""}) + result = serialize_observation(obs) + + assert "metadata" in result + assert result["metadata"] == {"active": False, "count": 0, "flag": ""} + def test_reward_and_done_promoted(self): """reward and done must be top-level siblings, not inside observation.""" obs = Observation(done=True, reward=1.0, metadata={"k": "v"})