Skip to content

Commit 9d3573a

Browse files
patch: save ram by disabling fp32 text encoders
1 parent e3efd7c commit 9d3573a

1 file changed

Lines changed: 1 addition & 8 deletions

File tree

library/sdxl_train_util.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,6 @@ def _load_target_model(
104104

105105
text_encoder1 = pipe.text_encoder
106106
text_encoder2 = pipe.text_encoder_2
107-
108-
# convert to fp32 for cache text_encoders outputs
109-
if text_encoder1.dtype != torch.float32:
110-
text_encoder1 = text_encoder1.to(dtype=torch.float32)
111-
if text_encoder2.dtype != torch.float32:
112-
text_encoder2 = text_encoder2.to(dtype=torch.float32)
113-
114107
vae = pipe.vae
115108
unet = pipe.unet
116109
del pipe
@@ -119,7 +112,7 @@ def _load_target_model(
119112
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
120113
with init_empty_weights():
121114
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
122-
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
115+
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device)
123116
print("U-Net converted to original U-Net")
124117

125118
logit_scale = None

0 commit comments

Comments
 (0)