Skip to content

Commit 9c07b46

Browse files
authored
[SYCL][libdevice] Add fp32 and fp64 division with rounding mode supported in kernel code (#11704)
This PR adds: fdiv_rd, fdiv_rn, fdiv_ru, fdiv_rz, ddiv_rd, ddiv_rn, ddiv_ru, ddiv_rz into imf libdevice which implements fp32/64 division with rounding mode supported in sycl kernel code. Signed-off-by: jinge90 <[email protected]>
1 parent 3139f5c commit 9c07b46

File tree

10 files changed

+444
-0
lines changed

10 files changed

+444
-0
lines changed

libdevice/imf_rounding_op.hpp

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,4 +627,236 @@ template <typename Ty> Ty __fp_mul(Ty x, Ty y, int rd) {
627627
Ty, (z_sig << (sizeof(Ty) * 8 - 1)) |
628628
(z_exp << (std::numeric_limits<Ty>::digits - 1)) | z_fra);
629629
}
630+
631+
template <typename UTy> static UTy fra_uint_div(UTy x, UTy y, unsigned nbits) {
632+
UTy res = 0;
633+
unsigned iters = 0;
634+
if (x == 0)
635+
return 0x0;
636+
while (iters < nbits) {
637+
res = res << 1;
638+
x = x << 1;
639+
if (x > y) {
640+
x = x - y;
641+
res = res | 0x1;
642+
} else if (x == y) {
643+
res = res | 0x1;
644+
res = res << (nbits - iters - 1);
645+
return res;
646+
} else {
647+
}
648+
iters++;
649+
}
650+
res = res | 0x1;
651+
return res;
652+
}
653+
654+
template <typename Ty> Ty __fp_div(Ty x, Ty y, int rd) {
655+
typedef typename __iml_fp_config<Ty>::utype UTy;
656+
typedef typename __iml_fp_config<Ty>::stype STy;
657+
UTy x_bit = __builtin_bit_cast(UTy, x);
658+
UTy y_bit = __builtin_bit_cast(UTy, y);
659+
UTy x_exp = (x_bit >> (std::numeric_limits<Ty>::digits - 1)) &
660+
__iml_fp_config<Ty>::exp_mask;
661+
UTy y_exp = (y_bit >> (std::numeric_limits<Ty>::digits - 1)) &
662+
__iml_fp_config<Ty>::exp_mask;
663+
UTy x_fra = x_bit & __iml_fp_config<Ty>::fra_mask;
664+
UTy y_fra = y_bit & __iml_fp_config<Ty>::fra_mask;
665+
UTy x_sig = x_bit >> ((sizeof(Ty) * 8) - 1);
666+
UTy y_sig = y_bit >> ((sizeof(Ty) * 8) - 1);
667+
UTy z_sig = x_sig ^ y_sig;
668+
UTy z_exp = 0x0, z_fra = 0x0;
669+
const UTy one_bits = 0x1;
670+
const UTy sig_off_mask = (one_bits << (sizeof(UTy) * 8 - 1)) - 1;
671+
672+
if (((x_exp == __iml_fp_config<Ty>::exp_mask) && (x_fra != 0x0)) ||
673+
((y_exp == __iml_fp_config<Ty>::exp_mask) && (y_fra != 0x0)) ||
674+
((y_bit & sig_off_mask) == 0x0)) {
675+
UTy tmp = __iml_fp_config<Ty>::nan_bits;
676+
return __builtin_bit_cast(Ty, tmp);
677+
}
678+
679+
if ((x_exp == __iml_fp_config<Ty>::exp_mask) && (x_fra == 0x0)) {
680+
if ((y_exp == __iml_fp_config<Ty>::exp_mask) && (y_fra == 0x0)) {
681+
UTy tmp = __iml_fp_config<Ty>::nan_bits;
682+
return __builtin_bit_cast(Ty, tmp);
683+
} else {
684+
UTy tmp =
685+
(z_sig << (sizeof(Ty) * 8 - 1)) | __iml_fp_config<Ty>::pos_inf_bits;
686+
return __builtin_bit_cast(Ty, tmp);
687+
}
688+
}
689+
690+
if ((x_bit & sig_off_mask) == 0x0)
691+
return __builtin_bit_cast(Ty, (z_sig << (sizeof(UTy) * 8 - 1)) | 0x0);
692+
693+
if ((y_exp == __iml_fp_config<Ty>::exp_mask) && (y_fra == 0x0))
694+
return __builtin_bit_cast(Ty, (z_sig << (sizeof(UTy) * 8 - 1)) | 0x0);
695+
696+
int sx_exp = x_exp, sy_exp = y_exp;
697+
sx_exp = (sx_exp == 0) ? (1 - __iml_fp_config<Ty>::bias)
698+
: (sx_exp - __iml_fp_config<Ty>::bias);
699+
sy_exp = (sy_exp == 0) ? (1 - __iml_fp_config<Ty>::bias)
700+
: (sy_exp - __iml_fp_config<Ty>::bias);
701+
int exp_diff = sx_exp - sy_exp;
702+
if (x_exp != 0x0)
703+
x_fra = (one_bits << (std::numeric_limits<Ty>::digits - 1)) | x_fra;
704+
if (y_exp != 0x0)
705+
y_fra = (one_bits << (std::numeric_limits<Ty>::digits - 1)) | y_fra;
706+
707+
if (x_fra >= y_fra) {
708+
// x_fra / y_fra max value for fp32 is 0xFFFFFF when x is normal
709+
// and y is subnormal, so msb_pos max value is 23
710+
UTy tmp = x_fra / y_fra;
711+
UTy fra_rem = x_fra - y_fra * tmp;
712+
int msb_pos = get_msb_pos(tmp);
713+
int tmp2 = exp_diff + msb_pos;
714+
if (tmp2 > __iml_fp_config<Ty>::bias)
715+
return __handling_fp_overflow<Ty>(z_sig, rd);
716+
717+
if (tmp2 >= (1 - __iml_fp_config<Ty>::bias)) {
718+
// Fall into normal floating point range
719+
z_exp = tmp2 + __iml_fp_config<Ty>::bias;
720+
// For fp32, starting msb_pos bits in fra comes from tmp and we need
721+
// 23 - msb_pos( + grs) more bits from fraction division.
722+
z_fra = ((one_bits << msb_pos) - 1) & tmp;
723+
z_fra = z_fra << ((std::numeric_limits<Ty>::digits - 1) - msb_pos);
724+
UTy fra_bits_quo = fra_uint_div(
725+
fra_rem, y_fra, std::numeric_limits<Ty>::digits - msb_pos + 2);
726+
z_fra = z_fra | (fra_bits_quo >> 3);
727+
int rb = __handling_rounding(z_sig, z_fra, fra_bits_quo & 0x7, rd);
728+
if (rb != 0) {
729+
z_fra++;
730+
if (z_fra > __iml_fp_config<Ty>::fra_mask) {
731+
z_exp++;
732+
if (z_exp == __iml_fp_config<Ty>::exp_mask)
733+
return __handling_fp_overflow<Ty>(z_sig, rd);
734+
}
735+
}
736+
return __builtin_bit_cast(
737+
Ty, (z_sig << (sizeof(Ty) * 8 - 1)) |
738+
(z_exp << (std::numeric_limits<Ty>::digits - 1)) | z_fra);
739+
}
740+
741+
// orignal value can be represented as (0.1xxxx.... * 2^tmp2)
742+
// which is equivalent to 0.00000...1xxxxx * 2^(-126)
743+
tmp2 = tmp2 + 1;
744+
if ((tmp2 + std::numeric_limits<Ty>::digits - 1) <=
745+
(1 - __iml_fp_config<Ty>::bias)) {
746+
bool above_half = false;
747+
if ((tmp2 + std::numeric_limits<Ty>::digits - 1) ==
748+
(1 - __iml_fp_config<Ty>::bias))
749+
above_half =
750+
!((x_fra == y_fra * tmp) && (tmp == (one_bits << msb_pos)));
751+
return __handling_fp_underflow<Ty, UTy>(z_sig, rd, above_half);
752+
} else {
753+
int rb;
754+
// Fall into subnormal floating point range. For fp32, there are -126 -
755+
// tmp2 leading zeros in final fra and we need get 23 + 126 + tmp2( + grs)
756+
// bits from fraction division.
757+
if (msb_pos >= (std::numeric_limits<Ty>::digits +
758+
__iml_fp_config<Ty>::bias + tmp2)) {
759+
unsigned fra_discard_bits = msb_pos + 3 - __iml_fp_config<Ty>::bias -
760+
std::numeric_limits<Ty>::digits - tmp2;
761+
z_fra = tmp >> fra_discard_bits;
762+
int grs_bits = (tmp >> (fra_discard_bits - 3)) & 0x7;
763+
if ((grs_bits & 0x1) == 0x0) {
764+
if ((tmp & ((0x1 << (fra_discard_bits - 3)) - 0x1)) || (fra_rem != 0))
765+
grs_bits = grs_bits | 0x1;
766+
}
767+
rb = __handling_rounding(z_sig, z_fra, grs_bits, rd);
768+
} else {
769+
// For fp32, we need to get (23 + 126 + tmp2 + 3) - (msb_pos + 1) bits
770+
// from fra division and the last bit is sticky bit.
771+
z_fra = tmp;
772+
unsigned fra_get_bits = std::numeric_limits<Ty>::digits +
773+
__iml_fp_config<Ty>::bias + tmp2 - msb_pos;
774+
z_fra = z_fra << fra_get_bits;
775+
UTy fra_bits_quo = fra_uint_div(fra_rem, y_fra, fra_get_bits);
776+
z_fra = z_fra | fra_bits_quo;
777+
int grs_bits = z_fra & 0x7;
778+
z_fra = z_fra >> 3;
779+
rb = __handling_rounding(z_sig, z_fra, grs_bits, rd);
780+
}
781+
if (rb != 0) {
782+
z_fra++;
783+
if (z_fra > __iml_fp_config<Ty>::fra_mask) {
784+
z_exp++;
785+
z_fra = 0x0;
786+
}
787+
}
788+
return __builtin_bit_cast(
789+
Ty, (z_sig << (sizeof(Ty) * 8 - 1)) |
790+
(z_exp << (std::numeric_limits<Ty>::digits - 1)) | z_fra);
791+
}
792+
} else {
793+
// x_fra < y_fra, the final result can be represented as
794+
// (2^exp_diff) * 0.000...01xxxxx
795+
unsigned lz = 0;
796+
UTy x_tmp = x_fra;
797+
x_tmp = x_tmp << 1;
798+
while (x_tmp < y_fra) {
799+
lz++;
800+
x_tmp = x_tmp << 1;
801+
}
802+
// x_fra < y_fra, the final result can be represented as
803+
// (2^exp_diff) * 0.000...01xxxxx... which is equivalent to
804+
// 2 ^ (exp_diff - lz - 1) * 1.xxxxx...
805+
int nor_exp = exp_diff - lz - 1;
806+
if (nor_exp > __iml_fp_config<Ty>::bias)
807+
return __handling_fp_overflow<Ty>(z_sig, rd);
808+
809+
if (nor_exp >= (1 - __iml_fp_config<Ty>::bias)) {
810+
z_exp = nor_exp + __iml_fp_config<Ty>::bias;
811+
x_fra = x_fra << lz;
812+
UTy fra_bits_quo =
813+
fra_uint_div(x_fra, y_fra, 3 + std::numeric_limits<Ty>::digits);
814+
z_fra = (fra_bits_quo >> 3) & __iml_fp_config<Ty>::fra_mask;
815+
int grs_bits = fra_bits_quo & 0x7;
816+
int rb = __handling_rounding(z_sig, z_fra, grs_bits, rd);
817+
if (rb != 0x0) {
818+
z_fra++;
819+
if (z_fra > __iml_fp_config<Ty>::fra_mask) {
820+
z_exp++;
821+
z_fra = 0x0;
822+
if (z_exp == __iml_fp_config<Ty>::exp_mask)
823+
return __handling_fp_overflow<Ty>(z_sig, rd);
824+
}
825+
}
826+
return __builtin_bit_cast(
827+
Ty, (z_sig << (sizeof(Ty) * 8 - 1)) |
828+
(z_exp << (std::numeric_limits<Ty>::digits - 1)) | z_fra);
829+
}
830+
831+
// Fall into subnormal range or underflow happens. For fp32,
832+
// nor_exp < -126, so (-126 - exp_diff + lz + 1) > 0 which means
833+
// (lz - exp_diff - 126) >= 0
834+
unsigned lzs = lz - __iml_fp_config<Ty>::bias - exp_diff + 1;
835+
if (lzs >= (std::numeric_limits<Ty>::digits - 1)) {
836+
bool above_half = false;
837+
if (lzs == (std::numeric_limits<Ty>::digits - 1)) {
838+
if ((x_fra << (lz + 1)) > y_fra)
839+
above_half = true;
840+
}
841+
return __handling_fp_underflow<Ty>(z_sig, rd, above_half);
842+
} else {
843+
x_fra = x_fra << lz;
844+
UTy fra_bits_quo =
845+
fra_uint_div(x_fra, y_fra, std::numeric_limits<Ty>::digits - lzs + 2);
846+
z_fra = fra_bits_quo >> 3;
847+
int grs_bits = fra_bits_quo & 0x7;
848+
int rb = __handling_rounding(z_sig, z_fra, grs_bits, rd);
849+
if (rb != 0x0) {
850+
z_fra++;
851+
if (z_fra > __iml_fp_config<Ty>::fra_mask) {
852+
z_exp++;
853+
z_fra = 0x0;
854+
}
855+
}
856+
return __builtin_bit_cast(
857+
Ty, (z_sig << (sizeof(Ty) * 8 - 1)) |
858+
(z_exp << (std::numeric_limits<Ty>::digits - 1)) | z_fra);
859+
}
860+
}
861+
}
630862
#endif

