1 // SPDX-License-Identifier: GPL-2.0
2 #define _GNU_SOURCE
3 #include <errno.h>
4 #include <fcntl.h>
5 #include <limits.h>
6 #include <sched.h>
7 #include <stdarg.h>
8 #include <stdbool.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <sys/mount.h>
13 #include <sys/stat.h>
14 #include <sys/types.h>
15 #include <sys/vfs.h>
16 #include <unistd.h>
17 
18 #ifndef MS_NOSYMFOLLOW
19 # define MS_NOSYMFOLLOW 256     /* Do not follow symlinks */
20 #endif
21 
22 #ifndef ST_NOSYMFOLLOW
23 # define ST_NOSYMFOLLOW 0x2000  /* Do not follow symlinks */
24 #endif
25 
26 #define DATA "/tmp/data"
27 #define LINK "/tmp/symlink"
28 #define TMP  "/tmp"
29 
die(char * fmt,...)30 static void die(char *fmt, ...)
31 {
32 	va_list ap;
33 
34 	va_start(ap, fmt);
35 	vfprintf(stderr, fmt, ap);
36 	va_end(ap);
37 	exit(EXIT_FAILURE);
38 }
39 
vmaybe_write_file(bool enoent_ok,char * filename,char * fmt,va_list ap)40 static void vmaybe_write_file(bool enoent_ok, char *filename, char *fmt,
41 		va_list ap)
42 {
43 	ssize_t written;
44 	char buf[4096];
45 	int buf_len;
46 	int fd;
47 
48 	buf_len = vsnprintf(buf, sizeof(buf), fmt, ap);
49 	if (buf_len < 0)
50 		die("vsnprintf failed: %s\n", strerror(errno));
51 
52 	if (buf_len >= sizeof(buf))
53 		die("vsnprintf output truncated\n");
54 
55 	fd = open(filename, O_WRONLY);
56 	if (fd < 0) {
57 		if ((errno == ENOENT) && enoent_ok)
58 			return;
59 		die("open of %s failed: %s\n", filename, strerror(errno));
60 	}
61 
62 	written = write(fd, buf, buf_len);
63 	if (written != buf_len) {
64 		if (written >= 0) {
65 			die("short write to %s\n", filename);
66 		} else {
67 			die("write to %s failed: %s\n",
68 				filename, strerror(errno));
69 		}
70 	}
71 
72 	if (close(fd) != 0)
73 		die("close of %s failed: %s\n", filename, strerror(errno));
74 }
75 
maybe_write_file(char * filename,char * fmt,...)76 static void maybe_write_file(char *filename, char *fmt, ...)
77 {
78 	va_list ap;
79 
80 	va_start(ap, fmt);
81 	vmaybe_write_file(true, filename, fmt, ap);
82 	va_end(ap);
83 }
84 
write_file(char * filename,char * fmt,...)85 static void write_file(char *filename, char *fmt, ...)
86 {
87 	va_list ap;
88 
89 	va_start(ap, fmt);
90 	vmaybe_write_file(false, filename, fmt, ap);
91 	va_end(ap);
92 }
93 
create_and_enter_ns(void)94 static void create_and_enter_ns(void)
95 {
96 	uid_t uid = getuid();
97 	gid_t gid = getgid();
98 
99 	if (unshare(CLONE_NEWUSER) != 0)
100 		die("unshare(CLONE_NEWUSER) failed: %s\n", strerror(errno));
101 
102 	maybe_write_file("/proc/self/setgroups", "deny");
103 	write_file("/proc/self/uid_map", "0 %d 1", uid);
104 	write_file("/proc/self/gid_map", "0 %d 1", gid);
105 
106 	if (setgid(0) != 0)
107 		die("setgid(0) failed %s\n", strerror(errno));
108 	if (setuid(0) != 0)
109 		die("setuid(0) failed %s\n", strerror(errno));
110 
111 	if (unshare(CLONE_NEWNS) != 0)
112 		die("unshare(CLONE_NEWNS) failed: %s\n", strerror(errno));
113 }
114 
setup_symlink(void)115 static void setup_symlink(void)
116 {
117 	int data, err;
118 
119 	data = creat(DATA, O_RDWR);
120 	if (data < 0)
121 		die("creat failed: %s\n", strerror(errno));
122 
123 	err = symlink(DATA, LINK);
124 	if (err < 0)
125 		die("symlink failed: %s\n", strerror(errno));
126 
127 	if (close(data) != 0)
128 		die("close of %s failed: %s\n", DATA, strerror(errno));
129 }
130 
test_link_traversal(bool nosymfollow)131 static void test_link_traversal(bool nosymfollow)
132 {
133 	int link;
134 
135 	link = open(LINK, 0, O_RDWR);
136 	if (nosymfollow) {
137 		if ((link != -1 || errno != ELOOP)) {
138 			die("link traversal unexpected result: %d, %s\n",
139 					link, strerror(errno));
140 		}
141 	} else {
142 		if (link < 0)
143 			die("link traversal failed: %s\n", strerror(errno));
144 
145 		if (close(link) != 0)
146 			die("close of link failed: %s\n", strerror(errno));
147 	}
148 }
149 
test_readlink(void)150 static void test_readlink(void)
151 {
152 	char buf[4096];
153 	ssize_t ret;
154 
155 	bzero(buf, sizeof(buf));
156 
157 	ret = readlink(LINK, buf, sizeof(buf));
158 	if (ret < 0)
159 		die("readlink failed: %s\n", strerror(errno));
160 	if (strcmp(buf, DATA) != 0)
161 		die("readlink strcmp failed: '%s' '%s'\n", buf, DATA);
162 }
163 
test_realpath(void)164 static void test_realpath(void)
165 {
166 	char *path = realpath(LINK, NULL);
167 
168 	if (!path)
169 		die("realpath failed: %s\n", strerror(errno));
170 	if (strcmp(path, DATA) != 0)
171 		die("realpath strcmp failed\n");
172 
173 	free(path);
174 }
175 
test_statfs(bool nosymfollow)176 static void test_statfs(bool nosymfollow)
177 {
178 	struct statfs buf;
179 	int ret;
180 
181 	ret = statfs(TMP, &buf);
182 	if (ret)
183 		die("statfs failed: %s\n", strerror(errno));
184 
185 	if (nosymfollow) {
186 		if ((buf.f_flags & ST_NOSYMFOLLOW) == 0)
187 			die("ST_NOSYMFOLLOW not set on %s\n", TMP);
188 	} else {
189 		if ((buf.f_flags & ST_NOSYMFOLLOW) != 0)
190 			die("ST_NOSYMFOLLOW set on %s\n", TMP);
191 	}
192 }
193 
run_tests(bool nosymfollow)194 static void run_tests(bool nosymfollow)
195 {
196 	test_link_traversal(nosymfollow);
197 	test_readlink();
198 	test_realpath();
199 	test_statfs(nosymfollow);
200 }
201 
main(int argc,char ** argv)202 int main(int argc, char **argv)
203 {
204 	create_and_enter_ns();
205 
206 	if (mount("testing", TMP, "ramfs", 0, NULL) != 0)
207 		die("mount failed: %s\n", strerror(errno));
208 
209 	setup_symlink();
210 	run_tests(false);
211 
212 	if (mount("testing", TMP, "ramfs", MS_REMOUNT|MS_NOSYMFOLLOW, NULL) != 0)
213 		die("remount failed: %s\n", strerror(errno));
214 
215 	run_tests(true);
216 
217 	return EXIT_SUCCESS;
218 }
219