Skip to content

Commit 1ff47c6

Browse files
committed
Make the cached version work
1 parent 4c833a2 commit 1ff47c6

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

dalle_pytorch/attention.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch.nn.functional as F
77
from einops import rearrange, repeat
88

9+
from rotary_embedding_torch import apply_rotary_emb
10+
911
# helpers
1012

1113
def exists(val):
@@ -27,17 +29,6 @@ def stable_softmax(t, dim = -1, alpha = 32 ** 2):
2729
t = t - torch.amax(t, dim = dim, keepdim = True)
2830
return (t * alpha).softmax(dim = dim)
2931

30-
def rotate_half(x):
31-
d = x.shape[-1] // 2
32-
return torch.cat([-x[..., d:], x[..., :d]], dim=-1)
33-
34-
def apply_rotary_emb(freqs, t):
35-
rot_dim = freqs.shape[-1]
36-
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
37-
t, t_right = t[..., :rot_dim], t[..., rot_dim:]
38-
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
39-
return torch.cat((t, t_right), dim = -1)
40-
4132
def apply_pos_emb(pos_emb, qkv):
4233
n = qkv[0].shape[-2]
4334
pos_emb = pos_emb[..., :n, :]

dalle_pytorch/transformer.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(self, dim, dropout = 0., mult = 4.):
9292
nn.Linear(dim * mult, dim)
9393
)
9494

95-
def forward(self, x):
95+
def forward(self, x, cache=None, cache_key=None):
9696
return self.net(x)
9797

9898
# token shift classes
@@ -104,7 +104,13 @@ def __init__(self, fn, image_size, seq_len):
104104
self.image_size = image_size
105105
self.seq_len = seq_len
106106

107-
def forward(self, x, **kwargs):
107+
def forward(self, x, cache=None, cache_key=None, **kwargs):
108+
n0 = x.shape[1]
109+
if exists(cache):
110+
if cache_key in cache:
111+
x = torch.cat([cache[cache_key], x], dim=-2)
112+
cache[cache_key] = x
113+
108114
n = x.shape[1]
109115
seq_len, image_size = self.seq_len, self.image_size
110116
img_seq_len = image_size ** 2
@@ -134,7 +140,7 @@ def forward(self, x, **kwargs):
134140

135141
x_img = rearrange(x_img, 'b h w d -> b (h w) d')
136142
x = torch.cat((x_text, x_img[:, :-padding]), dim = 1)
137-
return self.fn(x, **kwargs)
143+
return self.fn(x[:, -n0:], cache=cache, **kwargs)
138144

139145
# main transformer class
140146

@@ -221,7 +227,8 @@ def __init__(
221227
attn = CachedAs(f'attn_{ind}', attn)
222228

223229
if shift_tokens:
224-
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))
230+
attn = CachedAs(f'preshift_attn_{ind}', PreShiftToken(attn, image_size = image_fmap_size, seq_len = seq_len))
231+
ff = CachedAs(f'preshift_ff_{ind}', PreShiftToken(ff, image_size = image_fmap_size, seq_len = seq_len))
225232

226233
layers.append(nn.ModuleList([
227234
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)),
@@ -230,8 +237,9 @@ def __init__(
230237

231238
execute_type = ReversibleSequence if reversible else SequentialSequence
232239
route_attn = ((True, False),) * depth
240+
route_all = ((True, True),) * depth
233241
attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn,
234-
'cache': route_attn}
242+
'cache': route_all}
235243

236244
self.layers = execute_type(layers, args_route = attn_route_map)
237245

@@ -270,9 +278,9 @@ def forward(self, x, **kwargs):
270278

271279
def _get_static_mask(self, attn_type):
272280
img_seq_len = self.image_fmap_size ** 2
273-
text_len = self.seq_len - img_seq_len
281+
text_len = self.seq_len + 1 - img_seq_len
274282

275-
static_mask = torch.ones(self.seq_len, self.seq_len, dtype=torch.bool)
283+
static_mask = torch.zeros(self.seq_len, self.seq_len, dtype=torch.bool)
276284
static_mask[:, :text_len] = True
277285
if attn_type == 'axial_row':
278286
for row in range(self.image_fmap_size):

0 commit comments

Comments
 (0)