xref: /openbmc/linux/net/ipv4/inet_diag.c (revision 8622a0e5)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * inet_diag.c	Module for monitoring INET transport protocols sockets.
4  *
5  * Authors:	Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
6  */
7 
8 #include <linux/kernel.h>
9 #include <linux/module.h>
10 #include <linux/types.h>
11 #include <linux/fcntl.h>
12 #include <linux/random.h>
13 #include <linux/slab.h>
14 #include <linux/cache.h>
15 #include <linux/init.h>
16 #include <linux/time.h>
17 
18 #include <net/icmp.h>
19 #include <net/tcp.h>
20 #include <net/ipv6.h>
21 #include <net/inet_common.h>
22 #include <net/inet_connection_sock.h>
23 #include <net/inet_hashtables.h>
24 #include <net/inet_timewait_sock.h>
25 #include <net/inet6_hashtables.h>
26 #include <net/bpf_sk_storage.h>
27 #include <net/netlink.h>
28 
29 #include <linux/inet.h>
30 #include <linux/stddef.h>
31 
32 #include <linux/inet_diag.h>
33 #include <linux/sock_diag.h>
34 
35 static const struct inet_diag_handler **inet_diag_table;
36 
37 struct inet_diag_entry {
38 	const __be32 *saddr;
39 	const __be32 *daddr;
40 	u16 sport;
41 	u16 dport;
42 	u16 family;
43 	u16 userlocks;
44 	u32 ifindex;
45 	u32 mark;
46 };
47 
48 static DEFINE_MUTEX(inet_diag_table_mutex);
49 
50 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
51 {
52 	if (!inet_diag_table[proto])
53 		sock_load_diag_module(AF_INET, proto);
54 
55 	mutex_lock(&inet_diag_table_mutex);
56 	if (!inet_diag_table[proto])
57 		return ERR_PTR(-ENOENT);
58 
59 	return inet_diag_table[proto];
60 }
61 
62 static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
63 {
64 	mutex_unlock(&inet_diag_table_mutex);
65 }
66 
67 void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
68 {
69 	r->idiag_family = sk->sk_family;
70 
71 	r->id.idiag_sport = htons(sk->sk_num);
72 	r->id.idiag_dport = sk->sk_dport;
73 	r->id.idiag_if = sk->sk_bound_dev_if;
74 	sock_diag_save_cookie(sk, r->id.idiag_cookie);
75 
76 #if IS_ENABLED(CONFIG_IPV6)
77 	if (sk->sk_family == AF_INET6) {
78 		*(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr;
79 		*(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr;
80 	} else
81 #endif
82 	{
83 	memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
84 	memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
85 
86 	r->id.idiag_src[0] = sk->sk_rcv_saddr;
87 	r->id.idiag_dst[0] = sk->sk_daddr;
88 	}
89 }
90 EXPORT_SYMBOL_GPL(inet_diag_msg_common_fill);
91 
92 static size_t inet_sk_attr_size(struct sock *sk,
93 				const struct inet_diag_req_v2 *req,
94 				bool net_admin)
95 {
96 	const struct inet_diag_handler *handler;
97 	size_t aux = 0;
98 
99 	handler = inet_diag_table[req->sdiag_protocol];
100 	if (handler && handler->idiag_get_aux_size)
101 		aux = handler->idiag_get_aux_size(sk, net_admin);
102 
103 	return	  nla_total_size(sizeof(struct tcp_info))
104 		+ nla_total_size(1) /* INET_DIAG_SHUTDOWN */
105 		+ nla_total_size(1) /* INET_DIAG_TOS */
106 		+ nla_total_size(1) /* INET_DIAG_TCLASS */
107 		+ nla_total_size(4) /* INET_DIAG_MARK */
108 		+ nla_total_size(4) /* INET_DIAG_CLASS_ID */
109 		+ nla_total_size(sizeof(struct inet_diag_meminfo))
110 		+ nla_total_size(sizeof(struct inet_diag_msg))
111 		+ nla_total_size(SK_MEMINFO_VARS * sizeof(u32))
112 		+ nla_total_size(TCP_CA_NAME_MAX)
113 		+ nla_total_size(sizeof(struct tcpvegas_info))
114 		+ aux
115 		+ 64;
116 }
117 
118 int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
119 			     struct inet_diag_msg *r, int ext,
120 			     struct user_namespace *user_ns,
121 			     bool net_admin)
122 {
123 	const struct inet_sock *inet = inet_sk(sk);
124 
125 	if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
126 		goto errout;
127 
128 	/* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
129 	 * hence this needs to be included regardless of socket family.
130 	 */
131 	if (ext & (1 << (INET_DIAG_TOS - 1)))
132 		if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0)
133 			goto errout;
134 
135 #if IS_ENABLED(CONFIG_IPV6)
136 	if (r->idiag_family == AF_INET6) {
137 		if (ext & (1 << (INET_DIAG_TCLASS - 1)))
138 			if (nla_put_u8(skb, INET_DIAG_TCLASS,
139 				       inet6_sk(sk)->tclass) < 0)
140 				goto errout;
141 
142 		if (((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) &&
143 		    nla_put_u8(skb, INET_DIAG_SKV6ONLY, ipv6_only_sock(sk)))
144 			goto errout;
145 	}
146 #endif
147 
148 	if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, sk->sk_mark))
149 		goto errout;
150 
151 	r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
152 	r->idiag_inode = sock_i_ino(sk);
153 
154 	return 0;
155 errout:
156 	return 1;
157 }
158 EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);
159 
160 #define MAX_DUMP_ALLOC_SIZE (KMALLOC_MAX_SIZE - SKB_DATA_ALIGN(sizeof(struct skb_shared_info)))
161 
162 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
163 		      struct sk_buff *skb, struct netlink_callback *cb,
164 		      const struct inet_diag_req_v2 *req,
165 		      u16 nlmsg_flags, bool net_admin)
166 {
167 	const struct tcp_congestion_ops *ca_ops;
168 	const struct inet_diag_handler *handler;
169 	struct inet_diag_dump_data *cb_data;
170 	int ext = req->idiag_ext;
171 	struct inet_diag_msg *r;
172 	struct nlmsghdr  *nlh;
173 	struct nlattr *attr;
174 	void *info = NULL;
175 
176 	cb_data = cb->data;
177 	handler = inet_diag_table[req->sdiag_protocol];
178 	BUG_ON(!handler);
179 
180 	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
181 			cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
182 	if (!nlh)
183 		return -EMSGSIZE;
184 
185 	r = nlmsg_data(nlh);
186 	BUG_ON(!sk_fullsock(sk));
187 
188 	inet_diag_msg_common_fill(r, sk);
189 	r->idiag_state = sk->sk_state;
190 	r->idiag_timer = 0;
191 	r->idiag_retrans = 0;
192 
193 	if (inet_diag_msg_attrs_fill(sk, skb, r, ext,
194 				     sk_user_ns(NETLINK_CB(cb->skb).sk),
195 				     net_admin))
196 		goto errout;
197 
198 	if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
199 		struct inet_diag_meminfo minfo = {
200 			.idiag_rmem = sk_rmem_alloc_get(sk),
201 			.idiag_wmem = READ_ONCE(sk->sk_wmem_queued),
202 			.idiag_fmem = sk->sk_forward_alloc,
203 			.idiag_tmem = sk_wmem_alloc_get(sk),
204 		};
205 
206 		if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
207 			goto errout;
208 	}
209 
210 	if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
211 		if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
212 			goto errout;
213 
214 	/*
215 	 * RAW sockets might have user-defined protocols assigned,
216 	 * so report the one supplied on socket creation.
217 	 */
218 	if (sk->sk_type == SOCK_RAW) {
219 		if (nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))
220 			goto errout;
221 	}
222 
223 	if (!icsk) {
224 		handler->idiag_get_info(sk, r, NULL);
225 		goto out;
226 	}
227 
228 	if (icsk->icsk_pending == ICSK_TIME_RETRANS ||
229 	    icsk->icsk_pending == ICSK_TIME_REO_TIMEOUT ||
230 	    icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) {
231 		r->idiag_timer = 1;
232 		r->idiag_retrans = icsk->icsk_retransmits;
233 		r->idiag_expires =
234 			jiffies_delta_to_msecs(icsk->icsk_timeout - jiffies);
235 	} else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
236 		r->idiag_timer = 4;
237 		r->idiag_retrans = icsk->icsk_probes_out;
238 		r->idiag_expires =
239 			jiffies_delta_to_msecs(icsk->icsk_timeout - jiffies);
240 	} else if (timer_pending(&sk->sk_timer)) {
241 		r->idiag_timer = 2;
242 		r->idiag_retrans = icsk->icsk_probes_out;
243 		r->idiag_expires =
244 			jiffies_delta_to_msecs(sk->sk_timer.expires - jiffies);
245 	} else {
246 		r->idiag_timer = 0;
247 		r->idiag_expires = 0;
248 	}
249 
250 	if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) {
251 		attr = nla_reserve_64bit(skb, INET_DIAG_INFO,
252 					 handler->idiag_info_size,
253 					 INET_DIAG_PAD);
254 		if (!attr)
255 			goto errout;
256 
257 		info = nla_data(attr);
258 	}
259 
260 	if (ext & (1 << (INET_DIAG_CONG - 1))) {
261 		int err = 0;
262 
263 		rcu_read_lock();
264 		ca_ops = READ_ONCE(icsk->icsk_ca_ops);
265 		if (ca_ops)
266 			err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name);
267 		rcu_read_unlock();
268 		if (err < 0)
269 			goto errout;
270 	}
271 
272 	handler->idiag_get_info(sk, r, info);
273 
274 	if (ext & (1 << (INET_DIAG_INFO - 1)) && handler->idiag_get_aux)
275 		if (handler->idiag_get_aux(sk, net_admin, skb) < 0)
276 			goto errout;
277 
278 	if (sk->sk_state < TCP_TIME_WAIT) {
279 		union tcp_cc_info info;
280 		size_t sz = 0;
281 		int attr;
282 
283 		rcu_read_lock();
284 		ca_ops = READ_ONCE(icsk->icsk_ca_ops);
285 		if (ca_ops && ca_ops->get_info)
286 			sz = ca_ops->get_info(sk, ext, &attr, &info);
287 		rcu_read_unlock();
288 		if (sz && nla_put(skb, attr, sz, &info) < 0)
289 			goto errout;
290 	}
291 
292 	if (ext & (1 << (INET_DIAG_CLASS_ID - 1)) ||
293 	    ext & (1 << (INET_DIAG_TCLASS - 1))) {
294 		u32 classid = 0;
295 
296 #ifdef CONFIG_SOCK_CGROUP_DATA
297 		classid = sock_cgroup_classid(&sk->sk_cgrp_data);
298 #endif
299 		/* Fallback to socket priority if class id isn't set.
300 		 * Classful qdiscs use it as direct reference to class.
301 		 * For cgroup2 classid is always zero.
302 		 */
303 		if (!classid)
304 			classid = sk->sk_priority;
305 
306 		if (nla_put_u32(skb, INET_DIAG_CLASS_ID, classid))
307 			goto errout;
308 	}
309 
310 	/* Keep it at the end for potential retry with a larger skb,
311 	 * or else do best-effort fitting, which is only done for the
312 	 * first_nlmsg.
313 	 */
314 	if (cb_data->bpf_stg_diag) {
315 		bool first_nlmsg = ((unsigned char *)nlh == skb->data);
316 		unsigned int prev_min_dump_alloc;
317 		unsigned int total_nla_size = 0;
318 		unsigned int msg_len;
319 		int err;
320 
321 		msg_len = skb_tail_pointer(skb) - (unsigned char *)nlh;
322 		err = bpf_sk_storage_diag_put(cb_data->bpf_stg_diag, sk, skb,
323 					      INET_DIAG_SK_BPF_STORAGES,
324 					      &total_nla_size);
325 
326 		if (!err)
327 			goto out;
328 
329 		total_nla_size += msg_len;
330 		prev_min_dump_alloc = cb->min_dump_alloc;
331 		if (total_nla_size > prev_min_dump_alloc)
332 			cb->min_dump_alloc = min_t(u32, total_nla_size,
333 						   MAX_DUMP_ALLOC_SIZE);
334 
335 		if (!first_nlmsg)
336 			goto errout;
337 
338 		if (cb->min_dump_alloc > prev_min_dump_alloc)
339 			/* Retry with pskb_expand_head() with
340 			 * __GFP_DIRECT_RECLAIM
341 			 */
342 			goto errout;
343 
344 		WARN_ON_ONCE(total_nla_size <= prev_min_dump_alloc);
345 
346 		/* Send what we have for this sk
347 		 * and move on to the next sk in the following
348 		 * dump()
349 		 */
350 	}
351 
352 out:
353 	nlmsg_end(skb, nlh);
354 	return 0;
355 
356 errout:
357 	nlmsg_cancel(skb, nlh);
358 	return -EMSGSIZE;
359 }
360 EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
361 
362 static int inet_twsk_diag_fill(struct sock *sk,
363 			       struct sk_buff *skb,
364 			       struct netlink_callback *cb,
365 			       u16 nlmsg_flags)
366 {
367 	struct inet_timewait_sock *tw = inet_twsk(sk);
368 	struct inet_diag_msg *r;
369 	struct nlmsghdr *nlh;
370 	long tmo;
371 
372 	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid,
373 			cb->nlh->nlmsg_seq, cb->nlh->nlmsg_type,
374 			sizeof(*r), nlmsg_flags);
375 	if (!nlh)
376 		return -EMSGSIZE;
377 
378 	r = nlmsg_data(nlh);
379 	BUG_ON(tw->tw_state != TCP_TIME_WAIT);
380 
381 	inet_diag_msg_common_fill(r, sk);
382 	r->idiag_retrans      = 0;
383 
384 	r->idiag_state	      = tw->tw_substate;
385 	r->idiag_timer	      = 3;
386 	tmo = tw->tw_timer.expires - jiffies;
387 	r->idiag_expires      = jiffies_delta_to_msecs(tmo);
388 	r->idiag_rqueue	      = 0;
389 	r->idiag_wqueue	      = 0;
390 	r->idiag_uid	      = 0;
391 	r->idiag_inode	      = 0;
392 
393 	nlmsg_end(skb, nlh);
394 	return 0;
395 }
396 
397 static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
398 			      struct netlink_callback *cb,
399 			      u16 nlmsg_flags, bool net_admin)
400 {
401 	struct request_sock *reqsk = inet_reqsk(sk);
402 	struct inet_diag_msg *r;
403 	struct nlmsghdr *nlh;
404 	long tmo;
405 
406 	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
407 			cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
408 	if (!nlh)
409 		return -EMSGSIZE;
410 
411 	r = nlmsg_data(nlh);
412 	inet_diag_msg_common_fill(r, sk);
413 	r->idiag_state = TCP_SYN_RECV;
414 	r->idiag_timer = 1;
415 	r->idiag_retrans = reqsk->num_retrans;
416 
417 	BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
418 		     offsetof(struct sock, sk_cookie));
419 
420 	tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies;
421 	r->idiag_expires = jiffies_delta_to_msecs(tmo);
422 	r->idiag_rqueue	= 0;
423 	r->idiag_wqueue	= 0;
424 	r->idiag_uid	= 0;
425 	r->idiag_inode	= 0;
426 
427 	if (net_admin && nla_put_u32(skb, INET_DIAG_MARK,
428 				     inet_rsk(reqsk)->ir_mark))
429 		return -EMSGSIZE;
430 
431 	nlmsg_end(skb, nlh);
432 	return 0;
433 }
434 
435 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
436 			struct netlink_callback *cb,
437 			const struct inet_diag_req_v2 *r,
438 			u16 nlmsg_flags, bool net_admin)
439 {
440 	if (sk->sk_state == TCP_TIME_WAIT)
441 		return inet_twsk_diag_fill(sk, skb, cb, nlmsg_flags);
442 
443 	if (sk->sk_state == TCP_NEW_SYN_RECV)
444 		return inet_req_diag_fill(sk, skb, cb, nlmsg_flags, net_admin);
445 
446 	return inet_sk_diag_fill(sk, inet_csk(sk), skb, cb, r, nlmsg_flags,
447 				 net_admin);
448 }
449 
450 struct sock *inet_diag_find_one_icsk(struct net *net,
451 				     struct inet_hashinfo *hashinfo,
452 				     const struct inet_diag_req_v2 *req)
453 {
454 	struct sock *sk;
455 
456 	rcu_read_lock();
457 	if (req->sdiag_family == AF_INET)
458 		sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[0],
459 				 req->id.idiag_dport, req->id.idiag_src[0],
460 				 req->id.idiag_sport, req->id.idiag_if);
461 #if IS_ENABLED(CONFIG_IPV6)
462 	else if (req->sdiag_family == AF_INET6) {
463 		if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
464 		    ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src))
465 			sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[3],
466 					 req->id.idiag_dport, req->id.idiag_src[3],
467 					 req->id.idiag_sport, req->id.idiag_if);
468 		else
469 			sk = inet6_lookup(net, hashinfo, NULL, 0,
470 					  (struct in6_addr *)req->id.idiag_dst,
471 					  req->id.idiag_dport,
472 					  (struct in6_addr *)req->id.idiag_src,
473 					  req->id.idiag_sport,
474 					  req->id.idiag_if);
475 	}
476 #endif
477 	else {
478 		rcu_read_unlock();
479 		return ERR_PTR(-EINVAL);
480 	}
481 	rcu_read_unlock();
482 	if (!sk)
483 		return ERR_PTR(-ENOENT);
484 
485 	if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) {
486 		sock_gen_put(sk);
487 		return ERR_PTR(-ENOENT);
488 	}
489 
490 	return sk;
491 }
492 EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk);
493 
494 int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
495 			    struct netlink_callback *cb,
496 			    const struct inet_diag_req_v2 *req)
497 {
498 	struct sk_buff *in_skb = cb->skb;
499 	bool net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN);
500 	struct net *net = sock_net(in_skb->sk);
501 	struct sk_buff *rep;
502 	struct sock *sk;
503 	int err;
504 
505 	sk = inet_diag_find_one_icsk(net, hashinfo, req);
506 	if (IS_ERR(sk))
507 		return PTR_ERR(sk);
508 
509 	rep = nlmsg_new(inet_sk_attr_size(sk, req, net_admin), GFP_KERNEL);
510 	if (!rep) {
511 		err = -ENOMEM;
512 		goto out;
513 	}
514 
515 	err = sk_diag_fill(sk, rep, cb, req, 0, net_admin);
516 	if (err < 0) {
517 		WARN_ON(err == -EMSGSIZE);
518 		nlmsg_free(rep);
519 		goto out;
520 	}
521 	err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid,
522 			      MSG_DONTWAIT);
523 	if (err > 0)
524 		err = 0;
525 
526 out:
527 	if (sk)
528 		sock_gen_put(sk);
529 
530 	return err;
531 }
532 EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
533 
534 static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
535 			       const struct nlmsghdr *nlh,
536 			       const struct inet_diag_req_v2 *req)
537 {
538 	const struct inet_diag_handler *handler;
539 	int err;
540 
541 	handler = inet_diag_lock_handler(req->sdiag_protocol);
542 	if (IS_ERR(handler)) {
543 		err = PTR_ERR(handler);
544 	} else if (cmd == SOCK_DIAG_BY_FAMILY) {
545 		struct inet_diag_dump_data empty_dump_data = {};
546 		struct netlink_callback cb = {
547 			.nlh = nlh,
548 			.skb = in_skb,
549 			.data = &empty_dump_data,
550 		};
551 		err = handler->dump_one(&cb, req);
552 	} else if (cmd == SOCK_DESTROY && handler->destroy) {
553 		err = handler->destroy(in_skb, req);
554 	} else {
555 		err = -EOPNOTSUPP;
556 	}
557 	inet_diag_unlock_handler(handler);
558 
559 	return err;
560 }
561 
562 static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
563 {
564 	int words = bits >> 5;
565 
566 	bits &= 0x1f;
567 
568 	if (words) {
569 		if (memcmp(a1, a2, words << 2))
570 			return 0;
571 	}
572 	if (bits) {
573 		__be32 w1, w2;
574 		__be32 mask;
575 
576 		w1 = a1[words];
577 		w2 = a2[words];
578 
579 		mask = htonl((0xffffffff) << (32 - bits));
580 
581 		if ((w1 ^ w2) & mask)
582 			return 0;
583 	}
584 
585 	return 1;
586 }
587 
588 static int inet_diag_bc_run(const struct nlattr *_bc,
589 			    const struct inet_diag_entry *entry)
590 {
591 	const void *bc = nla_data(_bc);
592 	int len = nla_len(_bc);
593 
594 	while (len > 0) {
595 		int yes = 1;
596 		const struct inet_diag_bc_op *op = bc;
597 
598 		switch (op->code) {
599 		case INET_DIAG_BC_NOP:
600 			break;
601 		case INET_DIAG_BC_JMP:
602 			yes = 0;
603 			break;
604 		case INET_DIAG_BC_S_EQ:
605 			yes = entry->sport == op[1].no;
606 			break;
607 		case INET_DIAG_BC_S_GE:
608 			yes = entry->sport >= op[1].no;
609 			break;
610 		case INET_DIAG_BC_S_LE:
611 			yes = entry->sport <= op[1].no;
612 			break;
613 		case INET_DIAG_BC_D_EQ:
614 			yes = entry->dport == op[1].no;
615 			break;
616 		case INET_DIAG_BC_D_GE:
617 			yes = entry->dport >= op[1].no;
618 			break;
619 		case INET_DIAG_BC_D_LE:
620 			yes = entry->dport <= op[1].no;
621 			break;
622 		case INET_DIAG_BC_AUTO:
623 			yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
624 			break;
625 		case INET_DIAG_BC_S_COND:
626 		case INET_DIAG_BC_D_COND: {
627 			const struct inet_diag_hostcond *cond;
628 			const __be32 *addr;
629 
630 			cond = (const struct inet_diag_hostcond *)(op + 1);
631 			if (cond->port != -1 &&
632 			    cond->port != (op->code == INET_DIAG_BC_S_COND ?
633 					     entry->sport : entry->dport)) {
634 				yes = 0;
635 				break;
636 			}
637 
638 			if (op->code == INET_DIAG_BC_S_COND)
639 				addr = entry->saddr;
640 			else
641 				addr = entry->daddr;
642 
643 			if (cond->family != AF_UNSPEC &&
644 			    cond->family != entry->family) {
645 				if (entry->family == AF_INET6 &&
646 				    cond->family == AF_INET) {
647 					if (addr[0] == 0 && addr[1] == 0 &&
648 					    addr[2] == htonl(0xffff) &&
649 					    bitstring_match(addr + 3,
650 							    cond->addr,
651 							    cond->prefix_len))
652 						break;
653 				}
654 				yes = 0;
655 				break;
656 			}
657 
658 			if (cond->prefix_len == 0)
659 				break;
660 			if (bitstring_match(addr, cond->addr,
661 					    cond->prefix_len))
662 				break;
663 			yes = 0;
664 			break;
665 		}
666 		case INET_DIAG_BC_DEV_COND: {
667 			u32 ifindex;
668 
669 			ifindex = *((const u32 *)(op + 1));
670 			if (ifindex != entry->ifindex)
671 				yes = 0;
672 			break;
673 		}
674 		case INET_DIAG_BC_MARK_COND: {
675 			struct inet_diag_markcond *cond;
676 
677 			cond = (struct inet_diag_markcond *)(op + 1);
678 			if ((entry->mark & cond->mask) != cond->mark)
679 				yes = 0;
680 			break;
681 		}
682 		}
683 
684 		if (yes) {
685 			len -= op->yes;
686 			bc += op->yes;
687 		} else {
688 			len -= op->no;
689 			bc += op->no;
690 		}
691 	}
692 	return len == 0;
693 }
694 
695 /* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV)
696  */
697 static void entry_fill_addrs(struct inet_diag_entry *entry,
698 			     const struct sock *sk)
699 {
700 #if IS_ENABLED(CONFIG_IPV6)
701 	if (sk->sk_family == AF_INET6) {
702 		entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32;
703 		entry->daddr = sk->sk_v6_daddr.s6_addr32;
704 	} else
705 #endif
706 	{
707 		entry->saddr = &sk->sk_rcv_saddr;
708 		entry->daddr = &sk->sk_daddr;
709 	}
710 }
711 
712 int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
713 {
714 	struct inet_sock *inet = inet_sk(sk);
715 	struct inet_diag_entry entry;
716 
717 	if (!bc)
718 		return 1;
719 
720 	entry.family = sk->sk_family;
721 	entry_fill_addrs(&entry, sk);
722 	entry.sport = inet->inet_num;
723 	entry.dport = ntohs(inet->inet_dport);
724 	entry.ifindex = sk->sk_bound_dev_if;
725 	entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
726 	if (sk_fullsock(sk))
727 		entry.mark = sk->sk_mark;
728 	else if (sk->sk_state == TCP_NEW_SYN_RECV)
729 		entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark;
730 	else
731 		entry.mark = 0;
732 
733 	return inet_diag_bc_run(bc, &entry);
734 }
735 EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
736 
737 static int valid_cc(const void *bc, int len, int cc)
738 {
739 	while (len >= 0) {
740 		const struct inet_diag_bc_op *op = bc;
741 
742 		if (cc > len)
743 			return 0;
744 		if (cc == len)
745 			return 1;
746 		if (op->yes < 4 || op->yes & 3)
747 			return 0;
748 		len -= op->yes;
749 		bc  += op->yes;
750 	}
751 	return 0;
752 }
753 
754 /* data is u32 ifindex */
755 static bool valid_devcond(const struct inet_diag_bc_op *op, int len,
756 			  int *min_len)
757 {
758 	/* Check ifindex space. */
759 	*min_len += sizeof(u32);
760 	if (len < *min_len)
761 		return false;
762 
763 	return true;
764 }
765 /* Validate an inet_diag_hostcond. */
766 static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
767 			   int *min_len)
768 {
769 	struct inet_diag_hostcond *cond;
770 	int addr_len;
771 
772 	/* Check hostcond space. */
773 	*min_len += sizeof(struct inet_diag_hostcond);
774 	if (len < *min_len)
775 		return false;
776 	cond = (struct inet_diag_hostcond *)(op + 1);
777 
778 	/* Check address family and address length. */
779 	switch (cond->family) {
780 	case AF_UNSPEC:
781 		addr_len = 0;
782 		break;
783 	case AF_INET:
784 		addr_len = sizeof(struct in_addr);
785 		break;
786 	case AF_INET6:
787 		addr_len = sizeof(struct in6_addr);
788 		break;
789 	default:
790 		return false;
791 	}
792 	*min_len += addr_len;
793 	if (len < *min_len)
794 		return false;
795 
796 	/* Check prefix length (in bits) vs address length (in bytes). */
797 	if (cond->prefix_len > 8 * addr_len)
798 		return false;
799 
800 	return true;
801 }
802 
803 /* Validate a port comparison operator. */
804 static bool valid_port_comparison(const struct inet_diag_bc_op *op,
805 				  int len, int *min_len)
806 {
807 	/* Port comparisons put the port in a follow-on inet_diag_bc_op. */
808 	*min_len += sizeof(struct inet_diag_bc_op);
809 	if (len < *min_len)
810 		return false;
811 	return true;
812 }
813 
814 static bool valid_markcond(const struct inet_diag_bc_op *op, int len,
815 			   int *min_len)
816 {
817 	*min_len += sizeof(struct inet_diag_markcond);
818 	return len >= *min_len;
819 }
820 
821 static int inet_diag_bc_audit(const struct nlattr *attr,
822 			      const struct sk_buff *skb)
823 {
824 	bool net_admin = netlink_net_capable(skb, CAP_NET_ADMIN);
825 	const void *bytecode, *bc;
826 	int bytecode_len, len;
827 
828 	if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op))
829 		return -EINVAL;
830 
831 	bytecode = bc = nla_data(attr);
832 	len = bytecode_len = nla_len(attr);
833 
834 	while (len > 0) {
835 		int min_len = sizeof(struct inet_diag_bc_op);
836 		const struct inet_diag_bc_op *op = bc;
837 
838 		switch (op->code) {
839 		case INET_DIAG_BC_S_COND:
840 		case INET_DIAG_BC_D_COND:
841 			if (!valid_hostcond(bc, len, &min_len))
842 				return -EINVAL;
843 			break;
844 		case INET_DIAG_BC_DEV_COND:
845 			if (!valid_devcond(bc, len, &min_len))
846 				return -EINVAL;
847 			break;
848 		case INET_DIAG_BC_S_EQ:
849 		case INET_DIAG_BC_S_GE:
850 		case INET_DIAG_BC_S_LE:
851 		case INET_DIAG_BC_D_EQ:
852 		case INET_DIAG_BC_D_GE:
853 		case INET_DIAG_BC_D_LE:
854 			if (!valid_port_comparison(bc, len, &min_len))
855 				return -EINVAL;
856 			break;
857 		case INET_DIAG_BC_MARK_COND:
858 			if (!net_admin)
859 				return -EPERM;
860 			if (!valid_markcond(bc, len, &min_len))
861 				return -EINVAL;
862 			break;
863 		case INET_DIAG_BC_AUTO:
864 		case INET_DIAG_BC_JMP:
865 		case INET_DIAG_BC_NOP:
866 			break;
867 		default:
868 			return -EINVAL;
869 		}
870 
871 		if (op->code != INET_DIAG_BC_NOP) {
872 			if (op->no < min_len || op->no > len + 4 || op->no & 3)
873 				return -EINVAL;
874 			if (op->no < len &&
875 			    !valid_cc(bytecode, bytecode_len, len - op->no))
876 				return -EINVAL;
877 		}
878 
879 		if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
880 			return -EINVAL;
881 		bc  += op->yes;
882 		len -= op->yes;
883 	}
884 	return len == 0 ? 0 : -EINVAL;
885 }
886 
887 static void twsk_build_assert(void)
888 {
889 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) !=
890 		     offsetof(struct sock, sk_family));
891 
892 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) !=
893 		     offsetof(struct inet_sock, inet_num));
894 
895 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) !=
896 		     offsetof(struct inet_sock, inet_dport));
897 
898 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) !=
899 		     offsetof(struct inet_sock, inet_rcv_saddr));
900 
901 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) !=
902 		     offsetof(struct inet_sock, inet_daddr));
903 
904 #if IS_ENABLED(CONFIG_IPV6)
905 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) !=
906 		     offsetof(struct sock, sk_v6_rcv_saddr));
907 
908 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) !=
909 		     offsetof(struct sock, sk_v6_daddr));
910 #endif
911 }
912 
913 void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
914 			 struct netlink_callback *cb,
915 			 const struct inet_diag_req_v2 *r)
916 {
917 	bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN);
918 	struct inet_diag_dump_data *cb_data = cb->data;
919 	struct net *net = sock_net(skb->sk);
920 	u32 idiag_states = r->idiag_states;
921 	int i, num, s_i, s_num;
922 	struct nlattr *bc;
923 	struct sock *sk;
924 
925 	bc = cb_data->inet_diag_nla_bc;
926 	if (idiag_states & TCPF_SYN_RECV)
927 		idiag_states |= TCPF_NEW_SYN_RECV;
928 	s_i = cb->args[1];
929 	s_num = num = cb->args[2];
930 
931 	if (cb->args[0] == 0) {
932 		if (!(idiag_states & TCPF_LISTEN) || r->id.idiag_dport)
933 			goto skip_listen_ht;
934 
935 		for (i = s_i; i < INET_LHTABLE_SIZE; i++) {
936 			struct inet_listen_hashbucket *ilb;
937 			struct hlist_nulls_node *node;
938 
939 			num = 0;
940 			ilb = &hashinfo->listening_hash[i];
941 			spin_lock(&ilb->lock);
942 			sk_nulls_for_each(sk, node, &ilb->nulls_head) {
943 				struct inet_sock *inet = inet_sk(sk);
944 
945 				if (!net_eq(sock_net(sk), net))
946 					continue;
947 
948 				if (num < s_num) {
949 					num++;
950 					continue;
951 				}
952 
953 				if (r->sdiag_family != AF_UNSPEC &&
954 				    sk->sk_family != r->sdiag_family)
955 					goto next_listen;
956 
957 				if (r->id.idiag_sport != inet->inet_sport &&
958 				    r->id.idiag_sport)
959 					goto next_listen;
960 
961 				if (!inet_diag_bc_sk(bc, sk))
962 					goto next_listen;
963 
964 				if (inet_sk_diag_fill(sk, inet_csk(sk), skb,
965 						      cb, r, NLM_F_MULTI,
966 						      net_admin) < 0) {
967 					spin_unlock(&ilb->lock);
968 					goto done;
969 				}
970 
971 next_listen:
972 				++num;
973 			}
974 			spin_unlock(&ilb->lock);
975 
976 			s_num = 0;
977 		}
978 skip_listen_ht:
979 		cb->args[0] = 1;
980 		s_i = num = s_num = 0;
981 	}
982 
983 	if (!(idiag_states & ~TCPF_LISTEN))
984 		goto out;
985 
986 #define SKARR_SZ 16
987 	for (i = s_i; i <= hashinfo->ehash_mask; i++) {
988 		struct inet_ehash_bucket *head = &hashinfo->ehash[i];
989 		spinlock_t *lock = inet_ehash_lockp(hashinfo, i);
990 		struct hlist_nulls_node *node;
991 		struct sock *sk_arr[SKARR_SZ];
992 		int num_arr[SKARR_SZ];
993 		int idx, accum, res;
994 
995 		if (hlist_nulls_empty(&head->chain))
996 			continue;
997 
998 		if (i > s_i)
999 			s_num = 0;
1000 
1001 next_chunk:
1002 		num = 0;
1003 		accum = 0;
1004 		spin_lock_bh(lock);
1005 		sk_nulls_for_each(sk, node, &head->chain) {
1006 			int state;
1007 
1008 			if (!net_eq(sock_net(sk), net))
1009 				continue;
1010 			if (num < s_num)
1011 				goto next_normal;
1012 			state = (sk->sk_state == TCP_TIME_WAIT) ?
1013 				inet_twsk(sk)->tw_substate : sk->sk_state;
1014 			if (!(idiag_states & (1 << state)))
1015 				goto next_normal;
1016 			if (r->sdiag_family != AF_UNSPEC &&
1017 			    sk->sk_family != r->sdiag_family)
1018 				goto next_normal;
1019 			if (r->id.idiag_sport != htons(sk->sk_num) &&
1020 			    r->id.idiag_sport)
1021 				goto next_normal;
1022 			if (r->id.idiag_dport != sk->sk_dport &&
1023 			    r->id.idiag_dport)
1024 				goto next_normal;
1025 			twsk_build_assert();
1026 
1027 			if (!inet_diag_bc_sk(bc, sk))
1028 				goto next_normal;
1029 
1030 			if (!refcount_inc_not_zero(&sk->sk_refcnt))
1031 				goto next_normal;
1032 
1033 			num_arr[accum] = num;
1034 			sk_arr[accum] = sk;
1035 			if (++accum == SKARR_SZ)
1036 				break;
1037 next_normal:
1038 			++num;
1039 		}
1040 		spin_unlock_bh(lock);
1041 		res = 0;
1042 		for (idx = 0; idx < accum; idx++) {
1043 			if (res >= 0) {
1044 				res = sk_diag_fill(sk_arr[idx], skb, cb, r,
1045 						   NLM_F_MULTI, net_admin);
1046 				if (res < 0)
1047 					num = num_arr[idx];
1048 			}
1049 			sock_gen_put(sk_arr[idx]);
1050 		}
1051 		if (res < 0)
1052 			break;
1053 		cond_resched();
1054 		if (accum == SKARR_SZ) {
1055 			s_num = num + 1;
1056 			goto next_chunk;
1057 		}
1058 	}
1059 
1060 done:
1061 	cb->args[1] = i;
1062 	cb->args[2] = num;
1063 out:
1064 	;
1065 }
1066 EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
1067 
1068 static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
1069 			    const struct inet_diag_req_v2 *r)
1070 {
1071 	const struct inet_diag_handler *handler;
1072 	u32 prev_min_dump_alloc;
1073 	int err = 0;
1074 
1075 again:
1076 	prev_min_dump_alloc = cb->min_dump_alloc;
1077 	handler = inet_diag_lock_handler(r->sdiag_protocol);
1078 	if (!IS_ERR(handler))
1079 		handler->dump(skb, cb, r);
1080 	else
1081 		err = PTR_ERR(handler);
1082 	inet_diag_unlock_handler(handler);
1083 
1084 	/* The skb is not large enough to fit one sk info and
1085 	 * inet_sk_diag_fill() has requested for a larger skb.
1086 	 */
1087 	if (!skb->len && cb->min_dump_alloc > prev_min_dump_alloc) {
1088 		err = pskb_expand_head(skb, 0, cb->min_dump_alloc, GFP_KERNEL);
1089 		if (!err)
1090 			goto again;
1091 	}
1092 
1093 	return err ? : skb->len;
1094 }
1095 
1096 static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
1097 {
1098 	return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh));
1099 }
1100 
1101 static int __inet_diag_dump_start(struct netlink_callback *cb, int hdrlen)
1102 {
1103 	const struct nlmsghdr *nlh = cb->nlh;
1104 	struct inet_diag_dump_data *cb_data;
1105 	struct sk_buff *skb = cb->skb;
1106 	struct nlattr *nla;
1107 	int rem, err;
1108 
1109 	cb_data = kzalloc(sizeof(*cb_data), GFP_KERNEL);
1110 	if (!cb_data)
1111 		return -ENOMEM;
1112 
1113 	nla_for_each_attr(nla, nlmsg_attrdata(nlh, hdrlen),
1114 			  nlmsg_attrlen(nlh, hdrlen), rem) {
1115 		int type = nla_type(nla);
1116 
1117 		if (type < __INET_DIAG_REQ_MAX)
1118 			cb_data->req_nlas[type] = nla;
1119 	}
1120 
1121 	nla = cb_data->inet_diag_nla_bc;
1122 	if (nla) {
1123 		err = inet_diag_bc_audit(nla, skb);
1124 		if (err) {
1125 			kfree(cb_data);
1126 			return err;
1127 		}
1128 	}
1129 
1130 	nla = cb_data->inet_diag_nla_bpf_stgs;
1131 	if (nla) {
1132 		struct bpf_sk_storage_diag *bpf_stg_diag;
1133 
1134 		bpf_stg_diag = bpf_sk_storage_diag_alloc(nla);
1135 		if (IS_ERR(bpf_stg_diag)) {
1136 			kfree(cb_data);
1137 			return PTR_ERR(bpf_stg_diag);
1138 		}
1139 		cb_data->bpf_stg_diag = bpf_stg_diag;
1140 	}
1141 
1142 	cb->data = cb_data;
1143 	return 0;
1144 }
1145 
1146 static int inet_diag_dump_start(struct netlink_callback *cb)
1147 {
1148 	return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req_v2));
1149 }
1150 
1151 static int inet_diag_dump_start_compat(struct netlink_callback *cb)
1152 {
1153 	return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req));
1154 }
1155 
1156 static int inet_diag_dump_done(struct netlink_callback *cb)
1157 {
1158 	struct inet_diag_dump_data *cb_data = cb->data;
1159 
1160 	bpf_sk_storage_diag_free(cb_data->bpf_stg_diag);
1161 	kfree(cb->data);
1162 
1163 	return 0;
1164 }
1165 
1166 static int inet_diag_type2proto(int type)
1167 {
1168 	switch (type) {
1169 	case TCPDIAG_GETSOCK:
1170 		return IPPROTO_TCP;
1171 	case DCCPDIAG_GETSOCK:
1172 		return IPPROTO_DCCP;
1173 	default:
1174 		return 0;
1175 	}
1176 }
1177 
1178 static int inet_diag_dump_compat(struct sk_buff *skb,
1179 				 struct netlink_callback *cb)
1180 {
1181 	struct inet_diag_req *rc = nlmsg_data(cb->nlh);
1182 	struct inet_diag_req_v2 req;
1183 
1184 	req.sdiag_family = AF_UNSPEC; /* compatibility */
1185 	req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
1186 	req.idiag_ext = rc->idiag_ext;
1187 	req.idiag_states = rc->idiag_states;
1188 	req.id = rc->id;
1189 
1190 	return __inet_diag_dump(skb, cb, &req);
1191 }
1192 
1193 static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
1194 				      const struct nlmsghdr *nlh)
1195 {
1196 	struct inet_diag_req *rc = nlmsg_data(nlh);
1197 	struct inet_diag_req_v2 req;
1198 
1199 	req.sdiag_family = rc->idiag_family;
1200 	req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
1201 	req.idiag_ext = rc->idiag_ext;
1202 	req.idiag_states = rc->idiag_states;
1203 	req.id = rc->id;
1204 
1205 	return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh, &req);
1206 }
1207 
1208 static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
1209 {
1210 	int hdrlen = sizeof(struct inet_diag_req);
1211 	struct net *net = sock_net(skb->sk);
1212 
1213 	if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
1214 	    nlmsg_len(nlh) < hdrlen)
1215 		return -EINVAL;
1216 
1217 	if (nlh->nlmsg_flags & NLM_F_DUMP) {
1218 		struct netlink_dump_control c = {
1219 			.start = inet_diag_dump_start_compat,
1220 			.done = inet_diag_dump_done,
1221 			.dump = inet_diag_dump_compat,
1222 		};
1223 		return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
1224 	}
1225 
1226 	return inet_diag_get_exact_compat(skb, nlh);
1227 }
1228 
1229 static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
1230 {
1231 	int hdrlen = sizeof(struct inet_diag_req_v2);
1232 	struct net *net = sock_net(skb->sk);
1233 
1234 	if (nlmsg_len(h) < hdrlen)
1235 		return -EINVAL;
1236 
1237 	if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
1238 	    h->nlmsg_flags & NLM_F_DUMP) {
1239 		struct netlink_dump_control c = {
1240 			.start = inet_diag_dump_start,
1241 			.done = inet_diag_dump_done,
1242 			.dump = inet_diag_dump,
1243 		};
1244 		return netlink_dump_start(net->diag_nlsk, skb, h, &c);
1245 	}
1246 
1247 	return inet_diag_cmd_exact(h->nlmsg_type, skb, h, nlmsg_data(h));
1248 }
1249 
1250 static
1251 int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk)
1252 {
1253 	const struct inet_diag_handler *handler;
1254 	struct nlmsghdr *nlh;
1255 	struct nlattr *attr;
1256 	struct inet_diag_msg *r;
1257 	void *info = NULL;
1258 	int err = 0;
1259 
1260 	nlh = nlmsg_put(skb, 0, 0, SOCK_DIAG_BY_FAMILY, sizeof(*r), 0);
1261 	if (!nlh)
1262 		return -ENOMEM;
1263 
1264 	r = nlmsg_data(nlh);
1265 	memset(r, 0, sizeof(*r));
1266 	inet_diag_msg_common_fill(r, sk);
1267 	if (sk->sk_type == SOCK_DGRAM || sk->sk_type == SOCK_STREAM)
1268 		r->id.idiag_sport = inet_sk(sk)->inet_sport;
1269 	r->idiag_state = sk->sk_state;
1270 
1271 	if ((err = nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))) {
1272 		nlmsg_cancel(skb, nlh);
1273 		return err;
1274 	}
1275 
1276 	handler = inet_diag_lock_handler(sk->sk_protocol);
1277 	if (IS_ERR(handler)) {
1278 		inet_diag_unlock_handler(handler);
1279 		nlmsg_cancel(skb, nlh);
1280 		return PTR_ERR(handler);
1281 	}
1282 
1283 	attr = handler->idiag_info_size
1284 		? nla_reserve_64bit(skb, INET_DIAG_INFO,
1285 				    handler->idiag_info_size,
1286 				    INET_DIAG_PAD)
1287 		: NULL;
1288 	if (attr)
1289 		info = nla_data(attr);
1290 
1291 	handler->idiag_get_info(sk, r, info);
1292 	inet_diag_unlock_handler(handler);
1293 
1294 	nlmsg_end(skb, nlh);
1295 	return 0;
1296 }
1297 
1298 static const struct sock_diag_handler inet_diag_handler = {
1299 	.family = AF_INET,
1300 	.dump = inet_diag_handler_cmd,
1301 	.get_info = inet_diag_handler_get_info,
1302 	.destroy = inet_diag_handler_cmd,
1303 };
1304 
1305 static const struct sock_diag_handler inet6_diag_handler = {
1306 	.family = AF_INET6,
1307 	.dump = inet_diag_handler_cmd,
1308 	.get_info = inet_diag_handler_get_info,
1309 	.destroy = inet_diag_handler_cmd,
1310 };
1311 
1312 int inet_diag_register(const struct inet_diag_handler *h)
1313 {
1314 	const __u16 type = h->idiag_type;
1315 	int err = -EINVAL;
1316 
1317 	if (type >= IPPROTO_MAX)
1318 		goto out;
1319 
1320 	mutex_lock(&inet_diag_table_mutex);
1321 	err = -EEXIST;
1322 	if (!inet_diag_table[type]) {
1323 		inet_diag_table[type] = h;
1324 		err = 0;
1325 	}
1326 	mutex_unlock(&inet_diag_table_mutex);
1327 out:
1328 	return err;
1329 }
1330 EXPORT_SYMBOL_GPL(inet_diag_register);
1331 
1332 void inet_diag_unregister(const struct inet_diag_handler *h)
1333 {
1334 	const __u16 type = h->idiag_type;
1335 
1336 	if (type >= IPPROTO_MAX)
1337 		return;
1338 
1339 	mutex_lock(&inet_diag_table_mutex);
1340 	inet_diag_table[type] = NULL;
1341 	mutex_unlock(&inet_diag_table_mutex);
1342 }
1343 EXPORT_SYMBOL_GPL(inet_diag_unregister);
1344 
1345 static int __init inet_diag_init(void)
1346 {
1347 	const int inet_diag_table_size = (IPPROTO_MAX *
1348 					  sizeof(struct inet_diag_handler *));
1349 	int err = -ENOMEM;
1350 
1351 	inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1352 	if (!inet_diag_table)
1353 		goto out;
1354 
1355 	err = sock_diag_register(&inet_diag_handler);
1356 	if (err)
1357 		goto out_free_nl;
1358 
1359 	err = sock_diag_register(&inet6_diag_handler);
1360 	if (err)
1361 		goto out_free_inet;
1362 
1363 	sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
1364 out:
1365 	return err;
1366 
1367 out_free_inet:
1368 	sock_diag_unregister(&inet_diag_handler);
1369 out_free_nl:
1370 	kfree(inet_diag_table);
1371 	goto out;
1372 }
1373 
1374 static void __exit inet_diag_exit(void)
1375 {
1376 	sock_diag_unregister(&inet6_diag_handler);
1377 	sock_diag_unregister(&inet_diag_handler);
1378 	sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
1379 	kfree(inet_diag_table);
1380 }
1381 
1382 module_init(inet_diag_init);
1383 module_exit(inet_diag_exit);
1384 MODULE_LICENSE("GPL");
1385 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1386 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);
1387