Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use futex for IPC notifications #625

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions SConscript
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ msgq_objects = env.SharedObject([
'msgq/impl_zmq.cc',
'msgq/impl_msgq.cc',
'msgq/impl_fake.cc',
'msgq/futex.cc',
'msgq/msgq.cc',
])
msgq = env.Library('msgq', msgq_objects)
Expand Down
62 changes: 62 additions & 0 deletions msgq/futex.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include "msgq/futex.h"

#include <fcntl.h>
#include <limits.h>
#include <linux/futex.h>
#include <stdio.h>
#include <sys/mman.h>
#include <syscall.h>
#include <unistd.h>

#include <cassert>
#include <stdexcept>

Futex::Futex(const std::string &path) {
auto fd = open(path.c_str(), O_RDWR | O_CREAT, 0664);
if (fd < 0) {
throw std::runtime_error("Failed to open file: " + path);
}

if (ftruncate(fd, sizeof(uint32_t)) < 0) {
close(fd);
throw std::runtime_error("Failed to truncate file: " + path);
}

int *mem = (int *)mmap(NULL, sizeof(uint32_t), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
close(fd);
if (mem == MAP_FAILED) {
throw std::runtime_error("Failed to mmap file: " + path);
}

futex = reinterpret_cast<std::atomic<uint32_t> *>(mem);
}

Futex::~Futex() {
munmap(futex, sizeof(uint32_t));
}

void Futex::broadcast() {
// Increment the futex value to signal waiting threads
futex->fetch_add(1, std::memory_order_relaxed);

// Wake up all threads waiting on the futex
syscall(SYS_futex, futex, FUTEX_WAKE, INT_MAX, NULL, NULL, 0);
}

bool Futex::wait(uint32_t expected, int timeout_ms) {
if (futex->load(std::memory_order_relaxed) != expected) {
return true; // Already not equal, no need to wait
}

if (timeout_ms <= 0) {
return false; // Timeout immediately
}

// Perform the futex wait syscall
struct timespec ts;
ts.tv_sec = timeout_ms / 1000;
ts.tv_nsec = (timeout_ms % 1000) * 1000 * 1000;
syscall(SYS_futex, futex, FUTEX_WAIT, expected, &ts, nullptr, 0);

return futex->load(std::memory_order_relaxed) != expected;
}
18 changes: 18 additions & 0 deletions msgq/futex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include <cstdint>
#include <atomic>
#include <string>


class Futex {
public:
Futex(const std::string &path);
~Futex();
void broadcast();
bool wait(uint32_t expected, int timeout_ms);
inline uint32_t value() const { return futex->load(); }

private:
std::atomic<uint32_t> *futex = nullptr;
};
11 changes: 2 additions & 9 deletions msgq/impl_msgq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ Message * MSGQSubSocket::receive(bool non_blocking){
}

msgq_msg_t msg;

MSGQMessage *r = NULL;

int rc = msgq_msg_recv(&msg, q);
Expand All @@ -93,21 +92,15 @@ Message * MSGQSubSocket::receive(bool non_blocking){
items[0].q = q;

int t = (timeout != -1) ? timeout : 100;

int n = msgq_poll(items, 1, t);
rc = msgq_msg_recv(&msg, q);

// The poll indicated a message was ready, but the receive failed. Try again
if (n == 1 && rc == 0){
continue;
if (msgq_poll(items, 1, t) > 0) {
rc = msgq_msg_recv(&msg, q);
}

if (timeout != -1){
break;
}
}


if (!non_blocking){
std::signal(SIGINT, prev_handler_sigint);
std::signal(SIGTERM, prev_handler_sigterm);
Expand Down
78 changes: 20 additions & 58 deletions msgq/msgq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,21 @@
#include <cerrno>
#include <cmath>
#include <cstring>
#include <cstdint>
#include <chrono>
#include <algorithm>
#include <cstdlib>
#include <csignal>
#include <random>
#include <string>
#include <limits>

#include <poll.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/syscall.h>
#include <fcntl.h>
#include <unistd.h>

#include <stdio.h>

#include "msgq/futex.h"
#include "msgq/msgq.h"

void sigusr2_handler(int signal) {
assert(signal == SIGUSR2);
}
Futex g_futex("/dev/shm/msgq_futex");

uint64_t msgq_get_uid(void){
std::random_device rd("/dev/urandom");
Expand Down Expand Up @@ -85,7 +75,6 @@ void msgq_wait_for_subscriber(msgq_queue_t *q){

int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){
assert(size < 0xFFFFFFFF); // Buffer must be smaller than 2^32 bytes
std::signal(SIGUSR2, sigusr2_handler);

std::string full_path = "/dev/shm/";
const char* prefix = std::getenv("OPENPILOT_PREFIX");
Expand Down Expand Up @@ -142,7 +131,6 @@ void msgq_close_queue(msgq_queue_t *q){
}
}


void msgq_init_publisher(msgq_queue_t * q) {
//std::cout << "Starting publisher" << std::endl;
uint64_t uid = msgq_get_uid();
Expand All @@ -158,15 +146,6 @@ void msgq_init_publisher(msgq_queue_t * q) {
q->write_uid_local = uid;
}

static void thread_signal(uint32_t tid) {
#ifndef SYS_tkill
// TODO: this won't work for multithreaded programs
kill(tid, SIGUSR2);
#else
syscall(SYS_tkill, tid, SIGUSR2);
#endif
}

void msgq_init_subscriber(msgq_queue_t * q) {
assert(q != NULL);
assert(q->num_readers != NULL);
Expand All @@ -185,14 +164,11 @@ void msgq_init_subscriber(msgq_queue_t * q) {

for (size_t i = 0; i < NUM_READERS; i++){
*q->read_valids[i] = false;

uint64_t old_uid = *q->read_uids[i];
*q->read_uids[i] = 0;

// Wake up reader in case they are in a poll
thread_signal(old_uid & 0xFFFFFFFF);
}

// Notify readers
g_futex.broadcast();
continue;
}

Expand Down Expand Up @@ -293,10 +269,7 @@ int msgq_msg_send(msgq_msg_t * msg, msgq_queue_t *q){
PACK64(*q->write_pointer, write_cycles, new_ptr);

// Notify readers
for (uint64_t i = 0; i < num_readers; i++){
uint64_t reader_uid = *q->read_uids[i];
thread_signal(reader_uid & 0xFFFFFFFF);
}
g_futex.broadcast();

return msg->size;
}
Expand Down Expand Up @@ -414,42 +387,31 @@ int msgq_msg_recv(msgq_msg_t * msg, msgq_queue_t * q){
goto start;
}


return msg->size;
}



int msgq_poll(msgq_pollitem_t * items, size_t nitems, int timeout){
int msgq_poll(msgq_pollitem_t * items, size_t nitems, int timeout) {
int num = 0;
int timeout_ms = (timeout == -1) ? 100 : timeout;
uint32_t current_futex_value = 0;

// Check if messages ready
for (size_t i = 0; i < nitems; i++) {
items[i].revents = msgq_msg_ready(items[i].q);
if (items[i].revents) num++;
}

int ms = (timeout == -1) ? 100 : timeout;
struct timespec ts;
ts.tv_sec = ms / 1000;
ts.tv_nsec = (ms % 1000) * 1000 * 1000;


auto start_time = std::chrono::high_resolution_clock::now();
while (num == 0) {
int ret;

ret = nanosleep(&ts, &ts);
if (g_futex.wait(current_futex_value, timeout_ms)) {
current_futex_value = g_futex.value();

// Check if messages ready
for (size_t i = 0; i < nitems; i++) {
if (items[i].revents == 0 && msgq_msg_ready(items[i].q)){
num += 1;
items[i].revents = 1;
// Check if messages ready
for (size_t i = 0; i < nitems; i++) {
items[i].revents = msgq_msg_ready(items[i].q);
if (items[i].revents) ++num;
}
}

// exit if we had a timeout and the sleep finished
if (timeout != -1 && ret == 0){
// Update the remaining timeout
auto current_time = std::chrono::high_resolution_clock::now();
timeout_ms -= std::chrono::duration_cast<std::chrono::milliseconds>(current_time - start_time).count();
start_time = current_time;
if (timeout_ms <= 0) {
break;
}
}
Expand Down
Loading