1 // SPDX-License-Identifier: GPL-2.0 2 3 #define _GNU_SOURCE 4 5 #include <arpa/inet.h> 6 #include <errno.h> 7 #include <error.h> 8 #include <fcntl.h> 9 #include <limits.h> 10 #include <linux/filter.h> 11 #include <linux/bpf.h> 12 #include <linux/if_packet.h> 13 #include <linux/if_vlan.h> 14 #include <linux/virtio_net.h> 15 #include <net/if.h> 16 #include <net/ethernet.h> 17 #include <netinet/ip.h> 18 #include <netinet/udp.h> 19 #include <poll.h> 20 #include <sched.h> 21 #include <stdbool.h> 22 #include <stdint.h> 23 #include <stdio.h> 24 #include <stdlib.h> 25 #include <string.h> 26 #include <sys/mman.h> 27 #include <sys/socket.h> 28 #include <sys/stat.h> 29 #include <sys/types.h> 30 #include <unistd.h> 31 32 #include "psock_lib.h" 33 34 static bool cfg_use_bind; 35 static bool cfg_use_csum_off; 36 static bool cfg_use_csum_off_bad; 37 static bool cfg_use_dgram; 38 static bool cfg_use_gso; 39 static bool cfg_use_qdisc_bypass; 40 static bool cfg_use_vlan; 41 static bool cfg_use_vnet; 42 43 static char *cfg_ifname = "lo"; 44 static int cfg_mtu = 1500; 45 static int cfg_payload_len = DATA_LEN; 46 static int cfg_truncate_len = INT_MAX; 47 static uint16_t cfg_port = 8000; 48 49 /* test sending up to max mtu + 1 */ 50 #define TEST_SZ (sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1) 51 52 static char tbuf[TEST_SZ], rbuf[TEST_SZ]; 53 54 static unsigned long add_csum_hword(const uint16_t *start, int num_u16) 55 { 56 unsigned long sum = 0; 57 int i; 58 59 for (i = 0; i < num_u16; i++) 60 sum += start[i]; 61 62 return sum; 63 } 64 65 static uint16_t build_ip_csum(const uint16_t *start, int num_u16, 66 unsigned long sum) 67 { 68 sum += add_csum_hword(start, num_u16); 69 70 while (sum >> 16) 71 sum = (sum & 0xffff) + (sum >> 16); 72 73 return ~sum; 74 } 75 76 static int build_vnet_header(void *header) 77 { 78 struct virtio_net_hdr *vh = header; 79 80 vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr); 81 82 if (cfg_use_csum_off) { 83 vh->flags |= VIRTIO_NET_HDR_F_NEEDS_CSUM; 84 vh->csum_start = ETH_HLEN + sizeof(struct iphdr); 85 vh->csum_offset = __builtin_offsetof(struct udphdr, check); 86 87 /* position check field exactly one byte beyond end of packet */ 88 if (cfg_use_csum_off_bad) 89 vh->csum_start += sizeof(struct udphdr) + cfg_payload_len - 90 vh->csum_offset - 1; 91 } 92 93 if (cfg_use_gso) { 94 vh->gso_type = VIRTIO_NET_HDR_GSO_UDP; 95 vh->gso_size = cfg_mtu - sizeof(struct iphdr); 96 } 97 98 return sizeof(*vh); 99 } 100 101 static int build_eth_header(void *header) 102 { 103 struct ethhdr *eth = header; 104 105 if (cfg_use_vlan) { 106 uint16_t *tag = header + ETH_HLEN; 107 108 eth->h_proto = htons(ETH_P_8021Q); 109 tag[1] = htons(ETH_P_IP); 110 return ETH_HLEN + 4; 111 } 112 113 eth->h_proto = htons(ETH_P_IP); 114 return ETH_HLEN; 115 } 116 117 static int build_ipv4_header(void *header, int payload_len) 118 { 119 struct iphdr *iph = header; 120 121 iph->ihl = 5; 122 iph->version = 4; 123 iph->ttl = 8; 124 iph->tot_len = htons(sizeof(*iph) + sizeof(struct udphdr) + payload_len); 125 iph->id = htons(1337); 126 iph->protocol = IPPROTO_UDP; 127 iph->saddr = htonl((172 << 24) | (17 << 16) | 2); 128 iph->daddr = htonl((172 << 24) | (17 << 16) | 1); 129 iph->check = build_ip_csum((void *) iph, iph->ihl << 1, 0); 130 131 return iph->ihl << 2; 132 } 133 134 static int build_udp_header(void *header, int payload_len) 135 { 136 const int alen = sizeof(uint32_t); 137 struct udphdr *udph = header; 138 int len = sizeof(*udph) + payload_len; 139 140 udph->source = htons(9); 141 udph->dest = htons(cfg_port); 142 udph->len = htons(len); 143 144 if (cfg_use_csum_off) 145 udph->check = build_ip_csum(header - (2 * alen), alen, 146 htons(IPPROTO_UDP) + udph->len); 147 else 148 udph->check = 0; 149 150 return sizeof(*udph); 151 } 152 153 static int build_packet(int payload_len) 154 { 155 int off = 0; 156 157 off += build_vnet_header(tbuf); 158 off += build_eth_header(tbuf + off); 159 off += build_ipv4_header(tbuf + off, payload_len); 160 off += build_udp_header(tbuf + off, payload_len); 161 162 if (off + payload_len > sizeof(tbuf)) 163 error(1, 0, "payload length exceeds max"); 164 165 memset(tbuf + off, DATA_CHAR, payload_len); 166 167 return off + payload_len; 168 } 169 170 static void do_bind(int fd) 171 { 172 struct sockaddr_ll laddr = {0}; 173 174 laddr.sll_family = AF_PACKET; 175 laddr.sll_protocol = htons(ETH_P_IP); 176 laddr.sll_ifindex = if_nametoindex(cfg_ifname); 177 if (!laddr.sll_ifindex) 178 error(1, errno, "if_nametoindex"); 179 180 if (bind(fd, (void *)&laddr, sizeof(laddr))) 181 error(1, errno, "bind"); 182 } 183 184 static void do_send(int fd, char *buf, int len) 185 { 186 int ret; 187 188 if (!cfg_use_vnet) { 189 buf += sizeof(struct virtio_net_hdr); 190 len -= sizeof(struct virtio_net_hdr); 191 } 192 if (cfg_use_dgram) { 193 buf += ETH_HLEN; 194 len -= ETH_HLEN; 195 } 196 197 if (cfg_use_bind) { 198 ret = write(fd, buf, len); 199 } else { 200 struct sockaddr_ll laddr = {0}; 201 202 laddr.sll_protocol = htons(ETH_P_IP); 203 laddr.sll_ifindex = if_nametoindex(cfg_ifname); 204 if (!laddr.sll_ifindex) 205 error(1, errno, "if_nametoindex"); 206 207 ret = sendto(fd, buf, len, 0, (void *)&laddr, sizeof(laddr)); 208 } 209 210 if (ret == -1) 211 error(1, errno, "write"); 212 if (ret != len) 213 error(1, 0, "write: %u %u", ret, len); 214 215 fprintf(stderr, "tx: %u\n", ret); 216 } 217 218 static int do_tx(void) 219 { 220 const int one = 1; 221 int fd, len; 222 223 fd = socket(PF_PACKET, cfg_use_dgram ? SOCK_DGRAM : SOCK_RAW, 0); 224 if (fd == -1) 225 error(1, errno, "socket t"); 226 227 if (cfg_use_bind) 228 do_bind(fd); 229 230 if (cfg_use_qdisc_bypass && 231 setsockopt(fd, SOL_PACKET, PACKET_QDISC_BYPASS, &one, sizeof(one))) 232 error(1, errno, "setsockopt qdisc bypass"); 233 234 if (cfg_use_vnet && 235 setsockopt(fd, SOL_PACKET, PACKET_VNET_HDR, &one, sizeof(one))) 236 error(1, errno, "setsockopt vnet"); 237 238 len = build_packet(cfg_payload_len); 239 240 if (cfg_truncate_len < len) 241 len = cfg_truncate_len; 242 243 do_send(fd, tbuf, len); 244 245 if (close(fd)) 246 error(1, errno, "close t"); 247 248 return len; 249 } 250 251 static int setup_rx(void) 252 { 253 struct timeval tv = { .tv_usec = 100 * 1000 }; 254 struct sockaddr_in raddr = {0}; 255 int fd; 256 257 fd = socket(PF_INET, SOCK_DGRAM, 0); 258 if (fd == -1) 259 error(1, errno, "socket r"); 260 261 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv))) 262 error(1, errno, "setsockopt rcv timeout"); 263 264 raddr.sin_family = AF_INET; 265 raddr.sin_port = htons(cfg_port); 266 raddr.sin_addr.s_addr = htonl(INADDR_ANY); 267 268 if (bind(fd, (void *)&raddr, sizeof(raddr))) 269 error(1, errno, "bind r"); 270 271 return fd; 272 } 273 274 static void do_rx(int fd, int expected_len, char *expected) 275 { 276 int ret; 277 278 ret = recv(fd, rbuf, sizeof(rbuf), 0); 279 if (ret == -1) 280 error(1, errno, "recv"); 281 if (ret != expected_len) 282 error(1, 0, "recv: %u != %u", ret, expected_len); 283 284 if (memcmp(rbuf, expected, ret)) 285 error(1, 0, "recv: data mismatch"); 286 287 fprintf(stderr, "rx: %u\n", ret); 288 } 289 290 static int setup_sniffer(void) 291 { 292 struct timeval tv = { .tv_usec = 100 * 1000 }; 293 int fd; 294 295 fd = socket(PF_PACKET, SOCK_RAW, 0); 296 if (fd == -1) 297 error(1, errno, "socket p"); 298 299 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv))) 300 error(1, errno, "setsockopt rcv timeout"); 301 302 pair_udp_setfilter(fd); 303 do_bind(fd); 304 305 return fd; 306 } 307 308 static void parse_opts(int argc, char **argv) 309 { 310 int c; 311 312 while ((c = getopt(argc, argv, "bcCdgl:qt:vV")) != -1) { 313 switch (c) { 314 case 'b': 315 cfg_use_bind = true; 316 break; 317 case 'c': 318 cfg_use_csum_off = true; 319 break; 320 case 'C': 321 cfg_use_csum_off_bad = true; 322 break; 323 case 'd': 324 cfg_use_dgram = true; 325 break; 326 case 'g': 327 cfg_use_gso = true; 328 break; 329 case 'l': 330 cfg_payload_len = strtoul(optarg, NULL, 0); 331 break; 332 case 'q': 333 cfg_use_qdisc_bypass = true; 334 break; 335 case 't': 336 cfg_truncate_len = strtoul(optarg, NULL, 0); 337 break; 338 case 'v': 339 cfg_use_vnet = true; 340 break; 341 case 'V': 342 cfg_use_vlan = true; 343 break; 344 default: 345 error(1, 0, "%s: parse error", argv[0]); 346 } 347 } 348 349 if (cfg_use_vlan && cfg_use_dgram) 350 error(1, 0, "option vlan (-V) conflicts with dgram (-d)"); 351 352 if (cfg_use_csum_off && !cfg_use_vnet) 353 error(1, 0, "option csum offload (-c) requires vnet (-v)"); 354 355 if (cfg_use_csum_off_bad && !cfg_use_csum_off) 356 error(1, 0, "option csum bad (-C) requires csum offload (-c)"); 357 358 if (cfg_use_gso && !cfg_use_csum_off) 359 error(1, 0, "option gso (-g) requires csum offload (-c)"); 360 } 361 362 static void run_test(void) 363 { 364 int fdr, fds, total_len; 365 366 fdr = setup_rx(); 367 fds = setup_sniffer(); 368 369 total_len = do_tx(); 370 371 /* BPF filter accepts only this length, vlan changes MAC */ 372 if (cfg_payload_len == DATA_LEN && !cfg_use_vlan) 373 do_rx(fds, total_len - sizeof(struct virtio_net_hdr), 374 tbuf + sizeof(struct virtio_net_hdr)); 375 376 do_rx(fdr, cfg_payload_len, tbuf + total_len - cfg_payload_len); 377 378 if (close(fds)) 379 error(1, errno, "close s"); 380 if (close(fdr)) 381 error(1, errno, "close r"); 382 } 383 384 int main(int argc, char **argv) 385 { 386 parse_opts(argc, argv); 387 388 if (system("ip link set dev lo mtu 1500")) 389 error(1, errno, "ip link set mtu"); 390 if (system("ip addr add dev lo 172.17.0.1/24")) 391 error(1, errno, "ip addr add"); 392 393 run_test(); 394 395 fprintf(stderr, "OK\n\n"); 396 return 0; 397 } 398