Skip to content

Commit

Permalink
[RLlib] Discontinue support for "hybrid" API stack (using RLModule + …
Browse files Browse the repository at this point in the history
…Learner, but still on RolloutWorker and Policy) (#46085)
  • Loading branch information
sven1977 authored Sep 27, 2024
1 parent 6b44557 commit c9fa046
Show file tree
Hide file tree
Showing 94 changed files with 615 additions and 1,862 deletions.
5 changes: 4 additions & 1 deletion doc/source/rllib/doc_code/catalog_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def __init__(self, *args, **kwargs):

config = (
PPOConfig()
.api_stack(enable_rl_module_and_learner=True)
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("CartPole-v1")
.framework("torch")
)
Expand Down
20 changes: 16 additions & 4 deletions doc/source/rllib/doc_code/rlmodule_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

config = (
PPOConfig()
.api_stack(enable_rl_module_and_learner=True)
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.framework("torch")
.environment("CartPole-v1")
)
Expand Down Expand Up @@ -80,7 +83,10 @@

config = (
BCConfigTest()
.api_stack(enable_rl_module_and_learner=True)
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("CartPole-v1")
.rl_module(
model_config_dict={"fcnet_hiddens": [32, 32]},
Expand All @@ -103,7 +109,10 @@

config = (
BCConfigTest()
.api_stack(enable_rl_module_and_learner=True)
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment(MultiAgentCartPole, env_config={"num_agents": 2})
.rl_module(
model_config_dict={"fcnet_hiddens": [32, 32]},
Expand Down Expand Up @@ -406,7 +415,10 @@ def setup(self):
config = (
PPOConfig()
# Enable the new API stack (RLModule and Learner APIs).
.api_stack(enable_rl_module_and_learner=True).environment("CartPole-v1")
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
).environment("CartPole-v1")
)
env = gym.make("CartPole-v1")
# Create an RL Module that we would like to checkpoint
Expand Down
2 changes: 2 additions & 0 deletions doc/source/rllib/doc_code/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
.api_stack(
enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
)
.framework("torch")
.environment("CartPole-v1")
.env_runners(num_env_runners=0)
.training(
Expand Down Expand Up @@ -112,6 +113,7 @@
.api_stack(
enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
)
.framework("torch")
.environment("CartPole-v1")
.training(
replay_buffer_config={
Expand Down
30 changes: 19 additions & 11 deletions doc/source/rllib/new-api-stack-migration-guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,16 @@ This method isn't used on the old API stack because the old stack doesn't use Le

It allows you to specify:

1) the number of `Learner` workers through `.learners(num_learners=...)`.
1) the resources per learner; use `.learners(num_gpus_per_learner=1)` for GPU training and `.learners(num_gpus_per_learner=0)` for CPU training.
1) the custom Learner class you want to use (`example on how to do this here <https://github.com/ray-project/ray/blob/master/rllib/examples/learners/custom_loss_fn_simple.py>`__)
1) a config dict you would like to set for your custom learner: `.learners(learner_config_dict={...})`. Note that every `Learner` has access to the entire `AlgorithmConfig` object through `self.config`, but setting the `learner_config_dict` is a convenient way to avoid having to create an entirely new `AlgorithmConfig` subclass only to support a few extra settings for your custom `Learner` class.
#. the number of `Learner` workers through `.learners(num_learners=...)`.
#. the resources per learner; use `.learners(num_gpus_per_learner=1)` for GPU training
and `.learners(num_gpus_per_learner=0)` for CPU training.
#. the custom Learner class you want to use (`example on how to do this here <https://github.com/ray-project/ray/blob/master/rllib/examples/learners/custom_loss_fn_simple.py>`__)
#. a config dict you would like to set for your custom learner:
`.learners(learner_config_dict={...})`. Note that every `Learner` has access to the
entire `AlgorithmConfig` object through `self.config`, but setting the
`learner_config_dict` is a convenient way to avoid having to create an entirely new
`AlgorithmConfig` subclass only to support a few extra settings for your custom
`Learner` class.


AlgorithmConfig.env_runners()
Expand Down Expand Up @@ -380,9 +386,11 @@ and `how to write a custom LSTM-containing RL Module <https://github.com/ray-pro
There are various options for translating an existing, custom :py:class:`~ray.rllib.models.modelv2.ModelV2` from the old API stack,
to the new API stack's :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule`:

1) Move your ModelV2 code to a new, custom `RLModule` class. See :ref:`RL Modules <rlmodule-guide>` for details).
1) Use an Algorithm checkpoint or a Policy checkpoint that you have from an old API stack training run and use this checkpoint with the `new stack RL Module convenience wrapper <https://github.com/ray-project/ray/blob/master/rllib/examples/rl_modules/migrate_modelv2_to_new_api_stack_by_policy_checkpoint.py>`__.
1) Use an existing :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig` object from an old API stack training run, with the `new stack RL Module convenience wrapper <https://github.com/ray-project/ray/blob/master/rllib/examples/rl_modules/migrate_modelv2_to_new_api_stack_by_config.py>`__.
#. Move your ModelV2 code to a new, custom `RLModule` class. See :ref:`RL Modules <rlmodule-guide>` for details).
#. Use an Algorithm checkpoint or a Policy checkpoint that you have from an old API stack
training run and use this checkpoint with the `new stack RL Module convenience wrapper <https://github.com/ray-project/ray/blob/master/rllib/examples/rl_modules/migrate_modelv2_to_new_api_stack_by_policy_checkpoint.py>`__.
#. Use an existing :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig`
object from an old API stack training run, with the `new stack RL Module convenience wrapper <https://github.com/ray-project/ray/blob/master/rllib/examples/rl_modules/migrate_modelv2_to_new_api_stack_by_config.py>`__.


