Skip to content

Commit

Permalink
Make fsdp strategy compatible with latest release (#226)
Browse files Browse the repository at this point in the history
* Make fsdp strategy compatible with latest release

Signed-off-by: Jerome <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* ruff fixes

Signed-off-by: Jerome <[email protected]>

* Update hpu-tests.yml

---------

Signed-off-by: Jerome <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jerome-habana and pre-commit-ci[bot] authored Aug 23, 2024
1 parent f70bb5b commit e4c377e
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 24 deletions.
26 changes: 7 additions & 19 deletions src/lightning_habana/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_setup_activation_checkpointing,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers
from lightning.fabric.utilities.types import ReduceOp
from lightning.pytorch.plugins.precision import Precision
from lightning.pytorch.strategies.fsdp import FSDPStrategy
Expand All @@ -45,6 +46,7 @@
_setup_activation_checkpointing,
)
from lightning_fabric.utilities.distributed import group as _group
from lightning_fabric.utilities.init import _has_meta_device_parameters_or_buffers
from lightning_fabric.utilities.types import ReduceOp
from pytorch_lightning.plugins.precision import Precision
from pytorch_lightning.strategies.fsdp import FSDPStrategy
Expand All @@ -57,10 +59,7 @@
from lightning_habana.pytorch.plugins.io_plugin import HPUCheckpointIO
from lightning_habana.pytorch.strategies.parallel import HPUParallelStrategy, _hpu_broadcast_object_list
from lightning_habana.utils.hpu_distributed import _sync_ddp_if_available
from lightning_habana.utils.imports import _HABANA_FRAMEWORK_AVAILABLE, _LIGHTNING_GREATER_EQUAL_2_3_0

if _HABANA_FRAMEWORK_AVAILABLE:
import habana_frameworks.torch.distributed.hccl as hpu_dist
from lightning_habana.utils.imports import _LIGHTNING_GREATER_EQUAL_2_3_0

if TYPE_CHECKING:
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy
Expand Down Expand Up @@ -154,26 +153,15 @@ def precision_plugin(self, precision_plugin: Optional[HPUFSDPPrecision]) -> None
)
self._precision_plugin = precision_plugin

@override
def setup_environment(self) -> None:
if self._process_group_backend == "hccl":
# this env is used in overrides to check the backend initiated
_ws = self.cluster_environment.world_size()
_grank = self.cluster_environment.global_rank()
_lrank = self.cluster_environment.local_rank()
hpu_dist.initialize_distributed_hpu(world_size=_ws, rank=_grank, local_rank=_lrank)
super().setup_environment()

def _setup_model(self, model: Module) -> Module:

from torch.distributed.fsdp import FullyShardedDataParallel

if any(isinstance(mod, FullyShardedDataParallel) for mod in model.modules()):
# TBD: Enable meta device check once we move to PTL>=2.3 which has HPU fsdo support
# if _has_meta_device_parameters_or_buffers(model):
# rank_zero_warn(
# "The model is already wrapped in `FSDP` but there are still parameters on the meta device."
# )
if _has_meta_device_parameters_or_buffers(model):
rank_zero_warn(
"The model is already wrapped in `FSDP` but there are still parameters on the meta device."
)
if "auto_wrap_policy" in self.kwargs:
# The user has wrapped their submodules manually, don't apply the auto wrap policy.
rank_zero_warn(
Expand Down
5 changes: 0 additions & 5 deletions tests/test_fabric/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@

if module_available("lightning"):
from lightning.fabric import Fabric
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.wrappers import _FabricOptimizer
elif module_available("pytorch_lightning"):
from lightning_fabric import Fabric
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning_fabric.wrappers import _FabricOptimizer

from lightning_habana.fabric.accelerator import HPUAccelerator
Expand Down Expand Up @@ -312,9 +310,6 @@ def test_rewrap_warnings(arg_hpus):
assert not isinstance(model._forward_module, FullyShardedDataParallel)
assert isinstance(model._forward_module[2], FullyShardedDataParallel)

if not _TORCH_GREATER_EQUAL_2_1:
return

with fabric.init_module(empty_init=True):
model = torch.nn.Sequential(
torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1), device_id=device_hpu)
Expand Down
1 change: 1 addition & 0 deletions tests/test_pytorch/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def test_accelerator_with_single_device():
assert isinstance(trainer.accelerator, HPUAccelerator)


