Skip to content

Commit

Permalink
Merge branch 'main' into dummy
Browse files Browse the repository at this point in the history
  • Loading branch information
jerome-habana authored Oct 17, 2024
2 parents f1516bb + 1b754e1 commit 9874e05
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 16 deletions.
11 changes: 6 additions & 5 deletions src/lightning_habana/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
else:
raise ModuleNotFoundError("You are missing `lightning` or `pytorch-lightning` package, please install it.")

from lightning_habana.fabric.accelerator import HPUAccelerator
from lightning_habana.fabric.plugins.fsdp_precision import HPUFSDPPrecision
from lightning_habana.fabric.strategies.parallel import HPUParallelStrategy
from lightning_habana.utils.imports import _LIGHTNING_GREATER_EQUAL_2_3_0
Expand Down Expand Up @@ -96,10 +95,6 @@ def __init__(
) -> None:
if not _LIGHTNING_GREATER_EQUAL_2_3_0:
raise OSError("HPUFSDPStrategy requires `lightning>=2.3.0 or pytorch-lightning >= 2.3.0`.")
if parallel_devices is None:
parallel_devices = [torch.device("hpu", torch.hpu.current_device())] * HPUAccelerator.auto_device_count()
elif torch.device("hpu") in parallel_devices:
parallel_devices = [torch.device("hpu", torch.hpu.current_device())] * len(parallel_devices)
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
Expand Down Expand Up @@ -141,6 +136,12 @@ def precision(self, precision: Optional[HPUFSDPPrecision]) -> None:
raise TypeError(f"The FSDP strategy can only work with the `HPUFSDPPrecision` plugin, found {precision}")
self._precision = precision

@property
@override
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return torch.device("hpu", torch.hpu.current_device())

@override
def setup_module(self, module: Module) -> Module:
from torch.distributed.fsdp import FullyShardedDataParallel
Expand Down
5 changes: 2 additions & 3 deletions src/lightning_habana/fabric/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,14 @@ def __init__(
if not HPU_AVAILABLE:
raise ValueError("`HPUParallelStrategy` requires HPU devices to run")

self._process_group_backend: Optional[str] = "hccl"
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision=precision,
)
self._process_group_backend = "hccl"
self._process_group_backend: Optional[str] = "hccl"
self._timeout = default_pg_timeout
self._num_nodes = 1
self._start_method = "spawn" if self.strategy_name == "hpu_parallel" else None
Expand All @@ -116,11 +115,11 @@ def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:
self._checkpoint_io = io # type: ignore

def setup_environment(self) -> None:
self.setup_hccl_env()
super().setup_environment()
if self.strategy_name == "hpu_parallel":
# Strategies derived from this class should handle their own distributed setups.
self.setup_distributed()
self.setup_hccl_env()

def setup_hccl_env(self) -> None:
assert self._process_group_backend == "hccl"
Expand Down
12 changes: 7 additions & 5 deletions src/lightning_habana/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
else:
raise ModuleNotFoundError("You are missing `lightning` or `pytorch-lightning` package, please install it.")

from lightning_habana.pytorch.accelerator import HPUAccelerator
from lightning_habana.pytorch.plugins.fsdp_precision import HPUFSDPPrecision
from lightning_habana.pytorch.plugins.io_plugin import HPUCheckpointIO
from lightning_habana.pytorch.strategies.parallel import HPUParallelStrategy
Expand Down Expand Up @@ -101,10 +100,7 @@ def __init__(
) -> None:
if not _LIGHTNING_GREATER_EQUAL_2_3_0:
raise OSError("HPUFSDPStrategy requires `lightning>=2.3.0 or pytorch-lightning >= 2.3.0`.")
if parallel_devices is None:
parallel_devices = [torch.device("hpu", torch.hpu.current_device())] * HPUAccelerator.auto_device_count()
elif torch.device("hpu") in parallel_devices:
parallel_devices = [torch.device("hpu", torch.hpu.current_device())] * len(parallel_devices)

super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
Expand Down Expand Up @@ -149,6 +145,12 @@ def precision_plugin(self, precision_plugin: Optional[HPUFSDPPrecision]) -> None
)
self._precision_plugin = precision_plugin

@property
@override
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return torch.device("hpu", torch.hpu.current_device())

def _setup_model(self, model: Module) -> Module:
from torch.distributed.fsdp import FullyShardedDataParallel

Expand Down
9 changes: 6 additions & 3 deletions src/lightning_habana/pytorch/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
else:
raise ModuleNotFoundError("You are missing `lightning` or `pytorch-lightning` package, please install it.")

from lightning_habana.pytorch.accelerator import HPUAccelerator
from lightning_habana.pytorch.plugins.io_plugin import HPUCheckpointIO
from lightning_habana.utils.hpu_distributed import _sync_hpu_processes_if_available
from lightning_habana.utils.imports import _HABANA_FRAMEWORK_AVAILABLE
Expand Down Expand Up @@ -114,15 +115,18 @@ def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:
self._checkpoint_io = io # type: ignore

