Skip to content
This repository was archived by the owner on Oct 31, 2022. It is now read-only.

Commit ffc54c7

Browse files
author
nshepperd
committed
Add tensor rematerialization.
1 parent 2de5d1b commit ffc54c7

File tree

21 files changed

+1515
-50
lines changed

21 files changed

+1515
-50
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@ __pycache__
22
.mypy_cache/
33
models/
44
checkpoint
5-
samples
5+
samples
6+
dist-newstyle
7+
bin

README.md

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,68 @@ PYTHONPATH=src ./encode.py <file|directory|glob> /path/to/encoded.npz
1515
PYTHONPATH=src ./train.py --dataset /path/to/encoded.npz
1616
```
1717

18-
Make sure `cudnn` is installed. [Some have reported](https://github.com/nshepperd/gpt-2/issues/8) that `train.py` runs without it but has worse memory usage and might OOM.
18+
Make sure `cudnn` is installed. [Some have
19+
reported](https://github.com/nshepperd/gpt-2/issues/8) that `train.py`
20+
runs without it but has worse memory usage and might OOM.
21+
22+
### Tensor Rematerialization
23+
24+
Experimental: a rematerialization rewriter based on `Efficient
25+
Rematerialization for Deep Networks`
26+
<https://papers.nips.cc/paper/9653-efficient-rematerialization-for-deep-networks.pdf>,
27+
which unlike gradient checkpointing works in tensorflow 2.0 and is
28+
able to automatically select checkpoints in arbitrary graphs. Using
29+
this I was able to finetune GPT-2 1.5B on a single graphics card using
30+
slightly less than 12G of video ram with very little slowdown.
31+
32+
To use this is a little involved, because the graph optimization
33+
algorithm is offloaded to an optimized Haskell program. First, go into
34+
subdirectory `twremat`, and build it by invoking:
35+
36+
cabal v2-install --installdir=../bin
37+
38+
(You'll need to install cabal if you haven't already -- but setting up
39+
ghc and haskell compilation is beyond the scope of this README.)
40+
41+
Then run `train.py` as normal, enabling `--twremat` and setting
42+
`--twremat_memlimit` to an appropriate value -- this sets the amount
43+
of memory assumed to be available for computation of gradients, so it
44+
should be roughly the memory size of your graphics card minus whatever
45+
is taken up by the gpt-2 weights, and any other bookkeeping
46+
variables. You may need to experiment with the memlimit until you find
47+
the largest value that doesn't OOM.
48+
49+
(You probably also want to use SGD as optimizer instead of Adam to
50+
minimize those bookkeeping variables, of which Adam uses a lot).
1951

2052
### Gradient Checkpointing
2153

22-
https://github.com/openai/gradient-checkpointing is included to reduce the memory requirements of the model, and can be enabled by `--memory_saving_gradients`. The checkpoints are currently chosen manually (poorly) by just adding layer 10 to the 'checkpoints' collection in model.py. `--memory_saving_gradients` is enabled by default for training the 345M model.
54+
https://github.com/openai/gradient-checkpointing is included to reduce
55+
the memory requirements of the model, and can be enabled by
56+
`--memory_saving_gradients`. The checkpoints are currently chosen
57+
manually (poorly) by just adding layer 10 to the 'checkpoints'
58+
collection in model.py.
59+
60+
Gradient checkpointing doesn't work in tensorflow v2.0 and later due
61+
to the removal of tf.contrib. You should use tensor rematerialization
62+
instead if possible.
2363

2464
### Validation loss
2565

26-
Set `--val_every` to a number of steps `N > 0`, and "validation" loss against a fixed sample of the dataset will be calculated every N steps to get a better sense of training progress. N around 200 suggested. You can set `--val_dataset` to choose a separate validation dataset, otherwise it defaults to a sample from the train dataset (so not a real cross-validation loss!).
66+
Set `--val_every` to a number of steps `N > 0`, and "validation" loss
67+
against a fixed sample of the dataset will be calculated every N steps
68+
to get a better sense of training progress. N around 200
69+
suggested. You can set `--val_dataset` to choose a separate validation
70+
dataset, otherwise it defaults to a sample from the train dataset (so
71+
not a real cross-validation loss!).
2772

2873
### Optimizer
2974

30-
You can use SGD instead of Adam with `--optimizer sgd`. This also helps conserve memory when training the 345M model. Note: the learning rate needs to be adjusted for SGD, due to not having Adam's gradient normalization (0.0006 seems to be a good number from some experiments).
75+
You can use SGD instead of Adam with `--optimizer sgd`. This also
76+
helps conserve memory when training larger models. Note: the learning
77+
rate needs to be adjusted for SGD, due to not having Adam's gradient
78+
normalization (0.0006 seems to be a good number from some
79+
experiments).
3180

3281
# Original README
3382

src/tfremat.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import random
2+
import os
3+
import tensorflow.compat.v1 as tf
4+
import tempfile
5+
6+
import twremat
7+
8+
def splice_op(op, input_map, control_inputs=None):
9+
g = op.graph
10+
node_def = tf.NodeDef()
11+
node_def.CopyFrom(op.node_def)
12+
node_def.name = g.unique_name(op.name + '_copy')
13+
inputs = [input_map.get(x, x) for x in op.inputs]
14+
new_control_inputs = [input_map.get(x, x) for x in op.control_inputs]
15+
if control_inputs:
16+
new_control_inputs.extend([x for x in control_inputs if x is not None])
17+
# new_control_inputs = control_inputs
18+
output_types = [o.dtype for o in op.outputs]
19+
op_def = op.op_def
20+
return tf.Operation(node_def, g, inputs=inputs, output_types=output_types, op_def=op_def, control_inputs=new_control_inputs)
21+
22+
def splice_tensor(ten, new_op):
23+
i = ten.op.outputs.index(ten)
24+
return new_op.outputs[i]
25+
26+
def splice(obj, input_map, control_inputs=None):
27+
if type(obj) is tf.Operation:
28+
return splice_op(obj, input_map, control_inputs=control_inputs)
29+
elif type(obj) is tf.Tensor:
30+
return splice_tensor(obj, input_map.get(obj.op, obj.op))
31+
elif type(obj) is tf.IndexedSlices:
32+
return tf.IndexedSlices(values=input_map.get(obj.values, obj.values),
33+
indices=input_map.get(obj.indices, obj.indices),
34+
dense_shape=input_map.get(obj.dense_shape, obj.dense_shape))
35+
else:
36+
raise AssertionError(f'Could not get deps from{repr(type(obj))} {repr(obj)}')
37+
38+
def product(xs):
39+
r = 1
40+
for x in xs:
41+
r *= x
42+
return r
43+
44+
def shape_size(shape):
45+
if shape.rank is None:
46+
return 16
47+
shape = shape.as_list()
48+
for i in range(len(shape)):
49+
if shape[i] is None and i == 0:
50+
shape[i] = 1
51+
elif shape[i] is None:
52+
shape[i] = 1024
53+
return product(shape)
54+
55+
def graph_from_dfs(deps, starts):
56+
visited = set()
57+
frontier = starts
58+
while frontier:
59+
x = frontier.pop()
60+
if x in visited:
61+
continue
62+
visited.add(x)
63+
frontier.extend(list(deps(x)))
64+
return {x : list(deps(x)) for x in visited}
65+
66+
def get_deps(obj):
67+
if type(obj) is tf.Operation:
68+
return list(obj.inputs) + list(obj.control_inputs)
69+
elif type(obj) is tf.Tensor:
70+
return [obj.op]
71+
elif type(obj) is tf.IndexedSlices:
72+
return [obj.indices, obj.values, obj.dense_shape]
73+
else:
74+
raise AssertionError(f'Could not get deps from{repr(type(obj))} {repr(obj)}')
75+
76+
77+
def tensor_graph(compute):
78+
return graph_from_dfs(get_deps, list(compute))
79+
80+
def blacklist(obj):
81+
if type(obj) is tf.Operation:
82+
if 'Assign' in obj.type or 'Variable' in obj.type or 'Placeholder' in obj.type:
83+
# TODO: Should we do special accounting for
84+
# ReadVariableOp? Currently we forbid cloning altogether,
85+
# but it's actually ok to clone this op as long as it
86+
# doesn't float across an effectful op (Assign). Also
87+
# currently we don't account for the memory used by
88+
# ReadVariableOp (is it copy-on-write?).
89+
# https://www.tensorflow.org/api_docs/python/tf/raw_ops/ReadVariableOp?hl=uk
90+
return True
91+
elif type(obj) is tf.Tensor:
92+
return blacklist(obj.op)
93+
return False
94+
95+
def estimate_cpu(op):
96+
return sum(4 * shape_size(t.shape) for t in op.inputs if type(t) is tf.Tensor) + sum(4 * shape_size(t.shape) for t in op.outputs)
97+
98+
def estimate_mem(op):
99+
return sum(4 * shape_size(t.shape) for t in op.outputs)
100+
101+
def info(op):
102+
if blacklist(op):
103+
return {'type': 'effectful'}
104+
elif type(op) is tf.Operation:
105+
if 'Reshape' in op.type:
106+
return {'type': 'pointer'}
107+
return {'type': 'normal',
108+
'cpu': estimate_cpu(op),
109+
'mem': estimate_mem(op)}
110+
elif type(op) is tf.Tensor:
111+
return {'type': 'pointer'}
112+
elif type(op) is tf.IndexedSlices:
113+
return {'type': 'pointer'}
114+
else:
115+
raise AssertionError(repr((type(op), op)))
116+
117+
118+
# Helper functions to flatten and unflatten nested structures of
119+
# tensors and ops so that tf_remat can be applied to structures
120+
# without fiddly marshalling.
121+
def get_ops(compute):
122+
output = []
123+
stack = [compute]
124+
while stack:
125+
top = stack.pop()
126+
if type(top) is dict:
127+
for v in top.values():
128+
stack.append(v)
129+
elif type(top) in (list, tuple):
130+
stack.extend(top)
131+
elif type(top) in (tf.Operation, tf.Tensor, tf.IndexedSlices):
132+
output.append(top)
133+
return output
134+
135+
def replace_ops(top, live):
136+
if type(top) in (tf.Operation, tf.Tensor, tf.IndexedSlices):
137+
return live[top]
138+
elif type(top) is dict:
139+
return {k : replace_ops(v, live) for (k,v) in top.items()}
140+
elif type(top) is list:
141+
return [replace_ops(v, live) for v in top]
142+
elif type(top) is tuple:
143+
return tuple(replace_ops(v, live) for v in top)
144+
else:
145+
return top
146+
147+
148+
def tf_remat(compute, memlimit):
149+
compute_ops = get_ops(compute)
150+
tf_deps = tensor_graph(compute_ops)
151+
152+
# Relabel with integers
153+
from_op = {op : i for (i, op) in enumerate(tf_deps.keys())}
154+
from_node = {i : op for (op, i) in from_op.items()}
155+
nodes = set(from_node.keys())
156+
node_deps = {n : [from_op[d] for d in tf_deps[from_node[n]]] for n in nodes}
157+
158+
node_info = {}
159+
for n in nodes:
160+
node_info[n] = info(from_node[n])
161+
node_info[n]['deps'] = [from_op[d] for d in tf_deps[from_node[n]]]
162+
163+
steps = twremat.runtwremat(node_info, memlimit, {from_op[c] for c in compute_ops})
164+
165+
print('Constructing tensorflow graph...')
166+
live = {}
167+
last_op = None
168+
for (action, n) in steps:
169+
base = from_node[n]
170+
if action == 'compute':
171+
input_map = {d : live[d] for d in tf_deps[base] if live[d] != d}
172+
if blacklist(base) and not input_map:
173+
live[base] = base
174+
else:
175+
live[base] = splice(base, input_map, control_inputs=[last_op])
176+
if type(base) is tf.Operation:
177+
last_op = live[base]
178+
elif action == 'free':
179+
del live[base]
180+
181+
return replace_ops(compute, live)

src/twremat.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from subprocess import Popen, PIPE
2+
import random
3+
import os
4+
import sys
5+
import tempfile
6+
from tqdm import tqdm
7+
8+
BINDIR=os.path.join(os.path.dirname(sys.argv[0]), 'bin')
9+
TWREMAT=os.path.join(BINDIR, 'twremat')
10+
11+
# Allow users to pass 'humanized' memlimit values as strings.
12+
def parse_memlimit(memlimit):
13+
if memlimit[-1] == 'K':
14+
return int(memlimit[:-1]) * 1000
15+
elif memlimit[-1] == 'M':
16+
return int(memlimit[:-1]) * 1000000
17+
elif memlimit[-1] == 'G':
18+
return int(memlimit[:-1]) * 1000000000
19+
else:
20+
return int(memlimit)
21+
22+
def runtwremat(gr, memlimit, target):
23+
if type(memlimit) is str:
24+
memlimit = parse_memlimit(memlimit)
25+
26+
fname = tempfile.mktemp()
27+
outname = tempfile.mktemp()
28+
with open(fname, 'w') as fp:
29+
print('p remat2', file=fp)
30+
print(f'memlimit {memlimit}', file=fp)
31+
for (n, info) in gr.items():
32+
deps = ' '.join(str(d) for d in info['deps'])
33+
if info['type'] == 'normal':
34+
cpu = info['cpu']
35+
mem = info['mem']
36+
weight = f'cpu {cpu} mem {mem}'
37+
elif info['type'] == 'effectful':
38+
weight = 'effectful'
39+
elif info['type'] == 'pointer':
40+
weight = 'pointer'
41+
if n in target:
42+
tstr = 'target'
43+
else:
44+
tstr = ''
45+
print(f'node {n} deps {deps} {weight} {tstr}', file=fp)
46+
print(' '.join([TWREMAT, fname, outname]))
47+
proc = Popen([TWREMAT, fname, outname])
48+
assert proc.wait() == 0
49+
out = []
50+
with open(outname, 'r') as fp:
51+
for line in fp:
52+
line = line.split()
53+
if line and line[0] == 'c':
54+
out.append(('compute', int(line[1])))
55+
elif line and line[0] == 'f':
56+
out.append(('free', int(line[1])))
57+
elif line:
58+
print(line)
59+
exit()
60+
return out

0 commit comments

Comments
 (0)