Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for headers #66

Open
wants to merge 2 commits into
base: mlir-enum
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,27 @@ def P4HIR_EnumFieldAttr : P4HIR_Attr<"EnumField", "enum.field", [TypedAttrInterf
];
}

//===----------------------------------------------------------------------===//
// ValidAttr
//===----------------------------------------------------------------------===//

def ValidityBit_Invalid: I32BitEnumAttrCaseNone<"Invalid", "invalid">;
def ValidityBit_Valid : I32BitEnumAttrCaseBit<"Valid", 0, "valid">;

def ValidityBit : I32BitEnumAttr<
"ValidityBit",
"validity of a header",
[ValidityBit_Invalid, ValidityBit_Valid]> {
let cppNamespace = "::P4::P4MLIR::P4HIR";
let genSpecializedAttr = 0;
}
def ValidityBitAttr : EnumAttr<P4HIR_Dialect, ValidityBit, "validity.bit",
[TypedAttrInterface]> {
let extraClassDeclaration = [{
mlir::Type getType() { return P4HIR::ValidBitType::get(getContext()); }
}];
}

//===----------------------------------------------------------------------===//
// ParamDirAttr
//===----------------------------------------------------------------------===//
Expand Down
18 changes: 9 additions & 9 deletions include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def StructOp : P4HIR_Op<"struct",
let summary = "Create a struct from constituent parts.";
// FIXME: Better constraint type
let arguments = (ins Variadic<AnyP4Type>:$input);
let results = (outs StructType:$result);
let results = (outs StructLikeType:$result);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
Expand All @@ -841,14 +841,14 @@ def StructExtractOp : P4HIR_Op<"struct_extract",
```
}];

let arguments = (ins StructType:$input, I32Attr:$fieldIndex);
let arguments = (ins StructLikeType:$input, I32Attr:$fieldIndex);
// FIXME: Better constraint type
let results = (outs AnyP4Type:$result);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;

let builders = [
OpBuilder<(ins "mlir::Value":$input, "StructType::FieldInfo":$field)>,
OpBuilder<(ins "mlir::Value":$input, "P4HIR::FieldInfo":$field)>,
OpBuilder<(ins "mlir::Value":$input, "mlir::StringAttr":$fieldName)>,
OpBuilder<(ins "mlir::Value":$input, "llvm::StringRef":$fieldName), [{
build($_builder, $_state, input, $_builder.getStringAttr(fieldName));
Expand All @@ -858,8 +858,8 @@ def StructExtractOp : P4HIR_Op<"struct_extract",
let extraClassDeclaration = [{
/// Return the name attribute of the accessed field.
mlir::StringAttr getFieldNameAttr() {
StructType type = getInput().getType();
return type.getElements()[getFieldIndex()].name;
auto type = mlir::cast<P4HIR::StructLikeTypeInterface>(getInput().getType());
return type.getFields()[getFieldIndex()].name;
}

/// Return the name of the accessed field.
Expand All @@ -883,13 +883,13 @@ def StructExtractRefOp : P4HIR_Op<"struct_extract_ref",
```
}];

let arguments = (ins StructRefType:$input, I32Attr:$fieldIndex);
let arguments = (ins StructLikeRefType:$input, I32Attr:$fieldIndex);
let results = (outs ReferenceType:$result);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;

