Skip to content

Commit 1323bd8

Browse files
committed
Minor fixes
Signed-off-by: Amit Raj <[email protected]>
1 parent 9554336 commit 1323bd8

File tree

6 files changed

+335
-173
lines changed

6 files changed

+335
-173
lines changed

QEfficient/transformers/models/falcon/modeling_falcon.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
FalconForCausalLM,
2626
FalconModel,
2727
FalconRotaryEmbedding,
28-
apply_rotary_pos_emb,
2928
dropout_add,
29+
rotate_half,
3030
)
3131

3232
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
@@ -68,6 +68,37 @@ def forward(self, x, seq_len=None):
6868
)
6969

7070

71+
def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
72+
"""Applies Rotary Position Embedding to the query and key tensors.
73+
74+
Args:
75+
q (`torch.Tensor`): The query tensor.
76+
k (`torch.Tensor`): The key tensor.
77+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
78+
sin (`torch.Tensor`): The sine part of the rotary embedding.
79+
position_ids (`torch.Tensor`):
80+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
81+
used to pass offsetted position ids when working with a KV-cache.
82+
unsqueeze_dim (`int`, *optional*, defaults to 1):
83+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
84+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
85+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
86+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
87+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
88+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
89+
Returns:
90+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
91+
"""
92+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
93+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
94+
95+
# Apply rotation
96+
q_embed = (q * cos) + (rotate_half(q) * sin)
97+
k_embed = (k * cos) + (rotate_half(k) * sin)
98+
# Cast back to original dtype
99+
return q_embed.to(q.dtype), k_embed.to(k.dtype)
100+
101+
71102
class QEffFalconAttention(FalconAttention):
72103
"""
73104
Copied from FalconAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
@@ -91,13 +122,13 @@ def forward(
91122
alibi: Optional[torch.Tensor],
92123
attention_mask: torch.Tensor,
93124
position_ids: Optional[torch.LongTensor] = None,
125+
past_key_value: Optional[Cache] = None,
94126
batch_index: Optional[torch.LongTensor] = None,
95127
layer_past: Optional[Cache] = None,
96128
head_mask: Optional[torch.Tensor] = None,
97129
use_cache: bool = False,
98130
output_attentions: bool = False,
99131
cache_position: Optional[torch.LongTensor] = None,
100-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
101132
):
102133
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
103134
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
@@ -110,8 +141,10 @@ def forward(
110141
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
111142
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
112143

113-
cos, sin = position_embeddings
114-
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
144+
kv_seq_len = key_layer.shape[-2]
145+
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
146+
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
147+
query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
115148

116149
if layer_past is not None:
117150
cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
@@ -146,13 +179,13 @@ def forward(
146179
alibi: Optional[torch.Tensor],
147180
attention_mask: torch.Tensor,
148181
position_ids: Optional[torch.LongTensor] = None,
182+
past_key_value: Optional[Cache] = None,
149183
batch_index: Optional[torch.LongTensor] = None,
150184
layer_past: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None,
151185
head_mask: Optional[torch.Tensor] = None,
152186
use_cache: bool = False,
153187
output_attentions: bool = False,
154188
cache_position: Optional[torch.LongTensor] = None,
155-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
156189
**kwargs,
157190
):
158191
residual = hidden_states
@@ -165,13 +198,13 @@ def forward(
165198
layer_past=layer_past,
166199
attention_mask=attention_mask,
167200
position_ids=position_ids,
201+
past_key_value=past_key_value,
168202
batch_index=batch_index,
169203
alibi=alibi,
170204
head_mask=head_mask,
171205
use_cache=use_cache,
172206
output_attentions=output_attentions,
173207
cache_position=cache_position,
174-
position_embeddings=position_embeddings,
175208
)
176209

177210
attention_output = attn_outputs[0]
@@ -274,9 +307,6 @@ def forward(
274307
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
275308
hidden_states = inputs_embeds
276309

277-
# create position embeddings to be shared across the decoder layers
278-
position_embeddings = self.rotary_emb(hidden_states, position_ids)
279-
280310
all_self_attentions = () if output_attentions else None
281311
all_hidden_states = () if output_hidden_states else None
282312

@@ -289,13 +319,13 @@ def forward(
289319
layer_past=past_key_values,
290320
attention_mask=causal_mask,
291321
position_ids=position_ids,
322+
past_key_value=past_key_values,
292323
batch_index=batch_index,
293324
head_mask=head_mask[i],
294325
use_cache=use_cache,
295326
output_attentions=output_attentions,
296327
alibi=alibi,
297328
cache_position=cache_position,
298-
position_embeddings=position_embeddings,
299329
)
300330

301331
hidden_states = outputs[0]

QEfficient/transformers/models/gemma/modeling_gemma.py

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,36 +21,13 @@
2121
GemmaForCausalLM,
2222
GemmaModel,
2323
GemmaRotaryEmbedding,
24-
apply_rotary_pos_emb,
2524
repeat_kv,
25+
rotate_half,
2626
)
2727

2828
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
2929

3030

31-
def eager_attention_forward(
32-
module: nn.Module,
33-
query: torch.Tensor,
34-
key: torch.Tensor,
35-
value: torch.Tensor,
36-
attention_mask: Optional[torch.Tensor],
37-
scaling: float,
38-
**kwargs,
39-
):
40-
key_states = repeat_kv(key, module.num_key_value_groups)
41-
value_states = repeat_kv(value, module.num_key_value_groups)
42-
43-
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
44-
if attention_mask is not None:
45-
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
46-
47-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
48-
attn_output = torch.matmul(attn_weights, value_states)
49-
attn_output = attn_output.transpose(1, 2).contiguous()
50-
51-
return attn_output, attn_weights
52-
53-
5431
class QEffGemmaRotaryEmbedding(GemmaRotaryEmbedding):
5532
"""
5633
Copied from GemmaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py
@@ -87,6 +64,60 @@ def forward(self, x, seq_len=None):
8764
)
8865

