1 // SPDX-License-Identifier: GPL-2.0 2 // Copyright (c) 2020 Cloudflare 3 4 #include <errno.h> 5 #include <stdbool.h> 6 #include <linux/bpf.h> 7 8 #include <bpf/bpf_helpers.h> 9 10 struct { 11 __uint(type, BPF_MAP_TYPE_SOCKMAP); 12 __uint(max_entries, 2); 13 __type(key, __u32); 14 __type(value, __u64); 15 } sock_map SEC(".maps"); 16 17 struct { 18 __uint(type, BPF_MAP_TYPE_SOCKHASH); 19 __uint(max_entries, 2); 20 __type(key, __u32); 21 __type(value, __u64); 22 } sock_hash SEC(".maps"); 23 24 struct { 25 __uint(type, BPF_MAP_TYPE_ARRAY); 26 __uint(max_entries, 2); 27 __type(key, int); 28 __type(value, unsigned int); 29 } verdict_map SEC(".maps"); 30 31 static volatile bool test_sockmap; /* toggled by user-space */ 32 33 SEC("sk_skb/stream_parser") 34 int prog_skb_parser(struct __sk_buff *skb) 35 { 36 return skb->len; 37 } 38 39 SEC("sk_skb/stream_verdict") 40 int prog_skb_verdict(struct __sk_buff *skb) 41 { 42 unsigned int *count; 43 __u32 zero = 0; 44 int verdict; 45 46 if (test_sockmap) 47 verdict = bpf_sk_redirect_map(skb, &sock_map, zero, 0); 48 else 49 verdict = bpf_sk_redirect_hash(skb, &sock_hash, &zero, 0); 50 51 count = bpf_map_lookup_elem(&verdict_map, &verdict); 52 if (count) 53 (*count)++; 54 55 return verdict; 56 } 57 58 SEC("sk_msg") 59 int prog_msg_verdict(struct sk_msg_md *msg) 60 { 61 unsigned int *count; 62 __u32 zero = 0; 63 int verdict; 64 65 if (test_sockmap) 66 verdict = bpf_msg_redirect_map(msg, &sock_map, zero, 0); 67 else 68 verdict = bpf_msg_redirect_hash(msg, &sock_hash, &zero, 0); 69 70 count = bpf_map_lookup_elem(&verdict_map, &verdict); 71 if (count) 72 (*count)++; 73 74 return verdict; 75 } 76 77 SEC("sk_reuseport") 78 int prog_reuseport(struct sk_reuseport_md *reuse) 79 { 80 unsigned int *count; 81 int err, verdict; 82 __u32 zero = 0; 83 84 if (test_sockmap) 85 err = bpf_sk_select_reuseport(reuse, &sock_map, &zero, 0); 86 else 87 err = bpf_sk_select_reuseport(reuse, &sock_hash, &zero, 0); 88 verdict = err ? SK_DROP : SK_PASS; 89 90 count = bpf_map_lookup_elem(&verdict_map, &verdict); 91 if (count) 92 (*count)++; 93 94 return verdict; 95 } 96 97 int _version SEC("version") = 1; 98 char _license[] SEC("license") = "GPL"; 99