xref: /openbmc/qemu/tools/ebpf/rss.bpf.c (revision ef929281f1ddb1ce74f5fe39377a88e6cc8237aa)
1  /*
2   * eBPF RSS program
3   *
4   * Developed by Daynix Computing LTD (http://www.daynix.com)
5   *
6   * Authors:
7   *  Andrew Melnychenko <andrew@daynix.com>
8   *  Yuri Benditovich <yuri.benditovich@daynix.com>
9   *
10   * This work is licensed under the terms of the GNU GPL, version 2.  See
11   * the COPYING file in the top-level directory.
12   *
13   * Prepare:
14   * Requires llvm, clang, bpftool, linux kernel tree
15   *
16   * Build rss.bpf.skeleton.h:
17   * make -f Makefile.ebpf clean all
18   */
19  
20  #include <stddef.h>
21  #include <stdbool.h>
22  #include <linux/bpf.h>
23  
24  #include <linux/in.h>
25  #include <linux/if_ether.h>
26  #include <linux/ip.h>
27  #include <linux/ipv6.h>
28  
29  #include <linux/udp.h>
30  #include <linux/tcp.h>
31  
32  #include <bpf/bpf_helpers.h>
33  #include <bpf/bpf_endian.h>
34  #include <linux/virtio_net.h>
35  
36  #define INDIRECTION_TABLE_SIZE 128
37  #define HASH_CALCULATION_BUFFER_SIZE 36
38  
39  struct rss_config_t {
40      __u8 redirect;
41      __u8 populate_hash;
42      __u32 hash_types;
43      __u16 indirections_len;
44      __u16 default_queue;
45  } __attribute__((packed));
46  
47  struct toeplitz_key_data_t {
48      __u32 leftmost_32_bits;
49      __u8 next_byte[HASH_CALCULATION_BUFFER_SIZE];
50  };
51  
52  struct packet_hash_info_t {
53      __u8 is_ipv4;
54      __u8 is_ipv6;
55      __u8 is_udp;
56      __u8 is_tcp;
57      __u8 is_ipv6_ext_src;
58      __u8 is_ipv6_ext_dst;
59      __u8 is_fragmented;
60  
61      __u16 src_port;
62      __u16 dst_port;
63  
64      union {
65          struct {
66              __be32 in_src;
67              __be32 in_dst;
68          };
69  
70          struct {
71              struct in6_addr in6_src;
72              struct in6_addr in6_dst;
73              struct in6_addr in6_ext_src;
74              struct in6_addr in6_ext_dst;
75          };
76      };
77  };
78  
79  struct {
80      __uint(type, BPF_MAP_TYPE_ARRAY);
81      __uint(key_size, sizeof(__u32));
82      __uint(value_size, sizeof(struct rss_config_t));
83      __uint(max_entries, 1);
84      __uint(map_flags, BPF_F_MMAPABLE);
85  } tap_rss_map_configurations SEC(".maps");
86  
87  struct {
88      __uint(type, BPF_MAP_TYPE_ARRAY);
89      __uint(key_size, sizeof(__u32));
90      __uint(value_size, sizeof(struct toeplitz_key_data_t));
91      __uint(max_entries, 1);
92      __uint(map_flags, BPF_F_MMAPABLE);
93  } tap_rss_map_toeplitz_key SEC(".maps");
94  
95  struct {
96      __uint(type, BPF_MAP_TYPE_ARRAY);
97      __uint(key_size, sizeof(__u32));
98      __uint(value_size, sizeof(__u16));
99      __uint(max_entries, INDIRECTION_TABLE_SIZE);
100      __uint(map_flags, BPF_F_MMAPABLE);
101  } tap_rss_map_indirection_table SEC(".maps");
102  
103  static inline void net_rx_rss_add_chunk(__u8 *rss_input, size_t *bytes_written,
104                                          const void *ptr, size_t size) {
105      __builtin_memcpy(&rss_input[*bytes_written], ptr, size);
106      *bytes_written += size;
107  }
108  
109  static inline
110  void net_toeplitz_add(__u32 *result,
111                        __u8 *input,
112                        __u32 len
113          , struct toeplitz_key_data_t *key) {
114  
115      __u32 accumulator = *result;
116      __u32 leftmost_32_bits = key->leftmost_32_bits;
117      __u32 byte;
118  
119      for (byte = 0; byte < HASH_CALCULATION_BUFFER_SIZE; byte++) {
120          __u8 input_byte = input[byte];
121          __u8 key_byte = key->next_byte[byte];
122          __u8 bit;
123  
124          for (bit = 0; bit < 8; bit++) {
125              if (input_byte & (1 << 7)) {
126                  accumulator ^= leftmost_32_bits;
127              }
128  
129              leftmost_32_bits =
130                      (leftmost_32_bits << 1) | ((key_byte & (1 << 7)) >> 7);
131  
132              input_byte <<= 1;
133              key_byte <<= 1;
134          }
135      }
136  
137      *result = accumulator;
138  }
139  
140  
141  static inline int ip6_extension_header_type(__u8 hdr_type)
142  {
143      switch (hdr_type) {
144      case IPPROTO_HOPOPTS:
145      case IPPROTO_ROUTING:
146      case IPPROTO_FRAGMENT:
147      case IPPROTO_ICMPV6:
148      case IPPROTO_NONE:
149      case IPPROTO_DSTOPTS:
150      case IPPROTO_MH:
151          return 1;
152      default:
153          return 0;
154      }
155  }
156  /*
157   * According to
158   * https://www.iana.org/assignments/ipv6-parameters/ipv6-parameters.xhtml
159   * we expect that there are would be no more than 11 extensions in IPv6 header,
160   * also there is 27 TLV options for Destination and Hop-by-hop extensions.
161   * Need to choose reasonable amount of maximum extensions/options we may
162   * check to find ext src/dst.
163   */
164  #define IP6_EXTENSIONS_COUNT 11
165  #define IP6_OPTIONS_COUNT 30
166  
167  static inline int parse_ipv6_ext(struct __sk_buff *skb,
168          struct packet_hash_info_t *info,
169          __u8 *l4_protocol, size_t *l4_offset)
170  {
171      int err = 0;
172  
173      if (!ip6_extension_header_type(*l4_protocol)) {
174          return 0;
175      }
176  
177      struct ipv6_opt_hdr ext_hdr = {};
178  
179      for (unsigned int i = 0; i < IP6_EXTENSIONS_COUNT; ++i) {
180  
181          err = bpf_skb_load_bytes_relative(skb, *l4_offset, &ext_hdr,
182                                      sizeof(ext_hdr), BPF_HDR_START_NET);
183          if (err) {
184              goto error;
185          }
186  
187          if (*l4_protocol == IPPROTO_ROUTING) {
188              struct ipv6_rt_hdr ext_rt = {};
189  
190              err = bpf_skb_load_bytes_relative(skb, *l4_offset, &ext_rt,
191                                          sizeof(ext_rt), BPF_HDR_START_NET);
192              if (err) {
193                  goto error;
194              }
195  
196              if ((ext_rt.type == IPV6_SRCRT_TYPE_2) &&
197                      (ext_rt.hdrlen == sizeof(struct in6_addr) / 8) &&
198                      (ext_rt.segments_left == 1)) {
199  
200                  err = bpf_skb_load_bytes_relative(skb,
201                      *l4_offset + offsetof(struct rt2_hdr, addr),
202                      &info->in6_ext_dst, sizeof(info->in6_ext_dst),
203                      BPF_HDR_START_NET);
204                  if (err) {
205                      goto error;
206                  }
207  
208                  info->is_ipv6_ext_dst = 1;
209              }
210  
211          } else if (*l4_protocol == IPPROTO_DSTOPTS) {
212              struct ipv6_opt_t {
213                  __u8 type;
214                  __u8 length;
215              } __attribute__((packed)) opt = {};
216  
217              size_t opt_offset = sizeof(ext_hdr);
218  
219              for (unsigned int j = 0; j < IP6_OPTIONS_COUNT; ++j) {
220                  err = bpf_skb_load_bytes_relative(skb, *l4_offset + opt_offset,
221                                          &opt, sizeof(opt), BPF_HDR_START_NET);
222                  if (err) {
223                      goto error;
224                  }
225  
226                  if (opt.type == IPV6_TLV_HAO) {
227                      err = bpf_skb_load_bytes_relative(skb,
228                          *l4_offset + opt_offset
229                          + offsetof(struct ipv6_destopt_hao, addr),
230                          &info->in6_ext_src, sizeof(info->in6_ext_src),
231                          BPF_HDR_START_NET);
232                      if (err) {
233                          goto error;
234                      }
235  
236                      info->is_ipv6_ext_src = 1;
237                      break;
238                  }
239  
240                  opt_offset += (opt.type == IPV6_TLV_PAD1) ?
241                                1 : opt.length + sizeof(opt);
242  
243                  if (opt_offset + 1 >= ext_hdr.hdrlen * 8) {
244                      break;
245                  }
246              }
247          } else if (*l4_protocol == IPPROTO_FRAGMENT) {
248              info->is_fragmented = true;
249          }
250  
251          *l4_protocol = ext_hdr.nexthdr;
252          *l4_offset += (ext_hdr.hdrlen + 1) * 8;
253  
254          if (!ip6_extension_header_type(ext_hdr.nexthdr)) {
255              return 0;
256          }
257      }
258  
259      return 0;
260  error:
261      return err;
262  }
263  
264  static __be16 parse_eth_type(struct __sk_buff *skb)
265  {
266      unsigned int offset = 12;
267      __be16 ret = 0;
268      int err = 0;
269  
270      err = bpf_skb_load_bytes_relative(skb, offset, &ret, sizeof(ret),
271                                  BPF_HDR_START_MAC);
272      if (err) {
273          return 0;
274      }
275  
276      switch (bpf_ntohs(ret)) {
277      case ETH_P_8021AD:
278          offset += 4;
279      case ETH_P_8021Q:
280          offset += 4;
281          err = bpf_skb_load_bytes_relative(skb, offset, &ret, sizeof(ret),
282                                      BPF_HDR_START_MAC);
283      default:
284          break;
285      }
286  
287      if (err) {
288          return 0;
289      }
290  
291      return ret;
292  }
293  
294  static inline int parse_packet(struct __sk_buff *skb,
295          struct packet_hash_info_t *info)
296  {
297      int err = 0;
298  
299      if (!info || !skb) {
300          return -1;
301      }
302  
303      size_t l4_offset = 0;
304      __u8 l4_protocol = 0;
305      __u16 l3_protocol = bpf_ntohs(parse_eth_type(skb));
306      if (l3_protocol == 0) {
307          err = -1;
308          goto error;
309      }
310  
311      if (l3_protocol == ETH_P_IP) {
312          info->is_ipv4 = 1;
313  
314          struct iphdr ip = {};
315          err = bpf_skb_load_bytes_relative(skb, 0, &ip, sizeof(ip),
316                                      BPF_HDR_START_NET);
317          if (err) {
318              goto error;
319          }
320  
321          info->in_src = ip.saddr;
322          info->in_dst = ip.daddr;
323          info->is_fragmented = !!(bpf_ntohs(ip.frag_off) & (0x2000 | 0x1fff));
324  
325          l4_protocol = ip.protocol;
326          l4_offset = ip.ihl * 4;
327      } else if (l3_protocol == ETH_P_IPV6) {
328          info->is_ipv6 = 1;
329  
330          struct ipv6hdr ip6 = {};
331          err = bpf_skb_load_bytes_relative(skb, 0, &ip6, sizeof(ip6),
332                                      BPF_HDR_START_NET);
333          if (err) {
334              goto error;
335          }
336  
337          info->in6_src = ip6.saddr;
338          info->in6_dst = ip6.daddr;
339  
340          l4_protocol = ip6.nexthdr;
341          l4_offset = sizeof(ip6);
342  
343          err = parse_ipv6_ext(skb, info, &l4_protocol, &l4_offset);
344          if (err) {
345              goto error;
346          }
347      }
348  
349      if (l4_protocol != 0 && !info->is_fragmented) {
350          if (l4_protocol == IPPROTO_TCP) {
351              info->is_tcp = 1;
352  
353              struct tcphdr tcp = {};
354              err = bpf_skb_load_bytes_relative(skb, l4_offset, &tcp, sizeof(tcp),
355                                          BPF_HDR_START_NET);
356              if (err) {
357                  goto error;
358              }
359  
360              info->src_port = tcp.source;
361              info->dst_port = tcp.dest;
362          } else if (l4_protocol == IPPROTO_UDP) { /* TODO: add udplite? */
363              info->is_udp = 1;
364  
365              struct udphdr udp = {};
366              err = bpf_skb_load_bytes_relative(skb, l4_offset, &udp, sizeof(udp),
367                                          BPF_HDR_START_NET);
368              if (err) {
369                  goto error;
370              }
371  
372              info->src_port = udp.source;
373              info->dst_port = udp.dest;
374          }
375      }
376  
377      return 0;
378  
379  error:
380      return err;
381  }
382  
383  static inline __u32 calculate_rss_hash(struct __sk_buff *skb,
384          struct rss_config_t *config, struct toeplitz_key_data_t *toe)
385  {
386      __u8 rss_input[HASH_CALCULATION_BUFFER_SIZE] = {};
387      size_t bytes_written = 0;
388      __u32 result = 0;
389      int err = 0;
390      struct packet_hash_info_t packet_info = {};
391  
392      err = parse_packet(skb, &packet_info);
393      if (err) {
394          return 0;
395      }
396  
397      if (packet_info.is_ipv4) {
398          if (packet_info.is_tcp &&
399              config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_TCPv4) {
400  
401              net_rx_rss_add_chunk(rss_input, &bytes_written,
402                                   &packet_info.in_src,
403                                   sizeof(packet_info.in_src));
404              net_rx_rss_add_chunk(rss_input, &bytes_written,
405                                   &packet_info.in_dst,
406                                   sizeof(packet_info.in_dst));
407              net_rx_rss_add_chunk(rss_input, &bytes_written,
408                                   &packet_info.src_port,
409                                   sizeof(packet_info.src_port));
410              net_rx_rss_add_chunk(rss_input, &bytes_written,
411                                   &packet_info.dst_port,
412                                   sizeof(packet_info.dst_port));
413          } else if (packet_info.is_udp &&
414                     config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_UDPv4) {
415  
416              net_rx_rss_add_chunk(rss_input, &bytes_written,
417                                   &packet_info.in_src,
418                                   sizeof(packet_info.in_src));
419              net_rx_rss_add_chunk(rss_input, &bytes_written,
420                                   &packet_info.in_dst,
421                                   sizeof(packet_info.in_dst));
422              net_rx_rss_add_chunk(rss_input, &bytes_written,
423                                   &packet_info.src_port,
424                                   sizeof(packet_info.src_port));
425              net_rx_rss_add_chunk(rss_input, &bytes_written,
426                                   &packet_info.dst_port,
427                                   sizeof(packet_info.dst_port));
428          } else if (config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_IPv4) {
429              net_rx_rss_add_chunk(rss_input, &bytes_written,
430                                   &packet_info.in_src,
431                                   sizeof(packet_info.in_src));
432              net_rx_rss_add_chunk(rss_input, &bytes_written,
433                                   &packet_info.in_dst,
434                                   sizeof(packet_info.in_dst));
435          }
436      } else if (packet_info.is_ipv6) {
437          if (packet_info.is_tcp &&
438              config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_TCPv6) {
439  
440              if (packet_info.is_ipv6_ext_src &&
441                  config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_TCP_EX) {
442  
443                  net_rx_rss_add_chunk(rss_input, &bytes_written,
444                                       &packet_info.in6_ext_src,
445                                       sizeof(packet_info.in6_ext_src));
446              } else {
447                  net_rx_rss_add_chunk(rss_input, &bytes_written,
448                                       &packet_info.in6_src,
449                                       sizeof(packet_info.in6_src));
450              }
451              if (packet_info.is_ipv6_ext_dst &&
452                  config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_TCP_EX) {
453  
454                  net_rx_rss_add_chunk(rss_input, &bytes_written,
455                                       &packet_info.in6_ext_dst,
456                                       sizeof(packet_info.in6_ext_dst));
457              } else {
458                  net_rx_rss_add_chunk(rss_input, &bytes_written,
459                                       &packet_info.in6_dst,
460                                       sizeof(packet_info.in6_dst));
461              }
462              net_rx_rss_add_chunk(rss_input, &bytes_written,
463                                   &packet_info.src_port,
464                                   sizeof(packet_info.src_port));
465              net_rx_rss_add_chunk(rss_input, &bytes_written,
466                                   &packet_info.dst_port,
467                                   sizeof(packet_info.dst_port));
468          } else if (packet_info.is_udp &&
469                     config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_UDPv6) {
470  
471              if (packet_info.is_ipv6_ext_src &&
472                 config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_UDP_EX) {
473  
474                  net_rx_rss_add_chunk(rss_input, &bytes_written,
475                                       &packet_info.in6_ext_src,
476                                       sizeof(packet_info.in6_ext_src));
477              } else {
478                  net_rx_rss_add_chunk(rss_input, &bytes_written,
479                                       &packet_info.in6_src,
480                                       sizeof(packet_info.in6_src));
481              }
482              if (packet_info.is_ipv6_ext_dst &&
483                 config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_UDP_EX) {
484  
485                  net_rx_rss_add_chunk(rss_input, &bytes_written,
486                                       &packet_info.in6_ext_dst,
487                                       sizeof(packet_info.in6_ext_dst));
488              } else {
489                  net_rx_rss_add_chunk(rss_input, &bytes_written,
490                                       &packet_info.in6_dst,
491                                       sizeof(packet_info.in6_dst));
492              }
493  
494              net_rx_rss_add_chunk(rss_input, &bytes_written,
495                                   &packet_info.src_port,
496                                   sizeof(packet_info.src_port));
497              net_rx_rss_add_chunk(rss_input, &bytes_written,
498                                   &packet_info.dst_port,
499                                   sizeof(packet_info.dst_port));
500  
501          } else if (config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_IPv6) {
502              if (packet_info.is_ipv6_ext_src &&
503                 config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_IP_EX) {
504  
505                  net_rx_rss_add_chunk(rss_input, &bytes_written,
506                                       &packet_info.in6_ext_src,
507                                       sizeof(packet_info.in6_ext_src));
508              } else {
509                  net_rx_rss_add_chunk(rss_input, &bytes_written,
510                                       &packet_info.in6_src,
511                                       sizeof(packet_info.in6_src));
512              }
513              if (packet_info.is_ipv6_ext_dst &&
514                  config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_IP_EX) {
515  
516                  net_rx_rss_add_chunk(rss_input, &bytes_written,
517                                       &packet_info.in6_ext_dst,
518                                       sizeof(packet_info.in6_ext_dst));
519              } else {
520                  net_rx_rss_add_chunk(rss_input, &bytes_written,
521                                       &packet_info.in6_dst,
522                                       sizeof(packet_info.in6_dst));
523              }
524          }
525      }
526  
527      if (bytes_written) {
528          net_toeplitz_add(&result, rss_input, bytes_written, toe);
529      }
530  
531      return result;
532  }
533  
534  SEC("socket")
535  int tun_rss_steering_prog(struct __sk_buff *skb)
536  {
537  
538      struct rss_config_t *config;
539      struct toeplitz_key_data_t *toe;
540  
541      __u32 key = 0;
542      __u32 hash = 0;
543  
544      config = bpf_map_lookup_elem(&tap_rss_map_configurations, &key);
545      toe = bpf_map_lookup_elem(&tap_rss_map_toeplitz_key, &key);
546  
547      if (config && toe) {
548          if (!config->redirect) {
549              return config->default_queue;
550          }
551  
552          hash = calculate_rss_hash(skb, config, toe);
553          if (hash) {
554              __u32 table_idx = hash % config->indirections_len;
555              __u16 *queue = 0;
556  
557              queue = bpf_map_lookup_elem(&tap_rss_map_indirection_table,
558                                          &table_idx);
559  
560              if (queue) {
561                  return *queue;
562              }
563          }
564  
565          return config->default_queue;
566      }
567  
568      return -1;
569  }
570  
571  char _license[] SEC("license") = "GPL v2";
572