From d3a8916ff6c49206ced211055e6ac2c94f407e7a Mon Sep 17 00:00:00 2001 From: Thomas Herault Date: Fri, 4 Apr 2025 16:40:42 +0200 Subject: [PATCH] Fix missing corner cases when computing sum-square updates -- patch provided by Mathieu Faverge to import fixes done in Chameleon --- src/cores/core_zgessq.c | 26 ++----- src/cores/core_zhessq.c | 60 ++++++---------- src/cores/core_zsyssq.c | 66 ++++++------------ src/cores/core_ztrssq.c | 35 +++------- src/cores/sumsq_update.h | 141 ++++++++++++++++++++++++++++++++++++++ src/zlange_frb_cyclic.jdf | 28 ++------ src/zlanm2.jdf | 37 ++-------- src/zlansy.jdf | 41 ++--------- 8 files changed, 213 insertions(+), 221 deletions(-) create mode 100644 src/cores/sumsq_update.h diff --git a/src/cores/core_zgessq.c b/src/cores/core_zgessq.c index 49ede4f6..005ed8c0 100644 --- a/src/cores/core_zgessq.c +++ b/src/cores/core_zgessq.c @@ -15,18 +15,7 @@ #include #include #include "common.h" - -#define COMPLEX - -#define UPDATE( __nb, __value ) \ - if (__value != 0. ){ \ - if ( *scale < __value ) { \ - *sumsq = __nb + (*sumsq) * ( *scale / __value ) * ( *scale / __value ); \ - *scale = __value; \ - } else { \ - *sumsq = *sumsq + __nb * ( __value / *scale ) * ( __value / *scale ); \ - } \ - } +#include "sumsq_update.h" /***************************************************************************** * @@ -91,19 +80,16 @@ int CORE_zgessq(int M, int N, double *scale, double *sumsq) { int i, j; - double tmp; double *ptr; for(j=0; j #include #include "common.h" - -#define COMPLEX - -#define UPDATE( __nb, __value ) \ - if (__value != 0. ){ \ - if ( *scale < __value ) { \ - *sumsq = __nb + (*sumsq) * ( *scale / __value ) * ( *scale / __value ); \ - *scale = __value; \ - } else { \ - *sumsq = *sumsq + __nb * ( __value / *scale ) * ( __value / *scale ); \ - } \ - } +#include "sumsq_update.h" /***************************************************************************** * @@ -97,7 +86,6 @@ int CORE_zhessq(PLASMA_enum uplo, int N, double *scale, double *sumsq) { int i, j; - double tmp; double *ptr; if ( uplo == PlasmaUpper ) { @@ -105,23 +93,19 @@ int CORE_zhessq(PLASMA_enum uplo, int N, ptr = (double*) ( A + j * LDA ); for(i=0; i #include #include "common.h" - -#define COMPLEX - -#define UPDATE( __nb, __value ) \ - if (__value != 0. ){ \ - if ( *scale < __value ) { \ - *sumsq = __nb + (*sumsq) * ( *scale / __value ) * ( *scale / __value ); \ - *scale = __value; \ - } else { \ - *sumsq = *sumsq + __nb * ( __value / *scale ) * ( __value / *scale ); \ - } \ - } +#include "sumsq_update.h" /***************************************************************************** * @@ -97,7 +86,6 @@ int CORE_zsyssq(PLASMA_enum uplo, int N, double *scale, double *sumsq) { int i, j; - double tmp; double *ptr; if ( uplo == PlasmaUpper ) { @@ -105,25 +93,20 @@ int CORE_zsyssq(PLASMA_enum uplo, int N, ptr = (double*) ( A + j * LDA ); for(i=0; i #include #include "common.h" - -#define COMPLEX - -#define UPDATE( __nb, __value ) \ - if (__value != 0. ){ \ - if ( *scale < __value ) { \ - *sumsq = __nb + (*sumsq) * ( *scale / __value ) * ( *scale / __value ); \ - *scale = __value; \ - } else { \ - *sumsq = *sumsq + __nb * ( __value / *scale ) * ( __value / *scale ); \ - } \ - } +#include "sumsq_update.h" /***************************************************************************** * @@ -97,7 +86,7 @@ int CORE_ztrssq(PLASMA_enum uplo, PLASMA_enum diag, int M, int N, if ( diag == PlasmaUnit ){ tmp = sqrt( min(M, N) ); - UPDATE( 1., tmp ); + sumsq_update( 1, scale, sumsq, &tmp ); } if (uplo == PlasmaUpper ) { @@ -108,13 +97,11 @@ int CORE_ztrssq(PLASMA_enum uplo, PLASMA_enum diag, int M, int N, imax = min(j+1-idiag, M); for(i=0; i= 0.) { + if ( (*scaleout) < (*scalein) ) { + ratio = *scaleout / *scalein; + *sumsqout = *sumsqin + (*sumsqout) * ratio * ratio; + *scaleout = *scalein; + } else { + if ( (*scaleout) > 0. ){ + ratio = *scalein / *scaleout; + *sumsqout = *sumsqout + (*sumsqin) * ratio * ratio; + } + } + } +} +#elif defined(PRECISION_s) || defined(PRECISION_c) +static inline void +sumsq_update_2( const float *scalein, const float *sumsqin, float *scaleout, float *sumsqout ) +{ + float ratio; + if (*scaleout >= 0.) { + if ( (*scaleout) < (*scalein) ) { + ratio = *scaleout / *scalein; + *sumsqout = *sumsqin + (*sumsqout) * ratio * ratio; + *scaleout = *scalein; + } else { + if ( (*scaleout) > 0. ){ + ratio = *scalein / *scaleout; + *sumsqout = *sumsqout + (*sumsqin) * ratio * ratio; + } + } + } +} +#endif + +#endif /* _sumsq_update_h_ */ diff --git a/src/zlange_frb_cyclic.jdf b/src/zlange_frb_cyclic.jdf index 60ef816f..2e70619d 100644 --- a/src/zlange_frb_cyclic.jdf +++ b/src/zlange_frb_cyclic.jdf @@ -36,6 +36,7 @@ extern "C" %{ #include #include "dplasmajdf.h" #include "parsec/data_dist/matrix/matrix.h" +#include "cores/sumsq_update.h" %} @@ -228,12 +229,7 @@ BODY dW[1] = dA[1]; } else { - if( dW[0] < dA[0] ) { - dW[1] = dA[1] + (dW[1] * (( dW[0] / dA[0] ) * ( dW[0] / dA[0] ))); - dW[0] = dA[0]; - } else { - dW[1] = dW[1] + (dA[1] * (( dA[0] / dW[0] ) * ( dA[0] / dW[0] ))); - } + sumsq_update_2( dA, dA + 1, dW, dW + 1 ); } } else { @@ -314,12 +310,7 @@ BODY dW[0] = 0.; dW[1] = 1.; } - if( dW[0] < dA[0] ) { - dW[1] = dA[1] + (dW[1] * (( dW[0] / dA[0] ) * ( dW[0] / dA[0] ))); - dW[0] = dA[0]; - } else { - dW[1] = dW[1] + (dA[1] * (( dA[0] / dW[0] ) * ( dA[0] / dW[0] ))); - } + sumsq_update_2( dA, dA + 1, dW, dW + 1 ); } break; @@ -370,22 +361,11 @@ BODY case dplasmaFrobeniusNorm: { - double sqr; - if (m > (MT-1)) { dW[0] = 0.; dW[1] = 1.; } - if( dW[0] < dA[0] ) { - sqr = dW[0] / dA[0]; - sqr = sqr * sqr; - dW[1] = dA[1] + sqr * dW[1]; - dW[0] = dA[0]; - } else { - sqr = dA[0] / dW[0]; - sqr = sqr * sqr; - dW[1] = dW[1] + sqr * dA[1]; - } + sumsq_update_2( dA, dA + 1, dW, dW + 1 ); } break; diff --git a/src/zlanm2.jdf b/src/zlanm2.jdf index 0d363ed3..b3988cc3 100644 --- a/src/zlanm2.jdf +++ b/src/zlanm2.jdf @@ -24,6 +24,7 @@ extern "C" %{ #include #include "dplasmajdf.h" #include "parsec/data_dist/matrix/matrix.h" +#include "cores/sumsq_update.h" %} @@ -189,7 +190,6 @@ BODY { double *dA = (double*)A; double *dW = (double*)W; - double sqr; printlog("zlange STEP4(%d, %d)\n", m, n); @@ -198,16 +198,7 @@ BODY dW[1] = 1.; } if (n > 0) { - if( dW[0] < dA[0] ) { - sqr = dW[0] / dA[0]; - sqr = sqr * sqr; - dW[1] = dA[1] + sqr * dW[1]; - dW[0] = dA[0]; - } else { - sqr = dA[0] / dW[0]; - sqr = sqr * sqr; - dW[1] = dW[1] + sqr * dA[1]; - } + sumsq_update_2( dA, dA + 1, dW, dW + 1 ); } } END @@ -585,7 +576,6 @@ BODY { double *dA = (double*)A; double *dW = (double*)W; - double sqr; printlog("norm_sx_step2(%d, %d)\n", i, m); @@ -594,16 +584,7 @@ BODY dW[1] = dA[1]; } else { - if( dW[0] < dA[0] ) { - sqr = dW[0] / dA[0]; - sqr = sqr * sqr; - dW[1] = dA[1] + sqr * dW[1]; - dW[0] = dA[0]; - } else { - sqr = dA[0] / dW[0]; - sqr = sqr * sqr; - dW[1] = dW[1] + sqr * dA[1]; - } + sumsq_update_2( dA, dA + 1, dW, dW + 1 ); } } END @@ -672,7 +653,6 @@ BODY { double *dA = (double*)A; double *dW = (double*)W; - double sqr; printlog("norm_x_step3(%d, %d)\n", i, n); @@ -681,16 +661,7 @@ BODY dW[1] = dA[1]; } else { - if( dW[0] < dA[0] ) { - sqr = dW[0] / dA[0]; - sqr = sqr * sqr; - dW[1] = dA[1] + sqr * dW[1]; - dW[0] = dA[0]; - } else { - sqr = dA[0] / dW[0]; - sqr = sqr * sqr; - dW[1] = dW[1] + sqr * dA[1]; - } + sumsq_update_2( dA, dA + 1, dW, dW + 1 ); } } END diff --git a/src/zlansy.jdf b/src/zlansy.jdf index 0c88a35f..4233dfa2 100644 --- a/src/zlansy.jdf +++ b/src/zlansy.jdf @@ -25,6 +25,7 @@ extern "C" %{ #include #include "dplasmajdf.h" #include "parsec/data_dist/matrix/matrix.h" +#include "cores/sumsq_update.h" #define my_rank_of(m, n) (((parsec_data_collection_t*)(descA))->rank_of((parsec_data_collection_t*)descA, m, n)) %} @@ -259,15 +260,8 @@ BODY dW[1] = 1.; } if(n+Q < PQ) { - if ( dA[0] > 0. ){ - if( dW[0] < dA[0] ) { - dW[1] = dA[1] + (dW[1] * (( dW[0] / dA[0] ) * ( dW[0] / dA[0] ))); - dW[0] = dA[0]; - } else { - dW[1] = dW[1] + (dA[1] * (( dA[0] / dW[0] ) * ( dA[0] / dW[0] ))); - } - } - } + sumsq_update_2( dA, dA + 1, dW, dW + 1 ); + } } else { if ((hadtile+1) == 0) { @@ -317,14 +311,7 @@ BODY *dW = ( *dA > *dW ) ? *dA : *dW; } else if (ntype == dplasmaFrobeniusNorm) { - if ( dA[0] > 0. ){ - if( dW[0] < dA[0] ) { - dW[1] = dA[1] + (dW[1] * (( dW[0] / dA[0] ) * ( dW[0] / dA[0] ))); - dW[0] = dA[0]; - } else { - dW[1] = dW[1] + (dA[1] * (( dA[0] / dW[0] ) * ( dA[0] / dW[0] ))); - } - } + sumsq_update_2( dA, dA + 1, dW, dW + 1 ); } else { cblas_daxpy( tempmm, 1., dA, 1, dW, 1); @@ -376,12 +363,7 @@ BODY dW[0] = 0.; dW[1] = 1.; } - if( dW[0] < dA[0] ) { - dW[1] = dA[1] + (dW[1] * (( dW[0] / dA[0] ) * ( dW[0] / dA[0] ))); - dW[0] = dA[0]; - } else { - dW[1] = dW[1] + (dA[1] * (( dA[0] / dW[0] ) * ( dA[0] / dW[0] ))); - } + sumsq_update_2( dA, dA + 1, dW, dW + 1 ); } else { double maxval = 0; @@ -427,23 +409,12 @@ BODY double *dA = (double*)A; double *dW = (double*)W; if (ntype == dplasmaFrobeniusNorm) { - double sqr; - if ( m > (descA->mt-1)) { dW[0] = 0.; dW[1] = 1.; } if(m > 0) { - if( dW[0] < dA[0] ) { - sqr = dW[0] / dA[0]; - sqr = sqr * sqr; - dW[1] = dA[1] + sqr * dW[1]; - dW[0] = dA[0]; - } else { - sqr = dA[0] / dW[0]; - sqr = sqr * sqr; - dW[1] = dW[1] + sqr * dA[1]; - } + sumsq_update_2( dA, dA + 1, dW, dW + 1 ); } } else { if ( m > (descA->mt-1)) {