1 // SPDX-License-Identifier: GPL-2.0 2 #include <net/tcp.h> 3 #include <net/strparser.h> 4 #include <net/xfrm.h> 5 #include <net/esp.h> 6 #include <net/espintcp.h> 7 #include <linux/skmsg.h> 8 #include <net/inet_common.h> 9 #if IS_ENABLED(CONFIG_IPV6) 10 #include <net/ipv6_stubs.h> 11 #endif 12 13 static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb, 14 struct sock *sk) 15 { 16 if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf || 17 !sk_rmem_schedule(sk, skb, skb->truesize)) { 18 kfree_skb(skb); 19 return; 20 } 21 22 skb_set_owner_r(skb, sk); 23 24 memset(skb->cb, 0, sizeof(skb->cb)); 25 skb_queue_tail(&ctx->ike_queue, skb); 26 ctx->saved_data_ready(sk); 27 } 28 29 static void handle_esp(struct sk_buff *skb, struct sock *sk) 30 { 31 skb_reset_transport_header(skb); 32 memset(skb->cb, 0, sizeof(skb->cb)); 33 34 rcu_read_lock(); 35 skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif); 36 local_bh_disable(); 37 #if IS_ENABLED(CONFIG_IPV6) 38 if (sk->sk_family == AF_INET6) 39 ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP); 40 else 41 #endif 42 xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP); 43 local_bh_enable(); 44 rcu_read_unlock(); 45 } 46 47 static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb) 48 { 49 struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx, 50 strp); 51 struct strp_msg *rxm = strp_msg(skb); 52 u32 nonesp_marker; 53 int err; 54 55 err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker, 56 sizeof(nonesp_marker)); 57 if (err < 0) { 58 kfree_skb(skb); 59 return; 60 } 61 62 /* remove header, leave non-ESP marker/SPI */ 63 if (!__pskb_pull(skb, rxm->offset + 2)) { 64 kfree_skb(skb); 65 return; 66 } 67 68 if (pskb_trim(skb, rxm->full_len - 2) != 0) { 69 kfree_skb(skb); 70 return; 71 } 72 73 if (nonesp_marker == 0) 74 handle_nonesp(ctx, skb, strp->sk); 75 else 76 handle_esp(skb, strp->sk); 77 } 78 79 static int espintcp_parse(struct strparser *strp, struct sk_buff *skb) 80 { 81 struct strp_msg *rxm = strp_msg(skb); 82 __be16 blen; 83 u16 len; 84 int err; 85 86 if (skb->len < rxm->offset + 2) 87 return 0; 88 89 err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen)); 90 if (err < 0) 91 return err; 92 93 len = be16_to_cpu(blen); 94 if (len < 6) 95 return -EINVAL; 96 97 return len; 98 } 99 100 static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 101 int nonblock, int flags, int *addr_len) 102 { 103 struct espintcp_ctx *ctx = espintcp_getctx(sk); 104 struct sk_buff *skb; 105 int err = 0; 106 int copied; 107 int off = 0; 108 109 flags |= nonblock ? MSG_DONTWAIT : 0; 110 111 skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err); 112 if (!skb) 113 return err; 114 115 copied = len; 116 if (copied > skb->len) 117 copied = skb->len; 118 else if (copied < skb->len) 119 msg->msg_flags |= MSG_TRUNC; 120 121 err = skb_copy_datagram_msg(skb, 0, msg, copied); 122 if (unlikely(err)) { 123 kfree_skb(skb); 124 return err; 125 } 126 127 if (flags & MSG_TRUNC) 128 copied = skb->len; 129 kfree_skb(skb); 130 return copied; 131 } 132 133 int espintcp_queue_out(struct sock *sk, struct sk_buff *skb) 134 { 135 struct espintcp_ctx *ctx = espintcp_getctx(sk); 136 137 if (skb_queue_len(&ctx->out_queue) >= netdev_max_backlog) 138 return -ENOBUFS; 139 140 __skb_queue_tail(&ctx->out_queue, skb); 141 142 return 0; 143 } 144 EXPORT_SYMBOL_GPL(espintcp_queue_out); 145 146 /* espintcp length field is 2B and length includes the length field's size */ 147 #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2) 148 149 static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg, 150 int flags) 151 { 152 do { 153 int ret; 154 155 ret = skb_send_sock_locked(sk, emsg->skb, 156 emsg->offset, emsg->len); 157 if (ret < 0) 158 return ret; 159 160 emsg->len -= ret; 161 emsg->offset += ret; 162 } while (emsg->len > 0); 163 164 kfree_skb(emsg->skb); 165 memset(emsg, 0, sizeof(*emsg)); 166 167 return 0; 168 } 169 170 static int espintcp_sendskmsg_locked(struct sock *sk, 171 struct espintcp_msg *emsg, int flags) 172 { 173 struct sk_msg *skmsg = &emsg->skmsg; 174 struct scatterlist *sg; 175 int done = 0; 176 int ret; 177 178 flags |= MSG_SENDPAGE_NOTLAST; 179 sg = &skmsg->sg.data[skmsg->sg.start]; 180 do { 181 size_t size = sg->length - emsg->offset; 182 int offset = sg->offset + emsg->offset; 183 struct page *p; 184 185 emsg->offset = 0; 186 187 if (sg_is_last(sg)) 188 flags &= ~MSG_SENDPAGE_NOTLAST; 189 190 p = sg_page(sg); 191 retry: 192 ret = do_tcp_sendpages(sk, p, offset, size, flags); 193 if (ret < 0) { 194 emsg->offset = offset - sg->offset; 195 skmsg->sg.start += done; 196 return ret; 197 } 198 199 if (ret != size) { 200 offset += ret; 201 size -= ret; 202 goto retry; 203 } 204 205 done++; 206 put_page(p); 207 sk_mem_uncharge(sk, sg->length); 208 sg = sg_next(sg); 209 } while (sg); 210 211 memset(emsg, 0, sizeof(*emsg)); 212 213 return 0; 214 } 215 216 static int espintcp_push_msgs(struct sock *sk) 217 { 218 struct espintcp_ctx *ctx = espintcp_getctx(sk); 219 struct espintcp_msg *emsg = &ctx->partial; 220 int err; 221 222 if (!emsg->len) 223 return 0; 224 225 if (ctx->tx_running) 226 return -EAGAIN; 227 ctx->tx_running = 1; 228 229 if (emsg->skb) 230 err = espintcp_sendskb_locked(sk, emsg, 0); 231 else 232 err = espintcp_sendskmsg_locked(sk, emsg, 0); 233 if (err == -EAGAIN) { 234 ctx->tx_running = 0; 235 return 0; 236 } 237 if (!err) 238 memset(emsg, 0, sizeof(*emsg)); 239 240 ctx->tx_running = 0; 241 242 return err; 243 } 244 245 int espintcp_push_skb(struct sock *sk, struct sk_buff *skb) 246 { 247 struct espintcp_ctx *ctx = espintcp_getctx(sk); 248 struct espintcp_msg *emsg = &ctx->partial; 249 unsigned int len; 250 int offset; 251 252 if (sk->sk_state != TCP_ESTABLISHED) { 253 kfree_skb(skb); 254 return -ECONNRESET; 255 } 256 257 offset = skb_transport_offset(skb); 258 len = skb->len - offset; 259 260 espintcp_push_msgs(sk); 261 262 if (emsg->len) { 263 kfree_skb(skb); 264 return -ENOBUFS; 265 } 266 267 skb_set_owner_w(skb, sk); 268 269 emsg->offset = offset; 270 emsg->len = len; 271 emsg->skb = skb; 272 273 espintcp_push_msgs(sk); 274 275 return 0; 276 } 277 EXPORT_SYMBOL_GPL(espintcp_push_skb); 278 279 static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 280 { 281 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 282 struct espintcp_ctx *ctx = espintcp_getctx(sk); 283 struct espintcp_msg *emsg = &ctx->partial; 284 struct iov_iter pfx_iter; 285 struct kvec pfx_iov = {}; 286 size_t msglen = size + 2; 287 char buf[2] = {0}; 288 int err, end; 289 290 if (msg->msg_flags) 291 return -EOPNOTSUPP; 292 293 if (size > MAX_ESPINTCP_MSG) 294 return -EMSGSIZE; 295 296 if (msg->msg_controllen) 297 return -EOPNOTSUPP; 298 299 lock_sock(sk); 300 301 err = espintcp_push_msgs(sk); 302 if (err < 0) { 303 err = -ENOBUFS; 304 goto unlock; 305 } 306 307 sk_msg_init(&emsg->skmsg); 308 while (1) { 309 /* only -ENOMEM is possible since we don't coalesce */ 310 err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0); 311 if (!err) 312 break; 313 314 err = sk_stream_wait_memory(sk, &timeo); 315 if (err) 316 goto fail; 317 } 318 319 *((__be16 *)buf) = cpu_to_be16(msglen); 320 pfx_iov.iov_base = buf; 321 pfx_iov.iov_len = sizeof(buf); 322 iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len); 323 324 err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg, 325 pfx_iov.iov_len); 326 if (err < 0) 327 goto fail; 328 329 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size); 330 if (err < 0) 331 goto fail; 332 333 end = emsg->skmsg.sg.end; 334 emsg->len = size; 335 sk_msg_iter_var_prev(end); 336 sg_mark_end(sk_msg_elem(&emsg->skmsg, end)); 337 338 tcp_rate_check_app_limited(sk); 339 340 err = espintcp_push_msgs(sk); 341 /* this message could be partially sent, keep it */ 342 if (err < 0) 343 goto unlock; 344 release_sock(sk); 345 346 return size; 347 348 fail: 349 sk_msg_free(sk, &emsg->skmsg); 350 memset(emsg, 0, sizeof(*emsg)); 351 unlock: 352 release_sock(sk); 353 return err; 354 } 355 356 static struct proto espintcp_prot __ro_after_init; 357 static struct proto_ops espintcp_ops __ro_after_init; 358 static struct proto espintcp6_prot; 359 static struct proto_ops espintcp6_ops; 360 static DEFINE_MUTEX(tcpv6_prot_mutex); 361 362 static void espintcp_data_ready(struct sock *sk) 363 { 364 struct espintcp_ctx *ctx = espintcp_getctx(sk); 365 366 strp_data_ready(&ctx->strp); 367 } 368 369 static void espintcp_tx_work(struct work_struct *work) 370 { 371 struct espintcp_ctx *ctx = container_of(work, 372 struct espintcp_ctx, work); 373 struct sock *sk = ctx->strp.sk; 374 375 lock_sock(sk); 376 if (!ctx->tx_running) 377 espintcp_push_msgs(sk); 378 release_sock(sk); 379 } 380 381 static void espintcp_write_space(struct sock *sk) 382 { 383 struct espintcp_ctx *ctx = espintcp_getctx(sk); 384 385 schedule_work(&ctx->work); 386 ctx->saved_write_space(sk); 387 } 388 389 static void espintcp_destruct(struct sock *sk) 390 { 391 struct espintcp_ctx *ctx = espintcp_getctx(sk); 392 393 ctx->saved_destruct(sk); 394 kfree(ctx); 395 } 396 397 bool tcp_is_ulp_esp(struct sock *sk) 398 { 399 return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot; 400 } 401 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp); 402 403 static void build_protos(struct proto *espintcp_prot, 404 struct proto_ops *espintcp_ops, 405 const struct proto *orig_prot, 406 const struct proto_ops *orig_ops); 407 static int espintcp_init_sk(struct sock *sk) 408 { 409 struct inet_connection_sock *icsk = inet_csk(sk); 410 struct strp_callbacks cb = { 411 .rcv_msg = espintcp_rcv, 412 .parse_msg = espintcp_parse, 413 }; 414 struct espintcp_ctx *ctx; 415 int err; 416 417 /* sockmap is not compatible with espintcp */ 418 if (sk->sk_user_data) 419 return -EBUSY; 420 421 ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); 422 if (!ctx) 423 return -ENOMEM; 424 425 err = strp_init(&ctx->strp, sk, &cb); 426 if (err) 427 goto free; 428 429 __sk_dst_reset(sk); 430 431 strp_check_rcv(&ctx->strp); 432 skb_queue_head_init(&ctx->ike_queue); 433 skb_queue_head_init(&ctx->out_queue); 434 435 if (sk->sk_family == AF_INET) { 436 sk->sk_prot = &espintcp_prot; 437 sk->sk_socket->ops = &espintcp_ops; 438 } else { 439 mutex_lock(&tcpv6_prot_mutex); 440 if (!espintcp6_prot.recvmsg) 441 build_protos(&espintcp6_prot, &espintcp6_ops, sk->sk_prot, sk->sk_socket->ops); 442 mutex_unlock(&tcpv6_prot_mutex); 443 444 sk->sk_prot = &espintcp6_prot; 445 sk->sk_socket->ops = &espintcp6_ops; 446 } 447 ctx->saved_data_ready = sk->sk_data_ready; 448 ctx->saved_write_space = sk->sk_write_space; 449 ctx->saved_destruct = sk->sk_destruct; 450 sk->sk_data_ready = espintcp_data_ready; 451 sk->sk_write_space = espintcp_write_space; 452 sk->sk_destruct = espintcp_destruct; 453 rcu_assign_pointer(icsk->icsk_ulp_data, ctx); 454 INIT_WORK(&ctx->work, espintcp_tx_work); 455 456 /* avoid using task_frag */ 457 sk->sk_allocation = GFP_ATOMIC; 458 459 return 0; 460 461 free: 462 kfree(ctx); 463 return err; 464 } 465 466 static void espintcp_release(struct sock *sk) 467 { 468 struct espintcp_ctx *ctx = espintcp_getctx(sk); 469 struct sk_buff_head queue; 470 struct sk_buff *skb; 471 472 __skb_queue_head_init(&queue); 473 skb_queue_splice_init(&ctx->out_queue, &queue); 474 475 while ((skb = __skb_dequeue(&queue))) 476 espintcp_push_skb(sk, skb); 477 478 tcp_release_cb(sk); 479 } 480 481 static void espintcp_close(struct sock *sk, long timeout) 482 { 483 struct espintcp_ctx *ctx = espintcp_getctx(sk); 484 struct espintcp_msg *emsg = &ctx->partial; 485 486 strp_stop(&ctx->strp); 487 488 sk->sk_prot = &tcp_prot; 489 barrier(); 490 491 cancel_work_sync(&ctx->work); 492 strp_done(&ctx->strp); 493 494 skb_queue_purge(&ctx->out_queue); 495 skb_queue_purge(&ctx->ike_queue); 496 497 if (emsg->len) { 498 if (emsg->skb) 499 kfree_skb(emsg->skb); 500 else 501 sk_msg_free(sk, &emsg->skmsg); 502 } 503 504 tcp_close(sk, timeout); 505 } 506 507 static __poll_t espintcp_poll(struct file *file, struct socket *sock, 508 poll_table *wait) 509 { 510 __poll_t mask = datagram_poll(file, sock, wait); 511 struct sock *sk = sock->sk; 512 struct espintcp_ctx *ctx = espintcp_getctx(sk); 513 514 if (!skb_queue_empty(&ctx->ike_queue)) 515 mask |= EPOLLIN | EPOLLRDNORM; 516 517 return mask; 518 } 519 520 static void build_protos(struct proto *espintcp_prot, 521 struct proto_ops *espintcp_ops, 522 const struct proto *orig_prot, 523 const struct proto_ops *orig_ops) 524 { 525 memcpy(espintcp_prot, orig_prot, sizeof(struct proto)); 526 memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops)); 527 espintcp_prot->sendmsg = espintcp_sendmsg; 528 espintcp_prot->recvmsg = espintcp_recvmsg; 529 espintcp_prot->close = espintcp_close; 530 espintcp_prot->release_cb = espintcp_release; 531 espintcp_ops->poll = espintcp_poll; 532 } 533 534 static struct tcp_ulp_ops espintcp_ulp __read_mostly = { 535 .name = "espintcp", 536 .owner = THIS_MODULE, 537 .init = espintcp_init_sk, 538 }; 539 540 void __init espintcp_init(void) 541 { 542 build_protos(&espintcp_prot, &espintcp_ops, &tcp_prot, &inet_stream_ops); 543 544 tcp_register_ulp(&espintcp_ulp); 545 } 546