Skip to content

[HLSL][RootSignature] Add metadata generation for descriptor tables #139633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -4739,7 +4739,8 @@ def Error : InheritableAttr {
def RootSignature : Attr {
/// [RootSignature(Signature)]
let Spellings = [Microsoft<"RootSignature">];
let Args = [IdentifierArgument<"Signature">];
let Args = [IdentifierArgument<"SignatureIdent">,
DeclArgument<HLSLRootSignature, "SignatureDecl", 0, /*fake=*/1>];
let Subjects = SubjectList<[Function],
ErrorDiag, "'function'">;
let LangOpts = [HLSL];
Expand Down
21 changes: 21 additions & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
DXILValMD->addOperand(Val);
}

void addRootSignature(ArrayRef<llvm::hlsl::rootsig::RootElement> Elements,
llvm::Function *Fn, llvm::Module &M) {
auto &Ctx = M.getContext();

llvm::hlsl::rootsig::MetadataBuilder Builder(Ctx, Elements);
MDNode *RootSignature = Builder.BuildRootSignature();
MDNode *FnPairing =
MDNode::get(Ctx, {ValueAsMetadata::get(Fn), RootSignature});

StringRef RootSignatureValKey = "dx.rootsignatures";
auto *RootSignatureValMD = M.getOrInsertNamedMetadata(RootSignatureValKey);
RootSignatureValMD->addOperand(FnPairing);
}

} // namespace

llvm::Type *
Expand Down Expand Up @@ -423,6 +437,13 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
// FIXME: Handle codegen for return type semantics.
// See: https://github.com/llvm/llvm-project/issues/57875
B.CreateRetVoid();

// Add and identify root signature to function, if applicable
for (const Attr *Attr : FD->getAttrs()) {
if (const auto *RSAttr = dyn_cast<RootSignatureAttr>(Attr))
addRootSignature(RSAttr->getSignatureDecl()->getRootElements(), EntryFn,
M);
}
}

