From 5801679a98d07a1a780a4a3c7c3379e13221c7da Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Apr 2025 09:39:02 +0530 Subject: [PATCH 1/4] improve dtype mismatch handling for bnb + lora. --- src/diffusers/loaders/lora_pipeline.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a29b77acce6e..58e2ecd44508 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -94,6 +94,11 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module): weight_on_cpu = True if is_bnb_4bit_quantized: + if module.weight.quant_state.dtype != model.dtype: + raise ValueError( + f"Model is in {model.dtype} dtype while the current module weight will be dequantized to {module.weight.quant_state.dtype} dtype. " + f"Please pass {module.weight.quant_state.dtype} as `torch_dtype` in `from_pretrained()`." + ) module_weight = dequantize_bnb_weight( module.weight.cuda() if weight_on_cpu else module.weight, state=module.weight.quant_state, From 8bf12e8b03d39cea94ffc7a87884c4ae4b74318b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Apr 2025 09:49:36 +0530 Subject: [PATCH 2/4] add a test --- tests/quantization/bnb/test_4bit.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 096ee4c34448..c210dff41866 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -752,6 +752,20 @@ def test_lora_loading(self): max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}") + def test_loading_lora_with_incorrect_dtype_raises_error(self): + self.tearDown() + model_dtype = torch.bfloat16 + # https://huggingface.co/eramth/flux-4bit/blob/main/transformer/config.json#L23 + actual_dtype = torch.float16 + self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.bfloat16) + self.pipeline_4bit.enable_model_cpu_offload() + with self.assertRaises(ValueError) as err_context: + self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora") + assert ( + f"Model is in {model_dtype} dtype while the current module weight will be dequantized to {actual_dtype} dtype." + in str(err_context.exception) + ) + @slow class BaseBnb4BitSerializationTests(Base4bitTests): From 8b7ef9caed1fe06ce9dcf76b8e685d2b7af56a41 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 14 Apr 2025 16:08:44 +0530 Subject: [PATCH 3/4] fix and updates --- src/diffusers/loaders/lora_pipeline.py | 9 ++------- tests/quantization/bnb/test_4bit.py | 14 -------------- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e1a759ee95c6..e65252f18b05 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -95,11 +95,6 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module): weight_on_cpu = True if is_bnb_4bit_quantized: - if module.weight.quant_state.dtype != model.dtype: - raise ValueError( - f"Model is in {model.dtype} dtype while the current module weight will be dequantized to {module.weight.quant_state.dtype} dtype. " - f"Please pass {module.weight.quant_state.dtype} as `torch_dtype` in `from_pretrained()`." - ) module_weight = dequantize_bnb_weight( module.weight.cuda() if weight_on_cpu else module.weight, state=module.weight.quant_state, @@ -2372,14 +2367,14 @@ def _maybe_expand_transformer_param_shape_or_error_( # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True. with torch.device("meta"): expanded_module = torch.nn.Linear( - in_features, out_features, bias=bias, dtype=module_weight.dtype + in_features, out_features, bias=bias, dtype=transformer.dtype ) # Only weights are expanded and biases are not. This is because only the input dimensions # are changed while the output dimensions remain the same. The shape of the weight tensor # is (out_features, in_features), while the shape of bias tensor is (out_features,), which # explains the reason why only weights are expanded. new_weight = torch.zeros_like( - expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype + expanded_module.weight.data, device=module_weight.device, dtype=transformer.dtype ) slices = tuple(slice(0, dim) for dim in module_weight_shape) new_weight[slices] = module_weight diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index c210dff41866..096ee4c34448 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -752,20 +752,6 @@ def test_lora_loading(self): max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}") - def test_loading_lora_with_incorrect_dtype_raises_error(self): - self.tearDown() - model_dtype = torch.bfloat16 - # https://huggingface.co/eramth/flux-4bit/blob/main/transformer/config.json#L23 - actual_dtype = torch.float16 - self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.bfloat16) - self.pipeline_4bit.enable_model_cpu_offload() - with self.assertRaises(ValueError) as err_context: - self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora") - assert ( - f"Model is in {model_dtype} dtype while the current module weight will be dequantized to {actual_dtype} dtype." - in str(err_context.exception) - ) - @slow class BaseBnb4BitSerializationTests(Base4bitTests): From fccff451d7cda713039aae175f96992fd046e713 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 17 Apr 2025 11:32:57 +0530 Subject: [PATCH 4/4] update --- src/diffusers/loaders/lora_pipeline.py | 4 ++-- src/diffusers/quantizers/bitsandbytes/utils.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 8f92d0fb8ef0..9e02069a109b 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2485,14 +2485,14 @@ def _maybe_expand_transformer_param_shape_or_error_( # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True. with torch.device("meta"): expanded_module = torch.nn.Linear( - in_features, out_features, bias=bias, dtype=transformer.dtype + in_features, out_features, bias=bias, dtype=module_weight.dtype ) # Only weights are expanded and biases are not. This is because only the input dimensions # are changed while the output dimensions remain the same. The shape of the weight tensor # is (out_features, in_features), while the shape of bias tensor is (out_features,), which # explains the reason why only weights are expanded. new_weight = torch.zeros_like( - expanded_module.weight.data, device=module_weight.device, dtype=transformer.dtype + expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype ) slices = tuple(slice(0, dim) for dim in module_weight_shape) new_weight[slices] = module_weight diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index e150281e81ce..5476b93a4cc2 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -171,9 +171,11 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torc if cls_name == "Params4bit": output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - logger.warning_once( - f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" - ) + msg = f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" + if dtype: + msg = f"The model is going to be first dequantized in {output_tensor.dtype} and type-casted to {dtype}" + output_tensor = output_tensor.to(dtype) + logger.warning_once(msg) return output_tensor if state.SCB is None: