From 8a6cc82c2f9c4948666888610a7fd2cddd4c6419 Mon Sep 17 00:00:00 2001 From: magniloquency <197707854+magniloquency@users.noreply.github.com> Date: Thu, 2 Oct 2025 14:39:11 -0400 Subject: [PATCH 01/23] Add C++ YMQ Tests (#263) * Add C++ YMQ Tests * Move tests to tests/cpp/ymq --------- Co-authored-by: sharpener6 <1sc2l4qi@duck.com> --- tests/CMakeLists.txt | 2 +- tests/cpp/CMakeLists.txt | 1 + tests/cpp/ymq/CMakeLists.txt | 1 + tests/cpp/ymq/common.h | 410 ++++++++++++++++++++++++++ tests/cpp/ymq/test_cc_ymq.cpp | 522 ++++++++++++++++++++++++++++++++++ 5 files changed, 935 insertions(+), 1 deletion(-) create mode 100644 tests/cpp/CMakeLists.txt create mode 100644 tests/cpp/ymq/CMakeLists.txt create mode 100644 tests/cpp/ymq/common.h create mode 100644 tests/cpp/ymq/test_cc_ymq.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index ae6a3b116..b62b86fd4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -31,11 +31,11 @@ function(add_test_executable test_name source_file) add_test(NAME ${test_name} COMMAND ${test_name}) endfunction() - if(LINUX OR APPLE) # This directory fetches Google Test, so it must be included first. add_subdirectory(object_storage) # Add the new directory for io tests. add_subdirectory(io/ymq) + add_subdirectory(cpp) endif() diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt new file mode 100644 index 000000000..26c48e9b0 --- /dev/null +++ b/tests/cpp/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(ymq) \ No newline at end of file diff --git a/tests/cpp/ymq/CMakeLists.txt b/tests/cpp/ymq/CMakeLists.txt new file mode 100644 index 000000000..9f6abe371 --- /dev/null +++ b/tests/cpp/ymq/CMakeLists.txt @@ -0,0 +1 @@ +add_test_executable(test_cc_ymq test_cc_ymq.cpp) diff --git a/tests/cpp/ymq/common.h b/tests/cpp/ymq/common.h new file mode 100644 index 000000000..5fd9dad9f --- /dev/null +++ b/tests/cpp/ymq/common.h @@ -0,0 +1,410 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define RETURN_FAILURE_IF_FALSE(condition) \ + if (!(condition)) { \ + return TestResult::Failure; \ + } + +using namespace std::chrono_literals; + +enum class TestResult : char { Success = 1, Failure = 2 }; + +inline const char* check_localhost(const char* host) +{ + return std::strcmp(host, "localhost") == 0 ? "127.0.0.1" : host; +} + +inline std::string format_address(std::string host, uint16_t port) +{ + return std::format("tcp://{}:{}", check_localhost(host.c_str()), port); +} + +class OwnedFd { +public: + int fd; + + OwnedFd(int fd): fd(fd) {} + + // move-only + OwnedFd(const OwnedFd&) = delete; + OwnedFd& operator=(const OwnedFd&) = delete; + OwnedFd(OwnedFd&& other) noexcept: fd(other.fd) { other.fd = 0; } + OwnedFd& operator=(OwnedFd&& other) noexcept + { + if (this != &other) { + this->fd = other.fd; + other.fd = 0; + } + return *this; + } + + ~OwnedFd() + { + if (fd > 0 && close(fd) < 0) + std::println(std::cerr, "failed to close fd!"); + } + + size_t write(const void* data, size_t len) + { + auto n = ::write(this->fd, data, len); + if (n < 0) + throw std::system_error(errno, std::generic_category(), "failed to write to socket"); + + return n; + } + + void write_all(const char* data, size_t len) + { + for (size_t cursor = 0; cursor < len;) + cursor += this->write(data + cursor, len - cursor); + } + + void write_all(std::string data) { this->write_all(data.data(), data.length()); } + + void write_all(std::vector data) { this->write_all(data.data(), data.size()); } + + size_t read(void* buffer, size_t len) + { + auto n = ::read(this->fd, buffer, len); + if (n < 0) + throw std::system_error(errno, std::generic_category(), "failed to read from socket"); + return n; + } + + void read_exact(char* buffer, size_t len) + { + for (size_t cursor = 0; cursor < len;) + cursor += this->read(buffer + cursor, len - cursor); + } + + operator int() { return fd; } +}; + +class Socket: public OwnedFd { +public: + Socket(int fd): OwnedFd(fd) {} + + void connect(const char* host, uint16_t port, bool nowait = false) + { + sockaddr_in addr { + .sin_family = AF_INET, + .sin_port = htons(port), + .sin_addr = {.s_addr = inet_addr(check_localhost(host))}, + .sin_zero = {0}}; + + connect: + if (::connect(this->fd, (sockaddr*)&addr, sizeof(addr)) < 0) { + if (errno == ECONNREFUSED && !nowait) { + std::this_thread::sleep_for(300ms); + goto connect; + } + + throw std::system_error(errno, std::generic_category(), "failed to connect"); + } + } + + void bind(const char* host, int port) + { + sockaddr_in addr { + .sin_family = AF_INET, + .sin_port = htons(port), + .sin_addr = {.s_addr = inet_addr(check_localhost(host))}, + .sin_zero = {0}}; + + auto status = ::bind(this->fd, (sockaddr*)&addr, sizeof(addr)); + if (status < 0) + throw std::system_error(errno, std::generic_category(), "failed to bind"); + } + + void listen(int n = 32) + { + auto status = ::listen(this->fd, n); + if (status < 0) + throw std::system_error(errno, std::generic_category(), "failed to listen on socket"); + } + + std::pair accept(int flags = 0) + { + sockaddr_in peer_addr {}; + socklen_t len = sizeof(peer_addr); + auto fd = ::accept4(this->fd, (sockaddr*)&peer_addr, &len, flags); + if (fd < 0) + throw std::system_error(errno, std::generic_category(), "failed to accept socket"); + + return std::make_pair(Socket(fd), peer_addr); + } + + void write_message(std::string message) + { + uint64_t header = message.length(); + this->write_all((char*)&header, 8); + this->write_all(message.data(), message.length()); + } + + std::string read_message() + { + uint64_t header = 0; + this->read_exact((char*)&header, 8); + std::vector buffer(header); + this->read_exact(buffer.data(), header); + return std::string(buffer.data(), header); + } +}; + +class TcpSocket: public Socket { +public: + TcpSocket(bool nodelay = true): Socket(0) + { + this->fd = ::socket(AF_INET, SOCK_STREAM, 0); + if (this->fd < 0) + throw std::system_error(errno, std::generic_category(), "failed to create socket"); + + int on = 1; + if (nodelay && setsockopt(this->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&on, sizeof(on)) < 0) + throw std::system_error(errno, std::generic_category(), "failed to set nodelay"); + + if (setsockopt(this->fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) + throw std::system_error(errno, std::generic_category(), "failed to set reuseaddr"); + } + + void flush() + { + int on = 1; + int off = 0; + + if (setsockopt(this->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&off, sizeof(off)) < 0) + throw std::system_error(errno, std::generic_category(), "failed to set nodelay"); + + if (setsockopt(this->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&on, sizeof(on)) < 0) + throw std::system_error(errno, std::generic_category(), "failed to set nodelay"); + + if (setsockopt(this->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&off, sizeof(off)) < 0) + throw std::system_error(errno, std::generic_category(), "failed to set nodelay"); + + if (setsockopt(this->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&on, sizeof(on)) < 0) + throw std::system_error(errno, std::generic_category(), "failed to set nodelay"); + } +}; + +inline void fork_wrapper(std::function fn, int timeout_secs, OwnedFd pipe_wr) +{ + TestResult result = TestResult::Failure; + try { + result = fn(); + } catch (const std::exception& e) { + std::println(stderr, "Exception: {}", e.what()); + result = TestResult::Failure; + } catch (...) { + std::println(stderr, "Unknown exception"); + result = TestResult::Failure; + } + + pipe_wr.write_all((char*)&result, sizeof(TestResult)); +} + +// run a test +// forks and runs each of the provided closures +inline TestResult test( + int timeout_secs, std::vector> closures) +{ + std::vector> pipes {}; + std::vector pids {}; + for (size_t i = 0; i < closures.size(); i++) { + int pipe[2] = {0}; + if (pipe2(pipe, O_NONBLOCK) < 0) { + std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { + close(pipe.first); + close(pipe.second); + }); + + throw std::system_error(errno, std::generic_category(), "failed to create pipe: "); + } + pipes.push_back(std::make_pair(pipe[0], pipe[1])); + } + + for (size_t i = 0; i < closures.size(); i++) { + auto pid = fork(); + if (pid < 0) { + std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { + close(pipe.first); + close(pipe.second); + }); + + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + + throw std::system_error(errno, std::generic_category(), "failed to fork"); + } + + if (pid == 0) { + // close all pipes except our write half + for (size_t j = 0; j < pipes.size(); j++) { + if (i == j) + close(pipes[i].first); + else { + close(pipes[j].first); + close(pipes[j].second); + } + } + + fork_wrapper(closures[i], timeout_secs, pipes[i].second); + std::exit(EXIT_SUCCESS); + } + + pids.push_back(pid); + } + + // close all write halves of the pipes + for (auto pipe: pipes) + close(pipe.second); + + std::vector pfds {}; + + OwnedFd timerfd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK); + if (timerfd < 0) { + std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + + throw std::system_error(errno, std::generic_category(), "failed to create timerfd"); + } + + pfds.push_back({.fd = timerfd.fd, .events = POLL_IN, .revents = 0}); + for (auto pipe: pipes) + pfds.push_back({ + .fd = pipe.first, + .events = POLL_IN, + .revents = 0, + }); + + itimerspec spec { + .it_interval = + { + .tv_sec = 0, + .tv_nsec = 0, + }, + .it_value = { + .tv_sec = timeout_secs, + .tv_nsec = 0, + }}; + + if (timerfd_settime(timerfd, 0, &spec, nullptr) < 0) { + std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + + throw std::system_error(errno, std::generic_category(), "failed to set timerfd"); + } + + std::vector> results(pids.size(), std::nullopt); + + for (;;) { + auto n = poll(pfds.data(), pfds.size(), -1); + if (n < 0) { + std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + + throw std::system_error(errno, std::generic_category(), "failed to poll: "); + } + + for (auto& pfd: std::vector(pfds)) { + if (pfd.revents == 0) + continue; + + // timed out + if (pfd.fd == timerfd) { + std::println("Timed out!"); + + std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + + return TestResult::Failure; + } + + auto elem = std::find_if(pipes.begin(), pipes.end(), [fd = pfd.fd](auto pipe) { return pipe.first == fd; }); + auto idx = elem - pipes.begin(); + + TestResult result = TestResult::Failure; + char buffer = 0; + auto n = read(pfd.fd, &buffer, sizeof(TestResult)); + if (n == 0) { + std::println("failed to read from pipe: pipe closed unexpectedly"); + result = TestResult::Failure; + } else if (n < 0) { + std::println("failed to read from pipe: {}", std::strerror(errno)); + result = TestResult::Failure; + } else + result = (TestResult)buffer; + + // the subprocess should have exited + // check its exit status + int status; + if (waitpid(pids[idx], &status, 0) < 0) + std::println("failed to wait on subprocess[{}]: {}", idx, std::strerror(errno)); + + auto exit_status = WEXITSTATUS(status); + if (WIFEXITED(status) && exit_status != EXIT_SUCCESS) { + std::println("subprocess[{}] exited with status {}", idx, exit_status); + } else if (WIFSIGNALED(status)) { + std::println("subprocess[{}] killed by signal {}", idx, WTERMSIG(status)); + } else + std::println( + "subprocess[{}] completed with {}", idx, result == TestResult::Success ? "Success" : "Failure"); + + // store the result + results[idx] = result; + + // this subprocess is done, remove its pipe from the poll fds + pfds.erase(std::remove_if(pfds.begin(), pfds.end(), [&](auto p) { return p.fd == pfd.fd; }), pfds.end()); + + auto done = std::all_of(results.begin(), results.end(), [](auto result) { return result.has_value(); }); + if (done) + goto end; // justification for goto: breaks out of two levels of loop + } + } + +end: + + std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); + + if (std::ranges::any_of(results, [](auto x) { return x == TestResult::Failure; })) + return TestResult::Failure; + + return TestResult::Success; +} diff --git a/tests/cpp/ymq/test_cc_ymq.cpp b/tests/cpp/ymq/test_cc_ymq.cpp new file mode 100644 index 000000000..f20321908 --- /dev/null +++ b/tests/cpp/ymq/test_cc_ymq.cpp @@ -0,0 +1,522 @@ +// this file contains the tests for the C++ interface of YMQ +// each test case is comprised of at least one client and one server, and possibly a middleman +// the clients and servers used in these tests are defined in the first part of this file +// +// the test cases are at the bottom of this file, after the clients and servers +// the documentation for each case is found on the TEST() definition + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "scaler/io/ymq/bytes.h" +#include "scaler/io/ymq/io_context.h" +#include "scaler/io/ymq/simple_interface.h" +#include "tests/cpp/ymq/common.h" + +using namespace scaler::ymq; +using namespace std::chrono_literals; + +// ━━━━━━━━━━━━━━━━━━━ +// clients and servers +// ━━━━━━━━━━━━━━━━━━━ + +TestResult basic_server_ymq(std::string host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); + syncBindSocket(socket, format_address(host, port)); + auto result = syncRecvMessage(socket); + + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "yi er san si wu liu"); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult basic_client_ymq(std::string host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + syncConnectSocket(socket, format_address(host, port)); + auto result = syncSendMessage(socket, {.address = Bytes("server"), .payload = Bytes("yi er san si wu liu")}); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult basic_server_raw(std::string host, uint16_t port) +{ + TcpSocket socket; + + socket.bind(host.c_str(), port); + socket.listen(); + auto [client, _] = socket.accept(); + client.write_message("server"); + auto client_identity = client.read_message(); + RETURN_FAILURE_IF_FALSE(client_identity == "client"); + auto msg = client.read_message(); + RETURN_FAILURE_IF_FALSE(msg == "yi er san si wu liu"); + + return TestResult::Success; +} + +TestResult basic_client_raw(std::string host, uint16_t port) +{ + TcpSocket socket; + + socket.connect(host.c_str(), port); + socket.write_message("client"); + auto server_identity = socket.read_message(); + RETURN_FAILURE_IF_FALSE(server_identity == "server"); + socket.write_message("yi er san si wu liu"); + + return TestResult::Success; +} + +TestResult server_receives_big_message(std::string host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); + syncBindSocket(socket, format_address(host, port)); + auto result = syncRecvMessage(socket); + + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.len() == 500'000'000); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult client_sends_big_message(std::string host, uint16_t port) +{ + TcpSocket socket; + + socket.connect(host.c_str(), port); + socket.write_message("client"); + auto remote_identity = socket.read_message(); + RETURN_FAILURE_IF_FALSE(remote_identity == "server"); + std::string msg(500'000'000, '.'); + socket.write_message(msg); + + return TestResult::Success; +} + +TestResult reconnect_server_main(std::string host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); + syncBindSocket(socket, format_address(host, port)); + auto result = syncRecvMessage(socket); + + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "hello!!"); + + auto error = syncSendMessage(socket, {.address = Bytes("client"), .payload = Bytes("world!!")}); + RETURN_FAILURE_IF_FALSE(!error); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult reconnect_client_main(std::string host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + syncConnectSocket(socket, format_address(host, port)); + auto result = syncSendMessage(socket, {.address = Bytes("server"), .payload = Bytes("hello!!")}); + auto msg = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(msg.has_value()); + RETURN_FAILURE_IF_FALSE(msg->payload.as_string() == "world!!"); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult client_simulated_slow_network(const char* host, uint16_t port) +{ + TcpSocket socket; + + socket.connect(host, port); + socket.write_message("client"); + auto remote_identity = socket.read_message(); + RETURN_FAILURE_IF_FALSE(remote_identity == "server"); + + std::string message = "yi er san si wu liu"; + uint64_t header = message.length(); + + socket.write_all((char*)&header, 4); + std::this_thread::sleep_for(2s); + socket.write_all((char*)&header + 4, 4); + std::this_thread::sleep_for(3s); + socket.write_all(message.data(), header / 2); + std::this_thread::sleep_for(2s); + socket.write_all(message.data() + header / 2, header - header / 2); + + return TestResult::Success; +} + +TestResult client_sends_incomplete_identity(const char* host, uint16_t port) +{ + // open a socket, write an incomplete identity and exit + { + TcpSocket socket; + + socket.connect(host, port); + + auto server_identity = socket.read_message(); + RETURN_FAILURE_IF_FALSE(server_identity == "server"); + + // write incomplete identity and exit + std::string identity = "client"; + uint64_t header = identity.length(); + socket.write_all((char*)&header, 8); + socket.write_all(identity.data(), identity.length() - 2); + } + + // connect again and try to send a message + { + TcpSocket socket; + socket.connect(host, port); + auto server_identity = socket.read_message(); + RETURN_FAILURE_IF_FALSE(server_identity == "server"); + socket.write_message("client"); + socket.write_message("yi er san si wu liu"); + } + + return TestResult::Success; +} + +TestResult server_receives_huge_header(const char* host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); + syncBindSocket(socket, format_address(host, port)); + auto result = syncRecvMessage(socket); + + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "yi er san si wu liu"); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult client_sends_huge_header(const char* host, uint16_t port) +{ + // ignore SIGPIPE so that write() returns EPIPE instead of crashing the program + signal(SIGPIPE, SIG_IGN); + + { + TcpSocket socket; + + socket.connect(host, port); + socket.write_message("client"); + auto server_identity = socket.read_message(); + RETURN_FAILURE_IF_FALSE(server_identity == "server"); + + // write the huge header + uint64_t header = std::numeric_limits::max(); + socket.write_all((char*)&header, 8); + + size_t i = 0; + for (; i < 10; i++) { + std::this_thread::sleep_for(1s); + + try { + socket.write_all("yi er san si wu liu"); + } catch (const std::system_error& e) { + if (e.code().value() == EPIPE) { + std::println("writing failed with EPIPE as expected after sending huge header, continuing.."); + break; // this is expected + } + + throw; // rethrow other errors + } + } + + if (i == 10) { + std::println("expected EPIPE after sending huge header"); + return TestResult::Failure; + } + } + + { + TcpSocket socket; + socket.connect(host, port); + socket.write_message("client"); + auto server_identity = socket.read_message(); + RETURN_FAILURE_IF_FALSE(server_identity == "server"); + socket.write_message("yi er san si wu liu"); + } + + return TestResult::Success; +} + +TestResult server_receives_empty_messages(const char* host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); + syncBindSocket(socket, format_address(host, port)); + + auto result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == ""); + + auto result2 = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result2.has_value()); + RETURN_FAILURE_IF_FALSE(result2->payload.as_string() == ""); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult client_sends_empty_messages(std::string host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + syncConnectSocket(socket, format_address(host, port)); + + auto error = syncSendMessage(socket, Message {.address = Bytes(), .payload = Bytes()}); + RETURN_FAILURE_IF_FALSE(!error); + + auto error2 = syncSendMessage(socket, Message {.address = Bytes(), .payload = Bytes("")}); + RETURN_FAILURE_IF_FALSE(!error2); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult pubsub_subscriber(std::string host, uint16_t port, std::string topic, int differentiator) +{ + IOContext context(1); + + auto socket = + syncCreateSocket(context, IOSocketType::Unicast, std::format("{}_subscriber_{}", topic, differentiator)); + syncConnectSocket(socket, format_address(host, port)); + auto msg = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(msg.has_value()); + RETURN_FAILURE_IF_FALSE(msg->payload.as_string() == "hello topic " + topic); + + context.removeIOSocket(socket); + return TestResult::Success; +} + +TestResult pubsub_publisher(std::string host, uint16_t port, std::string topic) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Multicast, "publisher"); + syncBindSocket(socket, format_address(host, port)); + + // wait a second to ensure that the subscribers are ready + std::this_thread::sleep_for(1s); + + // the topic is wrong, so no one should receive this + auto error = syncSendMessage( + socket, Message {.address = Bytes(std::format("x{}", topic)), .payload = Bytes("no one should get this")}); + RETURN_FAILURE_IF_FALSE(!error); + + // no one should receive this either + error = syncSendMessage( + socket, + Message {.address = Bytes(std::format("{}x", topic)), .payload = Bytes("no one should get this either")}); + RETURN_FAILURE_IF_FALSE(!error); + + error = syncSendMessage(socket, Message {.address = Bytes(topic), .payload = Bytes("hello topic " + topic)}); + RETURN_FAILURE_IF_FALSE(!error); + + context.removeIOSocket(socket); + return TestResult::Success; +} + +// ━━━━━━━━━━━━━ +// test cases +// ━━━━━━━━━━━━━ + +// this is a 'basic' test which sends a single message from a client to a server +// in this variant, both the client and server are implemented using YMQ +// +// this case includes a _delay_ +// this is a thread sleep that happens after the client sends the message, to delay the close() of the socket +// at the moment, if this delay is missing, YMQ will not shut down correctly +TEST(CcYmqTestSuite, TestBasicYMQClientYMQServer) +{ + auto host = "localhost"; + auto port = 2889; + + // this is the test harness, it accepts a timeout, and a list of functions to run + auto result = + test(10, {[=] { return basic_client_ymq(host, port); }, [=] { return basic_server_ymq(host, port); }}); + + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); +} + +// same as above, except YMQs protocol is directly implemented on top of a TCP socket +TEST(CcYmqTestSuite, TestBasicRawClientYMQServer) +{ + auto host = "localhost"; + auto port = 2890; + + auto result = + test(10, {[=] { return basic_client_raw(host, port); }, [=] { return basic_server_ymq(host, port); }}); + + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); +} + +TEST(CcYmqTestSuite, TestBasicRawClientRawServer) +{ + auto host = "localhost"; + auto port = 2891; + + auto result = + test(10, {[=] { return basic_client_raw(host, port); }, [=] { return basic_server_raw(host, port); }}); + + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); +} + +// this is the same as above, except that it has no delay before calling close() on the socket +TEST(CcYmqTestSuite, TestBasicRawClientRawServerNoDelay) +{ + auto host = "localhost"; + auto port = 2892; + + auto result = + test(10, {[=] { return basic_client_raw(host, port); }, [=] { return basic_server_ymq(host, port); }}); + EXPECT_EQ(result, TestResult::Success); +} + +TEST(CcYmqTestSuite, TestBasicDelayYMQClientRawServer) +{ + auto host = "localhost"; + auto port = 2893; + + auto result = + test(10, {[=] { return basic_client_ymq(host, port); }, [=] { return basic_server_raw(host, port); }}); + + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); +} + +// in this test case, the client sends a large message to the server +// YMQ should be able to handle this without issue +TEST(CcYmqTestSuite, TestClientSendBigMessageToServer) +{ + auto host = "localhost"; + auto port = 2894; + + auto result = test( + 10, + {[=] { return client_sends_big_message(host, port); }, + [=] { return server_receives_big_message(host, port); }}); + EXPECT_EQ(result, TestResult::Success); +} + +// in this test the client is sending a message to the server +// but we simulate a slow network connection by sending the message in segmented chunks +TEST(CcYmqTestSuite, TestSlowNetwork) +{ + auto host = "localhost"; + auto port = 2895; + + auto result = test( + 20, {[=] { return client_simulated_slow_network(host, port); }, [=] { return basic_server_ymq(host, port); }}); + EXPECT_EQ(result, TestResult::Success); +} + +// TODO: figure out why this test fails in ci sometimes, and re-enable +// +// in this test, a client connects to the YMQ server but only partially sends its identity and then disconnects +// then a new client connection is established, and this one sends a complete identity and message +// YMQ should be able to recover from a poorly-behaved client like this +TEST(CcYmqTestSuite, TestClientSendIncompleteIdentity) +{ + auto host = "localhost"; + auto port = 2896; + + auto result = test( + 20, + {[=] { return client_sends_incomplete_identity(host, port); }, [=] { return basic_server_ymq(host, port); }}); + EXPECT_EQ(result, TestResult::Success); +} + +// TODO: this should pass +// currently YMQ rejects the second connection, saying that the message is too large even when it isn't +// +// in this test, the client sends an unrealistically-large header +// it is important that YMQ checks the header size before allocating memory +// both for resilence against attacks and to guard against errors +TEST(CcYmqTestSuite, TestClientSendHugeHeader) +{ + auto host = "localhost"; + auto port = 2897; + + auto result = test( + 20, + {[=] { return client_sends_huge_header(host, port); }, + [=] { return server_receives_huge_header(host, port); }}); + EXPECT_EQ(result, TestResult::Success); +} + +// in this test, the client sends empty messages to the server +// there are in effect two kinds of empty messages: Bytes() and Bytes("") +// in the former case, the bytes contains a nullptr +// in the latter case, the bytes contains a zero-length allocation +// it's important that the behaviour of YMQ is known for both of these cases +TEST(CcYmqTestSuite, TestClientSendEmptyMessage) +{ + auto host = "localhost"; + auto port = 2898; + + auto result = test( + 20, + {[=] { return client_sends_empty_messages(host, port); }, + [=] { return server_receives_empty_messages(host, port); }}); + EXPECT_EQ(result, TestResult::Success); +} + +// this case tests the publish-subscribe pattern of YMQ +// we create one publisher and two subscribers with a common topic +// the publisher will send two messages to the wrong topic +// none of the subscribers should receive these +// and then the publisher will send a message to the correct topic +// both subscribers should receive this message +TEST(CcYmqTestSuite, TestPubSub) +{ + auto host = "localhost"; + auto port = 2900; + auto topic = "mytopic"; + + auto result = test( + 20, + {[=] { return pubsub_publisher(host, port, topic); }, + [=] { return pubsub_subscriber(host, port, topic, 0); }, + [=] { return pubsub_subscriber(host, port, topic, 1); }}); + EXPECT_EQ(result, TestResult::Success); +} From 6b02ea3aaae144f09c4eaaac4aafedc28c32d87f Mon Sep 17 00:00:00 2001 From: sharpener6 <1sc2l4qi@duck.com> Date: Thu, 2 Oct 2025 15:05:01 -0400 Subject: [PATCH 02/23] Fix publish with missing shell entry (#270) --- .github/actions/publish-pypi/action.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/actions/publish-pypi/action.yml b/.github/actions/publish-pypi/action.yml index 4207d119c..1c1693666 100644 --- a/.github/actions/publish-pypi/action.yml +++ b/.github/actions/publish-pypi/action.yml @@ -37,5 +37,6 @@ runs: CIBW_MANYLINUX_X86_64_IMAGE: "manylinux_2_28" - name: Publish to PyPI + shell: bash run: | twine upload -r scaler -u "$PYPI_USERNAME" -p "$PYPI_PASSWORD" dist/* From cbbb14b5315ee59af063aef97f2c8e82a3b9e117 Mon Sep 17 00:00:00 2001 From: magniloquency <197707854+magniloquency@users.noreply.github.com> Date: Thu, 2 Oct 2025 15:28:36 -0400 Subject: [PATCH 03/23] Move Object Storage Server Tests to (#269) --- tests/CMakeLists.txt | 5 ++--- tests/cpp/CMakeLists.txt | 2 ++ tests/{ => cpp}/object_storage/CMakeLists.txt | 0 tests/{ => cpp}/object_storage/test_object_manager.cpp | 0 .../{ => cpp}/object_storage/test_object_storage_server.cpp | 0 5 files changed, 4 insertions(+), 3 deletions(-) rename tests/{ => cpp}/object_storage/CMakeLists.txt (100%) rename tests/{ => cpp}/object_storage/test_object_manager.cpp (100%) rename tests/{ => cpp}/object_storage/test_object_storage_server.cpp (100%) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b62b86fd4..f6294d570 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -32,10 +32,9 @@ function(add_test_executable test_name source_file) endfunction() if(LINUX OR APPLE) - # This directory fetches Google Test, so it must be included first. - add_subdirectory(object_storage) + # Add the directory for the C++ tests. + add_subdirectory(cpp) # Add the new directory for io tests. add_subdirectory(io/ymq) - add_subdirectory(cpp) endif() diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 26c48e9b0..6b6a872c0 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -1 +1,3 @@ +# this fetches Google Test, so it must be included first. +add_subdirectory(object_storage) add_subdirectory(ymq) \ No newline at end of file diff --git a/tests/object_storage/CMakeLists.txt b/tests/cpp/object_storage/CMakeLists.txt similarity index 100% rename from tests/object_storage/CMakeLists.txt rename to tests/cpp/object_storage/CMakeLists.txt diff --git a/tests/object_storage/test_object_manager.cpp b/tests/cpp/object_storage/test_object_manager.cpp similarity index 100% rename from tests/object_storage/test_object_manager.cpp rename to tests/cpp/object_storage/test_object_manager.cpp diff --git a/tests/object_storage/test_object_storage_server.cpp b/tests/cpp/object_storage/test_object_storage_server.cpp similarity index 100% rename from tests/object_storage/test_object_storage_server.cpp rename to tests/cpp/object_storage/test_object_storage_server.cpp From 8d43774bb7890cc2060a1c0cb4d3c93dbf1f4f8d Mon Sep 17 00:00:00 2001 From: sharpener6 <1sc2l4qi@duck.com> Date: Thu, 2 Oct 2025 16:15:01 -0400 Subject: [PATCH 04/23] Use pypa gh action to publish (#271) --- .github/actions/publish-pypi/action.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/actions/publish-pypi/action.yml b/.github/actions/publish-pypi/action.yml index 1c1693666..20e4576dc 100644 --- a/.github/actions/publish-pypi/action.yml +++ b/.github/actions/publish-pypi/action.yml @@ -37,6 +37,8 @@ runs: CIBW_MANYLINUX_X86_64_IMAGE: "manylinux_2_28" - name: Publish to PyPI - shell: bash - run: | - twine upload -r scaler -u "$PYPI_USERNAME" -p "$PYPI_PASSWORD" dist/* + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: ${{ secrets.PYPI_USERNAME }} + password: ${{ secrets.PYPI_PASSWORD }} + packages_dir: ./dist/ From f9a6b145860788f3ba2dc9c10cdbb7f5125f1ac1 Mon Sep 17 00:00:00 2001 From: sharpener6 <1sc2l4qi@duck.com> Date: Thu, 2 Oct 2025 18:19:59 -0400 Subject: [PATCH 05/23] Passing secrets to reusable actions (#272) --- .github/actions/publish-pypi/action.yml | 22 +++++++++++++++++----- .github/workflows/publish-artifact.yml | 8 +++++++- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/.github/actions/publish-pypi/action.yml b/.github/actions/publish-pypi/action.yml index 20e4576dc..c37c95671 100644 --- a/.github/actions/publish-pypi/action.yml +++ b/.github/actions/publish-pypi/action.yml @@ -5,6 +5,13 @@ inputs: os: description: "operating system" required: true + username: + description: "username" + required: true + password: + description: "password" + required: true + runs: using: "composite" @@ -37,8 +44,13 @@ runs: CIBW_MANYLINUX_X86_64_IMAGE: "manylinux_2_28" - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - user: ${{ secrets.PYPI_USERNAME }} - password: ${{ secrets.PYPI_PASSWORD }} - packages_dir: ./dist/ + shell: bash + run: | + twine upload -r pypi -u "${{ inputs.username }}" -p "${{ inputs.password }}" dist/* + + #- name: Publish to PyPI + # uses: pypa/gh-action-pypi-publish@release/v1 + # with: + # user: ${{ inputs.username }} + # password: ${{ inputs.password }} + # packages-dir: ./dist/ diff --git a/.github/workflows/publish-artifact.yml b/.github/workflows/publish-artifact.yml index 6367fe8d7..fe44c0c7e 100644 --- a/.github/workflows/publish-artifact.yml +++ b/.github/workflows/publish-artifact.yml @@ -3,6 +3,11 @@ name: Publish Python package to PyPI on: release: types: [ created ] + secrets: + PYPI_USERNAME: + required: true + PYPI_PASSWORD: + required: true permissions: contents: read @@ -18,7 +23,6 @@ jobs: permissions: id-token: write steps: - # - uses: ./.github/actions/checkout - name: Harden the runner (Audit all outbound calls) uses: step-security/harden-runner@0634a2670c59f64b4a01f0f96f84700a4088b9f0 # v2.12.0 with: @@ -34,3 +38,5 @@ jobs: - uses: ./.github/actions/publish-pypi with: os: ${{ runner.os }} + username: ${{ secrets.PYPI_USERNAME }} + password: ${{ secrets.PYPI_PASSWORD }} From 4e33a5a3d26f4bd73c5d00012303440d00506975 Mon Sep 17 00:00:00 2001 From: sharpener6 <1sc2l4qi@duck.com> Date: Fri, 3 Oct 2025 09:00:22 -0400 Subject: [PATCH 06/23] Remove username and password (#278) --- .github/actions/publish-pypi/action.yml | 10 +--------- .github/workflows/publish-artifact.yml | 7 ------- 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/.github/actions/publish-pypi/action.yml b/.github/actions/publish-pypi/action.yml index c37c95671..d9d5e7b93 100644 --- a/.github/actions/publish-pypi/action.yml +++ b/.github/actions/publish-pypi/action.yml @@ -5,12 +5,6 @@ inputs: os: description: "operating system" required: true - username: - description: "username" - required: true - password: - description: "password" - required: true runs: @@ -46,11 +40,9 @@ runs: - name: Publish to PyPI shell: bash run: | - twine upload -r pypi -u "${{ inputs.username }}" -p "${{ inputs.password }}" dist/* + twine upload -r pypi dist/* #- name: Publish to PyPI # uses: pypa/gh-action-pypi-publish@release/v1 # with: - # user: ${{ inputs.username }} - # password: ${{ inputs.password }} # packages-dir: ./dist/ diff --git a/.github/workflows/publish-artifact.yml b/.github/workflows/publish-artifact.yml index fe44c0c7e..ff5c83675 100644 --- a/.github/workflows/publish-artifact.yml +++ b/.github/workflows/publish-artifact.yml @@ -3,11 +3,6 @@ name: Publish Python package to PyPI on: release: types: [ created ] - secrets: - PYPI_USERNAME: - required: true - PYPI_PASSWORD: - required: true permissions: contents: read @@ -38,5 +33,3 @@ jobs: - uses: ./.github/actions/publish-pypi with: os: ${{ runner.os }} - username: ${{ secrets.PYPI_USERNAME }} - password: ${{ secrets.PYPI_PASSWORD }} From f7746e8a7f068dab986b55587fad0ef74dd64cd2 Mon Sep 17 00:00:00 2001 From: sharpener6 <1sc2l4qi@duck.com> Date: Fri, 3 Oct 2025 09:37:12 -0400 Subject: [PATCH 07/23] Switch to github action to publish to pypi (#279) --- .github/actions/publish-pypi/action.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/actions/publish-pypi/action.yml b/.github/actions/publish-pypi/action.yml index d9d5e7b93..7cdcdb171 100644 --- a/.github/actions/publish-pypi/action.yml +++ b/.github/actions/publish-pypi/action.yml @@ -37,12 +37,12 @@ runs: CIBW_SKIP: "pp* cp39-*" CIBW_MANYLINUX_X86_64_IMAGE: "manylinux_2_28" - - name: Publish to PyPI - shell: bash - run: | - twine upload -r pypi dist/* - #- name: Publish to PyPI - # uses: pypa/gh-action-pypi-publish@release/v1 - # with: - # packages-dir: ./dist/ + # shell: bash + # run: | + # twine upload -r pypi dist/* + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: ./dist/ From c640e111ee2ae6624e8ba59675affc701a27f945 Mon Sep 17 00:00:00 2001 From: sharpener6 <1sc2l4qi@duck.com> Date: Fri, 3 Oct 2025 13:40:05 -0400 Subject: [PATCH 08/23] Try login to github (#280) --- .github/actions/publish-pypi/action.yml | 13 +++++++++---- .github/workflows/publish-artifact.yml | 1 + 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.github/actions/publish-pypi/action.yml b/.github/actions/publish-pypi/action.yml index 7cdcdb171..ae782cfe6 100644 --- a/.github/actions/publish-pypi/action.yml +++ b/.github/actions/publish-pypi/action.yml @@ -5,6 +5,9 @@ inputs: os: description: "operating system" required: true + github_token: + description: "github token" + required: true runs: @@ -37,10 +40,12 @@ runs: CIBW_SKIP: "pp* cp39-*" CIBW_MANYLINUX_X86_64_IMAGE: "manylinux_2_28" - #- name: Publish to PyPI - # shell: bash - # run: | - # twine upload -r pypi dist/* + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ inputs.github_token }} - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/publish-artifact.yml b/.github/workflows/publish-artifact.yml index ff5c83675..e7b818aa1 100644 --- a/.github/workflows/publish-artifact.yml +++ b/.github/workflows/publish-artifact.yml @@ -33,3 +33,4 @@ jobs: - uses: ./.github/actions/publish-pypi with: os: ${{ runner.os }} + github_token: ${{ secrets.GITHUB_TOKEN }} From 61428e41d193105f3f05fecf7351424ed29b8c06 Mon Sep 17 00:00:00 2001 From: sharpener6 <1sc2l4qi@duck.com> Date: Fri, 3 Oct 2025 17:08:13 -0400 Subject: [PATCH 09/23] Move publish-artifact to top level (#283) - remove execinfo.h for musllinux - move all configurations from github workflow to pyproject.toml --- .github/actions/create-artifacts/action.yml | 29 +++++++++++ .github/actions/publish-pypi/action.yml | 53 --------------------- .github/workflows/publish-artifact.yml | 8 +++- pyproject.toml | 18 +++++++ scaler/io/ymq/common.h | 28 +---------- 5 files changed, 54 insertions(+), 82 deletions(-) create mode 100644 .github/actions/create-artifacts/action.yml delete mode 100644 .github/actions/publish-pypi/action.yml diff --git a/.github/actions/create-artifacts/action.yml b/.github/actions/create-artifacts/action.yml new file mode 100644 index 000000000..bc6365f7a --- /dev/null +++ b/.github/actions/create-artifacts/action.yml @@ -0,0 +1,29 @@ +name: create-artifacts +description: Build Artifacts + +inputs: + os: + description: "operating system" + required: true + + +runs: + using: "composite" + steps: + - name: Install Python Packages + if: inputs.os == 'Linux' + shell: bash + run: | + uv pip install --system --upgrade build cibuildwheel twine + rm -rf dist + + - name: Build Sdist + if: inputs.os == 'Linux' + shell: bash + run: | + python -m build --sdist + + - name: Build Wheel (Linux) + if: inputs.os == 'Linux' + shell: bash + run: python -m cibuildwheel --output-dir dist diff --git a/.github/actions/publish-pypi/action.yml b/.github/actions/publish-pypi/action.yml deleted file mode 100644 index ae782cfe6..000000000 --- a/.github/actions/publish-pypi/action.yml +++ /dev/null @@ -1,53 +0,0 @@ -name: publish-pypi -description: publish-pypi - -inputs: - os: - description: "operating system" - required: true - github_token: - description: "github token" - required: true - - -runs: - using: "composite" - steps: - - name: Install Python Packages - if: inputs.os == 'Linux' - shell: bash - run: | - uv pip install --system --upgrade build cibuildwheel twine - rm -rf dist - - - name: Build Sdist - if: inputs.os == 'Linux' - shell: bash - run: | - python -m build --sdist - - - name: Build Wheel (Linux) - if: inputs.os == 'Linux' - shell: bash - run: python -m cibuildwheel --output-dir dist - env: - CIBW_BEFORE_ALL_LINUX: | - echo "Building deps" - yum install -y sudo; - sudo ./scripts/download_install_libraries.sh capnp compile - sudo ./scripts/download_install_libraries.sh capnp install - CIBW_BUILD: "*manylinux_x86_64" - CIBW_SKIP: "pp* cp39-*" - CIBW_MANYLINUX_X86_64_IMAGE: "manylinux_2_28" - - - name: Login to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ inputs.github_token }} - - - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - packages-dir: ./dist/ diff --git a/.github/workflows/publish-artifact.yml b/.github/workflows/publish-artifact.yml index e7b818aa1..43d0bf8b4 100644 --- a/.github/workflows/publish-artifact.yml +++ b/.github/workflows/publish-artifact.yml @@ -30,7 +30,11 @@ jobs: with: os: ${{ runner.os }} - - uses: ./.github/actions/publish-pypi + - uses: ./.github/actions/create-artifacts with: os: ${{ runner.os }} - github_token: ${{ secrets.GITHUB_TOKEN }} + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: ./dist/ diff --git a/pyproject.toml b/pyproject.toml index bf7432bb3..819fb9b82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,3 +124,21 @@ skip-magic-trailing-comma = true [tool.isort] profile = "black" line_length = 120 + +[tool.cibuildwheel] +skip = "cp39-*" + +[tool.cibuildwheel.linux] +before-all = """ +yum install -y sudo; \ +sudo ./scripts/download_install_libraries.sh capnp compile; \ +sudo ./scripts/download_install_libraries.sh capnp install; +""" + +[[tool.cibuildwheel.overrides]] +select = "*-musllinux*" +before-all = """ +apk add sudo; \ +sudo ./scripts/download_install_libraries.sh capnp compile; \ +sudo ./scripts/download_install_libraries.sh capnp install; +""" diff --git a/scaler/io/ymq/common.h b/scaler/io/ymq/common.h index 9f0ec7c16..56617595b 100644 --- a/scaler/io/ymq/common.h +++ b/scaler/io/ymq/common.h @@ -1,38 +1,12 @@ #pragma once -// C -#ifdef __linux__ -#include -#endif // __linux__ - // C++ #include #include -#include -#include -#include -#include +#include using Errno = int; -inline void print_trace(void) -{ -#ifdef __linux__ - void* array[10]; - char** strings; - int size, i; - - size = backtrace(array, 10); - strings = backtrace_symbols(array, size); - if (strings != NULL) { - printf("Obtained %d stack frames.\n", size); - for (i = 0; i < size; i++) - printf("%s\n", strings[i]); - } - - free(strings); -#endif // __linux__ -} [[nodiscard("Memory is allocated but not used, likely causing a memory leak")]] inline uint8_t* datadup(const uint8_t* data, size_t len) noexcept From 0203eca619f792b7bcde4384a531fbbe0628fd54 Mon Sep 17 00:00:00 2001 From: sharpener6 <1sc2l4qi@duck.com> Date: Sat, 4 Oct 2025 01:36:28 -0400 Subject: [PATCH 10/23] Disable pypy and specify image version (#284) - rename download_install_libraries script to library_tool - separate compile command into download and compile, this is for cibuildwheel tool not download for multiple times when building wheel for different arch and library versions - library_tool.sh download now is cached, so ideally, we should only download once unless we change the 3rd party library version --- .../3rd-party-libraries-compile/action.yml | 30 ++++++ .../3rd-party-libraries-download/action.yml | 30 ++++++ .../3rd-party-libraries-install/action.yml | 22 +++++ .../action.yml | 6 +- .github/actions/create-artifacts/action.yml | 1 - .github/actions/setup-env/action.yml | 35 +------ .github/workflows/build-and-test.yml | 14 ++- .github/workflows/publish-artifact.yml | 4 + README.md | 90 ++++++++++------- pyproject.toml | 22 ++++- scripts/download_install_libraries.ps1 | 86 ---------------- scripts/download_install_libraries.sh | 74 -------------- scripts/library_tool.ps1 | 99 +++++++++++++++++++ scripts/library_tool.sh | 90 +++++++++++++++++ 14 files changed, 365 insertions(+), 238 deletions(-) create mode 100644 .github/actions/3rd-party-libraries-compile/action.yml create mode 100644 .github/actions/3rd-party-libraries-download/action.yml create mode 100644 .github/actions/3rd-party-libraries-install/action.yml rename .github/actions/{compile-library => compile-libraries}/action.yml (81%) delete mode 100755 scripts/download_install_libraries.ps1 delete mode 100755 scripts/download_install_libraries.sh create mode 100755 scripts/library_tool.ps1 create mode 100755 scripts/library_tool.sh diff --git a/.github/actions/3rd-party-libraries-compile/action.yml b/.github/actions/3rd-party-libraries-compile/action.yml new file mode 100644 index 000000000..663e5be5d --- /dev/null +++ b/.github/actions/3rd-party-libraries-compile/action.yml @@ -0,0 +1,30 @@ +name: "3rd Party Libraries Compile" +description: "3rd Party Libraries Compile" + +inputs: + os: + description: "operating system" + required: true + +runs: + using: "composite" + steps: + - name: Cache Library Compile + uses: actions/cache@v4 + id: compiled-libraries + with: + path: | + capnproto-* + key: compiled-libraries-${{ inputs.os }}-${{ hashFiles('scripts/library_tool.*') }} + + - name: Compile Libraries (Linux) + shell: bash + if: (inputs.os == 'Linux') && (steps.compiled-libraries.outputs.cache-hit != 'true') + run: | + ./scripts/library_tool.sh capnp compile + + - name: Compile Libraries (Windows) + shell: pwsh + if: (inputs.os == 'Windows') && (steps.compiled-libraries.outputs.cache-hit != 'true') + run: | + ./scripts/library_tool.ps1 capnp compile diff --git a/.github/actions/3rd-party-libraries-download/action.yml b/.github/actions/3rd-party-libraries-download/action.yml new file mode 100644 index 000000000..1c810e3d3 --- /dev/null +++ b/.github/actions/3rd-party-libraries-download/action.yml @@ -0,0 +1,30 @@ +name: "3rd Party Libraries Download" +description: "3rd Party Libraries Download" + +inputs: + os: + description: "operating system" + required: true + +runs: + using: "composite" + steps: + - name: Cache Library Download + uses: actions/cache@v4 + id: download-libraries + with: + path: | + downloaded + key: download-libraries-${{ inputs.os }}-${{ hashFiles('scripts/library_tool.*') }} + + - name: Download Libraries (Linux) + shell: bash + if: (inputs.os == 'Linux') && (steps.download-libraries.outputs.cache-hit != 'true') + run: | + ./scripts/library_tool.sh capnp download + + - name: Download Libraries (Windows) + shell: pwsh + if: (inputs.os == 'Windows') && (steps.download-libraries.outputs.cache-hit != 'true') + run: | + ./scripts/library_tool.ps1 capnp download diff --git a/.github/actions/3rd-party-libraries-install/action.yml b/.github/actions/3rd-party-libraries-install/action.yml new file mode 100644 index 000000000..8acada99b --- /dev/null +++ b/.github/actions/3rd-party-libraries-install/action.yml @@ -0,0 +1,22 @@ +name: "3rd Party Libraries Install" +description: "3rd Party Libraries Install" + +inputs: + os: + description: "operating system" + required: true + +runs: + using: "composite" + steps: + - name: Install Libraries (Linux) + shell: bash + if: inputs.os == 'Linux' + run: | + sudo ./scripts/library_tool.sh capnp install + + - name: Install Libraries (Windows) + shell: pwsh + if: inputs.os == 'Windows' + run: | + ./scripts/library_tool.ps1 capnp install diff --git a/.github/actions/compile-library/action.yml b/.github/actions/compile-libraries/action.yml similarity index 81% rename from .github/actions/compile-library/action.yml rename to .github/actions/compile-libraries/action.yml index 1036b1524..f253e2f80 100644 --- a/.github/actions/compile-library/action.yml +++ b/.github/actions/compile-libraries/action.yml @@ -1,5 +1,5 @@ -name: 'Build Wheel' -description: 'Build wheel' +name: 'Compile Libraries' +description: 'Compile Libraries' inputs: os: @@ -13,7 +13,7 @@ runs: if: inputs.os == 'Linux' shell: bash run: | - CXX=$(which g++-14) ./scripts/build.sh + CXX=$(which g++-14) ./scripts/build.sh - name: Build and test C++ Components (Windows) if: inputs.os == 'Windows' diff --git a/.github/actions/create-artifacts/action.yml b/.github/actions/create-artifacts/action.yml index bc6365f7a..d37441e91 100644 --- a/.github/actions/create-artifacts/action.yml +++ b/.github/actions/create-artifacts/action.yml @@ -6,7 +6,6 @@ inputs: description: "operating system" required: true - runs: using: "composite" steps: diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index 0108987d9..1ed9ff343 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -20,7 +20,7 @@ runs: run: | sudo apt update -y sudo apt install -y build-essential tzdata cmake clang curl pkg-config g++-14 - sudo chmod 755 ./scripts/download_install_libraries.sh + sudo chmod 755 ./scripts/library_tool.sh sudo chmod 755 ./scripts/build.sh - name: Install Python Base Packages @@ -28,36 +28,3 @@ runs: run: | pip install uv uv pip install --system --upgrade pip - - - name: Cache Library Install - if: inputs.os == 'Linux' - id: cache-library - uses: actions/cache@v4 - with: - path: | - capnproto-* - key: ${{ inputs.os }}-${{ hashFiles('scripts/download_install_libraries.*') }} - - - name: Download/Compile Libraries (Linux) - shell: bash - if: (inputs.os == 'Linux') && (steps.cache-library.outputs.cache-hit != 'true') - run: | - sudo ./scripts/download_install_libraries.sh capnp compile - - - name: Download/Compile Libraries (Windows) - shell: pwsh - if: (inputs.os == 'Windows') && (steps.cache-boost-windows.outputs.cache-hit != 'true') - run: | - ./scripts/download_install_libraries.ps1 capnp compile - - - name: Install Libraries (Linux) - shell: bash - if: inputs.os == 'Linux' - run: | - sudo ./scripts/download_install_libraries.sh capnp install - - - name: Install Libraries (Windows) - shell: pwsh - if: inputs.os == 'Windows' - run: | - ./scripts/download_install_libraries.ps1 capnp install diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index d25e06b70..b7a4dd574 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -30,7 +30,19 @@ jobs: with: os: ${{ runner.os }} - - uses: ./.github/actions/compile-library + - uses: ./.github/actions/3rd-party-libraries-download + with: + os: ${{ runner.os }} + + - uses: ./.github/actions/3rd-party-libraries-compile + with: + os: ${{ runner.os }} + + - uses: ./.github/actions/3rd-party-libraries-install + with: + os: ${{ runner.os }} + + - uses: ./.github/actions/compile-libraries with: os: ${{ runner.os }} diff --git a/.github/workflows/publish-artifact.yml b/.github/workflows/publish-artifact.yml index 43d0bf8b4..e5b09625d 100644 --- a/.github/workflows/publish-artifact.yml +++ b/.github/workflows/publish-artifact.yml @@ -30,6 +30,10 @@ jobs: with: os: ${{ runner.os }} + - uses: ./.github/actions/3rd-party-libraries-download + with: + os: ${{ runner.os }} + - uses: ./.github/actions/create-artifacts with: os: ${{ runner.os }} diff --git a/README.md b/README.md index 1eef78e89..9d98b2a24 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,8 @@
-**OpenGRIS Scaler provides a simple, efficient, and reliable way to perform distributed computing** using a centralized scheduler, +**OpenGRIS Scaler provides a simple, efficient, and reliable way to perform distributed computing** using a centralized +scheduler, with a stable and language-agnostic protocol for client and worker communications. ```python @@ -43,7 +44,8 @@ with Client(address="tcp://127.0.0.1:2345") as client: print(sum(results)) # 661.46 ``` -OpenGRIS Scaler is a suitable Dask replacement, offering significantly better scheduling performance for jobs with a large number +OpenGRIS Scaler is a suitable Dask replacement, offering significantly better scheduling performance for jobs with a +large number of lightweight tasks while improving on load balancing, messaging, and deadlocks. ## Features @@ -51,15 +53,16 @@ of lightweight tasks while improving on load balancing, messaging, and deadlocks - Distributed computing across **multiple cores and multiple servers** - **Python** reference implementation, with **language-agnostic messaging protocol** built on top of [Cap'n Proto](https://capnproto.org/) and [ZeroMQ](https://zeromq.org) -- **Graph** scheduling, which supports [Dask](https://www.dask.org)-like graph computing, with optional [GraphBLAS](https://graphblas.org) +- **Graph** scheduling, which supports [Dask](https://www.dask.org)-like graph computing, with + optional [GraphBLAS](https://graphblas.org) support for very large graph tasks -- **Automated load balancing**, which automatically balances load from busy workers to idle workers, ensuring uniform utilization across workers +- **Automated load balancing**, which automatically balances load from busy workers to idle workers, ensuring uniform + utilization across workers - **Automated task recovery** from worker-related hardware, OS, or network failures - Support for **nested tasks**, allowing tasks to submit new tasks - `top`-like **monitoring tools** - GUI monitoring tool - ## Installation Scaler is available on PyPI and can be installed using any compatible package manager. @@ -79,10 +82,12 @@ Scaler has 3 main components: - A **scheduler**, responsible for routing tasks to available computing resources. - An **object storage server** that stores the task data objects (task arguments and task results). -- A set of **workers** that form a _cluster_. Workers are independent computing units, each capable of executing a single task. +- A set of **workers** that form a _cluster_. Workers are independent computing units, each capable of executing a + single task. - **Clients** running inside applications, responsible for submitting tasks to the scheduler. -Please be noted that **Clients** are cross platform, supporting Windows and GNU/Linux, while other components can only be run on GNU/Linux. +Please be noted that **Clients** are cross platform, supporting Windows and GNU/Linux, while other components can only +be run on GNU/Linux. ### Start local scheduler and cluster programmatically in code @@ -136,7 +141,8 @@ $ scaler_cluster -n 4 tcp://127.0.0.1:2345 ... ``` -Multiple Scaler clusters can be connected to the same scheduler, providing distributed computation over multiple servers. +Multiple Scaler clusters can be connected to the same scheduler, providing distributed computation over multiple +servers. `-h` lists the available options for the object storage server, scheduler and the cluster executables: @@ -159,8 +165,8 @@ def square(value: int): with Client(address="tcp://127.0.0.1:2345") as client: - future = client.submit(square, 4) # submits a single task - print(future.result()) # 16 + future = client.submit(square, 4) # submits a single task + print(future.result()) # 16 ``` `Client.submit()` returns a standard Python future. @@ -221,7 +227,7 @@ def fibonacci(client: Client, n: int): with Client(address="tcp://127.0.0.1:2345") as client: future = client.submit(fibonacci, client, 8) - print(future.result()) # 21 + print(future.result()) # 21 ``` ## Task Routing and Capability Management @@ -266,7 +272,8 @@ with Client(address="tcp://127.0.0.1:2345") as client: The scheduler will route a task to a worker if `task.capabilities.is_subset(worker.capabilities)`. -Integer values specified for capabilities (e.g., `gpu=10`) are *currently* ignored by the capabilities allocation policy. +Integer values specified for capabilities (e.g., `gpu=10`) are *currently* ignored by the capabilities allocation +policy. This means that the presence of a capabilities is considered, but not its quantity. Support for capabilities tracking might be added in the future. @@ -307,6 +314,7 @@ class Message(soamapi.Message): def on_deserialize(self, stream): self.set_payload(stream.read_byte_array("b")) + class ServiceContainer(soamapi.ServiceContainer): def on_create_service(self, service_context): return @@ -330,6 +338,7 @@ class ServiceContainer(soamapi.ServiceContainer): def on_destroy_service(self): return ``` + ### Nested tasks @@ -354,7 +363,8 @@ A good heuristic for setting the base concurrency is to use the following formul base_concurrency = number_of_cores - deepest_nesting_level ``` -where `deepest_nesting_level` is the deepest nesting level a task has in your workload. For instance, if you have a workload that has +where `deepest_nesting_level` is the deepest nesting level a task has in your workload. For instance, if you have a +workload that has a base task that calls a nested task that calls another nested task, then the deepest nesting level is 2. ## Worker Adapter usage @@ -364,7 +374,8 @@ a base task that calls a nested task that calls another nested task, then the de Scaler provides a Worker Adapter webhook interface to integrate with other job schedulers or resource managers. The Worker Adapter allows external systems to request the creation and termination of Scaler workers dynamically. -Please check the OpenGRIS standard for more details on the Worker Adapter specification [here](https://github.com/finos/opengris). +Please check the OpenGRIS standard for more details on the Worker Adapter +specification [here](https://github.com/finos/opengris). ### Starting the Native Worker Adapter @@ -443,8 +454,9 @@ W|Linux|15943|a7fe8b5e+ 0.0% 30.7m 0.0% 28.3m 1000 0 0 | - **scheduler_sent** section shows count for each type of messages scheduler sent - **scheduler_received** section shows count for each type of messages scheduler received - **function_id_to_tasks** section shows task count for each function used -- **worker** section shows worker details, , you can use shortcuts to sort by columns, and the * in the column header shows -which column is being used for sorting +- **worker** section shows worker details, , you can use shortcuts to sort by columns, and the * in the column header + shows + which column is being used for sorting - `agt_cpu/agt_rss` means cpu/memory usage of worker agent - `cpu/rss` means cpu/memory usage of worker - `free` means number of free task slots for this worker @@ -473,13 +485,13 @@ We showcased Scaler at FOSDEM 2025. Check out the slides To contribute to Scaler, you might need to manually build its C++ components. These C++ components depend on the Boost and Cap'n Proto libraries. If these libraries are not available on your system, -you can use the `download_install_libraries.sh` script to download, compile, and install them (You might need `sudo`): +you can use the `library_tool.sh` script to download, compile, and install them (You might need `sudo`): ```bash -./scripts/download_install_libraries.sh boost compile -./scripts/download_install_libraries.sh boost install -./scripts/download_install_libraries.sh capnp compile -./scripts/download_install_libraries.sh capnp install +./scripts/library_tool.sh boost compile +./scripts/library_tool.sh boost install +./scripts/library_tool.sh capnp compile +./scripts/library_tool.sh capnp install ``` After installing these dependencies, use the `build.sh` script to configure, build, and install Scaler's C++ components: @@ -494,16 +506,19 @@ within the main source tree, as compiled Python modules. You can specify the com ### Building on Windows -*Building on Windows requires _Visual Studio 17 2022_*. Similar to the former section, you can use the `download_install_libraries.ps1` script to download, compile, and install them (You might need `Run as administrator`): +*Building on Windows requires _Visual Studio 17 2022_*. Similar to the former section, you can use the +`library_tool.ps1` script to download, compile, and install them (You might need `Run as administrator`): ```bash -./scripts/download_install_libraries.ps1 boost compile -./scripts/download_install_libraries.ps1 boost install -./scripts/download_install_libraries.ps1 capnp compile -./scripts/download_install_libraries.ps1 capnp install +./scripts/library_tool.ps1 boost compile +./scripts/library_tool.ps1 boost install +./scripts/library_tool.ps1 capnp compile +./scripts/library_tool.ps1 capnp install ``` -After installing these dependencies, if you are using _Visual Studio_ for developing, you may open the project folder with it, select preset `windows-x64`, and build the project. You may also run the following commands to configure, build, and install Scaler's C++ components: +After installing these dependencies, if you are using _Visual Studio_ for developing, you may open the project folder +with it, select preset `windows-x64`, and build the project. You may also run the following commands to configure, +build, and install Scaler's C++ components: ```bash cmake --preset windows-x64 @@ -511,7 +526,8 @@ cmake --build --preset windows-x64 --config (Debug|Release) cmake --install build_windows_x64 --config (Debug|Release) ``` -The output will be similar to what described in the former section. We recommend using _Visual Studio_ for developing on Windows. +The output will be similar to what described in the former section. We recommend using _Visual Studio_ for developing on +Windows. ### Building the Python wheel @@ -523,10 +539,10 @@ pip install build cibuildwheel==2.23.3 # Parametrize the cibuildwheel's container to build the Boost and Cap'n Proto dependencies. export CIBW_BEFORE_ALL=' yum install sudo -y; - sudo ./scripts/download_install_libraries.sh capnp compile - sudo ./scripts/download_install_libraries.sh capnp install - sudo ./scripts/download_install_libraries.sh boost compile - sudo ./scripts/download_install_libraries.sh boost install' + sudo ./scripts/library_tool.sh capnp compile + sudo ./scripts/library_tool.sh capnp install + sudo ./scripts/library_tool.sh boost compile + sudo ./scripts/library_tool.sh boost install' export CIBW_BUILD="*manylinux_x86_64" export CIBW_SKIP="pp*" export CIBW_MANYLINUX_X86_64_IMAGE="manylinux_2_28" @@ -549,7 +565,12 @@ We welcome you to: Please review [functional contribution guidelines](./CONTRIBUTING.md) to get started 👍. -_NOTE:_ Commits and pull requests to FINOS repositories will only be accepted from those contributors with an active, executed Individual Contributor License Agreement (ICLA) with FINOS OR contributors who are covered under an existing and active Corporate Contribution License Agreement (CCLA) executed with FINOS. Commits from individuals not covered under an ICLA or CCLA will be flagged and blocked by the ([EasyCLA](https://community.finos.org/docs/governance/Software-Projects/easycla)) tool. Please note that some CCLAs require individuals/employees to be explicitly named on the CCLA. +_NOTE:_ Commits and pull requests to FINOS repositories will only be accepted from those contributors with an active, +executed Individual Contributor License Agreement (ICLA) with FINOS OR contributors who are covered under an existing +and active Corporate Contribution License Agreement (CCLA) executed with FINOS. Commits from individuals not covered +under an ICLA or CCLA will be flagged and blocked by +the ([EasyCLA](https://community.finos.org/docs/governance/Software-Projects/easycla)) tool. Please note that some CCLAs +require individuals/employees to be explicitly named on the CCLA. *Need an ICLA? Unsure if you are covered under an existing CCLA? Email [help@finos.org](mailto:help@finos.org)* @@ -568,5 +589,6 @@ SPDX-License-Identifier: [Apache-2.0](https://spdx.org/licenses/Apache-2.0) ## Contact -If you have a query or require support with this project, [raise an issue](https://github.com/finos/opengris-scaler/issues). +If you have a query or require support with this +project, [raise an issue](https://github.com/finos/opengris-scaler/issues). Otherwise, reach out to [opensource@citi.com](mailto:opensource@citi.com). diff --git a/pyproject.toml b/pyproject.toml index 819fb9b82..5e1c0b1ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,19 +126,31 @@ profile = "black" line_length = 120 [tool.cibuildwheel] -skip = "cp39-*" +skip = "pp* cp39-*" [tool.cibuildwheel.linux] +archs = ["x86_64"] +manylinux-x86_64-image = "manylinux_2_28" before-all = """ yum install -y sudo; \ -sudo ./scripts/download_install_libraries.sh capnp compile; \ -sudo ./scripts/download_install_libraries.sh capnp install; +sudo ./scripts/library_tool.sh capnp compile; \ +sudo ./scripts/library_tool.sh capnp install; """ [[tool.cibuildwheel.overrides]] select = "*-musllinux*" before-all = """ apk add sudo; \ -sudo ./scripts/download_install_libraries.sh capnp compile; \ -sudo ./scripts/download_install_libraries.sh capnp install; +sudo ./scripts/library_tool.sh capnp compile; \ +sudo ./scripts/library_tool.sh capnp install; """ + +#[tool.cibuildwheel.macos] +#archs = ["x86_64", "arm64"] +#before-all = """ +#yum install -y sudo; \ +#sudo ./scripts/library_tool.sh capnp compile; \ +#sudo ./scripts/library_tool.sh capnp install; +#""" + + diff --git a/scripts/download_install_libraries.ps1 b/scripts/download_install_libraries.ps1 deleted file mode 100755 index 1588ac6fd..000000000 --- a/scripts/download_install_libraries.ps1 +++ /dev/null @@ -1,86 +0,0 @@ -# Constants -$BOOST_VERSION = "1.88.0" -$CAPNP_VERSION = "1.1.0" -$PREFIX = "C:\Program Files" - -# Parse optional --prefix argument from $args -foreach ($arg in $args) { - if ($arg -match "^--prefix=(.+)$") { - $PREFIX = $matches[1] - } -} - -# Get the number of cores -$NUM_CORES = [Environment]::ProcessorCount - -[Environment]::SetEnvironmentVariable("Path", - [Environment]::GetEnvironmentVariable("Path", - [EnvironmentVariableTarget]::Machine) + ";$PREFIX", - [EnvironmentVariableTarget]::Machine) - -# Main logic -if ($args.Count -lt 2) { - Write-Host "Usage: .\download_install_libraries.ps1 [boost|capnp] [compile|install] [--prefix=DIR]" - exit 1 -} - -$dependency = $args[0] -$action = $args[1] - -# Download, compile, or install Boost -if ($dependency -eq "boost") { - if ($action -eq "compile") { - $BOOST_FOLDER_NAME = "boost_" + $BOOST_VERSION -replace '\.', '_' - $BOOST_PACKAGE_NAME = "$BOOST_FOLDER_NAME.tar.gz" - $url = "https://archives.boost.org/release/$BOOST_VERSION/source/$BOOST_PACKAGE_NAME" - - # Download and extract Boost - # Necessary exe because of local dev env - curl.exe -O $url --retry 100 --retry-max-time 3600 - tar -xzf $BOOST_PACKAGE_NAME - Rename-Item -Path $BOOST_FOLDER_NAME -NewName "boost" - } - elseif ($action -eq "install") { - Copy-Item -Recurse -Path "boost\boost" -Destination "$PREFIX\include\boost" - Write-Host "Installed Boost into $PREFIX\include\boost" - } - else { - Write-Host "Argument needs to be either compile or install" - exit 1 - } -} - -# Download, compile, or install Cap'n Proto -elseif ($dependency -eq "capnp") { - if ($action -eq "compile") { - $CAPNP_FOLDER_NAME = "capnproto-c++-$CAPNP_VERSION" - $CAPNP_PACKAGE_NAME = "$CAPNP_FOLDER_NAME.tar.gz" - $url = "https://capnproto.org/$CAPNP_PACKAGE_NAME" - - # Download and extract Cap'n Proto - curl.exe -O $url --retry 100 --retry-max-time 3600 - tar -xzf $CAPNP_PACKAGE_NAME - Rename-Item -Path $CAPNP_FOLDER_NAME -NewName "capnp" - - # Configure and build with Visual Studio using CMake - Set-Location -Path "capnp" - cmake -G "Visual Studio 17 2022" -B build - cmake --build build --config Release - } - elseif ($action -eq "install") { - $CAPNP_FOLDER_NAME = "capnproto-c++-$CAPNP_VERSION" - Rename-Item -Path $CAPNP_FOLDER_NAME -NewName "capnp" - Set-Location -Path "capnp" - cmake --install build --config Release --prefix $PREFIX - Write-Host "Installed capnp into $PREFIX" - } - else { - Write-Host "Argument needs to be either compile or install" - exit 1 - } - -else { - Write-Host "Usage: .\download_install_libraries.ps1 [boost|capnp] [--prefix=DIR]" - exit 1 -} - diff --git a/scripts/download_install_libraries.sh b/scripts/download_install_libraries.sh deleted file mode 100755 index 651269108..000000000 --- a/scripts/download_install_libraries.sh +++ /dev/null @@ -1,74 +0,0 @@ -#!/bin/bash -e -# This script builds and installs the required 3rd party C++ libraries. -# -# Usage: -# ./scripts/download_install_libraries.sh [boost|capnp] [compile|install] [--prefix=PREFIX] - -# Remember: -# Update the usage string when you are add/remove dependency -# Bump version should be done through variables, not hard coded strs. - -BOOST_VERSION="1.88.0" -CAPNP_VERSION="1.1.0" - -PREFIX="/usr/local" - -# Parse the optional --prefix= argument -for arg in "$@"; do - if [[ "$arg" == --prefix=* ]]; then - PREFIX="${arg#--prefix=}" - fi -done - -if [[ "$OSTYPE" == "linux-gnu"* ]]; then - NUM_CORES=$(nproc) -elif [[ "$OSTYPE" == "darwin"* ]]; then - NUM_CORES=$(sysctl -n hw.ncpu) -else - NUM_CORES=1 -fi - -PREFIX=`readlink -f $PREFIX` -mkdir -p ${PREFIX}/include/ - -show_help() { - echo "Usage: ./download_install_libraries.sh [boost|capnp] [compile|install] [--prefix=DIR]" - exit 1 -} - -if [ "$1" == "boost" ]; then - BOOST_FOLDER_NAME="boost_$(echo $BOOST_VERSION | tr '.' '_')" - if [ "$2" == "compile" ]; then - BOOST_PACKAGE_NAME=${BOOST_FOLDER_NAME}.tar.gz - curl -O https://archives.boost.io/release/${BOOST_VERSION}/source/${BOOST_PACKAGE_NAME} --retry 100 --retry-max-time 3600 - tar -xzf ${BOOST_PACKAGE_NAME} - - elif [ "$2" == "install" ]; then - cp -r ${BOOST_FOLDER_NAME}/boost ${PREFIX}/include/. - echo "Installed Boost into ${PREFIX}/include/boost" - - else - show_help - fi -elif [ "$1" == "capnp" ]; then - CAPNP_FOLDER_NAME="capnproto-c++-$(echo $CAPNP_VERSION)" - if [ "$2" == "compile" ]; then - CAPNP_PACKAGE_NAME=${CAPNP_FOLDER_NAME}.tar.gz - curl -O https://capnproto.org/${CAPNP_PACKAGE_NAME} --retry 100 --retry-max-time 3600 - tar -xzf ${CAPNP_PACKAGE_NAME} - - cd ${CAPNP_FOLDER_NAME} - ./configure --prefix=${PREFIX} CXXFLAGS="${CXXFLAGS} -I${PREFIX}/include" LDFLAGS="${LDFLAGS} -L${PREFIX}/lib -Wl,-rpath,${PREFIX}/lib" - make -j${NUM_CORES} - - elif [ "$2" == "install" ]; then - cd ${CAPNP_FOLDER_NAME} - make install - echo "Installed capnp into ${PREFIX}" - - else - show_help - fi -else - show_help -fi diff --git a/scripts/library_tool.ps1 b/scripts/library_tool.ps1 new file mode 100755 index 000000000..d1daf9976 --- /dev/null +++ b/scripts/library_tool.ps1 @@ -0,0 +1,99 @@ +# Constants +$BOOST_VERSION = "1.88.0" +$CAPNP_VERSION = "1.1.0" + +$DOWNLOAD_DIR = ".\downloaded" +$PREFIX = "C:\Program Files" + +# Parse optional --prefix argument from $args +foreach ($arg in $args) +{ + if ($arg -match "^--prefix=(.+)$") + { + $PREFIX = $matches[1] + } +} + +# Get the number of cores +$NUM_CORES = [Environment]::ProcessorCount + +[Environment]::SetEnvironmentVariable("Path", + [Environment]::GetEnvironmentVariable("Path", + [EnvironmentVariableTarget]::Machine) + ";$PREFIX", + [EnvironmentVariableTarget]::Machine) + +# Main logic +if ($args.Count -lt 2) +{ + Write-Host "Usage: .\library_tool.ps1 [boost|capnp] [download|compile|install] [--prefix=DIR]" + exit 1 +} + +$dependency = $args[0] +$action = $args[1] + +# Download, compile, or install Boost +if ($dependency -eq "boost") +{ + $BOOST_FOLDER_NAME = "boost_" + $BOOST_VERSION -replace '\.', '_' + + if ($action -eq "download") + { + mkdir "$DOWNLOAD_DIR" -Force + $url = "https://archives.boost.org/release/$BOOST_VERSION/source/$BOOST_FOLDER_NAME.tar.gz" + curl.exe --retry 100 --retry-max-time 3600 -L $url -o "$DOWNLOAD_DIR\$BOOST_FOLDER_NAME.tar.gz" + } + elseif ($action -eq "compile") + { + tar -xzvf "$DOWNLOAD_DIR\$BOOST_FOLDER_NAME.tar.gz" -C .\ + } + elseif ($action -eq "install") + { + Copy-Item -Recurse -Path "boost\boost" -Destination "$PREFIX\include\boost" + Write-Host "Installed Boost into $PREFIX\include\boost" + } + else + { + Write-Host "Argument needs to be download or compile or install" + exit 1 + } +} + +# Download, compile, or install Cap'n Proto +elseif ($dependency -eq "capnp") +{ + $CAPNP_FOLDER_NAME = "capnproto-c++-$CAPNP_VERSION" + + if ($action -eq "download") + { + mkdir "$DOWNLOAD_DIR" -Force + $url = "https://capnproto.org/$CAPNP_FOLDER_NAME.tar.gz" + curl.exe --retry 100 --retry-max-time 3600 -L $url -o "$DOWNLOAD_DIR\$CAPNP_FOLDER_NAME.tar.gz" + } + if ($action -eq "compile") + { + Remove-Item -Path "$CAPNP_FOLDER_NAME" -Recurse -Force + tar -xzvf "$DOWNLOAD_DIR\$CAPNP_FOLDER_NAME.tar.gz" -C .\ + + # Configure and build with Visual Studio using CMake + Set-Location -Path "$CAPNP_FOLDER_NAME" + cmake -G "Visual Studio 17 2022" -B build + cmake --build build --config Release + } + elseif ($action -eq "install") + { + Set-Location -Path "$CAPNP_FOLDER_NAME" + cmake --install build --config Release --prefix $PREFIX + Write-Host "Installed capnp into $PREFIX" + } + else + { + Write-Host "Argument needs to be download or compile or install" + exit 1 + } + + else { + Write-Host "Usage: .\library_tool.ps1 [download|boost|capnp] [--prefix=DIR]" + exit 1 + } +} diff --git a/scripts/library_tool.sh b/scripts/library_tool.sh new file mode 100755 index 000000000..2fd512ed7 --- /dev/null +++ b/scripts/library_tool.sh @@ -0,0 +1,90 @@ +#!/bin/bash -e +# This script builds and installs the required 3rd party C++ libraries. +# +# Usage: +# ./scripts/library_tool.sh [boost|capnp] [compile|install] [--prefix=PREFIX] + +# Remember: +# Update the usage string when you are add/remove dependency +# Bump version should be done through variables, not hard coded strs. + +set -x + +BOOST_VERSION="1.88.0" +CAPNP_VERSION="1.1.0" + +DOWNLOAD_DIR="./downloaded" +PREFIX="/usr/local" + +# Parse the optional --prefix= argument +for arg in "$@"; do + if [[ "$arg" == --prefix=* ]]; then + PREFIX="${arg#--prefix=}" + fi +done + +if [[ "$OSTYPE" == "linux-gnu"* ]]; then + NUM_CORES=$(nproc) +elif [[ "$OSTYPE" == "darwin"* ]]; then + NUM_CORES=$(sysctl -n hw.ncpu) +else + NUM_CORES=1 +fi + +PREFIX=$(readlink -f "${PREFIX}") +mkdir -p "${PREFIX}/include/" + +show_help() { + echo "Usage: ./library_tool.sh [boost|capnp] [download|compile|install] [--prefix=DIR]" + exit 1 +} + +if [ "$1" == "boost" ]; then + BOOST_FOLDER_NAME="boost_$(echo $BOOST_VERSION | tr '.' '_')" + + if [ "$2" == "download" ]; then + mkdir -p ${DOWNLOAD_DIR} + curl --retry 100 --retry-max-time 3600 \ + -L "https://archives.boost.io/release/${BOOST_VERSION}/source/${BOOST_FOLDER_NAME}.tar.gz" \ + -o "${DOWNLOAD_DIR}/${BOOST_FOLDER_NAME}.tar.gz" + echo "Downloaded Boost to ${DOWNLOAD_DIR}/${BOOST_FOLDER_NAME}.tar.gz" + + elif [ "$2" == "compile" ]; then + tar -xzvf "${DOWNLOAD_DIR}/${BOOST_FOLDER_NAME}.tar.gz" -C "./" + + elif [ "$2" == "install" ]; then + cp -r "${BOOST_FOLDER_NAME}/boost" "${PREFIX}/include/." + echo "Installed Boost into ${PREFIX}/include/boost" + + else + show_help + fi +elif [ "$1" == "capnp" ]; then + CAPNP_FOLDER_NAME="capnproto-c++-$(echo $CAPNP_VERSION)" + + if [ "$2" == "download" ]; then + mkdir -p ${DOWNLOAD_DIR} + curl --retry 100 --retry-max-time 3600 \ + -L "https://capnproto.org/${CAPNP_FOLDER_NAME}.tar.gz" \ + -o "${DOWNLOAD_DIR}/${CAPNP_FOLDER_NAME}.tar.gz" + echo "Downloaded capnp into ${DOWNLOAD_DIR}/${CAPNP_FOLDER_NAME}.tar.gz" + + elif [ "$2" == "compile" ]; then + rm -rf "${CAPNP_FOLDER_NAME}" + tar -xzf "${DOWNLOAD_DIR}/${CAPNP_FOLDER_NAME}.tar.gz" -C "./" + + cd "${CAPNP_FOLDER_NAME}" + ./configure --prefix="${PREFIX}" CXXFLAGS="${CXXFLAGS} -I${PREFIX}/include" LDFLAGS="${LDFLAGS} -L${PREFIX}/lib -Wl,-rpath,${PREFIX}/lib" + make -j "${NUM_CORES}" + + elif [ "$2" == "install" ]; then + cd "${CAPNP_FOLDER_NAME}" + make install + echo "Installed capnp into ${PREFIX}" + + else + show_help + fi +else + show_help +fi From 7f7f481e57fb2130d4bab2733718d224dd9b75fc Mon Sep 17 00:00:00 2001 From: magniloquency <197707854+magniloquency@users.noreply.github.com> Date: Sun, 5 Oct 2025 23:32:26 -0400 Subject: [PATCH 11/23] Make Pub-Sub Test More Reliable (#285) --- scaler/io/ymq/common.h | 3 +-- tests/cpp/ymq/common.h | 5 +--- tests/cpp/ymq/test_cc_ymq.cpp | 47 +++++++++++++++++++++++++++++------ 3 files changed, 42 insertions(+), 13 deletions(-) diff --git a/scaler/io/ymq/common.h b/scaler/io/ymq/common.h index 56617595b..2fd420f6e 100644 --- a/scaler/io/ymq/common.h +++ b/scaler/io/ymq/common.h @@ -1,13 +1,12 @@ #pragma once // C++ +#include #include #include -#include using Errno = int; - [[nodiscard("Memory is allocated but not used, likely causing a memory leak")]] inline uint8_t* datadup(const uint8_t* data, size_t len) noexcept { diff --git a/tests/cpp/ymq/common.h b/tests/cpp/ymq/common.h index 5fd9dad9f..029048b27 100644 --- a/tests/cpp/ymq/common.h +++ b/tests/cpp/ymq/common.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -29,7 +28,6 @@ #include #include #include -#include #include #include #include @@ -243,8 +241,7 @@ inline void fork_wrapper(std::function fn, int timeout_secs, Owned // run a test // forks and runs each of the provided closures -inline TestResult test( - int timeout_secs, std::vector> closures) +inline TestResult test(int timeout_secs, std::vector> closures) { std::vector> pipes {}; std::vector pids {}; diff --git a/tests/cpp/ymq/test_cc_ymq.cpp b/tests/cpp/ymq/test_cc_ymq.cpp index f20321908..580a37bfa 100644 --- a/tests/cpp/ymq/test_cc_ymq.cpp +++ b/tests/cpp/ymq/test_cc_ymq.cpp @@ -5,8 +5,11 @@ // the test cases are at the bottom of this file, after the clients and servers // the documentation for each case is found on the TEST() definition +#include #include #include +#include +#include #include #include @@ -310,13 +313,23 @@ TestResult client_sends_empty_messages(std::string host, uint16_t port) return TestResult::Success; } -TestResult pubsub_subscriber(std::string host, uint16_t port, std::string topic, int differentiator) +TestResult pubsub_subscriber(std::string host, uint16_t port, std::string topic, int differentiator, sem_t* sem) { IOContext context(1); auto socket = syncCreateSocket(context, IOSocketType::Unicast, std::format("{}_subscriber_{}", topic, differentiator)); + + std::this_thread::sleep_for(500ms); + syncConnectSocket(socket, format_address(host, port)); + + std::this_thread::sleep_for(500ms); + + if (sem_post(sem) < 0) + throw std::system_error(errno, std::generic_category(), "failed to signal semaphore"); + sem_close(sem); + auto msg = syncRecvMessage(socket); RETURN_FAILURE_IF_FALSE(msg.has_value()); RETURN_FAILURE_IF_FALSE(msg->payload.as_string() == "hello topic " + topic); @@ -325,15 +338,21 @@ TestResult pubsub_subscriber(std::string host, uint16_t port, std::string topic, return TestResult::Success; } -TestResult pubsub_publisher(std::string host, uint16_t port, std::string topic) +// topic: the identifier of the topic, must match what's passed to the subscribers +// sem: a semaphore to synchronize the publisher and subscriber processes +// n: the number of subscribers +TestResult pubsub_publisher(std::string host, uint16_t port, std::string topic, sem_t* sem, int n) { IOContext context(1); auto socket = syncCreateSocket(context, IOSocketType::Multicast, "publisher"); syncBindSocket(socket, format_address(host, port)); - // wait a second to ensure that the subscribers are ready - std::this_thread::sleep_for(1s); + // wait for the subscribers to be ready + for (int i = 0; i < n; i++) + if (sem_wait(sem) < 0) + throw std::system_error(errno, std::generic_category(), "failed to wait on semaphore"); + sem_close(sem); // the topic is wrong, so no one should receive this auto error = syncSendMessage( @@ -513,10 +532,24 @@ TEST(CcYmqTestSuite, TestPubSub) auto port = 2900; auto topic = "mytopic"; + // allocate a semaphore to synchronize the publisher and subscriber processes + sem_t* sem = + static_cast(mmap(nullptr, sizeof(sem_t), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + + if (sem == MAP_FAILED) + throw std::system_error(errno, std::generic_category(), "failed to map shared memory for semaphore"); + + if (sem_init(sem, 1, 0) < 0) + throw std::system_error(errno, std::generic_category(), "failed to initialize semaphore"); + auto result = test( 20, - {[=] { return pubsub_publisher(host, port, topic); }, - [=] { return pubsub_subscriber(host, port, topic, 0); }, - [=] { return pubsub_subscriber(host, port, topic, 1); }}); + {[=] { return pubsub_publisher(host, port, topic, sem, 2); }, + [=] { return pubsub_subscriber(host, port, topic, 0, sem); }, + [=] { return pubsub_subscriber(host, port, topic, 1, sem); }}); + + sem_destroy(sem); + munmap(sem, sizeof(sem_t)); + EXPECT_EQ(result, TestResult::Success); } From ecf7be7f2250abf80c999cf42876ceb0db1ce16d Mon Sep 17 00:00:00 2001 From: e117649 Date: Mon, 6 Oct 2025 12:16:21 -0400 Subject: [PATCH 12/23] Cluster worker init script support via --preload (#234) * Add support for --preload "foo.bar:preload_function(arg1, arg2)" * Update example with preload for linter check * Update tests with preload for linter check * Log and stop on error, raising exception * Update docs with preload * Add tests for preload * Cleanup tests for preload * Try longer sleep to fix failing test on GitHub (works locally) * Separate preload tests which use their own cluster, possibly slowing test_responsiveness * Add some tests for preload spec * Address minor comments for preload Signed-off-by: eric * Bump version Signed-off-by: eric * Remove unused import Signed-off-by: eric --------- Signed-off-by: eric Co-authored-by: sharpener6 <1sc2l4qi@duck.com> --- docs/source/tutorials/configuration.rst | 39 +++++++ examples/task_capabilities.py | 1 + scaler/cluster/cluster.py | 3 + scaler/cluster/combo.py | 1 + scaler/entry_points/cluster.py | 7 ++ scaler/version.txt | 2 +- scaler/worker/agent/processor/processor.py | 12 +++ scaler/worker/agent/processor_holder.py | 2 + scaler/worker/agent/processor_manager.py | 3 + scaler/worker/preload.py | 84 +++++++++++++++ scaler/worker/worker.py | 3 + scaler/worker_adapter/native.py | 1 + tests/test_balance.py | 1 + tests/test_client.py | 118 +++++++++++++++++++++ tests/test_death_timeout.py | 1 + tests/test_graph.py | 1 + tests/utility.py | 20 ++++ 17 files changed, 298 insertions(+), 1 deletion(-) create mode 100644 scaler/worker/preload.py diff --git a/docs/source/tutorials/configuration.rst b/docs/source/tutorials/configuration.rst index 3d4d56fff..e70d8c611 100644 --- a/docs/source/tutorials/configuration.rst +++ b/docs/source/tutorials/configuration.rst @@ -80,6 +80,45 @@ For the list of available settings, use the CLI command: scaler_cluster -h +**Preload Hook** + +Workers can execute an optional initialization function before processing tasks using the ``--preload`` option. This enables workers to: + +* Set up environments on demand +* Preload data, libraries, or models +* Initialize connections or state + +The preload specification follows the format ``module.path:function(args, kwargs)`` where: + +* ``module.path`` is the Python module to import +* ``function`` is the callable to execute +* ``args`` and ``kwargs`` are literal values (strings, numbers, booleans, lists, dicts) + +.. code:: bash + + # Simple function call with no arguments + scaler_cluster tcp://127.0.0.1:8516 --preload "mypackage.init:setup" + + # Function call with arguments + scaler_cluster tcp://127.0.0.1:8516 --preload "mypackage.init:configure('production', debug=False)" + +The preload function is executed once per processor during initialization, before any tasks are processed. If the preload function fails, the error is logged and the processor will terminate. + +Example preload module (``mypackage/init.py``): + +.. code:: python + + import logging + + def setup(): + """Basic setup with no arguments""" + logging.info("Worker initialized") + + def configure(environment, debug=True): + """Setup with configuration parameters""" + logging.info(f"Configuring for {environment}, debug={debug}") + # Initialize connections, load models, etc. + **Death Timeout** Workers are spun up with a ``death_timeout_seconds``, which indicates how long the worker will stay alive without being connected to a Scheduler. The default setting is 300 seconds. This is intended for the workers to clean up if the Scheduler crashes. diff --git a/examples/task_capabilities.py b/examples/task_capabilities.py index e75468863..55f2c0fe4 100644 --- a/examples/task_capabilities.py +++ b/examples/task_capabilities.py @@ -34,6 +34,7 @@ def main(): regular_cluster = Cluster( address=base_cluster._address, storage_address=None, + preload=None, worker_io_threads=1, worker_names=["gpu_worker"], per_worker_capabilities={"gpu": -1}, diff --git a/scaler/cluster/cluster.py b/scaler/cluster/cluster.py index 58250d18d..ec717edeb 100644 --- a/scaler/cluster/cluster.py +++ b/scaler/cluster/cluster.py @@ -16,6 +16,7 @@ def __init__( self, address: ZMQConfig, storage_address: Optional[ObjectStorageConfig], + preload: Optional[str], worker_io_threads: int, worker_names: List[str], per_worker_capabilities: Dict[str, int], @@ -35,6 +36,7 @@ def __init__( self._address = address self._storage_address = storage_address + self._preload = preload self._worker_io_threads = worker_io_threads self._worker_names = worker_names self._per_worker_capabilities = per_worker_capabilities @@ -83,6 +85,7 @@ def __start_workers_and_run_forever(self): address=self._address, storage_address=self._storage_address, capabilities=self._per_worker_capabilities, + preload=self._preload, io_threads=self._worker_io_threads, task_queue_size=self._per_worker_task_queue_size, heartbeat_interval_seconds=self._heartbeat_interval_seconds, diff --git a/scaler/cluster/combo.py b/scaler/cluster/combo.py index 3c02d179a..e3b42913c 100644 --- a/scaler/cluster/combo.py +++ b/scaler/cluster/combo.py @@ -84,6 +84,7 @@ def __init__( self._cluster = Cluster( address=self._address, storage_address=self._storage_address, + preload=None, worker_io_threads=worker_io_threads, worker_names=[f"{socket.gethostname().split('.')[0]}_{i}" for i in range(n_workers)], per_worker_capabilities=per_worker_capabilities or {}, diff --git a/scaler/entry_points/cluster.py b/scaler/entry_points/cluster.py index dee1c1d3b..2d3999293 100644 --- a/scaler/entry_points/cluster.py +++ b/scaler/entry_points/cluster.py @@ -23,6 +23,12 @@ def get_args(): parser = argparse.ArgumentParser( "standalone compute cluster", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--preload", + type=str, + default=None, + help='optional module init in the form "pkg.mod:func(arg1, arg2)" executed in each processor before tasks', + ) parser.add_argument( "--num-of-workers", "-n", type=int, default=DEFAULT_NUMBER_OF_WORKER, help="number of workers in cluster" ) @@ -162,6 +168,7 @@ def main(): cluster = Cluster( address=args.address, storage_address=args.object_storage_address, + preload=args.preload, worker_names=worker_names, per_worker_capabilities=args.per_worker_capabilities, per_worker_task_queue_size=args.worker_task_queue_size, diff --git a/scaler/version.txt b/scaler/version.txt index 89c881bc9..e0a6b34fb 100644 --- a/scaler/version.txt +++ b/scaler/version.txt @@ -1 +1 @@ -1.12.4 +1.12.5 diff --git a/scaler/worker/agent/processor/processor.py b/scaler/worker/agent/processor/processor.py index 33a6e01b9..eee890bae 100644 --- a/scaler/worker/agent/processor/processor.py +++ b/scaler/worker/agent/processor/processor.py @@ -25,6 +25,7 @@ from scaler.utility.zmq_config import ZMQConfig from scaler.worker.agent.processor.object_cache import ObjectCache from scaler.worker.agent.processor.streaming_buffer import StreamingBuffer +from scaler.worker.preload import execute_preload SUSPEND_SIGNAL = "SIGUSR1" # use str instead of a signal.Signal to not trigger an import error on unsupported systems. @@ -37,6 +38,7 @@ def __init__( event_loop: str, agent_address: ZMQConfig, storage_address: ObjectStorageConfig, + preload: Optional[str], resume_event: Optional[EventType], resumed_event: Optional[EventType], garbage_collect_interval_seconds: int, @@ -49,6 +51,7 @@ def __init__( self._event_loop = event_loop self._agent_address = agent_address self._storage_address = storage_address + self._preload = preload self._resume_event = resume_event self._resumed_event = resumed_event @@ -98,6 +101,15 @@ def __initialize(self): self.__register_signals() + # Execute optional preload hook if provided + if self._preload is not None: + try: + execute_preload(self._preload) + except Exception as e: + raise RuntimeError( + f"Processor[{self.pid}] initialization failed due to preload error: {self._preload}" + ) from e + def __register_signals(self): self.__register_signal("SIGTERM", self.__interrupt) diff --git a/scaler/worker/agent/processor_holder.py b/scaler/worker/agent/processor_holder.py index 72c70e6b6..c77ca4b68 100644 --- a/scaler/worker/agent/processor_holder.py +++ b/scaler/worker/agent/processor_holder.py @@ -20,6 +20,7 @@ def __init__( event_loop: str, agent_address: ZMQConfig, storage_address: ObjectStorageConfig, + preload: Optional[str], garbage_collect_interval_seconds: int, trim_memory_threshold_bytes: int, hard_suspend: bool, @@ -43,6 +44,7 @@ def __init__( event_loop=event_loop, agent_address=agent_address, storage_address=storage_address, + preload=preload, resume_event=self._resume_event, resumed_event=self._resumed_event, garbage_collect_interval_seconds=garbage_collect_interval_seconds, diff --git a/scaler/worker/agent/processor_manager.py b/scaler/worker/agent/processor_manager.py index a9d7ce444..525e4b546 100644 --- a/scaler/worker/agent/processor_manager.py +++ b/scaler/worker/agent/processor_manager.py @@ -23,6 +23,7 @@ def __init__( identity: WorkerID, event_loop: str, address_internal: ZMQConfig, + preload: Optional[str], garbage_collect_interval_seconds: int, trim_memory_threshold_bytes: int, hard_processor_suspend: bool, @@ -33,6 +34,7 @@ def __init__( self._identity = identity self._event_loop = event_loop + self._preload = preload self._garbage_collect_interval_seconds = garbage_collect_interval_seconds self._trim_memory_threshold_bytes = trim_memory_threshold_bytes @@ -298,6 +300,7 @@ def __start_new_processor(self): self._event_loop, self._address_internal, storage_address, + self._preload, self._garbage_collect_interval_seconds, self._trim_memory_threshold_bytes, self._hard_processor_suspend, diff --git a/scaler/worker/preload.py b/scaler/worker/preload.py new file mode 100644 index 000000000..7b8963c3b --- /dev/null +++ b/scaler/worker/preload.py @@ -0,0 +1,84 @@ +import ast +import importlib +import logging +import os +import traceback +from typing import Any, Dict, List, Optional, Tuple + + +class PreloadSpecError(Exception): + pass + + +def execute_preload(spec: str) -> None: + """ + Import and execute the given preload spec in current interpreter. + + Example: 'foo.bar:preload_function("a", 2)' + """ + module_path, func_name, args, kwargs = _parse_preload_spec(spec) + logging.info("preloading: %s:%s with args=%s kwargs=%s", module_path, func_name, args, kwargs) + + try: + module = importlib.import_module(module_path) + except ImportError: + if module_path.endswith(".py") and os.path.exists(module_path): + raise PreloadSpecError( + f"Failed to find module. Did you mean '{module_path.rsplit('.', 1)[0]}:{func_name}'?" + ) + raise + + try: + target = getattr(module, func_name) + except AttributeError: + logging.exception(f"Failed to find attribute {func_name!r} in {module_path!r}.") + raise PreloadSpecError(f"Failed to find attribute {func_name!r} in {module_path!r}.") + + if not callable(target): + raise PreloadSpecError("Preload target must be callable.") + + try: + if args is None: + # Simple name: call with no args + target() + else: + target(*args, **(kwargs or {})) + except TypeError as e: + raise PreloadSpecError("".join(traceback.format_exception_only(TypeError, e)).strip()) + + +def _parse_preload_spec(spec: str) -> Tuple[str, str, Optional[List[Any]], Optional[Dict[str, Any]]]: + """ + Parse 'pkg.mod:func(arg1, kw=val)' using AST. + Returns (module_path, func_name, args_or_None, kwargs_or_None). + If expression is a simple name (no args), returns args=None, kwargs=None. + """ + if ":" not in spec: + raise PreloadSpecError("preload must be in 'module.sub:func(...)' format") + + module_part, obj_expr = spec.split(":", 1) + + # Parse the right-hand side as a single expression + try: + expression = ast.parse(obj_expr, mode="eval").body + except SyntaxError: + raise PreloadSpecError(f"Failed to parse {obj_expr!r} as an attribute name or function call.") + + if isinstance(expression, ast.Name): + func_name = expression.id + args = None + kwargs = None + elif isinstance(expression, ast.Call): + # Ensure the function name is an attribute name only (no dotted path) + if not isinstance(expression.func, ast.Name): + raise PreloadSpecError(f"Function reference must be a simple name: {obj_expr!r}") + func_name = expression.func.id + try: + args = [ast.literal_eval(arg) for arg in expression.args] + kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in expression.keywords} + except ValueError: + raise PreloadSpecError(f"Failed to parse arguments as literal values: {obj_expr!r}") + else: + raise PreloadSpecError(f"Failed to parse {obj_expr!r} as an attribute name or function call.") + + return module_part, func_name, args, kwargs diff --git a/scaler/worker/worker.py b/scaler/worker/worker.py index 18092fe40..b31b9ab08 100644 --- a/scaler/worker/worker.py +++ b/scaler/worker/worker.py @@ -46,6 +46,7 @@ def __init__( name: str, address: ZMQConfig, storage_address: Optional[ObjectStorageConfig], + preload: Optional[str], capabilities: Dict[str, int], io_threads: int, task_queue_size: int, @@ -64,6 +65,7 @@ def __init__( self._name = name self._address = address self._storage_address = storage_address + self._preload = preload self._capabilities = capabilities self._io_threads = io_threads self._task_queue_size = task_queue_size @@ -136,6 +138,7 @@ def __initialize(self): identity=self._ident, event_loop=self._event_loop, address_internal=self._address_internal, + preload=self._preload, garbage_collect_interval_seconds=self._garbage_collect_interval_seconds, trim_memory_threshold_bytes=self._trim_memory_threshold_bytes, hard_processor_suspend=self._hard_processor_suspend, diff --git a/scaler/worker_adapter/native.py b/scaler/worker_adapter/native.py index 53da544ca..8dfb6040b 100644 --- a/scaler/worker_adapter/native.py +++ b/scaler/worker_adapter/native.py @@ -74,6 +74,7 @@ async def start_worker_group(self) -> WorkerGroupID: name=uuid.uuid4().hex, address=self._address, storage_address=self._storage_address, + preload=None, capabilities=self._capabilities, io_threads=self._io_threads, task_queue_size=self._task_queue_size, diff --git a/tests/test_balance.py b/tests/test_balance.py index 7dfb99707..b9d76720e 100644 --- a/tests/test_balance.py +++ b/tests/test_balance.py @@ -45,6 +45,7 @@ def test_balance(self): new_cluster = Cluster( address=combo._cluster._address, storage_address=None, + preload=None, worker_io_threads=1, worker_names=[str(i) for i in range(0, N_WORKERS - 1)], per_worker_capabilities={}, diff --git a/tests/test_client.py b/tests/test_client.py index c0daa4ebe..b9f2fbf69 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,7 @@ import functools import os import random +import tempfile import time import unittest from concurrent.futures import CancelledError @@ -9,6 +10,7 @@ from scaler.utility.exceptions import MissingObjects, ProcessorDiedError from scaler.utility.logging.scoped_logger import ScopedLogger from scaler.utility.logging.utility import setup_logger +from scaler.worker.preload import PreloadSpecError, _parse_preload_spec, execute_preload from tests.utility import logging_test_name @@ -30,6 +32,13 @@ def raise_exception(foo: int): raise ValueError("foo cannot be 100") +def get_preloaded_value(): + """Function that retrieves value set by preload""" + from tests.utility import get_global_value + + return get_global_value() + + class TestClient(unittest.TestCase): def setUp(self) -> None: setup_logger() @@ -325,6 +334,7 @@ def test_capabilities(self): gpu_cluster = Cluster( address=base_cluster._address, storage_address=None, + preload=None, worker_io_threads=1, worker_names=["gpu_worker"], per_worker_capabilities={"gpu": -1}, @@ -345,3 +355,111 @@ def test_capabilities(self): self.assertEqual(future.result(), 3.0) gpu_cluster.terminate() + + +class TestClientPreload(unittest.TestCase): + # Separate class for preload functionality with separate cluster to avoid interfering with time-sensitive tests + + def setUp(self) -> None: + setup_logger() + logging_test_name(self) + self.combo = SchedulerClusterCombo(n_workers=0, event_loop="builtin") + + def tearDown(self) -> None: + self.combo.shutdown() + + def _create_preload_cluster(self, preload: str, logging_paths: tuple = ("/dev/stdout",)): + base_cluster = self.combo._cluster + preload_cluster = Cluster( + address=self.combo._address, + storage_address=self.combo._storage_address, + preload=preload, + worker_io_threads=base_cluster._worker_io_threads, + worker_names=["preload_worker"], + per_worker_capabilities={}, + per_worker_task_queue_size=base_cluster._per_worker_task_queue_size, + heartbeat_interval_seconds=base_cluster._heartbeat_interval_seconds, + task_timeout_seconds=base_cluster._task_timeout_seconds, + death_timeout_seconds=base_cluster._death_timeout_seconds, + garbage_collect_interval_seconds=base_cluster._garbage_collect_interval_seconds, + trim_memory_threshold_bytes=base_cluster._trim_memory_threshold_bytes, + hard_processor_suspend=base_cluster._hard_processor_suspend, + event_loop=base_cluster._event_loop, + logging_paths=logging_paths, + logging_level=base_cluster._logging_level, + logging_config_file=base_cluster._logging_config_file, + ) + return preload_cluster + + def test_preload_success(self): + preload_cluster = self._create_preload_cluster(preload="tests.utility:setup_global_value('test_preload_value')") + + try: + preload_cluster.start() + time.sleep(2) + + with Client(self.combo.get_address()) as client: + # Submit a task that should access the preloaded global value + future = client.submit(get_preloaded_value) + result = future.result() + + # Verify the preloaded value is accessible + self.assertEqual(result, "test_preload_value") + finally: + preload_cluster.terminate() + preload_cluster.join() + + def test_preload_failure(self): + # For checking if the failure was logged, Processor will create log_path-{pid} + log_file = tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".log") + log_path = log_file.name + log_dir = os.path.dirname(log_path) + log_basename = os.path.basename(log_path) + + try: + preload_cluster = self._create_preload_cluster( + preload="tests.utility:failing_preload()", logging_paths=(log_path,) + ) + + try: + preload_cluster.start() + time.sleep(10) + + # Find processor log files by looking for files with PID suffixes + processor_log_content = "" + for file in os.listdir(log_dir): + if file.startswith(log_basename + "-") and file != log_basename: + processor_log_path = os.path.join(log_dir, file) + with open(processor_log_path, "r") as f: + processor_log_content += f.read() + + # Verify that the preload failure was logged properly + self.assertIn("preloading: tests.utility:failing_preload with args", processor_log_content) + + # If we reach here without any other exceptions, the test is successful + finally: + preload_cluster.terminate() + preload_cluster.join() + finally: + # Clean up log files + try: + os.unlink(log_path) + for file in os.listdir(log_dir): + if file.startswith(log_basename + "-") and file != log_basename: + os.unlink(os.path.join(log_dir, file)) + except FileNotFoundError: + pass + + def test_parse_preload_spec_error(self): + # Test that _parse_preload_spec raises PreloadSpecError for invalid specs + with self.assertRaises(PreloadSpecError) as cm: + _parse_preload_spec("module_without_colon") + + self.assertIn("preload must be in 'module.sub:func(...)' format", str(cm.exception)) + + def test_execute_preload_error(self): + # Test that execute_preload raises PreloadSpecError for non-callable targets + with self.assertRaises(PreloadSpecError) as cm: + execute_preload("sys:version") # sys.version is a string, not callable + + self.assertIn("Preload target must be callable", str(cm.exception)) diff --git a/tests/test_death_timeout.py b/tests/test_death_timeout.py index b86615174..a4d7612e5 100644 --- a/tests/test_death_timeout.py +++ b/tests/test_death_timeout.py @@ -30,6 +30,7 @@ def test_no_scheduler(self): cluster = Cluster( address=ZMQConfig.from_string(f"tcp://127.0.0.1:{get_available_tcp_port()}"), storage_address=None, + preload=None, worker_io_threads=DEFAULT_IO_THREADS, worker_names=["a", "b"], per_worker_capabilities={}, diff --git a/tests/test_graph.py b/tests/test_graph.py index 22335d2a3..ef69ed10c 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -230,6 +230,7 @@ def test_graph_capabilities(self): gpu_cluster = Cluster( address=base_cluster._address, storage_address=None, + preload=None, worker_io_threads=1, worker_names=["gpu_worker"], per_worker_capabilities={"gpu": -1}, diff --git a/tests/utility.py b/tests/utility.py index 59a4f8c6c..00707f338 100644 --- a/tests/utility.py +++ b/tests/utility.py @@ -1,6 +1,26 @@ import logging import unittest +# Global variable to test preload functionality +PRELOAD_VALUE = None + def logging_test_name(obj: unittest.TestCase): logging.info(f"{obj.__class__.__name__}:{obj._testMethodName} ==============================================") + + +def setup_global_value(value: str = "default") -> None: + """Preload function that sets a global variable""" + global PRELOAD_VALUE + PRELOAD_VALUE = value + logging.info(f"Preload set PRELOAD_VALUE to: {value}") + + +def get_global_value(): + """Function to be called by tasks to retrieve the preloaded value""" + return PRELOAD_VALUE + + +def failing_preload(): + """Preload function that always fails""" + raise ValueError("Intentional preload failure for testing") From 1c1e5ebb2c194e11973672e2e84cfd3a0cae4521 Mon Sep 17 00:00:00 2001 From: magniloquency <197707854+magniloquency@users.noreply.github.com> Date: Mon, 6 Oct 2025 13:29:57 -0400 Subject: [PATCH 13/23] Add YMQ MITM Tests Move files Revert --- scripts/build.sh | 3 - scripts/test.sh | 15 ++ tests/CMakeLists.txt | 3 + tests/cpp/ymq/common.h | 150 ++++++++++++++- tests/cpp/ymq/py_mitm/__init__.py | 0 tests/cpp/ymq/py_mitm/main.py | 175 ++++++++++++++++++ tests/cpp/ymq/py_mitm/passthrough.py | 25 +++ .../cpp/ymq/py_mitm/randomly_drop_packets.py | 29 +++ tests/cpp/ymq/py_mitm/send_rst_to_client.py | 50 +++++ tests/cpp/ymq/py_mitm/types.py | 54 ++++++ tests/cpp/ymq/test_cc_ymq.cpp | 119 +++++++++++- 11 files changed, 610 insertions(+), 13 deletions(-) create mode 100755 scripts/test.sh create mode 100644 tests/cpp/ymq/py_mitm/__init__.py create mode 100644 tests/cpp/ymq/py_mitm/main.py create mode 100644 tests/cpp/ymq/py_mitm/passthrough.py create mode 100644 tests/cpp/ymq/py_mitm/randomly_drop_packets.py create mode 100644 tests/cpp/ymq/py_mitm/send_rst_to_client.py create mode 100644 tests/cpp/ymq/py_mitm/types.py diff --git a/scripts/build.sh b/scripts/build.sh index 7256ea367..4279683d0 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -25,6 +25,3 @@ cmake --build --preset $BUILD_PRESET # Install cmake --install $BUILD_DIR - -# Tests -ctest --preset $BUILD_PRESET diff --git a/scripts/test.sh b/scripts/test.sh new file mode 100755 index 000000000..948090580 --- /dev/null +++ b/scripts/test.sh @@ -0,0 +1,15 @@ +#!/bin/bash -e +# +# This script tests the C++ components. +# +# Usage: +# ./scripts/test.sh + +OS="$(uname -s | tr '[:upper:]' '[:lower:]')" # e.g. linux or darwin +ARCH="$(uname -m)" # e.g. x86_64 or arm64 + +BUILD_DIR="build_${OS}_${ARCH}" +BUILD_PRESET="${OS}-${ARCH}" + +# Run tests +ctest --preset $BUILD_PRESET -VV diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index f6294d570..3fe975515 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -11,6 +11,8 @@ set(BUILD_GMOCK OFF CACHE BOOL "" FORCE) set(BUILD_GTEST ON CACHE BOOL "" FORCE) FetchContent_MakeAvailable(googletest) +find_package(Python3 COMPONENTS Development REQUIRED) + # This function compiles, links, and adds a C++ test executable using Google Test. # It is shared by all test subdirectories. function(add_test_executable test_name source_file) @@ -26,6 +28,7 @@ function(add_test_executable test_name source_file) CapnProto::capnp CapnProto::kj GTest::gtest_main + Python3::Python ) add_test(NAME ${test_name} COMMAND ${test_name}) diff --git a/tests/cpp/ymq/common.h b/tests/cpp/ymq/common.h index 029048b27..3ed38fb3a 100644 --- a/tests/cpp/ymq/common.h +++ b/tests/cpp/ymq/common.h @@ -1,5 +1,7 @@ #pragma once +#define PY_SSIZE_T_CLEAN +#include #include #include #include @@ -25,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -239,9 +242,50 @@ inline void fork_wrapper(std::function fn, int timeout_secs, Owned pipe_wr.write_all((char*)&result, sizeof(TestResult)); } +// this function along with `wait_for_python_ready_sigwait()` +// work together to wait on a signal from the python process +// indicating that the tuntap interface has been created, and that the mitm is ready +inline void wait_for_python_ready_sigblock() +{ + sigset_t set {}; + + if (sigemptyset(&set) < 0) + throw std::system_error(errno, std::generic_category(), "failed to create empty signal set"); + + if (sigaddset(&set, SIGUSR1) < 0) + throw std::system_error(errno, std::generic_category(), "failed to add sigusr1 to the signal set"); + + if (sigprocmask(SIG_BLOCK, &set, nullptr) < 0) + throw std::system_error(errno, std::generic_category(), "failed to mask sigusr1"); + + std::println("blocked signal..."); +} + +inline void wait_for_python_ready_sigwait(int timeout_secs) +{ + sigset_t set {}; + siginfo_t sig {}; + + if (sigemptyset(&set) < 0) + throw std::system_error(errno, std::generic_category(), "failed to create empty signal set"); + + if (sigaddset(&set, SIGUSR1) < 0) + throw std::system_error(errno, std::generic_category(), "failed to add sigusr1 to the signal set"); + + std::println("waiting for python to be ready..."); + timespec ts {.tv_sec = timeout_secs, .tv_nsec = 0}; + if (sigtimedwait(&set, &sig, &ts) < 0) + throw std::system_error(errno, std::generic_category(), "failed to wait on sigusr1"); + + sigprocmask(SIG_UNBLOCK, &set, nullptr); + std::println("signal received; python is ready"); +} + // run a test // forks and runs each of the provided closures -inline TestResult test(int timeout_secs, std::vector> closures) +// if `wait_for_python` is true, wait for SIGUSR1 after forking and executing the first closure +inline TestResult test( + int timeout_secs, std::vector> closures, bool wait_for_python = false) { std::vector> pipes {}; std::vector pids {}; @@ -259,6 +303,9 @@ inline TestResult test(int timeout_secs, std::vector } for (size_t i = 0; i < closures.size(); i++) { + if (wait_for_python && i == 0) + wait_for_python_ready_sigblock(); + auto pid = fork(); if (pid < 0) { std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { @@ -287,6 +334,9 @@ inline TestResult test(int timeout_secs, std::vector } pids.push_back(pid); + + if (wait_for_python && i == 0) + wait_for_python_ready_sigwait(3); } // close all write halves of the pipes @@ -405,3 +455,101 @@ inline TestResult test(int timeout_secs, std::vector return TestResult::Success; } + +inline TestResult run_python(const char* path, std::vector argv = {}) +{ + // insert the pid at the start of the argv, this is important for signalling readiness + pid_t pid = getppid(); + auto pid_ws = std::to_wstring(pid); + argv.insert(argv.begin(), pid_ws.c_str()); + + PyStatus status; + PyConfig config; + PyConfig_InitPythonConfig(&config); + + status = PyConfig_SetBytesString(&config, &config.program_name, "mitm"); + if (PyStatus_Exception(status)) + goto exception; + + status = Py_InitializeFromConfig(&config); + if (PyStatus_Exception(status)) + goto exception; + PyConfig_Clear(&config); + + argv.insert(argv.begin(), L"mitm"); + PySys_SetArgv(argv.size(), (wchar_t**)argv.data()); + + { + auto file = fopen(path, "r"); + if (!file) + throw std::system_error(errno, std::generic_category(), "failed to open python file"); + + PyRun_SimpleFile(file, path); + fclose(file); + } + + if (Py_FinalizeEx() < 0) { + std::println("finalization failure"); + return TestResult::Failure; + } + + return TestResult::Success; + +exception: + PyConfig_Clear(&config); + Py_ExitStatusException(status); + + return TestResult::Failure; +} + +// change the current working directory to the project root +// this is important for finding the python mitm script +inline void chdir_to_project_root() +{ + auto cwd = std::filesystem::current_path(); + + // if pyproject.toml is in `path`, it's the project root + for (auto path = cwd; !path.empty(); path = path.parent_path()) { + if (std::filesystem::exists(path / "pyproject.toml")) { + // change to the project root + std::filesystem::current_path(path); + return; + } + } +} + +inline TestResult run_mitm( + std::string testcase, + std::string mitm_ip, + uint16_t mitm_port, + std::string remote_ip, + uint16_t remote_port, + std::vector extra_args = {}) +{ + auto cwd = std::filesystem::current_path(); + chdir_to_project_root(); + + // we build the args for the user to make calling the function more convenient + std::vector args { + testcase, mitm_ip, std::to_string(mitm_port), remote_ip, std::to_string(remote_port)}; + + for (auto arg: extra_args) + args.push_back(arg); + + // we need to convert to wide strings to pass to Python + std::vector wide_args_owned {}; + + // the strings are ascii so we can just make them into wstrings + for (const auto& str: args) + wide_args_owned.emplace_back(str.begin(), str.end()); + + std::vector wide_args {}; + for (const auto& wstr: wide_args_owned) + wide_args.push_back(wstr.c_str()); + + auto result = run_python("tests/cpp/ymq/py_mitm/main.py", wide_args); + + // change back to the original working directory + std::filesystem::current_path(cwd); + return result; +} diff --git a/tests/cpp/ymq/py_mitm/__init__.py b/tests/cpp/ymq/py_mitm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/cpp/ymq/py_mitm/main.py b/tests/cpp/ymq/py_mitm/main.py new file mode 100644 index 000000000..15e8e8349 --- /dev/null +++ b/tests/cpp/ymq/py_mitm/main.py @@ -0,0 +1,175 @@ +# flake8: noqa: E402 + +""" +This script provides a framework for running MITM test cases +""" + +import argparse +import os +import sys +import importlib +import signal +import subprocess +from tests.cpp.ymq.py_mitm.types import AbstractMITM, TCPConnection +from scapy.all import IP, TCP, TunTapInterface # type: ignore +from typing import List + +from tests.cpp.ymq.py_mitm import passthrough, randomly_drop_packets, send_rst_to_client + + +def echo_call(cmd: List[str]): + print(f"+ {' '.join(cmd)}") + subprocess.check_call(cmd) + + +def create_tuntap_interface(iface_name: str, mitm_ip: str, remote_ip: str) -> TunTapInterface: + """ + Creates a TUNTAP interface and sets brings it up and adds ips using the `ip` program + + Args: + iface_name: The name of the TUNTAP interface, usually like `tun0`, `tun1`, etc. + mitm_ip: The desired ip address of the mitm. This is the ip that clients can use to connect to the mitm + remote_ip: The ip that routes to/from the tuntap interface. + packets sent to `mitm_ip` will appear to come from `remote_ip`,\ + and conversely the tuntap interface can connect/send packets + to `remote_ip`, making it a suitable ip for binding a server + + Returns: + The TUNTAP interface + """ + iface = TunTapInterface(iface_name, mode="tun") + + try: + echo_call(["sudo", "ip", "link", "set", iface_name, "up"]) + echo_call(["sudo", "ip", "addr", "add", remote_ip, "peer", mitm_ip, "dev", iface_name]) + print(f"[+] Interface {iface_name} up with IP {mitm_ip}") + except subprocess.CalledProcessError: + print("[!] Could not bring up interface. Run as root or set manually.") + raise + + return iface + + +def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: int, mitm: MITMProtocol): + """ + This function serves as a framework for man in the middle implementations + A client connects to the MITM, then the MITM connects to a remote server + The MITM sits inbetween the client and the server, manipulating the packets sent depending on the test case + This function: + 1. creates a TUNTAP interface and prepares it for MITM + 2. handles connecting clients and handling connection closes + 3. delegates additional logic to a pluggable callable, `mitm` + 4. returns when both connections have terminated (via ) + + Args: + pid: this is the pid of the test process, used for signaling readiness \ + we send SIGUSR1 to this process when the mitm is ready + mitm_ip: The desired ip address of the mitm server + mitm_port: The desired port of the mitm server. \ + This is the port used to connect to the server, but the client is free to connect on any port + remote_ip: The desired remote ip for the TUNTAP interface. This is the only ip address \ + reachable by the interface and is thus the src ip for clients, and the ip that the remote server \ + must be bound to + server_port: The port that the remote server is bound to + mitm: The core logic for a MITM test case. This callable may maintain its own state and is responsible \ + for sending packets over the TUNTAP interface (if it doesn't, nothing will happen) + """ + + tuntap = create_tuntap_interface("tun0", mitm_ip, remote_ip) + + # signal the caller that the tuntap interface has been created + if pid > 0: + os.kill(pid, signal.SIGUSR1) + + # these track information about our connections + # we already know what to expect for the server connection, we are the connector + client_conn = None + + # the port that the mitm uses to connect to the server + # we increment the port for each new connection to avoid collisions + mitm_server_port = mitm_port + server_conn = TCPConnection(mitm_ip, mitm_server_port, remote_ip, server_port) + + # tracks the state of each connection + client_sent_fin_ack = False + client_closed = False + server_sent_fin_ack = False + server_closed = False + + while True: + pkt = tuntap.recv() + if not pkt.haslayer(IP) or not pkt.haslayer(TCP): + continue + ip = pkt[IP] + tcp = pkt[TCP] + + # for a received packet, the destination ip and port are our local ip and port + # and the source ip and port will be the remote ip and port + sender = TCPConnection(pkt.dst, pkt.dport, pkt.src, pkt.sport) + + pretty = f"[{tcp.flags}]{(': ' + str(bytes(tcp.payload))) if tcp.payload else ''}" + + if not mitm.proxy(tuntap, pkt, sender, client_conn, server_conn): + if sender == client_conn: + print(f"[DROPPED]: -> {pretty}") + elif sender == server_conn: + print(f"[DROPPED]: <- {pretty}") + else: + print(f"[DROPPED]: ?? {pretty}") + + continue # the segment was not proxied, so we can't update our internal state + + if sender == client_conn: + print(f"-> {pretty}") + elif sender == server_conn: + print(f"<- {pretty}") + + if tcp.flags == "S": # SYN from client + print("-> [S]") + if sender != client_conn or client_conn is None: + print(f"[*] New connection from {ip.src}:{tcp.sport} to {ip.dst}:{tcp.dport}") + client_conn = sender + + server_conn = TCPConnection(mitm_ip, mitm_server_port, remote_ip, server_port) + + # increment the port so that the next client connection (if there is one) uses a different port + mitm_server_port += 1 + + if tcp.flags == "SA": # SYN-ACK from server + if sender == server_conn: + print(f"[*] Connection to server established: {ip.src}:{tcp.sport} to {ip.dst}:{tcp.dport}") + + if tcp.flags.F and tcp.flags.A: # FIN-ACK + if sender == client_conn: + client_sent_fin_ack = True + if sender == server_conn: + server_sent_fin_ack = True + + if tcp.flags.A: # ACK + if sender == client_conn and server_sent_fin_ack: + server_closed = True + if sender == server_conn and client_sent_fin_ack: + client_closed = True + + if client_closed and server_closed: + print("[*] Both connections closed") + return + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Man in the middle test framework") + parser.add_argument("pid", type=int, help="The pid of the test process, used for signaling") + parser.add_argument("testcase", type=str, help="The MITM test case module name") + parser.add_argument("mitm_ip", type=str, help="The desired ip address of the mitm server") + parser.add_argument("mitm_port", type=int, help="The desired port of the mitm server") + parser.add_argument("remote_ip", type=str, help="The desired remote ip for the TUNTAP interface") + parser.add_argument("server_port", type=int, help="The port that the remote server is bound to") + + args, unknown = parser.parse_known_args() + + # add the script's directory to path + sys.path.append(os.path.dirname(os.path.realpath(__file__))) + + # load the module dynamically + module = importlib.import_module(args.testcase) + main(args.pid, args.mitm_ip, args.mitm_port, args.remote_ip, args.server_port, module.MITM(*unknown)) diff --git a/tests/cpp/ymq/py_mitm/passthrough.py b/tests/cpp/ymq/py_mitm/passthrough.py new file mode 100644 index 000000000..04a17e6ee --- /dev/null +++ b/tests/cpp/ymq/py_mitm/passthrough.py @@ -0,0 +1,25 @@ +""" +This MITM acts as a transparent passthrough, it simply forwards packets as they are, +minus necessary header changes to retransmit +This MITM should have no effect on the client and server, +and they should behave as if the MITM is not present +""" + +from tests.cpp.ymq.py_mitm.types import AbstractMITM, TunTapInterface, IP, TCPConnection +from typing import Optional + + +class MITM(MITMProtocol): + def proxy( + self, + tuntap: TunTapInterface, + pkt: IP, + sender: TCPConnection, + client_conn: Optional[TCPConnection], + server_conn: TCPConnection, + ) -> bool: + if sender == client_conn or client_conn is None: + tuntap.send(server_conn.rewrite(pkt)) + elif sender == server_conn: + tuntap.send(client_conn.rewrite(pkt)) + return True diff --git a/tests/cpp/ymq/py_mitm/randomly_drop_packets.py b/tests/cpp/ymq/py_mitm/randomly_drop_packets.py new file mode 100644 index 000000000..6278c93c2 --- /dev/null +++ b/tests/cpp/ymq/py_mitm/randomly_drop_packets.py @@ -0,0 +1,29 @@ +""" +This MITM drops a % of packets +""" + +import random +from tests.cpp.ymq.py_mitm.types import AbstractMITM, TunTapInterface, IP, TCPConnection +from typing import Optional + + +class MITM(MITMProtocol): + def __init__(self, drop_pcent: str): + self.drop_pcent = float(drop_pcent) + + def proxy( + self, + tuntap: TunTapInterface, + pkt: IP, + sender: TCPConnection, + client_conn: Optional[TCPConnection], + server_conn: TCPConnection, + ) -> bool: + if random.random() < self.drop_pcent: + return False + + if sender == client_conn or client_conn is None: + tuntap.send(server_conn.rewrite(pkt)) + elif sender == server_conn: + tuntap.send(client_conn.rewrite(pkt)) + return True diff --git a/tests/cpp/ymq/py_mitm/send_rst_to_client.py b/tests/cpp/ymq/py_mitm/send_rst_to_client.py new file mode 100644 index 000000000..9864b6986 --- /dev/null +++ b/tests/cpp/ymq/py_mitm/send_rst_to_client.py @@ -0,0 +1,50 @@ +""" +This MITM inserts an unexpected TCP RST +""" + +from tests.cpp.ymq.py_mitm.types import IP, TCP, AbstractMITM, TCPConnection, TunTapInterface +from typing import Optional + + +class MITM(MITMProtocol): + def __init__(self): + # count the number of psh-acks sent by the client + self.client_pshack_counter = 0 + + def proxy( + self, + tuntap: TunTapInterface, + pkt: IP, + sender: TCPConnection, + client_conn: Optional[TCPConnection], + server_conn: TCPConnection, + ) -> bool: + if sender == client_conn or client_conn is None: + if pkt[TCP].flags == "PA": + self.client_pshack_counter += 1 + + # on the second psh-ack, send a rst instead + if self.client_pshack_counter == 2: + rst_pkt = IP(src=client_conn.local_ip, dst=client_conn.remote_ip) / TCP( + sport=client_conn.local_port, dport=client_conn.remote_port, flags="R", seq=pkt[TCP].ack + ) + print(f"<- [{rst_pkt[TCP].flags}] (simulated)") + tuntap.send(rst_pkt) + return False + + tuntap.send(server_conn.rewrite(pkt)) + elif sender == server_conn: + tuntap.send(client_conn.rewrite(pkt)) + return True + + +# client -> mitm -> server +# server -> mitm -> client + +# client: 127.0.0.1:8080 +# mitm: 127.0.0.1:8081 +# server: 127.0.0.1:8081 + + +# client -> mitm == src = client.ip, sport = client.port ;; dst = mitm.ip, dport = mitm.port +# mitm -> server == src = mitm.ip, sport = mitm.port ;; dst = server.ip, dport = server.port diff --git a/tests/cpp/ymq/py_mitm/types.py b/tests/cpp/ymq/py_mitm/types.py new file mode 100644 index 000000000..fd0ce3ddf --- /dev/null +++ b/tests/cpp/ymq/py_mitm/types.py @@ -0,0 +1,54 @@ +""" +This is the common code for implementing man in the middle in Python +""" + +import dataclasses +from typing import Protocol, Optional +from scapy.all import TunTapInterface, IP, TCP # type: ignore + + +@dataclasses.dataclass +class TCPConnection: + """ + Represents a TCP connection over the TUNTAP interface + local_ip and local_port are the mitm's ip and port, and + remote_ip and remote_port are the port for the remote peer + """ + + local_ip: str + local_port: int + remote_ip: str + remote_port: int + + def rewrite(self, pkt: IP, ack: Optional[int] = None, data=None): + """ + Rewrite a TCP/IP packet as a packet originating + from (local_ip, local_port) and going to (remote_ip, remote_port) + This function is useful for taking a packet received from one connection, and redirecting it to another + + Args: + pkt: A scapy TCP/IP packet to rewrite + ack: An optional ack number to use instead of the one found in `pkt` + data: An optional payload to use instead of the one found int `pkt` + + Returns: + The rewritten packet, suitable for sending over TUNTAP + """ + tcp = pkt[TCP] + + return ( + IP(src=self.local_ip, dst=self.remote_ip) + / TCP(sport=self.local_port, dport=self.remote_port, flags=tcp.flags, seq=tcp.seq, ack=ack or tcp.ack) + / bytes(data or tcp.payload) + ) + + +class MITMProtocol(Protocol): + def proxy( + self, + tuntap: TunTapInterface, + pkt: IP, + sender: TCPConnection, + client_conn: Optional[TCPConnection], + server_conn: TCPConnection, + ) -> bool: ... diff --git a/tests/cpp/ymq/test_cc_ymq.cpp b/tests/cpp/ymq/test_cc_ymq.cpp index 580a37bfa..be8267382 100644 --- a/tests/cpp/ymq/test_cc_ymq.cpp +++ b/tests/cpp/ymq/test_cc_ymq.cpp @@ -2,6 +2,10 @@ // each test case is comprised of at least one client and one server, and possibly a middleman // the clients and servers used in these tests are defined in the first part of this file // +// the men in the middle (mitm) are implemented using Python and are found in py_mitm/ +// in that directory, `main.py` is the entrypoint and framework for all the mitm, +// and the individual mitm implementations are found in their respective files +// // the test cases are at the bottom of this file, after the clients and servers // the documentation for each case is found on the TEST() definition @@ -13,6 +17,7 @@ #include #include +#include #include #include #include @@ -128,9 +133,9 @@ TestResult reconnect_server_main(std::string host, uint16_t port) auto result = syncRecvMessage(socket); RETURN_FAILURE_IF_FALSE(result.has_value()); - RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "hello!!"); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "sync"); - auto error = syncSendMessage(socket, {.address = Bytes("client"), .payload = Bytes("world!!")}); + auto error = syncSendMessage(socket, {.address = Bytes("client"), .payload = Bytes("acknowledge")}); RETURN_FAILURE_IF_FALSE(!error); context.removeIOSocket(socket); @@ -144,14 +149,39 @@ TestResult reconnect_client_main(std::string host, uint16_t port) auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); syncConnectSocket(socket, format_address(host, port)); - auto result = syncSendMessage(socket, {.address = Bytes("server"), .payload = Bytes("hello!!")}); - auto msg = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(msg.has_value()); - RETURN_FAILURE_IF_FALSE(msg->payload.as_string() == "world!!"); - context.removeIOSocket(socket); + // send "sync" and wait for "acknowledge" in a loop + // the mitm will send a RST after the first "sync" + // the "sync" message will be lost, but YMQ should automatically reconnect + // therefore the next "sync" message should succeed + for (size_t i = 0; i < 10; i++) { + auto error = syncSendMessage(socket, {.address = Bytes("server"), .payload = Bytes("sync")}); + RETURN_FAILURE_IF_FALSE(!error); + + auto future = futureRecvMessage(socket); + auto result = future.wait_for(1s); + if (result == std::future_status::ready) { + auto msg = future.get(); + if (!msg.has_value()) { + std::println("message error: {}", msg.error().what()); + } + RETURN_FAILURE_IF_FALSE(msg.has_value()); + std::println("received message: {}", *msg->payload.as_string()); + RETURN_FAILURE_IF_FALSE(msg->payload.as_string() == "acknowledge"); + context.removeIOSocket(socket); + + return TestResult::Success; + } else if (result == std::future_status::timeout) { + // timeout, try again + continue; + } else { + std::println("future status error"); + return TestResult::Failure; + } + } - return TestResult::Success; + std::println("failed to reconnect after 10 attempts"); + return TestResult::Failure; } TestResult client_simulated_slow_network(const char* host, uint16_t port) @@ -387,7 +417,8 @@ TEST(CcYmqTestSuite, TestBasicYMQClientYMQServer) auto host = "localhost"; auto port = 2889; - // this is the test harness, it accepts a timeout, and a list of functions to run + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) auto result = test(10, {[=] { return basic_client_ymq(host, port); }, [=] { return basic_server_ymq(host, port); }}); @@ -401,6 +432,8 @@ TEST(CcYmqTestSuite, TestBasicRawClientYMQServer) auto host = "localhost"; auto port = 2890; + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) auto result = test(10, {[=] { return basic_client_raw(host, port); }, [=] { return basic_server_ymq(host, port); }}); @@ -413,6 +446,8 @@ TEST(CcYmqTestSuite, TestBasicRawClientRawServer) auto host = "localhost"; auto port = 2891; + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) auto result = test(10, {[=] { return basic_client_raw(host, port); }, [=] { return basic_server_raw(host, port); }}); @@ -436,6 +471,8 @@ TEST(CcYmqTestSuite, TestBasicDelayYMQClientRawServer) auto host = "localhost"; auto port = 2893; + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) auto result = test(10, {[=] { return basic_client_ymq(host, port); }, [=] { return basic_server_raw(host, port); }}); @@ -457,6 +494,70 @@ TEST(CcYmqTestSuite, TestClientSendBigMessageToServer) EXPECT_EQ(result, TestResult::Success); } +// this is the no-op/passthrough man in the middle test +// for this test case we use YMQ on both the client side and the server side +// the client connects to the mitm, and the mitm connects to the server +// when the mitm receives packets from the client, it forwards it to the server without changing it +// and similarly when it receives packets from the server, it forwards them to the client +// +// the mitm is implemented in Python. we pass the name of the test case, which corresponds to the Python filename, +// and a list of arguments, which are: mitm ip, mitm port, remote ip, remote port +// this defines the address of the mitm, and the addresses that can connect to it +// for more, see the python mitm files +TEST(CcYmqTestSuite, TestMitmPassthrough) +{ + auto mitm_ip = "192.0.2.4"; + auto mitm_port = 2323; + auto remote_ip = "192.0.2.3"; + auto remote_port = 23571; + + // the Python program must be the first and only the first function passed to test() + // we must also pass `true` as the third argument to ensure that Python is fully started + // before beginning the test + auto result = test( + 20, + {[=] { return run_mitm("passthrough", mitm_ip, mitm_port, remote_ip, remote_port); }, + [=] { return basic_client_ymq(mitm_ip, mitm_port); }, + [=] { return basic_server_ymq(remote_ip, remote_port); }}, + true); + EXPECT_EQ(result, TestResult::Success); +} + +// this test uses the mitm to test the reconnect logic of YMQ by sending RST packets +TEST(CcYmqTestSuite, DISABLED_TestMitmReconnect) +{ + auto mitm_ip = "192.0.2.4"; + auto mitm_port = 2525; + auto remote_ip = "192.0.2.3"; + auto remote_port = 23575; + + auto result = test( + 10, + {[=] { return run_mitm("send_rst_to_client", mitm_ip, mitm_port, remote_ip, remote_port); }, + [=] { return reconnect_client_main(mitm_ip, mitm_port); }, + [=] { return reconnect_server_main(remote_ip, remote_port); }}, + true); + EXPECT_EQ(result, TestResult::Success); +} + +// TODO: Make this more reliable, and re-enable it +// in this test, the mitm drops a random % of packets arriving from the client and server +TEST(CcYmqTestSuite, TestMitmRandomlyDropPackets) +{ + auto mitm_ip = "192.0.2.4"; + auto mitm_port = 2828; + auto remote_ip = "192.0.2.3"; + auto remote_port = 23591; + + auto result = test( + 60, + {[=] { return run_mitm("randomly_drop_packets", mitm_ip, mitm_port, remote_ip, remote_port, {"0.3"}); }, + [=] { return basic_client_ymq(mitm_ip, mitm_port); }, + [=] { return basic_server_ymq(remote_ip, remote_port); }}, + true); + EXPECT_EQ(result, TestResult::Success); +} + // in this test the client is sending a message to the server // but we simulate a slow network connection by sending the message in segmented chunks TEST(CcYmqTestSuite, TestSlowNetwork) From 1a7a49d688f81f65f67cee339b6ddc21e0f651d5 Mon Sep 17 00:00:00 2001 From: magniloquency <197707854+magniloquency@users.noreply.github.com> Date: Wed, 1 Oct 2025 20:46:47 -0400 Subject: [PATCH 14/23] Change protocol to abstract class --- tests/cpp/ymq/py_mitm/main.py | 2 +- tests/cpp/ymq/py_mitm/passthrough.py | 2 +- tests/cpp/ymq/py_mitm/randomly_drop_packets.py | 2 +- tests/cpp/ymq/py_mitm/send_rst_to_client.py | 2 +- tests/cpp/ymq/py_mitm/types.py | 6 ++++-- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/cpp/ymq/py_mitm/main.py b/tests/cpp/ymq/py_mitm/main.py index 15e8e8349..546e90d67 100644 --- a/tests/cpp/ymq/py_mitm/main.py +++ b/tests/cpp/ymq/py_mitm/main.py @@ -50,7 +50,7 @@ def create_tuntap_interface(iface_name: str, mitm_ip: str, remote_ip: str) -> Tu return iface -def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: int, mitm: MITMProtocol): +def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: int, mitm: AbstractMITM): """ This function serves as a framework for man in the middle implementations A client connects to the MITM, then the MITM connects to a remote server diff --git a/tests/cpp/ymq/py_mitm/passthrough.py b/tests/cpp/ymq/py_mitm/passthrough.py index 04a17e6ee..17e099db1 100644 --- a/tests/cpp/ymq/py_mitm/passthrough.py +++ b/tests/cpp/ymq/py_mitm/passthrough.py @@ -9,7 +9,7 @@ from typing import Optional -class MITM(MITMProtocol): +class MITM(AbstractMITM): def proxy( self, tuntap: TunTapInterface, diff --git a/tests/cpp/ymq/py_mitm/randomly_drop_packets.py b/tests/cpp/ymq/py_mitm/randomly_drop_packets.py index 6278c93c2..6aa9cc7f5 100644 --- a/tests/cpp/ymq/py_mitm/randomly_drop_packets.py +++ b/tests/cpp/ymq/py_mitm/randomly_drop_packets.py @@ -7,7 +7,7 @@ from typing import Optional -class MITM(MITMProtocol): +class MITM(AbstractMITM): def __init__(self, drop_pcent: str): self.drop_pcent = float(drop_pcent) diff --git a/tests/cpp/ymq/py_mitm/send_rst_to_client.py b/tests/cpp/ymq/py_mitm/send_rst_to_client.py index 9864b6986..11608f117 100644 --- a/tests/cpp/ymq/py_mitm/send_rst_to_client.py +++ b/tests/cpp/ymq/py_mitm/send_rst_to_client.py @@ -6,7 +6,7 @@ from typing import Optional -class MITM(MITMProtocol): +class MITM(AbstractMITM): def __init__(self): # count the number of psh-acks sent by the client self.client_pshack_counter = 0 diff --git a/tests/cpp/ymq/py_mitm/types.py b/tests/cpp/ymq/py_mitm/types.py index fd0ce3ddf..03d94fbd8 100644 --- a/tests/cpp/ymq/py_mitm/types.py +++ b/tests/cpp/ymq/py_mitm/types.py @@ -2,8 +2,9 @@ This is the common code for implementing man in the middle in Python """ +from abc import ABC, abstractmethod import dataclasses -from typing import Protocol, Optional +from typing import Optional from scapy.all import TunTapInterface, IP, TCP # type: ignore @@ -43,7 +44,8 @@ def rewrite(self, pkt: IP, ack: Optional[int] = None, data=None): ) -class MITMProtocol(Protocol): +class AbstractMITM(ABC): + @abstractmethod def proxy( self, tuntap: TunTapInterface, From fcc6a9b42c65ca0dfdcec01d54b8f1d83d990388 Mon Sep 17 00:00:00 2001 From: magniloquency <197707854+magniloquency@users.noreply.github.com> Date: Wed, 1 Oct 2025 20:48:35 -0400 Subject: [PATCH 15/23] Remove usage of importlib --- tests/cpp/ymq/py_mitm/main.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/cpp/ymq/py_mitm/main.py b/tests/cpp/ymq/py_mitm/main.py index 546e90d67..7cec8ed5e 100644 --- a/tests/cpp/ymq/py_mitm/main.py +++ b/tests/cpp/ymq/py_mitm/main.py @@ -7,7 +7,6 @@ import argparse import os import sys -import importlib import signal import subprocess from tests.cpp.ymq.py_mitm.types import AbstractMITM, TCPConnection @@ -167,9 +166,14 @@ def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: in args, unknown = parser.parse_known_args() - # add the script's directory to path - sys.path.append(os.path.dirname(os.path.realpath(__file__))) + # TODO: use `match` in Python 3.10+ + if args.testcase == "passthrough": + module = passthrough + elif args.testcase == "randomly_drop_packets": + module = randomly_drop_packets + elif args.testcase == "send_rst_to_client": + module = send_rst_to_client + else: + raise ValueError(f"Unknown testcase: {args.testcase}") - # load the module dynamically - module = importlib.import_module(args.testcase) main(args.pid, args.mitm_ip, args.mitm_port, args.remote_ip, args.server_port, module.MITM(*unknown)) From acc8d03d84cc453777291010700f4a0f9999c201 Mon Sep 17 00:00:00 2001 From: magniloquency <197707854+magniloquency@users.noreply.github.com> Date: Mon, 6 Oct 2025 13:30:27 -0400 Subject: [PATCH 16/23] Update action Run tests with sudo --- .github/actions/compile-libraries/action.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/actions/compile-libraries/action.yml b/.github/actions/compile-libraries/action.yml index f253e2f80..545f48cc7 100644 --- a/.github/actions/compile-libraries/action.yml +++ b/.github/actions/compile-libraries/action.yml @@ -9,11 +9,16 @@ inputs: runs: using: "composite" steps: + - name: Install dependencies for MITM tests + shell: bash + run: uv pip install --system scapy==2.* + - name: Build and test C++ Components (Linux) if: inputs.os == 'Linux' shell: bash run: | - CXX=$(which g++-14) ./scripts/build.sh + CXX=$(which g++-14) ./scripts/build.sh + sudo ./scripts/test.sh - name: Build and test C++ Components (Windows) if: inputs.os == 'Windows' From 3d86ed31fa081607f2478dfbe0b49dd1f87974ad Mon Sep 17 00:00:00 2001 From: magniloquency <197707854+magniloquency@users.noreply.github.com> Date: Wed, 1 Oct 2025 22:24:56 -0400 Subject: [PATCH 17/23] Lint --- tests/cpp/ymq/py_mitm/main.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/cpp/ymq/py_mitm/main.py b/tests/cpp/ymq/py_mitm/main.py index 7cec8ed5e..ab20f933e 100644 --- a/tests/cpp/ymq/py_mitm/main.py +++ b/tests/cpp/ymq/py_mitm/main.py @@ -6,13 +6,13 @@ import argparse import os -import sys import signal import subprocess -from tests.cpp.ymq.py_mitm.types import AbstractMITM, TCPConnection -from scapy.all import IP, TCP, TunTapInterface # type: ignore +import types from typing import List +from scapy.all import IP, TCP, TunTapInterface # type: ignore +from tests.cpp.ymq.py_mitm.types import AbstractMITM, TCPConnection from tests.cpp.ymq.py_mitm import passthrough, randomly_drop_packets, send_rst_to_client @@ -167,6 +167,7 @@ def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: in args, unknown = parser.parse_known_args() # TODO: use `match` in Python 3.10+ + module: types.ModuleType if args.testcase == "passthrough": module = passthrough elif args.testcase == "randomly_drop_packets": From 25eebd5c2bc134d94a59a9513e9b77b18413e75a Mon Sep 17 00:00:00 2001 From: magniloquency <197707854+magniloquency@users.noreply.github.com> Date: Thu, 2 Oct 2025 19:54:31 -0400 Subject: [PATCH 18/23] Remove old code --- tests/cpp/ymq/py_mitm/send_rst_to_client.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/cpp/ymq/py_mitm/send_rst_to_client.py b/tests/cpp/ymq/py_mitm/send_rst_to_client.py index 11608f117..c33359b43 100644 --- a/tests/cpp/ymq/py_mitm/send_rst_to_client.py +++ b/tests/cpp/ymq/py_mitm/send_rst_to_client.py @@ -36,15 +36,3 @@ def proxy( elif sender == server_conn: tuntap.send(client_conn.rewrite(pkt)) return True - - -# client -> mitm -> server -# server -> mitm -> client - -# client: 127.0.0.1:8080 -# mitm: 127.0.0.1:8081 -# server: 127.0.0.1:8081 - - -# client -> mitm == src = client.ip, sport = client.port ;; dst = mitm.ip, dport = mitm.port -# mitm -> server == src = mitm.ip, sport = mitm.port ;; dst = server.ip, dport = server.port From 92c221f4fc0270ef74a529e90b90696f5d4fb0fb Mon Sep 17 00:00:00 2001 From: magniloquency <197707854+magniloquency@users.noreply.github.com> Date: Thu, 2 Oct 2025 20:02:37 -0400 Subject: [PATCH 19/23] Fix reconnect test --- tests/cpp/ymq/test_cc_ymq.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/cpp/ymq/test_cc_ymq.cpp b/tests/cpp/ymq/test_cc_ymq.cpp index be8267382..b1df1a46f 100644 --- a/tests/cpp/ymq/test_cc_ymq.cpp +++ b/tests/cpp/ymq/test_cc_ymq.cpp @@ -150,6 +150,9 @@ TestResult reconnect_client_main(std::string host, uint16_t port) auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); syncConnectSocket(socket, format_address(host, port)); + // create the recv future in advance, this remains active between reconnects + auto future = futureRecvMessage(socket); + // send "sync" and wait for "acknowledge" in a loop // the mitm will send a RST after the first "sync" // the "sync" message will be lost, but YMQ should automatically reconnect @@ -158,15 +161,10 @@ TestResult reconnect_client_main(std::string host, uint16_t port) auto error = syncSendMessage(socket, {.address = Bytes("server"), .payload = Bytes("sync")}); RETURN_FAILURE_IF_FALSE(!error); - auto future = futureRecvMessage(socket); auto result = future.wait_for(1s); if (result == std::future_status::ready) { auto msg = future.get(); - if (!msg.has_value()) { - std::println("message error: {}", msg.error().what()); - } RETURN_FAILURE_IF_FALSE(msg.has_value()); - std::println("received message: {}", *msg->payload.as_string()); RETURN_FAILURE_IF_FALSE(msg->payload.as_string() == "acknowledge"); context.removeIOSocket(socket); @@ -524,7 +522,7 @@ TEST(CcYmqTestSuite, TestMitmPassthrough) } // this test uses the mitm to test the reconnect logic of YMQ by sending RST packets -TEST(CcYmqTestSuite, DISABLED_TestMitmReconnect) +TEST(CcYmqTestSuite, TestMitmReconnect) { auto mitm_ip = "192.0.2.4"; auto mitm_port = 2525; From fdba48445048c7077306b3bf80cc7dff25639bac Mon Sep 17 00:00:00 2001 From: magniloquency <197707854+magniloquency@users.noreply.github.com> Date: Fri, 3 Oct 2025 12:29:56 -0400 Subject: [PATCH 20/23] Update tests/cpp/ymq/py_mitm/main.py Co-authored-by: rafa-be --- tests/cpp/ymq/py_mitm/main.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/cpp/ymq/py_mitm/main.py b/tests/cpp/ymq/py_mitm/main.py index ab20f933e..91d3b07de 100644 --- a/tests/cpp/ymq/py_mitm/main.py +++ b/tests/cpp/ymq/py_mitm/main.py @@ -155,10 +155,16 @@ def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: in return +TESTCASES = { + "passthrough": passthrough, + "randomly_drop_packets": randomly_drop_packets, + "send_rst_to_client": send_rst_to_client, +} + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Man in the middle test framework") parser.add_argument("pid", type=int, help="The pid of the test process, used for signaling") - parser.add_argument("testcase", type=str, help="The MITM test case module name") + parser.add_argument("testcase", type=str, choices=TESTCASES.keys(), help="The MITM test case module name") parser.add_argument("mitm_ip", type=str, help="The desired ip address of the mitm server") parser.add_argument("mitm_port", type=int, help="The desired port of the mitm server") parser.add_argument("remote_ip", type=str, help="The desired remote ip for the TUNTAP interface") @@ -166,15 +172,6 @@ def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: in args, unknown = parser.parse_known_args() - # TODO: use `match` in Python 3.10+ - module: types.ModuleType - if args.testcase == "passthrough": - module = passthrough - elif args.testcase == "randomly_drop_packets": - module = randomly_drop_packets - elif args.testcase == "send_rst_to_client": - module = send_rst_to_client - else: - raise ValueError(f"Unknown testcase: {args.testcase}") + module = TESTCASES[args.testcase] main(args.pid, args.mitm_ip, args.mitm_port, args.remote_ip, args.server_port, module.MITM(*unknown)) From cac44c8758002daa8de1f9b2e55eab6b274b1d02 Mon Sep 17 00:00:00 2001 From: magniloquency <197707854+magniloquency@users.noreply.github.com> Date: Mon, 6 Oct 2025 22:35:46 -0400 Subject: [PATCH 21/23] YMQ OSS Client (#291) * Rework YMQ * Remove unused includes * Address comments from rafa-be * Address comments, format, improve typing * Adjust GIL acquisition after feedback * Fix Python 3.8 * Add num_threads property to iocontext * fix mypy * Update SyncObjectStorageConnector to use ymq Signed-off-by: gxu * Make Linter Happy Signed-off-by: gxu * Satisfy Linter Signed-off-by: gxu * satisfy linter Signed-off-by: gxu * Let AsyncObjectStorageConnector use ymq Signed-off-by: gxu * Apply comment from mag Signed-off-by: gxu * fix lint * catch connector socket closed by remote end error * add todo --------- Signed-off-by: gxu Co-authored-by: gxu --- pyproject.toml | 1 + scaler/io/async_object_storage_connector.py | 105 ++---- scaler/io/sync_object_storage_connector.py | 121 ++----- scaler/io/ymq/CMakeLists.txt | 5 +- scaler/io/ymq/{ymq.pyi => _ymq.pyi} | 37 +-- scaler/io/ymq/pymod_ymq/async.h | 97 ------ scaler/io/ymq/pymod_ymq/bytes.h | 12 +- scaler/io/ymq/pymod_ymq/exception.h | 15 +- scaler/io/ymq/pymod_ymq/gil.h | 15 + scaler/io/ymq/pymod_ymq/io_context.h | 110 ++----- scaler/io/ymq/pymod_ymq/io_socket.h | 336 ++++++-------------- scaler/io/ymq/pymod_ymq/message.h | 3 +- scaler/io/ymq/pymod_ymq/python.h | 6 +- scaler/io/ymq/pymod_ymq/utils.h | 110 ------- scaler/io/ymq/pymod_ymq/ymq.cpp | 2 +- scaler/io/ymq/pymod_ymq/ymq.h | 147 +-------- scaler/io/ymq/ymq.py | 130 ++++++++ scaler/worker/worker.py | 8 + 18 files changed, 376 insertions(+), 884 deletions(-) rename scaler/io/ymq/{ymq.pyi => _ymq.pyi} (68%) delete mode 100644 scaler/io/ymq/pymod_ymq/async.h create mode 100644 scaler/io/ymq/pymod_ymq/gil.h delete mode 100644 scaler/io/ymq/pymod_ymq/utils.h create mode 100644 scaler/io/ymq/ymq.py diff --git a/pyproject.toml b/pyproject.toml index 5e1c0b1ac..58392afa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "tblib", "aiohttp", "graphlib-backport; python_version < '3.9'", + "typing-extensions>=4.0; python_version < '3.10'" ] [project.optional-dependencies] diff --git a/scaler/io/async_object_storage_connector.py b/scaler/io/async_object_storage_connector.py index 940201cce..64a9c4029 100644 --- a/scaler/io/async_object_storage_connector.py +++ b/scaler/io/async_object_storage_connector.py @@ -2,11 +2,11 @@ import logging import os import socket -import struct import uuid from typing import Dict, Optional, Tuple from scaler.io.mixins import AsyncObjectStorageConnector +from scaler.io.ymq.ymq import IOSocketType, IOContext, Message, YMQException from scaler.protocol.capnp._python import _object_storage # noqa from scaler.protocol.python.object_storage import ObjectRequestHeader, ObjectResponseHeader, to_capnp_object_id from scaler.utility.exceptions import ObjectStorageException @@ -22,19 +22,18 @@ def __init__(self): self._connected_event = asyncio.Event() - self._reader: Optional[asyncio.StreamReader] = None - self._writer: Optional[asyncio.StreamWriter] = None - self._next_request_id = 0 self._pending_get_requests: Dict[ObjectID, asyncio.Future] = {} - self._identity: bytes = f"{os.getpid()}|{socket.gethostname().split('.')[0]}|{uuid.uuid4()}".encode() + self._lock = asyncio.Lock() + self._identity: str = f"{os.getpid()}|{socket.gethostname().split('.')[0]}|{uuid.uuid4()}" + self._io_context: IOContext = IOContext() + self._io_socket = self._io_context.createIOSocket_sync(self._identity, IOSocketType.Connector) def __del__(self): if not self.is_connected(): return - - self._writer.close() + self._io_socket = None async def connect(self, host: str, port: int): self._host = host @@ -42,20 +41,7 @@ async def connect(self, host: str, port: int): if self.is_connected(): raise ObjectStorageException("connector is already connected.") - - self._reader, self._writer = await asyncio.open_connection(self._host, self._port) - await self.__read_framed_message() - self.__write_framed(self._identity) - - try: - await self._writer.drain() - except ConnectionResetError: - self.__raise_connection_failure() - - # Makes sure the socket is TCP_NODELAY. It seems to be the case by default, but that's not specified in the - # asyncio's documentation and might change in the future. - self._writer.get_extra_info("socket").setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - + await self._io_socket.connect(self.address) self._connected_event.set() async def wait_until_connected(self): @@ -67,23 +53,10 @@ def is_connected(self) -> bool: async def destroy(self): if not self.is_connected(): return - - if not self._writer.is_closing: - self._writer.close() - - await self._writer.wait_closed() - - @property - def reader(self) -> Optional[asyncio.StreamReader]: - return self._reader - - @property - def writer(self) -> Optional[asyncio.StreamWriter]: - return self._writer + self._io_socket = None @property def address(self) -> str: - self.__ensure_is_connected() return f"tcp://{self._host}:{self._port}" async def routine(self): @@ -136,12 +109,9 @@ async def duplicate_object_id(self, object_id: ObjectID, new_object_id: ObjectID ) def __ensure_is_connected(self): - if self._writer is None: + if self._io_socket is None: raise ObjectStorageException("connector is not connected.") - if self._writer.is_closing(): - raise ObjectStorageException("connection is closed.") - async def __send_request( self, object_id: ObjectID, @@ -150,7 +120,6 @@ async def __send_request( payload: Optional[bytes], ): self.__ensure_is_connected() - assert self._writer is not None request_id = self._next_request_id self._next_request_id += 1 @@ -158,67 +127,59 @@ async def __send_request( header = ObjectRequestHeader.new_msg(object_id, payload_length, request_id, request_type) - self.__write_request_header(header) + try: + async with self._lock: + await self.__write_request_header(header) - if payload is not None: - self.__write_request_payload(payload) + if payload is not None: + await self.__write_request_payload(payload) - try: - await self._writer.drain() - except ConnectionResetError: + except YMQException: + self._io_socket = None self.__raise_connection_failure() - def __write_request_header(self, header: ObjectRequestHeader): - assert self._writer is not None - self.__write_framed(header.get_message().to_bytes()) + async def __write_request_header(self, header: ObjectRequestHeader): + assert self._io_socket is not None + await self._io_socket.send(Message(address=None, payload=header.get_message().to_bytes())) - def __write_request_payload(self, payload: bytes): - assert self._writer is not None - self.__write_framed(payload) + async def __write_request_payload(self, payload: bytes): + assert self._io_socket is not None + await self._io_socket.send(Message(address=None, payload=payload)) async def __receive_response(self) -> Optional[Tuple[ObjectResponseHeader, bytes]]: - assert self._reader is not None - - if self._writer.is_closing(): + if self._io_socket is None: return None try: header = await self.__read_response_header() payload = await self.__read_response_payload(header) - except asyncio.IncompleteReadError: + except YMQException: + self._io_socket = None self.__raise_connection_failure() return header, payload async def __read_response_header(self) -> ObjectResponseHeader: - assert self._reader is not None + assert self._io_socket is not None - header_data = await self.__read_framed_message() + msg = await self._io_socket.recv() + header_data = msg.payload.data assert len(header_data) == ObjectResponseHeader.MESSAGE_LENGTH with _object_storage.ObjectResponseHeader.from_bytes(header_data) as header_message: return ObjectResponseHeader(header_message) async def __read_response_payload(self, header: ObjectResponseHeader) -> bytes: - assert self._reader is not None + assert self._io_socket is not None + # assert self._reader is not None if header.payload_length > 0: - res = await self.__read_framed_message() - assert len(res) == header.payload_length - return res + res = await self._io_socket.recv() + assert len(res.payload) == header.payload_length + return res.payload.data else: return b"" - async def __read_framed_message(self) -> bytes: - length_bytes = await self._reader.readexactly(8) - (payload_length,) = struct.unpack(" 0 else bytes() - - def __write_framed(self, payload: bytes): - self._writer.write(struct.pack(" str: @@ -114,7 +109,7 @@ def duplicate_object_id(self, object_id: ObjectID, new_object_id: ObjectID) -> N self.__ensure_empty_payload(response_payload) def __ensure_is_connected(self): - if self._socket is None: + if self._io_socket is None: raise ObjectStorageException("connector is closed.") def __ensure_response_type( @@ -135,7 +130,7 @@ def __send_request( payload: Optional[bytes] = None, ): self.__ensure_is_connected() - assert self._socket is not None + assert self._io_socket is not None request_id = self._next_request_id self._next_request_id += 1 @@ -145,102 +140,46 @@ def __send_request( header_bytes = header.get_message().to_bytes() if payload is not None: - self.__send_buffers( - [struct.pack(" None: - if len(buffers) < 1: - return - - total_size = sum(len(buffer) for buffer in buffers) - - # If the message is small enough, first try to send it at once with sendmsg(). This would ensure the message can - # be transmitted within a single TCP segment. - if total_size < MAX_CHUNK_SIZE: - sent = self._socket.sendmsg(buffers) - - if sent <= 0: - self.__raise_connection_failure() - - remaining_buffers = collections.deque(buffers) - while sent > len(remaining_buffers[0]): - removed_buffer = remaining_buffers.popleft() - sent -= len(removed_buffer) - - if sent > 0: - # Truncate the first partially sent buffer - remaining_buffers[0] = memoryview(remaining_buffers[0])[sent:] - - buffers = list(remaining_buffers) - - # Send the remaining buffers sequentially - for buffer in buffers: - self.__send_buffer(buffer) - - def __send_buffer(self, buffer: bytes) -> None: - buffer_view = memoryview(buffer) + self._io_socket.send_sync(Message(address=None, payload=header_bytes)) - total_sent = 0 - while total_sent < len(buffer): - sent = self._socket.send(buffer_view[total_sent : MAX_CHUNK_SIZE + total_sent]) + def __receive_response(self): + assert self._io_socket is not None - if sent <= 0: - self.__raise_connection_failure() - - total_sent += sent - - def __receive_response(self) -> Tuple[ObjectResponseHeader, bytearray]: - assert self._socket is not None - - header = self.__read_response_header() - payload = self.__read_response_payload(header) - - return header, payload + try: + header = self.__read_response_header() + payload = self.__read_response_payload(header) + return header, payload + except YMQException: + self.__raise_connection_failure() def __read_response_header(self) -> ObjectResponseHeader: - assert self._socket is not None + assert self._io_socket is not None - header_bytearray = self.__read_framed_message() + header_bytes = self._io_socket.recv_sync().payload.data + if header_bytes is None: + self.__raise_connection_failure() # pycapnp does not like to read from a bytearray object. This look like an not-yet-resolved issue. # That's is annoying because it leads to an unnecessary copy of the header's buffer. # See https://github.com/capnproto/pycapnp/issues/153 - header_bytes = bytes(header_bytearray) + # header_bytes = bytes(header_bytearray) with _object_storage.ObjectResponseHeader.from_bytes(header_bytes) as header_message: return ObjectResponseHeader(header_message) def __read_response_payload(self, header: ObjectResponseHeader) -> bytearray: if header.payload_length > 0: - res = self.__read_framed_message() + res = self._io_socket.recv_sync().payload.data + if res is None: + self.__raise_connection_failure() assert len(res) == header.payload_length - return res + return bytearray(res) else: return bytearray() - def __read_exactly(self, length: int) -> bytearray: - buffer = bytearray(length) - - total_received = 0 - while total_received < length: - chunk_size = min(MAX_CHUNK_SIZE, length - total_received) - received = self._socket.recv_into(memoryview(buffer)[total_received:], chunk_size) - - if received <= 0: - self.__raise_connection_failure() - - total_received += received - - return buffer - - def __read_framed_message(self) -> bytearray: - length_bytes = self.__read_exactly(8) - (payload_length,) = struct.unpack(" 0 else bytearray() - @staticmethod def __raise_connection_failure(): raise ObjectStorageException("connection failure to object storage server.") diff --git a/scaler/io/ymq/CMakeLists.txt b/scaler/io/ymq/CMakeLists.txt index 63b898929..5d6cc2d1c 100644 --- a/scaler/io/ymq/CMakeLists.txt +++ b/scaler/io/ymq/CMakeLists.txt @@ -61,13 +61,12 @@ if(LINUX) find_package(Python3 COMPONENTS Development.Module REQUIRED) add_library(py_ymq SHARED - pymod_ymq/async.h pymod_ymq/bytes.h pymod_ymq/exception.h + pymod_ymq/gil.h pymod_ymq/message.h pymod_ymq/io_context.h pymod_ymq/io_socket.h - pymod_ymq/utils.h pymod_ymq/ymq.h pymod_ymq/ymq.cpp ) @@ -81,7 +80,7 @@ if(LINUX) set_target_properties(py_ymq PROPERTIES PREFIX "" - OUTPUT_NAME "ymq" + OUTPUT_NAME "_ymq" LINKER_LANGUAGE CXX ) diff --git a/scaler/io/ymq/ymq.pyi b/scaler/io/ymq/_ymq.pyi similarity index 68% rename from scaler/io/ymq/ymq.pyi rename to scaler/io/ymq/_ymq.pyi index 03229bca9..f27e9b45e 100644 --- a/scaler/io/ymq/ymq.pyi +++ b/scaler/io/ymq/_ymq.pyi @@ -1,9 +1,8 @@ # NOTE: NOT IMPLEMENTATION, TYPE INFORMATION ONLY # This file contains type stubs for the Ymq Python C Extension module import sys -from collections.abc import Awaitable from enum import IntEnum -from typing import SupportsBytes +from typing import Callable, Optional, SupportsBytes, Union if sys.version_info >= (3, 12): from collections.abc import Buffer @@ -39,46 +38,33 @@ class IOSocketType(IntEnum): Unicast = 3 Multicast = 4 -class IOContext: +class BaseIOContext: num_threads: int def __init__(self, num_threads: int = 1) -> None: ... def __repr__(self) -> str: ... - def createIOSocket(self, /, identity: str, socket_type: IOSocketType) -> Awaitable[IOSocket]: + def createIOSocket( + self, callback: Callable[[Union[BaseIOSocket, Exception]], None], identity: str, socket_type: IOSocketType + ) -> None: """Create an io socket with an identity and socket type""" - def createIOSocket_sync(self, /, identity: str, socket_type: IOSocketType) -> IOSocket: - """Create an io socket with an identity and socket type synchronously""" - -class IOSocket: +class BaseIOSocket: identity: str socket_type: IOSocketType def __repr__(self) -> str: ... - async def send(self, message: Message) -> None: + def send(self, callback: Callable[[Optional[Exception]], None], message: Message) -> None: """Send a message to one of the socket's peers""" - async def recv(self) -> Message: + def recv(self, callback: Callable[[Union[Message, Exception]], None]) -> None: """Receive a message from one of the socket's peers""" - async def bind(self, address: str) -> None: + def bind(self, callback: Callable[[Optional[Exception]], None], address: str) -> None: """Bind the socket to an address and listen for incoming connections""" - async def connect(self, address: str) -> None: + def connect(self, callback: Callable[[Optional[Exception]], None], address: str) -> None: """Connect to a remote socket""" - def send_sync(self, message: Message) -> None: - """Send a message to one of the socket's peers synchronously""" - - def recv_sync(self) -> Message: - """Receive a message from one of the socket's peers synchronously""" - - def bind_sync(self, address: str) -> None: - """Bind the socket to an address and listen for incoming connections synchronously""" - - def connect_sync(self, address: str) -> None: - """Connect to a remote socket synchronously""" - class ErrorCode(IntEnum): Uninit = 0 InvalidPortFormat = 1 @@ -108,6 +94,3 @@ class YMQException(Exception): def __init__(self, /, code: ErrorCode, message: str) -> None: ... def __repr__(self) -> str: ... def __str__(self) -> str: ... - -class YMQInterruptedException(YMQException): - def __init__(self) -> None: ... diff --git a/scaler/io/ymq/pymod_ymq/async.h b/scaler/io/ymq/pymod_ymq/async.h deleted file mode 100644 index 602097eb9..000000000 --- a/scaler/io/ymq/pymod_ymq/async.h +++ /dev/null @@ -1,97 +0,0 @@ -#pragma once - -// Python -#include "scaler/io/ymq/pymod_ymq/python.h" - -// C++ -#include - -// First-party -#include "scaler/io/ymq/pymod_ymq/ymq.h" - -// wraps an async callback that accepts a Python asyncio future -static PyObject* async_wrapper(PyObject* self, const std::function&& callback) -{ - auto state = YMQStateFromSelf(self); - if (!state) - return nullptr; - - OwnedPyObject loop = PyObject_CallMethod(*state->asyncioModule, "get_event_loop", nullptr); - if (!loop) { - PyErr_SetString(PyExc_RuntimeError, "Failed to get event loop"); - return nullptr; - } - - OwnedPyObject future = PyObject_CallMethod(*loop, "create_future", nullptr); - if (!future) { - PyErr_SetString(PyExc_RuntimeError, "Failed to create future"); - return nullptr; - } - - // create the awaitable before calling the callback - // this ensures that we create a new strong reference to the future before the callback decrefs it - auto awaitable = PyObject_CallFunction(*state->PyAwaitableType, "O", *future); - - // async - // we transfer ownership of the future to the callback - // TODO: investigate having the callback take an OwnedPyObject, and just std::move() - callback(state, future.take()); - - return awaitable; -} - -struct Awaitable { - PyObject_HEAD; - OwnedPyObject<> future; -}; - -extern "C" { - -static int Awaitable_init(Awaitable* self, PyObject* args, PyObject* kwds) -{ - PyObject* future = nullptr; - if (!PyArg_ParseTuple(args, "O", &future)) - return -1; - - new (&self->future) OwnedPyObject<>(); - self->future = OwnedPyObject<>::fromBorrowed(future); - - return 0; -} - -static PyObject* Awaitable_await(Awaitable* self) -{ - // Easy: coroutines are just iterators and we don't need anything fancy - // so we can just return the future's iterator! - return PyObject_GetIter(*self->future); -} - -static void Awaitable_dealloc(Awaitable* self) -{ - try { - self->future.~OwnedPyObject(); - } catch (...) { - PyErr_SetString(PyExc_RuntimeError, "Failed to deallocate Awaitable"); - PyErr_WriteUnraisable((PyObject*)self); - } - - auto* tp = Py_TYPE(self); - tp->tp_free(self); - Py_DECREF(tp); -} -} - -static PyType_Slot Awaitable_slots[] = { - {Py_tp_init, (void*)Awaitable_init}, - {Py_tp_dealloc, (void*)Awaitable_dealloc}, - {Py_am_await, (void*)Awaitable_await}, - {0, nullptr}, -}; - -static PyType_Spec Awaitable_spec { - .name = "ymq.Awaitable", - .basicsize = sizeof(Awaitable), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE, - .slots = Awaitable_slots, -}; diff --git a/scaler/io/ymq/pymod_ymq/bytes.h b/scaler/io/ymq/pymod_ymq/bytes.h index 0941c02c9..40743d7b6 100644 --- a/scaler/io/ymq/pymod_ymq/bytes.h +++ b/scaler/io/ymq/pymod_ymq/bytes.h @@ -6,8 +6,6 @@ // First-party #include "scaler/io/ymq/bytes.h" -using namespace scaler::ymq; - struct PyBytesYMQ { PyObject_HEAD; Bytes bytes; @@ -19,9 +17,8 @@ static int PyBytesYMQ_init(PyBytesYMQ* self, PyObject* args, PyObject* kwds) { Py_buffer view {.buf = nullptr}; const char* keywords[] = {"bytes", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|y*", (char**)keywords, &view)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|y*", (char**)keywords, &view)) return -1; // Error parsing arguments - } if (!view.buf) { // If no bytes were provided, initialize with an empty Bytes object @@ -94,11 +91,6 @@ static PyGetSetDef PyBytesYMQ_properties[] = { {nullptr, nullptr, nullptr, nullptr, nullptr}, // Sentinel }; -static PyBufferProcs PyBytesYMQBufferProcs = { - .bf_getbuffer = (getbufferproc)PyBytesYMQ_getbuffer, - .bf_releasebuffer = (releasebufferproc)PyBytesYMQ_releasebuffer, -}; - static PyType_Slot PyBytesYMQ_slots[] = { {Py_tp_init, (void*)PyBytesYMQ_init}, {Py_tp_dealloc, (void*)PyBytesYMQ_dealloc}, @@ -113,7 +105,7 @@ static PyType_Slot PyBytesYMQ_slots[] = { }; static PyType_Spec PyBytesYMQ_spec = { - .name = "ymq.Bytes", + .name = "_ymq.Bytes", .basicsize = sizeof(PyBytesYMQ), .itemsize = 0, .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE, diff --git a/scaler/io/ymq/pymod_ymq/exception.h b/scaler/io/ymq/pymod_ymq/exception.h index 2369d3b45..fb63862a0 100644 --- a/scaler/io/ymq/pymod_ymq/exception.h +++ b/scaler/io/ymq/pymod_ymq/exception.h @@ -3,13 +3,12 @@ // Python #include "scaler/io/ymq/pymod_ymq/python.h" -// C++ -#include - // First-party -#include "scaler/io/ymq/pymod_ymq/utils.h" +#include "scaler/io/ymq/error.h" #include "scaler/io/ymq/pymod_ymq/ymq.h" +using namespace scaler::ymq; + // the order of the members in the exception args tuple const Py_ssize_t YMQException_errorCodeIndex = 0; const Py_ssize_t YMQException_messageIndex = 1; @@ -81,9 +80,9 @@ static PyType_Slot YMQException_slots[] = { }; static PyType_Spec YMQException_spec = { - "ymq.YMQException", sizeof(YMQException), 0, Py_TPFLAGS_DEFAULT, YMQException_slots}; + "_ymq.YMQException", sizeof(YMQException), 0, Py_TPFLAGS_DEFAULT, YMQException_slots}; -OwnedPyObject<> YMQException_argtupleFromCoreError(YMQState* state, const Error* error) +inline OwnedPyObject<> YMQException_argtupleFromCoreError(YMQState* state, const Error* error) { OwnedPyObject code = PyLong_FromLong(static_cast(error->_errorCode)); @@ -103,7 +102,7 @@ OwnedPyObject<> YMQException_argtupleFromCoreError(YMQState* state, const Error* return PyTuple_Pack(2, *pyCode, *message); } -void YMQException_setFromCoreError(YMQState* state, const Error* error) +inline void YMQException_setFromCoreError(YMQState* state, const Error* error) { auto tuple = YMQException_argtupleFromCoreError(state, error); if (!tuple) @@ -112,7 +111,7 @@ void YMQException_setFromCoreError(YMQState* state, const Error* error) PyErr_SetObject(*state->PyExceptionType, *tuple); } -PyObject* YMQException_createFromCoreError(YMQState* state, const Error* error) +inline PyObject* YMQException_createFromCoreError(YMQState* state, const Error* error) { auto tuple = YMQException_argtupleFromCoreError(state, error); if (!tuple) diff --git a/scaler/io/ymq/pymod_ymq/gil.h b/scaler/io/ymq/pymod_ymq/gil.h new file mode 100644 index 000000000..c28590e8d --- /dev/null +++ b/scaler/io/ymq/pymod_ymq/gil.h @@ -0,0 +1,15 @@ +#include "scaler/io/ymq/pymod_ymq/python.h" + +class AcquireGIL { +public: + AcquireGIL() : _state(PyGILState_Ensure()) {} + ~AcquireGIL() { PyGILState_Release(_state); } + + AcquireGIL(const AcquireGIL&) = delete; + AcquireGIL& operator=(const AcquireGIL&) = delete; + AcquireGIL(AcquireGIL&&) = delete; + AcquireGIL& operator=(AcquireGIL&&) = delete; + +private: + PyGILState_STATE _state; +}; diff --git a/scaler/io/ymq/pymod_ymq/io_context.h b/scaler/io/ymq/pymod_ymq/io_context.h index deb63003e..13468b388 100644 --- a/scaler/io/ymq/pymod_ymq/io_context.h +++ b/scaler/io/ymq/pymod_ymq/io_context.h @@ -4,8 +4,6 @@ #include "scaler/io/ymq/pymod_ymq/python.h" // C++ -#include -#include #include // First-party @@ -64,25 +62,26 @@ static PyObject* PyIOContext_repr(PyIOContext* self) return PyUnicode_FromFormat("", (void*)self->ioContext.get()); } -static PyObject* PyIOContext_createIOSocket_( - PyIOContext* self, - PyObject* args, - PyObject* kwargs, - std::function fn) +static PyObject* PyIOContext_numThreads_getter(PyIOContext* self, void* Py_UNUSED(closure)) { - const char* identity = nullptr; - Py_ssize_t identityLen = 0; - PyObject* pySocketType = nullptr; - const char* kwlist[] = {"identity", "pySocketType", nullptr}; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#O", (char**)kwlist, &identity, &identityLen, &pySocketType)) - return nullptr; + return PyLong_FromSize_t(self->ioContext->numThreads()); +} +static PyObject* PyIOContext_createIOSocket(PyIOContext* self, PyObject* args, PyObject* kwargs) +{ YMQState* state = YMQStateFromSelf((PyObject*)self); - if (!state) return nullptr; + PyObject* callback = nullptr; + const char* identity = nullptr; + Py_ssize_t identityLen = 0; + PyObject* pySocketType = nullptr; + const char* kwlist[] = {"", "identity", "socket_type", nullptr}; + if (!PyArg_ParseTupleAndKeywords( + args, kwargs, "Os#O", (char**)kwlist, &callback, &identity, &identityLen, &pySocketType)) + return nullptr; + if (!PyObject_IsInstance(pySocketType, *state->PyIOSocketEnumType)) { PyErr_SetString(PyExc_TypeError, "Expected socket_type to be an instance of IOSocketType"); return nullptr; @@ -98,83 +97,36 @@ static PyObject* PyIOContext_createIOSocket_( } long socketTypeValue = PyLong_AsLong(*value); - if (socketTypeValue < 0 && PyErr_Occurred()) return nullptr; - IOSocketType socketType = static_cast(socketTypeValue); - + IOSocketType socketType = static_cast(socketTypeValue); OwnedPyObject ioSocket = PyObject_New(PyIOSocket, (PyTypeObject*)*state->PyIOSocketType); if (!ioSocket) return nullptr; + Py_INCREF(callback); + try { // ensure the fields are init new (&ioSocket->socket) std::shared_ptr(); new (&ioSocket->ioContext) std::shared_ptr(); ioSocket->ioContext = self->ioContext; - } catch (...) { - PyErr_SetString(PyExc_RuntimeError, "Failed to create IOSocket"); - return nullptr; - } - // move ownership of the ioSocket to the callback - return fn(ioSocket.take(), identity, socketType); -} + self->ioContext->createIOSocket( + std::string(identity, identityLen), socketType, [callback, ioSocket](auto socket) { + AcquireGIL _; -static PyObject* PyIOContext_createIOSocket(PyIOContext* self, PyObject* args, PyObject* kwargs) -{ - return PyIOContext_createIOSocket_( - self, args, kwargs, [self](auto ioSocket, Identity identity, IOSocketType socketType) { - return async_wrapper((PyObject*)self, [=](YMQState* state, auto future) { - self->ioContext->createIOSocket(identity, socketType, [=](std::shared_ptr socket) { - future_set_result(future, [=] { - ioSocket->socket = std::move(socket); - return (PyObject*)ioSocket; - }); - }); + ioSocket->socket = socket; + OwnedPyObject _result = PyObject_CallFunctionObjArgs(callback, *ioSocket, nullptr); + Py_DECREF(callback); }); - }); -} - -static PyObject* PyIOContext_createIOSocket_sync(PyIOContext* self, PyObject* args, PyObject* kwargs) -{ - auto state = YMQStateFromSelf((PyObject*)self); - if (!state) + } catch (...) { + PyErr_SetString(PyExc_RuntimeError, "Failed to create IOSocket"); return nullptr; + } - return PyIOContext_createIOSocket_( - self, args, kwargs, [self, state](auto ioSocket, Identity identity, IOSocketType socketType) { - PyThreadState* _save = PyEval_SaveThread(); - - std::shared_ptr socket {}; - try { - Waiter waiter(state->wakeupfd_rd); - - self->ioContext->createIOSocket( - identity, socketType, [waiter, &socket](std::shared_ptr s) mutable { - socket = std::move(s); - waiter.signal(); - }); - - if (waiter.wait()) - CHECK_SIGNALS; - } catch (...) { - PyEval_RestoreThread(_save); - PyErr_SetString(PyExc_RuntimeError, "Failed to create io socket synchronously"); - return (PyObject*)nullptr; - } - - PyEval_RestoreThread(_save); - - ioSocket->socket = socket; - return (PyObject*)ioSocket; - }); -} - -static PyObject* PyIOContext_numThreads_getter(PyIOContext* self, void* Py_UNUSED(closure)) -{ - return PyLong_FromSize_t(self->ioContext->numThreads()); + Py_RETURN_NONE; } } // extern "C" @@ -184,10 +136,6 @@ static PyMethodDef PyIOContext_methods[] = { (PyCFunction)PyIOContext_createIOSocket, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("Create a new IOSocket")}, - {"createIOSocket_sync", - (PyCFunction)PyIOContext_createIOSocket_sync, - METH_VARARGS | METH_KEYWORDS, - PyDoc_STR("Create a new IOSocket")}, {nullptr, nullptr, 0, nullptr}, }; @@ -210,9 +158,9 @@ static PyType_Slot PyIOContext_slots[] = { }; static PyType_Spec PyIOContext_spec = { - .name = "ymq.IOContext", + .name = "_ymq.BaseIOContext", .basicsize = sizeof(PyIOContext), .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE | Py_TPFLAGS_BASETYPE, .slots = PyIOContext_slots, }; diff --git a/scaler/io/ymq/pymod_ymq/io_socket.h b/scaler/io/ymq/pymod_ymq/io_socket.h index 251a6a461..d01f136f7 100644 --- a/scaler/io/ymq/pymod_ymq/io_socket.h +++ b/scaler/io/ymq/pymod_ymq/io_socket.h @@ -4,27 +4,25 @@ #include "scaler/io/ymq/pymod_ymq/python.h" // C++ -#include -#include #include -#include #include // C +#include #include #include #include // First-party #include "scaler/io/ymq/bytes.h" +#include "scaler/io/ymq/error.h" #include "scaler/io/ymq/io_context.h" #include "scaler/io/ymq/io_socket.h" #include "scaler/io/ymq/message.h" -#include "scaler/io/ymq/pymod_ymq/async.h" #include "scaler/io/ymq/pymod_ymq/bytes.h" #include "scaler/io/ymq/pymod_ymq/exception.h" +#include "scaler/io/ymq/pymod_ymq/gil.h" #include "scaler/io/ymq/pymod_ymq/message.h" -#include "scaler/io/ymq/pymod_ymq/utils.h" #include "scaler/io/ymq/pymod_ymq/ymq.h" using namespace scaler::ymq; @@ -54,231 +52,133 @@ static void PyIOSocket_dealloc(PyIOSocket* self) } static PyObject* PyIOSocket_send(PyIOSocket* self, PyObject* args, PyObject* kwargs) -{ - // borrowed reference - PyMessage* message = nullptr; - const char* kwlist[] = {"message", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O", (char**)kwlist, &message)) - return nullptr; - - auto address = message->address.is_none() ? Bytes() : std::move(message->address->bytes); - auto payload = std::move(message->payload->bytes); - - return async_wrapper((PyObject*)self, [=](YMQState* state, auto future) { - try { - self->socket->sendMessage({.address = std::move(address), .payload = std::move(payload)}, [=](auto result) { - future_set_result(future, [=] -> std::expected { - if (result) { - Py_RETURN_NONE; - } else { - return std::unexpected {YMQException_createFromCoreError(state, &result.error())}; - } - }); - }); - } catch (...) { - future_raise_exception( - future, [] { return PyErr_CreateFromString(PyExc_RuntimeError, "Failed to send message"); }); - } - }); -} - -static PyObject* PyIOSocket_send_sync(PyIOSocket* self, PyObject* args, PyObject* kwargs) { auto state = YMQStateFromSelf((PyObject*)self); if (!state) return nullptr; - // borrowed reference + PyObject* callback = nullptr; PyMessage* message = nullptr; - const char* kwlist[] = {"message", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O", (char**)kwlist, &message)) - return nullptr; - Bytes address = message->address.is_none() ? Bytes() : std::move(message->address->bytes); - Bytes payload = std::move(message->payload->bytes); - - PyThreadState* _save = PyEval_SaveThread(); - - std::shared_ptr> result = std::make_shared>(); - try { - Waiter waiter(state->wakeupfd_rd); - - self->socket->sendMessage({.address = std::move(address), .payload = std::move(payload)}, [=](auto r) mutable { - *result = std::move(r); - waiter.signal(); - }); + // empty str -> positional only + const char* kwlist[] = {"", "message", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO", (char**)kwlist, &callback, &message)) + return nullptr; - if (waiter.wait()) - CHECK_SIGNALS; - } catch (...) { - PyEval_RestoreThread(_save); - PyErr_SetString(PyExc_RuntimeError, "Failed to send synchronously"); + if (!PyObject_TypeCheck(message, (PyTypeObject*)*state->PyMessageType)) { + PyErr_SetString(PyExc_TypeError, "message must be a Message"); return nullptr; } - PyEval_RestoreThread(_save); + auto address = message->address.is_none() ? Bytes() : std::move(message->address->bytes); + auto payload = std::move(message->payload->bytes); + + Py_INCREF(callback); - if (!result) { - YMQException_setFromCoreError(state, &result->error()); + try { + self->socket->sendMessage( + {.address = std::move(address), .payload = std::move(payload)}, [callback, state](auto result) { + AcquireGIL _; + + if (result) { + OwnedPyObject result = PyObject_CallFunctionObjArgs(callback, Py_None, nullptr); + } else { + OwnedPyObject obj = YMQException_createFromCoreError(state, &result.error()); + OwnedPyObject _ = PyObject_CallFunctionObjArgs(callback, *obj, nullptr); + } + + Py_DECREF(callback); + }); + } catch (...) { + PyErr_SetString(PyExc_RuntimeError, "Failed to send message"); return nullptr; } Py_RETURN_NONE; } -static PyObject* PyIOSocket_recv(PyIOSocket* self, PyObject* args) -{ - return async_wrapper((PyObject*)self, [=](YMQState* state, auto future) { - self->socket->recvMessage([=](auto result) { - try { - future_set_result(future, [=] -> std::expected { - if (result.second._errorCode != Error::ErrorCode::Uninit) { - return std::unexpected {YMQException_createFromCoreError(state, &result.second)}; - } - - auto message = result.first; - OwnedPyObject address = (PyBytesYMQ*)PyObject_CallNoArgs(*state->PyBytesYMQType); - if (!address) - return YMQ_GetRaisedException(); - - address->bytes = std::move(message.address); - - OwnedPyObject payload = (PyBytesYMQ*)PyObject_CallNoArgs(*state->PyBytesYMQType); - if (!payload) - return YMQ_GetRaisedException(); - - payload->bytes = std::move(message.payload); - - OwnedPyObject pyMessage = - (PyMessage*)PyObject_CallFunction(*state->PyMessageType, "OO", *address, *payload); - if (!pyMessage) - return YMQ_GetRaisedException(); - - return (PyObject*)pyMessage.take(); - }); - } catch (...) { - future_raise_exception( - future, [] { return PyErr_CreateFromString(PyExc_RuntimeError, "Failed to receive message"); }); - } - }); - }); -} - -static PyObject* PyIOSocket_recv_sync(PyIOSocket* self, PyObject* args) +static PyObject* PyIOSocket_recv(PyIOSocket* self, PyObject* args, PyObject* kwargs) { auto state = YMQStateFromSelf((PyObject*)self); if (!state) return nullptr; - PyThreadState* _save = PyEval_SaveThread(); - - std::shared_ptr> result = std::make_shared>(); - try { - Waiter waiter(state->wakeupfd_rd); - - self->socket->recvMessage([=](auto r) mutable { - *result = std::move(r); - waiter.signal(); - }); - - if (waiter.wait()) - CHECK_SIGNALS; - } catch (...) { - PyEval_RestoreThread(_save); - PyErr_SetString(PyExc_RuntimeError, "Failed to recv synchronously"); + PyObject* callback = nullptr; + const char* kwlist[] = {"", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O", (char**)kwlist, &callback)) return nullptr; - } - PyEval_RestoreThread(_save); + Py_INCREF(callback); - if (result->second._errorCode != Error::ErrorCode::Uninit) { - YMQException_setFromCoreError(state, &result->second); - return nullptr; - } - - auto message = result->first; + try { + self->socket->recvMessage([callback, state](std::pair result) { + AcquireGIL _; - OwnedPyObject address = (PyBytesYMQ*)PyObject_CallNoArgs(*state->PyBytesYMQType); - if (!address) - return nullptr; + if (result.second._errorCode != Error::ErrorCode::Uninit) { + OwnedPyObject obj = YMQException_createFromCoreError(state, &result.second); + OwnedPyObject _ = PyObject_CallFunctionObjArgs(callback, *obj, nullptr); + return; + } - address->bytes = std::move(message.address); + auto message = result.first; + OwnedPyObject address = (PyBytesYMQ*)PyObject_CallNoArgs(*state->PyBytesYMQType); + if (!address) + return completeCallbackWithRaisedException(callback); - OwnedPyObject payload = (PyBytesYMQ*)PyObject_CallNoArgs(*state->PyBytesYMQType); - if (!payload) - return nullptr; + address->bytes = std::move(message.address); - payload->bytes = std::move(message.payload); + OwnedPyObject payload = (PyBytesYMQ*)PyObject_CallNoArgs(*state->PyBytesYMQType); + if (!payload) + return completeCallbackWithRaisedException(callback); - OwnedPyObject pyMessage = - (PyMessage*)PyObject_CallFunction(*state->PyMessageType, "OO", *address, *payload); - if (!pyMessage) - return nullptr; + payload->bytes = std::move(message.payload); - return (PyObject*)pyMessage.take(); -} + OwnedPyObject pyMessage = + (PyMessage*)PyObject_CallFunction(*state->PyMessageType, "OO", *address, *payload); + if (!pyMessage) + return completeCallbackWithRaisedException(callback); -static PyObject* PyIOSocket_bind(PyIOSocket* self, PyObject* args, PyObject* kwargs) -{ - const char* address = nullptr; - Py_ssize_t addressLen = 0; - const char* kwlist[] = {"address", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#", (char**)kwlist, &address, &addressLen)) + OwnedPyObject _result = PyObject_CallFunctionObjArgs(callback, *pyMessage, nullptr); + Py_DECREF(callback); + }); + } catch (...) { + PyErr_SetString(PyExc_RuntimeError, "Failed to receive message"); return nullptr; + } - return async_wrapper((PyObject*)self, [=](YMQState* state, auto future) { - try { - self->socket->bindTo(std::string(address, addressLen), [=](auto result) { - future_set_result(future, [=] -> std::expected { - if (!result) { - return std::unexpected {YMQException_createFromCoreError(state, &result.error())}; - } - - Py_RETURN_NONE; - }); - }); - } catch (...) { - future_raise_exception( - future, [] { return PyErr_CreateFromString(PyExc_RuntimeError, "Failed to bind to address"); }); - } - }); + Py_RETURN_NONE; } -static PyObject* PyIOSocket_bind_sync(PyIOSocket* self, PyObject* args, PyObject* kwargs) +static PyObject* PyIOSocket_bind(PyIOSocket* self, PyObject* args, PyObject* kwargs) { auto state = YMQStateFromSelf((PyObject*)self); if (!state) return nullptr; + PyObject* callback = nullptr; const char* address = nullptr; Py_ssize_t addressLen = 0; - const char* kwlist[] = {"address", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#", (char**)kwlist, &address, &addressLen)) + const char* kwlist[] = {"", "address", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "Os#", (char**)kwlist, &callback, &address, &addressLen)) return nullptr; - PyThreadState* _save = PyEval_SaveThread(); + Py_INCREF(callback); - auto result = std::make_shared>(); try { - Waiter waiter(state->wakeupfd_rd); + self->socket->bindTo(std::string(address, addressLen), [callback, state](auto result) { + AcquireGIL _; + + if (!result) { + OwnedPyObject exc = YMQException_createFromCoreError(state, &result.error()); + OwnedPyObject _ = PyObject_CallFunctionObjArgs(callback, *exc, nullptr); + } else { + OwnedPyObject _ = PyObject_CallFunctionObjArgs(callback, Py_None, nullptr); + } - self->socket->bindTo(std::string(address, addressLen), [=](auto r) mutable { - *result = std::move(r); - waiter.signal(); + Py_DECREF(callback); }); - - if (waiter.wait()) - CHECK_SIGNALS; } catch (...) { - PyEval_RestoreThread(_save); - PyErr_SetString(PyExc_RuntimeError, "Failed to bind synchronously"); - return nullptr; - } - - PyEval_RestoreThread(_save); - - if (!result) { - YMQException_setFromCoreError(state, &result->error()); + PyErr_SetString(PyExc_RuntimeError, "Failed to bind to address"); return nullptr; } @@ -286,66 +186,35 @@ static PyObject* PyIOSocket_bind_sync(PyIOSocket* self, PyObject* args, PyObject } static PyObject* PyIOSocket_connect(PyIOSocket* self, PyObject* args, PyObject* kwargs) -{ - const char* address = nullptr; - Py_ssize_t addressLen = 0; - const char* kwlist[] = {"address", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#", (char**)kwlist, &address, &addressLen)) - return nullptr; - - return async_wrapper((PyObject*)self, [=](YMQState* state, auto future) { - try { - self->socket->connectTo(std::string(address, addressLen), [=](auto result) { - future_set_result(future, [=] -> std::expected { - if (result || result.error()._errorCode == Error::ErrorCode::InitialConnectFailedWithInProgress) { - Py_RETURN_NONE; - } else { - return std::unexpected {YMQException_createFromCoreError(state, &result.error())}; - } - }); - }); - } catch (...) { - future_raise_exception( - future, [] { return PyErr_CreateFromString(PyExc_RuntimeError, "Failed to connect to address"); }); - } - }); -} - -static PyObject* PyIOSocket_connect_sync(PyIOSocket* self, PyObject* args, PyObject* kwargs) { auto state = YMQStateFromSelf((PyObject*)self); if (!state) return nullptr; + PyObject* callback = nullptr; const char* address = nullptr; Py_ssize_t addressLen = 0; - const char* kwlist[] = {"address", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#", (char**)kwlist, &address, &addressLen)) + const char* kwlist[] = {"", "address", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "Os#", (char**)kwlist, &callback, &address, &addressLen)) return nullptr; - PyThreadState* _save = PyEval_SaveThread(); + Py_INCREF(callback); - std::shared_ptr> result = std::make_shared>(); try { - Waiter waiter(state->wakeupfd_rd); + self->socket->connectTo(std::string(address, addressLen), [callback, state](auto result) { + AcquireGIL _; + + if (result || result.error()._errorCode == Error::ErrorCode::InitialConnectFailedWithInProgress) { + OwnedPyObject _ = PyObject_CallFunctionObjArgs(callback, Py_None, nullptr); + } else { + OwnedPyObject exc = YMQException_createFromCoreError(state, &result.error()); + OwnedPyObject _ = PyObject_CallFunctionObjArgs(callback, *exc, nullptr); + } - self->socket->connectTo(std::string(address, addressLen), [=](auto r) mutable { - *result = std::move(r); - waiter.signal(); + Py_DECREF(callback); }); - - if (waiter.wait()) - CHECK_SIGNALS; } catch (...) { - PyEval_RestoreThread(_save); - PyErr_SetString(PyExc_RuntimeError, "Failed to connect synchronously"); - return nullptr; - } - - PyEval_RestoreThread(_save); - - if (!result && result->error()._errorCode != Error::ErrorCode::InitialConnectFailedWithInProgress) { - YMQException_setFromCoreError(state, &result->error()); + PyErr_SetString(PyExc_RuntimeError, "Failed to connect to address"); return nullptr; } @@ -386,7 +255,7 @@ static PyGetSetDef PyIOSocket_properties[] = { static PyMethodDef PyIOSocket_methods[] = { {"send", (PyCFunction)PyIOSocket_send, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("Send data through the IOSocket")}, - {"recv", (PyCFunction)PyIOSocket_recv, METH_NOARGS, PyDoc_STR("Receive data from the IOSocket")}, + {"recv", (PyCFunction)PyIOSocket_recv, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("Receive data from the IOSocket")}, {"bind", (PyCFunction)PyIOSocket_bind, METH_VARARGS | METH_KEYWORDS, @@ -395,19 +264,6 @@ static PyMethodDef PyIOSocket_methods[] = { (PyCFunction)PyIOSocket_connect, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("Connect to a remote IOSocket")}, - {"send_sync", - (PyCFunction)PyIOSocket_send_sync, - METH_VARARGS | METH_KEYWORDS, - PyDoc_STR("Send data through the IOSocket synchronously")}, - {"recv_sync", (PyCFunction)PyIOSocket_recv_sync, METH_NOARGS, PyDoc_STR("Receive data from the IOSocket")}, - {"bind_sync", - (PyCFunction)PyIOSocket_bind_sync, - METH_VARARGS | METH_KEYWORDS, - PyDoc_STR("Bind to an address and listen for incoming connections")}, - {"connect_sync", - (PyCFunction)PyIOSocket_connect_sync, - METH_VARARGS | METH_KEYWORDS, - PyDoc_STR("Connect to a remote IOSocket")}, {nullptr, nullptr, 0, nullptr}, }; @@ -421,7 +277,7 @@ static PyType_Slot PyIOSocket_slots[] = { }; static PyType_Spec PyIOSocket_spec = { - .name = "ymq.IOSocket", + .name = "_ymq.BaseIOSocket", .basicsize = sizeof(PyIOSocket), .itemsize = 0, .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE | Py_TPFLAGS_DISALLOW_INSTANTIATION, diff --git a/scaler/io/ymq/pymod_ymq/message.h b/scaler/io/ymq/pymod_ymq/message.h index 52da66763..d8f7df5c9 100644 --- a/scaler/io/ymq/pymod_ymq/message.h +++ b/scaler/io/ymq/pymod_ymq/message.h @@ -5,7 +5,6 @@ // First-party #include "scaler/io/ymq/pymod_ymq/bytes.h" -#include "scaler/io/ymq/pymod_ymq/utils.h" #include "scaler/io/ymq/pymod_ymq/ymq.h" struct PyMessage { @@ -92,7 +91,7 @@ static PyType_Slot PyMessage_slots[] = { }; static PyType_Spec PyMessage_spec = { - .name = "ymq.Message", + .name = "_ymq.Message", .basicsize = sizeof(PyMessage), .itemsize = 0, .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE, diff --git a/scaler/io/ymq/pymod_ymq/python.h b/scaler/io/ymq/pymod_ymq/python.h index d627d92c3..4e0df4d2b 100644 --- a/scaler/io/ymq/pymod_ymq/python.h +++ b/scaler/io/ymq/pymod_ymq/python.h @@ -4,8 +4,6 @@ #include #include -#include "scaler/io/ymq/pymod_ymq/utils.h" - #if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION == 8 static inline PyObject* Py_NewRef(PyObject* obj) { @@ -77,7 +75,7 @@ class OwnedPyObject { // steals a reference OwnedPyObject(T* ptr): _ptr(ptr) {} - OwnedPyObject(const OwnedPyObject& other) { this->_ptr = Py_XNewRef(other._ptr); } + OwnedPyObject(const OwnedPyObject& other) { this->_ptr = (T*)Py_XNewRef((PyObject*)other._ptr); } OwnedPyObject(OwnedPyObject&& other) noexcept: _ptr(other._ptr) { other._ptr = nullptr; } OwnedPyObject& operator=(const OwnedPyObject& other) { @@ -85,7 +83,7 @@ class OwnedPyObject { return *this; this->free(); - this->_ptr = Py_XNewRef(other._ptr); + this->_ptr = (T*)Py_XNewRef((PyObject*)other._ptr); return *this; } OwnedPyObject& operator=(OwnedPyObject&& other) noexcept diff --git a/scaler/io/ymq/pymod_ymq/utils.h b/scaler/io/ymq/pymod_ymq/utils.h deleted file mode 100644 index 522f819e8..000000000 --- a/scaler/io/ymq/pymod_ymq/utils.h +++ /dev/null @@ -1,110 +0,0 @@ -#pragma once - -// Python -#include - -#include "scaler/io/ymq/pymod_ymq/python.h" - -// C++ -#include - -// C -#include -#include - -#include - -// First-party -#include "scaler/io/ymq/common.h" -#include "scaler/io/ymq/pymod_ymq/ymq.h" - -class Waiter { -public: - Waiter(int wakeFd): _waiter(std::shared_ptr(new int, &destroy_efd)), _wakeFd(wakeFd) - { - auto fd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); - if (fd < 0) - throw std::runtime_error("failed to create eventfd"); - - *_waiter = fd; - } - - Waiter(const Waiter& other): _waiter(other._waiter), _wakeFd(other._wakeFd) {} - Waiter(Waiter&& other) noexcept: _waiter(std::move(other._waiter)), _wakeFd(other._wakeFd) - { - other._wakeFd = -1; // invalidate the moved-from object - } - - Waiter& operator=(const Waiter& other) - { - if (this == &other) - return *this; - - this->_waiter = other._waiter; - this->_wakeFd = other._wakeFd; - return *this; - } - - Waiter& operator=(Waiter&& other) noexcept - { - if (this == &other) - return *this; - - this->_waiter = std::move(other._waiter); - this->_wakeFd = other._wakeFd; - other._wakeFd = -1; // invalidate the moved-from object - return *this; - } - - void signal() - { - if (eventfd_write(*_waiter, 1) < 0) { - std::println(stderr, "Failed to signal waiter: {}", std::strerror(errno)); - } - } - - // true -> error - // false -> ok - bool wait() - { - pollfd pfds[2] = { - { - .fd = *_waiter, - .events = POLLIN, - .revents = 0, - }, - { - .fd = _wakeFd, - .events = POLLIN, - .revents = 0, - }}; - - for (;;) { - int ready = poll(pfds, 2, -1); - if (ready < 0) { - if (errno == EINTR) - continue; - throw std::runtime_error("poll failed"); - } - - if (pfds[0].revents & POLLIN) - return false; // we got a message - - if (pfds[1].revents & POLLIN) - return true; // signal received - } - } - -private: - std::shared_ptr _waiter; - int _wakeFd; - - static void destroy_efd(int* fd) - { - if (!fd) - return; - - close(*fd); - delete fd; - } -}; diff --git a/scaler/io/ymq/pymod_ymq/ymq.cpp b/scaler/io/ymq/pymod_ymq/ymq.cpp index 8444c21a2..768110472 100644 --- a/scaler/io/ymq/pymod_ymq/ymq.cpp +++ b/scaler/io/ymq/pymod_ymq/ymq.cpp @@ -15,7 +15,7 @@ inline void ymqUnrecoverableError(scaler::ymq::Error e) std::exit(EXIT_FAILURE); } -PyMODINIT_FUNC PyInit_ymq(void) +PyMODINIT_FUNC PyInit__ymq(void) { unrecoverableErrorFunctionHookPtr = ymqUnrecoverableError; diff --git a/scaler/io/ymq/pymod_ymq/ymq.h b/scaler/io/ymq/pymod_ymq/ymq.h index 7442a7b02..59c5cef60 100644 --- a/scaler/io/ymq/pymod_ymq/ymq.h +++ b/scaler/io/ymq/pymod_ymq/ymq.h @@ -9,20 +9,14 @@ // C++ #include -#include -#include #include #include #include // First-party #include "scaler/io/ymq/error.h" -#include "scaler/io/ymq/pymod_ymq/utils.h" struct YMQState { - int wakeupfd_wr; - int wakeupfd_rd; - OwnedPyObject<> enumModule; // Reference to the enum module OwnedPyObject<> asyncioModule; // Reference to the asyncio module @@ -33,79 +27,8 @@ struct YMQState { OwnedPyObject<> PyIOSocketType; // Reference to the IOSocket type OwnedPyObject<> PyIOContextType; // Reference to the IOContext type OwnedPyObject<> PyExceptionType; // Reference to the Exception type - OwnedPyObject<> PyInterruptedExceptionType; // Reference to the YMQInterruptedException type - OwnedPyObject<> PyAwaitableType; // Reference to the Awaitable type }; -#define CHECK_SIGNALS \ - do { \ - PyEval_RestoreThread(_save); \ - if (PyErr_CheckSignals() >= 0) \ - PyErr_SetString( \ - *state->PyInterruptedExceptionType, "A synchronous YMQ operation was interrupted by a signal"); \ - return (PyObject*)nullptr; \ - } while (0); - -static bool future_do_(PyObject* future_, const std::function()>& fn) -{ - // this is an owned reference to the future created in `async_wrapper()` - OwnedPyObject future(future_); - OwnedPyObject loop = PyObject_CallMethod(*future, "get_loop", nullptr); - if (!loop) - return true; - - // if future is already done, no need to call the method - OwnedPyObject result1 = PyObject_CallMethod(*future, "done", nullptr); - if (*result1 == Py_True) - return false; - - const char* method_name = nullptr; - OwnedPyObject arg {}; - - if (auto result = fn()) { - method_name = "set_result"; - arg = *result; - } else { - method_name = "set_exception"; - arg = result.error(); - } - - OwnedPyObject method = PyObject_GetAttrString(*future, method_name); - if (!method) - return true; - - OwnedPyObject obj = PyObject_GetAttrString(*loop, "call_soon_threadsafe"); - - // auto result = PyObject_CallMethod(loop, "call_soon_threadsafe", "OO", method, fn()); - OwnedPyObject result2 = PyObject_CallFunctionObjArgs(*obj, *method, *arg, nullptr); - return !result2; -} - -// this function must be called from a C++ thread -// this function will lock the GIL, call `fn()` and use its return value to set the future's result/exception -static void future_do(PyObject* future, const std::function()>& fn) -{ - PyGILState_STATE gstate = PyGILState_Ensure(); - // begin python critical section - - auto error = future_do_(future, fn); - if (error) - PyErr_WriteUnraisable(future); - - // end python critical section - PyGILState_Release(gstate); -} - -static void future_set_result(PyObject* future, std::function()> fn) -{ - return future_do(future, fn); -} - -static void future_raise_exception(PyObject* future, std::function fn) -{ - return future_do(future, [=] { return std::unexpected {fn()}; }); -} - static YMQState* YMQStateFromSelf(PyObject* self) { // replace with PyType_GetModuleByDef(Py_TYPE(self), &YMQ_module) in a newer Python version @@ -151,8 +74,13 @@ std::expected YMQ_GetRaisedException() #endif } +void completeCallbackWithRaisedException(PyObject* callback) +{ + auto result = YMQ_GetRaisedException(); + OwnedPyObject _ =PyObject_CallFunctionObjArgs(callback, result.value_or(result.error())); +} + // First-Party -#include "scaler/io/ymq/pymod_ymq/async.h" #include "scaler/io/ymq/pymod_ymq/bytes.h" #include "scaler/io/ymq/pymod_ymq/exception.h" #include "scaler/io/ymq/pymod_ymq/io_context.h" @@ -173,22 +101,10 @@ static void YMQ_free(YMQState* state) state->PyIOSocketType.~OwnedPyObject(); state->PyIOContextType.~OwnedPyObject(); state->PyExceptionType.~OwnedPyObject(); - state->PyInterruptedExceptionType.~OwnedPyObject(); - state->PyAwaitableType.~OwnedPyObject(); } catch (...) { PyErr_SetString(PyExc_RuntimeError, "Failed to free YMQState"); PyErr_WriteUnraisable(nullptr); } - - if (close(state->wakeupfd_wr) < 0) { - PyErr_SetString(PyExc_RuntimeError, "Failed to close waitfd_wr"); - PyErr_WriteUnraisable(nullptr); - } - - if (close(state->wakeupfd_rd) < 0) { - PyErr_SetString(PyExc_RuntimeError, "Failed to close waitfd_rd"); - PyErr_WriteUnraisable(nullptr); - } } static int YMQ_createIntEnum( @@ -322,21 +238,6 @@ static int YMQ_createErrorCodeEnum(PyObject* pyModule, YMQState* state) } } -static int YMQ_createInterruptedException(PyObject* pyModule, OwnedPyObject<>* storage) -{ - *storage = PyErr_NewExceptionWithDoc( - "ymq.YMQInterruptedException", - "Raised when a synchronous method is interrupted by a signal", - PyExc_Exception, - nullptr); - - if (!*storage) - return -1; - if (PyModule_AddObjectRef(pyModule, "YMQInterruptedException", **storage) < 0) - return -1; - return 0; -} - // internal convenience function to create a type and add it to the module static int YMQ_createType( // the module object @@ -380,36 +281,12 @@ static int YMQ_createType( return 0; } -static int YMQ_setupWakeupFd(YMQState* state) -{ - int pipefd[2]; - if (pipe2(pipefd, O_NONBLOCK | O_CLOEXEC) < 0) { - PyErr_SetString(PyExc_RuntimeError, "Failed to create pipe for wakeup fd"); - return -1; - } - - state->wakeupfd_rd = pipefd[0]; - state->wakeupfd_wr = pipefd[1]; - - OwnedPyObject signalModule = PyImport_ImportModule("signal"); - if (!signalModule) - return -1; - - OwnedPyObject result = PyObject_CallMethod(*signalModule, "set_wakeup_fd", "i", state->wakeupfd_wr); - if (!result) - return -1; - return 0; -} - static int YMQ_exec(PyObject* pyModule) { auto state = (YMQState*)PyModule_GetState(pyModule); if (!state) return -1; - if (YMQ_setupWakeupFd(state) < 0) - return -1; - state->enumModule = PyImport_ImportModule("enum"); if (!state->enumModule) return -1; @@ -443,10 +320,10 @@ static int YMQ_exec(PyObject* pyModule) if (YMQ_createType(pyModule, &state->PyMessageType, &PyMessage_spec, "Message") < 0) return -1; - if (YMQ_createType(pyModule, &state->PyIOSocketType, &PyIOSocket_spec, "IOSocket") < 0) + if (YMQ_createType(pyModule, &state->PyIOSocketType, &PyIOSocket_spec, "BaseIOSocket") < 0) return -1; - if (YMQ_createType(pyModule, &state->PyIOContextType, &PyIOContext_spec, "IOContext") < 0) + if (YMQ_createType(pyModule, &state->PyIOContextType, &PyIOContext_spec, "BaseIOContext") < 0) return -1; PyObject* exceptionBases = PyTuple_Pack(1, PyExc_Exception); @@ -460,12 +337,6 @@ static int YMQ_exec(PyObject* pyModule) } Py_DECREF(exceptionBases); - if (YMQ_createInterruptedException(pyModule, &state->PyInterruptedExceptionType) < 0) - return -1; - - if (YMQ_createType(pyModule, &state->PyAwaitableType, &Awaitable_spec, "Awaitable", false) < 0) - return -1; - return 0; } @@ -476,7 +347,7 @@ static PyModuleDef_Slot YMQ_slots[] = { static PyModuleDef YMQ_module = { .m_base = PyModuleDef_HEAD_INIT, - .m_name = "ymq", + .m_name = "_ymq", .m_doc = PyDoc_STR("YMQ Python bindings"), .m_size = sizeof(YMQState), .m_slots = YMQ_slots, diff --git a/scaler/io/ymq/ymq.py b/scaler/io/ymq/ymq.py new file mode 100644 index 000000000..19e97c829 --- /dev/null +++ b/scaler/io/ymq/ymq.py @@ -0,0 +1,130 @@ +# This file wraps the interface exported by the C implementation of the module +# and provides a more ergonomic interface supporting both asynchronous and synchronous execution + +__all__ = ["IOSocket", "IOContext", "Message", "IOSocketType", "YMQException", "Bytes", "ErrorCode"] + +import asyncio +import concurrent.futures +from typing import Optional, Callable, TypeVar, Union + +try: + from typing import ParamSpec, Concatenate # type: ignore[attr-defined] +except ImportError: + from typing_extensions import ParamSpec, Concatenate # type: ignore[assignment] + +from scaler.io.ymq._ymq import BaseIOContext, BaseIOSocket, Bytes, ErrorCode, IOSocketType, Message, YMQException + + +class IOSocket: + _base: BaseIOSocket + + def __init__(self, base: BaseIOSocket) -> None: + self._base = base + + @property + def socket_type(self) -> IOSocketType: + return self._base.socket_type + + @property + def identity(self) -> str: + return self._base.identity + + async def bind(self, address: str) -> None: + """Bind the socket to an address and listen for incoming connections""" + await call_async(self._base.bind, address) + + def bind_sync(self, address: str, /, timeout: Optional[float] = None) -> None: + """Bind the socket to an address and listen for incoming connections""" + call_sync(self._base.bind, address, timeout=timeout) + + async def connect(self, address: str) -> None: + """Connect to a remote socket""" + await call_async(self._base.connect, address) + + def connect_sync(self, address: str, /, timeout: Optional[float] = None) -> None: + """Connect to a remote socket""" + call_sync(self._base.connect, address, timeout=timeout) + + async def send(self, message: Message) -> None: + """Send a message to one of the socket's peers""" + await call_async(self._base.send, message) + + def send_sync(self, message: Message, /, timeout: Optional[float] = None) -> None: + """Send a message to one of the socket's peers""" + call_sync(self._base.send, message, timeout=timeout) + + async def recv(self) -> Message: + """Receive a message from one of the socket's peers""" + return await call_async(self._base.recv) + + def recv_sync(self, /, timeout: Optional[float] = None) -> Message: + """Receive a message from one of the socket's peers""" + return call_sync(self._base.recv, timeout=timeout) + + +class IOContext: + _base: BaseIOContext + + def __init__(self, num_threads: int = 1) -> None: + self._base = BaseIOContext(num_threads) + + @property + def num_threads(self) -> int: + return self._base.num_threads + + async def createIOSocket(self, identity: str, socket_type: IOSocketType) -> IOSocket: + """Create an io socket with an identity and socket type""" + return IOSocket(await call_async(self._base.createIOSocket, identity, socket_type)) + + def createIOSocket_sync(self, identity: str, socket_type: IOSocketType) -> IOSocket: + """Create an io socket with an identity and socket type""" + return IOSocket(call_sync(self._base.createIOSocket, identity, socket_type)) + + +P = ParamSpec("P") +T = TypeVar("T") + + +async def call_async( + func: Callable[Concatenate[Callable[[Union[T, Exception]], None], P], None], # type: ignore + *args: P.args, # type: ignore + **kwargs: P.kwargs, # type: ignore +) -> T: + future = asyncio.get_event_loop().create_future() + + def callback(result: Union[T, Exception]): + if future.done(): + return + + loop = future.get_loop() + + if isinstance(result, Exception): + loop.call_soon_threadsafe(future.set_exception, result) + else: + loop.call_soon_threadsafe(future.set_result, result) + + func(callback, *args, **kwargs) + return await future + + +# about the ignore directives: mypy cannot properly handle typing extension's ParamSpec and Concatenate in python <=3.9 +# these type hints are correctly understood in Python 3.10+ +def call_sync( # type: ignore[valid-type] + func: Callable[Concatenate[Callable[[Union[T, Exception]], None], P], None], # type: ignore + *args: P.args, # type: ignore + timeout: Optional[float] = None, + **kwargs: P.kwargs, # type: ignore +) -> T: # type: ignore + future: concurrent.futures.Future = concurrent.futures.Future() + + def callback(result: Union[T, Exception]): + if future.done(): + return + + if isinstance(result, Exception): + future.set_exception(result) + else: + future.set_result(result) + + func(callback, *args, **kwargs) + return future.result(timeout) diff --git a/scaler/worker/worker.py b/scaler/worker/worker.py index b31b9ab08..7d59e4b58 100644 --- a/scaler/worker/worker.py +++ b/scaler/worker/worker.py @@ -14,6 +14,7 @@ from scaler.io.async_object_storage_connector import PyAsyncObjectStorageConnector from scaler.io.config import PROFILING_INTERVAL_SECONDS from scaler.io.mixins import AsyncBinder, AsyncConnector, AsyncObjectStorageConnector +from scaler.io.ymq import ymq from scaler.protocol.python.message import ( ClientDisconnect, DisconnectRequest, @@ -232,6 +233,13 @@ async def __get_loops(self): ) except asyncio.CancelledError: pass + + # TODO: Should the object storage connector catch this error? + except ymq.YMQException as e: + if e.code == ymq.ErrorCode.ConnectorSocketClosedByRemoteEnd: + pass + else: + logging.exception(f"{self.identity!r}: failed with unhandled exception:\n{e}") except (ClientShutdownException, TimeoutError) as e: logging.info(f"{self.identity!r}: {str(e)}") except Exception as e: From 4c02a32889550b0f25416c93b0caa545afe92f49 Mon Sep 17 00:00:00 2001 From: magniloquency <197707854+magniloquency@users.noreply.github.com> Date: Mon, 6 Oct 2025 22:57:40 -0400 Subject: [PATCH 22/23] Remove Old Test File (#292) --- scaler/io/ymq/ymq_test.py | 20 -------------------- 1 file changed, 20 deletions(-) delete mode 100644 scaler/io/ymq/ymq_test.py diff --git a/scaler/io/ymq/ymq_test.py b/scaler/io/ymq/ymq_test.py deleted file mode 100644 index 9201983c7..000000000 --- a/scaler/io/ymq/ymq_test.py +++ /dev/null @@ -1,20 +0,0 @@ -import asyncio - -import ymq - - -async def main(): - ctx = ymq.IOContext() - socket = await ctx.createIOSocket("ident", ymq.IOSocketType.Binder) - print(ctx, ";", socket) - - assert socket.identity == "ident" - assert socket.socket_type == ymq.IOSocketType.Binder - - exc = ymq.YMQException(ymq.ErrorCode.InvalidAddressFormat, "the address has an invalid format") - assert exc.code == ymq.ErrorCode.InvalidAddressFormat - assert exc.message == "the address has an invalid format" - assert exc.code.explanation() - - -asyncio.run(main()) From 3513f9e6f69e76ad78df24a2f118023ea131e5b0 Mon Sep 17 00:00:00 2001 From: magniloquency <197707854+magniloquency@users.noreply.github.com> Date: Tue, 7 Oct 2025 10:16:52 -0400 Subject: [PATCH 23/23] Move C++ tests into run-test action --- .github/actions/compile-libraries/action.yml | 7 +------ .github/actions/run-test/action.yml | 10 ++++++++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/.github/actions/compile-libraries/action.yml b/.github/actions/compile-libraries/action.yml index 545f48cc7..bedf3fb31 100644 --- a/.github/actions/compile-libraries/action.yml +++ b/.github/actions/compile-libraries/action.yml @@ -9,16 +9,11 @@ inputs: runs: using: "composite" steps: - - name: Install dependencies for MITM tests - shell: bash - run: uv pip install --system scapy==2.* - - - name: Build and test C++ Components (Linux) + - name: Build C++ Components (Linux) if: inputs.os == 'Linux' shell: bash run: | CXX=$(which g++-14) ./scripts/build.sh - sudo ./scripts/test.sh - name: Build and test C++ Components (Windows) if: inputs.os == 'Windows' diff --git a/.github/actions/run-test/action.yml b/.github/actions/run-test/action.yml index d01c577de..47ee04df0 100644 --- a/.github/actions/run-test/action.yml +++ b/.github/actions/run-test/action.yml @@ -9,6 +9,16 @@ inputs: runs: using: "composite" steps: + - name: Install dependencies for MITM tests + if: inputs.os == 'Linux' + shell: bash + run: uv pip install --system scapy==2.* + + - name: Run C++ Tests (Linux) + if: inputs.os == 'Linux' + shell: bash + run: sudo ./scripts/test.sh + # TODO: build wheel first, then run the test - name: Run Unittests shell: bash