@pytest.mark.standalone()
@pytest.mark.skipif(device_count() <= 1, reason="Test requires multiple HPU devices")
def test_accelerator_with_multiple_devices(arg_hpus):
if arg_hpus <= 1:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_pytorch/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_fsdp_strategy_sync_batchnorm(tmpdir, arg_hpus):

trainer = Trainer(
accelerator=HPUAccelerator(),
devices=arg_hpus,
strategy=HPUFSDPStrategy(
parallel_devices=[torch.device("hpu")] * arg_hpus,
cpu_offload=config,
Expand All @@ -204,6 +205,7 @@ def test_fsdp_simple_model(strategy, arg_hpus):

trainer = Trainer(
accelerator=HPUAccelerator(),
devices=arg_hpus,
strategy=HPUFSDPStrategy(
parallel_devices=[torch.device("hpu")] * arg_hpus,
sharding_strategy=strategy,
Expand All @@ -225,6 +227,7 @@ def test_fsdp_simple_model_activation_cp(strategy, arg_hpus):

trainer = Trainer(
accelerator=HPUAccelerator(),
devices=arg_hpus,
num_sanity_val_steps=0,
strategy=HPUFSDPStrategy(
parallel_devices=[torch.device("hpu")] * arg_hpus,
Expand All @@ -246,6 +249,7 @@ def test_fsdp_simple_model_activation_cp_mixed_precision(strategy, arg_hpus):

trainer = Trainer(
accelerator=HPUAccelerator(),
devices=arg_hpus,
num_sanity_val_steps=0,
strategy=HPUFSDPStrategy(
parallel_devices=[torch.device("hpu")] * arg_hpus,
Expand All @@ -261,6 +265,7 @@ def test_fsdp_simple_model_activation_cp_mixed_precision(strategy, arg_hpus):


@pytest.mark.skipif(HPUAccelerator.auto_device_count() <= 1, reason="Test requires multiple HPU devices.")
@pytest.mark.standalone()
def test_fsdp_strategy_simple_model_compile(tmpdir, arg_hpus):
"""Test to ensure that sync_batchnorm works when using FSDP and HPU."""
if arg_hpus <= 1:
Expand All @@ -273,6 +278,7 @@ def test_fsdp_strategy_simple_model_compile(tmpdir, arg_hpus):
trainer = Trainer(
default_root_dir=tmpdir,
accelerator=HPUAccelerator(),
devices=arg_hpus,
strategy=HPUFSDPStrategy(
parallel_devices=[torch.device("hpu")] * arg_hpus,
cpu_offload=config,
Expand Down Expand Up @@ -308,6 +314,7 @@ def training_step(self, batch, batch_idx):
trainer = Trainer(
default_root_dir=tmpdir,
accelerator=HPUAccelerator(),
devices=arg_hpus,
strategy=HPUFSDPStrategy(
parallel_devices=[torch.device("hpu")] * arg_hpus,
cpu_offload=True,
Expand Down Expand Up @@ -376,6 +383,7 @@ def test_fsdp_strategy_full_state_dict(tmpdir, wrap_min_params, arg_hpus):
trainer = Trainer(
default_root_dir=tmpdir,
accelerator=HPUAccelerator(),
devices=arg_hpus,
strategy=strategy,
max_epochs=1,
barebones=True,
Expand Down Expand Up @@ -548,6 +556,7 @@ def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params, arg_hpus):
trainer = Trainer(
default_root_dir=tmpdir,
accelerator=HPUAccelerator(),
devices=arg_hpus,
strategy=strategy,
max_epochs=1,
)
Expand Down Expand Up @@ -633,6 +642,7 @@ def test_fsdp_strategy_load_optimizer_states(tmpdir, wrap_min_params, arg_hpus):
trainer = Trainer(
default_root_dir=tmpdir,
accelerator=HPUAccelerator(),
devices=arg_hpus,
strategy=strategy,
max_epochs=1,
)
Expand Down

0 comments on commit e4c377e

Please sign in to comment.