diff --git a/src/neuronx_distributed/lightning/module.py b/src/neuronx_distributed/lightning/module.py index 4c09829..0046cd6 100644 --- a/src/neuronx_distributed/lightning/module.py +++ b/src/neuronx_distributed/lightning/module.py @@ -4,6 +4,7 @@ import torch import torch_xla.core.xla_model as xm from lightning_utilities.core.apply_func import apply_to_collection +from lightning_utilities.core.rank_zero import rank_zero_warn from pytorch_lightning import LightningModule from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import ( _FxValidator,