- 
                Notifications
    You must be signed in to change notification settings 
- Fork 794
SYCL free function namespace support #17585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b305e94
              5ff7b2c
              12df05e
              bd15ef7
              65ef84f
              117e97a
              c53312a
              190ac32
              0858e70
              32113db
              2183658
              5dd5894
              2499963
              a5edf00
              2bb7c21
              b3bd8ae
              c755d26
              64906f8
              c976900
              4cebf72
              157b39a
              e3ff53a
              bdb3967
              cae876a
              afb8b59
              09c7786
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -15,10 +15,9 @@ | |
| #include "clang/AST/QualTypeNames.h" | ||
| #include "clang/AST/RecordLayout.h" | ||
| #include "clang/AST/RecursiveASTVisitor.h" | ||
| #include "clang/AST/TemplateArgumentVisitor.h" | ||
| #include "clang/AST/Mangle.h" | ||
| #include "clang/AST/SYCLKernelInfo.h" | ||
| #include "clang/AST/StmtSYCL.h" | ||
| #include "clang/AST/TemplateArgumentVisitor.h" | ||
| #include "clang/AST/TypeOrdering.h" | ||
| #include "clang/AST/TypeVisitor.h" | ||
| #include "clang/Analysis/CallGraph.h" | ||
|  | @@ -27,7 +26,6 @@ | |
| #include "clang/Basic/Diagnostic.h" | ||
| #include "clang/Basic/TargetInfo.h" | ||
| #include "clang/Basic/Version.h" | ||
| #include "clang/AST/SYCLKernelInfo.h" | ||
| #include "clang/Sema/Attr.h" | ||
| #include "clang/Sema/Initialization.h" | ||
| #include "clang/Sema/ParsedAttr.h" | ||
|  | @@ -6425,6 +6423,120 @@ static void EmitPragmaDiagnosticPop(raw_ostream &O) { | |
| O << "\n"; | ||
| } | ||
|  | ||
| template <typename BeforeFn, typename AfterFn> | ||
| static void PrintNSHelper(BeforeFn Before, AfterFn After, raw_ostream &OS, | ||
| const DeclContext *DC) { | ||
| if (DC->isTranslationUnit()) | ||
| return; | ||
|  | ||
| const auto *CurDecl = cast<Decl>(DC); | ||
| // Ensure we are in the canonical version, so that we know we have the 'full' | ||
| // name of the thing. | ||
| CurDecl = CurDecl->getCanonicalDecl(); | ||
|  | ||
| // We are intentionally skipping linkage decls and record decls. Namespaces | ||
| // can appear in a linkage decl, but not a record decl, so we don't have to | ||
| // worry about the names getting messed up from that. We handle record decls | ||
| // later when printing the name of the thing. | ||
| const auto *NS = dyn_cast<NamespaceDecl>(CurDecl); | ||
| if (NS) | ||
| Before(OS, NS); | ||
|  | ||
| if (const DeclContext *NewDC = CurDecl->getDeclContext()) | ||
| PrintNSHelper(Before, After, OS, NewDC); | ||
|  | ||
| if (NS) | ||
| After(OS, NS); | ||
| } | ||
|  | ||
| static void PrintNamespaces(raw_ostream &OS, const DeclContext *DC, | ||
| bool isPrintNamesOnly = false) { | ||
| PrintNSHelper([](raw_ostream &OS, const NamespaceDecl *NS) {}, | ||
| [isPrintNamesOnly](raw_ostream &OS, const NamespaceDecl *NS) { | ||
| if (!isPrintNamesOnly) { | ||
| if (NS->isInline()) | ||
| OS << "inline "; | ||
| OS << "namespace "; | ||
| } | ||
| if (!NS->isAnonymousNamespace()) { | ||
| OS << NS->getName(); | ||
| if (isPrintNamesOnly) | ||
| OS << "::"; | ||
| else | ||
| OS << " "; | ||
| } | ||
| if (!isPrintNamesOnly) { | ||
| OS << "{\n"; | ||
| } | ||
| }, | ||
| OS, DC); | ||
| } | ||
|  | ||
| static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) { | ||
| PrintNSHelper( | ||
| [](raw_ostream &OS, const NamespaceDecl *NS) { | ||
| OS << "} // "; | ||
| if (NS->isInline()) | ||
| OS << "inline "; | ||
|  | ||
| OS << "namespace "; | ||
| if (!NS->isAnonymousNamespace()) | ||
| OS << NS->getName(); | ||
|  | ||
| OS << '\n'; | ||
| }, | ||
| [](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC); | ||
| } | ||
|  | ||
| class FreeFunctionPrinter { | ||
| raw_ostream &O; | ||
| const PrintingPolicy &Policy; | ||
| bool NSInserted = false; | ||
|  | ||
| public: | ||
| FreeFunctionPrinter(raw_ostream &O, const PrintingPolicy &Policy) | ||
| : O(O), Policy(Policy) {} | ||
|  | ||
| /// Emits the function declaration of a free function. | ||
| /// \param FD The function declaration to print. | ||
| /// \param Args The arguments of the function. | ||
| void printFreeFunctionDeclaration(const FunctionDecl *FD, | ||
| const std::string &Args) { | ||
| const DeclContext *DC = FD->getDeclContext(); | ||
| if (DC) { | ||
| // if function in namespace, print namespace | ||
| if (isa<NamespaceDecl>(DC)) { | ||
| PrintNamespaces(O, FD); | ||
| // Set flag to print closing braces for namespaces and namespace in shim | ||
| // function | ||
| NSInserted = true; | ||
| } | ||
| O << FD->getReturnType().getAsString() << " "; | ||
| O << FD->getNameAsString() << "(" << Args << ");"; | ||
| if (NSInserted) { | ||
| O << "\n"; | ||
| PrintNSClosingBraces(O, FD); | ||
| } | ||
| O << "\n"; | ||
| } | ||
| } | ||
|  | ||
| /// Emits free function shim function. | ||
| /// \param FD The function declaration to print. | ||
| /// \param ShimCounter The counter for the shim function. | ||
| /// \param ParmList The parameter list of the function. | ||
| void printFreeFunctionShim(const FunctionDecl *FD, const unsigned ShimCounter, | ||
| const std::string &ParmList) { | ||
| // Generate a shim function that returns the address of the free function. | ||
| O << "static constexpr auto __sycl_shim" << ShimCounter << "() {\n"; | ||
| O << " return (void (*)(" << ParmList << "))"; | ||
|  | ||
| if (NSInserted) | ||
| PrintNamespaces(O, FD, /*isPrintNamesOnly=*/true); | ||
| O << FD->getIdentifier()->getName().data(); | ||
| } | ||
| }; | ||
|  | ||
| void SYCLIntegrationHeader::emit(raw_ostream &O) { | ||
| O << "// This is auto-generated SYCL integration header.\n"; | ||
| O << "\n"; | ||
|  | @@ -6713,16 +6825,25 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { | |
| if (K.SyclKernel->getLanguageLinkage() == CLanguageLinkage) | ||
| O << "extern \"C\" "; | ||
| std::string ParmList; | ||
| std::string ParmListWithNames; | ||
| bool FirstParam = true; | ||
| Policy.SuppressDefaultTemplateArgs = false; | ||
| Policy.PrintCanonicalTypes = true; | ||
| llvm::raw_string_ostream ParmListWithNamesOstream{ParmListWithNames}; | ||
| for (ParmVarDecl *Param : K.SyclKernel->parameters()) { | ||
| if (FirstParam) | ||
| FirstParam = false; | ||
| else | ||
| else { | ||
| ParmList += ", "; | ||
| ParmListWithNamesOstream << ", "; | ||
| } | ||
| Policy.SuppressTagKeyword = true; | ||
| Param->getType().print(ParmListWithNamesOstream, Policy); | ||
|         
                  dklochkov-emb marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| Policy.SuppressTagKeyword = false; | ||
| ParmListWithNamesOstream << " " << Param->getNameAsString(); | ||
| 
      Comment on lines
    
      +6840
     to 
      +6843
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please elaborate what this particular addition is trying to achieve, why the previous code did not suffice and how does it relate to namespace printing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general,  | ||
| ParmList += Param->getType().getCanonicalType().getAsString(Policy); | ||
| } | ||
| ParmListWithNamesOstream.flush(); | ||
| FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate(); | ||
| Policy.PrintCanonicalTypes = false; | ||
| Policy.SuppressDefinition = true; | ||
|  | @@ -6756,17 +6877,15 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) { | |
| // template arguments that match default template arguments while printing | ||
| // template-ids, even if the source code doesn't reference them. | ||
| Policy.EnforceDefaultTemplateArgs = true; | ||
| FreeFunctionPrinter FFPrinter(O, Policy); | ||
| if (FTD) { | ||
| FTD->print(O, Policy); | ||
| O << ";\n"; | ||
| } else { | ||
| K.SyclKernel->print(O, Policy); | ||
| FFPrinter.printFreeFunctionDeclaration(K.SyclKernel, ParmListWithNames); | ||
| } | ||
| O << ";\n"; | ||
|  | ||
| // Generate a shim function that returns the address of the free function. | ||
| O << "static constexpr auto __sycl_shim" << ShimCounter << "() {\n"; | ||
| O << " return (void (*)(" << ParmList << "))" | ||
| << K.SyclKernel->getIdentifier()->getName().data(); | ||
| FFPrinter.printFreeFunctionShim(K.SyclKernel, ShimCounter, ParmList); | ||
| if (FTD) { | ||
| const TemplateArgumentList *TAL = | ||
| K.SyclKernel->getTemplateSpecializationArgs(); | ||
|  | @@ -6935,61 +7054,6 @@ bool SYCLIntegrationFooter::emit(StringRef IntHeaderName) { | |
| return emit(Out); | ||
| } | ||
|  | ||
| template <typename BeforeFn, typename AfterFn> | ||
| static void PrintNSHelper(BeforeFn Before, AfterFn After, raw_ostream &OS, | ||
| const DeclContext *DC) { | ||
| if (DC->isTranslationUnit()) | ||
| return; | ||
|  | ||
| const auto *CurDecl = cast<Decl>(DC); | ||
| // Ensure we are in the canonical version, so that we know we have the 'full' | ||
| // name of the thing. | ||
| CurDecl = CurDecl->getCanonicalDecl(); | ||
|  | ||
| // We are intentionally skipping linkage decls and record decls. Namespaces | ||
| // can appear in a linkage decl, but not a record decl, so we don't have to | ||
| // worry about the names getting messed up from that. We handle record decls | ||
| // later when printing the name of the thing. | ||
| const auto *NS = dyn_cast<NamespaceDecl>(CurDecl); | ||
| if (NS) | ||
| Before(OS, NS); | ||
|  | ||
| if (const DeclContext *NewDC = CurDecl->getDeclContext()) | ||
| PrintNSHelper(Before, After, OS, NewDC); | ||
|  | ||
| if (NS) | ||
| After(OS, NS); | ||
| } | ||
|  | ||
| static void PrintNamespaces(raw_ostream &OS, const DeclContext *DC) { | ||
| PrintNSHelper([](raw_ostream &OS, const NamespaceDecl *NS) {}, | ||
| [](raw_ostream &OS, const NamespaceDecl *NS) { | ||
| if (NS->isInline()) | ||
| OS << "inline "; | ||
| OS << "namespace "; | ||
| if (!NS->isAnonymousNamespace()) | ||
| OS << NS->getName() << " "; | ||
| OS << "{\n"; | ||
| }, | ||
| OS, DC); | ||
| } | ||
|  | ||
| static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) { | ||
| PrintNSHelper( | ||
| [](raw_ostream &OS, const NamespaceDecl *NS) { | ||
| OS << "} // "; | ||
| if (NS->isInline()) | ||
| OS << "inline "; | ||
|  | ||
| OS << "namespace "; | ||
| if (!NS->isAnonymousNamespace()) | ||
| OS << NS->getName(); | ||
|  | ||
| OS << '\n'; | ||
| }, | ||
| [](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC); | ||
| } | ||
|  | ||
| static std::string EmitShim(raw_ostream &OS, unsigned &ShimCounter, | ||
| const std::string &LastShim, | ||
| const NamespaceDecl *AnonNS) { | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -86,18 +86,20 @@ foo(Arg1<int> arg) { | |
| // CHECK-NEXT: template <typename T, typename, int a, typename, typename ...TS> struct Arg; | ||
| // CHECK-NEXT: } | ||
|  | ||
| // CHECK: void ns::simple(ns::Arg<char, int, 12, ns::notatuple>); | ||
| // CHECK-NEXT: static constexpr auto __sycl_shim1() { | ||
| // CHECK-NEXT: return (void (*)(struct ns::Arg<char, int, 12, struct ns::notatuple>))simple; | ||
| // CHECK: namespace ns { | ||
| // CHECK-NEXT: void simple(ns::Arg<char, int, 12, ns::notatuple> ); | ||
| // CHECK-NEXT: } // namespace ns | ||
| // CHECK: static constexpr auto __sycl_shim1() { | ||
| // CHECK-NEXT: return (void (*)(struct ns::Arg<char, int, 12, struct ns::notatuple>))ns::simple; | ||
| // CHECK-NEXT: } | ||
|  | ||
| // CHECK: Forward declarations of kernel and its argument types: | ||
| // CHECK: namespace ns { | ||
| // CHECK: namespace ns1 { | ||
| // CHECK-NEXT: template <typename A> class hasDefaultArg; | ||
| // CHECK-NEXT: } | ||
| // CHECK-NEXT: }} | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add more FE tests? Like with various combinations of namespaces around the free function kernel declaration? With inline namespace and not. Can we also test that codegen and semantic analysis is ok for free function kernels defined in a (maybe nested) namespace? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added new e2e tests to check any possible namespaces: nested, anonymous, inline etc. Is it enough or add in these tests too? New tests do not check header directly but if something is emitted wrong, they will fail. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SYCL compiler is complicated and has a lot of components. If we only have a e2e test and it fails suddenly (for example after a pulldown), it may take a while to identify which component now has a problem. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I did not see that these tests are units. Added new checks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well they are "unit" because clang is enormous itself and has its own unit tests but in terms of SYCL compiler we can consider them as unit tests. | ||
|  | ||
| // CHECK: void simple1(ns::Arg<ns::ns1::hasDefaultArg<ns::notatuple>, int, 12, ns::notatuple>); | ||
| // CHECK: void simple1(ns::Arg<ns::ns1::hasDefaultArg<ns::notatuple>, int, 12, ns::notatuple> ); | ||
| // CHECK-NEXT: static constexpr auto __sycl_shim2() { | ||
| // CHECK-NEXT: return (void (*)(struct ns::Arg<class ns::ns1::hasDefaultArg<struct ns::notatuple>, int, 12, struct ns::notatuple>))simple1; | ||
| // CHECK-NEXT: } | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang/lib/Sema/SemaSYCL.cpp:6584:25: error: private field 'Policy' is not used [-Werror,-Wunused-private-field]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@martygrant Can you fix it? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dklochkov-emb is best positioned to fix this issue as author of the change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR is created to fix that
#17970
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!