Skip to content

Commit 94fda36

Browse files
committedDec 21, 2021

File tree

1 file changed

+43
-11
lines changed

1 file changed

+43
-11
lines changed
 

‎dalle_pytorch/transformer.py

+43-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from collections import deque
12
from collections.abc import Iterable
23
from functools import partial
3-
from itertools import islice, cycle, product
4+
from itertools import islice, cycle
45

56
import torch
67
from torch import nn, einsum
@@ -103,18 +104,30 @@ def __init__(self, fn, image_size, seq_len):
103104
self.fn = fn
104105
self.image_size = image_size
105106
self.seq_len = seq_len
107+
self.img_seq_len = image_size ** 2
108+
self.text_len = seq_len - self.img_seq_len + 1
106109

107110
def forward(self, x, cache=None, cache_key=None, **kwargs):
108-
n0 = x.shape[1]
109-
if exists(cache):
110-
if cache_key in cache:
111-
x = torch.cat([cache[cache_key], x], dim=-2)
112-
cache[cache_key] = x
111+
seq_len, image_size, text_len = self.seq_len, self.image_size, self.text_len
112+
113+
if exists(cache) and cache_key in cache:
114+
offset = cache['offset']
115+
assert offset >= text_len, "cached inference for text is not supported"
116+
q = cache[cache_key]
117+
assert isinstance(q, deque) and len(q) == image_size
118+
119+
x_top, x_left, *x_pass = x[:, -1].chunk(4, dim=-1)
120+
121+
q.append((x_top, x_left))
122+
x_top = q.popleft()[0]
123+
x_left = q[-2][1]
124+
if (offset - text_len) % image_size == 0:
125+
x_left = torch.zeros_like(x_left)
126+
127+
x = torch.cat((x_top, x_left, *x_pass), dim=-1)
128+
return self.fn(x[:, None], cache=cache, **kwargs)
113129

114130
n = x.shape[1]
115-
seq_len, image_size = self.seq_len, self.image_size
116-
img_seq_len = image_size ** 2
117-
text_len = seq_len - img_seq_len + 1
118131
padding = seq_len - n + 1
119132

120133
# get text and image tokens
@@ -139,8 +152,22 @@ def forward(self, x, cache=None, cache_key=None, **kwargs):
139152
# merge text and image sequence back together
140153

141154
x_img = rearrange(x_img, 'b h w d -> b (h w) d')
142-
x = torch.cat((x_text, x_img[:, :-padding]), dim = 1)
143-
return self.fn(x[:, -n0:], cache=cache, **kwargs)
155+
x_img = x_img[:, :-padding]
156+
x = torch.cat((x_text, x_img), dim = 1)
157+
158+
if exists(cache):
159+
dummy_top, dummy_left, *_ = x[:, -1].chunk(4, dim=-1)
160+
dummy_top, dummy_left = torch.zeros_like(dummy_top), torch.zeros_like(dummy_left)
161+
162+
q = deque()
163+
x_img = x_img[:, -image_size:]
164+
for _ in range(image_size - x_img.shape[1]):
165+
q.append((dummy_top, dummy_left))
166+
for i in range(x_img.shape[1]):
167+
q.append(x_img[:, i].chunk(4, dim=-1)[:2])
168+
cache[cache_key] = q
169+
170+
return self.fn(x, cache=cache, **kwargs)
144171

145172
# main transformer class
146173

@@ -277,6 +304,11 @@ def forward(self, x, **kwargs):
277304
return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs)
278305

279306
def _get_static_mask(self, attn_type):
307+
# In case of attn_type = "axial_{row,col}",
308+
# the sparse implementation is most efficient for training,
309+
# but the full attention with a static mask is most efficient for inference
310+
# since caching is implemented in this case.
311+
280312
img_seq_len = self.image_fmap_size ** 2
281313
text_len = self.seq_len + 1 - img_seq_len
282314

0 commit comments

Comments
 (0)
Please sign in to comment.