@@ -903,7 +903,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
903903 const float * restrict vy = (const float * restrict) y ;
904904
905905 for (uint32_t i = 0 ; i < n ; i ++ ) {
906- rsum += vx [i ] * ( __fp16 ) vy [i ];
906+ rsum += ( float ) vx [i ] * vy [i ];
907907 }
908908 * s = rsum ;
909909 return ;
@@ -917,7 +917,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
917917
918918 // for some reason we need volatile here so that the compiler doesn't try anything funky
919919 volatile HVX_Vector rsum = Q6_V_vsplat_R (0 );
920-
920+ float r_sum_scalar = 0.0f ;
921921 uint32_t i = 0 ;
922922
923923 for (i = 0 ; i < nv0 ; i ++ ) {
@@ -926,31 +926,42 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
926926 HVX_Vector x = vx [i ];
927927 HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf (Q6_Vh_vshuff_Vh (x ), Q6_Vh_vsplat_R (0x3C00 )); // mul by 1.0
928928
929- HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf (Q6_Vsf_equals_Vqf32 (Q6_V_hi_W (xp )), Q6_V_hi_W (yp ));
930- HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf (Q6_Vsf_equals_Vqf32 (Q6_V_lo_W (xp )), Q6_V_lo_W (yp ));
929+ //NOTE: need volatile here to prevent compiler optimization
930+ // Seem compiler cannot guarantee read-after-write??
931+ volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf (Q6_Vsf_equals_Vqf32 (Q6_V_hi_W (xp )), Q6_V_hi_W (yp ));
932+ volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf (Q6_Vsf_equals_Vqf32 (Q6_V_lo_W (xp )), Q6_V_lo_W (yp ));
931933
932934 HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32 (hi , lo );
933935 rsum = Q6_Vqf32_vadd_Vqf32Vqf32 (rsum , sum );
934936 }
935937
936938 if (nv1 ) {
937- HVX_VectorPair yp = vy [i ];
939+ // HVX_VectorPair yp = vy[i];
938940
939- HVX_Vector x = vx [i ];
940- HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf (Q6_Vh_vshuff_Vh (x ), Q6_Vh_vsplat_R (0x3C00 )); // mul by 1.0
941+ // HVX_Vector x = vx[i];
942+ // HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
941943
942- if (nv1 >= 32 ) {
943- HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf (Q6_Vsf_equals_Vqf32 (Q6_V_hi_W (xp )), Q6_V_hi_W (yp ));
944- rsum = Q6_Vqf32_vadd_Vqf32Vqf32 (rsum , hi );
945- nv1 -= 32 ;
946- }
944+ // if (nv1 >= 32) {
945+ // volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
946+ // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
947+ // nv1 -= 32;
948+ // }
947949
950+ // rsum = hvx_vec_qf32_reduce_sum(rsum);
951+
952+ // if (nv1) {
953+ // volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
954+ // HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1);
955+ // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
956+ // }
957+
958+ //process the remainder using scalar loop
948959 rsum = hvx_vec_qf32_reduce_sum (rsum );
960+ const __fp16 * restrict sx = (const __fp16 * restrict) x ;
961+ const float * restrict sy = (const float * restrict) y ;
949962
950- if (nv1 ) {
951- HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf (Q6_Vsf_equals_Vqf32 (Q6_V_lo_W (xp )), Q6_V_lo_W (yp ));
952- HVX_Vector sum = hvx_vec_qf32_reduce_sum_n (lo , nv1 );
953- rsum = Q6_Vqf32_vadd_Vqf32Vqf32 (rsum , sum );
963+ for (uint32_t i = nv0 * 64 ; i < n ; i ++ ) {
964+ r_sum_scalar += (float ) sx [i ] * sy [i ];
954965 }
955966
956967 // hvx_vec_dump_fp16("X", x);
@@ -961,7 +972,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
961972 rsum = hvx_vec_qf32_reduce_sum (rsum );
962973 }
963974
964- * s = hvx_vec_get_fp32 (Q6_Vsf_equals_Vqf32 (rsum ));
975+ * s = hvx_vec_get_fp32 (Q6_Vsf_equals_Vqf32 (rsum )) + r_sum_scalar ;
965976
966977# ifdef HTP_DEBUG
967978 {
@@ -1498,9 +1509,6 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
14981509 uint64_t t1 , t2 ;
14991510 t1 = HAP_perf_get_qtimer_count ();
15001511
1501- const size_t src0_row_size = sizeof (__fp16 ) * ne00 ;
1502- const size_t src1_row_size = sizeof (float ) * ne10 ;
1503-
15041512 assert (ne12 % ne02 == 0 );
15051513 assert (ne13 % ne03 == 0 );
15061514
@@ -1510,8 +1518,6 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
15101518 // This is the size of the rest of the dimensions of the result
15111519 const uint32_t nr1 = ne1 * ne2 * ne3 ;
15121520
1513- uint32_t chunk_size = 64 ;
1514-
15151521 // distribute the thread work across the inner or outer loop based on which one is larger
15161522 uint32_t nchunk0 = nr0 > nr1 ? nth : 1 ; // parallelize by src0 rows
15171523 uint32_t nchunk1 = nr0 > nr1 ? 1 : nth ; // parallelize by src1 rows
@@ -1544,11 +1550,11 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
15441550 const uint32_t blck_0 = 64 ;
15451551 const uint32_t blck_1 = 64 ;
15461552
1547- float tmp [32 ];
1553+ __attribute__(( aligned ( 128 ))) float tmp [64 ];
15481554
15491555 for (uint32_t iir1 = ir1_start ; iir1 < ir1_end ; iir1 += blck_1 ) {
15501556 for (uint32_t iir0 = ir0_start ; iir0 < ir0_end ; iir0 += blck_0 ) {
1551- for (uint32_t ir1 = iir1 ; ir1 < iir1 + blck_1 && ir1 < ir1_end ; ir1 ++ ) {
1557+ for (uint32_t ir1 = iir1 ; ir1 < MIN ( iir1 + blck_1 , ir1_end ) ; ir1 ++ ) {
15521558 const uint32_t i13 = (ir1 / (ne12 * ne1 ));
15531559 const uint32_t i12 = (ir1 - i13 * ne12 * ne1 ) / ne1 ;
15541560 const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1 );
@@ -1561,13 +1567,16 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
15611567 const uint32_t i2 = i12 ;
15621568 const uint32_t i3 = i13 ;
15631569
1564- const uint8_t * restrict src0_row = (const uint8_t * ) src0 -> data + (0 + i02 * nb02 + i03 * nb03 );
1570+ const uint8_t * restrict src0_base = (const uint8_t * ) src0 -> data + (0 + i02 * nb02 + i03 * nb03 );
15651571 const uint8_t * restrict src1_col =
1566- (const uint8_t * ) src1 -> data + (i11 + i12 * ne11 + i13 * ne12 * ne11 ) * src1_row_size ;
1572+ (const uint8_t * ) src1 -> data + (i11 * nb11 + i12 * nb12 + i13 * nb13 ) ;
15671573 float * dst_col = (float * ) ((uint8_t * restrict) dst -> data + (i1 * nb1 + i2 * nb2 + i3 * nb3 ));
15681574
1569- for (uint32_t ir0 = iir0 ; ir0 < iir0 + blck_0 && ir0 < ir0_end ; ir0 ++ ) {
1570- vec_dot_f16_f32 (ne00 , & tmp [ir0 - iir0 ], src0_row + ir0 * src0_row_size , src1_col );
1575+ const uint32_t ir0_block_end = MIN (iir0 + blck_0 , ir0_end );
1576+ for (uint32_t ir0 = iir0 ; ir0 < ir0_block_end ; ir0 ++ ) {
1577+ // Use nb01 stride for non-contiguous src0 support
1578+ const uint8_t * restrict src0_row = src0_base + ir0 * nb01 ;
1579+ vec_dot_f16_f32 (ne00 , & tmp [ir0 - iir0 ], src0_row , src1_col );
15711580 }
15721581
15731582 hvx_copy_fp32_ua ((uint8_t * ) & dst_col [iir0 ], (uint8_t * ) tmp , MIN (iir0 + blck_0 , ir0_end ) - iir0 );
0 commit comments