libdevice/imf_utils/fp32_round.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,24 @@ DEVICE_EXTERN_C_INLINE
6868
float __devicelib_imf_fmul_rz(float x, float y) {
6969
return __fp_mul(x, y, __IML_RTZ);
7070
}
71+
72+
DEVICE_EXTERN_C_INLINE
73+
float __devicelib_imf_fdiv_rd(float x, float y) {
74+
return __fp_div(x, y, __IML_RTN);
75+
}
76+
77+
DEVICE_EXTERN_C_INLINE
78+
float __devicelib_imf_fdiv_rn(float x, float y) {
79+
return __fp_div(x, y, __IML_RTE);
80+
}
81+
82+
DEVICE_EXTERN_C_INLINE
83+
float __devicelib_imf_fdiv_ru(float x, float y) {
84+
return __fp_div(x, y, __IML_RTP);
85+
}
86+
87+
DEVICE_EXTERN_C_INLINE
88+
float __devicelib_imf_fdiv_rz(float x, float y) {
89+
return __fp_div(x, y, __IML_RTZ);
90+
}
7191
#endif

libdevice/imf_utils/fp64_round.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,24 @@ DEVICE_EXTERN_C_INLINE
6868
double __devicelib_imf_dmul_rz(double x, double y) {
6969
return __fp_mul(x, y, __IML_RTZ);
7070
}
71+
72+
DEVICE_EXTERN_C_INLINE
73+
double __devicelib_imf_ddiv_rd(double x, double y) {
74+
return __fp_div(x, y, __IML_RTN);
75+
}
76+
77+
DEVICE_EXTERN_C_INLINE
78+
double __devicelib_imf_ddiv_rn(double x, double y) {
79+
return __fp_div(x, y, __IML_RTE);
80+
}
81+
82+
DEVICE_EXTERN_C_INLINE
83+
double __devicelib_imf_ddiv_ru(double x, double y) {
84+
return __fp_div(x, y, __IML_RTP);
85+
}
86+
87+
DEVICE_EXTERN_C_INLINE
88+
double __devicelib_imf_ddiv_rz(double x, double y) {
89+
return __fp_div(x, y, __IML_RTZ);
90+
}
7191
#endif

