-
Notifications
You must be signed in to change notification settings - Fork 767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SYCL free function namespace support #17585
base: sycl
Are you sure you want to change the base?
Changes from all commits
b305e94
5ff7b2c
12df05e
bd15ef7
65ef84f
117e97a
c53312a
190ac32
0858e70
32113db
2183658
5dd5894
2499963
a5edf00
2bb7c21
b3bd8ae
c755d26
64906f8
c976900
4cebf72
157b39a
e3ff53a
bdb3967
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -15,10 +15,9 @@ | |||
#include "clang/AST/QualTypeNames.h" | ||||
#include "clang/AST/RecordLayout.h" | ||||
#include "clang/AST/RecursiveASTVisitor.h" | ||||
#include "clang/AST/TemplateArgumentVisitor.h" | ||||
#include "clang/AST/Mangle.h" | ||||
#include "clang/AST/SYCLKernelInfo.h" | ||||
#include "clang/AST/StmtSYCL.h" | ||||
#include "clang/AST/TemplateArgumentVisitor.h" | ||||
#include "clang/AST/TypeOrdering.h" | ||||
#include "clang/AST/TypeVisitor.h" | ||||
#include "clang/Analysis/CallGraph.h" | ||||
|
@@ -27,7 +26,6 @@ | |||
#include "clang/Basic/Diagnostic.h" | ||||
#include "clang/Basic/TargetInfo.h" | ||||
#include "clang/Basic/Version.h" | ||||
#include "clang/AST/SYCLKernelInfo.h" | ||||
#include "clang/Sema/Attr.h" | ||||
#include "clang/Sema/Initialization.h" | ||||
#include "clang/Sema/ParsedAttr.h" | ||||
|
@@ -6425,6 +6423,120 @@ static void EmitPragmaDiagnosticPop(raw_ostream &O) { | |||
O << "\n"; | ||||
} | ||||
|
||||
template <typename BeforeFn, typename AfterFn> | ||||
static void PrintNSHelper(BeforeFn Before, AfterFn After, raw_ostream &OS, | ||||
const DeclContext *DC) { | ||||
if (DC->isTranslationUnit()) | ||||
return; | ||||
|
||||
const auto *CurDecl = cast<Decl>(DC); | ||||
// Ensure we are in the canonical version, so that we know we have the 'full' | ||||
// name of the thing. | ||||
CurDecl = CurDecl->getCanonicalDecl(); | ||||
|
||||
// We are intentionally skipping linkage decls and record decls. Namespaces | ||||
// can appear in a linkage decl, but not a record decl, so we don't have to | ||||
// worry about the names getting messed up from that. We handle record decls | ||||
// later when printing the name of the thing. | ||||
const auto *NS = dyn_cast<NamespaceDecl>(CurDecl); | ||||
if (NS) | ||||
Before(OS, NS); | ||||
|
||||
if (const DeclContext *NewDC = CurDecl->getDeclContext()) | ||||
PrintNSHelper(Before, After, OS, NewDC); | ||||
|
||||
if (NS) | ||||
After(OS, NS); | ||||
} | ||||
|
||||
static void PrintNamespaces(raw_ostream &OS, const DeclContext *DC, | ||||
bool isPrintNamesOnly = false) { | ||||
PrintNSHelper([](raw_ostream &OS, const NamespaceDecl *NS) {}, | ||||
[isPrintNamesOnly](raw_ostream &OS, const NamespaceDecl *NS) { | ||||
if (!isPrintNamesOnly) { | ||||
if (NS->isInline()) | ||||
OS << "inline "; | ||||
OS << "namespace "; | ||||
} | ||||
if (!NS->isAnonymousNamespace()) { | ||||
OS << NS->getName(); | ||||
if (isPrintNamesOnly) | ||||
OS << "::"; | ||||
else | ||||
OS << " "; | ||||
} | ||||
if (!isPrintNamesOnly) { | ||||
OS << "{\n"; | ||||
} | ||||
}, | ||||
OS, DC); | ||||
} | ||||
|
||||
static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) { | ||||
PrintNSHelper( | ||||
[](raw_ostream &OS, const NamespaceDecl *NS) { | ||||
OS << "} // "; | ||||
if (NS->isInline()) | ||||
OS << "inline "; | ||||
|
||||
OS << "namespace "; | ||||
if (!NS->isAnonymousNamespace()) | ||||
OS << NS->getName(); | ||||
|
||||
OS << '\n'; | ||||
}, | ||||
[](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC); | ||||
} | ||||
|
||||
class FreeFunctionPrinter { | ||||
raw_ostream &O; | ||||
const PrintingPolicy &Policy; | ||||
bool NSInserted = false; | ||||
|
||||
public: | ||||
FreeFunctionPrinter(raw_ostream &O, const PrintingPolicy &Policy) | ||||
: O(O), Policy(Policy) {} | ||||
|
||||
/// Emits the function declaration of a free function. | ||||
/// \param FD The function declaration to print. | ||||
/// \param Args The arguments of the function. | ||||
void printFreeFunctionDeclaration(const FunctionDecl *FD, | ||||
const std::string &Args) { | ||||
const DeclContext *DC = FD->getDeclContext(); | ||||
if (DC) { | ||||
// if function in namespace, print namespace | ||||
if (isa<NamespaceDecl>(DC)) { | ||||
PrintNamespaces(O, FD); | ||||
// Set flag to print closing braces for namespaces and namespace in shim | ||||
// function | ||||
NSInserted = true; | ||||
} | ||||
O << FD->getReturnType().getAsString() << " "; | ||||
O << FD->getNameAsString() << "(" << Args << ");"; | ||||
if (NSInserted) { | ||||
O << "\n"; | ||||
PrintNSClosingBraces(O, FD); | ||||
} | ||||
O << "\n"; | ||||
} | ||||
} | ||||
|
||||
/// Emits free function shim function. | ||||
/// \param FD The function declaration to print. | ||||
/// \param ShimCounter The counter for the shim function. | ||||
/// \param ParmList The parameter list of the function. | ||||
void printFreeFunctionShim(const FunctionDecl *FD, const unsigned ShimCounter, | ||||
const std::string &ParmList) { | ||||
// Generate a shim function that returns the address of the free function. | ||||
O << "static constexpr auto __sycl_shim" << ShimCounter << "() {\n"; | ||||
O << " return (void (*)(" << ParmList << "))"; | ||||
|
||||
if (NSInserted) | ||||
PrintNamespaces(O, FD, /*isPrintNamesOnly=*/true); | ||||
O << FD->getIdentifier()->getName().data(); | ||||
} | ||||
}; | ||||
|
||||
void SYCLIntegrationHeader::emit(raw_ostream &O) { | ||||
O << "// This is auto-generated SYCL integration header.\n"; | ||||
O << "\n"; | ||||
|
@@ -6713,16 +6825,25 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { | |||
if (K.SyclKernel->getLanguageLinkage() == CLanguageLinkage) | ||||
O << "extern \"C\" "; | ||||
std::string ParmList; | ||||
std::string ParmListWithNames; | ||||
bool FirstParam = true; | ||||
Policy.SuppressDefaultTemplateArgs = false; | ||||
Policy.PrintCanonicalTypes = true; | ||||
llvm::raw_string_ostream ParmListWithNamesOstream{ParmListWithNames}; | ||||
for (ParmVarDecl *Param : K.SyclKernel->parameters()) { | ||||
if (FirstParam) | ||||
FirstParam = false; | ||||
else | ||||
else { | ||||
ParmList += ", "; | ||||
ParmListWithNamesOstream << ", "; | ||||
} | ||||
Policy.SuppressTagKeyword = true; | ||||
Param->getType().print(ParmListWithNamesOstream, Policy); | ||||
Policy.SuppressTagKeyword = false; | ||||
ParmListWithNamesOstream << " " << Param->getNameAsString(); | ||||
Comment on lines
+6840
to
+6843
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please elaborate what this particular addition is trying to achieve, why the previous code did not suffice and how does it relate to namespace printing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, |
||||
ParmList += Param->getType().getCanonicalType().getAsString(Policy); | ||||
} | ||||
ParmListWithNamesOstream.flush(); | ||||
FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate(); | ||||
Policy.PrintCanonicalTypes = false; | ||||
Policy.SuppressDefinition = true; | ||||
|
@@ -6756,17 +6877,16 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { | |||
// template arguments that match default template arguments while printing | ||||
// template-ids, even if the source code doesn't reference them. | ||||
Policy.EnforceDefaultTemplateArgs = true; | ||||
FreeFunctionPrinter FFPrinter(O, Policy); | ||||
// bool NSInserted{false}; | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
if (FTD) { | ||||
FTD->print(O, Policy); | ||||
O << ";\n"; | ||||
} else { | ||||
K.SyclKernel->print(O, Policy); | ||||
FFPrinter.printFreeFunctionDeclaration(K.SyclKernel, ParmListWithNames); | ||||
} | ||||
O << ";\n"; | ||||
|
||||
// Generate a shim function that returns the address of the free function. | ||||
O << "static constexpr auto __sycl_shim" << ShimCounter << "() {\n"; | ||||
O << " return (void (*)(" << ParmList << "))" | ||||
<< K.SyclKernel->getIdentifier()->getName().data(); | ||||
FFPrinter.printFreeFunctionShim(K.SyclKernel, ShimCounter, ParmList); | ||||
if (FTD) { | ||||
const TemplateArgumentList *TAL = | ||||
K.SyclKernel->getTemplateSpecializationArgs(); | ||||
|
@@ -6935,61 +7055,6 @@ bool SYCLIntegrationFooter::emit(StringRef IntHeaderName) { | |||
return emit(Out); | ||||
} | ||||
|
||||
template <typename BeforeFn, typename AfterFn> | ||||
static void PrintNSHelper(BeforeFn Before, AfterFn After, raw_ostream &OS, | ||||
const DeclContext *DC) { | ||||
if (DC->isTranslationUnit()) | ||||
return; | ||||
|
||||
const auto *CurDecl = cast<Decl>(DC); | ||||
// Ensure we are in the canonical version, so that we know we have the 'full' | ||||
// name of the thing. | ||||
CurDecl = CurDecl->getCanonicalDecl(); | ||||
|
||||
// We are intentionally skipping linkage decls and record decls. Namespaces | ||||
// can appear in a linkage decl, but not a record decl, so we don't have to | ||||
// worry about the names getting messed up from that. We handle record decls | ||||
// later when printing the name of the thing. | ||||
const auto *NS = dyn_cast<NamespaceDecl>(CurDecl); | ||||
if (NS) | ||||
Before(OS, NS); | ||||
|
||||
if (const DeclContext *NewDC = CurDecl->getDeclContext()) | ||||
PrintNSHelper(Before, After, OS, NewDC); | ||||
|
||||
if (NS) | ||||
After(OS, NS); | ||||
} | ||||
|
||||
static void PrintNamespaces(raw_ostream &OS, const DeclContext *DC) { | ||||
PrintNSHelper([](raw_ostream &OS, const NamespaceDecl *NS) {}, | ||||
[](raw_ostream &OS, const NamespaceDecl *NS) { | ||||
if (NS->isInline()) | ||||
OS << "inline "; | ||||
OS << "namespace "; | ||||
if (!NS->isAnonymousNamespace()) | ||||
OS << NS->getName() << " "; | ||||
OS << "{\n"; | ||||
}, | ||||
OS, DC); | ||||
} | ||||
|
||||
static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) { | ||||
PrintNSHelper( | ||||
[](raw_ostream &OS, const NamespaceDecl *NS) { | ||||
OS << "} // "; | ||||
if (NS->isInline()) | ||||
OS << "inline "; | ||||
|
||||
OS << "namespace "; | ||||
if (!NS->isAnonymousNamespace()) | ||||
OS << NS->getName(); | ||||
|
||||
OS << '\n'; | ||||
}, | ||||
[](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC); | ||||
} | ||||
|
||||
static std::string EmitShim(raw_ostream &OS, unsigned &ShimCounter, | ||||
const std::string &LastShim, | ||||
const NamespaceDecl *AnonNS) { | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,18 +86,20 @@ foo(Arg1<int> arg) { | |
// CHECK-NEXT: template <typename T, typename, int a, typename, typename ...TS> struct Arg; | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void ns::simple(ns::Arg<char, int, 12, ns::notatuple>); | ||
// CHECK-NEXT: static constexpr auto __sycl_shim1() { | ||
// CHECK-NEXT: return (void (*)(struct ns::Arg<char, int, 12, struct ns::notatuple>))simple; | ||
// CHECK: namespace ns { | ||
// CHECK-NEXT: void simple(ns::Arg<char, int, 12, ns::notatuple> ); | ||
// CHECK-NEXT: } // namespace ns | ||
// CHECK: static constexpr auto __sycl_shim1() { | ||
// CHECK-NEXT: return (void (*)(struct ns::Arg<char, int, 12, struct ns::notatuple>))ns::simple; | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: Forward declarations of kernel and its argument types: | ||
// CHECK: namespace ns { | ||
// CHECK: namespace ns1 { | ||
// CHECK-NEXT: template <typename A> class hasDefaultArg; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: }} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add more FE tests? Like with various combinations of namespaces around the free function kernel declaration? With inline namespace and not. Can we also test that codegen and semantic analysis is ok for free function kernels defined in a (maybe nested) namespace? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added new e2e tests to check any possible namespaces: nested, anonymous, inline etc. Is it enough or add in these tests too? New tests do not check header directly but if something is emitted wrong, they will fail. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SYCL compiler is complicated and has a lot of components. If we only have a e2e test and it fails suddenly (for example after a pulldown), it may take a while to identify which component now has a problem. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I did not see that these tests are units. Added new checks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well they are "unit" because clang is enormous itself and has its own unit tests but in terms of SYCL compiler we can consider them as unit tests. |
||
|
||
// CHECK: void simple1(ns::Arg<ns::ns1::hasDefaultArg<ns::notatuple>, int, 12, ns::notatuple>); | ||
// CHECK: void simple1(ns::Arg<ns::ns1::hasDefaultArg<ns::notatuple>, int, 12, ns::notatuple> ); | ||
// CHECK-NEXT: static constexpr auto __sycl_shim2() { | ||
// CHECK-NEXT: return (void (*)(struct ns::Arg<class ns::ns1::hasDefaultArg<struct ns::notatuple>, int, 12, struct ns::notatuple>))simple1; | ||
// CHECK-NEXT: } | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should that be a canonical type too?
Does it work if there is a typedef/alias involved in function argument types? Can be in a default value of a NTTP of a kernel argument type, for example.
If it doesn't, we need a canonical type here.