Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions qwerty_mlir/include/CAPI/Qwerty.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ MLIR_CAPI_EXPORTED bool mlirQwertyBasisVectorAttrGetHasPhase(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAttribute mlirQwertyBasisVectorListAttrGet(
MlirContext ctx, intptr_t numVectors, MlirAttribute const *vectors);

MLIR_CAPI_EXPORTED MlirAttribute mlirQwertyBasisVectorTreeAttrGet(
MlirContext ctx, int64_t kind, bool hasTilt, double tiltDeg,
intptr_t numChildren, MlirAttribute const *children);

/// Returns true if this is a qwerty::BasisVectorTreeAttr.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAQwertyBasisVectorTree(MlirAttribute attr);

/// Returns true if this is a qwerty::BasisVectorListAttr.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAQwertyBasisVectorList(MlirAttribute attr);

Expand All @@ -115,6 +122,7 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAQwertyBasisElem(MlirAttribute attr);
/// qwerty::BasisVectorAttrs.
MLIR_CAPI_EXPORTED MlirAttribute mlirQwertyBasisAttrGet(
MlirContext ctx, intptr_t numElems, MlirAttribute const *elems);


/// Returns true if this is a qwerty::BasisAttr.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAQwertyBasis(MlirAttribute attr);
Expand Down
34 changes: 34 additions & 0 deletions qwerty_mlir/include/Qwerty/IR/QwertyAttributes.td

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow great idea for how to approach this. I was wondering how to do it, but you answered it.

It just needs a robust verifier, which you added (thanks)

Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,40 @@ def Qwerty_BasisVectorListAttr : Qwerty_Attr<"BasisVectorList", "veclist"> {
let genVerifyDecl = 1;
}

def Qwerty_BasisVectorTreeKindAttr : I64EnumAttr<
"BasisVectorTreeKind", "Kind of a recursive basis-vector tree node",
[
I64EnumAttrCase<"ZeroVector", 0>,
I64EnumAttrCase<"OneVector", 1>,
I64EnumAttrCase<"PadVector", 2>,
I64EnumAttrCase<"TargetVector", 3>,
I64EnumAttrCase<"VectorUnit", 4>,
I64EnumAttrCase<"VectorTilt", 5>,
I64EnumAttrCase<"UniformVectorSuperpos", 6>,
I64EnumAttrCase<"VectorTensor", 7>,
]> {
let cppNamespace = "::qwerty";
}

def Qwerty_BasisVectorTreeAttr : Qwerty_Attr<"BasisVectorTree", "vectree"> {
let summary = "A recursive basis vector mirroring the AST `Vector` enum";
let description = [{
A mirror of the AST `Vector` type: one tree node per AST node.
The front-end transcribes a `Vector` structurally, with no Pauli recognition,
factoring, or flattening; all semantic analysis happens downstream on
the tree.
}];
let parameters = (ins
"BasisVectorTreeKind":$kind,
OptionalParameter<"mlir::FloatAttr">:$tilt,
ArrayRefParameter<"BasisVectorTreeAttr", "">:$children
);

let hasCustomAssemblyFormat = 1;

let genVerifyDecl = 1;
}

def Qwerty_BasisElemAttr : Qwerty_Attr<"BasisElem", "basiselem"> {
let summary = "Union of BuiltinBasisAttr, BasisVectorListAttr, and ApplyRevolveGeneratorAttr";
let description = [{
Expand Down
22 changes: 22 additions & 0 deletions qwerty_mlir/lib/CAPI/Qwerty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,28 @@ MlirAttribute mlirQwertyBasisVectorListAttrGet(
return wrap(qwerty::BasisVectorListAttr::get(unwrap(ctx), vecs));
}

MlirAttribute mlirQwertyBasisVectorTreeAttrGet(
MlirContext ctx, int64_t kind, bool hasTilt, double tiltDeg,
intptr_t numChildren, MlirAttribute const *children) {
llvm::SmallVector<mlir::Attribute> attrs;
(void)unwrapList(static_cast<size_t>(numChildren), children, attrs);

llvm::SmallVector<qwerty::BasisVectorTreeAttr> trees;
for (mlir::Attribute attr : attrs) {
trees.push_back(llvm::cast<qwerty::BasisVectorTreeAttr>(attr));
}
mlir::FloatAttr tilt;
if (hasTilt) {
tilt = mlir::FloatAttr::get(mlir::Float64Type::get(unwrap(ctx)), tiltDeg);
}

return wrap(qwerty::BasisVectorTreeAttr::get(unwrap(ctx), static_cast<qwerty::BasisVectorTreeKind>(kind), tilt, trees));
}

bool mlirAttributeIsAQwertyBasisVectorTree(MlirAttribute attr) {
return llvm::isa<qwerty::BasisVectorTreeAttr>(unwrap(attr));
}

bool mlirAttributeIsAQwertyBasisVectorList(MlirAttribute attr) {
return llvm::isa<qwerty::BasisVectorListAttr>(unwrap(attr));
}
Expand Down
112 changes: 112 additions & 0 deletions qwerty_mlir/lib/Qwerty/IR/QwertyAttributes.cpp

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to write a test to test at least the happy path for this parsing and verification. (In the past, when I haven't created such a test, the code breaks later when testing something tangential, confusing the daylights out of me)

Here's the existing parsing test file for the Qwerty dialect that you can extend with a few new tests to hit the happy paths: https://github.com/gt-tinker/qwerty/blob/main/qwerty_mlir/tests/Qwerty/IR/parsing.mlir. Note in the first line it basically says to parse the IR, print it, and then parse it again.

The tricky part is going to be writing the attribute at all, since I don't think there are any ops right now that have this attribute. Hmmm...

The other tricky part is that FileCheck can be pretty confusing. It is type (4) of tests described in this document: https://github.com/gt-tinker/qwerty/blob/main/docs/testing.md. I can explain it in a meeting if it's helpful. I am not aware of there being a good tutorial for writing FileCheck. I just had to stare at other people's tests and the documentation to figure it out, but I can try to save you time by skipping most of that pain

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, here's how upstream tests attribute parsing: https://github.com/llvm/llvm-project/blob/9f73a97cfb1a40da42a870906e2221102a71f807/mlir/test/IR/custom-float-attr-roundtrip.mlir#L6. That's pretty cool, they have a dummy op called test.op we can use. I just checked and the test dialect seems to be included in our LLVM build

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this in a slightly different way by just using a module declaration. Let me know if that works for you, I was able to run the run-tests.sh script and it seemed to pass.

Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,118 @@ mlir::LogicalResult BasisVectorListAttr::verify(
return mlir::success();
}

mlir::LogicalResult BasisVectorTreeAttr::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
BasisVectorTreeKind kind,
mlir::FloatAttr tilt,
llvm::ArrayRef<BasisVectorTreeAttr> children) {
bool hasTilt = (bool)tilt;
size_t n = children.size();

auto angleErr = [&](bool requiredAngle) {
return emitError() << stringifyBasisVectorTreeKind(kind)
<< (requiredAngle ? " requires an angle"
: " must not carry an angle");
};
auto arityErr = [&](const char *numChildren) {
return emitError() << stringifyBasisVectorTreeKind(kind) << " expects "
<< numChildren << " child(ren), got " << n;
};

switch (kind) {
case BasisVectorTreeKind::ZeroVector:
case BasisVectorTreeKind::OneVector:
case BasisVectorTreeKind::PadVector:
case BasisVectorTreeKind::TargetVector:
case BasisVectorTreeKind::VectorUnit:
if (hasTilt) return angleErr(false);
if (n != 0) return arityErr("no");
return mlir::success();

case BasisVectorTreeKind::VectorTilt:
if (!hasTilt) return angleErr(true);
if (n != 1) return arityErr("exactly 1");
return mlir::success();

case BasisVectorTreeKind::UniformVectorSuperpos:
if (hasTilt) return angleErr(false);
if (n != 2) return arityErr("exactly 2");
return mlir::success();

case BasisVectorTreeKind::VectorTensor:
if (hasTilt) return angleErr(false);
if (n < 2) return arityErr("at least 2");
return mlir::success();
}
return emitError() << "unknown BasisVectorTreeKind";
}

void BasisVectorTreeAttr::print(mlir::AsmPrinter &printer) const {
printer << "<" << stringifyBasisVectorTreeKind(getKind());
if (mlir::FloatAttr tilt = getTilt()) {
printer << " tilt " << tilt;
}
printer << " [";
llvm::ArrayRef<BasisVectorTreeAttr> children = getChildren();
for (size_t i = 0; i < children.size(); i++) {
if (i) {
printer << ", ";
}
printer << children[i];
}
printer << "]>";
}

mlir::Attribute BasisVectorTreeAttr::parse(mlir::AsmParser &parser, mlir::Type odsType) {
llvm::SMLoc loc = parser.getCurrentLocation();
if (parser.parseLess()) {
return {};
}

llvm::StringRef kindKeyword;
if (parser.parseKeyword(&kindKeyword)) {
return {};
}
std::optional<BasisVectorTreeKind> kind =
symbolizeBasisVectorTreeKind(kindKeyword);
if (!kind) {
parser.emitError(loc, "unknown BasisVectorTreeKind: '")
<< kindKeyword << "'";
return {};
}

mlir::FloatAttr tilt;
if (succeeded(parser.parseOptionalKeyword("tilt"))) {
if (parser.parseAttribute(tilt)) {
return {};
}
}

llvm::SmallVector<BasisVectorTreeAttr> children;
if (parser.parseLSquare()) {
return {};
}
if (failed(parser.parseOptionalRSquare())) {
do {
BasisVectorTreeAttr child;
if (parser.parseAttribute(child)) {
return {};
}
children.push_back(child);
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseRSquare()) {
return {};
}
}

if (parser.parseGreater()) {
return {};
}

return parser.getChecked<BasisVectorTreeAttr>(
loc, parser.getContext(), *kind, tilt, children);
}

mlir::LogicalResult ApplyRevolveGeneratorAttr::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
BasisAttr foo,
Expand Down
21 changes: 21 additions & 0 deletions qwerty_mlir/tests/Qwerty/IR/vectree.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: qwerty-opt -split-input-file %s | qwerty-opt -split-input-file | FileCheck --strict-whitespace %s
//
// Happy-path round-trip (parse -> print -> parse -> print) of
// BasisVectorTreeAttr. If any case were malformed, parsing would fail and the round-trip would error here.


// A bare leaf atom: '0' (0 children -> empty brackets)
// CHECK: qwerty.t = #qwerty.vectree<ZeroVector []>
module attributes {qwerty.t = #qwerty.vectree<ZeroVector []>} {}

// -----

// A tensor of computational atoms: '00'
// CHECK: qwerty.t = #qwerty.vectree<VectorTensor [#qwerty.vectree<ZeroVector []>, #qwerty.vectree<ZeroVector []>]>
module attributes {qwerty.t = #qwerty.vectree<VectorTensor [#qwerty.vectree<ZeroVector []>, #qwerty.vectree<ZeroVector []>]>} {}

// -----

// The non-Pauli vector '0'@45 + '1'.
// CHECK: qwerty.t = #qwerty.vectree<UniformVectorSuperpos [#qwerty.vectree<VectorTilt tilt {{[0-9.eE+-]+}} : f64 [#qwerty.vectree<ZeroVector []>]>, #qwerty.vectree<OneVector []>]>
module attributes {qwerty.t = #qwerty.vectree<UniformVectorSuperpos [#qwerty.vectree<VectorTilt tilt 45.0 : f64 [#qwerty.vectree<ZeroVector []>]>, #qwerty.vectree<OneVector []>]>} {}