1 #ifndef __SOCKMAP_HELPERS__
2 #define __SOCKMAP_HELPERS__
3 
4 #include <linux/vm_sockets.h>
5 
6 #define IO_TIMEOUT_SEC 30
7 #define MAX_STRERR_LEN 256
8 #define MAX_TEST_NAME 80
9 
10 /* workaround for older vm_sockets.h */
11 #ifndef VMADDR_CID_LOCAL
12 #define VMADDR_CID_LOCAL 1
13 #endif
14 
15 #define __always_unused	__attribute__((__unused__))
16 
17 #define _FAIL(errnum, fmt...)                                                  \
18 	({                                                                     \
19 		error_at_line(0, (errnum), __func__, __LINE__, fmt);           \
20 		CHECK_FAIL(true);                                              \
21 	})
22 #define FAIL(fmt...) _FAIL(0, fmt)
23 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
24 #define FAIL_LIBBPF(err, msg)                                                  \
25 	({                                                                     \
26 		char __buf[MAX_STRERR_LEN];                                    \
27 		libbpf_strerror((err), __buf, sizeof(__buf));                  \
28 		FAIL("%s: %s", (msg), __buf);                                  \
29 	})
30 
31 /* Wrappers that fail the test on error and report it. */
32 
33 #define xaccept_nonblock(fd, addr, len)                                        \
34 	({                                                                     \
35 		int __ret =                                                    \
36 			accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC);   \
37 		if (__ret == -1)                                               \
38 			FAIL_ERRNO("accept");                                  \
39 		__ret;                                                         \
40 	})
41 
42 #define xbind(fd, addr, len)                                                   \
43 	({                                                                     \
44 		int __ret = bind((fd), (addr), (len));                         \
45 		if (__ret == -1)                                               \
46 			FAIL_ERRNO("bind");                                    \
47 		__ret;                                                         \
48 	})
49 
50 #define xclose(fd)                                                             \
51 	({                                                                     \
52 		int __ret = close((fd));                                       \
53 		if (__ret == -1)                                               \
54 			FAIL_ERRNO("close");                                   \
55 		__ret;                                                         \
56 	})
57 
58 #define xconnect(fd, addr, len)                                                \
59 	({                                                                     \
60 		int __ret = connect((fd), (addr), (len));                      \
61 		if (__ret == -1)                                               \
62 			FAIL_ERRNO("connect");                                 \
63 		__ret;                                                         \
64 	})
65 
66 #define xgetsockname(fd, addr, len)                                            \
67 	({                                                                     \
68 		int __ret = getsockname((fd), (addr), (len));                  \
69 		if (__ret == -1)                                               \
70 			FAIL_ERRNO("getsockname");                             \
71 		__ret;                                                         \
72 	})
73 
74 #define xgetsockopt(fd, level, name, val, len)                                 \
75 	({                                                                     \
76 		int __ret = getsockopt((fd), (level), (name), (val), (len));   \
77 		if (__ret == -1)                                               \
78 			FAIL_ERRNO("getsockopt(" #name ")");                   \
79 		__ret;                                                         \
80 	})
81 
82 #define xlisten(fd, backlog)                                                   \
83 	({                                                                     \
84 		int __ret = listen((fd), (backlog));                           \
85 		if (__ret == -1)                                               \
86 			FAIL_ERRNO("listen");                                  \
87 		__ret;                                                         \
88 	})
89 
90 #define xsetsockopt(fd, level, name, val, len)                                 \
91 	({                                                                     \
92 		int __ret = setsockopt((fd), (level), (name), (val), (len));   \
93 		if (__ret == -1)                                               \
94 			FAIL_ERRNO("setsockopt(" #name ")");                   \
95 		__ret;                                                         \
96 	})
97 
98 #define xsend(fd, buf, len, flags)                                             \
99 	({                                                                     \
100 		ssize_t __ret = send((fd), (buf), (len), (flags));             \
101 		if (__ret == -1)                                               \
102 			FAIL_ERRNO("send");                                    \
103 		__ret;                                                         \
104 	})
105 
106 #define xrecv_nonblock(fd, buf, len, flags)                                    \
107 	({                                                                     \
108 		ssize_t __ret = recv_timeout((fd), (buf), (len), (flags),      \
109 					     IO_TIMEOUT_SEC);                  \
110 		if (__ret == -1)                                               \
111 			FAIL_ERRNO("recv");                                    \
112 		__ret;                                                         \
113 	})
114 
115 #define xsocket(family, sotype, flags)                                         \
116 	({                                                                     \
117 		int __ret = socket(family, sotype, flags);                     \
118 		if (__ret == -1)                                               \
119 			FAIL_ERRNO("socket");                                  \
120 		__ret;                                                         \
121 	})
122 
123 #define xbpf_map_delete_elem(fd, key)                                          \
124 	({                                                                     \
125 		int __ret = bpf_map_delete_elem((fd), (key));                  \
126 		if (__ret < 0)                                               \
127 			FAIL_ERRNO("map_delete");                              \
128 		__ret;                                                         \
129 	})
130 
131 #define xbpf_map_lookup_elem(fd, key, val)                                     \
132 	({                                                                     \
133 		int __ret = bpf_map_lookup_elem((fd), (key), (val));           \
134 		if (__ret < 0)                                               \
135 			FAIL_ERRNO("map_lookup");                              \
136 		__ret;                                                         \
137 	})
138 
139 #define xbpf_map_update_elem(fd, key, val, flags)                              \
140 	({                                                                     \
141 		int __ret = bpf_map_update_elem((fd), (key), (val), (flags));  \
142 		if (__ret < 0)                                               \
143 			FAIL_ERRNO("map_update");                              \
144 		__ret;                                                         \
145 	})
146 
147 #define xbpf_prog_attach(prog, target, type, flags)                            \
148 	({                                                                     \
149 		int __ret =                                                    \
150 			bpf_prog_attach((prog), (target), (type), (flags));    \
151 		if (__ret < 0)                                               \
152 			FAIL_ERRNO("prog_attach(" #type ")");                  \
153 		__ret;                                                         \
154 	})
155 
156 #define xbpf_prog_detach2(prog, target, type)                                  \
157 	({                                                                     \
158 		int __ret = bpf_prog_detach2((prog), (target), (type));        \
159 		if (__ret < 0)                                               \
160 			FAIL_ERRNO("prog_detach2(" #type ")");                 \
161 		__ret;                                                         \
162 	})
163 
164 #define xpthread_create(thread, attr, func, arg)                               \
165 	({                                                                     \
166 		int __ret = pthread_create((thread), (attr), (func), (arg));   \
167 		errno = __ret;                                                 \
168 		if (__ret)                                                     \
169 			FAIL_ERRNO("pthread_create");                          \
170 		__ret;                                                         \
171 	})
172 
173 #define xpthread_join(thread, retval)                                          \
174 	({                                                                     \
175 		int __ret = pthread_join((thread), (retval));                  \
176 		errno = __ret;                                                 \
177 		if (__ret)                                                     \
178 			FAIL_ERRNO("pthread_join");                            \
179 		__ret;                                                         \
180 	})
181 
poll_connect(int fd,unsigned int timeout_sec)182 static inline int poll_connect(int fd, unsigned int timeout_sec)
183 {
184 	struct timeval timeout = { .tv_sec = timeout_sec };
185 	fd_set wfds;
186 	int r, eval;
187 	socklen_t esize = sizeof(eval);
188 
189 	FD_ZERO(&wfds);
190 	FD_SET(fd, &wfds);
191 
192 	r = select(fd + 1, NULL, &wfds, NULL, &timeout);
193 	if (r == 0)
194 		errno = ETIME;
195 	if (r != 1)
196 		return -1;
197 
198 	if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0)
199 		return -1;
200 	if (eval != 0) {
201 		errno = eval;
202 		return -1;
203 	}
204 
205 	return 0;
206 }
207 
poll_read(int fd,unsigned int timeout_sec)208 static inline int poll_read(int fd, unsigned int timeout_sec)
209 {
210 	struct timeval timeout = { .tv_sec = timeout_sec };
211 	fd_set rfds;
212 	int r;
213 
214 	FD_ZERO(&rfds);
215 	FD_SET(fd, &rfds);
216 
217 	r = select(fd + 1, &rfds, NULL, NULL, &timeout);
218 	if (r == 0)
219 		errno = ETIME;
220 
221 	return r == 1 ? 0 : -1;
222 }
223 
accept_timeout(int fd,struct sockaddr * addr,socklen_t * len,unsigned int timeout_sec)224 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len,
225 				 unsigned int timeout_sec)
226 {
227 	if (poll_read(fd, timeout_sec))
228 		return -1;
229 
230 	return accept(fd, addr, len);
231 }
232 
recv_timeout(int fd,void * buf,size_t len,int flags,unsigned int timeout_sec)233 static inline int recv_timeout(int fd, void *buf, size_t len, int flags,
234 			       unsigned int timeout_sec)
235 {
236 	if (poll_read(fd, timeout_sec))
237 		return -1;
238 
239 	return recv(fd, buf, len, flags);
240 }
241 
init_addr_loopback4(struct sockaddr_storage * ss,socklen_t * len)242 static inline void init_addr_loopback4(struct sockaddr_storage *ss,
243 				       socklen_t *len)
244 {
245 	struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
246 
247 	addr4->sin_family = AF_INET;
248 	addr4->sin_port = 0;
249 	addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
250 	*len = sizeof(*addr4);
251 }
252 
init_addr_loopback6(struct sockaddr_storage * ss,socklen_t * len)253 static inline void init_addr_loopback6(struct sockaddr_storage *ss,
254 				       socklen_t *len)
255 {
256 	struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
257 
258 	addr6->sin6_family = AF_INET6;
259 	addr6->sin6_port = 0;
260 	addr6->sin6_addr = in6addr_loopback;
261 	*len = sizeof(*addr6);
262 }
263 
init_addr_loopback_vsock(struct sockaddr_storage * ss,socklen_t * len)264 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss,
265 					    socklen_t *len)
266 {
267 	struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss));
268 
269 	addr->svm_family = AF_VSOCK;
270 	addr->svm_port = VMADDR_PORT_ANY;
271 	addr->svm_cid = VMADDR_CID_LOCAL;
272 	*len = sizeof(*addr);
273 }
274 
init_addr_loopback(int family,struct sockaddr_storage * ss,socklen_t * len)275 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss,
276 				      socklen_t *len)
277 {
278 	switch (family) {
279 	case AF_INET:
280 		init_addr_loopback4(ss, len);
281 		return;
282 	case AF_INET6:
283 		init_addr_loopback6(ss, len);
284 		return;
285 	case AF_VSOCK:
286 		init_addr_loopback_vsock(ss, len);
287 		return;
288 	default:
289 		FAIL("unsupported address family %d", family);
290 	}
291 }
292 
sockaddr(struct sockaddr_storage * ss)293 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
294 {
295 	return (struct sockaddr *)ss;
296 }
297 
add_to_sockmap(int sock_mapfd,int fd1,int fd2)298 static inline int add_to_sockmap(int sock_mapfd, int fd1, int fd2)
299 {
300 	u64 value;
301 	u32 key;
302 	int err;
303 
304 	key = 0;
305 	value = fd1;
306 	err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
307 	if (err)
308 		return err;
309 
310 	key = 1;
311 	value = fd2;
312 	return xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
313 }
314 
create_pair(int s,int family,int sotype,int * c,int * p)315 static inline int create_pair(int s, int family, int sotype, int *c, int *p)
316 {
317 	struct sockaddr_storage addr;
318 	socklen_t len;
319 	int err = 0;
320 
321 	len = sizeof(addr);
322 	err = xgetsockname(s, sockaddr(&addr), &len);
323 	if (err)
324 		return err;
325 
326 	*c = xsocket(family, sotype, 0);
327 	if (*c < 0)
328 		return errno;
329 	err = xconnect(*c, sockaddr(&addr), len);
330 	if (err) {
331 		err = errno;
332 		goto close_cli0;
333 	}
334 
335 	*p = xaccept_nonblock(s, NULL, NULL);
336 	if (*p < 0) {
337 		err = errno;
338 		goto close_cli0;
339 	}
340 	return err;
341 close_cli0:
342 	close(*c);
343 	return err;
344 }
345 
create_socket_pairs(int s,int family,int sotype,int * c0,int * c1,int * p0,int * p1)346 static inline int create_socket_pairs(int s, int family, int sotype,
347 				      int *c0, int *c1, int *p0, int *p1)
348 {
349 	int err;
350 
351 	err = create_pair(s, family, sotype, c0, p0);
352 	if (err)
353 		return err;
354 
355 	err = create_pair(s, family, sotype, c1, p1);
356 	if (err) {
357 		close(*c0);
358 		close(*p0);
359 	}
360 	return err;
361 }
362 
enable_reuseport(int s,int progfd)363 static inline int enable_reuseport(int s, int progfd)
364 {
365 	int err, one = 1;
366 
367 	err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
368 	if (err)
369 		return -1;
370 	err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
371 			  sizeof(progfd));
372 	if (err)
373 		return -1;
374 
375 	return 0;
376 }
377 
socket_loopback_reuseport(int family,int sotype,int progfd)378 static inline int socket_loopback_reuseport(int family, int sotype, int progfd)
379 {
380 	struct sockaddr_storage addr;
381 	socklen_t len;
382 	int err, s;
383 
384 	init_addr_loopback(family, &addr, &len);
385 
386 	s = xsocket(family, sotype, 0);
387 	if (s == -1)
388 		return -1;
389 
390 	if (progfd >= 0)
391 		enable_reuseport(s, progfd);
392 
393 	err = xbind(s, sockaddr(&addr), len);
394 	if (err)
395 		goto close;
396 
397 	if (sotype & SOCK_DGRAM)
398 		return s;
399 
400 	err = xlisten(s, SOMAXCONN);
401 	if (err)
402 		goto close;
403 
404 	return s;
405 close:
406 	xclose(s);
407 	return -1;
408 }
409 
socket_loopback(int family,int sotype)410 static inline int socket_loopback(int family, int sotype)
411 {
412 	return socket_loopback_reuseport(family, sotype, -1);
413 }
414 
415 
416 #endif // __SOCKMAP_HELPERS__
417