21
21
GemmaForCausalLM ,
22
22
GemmaModel ,
23
23
GemmaRotaryEmbedding ,
24
- apply_rotary_pos_emb ,
25
24
repeat_kv ,
25
+ rotate_half ,
26
26
)
27
27
28
28
from QEfficient .transformers .modeling_attn_mask_utils import _create_causal_mask
29
29
30
30
31
- def eager_attention_forward (
32
- module : nn .Module ,
33
- query : torch .Tensor ,
34
- key : torch .Tensor ,
35
- value : torch .Tensor ,
36
- attention_mask : Optional [torch .Tensor ],
37
- scaling : float ,
38
- ** kwargs ,
39
- ):
40
- key_states = repeat_kv (key , module .num_key_value_groups )
41
- value_states = repeat_kv (value , module .num_key_value_groups )
42
-
43
- attn_weights = torch .matmul (query , key_states .transpose (2 , 3 )) * scaling
44
- if attention_mask is not None :
45
- attn_weights = torch .where (attention_mask , torch .tensor (- 10000.0 , dtype = torch .float32 ), attn_weights )
46
-
47
- attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
48
- attn_output = torch .matmul (attn_weights , value_states )
49
- attn_output = attn_output .transpose (1 , 2 ).contiguous ()
50
-
51
- return attn_output , attn_weights
52
-
53
-
54
31
class QEffGemmaRotaryEmbedding (GemmaRotaryEmbedding ):
55
32
"""
56
33
Copied from GemmaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py
@@ -87,6 +64,60 @@ def forward(self, x, seq_len=None):
87
64
)
88
65
89
66
67
+ def qeff_apply_rotary_pos_emb (q , k , cos , sin , position_ids , unsqueeze_dim = 1 ):
68
+ """Applies Rotary Position Embedding to the query and key tensors.
69
+
70
+ Args:
71
+ q (`torch.Tensor`): The query tensor.
72
+ k (`torch.Tensor`): The key tensor.
73
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
74
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
75
+ position_ids (`torch.Tensor`):
76
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
77
+ used to pass offsetted position ids when working with a KV-cache.
78
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
79
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
80
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
81
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
82
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
83
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
84
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
85
+ Returns:
86
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
87
+ """
88
+ cos = cos [position_ids ].unsqueeze (unsqueeze_dim )
89
+ sin = sin [position_ids ].unsqueeze (unsqueeze_dim )
90
+
91
+ # Apply rotation
92
+ q_embed = (q * cos ) + (rotate_half (q ) * sin )
93
+ k_embed = (k * cos ) + (rotate_half (k ) * sin )
94
+ # Cast back to original dtype
95
+ return q_embed .to (q .dtype ), k_embed .to (k .dtype )
96
+
97
+
98
+ def eager_attention_forward (
99
+ module : nn .Module ,
100
+ query : torch .Tensor ,
101
+ key : torch .Tensor ,
102
+ value : torch .Tensor ,
103
+ attention_mask : Optional [torch .Tensor ],
104
+ scaling : float ,
105
+ ** kwargs ,
106
+ ):
107
+ key_states = repeat_kv (key , module .num_key_value_groups )
108
+ value_states = repeat_kv (value , module .num_key_value_groups )
109
+
110
+ attn_weights = torch .matmul (query , key_states .transpose (2 , 3 )) * scaling
111
+ if attention_mask is not None :
112
+ attn_weights = torch .where (attention_mask , torch .tensor (- 10000.0 , dtype = torch .float32 ), attn_weights )
113
+
114
+ attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
115
+ attn_output = torch .matmul (attn_weights , value_states )
116
+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
117
+
118
+ return attn_output , attn_weights
119
+
120
+
90
121
class QEffGemmaAttention (GemmaAttention ):
91
122
"""
92
123
Copied from GemmaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py
@@ -107,7 +138,6 @@ def __qeff_init__(self):
107
138
def forward (
108
139
self ,
109
140
hidden_states : torch .Tensor ,
110
- position_embeddings : Tuple [torch .Tensor , torch .Tensor ],
111
141
attention_mask : Optional [torch .Tensor ],
112
142
position_ids : Optional [torch .LongTensor ] = None ,
113
143
past_key_value : Optional [Cache ] = None ,
@@ -125,8 +155,8 @@ def forward(
125
155
kv_seq_len = key_states .shape [- 2 ]
126
156
127
157
kv_seq_len = past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
128
- cos , sin = position_embeddings
129
- query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
158
+ cos , sin = self . rotary_emb ( value_states , seq_len = kv_seq_len )
159
+ query_states , key_states = qeff_apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
130
160
131
161
if past_key_value is not None :
132
162
# sin and cos are specific to RoPE models; cache_position needed for the static cache
@@ -167,7 +197,6 @@ def forward(
167
197
output_attentions : Optional [bool ] = False ,
168
198
use_cache : Optional [bool ] = False ,
169
199
cache_position : Optional [torch .LongTensor ] = None ,
170
- position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None , # necessary, but kept here for BC
171
200
** kwargs ,
172
201
) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
173
202
"""
@@ -200,7 +229,6 @@ def forward(
200
229
output_attentions = output_attentions ,
201
230
use_cache = use_cache ,
202
231
cache_position = cache_position ,
203
- position_embeddings = position_embeddings ,
204
232
** kwargs ,
205
233
)
206
234
hidden_states = residual + hidden_states
@@ -277,9 +305,6 @@ def forward(
277
305
# embed positions
278
306
hidden_states = inputs_embeds
279
307
280
- # create position embeddings to be shared across the decoder layers
281
- position_embeddings = self .rotary_emb (hidden_states , position_ids )
282
-
283
308
# normalized
284
309
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
285
310
# See https://github.com/huggingface/transformers/pull/29402
@@ -303,7 +328,6 @@ def forward(
303
328
output_attentions = output_attentions ,
304
329
use_cache = use_cache ,
305
330
cache_position = cache_position ,
306
- position_embeddings = position_embeddings ,
307
331
** kwargs ,
308
332
)
309
333
hidden_states = layer_outputs [0 ]
0 commit comments