Skip to content

Commit

Permalink
Enabling multitenacy tests with FSDP
Browse files Browse the repository at this point in the history
Signed-off-by: jyothi kumar sambolu <[email protected]>
  • Loading branch information
jyothisambolu committed Sep 3, 2024
1 parent 4d54c13 commit e773c8e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 21 deletions.
17 changes: 6 additions & 11 deletions .azure/hpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,27 +160,23 @@ jobs:
tests/test_pytorch/test_accelerator.py \
tests/test_pytorch/test_compile.py \
tests/test_pytorch/test_profiler.py
# work around to mitigate tenancy issue in G1 for cards 0,1,2,3
condition: or(eq(variables['HABANA_VISIBLE_MODULES'], '4,5'), eq(variables['HABANA_VISIBLE_MODULES'], '6,7'))
displayName: 'Multi card(2) HPU test'
# - bash: |
# bash tests/run_standalone_tests.sh --hpus 2 -f \
# tests/test_pytorch/test_fsdp.py
# env:
# PT_HPU_LAZY_MODE: 0
# displayName: 'FSDP PT Multi card(2) HPU test'
- bash: |
bash tests/run_standalone_tests.sh --hpus 2 -f \
tests/test_pytorch/test_fsdp.py
env:
PT_HPU_LAZY_MODE: 0
displayName: 'FSDP PT Multi card(2) HPU test'
- bash: |
bash tests/run_standalone_tests.sh --hpus 2 -f \
tests/test_fabric/test_fsdp.py
env:
PT_HPU_LAZY_MODE: 0
condition: or(eq(variables['HABANA_VISIBLE_MODULES'], '4,5'), eq(variables['HABANA_VISIBLE_MODULES'], '6,7'))
displayName: 'FSDP Fabric Multi card(2) HPU test'
- bash: pip install ".[examples]"
condition: or(eq(variables['HABANA_VISIBLE_MODULES'], '4,5'), eq(variables['HABANA_VISIBLE_MODULES'], '6,7'))
displayName: 'Install extra for examples'

- bash: |
Expand All @@ -191,7 +187,6 @@ jobs:
python pytorch/hpu_graphs.py -v dynamicity --mode dynamic_control_flow dynamic_ops
PT_HPU_LAZY_MODE=0 python pytorch/language_model.py -s SHARD_GRAD_OP -d 2
workingDirectory: examples/
condition: or(eq(variables['HABANA_VISIBLE_MODULES'], '4,5'), eq(variables['HABANA_VISIBLE_MODULES'], '6,7'))
displayName: 'Testing HPU examples'
- task: PublishTestResults@2
Expand Down
11 changes: 6 additions & 5 deletions src/lightning_habana/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,6 @@ def __init__(
if not _LIGHTNING_GREATER_EQUAL_2_3_0:
raise OSError("HPUFSDPStrategy requires `lightning>=2.3.0 or pytorch-lightning >= 2.3.0`.")

if parallel_devices is None:
parallel_devices = [torch.device("hpu", torch.hpu.current_device())] * HPUAccelerator.auto_device_count()
elif torch.device("hpu") in parallel_devices:
parallel_devices = [torch.device("hpu", torch.hpu.current_device())] * len(parallel_devices)

super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
Expand Down Expand Up @@ -160,6 +155,12 @@ def precision(self, precision: Optional[HPUFSDPPrecision]) -> None:
raise TypeError(f"The FSDP strategy can only work with the `HPUFSDPPrecision` plugin, found {precision}")
self._precision = precision

@property
@override
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return torch.device("hpu", torch.hpu.current_device())

@override
def setup_module(self, module: Module) -> Module:
from torch.distributed.fsdp import FullyShardedDataParallel
Expand Down
11 changes: 6 additions & 5 deletions src/lightning_habana/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,6 @@ def __init__(
if not _LIGHTNING_GREATER_EQUAL_2_3_0:
raise OSError("HPUFSDPStrategy requires `lightning>=2.3.0 or pytorch-lightning >= 2.3.0`.")

if parallel_devices is None:
parallel_devices = [torch.device("hpu", torch.hpu.current_device())] * HPUAccelerator.auto_device_count()
elif torch.device("hpu") in parallel_devices:
parallel_devices = [torch.device("hpu", torch.hpu.current_device())] * len(parallel_devices)

super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
Expand Down Expand Up @@ -153,6 +148,12 @@ def precision_plugin(self, precision_plugin: Optional[HPUFSDPPrecision]) -> None
)
self._precision_plugin = precision_plugin

@property
@override
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return torch.device("hpu", torch.hpu.current_device())

def _setup_model(self, model: Module) -> Module:

from torch.distributed.fsdp import FullyShardedDataParallel
Expand Down

0 comments on commit e773c8e

Please sign in to comment.