xref: /openbmc/linux/net/ipv4/bpf_tcp_ca.c (revision 911b8eac)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook  */
3 
4 #include <linux/types.h>
5 #include <linux/bpf_verifier.h>
6 #include <linux/bpf.h>
7 #include <linux/btf.h>
8 #include <linux/filter.h>
9 #include <net/tcp.h>
10 #include <net/bpf_sk_storage.h>
11 
12 static u32 optional_ops[] = {
13 	offsetof(struct tcp_congestion_ops, init),
14 	offsetof(struct tcp_congestion_ops, release),
15 	offsetof(struct tcp_congestion_ops, set_state),
16 	offsetof(struct tcp_congestion_ops, cwnd_event),
17 	offsetof(struct tcp_congestion_ops, in_ack_event),
18 	offsetof(struct tcp_congestion_ops, pkts_acked),
19 	offsetof(struct tcp_congestion_ops, min_tso_segs),
20 	offsetof(struct tcp_congestion_ops, sndbuf_expand),
21 	offsetof(struct tcp_congestion_ops, cong_control),
22 };
23 
24 static u32 unsupported_ops[] = {
25 	offsetof(struct tcp_congestion_ops, get_info),
26 };
27 
28 static const struct btf_type *tcp_sock_type;
29 static u32 tcp_sock_id, sock_id;
30 
31 static struct bpf_func_proto btf_sk_storage_get_proto __read_mostly;
32 static struct bpf_func_proto btf_sk_storage_delete_proto __read_mostly;
33 
34 static void convert_sk_func_proto(struct bpf_func_proto *to, const struct bpf_func_proto *from)
35 {
36 	int i;
37 
38 	*to = *from;
39 	for (i = 0; i < ARRAY_SIZE(to->arg_type); i++) {
40 		if (to->arg_type[i] == ARG_PTR_TO_SOCKET) {
41 			to->arg_type[i] = ARG_PTR_TO_BTF_ID;
42 			to->arg_btf_id[i] = &tcp_sock_id;
43 		}
44 	}
45 }
46 
47 static int bpf_tcp_ca_init(struct btf *btf)
48 {
49 	s32 type_id;
50 
51 	type_id = btf_find_by_name_kind(btf, "sock", BTF_KIND_STRUCT);
52 	if (type_id < 0)
53 		return -EINVAL;
54 	sock_id = type_id;
55 
56 	type_id = btf_find_by_name_kind(btf, "tcp_sock", BTF_KIND_STRUCT);
57 	if (type_id < 0)
58 		return -EINVAL;
59 	tcp_sock_id = type_id;
60 	tcp_sock_type = btf_type_by_id(btf, tcp_sock_id);
61 
62 	convert_sk_func_proto(&btf_sk_storage_get_proto, &bpf_sk_storage_get_proto);
63 	convert_sk_func_proto(&btf_sk_storage_delete_proto, &bpf_sk_storage_delete_proto);
64 
65 	return 0;
66 }
67 
68 static bool is_optional(u32 member_offset)
69 {
70 	unsigned int i;
71 
72 	for (i = 0; i < ARRAY_SIZE(optional_ops); i++) {
73 		if (member_offset == optional_ops[i])
74 			return true;
75 	}
76 
77 	return false;
78 }
79 
80 static bool is_unsupported(u32 member_offset)
81 {
82 	unsigned int i;
83 
84 	for (i = 0; i < ARRAY_SIZE(unsupported_ops); i++) {
85 		if (member_offset == unsupported_ops[i])
86 			return true;
87 	}
88 
89 	return false;
90 }
91 
92 extern struct btf *btf_vmlinux;
93 
94 static bool bpf_tcp_ca_is_valid_access(int off, int size,
95 				       enum bpf_access_type type,
96 				       const struct bpf_prog *prog,
97 				       struct bpf_insn_access_aux *info)
98 {
99 	if (off < 0 || off >= sizeof(__u64) * MAX_BPF_FUNC_ARGS)
100 		return false;
101 	if (type != BPF_READ)
102 		return false;
103 	if (off % size != 0)
104 		return false;
105 
106 	if (!btf_ctx_access(off, size, type, prog, info))
107 		return false;
108 
109 	if (info->reg_type == PTR_TO_BTF_ID && info->btf_id == sock_id)
110 		/* promote it to tcp_sock */
111 		info->btf_id = tcp_sock_id;
112 
113 	return true;
114 }
115 
116 static int bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log *log,
117 					const struct btf_type *t, int off,
118 					int size, enum bpf_access_type atype,
119 					u32 *next_btf_id)
120 {
121 	size_t end;
122 
123 	if (atype == BPF_READ)
124 		return btf_struct_access(log, t, off, size, atype, next_btf_id);
125 
126 	if (t != tcp_sock_type) {
127 		bpf_log(log, "only read is supported\n");
128 		return -EACCES;
129 	}
130 
131 	switch (off) {
132 	case bpf_ctx_range(struct inet_connection_sock, icsk_ca_priv):
133 		end = offsetofend(struct inet_connection_sock, icsk_ca_priv);
134 		break;
135 	case offsetof(struct inet_connection_sock, icsk_ack.pending):
136 		end = offsetofend(struct inet_connection_sock,
137 				  icsk_ack.pending);
138 		break;
139 	case offsetof(struct tcp_sock, snd_cwnd):
140 		end = offsetofend(struct tcp_sock, snd_cwnd);
141 		break;
142 	case offsetof(struct tcp_sock, snd_cwnd_cnt):
143 		end = offsetofend(struct tcp_sock, snd_cwnd_cnt);
144 		break;
145 	case offsetof(struct tcp_sock, snd_ssthresh):
146 		end = offsetofend(struct tcp_sock, snd_ssthresh);
147 		break;
148 	case offsetof(struct tcp_sock, ecn_flags):
149 		end = offsetofend(struct tcp_sock, ecn_flags);
150 		break;
151 	default:
152 		bpf_log(log, "no write support to tcp_sock at off %d\n", off);
153 		return -EACCES;
154 	}
155 
156 	if (off + size > end) {
157 		bpf_log(log,
158 			"write access at off %d with size %d beyond the member of tcp_sock ended at %zu\n",
159 			off, size, end);
160 		return -EACCES;
161 	}
162 
163 	return NOT_INIT;
164 }
165 
166 BPF_CALL_2(bpf_tcp_send_ack, struct tcp_sock *, tp, u32, rcv_nxt)
167 {
168 	/* bpf_tcp_ca prog cannot have NULL tp */
169 	__tcp_send_ack((struct sock *)tp, rcv_nxt);
170 	return 0;
171 }
172 
173 static const struct bpf_func_proto bpf_tcp_send_ack_proto = {
174 	.func		= bpf_tcp_send_ack,
175 	.gpl_only	= false,
176 	/* In case we want to report error later */
177 	.ret_type	= RET_INTEGER,
178 	.arg1_type	= ARG_PTR_TO_BTF_ID,
179 	.arg1_btf_id	= &tcp_sock_id,
180 	.arg2_type	= ARG_ANYTHING,
181 };
182 
183 static const struct bpf_func_proto *
184 bpf_tcp_ca_get_func_proto(enum bpf_func_id func_id,
185 			  const struct bpf_prog *prog)
186 {
187 	switch (func_id) {
188 	case BPF_FUNC_tcp_send_ack:
189 		return &bpf_tcp_send_ack_proto;
190 	case BPF_FUNC_sk_storage_get:
191 		return &btf_sk_storage_get_proto;
192 	case BPF_FUNC_sk_storage_delete:
193 		return &btf_sk_storage_delete_proto;
194 	default:
195 		return bpf_base_func_proto(func_id);
196 	}
197 }
198 
199 static const struct bpf_verifier_ops bpf_tcp_ca_verifier_ops = {
200 	.get_func_proto		= bpf_tcp_ca_get_func_proto,
201 	.is_valid_access	= bpf_tcp_ca_is_valid_access,
202 	.btf_struct_access	= bpf_tcp_ca_btf_struct_access,
203 };
204 
205 static int bpf_tcp_ca_init_member(const struct btf_type *t,
206 				  const struct btf_member *member,
207 				  void *kdata, const void *udata)
208 {
209 	const struct tcp_congestion_ops *utcp_ca;
210 	struct tcp_congestion_ops *tcp_ca;
211 	int prog_fd;
212 	u32 moff;
213 
214 	utcp_ca = (const struct tcp_congestion_ops *)udata;
215 	tcp_ca = (struct tcp_congestion_ops *)kdata;
216 
217 	moff = btf_member_bit_offset(t, member) / 8;
218 	switch (moff) {
219 	case offsetof(struct tcp_congestion_ops, flags):
220 		if (utcp_ca->flags & ~TCP_CONG_MASK)
221 			return -EINVAL;
222 		tcp_ca->flags = utcp_ca->flags;
223 		return 1;
224 	case offsetof(struct tcp_congestion_ops, name):
225 		if (bpf_obj_name_cpy(tcp_ca->name, utcp_ca->name,
226 				     sizeof(tcp_ca->name)) <= 0)
227 			return -EINVAL;
228 		if (tcp_ca_find(utcp_ca->name))
229 			return -EEXIST;
230 		return 1;
231 	}
232 
233 	if (!btf_type_resolve_func_ptr(btf_vmlinux, member->type, NULL))
234 		return 0;
235 
236 	/* Ensure bpf_prog is provided for compulsory func ptr */
237 	prog_fd = (int)(*(unsigned long *)(udata + moff));
238 	if (!prog_fd && !is_optional(moff) && !is_unsupported(moff))
239 		return -EINVAL;
240 
241 	return 0;
242 }
243 
244 static int bpf_tcp_ca_check_member(const struct btf_type *t,
245 				   const struct btf_member *member)
246 {
247 	if (is_unsupported(btf_member_bit_offset(t, member) / 8))
248 		return -ENOTSUPP;
249 	return 0;
250 }
251 
252 static int bpf_tcp_ca_reg(void *kdata)
253 {
254 	return tcp_register_congestion_control(kdata);
255 }
256 
257 static void bpf_tcp_ca_unreg(void *kdata)
258 {
259 	tcp_unregister_congestion_control(kdata);
260 }
261 
262 /* Avoid sparse warning.  It is only used in bpf_struct_ops.c. */
263 extern struct bpf_struct_ops bpf_tcp_congestion_ops;
264 
265 struct bpf_struct_ops bpf_tcp_congestion_ops = {
266 	.verifier_ops = &bpf_tcp_ca_verifier_ops,
267 	.reg = bpf_tcp_ca_reg,
268 	.unreg = bpf_tcp_ca_unreg,
269 	.check_member = bpf_tcp_ca_check_member,
270 	.init_member = bpf_tcp_ca_init_member,
271 	.init = bpf_tcp_ca_init,
272 	.name = "tcp_congestion_ops",
273 };
274