@@ -39,23 +39,58 @@ def __init__(
3939 super ().__init__ (result_type , _get_op_result_or_value (target ), loc = loc , ip = ip )
4040
4141
42+ def cast (
43+ result_type : Type , target : Union [Operation , Value ], * , loc = None , ip = None
44+ ) -> OpResult :
45+ return CastOp (result_type = result_type , target = target , loc = loc , ip = ip ).result
46+
47+
4248@_ods_cext .register_operation (_Dialect , replace = True )
4349class ApplyPatternsOp (ApplyPatternsOp ):
4450 def __init__ (
4551 self ,
4652 target : Union [Operation , Value , OpView ],
53+ apply_cse : bool = False ,
54+ max_iterations : Optional [Union [IntegerAttr , int ]] = None ,
55+ max_num_rewrites : Optional [Union [IntegerAttr , int ]] = None ,
4756 * ,
4857 loc = None ,
4958 ip = None ,
5059 ):
51- super ().__init__ (target , loc = loc , ip = ip )
60+ super ().__init__ (
61+ target ,
62+ apply_cse = apply_cse ,
63+ max_iterations = max_iterations ,
64+ max_num_rewrites = max_num_rewrites ,
65+ loc = loc ,
66+ ip = ip ,
67+ )
5268 self .regions [0 ].blocks .append ()
5369
5470 @property
5571 def patterns (self ) -> Block :
5672 return self .regions [0 ].blocks [0 ]
5773
5874
75+ def apply_patterns (
76+ target : Union [Operation , Value , OpView ],
77+ apply_cse : bool = False ,
78+ max_iterations : Optional [Union [IntegerAttr , int ]] = None ,
79+ max_num_rewrites : Optional [Union [IntegerAttr , int ]] = None ,
80+ * ,
81+ loc = None ,
82+ ip = None ,
83+ ) -> ApplyPatternsOp :
84+ return ApplyPatternsOp (
85+ target = target ,
86+ apply_cse = apply_cse ,
87+ max_iterations = max_iterations ,
88+ max_num_rewrites = max_num_rewrites ,
89+ loc = loc ,
90+ ip = ip ,
91+ )
92+
93+
5994@_ods_cext .register_operation (_Dialect , replace = True )
6095class GetParentOp (GetParentOp ):
6196 def __init__ (
@@ -64,6 +99,7 @@ def __init__(
6499 target : Union [Operation , Value ],
65100 * ,
66101 isolated_from_above : bool = False ,
102+ allow_empty_results : bool = False ,
67103 op_name : Optional [str ] = None ,
68104 deduplicate : bool = False ,
69105 nth_parent : int = 1 ,
@@ -74,6 +110,7 @@ def __init__(
74110 result_type ,
75111 _get_op_result_or_value (target ),
76112 isolated_from_above = isolated_from_above ,
113+ allow_empty_results = allow_empty_results ,
77114 op_name = op_name ,
78115 deduplicate = deduplicate ,
79116 nth_parent = nth_parent ,
@@ -82,24 +119,64 @@ def __init__(
82119 )
83120
84121
122+ def get_parent_op (
123+ result_type : Type ,
124+ target : Union [Operation , Value ],
125+ * ,
126+ isolated_from_above : bool = False ,
127+ allow_empty_results : bool = False ,
128+ op_name : Optional [str ] = None ,
129+ deduplicate : bool = False ,
130+ nth_parent : int = 1 ,
131+ loc = None ,
132+ ip = None ,
133+ ) -> OpResult :
134+ return GetParentOp (
135+ result_type = result_type ,
136+ target = target ,
137+ isolated_from_above = isolated_from_above ,
138+ allow_empty_results = allow_empty_results ,
139+ op_name = op_name ,
140+ deduplicate = deduplicate ,
141+ nth_parent = nth_parent ,
142+ loc = loc ,
143+ ip = ip ,
144+ ).result
145+
146+
85147@_ods_cext .register_operation (_Dialect , replace = True )
86148class MergeHandlesOp (MergeHandlesOp ):
87149 def __init__ (
88150 self ,
89151 handles : Sequence [Union [Operation , Value ]],
90152 * ,
91153 deduplicate : bool = False ,
154+ results : Optional [Sequence [Type ]] = None ,
92155 loc = None ,
93156 ip = None ,
94157 ):
95158 super ().__init__ (
96159 [_get_op_result_or_value (h ) for h in handles ],
97160 deduplicate = deduplicate ,
161+ results = results ,
98162 loc = loc ,
99163 ip = ip ,
100164 )
101165
102166
167+ def merge_handles (
168+ handles : Sequence [Union [Operation , Value ]],
169+ * ,
170+ deduplicate : bool = False ,
171+ results : Optional [Sequence [Type ]] = None ,
172+ loc = None ,
173+ ip = None ,
174+ ) -> OpResult :
175+ return MergeHandlesOp (
176+ handles = handles , deduplicate = deduplicate , results = results , loc = loc , ip = ip
177+ ).result
178+
179+
103180@_ods_cext .register_operation (_Dialect , replace = True )
104181class ReplicateOp (ReplicateOp ):
105182 def __init__ (
@@ -119,16 +196,31 @@ def __init__(
119196 )
120197
121198
199+ def replicate (
200+ pattern : Union [Operation , Value ],
201+ handles : Sequence [Union [Operation , Value ]],
202+ * ,
203+ loc = None ,
204+ ip = None ,
205+ ) -> Union [OpResult , OpResultList , ReplicateOp ]:
206+ op = ReplicateOp (pattern = pattern , handles = handles , loc = loc , ip = ip )
207+ results = op .results
208+ return results if len (results ) > 1 else (results [0 ] if len (results ) == 1 else op )
209+
210+
122211@_ods_cext .register_operation (_Dialect , replace = True )
123212class SequenceOp (SequenceOp ):
124213 def __init__ (
125214 self ,
126- failure_propagation_mode ,
215+ failure_propagation_mode : FailurePropagationMode ,
127216 results : Sequence [Type ],
128217 target : Union [Operation , Value , Type ],
129218 extra_bindings : Optional [
130219 Union [Sequence [Value ], Sequence [Type ], Operation , OpView ]
131220 ] = None ,
221+ * ,
222+ loc = None ,
223+ ip = None ,
132224 ):
133225 root = (
134226 _get_op_result_or_value (target )
@@ -155,6 +247,8 @@ def __init__(
155247 failure_propagation_mode = failure_propagation_mode ,
156248 root = root ,
157249 extra_bindings = extra_bindings ,
250+ loc = loc ,
251+ ip = ip ,
158252 )
159253 self .regions [0 ].blocks .append (* tuple ([root_type ] + extra_binding_types ))
160254
@@ -171,16 +265,42 @@ def bodyExtraArgs(self) -> BlockArgumentList:
171265 return self .body .arguments [1 :]
172266
173267
268+ def sequence (
269+ failure_propagation_mode : FailurePropagationMode ,
270+ results : Sequence [Type ],
271+ target : Union [Operation , Value , Type ],
272+ extra_bindings : Optional [
273+ Union [Sequence [Value ], Sequence [Type ], Operation , OpView ]
274+ ] = None ,
275+ * ,
276+ loc = None ,
277+ ip = None ,
278+ ) -> Union [OpResult , OpResultList , SequenceOp ]:
279+ op = SequenceOp (
280+ results = results ,
281+ failure_propagation_mode = failure_propagation_mode ,
282+ extra_bindings = extra_bindings ,
283+ target = target ,
284+ loc = loc ,
285+ ip = ip ,
286+ )
287+ results = op .results
288+ return results if len (results ) > 1 else (results [0 ] if len (results ) == 1 else op )
289+
290+
174291@_ods_cext .register_operation (_Dialect , replace = True )
175292class NamedSequenceOp (NamedSequenceOp ):
176293 def __init__ (
177294 self ,
178- sym_name ,
295+ sym_name : Union [ str , SymbolRefAttr ] ,
179296 input_types : Sequence [Type ],
180297 result_types : Sequence [Type ],
181- sym_visibility = None ,
182- arg_attrs = None ,
183- res_attrs = None ,
298+ * ,
299+ sym_visibility : Optional [Union [str , StringAttr ]] = None ,
300+ arg_attrs : Optional [Union [Sequence [dict ], "DictArrayAttr" ]] = None ,
301+ res_attrs : Optional [Union [Sequence [dict ], "DictArrayAttr" ]] = None ,
302+ loc = None ,
303+ ip = None ,
184304 ):
185305 function_type = FunctionType .get (input_types , result_types )
186306 super ().__init__ (
@@ -205,6 +325,29 @@ def bodyExtraArgs(self) -> BlockArgumentList:
205325 return self .body .arguments [1 :]
206326
207327
328+ def named_sequence (
329+ sym_name : Union [str , SymbolRefAttr ],
330+ input_types : Sequence [Type ],
331+ result_types : Sequence [Type ],
332+ * ,
333+ sym_visibility : Optional [Union [str , StringAttr ]] = None ,
334+ arg_attrs : Optional [Union [Sequence [dict ], "DictArrayAttr" ]] = None ,
335+ res_attrs : Optional [Union [Sequence [dict ], "DictArrayAttr" ]] = None ,
336+ loc = None ,
337+ ip = None ,
338+ ) -> NamedSequenceOp :
339+ return NamedSequenceOp (
340+ sym_name = sym_name ,
341+ input_types = input_types ,
342+ result_types = result_types ,
343+ sym_visibility = sym_visibility ,
344+ arg_attrs = arg_attrs ,
345+ res_attrs = res_attrs ,
346+ loc = loc ,
347+ ip = ip ,
348+ )
349+
350+
208351@_ods_cext .register_operation (_Dialect , replace = True )
209352class YieldOp (YieldOp ):
210353 def __init__ (
@@ -219,6 +362,12 @@ def __init__(
219362 super ().__init__ (_get_op_results_or_values (operands ), loc = loc , ip = ip )
220363
221364
365+ def yield_ (
366+ operands : Optional [Union [Operation , Sequence [Value ]]] = None , * , loc = None , ip = None
367+ ) -> YieldOp :
368+ return YieldOp (operands = operands , loc = loc , ip = ip )
369+
370+
222371OptionValueTypes = Union [
223372 Sequence ["OptionValueTypes" ], Attribute , Value , Operation , OpView , str , int , bool
224373]
@@ -247,7 +396,7 @@ def __init__(
247396 def option_value_to_attr (value ):
248397 nonlocal cur_param_operand_idx
249398 if isinstance (value , (Value , Operation , OpView )):
250- dynamic_options .append (_get_op_result_or_value ( value ) )
399+ dynamic_options .append (value )
251400 cur_param_operand_idx += 1
252401 return ParamOperandAttr (cur_param_operand_idx - 1 , context )
253402 elif isinstance (value , Attribute ):
0 commit comments