1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Copyright (C) 2021 Oracle Corporation 4 */ 5 #include <linux/slab.h> 6 #include <linux/completion.h> 7 #include <linux/sched/task.h> 8 #include <linux/sched/vhost_task.h> 9 #include <linux/sched/signal.h> 10 11 enum vhost_task_flags { 12 VHOST_TASK_FLAGS_STOP, 13 VHOST_TASK_FLAGS_KILLED, 14 }; 15 16 struct vhost_task { 17 bool (*fn)(void *data); 18 void (*handle_sigkill)(void *data); 19 void *data; 20 struct completion exited; 21 unsigned long flags; 22 struct task_struct *task; 23 /* serialize SIGKILL and vhost_task_stop calls */ 24 struct mutex exit_mutex; 25 }; 26 27 static int vhost_task_fn(void *data) 28 { 29 struct vhost_task *vtsk = data; 30 31 for (;;) { 32 bool did_work; 33 34 if (signal_pending(current)) { 35 struct ksignal ksig; 36 37 if (get_signal(&ksig)) 38 break; 39 } 40 41 /* mb paired w/ vhost_task_stop */ 42 set_current_state(TASK_INTERRUPTIBLE); 43 44 if (test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags)) { 45 __set_current_state(TASK_RUNNING); 46 break; 47 } 48 49 did_work = vtsk->fn(vtsk->data); 50 if (!did_work) 51 schedule(); 52 } 53 54 mutex_lock(&vtsk->exit_mutex); 55 /* 56 * If a vhost_task_stop and SIGKILL race, we can ignore the SIGKILL. 57 * When the vhost layer has called vhost_task_stop it's already stopped 58 * new work and flushed. 59 */ 60 if (!test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags)) { 61 set_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags); 62 vtsk->handle_sigkill(vtsk->data); 63 } 64 mutex_unlock(&vtsk->exit_mutex); 65 complete(&vtsk->exited); 66 67 do_exit(0); 68 } 69 70 /** 71 * vhost_task_wake - wakeup the vhost_task 72 * @vtsk: vhost_task to wake 73 * 74 * wake up the vhost_task worker thread 75 */ 76 void vhost_task_wake(struct vhost_task *vtsk) 77 { 78 wake_up_process(vtsk->task); 79 } 80 EXPORT_SYMBOL_GPL(vhost_task_wake); 81 82 /** 83 * vhost_task_stop - stop a vhost_task 84 * @vtsk: vhost_task to stop 85 * 86 * vhost_task_fn ensures the worker thread exits after 87 * VHOST_TASK_FLAGS_STOP becomes true. 88 */ 89 void vhost_task_stop(struct vhost_task *vtsk) 90 { 91 mutex_lock(&vtsk->exit_mutex); 92 if (!test_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags)) { 93 set_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags); 94 vhost_task_wake(vtsk); 95 } 96 mutex_unlock(&vtsk->exit_mutex); 97 98 /* 99 * Make sure vhost_task_fn is no longer accessing the vhost_task before 100 * freeing it below. 101 */ 102 wait_for_completion(&vtsk->exited); 103 kfree(vtsk); 104 } 105 EXPORT_SYMBOL_GPL(vhost_task_stop); 106 107 /** 108 * vhost_task_create - create a copy of a task to be used by the kernel 109 * @fn: vhost worker function 110 * @handle_sigkill: vhost function to handle when we are killed 111 * @arg: data to be passed to fn and handled_kill 112 * @name: the thread's name 113 * 114 * This returns a specialized task for use by the vhost layer or NULL on 115 * failure. The returned task is inactive, and the caller must fire it up 116 * through vhost_task_start(). 117 */ 118 struct vhost_task *vhost_task_create(bool (*fn)(void *), 119 void (*handle_sigkill)(void *), void *arg, 120 const char *name) 121 { 122 struct kernel_clone_args args = { 123 .flags = CLONE_FS | CLONE_UNTRACED | CLONE_VM | 124 CLONE_THREAD | CLONE_SIGHAND, 125 .exit_signal = 0, 126 .fn = vhost_task_fn, 127 .name = name, 128 .user_worker = 1, 129 .no_files = 1, 130 }; 131 struct vhost_task *vtsk; 132 struct task_struct *tsk; 133 134 vtsk = kzalloc(sizeof(*vtsk), GFP_KERNEL); 135 if (!vtsk) 136 return NULL; 137 init_completion(&vtsk->exited); 138 mutex_init(&vtsk->exit_mutex); 139 vtsk->data = arg; 140 vtsk->fn = fn; 141 vtsk->handle_sigkill = handle_sigkill; 142 143 args.fn_arg = vtsk; 144 145 tsk = copy_process(NULL, 0, NUMA_NO_NODE, &args); 146 if (IS_ERR(tsk)) { 147 kfree(vtsk); 148 return NULL; 149 } 150 151 vtsk->task = tsk; 152 return vtsk; 153 } 154 EXPORT_SYMBOL_GPL(vhost_task_create); 155 156 /** 157 * vhost_task_start - start a vhost_task created with vhost_task_create 158 * @vtsk: vhost_task to wake up 159 */ 160 void vhost_task_start(struct vhost_task *vtsk) 161 { 162 wake_up_new_task(vtsk->task); 163 } 164 EXPORT_SYMBOL_GPL(vhost_task_start); 165