Skip to content

Commit 536d8e6

Browse files
committed
added callable testing with three valid formats
1 parent 7a66cda commit 536d8e6

File tree

1 file changed

+150
-55
lines changed

1 file changed

+150
-55
lines changed

tests/hooks/test_group_offloading.py

Lines changed: 150 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from diffusers.utils import get_logger
2626
from diffusers.utils.import_utils import compare_versions
2727

28+
from typing import Union
29+
2830
from ..testing_utils import (
2931
backend_empty_cache,
3032
backend_max_memory_allocated,
@@ -147,6 +149,61 @@ def __init__(self):
147149
def post_forward(self, module, output):
148150
self.outputs.append(output)
149151
return output
152+
153+
154+
# Test for https://github.com/huggingface/diffusers/pull/12747
155+
class DummyCallableBySubmodule:
156+
"""
157+
Callable group offloading pinner that pins first and last DummyBlock
158+
called in the program by callable(submodule)
159+
"""
160+
def __init__(self, pin_targets):
161+
self.pin_targets = set(pin_targets)
162+
self.calls_track = [] #only for testing purposes
163+
164+
def __call__(self, submodule):
165+
self.calls_track.append(submodule)
166+
167+
return self._normalize_module_type(submodule) in self.pin_targets
168+
169+
def _normalize_module_type(self, obj):
170+
# group might be a single module, or a container of modules
171+
# The group-offloading code may pass either:
172+
# - a single `torch.nn.Module`, or
173+
# - a container (list/tuple) of modules.
174+
175+
# Only return a module when the mapping is unambiguous:
176+
# - if `obj` is a module -> return it
177+
# - if `obj` is a list/tuple containing exactly one module -> return that module
178+
# - otherwise -> return None (won't be considered as a target candidate)
179+
if isinstance(obj, torch.nn.Module):
180+
return obj
181+
if isinstance(obj, (list, tuple)):
182+
mods = [m for m in obj if isinstance(m, torch.nn.Module)]
183+
return mods[0] if len(mods) == 1 else None
184+
return None
185+
186+
# Test for https://github.com/huggingface/diffusers/pull/12747
187+
class DummyCallableByNameSubmodule(DummyCallableBySubmodule):
188+
"""
189+
Callable group offloading pinner that pins first and last DummyBlock
190+
Same behaviour with DummyCallableBySubmodule, only with different call signature
191+
called in the program by callable(name, submodule)
192+
"""
193+
def __call__(self, name, submodule):
194+
self.calls_track.append((name, submodule))
195+
return self._normalize_module_type(submodule) in self.pin_targets
196+
197+
# Test for https://github.com/huggingface/diffusers/pull/12747
198+
class DummyCallableByNameSubmoduleIdx(DummyCallableBySubmodule):
199+
"""
200+
Callable group offloading pinner that pins first and last DummyBlock.
201+
Same behaviour with DummyCallableBySubmodule, only with different call signature
202+
Called in the program by callable(name, submodule, idx)
203+
"""
204+
def __call__(self, name, submodule, idx):
205+
self.calls_track.append((name, submodule, idx))
206+
return self._normalize_module_type(submodule) in self.pin_targets
150207

151208

152209
@require_torch_accelerator
@@ -367,24 +424,30 @@ def test_block_level_pin_groups_stay_on_device(self):
367424
if torch.device(torch_device).type not in ["cuda", "xpu"]:
368425
return
369426

370-
def first_param_device(mod):
371-
p = next(mod.parameters(), None)
372-
self.assertIsNotNone(p, f"No parameters found for module {mod}")
373-
return p.device
374-
375-
def assert_all_modules_device(mods, expected_type: str, msg: str = ""):
427+
def assert_all_modules_on_expected_device(modules,
428+
expected_device: Union[torch.device, str],
429+
header_error_msg: str = ""):
430+
def first_param_device(modules):
431+
p = next(modules.parameters(), None)
432+
self.assertIsNotNone(p, f"No parameters found for module {modules}")
433+
return p.device
434+
435+
if isinstance(expected_device, torch.device):
436+
expected_device = expected_device.type
437+
376438
bad = []
377-
for i, m in enumerate(mods):
439+
for i, m in enumerate(modules):
378440
dev_type = first_param_device(m).type
379-
if dev_type != expected_type:
441+
if dev_type != expected_device:
380442
bad.append((i, m.__class__.__name__, dev_type))
381-
self.assertFalse(
382-
bad,
383-
(msg + "\n" if msg else "")
384-
+ f"Expected all modules on {expected_type}, but found mismatches: {bad}",
443+
self.assertTrue(
444+
len(bad) == 0,
445+
(header_error_msg + "\n" if header_error_msg else "")
446+
+ f"Expected all modules on {expected_device}, but found mismatches: {bad}",
385447
)
386448

387-
def get_param_modules_from_exec_order(model):
449+
def get_param_modules_from_execution_order(model):
450+
model.eval()
388451
root_registry = HookRegistry.check_if_exists_or_initialize(model)
389452

390453
lazy_hook = root_registry.get_hook("lazy_prefetch_group_offloading")
@@ -395,51 +458,83 @@ def get_param_modules_from_exec_order(model):
395458
model(self.input)
396459

397460
mods = [m for _, m in lazy_hook.execution_order]
398-
param_mods = [m for m in mods if next(m.parameters(), None) is not None]
399-
self.assertGreaterEqual(
400-
len(param_mods), 2, f"Expected >=2 param-bearing modules in execution_order, got {len(param_mods)}"
401-
)
402-
403-
first = param_mods[0]
404-
last = param_mods[-1]
405-
middle_layers = param_mods[1:-1]
406-
return first, middle_layers, last
461+
param_modules = [m for m in mods if next(m.parameters(), None) is not None]
462+
return param_modules
463+
464+
def assert_callables_offloading_tests(param_modules,
465+
callables,
466+
header_error_msg: str = ""):
467+
pinned_modules = [m for m in param_modules if m in callables.pin_targets]
468+
unpinned_modules = [m for m in param_modules if m not in callables.pin_targets]
469+
self.assertTrue(len(callables.calls_track) > 0, f"{header_error_msg}: callable should have been called at least once")
470+
assert_all_modules_on_expected_device(pinned_modules, torch_device, f"{header_error_msg}: pinned blocks should stay on device")
471+
assert_all_modules_on_expected_device(unpinned_modules, "cpu", f"{header_error_msg}: unpinned blocks should be offloaded")
472+
473+
474+
default_parameters = {
475+
"onload_device": torch_device,
476+
"offload_type": "block_level",
477+
"num_blocks_per_group": 1,
478+
"use_stream": True,
479+
}
480+
model_default_no_pin = self.get_model()
481+
model_default_no_pin.enable_group_offload(
482+
**default_parameters
483+
)
484+
param_modules = get_param_modules_from_execution_order(model_default_no_pin)
485+
assert_all_modules_on_expected_device(param_modules,
486+
expected_device="cpu",
487+
header_error_msg="default pin_groups: expected ALL modules offloaded to CPU")
488+
489+
model_pin_all = self.get_model()
490+
model_pin_all.enable_group_offload(
491+
**default_parameters,
492+
pin_groups="all",
493+
)
494+
param_modules = get_param_modules_from_execution_order(model_pin_all)
495+
assert_all_modules_on_expected_device(param_modules,
496+
expected_device=torch_device,
497+
header_error_msg="pin_groups = all: expected ALL layers on accelerator device")
407498

408-
accel_type = torch.device(torch_device).type
409499

410-
model_no_pin = self.get_model()
411-
model_no_pin.enable_group_offload(
412-
torch_device,
413-
offload_type="block_level",
414-
num_blocks_per_group=1,
415-
use_stream=True,
416-
)
417-
model_no_pin.eval()
418-
first, middle, last = get_param_modules_from_exec_order(model_no_pin)
419-
420-
self.assertEqual(first_param_device(first).type, "cpu")
421-
self.assertEqual(first_param_device(last).type, "cpu")
422-
assert_all_modules_device(middle, "cpu", msg="No-pin: expected ALL middle layers on CPU")
423-
424-
model_pin = self.get_model()
425-
model_pin.enable_group_offload(
426-
torch_device,
427-
offload_type="block_level",
428-
num_blocks_per_group=1,
429-
use_stream=True,
500+
model_pin_first_last = self.get_model()
501+
model_pin_first_last.enable_group_offload(
502+
**default_parameters,
430503
pin_groups="first_last",
431504
)
432-
model_pin.eval()
433-
first, middle, last = get_param_modules_from_exec_order(model_pin)
434-
435-
self.assertEqual(first_param_device(first).type, accel_type)
436-
self.assertEqual(first_param_device(last).type, accel_type)
437-
assert_all_modules_device(middle, "cpu", msg="Pin: expected ALL middle layers on CPU")
505+
param_modules = get_param_modules_from_execution_order(model_pin_first_last)
506+
assert_all_modules_on_expected_device([param_modules[0], param_modules[-1]],
507+
expected_device=torch_device,
508+
header_error_msg="pin_groups = first_last: expected first and last layers on accelerator device")
509+
assert_all_modules_on_expected_device(param_modules[1:-1],
510+
expected_device="cpu",
511+
header_error_msg="pin_groups = first_last: expected ALL middle layers offloaded to CPU")
512+
513+
514+
model = self.get_model()
515+
callable_by_submodule = DummyCallableBySubmodule(pin_targets=[model.blocks[0], model.blocks[-1]])
516+
model.enable_group_offload(**default_parameters,
517+
pin_groups=callable_by_submodule)
518+
param_modules = get_param_modules_from_execution_order(model)
519+
assert_callables_offloading_tests(param_modules,
520+
callable_by_submodule,
521+
header_error_msg="pin_groups with callable(submodule)")
438522

439-
# Should still hold after another invocation
440-
with torch.no_grad():
441-
model_pin(self.input)
523+
model = self.get_model()
524+
callable_by_name_submodule = DummyCallableByNameSubmodule(pin_targets=[model.blocks[0], model.blocks[-1]])
525+
model.enable_group_offload(**default_parameters,
526+
pin_groups=callable_by_name_submodule)
527+
param_modules = get_param_modules_from_execution_order(model)
528+
assert_callables_offloading_tests(param_modules,
529+
callable_by_name_submodule,
530+
header_error_msg="pin_groups with callable(name, submodule)")
442531

443-
self.assertEqual(first_param_device(first).type, accel_type)
444-
self.assertEqual(first_param_device(last).type, accel_type)
445-
assert_all_modules_device(middle, "cpu", msg="Pin (2nd forward): expected ALL middle layers on CPU")
532+
model = self.get_model()
533+
callable_by_name_submodule_idx = DummyCallableByNameSubmoduleIdx(pin_targets=[model.blocks[0], model.blocks[-1]])
534+
model.enable_group_offload(**default_parameters,
535+
pin_groups=callable_by_name_submodule_idx)
536+
param_modules = get_param_modules_from_execution_order(model)
537+
assert_callables_offloading_tests(param_modules,
538+
callable_by_name_submodule_idx,
539+
header_error_msg="pin_groups with callable(name, submodule, idx)")
540+

0 commit comments

Comments
 (0)