1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2022 Meta Platforms, Inc. and affiliates.*/
3 
4 #define _GNU_SOURCE
5 #include <unistd.h>
6 #include <sys/syscall.h>
7 #include <sys/types.h>
8 #include <test_progs.h>
9 #include <bpf/btf.h>
10 #include "rcu_read_lock.skel.h"
11 #include "cgroup_helpers.h"
12 
13 static unsigned long long cgroup_id;
14 
15 static void test_success(void)
16 {
17 	struct rcu_read_lock *skel;
18 	int err;
19 
20 	skel = rcu_read_lock__open();
21 	if (!ASSERT_OK_PTR(skel, "skel_open"))
22 		return;
23 
24 	skel->bss->target_pid = syscall(SYS_gettid);
25 
26 	bpf_program__set_autoload(skel->progs.get_cgroup_id, true);
27 	bpf_program__set_autoload(skel->progs.task_succ, true);
28 	bpf_program__set_autoload(skel->progs.no_lock, true);
29 	bpf_program__set_autoload(skel->progs.two_regions, true);
30 	bpf_program__set_autoload(skel->progs.non_sleepable_1, true);
31 	bpf_program__set_autoload(skel->progs.non_sleepable_2, true);
32 	err = rcu_read_lock__load(skel);
33 	if (!ASSERT_OK(err, "skel_load"))
34 		goto out;
35 
36 	err = rcu_read_lock__attach(skel);
37 	if (!ASSERT_OK(err, "skel_attach"))
38 		goto out;
39 
40 	syscall(SYS_getpgid);
41 
42 	ASSERT_EQ(skel->bss->task_storage_val, 2, "task_storage_val");
43 	ASSERT_EQ(skel->bss->cgroup_id, cgroup_id, "cgroup_id");
44 out:
45 	rcu_read_lock__destroy(skel);
46 }
47 
48 static void test_rcuptr_acquire(void)
49 {
50 	struct rcu_read_lock *skel;
51 	int err;
52 
53 	skel = rcu_read_lock__open();
54 	if (!ASSERT_OK_PTR(skel, "skel_open"))
55 		return;
56 
57 	skel->bss->target_pid = syscall(SYS_gettid);
58 
59 	bpf_program__set_autoload(skel->progs.task_acquire, true);
60 	err = rcu_read_lock__load(skel);
61 	if (!ASSERT_OK(err, "skel_load"))
62 		goto out;
63 
64 	err = rcu_read_lock__attach(skel);
65 	ASSERT_OK(err, "skel_attach");
66 out:
67 	rcu_read_lock__destroy(skel);
68 }
69 
70 static const char * const inproper_region_tests[] = {
71 	"miss_lock",
72 	"miss_unlock",
73 	"non_sleepable_rcu_mismatch",
74 	"inproper_sleepable_helper",
75 	"inproper_sleepable_kfunc",
76 	"nested_rcu_region",
77 };
78 
79 static void test_inproper_region(void)
80 {
81 	struct rcu_read_lock *skel;
82 	struct bpf_program *prog;
83 	int i, err;
84 
85 	for (i = 0; i < ARRAY_SIZE(inproper_region_tests); i++) {
86 		skel = rcu_read_lock__open();
87 		if (!ASSERT_OK_PTR(skel, "skel_open"))
88 			return;
89 
90 		prog = bpf_object__find_program_by_name(skel->obj, inproper_region_tests[i]);
91 		if (!ASSERT_OK_PTR(prog, "bpf_object__find_program_by_name"))
92 			goto out;
93 		bpf_program__set_autoload(prog, true);
94 		err = rcu_read_lock__load(skel);
95 		ASSERT_ERR(err, "skel_load");
96 out:
97 		rcu_read_lock__destroy(skel);
98 	}
99 }
100 
101 static const char * const rcuptr_misuse_tests[] = {
102 	"task_untrusted_non_rcuptr",
103 	"task_untrusted_rcuptr",
104 	"cross_rcu_region",
105 };
106 
107 static void test_rcuptr_misuse(void)
108 {
109 	struct rcu_read_lock *skel;
110 	struct bpf_program *prog;
111 	int i, err;
112 
113 	for (i = 0; i < ARRAY_SIZE(rcuptr_misuse_tests); i++) {
114 		skel = rcu_read_lock__open();
115 		if (!ASSERT_OK_PTR(skel, "skel_open"))
116 			return;
117 
118 		prog = bpf_object__find_program_by_name(skel->obj, rcuptr_misuse_tests[i]);
119 		if (!ASSERT_OK_PTR(prog, "bpf_object__find_program_by_name"))
120 			goto out;
121 		bpf_program__set_autoload(prog, true);
122 		err = rcu_read_lock__load(skel);
123 		ASSERT_ERR(err, "skel_load");
124 out:
125 		rcu_read_lock__destroy(skel);
126 	}
127 }
128 
129 void test_rcu_read_lock(void)
130 {
131 	struct btf *vmlinux_btf;
132 	int cgroup_fd;
133 
134 	vmlinux_btf = btf__load_vmlinux_btf();
135 	if (!ASSERT_OK_PTR(vmlinux_btf, "could not load vmlinux BTF"))
136 		return;
137 	if (btf__find_by_name_kind(vmlinux_btf, "rcu", BTF_KIND_TYPE_TAG) < 0) {
138 		test__skip();
139 		goto out;
140 	}
141 
142 	cgroup_fd = test__join_cgroup("/rcu_read_lock");
143 	if (!ASSERT_GE(cgroup_fd, 0, "join_cgroup /rcu_read_lock"))
144 		goto out;
145 
146 	cgroup_id = get_cgroup_id("/rcu_read_lock");
147 	if (test__start_subtest("success"))
148 		test_success();
149 	if (test__start_subtest("rcuptr_acquire"))
150 		test_rcuptr_acquire();
151 	if (test__start_subtest("negative_tests_inproper_region"))
152 		test_inproper_region();
153 	if (test__start_subtest("negative_tests_rcuptr_misuse"))
154 		test_rcuptr_misuse();
155 	close(cgroup_fd);
156 out:
157 	btf__free(vmlinux_btf);
158 }
159