diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h index 95ecb79dc0..ce0ac804c9 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h @@ -126,7 +126,7 @@ void pack_weights( bias); } -template +template void pack_weights_with_lut( // Output void* packed_weights, @@ -141,10 +141,16 @@ void pack_weights_with_lut( // weight_zeros not packed if nullptr const int8_t* weight_zeros, // bias not packed if nullptr - const float* bias) { + const float* bias, + int nr, + int kr, + int sr) { + (void)nr; // unused + (void)kr; // unused + (void)sr; // unused torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight::weight_packing:: - pack_weights_with_lut( + pack_weights_with_lut( packed_weights, n, k, diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 671ee3f0b9..6d6101e3cf 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -478,7 +478,10 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( lut.data(), test_case.weight_scales.data(), has_weight_zeros ? test_case.weight_zeros.data() : nullptr, - has_bias ? test_case.bias.data() : nullptr); + has_bias ? test_case.bias.data() : nullptr, + nr, + kr, + sr); std::vector output(m * n); kernel_1x8x16_f32_neondot( diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h index 1e4a9ef670..114e97838c 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h @@ -87,6 +87,23 @@ struct UKernelConfig { int kr, int sr); + // Pack weights into packed_weights buffer with int8-valued LUT + using pack_weights_with_lut_fn_type = void (*)( + void* packed_weights, + int n, + int k, + int group_size, + const int8_t* weight_qval_idxs, + int n_luts, + const int8_t* luts, + const float* weight_scales, + const int8_t* weight_zeros, + const float* bias, + int nr, + int kr, + int sr + ); + // Run matmul kernel using kernel_fn_type = void (*)( float* output, @@ -126,6 +143,7 @@ struct UKernelConfig { packed_weights_size_fn_type packed_weights_size{nullptr}; packed_weights_offset_fn_type packed_weights_offset{nullptr}; pack_weights_fn_type pack_weights{nullptr}; + pack_weights_with_lut_fn_type pack_weights_with_lut{nullptr}; // linear_configs must be sorted in ascending m_step std::array linear_configs; @@ -144,6 +162,20 @@ struct UKernelConfig { pack_weights_fn_type pack_weights, std::array linear_configs); + static UKernelConfig make_with_lut( + size_t preferred_alignment, + int n_step, + int nr, + int kr, + int sr, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + packed_weights_size_fn_type packed_weights_with_lut_size, + packed_weights_offset_fn_type packed_weights_with_lut_offset, + pack_weights_with_lut_fn_type pack_weights_with_lut, + std::array linear_configs); + inline void validate() const { TORCHAO_CHECK(preferred_alignment >= 1, "preferred_alignment must be >= 1"); TORCHAO_CHECK(n_step >= 1, "n_step must be >= 1"); @@ -155,7 +187,7 @@ struct UKernelConfig { packed_weights_size != nullptr, "packed_weights_size must be set"); TORCHAO_CHECK( packed_weights_offset != nullptr, "packed_weights_offset must be set"); - TORCHAO_CHECK(pack_weights != nullptr, "pack_weights must be set"); + TORCHAO_CHECK(pack_weights != nullptr || pack_weights_with_lut != nullptr, "pack_weights or pack_weights_with_lut must be set"); bool linear_configs_set = true; // first linear config must be set for (int i = 0; i < linear_configs.size(); i++) { @@ -232,6 +264,36 @@ inline UKernelConfig UKernelConfig::make( packed_weights_size, packed_weights_offset, pack_weights, + /*pack_weights_with_lut*/nullptr, + std::move(linear_configs)}; +} + +inline UKernelConfig UKernelConfig::make_with_lut( + size_t preferred_alignment, + int n_step, + int nr, + int kr, + int sr, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + packed_weights_size_fn_type packed_weights_with_lut_size, + packed_weights_offset_fn_type packed_weights_with_lut_offset, + pack_weights_with_lut_fn_type pack_weights_with_lut, + std::array linear_configs) { + return UKernelConfig{ + preferred_alignment, + n_step, + nr, + kr, + sr, + weight_nbit, + has_weight_zeros, + has_bias, + packed_weights_with_lut_size, + packed_weights_with_lut_offset, + /*pack_weights*/nullptr, + /*pack_weights_with_lut*/pack_weights_with_lut, std::move(linear_configs)}; } diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index ffdd62f7a7..930b93bd46 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -164,6 +164,70 @@ void register_ukernel_config_universal( } } +template +void register_ukernel_config_lut( + UKernelConfigRegistrationTable& table, + PackedWeightsFormat format, + cpuinfo_uarch uarch) { + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + check_format( + format, + torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_lut, + weight_nbit + ); + constexpr bool has_lut = true; + int preferred_alignment = 16; + + #if defined(TORCHAO_ENABLE_ARM_NEON_DOT) + namespace kernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight; + + if (cpuinfo_has_arm_neon_dot()) { + return; + } + if (format.has_weight_zeros) { + return; + } + constexpr bool has_weight_zeros = false; + if (format.nr == 8 && format.kr == 16 && format.sr == 2) { + log_registration(format, "lut: kernel_1x8x16_f32_neondot"); + constexpr int n_step = 8; + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; + constexpr int mr = 1; + constexpr int m_step = 1; + auto uk = UKernelConfig::make_with_lut( + preferred_alignment, + n_step, + nr, + kr, + sr, + weight_nbit, + format.has_weight_zeros, + format.has_bias, + &kernel::packed_weights_with_lut_size, + &kernel::packed_weights_with_lut_offset, + &kernel::pack_weights_with_lut, + /*linear_configs*/ {}); + uk.linear_configs[0] = UKernelConfig::linear_config_type( + {m_step, + mr, + &kernel::packed_activations_size, + &kernel::packed_activations_offset, + &kernel::pack_activations, + &kernel::kernel_1x8x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_lut>}); + table.register_ukernel_config(format, uarch, std::move(uk)); + return; + } + #endif // TORCHAO_ENABLE_ARM_NEON_DOT +} + #if defined(TORCHAO_ENABLE_KLEIDI) template UKernelConfig::linear_config_type @@ -285,6 +349,14 @@ void register_ukernel_config( #endif // TORCHAO_ENABLE_KLEIDI break; } + case torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_lut: { + // LUT kernels static assert on weight_nbit <= 4 + // This is needed to avoid compilation error + if constexpr (weight_nbit <= 4) { + register_ukernel_config_lut(table, format, uarch); + } + break; + } default: throw std::runtime_error( "No registration available for packed_weights_type=" + @@ -377,4 +449,24 @@ PackedWeightsFormat select_packed_weights_format( throw std::runtime_error("No packed_weights_format was selected"); } +template +PackedWeightsFormat select_packed_weights_with_lut_format( + std::optional target, + bool has_weight_zeros, + bool has_bias) { + if (!target) { +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) + return PackedWeightsFormat( + torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_lut, + weight_nbit, + has_weight_zeros, + has_bias, + /*nr*/ 8, + /*kr*/ 16, + /*sr*/ 2); +#endif // defined(TORCHAO_ENABLE_ARM_NEON_DOT) + } + throw std::runtime_error("No packed_weights_format was selected"); +} + } // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp index 6929e6e4a4..96bfe17b5a 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp @@ -75,6 +75,61 @@ void pack_weights_operator( }); } +void pack_weights_with_lut_operator( + const UKernelConfig& uk, + // Outputs + void* packed_weights, + // Inputs + int n, + int k, + int group_size, + const int8_t* weight_qval_idxs, + int n_luts, + const int8_t* luts, + const float* weight_scales, + const int8_t* weight_zeros, + const float* bias) { + int n_step = uk.n_step; + int nc = std::min(n, n_step); + int num_nc_panels = (n + nc - 1) / nc; + + torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) { + int nc_tile_idx = idx; + int n_idx = nc_tile_idx * nc; + int nc_tile_size = std::min(nc, n - n_idx); + + auto packed_weights_offset = uk.packed_weights_offset( + n_idx, + k, + group_size, + uk.weight_nbit, + uk.has_weight_zeros, + uk.has_bias, + uk.nr, + uk.kr, + uk.sr); + + int weight_qval_idxs_offset = n_idx * k; + int weight_scales_and_zeros_offset = (n_idx * k / group_size); + uk.pack_weights_with_lut( + (char*)packed_weights + packed_weights_offset, + /*n=*/nc_tile_size, + k, + group_size, + weight_qval_idxs + weight_qval_idxs_offset, + n_luts, + luts, + weight_scales + weight_scales_and_zeros_offset, + (weight_zeros == nullptr) + ? nullptr + : (weight_zeros + weight_scales_and_zeros_offset), + (bias == nullptr) ? nullptr : (bias + n_idx), + uk.nr, + uk.kr, + uk.sr); + }); +} + LinearTilingParams LinearTilingParams::from_target_tiles_per_thread( int m, int m_step, diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h index accc5be5a1..95e1640ad9 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h @@ -27,6 +27,21 @@ void pack_weights_operator( const int8_t* weight_zeros, const float* bias); +void pack_weights_with_lut_operator( + const UKernelConfig& uk, + // Outputs + void* packed_weights, + // Inputs + int n, + int k, + int group_size, + const int8_t* weight_qval_idxs, + int n_luts, + const int8_t* luts, + const float* weight_scales, + const int8_t* weight_zeros, + const float* bias); + // Linear functions struct LinearTilingParams { int mc{0}; diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index 065a5b0319..8a72cbd00a 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -227,4 +227,139 @@ Tensor linear_cpu( } #endif // USE_ATEN +#ifdef USE_ATEN +template +Tensor pack_weights_with_lut_cpu( + const Tensor& weight_qval_idxs, + const Tensor& luts, + const Tensor& weight_scales, + int64_t group_size, + const std::optional& bias, + const std::optional& target) { + bool has_bias = bias.has_value(); + bool has_weight_zeros = false; + + TORCHAO_CHECK( + weight_qval_idxs.dtype() == torch::kInt8, "weight_qval_idxs must be int8"); + TORCHAO_CHECK(weight_qval_idxs.dim() == 2, "weight_qval_idxs must be 2D"); + + int n = weight_qval_idxs.size(0); + int k = weight_qval_idxs.size(1); + + TORCHAO_CHECK( + weight_scales.dtype() == torch::kFloat32, + "weight_scales must be float32"); + TORCHAO_CHECK(weight_scales.dim() == 1, "weight_scales must be 1D"); + TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1"); + TORCHAO_CHECK( + weight_scales.size(0) == ((n * k) / group_size), + "expected 1 scale per group"); + + TORCHAO_CHECK( + luts.dtype() == torch::kInt8, "luts must be int8"); + TORCHAO_CHECK(luts.dim() == 2, "luts must be 2D"); + int n_luts = luts.size(0); + TORCHAO_CHECK( + n % n_luts == 0, + "the number of luts must divide n"); + int lut_channel_group_size = n / n_luts; + TORCHAO_CHECK( + luts.size(1) == (1 << weight_nbit), + "luts must have 1 entry per quantization level"); + + const float* bias_ptr = nullptr; + if (has_bias) { + TORCHAO_CHECK( + bias.value().dtype() == torch::kFloat32, "bias must be float32"); + TORCHAO_CHECK(bias.value().dim() == 1, "bias must be 1D"); + TORCHAO_CHECK(bias.value().size(0) == n, "expected 1 bias per row"); + bias_ptr = bias.value().const_data_ptr(); + } + + TORCHAO_CHECK( + !target.has_value(), + "target is not currently supported in pack_weights_with_lut_cpu" + ); + + auto packed_weights_format = torchao::ops::linear_8bit_act_xbit_weight::select_packed_weights_with_lut_format< + weight_nbit>(target, has_weight_zeros, has_bias); + TORCHAO_CHECK(packed_weights_format.nr == 8, "nr must be 8"); + TORCHAO_CHECK( + lut_channel_group_size % 8 == 0, + "the lut_channel_group_size must be a multiple of nr (8)"); + + auto packed_weights_header = packed_weights_format.to_packed_weights_header(); + auto uk = torchao::ops::linear_8bit_act_xbit_weight::select_ukernel_config< + weight_nbit>(packed_weights_header); + auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + + uk.packed_weights_size( + n, + k, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + uk.nr, + uk.kr, + uk.sr); + + Tensor packed_weights = torch::empty( + {static_cast(packed_weight_data_size)}, torch::kInt8); + packed_weights_header.write(packed_weights.mutable_data_ptr()); + + torchao::ops::linear_8bit_act_xbit_weight::pack_weights_with_lut_operator( + uk, + packed_weights.mutable_data_ptr() + + torchao::ops::PackedWeightsHeader::size(), + n, + k, + group_size, + weight_qval_idxs.const_data_ptr(), + n_luts, + luts.const_data_ptr(), + weight_scales.const_data_ptr(), + /*weight_zeros*/nullptr, + bias_ptr); + + return packed_weights; +} +#endif // USE_ATEN + +#ifdef USE_ATEN +template +Tensor pack_weights_with_lut_meta( + const Tensor& weight_qval_idxs, + const Tensor& luts, + const Tensor& weight_scales, + int64_t group_size, + const std::optional& bias, + const std::optional& target) { + bool has_bias = bias.has_value(); + bool has_weight_zeros = false; + int n = weight_qval_idxs.size(0); + int k = weight_qval_idxs.size(1); + auto packed_weights_format = torchao::ops::linear_8bit_act_xbit_weight::select_packed_weights_with_lut_format< + weight_nbit>(target, has_weight_zeros, has_bias); + + auto uk = torchao::ops::linear_8bit_act_xbit_weight::select_ukernel_config< + weight_nbit>(packed_weights_format); + + auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + + uk.packed_weights_size( + n, + k, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + uk.nr, + uk.kr, + uk.sr); + + auto options = + torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8); + return torch::empty({static_cast(packed_weight_data_size)}, options); +} +#endif // USE_ATEN + } // namespace diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp index a96d322cd0..7e5799b5fd 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp @@ -29,13 +29,26 @@ &linear_out_cpu) #define DEFINE_META_IMPL(weight_nbit) \ - m.impl( \ - "_pack_8bit_act_" #weight_nbit "bit0zp_weight", \ - &pack_weights_meta); \ m.impl( \ "_pack_8bit_act_" #weight_nbit "bit_weight", \ &pack_weights_meta) +#define DEFINE_LUT_PACK_OP(weight_nbit) \ + m.def( \ + "_pack_8bit_act_" #weight_nbit \ + "bit_weight_with_lut(Tensor weight_qval_ids, Tensor luts, Tensor weight_scales, int group_size, Tensor? bias, str? target) -> Tensor") + +#define DEFINE_LUT_PACK_CPU_IMPL(weight_nbit) \ + m.impl( \ + "_pack_8bit_act_" #weight_nbit "bit_weight_with_lut", \ + &pack_weights_with_lut_cpu) + +#define DEFINE_LUT_PACK_META_IMPL(weight_nbit) \ + m.impl( \ + "_pack_8bit_act_" #weight_nbit "bit_weight_with_lut", \ + &pack_weights_with_lut_meta) + + TORCH_LIBRARY_FRAGMENT(torchao, m) { DEFINE_OP(1); DEFINE_OP(2); @@ -45,6 +58,11 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { DEFINE_OP(6); DEFINE_OP(7); DEFINE_OP(8); + + DEFINE_LUT_PACK_OP(1); + DEFINE_LUT_PACK_OP(2); + DEFINE_LUT_PACK_OP(3); + DEFINE_LUT_PACK_OP(4); } TORCH_LIBRARY_IMPL(torchao, CPU, m) { @@ -56,6 +74,11 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) { DEFINE_CPU_IMPL(6); DEFINE_CPU_IMPL(7); DEFINE_CPU_IMPL(8); + + DEFINE_LUT_PACK_CPU_IMPL(1); + DEFINE_LUT_PACK_CPU_IMPL(2); + DEFINE_LUT_PACK_CPU_IMPL(3); + DEFINE_LUT_PACK_CPU_IMPL(4); } TORCH_LIBRARY_IMPL(torchao, Meta, m) { @@ -67,4 +90,9 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) { DEFINE_META_IMPL(6); DEFINE_META_IMPL(7); DEFINE_META_IMPL(8); + + DEFINE_LUT_PACK_META_IMPL(1); + DEFINE_LUT_PACK_META_IMPL(2); + DEFINE_LUT_PACK_META_IMPL(3); + DEFINE_LUT_PACK_META_IMPL(4); } diff --git a/torchao/experimental/ops/packed_weights_header.h b/torchao/experimental/ops/packed_weights_header.h index 213ec34f7f..0869c12ef9 100644 --- a/torchao/experimental/ops/packed_weights_header.h +++ b/torchao/experimental/ops/packed_weights_header.h @@ -16,7 +16,8 @@ enum class PackedWeightsType : uint32_t { unknown = 0, linear_8bit_act_xbit_weight_universal = 1, embedding_xbit_universal = 2, - kleidi_ai = 3 + kleidi_ai = 3, + linear_8bit_act_xbit_weight_lut = 4, }; class PackedWeightsHeader { diff --git a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp index 1d4127a43e..16c38aa8d3 100644 --- a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp @@ -25,7 +25,7 @@ const float kTolKleidiAI = 5.0e-2; using namespace torchao::ops::linear_8bit_act_xbit_weight; -template +template UKernelConfig get_ukernel_config() { namespace kernel = torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; @@ -37,7 +37,6 @@ UKernelConfig get_ukernel_config() { constexpr int sr = 2; constexpr int mr = 1; int m_step = 1; - constexpr bool has_lut = false; auto uk = UKernelConfig::make( preferred_alignment, @@ -62,6 +61,13 @@ UKernelConfig get_ukernel_config() { &kernel:: kernel_1x8x16_f32_neondot}; + if constexpr (has_lut) { + uk.packed_weights_size = &kernel::packed_weights_with_lut_size; + uk.packed_weights_offset = &kernel::packed_weights_with_lut_offset; + uk.pack_weights = nullptr; + uk.pack_weights_with_lut = &kernel::pack_weights_with_lut; + } + return uk; } @@ -70,7 +76,8 @@ template < bool has_weight_zeros, bool has_bias, bool has_clamp, - bool has_kleidi = false> + bool has_kleidi = false, + bool has_lut = false> void test_linear_8bit_act_xbit_weight( int m, int n, @@ -85,7 +92,8 @@ void test_linear_8bit_act_xbit_weight( weight_nbit, has_weight_zeros, has_bias, - has_clamp>(); + has_clamp, + has_lut>(); } auto test_case = torchao:: @@ -132,6 +140,31 @@ void test_linear_8bit_act_xbit_weight( bias_ptr = test_case.bias.data(); } + if constexpr (has_lut) { + // Define equivalent LUT for affine quantization + constexpr int lut_size = (1 << weight_nbit); + std::vector weight_qval_idxs(test_case.weight_qvals.size()); + std::vector lut(lut_size, 0); + constexpr int offset = (1 << (weight_nbit - 1)); + for (int i = 0; i < test_case.weight_qvals.size(); i++) { + weight_qval_idxs[i] = test_case.weight_qvals[i] + offset; + } + for (int i = 0; i < lut_size; i++) { + lut[i] = i - offset; + } + pack_weights_with_lut_operator( + ukernel_config, + packed_weights.get(), + n, + k, + group_size, + weight_qval_idxs.data(), + /*n_luts*/ 1, + lut.data(), + test_case.weight_scales.data(), + weight_zeros_ptr, + bias_ptr); + } else { pack_weights_operator( ukernel_config, packed_weights.get(), @@ -142,6 +175,7 @@ void test_linear_8bit_act_xbit_weight( test_case.weight_scales.data(), weight_zeros_ptr, bias_ptr); + } linear_operator( ukernel_config, @@ -313,6 +347,74 @@ UKernelConfig get_ukernel_config_kleidi() { #endif // TORCHAO_ENABLE_KLEIDI +TEST(test_linear_8bit_act_xbit_weight_lut, Standard) { + constexpr bool has_kleidi = false; + constexpr bool has_lut = true; + constexpr int weight_nbit = 3; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + test_linear_8bit_act_xbit_weight< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp, + has_kleidi, + has_lut>( + /*m=*/13, /*n=*/8 * 10 + 3, /*k=*/16 * 3, /*group_size=*/16); +} + +TEST(test_linear_8bit_act_xbit_weight_lut, HasWeightZeros) { + constexpr bool has_kleidi = false; + constexpr bool has_lut = true; + constexpr int weight_nbit = 3; + constexpr bool has_weight_zeros = true; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + test_linear_8bit_act_xbit_weight< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp, + has_kleidi, + has_lut>( + /*m=*/13, /*n=*/8 * 10 + 3, /*k=*/16 * 3, /*group_size=*/16); +} + +TEST(test_linear_8bit_act_xbit_weight_lut, HasBias) { + constexpr bool has_kleidi = false; + constexpr bool has_lut = true; + constexpr int weight_nbit = 3; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = true; + constexpr bool has_clamp = false; + test_linear_8bit_act_xbit_weight< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp, + has_kleidi, + has_lut>( + /*m=*/13, /*n=*/8 * 10 + 3, /*k=*/16 * 3, /*group_size=*/16); +} + +TEST(test_linear_8bit_act_xbit_weight_lut, HasClamp) { + constexpr bool has_kleidi = false; + constexpr bool has_lut = true; + constexpr int weight_nbit = 3; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = true; + test_linear_8bit_act_xbit_weight< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp, + has_kleidi, + has_lut>( + /*m=*/13, /*n=*/8 * 10 + 3, /*k=*/16 * 3, /*group_size=*/16); +} + TEST(test_linear_8bit_act_xbit_weight, Standard) { test_linear_8bit_act_xbit_weight< 4 /*weight_nbit*/,