def setup_environment(self) -> None:
self.setup_hccl_env()
super().setup_environment()
if self.__class__.__name__ == "HPUParallelStrategy":
# Strategies derived from this class should handle their own distributed setups.
self.setup_distributed()
self.setup_hccl_env()

def setup_hccl_env(self) -> None:
"""Initializes the HCCL environment for distributed training on HPU devices."""
assert self._get_process_group_backend() == "hccl"
assert isinstance(
self.accelerator, HPUAccelerator
), f"{self.__class__.__name__} requires HPUAccelerator. Found {self.accelerator}"
_ws = self.cluster_environment.world_size()
_grank = self.cluster_environment.global_rank()
_lrank = self.cluster_environment.local_rank()
Expand All @@ -137,8 +141,7 @@ def setup_distributed(self) -> None:
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)

def _get_process_group_backend(self) -> str:
assert self.root_device.type == "hpu"
return "hccl"
return self._process_group_backend

def set_world_ranks(self) -> None:
if self.cluster_environment is not None:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_pytorch/strategies/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,3 +945,12 @@ def test_step(self, batch, batch_idx):
if device_count == 2:
bf16_loss = torch.tensor(1.2734)
assert torch.allclose(fp8_test_loss, bf16_loss, rtol=0.03, atol=0.02)


def test_hpu_deepspeed_strategy_device_not_hpu(tmpdir):
"""Tests hpu required with HPUDeepSpeedStrategy."""
trainer = Trainer(
default_root_dir=tmpdir, accelerator="cpu", strategy=HPUDeepSpeedStrategy(), devices=1, fast_dev_run=True
)
with pytest.raises(AssertionError, match="HPUDeepSpeedStrategy requires HPUAccelerator"):
trainer.fit(BoringModel())
11 changes: 11 additions & 0 deletions tests/test_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def test_fsdp_custom_mixed_precision():
assert strategy.mixed_precision_config == config


@pytest.mark.xfail(run=False, reason="To be fixed.Failure post 1.17 upgrade.")
@pytest.mark.skipif(HPUAccelerator.auto_device_count() <= 1, reason="Test requires multiple HPU devices")
def test_fsdp_strategy_sync_batchnorm(tmpdir, arg_hpus):
"""Test to ensure that sync_batchnorm works when using FSDP on HPU."""
Expand Down Expand Up @@ -262,6 +263,7 @@ def test_fsdp_simple_model_activation_cp_mixed_precision(strategy, arg_hpus):
trainer.fit(model)


@pytest.mark.xfail(run=False, reason="To be fixed.Failure post 1.17 upgrade.")
@pytest.mark.skipif(HPUAccelerator.auto_device_count() <= 1, reason="Test requires multiple HPU devices.")
@pytest.mark.standalone()
def test_fsdp_strategy_simple_model_compile(tmpdir, arg_hpus):
Expand Down Expand Up @@ -795,3 +797,12 @@ def training_step(self, batch, batch_idx):
)
trainer.fit(_model)
assert expected_value.item() == _model.reduced_value.item()


def test_hpu_fsdp_strategy_device_not_hpu(tmpdir):
"""Tests hpu required with HPUDeepSpeedStrategy."""
trainer = Trainer(
default_root_dir=tmpdir, accelerator="cpu", strategy=HPUFSDPStrategy(), devices=1, fast_dev_run=True
)
with pytest.raises(AssertionError, match="HPUFSDPStrategy requires HPUAccelerator"):
trainer.fit(BoringModel())
9 changes: 9 additions & 0 deletions tests/test_pytorch/strategies/test_hpu_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def test_hpu_ddp_strategy_init():
assert strategy._ddp_kwargs["find_unused_parameters"] == find_unused_parameters


def test_hpu_ddp_strategy_device_not_hpu(tmpdir):
"""Tests hpu required with HPUDDPStrategy."""
trainer = Trainer(
default_root_dir=tmpdir, accelerator="cpu", strategy=HPUDDPStrategy(), devices=1, fast_dev_run=True
)
with pytest.raises(AssertionError, match="HPUDDPStrategy requires HPUAccelerator"):
trainer.fit(BoringModel())


def test_hpu_ddp_custom_strategy_registry():
"""Test custom parallel strategy registry."""

Expand Down
10 changes: 10 additions & 0 deletions tests/test_pytorch/strategies/test_hpu_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import torch.distributed
from lightning_utilities import module_available
from torch.multiprocessing.spawn import ProcessRaisedException

if module_available("lightning"):
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
Expand Down Expand Up @@ -55,6 +56,15 @@ def test_hpu_parallel_strategy_init():
assert strategy._num_nodes == 1


def test_hpu_parallel_strategy_device_not_hpu(tmpdir):
"""Tests hpu required with HPUParallelStrategy."""
trainer = Trainer(
default_root_dir=tmpdir, accelerator="cpu", strategy=HPUParallelStrategy(), devices=1, fast_dev_run=True
)
with pytest.raises(ProcessRaisedException, match="HPUParallelStrategy requires HPUAccelerator"):
trainer.fit(BoringModel())


def test_hpu_parallel_parallel_devices():
"""Test parallel_devices set."""
devices = [torch.device("hpu")] * 2
Expand Down

0 comments on commit 9874e05

Please sign in to comment.