1 /* SPDX-License-Identifier: GPL-2.0 */
2 #define _GNU_SOURCE
3 
4 #include <linux/limits.h>
5 #include <fcntl.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <string.h>
9 #include <sys/stat.h>
10 #include <sys/types.h>
11 #include <unistd.h>
12 #include <sys/socket.h>
13 #include <sys/wait.h>
14 #include <arpa/inet.h>
15 #include <netinet/in.h>
16 #include <netdb.h>
17 #include <errno.h>
18 
19 #include "../kselftest.h"
20 #include "cgroup_util.h"
21 
22 /*
23  * This test creates two nested cgroups with and without enabling
24  * the memory controller.
25  */
26 static int test_memcg_subtree_control(const char *root)
27 {
28 	char *parent, *child, *parent2, *child2;
29 	int ret = KSFT_FAIL;
30 	char buf[PAGE_SIZE];
31 
32 	/* Create two nested cgroups with the memory controller enabled */
33 	parent = cg_name(root, "memcg_test_0");
34 	child = cg_name(root, "memcg_test_0/memcg_test_1");
35 	if (!parent || !child)
36 		goto cleanup;
37 
38 	if (cg_create(parent))
39 		goto cleanup;
40 
41 	if (cg_write(parent, "cgroup.subtree_control", "+memory"))
42 		goto cleanup;
43 
44 	if (cg_create(child))
45 		goto cleanup;
46 
47 	if (cg_read_strstr(child, "cgroup.controllers", "memory"))
48 		goto cleanup;
49 
50 	/* Create two nested cgroups without enabling memory controller */
51 	parent2 = cg_name(root, "memcg_test_1");
52 	child2 = cg_name(root, "memcg_test_1/memcg_test_1");
53 	if (!parent2 || !child2)
54 		goto cleanup;
55 
56 	if (cg_create(parent2))
57 		goto cleanup;
58 
59 	if (cg_create(child2))
60 		goto cleanup;
61 
62 	if (cg_read(child2, "cgroup.controllers", buf, sizeof(buf)))
63 		goto cleanup;
64 
65 	if (!cg_read_strstr(child2, "cgroup.controllers", "memory"))
66 		goto cleanup;
67 
68 	ret = KSFT_PASS;
69 
70 cleanup:
71 	cg_destroy(child);
72 	cg_destroy(parent);
73 	free(parent);
74 	free(child);
75 
76 	cg_destroy(child2);
77 	cg_destroy(parent2);
78 	free(parent2);
79 	free(child2);
80 
81 	return ret;
82 }
83 
84 static int alloc_anon_50M_check(const char *cgroup, void *arg)
85 {
86 	size_t size = MB(50);
87 	char *buf, *ptr;
88 	long anon, current;
89 	int ret = -1;
90 
91 	buf = malloc(size);
92 	for (ptr = buf; ptr < buf + size; ptr += PAGE_SIZE)
93 		*ptr = 0;
94 
95 	current = cg_read_long(cgroup, "memory.current");
96 	if (current < size)
97 		goto cleanup;
98 
99 	if (!values_close(size, current, 3))
100 		goto cleanup;
101 
102 	anon = cg_read_key_long(cgroup, "memory.stat", "anon ");
103 	if (anon < 0)
104 		goto cleanup;
105 
106 	if (!values_close(anon, current, 3))
107 		goto cleanup;
108 
109 	ret = 0;
110 cleanup:
111 	free(buf);
112 	return ret;
113 }
114 
115 static int alloc_pagecache_50M_check(const char *cgroup, void *arg)
116 {
117 	size_t size = MB(50);
118 	int ret = -1;
119 	long current, file;
120 	int fd;
121 
122 	fd = get_temp_fd();
123 	if (fd < 0)
124 		return -1;
125 
126 	if (alloc_pagecache(fd, size))
127 		goto cleanup;
128 
129 	current = cg_read_long(cgroup, "memory.current");
130 	if (current < size)
131 		goto cleanup;
132 
133 	file = cg_read_key_long(cgroup, "memory.stat", "file ");
134 	if (file < 0)
135 		goto cleanup;
136 
137 	if (!values_close(file, current, 10))
138 		goto cleanup;
139 
140 	ret = 0;
141 
142 cleanup:
143 	close(fd);
144 	return ret;
145 }
146 
147 /*
148  * This test create a memory cgroup, allocates
149  * some anonymous memory and some pagecache
150  * and check memory.current and some memory.stat values.
151  */
152 static int test_memcg_current(const char *root)
153 {
154 	int ret = KSFT_FAIL;
155 	long current;
156 	char *memcg;
157 
158 	memcg = cg_name(root, "memcg_test");
159 	if (!memcg)
160 		goto cleanup;
161 
162 	if (cg_create(memcg))
163 		goto cleanup;
164 
165 	current = cg_read_long(memcg, "memory.current");
166 	if (current != 0)
167 		goto cleanup;
168 
169 	if (cg_run(memcg, alloc_anon_50M_check, NULL))
170 		goto cleanup;
171 
172 	if (cg_run(memcg, alloc_pagecache_50M_check, NULL))
173 		goto cleanup;
174 
175 	ret = KSFT_PASS;
176 
177 cleanup:
178 	cg_destroy(memcg);
179 	free(memcg);
180 
181 	return ret;
182 }
183 
184 static int alloc_pagecache_50M(const char *cgroup, void *arg)
185 {
186 	int fd = (long)arg;
187 
188 	return alloc_pagecache(fd, MB(50));
189 }
190 
191 static int alloc_pagecache_50M_noexit(const char *cgroup, void *arg)
192 {
193 	int fd = (long)arg;
194 	int ppid = getppid();
195 
196 	if (alloc_pagecache(fd, MB(50)))
197 		return -1;
198 
199 	while (getppid() == ppid)
200 		sleep(1);
201 
202 	return 0;
203 }
204 
205 /*
206  * First, this test creates the following hierarchy:
207  * A       memory.min = 50M,  memory.max = 200M
208  * A/B     memory.min = 50M,  memory.current = 50M
209  * A/B/C   memory.min = 75M,  memory.current = 50M
210  * A/B/D   memory.min = 25M,  memory.current = 50M
211  * A/B/E   memory.min = 500M, memory.current = 0
212  * A/B/F   memory.min = 0,    memory.current = 50M
213  *
214  * Usages are pagecache, but the test keeps a running
215  * process in every leaf cgroup.
216  * Then it creates A/G and creates a significant
217  * memory pressure in it.
218  *
219  * A/B    memory.current ~= 50M
220  * A/B/C  memory.current ~= 33M
221  * A/B/D  memory.current ~= 17M
222  * A/B/E  memory.current ~= 0
223  *
224  * After that it tries to allocate more than there is
225  * unprotected memory in A available, and checks
226  * checks that memory.min protects pagecache even
227  * in this case.
228  */
229 static int test_memcg_min(const char *root)
230 {
231 	int ret = KSFT_FAIL;
232 	char *parent[3] = {NULL};
233 	char *children[4] = {NULL};
234 	long c[4];
235 	int i, attempts;
236 	int fd;
237 
238 	fd = get_temp_fd();
239 	if (fd < 0)
240 		goto cleanup;
241 
242 	parent[0] = cg_name(root, "memcg_test_0");
243 	if (!parent[0])
244 		goto cleanup;
245 
246 	parent[1] = cg_name(parent[0], "memcg_test_1");
247 	if (!parent[1])
248 		goto cleanup;
249 
250 	parent[2] = cg_name(parent[0], "memcg_test_2");
251 	if (!parent[2])
252 		goto cleanup;
253 
254 	if (cg_create(parent[0]))
255 		goto cleanup;
256 
257 	if (cg_read_long(parent[0], "memory.min")) {
258 		ret = KSFT_SKIP;
259 		goto cleanup;
260 	}
261 
262 	if (cg_write(parent[0], "cgroup.subtree_control", "+memory"))
263 		goto cleanup;
264 
265 	if (cg_write(parent[0], "memory.max", "200M"))
266 		goto cleanup;
267 
268 	if (cg_write(parent[0], "memory.swap.max", "0"))
269 		goto cleanup;
270 
271 	if (cg_create(parent[1]))
272 		goto cleanup;
273 
274 	if (cg_write(parent[1], "cgroup.subtree_control", "+memory"))
275 		goto cleanup;
276 
277 	if (cg_create(parent[2]))
278 		goto cleanup;
279 
280 	for (i = 0; i < ARRAY_SIZE(children); i++) {
281 		children[i] = cg_name_indexed(parent[1], "child_memcg", i);
282 		if (!children[i])
283 			goto cleanup;
284 
285 		if (cg_create(children[i]))
286 			goto cleanup;
287 
288 		if (i == 2)
289 			continue;
290 
291 		cg_run_nowait(children[i], alloc_pagecache_50M_noexit,
292 			      (void *)(long)fd);
293 	}
294 
295 	if (cg_write(parent[0], "memory.min", "50M"))
296 		goto cleanup;
297 	if (cg_write(parent[1], "memory.min", "50M"))
298 		goto cleanup;
299 	if (cg_write(children[0], "memory.min", "75M"))
300 		goto cleanup;
301 	if (cg_write(children[1], "memory.min", "25M"))
302 		goto cleanup;
303 	if (cg_write(children[2], "memory.min", "500M"))
304 		goto cleanup;
305 	if (cg_write(children[3], "memory.min", "0"))
306 		goto cleanup;
307 
308 	attempts = 0;
309 	while (!values_close(cg_read_long(parent[1], "memory.current"),
310 			     MB(150), 3)) {
311 		if (attempts++ > 5)
312 			break;
313 		sleep(1);
314 	}
315 
316 	if (cg_run(parent[2], alloc_anon, (void *)MB(148)))
317 		goto cleanup;
318 
319 	if (!values_close(cg_read_long(parent[1], "memory.current"), MB(50), 3))
320 		goto cleanup;
321 
322 	for (i = 0; i < ARRAY_SIZE(children); i++)
323 		c[i] = cg_read_long(children[i], "memory.current");
324 
325 	if (!values_close(c[0], MB(33), 10))
326 		goto cleanup;
327 
328 	if (!values_close(c[1], MB(17), 10))
329 		goto cleanup;
330 
331 	if (!values_close(c[2], 0, 1))
332 		goto cleanup;
333 
334 	if (!cg_run(parent[2], alloc_anon, (void *)MB(170)))
335 		goto cleanup;
336 
337 	if (!values_close(cg_read_long(parent[1], "memory.current"), MB(50), 3))
338 		goto cleanup;
339 
340 	ret = KSFT_PASS;
341 
342 cleanup:
343 	for (i = ARRAY_SIZE(children) - 1; i >= 0; i--) {
344 		if (!children[i])
345 			continue;
346 
347 		cg_destroy(children[i]);
348 		free(children[i]);
349 	}
350 
351 	for (i = ARRAY_SIZE(parent) - 1; i >= 0; i--) {
352 		if (!parent[i])
353 			continue;
354 
355 		cg_destroy(parent[i]);
356 		free(parent[i]);
357 	}
358 	close(fd);
359 	return ret;
360 }
361 
362 /*
363  * First, this test creates the following hierarchy:
364  * A       memory.low = 50M,  memory.max = 200M
365  * A/B     memory.low = 50M,  memory.current = 50M
366  * A/B/C   memory.low = 75M,  memory.current = 50M
367  * A/B/D   memory.low = 25M,  memory.current = 50M
368  * A/B/E   memory.low = 500M, memory.current = 0
369  * A/B/F   memory.low = 0,    memory.current = 50M
370  *
371  * Usages are pagecache.
372  * Then it creates A/G an creates a significant
373  * memory pressure in it.
374  *
375  * Then it checks actual memory usages and expects that:
376  * A/B    memory.current ~= 50M
377  * A/B/   memory.current ~= 33M
378  * A/B/D  memory.current ~= 17M
379  * A/B/E  memory.current ~= 0
380  *
381  * After that it tries to allocate more than there is
382  * unprotected memory in A available,
383  * and checks low and oom events in memory.events.
384  */
385 static int test_memcg_low(const char *root)
386 {
387 	int ret = KSFT_FAIL;
388 	char *parent[3] = {NULL};
389 	char *children[4] = {NULL};
390 	long low, oom;
391 	long c[4];
392 	int i;
393 	int fd;
394 
395 	fd = get_temp_fd();
396 	if (fd < 0)
397 		goto cleanup;
398 
399 	parent[0] = cg_name(root, "memcg_test_0");
400 	if (!parent[0])
401 		goto cleanup;
402 
403 	parent[1] = cg_name(parent[0], "memcg_test_1");
404 	if (!parent[1])
405 		goto cleanup;
406 
407 	parent[2] = cg_name(parent[0], "memcg_test_2");
408 	if (!parent[2])
409 		goto cleanup;
410 
411 	if (cg_create(parent[0]))
412 		goto cleanup;
413 
414 	if (cg_read_long(parent[0], "memory.low"))
415 		goto cleanup;
416 
417 	if (cg_write(parent[0], "cgroup.subtree_control", "+memory"))
418 		goto cleanup;
419 
420 	if (cg_write(parent[0], "memory.max", "200M"))
421 		goto cleanup;
422 
423 	if (cg_write(parent[0], "memory.swap.max", "0"))
424 		goto cleanup;
425 
426 	if (cg_create(parent[1]))
427 		goto cleanup;
428 
429 	if (cg_write(parent[1], "cgroup.subtree_control", "+memory"))
430 		goto cleanup;
431 
432 	if (cg_create(parent[2]))
433 		goto cleanup;
434 
435 	for (i = 0; i < ARRAY_SIZE(children); i++) {
436 		children[i] = cg_name_indexed(parent[1], "child_memcg", i);
437 		if (!children[i])
438 			goto cleanup;
439 
440 		if (cg_create(children[i]))
441 			goto cleanup;
442 
443 		if (i == 2)
444 			continue;
445 
446 		if (cg_run(children[i], alloc_pagecache_50M, (void *)(long)fd))
447 			goto cleanup;
448 	}
449 
450 	if (cg_write(parent[0], "memory.low", "50M"))
451 		goto cleanup;
452 	if (cg_write(parent[1], "memory.low", "50M"))
453 		goto cleanup;
454 	if (cg_write(children[0], "memory.low", "75M"))
455 		goto cleanup;
456 	if (cg_write(children[1], "memory.low", "25M"))
457 		goto cleanup;
458 	if (cg_write(children[2], "memory.low", "500M"))
459 		goto cleanup;
460 	if (cg_write(children[3], "memory.low", "0"))
461 		goto cleanup;
462 
463 	if (cg_run(parent[2], alloc_anon, (void *)MB(148)))
464 		goto cleanup;
465 
466 	if (!values_close(cg_read_long(parent[1], "memory.current"), MB(50), 3))
467 		goto cleanup;
468 
469 	for (i = 0; i < ARRAY_SIZE(children); i++)
470 		c[i] = cg_read_long(children[i], "memory.current");
471 
472 	if (!values_close(c[0], MB(33), 10))
473 		goto cleanup;
474 
475 	if (!values_close(c[1], MB(17), 10))
476 		goto cleanup;
477 
478 	if (!values_close(c[2], 0, 1))
479 		goto cleanup;
480 
481 	if (cg_run(parent[2], alloc_anon, (void *)MB(166))) {
482 		fprintf(stderr,
483 			"memory.low prevents from allocating anon memory\n");
484 		goto cleanup;
485 	}
486 
487 	for (i = 0; i < ARRAY_SIZE(children); i++) {
488 		oom = cg_read_key_long(children[i], "memory.events", "oom ");
489 		low = cg_read_key_long(children[i], "memory.events", "low ");
490 
491 		if (oom)
492 			goto cleanup;
493 		if (i < 2 && low <= 0)
494 			goto cleanup;
495 		if (i >= 2 && low)
496 			goto cleanup;
497 	}
498 
499 	ret = KSFT_PASS;
500 
501 cleanup:
502 	for (i = ARRAY_SIZE(children) - 1; i >= 0; i--) {
503 		if (!children[i])
504 			continue;
505 
506 		cg_destroy(children[i]);
507 		free(children[i]);
508 	}
509 
510 	for (i = ARRAY_SIZE(parent) - 1; i >= 0; i--) {
511 		if (!parent[i])
512 			continue;
513 
514 		cg_destroy(parent[i]);
515 		free(parent[i]);
516 	}
517 	close(fd);
518 	return ret;
519 }
520 
521 static int alloc_pagecache_max_30M(const char *cgroup, void *arg)
522 {
523 	size_t size = MB(50);
524 	int ret = -1;
525 	long current;
526 	int fd;
527 
528 	fd = get_temp_fd();
529 	if (fd < 0)
530 		return -1;
531 
532 	if (alloc_pagecache(fd, size))
533 		goto cleanup;
534 
535 	current = cg_read_long(cgroup, "memory.current");
536 	if (current <= MB(29) || current > MB(30))
537 		goto cleanup;
538 
539 	ret = 0;
540 
541 cleanup:
542 	close(fd);
543 	return ret;
544 
545 }
546 
547 /*
548  * This test checks that memory.high limits the amount of
549  * memory which can be consumed by either anonymous memory
550  * or pagecache.
551  */
552 static int test_memcg_high(const char *root)
553 {
554 	int ret = KSFT_FAIL;
555 	char *memcg;
556 	long high;
557 
558 	memcg = cg_name(root, "memcg_test");
559 	if (!memcg)
560 		goto cleanup;
561 
562 	if (cg_create(memcg))
563 		goto cleanup;
564 
565 	if (cg_read_strcmp(memcg, "memory.high", "max\n"))
566 		goto cleanup;
567 
568 	if (cg_write(memcg, "memory.swap.max", "0"))
569 		goto cleanup;
570 
571 	if (cg_write(memcg, "memory.high", "30M"))
572 		goto cleanup;
573 
574 	if (cg_run(memcg, alloc_anon, (void *)MB(100)))
575 		goto cleanup;
576 
577 	if (!cg_run(memcg, alloc_pagecache_50M_check, NULL))
578 		goto cleanup;
579 
580 	if (cg_run(memcg, alloc_pagecache_max_30M, NULL))
581 		goto cleanup;
582 
583 	high = cg_read_key_long(memcg, "memory.events", "high ");
584 	if (high <= 0)
585 		goto cleanup;
586 
587 	ret = KSFT_PASS;
588 
589 cleanup:
590 	cg_destroy(memcg);
591 	free(memcg);
592 
593 	return ret;
594 }
595 
596 /*
597  * This test checks that memory.max limits the amount of
598  * memory which can be consumed by either anonymous memory
599  * or pagecache.
600  */
601 static int test_memcg_max(const char *root)
602 {
603 	int ret = KSFT_FAIL;
604 	char *memcg;
605 	long current, max;
606 
607 	memcg = cg_name(root, "memcg_test");
608 	if (!memcg)
609 		goto cleanup;
610 
611 	if (cg_create(memcg))
612 		goto cleanup;
613 
614 	if (cg_read_strcmp(memcg, "memory.max", "max\n"))
615 		goto cleanup;
616 
617 	if (cg_write(memcg, "memory.swap.max", "0"))
618 		goto cleanup;
619 
620 	if (cg_write(memcg, "memory.max", "30M"))
621 		goto cleanup;
622 
623 	/* Should be killed by OOM killer */
624 	if (!cg_run(memcg, alloc_anon, (void *)MB(100)))
625 		goto cleanup;
626 
627 	if (cg_run(memcg, alloc_pagecache_max_30M, NULL))
628 		goto cleanup;
629 
630 	current = cg_read_long(memcg, "memory.current");
631 	if (current > MB(30) || !current)
632 		goto cleanup;
633 
634 	max = cg_read_key_long(memcg, "memory.events", "max ");
635 	if (max <= 0)
636 		goto cleanup;
637 
638 	ret = KSFT_PASS;
639 
640 cleanup:
641 	cg_destroy(memcg);
642 	free(memcg);
643 
644 	return ret;
645 }
646 
647 static int alloc_anon_50M_check_swap(const char *cgroup, void *arg)
648 {
649 	long mem_max = (long)arg;
650 	size_t size = MB(50);
651 	char *buf, *ptr;
652 	long mem_current, swap_current;
653 	int ret = -1;
654 
655 	buf = malloc(size);
656 	for (ptr = buf; ptr < buf + size; ptr += PAGE_SIZE)
657 		*ptr = 0;
658 
659 	mem_current = cg_read_long(cgroup, "memory.current");
660 	if (!mem_current || !values_close(mem_current, mem_max, 3))
661 		goto cleanup;
662 
663 	swap_current = cg_read_long(cgroup, "memory.swap.current");
664 	if (!swap_current ||
665 	    !values_close(mem_current + swap_current, size, 3))
666 		goto cleanup;
667 
668 	ret = 0;
669 cleanup:
670 	free(buf);
671 	return ret;
672 }
673 
674 /*
675  * This test checks that memory.swap.max limits the amount of
676  * anonymous memory which can be swapped out.
677  */
678 static int test_memcg_swap_max(const char *root)
679 {
680 	int ret = KSFT_FAIL;
681 	char *memcg;
682 	long max;
683 
684 	if (!is_swap_enabled())
685 		return KSFT_SKIP;
686 
687 	memcg = cg_name(root, "memcg_test");
688 	if (!memcg)
689 		goto cleanup;
690 
691 	if (cg_create(memcg))
692 		goto cleanup;
693 
694 	if (cg_read_long(memcg, "memory.swap.current")) {
695 		ret = KSFT_SKIP;
696 		goto cleanup;
697 	}
698 
699 	if (cg_read_strcmp(memcg, "memory.max", "max\n"))
700 		goto cleanup;
701 
702 	if (cg_read_strcmp(memcg, "memory.swap.max", "max\n"))
703 		goto cleanup;
704 
705 	if (cg_write(memcg, "memory.swap.max", "30M"))
706 		goto cleanup;
707 
708 	if (cg_write(memcg, "memory.max", "30M"))
709 		goto cleanup;
710 
711 	/* Should be killed by OOM killer */
712 	if (!cg_run(memcg, alloc_anon, (void *)MB(100)))
713 		goto cleanup;
714 
715 	if (cg_read_key_long(memcg, "memory.events", "oom ") != 1)
716 		goto cleanup;
717 
718 	if (cg_read_key_long(memcg, "memory.events", "oom_kill ") != 1)
719 		goto cleanup;
720 
721 	if (cg_run(memcg, alloc_anon_50M_check_swap, (void *)MB(30)))
722 		goto cleanup;
723 
724 	max = cg_read_key_long(memcg, "memory.events", "max ");
725 	if (max <= 0)
726 		goto cleanup;
727 
728 	ret = KSFT_PASS;
729 
730 cleanup:
731 	cg_destroy(memcg);
732 	free(memcg);
733 
734 	return ret;
735 }
736 
737 /*
738  * This test disables swapping and tries to allocate anonymous memory
739  * up to OOM. Then it checks for oom and oom_kill events in
740  * memory.events.
741  */
742 static int test_memcg_oom_events(const char *root)
743 {
744 	int ret = KSFT_FAIL;
745 	char *memcg;
746 
747 	memcg = cg_name(root, "memcg_test");
748 	if (!memcg)
749 		goto cleanup;
750 
751 	if (cg_create(memcg))
752 		goto cleanup;
753 
754 	if (cg_write(memcg, "memory.max", "30M"))
755 		goto cleanup;
756 
757 	if (cg_write(memcg, "memory.swap.max", "0"))
758 		goto cleanup;
759 
760 	if (!cg_run(memcg, alloc_anon, (void *)MB(100)))
761 		goto cleanup;
762 
763 	if (cg_read_strcmp(memcg, "cgroup.procs", ""))
764 		goto cleanup;
765 
766 	if (cg_read_key_long(memcg, "memory.events", "oom ") != 1)
767 		goto cleanup;
768 
769 	if (cg_read_key_long(memcg, "memory.events", "oom_kill ") != 1)
770 		goto cleanup;
771 
772 	ret = KSFT_PASS;
773 
774 cleanup:
775 	cg_destroy(memcg);
776 	free(memcg);
777 
778 	return ret;
779 }
780 
781 struct tcp_server_args {
782 	unsigned short port;
783 	int ctl[2];
784 };
785 
786 static int tcp_server(const char *cgroup, void *arg)
787 {
788 	struct tcp_server_args *srv_args = arg;
789 	struct sockaddr_in6 saddr = { 0 };
790 	socklen_t slen = sizeof(saddr);
791 	int sk, client_sk, ctl_fd, yes = 1, ret = -1;
792 
793 	close(srv_args->ctl[0]);
794 	ctl_fd = srv_args->ctl[1];
795 
796 	saddr.sin6_family = AF_INET6;
797 	saddr.sin6_addr = in6addr_any;
798 	saddr.sin6_port = htons(srv_args->port);
799 
800 	sk = socket(AF_INET6, SOCK_STREAM, 0);
801 	if (sk < 0)
802 		return ret;
803 
804 	if (setsockopt(sk, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes)) < 0)
805 		goto cleanup;
806 
807 	if (bind(sk, (struct sockaddr *)&saddr, slen)) {
808 		write(ctl_fd, &errno, sizeof(errno));
809 		goto cleanup;
810 	}
811 
812 	if (listen(sk, 1))
813 		goto cleanup;
814 
815 	ret = 0;
816 	if (write(ctl_fd, &ret, sizeof(ret)) != sizeof(ret)) {
817 		ret = -1;
818 		goto cleanup;
819 	}
820 
821 	client_sk = accept(sk, NULL, NULL);
822 	if (client_sk < 0)
823 		goto cleanup;
824 
825 	ret = -1;
826 	for (;;) {
827 		uint8_t buf[0x100000];
828 
829 		if (write(client_sk, buf, sizeof(buf)) <= 0) {
830 			if (errno == ECONNRESET)
831 				ret = 0;
832 			break;
833 		}
834 	}
835 
836 	close(client_sk);
837 
838 cleanup:
839 	close(sk);
840 	return ret;
841 }
842 
843 static int tcp_client(const char *cgroup, unsigned short port)
844 {
845 	const char server[] = "localhost";
846 	struct addrinfo *ai;
847 	char servport[6];
848 	int retries = 0x10; /* nice round number */
849 	int sk, ret;
850 
851 	snprintf(servport, sizeof(servport), "%hd", port);
852 	ret = getaddrinfo(server, servport, NULL, &ai);
853 	if (ret)
854 		return ret;
855 
856 	sk = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
857 	if (sk < 0)
858 		goto free_ainfo;
859 
860 	ret = connect(sk, ai->ai_addr, ai->ai_addrlen);
861 	if (ret < 0)
862 		goto close_sk;
863 
864 	ret = KSFT_FAIL;
865 	while (retries--) {
866 		uint8_t buf[0x100000];
867 		long current, sock;
868 
869 		if (read(sk, buf, sizeof(buf)) <= 0)
870 			goto close_sk;
871 
872 		current = cg_read_long(cgroup, "memory.current");
873 		sock = cg_read_key_long(cgroup, "memory.stat", "sock ");
874 
875 		if (current < 0 || sock < 0)
876 			goto close_sk;
877 
878 		if (current < sock)
879 			goto close_sk;
880 
881 		if (values_close(current, sock, 10)) {
882 			ret = KSFT_PASS;
883 			break;
884 		}
885 	}
886 
887 close_sk:
888 	close(sk);
889 free_ainfo:
890 	freeaddrinfo(ai);
891 	return ret;
892 }
893 
894 /*
895  * This test checks socket memory accounting.
896  * The test forks a TCP server listens on a random port between 1000
897  * and 61000. Once it gets a client connection, it starts writing to
898  * its socket.
899  * The TCP client interleaves reads from the socket with check whether
900  * memory.current and memory.stat.sock are similar.
901  */
902 static int test_memcg_sock(const char *root)
903 {
904 	int bind_retries = 5, ret = KSFT_FAIL, pid, err;
905 	unsigned short port;
906 	char *memcg;
907 
908 	memcg = cg_name(root, "memcg_test");
909 	if (!memcg)
910 		goto cleanup;
911 
912 	if (cg_create(memcg))
913 		goto cleanup;
914 
915 	while (bind_retries--) {
916 		struct tcp_server_args args;
917 
918 		if (pipe(args.ctl))
919 			goto cleanup;
920 
921 		port = args.port = 1000 + rand() % 60000;
922 
923 		pid = cg_run_nowait(memcg, tcp_server, &args);
924 		if (pid < 0)
925 			goto cleanup;
926 
927 		close(args.ctl[1]);
928 		if (read(args.ctl[0], &err, sizeof(err)) != sizeof(err))
929 			goto cleanup;
930 		close(args.ctl[0]);
931 
932 		if (!err)
933 			break;
934 		if (err != EADDRINUSE)
935 			goto cleanup;
936 
937 		waitpid(pid, NULL, 0);
938 	}
939 
940 	if (err == EADDRINUSE) {
941 		ret = KSFT_SKIP;
942 		goto cleanup;
943 	}
944 
945 	if (tcp_client(memcg, port) != KSFT_PASS)
946 		goto cleanup;
947 
948 	waitpid(pid, &err, 0);
949 	if (WEXITSTATUS(err))
950 		goto cleanup;
951 
952 	if (cg_read_long(memcg, "memory.current") < 0)
953 		goto cleanup;
954 
955 	if (cg_read_key_long(memcg, "memory.stat", "sock "))
956 		goto cleanup;
957 
958 	ret = KSFT_PASS;
959 
960 cleanup:
961 	cg_destroy(memcg);
962 	free(memcg);
963 
964 	return ret;
965 }
966 
967 #define T(x) { x, #x }
968 struct memcg_test {
969 	int (*fn)(const char *root);
970 	const char *name;
971 } tests[] = {
972 	T(test_memcg_subtree_control),
973 	T(test_memcg_current),
974 	T(test_memcg_min),
975 	T(test_memcg_low),
976 	T(test_memcg_high),
977 	T(test_memcg_max),
978 	T(test_memcg_oom_events),
979 	T(test_memcg_swap_max),
980 	T(test_memcg_sock),
981 };
982 #undef T
983 
984 int main(int argc, char **argv)
985 {
986 	char root[PATH_MAX];
987 	int i, ret = EXIT_SUCCESS;
988 
989 	if (cg_find_unified_root(root, sizeof(root)))
990 		ksft_exit_skip("cgroup v2 isn't mounted\n");
991 
992 	/*
993 	 * Check that memory controller is available:
994 	 * memory is listed in cgroup.controllers
995 	 */
996 	if (cg_read_strstr(root, "cgroup.controllers", "memory"))
997 		ksft_exit_skip("memory controller isn't available\n");
998 
999 	for (i = 0; i < ARRAY_SIZE(tests); i++) {
1000 		switch (tests[i].fn(root)) {
1001 		case KSFT_PASS:
1002 			ksft_test_result_pass("%s\n", tests[i].name);
1003 			break;
1004 		case KSFT_SKIP:
1005 			ksft_test_result_skip("%s\n", tests[i].name);
1006 			break;
1007 		default:
1008 			ret = EXIT_FAILURE;
1009 			ksft_test_result_fail("%s\n", tests[i].name);
1010 			break;
1011 		}
1012 	}
1013 
1014 	return ret;
1015 }
1016