1*eed92afdSMartin KaFai Lau // SPDX-License-Identifier: GPL-2.0
2*eed92afdSMartin KaFai Lau /* Copyright (c) 2021 Facebook */
3*eed92afdSMartin KaFai Lau #include "bpf_iter.h"
4*eed92afdSMartin KaFai Lau #include "bpf_tracing_net.h"
5*eed92afdSMartin KaFai Lau #include <bpf/bpf_helpers.h>
6*eed92afdSMartin KaFai Lau #include <bpf/bpf_endian.h>
7*eed92afdSMartin KaFai Lau 
8*eed92afdSMartin KaFai Lau #define bpf_tcp_sk(skc)	({				\
9*eed92afdSMartin KaFai Lau 	struct sock_common *_skc = skc;			\
10*eed92afdSMartin KaFai Lau 	sk = NULL;					\
11*eed92afdSMartin KaFai Lau 	tp = NULL;					\
12*eed92afdSMartin KaFai Lau 	if (_skc) {					\
13*eed92afdSMartin KaFai Lau 		tp = bpf_skc_to_tcp_sock(_skc);		\
14*eed92afdSMartin KaFai Lau 		sk = (struct sock *)tp;			\
15*eed92afdSMartin KaFai Lau 	}						\
16*eed92afdSMartin KaFai Lau 	tp;						\
17*eed92afdSMartin KaFai Lau })
18*eed92afdSMartin KaFai Lau 
19*eed92afdSMartin KaFai Lau unsigned short reuse_listen_hport = 0;
20*eed92afdSMartin KaFai Lau unsigned short listen_hport = 0;
21*eed92afdSMartin KaFai Lau char cubic_cc[TCP_CA_NAME_MAX] = "bpf_cubic";
22*eed92afdSMartin KaFai Lau char dctcp_cc[TCP_CA_NAME_MAX] = "bpf_dctcp";
23*eed92afdSMartin KaFai Lau bool random_retry = false;
24*eed92afdSMartin KaFai Lau 
tcp_cc_eq(const char * a,const char * b)25*eed92afdSMartin KaFai Lau static bool tcp_cc_eq(const char *a, const char *b)
26*eed92afdSMartin KaFai Lau {
27*eed92afdSMartin KaFai Lau 	int i;
28*eed92afdSMartin KaFai Lau 
29*eed92afdSMartin KaFai Lau 	for (i = 0; i < TCP_CA_NAME_MAX; i++) {
30*eed92afdSMartin KaFai Lau 		if (a[i] != b[i])
31*eed92afdSMartin KaFai Lau 			return false;
32*eed92afdSMartin KaFai Lau 		if (!a[i])
33*eed92afdSMartin KaFai Lau 			break;
34*eed92afdSMartin KaFai Lau 	}
35*eed92afdSMartin KaFai Lau 
36*eed92afdSMartin KaFai Lau 	return true;
37*eed92afdSMartin KaFai Lau }
38*eed92afdSMartin KaFai Lau 
39*eed92afdSMartin KaFai Lau SEC("iter/tcp")
change_tcp_cc(struct bpf_iter__tcp * ctx)40*eed92afdSMartin KaFai Lau int change_tcp_cc(struct bpf_iter__tcp *ctx)
41*eed92afdSMartin KaFai Lau {
42*eed92afdSMartin KaFai Lau 	char cur_cc[TCP_CA_NAME_MAX];
43*eed92afdSMartin KaFai Lau 	struct tcp_sock *tp;
44*eed92afdSMartin KaFai Lau 	struct sock *sk;
45*eed92afdSMartin KaFai Lau 
46*eed92afdSMartin KaFai Lau 	if (!bpf_tcp_sk(ctx->sk_common))
47*eed92afdSMartin KaFai Lau 		return 0;
48*eed92afdSMartin KaFai Lau 
49*eed92afdSMartin KaFai Lau 	if (sk->sk_family != AF_INET6 ||
50*eed92afdSMartin KaFai Lau 	    (sk->sk_state != TCP_LISTEN &&
51*eed92afdSMartin KaFai Lau 	     sk->sk_state != TCP_ESTABLISHED) ||
52*eed92afdSMartin KaFai Lau 	    (sk->sk_num != reuse_listen_hport &&
53*eed92afdSMartin KaFai Lau 	     sk->sk_num != listen_hport &&
54*eed92afdSMartin KaFai Lau 	     bpf_ntohs(sk->sk_dport) != listen_hport))
55*eed92afdSMartin KaFai Lau 		return 0;
56*eed92afdSMartin KaFai Lau 
57*eed92afdSMartin KaFai Lau 	if (bpf_getsockopt(tp, SOL_TCP, TCP_CONGESTION,
58*eed92afdSMartin KaFai Lau 			   cur_cc, sizeof(cur_cc)))
59*eed92afdSMartin KaFai Lau 		return 0;
60*eed92afdSMartin KaFai Lau 
61*eed92afdSMartin KaFai Lau 	if (!tcp_cc_eq(cur_cc, cubic_cc))
62*eed92afdSMartin KaFai Lau 		return 0;
63*eed92afdSMartin KaFai Lau 
64*eed92afdSMartin KaFai Lau 	if (random_retry && bpf_get_prandom_u32() % 4 == 1)
65*eed92afdSMartin KaFai Lau 		return 1;
66*eed92afdSMartin KaFai Lau 
67*eed92afdSMartin KaFai Lau 	bpf_setsockopt(tp, SOL_TCP, TCP_CONGESTION, dctcp_cc, sizeof(dctcp_cc));
68*eed92afdSMartin KaFai Lau 	return 0;
69*eed92afdSMartin KaFai Lau }
70*eed92afdSMartin KaFai Lau 
71*eed92afdSMartin KaFai Lau char _license[] SEC("license") = "GPL";
72