Skip to content

Commit

Permalink
fix fsdp strateg pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
ankitgola005 committed Sep 11, 2024
1 parent e3f530a commit d3a527d
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 6 deletions.
6 changes: 1 addition & 5 deletions src/lightning_habana/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
else:
raise ModuleNotFoundError("You are missing `lightning` or `pytorch-lightning` package, please install it.")

from lightning_habana.pytorch.accelerator import HPUAccelerator
from lightning_habana.pytorch.plugins.fsdp_precision import HPUFSDPPrecision
from lightning_habana.pytorch.plugins.io_plugin import HPUCheckpointIO
from lightning_habana.pytorch.strategies.parallel import HPUParallelStrategy
Expand Down Expand Up @@ -101,10 +100,7 @@ def __init__(
) -> None:
if not _LIGHTNING_GREATER_EQUAL_2_3_0:
raise OSError("HPUFSDPStrategy requires `lightning>=2.3.0 or pytorch-lightning >= 2.3.0`.")
if parallel_devices is None:
parallel_devices = [torch.device("hpu", torch.hpu.current_device())] * HPUAccelerator.auto_device_count()
elif torch.device("hpu") in parallel_devices:
parallel_devices = [torch.device("hpu", torch.hpu.current_device())] * len(parallel_devices)

super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
Expand Down
1 change: 0 additions & 1 deletion src/lightning_habana/pytorch/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def setup_distributed(self) -> None:
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)

def _get_process_group_backend(self) -> str:
assert self.root_device.type == "hpu"
return "hccl"

def set_world_ranks(self) -> None:
Expand Down

0 comments on commit d3a527d

Please sign in to comment.