Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions rllm/engine/agent_execution_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
"""Deprecated legacy execution engine.

This module is deprecated and will be removed in a future release.
Use `rllm.experimental.engine.agent_execution_workflow.AgentExecutionWorkflowEngine`
instead.
"""

import asyncio
import logging
import time
import traceback
import uuid
import warnings
from concurrent.futures import ThreadPoolExecutor

import torch
Expand Down Expand Up @@ -49,6 +57,11 @@ def __init__(
overlong_filter=False, # Filter for overlong trajectories (i.e. TRUNCATION, MAX_STEPS, TIMEOUT)
**kwargs,
):
warnings.warn(
"rllm.engine.agent_execution_engine.AgentExecutionEngine is deprecated and will be removed in a future release. Use rllm.experimental.engine.agent_execution_workflow.AgentExecutionWorkflowEngine instead.",
DeprecationWarning,
stacklevel=2,
)
if agent_args is None:
agent_args = {}
if rollout_engine_args is None:
Expand Down Expand Up @@ -158,6 +171,9 @@ async def get_model_response(self, prompt, application_id, **kwargs) -> str:
return output
elif self.engine_name == "tinker":
output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, enforce_max_prompt_length=False, **sampling_params)

# print(f"Tinker output: {output.text}, prompt {prompt}, application_id {application_id}, sampling_params {sampling_params}")

return output
else:
raise NotImplementedError(f"Engine type '{self.engine_name}' not supported")
Expand Down
13 changes: 13 additions & 0 deletions rllm/engine/agent_sdk_engine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
"""Deprecated legacy SDK engine.

This module is deprecated and will be removed in a future release.
Use `rllm.experimental.engine.sdk_workflow.SdkWorkflow` with
`rllm.experimental.engine.unified_workflow_engine.UnifiedWorkflowEngine` instead.
"""

