Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion comms/torchcomms/ncclx/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Extension: torchcomms._comms_ncclx
file(GLOB TORCHCOMMS_NCCLX_SOURCES "comms/torchcomms/ncclx/*.cpp")
file(GLOB TORCHCOMMS_NCCLX_SOURCES
"comms/torchcomms/ncclx/*.cpp"
"comms/torchcomms/transport/*.cc"
)
file(GLOB TORCHCOMMS_CUDA_API_SOURCE "comms/torchcomms/device/CudaApi.cpp")

find_package(CUDA)
Expand Down Expand Up @@ -46,6 +49,7 @@ add_library(torchcomms_comms_ncclx MODULE
${TORCHCOMMS_NCCLX_SOURCES}
${TORCHCOMMS_CUDA_API_SOURCE}
)
target_compile_definitions(torchcomms_comms_ncclx PRIVATE MOCK_SCUBA_DATA CTRAN_DISABLE_TCPDM)
set_target_properties(torchcomms_comms_ncclx PROPERTIES
PREFIX ""
OUTPUT_NAME "_comms_ncclx"
Expand Down
30 changes: 30 additions & 0 deletions comms/torchcomms/ncclx/TorchCommNCCLXPy.cpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,49 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.

#include <folly/io/async/EventBase.h>
#include <folly/io/async/ScopedEventBaseThread.h>
#include <pybind11/chrono.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/csrc/utils/pybind.h>

#include "comms/torchcomms/ncclx/TorchCommNCCLX.hpp"
#include "comms/torchcomms/transport/RdmaTransport.h"

namespace py = pybind11;
using namespace torch::comms;

namespace {
folly::ScopedEventBaseThread& getScopedEventBaseThread() {
// This intentionally creates and leaks a global event base thread to be used
// for all Transports on first use.
static folly::ScopedEventBaseThread scopedEventBaseThread{"torchcomms_evb"};
return scopedEventBaseThread;
}
} // namespace

PYBIND11_MODULE(_comms_ncclx, m) {
m.doc() = "NCCLX specific python bindings for TorchComm";

py::class_<TorchCommNCCLX, std::shared_ptr<TorchCommNCCLX>>(
m, "TorchCommNCCLX");

py::class_<RdmaTransport, std::shared_ptr<RdmaTransport>>(m, "RdmaTransport")
// initialize a new RDMATransport using a custom init fn
.def(py::init([](at::Device device) {
TORCH_INTERNAL_ASSERT(device.is_cuda());
int cuda_device = device.index();
return std::make_shared<RdmaTransport>(
cuda_device, getScopedEventBaseThread().getEventBase());
}))
.def_static("supported", &RdmaTransport::supported)
.def("bind", [](RdmaTransport& self) { return py::bytes(self.bind()); })
.def(
"connect",
[](RdmaTransport& self, const py::bytes& peerUrl) {
std::string peerUrlStr = peerUrl.cast<std::string>();
return static_cast<int>(self.connect(peerUrlStr));
})
.def("connected", &RdmaTransport::connected);
}
11 changes: 11 additions & 0 deletions comms/torchcomms/ncclx/_comms_ncclx.pyi
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import torch

class TorchCommNCCLX: ...

class RdmaTransport:
def __init__(self, device: torch.device) -> None: ... # pyre-ignore[11]
@staticmethod
def supported() -> bool: ...
def bind(self) -> bytes: ...
def connect(self, peer_url: bytes) -> int: ...
def connected(self) -> bool: ...
55 changes: 55 additions & 0 deletions comms/torchcomms/tests/integration/py/TransportTest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python3
# pyre-unsafe
# Copyright (c) Meta Platforms, Inc. and affiliates.

import os
import unittest

import torch
from torchcomms._comms_ncclx import RdmaTransport


class TransportTest(unittest.TestCase):
def setUp(self):
if not RdmaTransport.supported():
self.skipTest("RdmaTransport is not supported on this system")

def test_construct(self) -> None:
_ = RdmaTransport(torch.device("cuda:0"))

def test_bind_and_connect(self) -> None:
if torch.cuda.device_count() < 2:
self.skipTest(
f"Test requires at least 2 CUDA devices, found {torch.cuda.device_count()}"
)

server_device = torch.device("cuda:0")
client_device = torch.device("cuda:1")

server_transport = RdmaTransport(server_device)
client_transport = RdmaTransport(client_device)

server_url = server_transport.bind()
client_url = client_transport.bind()

self.assertIsNotNone(server_url)
self.assertIsNotNone(client_url)
self.assertNotEqual(server_url, "")
self.assertNotEqual(client_url, "")

server_result = server_transport.connect(client_url)
client_result = client_transport.connect(server_url)

self.assertEqual(
server_result, 0, "Server connect should return commSuccess (0)"
)
self.assertEqual(
client_result, 0, "Client connect should return commSuccess (0)"
)

self.assertTrue(server_transport.connected())
self.assertTrue(client_transport.connected())


if __name__ == "__main__" and os.environ["TEST_BACKEND"] == "ncclx":
unittest.main()
Loading