@@ -746,36 +746,74 @@ static const std::string gpu_pipeline =
746746 " func.func(convert-parallel-loops-to-gpu),"
747747 // insert-gpu-allocs pass can have client-api = opencl or vulkan args
748748 " func.func(insert-gpu-allocs{in-regions=1}),"
749- // ** imex GPU passes
750- // "drop-regions,"
751- // "canonicalize,"
752- // // "normalize-memrefs,"
753- // // "gpu-decompose-memrefs,"
754- // "func.func(lower-affine),"
755- // "gpu-kernel-outlining,"
756- // "canonicalize,"
757- // "cse,"
758- // // The following set-spirv-* passes can have client-api = opencl or
759- // vulkan
760- // // args
761- // "set-spirv-capabilities{client-api=opencl},"
762- // "gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
763- // "canonicalize,"
764- // "fold-memref-alias-ops,"
765- // "imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
766- // "spirv.module(spirv-lower-abi-attrs),"
767- // "spirv.module(spirv-update-vce),"
768- // // "func.func(llvm-request-c-wrappers),"
769- // "serialize-spirv,"
770- // "expand-strided-metadata,"
771- // "lower-affine,"
772- // "convert-gpu-to-gpux,"
773- // "convert-func-to-llvm,"
774- // "convert-math-to-llvm,"
775- // "convert-gpux-to-llvm,"
776- // "finalize-memref-to-llvm,"
777- // "reconcile-unrealized-casts";
778- // ** nv GPU passes
749+ " drop-regions,"
750+ " canonicalize,"
751+ // "normalize-memrefs,"
752+ // "gpu-decompose-memrefs,"
753+ " func.func(lower-affine),"
754+ " gpu-kernel-outlining,"
755+ " canonicalize,"
756+ " cse,"
757+ // The following set-spirv-* passes can have client-api = opencl or vulkan
758+ // args
759+ " set-spirv-capabilities{client-api=opencl},"
760+ " gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
761+ " canonicalize,"
762+ " fold-memref-alias-ops,"
763+ " imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
764+ " spirv.module(spirv-lower-abi-attrs),"
765+ " spirv.module(spirv-update-vce),"
766+ // "func.func(llvm-request-c-wrappers),"
767+ " serialize-spirv,"
768+ " expand-strided-metadata,"
769+ " lower-affine,"
770+ " convert-gpu-to-gpux,"
771+ " convert-func-to-llvm,"
772+ " convert-math-to-llvm,"
773+ " convert-gpux-to-llvm,"
774+ " finalize-memref-to-llvm,"
775+ " reconcile-unrealized-casts" ;
776+
777+ static const std::string cuda_pipeline =
778+ " add-gpu-regions,"
779+ " canonicalize,"
780+ " ndarray-dist,"
781+ " func.func(dist-coalesce),"
782+ " func.func(dist-infer-elementwise-cores),"
783+ " convert-dist-to-standard,"
784+ " canonicalize,"
785+ " overlap-comm-and-compute,"
786+ " add-comm-cache-keys,"
787+ " lower-distruntime-to-idtr,"
788+ " convert-ndarray-to-linalg,"
789+ " canonicalize,"
790+ " func.func(tosa-make-broadcastable),"
791+ " func.func(tosa-to-linalg),"
792+ " func.func(tosa-to-tensor),"
793+ " canonicalize,"
794+ " linalg-fuse-elementwise-ops,"
795+ " arith-expand,"
796+ " memref-expand,"
797+ " arith-bufferize,"
798+ " func-bufferize,"
799+ " func.func(empty-tensor-to-alloc-tensor),"
800+ " func.func(scf-bufferize),"
801+ " func.func(tensor-bufferize),"
802+ " func.func(bufferization-bufferize),"
803+ " func.func(linalg-bufferize),"
804+ " func.func(linalg-detensorize),"
805+ " func.func(tensor-bufferize),"
806+ " region-bufferize,"
807+ " canonicalize,"
808+ " func.func(finalizing-bufferize),"
809+ " imex-remove-temporaries,"
810+ " func.func(convert-linalg-to-parallel-loops),"
811+ " func.func(scf-parallel-loop-fusion),"
812+ // is add-outer-parallel-loop needed?
813+ " func.func(imex-add-outer-parallel-loop),"
814+ " func.func(gpu-map-parallel-loops),"
815+ " func.func(convert-parallel-loops-to-gpu),"
816+ " func.func(insert-gpu-allocs{in-regions=1}),"
779817 " func.func(insert-gpu-copy),"
780818 " drop-regions,"
781819 " canonicalize,"
@@ -797,7 +835,9 @@ static const std::string gpu_pipeline =
797835
798836const std::string _passes (get_text_env (" SHARPY_PASSES" ));
799837static const std::string &pass_pipeline =
800- _passes != " " ? _passes : (useGPU () ? gpu_pipeline : cpu_pipeline);
838+ _passes != " " ? _passes
839+ : (useGPU () ? (useCUDA () ? cuda_pipeline : gpu_pipeline)
840+ : cpu_pipeline);
801841
802842JIT::JIT (const std::string &libidtr)
803843 : _context (::mlir::MLIRContext::Threading::DISABLED), _pm (&_context),
@@ -849,23 +889,24 @@ JIT::JIT(const std::string &libidtr)
849889 _crunnerlib = mlirRoot + " /lib/libmlir_c_runner_utils.so" ;
850890 _runnerlib = mlirRoot + " /lib/libmlir_runner_utils.so" ;
851891 if (!std::ifstream (_crunnerlib)) {
852- throw std::runtime_error (" Cannot find libmlir_c_runner_utils.so " );
892+ throw std::runtime_error (" Cannot find lib: " + _crunnerlib );
853893 }
854894 if (!std::ifstream (_runnerlib)) {
855- throw std::runtime_error (" Cannot find libmlir_runner_utils.so " );
895+ throw std::runtime_error (" Cannot find lib: " + _runnerlib );
856896 }
857897
858898 if (useGPU ()) {
859899 auto gpuxlibstr = get_text_env (" SHARPY_GPUX_SO" );
860900 if (!gpuxlibstr.empty ()) {
861901 _gpulib = std::string (gpuxlibstr);
862902 } else {
863- // auto imexRoot = get_text_env("IMEXROOT");
864- // imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
865- // _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
866- // _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
867- // for nv gpu
868- _gpulib = mlirRoot + " /lib/libmlir_cuda_runtime.so" ;
903+ if (useCUDA ()) {
904+ _gpulib = mlirRoot + " /lib/libmlir_cuda_runtime.so" ;
905+ } else {
906+ auto imexRoot = get_text_env (" IMEXROOT" );
907+ imexRoot = !imexRoot.empty () ? imexRoot : std::string (CMAKE_IMEX_ROOT);
908+ _gpulib = imexRoot + " /lib/liblevel-zero-runtime.so" ;
909+ }
869910 if (!std::ifstream (_gpulib)) {
870911 throw std::runtime_error (" Cannot find lib: " + _gpulib);
871912 }
0 commit comments