Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NFC] Update black version #3256

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 22.10.0
rev: 24.4.2
hooks:
- id: black

Expand Down
1 change: 1 addition & 0 deletions build_tools/scrape_releases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

See https://github.com/llvm/torch-mlir/issues/1374
"""

import argparse
import json

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from transformers import BertForMaskedLM


# Wrap the bert model to avoid multiple returns problem
class BertTinyWrapper(torch.nn.Module):
def __init__(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions projects/pt1/python/torch_mlir/_dynamo_fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ def __init__(self, g: torch.fx.Graph, func_name: str):
# FakeTensor's in case of a tuple return with multiple elements.
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
self._module = ir.Module.create(ir.Location.unknown())
self._module.operation.attributes[
"torch.debug_module_name"
] = ir.StringAttr.get(func_name)
self._module.operation.attributes["torch.debug_module_name"] = (
ir.StringAttr.get(func_name)
)
function_type = _extract_function_type_from_graph(g)
func = func_dialect.FuncOp(
func_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ def emit_with_mutating_variants(key, **kwargs):
(ns, unqual + "_", overload if not is_functional_op else "")
),
emitter_td,
traits=["IsTrailingUnderscoreInplaceVariant"]
if not is_functional_op
else [],
traits=(
["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else []
),
)

# ==========================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def convert_onnx(model, inputs):
examples = []
input_names = []
dynamic_tensors = {}
for (index, arg) in enumerate(inputs):
for index, arg in enumerate(inputs):
shape = map(lambda d: d if d >= 0 else 1, arg.shape)
shape = tuple(shape)
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
Expand All @@ -55,7 +55,7 @@ def convert_onnx(model, inputs):
input_names.append(input_name)

dynamic_dims = {}
for (dimindex, dim) in enumerate(arg.shape):
for dimindex, dim in enumerate(arg.shape):
if dim < 0:
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,12 @@ def __init__(self, module):
def consume_return_funcs(*args):
self.result = tuple(
[
arg
if type in elemental_type_to_ctype
else unranked_memref_to_numpy(
arg, memref_type_to_np_dtype[type]
(
arg
if type in elemental_type_to_ctype
else unranked_memref_to_numpy(
arg, memref_type_to_np_dtype[type]
)
)
for arg, type in zip(args, ret_types)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -803,9 +803,7 @@ def forward(self, x):

@register_test_case(module_factory=lambda: QuantizedReluInt32())
def QuantizedReluInt32_basic(module, tu: TestUtils):
module.forward(
tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32)
)
module.forward(tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32))


# ==============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils):

# ==============================================================================


# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
class SliceScatterModule(torch.nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

mb = ModuleBuilder()


# CHECK: module attributes {torch.debug_module_name = "TestModule"}
class TestModule(torch.nn.Module):
def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so
# naively duplicating a Tensor retains the identity of the TensorImpl.


# CHECK-LABEL: torch.class_type @__torch__.TestModule {
class TestModule(torch.nn.Module):
def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

mb = ModuleBuilder()


# CHECK-LABEL: torch.class_type @__torch__.TestModule {
class TestModule(torch.nn.Module):
def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

mb = ModuleBuilder()


# CHECK-LABEL: func.func @__torch__.add3
# Note that line-level debug information for parts unannotated in the Torch
# graph are ascribed to the first op that carries source information. Presently
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

mb = ModuleBuilder()


# CHECK-LABEL: @__torch__.f
@mb.import_function
@torch.jit.script
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

mb = ModuleBuilder()


# CHECK-LABEL: func.func @__torch__.optional_return(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/test/python/importer/jit_ir/node_import/if.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# else branch and making all defined values optional, so no special handling
# is needed.


# CHECK-LABEL: @__torch__.prim_If(
# CHECK-SAME: %[[B:.*]]: !torch.bool,
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

mb = ModuleBuilder()


# CHECK-LABEL: func.func @__torch__.prim_Loop_forlike(
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float {
# CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

mb = ModuleBuilder()


# CHECK-LABEL: func.func @__torch__.prim_NumToTensor(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
mb = ModuleBuilder()
NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])])


# CHECK-LABEL: func.func @__torch__.tuple(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

mb = ModuleBuilder()


# CHECK: @__torch__.returns_bool
@mb.import_function
@torch.jit.script
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

mb = ModuleBuilder()


# CHECK: @__torch__.returns_none
@mb.import_function
@torch.jit.script
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# RUN: %PYTHON %s


# Import TorchScript IR string as ScriptFunction.
def create_script_function(func_name, ts_ir_str, **kwargs):
cu = CompilationUnit()
Expand Down
3 changes: 1 addition & 2 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1849,8 +1849,7 @@ def _emit_operation(

# Opaque value to indicate something is empty. Used in cases where 'None'
# may have a different meaning.
class EmptyType:
...
class EmptyType: ...


Empty = EmptyType()
Expand Down
31 changes: 16 additions & 15 deletions python/torch_mlir/extras/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,7 @@ def find_type_proto_for_name(self, name: str) -> onnx.TypeProto:
return ""


class OnnxImportError(Exception):
...
class OnnxImportError(Exception): ...


class NodeImporter:
Expand Down Expand Up @@ -235,22 +234,22 @@ def _populate_graph_attrs(self, container_op: Operation):
else:
default_opset_version = opset_import.version
if default_opset_version:
container_op.attributes[
"torch.onnx_meta.opset_version"
] = IntegerAttr.get(i64_type, default_opset_version)
container_op.attributes["torch.onnx_meta.opset_version"] = (
IntegerAttr.get(i64_type, default_opset_version)
)
if opset_versions:
container_op.attributes[
"torch.onnx_meta.opset_versions"
] = DictAttr.get(opset_versions)
container_op.attributes["torch.onnx_meta.opset_versions"] = (
DictAttr.get(opset_versions)
)
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
IntegerType.get_signed(64), m.ir_version
)
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
m.producer_name
)
container_op.attributes[
"torch.onnx_meta.producer_version"
] = StringAttr.get(m.producer_version)
container_op.attributes["torch.onnx_meta.producer_version"] = (
StringAttr.get(m.producer_version)
)

def import_all(self, func=True):
"""Imports all nodes topologically."""
Expand Down Expand Up @@ -658,9 +657,11 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
RankedTensorType.get(shape, IntegerType.get_signed(64)),
IntegerAttr.get(
IntegerType.get_signed(64),
int.from_bytes(tp.raw_data, "little", signed=True)
if tp.HasField("raw_data")
else tp.int64_data[0],
(
int.from_bytes(tp.raw_data, "little", signed=True)
if tp.HasField("raw_data")
else tp.int64_data[0]
),
),
),
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
Expand Down Expand Up @@ -703,7 +704,7 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
),
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
)
),
# Intentionally unsupported: STRING
}

Expand Down
Loading