5
5
#
6
6
# ----------------------------------------------------------------------------
7
7
8
- import random
9
8
from typing import Optional , Tuple
10
9
11
10
import torch
@@ -128,7 +127,7 @@ def forward(
128
127
attn_weights = layer_head_mask .view (1 , - 1 , 1 , 1 ) * attn_weights .view (bsz , self .num_heads , tgt_len , src_len )
129
128
attn_weights = attn_weights .view (bsz * self .num_heads , tgt_len , src_len )
130
129
131
- attn_weights = nn .functional .dropout (attn_weights , p = self .dropout , training = self . training )
130
+ attn_weights = nn .functional .dropout (attn_weights , p = self .dropout )
132
131
attn_output = torch .matmul (attn_weights , value_states )
133
132
134
133
if tuple (attn_output .size ()) != (bsz , self .num_heads , tgt_len , self .head_dim ):
@@ -209,7 +208,7 @@ def forward(
209
208
cache_position = cache_position ,
210
209
input_features = input_features ,
211
210
)
212
- hidden_states = nn .functional .dropout (hidden_states , p = self .dropout , training = self . training )
211
+ hidden_states = nn .functional .dropout (hidden_states , p = self .dropout )
213
212
hidden_states = residual + hidden_states
214
213
215
214
# Cross-Attention Block
@@ -230,7 +229,7 @@ def forward(
230
229
input_features = input_features ,
231
230
is_cross_attention = True , # explicitly pass this argument, instead of figuring it out form key_value_states
232
231
)
233
- hidden_states = nn .functional .dropout (hidden_states , p = self .dropout , training = self . training )
232
+ hidden_states = nn .functional .dropout (hidden_states , p = self .dropout )
234
233
hidden_states = residual + hidden_states
235
234
236
235
# update the cached past_key_values accordingly
@@ -244,9 +243,9 @@ def forward(
244
243
residual = hidden_states
245
244
hidden_states = self .final_layer_norm (hidden_states )
246
245
hidden_states = self .activation_fn (self .fc1 (hidden_states ))
247
- hidden_states = nn .functional .dropout (hidden_states , p = self .activation_dropout , training = self . training )
246
+ hidden_states = nn .functional .dropout (hidden_states , p = self .activation_dropout )
248
247
hidden_states = self .fc2 (hidden_states )
249
- hidden_states = nn .functional .dropout (hidden_states , p = self .dropout , training = self . training )
248
+ hidden_states = nn .functional .dropout (hidden_states , p = self .dropout )
250
249
hidden_states = residual + hidden_states
251
250
252
251
outputs = (hidden_states ,)
@@ -319,7 +318,7 @@ def forward(
319
318
embed_pos = self .embed_positions .weight
320
319
321
320
hidden_states = inputs_embeds + embed_pos
322
- hidden_states = nn .functional .dropout (hidden_states , p = self .dropout , training = self . training )
321
+ hidden_states = nn .functional .dropout (hidden_states , p = self .dropout )
323
322
324
323
encoder_states = () if output_hidden_states else None
325
324
all_attentions = () if output_attentions else None
@@ -334,33 +333,13 @@ def forward(
334
333
if output_hidden_states :
335
334
encoder_states = encoder_states + (hidden_states ,)
336
335
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
337
- dropout_probability = random .uniform (0 , 1 )
338
- if self .training and (dropout_probability < self .layerdrop ): # skip the layer
339
- layer_outputs = (None , None )
340
- else :
341
- if self .gradient_checkpointing and self .training :
342
-
343
- def create_custom_forward (module ):
344
- def custom_forward (* inputs ):
345
- return module (* inputs , output_attentions )
346
-
347
- return custom_forward
348
-
349
- layer_outputs = torch .utils .checkpoint .checkpoint (
350
- create_custom_forward (encoder_layer ),
351
- hidden_states ,
352
- None ,
353
- (head_mask [idx ] if head_mask is not None else None ),
354
- )
355
- else :
356
- layer_outputs = encoder_layer (
357
- hidden_states ,
358
- None ,
359
- layer_head_mask = (head_mask [idx ] if head_mask is not None else None ),
360
- output_attentions = output_attentions ,
361
- )
362
-
363
- hidden_states = layer_outputs [0 ]
336
+ layer_outputs = encoder_layer (
337
+ hidden_states ,
338
+ None ,
339
+ layer_head_mask = (head_mask [idx ] if head_mask is not None else None ),
340
+ output_attentions = output_attentions ,
341
+ )
342
+ hidden_states = layer_outputs [0 ]
364
343
365
344
if output_attentions :
366
345
all_attentions = all_attentions + (layer_outputs [1 ],)
@@ -520,7 +499,7 @@ def forward(
520
499
# embed positions
521
500
positions = self .embed_positions (input_ids , past_key_values_length = position )
522
501
hidden_states = inputs_embeds + positions
523
- hidden_states = nn .functional .dropout (hidden_states , p = self .dropout , training = self . training )
502
+ hidden_states = nn .functional .dropout (hidden_states , p = self .dropout )
524
503
525
504
# decoder layers
526
505
all_hidden_states = () if output_hidden_states else None
0 commit comments