Skip to content

Commit 1fbc92f

Browse files
committed
optimize karatsuba split
1 parent 3bd4587 commit 1fbc92f

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

cp-algo/math/fft.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,16 +201,17 @@ namespace cp_algo::math::fft {
201201
void mul(auto &a, auto const& b) {
202202
size_t N = size(a) + size(b) - 1;
203203
if(std::max(size(a), size(b)) > (1 << 23)) {
204+
using T = std::decay_t<decltype(a[0])>;
204205
// do karatsuba to save memory
205206
auto n = (std::max(size(a), size(b)) + 1) / 2;
206-
auto a0 = to<std::vector>(a | std::views::take(n));
207-
auto a1 = to<std::vector>(a | std::views::drop(n));
208-
auto b0 = to<std::vector>(b | std::views::take(n));
209-
auto b1 = to<std::vector>(b | std::views::drop(n));
207+
auto a0 = to<std::vector<T, big_alloc<T>>>(a | std::views::take(n));
208+
auto a1 = to<std::vector<T, big_alloc<T>>>(a | std::views::drop(n));
209+
auto b0 = to<std::vector<T, big_alloc<T>>>(b | std::views::take(n));
210+
auto b1 = to<std::vector<T, big_alloc<T>>>(b | std::views::drop(n));
210211
a0.resize(n); a1.resize(n);
211212
b0.resize(n); b1.resize(n);
212-
auto a01 = to<std::vector>(std::views::zip_transform(std::plus{}, a0, a1));
213-
auto b01 = to<std::vector>(std::views::zip_transform(std::plus{}, b0, b1));
213+
auto a01 = to<std::vector<T, big_alloc<T>>>(std::views::zip_transform(std::plus{}, a0, a1));
214+
auto b01 = to<std::vector<T, big_alloc<T>>>(std::views::zip_transform(std::plus{}, b0, b1));
214215
checkpoint("karatsuba split");
215216
mul(a0, b0);
216217
mul(a1, b1);

cp-algo/util/simd.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ namespace cp_algo {
3333
[[gnu::always_inline]] inline u64x4 montgomery_reduce(u64x4 x, u64x4 mod, u64x4 imod) {
3434
auto x_ninv = u64x4(u32x8(x) * u32x8(imod));
3535
#ifdef __AVX2__
36-
auto x_res = __m256i(x) + _mm256_mul_epu32(__m256i(x_ninv), __m256i(mod));
36+
x += u64x4(_mm256_mul_epu32(__m256i(x_ninv), __m256i(mod)));
3737
#else
38-
auto x_res = x + x_ninv * mod;
38+
x += x_ninv * mod;
3939
#endif
40-
return u64x4(x_res) >> 32;
40+
return x >> 32;
4141
}
4242

4343
[[gnu::always_inline]] inline u64x4 montgomery_mul(u64x4 x, u64x4 y, u64x4 mod, u64x4 imod) {

0 commit comments

Comments
 (0)