Skip to content

Commit e396ad0

Browse files
committed
Merge remote-tracking branch 'origin/topic/ssymm_direct_sme1' into topic/strmm_direct_sme1
2 parents ea2890d + 1926847 commit e396ad0

12 files changed

+456
-189
lines changed

common_level3.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,32 @@ void sgemm_direct_alpha_beta(BLASLONG M, BLASLONG N, BLASLONG K,
5959
float beta,
6060
float * R, BLASLONG strideR);
6161

62-
void strmm_direct_LNUN(BLASLONG M, BLASLONG N, BLASLONG K,
62+
void ssymm_direct_alpha_betaLU(BLASLONG M, BLASLONG N,
63+
float alpha,
64+
float * A, BLASLONG strideA,
65+
float * B, BLASLONG strideB,
66+
float beta,
67+
float * R, BLASLONG strideR);
68+
void ssymm_direct_alpha_betaLL(BLASLONG M, BLASLONG N,
69+
float alpha,
70+
float * A, BLASLONG strideA,
71+
float * B, BLASLONG strideB,
72+
float beta,
73+
float * R, BLASLONG strideR);
74+
75+
void strmm_direct_LNUN(BLASLONG M, BLASLONG N,
6376
float alpha,
6477
float * A, BLASLONG strideA,
6578
float * B, BLASLONG strideB);
66-
void strmm_direct_LNLN(BLASLONG M, BLASLONG N, BLASLONG K,
79+
void strmm_direct_LNLN(BLASLONG M, BLASLONG N,
6780
float alpha,
6881
float * A, BLASLONG strideA,
6982
float * B, BLASLONG strideB);
70-
void strmm_direct_LTUN(BLASLONG M, BLASLONG N, BLASLONG K,
83+
void strmm_direct_LTUN(BLASLONG M, BLASLONG N,
7184
float alpha,
7285
float * A, BLASLONG strideA,
7386
float * B, BLASLONG strideB);
74-
void strmm_direct_LTLN(BLASLONG M, BLASLONG N, BLASLONG K,
87+
void strmm_direct_LTLN(BLASLONG M, BLASLONG N,
7588
float alpha,
7689
float * A, BLASLONG strideA,
7790
float * B, BLASLONG strideB);

common_param.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,12 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
257257
#ifdef ARCH_ARM64
258258
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG);
259259
void (*sgemm_direct_alpha_beta) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
260+
void (*ssymm_direct_alpha_betaLU) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
261+
void (*ssymm_direct_alpha_betaLL) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
262+
void (*strmm_direct_LNUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
263+
void (*strmm_direct_LNLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
264+
void (*strmm_direct_LTUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
265+
void (*strmm_direct_LTLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
260266
#endif
261267

262268

@@ -307,12 +313,6 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
307313
int (*strsm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
308314
#endif
309315
#if (BUILD_SINGLE==1)
310-
#ifdef ARCH_ARM64
311-
void (*strmm_direct_LNUN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
312-
void (*strmm_direct_LNLN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
313-
void (*strmm_direct_LTUN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
314-
void (*strmm_direct_LTLN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
315-
#endif
316316
int (*strmm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
317317
int (*strmm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
318318
int (*strmm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);

common_s.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@
5050
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
5151
#define SGEMM_DIRECT sgemm_direct
5252
#define SGEMM_DIRECT_ALPHA_BETA sgemm_direct_alpha_beta
53+
#define SSYMM_DIRECT_ALPHA_BETA_LU ssymm_direct_alpha_betaLU
54+
#define SSYMM_DIRECT_ALPHA_BETA_LL ssymm_direct_alpha_betaLL
55+
#define STRMM_DIRECT_LNUN strmm_direct_LNUN
56+
#define STRMM_DIRECT_LNLN strmm_direct_LNLN
57+
#define STRMM_DIRECT_LTUN strmm_direct_LTUN
58+
#define STRMM_DIRECT_LTLN strmm_direct_LTLN
5359

5460
#define SGEMM_ONCOPY sgemm_oncopy
5561
#define SGEMM_OTCOPY sgemm_otcopy
@@ -62,11 +68,6 @@
6268
#define SGEMM_ITCOPY sgemm_itcopy
6369
#endif
6470

65-
#define STRMM_DIRECT_LNUN strmm_direct_LNUN
66-
#define STRMM_DIRECT_LNLN strmm_direct_LNLN
67-
#define STRMM_DIRECT_LTUN strmm_direct_LTUN
68-
#define STRMM_DIRECT_LTLN strmm_direct_LTLN
69-
7071
#define STRMM_OUNUCOPY strmm_ounucopy
7172
#define STRMM_OUNNCOPY strmm_ounncopy
7273
#define STRMM_OUTUCOPY strmm_outucopy
@@ -225,20 +226,19 @@
225226
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
226227
#define SGEMM_DIRECT gotoblas -> sgemm_direct
227228
#define SGEMM_DIRECT_ALPHA_BETA gotoblas -> sgemm_direct_alpha_beta
229+
#define SSYMM_DIRECT_ALPHA_BETA_LU gotoblas -> ssymm_direct_alpha_betaLU
230+
#define SSYMM_DIRECT_ALPHA_BETA_LL gotoblas -> ssymm_direct_alpha_betaLL
231+
#define STRMM_DIRECT_LNUN gotoblas -> strmm_direct_LNUN
232+
#define STRMM_DIRECT_LNLN gotoblas -> strmm_direct_LNLN
233+
#define STRMM_DIRECT_LTUN gotoblas -> strmm_direct_LTUN
234+
#define STRMM_DIRECT_LTLN gotoblas -> strmm_direct_LTLN
228235
#endif
229236

230237
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy
231238
#define SGEMM_OTCOPY gotoblas -> sgemm_otcopy
232239
#define SGEMM_INCOPY gotoblas -> sgemm_incopy
233240
#define SGEMM_ITCOPY gotoblas -> sgemm_itcopy
234241

235-
#ifdef ARCH_ARM64
236-
#define STRMM_DIRECT_LNUN gotoblas -> strmm_direct_LNUN
237-
#define STRMM_DIRECT_LNLN gotoblas -> strmm_direct_LNLN
238-
#define STRMM_DIRECT_LTUN gotoblas -> strmm_direct_LTUN
239-
#define STRMM_DIRECT_LTLN gotoblas -> strmm_direct_LTLN
240-
#endif
241-
242242
#define STRMM_OUNUCOPY gotoblas -> strmm_ounucopy
243243
#define STRMM_OUTUCOPY gotoblas -> strmm_outucopy
244244
#define STRMM_OLNUCOPY gotoblas -> strmm_olnucopy

interface/symm.c

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,24 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo,
371371
return;
372372
}
373373

374+
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
375+
#if defined(ARCH_ARM64) && (defined(USE_SSYMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
376+
#if defined(DYNAMIC_ARCH)
377+
if (support_sme1())
378+
#endif
379+
if (args.m == 0 || args.n == 0) return;
380+
if (order == CblasRowMajor && m == lda && n == ldb && n == ldc)
381+
{
382+
if (Side == CblasLeft && Uplo == CblasUpper) {
383+
SSYMM_DIRECT_ALPHA_BETA_LU(m, n, alpha, a, lda, b, ldb, beta, c, ldc); return;
384+
}
385+
else if (Side == CblasLeft && Uplo == CblasLower) {
386+
SSYMM_DIRECT_ALPHA_BETA_LL(m, n, alpha, a, lda, b, ldb, beta, c, ldc); return;
387+
}
388+
}
389+
#endif
390+
#endif
391+
374392
#endif
375393

376394
if (args.m == 0 || args.n == 0) return;

interface/trsm.c

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -255,21 +255,6 @@ void CNAME(enum CBLAS_ORDER order,
255255
#endif
256256

257257
PRINT_DEBUG_CNAME;
258-
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
259-
#if defined(ARCH_ARM64) && (defined(USE_STRMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
260-
#if defined(DYNAMIC_ARCH)
261-
if (support_sme1())
262-
#endif
263-
if (order == CblasRowMajor && Diag == CblasNonUnit && Side == CblasLeft) {
264-
if (Trans == CblasNoTrans) {
265-
(Uplo == CblasUpper ? STRMM_DIRECT_LNUN : STRMM_DIRECT_LNLN)(m, n, m, alpha, a, lda, b, ldb);
266-
} else if (Trans == CblasTrans) {
267-
(Uplo == CblasUpper ? STRMM_DIRECT_LTUN : STRMM_DIRECT_LTLN)(m, n, m, alpha, a, lda, b, ldb);
268-
}
269-
return;
270-
}
271-
#endif
272-
#endif
273258

274259
args.a = (void *)a;
275260
args.b = (void *)b;
@@ -370,6 +355,23 @@ void CNAME(enum CBLAS_ORDER order,
370355
return;
371356
}
372357

358+
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
359+
#if defined(ARCH_ARM64) && (defined(USE_STRMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
360+
#if defined(DYNAMIC_ARCH)
361+
if (support_sme1())
362+
#endif
363+
if (args.m == 0 || args.n == 0) return;
364+
if (order == CblasRowMajor && Diag == CblasNonUnit && Side == CblasLeft) {
365+
if (Trans == CblasNoTrans) {
366+
(Uplo == CblasUpper ? STRMM_DIRECT_LNUN : STRMM_DIRECT_LNLN)(m, n, alpha, a, lda, b, ldb);
367+
} else if (Trans == CblasTrans) {
368+
(Uplo == CblasUpper ? STRMM_DIRECT_LTUN : STRMM_DIRECT_LTLN)(m, n, alpha, a, lda, b, ldb);
369+
}
370+
return;
371+
}
372+
#endif
373+
#endif
374+
373375
#endif
374376

375377
if ((args.m == 0) || (args.n == 0)) return;

kernel/CMakeLists.txt

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
245245
if (X86_64 OR ARM64)
246246
set(USE_DIRECT_SGEMM true)
247247
endif()
248+
set(USE_DIRECT_SSYMM false)
249+
if (ARM64)
250+
set(USE_DIRECT_SSYMM true)
251+
endif()
248252
if (UC_TARGET_CORE MATCHES ARMV9SME)
249253
set (HAVE_SME true)
250254
endif ()
@@ -271,6 +275,24 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
271275
endif ()
272276
endif()
273277

278+
if (USE_DIRECT_SSYMM)
279+
if (ARM64)
280+
set (SSYMMDIRECTKERNEL_ALPHA_BETA ssymm_direct_alpha_beta_arm64_sme1.c)
281+
GenerateNamedObjects("${KERNELDIR}/${SSYMMDIRECTKERNEL_ALPHA_BETA}" "" "symm_direct_alpha_betaLU" false "" "" false SINGLE)
282+
GenerateNamedObjects("${KERNELDIR}/${SSYMMDIRECTKERNEL_ALPHA_BETA}" "" "symm_direct_alpha_betaLL" false "" "" false SINGLE)
283+
endif ()
284+
endif()
285+
286+
if (USE_DIRECT_STRMM)
287+
if (ARM64)
288+
set (STRMMDIRECTKERNEL strmm_direct_arm64_sme1.c)
289+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNUN" false "" "" false SINGLE)
290+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNLN" false "" "" false SINGLE)
291+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTUN" false "" "" false SINGLE)
292+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTLN" false "" "" false SINGLE)
293+
endif ()
294+
endif ()
295+
274296
foreach (float_type SINGLE DOUBLE)
275297
string(SUBSTRING ${float_type} 0 1 float_char)
276298
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type})
@@ -446,16 +468,6 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
446468
set(TRMM_KERNEL "${${float_char}GEMMKERNEL}")
447469
endif ()
448470

449-
if (USE_DIRECT_STRMM)
450-
set (STRMMDIRECTKERNEL strmm_direct_arm64_sme1.c)
451-
set (STRMMDIRECTPREKERNEL strmm_direct_arm64_sme1_preprocess.c)
452-
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNUN" false "" "" false SINGLE)
453-
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNLN" false "" "" false SINGLE)
454-
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTUN" false "" "" false SINGLE)
455-
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTLN" false "" "" false SINGLE)
456-
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTPREKERNEL}" "" "trmm_direct_sme1_preprocess_UN" false "" "" false SINGLE)
457-
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTPREKERNEL}" "" "trmm_direct_sme1_preprocess_LN" false "" "" false SINGLE)
458-
endif ()
459471

460472
if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX")
461473

kernel/Makefile.L3

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ endif
5252
ifeq ($(ARCH), arm64)
5353
USE_TRMM = 1
5454
USE_DIRECT_SGEMM = 1
55+
USE_DIRECT_SSYMM = 1
5556
USE_DIRECT_STRMM = 1
5657
endif
5758

@@ -138,9 +139,28 @@ endif
138139
endif
139140
endif
140141

142+
ifdef USE_DIRECT_SSYMM
143+
ifndef SSYMMDIRECTKERNEL_ALPHA_BETA
144+
ifeq ($(ARCH), arm64)
145+
ifeq ($(TARGET_CORE), ARMV9SME)
146+
HAVE_SME = 1
147+
endif
148+
SSYMMDIRECTKERNEL_ALPHA_BETA = ssymm_direct_alpha_beta_arm64_sme1.c
149+
endif
150+
endif
151+
endif
152+
141153
ifdef USE_DIRECT_STRMM
154+
ifndef STRMMDIRECTKERNEL
155+
ifeq ($(ARCH), arm64)
156+
ifeq ($(TARGET_CORE), ARMV9SME)
157+
HAVE_SME = 1
158+
endif
142159
STRMMDIRECTKERNEL = strmm_direct_arm64_sme1.c
143160
endif
161+
endif
162+
endif
163+
144164

145165
ifeq ($(BUILD_BFLOAT16), 1)
146166
ifndef BGEMMKERNEL
@@ -225,6 +245,22 @@ endif
225245
endif
226246
endif
227247

248+
ifdef USE_DIRECT_SSYMM
249+
ifeq ($(ARCH), arm64)
250+
SKERNELOBJS += \
251+
ssymm_direct_alpha_betaLU$(TSUFFIX).$(SUFFIX) \
252+
ssymm_direct_alpha_betaLL$(TSUFFIX).$(SUFFIX)
253+
endif
254+
endif
255+
256+
ifdef USE_DIRECT_STRMM
257+
ifeq ($(ARCH), arm64)
258+
SKERNELOBJS += \
259+
strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) \
260+
strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) strmm_direct_LTLN$(TSUFFIX).$(SUFFIX)
261+
endif
262+
endif
263+
228264
ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
229265
DKERNELOBJS += \
230266
dgemm_beta$(TSUFFIX).$(SUFFIX) \
@@ -291,15 +327,6 @@ SBLASOBJS += \
291327
strsm_kernel_RN$(TSUFFIX).$(SUFFIX) strsm_kernel_RT$(TSUFFIX).$(SUFFIX)
292328
endif
293329

294-
ifdef USE_DIRECT_STRMM
295-
SBLASOBJS += \
296-
strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) \
297-
strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) strmm_direct_LTLN$(TSUFFIX).$(SUFFIX)
298-
SBLASOBJS += \
299-
strmm_direct_sme1_preprocess_UN$(TSUFFIX).$(SUFFIX) \
300-
strmm_direct_sme1_preprocess_LN$(TSUFFIX).$(SUFFIX)
301-
endif
302-
303330
ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
304331
DBLASOBJS += \
305332
dgemm_beta$(TSUFFIX).$(SUFFIX) \
@@ -996,6 +1023,15 @@ endif
9961023
endif
9971024
endif
9981025

1026+
ifdef USE_DIRECT_SSYMM
1027+
ifeq ($(ARCH), arm64)
1028+
$(KDIR)ssymm_direct_alpha_betaLU$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYMMDIRECTKERNEL_ALPHA_BETA)
1029+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DLEFT -DUPPER $< -o $@
1030+
$(KDIR)ssymm_direct_alpha_betaLL$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYMMDIRECTKERNEL_ALPHA_BETA)
1031+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DLEFT -DLOWER $< -o $@
1032+
endif
1033+
endif
1034+
9991035
ifeq ($(BUILD_BFLOAT16), 1)
10001036
$(KDIR)bgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BGEMMKERNEL)
10011037
$(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@
@@ -1166,6 +1202,7 @@ endif
11661202

11671203

11681204
ifdef USE_DIRECT_STRMM
1205+
ifeq ($(ARCH), arm64)
11691206
$(KDIR)strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
11701207
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -DUPPER $< -o $@
11711208

@@ -1177,12 +1214,7 @@ $(KDIR)strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
11771214

11781215
$(KDIR)strmm_direct_LTLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
11791216
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -UUPPER $< -o $@
1180-
1181-
$(KDIR)strmm_direct_sme1_preprocess_UN$(TSUFFIX).$(SUFFIX) :
1182-
$(CC) $(CFLAGS) -c $(KERNELDIR)/strmm_direct_arm64_sme1_preprocess.c -UDOUBLE -UCOMPLEX -DUPPER $< -o $@
1183-
1184-
$(KDIR)strmm_direct_sme1_preprocess_LN$(TSUFFIX).$(SUFFIX) :
1185-
$(CC) $(CFLAGS) -c $(KERNELDIR)/strmm_direct_arm64_sme1_preprocess.c -UDOUBLE -UCOMPLEX -UUPPER $< -o $@
1217+
endif
11861218
endif
11871219

11881220
$(KDIR)dtrmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DTRMMKERNEL)

0 commit comments

Comments
 (0)