1 // SPDX-License-Identifier: GPL-2.0-only
2 #include <errno.h>
3 #include <stdbool.h>
4 #include <stdio.h>
5 #include <string.h>
6 #include <unistd.h>
7 
8 #include <arpa/inet.h>
9 
10 #include <linux/err.h>
11 #include <linux/in.h>
12 #include <linux/in6.h>
13 
14 #include "bpf_util.h"
15 #include "network_helpers.h"
16 
17 #define clean_errno() (errno == 0 ? "None" : strerror(errno))
18 #define log_err(MSG, ...) ({						\
19 			int __save = errno;				\
20 			fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", \
21 				__FILE__, __LINE__, clean_errno(),	\
22 				##__VA_ARGS__);				\
23 			errno = __save;					\
24 })
25 
26 struct ipv4_packet pkt_v4 = {
27 	.eth.h_proto = __bpf_constant_htons(ETH_P_IP),
28 	.iph.ihl = 5,
29 	.iph.protocol = IPPROTO_TCP,
30 	.iph.tot_len = __bpf_constant_htons(MAGIC_BYTES),
31 	.tcp.urg_ptr = 123,
32 	.tcp.doff = 5,
33 };
34 
35 struct ipv6_packet pkt_v6 = {
36 	.eth.h_proto = __bpf_constant_htons(ETH_P_IPV6),
37 	.iph.nexthdr = IPPROTO_TCP,
38 	.iph.payload_len = __bpf_constant_htons(MAGIC_BYTES),
39 	.tcp.urg_ptr = 123,
40 	.tcp.doff = 5,
41 };
42 
43 static int settimeo(int fd, int timeout_ms)
44 {
45 	struct timeval timeout = { .tv_sec = 3 };
46 
47 	if (timeout_ms > 0) {
48 		timeout.tv_sec = timeout_ms / 1000;
49 		timeout.tv_usec = (timeout_ms % 1000) * 1000;
50 	}
51 
52 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout,
53 		       sizeof(timeout))) {
54 		log_err("Failed to set SO_RCVTIMEO");
55 		return -1;
56 	}
57 
58 	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout,
59 		       sizeof(timeout))) {
60 		log_err("Failed to set SO_SNDTIMEO");
61 		return -1;
62 	}
63 
64 	return 0;
65 }
66 
67 #define save_errno_close(fd) ({ int __save = errno; close(fd); errno = __save; })
68 
69 int start_server(int family, int type, const char *addr_str, __u16 port,
70 		 int timeout_ms)
71 {
72 	struct sockaddr_storage addr = {};
73 	socklen_t len;
74 	int fd;
75 
76 	if (make_sockaddr(family, addr_str, port, &addr, &len))
77 		return -1;
78 
79 	fd = socket(family, type, 0);
80 	if (fd < 0) {
81 		log_err("Failed to create server socket");
82 		return -1;
83 	}
84 
85 	if (settimeo(fd, timeout_ms))
86 		goto error_close;
87 
88 	if (bind(fd, (const struct sockaddr *)&addr, len) < 0) {
89 		log_err("Failed to bind socket");
90 		goto error_close;
91 	}
92 
93 	if (type == SOCK_STREAM) {
94 		if (listen(fd, 1) < 0) {
95 			log_err("Failed to listed on socket");
96 			goto error_close;
97 		}
98 	}
99 
100 	return fd;
101 
102 error_close:
103 	save_errno_close(fd);
104 	return -1;
105 }
106 
107 int fastopen_connect(int server_fd, const char *data, unsigned int data_len,
108 		     int timeout_ms)
109 {
110 	struct sockaddr_storage addr;
111 	socklen_t addrlen = sizeof(addr);
112 	struct sockaddr_in *addr_in;
113 	int fd, ret;
114 
115 	if (getsockname(server_fd, (struct sockaddr *)&addr, &addrlen)) {
116 		log_err("Failed to get server addr");
117 		return -1;
118 	}
119 
120 	addr_in = (struct sockaddr_in *)&addr;
121 	fd = socket(addr_in->sin_family, SOCK_STREAM, 0);
122 	if (fd < 0) {
123 		log_err("Failed to create client socket");
124 		return -1;
125 	}
126 
127 	if (settimeo(fd, timeout_ms))
128 		goto error_close;
129 
130 	ret = sendto(fd, data, data_len, MSG_FASTOPEN, (struct sockaddr *)&addr,
131 		     addrlen);
132 	if (ret != data_len) {
133 		log_err("sendto(data, %u) != %d\n", data_len, ret);
134 		goto error_close;
135 	}
136 
137 	return fd;
138 
139 error_close:
140 	save_errno_close(fd);
141 	return -1;
142 }
143 
144 static int connect_fd_to_addr(int fd,
145 			      const struct sockaddr_storage *addr,
146 			      socklen_t addrlen)
147 {
148 	if (connect(fd, (const struct sockaddr *)addr, addrlen)) {
149 		log_err("Failed to connect to server");
150 		return -1;
151 	}
152 
153 	return 0;
154 }
155 
156 int connect_to_fd(int server_fd, int timeout_ms)
157 {
158 	struct sockaddr_storage addr;
159 	struct sockaddr_in *addr_in;
160 	socklen_t addrlen, optlen;
161 	int fd, type;
162 
163 	optlen = sizeof(type);
164 	if (getsockopt(server_fd, SOL_SOCKET, SO_TYPE, &type, &optlen)) {
165 		log_err("getsockopt(SOL_TYPE)");
166 		return -1;
167 	}
168 
169 	addrlen = sizeof(addr);
170 	if (getsockname(server_fd, (struct sockaddr *)&addr, &addrlen)) {
171 		log_err("Failed to get server addr");
172 		return -1;
173 	}
174 
175 	addr_in = (struct sockaddr_in *)&addr;
176 	fd = socket(addr_in->sin_family, type, 0);
177 	if (fd < 0) {
178 		log_err("Failed to create client socket");
179 		return -1;
180 	}
181 
182 	if (settimeo(fd, timeout_ms))
183 		goto error_close;
184 
185 	if (connect_fd_to_addr(fd, &addr, addrlen))
186 		goto error_close;
187 
188 	return fd;
189 
190 error_close:
191 	save_errno_close(fd);
192 	return -1;
193 }
194 
195 int connect_fd_to_fd(int client_fd, int server_fd, int timeout_ms)
196 {
197 	struct sockaddr_storage addr;
198 	socklen_t len = sizeof(addr);
199 
200 	if (settimeo(client_fd, timeout_ms))
201 		return -1;
202 
203 	if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
204 		log_err("Failed to get server addr");
205 		return -1;
206 	}
207 
208 	if (connect_fd_to_addr(client_fd, &addr, len))
209 		return -1;
210 
211 	return 0;
212 }
213 
214 int make_sockaddr(int family, const char *addr_str, __u16 port,
215 		  struct sockaddr_storage *addr, socklen_t *len)
216 {
217 	if (family == AF_INET) {
218 		struct sockaddr_in *sin = (void *)addr;
219 
220 		sin->sin_family = AF_INET;
221 		sin->sin_port = htons(port);
222 		if (addr_str &&
223 		    inet_pton(AF_INET, addr_str, &sin->sin_addr) != 1) {
224 			log_err("inet_pton(AF_INET, %s)", addr_str);
225 			return -1;
226 		}
227 		if (len)
228 			*len = sizeof(*sin);
229 		return 0;
230 	} else if (family == AF_INET6) {
231 		struct sockaddr_in6 *sin6 = (void *)addr;
232 
233 		sin6->sin6_family = AF_INET6;
234 		sin6->sin6_port = htons(port);
235 		if (addr_str &&
236 		    inet_pton(AF_INET6, addr_str, &sin6->sin6_addr) != 1) {
237 			log_err("inet_pton(AF_INET6, %s)", addr_str);
238 			return -1;
239 		}
240 		if (len)
241 			*len = sizeof(*sin6);
242 		return 0;
243 	}
244 	return -1;
245 }
246