1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * ipsec.c - Check xfrm on veth inside a net-ns.
4  * Copyright (c) 2018 Dmitry Safonov
5  */
6 
7 #define _GNU_SOURCE
8 
9 #include <arpa/inet.h>
10 #include <asm/types.h>
11 #include <errno.h>
12 #include <fcntl.h>
13 #include <limits.h>
14 #include <linux/limits.h>
15 #include <linux/netlink.h>
16 #include <linux/random.h>
17 #include <linux/rtnetlink.h>
18 #include <linux/veth.h>
19 #include <linux/xfrm.h>
20 #include <netinet/in.h>
21 #include <net/if.h>
22 #include <sched.h>
23 #include <stdbool.h>
24 #include <stdint.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <sys/mman.h>
29 #include <sys/socket.h>
30 #include <sys/stat.h>
31 #include <sys/syscall.h>
32 #include <sys/types.h>
33 #include <sys/wait.h>
34 #include <time.h>
35 #include <unistd.h>
36 
37 #include "../kselftest.h"
38 
39 #define printk(fmt, ...)						\
40 	ksft_print_msg("%d[%u] " fmt "\n", getpid(), __LINE__, ##__VA_ARGS__)
41 
42 #define pr_err(fmt, ...)	printk(fmt ": %m", ##__VA_ARGS__)
43 
44 #define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))
45 
46 #define IPV4_STR_SZ	16	/* xxx.xxx.xxx.xxx is longest + \0 */
47 #define MAX_PAYLOAD	2048
48 #define XFRM_ALGO_KEY_BUF_SIZE	512
49 #define MAX_PROCESSES	(1 << 14) /* /16 mask divided by /30 subnets */
50 #define INADDR_A	((in_addr_t) 0x0a000000) /* 10.0.0.0 */
51 #define INADDR_B	((in_addr_t) 0xc0a80000) /* 192.168.0.0 */
52 
53 /* /30 mask for one veth connection */
54 #define PREFIX_LEN	30
55 #define child_ip(nr)	(4*nr + 1)
56 #define grchild_ip(nr)	(4*nr + 2)
57 
58 #define VETH_FMT	"ktst-%d"
59 #define VETH_LEN	12
60 
61 static int nsfd_parent	= -1;
62 static int nsfd_childa	= -1;
63 static int nsfd_childb	= -1;
64 static long page_size;
65 
66 /*
67  * ksft_cnt is static in kselftest, so isn't shared with children.
68  * We have to send a test result back to parent and count there.
69  * results_fd is a pipe with test feedback from children.
70  */
71 static int results_fd[2];
72 
73 const unsigned int ping_delay_nsec	= 50 * 1000 * 1000;
74 const unsigned int ping_timeout		= 300;
75 const unsigned int ping_count		= 100;
76 const unsigned int ping_success		= 80;
77 
78 static void randomize_buffer(void *buf, size_t buflen)
79 {
80 	int *p = (int *)buf;
81 	size_t words = buflen / sizeof(int);
82 	size_t leftover = buflen % sizeof(int);
83 
84 	if (!buflen)
85 		return;
86 
87 	while (words--)
88 		*p++ = rand();
89 
90 	if (leftover) {
91 		int tmp = rand();
92 
93 		memcpy(buf + buflen - leftover, &tmp, leftover);
94 	}
95 
96 	return;
97 }
98 
99 static int unshare_open(void)
100 {
101 	const char *netns_path = "/proc/self/ns/net";
102 	int fd;
103 
104 	if (unshare(CLONE_NEWNET) != 0) {
105 		pr_err("unshare()");
106 		return -1;
107 	}
108 
109 	fd = open(netns_path, O_RDONLY);
110 	if (fd <= 0) {
111 		pr_err("open(%s)", netns_path);
112 		return -1;
113 	}
114 
115 	return fd;
116 }
117 
118 static int switch_ns(int fd)
119 {
120 	if (setns(fd, CLONE_NEWNET)) {
121 		pr_err("setns()");
122 		return -1;
123 	}
124 	return 0;
125 }
126 
127 /*
128  * Running the test inside a new parent net namespace to bother less
129  * about cleanup on error-path.
130  */
131 static int init_namespaces(void)
132 {
133 	nsfd_parent = unshare_open();
134 	if (nsfd_parent <= 0)
135 		return -1;
136 
137 	nsfd_childa = unshare_open();
138 	if (nsfd_childa <= 0)
139 		return -1;
140 
141 	if (switch_ns(nsfd_parent))
142 		return -1;
143 
144 	nsfd_childb = unshare_open();
145 	if (nsfd_childb <= 0)
146 		return -1;
147 
148 	if (switch_ns(nsfd_parent))
149 		return -1;
150 	return 0;
151 }
152 
153 static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
154 {
155 	if (*sock > 0) {
156 		seq_nr++;
157 		return 0;
158 	}
159 
160 	*sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
161 	if (*sock <= 0) {
162 		pr_err("socket(AF_NETLINK)");
163 		return -1;
164 	}
165 
166 	randomize_buffer(seq_nr, sizeof(*seq_nr));
167 
168 	return 0;
169 }
170 
171 static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
172 {
173 	return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
174 }
175 
176 static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
177 		unsigned short rta_type, const void *payload, size_t size)
178 {
179 	/* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
180 	struct rtattr *attr = rtattr_hdr(nh);
181 	size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
182 
183 	if (req_sz < nl_size) {
184 		printk("req buf is too small: %zu < %zu", req_sz, nl_size);
185 		return -1;
186 	}
187 	nh->nlmsg_len = nl_size;
188 
189 	attr->rta_len = RTA_LENGTH(size);
190 	attr->rta_type = rta_type;
191 	memcpy(RTA_DATA(attr), payload, size);
192 
193 	return 0;
194 }
195 
196 static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
197 		unsigned short rta_type, const void *payload, size_t size)
198 {
199 	struct rtattr *ret = rtattr_hdr(nh);
200 
201 	if (rtattr_pack(nh, req_sz, rta_type, payload, size))
202 		return 0;
203 
204 	return ret;
205 }
206 
207 static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
208 		unsigned short rta_type)
209 {
210 	return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
211 }
212 
213 static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
214 {
215 	char *nlmsg_end = (char *)nh + nh->nlmsg_len;
216 
217 	attr->rta_len = nlmsg_end - (char *)attr;
218 }
219 
220 static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
221 		const char *peer, int ns)
222 {
223 	struct ifinfomsg pi;
224 	struct rtattr *peer_attr;
225 
226 	memset(&pi, 0, sizeof(pi));
227 	pi.ifi_family	= AF_UNSPEC;
228 	pi.ifi_change	= 0xFFFFFFFF;
229 
230 	peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
231 	if (!peer_attr)
232 		return -1;
233 
234 	if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
235 		return -1;
236 
237 	if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
238 		return -1;
239 
240 	rtattr_end(nh, peer_attr);
241 
242 	return 0;
243 }
244 
245 static int netlink_check_answer(int sock)
246 {
247 	struct nlmsgerror {
248 		struct nlmsghdr hdr;
249 		int error;
250 		struct nlmsghdr orig_msg;
251 	} answer;
252 
253 	if (recv(sock, &answer, sizeof(answer), 0) < 0) {
254 		pr_err("recv()");
255 		return -1;
256 	} else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
257 		printk("expected NLMSG_ERROR, got %d", (int)answer.hdr.nlmsg_type);
258 		return -1;
259 	} else if (answer.error) {
260 		printk("NLMSG_ERROR: %d: %s",
261 			answer.error, strerror(-answer.error));
262 		return answer.error;
263 	}
264 
265 	return 0;
266 }
267 
268 static int veth_add(int sock, uint32_t seq, const char *peera, int ns_a,
269 		const char *peerb, int ns_b)
270 {
271 	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
272 	struct {
273 		struct nlmsghdr		nh;
274 		struct ifinfomsg	info;
275 		char			attrbuf[MAX_PAYLOAD];
276 	} req;
277 	const char veth_type[] = "veth";
278 	struct rtattr *link_info, *info_data;
279 
280 	memset(&req, 0, sizeof(req));
281 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
282 	req.nh.nlmsg_type	= RTM_NEWLINK;
283 	req.nh.nlmsg_flags	= flags;
284 	req.nh.nlmsg_seq	= seq;
285 	req.info.ifi_family	= AF_UNSPEC;
286 	req.info.ifi_change	= 0xFFFFFFFF;
287 
288 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, peera, strlen(peera)))
289 		return -1;
290 
291 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
292 		return -1;
293 
294 	link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
295 	if (!link_info)
296 		return -1;
297 
298 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
299 		return -1;
300 
301 	info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
302 	if (!info_data)
303 		return -1;
304 
305 	if (veth_pack_peerb(&req.nh, sizeof(req), peerb, ns_b))
306 		return -1;
307 
308 	rtattr_end(&req.nh, info_data);
309 	rtattr_end(&req.nh, link_info);
310 
311 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
312 		pr_err("send()");
313 		return -1;
314 	}
315 	return netlink_check_answer(sock);
316 }
317 
318 static int ip4_addr_set(int sock, uint32_t seq, const char *intf,
319 		struct in_addr addr, uint8_t prefix)
320 {
321 	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
322 	struct {
323 		struct nlmsghdr		nh;
324 		struct ifaddrmsg	info;
325 		char			attrbuf[MAX_PAYLOAD];
326 	} req;
327 
328 	memset(&req, 0, sizeof(req));
329 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
330 	req.nh.nlmsg_type	= RTM_NEWADDR;
331 	req.nh.nlmsg_flags	= flags;
332 	req.nh.nlmsg_seq	= seq;
333 	req.info.ifa_family	= AF_INET;
334 	req.info.ifa_prefixlen	= prefix;
335 	req.info.ifa_index	= if_nametoindex(intf);
336 
337 #ifdef DEBUG
338 	{
339 		char addr_str[IPV4_STR_SZ] = {};
340 
341 		strncpy(addr_str, inet_ntoa(addr), IPV4_STR_SZ - 1);
342 
343 		printk("ip addr set %s", addr_str);
344 	}
345 #endif
346 
347 	if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, sizeof(addr)))
348 		return -1;
349 
350 	if (rtattr_pack(&req.nh, sizeof(req), IFA_ADDRESS, &addr, sizeof(addr)))
351 		return -1;
352 
353 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
354 		pr_err("send()");
355 		return -1;
356 	}
357 	return netlink_check_answer(sock);
358 }
359 
360 static int link_set_up(int sock, uint32_t seq, const char *intf)
361 {
362 	struct {
363 		struct nlmsghdr		nh;
364 		struct ifinfomsg	info;
365 		char			attrbuf[MAX_PAYLOAD];
366 	} req;
367 
368 	memset(&req, 0, sizeof(req));
369 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
370 	req.nh.nlmsg_type	= RTM_NEWLINK;
371 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
372 	req.nh.nlmsg_seq	= seq;
373 	req.info.ifi_family	= AF_UNSPEC;
374 	req.info.ifi_change	= 0xFFFFFFFF;
375 	req.info.ifi_index	= if_nametoindex(intf);
376 	req.info.ifi_flags	= IFF_UP;
377 	req.info.ifi_change	= IFF_UP;
378 
379 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
380 		pr_err("send()");
381 		return -1;
382 	}
383 	return netlink_check_answer(sock);
384 }
385 
386 static int ip4_route_set(int sock, uint32_t seq, const char *intf,
387 		struct in_addr src, struct in_addr dst)
388 {
389 	struct {
390 		struct nlmsghdr	nh;
391 		struct rtmsg	rt;
392 		char		attrbuf[MAX_PAYLOAD];
393 	} req;
394 	unsigned int index = if_nametoindex(intf);
395 
396 	memset(&req, 0, sizeof(req));
397 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.rt));
398 	req.nh.nlmsg_type	= RTM_NEWROUTE;
399 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
400 	req.nh.nlmsg_seq	= seq;
401 	req.rt.rtm_family	= AF_INET;
402 	req.rt.rtm_dst_len	= 32;
403 	req.rt.rtm_table	= RT_TABLE_MAIN;
404 	req.rt.rtm_protocol	= RTPROT_BOOT;
405 	req.rt.rtm_scope	= RT_SCOPE_LINK;
406 	req.rt.rtm_type		= RTN_UNICAST;
407 
408 	if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, sizeof(dst)))
409 		return -1;
410 
411 	if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, sizeof(src)))
412 		return -1;
413 
414 	if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
415 		return -1;
416 
417 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
418 		pr_err("send()");
419 		return -1;
420 	}
421 
422 	return netlink_check_answer(sock);
423 }
424 
425 static int tunnel_set_route(int route_sock, uint32_t *route_seq, char *veth,
426 		struct in_addr tunsrc, struct in_addr tundst)
427 {
428 	if (ip4_addr_set(route_sock, (*route_seq)++, "lo",
429 			tunsrc, PREFIX_LEN)) {
430 		printk("Failed to set ipv4 addr");
431 		return -1;
432 	}
433 
434 	if (ip4_route_set(route_sock, (*route_seq)++, veth, tunsrc, tundst)) {
435 		printk("Failed to set ipv4 route");
436 		return -1;
437 	}
438 
439 	return 0;
440 }
441 
442 static int init_child(int nsfd, char *veth, unsigned int src, unsigned int dst)
443 {
444 	struct in_addr intsrc = inet_makeaddr(INADDR_B, src);
445 	struct in_addr tunsrc = inet_makeaddr(INADDR_A, src);
446 	struct in_addr tundst = inet_makeaddr(INADDR_A, dst);
447 	int route_sock = -1, ret = -1;
448 	uint32_t route_seq;
449 
450 	if (switch_ns(nsfd))
451 		return -1;
452 
453 	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) {
454 		printk("Failed to open netlink route socket in child");
455 		return -1;
456 	}
457 
458 	if (ip4_addr_set(route_sock, route_seq++, veth, intsrc, PREFIX_LEN)) {
459 		printk("Failed to set ipv4 addr");
460 		goto err;
461 	}
462 
463 	if (link_set_up(route_sock, route_seq++, veth)) {
464 		printk("Failed to bring up %s", veth);
465 		goto err;
466 	}
467 
468 	if (tunnel_set_route(route_sock, &route_seq, veth, tunsrc, tundst)) {
469 		printk("Failed to add tunnel route on %s", veth);
470 		goto err;
471 	}
472 	ret = 0;
473 
474 err:
475 	close(route_sock);
476 	return ret;
477 }
478 
479 #define ALGO_LEN	64
480 enum desc_type {
481 	CREATE_TUNNEL	= 0,
482 	ALLOCATE_SPI,
483 	MONITOR_ACQUIRE,
484 	EXPIRE_STATE,
485 	EXPIRE_POLICY,
486 	SPDINFO_ATTRS,
487 };
488 const char *desc_name[] = {
489 	"create tunnel",
490 	"alloc spi",
491 	"monitor acquire",
492 	"expire state",
493 	"expire policy",
494 	"spdinfo attributes",
495 	""
496 };
497 struct xfrm_desc {
498 	enum desc_type	type;
499 	uint8_t		proto;
500 	char		a_algo[ALGO_LEN];
501 	char		e_algo[ALGO_LEN];
502 	char		c_algo[ALGO_LEN];
503 	char		ae_algo[ALGO_LEN];
504 	unsigned int	icv_len;
505 	/* unsigned key_len; */
506 };
507 
508 enum msg_type {
509 	MSG_ACK		= 0,
510 	MSG_EXIT,
511 	MSG_PING,
512 	MSG_XFRM_PREPARE,
513 	MSG_XFRM_ADD,
514 	MSG_XFRM_DEL,
515 	MSG_XFRM_CLEANUP,
516 };
517 
518 struct test_desc {
519 	enum msg_type type;
520 	union {
521 		struct {
522 			in_addr_t reply_ip;
523 			unsigned int port;
524 		} ping;
525 		struct xfrm_desc xfrm_desc;
526 	} body;
527 };
528 
529 struct test_result {
530 	struct xfrm_desc desc;
531 	unsigned int res;
532 };
533 
534 static void write_test_result(unsigned int res, struct xfrm_desc *d)
535 {
536 	struct test_result tr = {};
537 	ssize_t ret;
538 
539 	tr.desc = *d;
540 	tr.res = res;
541 
542 	ret = write(results_fd[1], &tr, sizeof(tr));
543 	if (ret != sizeof(tr))
544 		pr_err("Failed to write the result in pipe %zd", ret);
545 }
546 
547 static void write_msg(int fd, struct test_desc *msg, bool exit_of_fail)
548 {
549 	ssize_t bytes = write(fd, msg, sizeof(*msg));
550 
551 	/* Make sure that write/read is atomic to a pipe */
552 	BUILD_BUG_ON(sizeof(struct test_desc) > PIPE_BUF);
553 
554 	if (bytes < 0) {
555 		pr_err("write()");
556 		if (exit_of_fail)
557 			exit(KSFT_FAIL);
558 	}
559 	if (bytes != sizeof(*msg)) {
560 		pr_err("sent part of the message %zd/%zu", bytes, sizeof(*msg));
561 		if (exit_of_fail)
562 			exit(KSFT_FAIL);
563 	}
564 }
565 
566 static void read_msg(int fd, struct test_desc *msg, bool exit_of_fail)
567 {
568 	ssize_t bytes = read(fd, msg, sizeof(*msg));
569 
570 	if (bytes < 0) {
571 		pr_err("read()");
572 		if (exit_of_fail)
573 			exit(KSFT_FAIL);
574 	}
575 	if (bytes != sizeof(*msg)) {
576 		pr_err("got incomplete message %zd/%zu", bytes, sizeof(*msg));
577 		if (exit_of_fail)
578 			exit(KSFT_FAIL);
579 	}
580 }
581 
582 static int udp_ping_init(struct in_addr listen_ip, unsigned int u_timeout,
583 		unsigned int *server_port, int sock[2])
584 {
585 	struct sockaddr_in server;
586 	struct timeval t = { .tv_sec = 0, .tv_usec = u_timeout };
587 	socklen_t s_len = sizeof(server);
588 
589 	sock[0] = socket(AF_INET, SOCK_DGRAM, 0);
590 	if (sock[0] < 0) {
591 		pr_err("socket()");
592 		return -1;
593 	}
594 
595 	server.sin_family	= AF_INET;
596 	server.sin_port		= 0;
597 	memcpy(&server.sin_addr.s_addr, &listen_ip, sizeof(struct in_addr));
598 
599 	if (bind(sock[0], (struct sockaddr *)&server, s_len)) {
600 		pr_err("bind()");
601 		goto err_close_server;
602 	}
603 
604 	if (getsockname(sock[0], (struct sockaddr *)&server, &s_len)) {
605 		pr_err("getsockname()");
606 		goto err_close_server;
607 	}
608 
609 	*server_port = ntohs(server.sin_port);
610 
611 	if (setsockopt(sock[0], SOL_SOCKET, SO_RCVTIMEO, (const char *)&t, sizeof t)) {
612 		pr_err("setsockopt()");
613 		goto err_close_server;
614 	}
615 
616 	sock[1] = socket(AF_INET, SOCK_DGRAM, 0);
617 	if (sock[1] < 0) {
618 		pr_err("socket()");
619 		goto err_close_server;
620 	}
621 
622 	return 0;
623 
624 err_close_server:
625 	close(sock[0]);
626 	return -1;
627 }
628 
629 static int udp_ping_send(int sock[2], in_addr_t dest_ip, unsigned int port,
630 		char *buf, size_t buf_len)
631 {
632 	struct sockaddr_in server;
633 	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
634 	char *sock_buf[buf_len];
635 	ssize_t r_bytes, s_bytes;
636 
637 	server.sin_family	= AF_INET;
638 	server.sin_port		= htons(port);
639 	server.sin_addr.s_addr	= dest_ip;
640 
641 	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
642 	if (s_bytes < 0) {
643 		pr_err("sendto()");
644 		return -1;
645 	} else if (s_bytes != buf_len) {
646 		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
647 		return -1;
648 	}
649 
650 	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
651 	if (r_bytes < 0) {
652 		if (errno != EAGAIN)
653 			pr_err("recv()");
654 		return -1;
655 	} else if (r_bytes == 0) { /* EOF */
656 		printk("EOF on reply to ping");
657 		return -1;
658 	} else if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
659 		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
660 		return -1;
661 	}
662 
663 	return 0;
664 }
665 
666 static int udp_ping_reply(int sock[2], in_addr_t dest_ip, unsigned int port,
667 		char *buf, size_t buf_len)
668 {
669 	struct sockaddr_in server;
670 	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
671 	char *sock_buf[buf_len];
672 	ssize_t r_bytes, s_bytes;
673 
674 	server.sin_family	= AF_INET;
675 	server.sin_port		= htons(port);
676 	server.sin_addr.s_addr	= dest_ip;
677 
678 	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
679 	if (r_bytes < 0) {
680 		if (errno != EAGAIN)
681 			pr_err("recv()");
682 		return -1;
683 	}
684 	if (r_bytes == 0) { /* EOF */
685 		printk("EOF on reply to ping");
686 		return -1;
687 	}
688 	if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
689 		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
690 		return -1;
691 	}
692 
693 	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
694 	if (s_bytes < 0) {
695 		pr_err("sendto()");
696 		return -1;
697 	} else if (s_bytes != buf_len) {
698 		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
699 		return -1;
700 	}
701 
702 	return 0;
703 }
704 
705 typedef int (*ping_f)(int sock[2], in_addr_t dest_ip, unsigned int port,
706 		char *buf, size_t buf_len);
707 static int do_ping(int cmd_fd, char *buf, size_t buf_len, struct in_addr from,
708 		bool init_side, int d_port, in_addr_t to, ping_f func)
709 {
710 	struct test_desc msg;
711 	unsigned int s_port, i, ping_succeeded = 0;
712 	int ping_sock[2];
713 	char to_str[IPV4_STR_SZ] = {}, from_str[IPV4_STR_SZ] = {};
714 
715 	if (udp_ping_init(from, ping_timeout, &s_port, ping_sock)) {
716 		printk("Failed to init ping");
717 		return -1;
718 	}
719 
720 	memset(&msg, 0, sizeof(msg));
721 	msg.type		= MSG_PING;
722 	msg.body.ping.port	= s_port;
723 	memcpy(&msg.body.ping.reply_ip, &from, sizeof(from));
724 
725 	write_msg(cmd_fd, &msg, 0);
726 	if (init_side) {
727 		/* The other end sends ip to ping */
728 		read_msg(cmd_fd, &msg, 0);
729 		if (msg.type != MSG_PING)
730 			return -1;
731 		to = msg.body.ping.reply_ip;
732 		d_port = msg.body.ping.port;
733 	}
734 
735 	for (i = 0; i < ping_count ; i++) {
736 		struct timespec sleep_time = {
737 			.tv_sec = 0,
738 			.tv_nsec = ping_delay_nsec,
739 		};
740 
741 		ping_succeeded += !func(ping_sock, to, d_port, buf, page_size);
742 		nanosleep(&sleep_time, 0);
743 	}
744 
745 	close(ping_sock[0]);
746 	close(ping_sock[1]);
747 
748 	strncpy(to_str, inet_ntoa(*(struct in_addr *)&to), IPV4_STR_SZ - 1);
749 	strncpy(from_str, inet_ntoa(from), IPV4_STR_SZ - 1);
750 
751 	if (ping_succeeded < ping_success) {
752 		printk("ping (%s) %s->%s failed %u/%u times",
753 			init_side ? "send" : "reply", from_str, to_str,
754 			ping_count - ping_succeeded, ping_count);
755 		return -1;
756 	}
757 
758 #ifdef DEBUG
759 	printk("ping (%s) %s->%s succeeded %u/%u times",
760 		init_side ? "send" : "reply", from_str, to_str,
761 		ping_succeeded, ping_count);
762 #endif
763 
764 	return 0;
765 }
766 
767 static int xfrm_fill_key(char *name, char *buf,
768 		size_t buf_len, unsigned int *key_len)
769 {
770 	/* TODO: use set/map instead */
771 	if (strncmp(name, "digest_null", ALGO_LEN) == 0)
772 		*key_len = 0;
773 	else if (strncmp(name, "ecb(cipher_null)", ALGO_LEN) == 0)
774 		*key_len = 0;
775 	else if (strncmp(name, "cbc(des)", ALGO_LEN) == 0)
776 		*key_len = 64;
777 	else if (strncmp(name, "hmac(md5)", ALGO_LEN) == 0)
778 		*key_len = 128;
779 	else if (strncmp(name, "cmac(aes)", ALGO_LEN) == 0)
780 		*key_len = 128;
781 	else if (strncmp(name, "xcbc(aes)", ALGO_LEN) == 0)
782 		*key_len = 128;
783 	else if (strncmp(name, "cbc(cast5)", ALGO_LEN) == 0)
784 		*key_len = 128;
785 	else if (strncmp(name, "cbc(serpent)", ALGO_LEN) == 0)
786 		*key_len = 128;
787 	else if (strncmp(name, "hmac(sha1)", ALGO_LEN) == 0)
788 		*key_len = 160;
789 	else if (strncmp(name, "hmac(rmd160)", ALGO_LEN) == 0)
790 		*key_len = 160;
791 	else if (strncmp(name, "cbc(des3_ede)", ALGO_LEN) == 0)
792 		*key_len = 192;
793 	else if (strncmp(name, "hmac(sha256)", ALGO_LEN) == 0)
794 		*key_len = 256;
795 	else if (strncmp(name, "cbc(aes)", ALGO_LEN) == 0)
796 		*key_len = 256;
797 	else if (strncmp(name, "cbc(camellia)", ALGO_LEN) == 0)
798 		*key_len = 256;
799 	else if (strncmp(name, "cbc(twofish)", ALGO_LEN) == 0)
800 		*key_len = 256;
801 	else if (strncmp(name, "rfc3686(ctr(aes))", ALGO_LEN) == 0)
802 		*key_len = 288;
803 	else if (strncmp(name, "hmac(sha384)", ALGO_LEN) == 0)
804 		*key_len = 384;
805 	else if (strncmp(name, "cbc(blowfish)", ALGO_LEN) == 0)
806 		*key_len = 448;
807 	else if (strncmp(name, "hmac(sha512)", ALGO_LEN) == 0)
808 		*key_len = 512;
809 	else if (strncmp(name, "rfc4106(gcm(aes))-128", ALGO_LEN) == 0)
810 		*key_len = 160;
811 	else if (strncmp(name, "rfc4543(gcm(aes))-128", ALGO_LEN) == 0)
812 		*key_len = 160;
813 	else if (strncmp(name, "rfc4309(ccm(aes))-128", ALGO_LEN) == 0)
814 		*key_len = 152;
815 	else if (strncmp(name, "rfc4106(gcm(aes))-192", ALGO_LEN) == 0)
816 		*key_len = 224;
817 	else if (strncmp(name, "rfc4543(gcm(aes))-192", ALGO_LEN) == 0)
818 		*key_len = 224;
819 	else if (strncmp(name, "rfc4309(ccm(aes))-192", ALGO_LEN) == 0)
820 		*key_len = 216;
821 	else if (strncmp(name, "rfc4106(gcm(aes))-256", ALGO_LEN) == 0)
822 		*key_len = 288;
823 	else if (strncmp(name, "rfc4543(gcm(aes))-256", ALGO_LEN) == 0)
824 		*key_len = 288;
825 	else if (strncmp(name, "rfc4309(ccm(aes))-256", ALGO_LEN) == 0)
826 		*key_len = 280;
827 	else if (strncmp(name, "rfc7539(chacha20,poly1305)-128", ALGO_LEN) == 0)
828 		*key_len = 0;
829 
830 	if (*key_len > buf_len) {
831 		printk("Can't pack a key - too big for buffer");
832 		return -1;
833 	}
834 
835 	randomize_buffer(buf, *key_len);
836 
837 	return 0;
838 }
839 
840 static int xfrm_state_pack_algo(struct nlmsghdr *nh, size_t req_sz,
841 		struct xfrm_desc *desc)
842 {
843 	struct {
844 		union {
845 			struct xfrm_algo	alg;
846 			struct xfrm_algo_aead	aead;
847 			struct xfrm_algo_auth	auth;
848 		} u;
849 		char buf[XFRM_ALGO_KEY_BUF_SIZE];
850 	} alg = {};
851 	size_t alen, elen, clen, aelen;
852 	unsigned short type;
853 
854 	alen = strlen(desc->a_algo);
855 	elen = strlen(desc->e_algo);
856 	clen = strlen(desc->c_algo);
857 	aelen = strlen(desc->ae_algo);
858 
859 	/* Verify desc */
860 	switch (desc->proto) {
861 	case IPPROTO_AH:
862 		if (!alen || elen || clen || aelen) {
863 			printk("BUG: buggy ah desc");
864 			return -1;
865 		}
866 		strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN - 1);
867 		if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
868 				sizeof(alg.buf), &alg.u.alg.alg_key_len))
869 			return -1;
870 		type = XFRMA_ALG_AUTH;
871 		break;
872 	case IPPROTO_COMP:
873 		if (!clen || elen || alen || aelen) {
874 			printk("BUG: buggy comp desc");
875 			return -1;
876 		}
877 		strncpy(alg.u.alg.alg_name, desc->c_algo, ALGO_LEN - 1);
878 		if (xfrm_fill_key(desc->c_algo, alg.u.alg.alg_key,
879 				sizeof(alg.buf), &alg.u.alg.alg_key_len))
880 			return -1;
881 		type = XFRMA_ALG_COMP;
882 		break;
883 	case IPPROTO_ESP:
884 		if (!((alen && elen) ^ aelen) || clen) {
885 			printk("BUG: buggy esp desc");
886 			return -1;
887 		}
888 		if (aelen) {
889 			alg.u.aead.alg_icv_len = desc->icv_len;
890 			strncpy(alg.u.aead.alg_name, desc->ae_algo, ALGO_LEN - 1);
891 			if (xfrm_fill_key(desc->ae_algo, alg.u.aead.alg_key,
892 						sizeof(alg.buf), &alg.u.aead.alg_key_len))
893 				return -1;
894 			type = XFRMA_ALG_AEAD;
895 		} else {
896 
897 			strncpy(alg.u.alg.alg_name, desc->e_algo, ALGO_LEN - 1);
898 			type = XFRMA_ALG_CRYPT;
899 			if (xfrm_fill_key(desc->e_algo, alg.u.alg.alg_key,
900 						sizeof(alg.buf), &alg.u.alg.alg_key_len))
901 				return -1;
902 			if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
903 				return -1;
904 
905 			strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN);
906 			type = XFRMA_ALG_AUTH;
907 			if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
908 						sizeof(alg.buf), &alg.u.alg.alg_key_len))
909 				return -1;
910 		}
911 		break;
912 	default:
913 		printk("BUG: unknown proto in desc");
914 		return -1;
915 	}
916 
917 	if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
918 		return -1;
919 
920 	return 0;
921 }
922 
923 static inline uint32_t gen_spi(struct in_addr src)
924 {
925 	return htonl(inet_lnaof(src));
926 }
927 
928 static int xfrm_state_add(int xfrm_sock, uint32_t seq, uint32_t spi,
929 		struct in_addr src, struct in_addr dst,
930 		struct xfrm_desc *desc)
931 {
932 	struct {
933 		struct nlmsghdr		nh;
934 		struct xfrm_usersa_info	info;
935 		char			attrbuf[MAX_PAYLOAD];
936 	} req;
937 
938 	memset(&req, 0, sizeof(req));
939 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
940 	req.nh.nlmsg_type	= XFRM_MSG_NEWSA;
941 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
942 	req.nh.nlmsg_seq	= seq;
943 
944 	/* Fill selector. */
945 	memcpy(&req.info.sel.daddr, &dst, sizeof(dst));
946 	memcpy(&req.info.sel.saddr, &src, sizeof(src));
947 	req.info.sel.family		= AF_INET;
948 	req.info.sel.prefixlen_d	= PREFIX_LEN;
949 	req.info.sel.prefixlen_s	= PREFIX_LEN;
950 
951 	/* Fill id */
952 	memcpy(&req.info.id.daddr, &dst, sizeof(dst));
953 	/* Note: zero-spi cannot be deleted */
954 	req.info.id.spi = spi;
955 	req.info.id.proto	= desc->proto;
956 
957 	memcpy(&req.info.saddr, &src, sizeof(src));
958 
959 	/* Fill lifteme_cfg */
960 	req.info.lft.soft_byte_limit	= XFRM_INF;
961 	req.info.lft.hard_byte_limit	= XFRM_INF;
962 	req.info.lft.soft_packet_limit	= XFRM_INF;
963 	req.info.lft.hard_packet_limit	= XFRM_INF;
964 
965 	req.info.family		= AF_INET;
966 	req.info.mode		= XFRM_MODE_TUNNEL;
967 
968 	if (xfrm_state_pack_algo(&req.nh, sizeof(req), desc))
969 		return -1;
970 
971 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
972 		pr_err("send()");
973 		return -1;
974 	}
975 
976 	return netlink_check_answer(xfrm_sock);
977 }
978 
979 static bool xfrm_usersa_found(struct xfrm_usersa_info *info, uint32_t spi,
980 		struct in_addr src, struct in_addr dst,
981 		struct xfrm_desc *desc)
982 {
983 	if (memcmp(&info->sel.daddr, &dst, sizeof(dst)))
984 		return false;
985 
986 	if (memcmp(&info->sel.saddr, &src, sizeof(src)))
987 		return false;
988 
989 	if (info->sel.family != AF_INET					||
990 			info->sel.prefixlen_d != PREFIX_LEN		||
991 			info->sel.prefixlen_s != PREFIX_LEN)
992 		return false;
993 
994 	if (info->id.spi != spi || info->id.proto != desc->proto)
995 		return false;
996 
997 	if (memcmp(&info->id.daddr, &dst, sizeof(dst)))
998 		return false;
999 
1000 	if (memcmp(&info->saddr, &src, sizeof(src)))
1001 		return false;
1002 
1003 	if (info->lft.soft_byte_limit != XFRM_INF			||
1004 			info->lft.hard_byte_limit != XFRM_INF		||
1005 			info->lft.soft_packet_limit != XFRM_INF		||
1006 			info->lft.hard_packet_limit != XFRM_INF)
1007 		return false;
1008 
1009 	if (info->family != AF_INET || info->mode != XFRM_MODE_TUNNEL)
1010 		return false;
1011 
1012 	/* XXX: check xfrm algo, see xfrm_state_pack_algo(). */
1013 
1014 	return true;
1015 }
1016 
1017 static int xfrm_state_check(int xfrm_sock, uint32_t seq, uint32_t spi,
1018 		struct in_addr src, struct in_addr dst,
1019 		struct xfrm_desc *desc)
1020 {
1021 	struct {
1022 		struct nlmsghdr		nh;
1023 		char			attrbuf[MAX_PAYLOAD];
1024 	} req;
1025 	struct {
1026 		struct nlmsghdr		nh;
1027 		union {
1028 			struct xfrm_usersa_info	info;
1029 			int error;
1030 		};
1031 		char			attrbuf[MAX_PAYLOAD];
1032 	} answer;
1033 	struct xfrm_address_filter filter = {};
1034 	bool found = false;
1035 
1036 
1037 	memset(&req, 0, sizeof(req));
1038 	req.nh.nlmsg_len	= NLMSG_LENGTH(0);
1039 	req.nh.nlmsg_type	= XFRM_MSG_GETSA;
1040 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_DUMP;
1041 	req.nh.nlmsg_seq	= seq;
1042 
1043 	/*
1044 	 * Add dump filter by source address as there may be other tunnels
1045 	 * in this netns (if tests run in parallel).
1046 	 */
1047 	filter.family = AF_INET;
1048 	filter.splen = 0x1f;	/* 0xffffffff mask see addr_match() */
1049 	memcpy(&filter.saddr, &src, sizeof(src));
1050 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_ADDRESS_FILTER,
1051 				&filter, sizeof(filter)))
1052 		return -1;
1053 
1054 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1055 		pr_err("send()");
1056 		return -1;
1057 	}
1058 
1059 	while (1) {
1060 		if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1061 			pr_err("recv()");
1062 			return -1;
1063 		}
1064 		if (answer.nh.nlmsg_type == NLMSG_ERROR) {
1065 			printk("NLMSG_ERROR: %d: %s",
1066 				answer.error, strerror(-answer.error));
1067 			return -1;
1068 		} else if (answer.nh.nlmsg_type == NLMSG_DONE) {
1069 			if (found)
1070 				return 0;
1071 			printk("didn't find allocated xfrm state in dump");
1072 			return -1;
1073 		} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1074 			if (xfrm_usersa_found(&answer.info, spi, src, dst, desc))
1075 				found = true;
1076 		}
1077 	}
1078 }
1079 
1080 static int xfrm_set(int xfrm_sock, uint32_t *seq,
1081 		struct in_addr src, struct in_addr dst,
1082 		struct in_addr tunsrc, struct in_addr tundst,
1083 		struct xfrm_desc *desc)
1084 {
1085 	int err;
1086 
1087 	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1088 	if (err) {
1089 		printk("Failed to add xfrm state");
1090 		return -1;
1091 	}
1092 
1093 	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1094 	if (err) {
1095 		printk("Failed to add xfrm state");
1096 		return -1;
1097 	}
1098 
1099 	/* Check dumps for XFRM_MSG_GETSA */
1100 	err = xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1101 	err |= xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1102 	if (err) {
1103 		printk("Failed to check xfrm state");
1104 		return -1;
1105 	}
1106 
1107 	return 0;
1108 }
1109 
1110 static int xfrm_policy_add(int xfrm_sock, uint32_t seq, uint32_t spi,
1111 		struct in_addr src, struct in_addr dst, uint8_t dir,
1112 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1113 {
1114 	struct {
1115 		struct nlmsghdr			nh;
1116 		struct xfrm_userpolicy_info	info;
1117 		char				attrbuf[MAX_PAYLOAD];
1118 	} req;
1119 	struct xfrm_user_tmpl tmpl;
1120 
1121 	memset(&req, 0, sizeof(req));
1122 	memset(&tmpl, 0, sizeof(tmpl));
1123 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
1124 	req.nh.nlmsg_type	= XFRM_MSG_NEWPOLICY;
1125 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1126 	req.nh.nlmsg_seq	= seq;
1127 
1128 	/* Fill selector. */
1129 	memcpy(&req.info.sel.daddr, &dst, sizeof(tundst));
1130 	memcpy(&req.info.sel.saddr, &src, sizeof(tunsrc));
1131 	req.info.sel.family		= AF_INET;
1132 	req.info.sel.prefixlen_d	= PREFIX_LEN;
1133 	req.info.sel.prefixlen_s	= PREFIX_LEN;
1134 
1135 	/* Fill lifteme_cfg */
1136 	req.info.lft.soft_byte_limit	= XFRM_INF;
1137 	req.info.lft.hard_byte_limit	= XFRM_INF;
1138 	req.info.lft.soft_packet_limit	= XFRM_INF;
1139 	req.info.lft.hard_packet_limit	= XFRM_INF;
1140 
1141 	req.info.dir = dir;
1142 
1143 	/* Fill tmpl */
1144 	memcpy(&tmpl.id.daddr, &dst, sizeof(dst));
1145 	/* Note: zero-spi cannot be deleted */
1146 	tmpl.id.spi = spi;
1147 	tmpl.id.proto	= proto;
1148 	tmpl.family	= AF_INET;
1149 	memcpy(&tmpl.saddr, &src, sizeof(src));
1150 	tmpl.mode	= XFRM_MODE_TUNNEL;
1151 	tmpl.aalgos = (~(uint32_t)0);
1152 	tmpl.ealgos = (~(uint32_t)0);
1153 	tmpl.calgos = (~(uint32_t)0);
1154 
1155 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &tmpl, sizeof(tmpl)))
1156 		return -1;
1157 
1158 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1159 		pr_err("send()");
1160 		return -1;
1161 	}
1162 
1163 	return netlink_check_answer(xfrm_sock);
1164 }
1165 
1166 static int xfrm_prepare(int xfrm_sock, uint32_t *seq,
1167 		struct in_addr src, struct in_addr dst,
1168 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1169 {
1170 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1171 				XFRM_POLICY_OUT, tunsrc, tundst, proto)) {
1172 		printk("Failed to add xfrm policy");
1173 		return -1;
1174 	}
1175 
1176 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src,
1177 				XFRM_POLICY_IN, tunsrc, tundst, proto)) {
1178 		printk("Failed to add xfrm policy");
1179 		return -1;
1180 	}
1181 
1182 	return 0;
1183 }
1184 
1185 static int xfrm_policy_del(int xfrm_sock, uint32_t seq,
1186 		struct in_addr src, struct in_addr dst, uint8_t dir,
1187 		struct in_addr tunsrc, struct in_addr tundst)
1188 {
1189 	struct {
1190 		struct nlmsghdr			nh;
1191 		struct xfrm_userpolicy_id	id;
1192 		char				attrbuf[MAX_PAYLOAD];
1193 	} req;
1194 
1195 	memset(&req, 0, sizeof(req));
1196 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1197 	req.nh.nlmsg_type	= XFRM_MSG_DELPOLICY;
1198 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1199 	req.nh.nlmsg_seq	= seq;
1200 
1201 	/* Fill id */
1202 	memcpy(&req.id.sel.daddr, &dst, sizeof(tundst));
1203 	memcpy(&req.id.sel.saddr, &src, sizeof(tunsrc));
1204 	req.id.sel.family		= AF_INET;
1205 	req.id.sel.prefixlen_d		= PREFIX_LEN;
1206 	req.id.sel.prefixlen_s		= PREFIX_LEN;
1207 	req.id.dir = dir;
1208 
1209 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1210 		pr_err("send()");
1211 		return -1;
1212 	}
1213 
1214 	return netlink_check_answer(xfrm_sock);
1215 }
1216 
1217 static int xfrm_cleanup(int xfrm_sock, uint32_t *seq,
1218 		struct in_addr src, struct in_addr dst,
1219 		struct in_addr tunsrc, struct in_addr tundst)
1220 {
1221 	if (xfrm_policy_del(xfrm_sock, (*seq)++, src, dst,
1222 				XFRM_POLICY_OUT, tunsrc, tundst)) {
1223 		printk("Failed to add xfrm policy");
1224 		return -1;
1225 	}
1226 
1227 	if (xfrm_policy_del(xfrm_sock, (*seq)++, dst, src,
1228 				XFRM_POLICY_IN, tunsrc, tundst)) {
1229 		printk("Failed to add xfrm policy");
1230 		return -1;
1231 	}
1232 
1233 	return 0;
1234 }
1235 
1236 static int xfrm_state_del(int xfrm_sock, uint32_t seq, uint32_t spi,
1237 		struct in_addr src, struct in_addr dst, uint8_t proto)
1238 {
1239 	struct {
1240 		struct nlmsghdr		nh;
1241 		struct xfrm_usersa_id	id;
1242 		char			attrbuf[MAX_PAYLOAD];
1243 	} req;
1244 	xfrm_address_t saddr = {};
1245 
1246 	memset(&req, 0, sizeof(req));
1247 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1248 	req.nh.nlmsg_type	= XFRM_MSG_DELSA;
1249 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1250 	req.nh.nlmsg_seq	= seq;
1251 
1252 	memcpy(&req.id.daddr, &dst, sizeof(dst));
1253 	req.id.family		= AF_INET;
1254 	req.id.proto		= proto;
1255 	/* Note: zero-spi cannot be deleted */
1256 	req.id.spi = spi;
1257 
1258 	memcpy(&saddr, &src, sizeof(src));
1259 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SRCADDR, &saddr, sizeof(saddr)))
1260 		return -1;
1261 
1262 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1263 		pr_err("send()");
1264 		return -1;
1265 	}
1266 
1267 	return netlink_check_answer(xfrm_sock);
1268 }
1269 
1270 static int xfrm_delete(int xfrm_sock, uint32_t *seq,
1271 		struct in_addr src, struct in_addr dst,
1272 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1273 {
1274 	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), src, dst, proto)) {
1275 		printk("Failed to remove xfrm state");
1276 		return -1;
1277 	}
1278 
1279 	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), dst, src, proto)) {
1280 		printk("Failed to remove xfrm state");
1281 		return -1;
1282 	}
1283 
1284 	return 0;
1285 }
1286 
1287 static int xfrm_state_allocspi(int xfrm_sock, uint32_t *seq,
1288 		uint32_t spi, uint8_t proto)
1289 {
1290 	struct {
1291 		struct nlmsghdr			nh;
1292 		struct xfrm_userspi_info	spi;
1293 	} req;
1294 	struct {
1295 		struct nlmsghdr			nh;
1296 		union {
1297 			struct xfrm_usersa_info	info;
1298 			int error;
1299 		};
1300 	} answer;
1301 
1302 	memset(&req, 0, sizeof(req));
1303 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.spi));
1304 	req.nh.nlmsg_type	= XFRM_MSG_ALLOCSPI;
1305 	req.nh.nlmsg_flags	= NLM_F_REQUEST;
1306 	req.nh.nlmsg_seq	= (*seq)++;
1307 
1308 	req.spi.info.family	= AF_INET;
1309 	req.spi.min		= spi;
1310 	req.spi.max		= spi;
1311 	req.spi.info.id.proto	= proto;
1312 
1313 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1314 		pr_err("send()");
1315 		return KSFT_FAIL;
1316 	}
1317 
1318 	if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1319 		pr_err("recv()");
1320 		return KSFT_FAIL;
1321 	} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1322 		uint32_t new_spi = htonl(answer.info.id.spi);
1323 
1324 		if (new_spi != spi) {
1325 			printk("allocated spi is different from requested: %#x != %#x",
1326 					new_spi, spi);
1327 			return KSFT_FAIL;
1328 		}
1329 		return KSFT_PASS;
1330 	} else if (answer.nh.nlmsg_type != NLMSG_ERROR) {
1331 		printk("expected NLMSG_ERROR, got %d", (int)answer.nh.nlmsg_type);
1332 		return KSFT_FAIL;
1333 	}
1334 
1335 	printk("NLMSG_ERROR: %d: %s", answer.error, strerror(-answer.error));
1336 	return (answer.error) ? KSFT_FAIL : KSFT_PASS;
1337 }
1338 
1339 static int netlink_sock_bind(int *sock, uint32_t *seq, int proto, uint32_t groups)
1340 {
1341 	struct sockaddr_nl snl = {};
1342 	socklen_t addr_len;
1343 	int ret = -1;
1344 
1345 	snl.nl_family = AF_NETLINK;
1346 	snl.nl_groups = groups;
1347 
1348 	if (netlink_sock(sock, seq, proto)) {
1349 		printk("Failed to open xfrm netlink socket");
1350 		return -1;
1351 	}
1352 
1353 	if (bind(*sock, (struct sockaddr *)&snl, sizeof(snl)) < 0) {
1354 		pr_err("bind()");
1355 		goto out_close;
1356 	}
1357 
1358 	addr_len = sizeof(snl);
1359 	if (getsockname(*sock, (struct sockaddr *)&snl, &addr_len) < 0) {
1360 		pr_err("getsockname()");
1361 		goto out_close;
1362 	}
1363 	if (addr_len != sizeof(snl)) {
1364 		printk("Wrong address length %d", addr_len);
1365 		goto out_close;
1366 	}
1367 	if (snl.nl_family != AF_NETLINK) {
1368 		printk("Wrong address family %d", snl.nl_family);
1369 		goto out_close;
1370 	}
1371 	return 0;
1372 
1373 out_close:
1374 	close(*sock);
1375 	return ret;
1376 }
1377 
1378 static int xfrm_monitor_acquire(int xfrm_sock, uint32_t *seq, unsigned int nr)
1379 {
1380 	struct {
1381 		struct nlmsghdr nh;
1382 		union {
1383 			struct xfrm_user_acquire acq;
1384 			int error;
1385 		};
1386 		char attrbuf[MAX_PAYLOAD];
1387 	} req;
1388 	struct xfrm_user_tmpl xfrm_tmpl = {};
1389 	int xfrm_listen = -1, ret = KSFT_FAIL;
1390 	uint32_t seq_listen;
1391 
1392 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_ACQUIRE))
1393 		return KSFT_FAIL;
1394 
1395 	memset(&req, 0, sizeof(req));
1396 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.acq));
1397 	req.nh.nlmsg_type	= XFRM_MSG_ACQUIRE;
1398 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1399 	req.nh.nlmsg_seq	= (*seq)++;
1400 
1401 	req.acq.policy.sel.family	= AF_INET;
1402 	req.acq.aalgos	= 0xfeed;
1403 	req.acq.ealgos	= 0xbaad;
1404 	req.acq.calgos	= 0xbabe;
1405 
1406 	xfrm_tmpl.family = AF_INET;
1407 	xfrm_tmpl.id.proto = IPPROTO_ESP;
1408 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &xfrm_tmpl, sizeof(xfrm_tmpl)))
1409 		goto out_close;
1410 
1411 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1412 		pr_err("send()");
1413 		goto out_close;
1414 	}
1415 
1416 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1417 		pr_err("recv()");
1418 		goto out_close;
1419 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1420 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1421 		goto out_close;
1422 	}
1423 
1424 	if (req.error) {
1425 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1426 		ret = req.error;
1427 		goto out_close;
1428 	}
1429 
1430 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1431 		pr_err("recv()");
1432 		goto out_close;
1433 	}
1434 
1435 	if (req.acq.aalgos != 0xfeed || req.acq.ealgos != 0xbaad
1436 			|| req.acq.calgos != 0xbabe) {
1437 		printk("xfrm_user_acquire has changed  %x %x %x",
1438 				req.acq.aalgos, req.acq.ealgos, req.acq.calgos);
1439 		goto out_close;
1440 	}
1441 
1442 	ret = KSFT_PASS;
1443 out_close:
1444 	close(xfrm_listen);
1445 	return ret;
1446 }
1447 
1448 static int xfrm_expire_state(int xfrm_sock, uint32_t *seq,
1449 		unsigned int nr, struct xfrm_desc *desc)
1450 {
1451 	struct {
1452 		struct nlmsghdr nh;
1453 		union {
1454 			struct xfrm_user_expire expire;
1455 			int error;
1456 		};
1457 	} req;
1458 	struct in_addr src, dst;
1459 	int xfrm_listen = -1, ret = KSFT_FAIL;
1460 	uint32_t seq_listen;
1461 
1462 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1463 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1464 
1465 	if (xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc)) {
1466 		printk("Failed to add xfrm state");
1467 		return KSFT_FAIL;
1468 	}
1469 
1470 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1471 		return KSFT_FAIL;
1472 
1473 	memset(&req, 0, sizeof(req));
1474 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1475 	req.nh.nlmsg_type	= XFRM_MSG_EXPIRE;
1476 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1477 	req.nh.nlmsg_seq	= (*seq)++;
1478 
1479 	memcpy(&req.expire.state.id.daddr, &dst, sizeof(dst));
1480 	req.expire.state.id.spi		= gen_spi(src);
1481 	req.expire.state.id.proto	= desc->proto;
1482 	req.expire.state.family		= AF_INET;
1483 	req.expire.hard			= 0xff;
1484 
1485 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1486 		pr_err("send()");
1487 		goto out_close;
1488 	}
1489 
1490 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1491 		pr_err("recv()");
1492 		goto out_close;
1493 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1494 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1495 		goto out_close;
1496 	}
1497 
1498 	if (req.error) {
1499 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1500 		ret = req.error;
1501 		goto out_close;
1502 	}
1503 
1504 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1505 		pr_err("recv()");
1506 		goto out_close;
1507 	}
1508 
1509 	if (req.expire.hard != 0x1) {
1510 		printk("expire.hard is not set: %x", req.expire.hard);
1511 		goto out_close;
1512 	}
1513 
1514 	ret = KSFT_PASS;
1515 out_close:
1516 	close(xfrm_listen);
1517 	return ret;
1518 }
1519 
1520 static int xfrm_expire_policy(int xfrm_sock, uint32_t *seq,
1521 		unsigned int nr, struct xfrm_desc *desc)
1522 {
1523 	struct {
1524 		struct nlmsghdr nh;
1525 		union {
1526 			struct xfrm_user_polexpire expire;
1527 			int error;
1528 		};
1529 	} req;
1530 	struct in_addr src, dst, tunsrc, tundst;
1531 	int xfrm_listen = -1, ret = KSFT_FAIL;
1532 	uint32_t seq_listen;
1533 
1534 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1535 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1536 	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1537 	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1538 
1539 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1540 				XFRM_POLICY_OUT, tunsrc, tundst, desc->proto)) {
1541 		printk("Failed to add xfrm policy");
1542 		return KSFT_FAIL;
1543 	}
1544 
1545 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1546 		return KSFT_FAIL;
1547 
1548 	memset(&req, 0, sizeof(req));
1549 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1550 	req.nh.nlmsg_type	= XFRM_MSG_POLEXPIRE;
1551 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1552 	req.nh.nlmsg_seq	= (*seq)++;
1553 
1554 	/* Fill selector. */
1555 	memcpy(&req.expire.pol.sel.daddr, &dst, sizeof(tundst));
1556 	memcpy(&req.expire.pol.sel.saddr, &src, sizeof(tunsrc));
1557 	req.expire.pol.sel.family	= AF_INET;
1558 	req.expire.pol.sel.prefixlen_d	= PREFIX_LEN;
1559 	req.expire.pol.sel.prefixlen_s	= PREFIX_LEN;
1560 	req.expire.pol.dir		= XFRM_POLICY_OUT;
1561 	req.expire.hard			= 0xff;
1562 
1563 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1564 		pr_err("send()");
1565 		goto out_close;
1566 	}
1567 
1568 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1569 		pr_err("recv()");
1570 		goto out_close;
1571 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1572 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1573 		goto out_close;
1574 	}
1575 
1576 	if (req.error) {
1577 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1578 		ret = req.error;
1579 		goto out_close;
1580 	}
1581 
1582 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1583 		pr_err("recv()");
1584 		goto out_close;
1585 	}
1586 
1587 	if (req.expire.hard != 0x1) {
1588 		printk("expire.hard is not set: %x", req.expire.hard);
1589 		goto out_close;
1590 	}
1591 
1592 	ret = KSFT_PASS;
1593 out_close:
1594 	close(xfrm_listen);
1595 	return ret;
1596 }
1597 
1598 static int xfrm_spdinfo_set_thresh(int xfrm_sock, uint32_t *seq,
1599 		unsigned thresh4_l, unsigned thresh4_r,
1600 		unsigned thresh6_l, unsigned thresh6_r,
1601 		bool add_bad_attr)
1602 
1603 {
1604 	struct {
1605 		struct nlmsghdr		nh;
1606 		union {
1607 			uint32_t	unused;
1608 			int		error;
1609 		};
1610 		char			attrbuf[MAX_PAYLOAD];
1611 	} req;
1612 	struct xfrmu_spdhthresh thresh;
1613 
1614 	memset(&req, 0, sizeof(req));
1615 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
1616 	req.nh.nlmsg_type	= XFRM_MSG_NEWSPDINFO;
1617 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1618 	req.nh.nlmsg_seq	= (*seq)++;
1619 
1620 	thresh.lbits = thresh4_l;
1621 	thresh.rbits = thresh4_r;
1622 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV4_HTHRESH, &thresh, sizeof(thresh)))
1623 		return -1;
1624 
1625 	thresh.lbits = thresh6_l;
1626 	thresh.rbits = thresh6_r;
1627 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV6_HTHRESH, &thresh, sizeof(thresh)))
1628 		return -1;
1629 
1630 	if (add_bad_attr) {
1631 		BUILD_BUG_ON(XFRMA_IF_ID <= XFRMA_SPD_MAX + 1);
1632 		if (rtattr_pack(&req.nh, sizeof(req), XFRMA_IF_ID, NULL, 0)) {
1633 			pr_err("adding attribute failed: no space");
1634 			return -1;
1635 		}
1636 	}
1637 
1638 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1639 		pr_err("send()");
1640 		return -1;
1641 	}
1642 
1643 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1644 		pr_err("recv()");
1645 		return -1;
1646 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1647 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1648 		return -1;
1649 	}
1650 
1651 	if (req.error) {
1652 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1653 		return -1;
1654 	}
1655 
1656 	return 0;
1657 }
1658 
1659 static int xfrm_spdinfo_attrs(int xfrm_sock, uint32_t *seq)
1660 {
1661 	struct {
1662 		struct nlmsghdr			nh;
1663 		union {
1664 			uint32_t	unused;
1665 			int		error;
1666 		};
1667 		char			attrbuf[MAX_PAYLOAD];
1668 	} req;
1669 
1670 	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 31, 120, 16, false)) {
1671 		pr_err("Can't set SPD HTHRESH");
1672 		return KSFT_FAIL;
1673 	}
1674 
1675 	memset(&req, 0, sizeof(req));
1676 
1677 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
1678 	req.nh.nlmsg_type	= XFRM_MSG_GETSPDINFO;
1679 	req.nh.nlmsg_flags	= NLM_F_REQUEST;
1680 	req.nh.nlmsg_seq	= (*seq)++;
1681 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1682 		pr_err("send()");
1683 		return KSFT_FAIL;
1684 	}
1685 
1686 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1687 		pr_err("recv()");
1688 		return KSFT_FAIL;
1689 	} else if (req.nh.nlmsg_type == XFRM_MSG_NEWSPDINFO) {
1690 		size_t len = NLMSG_PAYLOAD(&req.nh, sizeof(req.unused));
1691 		struct rtattr *attr = (void *)req.attrbuf;
1692 		int got_thresh = 0;
1693 
1694 		for (; RTA_OK(attr, len); attr = RTA_NEXT(attr, len)) {
1695 			if (attr->rta_type == XFRMA_SPD_IPV4_HTHRESH) {
1696 				struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1697 
1698 				got_thresh++;
1699 				if (t->lbits != 32 || t->rbits != 31) {
1700 					pr_err("thresh differ: %u, %u",
1701 							t->lbits, t->rbits);
1702 					return KSFT_FAIL;
1703 				}
1704 			}
1705 			if (attr->rta_type == XFRMA_SPD_IPV6_HTHRESH) {
1706 				struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1707 
1708 				got_thresh++;
1709 				if (t->lbits != 120 || t->rbits != 16) {
1710 					pr_err("thresh differ: %u, %u",
1711 							t->lbits, t->rbits);
1712 					return KSFT_FAIL;
1713 				}
1714 			}
1715 		}
1716 		if (got_thresh != 2) {
1717 			pr_err("only %d thresh returned by XFRM_MSG_GETSPDINFO", got_thresh);
1718 			return KSFT_FAIL;
1719 		}
1720 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1721 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1722 		return KSFT_FAIL;
1723 	} else {
1724 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1725 		return -1;
1726 	}
1727 
1728 	/* Restore the default */
1729 	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, false)) {
1730 		pr_err("Can't restore SPD HTHRESH");
1731 		return KSFT_FAIL;
1732 	}
1733 
1734 	/*
1735 	 * At this moment xfrm uses nlmsg_parse_deprecated(), which
1736 	 * implies NL_VALIDATE_LIBERAL - ignoring attributes with
1737 	 * (type > maxtype). nla_parse_depricated_strict() would enforce
1738 	 * it. Or even stricter nla_parse().
1739 	 * Right now it's not expected to fail, but to be ignored.
1740 	 */
1741 	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, true))
1742 		return KSFT_PASS;
1743 
1744 	return KSFT_PASS;
1745 }
1746 
1747 static int child_serv(int xfrm_sock, uint32_t *seq,
1748 		unsigned int nr, int cmd_fd, void *buf, struct xfrm_desc *desc)
1749 {
1750 	struct in_addr src, dst, tunsrc, tundst;
1751 	struct test_desc msg;
1752 	int ret = KSFT_FAIL;
1753 
1754 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1755 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1756 	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1757 	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1758 
1759 	/* UDP pinging without xfrm */
1760 	if (do_ping(cmd_fd, buf, page_size, src, true, 0, 0, udp_ping_send)) {
1761 		printk("ping failed before setting xfrm");
1762 		return KSFT_FAIL;
1763 	}
1764 
1765 	memset(&msg, 0, sizeof(msg));
1766 	msg.type = MSG_XFRM_PREPARE;
1767 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1768 	write_msg(cmd_fd, &msg, 1);
1769 
1770 	if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1771 		printk("failed to prepare xfrm");
1772 		goto cleanup;
1773 	}
1774 
1775 	memset(&msg, 0, sizeof(msg));
1776 	msg.type = MSG_XFRM_ADD;
1777 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1778 	write_msg(cmd_fd, &msg, 1);
1779 	if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1780 		printk("failed to set xfrm");
1781 		goto delete;
1782 	}
1783 
1784 	/* UDP pinging with xfrm tunnel */
1785 	if (do_ping(cmd_fd, buf, page_size, tunsrc,
1786 				true, 0, 0, udp_ping_send)) {
1787 		printk("ping failed for xfrm");
1788 		goto delete;
1789 	}
1790 
1791 	ret = KSFT_PASS;
1792 delete:
1793 	/* xfrm delete */
1794 	memset(&msg, 0, sizeof(msg));
1795 	msg.type = MSG_XFRM_DEL;
1796 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1797 	write_msg(cmd_fd, &msg, 1);
1798 
1799 	if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1800 		printk("failed ping to remove xfrm");
1801 		ret = KSFT_FAIL;
1802 	}
1803 
1804 cleanup:
1805 	memset(&msg, 0, sizeof(msg));
1806 	msg.type = MSG_XFRM_CLEANUP;
1807 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1808 	write_msg(cmd_fd, &msg, 1);
1809 	if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1810 		printk("failed ping to cleanup xfrm");
1811 		ret = KSFT_FAIL;
1812 	}
1813 	return ret;
1814 }
1815 
1816 static int child_f(unsigned int nr, int test_desc_fd, int cmd_fd, void *buf)
1817 {
1818 	struct xfrm_desc desc;
1819 	struct test_desc msg;
1820 	int xfrm_sock = -1;
1821 	uint32_t seq;
1822 
1823 	if (switch_ns(nsfd_childa))
1824 		exit(KSFT_FAIL);
1825 
1826 	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1827 		printk("Failed to open xfrm netlink socket");
1828 		exit(KSFT_FAIL);
1829 	}
1830 
1831 	/* Check that seq sock is ready, just for sure. */
1832 	memset(&msg, 0, sizeof(msg));
1833 	msg.type = MSG_ACK;
1834 	write_msg(cmd_fd, &msg, 1);
1835 	read_msg(cmd_fd, &msg, 1);
1836 	if (msg.type != MSG_ACK) {
1837 		printk("Ack failed");
1838 		exit(KSFT_FAIL);
1839 	}
1840 
1841 	for (;;) {
1842 		ssize_t received = read(test_desc_fd, &desc, sizeof(desc));
1843 		int ret;
1844 
1845 		if (received == 0) /* EOF */
1846 			break;
1847 
1848 		if (received != sizeof(desc)) {
1849 			pr_err("read() returned %zd", received);
1850 			exit(KSFT_FAIL);
1851 		}
1852 
1853 		switch (desc.type) {
1854 		case CREATE_TUNNEL:
1855 			ret = child_serv(xfrm_sock, &seq, nr,
1856 					 cmd_fd, buf, &desc);
1857 			break;
1858 		case ALLOCATE_SPI:
1859 			ret = xfrm_state_allocspi(xfrm_sock, &seq,
1860 						  -1, desc.proto);
1861 			break;
1862 		case MONITOR_ACQUIRE:
1863 			ret = xfrm_monitor_acquire(xfrm_sock, &seq, nr);
1864 			break;
1865 		case EXPIRE_STATE:
1866 			ret = xfrm_expire_state(xfrm_sock, &seq, nr, &desc);
1867 			break;
1868 		case EXPIRE_POLICY:
1869 			ret = xfrm_expire_policy(xfrm_sock, &seq, nr, &desc);
1870 			break;
1871 		case SPDINFO_ATTRS:
1872 			ret = xfrm_spdinfo_attrs(xfrm_sock, &seq);
1873 			break;
1874 		default:
1875 			printk("Unknown desc type %d", desc.type);
1876 			exit(KSFT_FAIL);
1877 		}
1878 		write_test_result(ret, &desc);
1879 	}
1880 
1881 	close(xfrm_sock);
1882 
1883 	msg.type = MSG_EXIT;
1884 	write_msg(cmd_fd, &msg, 1);
1885 	exit(KSFT_PASS);
1886 }
1887 
1888 static void grand_child_serv(unsigned int nr, int cmd_fd, void *buf,
1889 		struct test_desc *msg, int xfrm_sock, uint32_t *seq)
1890 {
1891 	struct in_addr src, dst, tunsrc, tundst;
1892 	bool tun_reply;
1893 	struct xfrm_desc *desc = &msg->body.xfrm_desc;
1894 
1895 	src = inet_makeaddr(INADDR_B, grchild_ip(nr));
1896 	dst = inet_makeaddr(INADDR_B, child_ip(nr));
1897 	tunsrc = inet_makeaddr(INADDR_A, grchild_ip(nr));
1898 	tundst = inet_makeaddr(INADDR_A, child_ip(nr));
1899 
1900 	switch (msg->type) {
1901 	case MSG_EXIT:
1902 		exit(KSFT_PASS);
1903 	case MSG_ACK:
1904 		write_msg(cmd_fd, msg, 1);
1905 		break;
1906 	case MSG_PING:
1907 		tun_reply = memcmp(&dst, &msg->body.ping.reply_ip, sizeof(in_addr_t));
1908 		/* UDP pinging without xfrm */
1909 		if (do_ping(cmd_fd, buf, page_size, tun_reply ? tunsrc : src,
1910 				false, msg->body.ping.port,
1911 				msg->body.ping.reply_ip, udp_ping_reply)) {
1912 			printk("ping failed before setting xfrm");
1913 		}
1914 		break;
1915 	case MSG_XFRM_PREPARE:
1916 		if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst,
1917 					desc->proto)) {
1918 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1919 			printk("failed to prepare xfrm");
1920 		}
1921 		break;
1922 	case MSG_XFRM_ADD:
1923 		if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1924 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1925 			printk("failed to set xfrm");
1926 		}
1927 		break;
1928 	case MSG_XFRM_DEL:
1929 		if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst,
1930 					desc->proto)) {
1931 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1932 			printk("failed to remove xfrm");
1933 		}
1934 		break;
1935 	case MSG_XFRM_CLEANUP:
1936 		if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1937 			printk("failed to cleanup xfrm");
1938 		}
1939 		break;
1940 	default:
1941 		printk("got unknown msg type %d", msg->type);
1942 	}
1943 }
1944 
1945 static int grand_child_f(unsigned int nr, int cmd_fd, void *buf)
1946 {
1947 	struct test_desc msg;
1948 	int xfrm_sock = -1;
1949 	uint32_t seq;
1950 
1951 	if (switch_ns(nsfd_childb))
1952 		exit(KSFT_FAIL);
1953 
1954 	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1955 		printk("Failed to open xfrm netlink socket");
1956 		exit(KSFT_FAIL);
1957 	}
1958 
1959 	do {
1960 		read_msg(cmd_fd, &msg, 1);
1961 		grand_child_serv(nr, cmd_fd, buf, &msg, xfrm_sock, &seq);
1962 	} while (1);
1963 
1964 	close(xfrm_sock);
1965 	exit(KSFT_FAIL);
1966 }
1967 
1968 static int start_child(unsigned int nr, char *veth, int test_desc_fd[2])
1969 {
1970 	int cmd_sock[2];
1971 	void *data_map;
1972 	pid_t child;
1973 
1974 	if (init_child(nsfd_childa, veth, child_ip(nr), grchild_ip(nr)))
1975 		return -1;
1976 
1977 	if (init_child(nsfd_childb, veth, grchild_ip(nr), child_ip(nr)))
1978 		return -1;
1979 
1980 	child = fork();
1981 	if (child < 0) {
1982 		pr_err("fork()");
1983 		return -1;
1984 	} else if (child) {
1985 		/* in parent - selftest */
1986 		return switch_ns(nsfd_parent);
1987 	}
1988 
1989 	if (close(test_desc_fd[1])) {
1990 		pr_err("close()");
1991 		return -1;
1992 	}
1993 
1994 	/* child */
1995 	data_map = mmap(0, page_size, PROT_READ | PROT_WRITE,
1996 			MAP_SHARED | MAP_ANONYMOUS, -1, 0);
1997 	if (data_map == MAP_FAILED) {
1998 		pr_err("mmap()");
1999 		return -1;
2000 	}
2001 
2002 	randomize_buffer(data_map, page_size);
2003 
2004 	if (socketpair(PF_LOCAL, SOCK_SEQPACKET, 0, cmd_sock)) {
2005 		pr_err("socketpair()");
2006 		return -1;
2007 	}
2008 
2009 	child = fork();
2010 	if (child < 0) {
2011 		pr_err("fork()");
2012 		return -1;
2013 	} else if (child) {
2014 		if (close(cmd_sock[0])) {
2015 			pr_err("close()");
2016 			return -1;
2017 		}
2018 		return child_f(nr, test_desc_fd[0], cmd_sock[1], data_map);
2019 	}
2020 	if (close(cmd_sock[1])) {
2021 		pr_err("close()");
2022 		return -1;
2023 	}
2024 	return grand_child_f(nr, cmd_sock[0], data_map);
2025 }
2026 
2027 static void exit_usage(char **argv)
2028 {
2029 	printk("Usage: %s [nr_process]", argv[0]);
2030 	exit(KSFT_FAIL);
2031 }
2032 
2033 static int __write_desc(int test_desc_fd, struct xfrm_desc *desc)
2034 {
2035 	ssize_t ret;
2036 
2037 	ret = write(test_desc_fd, desc, sizeof(*desc));
2038 
2039 	if (ret == sizeof(*desc))
2040 		return 0;
2041 
2042 	pr_err("Writing test's desc failed %ld", ret);
2043 
2044 	return -1;
2045 }
2046 
2047 static int write_desc(int proto, int test_desc_fd,
2048 		char *a, char *e, char *c, char *ae)
2049 {
2050 	struct xfrm_desc desc = {};
2051 
2052 	desc.type = CREATE_TUNNEL;
2053 	desc.proto = proto;
2054 
2055 	if (a)
2056 		strncpy(desc.a_algo, a, ALGO_LEN - 1);
2057 	if (e)
2058 		strncpy(desc.e_algo, e, ALGO_LEN - 1);
2059 	if (c)
2060 		strncpy(desc.c_algo, c, ALGO_LEN - 1);
2061 	if (ae)
2062 		strncpy(desc.ae_algo, ae, ALGO_LEN - 1);
2063 
2064 	return __write_desc(test_desc_fd, &desc);
2065 }
2066 
2067 int proto_list[] = { IPPROTO_AH, IPPROTO_COMP, IPPROTO_ESP };
2068 char *ah_list[] = {
2069 	"digest_null", "hmac(md5)", "hmac(sha1)", "hmac(sha256)",
2070 	"hmac(sha384)", "hmac(sha512)", "hmac(rmd160)",
2071 	"xcbc(aes)", "cmac(aes)"
2072 };
2073 char *comp_list[] = {
2074 	"deflate",
2075 #if 0
2076 	/* No compression backend realization */
2077 	"lzs", "lzjh"
2078 #endif
2079 };
2080 char *e_list[] = {
2081 	"ecb(cipher_null)", "cbc(des)", "cbc(des3_ede)", "cbc(cast5)",
2082 	"cbc(blowfish)", "cbc(aes)", "cbc(serpent)", "cbc(camellia)",
2083 	"cbc(twofish)", "rfc3686(ctr(aes))"
2084 };
2085 char *ae_list[] = {
2086 #if 0
2087 	/* not implemented */
2088 	"rfc4106(gcm(aes))", "rfc4309(ccm(aes))", "rfc4543(gcm(aes))",
2089 	"rfc7539esp(chacha20,poly1305)"
2090 #endif
2091 };
2092 
2093 const unsigned int proto_plan = ARRAY_SIZE(ah_list) + ARRAY_SIZE(comp_list) \
2094 				+ (ARRAY_SIZE(ah_list) * ARRAY_SIZE(e_list)) \
2095 				+ ARRAY_SIZE(ae_list);
2096 
2097 static int write_proto_plan(int fd, int proto)
2098 {
2099 	unsigned int i;
2100 
2101 	switch (proto) {
2102 	case IPPROTO_AH:
2103 		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2104 			if (write_desc(proto, fd, ah_list[i], 0, 0, 0))
2105 				return -1;
2106 		}
2107 		break;
2108 	case IPPROTO_COMP:
2109 		for (i = 0; i < ARRAY_SIZE(comp_list); i++) {
2110 			if (write_desc(proto, fd, 0, 0, comp_list[i], 0))
2111 				return -1;
2112 		}
2113 		break;
2114 	case IPPROTO_ESP:
2115 		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2116 			int j;
2117 
2118 			for (j = 0; j < ARRAY_SIZE(e_list); j++) {
2119 				if (write_desc(proto, fd, ah_list[i],
2120 							e_list[j], 0, 0))
2121 					return -1;
2122 			}
2123 		}
2124 		for (i = 0; i < ARRAY_SIZE(ae_list); i++) {
2125 			if (write_desc(proto, fd, 0, 0, 0, ae_list[i]))
2126 				return -1;
2127 		}
2128 		break;
2129 	default:
2130 		printk("BUG: Specified unknown proto %d", proto);
2131 		return -1;
2132 	}
2133 
2134 	return 0;
2135 }
2136 
2137 /*
2138  * Some structures in xfrm uapi header differ in size between
2139  * 64-bit and 32-bit ABI:
2140  *
2141  *             32-bit UABI               |            64-bit UABI
2142  *  -------------------------------------|-------------------------------------
2143  *   sizeof(xfrm_usersa_info)     = 220  |  sizeof(xfrm_usersa_info)     = 224
2144  *   sizeof(xfrm_userpolicy_info) = 164  |  sizeof(xfrm_userpolicy_info) = 168
2145  *   sizeof(xfrm_userspi_info)    = 228  |  sizeof(xfrm_userspi_info)    = 232
2146  *   sizeof(xfrm_user_acquire)    = 276  |  sizeof(xfrm_user_acquire)    = 280
2147  *   sizeof(xfrm_user_expire)     = 224  |  sizeof(xfrm_user_expire)     = 232
2148  *   sizeof(xfrm_user_polexpire)  = 168  |  sizeof(xfrm_user_polexpire)  = 176
2149  *
2150  * Check the affected by the UABI difference structures.
2151  * Also, check translation for xfrm_set_spdinfo: it has it's own attributes
2152  * which needs to be correctly copied, but not translated.
2153  */
2154 const unsigned int compat_plan = 5;
2155 static int write_compat_struct_tests(int test_desc_fd)
2156 {
2157 	struct xfrm_desc desc = {};
2158 
2159 	desc.type = ALLOCATE_SPI;
2160 	desc.proto = IPPROTO_AH;
2161 	strncpy(desc.a_algo, ah_list[0], ALGO_LEN - 1);
2162 
2163 	if (__write_desc(test_desc_fd, &desc))
2164 		return -1;
2165 
2166 	desc.type = MONITOR_ACQUIRE;
2167 	if (__write_desc(test_desc_fd, &desc))
2168 		return -1;
2169 
2170 	desc.type = EXPIRE_STATE;
2171 	if (__write_desc(test_desc_fd, &desc))
2172 		return -1;
2173 
2174 	desc.type = EXPIRE_POLICY;
2175 	if (__write_desc(test_desc_fd, &desc))
2176 		return -1;
2177 
2178 	desc.type = SPDINFO_ATTRS;
2179 	if (__write_desc(test_desc_fd, &desc))
2180 		return -1;
2181 
2182 	return 0;
2183 }
2184 
2185 static int write_test_plan(int test_desc_fd)
2186 {
2187 	unsigned int i;
2188 	pid_t child;
2189 
2190 	child = fork();
2191 	if (child < 0) {
2192 		pr_err("fork()");
2193 		return -1;
2194 	}
2195 	if (child) {
2196 		if (close(test_desc_fd))
2197 			printk("close(): %m");
2198 		return 0;
2199 	}
2200 
2201 	if (write_compat_struct_tests(test_desc_fd))
2202 		exit(KSFT_FAIL);
2203 
2204 	for (i = 0; i < ARRAY_SIZE(proto_list); i++) {
2205 		if (write_proto_plan(test_desc_fd, proto_list[i]))
2206 			exit(KSFT_FAIL);
2207 	}
2208 
2209 	exit(KSFT_PASS);
2210 }
2211 
2212 static int children_cleanup(void)
2213 {
2214 	unsigned ret = KSFT_PASS;
2215 
2216 	while (1) {
2217 		int status;
2218 		pid_t p = wait(&status);
2219 
2220 		if ((p < 0) && errno == ECHILD)
2221 			break;
2222 
2223 		if (p < 0) {
2224 			pr_err("wait()");
2225 			return KSFT_FAIL;
2226 		}
2227 
2228 		if (!WIFEXITED(status)) {
2229 			ret = KSFT_FAIL;
2230 			continue;
2231 		}
2232 
2233 		if (WEXITSTATUS(status) == KSFT_FAIL)
2234 			ret = KSFT_FAIL;
2235 	}
2236 
2237 	return ret;
2238 }
2239 
2240 typedef void (*print_res)(const char *, ...);
2241 
2242 static int check_results(void)
2243 {
2244 	struct test_result tr = {};
2245 	struct xfrm_desc *d = &tr.desc;
2246 	int ret = KSFT_PASS;
2247 
2248 	while (1) {
2249 		ssize_t received = read(results_fd[0], &tr, sizeof(tr));
2250 		print_res result;
2251 
2252 		if (received == 0) /* EOF */
2253 			break;
2254 
2255 		if (received != sizeof(tr)) {
2256 			pr_err("read() returned %zd", received);
2257 			return KSFT_FAIL;
2258 		}
2259 
2260 		switch (tr.res) {
2261 		case KSFT_PASS:
2262 			result = ksft_test_result_pass;
2263 			break;
2264 		case KSFT_FAIL:
2265 		default:
2266 			result = ksft_test_result_fail;
2267 			ret = KSFT_FAIL;
2268 		}
2269 
2270 		result(" %s: [%u, '%s', '%s', '%s', '%s', %u]\n",
2271 		       desc_name[d->type], (unsigned int)d->proto, d->a_algo,
2272 		       d->e_algo, d->c_algo, d->ae_algo, d->icv_len);
2273 	}
2274 
2275 	return ret;
2276 }
2277 
2278 int main(int argc, char **argv)
2279 {
2280 	unsigned int nr_process = 1;
2281 	int route_sock = -1, ret = KSFT_SKIP;
2282 	int test_desc_fd[2];
2283 	uint32_t route_seq;
2284 	unsigned int i;
2285 
2286 	if (argc > 2)
2287 		exit_usage(argv);
2288 
2289 	if (argc > 1) {
2290 		char *endptr;
2291 
2292 		errno = 0;
2293 		nr_process = strtol(argv[1], &endptr, 10);
2294 		if ((errno == ERANGE && (nr_process == LONG_MAX || nr_process == LONG_MIN))
2295 				|| (errno != 0 && nr_process == 0)
2296 				|| (endptr == argv[1]) || (*endptr != '\0')) {
2297 			printk("Failed to parse [nr_process]");
2298 			exit_usage(argv);
2299 		}
2300 
2301 		if (nr_process > MAX_PROCESSES || !nr_process) {
2302 			printk("nr_process should be between [1; %u]",
2303 					MAX_PROCESSES);
2304 			exit_usage(argv);
2305 		}
2306 	}
2307 
2308 	srand(time(NULL));
2309 	page_size = sysconf(_SC_PAGESIZE);
2310 	if (page_size < 1)
2311 		ksft_exit_skip("sysconf(): %m\n");
2312 
2313 	if (pipe2(test_desc_fd, O_DIRECT) < 0)
2314 		ksft_exit_skip("pipe(): %m\n");
2315 
2316 	if (pipe2(results_fd, O_DIRECT) < 0)
2317 		ksft_exit_skip("pipe(): %m\n");
2318 
2319 	if (init_namespaces())
2320 		ksft_exit_skip("Failed to create namespaces\n");
2321 
2322 	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
2323 		ksft_exit_skip("Failed to open netlink route socket\n");
2324 
2325 	for (i = 0; i < nr_process; i++) {
2326 		char veth[VETH_LEN];
2327 
2328 		snprintf(veth, VETH_LEN, VETH_FMT, i);
2329 
2330 		if (veth_add(route_sock, route_seq++, veth, nsfd_childa, veth, nsfd_childb)) {
2331 			close(route_sock);
2332 			ksft_exit_fail_msg("Failed to create veth device");
2333 		}
2334 
2335 		if (start_child(i, veth, test_desc_fd)) {
2336 			close(route_sock);
2337 			ksft_exit_fail_msg("Child %u failed to start", i);
2338 		}
2339 	}
2340 
2341 	if (close(route_sock) || close(test_desc_fd[0]) || close(results_fd[1]))
2342 		ksft_exit_fail_msg("close(): %m");
2343 
2344 	ksft_set_plan(proto_plan + compat_plan);
2345 
2346 	if (write_test_plan(test_desc_fd[1]))
2347 		ksft_exit_fail_msg("Failed to write test plan to pipe");
2348 
2349 	ret = check_results();
2350 
2351 	if (children_cleanup() == KSFT_FAIL)
2352 		exit(KSFT_FAIL);
2353 
2354 	exit(ret);
2355 }
2356