@@ -513,43 +513,50 @@ def functional_linear(self, weights, bias=None):
513513 res += bias
514514 return res
515515
516- @register_function (torch .ops .xla .dynamo_set_buffer_donor_ )
517- def _dynamo_set_buffer_donor (self , donor ):
516+ try :
517+ # TODO: Currently the following ops are wrapped in the try
518+ # catch block because torch.ops.xla is not in the torch ops
519+ # registry. Either we import torch_xla in the upper level,
520+ # or modify the the register_function to support this.
521+ @register_function (torch .ops .xla .dynamo_set_buffer_donor_ )
522+ def _dynamo_set_buffer_donor (self , donor ):
523+ pass
524+
525+ @register_function (torch .ops .xla .ragged_paged_attention )
526+ def _ragged_paged_attention (
527+ q : jax .Array , # [max_num_batched_tokens, num_q_heads, head_dim]
528+ kv_pages : jax .Array , # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
529+ kv_lens : jax .Array , # i32[max_num_seqs]
530+ page_indices : jax .Array , # i32[max_num_seqs, pages_per_seq]
531+ cu_q_lens : jax .Array , # i32[max_num_seqs + 1]
532+ num_seqs : jax .Array , # i32[1]
533+ use_kernel : bool = True ,
534+ sm_scale : float = 1.0 ,
535+ sliding_window : int | None = None ,
536+ soft_cap : float | None = None ,
537+ mask_value : float | None = None ,
538+ num_kv_pages_per_block : int | None = None ,
539+ num_queries_per_block : int | None = None ,
540+ vmem_limit_bytes : int | None = None ,
541+ ):
542+
543+ from torch_xla .experimental .pallas_kernels .ragged_paged_attention_v2 import ragged_paged_attention as ragged_paged_attention_kernel
544+ return ragged_paged_attention_kernel (
545+ q = q ,
546+ kv_pages = kv_pages ,
547+ kv_lens = kv_lens ,
548+ page_indices = page_indices ,
549+ cu_q_lens = cu_q_lens ,
550+ num_seqs = num_seqs ,
551+ sm_scale = sm_scale ,
552+ sliding_window = sliding_window ,
553+ soft_cap = soft_cap ,
554+ mask_value = mask_value ,
555+ num_kv_pages_per_block = num_kv_pages_per_block ,
556+ num_queries_per_block = num_queries_per_block ,
557+ vmem_limit_bytes = vmem_limit_bytes ,
558+ )
559+ except Exception as e :
518560 pass
519561
520- @register_function (torch .ops .xla .ragged_paged_attention )
521- def _ragged_paged_attention (
522- q : jax .Array , # [max_num_batched_tokens, num_q_heads, head_dim]
523- kv_pages : jax .Array , # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
524- kv_lens : jax .Array , # i32[max_num_seqs]
525- page_indices : jax .Array , # i32[max_num_seqs, pages_per_seq]
526- cu_q_lens : jax .Array , # i32[max_num_seqs + 1]
527- num_seqs : jax .Array , # i32[1]
528- use_kernel : bool = True ,
529- sm_scale : float = 1.0 ,
530- sliding_window : int | None = None ,
531- soft_cap : float | None = None ,
532- mask_value : float | None = None ,
533- num_kv_pages_per_block : int | None = None ,
534- num_queries_per_block : int | None = None ,
535- vmem_limit_bytes : int | None = None ,
536- ):
537-
538- from torch_xla .experimental .pallas_kernels .ragged_paged_attention_v2 import ragged_paged_attention as ragged_paged_attention_kernel
539- return ragged_paged_attention_kernel (
540- q = q ,
541- kv_pages = kv_pages ,
542- kv_lens = kv_lens ,
543- page_indices = page_indices ,
544- cu_q_lens = cu_q_lens ,
545- num_seqs = num_seqs ,
546- sm_scale = sm_scale ,
547- sliding_window = sliding_window ,
548- soft_cap = soft_cap ,
549- mask_value = mask_value ,
550- num_kv_pages_per_block = num_kv_pages_per_block ,
551- num_queries_per_block = num_queries_per_block ,
552- vmem_limit_bytes = vmem_limit_bytes ,
553- )
554-
555562
0 commit comments