|
12 | 12 |
|
13 | 13 | import executorch.backends.arm.tosa.dialect # noqa: unused |
14 | 14 | from executorch.backends.arm._passes import ( |
15 | | - AddBiasPass, |
16 | 15 | AnnotateDecomposedMatmulPass, |
17 | 16 | AnnotateOutputDimOrderPass, |
18 | 17 | BroadcastArgsPass, |
|
93 | 92 | ReplaceScalarWithTensorArgPassTOSABI, |
94 | 93 | ReplaceScalarWithTensorArgPassTOSAMI, |
95 | 94 | RetraceFoldedDtypesPass, |
| 95 | + RewriteConv2dPass, |
96 | 96 | RewriteMatmulPass, |
97 | 97 | RewriteUpsamplePass, |
98 | 98 | ScalarsToAttributePass, |
@@ -207,13 +207,13 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: |
207 | 207 | self.add_pass(InsertTableOpsPass(exported_program)) |
208 | 208 | # If we have a conv2d with int16 activation split up into a convolution |
209 | 209 | # and an addition, to work-around the lack of support for int48 in torch |
210 | | - # needs to happen before AddBiasPass, but after the table ops are inserted |
| 210 | + # needs to happen before RewriteConv2dPass, but after the table ops are inserted |
211 | 211 | # to be able to validate that conv2d has right dtype arguments. |
212 | 212 | self.add_pass(DecomposeConv2dWithInt16ActivationPass()) |
213 | | - self.add_pass(RewriteUpsamplePass()) |
214 | | - self.add_pass(AddBiasPass(exported_program)) |
| 213 | + self.add_pass(RewriteConv2dPass(exported_program)) |
215 | 214 |
|
216 | 215 | self.add_pass(RewriteMatmulPass()) |
| 216 | + self.add_pass(RewriteUpsamplePass()) |
217 | 217 | self.add_pass(FuseEqualPlaceholdersPass(exported_program)) |
218 | 218 | self.add_pass(ToTosaMemoryFormatPass(exported_program)) |
219 | 219 | self.add_pass(RemoveNoopPass()) |
@@ -297,9 +297,9 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: |
297 | 297 |
|
298 | 298 | self.add_pass(FuseViewCopyTransform()) |
299 | 299 | self.add_pass(FuseConstantArgsPass(exported_program)) |
| 300 | + self.add_pass(RewriteConv2dPass(exported_program)) |
300 | 301 | self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) |
301 | 302 | self.add_pass(RewriteUpsamplePass()) |
302 | | - self.add_pass(AddBiasPass(exported_program)) |
303 | 303 | self.add_pass(InsertTableOpsPass(exported_program)) |
304 | 304 | self.add_pass(RewriteMatmulPass()) |
305 | 305 | self.add_pass(FuseEqualPlaceholdersPass(exported_program)) |
|
0 commit comments