1 // SPDX-License-Identifier: GPL-2.0
2 #include <test_progs.h>
3 #include "cgroup_helpers.h"
4 
5 #include <linux/tcp.h>
6 #include <linux/netlink.h>
7 #include "sockopt_sk.skel.h"
8 
9 #ifndef SOL_TCP
10 #define SOL_TCP IPPROTO_TCP
11 #endif
12 
13 #define SOL_CUSTOM			0xdeadbeef
14 
15 static int getsetsockopt(void)
16 {
17 	int fd, err;
18 	union {
19 		char u8[4];
20 		__u32 u32;
21 		char cc[16]; /* TCP_CA_NAME_MAX */
22 		struct tcp_zerocopy_receive zc;
23 	} buf = {};
24 	socklen_t optlen;
25 	char *big_buf = NULL;
26 
27 	fd = socket(AF_INET, SOCK_STREAM, 0);
28 	if (fd < 0) {
29 		log_err("Failed to create socket");
30 		return -1;
31 	}
32 
33 	/* IP_TOS - BPF bypass */
34 
35 	optlen = getpagesize() * 2;
36 	big_buf = calloc(1, optlen);
37 	if (!big_buf) {
38 		log_err("Couldn't allocate two pages");
39 		goto err;
40 	}
41 
42 	*(int *)big_buf = 0x08;
43 	err = setsockopt(fd, SOL_IP, IP_TOS, big_buf, optlen);
44 	if (err) {
45 		log_err("Failed to call setsockopt(IP_TOS)");
46 		goto err;
47 	}
48 
49 	memset(big_buf, 0, optlen);
50 	optlen = 1;
51 	err = getsockopt(fd, SOL_IP, IP_TOS, big_buf, &optlen);
52 	if (err) {
53 		log_err("Failed to call getsockopt(IP_TOS)");
54 		goto err;
55 	}
56 
57 	if (*big_buf != 0x08) {
58 		log_err("Unexpected getsockopt(IP_TOS) optval 0x%x != 0x08",
59 			(int)*big_buf);
60 		goto err;
61 	}
62 
63 	/* IP_TTL - EPERM */
64 
65 	buf.u8[0] = 1;
66 	err = setsockopt(fd, SOL_IP, IP_TTL, &buf, 1);
67 	if (!err || errno != EPERM) {
68 		log_err("Unexpected success from setsockopt(IP_TTL)");
69 		goto err;
70 	}
71 
72 	/* SOL_CUSTOM - handled by BPF */
73 
74 	buf.u8[0] = 0x01;
75 	err = setsockopt(fd, SOL_CUSTOM, 0, &buf, 1);
76 	if (err) {
77 		log_err("Failed to call setsockopt");
78 		goto err;
79 	}
80 
81 	buf.u32 = 0x00;
82 	optlen = 4;
83 	err = getsockopt(fd, SOL_CUSTOM, 0, &buf, &optlen);
84 	if (err) {
85 		log_err("Failed to call getsockopt");
86 		goto err;
87 	}
88 
89 	if (optlen != 1) {
90 		log_err("Unexpected optlen %d != 1", optlen);
91 		goto err;
92 	}
93 	if (buf.u8[0] != 0x01) {
94 		log_err("Unexpected buf[0] 0x%02x != 0x01", buf.u8[0]);
95 		goto err;
96 	}
97 
98 	/* IP_FREEBIND - BPF can't access optval past PAGE_SIZE */
99 
100 	optlen = getpagesize() * 2;
101 	memset(big_buf, 0, optlen);
102 
103 	err = setsockopt(fd, SOL_IP, IP_FREEBIND, big_buf, optlen);
104 	if (err != 0) {
105 		log_err("Failed to call setsockopt, ret=%d", err);
106 		goto err;
107 	}
108 
109 	err = getsockopt(fd, SOL_IP, IP_FREEBIND, big_buf, &optlen);
110 	if (err != 0) {
111 		log_err("Failed to call getsockopt, ret=%d", err);
112 		goto err;
113 	}
114 
115 	if (optlen != 1 || *(__u8 *)big_buf != 0x55) {
116 		log_err("Unexpected IP_FREEBIND getsockopt, optlen=%d, optval=0x%x",
117 			optlen, *(__u8 *)big_buf);
118 	}
119 
120 	/* SO_SNDBUF is overwritten */
121 
122 	buf.u32 = 0x01010101;
123 	err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &buf, 4);
124 	if (err) {
125 		log_err("Failed to call setsockopt(SO_SNDBUF)");
126 		goto err;
127 	}
128 
129 	buf.u32 = 0x00;
130 	optlen = 4;
131 	err = getsockopt(fd, SOL_SOCKET, SO_SNDBUF, &buf, &optlen);
132 	if (err) {
133 		log_err("Failed to call getsockopt(SO_SNDBUF)");
134 		goto err;
135 	}
136 
137 	if (buf.u32 != 0x55AA*2) {
138 		log_err("Unexpected getsockopt(SO_SNDBUF) 0x%x != 0x55AA*2",
139 			buf.u32);
140 		goto err;
141 	}
142 
143 	/* TCP_CONGESTION can extend the string */
144 
145 	strcpy(buf.cc, "nv");
146 	err = setsockopt(fd, SOL_TCP, TCP_CONGESTION, &buf, strlen("nv"));
147 	if (err) {
148 		log_err("Failed to call setsockopt(TCP_CONGESTION)");
149 		goto err;
150 	}
151 
152 
153 	optlen = sizeof(buf.cc);
154 	err = getsockopt(fd, SOL_TCP, TCP_CONGESTION, &buf, &optlen);
155 	if (err) {
156 		log_err("Failed to call getsockopt(TCP_CONGESTION)");
157 		goto err;
158 	}
159 
160 	if (strcmp(buf.cc, "cubic") != 0) {
161 		log_err("Unexpected getsockopt(TCP_CONGESTION) %s != %s",
162 			buf.cc, "cubic");
163 		goto err;
164 	}
165 
166 	/* TCP_ZEROCOPY_RECEIVE triggers */
167 	memset(&buf, 0, sizeof(buf));
168 	optlen = sizeof(buf.zc);
169 	err = getsockopt(fd, SOL_TCP, TCP_ZEROCOPY_RECEIVE, &buf, &optlen);
170 	if (err) {
171 		log_err("Unexpected getsockopt(TCP_ZEROCOPY_RECEIVE) err=%d errno=%d",
172 			err, errno);
173 		goto err;
174 	}
175 
176 	memset(&buf, 0, sizeof(buf));
177 	buf.zc.address = 12345; /* Not page aligned. Rejected by tcp_zerocopy_receive() */
178 	optlen = sizeof(buf.zc);
179 	errno = 0;
180 	err = getsockopt(fd, SOL_TCP, TCP_ZEROCOPY_RECEIVE, &buf, &optlen);
181 	if (errno != EINVAL) {
182 		log_err("Unexpected getsockopt(TCP_ZEROCOPY_RECEIVE) err=%d errno=%d",
183 			err, errno);
184 		goto err;
185 	}
186 
187 	/* optval=NULL case is handled correctly */
188 
189 	close(fd);
190 	fd = socket(AF_NETLINK, SOCK_RAW, 0);
191 	if (fd < 0) {
192 		log_err("Failed to create AF_NETLINK socket");
193 		return -1;
194 	}
195 
196 	buf.u32 = 1;
197 	optlen = sizeof(__u32);
198 	err = setsockopt(fd, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, &buf, optlen);
199 	if (err) {
200 		log_err("Unexpected getsockopt(NETLINK_ADD_MEMBERSHIP) err=%d errno=%d",
201 			err, errno);
202 		goto err;
203 	}
204 
205 	optlen = 0;
206 	err = getsockopt(fd, SOL_NETLINK, NETLINK_LIST_MEMBERSHIPS, NULL, &optlen);
207 	if (err) {
208 		log_err("Unexpected getsockopt(NETLINK_LIST_MEMBERSHIPS) err=%d errno=%d",
209 			err, errno);
210 		goto err;
211 	}
212 	ASSERT_EQ(optlen, 8, "Unexpected NETLINK_LIST_MEMBERSHIPS value");
213 
214 	free(big_buf);
215 	close(fd);
216 	return 0;
217 err:
218 	free(big_buf);
219 	close(fd);
220 	return -1;
221 }
222 
223 static void run_test(int cgroup_fd)
224 {
225 	struct sockopt_sk *skel;
226 
227 	skel = sockopt_sk__open_and_load();
228 	if (!ASSERT_OK_PTR(skel, "skel_load"))
229 		goto cleanup;
230 
231 	skel->bss->page_size = getpagesize();
232 
233 	skel->links._setsockopt =
234 		bpf_program__attach_cgroup(skel->progs._setsockopt, cgroup_fd);
235 	if (!ASSERT_OK_PTR(skel->links._setsockopt, "setsockopt_link"))
236 		goto cleanup;
237 
238 	skel->links._getsockopt =
239 		bpf_program__attach_cgroup(skel->progs._getsockopt, cgroup_fd);
240 	if (!ASSERT_OK_PTR(skel->links._getsockopt, "getsockopt_link"))
241 		goto cleanup;
242 
243 	ASSERT_OK(getsetsockopt(), "getsetsockopt");
244 
245 cleanup:
246 	sockopt_sk__destroy(skel);
247 }
248 
249 void test_sockopt_sk(void)
250 {
251 	int cgroup_fd;
252 
253 	cgroup_fd = test__join_cgroup("/sockopt_sk");
254 	if (!ASSERT_GE(cgroup_fd, 0, "join_cgroup /sockopt_sk"))
255 		return;
256 
257 	run_test(cgroup_fd);
258 	close(cgroup_fd);
259 }
260