Skip to content

Commit 9405820

Browse files
committed
up
1 parent c171a1d commit 9405820

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/diffusers/modular_pipelines/components_manager.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,12 @@ def __call__(self, hooks, model_id, model, execution_device):
164164

165165
device_type = execution_device.type
166166
device_module = getattr(torch, device_type, torch.cuda)
167-
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
168-
mem_on_device = mem_on_device - self.memory_reserve_margin
167+
try:
168+
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
169+
except AttributeError:
170+
raise AttributeError(f"Do not know how to obtain obtain memory info for {str(device_module)}.")
169171

172+
mem_on_device = mem_on_device - self.memory_reserve_margin
170173
if current_module_size < mem_on_device:
171174
return []
172175

@@ -700,6 +703,8 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None,
700703
if not is_accelerate_available():
701704
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
702705

706+
# TODO: add a warning if mem_get_info isn't available on `device`.
707+
703708
for name, component in self.components.items():
704709
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
705710
remove_hook_from_module(component, recurse=True)

0 commit comments

Comments
 (0)