xref: /openbmc/qemu/tools/ebpf/rss.bpf.c (revision f5c69e7a)
1 /*
2  * eBPF RSS program
3  *
4  * Developed by Daynix Computing LTD (http://www.daynix.com)
5  *
6  * Authors:
7  *  Andrew Melnychenko <andrew@daynix.com>
8  *  Yuri Benditovich <yuri.benditovich@daynix.com>
9  *
10  * This work is licensed under the terms of the GNU GPL, version 2.  See
11  * the COPYING file in the top-level directory.
12  *
13  * Prepare:
14  * Requires llvm, clang, bpftool, linux kernel tree
15  *
16  * Build rss.bpf.skeleton.h:
17  * make -f Makefile.ebpf clean all
18  */
19 
20 #include <stddef.h>
21 #include <stdbool.h>
22 #include <linux/bpf.h>
23 
24 #include <linux/in.h>
25 #include <linux/if_ether.h>
26 #include <linux/ip.h>
27 #include <linux/ipv6.h>
28 
29 #include <linux/udp.h>
30 #include <linux/tcp.h>
31 
32 #include <bpf/bpf_helpers.h>
33 #include <bpf/bpf_endian.h>
34 #include <linux/virtio_net.h>
35 
36 #define INDIRECTION_TABLE_SIZE 128
37 #define HASH_CALCULATION_BUFFER_SIZE 36
38 
39 struct rss_config_t {
40     __u8 redirect;
41     __u8 populate_hash;
42     __u32 hash_types;
43     __u16 indirections_len;
44     __u16 default_queue;
45 } __attribute__((packed));
46 
47 struct toeplitz_key_data_t {
48     __u32 leftmost_32_bits;
49     __u8 next_byte[HASH_CALCULATION_BUFFER_SIZE];
50 };
51 
52 struct packet_hash_info_t {
53     __u8 is_ipv4;
54     __u8 is_ipv6;
55     __u8 is_udp;
56     __u8 is_tcp;
57     __u8 is_ipv6_ext_src;
58     __u8 is_ipv6_ext_dst;
59     __u8 is_fragmented;
60 
61     __u16 src_port;
62     __u16 dst_port;
63 
64     union {
65         struct {
66             __be32 in_src;
67             __be32 in_dst;
68         };
69 
70         struct {
71             struct in6_addr in6_src;
72             struct in6_addr in6_dst;
73             struct in6_addr in6_ext_src;
74             struct in6_addr in6_ext_dst;
75         };
76     };
77 };
78 
79 struct {
80     __uint(type, BPF_MAP_TYPE_ARRAY);
81     __uint(key_size, sizeof(__u32));
82     __uint(value_size, sizeof(struct rss_config_t));
83     __uint(max_entries, 1);
84     __uint(map_flags, BPF_F_MMAPABLE);
85 } tap_rss_map_configurations SEC(".maps");
86 
87 struct {
88     __uint(type, BPF_MAP_TYPE_ARRAY);
89     __uint(key_size, sizeof(__u32));
90     __uint(value_size, sizeof(struct toeplitz_key_data_t));
91     __uint(max_entries, 1);
92     __uint(map_flags, BPF_F_MMAPABLE);
93 } tap_rss_map_toeplitz_key SEC(".maps");
94 
95 struct {
96     __uint(type, BPF_MAP_TYPE_ARRAY);
97     __uint(key_size, sizeof(__u32));
98     __uint(value_size, sizeof(__u16));
99     __uint(max_entries, INDIRECTION_TABLE_SIZE);
100     __uint(map_flags, BPF_F_MMAPABLE);
101 } tap_rss_map_indirection_table SEC(".maps");
102 
net_rx_rss_add_chunk(__u8 * rss_input,size_t * bytes_written,const void * ptr,size_t size)103 static inline void net_rx_rss_add_chunk(__u8 *rss_input, size_t *bytes_written,
104                                         const void *ptr, size_t size) {
105     __builtin_memcpy(&rss_input[*bytes_written], ptr, size);
106     *bytes_written += size;
107 }
108 
109 static inline
net_toeplitz_add(__u32 * result,__u8 * input,__u32 len,struct toeplitz_key_data_t * key)110 void net_toeplitz_add(__u32 *result,
111                       __u8 *input,
112                       __u32 len
113         , struct toeplitz_key_data_t *key) {
114 
115     __u32 accumulator = *result;
116     __u32 leftmost_32_bits = key->leftmost_32_bits;
117     __u32 byte;
118 
119     for (byte = 0; byte < HASH_CALCULATION_BUFFER_SIZE; byte++) {
120         __u8 input_byte = input[byte];
121         __u8 key_byte = key->next_byte[byte];
122         __u8 bit;
123 
124         for (bit = 0; bit < 8; bit++) {
125             if (input_byte & (1 << 7)) {
126                 accumulator ^= leftmost_32_bits;
127             }
128 
129             leftmost_32_bits =
130                     (leftmost_32_bits << 1) | ((key_byte & (1 << 7)) >> 7);
131 
132             input_byte <<= 1;
133             key_byte <<= 1;
134         }
135     }
136 
137     *result = accumulator;
138 }
139 
140 
ip6_extension_header_type(__u8 hdr_type)141 static inline int ip6_extension_header_type(__u8 hdr_type)
142 {
143     switch (hdr_type) {
144     case IPPROTO_HOPOPTS:
145     case IPPROTO_ROUTING:
146     case IPPROTO_FRAGMENT:
147     case IPPROTO_ICMPV6:
148     case IPPROTO_NONE:
149     case IPPROTO_DSTOPTS:
150     case IPPROTO_MH:
151         return 1;
152     default:
153         return 0;
154     }
155 }
156 /*
157  * According to
158  * https://www.iana.org/assignments/ipv6-parameters/ipv6-parameters.xhtml
159  * we expect that there are would be no more than 11 extensions in IPv6 header,
160  * also there is 27 TLV options for Destination and Hop-by-hop extensions.
161  * Need to choose reasonable amount of maximum extensions/options we may
162  * check to find ext src/dst.
163  */
164 #define IP6_EXTENSIONS_COUNT 11
165 #define IP6_OPTIONS_COUNT 30
166 
parse_ipv6_ext(struct __sk_buff * skb,struct packet_hash_info_t * info,__u8 * l4_protocol,size_t * l4_offset)167 static inline int parse_ipv6_ext(struct __sk_buff *skb,
168         struct packet_hash_info_t *info,
169         __u8 *l4_protocol, size_t *l4_offset)
170 {
171     int err = 0;
172 
173     if (!ip6_extension_header_type(*l4_protocol)) {
174         return 0;
175     }
176 
177     struct ipv6_opt_hdr ext_hdr = {};
178 
179     for (unsigned int i = 0; i < IP6_EXTENSIONS_COUNT; ++i) {
180 
181         err = bpf_skb_load_bytes_relative(skb, *l4_offset, &ext_hdr,
182                                     sizeof(ext_hdr), BPF_HDR_START_NET);
183         if (err) {
184             goto error;
185         }
186 
187         if (*l4_protocol == IPPROTO_ROUTING) {
188             struct ipv6_rt_hdr ext_rt = {};
189 
190             err = bpf_skb_load_bytes_relative(skb, *l4_offset, &ext_rt,
191                                         sizeof(ext_rt), BPF_HDR_START_NET);
192             if (err) {
193                 goto error;
194             }
195 
196             if ((ext_rt.type == IPV6_SRCRT_TYPE_2) &&
197                     (ext_rt.hdrlen == sizeof(struct in6_addr) / 8) &&
198                     (ext_rt.segments_left == 1)) {
199 
200                 err = bpf_skb_load_bytes_relative(skb,
201                     *l4_offset + offsetof(struct rt2_hdr, addr),
202                     &info->in6_ext_dst, sizeof(info->in6_ext_dst),
203                     BPF_HDR_START_NET);
204                 if (err) {
205                     goto error;
206                 }
207 
208                 info->is_ipv6_ext_dst = 1;
209             }
210 
211         } else if (*l4_protocol == IPPROTO_DSTOPTS) {
212             struct ipv6_opt_t {
213                 __u8 type;
214                 __u8 length;
215             } __attribute__((packed)) opt = {};
216 
217             size_t opt_offset = sizeof(ext_hdr);
218 
219             for (unsigned int j = 0; j < IP6_OPTIONS_COUNT; ++j) {
220                 err = bpf_skb_load_bytes_relative(skb, *l4_offset + opt_offset,
221                                         &opt, sizeof(opt), BPF_HDR_START_NET);
222                 if (err) {
223                     goto error;
224                 }
225 
226                 if (opt.type == IPV6_TLV_HAO) {
227                     err = bpf_skb_load_bytes_relative(skb,
228                         *l4_offset + opt_offset
229                         + offsetof(struct ipv6_destopt_hao, addr),
230                         &info->in6_ext_src, sizeof(info->in6_ext_src),
231                         BPF_HDR_START_NET);
232                     if (err) {
233                         goto error;
234                     }
235 
236                     info->is_ipv6_ext_src = 1;
237                     break;
238                 }
239 
240                 opt_offset += (opt.type == IPV6_TLV_PAD1) ?
241                               1 : opt.length + sizeof(opt);
242 
243                 if (opt_offset + 1 >= ext_hdr.hdrlen * 8) {
244                     break;
245                 }
246             }
247         } else if (*l4_protocol == IPPROTO_FRAGMENT) {
248             info->is_fragmented = true;
249         }
250 
251         *l4_protocol = ext_hdr.nexthdr;
252         *l4_offset += (ext_hdr.hdrlen + 1) * 8;
253 
254         if (!ip6_extension_header_type(ext_hdr.nexthdr)) {
255             return 0;
256         }
257     }
258 
259     return 0;
260 error:
261     return err;
262 }
263 
parse_eth_type(struct __sk_buff * skb)264 static __be16 parse_eth_type(struct __sk_buff *skb)
265 {
266     unsigned int offset = 12;
267     __be16 ret = 0;
268     int err = 0;
269 
270     err = bpf_skb_load_bytes_relative(skb, offset, &ret, sizeof(ret),
271                                 BPF_HDR_START_MAC);
272     if (err) {
273         return 0;
274     }
275 
276     switch (bpf_ntohs(ret)) {
277     case ETH_P_8021AD:
278         offset += 4;
279     case ETH_P_8021Q:
280         offset += 4;
281         err = bpf_skb_load_bytes_relative(skb, offset, &ret, sizeof(ret),
282                                     BPF_HDR_START_MAC);
283     default:
284         break;
285     }
286 
287     if (err) {
288         return 0;
289     }
290 
291     return ret;
292 }
293 
parse_packet(struct __sk_buff * skb,struct packet_hash_info_t * info)294 static inline int parse_packet(struct __sk_buff *skb,
295         struct packet_hash_info_t *info)
296 {
297     int err = 0;
298 
299     if (!info || !skb) {
300         return -1;
301     }
302 
303     size_t l4_offset = 0;
304     __u8 l4_protocol = 0;
305     __u16 l3_protocol = bpf_ntohs(parse_eth_type(skb));
306     if (l3_protocol == 0) {
307         err = -1;
308         goto error;
309     }
310 
311     if (l3_protocol == ETH_P_IP) {
312         info->is_ipv4 = 1;
313 
314         struct iphdr ip = {};
315         err = bpf_skb_load_bytes_relative(skb, 0, &ip, sizeof(ip),
316                                     BPF_HDR_START_NET);
317         if (err) {
318             goto error;
319         }
320 
321         info->in_src = ip.saddr;
322         info->in_dst = ip.daddr;
323         info->is_fragmented = !!(bpf_ntohs(ip.frag_off) & (0x2000 | 0x1fff));
324 
325         l4_protocol = ip.protocol;
326         l4_offset = ip.ihl * 4;
327     } else if (l3_protocol == ETH_P_IPV6) {
328         info->is_ipv6 = 1;
329 
330         struct ipv6hdr ip6 = {};
331         err = bpf_skb_load_bytes_relative(skb, 0, &ip6, sizeof(ip6),
332                                     BPF_HDR_START_NET);
333         if (err) {
334             goto error;
335         }
336 
337         info->in6_src = ip6.saddr;
338         info->in6_dst = ip6.daddr;
339 
340         l4_protocol = ip6.nexthdr;
341         l4_offset = sizeof(ip6);
342 
343         err = parse_ipv6_ext(skb, info, &l4_protocol, &l4_offset);
344         if (err) {
345             goto error;
346         }
347     }
348 
349     if (l4_protocol != 0 && !info->is_fragmented) {
350         if (l4_protocol == IPPROTO_TCP) {
351             info->is_tcp = 1;
352 
353             struct tcphdr tcp = {};
354             err = bpf_skb_load_bytes_relative(skb, l4_offset, &tcp, sizeof(tcp),
355                                         BPF_HDR_START_NET);
356             if (err) {
357                 goto error;
358             }
359 
360             info->src_port = tcp.source;
361             info->dst_port = tcp.dest;
362         } else if (l4_protocol == IPPROTO_UDP) { /* TODO: add udplite? */
363             info->is_udp = 1;
364 
365             struct udphdr udp = {};
366             err = bpf_skb_load_bytes_relative(skb, l4_offset, &udp, sizeof(udp),
367                                         BPF_HDR_START_NET);
368             if (err) {
369                 goto error;
370             }
371 
372             info->src_port = udp.source;
373             info->dst_port = udp.dest;
374         }
375     }
376 
377     return 0;
378 
379 error:
380     return err;
381 }
382 
calculate_rss_hash(struct __sk_buff * skb,struct rss_config_t * config,struct toeplitz_key_data_t * toe,__u32 * result)383 static inline bool calculate_rss_hash(struct __sk_buff *skb,
384                                       struct rss_config_t *config,
385                                       struct toeplitz_key_data_t *toe,
386                                       __u32 *result)
387 {
388     __u8 rss_input[HASH_CALCULATION_BUFFER_SIZE] = {};
389     size_t bytes_written = 0;
390     int err = 0;
391     struct packet_hash_info_t packet_info = {};
392 
393     err = parse_packet(skb, &packet_info);
394     if (err) {
395         return false;
396     }
397 
398     if (packet_info.is_ipv4) {
399         if (packet_info.is_tcp &&
400             config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_TCPv4) {
401 
402             net_rx_rss_add_chunk(rss_input, &bytes_written,
403                                  &packet_info.in_src,
404                                  sizeof(packet_info.in_src));
405             net_rx_rss_add_chunk(rss_input, &bytes_written,
406                                  &packet_info.in_dst,
407                                  sizeof(packet_info.in_dst));
408             net_rx_rss_add_chunk(rss_input, &bytes_written,
409                                  &packet_info.src_port,
410                                  sizeof(packet_info.src_port));
411             net_rx_rss_add_chunk(rss_input, &bytes_written,
412                                  &packet_info.dst_port,
413                                  sizeof(packet_info.dst_port));
414         } else if (packet_info.is_udp &&
415                    config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_UDPv4) {
416 
417             net_rx_rss_add_chunk(rss_input, &bytes_written,
418                                  &packet_info.in_src,
419                                  sizeof(packet_info.in_src));
420             net_rx_rss_add_chunk(rss_input, &bytes_written,
421                                  &packet_info.in_dst,
422                                  sizeof(packet_info.in_dst));
423             net_rx_rss_add_chunk(rss_input, &bytes_written,
424                                  &packet_info.src_port,
425                                  sizeof(packet_info.src_port));
426             net_rx_rss_add_chunk(rss_input, &bytes_written,
427                                  &packet_info.dst_port,
428                                  sizeof(packet_info.dst_port));
429         } else if (config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_IPv4) {
430             net_rx_rss_add_chunk(rss_input, &bytes_written,
431                                  &packet_info.in_src,
432                                  sizeof(packet_info.in_src));
433             net_rx_rss_add_chunk(rss_input, &bytes_written,
434                                  &packet_info.in_dst,
435                                  sizeof(packet_info.in_dst));
436         }
437     } else if (packet_info.is_ipv6) {
438         if (packet_info.is_tcp &&
439             config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_TCPv6) {
440 
441             if (packet_info.is_ipv6_ext_src &&
442                 config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_TCP_EX) {
443 
444                 net_rx_rss_add_chunk(rss_input, &bytes_written,
445                                      &packet_info.in6_ext_src,
446                                      sizeof(packet_info.in6_ext_src));
447             } else {
448                 net_rx_rss_add_chunk(rss_input, &bytes_written,
449                                      &packet_info.in6_src,
450                                      sizeof(packet_info.in6_src));
451             }
452             if (packet_info.is_ipv6_ext_dst &&
453                 config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_TCP_EX) {
454 
455                 net_rx_rss_add_chunk(rss_input, &bytes_written,
456                                      &packet_info.in6_ext_dst,
457                                      sizeof(packet_info.in6_ext_dst));
458             } else {
459                 net_rx_rss_add_chunk(rss_input, &bytes_written,
460                                      &packet_info.in6_dst,
461                                      sizeof(packet_info.in6_dst));
462             }
463             net_rx_rss_add_chunk(rss_input, &bytes_written,
464                                  &packet_info.src_port,
465                                  sizeof(packet_info.src_port));
466             net_rx_rss_add_chunk(rss_input, &bytes_written,
467                                  &packet_info.dst_port,
468                                  sizeof(packet_info.dst_port));
469         } else if (packet_info.is_udp &&
470                    config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_UDPv6) {
471 
472             if (packet_info.is_ipv6_ext_src &&
473                config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_UDP_EX) {
474 
475                 net_rx_rss_add_chunk(rss_input, &bytes_written,
476                                      &packet_info.in6_ext_src,
477                                      sizeof(packet_info.in6_ext_src));
478             } else {
479                 net_rx_rss_add_chunk(rss_input, &bytes_written,
480                                      &packet_info.in6_src,
481                                      sizeof(packet_info.in6_src));
482             }
483             if (packet_info.is_ipv6_ext_dst &&
484                config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_UDP_EX) {
485 
486                 net_rx_rss_add_chunk(rss_input, &bytes_written,
487                                      &packet_info.in6_ext_dst,
488                                      sizeof(packet_info.in6_ext_dst));
489             } else {
490                 net_rx_rss_add_chunk(rss_input, &bytes_written,
491                                      &packet_info.in6_dst,
492                                      sizeof(packet_info.in6_dst));
493             }
494 
495             net_rx_rss_add_chunk(rss_input, &bytes_written,
496                                  &packet_info.src_port,
497                                  sizeof(packet_info.src_port));
498             net_rx_rss_add_chunk(rss_input, &bytes_written,
499                                  &packet_info.dst_port,
500                                  sizeof(packet_info.dst_port));
501 
502         } else if (config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_IPv6) {
503             if (packet_info.is_ipv6_ext_src &&
504                config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_IP_EX) {
505 
506                 net_rx_rss_add_chunk(rss_input, &bytes_written,
507                                      &packet_info.in6_ext_src,
508                                      sizeof(packet_info.in6_ext_src));
509             } else {
510                 net_rx_rss_add_chunk(rss_input, &bytes_written,
511                                      &packet_info.in6_src,
512                                      sizeof(packet_info.in6_src));
513             }
514             if (packet_info.is_ipv6_ext_dst &&
515                 config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_IP_EX) {
516 
517                 net_rx_rss_add_chunk(rss_input, &bytes_written,
518                                      &packet_info.in6_ext_dst,
519                                      sizeof(packet_info.in6_ext_dst));
520             } else {
521                 net_rx_rss_add_chunk(rss_input, &bytes_written,
522                                      &packet_info.in6_dst,
523                                      sizeof(packet_info.in6_dst));
524             }
525         }
526     }
527 
528     if (!bytes_written) {
529         return false;
530     }
531 
532     net_toeplitz_add(result, rss_input, bytes_written, toe);
533 
534     return true;
535 }
536 
537 SEC("socket")
tun_rss_steering_prog(struct __sk_buff * skb)538 int tun_rss_steering_prog(struct __sk_buff *skb)
539 {
540 
541     struct rss_config_t *config;
542     struct toeplitz_key_data_t *toe;
543 
544     __u32 key = 0;
545     __u32 hash = 0;
546 
547     config = bpf_map_lookup_elem(&tap_rss_map_configurations, &key);
548     toe = bpf_map_lookup_elem(&tap_rss_map_toeplitz_key, &key);
549 
550     if (!config || !toe) {
551         return 0;
552     }
553 
554     if (config->redirect && calculate_rss_hash(skb, config, toe, &hash)) {
555         __u32 table_idx = hash % config->indirections_len;
556         __u16 *queue = 0;
557 
558         queue = bpf_map_lookup_elem(&tap_rss_map_indirection_table,
559                                     &table_idx);
560 
561         if (queue) {
562             return *queue;
563         }
564     }
565 
566     return config->default_queue;
567 }
568 
569 char _license[] SEC("license") = "GPL v2";
570