@@ -92,7 +92,7 @@ def __init__(self, dim, dropout = 0., mult = 4.):
92
92
nn .Linear (dim * mult , dim )
93
93
)
94
94
95
- def forward (self , x ):
95
+ def forward (self , x , cache = None , cache_key = None ):
96
96
return self .net (x )
97
97
98
98
# token shift classes
@@ -104,7 +104,13 @@ def __init__(self, fn, image_size, seq_len):
104
104
self .image_size = image_size
105
105
self .seq_len = seq_len
106
106
107
- def forward (self , x , ** kwargs ):
107
+ def forward (self , x , cache = None , cache_key = None , ** kwargs ):
108
+ n0 = x .shape [1 ]
109
+ if exists (cache ):
110
+ if cache_key in cache :
111
+ x = torch .cat ([cache [cache_key ], x ], dim = - 2 )
112
+ cache [cache_key ] = x
113
+
108
114
n = x .shape [1 ]
109
115
seq_len , image_size = self .seq_len , self .image_size
110
116
img_seq_len = image_size ** 2
@@ -134,7 +140,7 @@ def forward(self, x, **kwargs):
134
140
135
141
x_img = rearrange (x_img , 'b h w d -> b (h w) d' )
136
142
x = torch .cat ((x_text , x_img [:, :- padding ]), dim = 1 )
137
- return self .fn (x , ** kwargs )
143
+ return self .fn (x [:, - n0 :], cache = cache , ** kwargs )
138
144
139
145
# main transformer class
140
146
@@ -221,7 +227,8 @@ def __init__(
221
227
attn = CachedAs (f'attn_{ ind } ' , attn )
222
228
223
229
if shift_tokens :
224
- attn , ff = map (lambda t : PreShiftToken (t , image_size = image_fmap_size , seq_len = seq_len ), (attn , ff ))
230
+ attn = CachedAs (f'preshift_attn_{ ind } ' , PreShiftToken (attn , image_size = image_fmap_size , seq_len = seq_len ))
231
+ ff = CachedAs (f'preshift_ff_{ ind } ' , PreShiftToken (ff , image_size = image_fmap_size , seq_len = seq_len ))
225
232
226
233
layers .append (nn .ModuleList ([
227
234
LayerScale (dim , ind + 1 , PreNorm (dim , attn , sandwich = sandwich_norm )),
@@ -230,8 +237,9 @@ def __init__(
230
237
231
238
execute_type = ReversibleSequence if reversible else SequentialSequence
232
239
route_attn = ((True , False ),) * depth
240
+ route_all = ((True , True ),) * depth
233
241
attn_route_map = {'mask' : route_attn , 'rotary_pos_emb' : route_attn ,
234
- 'cache' : route_attn }
242
+ 'cache' : route_all }
235
243
236
244
self .layers = execute_type (layers , args_route = attn_route_map )
237
245
@@ -270,9 +278,9 @@ def forward(self, x, **kwargs):
270
278
271
279
def _get_static_mask (self , attn_type ):
272
280
img_seq_len = self .image_fmap_size ** 2
273
- text_len = self .seq_len - img_seq_len
281
+ text_len = self .seq_len + 1 - img_seq_len
274
282
275
- static_mask = torch .ones (self .seq_len , self .seq_len , dtype = torch .bool )
283
+ static_mask = torch .zeros (self .seq_len , self .seq_len , dtype = torch .bool )
276
284
static_mask [:, :text_len ] = True
277
285
if attn_type == 'axial_row' :
278
286
for row in range (self .image_fmap_size ):
0 commit comments