Skip to content

Commit 92b1c04

Browse files
committed
Cleaning and training removal done for phi3 phi gpt2 gpt whisper
Signed-off-by: Dipankar Sarkar <[email protected]>
1 parent 550b4ec commit 92b1c04

File tree

5 files changed

+18
-70
lines changed

5 files changed

+18
-70
lines changed

Diff for: QEfficient/transformers/models/gpt2/modeling_gpt2.py

-1
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,6 @@ def forward(
404404
value_states,
405405
attention_mask,
406406
head_mask=head_mask,
407-
dropout=self.attn_dropout.p if self.training else 0.0,
408407
**kwargs,
409408
)
410409
attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()

Diff for: QEfficient/transformers/models/gptj/modeling_gptj.py

-6
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,6 @@ def forward(
202202
if (input_ids is None) ^ (inputs_embeds is not None):
203203
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
204204

205-
if self.gradient_checkpointing and self.training:
206-
if use_cache:
207-
logger.warning_once(
208-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
209-
)
210-
use_cache = False
211205
if inputs_embeds is None:
212206
inputs_embeds = self.wte(input_ids)
213207

Diff for: QEfficient/transformers/models/phi/modeling_phi.py

+1-21
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def eager_attention_forward(
3232
value: torch.Tensor,
3333
attention_mask: Optional[torch.Tensor],
3434
scaling: float,
35-
dropout: float = 0.0,
3635
**kwargs,
3736
):
3837
key_states = repeat_kv(key, module.num_key_value_groups)
@@ -43,7 +42,6 @@ def eager_attention_forward(
4342
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
4443

4544
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
46-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
4745
attn_output = torch.matmul(attn_weights, value_states)
4846
attn_output = attn_output.transpose(1, 2).contiguous()
4947

@@ -113,7 +111,6 @@ def forward(
113111
key_states,
114112
value_states,
115113
attention_mask,
116-
dropout=0.0 if not self.training else self.attention_dropout,
117114
scaling=self.scaling,
118115
**kwargs,
119116
)
@@ -176,9 +173,7 @@ def forward(
176173
if position_ids is None:
177174
position_ids = cache_position.unsqueeze(0)
178175

179-
causal_mask = self._update_causal_mask(
180-
attention_mask, inputs_embeds, cache_position, position_ids, past_key_values, output_attentions
181-
)
176+
causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens)
182177

183178
inputs_embeds = self.embed_dropout(inputs_embeds)
184179
hidden_states = inputs_embeds
@@ -226,20 +221,6 @@ def forward(
226221
)
227222
return output if return_dict else output.to_tuple()
228223

229-
def _update_causal_mask(
230-
self,
231-
attention_mask: torch.Tensor,
232-
input_tensor: torch.Tensor,
233-
cache_position: torch.Tensor,
234-
position_ids: torch.Tensor,
235-
past_key_values: Cache,
236-
output_attentions: bool,
237-
):
238-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
239-
target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens
240-
causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length)
241-
return causal_mask
242-
243224

244225
class QEffPhiDecoderLayer(PhiDecoderLayer):
245226
"""
@@ -393,7 +374,6 @@ def forward(
393374
cache_position=cache_position,
394375
**kwargs,
395376
)
396-
hidden_states = outputs[0]
397377
# Cast to INT32 to avoid issue while running in ONNXRT
398378
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
399379
hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]

Diff for: QEfficient/transformers/models/phi3/modeling_phi3.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
rotate_half,
2626
)
2727

28+
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
29+
2830

2931
class QEffPhi3RotaryEmbedding(Phi3RotaryEmbedding):
3032
"""
@@ -98,7 +100,6 @@ def eager_attention_forward(
98100
value: torch.Tensor,
99101
attention_mask: Optional[torch.Tensor],
100102
scaling: float,
101-
dropout: float = 0.0,
102103
**kwargs,
103104
):
104105
key_states = repeat_kv(key, module.num_key_value_groups)
@@ -109,7 +110,6 @@ def eager_attention_forward(
109110
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
110111

111112
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
112-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
113113
attn_output = torch.matmul(attn_weights, value_states)
114114
attn_output = attn_output.transpose(1, 2).contiguous()
115115

@@ -182,7 +182,6 @@ def forward(
182182
key_states,
183183
value_states,
184184
attention_mask,
185-
dropout=0.0 if not self.training else self.attention_dropout,
186185
scaling=self.scaling,
187186
**kwargs,
188187
)
@@ -243,9 +242,7 @@ def forward(
243242
if position_ids is None:
244243
position_ids = cache_position.unsqueeze(0)
245244

246-
causal_mask = self._update_causal_mask(
247-
attention_mask, inputs_embeds, cache_position, position_ids, past_key_values, output_attentions
248-
)
245+
causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens)
249246

250247
hidden_states = inputs_embeds
251248

@@ -438,7 +435,6 @@ def forward(
438435
**kwargs,
439436
)
440437

441-
hidden_states = outputs[0]
442438
# Cast to INT32 to avoid issue while running in ONNXRT
443439
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
444440
hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]

Diff for: QEfficient/transformers/models/whisper/modeling_whisper.py

+14-35
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#
66
# ----------------------------------------------------------------------------
77

8-
import random
98
from typing import Optional, Tuple
109

1110
import torch
@@ -128,7 +127,7 @@ def forward(
128127
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
129128
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
130129

131-
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
130+
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout)
132131
attn_output = torch.matmul(attn_weights, value_states)
133132

134133
if tuple(attn_output.size()) != (bsz, self.num_heads, tgt_len, self.head_dim):
@@ -209,7 +208,7 @@ def forward(
209208
cache_position=cache_position,
210209
input_features=input_features,
211210
)
212-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
211+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout)
213212
hidden_states = residual + hidden_states
214213

