|
| 1 | +import random |
| 2 | +import os |
| 3 | +import tensorflow.compat.v1 as tf |
| 4 | +import tempfile |
| 5 | + |
| 6 | +import twremat |
| 7 | + |
| 8 | +def splice_op(op, input_map, control_inputs=None): |
| 9 | + g = op.graph |
| 10 | + node_def = tf.NodeDef() |
| 11 | + node_def.CopyFrom(op.node_def) |
| 12 | + node_def.name = g.unique_name(op.name + '_copy') |
| 13 | + inputs = [input_map.get(x, x) for x in op.inputs] |
| 14 | + new_control_inputs = [input_map.get(x, x) for x in op.control_inputs] |
| 15 | + if control_inputs: |
| 16 | + new_control_inputs.extend([x for x in control_inputs if x is not None]) |
| 17 | + # new_control_inputs = control_inputs |
| 18 | + output_types = [o.dtype for o in op.outputs] |
| 19 | + op_def = op.op_def |
| 20 | + return tf.Operation(node_def, g, inputs=inputs, output_types=output_types, op_def=op_def, control_inputs=new_control_inputs) |
| 21 | + |
| 22 | +def splice_tensor(ten, new_op): |
| 23 | + i = ten.op.outputs.index(ten) |
| 24 | + return new_op.outputs[i] |
| 25 | + |
| 26 | +def splice(obj, input_map, control_inputs=None): |
| 27 | + if type(obj) is tf.Operation: |
| 28 | + return splice_op(obj, input_map, control_inputs=control_inputs) |
| 29 | + elif type(obj) is tf.Tensor: |
| 30 | + return splice_tensor(obj, input_map.get(obj.op, obj.op)) |
| 31 | + elif type(obj) is tf.IndexedSlices: |
| 32 | + return tf.IndexedSlices(values=input_map.get(obj.values, obj.values), |
| 33 | + indices=input_map.get(obj.indices, obj.indices), |
| 34 | + dense_shape=input_map.get(obj.dense_shape, obj.dense_shape)) |
| 35 | + else: |
| 36 | + raise AssertionError(f'Could not get deps from{repr(type(obj))} {repr(obj)}') |
| 37 | + |
| 38 | +def product(xs): |
| 39 | + r = 1 |
| 40 | + for x in xs: |
| 41 | + r *= x |
| 42 | + return r |
| 43 | + |
| 44 | +def shape_size(shape): |
| 45 | + if shape.rank is None: |
| 46 | + return 16 |
| 47 | + shape = shape.as_list() |
| 48 | + for i in range(len(shape)): |
| 49 | + if shape[i] is None and i == 0: |
| 50 | + shape[i] = 1 |
| 51 | + elif shape[i] is None: |
| 52 | + shape[i] = 1024 |
| 53 | + return product(shape) |
| 54 | + |
| 55 | +def graph_from_dfs(deps, starts): |
| 56 | + visited = set() |
| 57 | + frontier = starts |
| 58 | + while frontier: |
| 59 | + x = frontier.pop() |
| 60 | + if x in visited: |
| 61 | + continue |
| 62 | + visited.add(x) |
| 63 | + frontier.extend(list(deps(x))) |
| 64 | + return {x : list(deps(x)) for x in visited} |
| 65 | + |
| 66 | +def get_deps(obj): |
| 67 | + if type(obj) is tf.Operation: |
| 68 | + return list(obj.inputs) + list(obj.control_inputs) |
| 69 | + elif type(obj) is tf.Tensor: |
| 70 | + return [obj.op] |
| 71 | + elif type(obj) is tf.IndexedSlices: |
| 72 | + return [obj.indices, obj.values, obj.dense_shape] |
| 73 | + else: |
| 74 | + raise AssertionError(f'Could not get deps from{repr(type(obj))} {repr(obj)}') |
| 75 | + |
| 76 | + |
| 77 | +def tensor_graph(compute): |
| 78 | + return graph_from_dfs(get_deps, list(compute)) |
| 79 | + |
| 80 | +def blacklist(obj): |
| 81 | + if type(obj) is tf.Operation: |
| 82 | + if 'Assign' in obj.type or 'Variable' in obj.type or 'Placeholder' in obj.type: |
| 83 | + # TODO: Should we do special accounting for |
| 84 | + # ReadVariableOp? Currently we forbid cloning altogether, |
| 85 | + # but it's actually ok to clone this op as long as it |
| 86 | + # doesn't float across an effectful op (Assign). Also |
| 87 | + # currently we don't account for the memory used by |
| 88 | + # ReadVariableOp (is it copy-on-write?). |
| 89 | + # https://www.tensorflow.org/api_docs/python/tf/raw_ops/ReadVariableOp?hl=uk |
| 90 | + return True |
| 91 | + elif type(obj) is tf.Tensor: |
| 92 | + return blacklist(obj.op) |
| 93 | + return False |
| 94 | + |
| 95 | +def estimate_cpu(op): |
| 96 | + return sum(4 * shape_size(t.shape) for t in op.inputs if type(t) is tf.Tensor) + sum(4 * shape_size(t.shape) for t in op.outputs) |
| 97 | + |
| 98 | +def estimate_mem(op): |
| 99 | + return sum(4 * shape_size(t.shape) for t in op.outputs) |
| 100 | + |
| 101 | +def info(op): |
| 102 | + if blacklist(op): |
| 103 | + return {'type': 'effectful'} |
| 104 | + elif type(op) is tf.Operation: |
| 105 | + if 'Reshape' in op.type: |
| 106 | + return {'type': 'pointer'} |
| 107 | + return {'type': 'normal', |
| 108 | + 'cpu': estimate_cpu(op), |
| 109 | + 'mem': estimate_mem(op)} |
| 110 | + elif type(op) is tf.Tensor: |
| 111 | + return {'type': 'pointer'} |
| 112 | + elif type(op) is tf.IndexedSlices: |
| 113 | + return {'type': 'pointer'} |
| 114 | + else: |
| 115 | + raise AssertionError(repr((type(op), op))) |
| 116 | + |
| 117 | + |
| 118 | +# Helper functions to flatten and unflatten nested structures of |
| 119 | +# tensors and ops so that tf_remat can be applied to structures |
| 120 | +# without fiddly marshalling. |
| 121 | +def get_ops(compute): |
| 122 | + output = [] |
| 123 | + stack = [compute] |
| 124 | + while stack: |
| 125 | + top = stack.pop() |
| 126 | + if type(top) is dict: |
| 127 | + for v in top.values(): |
| 128 | + stack.append(v) |
| 129 | + elif type(top) in (list, tuple): |
| 130 | + stack.extend(top) |
| 131 | + elif type(top) in (tf.Operation, tf.Tensor, tf.IndexedSlices): |
| 132 | + output.append(top) |
| 133 | + return output |
| 134 | + |
| 135 | +def replace_ops(top, live): |
| 136 | + if type(top) in (tf.Operation, tf.Tensor, tf.IndexedSlices): |
| 137 | + return live[top] |
| 138 | + elif type(top) is dict: |
| 139 | + return {k : replace_ops(v, live) for (k,v) in top.items()} |
| 140 | + elif type(top) is list: |
| 141 | + return [replace_ops(v, live) for v in top] |
| 142 | + elif type(top) is tuple: |
| 143 | + return tuple(replace_ops(v, live) for v in top) |
| 144 | + else: |
| 145 | + return top |
| 146 | + |
| 147 | + |
| 148 | +def tf_remat(compute, memlimit): |
| 149 | + compute_ops = get_ops(compute) |
| 150 | + tf_deps = tensor_graph(compute_ops) |
| 151 | + |
| 152 | + # Relabel with integers |
| 153 | + from_op = {op : i for (i, op) in enumerate(tf_deps.keys())} |
| 154 | + from_node = {i : op for (op, i) in from_op.items()} |
| 155 | + nodes = set(from_node.keys()) |
| 156 | + node_deps = {n : [from_op[d] for d in tf_deps[from_node[n]]] for n in nodes} |
| 157 | + |
| 158 | + node_info = {} |
| 159 | + for n in nodes: |
| 160 | + node_info[n] = info(from_node[n]) |
| 161 | + node_info[n]['deps'] = [from_op[d] for d in tf_deps[from_node[n]]] |
| 162 | + |
| 163 | + steps = twremat.runtwremat(node_info, memlimit, {from_op[c] for c in compute_ops}) |
| 164 | + |
| 165 | + print('Constructing tensorflow graph...') |
| 166 | + live = {} |
| 167 | + last_op = None |
| 168 | + for (action, n) in steps: |
| 169 | + base = from_node[n] |
| 170 | + if action == 'compute': |
| 171 | + input_map = {d : live[d] for d in tf_deps[base] if live[d] != d} |
| 172 | + if blacklist(base) and not input_map: |
| 173 | + live[base] = base |
| 174 | + else: |
| 175 | + live[base] = splice(base, input_map, control_inputs=[last_op]) |
| 176 | + if type(base) is tf.Operation: |
| 177 | + last_op = live[base] |
| 178 | + elif action == 'free': |
| 179 | + del live[base] |
| 180 | + |
| 181 | + return replace_ops(compute, live) |
0 commit comments