diff --git a/torch_tvm/custom_tvm_ops/topi/quantized_linear_int8.py b/torch_tvm/custom_tvm_ops/topi/quantized_linear_int8.py index 124da7d..e1433c8 100644 --- a/torch_tvm/custom_tvm_ops/topi/quantized_linear_int8.py +++ b/torch_tvm/custom_tvm_ops/topi/quantized_linear_int8.py @@ -4,7 +4,8 @@ from topi.util import get_const_int from topi.generic import nn from topi import tag -from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int32 +from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int32, \ + dot_1x4x16_int8_int8_int32_avx2 from enum import Enum AVXType = Enum('AVXType', 'AVX2 AVX512 None') @@ -146,6 +147,9 @@ def _schedule_quantized_mm(cfg, s, QGEMM): if avx_type == AVXType.AVX512: pc = dot_16x1x16_int8_int8_int32() s[QGEMM].tensorize(xi, pc) + if avx_type == AVXType.AVX2: + pc = dot_1x4x16_int8_int8_int32_avx2() + s[QGEMM].tensorize(xi, pc) else: s[QGEMM].reorder(xo, y, xi) s[QGEMM].unroll(y)