1 // SPDX-License-Identifier: MIT
2 
3 /*
4  * Copyright © 2019 Intel Corporation
5  */
6 
7 #include <linux/delay.h>
8 #include <linux/dma-fence.h>
9 #include <linux/dma-fence-chain.h>
10 #include <linux/kernel.h>
11 #include <linux/kthread.h>
12 #include <linux/mm.h>
13 #include <linux/sched/signal.h>
14 #include <linux/slab.h>
15 #include <linux/spinlock.h>
16 #include <linux/random.h>
17 
18 #include "selftest.h"
19 
20 #define CHAIN_SZ (4 << 10)
21 
22 static struct kmem_cache *slab_fences;
23 
24 static inline struct mock_fence {
25 	struct dma_fence base;
26 	spinlock_t lock;
27 } *to_mock_fence(struct dma_fence *f) {
28 	return container_of(f, struct mock_fence, base);
29 }
30 
31 static const char *mock_name(struct dma_fence *f)
32 {
33 	return "mock";
34 }
35 
36 static void mock_fence_release(struct dma_fence *f)
37 {
38 	kmem_cache_free(slab_fences, to_mock_fence(f));
39 }
40 
41 static const struct dma_fence_ops mock_ops = {
42 	.get_driver_name = mock_name,
43 	.get_timeline_name = mock_name,
44 	.release = mock_fence_release,
45 };
46 
47 static struct dma_fence *mock_fence(void)
48 {
49 	struct mock_fence *f;
50 
51 	f = kmem_cache_alloc(slab_fences, GFP_KERNEL);
52 	if (!f)
53 		return NULL;
54 
55 	spin_lock_init(&f->lock);
56 	dma_fence_init(&f->base, &mock_ops, &f->lock, 0, 0);
57 
58 	return &f->base;
59 }
60 
61 static struct dma_fence *mock_chain(struct dma_fence *prev,
62 				    struct dma_fence *fence,
63 				    u64 seqno)
64 {
65 	struct dma_fence_chain *f;
66 
67 	f = dma_fence_chain_alloc();
68 	if (!f)
69 		return NULL;
70 
71 	dma_fence_chain_init(f, dma_fence_get(prev), dma_fence_get(fence),
72 			     seqno);
73 
74 	return &f->base;
75 }
76 
77 static int sanitycheck(void *arg)
78 {
79 	struct dma_fence *f, *chain;
80 	int err = 0;
81 
82 	f = mock_fence();
83 	if (!f)
84 		return -ENOMEM;
85 
86 	chain = mock_chain(NULL, f, 1);
87 	if (chain)
88 		dma_fence_enable_sw_signaling(chain);
89 	else
90 		err = -ENOMEM;
91 
92 	dma_fence_signal(f);
93 	dma_fence_put(f);
94 
95 	dma_fence_put(chain);
96 
97 	return err;
98 }
99 
100 struct fence_chains {
101 	unsigned int chain_length;
102 	struct dma_fence **fences;
103 	struct dma_fence **chains;
104 
105 	struct dma_fence *tail;
106 };
107 
108 static uint64_t seqno_inc(unsigned int i)
109 {
110 	return i + 1;
111 }
112 
113 static int fence_chains_init(struct fence_chains *fc, unsigned int count,
114 			     uint64_t (*seqno_fn)(unsigned int))
115 {
116 	unsigned int i;
117 	int err = 0;
118 
119 	fc->chains = kvmalloc_array(count, sizeof(*fc->chains),
120 				    GFP_KERNEL | __GFP_ZERO);
121 	if (!fc->chains)
122 		return -ENOMEM;
123 
124 	fc->fences = kvmalloc_array(count, sizeof(*fc->fences),
125 				    GFP_KERNEL | __GFP_ZERO);
126 	if (!fc->fences) {
127 		err = -ENOMEM;
128 		goto err_chains;
129 	}
130 
131 	fc->tail = NULL;
132 	for (i = 0; i < count; i++) {
133 		fc->fences[i] = mock_fence();
134 		if (!fc->fences[i]) {
135 			err = -ENOMEM;
136 			goto unwind;
137 		}
138 
139 		fc->chains[i] = mock_chain(fc->tail,
140 					   fc->fences[i],
141 					   seqno_fn(i));
142 		if (!fc->chains[i]) {
143 			err = -ENOMEM;
144 			goto unwind;
145 		}
146 
147 		fc->tail = fc->chains[i];
148 
149 		dma_fence_enable_sw_signaling(fc->chains[i]);
150 	}
151 
152 	fc->chain_length = i;
153 	return 0;
154 
155 unwind:
156 	for (i = 0; i < count; i++) {
157 		dma_fence_put(fc->fences[i]);
158 		dma_fence_put(fc->chains[i]);
159 	}
160 	kvfree(fc->fences);
161 err_chains:
162 	kvfree(fc->chains);
163 	return err;
164 }
165 
166 static void fence_chains_fini(struct fence_chains *fc)
167 {
168 	unsigned int i;
169 
170 	for (i = 0; i < fc->chain_length; i++) {
171 		dma_fence_signal(fc->fences[i]);
172 		dma_fence_put(fc->fences[i]);
173 	}
174 	kvfree(fc->fences);
175 
176 	for (i = 0; i < fc->chain_length; i++)
177 		dma_fence_put(fc->chains[i]);
178 	kvfree(fc->chains);
179 }
180 
181 static int find_seqno(void *arg)
182 {
183 	struct fence_chains fc;
184 	struct dma_fence *fence;
185 	int err;
186 	int i;
187 
188 	err = fence_chains_init(&fc, 64, seqno_inc);
189 	if (err)
190 		return err;
191 
192 	fence = dma_fence_get(fc.tail);
193 	err = dma_fence_chain_find_seqno(&fence, 0);
194 	dma_fence_put(fence);
195 	if (err) {
196 		pr_err("Reported %d for find_seqno(0)!\n", err);
197 		goto err;
198 	}
199 
200 	for (i = 0; i < fc.chain_length; i++) {
201 		fence = dma_fence_get(fc.tail);
202 		err = dma_fence_chain_find_seqno(&fence, i + 1);
203 		dma_fence_put(fence);
204 		if (err) {
205 			pr_err("Reported %d for find_seqno(%d:%d)!\n",
206 			       err, fc.chain_length + 1, i + 1);
207 			goto err;
208 		}
209 		if (fence != fc.chains[i]) {
210 			pr_err("Incorrect fence reported by find_seqno(%d:%d)\n",
211 			       fc.chain_length + 1, i + 1);
212 			err = -EINVAL;
213 			goto err;
214 		}
215 
216 		dma_fence_get(fence);
217 		err = dma_fence_chain_find_seqno(&fence, i + 1);
218 		dma_fence_put(fence);
219 		if (err) {
220 			pr_err("Error reported for finding self\n");
221 			goto err;
222 		}
223 		if (fence != fc.chains[i]) {
224 			pr_err("Incorrect fence reported by find self\n");
225 			err = -EINVAL;
226 			goto err;
227 		}
228 
229 		dma_fence_get(fence);
230 		err = dma_fence_chain_find_seqno(&fence, i + 2);
231 		dma_fence_put(fence);
232 		if (!err) {
233 			pr_err("Error not reported for future fence: find_seqno(%d:%d)!\n",
234 			       i + 1, i + 2);
235 			err = -EINVAL;
236 			goto err;
237 		}
238 
239 		dma_fence_get(fence);
240 		err = dma_fence_chain_find_seqno(&fence, i);
241 		dma_fence_put(fence);
242 		if (err) {
243 			pr_err("Error reported for previous fence!\n");
244 			goto err;
245 		}
246 		if (i > 0 && fence != fc.chains[i - 1]) {
247 			pr_err("Incorrect fence reported by find_seqno(%d:%d)\n",
248 			       i + 1, i);
249 			err = -EINVAL;
250 			goto err;
251 		}
252 	}
253 
254 err:
255 	fence_chains_fini(&fc);
256 	return err;
257 }
258 
259 static int find_signaled(void *arg)
260 {
261 	struct fence_chains fc;
262 	struct dma_fence *fence;
263 	int err;
264 
265 	err = fence_chains_init(&fc, 2, seqno_inc);
266 	if (err)
267 		return err;
268 
269 	dma_fence_signal(fc.fences[0]);
270 
271 	fence = dma_fence_get(fc.tail);
272 	err = dma_fence_chain_find_seqno(&fence, 1);
273 	dma_fence_put(fence);
274 	if (err) {
275 		pr_err("Reported %d for find_seqno()!\n", err);
276 		goto err;
277 	}
278 
279 	if (fence && fence != fc.chains[0]) {
280 		pr_err("Incorrect chain-fence.seqno:%lld reported for completed seqno:1\n",
281 		       fence->seqno);
282 
283 		dma_fence_get(fence);
284 		err = dma_fence_chain_find_seqno(&fence, 1);
285 		dma_fence_put(fence);
286 		if (err)
287 			pr_err("Reported %d for finding self!\n", err);
288 
289 		err = -EINVAL;
290 	}
291 
292 err:
293 	fence_chains_fini(&fc);
294 	return err;
295 }
296 
297 static int find_out_of_order(void *arg)
298 {
299 	struct fence_chains fc;
300 	struct dma_fence *fence;
301 	int err;
302 
303 	err = fence_chains_init(&fc, 3, seqno_inc);
304 	if (err)
305 		return err;
306 
307 	dma_fence_signal(fc.fences[1]);
308 
309 	fence = dma_fence_get(fc.tail);
310 	err = dma_fence_chain_find_seqno(&fence, 2);
311 	dma_fence_put(fence);
312 	if (err) {
313 		pr_err("Reported %d for find_seqno()!\n", err);
314 		goto err;
315 	}
316 
317 	/*
318 	 * We signaled the middle fence (2) of the 1-2-3 chain. The behavior
319 	 * of the dma-fence-chain is to make us wait for all the fences up to
320 	 * the point we want. Since fence 1 is still not signaled, this what
321 	 * we should get as fence to wait upon (fence 2 being garbage
322 	 * collected during the traversal of the chain).
323 	 */
324 	if (fence != fc.chains[0]) {
325 		pr_err("Incorrect chain-fence.seqno:%lld reported for completed seqno:2\n",
326 		       fence ? fence->seqno : 0);
327 
328 		err = -EINVAL;
329 	}
330 
331 err:
332 	fence_chains_fini(&fc);
333 	return err;
334 }
335 
336 static uint64_t seqno_inc2(unsigned int i)
337 {
338 	return 2 * i + 2;
339 }
340 
341 static int find_gap(void *arg)
342 {
343 	struct fence_chains fc;
344 	struct dma_fence *fence;
345 	int err;
346 	int i;
347 
348 	err = fence_chains_init(&fc, 64, seqno_inc2);
349 	if (err)
350 		return err;
351 
352 	for (i = 0; i < fc.chain_length; i++) {
353 		fence = dma_fence_get(fc.tail);
354 		err = dma_fence_chain_find_seqno(&fence, 2 * i + 1);
355 		dma_fence_put(fence);
356 		if (err) {
357 			pr_err("Reported %d for find_seqno(%d:%d)!\n",
358 			       err, fc.chain_length + 1, 2 * i + 1);
359 			goto err;
360 		}
361 		if (fence != fc.chains[i]) {
362 			pr_err("Incorrect fence.seqno:%lld reported by find_seqno(%d:%d)\n",
363 			       fence->seqno,
364 			       fc.chain_length + 1,
365 			       2 * i + 1);
366 			err = -EINVAL;
367 			goto err;
368 		}
369 
370 		dma_fence_get(fence);
371 		err = dma_fence_chain_find_seqno(&fence, 2 * i + 2);
372 		dma_fence_put(fence);
373 		if (err) {
374 			pr_err("Error reported for finding self\n");
375 			goto err;
376 		}
377 		if (fence != fc.chains[i]) {
378 			pr_err("Incorrect fence reported by find self\n");
379 			err = -EINVAL;
380 			goto err;
381 		}
382 	}
383 
384 err:
385 	fence_chains_fini(&fc);
386 	return err;
387 }
388 
389 struct find_race {
390 	struct fence_chains fc;
391 	atomic_t children;
392 };
393 
394 static int __find_race(void *arg)
395 {
396 	struct find_race *data = arg;
397 	int err = 0;
398 
399 	while (!kthread_should_stop()) {
400 		struct dma_fence *fence = dma_fence_get(data->fc.tail);
401 		int seqno;
402 
403 		seqno = get_random_u32_inclusive(1, data->fc.chain_length);
404 
405 		err = dma_fence_chain_find_seqno(&fence, seqno);
406 		if (err) {
407 			pr_err("Failed to find fence seqno:%d\n",
408 			       seqno);
409 			dma_fence_put(fence);
410 			break;
411 		}
412 		if (!fence)
413 			goto signal;
414 
415 		/*
416 		 * We can only find ourselves if we are on fence we were
417 		 * looking for.
418 		 */
419 		if (fence->seqno == seqno) {
420 			err = dma_fence_chain_find_seqno(&fence, seqno);
421 			if (err) {
422 				pr_err("Reported an invalid fence for find-self:%d\n",
423 				       seqno);
424 				dma_fence_put(fence);
425 				break;
426 			}
427 		}
428 
429 		dma_fence_put(fence);
430 
431 signal:
432 		seqno = get_random_u32_below(data->fc.chain_length - 1);
433 		dma_fence_signal(data->fc.fences[seqno]);
434 		cond_resched();
435 	}
436 
437 	if (atomic_dec_and_test(&data->children))
438 		wake_up_var(&data->children);
439 	return err;
440 }
441 
442 static int find_race(void *arg)
443 {
444 	struct find_race data;
445 	int ncpus = num_online_cpus();
446 	struct task_struct **threads;
447 	unsigned long count;
448 	int err;
449 	int i;
450 
451 	err = fence_chains_init(&data.fc, CHAIN_SZ, seqno_inc);
452 	if (err)
453 		return err;
454 
455 	threads = kmalloc_array(ncpus, sizeof(*threads), GFP_KERNEL);
456 	if (!threads) {
457 		err = -ENOMEM;
458 		goto err;
459 	}
460 
461 	atomic_set(&data.children, 0);
462 	for (i = 0; i < ncpus; i++) {
463 		threads[i] = kthread_run(__find_race, &data, "dmabuf/%d", i);
464 		if (IS_ERR(threads[i])) {
465 			ncpus = i;
466 			break;
467 		}
468 		atomic_inc(&data.children);
469 		get_task_struct(threads[i]);
470 	}
471 
472 	wait_var_event_timeout(&data.children,
473 			       !atomic_read(&data.children),
474 			       5 * HZ);
475 
476 	for (i = 0; i < ncpus; i++) {
477 		int ret;
478 
479 		ret = kthread_stop(threads[i]);
480 		if (ret && !err)
481 			err = ret;
482 		put_task_struct(threads[i]);
483 	}
484 	kfree(threads);
485 
486 	count = 0;
487 	for (i = 0; i < data.fc.chain_length; i++)
488 		if (dma_fence_is_signaled(data.fc.fences[i]))
489 			count++;
490 	pr_info("Completed %lu cycles\n", count);
491 
492 err:
493 	fence_chains_fini(&data.fc);
494 	return err;
495 }
496 
497 static int signal_forward(void *arg)
498 {
499 	struct fence_chains fc;
500 	int err;
501 	int i;
502 
503 	err = fence_chains_init(&fc, 64, seqno_inc);
504 	if (err)
505 		return err;
506 
507 	for (i = 0; i < fc.chain_length; i++) {
508 		dma_fence_signal(fc.fences[i]);
509 
510 		if (!dma_fence_is_signaled(fc.chains[i])) {
511 			pr_err("chain[%d] not signaled!\n", i);
512 			err = -EINVAL;
513 			goto err;
514 		}
515 
516 		if (i + 1 < fc.chain_length &&
517 		    dma_fence_is_signaled(fc.chains[i + 1])) {
518 			pr_err("chain[%d] is signaled!\n", i);
519 			err = -EINVAL;
520 			goto err;
521 		}
522 	}
523 
524 err:
525 	fence_chains_fini(&fc);
526 	return err;
527 }
528 
529 static int signal_backward(void *arg)
530 {
531 	struct fence_chains fc;
532 	int err;
533 	int i;
534 
535 	err = fence_chains_init(&fc, 64, seqno_inc);
536 	if (err)
537 		return err;
538 
539 	for (i = fc.chain_length; i--; ) {
540 		dma_fence_signal(fc.fences[i]);
541 
542 		if (i > 0 && dma_fence_is_signaled(fc.chains[i])) {
543 			pr_err("chain[%d] is signaled!\n", i);
544 			err = -EINVAL;
545 			goto err;
546 		}
547 	}
548 
549 	for (i = 0; i < fc.chain_length; i++) {
550 		if (!dma_fence_is_signaled(fc.chains[i])) {
551 			pr_err("chain[%d] was not signaled!\n", i);
552 			err = -EINVAL;
553 			goto err;
554 		}
555 	}
556 
557 err:
558 	fence_chains_fini(&fc);
559 	return err;
560 }
561 
562 static int __wait_fence_chains(void *arg)
563 {
564 	struct fence_chains *fc = arg;
565 
566 	if (dma_fence_wait(fc->tail, false))
567 		return -EIO;
568 
569 	return 0;
570 }
571 
572 static int wait_forward(void *arg)
573 {
574 	struct fence_chains fc;
575 	struct task_struct *tsk;
576 	int err;
577 	int i;
578 
579 	err = fence_chains_init(&fc, CHAIN_SZ, seqno_inc);
580 	if (err)
581 		return err;
582 
583 	tsk = kthread_run(__wait_fence_chains, &fc, "dmabuf/wait");
584 	if (IS_ERR(tsk)) {
585 		err = PTR_ERR(tsk);
586 		goto err;
587 	}
588 	get_task_struct(tsk);
589 	yield_to(tsk, true);
590 
591 	for (i = 0; i < fc.chain_length; i++)
592 		dma_fence_signal(fc.fences[i]);
593 
594 	err = kthread_stop(tsk);
595 	put_task_struct(tsk);
596 
597 err:
598 	fence_chains_fini(&fc);
599 	return err;
600 }
601 
602 static int wait_backward(void *arg)
603 {
604 	struct fence_chains fc;
605 	struct task_struct *tsk;
606 	int err;
607 	int i;
608 
609 	err = fence_chains_init(&fc, CHAIN_SZ, seqno_inc);
610 	if (err)
611 		return err;
612 
613 	tsk = kthread_run(__wait_fence_chains, &fc, "dmabuf/wait");
614 	if (IS_ERR(tsk)) {
615 		err = PTR_ERR(tsk);
616 		goto err;
617 	}
618 	get_task_struct(tsk);
619 	yield_to(tsk, true);
620 
621 	for (i = fc.chain_length; i--; )
622 		dma_fence_signal(fc.fences[i]);
623 
624 	err = kthread_stop(tsk);
625 	put_task_struct(tsk);
626 
627 err:
628 	fence_chains_fini(&fc);
629 	return err;
630 }
631 
632 static void randomise_fences(struct fence_chains *fc)
633 {
634 	unsigned int count = fc->chain_length;
635 
636 	/* Fisher-Yates shuffle courtesy of Knuth */
637 	while (--count) {
638 		unsigned int swp;
639 
640 		swp = get_random_u32_below(count + 1);
641 		if (swp == count)
642 			continue;
643 
644 		swap(fc->fences[count], fc->fences[swp]);
645 	}
646 }
647 
648 static int wait_random(void *arg)
649 {
650 	struct fence_chains fc;
651 	struct task_struct *tsk;
652 	int err;
653 	int i;
654 
655 	err = fence_chains_init(&fc, CHAIN_SZ, seqno_inc);
656 	if (err)
657 		return err;
658 
659 	randomise_fences(&fc);
660 
661 	tsk = kthread_run(__wait_fence_chains, &fc, "dmabuf/wait");
662 	if (IS_ERR(tsk)) {
663 		err = PTR_ERR(tsk);
664 		goto err;
665 	}
666 	get_task_struct(tsk);
667 	yield_to(tsk, true);
668 
669 	for (i = 0; i < fc.chain_length; i++)
670 		dma_fence_signal(fc.fences[i]);
671 
672 	err = kthread_stop(tsk);
673 	put_task_struct(tsk);
674 
675 err:
676 	fence_chains_fini(&fc);
677 	return err;
678 }
679 
680 int dma_fence_chain(void)
681 {
682 	static const struct subtest tests[] = {
683 		SUBTEST(sanitycheck),
684 		SUBTEST(find_seqno),
685 		SUBTEST(find_signaled),
686 		SUBTEST(find_out_of_order),
687 		SUBTEST(find_gap),
688 		SUBTEST(find_race),
689 		SUBTEST(signal_forward),
690 		SUBTEST(signal_backward),
691 		SUBTEST(wait_forward),
692 		SUBTEST(wait_backward),
693 		SUBTEST(wait_random),
694 	};
695 	int ret;
696 
697 	pr_info("sizeof(dma_fence_chain)=%zu\n",
698 		sizeof(struct dma_fence_chain));
699 
700 	slab_fences = KMEM_CACHE(mock_fence,
701 				 SLAB_TYPESAFE_BY_RCU |
702 				 SLAB_HWCACHE_ALIGN);
703 	if (!slab_fences)
704 		return -ENOMEM;
705 
706 	ret = subtests(tests, NULL);
707 
708 	kmem_cache_destroy(slab_fences);
709 	return ret;
710 }
711