libdevice/imf_wrapper.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,4 +1924,28 @@ float __devicelib_imf_fmul_rz(float, float);
19241924

19251925
DEVICE_EXTERN_C_INLINE
19261926
float __imf_fmul_rz(float x, float y) { return __devicelib_imf_fmul_rz(x, y); }
1927+
1928+
DEVICE_EXTERN_C_INLINE
1929+
float __devicelib_imf_fdiv_rd(float, float);
1930+
1931+
DEVICE_EXTERN_C_INLINE
1932+
float __imf_fdiv_rd(float x, float y) { return __devicelib_imf_fdiv_rd(x, y); }
1933+
1934+
DEVICE_EXTERN_C_INLINE
1935+
float __devicelib_imf_fdiv_rn(float, float);
1936+
1937+
DEVICE_EXTERN_C_INLINE
1938+
float __imf_fdiv_rn(float x, float y) { return __devicelib_imf_fdiv_rn(x, y); }
1939+
1940+
DEVICE_EXTERN_C_INLINE
1941+
float __devicelib_imf_fdiv_ru(float, float);
1942+
1943+
DEVICE_EXTERN_C_INLINE
1944+
float __imf_fdiv_ru(float x, float y) { return __devicelib_imf_fdiv_ru(x, y); }
1945+
1946+
DEVICE_EXTERN_C_INLINE
1947+
float __devicelib_imf_fdiv_rz(float, float);
1948+
1949+
DEVICE_EXTERN_C_INLINE
1950+
float __imf_fdiv_rz(float x, float y) { return __devicelib_imf_fdiv_rz(x, y); }
19271951
#endif // __LIBDEVICE_IMF_ENABLED__

