1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2020 Cloudflare
3 
4 #include <errno.h>
5 #include <stdbool.h>
6 #include <linux/bpf.h>
7 
8 #include <bpf/bpf_helpers.h>
9 
10 struct {
11 	__uint(type, BPF_MAP_TYPE_SOCKMAP);
12 	__uint(max_entries, 2);
13 	__type(key, __u32);
14 	__type(value, __u64);
15 } sock_map SEC(".maps");
16 
17 struct {
18 	__uint(type, BPF_MAP_TYPE_SOCKHASH);
19 	__uint(max_entries, 2);
20 	__type(key, __u32);
21 	__type(value, __u64);
22 } sock_hash SEC(".maps");
23 
24 struct {
25 	__uint(type, BPF_MAP_TYPE_ARRAY);
26 	__uint(max_entries, 2);
27 	__type(key, int);
28 	__type(value, unsigned int);
29 } verdict_map SEC(".maps");
30 
31 static volatile bool test_sockmap; /* toggled by user-space */
32 
33 SEC("sk_skb/stream_parser")
34 int prog_skb_parser(struct __sk_buff *skb)
35 {
36 	return skb->len;
37 }
38 
39 SEC("sk_skb/stream_verdict")
40 int prog_skb_verdict(struct __sk_buff *skb)
41 {
42 	unsigned int *count;
43 	__u32 zero = 0;
44 	int verdict;
45 
46 	if (test_sockmap)
47 		verdict = bpf_sk_redirect_map(skb, &sock_map, zero, 0);
48 	else
49 		verdict = bpf_sk_redirect_hash(skb, &sock_hash, &zero, 0);
50 
51 	count = bpf_map_lookup_elem(&verdict_map, &verdict);
52 	if (count)
53 		(*count)++;
54 
55 	return verdict;
56 }
57 
58 SEC("sk_msg")
59 int prog_msg_verdict(struct sk_msg_md *msg)
60 {
61 	unsigned int *count;
62 	__u32 zero = 0;
63 	int verdict;
64 
65 	if (test_sockmap)
66 		verdict = bpf_msg_redirect_map(msg, &sock_map, zero, 0);
67 	else
68 		verdict = bpf_msg_redirect_hash(msg, &sock_hash, &zero, 0);
69 
70 	count = bpf_map_lookup_elem(&verdict_map, &verdict);
71 	if (count)
72 		(*count)++;
73 
74 	return verdict;
75 }
76 
77 SEC("sk_reuseport")
78 int prog_reuseport(struct sk_reuseport_md *reuse)
79 {
80 	unsigned int *count;
81 	int err, verdict;
82 	__u32 zero = 0;
83 
84 	if (test_sockmap)
85 		err = bpf_sk_select_reuseport(reuse, &sock_map, &zero, 0);
86 	else
87 		err = bpf_sk_select_reuseport(reuse, &sock_hash, &zero, 0);
88 	verdict = err ? SK_DROP : SK_PASS;
89 
90 	count = bpf_map_lookup_elem(&verdict_map, &verdict);
91 	if (count)
92 		(*count)++;
93 
94 	return verdict;
95 }
96 
97 int _version SEC("version") = 1;
98 char _license[] SEC("license") = "GPL";
99