215214
# Cross-Attention Block
@@ -230,7 +229,7 @@ def forward(
230229
input_features=input_features,
231230
is_cross_attention=True, # explicitly pass this argument, instead of figuring it out form key_value_states
232231
)
233-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
232+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout)
234233
hidden_states = residual + hidden_states
235234

236235
# update the cached past_key_values accordingly
@@ -244,9 +243,9 @@ def forward(
244243
residual = hidden_states
245244
hidden_states = self.final_layer_norm(hidden_states)
246245
hidden_states = self.activation_fn(self.fc1(hidden_states))
247-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
246+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout)
248247
hidden_states = self.fc2(hidden_states)
249-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
248+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout)
250249
hidden_states = residual + hidden_states
251250

252251
outputs = (hidden_states,)
@@ -319,7 +318,7 @@ def forward(
319318
embed_pos = self.embed_positions.weight
320319

321320
hidden_states = inputs_embeds + embed_pos
322-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
321+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout)
323322

324323
encoder_states = () if output_hidden_states else None
325324
all_attentions = () if output_attentions else None
@@ -334,33 +333,13 @@ def forward(
334333
if output_hidden_states:
335334
encoder_states = encoder_states + (hidden_states,)
336335
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
337-
dropout_probability = random.uniform(0, 1)
338-
if self.training and (dropout_probability < self.layerdrop): # skip the layer
339-
layer_outputs = (None, None)
340-
else:
341-
if self.gradient_checkpointing and self.training:
342-
343-
def create_custom_forward(module):
344-
def custom_forward(*inputs):
345-
return module(*inputs, output_attentions)
346-
347-
return custom_forward
348-
349-
layer_outputs = torch.utils.checkpoint.checkpoint(
350-
create_custom_forward(encoder_layer),
351-
hidden_states,
352-
None,
353-
(head_mask[idx] if head_mask is not None else None),
354-
)
355-
else:
356-
layer_outputs = encoder_layer(
357-
hidden_states,
358-
None,
359-
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
360-
output_attentions=output_attentions,
361-
)
362-
363-
hidden_states = layer_outputs[0]
336+
layer_outputs = encoder_layer(
337+
hidden_states,
338+
None,
339+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
340+
output_attentions=output_attentions,
341+
)
342+
hidden_states = layer_outputs[0]
364343

365344
if output_attentions:
366345
all_attentions = all_attentions + (layer_outputs[1],)
@@ -520,7 +499,7 @@ def forward(
520499
# embed positions
521500
positions = self.embed_positions(input_ids, past_key_values_length=position)
522501
hidden_states = inputs_embeds + positions
523-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
502+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout)
524503

525504
# decoder layers
526505
all_hidden_states = () if output_hidden_states else None

0 commit comments

Comments
 (0)