Skip to content

Commit 8b7ef9c

Browse files
committed
fix and updates
1 parent 2107d4c commit 8b7ef9c

File tree

2 files changed

+2
-21
lines changed

2 files changed

+2
-21
lines changed

Diff for: src/diffusers/loaders/lora_pipeline.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,6 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
9595
weight_on_cpu = True
9696

9797
if is_bnb_4bit_quantized:
98-
if module.weight.quant_state.dtype != model.dtype:
99-
raise ValueError(
100-
f"Model is in {model.dtype} dtype while the current module weight will be dequantized to {module.weight.quant_state.dtype} dtype. "
101-
f"Please pass {module.weight.quant_state.dtype} as `torch_dtype` in `from_pretrained()`."
102-
)
10398
module_weight = dequantize_bnb_weight(
10499
module.weight.cuda() if weight_on_cpu else module.weight,
105100
state=module.weight.quant_state,
@@ -2372,14 +2367,14 @@ def _maybe_expand_transformer_param_shape_or_error_(
23722367
# TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
23732368
with torch.device("meta"):
23742369
expanded_module = torch.nn.Linear(
2375-
in_features, out_features, bias=bias, dtype=module_weight.dtype
2370+
in_features, out_features, bias=bias, dtype=transformer.dtype
23762371
)
23772372
# Only weights are expanded and biases are not. This is because only the input dimensions
23782373
# are changed while the output dimensions remain the same. The shape of the weight tensor
23792374
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which
23802375
# explains the reason why only weights are expanded.
23812376
new_weight = torch.zeros_like(
2382-
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
2377+
expanded_module.weight.data, device=module_weight.device, dtype=transformer.dtype
23832378
)
23842379
slices = tuple(slice(0, dim) for dim in module_weight_shape)
23852380
new_weight[slices] = module_weight

Diff for: tests/quantization/bnb/test_4bit.py

-14
Original file line numberDiff line numberDiff line change
@@ -752,20 +752,6 @@ def test_lora_loading(self):
752752
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
753753
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
754754

755-
def test_loading_lora_with_incorrect_dtype_raises_error(self):
756-
self.tearDown()
757-
model_dtype = torch.bfloat16
758-
# https://huggingface.co/eramth/flux-4bit/blob/main/transformer/config.json#L23
759-
actual_dtype = torch.float16
760-
self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.bfloat16)
761-
self.pipeline_4bit.enable_model_cpu_offload()
762-
with self.assertRaises(ValueError) as err_context:
763-
self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
764-
assert (
765-
f"Model is in {model_dtype} dtype while the current module weight will be dequantized to {actual_dtype} dtype."
766-
in str(err_context.exception)
767-
)
768-
769755

770756
@slow
771757
class BaseBnb4BitSerializationTests(Base4bitTests):

0 commit comments

Comments
 (0)