Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions common_level3.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ void ssymm_direct_alpha_betaLL(BLASLONG M, BLASLONG N,
float beta,
float * R, BLASLONG strideR);

void strmm_direct_LNUN(BLASLONG M, BLASLONG N,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB);
void strmm_direct_LNLN(BLASLONG M, BLASLONG N,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB);
void strmm_direct_LTUN(BLASLONG M, BLASLONG N,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB);
void strmm_direct_LTLN(BLASLONG M, BLASLONG N,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB);

int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);

int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
Expand Down
4 changes: 4 additions & 0 deletions common_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
void (*sgemm_direct_alpha_beta) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
void (*ssymm_direct_alpha_betaLU) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
void (*ssymm_direct_alpha_betaLL) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
void (*strmm_direct_LNUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
void (*strmm_direct_LNLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
void (*strmm_direct_LTUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
void (*strmm_direct_LTLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
#endif


Expand Down
8 changes: 8 additions & 0 deletions common_s.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
#define SGEMM_DIRECT_ALPHA_BETA sgemm_direct_alpha_beta
#define SSYMM_DIRECT_ALPHA_BETA_LU ssymm_direct_alpha_betaLU
#define SSYMM_DIRECT_ALPHA_BETA_LL ssymm_direct_alpha_betaLL
#define STRMM_DIRECT_LNUN strmm_direct_LNUN
#define STRMM_DIRECT_LNLN strmm_direct_LNLN
#define STRMM_DIRECT_LTUN strmm_direct_LTUN
#define STRMM_DIRECT_LTLN strmm_direct_LTLN

#define SGEMM_ONCOPY sgemm_oncopy
#define SGEMM_OTCOPY sgemm_otcopy
Expand Down Expand Up @@ -224,6 +228,10 @@
#define SGEMM_DIRECT_ALPHA_BETA gotoblas -> sgemm_direct_alpha_beta
#define SSYMM_DIRECT_ALPHA_BETA_LU gotoblas -> ssymm_direct_alpha_betaLU
#define SSYMM_DIRECT_ALPHA_BETA_LL gotoblas -> ssymm_direct_alpha_betaLL
#define STRMM_DIRECT_LNUN gotoblas -> strmm_direct_LNUN
#define STRMM_DIRECT_LNLN gotoblas -> strmm_direct_LNLN
#define STRMM_DIRECT_LTUN gotoblas -> strmm_direct_LTUN
#define STRMM_DIRECT_LTLN gotoblas -> strmm_direct_LTLN
#endif

#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy
Expand Down
17 changes: 17 additions & 0 deletions interface/trsm.c
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,23 @@ void CNAME(enum CBLAS_ORDER order,
return;
}

#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
#if defined(ARCH_ARM64) && (defined(USE_STRMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
#if defined(DYNAMIC_ARCH)
if (support_sme1())
#endif
if (args.m == 0 || args.n == 0) return;
if (order == CblasRowMajor && Diag == CblasNonUnit && Side == CblasLeft && m == lda && n == ldb) {
if (Trans == CblasNoTrans) {
(Uplo == CblasUpper ? STRMM_DIRECT_LNUN : STRMM_DIRECT_LNLN)(m, n, alpha, a, lda, b, ldb);
} else if (Trans == CblasTrans) {
(Uplo == CblasUpper ? STRMM_DIRECT_LTUN : STRMM_DIRECT_LTLN)(m, n, alpha, a, lda, b, ldb);
}
return;
}
#endif
#endif

#endif

if ((args.m == 0) || (args.n == 0)) return;
Expand Down
15 changes: 15 additions & 0 deletions kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
if (ZARCH OR (UC_TARGET_CORE MATCHES POWER8) OR (UC_TARGET_CORE MATCHES POWER9) OR (UC_TARGET_CORE MATCHES POWER10))
set(USE_TRMM true)
endif ()
set(USE_DIRECT_STRMM false)
if (ARM64)
set(USE_DIRECT_STRMM true)
endif()
set(USE_DIRECT_SGEMM false)
if (X86_64 OR ARM64)
set(USE_DIRECT_SGEMM true)
Expand Down Expand Up @@ -279,6 +283,16 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
endif ()
endif()

if (USE_DIRECT_STRMM)
if (ARM64)
set (STRMMDIRECTKERNEL strmm_direct_arm64_sme1.c)
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNUN" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNLN" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTUN" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTLN" false "" "" false SINGLE)
endif ()
endif ()

foreach (float_type SINGLE DOUBLE)
string(SUBSTRING ${float_type} 0 1 float_char)
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type})
Expand Down Expand Up @@ -454,6 +468,7 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
set(TRMM_KERNEL "${${float_char}GEMMKERNEL}")
endif ()


if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX")

# just enumerate all these. there is an extra define for these indicating which side is a conjugate (e.g. CN NC NN) that I don't really want to work into GenerateCombinationObjects
Expand Down
38 changes: 38 additions & 0 deletions kernel/Makefile.L3
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ ifeq ($(ARCH), arm64)
USE_TRMM = 1
USE_DIRECT_SGEMM = 1
USE_DIRECT_SSYMM = 1
USE_DIRECT_STRMM = 1
endif

ifeq ($(ARCH), riscv64)
Expand Down Expand Up @@ -149,6 +150,18 @@ endif
endif
endif

ifdef USE_DIRECT_STRMM
ifndef STRMMDIRECTKERNEL
ifeq ($(ARCH), arm64)
ifeq ($(TARGET_CORE), ARMV9SME)
HAVE_SME = 1
endif
STRMMDIRECTKERNEL = strmm_direct_arm64_sme1.c
endif
endif
endif


ifeq ($(BUILD_BFLOAT16), 1)
ifndef BGEMMKERNEL
BGEMM_BETA = ../generic/gemm_beta.c
Expand Down Expand Up @@ -240,6 +253,14 @@ SKERNELOBJS += \
endif
endif

ifdef USE_DIRECT_STRMM
ifeq ($(ARCH), arm64)
SKERNELOBJS += \
strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) \
strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) strmm_direct_LTLN$(TSUFFIX).$(SUFFIX)
endif
endif

ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
DKERNELOBJS += \
dgemm_beta$(TSUFFIX).$(SUFFIX) \
Expand Down Expand Up @@ -1179,6 +1200,23 @@ else
$(CC) $(CFLAGS) -c -DTRMMKERNEL -UDOUBLE -UCOMPLEX -ULEFT -DTRANSA $< -o $@
endif


ifdef USE_DIRECT_STRMM
ifeq ($(ARCH), arm64)
$(KDIR)strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -DUPPER $< -o $@

$(KDIR)strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -UUPPER $< -o $@

$(KDIR)strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -DUPPER $< -o $@

$(KDIR)strmm_direct_LTLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -UUPPER $< -o $@
endif
endif

$(KDIR)dtrmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DTRMMKERNEL)
ifeq ($(OS), AIX)
$(CC) $(CFLAGS) -S -DTRMMKERNEL -DDOUBLE -UCOMPLEX -DLEFT -UTRANSA $< -o - > dtrmm_kernel_ln.s
Expand Down
Loading
Loading