Skip to content

Commit 337a057

Browse files
authored
Add int8 LUT CPU ops (#2026)
init
1 parent 8334340 commit 337a057

File tree

10 files changed

+512
-13
lines changed

10 files changed

+512
-13
lines changed

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ void pack_weights(
126126
bias);
127127
}
128128

129-
template <int weight_nbit, int nr, int kr, int sr>
129+
template <int weight_nbit, int nr_, int kr_, int sr_>
130130
void pack_weights_with_lut(
131131
// Output
132132
void* packed_weights,
@@ -141,10 +141,16 @@ void pack_weights_with_lut(
141141
// weight_zeros not packed if nullptr
142142
const int8_t* weight_zeros,
143143
// bias not packed if nullptr
144-
const float* bias) {
144+
const float* bias,
145+
int nr,
146+
int kr,
147+
int sr) {
148+
(void)nr; // unused
149+
(void)kr; // unused
150+
(void)sr; // unused
145151
torchao::kernels::cpu::aarch64::linear::
146152
channelwise_8bit_activation_groupwise_lowbit_weight::weight_packing::
147-
pack_weights_with_lut<weight_nbit, nr, kr, sr>(
153+
pack_weights_with_lut<weight_nbit, nr_, kr_, sr_>(
148154
packed_weights,
149155
n,
150156
k,

torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,10 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut(
478478
lut.data(),
479479
test_case.weight_scales.data(),
480480
has_weight_zeros ? test_case.weight_zeros.data() : nullptr,
481-
has_bias ? test_case.bias.data() : nullptr);
481+
has_bias ? test_case.bias.data() : nullptr,
482+
nr,
483+
kr,
484+
sr);
482485

483486
std::vector<float> output(m * n);
484487
kernel_1x8x16_f32_neondot<weight_nbit, has_weight_zeros, /*has_lut*/ true>(

torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,23 @@ struct UKernelConfig {
8787
int kr,
8888
int sr);
8989

90+
// Pack weights into packed_weights buffer with int8-valued LUT
91+
using pack_weights_with_lut_fn_type = void (*)(
92+
void* packed_weights,
93+
int n,
94+
int k,
95+
int group_size,
96+
const int8_t* weight_qval_idxs,
97+
int n_luts,
98+
const int8_t* luts,
99+
const float* weight_scales,
100+
const int8_t* weight_zeros,
101+
const float* bias,
102+
int nr,
103+
int kr,
104+
int sr
105+
);
106+
90107
// Run matmul kernel
91108
using kernel_fn_type = void (*)(
92109
float* output,
@@ -126,6 +143,7 @@ struct UKernelConfig {
126143
packed_weights_size_fn_type packed_weights_size{nullptr};
127144
packed_weights_offset_fn_type packed_weights_offset{nullptr};
128145
pack_weights_fn_type pack_weights{nullptr};
146+
pack_weights_with_lut_fn_type pack_weights_with_lut{nullptr};
129147

130148
// linear_configs must be sorted in ascending m_step
131149
std::array<linear_config_type, kMaxLinearConfigs> linear_configs;
@@ -144,6 +162,20 @@ struct UKernelConfig {
144162
pack_weights_fn_type pack_weights,
145163
std::array<linear_config_type, kMaxLinearConfigs> linear_configs);
146164

165+
static UKernelConfig make_with_lut(
166+
size_t preferred_alignment,
167+
int n_step,
168+
int nr,
169+
int kr,
170+
int sr,
171+
int weight_nbit,
172+
bool has_weight_zeros,
173+
bool has_bias,
174+
packed_weights_size_fn_type packed_weights_with_lut_size,
175+
packed_weights_offset_fn_type packed_weights_with_lut_offset,
176+
pack_weights_with_lut_fn_type pack_weights_with_lut,
177+
std::array<linear_config_type, kMaxLinearConfigs> linear_configs);
178+
147179
inline void validate() const {
148180
TORCHAO_CHECK(preferred_alignment >= 1, "preferred_alignment must be >= 1");
149181
TORCHAO_CHECK(n_step >= 1, "n_step must be >= 1");
@@ -155,7 +187,7 @@ struct UKernelConfig {
155187
packed_weights_size != nullptr, "packed_weights_size must be set");
156188
TORCHAO_CHECK(
157189
packed_weights_offset != nullptr, "packed_weights_offset must be set");
158-
TORCHAO_CHECK(pack_weights != nullptr, "pack_weights must be set");
190+
TORCHAO_CHECK(pack_weights != nullptr || pack_weights_with_lut != nullptr, "pack_weights or pack_weights_with_lut must be set");
159191

160192
bool linear_configs_set = true; // first linear config must be set
161193
for (int i = 0; i < linear_configs.size(); i++) {
@@ -232,6 +264,36 @@ inline UKernelConfig UKernelConfig::make(
232264
packed_weights_size,
233265
packed_weights_offset,
234266
pack_weights,
267+
/*pack_weights_with_lut*/nullptr,
268+
std::move(linear_configs)};
269+
}
270+
271+
inline UKernelConfig UKernelConfig::make_with_lut(
272+
size_t preferred_alignment,
273+
int n_step,
274+
int nr,
275+
int kr,
276+
int sr,
277+
int weight_nbit,
278+
bool has_weight_zeros,
279+
bool has_bias,
280+
packed_weights_size_fn_type packed_weights_with_lut_size,
281+
packed_weights_offset_fn_type packed_weights_with_lut_offset,
282+
pack_weights_with_lut_fn_type pack_weights_with_lut,
283+
std::array<linear_config_type, kMaxLinearConfigs> linear_configs) {
284+
return UKernelConfig{
285+
preferred_alignment,
286+
n_step,
287+
nr,
288+
kr,
289+
sr,
290+
weight_nbit,
291+
has_weight_zeros,
292+
has_bias,
293+
packed_weights_with_lut_size,
294+
packed_weights_with_lut_offset,
295+
/*pack_weights*/nullptr,
296+
/*pack_weights_with_lut*/pack_weights_with_lut,
235297
std::move(linear_configs)};
236298
}
237299

torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,70 @@ void register_ukernel_config_universal(
164164
}
165165
}
166166

167+
template <int weight_nbit>
168+
void register_ukernel_config_lut(
169+
UKernelConfigRegistrationTable& table,
170+
PackedWeightsFormat format,
171+
cpuinfo_uarch uarch) {
172+
if (!cpuinfo_initialize()) {
173+
throw std::runtime_error("Failed to initialize cpuinfo!");
174+
}
175+
check_format(
176+
format,
177+
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_lut,
178+
weight_nbit
179+
);
180+
constexpr bool has_lut = true;
181+
int preferred_alignment = 16;
182+
183+
#if defined(TORCHAO_ENABLE_ARM_NEON_DOT)
184+
namespace kernel = torchao::kernels::cpu::aarch64::linear::
185+
channelwise_8bit_activation_groupwise_lowbit_weight;
186+
187+
if (cpuinfo_has_arm_neon_dot()) {
188+
return;
189+
}
190+
if (format.has_weight_zeros) {
191+
return;
192+
}
193+
constexpr bool has_weight_zeros = false;
194+
if (format.nr == 8 && format.kr == 16 && format.sr == 2) {
195+
log_registration(format, "lut: kernel_1x8x16_f32_neondot");
196+
constexpr int n_step = 8;
197+
constexpr int nr = 8;
198+
constexpr int kr = 16;
199+
constexpr int sr = 2;
200+
constexpr int mr = 1;
201+
constexpr int m_step = 1;
202+
auto uk = UKernelConfig::make_with_lut(
203+
preferred_alignment,
204+
n_step,
205+
nr,
206+
kr,
207+
sr,
208+
weight_nbit,
209+
format.has_weight_zeros,
210+
format.has_bias,
211+
&kernel::packed_weights_with_lut_size,
212+
&kernel::packed_weights_with_lut_offset,
213+
&kernel::pack_weights_with_lut<weight_nbit, nr, kr, sr>,
214+
/*linear_configs*/ {});
215+
uk.linear_configs[0] = UKernelConfig::linear_config_type(
216+
{m_step,
217+
mr,
218+
&kernel::packed_activations_size,
219+
&kernel::packed_activations_offset,
220+
&kernel::pack_activations<mr, kr, sr>,
221+
&kernel::kernel_1x8x16_f32_neondot<
222+
weight_nbit,
223+
has_weight_zeros,
224+
has_lut>});
225+
table.register_ukernel_config(format, uarch, std::move(uk));
226+
return;
227+
}
228+
#endif // TORCHAO_ENABLE_ARM_NEON_DOT
229+
}
230+
167231
#if defined(TORCHAO_ENABLE_KLEIDI)
168232
template <typename kernel_struct>
169233
UKernelConfig::linear_config_type
@@ -285,6 +349,14 @@ void register_ukernel_config(
285349
#endif // TORCHAO_ENABLE_KLEIDI
286350
break;
287351
}
352+
case torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_lut: {
353+
// LUT kernels static assert on weight_nbit <= 4
354+
// This is needed to avoid compilation error
355+
if constexpr (weight_nbit <= 4) {
356+
register_ukernel_config_lut<weight_nbit>(table, format, uarch);
357+
}
358+
break;
359+
}
288360
default:
289361
throw std::runtime_error(
290362
"No registration available for packed_weights_type=" +
@@ -377,4 +449,24 @@ PackedWeightsFormat select_packed_weights_format(
377449
throw std::runtime_error("No packed_weights_format was selected");
378450
}
379451

452+
template <int weight_nbit>
453+
PackedWeightsFormat select_packed_weights_with_lut_format(
454+
std::optional<std::string> target,
455+
bool has_weight_zeros,
456+
bool has_bias) {
457+
if (!target) {
458+
#if defined(TORCHAO_ENABLE_ARM_NEON_DOT)
459+
return PackedWeightsFormat(
460+
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_lut,
461+
weight_nbit,
462+
has_weight_zeros,
463+
has_bias,
464+
/*nr*/ 8,
465+
/*kr*/ 16,
466+
/*sr*/ 2);
467+
#endif // defined(TORCHAO_ENABLE_ARM_NEON_DOT)
468+
}
469+
throw std::runtime_error("No packed_weights_format was selected");
470+
}
471+
380472
} // namespace torchao::ops::linear_8bit_act_xbit_weight

torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,61 @@ void pack_weights_operator(
7575
});
7676
}
7777

78+
void pack_weights_with_lut_operator(
79+
const UKernelConfig& uk,
80+
// Outputs
81+
void* packed_weights,
82+
// Inputs
83+
int n,
84+
int k,
85+
int group_size,
86+
const int8_t* weight_qval_idxs,
87+
int n_luts,
88+
const int8_t* luts,
89+
const float* weight_scales,
90+
const int8_t* weight_zeros,
91+
const float* bias) {
92+
int n_step = uk.n_step;
93+
int nc = std::min(n, n_step);
94+
int num_nc_panels = (n + nc - 1) / nc;
95+
96+
torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) {
97+
int nc_tile_idx = idx;
98+
int n_idx = nc_tile_idx * nc;
99+
int nc_tile_size = std::min(nc, n - n_idx);
100+
101+
auto packed_weights_offset = uk.packed_weights_offset(
102+
n_idx,
103+
k,
104+
group_size,
105+
uk.weight_nbit,
106+
uk.has_weight_zeros,
107+
uk.has_bias,
108+
uk.nr,
109+
uk.kr,
110+
uk.sr);
111+
112+
int weight_qval_idxs_offset = n_idx * k;
113+
int weight_scales_and_zeros_offset = (n_idx * k / group_size);
114+
uk.pack_weights_with_lut(
115+
(char*)packed_weights + packed_weights_offset,
116+
/*n=*/nc_tile_size,
117+
k,
118+
group_size,
119+
weight_qval_idxs + weight_qval_idxs_offset,
120+
n_luts,
121+
luts,
122+
weight_scales + weight_scales_and_zeros_offset,
123+
(weight_zeros == nullptr)
124+
? nullptr
125+
: (weight_zeros + weight_scales_and_zeros_offset),
126+
(bias == nullptr) ? nullptr : (bias + n_idx),
127+
uk.nr,
128+
uk.kr,
129+
uk.sr);
130+
});
131+
}
132+
78133
LinearTilingParams LinearTilingParams::from_target_tiles_per_thread(
79134
int m,
80135
int m_step,

torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@ void pack_weights_operator(
2727
const int8_t* weight_zeros,
2828
const float* bias);
2929

30+
void pack_weights_with_lut_operator(
31+
const UKernelConfig& uk,
32+
// Outputs
33+
void* packed_weights,
34+
// Inputs
35+
int n,
36+
int k,
37+
int group_size,
38+
const int8_t* weight_qval_idxs,
39+
int n_luts,
40+
const int8_t* luts,
41+
const float* weight_scales,
42+
const int8_t* weight_zeros,
43+
const float* bias);
44+
3045
// Linear functions
3146
struct LinearTilingParams {
3247
int mc{0};

0 commit comments

Comments
 (0)