1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2023 Isovalent */
3 #include <stdbool.h>
4 #include <linux/bpf.h>
5 #include <linux/if_ether.h>
6 #include <linux/in.h>
7 #include <linux/ip.h>
8 #include <linux/ipv6.h>
9 #include <linux/tcp.h>
10 #include <linux/udp.h>
11 #include <bpf/bpf_endian.h>
12 #include <bpf/bpf_helpers.h>
13 #include <linux/pkt_cls.h>
14 
15 char LICENSE[] SEC("license") = "GPL";
16 
17 __u64 sk_cookie_seen;
18 __u64 reuseport_executed;
19 union {
20 	struct tcphdr tcp;
21 	struct udphdr udp;
22 } headers;
23 
24 const volatile __u16 dest_port;
25 
26 struct {
27 	__uint(type, BPF_MAP_TYPE_SOCKMAP);
28 	__uint(max_entries, 1);
29 	__type(key, __u32);
30 	__type(value, __u64);
31 } sk_map SEC(".maps");
32 
33 SEC("sk_reuseport")
34 int reuse_accept(struct sk_reuseport_md *ctx)
35 {
36 	reuseport_executed++;
37 
38 	if (ctx->ip_protocol == IPPROTO_TCP) {
39 		if (ctx->data + sizeof(headers.tcp) > ctx->data_end)
40 			return SK_DROP;
41 
42 		if (__builtin_memcmp(&headers.tcp, ctx->data, sizeof(headers.tcp)) != 0)
43 			return SK_DROP;
44 	} else if (ctx->ip_protocol == IPPROTO_UDP) {
45 		if (ctx->data + sizeof(headers.udp) > ctx->data_end)
46 			return SK_DROP;
47 
48 		if (__builtin_memcmp(&headers.udp, ctx->data, sizeof(headers.udp)) != 0)
49 			return SK_DROP;
50 	} else {
51 		return SK_DROP;
52 	}
53 
54 	sk_cookie_seen = bpf_get_socket_cookie(ctx->sk);
55 	return SK_PASS;
56 }
57 
58 SEC("sk_reuseport")
59 int reuse_drop(struct sk_reuseport_md *ctx)
60 {
61 	reuseport_executed++;
62 	sk_cookie_seen = 0;
63 	return SK_DROP;
64 }
65 
66 static int
67 assign_sk(struct __sk_buff *skb)
68 {
69 	int zero = 0, ret = 0;
70 	struct bpf_sock *sk;
71 
72 	sk = bpf_map_lookup_elem(&sk_map, &zero);
73 	if (!sk)
74 		return TC_ACT_SHOT;
75 	ret = bpf_sk_assign(skb, sk, 0);
76 	bpf_sk_release(sk);
77 	return ret ? TC_ACT_SHOT : TC_ACT_OK;
78 }
79 
80 static bool
81 maybe_assign_tcp(struct __sk_buff *skb, struct tcphdr *th)
82 {
83 	if (th + 1 > (void *)(long)(skb->data_end))
84 		return TC_ACT_SHOT;
85 
86 	if (!th->syn || th->ack || th->dest != bpf_htons(dest_port))
87 		return TC_ACT_OK;
88 
89 	__builtin_memcpy(&headers.tcp, th, sizeof(headers.tcp));
90 	return assign_sk(skb);
91 }
92 
93 static bool
94 maybe_assign_udp(struct __sk_buff *skb, struct udphdr *uh)
95 {
96 	if (uh + 1 > (void *)(long)(skb->data_end))
97 		return TC_ACT_SHOT;
98 
99 	if (uh->dest != bpf_htons(dest_port))
100 		return TC_ACT_OK;
101 
102 	__builtin_memcpy(&headers.udp, uh, sizeof(headers.udp));
103 	return assign_sk(skb);
104 }
105 
106 SEC("tc")
107 int tc_main(struct __sk_buff *skb)
108 {
109 	void *data_end = (void *)(long)skb->data_end;
110 	void *data = (void *)(long)skb->data;
111 	struct ethhdr *eth;
112 
113 	eth = (struct ethhdr *)(data);
114 	if (eth + 1 > data_end)
115 		return TC_ACT_SHOT;
116 
117 	if (eth->h_proto == bpf_htons(ETH_P_IP)) {
118 		struct iphdr *iph = (struct iphdr *)(data + sizeof(*eth));
119 
120 		if (iph + 1 > data_end)
121 			return TC_ACT_SHOT;
122 
123 		if (iph->protocol == IPPROTO_TCP)
124 			return maybe_assign_tcp(skb, (struct tcphdr *)(iph + 1));
125 		else if (iph->protocol == IPPROTO_UDP)
126 			return maybe_assign_udp(skb, (struct udphdr *)(iph + 1));
127 		else
128 			return TC_ACT_SHOT;
129 	} else {
130 		struct ipv6hdr *ip6h = (struct ipv6hdr *)(data + sizeof(*eth));
131 
132 		if (ip6h + 1 > data_end)
133 			return TC_ACT_SHOT;
134 
135 		if (ip6h->nexthdr == IPPROTO_TCP)
136 			return maybe_assign_tcp(skb, (struct tcphdr *)(ip6h + 1));
137 		else if (ip6h->nexthdr == IPPROTO_UDP)
138 			return maybe_assign_udp(skb, (struct udphdr *)(ip6h + 1));
139 		else
140 			return TC_ACT_SHOT;
141 	}
142 }
143