1 // SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2 // Copyright (c) 2020 Cloudflare
3 
4 #define _GNU_SOURCE
5 
6 #include <arpa/inet.h>
7 #include <string.h>
8 
9 #include <linux/pkt_cls.h>
10 
11 #include <test_progs.h>
12 
13 #include "progs/test_cls_redirect.h"
14 #include "test_cls_redirect.skel.h"
15 #include "test_cls_redirect_subprogs.skel.h"
16 
17 #define ENCAP_IP INADDR_LOOPBACK
18 #define ENCAP_PORT (1234)
19 
20 static int duration = 0;
21 
22 struct addr_port {
23 	in_port_t port;
24 	union {
25 		struct in_addr in_addr;
26 		struct in6_addr in6_addr;
27 	};
28 };
29 
30 struct tuple {
31 	int family;
32 	struct addr_port src;
33 	struct addr_port dst;
34 };
35 
36 static int start_server(const struct sockaddr *addr, socklen_t len, int type)
37 {
38 	int fd = socket(addr->sa_family, type, 0);
39 	if (CHECK_FAIL(fd == -1))
40 		return -1;
41 	if (CHECK_FAIL(bind(fd, addr, len) == -1))
42 		goto err;
43 	if (type == SOCK_STREAM && CHECK_FAIL(listen(fd, 128) == -1))
44 		goto err;
45 
46 	return fd;
47 
48 err:
49 	close(fd);
50 	return -1;
51 }
52 
53 static int connect_to_server(const struct sockaddr *addr, socklen_t len,
54 			     int type)
55 {
56 	int fd = socket(addr->sa_family, type, 0);
57 	if (CHECK_FAIL(fd == -1))
58 		return -1;
59 	if (CHECK_FAIL(connect(fd, addr, len)))
60 		goto err;
61 
62 	return fd;
63 
64 err:
65 	close(fd);
66 	return -1;
67 }
68 
69 static bool fill_addr_port(const struct sockaddr *sa, struct addr_port *ap)
70 {
71 	const struct sockaddr_in6 *in6;
72 	const struct sockaddr_in *in;
73 
74 	switch (sa->sa_family) {
75 	case AF_INET:
76 		in = (const struct sockaddr_in *)sa;
77 		ap->in_addr = in->sin_addr;
78 		ap->port = in->sin_port;
79 		return true;
80 
81 	case AF_INET6:
82 		in6 = (const struct sockaddr_in6 *)sa;
83 		ap->in6_addr = in6->sin6_addr;
84 		ap->port = in6->sin6_port;
85 		return true;
86 
87 	default:
88 		return false;
89 	}
90 }
91 
92 static bool set_up_conn(const struct sockaddr *addr, socklen_t len, int type,
93 			int *server, int *conn, struct tuple *tuple)
94 {
95 	struct sockaddr_storage ss;
96 	socklen_t slen = sizeof(ss);
97 	struct sockaddr *sa = (struct sockaddr *)&ss;
98 
99 	*server = start_server(addr, len, type);
100 	if (*server < 0)
101 		return false;
102 
103 	if (CHECK_FAIL(getsockname(*server, sa, &slen)))
104 		goto close_server;
105 
106 	*conn = connect_to_server(sa, slen, type);
107 	if (*conn < 0)
108 		goto close_server;
109 
110 	/* We want to simulate packets arriving at conn, so we have to
111 	 * swap src and dst.
112 	 */
113 	slen = sizeof(ss);
114 	if (CHECK_FAIL(getsockname(*conn, sa, &slen)))
115 		goto close_conn;
116 
117 	if (CHECK_FAIL(!fill_addr_port(sa, &tuple->dst)))
118 		goto close_conn;
119 
120 	slen = sizeof(ss);
121 	if (CHECK_FAIL(getpeername(*conn, sa, &slen)))
122 		goto close_conn;
123 
124 	if (CHECK_FAIL(!fill_addr_port(sa, &tuple->src)))
125 		goto close_conn;
126 
127 	tuple->family = ss.ss_family;
128 	return true;
129 
130 close_conn:
131 	close(*conn);
132 	*conn = -1;
133 close_server:
134 	close(*server);
135 	*server = -1;
136 	return false;
137 }
138 
139 static socklen_t prepare_addr(struct sockaddr_storage *addr, int family)
140 {
141 	struct sockaddr_in *addr4;
142 	struct sockaddr_in6 *addr6;
143 
144 	switch (family) {
145 	case AF_INET:
146 		addr4 = (struct sockaddr_in *)addr;
147 		memset(addr4, 0, sizeof(*addr4));
148 		addr4->sin_family = family;
149 		addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
150 		return sizeof(*addr4);
151 	case AF_INET6:
152 		addr6 = (struct sockaddr_in6 *)addr;
153 		memset(addr6, 0, sizeof(*addr6));
154 		addr6->sin6_family = family;
155 		addr6->sin6_addr = in6addr_loopback;
156 		return sizeof(*addr6);
157 	default:
158 		fprintf(stderr, "Invalid family %d", family);
159 		return 0;
160 	}
161 }
162 
163 static bool was_decapsulated(struct bpf_prog_test_run_attr *tattr)
164 {
165 	return tattr->data_size_out < tattr->data_size_in;
166 }
167 
168 enum type {
169 	UDP,
170 	TCP,
171 	__NR_KIND,
172 };
173 
174 enum hops {
175 	NO_HOPS,
176 	ONE_HOP,
177 };
178 
179 enum flags {
180 	NONE,
181 	SYN,
182 	ACK,
183 };
184 
185 enum conn {
186 	KNOWN_CONN,
187 	UNKNOWN_CONN,
188 };
189 
190 enum result {
191 	ACCEPT,
192 	FORWARD,
193 };
194 
195 struct test_cfg {
196 	enum type type;
197 	enum result result;
198 	enum conn conn;
199 	enum hops hops;
200 	enum flags flags;
201 };
202 
203 static int test_str(void *buf, size_t len, const struct test_cfg *test,
204 		    int family)
205 {
206 	const char *family_str, *type, *conn, *hops, *result, *flags;
207 
208 	family_str = "IPv4";
209 	if (family == AF_INET6)
210 		family_str = "IPv6";
211 
212 	type = "TCP";
213 	if (test->type == UDP)
214 		type = "UDP";
215 
216 	conn = "known";
217 	if (test->conn == UNKNOWN_CONN)
218 		conn = "unknown";
219 
220 	hops = "no hops";
221 	if (test->hops == ONE_HOP)
222 		hops = "one hop";
223 
224 	result = "accept";
225 	if (test->result == FORWARD)
226 		result = "forward";
227 
228 	flags = "none";
229 	if (test->flags == SYN)
230 		flags = "SYN";
231 	else if (test->flags == ACK)
232 		flags = "ACK";
233 
234 	return snprintf(buf, len, "%s %s %s %s (%s, flags: %s)", family_str,
235 			type, result, conn, hops, flags);
236 }
237 
238 static struct test_cfg tests[] = {
239 	{ TCP, ACCEPT, UNKNOWN_CONN, NO_HOPS, SYN },
240 	{ TCP, ACCEPT, UNKNOWN_CONN, NO_HOPS, ACK },
241 	{ TCP, FORWARD, UNKNOWN_CONN, ONE_HOP, ACK },
242 	{ TCP, ACCEPT, KNOWN_CONN, ONE_HOP, ACK },
243 	{ UDP, ACCEPT, UNKNOWN_CONN, NO_HOPS, NONE },
244 	{ UDP, FORWARD, UNKNOWN_CONN, ONE_HOP, NONE },
245 	{ UDP, ACCEPT, KNOWN_CONN, ONE_HOP, NONE },
246 };
247 
248 static void encap_init(encap_headers_t *encap, uint8_t hop_count, uint8_t proto)
249 {
250 	const uint8_t hlen =
251 		(sizeof(struct guehdr) / sizeof(uint32_t)) + hop_count;
252 	*encap = (encap_headers_t){
253 		.eth = { .h_proto = htons(ETH_P_IP) },
254 		.ip = {
255 			.ihl = 5,
256 			.version = 4,
257 			.ttl = IPDEFTTL,
258 			.protocol = IPPROTO_UDP,
259 			.daddr = htonl(ENCAP_IP)
260 		},
261 		.udp = {
262 			.dest = htons(ENCAP_PORT),
263 		},
264 		.gue = {
265 			.hlen = hlen,
266 			.proto_ctype = proto
267 		},
268 		.unigue = {
269 			.hop_count = hop_count
270 		},
271 	};
272 }
273 
274 static size_t build_input(const struct test_cfg *test, void *const buf,
275 			  const struct tuple *tuple)
276 {
277 	in_port_t sport = tuple->src.port;
278 	encap_headers_t encap;
279 	struct iphdr ip;
280 	struct ipv6hdr ipv6;
281 	struct tcphdr tcp;
282 	struct udphdr udp;
283 	struct in_addr next_hop;
284 	uint8_t *p = buf;
285 	int proto;
286 
287 	proto = IPPROTO_IPIP;
288 	if (tuple->family == AF_INET6)
289 		proto = IPPROTO_IPV6;
290 
291 	encap_init(&encap, test->hops == ONE_HOP ? 1 : 0, proto);
292 	p = mempcpy(p, &encap, sizeof(encap));
293 
294 	if (test->hops == ONE_HOP) {
295 		next_hop = (struct in_addr){ .s_addr = htonl(0x7f000002) };
296 		p = mempcpy(p, &next_hop, sizeof(next_hop));
297 	}
298 
299 	proto = IPPROTO_TCP;
300 	if (test->type == UDP)
301 		proto = IPPROTO_UDP;
302 
303 	switch (tuple->family) {
304 	case AF_INET:
305 		ip = (struct iphdr){
306 			.ihl = 5,
307 			.version = 4,
308 			.ttl = IPDEFTTL,
309 			.protocol = proto,
310 			.saddr = tuple->src.in_addr.s_addr,
311 			.daddr = tuple->dst.in_addr.s_addr,
312 		};
313 		p = mempcpy(p, &ip, sizeof(ip));
314 		break;
315 	case AF_INET6:
316 		ipv6 = (struct ipv6hdr){
317 			.version = 6,
318 			.hop_limit = IPDEFTTL,
319 			.nexthdr = proto,
320 			.saddr = tuple->src.in6_addr,
321 			.daddr = tuple->dst.in6_addr,
322 		};
323 		p = mempcpy(p, &ipv6, sizeof(ipv6));
324 		break;
325 	default:
326 		return 0;
327 	}
328 
329 	if (test->conn == UNKNOWN_CONN)
330 		sport--;
331 
332 	switch (test->type) {
333 	case TCP:
334 		tcp = (struct tcphdr){
335 			.source = sport,
336 			.dest = tuple->dst.port,
337 		};
338 		if (test->flags == SYN)
339 			tcp.syn = true;
340 		if (test->flags == ACK)
341 			tcp.ack = true;
342 		p = mempcpy(p, &tcp, sizeof(tcp));
343 		break;
344 	case UDP:
345 		udp = (struct udphdr){
346 			.source = sport,
347 			.dest = tuple->dst.port,
348 		};
349 		p = mempcpy(p, &udp, sizeof(udp));
350 		break;
351 	default:
352 		return 0;
353 	}
354 
355 	return (void *)p - buf;
356 }
357 
358 static void close_fds(int *fds, int n)
359 {
360 	int i;
361 
362 	for (i = 0; i < n; i++)
363 		if (fds[i] > 0)
364 			close(fds[i]);
365 }
366 
367 static void test_cls_redirect_common(struct bpf_program *prog)
368 {
369 	struct bpf_prog_test_run_attr tattr = {};
370 	int families[] = { AF_INET, AF_INET6 };
371 	struct sockaddr_storage ss;
372 	struct sockaddr *addr;
373 	socklen_t slen;
374 	int i, j, err;
375 	int servers[__NR_KIND][ARRAY_SIZE(families)] = {};
376 	int conns[__NR_KIND][ARRAY_SIZE(families)] = {};
377 	struct tuple tuples[__NR_KIND][ARRAY_SIZE(families)];
378 
379 	addr = (struct sockaddr *)&ss;
380 	for (i = 0; i < ARRAY_SIZE(families); i++) {
381 		slen = prepare_addr(&ss, families[i]);
382 		if (CHECK_FAIL(!slen))
383 			goto cleanup;
384 
385 		if (CHECK_FAIL(!set_up_conn(addr, slen, SOCK_DGRAM,
386 					    &servers[UDP][i], &conns[UDP][i],
387 					    &tuples[UDP][i])))
388 			goto cleanup;
389 
390 		if (CHECK_FAIL(!set_up_conn(addr, slen, SOCK_STREAM,
391 					    &servers[TCP][i], &conns[TCP][i],
392 					    &tuples[TCP][i])))
393 			goto cleanup;
394 	}
395 
396 	tattr.prog_fd = bpf_program__fd(prog);
397 	for (i = 0; i < ARRAY_SIZE(tests); i++) {
398 		struct test_cfg *test = &tests[i];
399 
400 		for (j = 0; j < ARRAY_SIZE(families); j++) {
401 			struct tuple *tuple = &tuples[test->type][j];
402 			char input[256];
403 			char tmp[256];
404 
405 			test_str(tmp, sizeof(tmp), test, tuple->family);
406 			if (!test__start_subtest(tmp))
407 				continue;
408 
409 			tattr.data_out = tmp;
410 			tattr.data_size_out = sizeof(tmp);
411 
412 			tattr.data_in = input;
413 			tattr.data_size_in = build_input(test, input, tuple);
414 			if (CHECK_FAIL(!tattr.data_size_in))
415 				continue;
416 
417 			err = bpf_prog_test_run_xattr(&tattr);
418 			if (CHECK_FAIL(err))
419 				continue;
420 
421 			if (tattr.retval != TC_ACT_REDIRECT) {
422 				PRINT_FAIL("expected TC_ACT_REDIRECT, got %d\n",
423 					   tattr.retval);
424 				continue;
425 			}
426 
427 			switch (test->result) {
428 			case ACCEPT:
429 				if (CHECK_FAIL(!was_decapsulated(&tattr)))
430 					continue;
431 				break;
432 			case FORWARD:
433 				if (CHECK_FAIL(was_decapsulated(&tattr)))
434 					continue;
435 				break;
436 			default:
437 				PRINT_FAIL("unknown result %d\n", test->result);
438 				continue;
439 			}
440 		}
441 	}
442 
443 cleanup:
444 	close_fds((int *)servers, sizeof(servers) / sizeof(servers[0][0]));
445 	close_fds((int *)conns, sizeof(conns) / sizeof(conns[0][0]));
446 }
447 
448 static void test_cls_redirect_inlined(void)
449 {
450 	struct test_cls_redirect *skel;
451 	int err;
452 
453 	skel = test_cls_redirect__open();
454 	if (CHECK(!skel, "skel_open", "failed\n"))
455 		return;
456 
457 	skel->rodata->ENCAPSULATION_IP = htonl(ENCAP_IP);
458 	skel->rodata->ENCAPSULATION_PORT = htons(ENCAP_PORT);
459 
460 	err = test_cls_redirect__load(skel);
461 	if (CHECK(err, "skel_load", "failed: %d\n", err))
462 		goto cleanup;
463 
464 	test_cls_redirect_common(skel->progs.cls_redirect);
465 
466 cleanup:
467 	test_cls_redirect__destroy(skel);
468 }
469 
470 static void test_cls_redirect_subprogs(void)
471 {
472 	struct test_cls_redirect_subprogs *skel;
473 	int err;
474 
475 	skel = test_cls_redirect_subprogs__open();
476 	if (CHECK(!skel, "skel_open", "failed\n"))
477 		return;
478 
479 	skel->rodata->ENCAPSULATION_IP = htonl(ENCAP_IP);
480 	skel->rodata->ENCAPSULATION_PORT = htons(ENCAP_PORT);
481 
482 	err = test_cls_redirect_subprogs__load(skel);
483 	if (CHECK(err, "skel_load", "failed: %d\n", err))
484 		goto cleanup;
485 
486 	test_cls_redirect_common(skel->progs.cls_redirect);
487 
488 cleanup:
489 	test_cls_redirect_subprogs__destroy(skel);
490 }
491 
492 void test_cls_redirect(void)
493 {
494 	if (test__start_subtest("cls_redirect_inlined"))
495 		test_cls_redirect_inlined();
496 	if (test__start_subtest("cls_redirect_subprogs"))
497 		test_cls_redirect_subprogs();
498 }
499