Skip to content

Commit 129f3d5

Browse files
committed
[Feature] PPO Trainer Updates
ghstack-source-id: 7d699b2 Pull-Request: #3190
1 parent 888095f commit 129f3d5

File tree

4 files changed

+151
-17
lines changed

4 files changed

+151
-17
lines changed

test/test_configs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,10 @@ def test_tensor_dict_module_config(self):
835835
in_keys=["observation"],
836836
out_keys=["action"],
837837
)
838-
assert cfg._target_ == "tensordict.nn.TensorDictModule"
838+
assert (
839+
cfg._target_
840+
== "torchrl.trainers.algorithms.configs.modules._make_tensordict_module"
841+
)
839842
assert cfg.module._target_ == "torchrl.modules.MLP"
840843
assert cfg.in_keys == ["observation"]
841844
assert cfg.out_keys == ["action"]

torchrl/objectives/value/advantages.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,8 @@ class GAE(ValueEstimatorBase):
13241324
`"is_init"` of the `"next"` entry, such that trajectories are well separated both for root and `"next"` data.
13251325
"""
13261326

1327+
value_network: TensorDictModule | None
1328+
13271329
def __init__(
13281330
self,
13291331
*,

torchrl/trainers/algorithms/configs/modules.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ class TensorDictModuleConfig(ModelConfig):
222222
"""
223223

224224
module: MLPConfig = MISSING
225-
_target_: str = "tensordict.nn.TensorDictModule"
225+
_target_: str = (
226+
"torchrl.trainers.algorithms.configs.modules._make_tensordict_module"
227+
)
226228
_partial_: bool = False
227229

228230
def __post_init__(self) -> None:
@@ -292,6 +294,30 @@ def __post_init__(self) -> None:
292294
super().__post_init__()
293295

294296

297+
def _make_tensordict_module(*args, **kwargs):
298+
"""Helper function to create a TensorDictModule."""
299+
from hydra.utils import instantiate
300+
from tensordict.nn import TensorDictModule
301+
302+
module = kwargs.pop("module")
303+
shared = kwargs.pop("shared", False)
304+
305+
# Instantiate the module if it's a config
306+
if hasattr(module, "_target_"):
307+
module = instantiate(module)
308+
elif callable(module) and hasattr(module, "func"): # partial function
309+
module = module()
310+
311+
# Create the TensorDictModule
312+
tensordict_module = TensorDictModule(module, **kwargs)
313+
314+
# Apply share_memory if needed
315+
if shared:
316+
tensordict_module = tensordict_module.share_memory()
317+
318+
return tensordict_module
319+
320+
295321
def _make_tanh_normal_model(*args, **kwargs):
296322
"""Helper function to create a TanhNormal model with ProbabilisticTensorDictSequential."""
297323
from hydra.utils import instantiate
@@ -351,10 +377,24 @@ def _make_tanh_normal_model(*args, **kwargs):
351377

352378
def _make_value_model(*args, **kwargs):
353379
"""Helper function to create a ValueOperator with the given network."""
380+
from hydra.utils import instantiate
381+
354382
from torchrl.modules import ValueOperator
355383

356384
network = kwargs.pop("network")
357385
shared = kwargs.pop("shared", False)
386+
387+
# Instantiate the network if it's a config
388+
if hasattr(network, "_target_"):
389+
network = instantiate(network)
390+
elif callable(network) and hasattr(network, "func"): # partial function
391+
network = network()
392+
393+
# Create the ValueOperator
394+
value_operator = ValueOperator(network, **kwargs)
395+
396+
# Apply share_memory if needed
358397
if shared:
359-
network = network.share_memory()
360-
return ValueOperator(network, **kwargs)
398+
value_operator = value_operator.share_memory()
399+
400+
return value_operator

torchrl/trainers/algorithms/ppo.py

