xref: /openbmc/linux/kernel/task_work.c (revision 158e1645e07f3e9f7e4962d7a0997f5c3b98311b)
1 #include <linux/spinlock.h>
2 #include <linux/task_work.h>
3 #include <linux/tracehook.h>
4 
5 int
6 task_work_add(struct task_struct *task, struct task_work *twork, bool notify)
7 {
8 	unsigned long flags;
9 	int err = -ESRCH;
10 
11 #ifndef TIF_NOTIFY_RESUME
12 	if (notify)
13 		return -ENOTSUPP;
14 #endif
15 	/*
16 	 * We must not insert the new work if the task has already passed
17 	 * exit_task_work(). We rely on do_exit()->raw_spin_unlock_wait()
18 	 * and check PF_EXITING under pi_lock.
19 	 */
20 	raw_spin_lock_irqsave(&task->pi_lock, flags);
21 	if (likely(!(task->flags & PF_EXITING))) {
22 		struct task_work *last = task->task_works;
23 		struct task_work *first = last ? last->next : twork;
24 		twork->next = first;
25 		if (last)
26 			last->next = twork;
27 		task->task_works = twork;
28 		err = 0;
29 	}
30 	raw_spin_unlock_irqrestore(&task->pi_lock, flags);
31 
32 	/* test_and_set_bit() implies mb(), see tracehook_notify_resume(). */
33 	if (likely(!err) && notify)
34 		set_notify_resume(task);
35 	return err;
36 }
37 
38 struct task_work *
39 task_work_cancel(struct task_struct *task, task_work_func_t func)
40 {
41 	unsigned long flags;
42 	struct task_work *last, *res = NULL;
43 
44 	raw_spin_lock_irqsave(&task->pi_lock, flags);
45 	last = task->task_works;
46 	if (last) {
47 		struct task_work *q = last, *p = q->next;
48 		while (1) {
49 			if (p->func == func) {
50 				q->next = p->next;
51 				if (p == last)
52 					task->task_works = q == p ? NULL : q;
53 				res = p;
54 				break;
55 			}
56 			if (p == last)
57 				break;
58 			q = p;
59 			p = q->next;
60 		}
61 	}
62 	raw_spin_unlock_irqrestore(&task->pi_lock, flags);
63 	return res;
64 }
65 
66 void task_work_run(void)
67 {
68 	struct task_struct *task = current;
69 	struct task_work *p, *q;
70 
71 	raw_spin_lock_irq(&task->pi_lock);
72 	p = task->task_works;
73 	task->task_works = NULL;
74 	raw_spin_unlock_irq(&task->pi_lock);
75 
76 	if (unlikely(!p))
77 		return;
78 
79 	q = p->next; /* head */
80 	p->next = NULL; /* cut it */
81 	while (q) {
82 		p = q->next;
83 		q->func(q);
84 		q = p;
85 	}
86 }
87