1 // SPDX-License-Identifier: GPL-2.0 2 /* 3 * Management Component Transport Protocol (MCTP) 4 * 5 * Copyright (c) 2021 Code Construct 6 * Copyright (c) 2021 Google 7 */ 8 9 #include <linux/if_arp.h> 10 #include <linux/net.h> 11 #include <linux/mctp.h> 12 #include <linux/module.h> 13 #include <linux/socket.h> 14 15 #include <net/mctp.h> 16 #include <net/mctpdevice.h> 17 #include <net/sock.h> 18 19 /* socket implementation */ 20 21 static int mctp_release(struct socket *sock) 22 { 23 struct sock *sk = sock->sk; 24 25 if (sk) { 26 sock->sk = NULL; 27 sk->sk_prot->close(sk, 0); 28 } 29 30 return 0; 31 } 32 33 static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen) 34 { 35 struct sock *sk = sock->sk; 36 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); 37 struct sockaddr_mctp *smctp; 38 int rc; 39 40 if (addrlen < sizeof(*smctp)) 41 return -EINVAL; 42 43 if (addr->sa_family != AF_MCTP) 44 return -EAFNOSUPPORT; 45 46 if (!capable(CAP_NET_BIND_SERVICE)) 47 return -EACCES; 48 49 /* it's a valid sockaddr for MCTP, cast and do protocol checks */ 50 smctp = (struct sockaddr_mctp *)addr; 51 52 lock_sock(sk); 53 54 /* TODO: allow rebind */ 55 if (sk_hashed(sk)) { 56 rc = -EADDRINUSE; 57 goto out_release; 58 } 59 msk->bind_net = smctp->smctp_network; 60 msk->bind_addr = smctp->smctp_addr.s_addr; 61 msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */ 62 63 rc = sk->sk_prot->hash(sk); 64 65 out_release: 66 release_sock(sk); 67 68 return rc; 69 } 70 71 static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) 72 { 73 DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name); 74 const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr); 75 int rc, addrlen = msg->msg_namelen; 76 struct sock *sk = sock->sk; 77 struct mctp_skb_cb *cb; 78 struct mctp_route *rt; 79 struct sk_buff *skb; 80 81 if (addr) { 82 if (addrlen < sizeof(struct sockaddr_mctp)) 83 return -EINVAL; 84 if (addr->smctp_family != AF_MCTP) 85 return -EINVAL; 86 if (addr->smctp_tag & ~(MCTP_TAG_MASK | MCTP_TAG_OWNER)) 87 return -EINVAL; 88 89 } else { 90 /* TODO: connect()ed sockets */ 91 return -EDESTADDRREQ; 92 } 93 94 if (!capable(CAP_NET_RAW)) 95 return -EACCES; 96 97 if (addr->smctp_network == MCTP_NET_ANY) 98 addr->smctp_network = mctp_default_net(sock_net(sk)); 99 100 rt = mctp_route_lookup(sock_net(sk), addr->smctp_network, 101 addr->smctp_addr.s_addr); 102 if (!rt) 103 return -EHOSTUNREACH; 104 105 skb = sock_alloc_send_skb(sk, hlen + 1 + len, 106 msg->msg_flags & MSG_DONTWAIT, &rc); 107 if (!skb) 108 return rc; 109 110 skb_reserve(skb, hlen); 111 112 /* set type as fist byte in payload */ 113 *(u8 *)skb_put(skb, 1) = addr->smctp_type; 114 115 rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len); 116 if (rc < 0) { 117 kfree_skb(skb); 118 return rc; 119 } 120 121 /* set up cb */ 122 cb = __mctp_cb(skb); 123 cb->net = addr->smctp_network; 124 125 rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr, 126 addr->smctp_tag); 127 128 return rc ? : len; 129 } 130 131 static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, 132 int flags) 133 { 134 DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name); 135 struct sock *sk = sock->sk; 136 struct sk_buff *skb; 137 size_t msglen; 138 u8 type; 139 int rc; 140 141 if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK)) 142 return -EOPNOTSUPP; 143 144 skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &rc); 145 if (!skb) 146 return rc; 147 148 if (!skb->len) { 149 rc = 0; 150 goto out_free; 151 } 152 153 /* extract message type, remove from data */ 154 type = *((u8 *)skb->data); 155 msglen = skb->len - 1; 156 157 if (len < msglen) 158 msg->msg_flags |= MSG_TRUNC; 159 else 160 len = msglen; 161 162 rc = skb_copy_datagram_msg(skb, 1, msg, len); 163 if (rc < 0) 164 goto out_free; 165 166 sock_recv_ts_and_drops(msg, sk, skb); 167 168 if (addr) { 169 struct mctp_skb_cb *cb = mctp_cb(skb); 170 /* TODO: expand mctp_skb_cb for header fields? */ 171 struct mctp_hdr *hdr = mctp_hdr(skb); 172 173 addr = msg->msg_name; 174 addr->smctp_family = AF_MCTP; 175 addr->smctp_network = cb->net; 176 addr->smctp_addr.s_addr = hdr->src; 177 addr->smctp_type = type; 178 addr->smctp_tag = hdr->flags_seq_tag & 179 (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); 180 msg->msg_namelen = sizeof(*addr); 181 } 182 183 rc = len; 184 185 if (flags & MSG_TRUNC) 186 rc = msglen; 187 188 out_free: 189 skb_free_datagram(sk, skb); 190 return rc; 191 } 192 193 static int mctp_setsockopt(struct socket *sock, int level, int optname, 194 sockptr_t optval, unsigned int optlen) 195 { 196 return -EINVAL; 197 } 198 199 static int mctp_getsockopt(struct socket *sock, int level, int optname, 200 char __user *optval, int __user *optlen) 201 { 202 return -EINVAL; 203 } 204 205 static const struct proto_ops mctp_dgram_ops = { 206 .family = PF_MCTP, 207 .release = mctp_release, 208 .bind = mctp_bind, 209 .connect = sock_no_connect, 210 .socketpair = sock_no_socketpair, 211 .accept = sock_no_accept, 212 .getname = sock_no_getname, 213 .poll = datagram_poll, 214 .ioctl = sock_no_ioctl, 215 .gettstamp = sock_gettstamp, 216 .listen = sock_no_listen, 217 .shutdown = sock_no_shutdown, 218 .setsockopt = mctp_setsockopt, 219 .getsockopt = mctp_getsockopt, 220 .sendmsg = mctp_sendmsg, 221 .recvmsg = mctp_recvmsg, 222 .mmap = sock_no_mmap, 223 .sendpage = sock_no_sendpage, 224 }; 225 226 static int mctp_sk_init(struct sock *sk) 227 { 228 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); 229 230 INIT_HLIST_HEAD(&msk->keys); 231 return 0; 232 } 233 234 static void mctp_sk_close(struct sock *sk, long timeout) 235 { 236 sk_common_release(sk); 237 } 238 239 static int mctp_sk_hash(struct sock *sk) 240 { 241 struct net *net = sock_net(sk); 242 243 mutex_lock(&net->mctp.bind_lock); 244 sk_add_node_rcu(sk, &net->mctp.binds); 245 mutex_unlock(&net->mctp.bind_lock); 246 247 return 0; 248 } 249 250 static void mctp_sk_unhash(struct sock *sk) 251 { 252 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); 253 struct net *net = sock_net(sk); 254 struct mctp_sk_key *key; 255 struct hlist_node *tmp; 256 unsigned long flags; 257 258 /* remove from any type-based binds */ 259 mutex_lock(&net->mctp.bind_lock); 260 sk_del_node_init_rcu(sk); 261 mutex_unlock(&net->mctp.bind_lock); 262 263 /* remove tag allocations */ 264 spin_lock_irqsave(&net->mctp.keys_lock, flags); 265 hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { 266 hlist_del_rcu(&key->sklist); 267 hlist_del_rcu(&key->hlist); 268 269 spin_lock(&key->reasm_lock); 270 if (key->reasm_head) 271 kfree_skb(key->reasm_head); 272 key->reasm_head = NULL; 273 key->reasm_dead = true; 274 spin_unlock(&key->reasm_lock); 275 276 kfree_rcu(key, rcu); 277 } 278 spin_unlock_irqrestore(&net->mctp.keys_lock, flags); 279 280 synchronize_rcu(); 281 } 282 283 static struct proto mctp_proto = { 284 .name = "MCTP", 285 .owner = THIS_MODULE, 286 .obj_size = sizeof(struct mctp_sock), 287 .init = mctp_sk_init, 288 .close = mctp_sk_close, 289 .hash = mctp_sk_hash, 290 .unhash = mctp_sk_unhash, 291 }; 292 293 static int mctp_pf_create(struct net *net, struct socket *sock, 294 int protocol, int kern) 295 { 296 const struct proto_ops *ops; 297 struct proto *proto; 298 struct sock *sk; 299 int rc; 300 301 if (protocol) 302 return -EPROTONOSUPPORT; 303 304 /* only datagram sockets are supported */ 305 if (sock->type != SOCK_DGRAM) 306 return -ESOCKTNOSUPPORT; 307 308 proto = &mctp_proto; 309 ops = &mctp_dgram_ops; 310 311 sock->state = SS_UNCONNECTED; 312 sock->ops = ops; 313 314 sk = sk_alloc(net, PF_MCTP, GFP_KERNEL, proto, kern); 315 if (!sk) 316 return -ENOMEM; 317 318 sock_init_data(sock, sk); 319 320 rc = 0; 321 if (sk->sk_prot->init) 322 rc = sk->sk_prot->init(sk); 323 324 if (rc) 325 goto err_sk_put; 326 327 return 0; 328 329 err_sk_put: 330 sock_orphan(sk); 331 sock_put(sk); 332 return rc; 333 } 334 335 static struct net_proto_family mctp_pf = { 336 .family = PF_MCTP, 337 .create = mctp_pf_create, 338 .owner = THIS_MODULE, 339 }; 340 341 static __init int mctp_init(void) 342 { 343 int rc; 344 345 /* ensure our uapi tag definitions match the header format */ 346 BUILD_BUG_ON(MCTP_TAG_OWNER != MCTP_HDR_FLAG_TO); 347 BUILD_BUG_ON(MCTP_TAG_MASK != MCTP_HDR_TAG_MASK); 348 349 pr_info("mctp: management component transport protocol core\n"); 350 351 rc = sock_register(&mctp_pf); 352 if (rc) 353 return rc; 354 355 rc = proto_register(&mctp_proto, 0); 356 if (rc) 357 goto err_unreg_sock; 358 359 rc = mctp_routes_init(); 360 if (rc) 361 goto err_unreg_proto; 362 363 rc = mctp_neigh_init(); 364 if (rc) 365 goto err_unreg_proto; 366 367 mctp_device_init(); 368 369 return 0; 370 371 err_unreg_proto: 372 proto_unregister(&mctp_proto); 373 err_unreg_sock: 374 sock_unregister(PF_MCTP); 375 376 return rc; 377 } 378 379 static __exit void mctp_exit(void) 380 { 381 mctp_device_exit(); 382 mctp_neigh_exit(); 383 mctp_routes_exit(); 384 proto_unregister(&mctp_proto); 385 sock_unregister(PF_MCTP); 386 } 387 388 module_init(mctp_init); 389 module_exit(mctp_exit); 390 391 MODULE_DESCRIPTION("MCTP core"); 392 MODULE_LICENSE("GPL v2"); 393 MODULE_AUTHOR("Jeremy Kerr <jk@codeconstruct.com.au>"); 394 395 MODULE_ALIAS_NETPROTO(PF_MCTP); 396