Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LINALG][MLIR] Fix the broadcast dim check for elementwise ops #2882

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Shukla-Gaurav
Copy link
Collaborator

No description provided.

@Shukla-Gaurav Shukla-Gaurav force-pushed the gaurav/fix_broadcast_element branch from 5bb719e to b3dfea2 Compare February 7, 2024 17:29
@Shukla-Gaurav
Copy link
Collaborator Author

@stellaraccident @rsuderman
This issue arises in the opt-125M model at
%47 = torch.aten.add.Tensor %3, %46, %int1 : !torch.vtensor<[1,12,6,6],f32>, !torch.vtensor<[?,?,6,6],f32>, !torch.int -> !torch.vtensor<[?,12,6,6],f32>
While lowering it to linalg.generic, the dynamic dimensions of inputs are strictly checked for the equality to the expected broadcasted type while it can be also equal to one.
For example. the second dim of %46 is strictly checked to be equal to 12, while it can also be 1.
The fix above adds an OR condition to check if the dynamic dim is 1.

Let me know if this sounds good to you, I will add a test case and update the existing test cases. Thanks!

@rsuderman rsuderman changed the title [LINALG][MLIR] Fix the broadcast dim check for elementwise ops lowering [LINALG][MLIR] Fix the broadcast dim check for elementwise ops Feb 7, 2024
@rsuderman
Copy link
Contributor

Could we validate with a test that the broadcast case would pass? I am a little concerned that removing the assertion will just cause runtime failures.

@kumardeepakamd
Copy link
Collaborator

Suggest creating few Operator level tests and adding to https://github.com/nod-ai/SHARK-TestSuite/tree/main/e2eshark

@stellaraccident
Copy link
Collaborator

stellaraccident commented Feb 7, 2024

The problem is that I don't think it can be dynamically 1. That is illegal at the pytorch level. (The assert is correct)

@stellaraccident
Copy link
Collaborator

I think it is likely that the problem is leading up to this and we are not narrowing it to a static value.

You can't dynamically switch between an expanding and non expanding dim.

@stellaraccident
Copy link
Collaborator

I think it is likely that the problem is leading up to this and we are not narrowing it to a static value.

You can't dynamically switch between an expanding and non expanding dim.

These are abnormally small tensors and are raising red flags for me. Given that this is only 47 instructions deep, can you post the ir ave function declaration up to this point?

Copy link
Collaborator

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking request changes because the assert is correct and we need to chase this further.

@Shukla-Gaurav
Copy link
Collaborator Author

Shukla-Gaurav commented Feb 8, 2024

I think it is likely that the problem is leading up to this and we are not narrowing it to a static value.

You can't dynamically switch between an expanding and non expanding dim.

These are abnormally small tensors and are raising red flags for me. Given that this is only 47 instructions deep, can you post the ir ave function declaration up to this point?

