1
+ from collections import deque
1
2
from collections .abc import Iterable
2
3
from functools import partial
3
- from itertools import islice , cycle , product
4
+ from itertools import islice , cycle
4
5
5
6
import torch
6
7
from torch import nn , einsum
@@ -103,18 +104,30 @@ def __init__(self, fn, image_size, seq_len):
103
104
self .fn = fn
104
105
self .image_size = image_size
105
106
self .seq_len = seq_len
107
+ self .img_seq_len = image_size ** 2
108
+ self .text_len = seq_len - self .img_seq_len + 1
106
109
107
110
def forward (self , x , cache = None , cache_key = None , ** kwargs ):
108
- n0 = x .shape [1 ]
109
- if exists (cache ):
110
- if cache_key in cache :
111
- x = torch .cat ([cache [cache_key ], x ], dim = - 2 )
112
- cache [cache_key ] = x
111
+ seq_len , image_size , text_len = self .seq_len , self .image_size , self .text_len
112
+
113
+ if exists (cache ) and cache_key in cache :
114
+ offset = cache ['offset' ]
115
+ assert offset >= text_len , "cached inference for text is not supported"
116
+ q = cache [cache_key ]
117
+ assert isinstance (q , deque ) and len (q ) == image_size
118
+
119
+ x_top , x_left , * x_pass = x [:, - 1 ].chunk (4 , dim = - 1 )
120
+
121
+ q .append ((x_top , x_left ))
122
+ x_top = q .popleft ()[0 ]
123
+ x_left = q [- 2 ][1 ]
124
+ if (offset - text_len ) % image_size == 0 :
125
+ x_left = torch .zeros_like (x_left )
126
+
127
+ x = torch .cat ((x_top , x_left , * x_pass ), dim = - 1 )
128
+ return self .fn (x [:, None ], cache = cache , ** kwargs )
113
129
114
130
n = x .shape [1 ]
115
- seq_len , image_size = self .seq_len , self .image_size
116
- img_seq_len = image_size ** 2
117
- text_len = seq_len - img_seq_len + 1
118
131
padding = seq_len - n + 1
119
132
120
133
# get text and image tokens
@@ -139,8 +152,22 @@ def forward(self, x, cache=None, cache_key=None, **kwargs):
139
152
# merge text and image sequence back together
140
153
141
154
x_img = rearrange (x_img , 'b h w d -> b (h w) d' )
142
- x = torch .cat ((x_text , x_img [:, :- padding ]), dim = 1 )
143
- return self .fn (x [:, - n0 :], cache = cache , ** kwargs )
155
+ x_img = x_img [:, :- padding ]
156
+ x = torch .cat ((x_text , x_img ), dim = 1 )
157
+
158
+ if exists (cache ):
159
+ dummy_top , dummy_left , * _ = x [:, - 1 ].chunk (4 , dim = - 1 )
160
+ dummy_top , dummy_left = torch .zeros_like (dummy_top ), torch .zeros_like (dummy_left )
161
+
162
+ q = deque ()
163
+ x_img = x_img [:, - image_size :]
164
+ for _ in range (image_size - x_img .shape [1 ]):
165
+ q .append ((dummy_top , dummy_left ))
166
+ for i in range (x_img .shape [1 ]):
167
+ q .append (x_img [:, i ].chunk (4 , dim = - 1 )[:2 ])
168
+ cache [cache_key ] = q
169
+
170
+ return self .fn (x , cache = cache , ** kwargs )
144
171
145
172
# main transformer class
146
173
@@ -277,6 +304,11 @@ def forward(self, x, **kwargs):
277
304
return self .layers (x , rotary_pos_emb = self .pos_emb , ** kwargs )
278
305
279
306
def _get_static_mask (self , attn_type ):
307
+ # In case of attn_type = "axial_{row,col}",
308
+ # the sparse implementation is most efficient for training,
309
+ # but the full attention with a static mask is most efficient for inference
310
+ # since caching is implemented in this case.
311
+
280
312
img_seq_len = self .image_fmap_size ** 2
281
313
text_len = self .seq_len + 1 - img_seq_len
282
314
0 commit comments