1 #ifndef __SOCKMAP_HELPERS__
2 #define __SOCKMAP_HELPERS__
3
4 #include <linux/vm_sockets.h>
5
6 #define IO_TIMEOUT_SEC 30
7 #define MAX_STRERR_LEN 256
8 #define MAX_TEST_NAME 80
9
10 /* workaround for older vm_sockets.h */
11 #ifndef VMADDR_CID_LOCAL
12 #define VMADDR_CID_LOCAL 1
13 #endif
14
15 #define __always_unused __attribute__((__unused__))
16
17 #define _FAIL(errnum, fmt...) \
18 ({ \
19 error_at_line(0, (errnum), __func__, __LINE__, fmt); \
20 CHECK_FAIL(true); \
21 })
22 #define FAIL(fmt...) _FAIL(0, fmt)
23 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
24 #define FAIL_LIBBPF(err, msg) \
25 ({ \
26 char __buf[MAX_STRERR_LEN]; \
27 libbpf_strerror((err), __buf, sizeof(__buf)); \
28 FAIL("%s: %s", (msg), __buf); \
29 })
30
31 /* Wrappers that fail the test on error and report it. */
32
33 #define xaccept_nonblock(fd, addr, len) \
34 ({ \
35 int __ret = \
36 accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC); \
37 if (__ret == -1) \
38 FAIL_ERRNO("accept"); \
39 __ret; \
40 })
41
42 #define xbind(fd, addr, len) \
43 ({ \
44 int __ret = bind((fd), (addr), (len)); \
45 if (__ret == -1) \
46 FAIL_ERRNO("bind"); \
47 __ret; \
48 })
49
50 #define xclose(fd) \
51 ({ \
52 int __ret = close((fd)); \
53 if (__ret == -1) \
54 FAIL_ERRNO("close"); \
55 __ret; \
56 })
57
58 #define xconnect(fd, addr, len) \
59 ({ \
60 int __ret = connect((fd), (addr), (len)); \
61 if (__ret == -1) \
62 FAIL_ERRNO("connect"); \
63 __ret; \
64 })
65
66 #define xgetsockname(fd, addr, len) \
67 ({ \
68 int __ret = getsockname((fd), (addr), (len)); \
69 if (__ret == -1) \
70 FAIL_ERRNO("getsockname"); \
71 __ret; \
72 })
73
74 #define xgetsockopt(fd, level, name, val, len) \
75 ({ \
76 int __ret = getsockopt((fd), (level), (name), (val), (len)); \
77 if (__ret == -1) \
78 FAIL_ERRNO("getsockopt(" #name ")"); \
79 __ret; \
80 })
81
82 #define xlisten(fd, backlog) \
83 ({ \
84 int __ret = listen((fd), (backlog)); \
85 if (__ret == -1) \
86 FAIL_ERRNO("listen"); \
87 __ret; \
88 })
89
90 #define xsetsockopt(fd, level, name, val, len) \
91 ({ \
92 int __ret = setsockopt((fd), (level), (name), (val), (len)); \
93 if (__ret == -1) \
94 FAIL_ERRNO("setsockopt(" #name ")"); \
95 __ret; \
96 })
97
98 #define xsend(fd, buf, len, flags) \
99 ({ \
100 ssize_t __ret = send((fd), (buf), (len), (flags)); \
101 if (__ret == -1) \
102 FAIL_ERRNO("send"); \
103 __ret; \
104 })
105
106 #define xrecv_nonblock(fd, buf, len, flags) \
107 ({ \
108 ssize_t __ret = recv_timeout((fd), (buf), (len), (flags), \
109 IO_TIMEOUT_SEC); \
110 if (__ret == -1) \
111 FAIL_ERRNO("recv"); \
112 __ret; \
113 })
114
115 #define xsocket(family, sotype, flags) \
116 ({ \
117 int __ret = socket(family, sotype, flags); \
118 if (__ret == -1) \
119 FAIL_ERRNO("socket"); \
120 __ret; \
121 })
122
123 #define xbpf_map_delete_elem(fd, key) \
124 ({ \
125 int __ret = bpf_map_delete_elem((fd), (key)); \
126 if (__ret < 0) \
127 FAIL_ERRNO("map_delete"); \
128 __ret; \
129 })
130
131 #define xbpf_map_lookup_elem(fd, key, val) \
132 ({ \
133 int __ret = bpf_map_lookup_elem((fd), (key), (val)); \
134 if (__ret < 0) \
135 FAIL_ERRNO("map_lookup"); \
136 __ret; \
137 })
138
139 #define xbpf_map_update_elem(fd, key, val, flags) \
140 ({ \
141 int __ret = bpf_map_update_elem((fd), (key), (val), (flags)); \
142 if (__ret < 0) \
143 FAIL_ERRNO("map_update"); \
144 __ret; \
145 })
146
147 #define xbpf_prog_attach(prog, target, type, flags) \
148 ({ \
149 int __ret = \
150 bpf_prog_attach((prog), (target), (type), (flags)); \
151 if (__ret < 0) \
152 FAIL_ERRNO("prog_attach(" #type ")"); \
153 __ret; \
154 })
155
156 #define xbpf_prog_detach2(prog, target, type) \
157 ({ \
158 int __ret = bpf_prog_detach2((prog), (target), (type)); \
159 if (__ret < 0) \
160 FAIL_ERRNO("prog_detach2(" #type ")"); \
161 __ret; \
162 })
163
164 #define xpthread_create(thread, attr, func, arg) \
165 ({ \
166 int __ret = pthread_create((thread), (attr), (func), (arg)); \
167 errno = __ret; \
168 if (__ret) \
169 FAIL_ERRNO("pthread_create"); \
170 __ret; \
171 })
172
173 #define xpthread_join(thread, retval) \
174 ({ \
175 int __ret = pthread_join((thread), (retval)); \
176 errno = __ret; \
177 if (__ret) \
178 FAIL_ERRNO("pthread_join"); \
179 __ret; \
180 })
181
poll_connect(int fd,unsigned int timeout_sec)182 static inline int poll_connect(int fd, unsigned int timeout_sec)
183 {
184 struct timeval timeout = { .tv_sec = timeout_sec };
185 fd_set wfds;
186 int r, eval;
187 socklen_t esize = sizeof(eval);
188
189 FD_ZERO(&wfds);
190 FD_SET(fd, &wfds);
191
192 r = select(fd + 1, NULL, &wfds, NULL, &timeout);
193 if (r == 0)
194 errno = ETIME;
195 if (r != 1)
196 return -1;
197
198 if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0)
199 return -1;
200 if (eval != 0) {
201 errno = eval;
202 return -1;
203 }
204
205 return 0;
206 }
207
poll_read(int fd,unsigned int timeout_sec)208 static inline int poll_read(int fd, unsigned int timeout_sec)
209 {
210 struct timeval timeout = { .tv_sec = timeout_sec };
211 fd_set rfds;
212 int r;
213
214 FD_ZERO(&rfds);
215 FD_SET(fd, &rfds);
216
217 r = select(fd + 1, &rfds, NULL, NULL, &timeout);
218 if (r == 0)
219 errno = ETIME;
220
221 return r == 1 ? 0 : -1;
222 }
223
accept_timeout(int fd,struct sockaddr * addr,socklen_t * len,unsigned int timeout_sec)224 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len,
225 unsigned int timeout_sec)
226 {
227 if (poll_read(fd, timeout_sec))
228 return -1;
229
230 return accept(fd, addr, len);
231 }
232
recv_timeout(int fd,void * buf,size_t len,int flags,unsigned int timeout_sec)233 static inline int recv_timeout(int fd, void *buf, size_t len, int flags,
234 unsigned int timeout_sec)
235 {
236 if (poll_read(fd, timeout_sec))
237 return -1;
238
239 return recv(fd, buf, len, flags);
240 }
241
init_addr_loopback4(struct sockaddr_storage * ss,socklen_t * len)242 static inline void init_addr_loopback4(struct sockaddr_storage *ss,
243 socklen_t *len)
244 {
245 struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
246
247 addr4->sin_family = AF_INET;
248 addr4->sin_port = 0;
249 addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
250 *len = sizeof(*addr4);
251 }
252
init_addr_loopback6(struct sockaddr_storage * ss,socklen_t * len)253 static inline void init_addr_loopback6(struct sockaddr_storage *ss,
254 socklen_t *len)
255 {
256 struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
257
258 addr6->sin6_family = AF_INET6;
259 addr6->sin6_port = 0;
260 addr6->sin6_addr = in6addr_loopback;
261 *len = sizeof(*addr6);
262 }
263
init_addr_loopback_vsock(struct sockaddr_storage * ss,socklen_t * len)264 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss,
265 socklen_t *len)
266 {
267 struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss));
268
269 addr->svm_family = AF_VSOCK;
270 addr->svm_port = VMADDR_PORT_ANY;
271 addr->svm_cid = VMADDR_CID_LOCAL;
272 *len = sizeof(*addr);
273 }
274
init_addr_loopback(int family,struct sockaddr_storage * ss,socklen_t * len)275 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss,
276 socklen_t *len)
277 {
278 switch (family) {
279 case AF_INET:
280 init_addr_loopback4(ss, len);
281 return;
282 case AF_INET6:
283 init_addr_loopback6(ss, len);
284 return;
285 case AF_VSOCK:
286 init_addr_loopback_vsock(ss, len);
287 return;
288 default:
289 FAIL("unsupported address family %d", family);
290 }
291 }
292
sockaddr(struct sockaddr_storage * ss)293 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
294 {
295 return (struct sockaddr *)ss;
296 }
297
add_to_sockmap(int sock_mapfd,int fd1,int fd2)298 static inline int add_to_sockmap(int sock_mapfd, int fd1, int fd2)
299 {
300 u64 value;
301 u32 key;
302 int err;
303
304 key = 0;
305 value = fd1;
306 err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
307 if (err)
308 return err;
309
310 key = 1;
311 value = fd2;
312 return xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
313 }
314
create_pair(int s,int family,int sotype,int * c,int * p)315 static inline int create_pair(int s, int family, int sotype, int *c, int *p)
316 {
317 struct sockaddr_storage addr;
318 socklen_t len;
319 int err = 0;
320
321 len = sizeof(addr);
322 err = xgetsockname(s, sockaddr(&addr), &len);
323 if (err)
324 return err;
325
326 *c = xsocket(family, sotype, 0);
327 if (*c < 0)
328 return errno;
329 err = xconnect(*c, sockaddr(&addr), len);
330 if (err) {
331 err = errno;
332 goto close_cli0;
333 }
334
335 *p = xaccept_nonblock(s, NULL, NULL);
336 if (*p < 0) {
337 err = errno;
338 goto close_cli0;
339 }
340 return err;
341 close_cli0:
342 close(*c);
343 return err;
344 }
345
create_socket_pairs(int s,int family,int sotype,int * c0,int * c1,int * p0,int * p1)346 static inline int create_socket_pairs(int s, int family, int sotype,
347 int *c0, int *c1, int *p0, int *p1)
348 {
349 int err;
350
351 err = create_pair(s, family, sotype, c0, p0);
352 if (err)
353 return err;
354
355 err = create_pair(s, family, sotype, c1, p1);
356 if (err) {
357 close(*c0);
358 close(*p0);
359 }
360 return err;
361 }
362
enable_reuseport(int s,int progfd)363 static inline int enable_reuseport(int s, int progfd)
364 {
365 int err, one = 1;
366
367 err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
368 if (err)
369 return -1;
370 err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
371 sizeof(progfd));
372 if (err)
373 return -1;
374
375 return 0;
376 }
377
socket_loopback_reuseport(int family,int sotype,int progfd)378 static inline int socket_loopback_reuseport(int family, int sotype, int progfd)
379 {
380 struct sockaddr_storage addr;
381 socklen_t len;
382 int err, s;
383
384 init_addr_loopback(family, &addr, &len);
385
386 s = xsocket(family, sotype, 0);
387 if (s == -1)
388 return -1;
389
390 if (progfd >= 0)
391 enable_reuseport(s, progfd);
392
393 err = xbind(s, sockaddr(&addr), len);
394 if (err)
395 goto close;
396
397 if (sotype & SOCK_DGRAM)
398 return s;
399
400 err = xlisten(s, SOMAXCONN);
401 if (err)
402 goto close;
403
404 return s;
405 close:
406 xclose(s);
407 return -1;
408 }
409
socket_loopback(int family,int sotype)410 static inline int socket_loopback(int family, int sotype)
411 {
412 return socket_loopback_reuseport(family, sotype, -1);
413 }
414
415
416 #endif // __SOCKMAP_HELPERS__
417