Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for enum & serializable enum types #65

Open
wants to merge 2 commits into
base: mlir-struct
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions include/p4mlir/Dialect/P4HIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@ mlir_tablegen(P4HIR_Ops.h.inc -gen-op-decls)
mlir_tablegen(P4HIR_Ops.cpp.inc -gen-op-defs)
mlir_tablegen(P4HIR_Types.h.inc -gen-typedef-decls -typedefs-dialect=p4hir)
mlir_tablegen(P4HIR_Types.cpp.inc -gen-typedef-defs -typedefs-dialect=p4hir)
add_public_tablegen_target(P4MLIR_P4HIR_IncGen)
add_dependencies(mlir-headers P4MLIR_P4HIR_IncGen)

# Generate extra headers for custom enum and attrs.
mlir_tablegen(P4HIR_OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(P4HIR_OpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(P4HIR_Attrs.h.inc -gen-attrdef-decls -attrdefs-dialect=p4hir)
mlir_tablegen(P4HIR_Attrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=p4hir)
add_public_tablegen_target(P4MLIR_P4HIR_AttrIncGen)
add_dependencies(mlir-headers P4MLIR_P4HIR_AttrIncGen)

add_public_tablegen_target(P4MLIR_P4HIR_IncGen)
add_dependencies(mlir-headers P4MLIR_P4HIR_IncGen)
mlir_tablegen(P4HIR_OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(P4HIR_OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(P4MLIR_P4HIR_EnumIncGen)
add_dependencies(mlir-headers P4MLIR_P4HIR_EnumIncGen)

set(LLVM_TARGET_DEFINITIONS P4HIR_TypeInterfaces.td)
mlir_tablegen(P4HIR_TypeInterfaces.h.inc -gen-type-interface-decls)
Expand Down
1 change: 1 addition & 0 deletions include/p4mlir/Dialect/P4HIR/P4HIR.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define P4MLIR_DIALECT_P4HIR_P4HIR_TD

include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.td"
include "p4mlir/Dialect/P4HIR/P4HIR_Attrs.td"
include "p4mlir/Dialect/P4HIR/P4HIR_Ops.td"
include "p4mlir/Dialect/P4HIR/P4HIR_Types.td"
include "p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.td"
Expand Down
28 changes: 28 additions & 0 deletions include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,34 @@ def P4HIR_AggAttr : P4HIR_Attr<"Agg", "aggregate", [TypedAttrInterface]> {

}

//===----------------------------------------------------------------------===//
// EnumFieldAttr
//===----------------------------------------------------------------------===//
// An attribute to indicate an enumeration value.
def P4HIR_EnumFieldAttr : P4HIR_Attr<"EnumField", "enum.field", [TypedAttrInterface]> {
let summary = "Enumeration field attribute";
let description = [{
This attribute represents a field of an enumeration.

Examples:
```mlir
#p4hir.enum.field<A, !p4hir.enum<"name", A, B, C>> : !p4hir.enum<"name", A, B, C>
```
}];
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "::mlir::StringAttr":$field);

// Force all clients to go through custom builder so we can check
// whether the requested enum value is part of the provided enum type.
let skipDefaultBuilders = 1;
let hasCustomAssemblyFormat = 1;

let builders = [
AttrBuilderWithInferredContext<(ins "mlir::Type":$type, "mlir::StringAttr": $value)>,
AttrBuilderWithInferredContext<(ins "mlir::Type":$type, "mlir::StringRef": $value), [{
return $_get(type.getContext(), type, mlir::StringAttr::get(type.getContext(), value));
}]>
];
}

//===----------------------------------------------------------------------===//
// ParamDirAttr
Expand Down
65 changes: 64 additions & 1 deletion include/p4mlir/Dialect/P4HIR/P4HIR_Types.td
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,77 @@ def StructType : P4HIR_Type<"Struct", "struct", [
}];
}

//===----------------------------------------------------------------------===//
// EnumType & SerEnumType
//===----------------------------------------------------------------------===//
def EnumType : P4HIR_Type<"Enum", "enum", []> {
let summary = "enum type";
let description = [{
Represents an enumeration of values
!p4hir.enum<"name", Case1, Case2>
}];

let hasCustomAssemblyFormat = 1;

let parameters = (ins StringRefParameter<"enum name">:$name,
"mlir::ArrayAttr":$fields);

let extraClassDeclaration = [{
/// Returns true if the requested field is part of this enum
bool contains(mlir::StringRef field) { return indexOf(field).has_value(); }

/// Returns the index of the requested field, or a nullopt if the field is
/// not part of this enum.
std::optional<size_t> indexOf(mlir::StringRef field);
}];
}

def SerEnumType : P4HIR_Type<"SerEnum", "ser.enum", []> {
let summary = "serializable enum type";
let description = [{
Represents an enumeration of values backed by some integer value
!p4hir.ser.enum<"name", !p4hir.bit<32>, Case1 : 42, Case2 : 0>
}];

let hasCustomAssemblyFormat = 1;

let parameters = (ins StringRefParameter<"enum name">:$name,
"P4HIR::BitsType":$type, "mlir::DictionaryAttr":$fields);

let builders = [
TypeBuilderWithInferredContext<(ins "llvm::StringRef":$name,
"P4HIR::BitsType":$type, "mlir::DictionaryAttr":$fields), [{
return $_get(type.getContext(), name, type, fields);
}]>,
TypeBuilderWithInferredContext<(ins "llvm::StringRef":$name,
"P4HIR::BitsType":$type, "llvm::ArrayRef<mlir::NamedAttribute>":$fields), [{
return $_get(type.getContext(), name, type,
DictionaryAttr::get(type.getContext(), fields));
}]>

];

let extraClassDeclaration = [{
/// Returns true if the requested field is part of this enum
bool contains(mlir::StringRef field) { return getFields().contains(field); }

/// Returns the underlying value of the requested field. Must be BitsAttr.
mlir::Attribute valueOf(mlir::StringRef field) { return getFields().get(field); }
}];
}


//===----------------------------------------------------------------------===//
// P4HIR type constraints.
//===----------------------------------------------------------------------===//

def AnyP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType,
EnumType, SerEnumType,
DontcareType, ErrorType, UnknownType]> {}
def CallResultP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, VoidType]> {}
def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType]> {}
def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType,
EnumType, SerEnumType]> {}
def AnyEnumType : AnyTypeOf<[EnumType, SerEnumType]>;

