Skip to content

Commit

Permalink
[RLlib] RLModule: InferenceOnlyAPI. (ray-project#47572)
Browse files Browse the repository at this point in the history
Signed-off-by: ujjawal-khare <[email protected]>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent 881ca0f commit db63e7b
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 69 deletions.
3 changes: 1 addition & 2 deletions rllib/algorithms/appo/appo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,12 @@ def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
old_action_dist_logits = self._old_pi(old_pi_inputs_encoded)
return {OLD_ACTION_DIST_LOGITS_KEY: old_action_dist_logits}

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(PPORLModule)
def get_non_inference_attributes(self) -> List[str]:
# Get the NON inference-only attributes from the parent class
# `PPOTorchRLModule`.
ret = super().get_non_inference_attributes()
# Add the two (APPO) target networks to it (NOT needed in
# Add the two (APPO) target networks to it (also NOT needed in
# inference-only mode).
ret += ["_old_encoder", "_old_pi"]
return ret
4 changes: 2 additions & 2 deletions rllib/algorithms/dqn/dqn_rainbow_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
QF_TARGET_NEXT_PROBS = "qf_target_next_probs"


@DeveloperAPI(stability="alpha")
@ExperimentalAPI
class DQNRainbowRLModule(RLModule, InferenceOnlyAPI, TargetNetworkAPI):
@override(RLModule)
def setup(self):
Expand Down Expand Up @@ -56,7 +56,7 @@ def setup(self):
# the same encoder is used.
# Note further, by using the base encoder the correct encoder
# is chosen for the observation space used.
self.encoder = self.catalog.build_encoder(framework=self.framework)
self.encoder = catalog.build_encoder(framework=self.framework)

# Build heads.
self.af = self.catalog.build_af_head(framework=self.framework)
Expand Down
44 changes: 36 additions & 8 deletions rllib/algorithms/ppo/ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
"""

import abc
from typing import List
from typing import List, Type

from ray.rllib.core.models.configs import RecurrentEncoderConfig
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import (
override,
Expand All @@ -17,7 +17,7 @@


@ExperimentalAPI
class PPORLModule(RLModule, ValueFunctionAPI, abc.ABC):
class PPORLModule(RLModule, InferenceOnlyAPI, ValueFunctionAPI, abc.ABC):
def setup(self):
# __sphinx_doc_begin__
# If we have a stateful model, states for the critic need to be collected
Expand All @@ -36,9 +36,11 @@ def setup(self):
self.catalog.actor_critic_encoder_config.inference_only = True

# Build models from catalog.
self.encoder = self.catalog.build_actor_critic_encoder(framework=self.framework)
self.pi = self.catalog.build_pi_head(framework=self.framework)
self.vf = self.catalog.build_vf_head(framework=self.framework)
self.encoder = catalog.build_actor_critic_encoder(framework=self.framework)
self.pi = catalog.build_pi_head(framework=self.framework)
self.vf = catalog.build_vf_head(framework=self.framework)

self.action_dist_cls = catalog.get_action_dist_cls(framework=self.framework)
# __sphinx_doc_end__

@override(RLModule)
Expand All @@ -48,12 +50,38 @@ def get_initial_state(self) -> dict:
else:
return {}

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(RLModule)
def input_specs_inference(self) -> SpecDict:
return [Columns.OBS]

@override(RLModule)
def output_specs_inference(self) -> SpecDict:
return [Columns.ACTION_DIST_INPUTS]

@override(RLModule)
def input_specs_exploration(self):
return self.input_specs_inference()

@override(RLModule)
def output_specs_exploration(self) -> SpecDict:
return self.output_specs_inference()

@override(RLModule)
def input_specs_train(self) -> SpecDict:
return self.input_specs_exploration()

@override(RLModule)
def output_specs_train(self) -> SpecDict:
return [
Columns.VF_PREDS,
Columns.ACTION_DIST_INPUTS,
]

@override(InferenceOnlyAPI)
def get_non_inference_attributes(self) -> List[str]:
"""Return attributes, which are NOT inference-only (only used for training)."""
return ["vf"] + (
[]
if self.model_config.get("vf_share_layers")
if self.config.model_config_dict.get("vf_share_layers")
else ["encoder.critic_encoder"]
)
6 changes: 4 additions & 2 deletions rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Dict, Optional
from typing import Any, Dict

from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import ACTOR, CRITIC, ENCODER_OUT
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.apis import ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.utils.annotations import override
Expand All @@ -14,6 +14,8 @@


class PPOTorchRLModule(TorchRLModule, PPORLModule):
framework: str = "torch"

@override(RLModule)
def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Default forward pass (used for inference and exploration)."""
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/sac/sac_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ray.util.annotations import DeveloperAPI


@DeveloperAPI(stability="alpha")
@ExperimentalAPI
class SACRLModule(RLModule, InferenceOnlyAPI, TargetNetworkAPI):
"""`RLModule` for the Soft-Actor-Critic (SAC) algorithm.
Expand Down
2 changes: 2 additions & 0 deletions rllib/core/rl_module/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from ray.rllib.core.rl_module.apis.inference_only_api import InferenceOnlyAPI
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI


__all__ = [
"InferenceOnlyAPI",
"TargetNetworkAPI",
"ValueFunctionAPI",
]
4 changes: 2 additions & 2 deletions rllib/core/rl_module/apis/inference_only_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ class InferenceOnlyAPI(abc.ABC):
Only the `get_non_inference_attributes` method needs to get implemented for
an RLModule to have the following functionality:
- On EnvRunners (or when self.inference_only=True), RLlib will remove
- On EnvRunners (or when self.config.inference_only=True), RLlib will remove
those parts of the model not required for action computation.
- An RLModule on a Learner (where `self.inference_only=False`) will
- An RLModule on a Learner (where `self.config.inference_only=False`) will
return only those weights from `get_state()` that are part of its inference-only
version, thus possibly saving network traffic/time.
"""
Expand Down
64 changes: 12 additions & 52 deletions rllib/core/rl_module/torch/torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def __init__(self, *args, **kwargs) -> None:
nn.Module.__init__(self)
RLModule.__init__(self, *args, **kwargs)

# If an inference-only class AND self.inference_only is True,
# If an inference-only class AND self.config.inference_only is True,
# remove all attributes that are returned by
# `self.get_non_inference_attributes()`.
if self.inference_only and isinstance(self, InferenceOnlyAPI):
if self.config.inference_only and isinstance(self, InferenceOnlyAPI):
for attr in self.get_non_inference_attributes():
parts = attr.split(".")
if not hasattr(self, parts[0]):
Expand All @@ -68,6 +68,15 @@ def __init__(self, *args, **kwargs) -> None:
if target is not None:
del target

@override(nn.Module)
def forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""forward pass of the module.
This is aliased to forward_train because Torch DDP requires a forward method to
be implemented for backpropagation to work.
"""
return self.forward_train(batch, **kwargs)

def compile(self, compile_config: TorchCompileConfig):
"""Compile the forward methods of this module.
Expand Down Expand Up @@ -109,7 +118,7 @@ def get_state(
# InferenceOnlyAPI).
if (
inference_only
and not self.inference_only
and not self.config.inference_only
and isinstance(self, InferenceOnlyAPI)
):
attr = self.get_non_inference_attributes()
Expand All @@ -130,55 +139,6 @@ def set_state(self, state: StateDict) -> None:
# RLModule.
self.load_state_dict(convert_to_torch_tensor(state), strict=False)

@OverrideToImplementCustomLogic
@override(RLModule)
def get_inference_action_dist_cls(self) -> Type[TorchDistribution]:
if self.action_dist_cls is not None:
return self.action_dist_cls
elif isinstance(self.action_space, gym.spaces.Discrete):
return TorchCategorical
elif isinstance(self.action_space, gym.spaces.Box):
return TorchDiagGaussian
else:
raise ValueError(
f"Default action distribution for action space "
f"{self.action_space} not supported! Either set the "
f"`self.action_dist_cls` property in your RLModule's `setup()` method "
f"to a subclass of `ray.rllib.models.torch.torch_distributions."
f"TorchDistribution` or - if you need different distributions for "
f"inference and training - override the three methods: "
f"`get_inference_action_dist_cls`, `get_exploration_action_dist_cls`, "
f"and `get_train_action_dist_cls` in your RLModule."
)

@OverrideToImplementCustomLogic
@override(RLModule)
def get_exploration_action_dist_cls(self) -> Type[TorchDistribution]:
return self.get_inference_action_dist_cls()

@OverrideToImplementCustomLogic
@override(RLModule)
def get_train_action_dist_cls(self) -> Type[TorchDistribution]:
return self.get_inference_action_dist_cls()

@override(nn.Module)
def forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""DO NOT OVERRIDE!
This is aliased to `self.forward_train` because Torch DDP requires a forward
method to be implemented for backpropagation to work.
Instead, override:
`_forward()` to define a generic forward pass for all phases (exploration,
inference, training)
`_forward_inference()` to define the forward pass for action inference in
deployment/production (no exploration).
`_forward_exploration()` to define the forward pass for action inference during
training sample collection (w/ exploration behavior).
`_forward_train()` to define the forward pass prior to loss computation.
"""
return self.forward_train(batch, **kwargs)


class TorchDDPRLModule(RLModule, nn.parallel.DistributedDataParallel):
def __init__(self, *args, **kwargs) -> None:
Expand Down

0 comments on commit db63e7b

Please sign in to comment.