diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index 8a59e884d6c..74bb014cf39 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -9,7 +9,8 @@ void rotary_embedding_impl( scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, /// head_size] or [num_tokens, num_heads, /// head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr (optional) or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // @@ -85,10 +86,13 @@ void rotary_embedding_impl( compute_loop(token_head, cache_ptr, query); } - for (int i = 0; i < num_kv_heads; ++i) { - const int head_idx = i; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; - compute_loop(token_head, cache_ptr, key); + if (key != nullptr) { + for (int i = 0; i < num_kv_heads; ++i) { + const int head_idx = i; + const int64_t token_head = + token_idx * key_stride + head_idx * head_size; + compute_loop(token_head, cache_ptr, key); + } } } } @@ -100,7 +104,8 @@ void rotary_embedding_gptj_impl( scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, /// head_size] or [num_tokens, num_heads, /// head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr (optional) or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // @@ -138,6 +143,10 @@ void rotary_embedding_gptj_impl( } } + if (key == nullptr) { + return; + } + #pragma omp parallel for collapse(2) for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_kv_heads; ++i) { @@ -168,13 +177,13 @@ void rotary_embedding_gptj_impl( }; // namespace void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, + std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox) { int num_tokens = positions.numel(); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; - int num_kv_heads = key.size(-1) / head_size; - int64_t key_stride = key.stride(-2); + int num_kv_heads = key.has_value() ? key->size(-1) / head_size : num_heads; + int64_t key_stride = key.has_value() ? key->stride(-2) : 0; int64_t query_stride = query.stride(-2); VLLM_DISPATCH_FLOATING_TYPES( @@ -183,15 +192,15 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, if (is_neox) { rotary_embedding_impl( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - rot_dim, query_stride, key_stride, num_heads, num_kv_heads, - head_size, num_tokens); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size, num_tokens); } else { rotary_embedding_gptj_impl( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - rot_dim, query_stride, key_stride, num_heads, num_kv_heads, - head_size, num_tokens); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size, num_tokens); } CPU_KERNEL_GUARD_OUT(rotary_embedding_impl) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 7ae7e3386b4..84b2a8555cc 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -117,7 +117,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( "rotary_embedding(Tensor positions, Tensor! query," - " Tensor! key, int head_size," + " Tensor!? key, int head_size," " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); diff --git a/csrc/ops.h b/csrc/ops.h index 59ae0937604..8c4d19963e5 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -86,13 +86,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out, std::optional residual); void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, + std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, - torch::Tensor& cos_sin_cache, bool is_neox, - int64_t rot_dim, + std::optional key, + int64_t head_size, torch::Tensor& cos_sin_cache, + bool is_neox, int64_t rot_dim, torch::Tensor& cos_sin_cache_offsets); void silu_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index c085d31a3e9..ef6dd1c0978 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -38,7 +38,8 @@ inline __device__ void apply_rotary_embedding( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, // head_size] or [num_tokens, num_heads, // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* cache_ptr, const int head_size, const int num_heads, @@ -57,13 +58,15 @@ inline __device__ void apply_rotary_embedding( query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } - const int nk = num_kv_heads * embed_dim; - for (int i = threadIdx.x; i < nk; i += blockDim.x) { - const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; - const int rot_offset = i % embed_dim; - apply_token_rotary_embedding( - key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + if (key != nullptr) { + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding( + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } } } @@ -74,7 +77,8 @@ __global__ void rotary_embedding_kernel( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, // head_size] or [num_tokens, num_heads, // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // @@ -98,7 +102,8 @@ __global__ void batched_rotary_embedding_kernel( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, // head_size] or [num_tokens, num_heads, // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // @@ -127,10 +132,12 @@ void rotary_embedding( // [num_tokens, num_heads * head_size] or // [batch_size, seq_len, num_heads, head_size] or // [num_tokens, num_heads, head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] + std::optional key, + // null or + // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox) { @@ -138,40 +145,40 @@ void rotary_embedding( int64_t num_tokens = positions.numel(); int positions_ndim = positions.dim(); - // Make sure num_tokens dim is consistent across positions, query, and key. + // Make sure num_tokens dim is consistent across positions, query, and key TORCH_CHECK( positions_ndim == 1 || positions_ndim == 2, "positions must have shape [num_tokens] or [batch_size, seq_len]"); if (positions_ndim == 1) { - TORCH_CHECK( - query.size(0) == positions.size(0) && key.size(0) == positions.size(0), - "query, key and positions must have the same number of tokens"); + TORCH_CHECK(query.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)), + "query, key and positions must have the same number of tokens"); } if (positions_ndim == 2) { TORCH_CHECK( query.size(0) == positions.size(0) && - key.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)) && query.size(1) == positions.size(1) && - key.size(1) == positions.size(1), + (!key.has_value() || key->size(1) == positions.size(1)), "query, key and positions must have the same batch_size and seq_len"); } // Make sure head_size is valid for query and key // hidden_size = num_heads * head_size int query_hidden_size = query.numel() / num_tokens; - int key_hidden_size = key.numel() / num_tokens; + int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(key_hidden_size % head_size == 0); // Make sure query and key have consistent number of heads int num_heads = query_hidden_size / head_size; - int num_kv_heads = key_hidden_size / head_size; + int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; TORCH_CHECK(num_heads % num_kv_heads == 0); int rot_dim = cos_sin_cache.size(1); int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); - int64_t key_stride = key.stride(seq_dim_idx); + int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -181,15 +188,16 @@ void rotary_embedding( if (is_neox) { vllm::rotary_embedding_kernel<<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, - query_stride, key_stride, num_heads, num_kv_heads, head_size); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, key_stride, + num_heads, num_kv_heads, head_size); } else { vllm::rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - rot_dim, query_stride, key_stride, num_heads, num_kv_heads, - head_size); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); } }); } @@ -204,10 +212,12 @@ void batched_rotary_embedding( // [num_tokens, num_heads * head_size] or // [batch_size, seq_len, num_heads, head_size] or // [num_tokens, num_heads, head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] + std::optional + key, // null or + // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox, int64_t rot_dim, @@ -221,38 +231,38 @@ void batched_rotary_embedding( "cos_sin_cache_offsets"); int positions_ndim = positions.dim(); - // Make sure num_tokens dim is consistent across positions, query, and key. + // Make sure num_tokens dim is consistent across positions, query, and key TORCH_CHECK( positions_ndim == 1 || positions_ndim == 2, "positions must have shape [num_tokens] or [batch_size, seq_len]"); if (positions_ndim == 1) { - TORCH_CHECK( - query.size(0) == positions.size(0) && key.size(0) == positions.size(0), - "query, key and positions must have the same number of tokens"); + TORCH_CHECK(query.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)), + "query, key and positions must have the same number of tokens"); } if (positions_ndim == 2) { TORCH_CHECK( query.size(0) == positions.size(0) && - key.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)) && query.size(1) == positions.size(1) && - key.size(1) == positions.size(1), + (!key.has_value() || key->size(1) == positions.size(1)), "query, key and positions must have the same batch_size and seq_len"); } // Make sure head_size is valid for query and key int query_hidden_size = query.numel() / num_tokens; - int key_hidden_size = key.numel() / num_tokens; + int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(key_hidden_size % head_size == 0); // Make sure query and key have concistent number of heads int num_heads = query_hidden_size / head_size; - int num_kv_heads = key_hidden_size / head_size; + int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; TORCH_CHECK(num_heads % num_kv_heads == 0); int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); - int64_t key_stride = key.stride(seq_dim_idx); + int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -263,14 +273,16 @@ void batched_rotary_embedding( vllm::batched_rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, key_stride, num_heads, num_kv_heads, head_size); } else { vllm::batched_rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, key_stride, num_heads, num_kv_heads, head_size); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index f59b42d88c6..f207a4c2fa3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -176,7 +176,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( "rotary_embedding(Tensor positions, Tensor! query," - " Tensor! key, int head_size," + " Tensor!? key, int head_size," " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); @@ -184,7 +184,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // (supports multiple loras). ops.def( "batched_rotary_embedding(Tensor positions, Tensor! query," - " Tensor! key, int head_size," + " Tensor!? key, int head_size," " Tensor cos_sin_cache, bool is_neox," " int rot_dim," " Tensor cos_sin_cache_offsets) -> ()"); diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index 2b7bf755ec2..d81c7487b88 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -21,6 +21,7 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +USE_KEY = [True, False] def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, @@ -46,6 +47,7 @@ def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_key", USE_KEY) @torch.inference_mode() def test_rotary_embedding( is_neox_style: bool, @@ -58,6 +60,7 @@ def test_rotary_embedding( dtype: torch.dtype, seed: int, device: str, + use_key: bool, max_position: int = 8192, base: int = 10000, ) -> None: @@ -74,7 +77,7 @@ def test_rotary_embedding( positions = torch.randint(0, max_position, (batch_size, seq_len)) query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) query = torch.randn(query_shape, dtype=dtype) - key = torch.randn_like(query) + key = torch.randn_like(query) if use_key else None # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. @@ -85,10 +88,14 @@ def test_rotary_embedding( ref_query, atol=get_default_atol(out_query), rtol=get_default_rtol(out_query)) - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) + if use_key: + torch.testing.assert_close(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) + else: + assert ref_key is None and out_key is None, \ + "expected returned key to be None" @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -101,6 +108,7 @@ def test_rotary_embedding( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_key", USE_KEY) @torch.inference_mode() def test_batched_rotary_embedding( is_neox_style: bool, @@ -113,6 +121,7 @@ def test_batched_rotary_embedding( dtype: torch.dtype, seed: int, device: str, + use_key: bool, max_position: int = 8192, base: int = 10000, ) -> None: @@ -129,7 +138,7 @@ def test_batched_rotary_embedding( positions = torch.randint(0, max_position, (batch_size, seq_len)) query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) query = torch.randn(query_shape, dtype=dtype) - key = torch.randn_like(query) + key = torch.randn_like(query) if use_key else None # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. @@ -145,10 +154,14 @@ def test_batched_rotary_embedding( ref_query, atol=get_default_atol(out_query), rtol=get_default_rtol(out_query)) - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) + if use_key: + torch.testing.assert_close(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) + else: + assert ref_key is None and out_key is None, \ + "expected returned key to be None" @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -160,6 +173,7 @@ def test_batched_rotary_embedding( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_key", USE_KEY) @torch.inference_mode() def test_batched_rotary_embedding_multi_lora( is_neox_style: bool, @@ -171,6 +185,7 @@ def test_batched_rotary_embedding_multi_lora( dtype: torch.dtype, seed: int, device: str, + use_key: bool, max_position: int = 8192, base: int = 10000, ) -> None: @@ -190,7 +205,7 @@ def test_batched_rotary_embedding_multi_lora( seq_len, num_heads * head_size, dtype=dtype) - key = torch.randn_like(query) + key = torch.randn_like(query) if use_key else None offset_map = torch.tensor( list( @@ -214,10 +229,14 @@ def test_batched_rotary_embedding_multi_lora( ref_query, atol=get_default_atol(out_query), rtol=get_default_rtol(out_query)) - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) + if use_key: + torch.testing.assert_close(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) + else: + assert ref_key is None and out_key is None, \ + "expected returned key to be None" @torch.inference_mode() diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index c497dd90edd..4e54861005f 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -15,7 +15,7 @@ def rotary_embedding_opcheck(rot, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None): cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype) @@ -37,9 +37,10 @@ def rotary_embedding_opcheck(rot, @pytest.mark.parametrize("rotary_dim", [32]) @pytest.mark.parametrize("head_size", [32, 108]) @pytest.mark.parametrize("seq_len", [11, 1024]) +@pytest.mark.parametrize("use_key", [True, False]) def test_rotary_embedding_opcheck(dist_init, device, max_position, is_neox_style, rotary_dim, head_size, - seq_len): + seq_len, use_key): batch_size = 1 base = 10000 num_heads = 7 @@ -54,7 +55,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, num_heads * head_size, dtype=torch.float32, device=device) - key = torch.randn_like(query) + key = torch.randn_like(query) if use_key else None rotary_embedding_opcheck(rot, positions, query, key) offsets = torch.zeros(batch_size * seq_len, diff --git a/tests/neuron/1_core/test_rotary_embedding.py b/tests/neuron/1_core/test_rotary_embedding.py index c015b80bd47..da57631fcfc 100644 --- a/tests/neuron/1_core/test_rotary_embedding.py +++ b/tests/neuron/1_core/test_rotary_embedding.py @@ -11,14 +11,16 @@ @pytest.mark.parametrize( - "max_position,is_neox_style,rotary_dim,head_size,seq_len", [ - (16, False, 32, 32, 1024), - (16, False, 32, 128, 1024), - (16, True, 32, 32, 1024), - (16, True, 32, 128, 1024), + "max_position,is_neox_style,rotary_dim,head_size,seq_len,use_key", [ + (16, False, 32, 32, 1024, True), + (16, False, 32, 128, 1024, True), + (16, True, 32, 32, 1024, True), + (16, True, 32, 128, 1024, True), + (16, False, 32, 128, 1024, False), + (16, True, 32, 128, 1024, False), ]) def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim, - head_size, seq_len): + head_size, seq_len, use_key): import torch_xla.core.xla_model as xm device = xm.xla_device() @@ -40,19 +42,26 @@ def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim, num_heads * head_size, dtype=torch.float32, device="cpu") - key = torch.randn_like(query) - + key = torch.randn_like(query) if use_key else None assert positions.is_cpu, \ "reference input tensor is expected to be CPU tensor." ref_query, ref_key = rot.to(device="cpu").forward_native( positions, query, key) out_query, out_key = rot.to(device=device).forward_neuron( positions.to(device=device), query.to(device=device), - key.to(device=device)) - assert out_query.is_xla and out_key.is_xla, \ - "output tensor is expected to be XLA tensor" + key.to(device=device) if key is not None else None) + if use_key: + assert out_query.is_xla and out_key.is_xla, \ + "output tensor is expected to be XLA tensor" + torch.testing.assert_close(out_key.cpu(), + ref_key, + atol=1e-2, + rtol=1e-2) + else: + assert out_key is None, "expected returned key to be None" + assert out_query.is_xla, \ + "output tensor is expected to be XLA tensor" torch.testing.assert_close(out_query.cpu(), ref_query, atol=1e-2, rtol=1e-2) - torch.testing.assert_close(out_key.cpu(), ref_key, atol=1e-2, rtol=1e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 44377ccb295..463cc1a8c64 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -153,34 +153,36 @@ def merge_attn_states(output: torch.Tensor, def rotary_embedding( positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor], head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: # TODO: Remove this contiguous call when the kernel is updated to support tensor slices query_contiguous = query.contiguous() - key_contiguous = key.contiguous() + key_contiguous = key.contiguous() if key is not None else None torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous, head_size, cos_sin_cache, is_neox) query.copy_(query_contiguous) - key.copy_(key_contiguous) + if key is not None: + key.copy_(key_contiguous) def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, head_size: int, + key: Optional[torch.Tensor], head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: # TODO: Remove this contiguous call when the kernel is updated to support tensor slices query_contiguous = query.contiguous() - key_contiguous = key.contiguous() + key_contiguous = key.contiguous() if key is not None else None torch.ops._C.batched_rotary_embedding(positions, query_contiguous, key_contiguous, head_size, cos_sin_cache, is_neox, rot_dim, cos_sin_cache_offsets) query.copy_(query_contiguous) - key.copy_(key_contiguous) + if key is not None: + key.copy_(key_contiguous) # layer norm ops diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 523250c3080..32c2a2859b4 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -138,9 +138,9 @@ def forward_native( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """A PyTorch-native implementation of forward().""" if offsets is not None: positions = positions + offsets @@ -157,22 +157,24 @@ def forward_native( self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = _apply_rotary_emb_torch(key_rot, cos, sin, - self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + # key may be None in some cases, e.g. cross-layer KV sharing + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb_torch(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key def forward_cuda( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: from vllm import _custom_ops as ops # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) @@ -198,32 +200,39 @@ def forward_xpu( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: from vllm._ipex_ops import ipex_ops as ops self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype) # ops.rotary_embedding()/batched_rotary_embedding() # are in-place operations that update the query and key tensors. - if offsets is not None: - ops.batched_rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, - self.is_neox_style, self.rotary_dim, - offsets) + if key is None: + # XPU kernel doesn't support key=None so fall back to native impl + # TODO(sarckk): add support for optional key in + # ipex.llm.functional.rotary_embedding_batched + return self.forward_native(positions, query, key, offsets) else: - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + self.rotary_dim, offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) return query, key def forward_hpu( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb) if offsets is not None: @@ -265,21 +274,23 @@ def forward_hpu( rope_mode) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, + rope_mode) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key def forward_neuron( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: def _apply_rotary_emb_neuron( x: torch.Tensor, @@ -319,14 +330,16 @@ def _apply_rotary_emb_neuron( query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) if self.rotary_dim == self.head_size: query = _apply_rotary_emb(query, cos, sin, self.is_neox_style) query = query.reshape(query_shape) - key = _apply_rotary_emb(key, cos, sin, self.is_neox_style) - key = key.reshape(key_shape) + if key is not None: + key = _apply_rotary_emb(key, cos, sin, self.is_neox_style) + key = key.reshape(key_shape) else: head_size = query.shape[-1] query_reshaped = query.view(-1, head_size) @@ -339,14 +352,15 @@ def _apply_rotary_emb_neuron( query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - key_reshaped = key.view(-1, head_size) - key_pass = key_reshaped[:, self.rotary_dim:].view( - *key.shape[:-1], head_size - self.rotary_dim) - key_rot = key_reshaped[:, :self.rotary_dim].view( - *key.shape[:-1], self.rotary_dim) - key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin, - self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + if key is not None: + key_reshaped = key.view(-1, head_size) + key_pass = key_reshaped[:, self.rotary_dim:].view( + *key.shape[:-1], head_size - self.rotary_dim) + key_rot = key_reshaped[:, :self.rotary_dim].view( + *key.shape[:-1], self.rotary_dim) + key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key def extra_repr(self) -> str: @@ -672,9 +686,10 @@ def forward( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + assert key is not None query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) @@ -782,10 +797,11 @@ def forward( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" + assert key is not None query_rot = query[..., :self.rotary_dim] key_rot = key[..., :self.rotary_dim] if self.rotary_dim < self.head_size: @@ -912,8 +928,9 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: def forward( self, query: torch.Tensor, - key: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + key: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + assert key is not None self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) query_ = torch.view_as_complex(query.float().reshape( *query.shape[:-1], -1, 2)) @@ -957,8 +974,8 @@ def forward( self, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + key: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward(). Args: @@ -969,6 +986,7 @@ def forward( key: [num_tokens, num_kv_heads * head_size] """ assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None num_tokens = positions.shape[-1] cos_sin = self.cos_sin_cache[positions]