1 /* SPDX-License-Identifier: GPL-2.0 */
2 
3 #include <errno.h>
4 #include <linux/limits.h>
5 #include <stdbool.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <string.h>
9 #include <sys/types.h>
10 #include <unistd.h>
11 
12 #include "../kselftest.h"
13 #include "../pidfd/pidfd.h"
14 #include "cgroup_util.h"
15 
16 /*
17  * Kill the given cgroup and wait for the inotify signal.
18  * If there are no events in 10 seconds, treat this as an error.
19  * Then check that the cgroup is in the desired state.
20  */
21 static int cg_kill_wait(const char *cgroup)
22 {
23 	int fd, ret = -1;
24 
25 	fd = cg_prepare_for_wait(cgroup);
26 	if (fd < 0)
27 		return fd;
28 
29 	ret = cg_write(cgroup, "cgroup.kill", "1");
30 	if (ret)
31 		goto out;
32 
33 	ret = cg_wait_for(fd);
34 	if (ret)
35 		goto out;
36 
37 out:
38 	close(fd);
39 	return ret;
40 }
41 
42 /*
43  * A simple process running in a sleep loop until being
44  * re-parented.
45  */
46 static int child_fn(const char *cgroup, void *arg)
47 {
48 	int ppid = getppid();
49 
50 	while (getppid() == ppid)
51 		usleep(1000);
52 
53 	return getppid() == ppid;
54 }
55 
56 static int test_cgkill_simple(const char *root)
57 {
58 	pid_t pids[100];
59 	int ret = KSFT_FAIL;
60 	char *cgroup = NULL;
61 	int i;
62 
63 	cgroup = cg_name(root, "cg_test_simple");
64 	if (!cgroup)
65 		goto cleanup;
66 
67 	if (cg_create(cgroup))
68 		goto cleanup;
69 
70 	for (i = 0; i < 100; i++)
71 		pids[i] = cg_run_nowait(cgroup, child_fn, NULL);
72 
73 	if (cg_wait_for_proc_count(cgroup, 100))
74 		goto cleanup;
75 
76 	if (cg_read_strcmp(cgroup, "cgroup.events", "populated 1\n"))
77 		goto cleanup;
78 
79 	if (cg_kill_wait(cgroup))
80 		goto cleanup;
81 
82 	ret = KSFT_PASS;
83 
84 cleanup:
85 	for (i = 0; i < 100; i++)
86 		wait_for_pid(pids[i]);
87 
88 	if (ret == KSFT_PASS &&
89 	    cg_read_strcmp(cgroup, "cgroup.events", "populated 0\n"))
90 		ret = KSFT_FAIL;
91 
92 	if (cgroup)
93 		cg_destroy(cgroup);
94 	free(cgroup);
95 	return ret;
96 }
97 
98 /*
99  * The test creates the following hierarchy:
100  *       A
101  *    / / \ \
102  *   B  E  I K
103  *  /\  |
104  * C  D F
105  *      |
106  *      G
107  *      |
108  *      H
109  *
110  * with a process in C, H and 3 processes in K.
111  * Then it tries to kill the whole tree.
112  */
113 static int test_cgkill_tree(const char *root)
114 {
115 	pid_t pids[5];
116 	char *cgroup[10] = {0};
117 	int ret = KSFT_FAIL;
118 	int i;
119 
120 	cgroup[0] = cg_name(root, "cg_test_tree_A");
121 	if (!cgroup[0])
122 		goto cleanup;
123 
124 	cgroup[1] = cg_name(cgroup[0], "B");
125 	if (!cgroup[1])
126 		goto cleanup;
127 
128 	cgroup[2] = cg_name(cgroup[1], "C");
129 	if (!cgroup[2])
130 		goto cleanup;
131 
132 	cgroup[3] = cg_name(cgroup[1], "D");
133 	if (!cgroup[3])
134 		goto cleanup;
135 
136 	cgroup[4] = cg_name(cgroup[0], "E");
137 	if (!cgroup[4])
138 		goto cleanup;
139 
140 	cgroup[5] = cg_name(cgroup[4], "F");
141 	if (!cgroup[5])
142 		goto cleanup;
143 
144 	cgroup[6] = cg_name(cgroup[5], "G");
145 	if (!cgroup[6])
146 		goto cleanup;
147 
148 	cgroup[7] = cg_name(cgroup[6], "H");
149 	if (!cgroup[7])
150 		goto cleanup;
151 
152 	cgroup[8] = cg_name(cgroup[0], "I");
153 	if (!cgroup[8])
154 		goto cleanup;
155 
156 	cgroup[9] = cg_name(cgroup[0], "K");
157 	if (!cgroup[9])
158 		goto cleanup;
159 
160 	for (i = 0; i < 10; i++)
161 		if (cg_create(cgroup[i]))
162 			goto cleanup;
163 
164 	pids[0] = cg_run_nowait(cgroup[2], child_fn, NULL);
165 	pids[1] = cg_run_nowait(cgroup[7], child_fn, NULL);
166 	pids[2] = cg_run_nowait(cgroup[9], child_fn, NULL);
167 	pids[3] = cg_run_nowait(cgroup[9], child_fn, NULL);
168 	pids[4] = cg_run_nowait(cgroup[9], child_fn, NULL);
169 
170 	/*
171 	 * Wait until all child processes will enter
172 	 * corresponding cgroups.
173 	 */
174 
175 	if (cg_wait_for_proc_count(cgroup[2], 1) ||
176 	    cg_wait_for_proc_count(cgroup[7], 1) ||
177 	    cg_wait_for_proc_count(cgroup[9], 3))
178 		goto cleanup;
179 
180 	/*
181 	 * Kill A and check that we get an empty notification.
182 	 */
183 	if (cg_kill_wait(cgroup[0]))
184 		goto cleanup;
185 
186 	ret = KSFT_PASS;
187 
188 cleanup:
189 	for (i = 0; i < 5; i++)
190 		wait_for_pid(pids[i]);
191 
192 	if (ret == KSFT_PASS &&
193 	    cg_read_strcmp(cgroup[0], "cgroup.events", "populated 0\n"))
194 		ret = KSFT_FAIL;
195 
196 	for (i = 9; i >= 0 && cgroup[i]; i--) {
197 		cg_destroy(cgroup[i]);
198 		free(cgroup[i]);
199 	}
200 
201 	return ret;
202 }
203 
204 static int forkbomb_fn(const char *cgroup, void *arg)
205 {
206 	int ppid;
207 
208 	fork();
209 	fork();
210 
211 	ppid = getppid();
212 
213 	while (getppid() == ppid)
214 		usleep(1000);
215 
216 	return getppid() == ppid;
217 }
218 
219 /*
220  * The test runs a fork bomb in a cgroup and tries to kill it.
221  */
222 static int test_cgkill_forkbomb(const char *root)
223 {
224 	int ret = KSFT_FAIL;
225 	char *cgroup = NULL;
226 	pid_t pid = -ESRCH;
227 
228 	cgroup = cg_name(root, "cg_forkbomb_test");
229 	if (!cgroup)
230 		goto cleanup;
231 
232 	if (cg_create(cgroup))
233 		goto cleanup;
234 
235 	pid = cg_run_nowait(cgroup, forkbomb_fn, NULL);
236 	if (pid < 0)
237 		goto cleanup;
238 
239 	usleep(100000);
240 
241 	if (cg_kill_wait(cgroup))
242 		goto cleanup;
243 
244 	if (cg_wait_for_proc_count(cgroup, 0))
245 		goto cleanup;
246 
247 	ret = KSFT_PASS;
248 
249 cleanup:
250 	if (pid > 0)
251 		wait_for_pid(pid);
252 
253 	if (ret == KSFT_PASS &&
254 	    cg_read_strcmp(cgroup, "cgroup.events", "populated 0\n"))
255 		ret = KSFT_FAIL;
256 
257 	if (cgroup)
258 		cg_destroy(cgroup);
259 	free(cgroup);
260 	return ret;
261 }
262 
263 #define T(x) { x, #x }
264 struct cgkill_test {
265 	int (*fn)(const char *root);
266 	const char *name;
267 } tests[] = {
268 	T(test_cgkill_simple),
269 	T(test_cgkill_tree),
270 	T(test_cgkill_forkbomb),
271 };
272 #undef T
273 
274 int main(int argc, char *argv[])
275 {
276 	char root[PATH_MAX];
277 	int i, ret = EXIT_SUCCESS;
278 
279 	if (cg_find_unified_root(root, sizeof(root)))
280 		ksft_exit_skip("cgroup v2 isn't mounted\n");
281 	for (i = 0; i < ARRAY_SIZE(tests); i++) {
282 		switch (tests[i].fn(root)) {
283 		case KSFT_PASS:
284 			ksft_test_result_pass("%s\n", tests[i].name);
285 			break;
286 		case KSFT_SKIP:
287 			ksft_test_result_skip("%s\n", tests[i].name);
288 			break;
289 		default:
290 			ret = EXIT_FAILURE;
291 			ksft_test_result_fail("%s\n", tests[i].name);
292 			break;
293 		}
294 	}
295 
296 	return ret;
297 }
298