1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright 2013 Google Inc.
4  * Author: Willem de Bruijn (willemb@google.com)
5  *
6  * A basic test of packet socket fanout behavior.
7  *
8  * Control:
9  * - create fanout fails as expected with illegal flag combinations
10  * - join   fanout fails as expected with diverging types or flags
11  *
12  * Datapath:
13  *   Open a pair of packet sockets and a pair of INET sockets, send a known
14  *   number of packets across the two INET sockets and count the number of
15  *   packets enqueued onto the two packet sockets.
16  *
17  *   The test currently runs for
18  *   - PACKET_FANOUT_HASH
19  *   - PACKET_FANOUT_HASH with PACKET_FANOUT_FLAG_ROLLOVER
20  *   - PACKET_FANOUT_LB
21  *   - PACKET_FANOUT_CPU
22  *   - PACKET_FANOUT_ROLLOVER
23  *   - PACKET_FANOUT_CBPF
24  *   - PACKET_FANOUT_EBPF
25  *
26  * Todo:
27  * - functionality: PACKET_FANOUT_FLAG_DEFRAG
28  */
29 
30 #define _GNU_SOURCE		/* for sched_setaffinity */
31 
32 #include <arpa/inet.h>
33 #include <errno.h>
34 #include <fcntl.h>
35 #include <linux/unistd.h>	/* for __NR_bpf */
36 #include <linux/filter.h>
37 #include <linux/bpf.h>
38 #include <linux/if_packet.h>
39 #include <net/if.h>
40 #include <net/ethernet.h>
41 #include <netinet/ip.h>
42 #include <netinet/udp.h>
43 #include <poll.h>
44 #include <sched.h>
45 #include <stdint.h>
46 #include <stdio.h>
47 #include <stdlib.h>
48 #include <string.h>
49 #include <sys/mman.h>
50 #include <sys/socket.h>
51 #include <sys/stat.h>
52 #include <sys/types.h>
53 #include <unistd.h>
54 
55 #include "psock_lib.h"
56 
57 #define RING_NUM_FRAMES			20
58 
59 static uint32_t cfg_max_num_members;
60 
61 /* Open a socket in a given fanout mode.
62  * @return -1 if mode is bad, a valid socket otherwise */
63 static int sock_fanout_open(uint16_t typeflags, uint16_t group_id)
64 {
65 	struct sockaddr_ll addr = {0};
66 	struct fanout_args args;
67 	int fd, val, err;
68 
69 	fd = socket(PF_PACKET, SOCK_RAW, 0);
70 	if (fd < 0) {
71 		perror("socket packet");
72 		exit(1);
73 	}
74 
75 	pair_udp_setfilter(fd);
76 
77 	addr.sll_family = AF_PACKET;
78 	addr.sll_protocol = htons(ETH_P_IP);
79 	addr.sll_ifindex = if_nametoindex("lo");
80 	if (addr.sll_ifindex == 0) {
81 		perror("if_nametoindex");
82 		exit(1);
83 	}
84 	if (bind(fd, (void *) &addr, sizeof(addr))) {
85 		perror("bind packet");
86 		exit(1);
87 	}
88 
89 	if (cfg_max_num_members) {
90 		args.id = group_id;
91 		args.type_flags = typeflags;
92 		args.max_num_members = cfg_max_num_members;
93 		err = setsockopt(fd, SOL_PACKET, PACKET_FANOUT, &args,
94 				 sizeof(args));
95 	} else {
96 		val = (((int) typeflags) << 16) | group_id;
97 		err = setsockopt(fd, SOL_PACKET, PACKET_FANOUT, &val,
98 				 sizeof(val));
99 	}
100 	if (err) {
101 		if (close(fd)) {
102 			perror("close packet");
103 			exit(1);
104 		}
105 		return -1;
106 	}
107 
108 	return fd;
109 }
110 
111 static void sock_fanout_set_cbpf(int fd)
112 {
113 	struct sock_filter bpf_filter[] = {
114 		BPF_STMT(BPF_LD+BPF_B+BPF_ABS, 80),	      /* ldb [80] */
115 		BPF_STMT(BPF_RET+BPF_A, 0),		      /* ret A */
116 	};
117 	struct sock_fprog bpf_prog;
118 
119 	bpf_prog.filter = bpf_filter;
120 	bpf_prog.len = sizeof(bpf_filter) / sizeof(struct sock_filter);
121 
122 	if (setsockopt(fd, SOL_PACKET, PACKET_FANOUT_DATA, &bpf_prog,
123 		       sizeof(bpf_prog))) {
124 		perror("fanout data cbpf");
125 		exit(1);
126 	}
127 }
128 
129 static void sock_fanout_getopts(int fd, uint16_t *typeflags, uint16_t *group_id)
130 {
131 	int sockopt;
132 	socklen_t sockopt_len = sizeof(sockopt);
133 
134 	if (getsockopt(fd, SOL_PACKET, PACKET_FANOUT,
135 		       &sockopt, &sockopt_len)) {
136 		perror("failed to getsockopt");
137 		exit(1);
138 	}
139 	*typeflags = sockopt >> 16;
140 	*group_id = sockopt & 0xfffff;
141 }
142 
143 static void sock_fanout_set_ebpf(int fd)
144 {
145 	static char log_buf[65536];
146 
147 	const int len_off = __builtin_offsetof(struct __sk_buff, len);
148 	struct bpf_insn prog[] = {
149 		{ BPF_ALU64 | BPF_MOV | BPF_X,   6, 1, 0, 0 },
150 		{ BPF_LDX   | BPF_W   | BPF_MEM, 0, 6, len_off, 0 },
151 		{ BPF_JMP   | BPF_JGE | BPF_K,   0, 0, 1, DATA_LEN },
152 		{ BPF_JMP   | BPF_JA  | BPF_K,   0, 0, 4, 0 },
153 		{ BPF_LD    | BPF_B   | BPF_ABS, 0, 0, 0, 0x50 },
154 		{ BPF_JMP   | BPF_JEQ | BPF_K,   0, 0, 2, DATA_CHAR },
155 		{ BPF_JMP   | BPF_JEQ | BPF_K,   0, 0, 1, DATA_CHAR_1 },
156 		{ BPF_ALU   | BPF_MOV | BPF_K,   0, 0, 0, 0 },
157 		{ BPF_JMP   | BPF_EXIT,          0, 0, 0, 0 }
158 	};
159 	union bpf_attr attr;
160 	int pfd;
161 
162 	memset(&attr, 0, sizeof(attr));
163 	attr.prog_type = BPF_PROG_TYPE_SOCKET_FILTER;
164 	attr.insns = (unsigned long) prog;
165 	attr.insn_cnt = sizeof(prog) / sizeof(prog[0]);
166 	attr.license = (unsigned long) "GPL";
167 	attr.log_buf = (unsigned long) log_buf,
168 	attr.log_size = sizeof(log_buf),
169 	attr.log_level = 1,
170 
171 	pfd = syscall(__NR_bpf, BPF_PROG_LOAD, &attr, sizeof(attr));
172 	if (pfd < 0) {
173 		perror("bpf");
174 		fprintf(stderr, "bpf verifier:\n%s\n", log_buf);
175 		exit(1);
176 	}
177 
178 	if (setsockopt(fd, SOL_PACKET, PACKET_FANOUT_DATA, &pfd, sizeof(pfd))) {
179 		perror("fanout data ebpf");
180 		exit(1);
181 	}
182 
183 	if (close(pfd)) {
184 		perror("close ebpf");
185 		exit(1);
186 	}
187 }
188 
189 static char *sock_fanout_open_ring(int fd)
190 {
191 	struct tpacket_req req = {
192 		.tp_block_size = getpagesize(),
193 		.tp_frame_size = getpagesize(),
194 		.tp_block_nr   = RING_NUM_FRAMES,
195 		.tp_frame_nr   = RING_NUM_FRAMES,
196 	};
197 	char *ring;
198 	int val = TPACKET_V2;
199 
200 	if (setsockopt(fd, SOL_PACKET, PACKET_VERSION, (void *) &val,
201 		       sizeof(val))) {
202 		perror("packetsock ring setsockopt version");
203 		exit(1);
204 	}
205 	if (setsockopt(fd, SOL_PACKET, PACKET_RX_RING, (void *) &req,
206 		       sizeof(req))) {
207 		perror("packetsock ring setsockopt");
208 		exit(1);
209 	}
210 
211 	ring = mmap(0, req.tp_block_size * req.tp_block_nr,
212 		    PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
213 	if (ring == MAP_FAILED) {
214 		perror("packetsock ring mmap");
215 		exit(1);
216 	}
217 
218 	return ring;
219 }
220 
221 static int sock_fanout_read_ring(int fd, void *ring)
222 {
223 	struct tpacket2_hdr *header = ring;
224 	int count = 0;
225 
226 	while (count < RING_NUM_FRAMES && header->tp_status & TP_STATUS_USER) {
227 		count++;
228 		header = ring + (count * getpagesize());
229 	}
230 
231 	return count;
232 }
233 
234 static int sock_fanout_read(int fds[], char *rings[], const int expect[])
235 {
236 	int ret[2];
237 
238 	ret[0] = sock_fanout_read_ring(fds[0], rings[0]);
239 	ret[1] = sock_fanout_read_ring(fds[1], rings[1]);
240 
241 	fprintf(stderr, "info: count=%d,%d, expect=%d,%d\n",
242 			ret[0], ret[1], expect[0], expect[1]);
243 
244 	if ((!(ret[0] == expect[0] && ret[1] == expect[1])) &&
245 	    (!(ret[0] == expect[1] && ret[1] == expect[0]))) {
246 		fprintf(stderr, "warning: incorrect queue lengths\n");
247 		return 1;
248 	}
249 
250 	return 0;
251 }
252 
253 /* Test illegal mode + flag combination */
254 static void test_control_single(void)
255 {
256 	fprintf(stderr, "test: control single socket\n");
257 
258 	if (sock_fanout_open(PACKET_FANOUT_ROLLOVER |
259 			       PACKET_FANOUT_FLAG_ROLLOVER, 0) != -1) {
260 		fprintf(stderr, "ERROR: opened socket with dual rollover\n");
261 		exit(1);
262 	}
263 }
264 
265 /* Test illegal group with different modes or flags */
266 static void test_control_group(void)
267 {
268 	int fds[2];
269 
270 	fprintf(stderr, "test: control multiple sockets\n");
271 
272 	fds[0] = sock_fanout_open(PACKET_FANOUT_HASH, 0);
273 	if (fds[0] == -1) {
274 		fprintf(stderr, "ERROR: failed to open HASH socket\n");
275 		exit(1);
276 	}
277 	if (sock_fanout_open(PACKET_FANOUT_HASH |
278 			       PACKET_FANOUT_FLAG_DEFRAG, 0) != -1) {
279 		fprintf(stderr, "ERROR: joined group with wrong flag defrag\n");
280 		exit(1);
281 	}
282 	if (sock_fanout_open(PACKET_FANOUT_HASH |
283 			       PACKET_FANOUT_FLAG_ROLLOVER, 0) != -1) {
284 		fprintf(stderr, "ERROR: joined group with wrong flag ro\n");
285 		exit(1);
286 	}
287 	if (sock_fanout_open(PACKET_FANOUT_CPU, 0) != -1) {
288 		fprintf(stderr, "ERROR: joined group with wrong mode\n");
289 		exit(1);
290 	}
291 	fds[1] = sock_fanout_open(PACKET_FANOUT_HASH, 0);
292 	if (fds[1] == -1) {
293 		fprintf(stderr, "ERROR: failed to join group\n");
294 		exit(1);
295 	}
296 	if (close(fds[1]) || close(fds[0])) {
297 		fprintf(stderr, "ERROR: closing sockets\n");
298 		exit(1);
299 	}
300 }
301 
302 /* Test illegal max_num_members values */
303 static void test_control_group_max_num_members(void)
304 {
305 	int fds[3];
306 
307 	fprintf(stderr, "test: control multiple sockets, max_num_members\n");
308 
309 	/* expected failure on greater than PACKET_FANOUT_MAX */
310 	cfg_max_num_members = (1 << 16) + 1;
311 	if (sock_fanout_open(PACKET_FANOUT_HASH, 0) != -1) {
312 		fprintf(stderr, "ERROR: max_num_members > PACKET_FANOUT_MAX\n");
313 		exit(1);
314 	}
315 
316 	cfg_max_num_members = 256;
317 	fds[0] = sock_fanout_open(PACKET_FANOUT_HASH, 0);
318 	if (fds[0] == -1) {
319 		fprintf(stderr, "ERROR: failed open\n");
320 		exit(1);
321 	}
322 
323 	/* expected failure on joining group with different max_num_members */
324 	cfg_max_num_members = 257;
325 	if (sock_fanout_open(PACKET_FANOUT_HASH, 0) != -1) {
326 		fprintf(stderr, "ERROR: set different max_num_members\n");
327 		exit(1);
328 	}
329 
330 	/* success on joining group with same max_num_members */
331 	cfg_max_num_members = 256;
332 	fds[1] = sock_fanout_open(PACKET_FANOUT_HASH, 0);
333 	if (fds[1] == -1) {
334 		fprintf(stderr, "ERROR: failed to join group\n");
335 		exit(1);
336 	}
337 
338 	/* success on joining group with max_num_members unspecified */
339 	cfg_max_num_members = 0;
340 	fds[2] = sock_fanout_open(PACKET_FANOUT_HASH, 0);
341 	if (fds[2] == -1) {
342 		fprintf(stderr, "ERROR: failed to join group\n");
343 		exit(1);
344 	}
345 
346 	if (close(fds[2]) || close(fds[1]) || close(fds[0])) {
347 		fprintf(stderr, "ERROR: closing sockets\n");
348 		exit(1);
349 	}
350 }
351 
352 /* Test creating a unique fanout group ids */
353 static void test_unique_fanout_group_ids(void)
354 {
355 	int fds[3];
356 	uint16_t typeflags, first_group_id, second_group_id;
357 
358 	fprintf(stderr, "test: unique ids\n");
359 
360 	fds[0] = sock_fanout_open(PACKET_FANOUT_HASH |
361 				  PACKET_FANOUT_FLAG_UNIQUEID, 0);
362 	if (fds[0] == -1) {
363 		fprintf(stderr, "ERROR: failed to create a unique id group.\n");
364 		exit(1);
365 	}
366 
367 	sock_fanout_getopts(fds[0], &typeflags, &first_group_id);
368 	if (typeflags != PACKET_FANOUT_HASH) {
369 		fprintf(stderr, "ERROR: unexpected typeflags %x\n", typeflags);
370 		exit(1);
371 	}
372 
373 	if (sock_fanout_open(PACKET_FANOUT_CPU, first_group_id) != -1) {
374 		fprintf(stderr, "ERROR: joined group with wrong type.\n");
375 		exit(1);
376 	}
377 
378 	fds[1] = sock_fanout_open(PACKET_FANOUT_HASH, first_group_id);
379 	if (fds[1] == -1) {
380 		fprintf(stderr,
381 			"ERROR: failed to join previously created group.\n");
382 		exit(1);
383 	}
384 
385 	fds[2] = sock_fanout_open(PACKET_FANOUT_HASH |
386 				  PACKET_FANOUT_FLAG_UNIQUEID, 0);
387 	if (fds[2] == -1) {
388 		fprintf(stderr,
389 			"ERROR: failed to create a second unique id group.\n");
390 		exit(1);
391 	}
392 
393 	sock_fanout_getopts(fds[2], &typeflags, &second_group_id);
394 	if (sock_fanout_open(PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_UNIQUEID,
395 			     second_group_id) != -1) {
396 		fprintf(stderr,
397 			"ERROR: specified a group id when requesting unique id\n");
398 		exit(1);
399 	}
400 
401 	if (close(fds[0]) || close(fds[1]) || close(fds[2])) {
402 		fprintf(stderr, "ERROR: closing sockets\n");
403 		exit(1);
404 	}
405 }
406 
407 static int test_datapath(uint16_t typeflags, int port_off,
408 			 const int expect1[], const int expect2[])
409 {
410 	const int expect0[] = { 0, 0 };
411 	char *rings[2];
412 	uint8_t type = typeflags & 0xFF;
413 	int fds[2], fds_udp[2][2], ret;
414 
415 	fprintf(stderr, "\ntest: datapath 0x%hx ports %hu,%hu\n",
416 		typeflags, (uint16_t)PORT_BASE,
417 		(uint16_t)(PORT_BASE + port_off));
418 
419 	fds[0] = sock_fanout_open(typeflags, 0);
420 	fds[1] = sock_fanout_open(typeflags, 0);
421 	if (fds[0] == -1 || fds[1] == -1) {
422 		fprintf(stderr, "ERROR: failed open\n");
423 		exit(1);
424 	}
425 	if (type == PACKET_FANOUT_CBPF)
426 		sock_fanout_set_cbpf(fds[0]);
427 	else if (type == PACKET_FANOUT_EBPF)
428 		sock_fanout_set_ebpf(fds[0]);
429 
430 	rings[0] = sock_fanout_open_ring(fds[0]);
431 	rings[1] = sock_fanout_open_ring(fds[1]);
432 	pair_udp_open(fds_udp[0], PORT_BASE);
433 	pair_udp_open(fds_udp[1], PORT_BASE + port_off);
434 	sock_fanout_read(fds, rings, expect0);
435 
436 	/* Send data, but not enough to overflow a queue */
437 	pair_udp_send(fds_udp[0], 15);
438 	pair_udp_send_char(fds_udp[1], 5, DATA_CHAR_1);
439 	ret = sock_fanout_read(fds, rings, expect1);
440 
441 	/* Send more data, overflow the queue */
442 	pair_udp_send_char(fds_udp[0], 15, DATA_CHAR_1);
443 	/* TODO: ensure consistent order between expect1 and expect2 */
444 	ret |= sock_fanout_read(fds, rings, expect2);
445 
446 	if (munmap(rings[1], RING_NUM_FRAMES * getpagesize()) ||
447 	    munmap(rings[0], RING_NUM_FRAMES * getpagesize())) {
448 		fprintf(stderr, "close rings\n");
449 		exit(1);
450 	}
451 	if (close(fds_udp[1][1]) || close(fds_udp[1][0]) ||
452 	    close(fds_udp[0][1]) || close(fds_udp[0][0]) ||
453 	    close(fds[1]) || close(fds[0])) {
454 		fprintf(stderr, "close datapath\n");
455 		exit(1);
456 	}
457 
458 	return ret;
459 }
460 
461 static int set_cpuaffinity(int cpuid)
462 {
463 	cpu_set_t mask;
464 
465 	CPU_ZERO(&mask);
466 	CPU_SET(cpuid, &mask);
467 	if (sched_setaffinity(0, sizeof(mask), &mask)) {
468 		if (errno != EINVAL) {
469 			fprintf(stderr, "setaffinity %d\n", cpuid);
470 			exit(1);
471 		}
472 		return 1;
473 	}
474 
475 	return 0;
476 }
477 
478 int main(int argc, char **argv)
479 {
480 	const int expect_hash[2][2]	= { { 15, 5 },  { 20, 5 } };
481 	const int expect_hash_rb[2][2]	= { { 15, 5 },  { 20, 15 } };
482 	const int expect_lb[2][2]	= { { 10, 10 }, { 18, 17 } };
483 	const int expect_rb[2][2]	= { { 15, 5 },  { 20, 15 } };
484 	const int expect_cpu0[2][2]	= { { 20, 0 },  { 20, 0 } };
485 	const int expect_cpu1[2][2]	= { { 0, 20 },  { 0, 20 } };
486 	const int expect_bpf[2][2]	= { { 15, 5 },  { 15, 20 } };
487 	const int expect_uniqueid[2][2] = { { 20, 20},  { 20, 20 } };
488 	int port_off = 2, tries = 20, ret;
489 
490 	test_control_single();
491 	test_control_group();
492 	test_control_group_max_num_members();
493 	test_unique_fanout_group_ids();
494 
495 	/* PACKET_FANOUT_MAX */
496 	cfg_max_num_members = 1 << 16;
497 	/* find a set of ports that do not collide onto the same socket */
498 	ret = test_datapath(PACKET_FANOUT_HASH, port_off,
499 			    expect_hash[0], expect_hash[1]);
500 	while (ret) {
501 		fprintf(stderr, "info: trying alternate ports (%d)\n", tries);
502 		ret = test_datapath(PACKET_FANOUT_HASH, ++port_off,
503 				    expect_hash[0], expect_hash[1]);
504 		if (!--tries) {
505 			fprintf(stderr, "too many collisions\n");
506 			return 1;
507 		}
508 	}
509 
510 	ret |= test_datapath(PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_ROLLOVER,
511 			     port_off, expect_hash_rb[0], expect_hash_rb[1]);
512 	ret |= test_datapath(PACKET_FANOUT_LB,
513 			     port_off, expect_lb[0], expect_lb[1]);
514 	ret |= test_datapath(PACKET_FANOUT_ROLLOVER,
515 			     port_off, expect_rb[0], expect_rb[1]);
516 
517 	ret |= test_datapath(PACKET_FANOUT_CBPF,
518 			     port_off, expect_bpf[0], expect_bpf[1]);
519 	ret |= test_datapath(PACKET_FANOUT_EBPF,
520 			     port_off, expect_bpf[0], expect_bpf[1]);
521 
522 	set_cpuaffinity(0);
523 	ret |= test_datapath(PACKET_FANOUT_CPU, port_off,
524 			     expect_cpu0[0], expect_cpu0[1]);
525 	if (!set_cpuaffinity(1))
526 		/* TODO: test that choice alternates with previous */
527 		ret |= test_datapath(PACKET_FANOUT_CPU, port_off,
528 				     expect_cpu1[0], expect_cpu1[1]);
529 
530 	ret |= test_datapath(PACKET_FANOUT_FLAG_UNIQUEID, port_off,
531 			     expect_uniqueid[0], expect_uniqueid[1]);
532 
533 	if (ret)
534 		return 1;
535 
536 	printf("OK. All tests passed\n");
537 	return 0;
538 }
539