Skip to content

Commit 9950689

Browse files
committed
cleaned tests coding style
1 parent be76b82 commit 9950689

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

tests/hooks/test_group_offloading.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from diffusers.utils import get_logger
2626
from diffusers.utils.import_utils import compare_versions
2727

28-
from typing import Union
28+
from typing import Any, Iterable, List, Optional, Sequence, Union
2929

3030
from ..testing_utils import (
3131
backend_empty_cache,
@@ -157,16 +157,15 @@ class DummyCallableBySubmodule:
157157
Callable group offloading pinner that pins first and last DummyBlock
158158
called in the program by callable(submodule)
159159
"""
160-
def __init__(self, pin_targets):
160+
def __init__(self, pin_targets: Iterable[torch.nn.Module]) -> None:
161161
self.pin_targets = set(pin_targets)
162-
self.calls_track = [] #only for testing purposes
162+
self.calls_track = [] # testing only
163163

164-
def __call__(self, submodule):
164+
def __call__(self, submodule: torch.nn.Module) -> bool:
165165
self.calls_track.append(submodule)
166-
167166
return self._normalize_module_type(submodule) in self.pin_targets
168167

169-
def _normalize_module_type(self, obj):
168+
def _normalize_module_type(self, obj: Any) -> Optional[torch.nn.Module]:
170169
# group might be a single module, or a container of modules
171170
# The group-offloading code may pass either:
172171
# - a single `torch.nn.Module`, or
@@ -190,7 +189,7 @@ class DummyCallableByNameSubmodule(DummyCallableBySubmodule):
190189
Same behaviour with DummyCallableBySubmodule, only with different call signature
191190
called in the program by callable(name, submodule)
192191
"""
193-
def __call__(self, name, submodule):
192+
def __call__(self, name: str, submodule: torch.nn.Module) -> bool:
194193
self.calls_track.append((name, submodule))
195194
return self._normalize_module_type(submodule) in self.pin_targets
196195

@@ -201,7 +200,7 @@ class DummyCallableByNameSubmoduleIdx(DummyCallableBySubmodule):
201200
Same behaviour with DummyCallableBySubmodule, only with different call signature
202201
Called in the program by callable(name, submodule, idx)
203202
"""
204-
def __call__(self, name, submodule, idx):
203+
def __call__(self, name: str, submodule: torch.nn.Module, idx: int) -> bool:
205204
self.calls_track.append((name, submodule, idx))
206205
return self._normalize_module_type(submodule) in self.pin_targets
207206

@@ -210,7 +209,7 @@ class DummyInvalidCallable(DummyCallableBySubmodule):
210209
"""
211210
Callable group offloading pinner that uses invalid call signature
212211
"""
213-
def __call__(self, name, submodule, idx, extra):
212+
def __call__(self, name: str, submodule: torch.nn.Module, idx: int, extra: Any) -> bool:
214213
self.calls_track.append((name, submodule, idx, extra))
215214
return self._normalize_module_type(submodule) in self.pin_targets
216215

@@ -433,10 +432,10 @@ def test_block_level_offloading_with_pin_groups_stay_on_device(self):
433432
if torch.device(torch_device).type not in ["cuda", "xpu"]:
434433
return
435434

436-
def assert_all_modules_on_expected_device(modules,
437-
expected_device: Union[torch.device, str],
438-
header_error_msg: str = ""):
439-
def first_param_device(modules):
435+
def assert_all_modules_on_expected_device(modules: Sequence[torch.nn.Module],
436+
expected_device: Union[torch.device, str],
437+
header_error_msg: str = "") -> None:
438+
def first_param_device(modules: torch.nn.Module) -> torch.device:
440439
p = next(modules.parameters(), None)
441440
self.assertIsNotNone(p, f"No parameters found for module {modules}")
442441
return p.device
@@ -455,7 +454,7 @@ def first_param_device(modules):
455454
+ f"Expected all modules on {expected_device}, but found mismatches: {bad}",
456455
)
457456

458-
def get_param_modules_from_execution_order(model):
457+
def get_param_modules_from_execution_order(model: DummyModel) -> List[torch.nn.Module]:
459458
model.eval()
460459
root_registry = HookRegistry.check_if_exists_or_initialize(model)
461460

@@ -470,12 +469,14 @@ def get_param_modules_from_execution_order(model):
470469
param_modules = [m for m in mods if next(m.parameters(), None) is not None]
471470
return param_modules
472471

473-
def assert_callables_offloading_tests(param_modules,
474-
callables,
475-
header_error_msg: str = ""):
476-
pinned_modules = [m for m in param_modules if m in callables.pin_targets]
477-
unpinned_modules = [m for m in param_modules if m not in callables.pin_targets]
478-
self.assertTrue(len(callables.calls_track) > 0, f"{header_error_msg}: callable should have been called at least once")
472+
def assert_callables_offloading_tests(
473+
param_modules: Sequence[torch.nn.Module],
474+
callable: Any,
475+
header_error_msg: str = "",
476+
) -> None:
477+
pinned_modules = [m for m in param_modules if m in callable.pin_targets]
478+
unpinned_modules = [m for m in param_modules if m not in callable.pin_targets]
479+
self.assertTrue(len(callable.calls_track) > 0, f"{header_error_msg}: callable should have been called at least once")
479480
assert_all_modules_on_expected_device(pinned_modules, torch_device, f"{header_error_msg}: pinned blocks should stay on device")
480481
assert_all_modules_on_expected_device(unpinned_modules, "cpu", f"{header_error_msg}: unpinned blocks should be offloaded")
481482

0 commit comments

Comments
 (0)