Skip to content

Commit

Permalink
[linalg] Fix bug for conversion of complex dtype
Browse files Browse the repository at this point in the history
Conversion of complex dtype wasn't supported and checked. Added the
support and required tests.
  • Loading branch information
pashu123 committed May 1, 2024
1 parent 0a2d21b commit 7a6627c
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
21 changes: 21 additions & 0 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/Utils/Utils.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -349,6 +350,26 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
}

if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(dtype)) {
if (auto scalarComplex = dyn_cast<mlir::ComplexType>(scalarType)) {
auto dtypeElemType = dtypeComplex.getElementType();

// Extract the real and imaginary parts of the scalar.
// Cast them to the target element type, and create a new complex
// value with the target complex type.
Value realVal = b.create<complex::ReOp>(loc, scalar);
Value imgVal = b.create<complex::ImOp>(loc, scalar);

realVal = convertScalarToDtype(b, loc, realVal, dtypeElemType);
imgVal = convertScalarToDtype(b, loc, imgVal, dtypeElemType);

return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
}
mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype
<< "(dtype)";
}

llvm_unreachable("convertScalarToDtype should handle all the types");
}

Expand Down
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@
"ElementwiseErfIntModule_basic",
"ElementwiseLogitModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
Expand Down Expand Up @@ -2314,6 +2315,7 @@
"ElementwiseExpm1Module_basic",
"ElementwiseFmodTensor_Int_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseOrTensorModule_basic",
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
Expand Down
28 changes: 28 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1839,6 +1839,34 @@ def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils):
# ==============================================================================


# torch.complex32 is not supported by the refbackend.
class ElementwiseMulTensorComplexDiffModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.complex64, True),
([-1], torch.complex128, True),
]
)
def forward(self, a, b):
return torch.mul(a, b)


@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexDiffModule())
def ElementwiseMulTensorComplexDiffModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, high=10).type(torch.complex64),
tu.randint(4, high=10).type(torch.complex128),
)


# ==============================================================================


class ElementwiseMishModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 7a6627c

Please sign in to comment.