Skip to content

Commit

Permalink
Add serializable enum
Browse files Browse the repository at this point in the history
Signed-off-by: Anton Korobeynikov <[email protected]>
  • Loading branch information
asl committed Feb 11, 2025
1 parent caf9f5f commit b07eaf9
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 21 deletions.
50 changes: 43 additions & 7 deletions include/p4mlir/Dialect/P4HIR/P4HIR_Types.td
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def StructType : P4HIR_Type<"Struct", "struct", [
}

//===----------------------------------------------------------------------===//
// EnumType
// EnumType & SerEnumType
//===----------------------------------------------------------------------===//
def EnumType : P4HIR_Type<"Enum", "enum", []> {
let summary = "enum type";
Expand All @@ -263,29 +263,65 @@ def EnumType : P4HIR_Type<"Enum", "enum", []> {

let hasCustomAssemblyFormat = 1;

let parameters = (
ins StringRefParameter<"enum name">:$name, "mlir::ArrayAttr":$fields);
let parameters = (ins StringRefParameter<"enum name">:$name,
"mlir::ArrayAttr":$fields);

let extraClassDeclaration = [{
/// Returns true if the requested field is part of this enum
bool contains(mlir::StringRef field);
bool contains(mlir::StringRef field) { return indexOf(field).has_value(); }

/// Returns the index of the requested field, or a nullopt if the field is
// not part of this enum.
/// not part of this enum.
std::optional<size_t> indexOf(mlir::StringRef field);
}];
}

def SerEnumType : P4HIR_Type<"SerEnum", "ser.enum", []> {
let summary = "serializable enum type";
let description = [{
Represents an enumeration of values backed by some integer value
!p4hir.ser.enum<"name", !p4hir.bit<32>, Case1 : 42, Case2 : 0>
}];

let hasCustomAssemblyFormat = 1;

let parameters = (ins StringRefParameter<"enum name">:$name,
"P4HIR::BitsType":$type, "mlir::DictionaryAttr":$fields);

let builders = [
TypeBuilderWithInferredContext<(ins "llvm::StringRef":$name,
"P4HIR::BitsType":$type, "mlir::DictionaryAttr":$fields), [{
return $_get(type.getContext(), name, type, fields);
}]>,
TypeBuilderWithInferredContext<(ins "llvm::StringRef":$name,
"P4HIR::BitsType":$type, "llvm::ArrayRef<mlir::NamedAttribute>":$fields), [{
return $_get(type.getContext(), name, type,
DictionaryAttr::get(type.getContext(), fields));
}]>

];

let extraClassDeclaration = [{
/// Returns true if the requested field is part of this enum
bool contains(mlir::StringRef field) { return getFields().contains(field); }

/// Returns the underlying value of the requested field. Must be BitsAttr.
mlir::Attribute valueOf(mlir::StringRef field) { return getFields().get(field); }
}];
}


//===----------------------------------------------------------------------===//
// P4HIR type constraints.
//===----------------------------------------------------------------------===//

def AnyP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType,
EnumType,
EnumType, SerEnumType,
DontcareType, ErrorType, UnknownType]> {}
def CallResultP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, VoidType]> {}
def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType, EnumType]> {}
def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType,
EnumType, SerEnumType]> {}
def AnyEnumType : AnyTypeOf<[EnumType, SerEnumType]>;

/// A ref type with the specified constraints on the nested type.
class SpecificRefType<Type type> : ConfinedType<ReferenceType,
Expand Down
24 changes: 19 additions & 5 deletions lib/Dialect/P4HIR/P4HIR_Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
}

if (mlir::isa<P4HIR::EnumFieldAttr>(attrType)) {
if (!mlir::isa<P4HIR::EnumType>(opType))
if (!mlir::isa<P4HIR::EnumType, P4HIR::SerEnumType>(opType))
return op->emitOpError("result type (") << opType << ") is not an enum type";

return success();
Expand Down Expand Up @@ -83,8 +83,13 @@ void P4HIR::ConstOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
} else if (auto enumCst = mlir::dyn_cast<P4HIR::EnumFieldAttr>(getValue())) {
llvm::SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
specialName << mlir::cast<P4HIR::EnumType>(enumCst.getType()).getName() << '_'
<< enumCst.getField().getValue();
if (auto enumType = mlir::dyn_cast<P4HIR::EnumType>(enumCst.getType()))
specialName << enumType.getName() << '_' << enumCst.getField().getValue();
else {
specialName << mlir::cast<P4HIR::SerEnumType>(enumCst.getType()).getName() << '_'
<< enumCst.getField().getValue();
}

setNameFn(getResult(), specialName.str());
} else {
setNameFn(getResult(), "cst");
Expand Down Expand Up @@ -856,6 +861,11 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface {
return AliasResult::OverridableAlias;
}

if (auto serEnumType = mlir::dyn_cast<P4HIR::SerEnumType>(type)) {
os << serEnumType.getName();
return AliasResult::OverridableAlias;
}

return AliasResult::NoAlias;
}

Expand All @@ -881,8 +891,12 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface {
}

if (auto enumFieldAttr = mlir::dyn_cast<P4HIR::EnumFieldAttr>(attr)) {
os << mlir::cast<P4HIR::EnumType>(enumFieldAttr.getType()).getName() << "_"
<< enumFieldAttr.getField().getValue();
if (auto enumType = mlir::dyn_cast<P4HIR::EnumType>(enumFieldAttr.getType()))
os << enumType.getName() << "_" << enumFieldAttr.getField().getValue();
else
os << mlir::cast<P4HIR::SerEnumType>(enumFieldAttr.getType()).getName() << "_"
<< enumFieldAttr.getField().getValue();

return AliasResult::FinalAlias;
}

Expand Down
48 changes: 46 additions & 2 deletions lib/Dialect/P4HIR/P4HIR_Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "p4mlir/Dialect/P4HIR/P4HIR_Attrs.h"
#include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.h"
#include "p4mlir/Dialect/P4HIR/P4HIR_OpsEnums.h"

Expand Down Expand Up @@ -347,14 +349,56 @@ void EnumType::print(AsmPrinter &p) const {
p << ">";
}

bool EnumType::contains(mlir::StringRef field) { return indexOf(field).has_value(); }

std::optional<size_t> EnumType::indexOf(mlir::StringRef field) {
for (auto it : llvm::enumerate(getFields()))
if (mlir::cast<StringAttr>(it.value()).getValue() == field) return it.index();
return {};
}

void SerEnumType::print(AsmPrinter &p) const {
auto fields = getFields();
p << '<';
p.printString(getName());
p << ", ";
p.printType(getType());
if (!fields.empty()) p << ", ";
llvm::interleaveComma(fields, p, [&](NamedAttribute enumerator) {
p.printKeywordOrString(enumerator.getName());
p << " : ";
p.printAttribute(enumerator.getValue());
});
p << ">";
}

Type SerEnumType::parse(AsmParser &p) {
std::string name;
llvm::SmallVector<NamedAttribute> fields;
P4HIR::BitsType type;

// Parse "<name, type, " part
if (p.parseLess() || p.parseKeywordOrString(&name) || p.parseComma() ||
p.parseCustomTypeWithFallback<P4HIR::BitsType>(type) || p.parseComma())
return {};

if (p.parseCommaSeparatedList([&]() {
StringRef caseName;
P4HIR::IntAttr attr;
// Parse fields "name : #value"
if (p.parseKeyword(&caseName) || p.parseColon() ||
p.parseCustomAttributeWithFallback<P4HIR::IntAttr>(attr))
return failure();

fields.emplace_back(StringAttr::get(p.getContext(), caseName), attr);
return success();
}))
return {};

// Parse closing >
if (p.parseGreater()) return {};

return get(name, type, fields);
}

