Skip to content

Commit d61199a

Browse files
committed
register op
1 parent 29dd674 commit d61199a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

torch/distributed/_symmetric_memory/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,7 +1371,7 @@ def _maybe_convert_scalar_types_to_dtypes(
13711371
class Work(_Work):
13721372
def __init__(self) -> None:
13731373
super().__init__()
1374-
self.event = torch.cuda.Event()
1374+
self.event = torch.xpu.Event()
13751375
self.event.record()
13761376

13771377
def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool:
@@ -1416,7 +1416,7 @@ def _low_contention_all_gather_meta(
14161416
group_size = c10d._get_group_size_by_name(group_name)
14171417
return tensor.new_empty(tensor.shape[0] * group_size, *tensor.shape[1:])
14181418

1419-
1419+
@torch.library.impl(lib, "_low_contention_all_gather", "XPU")
14201420
@torch.library.impl(lib, "_low_contention_all_gather", "CUDA")
14211421
def _low_contention_all_gather(
14221422
tensor: torch.Tensor,
@@ -1449,7 +1449,7 @@ def _low_contention_all_gather(
14491449
output = tensor.new_empty(tensor.shape[0] * world_size, *tensor.shape[1:])
14501450
chunks = output.chunk(world_size)
14511451

1452-
_get_backend_stream().wait_stream(torch.cuda.current_stream())
1452+
_get_backend_stream().wait_stream(torch.xpu.current_stream())
14531453
with _get_backend_stream():
14541454
if not input_is_symm_mem:
14551455
local_buf = symm_mem.get_buffer(rank, tensor.shape, tensor.dtype)
@@ -1487,7 +1487,7 @@ def _low_contention_reduce_scatter_with_symm_mem_input(
14871487
a2a_res = torch.empty_like(tensor)
14881488
chunks = a2a_res.chunk(world_size)
14891489

1490-
_get_backend_stream().wait_stream(torch.cuda.current_stream())
1490+
_get_backend_stream().wait_stream(torch.xpu.current_stream())
14911491
with _get_backend_stream():
14921492
# pull + offline reduction
14931493
symm_mem.barrier()
@@ -1524,7 +1524,7 @@ def _low_contention_reduce_scatter_with_workspace(
15241524
assert tensor.shape[0] % world_size == 0
15251525
chunks = tensor.chunk(world_size)
15261526

1527-
_get_backend_stream().wait_stream(torch.cuda.current_stream())
1527+
_get_backend_stream().wait_stream(torch.xpu.current_stream())
15281528
with _get_backend_stream():
15291529
# push + offline reduction
15301530
workspace.barrier()
@@ -1547,7 +1547,7 @@ def _low_contention_reduce_scatter_with_workspace(
15471547
torch._C._distributed_c10d._register_work(ret, Work())
15481548
return ret
15491549

1550-
1550+
@torch.library.impl(lib, "_low_contention_reduce_scatter", "XPU")
15511551
@torch.library.impl(lib, "_low_contention_reduce_scatter", "CUDA")
15521552
def _low_contention_reduce_scatter(
15531553
tensor: torch.Tensor,

0 commit comments

Comments
 (0)