Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 86 additions & 38 deletions src/openenv/core/env_server/gradio_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ def _format_observation(data: Dict[str, Any]) -> str:
"""Format reset/step response for Markdown display."""
lines: List[str] = []
obs = data.get("observation", {})

if isinstance(obs, dict):
if obs.get("prompt"):
lines.append(f"**Prompt:**\n\n{_escape_md(obs['prompt'])}\n")

messages = obs.get("messages", [])
if messages:
lines.append("**Messages:**\n")
Expand All @@ -43,12 +45,15 @@ def _format_observation(data: Dict[str, Any]) -> str:
cat = _escape_md(str(msg.get("category", "")))
lines.append(f"- `[{cat}]` Player {sender}: {content}")
lines.append("")

reward = data.get("reward")
done = data.get("done")

if reward is not None:
lines.append(f"**Reward:** `{reward}`")
if done is not None:
lines.append(f"**Done:** `{done}`")

return "\n".join(lines) if lines else "*No observation data*"


Expand Down Expand Up @@ -84,49 +89,56 @@ def build_gradio_app(
action_fields: Field dicts from _extract_action_fields(action_cls).
metadata: Environment metadata for README/name.
is_chat_env: If True, single message textbox; else form from action_fields.
title: App title (overridden by metadata.name when present; see get_gradio_display_title).
quick_start_md: Optional Quick Start markdown (class names already replaced).
title: App title (overridden by metadata.name when present).
quick_start_md: Optional Quick Start markdown.

Returns:
gr.Blocks to mount with gr.mount_gradio_app(app, blocks, path="/web").
"""

readme_content = _readme_section(metadata)
display_title = get_gradio_display_title(metadata, fallback=title)

async def reset_env():
def clear_inputs() -> List[Any]:
"""Return cleared values for all input components."""
if is_chat_env:
return [""]
return [
False if f.get("type") == "checkbox" else None
for f in action_fields
]

async def step(action_data: Dict[str, Any]):
"""Execute a single environment step and format the response."""
try:
data = await web_manager.reset_environment()
data = await web_manager.step_environment(action_data)
obs_md = _format_observation(data)

return (
obs_md,
json.dumps(data, indent=2),
"Environment reset successfully.",
"Step complete.",
)
except Exception as e:
return ("", "", f"Error: {e}")

def _step_with_action(action_data: Dict[str, Any]):
async def _run():
try:
data = await web_manager.step_environment(action_data)
obs_md = _format_observation(data)
return (
obs_md,
json.dumps(data, indent=2),
"Step complete.",
)
except Exception as e:
return ("", "", f"Error: {e}")

return _run
async def reset_env():
"""Reset the environment and return the initial observation."""
try:
data = await web_manager.reset_environment()
obs_md = _format_observation(data)

async def step_chat(message: str):
if not (message or str(message).strip()):
return ("", "", "Please enter an action message.")
action = {"message": str(message).strip()}
return await _step_with_action(action)()
return (
obs_md,
json.dumps(data, indent=2),
"Environment reset successfully.",
*clear_inputs(),
)
except Exception as e:
return ("", "", f"Error: {e}", *clear_inputs())

def get_state_sync():
"""Fetch the current environment state synchronously."""
try:
data = web_manager.get_state()
return json.dumps(data, indent=2)
Expand All @@ -135,32 +147,51 @@ def get_state_sync():

with gr.Blocks(title=display_title) as demo:
with gr.Row():

with gr.Column(scale=1, elem_classes="col-left"):
if quick_start_md:
with gr.Accordion("Quick Start", open=True):
gr.Markdown(quick_start_md)

with gr.Accordion("README", open=False):
gr.Markdown(readme_content)

with gr.Column(scale=2, elem_classes="col-right"):
obs_display = gr.Markdown(
value=("# Playground\n\nClick **Reset** to start a new episode."),
value="# Playground\n\nClick **Reset** to start a new episode."
)

with gr.Group():

if is_chat_env:
action_input = gr.Textbox(
label="Action message",
placeholder="e.g. Enter your message...",
)

step_inputs = [action_input]

async def step_chat(message: str):
"""Handle chat-style step input."""
if not (message and str(message).strip()):
return ("", "", "Please enter an action message.", "")

action = {"message": str(message).strip()}
obs_md, raw_json, status = await step(action)

return (obs_md, raw_json, status, "")

step_fn = step_chat

else:
step_inputs = []

for field in action_fields:
name = field["name"]
field_type = field.get("type", "text")
label = name.replace("_", " ").title()
placeholder = field.get("placeholder", "")

if field_type == "checkbox":
inp = gr.Checkbox(label=label)
elif field_type == "number":
Expand All @@ -183,34 +214,48 @@ def get_state_sync():
label=label,
placeholder=placeholder,
)

step_inputs.append(inp)

async def step_form(*values):
if not action_fields:
return await _step_with_action({})()
def build_action(values: List[Any]) -> Dict[str, Any]:
"""Convert UI input values into action dictionary."""
action_data = {}
for i, field in enumerate(action_fields):
if i >= len(values):
break
for val, field in zip(values, action_fields):
name = field["name"]
val = values[i]
if field.get("type") == "checkbox":
field_type = field.get("type")

if field_type == "checkbox":
action_data[name] = bool(val)
elif val is not None and val != "":
action_data[name] = val
return await _step_with_action(action_data)()

return action_data

async def step_form(*values):
"""Handle form-based step input."""
action_data = build_action(values)
obs_md, raw_json, status = await step(action_data)

return (
obs_md,
raw_json,
status,
*clear_inputs(),
)

step_fn = step_form

with gr.Row():
step_btn = gr.Button("Step", variant="primary")
reset_btn = gr.Button("Reset", variant="secondary")
state_btn = gr.Button("Get state", variant="secondary")

with gr.Row():
status = gr.Textbox(
label="Status",
interactive=False,
)

raw_json = gr.Code(
label="Raw JSON response",
language="json",
Expand All @@ -219,22 +264,25 @@ async def step_form(*values):

reset_btn.click(
fn=reset_env,
outputs=[obs_display, raw_json, status],
outputs=[obs_display, raw_json, status, *step_inputs],
)
Comment thread
Vidit-Ostwal marked this conversation as resolved.

step_btn.click(
fn=step_fn,
inputs=step_inputs,
outputs=[obs_display, raw_json, status],
outputs=[obs_display, raw_json, status, *step_inputs],
)

if is_chat_env:
action_input.submit(
fn=step_fn,
inputs=step_inputs,
outputs=[obs_display, raw_json, status],
outputs=[obs_display, raw_json, status, *step_inputs],
)

state_btn.click(
fn=get_state_sync,
outputs=[raw_json],
)

return demo
return demo