Skip to content

Commit d78e0de

Browse files
authored
[MLIR][Transform][Python] Sync derived classes and their wrappers (#166871)
Updates the derived Op-classes for the main transform ops to have all the arguments, etc, from the auto-generated classes. Additionally updates and adds missing snake_case wrappers for the derived classes which shadow the snake_case wrappers of the auto-generated classes, which were hitherto exposed alongside the derived classes.
1 parent 3ee2f07 commit d78e0de

File tree

2 files changed

+298
-80
lines changed

2 files changed

+298
-80
lines changed

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 156 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,58 @@ def __init__(
3939
super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
4040

4141

42+
def cast(
43+
result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None
44+
) -> OpResult:
45+
return CastOp(result_type=result_type, target=target, loc=loc, ip=ip).result
46+
47+
4248
@_ods_cext.register_operation(_Dialect, replace=True)
4349
class ApplyPatternsOp(ApplyPatternsOp):
4450
def __init__(
4551
self,
4652
target: Union[Operation, Value, OpView],
53+
apply_cse: bool = False,
54+
max_iterations: Optional[Union[IntegerAttr, int]] = None,
55+
max_num_rewrites: Optional[Union[IntegerAttr, int]] = None,
4756
*,
4857
loc=None,
4958
ip=None,
5059
):
51-
super().__init__(target, loc=loc, ip=ip)
60+
super().__init__(
61+
target,
62+
apply_cse=apply_cse,
63+
max_iterations=max_iterations,
64+
max_num_rewrites=max_num_rewrites,
65+
loc=loc,
66+
ip=ip,
67+
)
5268
self.regions[0].blocks.append()
5369

5470
@property
5571
def patterns(self) -> Block:
5672
return self.regions[0].blocks[0]
5773

5874

75+
def apply_patterns(
76+
target: Union[Operation, Value, OpView],
77+
apply_cse: bool = False,
78+
max_iterations: Optional[Union[IntegerAttr, int]] = None,
79+
max_num_rewrites: Optional[Union[IntegerAttr, int]] = None,
80+
*,
81+
loc=None,
82+
ip=None,
83+
) -> ApplyPatternsOp:
84+
return ApplyPatternsOp(
85+
target=target,
86+
apply_cse=apply_cse,
87+
max_iterations=max_iterations,
88+
max_num_rewrites=max_num_rewrites,
89+
loc=loc,
90+
ip=ip,
91+
)
92+
93+
5994
@_ods_cext.register_operation(_Dialect, replace=True)
6095
class GetParentOp(GetParentOp):
6196
def __init__(
@@ -64,6 +99,7 @@ def __init__(
6499
target: Union[Operation, Value],
65100
*,
66101
isolated_from_above: bool = False,
102+
allow_empty_results: bool = False,
67103
op_name: Optional[str] = None,
68104
deduplicate: bool = False,
69105
nth_parent: int = 1,
@@ -74,6 +110,7 @@ def __init__(
74110
result_type,
75111
_get_op_result_or_value(target),
76112
isolated_from_above=isolated_from_above,
113+
allow_empty_results=allow_empty_results,
77114
op_name=op_name,
78115
deduplicate=deduplicate,
79116
nth_parent=nth_parent,
@@ -82,24 +119,64 @@ def __init__(
82119
)
83120

84121

122+
def get_parent_op(
123+
result_type: Type,
124+
target: Union[Operation, Value],
125+
*,
126+
isolated_from_above: bool = False,
127+
allow_empty_results: bool = False,
128+
op_name: Optional[str] = None,
129+
deduplicate: bool = False,
130+
nth_parent: int = 1,
131+
loc=None,
132+
ip=None,
133+
) -> OpResult:
134+
return GetParentOp(
135+
result_type=result_type,
136+
target=target,
137+
isolated_from_above=isolated_from_above,
138+
allow_empty_results=allow_empty_results,
139+
op_name=op_name,
140+
deduplicate=deduplicate,
141+
nth_parent=nth_parent,
142+
loc=loc,
143+
ip=ip,
144+
).result
145+
146+
85147
@_ods_cext.register_operation(_Dialect, replace=True)
86148
class MergeHandlesOp(MergeHandlesOp):
87149
def __init__(
88150
self,
89151
handles: Sequence[Union[Operation, Value]],
90152
*,
91153
deduplicate: bool = False,
154+
results: Optional[Sequence[Type]] = None,
92155
loc=None,
93156
ip=None,
94157
):
95158
super().__init__(
96159
[_get_op_result_or_value(h) for h in handles],
97160
deduplicate=deduplicate,
161+
results=results,
98162
loc=loc,
99163
ip=ip,
100164
)
101165

102166

167+
def merge_handles(
168+
handles: Sequence[Union[Operation, Value]],
169+
*,
170+
deduplicate: bool = False,
171+
results: Optional[Sequence[Type]] = None,
172+
loc=None,
173+
ip=None,
174+
) -> OpResult:
175+
return MergeHandlesOp(
176+
handles=handles, deduplicate=deduplicate, results=results, loc=loc, ip=ip
177+
).result
178+
179+
103180
@_ods_cext.register_operation(_Dialect, replace=True)
104181
class ReplicateOp(ReplicateOp):
105182
def __init__(
@@ -119,16 +196,31 @@ def __init__(
119196
)
120197

121198

199+
def replicate(
200+
pattern: Union[Operation, Value],
201+
handles: Sequence[Union[Operation, Value]],
202+
*,
203+
loc=None,
204+
ip=None,
205+
) -> Union[OpResult, OpResultList, ReplicateOp]:
206+
op = ReplicateOp(pattern=pattern, handles=handles, loc=loc, ip=ip)
207+
results = op.results
208+
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
209+
210+
122211
@_ods_cext.register_operation(_Dialect, replace=True)
123212
class SequenceOp(SequenceOp):
124213
def __init__(
125214
self,
126-
failure_propagation_mode,
215+
failure_propagation_mode: FailurePropagationMode,
127216
results: Sequence[Type],
128217
target: Union[Operation, Value, Type],
129218
extra_bindings: Optional[
130219
Union[Sequence[Value], Sequence[Type], Operation, OpView]
131220
] = None,
221+
*,
222+
loc=None,
223+
ip=None,
132224
):
133225
root = (
134226
_get_op_result_or_value(target)
@@ -155,6 +247,8 @@ def __init__(
155247
failure_propagation_mode=failure_propagation_mode,
156248
root=root,
157249
extra_bindings=extra_bindings,
250+
loc=loc,
251+
ip=ip,
158252
)
159253
self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
160254

@@ -171,16 +265,42 @@ def bodyExtraArgs(self) -> BlockArgumentList:
171265
return self.body.arguments[1:]
172266

173267

268+
def sequence(
269+
failure_propagation_mode: FailurePropagationMode,
270+
results: Sequence[Type],
271+
target: Union[Operation, Value, Type],
272+
extra_bindings: Optional[
273+
Union[Sequence[Value], Sequence[Type], Operation, OpView]
274+
] = None,
275+
*,
276+
loc=None,
277+
ip=None,
278+
) -> Union[OpResult, OpResultList, SequenceOp]:
279+
op = SequenceOp(
280+
results=results,
281+
failure_propagation_mode=failure_propagation_mode,
282+
extra_bindings=extra_bindings,
283+
target=target,
284+
loc=loc,
285+
ip=ip,
286+
)
287+
results = op.results
288+
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
289+
290+
174291
@_ods_cext.register_operation(_Dialect, replace=True)
175292
class NamedSequenceOp(NamedSequenceOp):
176293
def __init__(
177294
self,
178-
sym_name,
295+
sym_name: Union[str, SymbolRefAttr],
179296
input_types: Sequence[Type],
180297
result_types: Sequence[Type],
181-
sym_visibility=None,
182-
arg_attrs=None,
183-
res_attrs=None,
298+
*,
299+
sym_visibility: Optional[Union[str, StringAttr]] = None,
300+
arg_attrs: Optional[Union[Sequence[dict], "DictArrayAttr"]] = None,
301+
res_attrs: Optional[Union[Sequence[dict], "DictArrayAttr"]] = None,
302+
loc=None,
303+
ip=None,
184304
):
185305
function_type = FunctionType.get(input_types, result_types)
186306
super().__init__(
@@ -205,6 +325,29 @@ def bodyExtraArgs(self) -> BlockArgumentList:
205325
return self.body.arguments[1:]
206326

207327

328+
def named_sequence(
329+
sym_name: Union[str, SymbolRefAttr],
330+
input_types: Sequence[Type],
331+
result_types: Sequence[Type],
332+
*,
333+
sym_visibility: Optional[Union[str, StringAttr]] = None,
334+
arg_attrs: Optional[Union[Sequence[dict], "DictArrayAttr"]] = None,
335+
res_attrs: Optional[Union[Sequence[dict], "DictArrayAttr"]] = None,
336+
loc=None,
337+
ip=None,
338+
) -> NamedSequenceOp:
339+
return NamedSequenceOp(
340+
sym_name=sym_name,
341+
input_types=input_types,
342+
result_types=result_types,
343+
sym_visibility=sym_visibility,
344+
arg_attrs=arg_attrs,
345+
res_attrs=res_attrs,
346+
loc=loc,
347+
ip=ip,
348+
)
349+
350+
208351
@_ods_cext.register_operation(_Dialect, replace=True)
209352
class YieldOp(YieldOp):
210353
def __init__(
@@ -219,6 +362,12 @@ def __init__(
219362
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
220363

221364

365+
def yield_(
366+
operands: Optional[Union[Operation, Sequence[Value]]] = None, *, loc=None, ip=None
367+
) -> YieldOp:
368+
return YieldOp(operands=operands, loc=loc, ip=ip)
369+
370+
222371
OptionValueTypes = Union[
223372
Sequence["OptionValueTypes"], Attribute, Value, Operation, OpView, str, int, bool
224373
]
@@ -247,7 +396,7 @@ def __init__(
247396
def option_value_to_attr(value):
248397
nonlocal cur_param_operand_idx
249398
if isinstance(value, (Value, Operation, OpView)):
250-
dynamic_options.append(_get_op_result_or_value(value))
399+
dynamic_options.append(value)
251400
cur_param_operand_idx += 1
252401
return ParamOperandAttr(cur_param_operand_idx - 1, context)
253402
elif isinstance(value, Attribute):

0 commit comments

Comments
 (0)