diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 79938ba4d..555f70fb2 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -626,23 +626,74 @@ def unflatten(context, node): def _array_construct(context, node, array_type): assert len(node.outputs) == 1 inputs = _get_inputs(context, node) - scalar_inputs = [ - inp - for inp in inputs - if isinstance(inp, Var) and inp.can_be_folded_to_const() and len(inp.shape) == 0 - ] - if len(scalar_inputs) == len(inputs): + is_all_const = all(map(lambda inp : isinstance(inp, Var) and inp.can_be_folded_to_const() and len(inp.shape) == 0, inputs)) + if is_all_const: # All the list items are compile-time scalar constants, so let's create # a new const that concatenates them. val = array_type([inp.val for inp in inputs]) const = mb.const(val=val, name=node.name) context.add(const) - else: - # If at least one input to the construct op is non-const, collect - # the inputs and add them directly to the context. Ops that use this - # node's output will take the list directly as input. - context.add(array_type(inputs), node.name) + return + + nodes = {n.name : n for n in context.torch_graph.nodes} + is_known_name = lambda name : name in nodes + inheriting_bookkeeping = {name : -1 for name in nodes.keys()} + def dfs_graph_input_dependent(inputs, non_const=None): + ''' + inputs would be [] if all constant + otherwise further depend on each of their inputs, all the way to the root + + if some name is not in context.torch_graph.nodes, then it should be a symbolic in graph input + ''' + if non_const is None: + # init, effectively only at dfs.layer[0] + non_const = set() + + # len(inputs) == 0 is dfs base + for i in inputs: + if is_known_name(i): + if inheriting_bookkeeping[i] == -1: + inheriting = dfs_graph_input_dependent(nodes[i].inputs, non_const) + inheriting_bookkeeping[i] = len(inheriting) + else: + non_const.add(i) + return non_const + any_inheriting = dfs_graph_input_dependent(node.inputs) + dependent_on_graph_input = len(any_inheriting) > 0 + + if dependent_on_graph_input: + to_concat = [] + for input in node.inputs: + inheriting = inheriting_bookkeeping[input] + if inheriting <= 0: + # is const + to_concat.append([context[input].val]) + + else: + # is non_const + iter_node = nodes[input] + while all([is_known_name(i) for i in iter_node.inputs]): + iter_node = nodes[iter_node.inputs[0]] + + if context[iter_node.name].op.op_type == 'gather': + non_const = iter_node.inputs[0] + non_const_name = iter_node.inputs[1] + non_const_idx = context[non_const_name].val + to_concat.append(mb.slice_by_size(x=mb.shape(x=context[non_const]), begin=[non_const_idx], size=[1])) + + else: + to_concat = [] + break + + if len(to_concat) > 0: + context.add(mb.concat(values=to_concat, axis=0), node.name) + return + + # If at least one input to the construct op is neither const nor symbolic, collect + # the inputs and add them directly to the context. Ops that use this + # node's output will take the list directly as input. + context.add(array_type(inputs), node.name) @register_torch_op @@ -1595,6 +1646,12 @@ def pad(context, node): pad = pad.val.reshape((-1, 2))[::-1].reshape(-1).tolist() missing_dims = x.rank - (len(pad) // 2) pad = [0, 0] * missing_dims + pad + else: + missing_dims = (x.rank * 2 - pad.shape[0]) // 2 + pad = mb.concat(values=[pad, [0, 0] * missing_dims], axis=0) + pad = mb.reshape(x=pad, shape=[-1,2]) + pad = mb.reverse(x=pad, axes=[0]) + pad = mb.reshape(x=pad, shape=[-1]) # mil.ops.defs.iOS15.pad asserts 1D tensor if len(inputs) == 4: mode = inputs[2].val