@@ -1371,7 +1371,7 @@ def _maybe_convert_scalar_types_to_dtypes(
13711371class Work (_Work ):
13721372 def __init__ (self ) -> None :
13731373 super ().__init__ ()
1374- self .event = torch .cuda .Event ()
1374+ self .event = torch .xpu .Event ()
13751375 self .event .record ()
13761376
13771377 def wait (self , timeout : timedelta = timedelta (seconds = 0 )) -> bool :
@@ -1416,7 +1416,7 @@ def _low_contention_all_gather_meta(
14161416 group_size = c10d ._get_group_size_by_name (group_name )
14171417 return tensor .new_empty (tensor .shape [0 ] * group_size , * tensor .shape [1 :])
14181418
1419-
1419+ @ torch . library . impl ( lib , "_low_contention_all_gather" , "XPU" )
14201420@torch .library .impl (lib , "_low_contention_all_gather" , "CUDA" )
14211421def _low_contention_all_gather (
14221422 tensor : torch .Tensor ,
@@ -1449,7 +1449,7 @@ def _low_contention_all_gather(
14491449 output = tensor .new_empty (tensor .shape [0 ] * world_size , * tensor .shape [1 :])
14501450 chunks = output .chunk (world_size )
14511451
1452- _get_backend_stream ().wait_stream (torch .cuda .current_stream ())
1452+ _get_backend_stream ().wait_stream (torch .xpu .current_stream ())
14531453 with _get_backend_stream ():
14541454 if not input_is_symm_mem :
14551455 local_buf = symm_mem .get_buffer (rank , tensor .shape , tensor .dtype )
@@ -1487,7 +1487,7 @@ def _low_contention_reduce_scatter_with_symm_mem_input(
14871487 a2a_res = torch .empty_like (tensor )
14881488 chunks = a2a_res .chunk (world_size )
14891489
1490- _get_backend_stream ().wait_stream (torch .cuda .current_stream ())
1490+ _get_backend_stream ().wait_stream (torch .xpu .current_stream ())
14911491 with _get_backend_stream ():
14921492 # pull + offline reduction
14931493 symm_mem .barrier ()
@@ -1524,7 +1524,7 @@ def _low_contention_reduce_scatter_with_workspace(
15241524 assert tensor .shape [0 ] % world_size == 0
15251525 chunks = tensor .chunk (world_size )
15261526
1527- _get_backend_stream ().wait_stream (torch .cuda .current_stream ())
1527+ _get_backend_stream ().wait_stream (torch .xpu .current_stream ())
15281528 with _get_backend_stream ():
15291529 # push + offline reduction
15301530 workspace .barrier ()
@@ -1547,7 +1547,7 @@ def _low_contention_reduce_scatter_with_workspace(
15471547 torch ._C ._distributed_c10d ._register_work (ret , Work ())
15481548 return ret
15491549
1550-
1550+ @ torch . library . impl ( lib , "_low_contention_reduce_scatter" , "XPU" )
15511551@torch .library .impl (lib , "_low_contention_reduce_scatter" , "CUDA" )
15521552def _low_contention_reduce_scatter (
15531553 tensor : torch .Tensor ,
0 commit comments