diff --git a/backends/xnnpack/partition/config/quant_affine_configs.py b/backends/xnnpack/partition/config/quant_affine_configs.py index 046402800a3..2ae84c06ad3 100644 --- a/backends/xnnpack/partition/config/quant_affine_configs.py +++ b/backends/xnnpack/partition/config/quant_affine_configs.py @@ -7,6 +7,7 @@ from typing import List, Optional import torch +import torchao.quantization.quant_primitives # noqa from executorch.backends.xnnpack.partition.config.xnnpack_config import ( ConfigPrecisionType, XNNPartitionerConfig, @@ -33,33 +34,18 @@ class QuantizeAffineConfig(QDQAffineConfigs): target_name = "quantize_affine.default" def get_original_aten(self) -> Optional[torch._ops.OpOverload]: - try: - import torchao.quantization.quant_primitives # noqa - - return torch.ops.torchao.quantize_affine.default - except: - return None + return torch.ops.torchao.quantize_affine.default class DeQuantizeAffineConfig(QDQAffineConfigs): target_name = "dequantize_affine.default" def get_original_aten(self) -> Optional[torch._ops.OpOverload]: - try: - import torchao.quantization.quant_primitives # noqa - - return torch.ops.torchao.dequantize_affine.default - except: - return None + return torch.ops.torchao.dequantize_affine.default class ChooseQParamsAffineConfig(QDQAffineConfigs): target_name = "choose_qparams_affine.default" def get_original_aten(self) -> Optional[torch._ops.OpOverload]: - try: - import torchao.quantization.quant_primitives # noqa - - return torch.ops.torchao.choose_qparams_affine.default - except: - return None + return torch.ops.torchao.choose_qparams_affine.default diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index 421e59c0b08..b9094f367d3 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -33,16 +33,11 @@ from torch.export.graph_signature import ExportGraphSignature, InputKind -try: - from torchao.quantization.quant_api import ( - int8_dynamic_activation_int4_weight, - quantize_, - ) - from torchao.utils import unwrap_tensor_subclass - - torchao_installed = True -except: - torchao_installed = False +from torchao.quantization.quant_api import ( + int8_dynamic_activation_int4_weight, + quantize_, +) +from torchao.utils import unwrap_tensor_subclass # Pytorch Modules Used for Testing @@ -818,22 +813,13 @@ def test_linear_qd8_f32_per_channel_int4(self): self._test_qd8_per_channel_4w_linear(dtype=torch.float) # Tests for q[dp]8-f16-qb4w - @unittest.skipIf( - not torchao_installed, "Per Channel Group Quantization Required TorchAO" - ) def test_linear_qd8_f16_per_token_weight_per_channel_group_int4(self): self._test_qd8_per_token_weight_per_channel_group_int4(dtype=torch.half) # Tests for q[dp]8-f32-qb4w - @unittest.skipIf( - not torchao_installed, "Per Channel Group Quantization Required TorchAO" - ) def test_linear_qd8_f32_per_token_weight_per_channel_group_int4(self): self._test_qd8_per_token_weight_per_channel_group_int4(dtype=torch.float) - @unittest.skipIf( - not torchao_installed, "Per Channel Group Quantization Required TorchAO" - ) def test_linear_qd8_per_token_groupwise_unsupported_groupsize(self): # groupsize must be multiple of 32 for dtype in [torch.float, torch.half]: