25
25
is_accelerate_available ,
26
26
logging ,
27
27
)
28
+ from ..utils .torch_utils import get_device
28
29
29
30
30
31
if is_accelerate_available ():
@@ -161,7 +162,9 @@ def __call__(self, hooks, model_id, model, execution_device):
161
162
162
163
current_module_size = model .get_memory_footprint ()
163
164
164
- mem_on_device = torch .cuda .mem_get_info (execution_device .index )[0 ]
165
+ device_type = execution_device .type
166
+ device_module = getattr (torch , device_type , torch .cuda )
167
+ mem_on_device = device_module .mem_get_info (execution_device .index )[0 ]
165
168
mem_on_device = mem_on_device - self .memory_reserve_margin
166
169
if current_module_size < mem_on_device :
167
170
return []
@@ -301,7 +304,7 @@ class ComponentsManager:
301
304
cm.add("vae", vae_model, collection="sdxl")
302
305
303
306
# Enable auto offloading
304
- cm.enable_auto_cpu_offload(device="cuda" )
307
+ cm.enable_auto_cpu_offload()
305
308
306
309
# Retrieve components
307
310
unet = cm.get_one(name="unet", collection="sdxl")
@@ -490,6 +493,8 @@ def remove(self, component_id: str = None):
490
493
gc .collect ()
491
494
if torch .cuda .is_available ():
492
495
torch .cuda .empty_cache ()
496
+ if torch .xpu .is_available ():
497
+ torch .xpu .empty_cache ()
493
498
494
499
# YiYi TODO: rename to search_components for now, may remove this method
495
500
def search_components (
@@ -678,7 +683,7 @@ def matches_pattern(component_id, pattern, exact_match=False):
678
683
679
684
return get_return_dict (matches , return_dict_with_names )
680
685
681
- def enable_auto_cpu_offload (self , device : Union [str , int , torch .device ] = "cuda" , memory_reserve_margin = "3GB" ):
686
+ def enable_auto_cpu_offload (self , device : Union [str , int , torch .device ] = None , memory_reserve_margin = "3GB" ):
682
687
"""
683
688
Enable automatic CPU offloading for all components.
684
689
@@ -704,6 +709,8 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda"
704
709
705
710
self .disable_auto_cpu_offload ()
706
711
offload_strategy = AutoOffloadStrategy (memory_reserve_margin = memory_reserve_margin )
712
+ if device is None :
713
+ device = get_device ()
707
714
device = torch .device (device )
708
715
if device .index is None :
709
716
device = torch .device (f"{ device .type } :{ 0 } " )
0 commit comments