diff --git a/coremltools/converters/mil/frontend/torch/quantization_ops.py b/coremltools/converters/mil/frontend/torch/quantization_ops.py index 761318498..b0674ef7f 100644 --- a/coremltools/converters/mil/frontend/torch/quantization_ops.py +++ b/coremltools/converters/mil/frontend/torch/quantization_ops.py @@ -803,7 +803,6 @@ def dequantize_affine(context, node): int_data.astype(quantized_np_dtype), zero_point, scale, - axis=-1, name=node.name, ) context.add(output, node.name) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py index c46179a3a..6d10a9e36 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py @@ -272,6 +272,52 @@ def forward(self, x): prog = res[1]._mil_program assert get_op_types_in_program(prog) == ["constexpr_blockwise_shift_scale", "linear"] + @pytest.mark.skipif(not _HAS_TORCHAO, reason=MSG_TORCHAO_NOT_FOUND) + @pytest.mark.parametrize( + "compute_unit, has_zeros", + itertools.product(compute_units, [True, False], [ct.target.IOS16, ct.target.IOS17]), + ) + def test_dequantize_affine_before_ios18(self, compute_unit, has_zeros, minimum_deployment_target): + + quant_min = -128 + quant_max = 127 + + n = 4 + k = 128 + input_dtype = torch.int8 + int_data = torch.randint(low=quant_min, high=quant_max, size=(n, k)).to(input_dtype) + scale = torch.rand(n, 1) + + zero_point = None + if has_zeros: + zero_point = torch.randint(low=quant_min, high=quant_max, size=(n, 1)).to(input_dtype) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("int_data", int_data) + self.register_buffer("scale", scale) + self.register_buffer("zero_point", zero_point) + + def forward(self, x): + w = torchao_quant.dequantize_affine(self.int_data, [1, k], self.scale, self.zero_point, input_dtype, quant_min, quant_max) + return torch.nn.functional.linear(x, w) + + + model = Model() + model = model.to(torch.device("cpu")) + + input_shape = [(3, k)] + res = self.run_compare_torch( + input_shape, + model, + minimum_deployment_target=minimum_deployment_target, + compute_unit=compute_unit, + rtol=0.1, + frontend=TorchFrontend.TORCHEXPORT, + ) + prog = res[1]._mil_program + assert get_op_types_in_program(prog) == ["constexpr_affine_dequantize", "linear"] # TODO(rdar://108463675): refactor torch op tests later to parametrize quantized vs standard ops