Skip to content

Commit 577e6ea

Browse files
amirafzalifacebook-github-bot
authored andcommitted
RdmaTransport python binding (#35)
Summary: Create some initial python bindings for rdmatransport, test bind+connect Reviewed By: d4l3k Differential Revision: D85694262
1 parent 68a9391 commit 577e6ea

File tree

4 files changed

+101
-1
lines changed

4 files changed

+101
-1
lines changed

comms/torchcomms/ncclx/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# Extension: torchcomms._comms_ncclx
3-
file(GLOB TORCHCOMMS_NCCLX_SOURCES "comms/torchcomms/ncclx/*.cpp")
3+
file(GLOB TORCHCOMMS_NCCLX_SOURCES
4+
"comms/torchcomms/ncclx/*.cpp"
5+
"comms/torchcomms/transport/*.cc"
6+
)
47
file(GLOB TORCHCOMMS_CUDA_API_SOURCE "comms/torchcomms/device/CudaApi.cpp")
58

69
find_package(CUDA)
@@ -46,6 +49,7 @@ add_library(torchcomms_comms_ncclx MODULE
4649
${TORCHCOMMS_NCCLX_SOURCES}
4750
${TORCHCOMMS_CUDA_API_SOURCE}
4851
)
52+
target_compile_definitions(torchcomms_comms_ncclx PRIVATE MOCK_SCUBA_DATA CTRAN_DISABLE_TCPDM)
4953
set_target_properties(torchcomms_comms_ncclx PROPERTIES
5054
PREFIX ""
5155
OUTPUT_NAME "_comms_ncclx"
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,49 @@
11
// Copyright (c) Meta Platforms, Inc. and affiliates.
22

3+
#include <folly/io/async/EventBase.h>
4+
#include <folly/io/async/ScopedEventBaseThread.h>
35
#include <pybind11/chrono.h>
46
#include <pybind11/numpy.h>
57
#include <pybind11/pybind11.h>
68
#include <pybind11/stl.h>
79
#include <torch/csrc/utils/pybind.h>
810

911
#include "comms/torchcomms/ncclx/TorchCommNCCLX.hpp"
12+
#include "comms/torchcomms/transport/RdmaTransport.h"
1013

1114
namespace py = pybind11;
1215
using namespace torch::comms;
1316

17+
namespace {
18+
folly::ScopedEventBaseThread& getScopedEventBaseThread() {
19+
// This intentionally creates and leaks a global event base thread to be used
20+
// for all Transports on first use.
21+
static folly::ScopedEventBaseThread scopedEventBaseThread{"torchcomms_evb"};
22+
return scopedEventBaseThread;
23+
}
24+
} // namespace
25+
1426
PYBIND11_MODULE(_comms_ncclx, m) {
1527
m.doc() = "NCCLX specific python bindings for TorchComm";
1628

1729
py::class_<TorchCommNCCLX, std::shared_ptr<TorchCommNCCLX>>(
1830
m, "TorchCommNCCLX");
31+
32+
py::class_<RdmaTransport, std::shared_ptr<RdmaTransport>>(m, "RdmaTransport")
33+
// initialize a new RDMATransport using a custom init fn
34+
.def(py::init([](at::Device device) {
35+
TORCH_INTERNAL_ASSERT(device.is_cuda());
36+
int cuda_device = device.index();
37+
return std::make_shared<RdmaTransport>(
38+
cuda_device, getScopedEventBaseThread().getEventBase());
39+
}))
40+
.def_static("supported", &RdmaTransport::supported)
41+
.def("bind", [](RdmaTransport& self) { return py::bytes(self.bind()); })
42+
.def(
43+
"connect",
44+
[](RdmaTransport& self, const py::bytes& peerUrl) {
45+
std::string peerUrlStr = peerUrl.cast<std::string>();
46+
return static_cast<int>(self.connect(peerUrlStr));
47+
})
48+
.def("connected", &RdmaTransport::connected);
1949
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,13 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
import torch
4+
25
class TorchCommNCCLX: ...
6+
7+
class RdmaTransport:
8+
def __init__(self, device: torch.device) -> None: ... # pyre-ignore[11]
9+
@staticmethod
10+
def supported() -> bool: ...
11+
def bind(self) -> bytes: ...
12+
def connect(self, peer_url: bytes) -> int: ...
13+
def connected(self) -> bool: ...
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#!/usr/bin/env python3
2+
# pyre-unsafe
3+
# Copyright (c) Meta Platforms, Inc. and affiliates.
4+
5+
import os
6+
import unittest
7+
8+
import torch
9+
from torchcomms._comms_ncclx import RdmaTransport
10+
11+
12+
class TransportTest(unittest.TestCase):
13+
def setUp(self):
14+
if not RdmaTransport.supported():
15+
self.skipTest("RdmaTransport is not supported on this system")
16+
17+
def test_construct(self) -> None:
18+
_ = RdmaTransport(torch.device("cuda:0"))
19+
20+
def test_bind_and_connect(self) -> None:
21+
if torch.cuda.device_count() < 2:
22+
self.skipTest(
23+
f"Test requires at least 2 CUDA devices, found {torch.cuda.device_count()}"
24+
)
25+
26+
server_device = torch.device("cuda:0")
27+
client_device = torch.device("cuda:1")
28+
29+
server_transport = RdmaTransport(server_device)
30+
client_transport = RdmaTransport(client_device)
31+
32+
server_url = server_transport.bind()
33+
client_url = client_transport.bind()
34+
35+
self.assertIsNotNone(server_url)
36+
self.assertIsNotNone(client_url)
37+
self.assertNotEqual(server_url, "")
38+
self.assertNotEqual(client_url, "")
39+
40+
server_result = server_transport.connect(client_url)
41+
client_result = client_transport.connect(server_url)
42+
43+
self.assertEqual(
44+
server_result, 0, "Server connect should return commSuccess (0)"
45+
)
46+
self.assertEqual(
47+
client_result, 0, "Client connect should return commSuccess (0)"
48+
)
49+
50+
self.assertTrue(server_transport.connected())
51+
self.assertTrue(client_transport.connected())
52+
53+
54+
if __name__ == "__main__" and os.environ["TEST_BACKEND"] == "ncclx":
55+
unittest.main()

0 commit comments

Comments
 (0)