Skip to content

Commit 08c2902

Browse files
authored
fix marigold ut case fail on xpu (#12350)
Signed-off-by: Yao, Matrix <[email protected]>
1 parent 7a58734 commit 08c2902

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

tests/pipelines/marigold/test_marigold_depth.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434

3535
from ...testing_utils import (
36+
Expectations,
3637
backend_empty_cache,
3738
enable_full_determinism,
3839
floats_tensor,
@@ -356,7 +357,7 @@ def test_marigold_depth_einstein_f32_cpu_G0_S1_P32_E1_B1_M1(self):
356357
match_input_resolution=True,
357358
)
358359

359-
def test_marigold_depth_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
360+
def test_marigold_depth_einstein_f32_accelerator_G0_S1_P768_E1_B1_M1(self):
360361
self._test_marigold_depth(
361362
is_fp16=False,
362363
device=torch_device,
@@ -369,7 +370,7 @@ def test_marigold_depth_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
369370
match_input_resolution=True,
370371
)
371372

372-
def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
373+
def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E1_B1_M1(self):
373374
self._test_marigold_depth(
374375
is_fp16=True,
375376
device=torch_device,
@@ -382,7 +383,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
382383
match_input_resolution=True,
383384
)
384385

385-
def test_marigold_depth_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
386+
def test_marigold_depth_einstein_f16_accelerator_G2024_S1_P768_E1_B1_M1(self):
386387
self._test_marigold_depth(
387388
is_fp16=True,
388389
device=torch_device,
@@ -395,20 +396,31 @@ def test_marigold_depth_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
395396
match_input_resolution=True,
396397
)
397398

398-
def test_marigold_depth_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
399+
def test_marigold_depth_einstein_f16_accelerator_G0_S2_P768_E1_B1_M1(self):
400+
# fmt: off
401+
expected_slices = Expectations(
402+
{
403+
("cuda", 7): np.array([0.1085, 0.1098, 0.1110, 0.1081, 0.1085, 0.1082, 0.1085, 0.1057, 0.0996]),
404+
("xpu", 3): np.array([0.1084, 0.1096, 0.1108, 0.1080, 0.1083, 0.1080,
405+
0.1085, 0.1057, 0.0996]),
406+
}
407+
)
408+
expected_slice = expected_slices.get_expectation()
409+
# fmt: on
410+
399411
self._test_marigold_depth(
400412
is_fp16=True,
401413
device=torch_device,
402414
generator_seed=0,
403-
expected_slice=np.array([0.1085, 0.1098, 0.1110, 0.1081, 0.1085, 0.1082, 0.1085, 0.1057, 0.0996]),
415+
expected_slice=expected_slice,
404416
num_inference_steps=2,
405417
processing_resolution=768,
406418
ensemble_size=1,
407419
batch_size=1,
408420
match_input_resolution=True,
409421
)
410422

411-
def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
423+
def test_marigold_depth_einstein_f16_accelerator_G0_S1_P512_E1_B1_M1(self):
412424
self._test_marigold_depth(
413425
is_fp16=True,
414426
device=torch_device,
@@ -421,7 +433,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
421433
match_input_resolution=True,
422434
)
423435

424-
def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
436+
def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E3_B1_M1(self):
425437
self._test_marigold_depth(
426438
is_fp16=True,
427439
device=torch_device,
@@ -435,7 +447,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
435447
match_input_resolution=True,
436448
)
437449

438-
def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
450+
def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E4_B2_M1(self):
439451
self._test_marigold_depth(
440452
is_fp16=True,
441453
device=torch_device,
@@ -449,7 +461,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
449461
match_input_resolution=True,
450462
)
451463

452-
def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M0(self):
464+
def test_marigold_depth_einstein_f16_accelerator_G0_S1_P512_E1_B1_M0(self):
453465
self._test_marigold_depth(
454466
is_fp16=True,
455467
device=torch_device,

0 commit comments

Comments
 (0)