libdevice/imf_wrapper_fp64.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,4 +473,36 @@ DEVICE_EXTERN_C_INLINE
473473
double __imf_dmul_rz(double x, double y) {
474474
return __devicelib_imf_dmul_rz(x, y);
475475
}
476+
477+
DEVICE_EXTERN_C_INLINE
478+
double __devicelib_imf_ddiv_rd(double, double);
479+
480+
DEVICE_EXTERN_C_INLINE
481+
double __imf_ddiv_rd(double x, double y) {
482+
return __devicelib_imf_ddiv_rd(x, y);
483+
}
484+
485+
DEVICE_EXTERN_C_INLINE
486+
double __devicelib_imf_ddiv_rn(double, double);
487+
488+
DEVICE_EXTERN_C_INLINE
489+
double __imf_ddiv_rn(double x, double y) {
490+
return __devicelib_imf_ddiv_rn(x, y);
491+
}
492+
493+
DEVICE_EXTERN_C_INLINE
494+
double __devicelib_imf_ddiv_ru(double, double);
495+
496+
DEVICE_EXTERN_C_INLINE
497+
double __imf_ddiv_ru(double x, double y) {
498+
return __devicelib_imf_ddiv_ru(x, y);
499+
}
500+
501+
DEVICE_EXTERN_C_INLINE
502+
double __devicelib_imf_ddiv_rz(double, double);
503+
504+
DEVICE_EXTERN_C_INLINE
505+
double __imf_ddiv_rz(double x, double y) {
506+
return __devicelib_imf_ddiv_rz(x, y);
507+
}
476508
#endif // __LIBDEVICE_IMF_ENABLED__