/// A ref type with the specified constraints on the nested type.
class SpecificRefType<Type type> : ConfinedType<ReferenceType,
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/P4HIR/P4HIR_Attrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,36 @@ LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError, Type
return success();
}

Attribute EnumFieldAttr::parse(AsmParser &p, Type) {
StringRef field;
P4HIR::EnumType type;
if (p.parseLess() || p.parseKeyword(&field) || p.parseComma() ||
p.parseCustomTypeWithFallback<P4HIR::EnumType>(type) || p.parseGreater())
return {};

return EnumFieldAttr::get(type, field);
}

void EnumFieldAttr::print(AsmPrinter &p) const {
p << "<" << getField().getValue() << ", ";
p.printType(getType());
p << ">";
}

EnumFieldAttr EnumFieldAttr::get(mlir::Type type, StringAttr value) {
EnumType enumType = llvm::dyn_cast<EnumType>(type);
if (!enumType) return nullptr;

// Check whether the provided value is a member of the enum type.
if (!enumType.contains(value.getValue())) {
// emitError() << "enum value '" << value.getValue()
// << "' is not a member of enum type " << enumType;
return nullptr;
}

return Base::get(value.getContext(), type, value);
}

void P4HIRDialect::registerAttributes() {
addAttributes<
#define GET_ATTRDEF_LIST
Expand Down
38 changes: 38 additions & 0 deletions lib/Dialect/P4HIR/P4HIR_Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
return success();
}

if (mlir::isa<P4HIR::EnumFieldAttr>(attrType)) {
if (!mlir::isa<P4HIR::EnumType, P4HIR::SerEnumType>(opType))
return op->emitOpError("result type (") << opType << ") is not an enum type";

return success();
}

assert(isa<TypedAttr>(attrType) && "expected typed attribute");
return op->emitOpError("constant with type ")
<< cast<TypedAttr>(attrType).getType() << " not supported";
Expand Down Expand Up @@ -73,6 +80,17 @@ void P4HIR::ConstOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
setNameFn(getResult(), specialName.str());
} else if (auto boolCst = mlir::dyn_cast<P4HIR::BoolAttr>(getValue())) {
setNameFn(getResult(), boolCst.getValue() ? "true" : "false");
} else if (auto enumCst = mlir::dyn_cast<P4HIR::EnumFieldAttr>(getValue())) {
llvm::SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
if (auto enumType = mlir::dyn_cast<P4HIR::EnumType>(enumCst.getType()))
specialName << enumType.getName() << '_' << enumCst.getField().getValue();
else {
specialName << mlir::cast<P4HIR::SerEnumType>(enumCst.getType()).getName() << '_'
<< enumCst.getField().getValue();
}

setNameFn(getResult(), specialName.str());
} else {
setNameFn(getResult(), "cst");
}
Expand Down Expand Up @@ -838,6 +856,16 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface {
return AliasResult::OverridableAlias;
}

if (auto enumType = mlir::dyn_cast<P4HIR::EnumType>(type)) {
os << enumType.getName();
return AliasResult::OverridableAlias;
}

if (auto serEnumType = mlir::dyn_cast<P4HIR::SerEnumType>(type)) {
os << serEnumType.getName();
return AliasResult::OverridableAlias;
}

return AliasResult::NoAlias;
}

Expand All @@ -862,6 +890,16 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}

