-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][spirv] Add support for SPV_ARM_graph extension #147937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This patch adds support for the `SPV_ARM_graph` SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a new `Graph` abstraction for expressing dataflow computations over full resources. The implementation includes: - A new `GraphType`, modeled similarly to `FunctionType`, for typed graph signatures. - New operations in the `spirv.arm` namespace: - `spirv.arm.Graph` - `spirv.arm.GraphEntryPoint` - `spirv.arm.GraphConstant` - `spirv.arm.GraphOutput` - Serialization and deserialization support for: - `OpGraphARM`, `OpGraphInputARM`, `OpGraphSetOutputARM`, `OpGraphEndARM` - `OpGraphEntryPointARM`, `OpGraphConstantARM`, `OpTypeGraphARM` - ABI lowering support for graph entry points via `LowerABIAttributesPass`. - Verifier and VCE updates to properly gate usage under `SPV_ARM_graph`. - Tests covering parsing, verification, ABI handling, and binary round-tripping. Graphs currently support only `SPV_ARM_tensors`, but are designed to generalize to other resource types, such as images. Spec: KhronosGroup/SPIRV-Registry#346 RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947 Signed-off-by: Davide Grohmann <[email protected]> Change-Id: I99aa469f2108219591544056db55bcd3f0702c7e
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-ods Author: Davide Grohmann (davidegrohmann) ChangesThis patch adds support for the The implementation includes:
Graphs currently support only Spec: KhronosGroup/SPIRV-Registry#346 Patch is 85.76 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147937.diff 27 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 910418f1706a6..ce4bb6c2e4934 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -423,6 +423,7 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m
def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
+def SPV_ARM_graph : I32EnumAttrCase<"SPV_ARM_graph", 6001>;
def SPIRV_ExtensionAttr :
SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
@@ -447,7 +448,7 @@ def SPIRV_ExtensionAttr :
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
SPV_EXT_mesh_shader,
- SPV_ARM_tensors,
+ SPV_ARM_tensors, SPV_ARM_graph,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1332,6 +1333,12 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"Stora
Extension<[SPV_ARM_tensors]>
];
}
+def SPIRV_C_GraphARM : I32EnumAttrCase<"GraphARM", 4191> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader, SPIRV_C_VulkanMemoryModel];
+ list<Availability> availability = [
+ Extension<[SPV_ARM_graph]>
+ ];
+}
def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
list<Availability> availability = [
@@ -1545,7 +1552,7 @@ def SPIRV_CapabilityAttr :
SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
- SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
+ SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, SPIRV_C_GraphARM,
SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4245,6 +4252,7 @@ def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
+
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
@@ -4551,6 +4559,13 @@ def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
+def SPIRV_OC_OpGraphConstantARM : I32EnumAttrCase<"OpGraphConstantARM", 4181>;
+def SPIRV_OC_OpGraphEntryPointARM : I32EnumAttrCase<"OpGraphEntryPointARM", 4182>;
+def SPIRV_OC_OpGraphARM : I32EnumAttrCase<"OpGraphARM", 4183>;
+def SPIRV_OC_OpGraphInputARM : I32EnumAttrCase<"OpGraphInputARM", 4184>;
+def SPIRV_OC_OpGraphSetOutputARM : I32EnumAttrCase<"OpGraphSetOutputARM", 4185>;
+def SPIRV_OC_OpGraphEndARM : I32EnumAttrCase<"OpGraphEndARM", 4186>;
+def SPIRV_OC_OpTypeGraphARM : I32EnumAttrCase<"OpTypeGraphARM", 4190>;
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4666,6 +4681,9 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
SPIRV_OC_OpGroupNonUniformLogicalXor,
SPIRV_OC_OpTypeTensorARM,
+ SPIRV_OC_OpGraphEntryPointARM, SPIRV_OC_OpGraphARM,
+ SPIRV_OC_OpGraphInputARM, SPIRV_OC_OpGraphSetOutputARM, SPIRV_OC_OpGraphEndARM,
+ SPIRV_OC_OpTypeGraphARM, SPIRV_OC_OpGraphConstantARM,
SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
@@ -4836,6 +4854,11 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
SPIRV_VendorOp<mnemonic, "NV", traits> {
}
+class SPIRV_ArmVendorOp<string mnemonic, list<Trait> traits = []> :
+ SPIRV_VendorOp<mnemonic, "ARM", traits> {
+}
+
+
def SPIRV_FPFMM_None : I32BitEnumAttrCaseNone<"None">;
def SPIRV_FPFMM_NotNaN : I32BitEnumAttrCaseBit<"NotNaN", 0>;
def SPIRV_FPFMM_NotInf : I32BitEnumAttrCaseBit<"NotInf", 1>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
new file mode 100644
index 0000000000000..38fb4b2eff414
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
@@ -0,0 +1,201 @@
+//===- SPIRVGraphOps.td - Graph extended insts spec file -----*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of Graph extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
+#define MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Graph opcode specification.
+//===----------------------------------------------------------------------===//
+
+// Base class for all Graph ops.
+class SPIRV_GraphARMOp<string mnemonic, list<Trait> traits = []> :
+ SPIRV_ArmVendorOp<mnemonic, traits> {
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>,
+ Capability<[SPIRV_C_GraphARM]>
+ ];
+}
+
+def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [Pure]> {
+ let summary = "Declare a graph constant.";
+
+ let description = [{
+ Declare a graph constant.
+ Result Type must be an OpTypeTensorARM.
+ GraphConstantID must be a 32-bit integer literal.
+ }];
+
+ let arguments = (ins
+ I32Attr: $graph_constant_id
+ );
+
+ let results = (outs
+ SPIRV_AnyTensorArm:$output
+ );
+
+ let hasVerifier = 0;
+
+ let autogenSerialization = 0;
+
+ let assemblyFormat = [{
+ attr-dict `:` type($output)
+ }];
+}
+
+// -----
+
+def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
+ AutomaticAllocationScope, DeclareOpInterfaceMethods<CallableOpInterface>,
+ FunctionOpInterface, InModuleScope, IsolatedFromAbove
+ ]> {
+
+ let summary = "Declare or define a SPIR-V graph";
+
+ let description = [{
+ This op declares or defines a SPIR-V graph using one region, which
+ contains one or more blocks.
+
+ Different from the SPIR-V binary format, this op is not allowed to
+ implicitly capture global values, and all external references must use
+ function arguments or symbol references. This op itself defines a symbol
+ that is unique in the enclosing module op.
+
+ This op itself takes no operands and generates no results. Its region
+ can take zero or more arguments and return zero or more values.
+
+ ```
+ spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
+ region
+ ```
+ }];
+
+ let arguments = (ins
+ TypeAttrOf<GraphType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs,
+ OptionalAttr<BoolAttr>:$entry_point,
+ StrAttr:$sym_name
+ );
+
+ let results = (outs);
+
+ let regions = (region AnyRegion:$body);
+
+ let hasVerifier = 0;
+
+ let builders = [
+ OpBuilder<(ins "StringRef":$name, "GraphType":$type,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs, CArg<"bool", "false">:$entry_point)>];
+
+ let hasOpcode = 0;
+
+ let autogenSerialization = 0;
+
+ let extraClassDeclaration = [{
+ /// Hook for FunctionOpInterface, called after verifying that the 'type'
+ /// attribute is present and checks if it holds a function type. Ensures
+ /// getType, getNumArguments, and getNumResults can be called safely
+ LogicalResult verifyType();
+
+ /// Hook for FunctionOpInterface, called after verifying the function
+ /// type and the presence of the (potentially empty) function body.
+ /// Ensures SPIR-V specific semantics.
+ LogicalResult verifyBody();
+ }];
+}
+
+// Check that an op can only be used within the scope of a spirv.ARM.Graph op.
+def InGraphScope : PredOpTrait<
+ "op must appear in a spirv.ARM.Graph op's block",
+ CPred<"isNestedInGraphARMOpInterface($_op.getParentOp())">>;
+
+// -----
+
+def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> {
+ let summary = [{
+ Declare a graph entry point and its interface.
+ }];
+
+ let description = [{
+ Graph Entry Point must be the Result <id> of an OpGraphARM instruction.
+
+ Name is a name string for the graphentry point. A module cannot have two
+ OpGraphEntryPointARM instructions with the same Name string.
+
+ Interface is a list of symbol references to `spirv.GlobalVariable`
+ operations. These declare the set of global variables from a
+ module that form the interface of this entry point. The set of
+ Interface symbols must be equal to or a superset of the
+ `spirv.GlobalVariable`s referenced by the entry point’s static call
+ tree, within the interface’s storage classes.
+
+ ```
+ entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint`
+ symbol-reference (`, ` symbol-reference)*
+ ```
+ }];
+
+ let arguments = (ins
+ FlatSymbolRefAttr:$fn,
+ SymbolRefArrayAttr:$interface
+ );
+
+ let results = (outs);
+
+ let autogenSerialization = 0;
+
+ let builders = [
+ OpBuilder<(ins "spirv::GraphARMOp":$graph, "ArrayRef<Attribute>":$interfaceVars)>];
+}
+
+// -----
+
+def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pure,
+ Terminator]> {
+
+ let summary = "Define graph outputs.";
+
+ let description = [{
+ Values are the graph outputs values and must match the GraphOutputs Type
+ operand of the OpTypeGraphARM type of the OpGraphARM body this
+ instruction is in.
+
+ This instruction must be the last instruction in a block.
+
+ ```
+ graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<SPIRV_AnyTensorArm>:$value
+ );
+
+ let results = (outs);
+
+ let autogenSerialization = 0;
+
+ let hasOpcode = 0;
+
+ let assemblyFormat = "$value attr-dict `:` type($value)";
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 0fa1bb9d5bd01..96ef035eda37a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td"
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index ad59ea63a6901..aa7d30b87db14 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -24,6 +24,7 @@ class Type;
class IntegerType;
class FloatType;
class FunctionType;
+class GraphType;
class IndexType;
class MemRefType;
class VectorType;
@@ -81,6 +82,7 @@ class Builder {
IntegerType getIntegerType(unsigned width);
IntegerType getIntegerType(unsigned width, bool isSigned);
FunctionType getFunctionType(TypeRange inputs, TypeRange results);
+ GraphType getGraphType(TypeRange inputs, TypeRange results);
TupleType getTupleType(TypeRange elementTypes);
NoneType getNoneType();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index a0c8acea91dc5..08847dd11c685 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -403,7 +403,7 @@ def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
// FunctionType
//===----------------------------------------------------------------------===//
-def Builtin_Function : Builtin_Type<"Function", "function"> {
+class Builtin_FunctionLike<string Name, string typeMnemonic> : Builtin_Type<Name, typeMnemonic> {
let summary = "Map from a list of inputs to a list of results";
let description = [{
Syntax:
@@ -434,6 +434,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
}]>
];
let skipDefaultBuilders = 1;
+ let storageClass = "FunctionTypeStorage";
let genStorageClass = 0;
let extraClassDeclaration = [{
/// Input types.
@@ -444,23 +445,26 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
unsigned getNumResults() const;
Type getResult(unsigned i) const { return getResults()[i]; }
- /// Returns a clone of this function type with the given argument
+ /// Returns a clone of this function-like type with the given argument
/// and result types.
- FunctionType clone(TypeRange inputs, TypeRange results) const;
+ }] # Name # "Type" # [{ clone(TypeRange inputs, TypeRange results) const;
- /// Returns a new function type with the specified arguments and results
+ /// Returns a new function-like type with the specified arguments and results
/// inserted.
- FunctionType getWithArgsAndResults(ArrayRef<unsigned> argIndices,
+ }] # Name # "Type" # [{ getWithArgsAndResults(ArrayRef<unsigned> argIndices,
TypeRange argTypes,
ArrayRef<unsigned> resultIndices,
TypeRange resultTypes);
- /// Returns a new function type without the specified arguments and results.
- FunctionType getWithoutArgsAndResults(const BitVector &argIndices,
+ /// Returns a new function-like type without the specified arguments and results.
+ }] # Name # "Type" # [{ getWithoutArgsAndResults(const BitVector &argIndices,
const BitVector &resultIndices);
}];
}
+def Builtin_Function : Builtin_FunctionLike<"Function", "function">;
+def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">;
+
//===----------------------------------------------------------------------===//
// IndexType
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 45ec1846580f2..aab1b01c5cff9 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -387,6 +387,13 @@ class OpaqueType<string dialect, string name, string summary>
def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
"function type", "::mlir::FunctionType">;
+// Graph Type
+
+// Any graph type.
+def GraphType : Type<CPred<"::llvm::isa<::mlir::GraphType>($_self)">,
+ "graph type", "::mlir::GraphType">;
+
+
// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
string descr, string cppType = "::mlir::Type"> :
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 88c7adf3dfcb3..e66d4b0ffc446 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1019,8 +1019,14 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
return verifyRegionAttribute(op->getLoc(), argType, attribute);
}
-LogicalResult SPIRVDialect::verifyRegionResultAttribute(
- Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
- NamedAttribute attribute) {
- return op->emitError("cannot attach SPIR-V attributes to region result");
+LogicalResult SPIRVDialect::verifyRegionResultAttribute(Operation *op,
+ unsigned regionIndex,
+ unsigned resultIndex,
+ NamedAttribute attribute) {
+ auto funcOp = dyn_cast<FunctionOpInterface>(op);
+ if (!funcOp)
+ return op->emitError("cannot attach SPIR-V attributes to region result which is "
+ "not a FunctionOpInterface type");
+ return verifyRegionAttribute(
+ op->getLoc(), funcOp.getResultTypes()[resultIndex], attribute);
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index d8dfe164458e2..2f3a28ff16173 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -31,6 +31,18 @@ static bool isNestedInFunctionOpInterface(Operation *op) {
return isNestedInFunctionOpInterface(op->getParentOp());
}
+/// Returns true if the given op is a GraphARM op or nested in a
+/// GraphARM op without a module-like op in the middle.
+static bool isNestedInGraphARMOpInterface(Operation *op) {
+ if (!op)
+ return false;
+ if (op->hasTrait<OpTrait::SymbolTable>())
+ return false;
+ if (isa<spirv::GraphARMOp>(op))
+ return true;
+ return isNestedInGraphARMOpInterface(op->getParentOp());
+}
+
/// Returns true if the given op is an module-like op that maintains a symbol
/// table.
static bool isDirectInModuleLikeOp(Operation *op) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index eb2974d62fdd1..17cbab189588f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1084,6 +1084,236 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
state.addRegion();
}
+//===----------------------------------------------------------------------===//
+// spirv.GraphEntryPointARM
+//===----------------------------------------------------------------------===//
+
+void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
+ OperationState &state,
+ spirv::GraphARMOp graph,
+ ArrayRef<Attribute> interfaceVars) {
+ build(builder, state, SymbolRefAttr::get(graph),
+ builder.getArrayAttr(interfaceVars));
+}
+
+ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<Type, 0> idTypes;
+ SmallVector<Attribute, 4> interfaceVars;
+
+ FlatSymbolRefAttr fn;
+ if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
+ return failure();
+ }
+
+ if (!parser.parseOptionalComma()) {
+ // Parse the interface variables
+ if ...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Signed-off-by: Davide Grohmann <[email protected]> Change-Id: I07c5cad1f3092994af33ebbeda84e2018e03f6b7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, great progress! Would you be able to break this up into a few smaller and easier to review PRs? These can be stacked. It's fine to leave this one as a draft for the full picture.
This patch adds support for the
SPV_ARM_graph
SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a newGraph
abstraction for expressing dataflow computations over full resources.The implementation includes:
GraphType
, modeled similarly toFunctionType
, for typed graph signatures.spirv.arm
namespace:spirv.arm.Graph
spirv.arm.GraphEntryPoint
spirv.arm.GraphConstant
spirv.arm.GraphOutput
OpGraphARM
,OpGraphInputARM
,OpGraphSetOutputARM
,OpGraphEndARM
OpGraphEntryPointARM
,OpGraphConstantARM
,OpTypeGraphARM
LowerABIAttributesPass
.SPV_ARM_graph
.Graphs currently support only
SPV_ARM_tensors
, but are designed to generalize to other resource types, such as images.Spec: KhronosGroup/SPIRV-Registry#346
RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947