Skip to content

Commit d48d084

Browse files
committed
Optimizations (#2)
* Initial commit * Reformat code * Fix bug * Add Gumbel-Max trick based random sampling * Bring up to date * Use Gumbel-Max Trick based Random Sampling as default * Clip k to max value * Add docstring for sampling parameters * Fix bug * Add support for continuous batching * Fix ONNX error for batch_size 1 treated as a Constant * Undo docstring deletion * Remove device and unncessary reshapes * Revert batch_size to 1 * Remove vocab_size from dynamic axes * Change condition * Change size of each sampling parameter to (batch_size, 1) * Reformat code * Add optimizations * Identify optimizations * Fix bug * Fix merge issue * Optimizations: Perform random sampling only on topk_values_asc Only need logits for probs when self.return_pdfs is True * Remove where clause for temperature * Remove boolean type casting for retain state * Always return next_tokens * Fix bug * Reformat code * Initialize retain states * Optimize imports * Remove torch.index_select() * Change dtype of penalty buffers to bool --------- Signed-off-by: quic-sanising <[email protected]>
1 parent 7dfdda4 commit d48d084

File tree

3 files changed

+53
-37
lines changed

3 files changed

+53
-37
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838
CustomOpsTransform,
3939
KVCacheModuleMethodMapperTransform,
4040
KVCacheTransform,
41-
SpDTransform,
4241
SamplerTransform,
42+
SpDTransform,
4343
VlmKVOffloadTransform,
4444
VlmNoKVOffloadTransform,
4545
)
@@ -1483,7 +1483,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
14831483
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "num_logits_to_keep"}
14841484

14851485
example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
1486-
fbs if self.continuous_batching else bs, self.model.config.vocab_size, dtype=torch.int32)
1486+
fbs if self.continuous_batching else bs, self.model.config.vocab_size, dtype=torch.bool)
14871487
dynamic_axes["past_repetition_penalty_buffer"] = {
14881488
0: "full_batch_size" if self.continuous_batching else "batch_size",
14891489
}
@@ -1493,7 +1493,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
14931493
dynamic_axes["repetition_penalties"] = {0: "batch_size"}
14941494

14951495
example_inputs["past_presence_penalty_buffer"] = torch.zeros(
1496-
fbs if self.continuous_batching else bs, self.model.config.vocab_size, dtype=torch.int32)
1496+
fbs if self.continuous_batching else bs, self.model.config.vocab_size, dtype=torch.bool)
14971497
dynamic_axes["past_presence_penalty_buffer"] = {
14981498
0: "full_batch_size" if self.continuous_batching else "batch_size",
14991499
}

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,8 @@
246246
QEffWhisperModel,
247247
QEffWhisperPositionalEmbedding,
248248
)
249-
from QEfficient.transformers.spd.causal_lm_forward import tlm_forward
250249
from QEfficient.transformers.sampler.sampler import sampler_forward
250+
from QEfficient.transformers.spd.causal_lm_forward import tlm_forward
251251