import asyncio
import functools
import inspect
import logging
import time
import uuid
import warnings
from collections import defaultdict
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -56,6 +64,11 @@ def __init__(self, agent_run_func: Callable, rollout_engine: RolloutEngine, conf
tracer: Optional tracer for logging.
**kwargs: Additional arguments.
"""
warnings.warn(
"rllm.engine.agent_sdk_engine.AgentSdkEngine is deprecated and will be removed in a future release. Use rllm.experimental.engine.sdk_workflow.SdkWorkflow with rllm.experimental.engine.unified_workflow_engine.UnifiedWorkflowEngine instead.",
DeprecationWarning,
stacklevel=2,
)
self.rollout_engine = rollout_engine
self.agent_run_func = agent_run_func
self.config = config # if training
Expand Down
12 changes: 12 additions & 0 deletions rllm/engine/agent_workflow_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""Deprecated legacy workflow engine.

This module is deprecated and will be removed in a future release.
Use `rllm.experimental.engine.agent_workflow_engine.AgentWorkflowEngine` instead.
"""

import asyncio
import logging
import uuid
import warnings
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -36,6 +43,11 @@ def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_en
episode_logger: Optional logger for saving episode data to files.
**kwargs: Additional keyword arguments.
"""
warnings.warn(
"rllm.engine.agent_workflow_engine.AgentWorkflowEngine is deprecated and will be removed in a future release. Use rllm.experimental.engine.agent_workflow_engine.AgentWorkflowEngine instead.",
DeprecationWarning,
stacklevel=2,
)
self.workflow_cls = workflow_cls
self.workflow_args = workflow_args or {}

Expand Down
3 changes: 2 additions & 1 deletion rllm/environments/frozenlake/frozenlake.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def __init__(self, **kwargs):
is_slippery = kwargs.pop("is_slippery", False)
size = kwargs.pop("size", 8)
p = kwargs.pop("p", 0.8)
seed = kwargs.pop("seed", 42)
# Gymnasium's seeding utility requires a Python int, not numpy scalar types.
seed = int(kwargs.pop("seed", 42))
self.seed = seed
self.size = size
self.p = p
Expand Down
176 changes: 176 additions & 0 deletions rllm/experimental/engine/agent_execution_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""Compatibility adapter replacing AgentExecutionEngine with workflow execution.

This module provides a light-weight migration path for scripts that currently
instantiate ``rllm.engine.agent_execution_engine.AgentExecutionEngine`` for
inference/evaluation. It preserves the familiar constructor shape while routing
execution through ``UnifiedWorkflowEngine`` + ``CumulativeWorkflow``.
"""

from __future__ import annotations

import logging
from typing import Any

from rllm.agents.agent import Trajectory
from rllm.experimental.engine.unified_workflow_engine import UnifiedWorkflowEngine
from rllm.parser import ChatTemplateParser
from rllm.workflows.cumulative_workflow import CumulativeWorkflow

logger = logging.getLogger(__name__)


class AgentExecutionWorkflowFactory:
"""Build workflow class/args from legacy agent/env constructor arguments."""

def __init__(
self,
agent_class,
env_class,
agent_args: dict | None = None,
env_args: dict | None = None,
max_steps: int = 5,
) -> None:
self.agent_class = agent_class
self.env_class = env_class
self.agent_args = agent_args or {}
self.env_args = env_args or {}
self.max_steps = max_steps

def get_workflow_cls(self):
return CumulativeWorkflow

def get_workflow_args(self) -> dict[str, Any]:
return {
"agent_cls": self.agent_class,
"env_cls": self.env_class,
"agent_args": self.agent_args,
"env_args": self.env_args,
"max_steps": self.max_steps,
}


class AgentExecutionWorkflowEngine:
"""Workflow-backed compatibility engine for legacy rollout scripts.

Notes:
- This class intentionally targets the ``execute_tasks(tasks)`` inference
interface used by examples like FrozenLake.
- Returned values are flattened trajectories to remain compatible with
helpers such as ``compute_pass_at_k`` that expect ``list[Trajectory]``.
"""

def __init__(
self,
engine_name="openai",
tokenizer=None,
rollout_engine=None,
chat_parser=None,
n_parallel_agents=128,
trajectory_timeout=None,
gamma=0.2,
api_retries=3,
retry_limit=3,
max_steps=5,
max_response_length=8192,
max_prompt_length=1024,
config=None,
agent_class=None,
env_class=None,
agent_args=None,
rollout_engine_args=None,
env_args=None,
max_workers=64,
enforce_max_prompt_length=False,
overlong_filter=False,
**kwargs,
):
del trajectory_timeout, gamma, max_workers, enforce_max_prompt_length, overlong_filter
self.config = config
self.engine_name = engine_name
self.tokenizer = tokenizer
self.n_parallel_agents = n_parallel_agents
self.retry_limit = retry_limit
self.disable_thinking = self.config.get("rllm", {}).get("disable_thinking", False) if self.config is not None else False

if agent_args is None:
agent_args = {}
if env_args is None:
env_args = {}
if rollout_engine_args is None:
rollout_engine_args = {}

# Keep legacy naming for example compatibility.
self.rollout_engine_args = rollout_engine_args
self.sampling_params = kwargs.get("sampling_params", {})

if self.engine_name == "openai":
from rllm.engine.rollout.openai_engine import OpenAIEngine

self.rollout_engine = OpenAIEngine(
**rollout_engine_args,
api_retries=api_retries,
tokenizer=self.tokenizer,
max_prompt_length=max_prompt_length,
max_response_length=max_response_length,
disable_thinking=kwargs.get("disable_thinking", False),
)
elif self.engine_name == "verl":
from rllm.engine.rollout.verl_engine import VerlEngine

self.rollout_engine = VerlEngine(
config=self.config,
rollout_manager=rollout_engine,
tokenizer=self.tokenizer,
disable_thinking=kwargs.get("disable_thinking", False),
)
elif self.engine_name == "tinker":
from rllm.engine.rollout.tinker_engine import TinkerEngine

# CumulativeWorkflow relies on rollout_engine.chat_parser for prompt length accounting.
# Tinker defaults to renderer mode where chat_parser is None, so we enable parser mode by default.
rollout_engine_args.setdefault("bypass_render_with_parser", True)
rollout_engine_args.setdefault("disable_thinking", self.disable_thinking)
self.rollout_engine = TinkerEngine(
**rollout_engine_args,
)
else:
raise NotImplementedError(f"Engine type '{self.engine_name}' not supported")

# Backward-compatible parser initialization if caller provides/needs one.
if chat_parser is None and getattr(self.rollout_engine, "chat_parser", None) is None:
self.rollout_engine.chat_parser = ChatTemplateParser.get_parser(self.tokenizer, disable_thinking=self.disable_thinking)
elif chat_parser is not None:
self.rollout_engine.chat_parser = chat_parser

factory = AgentExecutionWorkflowFactory(
agent_class=agent_class,
env_class=env_class,
agent_args=agent_args,
env_args=env_args,
max_steps=max_steps,
)
self.workflow_engine = UnifiedWorkflowEngine(
workflow_cls=factory.get_workflow_cls(),
workflow_args=factory.get_workflow_args(),
rollout_engine=self.rollout_engine,
config=self.config,
n_parallel_tasks=self.n_parallel_agents,
retry_limit=self.retry_limit,
raise_on_error=True,
)

async def execute_tasks(self, tasks: list[dict]):
episodes = await self.workflow_engine.execute_tasks(tasks)
trajectories: list[Trajectory] = []
for episode in episodes:
if not episode.trajectories:
continue
for traj in episode.trajectories:
# Preserve old helper compatibility (expects trajectory.task / reward).
if traj.task is None:
traj.task = episode.task
trajectories.append(traj)
return trajectories

def shutdown(self):
self.workflow_engine.shutdown()
68 changes: 68 additions & 0 deletions rllm/experimental/engine/agent_workflow_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Compatibility adapter for migrating from legacy AgentWorkflowEngine.

This adapter preserves the familiar workflow-engine constructor and runtime
methods while delegating execution to ``UnifiedWorkflowEngine``.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

from rllm.experimental.engine.unified_workflow_engine import UnifiedWorkflowEngine

if TYPE_CHECKING:
from rllm.agents.agent import Episode
from rllm.workflows.workflow import Workflow


class AgentWorkflowEngine:
"""Compatibility facade over UnifiedWorkflowEngine.

The goal is migration ergonomics: old scripts can switch import paths to
this experimental module with minimal code changes.
"""

def __init__(
self,
workflow_cls: type[Workflow],
workflow_args: dict,
rollout_engine,
config=None,
n_parallel_tasks: int = 128,
retry_limit: int = 3,
raise_on_error: bool = True,
episode_logger=None,
**kwargs,
):
self._engine = UnifiedWorkflowEngine(
workflow_cls=workflow_cls,
workflow_args=workflow_args,
rollout_engine=rollout_engine,
config=config,
n_parallel_tasks=n_parallel_tasks,
retry_limit=retry_limit,
raise_on_error=raise_on_error,
episode_logger=episode_logger,
**kwargs,
)

@property
def rollout_engine(self):
return self._engine.rollout_engine

def set_training_step(self, step: int, mode: str = "train", epoch: int = 0):
self._engine.set_training_step(step=step, mode=mode, epoch=epoch)

async def initialize_pool(self):
await self._engine.initialize_pool()

async def execute_tasks(self, tasks: list[dict], task_ids: list[str] | None = None, **kwargs) -> list[Episode]:
return await self._engine.execute_tasks(tasks=tasks, task_ids=task_ids, **kwargs)

async def execute_tasks_verl(self, batch, **kwargs):
# Unified workflow path currently returns episodes rather than a
# transformed DataProto payload.
return await self._engine.execute_tasks_verl(batch, **kwargs)

def shutdown(self):
self._engine.shutdown()
Loading