xref: /openbmc/linux/net/ipv4/inet_diag.c (revision 82df5b73)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * inet_diag.c	Module for monitoring INET transport protocols sockets.
4  *
5  * Authors:	Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
6  */
7 
8 #include <linux/kernel.h>
9 #include <linux/module.h>
10 #include <linux/types.h>
11 #include <linux/fcntl.h>
12 #include <linux/random.h>
13 #include <linux/slab.h>
14 #include <linux/cache.h>
15 #include <linux/init.h>
16 #include <linux/time.h>
17 
18 #include <net/icmp.h>
19 #include <net/tcp.h>
20 #include <net/ipv6.h>
21 #include <net/inet_common.h>
22 #include <net/inet_connection_sock.h>
23 #include <net/inet_hashtables.h>
24 #include <net/inet_timewait_sock.h>
25 #include <net/inet6_hashtables.h>
26 #include <net/bpf_sk_storage.h>
27 #include <net/netlink.h>
28 
29 #include <linux/inet.h>
30 #include <linux/stddef.h>
31 
32 #include <linux/inet_diag.h>
33 #include <linux/sock_diag.h>
34 
35 static const struct inet_diag_handler **inet_diag_table;
36 
37 struct inet_diag_entry {
38 	const __be32 *saddr;
39 	const __be32 *daddr;
40 	u16 sport;
41 	u16 dport;
42 	u16 family;
43 	u16 userlocks;
44 	u32 ifindex;
45 	u32 mark;
46 #ifdef CONFIG_SOCK_CGROUP_DATA
47 	u64 cgroup_id;
48 #endif
49 };
50 
51 static DEFINE_MUTEX(inet_diag_table_mutex);
52 
53 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
54 {
55 	if (!inet_diag_table[proto])
56 		sock_load_diag_module(AF_INET, proto);
57 
58 	mutex_lock(&inet_diag_table_mutex);
59 	if (!inet_diag_table[proto])
60 		return ERR_PTR(-ENOENT);
61 
62 	return inet_diag_table[proto];
63 }
64 
65 static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
66 {
67 	mutex_unlock(&inet_diag_table_mutex);
68 }
69 
70 void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
71 {
72 	r->idiag_family = sk->sk_family;
73 
74 	r->id.idiag_sport = htons(sk->sk_num);
75 	r->id.idiag_dport = sk->sk_dport;
76 	r->id.idiag_if = sk->sk_bound_dev_if;
77 	sock_diag_save_cookie(sk, r->id.idiag_cookie);
78 
79 #if IS_ENABLED(CONFIG_IPV6)
80 	if (sk->sk_family == AF_INET6) {
81 		*(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr;
82 		*(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr;
83 	} else
84 #endif
85 	{
86 	memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
87 	memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
88 
89 	r->id.idiag_src[0] = sk->sk_rcv_saddr;
90 	r->id.idiag_dst[0] = sk->sk_daddr;
91 	}
92 }
93 EXPORT_SYMBOL_GPL(inet_diag_msg_common_fill);
94 
95 static size_t inet_sk_attr_size(struct sock *sk,
96 				const struct inet_diag_req_v2 *req,
97 				bool net_admin)
98 {
99 	const struct inet_diag_handler *handler;
100 	size_t aux = 0;
101 
102 	handler = inet_diag_table[req->sdiag_protocol];
103 	if (handler && handler->idiag_get_aux_size)
104 		aux = handler->idiag_get_aux_size(sk, net_admin);
105 
106 	return	  nla_total_size(sizeof(struct tcp_info))
107 		+ nla_total_size(sizeof(struct inet_diag_msg))
108 		+ inet_diag_msg_attrs_size()
109 		+ nla_total_size(sizeof(struct inet_diag_meminfo))
110 		+ nla_total_size(SK_MEMINFO_VARS * sizeof(u32))
111 		+ nla_total_size(TCP_CA_NAME_MAX)
112 		+ nla_total_size(sizeof(struct tcpvegas_info))
113 		+ aux
114 		+ 64;
115 }
116 
117 int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
118 			     struct inet_diag_msg *r, int ext,
119 			     struct user_namespace *user_ns,
120 			     bool net_admin)
121 {
122 	const struct inet_sock *inet = inet_sk(sk);
123 
124 	if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
125 		goto errout;
126 
127 	/* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
128 	 * hence this needs to be included regardless of socket family.
129 	 */
130 	if (ext & (1 << (INET_DIAG_TOS - 1)))
131 		if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0)
132 			goto errout;
133 
134 #if IS_ENABLED(CONFIG_IPV6)
135 	if (r->idiag_family == AF_INET6) {
136 		if (ext & (1 << (INET_DIAG_TCLASS - 1)))
137 			if (nla_put_u8(skb, INET_DIAG_TCLASS,
138 				       inet6_sk(sk)->tclass) < 0)
139 				goto errout;
140 
141 		if (((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) &&
142 		    nla_put_u8(skb, INET_DIAG_SKV6ONLY, ipv6_only_sock(sk)))
143 			goto errout;
144 	}
145 #endif
146 
147 	if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, sk->sk_mark))
148 		goto errout;
149 
150 	if (ext & (1 << (INET_DIAG_CLASS_ID - 1)) ||
151 	    ext & (1 << (INET_DIAG_TCLASS - 1))) {
152 		u32 classid = 0;
153 
154 #ifdef CONFIG_SOCK_CGROUP_DATA
155 		classid = sock_cgroup_classid(&sk->sk_cgrp_data);
156 #endif
157 		/* Fallback to socket priority if class id isn't set.
158 		 * Classful qdiscs use it as direct reference to class.
159 		 * For cgroup2 classid is always zero.
160 		 */
161 		if (!classid)
162 			classid = sk->sk_priority;
163 
164 		if (nla_put_u32(skb, INET_DIAG_CLASS_ID, classid))
165 			goto errout;
166 	}
167 
168 #ifdef CONFIG_SOCK_CGROUP_DATA
169 	if (nla_put_u64_64bit(skb, INET_DIAG_CGROUP_ID,
170 			      cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)),
171 			      INET_DIAG_PAD))
172 		goto errout;
173 #endif
174 
175 	r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
176 	r->idiag_inode = sock_i_ino(sk);
177 
178 	return 0;
179 errout:
180 	return 1;
181 }
182 EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);
183 
184 #define MAX_DUMP_ALLOC_SIZE (KMALLOC_MAX_SIZE - SKB_DATA_ALIGN(sizeof(struct skb_shared_info)))
185 
186 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
187 		      struct sk_buff *skb, struct netlink_callback *cb,
188 		      const struct inet_diag_req_v2 *req,
189 		      u16 nlmsg_flags, bool net_admin)
190 {
191 	const struct tcp_congestion_ops *ca_ops;
192 	const struct inet_diag_handler *handler;
193 	struct inet_diag_dump_data *cb_data;
194 	int ext = req->idiag_ext;
195 	struct inet_diag_msg *r;
196 	struct nlmsghdr  *nlh;
197 	struct nlattr *attr;
198 	void *info = NULL;
199 
200 	cb_data = cb->data;
201 	handler = inet_diag_table[req->sdiag_protocol];
202 	BUG_ON(!handler);
203 
204 	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
205 			cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
206 	if (!nlh)
207 		return -EMSGSIZE;
208 
209 	r = nlmsg_data(nlh);
210 	BUG_ON(!sk_fullsock(sk));
211 
212 	inet_diag_msg_common_fill(r, sk);
213 	r->idiag_state = sk->sk_state;
214 	r->idiag_timer = 0;
215 	r->idiag_retrans = 0;
216 
217 	if (inet_diag_msg_attrs_fill(sk, skb, r, ext,
218 				     sk_user_ns(NETLINK_CB(cb->skb).sk),
219 				     net_admin))
220 		goto errout;
221 
222 	if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
223 		struct inet_diag_meminfo minfo = {
224 			.idiag_rmem = sk_rmem_alloc_get(sk),
225 			.idiag_wmem = READ_ONCE(sk->sk_wmem_queued),
226 			.idiag_fmem = sk->sk_forward_alloc,
227 			.idiag_tmem = sk_wmem_alloc_get(sk),
228 		};
229 
230 		if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
231 			goto errout;
232 	}
233 
234 	if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
235 		if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
236 			goto errout;
237 
238 	/*
239 	 * RAW sockets might have user-defined protocols assigned,
240 	 * so report the one supplied on socket creation.
241 	 */
242 	if (sk->sk_type == SOCK_RAW) {
243 		if (nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))
244 			goto errout;
245 	}
246 
247 	if (!icsk) {
248 		handler->idiag_get_info(sk, r, NULL);
249 		goto out;
250 	}
251 
252 	if (icsk->icsk_pending == ICSK_TIME_RETRANS ||
253 	    icsk->icsk_pending == ICSK_TIME_REO_TIMEOUT ||
254 	    icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) {
255 		r->idiag_timer = 1;
256 		r->idiag_retrans = icsk->icsk_retransmits;
257 		r->idiag_expires =
258 			jiffies_delta_to_msecs(icsk->icsk_timeout - jiffies);
259 	} else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
260 		r->idiag_timer = 4;
261 		r->idiag_retrans = icsk->icsk_probes_out;
262 		r->idiag_expires =
263 			jiffies_delta_to_msecs(icsk->icsk_timeout - jiffies);
264 	} else if (timer_pending(&sk->sk_timer)) {
265 		r->idiag_timer = 2;
266 		r->idiag_retrans = icsk->icsk_probes_out;
267 		r->idiag_expires =
268 			jiffies_delta_to_msecs(sk->sk_timer.expires - jiffies);
269 	} else {
270 		r->idiag_timer = 0;
271 		r->idiag_expires = 0;
272 	}
273 
274 	if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) {
275 		attr = nla_reserve_64bit(skb, INET_DIAG_INFO,
276 					 handler->idiag_info_size,
277 					 INET_DIAG_PAD);
278 		if (!attr)
279 			goto errout;
280 
281 		info = nla_data(attr);
282 	}
283 
284 	if (ext & (1 << (INET_DIAG_CONG - 1))) {
285 		int err = 0;
286 
287 		rcu_read_lock();
288 		ca_ops = READ_ONCE(icsk->icsk_ca_ops);
289 		if (ca_ops)
290 			err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name);
291 		rcu_read_unlock();
292 		if (err < 0)
293 			goto errout;
294 	}
295 
296 	handler->idiag_get_info(sk, r, info);
297 
298 	if (ext & (1 << (INET_DIAG_INFO - 1)) && handler->idiag_get_aux)
299 		if (handler->idiag_get_aux(sk, net_admin, skb) < 0)
300 			goto errout;
301 
302 	if (sk->sk_state < TCP_TIME_WAIT) {
303 		union tcp_cc_info info;
304 		size_t sz = 0;
305 		int attr;
306 
307 		rcu_read_lock();
308 		ca_ops = READ_ONCE(icsk->icsk_ca_ops);
309 		if (ca_ops && ca_ops->get_info)
310 			sz = ca_ops->get_info(sk, ext, &attr, &info);
311 		rcu_read_unlock();
312 		if (sz && nla_put(skb, attr, sz, &info) < 0)
313 			goto errout;
314 	}
315 
316 	/* Keep it at the end for potential retry with a larger skb,
317 	 * or else do best-effort fitting, which is only done for the
318 	 * first_nlmsg.
319 	 */
320 	if (cb_data->bpf_stg_diag) {
321 		bool first_nlmsg = ((unsigned char *)nlh == skb->data);
322 		unsigned int prev_min_dump_alloc;
323 		unsigned int total_nla_size = 0;
324 		unsigned int msg_len;
325 		int err;
326 
327 		msg_len = skb_tail_pointer(skb) - (unsigned char *)nlh;
328 		err = bpf_sk_storage_diag_put(cb_data->bpf_stg_diag, sk, skb,
329 					      INET_DIAG_SK_BPF_STORAGES,
330 					      &total_nla_size);
331 
332 		if (!err)
333 			goto out;
334 
335 		total_nla_size += msg_len;
336 		prev_min_dump_alloc = cb->min_dump_alloc;
337 		if (total_nla_size > prev_min_dump_alloc)
338 			cb->min_dump_alloc = min_t(u32, total_nla_size,
339 						   MAX_DUMP_ALLOC_SIZE);
340 
341 		if (!first_nlmsg)
342 			goto errout;
343 
344 		if (cb->min_dump_alloc > prev_min_dump_alloc)
345 			/* Retry with pskb_expand_head() with
346 			 * __GFP_DIRECT_RECLAIM
347 			 */
348 			goto errout;
349 
350 		WARN_ON_ONCE(total_nla_size <= prev_min_dump_alloc);
351 
352 		/* Send what we have for this sk
353 		 * and move on to the next sk in the following
354 		 * dump()
355 		 */
356 	}
357 
358 out:
359 	nlmsg_end(skb, nlh);
360 	return 0;
361 
362 errout:
363 	nlmsg_cancel(skb, nlh);
364 	return -EMSGSIZE;
365 }
366 EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
367 
368 static int inet_twsk_diag_fill(struct sock *sk,
369 			       struct sk_buff *skb,
370 			       struct netlink_callback *cb,
371 			       u16 nlmsg_flags)
372 {
373 	struct inet_timewait_sock *tw = inet_twsk(sk);
374 	struct inet_diag_msg *r;
375 	struct nlmsghdr *nlh;
376 	long tmo;
377 
378 	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid,
379 			cb->nlh->nlmsg_seq, cb->nlh->nlmsg_type,
380 			sizeof(*r), nlmsg_flags);
381 	if (!nlh)
382 		return -EMSGSIZE;
383 
384 	r = nlmsg_data(nlh);
385 	BUG_ON(tw->tw_state != TCP_TIME_WAIT);
386 
387 	inet_diag_msg_common_fill(r, sk);
388 	r->idiag_retrans      = 0;
389 
390 	r->idiag_state	      = tw->tw_substate;
391 	r->idiag_timer	      = 3;
392 	tmo = tw->tw_timer.expires - jiffies;
393 	r->idiag_expires      = jiffies_delta_to_msecs(tmo);
394 	r->idiag_rqueue	      = 0;
395 	r->idiag_wqueue	      = 0;
396 	r->idiag_uid	      = 0;
397 	r->idiag_inode	      = 0;
398 
399 	nlmsg_end(skb, nlh);
400 	return 0;
401 }
402 
403 static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
404 			      struct netlink_callback *cb,
405 			      u16 nlmsg_flags, bool net_admin)
406 {
407 	struct request_sock *reqsk = inet_reqsk(sk);
408 	struct inet_diag_msg *r;
409 	struct nlmsghdr *nlh;
410 	long tmo;
411 
412 	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
413 			cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
414 	if (!nlh)
415 		return -EMSGSIZE;
416 
417 	r = nlmsg_data(nlh);
418 	inet_diag_msg_common_fill(r, sk);
419 	r->idiag_state = TCP_SYN_RECV;
420 	r->idiag_timer = 1;
421 	r->idiag_retrans = reqsk->num_retrans;
422 
423 	BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
424 		     offsetof(struct sock, sk_cookie));
425 
426 	tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies;
427 	r->idiag_expires = jiffies_delta_to_msecs(tmo);
428 	r->idiag_rqueue	= 0;
429 	r->idiag_wqueue	= 0;
430 	r->idiag_uid	= 0;
431 	r->idiag_inode	= 0;
432 
433 	if (net_admin && nla_put_u32(skb, INET_DIAG_MARK,
434 				     inet_rsk(reqsk)->ir_mark))
435 		return -EMSGSIZE;
436 
437 	nlmsg_end(skb, nlh);
438 	return 0;
439 }
440 
441 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
442 			struct netlink_callback *cb,
443 			const struct inet_diag_req_v2 *r,
444 			u16 nlmsg_flags, bool net_admin)
445 {
446 	if (sk->sk_state == TCP_TIME_WAIT)
447 		return inet_twsk_diag_fill(sk, skb, cb, nlmsg_flags);
448 
449 	if (sk->sk_state == TCP_NEW_SYN_RECV)
450 		return inet_req_diag_fill(sk, skb, cb, nlmsg_flags, net_admin);
451 
452 	return inet_sk_diag_fill(sk, inet_csk(sk), skb, cb, r, nlmsg_flags,
453 				 net_admin);
454 }
455 
456 struct sock *inet_diag_find_one_icsk(struct net *net,
457 				     struct inet_hashinfo *hashinfo,
458 				     const struct inet_diag_req_v2 *req)
459 {
460 	struct sock *sk;
461 
462 	rcu_read_lock();
463 	if (req->sdiag_family == AF_INET)
464 		sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[0],
465 				 req->id.idiag_dport, req->id.idiag_src[0],
466 				 req->id.idiag_sport, req->id.idiag_if);
467 #if IS_ENABLED(CONFIG_IPV6)
468 	else if (req->sdiag_family == AF_INET6) {
469 		if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
470 		    ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src))
471 			sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[3],
472 					 req->id.idiag_dport, req->id.idiag_src[3],
473 					 req->id.idiag_sport, req->id.idiag_if);
474 		else
475 			sk = inet6_lookup(net, hashinfo, NULL, 0,
476 					  (struct in6_addr *)req->id.idiag_dst,
477 					  req->id.idiag_dport,
478 					  (struct in6_addr *)req->id.idiag_src,
479 					  req->id.idiag_sport,
480 					  req->id.idiag_if);
481 	}
482 #endif
483 	else {
484 		rcu_read_unlock();
485 		return ERR_PTR(-EINVAL);
486 	}
487 	rcu_read_unlock();
488 	if (!sk)
489 		return ERR_PTR(-ENOENT);
490 
491 	if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) {
492 		sock_gen_put(sk);
493 		return ERR_PTR(-ENOENT);
494 	}
495 
496 	return sk;
497 }
498 EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk);
499 
500 int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
501 			    struct netlink_callback *cb,
502 			    const struct inet_diag_req_v2 *req)
503 {
504 	struct sk_buff *in_skb = cb->skb;
505 	bool net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN);
506 	struct net *net = sock_net(in_skb->sk);
507 	struct sk_buff *rep;
508 	struct sock *sk;
509 	int err;
510 
511 	sk = inet_diag_find_one_icsk(net, hashinfo, req);
512 	if (IS_ERR(sk))
513 		return PTR_ERR(sk);
514 
515 	rep = nlmsg_new(inet_sk_attr_size(sk, req, net_admin), GFP_KERNEL);
516 	if (!rep) {
517 		err = -ENOMEM;
518 		goto out;
519 	}
520 
521 	err = sk_diag_fill(sk, rep, cb, req, 0, net_admin);
522 	if (err < 0) {
523 		WARN_ON(err == -EMSGSIZE);
524 		nlmsg_free(rep);
525 		goto out;
526 	}
527 	err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid,
528 			      MSG_DONTWAIT);
529 	if (err > 0)
530 		err = 0;
531 
532 out:
533 	if (sk)
534 		sock_gen_put(sk);
535 
536 	return err;
537 }
538 EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
539 
540 static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
541 			       const struct nlmsghdr *nlh,
542 			       const struct inet_diag_req_v2 *req)
543 {
544 	const struct inet_diag_handler *handler;
545 	int err;
546 
547 	handler = inet_diag_lock_handler(req->sdiag_protocol);
548 	if (IS_ERR(handler)) {
549 		err = PTR_ERR(handler);
550 	} else if (cmd == SOCK_DIAG_BY_FAMILY) {
551 		struct inet_diag_dump_data empty_dump_data = {};
552 		struct netlink_callback cb = {
553 			.nlh = nlh,
554 			.skb = in_skb,
555 			.data = &empty_dump_data,
556 		};
557 		err = handler->dump_one(&cb, req);
558 	} else if (cmd == SOCK_DESTROY && handler->destroy) {
559 		err = handler->destroy(in_skb, req);
560 	} else {
561 		err = -EOPNOTSUPP;
562 	}
563 	inet_diag_unlock_handler(handler);
564 
565 	return err;
566 }
567 
568 static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
569 {
570 	int words = bits >> 5;
571 
572 	bits &= 0x1f;
573 
574 	if (words) {
575 		if (memcmp(a1, a2, words << 2))
576 			return 0;
577 	}
578 	if (bits) {
579 		__be32 w1, w2;
580 		__be32 mask;
581 
582 		w1 = a1[words];
583 		w2 = a2[words];
584 
585 		mask = htonl((0xffffffff) << (32 - bits));
586 
587 		if ((w1 ^ w2) & mask)
588 			return 0;
589 	}
590 
591 	return 1;
592 }
593 
594 static int inet_diag_bc_run(const struct nlattr *_bc,
595 			    const struct inet_diag_entry *entry)
596 {
597 	const void *bc = nla_data(_bc);
598 	int len = nla_len(_bc);
599 
600 	while (len > 0) {
601 		int yes = 1;
602 		const struct inet_diag_bc_op *op = bc;
603 
604 		switch (op->code) {
605 		case INET_DIAG_BC_NOP:
606 			break;
607 		case INET_DIAG_BC_JMP:
608 			yes = 0;
609 			break;
610 		case INET_DIAG_BC_S_EQ:
611 			yes = entry->sport == op[1].no;
612 			break;
613 		case INET_DIAG_BC_S_GE:
614 			yes = entry->sport >= op[1].no;
615 			break;
616 		case INET_DIAG_BC_S_LE:
617 			yes = entry->sport <= op[1].no;
618 			break;
619 		case INET_DIAG_BC_D_EQ:
620 			yes = entry->dport == op[1].no;
621 			break;
622 		case INET_DIAG_BC_D_GE:
623 			yes = entry->dport >= op[1].no;
624 			break;
625 		case INET_DIAG_BC_D_LE:
626 			yes = entry->dport <= op[1].no;
627 			break;
628 		case INET_DIAG_BC_AUTO:
629 			yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
630 			break;
631 		case INET_DIAG_BC_S_COND:
632 		case INET_DIAG_BC_D_COND: {
633 			const struct inet_diag_hostcond *cond;
634 			const __be32 *addr;
635 
636 			cond = (const struct inet_diag_hostcond *)(op + 1);
637 			if (cond->port != -1 &&
638 			    cond->port != (op->code == INET_DIAG_BC_S_COND ?
639 					     entry->sport : entry->dport)) {
640 				yes = 0;
641 				break;
642 			}
643 
644 			if (op->code == INET_DIAG_BC_S_COND)
645 				addr = entry->saddr;
646 			else
647 				addr = entry->daddr;
648 
649 			if (cond->family != AF_UNSPEC &&
650 			    cond->family != entry->family) {
651 				if (entry->family == AF_INET6 &&
652 				    cond->family == AF_INET) {
653 					if (addr[0] == 0 && addr[1] == 0 &&
654 					    addr[2] == htonl(0xffff) &&
655 					    bitstring_match(addr + 3,
656 							    cond->addr,
657 							    cond->prefix_len))
658 						break;
659 				}
660 				yes = 0;
661 				break;
662 			}
663 
664 			if (cond->prefix_len == 0)
665 				break;
666 			if (bitstring_match(addr, cond->addr,
667 					    cond->prefix_len))
668 				break;
669 			yes = 0;
670 			break;
671 		}
672 		case INET_DIAG_BC_DEV_COND: {
673 			u32 ifindex;
674 
675 			ifindex = *((const u32 *)(op + 1));
676 			if (ifindex != entry->ifindex)
677 				yes = 0;
678 			break;
679 		}
680 		case INET_DIAG_BC_MARK_COND: {
681 			struct inet_diag_markcond *cond;
682 
683 			cond = (struct inet_diag_markcond *)(op + 1);
684 			if ((entry->mark & cond->mask) != cond->mark)
685 				yes = 0;
686 			break;
687 		}
688 #ifdef CONFIG_SOCK_CGROUP_DATA
689 		case INET_DIAG_BC_CGROUP_COND: {
690 			u64 cgroup_id;
691 
692 			cgroup_id = get_unaligned((const u64 *)(op + 1));
693 			if (cgroup_id != entry->cgroup_id)
694 				yes = 0;
695 			break;
696 		}
697 #endif
698 		}
699 
700 		if (yes) {
701 			len -= op->yes;
702 			bc += op->yes;
703 		} else {
704 			len -= op->no;
705 			bc += op->no;
706 		}
707 	}
708 	return len == 0;
709 }
710 
711 /* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV)
712  */
713 static void entry_fill_addrs(struct inet_diag_entry *entry,
714 			     const struct sock *sk)
715 {
716 #if IS_ENABLED(CONFIG_IPV6)
717 	if (sk->sk_family == AF_INET6) {
718 		entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32;
719 		entry->daddr = sk->sk_v6_daddr.s6_addr32;
720 	} else
721 #endif
722 	{
723 		entry->saddr = &sk->sk_rcv_saddr;
724 		entry->daddr = &sk->sk_daddr;
725 	}
726 }
727 
728 int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
729 {
730 	struct inet_sock *inet = inet_sk(sk);
731 	struct inet_diag_entry entry;
732 
733 	if (!bc)
734 		return 1;
735 
736 	entry.family = sk->sk_family;
737 	entry_fill_addrs(&entry, sk);
738 	entry.sport = inet->inet_num;
739 	entry.dport = ntohs(inet->inet_dport);
740 	entry.ifindex = sk->sk_bound_dev_if;
741 	entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
742 	if (sk_fullsock(sk))
743 		entry.mark = sk->sk_mark;
744 	else if (sk->sk_state == TCP_NEW_SYN_RECV)
745 		entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark;
746 	else
747 		entry.mark = 0;
748 #ifdef CONFIG_SOCK_CGROUP_DATA
749 	entry.cgroup_id = sk_fullsock(sk) ?
750 		cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)) : 0;
751 #endif
752 
753 	return inet_diag_bc_run(bc, &entry);
754 }
755 EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
756 
757 static int valid_cc(const void *bc, int len, int cc)
758 {
759 	while (len >= 0) {
760 		const struct inet_diag_bc_op *op = bc;
761 
762 		if (cc > len)
763 			return 0;
764 		if (cc == len)
765 			return 1;
766 		if (op->yes < 4 || op->yes & 3)
767 			return 0;
768 		len -= op->yes;
769 		bc  += op->yes;
770 	}
771 	return 0;
772 }
773 
774 /* data is u32 ifindex */
775 static bool valid_devcond(const struct inet_diag_bc_op *op, int len,
776 			  int *min_len)
777 {
778 	/* Check ifindex space. */
779 	*min_len += sizeof(u32);
780 	if (len < *min_len)
781 		return false;
782 
783 	return true;
784 }
785 /* Validate an inet_diag_hostcond. */
786 static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
787 			   int *min_len)
788 {
789 	struct inet_diag_hostcond *cond;
790 	int addr_len;
791 
792 	/* Check hostcond space. */
793 	*min_len += sizeof(struct inet_diag_hostcond);
794 	if (len < *min_len)
795 		return false;
796 	cond = (struct inet_diag_hostcond *)(op + 1);
797 
798 	/* Check address family and address length. */
799 	switch (cond->family) {
800 	case AF_UNSPEC:
801 		addr_len = 0;
802 		break;
803 	case AF_INET:
804 		addr_len = sizeof(struct in_addr);
805 		break;
806 	case AF_INET6:
807 		addr_len = sizeof(struct in6_addr);
808 		break;
809 	default:
810 		return false;
811 	}
812 	*min_len += addr_len;
813 	if (len < *min_len)
814 		return false;
815 
816 	/* Check prefix length (in bits) vs address length (in bytes). */
817 	if (cond->prefix_len > 8 * addr_len)
818 		return false;
819 
820 	return true;
821 }
822 
823 /* Validate a port comparison operator. */
824 static bool valid_port_comparison(const struct inet_diag_bc_op *op,
825 				  int len, int *min_len)
826 {
827 	/* Port comparisons put the port in a follow-on inet_diag_bc_op. */
828 	*min_len += sizeof(struct inet_diag_bc_op);
829 	if (len < *min_len)
830 		return false;
831 	return true;
832 }
833 
834 static bool valid_markcond(const struct inet_diag_bc_op *op, int len,
835 			   int *min_len)
836 {
837 	*min_len += sizeof(struct inet_diag_markcond);
838 	return len >= *min_len;
839 }
840 
841 #ifdef CONFIG_SOCK_CGROUP_DATA
842 static bool valid_cgroupcond(const struct inet_diag_bc_op *op, int len,
843 			     int *min_len)
844 {
845 	*min_len += sizeof(u64);
846 	return len >= *min_len;
847 }
848 #endif
849 
850 static int inet_diag_bc_audit(const struct nlattr *attr,
851 			      const struct sk_buff *skb)
852 {
853 	bool net_admin = netlink_net_capable(skb, CAP_NET_ADMIN);
854 	const void *bytecode, *bc;
855 	int bytecode_len, len;
856 
857 	if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op))
858 		return -EINVAL;
859 
860 	bytecode = bc = nla_data(attr);
861 	len = bytecode_len = nla_len(attr);
862 
863 	while (len > 0) {
864 		int min_len = sizeof(struct inet_diag_bc_op);
865 		const struct inet_diag_bc_op *op = bc;
866 
867 		switch (op->code) {
868 		case INET_DIAG_BC_S_COND:
869 		case INET_DIAG_BC_D_COND:
870 			if (!valid_hostcond(bc, len, &min_len))
871 				return -EINVAL;
872 			break;
873 		case INET_DIAG_BC_DEV_COND:
874 			if (!valid_devcond(bc, len, &min_len))
875 				return -EINVAL;
876 			break;
877 		case INET_DIAG_BC_S_EQ:
878 		case INET_DIAG_BC_S_GE:
879 		case INET_DIAG_BC_S_LE:
880 		case INET_DIAG_BC_D_EQ:
881 		case INET_DIAG_BC_D_GE:
882 		case INET_DIAG_BC_D_LE:
883 			if (!valid_port_comparison(bc, len, &min_len))
884 				return -EINVAL;
885 			break;
886 		case INET_DIAG_BC_MARK_COND:
887 			if (!net_admin)
888 				return -EPERM;
889 			if (!valid_markcond(bc, len, &min_len))
890 				return -EINVAL;
891 			break;
892 #ifdef CONFIG_SOCK_CGROUP_DATA
893 		case INET_DIAG_BC_CGROUP_COND:
894 			if (!valid_cgroupcond(bc, len, &min_len))
895 				return -EINVAL;
896 			break;
897 #endif
898 		case INET_DIAG_BC_AUTO:
899 		case INET_DIAG_BC_JMP:
900 		case INET_DIAG_BC_NOP:
901 			break;
902 		default:
903 			return -EINVAL;
904 		}
905 
906 		if (op->code != INET_DIAG_BC_NOP) {
907 			if (op->no < min_len || op->no > len + 4 || op->no & 3)
908 				return -EINVAL;
909 			if (op->no < len &&
910 			    !valid_cc(bytecode, bytecode_len, len - op->no))
911 				return -EINVAL;
912 		}
913 
914 		if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
915 			return -EINVAL;
916 		bc  += op->yes;
917 		len -= op->yes;
918 	}
919 	return len == 0 ? 0 : -EINVAL;
920 }
921 
922 static void twsk_build_assert(void)
923 {
924 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) !=
925 		     offsetof(struct sock, sk_family));
926 
927 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) !=
928 		     offsetof(struct inet_sock, inet_num));
929 
930 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) !=
931 		     offsetof(struct inet_sock, inet_dport));
932 
933 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) !=
934 		     offsetof(struct inet_sock, inet_rcv_saddr));
935 
936 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) !=
937 		     offsetof(struct inet_sock, inet_daddr));
938 
939 #if IS_ENABLED(CONFIG_IPV6)
940 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) !=
941 		     offsetof(struct sock, sk_v6_rcv_saddr));
942 
943 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) !=
944 		     offsetof(struct sock, sk_v6_daddr));
945 #endif
946 }
947 
948 void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
949 			 struct netlink_callback *cb,
950 			 const struct inet_diag_req_v2 *r)
951 {
952 	bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN);
953 	struct inet_diag_dump_data *cb_data = cb->data;
954 	struct net *net = sock_net(skb->sk);
955 	u32 idiag_states = r->idiag_states;
956 	int i, num, s_i, s_num;
957 	struct nlattr *bc;
958 	struct sock *sk;
959 
960 	bc = cb_data->inet_diag_nla_bc;
961 	if (idiag_states & TCPF_SYN_RECV)
962 		idiag_states |= TCPF_NEW_SYN_RECV;
963 	s_i = cb->args[1];
964 	s_num = num = cb->args[2];
965 
966 	if (cb->args[0] == 0) {
967 		if (!(idiag_states & TCPF_LISTEN) || r->id.idiag_dport)
968 			goto skip_listen_ht;
969 
970 		for (i = s_i; i < INET_LHTABLE_SIZE; i++) {
971 			struct inet_listen_hashbucket *ilb;
972 			struct hlist_nulls_node *node;
973 
974 			num = 0;
975 			ilb = &hashinfo->listening_hash[i];
976 			spin_lock(&ilb->lock);
977 			sk_nulls_for_each(sk, node, &ilb->nulls_head) {
978 				struct inet_sock *inet = inet_sk(sk);
979 
980 				if (!net_eq(sock_net(sk), net))
981 					continue;
982 
983 				if (num < s_num) {
984 					num++;
985 					continue;
986 				}
987 
988 				if (r->sdiag_family != AF_UNSPEC &&
989 				    sk->sk_family != r->sdiag_family)
990 					goto next_listen;
991 
992 				if (r->id.idiag_sport != inet->inet_sport &&
993 				    r->id.idiag_sport)
994 					goto next_listen;
995 
996 				if (!inet_diag_bc_sk(bc, sk))
997 					goto next_listen;
998 
999 				if (inet_sk_diag_fill(sk, inet_csk(sk), skb,
1000 						      cb, r, NLM_F_MULTI,
1001 						      net_admin) < 0) {
1002 					spin_unlock(&ilb->lock);
1003 					goto done;
1004 				}
1005 
1006 next_listen:
1007 				++num;
1008 			}
1009 			spin_unlock(&ilb->lock);
1010 
1011 			s_num = 0;
1012 		}
1013 skip_listen_ht:
1014 		cb->args[0] = 1;
1015 		s_i = num = s_num = 0;
1016 	}
1017 
1018 	if (!(idiag_states & ~TCPF_LISTEN))
1019 		goto out;
1020 
1021 #define SKARR_SZ 16
1022 	for (i = s_i; i <= hashinfo->ehash_mask; i++) {
1023 		struct inet_ehash_bucket *head = &hashinfo->ehash[i];
1024 		spinlock_t *lock = inet_ehash_lockp(hashinfo, i);
1025 		struct hlist_nulls_node *node;
1026 		struct sock *sk_arr[SKARR_SZ];
1027 		int num_arr[SKARR_SZ];
1028 		int idx, accum, res;
1029 
1030 		if (hlist_nulls_empty(&head->chain))
1031 			continue;
1032 
1033 		if (i > s_i)
1034 			s_num = 0;
1035 
1036 next_chunk:
1037 		num = 0;
1038 		accum = 0;
1039 		spin_lock_bh(lock);
1040 		sk_nulls_for_each(sk, node, &head->chain) {
1041 			int state;
1042 
1043 			if (!net_eq(sock_net(sk), net))
1044 				continue;
1045 			if (num < s_num)
1046 				goto next_normal;
1047 			state = (sk->sk_state == TCP_TIME_WAIT) ?
1048 				inet_twsk(sk)->tw_substate : sk->sk_state;
1049 			if (!(idiag_states & (1 << state)))
1050 				goto next_normal;
1051 			if (r->sdiag_family != AF_UNSPEC &&
1052 			    sk->sk_family != r->sdiag_family)
1053 				goto next_normal;
1054 			if (r->id.idiag_sport != htons(sk->sk_num) &&
1055 			    r->id.idiag_sport)
1056 				goto next_normal;
1057 			if (r->id.idiag_dport != sk->sk_dport &&
1058 			    r->id.idiag_dport)
1059 				goto next_normal;
1060 			twsk_build_assert();
1061 
1062 			if (!inet_diag_bc_sk(bc, sk))
1063 				goto next_normal;
1064 
1065 			if (!refcount_inc_not_zero(&sk->sk_refcnt))
1066 				goto next_normal;
1067 
1068 			num_arr[accum] = num;
1069 			sk_arr[accum] = sk;
1070 			if (++accum == SKARR_SZ)
1071 				break;
1072 next_normal:
1073 			++num;
1074 		}
1075 		spin_unlock_bh(lock);
1076 		res = 0;
1077 		for (idx = 0; idx < accum; idx++) {
1078 			if (res >= 0) {
1079 				res = sk_diag_fill(sk_arr[idx], skb, cb, r,
1080 						   NLM_F_MULTI, net_admin);
1081 				if (res < 0)
1082 					num = num_arr[idx];
1083 			}
1084 			sock_gen_put(sk_arr[idx]);
1085 		}
1086 		if (res < 0)
1087 			break;
1088 		cond_resched();
1089 		if (accum == SKARR_SZ) {
1090 			s_num = num + 1;
1091 			goto next_chunk;
1092 		}
1093 	}
1094 
1095 done:
1096 	cb->args[1] = i;
1097 	cb->args[2] = num;
1098 out:
1099 	;
1100 }
1101 EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
1102 
1103 static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
1104 			    const struct inet_diag_req_v2 *r)
1105 {
1106 	const struct inet_diag_handler *handler;
1107 	u32 prev_min_dump_alloc;
1108 	int err = 0;
1109 
1110 again:
1111 	prev_min_dump_alloc = cb->min_dump_alloc;
1112 	handler = inet_diag_lock_handler(r->sdiag_protocol);
1113 	if (!IS_ERR(handler))
1114 		handler->dump(skb, cb, r);
1115 	else
1116 		err = PTR_ERR(handler);
1117 	inet_diag_unlock_handler(handler);
1118 
1119 	/* The skb is not large enough to fit one sk info and
1120 	 * inet_sk_diag_fill() has requested for a larger skb.
1121 	 */
1122 	if (!skb->len && cb->min_dump_alloc > prev_min_dump_alloc) {
1123 		err = pskb_expand_head(skb, 0, cb->min_dump_alloc, GFP_KERNEL);
1124 		if (!err)
1125 			goto again;
1126 	}
1127 
1128 	return err ? : skb->len;
1129 }
1130 
1131 static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
1132 {
1133 	return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh));
1134 }
1135 
1136 static int __inet_diag_dump_start(struct netlink_callback *cb, int hdrlen)
1137 {
1138 	const struct nlmsghdr *nlh = cb->nlh;
1139 	struct inet_diag_dump_data *cb_data;
1140 	struct sk_buff *skb = cb->skb;
1141 	struct nlattr *nla;
1142 	int rem, err;
1143 
1144 	cb_data = kzalloc(sizeof(*cb_data), GFP_KERNEL);
1145 	if (!cb_data)
1146 		return -ENOMEM;
1147 
1148 	nla_for_each_attr(nla, nlmsg_attrdata(nlh, hdrlen),
1149 			  nlmsg_attrlen(nlh, hdrlen), rem) {
1150 		int type = nla_type(nla);
1151 
1152 		if (type < __INET_DIAG_REQ_MAX)
1153 			cb_data->req_nlas[type] = nla;
1154 	}
1155 
1156 	nla = cb_data->inet_diag_nla_bc;
1157 	if (nla) {
1158 		err = inet_diag_bc_audit(nla, skb);
1159 		if (err) {
1160 			kfree(cb_data);
1161 			return err;
1162 		}
1163 	}
1164 
1165 	nla = cb_data->inet_diag_nla_bpf_stgs;
1166 	if (nla) {
1167 		struct bpf_sk_storage_diag *bpf_stg_diag;
1168 
1169 		bpf_stg_diag = bpf_sk_storage_diag_alloc(nla);
1170 		if (IS_ERR(bpf_stg_diag)) {
1171 			kfree(cb_data);
1172 			return PTR_ERR(bpf_stg_diag);
1173 		}
1174 		cb_data->bpf_stg_diag = bpf_stg_diag;
1175 	}
1176 
1177 	cb->data = cb_data;
1178 	return 0;
1179 }
1180 
1181 static int inet_diag_dump_start(struct netlink_callback *cb)
1182 {
1183 	return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req_v2));
1184 }
1185 
1186 static int inet_diag_dump_start_compat(struct netlink_callback *cb)
1187 {
1188 	return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req));
1189 }
1190 
1191 static int inet_diag_dump_done(struct netlink_callback *cb)
1192 {
1193 	struct inet_diag_dump_data *cb_data = cb->data;
1194 
1195 	bpf_sk_storage_diag_free(cb_data->bpf_stg_diag);
1196 	kfree(cb->data);
1197 
1198 	return 0;
1199 }
1200 
1201 static int inet_diag_type2proto(int type)
1202 {
1203 	switch (type) {
1204 	case TCPDIAG_GETSOCK:
1205 		return IPPROTO_TCP;
1206 	case DCCPDIAG_GETSOCK:
1207 		return IPPROTO_DCCP;
1208 	default:
1209 		return 0;
1210 	}
1211 }
1212 
1213 static int inet_diag_dump_compat(struct sk_buff *skb,
1214 				 struct netlink_callback *cb)
1215 {
1216 	struct inet_diag_req *rc = nlmsg_data(cb->nlh);
1217 	struct inet_diag_req_v2 req;
1218 
1219 	req.sdiag_family = AF_UNSPEC; /* compatibility */
1220 	req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
1221 	req.idiag_ext = rc->idiag_ext;
1222 	req.idiag_states = rc->idiag_states;
1223 	req.id = rc->id;
1224 
1225 	return __inet_diag_dump(skb, cb, &req);
1226 }
1227 
1228 static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
1229 				      const struct nlmsghdr *nlh)
1230 {
1231 	struct inet_diag_req *rc = nlmsg_data(nlh);
1232 	struct inet_diag_req_v2 req;
1233 
1234 	req.sdiag_family = rc->idiag_family;
1235 	req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
1236 	req.idiag_ext = rc->idiag_ext;
1237 	req.idiag_states = rc->idiag_states;
1238 	req.id = rc->id;
1239 
1240 	return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh, &req);
1241 }
1242 
1243 static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
1244 {
1245 	int hdrlen = sizeof(struct inet_diag_req);
1246 	struct net *net = sock_net(skb->sk);
1247 
1248 	if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
1249 	    nlmsg_len(nlh) < hdrlen)
1250 		return -EINVAL;
1251 
1252 	if (nlh->nlmsg_flags & NLM_F_DUMP) {
1253 		struct netlink_dump_control c = {
1254 			.start = inet_diag_dump_start_compat,
1255 			.done = inet_diag_dump_done,
1256 			.dump = inet_diag_dump_compat,
1257 		};
1258 		return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
1259 	}
1260 
1261 	return inet_diag_get_exact_compat(skb, nlh);
1262 }
1263 
1264 static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
1265 {
1266 	int hdrlen = sizeof(struct inet_diag_req_v2);
1267 	struct net *net = sock_net(skb->sk);
1268 
1269 	if (nlmsg_len(h) < hdrlen)
1270 		return -EINVAL;
1271 
1272 	if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
1273 	    h->nlmsg_flags & NLM_F_DUMP) {
1274 		struct netlink_dump_control c = {
1275 			.start = inet_diag_dump_start,
1276 			.done = inet_diag_dump_done,
1277 			.dump = inet_diag_dump,
1278 		};
1279 		return netlink_dump_start(net->diag_nlsk, skb, h, &c);
1280 	}
1281 
1282 	return inet_diag_cmd_exact(h->nlmsg_type, skb, h, nlmsg_data(h));
1283 }
1284 
1285 static
1286 int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk)
1287 {
1288 	const struct inet_diag_handler *handler;
1289 	struct nlmsghdr *nlh;
1290 	struct nlattr *attr;
1291 	struct inet_diag_msg *r;
1292 	void *info = NULL;
1293 	int err = 0;
1294 
1295 	nlh = nlmsg_put(skb, 0, 0, SOCK_DIAG_BY_FAMILY, sizeof(*r), 0);
1296 	if (!nlh)
1297 		return -ENOMEM;
1298 
1299 	r = nlmsg_data(nlh);
1300 	memset(r, 0, sizeof(*r));
1301 	inet_diag_msg_common_fill(r, sk);
1302 	if (sk->sk_type == SOCK_DGRAM || sk->sk_type == SOCK_STREAM)
1303 		r->id.idiag_sport = inet_sk(sk)->inet_sport;
1304 	r->idiag_state = sk->sk_state;
1305 
1306 	if ((err = nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))) {
1307 		nlmsg_cancel(skb, nlh);
1308 		return err;
1309 	}
1310 
1311 	handler = inet_diag_lock_handler(sk->sk_protocol);
1312 	if (IS_ERR(handler)) {
1313 		inet_diag_unlock_handler(handler);
1314 		nlmsg_cancel(skb, nlh);
1315 		return PTR_ERR(handler);
1316 	}
1317 
1318 	attr = handler->idiag_info_size
1319 		? nla_reserve_64bit(skb, INET_DIAG_INFO,
1320 				    handler->idiag_info_size,
1321 				    INET_DIAG_PAD)
1322 		: NULL;
1323 	if (attr)
1324 		info = nla_data(attr);
1325 
1326 	handler->idiag_get_info(sk, r, info);
1327 	inet_diag_unlock_handler(handler);
1328 
1329 	nlmsg_end(skb, nlh);
1330 	return 0;
1331 }
1332 
1333 static const struct sock_diag_handler inet_diag_handler = {
1334 	.family = AF_INET,
1335 	.dump = inet_diag_handler_cmd,
1336 	.get_info = inet_diag_handler_get_info,
1337 	.destroy = inet_diag_handler_cmd,
1338 };
1339 
1340 static const struct sock_diag_handler inet6_diag_handler = {
1341 	.family = AF_INET6,
1342 	.dump = inet_diag_handler_cmd,
1343 	.get_info = inet_diag_handler_get_info,
1344 	.destroy = inet_diag_handler_cmd,
1345 };
1346 
1347 int inet_diag_register(const struct inet_diag_handler *h)
1348 {
1349 	const __u16 type = h->idiag_type;
1350 	int err = -EINVAL;
1351 
1352 	if (type >= IPPROTO_MAX)
1353 		goto out;
1354 
1355 	mutex_lock(&inet_diag_table_mutex);
1356 	err = -EEXIST;
1357 	if (!inet_diag_table[type]) {
1358 		inet_diag_table[type] = h;
1359 		err = 0;
1360 	}
1361 	mutex_unlock(&inet_diag_table_mutex);
1362 out:
1363 	return err;
1364 }
1365 EXPORT_SYMBOL_GPL(inet_diag_register);
1366 
1367 void inet_diag_unregister(const struct inet_diag_handler *h)
1368 {
1369 	const __u16 type = h->idiag_type;
1370 
1371 	if (type >= IPPROTO_MAX)
1372 		return;
1373 
1374 	mutex_lock(&inet_diag_table_mutex);
1375 	inet_diag_table[type] = NULL;
1376 	mutex_unlock(&inet_diag_table_mutex);
1377 }
1378 EXPORT_SYMBOL_GPL(inet_diag_unregister);
1379 
1380 static int __init inet_diag_init(void)
1381 {
1382 	const int inet_diag_table_size = (IPPROTO_MAX *
1383 					  sizeof(struct inet_diag_handler *));
1384 	int err = -ENOMEM;
1385 
1386 	inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1387 	if (!inet_diag_table)
1388 		goto out;
1389 
1390 	err = sock_diag_register(&inet_diag_handler);
1391 	if (err)
1392 		goto out_free_nl;
1393 
1394 	err = sock_diag_register(&inet6_diag_handler);
1395 	if (err)
1396 		goto out_free_inet;
1397 
1398 	sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
1399 out:
1400 	return err;
1401 
1402 out_free_inet:
1403 	sock_diag_unregister(&inet_diag_handler);
1404 out_free_nl:
1405 	kfree(inet_diag_table);
1406 	goto out;
1407 }
1408 
1409 static void __exit inet_diag_exit(void)
1410 {
1411 	sock_diag_unregister(&inet6_diag_handler);
1412 	sock_diag_unregister(&inet_diag_handler);
1413 	sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
1414 	kfree(inet_diag_table);
1415 }
1416 
1417 module_init(inet_diag_init);
1418 module_exit(inet_diag_exit);
1419 MODULE_LICENSE("GPL");
1420 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1421 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);
1422