diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index cc4d4fc1b017..3d1056c53a3e 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -428,6 +428,10 @@ def __init__(self, compute_dtype: Optional["torch.dtype"] = None): if self.compute_dtype is None: self.compute_dtype = torch.float32 + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + @dataclass class TorchAoConfig(QuantizationConfigMixin): @@ -722,3 +726,7 @@ def post_init(self): accepted_weights = ["float8", "int8", "int4", "int2"] if self.weights_dtype not in accepted_weights: raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}") + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"