1 /* Copyright (c) 2016,2017 Facebook 2 * 3 * This program is free software; you can redistribute it and/or 4 * modify it under the terms of version 2 of the GNU General Public 5 * License as published by the Free Software Foundation. 6 */ 7 #include <stddef.h> 8 #include <string.h> 9 #include <linux/bpf.h> 10 #include <linux/if_ether.h> 11 #include <linux/if_packet.h> 12 #include <linux/ip.h> 13 #include <linux/ipv6.h> 14 #include <linux/in.h> 15 #include <linux/udp.h> 16 #include <linux/tcp.h> 17 #include <linux/pkt_cls.h> 18 #include <sys/socket.h> 19 #include <bpf/bpf_helpers.h> 20 #include <bpf/bpf_endian.h> 21 #include "test_iptunnel_common.h" 22 23 int _version SEC("version") = 1; 24 25 struct { 26 __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY); 27 __uint(max_entries, 256); 28 __type(key, __u32); 29 __type(value, __u64); 30 } rxcnt SEC(".maps"); 31 32 struct { 33 __uint(type, BPF_MAP_TYPE_HASH); 34 __uint(max_entries, MAX_IPTNL_ENTRIES); 35 __type(key, struct vip); 36 __type(value, struct iptnl_info); 37 } vip2tnl SEC(".maps"); 38 39 static __always_inline void count_tx(__u32 protocol) 40 { 41 __u64 *rxcnt_count; 42 43 rxcnt_count = bpf_map_lookup_elem(&rxcnt, &protocol); 44 if (rxcnt_count) 45 *rxcnt_count += 1; 46 } 47 48 static __always_inline int get_dport(void *trans_data, void *data_end, 49 __u8 protocol) 50 { 51 struct tcphdr *th; 52 struct udphdr *uh; 53 54 switch (protocol) { 55 case IPPROTO_TCP: 56 th = (struct tcphdr *)trans_data; 57 if (th + 1 > data_end) 58 return -1; 59 return th->dest; 60 case IPPROTO_UDP: 61 uh = (struct udphdr *)trans_data; 62 if (uh + 1 > data_end) 63 return -1; 64 return uh->dest; 65 default: 66 return 0; 67 } 68 } 69 70 static __always_inline void set_ethhdr(struct ethhdr *new_eth, 71 const struct ethhdr *old_eth, 72 const struct iptnl_info *tnl, 73 __be16 h_proto) 74 { 75 memcpy(new_eth->h_source, old_eth->h_dest, sizeof(new_eth->h_source)); 76 memcpy(new_eth->h_dest, tnl->dmac, sizeof(new_eth->h_dest)); 77 new_eth->h_proto = h_proto; 78 } 79 80 static __always_inline int handle_ipv4(struct xdp_md *xdp) 81 { 82 void *data_end = (void *)(long)xdp->data_end; 83 void *data = (void *)(long)xdp->data; 84 struct iptnl_info *tnl; 85 struct ethhdr *new_eth; 86 struct ethhdr *old_eth; 87 struct iphdr *iph = data + sizeof(struct ethhdr); 88 __u16 *next_iph; 89 __u16 payload_len; 90 struct vip vip = {}; 91 int dport; 92 __u32 csum = 0; 93 int i; 94 95 if (iph + 1 > data_end) 96 return XDP_DROP; 97 98 dport = get_dport(iph + 1, data_end, iph->protocol); 99 if (dport == -1) 100 return XDP_DROP; 101 102 vip.protocol = iph->protocol; 103 vip.family = AF_INET; 104 vip.daddr.v4 = iph->daddr; 105 vip.dport = dport; 106 payload_len = bpf_ntohs(iph->tot_len); 107 108 tnl = bpf_map_lookup_elem(&vip2tnl, &vip); 109 /* It only does v4-in-v4 */ 110 if (!tnl || tnl->family != AF_INET) 111 return XDP_PASS; 112 113 if (bpf_xdp_adjust_head(xdp, 0 - (int)sizeof(struct iphdr))) 114 return XDP_DROP; 115 116 data = (void *)(long)xdp->data; 117 data_end = (void *)(long)xdp->data_end; 118 119 new_eth = data; 120 iph = data + sizeof(*new_eth); 121 old_eth = data + sizeof(*iph); 122 123 if (new_eth + 1 > data_end || 124 old_eth + 1 > data_end || 125 iph + 1 > data_end) 126 return XDP_DROP; 127 128 set_ethhdr(new_eth, old_eth, tnl, bpf_htons(ETH_P_IP)); 129 130 iph->version = 4; 131 iph->ihl = sizeof(*iph) >> 2; 132 iph->frag_off = 0; 133 iph->protocol = IPPROTO_IPIP; 134 iph->check = 0; 135 iph->tos = 0; 136 iph->tot_len = bpf_htons(payload_len + sizeof(*iph)); 137 iph->daddr = tnl->daddr.v4; 138 iph->saddr = tnl->saddr.v4; 139 iph->ttl = 8; 140 141 next_iph = (__u16 *)iph; 142 #pragma clang loop unroll(full) 143 for (i = 0; i < sizeof(*iph) >> 1; i++) 144 csum += *next_iph++; 145 146 iph->check = ~((csum & 0xffff) + (csum >> 16)); 147 148 count_tx(vip.protocol); 149 150 return XDP_TX; 151 } 152 153 static __always_inline int handle_ipv6(struct xdp_md *xdp) 154 { 155 void *data_end = (void *)(long)xdp->data_end; 156 void *data = (void *)(long)xdp->data; 157 struct iptnl_info *tnl; 158 struct ethhdr *new_eth; 159 struct ethhdr *old_eth; 160 struct ipv6hdr *ip6h = data + sizeof(struct ethhdr); 161 __u16 payload_len; 162 struct vip vip = {}; 163 int dport; 164 165 if (ip6h + 1 > data_end) 166 return XDP_DROP; 167 168 dport = get_dport(ip6h + 1, data_end, ip6h->nexthdr); 169 if (dport == -1) 170 return XDP_DROP; 171 172 vip.protocol = ip6h->nexthdr; 173 vip.family = AF_INET6; 174 memcpy(vip.daddr.v6, ip6h->daddr.s6_addr32, sizeof(vip.daddr)); 175 vip.dport = dport; 176 payload_len = ip6h->payload_len; 177 178 tnl = bpf_map_lookup_elem(&vip2tnl, &vip); 179 /* It only does v6-in-v6 */ 180 if (!tnl || tnl->family != AF_INET6) 181 return XDP_PASS; 182 183 if (bpf_xdp_adjust_head(xdp, 0 - (int)sizeof(struct ipv6hdr))) 184 return XDP_DROP; 185 186 data = (void *)(long)xdp->data; 187 data_end = (void *)(long)xdp->data_end; 188 189 new_eth = data; 190 ip6h = data + sizeof(*new_eth); 191 old_eth = data + sizeof(*ip6h); 192 193 if (new_eth + 1 > data_end || old_eth + 1 > data_end || 194 ip6h + 1 > data_end) 195 return XDP_DROP; 196 197 set_ethhdr(new_eth, old_eth, tnl, bpf_htons(ETH_P_IPV6)); 198 199 ip6h->version = 6; 200 ip6h->priority = 0; 201 memset(ip6h->flow_lbl, 0, sizeof(ip6h->flow_lbl)); 202 ip6h->payload_len = bpf_htons(bpf_ntohs(payload_len) + sizeof(*ip6h)); 203 ip6h->nexthdr = IPPROTO_IPV6; 204 ip6h->hop_limit = 8; 205 memcpy(ip6h->saddr.s6_addr32, tnl->saddr.v6, sizeof(tnl->saddr.v6)); 206 memcpy(ip6h->daddr.s6_addr32, tnl->daddr.v6, sizeof(tnl->daddr.v6)); 207 208 count_tx(vip.protocol); 209 210 return XDP_TX; 211 } 212 213 SEC("xdp_tx_iptunnel") 214 int _xdp_tx_iptunnel(struct xdp_md *xdp) 215 { 216 void *data_end = (void *)(long)xdp->data_end; 217 void *data = (void *)(long)xdp->data; 218 struct ethhdr *eth = data; 219 __u16 h_proto; 220 221 if (eth + 1 > data_end) 222 return XDP_DROP; 223 224 h_proto = eth->h_proto; 225 226 if (h_proto == bpf_htons(ETH_P_IP)) 227 return handle_ipv4(xdp); 228 else if (h_proto == bpf_htons(ETH_P_IPV6)) 229 230 return handle_ipv6(xdp); 231 else 232 return XDP_DROP; 233 } 234 235 char _license[] SEC("license") = "GPL"; 236