@@ -123,14 +123,12 @@ def __init__(
123
123
head_dim : int ,
124
124
n_rep : int ,
125
125
max_context_len : int ,
126
- enable_dynamic_shape : bool ,
127
126
):
128
127
super ().__init__ ()
129
128
self .dim = dim
130
129
self .head_dim = head_dim
131
130
self .n_rep = n_rep
132
131
self .max_context_len = max_context_len
133
- self .enable_dynamic_shape = enable_dynamic_shape
134
132
135
133
def forward (
136
134
self ,
@@ -142,21 +140,12 @@ def forward(
142
140
seqlen ,
143
141
mask : torch .Tensor ,
144
142
) -> torch .Tensor :
145
- if self .enable_dynamic_shape :
146
- start_pos = input_pos [- 1 ].item ()
147
- torch ._check_is_size (start_pos )
148
- torch ._check (start_pos < self .max_context_len )
149
- seq_length = q .size (2 )
150
- # pyre-ignore: Incompatible parameter type [6]
151
- attn_mask = mask .narrow (0 , start_pos , seq_length )
152
- else :
153
- attn_mask = mask [None , None , input_pos ]
154
143
155
144
# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
156
145
# can natively support GQA now. But needs enable_gqa=True
157
146
k = k .repeat_interleave (self .n_rep , dim = 1 )
158
147
v = v .repeat_interleave (self .n_rep , dim = 1 )
159
- y = F .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask , dropout_p = 0.0 )
148
+ y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
160
149
161
150
return y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
162
151
@@ -236,21 +225,79 @@ def __init__(
236
225
enable_dynamic_shape : bool ,
237
226
dtype = torch .float32 ,
238
227
):
228
+ self .window_size = max_context_length
229
+ """
230
+ Reason why we want the kv cache size to be twice the context length:
231
+ Sliding window attention without ringbuffer
232
+ pos 0 1 2 3 4 5 6 7 8 9 10
233
+ 0 x 0 0 0 0 0 0 0 0 0 0
234
+ 1 x x 0 0 0 0 0 0 0 0 0
235
+ 2 x x x 0 0 0 0 0 0 0 0
236
+ 3 x x x x 0 0 0 0 0 0 0
237
+ 4 0 x x x x 0 0 0 0 0 0
238
+ 5 0 0 x x x x 0 0 0 0 0
239
+ 6 0 0 0 x x x x 0 0 0 0
240
+ 7 0 0 0 0 x x x x 0 0 0
241
+ 8 0 0 0 0 0 x x x x 0 0
242
+ 9 0 0 0 0 0 0 x x x x 0
243
+ 10 0 0 0 0 0 0 0 x x x x
244
+
245
+ So when doing attention for pos = 5 and seq_len = 4 our attention
246
+ mask would be
247
+ 5 0 0 x x x x 0 0 0 0 0
248
+ 6 0 0 0 x x x x 0 0 0 0
249
+ 7 0 0 0 0 x x x x 0 0 0
250
+ 8 0 0 0 0 0 x x x x 0 0
251
+ Thus tok at pos = 5 is able to attend to tokens at pos 2, 3 and 4.
252
+ This is how training is done.
253
+
254
+ Now lets consider ring kv cache of size 4. When we are at pos = 5
255
+ before updating the kv cache, state of the kv cache would be
256
+ [4 1 2 3]. That is we evicted token at pos = 0 out. Now during
257
+ attention calculation at pos = 5 seq len = 4, we will update cache and
258
+ new pos in the cache would be [8 5 6 7]. So note that 5 can now only attend
259
+ to itself. Not 2, 3 and 4 as you would have during training.
260
+ So not having kept 2, 3 and 4 in cache means we will have divergent behavior.
261
+ Worst case of this would have been when update it equal to the length of
262
+ the cache. like in our case pos = 5 seq len = 4.
263
+ Thus we need to have a cache that is larger. How much larger, as much as
264
+ the sliding window size. So twice the max_context_length.
265
+ How would that have helped. Lets see. At pos = 5 our cache would have
266
+ [0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have
267
+ [8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the
268
+ current step still has access to [pos - sliding_window_size, pos] tokens.
269
+
270
+ To make sure we dont over attend, i.e. we dont have pos = 5
271
+ to attend to pos = 1, mask calculaton has to account for the sliding window
272
+ size.
273
+ """
239
274
super ().__init__ (
240
275
max_batch_size ,
241
- max_context_length ,
276
+ max_context_length * 2 ,
242
277
n_heads ,
243
278
head_dim ,
244
279
enable_dynamic_shape ,
245
280
dtype ,
246
281
)
247
- self .cache_positions_manager = CachePositionsManager (max_context_length )
282
+ self .cache_positions_manager = CachePositionsManager (self .max_context_length )
283
+ self .is_ring_buffer = True
284
+
285
+ def create_causal_mask_for_ring_buffer (self , start_pos , seq_len ):
286
+ pos_q = start_pos + torch .arange (seq_len , dtype = torch .long ).view (- 1 , 1 )
287
+ cache_positions = self .cache_positions_manager .cache_positions
288
+ delta = pos_q - cache_positions
289
+ attn_mask = (cache_positions >= 0 ) & (delta >= 0 ) & (delta < self .window_size )
290
+ attn_mask = torch .where (attn_mask == True , 0 , float ("-inf" )) # noqa E712
291
+ return attn_mask
248
292
249
293
def update (
250
294
self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor
251
295
) -> Tuple [torch .Tensor , torch .Tensor ]:
252
296
# input_pos: [S], k_val: [B, H, S, D]
253
297
seq_len = k_val .size (2 )
298
+ assert seq_len <= self .k_cache .size (
299
+ 2
300
+ ), f"Update sequence length({ seq_len } ) for kv cache must be smaller than the cache size({ self .k_cache .size (2 )} )"
254
301
indices = self .cache_positions_manager .calculate_positions_and_update_indices (
255
302
input_pos , seq_len
256
303
)
@@ -286,6 +333,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
286
333
self .attention_qkv_bias = args .attention_qkv_bias
287
334
self .use_qk_norm = args .use_qk_norm
288
335
self .qk_norm_before_rope = args .qk_norm_before_rope
336
+ self .enable_dynamic_shape = args .enable_dynamic_shape
289
337
290
338
if self .use_qk_norm :
291
339
q_norm_dim = self .head_dim
@@ -331,7 +379,6 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
331
379
head_dim = self .head_dim ,
332
380
n_rep = self .n_rep ,
333
381
max_context_len = self .max_context_len ,
334
- enable_dynamic_shape = args .enable_dynamic_shape ,
335
382
)
336
383
337
384
def forward (
@@ -368,8 +415,22 @@ def forward(
368
415
369
416
if self .use_kv_cache :
370
417
assert input_pos is not None
418
+ if self .enable_dynamic_shape :
419
+ start_pos = input_pos [- 1 ].item ()
420
+ torch ._check_is_size (start_pos )
421
+ torch ._check (start_pos < self .max_context_len )
422
+ seq_length = q .size (2 )
423
+ # pyre-ignore: Incompatible parameter type [6]
424
+ attn_mask = self .mask .narrow (0 , start_pos , seq_length )
425
+ else :
426
+ # mask is always 2D
427
+ attn_mask = self .mask [input_pos ]
371
428
k , v = self .kv_cache .update (input_pos , k , v )
372
- output = self .SDPA (input_pos , q , k , v , bsz , seqlen , self .mask )
429
+ if getattr (self .kv_cache , "is_ring_buffer" , False ):
430
+ attn_mask = self .kv_cache .create_causal_mask_for_ring_buffer (
431
+ input_pos [0 ].item (), seqlen
432
+ )
433
+ output = self .SDPA (input_pos , q , k , v , bsz , seqlen , attn_mask )
373
434
return self .wo (output ), None
374
435
375
436
# grouped multiquery attention: expand out keys and values
0 commit comments