Skip to content

Commit

Permalink
fix: Avoid racy cancellation of tasks
Browse files Browse the repository at this point in the history
Relying on the value behind `Arc::strong_count` to remain unchanged in
the Drop::drop impl is not a good idea. The docs even warn about this:

> Another thread can change the strong count at any time, including
> potentially between calling this method and acting on the result.

So solve this by making use of tokio's API for cancelling tasks.
  • Loading branch information
zeenix committed Jan 10, 2024
1 parent 1f05699 commit 76a1b86
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 78 deletions.
25 changes: 12 additions & 13 deletions extensions/warp-ipfs/src/store/document/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use futures::{
};
use libipld::Cid;
use rust_ipfs::{Ipfs, IpfsPath};
use tokio::select;
use tokio_util::sync::{CancellationToken, DropGuard};
use warp::{crypto::DID, error::Error};

use super::identity::IdentityDocument;
Expand All @@ -36,15 +38,7 @@ enum IdentityCacheCommand {
#[derive(Debug, Clone)]
pub struct IdentityCache {
tx: Sender<IdentityCacheCommand>,
task: Arc<tokio::task::JoinHandle<()>>,
}

impl Drop for IdentityCache {
fn drop(&mut self) {
if Arc::strong_count(&self.task) == 1 && !self.task.is_finished() {
self.task.abort();
}
}
_task_cancellation: Arc<DropGuard>,
}

impl IdentityCache {
Expand All @@ -67,13 +61,18 @@ impl IdentityCache {
rx,
};

let handle = tokio::spawn(async move {
task.start().await;
let token = CancellationToken::new();
let drop_guard = token.clone().drop_guard();
tokio::spawn(async move {
select! {
_ = token.cancelled() => {}
_ = task.run() => {}
}
});

Self {
tx,
task: Arc::new(handle),
_task_cancellation: Arc::new(drop_guard),
}
}

Expand Down Expand Up @@ -146,7 +145,7 @@ struct IdentityCacheTask {
}

impl IdentityCacheTask {
pub async fn start(&mut self) {
pub async fn run(&mut self) {
// migrate old identity to new
self.migrate().await;
// repin map
Expand Down
25 changes: 12 additions & 13 deletions extensions/warp-ipfs/src/store/document/conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use futures::{
};
use libipld::Cid;
use rust_ipfs::{Ipfs, IpfsPath};
use tokio::select;
use tokio_util::sync::{CancellationToken, DropGuard};
use tracing::warn;
use uuid::Uuid;
use warp::{
Expand Down Expand Up @@ -63,15 +65,7 @@ enum ConversationCommand {
#[derive(Debug, Clone)]
pub struct Conversations {
tx: mpsc::Sender<ConversationCommand>,
task: Arc<tokio::task::JoinHandle<()>>,
}

impl Drop for Conversations {
fn drop(&mut self) {
if Arc::strong_count(&self.task) == 1 && !self.task.is_finished() {
self.task.abort();
}
}
_task_cancellation: Arc<DropGuard>,
}

impl Conversations {
Expand Down Expand Up @@ -102,13 +96,18 @@ impl Conversations {
root,
};

let handle = tokio::spawn(async move {
task.start().await;
let token = CancellationToken::new();
let drop_guard = token.clone().drop_guard();
tokio::spawn(async move {
select! {
_ = token.cancelled() => {}
_ = task.run() => {}
}
});

Self {
tx,
task: Arc::new(handle),
_task_cancellation: Arc::new(drop_guard),
}
}

Expand Down Expand Up @@ -214,7 +213,7 @@ struct ConversationTask {
}

impl ConversationTask {
async fn start(&mut self) {
async fn run(&mut self) {
while let Some(command) = self.rx.next().await {
match command {
ConversationCommand::GetDocument { id, response } => {
Expand Down
25 changes: 12 additions & 13 deletions extensions/warp-ipfs/src/store/document/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use futures::{
};
use libipld::Cid;
use rust_ipfs::{Ipfs, IpfsPath};
use tokio::select;
use tokio_util::sync::{CancellationToken, DropGuard};
use uuid::Uuid;
use warp::{crypto::DID, error::Error, multipass::identity::IdentityStatus};

Expand Down Expand Up @@ -96,15 +98,7 @@ pub enum RootDocumentCommand {
#[derive(Debug, Clone)]
pub struct RootDocumentMap {
tx: futures::channel::mpsc::Sender<RootDocumentCommand>,
task: Arc<tokio::task::JoinHandle<()>>,
}

impl Drop for RootDocumentMap {
fn drop(&mut self) {
if Arc::strong_count(&self.task) == 1 && !self.task.is_finished() {
self.task.abort();
}
}
_task_cancellation: Arc<DropGuard>,
}

impl RootDocumentMap {
Expand All @@ -128,13 +122,18 @@ impl RootDocumentMap {
rx,
};

let handle = tokio::spawn(async move {
task.start().await;
let token = CancellationToken::new();
let drop_guard = token.clone().drop_guard();
tokio::spawn(async move {
select! {
_ = token.cancelled() => {}
_ = task.run() => {}
}
});

Self {
tx,
task: Arc::new(handle),
_task_cancellation: Arc::new(drop_guard),
}
}

Expand Down Expand Up @@ -394,7 +393,7 @@ struct RootDocumentTask {
}

impl RootDocumentTask {
pub async fn start(&mut self) {
pub async fn run(&mut self) {
self.migrate().await;

while let Some(command) = self.rx.next().await {
Expand Down
25 changes: 12 additions & 13 deletions extensions/warp-ipfs/src/store/event_subscription.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use futures::{
use std::fmt::Debug;
use std::task::{Poll, Waker};
use std::{collections::VecDeque, sync::Arc};
use tokio::select;
use tokio_util::sync::{CancellationToken, DropGuard};
use warp::error::Error;

#[allow(clippy::large_enum_variant)]
Expand All @@ -24,7 +26,7 @@ enum Command<T: Clone + Debug + Send + 'static> {
#[derive(Clone, Debug)]
pub struct EventSubscription<T: Clone + Debug + Send + 'static> {
tx: Sender<Command<T>>,
task: Arc<tokio::task::JoinHandle<()>>,
_task_cancellation: Arc<DropGuard>,
}

impl<T: Clone + Debug + Send + 'static> EventSubscription<T> {
Expand All @@ -39,13 +41,18 @@ impl<T: Clone + Debug + Send + 'static> EventSubscription<T> {
rx,
};

let handle = tokio::spawn(async move {
task.start().await;
let token = CancellationToken::new();
let drop_guard = token.clone().drop_guard();
tokio::spawn(async move {
select! {
_ = token.cancelled() => {}
_ = task.run() => {}
}
});

Self {
tx,
task: Arc::new(handle),
_task_cancellation: Arc::new(drop_guard),
}
}

Expand All @@ -70,14 +77,6 @@ impl<T: Clone + Debug + Send + 'static> EventSubscription<T> {
}
}

impl<T: Clone + Debug + Send + 'static> Drop for EventSubscription<T> {
fn drop(&mut self) {
if Arc::strong_count(&self.task) == 1 && !self.task.is_finished() {
self.task.abort();
}
}
}

struct EventSubscriptionTask<T: Clone + Send + Debug + 'static> {
senders: Vec<Sender<T>>,
queue: VecDeque<T>,
Expand All @@ -86,7 +85,7 @@ struct EventSubscriptionTask<T: Clone + Send + Debug + 'static> {
}

impl<T: Clone + Send + 'static + Debug> EventSubscriptionTask<T> {
pub async fn start(&mut self) {
pub async fn run(&mut self) {
loop {
tokio::select! {
_ = futures::future::poll_fn(|cx| -> Poll<T> {
Expand Down
25 changes: 12 additions & 13 deletions tools/shuttle/src/store/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use futures::{
};
use libipld::Cid;
use rust_ipfs::{Ipfs, IpfsPath};
use tokio::select;
use tokio_util::sync::{CancellationToken, DropGuard};
use warp::{crypto::DID, error::Error};

use crate::identity::{document::IdentityDocument, protocol::Lookup, RequestPayload};
Expand Down Expand Up @@ -60,15 +62,7 @@ enum IdentityStorageCommand {
#[derive(Debug, Clone)]
pub struct IdentityStorage {
tx: Sender<IdentityStorageCommand>,
task: Arc<tokio::task::JoinHandle<()>>,
}

impl Drop for IdentityStorage {
fn drop(&mut self) {
if Arc::strong_count(&self.task) == 1 && !self.task.is_finished() {
self.task.abort();
}
}
_task_cancellation: Arc<DropGuard>,
}

impl IdentityStorage {
Expand All @@ -90,13 +84,18 @@ impl IdentityStorage {
rx,
};

let handle = tokio::spawn(async move {
task.start().await;
let token = CancellationToken::new();
let drop_guard = token.clone().drop_guard();
tokio::spawn(async move {
select! {
_ = token.cancelled() => {}
_ = task.run() => {}
}
});

Self {
tx,
task: Arc::new(handle),
_task_cancellation: Arc::new(drop_guard),
}
}

Expand Down Expand Up @@ -267,7 +266,7 @@ struct IdentityStorageTask {
}

impl IdentityStorageTask {
pub async fn start(&mut self) {
pub async fn run(&mut self) {
while let Some(command) = self.rx.next().await {
match command {
IdentityStorageCommand::Register { document, response } => {
Expand Down
25 changes: 12 additions & 13 deletions tools/shuttle/src/store/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use futures::{
use libipld::Cid;
use rust_ipfs::Ipfs;
use serde::{Deserialize, Serialize};
use tokio::select;
use tokio_util::sync::{CancellationToken, DropGuard};
use warp::error::Error;

#[derive(Default, Serialize, Deserialize, Clone, Copy, Debug)]
Expand Down Expand Up @@ -45,15 +47,7 @@ enum RootCommand {
#[derive(Debug, Clone)]
pub struct RootStorage {
tx: Sender<RootCommand>,
task: Arc<tokio::task::JoinHandle<()>>,
}

impl Drop for RootStorage {
fn drop(&mut self) {
if Arc::strong_count(&self.task) == 1 && !self.task.is_finished() {
self.task.abort();
}
}
_task_cancellation: Arc<DropGuard>,
}

impl RootStorage {
Expand Down Expand Up @@ -89,13 +83,18 @@ impl RootStorage {
rx,
};

let handle = tokio::spawn(async move {
task.start().await;
let token = CancellationToken::new();
let drop_guard = token.clone().drop_guard();
tokio::spawn(async move {
select! {
_ = token.cancelled() => {}
_ = task.run() => {}
}
});

Self {
tx,
task: Arc::new(handle),
_task_cancellation: Arc::new(drop_guard),
}
}

Expand Down Expand Up @@ -166,7 +165,7 @@ struct RootStorageTask {
}

impl RootStorageTask {
pub async fn start(&mut self) {
pub async fn run(&mut self) {
while let Some(command) = self.rx.next().await {
match command {
RootCommand::SetIdentityList { link, response } => {
Expand Down

0 comments on commit 76a1b86

Please sign in to comment.