1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * 5 * This software is available to you under a choice of one of two 6 * licenses. You may choose to be licensed under the terms of the GNU 7 * General Public License (GPL) Version 2, available from the file 8 * COPYING in the main directory of this source tree, or the 9 * OpenIB.org BSD license below: 10 * 11 * Redistribution and use in source and binary forms, with or 12 * without modification, are permitted provided that the following 13 * conditions are met: 14 * 15 * - Redistributions of source code must retain the above 16 * copyright notice, this list of conditions and the following 17 * disclaimer. 18 * 19 * - Redistributions in binary form must reproduce the above 20 * copyright notice, this list of conditions and the following 21 * disclaimer in the documentation and/or other materials 22 * provided with the distribution. 23 * 24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 * SOFTWARE. 32 */ 33 34 #include <linux/module.h> 35 36 #include <net/tcp.h> 37 #include <net/inet_common.h> 38 #include <linux/highmem.h> 39 #include <linux/netdevice.h> 40 #include <linux/sched/signal.h> 41 42 #include <net/tls.h> 43 44 MODULE_AUTHOR("Mellanox Technologies"); 45 MODULE_DESCRIPTION("Transport Layer Security Support"); 46 MODULE_LICENSE("Dual BSD/GPL"); 47 48 static struct proto tls_base_prot; 49 static struct proto tls_sw_prot; 50 51 int wait_on_pending_writer(struct sock *sk, long *timeo) 52 { 53 int rc = 0; 54 DEFINE_WAIT_FUNC(wait, woken_wake_function); 55 56 add_wait_queue(sk_sleep(sk), &wait); 57 while (1) { 58 if (!*timeo) { 59 rc = -EAGAIN; 60 break; 61 } 62 63 if (signal_pending(current)) { 64 rc = sock_intr_errno(*timeo); 65 break; 66 } 67 68 if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait)) 69 break; 70 } 71 remove_wait_queue(sk_sleep(sk), &wait); 72 return rc; 73 } 74 75 int tls_push_sg(struct sock *sk, 76 struct tls_context *ctx, 77 struct scatterlist *sg, 78 u16 first_offset, 79 int flags) 80 { 81 int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST; 82 int ret = 0; 83 struct page *p; 84 size_t size; 85 int offset = first_offset; 86 87 size = sg->length - offset; 88 offset += sg->offset; 89 90 while (1) { 91 if (sg_is_last(sg)) 92 sendpage_flags = flags; 93 94 /* is sending application-limited? */ 95 tcp_rate_check_app_limited(sk); 96 p = sg_page(sg); 97 retry: 98 ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags); 99 100 if (ret != size) { 101 if (ret > 0) { 102 offset += ret; 103 size -= ret; 104 goto retry; 105 } 106 107 offset -= sg->offset; 108 ctx->partially_sent_offset = offset; 109 ctx->partially_sent_record = (void *)sg; 110 return ret; 111 } 112 113 put_page(p); 114 sk_mem_uncharge(sk, sg->length); 115 sg = sg_next(sg); 116 if (!sg) 117 break; 118 119 offset = sg->offset; 120 size = sg->length; 121 } 122 123 clear_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); 124 125 return 0; 126 } 127 128 static int tls_handle_open_record(struct sock *sk, int flags) 129 { 130 struct tls_context *ctx = tls_get_ctx(sk); 131 132 if (tls_is_pending_open_record(ctx)) 133 return ctx->push_pending_record(sk, flags); 134 135 return 0; 136 } 137 138 int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg, 139 unsigned char *record_type) 140 { 141 struct cmsghdr *cmsg; 142 int rc = -EINVAL; 143 144 for_each_cmsghdr(cmsg, msg) { 145 if (!CMSG_OK(msg, cmsg)) 146 return -EINVAL; 147 if (cmsg->cmsg_level != SOL_TLS) 148 continue; 149 150 switch (cmsg->cmsg_type) { 151 case TLS_SET_RECORD_TYPE: 152 if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type))) 153 return -EINVAL; 154 155 if (msg->msg_flags & MSG_MORE) 156 return -EINVAL; 157 158 rc = tls_handle_open_record(sk, msg->msg_flags); 159 if (rc) 160 return rc; 161 162 *record_type = *(unsigned char *)CMSG_DATA(cmsg); 163 rc = 0; 164 break; 165 default: 166 return -EINVAL; 167 } 168 } 169 170 return rc; 171 } 172 173 int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx, 174 int flags, long *timeo) 175 { 176 struct scatterlist *sg; 177 u16 offset; 178 179 if (!tls_is_partially_sent_record(ctx)) 180 return ctx->push_pending_record(sk, flags); 181 182 sg = ctx->partially_sent_record; 183 offset = ctx->partially_sent_offset; 184 185 ctx->partially_sent_record = NULL; 186 return tls_push_sg(sk, ctx, sg, offset, flags); 187 } 188 189 static void tls_write_space(struct sock *sk) 190 { 191 struct tls_context *ctx = tls_get_ctx(sk); 192 193 if (!sk->sk_write_pending && tls_is_pending_closed_record(ctx)) { 194 gfp_t sk_allocation = sk->sk_allocation; 195 int rc; 196 long timeo = 0; 197 198 sk->sk_allocation = GFP_ATOMIC; 199 rc = tls_push_pending_closed_record(sk, ctx, 200 MSG_DONTWAIT | 201 MSG_NOSIGNAL, 202 &timeo); 203 sk->sk_allocation = sk_allocation; 204 205 if (rc < 0) 206 return; 207 } 208 209 ctx->sk_write_space(sk); 210 } 211 212 static void tls_sk_proto_close(struct sock *sk, long timeout) 213 { 214 struct tls_context *ctx = tls_get_ctx(sk); 215 long timeo = sock_sndtimeo(sk, 0); 216 void (*sk_proto_close)(struct sock *sk, long timeout); 217 218 lock_sock(sk); 219 220 if (!tls_complete_pending_work(sk, ctx, 0, &timeo)) 221 tls_handle_open_record(sk, 0); 222 223 if (ctx->partially_sent_record) { 224 struct scatterlist *sg = ctx->partially_sent_record; 225 226 while (1) { 227 put_page(sg_page(sg)); 228 sk_mem_uncharge(sk, sg->length); 229 230 if (sg_is_last(sg)) 231 break; 232 sg++; 233 } 234 } 235 ctx->free_resources(sk); 236 kfree(ctx->rec_seq); 237 kfree(ctx->iv); 238 239 sk_proto_close = ctx->sk_proto_close; 240 kfree(ctx); 241 242 release_sock(sk); 243 sk_proto_close(sk, timeout); 244 } 245 246 static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, 247 int __user *optlen) 248 { 249 int rc = 0; 250 struct tls_context *ctx = tls_get_ctx(sk); 251 struct tls_crypto_info *crypto_info; 252 int len; 253 254 if (get_user(len, optlen)) 255 return -EFAULT; 256 257 if (!optval || (len < sizeof(*crypto_info))) { 258 rc = -EINVAL; 259 goto out; 260 } 261 262 if (!ctx) { 263 rc = -EBUSY; 264 goto out; 265 } 266 267 /* get user crypto info */ 268 crypto_info = &ctx->crypto_send; 269 270 if (!TLS_CRYPTO_INFO_READY(crypto_info)) { 271 rc = -EBUSY; 272 goto out; 273 } 274 275 if (len == sizeof(*crypto_info)) { 276 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info))) 277 rc = -EFAULT; 278 goto out; 279 } 280 281 switch (crypto_info->cipher_type) { 282 case TLS_CIPHER_AES_GCM_128: { 283 struct tls12_crypto_info_aes_gcm_128 * 284 crypto_info_aes_gcm_128 = 285 container_of(crypto_info, 286 struct tls12_crypto_info_aes_gcm_128, 287 info); 288 289 if (len != sizeof(*crypto_info_aes_gcm_128)) { 290 rc = -EINVAL; 291 goto out; 292 } 293 lock_sock(sk); 294 memcpy(crypto_info_aes_gcm_128->iv, ctx->iv, 295 TLS_CIPHER_AES_GCM_128_IV_SIZE); 296 release_sock(sk); 297 if (copy_to_user(optval, 298 crypto_info_aes_gcm_128, 299 sizeof(*crypto_info_aes_gcm_128))) 300 rc = -EFAULT; 301 break; 302 } 303 default: 304 rc = -EINVAL; 305 } 306 307 out: 308 return rc; 309 } 310 311 static int do_tls_getsockopt(struct sock *sk, int optname, 312 char __user *optval, int __user *optlen) 313 { 314 int rc = 0; 315 316 switch (optname) { 317 case TLS_TX: 318 rc = do_tls_getsockopt_tx(sk, optval, optlen); 319 break; 320 default: 321 rc = -ENOPROTOOPT; 322 break; 323 } 324 return rc; 325 } 326 327 static int tls_getsockopt(struct sock *sk, int level, int optname, 328 char __user *optval, int __user *optlen) 329 { 330 struct tls_context *ctx = tls_get_ctx(sk); 331 332 if (level != SOL_TLS) 333 return ctx->getsockopt(sk, level, optname, optval, optlen); 334 335 return do_tls_getsockopt(sk, optname, optval, optlen); 336 } 337 338 static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval, 339 unsigned int optlen) 340 { 341 struct tls_crypto_info *crypto_info, tmp_crypto_info; 342 struct tls_context *ctx = tls_get_ctx(sk); 343 struct proto *prot = NULL; 344 int rc = 0; 345 346 if (!optval || (optlen < sizeof(*crypto_info))) { 347 rc = -EINVAL; 348 goto out; 349 } 350 351 rc = copy_from_user(&tmp_crypto_info, optval, sizeof(*crypto_info)); 352 if (rc) { 353 rc = -EFAULT; 354 goto out; 355 } 356 357 /* check version */ 358 if (tmp_crypto_info.version != TLS_1_2_VERSION) { 359 rc = -ENOTSUPP; 360 goto out; 361 } 362 363 /* get user crypto info */ 364 crypto_info = &ctx->crypto_send; 365 366 /* Currently we don't support set crypto info more than one time */ 367 if (TLS_CRYPTO_INFO_READY(crypto_info)) 368 goto out; 369 370 switch (tmp_crypto_info.cipher_type) { 371 case TLS_CIPHER_AES_GCM_128: { 372 if (optlen != sizeof(struct tls12_crypto_info_aes_gcm_128)) { 373 rc = -EINVAL; 374 goto out; 375 } 376 rc = copy_from_user( 377 crypto_info, 378 optval, 379 sizeof(struct tls12_crypto_info_aes_gcm_128)); 380 381 if (rc) { 382 rc = -EFAULT; 383 goto err_crypto_info; 384 } 385 break; 386 } 387 default: 388 rc = -EINVAL; 389 goto out; 390 } 391 392 ctx->sk_write_space = sk->sk_write_space; 393 sk->sk_write_space = tls_write_space; 394 395 ctx->sk_proto_close = sk->sk_prot->close; 396 397 /* currently SW is default, we will have ethtool in future */ 398 rc = tls_set_sw_offload(sk, ctx); 399 prot = &tls_sw_prot; 400 if (rc) 401 goto err_crypto_info; 402 403 sk->sk_prot = prot; 404 goto out; 405 406 err_crypto_info: 407 memset(crypto_info, 0, sizeof(*crypto_info)); 408 out: 409 return rc; 410 } 411 412 static int do_tls_setsockopt(struct sock *sk, int optname, 413 char __user *optval, unsigned int optlen) 414 { 415 int rc = 0; 416 417 switch (optname) { 418 case TLS_TX: 419 lock_sock(sk); 420 rc = do_tls_setsockopt_tx(sk, optval, optlen); 421 release_sock(sk); 422 break; 423 default: 424 rc = -ENOPROTOOPT; 425 break; 426 } 427 return rc; 428 } 429 430 static int tls_setsockopt(struct sock *sk, int level, int optname, 431 char __user *optval, unsigned int optlen) 432 { 433 struct tls_context *ctx = tls_get_ctx(sk); 434 435 if (level != SOL_TLS) 436 return ctx->setsockopt(sk, level, optname, optval, optlen); 437 438 return do_tls_setsockopt(sk, optname, optval, optlen); 439 } 440 441 static int tls_init(struct sock *sk) 442 { 443 struct inet_connection_sock *icsk = inet_csk(sk); 444 struct tls_context *ctx; 445 int rc = 0; 446 447 /* allocate tls context */ 448 ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); 449 if (!ctx) { 450 rc = -ENOMEM; 451 goto out; 452 } 453 icsk->icsk_ulp_data = ctx; 454 ctx->setsockopt = sk->sk_prot->setsockopt; 455 ctx->getsockopt = sk->sk_prot->getsockopt; 456 sk->sk_prot = &tls_base_prot; 457 out: 458 return rc; 459 } 460 461 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { 462 .name = "tls", 463 .owner = THIS_MODULE, 464 .init = tls_init, 465 }; 466 467 static int __init tls_register(void) 468 { 469 tls_base_prot = tcp_prot; 470 tls_base_prot.setsockopt = tls_setsockopt; 471 tls_base_prot.getsockopt = tls_getsockopt; 472 473 tls_sw_prot = tls_base_prot; 474 tls_sw_prot.sendmsg = tls_sw_sendmsg; 475 tls_sw_prot.sendpage = tls_sw_sendpage; 476 tls_sw_prot.close = tls_sk_proto_close; 477 478 tcp_register_ulp(&tcp_tls_ulp_ops); 479 480 return 0; 481 } 482 483 static void __exit tls_unregister(void) 484 { 485 tcp_unregister_ulp(&tcp_tls_ulp_ops); 486 } 487 488 module_init(tls_register); 489 module_exit(tls_unregister); 490