1 // SPDX-License-Identifier: GPL-2.0
2 
3 #define _GNU_SOURCE
4 
5 #include <arpa/inet.h>
6 #include <errno.h>
7 #include <error.h>
8 #include <netinet/if_ether.h>
9 #include <netinet/in.h>
10 #include <netinet/ip.h>
11 #include <netinet/ip6.h>
12 #include <netinet/udp.h>
13 #include <poll.h>
14 #include <sched.h>
15 #include <signal.h>
16 #include <stdbool.h>
17 #include <stdio.h>
18 #include <stdlib.h>
19 #include <string.h>
20 #include <sys/socket.h>
21 #include <sys/time.h>
22 #include <sys/types.h>
23 #include <unistd.h>
24 
25 #ifndef ETH_MAX_MTU
26 #define ETH_MAX_MTU 0xFFFFU
27 #endif
28 
29 #ifndef UDP_SEGMENT
30 #define UDP_SEGMENT		103
31 #endif
32 
33 #ifndef SO_ZEROCOPY
34 #define SO_ZEROCOPY	60
35 #endif
36 
37 #ifndef MSG_ZEROCOPY
38 #define MSG_ZEROCOPY	0x4000000
39 #endif
40 
41 #define NUM_PKT		100
42 
43 static bool	cfg_cache_trash;
44 static int	cfg_cpu		= -1;
45 static int	cfg_connected	= true;
46 static int	cfg_family	= PF_UNSPEC;
47 static uint16_t	cfg_mss;
48 static int	cfg_payload_len	= (1472 * 42);
49 static int	cfg_port	= 8000;
50 static int	cfg_runtime_ms	= -1;
51 static bool	cfg_segment;
52 static bool	cfg_sendmmsg;
53 static bool	cfg_tcp;
54 static bool	cfg_zerocopy;
55 
56 static socklen_t cfg_alen;
57 static struct sockaddr_storage cfg_dst_addr;
58 
59 static bool interrupted;
60 static char buf[NUM_PKT][ETH_MAX_MTU];
61 
62 static void sigint_handler(int signum)
63 {
64 	if (signum == SIGINT)
65 		interrupted = true;
66 }
67 
68 static unsigned long gettimeofday_ms(void)
69 {
70 	struct timeval tv;
71 
72 	gettimeofday(&tv, NULL);
73 	return (tv.tv_sec * 1000) + (tv.tv_usec / 1000);
74 }
75 
76 static int set_cpu(int cpu)
77 {
78 	cpu_set_t mask;
79 
80 	CPU_ZERO(&mask);
81 	CPU_SET(cpu, &mask);
82 	if (sched_setaffinity(0, sizeof(mask), &mask))
83 		error(1, 0, "setaffinity %d", cpu);
84 
85 	return 0;
86 }
87 
88 static void setup_sockaddr(int domain, const char *str_addr, void *sockaddr)
89 {
90 	struct sockaddr_in6 *addr6 = (void *) sockaddr;
91 	struct sockaddr_in *addr4 = (void *) sockaddr;
92 
93 	switch (domain) {
94 	case PF_INET:
95 		addr4->sin_family = AF_INET;
96 		addr4->sin_port = htons(cfg_port);
97 		if (inet_pton(AF_INET, str_addr, &(addr4->sin_addr)) != 1)
98 			error(1, 0, "ipv4 parse error: %s", str_addr);
99 		break;
100 	case PF_INET6:
101 		addr6->sin6_family = AF_INET6;
102 		addr6->sin6_port = htons(cfg_port);
103 		if (inet_pton(AF_INET6, str_addr, &(addr6->sin6_addr)) != 1)
104 			error(1, 0, "ipv6 parse error: %s", str_addr);
105 		break;
106 	default:
107 		error(1, 0, "illegal domain");
108 	}
109 }
110 
111 static void flush_zerocopy(int fd)
112 {
113 	struct msghdr msg = {0};	/* flush */
114 	int ret;
115 
116 	while (1) {
117 		ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
118 		if (ret == -1 && errno == EAGAIN)
119 			break;
120 		if (ret == -1)
121 			error(1, errno, "errqueue");
122 		if (msg.msg_flags != (MSG_ERRQUEUE | MSG_CTRUNC))
123 			error(1, 0, "errqueue: flags 0x%x\n", msg.msg_flags);
124 		msg.msg_flags = 0;
125 	}
126 }
127 
128 static int send_tcp(int fd, char *data)
129 {
130 	int ret, done = 0, count = 0;
131 
132 	while (done < cfg_payload_len) {
133 		ret = send(fd, data + done, cfg_payload_len - done,
134 			   cfg_zerocopy ? MSG_ZEROCOPY : 0);
135 		if (ret == -1)
136 			error(1, errno, "write");
137 
138 		done += ret;
139 		count++;
140 	}
141 
142 	return count;
143 }
144 
145 static int send_udp(int fd, char *data)
146 {
147 	int ret, total_len, len, count = 0;
148 
149 	total_len = cfg_payload_len;
150 
151 	while (total_len) {
152 		len = total_len < cfg_mss ? total_len : cfg_mss;
153 
154 		ret = sendto(fd, data, len, cfg_zerocopy ? MSG_ZEROCOPY : 0,
155 			     cfg_connected ? NULL : (void *)&cfg_dst_addr,
156 			     cfg_connected ? 0 : cfg_alen);
157 		if (ret == -1)
158 			error(1, errno, "write");
159 		if (ret != len)
160 			error(1, errno, "write: %uB != %uB\n", ret, len);
161 
162 		total_len -= len;
163 		count++;
164 	}
165 
166 	return count;
167 }
168 
169 static int send_udp_sendmmsg(int fd, char *data)
170 {
171 	const int max_nr_msg = ETH_MAX_MTU / ETH_DATA_LEN;
172 	struct mmsghdr mmsgs[max_nr_msg];
173 	struct iovec iov[max_nr_msg];
174 	unsigned int off = 0, left;
175 	int i = 0, ret;
176 
177 	memset(mmsgs, 0, sizeof(mmsgs));
178 
179 	left = cfg_payload_len;
180 	while (left) {
181 		if (i == max_nr_msg)
182 			error(1, 0, "sendmmsg: exceeds max_nr_msg");
183 
184 		iov[i].iov_base = data + off;
185 		iov[i].iov_len = cfg_mss < left ? cfg_mss : left;
186 
187 		mmsgs[i].msg_hdr.msg_iov = iov + i;
188 		mmsgs[i].msg_hdr.msg_iovlen = 1;
189 
190 		off += iov[i].iov_len;
191 		left -= iov[i].iov_len;
192 		i++;
193 	}
194 
195 	ret = sendmmsg(fd, mmsgs, i, cfg_zerocopy ? MSG_ZEROCOPY : 0);
196 	if (ret == -1)
197 		error(1, errno, "sendmmsg");
198 
199 	return ret;
200 }
201 
202 static void send_udp_segment_cmsg(struct cmsghdr *cm)
203 {
204 	uint16_t *valp;
205 
206 	cm->cmsg_level = SOL_UDP;
207 	cm->cmsg_type = UDP_SEGMENT;
208 	cm->cmsg_len = CMSG_LEN(sizeof(cfg_mss));
209 	valp = (void *)CMSG_DATA(cm);
210 	*valp = cfg_mss;
211 }
212 
213 static int send_udp_segment(int fd, char *data)
214 {
215 	char control[CMSG_SPACE(sizeof(cfg_mss))] = {0};
216 	struct msghdr msg = {0};
217 	struct iovec iov = {0};
218 	int ret;
219 
220 	iov.iov_base = data;
221 	iov.iov_len = cfg_payload_len;
222 
223 	msg.msg_iov = &iov;
224 	msg.msg_iovlen = 1;
225 
226 	msg.msg_control = control;
227 	msg.msg_controllen = sizeof(control);
228 	send_udp_segment_cmsg(CMSG_FIRSTHDR(&msg));
229 
230 	msg.msg_name = (void *)&cfg_dst_addr;
231 	msg.msg_namelen = cfg_alen;
232 
233 	ret = sendmsg(fd, &msg, cfg_zerocopy ? MSG_ZEROCOPY : 0);
234 	if (ret == -1)
235 		error(1, errno, "sendmsg");
236 	if (ret != iov.iov_len)
237 		error(1, 0, "sendmsg: %u != %lu\n", ret, iov.iov_len);
238 
239 	return 1;
240 }
241 
242 static void usage(const char *filepath)
243 {
244 	error(1, 0, "Usage: %s [-46cmStuz] [-C cpu] [-D dst ip] [-l secs] [-p port] [-s sendsize]",
245 		    filepath);
246 }
247 
248 static void parse_opts(int argc, char **argv)
249 {
250 	int max_len, hdrlen;
251 	int c;
252 
253 	while ((c = getopt(argc, argv, "46cC:D:l:mp:s:Stuz")) != -1) {
254 		switch (c) {
255 		case '4':
256 			if (cfg_family != PF_UNSPEC)
257 				error(1, 0, "Pass one of -4 or -6");
258 			cfg_family = PF_INET;
259 			cfg_alen = sizeof(struct sockaddr_in);
260 			break;
261 		case '6':
262 			if (cfg_family != PF_UNSPEC)
263 				error(1, 0, "Pass one of -4 or -6");
264 			cfg_family = PF_INET6;
265 			cfg_alen = sizeof(struct sockaddr_in6);
266 			break;
267 		case 'c':
268 			cfg_cache_trash = true;
269 			break;
270 		case 'C':
271 			cfg_cpu = strtol(optarg, NULL, 0);
272 			break;
273 		case 'D':
274 			setup_sockaddr(cfg_family, optarg, &cfg_dst_addr);
275 			break;
276 		case 'l':
277 			cfg_runtime_ms = strtoul(optarg, NULL, 10) * 1000;
278 			break;
279 		case 'm':
280 			cfg_sendmmsg = true;
281 			break;
282 		case 'p':
283 			cfg_port = strtoul(optarg, NULL, 0);
284 			break;
285 		case 's':
286 			cfg_payload_len = strtoul(optarg, NULL, 0);
287 			break;
288 		case 'S':
289 			cfg_segment = true;
290 			break;
291 		case 't':
292 			cfg_tcp = true;
293 			break;
294 		case 'u':
295 			cfg_connected = false;
296 			break;
297 		case 'z':
298 			cfg_zerocopy = true;
299 			break;
300 		}
301 	}
302 
303 	if (optind != argc)
304 		usage(argv[0]);
305 
306 	if (cfg_family == PF_UNSPEC)
307 		error(1, 0, "must pass one of -4 or -6");
308 	if (cfg_tcp && !cfg_connected)
309 		error(1, 0, "connectionless tcp makes no sense");
310 	if (cfg_segment && cfg_sendmmsg)
311 		error(1, 0, "cannot combine segment offload and sendmmsg");
312 
313 	if (cfg_family == PF_INET)
314 		hdrlen = sizeof(struct iphdr) + sizeof(struct udphdr);
315 	else
316 		hdrlen = sizeof(struct ip6_hdr) + sizeof(struct udphdr);
317 
318 	cfg_mss = ETH_DATA_LEN - hdrlen;
319 	max_len = ETH_MAX_MTU - hdrlen;
320 
321 	if (cfg_payload_len > max_len)
322 		error(1, 0, "payload length %u exceeds max %u",
323 		      cfg_payload_len, max_len);
324 }
325 
326 static void set_pmtu_discover(int fd, bool is_ipv4)
327 {
328 	int level, name, val;
329 
330 	if (is_ipv4) {
331 		level	= SOL_IP;
332 		name	= IP_MTU_DISCOVER;
333 		val	= IP_PMTUDISC_DO;
334 	} else {
335 		level	= SOL_IPV6;
336 		name	= IPV6_MTU_DISCOVER;
337 		val	= IPV6_PMTUDISC_DO;
338 	}
339 
340 	if (setsockopt(fd, level, name, &val, sizeof(val)))
341 		error(1, errno, "setsockopt path mtu");
342 }
343 
344 int main(int argc, char **argv)
345 {
346 	unsigned long num_msgs, num_sends;
347 	unsigned long tnow, treport, tstop;
348 	int fd, i, val;
349 
350 	parse_opts(argc, argv);
351 
352 	if (cfg_cpu > 0)
353 		set_cpu(cfg_cpu);
354 
355 	for (i = 0; i < sizeof(buf[0]); i++)
356 		buf[0][i] = 'a' + (i % 26);
357 	for (i = 1; i < NUM_PKT; i++)
358 		memcpy(buf[i], buf[0], sizeof(buf[0]));
359 
360 	signal(SIGINT, sigint_handler);
361 
362 	fd = socket(cfg_family, cfg_tcp ? SOCK_STREAM : SOCK_DGRAM, 0);
363 	if (fd == -1)
364 		error(1, errno, "socket");
365 
366 	if (cfg_zerocopy) {
367 		val = 1;
368 		if (setsockopt(fd, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val)))
369 			error(1, errno, "setsockopt zerocopy");
370 	}
371 
372 	if (cfg_connected &&
373 	    connect(fd, (void *)&cfg_dst_addr, cfg_alen))
374 		error(1, errno, "connect");
375 
376 	if (cfg_segment)
377 		set_pmtu_discover(fd, cfg_family == PF_INET);
378 
379 	num_msgs = num_sends = 0;
380 	tnow = gettimeofday_ms();
381 	tstop = tnow + cfg_runtime_ms;
382 	treport = tnow + 1000;
383 
384 	i = 0;
385 	do {
386 		if (cfg_tcp)
387 			num_sends += send_tcp(fd, buf[i]);
388 		else if (cfg_segment)
389 			num_sends += send_udp_segment(fd, buf[i]);
390 		else if (cfg_sendmmsg)
391 			num_sends += send_udp_sendmmsg(fd, buf[i]);
392 		else
393 			num_sends += send_udp(fd, buf[i]);
394 		num_msgs++;
395 
396 		if (cfg_zerocopy && ((num_msgs & 0xF) == 0))
397 			flush_zerocopy(fd);
398 
399 		tnow = gettimeofday_ms();
400 		if (tnow > treport) {
401 			fprintf(stderr,
402 				"%s tx: %6lu MB/s %8lu calls/s %6lu msg/s\n",
403 				cfg_tcp ? "tcp" : "udp",
404 				(num_msgs * cfg_payload_len) >> 20,
405 				num_sends, num_msgs);
406 			num_msgs = num_sends = 0;
407 			treport = tnow + 1000;
408 		}
409 
410 		/* cold cache when writing buffer */
411 		if (cfg_cache_trash)
412 			i = ++i < NUM_PKT ? i : 0;
413 
414 	} while (!interrupted && (cfg_runtime_ms == -1 || tnow < tstop));
415 
416 	if (close(fd))
417 		error(1, errno, "close");
418 
419 	return 0;
420 }
421