Skip to content

Commit 644ea07

Browse files
committed
Support for SME1 based strmm_direct kernel for cblas_strmm level 3 API
1 parent 1926847 commit 644ea07

File tree

9 files changed

+362
-0
lines changed

9 files changed

+362
-0
lines changed

common_level3.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,23 @@ void ssymm_direct_alpha_betaLL(BLASLONG M, BLASLONG N,
7272
float beta,
7373
float * R, BLASLONG strideR);
7474

75+
void strmm_direct_LNUN(BLASLONG M, BLASLONG N,
76+
float alpha,
77+
float * A, BLASLONG strideA,
78+
float * B, BLASLONG strideB);
79+
void strmm_direct_LNLN(BLASLONG M, BLASLONG N,
80+
float alpha,
81+
float * A, BLASLONG strideA,
82+
float * B, BLASLONG strideB);
83+
void strmm_direct_LTUN(BLASLONG M, BLASLONG N,
84+
float alpha,
85+
float * A, BLASLONG strideA,
86+
float * B, BLASLONG strideB);
87+
void strmm_direct_LTLN(BLASLONG M, BLASLONG N,
88+
float alpha,
89+
float * A, BLASLONG strideA,
90+
float * B, BLASLONG strideB);
91+
7592
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
7693

7794
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,

