1 // SPDX-License-Identifier: GPL-2.0
2 #include <test_progs.h>
3 #include "cgroup_helpers.h"
4 
5 #include <linux/tcp.h>
6 
7 #ifndef SOL_TCP
8 #define SOL_TCP IPPROTO_TCP
9 #endif
10 
11 #define SOL_CUSTOM			0xdeadbeef
12 
13 static int getsetsockopt(void)
14 {
15 	int fd, err;
16 	union {
17 		char u8[4];
18 		__u32 u32;
19 		char cc[16]; /* TCP_CA_NAME_MAX */
20 		struct tcp_zerocopy_receive zc;
21 	} buf = {};
22 	socklen_t optlen;
23 	char *big_buf = NULL;
24 
25 	fd = socket(AF_INET, SOCK_STREAM, 0);
26 	if (fd < 0) {
27 		log_err("Failed to create socket");
28 		return -1;
29 	}
30 
31 	/* IP_TOS - BPF bypass */
32 
33 	optlen = getpagesize() * 2;
34 	big_buf = calloc(1, optlen);
35 	if (!big_buf) {
36 		log_err("Couldn't allocate two pages");
37 		goto err;
38 	}
39 
40 	*(int *)big_buf = 0x08;
41 	err = setsockopt(fd, SOL_IP, IP_TOS, big_buf, optlen);
42 	if (err) {
43 		log_err("Failed to call setsockopt(IP_TOS)");
44 		goto err;
45 	}
46 
47 	memset(big_buf, 0, optlen);
48 	optlen = 1;
49 	err = getsockopt(fd, SOL_IP, IP_TOS, big_buf, &optlen);
50 	if (err) {
51 		log_err("Failed to call getsockopt(IP_TOS)");
52 		goto err;
53 	}
54 
55 	if (*big_buf != 0x08) {
56 		log_err("Unexpected getsockopt(IP_TOS) optval 0x%x != 0x08",
57 			(int)*big_buf);
58 		goto err;
59 	}
60 
61 	/* IP_TTL - EPERM */
62 
63 	buf.u8[0] = 1;
64 	err = setsockopt(fd, SOL_IP, IP_TTL, &buf, 1);
65 	if (!err || errno != EPERM) {
66 		log_err("Unexpected success from setsockopt(IP_TTL)");
67 		goto err;
68 	}
69 
70 	/* SOL_CUSTOM - handled by BPF */
71 
72 	buf.u8[0] = 0x01;
73 	err = setsockopt(fd, SOL_CUSTOM, 0, &buf, 1);
74 	if (err) {
75 		log_err("Failed to call setsockopt");
76 		goto err;
77 	}
78 
79 	buf.u32 = 0x00;
80 	optlen = 4;
81 	err = getsockopt(fd, SOL_CUSTOM, 0, &buf, &optlen);
82 	if (err) {
83 		log_err("Failed to call getsockopt");
84 		goto err;
85 	}
86 
87 	if (optlen != 1) {
88 		log_err("Unexpected optlen %d != 1", optlen);
89 		goto err;
90 	}
91 	if (buf.u8[0] != 0x01) {
92 		log_err("Unexpected buf[0] 0x%02x != 0x01", buf.u8[0]);
93 		goto err;
94 	}
95 
96 	/* IP_FREEBIND - BPF can't access optval past PAGE_SIZE */
97 
98 	optlen = getpagesize() * 2;
99 	memset(big_buf, 0, optlen);
100 
101 	err = setsockopt(fd, SOL_IP, IP_FREEBIND, big_buf, optlen);
102 	if (err != 0) {
103 		log_err("Failed to call setsockopt, ret=%d", err);
104 		goto err;
105 	}
106 
107 	err = getsockopt(fd, SOL_IP, IP_FREEBIND, big_buf, &optlen);
108 	if (err != 0) {
109 		log_err("Failed to call getsockopt, ret=%d", err);
110 		goto err;
111 	}
112 
113 	if (optlen != 1 || *(__u8 *)big_buf != 0x55) {
114 		log_err("Unexpected IP_FREEBIND getsockopt, optlen=%d, optval=0x%x",
115 			optlen, *(__u8 *)big_buf);
116 	}
117 
118 	/* SO_SNDBUF is overwritten */
119 
120 	buf.u32 = 0x01010101;
121 	err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &buf, 4);
122 	if (err) {
123 		log_err("Failed to call setsockopt(SO_SNDBUF)");
124 		goto err;
125 	}
126 
127 	buf.u32 = 0x00;
128 	optlen = 4;
129 	err = getsockopt(fd, SOL_SOCKET, SO_SNDBUF, &buf, &optlen);
130 	if (err) {
131 		log_err("Failed to call getsockopt(SO_SNDBUF)");
132 		goto err;
133 	}
134 
135 	if (buf.u32 != 0x55AA*2) {
136 		log_err("Unexpected getsockopt(SO_SNDBUF) 0x%x != 0x55AA*2",
137 			buf.u32);
138 		goto err;
139 	}
140 
141 	/* TCP_CONGESTION can extend the string */
142 
143 	strcpy(buf.cc, "nv");
144 	err = setsockopt(fd, SOL_TCP, TCP_CONGESTION, &buf, strlen("nv"));
145 	if (err) {
146 		log_err("Failed to call setsockopt(TCP_CONGESTION)");
147 		goto err;
148 	}
149 
150 
151 	optlen = sizeof(buf.cc);
152 	err = getsockopt(fd, SOL_TCP, TCP_CONGESTION, &buf, &optlen);
153 	if (err) {
154 		log_err("Failed to call getsockopt(TCP_CONGESTION)");
155 		goto err;
156 	}
157 
158 	if (strcmp(buf.cc, "cubic") != 0) {
159 		log_err("Unexpected getsockopt(TCP_CONGESTION) %s != %s",
160 			buf.cc, "cubic");
161 		goto err;
162 	}
163 
164 	/* TCP_ZEROCOPY_RECEIVE triggers */
165 	memset(&buf, 0, sizeof(buf));
166 	optlen = sizeof(buf.zc);
167 	err = getsockopt(fd, SOL_TCP, TCP_ZEROCOPY_RECEIVE, &buf, &optlen);
168 	if (err) {
169 		log_err("Unexpected getsockopt(TCP_ZEROCOPY_RECEIVE) err=%d errno=%d",
170 			err, errno);
171 		goto err;
172 	}
173 
174 	memset(&buf, 0, sizeof(buf));
175 	buf.zc.address = 12345; /* rejected by BPF */
176 	optlen = sizeof(buf.zc);
177 	errno = 0;
178 	err = getsockopt(fd, SOL_TCP, TCP_ZEROCOPY_RECEIVE, &buf, &optlen);
179 	if (errno != EPERM) {
180 		log_err("Unexpected getsockopt(TCP_ZEROCOPY_RECEIVE) err=%d errno=%d",
181 			err, errno);
182 		goto err;
183 	}
184 
185 	free(big_buf);
186 	close(fd);
187 	return 0;
188 err:
189 	free(big_buf);
190 	close(fd);
191 	return -1;
192 }
193 
194 static int prog_attach(struct bpf_object *obj, int cgroup_fd, const char *title)
195 {
196 	enum bpf_attach_type attach_type;
197 	enum bpf_prog_type prog_type;
198 	struct bpf_program *prog;
199 	int err;
200 
201 	err = libbpf_prog_type_by_name(title, &prog_type, &attach_type);
202 	if (err) {
203 		log_err("Failed to deduct types for %s BPF program", title);
204 		return -1;
205 	}
206 
207 	prog = bpf_object__find_program_by_title(obj, title);
208 	if (!prog) {
209 		log_err("Failed to find %s BPF program", title);
210 		return -1;
211 	}
212 
213 	err = bpf_prog_attach(bpf_program__fd(prog), cgroup_fd,
214 			      attach_type, 0);
215 	if (err) {
216 		log_err("Failed to attach %s BPF program", title);
217 		return -1;
218 	}
219 
220 	return 0;
221 }
222 
223 static void run_test(int cgroup_fd)
224 {
225 	struct bpf_prog_load_attr attr = {
226 		.file = "./sockopt_sk.o",
227 	};
228 	struct bpf_object *obj;
229 	int ignored;
230 	int err;
231 
232 	err = bpf_prog_load_xattr(&attr, &obj, &ignored);
233 	if (CHECK_FAIL(err))
234 		return;
235 
236 	err = prog_attach(obj, cgroup_fd, "cgroup/getsockopt");
237 	if (CHECK_FAIL(err))
238 		goto close_bpf_object;
239 
240 	err = prog_attach(obj, cgroup_fd, "cgroup/setsockopt");
241 	if (CHECK_FAIL(err))
242 		goto close_bpf_object;
243 
244 	CHECK_FAIL(getsetsockopt());
245 
246 close_bpf_object:
247 	bpf_object__close(obj);
248 }
249 
250 void test_sockopt_sk(void)
251 {
252 	int cgroup_fd;
253 
254 	cgroup_fd = test__join_cgroup("/sockopt_sk");
255 	if (CHECK_FAIL(cgroup_fd < 0))
256 		return;
257 
258 	run_test(cgroup_fd);
259 	close(cgroup_fd);
260 }
261