Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tosa backend] Support MatmulStaticBroadcast_basic for tosa #2592

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

bilibiliGO283
Copy link

This PR is to solve this issue: #2581

Test case

class Matmul(nn.Module):
    def __init__(self):
        super(Matmul, self).__init__()
 
    def forward(self, x, weight):
        r = torch.matmul(x, weight)
        return r

def gen_mt_mlir():
    model = Matmul()
    weight = torch.randn(80,300,250,150,dtype=torch.float32)
    input = torch.ones(300,100,250)
    res = model(input,weight)
    print(res.shape)
    module = torch_mlir.compile(model, [input,weight], output_type="TOSA")

if __name__=='__main__':
    gen_mt_mlir()

Before bug fix

func.func @forward(%arg0: !torch.vtensor<[300,100,250],f32>, %arg1: !torch.vtensor<[80,300,250,150],f32>) -> !torch.vtensor<[80,300,100,150],f32> {
  %0 = builtin.unrealized_conversion_cast %arg1 : !torch.vtensor<[80,300,250,150],f32> to tensor<80x300x250x150xf32>
  %1 = builtin.unrealized_conversion_cast %arg0 : !torch.vtensor<[300,100,250],f32> to tensor<300x100x250xf32>
  %2 = tosa.reshape %1 {new_shape = array<i64: 1, 300, 100, 250>} : (tensor<300x100x250xf32>) -> tensor<1x300x100x250xf32>
  %3 = "tosa.const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %4 = tosa.transpose %2, %3 : (tensor<1x300x100x250xf32>, tensor<4xi32>) -> tensor<300x1x100x250xf32>
  %5 = tosa.reshape %4 {new_shape = array<i64: 300, 100, 250>} : (tensor<300x1x100x250xf32>) -> tensor<300x100x250xf32>
  %6 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %7 = tosa.transpose %0, %6 : (tensor<80x300x250x150xf32>, tensor<4xi32>) -> tensor<300x250x80x150xf32>
  %8 = tosa.reshape %7 {new_shape = array<i64: 300, 250, 12000>} : (tensor<300x250x80x150xf32>) -> tensor<300x250x12000xf32>
  %9 = tosa.matmul %5, %8 : (tensor<300x100x250xf32>, tensor<300x250x12000xf32>) -> tensor<300x100x12000xf32>
  %10 = tosa.reshape %9 {new_shape = array<i64: 300, 100, 80, 150>} : (tensor<300x100x12000xf32>) -> tensor<300x100x80x150xf32>
  %11 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %12 = tosa.transpose %10, %11 : (tensor<300x100x80x150xf32>, tensor<4xi32>) -> tensor<80x300x100x150xf32>
  %13 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[300,100,250],f32>, !torch.vtensor<[80,300,250,150],f32> -> !torch.vtensor<[80,300,100,150],f32>
  return %13 : !torch.vtensor<[80,300,100,150],f32>
}

Bug fix

func.func @forward(%arg0: !torch.vtensor<[300,100,250],f32>, %arg1: !torch.vtensor<[80,300,250,150],f32>) -> !torch.vtensor<[80,300,100,150],f32> {
  %0 = builtin.unrealized_conversion_cast %arg1 : !torch.vtensor<[80,300,250,150],f32> to tensor<80x300x250x150xf32>
  %1 = builtin.unrealized_conversion_cast %arg0 : !torch.vtensor<[300,100,250],f32> to tensor<300x100x250xf32>
  %2 = tosa.reshape %1 {new_shape = array<i64: 1, 300, 100, 250>} : (tensor<300x100x250xf32>) -> tensor<1x300x100x250xf32>
  %3 = "tosa.const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %4 = tosa.transpose %2, %3 : (tensor<1x300x100x250xf32>, tensor<4xi32>) -> tensor<300x1x100x250xf32>
  %5 = tosa.reshape %4 {new_shape = array<i64: 300, 100, 250>} : (tensor<300x1x100x250xf32>) -> tensor<300x100x250xf32>
  %6 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %7 = tosa.transpose %0, %6 : (tensor<80x300x250x150xf32>, tensor<4xi32>) -> tensor<300x250x80x150xf32>
  %8 = tosa.reshape %7 {new_shape = array<i64: 300, 250, 12000>} : (tensor<300x250x80x150xf32>) -> tensor<300x250x12000xf32>
  %9 = tosa.matmul %5, %8 : (tensor<300x100x250xf32>, tensor<300x250x12000xf32>) -> tensor<300x100x12000xf32>
  %10 = tosa.reshape %9 {new_shape = array<i64: 300, 100, 80, 150>} : (tensor<300x100x12000xf32>) -> tensor<300x100x80x150xf32>
  %11 = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %12 = tosa.transpose %10, %11 : (tensor<300x100x80x150xf32>, tensor<4xi32>) -> tensor<80x300x100x150xf32>
  %13 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[300,100,250],f32>, !torch.vtensor<[80,300,250,150],f32> -> !torch.vtensor<[80,300,100,150],f32>
  return %13 : !torch.vtensor<[80,300,100,150],f32>
}

Result analysis

there are some mistake on %11:

%11 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>

It was transformed into this:

%11 = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>

@dan-garvey
Copy link
Collaborator

Any reason it needs to be added as xfail?

@dan-garvey dan-garvey self-requested a review December 15, 2023 18:57
@bilibiliGO283
Copy link
Author

hi @dan-garvey.
I didn't add it to xfail, instead I added it to TOSA_PASS_SET, this code is like this:

# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
    "IscloseStaticModule_basic",
    "IscloseStaticModuleTrue_basic",
    "TileBigDimsSizeModule_basic",
    ...
    "MatmulStaticBroadcast_basic", # Added code

@dan-garvey
Copy link
Collaborator

Sorry for the slow review. If you resolve the merge conflict I'll run the ci and approve.

@bilibiliGO283
Copy link
Author

Hello @dan-garvey , I've resolved the merge conflicts. Could you please run the ci?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants