Skip to content

Commit

Permalink
Fix parsing of struct_extract_ref operands. Add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Anton Korobeynikov <[email protected]>
  • Loading branch information
asl committed Feb 11, 2025
1 parent 275f5fb commit e5bcc01
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 11 deletions.
10 changes: 5 additions & 5 deletions include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,9 @@ def CallOp : P4HIR_Op<"call",
}];
}

def StructOp : P4HIR_Op<"struct", [Pure]> {
def StructOp : P4HIR_Op<"struct",
[Pure,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Create a struct from constituent parts.";
// FIXME: Better constraint type
let arguments = (ins Variadic<AnyP4Type>:$input);
Expand All @@ -825,10 +827,9 @@ def StructOp : P4HIR_Op<"struct", [Pure]> {
let hasVerifier = 1;
}

// Extract the value of a field of a structure.
def StructExtractOp : P4HIR_Op<"struct_extract",
[Pure,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Extract a named field from a struct.";
let description = [{
Expand Down Expand Up @@ -868,12 +869,11 @@ def StructExtractOp : P4HIR_Op<"struct_extract",
}];
}

// Extract the value of a field of a structure.
def StructExtractRefOp : P4HIR_Op<"struct_extract_ref",
[Pure,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Create a reference to a struct field";
let summary = "Project from a struct reference to a reference to a named struct field";
let description = [{
```
%result = p4hir.struct_extract_ref %input["field"] : <!p4hir.struct<field: type>>
Expand Down
49 changes: 43 additions & 6 deletions lib/Dialect/P4HIR/P4HIR_Ops.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "p4mlir/Dialect/P4HIR/P4HIR_Ops.h"

#include "llvm/ADT/SmallString.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
Expand Down Expand Up @@ -642,6 +643,12 @@ LogicalResult P4HIR::StructOp::verify() {
return success();
}

void P4HIR::StructOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
llvm::SmallString<32> name("struct_");
name += getType().getName();
setNameFn(getResult(), name);
}

//===----------------------------------------------------------------------===//
// StructExtractOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -673,15 +680,45 @@ template <typename AggregateType>
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand operand;
StringAttr fieldName;
Type declType;
AggregateType declType;

if (parser.parseOperand(operand) || parser.parseLSquare() || parser.parseAttribute(fieldName) ||
parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColon() || parser.parseCustomTypeWithFallback<AggregateType>(declType))
return failure();

auto fieldIndex = declType.getFieldIndex(fieldName);
if (!fieldIndex) {
parser.emitError(parser.getNameLoc(),
"field name '" + fieldName.getValue() + "' not found in aggregate type");
return failure();
}

auto indexAttr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
result.addAttribute("fieldIndex", indexAttr);
Type resultType = declType.getElements()[*fieldIndex].type;
result.addTypes(resultType);

if (parser.resolveOperand(operand, declType, result.operands)) return failure();
return success();
}

template <typename AggregateType>
static ParseResult parseExtractRefOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand operand;
StringAttr fieldName;
P4HIR::ReferenceType declType;

if (parser.parseOperand(operand) || parser.parseLSquare() || parser.parseAttribute(fieldName) ||
parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(declType))
parser.parseColon() || parser.parseCustomTypeWithFallback<P4HIR::ReferenceType>(declType))
return failure();
auto aggType = mlir::dyn_cast<AggregateType>(declType);
if (!aggType) return parser.emitError(parser.getNameLoc(), "invalid kind of type specified");

auto aggType = mlir::dyn_cast<AggregateType>(declType.getObjectType());
if (!aggType) {
parser.emitError(parser.getNameLoc(), "expected reference to aggregate type");
return failure();
}
auto fieldIndex = aggType.getFieldIndex(fieldName);
if (!fieldIndex) {
parser.emitError(parser.getNameLoc(),
Expand All @@ -691,7 +728,7 @@ static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {

auto indexAttr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
result.addAttribute("fieldIndex", indexAttr);
Type resultType = aggType.getElements()[*fieldIndex].type;
Type resultType = P4HIR::ReferenceType::get(aggType.getElements()[*fieldIndex].type);
result.addTypes(resultType);

if (parser.resolveOperand(operand, declType, result.operands)) return failure();
Expand Down Expand Up @@ -745,7 +782,7 @@ void P4HIR::StructExtractRefOp::getAsmResultNames(function_ref<void(Value, Strin
}

ParseResult P4HIR::StructExtractRefOp::parse(OpAsmParser &parser, OperationState &result) {
return parseExtractOp<StructType>(parser, result);
return parseExtractRefOp<StructType>(parser, result);
}

void P4HIR::StructExtractRefOp::print(OpAsmPrinter &printer) { printExtractOp(printer, *this); }
Expand Down
37 changes: 37 additions & 0 deletions test/Dialect/P4HIR/struct.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: p4mlir-opt %s | FileCheck %s

!i32i = !p4hir.int<32>
!T = !p4hir.struct<"T", t1: !i32i, t2: !i32i>
!S = !p4hir.struct<"S", s1: !T, s2: !T>
!Empty = !p4hir.struct<"Empty">
!b9i = !p4hir.bit<9>
!PortId_t = !p4hir.struct<"PortId_t", _v: !b9i>

#int10_i32i = #p4hir.int<10> : !i32i
#int20_i32i = #p4hir.int<20> : !i32i
#int1_b9i = #p4hir.int<1> : !b9i

// CHECK: module
module {
%e = p4hir.const ["e"] #p4hir.aggregate<[]> : !Empty
%t = p4hir.const ["t"] #p4hir.aggregate<[#int10_i32i, #int20_i32i]> : !T

p4hir.func action @test2(%arg0: !p4hir.ref<!PortId_t> {p4hir.dir = #p4hir<dir inout>}) {
%_v = p4hir.struct_extract_ref %arg0["_v"] : <!PortId_t>
%val = p4hir.read %arg0 : <!PortId_t>
%_v_0 = p4hir.struct_extract %val["_v"] : !PortId_t
%c1_b9i = p4hir.const #int1_b9i
%add = p4hir.binop(add, %_v_0, %c1_b9i) : !b9i
p4hir.assign %add, %_v : <!b9i>
p4hir.return
}

p4hir.func action @test() {
%vv = p4hir.variable ["vv"] : <!b9i>
%val = p4hir.read %vv : <!b9i>
%0 = p4hir.struct (%val) : !PortId_t
%p1 = p4hir.variable ["p1", init] : <!PortId_t>
p4hir.assign %0, %p1 : <!PortId_t>
p4hir.return
}
}
111 changes: 111 additions & 0 deletions test/Translate/Ops/struct.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s

struct P {
bit<32> f1;
bit<32> f2;
}

struct T {
int<32> t1;
int<32> t2;
}

struct S {
T s1;
T s2;
}

struct Empty {};

// CHECK: !Empty = !p4hir.struct<"Empty">
// CHECK: !PortId_t = !p4hir.struct<"PortId_t", _v: !b9i>
// CHECK: !T = !p4hir.struct<"T", t1: !i32i, t2: !i32i>
// CHECK: !S = !p4hir.struct<"S", s1: !T, s2: !T>
// CHECK: !metadata_t = !p4hir.struct<"metadata_t", foo: !PortId_t>

// CHECK-LABEL: module

const T t = { 32s10, 32s20 };
const S s = { { 32s15, 32s25}, t };

// CHECK: %t = p4hir.const ["t"] #p4hir.aggregate<[#int10_i32i, #int20_i32i]> : !T
// CHECK: %s = p4hir.const ["s"] #p4hir.aggregate<[#p4hir.aggregate<[#int15_i32i, #int25_i32i]> : !T, #p4hir.aggregate<[#int10_i32i, #int20_i32i]> : !T]> : !S

const int<32> x = t.t1;
const int<32> y = s.s1.t2;

const int<32> w = .t.t1;

// CHECK: %x = p4hir.const ["x"] #int10_i32i
// CHECK: %y = p4hir.const ["y"] #int25_i32i
// CHECK: %w = p4hir.const ["w"] #int10_i32i

const T tt1 = s.s1;
const Empty e = {};

// CHECK: %tt1 = p4hir.const ["tt1"] #p4hir.aggregate<[#int15_i32i, #int25_i32i]> : !T
// CHECK: %e = p4hir.const ["e"] #p4hir.aggregate<[]> : !Empty

const T t1 = { 10, 20 };
const S s1 = { { 15, 25 }, t1 };

const int<32> x1 = t1.t1;
const int<32> y1 = s1.s1.t2;

const int<32> w1 = .t1.t1;

const T t2 = s1.s1;

struct PortId_t { bit<9> _v; }

const PortId_t PSA_CPU_PORT = { _v = 9w192 };

struct metadata_t {
PortId_t foo;
}

action test2(inout PortId_t port) {
port._v = port._v + 1;
}

// CHECK-LABEL: p4hir.func action @test2(%arg0: !p4hir.ref<!PortId_t> {p4hir.dir = #p4hir<dir inout>}) {
// CHECK: %[[_V_REF:.*]] = p4hir.struct_extract_ref %arg0["_v"] : <!PortId_t>
// CHECK: %[[VAL:.*]] = p4hir.read %arg0 : <!PortId_t>
// CHECK: %[[_V_VAL:.*]] = p4hir.struct_extract %[[VAL]]["_v"] : !PortId_t
// CHECK: p4hir.assign %{{.*}}, %[[_V_REF]]
// CHECK: p4hir.return

// CHECK-LABEL: p4hir.func action @test(%arg0: !p4hir.ref<!metadata_t> {p4hir.dir = #p4hir<dir inout>}) {
// Just few important bits here
action test(inout metadata_t meta) {
bit<9> vv;

PortId_t p1 = { _v = vv };

// CHECK: %[[VV_VAR:.*]] = p4hir.variable ["vv"] : <!b9i>
// CHECK: %[[VV_VAL:.*]] = p4hir.read %[[VV_VAR]] : <!b9i>
// CHECK: %[[STRUCT:.*]] = p4hir.struct (%[[VV_VAL]]) : !PortId_t
// CHECK: %[[P_VAR:.*]] = p4hir.variable ["p1", init] : <!PortId_t>
// CHECK: p4hir.assign %[[STRUCT]], %[[P_VAR]] : <!PortId_t>

PortId_t p;
bit<9> v;
v = p._v;

v = meta.foo._v;

meta.foo._v = 1;

// CHECK: p4hir.scope {
// CHECK: p4hir.call @test2
test2(meta.foo);
// CHECK: }

// CHECK: %[[METADATA_VAL:.*]] = p4hir.read %arg0 : <!metadata_t>
// CHECK: %[[FOO:.*]] = p4hir.struct_extract %[[METADATA_VAL]]["foo"] : !metadata_t
// CHECK: %[[PSA_CPU_PORT:.*]] = p4hir.const ["PSA_CPU_PORT"] #p4hir.aggregate<[#int192_b9i]> : !PortId_t
// CHECK: %eq = p4hir.cmp(eq, %[[FOO]], %[[PSA_CPU_PORT]]) : !PortId_t, !p4hir.bool
if (meta.foo == PSA_CPU_PORT) {
meta.foo._v = meta.foo._v + 1;
}
}

0 comments on commit e5bcc01

Please sign in to comment.