88
99#include " DeviceCompilation.h"
1010#include " ESIMD.h"
11+ #include " JITBinaryInfo.h"
12+ #include " translation/Translation.h"
1113
1214#include < clang/Basic/DiagnosticDriver.h>
1315#include < clang/Basic/Version.h>
2224#include < clang/Frontend/Utils.h>
2325#include < clang/Tooling/CompilationDatabase.h>
2426#include < clang/Tooling/Tooling.h>
27+ #if defined(JIT_SUPPORT_PTX) || defined(JIT_SUPPORT_AMDGCN)
28+ #include < clang/Driver/Driver.h>
29+ #endif
30+ #ifdef JIT_SUPPORT_PTX
31+ #include < Driver/ToolChains/Cuda.h>
32+ #include < Driver/ToolChains/LazyDetector.h>
33+ #elif JIT_SUPPORT_AMDGCN
34+ #include < Driver/ToolChains/AMDGPU.h>
35+ #endif
2536
2637#include < llvm/IR/DiagnosticInfo.h>
2738#include < llvm/IR/DiagnosticPrinter.h>
@@ -178,7 +189,8 @@ class RTCToolActionBase : public ToolAction {
178189 assert (!hasExecuted () && " Action should only be invoked on a single file" );
179190
180191 // Create a compiler instance to handle the actual work.
181- CompilerInstance Compiler (std::move (Invocation), std::move (PCHContainerOps));
192+ CompilerInstance Compiler (std::move (Invocation),
193+ std::move (PCHContainerOps));
182194 Compiler.setFileManager (Files);
183195 // Suppress summary with number of warnings and errors being printed to
184196 // stdout.
@@ -361,10 +373,24 @@ static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot,
361373 });
362374}
363375
364- Expected<std::string>
365- jit_compiler::calculateHash (InMemoryFile SourceFile,
366- View<InMemoryFile> IncludeFiles,
367- const InputArgList &UserArgList) {
376+ static void setGPUTarget (BinaryFormat Format,
377+ SmallVector<std::string> &CommandLine) {
378+ auto [CPU, _] = Translator::getTargetCPUAndFeatureAttrs (nullptr , " " , Format);
379+ CommandLine.push_back (" -fsycl" );
380+ if (Format == BinaryFormat::PTX) {
381+ CommandLine.push_back (" -fsycl-targets=nvptx64-nvidia-cuda" );
382+ CommandLine.push_back (" -Xsycl-target-backend" );
383+ CommandLine.push_back (" --cuda-gpu-arch=" + CPU);
384+ } else if (Format == BinaryFormat::AMDGCN) {
385+ CommandLine.push_back (" -fsycl-targets=amdgcn-amd-amdhsa" );
386+ CommandLine.push_back (" -Xsycl-target-backend=amdgcn-amd-amdhsa" );
387+ CommandLine.push_back (" --offload-arch=" + CPU);
388+ }
389+ }
390+
391+ Expected<std::string> jit_compiler::calculateHash (
392+ InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
393+ const InputArgList &UserArgList, BinaryFormat Format) {
368394 TimeTraceScope TTS{" calculateHash" };
369395
370396 const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -373,6 +399,9 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
373399 }
374400
375401 SmallVector<std::string> CommandLine;
402+ if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
403+ setGPUTarget (Format, CommandLine);
404+ }
376405 adjustArgs (UserArgList, DPCPPRoot, CommandLine);
377406
378407 FixedCompilationDatabase DB{" ." , CommandLine};
@@ -399,11 +428,10 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
399428 return createStringError (" Calculating source hash failed" );
400429}
401430
402- Expected<ModuleUPtr>
403- jit_compiler::compileDeviceCode (InMemoryFile SourceFile,
404- View<InMemoryFile> IncludeFiles,
405- const InputArgList &UserArgList,
406- std::string &BuildLog, LLVMContext &Context) {
431+ Expected<ModuleUPtr> jit_compiler::compileDeviceCode (
432+ InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
433+ const InputArgList &UserArgList, std::string &BuildLog,
434+ LLVMContext &Context, BinaryFormat Format) {
407435 TimeTraceScope TTS{" compileDeviceCode" };
408436
409437 const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -412,6 +440,9 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
412440 }
413441
414442 SmallVector<std::string> CommandLine;
443+ if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
444+ setGPUTarget (Format, CommandLine);
445+ }
415446 adjustArgs (UserArgList, DPCPPRoot, CommandLine);
416447
417448 FixedCompilationDatabase DB{" ." , CommandLine};
@@ -430,12 +461,22 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
430461 return createStringError (BuildLog);
431462}
432463
433- // This function is a simplified copy of the device library selection process in
434- // `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
435- // (no AoT, no third-party GPUs , no native CPU). Keep in sync!
464+ // This function is a simplified copy of the device library selection process
465+ // in `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V, or
466+ // GPU targets ( no AoT , no native CPU). Keep in sync!
436467static bool getDeviceLibraries (const ArgList &Args,
437468 SmallVectorImpl<std::string> &LibraryList,
438- DiagnosticsEngine &Diags) {
469+ DiagnosticsEngine &Diags, BinaryFormat Format) {
470+ // For CUDA/HIP we only need devicelib, early exit here.
471+ if (Format == BinaryFormat::PTX) {
472+ LibraryList.push_back (
473+ Args.MakeArgString (" devicelib-nvptx64-nvidia-cuda.bc" ));
474+ return false ;
475+ } else if (Format == BinaryFormat::AMDGCN) {
476+ LibraryList.push_back (Args.MakeArgString (" devicelib-amdgcn-amd-amdhsa.bc" ));
477+ return false ;
478+ }
479+
439480 struct DeviceLibOptInfo {
440481 StringRef DeviceLibName;
441482 StringRef DeviceLibOption;
@@ -540,7 +581,8 @@ static Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,
540581
541582Error jit_compiler::linkDeviceLibraries (llvm::Module &Module,
542583 const InputArgList &UserArgList,
543- std::string &BuildLog) {
584+ std::string &BuildLog,
585+ BinaryFormat Format) {
544586 TimeTraceScope TTS{" linkDeviceLibraries" };
545587
546588 const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -555,11 +597,29 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
555597 /* ShouldOwnClient=*/ false );
556598
557599 SmallVector<std::string> LibNames;
558- bool FoundUnknownLib = getDeviceLibraries (UserArgList, LibNames, Diags);
600+ const bool FoundUnknownLib =
601+ getDeviceLibraries (UserArgList, LibNames, Diags, Format);
559602 if (FoundUnknownLib) {
560603 return createStringError (" Could not determine list of device libraries: %s" ,
561604 BuildLog.c_str ());
562605 }
606+ const bool IsGPUTarget =
607+ Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN;
608+ if (IsGPUTarget) {
609+ // Based on the OS and the format decide on the version of libspirv.
610+ // NOTE: this will be problematic if cross-compiling between OSes.
611+ std::string Libclc{" clc/" };
612+ Libclc.append (
613+ #ifdef _WIN32
614+ " remangled-l32-signed_char.libspirv-"
615+ #else
616+ " remangled-l64-signed_char.libspirv-"
617+ #endif
618+ );
619+ Libclc.append (Format == BinaryFormat::PTX ? " nvptx64-nvidia-cuda.bc"
620+ : " amdgcn-amd-amdhsa.bc" );
621+ LibNames.push_back (Libclc);
622+ }
563623
564624 LLVMContext &Context = Module.getContext ();
565625 for (const std::string &LibName : LibNames) {
@@ -577,6 +637,57 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
577637 }
578638 }
579639
640+ // For GPU targets we need to link against vendor provided libdevice.
641+ if (IsGPUTarget) {
642+ Triple T{Module.getTargetTriple ()};
643+ Driver D{(Twine (DPCPPRoot) + " /bin/clang++" ).str (), T.getTriple (), Diags};
644+ auto [CPU, _] =
645+ Translator::getTargetCPUAndFeatureAttrs (&Module, " " , Format);
646+ // Helper lambda to link modules.
647+ auto LinkInLib = [&](const StringRef LibDevice) -> Error {
648+ ModuleUPtr LibDeviceModule;
649+ if (auto Error = loadBitcodeLibrary (LibDevice, Context)
650+ .moveInto (LibDeviceModule)) {
651+ return Error;
652+ }
653+ if (Linker::linkModules (Module, std::move (LibDeviceModule),
654+ Linker::LinkOnlyNeeded)) {
655+ return createStringError (" Unable to link libdevice: %s" ,
656+ BuildLog.c_str ());
657+ }
658+ return Error::success ();
659+ };
660+ SmallVector<std::string, 12 > LibDeviceFiles;
661+ #ifdef JIT_SUPPORT_PTX
662+ // For NVPTX we can get away with CudaInstallationDetector.
663+ LazyDetector<CudaInstallationDetector> CudaInstallation{D, T, UserArgList};
664+ auto LibDevice = CudaInstallation->getLibDeviceFile (CPU);
665+ if (LibDevice.empty ()) {
666+ return createStringError (" Unable to find Cuda libdevice" );
667+ }
668+ LibDeviceFiles.push_back (LibDevice);
669+ #elif JIT_SUPPORT_AMDGCN
670+ // AMDGPU requires entire toolchain in order to provide all common bitcode
671+ // libraries.
672+ clang::driver::toolchains::ROCMToolChain TC (D, T, UserArgList);
673+ auto CommonDeviceLibs = TC.getCommonDeviceLibNames (
674+ UserArgList, CPU, Action::OffloadKind::OFK_SYCL, false );
675+ if (CommonDeviceLibs.empty ()) {
676+ return createStringError (" Unable to find ROCm common device libraries" );
677+ }
678+ for (auto &Lib : CommonDeviceLibs) {
679+ LibDeviceFiles.push_back (Lib.Path );
680+ }
681+ #endif
682+ for (auto &LibDeviceFile : LibDeviceFiles) {
683+ auto Res = LinkInLib (LibDeviceFile);
684+ // llvm::Error converts to false on success.
685+ if (Res) {
686+ return Res;
687+ }
688+ }
689+ }
690+
580691 return Error::success ();
581692}
582693
0 commit comments