Skip to content

Commit

Permalink
Merge pull request flintlib#1576 from math-gout/neon
Browse files Browse the repository at this point in the history
Work on NEON intrinsics
  • Loading branch information
albinahlback authored Mar 5, 2024
2 parents 11d1edb + 2d4796f commit b471df2
Showing 1 changed file with 126 additions and 102 deletions.
228 changes: 126 additions & 102 deletions src/machine_vectors.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
Copyright (C) 2022 Daniel Schultz
Copyright (C) 2023 Mathieu Gouttenoire
This file is part of FLINT.
Expand Down Expand Up @@ -828,7 +829,9 @@ FLINT_FORCE_INLINE V V##f(V a, V b, V c, V d) { \
}


/* vec1 **************************************************/
/* floating point stuff ******************************************************/

/* vec1d ***********************************************************/

FLINT_FORCE_INLINE vec1d vec1d_load(const double* a) {
return a[0];
Expand Down Expand Up @@ -951,7 +954,7 @@ FLINT_FORCE_INLINE vec1d vec1d_reduce_pm1n_to_pmhn(vec1d a, vec1d n) {
return a;
}

/* vec2 **********************************************************************/
/* vec2d ***********************************************************/

FLINT_FORCE_INLINE double vec2d_get_index(vec2d a, int i) {
return a[i];
Expand All @@ -965,7 +968,7 @@ FLINT_FORCE_INLINE vec2d vec2d_set_d2(double a0, double a1)

FLINT_FORCE_INLINE vec2d vec2d_set_d(double a)
{
return vec2d_set_d2(a, a);
return vdupq_n_f64(a);
}

FLINT_FORCE_INLINE vec2d vec2d_load(const double* a)
Expand Down Expand Up @@ -1000,12 +1003,12 @@ FLINT_FORCE_INLINE void vec2d_store_aligned(double* z, vec2d a)

FLINT_FORCE_INLINE vec2d vec2d_zero(void)
{
return vec2d_set_d(0.0);
return vdupq_n_f64(0.0);
}

FLINT_FORCE_INLINE vec2d vec2d_one(void)
{
return vec2d_set_d(1.0);
return vdupq_n_f64(1.0);
}

FLINT_FORCE_INLINE vec2d vec2d_neg(vec2d a)
Expand Down Expand Up @@ -1057,7 +1060,7 @@ FLINT_FORCE_INLINE vec2d vec2d_mul(vec2d a, vec2d b)

FLINT_FORCE_INLINE vec2d vec2d_half(vec2d a)
{
return vec2d_mul(a, vec2d_set_d(0.5));
return vmulq_n_f64(a, 0.5);
}

FLINT_FORCE_INLINE vec2d vec2d_fmadd(vec2d a, vec2d b, vec2d c) {
Expand Down Expand Up @@ -1113,21 +1116,25 @@ FLINT_FORCE_INLINE vec2d vec2d_reduce_to_pm1n(vec2d a, vec2d n, vec2d ninv) {

FLINT_FORCE_INLINE vec2d vec2d_reduce_0n_to_pmhn(vec2d a, vec2d n) {
vec2d halfn = vec2d_half(n);
return vec2d_blendv(a, vec2d_sub(a, n), vec2d_cmp_gt(a, halfn));
return vbslq_f64(vcgtq_f64(a, halfn), vec2d_sub(a, n), a);
}

FLINT_FORCE_INLINE vec2d vec2d_reduce_pm1n_to_pmhn(vec2d a, vec2d n) {
vec2d halfn = vec2d_half(n);
vec2d t = vec2d_blendv(n, vec2d_neg(n), vec2d_cmp_lt(a, vec2d_zero()));
return vec2d_blendv(a, vec2d_sub(a, t), vec2d_cmp_gt(vec2d_abs(a), halfn));
vec2d t = vec2d_add(a, n);

vec2n condition_a = vcgtq_f64(a, halfn);
vec2n condition_t = vcltq_f64(t, halfn);

return vbslq_f64(condition_a, vec2d_sub(a, n), vbslq_f64(condition_t, t, a));
}

FLINT_FORCE_INLINE vec1d vec1d_reduce_pm1no_to_0n(vec1d a, vec1d n) {
return a >= 0 ? a : a + n;
}

FLINT_FORCE_INLINE vec2d vec2d_reduce_pm1no_to_0n(vec2d a, vec2d n) {
return vec2d_blendv(a, vec2d_add(a, n), vec2d_cmp_lt(a, vec2d_zero()));
return vbslq_f64(vcgeq_f64(a, vec2d_zero()), a, vaddq_f64(a, n));
}

FLINT_FORCE_INLINE vec1d vec1d_reduce_to_0n(vec1d a, vec1d n, vec1d ninv) {
Expand Down Expand Up @@ -1161,7 +1168,7 @@ DEFINE_IT(vec2d)



/* vec4 **********************************************************************/
/* vec4d ***********************************************************/

FLINT_FORCE_INLINE vec4d vec4d_set_vec2d2(vec2d a, vec2d b) {
vec4d z = {a, b}; return z;
Expand All @@ -1179,7 +1186,7 @@ FLINT_FORCE_INLINE vec4d vec4d_set_d4(double a0, double a1, double a2, double a3
FLINT_FORCE_INLINE vec4d vec4d_set_d(double a)
{
vec2d z1 = vec2d_set_d(a);
return vec4d_set_vec2d2(z1, z1);
vec4d z = {z1, z1}; return z;
}

FLINT_FORCE_INLINE void vec4d_store(double* z, vec4d a)
Expand Down Expand Up @@ -1301,7 +1308,11 @@ EXTEND_VEC_DEF3(vec2d, vec4d, _blendv)
EXTEND_VEC_DEF4(vec2d, vec4d, _mulmod)
EXTEND_VEC_DEF4(vec2d, vec4d, _nmulmod)

/* vec8 **********************************************************************/
/* vec8d ***********************************************************/

FLINT_FORCE_INLINE vec8d vec8d_set_vec4d2(vec4d a, vec4d b) {
vec8d z = {a, b}; return z;
}

FLINT_FORCE_INLINE double vec8d_get_index(vec8d a, int i) {
return i < 4 ? vec4d_get_index(a.e1, i) : vec4d_get_index(a.e2, i - 4);
Expand Down Expand Up @@ -1405,130 +1416,160 @@ DEFINE_IT(vec2d)

/* integer stuff *************************************************************/

/* vec1n ***********************************************************/

FLINT_FORCE_INLINE void vec1n_store_unaligned(ulong* z, vec1n a) {
z[0] = a;
}

FLINT_FORCE_INLINE void vec2n_store_unaligned(ulong* z, vec2n a) {
vst1q_u64(z, a);
FLINT_FORCE_INLINE vec1n vec1d_convert_limited_vec1n(vec1d a) {
return (slong)a;
}

FLINT_FORCE_INLINE void vec4n_store_unaligned(ulong* z, vec4n a) {
vec2n_store_unaligned(z+0, a.e1);
vec2n_store_unaligned(z+2, a.e2);
// (a + b) % n
FLINT_FORCE_INLINE vec1n vec1n_addmod(vec1n a, vec1n b, vec1n n) {
vec1n nmb = n - b;
return nmb > a ? a + b : a - nmb;
}

FLINT_FORCE_INLINE vec1n vec1d_convert_limited_vec1n(vec1d a) {
return (slong)a;
}
/* vec2n ***********************************************************/

FLINT_FORCE_INLINE vec2n vec2d_convert_limited_vec2n(vec2d a) {
return vcvtnq_u64_f64(a);
FLINT_FORCE_INLINE vec2d vec2n_convert_limited_vec2d(vec2n a) {
float64x2_t t = vdupq_n_f64(0x1.0p52);
return vsubq_f64(vreinterpretq_f64_u64(vorrq_u64(a, vreinterpretq_u64_f64(t))), t);
}

FLINT_FORCE_INLINE vec4n vec4d_convert_limited_vec4n(vec4d a) {
vec2n z1 = vec2d_convert_limited_vec2n(a.e1);
vec2n z2 = vec2d_convert_limited_vec2n(a.e2);
vec4n z = {z1, z2}; return z;
FLINT_FORCE_INLINE void vec2n_store_unaligned(ulong* z, vec2n a) {
vst1q_u64(z, a);
}

FLINT_FORCE_INLINE vec2n vec2d_convert_limited_vec2n(vec2d a) {
return vcvtnq_u64_f64(a);
}

FLINT_FORCE_INLINE vec2n vec2n_set_n(ulong a) {
vec2n x = vdupq_n_u64(a);
return x;
}

FLINT_FORCE_INLINE vec4n vec4n_set_n(ulong a) {
vec2n x = vec2n_set_n(a);
vec4n z = {x, x};
return z;
FLINT_FORCE_INLINE vec2n vec2n_load_unaligned(const ulong* a) {
return vld1q_u64(a);
}

FLINT_FORCE_INLINE vec8n vec8n_set_n(ulong a) {
vec4n x = vec4n_set_n(a);
vec8n z = {x, x};
return z;
// Right shift 32bits
FLINT_FORCE_INLINE vec2n vec2n_bit_shift_right_32(vec2n a) {
return vshrq_n_u64(a, 32);
}

FLINT_FORCE_INLINE vec2n vec2n_load_unaligned(const ulong* a)
{
return vld1q_u64(a);
// AND operation
FLINT_FORCE_INLINE vec2n vec2n_bit_and(vec2n a, vec2n b) {
return vandq_u64(a, b);
}

FLINT_FORCE_INLINE vec4n vec4n_load_unaligned(const ulong* a)
{
vec2n z1 = vec2n_load_unaligned(a+0);
vec2n z2 = vec2n_load_unaligned(a+2);
vec4n z = {z1, z2}; return z;
// Addition
FLINT_FORCE_INLINE vec2n vec2n_add(vec2n a, vec2n b) {
return vaddq_u64(a, b);
}

FLINT_FORCE_INLINE vec8n vec8n_load_unaligned(const ulong* a)
{
vec4n z1 = vec4n_load_unaligned(a+0);
vec4n z2 = vec4n_load_unaligned(a+4);
vec8n z = {z1, z2}; return z;
// Substraction
FLINT_FORCE_INLINE vec2n vec2n_sub(vec2n a, vec2n b) {
return vsubq_u64(a, b);
}

/* todo: implement using intrinsics */
// (a + b) % n
FLINT_FORCE_INLINE vec2n vec2n_addmod(vec2n a, vec2n b, vec2n n) {
vec2n nmb = vec2n_sub(n, b);
vec2n sum = vec2n_sub(a, nmb);

vec2n mask = vcgtq_u64(nmb, a);

return vec2n_add(sum, vandq_u64(n, mask));
}

FLINT_FORCE_INLINE vec1n vec1n_addmod(vec1n a, vec1n b, vec1n n)
{
return n - b > a ? a + b : a + b - n;
// (a + b) % n for n < 2^63
FLINT_FORCE_INLINE vec2n vec2n_addmod_limited(vec2n a, vec2n b, vec2n n) {
vec2n s = vec2n_add(a, b);

vec2n mask = vcgeq_u64(s, n);

return vec2n_sub(s, vandq_u64(n, mask));
}

FLINT_FORCE_INLINE vec2n vec2n_addmod(vec2n a, vec2n b, vec2n n)
{
vec2n z = {vec1n_addmod(a[0], b[0], n[0]), vec1n_addmod(a[1], b[1], n[1])};
return z;
/* vec4n ***********************************************************/

FLINT_FORCE_INLINE vec4d vec4n_convert_limited_vec4d(vec4n a) {
vec2d z1 = vec2n_convert_limited_vec2d(a.e1);
vec2d z2 = vec2n_convert_limited_vec2d(a.e2);
vec4d z = {z1, z2}; return z;
}

EXTEND_VEC_DEF3(vec2n, vec4n, _addmod)
EXTEND_VEC_DEF3(vec4n, vec8n, _addmod)
FLINT_FORCE_INLINE void vec4n_store_unaligned(ulong* z, vec4n a) {
vec2n_store_unaligned(z+0, a.e1);
vec2n_store_unaligned(z+2, a.e2);
}

/* todo: optimized version */
FLINT_FORCE_INLINE vec8n vec8n_addmod_limited(vec8n a, vec8n b, vec8n n)
{
return vec8n_addmod(a, b, n);
FLINT_FORCE_INLINE vec4n vec4d_convert_limited_vec4n(vec4d a) {
vec2n z1 = vec2d_convert_limited_vec2n(a.e1);
vec2n z2 = vec2d_convert_limited_vec2n(a.e2);
vec4n z = {z1, z2}; return z;
}

FLINT_FORCE_INLINE vec2d vec2n_convert_limited_vec2d(vec2n a)
{
float64x2_t t = vdupq_n_f64(0x1.0p52);
return vsubq_f64(vreinterpretq_f64_u64(vorrq_u64(a, vreinterpretq_u64_f64(t))), t);
FLINT_FORCE_INLINE vec4n vec4n_set_n(ulong a) {
vec2n x = vec2n_set_n(a);
vec4n z = {x, x};
return z;
}

FLINT_FORCE_INLINE vec4d vec4n_convert_limited_vec4d(vec4n a)
{
return vec4d_set_vec2d2(vec2n_convert_limited_vec2d(a.e1),
vec2n_convert_limited_vec2d(a.e2));
FLINT_FORCE_INLINE vec4n vec4n_load_unaligned(const ulong* a) {
vec2n z1 = vec2n_load_unaligned(a+0);
vec2n z2 = vec2n_load_unaligned(a+2);
vec4n z = {z1, z2}; return z;
}

FLINT_FORCE_INLINE vec8d vec8d_set_vec4d2(vec4d a, vec4d b)
{
vec8d z = {a, b}; return z;
EXTEND_VEC_DEF1(vec2n, vec4n, _bit_shift_right_32)
EXTEND_VEC_DEF2(vec2n, vec4n, _bit_and)
EXTEND_VEC_DEF2(vec2n, vec4n, _add)
EXTEND_VEC_DEF2(vec2n, vec4n, _sub)
EXTEND_VEC_DEF3(vec2n, vec4n, _addmod)
EXTEND_VEC_DEF3(vec2n, vec4n, _addmod_limited)

/* vec8n ***********************************************************/

FLINT_FORCE_INLINE vec8d vec8n_convert_limited_vec8d(vec8n a) {
vec4d z1 = vec4n_convert_limited_vec4d(a.e1);
vec4d z2 = vec4n_convert_limited_vec4d(a.e2);
vec8d z = {z1, z2}; return z;
}

FLINT_FORCE_INLINE vec8d vec8n_convert_limited_vec8d(vec8n a)
{
return vec8d_set_vec4d2(vec4n_convert_limited_vec4d(a.e1),
vec4n_convert_limited_vec4d(a.e2));
FLINT_FORCE_INLINE vec8n vec8n_set_n(ulong a) {
vec4n x = vec4n_set_n(a);
vec8n z = {x, x};
return z;
}

FLINT_FORCE_INLINE vec2n vec2n_bit_and(vec2n a, vec2n b)
FLINT_FORCE_INLINE vec8n vec8n_load_unaligned(const ulong* a)
{
return vandq_u64(a, b);
vec4n z1 = vec4n_load_unaligned(a+0);
vec4n z2 = vec4n_load_unaligned(a+4);
vec8n z = {z1, z2}; return z;
}

EXTEND_VEC_DEF2(vec2n, vec4n, _bit_and)

EXTEND_VEC_DEF1(vec4n, vec8n, _bit_shift_right_32)
EXTEND_VEC_DEF2(vec4n, vec8n, _bit_and)
EXTEND_VEC_DEF2(vec4n, vec8n, _add)
EXTEND_VEC_DEF2(vec4n, vec8n, _sub)
EXTEND_VEC_DEF3(vec4n, vec8n, _addmod)
EXTEND_VEC_DEF3(vec4n, vec8n, _addmod_limited)


#if 0
/*
vshrq_n_u64(a, n) cannot be used because n must be a compile-time
constant, and the compiler doesn't see that n is constant
even if the function is forced inline.
And vshlq_s64(a, vdupq_n_s64(-(slong) n)) cannot be used to emulate
vshrq_n_u64(a, n) as it propagates the sign bit.
vshrq_n_u64(a, n) cannot be used because n must be a compile-time
constant, and the compiler doesn't see that n is constant
even if the function is forced inline.
And vshlq_s64(a, vdupq_n_s64(-(slong) n)) cannot be used to emulate
vshrq_n_u64(a, n) as it propagates the sign bit.
*/
FLINT_FORCE_INLINE vec2n vec2n_bit_shift_right(vec2n a, ulong n)
{
Expand All @@ -1548,23 +1589,6 @@ FLINT_FORCE_INLINE vec8n vec8n_bit_shift_right(vec8n a, ulong n)
}
#endif

FLINT_FORCE_INLINE vec2n vec2n_bit_shift_right_32(vec2n a)
{
return vshrq_n_u64(a, 32);
}

FLINT_FORCE_INLINE vec4n vec4n_bit_shift_right_32(vec4n a)
{
vec4n z = {vec2n_bit_shift_right_32(a.e1), vec2n_bit_shift_right_32(a.e2)};
return z;
}

FLINT_FORCE_INLINE vec8n vec8n_bit_shift_right_32(vec8n a)
{
vec8n z = {vec4n_bit_shift_right_32(a.e1), vec4n_bit_shift_right_32(a.e2)};
return z;
}


#undef EXTEND_VEC_DEF4
#undef EXTEND_VEC_DEF3
Expand Down

0 comments on commit b471df2

Please sign in to comment.