252252
class CustomOpsTransform(ModuleMappingTransform):
253253
_module_mapping = {

QEfficient/transformers/sampler/sampler.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# from QEfficient.customop import CtxScatterFunc
66
from QEfficient.utils.constants import Constants
77
from transformers.cache_utils import Cache
8-
from transformers.modeling_outputs import ModelOutput, CausalLMOutputWithPast
8+
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
99
from typing import List, Optional, Tuple, Union
1010

1111

@@ -217,16 +217,19 @@ def sampler_forward(
217217

218218
# Select relevant rows
219219
batch_index_reshaped = batch_index.view(-1)
220-
past_repetition_penalty_buffer_selected = torch.index_select(past_repetition_penalty_buffer, 0, batch_index_reshaped)
221-
past_presence_penalty_buffer_selected = torch.index_select(past_presence_penalty_buffer, 0, batch_index_reshaped)
220+
past_repetition_penalty_buffer_selected = past_repetition_penalty_buffer[batch_index_reshaped]
221+
past_presence_penalty_buffer_selected = past_presence_penalty_buffer[batch_index_reshaped]
222222

223223
logits = logits.reshape(-1, vocab_size) # Reshape tensor to 2D
224224

225225
if input_ids.shape[1] > spec_length: # Prefill phase, initialize retained states
226226
# TODO: Replace scatter_ with CtxScatterFunc; Replace -1 with int_max while exporting on onnx
227227
# past_repetition_penalty_buffer_selected = CtxScatterFunc.apply(past_repetition_penalty_buffer_selected.unsqueeze(1), input_ids, 1).squeeze(1)
228+
if position_ids[0, 0] == 0:
229+
past_repetition_penalty_buffer_selected = torch.mul(past_repetition_penalty_buffer_selected, 0)
230+
past_presence_penalty_buffer_selected = torch.mul(past_presence_penalty_buffer_selected, 0)
228231
past_repetition_penalty_buffer_selected.scatter_(1, input_ids, 1)
229-
past_presence_penalty_buffer_selected.scatter_(1, input_ids, 0)
232+
230233
else: # Decode phase, update retained states
231234
past_repetition_penalty_buffer_selected.scatter_(1, last_accepted_output_tokens, 1)
232235
past_presence_penalty_buffer_selected.scatter_(1, last_accepted_output_tokens, 1)
@@ -236,11 +239,26 @@ def sampler_forward(
236239
past_repetition_penalty_buffer[batch_index_reshaped] = past_repetition_penalty_buffer_selected
237240
past_presence_penalty_buffer[batch_index_reshaped] = past_presence_penalty_buffer_selected
238241

242+
# Greedy Sampling
243+
greedy_samples = torch.argmax(logits, dim=1, keepdim=True) # (batch_size * spec_length, 1)
244+
if (temperatures == 0).all() and self.return_pdfs == False:
245+
return QEffCausalLMOutputWithPast(
246+
loss=None,
247+
logits=None,
248+
probs=None,
249+
next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits
250+
past_key_values=outputs.past_key_values,
251+
hidden_states=outputs.hidden_states,
252+
attentions=outputs.attentions,
253+
past_repetition_penalty_buffer=past_repetition_penalty_buffer,
254+
past_presence_penalty_buffer=past_presence_penalty_buffer,
255+
)
256+
239257
# Repetition Penalty
240258
if (repetition_penalties != 1.).any():
241259
repetition_penalties = repetition_penalties.repeat(spec_length, vocab_size) # (batch_size, 1) -> (batch_size * spec_length, vocab_size)
242260
past_repetition_penalty_buffer_selected = past_repetition_penalty_buffer_selected.repeat(spec_length, 1) # (batch_size, vocab_size) -> (batch_size * spec_length, vocab_size)
243-
repetition_penalties[~past_repetition_penalty_buffer_selected.bool()] = 1.0
261+
repetition_penalties[past_repetition_penalty_buffer_selected == 0] = 1.0
244262
logits = torch.where(logits > 0, logits / repetition_penalties, logits * repetition_penalties)
245263

246264
# Presence Penalty
@@ -252,12 +270,11 @@ def sampler_forward(
252270
# TODO: Frequency Penalty
253271

254272
# Temperature Scaling
255-
if (temperatures != 0).any():
256-
temperatures = temperatures.repeat(spec_length, 1) # (batch_size, 1) -> (batch_size * spec_length, 1)
257-
logits = torch.where(temperatures != 0, logits / temperatures, logits)
258-
273+
temperatures = temperatures.repeat(spec_length, 1) # (batch_size, 1) -> (batch_size * spec_length, 1)
274+
logits /= temperatures
275+
259276
# Top K
260-
# TODO (Optimization): if (top_ks != -1 or top_ks != Constants.MAX_TOP_K_IDS).any(): skip
277+
# TODO (Optimization): if (top_ks != -1 or top_ks != Constants.MAX_TOP_K_IDS).any() is False: skip but will need topk_values_asc and topk_indices_asc
261278
topk_values, topk_indices = torch.topk(logits, k=Constants.MAX_TOP_K_IDS, dim=1) # (batch_size * spec_length, vocab_size)
262279
topk_values_asc = topk_values.flip(dims=[1])
263280
topk_indices_asc = topk_indices.flip(dims=[1])
@@ -267,42 +284,41 @@ def sampler_forward(
267284
topk_values_asc[topk_mask] = torch.finfo(torch.float16).min
268285

269286
# Top P
270-
# TODO (Optimization): if (top_ps != 1.).any(): skip but will need top_probs for Min P
287+
# TODO (Optimization): if (top_ps != 1.).any() is False: skip but will need top_probs for Min P
271288
top_probs = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
272289
topk_probs_sum = torch.cumsum(top_probs, dim=1)
273290
top_p_mask = topk_probs_sum <= 1 - top_ps.repeat(spec_length, 1) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
274291
top_p_mask[:, Constants.MAX_TOP_K_IDS - 1] = False
275292
topk_values_asc[top_p_mask] = torch.finfo(torch.float16).min
276293

277294
# Min P
278-
# TODO (Optimization): if (min_ps != 0.).any(): skip
279-
scaled_min_p = torch.mul(min_ps.repeat(spec_length, 1), top_probs[:, -1:]) # (batch_size * spec_length, 1)
280-
min_p_mask = top_probs < scaled_min_p # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
281-
topk_values_asc[min_p_mask] = torch.finfo(torch.float16).min
282-
283-
logits.fill_(torch.finfo(torch.float16).min)
284-
logits = logits.scatter(1, topk_indices_asc, topk_values_asc) # (batch_size * spec_length, vocab_size)
285-
286-
# Softmax
287-
# TODO (Optimization): if (temperatures == 0).all(): skip and perform only greedy sampling
288-
probs = torch.softmax(logits, dim=1) # (batch_size * spec_length, vocab_size)
289-
290-
# Sample the next tokens
291-
# TODO (Optimization): if self.return_pds: skip
292-
greedy_samples = torch.argmax(probs, dim=-1, keepdim=True) # Greedy Sampling
295+
if (min_ps != 0.).any():
296+
scaled_min_p = torch.mul(min_ps.repeat(spec_length, 1), top_probs[:, Constants.MAX_TOP_K_IDS - 1:]) # (batch_size * spec_length, 1)
297+
min_p_mask = top_probs < scaled_min_p # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
298+
topk_values_asc[min_p_mask] = torch.finfo(torch.float16).min
299+
300+
probs = None
301+
if self.return_pdfs:
302+
# Update the logits
303+
logits.fill_(torch.finfo(torch.float16).min)
304+
logits = logits.scatter(1, topk_indices_asc, topk_values_asc) # (batch_size * spec_length, vocab_size)
305+
# Softmax
306+
probs = torch.softmax(logits, dim=1).reshape(-1, spec_length, vocab_size) # (batch_size, spec_length, vocab_size)
307+
308+
# Random Sampling
309+
topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
293310
gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick
294-
y = probs + gumbel_noise
295-
random_samples = torch.argmax(y, dim=-1, keepdim=True) # Random Sampling
296-
next_tokens = torch.where(temperatures == 0, greedy_samples, random_samples) # (batch_size * spec_length, 1)
311+
y = topk_probs_asc + gumbel_noise
312+
random_samples_indices = torch.argmax(y, dim=1, keepdim=True)
313+
random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1)
297314

298-
# Reshape tensor back to 3D
299-
probs = probs.reshape(-1, spec_length, vocab_size)
300-
next_tokens = next_tokens.reshape(-1, spec_length, 1)
315+
# Sample the next tokens
316+
next_tokens = torch.where(temperatures == 0, greedy_samples, random_samples).reshape(-1, spec_length, 1) # (batch_size, spec_length, 1)
301317

302318
return QEffCausalLMOutputWithPast(
303319
loss=None,
304320
logits=None,
305-
probs=probs if self.return_pdfs else None, # Return probabilities instead of logits
321+
probs=probs,
306322
next_tokens=next_tokens, # Return sampled next tokens instead of logits
307323
past_key_values=outputs.past_key_values,
308324
hidden_states=outputs.hidden_states,

0 commit comments

Comments
 (0)