llvm/tools/sycl-post-link/SYCLDeviceLibReqMask.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,10 @@ SYCLDeviceLibFuncMap SDLMap = {
235235
{"__devicelib_imf_fmul_rn", DeviceLibExt::cl_intel_devicelib_imf},
236236
{"__devicelib_imf_fmul_ru", DeviceLibExt::cl_intel_devicelib_imf},
237237
{"__devicelib_imf_fmul_rz", DeviceLibExt::cl_intel_devicelib_imf},
238+
{"__devicelib_imf_fdiv_rd", DeviceLibExt::cl_intel_devicelib_imf},
239+
{"__devicelib_imf_fdiv_rn", DeviceLibExt::cl_intel_devicelib_imf},
240+
{"__devicelib_imf_fdiv_ru", DeviceLibExt::cl_intel_devicelib_imf},
241+
{"__devicelib_imf_fdiv_rz", DeviceLibExt::cl_intel_devicelib_imf},
238242
{"__devicelib_imf_float2int_rd", DeviceLibExt::cl_intel_devicelib_imf},
239243
{"__devicelib_imf_float2int_rn", DeviceLibExt::cl_intel_devicelib_imf},
240244
{"__devicelib_imf_float2int_ru", DeviceLibExt::cl_intel_devicelib_imf},
@@ -452,6 +456,10 @@ SYCLDeviceLibFuncMap SDLMap = {
452456
{"__devicelib_imf_dmul_rn", DeviceLibExt::cl_intel_devicelib_imf_fp64},
453457
{"__devicelib_imf_dmul_ru", DeviceLibExt::cl_intel_devicelib_imf_fp64},
454458
{"__devicelib_imf_dmul_rz", DeviceLibExt::cl_intel_devicelib_imf_fp64},
459+
{"__devicelib_imf_ddiv_rd", DeviceLibExt::cl_intel_devicelib_imf_fp64},
460+
{"__devicelib_imf_ddiv_rn", DeviceLibExt::cl_intel_devicelib_imf_fp64},
461+
{"__devicelib_imf_ddiv_ru", DeviceLibExt::cl_intel_devicelib_imf_fp64},
462+
{"__devicelib_imf_ddiv_rz", DeviceLibExt::cl_intel_devicelib_imf_fp64},
455463
{"__devicelib_imf_double2float_rd",
456464
DeviceLibExt::cl_intel_devicelib_imf_fp64},
457465
{"__devicelib_imf_double2float_rn",

sycl/include/sycl/builtins.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ extern __DPCPP_SYCL_EXTERNAL float __imf_fmul_rd(float x, float y);
103103
extern __DPCPP_SYCL_EXTERNAL float __imf_fmul_rn(float x, float y);
104104
extern __DPCPP_SYCL_EXTERNAL float __imf_fmul_ru(float x, float y);
105105
extern __DPCPP_SYCL_EXTERNAL float __imf_fmul_rz(float x, float y);
106+
extern __DPCPP_SYCL_EXTERNAL float __imf_fdiv_rd(float x, float y);
107+
extern __DPCPP_SYCL_EXTERNAL float __imf_fdiv_rn(float x, float y);
108+
extern __DPCPP_SYCL_EXTERNAL float __imf_fdiv_ru(float x, float y);
109+
extern __DPCPP_SYCL_EXTERNAL float __imf_fdiv_rz(float x, float y);
106110
extern __DPCPP_SYCL_EXTERNAL int __imf_float2int_rd(float x);
107111
extern __DPCPP_SYCL_EXTERNAL int __imf_float2int_rn(float x);
108112
extern __DPCPP_SYCL_EXTERNAL int __imf_float2int_ru(float x);
@@ -328,6 +332,10 @@ extern __DPCPP_SYCL_EXTERNAL double __imf_dmul_rd(double x, double y);
328332
extern __DPCPP_SYCL_EXTERNAL double __imf_dmul_rn(double x, double y);
329333
extern __DPCPP_SYCL_EXTERNAL double __imf_dmul_ru(double x, double y);
330334
extern __DPCPP_SYCL_EXTERNAL double __imf_dmul_rz(double x, double y);
335+
extern __DPCPP_SYCL_EXTERNAL double __imf_ddiv_rd(double x, double y);
336+
extern __DPCPP_SYCL_EXTERNAL double __imf_ddiv_rn(double x, double y);
337+
extern __DPCPP_SYCL_EXTERNAL double __imf_ddiv_ru(double x, double y);
338+
extern __DPCPP_SYCL_EXTERNAL double __imf_ddiv_rz(double x, double y);
331339
extern __DPCPP_SYCL_EXTERNAL float __imf_double2float_rd(double x);
332340
extern __DPCPP_SYCL_EXTERNAL float __imf_double2float_rn(double x);
333341
extern __DPCPP_SYCL_EXTERNAL float __imf_double2float_ru(double x);

0 commit comments

Comments
 (0)