It's not exactly 47th instruction in the opt-125m model, we got a smaller IR out of opt model to reproduce the issue.

  func.func @main_graph() -> (!torch.vtensor<[?,12,6,6],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.3.0"} {
    %199 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<1x1x6x6xf32>} : () -> !torch.vtensor<[1,1,6,6],f32>
    %200 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__2> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64>
    %201 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__3> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
    %202 = torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
    %203 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__4> : tensor<si64>} : () -> !torch.vtensor<[],si64>
    %204 = torch.operator "onnx.Mul"(%202, %203) : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[4],si64>
    %205 = torch.operator "onnx.Equal"(%200, %204) : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1>
    %206 = torch.operator "onnx.Where"(%205, %202, %200) : (!torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64>
    %207 = torch.operator "onnx.Expand"(%199, %206) : (!torch.vtensor<[1,1,6,6],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,6,6],f32>
    %208 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__5> : tensor<1x1x1x6xf32>} : () -> !torch.vtensor<[1,1,1,6],f32>
    %209 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__6> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64>
    %210 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__7> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
    %211 = torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
    %212 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__8> : tensor<si64>} : () -> !torch.vtensor<[],si64>
    %213 = torch.operator "onnx.Mul"(%211, %212) : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[4],si64>
    %214 = torch.operator "onnx.Equal"(%209, %213) : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1>
    %215 = torch.operator "onnx.Where"(%214, %211, %209) : (!torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64>
    %216 = torch.operator "onnx.Expand"(%208, %215) : (!torch.vtensor<[1,1,1,6],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,6],f32>
    %217 = torch.operator "onnx.Cast"(%216) {torch.onnx.to = 1 : si64} : (!torch.vtensor<[?,?,?,6],f32>) -> !torch.vtensor<[?,?,?,6],f32>
    %218 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__9> : tensor<f32>} : () -> !torch.vtensor<[],f32>
    %219 = torch.operator "onnx.Sub"(%218, %217) : (!torch.vtensor<[],f32>, !torch.vtensor<[?,?,?,6],f32>) -> !torch.vtensor<[?,?,?,6],f32>
    %220 = torch.operator "onnx.Cast"(%219) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[?,?,?,6],f32>) -> !torch.vtensor<[?,?,?,6],i1>
    %221 = torch.operator "onnx.Cast"(%220) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[?,?,?,6],i1>) -> !torch.vtensor<[?,?,?,6],i1>
    %222 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__10> : tensor<f32>} : () -> !torch.vtensor<[],f32>
    %223 = torch.operator "onnx.Where"(%221, %222, %219) : (!torch.vtensor<[?,?,?,6],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[?,?,?,6],f32>) -> !torch.vtensor<[?,?,?,6],f32>
    %224 = torch.operator "onnx.Cast"(%223) {torch.onnx.to = 1 : si64} : (!torch.vtensor<[?,?,?,6],f32>) -> !torch.vtensor<[?,?,?,6],f32>
    %225 = torch.operator "onnx.Cast"(%224) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[?,?,?,6],f32>) -> !torch.vtensor<[?,?,?,6],i1>
    %226 = torch.operator "onnx.Cast"(%225) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[?,?,?,6],i1>) -> !torch.vtensor<[?,?,?,6],i1>
    %227 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__11> : tensor<f32>} : () -> !torch.vtensor<[],f32>
    %228 = torch.operator "onnx.Where"(%226, %227, %207) : (!torch.vtensor<[?,?,?,6],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[?,?,6,6],f32>) -> !torch.vtensor<[?,?,6,6],f32>
    %268 = torch.vtensor.literal(dense<1.0> : tensor<1x12x6x6xf32>) : !torch.vtensor<[1,12,6,6],f32>
    %269 = torch.operator "onnx.Add"(%268, %228) : (!torch.vtensor<[1,12,6,6],f32>, !torch.vtensor<[?,?,6,6],f32>) -> !torch.vtensor<[?,12,6,6],f32>
    return %269 : !torch.vtensor<[?,12,6,6],f32>
  }
}

@Shukla-Gaurav
Copy link
Collaborator Author

I think it is likely that the problem is leading up to this and we are not narrowing it to a static value.

You can't dynamically switch between an expanding and non expanding dim.

I guess, folders might be helpful here.

@newling
Copy link
Collaborator

newling commented Feb 8, 2024

Maybe extending the logic in AtenWhereSelfOp::fold ?

@stellaraccident
Copy link
Collaborator

That IR needs some more work... One thing that jumps out to me: you've got multiple onnx.Constant ops that are just capturing i1 0d tensors as dense resources. Those are probably feeding in to conditionals that are constraining the entire thing. I doubt know why those are coming in like that but they will block any kind of folding, even if it existed. I have a gut feeling that this is doing a lot of dumb "shape like" work on those small tensors, probably just to be thrown out. This old onnx and tf gunk was lossy with that sort of thing. But you've got to get the ir normalized and folded a bit to weed it out.

Those onnx.Constant ops shouldn't be there and they should be coming in as dense_elements with a single value so they can be analyzed.

I'd need to stare at the whole problem some more to have better guidance.

@stellaraccident
Copy link
Collaborator

There's probably one use of that whole subgraph that is conditioned on some constant or input shape thing... The onnx representation is really bad for that stuff and we'll need to implement sufficient simplifications. I was there when the tflite folks had to do the same with insane tf graphs, which basically look like this :/

@stellaraccident
Copy link
Collaborator

Before boiling the ocean on this, I'd recommend using a third party tool to simplify the onnx graph. Example: https://github.com/daquexian/onnx-simplifier

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants