Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
Signed-off-by: Kai-Hsun Chen <[email protected]>
  • Loading branch information
kevin85421 committed Sep 19, 2024
1 parent 37bdac0 commit 9df7c0c
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions python/ray/dag/tests/experimental/test_execution_schedule_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,23 +178,23 @@ def test_simulate_pp_2workers_2batches_1f1b(
):
assert len(schedule) == len(expected_schedule)
for i, operation in enumerate(schedule):
assert operation.local_idx == expected_schedule[i][0]
assert operation.exec_task_idx == expected_schedule[i][0]
assert operation.type == expected_schedule[i][1]

tensor_cpu = torch.zeros(10, 10)
tensor_cuda = tensor_cpu.to("cuda:0")
refs = compiled_dag.execute(tensor_cpu)
refs = compiled_dag.execute(tensor_cuda)

if single_fetch:
assert len(refs) == 2
for ref in refs:
tensor = ray.get(ref)
assert torch.equal(tensor, tensor_cuda)
assert torch.equal(tensor, tensor_cpu)
else:
tensors = ray.get(refs)
assert len(tensors) == 2
for tensor in tensors:
assert torch.equal(tensor, tensor_cuda)
assert torch.equal(tensor, tensor_cpu)

compiled_dag.teardown()

Expand All @@ -216,11 +216,12 @@ def test_simulate_pp_4workers_8batches_1f1b(ray_start_regular, monkeypatch):
)

tensor_cpu = torch.zeros(10, 10)
tensors = ray.get(compiled_dag.execute(tensor_cpu))
tensor_cuda = tensor_cpu.to("cuda:0")
tensors = ray.get(compiled_dag.execute(tensor_cuda))

assert len(tensors) == num_microbatches
for t in tensors:
assert torch.equal(t, tensor_cuda)
assert torch.equal(t, tensor_cpu)
compiled_dag.teardown()


Expand Down Expand Up @@ -277,17 +278,17 @@ def test_three_actors_with_nccl_1(ray_start_regular):
):
assert len(schedule) == len(expected_schedule)
for i, operation in enumerate(schedule):
assert operation.local_idx == expected_schedule[i][0]
assert operation.exec_task_idx == expected_schedule[i][0]
assert operation.type == expected_schedule[i][1]

tensor_cpu = torch.zeros(10, 10)
ref = compiled_dag.execute(tensor_cpu)
tensors = ray.get(ref)
tensor_cuda = tensor_cpu.to("cuda:0")
ref = compiled_dag.execute(tensor_cuda)
tensors = ray.get(ref)

assert len(tensors) == 2
for t in tensors:
assert torch.equal(t, tensor_cuda)
assert torch.equal(t, tensor_cpu)

compiled_dag.teardown()

Expand Down Expand Up @@ -356,23 +357,23 @@ def test_three_actors_with_nccl_2(ray_start_regular, single_fetch, monkeypatch):
):
assert len(schedule) == len(expected_schedule)
for i, operation in enumerate(schedule):
assert operation.local_idx == expected_schedule[i][0]
assert operation.exec_task_idx == expected_schedule[i][0]
assert operation.type == expected_schedule[i][1]

tensor_cpu = torch.zeros(10, 10)
tensor_cuda = tensor_cpu.to("cuda:0")
refs = compiled_dag.execute(tensor_cpu)
refs = compiled_dag.execute(tensor_cuda)

if single_fetch:
assert len(refs) == 3
for ref in refs:
tensor = ray.get(ref)
assert torch.equal(tensor, tensor_cuda)
assert torch.equal(tensor, tensor_cpu)
else:
tensors = ray.get(refs)
assert len(tensors) == 3
for tensor in tensors:
assert torch.equal(tensor, tensor_cuda)
assert torch.equal(tensor, tensor_cpu)

compiled_dag.teardown()

Expand Down

0 comments on commit 9df7c0c

Please sign in to comment.