Lines changed: 102 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,42 @@ class PPOTrainer(Trainer):
6666
6767
Logging can be configured via constructor parameters to enable/disable specific metrics.
6868
69+
Args:
70+
collector (DataCollectorBase): The data collector for gathering training data.
71+
total_frames (int): Total number of frames to train for.
72+
frame_skip (int): Frame skip value for the environment.
73+
optim_steps_per_batch (int): Number of optimization steps per batch.
74+
loss_module (LossModule): The loss module for computing policy and value losses.
75+
optimizer (optim.Optimizer, optional): The optimizer for training.
76+
logger (Logger, optional): Logger for tracking training metrics.
77+
clip_grad_norm (bool, optional): Whether to clip gradient norms. Default: True.
78+
clip_norm (float, optional): Maximum gradient norm value.
79+
progress_bar (bool, optional): Whether to show a progress bar. Default: True.
80+
seed (int, optional): Random seed for reproducibility.
81+
save_trainer_interval (int, optional): Interval for saving trainer state. Default: 10000.
82+
log_interval (int, optional): Interval for logging metrics. Default: 10000.
83+
save_trainer_file (str | pathlib.Path, optional): File path for saving trainer state.
84+
num_epochs (int, optional): Number of epochs per batch. Default: 4.
85+
replay_buffer (ReplayBuffer, optional): Replay buffer for storing data.
86+
batch_size (int, optional): Batch size for optimization.
87+
gamma (float, optional): Discount factor for GAE. Default: 0.9.
88+
lmbda (float, optional): Lambda parameter for GAE. Default: 0.99.
89+
enable_logging (bool, optional): Whether to enable logging. Default: True.
90+
log_rewards (bool, optional): Whether to log rewards. Default: True.
91+
log_actions (bool, optional): Whether to log actions. Default: True.
92+
log_observations (bool, optional): Whether to log observations. Default: False.
93+
async_collection (bool, optional): Whether to use async collection. Default: False.
94+
add_gae (bool, optional): Whether to add GAE computation. Default: True.
95+
gae (Callable, optional): Custom GAE module. If None and add_gae is True, a default GAE will be created.
96+
weight_update_map (dict[str, str], optional): Mapping from collector destination paths (keys in
97+
collector's weight_sync_schemes) to trainer source paths. Required if collector has
98+
weight_sync_schemes configured. Example: {"policy": "loss_module.actor_network",
99+
"replay_buffer.transforms[0]": "loss_module.critic_network"}
100+
log_timings (bool, optional): If True, automatically register a LogTiming hook to log
101+
timing information for all hooks to the logger (e.g., wandb, tensorboard).
102+
Timing metrics will be logged with prefix "time/" (e.g., "time/hook/UpdateWeights").
103+
Default is False.
104+
69105
Examples:
70106
>>> # Basic usage with manual configuration
71107
>>> from torchrl.trainers.algorithms.ppo import PPOTrainer
@@ -111,6 +147,10 @@ def __init__(
111147
log_actions: bool = True,
112148
log_observations: bool = False,
113149
async_collection: bool = False,
150+
add_gae: bool = True,
151+
gae: Callable[[TensorDictBase], TensorDictBase] | None = None,
152+
weight_update_map: dict[str, str] | None = None,
153+
log_timings: bool = False,
114154
) -> None:
115155
warnings.warn(
116156
"PPOTrainer is an experimental/prototype feature. The API may change in future versions. "
@@ -135,17 +175,21 @@ def __init__(
135175
save_trainer_file=save_trainer_file,
136176
num_epochs=num_epochs,
137177
async_collection=async_collection,
178+
log_timings=log_timings,
138179
)
139180
self.replay_buffer = replay_buffer
140181
self.async_collection = async_collection
141182

142-
gae = GAE(
143-
gamma=gamma,
144-
lmbda=lmbda,
145-
value_network=self.loss_module.critic_network,
146-
average_gae=True,
147-
)
148-
self.register_op("pre_epoch", gae)
183+
if add_gae and gae is None:
184+
gae = GAE(
185+
gamma=gamma,
186+
lmbda=lmbda,
187+
value_network=self.loss_module.critic_network,
188+
average_gae=True,
189+
)
190+
self.register_op("pre_epoch", gae)
191+
elif not add_gae and gae is not None:
192+
raise ValueError("gae must not be provided if add_gae is False")
149193

150194
if (
151195
not self.async_collection
@@ -167,16 +211,61 @@ def __init__(
167211
)
168212

169213
if not self.async_collection:
214+
# rb has been extended by the collector
170215
self.register_op("pre_epoch", rb_trainer.extend)
171216
self.register_op("process_optim_batch", rb_trainer.sample)
172217
self.register_op("post_loss", rb_trainer.update_priority)
173218

174-
policy_weights_getter = partial(
175-
TensorDict.from_module, self.loss_module.actor_network
176-
)
177-
update_weights = UpdateWeights(
178-
self.collector, 1, policy_weights_getter=policy_weights_getter
179-
)
219+
# Set up weight updates
220+
# Validate weight_update_map if collector has weight_sync_schemes
221+
if (
222+
hasattr(self.collector, "_weight_sync_schemes")
223+
and self.collector._weight_sync_schemes
224+
):
225+
if weight_update_map is None:
226+
raise ValueError(
227+
"Collector has weight_sync_schemes configured, but weight_update_map was not provided. "
228+
f"Please provide a mapping for all destinations: {list(self.collector._weight_sync_schemes.keys())}"
229+
)
230+
231+
# Validate that all scheme destinations are covered in the map
232+
scheme_destinations = set(self.collector._weight_sync_schemes.keys())
233+
map_destinations = set(weight_update_map.keys())
234+
235+
if scheme_destinations != map_destinations:
236+
missing = scheme_destinations - map_destinations
237+
extra = map_destinations - scheme_destinations
238+
error_msg = "weight_update_map does not match collector's weight_sync_schemes.\n"
239+
if missing:
240+
error_msg += f" Missing destinations: {missing}\n"
241+
if extra:
242+
error_msg += f" Extra destinations: {extra}\n"
243+
raise ValueError(error_msg)
244+
245+
# Use the weight_update_map approach
246+
update_weights = UpdateWeights(
247+
self.collector,
248+
1,
249+
weight_update_map=weight_update_map,
250+
trainer=self,
251+
)
252+
else:
253+
# Fall back to legacy approach for backward compatibility
254+
if weight_update_map is not None:
255+
warnings.warn(
256+
"weight_update_map was provided but collector has no weight_sync_schemes. "
257+
"Ignoring weight_update_map and using legacy policy_weights_getter.",
258+
UserWarning,
259+
stacklevel=2,
260+
)
261+
262+
policy_weights_getter = partial(
263+
TensorDict.from_module, self.loss_module.actor_network
264+
)
265+
update_weights = UpdateWeights(
266+
self.collector, 1, policy_weights_getter=policy_weights_getter
267+
)
268+
180269
self.register_op("post_steps", update_weights)
181270

182271
# Store logging configuration

0 commit comments

Comments
 (0)