1 // SPDX-License-Identifier: GPL-2.0 2 #include <net/tcp.h> 3 #include <net/strparser.h> 4 #include <net/xfrm.h> 5 #include <net/esp.h> 6 #include <net/espintcp.h> 7 #include <linux/skmsg.h> 8 #include <net/inet_common.h> 9 #if IS_ENABLED(CONFIG_IPV6) 10 #include <net/ipv6_stubs.h> 11 #endif 12 13 static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb, 14 struct sock *sk) 15 { 16 if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf || 17 !sk_rmem_schedule(sk, skb, skb->truesize)) { 18 XFRM_INC_STATS(sock_net(sk), LINUX_MIB_XFRMINERROR); 19 kfree_skb(skb); 20 return; 21 } 22 23 skb_set_owner_r(skb, sk); 24 25 memset(skb->cb, 0, sizeof(skb->cb)); 26 skb_queue_tail(&ctx->ike_queue, skb); 27 ctx->saved_data_ready(sk); 28 } 29 30 static void handle_esp(struct sk_buff *skb, struct sock *sk) 31 { 32 skb_reset_transport_header(skb); 33 memset(skb->cb, 0, sizeof(skb->cb)); 34 35 rcu_read_lock(); 36 skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif); 37 local_bh_disable(); 38 #if IS_ENABLED(CONFIG_IPV6) 39 if (sk->sk_family == AF_INET6) 40 ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP); 41 else 42 #endif 43 xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP); 44 local_bh_enable(); 45 rcu_read_unlock(); 46 } 47 48 static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb) 49 { 50 struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx, 51 strp); 52 struct strp_msg *rxm = strp_msg(skb); 53 int len = rxm->full_len - 2; 54 u32 nonesp_marker; 55 int err; 56 57 /* keepalive packet? */ 58 if (unlikely(len == 1)) { 59 u8 data; 60 61 err = skb_copy_bits(skb, rxm->offset + 2, &data, 1); 62 if (err < 0) { 63 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR); 64 kfree_skb(skb); 65 return; 66 } 67 68 if (data == 0xff) { 69 kfree_skb(skb); 70 return; 71 } 72 } 73 74 /* drop other short messages */ 75 if (unlikely(len <= sizeof(nonesp_marker))) { 76 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR); 77 kfree_skb(skb); 78 return; 79 } 80 81 err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker, 82 sizeof(nonesp_marker)); 83 if (err < 0) { 84 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR); 85 kfree_skb(skb); 86 return; 87 } 88 89 /* remove header, leave non-ESP marker/SPI */ 90 if (!__pskb_pull(skb, rxm->offset + 2)) { 91 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR); 92 kfree_skb(skb); 93 return; 94 } 95 96 if (pskb_trim(skb, rxm->full_len - 2) != 0) { 97 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR); 98 kfree_skb(skb); 99 return; 100 } 101 102 if (nonesp_marker == 0) 103 handle_nonesp(ctx, skb, strp->sk); 104 else 105 handle_esp(skb, strp->sk); 106 } 107 108 static int espintcp_parse(struct strparser *strp, struct sk_buff *skb) 109 { 110 struct strp_msg *rxm = strp_msg(skb); 111 __be16 blen; 112 u16 len; 113 int err; 114 115 if (skb->len < rxm->offset + 2) 116 return 0; 117 118 err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen)); 119 if (err < 0) 120 return err; 121 122 len = be16_to_cpu(blen); 123 if (len < 2) 124 return -EINVAL; 125 126 return len; 127 } 128 129 static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 130 int nonblock, int flags, int *addr_len) 131 { 132 struct espintcp_ctx *ctx = espintcp_getctx(sk); 133 struct sk_buff *skb; 134 int err = 0; 135 int copied; 136 int off = 0; 137 138 flags |= nonblock ? MSG_DONTWAIT : 0; 139 140 skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err); 141 if (!skb) { 142 if (err == -EAGAIN && sk->sk_shutdown & RCV_SHUTDOWN) 143 return 0; 144 return err; 145 } 146 147 copied = len; 148 if (copied > skb->len) 149 copied = skb->len; 150 else if (copied < skb->len) 151 msg->msg_flags |= MSG_TRUNC; 152 153 err = skb_copy_datagram_msg(skb, 0, msg, copied); 154 if (unlikely(err)) { 155 kfree_skb(skb); 156 return err; 157 } 158 159 if (flags & MSG_TRUNC) 160 copied = skb->len; 161 kfree_skb(skb); 162 return copied; 163 } 164 165 int espintcp_queue_out(struct sock *sk, struct sk_buff *skb) 166 { 167 struct espintcp_ctx *ctx = espintcp_getctx(sk); 168 169 if (skb_queue_len(&ctx->out_queue) >= netdev_max_backlog) 170 return -ENOBUFS; 171 172 __skb_queue_tail(&ctx->out_queue, skb); 173 174 return 0; 175 } 176 EXPORT_SYMBOL_GPL(espintcp_queue_out); 177 178 /* espintcp length field is 2B and length includes the length field's size */ 179 #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2) 180 181 static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg, 182 int flags) 183 { 184 do { 185 int ret; 186 187 ret = skb_send_sock_locked(sk, emsg->skb, 188 emsg->offset, emsg->len); 189 if (ret < 0) 190 return ret; 191 192 emsg->len -= ret; 193 emsg->offset += ret; 194 } while (emsg->len > 0); 195 196 kfree_skb(emsg->skb); 197 memset(emsg, 0, sizeof(*emsg)); 198 199 return 0; 200 } 201 202 static int espintcp_sendskmsg_locked(struct sock *sk, 203 struct espintcp_msg *emsg, int flags) 204 { 205 struct sk_msg *skmsg = &emsg->skmsg; 206 struct scatterlist *sg; 207 int done = 0; 208 int ret; 209 210 flags |= MSG_SENDPAGE_NOTLAST; 211 sg = &skmsg->sg.data[skmsg->sg.start]; 212 do { 213 size_t size = sg->length - emsg->offset; 214 int offset = sg->offset + emsg->offset; 215 struct page *p; 216 217 emsg->offset = 0; 218 219 if (sg_is_last(sg)) 220 flags &= ~MSG_SENDPAGE_NOTLAST; 221 222 p = sg_page(sg); 223 retry: 224 ret = do_tcp_sendpages(sk, p, offset, size, flags); 225 if (ret < 0) { 226 emsg->offset = offset - sg->offset; 227 skmsg->sg.start += done; 228 return ret; 229 } 230 231 if (ret != size) { 232 offset += ret; 233 size -= ret; 234 goto retry; 235 } 236 237 done++; 238 put_page(p); 239 sk_mem_uncharge(sk, sg->length); 240 sg = sg_next(sg); 241 } while (sg); 242 243 memset(emsg, 0, sizeof(*emsg)); 244 245 return 0; 246 } 247 248 static int espintcp_push_msgs(struct sock *sk, int flags) 249 { 250 struct espintcp_ctx *ctx = espintcp_getctx(sk); 251 struct espintcp_msg *emsg = &ctx->partial; 252 int err; 253 254 if (!emsg->len) 255 return 0; 256 257 if (ctx->tx_running) 258 return -EAGAIN; 259 ctx->tx_running = 1; 260 261 if (emsg->skb) 262 err = espintcp_sendskb_locked(sk, emsg, flags); 263 else 264 err = espintcp_sendskmsg_locked(sk, emsg, flags); 265 if (err == -EAGAIN) { 266 ctx->tx_running = 0; 267 return flags & MSG_DONTWAIT ? -EAGAIN : 0; 268 } 269 if (!err) 270 memset(emsg, 0, sizeof(*emsg)); 271 272 ctx->tx_running = 0; 273 274 return err; 275 } 276 277 int espintcp_push_skb(struct sock *sk, struct sk_buff *skb) 278 { 279 struct espintcp_ctx *ctx = espintcp_getctx(sk); 280 struct espintcp_msg *emsg = &ctx->partial; 281 unsigned int len; 282 int offset; 283 284 if (sk->sk_state != TCP_ESTABLISHED) { 285 kfree_skb(skb); 286 return -ECONNRESET; 287 } 288 289 offset = skb_transport_offset(skb); 290 len = skb->len - offset; 291 292 espintcp_push_msgs(sk, 0); 293 294 if (emsg->len) { 295 kfree_skb(skb); 296 return -ENOBUFS; 297 } 298 299 skb_set_owner_w(skb, sk); 300 301 emsg->offset = offset; 302 emsg->len = len; 303 emsg->skb = skb; 304 305 espintcp_push_msgs(sk, 0); 306 307 return 0; 308 } 309 EXPORT_SYMBOL_GPL(espintcp_push_skb); 310 311 static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 312 { 313 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 314 struct espintcp_ctx *ctx = espintcp_getctx(sk); 315 struct espintcp_msg *emsg = &ctx->partial; 316 struct iov_iter pfx_iter; 317 struct kvec pfx_iov = {}; 318 size_t msglen = size + 2; 319 char buf[2] = {0}; 320 int err, end; 321 322 if (msg->msg_flags & ~MSG_DONTWAIT) 323 return -EOPNOTSUPP; 324 325 if (size > MAX_ESPINTCP_MSG) 326 return -EMSGSIZE; 327 328 if (msg->msg_controllen) 329 return -EOPNOTSUPP; 330 331 lock_sock(sk); 332 333 err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT); 334 if (err < 0) { 335 if (err != -EAGAIN || !(msg->msg_flags & MSG_DONTWAIT)) 336 err = -ENOBUFS; 337 goto unlock; 338 } 339 340 sk_msg_init(&emsg->skmsg); 341 while (1) { 342 /* only -ENOMEM is possible since we don't coalesce */ 343 err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0); 344 if (!err) 345 break; 346 347 err = sk_stream_wait_memory(sk, &timeo); 348 if (err) 349 goto fail; 350 } 351 352 *((__be16 *)buf) = cpu_to_be16(msglen); 353 pfx_iov.iov_base = buf; 354 pfx_iov.iov_len = sizeof(buf); 355 iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len); 356 357 err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg, 358 pfx_iov.iov_len); 359 if (err < 0) 360 goto fail; 361 362 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size); 363 if (err < 0) 364 goto fail; 365 366 end = emsg->skmsg.sg.end; 367 emsg->len = size; 368 sk_msg_iter_var_prev(end); 369 sg_mark_end(sk_msg_elem(&emsg->skmsg, end)); 370 371 tcp_rate_check_app_limited(sk); 372 373 err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT); 374 /* this message could be partially sent, keep it */ 375 376 release_sock(sk); 377 378 return size; 379 380 fail: 381 sk_msg_free(sk, &emsg->skmsg); 382 memset(emsg, 0, sizeof(*emsg)); 383 unlock: 384 release_sock(sk); 385 return err; 386 } 387 388 static struct proto espintcp_prot __ro_after_init; 389 static struct proto_ops espintcp_ops __ro_after_init; 390 static struct proto espintcp6_prot; 391 static struct proto_ops espintcp6_ops; 392 static DEFINE_MUTEX(tcpv6_prot_mutex); 393 394 static void espintcp_data_ready(struct sock *sk) 395 { 396 struct espintcp_ctx *ctx = espintcp_getctx(sk); 397 398 strp_data_ready(&ctx->strp); 399 } 400 401 static void espintcp_tx_work(struct work_struct *work) 402 { 403 struct espintcp_ctx *ctx = container_of(work, 404 struct espintcp_ctx, work); 405 struct sock *sk = ctx->strp.sk; 406 407 lock_sock(sk); 408 if (!ctx->tx_running) 409 espintcp_push_msgs(sk, 0); 410 release_sock(sk); 411 } 412 413 static void espintcp_write_space(struct sock *sk) 414 { 415 struct espintcp_ctx *ctx = espintcp_getctx(sk); 416 417 schedule_work(&ctx->work); 418 ctx->saved_write_space(sk); 419 } 420 421 static void espintcp_destruct(struct sock *sk) 422 { 423 struct espintcp_ctx *ctx = espintcp_getctx(sk); 424 425 ctx->saved_destruct(sk); 426 kfree(ctx); 427 } 428 429 bool tcp_is_ulp_esp(struct sock *sk) 430 { 431 return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot; 432 } 433 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp); 434 435 static void build_protos(struct proto *espintcp_prot, 436 struct proto_ops *espintcp_ops, 437 const struct proto *orig_prot, 438 const struct proto_ops *orig_ops); 439 static int espintcp_init_sk(struct sock *sk) 440 { 441 struct inet_connection_sock *icsk = inet_csk(sk); 442 struct strp_callbacks cb = { 443 .rcv_msg = espintcp_rcv, 444 .parse_msg = espintcp_parse, 445 }; 446 struct espintcp_ctx *ctx; 447 int err; 448 449 /* sockmap is not compatible with espintcp */ 450 if (sk->sk_user_data) 451 return -EBUSY; 452 453 ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); 454 if (!ctx) 455 return -ENOMEM; 456 457 err = strp_init(&ctx->strp, sk, &cb); 458 if (err) 459 goto free; 460 461 __sk_dst_reset(sk); 462 463 strp_check_rcv(&ctx->strp); 464 skb_queue_head_init(&ctx->ike_queue); 465 skb_queue_head_init(&ctx->out_queue); 466 467 if (sk->sk_family == AF_INET) { 468 sk->sk_prot = &espintcp_prot; 469 sk->sk_socket->ops = &espintcp_ops; 470 } else { 471 mutex_lock(&tcpv6_prot_mutex); 472 if (!espintcp6_prot.recvmsg) 473 build_protos(&espintcp6_prot, &espintcp6_ops, sk->sk_prot, sk->sk_socket->ops); 474 mutex_unlock(&tcpv6_prot_mutex); 475 476 sk->sk_prot = &espintcp6_prot; 477 sk->sk_socket->ops = &espintcp6_ops; 478 } 479 ctx->saved_data_ready = sk->sk_data_ready; 480 ctx->saved_write_space = sk->sk_write_space; 481 ctx->saved_destruct = sk->sk_destruct; 482 sk->sk_data_ready = espintcp_data_ready; 483 sk->sk_write_space = espintcp_write_space; 484 sk->sk_destruct = espintcp_destruct; 485 rcu_assign_pointer(icsk->icsk_ulp_data, ctx); 486 INIT_WORK(&ctx->work, espintcp_tx_work); 487 488 /* avoid using task_frag */ 489 sk->sk_allocation = GFP_ATOMIC; 490 491 return 0; 492 493 free: 494 kfree(ctx); 495 return err; 496 } 497 498 static void espintcp_release(struct sock *sk) 499 { 500 struct espintcp_ctx *ctx = espintcp_getctx(sk); 501 struct sk_buff_head queue; 502 struct sk_buff *skb; 503 504 __skb_queue_head_init(&queue); 505 skb_queue_splice_init(&ctx->out_queue, &queue); 506 507 while ((skb = __skb_dequeue(&queue))) 508 espintcp_push_skb(sk, skb); 509 510 tcp_release_cb(sk); 511 } 512 513 static void espintcp_close(struct sock *sk, long timeout) 514 { 515 struct espintcp_ctx *ctx = espintcp_getctx(sk); 516 struct espintcp_msg *emsg = &ctx->partial; 517 518 strp_stop(&ctx->strp); 519 520 sk->sk_prot = &tcp_prot; 521 barrier(); 522 523 cancel_work_sync(&ctx->work); 524 strp_done(&ctx->strp); 525 526 skb_queue_purge(&ctx->out_queue); 527 skb_queue_purge(&ctx->ike_queue); 528 529 if (emsg->len) { 530 if (emsg->skb) 531 kfree_skb(emsg->skb); 532 else 533 sk_msg_free(sk, &emsg->skmsg); 534 } 535 536 tcp_close(sk, timeout); 537 } 538 539 static __poll_t espintcp_poll(struct file *file, struct socket *sock, 540 poll_table *wait) 541 { 542 __poll_t mask = datagram_poll(file, sock, wait); 543 struct sock *sk = sock->sk; 544 struct espintcp_ctx *ctx = espintcp_getctx(sk); 545 546 if (!skb_queue_empty(&ctx->ike_queue)) 547 mask |= EPOLLIN | EPOLLRDNORM; 548 549 return mask; 550 } 551 552 static void build_protos(struct proto *espintcp_prot, 553 struct proto_ops *espintcp_ops, 554 const struct proto *orig_prot, 555 const struct proto_ops *orig_ops) 556 { 557 memcpy(espintcp_prot, orig_prot, sizeof(struct proto)); 558 memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops)); 559 espintcp_prot->sendmsg = espintcp_sendmsg; 560 espintcp_prot->recvmsg = espintcp_recvmsg; 561 espintcp_prot->close = espintcp_close; 562 espintcp_prot->release_cb = espintcp_release; 563 espintcp_ops->poll = espintcp_poll; 564 } 565 566 static struct tcp_ulp_ops espintcp_ulp __read_mostly = { 567 .name = "espintcp", 568 .owner = THIS_MODULE, 569 .init = espintcp_init_sk, 570 }; 571 572 void __init espintcp_init(void) 573 { 574 build_protos(&espintcp_prot, &espintcp_ops, &tcp_prot, &inet_stream_ops); 575 576 tcp_register_ulp(&espintcp_ulp); 577 } 578