@@ -140,7 +140,7 @@ def forward_native(
140
140
query : torch .Tensor ,
141
141
key : Optional [torch .Tensor ] = None ,
142
142
offsets : Optional [torch .Tensor ] = None ,
143
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
143
+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
144
144
"""A PyTorch-native implementation of forward()."""
145
145
if offsets is not None :
146
146
positions = positions + offsets
@@ -174,7 +174,7 @@ def forward_cuda(
174
174
query : torch .Tensor ,
175
175
key : Optional [torch .Tensor ] = None ,
176
176
offsets : Optional [torch .Tensor ] = None ,
177
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
177
+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
178
178
from vllm import _custom_ops as ops
179
179
180
180
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
@@ -202,7 +202,7 @@ def forward_xpu(
202
202
query : torch .Tensor ,
203
203
key : Optional [torch .Tensor ] = None ,
204
204
offsets : Optional [torch .Tensor ] = None ,
205
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
205
+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
206
206
from vllm ._ipex_ops import ipex_ops as ops
207
207
208
208
self .cos_sin_cache = self .cos_sin_cache .to (positions .device ,
@@ -232,7 +232,7 @@ def forward_hpu(
232
232
query : torch .Tensor ,
233
233
key : Optional [torch .Tensor ] = None ,
234
234
offsets : Optional [torch .Tensor ] = None ,
235
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
235
+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
236
236
from habana_frameworks .torch .hpex .kernels import (
237
237
RotaryPosEmbeddingMode , apply_rotary_pos_emb )
238
238
if offsets is not None :
@@ -290,7 +290,7 @@ def forward_neuron(
290
290
query : torch .Tensor ,
291
291
key : Optional [torch .Tensor ] = None ,
292
292
offsets : Optional [torch .Tensor ] = None ,
293
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
293
+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
294
294
295
295
def _apply_rotary_emb_neuron (
296
296
x : torch .Tensor ,
@@ -688,7 +688,7 @@ def forward(
688
688
query : torch .Tensor ,
689
689
key : Optional [torch .Tensor ] = None ,
690
690
offsets : Optional [torch .Tensor ] = None ,
691
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
691
+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
692
692
assert key is not None
693
693
query = query .view (* query .shape [:- 1 ], - 1 , self .head_size )
694
694
key = key .view (* key .shape [:- 1 ], - 1 , self .head_size )
@@ -799,7 +799,7 @@ def forward(
799
799
query : torch .Tensor ,
800
800
key : Optional [torch .Tensor ] = None ,
801
801
offsets : Optional [torch .Tensor ] = None ,
802
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
802
+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
803
803
"""PyTorch-native implementation equivalent to forward()."""
804
804
assert key is not None
805
805
query_rot = query [..., :self .rotary_dim ]
@@ -929,7 +929,7 @@ def forward(
929
929
self ,
930
930
query : torch .Tensor ,
931
931
key : Optional [torch .Tensor ] = None ,
932
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
932
+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
933
933
assert key is not None
934
934
self .cos_sin_cache : torch .Tensor = self .cos_sin_cache .to (query .device )
935
935
query_ = torch .view_as_complex (query .float ().reshape (
@@ -975,7 +975,7 @@ def forward(
975
975
positions : torch .Tensor ,
976
976
query : torch .Tensor ,
977
977
key : Optional [torch .Tensor ] = None ,
978
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
978
+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] ]:
979
979
"""PyTorch-native implementation equivalent to forward().
980
980
981
981
Args:
0 commit comments