5
5
# from QEfficient.customop import CtxScatterFunc
6
6
from QEfficient .utils .constants import Constants
7
7
from transformers .cache_utils import Cache
8
- from transformers .modeling_outputs import ModelOutput , CausalLMOutputWithPast
8
+ from transformers .modeling_outputs import CausalLMOutputWithPast , ModelOutput
9
9
from typing import List , Optional , Tuple , Union
10
10
11
11
@@ -217,16 +217,19 @@ def sampler_forward(
217
217
218
218
# Select relevant rows
219
219
batch_index_reshaped = batch_index .view (- 1 )
220
- past_repetition_penalty_buffer_selected = torch . index_select ( past_repetition_penalty_buffer , 0 , batch_index_reshaped )
221
- past_presence_penalty_buffer_selected = torch . index_select ( past_presence_penalty_buffer , 0 , batch_index_reshaped )
220
+ past_repetition_penalty_buffer_selected = past_repetition_penalty_buffer [ batch_index_reshaped ]
221
+ past_presence_penalty_buffer_selected = past_presence_penalty_buffer [ batch_index_reshaped ]
222
222
223
223
logits = logits .reshape (- 1 , vocab_size ) # Reshape tensor to 2D
224
224
225
225
if input_ids .shape [1 ] > spec_length : # Prefill phase, initialize retained states
226
226
# TODO: Replace scatter_ with CtxScatterFunc; Replace -1 with int_max while exporting on onnx
227
227
# past_repetition_penalty_buffer_selected = CtxScatterFunc.apply(past_repetition_penalty_buffer_selected.unsqueeze(1), input_ids, 1).squeeze(1)
228
+ if position_ids [0 , 0 ] == 0 :
229
+ past_repetition_penalty_buffer_selected = torch .mul (past_repetition_penalty_buffer_selected , 0 )
230
+ past_presence_penalty_buffer_selected = torch .mul (past_presence_penalty_buffer_selected , 0 )
228
231
past_repetition_penalty_buffer_selected .scatter_ (1 , input_ids , 1 )
229
- past_presence_penalty_buffer_selected . scatter_ ( 1 , input_ids , 0 )
232
+
230
233
else : # Decode phase, update retained states
231
234
past_repetition_penalty_buffer_selected .scatter_ (1 , last_accepted_output_tokens , 1 )
232
235
past_presence_penalty_buffer_selected .scatter_ (1 , last_accepted_output_tokens , 1 )
@@ -236,11 +239,26 @@ def sampler_forward(
236
239
past_repetition_penalty_buffer [batch_index_reshaped ] = past_repetition_penalty_buffer_selected
237
240
past_presence_penalty_buffer [batch_index_reshaped ] = past_presence_penalty_buffer_selected
238
241
242
+ # Greedy Sampling
243
+ greedy_samples = torch .argmax (logits , dim = 1 , keepdim = True ) # (batch_size * spec_length, 1)
244
+ if (temperatures == 0 ).all () and self .return_pdfs == False :
245
+ return QEffCausalLMOutputWithPast (
246
+ loss = None ,
247
+ logits = None ,
248
+ probs = None ,
249
+ next_tokens = greedy_samples .reshape (- 1 , spec_length , 1 ), # Return sampled next tokens instead of logits
250
+ past_key_values = outputs .past_key_values ,
251
+ hidden_states = outputs .hidden_states ,
252
+ attentions = outputs .attentions ,
253
+ past_repetition_penalty_buffer = past_repetition_penalty_buffer ,
254
+ past_presence_penalty_buffer = past_presence_penalty_buffer ,
255
+ )
256
+
239
257
# Repetition Penalty
240
258
if (repetition_penalties != 1. ).any ():
241
259
repetition_penalties = repetition_penalties .repeat (spec_length , vocab_size ) # (batch_size, 1) -> (batch_size * spec_length, vocab_size)
242
260
past_repetition_penalty_buffer_selected = past_repetition_penalty_buffer_selected .repeat (spec_length , 1 ) # (batch_size, vocab_size) -> (batch_size * spec_length, vocab_size)
243
- repetition_penalties [~ past_repetition_penalty_buffer_selected . bool () ] = 1.0
261
+ repetition_penalties [past_repetition_penalty_buffer_selected == 0 ] = 1.0
244
262
logits = torch .where (logits > 0 , logits / repetition_penalties , logits * repetition_penalties )
245
263
246
264
# Presence Penalty
@@ -252,12 +270,11 @@ def sampler_forward(
252
270
# TODO: Frequency Penalty
253
271
254
272
# Temperature Scaling
255
- if (temperatures != 0 ).any ():
256
- temperatures = temperatures .repeat (spec_length , 1 ) # (batch_size, 1) -> (batch_size * spec_length, 1)
257
- logits = torch .where (temperatures != 0 , logits / temperatures , logits )
258
-
273
+ temperatures = temperatures .repeat (spec_length , 1 ) # (batch_size, 1) -> (batch_size * spec_length, 1)
274
+ logits /= temperatures
275
+
259
276
# Top K
260
- # TODO (Optimization): if (top_ks != -1 or top_ks != Constants.MAX_TOP_K_IDS).any(): skip
277
+ # TODO (Optimization): if (top_ks != -1 or top_ks != Constants.MAX_TOP_K_IDS).any() is False : skip but will need topk_values_asc and topk_indices_asc
261
278
topk_values , topk_indices = torch .topk (logits , k = Constants .MAX_TOP_K_IDS , dim = 1 ) # (batch_size * spec_length, vocab_size)
262
279
topk_values_asc = topk_values .flip (dims = [1 ])
263
280
topk_indices_asc = topk_indices .flip (dims = [1 ])
@@ -267,42 +284,41 @@ def sampler_forward(
267
284
topk_values_asc [topk_mask ] = torch .finfo (torch .float16 ).min
268
285
269
286
# Top P
270
- # TODO (Optimization): if (top_ps != 1.).any(): skip but will need top_probs for Min P
287
+ # TODO (Optimization): if (top_ps != 1.).any() is False : skip but will need top_probs for Min P
271
288
top_probs = torch .softmax (topk_values_asc , dim = 1 ) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
272
289
topk_probs_sum = torch .cumsum (top_probs , dim = 1 )
273
290
top_p_mask = topk_probs_sum <= 1 - top_ps .repeat (spec_length , 1 ) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
274
291
top_p_mask [:, Constants .MAX_TOP_K_IDS - 1 ] = False
275
292
topk_values_asc [top_p_mask ] = torch .finfo (torch .float16 ).min
276
293
277
294
# Min P
278
- # TODO (Optimization): if (min_ps != 0.).any(): skip
279
- scaled_min_p = torch .mul (min_ps .repeat (spec_length , 1 ), top_probs [:, - 1 :]) # (batch_size * spec_length, 1)
280
- min_p_mask = top_probs < scaled_min_p # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
281
- topk_values_asc [min_p_mask ] = torch .finfo (torch .float16 ).min
282
-
283
- logits . fill_ ( torch . finfo ( torch . float16 ). min )
284
- logits = logits . scatter ( 1 , topk_indices_asc , topk_values_asc ) # (batch_size * spec_length, vocab_size)
285
-
286
- # Softmax
287
- # TODO (Optimization): if (temperatures == 0).all(): skip and perform only greedy sampling
288
- probs = torch . softmax ( logits , dim = 1 ) # (batch_size * spec_length, vocab_size)
289
-
290
- # Sample the next tokens
291
- # TODO (Optimization): if self.return_pds: skip
292
- greedy_samples = torch .argmax ( probs , dim = - 1 , keepdim = True ) # Greedy Sampling
295
+ if (min_ps != 0. ).any ():
296
+ scaled_min_p = torch .mul (min_ps .repeat (spec_length , 1 ), top_probs [:, Constants . MAX_TOP_K_IDS - 1 :]) # (batch_size * spec_length, 1)
297
+ min_p_mask = top_probs < scaled_min_p # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
298
+ topk_values_asc [min_p_mask ] = torch .finfo (torch .float16 ).min
299
+
300
+ probs = None
301
+ if self . return_pdfs :
302
+ # Update the logits
303
+ logits . fill_ ( torch . finfo ( torch . float16 ). min )
304
+ logits = logits . scatter ( 1 , topk_indices_asc , topk_values_asc ) # (batch_size * spec_length, vocab_size)
305
+ # Softmax
306
+ probs = torch . softmax ( logits , dim = 1 ). reshape ( - 1 , spec_length , vocab_size ) # (batch_size, spec_length, vocab_size)
307
+
308
+ # Random Sampling
309
+ topk_probs_asc = torch .softmax ( topk_values_asc , dim = 1 ) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
293
310
gumbel_noise = - torch .log (- torch .log (random_numbers .repeat (spec_length , 1 ))) # Gumbel-Max Trick
294
- y = probs + gumbel_noise
295
- random_samples = torch .argmax (y , dim = - 1 , keepdim = True ) # Random Sampling
296
- next_tokens = torch .where ( temperatures == 0 , greedy_samples , random_samples ) # (batch_size * spec_length, 1)
311
+ y = topk_probs_asc + gumbel_noise
312
+ random_samples_indices = torch .argmax (y , dim = 1 , keepdim = True )
313
+ random_samples = torch .gather ( topk_indices_asc , 1 , random_samples_indices ) # (batch_size * spec_length, 1)
297
314
298
- # Reshape tensor back to 3D
299
- probs = probs .reshape (- 1 , spec_length , vocab_size )
300
- next_tokens = next_tokens .reshape (- 1 , spec_length , 1 )
315
+ # Sample the next tokens
316
+ next_tokens = torch .where (temperatures == 0 , greedy_samples , random_samples ).reshape (- 1 , spec_length , 1 ) # (batch_size, spec_length, 1)
301
317
302
318
return QEffCausalLMOutputWithPast (
303
319
loss = None ,
304
320
logits = None ,
305
- probs = probs if self . return_pdfs else None , # Return probabilities instead of logits
321
+ probs = probs ,
306
322
next_tokens = next_tokens , # Return sampled next tokens instead of logits
307
323
past_key_values = outputs .past_key_values ,
308
324
hidden_states = outputs .hidden_states ,
0 commit comments