88
99#include " DeviceCompilation.h"
1010#include " ESIMD.h"
11+ #include " JITBinaryInfo.h"
12+ #include " translation/Translation.h"
1113
14+ #include < Driver/ToolChains/AMDGPU.h>
15+ #include < Driver/ToolChains/Cuda.h>
16+ #include < Driver/ToolChains/LazyDetector.h>
1217#include < clang/Basic/DiagnosticDriver.h>
1318#include < clang/Basic/Version.h>
1419#include < clang/CodeGen/CodeGenAction.h>
1520#include < clang/Driver/Compilation.h>
21+ #include < clang/Driver/Driver.h>
1622#include < clang/Driver/Options.h>
1723#include < clang/Frontend/ChainedDiagnosticConsumer.h>
1824#include < clang/Frontend/CompilerInstance.h>
@@ -178,7 +184,8 @@ class RTCToolActionBase : public ToolAction {
178184 assert (!hasExecuted () && " Action should only be invoked on a single file" );
179185
180186 // Create a compiler instance to handle the actual work.
181- CompilerInstance Compiler (std::move (Invocation), std::move (PCHContainerOps));
187+ CompilerInstance Compiler (std::move (Invocation),
188+ std::move (PCHContainerOps));
182189 Compiler.setFileManager (Files);
183190 // Suppress summary with number of warnings and errors being printed to
184191 // stdout.
@@ -312,7 +319,7 @@ class LLVMDiagnosticWrapper : public llvm::DiagnosticHandler {
312319} // anonymous namespace
313320
314321static void adjustArgs (const InputArgList &UserArgList,
315- const std::string &DPCPPRoot,
322+ const std::string &DPCPPRoot, BinaryFormat Format,
316323 SmallVectorImpl<std::string> &CommandLine) {
317324 DerivedArgList DAL{UserArgList};
318325 const auto &OptTable = getDriverOptTable ();
@@ -325,6 +332,23 @@ static void adjustArgs(const InputArgList &UserArgList,
325332 // unused argument warning.
326333 DAL.AddFlagArg (nullptr , OptTable.getOption (OPT_Qunused_arguments));
327334
335+ if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
336+ auto [CPU, Features] =
337+ Translator::getTargetCPUAndFeatureAttrs (nullptr , " " , Format);
338+ (void )Features;
339+ if (Format == BinaryFormat::AMDGCN) {
340+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_fsycl_targets_EQ),
341+ " amdgcn-amd-amdhsa" );
342+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_Xsycl_backend_EQ),
343+ " amdgcn-amd-amdhsa" );
344+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_offload_arch_EQ), CPU);
345+ } else {
346+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_fsycl_targets_EQ),
347+ " nvptx64-nvidia-cuda" );
348+ DAL.AddFlagArg (nullptr , OptTable.getOption (OPT_Xsycl_backend));
349+ DAL.AddJoinedArg (nullptr , OptTable.getOption (OPT_cuda_gpu_arch_EQ), CPU);
350+ }
351+ }
328352 ArgStringList ASL;
329353 for_each (DAL, [&DAL, &ASL](Arg *A) { A->render (DAL, ASL); });
330354 for_each (UserArgList,
@@ -361,10 +385,9 @@ static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot,
361385 });
362386}
363387
364- Expected<std::string>
365- jit_compiler::calculateHash (InMemoryFile SourceFile,
366- View<InMemoryFile> IncludeFiles,
367- const InputArgList &UserArgList) {
388+ Expected<std::string> jit_compiler::calculateHash (
389+ InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
390+ const InputArgList &UserArgList, BinaryFormat Format) {
368391 TimeTraceScope TTS{" calculateHash" };
369392
370393 const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -373,7 +396,7 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
373396 }
374397
375398 SmallVector<std::string> CommandLine;
376- adjustArgs (UserArgList, DPCPPRoot, CommandLine);
399+ adjustArgs (UserArgList, DPCPPRoot, Format, CommandLine);
377400
378401 FixedCompilationDatabase DB{" ." , CommandLine};
379402 ClangTool Tool{DB, {SourceFile.Path }};
@@ -399,11 +422,10 @@ jit_compiler::calculateHash(InMemoryFile SourceFile,
399422 return createStringError (" Calculating source hash failed" );
400423}
401424
402- Expected<ModuleUPtr>
403- jit_compiler::compileDeviceCode (InMemoryFile SourceFile,
404- View<InMemoryFile> IncludeFiles,
405- const InputArgList &UserArgList,
406- std::string &BuildLog, LLVMContext &Context) {
425+ Expected<ModuleUPtr> jit_compiler::compileDeviceCode (
426+ InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
427+ const InputArgList &UserArgList, std::string &BuildLog,
428+ LLVMContext &Context, BinaryFormat Format) {
407429 TimeTraceScope TTS{" compileDeviceCode" };
408430
409431 const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -412,7 +434,7 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
412434 }
413435
414436 SmallVector<std::string> CommandLine;
415- adjustArgs (UserArgList, DPCPPRoot, CommandLine);
437+ adjustArgs (UserArgList, DPCPPRoot, Format, CommandLine);
416438
417439 FixedCompilationDatabase DB{" ." , CommandLine};
418440 ClangTool Tool{DB, {SourceFile.Path }};
@@ -430,12 +452,22 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
430452 return createStringError (BuildLog);
431453}
432454
433- // This function is a simplified copy of the device library selection process in
434- // `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
435- // (no AoT, no third-party GPUs , no native CPU). Keep in sync!
455+ // This function is a simplified copy of the device library selection process
456+ // in `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V, or
457+ // GPU targets ( no AoT , no native CPU). Keep in sync!
436458static bool getDeviceLibraries (const ArgList &Args,
437459 SmallVectorImpl<std::string> &LibraryList,
438- DiagnosticsEngine &Diags) {
460+ DiagnosticsEngine &Diags, BinaryFormat Format) {
461+ // For CUDA/HIP we only need devicelib, early exit here.
462+ if (Format == BinaryFormat::PTX) {
463+ LibraryList.push_back (
464+ Args.MakeArgString (" devicelib-nvptx64-nvidia-cuda.bc" ));
465+ return false ;
466+ } else if (Format == BinaryFormat::AMDGCN) {
467+ LibraryList.push_back (Args.MakeArgString (" devicelib-amdgcn-amd-amdhsa.bc" ));
468+ return false ;
469+ }
470+
439471 struct DeviceLibOptInfo {
440472 StringRef DeviceLibName;
441473 StringRef DeviceLibOption;
@@ -540,7 +572,8 @@ static Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,
540572
541573Error jit_compiler::linkDeviceLibraries (llvm::Module &Module,
542574 const InputArgList &UserArgList,
543- std::string &BuildLog) {
575+ std::string &BuildLog,
576+ BinaryFormat Format) {
544577 TimeTraceScope TTS{" linkDeviceLibraries" };
545578
546579 const std::string &DPCPPRoot = getDPCPPRoot ();
@@ -555,11 +588,29 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
555588 /* ShouldOwnClient=*/ false );
556589
557590 SmallVector<std::string> LibNames;
558- bool FoundUnknownLib = getDeviceLibraries (UserArgList, LibNames, Diags);
591+ const bool FoundUnknownLib =
592+ getDeviceLibraries (UserArgList, LibNames, Diags, Format);
559593 if (FoundUnknownLib) {
560594 return createStringError (" Could not determine list of device libraries: %s" ,
561595 BuildLog.c_str ());
562596 }
597+ const bool IsCudaHIP =
598+ Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN;
599+ if (IsCudaHIP) {
600+ // Based on the OS and the format decide on the version of libspirv.
601+ // NOTE: this will be problematic if cross-compiling between OSes.
602+ std::string Libclc{" clc/" };
603+ Libclc.append (
604+ #ifdef _WIN32
605+ " remangled-l32-signed_char.libspirv-"
606+ #else
607+ " remangled-l64-signed_char.libspirv-"
608+ #endif
609+ );
610+ Libclc.append (Format == BinaryFormat::PTX ? " nvptx64-nvidia-cuda.bc"
611+ : " amdgcn-amd-amdhsa.bc" );
612+ LibNames.push_back (Libclc);
613+ }
563614
564615 LLVMContext &Context = Module.getContext ();
565616 for (const std::string &LibName : LibNames) {
@@ -577,6 +628,58 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
577628 }
578629 }
579630
631+ // For GPU targets we need to link against vendor provided libdevice.
632+ if (IsCudaHIP) {
633+ Triple T{Module.getTargetTriple ()};
634+ Driver D{(Twine (DPCPPRoot) + " /bin/clang++" ).str (), T.getTriple (), Diags};
635+ auto [CPU, Features] =
636+ Translator::getTargetCPUAndFeatureAttrs (&Module, " " , Format);
637+ (void )Features;
638+ // Helper lambda to link modules.
639+ auto LinkInLib = [&](const StringRef LibDevice) -> Error {
640+ ModuleUPtr LibDeviceModule;
641+ if (auto Error = loadBitcodeLibrary (LibDevice, Context)
642+ .moveInto (LibDeviceModule)) {
643+ return Error;
644+ }
645+ if (Linker::linkModules (Module, std::move (LibDeviceModule),
646+ Linker::LinkOnlyNeeded)) {
647+ return createStringError (" Unable to link libdevice: %s" ,
648+ BuildLog.c_str ());
649+ }
650+ return Error::success ();
651+ };
652+ SmallVector<std::string, 12 > LibDeviceFiles;
653+ if (Format == BinaryFormat::PTX) {
654+ // For NVPTX we can get away with CudaInstallationDetector.
655+ LazyDetector<CudaInstallationDetector> CudaInstallation{D, T,
656+ UserArgList};
657+ auto LibDevice = CudaInstallation->getLibDeviceFile (CPU);
658+ if (LibDevice.empty ()) {
659+ return createStringError (" Unable to find Cuda libdevice" );
660+ }
661+ LibDeviceFiles.push_back (LibDevice);
662+ } else {
663+ // AMDGPU requires entire toolchain in order to provide all common bitcode
664+ // libraries.
665+ clang::driver::toolchains::ROCMToolChain TC (D, T, UserArgList);
666+ auto CommonDeviceLibs = TC.getCommonDeviceLibNames (
667+ UserArgList, CPU, Action::OffloadKind::OFK_SYCL, false );
668+ if (CommonDeviceLibs.empty ()) {
669+ return createStringError (" Unable to find ROCm common device libraries" );
670+ }
671+ for (auto &Lib : CommonDeviceLibs) {
672+ LibDeviceFiles.push_back (Lib.Path );
673+ }
674+ }
675+ for (auto &LibDeviceFile : LibDeviceFiles) {
676+ // llvm::Error converts to false on success.
677+ if (auto Error = LinkInLib (LibDeviceFile)) {
678+ return Error;
679+ }
680+ }
681+ }
682+
580683 return Error::success ();
581684}
582685
0 commit comments