let builders = [
OpBuilder<(ins "mlir::Value":$input, "StructType::FieldInfo":$field)>,
OpBuilder<(ins "mlir::Value":$input, "P4HIR::FieldInfo":$field)>,
OpBuilder<(ins "mlir::Value":$input, "mlir::StringAttr":$fieldName)>,
OpBuilder<(ins "mlir::Value":$input, "llvm::StringRef":$fieldName), [{
build($_builder, $_state, input, $_builder.getStringAttr(fieldName));
Expand All @@ -899,8 +899,8 @@ def StructExtractRefOp : P4HIR_Op<"struct_extract_ref",
let extraClassDeclaration = [{
/// Return the name attribute of the accessed field.
mlir::StringAttr getFieldNameAttr() {
auto type = mlir::cast<StructType>(mlir::cast<ReferenceType>(getInput().getType()).getObjectType());
return type.getElements()[getFieldIndex()].name;
auto type = mlir::cast<StructLikeTypeInterface>(mlir::cast<ReferenceType>(getInput().getType()).getObjectType());
return type.getFields()[getFieldIndex()].name;
}

/// Return the name of the accessed field.
Expand Down
8 changes: 8 additions & 0 deletions include/p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
#ifndef P4MLIR_DIALECT_P4HIR_P4HIR_TYPEINTERFACES_H
#define P4MLIR_DIALECT_P4HIR_P4HIR_TYPEINTERFACES_H

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Types.h"

namespace P4::P4MLIR::P4HIR {

/// Struct defining a field. Used in structs and header
struct FieldInfo {
mlir::StringAttr name;
mlir::Type type;
};

namespace FieldIdImpl {
unsigned getMaxFieldID(::mlir::Type);

Expand Down
27 changes: 27 additions & 0 deletions include/p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,33 @@

include "mlir/IR/OpBase.td"

def StructLikeTypeInterface : TypeInterface<"StructLikeTypeInterface"> {
let description = [{
Common methods for struct-like types that could be viewed as a collection
of named fields
}];

let methods = [
InterfaceMethod<[{
Get the type of the field at a given index
}],
"mlir::Type", "getFieldType", (ins "mlir::StringRef":$fieldName)>,
InterfaceMethod<[{
Get the field index given the name in StringRef
}],
"std::optional<unsigned>", "getFieldIndex", (ins "mlir::StringRef":$fieldName)>,
InterfaceMethod<[{
Get the field given the name in StringRef
}],
"std::optional<FieldInfo>", "getField", (ins "mlir::StringRef":$fieldName)>,
InterfaceMethod<[{
Get all the fields.
}],
"llvm::ArrayRef<FieldInfo>", "getFields">
];
let cppNamespace = "::P4::P4MLIR::P4HIR";
}

def FieldIDTypeInterface : TypeInterface<"FieldIDTypeInterface"> {
let description = [{
Common methods for types which can be indexed by a FieldID.
Expand Down
11 changes: 0 additions & 11 deletions include/p4mlir/Dialect/P4HIR/P4HIR_Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,6 @@
#include "p4mlir/Dialect/P4HIR/P4HIR_OpsEnums.h"
#include "p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.h"

namespace P4::P4MLIR::P4HIR {

namespace detail {
/// Struct defining a field. Used in structs.
struct FieldInfo {
mlir::StringAttr name;
mlir::Type type;
};
} // namespace detail
} // namespace P4::P4MLIR::P4HIR

#define GET_TYPEDEF_CLASSES
#include "p4mlir/Dialect/P4HIR/P4HIR_Types.h.inc"

Expand Down
195 changes: 176 additions & 19 deletions include/p4mlir/Dialect/P4HIR/P4HIR_Types.td
Original file line number Diff line number Diff line change
Expand Up @@ -223,31 +223,183 @@ def FuncType : P4HIR_Type<"Func", "func"> {
// StructType
//===----------------------------------------------------------------------===//

// A packed struct. Declares the P4HIR::StructType in C++.
def StructType : P4HIR_Type<"Struct", "struct", [
DeclareTypeInterfaceMethods<DestructurableTypeInterface>,
DeclareTypeInterfaceMethods<FieldIDTypeInterface>
class StructLikeType<string name, string typeMnemonic>
: P4HIR_Type<name, typeMnemonic, [
StructLikeTypeInterface,
DestructurableTypeInterface,
FieldIDTypeInterface
]> {
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;

let parameters = (
ins StringRefParameter<"struct name">:$name,
ArrayRefParameter<"P4HIR::FieldInfo", "struct fields">:$elements
);

string extra2ClassDeclaration = "";
let extraClassDeclaration = !strconcat([{
mlir::Type getFieldType(mlir::StringRef fieldName) {
for (const auto &field : getElements())
if (field.name == fieldName) return field.type;
return {};
}

std::optional<P4HIR::FieldInfo> getField(mlir::StringRef fieldName) {
for (const auto &field : getElements())
if (field.name == fieldName) return field;
return {};
}

llvm::ArrayRef<P4HIR::FieldInfo> getFields() {
return getElements();
}

void getInnerTypes(mlir::SmallVectorImpl<mlir::Type> &types) {
for (const auto &field : getElements()) types.push_back(field.type);
}

std::optional<unsigned> getFieldIndex(mlir::StringRef fieldName) {
llvm::ArrayRef<P4HIR::FieldInfo> elems = getElements();
for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
if (elems[idx].name == fieldName) return idx;
return {};
}

std::optional<unsigned> getFieldIndex(mlir::StringAttr fieldName) {
llvm::ArrayRef<P4HIR::FieldInfo> elems = getElements();
for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
if (elems[idx].name == fieldName) return idx;
return {};
}

// FieldID type interface implementation
std::pair<unsigned, llvm::SmallVector<unsigned>> getFieldIDsStruct() const {
unsigned fieldID = 0;
auto elements = getElements();
llvm::SmallVector<unsigned> fieldIDs;
fieldIDs.reserve(elements.size());
for (auto &element : elements) {
auto type = element.type;
fieldID += 1;
fieldIDs.push_back(fieldID);
// Increment the field ID for the next field by the number of subfields.
fieldID += FieldIdImpl::getMaxFieldID(type);
}
return {fieldID, fieldIDs};
}

std::pair<mlir::Type, unsigned> getSubTypeByFieldID(unsigned fieldID) const {
if (fieldID == 0) return {*this, 0};
auto [maxId, fieldIDs] = getFieldIDsStruct();
auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
auto subfieldIndex = std::distance(fieldIDs.begin(), it);
auto subfieldType = getElements()[subfieldIndex].type;
auto subfieldID = fieldID - fieldIDs[subfieldIndex];
return {subfieldType, subfieldID};
}

mlir::Type getTypeAtIndex(mlir::Attribute index) const {
auto indexAttr = mlir::dyn_cast<mlir::IntegerAttr>(index);
if (!indexAttr) return {};

return getSubTypeByFieldID(indexAttr.getInt()).first;
}

unsigned getFieldID(unsigned index) const {
auto [maxId, fieldIDs] = getFieldIDsStruct();
return fieldIDs[index];
}

unsigned getMaxFieldID() const {
unsigned fieldID = 0;
for (const auto &field : getElements()) fieldID += 1 + FieldIdImpl::getMaxFieldID(field.type);
return fieldID;
}

unsigned getIndexForFieldID(unsigned fieldID) const {
assert(!getElements().empty() && "struct must have >0 fields");
auto [maxId, fieldIDs] = getFieldIDsStruct();
auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
return std::distance(fieldIDs.begin(), it);
}

std::pair<unsigned, unsigned> getIndexAndSubfieldID(unsigned fieldID) const {
auto index = getIndexForFieldID(fieldID);
auto elementFieldID = getFieldID(index);
return {index, fieldID - elementFieldID};
}

std::pair<unsigned, bool> projectToChildFieldID(unsigned fieldID,
unsigned index) const {
auto [maxId, fieldIDs] = getFieldIDsStruct();
auto childRoot = fieldIDs[index];
auto rangeEnd = index + 1 >= getElements().size() ? maxId : (fieldIDs[index + 1] - 1);
return std::make_pair(fieldID - childRoot, fieldID >= childRoot && fieldID <= rangeEnd);
}

std::optional<llvm::DenseMap<mlir::Attribute, mlir::Type>> getSubelementIndexMap() const {
llvm::DenseMap<mlir::Attribute, mlir::Type> destructured;
for (auto [i, field] : llvm::enumerate(getElements()))
destructured.try_emplace(mlir::IntegerAttr::get(mlir::IndexType::get(getContext()), i), field.type);
return destructured;
}
}], "\n", extra2ClassDeclaration);
}

// A packed struct. Declares the P4HIR::StructType in C++.
def StructType : StructLikeType<"Struct", "struct"> {
let summary = "struct type";
let description = [{
Represents a structure of name, value pairs.
!p4hir.struct<"name", fieldName1: Type1, fieldName2: Type2>
}];
}

let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
//===----------------------------------------------------------------------===//
// ValidType: type used to represent validity bit in headers. We explicitly
// want it to be distinct from bool type to ensure we can always identify it
// by the type.
//===----------------------------------------------------------------------===//

let parameters = (
ins StringRefParameter<"struct name">:$name, ArrayRefParameter<
"P4HIR::StructType::FieldInfo", "struct fields">:$elements
);
def ValidBitType : P4HIR_Type<"ValidBit", "validity.bit"> {
let summary = "Valid bit type";
let description = [{
`p4hir.valid.bit` represents valid bit in headers.
}];

let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
using FieldInfo = P4HIR::detail::FieldInfo;
mlir::Type getFieldType(mlir::StringRef fieldName);
void getInnerTypes(mlir::SmallVectorImpl<mlir::Type>&);
std::optional<unsigned> getFieldIndex(mlir::StringRef fieldName);
std::optional<unsigned> getFieldIndex(mlir::StringAttr fieldName);
llvm::StringRef getAlias() const { return "validity_bit"; };
}];
}

//===----------------------------------------------------------------------===//
// HeaderType
//===----------------------------------------------------------------------===//

// A header. Declares the P4HIR::HeaderType in C++.
def HeaderType : StructLikeType<"Header", "header"> {
let summary = "header type";
let description = [{
Represents a structure of name, value pairs.
!p4hir.header<"name", fieldName1: Type1, fieldName2: Type2>

Special field named "__valid" of type !p4hir.validity.bit is used to
represent validity bit
}];

// We skip default builders entirely to consistently add validity bit field on fly
let skipDefaultBuilders = 1;

let builders = [
TypeBuilder<(ins "llvm::StringRef":$name,
"llvm::ArrayRef<P4HIR::FieldInfo>":$fields)>
];

// This adds more C++ stuff into parent extraClassDeclaration
let extra2ClassDeclaration = [{
static constexpr llvm::StringRef validityBit = "__valid";
}];
}

Expand Down Expand Up @@ -310,18 +462,22 @@ def SerEnumType : P4HIR_Type<"SerEnum", "ser.enum", []> {
}];
}


//===----------------------------------------------------------------------===//
// P4HIR type constraints.
//===----------------------------------------------------------------------===//

def AnyP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType,
def AnyP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType,
StructType, HeaderType,
EnumType, SerEnumType,
ValidBitType,
DontcareType, ErrorType, UnknownType]> {}
def CallResultP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, VoidType]> {}
def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType,
EnumType, SerEnumType]> {}
def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType,
StructType, HeaderType,
EnumType, SerEnumType,
ValidBitType]> {}
def AnyEnumType : AnyTypeOf<[EnumType, SerEnumType]>;
def StructLikeType : AnyTypeOf<[StructType, HeaderType]>;

/// A ref type with the specified constraints on the nested type.
class SpecificRefType<Type type> : ConfinedType<ReferenceType,
Expand All @@ -333,5 +489,6 @@ class SpecificRefType<Type type> : ConfinedType<ReferenceType,
}

def StructRefType : SpecificRefType<StructType>;
def StructLikeRefType : SpecificRefType<StructLikeType>;

#endif // P4MLIR_DIALECT_P4HIR_P4HIR_TYPES_TD
Loading