2525from diffusers .utils import get_logger
2626from diffusers .utils .import_utils import compare_versions
2727
28- from typing import Union
28+ from typing import Any , Iterable , List , Optional , Sequence , Union
2929
3030from ..testing_utils import (
3131 backend_empty_cache ,
@@ -157,16 +157,15 @@ class DummyCallableBySubmodule:
157157 Callable group offloading pinner that pins first and last DummyBlock
158158 called in the program by callable(submodule)
159159 """
160- def __init__ (self , pin_targets ) :
160+ def __init__ (self , pin_targets : Iterable [ torch . nn . Module ]) -> None :
161161 self .pin_targets = set (pin_targets )
162- self .calls_track = [] #only for testing purposes
162+ self .calls_track = [] # testing only
163163
164- def __call__ (self , submodule ) :
164+ def __call__ (self , submodule : torch . nn . Module ) -> bool :
165165 self .calls_track .append (submodule )
166-
167166 return self ._normalize_module_type (submodule ) in self .pin_targets
168167
169- def _normalize_module_type (self , obj ) :
168+ def _normalize_module_type (self , obj : Any ) -> Optional [ torch . nn . Module ] :
170169 # group might be a single module, or a container of modules
171170 # The group-offloading code may pass either:
172171 # - a single `torch.nn.Module`, or
@@ -190,7 +189,7 @@ class DummyCallableByNameSubmodule(DummyCallableBySubmodule):
190189 Same behaviour with DummyCallableBySubmodule, only with different call signature
191190 called in the program by callable(name, submodule)
192191 """
193- def __call__ (self , name , submodule ) :
192+ def __call__ (self , name : str , submodule : torch . nn . Module ) -> bool :
194193 self .calls_track .append ((name , submodule ))
195194 return self ._normalize_module_type (submodule ) in self .pin_targets
196195
@@ -201,7 +200,7 @@ class DummyCallableByNameSubmoduleIdx(DummyCallableBySubmodule):
201200 Same behaviour with DummyCallableBySubmodule, only with different call signature
202201 Called in the program by callable(name, submodule, idx)
203202 """
204- def __call__ (self , name , submodule , idx ) :
203+ def __call__ (self , name : str , submodule : torch . nn . Module , idx : int ) -> bool :
205204 self .calls_track .append ((name , submodule , idx ))
206205 return self ._normalize_module_type (submodule ) in self .pin_targets
207206
@@ -210,7 +209,7 @@ class DummyInvalidCallable(DummyCallableBySubmodule):
210209 """
211210 Callable group offloading pinner that uses invalid call signature
212211 """
213- def __call__ (self , name , submodule , idx , extra ) :
212+ def __call__ (self , name : str , submodule : torch . nn . Module , idx : int , extra : Any ) -> bool :
214213 self .calls_track .append ((name , submodule , idx , extra ))
215214 return self ._normalize_module_type (submodule ) in self .pin_targets
216215
@@ -433,10 +432,10 @@ def test_block_level_offloading_with_pin_groups_stay_on_device(self):
433432 if torch .device (torch_device ).type not in ["cuda" , "xpu" ]:
434433 return
435434
436- def assert_all_modules_on_expected_device (modules ,
437- expected_device : Union [torch .device , str ],
438- header_error_msg : str = "" ):
439- def first_param_device (modules ) :
435+ def assert_all_modules_on_expected_device (modules : Sequence [ torch . nn . Module ],
436+ expected_device : Union [torch .device , str ],
437+ header_error_msg : str = "" ) -> None :
438+ def first_param_device (modules : torch . nn . Module ) -> torch . device :
440439 p = next (modules .parameters (), None )
441440 self .assertIsNotNone (p , f"No parameters found for module { modules } " )
442441 return p .device
@@ -455,7 +454,7 @@ def first_param_device(modules):
455454 + f"Expected all modules on { expected_device } , but found mismatches: { bad } " ,
456455 )
457456
458- def get_param_modules_from_execution_order (model ) :
457+ def get_param_modules_from_execution_order (model : DummyModel ) -> List [ torch . nn . Module ] :
459458 model .eval ()
460459 root_registry = HookRegistry .check_if_exists_or_initialize (model )
461460
@@ -470,12 +469,14 @@ def get_param_modules_from_execution_order(model):
470469 param_modules = [m for m in mods if next (m .parameters (), None ) is not None ]
471470 return param_modules
472471
473- def assert_callables_offloading_tests (param_modules ,
474- callables ,
475- header_error_msg : str = "" ):
476- pinned_modules = [m for m in param_modules if m in callables .pin_targets ]
477- unpinned_modules = [m for m in param_modules if m not in callables .pin_targets ]
478- self .assertTrue (len (callables .calls_track ) > 0 , f"{ header_error_msg } : callable should have been called at least once" )
472+ def assert_callables_offloading_tests (
473+ param_modules : Sequence [torch .nn .Module ],
474+ callable : Any ,
475+ header_error_msg : str = "" ,
476+ ) -> None :
477+ pinned_modules = [m for m in param_modules if m in callable .pin_targets ]
478+ unpinned_modules = [m for m in param_modules if m not in callable .pin_targets ]
479+ self .assertTrue (len (callable .calls_track ) > 0 , f"{ header_error_msg } : callable should have been called at least once" )
479480 assert_all_modules_on_expected_device (pinned_modules , torch_device , f"{ header_error_msg } : pinned blocks should stay on device" )
480481 assert_all_modules_on_expected_device (unpinned_modules , "cpu" , f"{ header_error_msg } : unpinned blocks should be offloaded" )
481482
0 commit comments