Skip to content

Commit

Permalink
[Stablehlo] enable Stablehlo refbackend with Interpreter
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Aug 10, 2024
1 parent 8358e8c commit f9efc00
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 6 deletions.
118 changes: 118 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,119 @@
}

STABLEHLO_PASS_SET = {
"AddCDivModule_basic",
"AddCMulModule_basic",
"Add_MixPModule_basic",
"Add_Module_basic",
"AvgPool1dFloatModule_basic",
"AvgPool1dIntModule_basic",
"AvgPool2dCeilModeTrueModule_basic",
"AvgPool2dFloatModule_basic",
"AvgPool2dIntModule_basic",
"AvgPool2dWithoutPadModule_basic",
"BmmFloatModule_basic",
"BmmIntModule_basic",
"BroadcastDynamicDimModule_basic",
"BroadcastToModule_basic",
"CollapseFullDynamicModule_basic",
"CollapsePartialDynamicModule_basic",
"CollapseRank1DynamicModule_basic",
"CopyModule_basic",
"CopyWithDifferentDTypesAndSizesModule_basic",
"CopyWithDifferentDTypesModule_basic",
"CopyWithDifferentSizesModule_basic",
"ElementwiseAddModule_basic",
"ElementwiseAddScalarInt8Module_basic",
"ElementwiseAndScalarModule_basic",
"ElementwiseAtan2FloatIntModule_basic",
"ElementwiseAtan2TensorFloatModule_basic",
"ElementwiseAtan2TensorIntModule_basic",
"ElementwiseAtenFloorDivideBroadcastModule_basic",
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
"ElementwiseBinaryModule_basic",
"ElementwiseBitwiseAndModule_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseBitwiseOrModule_basic",
"ElementwiseBitwiseXorModule_basic",
"ElementwiseDivTensorFloatModule_basic",
"ElementwiseDivTensorIntegerModule_basic",
"ElementwiseDivTensorRoundingModeFloorModule_basic",
"ElementwiseDivTensorRoundingModeTruncModule_basic",
"ElementwiseFlattenBroadcastModule_basic",
"ElementwiseMaxOtherIntModule_basic",
"ElementwiseMaxOtherModule_basic",
"ElementwiseMaximumIntModule_basic",
"ElementwiseMaximumModule_basic",
"ElementwiseMinOtherIntModule_basic",
"ElementwiseMinOtherModule_basic",
"ElementwiseMinimumIntModule_basic",
"ElementwiseMinimumModule_basic",
"ElementwiseMulScalarModule_int",
"ElementwiseMulTensorFloatModule_basic",
"ElementwiseMulTensorIntModule_basic",
"ElementwiseOrTensorModule_basic",
"ElementwisePowTensorBroadcastModule_basic",
"ElementwisePowTensorModule_basic",
"ElementwiseRelu6Module_basic",
"ElementwiseRemainderScalarModule_Int_basic",
"ElementwiseSignModule_basic",
"ElementwiseSubTensorInt8Module_basic",
"ElementwiseTernaryModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic",
"ExpandAsFloatModule_basic",
"ExpandModule_basic",
"FlipModule_basic",
"HardTanhIntModule_basic",
"HardTanhModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"LogSoftmaxBackwardModule_basic",
"LogSoftmaxIntModule_basic",
"MatmulBroadcastBatchDim_basic",
"MatmulSingleDynamicBatchDim_basic",
"Matmul_3d",
"Matmul_4d",
"MaxPool1dModule_basic",
"MaxPool2dModule_basic",
"MaxPool3dLargeDatadModule_basic",
"MaxPool3dModuleRandomSimple_basic",
"MaxPool3dModule_basic",
"MseLossNoReductionModule_basic",
"MseLossSumReductionWithDifferentElemTypeModule_basic",
"OneHotModule_basic",
"PixelShuffleModuleSpatiallyDynamic_basic",
"ReduceAmaxKeepDim_basic",
"ReduceFrobeniusNormKeepDimModule_basic",
"ReduceSumDimIntListKeepDimFloatModule_basic",
"ReduceSumDimIntListKeepDimIntModule_basic",
"RsubIntModule_basic",
"ScatterSrcStaticModule_basic",
"SiluModule_basic",
"SliceCopy_Module_basic",
"SoftmaxBackwardModule_basic",
"SoftmaxIntArgTypeF64Module_basic",
"SoftmaxIntModule_basic",
"SoftmaxIntNegDimModule_basic",
"SoftmaxIntNonNoneDtypeModule_basic",
"SquareModule_basic",
"SqueezeDimModule_dynamic",
"SqueezeDimModule_negDim",
"SqueezeModule_broadcast",
"TanhBackward_basic",
"TensorsStackModule_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackSingleElementListModule_basic",
"ToCopyModule_basic",
"ToCopyWithDTypeFalsePinMemoryModule_basic",
"ToCopyWithDTypeModule_basic",
"TypePromotionSameCategoryDifferentWidthModule_basic",
"_LogSoftmaxModuleStable_basic",
"_LogSoftmaxModule_basic",
"_SoftmaxModule_basic",
"ReduceAminmaxSingleDim_basic",
"ReduceAminmaxAllDims_basic",
"ReduceAmaxEmptyDim_basic",
Expand Down Expand Up @@ -1513,6 +1626,11 @@
"IndexPutWithNoneAndBroadcastModule_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
# stablehlo intrepreter crash
"ElementwiseDivTensorUnsignedIntegerModule_basic",
"ScatterValueFloatModule_basic",
"ScatterValueIntModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
}