void CGHLSLRuntime::setHLSLFunctionAttributes(const FunctionDecl *FD,
Expand Down
9 changes: 5 additions & 4 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {

IdentifierInfo *Ident = AL.getArgAsIdent(0)->getIdentifierInfo();
if (auto *RS = D->getAttr<RootSignatureAttr>()) {
if (RS->getSignature() != Ident) {
if (RS->getSignatureIdent() != Ident) {
Diag(AL.getLoc(), diag::err_disallowed_duplicate_attribute) << RS;
return;
}
Expand All @@ -970,10 +970,11 @@ void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {

LookupResult R(SemaRef, Ident, SourceLocation(), Sema::LookupOrdinaryName);
if (SemaRef.LookupQualifiedName(R, D->getDeclContext()))
if (isa<HLSLRootSignatureDecl>(R.getFoundDecl())) {
if (auto *SignatureDecl =
dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl())) {
// Perform validation of constructs here
D->addAttr(::new (getASTContext())
RootSignatureAttr(getASTContext(), AL, Ident));
D->addAttr(::new (getASTContext()) RootSignatureAttr(
getASTContext(), AL, Ident, SignatureDecl));
}
}

Expand Down
31 changes: 31 additions & 0 deletions clang/test/CodeGenHLSL/RootSignature.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -emit-llvm -o - %s | FileCheck %s

// CHECK: !dx.rootsignatures = !{![[#FIRST_ENTRY:]], ![[#SECOND_ENTRY:]]}

// CHECK: ![[#FIRST_ENTRY]] = !{ptr @FirstEntry, ![[#EMPTY:]]}
// CHECK: ![[#EMPTY]] = !{}

[shader("compute"), RootSignature("")]
[numthreads(1,1,1)]
void FirstEntry() {}

// CHECK: ![[#SECOND_ENTRY]] = !{ptr @SecondEntry, ![[#SECOND_RS:]]}
// CHECK: ![[#SECOND_RS]] = !{![[#TABLE:]]}
// CHECK: ![[#TABLE]] = !{!"DescriptorTable", i32 0, ![[#CBV:]], ![[#SRV:]]}
// CHECK: ![[#CBV]] = !{!"CBV", i32 1, i32 0, i32 0, i32 -1, i32 4}
// CHECK: ![[#SRV]] = !{!"SRV", i32 4, i32 42, i32 3, i32 32, i32 0}

#define SampleDescriptorTable \
"DescriptorTable( " \
" CBV(b0), " \
" SRV(t42, space = 3, offset = 32, numDescriptors = 4, flags = 0) " \
")"
[shader("compute"), RootSignature(SampleDescriptorTable)]
[numthreads(1,1,1)]
void SecondEntry() {}

// Sanity test to ensure no root is added for this function as there is only
// two entries in !dx.roosignatures
[shader("compute")]
[numthreads(1,1,1)]
void ThirdEntry() {}
45 changes: 43 additions & 2 deletions llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
#include <variant>

namespace llvm {
class LLVMContext;
class MDNode;
class Metadata;

namespace hlsl {
namespace rootsig {

Expand Down Expand Up @@ -84,7 +88,9 @@ struct RootConstants {
// Models the end of a descriptor table and stores its visibility
struct DescriptorTable {
ShaderVisibility Visibility = ShaderVisibility::All;
uint32_t NumClauses = 0; // The number of clauses in the table
// Denotes that the previous NumClauses in the RootElement array
// are the clauses in the table.
uint32_t NumClauses = 0;

void dump(raw_ostream &OS) const;
};
Expand Down Expand Up @@ -119,12 +125,47 @@ struct DescriptorTableClause {
void dump(raw_ostream &OS) const;
};

// Models RootElement : RootConstants | DescriptorTable | DescriptorTableClause
/// Models RootElement : RootFlags | RootConstants | DescriptorTable
/// | DescriptorTableClause
///
/// A Root Signature is modeled in-memory by an array of RootElements. These
/// aim to map closely to their DSL grammar reprsentation defined in the spec.
///
/// Each optional parameter has its default value defined in the struct, and,
/// each mandatory parameter does not have a default initialization.
///
/// For the variants RootFlags, RootConstants and DescriptorTableClause: each
/// data member maps directly to a parameter in the grammar.
///
/// The DescriptorTable is modelled by having its Clauses as the previous
/// RootElements in the array, and it holds a data member for the Visibility
/// parameter.
using RootElement = std::variant<RootFlags, RootConstants, DescriptorTable,
DescriptorTableClause>;

void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements);

class MetadataBuilder {
public:
MetadataBuilder(llvm::LLVMContext &Ctx, ArrayRef<RootElement> Elements)
: Ctx(Ctx), Elements(Elements) {}

/// Iterates through the elements and dispatches onto the correct Build method
///
/// Accumulates the root signature and returns the Metadata node that is just
/// a list of all the elements
MDNode *BuildRootSignature();

private:
/// Define the various builders for the different metadata types
MDNode *BuildDescriptorTable(const DescriptorTable &Table);
MDNode *BuildDescriptorTableClause(const DescriptorTableClause &Clause);

llvm::LLVMContext &Ctx;
ArrayRef<RootElement> Elements;
SmallVector<Metadata *> GeneratedMetadata;
};

} // namespace rootsig
} // namespace hlsl
} // namespace llvm
Expand Down
62 changes: 62 additions & 0 deletions llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
#include "llvm/ADT/bit.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"

namespace llvm {
namespace hlsl {
Expand Down Expand Up @@ -160,6 +163,65 @@ void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements) {
OS << "}";
}

MDNode *MetadataBuilder::BuildRootSignature() {
for (const RootElement &Element : Elements) {
MDNode *ElementMD = nullptr;
if (const auto &Clause = std::get_if<DescriptorTableClause>(&Element))
ElementMD = BuildDescriptorTableClause(*Clause);
if (const auto &Table = std::get_if<DescriptorTable>(&Element))
ElementMD = BuildDescriptorTable(*Table);

// FIXME(#126586): remove once all RootElemnt variants are handled in a
// visit or otherwise
assert(ElementMD != nullptr &&
"Constructed an unhandled root element type.");

GeneratedMetadata.push_back(ElementMD);
}

return MDNode::get(Ctx, GeneratedMetadata);
}

MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
IRBuilder<> Builder(Ctx);
SmallVector<Metadata *> TableOperands;
// Set the mandatory arguments
TableOperands.push_back(MDString::get(Ctx, "DescriptorTable"));
TableOperands.push_back(ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Table.Visibility))));

// Remaining operands are references to the table's clauses. The in-memory
// representation of the Root Elements created from parsing will ensure that
// the previous N elements are the clauses for this table.
assert(Table.NumClauses <= GeneratedMetadata.size() &&
"Table expected all owned clauses to be generated already");
Comment on lines +196 to +197
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if that assert would really help, since it could easily be mislead by the generation of other parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was added as a (hopefully) better error message rather than the error that would be generated if Generated.size() - Table.NumClauses created an undefined behaviour iterator.

Do you have a suggestion for how it could be better worded, or, a better assert?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how complex/useful this is, we could check how many DescriptorTableClause are in TableOperands preceding the table. I think this could be done like:

uint Count = 0;
// I think this should be the Element parameter, from BuildRootSignature
auto It = Element;
while (std::holds_alternative<DescriptorTableClause>(*It)){
    Count ++;
    Element --;
}

assert(Table.NumClauses == Count);

Not sure this code is correct/useful, but that is the idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea and toyed around with it a bit. It unfortunately is a bit a unwieldly to implement. I think it introduces more complexity then it reduces

// So, add a refence to each clause to our operands
TableOperands.append(GeneratedMetadata.end() - Table.NumClauses,
GeneratedMetadata.end());
// Then, remove those clauses from the general list of Root Elements
GeneratedMetadata.pop_back_n(Table.NumClauses);

return MDNode::get(Ctx, TableOperands);
}

MDNode *MetadataBuilder::BuildDescriptorTableClause(
const DescriptorTableClause &Clause) {
IRBuilder<> Builder(Ctx);
std::string Name;
llvm::raw_string_ostream OS(Name);
OS << Clause.Type;
return MDNode::get(
Ctx, {
MDString::get(Ctx, OS.str()),
ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Clause.Flags))),
});
}

} // namespace rootsig
} // namespace hlsl
} // namespace llvm