From 9df7c0caae3857ac5f277cb42da78c8857cd64c4 Mon Sep 17 00:00:00 2001 From: Kai-Hsun Chen Date: Thu, 19 Sep 2024 23:26:40 +0000 Subject: [PATCH] fix test Signed-off-by: Kai-Hsun Chen --- .../test_execution_schedule_gpu.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py b/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py index 1c9122427da59..743d3afd3854c 100644 --- a/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py +++ b/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py @@ -178,23 +178,23 @@ def test_simulate_pp_2workers_2batches_1f1b( ): assert len(schedule) == len(expected_schedule) for i, operation in enumerate(schedule): - assert operation.local_idx == expected_schedule[i][0] + assert operation.exec_task_idx == expected_schedule[i][0] assert operation.type == expected_schedule[i][1] tensor_cpu = torch.zeros(10, 10) tensor_cuda = tensor_cpu.to("cuda:0") - refs = compiled_dag.execute(tensor_cpu) + refs = compiled_dag.execute(tensor_cuda) if single_fetch: assert len(refs) == 2 for ref in refs: tensor = ray.get(ref) - assert torch.equal(tensor, tensor_cuda) + assert torch.equal(tensor, tensor_cpu) else: tensors = ray.get(refs) assert len(tensors) == 2 for tensor in tensors: - assert torch.equal(tensor, tensor_cuda) + assert torch.equal(tensor, tensor_cpu) compiled_dag.teardown() @@ -216,11 +216,12 @@ def test_simulate_pp_4workers_8batches_1f1b(ray_start_regular, monkeypatch): ) tensor_cpu = torch.zeros(10, 10) - tensors = ray.get(compiled_dag.execute(tensor_cpu)) tensor_cuda = tensor_cpu.to("cuda:0") + tensors = ray.get(compiled_dag.execute(tensor_cuda)) + assert len(tensors) == num_microbatches for t in tensors: - assert torch.equal(t, tensor_cuda) + assert torch.equal(t, tensor_cpu) compiled_dag.teardown() @@ -277,17 +278,17 @@ def test_three_actors_with_nccl_1(ray_start_regular): ): assert len(schedule) == len(expected_schedule) for i, operation in enumerate(schedule): - assert operation.local_idx == expected_schedule[i][0] + assert operation.exec_task_idx == expected_schedule[i][0] assert operation.type == expected_schedule[i][1] tensor_cpu = torch.zeros(10, 10) - ref = compiled_dag.execute(tensor_cpu) - tensors = ray.get(ref) tensor_cuda = tensor_cpu.to("cuda:0") + ref = compiled_dag.execute(tensor_cuda) + tensors = ray.get(ref) assert len(tensors) == 2 for t in tensors: - assert torch.equal(t, tensor_cuda) + assert torch.equal(t, tensor_cpu) compiled_dag.teardown() @@ -356,23 +357,23 @@ def test_three_actors_with_nccl_2(ray_start_regular, single_fetch, monkeypatch): ): assert len(schedule) == len(expected_schedule) for i, operation in enumerate(schedule): - assert operation.local_idx == expected_schedule[i][0] + assert operation.exec_task_idx == expected_schedule[i][0] assert operation.type == expected_schedule[i][1] tensor_cpu = torch.zeros(10, 10) tensor_cuda = tensor_cpu.to("cuda:0") - refs = compiled_dag.execute(tensor_cpu) + refs = compiled_dag.execute(tensor_cuda) if single_fetch: assert len(refs) == 3 for ref in refs: tensor = ray.get(ref) - assert torch.equal(tensor, tensor_cuda) + assert torch.equal(tensor, tensor_cpu) else: tensors = ray.get(refs) assert len(tensors) == 3 for tensor in tensors: - assert torch.equal(tensor, tensor_cuda) + assert torch.equal(tensor, tensor_cpu) compiled_dag.teardown()