common_param.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,10 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
259259
void (*sgemm_direct_alpha_beta) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
260260
void (*ssymm_direct_alpha_betaLU) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
261261
void (*ssymm_direct_alpha_betaLL) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
262+
void (*strmm_direct_LNUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
263+
void (*strmm_direct_LNLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
264+
void (*strmm_direct_LTUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
265+
void (*strmm_direct_LTLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
262266
#endif
263267

264268

common_s.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252
#define SGEMM_DIRECT_ALPHA_BETA sgemm_direct_alpha_beta
5353
#define SSYMM_DIRECT_ALPHA_BETA_LU ssymm_direct_alpha_betaLU
5454
#define SSYMM_DIRECT_ALPHA_BETA_LL ssymm_direct_alpha_betaLL
55+
#define STRMM_DIRECT_LNUN strmm_direct_LNUN
56+
#define STRMM_DIRECT_LNLN strmm_direct_LNLN
57+
#define STRMM_DIRECT_LTUN strmm_direct_LTUN
58+
#define STRMM_DIRECT_LTLN strmm_direct_LTLN
5559

5660
#define SGEMM_ONCOPY sgemm_oncopy
5761
#define SGEMM_OTCOPY sgemm_otcopy
@@ -224,6 +228,10 @@
224228
#define SGEMM_DIRECT_ALPHA_BETA gotoblas -> sgemm_direct_alpha_beta
225229
#define SSYMM_DIRECT_ALPHA_BETA_LU gotoblas -> ssymm_direct_alpha_betaLU
226230
#define SSYMM_DIRECT_ALPHA_BETA_LL gotoblas -> ssymm_direct_alpha_betaLL
231+
#define STRMM_DIRECT_LNUN gotoblas -> strmm_direct_LNUN
232+
#define STRMM_DIRECT_LNLN gotoblas -> strmm_direct_LNLN
233+
#define STRMM_DIRECT_LTUN gotoblas -> strmm_direct_LTUN
234+
#define STRMM_DIRECT_LTLN gotoblas -> strmm_direct_LTLN
227235
#endif
228236

229237
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy

interface/trsm.c

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,23 @@ void CNAME(enum CBLAS_ORDER order,
355355
return;
356356
}
357357

358+
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
359+
#if defined(ARCH_ARM64) && (defined(USE_STRMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
360+
#if defined(DYNAMIC_ARCH)
361+
if (support_sme1())
362+
#endif
363+
if (args.m == 0 || args.n == 0) return;
364+
if (order == CblasRowMajor && Diag == CblasNonUnit && Side == CblasLeft && m == lda && n == ldb) {
365+
if (Trans == CblasNoTrans) {
366+
(Uplo == CblasUpper ? STRMM_DIRECT_LNUN : STRMM_DIRECT_LNLN)(m, n, alpha, a, lda, b, ldb);
367+
} else if (Trans == CblasTrans) {
368+
(Uplo == CblasUpper ? STRMM_DIRECT_LTUN : STRMM_DIRECT_LTLN)(m, n, alpha, a, lda, b, ldb);
369+
}
370+
return;
371+
}
372+
#endif
373+
#endif
374+
358375
#endif
359376

360377
if ((args.m == 0) || (args.n == 0)) return;

kernel/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
237237
if (ZARCH OR (UC_TARGET_CORE MATCHES POWER8) OR (UC_TARGET_CORE MATCHES POWER9) OR (UC_TARGET_CORE MATCHES POWER10))
238238
set(USE_TRMM true)
239239
endif ()
240+
set(USE_DIRECT_STRMM false)
241+
if (ARM64)
242+
set(USE_DIRECT_STRMM true)
243+
endif()
240244
set(USE_DIRECT_SGEMM false)
241245
if (X86_64 OR ARM64)
242246
set(USE_DIRECT_SGEMM true)
@@ -279,6 +283,16 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
279283
endif ()
280284
endif()
281285

286+
if (USE_DIRECT_STRMM)
287+
if (ARM64)
288+
set (STRMMDIRECTKERNEL strmm_direct_arm64_sme1.c)
289+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNUN" false "" "" false SINGLE)
290+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNLN" false "" "" false SINGLE)
291+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTUN" false "" "" false SINGLE)
292+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTLN" false "" "" false SINGLE)
293+
endif ()
294+
endif ()
295+
282296
foreach (float_type SINGLE DOUBLE)
283297
string(SUBSTRING ${float_type} 0 1 float_char)
284298
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type})
@@ -454,6 +468,7 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
454468
set(TRMM_KERNEL "${${float_char}GEMMKERNEL}")
455469
endif ()
456470

471+
457472
if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX")
458473

459474
# just enumerate all these. there is an extra define for these indicating which side is a conjugate (e.g. CN NC NN) that I don't really want to work into GenerateCombinationObjects

kernel/Makefile.L3

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ ifeq ($(ARCH), arm64)
5353
USE_TRMM = 1
5454
USE_DIRECT_SGEMM = 1
5555
USE_DIRECT_SSYMM = 1
56+
USE_DIRECT_STRMM = 1
5657
endif
5758

5859
ifeq ($(ARCH), riscv64)
@@ -149,6 +150,18 @@ endif
149150
endif
150151
endif
151152

153+
ifdef USE_DIRECT_STRMM
154+
ifndef STRMMDIRECTKERNEL
155+
ifeq ($(ARCH), arm64)
156+
ifeq ($(TARGET_CORE), ARMV9SME)
157+
HAVE_SME = 1
158+
endif
159+
STRMMDIRECTKERNEL = strmm_direct_arm64_sme1.c
160+
endif
161+
endif
162+
endif
163+
164+
152165
ifeq ($(BUILD_BFLOAT16), 1)
153166
ifndef BGEMMKERNEL
154167
BGEMM_BETA = ../generic/gemm_beta.c
@@ -240,6 +253,14 @@ SKERNELOBJS += \
240253
endif
241254
endif
242255

256+
ifdef USE_DIRECT_STRMM
257+
ifeq ($(ARCH), arm64)
258+
SKERNELOBJS += \
259+
strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) \
260+
strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) strmm_direct_LTLN$(TSUFFIX).$(SUFFIX)
261+
endif
262+
endif
263+
243264
ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
244265
DKERNELOBJS += \
245266
dgemm_beta$(TSUFFIX).$(SUFFIX) \
@@ -1179,6 +1200,23 @@ else
11791200
$(CC) $(CFLAGS) -c -DTRMMKERNEL -UDOUBLE -UCOMPLEX -ULEFT -DTRANSA $< -o $@
11801201
endif
11811202

1203+
1204+
ifdef USE_DIRECT_STRMM
1205+
ifeq ($(ARCH), arm64)
1206+
$(KDIR)strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1207+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -DUPPER $< -o $@
1208+
1209+
$(KDIR)strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1210+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -UUPPER $< -o $@
1211+
1212+
$(KDIR)strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1213+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -DUPPER $< -o $@
1214+
1215+
$(KDIR)strmm_direct_LTLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1216+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -UUPPER $< -o $@
1217+
endif
1218+
endif
1219+
11821220
$(KDIR)dtrmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DTRMMKERNEL)
11831221
ifeq ($(OS), AIX)
11841222
$(CC) $(CFLAGS) -S -DTRMMKERNEL -DDOUBLE -UCOMPLEX -DLEFT -UTRANSA $< -o - > dtrmm_kernel_ln.s

0 commit comments

Comments
 (0)