1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * inet_diag.c Module for monitoring INET transport protocols sockets. 4 * 5 * Authors: Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru> 6 */ 7 8 #include <linux/kernel.h> 9 #include <linux/module.h> 10 #include <linux/types.h> 11 #include <linux/fcntl.h> 12 #include <linux/random.h> 13 #include <linux/slab.h> 14 #include <linux/cache.h> 15 #include <linux/init.h> 16 #include <linux/time.h> 17 18 #include <net/icmp.h> 19 #include <net/tcp.h> 20 #include <net/ipv6.h> 21 #include <net/inet_common.h> 22 #include <net/inet_connection_sock.h> 23 #include <net/inet_hashtables.h> 24 #include <net/inet_timewait_sock.h> 25 #include <net/inet6_hashtables.h> 26 #include <net/bpf_sk_storage.h> 27 #include <net/netlink.h> 28 29 #include <linux/inet.h> 30 #include <linux/stddef.h> 31 32 #include <linux/inet_diag.h> 33 #include <linux/sock_diag.h> 34 35 static const struct inet_diag_handler **inet_diag_table; 36 37 struct inet_diag_entry { 38 const __be32 *saddr; 39 const __be32 *daddr; 40 u16 sport; 41 u16 dport; 42 u16 family; 43 u16 userlocks; 44 u32 ifindex; 45 u32 mark; 46 #ifdef CONFIG_SOCK_CGROUP_DATA 47 u64 cgroup_id; 48 #endif 49 }; 50 51 static DEFINE_MUTEX(inet_diag_table_mutex); 52 53 static const struct inet_diag_handler *inet_diag_lock_handler(int proto) 54 { 55 if (proto < 0 || proto >= IPPROTO_MAX) { 56 mutex_lock(&inet_diag_table_mutex); 57 return ERR_PTR(-ENOENT); 58 } 59 60 if (!inet_diag_table[proto]) 61 sock_load_diag_module(AF_INET, proto); 62 63 mutex_lock(&inet_diag_table_mutex); 64 if (!inet_diag_table[proto]) 65 return ERR_PTR(-ENOENT); 66 67 return inet_diag_table[proto]; 68 } 69 70 static void inet_diag_unlock_handler(const struct inet_diag_handler *handler) 71 { 72 mutex_unlock(&inet_diag_table_mutex); 73 } 74 75 void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk) 76 { 77 r->idiag_family = sk->sk_family; 78 79 r->id.idiag_sport = htons(sk->sk_num); 80 r->id.idiag_dport = sk->sk_dport; 81 r->id.idiag_if = sk->sk_bound_dev_if; 82 sock_diag_save_cookie(sk, r->id.idiag_cookie); 83 84 #if IS_ENABLED(CONFIG_IPV6) 85 if (sk->sk_family == AF_INET6) { 86 *(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr; 87 *(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr; 88 } else 89 #endif 90 { 91 memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src)); 92 memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst)); 93 94 r->id.idiag_src[0] = sk->sk_rcv_saddr; 95 r->id.idiag_dst[0] = sk->sk_daddr; 96 } 97 } 98 EXPORT_SYMBOL_GPL(inet_diag_msg_common_fill); 99 100 static size_t inet_sk_attr_size(struct sock *sk, 101 const struct inet_diag_req_v2 *req, 102 bool net_admin) 103 { 104 const struct inet_diag_handler *handler; 105 size_t aux = 0; 106 107 handler = inet_diag_table[req->sdiag_protocol]; 108 if (handler && handler->idiag_get_aux_size) 109 aux = handler->idiag_get_aux_size(sk, net_admin); 110 111 return nla_total_size(sizeof(struct tcp_info)) 112 + nla_total_size(sizeof(struct inet_diag_msg)) 113 + inet_diag_msg_attrs_size() 114 + nla_total_size(sizeof(struct inet_diag_meminfo)) 115 + nla_total_size(SK_MEMINFO_VARS * sizeof(u32)) 116 + nla_total_size(TCP_CA_NAME_MAX) 117 + nla_total_size(sizeof(struct tcpvegas_info)) 118 + aux 119 + 64; 120 } 121 122 int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb, 123 struct inet_diag_msg *r, int ext, 124 struct user_namespace *user_ns, 125 bool net_admin) 126 { 127 const struct inet_sock *inet = inet_sk(sk); 128 129 if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown)) 130 goto errout; 131 132 /* IPv6 dual-stack sockets use inet->tos for IPv4 connections, 133 * hence this needs to be included regardless of socket family. 134 */ 135 if (ext & (1 << (INET_DIAG_TOS - 1))) 136 if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0) 137 goto errout; 138 139 #if IS_ENABLED(CONFIG_IPV6) 140 if (r->idiag_family == AF_INET6) { 141 if (ext & (1 << (INET_DIAG_TCLASS - 1))) 142 if (nla_put_u8(skb, INET_DIAG_TCLASS, 143 inet6_sk(sk)->tclass) < 0) 144 goto errout; 145 146 if (((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) && 147 nla_put_u8(skb, INET_DIAG_SKV6ONLY, ipv6_only_sock(sk))) 148 goto errout; 149 } 150 #endif 151 152 if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, sk->sk_mark)) 153 goto errout; 154 155 if (ext & (1 << (INET_DIAG_CLASS_ID - 1)) || 156 ext & (1 << (INET_DIAG_TCLASS - 1))) { 157 u32 classid = 0; 158 159 #ifdef CONFIG_SOCK_CGROUP_DATA 160 classid = sock_cgroup_classid(&sk->sk_cgrp_data); 161 #endif 162 /* Fallback to socket priority if class id isn't set. 163 * Classful qdiscs use it as direct reference to class. 164 * For cgroup2 classid is always zero. 165 */ 166 if (!classid) 167 classid = sk->sk_priority; 168 169 if (nla_put_u32(skb, INET_DIAG_CLASS_ID, classid)) 170 goto errout; 171 } 172 173 #ifdef CONFIG_SOCK_CGROUP_DATA 174 if (nla_put_u64_64bit(skb, INET_DIAG_CGROUP_ID, 175 cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)), 176 INET_DIAG_PAD)) 177 goto errout; 178 #endif 179 180 r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk)); 181 r->idiag_inode = sock_i_ino(sk); 182 183 return 0; 184 errout: 185 return 1; 186 } 187 EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill); 188 189 static int inet_diag_parse_attrs(const struct nlmsghdr *nlh, int hdrlen, 190 struct nlattr **req_nlas) 191 { 192 struct nlattr *nla; 193 int remaining; 194 195 nlmsg_for_each_attr(nla, nlh, hdrlen, remaining) { 196 int type = nla_type(nla); 197 198 if (type == INET_DIAG_REQ_PROTOCOL && nla_len(nla) != sizeof(u32)) 199 return -EINVAL; 200 201 if (type < __INET_DIAG_REQ_MAX) 202 req_nlas[type] = nla; 203 } 204 return 0; 205 } 206 207 static int inet_diag_get_protocol(const struct inet_diag_req_v2 *req, 208 const struct inet_diag_dump_data *data) 209 { 210 if (data->req_nlas[INET_DIAG_REQ_PROTOCOL]) 211 return nla_get_u32(data->req_nlas[INET_DIAG_REQ_PROTOCOL]); 212 return req->sdiag_protocol; 213 } 214 215 #define MAX_DUMP_ALLOC_SIZE (KMALLOC_MAX_SIZE - SKB_DATA_ALIGN(sizeof(struct skb_shared_info))) 216 217 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, 218 struct sk_buff *skb, struct netlink_callback *cb, 219 const struct inet_diag_req_v2 *req, 220 u16 nlmsg_flags, bool net_admin) 221 { 222 const struct tcp_congestion_ops *ca_ops; 223 const struct inet_diag_handler *handler; 224 struct inet_diag_dump_data *cb_data; 225 int ext = req->idiag_ext; 226 struct inet_diag_msg *r; 227 struct nlmsghdr *nlh; 228 struct nlattr *attr; 229 void *info = NULL; 230 231 cb_data = cb->data; 232 handler = inet_diag_table[inet_diag_get_protocol(req, cb_data)]; 233 BUG_ON(!handler); 234 235 nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, 236 cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags); 237 if (!nlh) 238 return -EMSGSIZE; 239 240 r = nlmsg_data(nlh); 241 BUG_ON(!sk_fullsock(sk)); 242 243 inet_diag_msg_common_fill(r, sk); 244 r->idiag_state = sk->sk_state; 245 r->idiag_timer = 0; 246 r->idiag_retrans = 0; 247 248 if (inet_diag_msg_attrs_fill(sk, skb, r, ext, 249 sk_user_ns(NETLINK_CB(cb->skb).sk), 250 net_admin)) 251 goto errout; 252 253 if (ext & (1 << (INET_DIAG_MEMINFO - 1))) { 254 struct inet_diag_meminfo minfo = { 255 .idiag_rmem = sk_rmem_alloc_get(sk), 256 .idiag_wmem = READ_ONCE(sk->sk_wmem_queued), 257 .idiag_fmem = sk->sk_forward_alloc, 258 .idiag_tmem = sk_wmem_alloc_get(sk), 259 }; 260 261 if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0) 262 goto errout; 263 } 264 265 if (ext & (1 << (INET_DIAG_SKMEMINFO - 1))) 266 if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO)) 267 goto errout; 268 269 /* 270 * RAW sockets might have user-defined protocols assigned, 271 * so report the one supplied on socket creation. 272 */ 273 if (sk->sk_type == SOCK_RAW) { 274 if (nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol)) 275 goto errout; 276 } 277 278 if (!icsk) { 279 handler->idiag_get_info(sk, r, NULL); 280 goto out; 281 } 282 283 if (icsk->icsk_pending == ICSK_TIME_RETRANS || 284 icsk->icsk_pending == ICSK_TIME_REO_TIMEOUT || 285 icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) { 286 r->idiag_timer = 1; 287 r->idiag_retrans = icsk->icsk_retransmits; 288 r->idiag_expires = 289 jiffies_delta_to_msecs(icsk->icsk_timeout - jiffies); 290 } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) { 291 r->idiag_timer = 4; 292 r->idiag_retrans = icsk->icsk_probes_out; 293 r->idiag_expires = 294 jiffies_delta_to_msecs(icsk->icsk_timeout - jiffies); 295 } else if (timer_pending(&sk->sk_timer)) { 296 r->idiag_timer = 2; 297 r->idiag_retrans = icsk->icsk_probes_out; 298 r->idiag_expires = 299 jiffies_delta_to_msecs(sk->sk_timer.expires - jiffies); 300 } else { 301 r->idiag_timer = 0; 302 r->idiag_expires = 0; 303 } 304 305 if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) { 306 attr = nla_reserve_64bit(skb, INET_DIAG_INFO, 307 handler->idiag_info_size, 308 INET_DIAG_PAD); 309 if (!attr) 310 goto errout; 311 312 info = nla_data(attr); 313 } 314 315 if (ext & (1 << (INET_DIAG_CONG - 1))) { 316 int err = 0; 317 318 rcu_read_lock(); 319 ca_ops = READ_ONCE(icsk->icsk_ca_ops); 320 if (ca_ops) 321 err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name); 322 rcu_read_unlock(); 323 if (err < 0) 324 goto errout; 325 } 326 327 handler->idiag_get_info(sk, r, info); 328 329 if (ext & (1 << (INET_DIAG_INFO - 1)) && handler->idiag_get_aux) 330 if (handler->idiag_get_aux(sk, net_admin, skb) < 0) 331 goto errout; 332 333 if (sk->sk_state < TCP_TIME_WAIT) { 334 union tcp_cc_info info; 335 size_t sz = 0; 336 int attr; 337 338 rcu_read_lock(); 339 ca_ops = READ_ONCE(icsk->icsk_ca_ops); 340 if (ca_ops && ca_ops->get_info) 341 sz = ca_ops->get_info(sk, ext, &attr, &info); 342 rcu_read_unlock(); 343 if (sz && nla_put(skb, attr, sz, &info) < 0) 344 goto errout; 345 } 346 347 /* Keep it at the end for potential retry with a larger skb, 348 * or else do best-effort fitting, which is only done for the 349 * first_nlmsg. 350 */ 351 if (cb_data->bpf_stg_diag) { 352 bool first_nlmsg = ((unsigned char *)nlh == skb->data); 353 unsigned int prev_min_dump_alloc; 354 unsigned int total_nla_size = 0; 355 unsigned int msg_len; 356 int err; 357 358 msg_len = skb_tail_pointer(skb) - (unsigned char *)nlh; 359 err = bpf_sk_storage_diag_put(cb_data->bpf_stg_diag, sk, skb, 360 INET_DIAG_SK_BPF_STORAGES, 361 &total_nla_size); 362 363 if (!err) 364 goto out; 365 366 total_nla_size += msg_len; 367 prev_min_dump_alloc = cb->min_dump_alloc; 368 if (total_nla_size > prev_min_dump_alloc) 369 cb->min_dump_alloc = min_t(u32, total_nla_size, 370 MAX_DUMP_ALLOC_SIZE); 371 372 if (!first_nlmsg) 373 goto errout; 374 375 if (cb->min_dump_alloc > prev_min_dump_alloc) 376 /* Retry with pskb_expand_head() with 377 * __GFP_DIRECT_RECLAIM 378 */ 379 goto errout; 380 381 WARN_ON_ONCE(total_nla_size <= prev_min_dump_alloc); 382 383 /* Send what we have for this sk 384 * and move on to the next sk in the following 385 * dump() 386 */ 387 } 388 389 out: 390 nlmsg_end(skb, nlh); 391 return 0; 392 393 errout: 394 nlmsg_cancel(skb, nlh); 395 return -EMSGSIZE; 396 } 397 EXPORT_SYMBOL_GPL(inet_sk_diag_fill); 398 399 static int inet_twsk_diag_fill(struct sock *sk, 400 struct sk_buff *skb, 401 struct netlink_callback *cb, 402 u16 nlmsg_flags) 403 { 404 struct inet_timewait_sock *tw = inet_twsk(sk); 405 struct inet_diag_msg *r; 406 struct nlmsghdr *nlh; 407 long tmo; 408 409 nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, 410 cb->nlh->nlmsg_seq, cb->nlh->nlmsg_type, 411 sizeof(*r), nlmsg_flags); 412 if (!nlh) 413 return -EMSGSIZE; 414 415 r = nlmsg_data(nlh); 416 BUG_ON(tw->tw_state != TCP_TIME_WAIT); 417 418 inet_diag_msg_common_fill(r, sk); 419 r->idiag_retrans = 0; 420 421 r->idiag_state = tw->tw_substate; 422 r->idiag_timer = 3; 423 tmo = tw->tw_timer.expires - jiffies; 424 r->idiag_expires = jiffies_delta_to_msecs(tmo); 425 r->idiag_rqueue = 0; 426 r->idiag_wqueue = 0; 427 r->idiag_uid = 0; 428 r->idiag_inode = 0; 429 430 nlmsg_end(skb, nlh); 431 return 0; 432 } 433 434 static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb, 435 struct netlink_callback *cb, 436 u16 nlmsg_flags, bool net_admin) 437 { 438 struct request_sock *reqsk = inet_reqsk(sk); 439 struct inet_diag_msg *r; 440 struct nlmsghdr *nlh; 441 long tmo; 442 443 nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, 444 cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags); 445 if (!nlh) 446 return -EMSGSIZE; 447 448 r = nlmsg_data(nlh); 449 inet_diag_msg_common_fill(r, sk); 450 r->idiag_state = TCP_SYN_RECV; 451 r->idiag_timer = 1; 452 r->idiag_retrans = reqsk->num_retrans; 453 454 BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) != 455 offsetof(struct sock, sk_cookie)); 456 457 tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies; 458 r->idiag_expires = jiffies_delta_to_msecs(tmo); 459 r->idiag_rqueue = 0; 460 r->idiag_wqueue = 0; 461 r->idiag_uid = 0; 462 r->idiag_inode = 0; 463 464 if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, 465 inet_rsk(reqsk)->ir_mark)) 466 return -EMSGSIZE; 467 468 nlmsg_end(skb, nlh); 469 return 0; 470 } 471 472 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, 473 struct netlink_callback *cb, 474 const struct inet_diag_req_v2 *r, 475 u16 nlmsg_flags, bool net_admin) 476 { 477 if (sk->sk_state == TCP_TIME_WAIT) 478 return inet_twsk_diag_fill(sk, skb, cb, nlmsg_flags); 479 480 if (sk->sk_state == TCP_NEW_SYN_RECV) 481 return inet_req_diag_fill(sk, skb, cb, nlmsg_flags, net_admin); 482 483 return inet_sk_diag_fill(sk, inet_csk(sk), skb, cb, r, nlmsg_flags, 484 net_admin); 485 } 486 487 struct sock *inet_diag_find_one_icsk(struct net *net, 488 struct inet_hashinfo *hashinfo, 489 const struct inet_diag_req_v2 *req) 490 { 491 struct sock *sk; 492 493 rcu_read_lock(); 494 if (req->sdiag_family == AF_INET) 495 sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[0], 496 req->id.idiag_dport, req->id.idiag_src[0], 497 req->id.idiag_sport, req->id.idiag_if); 498 #if IS_ENABLED(CONFIG_IPV6) 499 else if (req->sdiag_family == AF_INET6) { 500 if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) && 501 ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src)) 502 sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[3], 503 req->id.idiag_dport, req->id.idiag_src[3], 504 req->id.idiag_sport, req->id.idiag_if); 505 else 506 sk = inet6_lookup(net, hashinfo, NULL, 0, 507 (struct in6_addr *)req->id.idiag_dst, 508 req->id.idiag_dport, 509 (struct in6_addr *)req->id.idiag_src, 510 req->id.idiag_sport, 511 req->id.idiag_if); 512 } 513 #endif 514 else { 515 rcu_read_unlock(); 516 return ERR_PTR(-EINVAL); 517 } 518 rcu_read_unlock(); 519 if (!sk) 520 return ERR_PTR(-ENOENT); 521 522 if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) { 523 sock_gen_put(sk); 524 return ERR_PTR(-ENOENT); 525 } 526 527 return sk; 528 } 529 EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk); 530 531 int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, 532 struct netlink_callback *cb, 533 const struct inet_diag_req_v2 *req) 534 { 535 struct sk_buff *in_skb = cb->skb; 536 bool net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN); 537 struct net *net = sock_net(in_skb->sk); 538 struct sk_buff *rep; 539 struct sock *sk; 540 int err; 541 542 sk = inet_diag_find_one_icsk(net, hashinfo, req); 543 if (IS_ERR(sk)) 544 return PTR_ERR(sk); 545 546 rep = nlmsg_new(inet_sk_attr_size(sk, req, net_admin), GFP_KERNEL); 547 if (!rep) { 548 err = -ENOMEM; 549 goto out; 550 } 551 552 err = sk_diag_fill(sk, rep, cb, req, 0, net_admin); 553 if (err < 0) { 554 WARN_ON(err == -EMSGSIZE); 555 nlmsg_free(rep); 556 goto out; 557 } 558 err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid, 559 MSG_DONTWAIT); 560 if (err > 0) 561 err = 0; 562 563 out: 564 if (sk) 565 sock_gen_put(sk); 566 567 return err; 568 } 569 EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk); 570 571 static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb, 572 const struct nlmsghdr *nlh, 573 int hdrlen, 574 const struct inet_diag_req_v2 *req) 575 { 576 const struct inet_diag_handler *handler; 577 struct inet_diag_dump_data dump_data; 578 int err, protocol; 579 580 memset(&dump_data, 0, sizeof(dump_data)); 581 err = inet_diag_parse_attrs(nlh, hdrlen, dump_data.req_nlas); 582 if (err) 583 return err; 584 585 protocol = inet_diag_get_protocol(req, &dump_data); 586 587 handler = inet_diag_lock_handler(protocol); 588 if (IS_ERR(handler)) { 589 err = PTR_ERR(handler); 590 } else if (cmd == SOCK_DIAG_BY_FAMILY) { 591 struct netlink_callback cb = { 592 .nlh = nlh, 593 .skb = in_skb, 594 .data = &dump_data, 595 }; 596 err = handler->dump_one(&cb, req); 597 } else if (cmd == SOCK_DESTROY && handler->destroy) { 598 err = handler->destroy(in_skb, req); 599 } else { 600 err = -EOPNOTSUPP; 601 } 602 inet_diag_unlock_handler(handler); 603 604 return err; 605 } 606 607 static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits) 608 { 609 int words = bits >> 5; 610 611 bits &= 0x1f; 612 613 if (words) { 614 if (memcmp(a1, a2, words << 2)) 615 return 0; 616 } 617 if (bits) { 618 __be32 w1, w2; 619 __be32 mask; 620 621 w1 = a1[words]; 622 w2 = a2[words]; 623 624 mask = htonl((0xffffffff) << (32 - bits)); 625 626 if ((w1 ^ w2) & mask) 627 return 0; 628 } 629 630 return 1; 631 } 632 633 static int inet_diag_bc_run(const struct nlattr *_bc, 634 const struct inet_diag_entry *entry) 635 { 636 const void *bc = nla_data(_bc); 637 int len = nla_len(_bc); 638 639 while (len > 0) { 640 int yes = 1; 641 const struct inet_diag_bc_op *op = bc; 642 643 switch (op->code) { 644 case INET_DIAG_BC_NOP: 645 break; 646 case INET_DIAG_BC_JMP: 647 yes = 0; 648 break; 649 case INET_DIAG_BC_S_EQ: 650 yes = entry->sport == op[1].no; 651 break; 652 case INET_DIAG_BC_S_GE: 653 yes = entry->sport >= op[1].no; 654 break; 655 case INET_DIAG_BC_S_LE: 656 yes = entry->sport <= op[1].no; 657 break; 658 case INET_DIAG_BC_D_EQ: 659 yes = entry->dport == op[1].no; 660 break; 661 case INET_DIAG_BC_D_GE: 662 yes = entry->dport >= op[1].no; 663 break; 664 case INET_DIAG_BC_D_LE: 665 yes = entry->dport <= op[1].no; 666 break; 667 case INET_DIAG_BC_AUTO: 668 yes = !(entry->userlocks & SOCK_BINDPORT_LOCK); 669 break; 670 case INET_DIAG_BC_S_COND: 671 case INET_DIAG_BC_D_COND: { 672 const struct inet_diag_hostcond *cond; 673 const __be32 *addr; 674 675 cond = (const struct inet_diag_hostcond *)(op + 1); 676 if (cond->port != -1 && 677 cond->port != (op->code == INET_DIAG_BC_S_COND ? 678 entry->sport : entry->dport)) { 679 yes = 0; 680 break; 681 } 682 683 if (op->code == INET_DIAG_BC_S_COND) 684 addr = entry->saddr; 685 else 686 addr = entry->daddr; 687 688 if (cond->family != AF_UNSPEC && 689 cond->family != entry->family) { 690 if (entry->family == AF_INET6 && 691 cond->family == AF_INET) { 692 if (addr[0] == 0 && addr[1] == 0 && 693 addr[2] == htonl(0xffff) && 694 bitstring_match(addr + 3, 695 cond->addr, 696 cond->prefix_len)) 697 break; 698 } 699 yes = 0; 700 break; 701 } 702 703 if (cond->prefix_len == 0) 704 break; 705 if (bitstring_match(addr, cond->addr, 706 cond->prefix_len)) 707 break; 708 yes = 0; 709 break; 710 } 711 case INET_DIAG_BC_DEV_COND: { 712 u32 ifindex; 713 714 ifindex = *((const u32 *)(op + 1)); 715 if (ifindex != entry->ifindex) 716 yes = 0; 717 break; 718 } 719 case INET_DIAG_BC_MARK_COND: { 720 struct inet_diag_markcond *cond; 721 722 cond = (struct inet_diag_markcond *)(op + 1); 723 if ((entry->mark & cond->mask) != cond->mark) 724 yes = 0; 725 break; 726 } 727 #ifdef CONFIG_SOCK_CGROUP_DATA 728 case INET_DIAG_BC_CGROUP_COND: { 729 u64 cgroup_id; 730 731 cgroup_id = get_unaligned((const u64 *)(op + 1)); 732 if (cgroup_id != entry->cgroup_id) 733 yes = 0; 734 break; 735 } 736 #endif 737 } 738 739 if (yes) { 740 len -= op->yes; 741 bc += op->yes; 742 } else { 743 len -= op->no; 744 bc += op->no; 745 } 746 } 747 return len == 0; 748 } 749 750 /* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV) 751 */ 752 static void entry_fill_addrs(struct inet_diag_entry *entry, 753 const struct sock *sk) 754 { 755 #if IS_ENABLED(CONFIG_IPV6) 756 if (sk->sk_family == AF_INET6) { 757 entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32; 758 entry->daddr = sk->sk_v6_daddr.s6_addr32; 759 } else 760 #endif 761 { 762 entry->saddr = &sk->sk_rcv_saddr; 763 entry->daddr = &sk->sk_daddr; 764 } 765 } 766 767 int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk) 768 { 769 struct inet_sock *inet = inet_sk(sk); 770 struct inet_diag_entry entry; 771 772 if (!bc) 773 return 1; 774 775 entry.family = sk->sk_family; 776 entry_fill_addrs(&entry, sk); 777 entry.sport = inet->inet_num; 778 entry.dport = ntohs(inet->inet_dport); 779 entry.ifindex = sk->sk_bound_dev_if; 780 entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0; 781 if (sk_fullsock(sk)) 782 entry.mark = sk->sk_mark; 783 else if (sk->sk_state == TCP_NEW_SYN_RECV) 784 entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark; 785 else 786 entry.mark = 0; 787 #ifdef CONFIG_SOCK_CGROUP_DATA 788 entry.cgroup_id = sk_fullsock(sk) ? 789 cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)) : 0; 790 #endif 791 792 return inet_diag_bc_run(bc, &entry); 793 } 794 EXPORT_SYMBOL_GPL(inet_diag_bc_sk); 795 796 static int valid_cc(const void *bc, int len, int cc) 797 { 798 while (len >= 0) { 799 const struct inet_diag_bc_op *op = bc; 800 801 if (cc > len) 802 return 0; 803 if (cc == len) 804 return 1; 805 if (op->yes < 4 || op->yes & 3) 806 return 0; 807 len -= op->yes; 808 bc += op->yes; 809 } 810 return 0; 811 } 812 813 /* data is u32 ifindex */ 814 static bool valid_devcond(const struct inet_diag_bc_op *op, int len, 815 int *min_len) 816 { 817 /* Check ifindex space. */ 818 *min_len += sizeof(u32); 819 if (len < *min_len) 820 return false; 821 822 return true; 823 } 824 /* Validate an inet_diag_hostcond. */ 825 static bool valid_hostcond(const struct inet_diag_bc_op *op, int len, 826 int *min_len) 827 { 828 struct inet_diag_hostcond *cond; 829 int addr_len; 830 831 /* Check hostcond space. */ 832 *min_len += sizeof(struct inet_diag_hostcond); 833 if (len < *min_len) 834 return false; 835 cond = (struct inet_diag_hostcond *)(op + 1); 836 837 /* Check address family and address length. */ 838 switch (cond->family) { 839 case AF_UNSPEC: 840 addr_len = 0; 841 break; 842 case AF_INET: 843 addr_len = sizeof(struct in_addr); 844 break; 845 case AF_INET6: 846 addr_len = sizeof(struct in6_addr); 847 break; 848 default: 849 return false; 850 } 851 *min_len += addr_len; 852 if (len < *min_len) 853 return false; 854 855 /* Check prefix length (in bits) vs address length (in bytes). */ 856 if (cond->prefix_len > 8 * addr_len) 857 return false; 858 859 return true; 860 } 861 862 /* Validate a port comparison operator. */ 863 static bool valid_port_comparison(const struct inet_diag_bc_op *op, 864 int len, int *min_len) 865 { 866 /* Port comparisons put the port in a follow-on inet_diag_bc_op. */ 867 *min_len += sizeof(struct inet_diag_bc_op); 868 if (len < *min_len) 869 return false; 870 return true; 871 } 872 873 static bool valid_markcond(const struct inet_diag_bc_op *op, int len, 874 int *min_len) 875 { 876 *min_len += sizeof(struct inet_diag_markcond); 877 return len >= *min_len; 878 } 879 880 #ifdef CONFIG_SOCK_CGROUP_DATA 881 static bool valid_cgroupcond(const struct inet_diag_bc_op *op, int len, 882 int *min_len) 883 { 884 *min_len += sizeof(u64); 885 return len >= *min_len; 886 } 887 #endif 888 889 static int inet_diag_bc_audit(const struct nlattr *attr, 890 const struct sk_buff *skb) 891 { 892 bool net_admin = netlink_net_capable(skb, CAP_NET_ADMIN); 893 const void *bytecode, *bc; 894 int bytecode_len, len; 895 896 if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op)) 897 return -EINVAL; 898 899 bytecode = bc = nla_data(attr); 900 len = bytecode_len = nla_len(attr); 901 902 while (len > 0) { 903 int min_len = sizeof(struct inet_diag_bc_op); 904 const struct inet_diag_bc_op *op = bc; 905 906 switch (op->code) { 907 case INET_DIAG_BC_S_COND: 908 case INET_DIAG_BC_D_COND: 909 if (!valid_hostcond(bc, len, &min_len)) 910 return -EINVAL; 911 break; 912 case INET_DIAG_BC_DEV_COND: 913 if (!valid_devcond(bc, len, &min_len)) 914 return -EINVAL; 915 break; 916 case INET_DIAG_BC_S_EQ: 917 case INET_DIAG_BC_S_GE: 918 case INET_DIAG_BC_S_LE: 919 case INET_DIAG_BC_D_EQ: 920 case INET_DIAG_BC_D_GE: 921 case INET_DIAG_BC_D_LE: 922 if (!valid_port_comparison(bc, len, &min_len)) 923 return -EINVAL; 924 break; 925 case INET_DIAG_BC_MARK_COND: 926 if (!net_admin) 927 return -EPERM; 928 if (!valid_markcond(bc, len, &min_len)) 929 return -EINVAL; 930 break; 931 #ifdef CONFIG_SOCK_CGROUP_DATA 932 case INET_DIAG_BC_CGROUP_COND: 933 if (!valid_cgroupcond(bc, len, &min_len)) 934 return -EINVAL; 935 break; 936 #endif 937 case INET_DIAG_BC_AUTO: 938 case INET_DIAG_BC_JMP: 939 case INET_DIAG_BC_NOP: 940 break; 941 default: 942 return -EINVAL; 943 } 944 945 if (op->code != INET_DIAG_BC_NOP) { 946 if (op->no < min_len || op->no > len + 4 || op->no & 3) 947 return -EINVAL; 948 if (op->no < len && 949 !valid_cc(bytecode, bytecode_len, len - op->no)) 950 return -EINVAL; 951 } 952 953 if (op->yes < min_len || op->yes > len + 4 || op->yes & 3) 954 return -EINVAL; 955 bc += op->yes; 956 len -= op->yes; 957 } 958 return len == 0 ? 0 : -EINVAL; 959 } 960 961 static void twsk_build_assert(void) 962 { 963 BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) != 964 offsetof(struct sock, sk_family)); 965 966 BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) != 967 offsetof(struct inet_sock, inet_num)); 968 969 BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) != 970 offsetof(struct inet_sock, inet_dport)); 971 972 BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) != 973 offsetof(struct inet_sock, inet_rcv_saddr)); 974 975 BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) != 976 offsetof(struct inet_sock, inet_daddr)); 977 978 #if IS_ENABLED(CONFIG_IPV6) 979 BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) != 980 offsetof(struct sock, sk_v6_rcv_saddr)); 981 982 BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) != 983 offsetof(struct sock, sk_v6_daddr)); 984 #endif 985 } 986 987 void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb, 988 struct netlink_callback *cb, 989 const struct inet_diag_req_v2 *r) 990 { 991 bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); 992 struct inet_diag_dump_data *cb_data = cb->data; 993 struct net *net = sock_net(skb->sk); 994 u32 idiag_states = r->idiag_states; 995 int i, num, s_i, s_num; 996 struct nlattr *bc; 997 struct sock *sk; 998 999 bc = cb_data->inet_diag_nla_bc; 1000 if (idiag_states & TCPF_SYN_RECV) 1001 idiag_states |= TCPF_NEW_SYN_RECV; 1002 s_i = cb->args[1]; 1003 s_num = num = cb->args[2]; 1004 1005 if (cb->args[0] == 0) { 1006 if (!(idiag_states & TCPF_LISTEN) || r->id.idiag_dport) 1007 goto skip_listen_ht; 1008 1009 for (i = s_i; i < INET_LHTABLE_SIZE; i++) { 1010 struct inet_listen_hashbucket *ilb; 1011 struct hlist_nulls_node *node; 1012 1013 num = 0; 1014 ilb = &hashinfo->listening_hash[i]; 1015 spin_lock(&ilb->lock); 1016 sk_nulls_for_each(sk, node, &ilb->nulls_head) { 1017 struct inet_sock *inet = inet_sk(sk); 1018 1019 if (!net_eq(sock_net(sk), net)) 1020 continue; 1021 1022 if (num < s_num) { 1023 num++; 1024 continue; 1025 } 1026 1027 if (r->sdiag_family != AF_UNSPEC && 1028 sk->sk_family != r->sdiag_family) 1029 goto next_listen; 1030 1031 if (r->id.idiag_sport != inet->inet_sport && 1032 r->id.idiag_sport) 1033 goto next_listen; 1034 1035 if (!inet_diag_bc_sk(bc, sk)) 1036 goto next_listen; 1037 1038 if (inet_sk_diag_fill(sk, inet_csk(sk), skb, 1039 cb, r, NLM_F_MULTI, 1040 net_admin) < 0) { 1041 spin_unlock(&ilb->lock); 1042 goto done; 1043 } 1044 1045 next_listen: 1046 ++num; 1047 } 1048 spin_unlock(&ilb->lock); 1049 1050 s_num = 0; 1051 } 1052 skip_listen_ht: 1053 cb->args[0] = 1; 1054 s_i = num = s_num = 0; 1055 } 1056 1057 if (!(idiag_states & ~TCPF_LISTEN)) 1058 goto out; 1059 1060 #define SKARR_SZ 16 1061 for (i = s_i; i <= hashinfo->ehash_mask; i++) { 1062 struct inet_ehash_bucket *head = &hashinfo->ehash[i]; 1063 spinlock_t *lock = inet_ehash_lockp(hashinfo, i); 1064 struct hlist_nulls_node *node; 1065 struct sock *sk_arr[SKARR_SZ]; 1066 int num_arr[SKARR_SZ]; 1067 int idx, accum, res; 1068 1069 if (hlist_nulls_empty(&head->chain)) 1070 continue; 1071 1072 if (i > s_i) 1073 s_num = 0; 1074 1075 next_chunk: 1076 num = 0; 1077 accum = 0; 1078 spin_lock_bh(lock); 1079 sk_nulls_for_each(sk, node, &head->chain) { 1080 int state; 1081 1082 if (!net_eq(sock_net(sk), net)) 1083 continue; 1084 if (num < s_num) 1085 goto next_normal; 1086 state = (sk->sk_state == TCP_TIME_WAIT) ? 1087 inet_twsk(sk)->tw_substate : sk->sk_state; 1088 if (!(idiag_states & (1 << state))) 1089 goto next_normal; 1090 if (r->sdiag_family != AF_UNSPEC && 1091 sk->sk_family != r->sdiag_family) 1092 goto next_normal; 1093 if (r->id.idiag_sport != htons(sk->sk_num) && 1094 r->id.idiag_sport) 1095 goto next_normal; 1096 if (r->id.idiag_dport != sk->sk_dport && 1097 r->id.idiag_dport) 1098 goto next_normal; 1099 twsk_build_assert(); 1100 1101 if (!inet_diag_bc_sk(bc, sk)) 1102 goto next_normal; 1103 1104 if (!refcount_inc_not_zero(&sk->sk_refcnt)) 1105 goto next_normal; 1106 1107 num_arr[accum] = num; 1108 sk_arr[accum] = sk; 1109 if (++accum == SKARR_SZ) 1110 break; 1111 next_normal: 1112 ++num; 1113 } 1114 spin_unlock_bh(lock); 1115 res = 0; 1116 for (idx = 0; idx < accum; idx++) { 1117 if (res >= 0) { 1118 res = sk_diag_fill(sk_arr[idx], skb, cb, r, 1119 NLM_F_MULTI, net_admin); 1120 if (res < 0) 1121 num = num_arr[idx]; 1122 } 1123 sock_gen_put(sk_arr[idx]); 1124 } 1125 if (res < 0) 1126 break; 1127 cond_resched(); 1128 if (accum == SKARR_SZ) { 1129 s_num = num + 1; 1130 goto next_chunk; 1131 } 1132 } 1133 1134 done: 1135 cb->args[1] = i; 1136 cb->args[2] = num; 1137 out: 1138 ; 1139 } 1140 EXPORT_SYMBOL_GPL(inet_diag_dump_icsk); 1141 1142 static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, 1143 const struct inet_diag_req_v2 *r) 1144 { 1145 struct inet_diag_dump_data *cb_data = cb->data; 1146 const struct inet_diag_handler *handler; 1147 u32 prev_min_dump_alloc; 1148 int protocol, err = 0; 1149 1150 protocol = inet_diag_get_protocol(r, cb_data); 1151 1152 again: 1153 prev_min_dump_alloc = cb->min_dump_alloc; 1154 handler = inet_diag_lock_handler(protocol); 1155 if (!IS_ERR(handler)) 1156 handler->dump(skb, cb, r); 1157 else 1158 err = PTR_ERR(handler); 1159 inet_diag_unlock_handler(handler); 1160 1161 /* The skb is not large enough to fit one sk info and 1162 * inet_sk_diag_fill() has requested for a larger skb. 1163 */ 1164 if (!skb->len && cb->min_dump_alloc > prev_min_dump_alloc) { 1165 err = pskb_expand_head(skb, 0, cb->min_dump_alloc, GFP_KERNEL); 1166 if (!err) 1167 goto again; 1168 } 1169 1170 return err ? : skb->len; 1171 } 1172 1173 static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) 1174 { 1175 return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh)); 1176 } 1177 1178 static int __inet_diag_dump_start(struct netlink_callback *cb, int hdrlen) 1179 { 1180 const struct nlmsghdr *nlh = cb->nlh; 1181 struct inet_diag_dump_data *cb_data; 1182 struct sk_buff *skb = cb->skb; 1183 struct nlattr *nla; 1184 int err; 1185 1186 cb_data = kzalloc(sizeof(*cb_data), GFP_KERNEL); 1187 if (!cb_data) 1188 return -ENOMEM; 1189 1190 err = inet_diag_parse_attrs(nlh, hdrlen, cb_data->req_nlas); 1191 if (err) { 1192 kfree(cb_data); 1193 return err; 1194 } 1195 nla = cb_data->inet_diag_nla_bc; 1196 if (nla) { 1197 err = inet_diag_bc_audit(nla, skb); 1198 if (err) { 1199 kfree(cb_data); 1200 return err; 1201 } 1202 } 1203 1204 nla = cb_data->inet_diag_nla_bpf_stgs; 1205 if (nla) { 1206 struct bpf_sk_storage_diag *bpf_stg_diag; 1207 1208 bpf_stg_diag = bpf_sk_storage_diag_alloc(nla); 1209 if (IS_ERR(bpf_stg_diag)) { 1210 kfree(cb_data); 1211 return PTR_ERR(bpf_stg_diag); 1212 } 1213 cb_data->bpf_stg_diag = bpf_stg_diag; 1214 } 1215 1216 cb->data = cb_data; 1217 return 0; 1218 } 1219 1220 static int inet_diag_dump_start(struct netlink_callback *cb) 1221 { 1222 return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req_v2)); 1223 } 1224 1225 static int inet_diag_dump_start_compat(struct netlink_callback *cb) 1226 { 1227 return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req)); 1228 } 1229 1230 static int inet_diag_dump_done(struct netlink_callback *cb) 1231 { 1232 struct inet_diag_dump_data *cb_data = cb->data; 1233 1234 bpf_sk_storage_diag_free(cb_data->bpf_stg_diag); 1235 kfree(cb->data); 1236 1237 return 0; 1238 } 1239 1240 static int inet_diag_type2proto(int type) 1241 { 1242 switch (type) { 1243 case TCPDIAG_GETSOCK: 1244 return IPPROTO_TCP; 1245 case DCCPDIAG_GETSOCK: 1246 return IPPROTO_DCCP; 1247 default: 1248 return 0; 1249 } 1250 } 1251 1252 static int inet_diag_dump_compat(struct sk_buff *skb, 1253 struct netlink_callback *cb) 1254 { 1255 struct inet_diag_req *rc = nlmsg_data(cb->nlh); 1256 struct inet_diag_req_v2 req; 1257 1258 req.sdiag_family = AF_UNSPEC; /* compatibility */ 1259 req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type); 1260 req.idiag_ext = rc->idiag_ext; 1261 req.idiag_states = rc->idiag_states; 1262 req.id = rc->id; 1263 1264 return __inet_diag_dump(skb, cb, &req); 1265 } 1266 1267 static int inet_diag_get_exact_compat(struct sk_buff *in_skb, 1268 const struct nlmsghdr *nlh) 1269 { 1270 struct inet_diag_req *rc = nlmsg_data(nlh); 1271 struct inet_diag_req_v2 req; 1272 1273 req.sdiag_family = rc->idiag_family; 1274 req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type); 1275 req.idiag_ext = rc->idiag_ext; 1276 req.idiag_states = rc->idiag_states; 1277 req.id = rc->id; 1278 1279 return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh, 1280 sizeof(struct inet_diag_req), &req); 1281 } 1282 1283 static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh) 1284 { 1285 int hdrlen = sizeof(struct inet_diag_req); 1286 struct net *net = sock_net(skb->sk); 1287 1288 if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX || 1289 nlmsg_len(nlh) < hdrlen) 1290 return -EINVAL; 1291 1292 if (nlh->nlmsg_flags & NLM_F_DUMP) { 1293 struct netlink_dump_control c = { 1294 .start = inet_diag_dump_start_compat, 1295 .done = inet_diag_dump_done, 1296 .dump = inet_diag_dump_compat, 1297 }; 1298 return netlink_dump_start(net->diag_nlsk, skb, nlh, &c); 1299 } 1300 1301 return inet_diag_get_exact_compat(skb, nlh); 1302 } 1303 1304 static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h) 1305 { 1306 int hdrlen = sizeof(struct inet_diag_req_v2); 1307 struct net *net = sock_net(skb->sk); 1308 1309 if (nlmsg_len(h) < hdrlen) 1310 return -EINVAL; 1311 1312 if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY && 1313 h->nlmsg_flags & NLM_F_DUMP) { 1314 struct netlink_dump_control c = { 1315 .start = inet_diag_dump_start, 1316 .done = inet_diag_dump_done, 1317 .dump = inet_diag_dump, 1318 }; 1319 return netlink_dump_start(net->diag_nlsk, skb, h, &c); 1320 } 1321 1322 return inet_diag_cmd_exact(h->nlmsg_type, skb, h, hdrlen, 1323 nlmsg_data(h)); 1324 } 1325 1326 static 1327 int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk) 1328 { 1329 const struct inet_diag_handler *handler; 1330 struct nlmsghdr *nlh; 1331 struct nlattr *attr; 1332 struct inet_diag_msg *r; 1333 void *info = NULL; 1334 int err = 0; 1335 1336 nlh = nlmsg_put(skb, 0, 0, SOCK_DIAG_BY_FAMILY, sizeof(*r), 0); 1337 if (!nlh) 1338 return -ENOMEM; 1339 1340 r = nlmsg_data(nlh); 1341 memset(r, 0, sizeof(*r)); 1342 inet_diag_msg_common_fill(r, sk); 1343 if (sk->sk_type == SOCK_DGRAM || sk->sk_type == SOCK_STREAM) 1344 r->id.idiag_sport = inet_sk(sk)->inet_sport; 1345 r->idiag_state = sk->sk_state; 1346 1347 if ((err = nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))) { 1348 nlmsg_cancel(skb, nlh); 1349 return err; 1350 } 1351 1352 handler = inet_diag_lock_handler(sk->sk_protocol); 1353 if (IS_ERR(handler)) { 1354 inet_diag_unlock_handler(handler); 1355 nlmsg_cancel(skb, nlh); 1356 return PTR_ERR(handler); 1357 } 1358 1359 attr = handler->idiag_info_size 1360 ? nla_reserve_64bit(skb, INET_DIAG_INFO, 1361 handler->idiag_info_size, 1362 INET_DIAG_PAD) 1363 : NULL; 1364 if (attr) 1365 info = nla_data(attr); 1366 1367 handler->idiag_get_info(sk, r, info); 1368 inet_diag_unlock_handler(handler); 1369 1370 nlmsg_end(skb, nlh); 1371 return 0; 1372 } 1373 1374 static const struct sock_diag_handler inet_diag_handler = { 1375 .family = AF_INET, 1376 .dump = inet_diag_handler_cmd, 1377 .get_info = inet_diag_handler_get_info, 1378 .destroy = inet_diag_handler_cmd, 1379 }; 1380 1381 static const struct sock_diag_handler inet6_diag_handler = { 1382 .family = AF_INET6, 1383 .dump = inet_diag_handler_cmd, 1384 .get_info = inet_diag_handler_get_info, 1385 .destroy = inet_diag_handler_cmd, 1386 }; 1387 1388 int inet_diag_register(const struct inet_diag_handler *h) 1389 { 1390 const __u16 type = h->idiag_type; 1391 int err = -EINVAL; 1392 1393 if (type >= IPPROTO_MAX) 1394 goto out; 1395 1396 mutex_lock(&inet_diag_table_mutex); 1397 err = -EEXIST; 1398 if (!inet_diag_table[type]) { 1399 inet_diag_table[type] = h; 1400 err = 0; 1401 } 1402 mutex_unlock(&inet_diag_table_mutex); 1403 out: 1404 return err; 1405 } 1406 EXPORT_SYMBOL_GPL(inet_diag_register); 1407 1408 void inet_diag_unregister(const struct inet_diag_handler *h) 1409 { 1410 const __u16 type = h->idiag_type; 1411 1412 if (type >= IPPROTO_MAX) 1413 return; 1414 1415 mutex_lock(&inet_diag_table_mutex); 1416 inet_diag_table[type] = NULL; 1417 mutex_unlock(&inet_diag_table_mutex); 1418 } 1419 EXPORT_SYMBOL_GPL(inet_diag_unregister); 1420 1421 static int __init inet_diag_init(void) 1422 { 1423 const int inet_diag_table_size = (IPPROTO_MAX * 1424 sizeof(struct inet_diag_handler *)); 1425 int err = -ENOMEM; 1426 1427 inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL); 1428 if (!inet_diag_table) 1429 goto out; 1430 1431 err = sock_diag_register(&inet_diag_handler); 1432 if (err) 1433 goto out_free_nl; 1434 1435 err = sock_diag_register(&inet6_diag_handler); 1436 if (err) 1437 goto out_free_inet; 1438 1439 sock_diag_register_inet_compat(inet_diag_rcv_msg_compat); 1440 out: 1441 return err; 1442 1443 out_free_inet: 1444 sock_diag_unregister(&inet_diag_handler); 1445 out_free_nl: 1446 kfree(inet_diag_table); 1447 goto out; 1448 } 1449 1450 static void __exit inet_diag_exit(void) 1451 { 1452 sock_diag_unregister(&inet6_diag_handler); 1453 sock_diag_unregister(&inet_diag_handler); 1454 sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat); 1455 kfree(inet_diag_table); 1456 } 1457 1458 module_init(inet_diag_init); 1459 module_exit(inet_diag_exit); 1460 MODULE_LICENSE("GPL"); 1461 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */); 1462 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */); 1463