-
Notifications
You must be signed in to change notification settings - Fork 259
Add int8 LUT CPU ops #2026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add int8 LUT CPU ops #2026
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -87,6 +87,23 @@ struct UKernelConfig { | |||||||||||||
int kr, | ||||||||||||||
int sr); | ||||||||||||||
|
||||||||||||||
// Pack weights into packed_weights buffer with int8-valued LUT | ||||||||||||||
using pack_weights_with_lut_fn_type = void (*)( | ||||||||||||||
void* packed_weights, | ||||||||||||||
int n, | ||||||||||||||
int k, | ||||||||||||||
int group_size, | ||||||||||||||
const int8_t* weight_qval_idxs, | ||||||||||||||
int n_luts, | ||||||||||||||
const int8_t* luts, | ||||||||||||||
const float* weight_scales, | ||||||||||||||
const int8_t* weight_zeros, | ||||||||||||||
const float* bias, | ||||||||||||||
int nr, | ||||||||||||||
int kr, | ||||||||||||||
int sr | ||||||||||||||
); | ||||||||||||||
|
||||||||||||||
// Run matmul kernel | ||||||||||||||
using kernel_fn_type = void (*)( | ||||||||||||||
float* output, | ||||||||||||||
|
@@ -126,6 +143,7 @@ struct UKernelConfig { | |||||||||||||
packed_weights_size_fn_type packed_weights_size{nullptr}; | ||||||||||||||
packed_weights_offset_fn_type packed_weights_offset{nullptr}; | ||||||||||||||
pack_weights_fn_type pack_weights{nullptr}; | ||||||||||||||
pack_weights_with_lut_fn_type pack_weights_with_lut{nullptr}; | ||||||||||||||
Comment on lines
145
to
+146
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we have a different uKernel config for LUT? or even worse a
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer separate ones for now |
||||||||||||||
|
||||||||||||||
// linear_configs must be sorted in ascending m_step | ||||||||||||||
std::array<linear_config_type, kMaxLinearConfigs> linear_configs; | ||||||||||||||
|
@@ -144,6 +162,20 @@ struct UKernelConfig { | |||||||||||||
pack_weights_fn_type pack_weights, | ||||||||||||||
std::array<linear_config_type, kMaxLinearConfigs> linear_configs); | ||||||||||||||
|
||||||||||||||
static UKernelConfig make_with_lut( | ||||||||||||||
size_t preferred_alignment, | ||||||||||||||
int n_step, | ||||||||||||||
int nr, | ||||||||||||||
int kr, | ||||||||||||||
int sr, | ||||||||||||||
int weight_nbit, | ||||||||||||||
bool has_weight_zeros, | ||||||||||||||
bool has_bias, | ||||||||||||||
packed_weights_size_fn_type packed_weights_with_lut_size, | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. function overload? |
||||||||||||||
packed_weights_offset_fn_type packed_weights_with_lut_offset, | ||||||||||||||
pack_weights_with_lut_fn_type pack_weights_with_lut, | ||||||||||||||
std::array<linear_config_type, kMaxLinearConfigs> linear_configs); | ||||||||||||||
|
||||||||||||||
inline void validate() const { | ||||||||||||||
TORCHAO_CHECK(preferred_alignment >= 1, "preferred_alignment must be >= 1"); | ||||||||||||||
TORCHAO_CHECK(n_step >= 1, "n_step must be >= 1"); | ||||||||||||||
|
@@ -155,7 +187,7 @@ struct UKernelConfig { | |||||||||||||
packed_weights_size != nullptr, "packed_weights_size must be set"); | ||||||||||||||
TORCHAO_CHECK( | ||||||||||||||
packed_weights_offset != nullptr, "packed_weights_offset must be set"); | ||||||||||||||
TORCHAO_CHECK(pack_weights != nullptr, "pack_weights must be set"); | ||||||||||||||
TORCHAO_CHECK(pack_weights != nullptr || pack_weights_with_lut != nullptr, "pack_weights or pack_weights_with_lut must be set"); | ||||||||||||||
|
||||||||||||||
bool linear_configs_set = true; // first linear config must be set | ||||||||||||||
for (int i = 0; i < linear_configs.size(); i++) { | ||||||||||||||
|
@@ -232,6 +264,36 @@ inline UKernelConfig UKernelConfig::make( | |||||||||||||
packed_weights_size, | ||||||||||||||
packed_weights_offset, | ||||||||||||||
pack_weights, | ||||||||||||||
/*pack_weights_with_lut*/nullptr, | ||||||||||||||
std::move(linear_configs)}; | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
inline UKernelConfig UKernelConfig::make_with_lut( | ||||||||||||||
size_t preferred_alignment, | ||||||||||||||
int n_step, | ||||||||||||||
int nr, | ||||||||||||||
int kr, | ||||||||||||||
int sr, | ||||||||||||||
int weight_nbit, | ||||||||||||||
bool has_weight_zeros, | ||||||||||||||
bool has_bias, | ||||||||||||||
packed_weights_size_fn_type packed_weights_with_lut_size, | ||||||||||||||
packed_weights_offset_fn_type packed_weights_with_lut_offset, | ||||||||||||||
pack_weights_with_lut_fn_type pack_weights_with_lut, | ||||||||||||||
std::array<linear_config_type, kMaxLinearConfigs> linear_configs) { | ||||||||||||||
return UKernelConfig{ | ||||||||||||||
preferred_alignment, | ||||||||||||||
n_step, | ||||||||||||||
nr, | ||||||||||||||
kr, | ||||||||||||||
sr, | ||||||||||||||
weight_nbit, | ||||||||||||||
has_weight_zeros, | ||||||||||||||
has_bias, | ||||||||||||||
packed_weights_with_lut_size, | ||||||||||||||
packed_weights_with_lut_offset, | ||||||||||||||
/*pack_weights*/nullptr, | ||||||||||||||
/*pack_weights_with_lut*/pack_weights_with_lut, | ||||||||||||||
std::move(linear_configs)}; | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -164,6 +164,70 @@ void register_ukernel_config_universal( | |
} | ||
} | ||
|
||
template <int weight_nbit> | ||
void register_ukernel_config_lut( | ||
UKernelConfigRegistrationTable& table, | ||
PackedWeightsFormat format, | ||
cpuinfo_uarch uarch) { | ||
if (!cpuinfo_initialize()) { | ||
throw std::runtime_error("Failed to initialize cpuinfo!"); | ||
} | ||
check_format( | ||
format, | ||
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_lut, | ||
weight_nbit | ||
); | ||
constexpr bool has_lut = true; | ||
int preferred_alignment = 16; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: push this down closer to the kernel? |
||
|
||
#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) | ||
namespace kernel = torchao::kernels::cpu::aarch64::linear:: | ||
channelwise_8bit_activation_groupwise_lowbit_weight; | ||
|
||
if (cpuinfo_has_arm_neon_dot()) { | ||
return; | ||
} | ||
if (format.has_weight_zeros) { | ||
return; | ||
} | ||
constexpr bool has_weight_zeros = false; | ||
if (format.nr == 8 && format.kr == 16 && format.sr == 2) { | ||
log_registration(format, "lut: kernel_1x8x16_f32_neondot"); | ||
constexpr int n_step = 8; | ||
constexpr int nr = 8; | ||
constexpr int kr = 16; | ||
constexpr int sr = 2; | ||
constexpr int mr = 1; | ||
constexpr int m_step = 1; | ||
auto uk = UKernelConfig::make_with_lut( | ||
preferred_alignment, | ||
n_step, | ||
nr, | ||
kr, | ||
sr, | ||
weight_nbit, | ||
format.has_weight_zeros, | ||
format.has_bias, | ||
&kernel::packed_weights_with_lut_size, | ||
&kernel::packed_weights_with_lut_offset, | ||
&kernel::pack_weights_with_lut<weight_nbit, nr, kr, sr>, | ||
/*linear_configs*/ {}); | ||
uk.linear_configs[0] = UKernelConfig::linear_config_type( | ||
{m_step, | ||
mr, | ||
&kernel::packed_activations_size, | ||
&kernel::packed_activations_offset, | ||
&kernel::pack_activations<mr, kr, sr>, | ||
&kernel::kernel_1x8x16_f32_neondot< | ||
weight_nbit, | ||
has_weight_zeros, | ||
has_lut>}); | ||
table.register_ukernel_config(format, uarch, std::move(uk)); | ||
return; | ||
} | ||
#endif // TORCHAO_ENABLE_ARM_NEON_DOT | ||
} | ||
|
||
#if defined(TORCHAO_ENABLE_KLEIDI) | ||
template <typename kernel_struct> | ||
UKernelConfig::linear_config_type | ||
|
@@ -285,6 +349,14 @@ void register_ukernel_config( | |
#endif // TORCHAO_ENABLE_KLEIDI | ||
break; | ||
} | ||
case torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_lut: { | ||
// LUT kernels static assert on weight_nbit <= 4 | ||
// This is needed to avoid compilation error | ||
if constexpr (weight_nbit <= 4) { | ||
register_ukernel_config_lut<weight_nbit>(table, format, uarch); | ||
} | ||
break; | ||
} | ||
default: | ||
throw std::runtime_error( | ||
"No registration available for packed_weights_type=" + | ||
|
@@ -377,4 +449,24 @@ PackedWeightsFormat select_packed_weights_format( | |
throw std::runtime_error("No packed_weights_format was selected"); | ||
} | ||
|
||
template <int weight_nbit> | ||
PackedWeightsFormat select_packed_weights_with_lut_format( | ||
std::optional<std::string> target, | ||
bool has_weight_zeros, | ||
bool has_bias) { | ||
if (!target) { | ||
#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) | ||
return PackedWeightsFormat( | ||
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_lut, | ||
weight_nbit, | ||
has_weight_zeros, | ||
has_bias, | ||
/*nr*/ 8, | ||
/*kr*/ 16, | ||
/*sr*/ 2); | ||
#endif // defined(TORCHAO_ENABLE_ARM_NEON_DOT) | ||
} | ||
throw std::runtime_error("No packed_weights_format was selected"); | ||
} | ||
|
||
} // namespace torchao::ops::linear_8bit_act_xbit_weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: suffix _ has different meaning too i.e. class private variables, what about capitalized?