13d5786eaSGilad Sever // SPDX-License-Identifier: GPL-2.0
23d5786eaSGilad Sever #include <linux/bpf.h>
3*ee77f3d6SYonghong Song #include <bpf/bpf_helpers.h>
4*ee77f3d6SYonghong Song #include <bpf/bpf_endian.h>
5*ee77f3d6SYonghong Song 
63d5786eaSGilad Sever #include <linux/ip.h>
73d5786eaSGilad Sever #include <linux/in.h>
83d5786eaSGilad Sever #include <linux/if_ether.h>
93d5786eaSGilad Sever #include <linux/pkt_cls.h>
103d5786eaSGilad Sever #include <stdbool.h>
113d5786eaSGilad Sever 
123d5786eaSGilad Sever int lookup_status;
133d5786eaSGilad Sever bool test_xdp;
143d5786eaSGilad Sever bool tcp_skc;
153d5786eaSGilad Sever 
163d5786eaSGilad Sever #define CUR_NS BPF_F_CURRENT_NETNS
173d5786eaSGilad Sever 
socket_lookup(void * ctx,void * data_end,void * data)183d5786eaSGilad Sever static void socket_lookup(void *ctx, void *data_end, void *data)
193d5786eaSGilad Sever {
203d5786eaSGilad Sever 	struct ethhdr *eth = data;
213d5786eaSGilad Sever 	struct bpf_sock_tuple *tp;
223d5786eaSGilad Sever 	struct bpf_sock *sk;
233d5786eaSGilad Sever 	struct iphdr *iph;
243d5786eaSGilad Sever 	int tplen;
253d5786eaSGilad Sever 
263d5786eaSGilad Sever 	if (eth + 1 > data_end)
273d5786eaSGilad Sever 		return;
283d5786eaSGilad Sever 
293d5786eaSGilad Sever 	if (eth->h_proto != bpf_htons(ETH_P_IP))
303d5786eaSGilad Sever 		return;
313d5786eaSGilad Sever 
323d5786eaSGilad Sever 	iph = (struct iphdr *)(eth + 1);
333d5786eaSGilad Sever 	if (iph + 1 > data_end)
343d5786eaSGilad Sever 		return;
353d5786eaSGilad Sever 
363d5786eaSGilad Sever 	tp = (struct bpf_sock_tuple *)&iph->saddr;
373d5786eaSGilad Sever 	tplen = sizeof(tp->ipv4);
383d5786eaSGilad Sever 	if ((void *)tp + tplen > data_end)
393d5786eaSGilad Sever 		return;
403d5786eaSGilad Sever 
413d5786eaSGilad Sever 	switch (iph->protocol) {
423d5786eaSGilad Sever 	case IPPROTO_TCP:
433d5786eaSGilad Sever 		if (tcp_skc)
443d5786eaSGilad Sever 			sk = bpf_skc_lookup_tcp(ctx, tp, tplen, CUR_NS, 0);
453d5786eaSGilad Sever 		else
463d5786eaSGilad Sever 			sk = bpf_sk_lookup_tcp(ctx, tp, tplen, CUR_NS, 0);
473d5786eaSGilad Sever 		break;
483d5786eaSGilad Sever 	case IPPROTO_UDP:
493d5786eaSGilad Sever 		sk = bpf_sk_lookup_udp(ctx, tp, tplen, CUR_NS, 0);
503d5786eaSGilad Sever 		break;
513d5786eaSGilad Sever 	default:
523d5786eaSGilad Sever 		return;
533d5786eaSGilad Sever 	}
543d5786eaSGilad Sever 
553d5786eaSGilad Sever 	lookup_status = 0;
563d5786eaSGilad Sever 
573d5786eaSGilad Sever 	if (sk) {
583d5786eaSGilad Sever 		bpf_sk_release(sk);
593d5786eaSGilad Sever 		lookup_status = 1;
603d5786eaSGilad Sever 	}
613d5786eaSGilad Sever }
623d5786eaSGilad Sever 
633d5786eaSGilad Sever SEC("tc")
tc_socket_lookup(struct __sk_buff * skb)643d5786eaSGilad Sever int tc_socket_lookup(struct __sk_buff *skb)
653d5786eaSGilad Sever {
663d5786eaSGilad Sever 	void *data_end = (void *)(long)skb->data_end;
673d5786eaSGilad Sever 	void *data = (void *)(long)skb->data;
683d5786eaSGilad Sever 
693d5786eaSGilad Sever 	if (test_xdp)
703d5786eaSGilad Sever 		return TC_ACT_UNSPEC;
713d5786eaSGilad Sever 
723d5786eaSGilad Sever 	socket_lookup(skb, data_end, data);
733d5786eaSGilad Sever 	return TC_ACT_UNSPEC;
743d5786eaSGilad Sever }
753d5786eaSGilad Sever 
763d5786eaSGilad Sever SEC("xdp")
xdp_socket_lookup(struct xdp_md * xdp)773d5786eaSGilad Sever int xdp_socket_lookup(struct xdp_md *xdp)
783d5786eaSGilad Sever {
793d5786eaSGilad Sever 	void *data_end = (void *)(long)xdp->data_end;
803d5786eaSGilad Sever 	void *data = (void *)(long)xdp->data;
813d5786eaSGilad Sever 
823d5786eaSGilad Sever 	if (!test_xdp)
833d5786eaSGilad Sever 		return XDP_PASS;
843d5786eaSGilad Sever 
853d5786eaSGilad Sever 	socket_lookup(xdp, data_end, data);
863d5786eaSGilad Sever 	return XDP_PASS;
873d5786eaSGilad Sever }
883d5786eaSGilad Sever 
893d5786eaSGilad Sever char _license[] SEC("license") = "GPL";
90