1 // SPDX-License-Identifier: GPL-2.0-only 2 /* Copyright (c) 2016 Tom Herbert <tom@herbertland.com> */ 3 4 #include <linux/skbuff.h> 5 #include <linux/workqueue.h> 6 #include <net/strparser.h> 7 #include <net/tcp.h> 8 #include <net/sock.h> 9 #include <net/tls.h> 10 11 #include "tls.h" 12 13 static struct workqueue_struct *tls_strp_wq; 14 15 static void tls_strp_abort_strp(struct tls_strparser *strp, int err) 16 { 17 if (strp->stopped) 18 return; 19 20 strp->stopped = 1; 21 22 /* Report an error on the lower socket */ 23 strp->sk->sk_err = -err; 24 sk_error_report(strp->sk); 25 } 26 27 static void tls_strp_anchor_free(struct tls_strparser *strp) 28 { 29 struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); 30 31 DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); 32 shinfo->frag_list = NULL; 33 consume_skb(strp->anchor); 34 strp->anchor = NULL; 35 } 36 37 /* Create a new skb with the contents of input copied to its page frags */ 38 static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp) 39 { 40 struct strp_msg *rxm; 41 struct sk_buff *skb; 42 int i, err, offset; 43 44 skb = alloc_skb_with_frags(0, strp->anchor->len, TLS_PAGE_ORDER, 45 &err, strp->sk->sk_allocation); 46 if (!skb) 47 return NULL; 48 49 offset = strp->stm.offset; 50 for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { 51 skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; 52 53 WARN_ON_ONCE(skb_copy_bits(strp->anchor, offset, 54 skb_frag_address(frag), 55 skb_frag_size(frag))); 56 offset += skb_frag_size(frag); 57 } 58 59 skb_copy_header(skb, strp->anchor); 60 rxm = strp_msg(skb); 61 rxm->offset = 0; 62 return skb; 63 } 64 65 /* Steal the input skb, input msg is invalid after calling this function */ 66 struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx) 67 { 68 struct tls_strparser *strp = &ctx->strp; 69 70 #ifdef CONFIG_TLS_DEVICE 71 DEBUG_NET_WARN_ON_ONCE(!strp->anchor->decrypted); 72 #else 73 /* This function turns an input into an output, 74 * that can only happen if we have offload. 75 */ 76 WARN_ON(1); 77 #endif 78 79 if (strp->copy_mode) { 80 struct sk_buff *skb; 81 82 /* Replace anchor with an empty skb, this is a little 83 * dangerous but __tls_cur_msg() warns on empty skbs 84 * so hopefully we'll catch abuses. 85 */ 86 skb = alloc_skb(0, strp->sk->sk_allocation); 87 if (!skb) 88 return NULL; 89 90 swap(strp->anchor, skb); 91 return skb; 92 } 93 94 return tls_strp_msg_make_copy(strp); 95 } 96 97 /* Force the input skb to be in copy mode. The data ownership remains 98 * with the input skb itself (meaning unpause will wipe it) but it can 99 * be modified. 100 */ 101 int tls_strp_msg_cow(struct tls_sw_context_rx *ctx) 102 { 103 struct tls_strparser *strp = &ctx->strp; 104 struct sk_buff *skb; 105 106 if (strp->copy_mode) 107 return 0; 108 109 skb = tls_strp_msg_make_copy(strp); 110 if (!skb) 111 return -ENOMEM; 112 113 tls_strp_anchor_free(strp); 114 strp->anchor = skb; 115 116 tcp_read_done(strp->sk, strp->stm.full_len); 117 strp->copy_mode = 1; 118 119 return 0; 120 } 121 122 /* Make a clone (in the skb sense) of the input msg to keep a reference 123 * to the underlying data. The reference-holding skbs get placed on 124 * @dst. 125 */ 126 int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst) 127 { 128 struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); 129 130 if (strp->copy_mode) { 131 struct sk_buff *skb; 132 133 WARN_ON_ONCE(!shinfo->nr_frags); 134 135 /* We can't skb_clone() the anchor, it gets wiped by unpause */ 136 skb = alloc_skb(0, strp->sk->sk_allocation); 137 if (!skb) 138 return -ENOMEM; 139 140 __skb_queue_tail(dst, strp->anchor); 141 strp->anchor = skb; 142 } else { 143 struct sk_buff *iter, *clone; 144 int chunk, len, offset; 145 146 offset = strp->stm.offset; 147 len = strp->stm.full_len; 148 iter = shinfo->frag_list; 149 150 while (len > 0) { 151 if (iter->len <= offset) { 152 offset -= iter->len; 153 goto next; 154 } 155 156 chunk = iter->len - offset; 157 offset = 0; 158 159 clone = skb_clone(iter, strp->sk->sk_allocation); 160 if (!clone) 161 return -ENOMEM; 162 __skb_queue_tail(dst, clone); 163 164 len -= chunk; 165 next: 166 iter = iter->next; 167 } 168 } 169 170 return 0; 171 } 172 173 static void tls_strp_flush_anchor_copy(struct tls_strparser *strp) 174 { 175 struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); 176 int i; 177 178 DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); 179 180 for (i = 0; i < shinfo->nr_frags; i++) 181 __skb_frag_unref(&shinfo->frags[i], false); 182 shinfo->nr_frags = 0; 183 strp->copy_mode = 0; 184 } 185 186 static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, 187 unsigned int offset, size_t in_len) 188 { 189 struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data; 190 struct sk_buff *skb; 191 skb_frag_t *frag; 192 size_t len, chunk; 193 int sz; 194 195 if (strp->msg_ready) 196 return 0; 197 198 skb = strp->anchor; 199 frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE]; 200 201 len = in_len; 202 /* First make sure we got the header */ 203 if (!strp->stm.full_len) { 204 /* Assume one page is more than enough for headers */ 205 chunk = min_t(size_t, len, PAGE_SIZE - skb_frag_size(frag)); 206 WARN_ON_ONCE(skb_copy_bits(in_skb, offset, 207 skb_frag_address(frag) + 208 skb_frag_size(frag), 209 chunk)); 210 211 sz = tls_rx_msg_size(strp, strp->anchor); 212 if (sz < 0) { 213 desc->error = sz; 214 return 0; 215 } 216 217 /* We may have over-read, sz == 0 is guaranteed under-read */ 218 if (sz > 0) 219 chunk = min_t(size_t, chunk, sz - skb->len); 220 221 skb->len += chunk; 222 skb->data_len += chunk; 223 skb_frag_size_add(frag, chunk); 224 frag++; 225 len -= chunk; 226 offset += chunk; 227 228 strp->stm.full_len = sz; 229 if (!strp->stm.full_len) 230 goto read_done; 231 } 232 233 /* Load up more data */ 234 while (len && strp->stm.full_len > skb->len) { 235 chunk = min_t(size_t, len, strp->stm.full_len - skb->len); 236 chunk = min_t(size_t, chunk, PAGE_SIZE - skb_frag_size(frag)); 237 WARN_ON_ONCE(skb_copy_bits(in_skb, offset, 238 skb_frag_address(frag) + 239 skb_frag_size(frag), 240 chunk)); 241 242 skb->len += chunk; 243 skb->data_len += chunk; 244 skb_frag_size_add(frag, chunk); 245 frag++; 246 len -= chunk; 247 offset += chunk; 248 } 249 250 if (strp->stm.full_len == skb->len) { 251 desc->count = 0; 252 253 strp->msg_ready = 1; 254 tls_rx_msg_ready(strp); 255 } 256 257 read_done: 258 return in_len - len; 259 } 260 261 static int tls_strp_read_copyin(struct tls_strparser *strp) 262 { 263 struct socket *sock = strp->sk->sk_socket; 264 read_descriptor_t desc; 265 266 desc.arg.data = strp; 267 desc.error = 0; 268 desc.count = 1; /* give more than one skb per call */ 269 270 /* sk should be locked here, so okay to do read_sock */ 271 sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin); 272 273 return desc.error; 274 } 275 276 static int tls_strp_read_short(struct tls_strparser *strp) 277 { 278 struct skb_shared_info *shinfo; 279 struct page *page; 280 int need_spc, len; 281 282 /* If the rbuf is small or rcv window has collapsed to 0 we need 283 * to read the data out. Otherwise the connection will stall. 284 * Without pressure threshold of INT_MAX will never be ready. 285 */ 286 if (likely(!tcp_epollin_ready(strp->sk, INT_MAX))) 287 return 0; 288 289 shinfo = skb_shinfo(strp->anchor); 290 shinfo->frag_list = NULL; 291 292 /* If we don't know the length go max plus page for cipher overhead */ 293 need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE; 294 295 for (len = need_spc; len > 0; len -= PAGE_SIZE) { 296 page = alloc_page(strp->sk->sk_allocation); 297 if (!page) { 298 tls_strp_flush_anchor_copy(strp); 299 return -ENOMEM; 300 } 301 302 skb_fill_page_desc(strp->anchor, shinfo->nr_frags++, 303 page, 0, 0); 304 } 305 306 strp->copy_mode = 1; 307 strp->stm.offset = 0; 308 309 strp->anchor->len = 0; 310 strp->anchor->data_len = 0; 311 strp->anchor->truesize = round_up(need_spc, PAGE_SIZE); 312 313 tls_strp_read_copyin(strp); 314 315 return 0; 316 } 317 318 static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len) 319 { 320 struct tcp_sock *tp = tcp_sk(strp->sk); 321 struct sk_buff *first; 322 u32 offset; 323 324 first = tcp_recv_skb(strp->sk, tp->copied_seq, &offset); 325 if (WARN_ON_ONCE(!first)) 326 return; 327 328 /* Bestow the state onto the anchor */ 329 strp->anchor->len = offset + len; 330 strp->anchor->data_len = offset + len; 331 strp->anchor->truesize = offset + len; 332 333 skb_shinfo(strp->anchor)->frag_list = first; 334 335 skb_copy_header(strp->anchor, first); 336 strp->anchor->destructor = NULL; 337 338 strp->stm.offset = offset; 339 } 340 341 void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) 342 { 343 struct strp_msg *rxm; 344 struct tls_msg *tlm; 345 346 DEBUG_NET_WARN_ON_ONCE(!strp->msg_ready); 347 DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len); 348 349 if (!strp->copy_mode && force_refresh) { 350 if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len)) 351 return; 352 353 tls_strp_load_anchor_with_queue(strp, strp->stm.full_len); 354 } 355 356 rxm = strp_msg(strp->anchor); 357 rxm->full_len = strp->stm.full_len; 358 rxm->offset = strp->stm.offset; 359 tlm = tls_msg(strp->anchor); 360 tlm->control = strp->mark; 361 } 362 363 /* Called with lock held on lower socket */ 364 static int tls_strp_read_sock(struct tls_strparser *strp) 365 { 366 int sz, inq; 367 368 inq = tcp_inq(strp->sk); 369 if (inq < 1) 370 return 0; 371 372 if (unlikely(strp->copy_mode)) 373 return tls_strp_read_copyin(strp); 374 375 if (inq < strp->stm.full_len) 376 return tls_strp_read_short(strp); 377 378 if (!strp->stm.full_len) { 379 tls_strp_load_anchor_with_queue(strp, inq); 380 381 sz = tls_rx_msg_size(strp, strp->anchor); 382 if (sz < 0) { 383 tls_strp_abort_strp(strp, sz); 384 return sz; 385 } 386 387 strp->stm.full_len = sz; 388 389 if (!strp->stm.full_len || inq < strp->stm.full_len) 390 return tls_strp_read_short(strp); 391 } 392 393 strp->msg_ready = 1; 394 tls_rx_msg_ready(strp); 395 396 return 0; 397 } 398 399 void tls_strp_check_rcv(struct tls_strparser *strp) 400 { 401 if (unlikely(strp->stopped) || strp->msg_ready) 402 return; 403 404 if (tls_strp_read_sock(strp) == -ENOMEM) 405 queue_work(tls_strp_wq, &strp->work); 406 } 407 408 /* Lower sock lock held */ 409 void tls_strp_data_ready(struct tls_strparser *strp) 410 { 411 /* This check is needed to synchronize with do_tls_strp_work. 412 * do_tls_strp_work acquires a process lock (lock_sock) whereas 413 * the lock held here is bh_lock_sock. The two locks can be 414 * held by different threads at the same time, but bh_lock_sock 415 * allows a thread in BH context to safely check if the process 416 * lock is held. In this case, if the lock is held, queue work. 417 */ 418 if (sock_owned_by_user_nocheck(strp->sk)) { 419 queue_work(tls_strp_wq, &strp->work); 420 return; 421 } 422 423 tls_strp_check_rcv(strp); 424 } 425 426 static void tls_strp_work(struct work_struct *w) 427 { 428 struct tls_strparser *strp = 429 container_of(w, struct tls_strparser, work); 430 431 lock_sock(strp->sk); 432 tls_strp_check_rcv(strp); 433 release_sock(strp->sk); 434 } 435 436 void tls_strp_msg_done(struct tls_strparser *strp) 437 { 438 WARN_ON(!strp->stm.full_len); 439 440 if (likely(!strp->copy_mode)) 441 tcp_read_done(strp->sk, strp->stm.full_len); 442 else 443 tls_strp_flush_anchor_copy(strp); 444 445 strp->msg_ready = 0; 446 memset(&strp->stm, 0, sizeof(strp->stm)); 447 448 tls_strp_check_rcv(strp); 449 } 450 451 void tls_strp_stop(struct tls_strparser *strp) 452 { 453 strp->stopped = 1; 454 } 455 456 int tls_strp_init(struct tls_strparser *strp, struct sock *sk) 457 { 458 memset(strp, 0, sizeof(*strp)); 459 460 strp->sk = sk; 461 462 strp->anchor = alloc_skb(0, GFP_KERNEL); 463 if (!strp->anchor) 464 return -ENOMEM; 465 466 INIT_WORK(&strp->work, tls_strp_work); 467 468 return 0; 469 } 470 471 /* strp must already be stopped so that tls_strp_recv will no longer be called. 472 * Note that tls_strp_done is not called with the lower socket held. 473 */ 474 void tls_strp_done(struct tls_strparser *strp) 475 { 476 WARN_ON(!strp->stopped); 477 478 cancel_work_sync(&strp->work); 479 tls_strp_anchor_free(strp); 480 } 481 482 int __init tls_strp_dev_init(void) 483 { 484 tls_strp_wq = create_workqueue("tls-strp"); 485 if (unlikely(!tls_strp_wq)) 486 return -ENOMEM; 487 488 return 0; 489 } 490 491 void tls_strp_dev_exit(void) 492 { 493 destroy_workqueue(tls_strp_wq); 494 } 495