12d7824ffSLorenz Bauer // SPDX-License-Identifier: GPL-2.0
22d7824ffSLorenz Bauer // Copyright (c) 2019 Cloudflare Ltd.
32d7824ffSLorenz Bauer // Copyright (c) 2020 Isovalent, Inc.
42d7824ffSLorenz Bauer 
52d7824ffSLorenz Bauer #include <stddef.h>
62d7824ffSLorenz Bauer #include <stdbool.h>
72d7824ffSLorenz Bauer #include <string.h>
82d7824ffSLorenz Bauer #include <linux/bpf.h>
92d7824ffSLorenz Bauer #include <linux/if_ether.h>
102d7824ffSLorenz Bauer #include <linux/in.h>
112d7824ffSLorenz Bauer #include <linux/ip.h>
122d7824ffSLorenz Bauer #include <linux/ipv6.h>
132d7824ffSLorenz Bauer #include <linux/pkt_cls.h>
142d7824ffSLorenz Bauer #include <linux/tcp.h>
152d7824ffSLorenz Bauer #include <sys/socket.h>
162d7824ffSLorenz Bauer #include <bpf/bpf_helpers.h>
172d7824ffSLorenz Bauer #include <bpf/bpf_endian.h>
18*c8ed6685SAndrii Nakryiko #include "bpf_misc.h"
192d7824ffSLorenz Bauer 
207ce878caSIlya Leoshkevich #if defined(IPROUTE2_HAVE_LIBBPF)
217ce878caSIlya Leoshkevich /* Use a new-style map definition. */
227ce878caSIlya Leoshkevich struct {
237ce878caSIlya Leoshkevich 	__uint(type, BPF_MAP_TYPE_SOCKMAP);
247ce878caSIlya Leoshkevich 	__type(key, int);
257ce878caSIlya Leoshkevich 	__type(value, __u64);
267ce878caSIlya Leoshkevich 	__uint(pinning, LIBBPF_PIN_BY_NAME);
277ce878caSIlya Leoshkevich 	__uint(max_entries, 1);
287ce878caSIlya Leoshkevich } server_map SEC(".maps");
297ce878caSIlya Leoshkevich #else
300b9ad56bSJakub Sitnicki /* Pin map under /sys/fs/bpf/tc/globals/<map name> */
310b9ad56bSJakub Sitnicki #define PIN_GLOBAL_NS 2
320b9ad56bSJakub Sitnicki 
330b9ad56bSJakub Sitnicki /* Must match struct bpf_elf_map layout from iproute2 */
340b9ad56bSJakub Sitnicki struct {
350b9ad56bSJakub Sitnicki 	__u32 type;
360b9ad56bSJakub Sitnicki 	__u32 size_key;
370b9ad56bSJakub Sitnicki 	__u32 size_value;
380b9ad56bSJakub Sitnicki 	__u32 max_elem;
390b9ad56bSJakub Sitnicki 	__u32 flags;
400b9ad56bSJakub Sitnicki 	__u32 id;
410b9ad56bSJakub Sitnicki 	__u32 pinning;
420b9ad56bSJakub Sitnicki } server_map SEC("maps") = {
430b9ad56bSJakub Sitnicki 	.type = BPF_MAP_TYPE_SOCKMAP,
440b9ad56bSJakub Sitnicki 	.size_key = sizeof(int),
450b9ad56bSJakub Sitnicki 	.size_value  = sizeof(__u64),
460b9ad56bSJakub Sitnicki 	.max_elem = 1,
470b9ad56bSJakub Sitnicki 	.pinning = PIN_GLOBAL_NS,
480b9ad56bSJakub Sitnicki };
497ce878caSIlya Leoshkevich #endif
500b9ad56bSJakub Sitnicki 
512d7824ffSLorenz Bauer char _license[] SEC("license") = "GPL";
522d7824ffSLorenz Bauer 
532d7824ffSLorenz Bauer /* Fill 'tuple' with L3 info, and attempt to find L4. On fail, return NULL. */
542d7824ffSLorenz Bauer static inline struct bpf_sock_tuple *
get_tuple(struct __sk_buff * skb,bool * ipv4,bool * tcp)558a02a170SJoe Stringer get_tuple(struct __sk_buff *skb, bool *ipv4, bool *tcp)
562d7824ffSLorenz Bauer {
572d7824ffSLorenz Bauer 	void *data_end = (void *)(long)skb->data_end;
582d7824ffSLorenz Bauer 	void *data = (void *)(long)skb->data;
592d7824ffSLorenz Bauer 	struct bpf_sock_tuple *result;
602d7824ffSLorenz Bauer 	struct ethhdr *eth;
612d7824ffSLorenz Bauer 	__u8 proto = 0;
622d7824ffSLorenz Bauer 	__u64 ihl_len;
632d7824ffSLorenz Bauer 
642d7824ffSLorenz Bauer 	eth = (struct ethhdr *)(data);
652d7824ffSLorenz Bauer 	if (eth + 1 > data_end)
662d7824ffSLorenz Bauer 		return NULL;
672d7824ffSLorenz Bauer 
682d7824ffSLorenz Bauer 	if (eth->h_proto == bpf_htons(ETH_P_IP)) {
692d7824ffSLorenz Bauer 		struct iphdr *iph = (struct iphdr *)(data + sizeof(*eth));
702d7824ffSLorenz Bauer 
712d7824ffSLorenz Bauer 		if (iph + 1 > data_end)
722d7824ffSLorenz Bauer 			return NULL;
732d7824ffSLorenz Bauer 		if (iph->ihl != 5)
742d7824ffSLorenz Bauer 			/* Options are not supported */
752d7824ffSLorenz Bauer 			return NULL;
762d7824ffSLorenz Bauer 		ihl_len = iph->ihl * 4;
772d7824ffSLorenz Bauer 		proto = iph->protocol;
782d7824ffSLorenz Bauer 		*ipv4 = true;
792d7824ffSLorenz Bauer 		result = (struct bpf_sock_tuple *)&iph->saddr;
802d7824ffSLorenz Bauer 	} else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) {
812d7824ffSLorenz Bauer 		struct ipv6hdr *ip6h = (struct ipv6hdr *)(data + sizeof(*eth));
822d7824ffSLorenz Bauer 
832d7824ffSLorenz Bauer 		if (ip6h + 1 > data_end)
842d7824ffSLorenz Bauer 			return NULL;
852d7824ffSLorenz Bauer 		ihl_len = sizeof(*ip6h);
862d7824ffSLorenz Bauer 		proto = ip6h->nexthdr;
872d7824ffSLorenz Bauer 		*ipv4 = false;
882d7824ffSLorenz Bauer 		result = (struct bpf_sock_tuple *)&ip6h->saddr;
892d7824ffSLorenz Bauer 	} else {
902d7824ffSLorenz Bauer 		return (struct bpf_sock_tuple *)data;
912d7824ffSLorenz Bauer 	}
922d7824ffSLorenz Bauer 
938a02a170SJoe Stringer 	if (proto != IPPROTO_TCP && proto != IPPROTO_UDP)
942d7824ffSLorenz Bauer 		return NULL;
952d7824ffSLorenz Bauer 
968a02a170SJoe Stringer 	*tcp = (proto == IPPROTO_TCP);
97*c8ed6685SAndrii Nakryiko 	__sink(ihl_len);
982d7824ffSLorenz Bauer 	return result;
992d7824ffSLorenz Bauer }
1002d7824ffSLorenz Bauer 
1012d7824ffSLorenz Bauer static inline int
handle_udp(struct __sk_buff * skb,struct bpf_sock_tuple * tuple,bool ipv4)1028a02a170SJoe Stringer handle_udp(struct __sk_buff *skb, struct bpf_sock_tuple *tuple, bool ipv4)
1038a02a170SJoe Stringer {
1048a02a170SJoe Stringer 	struct bpf_sock *sk;
1050b9ad56bSJakub Sitnicki 	const int zero = 0;
1068a02a170SJoe Stringer 	size_t tuple_len;
1070b9ad56bSJakub Sitnicki 	__be16 dport;
1088a02a170SJoe Stringer 	int ret;
1098a02a170SJoe Stringer 
1108a02a170SJoe Stringer 	tuple_len = ipv4 ? sizeof(tuple->ipv4) : sizeof(tuple->ipv6);
1118a02a170SJoe Stringer 	if ((void *)tuple + tuple_len > (void *)(long)skb->data_end)
1128a02a170SJoe Stringer 		return TC_ACT_SHOT;
1138a02a170SJoe Stringer 
1148a02a170SJoe Stringer 	sk = bpf_sk_lookup_udp(skb, tuple, tuple_len, BPF_F_CURRENT_NETNS, 0);
1158a02a170SJoe Stringer 	if (sk)
1168a02a170SJoe Stringer 		goto assign;
1178a02a170SJoe Stringer 
1180b9ad56bSJakub Sitnicki 	dport = ipv4 ? tuple->ipv4.dport : tuple->ipv6.dport;
1190b9ad56bSJakub Sitnicki 	if (dport != bpf_htons(4321))
1208a02a170SJoe Stringer 		return TC_ACT_OK;
1218a02a170SJoe Stringer 
1220b9ad56bSJakub Sitnicki 	sk = bpf_map_lookup_elem(&server_map, &zero);
1238a02a170SJoe Stringer 	if (!sk)
1248a02a170SJoe Stringer 		return TC_ACT_SHOT;
1258a02a170SJoe Stringer 
1268a02a170SJoe Stringer assign:
1278a02a170SJoe Stringer 	ret = bpf_sk_assign(skb, sk, 0);
1288a02a170SJoe Stringer 	bpf_sk_release(sk);
1298a02a170SJoe Stringer 	return ret;
1308a02a170SJoe Stringer }
1318a02a170SJoe Stringer 
1328a02a170SJoe Stringer static inline int
handle_tcp(struct __sk_buff * skb,struct bpf_sock_tuple * tuple,bool ipv4)1332d7824ffSLorenz Bauer handle_tcp(struct __sk_buff *skb, struct bpf_sock_tuple *tuple, bool ipv4)
1342d7824ffSLorenz Bauer {
1352d7824ffSLorenz Bauer 	struct bpf_sock *sk;
1360b9ad56bSJakub Sitnicki 	const int zero = 0;
1372d7824ffSLorenz Bauer 	size_t tuple_len;
1380b9ad56bSJakub Sitnicki 	__be16 dport;
1392d7824ffSLorenz Bauer 	int ret;
1402d7824ffSLorenz Bauer 
1412d7824ffSLorenz Bauer 	tuple_len = ipv4 ? sizeof(tuple->ipv4) : sizeof(tuple->ipv6);
1422d7824ffSLorenz Bauer 	if ((void *)tuple + tuple_len > (void *)(long)skb->data_end)
1432d7824ffSLorenz Bauer 		return TC_ACT_SHOT;
1442d7824ffSLorenz Bauer 
1452d7824ffSLorenz Bauer 	sk = bpf_skc_lookup_tcp(skb, tuple, tuple_len, BPF_F_CURRENT_NETNS, 0);
1462d7824ffSLorenz Bauer 	if (sk) {
1472d7824ffSLorenz Bauer 		if (sk->state != BPF_TCP_LISTEN)
1482d7824ffSLorenz Bauer 			goto assign;
1492d7824ffSLorenz Bauer 		bpf_sk_release(sk);
1502d7824ffSLorenz Bauer 	}
1512d7824ffSLorenz Bauer 
1520b9ad56bSJakub Sitnicki 	dport = ipv4 ? tuple->ipv4.dport : tuple->ipv6.dport;
1530b9ad56bSJakub Sitnicki 	if (dport != bpf_htons(4321))
1542d7824ffSLorenz Bauer 		return TC_ACT_OK;
1552d7824ffSLorenz Bauer 
1560b9ad56bSJakub Sitnicki 	sk = bpf_map_lookup_elem(&server_map, &zero);
1572d7824ffSLorenz Bauer 	if (!sk)
1582d7824ffSLorenz Bauer 		return TC_ACT_SHOT;
1592d7824ffSLorenz Bauer 
1602d7824ffSLorenz Bauer 	if (sk->state != BPF_TCP_LISTEN) {
1612d7824ffSLorenz Bauer 		bpf_sk_release(sk);
1622d7824ffSLorenz Bauer 		return TC_ACT_SHOT;
1632d7824ffSLorenz Bauer 	}
1642d7824ffSLorenz Bauer 
1652d7824ffSLorenz Bauer assign:
1662d7824ffSLorenz Bauer 	ret = bpf_sk_assign(skb, sk, 0);
1672d7824ffSLorenz Bauer 	bpf_sk_release(sk);
1682d7824ffSLorenz Bauer 	return ret;
1692d7824ffSLorenz Bauer }
1702d7824ffSLorenz Bauer 
171c22bdd28SAndrii Nakryiko SEC("tc")
bpf_sk_assign_test(struct __sk_buff * skb)1722d7824ffSLorenz Bauer int bpf_sk_assign_test(struct __sk_buff *skb)
1732d7824ffSLorenz Bauer {
174fe4625d8SEyal Birger 	struct bpf_sock_tuple *tuple;
1752d7824ffSLorenz Bauer 	bool ipv4 = false;
1768a02a170SJoe Stringer 	bool tcp = false;
1772d7824ffSLorenz Bauer 	int ret = 0;
1782d7824ffSLorenz Bauer 
1798a02a170SJoe Stringer 	tuple = get_tuple(skb, &ipv4, &tcp);
1802d7824ffSLorenz Bauer 	if (!tuple)
1812d7824ffSLorenz Bauer 		return TC_ACT_SHOT;
1822d7824ffSLorenz Bauer 
1838a02a170SJoe Stringer 	/* Note that the verifier socket return type for bpf_skc_lookup_tcp()
1848a02a170SJoe Stringer 	 * differs from bpf_sk_lookup_udp(), so even though the C-level type is
1858a02a170SJoe Stringer 	 * the same here, if we try to share the implementations they will
1868a02a170SJoe Stringer 	 * fail to verify because we're crossing pointer types.
1878a02a170SJoe Stringer 	 */
1888a02a170SJoe Stringer 	if (tcp)
1892d7824ffSLorenz Bauer 		ret = handle_tcp(skb, tuple, ipv4);
1908a02a170SJoe Stringer 	else
1918a02a170SJoe Stringer 		ret = handle_udp(skb, tuple, ipv4);
1922d7824ffSLorenz Bauer 
1932d7824ffSLorenz Bauer 	return ret == 0 ? TC_ACT_OK : TC_ACT_SHOT;
1942d7824ffSLorenz Bauer }
195