From 9bae97c11af9cdd3d5668bd9ca7bd144ab22c953 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 25 Sep 2019 13:53:08 -0700 Subject: [PATCH] Use tensorization for avx2 as well. Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torch_tvm/custom_tvm_ops/topi/quantized_linear_int8.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch_tvm/custom_tvm_ops/topi/quantized_linear_int8.py b/torch_tvm/custom_tvm_ops/topi/quantized_linear_int8.py index 124da7d..e1433c8 100644 --- a/torch_tvm/custom_tvm_ops/topi/quantized_linear_int8.py +++ b/torch_tvm/custom_tvm_ops/topi/quantized_linear_int8.py @@ -4,7 +4,8 @@ from topi.util import get_const_int from topi.generic import nn from topi import tag -from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int32 +from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int32, \ + dot_1x4x16_int8_int8_int32_avx2 from enum import Enum AVXType = Enum('AVXType', 'AVX2 AVX512 None') @@ -146,6 +147,9 @@ def _schedule_quantized_mm(cfg, s, QGEMM): if avx_type == AVXType.AVX512: pc = dot_16x1x16_int8_int8_int32() s[QGEMM].tensorize(xi, pc) + if avx_type == AVXType.AVX2: + pc = dot_1x4x16_int8_int8_int32_avx2() + s[QGEMM].tensorize(xi, pc) else: s[QGEMM].reorder(xo, y, xi) s[QGEMM].unroll(y)