1 // SPDX-License-Identifier: GPL-2.0 2 #include <net/tcp.h> 3 #include <net/strparser.h> 4 #include <net/xfrm.h> 5 #include <net/esp.h> 6 #include <net/espintcp.h> 7 #include <linux/skmsg.h> 8 #include <net/inet_common.h> 9 10 static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb, 11 struct sock *sk) 12 { 13 if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf || 14 !sk_rmem_schedule(sk, skb, skb->truesize)) { 15 kfree_skb(skb); 16 return; 17 } 18 19 skb_set_owner_r(skb, sk); 20 21 memset(skb->cb, 0, sizeof(skb->cb)); 22 skb_queue_tail(&ctx->ike_queue, skb); 23 ctx->saved_data_ready(sk); 24 } 25 26 static void handle_esp(struct sk_buff *skb, struct sock *sk) 27 { 28 skb_reset_transport_header(skb); 29 memset(skb->cb, 0, sizeof(skb->cb)); 30 31 rcu_read_lock(); 32 skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif); 33 local_bh_disable(); 34 xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP); 35 local_bh_enable(); 36 rcu_read_unlock(); 37 } 38 39 static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb) 40 { 41 struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx, 42 strp); 43 struct strp_msg *rxm = strp_msg(skb); 44 u32 nonesp_marker; 45 int err; 46 47 err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker, 48 sizeof(nonesp_marker)); 49 if (err < 0) { 50 kfree_skb(skb); 51 return; 52 } 53 54 /* remove header, leave non-ESP marker/SPI */ 55 if (!__pskb_pull(skb, rxm->offset + 2)) { 56 kfree_skb(skb); 57 return; 58 } 59 60 if (pskb_trim(skb, rxm->full_len - 2) != 0) { 61 kfree_skb(skb); 62 return; 63 } 64 65 if (nonesp_marker == 0) 66 handle_nonesp(ctx, skb, strp->sk); 67 else 68 handle_esp(skb, strp->sk); 69 } 70 71 static int espintcp_parse(struct strparser *strp, struct sk_buff *skb) 72 { 73 struct strp_msg *rxm = strp_msg(skb); 74 __be16 blen; 75 u16 len; 76 int err; 77 78 if (skb->len < rxm->offset + 2) 79 return 0; 80 81 err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen)); 82 if (err < 0) 83 return err; 84 85 len = be16_to_cpu(blen); 86 if (len < 6) 87 return -EINVAL; 88 89 return len; 90 } 91 92 static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 93 int nonblock, int flags, int *addr_len) 94 { 95 struct espintcp_ctx *ctx = espintcp_getctx(sk); 96 struct sk_buff *skb; 97 int err = 0; 98 int copied; 99 int off = 0; 100 101 flags |= nonblock ? MSG_DONTWAIT : 0; 102 103 skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err); 104 if (!skb) 105 return err; 106 107 copied = len; 108 if (copied > skb->len) 109 copied = skb->len; 110 else if (copied < skb->len) 111 msg->msg_flags |= MSG_TRUNC; 112 113 err = skb_copy_datagram_msg(skb, 0, msg, copied); 114 if (unlikely(err)) { 115 kfree_skb(skb); 116 return err; 117 } 118 119 if (flags & MSG_TRUNC) 120 copied = skb->len; 121 kfree_skb(skb); 122 return copied; 123 } 124 125 int espintcp_queue_out(struct sock *sk, struct sk_buff *skb) 126 { 127 struct espintcp_ctx *ctx = espintcp_getctx(sk); 128 129 if (skb_queue_len(&ctx->out_queue) >= netdev_max_backlog) 130 return -ENOBUFS; 131 132 __skb_queue_tail(&ctx->out_queue, skb); 133 134 return 0; 135 } 136 EXPORT_SYMBOL_GPL(espintcp_queue_out); 137 138 /* espintcp length field is 2B and length includes the length field's size */ 139 #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2) 140 141 static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg, 142 int flags) 143 { 144 do { 145 int ret; 146 147 ret = skb_send_sock_locked(sk, emsg->skb, 148 emsg->offset, emsg->len); 149 if (ret < 0) 150 return ret; 151 152 emsg->len -= ret; 153 emsg->offset += ret; 154 } while (emsg->len > 0); 155 156 kfree_skb(emsg->skb); 157 memset(emsg, 0, sizeof(*emsg)); 158 159 return 0; 160 } 161 162 static int espintcp_sendskmsg_locked(struct sock *sk, 163 struct espintcp_msg *emsg, int flags) 164 { 165 struct sk_msg *skmsg = &emsg->skmsg; 166 struct scatterlist *sg; 167 int done = 0; 168 int ret; 169 170 flags |= MSG_SENDPAGE_NOTLAST; 171 sg = &skmsg->sg.data[skmsg->sg.start]; 172 do { 173 size_t size = sg->length - emsg->offset; 174 int offset = sg->offset + emsg->offset; 175 struct page *p; 176 177 emsg->offset = 0; 178 179 if (sg_is_last(sg)) 180 flags &= ~MSG_SENDPAGE_NOTLAST; 181 182 p = sg_page(sg); 183 retry: 184 ret = do_tcp_sendpages(sk, p, offset, size, flags); 185 if (ret < 0) { 186 emsg->offset = offset - sg->offset; 187 skmsg->sg.start += done; 188 return ret; 189 } 190 191 if (ret != size) { 192 offset += ret; 193 size -= ret; 194 goto retry; 195 } 196 197 done++; 198 put_page(p); 199 sk_mem_uncharge(sk, sg->length); 200 sg = sg_next(sg); 201 } while (sg); 202 203 memset(emsg, 0, sizeof(*emsg)); 204 205 return 0; 206 } 207 208 static int espintcp_push_msgs(struct sock *sk) 209 { 210 struct espintcp_ctx *ctx = espintcp_getctx(sk); 211 struct espintcp_msg *emsg = &ctx->partial; 212 int err; 213 214 if (!emsg->len) 215 return 0; 216 217 if (ctx->tx_running) 218 return -EAGAIN; 219 ctx->tx_running = 1; 220 221 if (emsg->skb) 222 err = espintcp_sendskb_locked(sk, emsg, 0); 223 else 224 err = espintcp_sendskmsg_locked(sk, emsg, 0); 225 if (err == -EAGAIN) { 226 ctx->tx_running = 0; 227 return 0; 228 } 229 if (!err) 230 memset(emsg, 0, sizeof(*emsg)); 231 232 ctx->tx_running = 0; 233 234 return err; 235 } 236 237 int espintcp_push_skb(struct sock *sk, struct sk_buff *skb) 238 { 239 struct espintcp_ctx *ctx = espintcp_getctx(sk); 240 struct espintcp_msg *emsg = &ctx->partial; 241 unsigned int len; 242 int offset; 243 244 if (sk->sk_state != TCP_ESTABLISHED) { 245 kfree_skb(skb); 246 return -ECONNRESET; 247 } 248 249 offset = skb_transport_offset(skb); 250 len = skb->len - offset; 251 252 espintcp_push_msgs(sk); 253 254 if (emsg->len) { 255 kfree_skb(skb); 256 return -ENOBUFS; 257 } 258 259 skb_set_owner_w(skb, sk); 260 261 emsg->offset = offset; 262 emsg->len = len; 263 emsg->skb = skb; 264 265 espintcp_push_msgs(sk); 266 267 return 0; 268 } 269 EXPORT_SYMBOL_GPL(espintcp_push_skb); 270 271 static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 272 { 273 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 274 struct espintcp_ctx *ctx = espintcp_getctx(sk); 275 struct espintcp_msg *emsg = &ctx->partial; 276 struct iov_iter pfx_iter; 277 struct kvec pfx_iov = {}; 278 size_t msglen = size + 2; 279 char buf[2] = {0}; 280 int err, end; 281 282 if (msg->msg_flags) 283 return -EOPNOTSUPP; 284 285 if (size > MAX_ESPINTCP_MSG) 286 return -EMSGSIZE; 287 288 if (msg->msg_controllen) 289 return -EOPNOTSUPP; 290 291 lock_sock(sk); 292 293 err = espintcp_push_msgs(sk); 294 if (err < 0) { 295 err = -ENOBUFS; 296 goto unlock; 297 } 298 299 sk_msg_init(&emsg->skmsg); 300 while (1) { 301 /* only -ENOMEM is possible since we don't coalesce */ 302 err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0); 303 if (!err) 304 break; 305 306 err = sk_stream_wait_memory(sk, &timeo); 307 if (err) 308 goto fail; 309 } 310 311 *((__be16 *)buf) = cpu_to_be16(msglen); 312 pfx_iov.iov_base = buf; 313 pfx_iov.iov_len = sizeof(buf); 314 iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len); 315 316 err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg, 317 pfx_iov.iov_len); 318 if (err < 0) 319 goto fail; 320 321 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size); 322 if (err < 0) 323 goto fail; 324 325 end = emsg->skmsg.sg.end; 326 emsg->len = size; 327 sk_msg_iter_var_prev(end); 328 sg_mark_end(sk_msg_elem(&emsg->skmsg, end)); 329 330 tcp_rate_check_app_limited(sk); 331 332 err = espintcp_push_msgs(sk); 333 /* this message could be partially sent, keep it */ 334 if (err < 0) 335 goto unlock; 336 release_sock(sk); 337 338 return size; 339 340 fail: 341 sk_msg_free(sk, &emsg->skmsg); 342 memset(emsg, 0, sizeof(*emsg)); 343 unlock: 344 release_sock(sk); 345 return err; 346 } 347 348 static struct proto espintcp_prot __ro_after_init; 349 static struct proto_ops espintcp_ops __ro_after_init; 350 351 static void espintcp_data_ready(struct sock *sk) 352 { 353 struct espintcp_ctx *ctx = espintcp_getctx(sk); 354 355 strp_data_ready(&ctx->strp); 356 } 357 358 static void espintcp_tx_work(struct work_struct *work) 359 { 360 struct espintcp_ctx *ctx = container_of(work, 361 struct espintcp_ctx, work); 362 struct sock *sk = ctx->strp.sk; 363 364 lock_sock(sk); 365 if (!ctx->tx_running) 366 espintcp_push_msgs(sk); 367 release_sock(sk); 368 } 369 370 static void espintcp_write_space(struct sock *sk) 371 { 372 struct espintcp_ctx *ctx = espintcp_getctx(sk); 373 374 schedule_work(&ctx->work); 375 ctx->saved_write_space(sk); 376 } 377 378 static void espintcp_destruct(struct sock *sk) 379 { 380 struct espintcp_ctx *ctx = espintcp_getctx(sk); 381 382 kfree(ctx); 383 } 384 385 bool tcp_is_ulp_esp(struct sock *sk) 386 { 387 return sk->sk_prot == &espintcp_prot; 388 } 389 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp); 390 391 static int espintcp_init_sk(struct sock *sk) 392 { 393 struct inet_connection_sock *icsk = inet_csk(sk); 394 struct strp_callbacks cb = { 395 .rcv_msg = espintcp_rcv, 396 .parse_msg = espintcp_parse, 397 }; 398 struct espintcp_ctx *ctx; 399 int err; 400 401 /* sockmap is not compatible with espintcp */ 402 if (sk->sk_user_data) 403 return -EBUSY; 404 405 ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); 406 if (!ctx) 407 return -ENOMEM; 408 409 err = strp_init(&ctx->strp, sk, &cb); 410 if (err) 411 goto free; 412 413 __sk_dst_reset(sk); 414 415 strp_check_rcv(&ctx->strp); 416 skb_queue_head_init(&ctx->ike_queue); 417 skb_queue_head_init(&ctx->out_queue); 418 sk->sk_prot = &espintcp_prot; 419 sk->sk_socket->ops = &espintcp_ops; 420 ctx->saved_data_ready = sk->sk_data_ready; 421 ctx->saved_write_space = sk->sk_write_space; 422 sk->sk_data_ready = espintcp_data_ready; 423 sk->sk_write_space = espintcp_write_space; 424 sk->sk_destruct = espintcp_destruct; 425 rcu_assign_pointer(icsk->icsk_ulp_data, ctx); 426 INIT_WORK(&ctx->work, espintcp_tx_work); 427 428 /* avoid using task_frag */ 429 sk->sk_allocation = GFP_ATOMIC; 430 431 return 0; 432 433 free: 434 kfree(ctx); 435 return err; 436 } 437 438 static void espintcp_release(struct sock *sk) 439 { 440 struct espintcp_ctx *ctx = espintcp_getctx(sk); 441 struct sk_buff_head queue; 442 struct sk_buff *skb; 443 444 __skb_queue_head_init(&queue); 445 skb_queue_splice_init(&ctx->out_queue, &queue); 446 447 while ((skb = __skb_dequeue(&queue))) 448 espintcp_push_skb(sk, skb); 449 450 tcp_release_cb(sk); 451 } 452 453 static void espintcp_close(struct sock *sk, long timeout) 454 { 455 struct espintcp_ctx *ctx = espintcp_getctx(sk); 456 struct espintcp_msg *emsg = &ctx->partial; 457 458 strp_stop(&ctx->strp); 459 460 sk->sk_prot = &tcp_prot; 461 barrier(); 462 463 cancel_work_sync(&ctx->work); 464 strp_done(&ctx->strp); 465 466 skb_queue_purge(&ctx->out_queue); 467 skb_queue_purge(&ctx->ike_queue); 468 469 if (emsg->len) { 470 if (emsg->skb) 471 kfree_skb(emsg->skb); 472 else 473 sk_msg_free(sk, &emsg->skmsg); 474 } 475 476 tcp_close(sk, timeout); 477 } 478 479 static __poll_t espintcp_poll(struct file *file, struct socket *sock, 480 poll_table *wait) 481 { 482 __poll_t mask = datagram_poll(file, sock, wait); 483 struct sock *sk = sock->sk; 484 struct espintcp_ctx *ctx = espintcp_getctx(sk); 485 486 if (!skb_queue_empty(&ctx->ike_queue)) 487 mask |= EPOLLIN | EPOLLRDNORM; 488 489 return mask; 490 } 491 492 static struct tcp_ulp_ops espintcp_ulp __read_mostly = { 493 .name = "espintcp", 494 .owner = THIS_MODULE, 495 .init = espintcp_init_sk, 496 }; 497 498 void __init espintcp_init(void) 499 { 500 memcpy(&espintcp_prot, &tcp_prot, sizeof(tcp_prot)); 501 memcpy(&espintcp_ops, &inet_stream_ops, sizeof(inet_stream_ops)); 502 espintcp_prot.sendmsg = espintcp_sendmsg; 503 espintcp_prot.recvmsg = espintcp_recvmsg; 504 espintcp_prot.close = espintcp_close; 505 espintcp_prot.release_cb = espintcp_release; 506 espintcp_ops.poll = espintcp_poll; 507 508 tcp_register_ulp(&espintcp_ulp); 509 } 510