Skip to content

[mlir][spirv] Add support for SPV_ARM_graph extension #147937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m
def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;

def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
def SPV_ARM_graph : I32EnumAttrCase<"SPV_ARM_graph", 6001>;

def SPIRV_ExtensionAttr :
SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
Expand All @@ -447,7 +448,7 @@ def SPIRV_ExtensionAttr :
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
SPV_EXT_mesh_shader,
SPV_ARM_tensors,
SPV_ARM_tensors, SPV_ARM_graph,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
Expand Down Expand Up @@ -1332,6 +1333,12 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"Stora
Extension<[SPV_ARM_tensors]>
];
}
def SPIRV_C_GraphARM : I32EnumAttrCase<"GraphARM", 4191> {
list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader, SPIRV_C_VulkanMemoryModel];
list<Availability> availability = [
Extension<[SPV_ARM_graph]>
];
}
def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
list<Availability> availability = [
Expand Down Expand Up @@ -1545,7 +1552,7 @@ def SPIRV_CapabilityAttr :
SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, SPIRV_C_GraphARM,
SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
Expand Down Expand Up @@ -4245,6 +4252,7 @@ def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,

def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;

def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
Expand Down Expand Up @@ -4551,6 +4559,13 @@ def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
def SPIRV_OC_OpGraphConstantARM : I32EnumAttrCase<"OpGraphConstantARM", 4181>;
def SPIRV_OC_OpGraphEntryPointARM : I32EnumAttrCase<"OpGraphEntryPointARM", 4182>;
def SPIRV_OC_OpGraphARM : I32EnumAttrCase<"OpGraphARM", 4183>;
def SPIRV_OC_OpGraphInputARM : I32EnumAttrCase<"OpGraphInputARM", 4184>;
def SPIRV_OC_OpGraphSetOutputARM : I32EnumAttrCase<"OpGraphSetOutputARM", 4185>;
def SPIRV_OC_OpGraphEndARM : I32EnumAttrCase<"OpGraphEndARM", 4186>;
def SPIRV_OC_OpTypeGraphARM : I32EnumAttrCase<"OpTypeGraphARM", 4190>;
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
Expand Down Expand Up @@ -4666,6 +4681,9 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
SPIRV_OC_OpGroupNonUniformLogicalXor,
SPIRV_OC_OpTypeTensorARM,
SPIRV_OC_OpGraphEntryPointARM, SPIRV_OC_OpGraphARM,
SPIRV_OC_OpGraphInputARM, SPIRV_OC_OpGraphSetOutputARM, SPIRV_OC_OpGraphEndARM,
SPIRV_OC_OpTypeGraphARM, SPIRV_OC_OpGraphConstantARM,
SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
Expand Down Expand Up @@ -4836,6 +4854,11 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
SPIRV_VendorOp<mnemonic, "NV", traits> {
}

class SPIRV_ArmVendorOp<string mnemonic, list<Trait> traits = []> :
SPIRV_VendorOp<mnemonic, "ARM", traits> {
}


def SPIRV_FPFMM_None : I32BitEnumAttrCaseNone<"None">;
def SPIRV_FPFMM_NotNaN : I32BitEnumAttrCaseBit<"NotNaN", 0>;
def SPIRV_FPFMM_NotInf : I32BitEnumAttrCaseBit<"NotInf", 1>;
Expand Down
201 changes: 201 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
//===- SPIRVGraphOps.td - Graph extended insts spec file -----*- tablegen -*-=//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the op definition spec of Graph extension ops.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
#define MLIR_DIALECT_SPIRV_IR_GRAPH_OPS

include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/FunctionInterfaces.td"

//===----------------------------------------------------------------------===//
// SPIR-V Graph opcode specification.
//===----------------------------------------------------------------------===//

// Base class for all Graph ops.
class SPIRV_GraphARMOp<string mnemonic, list<Trait> traits = []> :
SPIRV_ArmVendorOp<mnemonic, traits> {

let availability = [
MinVersion<SPIRV_V_1_0>,
MaxVersion<SPIRV_V_1_6>,
Extension<[SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>,
Capability<[SPIRV_C_GraphARM]>
];
}

def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [Pure]> {
let summary = "Declare a graph constant.";

let description = [{
Declare a graph constant.
Result Type must be an OpTypeTensorARM.
GraphConstantID must be a 32-bit integer literal.
}];

let arguments = (ins
I32Attr: $graph_constant_id
);

let results = (outs
SPIRV_AnyTensorArm:$output
);

let hasVerifier = 0;

let autogenSerialization = 0;

let assemblyFormat = [{
attr-dict `:` type($output)
}];
}

// -----

def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
AutomaticAllocationScope, DeclareOpInterfaceMethods<CallableOpInterface>,
FunctionOpInterface, InModuleScope, IsolatedFromAbove
]> {

let summary = "Declare or define a SPIR-V graph";

let description = [{
This op declares or defines a SPIR-V graph using one region, which
contains one or more blocks.

Different from the SPIR-V binary format, this op is not allowed to
implicitly capture global values, and all external references must use
function arguments or symbol references. This op itself defines a symbol
that is unique in the enclosing module op.

This op itself takes no operands and generates no results. Its region
can take zero or more arguments and return zero or more values.

```
spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
region
```
}];

let arguments = (ins
TypeAttrOf<GraphType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
OptionalAttr<BoolAttr>:$entry_point,
StrAttr:$sym_name
);

let results = (outs);

let regions = (region AnyRegion:$body);

let hasVerifier = 0;

let builders = [
OpBuilder<(ins "StringRef":$name, "GraphType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs, CArg<"bool", "false">:$entry_point)>];

let hasOpcode = 0;

let autogenSerialization = 0;

let extraClassDeclaration = [{
/// Hook for FunctionOpInterface, called after verifying that the 'type'
/// attribute is present and checks if it holds a function type. Ensures
/// getType, getNumArguments, and getNumResults can be called safely
LogicalResult verifyType();

/// Hook for FunctionOpInterface, called after verifying the function
/// type and the presence of the (potentially empty) function body.
/// Ensures SPIR-V specific semantics.
LogicalResult verifyBody();
}];
}

// Check that an op can only be used within the scope of a spirv.ARM.Graph op.
def InGraphScope : PredOpTrait<
"op must appear in a spirv.ARM.Graph op's block",
CPred<"isNestedInGraphARMOpInterface($_op.getParentOp())">>;

// -----

def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> {
let summary = [{
Declare a graph entry point and its interface.
}];

let description = [{
Graph Entry Point must be the Result <id> of an OpGraphARM instruction.

Name is a name string for the graphentry point. A module cannot have two
OpGraphEntryPointARM instructions with the same Name string.

Interface is a list of symbol references to `spirv.GlobalVariable`
operations. These declare the set of global variables from a
module that form the interface of this entry point. The set of
Interface symbols must be equal to or a superset of the
`spirv.GlobalVariable`s referenced by the entry point’s static call
tree, within the interface’s storage classes.

```
entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint`
symbol-reference (`, ` symbol-reference)*
```
}];

let arguments = (ins
FlatSymbolRefAttr:$fn,
SymbolRefArrayAttr:$interface
);

let results = (outs);

let autogenSerialization = 0;

let builders = [
OpBuilder<(ins "spirv::GraphARMOp":$graph, "ArrayRef<Attribute>":$interfaceVars)>];
}

// -----

def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pure,
Terminator]> {

let summary = "Define graph outputs.";

let description = [{
Values are the graph outputs values and must match the GraphOutputs Type
operand of the OpTypeGraphARM type of the OpGraphARM body this
instruction is in.

This instruction must be the last instruction in a block.

```
graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens
```
}];

let arguments = (ins
Variadic<SPIRV_AnyTensorArm>:$value
);

let results = (outs);

let autogenSerialization = 0;

let hasOpcode = 0;

let assemblyFormat = "$value attr-dict `:` type($value)";
}

#endif // MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td"
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Type;
class IntegerType;
class FloatType;
class FunctionType;
class GraphType;
class IndexType;
class MemRefType;
class VectorType;
Expand Down Expand Up @@ -81,6 +82,7 @@ class Builder {
IntegerType getIntegerType(unsigned width);
IntegerType getIntegerType(unsigned width, bool isSigned);
FunctionType getFunctionType(TypeRange inputs, TypeRange results);
GraphType getGraphType(TypeRange inputs, TypeRange results);
TupleType getTupleType(TypeRange elementTypes);
NoneType getNoneType();

Expand Down
18 changes: 11 additions & 7 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
// FunctionType
//===----------------------------------------------------------------------===//

def Builtin_Function : Builtin_Type<"Function", "function"> {
class Builtin_FunctionLike<string Name, string typeMnemonic> : Builtin_Type<Name, typeMnemonic> {
let summary = "Map from a list of inputs to a list of results";
let description = [{
Syntax:
Expand Down Expand Up @@ -434,6 +434,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
}]>
];
let skipDefaultBuilders = 1;
let storageClass = "FunctionTypeStorage";
let genStorageClass = 0;
let extraClassDeclaration = [{
/// Input types.
Expand All @@ -444,23 +445,26 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
unsigned getNumResults() const;
Type getResult(unsigned i) const { return getResults()[i]; }

/// Returns a clone of this function type with the given argument
/// Returns a clone of this function-like type with the given argument
/// and result types.
FunctionType clone(TypeRange inputs, TypeRange results) const;
}] # Name # "Type" # [{ clone(TypeRange inputs, TypeRange results) const;

/// Returns a new function type with the specified arguments and results
/// Returns a new function-like type with the specified arguments and results
/// inserted.
FunctionType getWithArgsAndResults(ArrayRef<unsigned> argIndices,
}] # Name # "Type" # [{ getWithArgsAndResults(ArrayRef<unsigned> argIndices,
TypeRange argTypes,
ArrayRef<unsigned> resultIndices,
TypeRange resultTypes);

/// Returns a new function type without the specified arguments and results.
FunctionType getWithoutArgsAndResults(const BitVector &argIndices,
/// Returns a new function-like type without the specified arguments and results.
}] # Name # "Type" # [{ getWithoutArgsAndResults(const BitVector &argIndices,
const BitVector &resultIndices);
}];
}

def Builtin_Function : Builtin_FunctionLike<"Function", "function">;
def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">;

//===----------------------------------------------------------------------===//
// IndexType
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,13 @@ class OpaqueType<string dialect, string name, string summary>
def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
"function type", "::mlir::FunctionType">;

// Graph Type

// Any graph type.
def GraphType : Type<CPred<"::llvm::isa<::mlir::GraphType>($_self)">,
"graph type", "::mlir::GraphType">;


// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
string descr, string cppType = "::mlir::Type"> :
Expand Down
15 changes: 11 additions & 4 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1019,8 +1019,15 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
return verifyRegionAttribute(op->getLoc(), argType, attribute);
}

LogicalResult SPIRVDialect::verifyRegionResultAttribute(
Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
NamedAttribute attribute) {
return op->emitError("cannot attach SPIR-V attributes to region result");
LogicalResult
SPIRVDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
unsigned resultIndex,
NamedAttribute attribute) {
auto funcOp = dyn_cast<FunctionOpInterface>(op);
if (!funcOp)
return op->emitError(
"cannot attach SPIR-V attributes to region result which is "
"not a FunctionOpInterface type");
return verifyRegionAttribute(op->getLoc(),
funcOp.getResultTypes()[resultIndex], attribute);
}
Loading