1 // SPDX-License-Identifier: GPL-2.0-only 2 /* Copyright (c) 2016 Tom Herbert <tom@herbertland.com> */ 3 4 #include <linux/skbuff.h> 5 #include <linux/workqueue.h> 6 #include <net/strparser.h> 7 #include <net/tcp.h> 8 #include <net/sock.h> 9 #include <net/tls.h> 10 11 #include "tls.h" 12 13 static struct workqueue_struct *tls_strp_wq; 14 15 static void tls_strp_abort_strp(struct tls_strparser *strp, int err) 16 { 17 if (strp->stopped) 18 return; 19 20 strp->stopped = 1; 21 22 /* Report an error on the lower socket */ 23 strp->sk->sk_err = -err; 24 sk_error_report(strp->sk); 25 } 26 27 static void tls_strp_anchor_free(struct tls_strparser *strp) 28 { 29 struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); 30 31 DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); 32 shinfo->frag_list = NULL; 33 consume_skb(strp->anchor); 34 strp->anchor = NULL; 35 } 36 37 /* Create a new skb with the contents of input copied to its page frags */ 38 static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp) 39 { 40 struct strp_msg *rxm; 41 struct sk_buff *skb; 42 int i, err, offset; 43 44 skb = alloc_skb_with_frags(0, strp->stm.full_len, TLS_PAGE_ORDER, 45 &err, strp->sk->sk_allocation); 46 if (!skb) 47 return NULL; 48 49 offset = strp->stm.offset; 50 for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { 51 skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; 52 53 WARN_ON_ONCE(skb_copy_bits(strp->anchor, offset, 54 skb_frag_address(frag), 55 skb_frag_size(frag))); 56 offset += skb_frag_size(frag); 57 } 58 59 skb_copy_header(skb, strp->anchor); 60 rxm = strp_msg(skb); 61 rxm->offset = 0; 62 return skb; 63 } 64 65 /* Steal the input skb, input msg is invalid after calling this function */ 66 struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx) 67 { 68 struct tls_strparser *strp = &ctx->strp; 69 70 #ifdef CONFIG_TLS_DEVICE 71 DEBUG_NET_WARN_ON_ONCE(!strp->anchor->decrypted); 72 #else 73 /* This function turns an input into an output, 74 * that can only happen if we have offload. 75 */ 76 WARN_ON(1); 77 #endif 78 79 if (strp->copy_mode) { 80 struct sk_buff *skb; 81 82 /* Replace anchor with an empty skb, this is a little 83 * dangerous but __tls_cur_msg() warns on empty skbs 84 * so hopefully we'll catch abuses. 85 */ 86 skb = alloc_skb(0, strp->sk->sk_allocation); 87 if (!skb) 88 return NULL; 89 90 swap(strp->anchor, skb); 91 return skb; 92 } 93 94 return tls_strp_msg_make_copy(strp); 95 } 96 97 /* Force the input skb to be in copy mode. The data ownership remains 98 * with the input skb itself (meaning unpause will wipe it) but it can 99 * be modified. 100 */ 101 int tls_strp_msg_cow(struct tls_sw_context_rx *ctx) 102 { 103 struct tls_strparser *strp = &ctx->strp; 104 struct sk_buff *skb; 105 106 if (strp->copy_mode) 107 return 0; 108 109 skb = tls_strp_msg_make_copy(strp); 110 if (!skb) 111 return -ENOMEM; 112 113 tls_strp_anchor_free(strp); 114 strp->anchor = skb; 115 116 tcp_read_done(strp->sk, strp->stm.full_len); 117 strp->copy_mode = 1; 118 119 return 0; 120 } 121 122 /* Make a clone (in the skb sense) of the input msg to keep a reference 123 * to the underlying data. The reference-holding skbs get placed on 124 * @dst. 125 */ 126 int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst) 127 { 128 struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); 129 130 if (strp->copy_mode) { 131 struct sk_buff *skb; 132 133 WARN_ON_ONCE(!shinfo->nr_frags); 134 135 /* We can't skb_clone() the anchor, it gets wiped by unpause */ 136 skb = alloc_skb(0, strp->sk->sk_allocation); 137 if (!skb) 138 return -ENOMEM; 139 140 __skb_queue_tail(dst, strp->anchor); 141 strp->anchor = skb; 142 } else { 143 struct sk_buff *iter, *clone; 144 int chunk, len, offset; 145 146 offset = strp->stm.offset; 147 len = strp->stm.full_len; 148 iter = shinfo->frag_list; 149 150 while (len > 0) { 151 if (iter->len <= offset) { 152 offset -= iter->len; 153 goto next; 154 } 155 156 chunk = iter->len - offset; 157 offset = 0; 158 159 clone = skb_clone(iter, strp->sk->sk_allocation); 160 if (!clone) 161 return -ENOMEM; 162 __skb_queue_tail(dst, clone); 163 164 len -= chunk; 165 next: 166 iter = iter->next; 167 } 168 } 169 170 return 0; 171 } 172 173 static void tls_strp_flush_anchor_copy(struct tls_strparser *strp) 174 { 175 struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); 176 int i; 177 178 DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); 179 180 for (i = 0; i < shinfo->nr_frags; i++) 181 __skb_frag_unref(&shinfo->frags[i], false); 182 shinfo->nr_frags = 0; 183 strp->copy_mode = 0; 184 } 185 186 static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, 187 unsigned int offset, size_t in_len) 188 { 189 struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data; 190 struct sk_buff *skb; 191 skb_frag_t *frag; 192 size_t len, chunk; 193 int sz; 194 195 if (strp->msg_ready) 196 return 0; 197 198 skb = strp->anchor; 199 frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE]; 200 201 len = in_len; 202 /* First make sure we got the header */ 203 if (!strp->stm.full_len) { 204 /* Assume one page is more than enough for headers */ 205 chunk = min_t(size_t, len, PAGE_SIZE - skb_frag_size(frag)); 206 WARN_ON_ONCE(skb_copy_bits(in_skb, offset, 207 skb_frag_address(frag) + 208 skb_frag_size(frag), 209 chunk)); 210 211 sz = tls_rx_msg_size(strp, strp->anchor); 212 if (sz < 0) { 213 desc->error = sz; 214 return 0; 215 } 216 217 /* We may have over-read, sz == 0 is guaranteed under-read */ 218 if (sz > 0) 219 chunk = min_t(size_t, chunk, sz - skb->len); 220 221 skb->len += chunk; 222 skb->data_len += chunk; 223 skb_frag_size_add(frag, chunk); 224 frag++; 225 len -= chunk; 226 offset += chunk; 227 228 strp->stm.full_len = sz; 229 if (!strp->stm.full_len) 230 goto read_done; 231 } 232 233 /* Load up more data */ 234 while (len && strp->stm.full_len > skb->len) { 235 chunk = min_t(size_t, len, strp->stm.full_len - skb->len); 236 chunk = min_t(size_t, chunk, PAGE_SIZE - skb_frag_size(frag)); 237 WARN_ON_ONCE(skb_copy_bits(in_skb, offset, 238 skb_frag_address(frag) + 239 skb_frag_size(frag), 240 chunk)); 241 242 skb->len += chunk; 243 skb->data_len += chunk; 244 skb_frag_size_add(frag, chunk); 245 frag++; 246 len -= chunk; 247 offset += chunk; 248 } 249 250 if (strp->stm.full_len == skb->len) { 251 desc->count = 0; 252 253 strp->msg_ready = 1; 254 tls_rx_msg_ready(strp); 255 } 256 257 read_done: 258 return in_len - len; 259 } 260 261 static int tls_strp_read_copyin(struct tls_strparser *strp) 262 { 263 struct socket *sock = strp->sk->sk_socket; 264 read_descriptor_t desc; 265 266 desc.arg.data = strp; 267 desc.error = 0; 268 desc.count = 1; /* give more than one skb per call */ 269 270 /* sk should be locked here, so okay to do read_sock */ 271 sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin); 272 273 return desc.error; 274 } 275 276 static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort) 277 { 278 struct skb_shared_info *shinfo; 279 struct page *page; 280 int need_spc, len; 281 282 /* If the rbuf is small or rcv window has collapsed to 0 we need 283 * to read the data out. Otherwise the connection will stall. 284 * Without pressure threshold of INT_MAX will never be ready. 285 */ 286 if (likely(qshort && !tcp_epollin_ready(strp->sk, INT_MAX))) 287 return 0; 288 289 shinfo = skb_shinfo(strp->anchor); 290 shinfo->frag_list = NULL; 291 292 /* If we don't know the length go max plus page for cipher overhead */ 293 need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE; 294 295 for (len = need_spc; len > 0; len -= PAGE_SIZE) { 296 page = alloc_page(strp->sk->sk_allocation); 297 if (!page) { 298 tls_strp_flush_anchor_copy(strp); 299 return -ENOMEM; 300 } 301 302 skb_fill_page_desc(strp->anchor, shinfo->nr_frags++, 303 page, 0, 0); 304 } 305 306 strp->copy_mode = 1; 307 strp->stm.offset = 0; 308 309 strp->anchor->len = 0; 310 strp->anchor->data_len = 0; 311 strp->anchor->truesize = round_up(need_spc, PAGE_SIZE); 312 313 tls_strp_read_copyin(strp); 314 315 return 0; 316 } 317 318 static bool tls_strp_check_no_dup(struct tls_strparser *strp) 319 { 320 unsigned int len = strp->stm.offset + strp->stm.full_len; 321 struct sk_buff *skb; 322 u32 seq; 323 324 skb = skb_shinfo(strp->anchor)->frag_list; 325 seq = TCP_SKB_CB(skb)->seq; 326 327 while (skb->len < len) { 328 seq += skb->len; 329 len -= skb->len; 330 skb = skb->next; 331 332 if (TCP_SKB_CB(skb)->seq != seq) 333 return false; 334 } 335 336 return true; 337 } 338 339 static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len) 340 { 341 struct tcp_sock *tp = tcp_sk(strp->sk); 342 struct sk_buff *first; 343 u32 offset; 344 345 first = tcp_recv_skb(strp->sk, tp->copied_seq, &offset); 346 if (WARN_ON_ONCE(!first)) 347 return; 348 349 /* Bestow the state onto the anchor */ 350 strp->anchor->len = offset + len; 351 strp->anchor->data_len = offset + len; 352 strp->anchor->truesize = offset + len; 353 354 skb_shinfo(strp->anchor)->frag_list = first; 355 356 skb_copy_header(strp->anchor, first); 357 strp->anchor->destructor = NULL; 358 359 strp->stm.offset = offset; 360 } 361 362 void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) 363 { 364 struct strp_msg *rxm; 365 struct tls_msg *tlm; 366 367 DEBUG_NET_WARN_ON_ONCE(!strp->msg_ready); 368 DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len); 369 370 if (!strp->copy_mode && force_refresh) { 371 if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len)) 372 return; 373 374 tls_strp_load_anchor_with_queue(strp, strp->stm.full_len); 375 } 376 377 rxm = strp_msg(strp->anchor); 378 rxm->full_len = strp->stm.full_len; 379 rxm->offset = strp->stm.offset; 380 tlm = tls_msg(strp->anchor); 381 tlm->control = strp->mark; 382 } 383 384 /* Called with lock held on lower socket */ 385 static int tls_strp_read_sock(struct tls_strparser *strp) 386 { 387 int sz, inq; 388 389 inq = tcp_inq(strp->sk); 390 if (inq < 1) 391 return 0; 392 393 if (unlikely(strp->copy_mode)) 394 return tls_strp_read_copyin(strp); 395 396 if (inq < strp->stm.full_len) 397 return tls_strp_read_copy(strp, true); 398 399 if (!strp->stm.full_len) { 400 tls_strp_load_anchor_with_queue(strp, inq); 401 402 sz = tls_rx_msg_size(strp, strp->anchor); 403 if (sz < 0) { 404 tls_strp_abort_strp(strp, sz); 405 return sz; 406 } 407 408 strp->stm.full_len = sz; 409 410 if (!strp->stm.full_len || inq < strp->stm.full_len) 411 return tls_strp_read_copy(strp, true); 412 } 413 414 if (!tls_strp_check_no_dup(strp)) 415 return tls_strp_read_copy(strp, false); 416 417 strp->msg_ready = 1; 418 tls_rx_msg_ready(strp); 419 420 return 0; 421 } 422 423 void tls_strp_check_rcv(struct tls_strparser *strp) 424 { 425 if (unlikely(strp->stopped) || strp->msg_ready) 426 return; 427 428 if (tls_strp_read_sock(strp) == -ENOMEM) 429 queue_work(tls_strp_wq, &strp->work); 430 } 431 432 /* Lower sock lock held */ 433 void tls_strp_data_ready(struct tls_strparser *strp) 434 { 435 /* This check is needed to synchronize with do_tls_strp_work. 436 * do_tls_strp_work acquires a process lock (lock_sock) whereas 437 * the lock held here is bh_lock_sock. The two locks can be 438 * held by different threads at the same time, but bh_lock_sock 439 * allows a thread in BH context to safely check if the process 440 * lock is held. In this case, if the lock is held, queue work. 441 */ 442 if (sock_owned_by_user_nocheck(strp->sk)) { 443 queue_work(tls_strp_wq, &strp->work); 444 return; 445 } 446 447 tls_strp_check_rcv(strp); 448 } 449 450 static void tls_strp_work(struct work_struct *w) 451 { 452 struct tls_strparser *strp = 453 container_of(w, struct tls_strparser, work); 454 455 lock_sock(strp->sk); 456 tls_strp_check_rcv(strp); 457 release_sock(strp->sk); 458 } 459 460 void tls_strp_msg_done(struct tls_strparser *strp) 461 { 462 WARN_ON(!strp->stm.full_len); 463 464 if (likely(!strp->copy_mode)) 465 tcp_read_done(strp->sk, strp->stm.full_len); 466 else 467 tls_strp_flush_anchor_copy(strp); 468 469 strp->msg_ready = 0; 470 memset(&strp->stm, 0, sizeof(strp->stm)); 471 472 tls_strp_check_rcv(strp); 473 } 474 475 void tls_strp_stop(struct tls_strparser *strp) 476 { 477 strp->stopped = 1; 478 } 479 480 int tls_strp_init(struct tls_strparser *strp, struct sock *sk) 481 { 482 memset(strp, 0, sizeof(*strp)); 483 484 strp->sk = sk; 485 486 strp->anchor = alloc_skb(0, GFP_KERNEL); 487 if (!strp->anchor) 488 return -ENOMEM; 489 490 INIT_WORK(&strp->work, tls_strp_work); 491 492 return 0; 493 } 494 495 /* strp must already be stopped so that tls_strp_recv will no longer be called. 496 * Note that tls_strp_done is not called with the lower socket held. 497 */ 498 void tls_strp_done(struct tls_strparser *strp) 499 { 500 WARN_ON(!strp->stopped); 501 502 cancel_work_sync(&strp->work); 503 tls_strp_anchor_free(strp); 504 } 505 506 int __init tls_strp_dev_init(void) 507 { 508 tls_strp_wq = create_workqueue("tls-strp"); 509 if (unlikely(!tls_strp_wq)) 510 return -ENOMEM; 511 512 return 0; 513 } 514 515 void tls_strp_dev_exit(void) 516 { 517 destroy_workqueue(tls_strp_wq); 518 } 519