Custom loss functions and policies
Expand Down Expand Up @@ -423,7 +431,7 @@ The :py:class:`~ray.rllib.connectors.connector_v2.ConnectorV2` documentation is
The following are some examples on how to write ConnectorV2 pieces for the
different pipelines:

1) `Observation frame-stacking <https://github.com/ray-project/ray/blob/master/rllib/examples/connectors/frame_stacking.py>`__.
1) `Add the most recent action and reward to the RL Module's input <https://github.com/ray-project/ray/blob/master/rllib/examples/connectors/prev_actions_prev_rewards.py>`__.
1) `Mean-std filtering on all observations <https://github.com/ray-project/ray/blob/master/rllib/examples/connectors/mean_std_filtering.py>`__.
1) `Flatten any complex observation space to a 1D space <https://github.com/ray-project/ray/blob/master/rllib/examples/connectors/flatten_observations_dict_space.py>`__.
#. `Observation frame-stacking <https://github.com/ray-project/ray/blob/master/rllib/examples/connectors/frame_stacking.py>`__.
#. `Add the most recent action and reward to the RL Module's input <https://github.com/ray-project/ray/blob/master/rllib/examples/connectors/prev_actions_prev_rewards.py>`__.
#. `Mean-std filtering on all observations <https://github.com/ray-project/ray/blob/master/rllib/examples/connectors/mean_std_filtering.py>`__.
#. `Flatten any complex observation space to a 1D space <https://github.com/ray-project/ray/blob/master/rllib/examples/connectors/flatten_observations_dict_space.py>`__.
8 changes: 0 additions & 8 deletions doc/source/rllib/package_ref/policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,6 @@ Base Policy classes
Making models
--------------------

Base Policy
~~~~~~~~~~~~~~~~~~~~
.. autosummary::
:nosignatures:
:toctree: doc/

~policy.Policy.make_rl_module


Torch Policy
~~~~~~~~~~~~~~~~~~~~
Expand Down
18 changes: 12 additions & 6 deletions doc/source/rllib/rllib-learner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ arguments in the :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConf

