1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * 5 * This software is available to you under a choice of one of two 6 * licenses. You may choose to be licensed under the terms of the GNU 7 * General Public License (GPL) Version 2, available from the file 8 * COPYING in the main directory of this source tree, or the 9 * OpenIB.org BSD license below: 10 * 11 * Redistribution and use in source and binary forms, with or 12 * without modification, are permitted provided that the following 13 * conditions are met: 14 * 15 * - Redistributions of source code must retain the above 16 * copyright notice, this list of conditions and the following 17 * disclaimer. 18 * 19 * - Redistributions in binary form must reproduce the above 20 * copyright notice, this list of conditions and the following 21 * disclaimer in the documentation and/or other materials 22 * provided with the distribution. 23 * 24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 * SOFTWARE. 32 */ 33 34 #include <linux/module.h> 35 36 #include <net/tcp.h> 37 #include <net/inet_common.h> 38 #include <linux/highmem.h> 39 #include <linux/netdevice.h> 40 #include <linux/sched/signal.h> 41 42 #include <net/tls.h> 43 44 MODULE_AUTHOR("Mellanox Technologies"); 45 MODULE_DESCRIPTION("Transport Layer Security Support"); 46 MODULE_LICENSE("Dual BSD/GPL"); 47 48 enum { 49 TLSV4, 50 TLSV6, 51 TLS_NUM_PROTS, 52 }; 53 54 enum { 55 TLS_BASE, 56 TLS_SW_TX, 57 TLS_SW_RX, 58 TLS_SW_RXTX, 59 TLS_NUM_CONFIG, 60 }; 61 62 static struct proto *saved_tcpv6_prot; 63 static DEFINE_MUTEX(tcpv6_prot_mutex); 64 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG]; 65 static struct proto_ops tls_sw_proto_ops; 66 67 static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx) 68 { 69 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 70 71 sk->sk_prot = &tls_prots[ip_ver][ctx->conf]; 72 } 73 74 int wait_on_pending_writer(struct sock *sk, long *timeo) 75 { 76 int rc = 0; 77 DEFINE_WAIT_FUNC(wait, woken_wake_function); 78 79 add_wait_queue(sk_sleep(sk), &wait); 80 while (1) { 81 if (!*timeo) { 82 rc = -EAGAIN; 83 break; 84 } 85 86 if (signal_pending(current)) { 87 rc = sock_intr_errno(*timeo); 88 break; 89 } 90 91 if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait)) 92 break; 93 } 94 remove_wait_queue(sk_sleep(sk), &wait); 95 return rc; 96 } 97 98 int tls_push_sg(struct sock *sk, 99 struct tls_context *ctx, 100 struct scatterlist *sg, 101 u16 first_offset, 102 int flags) 103 { 104 int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST; 105 int ret = 0; 106 struct page *p; 107 size_t size; 108 int offset = first_offset; 109 110 size = sg->length - offset; 111 offset += sg->offset; 112 113 while (1) { 114 if (sg_is_last(sg)) 115 sendpage_flags = flags; 116 117 /* is sending application-limited? */ 118 tcp_rate_check_app_limited(sk); 119 p = sg_page(sg); 120 retry: 121 ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags); 122 123 if (ret != size) { 124 if (ret > 0) { 125 offset += ret; 126 size -= ret; 127 goto retry; 128 } 129 130 offset -= sg->offset; 131 ctx->partially_sent_offset = offset; 132 ctx->partially_sent_record = (void *)sg; 133 return ret; 134 } 135 136 put_page(p); 137 sk_mem_uncharge(sk, sg->length); 138 sg = sg_next(sg); 139 if (!sg) 140 break; 141 142 offset = sg->offset; 143 size = sg->length; 144 } 145 146 clear_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); 147 148 return 0; 149 } 150 151 static int tls_handle_open_record(struct sock *sk, int flags) 152 { 153 struct tls_context *ctx = tls_get_ctx(sk); 154 155 if (tls_is_pending_open_record(ctx)) 156 return ctx->push_pending_record(sk, flags); 157 158 return 0; 159 } 160 161 int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg, 162 unsigned char *record_type) 163 { 164 struct cmsghdr *cmsg; 165 int rc = -EINVAL; 166 167 for_each_cmsghdr(cmsg, msg) { 168 if (!CMSG_OK(msg, cmsg)) 169 return -EINVAL; 170 if (cmsg->cmsg_level != SOL_TLS) 171 continue; 172 173 switch (cmsg->cmsg_type) { 174 case TLS_SET_RECORD_TYPE: 175 if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type))) 176 return -EINVAL; 177 178 if (msg->msg_flags & MSG_MORE) 179 return -EINVAL; 180 181 rc = tls_handle_open_record(sk, msg->msg_flags); 182 if (rc) 183 return rc; 184 185 *record_type = *(unsigned char *)CMSG_DATA(cmsg); 186 rc = 0; 187 break; 188 default: 189 return -EINVAL; 190 } 191 } 192 193 return rc; 194 } 195 196 int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx, 197 int flags, long *timeo) 198 { 199 struct scatterlist *sg; 200 u16 offset; 201 202 if (!tls_is_partially_sent_record(ctx)) 203 return ctx->push_pending_record(sk, flags); 204 205 sg = ctx->partially_sent_record; 206 offset = ctx->partially_sent_offset; 207 208 ctx->partially_sent_record = NULL; 209 return tls_push_sg(sk, ctx, sg, offset, flags); 210 } 211 212 static void tls_write_space(struct sock *sk) 213 { 214 struct tls_context *ctx = tls_get_ctx(sk); 215 216 if (!sk->sk_write_pending && tls_is_pending_closed_record(ctx)) { 217 gfp_t sk_allocation = sk->sk_allocation; 218 int rc; 219 long timeo = 0; 220 221 sk->sk_allocation = GFP_ATOMIC; 222 rc = tls_push_pending_closed_record(sk, ctx, 223 MSG_DONTWAIT | 224 MSG_NOSIGNAL, 225 &timeo); 226 sk->sk_allocation = sk_allocation; 227 228 if (rc < 0) 229 return; 230 } 231 232 ctx->sk_write_space(sk); 233 } 234 235 static void tls_sk_proto_close(struct sock *sk, long timeout) 236 { 237 struct tls_context *ctx = tls_get_ctx(sk); 238 long timeo = sock_sndtimeo(sk, 0); 239 void (*sk_proto_close)(struct sock *sk, long timeout); 240 241 lock_sock(sk); 242 sk_proto_close = ctx->sk_proto_close; 243 244 if (ctx->conf == TLS_BASE) { 245 kfree(ctx); 246 goto skip_tx_cleanup; 247 } 248 249 if (!tls_complete_pending_work(sk, ctx, 0, &timeo)) 250 tls_handle_open_record(sk, 0); 251 252 if (ctx->partially_sent_record) { 253 struct scatterlist *sg = ctx->partially_sent_record; 254 255 while (1) { 256 put_page(sg_page(sg)); 257 sk_mem_uncharge(sk, sg->length); 258 259 if (sg_is_last(sg)) 260 break; 261 sg++; 262 } 263 } 264 265 kfree(ctx->tx.rec_seq); 266 kfree(ctx->tx.iv); 267 kfree(ctx->rx.rec_seq); 268 kfree(ctx->rx.iv); 269 270 if (ctx->conf == TLS_SW_TX || 271 ctx->conf == TLS_SW_RX || 272 ctx->conf == TLS_SW_RXTX) { 273 tls_sw_free_resources(sk); 274 } 275 276 skip_tx_cleanup: 277 release_sock(sk); 278 sk_proto_close(sk, timeout); 279 } 280 281 static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, 282 int __user *optlen) 283 { 284 int rc = 0; 285 struct tls_context *ctx = tls_get_ctx(sk); 286 struct tls_crypto_info *crypto_info; 287 int len; 288 289 if (get_user(len, optlen)) 290 return -EFAULT; 291 292 if (!optval || (len < sizeof(*crypto_info))) { 293 rc = -EINVAL; 294 goto out; 295 } 296 297 if (!ctx) { 298 rc = -EBUSY; 299 goto out; 300 } 301 302 /* get user crypto info */ 303 crypto_info = &ctx->crypto_send; 304 305 if (!TLS_CRYPTO_INFO_READY(crypto_info)) { 306 rc = -EBUSY; 307 goto out; 308 } 309 310 if (len == sizeof(*crypto_info)) { 311 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info))) 312 rc = -EFAULT; 313 goto out; 314 } 315 316 switch (crypto_info->cipher_type) { 317 case TLS_CIPHER_AES_GCM_128: { 318 struct tls12_crypto_info_aes_gcm_128 * 319 crypto_info_aes_gcm_128 = 320 container_of(crypto_info, 321 struct tls12_crypto_info_aes_gcm_128, 322 info); 323 324 if (len != sizeof(*crypto_info_aes_gcm_128)) { 325 rc = -EINVAL; 326 goto out; 327 } 328 lock_sock(sk); 329 memcpy(crypto_info_aes_gcm_128->iv, 330 ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, 331 TLS_CIPHER_AES_GCM_128_IV_SIZE); 332 memcpy(crypto_info_aes_gcm_128->rec_seq, ctx->tx.rec_seq, 333 TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE); 334 release_sock(sk); 335 if (copy_to_user(optval, 336 crypto_info_aes_gcm_128, 337 sizeof(*crypto_info_aes_gcm_128))) 338 rc = -EFAULT; 339 break; 340 } 341 default: 342 rc = -EINVAL; 343 } 344 345 out: 346 return rc; 347 } 348 349 static int do_tls_getsockopt(struct sock *sk, int optname, 350 char __user *optval, int __user *optlen) 351 { 352 int rc = 0; 353 354 switch (optname) { 355 case TLS_TX: 356 rc = do_tls_getsockopt_tx(sk, optval, optlen); 357 break; 358 default: 359 rc = -ENOPROTOOPT; 360 break; 361 } 362 return rc; 363 } 364 365 static int tls_getsockopt(struct sock *sk, int level, int optname, 366 char __user *optval, int __user *optlen) 367 { 368 struct tls_context *ctx = tls_get_ctx(sk); 369 370 if (level != SOL_TLS) 371 return ctx->getsockopt(sk, level, optname, optval, optlen); 372 373 return do_tls_getsockopt(sk, optname, optval, optlen); 374 } 375 376 static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, 377 unsigned int optlen, int tx) 378 { 379 struct tls_crypto_info *crypto_info; 380 struct tls_context *ctx = tls_get_ctx(sk); 381 int rc = 0; 382 int conf; 383 384 if (!optval || (optlen < sizeof(*crypto_info))) { 385 rc = -EINVAL; 386 goto out; 387 } 388 389 if (tx) 390 crypto_info = &ctx->crypto_send; 391 else 392 crypto_info = &ctx->crypto_recv; 393 394 /* Currently we don't support set crypto info more than one time */ 395 if (TLS_CRYPTO_INFO_READY(crypto_info)) { 396 rc = -EBUSY; 397 goto out; 398 } 399 400 rc = copy_from_user(crypto_info, optval, sizeof(*crypto_info)); 401 if (rc) { 402 rc = -EFAULT; 403 goto err_crypto_info; 404 } 405 406 /* check version */ 407 if (crypto_info->version != TLS_1_2_VERSION) { 408 rc = -ENOTSUPP; 409 goto err_crypto_info; 410 } 411 412 switch (crypto_info->cipher_type) { 413 case TLS_CIPHER_AES_GCM_128: { 414 if (optlen != sizeof(struct tls12_crypto_info_aes_gcm_128)) { 415 rc = -EINVAL; 416 goto err_crypto_info; 417 } 418 rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info), 419 optlen - sizeof(*crypto_info)); 420 if (rc) { 421 rc = -EFAULT; 422 goto err_crypto_info; 423 } 424 break; 425 } 426 default: 427 rc = -EINVAL; 428 goto err_crypto_info; 429 } 430 431 /* currently SW is default, we will have ethtool in future */ 432 if (tx) { 433 rc = tls_set_sw_offload(sk, ctx, 1); 434 if (ctx->conf == TLS_SW_RX) 435 conf = TLS_SW_RXTX; 436 else 437 conf = TLS_SW_TX; 438 } else { 439 rc = tls_set_sw_offload(sk, ctx, 0); 440 if (ctx->conf == TLS_SW_TX) 441 conf = TLS_SW_RXTX; 442 else 443 conf = TLS_SW_RX; 444 } 445 446 if (rc) 447 goto err_crypto_info; 448 449 ctx->conf = conf; 450 update_sk_prot(sk, ctx); 451 if (tx) { 452 ctx->sk_write_space = sk->sk_write_space; 453 sk->sk_write_space = tls_write_space; 454 } else { 455 sk->sk_socket->ops = &tls_sw_proto_ops; 456 } 457 goto out; 458 459 err_crypto_info: 460 memset(crypto_info, 0, sizeof(*crypto_info)); 461 out: 462 return rc; 463 } 464 465 static int do_tls_setsockopt(struct sock *sk, int optname, 466 char __user *optval, unsigned int optlen) 467 { 468 int rc = 0; 469 470 switch (optname) { 471 case TLS_TX: 472 case TLS_RX: 473 lock_sock(sk); 474 rc = do_tls_setsockopt_conf(sk, optval, optlen, 475 optname == TLS_TX); 476 release_sock(sk); 477 break; 478 default: 479 rc = -ENOPROTOOPT; 480 break; 481 } 482 return rc; 483 } 484 485 static int tls_setsockopt(struct sock *sk, int level, int optname, 486 char __user *optval, unsigned int optlen) 487 { 488 struct tls_context *ctx = tls_get_ctx(sk); 489 490 if (level != SOL_TLS) 491 return ctx->setsockopt(sk, level, optname, optval, optlen); 492 493 return do_tls_setsockopt(sk, optname, optval, optlen); 494 } 495 496 static void build_protos(struct proto *prot, struct proto *base) 497 { 498 prot[TLS_BASE] = *base; 499 prot[TLS_BASE].setsockopt = tls_setsockopt; 500 prot[TLS_BASE].getsockopt = tls_getsockopt; 501 prot[TLS_BASE].close = tls_sk_proto_close; 502 503 prot[TLS_SW_TX] = prot[TLS_BASE]; 504 prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg; 505 prot[TLS_SW_TX].sendpage = tls_sw_sendpage; 506 507 prot[TLS_SW_RX] = prot[TLS_BASE]; 508 prot[TLS_SW_RX].recvmsg = tls_sw_recvmsg; 509 prot[TLS_SW_RX].close = tls_sk_proto_close; 510 511 prot[TLS_SW_RXTX] = prot[TLS_SW_TX]; 512 prot[TLS_SW_RXTX].recvmsg = tls_sw_recvmsg; 513 prot[TLS_SW_RXTX].close = tls_sk_proto_close; 514 } 515 516 static int tls_init(struct sock *sk) 517 { 518 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 519 struct inet_connection_sock *icsk = inet_csk(sk); 520 struct tls_context *ctx; 521 int rc = 0; 522 523 /* The TLS ulp is currently supported only for TCP sockets 524 * in ESTABLISHED state. 525 * Supporting sockets in LISTEN state will require us 526 * to modify the accept implementation to clone rather then 527 * share the ulp context. 528 */ 529 if (sk->sk_state != TCP_ESTABLISHED) 530 return -ENOTSUPP; 531 532 /* allocate tls context */ 533 ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); 534 if (!ctx) { 535 rc = -ENOMEM; 536 goto out; 537 } 538 icsk->icsk_ulp_data = ctx; 539 ctx->setsockopt = sk->sk_prot->setsockopt; 540 ctx->getsockopt = sk->sk_prot->getsockopt; 541 ctx->sk_proto_close = sk->sk_prot->close; 542 543 /* Build IPv6 TLS whenever the address of tcpv6_prot changes */ 544 if (ip_ver == TLSV6 && 545 unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) { 546 mutex_lock(&tcpv6_prot_mutex); 547 if (likely(sk->sk_prot != saved_tcpv6_prot)) { 548 build_protos(tls_prots[TLSV6], sk->sk_prot); 549 smp_store_release(&saved_tcpv6_prot, sk->sk_prot); 550 } 551 mutex_unlock(&tcpv6_prot_mutex); 552 } 553 554 ctx->conf = TLS_BASE; 555 update_sk_prot(sk, ctx); 556 out: 557 return rc; 558 } 559 560 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { 561 .name = "tls", 562 .uid = TCP_ULP_TLS, 563 .user_visible = true, 564 .owner = THIS_MODULE, 565 .init = tls_init, 566 }; 567 568 static int __init tls_register(void) 569 { 570 build_protos(tls_prots[TLSV4], &tcp_prot); 571 572 tls_sw_proto_ops = inet_stream_ops; 573 tls_sw_proto_ops.poll = tls_sw_poll; 574 tls_sw_proto_ops.splice_read = tls_sw_splice_read; 575 576 tcp_register_ulp(&tcp_tls_ulp_ops); 577 578 return 0; 579 } 580 581 static void __exit tls_unregister(void) 582 { 583 tcp_unregister_ulp(&tcp_tls_ulp_ops); 584 } 585 586 module_init(tls_register); 587 module_exit(tls_unregister); 588