@@ -204,6 +204,15 @@ class DummyCallableByNameSubmoduleIdx(DummyCallableBySubmodule):
204204 def __call__ (self , name , submodule , idx ):
205205 self .calls_track .append ((name , submodule , idx ))
206206 return self ._normalize_module_type (submodule ) in self .pin_targets
207+
208+ # Test for https://github.com/huggingface/diffusers/pull/12747
209+ class DummyInvalidCallable (DummyCallableBySubmodule ):
210+ """
211+ Callable group offloading pinner that uses invalid call signature
212+ """
213+ def __call__ (self , name , submodule , idx , extra ):
214+ self .calls_track .append ((name , submodule , idx , extra ))
215+ return self ._normalize_module_type (submodule ) in self .pin_targets
207216
208217
209218@require_torch_accelerator
@@ -420,7 +429,7 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
420429 cumulated_absmax , 1e-5 , f"Output differences for { name } exceeded threshold: { cumulated_absmax :.5f} "
421430 )
422431
423- def test_block_level_pin_groups_stay_on_device (self ):
432+ def test_block_level_offloading_with_pin_groups_stay_on_device (self ):
424433 if torch .device (torch_device ).type not in ["cuda" , "xpu" ]:
425434 return
426435
@@ -538,3 +547,39 @@ def assert_callables_offloading_tests(param_modules,
538547 callable_by_name_submodule_idx ,
539548 header_error_msg = "pin_groups with callable(name, submodule, idx)" )
540549
550+ def test_error_raised_if_pin_groups_received_invalid_value (self ):
551+ default_parameters = {
552+ "onload_device" : torch_device ,
553+ "offload_type" : "block_level" ,
554+ "num_blocks_per_group" : 1 ,
555+ "use_stream" : True ,
556+ }
557+ model = self .get_model ()
558+ with self .assertRaisesRegex (ValueError ,
559+ "`pin_groups` must be one of `None`, 'first_last', 'all', or a callable." ):
560+ model .enable_group_offload (
561+ ** default_parameters ,
562+ pin_groups = "invalid value" ,
563+ )
564+
565+ def test_error_raised_if_pin_groups_received_invalid_callables (self ):
566+ default_parameters = {
567+ "onload_device" : torch_device ,
568+ "offload_type" : "block_level" ,
569+ "num_blocks_per_group" : 1 ,
570+ "use_stream" : True ,
571+ }
572+ model = self .get_model ()
573+ invalid_callable = DummyInvalidCallable (pin_targets = [model .blocks [0 ], model .blocks [- 1 ]])
574+ model .enable_group_offload (
575+ ** default_parameters ,
576+ pin_groups = invalid_callable ,
577+ )
578+ with self .assertRaisesRegex (TypeError ,
579+ r"missing\s+\d+\s+required\s+positional\s+argument(s)?:" ):
580+ with torch .no_grad ():
581+ model (self .input )
582+
583+
584+
585+
0 commit comments