# Write the TOSA set as a "passing" set as it is very early in development
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from torch_mlir import ir
from torch_mlir.ir import *
from torch_mlir.dialects.func import FuncOp
from torch_mlir.passmanager import *
from torch_mlir.compiler_utils import run_pipeline_with_repro_report

Expand All @@ -13,10 +15,60 @@

from .abc import StablehloBackend

from torch_mlir._mlir_libs._stablehlo import eval_module
import numpy as np

__all__ = [
"LinalgOnTensorsStablehloBackend",
]

element_type_to_np_dtype = {
"i1": np.bool_,
"i8": np.int8,
"ui8": np.uint8,
"i16": np.int16,
"i32": np.int32,
"i64": np.int64,
"f16": np.float16,
"f32": np.float32,
"f64": np.float64,
}


def convert_dense_elements_attr_to_numpy(attr):
assert isinstance(attr, ir.DenseElementsAttr)
dense_attr = ir.DenseElementsAttr(attr)
for DenseElementsAttrCls in [ir.DenseIntElementsAttr, ir.DenseFPElementsAttr]:
if DenseElementsAttrCls.isinstance(attr):
dense_attr = DenseElementsAttrCls(attr)
assert ir.ShapedType.isinstance(dense_attr.type)
dense_attr_type = ir.ShapedType(dense_attr.type)
return np.array(
[i for i in dense_attr],
dtype=element_type_to_np_dtype[str(dense_attr_type.element_type)],
).reshape(dense_attr_type.shape)
raise NotImplementedError("unsupported attribute {}".format(attr))


class RefBackendInvoker:
def __init__(self, module):
self.module = module

def __getattr__(self, function_name: str):
def invoke(*args):
mlir_args = [
ir.DenseElementsAttr.get(arg, context=self.module.context)
for arg in args
]
rets = eval_module(self.module, mlir_args)
rets = [convert_dense_elements_attr_to_numpy(i) for i in rets]
if len(rets) == 1:
return rets[0]
return rets

return invoke


# The pipeline of func.func passes that lower the STABLEHLO backend contract to the
# Linalg-on-Tensors backend contract accepted by RefBackend.
STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join(
Expand All @@ -28,6 +80,39 @@
]
)

SHAPE_LEGALIZE_TO_STABLEHLO_PIPELINE = ",".join(
[
"func.func(remove-shape-constraints)",
"canonicalize",
"func.func(shape-legalize-to-stablehlo)",
"canonicalize",
]
)


def raise_if_not_supported_by_interpreter(module: Module):
for func in module.body.operations:
assert isinstance(func, FuncOp)
for arg in func.arguments:
assert isinstance(arg.type, ir.ShapedType)
if str(ir.ShapedType(arg.type).element_type) == "i1":
raise RuntimeError("i1")
for ret in list(func.entry_block.operations)[-1].operands:
assert isinstance(ret.type, ir.ShapedType)
if str(ir.ShapedType(ret.type).element_type) == "i1":
raise RuntimeError("i1")
for op in func.entry_block.operations:
if op.operation.name == "func.return":
continue
if not op.operation.name.startswith("stablehlo."):
raise RuntimeError(
f"stablehlo interpreter doesn't support {op.operation.name}"
)
if op.operation.name == "stablehlo.batch_norm_inference":
raise RuntimeError(
f"stablehlo interpreter doesn't support {op.operation.name}"
)


class LinalgOnTensorsStablehloBackend(StablehloBackend):
"""Main entry-point for the linalg-on-tensors based Stablehlo backend.
Expand All @@ -48,15 +133,29 @@ def compile(self, imported_module: Module):
An opaque, backend specific compiled artifact object that can be
passed to `load`.
"""
copied_module = Module.parse(imported_module.operation.get_asm(), imported_module.context)
try:
run_pipeline_with_repro_report(
imported_module,
f"builtin.module({STABLEHLO_TO_LINALG_FUNC_PIPELINE})",
"Lowering STABLEHLO backend contract to Linalg-on-Tensors backend contract",
)
result = self.refbackend.compile(imported_module)
return (result, "linalg")
except:
pass

run_pipeline_with_repro_report(
imported_module,
f"builtin.module({STABLEHLO_TO_LINALG_FUNC_PIPELINE})",
"Lowering STABLEHLO backend contract to Linalg-on-Tensors backend contract",
copied_module,
f"builtin.module({SHAPE_LEGALIZE_TO_STABLEHLO_PIPELINE})",
"Shape legalize to stablehlo",
)

return self.refbackend.compile(imported_module)
raise_if_not_supported_by_interpreter(copied_module)
return (copied_module, "stablehlo")

def load(self, module):
"""Loads a compiled artifact into the runtime."""
return self.refbackend.load(module)
if module[1] == "linalg":
return self.refbackend.load(module[0])
else:
return RefBackendInvoker(module[0])

0 comments on commit f9efc00

Please sign in to comment.