void P4HIRDialect::registerTypes() {
addTypes<
#define GET_TYPEDEF_LIST
Expand Down
7 changes: 7 additions & 0 deletions test/Dialect/P4HIR/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@

!Suits = !p4hir.enum<"Suits", Clubs, Diamonds, Hearths, Spades>

#b1 = #p4hir.int<1> : !bit42
#b2 = #p4hir.int<2> : !bit42
#b3 = #p4hir.int<3> : !bit42
#b4 = #p4hir.int<4> : !bit42

!SuitsSerializable = !p4hir.ser.enum<"Suits", !bit42, Clubs : #b1, Diamonds : #b2, Hearths : #b3, Spades : #b4>

// No need to check stuff. If it parses, it's fine.
// CHECK: module
module {
Expand Down
35 changes: 35 additions & 0 deletions test/Translate/Ops/serenum.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s

enum bit<16> EthTypes {
IPv4 = 0x0800,
ARP = 0x0806,
RARP = 0x8035,
EtherTalk = 0x809B,
VLAN = 0x8100,
IPX = 0x8137,
IPv6 = 0x86DD
}

struct Ethernet {
bit<48> src;
bit<48> dest;
EthTypes type;
}

struct Headers {
Ethernet eth;
}

// CHECK: !EthTypes = !p4hir.ser.enum<"EthTypes", !b16i, ARP : #int2054_b16i, EtherTalk : #int-32613_b16i, IPX : #int-32457_b16i, IPv4 : #int2048_b16i, IPv6 : #int-31011_b16i, RARP : #int-32715_b16i, VLAN : #int-32512_b16i>
// CHECK: !Ethernet = !p4hir.struct<"Ethernet", src: !b48i, dest: !b48i, type: !EthTypes>
// CHECK: #EthTypes_IPv4_ = #p4hir.enum.field<IPv4, !EthTypes> : !EthTypes
// CHECK-LABEL: module

// CHECK-LABEL: p4hir.func action @test(%arg0: !p4hir.ref<!Headers>
// CHECK: p4hir.const #EthTypes_IPv4_
action test(inout Headers h) {
if (h.eth.type == EthTypes.IPv4)
h.eth.src = h.eth.dest;
else
h.eth.type = (EthTypes)(bit<16>)0;
}
33 changes: 26 additions & 7 deletions tools/p4mlir-translate/translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
#include <algorithm>
#include <climits>

#include "ir/ir-generated.h"

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wcovered-switch-default"
#include "frontends/common/resolveReferences/resolveReferences.h"
Expand Down Expand Up @@ -136,6 +134,7 @@ class P4TypeConverter : public P4::Inspector {
bool preorder(const P4::IR::Type_Void *v) override;
bool preorder(const P4::IR::Type_Struct *s) override;
bool preorder(const P4::IR::Type_Enum *e) override;
bool preorder(const P4::IR::Type_SerEnum *se) override;

mlir::Type getType() const { return type; }
bool setType(const P4::IR::Type *type, mlir::Type mlirType);
Expand Down Expand Up @@ -484,12 +483,29 @@ bool P4TypeConverter::preorder(const P4::IR::Type_Enum *type) {
if ((this->type = converter.findType(type))) return false;

ConversionTracer trace("TypeConverting ", type);
llvm::SmallVector<mlir::Attribute, 4> fields;
llvm::SmallVector<mlir::Attribute, 4> cases;
for (const auto *field : type->members) {
fields.push_back(mlir::StringAttr::get(converter.context(), field->name.string_view()));
cases.push_back(mlir::StringAttr::get(converter.context(), field->name.string_view()));
}
auto mlirType = P4HIR::EnumType::get(converter.context(), type->name.string_view(),
mlir::ArrayAttr::get(converter.context(), fields));
mlir::ArrayAttr::get(converter.context(), cases));
return setType(type, mlirType);
}

bool P4TypeConverter::preorder(const P4::IR::Type_SerEnum *type) {
if ((this->type = converter.findType(type))) return false;

ConversionTracer trace("TypeConverting ", type);
llvm::SmallVector<mlir::NamedAttribute, 4> cases;

auto enumType = mlir::cast<P4HIR::BitsType>(convert(type->type));
for (const auto *field : type->members) {
auto value = mlir::cast<P4HIR::IntAttr>(converter.getOrCreateConstantExpr(field->value));
cases.emplace_back(mlir::StringAttr::get(converter.context(), field->name.string_view()),
value);
}

auto mlirType = P4HIR::SerEnumType::get(type->name.string_view(), enumType, cases);
return setType(type, mlirType);
}

Expand Down Expand Up @@ -1084,10 +1100,13 @@ bool P4HIRConverter::preorder(const P4::IR::MethodCallExpression *mce) {
bool P4HIRConverter::preorder(const P4::IR::Member *m) {
// This is just enum constant
if (const auto *typeNameExpr = m->expr->to<P4::IR::TypeNameExpression>()) {
auto enumType = mlir::cast<P4HIR::EnumType>(getOrCreateType(typeNameExpr->typeName));
auto type = getOrCreateType(typeNameExpr->typeName);
BUG_CHECK((mlir::isa<P4HIR::EnumType, P4HIR::SerEnumType>(type)),
"unexpected type for expression %1%", typeNameExpr);

setValue(m, builder.create<P4HIR::ConstOp>(
getLoc(builder, m),
P4HIR::EnumFieldAttr::get(enumType, m->member.name.string_view())));
P4HIR::EnumFieldAttr::get(type, m->member.name.string_view())));
return false;
}

Expand Down

0 comments on commit b07eaf9

Please sign in to comment.