1 // SPDX-License-Identifier: GPL-2.0-only
2 #include <errno.h>
3 #include <stdbool.h>
4 #include <stdio.h>
5 #include <string.h>
6 #include <unistd.h>
7 
8 #include <arpa/inet.h>
9 
10 #include <linux/err.h>
11 #include <linux/in.h>
12 #include <linux/in6.h>
13 
14 #include "bpf_util.h"
15 #include "network_helpers.h"
16 
17 #define clean_errno() (errno == 0 ? "None" : strerror(errno))
18 #define log_err(MSG, ...) ({						\
19 			int __save = errno;				\
20 			fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", \
21 				__FILE__, __LINE__, clean_errno(),	\
22 				##__VA_ARGS__);				\
23 			errno = __save;					\
24 })
25 
26 struct ipv4_packet pkt_v4 = {
27 	.eth.h_proto = __bpf_constant_htons(ETH_P_IP),
28 	.iph.ihl = 5,
29 	.iph.protocol = IPPROTO_TCP,
30 	.iph.tot_len = __bpf_constant_htons(MAGIC_BYTES),
31 	.tcp.urg_ptr = 123,
32 	.tcp.doff = 5,
33 };
34 
35 struct ipv6_packet pkt_v6 = {
36 	.eth.h_proto = __bpf_constant_htons(ETH_P_IPV6),
37 	.iph.nexthdr = IPPROTO_TCP,
38 	.iph.payload_len = __bpf_constant_htons(MAGIC_BYTES),
39 	.tcp.urg_ptr = 123,
40 	.tcp.doff = 5,
41 };
42 
43 static int settimeo(int fd, int timeout_ms)
44 {
45 	struct timeval timeout = { .tv_sec = 3 };
46 
47 	if (timeout_ms > 0) {
48 		timeout.tv_sec = timeout_ms / 1000;
49 		timeout.tv_usec = (timeout_ms % 1000) * 1000;
50 	}
51 
52 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout,
53 		       sizeof(timeout))) {
54 		log_err("Failed to set SO_RCVTIMEO");
55 		return -1;
56 	}
57 
58 	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout,
59 		       sizeof(timeout))) {
60 		log_err("Failed to set SO_SNDTIMEO");
61 		return -1;
62 	}
63 
64 	return 0;
65 }
66 
67 #define save_errno_close(fd) ({ int __save = errno; close(fd); errno = __save; })
68 
69 int start_server(int family, int type, const char *addr_str, __u16 port,
70 		 int timeout_ms)
71 {
72 	struct sockaddr_storage addr = {};
73 	socklen_t len;
74 	int fd;
75 
76 	if (make_sockaddr(family, addr_str, port, &addr, &len))
77 		return -1;
78 
79 	fd = socket(family, type, 0);
80 	if (fd < 0) {
81 		log_err("Failed to create server socket");
82 		return -1;
83 	}
84 
85 	if (settimeo(fd, timeout_ms))
86 		goto error_close;
87 
88 	if (bind(fd, (const struct sockaddr *)&addr, len) < 0) {
89 		log_err("Failed to bind socket");
90 		goto error_close;
91 	}
92 
93 	if (type == SOCK_STREAM) {
94 		if (listen(fd, 1) < 0) {
95 			log_err("Failed to listed on socket");
96 			goto error_close;
97 		}
98 	}
99 
100 	return fd;
101 
102 error_close:
103 	save_errno_close(fd);
104 	return -1;
105 }
106 
107 static int connect_fd_to_addr(int fd,
108 			      const struct sockaddr_storage *addr,
109 			      socklen_t addrlen)
110 {
111 	if (connect(fd, (const struct sockaddr *)addr, addrlen)) {
112 		log_err("Failed to connect to server");
113 		return -1;
114 	}
115 
116 	return 0;
117 }
118 
119 int connect_to_fd(int server_fd, int timeout_ms)
120 {
121 	struct sockaddr_storage addr;
122 	struct sockaddr_in *addr_in;
123 	socklen_t addrlen, optlen;
124 	int fd, type;
125 
126 	optlen = sizeof(type);
127 	if (getsockopt(server_fd, SOL_SOCKET, SO_TYPE, &type, &optlen)) {
128 		log_err("getsockopt(SOL_TYPE)");
129 		return -1;
130 	}
131 
132 	addrlen = sizeof(addr);
133 	if (getsockname(server_fd, (struct sockaddr *)&addr, &addrlen)) {
134 		log_err("Failed to get server addr");
135 		return -1;
136 	}
137 
138 	addr_in = (struct sockaddr_in *)&addr;
139 	fd = socket(addr_in->sin_family, type, 0);
140 	if (fd < 0) {
141 		log_err("Failed to create client socket");
142 		return -1;
143 	}
144 
145 	if (settimeo(fd, timeout_ms))
146 		goto error_close;
147 
148 	if (connect_fd_to_addr(fd, &addr, addrlen))
149 		goto error_close;
150 
151 	return fd;
152 
153 error_close:
154 	save_errno_close(fd);
155 	return -1;
156 }
157 
158 int connect_fd_to_fd(int client_fd, int server_fd, int timeout_ms)
159 {
160 	struct sockaddr_storage addr;
161 	socklen_t len = sizeof(addr);
162 
163 	if (settimeo(client_fd, timeout_ms))
164 		return -1;
165 
166 	if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
167 		log_err("Failed to get server addr");
168 		return -1;
169 	}
170 
171 	if (connect_fd_to_addr(client_fd, &addr, len))
172 		return -1;
173 
174 	return 0;
175 }
176 
177 int make_sockaddr(int family, const char *addr_str, __u16 port,
178 		  struct sockaddr_storage *addr, socklen_t *len)
179 {
180 	if (family == AF_INET) {
181 		struct sockaddr_in *sin = (void *)addr;
182 
183 		sin->sin_family = AF_INET;
184 		sin->sin_port = htons(port);
185 		if (addr_str &&
186 		    inet_pton(AF_INET, addr_str, &sin->sin_addr) != 1) {
187 			log_err("inet_pton(AF_INET, %s)", addr_str);
188 			return -1;
189 		}
190 		if (len)
191 			*len = sizeof(*sin);
192 		return 0;
193 	} else if (family == AF_INET6) {
194 		struct sockaddr_in6 *sin6 = (void *)addr;
195 
196 		sin6->sin6_family = AF_INET6;
197 		sin6->sin6_port = htons(port);
198 		if (addr_str &&
199 		    inet_pton(AF_INET6, addr_str, &sin6->sin6_addr) != 1) {
200 			log_err("inet_pton(AF_INET6, %s)", addr_str);
201 			return -1;
202 		}
203 		if (len)
204 			*len = sizeof(*sin6);
205 		return 0;
206 	}
207 	return -1;
208 }
209