1 /* 2 * Checksum updating actions 3 * 4 * Copyright (c) 2010 Gregoire Baron <baronchon@n7mm.org> 5 * 6 * This program is free software; you can redistribute it and/or modify it 7 * under the terms of the GNU General Public License as published by the Free 8 * Software Foundation; either version 2 of the License, or (at your option) 9 * any later version. 10 * 11 */ 12 13 #include <linux/types.h> 14 #include <linux/init.h> 15 #include <linux/kernel.h> 16 #include <linux/module.h> 17 #include <linux/spinlock.h> 18 19 #include <linux/netlink.h> 20 #include <net/netlink.h> 21 #include <linux/rtnetlink.h> 22 23 #include <linux/skbuff.h> 24 25 #include <net/ip.h> 26 #include <net/ipv6.h> 27 #include <net/icmp.h> 28 #include <linux/icmpv6.h> 29 #include <linux/igmp.h> 30 #include <net/tcp.h> 31 #include <net/udp.h> 32 #include <net/ip6_checksum.h> 33 #include <net/sctp/checksum.h> 34 35 #include <net/act_api.h> 36 37 #include <linux/tc_act/tc_csum.h> 38 #include <net/tc_act/tc_csum.h> 39 40 static const struct nla_policy csum_policy[TCA_CSUM_MAX + 1] = { 41 [TCA_CSUM_PARMS] = { .len = sizeof(struct tc_csum), }, 42 }; 43 44 static unsigned int csum_net_id; 45 static struct tc_action_ops act_csum_ops; 46 47 static int tcf_csum_init(struct net *net, struct nlattr *nla, 48 struct nlattr *est, struct tc_action **a, int ovr, 49 int bind) 50 { 51 struct tc_action_net *tn = net_generic(net, csum_net_id); 52 struct nlattr *tb[TCA_CSUM_MAX + 1]; 53 struct tc_csum *parm; 54 struct tcf_csum *p; 55 int ret = 0, err; 56 57 if (nla == NULL) 58 return -EINVAL; 59 60 err = nla_parse_nested(tb, TCA_CSUM_MAX, nla, csum_policy, NULL); 61 if (err < 0) 62 return err; 63 64 if (tb[TCA_CSUM_PARMS] == NULL) 65 return -EINVAL; 66 parm = nla_data(tb[TCA_CSUM_PARMS]); 67 68 if (!tcf_idr_check(tn, parm->index, a, bind)) { 69 ret = tcf_idr_create(tn, parm->index, est, a, 70 &act_csum_ops, bind, false); 71 if (ret) 72 return ret; 73 ret = ACT_P_CREATED; 74 } else { 75 if (bind)/* dont override defaults */ 76 return 0; 77 tcf_idr_release(*a, bind); 78 if (!ovr) 79 return -EEXIST; 80 } 81 82 p = to_tcf_csum(*a); 83 spin_lock_bh(&p->tcf_lock); 84 p->tcf_action = parm->action; 85 p->update_flags = parm->update_flags; 86 spin_unlock_bh(&p->tcf_lock); 87 88 if (ret == ACT_P_CREATED) 89 tcf_idr_insert(tn, *a); 90 91 return ret; 92 } 93 94 /** 95 * tcf_csum_skb_nextlayer - Get next layer pointer 96 * @skb: sk_buff to use 97 * @ihl: previous summed headers length 98 * @ipl: complete packet length 99 * @jhl: next header length 100 * 101 * Check the expected next layer availability in the specified sk_buff. 102 * Return the next layer pointer if pass, NULL otherwise. 103 */ 104 static void *tcf_csum_skb_nextlayer(struct sk_buff *skb, 105 unsigned int ihl, unsigned int ipl, 106 unsigned int jhl) 107 { 108 int ntkoff = skb_network_offset(skb); 109 int hl = ihl + jhl; 110 111 if (!pskb_may_pull(skb, ipl + ntkoff) || (ipl < hl) || 112 skb_try_make_writable(skb, hl + ntkoff)) 113 return NULL; 114 else 115 return (void *)(skb_network_header(skb) + ihl); 116 } 117 118 static int tcf_csum_ipv4_icmp(struct sk_buff *skb, unsigned int ihl, 119 unsigned int ipl) 120 { 121 struct icmphdr *icmph; 122 123 icmph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*icmph)); 124 if (icmph == NULL) 125 return 0; 126 127 icmph->checksum = 0; 128 skb->csum = csum_partial(icmph, ipl - ihl, 0); 129 icmph->checksum = csum_fold(skb->csum); 130 131 skb->ip_summed = CHECKSUM_NONE; 132 133 return 1; 134 } 135 136 static int tcf_csum_ipv4_igmp(struct sk_buff *skb, 137 unsigned int ihl, unsigned int ipl) 138 { 139 struct igmphdr *igmph; 140 141 igmph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*igmph)); 142 if (igmph == NULL) 143 return 0; 144 145 igmph->csum = 0; 146 skb->csum = csum_partial(igmph, ipl - ihl, 0); 147 igmph->csum = csum_fold(skb->csum); 148 149 skb->ip_summed = CHECKSUM_NONE; 150 151 return 1; 152 } 153 154 static int tcf_csum_ipv6_icmp(struct sk_buff *skb, unsigned int ihl, 155 unsigned int ipl) 156 { 157 struct icmp6hdr *icmp6h; 158 const struct ipv6hdr *ip6h; 159 160 icmp6h = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*icmp6h)); 161 if (icmp6h == NULL) 162 return 0; 163 164 ip6h = ipv6_hdr(skb); 165 icmp6h->icmp6_cksum = 0; 166 skb->csum = csum_partial(icmp6h, ipl - ihl, 0); 167 icmp6h->icmp6_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, 168 ipl - ihl, IPPROTO_ICMPV6, 169 skb->csum); 170 171 skb->ip_summed = CHECKSUM_NONE; 172 173 return 1; 174 } 175 176 static int tcf_csum_ipv4_tcp(struct sk_buff *skb, unsigned int ihl, 177 unsigned int ipl) 178 { 179 struct tcphdr *tcph; 180 const struct iphdr *iph; 181 182 if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_TCPV4) 183 return 1; 184 185 tcph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*tcph)); 186 if (tcph == NULL) 187 return 0; 188 189 iph = ip_hdr(skb); 190 tcph->check = 0; 191 skb->csum = csum_partial(tcph, ipl - ihl, 0); 192 tcph->check = tcp_v4_check(ipl - ihl, 193 iph->saddr, iph->daddr, skb->csum); 194 195 skb->ip_summed = CHECKSUM_NONE; 196 197 return 1; 198 } 199 200 static int tcf_csum_ipv6_tcp(struct sk_buff *skb, unsigned int ihl, 201 unsigned int ipl) 202 { 203 struct tcphdr *tcph; 204 const struct ipv6hdr *ip6h; 205 206 if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_TCPV6) 207 return 1; 208 209 tcph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*tcph)); 210 if (tcph == NULL) 211 return 0; 212 213 ip6h = ipv6_hdr(skb); 214 tcph->check = 0; 215 skb->csum = csum_partial(tcph, ipl - ihl, 0); 216 tcph->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, 217 ipl - ihl, IPPROTO_TCP, 218 skb->csum); 219 220 skb->ip_summed = CHECKSUM_NONE; 221 222 return 1; 223 } 224 225 static int tcf_csum_ipv4_udp(struct sk_buff *skb, unsigned int ihl, 226 unsigned int ipl, int udplite) 227 { 228 struct udphdr *udph; 229 const struct iphdr *iph; 230 u16 ul; 231 232 /* 233 * Support both UDP and UDPLITE checksum algorithms, Don't use 234 * udph->len to get the real length without any protocol check, 235 * UDPLITE uses udph->len for another thing, 236 * Use iph->tot_len, or just ipl. 237 */ 238 239 udph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*udph)); 240 if (udph == NULL) 241 return 0; 242 243 iph = ip_hdr(skb); 244 ul = ntohs(udph->len); 245 246 if (udplite || udph->check) { 247 248 udph->check = 0; 249 250 if (udplite) { 251 if (ul == 0) 252 skb->csum = csum_partial(udph, ipl - ihl, 0); 253 else if ((ul >= sizeof(*udph)) && (ul <= ipl - ihl)) 254 skb->csum = csum_partial(udph, ul, 0); 255 else 256 goto ignore_obscure_skb; 257 } else { 258 if (ul != ipl - ihl) 259 goto ignore_obscure_skb; 260 261 skb->csum = csum_partial(udph, ul, 0); 262 } 263 264 udph->check = csum_tcpudp_magic(iph->saddr, iph->daddr, 265 ul, iph->protocol, 266 skb->csum); 267 268 if (!udph->check) 269 udph->check = CSUM_MANGLED_0; 270 } 271 272 skb->ip_summed = CHECKSUM_NONE; 273 274 ignore_obscure_skb: 275 return 1; 276 } 277 278 static int tcf_csum_ipv6_udp(struct sk_buff *skb, unsigned int ihl, 279 unsigned int ipl, int udplite) 280 { 281 struct udphdr *udph; 282 const struct ipv6hdr *ip6h; 283 u16 ul; 284 285 /* 286 * Support both UDP and UDPLITE checksum algorithms, Don't use 287 * udph->len to get the real length without any protocol check, 288 * UDPLITE uses udph->len for another thing, 289 * Use ip6h->payload_len + sizeof(*ip6h) ... , or just ipl. 290 */ 291 292 udph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*udph)); 293 if (udph == NULL) 294 return 0; 295 296 ip6h = ipv6_hdr(skb); 297 ul = ntohs(udph->len); 298 299 udph->check = 0; 300 301 if (udplite) { 302 if (ul == 0) 303 skb->csum = csum_partial(udph, ipl - ihl, 0); 304 305 else if ((ul >= sizeof(*udph)) && (ul <= ipl - ihl)) 306 skb->csum = csum_partial(udph, ul, 0); 307 308 else 309 goto ignore_obscure_skb; 310 } else { 311 if (ul != ipl - ihl) 312 goto ignore_obscure_skb; 313 314 skb->csum = csum_partial(udph, ul, 0); 315 } 316 317 udph->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, ul, 318 udplite ? IPPROTO_UDPLITE : IPPROTO_UDP, 319 skb->csum); 320 321 if (!udph->check) 322 udph->check = CSUM_MANGLED_0; 323 324 skb->ip_summed = CHECKSUM_NONE; 325 326 ignore_obscure_skb: 327 return 1; 328 } 329 330 static int tcf_csum_sctp(struct sk_buff *skb, unsigned int ihl, 331 unsigned int ipl) 332 { 333 struct sctphdr *sctph; 334 335 if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_SCTP) 336 return 1; 337 338 sctph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*sctph)); 339 if (!sctph) 340 return 0; 341 342 sctph->checksum = sctp_compute_cksum(skb, 343 skb_network_offset(skb) + ihl); 344 skb->ip_summed = CHECKSUM_NONE; 345 skb->csum_not_inet = 0; 346 347 return 1; 348 } 349 350 static int tcf_csum_ipv4(struct sk_buff *skb, u32 update_flags) 351 { 352 const struct iphdr *iph; 353 int ntkoff; 354 355 ntkoff = skb_network_offset(skb); 356 357 if (!pskb_may_pull(skb, sizeof(*iph) + ntkoff)) 358 goto fail; 359 360 iph = ip_hdr(skb); 361 362 switch (iph->frag_off & htons(IP_OFFSET) ? 0 : iph->protocol) { 363 case IPPROTO_ICMP: 364 if (update_flags & TCA_CSUM_UPDATE_FLAG_ICMP) 365 if (!tcf_csum_ipv4_icmp(skb, iph->ihl * 4, 366 ntohs(iph->tot_len))) 367 goto fail; 368 break; 369 case IPPROTO_IGMP: 370 if (update_flags & TCA_CSUM_UPDATE_FLAG_IGMP) 371 if (!tcf_csum_ipv4_igmp(skb, iph->ihl * 4, 372 ntohs(iph->tot_len))) 373 goto fail; 374 break; 375 case IPPROTO_TCP: 376 if (update_flags & TCA_CSUM_UPDATE_FLAG_TCP) 377 if (!tcf_csum_ipv4_tcp(skb, iph->ihl * 4, 378 ntohs(iph->tot_len))) 379 goto fail; 380 break; 381 case IPPROTO_UDP: 382 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDP) 383 if (!tcf_csum_ipv4_udp(skb, iph->ihl * 4, 384 ntohs(iph->tot_len), 0)) 385 goto fail; 386 break; 387 case IPPROTO_UDPLITE: 388 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDPLITE) 389 if (!tcf_csum_ipv4_udp(skb, iph->ihl * 4, 390 ntohs(iph->tot_len), 1)) 391 goto fail; 392 break; 393 case IPPROTO_SCTP: 394 if ((update_flags & TCA_CSUM_UPDATE_FLAG_SCTP) && 395 !tcf_csum_sctp(skb, iph->ihl * 4, ntohs(iph->tot_len))) 396 goto fail; 397 break; 398 } 399 400 if (update_flags & TCA_CSUM_UPDATE_FLAG_IPV4HDR) { 401 if (skb_try_make_writable(skb, sizeof(*iph) + ntkoff)) 402 goto fail; 403 404 ip_send_check(ip_hdr(skb)); 405 } 406 407 return 1; 408 409 fail: 410 return 0; 411 } 412 413 static int tcf_csum_ipv6_hopopts(struct ipv6_opt_hdr *ip6xh, unsigned int ixhl, 414 unsigned int *pl) 415 { 416 int off, len, optlen; 417 unsigned char *xh = (void *)ip6xh; 418 419 off = sizeof(*ip6xh); 420 len = ixhl - off; 421 422 while (len > 1) { 423 switch (xh[off]) { 424 case IPV6_TLV_PAD1: 425 optlen = 1; 426 break; 427 case IPV6_TLV_JUMBO: 428 optlen = xh[off + 1] + 2; 429 if (optlen != 6 || len < 6 || (off & 3) != 2) 430 /* wrong jumbo option length/alignment */ 431 return 0; 432 *pl = ntohl(*(__be32 *)(xh + off + 2)); 433 goto done; 434 default: 435 optlen = xh[off + 1] + 2; 436 if (optlen > len) 437 /* ignore obscure options */ 438 goto done; 439 break; 440 } 441 off += optlen; 442 len -= optlen; 443 } 444 445 done: 446 return 1; 447 } 448 449 static int tcf_csum_ipv6(struct sk_buff *skb, u32 update_flags) 450 { 451 struct ipv6hdr *ip6h; 452 struct ipv6_opt_hdr *ip6xh; 453 unsigned int hl, ixhl; 454 unsigned int pl; 455 int ntkoff; 456 u8 nexthdr; 457 458 ntkoff = skb_network_offset(skb); 459 460 hl = sizeof(*ip6h); 461 462 if (!pskb_may_pull(skb, hl + ntkoff)) 463 goto fail; 464 465 ip6h = ipv6_hdr(skb); 466 467 pl = ntohs(ip6h->payload_len); 468 nexthdr = ip6h->nexthdr; 469 470 do { 471 switch (nexthdr) { 472 case NEXTHDR_FRAGMENT: 473 goto ignore_skb; 474 case NEXTHDR_ROUTING: 475 case NEXTHDR_HOP: 476 case NEXTHDR_DEST: 477 if (!pskb_may_pull(skb, hl + sizeof(*ip6xh) + ntkoff)) 478 goto fail; 479 ip6xh = (void *)(skb_network_header(skb) + hl); 480 ixhl = ipv6_optlen(ip6xh); 481 if (!pskb_may_pull(skb, hl + ixhl + ntkoff)) 482 goto fail; 483 ip6xh = (void *)(skb_network_header(skb) + hl); 484 if ((nexthdr == NEXTHDR_HOP) && 485 !(tcf_csum_ipv6_hopopts(ip6xh, ixhl, &pl))) 486 goto fail; 487 nexthdr = ip6xh->nexthdr; 488 hl += ixhl; 489 break; 490 case IPPROTO_ICMPV6: 491 if (update_flags & TCA_CSUM_UPDATE_FLAG_ICMP) 492 if (!tcf_csum_ipv6_icmp(skb, 493 hl, pl + sizeof(*ip6h))) 494 goto fail; 495 goto done; 496 case IPPROTO_TCP: 497 if (update_flags & TCA_CSUM_UPDATE_FLAG_TCP) 498 if (!tcf_csum_ipv6_tcp(skb, 499 hl, pl + sizeof(*ip6h))) 500 goto fail; 501 goto done; 502 case IPPROTO_UDP: 503 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDP) 504 if (!tcf_csum_ipv6_udp(skb, hl, 505 pl + sizeof(*ip6h), 0)) 506 goto fail; 507 goto done; 508 case IPPROTO_UDPLITE: 509 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDPLITE) 510 if (!tcf_csum_ipv6_udp(skb, hl, 511 pl + sizeof(*ip6h), 1)) 512 goto fail; 513 goto done; 514 case IPPROTO_SCTP: 515 if ((update_flags & TCA_CSUM_UPDATE_FLAG_SCTP) && 516 !tcf_csum_sctp(skb, hl, pl + sizeof(*ip6h))) 517 goto fail; 518 goto done; 519 default: 520 goto ignore_skb; 521 } 522 } while (pskb_may_pull(skb, hl + 1 + ntkoff)); 523 524 done: 525 ignore_skb: 526 return 1; 527 528 fail: 529 return 0; 530 } 531 532 static int tcf_csum(struct sk_buff *skb, const struct tc_action *a, 533 struct tcf_result *res) 534 { 535 struct tcf_csum *p = to_tcf_csum(a); 536 int action; 537 u32 update_flags; 538 539 spin_lock(&p->tcf_lock); 540 tcf_lastuse_update(&p->tcf_tm); 541 bstats_update(&p->tcf_bstats, skb); 542 action = p->tcf_action; 543 update_flags = p->update_flags; 544 spin_unlock(&p->tcf_lock); 545 546 if (unlikely(action == TC_ACT_SHOT)) 547 goto drop; 548 549 switch (tc_skb_protocol(skb)) { 550 case cpu_to_be16(ETH_P_IP): 551 if (!tcf_csum_ipv4(skb, update_flags)) 552 goto drop; 553 break; 554 case cpu_to_be16(ETH_P_IPV6): 555 if (!tcf_csum_ipv6(skb, update_flags)) 556 goto drop; 557 break; 558 } 559 560 return action; 561 562 drop: 563 spin_lock(&p->tcf_lock); 564 p->tcf_qstats.drops++; 565 spin_unlock(&p->tcf_lock); 566 return TC_ACT_SHOT; 567 } 568 569 static int tcf_csum_dump(struct sk_buff *skb, struct tc_action *a, int bind, 570 int ref) 571 { 572 unsigned char *b = skb_tail_pointer(skb); 573 struct tcf_csum *p = to_tcf_csum(a); 574 struct tc_csum opt = { 575 .update_flags = p->update_flags, 576 .index = p->tcf_index, 577 .action = p->tcf_action, 578 .refcnt = p->tcf_refcnt - ref, 579 .bindcnt = p->tcf_bindcnt - bind, 580 }; 581 struct tcf_t t; 582 583 if (nla_put(skb, TCA_CSUM_PARMS, sizeof(opt), &opt)) 584 goto nla_put_failure; 585 586 tcf_tm_dump(&t, &p->tcf_tm); 587 if (nla_put_64bit(skb, TCA_CSUM_TM, sizeof(t), &t, TCA_CSUM_PAD)) 588 goto nla_put_failure; 589 590 return skb->len; 591 592 nla_put_failure: 593 nlmsg_trim(skb, b); 594 return -1; 595 } 596 597 static int tcf_csum_walker(struct net *net, struct sk_buff *skb, 598 struct netlink_callback *cb, int type, 599 const struct tc_action_ops *ops) 600 { 601 struct tc_action_net *tn = net_generic(net, csum_net_id); 602 603 return tcf_generic_walker(tn, skb, cb, type, ops); 604 } 605 606 static int tcf_csum_search(struct net *net, struct tc_action **a, u32 index) 607 { 608 struct tc_action_net *tn = net_generic(net, csum_net_id); 609 610 return tcf_idr_search(tn, a, index); 611 } 612 613 static struct tc_action_ops act_csum_ops = { 614 .kind = "csum", 615 .type = TCA_ACT_CSUM, 616 .owner = THIS_MODULE, 617 .act = tcf_csum, 618 .dump = tcf_csum_dump, 619 .init = tcf_csum_init, 620 .walk = tcf_csum_walker, 621 .lookup = tcf_csum_search, 622 .size = sizeof(struct tcf_csum), 623 }; 624 625 static __net_init int csum_init_net(struct net *net) 626 { 627 struct tc_action_net *tn = net_generic(net, csum_net_id); 628 629 return tc_action_net_init(tn, &act_csum_ops); 630 } 631 632 static void __net_exit csum_exit_net(struct net *net) 633 { 634 struct tc_action_net *tn = net_generic(net, csum_net_id); 635 636 tc_action_net_exit(tn); 637 } 638 639 static struct pernet_operations csum_net_ops = { 640 .init = csum_init_net, 641 .exit = csum_exit_net, 642 .id = &csum_net_id, 643 .size = sizeof(struct tc_action_net), 644 }; 645 646 MODULE_DESCRIPTION("Checksum updating actions"); 647 MODULE_LICENSE("GPL"); 648 649 static int __init csum_init_module(void) 650 { 651 return tcf_register_action(&act_csum_ops, &csum_net_ops); 652 } 653 654 static void __exit csum_cleanup_module(void) 655 { 656 tcf_unregister_action(&act_csum_ops, &csum_net_ops); 657 } 658 659 module_init(csum_init_module); 660 module_exit(csum_cleanup_module); 661