Skip to content

Commit d12c2a4

Browse files
committed
Fix return type for rot embedding
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent c245d16 commit d12c2a4

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def forward_native(
140140
query: torch.Tensor,
141141
key: Optional[torch.Tensor] = None,
142142
offsets: Optional[torch.Tensor] = None,
143-
) -> Tuple[torch.Tensor, torch.Tensor]:
143+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
144144
"""A PyTorch-native implementation of forward()."""
145145
if offsets is not None:
146146
positions = positions + offsets
@@ -174,7 +174,7 @@ def forward_cuda(
174174
query: torch.Tensor,
175175
key: Optional[torch.Tensor] = None,
176176
offsets: Optional[torch.Tensor] = None,
177-
) -> Tuple[torch.Tensor, torch.Tensor]:
177+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
178178
from vllm import _custom_ops as ops
179179

180180
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
@@ -202,7 +202,7 @@ def forward_xpu(
202202
query: torch.Tensor,
203203
key: Optional[torch.Tensor] = None,
204204
offsets: Optional[torch.Tensor] = None,
205-
) -> Tuple[torch.Tensor, torch.Tensor]:
205+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
206206
from vllm._ipex_ops import ipex_ops as ops
207207

208208
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
@@ -232,7 +232,7 @@ def forward_hpu(
232232
query: torch.Tensor,
233233
key: Optional[torch.Tensor] = None,
234234
offsets: Optional[torch.Tensor] = None,
235-
) -> Tuple[torch.Tensor, torch.Tensor]:
235+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
236236
from habana_frameworks.torch.hpex.kernels import (
237237
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
238238
if offsets is not None:
@@ -290,7 +290,7 @@ def forward_neuron(
290290
query: torch.Tensor,
291291
key: Optional[torch.Tensor] = None,
292292
offsets: Optional[torch.Tensor] = None,
293-
) -> Tuple[torch.Tensor, torch.Tensor]:
293+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
294294

295295
def _apply_rotary_emb_neuron(
296296
x: torch.Tensor,
@@ -688,7 +688,7 @@ def forward(
688688
query: torch.Tensor,
689689
key: Optional[torch.Tensor] = None,
690690
offsets: Optional[torch.Tensor] = None,
691-
) -> Tuple[torch.Tensor, torch.Tensor]:
691+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
692692
assert key is not None
693693
query = query.view(*query.shape[:-1], -1, self.head_size)
694694
key = key.view(*key.shape[:-1], -1, self.head_size)
@@ -799,7 +799,7 @@ def forward(
799799
query: torch.Tensor,
800800
key: Optional[torch.Tensor] = None,
801801
offsets: Optional[torch.Tensor] = None,
802-
) -> Tuple[torch.Tensor, torch.Tensor]:
802+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
803803
"""PyTorch-native implementation equivalent to forward()."""
804804
assert key is not None
805805
query_rot = query[..., :self.rotary_dim]
@@ -929,7 +929,7 @@ def forward(
929929
self,
930930
query: torch.Tensor,
931931
key: Optional[torch.Tensor] = None,
932-
) -> Tuple[torch.Tensor, torch.Tensor]:
932+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
933933
assert key is not None
934934
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
935935
query_ = torch.view_as_complex(query.float().reshape(
@@ -975,7 +975,7 @@ def forward(
975975
positions: torch.Tensor,
976976
query: torch.Tensor,
977977
key: Optional[torch.Tensor] = None,
978-
) -> Tuple[torch.Tensor, torch.Tensor]:
978+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
979979
"""PyTorch-native implementation equivalent to forward().
980980
981981
Args:

0 commit comments

Comments
 (0)