if (auto enumFieldAttr = mlir::dyn_cast<P4HIR::EnumFieldAttr>(attr)) {
if (auto enumType = mlir::dyn_cast<P4HIR::EnumType>(enumFieldAttr.getType()))
os << enumType.getName() << "_" << enumFieldAttr.getField().getValue();
else
os << mlir::cast<P4HIR::SerEnumType>(enumFieldAttr.getType()).getName() << "_"
<< enumFieldAttr.getField().getValue();

return AliasResult::FinalAlias;
}

return AliasResult::NoAlias;
}
};
Expand Down
85 changes: 85 additions & 0 deletions lib/Dialect/P4HIR/P4HIR_Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "p4mlir/Dialect/P4HIR/P4HIR_Attrs.h"
#include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.h"
#include "p4mlir/Dialect/P4HIR/P4HIR_OpsEnums.h"

Expand Down Expand Up @@ -315,6 +317,89 @@ void StructType::getInnerTypes(SmallVectorImpl<Type> &types) {
for (const auto &field : getElements()) types.push_back(field.type);
}

Type EnumType::parse(AsmParser &p) {
std::string name;
llvm::SmallVector<Attribute> fields;
bool parsedName = false;
if (p.parseCommaSeparatedList(AsmParser::Delimiter::LessGreater, [&]() {
// First, try to parse name
if (!parsedName) {
if (p.parseKeywordOrString(&name)) return failure();
parsedName = true;
return success();
}

StringRef caseName;
if (p.parseKeyword(&caseName)) return failure();
fields.push_back(StringAttr::get(p.getContext(), name));
return success();
}))
return {};

return get(p.getContext(), name, ArrayAttr::get(p.getContext(), fields));
}

void EnumType::print(AsmPrinter &p) const {
auto fields = getFields();
p << '<';
p.printString(getName());
if (!fields.empty()) p << ", ";
llvm::interleaveComma(fields, p, [&](Attribute enumerator) {
p << mlir::cast<StringAttr>(enumerator).getValue();
});
p << ">";
}

std::optional<size_t> EnumType::indexOf(mlir::StringRef field) {
for (auto it : llvm::enumerate(getFields()))
if (mlir::cast<StringAttr>(it.value()).getValue() == field) return it.index();
return {};
}

void SerEnumType::print(AsmPrinter &p) const {
auto fields = getFields();
p << '<';
p.printString(getName());
p << ", ";
p.printType(getType());
if (!fields.empty()) p << ", ";
llvm::interleaveComma(fields, p, [&](NamedAttribute enumerator) {
p.printKeywordOrString(enumerator.getName());
p << " : ";
p.printAttribute(enumerator.getValue());
});
p << ">";
}

Type SerEnumType::parse(AsmParser &p) {
std::string name;
llvm::SmallVector<NamedAttribute> fields;
P4HIR::BitsType type;

// Parse "<name, type, " part
if (p.parseLess() || p.parseKeywordOrString(&name) || p.parseComma() ||
p.parseCustomTypeWithFallback<P4HIR::BitsType>(type) || p.parseComma())
return {};

if (p.parseCommaSeparatedList([&]() {
StringRef caseName;
P4HIR::IntAttr attr;
// Parse fields "name : #value"
if (p.parseKeyword(&caseName) || p.parseColon() ||
p.parseCustomAttributeWithFallback<P4HIR::IntAttr>(attr))
return failure();

fields.emplace_back(StringAttr::get(p.getContext(), caseName), attr);
return success();
}))
return {};

// Parse closing >
if (p.parseGreater()) return {};

return get(name, type, fields);
}

void P4HIRDialect::registerTypes() {
addTypes<
#define GET_TYPEDEF_LIST
Expand Down
11 changes: 11 additions & 0 deletions test/Dialect/P4HIR/enum.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: p4mlir-opt %s | FileCheck %s

!Suits = !p4hir.enum<"Suits", Clubs, Diamonds, Hearths, Spades>

#Suits_Clubs = #p4hir.enum.field<Clubs, !Suits> : !Suits
#Suits_Diamonds = #p4hir.enum.field<Diamonds, !Suits> : !Suits

// CHECK: module
module {
%Suits_Diamonds = p4hir.const #Suits_Diamonds
}
9 changes: 9 additions & 0 deletions test/Dialect/P4HIR/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@
!struct = !p4hir.struct<"struct_name", boolfield : !p4hir.bool, bitfield : !bit42>
!nested_struct = !p4hir.struct<"another_name", neststructfield : !struct, bitfield : !bit42>

!Suits = !p4hir.enum<"Suits", Clubs, Diamonds, Hearths, Spades>

#b1 = #p4hir.int<1> : !bit42
#b2 = #p4hir.int<2> : !bit42
#b3 = #p4hir.int<3> : !bit42
#b4 = #p4hir.int<4> : !bit42

!SuitsSerializable = !p4hir.ser.enum<"Suits", !bit42, Clubs : #b1, Diamonds : #b2, Hearths : #b3, Spades : #b4>

// No need to check stuff. If it parses, it's fine.
// CHECK: module
module {
Expand Down
Loading