Skip to content

Commit

Permalink
WIP: fixing deadlocks
Browse files Browse the repository at this point in the history
replaced all structures in Server to be thread safe.
  • Loading branch information
eaneto committed May 7, 2024
1 parent 3140b53 commit b77213f
Showing 1 changed file with 90 additions and 56 deletions.
146 changes: 90 additions & 56 deletions src/raft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
sync::{atomic::AtomicU64, Arc},
};

use crossbeam_skiplist::SkipSet;
use crossbeam_skiplist::{SkipMap, SkipSet};
use tracing::{debug, error, info, trace};

use bytes::BytesMut;
Expand Down Expand Up @@ -44,31 +44,31 @@ enum State {
/// leader -> follower

pub struct Server {
id: u64,
id: NodeId,
// Need to be stored on disk
current_term: Arc<AtomicU64>,
voted_for: Arc<RwLock<Option<u64>>>,
// TODO: Maybe this could be a skiplist
log: Vec<LogEntry>,
voted_for: Arc<RwLock<Option<NodeId>>>,
log: RwLock<Vec<LogEntry>>,
// Can be stored in-memory
state: RwLock<State>,
// commit_index
commit_length: AtomicU64,
election_timeout: u16,
current_leader: AtomicU64,
votes_received: HashSet<u64>,
votes_received_list: SkipSet<u64>,
votes_received: SkipSet<NodeId>,
// sent_index
sent_length: HashMap<u64, u64>,
sent_length: SkipMap<NodeId, u64>,
// match_index
acked_length: HashMap<u64, u64>,
nodes: RwLock<HashMap<u64, Node>>,
acked_length: SkipMap<NodeId, u64>,
nodes: RwLock<HashMap<NodeId, Node>>,
last_heartbeat: RwLock<Option<time::Instant>>,
}

type NodeId = u64;

#[derive(Clone, Debug)]
pub struct Node {
pub id: u64,
pub id: NodeId,
pub address: String,
}

Expand All @@ -78,24 +78,19 @@ impl Server {
id,
current_term: Arc::new(AtomicU64::new(0)),
voted_for: Arc::new(RwLock::new(None)),
log: Vec::new(),
log: RwLock::new(Vec::new()),
state: RwLock::new(State::Follower),
commit_length: AtomicU64::new(0),
election_timeout: rand::thread_rng().gen_range(150..300),
current_leader: AtomicU64::new(0),
votes_received: HashSet::new(),
votes_received_list: SkipSet::new(),
sent_length: HashMap::new(),
acked_length: HashMap::new(),
votes_received: SkipSet::new(),
sent_length: SkipMap::new(),
acked_length: SkipMap::new(),
nodes: RwLock::new(nodes.clone()),
last_heartbeat: RwLock::new(None),
}
}

pub fn election_timeout(&self) -> u16 {
self.election_timeout
}

