Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch] aten.select.int use the concat operand as input #3139

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

xinan-jiang
Copy link
Contributor

@xinan-jiang xinan-jiang commented Apr 11, 2024

I found that the reshape can not distinguish the constant dimension from the shape, if the shape is come from the concated tensor. This PR could resolve this problem.
Below is an example.
the ONNX input

func.func @test_reshape(%arg0: !torch.vtensor<[1,?,1,16],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,1,16],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
  %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<const_fold_opt__1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
  %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<const_fold_opt__2> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
  %2 = torch.operator "onnx.Concat"(%0, %arg1, %1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3],si64> 
  %3 = torch.operator "onnx.Reshape"(%arg0, %2) : (!torch.vtensor<[1,?,1,16],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,1,16],f32> 
  return %3 : !torch.vtensor<[?,1,16],f32>
}

the mlir before torch-decompose-complex-ops

func.func @test_reshape(%arg0: !torch.vtensor<[1,?,1,16],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,1,16],f32> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
  %int2 = torch.constant.int 2
  %int-1 = torch.constant.int -1
  %int1 = torch.constant.int 1
  %int0 = torch.constant.int 0
  %0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
  %1 = torch.vtensor.literal(dense<16> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
  %2 = torch.prim.ListConstruct %0, %arg1, %1 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
  %3 = torch.aten.cat %2, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
  %4 = torch.aten.select.int %3, %int0, %int0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
  %5 = torch.aten.item %4 : !torch.vtensor<[1],si64> -> !torch.int
  %6 = torch.aten.eq.int %5, %int0 : !torch.int, !torch.int -> !torch.bool
  %7 = torch.aten.Int.bool %6 : !torch.bool -> !torch.int
  %8 = torch.aten.mul.int %7, %int1 : !torch.int, !torch.int -> !torch.int
  %9 = torch.aten.add.int %5, %8 : !torch.int, !torch.int -> !torch.int
  %10 = torch.aten.select.int %3, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
  %11 = torch.aten.item %10 : !torch.vtensor<[1],si64> -> !torch.int
  %12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool
  %13 = torch.aten.Int.bool %12 : !torch.bool -> !torch.int
  %14 = torch.aten.mul.int %13, %int-1 : !torch.int, !torch.int -> !torch.int
  %15 = torch.aten.add.int %11, %14 : !torch.int, !torch.int -> !torch.int
  %16 = torch.aten.select.int %3, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
  %17 = torch.aten.item %16 : !torch.vtensor<[1],si64> -> !torch.int
  %18 = torch.aten.eq.int %17, %int0 : !torch.int, !torch.int -> !torch.bool
  %19 = torch.aten.Int.bool %18 : !torch.bool -> !torch.int
  %20 = torch.aten.mul.int %19, %int1 : !torch.int, !torch.int -> !torch.int
  %21 = torch.aten.add.int %17, %20 : !torch.int, !torch.int -> !torch.int
  %22 = torch.prim.ListConstruct %9, %15, %21 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %23 = torch.aten.reshape %arg0, %22 : !torch.vtensor<[1,?,1,16],f32>, !torch.list<int> -> !torch.vtensor<[?,1,16],f32>
  return %23 : !torch.vtensor<[?,1,16],f32>
}

After the fixed torch-decompose-complex-ops

func.func @test_reshape(%arg0: !torch.vtensor<[1,?,1,16],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,1,16],f32> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
  %int16 = torch.constant.int 16
  %int1 = torch.constant.int 1
  %int-1 = torch.constant.int -1
  %int0 = torch.constant.int 0
  %0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
  %1 = torch.vtensor.literal(dense<16> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
  %2 = torch.prim.ListConstruct %0, %arg1, %1 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
  %3 = torch.aten.cat %2, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
  %4 = torch.aten.item %arg1 : !torch.vtensor<[1],si64> -> !torch.int
  %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool
  %6 = torch.aten.Int.bool %5 : !torch.bool -> !torch.int
  %7 = torch.aten.mul.int %6, %int-1 : !torch.int, !torch.int -> !torch.int
  %8 = torch.aten.add.int %4, %7 : !torch.int, !torch.int -> !torch.int
  %9 = torch.prim.ListConstruct %int1, %8, %int16 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %10 = torch.aten.view %arg0, %9 : !torch.vtensor<[1,?,1,16],f32>, !torch.list<int> -> !torch.vtensor<[?,1,16],f32>
  return %10 : !torch.vtensor<[?,1,16],f32>
}

@xinan-jiang xinan-jiang marked this pull request as draft May 7, 2024 02:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant