Skip to content

Commit

Permalink
Update linalg dialect ops (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
amanda849 authored Nov 27, 2023
1 parent bddedc2 commit 5a9a019
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 11 deletions.
51 changes: 42 additions & 9 deletions mlir/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import mlir.astnodes as mast
from mlir.dialect import Dialect, DialectOp, is_op
from dataclasses import dataclass
from typing import Optional, List
from typing import Optional, List, Tuple, Union

Literal = Union[mast.StringLiteral, float, int, bool]
SsaUse = Union[mast.SsaId, Literal]

@dataclass
class LinalgBatchMatmul(DialectOp):
Expand Down Expand Up @@ -114,16 +116,28 @@ class LinalgDot(DialectOp):

@dataclass
class LinalgFill(DialectOp):
output_id: mast.SsaId
value_id: mast.SsaId
output_type: mast.Type
value_type: mast.Type
in_id: mast.SsaId
in_type: mast.Type
out_id: mast.SsaId
out_type: mast.Type
res_type: Optional[mast.Type] = None
attr: Optional[mast.Attribute] = None

_syntax_ = [("linalg.fill( {output_id.ssa_id} , {value_id.ssa_id} ) "
"{attr.attribute_value} : {output_type.type} , {value_type.type}"),
("linalg.fill( {output_id.ssa_id} , {value_id.ssa_id} ) "
" : {output_type.type} , {value_type.type}")]
_syntax_ = [("linalg.fill"
" ins( {in_id.ssa_id} : {in_type.type} )"
" outs( {out_id.ssa_id} : {out_type.type} )"
" {attr.attribute_value}"),
("linalg.fill"
" ins( {in_id.ssa_id} : {in_type.type} )"
" outs( {out_id.ssa_id} : {out_type.type} )"),
("linalg.fill"
" ins( {in_id.ssa_id} : {in_type.type} )"
" outs( {out_id.ssa_id} : {out_type.type} )"
" {attr.attribute_value} -> {res_type.type}"),
("linalg.fill"
" ins( {in_id.ssa_id} : {in_type.type} )"
" outs( {out_id.ssa_id} : {out_type.type} )"
" -> {res_type.type}")]


@dataclass
Expand Down Expand Up @@ -188,6 +202,22 @@ class LinalgRange(DialectOp):
" : {out_type.type}")]


@dataclass
class LinalgReduce(DialectOp):
inargs: List[mast.SsaId]
in_types: List[mast.Type]
outargs: List[mast.SsaId]
out_types: List[mast.Type]
dimensions: List[SsaUse]
region: mast.Region
args: List[Tuple[mast.SsaId, mast.Type]]

_syntax_ = [("linalg.reduce"
" ins( {inargs.ssa_id_list} : {in_types.type_list_no_parens} )"
" outs( {outargs.ssa_id_list} : {out_types.type_list_no_parens} )"
" dimensions = [ {dimensions.ssa_use_list} ]"
" ( {args.argument_list} ) {region.region}")]

@dataclass
class LinalgReshape(DialectOp):
src_id: mast.SsaId
Expand Down Expand Up @@ -271,6 +301,9 @@ class LinalgMatmul(DialectOp):
_syntax_ = [("linalg.matmul"
" ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" outs( {c_id.ssa_id} : {c_type.type} )"),
("linalg.matmul"
" ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" outs( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}"),
("linalg.matmul"
" ins( {a_id.ssa_id} , {b_id.ssa_id} : {a_type.type} , {b_type.type} )"
" init( {c_id.ssa_id} : {c_type.type} ) -> {out_type.type}")]
Expand Down
17 changes: 15 additions & 2 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def test_dot():

def test_fill():
assert_roundtrip_equivalence("""module {
func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32) {
linalg.fill( %arg0 , %arg1 ) : memref<?xf32, strided<[1], offset: ?>> , f32
func.func @fill_view(%arg0: f32, %arg1: tensor<?x?xf32>) {
linalg.fill ins( %arg0 : f32 ) outs( %arg1 : tensor<?x?xf32> ) -> tensor<?x?xf32>
linalg.fill ins( %arg0 : f32 ) outs( %arg1 : tensor<?x?xf32> )
return
}
}""")
Expand Down Expand Up @@ -105,6 +106,17 @@ def test_indexed_generic():
return
}
}""")

def test_reduce():
assert_roundtrip_equivalence("""module {
func.func @reduce(%arg0: tensor<16x32x64xf32>, %arg1: tensor<16x64xf32>) {
%reduce = linalg.reduce ins( %arg0 : tensor<16x32x64xf32> ) outs( %arg1 : tensor<16x64xf32> ) dimensions = [ 1 ] ( %in: f32, %out: f32 ) {
%0 = arith.addf %out, %in : f32
linalg.yield %0 : f32
}
return
}
}""")


def test_view():
Expand Down Expand Up @@ -135,6 +147,7 @@ def test_matmul():
%B = view %arg0 [ %c0 ] [ %K, %N ] : memref<?xi8> to memref<?x?xf32>
%C = view %arg0 [ %c0 ] [ %M, %N ] : memref<?xi8> to memref<?x?xf32>
linalg.matmul ins( %A , %B : memref<?x?xf32> , memref<?x?xf32> ) outs( %C : memref<?x?xf32> )
linalg.matmul ins( %A , %B : memref<?x?xf32> , memref<?x?xf32> ) outs( %C : memref<?x?xf32> ) -> memref<?x?xf32>
return
}
}""")
Expand Down

0 comments on commit 5a9a019

Please sign in to comment.