2525from diffusers .utils import get_logger
2626from diffusers .utils .import_utils import compare_versions
2727
28+ from typing import Union
29+
2830from ..testing_utils import (
2931 backend_empty_cache ,
3032 backend_max_memory_allocated ,
@@ -147,6 +149,61 @@ def __init__(self):
147149 def post_forward (self , module , output ):
148150 self .outputs .append (output )
149151 return output
152+
153+
154+ # Test for https://github.com/huggingface/diffusers/pull/12747
155+ class DummyCallableBySubmodule :
156+ """
157+ Callable group offloading pinner that pins first and last DummyBlock
158+ called in the program by callable(submodule)
159+ """
160+ def __init__ (self , pin_targets ):
161+ self .pin_targets = set (pin_targets )
162+ self .calls_track = [] #only for testing purposes
163+
164+ def __call__ (self , submodule ):
165+ self .calls_track .append (submodule )
166+
167+ return self ._normalize_module_type (submodule ) in self .pin_targets
168+
169+ def _normalize_module_type (self , obj ):
170+ # group might be a single module, or a container of modules
171+ # The group-offloading code may pass either:
172+ # - a single `torch.nn.Module`, or
173+ # - a container (list/tuple) of modules.
174+
175+ # Only return a module when the mapping is unambiguous:
176+ # - if `obj` is a module -> return it
177+ # - if `obj` is a list/tuple containing exactly one module -> return that module
178+ # - otherwise -> return None (won't be considered as a target candidate)
179+ if isinstance (obj , torch .nn .Module ):
180+ return obj
181+ if isinstance (obj , (list , tuple )):
182+ mods = [m for m in obj if isinstance (m , torch .nn .Module )]
183+ return mods [0 ] if len (mods ) == 1 else None
184+ return None
185+
186+ # Test for https://github.com/huggingface/diffusers/pull/12747
187+ class DummyCallableByNameSubmodule (DummyCallableBySubmodule ):
188+ """
189+ Callable group offloading pinner that pins first and last DummyBlock
190+ Same behaviour with DummyCallableBySubmodule, only with different call signature
191+ called in the program by callable(name, submodule)
192+ """
193+ def __call__ (self , name , submodule ):
194+ self .calls_track .append ((name , submodule ))
195+ return self ._normalize_module_type (submodule ) in self .pin_targets
196+
197+ # Test for https://github.com/huggingface/diffusers/pull/12747
198+ class DummyCallableByNameSubmoduleIdx (DummyCallableBySubmodule ):
199+ """
200+ Callable group offloading pinner that pins first and last DummyBlock.
201+ Same behaviour with DummyCallableBySubmodule, only with different call signature
202+ Called in the program by callable(name, submodule, idx)
203+ """
204+ def __call__ (self , name , submodule , idx ):
205+ self .calls_track .append ((name , submodule , idx ))
206+ return self ._normalize_module_type (submodule ) in self .pin_targets
150207
151208
152209@require_torch_accelerator
@@ -367,24 +424,30 @@ def test_block_level_pin_groups_stay_on_device(self):
367424 if torch .device (torch_device ).type not in ["cuda" , "xpu" ]:
368425 return
369426
370- def first_param_device (mod ):
371- p = next (mod .parameters (), None )
372- self .assertIsNotNone (p , f"No parameters found for module { mod } " )
373- return p .device
374-
375- def assert_all_modules_device (mods , expected_type : str , msg : str = "" ):
427+ def assert_all_modules_on_expected_device (modules ,
428+ expected_device : Union [torch .device , str ],
429+ header_error_msg : str = "" ):
430+ def first_param_device (modules ):
431+ p = next (modules .parameters (), None )
432+ self .assertIsNotNone (p , f"No parameters found for module { modules } " )
433+ return p .device
434+
435+ if isinstance (expected_device , torch .device ):
436+ expected_device = expected_device .type
437+
376438 bad = []
377- for i , m in enumerate (mods ):
439+ for i , m in enumerate (modules ):
378440 dev_type = first_param_device (m ).type
379- if dev_type != expected_type :
441+ if dev_type != expected_device :
380442 bad .append ((i , m .__class__ .__name__ , dev_type ))
381- self .assertFalse (
382- bad ,
383- (msg + "\n " if msg else "" )
384- + f"Expected all modules on { expected_type } , but found mismatches: { bad } " ,
443+ self .assertTrue (
444+ len ( bad ) == 0 ,
445+ (header_error_msg + "\n " if header_error_msg else "" )
446+ + f"Expected all modules on { expected_device } , but found mismatches: { bad } " ,
385447 )
386448
387- def get_param_modules_from_exec_order (model ):
449+ def get_param_modules_from_execution_order (model ):
450+ model .eval ()
388451 root_registry = HookRegistry .check_if_exists_or_initialize (model )
389452
390453 lazy_hook = root_registry .get_hook ("lazy_prefetch_group_offloading" )
@@ -395,51 +458,83 @@ def get_param_modules_from_exec_order(model):
395458 model (self .input )
396459
397460 mods = [m for _ , m in lazy_hook .execution_order ]
398- param_mods = [m for m in mods if next (m .parameters (), None ) is not None ]
399- self .assertGreaterEqual (
400- len (param_mods ), 2 , f"Expected >=2 param-bearing modules in execution_order, got { len (param_mods )} "
401- )
402-
403- first = param_mods [0 ]
404- last = param_mods [- 1 ]
405- middle_layers = param_mods [1 :- 1 ]
406- return first , middle_layers , last
461+ param_modules = [m for m in mods if next (m .parameters (), None ) is not None ]
462+ return param_modules
463+
464+ def assert_callables_offloading_tests (param_modules ,
465+ callables ,
466+ header_error_msg : str = "" ):
467+ pinned_modules = [m for m in param_modules if m in callables .pin_targets ]
468+ unpinned_modules = [m for m in param_modules if m not in callables .pin_targets ]
469+ self .assertTrue (len (callables .calls_track ) > 0 , f"{ header_error_msg } : callable should have been called at least once" )
470+ assert_all_modules_on_expected_device (pinned_modules , torch_device , f"{ header_error_msg } : pinned blocks should stay on device" )
471+ assert_all_modules_on_expected_device (unpinned_modules , "cpu" , f"{ header_error_msg } : unpinned blocks should be offloaded" )
472+
473+
474+ default_parameters = {
475+ "onload_device" : torch_device ,
476+ "offload_type" : "block_level" ,
477+ "num_blocks_per_group" : 1 ,
478+ "use_stream" : True ,
479+ }
480+ model_default_no_pin = self .get_model ()
481+ model_default_no_pin .enable_group_offload (
482+ ** default_parameters
483+ )
484+ param_modules = get_param_modules_from_execution_order (model_default_no_pin )
485+ assert_all_modules_on_expected_device (param_modules ,
486+ expected_device = "cpu" ,
487+ header_error_msg = "default pin_groups: expected ALL modules offloaded to CPU" )
488+
489+ model_pin_all = self .get_model ()
490+ model_pin_all .enable_group_offload (
491+ ** default_parameters ,
492+ pin_groups = "all" ,
493+ )
494+ param_modules = get_param_modules_from_execution_order (model_pin_all )
495+ assert_all_modules_on_expected_device (param_modules ,
496+ expected_device = torch_device ,
497+ header_error_msg = "pin_groups = all: expected ALL layers on accelerator device" )
407498
408- accel_type = torch .device (torch_device ).type
409499
410- model_no_pin = self .get_model ()
411- model_no_pin .enable_group_offload (
412- torch_device ,
413- offload_type = "block_level" ,
414- num_blocks_per_group = 1 ,
415- use_stream = True ,
416- )
417- model_no_pin .eval ()
418- first , middle , last = get_param_modules_from_exec_order (model_no_pin )
419-
420- self .assertEqual (first_param_device (first ).type , "cpu" )
421- self .assertEqual (first_param_device (last ).type , "cpu" )
422- assert_all_modules_device (middle , "cpu" , msg = "No-pin: expected ALL middle layers on CPU" )
423-
424- model_pin = self .get_model ()
425- model_pin .enable_group_offload (
426- torch_device ,
427- offload_type = "block_level" ,
428- num_blocks_per_group = 1 ,
429- use_stream = True ,
500+ model_pin_first_last = self .get_model ()
501+ model_pin_first_last .enable_group_offload (
502+ ** default_parameters ,
430503 pin_groups = "first_last" ,
431504 )
432- model_pin .eval ()
433- first , middle , last = get_param_modules_from_exec_order (model_pin )
434-
435- self .assertEqual (first_param_device (first ).type , accel_type )
436- self .assertEqual (first_param_device (last ).type , accel_type )
437- assert_all_modules_device (middle , "cpu" , msg = "Pin: expected ALL middle layers on CPU" )
505+ param_modules = get_param_modules_from_execution_order (model_pin_first_last )
506+ assert_all_modules_on_expected_device ([param_modules [0 ], param_modules [- 1 ]],
507+ expected_device = torch_device ,
508+ header_error_msg = "pin_groups = first_last: expected first and last layers on accelerator device" )
509+ assert_all_modules_on_expected_device (param_modules [1 :- 1 ],
510+ expected_device = "cpu" ,
511+ header_error_msg = "pin_groups = first_last: expected ALL middle layers offloaded to CPU" )
512+
513+
514+ model = self .get_model ()
515+ callable_by_submodule = DummyCallableBySubmodule (pin_targets = [model .blocks [0 ], model .blocks [- 1 ]])
516+ model .enable_group_offload (** default_parameters ,
517+ pin_groups = callable_by_submodule )
518+ param_modules = get_param_modules_from_execution_order (model )
519+ assert_callables_offloading_tests (param_modules ,
520+ callable_by_submodule ,
521+ header_error_msg = "pin_groups with callable(submodule)" )
438522
439- # Should still hold after another invocation
440- with torch .no_grad ():
441- model_pin (self .input )
523+ model = self .get_model ()
524+ callable_by_name_submodule = DummyCallableByNameSubmodule (pin_targets = [model .blocks [0 ], model .blocks [- 1 ]])
525+ model .enable_group_offload (** default_parameters ,
526+ pin_groups = callable_by_name_submodule )
527+ param_modules = get_param_modules_from_execution_order (model )
528+ assert_callables_offloading_tests (param_modules ,
529+ callable_by_name_submodule ,
530+ header_error_msg = "pin_groups with callable(name, submodule)" )
442531
443- self .assertEqual (first_param_device (first ).type , accel_type )
444- self .assertEqual (first_param_device (last ).type , accel_type )
445- assert_all_modules_device (middle , "cpu" , msg = "Pin (2nd forward): expected ALL middle layers on CPU" )
532+ model = self .get_model ()
533+ callable_by_name_submodule_idx = DummyCallableByNameSubmoduleIdx (pin_targets = [model .blocks [0 ], model .blocks [- 1 ]])
534+ model .enable_group_offload (** default_parameters ,
535+ pin_groups = callable_by_name_submodule_idx )
536+ param_modules = get_param_modules_from_execution_order (model )
537+ assert_callables_offloading_tests (param_modules ,
538+ callable_by_name_submodule_idx ,
539+ header_error_msg = "pin_groups with callable(name, submodule, idx)" )
540+
0 commit comments