8966

67+
def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
68+
"""Applies Rotary Position Embedding to the query and key tensors.
69+
70+
Args:
71+
q (`torch.Tensor`): The query tensor.
72+
k (`torch.Tensor`): The key tensor.
73+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
74+
sin (`torch.Tensor`): The sine part of the rotary embedding.
75+
position_ids (`torch.Tensor`):
76+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
77+
used to pass offsetted position ids when working with a KV-cache.
78+
unsqueeze_dim (`int`, *optional*, defaults to 1):
79+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
80+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
81+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
82+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
83+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
84+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
85+
Returns:
86+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
87+
"""
88+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
89+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
90+
91+
# Apply rotation
92+
q_embed = (q * cos) + (rotate_half(q) * sin)
93+
k_embed = (k * cos) + (rotate_half(k) * sin)
94+
# Cast back to original dtype
95+
return q_embed.to(q.dtype), k_embed.to(k.dtype)
96+
97+
98+
def eager_attention_forward(
99+
module: nn.Module,
100+
query: torch.Tensor,
101+
key: torch.Tensor,
102+
value: torch.Tensor,
103+
attention_mask: Optional[torch.Tensor],
104+
scaling: float,
105+
**kwargs,
106+
):
107+
key_states = repeat_kv(key, module.num_key_value_groups)
108+
value_states = repeat_kv(value, module.num_key_value_groups)
109+
110+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
111+
if attention_mask is not None:
112+
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
113+
114+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
115+
attn_output = torch.matmul(attn_weights, value_states)
116+
attn_output = attn_output.transpose(1, 2).contiguous()
117+
118+
return attn_output, attn_weights
119+
120+
90121
class QEffGemmaAttention(GemmaAttention):
91122
"""
92123
Copied from GemmaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py
@@ -107,7 +138,6 @@ def __qeff_init__(self):
107138
def forward(
108139
self,
109140
hidden_states: torch.Tensor,
110-
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
111141
attention_mask: Optional[torch.Tensor],
112142
position_ids: Optional[torch.LongTensor] = None,
113143
past_key_value: Optional[Cache] = None,
@@ -125,8 +155,8 @@ def forward(
125155
kv_seq_len = key_states.shape[-2]
126156

127157
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
128-
cos, sin = position_embeddings
129-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
158+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
159+
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
130160

131161
if past_key_value is not None:
132162
# sin and cos are specific to RoPE models; cache_position needed for the static cache
@@ -167,7 +197,6 @@ def forward(
167197
output_attentions: Optional[bool] = False,
168198
use_cache: Optional[bool] = False,
169199
cache_position: Optional[torch.LongTensor] = None,
170-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
171200
**kwargs,
172201
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
173202
"""
@@ -200,7 +229,6 @@ def forward(
200229
output_attentions=output_attentions,
201230
use_cache=use_cache,
202231
cache_position=cache_position,
203-
position_embeddings=position_embeddings,
204232
**kwargs,
205233
)
206234
hidden_states = residual + hidden_states
@@ -277,9 +305,6 @@ def forward(
277305
# embed positions
278306
hidden_states = inputs_embeds
279307

280-
# create position embeddings to be shared across the decoder layers
281-
position_embeddings = self.rotary_emb(hidden_states, position_ids)
282-
283308
# normalized
284309
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
285310
# See https://github.com/huggingface/transformers/pull/29402
@@ -303,7 +328,6 @@ def forward(
303328
output_attentions=output_attentions,
304329
use_cache=use_cache,
305330
cache_position=cache_position,
306-
position_embeddings=position_embeddings,
307331
**kwargs,
308332
)
309333
hidden_states = layer_outputs[0]

QEfficient/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
Gemma2ForCausalLM,
2323
Gemma2Model,
2424
Gemma2RotaryEmbedding,
25-
apply_rotary_pos_emb,
2625
repeat_kv,
26+
rotate_half,
2727
)
2828

