Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/openenv/core/client_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Type definitions for EnvTorch
from dataclasses import dataclass
from typing import Generic, Optional, TypeVar
from typing import Any, Dict, Generic, Optional, TypeVar

# Generic type for observations
ObsT = TypeVar("ObsT")
Expand All @@ -16,8 +16,10 @@ class StepResult(Generic[ObsT]):
observation: The environment's observation after the action.
reward: Scalar reward for this step (optional).
done: Whether the episode is finished.
metadata: Optional dictionary of additional metadata from the environment.
"""

observation: ObsT
reward: Optional[float] = None
done: bool = False
metadata: Optional[Dict[str, Any]] = None
17 changes: 13 additions & 4 deletions src/openenv/core/env_server/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,24 +148,33 @@ def serialize_observation(observation: Observation) -> Dict[str, Any]:
"observation": {...}, # Observation fields
"reward": float | None,
"done": bool,
"metadata": dict,
}
}
"""
# Use Pydantic's model_dump() for serialization
# Use Pydantic's model_dump() for serialization.
# reward and done are promoted to top-level sibling keys so that
# EnvClient._parse_result() can read them directly. metadata is also
# promoted so it is not lost during serialization.
obs_dict = observation.model_dump(
exclude={
"reward",
"done",
"metadata",
} # Exclude these from observation dict
}
)

# Extract reward and done directly from the observation
# Extract reward, done, and metadata directly from the observation
reward = observation.reward
done = observation.done
metadata = observation.metadata

# Return in EnvClient expected format
return {
result = {
"observation": obs_dict,
"reward": reward,
"done": done,
}
if metadata is not None:
result["metadata"] = metadata
return result
6 changes: 6 additions & 0 deletions src/openenv/core/env_server/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class ResetResponse(BaseModel):
done: bool = Field(
default=False, description="Whether episode is already done (typically False)"
)
metadata: Optional[Dict[str, Any]] = Field(
default=None, description="Additional metadata from the environment"
)


class StepRequest(BaseModel):
Expand Down Expand Up @@ -164,6 +167,9 @@ class StepResponse(BaseModel):
default=None, description="Reward signal from the action"
)
done: bool = Field(default=False, description="Whether the episode has terminated")
metadata: Optional[Dict[str, Any]] = Field(
default=None, description="Additional metadata from the environment"
)


class BaseMessage(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions src/openenv/core/generic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Dict[str, Any]]:
observation=payload.get("observation", {}),
reward=payload.get("reward"),
done=payload.get("done", False),
metadata=payload.get("metadata"),
)

def _parse_state(self, payload: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
7 changes: 4 additions & 3 deletions src/openenv/core/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]:
tools=tools,
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
metadata=payload.get("metadata"),
)
# Check if this is a CallToolObservation
elif "tool_name" in obs_data:
Expand All @@ -286,20 +286,21 @@ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]:
error=error,
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
metadata=payload.get("metadata"),
)
else:
# Generic observation
observation = Observation(
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
metadata=payload.get("metadata"),
)

return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
metadata=payload.get("metadata"),
)

def _parse_state(self, payload: Dict[str, Any]) -> State:
Expand Down
149 changes: 149 additions & 0 deletions tests/core/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Tests for observation serialization, specifically metadata preservation.

Ensures that Observation.metadata survives the serialize -> deserialize
round-trip through serialize_observation() and GenericEnvClient._parse_result().
"""

import pytest
from openenv.core.env_server.serialization import serialize_observation
from openenv.core.env_server.types import Observation, ResetResponse, StepResponse
from openenv.core.generic_client import GenericEnvClient


# ---------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------


class CustomObservation(Observation):
"""Observation subclass with domain-specific fields."""

ally_tree: str = ""
task_instruction: str = ""


# ---------------------------------------------------------------------
# serialize_observation tests
# ---------------------------------------------------------------------


class TestSerializeObservation:
"""Tests for serialize_observation()."""

def test_metadata_preserved_in_serialized_output(self):
"""Metadata must appear in the serialized dict, not be silently dropped."""
obs = Observation(
done=False,
reward=0.5,
metadata={"total_nodes": 5, "task_id": 1},
)
result = serialize_observation(obs)

assert "metadata" in result
assert result["metadata"]["total_nodes"] == 5
assert result["metadata"]["task_id"] == 1

def test_empty_metadata_omitted(self):
"""When metadata is empty, it should not clutter the response."""
obs = Observation(done=False, reward=0.0, metadata={})
result = serialize_observation(obs)

assert "metadata" not in result

def test_falsy_metadata_preserved(self):
"""Falsy metadata values must not be silently dropped."""
obs = Observation(metadata={"active": False, "count": 0, "flag": ""})
result = serialize_observation(obs)

assert "metadata" in result
assert result["metadata"] == {"active": False, "count": 0, "flag": ""}

def test_reward_and_done_promoted(self):
"""reward and done must be top-level siblings, not inside observation."""
obs = Observation(done=True, reward=1.0, metadata={"k": "v"})
result = serialize_observation(obs)

assert result["reward"] == 1.0
assert result["done"] is True
assert "reward" not in result["observation"]
assert "done" not in result["observation"]

def test_metadata_not_inside_observation(self):
"""metadata must be a top-level sibling, not nested inside observation."""
obs = Observation(done=False, reward=0.0, metadata={"step": 3})
result = serialize_observation(obs)

assert "metadata" not in result["observation"]
assert result["metadata"]["step"] == 3

def test_custom_observation_fields_in_observation_dict(self):
"""Subclass fields must appear inside the observation dict."""
obs = CustomObservation(
ally_tree="[ref=btn_1 role=button]",
task_instruction="Book a ticket",
done=False,
reward=0.2,
metadata={"variant": "label_drift"},
)
result = serialize_observation(obs)

assert result["observation"]["ally_tree"] == "[ref=btn_1 role=button]"
assert result["observation"]["task_instruction"] == "Book a ticket"
assert result["metadata"]["variant"] == "label_drift"

def test_reset_metadata_preserved(self):
"""ResetResponse must preserve metadata from the observation."""
obs = Observation(metadata={"reset_key": "val"})
serialized = serialize_observation(obs)
reset_response = ResetResponse(**serialized)
assert reset_response.metadata == {"reset_key": "val"}

def test_step_response_metadata_preserved(self):
"""StepResponse must preserve metadata from the observation."""
obs = Observation(metadata={"step_key": "val"})
serialized = serialize_observation(obs)
step_response = StepResponse(**serialized)
assert step_response.metadata == {"step_key": "val"}


# ---------------------------------------------------------------------
# Round-trip: serialize -> client parse
# ---------------------------------------------------------------------


class TestMetadataRoundTrip:
"""End-to-end: serialize on server, parse on client."""

def test_generic_client_receives_metadata(self):
"""GenericEnvClient._parse_result must surface metadata from payload."""
obs = Observation(
done=False,
reward=0.42,
metadata={"total_nodes": 6, "completed": ["origin", "dest"]},
)
payload = serialize_observation(obs)

client = GenericEnvClient.__new__(GenericEnvClient)
step_result = client._parse_result(payload)

assert step_result.reward == 0.42
assert step_result.done is False
assert step_result.metadata is not None
assert step_result.metadata["total_nodes"] == 6
assert step_result.metadata["completed"] == ["origin", "dest"]

def test_generic_client_handles_missing_metadata(self):
"""When server sends no metadata, StepResult.metadata should be None."""
payload = {"observation": {"text": "hello"}, "reward": 0.0, "done": False}

client = GenericEnvClient.__new__(GenericEnvClient)
step_result = client._parse_result(payload)

assert step_result.metadata is None