Skip to content

Add int8 LUT CPU ops #2026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ void pack_weights(
bias);
}

template <int weight_nbit, int nr, int kr, int sr>
template <int weight_nbit, int nr_, int kr_, int sr_>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: suffix _ has different meaning too i.e. class private variables, what about capitalized?

void pack_weights_with_lut(
// Output
void* packed_weights,
Expand All @@ -141,10 +141,16 @@ void pack_weights_with_lut(
// weight_zeros not packed if nullptr
const int8_t* weight_zeros,
// bias not packed if nullptr
const float* bias) {
const float* bias,
int nr,
int kr,
int sr) {
(void)nr; // unused
(void)kr; // unused
(void)sr; // unused
torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight::weight_packing::
pack_weights_with_lut<weight_nbit, nr, kr, sr>(
pack_weights_with_lut<weight_nbit, nr_, kr_, sr_>(
packed_weights,
n,
k,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,10 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut(
lut.data(),
test_case.weight_scales.data(),
has_weight_zeros ? test_case.weight_zeros.data() : nullptr,
has_bias ? test_case.bias.data() : nullptr);
has_bias ? test_case.bias.data() : nullptr,
nr,
kr,
sr);

std::vector<float> output(m * n);
kernel_1x8x16_f32_neondot<weight_nbit, has_weight_zeros, /*has_lut*/ true>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,23 @@ struct UKernelConfig {
int kr,
int sr);

// Pack weights into packed_weights buffer with int8-valued LUT
using pack_weights_with_lut_fn_type = void (*)(
void* packed_weights,
int n,
int k,
int group_size,
const int8_t* weight_qval_idxs,
int n_luts,
const int8_t* luts,
const float* weight_scales,
const int8_t* weight_zeros,
const float* bias,
int nr,
int kr,
int sr
);

// Run matmul kernel
using kernel_fn_type = void (*)(
float* output,
Expand Down Expand Up @@ -126,6 +143,7 @@ struct UKernelConfig {
packed_weights_size_fn_type packed_weights_size{nullptr};
packed_weights_offset_fn_type packed_weights_offset{nullptr};
pack_weights_fn_type pack_weights{nullptr};
pack_weights_with_lut_fn_type pack_weights_with_lut{nullptr};
Comment on lines 145 to +146
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have a different uKernel config for LUT? or even worse a union :)

Suggested change
pack_weights_fn_type pack_weights{nullptr};
pack_weights_with_lut_fn_type pack_weights_with_lut{nullptr};
union pack_weights_fn_t {
pack_weights_fn_type tiled{nullptr};
pack_weights_with_lut_fn_type tiled_with_lut{nullptr};
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer separate ones for now


// linear_configs must be sorted in ascending m_step
std::array<linear_config_type, kMaxLinearConfigs> linear_configs;
Expand All @@ -144,6 +162,20 @@ struct UKernelConfig {
pack_weights_fn_type pack_weights,
std::array<linear_config_type, kMaxLinearConfigs> linear_configs);

static UKernelConfig make_with_lut(
size_t preferred_alignment,
int n_step,
int nr,
int kr,
int sr,
int weight_nbit,
bool has_weight_zeros,
bool has_bias,
packed_weights_size_fn_type packed_weights_with_lut_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function overload?

packed_weights_offset_fn_type packed_weights_with_lut_offset,
pack_weights_with_lut_fn_type pack_weights_with_lut,
std::array<linear_config_type, kMaxLinearConfigs> linear_configs);

inline void validate() const {
TORCHAO_CHECK(preferred_alignment >= 1, "preferred_alignment must be >= 1");
TORCHAO_CHECK(n_step >= 1, "n_step must be >= 1");
Expand All @@ -155,7 +187,7 @@ struct UKernelConfig {
packed_weights_size != nullptr, "packed_weights_size must be set");
TORCHAO_CHECK(
packed_weights_offset != nullptr, "packed_weights_offset must be set");
TORCHAO_CHECK(pack_weights != nullptr, "pack_weights must be set");
TORCHAO_CHECK(pack_weights != nullptr || pack_weights_with_lut != nullptr, "pack_weights or pack_weights_with_lut must be set");

bool linear_configs_set = true; // first linear config must be set
for (int i = 0; i < linear_configs.size(); i++) {
Expand Down Expand Up @@ -232,6 +264,36 @@ inline UKernelConfig UKernelConfig::make(
packed_weights_size,
packed_weights_offset,
pack_weights,
/*pack_weights_with_lut*/nullptr,
std::move(linear_configs)};
}

inline UKernelConfig UKernelConfig::make_with_lut(
size_t preferred_alignment,
int n_step,
int nr,
int kr,
int sr,
int weight_nbit,
bool has_weight_zeros,
bool has_bias,
packed_weights_size_fn_type packed_weights_with_lut_size,
packed_weights_offset_fn_type packed_weights_with_lut_offset,
pack_weights_with_lut_fn_type pack_weights_with_lut,
std::array<linear_config_type, kMaxLinearConfigs> linear_configs) {
return UKernelConfig{
preferred_alignment,
n_step,
nr,
kr,
sr,
weight_nbit,
has_weight_zeros,
has_bias,
packed_weights_with_lut_size,
packed_weights_with_lut_offset,
/*pack_weights*/nullptr,
/*pack_weights_with_lut*/pack_weights_with_lut,
std::move(linear_configs)};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,70 @@ void register_ukernel_config_universal(
}
}

template <int weight_nbit>
void register_ukernel_config_lut(
UKernelConfigRegistrationTable& table,
PackedWeightsFormat format,
cpuinfo_uarch uarch) {
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
check_format(
format,
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_lut,
weight_nbit
);
constexpr bool has_lut = true;
int preferred_alignment = 16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: push this down closer to the kernel?


#if defined(TORCHAO_ENABLE_ARM_NEON_DOT)
namespace kernel = torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight;

if (cpuinfo_has_arm_neon_dot()) {
return;
}
if (format.has_weight_zeros) {
return;
}
constexpr bool has_weight_zeros = false;
if (format.nr == 8 && format.kr == 16 && format.sr == 2) {
log_registration(format, "lut: kernel_1x8x16_f32_neondot");
constexpr int n_step = 8;
constexpr int nr = 8;
constexpr int kr = 16;
constexpr int sr = 2;
constexpr int mr = 1;
constexpr int m_step = 1;
auto uk = UKernelConfig::make_with_lut(
preferred_alignment,
n_step,
nr,
kr,
sr,
weight_nbit,
format.has_weight_zeros,
format.has_bias,
&kernel::packed_weights_with_lut_size,
&kernel::packed_weights_with_lut_offset,
&kernel::pack_weights_with_lut<weight_nbit, nr, kr, sr>,
/*linear_configs*/ {});
uk.linear_configs[0] = UKernelConfig::linear_config_type(
{m_step,
mr,
&kernel::packed_activations_size,
&kernel::packed_activations_offset,
&kernel::pack_activations<mr, kr, sr>,
&kernel::kernel_1x8x16_f32_neondot<
weight_nbit,
has_weight_zeros,
has_lut>});
table.register_ukernel_config(format, uarch, std::move(uk));
return;
}
#endif // TORCHAO_ENABLE_ARM_NEON_DOT
}

#if defined(TORCHAO_ENABLE_KLEIDI)
template <typename kernel_struct>
UKernelConfig::linear_config_type
Expand Down Expand Up @@ -285,6 +349,14 @@ void register_ukernel_config(
#endif // TORCHAO_ENABLE_KLEIDI
break;
}
case torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_lut: {
// LUT kernels static assert on weight_nbit <= 4
// This is needed to avoid compilation error
if constexpr (weight_nbit <= 4) {
register_ukernel_config_lut<weight_nbit>(table, format, uarch);
}
break;
}
default:
throw std::runtime_error(
"No registration available for packed_weights_type=" +
Expand Down Expand Up @@ -377,4 +449,24 @@ PackedWeightsFormat select_packed_weights_format(
throw std::runtime_error("No packed_weights_format was selected");
}

template <int weight_nbit>
PackedWeightsFormat select_packed_weights_with_lut_format(
std::optional<std::string> target,
bool has_weight_zeros,
bool has_bias) {
if (!target) {
#if defined(TORCHAO_ENABLE_ARM_NEON_DOT)
return PackedWeightsFormat(
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_lut,
weight_nbit,
has_weight_zeros,
has_bias,
/*nr*/ 8,
/*kr*/ 16,
/*sr*/ 2);
#endif // defined(TORCHAO_ENABLE_ARM_NEON_DOT)
}
throw std::runtime_error("No packed_weights_format was selected");
}

} // namespace torchao::ops::linear_8bit_act_xbit_weight
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,61 @@ void pack_weights_operator(
});
}

void pack_weights_with_lut_operator(
const UKernelConfig& uk,
// Outputs
void* packed_weights,
// Inputs
int n,
int k,
int group_size,
const int8_t* weight_qval_idxs,
int n_luts,
const int8_t* luts,
const float* weight_scales,
const int8_t* weight_zeros,
const float* bias) {
int n_step = uk.n_step;
int nc = std::min(n, n_step);
int num_nc_panels = (n + nc - 1) / nc;

torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) {
int nc_tile_idx = idx;
int n_idx = nc_tile_idx * nc;
int nc_tile_size = std::min(nc, n - n_idx);

auto packed_weights_offset = uk.packed_weights_offset(
n_idx,
k,
group_size,
uk.weight_nbit,
uk.has_weight_zeros,
uk.has_bias,
uk.nr,
uk.kr,
uk.sr);

int weight_qval_idxs_offset = n_idx * k;
int weight_scales_and_zeros_offset = (n_idx * k / group_size);
uk.pack_weights_with_lut(
(char*)packed_weights + packed_weights_offset,
/*n=*/nc_tile_size,
k,
group_size,
weight_qval_idxs + weight_qval_idxs_offset,
n_luts,
luts,
weight_scales + weight_scales_and_zeros_offset,
(weight_zeros == nullptr)
? nullptr
: (weight_zeros + weight_scales_and_zeros_offset),
(bias == nullptr) ? nullptr : (bias + n_idx),
uk.nr,
uk.kr,
uk.sr);
});
}

LinearTilingParams LinearTilingParams::from_target_tiles_per_thread(
int m,
int m_step,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ void pack_weights_operator(
const int8_t* weight_zeros,
const float* bias);

void pack_weights_with_lut_operator(
const UKernelConfig& uk,
// Outputs
void* packed_weights,
// Inputs
int n,
int k,
int group_size,
const int8_t* weight_qval_idxs,
int n_luts,
const int8_t* luts,
const float* weight_scales,
const int8_t* weight_zeros,
const float* bias);

// Linear functions
struct LinearTilingParams {
int mc{0};
Expand Down
Loading
Loading