config = (
PPOConfig()
.api_stack(enable_rl_module_and_learner=True)
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.learners(
num_learners=0, # Set this to greater than 1 to allow for DDP style updates.
num_gpus_per_learner=0, # Set this to 1 to enable GPU training.
Expand All @@ -75,7 +78,7 @@ arguments in the :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConf
.. note::

This features is in alpha. If you migrate to this algorithm, enable the feature by
via `AlgorithmConfig.api_stack(enable_rl_module_and_learner=True)`.
via `AlgorithmConfig.api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)`.

The following algorithms support :py:class:`~ray.rllib.core.learner.learner.Learner` out of the box. Implement
an algorithm with a custom :py:class:`~ray.rllib.core.learner.learner.Learner` to leverage this API for other algorithms.
Expand Down Expand Up @@ -240,10 +243,13 @@ Updates
results = learner_group.update_from_batch(
batch=DUMMY_BATCH, async_update=True, timesteps=TIMESTEPS
)
# `results` is an already reduced dict, which is the result of
# reducing over the individual async `update_from_batch(..., async_update=True)`
# calls.
assert isinstance(results, dict), results
# `results` is a list of n items (where n is the number of async results collected).
assert isinstance(results, list), results
# Each item in that list is another list of m items (where m is the number of Learner
# workers).
assert isinstance(results[0], list), results
# Each item in the inner list is a result dict from the Learner worker.
assert isinstance(results[0][0], dict), results

When updating a :py:class:`~ray.rllib.core.learner.learner_group.LearnerGroup` you can perform blocking or async updates on batches of data.
Async updates are necessary for implementing async algorithms such as APPO/IMPALA.
Expand Down
9 changes: 0 additions & 9 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2950,15 +2950,6 @@ py_test(
args = ["--framework=tf2", "--config=multi-gpu-ddp"]
)

#@OldAPIStack @HybridAPIStack
py_test(
name = "examples/learners/train_w_bc_finetune_w_ppo",
main = "examples/learners/train_w_bc_finetune_w_ppo.py",
tags = ["team:rllib", "examples", "exclusive"],
size = "medium",
srcs = ["examples/learners/train_w_bc_finetune_w_ppo.py"],
)

# subdirectory: multi_agent/
# ....................................
py_test(
Expand Down
46 changes: 21 additions & 25 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,7 @@ def setup(self, config: AlgorithmConfig) -> None:
env_steps_sampled=self.metrics.peek(
NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0
),
rl_module_state=rl_module_state,
)

if self.offline_data:
Expand Down Expand Up @@ -1003,7 +1004,7 @@ def step(self) -> ResultDict:
self._remote_worker_ids_for_metrics(),
timeout_seconds=self.config.metrics_episode_collection_timeout_s,
)
results = self._compile_iteration_results_old_and_hybrid_api_stacks(
results = self._compile_iteration_results_old_api_stack(
episodes_this_iter=episodes_this_iter,
step_ctx=train_iter_ctx,
iteration_results={**train_results, **eval_results},
Expand Down Expand Up @@ -1708,7 +1709,7 @@ def training_step(self) -> ResultDict:
if not self.config.enable_env_runner_and_connector_v2:
raise NotImplementedError(
"The `Algorithm.training_step()` default implementation no longer "
"supports the old or hybrid API stacks! If you would like to continue "
"supports the old API stack! If you would like to continue "
"using these "
"old APIs with this default `training_step`, simply subclass "
"`Algorithm` and override its `training_step` method (copy/paste the "
Expand Down Expand Up @@ -2404,12 +2405,12 @@ def add_policy(
Callable[[PolicyID, Optional[SampleBatchType]], bool],
]
] = None,
add_to_learners: bool = True,
add_to_env_runners: bool = True,
add_to_eval_env_runners: bool = True,
module_spec: Optional[RLModuleSpec] = None,
# Deprecated arg.
evaluation_workers=DEPRECATED_VALUE,
add_to_learners=DEPRECATED_VALUE,
) -> Optional[Policy]:
"""Adds a new policy to this Algorithm.
Expand Down Expand Up @@ -2442,9 +2443,6 @@ def add_policy(
If None, will keep the existing setup in place. Policies,
whose IDs are not in the list (or for which the callable
returns False) will not be updated.
add_to_learners: Whether to add the new RLModule to the LearnerGroup
(with its n Learners). This setting is only valid on the hybrid-API
stack (with Learners, but w/o EnvRunners).
add_to_env_runners: Whether to add the new RLModule to the EnvRunnerGroup
(with its m EnvRunners plus the local one).
add_to_eval_env_runners: Whether to add the new RLModule to the eval
Expand All @@ -2471,6 +2469,12 @@ def add_policy(
new="Algorithm.add_policy(add_to_eval_env_runners=...)",
error=True,
)
if add_to_learners != DEPRECATED_VALUE:
deprecation_warning(
old="Algorithm.add_policy(add_to_learners=..)",
help="Hybrid API stack no longer supported by RLlib!",
error=True,
)

validate_module_id(policy_id, error=True)

Expand Down Expand Up @@ -2544,11 +2548,11 @@ def remove_policy(
Callable[[PolicyID, Optional[SampleBatchType]], bool],
]
] = None,
remove_from_learners: bool = True,
remove_from_env_runners: bool = True,
remove_from_eval_env_runners: bool = True,
# Deprecated args.
evaluation_workers=DEPRECATED_VALUE,
remove_from_learners=DEPRECATED_VALUE,
) -> None:
"""Removes a policy from this Algorithm.
Expand All @@ -2564,9 +2568,6 @@ def remove_policy(
If None, will keep the existing setup in place. Policies,
whose IDs are not in the list (or for which the callable
returns False) will not be updated.
remove_from_learners: Whether to remove the Policy from the LearnerGroup
(with its n Learners). Only valid on the hybrid API stack (w/ Learners,
but w/o EnvRunners).
remove_from_env_runners: Whether to remove the Policy from the
EnvRunnerGroup (with its m EnvRunners plus the local one).
remove_from_eval_env_runners: Whether to remove the RLModule from the eval
Expand All @@ -2579,6 +2580,12 @@ def remove_policy(
error=False,
)
remove_from_eval_env_runners = evaluation_workers
if remove_from_learners != DEPRECATED_VALUE:
deprecation_warning(
old="Algorithm.remove_policy(remove_from_learners=..)",
help="Hybrid API stack no longer supported by RLlib!",
error=True,
)

def fn(worker):
worker.remove_policy(
Expand Down Expand Up @@ -2768,27 +2775,16 @@ def load_checkpoint(self, checkpoint_dir: str) -> None:
and self.config.enable_env_runner_and_connector_v2
):
self.restore_from_path(checkpoint_dir)

# Call the `on_checkpoint_loaded` callback.
self.callbacks.on_checkpoint_loaded(algorithm=self)
return

# Checkpoint is provided as a local directory.
# Restore from the checkpoint file or dir.
checkpoint_info = get_checkpoint_info(checkpoint_dir)
checkpoint_data = Algorithm._checkpoint_info_to_algorithm_state(checkpoint_info)
self.__setstate__(checkpoint_data)
if self.config.enable_rl_module_and_learner:
# We restore the LearnerGroup from a "learner" subdir. Note that this is not
# in line with the new Checkpointable API, but makes this case backward
# compatible. The new Checkpointable API is only strictly applied anyways
# to the new API stack.
learner_group_state_dir = os.path.join(checkpoint_dir, "learner")
self.learner_group.restore_from_path(learner_group_state_dir)
# Make also sure, all (training) EnvRunners get the just loaded weights, but
# only the inference-only ones.
self.env_runner_group.sync_weights(
from_worker_or_learner_group=self.learner_group,
inference_only=True,
)

# Call the `on_checkpoint_loaded` callback.
self.callbacks.on_checkpoint_loaded(algorithm=self)

Expand Down Expand Up @@ -3901,7 +3897,7 @@ def _compile_iteration_results_new_api_stack(
)

@OldAPIStack
def _compile_iteration_results_old_and_hybrid_api_stacks(
def _compile_iteration_results_old_api_stack(
self, *, episodes_this_iter, step_ctx, iteration_results
):
# Results to be returned.
Expand Down
16 changes: 11 additions & 5 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4396,6 +4396,16 @@ def _validate_new_api_stack_settings(self):
# `enable_rl_module_and_learner=True`.
return

# Disabled hybrid API stack. Now, both `enable_rl_module_and_learner` and
# `enable_env_runner_and_connector_v2` must be True or both False.
if not self.enable_env_runner_and_connector_v2:
raise ValueError(
"Setting `enable_rl_module_and_learner` to True and "
"`enable_env_runner_and_connector_v2` to False ('hybrid API stack'"
") is not longer supported! Set both to True (new API stack) or both "
"to False (old API stack), instead."
)

# New API stack (RLModule, Learner APIs) only works with connectors.
if not self.enable_connectors:
raise ValueError(
Expand All @@ -4415,11 +4425,7 @@ def _validate_new_api_stack_settings(self):
# new API stack AND this is a single-agent setup (multi-agent does not use
# gym.vector.Env yet and therefore the reset call is still made manually,
# allowing for the callback to be fired).
if (
self.enable_env_runner_and_connector_v2
and not self.is_multi_agent()
and self.callbacks_class is not DefaultCallbacks
):
if not self.is_multi_agent() and self.callbacks_class is not DefaultCallbacks:
default_src = inspect.getsource(DefaultCallbacks.on_episode_created)
try:
user_src = inspect.getsource(self.callbacks_class.on_episode_created)
Expand Down
7 changes: 4 additions & 3 deletions rllib/algorithms/bc/tests/test_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_bc_compilation_and_learning_from_offline_file(self):
)

num_iterations = 350
min_reward = 120.0
min_return_to_reach = 120.0

# TODO (simon): Add support for recurrent modules.
algo = config.build()
Expand All @@ -73,14 +73,15 @@ def test_bc_compilation_and_learning_from_offline_file(self):
EPISODE_RETURN_MEAN
]
print(f"iter={i}, R={episode_return_mean}")
if episode_return_mean > min_reward:
if episode_return_mean > min_return_to_reach:
print("BC has learnt the task!")
learnt = True
break

if not learnt:
raise ValueError(
f"`BC` did not reach {min_reward} reward from expert offline data!"
f"`BC` did not reach {min_return_to_reach} reward from "
"expert offline data!"
)

algo.stop()
Expand Down
Loading

0 comments on commit c9fa046

Please sign in to comment.