Skip to content

Commit 4397f59

Browse files
authored
[bitsandbytes] improve dtype mismatch handling for bnb + lora. (#11270)
* improve dtype mismatch handling for bnb + lora. * add a test * fix and updates * update
1 parent 0567932 commit 4397f59

File tree

1 file changed

+5
-3
lines changed
  • src/diffusers/quantizers/bitsandbytes

1 file changed

+5
-3
lines changed

src/diffusers/quantizers/bitsandbytes/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,11 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torc
171171

172172
if cls_name == "Params4bit":
173173
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
174-
logger.warning_once(
175-
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
176-
)
174+
msg = f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
175+
if dtype:
176+
msg = f"The model is going to be first dequantized in {output_tensor.dtype} and type-casted to {dtype}"
177+
output_tensor = output_tensor.to(dtype)
178+
logger.warning_once(msg)
177179
return output_tensor
178180

179181
if state.SCB is None:

0 commit comments

Comments
 (0)