Skip to content

Commit

Permalink
vhost_task: Handle SIGKILL by flushing work and exiting
Browse files Browse the repository at this point in the history
Instead of lingering until the device is closed, this has us handle
SIGKILL by:

1. marking the worker as killed so we no longer try to use it with
   new virtqueues and new flush operations.
2. setting the virtqueue to worker mapping so no new works are queued.
3. running all the exiting works.

Suggested-by: Edward Adam Davis <[email protected]>
Reported-and-tested-by: [email protected]
Message-Id: <[email protected]>
Signed-off-by: Mike Christie <[email protected]>
Message-Id: <[email protected]>
Signed-off-by: Michael S. Tsirkin <[email protected]>
  • Loading branch information
mikechristie authored and mstsirkin committed May 22, 2024
1 parent ba704ff commit db5247d
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 24 deletions.
54 changes: 50 additions & 4 deletions drivers/vhost/vhost.c
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ static void __vhost_worker_flush(struct vhost_worker *worker)
{
struct vhost_flush_struct flush;

if (!worker->attachment_cnt)
if (!worker->attachment_cnt || worker->killed)
return;

init_completion(&flush.wait_event);
Expand Down Expand Up @@ -388,7 +388,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
__vhost_vq_meta_reset(vq);
}

static bool vhost_worker(void *data)
static bool vhost_run_work_list(void *data)
{
struct vhost_worker *worker = data;
struct vhost_work *work, *work_next;
Expand All @@ -413,6 +413,40 @@ static bool vhost_worker(void *data)
return !!node;
}

static void vhost_worker_killed(void *data)
{
struct vhost_worker *worker = data;
struct vhost_dev *dev = worker->dev;
struct vhost_virtqueue *vq;
int i, attach_cnt = 0;

mutex_lock(&worker->mutex);
worker->killed = true;

for (i = 0; i < dev->nvqs; i++) {
vq = dev->vqs[i];

mutex_lock(&vq->mutex);
if (worker ==
rcu_dereference_check(vq->worker,
lockdep_is_held(&vq->mutex))) {
rcu_assign_pointer(vq->worker, NULL);
attach_cnt++;
}
mutex_unlock(&vq->mutex);
}

worker->attachment_cnt -= attach_cnt;
if (attach_cnt)
synchronize_rcu();
/*
* Finish vhost_worker_flush calls and any other works that snuck in
* before the synchronize_rcu.
*/
vhost_run_work_list(worker);
mutex_unlock(&worker->mutex);
}

static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
{
kfree(vq->indirect);
Expand Down Expand Up @@ -627,9 +661,11 @@ static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
if (!worker)
return NULL;

worker->dev = dev;
snprintf(name, sizeof(name), "vhost-%d", current->pid);

vtsk = vhost_task_create(vhost_worker, worker, name);
vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed,
worker, name);
if (!vtsk)
goto free_worker;

Expand Down Expand Up @@ -661,6 +697,11 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
struct vhost_worker *old_worker;

mutex_lock(&worker->mutex);
if (worker->killed) {
mutex_unlock(&worker->mutex);
return;
}

mutex_lock(&vq->mutex);

old_worker = rcu_dereference_check(vq->worker,
Expand All @@ -681,6 +722,11 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
* device wide flushes which doesn't use RCU for execution.
*/
mutex_lock(&old_worker->mutex);
if (old_worker->killed) {
mutex_unlock(&old_worker->mutex);
return;
}

/*
* We don't want to call synchronize_rcu for every vq during setup
* because it will slow down VM startup. If we haven't done
Expand Down Expand Up @@ -758,7 +804,7 @@ static int vhost_free_worker(struct vhost_dev *dev,
return -ENODEV;

mutex_lock(&worker->mutex);
if (worker->attachment_cnt) {
if (worker->attachment_cnt || worker->killed) {
mutex_unlock(&worker->mutex);
return -EBUSY;
}
Expand Down
2 changes: 2 additions & 0 deletions drivers/vhost/vhost.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ struct vhost_work {

struct vhost_worker {
struct vhost_task *vtsk;
struct vhost_dev *dev;
/* Used to serialize device wide flushing with worker swapping. */
struct mutex mutex;
struct llist_head work_list;
u64 kcov_handle;
u32 id;
int attachment_cnt;
bool killed;
};

