Skip to content

[refactor] condense group offloading #11990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 74 additions & 102 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
self.offload_to_disk_path = offload_to_disk_path
self._is_offloaded_to_disk = False

if self.offload_to_disk_path:
if self.offload_to_disk_path is not None:
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
self.group_id = group_id if group_id is not None else str(id(self))
short_hash = _compute_group_hash(self.group_id)
Expand All @@ -115,6 +115,12 @@ def __init__(
else:
self.cpu_param_dict = self._init_cpu_param_dict()

self._torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)

def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
Expand All @@ -138,112 +144,76 @@ def _init_cpu_param_dict(self):

@contextmanager
def _pinned_memory_tensors(self):
pinned_dict = {}
try:
for param, tensor in self.cpu_param_dict.items():
if not tensor.is_pinned():
pinned_dict[param] = tensor.pin_memory()
else:
pinned_dict[param] = tensor

pinned_dict = {
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
for param, tensor in self.cpu_param_dict.items()
}
yield pinned_dict

finally:
pinned_dict = None

def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
def _transfer_tensor_to_device(self, tensor, source_tensor):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream and current_stream is not None:
tensor.data.record_stream(current_stream)
if self.record_stream:
tensor.data.record_stream(self._torch_accelerator_module.current_stream())

def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
def _process_tensors_from_modules(self, pinned_memory=None):
for group_module in self.modules:
for param in group_module.parameters():
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source, current_stream)
self._transfer_tensor_to_device(param, source)
for buffer in group_module.buffers():
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, current_stream)
self._transfer_tensor_to_device(buffer, source)

for param in self.parameters:
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source, current_stream)
self._transfer_tensor_to_device(param, source)

for buffer in self.buffers:
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, current_stream)
self._transfer_tensor_to_device(buffer, source)

def _onload_from_disk(self, current_stream):
def _onload_from_disk(self):
if self.stream is not None:
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")

for key, tensor_obj in self.key_to_tensor.items():
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]

with self._pinned_memory_tensors() as pinned_memory:
for key, tensor_obj in self.key_to_tensor.items():
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)

self.cpu_param_dict.clear()
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()

else:
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None

def _onload_from_memory(self, current_stream):
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory, current_stream)
else:
self._process_tensors_from_modules(None, current_stream)

@torch.compiler.disable()
def onload_(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
with context:
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
device = str(self.onload_device) if self.stream is None else "cpu"
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)

if self.offload_to_disk_path:
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()

with context:
if self.stream is not None:
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
# Load directly to the target device (synchronous)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]

def _onload_from_memory(self):
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()

context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
with context:
if self.offload_to_disk_path:
self._onload_from_disk(current_stream)
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory)
else:
self._onload_from_memory(current_stream)
self._process_tensors_from_modules(None)

def _offload_to_disk(self):
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
Expand All @@ -264,14 +234,10 @@ def _offload_to_disk(self):
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)

def _offload_to_memory(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
if self.stream is not None:
if not self.record_stream:
torch_accelerator_module.current_stream().synchronize()
self._torch_accelerator_module.current_stream().synchronize()

for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
Expand All @@ -282,15 +248,23 @@ def _offload_to_memory(self):

else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)

@torch.compiler.disable()
def onload_(self):
r"""Onloads the group of parameters to the onload_device."""
if self.offload_to_disk_path is not None:
self._onload_from_disk()
else:
self._onload_from_memory()

@torch.compiler.disable()
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
r"""Offloads the group of parameters to the offload_device."""
if self.offload_to_disk_path:
self._offload_to_disk()
else:
Expand All @@ -307,11 +281,9 @@ class GroupOffloadingHook(ModelHook):

_is_stateful = False

def __init__(
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
) -> None:
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
self.group = group
self.next_group = next_group
self.next_group: Optional[ModuleGroup] = None
self.config = config

def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
Expand Down Expand Up @@ -459,8 +431,8 @@ def pre_forward(self, module, *args, **kwargs):

def apply_group_offloading(
module: torch.nn.Module,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
onload_device: Union[str, torch.device],
offload_device: Union[str, torch.device] = torch.device("cpu"),
offload_type: Union[str, GroupOffloadingType] = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
Expand Down Expand Up @@ -546,6 +518,8 @@ def apply_group_offloading(
```
"""

onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
offload_type = GroupOffloadingType(offload_type)

stream = None
Expand Down Expand Up @@ -633,7 +607,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, None, config=config)
_apply_group_offloading_hook(group_module, group, config=config)

# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
Expand Down Expand Up @@ -662,9 +636,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)


def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
Expand Down Expand Up @@ -693,7 +667,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(submodule, group, None, config=config)
_apply_group_offloading_hook(submodule, group, config=config)
modules_with_group_offloading.add(name)

# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
Expand Down Expand Up @@ -740,7 +714,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(parent_module, group, None, config=config)
_apply_group_offloading_hook(parent_module, group, config=config)

if config.stream is not None:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
Expand All @@ -762,13 +736,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=_GROUP_ID_LAZY_LEAF,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)


def _apply_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
Expand All @@ -777,14 +750,13 @@ def _apply_group_offloading_hook(
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group, config=config)
hook = GroupOffloadingHook(group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)


def _apply_lazy_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
Expand All @@ -793,7 +765,7 @@ def _apply_lazy_group_offloading_hook(
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group, config=config)
hook = GroupOffloadingHook(group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)

lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
Expand Down
Loading