Skip to content

Commit

Permalink
Correct mypy errors
Browse files Browse the repository at this point in the history
Signed-off-by: Jerome <[email protected]>
  • Loading branch information
jerome-habana committed Sep 3, 2024
1 parent 846e929 commit 9a51fe2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
9 changes: 5 additions & 4 deletions src/lightning_habana/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
state_dict_type=state_dict_type,
**kwargs,
)
self._parallel_devices = parallel_devices

def setup_environment(self) -> None:
if self._process_group_backend == "hccl":
Expand Down Expand Up @@ -195,12 +196,12 @@ def setup_module(self, module: Module) -> Module:
return module

def setup(self, trainer: "pl.Trainer") -> None:
if self.parallel_devices is None:
self.parallel_devices = [
if self._parallel_devices is None:
self._parallel_devices = [
torch.device("hpu", torch.hpu.current_device())
] * HPUAccelerator.auto_device_count()
elif torch.device("hpu") in self.parallel_devices:
self.parallel_devices = [torch.device("hpu", torch.hpu.current_device())] * len(self.parallel_devices)
elif torch.device("hpu") in self._parallel_devices:
self._parallel_devices = [torch.device("hpu", torch.hpu.current_device())] * len(self._parallel_devices)
self.model_to_device()
super().setup(trainer)

Expand Down
9 changes: 5 additions & 4 deletions src/lightning_habana/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
state_dict_type=state_dict_type,
**kwargs,
)
self._parallel_devices = parallel_devices

@property
def mixed_precision_config(self) -> Optional["MixedPrecision"]:
Expand Down Expand Up @@ -181,12 +182,12 @@ def _setup_model(self, model: Module) -> Module:
return model

def setup(self, trainer: "pl.Trainer") -> None:
if self.parallel_devices is None:
self.parallel_devices = [
if self._parallel_devices is None:
self._parallel_devices = [
torch.device("hpu", torch.hpu.current_device())
] * HPUAccelerator.auto_device_count()
elif torch.device("hpu") in self.parallel_devices:
self.parallel_devices = [torch.device("hpu", torch.hpu.current_device())] * len(self.parallel_devices)
elif torch.device("hpu") in self._parallel_devices:
self._parallel_devices = [torch.device("hpu", torch.hpu.current_device())] * len(self._parallel_devices)
self.model_to_device()
super().setup(trainer)

Expand Down

0 comments on commit 9a51fe2

Please sign in to comment.