diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 333c72a8297..5efa0592068 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -21,13 +21,7 @@ from tensordict.nn import TensorDictModuleBase from torch import nn, vmap -from torchrl._utils import ( - implement_for, - logger, - logger as torchrl_logger, - RL_WARNINGS, - seed_generator, -) +from torchrl._utils import implement_for, logger, RL_WARNINGS, seed_generator from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import MultiThreadedEnv, ObservationNorm from torchrl.envs.batched_envs import ParallelEnv, SerialEnv @@ -230,7 +224,7 @@ def f_retry(*args, **kwargs): return f(*args, **kwargs) except ExceptionToCheck as e: msg = "%s, Retrying in %d seconds..." % (str(e), mdelay) - torchrl_logger.info(msg) + logger.info(msg) time.sleep(mdelay) mtries -= 1 try: diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 5ea95ae26a8..85aaf6003e7 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -62,7 +62,7 @@ timeit, ) -torchrl_logger = logger +logger = logger # Filter warnings in subprocesses: True by default given the multiple optional # deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`. @@ -121,5 +121,5 @@ def _inv(self): "set_auto_unwrap_transformed_env", "timeit", "logger", - "torchrl_logger", + "logger", ] diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index d93cef3375d..182e510145a 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -18,7 +18,7 @@ import numpy as np import torch.nn -from tensordict import NestedKey, pad, TensorDictBase +from tensordict import NestedKey, pad, TensorDict, TensorDictBase from tensordict._tensorcollection import TensorCollection from tensordict.nn import TensorDictModule from tensordict.utils import expand_right @@ -29,6 +29,7 @@ KeyDependentDefaultDict, logger as torchrl_logger, RL_WARNINGS, + timeit, VERBOSE, ) from torchrl.collectors import DataCollectorBase @@ -150,6 +151,10 @@ class Trainer: This will only work if the replay buffer is registed within the data collector. If using this, the UTD ratio (Update to Data) will be logged under the key "utd_ratio". Default is False. + log_timings (bool, optional): If True, automatically register a LogTiming hook to log + timing information for all hooks to the logger (e.g., wandb, tensorboard). + Timing metrics will be logged with prefix "time/" (e.g., "time/hook/UpdateWeights"). + Default is False. """ @classmethod @@ -186,6 +191,7 @@ def __init__( save_trainer_file: str | pathlib.Path | None = None, num_epochs: int = 1, async_collection: bool = False, + log_timings: bool = False, ) -> None: # objects @@ -267,6 +273,10 @@ def __init__( optimizer_hook = OptimizerHook(self.optimizer) optimizer_hook.register(self) + if log_timings: + log_timing_hook = LogTiming(prefix="time", percall=True, erase=False) + log_timing_hook.register(self) + def register_module(self, module_name: str, module: Any) -> None: if module_name in self._modules: raise RuntimeError( @@ -274,6 +284,34 @@ def register_module(self, module_name: str, module: Any) -> None: ) self._modules[module_name] = module + def _wrap_hook_with_timing( + self, op: Callable, hook_name: str | None = None + ) -> Callable: + """Wrap a hook with timing measurement. + + Args: + op: The hook/operation to wrap + hook_name: Optional name for the hook. If not provided, will be inferred from op. + + Returns: + A wrapped version of the hook that measures execution time. + """ + if hook_name is None: + hook_name = getattr( + op, + "__name__", + op.__class__.__name__ if hasattr(op, "__class__") else "unknown_hook", + ) + + def timed_hook(*args, **kwargs): + with timeit(f"hook/{hook_name}"): + return op(*args, **kwargs) + + # Preserve original attributes for debugging + timed_hook.__wrapped__ = op + timed_hook.__name__ = hook_name + return timed_hook + def _get_state(self): if _CKPT_BACKEND == "torchsnapshot": state = StateDict( @@ -397,79 +435,89 @@ def register_op( op: Callable, **kwargs, ) -> None: + # Wrap hook with timing for performance monitoring + # Get hook name from registered modules if available + hook_name = None + for name, module in self._modules.items(): + if module is op or (callable(module) and module.__call__ is op): + hook_name = name + break + + timed_op = self._wrap_hook_with_timing(op, hook_name) + if dest == "batch_process": _check_input_output_typehint( op, input=TensorDictBase, output=TensorDictBase ) - self._batch_process_ops.append((op, kwargs)) + self._batch_process_ops.append((timed_op, kwargs)) elif dest == "pre_optim_steps": _check_input_output_typehint(op, input=None, output=None) - self._pre_optim_ops.append((op, kwargs)) + self._pre_optim_ops.append((timed_op, kwargs)) elif dest == "process_optim_batch": _check_input_output_typehint( op, input=TensorDictBase, output=TensorDictBase ) - self._process_optim_batch_ops.append((op, kwargs)) + self._process_optim_batch_ops.append((timed_op, kwargs)) elif dest == "post_loss": _check_input_output_typehint( op, input=TensorDictBase, output=TensorDictBase ) - self._post_loss_ops.append((op, kwargs)) + self._post_loss_ops.append((timed_op, kwargs)) elif dest == "optimizer": _check_input_output_typehint( op, input=[TensorDictBase, bool, float, int], output=TensorDictBase ) - self._optimizer_ops.append((op, kwargs)) + self._optimizer_ops.append((timed_op, kwargs)) elif dest == "post_steps": _check_input_output_typehint(op, input=None, output=None) - self._post_steps_ops.append((op, kwargs)) + self._post_steps_ops.append((timed_op, kwargs)) elif dest == "post_optim": _check_input_output_typehint(op, input=None, output=None) - self._post_optim_ops.append((op, kwargs)) + self._post_optim_ops.append((timed_op, kwargs)) elif dest == "pre_steps_log": _check_input_output_typehint( op, input=TensorDictBase, output=tuple[str, float] ) - self._pre_steps_log_ops.append((op, kwargs)) + self._pre_steps_log_ops.append((timed_op, kwargs)) elif dest == "post_steps_log": _check_input_output_typehint( op, input=TensorDictBase, output=tuple[str, float] ) - self._post_steps_log_ops.append((op, kwargs)) + self._post_steps_log_ops.append((timed_op, kwargs)) elif dest == "post_optim_log": _check_input_output_typehint( op, input=TensorDictBase, output=tuple[str, float] ) - self._post_optim_log_ops.append((op, kwargs)) + self._post_optim_log_ops.append((timed_op, kwargs)) elif dest == "pre_epoch_log": _check_input_output_typehint( op, input=TensorDictBase, output=tuple[str, float] ) - self._pre_epoch_log_ops.append((op, kwargs)) + self._pre_epoch_log_ops.append((timed_op, kwargs)) elif dest == "post_epoch_log": _check_input_output_typehint( op, input=TensorDictBase, output=tuple[str, float] ) - self._post_epoch_log_ops.append((op, kwargs)) + self._post_epoch_log_ops.append((timed_op, kwargs)) elif dest == "pre_epoch": _check_input_output_typehint(op, input=None, output=None) - self._pre_epoch_ops.append((op, kwargs)) + self._pre_epoch_ops.append((timed_op, kwargs)) elif dest == "post_epoch": _check_input_output_typehint(op, input=None, output=None) - self._post_epoch_ops.append((op, kwargs)) + self._post_epoch_ops.append((timed_op, kwargs)) else: raise RuntimeError( @@ -606,7 +654,7 @@ def train(self): if self.async_collection: self.collector.start() - while not self.collector.getattr_rb("write_count"): + while self.collector.getattr_rb("write_count") == 0: time.sleep(0.1) # Create async iterator that monitors write_count progress @@ -1083,6 +1131,83 @@ def __call__(self, *args, **kwargs): torch.cuda.empty_cache() +class LogTiming(TrainerHookBase): + """Hook to log timing information collected by timeit context managers. + + This hook extracts timing data from the global timeit registry and logs it + to the trainer's logger (e.g., wandb, tensorboard). It's useful for profiling + different parts of the training loop. + + Args: + prefix (str, optional): Prefix to add to timing metric names. + Default is "time". + percall (bool, optional): If True, log average time per call. + If False, log total time. Default is True. + erase (bool, optional): If True, reset timing data after each log. + Default is False. + + Examples: + >>> # Log timing data after each optimization step + >>> log_timing = LogTiming(prefix="time", percall=True) + >>> trainer.register_op("post_optim_log", log_timing) + + >>> # Log timing data after each batch collection + >>> log_timing = LogTiming(prefix="time", erase=True) + >>> trainer.register_op("post_steps_log", log_timing) + + Note: + This hook works with timing data collected using the `timeit` context manager. + For example, hooks registered with `register_op` are automatically wrapped + with timing measurement. + """ + + def __init__( + self, + prefix: str = "time", + percall: bool = True, + erase: bool = False, + ): + self.prefix = prefix + self.percall = percall + self.erase = erase + + def __call__(self, batch: TensorDictBase | None = None) -> dict: + """Extract timing data and return as a dict for logging. + + Args: + batch: The batch (unused, but required by hook signature) + + Returns: + Dictionary of timing metrics with the format {metric_name: value} + """ + timing_dict = timeit.todict(percall=self.percall, prefix=self.prefix) + + if self.erase: + timeit.erase() + + return timing_dict + + def state_dict(self) -> dict[str, Any]: + """Return state dict for checkpointing.""" + return { + "prefix": self.prefix, + "percall": self.percall, + "erase": self.erase, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load state dict from checkpoint.""" + self.prefix = state_dict.get("prefix", "time") + self.percall = state_dict.get("percall", True) + self.erase = state_dict.get("erase", False) + + def register(self, trainer: Trainer, name: str | None = None): + if name is None: + name = "log_timing" + trainer.register_module(name, self) + trainer.register_op("post_steps_log", self) + + class LogScalar(TrainerHookBase): """Generic scalar logger hook for any tensor values in the batch. @@ -1635,6 +1760,29 @@ def __init__( ) +def _resolve_module(trainer: Trainer, path: str): + """Resolve a module from a trainer using a string path. + + Args: + trainer (Trainer): The trainer instance to resolve from. + path (str): A dot-separated path to the module (e.g., "loss_module.actor_network"). + + Returns: + The resolved module. + + Raises: + AttributeError: If the path cannot be resolved. + + Examples: + >>> module = _resolve_module(trainer, "loss_module.actor_network") + >>> module = _resolve_module(trainer, "collector.policy") + """ + obj = trainer + for attr in path.split("."): + obj = getattr(obj, attr) + return obj + + class UpdateWeights(TrainerHookBase): """A collector weights update hook class. @@ -1648,9 +1796,33 @@ class UpdateWeights(TrainerHookBase): must be synced. update_weights_interval (int): Interval (in terms of number of batches collected) where the sync must take place. + policy_weights_getter (Callable, optional): A callable that returns the policy + weights to sync. Used for backward compatibility. If both this and + weight_update_map are provided, weight_update_map takes precedence. + weight_update_map (dict[str, str], optional): A mapping from destination paths + (keys in collector's weight_sync_schemes) to source paths on the trainer. + Example: {"policy": "loss_module.actor_network", + "replay_buffer.transforms[0]": "loss_module.critic_network"} + trainer (Trainer, optional): The trainer instance, required when using + weight_update_map to resolve source paths. Examples: - >>> update_weights = UpdateWeights(trainer.collector, T) + >>> # Legacy usage with policy_weights_getter + >>> update_weights = UpdateWeights( + ... trainer.collector, T, + ... policy_weights_getter=lambda: TensorDict.from_module(policy) + ... ) + >>> trainer.register_op("post_steps", update_weights) + + >>> # New usage with weight_update_map + >>> update_weights = UpdateWeights( + ... trainer.collector, T, + ... weight_update_map={ + ... "policy": "loss_module.actor_network", + ... "replay_buffer.transforms[0]": "loss_module.critic_network" + ... }, + ... trainer=trainer + ... ) >>> trainer.register_op("post_steps", update_weights) """ @@ -1660,24 +1832,65 @@ def __init__( collector: DataCollectorBase, update_weights_interval: int, policy_weights_getter: Callable[[Any], Any] | None = None, + weight_update_map: dict[str, str] | None = None, + trainer: Trainer | None = None, ): self.collector = collector self.update_weights_interval = update_weights_interval self.counter = 0 self.policy_weights_getter = policy_weights_getter + self.weight_update_map = weight_update_map + self.trainer = trainer + + # Validate inputs + if weight_update_map is not None and trainer is None: + raise ValueError("trainer must be provided when using weight_update_map") def __call__(self): self.counter += 1 if self.counter % self.update_weights_interval == 0: - weights = ( - self.policy_weights_getter() - if self.policy_weights_getter is not None - else None - ) - if weights is not None: - self.collector.update_policy_weights_(weights) + # New approach: use weight_update_map if provided + if self.weight_update_map is not None: + self._update_with_map() + # Legacy approach: use policy_weights_getter + else: + weights = ( + self.policy_weights_getter() + if self.policy_weights_getter is not None + else None + ) + if weights is not None: + self.collector.update_policy_weights_(weights) + else: + self.collector.update_policy_weights_() + + def _update_with_map(self): + """Update weights using the weight_update_map.""" + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + weights_dict = {} + + for destination, source_path in self.weight_update_map.items(): + # Resolve the source module from the trainer + source_module = _resolve_module(self.trainer, source_path) + + # Get the scheme for this destination to know the extraction strategy + if ( + hasattr(self.collector, "_weight_sync_schemes") + and self.collector._weight_sync_schemes + and destination in self.collector._weight_sync_schemes + ): + scheme = self.collector._weight_sync_schemes[destination] + strategy = WeightStrategy(extract_as=scheme.strategy) + weights = strategy.extract_weights(source_module) else: - self.collector.update_policy_weights_() + # Fallback: use TensorDict extraction if no scheme found + weights = TensorDict.from_module(source_module) + + weights_dict[destination] = weights + + # Send all weights atomically + self.collector.update_policy_weights_(weights_dict=weights_dict) def register(self, trainer: Trainer, name: str = "update_weights"): trainer.register_module(name, self)