2929
# from transformers.utils import is_torchdynamo_compiling
@@ -66,6 +66,37 @@ def forward(self, x, seq_len=None):
6666
)
6767

6868

69+
def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
70+
"""Applies Rotary Position Embedding to the query and key tensors.
71+
72+
Args:
73+
q (`torch.Tensor`): The query tensor.
74+
k (`torch.Tensor`): The key tensor.
75+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
76+
sin (`torch.Tensor`): The sine part of the rotary embedding.
77+
position_ids (`torch.Tensor`):
78+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
79+
used to pass offsetted position ids when working with a KV-cache.
80+
unsqueeze_dim (`int`, *optional*, defaults to 1):
81+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
82+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
83+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
84+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
85+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
86+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
87+
Returns:
88+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
89+
"""
90+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
91+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
92+
93+
# Apply rotation
94+
q_embed = (q * cos) + (rotate_half(q) * sin)
95+
k_embed = (k * cos) + (rotate_half(k) * sin)
96+
# Cast back to original dtype
97+
return q_embed.to(q.dtype), k_embed.to(k.dtype)
98+
99+
69100
def eager_attention_forward(
70101
module: nn.Module,
71102
query: torch.Tensor,
@@ -112,7 +143,6 @@ def __qeff_init__(self):
112143
def forward(
113144
self,
114145
hidden_states: torch.Tensor,
115-
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
116146
attention_mask: Optional[torch.Tensor],
117147
position_ids: Optional[torch.LongTensor] = None,
118148
past_key_value: Optional[Cache] = None,
@@ -130,8 +160,8 @@ def forward(
130160
kv_seq_len = key_states.shape[-2]
131161

132162
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
133-
cos, sin = position_embeddings
134-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
163+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
164+
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
135165

136166
if past_key_value is not None:
137167
# sin and cos are specific to RoPE models; cache_position needed for the static cache
@@ -172,7 +202,6 @@ def forward(
172202
output_attentions: Optional[bool] = False,
173203
use_cache: Optional[bool] = False,
174204
cache_position: Optional[torch.LongTensor] = None,
175-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
176205
**kwargs,
177206
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
178207
"""
@@ -205,7 +234,6 @@ def forward(
205234
output_attentions=output_attentions,
206235
use_cache=use_cache,
207236
cache_position=cache_position,
208-
position_embeddings=position_embeddings,
209237
**kwargs,
210238
)
211239
hidden_states = self.post_attention_layernorm(hidden_states)
@@ -295,9 +323,6 @@ def forward(
295323
# embed positions
296324
hidden_states = inputs_embeds
297325

298-
# create position embeddings to be shared across the decoder layers
299-
position_embeddings = self.rotary_emb(hidden_states, position_ids)
300-
301326
# normalized
302327
# Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
303328
# See https://github.com/huggingface/transformers/pull/29402
@@ -321,7 +346,6 @@ def forward(
321346
output_attentions=output_attentions,
322347
use_cache=use_cache,
323348
cache_position=cache_position,
324-
position_embeddings=position_embeddings,
325349
**kwargs,
326350
)
327351

0 commit comments

Comments
 (0)