fn current_term(&self) -> u64 {
self.current_term
.clone()
Expand Down Expand Up @@ -157,6 +152,32 @@ impl Server {
.store(value, std::sync::atomic::Ordering::SeqCst);
}

pub fn election_timeout(&self) -> u16 {
self.election_timeout
}

fn vote_on_self(&self) {
self.votes_received.insert(self.id);
}

fn unvote(&self) {
self.votes_received.remove(&self.id);
}

fn vote_on_new_leader(&self, node: u64) {
self.votes_received.insert(node);
}

fn update_sent_length(&self, node_id: NodeId, length: u64) {
self.sent_length.insert(node_id, length);
}

fn decrement_sent_length(&self, node_id: NodeId) {
let sent_length_by_node = self.sent_length.get(&node_id).unwrap();
self.sent_length
.insert(node_id, sent_length_by_node.value() - 1);
}

pub async fn last_heartbeat(&self) -> Option<time::Instant> {
*self.last_heartbeat.read().await
}
Expand All @@ -166,6 +187,10 @@ impl Server {
*last_heartbeat = Some(time::Instant::now());
}

async fn log_length(&self) -> usize {
self.log.read().await.len()
}

pub async fn start_election(&mut self) {
if self.is_leader().await {
return;
Expand All @@ -176,10 +201,10 @@ impl Server {
self.become_candidate().await;
let mut voted_for = self.voted_for.write().await;
*voted_for = Some(self.id);
self.votes_received.insert(self.id);
self.vote_on_self();
}

let last_term = self.last_term();
let last_term = self.last_term().await;

let vote_request = VoteRequest {
node_id: self.id,
Expand Down Expand Up @@ -207,7 +232,7 @@ impl Server {
self.become_follower().await;
let mut voted_for = self.voted_for.write().await;
*voted_for = None;
self.votes_received.remove(&self.id);
self.unvote();
}
}

Expand Down Expand Up @@ -280,17 +305,17 @@ impl Server {
&& vote_response.term == self.current_term()
&& vote_response.vote_in_favor
{
self.votes_received.insert(vote_response.node_id);
self.vote_on_new_leader(vote_response.node_id);
trace!("Received vote in favor from {}", &vote_response.node_id);
if self.votes_received.len() >= (self.nodes.read().await.len() - 1) / 2 {
if self.has_majority_of_votes().await {
info!("Majority of votes in favor received, becoming leader");
self.become_leader().await;
self.current_leader
.store(self.id, std::sync::atomic::Ordering::SeqCst);
// TODO: Cancel election timer
let nodes = self.nodes.read().await.clone();
for (_, node) in &nodes {
self.sent_length.insert(node.id, self.log.len() as u64);
self.update_sent_length(node.id, self.log.read().await.len() as u64);
self.acked_length.insert(node.id, 0);
// FIXME
let _ = self.replicate_log(node).await;
Expand All @@ -300,17 +325,21 @@ impl Server {
}
}

async fn has_majority_of_votes(&self) -> bool {
self.votes_received.len() >= (self.nodes.read().await.len() - 1) / 2
}

pub async fn receive_vote(&mut self, vote_request: VoteRequest) -> VoteResponse {
if vote_request.current_term > self.current_term() {
self.update_current_term(vote_request.current_term);
self.become_follower().await;
let mut voted_for = self.voted_for.write().await;
*voted_for = None;
}
let last_term = self.last_term();
let last_term = self.last_term().await;
let ok = (vote_request.last_term > last_term)
|| (vote_request.last_term == last_term
&& vote_request.log_length >= self.log.len() as u64);
&& vote_request.log_length >= self.log.read().await.len() as u64);
let mut voted_for = self.voted_for.write().await;
let response = if vote_request.current_term == self.current_term()
&& ok
Expand All @@ -333,9 +362,10 @@ impl Server {
response
}

fn last_term(&self) -> u64 {
if self.log.len() > 0 {
self.log[self.log.len() - 1].term
async fn last_term(&self) -> u64 {
let log = self.log.read().await;
if log.len() > 0 {
log[log.len() - 1].term
} else {
0
}
Expand All @@ -345,7 +375,7 @@ impl Server {

pub async fn broadcast_message(&mut self, message: BytesMut) {
if self.is_leader().await {
self.log.push(LogEntry {
self.log.write().await.push(LogEntry {
term: self.current_term(),
message,
});
Expand All @@ -371,11 +401,11 @@ impl Server {
// Can only be called by the leader
async fn replicate_log(&self, node: &Node) -> Result<LogResponse, &str> {
let request = match self.sent_length.get(&node.id) {
Some(length) => {
let prefix_length = *length as usize;
let suffix = &self.log[prefix_length..];
Some(entry) => {
let prefix_length = *entry.value() as usize;
let suffix = &self.log.read().await[prefix_length..];
let prefix_term = if prefix_length > 0 {
self.log[prefix_length - 1].term
self.log.read().await[prefix_length - 1].term
} else {
0
};
Expand Down Expand Up @@ -489,18 +519,19 @@ impl Server {

if log_response.term == self.current_term() && self.is_leader().await {
if log_response.successful
&& &log_response.ack >= self.acked_length.get(&log_response.node_id).unwrap()
&& &log_response.ack
>= self
.acked_length
.get(&log_response.node_id)
.unwrap()
.value()
{
self.sent_length
.insert(log_response.node_id, log_response.ack);
self.update_sent_length(log_response.node_id, log_response.ack);
self.acked_length
.insert(log_response.node_id, log_response.ack);
self.commit_log_entries().await;
} else if *self.sent_length.get(&log_response.node_id).unwrap() > 0 {
self.sent_length.insert(
log_response.node_id,
self.sent_length.get(&log_response.node_id).unwrap() - 1,
);
} else if *self.sent_length.get(&log_response.node_id).unwrap().value() > 0 {
self.decrement_sent_length(log_response.node_id);

let nodes = self.nodes.read().await;
let node = nodes.get(&log_response.node_id).unwrap();
Expand All @@ -526,9 +557,10 @@ impl Server {
.store(log_request.leader_id, std::sync::atomic::Ordering::SeqCst);
}

let ok = (self.log.len() >= log_request.prefix_length)
let ok = (self.log_length().await >= log_request.prefix_length)
&& (log_request.prefix_length == 0
|| self.log[log_request.prefix_length - 1].term == log_request.prefix_term);
|| self.log.read().await[log_request.prefix_length - 1].term
== log_request.prefix_term);
if log_request.term == self.current_term() && ok {
let ack = log_request.prefix_length + log_request.suffix.len();
self.send_append_entries(log_request).await;
Expand All @@ -553,12 +585,12 @@ impl Server {
}

async fn commit_log_entries(&mut self) {
while self.commit_length() < self.log.len() as u64 {
while self.commit_length() < self.log_length().await as u64 {
let mut acks = 0;
let nodes = self.nodes.read().await.clone();
for (_, node) in &nodes {
let acked_length = self.acked_length.get(&node.id).unwrap();
if *acked_length > self.commit_length() {
let acked_length = *self.acked_length.get(&node.id).unwrap().value();
if acked_length > self.commit_length() {
acks += 1;
}
}
Expand All @@ -574,22 +606,24 @@ impl Server {
}

async fn send_append_entries(&mut self, log_request: LogRequest) {
if log_request.suffix.len() > 0 && self.log.len() > log_request.prefix_length {
let mut log = self.log.write().await;
let log_length = log.len();
if log_request.suffix.len() > 0 && log_length > log_request.prefix_length {
let index = cmp::min(
self.log.len(),
log_length,
log_request.prefix_length + log_request.suffix.len(),
) - 1;
// Log is inconsistent
if self.log[index].term != log_request.suffix[index - log_request.prefix_length].term {
self.log = self.log[..log_request.prefix_length - 1].to_vec();
if log[index].term != log_request.suffix[index - log_request.prefix_length].term {
*log = log[..log_request.prefix_length - 1].to_vec();
}
}

if log_request.prefix_length + log_request.suffix.len() > self.log.len() {
let start = self.log.len() - log_request.prefix_length;
if log_request.prefix_length + log_request.suffix.len() > log_length {
let start = log_length - log_request.prefix_length;
let end = log_request.suffix.len() - 1;
for i in start..end {
self.log.push(log_request.suffix[i].clone());
log.push(log_request.suffix[i].clone());
}
}

Expand Down

0 comments on commit b77213f

Please sign in to comment.