Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: openxla/stablehlo
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 8cd9444b78ccec3e42a4b21105a5a547c021e823
Choose a base ref
...
head repository: openxla/stablehlo
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 48a1e14edc8219577fcad53de1924876f855f431
Choose a head ref
  • 9 commits
  • 52 files changed
  • 8 contributors

Commits on Jan 22, 2025

  1. Bump patch version after integrate 1.8.10 -> 1.8.11 (#2692)

    sdasgup3 authored Jan 22, 2025

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    01baa23 View commit details

Commits on Jan 23, 2025

  1. Expose populateStablehloToLinalgConversionPatterns function (#2695)

    This function declaration has existed since the code was forked from
    IREE in #1817, but the
    implementation was kept private (static function within an anonymous
    namespace).
    
    I'm now trying to switch IREE from having its own implementation to
    using the upstream implementation from this project in
    iree-org/iree#19792, and I would like to access
    these patterns directly, instead of through the
    `StablehloLegalizeToLinalgPass`. With the patterns I can run conversion
    including my own sets of additional patterns, while a pass runs in
    isolation.
    
    I'm also deleting the `populateLegalizeChloPatterns`,
    `populateLegalizeControlFlowPatterns`, and
    `populateLegalizeShapeComputationPatterns` declarations, which were not
    migrated from IREE and are also dangling without implementations.
    ScottTodd authored Jan 23, 2025

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    c27ba67 View commit details

Commits on Jan 27, 2025

  1. Add source link to awesome.md. (#2697)

    I couldn't find an easy link back to this source file from
    https://openxla.org/stablehlo/awesome.
    
    I _think_ this style of URL should be fine, but I might be missing some
    website detail (for example, the hosted website has support for
    translations).
    
    A site-wide solution may be available for some source files. For
    example:
    * The mkdocs-material site generator has "edit this page" and "view
    source of this page" actions:
    https://squidfunk.github.io/mkdocs-material/setup/adding-a-git-repository/#code-actions.
    * Some tensorflow examples include "View source on GitHub" buttons for
    notebooks: https://www.tensorflow.org/guide/keras/functional_api
    ScottTodd authored Jan 27, 2025

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    c02033c View commit details
  2. Add IREE project link to awesome.md. (#2698)

    Happy to iterate on the specific phrasing, suggestions welcome.
    
    I wonder if some reorganization would help too, since IREE could fit
    into either the "PJRT Plugins" or "Edge Compilation" sections.
    ScottTodd authored Jan 27, 2025

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    b9d4c70 View commit details
  3. Added GoMLX and gopjrt for awesome.md (#2696)

    janpfeifer authored Jan 27, 2025

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    540d48f View commit details

Commits on Jan 28, 2025

  1. Add stablehlo-coreml to awesome.md (#2699)

    kasper0406 authored Jan 28, 2025

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    37750d0 View commit details
  2. Bump patch version after integrate 1.8.11 -> 1.8.12 (#2700)

    ghpvnist authored Jan 28, 2025

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    8993ef7 View commit details

Commits on Jan 29, 2025

  1. Add ResultAccuracy to ExpOp (#2694)

    Implementation of RFC: #2592
    
    For ExpOp.
    
    TODO: Modify spec.md
    
    ---------
    
    Co-authored-by: Rachel Han <[email protected]>
    GleasonK and hanrach9 authored Jan 29, 2025

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    7c50d4e View commit details
  2. Integrate LLVM at llvm/llvm-project@aa65f93b71de (#2701)

    Co-authored-by: Rachel Han <[email protected]>
    abhigunj and hanrach9 authored Jan 29, 2025

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    48a1e14 View commit details
Showing with 4,076 additions and 116 deletions.
  1. +1 −1 BUILD.bazel
  2. +2 −2 WORKSPACE.bazel
  3. +1 −1 build_tools/llvm_version.txt
  4. +17 −1 docs/awesome.md
  5. +3 −16 stablehlo/conversions/linalg/transforms/Rewriters.h
  6. +39 −38 stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp
  7. +23 −0 stablehlo/dialect/AssemblyFormat.cpp
  8. +59 −0 stablehlo/dialect/AssemblyFormat.h
  9. +17 −0 stablehlo/dialect/Base.cpp
  10. +3 −0 stablehlo/dialect/Base.h
  11. +17 −0 stablehlo/dialect/Base.td
  12. +1 −1 stablehlo/dialect/CMakeLists.txt
  13. +15 −0 stablehlo/dialect/StablehloAttrs.td
  14. +79 −7 stablehlo/dialect/StablehloBytecode.cpp
  15. +23 −0 stablehlo/dialect/StablehloEnums.td
  16. +38 −0 stablehlo/dialect/StablehloOps.cpp
  17. +19 −2 stablehlo/dialect/StablehloOps.td
  18. +29 −4 stablehlo/dialect/TypeInference.cpp
  19. +9 −0 stablehlo/dialect/TypeInference.h
  20. +3 −3 stablehlo/dialect/Version.cpp
  21. +1 −1 stablehlo/dialect/Version.h
  22. +24 −11 stablehlo/dialect/VhloAttrs.td
  23. +74 −1 stablehlo/dialect/VhloBytecode.cpp
  24. +1 −0 stablehlo/dialect/VhloDialect.td
  25. +33 −1 stablehlo/dialect/VhloEnums.td
  26. +9 −8 stablehlo/dialect/VhloOps.cpp
  27. +8 −1 stablehlo/dialect/VhloOps.td
  28. +67 −0 stablehlo/integrations/c/StablehloAttributes.cpp
  29. +37 −0 stablehlo/integrations/c/StablehloAttributes.h
  30. +44 −0 stablehlo/integrations/python/StablehloModule.cpp
  31. +21 −0 stablehlo/integrations/python/tests/stablehlo.py
  32. +6 −7 stablehlo/reference/Types.cpp
  33. +40 −0 stablehlo/tests/ops_stablehlo.mlir
  34. +63 −0 stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir
  35. +5 −0 stablehlo/tests/ops_stablehlo_roundtrip.mlir
  36. +13 −0 stablehlo/tests/print_stablehlo.mlir
  37. +11 −0 stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
  38. +2,966 −0 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir
  39. BIN stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir.bc
  40. +31 −1 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir
  41. +26 −0 stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
  42. +24 −0 stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir
  43. +22 −0 stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir
  44. +1 −1 stablehlo/transforms/MapStablehloToVhlo.h
  45. +3 −3 stablehlo/transforms/PassUtils.h
  46. +5 −0 stablehlo/transforms/Passes.h
  47. +20 −2 stablehlo/transforms/StablehloAggressiveSimplification.cpp
  48. +6 −3 stablehlo/transforms/StablehloComplexMathExpanderPatterns.td
  49. +24 −0 stablehlo/transforms/StablehloLegalizeToVhlo.cpp
  50. +24 −0 stablehlo/transforms/VhloLegalizeToStablehlo.cpp
  51. +53 −0 stablehlo/transforms/VhloToVersion.cpp
  52. +16 −0 stablehlo/transforms/VhloToVersionPatterns.td
2 changes: 1 addition & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1547,7 +1547,7 @@ gentbl_cc_library(
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "stablehlo/dialect/VhloAttrs.td",
td_file = "stablehlo/dialect/VhloEnums.td",
deps = [
":vhlo_ops_td_files",
],
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
@@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "e2402615a5a76d46a433dfcc1de10b38a1263c9d"
LLVM_COMMIT = "aa65f93b71dee8cacb22be1957673c8be6a3ec24"

LLVM_SHA256 = "9c22349e1d38555b2f223e49951655f60c04c0c3467e0150aaf6c9f50484cc9f"
LLVM_SHA256 = "0a6046edb6a9834d5b912ec0e705dec91d39ee1b7b2fbb5930955d83d2090ff5"

http_archive(
name = "llvm-raw",
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
e2402615a5a76d46a433dfcc1de10b38a1263c9d
aa65f93b71dee8cacb22be1957673c8be6a3ec24
18 changes: 17 additions & 1 deletion docs/awesome.md
Original file line number Diff line number Diff line change
@@ -18,7 +18,9 @@ backends. Built by industry leaders in AI modeling, software, and hardware.
**How is the community using OpenXLA?** This page consolidates links to
repositories and projects using OpenXLA to provide inspiration and code pointers!

**Have a project that uses OpenXLA?** Send us a pull request and add it to this page!
**Have a project that uses OpenXLA?** Send us a
[pull request](https://github.com/openxla/stablehlo/blob/main/docs/awesome.md)
and add it to this page!

## Frameworks

@@ -28,6 +30,11 @@ NumPy-like API for writing high-performance ML models <img align="center" src="h
to OpenXLA and StableHLO <img align="center" src="https://img.shields.io/github/stars/pytorch/xla?style=social">
- [TensorFlow](https://github.com/tensorflow/tensorflow) is a long-standing ML
framework with a large ecosystem <img align="center" src="https://img.shields.io/github/stars/tensorflow/tensorflow?style=social">
- [GoMLX](https://github.com/gomlx/gomlx) ML Framework for the Go Language
<img align="center" src="https://img.shields.io/github/stars/gomlx/gomlx?style=social">
- [gopjrt](https://github.com/gomlx/gopjrt) raw XlaBuilder+PJRT wrapper for Go:
tested on CPU, GPU and TPU.
<img align="center" src="https://img.shields.io/github/stars/gomlx/gopjrt?style=social">

## PJRT Plugins

@@ -40,6 +47,15 @@ on Google's Cloud TPUs
to deploy to mobile devices using [LiteRT](https://ai.google.dev/edge/litert)
- [AI Edge Torch](https://github.com/google-ai-edge/ai-edge-torch) exports
PyTorch models for mobile deployment via StableHLO <img align="center" src="https://img.shields.io/github/stars/google-ai-edge/ai-edge-torch?style=social">
- [IREE](https://iree.dev/) uses StableHLO as an input format to deploy across
a range of devices and accelerators
<img align="center" src="https://img.shields.io/github/stars/iree-org/iree?style=social">
- IREE also includes a
[PJRT plugin](https://github.com/iree-org/iree/tree/main/integrations/pjrt)
- [StableHLO to CoreML](https://github.com/kasper0406/stablehlo-coreml/tree/main)
converts StableHLO models to [Apple's CoreML](https://developer.apple.com/documentation/coreml/)
for deploying to Apple devices
<img align="center" src="https://img.shields.io/github/stars/kasper0406/stablehlo-coreml?style=social">

## Tooling and Visualization

19 changes: 3 additions & 16 deletions stablehlo/conversions/linalg/transforms/Rewriters.h
Original file line number Diff line number Diff line change
@@ -22,28 +22,15 @@ limitations under the License.
namespace mlir::stablehlo {

//===----------------------------------------------------------------------===//
// General StableHLO/CHLO lowering patterns.
// General StableHLO lowering patterns.
//===----------------------------------------------------------------------===//

/// Populates the patterns that convert from StableHLO to Linalg on tensors.
void populateStablehloToLinalgConversionPatterns(MLIRContext *context,
TypeConverter &typeConverter,
RewritePatternSet *patterns,
bool enablePrimitiveOps);

/// Collection of rewrite patterns for lowering of CHLO ops to StableHLO and
/// Shape ops.
void populateLegalizeChloPatterns(MLIRContext *context,
RewritePatternSet *patterns);

/// Collection of rewrite patterns for lowering of StableHLO ops to SCF control
/// flow ops.
void populateLegalizeControlFlowPatterns(MLIRContext *context,
RewritePatternSet *patterns);

/// Collection of rewrite patterns for lowering of StableHLO dim operations.
void populateLegalizeShapeComputationPatterns(MLIRContext *context,
RewritePatternSet *patterns);
bool enablePrimitiveOps,
bool enableSparseOps);

//===----------------------------------------------------------------------===//
// Fine-grained patterns used by the implementation.
Original file line number Diff line number Diff line change
@@ -2600,11 +2600,45 @@ struct SetDimensionSizeConverter final
}
};

static void populateConversionPatterns(MLIRContext *context,
TypeConverter &typeConverter,
RewritePatternSet *patterns,
bool enablePrimitiveOps,
bool enableSparseOps) {
struct StablehloLegalizeToLinalgPass
: impl::StablehloLegalizeToLinalgPassBase<StablehloLegalizeToLinalgPass> {
using StablehloLegalizeToLinalgPassBase::StablehloLegalizeToLinalgPassBase;

LogicalResult initialize(MLIRContext *context) override {
target = std::make_shared<ConversionTarget>(*context);
target->addLegalDialect<
bufferization::BufferizationDialect, arith::ArithDialect,
complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect,
tensor::TensorDialect, sparse_tensor::SparseTensorDialect,
scf::SCFDialect, shape::ShapeDialect>();
target->addLegalOp<UnrealizedConversionCastOp>();

RewritePatternSet patterns_(context);
populateStablehloToLinalgConversionPatterns(
context, converter, &patterns_, enablePrimitiveOps, enableSparseOps);
patterns = std::move(patterns_);

return success();
}

void runOnOperation() override {
if (failed(applyPartialConversion(getOperation(), *target, patterns))) {
return signalPassFailure();
}
}

private:
std::shared_ptr<ConversionTarget> target;
FrozenRewritePatternSet patterns;
LinalgTypeConverter converter;
};
} // namespace

void populateStablehloToLinalgConversionPatterns(MLIRContext *context,
TypeConverter &typeConverter,
RewritePatternSet *patterns,
bool enablePrimitiveOps,
bool enableSparseOps) {
// clang-format off
patterns->add<ConcatenateConverter>(typeConverter, context,
enablePrimitiveOps);
@@ -2670,37 +2704,4 @@ static void populateConversionPatterns(MLIRContext *context,
linalg::populateEraseUnusedOperandsAndResultsPatterns(*patterns);
}

struct StablehloLegalizeToLinalgPass
: impl::StablehloLegalizeToLinalgPassBase<StablehloLegalizeToLinalgPass> {
using StablehloLegalizeToLinalgPassBase::StablehloLegalizeToLinalgPassBase;

LogicalResult initialize(MLIRContext *context) override {
target = std::make_shared<ConversionTarget>(*context);
target->addLegalDialect<
bufferization::BufferizationDialect, arith::ArithDialect,
complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect,
tensor::TensorDialect, sparse_tensor::SparseTensorDialect,
scf::SCFDialect, shape::ShapeDialect>();
target->addLegalOp<UnrealizedConversionCastOp>();

RewritePatternSet patterns_(context);
populateConversionPatterns(context, converter, &patterns_,
enablePrimitiveOps, enableSparseOps);
patterns = std::move(patterns_);

return success();
}

void runOnOperation() override {
if (failed(applyPartialConversion(getOperation(), *target, patterns))) {
return signalPassFailure();
}
}

private:
std::shared_ptr<ConversionTarget> target;
FrozenRewritePatternSet patterns;
LinalgTypeConverter converter;
};
} // namespace
} // namespace mlir::stablehlo
23 changes: 23 additions & 0 deletions stablehlo/dialect/AssemblyFormat.cpp
Original file line number Diff line number Diff line change
@@ -860,6 +860,29 @@ ParseResult parseCustomCallTarget(AsmParser& parser, StringAttr& target) {
return parser.parseSymbolName(target);
}

void printResultAccuracyAttr(AsmPrinter& odsPrinter, APFloat atol, APFloat rtol,
int64_t ulps, Attribute mode) {
odsPrinter << "<";
if (!atol.isZero()) {
odsPrinter << "atol = ";
odsPrinter.printFloat(atol);
odsPrinter << ", ";
}
if (!rtol.isZero()) {
odsPrinter << "rtol = ";
odsPrinter.printFloat(rtol);
odsPrinter << ", ";
}
if (ulps != 0) {
odsPrinter << "ulps = ";
odsPrinter << ulps;
odsPrinter << ", ";
}
odsPrinter << "mode = ";
odsPrinter.printAttribute(mode);
odsPrinter << ">";
}

void printTypeExtensions(BoundedAttrInterface attr, DialectAsmPrinter& os) {
os << "bounds<";
llvm::interleaveComma(attr.getBounds(), os,
59 changes: 59 additions & 0 deletions stablehlo/dialect/AssemblyFormat.h
Original file line number Diff line number Diff line change
@@ -378,6 +378,65 @@ ParseResult parseDotDimensionNumbers(AsmParser& parser, AttrTy& target) {
return success();
}

// ResultAccuracyAttr - Custom printing and parsing for ResultAccuracyAttr.
//
// ResultAccuractAttr ::= `<` OptAtolAccuracy OptRtolAccuracy
// OptUlpAccuracy ModeAccuracy `>`
// OptAtolAccuracy ::= `atol` `=` APFloat `, ` | eps
// OptRtolAccuracy ::= `rtol` `=` APFloat `, ` | eps
// OptUlpAccuracy ::= `ulps` `=` int64_t `, ` | eps
// ModeAccuracy ::= `mode` `=` ResultAccuracyModeAttr
void printResultAccuracyAttr(AsmPrinter& odsPrinter, APFloat atol, APFloat rtol,
int64_t ulps, Attribute mode);

template <typename AttrTy, typename ModeTy>
Attribute parseResultAccuracyAttr(AsmParser& parser, Type type) {
APFloat resultAtol = APFloat::getZero(APFloat::IEEEdouble());
APFloat resultRtol = APFloat::getZero(APFloat::IEEEdouble());
int64_t resultUlps = 0;

// Parse literal '<'
if (parser.parseLess()) return {};

// OptAtolAccuracy
if (succeeded(parser.parseOptionalKeyword("atol"))) {
double value;
if (parser.parseEqual() || parser.parseFloat(value) || parser.parseComma())
return {};
resultAtol = APFloat(value);
}

// OptRtolAccuracy
if (succeeded(parser.parseOptionalKeyword("rtol"))) {
double value;
if (parser.parseEqual() || parser.parseFloat(value) || parser.parseComma())
return {};
resultRtol = APFloat(value);
}

// OptUlpAccuracy
if (succeeded(parser.parseOptionalKeyword("ulps"))) {
int64_t value;
if (parser.parseEqual() || parser.parseInteger(value) ||
parser.parseComma())
return {};
resultUlps = value;
}

// ModeAccuracy
ModeTy modeAttr;
if (parser.parseKeyword("mode") || parser.parseEqual() ||
parser.parseAttribute(modeAttr)) {
return {};
}

// Parse literal '>'
if (parser.parseGreater()) return {};
return parser.getChecked<AttrTy>(parser.getCurrentLocation(),
parser.getContext(), resultAtol, resultRtol,
resultUlps, modeAttr);
}

} // namespace hlo
} // namespace mlir

17 changes: 17 additions & 0 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
@@ -780,5 +780,22 @@ bool isValidQuantizedDimension(Type type) {
numScales == rankedType.getDimSize(quantDim));
}

bool hasSingleBoundedDimension(Type type) {
RankedTensorType rankedType = dyn_cast<RankedTensorType>(type);
auto boundedAttr =
dyn_cast_or_null<BoundedAttrInterface>(rankedType.getEncoding());
if (!boundedAttr) return false;

// Count if bounded attr size is not kDynamic
int64_t numBoundedDims = llvm::count_if(
boundedAttr.getBounds(),
[](int64_t bound) { return !ShapedType::isDynamic(bound); });
// Also check that there are only bounded dims and no unbounded dims.
int64_t numDynamicDims = llvm::count_if(
rankedType.getShape(),
[](int64_t bound) { return ShapedType::isDynamic(bound); });
return numBoundedDims == 1 && numDynamicDims == 1;
}

} // namespace hlo
} // namespace mlir
3 changes: 3 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
@@ -101,6 +101,9 @@ bool isValidStablehloQuantizedElementType(Type elementType);
// mentioned in the StableHLO specification.
bool isValidQuantizedDimension(Type type);

// Returns true if the given type has a single bounded dimension.
bool hasSingleBoundedDimension(Type type);

// TODO(zhouxin) Move type inference related methods to TypeInference.cpp

std::pair<int64_t, int64_t> inferConcatenatedDimAndBound(int64_t leftSize,
17 changes: 17 additions & 0 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
@@ -30,6 +30,20 @@ def I32RankedTensor : RankedTensorOf<[I32]>;

def UI32RankedTensor : RankedTensorOf<[UI32]>;

//===----------------------------------------------------------------------===//
// HLO type constraints.
//===----------------------------------------------------------------------===//

// Note: Bounded dynamisms is largely unspecced and this feature needs more
// thoguht as it is adopted to modern frameworks. The current support is
// designed to allow existing TF programs to be representable in StableHLO and
// is subject to change as a formal design for boudned dynamism is developed.
def HLO_HasSingleBoundedDimensionPred
: CPred<"mlir::hlo::hasSingleBoundedDimension($_self)">;

def HLO_HasStaticOrSingleBoundedShapePred
: Or<[HasStaticShapePred, HLO_HasSingleBoundedDimensionPred]>;

//===----------------------------------------------------------------------===//
// HLO type definitions.
//===----------------------------------------------------------------------===//
@@ -267,6 +281,9 @@ def HLO_StaticShapeTensor : StaticShapeTensorOf<[
def HLO_StaticShapeTensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
[IsValidQuantizedDimension, HasStaticShapePred], "statically shaped tensor">;

def HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
[IsValidQuantizedDimension, HLO_HasStaticOrSingleBoundedShapePred], "statically shaped or single bounded dimension tensor">;

def HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken : AnyTypeOf<[HLO_StaticShapeTensor, HLO_StaticShapeTensorOrPerAxisQuantizedTensor, HLO_Token]>;

def HLO_StaticShapeIntOrFpTensor : StaticShapeTensorOf<[HLO_Int, HLO_Float]>;
2 changes: 1 addition & 1 deletion stablehlo/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -190,7 +190,7 @@ mlir_tablegen(VhloEnums.cpp.inc -gen-enum-defs)
set(LLVM_TARGET_DEFINITIONS VhloOps.td)
mlir_tablegen(VhloAttrs.h.inc -gen-attrdef-decls)
mlir_tablegen(VhloAttrs.cpp.inc -gen-attrdef-defs)
set(LLVM_TARGET_DEFINITIONS VhloAttrs.td)
set(LLVM_TARGET_DEFINITIONS VhloEnums.td)
mlir_tablegen(VhloAttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(VhloAttrInterfaces.cpp.inc -gen-attr-interface-defs)
set(LLVM_TARGET_DEFINITIONS VhloTypes.td)
15 changes: 15 additions & 0 deletions stablehlo/dialect/StablehloAttrs.td
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ limitations under the License.

include "mlir/IR/OpBase.td"
include "mlir/IR/TensorEncoding.td"
include "stablehlo/dialect/StablehloTypes.td"

def StableHLO_Dims : ArrayRefParameter<"int64_t", "Dimension"> {
let parser = "parseDimSizes($_parser)";
@@ -209,4 +210,18 @@ def StableHLO_ConvDimensionNumbers : AttrDef<StableHLO_Dialect, "ConvDimensionNu
let hasCustomAssemblyFormat = 1;
}

def StableHLO_ResultAccuracyAttr : AttrDef<StableHLO_Dialect, "ResultAccuracy"> {
let mnemonic = "result_accuracy";
let summary = "The requested accuracy for transcendental unary ops.";
let parameters = (ins
"APFloat":$atol,
"APFloat":$rtol,
"int64_t":$ulps,
StableHLO_ResultAccuracyModeAttr:$mode
);
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
let constBuilderCall = "ResultAccuracyAttr::get($_builder.getContext(), APFloat(0.0), APFloat(0.0), 0, ResultAccuracyModeAttr::get($_builder.getContext(), $0))";
}

#endif // STABLEHLO_DIALECT_STABLEHLO_ATTRS
Loading