1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Module-based API test facility for ww_mutexes
4  */
5 
6 #include <linux/kernel.h>
7 
8 #include <linux/completion.h>
9 #include <linux/delay.h>
10 #include <linux/kthread.h>
11 #include <linux/module.h>
12 #include <linux/random.h>
13 #include <linux/slab.h>
14 #include <linux/ww_mutex.h>
15 
16 static DEFINE_WD_CLASS(ww_class);
17 struct workqueue_struct *wq;
18 
19 #ifdef CONFIG_DEBUG_WW_MUTEX_SLOWPATH
20 #define ww_acquire_init_noinject(a, b) do { \
21 		ww_acquire_init((a), (b)); \
22 		(a)->deadlock_inject_countdown = ~0U; \
23 	} while (0)
24 #else
25 #define ww_acquire_init_noinject(a, b) ww_acquire_init((a), (b))
26 #endif
27 
28 struct test_mutex {
29 	struct work_struct work;
30 	struct ww_mutex mutex;
31 	struct completion ready, go, done;
32 	unsigned int flags;
33 };
34 
35 #define TEST_MTX_SPIN BIT(0)
36 #define TEST_MTX_TRY BIT(1)
37 #define TEST_MTX_CTX BIT(2)
38 #define __TEST_MTX_LAST BIT(3)
39 
40 static void test_mutex_work(struct work_struct *work)
41 {
42 	struct test_mutex *mtx = container_of(work, typeof(*mtx), work);
43 
44 	complete(&mtx->ready);
45 	wait_for_completion(&mtx->go);
46 
47 	if (mtx->flags & TEST_MTX_TRY) {
48 		while (!ww_mutex_trylock(&mtx->mutex, NULL))
49 			cond_resched();
50 	} else {
51 		ww_mutex_lock(&mtx->mutex, NULL);
52 	}
53 	complete(&mtx->done);
54 	ww_mutex_unlock(&mtx->mutex);
55 }
56 
57 static int __test_mutex(unsigned int flags)
58 {
59 #define TIMEOUT (HZ / 16)
60 	struct test_mutex mtx;
61 	struct ww_acquire_ctx ctx;
62 	int ret;
63 
64 	ww_mutex_init(&mtx.mutex, &ww_class);
65 	ww_acquire_init(&ctx, &ww_class);
66 
67 	INIT_WORK_ONSTACK(&mtx.work, test_mutex_work);
68 	init_completion(&mtx.ready);
69 	init_completion(&mtx.go);
70 	init_completion(&mtx.done);
71 	mtx.flags = flags;
72 
73 	schedule_work(&mtx.work);
74 
75 	wait_for_completion(&mtx.ready);
76 	ww_mutex_lock(&mtx.mutex, (flags & TEST_MTX_CTX) ? &ctx : NULL);
77 	complete(&mtx.go);
78 	if (flags & TEST_MTX_SPIN) {
79 		unsigned long timeout = jiffies + TIMEOUT;
80 
81 		ret = 0;
82 		do {
83 			if (completion_done(&mtx.done)) {
84 				ret = -EINVAL;
85 				break;
86 			}
87 			cond_resched();
88 		} while (time_before(jiffies, timeout));
89 	} else {
90 		ret = wait_for_completion_timeout(&mtx.done, TIMEOUT);
91 	}
92 	ww_mutex_unlock(&mtx.mutex);
93 	ww_acquire_fini(&ctx);
94 
95 	if (ret) {
96 		pr_err("%s(flags=%x): mutual exclusion failure\n",
97 		       __func__, flags);
98 		ret = -EINVAL;
99 	}
100 
101 	flush_work(&mtx.work);
102 	destroy_work_on_stack(&mtx.work);
103 	return ret;
104 #undef TIMEOUT
105 }
106 
107 static int test_mutex(void)
108 {
109 	int ret;
110 	int i;
111 
112 	for (i = 0; i < __TEST_MTX_LAST; i++) {
113 		ret = __test_mutex(i);
114 		if (ret)
115 			return ret;
116 	}
117 
118 	return 0;
119 }
120 
121 static int test_aa(bool trylock)
122 {
123 	struct ww_mutex mutex;
124 	struct ww_acquire_ctx ctx;
125 	int ret;
126 	const char *from = trylock ? "trylock" : "lock";
127 
128 	ww_mutex_init(&mutex, &ww_class);
129 	ww_acquire_init(&ctx, &ww_class);
130 
131 	if (!trylock) {
132 		ret = ww_mutex_lock(&mutex, &ctx);
133 		if (ret) {
134 			pr_err("%s: initial lock failed!\n", __func__);
135 			goto out;
136 		}
137 	} else {
138 		if (!ww_mutex_trylock(&mutex, &ctx)) {
139 			pr_err("%s: initial trylock failed!\n", __func__);
140 			goto out;
141 		}
142 	}
143 
144 	if (ww_mutex_trylock(&mutex, NULL))  {
145 		pr_err("%s: trylocked itself without context from %s!\n", __func__, from);
146 		ww_mutex_unlock(&mutex);
147 		ret = -EINVAL;
148 		goto out;
149 	}
150 
151 	if (ww_mutex_trylock(&mutex, &ctx))  {
152 		pr_err("%s: trylocked itself with context from %s!\n", __func__, from);
153 		ww_mutex_unlock(&mutex);
154 		ret = -EINVAL;
155 		goto out;
156 	}
157 
158 	ret = ww_mutex_lock(&mutex, &ctx);
159 	if (ret != -EALREADY) {
160 		pr_err("%s: missed deadlock for recursing, ret=%d from %s\n",
161 		       __func__, ret, from);
162 		if (!ret)
163 			ww_mutex_unlock(&mutex);
164 		ret = -EINVAL;
165 		goto out;
166 	}
167 
168 	ww_mutex_unlock(&mutex);
169 	ret = 0;
170 out:
171 	ww_acquire_fini(&ctx);
172 	return ret;
173 }
174 
175 struct test_abba {
176 	struct work_struct work;
177 	struct ww_mutex a_mutex;
178 	struct ww_mutex b_mutex;
179 	struct completion a_ready;
180 	struct completion b_ready;
181 	bool resolve, trylock;
182 	int result;
183 };
184 
185 static void test_abba_work(struct work_struct *work)
186 {
187 	struct test_abba *abba = container_of(work, typeof(*abba), work);
188 	struct ww_acquire_ctx ctx;
189 	int err;
190 
191 	ww_acquire_init_noinject(&ctx, &ww_class);
192 	if (!abba->trylock)
193 		ww_mutex_lock(&abba->b_mutex, &ctx);
194 	else
195 		WARN_ON(!ww_mutex_trylock(&abba->b_mutex, &ctx));
196 
197 	WARN_ON(READ_ONCE(abba->b_mutex.ctx) != &ctx);
198 
199 	complete(&abba->b_ready);
200 	wait_for_completion(&abba->a_ready);
201 
202 	err = ww_mutex_lock(&abba->a_mutex, &ctx);
203 	if (abba->resolve && err == -EDEADLK) {
204 		ww_mutex_unlock(&abba->b_mutex);
205 		ww_mutex_lock_slow(&abba->a_mutex, &ctx);
206 		err = ww_mutex_lock(&abba->b_mutex, &ctx);
207 	}
208 
209 	if (!err)
210 		ww_mutex_unlock(&abba->a_mutex);
211 	ww_mutex_unlock(&abba->b_mutex);
212 	ww_acquire_fini(&ctx);
213 
214 	abba->result = err;
215 }
216 
217 static int test_abba(bool trylock, bool resolve)
218 {
219 	struct test_abba abba;
220 	struct ww_acquire_ctx ctx;
221 	int err, ret;
222 
223 	ww_mutex_init(&abba.a_mutex, &ww_class);
224 	ww_mutex_init(&abba.b_mutex, &ww_class);
225 	INIT_WORK_ONSTACK(&abba.work, test_abba_work);
226 	init_completion(&abba.a_ready);
227 	init_completion(&abba.b_ready);
228 	abba.trylock = trylock;
229 	abba.resolve = resolve;
230 
231 	schedule_work(&abba.work);
232 
233 	ww_acquire_init_noinject(&ctx, &ww_class);
234 	if (!trylock)
235 		ww_mutex_lock(&abba.a_mutex, &ctx);
236 	else
237 		WARN_ON(!ww_mutex_trylock(&abba.a_mutex, &ctx));
238 
239 	WARN_ON(READ_ONCE(abba.a_mutex.ctx) != &ctx);
240 
241 	complete(&abba.a_ready);
242 	wait_for_completion(&abba.b_ready);
243 
244 	err = ww_mutex_lock(&abba.b_mutex, &ctx);
245 	if (resolve && err == -EDEADLK) {
246 		ww_mutex_unlock(&abba.a_mutex);
247 		ww_mutex_lock_slow(&abba.b_mutex, &ctx);
248 		err = ww_mutex_lock(&abba.a_mutex, &ctx);
249 	}
250 
251 	if (!err)
252 		ww_mutex_unlock(&abba.b_mutex);
253 	ww_mutex_unlock(&abba.a_mutex);
254 	ww_acquire_fini(&ctx);
255 
256 	flush_work(&abba.work);
257 	destroy_work_on_stack(&abba.work);
258 
259 	ret = 0;
260 	if (resolve) {
261 		if (err || abba.result) {
262 			pr_err("%s: failed to resolve ABBA deadlock, A err=%d, B err=%d\n",
263 			       __func__, err, abba.result);
264 			ret = -EINVAL;
265 		}
266 	} else {
267 		if (err != -EDEADLK && abba.result != -EDEADLK) {
268 			pr_err("%s: missed ABBA deadlock, A err=%d, B err=%d\n",
269 			       __func__, err, abba.result);
270 			ret = -EINVAL;
271 		}
272 	}
273 	return ret;
274 }
275 
276 struct test_cycle {
277 	struct work_struct work;
278 	struct ww_mutex a_mutex;
279 	struct ww_mutex *b_mutex;
280 	struct completion *a_signal;
281 	struct completion b_signal;
282 	int result;
283 };
284 
285 static void test_cycle_work(struct work_struct *work)
286 {
287 	struct test_cycle *cycle = container_of(work, typeof(*cycle), work);
288 	struct ww_acquire_ctx ctx;
289 	int err, erra = 0;
290 
291 	ww_acquire_init_noinject(&ctx, &ww_class);
292 	ww_mutex_lock(&cycle->a_mutex, &ctx);
293 
294 	complete(cycle->a_signal);
295 	wait_for_completion(&cycle->b_signal);
296 
297 	err = ww_mutex_lock(cycle->b_mutex, &ctx);
298 	if (err == -EDEADLK) {
299 		err = 0;
300 		ww_mutex_unlock(&cycle->a_mutex);
301 		ww_mutex_lock_slow(cycle->b_mutex, &ctx);
302 		erra = ww_mutex_lock(&cycle->a_mutex, &ctx);
303 	}
304 
305 	if (!err)
306 		ww_mutex_unlock(cycle->b_mutex);
307 	if (!erra)
308 		ww_mutex_unlock(&cycle->a_mutex);
309 	ww_acquire_fini(&ctx);
310 
311 	cycle->result = err ?: erra;
312 }
313 
314 static int __test_cycle(unsigned int nthreads)
315 {
316 	struct test_cycle *cycles;
317 	unsigned int n, last = nthreads - 1;
318 	int ret;
319 
320 	cycles = kmalloc_array(nthreads, sizeof(*cycles), GFP_KERNEL);
321 	if (!cycles)
322 		return -ENOMEM;
323 
324 	for (n = 0; n < nthreads; n++) {
325 		struct test_cycle *cycle = &cycles[n];
326 
327 		ww_mutex_init(&cycle->a_mutex, &ww_class);
328 		if (n == last)
329 			cycle->b_mutex = &cycles[0].a_mutex;
330 		else
331 			cycle->b_mutex = &cycles[n + 1].a_mutex;
332 
333 		if (n == 0)
334 			cycle->a_signal = &cycles[last].b_signal;
335 		else
336 			cycle->a_signal = &cycles[n - 1].b_signal;
337 		init_completion(&cycle->b_signal);
338 
339 		INIT_WORK(&cycle->work, test_cycle_work);
340 		cycle->result = 0;
341 	}
342 
343 	for (n = 0; n < nthreads; n++)
344 		queue_work(wq, &cycles[n].work);
345 
346 	flush_workqueue(wq);
347 
348 	ret = 0;
349 	for (n = 0; n < nthreads; n++) {
350 		struct test_cycle *cycle = &cycles[n];
351 
352 		if (!cycle->result)
353 			continue;
354 
355 		pr_err("cyclic deadlock not resolved, ret[%d/%d] = %d\n",
356 		       n, nthreads, cycle->result);
357 		ret = -EINVAL;
358 		break;
359 	}
360 
361 	for (n = 0; n < nthreads; n++)
362 		ww_mutex_destroy(&cycles[n].a_mutex);
363 	kfree(cycles);
364 	return ret;
365 }
366 
367 static int test_cycle(unsigned int ncpus)
368 {
369 	unsigned int n;
370 	int ret;
371 
372 	for (n = 2; n <= ncpus + 1; n++) {
373 		ret = __test_cycle(n);
374 		if (ret)
375 			return ret;
376 	}
377 
378 	return 0;
379 }
380 
381 struct stress {
382 	struct work_struct work;
383 	struct ww_mutex *locks;
384 	unsigned long timeout;
385 	int nlocks;
386 };
387 
388 static int *get_random_order(int count)
389 {
390 	int *order;
391 	int n, r, tmp;
392 
393 	order = kmalloc_array(count, sizeof(*order), GFP_KERNEL);
394 	if (!order)
395 		return order;
396 
397 	for (n = 0; n < count; n++)
398 		order[n] = n;
399 
400 	for (n = count - 1; n > 1; n--) {
401 		r = get_random_int() % (n + 1);
402 		if (r != n) {
403 			tmp = order[n];
404 			order[n] = order[r];
405 			order[r] = tmp;
406 		}
407 	}
408 
409 	return order;
410 }
411 
412 static void dummy_load(struct stress *stress)
413 {
414 	usleep_range(1000, 2000);
415 }
416 
417 static void stress_inorder_work(struct work_struct *work)
418 {
419 	struct stress *stress = container_of(work, typeof(*stress), work);
420 	const int nlocks = stress->nlocks;
421 	struct ww_mutex *locks = stress->locks;
422 	struct ww_acquire_ctx ctx;
423 	int *order;
424 
425 	order = get_random_order(nlocks);
426 	if (!order)
427 		return;
428 
429 	do {
430 		int contended = -1;
431 		int n, err;
432 
433 		ww_acquire_init(&ctx, &ww_class);
434 retry:
435 		err = 0;
436 		for (n = 0; n < nlocks; n++) {
437 			if (n == contended)
438 				continue;
439 
440 			err = ww_mutex_lock(&locks[order[n]], &ctx);
441 			if (err < 0)
442 				break;
443 		}
444 		if (!err)
445 			dummy_load(stress);
446 
447 		if (contended > n)
448 			ww_mutex_unlock(&locks[order[contended]]);
449 		contended = n;
450 		while (n--)
451 			ww_mutex_unlock(&locks[order[n]]);
452 
453 		if (err == -EDEADLK) {
454 			ww_mutex_lock_slow(&locks[order[contended]], &ctx);
455 			goto retry;
456 		}
457 
458 		if (err) {
459 			pr_err_once("stress (%s) failed with %d\n",
460 				    __func__, err);
461 			break;
462 		}
463 
464 		ww_acquire_fini(&ctx);
465 	} while (!time_after(jiffies, stress->timeout));
466 
467 	kfree(order);
468 	kfree(stress);
469 }
470 
471 struct reorder_lock {
472 	struct list_head link;
473 	struct ww_mutex *lock;
474 };
475 
476 static void stress_reorder_work(struct work_struct *work)
477 {
478 	struct stress *stress = container_of(work, typeof(*stress), work);
479 	LIST_HEAD(locks);
480 	struct ww_acquire_ctx ctx;
481 	struct reorder_lock *ll, *ln;
482 	int *order;
483 	int n, err;
484 
485 	order = get_random_order(stress->nlocks);
486 	if (!order)
487 		return;
488 
489 	for (n = 0; n < stress->nlocks; n++) {
490 		ll = kmalloc(sizeof(*ll), GFP_KERNEL);
491 		if (!ll)
492 			goto out;
493 
494 		ll->lock = &stress->locks[order[n]];
495 		list_add(&ll->link, &locks);
496 	}
497 	kfree(order);
498 	order = NULL;
499 
500 	do {
501 		ww_acquire_init(&ctx, &ww_class);
502 
503 		list_for_each_entry(ll, &locks, link) {
504 			err = ww_mutex_lock(ll->lock, &ctx);
505 			if (!err)
506 				continue;
507 
508 			ln = ll;
509 			list_for_each_entry_continue_reverse(ln, &locks, link)
510 				ww_mutex_unlock(ln->lock);
511 
512 			if (err != -EDEADLK) {
513 				pr_err_once("stress (%s) failed with %d\n",
514 					    __func__, err);
515 				break;
516 			}
517 
518 			ww_mutex_lock_slow(ll->lock, &ctx);
519 			list_move(&ll->link, &locks); /* restarts iteration */
520 		}
521 
522 		dummy_load(stress);
523 		list_for_each_entry(ll, &locks, link)
524 			ww_mutex_unlock(ll->lock);
525 
526 		ww_acquire_fini(&ctx);
527 	} while (!time_after(jiffies, stress->timeout));
528 
529 out:
530 	list_for_each_entry_safe(ll, ln, &locks, link)
531 		kfree(ll);
532 	kfree(order);
533 	kfree(stress);
534 }
535 
536 static void stress_one_work(struct work_struct *work)
537 {
538 	struct stress *stress = container_of(work, typeof(*stress), work);
539 	const int nlocks = stress->nlocks;
540 	struct ww_mutex *lock = stress->locks + (get_random_int() % nlocks);
541 	int err;
542 
543 	do {
544 		err = ww_mutex_lock(lock, NULL);
545 		if (!err) {
546 			dummy_load(stress);
547 			ww_mutex_unlock(lock);
548 		} else {
549 			pr_err_once("stress (%s) failed with %d\n",
550 				    __func__, err);
551 			break;
552 		}
553 	} while (!time_after(jiffies, stress->timeout));
554 
555 	kfree(stress);
556 }
557 
558 #define STRESS_INORDER BIT(0)
559 #define STRESS_REORDER BIT(1)
560 #define STRESS_ONE BIT(2)
561 #define STRESS_ALL (STRESS_INORDER | STRESS_REORDER | STRESS_ONE)
562 
563 static int stress(int nlocks, int nthreads, unsigned int flags)
564 {
565 	struct ww_mutex *locks;
566 	int n;
567 
568 	locks = kmalloc_array(nlocks, sizeof(*locks), GFP_KERNEL);
569 	if (!locks)
570 		return -ENOMEM;
571 
572 	for (n = 0; n < nlocks; n++)
573 		ww_mutex_init(&locks[n], &ww_class);
574 
575 	for (n = 0; nthreads; n++) {
576 		struct stress *stress;
577 		void (*fn)(struct work_struct *work);
578 
579 		fn = NULL;
580 		switch (n & 3) {
581 		case 0:
582 			if (flags & STRESS_INORDER)
583 				fn = stress_inorder_work;
584 			break;
585 		case 1:
586 			if (flags & STRESS_REORDER)
587 				fn = stress_reorder_work;
588 			break;
589 		case 2:
590 			if (flags & STRESS_ONE)
591 				fn = stress_one_work;
592 			break;
593 		}
594 
595 		if (!fn)
596 			continue;
597 
598 		stress = kmalloc(sizeof(*stress), GFP_KERNEL);
599 		if (!stress)
600 			break;
601 
602 		INIT_WORK(&stress->work, fn);
603 		stress->locks = locks;
604 		stress->nlocks = nlocks;
605 		stress->timeout = jiffies + 2*HZ;
606 
607 		queue_work(wq, &stress->work);
608 		nthreads--;
609 	}
610 
611 	flush_workqueue(wq);
612 
613 	for (n = 0; n < nlocks; n++)
614 		ww_mutex_destroy(&locks[n]);
615 	kfree(locks);
616 
617 	return 0;
618 }
619 
620 static int __init test_ww_mutex_init(void)
621 {
622 	int ncpus = num_online_cpus();
623 	int ret, i;
624 
625 	printk(KERN_INFO "Beginning ww mutex selftests\n");
626 
627 	wq = alloc_workqueue("test-ww_mutex", WQ_UNBOUND, 0);
628 	if (!wq)
629 		return -ENOMEM;
630 
631 	ret = test_mutex();
632 	if (ret)
633 		return ret;
634 
635 	ret = test_aa(false);
636 	if (ret)
637 		return ret;
638 
639 	ret = test_aa(true);
640 	if (ret)
641 		return ret;
642 
643 	for (i = 0; i < 4; i++) {
644 		ret = test_abba(i & 1, i & 2);
645 		if (ret)
646 			return ret;
647 	}
648 
649 	ret = test_cycle(ncpus);
650 	if (ret)
651 		return ret;
652 
653 	ret = stress(16, 2*ncpus, STRESS_INORDER);
654 	if (ret)
655 		return ret;
656 
657 	ret = stress(16, 2*ncpus, STRESS_REORDER);
658 	if (ret)
659 		return ret;
660 
661 	ret = stress(4095, hweight32(STRESS_ALL)*ncpus, STRESS_ALL);
662 	if (ret)
663 		return ret;
664 
665 	printk(KERN_INFO "All ww mutex selftests passed\n");
666 	return 0;
667 }
668 
669 static void __exit test_ww_mutex_exit(void)
670 {
671 	destroy_workqueue(wq);
672 }
673 
674 module_init(test_ww_mutex_init);
675 module_exit(test_ww_mutex_exit);
676 
677 MODULE_LICENSE("GPL");
678 MODULE_AUTHOR("Intel Corporation");
679