1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021 Facebook */
3 #include <stdbool.h>
4 #include <stdint.h>
5 #include <linux/stddef.h>
6 #include <linux/if_ether.h>
7 #include <linux/in.h>
8 #include <linux/in6.h>
9 #include <linux/ip.h>
10 #include <linux/ipv6.h>
11 #include <linux/tcp.h>
12 #include <linux/udp.h>
13 #include <linux/bpf.h>
14 #include <linux/types.h>
15 #include <bpf/bpf_endian.h>
16 #include <bpf/bpf_helpers.h>
17 
18 enum pkt_parse_err {
19 	NO_ERR,
20 	BAD_IP6_HDR,
21 	BAD_IP4GUE_HDR,
22 	BAD_IP6GUE_HDR,
23 };
24 
25 enum pkt_flag {
26 	TUNNEL = 0x1,
27 	TCP_SYN = 0x2,
28 	QUIC_INITIAL_FLAG = 0x4,
29 	TCP_ACK = 0x8,
30 	TCP_RST = 0x10
31 };
32 
33 struct v4_lpm_key {
34 	__u32 prefixlen;
35 	__u32 src;
36 };
37 
38 struct v4_lpm_val {
39 	struct v4_lpm_key key;
40 	__u8 val;
41 };
42 
43 struct {
44 	__uint(type, BPF_MAP_TYPE_HASH);
45 	__uint(max_entries, 16);
46 	__type(key, struct in6_addr);
47 	__type(value, bool);
48 } v6_addr_map SEC(".maps");
49 
50 struct {
51 	__uint(type, BPF_MAP_TYPE_HASH);
52 	__uint(max_entries, 16);
53 	__type(key, __u32);
54 	__type(value, bool);
55 } v4_addr_map SEC(".maps");
56 
57 struct {
58 	__uint(type, BPF_MAP_TYPE_LPM_TRIE);
59 	__uint(max_entries, 16);
60 	__uint(key_size, sizeof(struct v4_lpm_key));
61 	__uint(value_size, sizeof(struct v4_lpm_val));
62 	__uint(map_flags, BPF_F_NO_PREALLOC);
63 } v4_lpm_val_map SEC(".maps");
64 
65 struct {
66 	__uint(type, BPF_MAP_TYPE_ARRAY);
67 	__uint(max_entries, 16);
68 	__type(key, int);
69 	__type(value, __u8);
70 } tcp_port_map SEC(".maps");
71 
72 struct {
73 	__uint(type, BPF_MAP_TYPE_ARRAY);
74 	__uint(max_entries, 16);
75 	__type(key, int);
76 	__type(value, __u16);
77 } udp_port_map SEC(".maps");
78 
79 enum ip_type { V4 = 1, V6 = 2 };
80 
81 struct fw_match_info {
82 	__u8 v4_src_ip_match;
83 	__u8 v6_src_ip_match;
84 	__u8 v4_src_prefix_match;
85 	__u8 v4_dst_prefix_match;
86 	__u8 tcp_dp_match;
87 	__u16 udp_sp_match;
88 	__u16 udp_dp_match;
89 	bool is_tcp;
90 	bool is_tcp_syn;
91 };
92 
93 struct pkt_info {
94 	enum ip_type type;
95 	union {
96 		struct iphdr *ipv4;
97 		struct ipv6hdr *ipv6;
98 	} ip;
99 	int sport;
100 	int dport;
101 	__u16 trans_hdr_offset;
102 	__u8 proto;
103 	__u8 flags;
104 };
105 
106 static __always_inline struct ethhdr *parse_ethhdr(void *data, void *data_end)
107 {
108 	struct ethhdr *eth = data;
109 
110 	if (eth + 1 > data_end)
111 		return NULL;
112 
113 	return eth;
114 }
115 
116 static __always_inline __u8 filter_ipv6_addr(const struct in6_addr *ipv6addr)
117 {
118 	__u8 *leaf;
119 
120 	leaf = bpf_map_lookup_elem(&v6_addr_map, ipv6addr);
121 
122 	return leaf ? *leaf : 0;
123 }
124 
125 static __always_inline __u8 filter_ipv4_addr(const __u32 ipaddr)
126 {
127 	__u8 *leaf;
128 
129 	leaf = bpf_map_lookup_elem(&v4_addr_map, &ipaddr);
130 
131 	return leaf ? *leaf : 0;
132 }
133 
134 static __always_inline __u8 filter_ipv4_lpm(const __u32 ipaddr)
135 {
136 	struct v4_lpm_key v4_key = {};
137 	struct v4_lpm_val *lpm_val;
138 
139 	v4_key.src = ipaddr;
140 	v4_key.prefixlen = 32;
141 
142 	lpm_val = bpf_map_lookup_elem(&v4_lpm_val_map, &v4_key);
143 
144 	return lpm_val ? lpm_val->val : 0;
145 }
146 
147 
148 static __always_inline void
149 filter_src_dst_ip(struct pkt_info* info, struct fw_match_info* match_info)
150 {
151 	if (info->type == V6) {
152 		match_info->v6_src_ip_match =
153 			filter_ipv6_addr(&info->ip.ipv6->saddr);
154 	} else if (info->type == V4) {
155 		match_info->v4_src_ip_match =
156 			filter_ipv4_addr(info->ip.ipv4->saddr);
157 		match_info->v4_src_prefix_match =
158 			filter_ipv4_lpm(info->ip.ipv4->saddr);
159 		match_info->v4_dst_prefix_match =
160 			filter_ipv4_lpm(info->ip.ipv4->daddr);
161 	}
162 }
163 
164 static __always_inline void *
165 get_transport_hdr(__u16 offset, void *data, void *data_end)
166 {
167 	if (offset > 255 || data + offset > data_end)
168 		return NULL;
169 
170 	return data + offset;
171 }
172 
173 static __always_inline bool tcphdr_only_contains_flag(struct tcphdr *tcp,
174 						      __u32 FLAG)
175 {
176 	return (tcp_flag_word(tcp) &
177 		(TCP_FLAG_ACK | TCP_FLAG_RST | TCP_FLAG_SYN | TCP_FLAG_FIN)) == FLAG;
178 }
179 
180 static __always_inline void set_tcp_flags(struct pkt_info *info,
181 					  struct tcphdr *tcp) {
182 	if (tcphdr_only_contains_flag(tcp, TCP_FLAG_SYN))
183 		info->flags |= TCP_SYN;
184 	else if (tcphdr_only_contains_flag(tcp, TCP_FLAG_ACK))
185 		info->flags |= TCP_ACK;
186 	else if (tcphdr_only_contains_flag(tcp, TCP_FLAG_RST))
187 		info->flags |= TCP_RST;
188 }
189 
190 static __always_inline bool
191 parse_tcp(struct pkt_info *info, void *transport_hdr, void *data_end)
192 {
193 	struct tcphdr *tcp = transport_hdr;
194 
195 	if (tcp + 1 > data_end)
196 		return false;
197 
198 	info->sport = bpf_ntohs(tcp->source);
199 	info->dport = bpf_ntohs(tcp->dest);
200 	set_tcp_flags(info, tcp);
201 
202 	return true;
203 }
204 
205 static __always_inline bool
206 parse_udp(struct pkt_info *info, void *transport_hdr, void *data_end)
207 {
208 	struct udphdr *udp = transport_hdr;
209 
210 	if (udp + 1 > data_end)
211 		return false;
212 
213 	info->sport = bpf_ntohs(udp->source);
214 	info->dport = bpf_ntohs(udp->dest);
215 
216 	return true;
217 }
218 
219 static __always_inline __u8 filter_tcp_port(int port)
220 {
221 	__u8 *leaf = bpf_map_lookup_elem(&tcp_port_map, &port);
222 
223 	return leaf ? *leaf : 0;
224 }
225 
226 static __always_inline __u16 filter_udp_port(int port)
227 {
228 	__u16 *leaf = bpf_map_lookup_elem(&udp_port_map, &port);
229 
230 	return leaf ? *leaf : 0;
231 }
232 
233 static __always_inline bool
234 filter_transport_hdr(void *transport_hdr, void *data_end,
235 		     struct pkt_info *info, struct fw_match_info *match_info)
236 {
237 	if (info->proto == IPPROTO_TCP) {
238 		if (!parse_tcp(info, transport_hdr, data_end))
239 			return false;
240 
241 		match_info->is_tcp = true;
242 		match_info->is_tcp_syn = (info->flags & TCP_SYN) > 0;
243 
244 		match_info->tcp_dp_match = filter_tcp_port(info->dport);
245 	} else if (info->proto == IPPROTO_UDP) {
246 		if (!parse_udp(info, transport_hdr, data_end))
247 			return false;
248 
249 		match_info->udp_dp_match = filter_udp_port(info->dport);
250 		match_info->udp_sp_match = filter_udp_port(info->sport);
251 	}
252 
253 	return true;
254 }
255 
256 static __always_inline __u8
257 parse_gue_v6(struct pkt_info *info, struct ipv6hdr *ip6h, void *data_end)
258 {
259 	struct udphdr *udp = (struct udphdr *)(ip6h + 1);
260 	void *encap_data = udp + 1;
261 
262 	if (udp + 1 > data_end)
263 		return BAD_IP6_HDR;
264 
265 	if (udp->dest != bpf_htons(6666))
266 		return NO_ERR;
267 
268 	info->flags |= TUNNEL;
269 
270 	if (encap_data + 1 > data_end)
271 		return BAD_IP6GUE_HDR;
272 
273 	if (*(__u8 *)encap_data & 0x30) {
274 		struct ipv6hdr *inner_ip6h = encap_data;
275 
276 		if (inner_ip6h + 1 > data_end)
277 			return BAD_IP6GUE_HDR;
278 
279 		info->type = V6;
280 		info->proto = inner_ip6h->nexthdr;
281 		info->ip.ipv6 = inner_ip6h;
282 		info->trans_hdr_offset += sizeof(struct ipv6hdr) + sizeof(struct udphdr);
283 	} else {
284 		struct iphdr *inner_ip4h = encap_data;
285 
286 		if (inner_ip4h + 1 > data_end)
287 			return BAD_IP6GUE_HDR;
288 
289 		info->type = V4;
290 		info->proto = inner_ip4h->protocol;
291 		info->ip.ipv4 = inner_ip4h;
292 		info->trans_hdr_offset += sizeof(struct iphdr) + sizeof(struct udphdr);
293 	}
294 
295 	return NO_ERR;
296 }
297 
298 static __always_inline __u8 parse_ipv6_gue(struct pkt_info *info,
299 					   void *data, void *data_end)
300 {
301 	struct ipv6hdr *ip6h = data + sizeof(struct ethhdr);
302 
303 	if (ip6h + 1 > data_end)
304 		return BAD_IP6_HDR;
305 
306 	info->proto = ip6h->nexthdr;
307 	info->ip.ipv6 = ip6h;
308 	info->type = V6;
309 	info->trans_hdr_offset = sizeof(struct ethhdr) + sizeof(struct ipv6hdr);
310 
311 	if (info->proto == IPPROTO_UDP)
312 		return parse_gue_v6(info, ip6h, data_end);
313 
314 	return NO_ERR;
315 }
316 
317 SEC("xdp")
318 int edgewall(struct xdp_md *ctx)
319 {
320 	void *data_end = (void *)(long)(ctx->data_end);
321 	void *data = (void *)(long)(ctx->data);
322 	struct fw_match_info match_info = {};
323 	struct pkt_info info = {};
324 	void *transport_hdr;
325 	struct ethhdr *eth;
326 	bool filter_res;
327 	__u32 proto;
328 
329 	eth = parse_ethhdr(data, data_end);
330 	if (!eth)
331 		return XDP_DROP;
332 
333 	proto = eth->h_proto;
334 	if (proto != bpf_htons(ETH_P_IPV6))
335 		return XDP_DROP;
336 
337 	if (parse_ipv6_gue(&info, data, data_end))
338 		return XDP_DROP;
339 
340 	if (info.proto == IPPROTO_ICMPV6)
341 		return XDP_PASS;
342 
343 	if (info.proto != IPPROTO_TCP && info.proto != IPPROTO_UDP)
344 		return XDP_DROP;
345 
346 	filter_src_dst_ip(&info, &match_info);
347 
348 	transport_hdr = get_transport_hdr(info.trans_hdr_offset, data,
349 					  data_end);
350 	if (!transport_hdr)
351 		return XDP_DROP;
352 
353 	filter_res = filter_transport_hdr(transport_hdr, data_end,
354 					  &info, &match_info);
355 	if (!filter_res)
356 		return XDP_DROP;
357 
358 	if (match_info.is_tcp && !match_info.is_tcp_syn)
359 		return XDP_PASS;
360 
361 	return XDP_DROP;
362 }
363 
364 char LICENSE[] SEC("license") = "GPL";
365