33
33
)
34
34
35
35
from ...testing_utils import (
36
+ Expectations ,
36
37
backend_empty_cache ,
37
38
enable_full_determinism ,
38
39
floats_tensor ,
@@ -356,7 +357,7 @@ def test_marigold_depth_einstein_f32_cpu_G0_S1_P32_E1_B1_M1(self):
356
357
match_input_resolution = True ,
357
358
)
358
359
359
- def test_marigold_depth_einstein_f32_cuda_G0_S1_P768_E1_B1_M1 (self ):
360
+ def test_marigold_depth_einstein_f32_accelerator_G0_S1_P768_E1_B1_M1 (self ):
360
361
self ._test_marigold_depth (
361
362
is_fp16 = False ,
362
363
device = torch_device ,
@@ -369,7 +370,7 @@ def test_marigold_depth_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
369
370
match_input_resolution = True ,
370
371
)
371
372
372
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E1_B1_M1 (self ):
373
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E1_B1_M1 (self ):
373
374
self ._test_marigold_depth (
374
375
is_fp16 = True ,
375
376
device = torch_device ,
@@ -382,7 +383,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
382
383
match_input_resolution = True ,
383
384
)
384
385
385
- def test_marigold_depth_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1 (self ):
386
+ def test_marigold_depth_einstein_f16_accelerator_G2024_S1_P768_E1_B1_M1 (self ):
386
387
self ._test_marigold_depth (
387
388
is_fp16 = True ,
388
389
device = torch_device ,
@@ -395,20 +396,31 @@ def test_marigold_depth_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
395
396
match_input_resolution = True ,
396
397
)
397
398
398
- def test_marigold_depth_einstein_f16_cuda_G0_S2_P768_E1_B1_M1 (self ):
399
+ def test_marigold_depth_einstein_f16_accelerator_G0_S2_P768_E1_B1_M1 (self ):
400
+ # fmt: off
401
+ expected_slices = Expectations (
402
+ {
403
+ ("cuda" , 7 ): np .array ([0.1085 , 0.1098 , 0.1110 , 0.1081 , 0.1085 , 0.1082 , 0.1085 , 0.1057 , 0.0996 ]),
404
+ ("xpu" , 3 ): np .array ([0.1084 , 0.1096 , 0.1108 , 0.1080 , 0.1083 , 0.1080 ,
405
+ 0.1085 , 0.1057 , 0.0996 ]),
406
+ }
407
+ )
408
+ expected_slice = expected_slices .get_expectation ()
409
+ # fmt: on
410
+
399
411
self ._test_marigold_depth (
400
412
is_fp16 = True ,
401
413
device = torch_device ,
402
414
generator_seed = 0 ,
403
- expected_slice = np . array ([ 0.1085 , 0.1098 , 0.1110 , 0.1081 , 0.1085 , 0.1082 , 0.1085 , 0.1057 , 0.0996 ]) ,
415
+ expected_slice = expected_slice ,
404
416
num_inference_steps = 2 ,
405
417
processing_resolution = 768 ,
406
418
ensemble_size = 1 ,
407
419
batch_size = 1 ,
408
420
match_input_resolution = True ,
409
421
)
410
422
411
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M1 (self ):
423
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P512_E1_B1_M1 (self ):
412
424
self ._test_marigold_depth (
413
425
is_fp16 = True ,
414
426
device = torch_device ,
@@ -421,7 +433,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
421
433
match_input_resolution = True ,
422
434
)
423
435
424
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E3_B1_M1 (self ):
436
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E3_B1_M1 (self ):
425
437
self ._test_marigold_depth (
426
438
is_fp16 = True ,
427
439
device = torch_device ,
@@ -435,7 +447,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
435
447
match_input_resolution = True ,
436
448
)
437
449
438
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E4_B2_M1 (self ):
450
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E4_B2_M1 (self ):
439
451
self ._test_marigold_depth (
440
452
is_fp16 = True ,
441
453
device = torch_device ,
@@ -449,7 +461,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
449
461
match_input_resolution = True ,
450
462
)
451
463
452
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M0 (self ):
464
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P512_E1_B1_M0 (self ):
453
465
self ._test_marigold_depth (
454
466
is_fp16 = True ,
455
467
device = torch_device ,
0 commit comments