Skip to content

Commit 6a28cc8

Browse files
amathewcpytorchmergebot
authored andcommitted
Add TEST_HPU flag to set device type (#153461)
MOTIVATION This PR includes a minor change to check for TEST_HPU flag as well before falling back to CPU. Without this flag, some tests were falling back to CPU causing them to fail. Please refer to this RFC as well: pytorch/rfcs#66 CHANGES add TEST_HPU flag to some of the conditions checking the environment use DEVICE_COUNT variable instead of torch.accelerator.device_count() API since the later is not supported on out-of-tree devices like Intel Gaudi. @ankurneog , @EikanWang , @cyyever , @guangyey Pull Request resolved: #153461 Approved by: https://github.com/EikanWang, https://github.com/cyyever, https://github.com/albanD
1 parent a54bf43 commit 6a28cc8

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

torch/testing/_internal/distributed/_tensor/common_dtensor.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
NUM_DEVICES = 4
5959

6060
# We use this as a proxy for "multiple GPUs exist"
61-
if (TEST_CUDA or TEST_XPU) and DEVICE_COUNT > 1:
61+
if (TEST_CUDA or TEST_XPU or TEST_HPU) and DEVICE_COUNT > 1:
6262
# when we actually have multiple GPUs, relax the requirement to smaller counts.
6363
NUM_DEVICES = min(NUM_DEVICES, DEVICE_COUNT)
6464

@@ -339,11 +339,8 @@ def world_size(self) -> int:
339339

340340
@property
341341
def device_type(self) -> str:
342-
# if enough GPU we can use GPU, otherwise we fallback to CPU
343-
if (
344-
not (TEST_CUDA or TEST_XPU)
345-
or torch.accelerator.device_count() < self.world_size
346-
):
342+
# if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU
343+
if not (TEST_CUDA or TEST_XPU or TEST_HPU) or DEVICE_COUNT < self.world_size:
347344
return "cpu"
348345
else:
349346
return DEVICE_TYPE

0 commit comments

Comments
 (0)