1 // SPDX-License-Identifier: GPL-2.0 2 /* 3 * Management Component Transport Protocol (MCTP) 4 * 5 * Copyright (c) 2021 Code Construct 6 * Copyright (c) 2021 Google 7 */ 8 9 #include <linux/compat.h> 10 #include <linux/if_arp.h> 11 #include <linux/net.h> 12 #include <linux/mctp.h> 13 #include <linux/module.h> 14 #include <linux/socket.h> 15 16 #include <net/mctp.h> 17 #include <net/mctpdevice.h> 18 #include <net/sock.h> 19 20 #define CREATE_TRACE_POINTS 21 #include <trace/events/mctp.h> 22 23 /* socket implementation */ 24 25 static void mctp_sk_expire_keys(struct timer_list *timer); 26 27 static int mctp_release(struct socket *sock) 28 { 29 struct sock *sk = sock->sk; 30 31 if (sk) { 32 sock->sk = NULL; 33 sk->sk_prot->close(sk, 0); 34 } 35 36 return 0; 37 } 38 39 /* Generic sockaddr checks, padding checks only so far */ 40 static bool mctp_sockaddr_is_ok(const struct sockaddr_mctp *addr) 41 { 42 return !addr->__smctp_pad0 && !addr->__smctp_pad1; 43 } 44 45 static bool mctp_sockaddr_ext_is_ok(const struct sockaddr_mctp_ext *addr) 46 { 47 return !addr->__smctp_pad0[0] && 48 !addr->__smctp_pad0[1] && 49 !addr->__smctp_pad0[2]; 50 } 51 52 static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen) 53 { 54 struct sock *sk = sock->sk; 55 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); 56 struct sockaddr_mctp *smctp; 57 int rc; 58 59 if (addrlen < sizeof(*smctp)) 60 return -EINVAL; 61 62 if (addr->sa_family != AF_MCTP) 63 return -EAFNOSUPPORT; 64 65 if (!capable(CAP_NET_BIND_SERVICE)) 66 return -EACCES; 67 68 /* it's a valid sockaddr for MCTP, cast and do protocol checks */ 69 smctp = (struct sockaddr_mctp *)addr; 70 71 if (!mctp_sockaddr_is_ok(smctp)) 72 return -EINVAL; 73 74 lock_sock(sk); 75 76 /* TODO: allow rebind */ 77 if (sk_hashed(sk)) { 78 rc = -EADDRINUSE; 79 goto out_release; 80 } 81 msk->bind_net = smctp->smctp_network; 82 msk->bind_addr = smctp->smctp_addr.s_addr; 83 msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */ 84 85 rc = sk->sk_prot->hash(sk); 86 87 out_release: 88 release_sock(sk); 89 90 return rc; 91 } 92 93 static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) 94 { 95 DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name); 96 const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr); 97 int rc, addrlen = msg->msg_namelen; 98 struct sock *sk = sock->sk; 99 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); 100 struct mctp_skb_cb *cb; 101 struct mctp_route *rt; 102 struct sk_buff *skb; 103 104 if (addr) { 105 const u8 tagbits = MCTP_TAG_MASK | MCTP_TAG_OWNER | 106 MCTP_TAG_PREALLOC; 107 108 if (addrlen < sizeof(struct sockaddr_mctp)) 109 return -EINVAL; 110 if (addr->smctp_family != AF_MCTP) 111 return -EINVAL; 112 if (!mctp_sockaddr_is_ok(addr)) 113 return -EINVAL; 114 if (addr->smctp_tag & ~tagbits) 115 return -EINVAL; 116 /* can't preallocate a non-owned tag */ 117 if (addr->smctp_tag & MCTP_TAG_PREALLOC && 118 !(addr->smctp_tag & MCTP_TAG_OWNER)) 119 return -EINVAL; 120 121 } else { 122 /* TODO: connect()ed sockets */ 123 return -EDESTADDRREQ; 124 } 125 126 if (!capable(CAP_NET_RAW)) 127 return -EACCES; 128 129 if (addr->smctp_network == MCTP_NET_ANY) 130 addr->smctp_network = mctp_default_net(sock_net(sk)); 131 132 skb = sock_alloc_send_skb(sk, hlen + 1 + len, 133 msg->msg_flags & MSG_DONTWAIT, &rc); 134 if (!skb) 135 return rc; 136 137 skb_reserve(skb, hlen); 138 139 /* set type as fist byte in payload */ 140 *(u8 *)skb_put(skb, 1) = addr->smctp_type; 141 142 rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len); 143 if (rc < 0) 144 goto err_free; 145 146 /* set up cb */ 147 cb = __mctp_cb(skb); 148 cb->net = addr->smctp_network; 149 150 /* direct addressing */ 151 if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) { 152 DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, 153 extaddr, msg->msg_name); 154 155 if (!mctp_sockaddr_ext_is_ok(extaddr) || 156 extaddr->smctp_halen > sizeof(cb->haddr)) { 157 rc = -EINVAL; 158 goto err_free; 159 } 160 161 cb->ifindex = extaddr->smctp_ifindex; 162 cb->halen = extaddr->smctp_halen; 163 memcpy(cb->haddr, extaddr->smctp_haddr, cb->halen); 164 165 rt = NULL; 166 } else { 167 rt = mctp_route_lookup(sock_net(sk), addr->smctp_network, 168 addr->smctp_addr.s_addr); 169 if (!rt) { 170 rc = -EHOSTUNREACH; 171 goto err_free; 172 } 173 } 174 175 rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr, 176 addr->smctp_tag); 177 178 return rc ? : len; 179 180 err_free: 181 kfree_skb(skb); 182 return rc; 183 } 184 185 static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, 186 int flags) 187 { 188 DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name); 189 struct sock *sk = sock->sk; 190 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); 191 struct sk_buff *skb; 192 size_t msglen; 193 u8 type; 194 int rc; 195 196 if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK)) 197 return -EOPNOTSUPP; 198 199 skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &rc); 200 if (!skb) 201 return rc; 202 203 if (!skb->len) { 204 rc = 0; 205 goto out_free; 206 } 207 208 /* extract message type, remove from data */ 209 type = *((u8 *)skb->data); 210 msglen = skb->len - 1; 211 212 if (len < msglen) 213 msg->msg_flags |= MSG_TRUNC; 214 else 215 len = msglen; 216 217 rc = skb_copy_datagram_msg(skb, 1, msg, len); 218 if (rc < 0) 219 goto out_free; 220 221 sock_recv_ts_and_drops(msg, sk, skb); 222 223 if (addr) { 224 struct mctp_skb_cb *cb = mctp_cb(skb); 225 /* TODO: expand mctp_skb_cb for header fields? */ 226 struct mctp_hdr *hdr = mctp_hdr(skb); 227 228 addr = msg->msg_name; 229 addr->smctp_family = AF_MCTP; 230 addr->__smctp_pad0 = 0; 231 addr->smctp_network = cb->net; 232 addr->smctp_addr.s_addr = hdr->src; 233 addr->smctp_type = type; 234 addr->smctp_tag = hdr->flags_seq_tag & 235 (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); 236 addr->__smctp_pad1 = 0; 237 msg->msg_namelen = sizeof(*addr); 238 239 if (msk->addr_ext) { 240 DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, ae, 241 msg->msg_name); 242 msg->msg_namelen = sizeof(*ae); 243 ae->smctp_ifindex = cb->ifindex; 244 ae->smctp_halen = cb->halen; 245 memset(ae->__smctp_pad0, 0x0, sizeof(ae->__smctp_pad0)); 246 memset(ae->smctp_haddr, 0x0, sizeof(ae->smctp_haddr)); 247 memcpy(ae->smctp_haddr, cb->haddr, cb->halen); 248 } 249 } 250 251 rc = len; 252 253 if (flags & MSG_TRUNC) 254 rc = msglen; 255 256 out_free: 257 skb_free_datagram(sk, skb); 258 return rc; 259 } 260 261 /* We're done with the key; invalidate, stop reassembly, and remove from lists. 262 */ 263 static void __mctp_key_remove(struct mctp_sk_key *key, struct net *net, 264 unsigned long flags, unsigned long reason) 265 __releases(&key->lock) 266 __must_hold(&net->mctp.keys_lock) 267 { 268 struct sk_buff *skb; 269 270 trace_mctp_key_release(key, reason); 271 skb = key->reasm_head; 272 key->reasm_head = NULL; 273 key->reasm_dead = true; 274 key->valid = false; 275 mctp_dev_release_key(key->dev, key); 276 spin_unlock_irqrestore(&key->lock, flags); 277 278 hlist_del(&key->hlist); 279 hlist_del(&key->sklist); 280 281 /* unref for the lists */ 282 mctp_key_unref(key); 283 284 kfree_skb(skb); 285 } 286 287 static int mctp_setsockopt(struct socket *sock, int level, int optname, 288 sockptr_t optval, unsigned int optlen) 289 { 290 struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk); 291 int val; 292 293 if (level != SOL_MCTP) 294 return -EINVAL; 295 296 if (optname == MCTP_OPT_ADDR_EXT) { 297 if (optlen != sizeof(int)) 298 return -EINVAL; 299 if (copy_from_sockptr(&val, optval, sizeof(int))) 300 return -EFAULT; 301 msk->addr_ext = val; 302 return 0; 303 } 304 305 return -ENOPROTOOPT; 306 } 307 308 static int mctp_getsockopt(struct socket *sock, int level, int optname, 309 char __user *optval, int __user *optlen) 310 { 311 struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk); 312 int len, val; 313 314 if (level != SOL_MCTP) 315 return -EINVAL; 316 317 if (get_user(len, optlen)) 318 return -EFAULT; 319 320 if (optname == MCTP_OPT_ADDR_EXT) { 321 if (len != sizeof(int)) 322 return -EINVAL; 323 val = !!msk->addr_ext; 324 if (copy_to_user(optval, &val, len)) 325 return -EFAULT; 326 return 0; 327 } 328 329 return -EINVAL; 330 } 331 332 static int mctp_ioctl_alloctag(struct mctp_sock *msk, unsigned long arg) 333 { 334 struct net *net = sock_net(&msk->sk); 335 struct mctp_sk_key *key = NULL; 336 struct mctp_ioc_tag_ctl ctl; 337 unsigned long flags; 338 u8 tag; 339 340 if (copy_from_user(&ctl, (void __user *)arg, sizeof(ctl))) 341 return -EFAULT; 342 343 if (ctl.tag) 344 return -EINVAL; 345 346 if (ctl.flags) 347 return -EINVAL; 348 349 key = mctp_alloc_local_tag(msk, ctl.peer_addr, MCTP_ADDR_ANY, 350 true, &tag); 351 if (IS_ERR(key)) 352 return PTR_ERR(key); 353 354 ctl.tag = tag | MCTP_TAG_OWNER | MCTP_TAG_PREALLOC; 355 if (copy_to_user((void __user *)arg, &ctl, sizeof(ctl))) { 356 spin_lock_irqsave(&key->lock, flags); 357 __mctp_key_remove(key, net, flags, MCTP_TRACE_KEY_DROPPED); 358 mctp_key_unref(key); 359 return -EFAULT; 360 } 361 362 mctp_key_unref(key); 363 return 0; 364 } 365 366 static int mctp_ioctl_droptag(struct mctp_sock *msk, unsigned long arg) 367 { 368 struct net *net = sock_net(&msk->sk); 369 struct mctp_ioc_tag_ctl ctl; 370 unsigned long flags, fl2; 371 struct mctp_sk_key *key; 372 struct hlist_node *tmp; 373 int rc; 374 u8 tag; 375 376 if (copy_from_user(&ctl, (void __user *)arg, sizeof(ctl))) 377 return -EFAULT; 378 379 if (ctl.flags) 380 return -EINVAL; 381 382 /* Must be a local tag, TO set, preallocated */ 383 if ((ctl.tag & ~MCTP_TAG_MASK) != (MCTP_TAG_OWNER | MCTP_TAG_PREALLOC)) 384 return -EINVAL; 385 386 tag = ctl.tag & MCTP_TAG_MASK; 387 rc = -EINVAL; 388 389 spin_lock_irqsave(&net->mctp.keys_lock, flags); 390 hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { 391 /* we do an irqsave here, even though we know the irq state, 392 * so we have the flags to pass to __mctp_key_remove 393 */ 394 spin_lock_irqsave(&key->lock, fl2); 395 if (key->manual_alloc && 396 ctl.peer_addr == key->peer_addr && 397 tag == key->tag) { 398 __mctp_key_remove(key, net, fl2, 399 MCTP_TRACE_KEY_DROPPED); 400 rc = 0; 401 } else { 402 spin_unlock_irqrestore(&key->lock, fl2); 403 } 404 } 405 spin_unlock_irqrestore(&net->mctp.keys_lock, flags); 406 407 return rc; 408 } 409 410 static int mctp_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) 411 { 412 struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk); 413 414 switch (cmd) { 415 case SIOCMCTPALLOCTAG: 416 return mctp_ioctl_alloctag(msk, arg); 417 case SIOCMCTPDROPTAG: 418 return mctp_ioctl_droptag(msk, arg); 419 } 420 421 return -EINVAL; 422 } 423 424 #ifdef CONFIG_COMPAT 425 static int mctp_compat_ioctl(struct socket *sock, unsigned int cmd, 426 unsigned long arg) 427 { 428 void __user *argp = compat_ptr(arg); 429 430 switch (cmd) { 431 /* These have compatible ptr layouts */ 432 case SIOCMCTPALLOCTAG: 433 case SIOCMCTPDROPTAG: 434 return mctp_ioctl(sock, cmd, (unsigned long)argp); 435 } 436 437 return -ENOIOCTLCMD; 438 } 439 #endif 440 441 static const struct proto_ops mctp_dgram_ops = { 442 .family = PF_MCTP, 443 .release = mctp_release, 444 .bind = mctp_bind, 445 .connect = sock_no_connect, 446 .socketpair = sock_no_socketpair, 447 .accept = sock_no_accept, 448 .getname = sock_no_getname, 449 .poll = datagram_poll, 450 .ioctl = mctp_ioctl, 451 .gettstamp = sock_gettstamp, 452 .listen = sock_no_listen, 453 .shutdown = sock_no_shutdown, 454 .setsockopt = mctp_setsockopt, 455 .getsockopt = mctp_getsockopt, 456 .sendmsg = mctp_sendmsg, 457 .recvmsg = mctp_recvmsg, 458 .mmap = sock_no_mmap, 459 .sendpage = sock_no_sendpage, 460 #ifdef CONFIG_COMPAT 461 .compat_ioctl = mctp_compat_ioctl, 462 #endif 463 }; 464 465 static void mctp_sk_expire_keys(struct timer_list *timer) 466 { 467 struct mctp_sock *msk = container_of(timer, struct mctp_sock, 468 key_expiry); 469 struct net *net = sock_net(&msk->sk); 470 unsigned long next_expiry, flags, fl2; 471 struct mctp_sk_key *key; 472 struct hlist_node *tmp; 473 bool next_expiry_valid = false; 474 475 spin_lock_irqsave(&net->mctp.keys_lock, flags); 476 477 hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { 478 /* don't expire. manual_alloc is immutable, no locking 479 * required. 480 */ 481 if (key->manual_alloc) 482 continue; 483 484 spin_lock_irqsave(&key->lock, fl2); 485 if (!time_after_eq(key->expiry, jiffies)) { 486 __mctp_key_remove(key, net, fl2, 487 MCTP_TRACE_KEY_TIMEOUT); 488 continue; 489 } 490 491 if (next_expiry_valid) { 492 if (time_before(key->expiry, next_expiry)) 493 next_expiry = key->expiry; 494 } else { 495 next_expiry = key->expiry; 496 next_expiry_valid = true; 497 } 498 spin_unlock_irqrestore(&key->lock, fl2); 499 } 500 501 spin_unlock_irqrestore(&net->mctp.keys_lock, flags); 502 503 if (next_expiry_valid) 504 mod_timer(timer, next_expiry); 505 } 506 507 static int mctp_sk_init(struct sock *sk) 508 { 509 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); 510 511 INIT_HLIST_HEAD(&msk->keys); 512 timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0); 513 return 0; 514 } 515 516 static void mctp_sk_close(struct sock *sk, long timeout) 517 { 518 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); 519 520 del_timer_sync(&msk->key_expiry); 521 sk_common_release(sk); 522 } 523 524 static int mctp_sk_hash(struct sock *sk) 525 { 526 struct net *net = sock_net(sk); 527 528 mutex_lock(&net->mctp.bind_lock); 529 sk_add_node_rcu(sk, &net->mctp.binds); 530 mutex_unlock(&net->mctp.bind_lock); 531 532 return 0; 533 } 534 535 static void mctp_sk_unhash(struct sock *sk) 536 { 537 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); 538 struct net *net = sock_net(sk); 539 unsigned long flags, fl2; 540 struct mctp_sk_key *key; 541 struct hlist_node *tmp; 542 543 /* remove from any type-based binds */ 544 mutex_lock(&net->mctp.bind_lock); 545 sk_del_node_init_rcu(sk); 546 mutex_unlock(&net->mctp.bind_lock); 547 548 /* remove tag allocations */ 549 spin_lock_irqsave(&net->mctp.keys_lock, flags); 550 hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { 551 spin_lock_irqsave(&key->lock, fl2); 552 __mctp_key_remove(key, net, fl2, MCTP_TRACE_KEY_CLOSED); 553 } 554 spin_unlock_irqrestore(&net->mctp.keys_lock, flags); 555 } 556 557 static struct proto mctp_proto = { 558 .name = "MCTP", 559 .owner = THIS_MODULE, 560 .obj_size = sizeof(struct mctp_sock), 561 .init = mctp_sk_init, 562 .close = mctp_sk_close, 563 .hash = mctp_sk_hash, 564 .unhash = mctp_sk_unhash, 565 }; 566 567 static int mctp_pf_create(struct net *net, struct socket *sock, 568 int protocol, int kern) 569 { 570 const struct proto_ops *ops; 571 struct proto *proto; 572 struct sock *sk; 573 int rc; 574 575 if (protocol) 576 return -EPROTONOSUPPORT; 577 578 /* only datagram sockets are supported */ 579 if (sock->type != SOCK_DGRAM) 580 return -ESOCKTNOSUPPORT; 581 582 proto = &mctp_proto; 583 ops = &mctp_dgram_ops; 584 585 sock->state = SS_UNCONNECTED; 586 sock->ops = ops; 587 588 sk = sk_alloc(net, PF_MCTP, GFP_KERNEL, proto, kern); 589 if (!sk) 590 return -ENOMEM; 591 592 sock_init_data(sock, sk); 593 594 rc = 0; 595 if (sk->sk_prot->init) 596 rc = sk->sk_prot->init(sk); 597 598 if (rc) 599 goto err_sk_put; 600 601 return 0; 602 603 err_sk_put: 604 sock_orphan(sk); 605 sock_put(sk); 606 return rc; 607 } 608 609 static struct net_proto_family mctp_pf = { 610 .family = PF_MCTP, 611 .create = mctp_pf_create, 612 .owner = THIS_MODULE, 613 }; 614 615 static __init int mctp_init(void) 616 { 617 int rc; 618 619 /* ensure our uapi tag definitions match the header format */ 620 BUILD_BUG_ON(MCTP_TAG_OWNER != MCTP_HDR_FLAG_TO); 621 BUILD_BUG_ON(MCTP_TAG_MASK != MCTP_HDR_TAG_MASK); 622 623 pr_info("mctp: management component transport protocol core\n"); 624 625 rc = sock_register(&mctp_pf); 626 if (rc) 627 return rc; 628 629 rc = proto_register(&mctp_proto, 0); 630 if (rc) 631 goto err_unreg_sock; 632 633 rc = mctp_routes_init(); 634 if (rc) 635 goto err_unreg_proto; 636 637 rc = mctp_neigh_init(); 638 if (rc) 639 goto err_unreg_proto; 640 641 mctp_device_init(); 642 643 return 0; 644 645 err_unreg_proto: 646 proto_unregister(&mctp_proto); 647 err_unreg_sock: 648 sock_unregister(PF_MCTP); 649 650 return rc; 651 } 652 653 static __exit void mctp_exit(void) 654 { 655 mctp_device_exit(); 656 mctp_neigh_exit(); 657 mctp_routes_exit(); 658 proto_unregister(&mctp_proto); 659 sock_unregister(PF_MCTP); 660 } 661 662 subsys_initcall(mctp_init); 663 module_exit(mctp_exit); 664 665 MODULE_DESCRIPTION("MCTP core"); 666 MODULE_LICENSE("GPL v2"); 667 MODULE_AUTHOR("Jeremy Kerr <jk@codeconstruct.com.au>"); 668 669 MODULE_ALIAS_NETPROTO(PF_MCTP); 670