/* Poll a file (eventfd or socket) */
Expand Down
3 changes: 2 additions & 1 deletion include/linux/sched/vhost_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

struct vhost_task;

struct vhost_task *vhost_task_create(bool (*fn)(void *), void *arg,
struct vhost_task *vhost_task_create(bool (*fn)(void *),
void (*handle_kill)(void *), void *arg,
const char *name);
void vhost_task_start(struct vhost_task *vtsk);
void vhost_task_stop(struct vhost_task *vtsk);
Expand Down
53 changes: 34 additions & 19 deletions kernel/vhost_task.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,32 @@

enum vhost_task_flags {
VHOST_TASK_FLAGS_STOP,
VHOST_TASK_FLAGS_KILLED,
};

struct vhost_task {
bool (*fn)(void *data);
void (*handle_sigkill)(void *data);
void *data;
struct completion exited;
unsigned long flags;
struct task_struct *task;
/* serialize SIGKILL and vhost_task_stop calls */
struct mutex exit_mutex;
};

static int vhost_task_fn(void *data)
{
struct vhost_task *vtsk = data;
bool dead = false;

for (;;) {
bool did_work;

if (!dead && signal_pending(current)) {
if (signal_pending(current)) {
struct ksignal ksig;
/*
* Calling get_signal will block in SIGSTOP,
* or clear fatal_signal_pending, but remember
* what was set.
*
* This thread won't actually exit until all
* of the file descriptors are closed, and
* the release function is called.
*/
dead = get_signal(&ksig);
if (dead)
clear_thread_flag(TIF_SIGPENDING);

if (get_signal(&ksig))
break;
}

/* mb paired w/ vhost_task_stop */
Expand All @@ -57,7 +51,19 @@ static int vhost_task_fn(void *data)
schedule();
}

mutex_lock(&vtsk->exit_mutex);
/*
* If a vhost_task_stop and SIGKILL race, we can ignore the SIGKILL.
* When the vhost layer has called vhost_task_stop it's already stopped
* new work and flushed.
*/
if (!test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags)) {
set_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags);
vtsk->handle_sigkill(vtsk->data);
}
mutex_unlock(&vtsk->exit_mutex);
complete(&vtsk->exited);

do_exit(0);
}

Expand All @@ -78,12 +84,17 @@ EXPORT_SYMBOL_GPL(vhost_task_wake);
* @vtsk: vhost_task to stop
*
* vhost_task_fn ensures the worker thread exits after
* VHOST_TASK_FLAGS_SOP becomes true.
* VHOST_TASK_FLAGS_STOP becomes true.
*/
void vhost_task_stop(struct vhost_task *vtsk)
{
set_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags);
vhost_task_wake(vtsk);
mutex_lock(&vtsk->exit_mutex);
if (!test_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags)) {
set_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags);
vhost_task_wake(vtsk);
}
mutex_unlock(&vtsk->exit_mutex);

/*
* Make sure vhost_task_fn is no longer accessing the vhost_task before
* freeing it below.
Expand All @@ -96,14 +107,16 @@ EXPORT_SYMBOL_GPL(vhost_task_stop);
/**
* vhost_task_create - create a copy of a task to be used by the kernel
* @fn: vhost worker function
* @arg: data to be passed to fn
* @handle_sigkill: vhost function to handle when we are killed
* @arg: data to be passed to fn and handled_kill
* @name: the thread's name
*
* This returns a specialized task for use by the vhost layer or NULL on
* failure. The returned task is inactive, and the caller must fire it up
* through vhost_task_start().
*/
struct vhost_task *vhost_task_create(bool (*fn)(void *), void *arg,
struct vhost_task *vhost_task_create(bool (*fn)(void *),
void (*handle_sigkill)(void *), void *arg,
const char *name)
{
struct kernel_clone_args args = {
Expand All @@ -122,8 +135,10 @@ struct vhost_task *vhost_task_create(bool (*fn)(void *), void *arg,
if (!vtsk)
return NULL;
init_completion(&vtsk->exited);
mutex_init(&vtsk->exit_mutex);
vtsk->data = arg;
vtsk->fn = fn;
vtsk->handle_sigkill = handle_sigkill;

args.fn_arg = vtsk;

Expand Down

0 comments on commit db5247d

Please sign in to comment.