Skip to content

Commit b8ace6f

Browse files
amathewcpytorchmergebot
authored andcommitted
Make dtensor tests device agnostic (#155687)
## MOTIVATION This PR is a continuation of #154840 and we are trying to make the tests more device agnostic by removing hard coded references to any particular device. Please refer to this RFC as well: pytorch/rfcs#66 ## CHANGES 1. test_convolution_ops.py: - Replace "cuda" with self.device_type 2. test_random_ops.py: - Remove setting and using TYPE_DEVICE variable since device_type is set as per the environment (device) in DTensorTestBase class. - Replace "cuda" with self.device_type Pull Request resolved: #155687 Approved by: https://github.com/EikanWang, https://github.com/d4l3k
1 parent f3ec16c commit b8ace6f

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

test/distributed/tensor/test_convolution_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def test_depthwise_convolution(self):
187187
@skip_if_lt_x_gpu(2)
188188
def test_conv_backward_none_grad_inp(self):
189189
device_mesh = init_device_mesh(
190-
device_type="cuda", mesh_shape=(self.world_size,)
190+
device_type=self.device_type, mesh_shape=(self.world_size,)
191191
)
192192
conv = nn.Conv2d(64, 64, 3, padding=1).train()
193193
x = torch.randn(1, 64, 32, 32)

test/distributed/tensor/test_random_ops.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
2525
from torch.distributed.tensor.debug import CommDebugMode
2626
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
27-
from torch.testing._internal.common_utils import run_tests, TEST_HPU
27+
from torch.testing._internal.common_utils import run_tests
2828
from torch.testing._internal.distributed._tensor.common_dtensor import (
2929
DTensorTestBase,
3030
skip_if_lt_x_gpu,
@@ -33,9 +33,6 @@
3333
)
3434

3535

36-
TYPE_DEVICE = "hpu" if TEST_HPU else "cuda"
37-
38-
3936
class DistTensorRandomInitTest(DTensorTestBase):
4037
def _run_init_op(self, init_op, *args, **kwargs):
4138
device_mesh = self.build_device_mesh()
@@ -55,7 +52,7 @@ def _run_init_op(self, init_op, *args, **kwargs):
5552
self.assertEqual(local_tensor_clone, dtensor.to_local())
5653
else:
5754
# create DTensor from Tensor
58-
_tensor = torch.empty(*input_size, device=TYPE_DEVICE)
55+
_tensor = torch.empty(*input_size, device=self.device_type)
5956
dtensor = distribute_tensor(_tensor, device_mesh, [Shard(1)])
6057

6158
# DTensor random init
@@ -173,7 +170,9 @@ def test_tp_model_meta_init(self):
173170
self.assertEqual(model.weight.device, torch.device("meta"))
174171

175172
# actual initialization
176-
device = torch.device("cuda", torch.cuda.current_device())
173+
device = torch.device(
174+
self.device_type, torch.get_device_module(self.device_type).current_device()
175+
)
177176
model.to_empty(device=device)
178177
model.reset_parameters()
179178
self.assertTrue(
@@ -224,7 +223,9 @@ def test_fsdp_tp_model_meta_init(self):
224223
self.assertEqual(model.weight.device, torch.device("meta"))
225224

226225
# actual initialization
227-
device = torch.device("cuda", torch.cuda.current_device())
226+
device = torch.device(
227+
self.device_type, torch.get_device_module(self.device_type).current_device()
228+
)
228229
model.to_empty(device=device)
229230
model.reset_parameters()
230231
self.assertTrue(
@@ -266,7 +267,9 @@ def test_rng_tracker_init(self):
266267
# seed synchronization now does NOT happen after the first `distribute_tensor`
267268
# call
268269
dt = distribute_tensor(
269-
torch.empty([self.world_size], device=TYPE_DEVICE), device_mesh, [Shard(0)]
270+
torch.empty([self.world_size], device=self.device_type),
271+
device_mesh,
272+
[Shard(0)],
270273
)
271274
self.assertTrue(random._rng_tracker is None)
272275
# seed synchronization only happens after `manual_seed` or the first DTensor
@@ -366,7 +369,7 @@ def test_deterministic_dropout_1d(self):
366369
size = [4, 4]
367370

368371
dtensor = distribute_tensor(
369-
torch.empty(*size, device=TYPE_DEVICE), device_mesh, [Shard(1)]
372+
torch.empty(*size, device=self.device_type), device_mesh, [Shard(1)]
370373
)
371374

372375
# a random op call shifts the offset
@@ -571,7 +574,9 @@ def test_hsdp_tp_model_meta_init(self):
571574
self.assertEqual(model.weight.device, torch.device("meta"))
572575

573576
# actual initialization
574-
device = torch.device("cuda", torch.cuda.current_device())
577+
device = torch.device(
578+
self.device_type, torch.get_device_module(self.device_type).current_device()
579+
)
575580
model.to_empty(device=device)
576581
model.reset_parameters()
577582
self.assertTrue(

0 commit comments

Comments
 (0)