From 0c41739b754ca418263d9f9eb70265283cdbc98d Mon Sep 17 00:00:00 2001
From: raghavendrak <raghavendra066@gmail.com>
Date: Sat, 15 Jan 2022 08:14:57 -0600
Subject: [PATCH] fix CUDA compilation issues; placeholders for upload_mdl_init
 and download_mdl_init in init_models.cxx

---
 src/contraction/spctr_offload.cxx | 6 +++---
 src/contraction/spctr_offload.h   | 4 ++--
 src/shared/init_models.cxx        | 2 ++
 src/shared/offload.cu             | 2 +-
 4 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/src/contraction/spctr_offload.cxx b/src/contraction/spctr_offload.cxx
index 8b765ca8..a7f0323a 100644
--- a/src/contraction/spctr_offload.cxx
+++ b/src/contraction/spctr_offload.cxx
@@ -73,8 +73,8 @@ namespace CTF_int {
     return tot_time;
   }
 
-  double spctr_offload::est_time_rec(int nlyr, double nnz_frac_A, double nnz_frac_B, double nnz_frac_C){
-    return rec_ctr->est_time_rec(nlyr, nnz_frac_A, nnz_frac_B, nnz_frac_C) + est_time_fp(nlyr, nnz_frac_A, nnz_frac_B, nnz_frac_C);
+  double spctr_offload::est_time_rec(int nlyr, int nblk_A, int nblk_B, int nblk_C, double nnz_frac_A, double nnz_frac_B, double nnz_frac_C){
+    return rec_ctr->est_time_rec(nlyr, nblk_A, nblk_B, nblk_C, nnz_frac_A, nnz_frac_B, nnz_frac_C) + est_time_fp(nlyr, nnz_frac_A, nnz_frac_B, nnz_frac_C);
   }
 
   int64_t spctr_offload::spmem_fp(double nnz_frac_A, double nnz_frac_B, double nnz_frac_C){
@@ -82,7 +82,7 @@ namespace CTF_int {
   }
 
   int64_t spctr_offload::mem_rec(double nnz_frac_A, double nnz_frac_B, double nnz_frac_C) {
-    return rec_ctr->mem_rec(nnz_frac_A, nnz_frac_B, nnz_frac_C) + spmem_fp(nnz_frac_A, nnz_frac_B, nnz_frac_C);
+    return rec_ctr->spmem_rec(nnz_frac_A, nnz_frac_B, nnz_frac_C) + spmem_fp(nnz_frac_A, nnz_frac_B, nnz_frac_C);
   }
 
   void spctr_offload::run(char * A, int nblk_A, int64_t const * size_blk_A,
diff --git a/src/contraction/spctr_offload.h b/src/contraction/spctr_offload.h
index de6e59d8..bd9aa61c 100644
--- a/src/contraction/spctr_offload.h
+++ b/src/contraction/spctr_offload.h
@@ -42,7 +42,7 @@ namespace CTF_int {
          we need 
        * \return bytes needed
        */
-      int64_t spmem_fp();
+      int64_t spmem_fp(double nnz_frac_A, double nnz_frac_B, double nnz_frac_C);
 
       /**
        * \brief returns the number of bytes of buffer space we need recursively 
@@ -60,7 +60,7 @@ namespace CTF_int {
        * \brief returns the time this kernel will take including calls to rec_ctr
        * \return seconds needed for recursive contraction
        */
-      double est_time_rec(int nlyr, double nnz_frac_A, double nnz_frac_B, double nnz_frac_C);
+      double est_time_rec(int nlyr, int nblk_A, int nblk_B, int nblk_C, double nnz_frac_A, double nnz_frac_B, double nnz_frac_C);
 
       spctr * clone();
 
diff --git a/src/shared/init_models.cxx b/src/shared/init_models.cxx
index cb4cbbd7..8e82bd34 100644
--- a/src/shared/init_models.cxx
+++ b/src/shared/init_models.cxx
@@ -14,6 +14,8 @@ double seq_tsr_ctr_mdl_inr_init[] = {1.0689E-05, 9.4660E-10, 2.1921E-10};
 double seq_tsr_ctr_mdl_off_init[] = {6.2925E-05, 1.7449E-11, 1.7211E-12};
 double seq_tsr_ctr_mdl_cst_inr_init[] = {1.3863E-04, 2.0119E-10, 9.8820E-09};
 double seq_tsr_ctr_mdl_cst_off_init[] = {8.4844E-04, 5.9246E-11, 3.5247E-10};
+double upload_mdl_init[] = {8.4844E-04, 5.9246E-11, 3.5247E-10};
+double download_mdl_init[] = {8.4844E-04, 5.9246E-11, 3.5247E-10};
 double long_contig_transp_mdl_init[] = {1.5117E-04, 1.9091E-09};
 double shrt_contig_transp_mdl_init[] = {7.7643E-05, 6.4347E-12};
 double non_contig_transp_mdl_init[] = {2.6680E-05, 4.6247E-06};
diff --git a/src/shared/offload.cu b/src/shared/offload.cu
index 09a04e4b..fac6da13 100644
--- a/src/shared/offload.cu
+++ b/src/shared/offload.cu
@@ -228,7 +228,7 @@ namespace CTF_int{
     }  
   
     cublasStatus_t status = 
-      cublasDgemm(cuhandle, cuA, cuB, m, n, k, &alpha, 
+      cublasSgemm(cuhandle, cuA, cuB, m, n, k, &alpha,
                   dev_A, lda_A, 
                   dev_B, lda_B, &beta, 
                   dev_C, lda_C);