1 /* 2 * Checksum updating actions 3 * 4 * Copyright (c) 2010 Gregoire Baron <baronchon@n7mm.org> 5 * 6 * This program is free software; you can redistribute it and/or modify it 7 * under the terms of the GNU General Public License as published by the Free 8 * Software Foundation; either version 2 of the License, or (at your option) 9 * any later version. 10 * 11 */ 12 13 #include <linux/types.h> 14 #include <linux/init.h> 15 #include <linux/kernel.h> 16 #include <linux/module.h> 17 #include <linux/spinlock.h> 18 19 #include <linux/netlink.h> 20 #include <net/netlink.h> 21 #include <linux/rtnetlink.h> 22 23 #include <linux/skbuff.h> 24 25 #include <net/ip.h> 26 #include <net/ipv6.h> 27 #include <net/icmp.h> 28 #include <linux/icmpv6.h> 29 #include <linux/igmp.h> 30 #include <net/tcp.h> 31 #include <net/udp.h> 32 #include <net/ip6_checksum.h> 33 #include <net/sctp/checksum.h> 34 35 #include <net/act_api.h> 36 37 #include <linux/tc_act/tc_csum.h> 38 #include <net/tc_act/tc_csum.h> 39 40 #define CSUM_TAB_MASK 15 41 42 static const struct nla_policy csum_policy[TCA_CSUM_MAX + 1] = { 43 [TCA_CSUM_PARMS] = { .len = sizeof(struct tc_csum), }, 44 }; 45 46 static unsigned int csum_net_id; 47 static struct tc_action_ops act_csum_ops; 48 49 static int tcf_csum_init(struct net *net, struct nlattr *nla, 50 struct nlattr *est, struct tc_action **a, int ovr, 51 int bind) 52 { 53 struct tc_action_net *tn = net_generic(net, csum_net_id); 54 struct nlattr *tb[TCA_CSUM_MAX + 1]; 55 struct tc_csum *parm; 56 struct tcf_csum *p; 57 int ret = 0, err; 58 59 if (nla == NULL) 60 return -EINVAL; 61 62 err = nla_parse_nested(tb, TCA_CSUM_MAX, nla, csum_policy, NULL); 63 if (err < 0) 64 return err; 65 66 if (tb[TCA_CSUM_PARMS] == NULL) 67 return -EINVAL; 68 parm = nla_data(tb[TCA_CSUM_PARMS]); 69 70 if (!tcf_hash_check(tn, parm->index, a, bind)) { 71 ret = tcf_hash_create(tn, parm->index, est, a, 72 &act_csum_ops, bind, false); 73 if (ret) 74 return ret; 75 ret = ACT_P_CREATED; 76 } else { 77 if (bind)/* dont override defaults */ 78 return 0; 79 tcf_hash_release(*a, bind); 80 if (!ovr) 81 return -EEXIST; 82 } 83 84 p = to_tcf_csum(*a); 85 spin_lock_bh(&p->tcf_lock); 86 p->tcf_action = parm->action; 87 p->update_flags = parm->update_flags; 88 spin_unlock_bh(&p->tcf_lock); 89 90 if (ret == ACT_P_CREATED) 91 tcf_hash_insert(tn, *a); 92 93 return ret; 94 } 95 96 /** 97 * tcf_csum_skb_nextlayer - Get next layer pointer 98 * @skb: sk_buff to use 99 * @ihl: previous summed headers length 100 * @ipl: complete packet length 101 * @jhl: next header length 102 * 103 * Check the expected next layer availability in the specified sk_buff. 104 * Return the next layer pointer if pass, NULL otherwise. 105 */ 106 static void *tcf_csum_skb_nextlayer(struct sk_buff *skb, 107 unsigned int ihl, unsigned int ipl, 108 unsigned int jhl) 109 { 110 int ntkoff = skb_network_offset(skb); 111 int hl = ihl + jhl; 112 113 if (!pskb_may_pull(skb, ipl + ntkoff) || (ipl < hl) || 114 skb_try_make_writable(skb, hl + ntkoff)) 115 return NULL; 116 else 117 return (void *)(skb_network_header(skb) + ihl); 118 } 119 120 static int tcf_csum_ipv4_icmp(struct sk_buff *skb, unsigned int ihl, 121 unsigned int ipl) 122 { 123 struct icmphdr *icmph; 124 125 icmph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*icmph)); 126 if (icmph == NULL) 127 return 0; 128 129 icmph->checksum = 0; 130 skb->csum = csum_partial(icmph, ipl - ihl, 0); 131 icmph->checksum = csum_fold(skb->csum); 132 133 skb->ip_summed = CHECKSUM_NONE; 134 135 return 1; 136 } 137 138 static int tcf_csum_ipv4_igmp(struct sk_buff *skb, 139 unsigned int ihl, unsigned int ipl) 140 { 141 struct igmphdr *igmph; 142 143 igmph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*igmph)); 144 if (igmph == NULL) 145 return 0; 146 147 igmph->csum = 0; 148 skb->csum = csum_partial(igmph, ipl - ihl, 0); 149 igmph->csum = csum_fold(skb->csum); 150 151 skb->ip_summed = CHECKSUM_NONE; 152 153 return 1; 154 } 155 156 static int tcf_csum_ipv6_icmp(struct sk_buff *skb, unsigned int ihl, 157 unsigned int ipl) 158 { 159 struct icmp6hdr *icmp6h; 160 const struct ipv6hdr *ip6h; 161 162 icmp6h = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*icmp6h)); 163 if (icmp6h == NULL) 164 return 0; 165 166 ip6h = ipv6_hdr(skb); 167 icmp6h->icmp6_cksum = 0; 168 skb->csum = csum_partial(icmp6h, ipl - ihl, 0); 169 icmp6h->icmp6_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, 170 ipl - ihl, IPPROTO_ICMPV6, 171 skb->csum); 172 173 skb->ip_summed = CHECKSUM_NONE; 174 175 return 1; 176 } 177 178 static int tcf_csum_ipv4_tcp(struct sk_buff *skb, unsigned int ihl, 179 unsigned int ipl) 180 { 181 struct tcphdr *tcph; 182 const struct iphdr *iph; 183 184 if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_TCPV4) 185 return 1; 186 187 tcph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*tcph)); 188 if (tcph == NULL) 189 return 0; 190 191 iph = ip_hdr(skb); 192 tcph->check = 0; 193 skb->csum = csum_partial(tcph, ipl - ihl, 0); 194 tcph->check = tcp_v4_check(ipl - ihl, 195 iph->saddr, iph->daddr, skb->csum); 196 197 skb->ip_summed = CHECKSUM_NONE; 198 199 return 1; 200 } 201 202 static int tcf_csum_ipv6_tcp(struct sk_buff *skb, unsigned int ihl, 203 unsigned int ipl) 204 { 205 struct tcphdr *tcph; 206 const struct ipv6hdr *ip6h; 207 208 if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_TCPV6) 209 return 1; 210 211 tcph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*tcph)); 212 if (tcph == NULL) 213 return 0; 214 215 ip6h = ipv6_hdr(skb); 216 tcph->check = 0; 217 skb->csum = csum_partial(tcph, ipl - ihl, 0); 218 tcph->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, 219 ipl - ihl, IPPROTO_TCP, 220 skb->csum); 221 222 skb->ip_summed = CHECKSUM_NONE; 223 224 return 1; 225 } 226 227 static int tcf_csum_ipv4_udp(struct sk_buff *skb, unsigned int ihl, 228 unsigned int ipl, int udplite) 229 { 230 struct udphdr *udph; 231 const struct iphdr *iph; 232 u16 ul; 233 234 if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_UDP) 235 return 1; 236 237 /* 238 * Support both UDP and UDPLITE checksum algorithms, Don't use 239 * udph->len to get the real length without any protocol check, 240 * UDPLITE uses udph->len for another thing, 241 * Use iph->tot_len, or just ipl. 242 */ 243 244 udph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*udph)); 245 if (udph == NULL) 246 return 0; 247 248 iph = ip_hdr(skb); 249 ul = ntohs(udph->len); 250 251 if (udplite || udph->check) { 252 253 udph->check = 0; 254 255 if (udplite) { 256 if (ul == 0) 257 skb->csum = csum_partial(udph, ipl - ihl, 0); 258 else if ((ul >= sizeof(*udph)) && (ul <= ipl - ihl)) 259 skb->csum = csum_partial(udph, ul, 0); 260 else 261 goto ignore_obscure_skb; 262 } else { 263 if (ul != ipl - ihl) 264 goto ignore_obscure_skb; 265 266 skb->csum = csum_partial(udph, ul, 0); 267 } 268 269 udph->check = csum_tcpudp_magic(iph->saddr, iph->daddr, 270 ul, iph->protocol, 271 skb->csum); 272 273 if (!udph->check) 274 udph->check = CSUM_MANGLED_0; 275 } 276 277 skb->ip_summed = CHECKSUM_NONE; 278 279 ignore_obscure_skb: 280 return 1; 281 } 282 283 static int tcf_csum_ipv6_udp(struct sk_buff *skb, unsigned int ihl, 284 unsigned int ipl, int udplite) 285 { 286 struct udphdr *udph; 287 const struct ipv6hdr *ip6h; 288 u16 ul; 289 290 if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_UDP) 291 return 1; 292 293 /* 294 * Support both UDP and UDPLITE checksum algorithms, Don't use 295 * udph->len to get the real length without any protocol check, 296 * UDPLITE uses udph->len for another thing, 297 * Use ip6h->payload_len + sizeof(*ip6h) ... , or just ipl. 298 */ 299 300 udph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*udph)); 301 if (udph == NULL) 302 return 0; 303 304 ip6h = ipv6_hdr(skb); 305 ul = ntohs(udph->len); 306 307 udph->check = 0; 308 309 if (udplite) { 310 if (ul == 0) 311 skb->csum = csum_partial(udph, ipl - ihl, 0); 312 313 else if ((ul >= sizeof(*udph)) && (ul <= ipl - ihl)) 314 skb->csum = csum_partial(udph, ul, 0); 315 316 else 317 goto ignore_obscure_skb; 318 } else { 319 if (ul != ipl - ihl) 320 goto ignore_obscure_skb; 321 322 skb->csum = csum_partial(udph, ul, 0); 323 } 324 325 udph->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, ul, 326 udplite ? IPPROTO_UDPLITE : IPPROTO_UDP, 327 skb->csum); 328 329 if (!udph->check) 330 udph->check = CSUM_MANGLED_0; 331 332 skb->ip_summed = CHECKSUM_NONE; 333 334 ignore_obscure_skb: 335 return 1; 336 } 337 338 static int tcf_csum_sctp(struct sk_buff *skb, unsigned int ihl, 339 unsigned int ipl) 340 { 341 struct sctphdr *sctph; 342 343 if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_SCTP) 344 return 1; 345 346 sctph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*sctph)); 347 if (!sctph) 348 return 0; 349 350 sctph->checksum = sctp_compute_cksum(skb, 351 skb_network_offset(skb) + ihl); 352 skb->ip_summed = CHECKSUM_NONE; 353 skb->csum_not_inet = 0; 354 355 return 1; 356 } 357 358 static int tcf_csum_ipv4(struct sk_buff *skb, u32 update_flags) 359 { 360 const struct iphdr *iph; 361 int ntkoff; 362 363 ntkoff = skb_network_offset(skb); 364 365 if (!pskb_may_pull(skb, sizeof(*iph) + ntkoff)) 366 goto fail; 367 368 iph = ip_hdr(skb); 369 370 switch (iph->frag_off & htons(IP_OFFSET) ? 0 : iph->protocol) { 371 case IPPROTO_ICMP: 372 if (update_flags & TCA_CSUM_UPDATE_FLAG_ICMP) 373 if (!tcf_csum_ipv4_icmp(skb, iph->ihl * 4, 374 ntohs(iph->tot_len))) 375 goto fail; 376 break; 377 case IPPROTO_IGMP: 378 if (update_flags & TCA_CSUM_UPDATE_FLAG_IGMP) 379 if (!tcf_csum_ipv4_igmp(skb, iph->ihl * 4, 380 ntohs(iph->tot_len))) 381 goto fail; 382 break; 383 case IPPROTO_TCP: 384 if (update_flags & TCA_CSUM_UPDATE_FLAG_TCP) 385 if (!tcf_csum_ipv4_tcp(skb, iph->ihl * 4, 386 ntohs(iph->tot_len))) 387 goto fail; 388 break; 389 case IPPROTO_UDP: 390 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDP) 391 if (!tcf_csum_ipv4_udp(skb, iph->ihl * 4, 392 ntohs(iph->tot_len), 0)) 393 goto fail; 394 break; 395 case IPPROTO_UDPLITE: 396 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDPLITE) 397 if (!tcf_csum_ipv4_udp(skb, iph->ihl * 4, 398 ntohs(iph->tot_len), 1)) 399 goto fail; 400 break; 401 case IPPROTO_SCTP: 402 if ((update_flags & TCA_CSUM_UPDATE_FLAG_SCTP) && 403 !tcf_csum_sctp(skb, iph->ihl * 4, ntohs(iph->tot_len))) 404 goto fail; 405 break; 406 } 407 408 if (update_flags & TCA_CSUM_UPDATE_FLAG_IPV4HDR) { 409 if (skb_try_make_writable(skb, sizeof(*iph) + ntkoff)) 410 goto fail; 411 412 ip_send_check(ip_hdr(skb)); 413 } 414 415 return 1; 416 417 fail: 418 return 0; 419 } 420 421 static int tcf_csum_ipv6_hopopts(struct ipv6_opt_hdr *ip6xh, unsigned int ixhl, 422 unsigned int *pl) 423 { 424 int off, len, optlen; 425 unsigned char *xh = (void *)ip6xh; 426 427 off = sizeof(*ip6xh); 428 len = ixhl - off; 429 430 while (len > 1) { 431 switch (xh[off]) { 432 case IPV6_TLV_PAD1: 433 optlen = 1; 434 break; 435 case IPV6_TLV_JUMBO: 436 optlen = xh[off + 1] + 2; 437 if (optlen != 6 || len < 6 || (off & 3) != 2) 438 /* wrong jumbo option length/alignment */ 439 return 0; 440 *pl = ntohl(*(__be32 *)(xh + off + 2)); 441 goto done; 442 default: 443 optlen = xh[off + 1] + 2; 444 if (optlen > len) 445 /* ignore obscure options */ 446 goto done; 447 break; 448 } 449 off += optlen; 450 len -= optlen; 451 } 452 453 done: 454 return 1; 455 } 456 457 static int tcf_csum_ipv6(struct sk_buff *skb, u32 update_flags) 458 { 459 struct ipv6hdr *ip6h; 460 struct ipv6_opt_hdr *ip6xh; 461 unsigned int hl, ixhl; 462 unsigned int pl; 463 int ntkoff; 464 u8 nexthdr; 465 466 ntkoff = skb_network_offset(skb); 467 468 hl = sizeof(*ip6h); 469 470 if (!pskb_may_pull(skb, hl + ntkoff)) 471 goto fail; 472 473 ip6h = ipv6_hdr(skb); 474 475 pl = ntohs(ip6h->payload_len); 476 nexthdr = ip6h->nexthdr; 477 478 do { 479 switch (nexthdr) { 480 case NEXTHDR_FRAGMENT: 481 goto ignore_skb; 482 case NEXTHDR_ROUTING: 483 case NEXTHDR_HOP: 484 case NEXTHDR_DEST: 485 if (!pskb_may_pull(skb, hl + sizeof(*ip6xh) + ntkoff)) 486 goto fail; 487 ip6xh = (void *)(skb_network_header(skb) + hl); 488 ixhl = ipv6_optlen(ip6xh); 489 if (!pskb_may_pull(skb, hl + ixhl + ntkoff)) 490 goto fail; 491 ip6xh = (void *)(skb_network_header(skb) + hl); 492 if ((nexthdr == NEXTHDR_HOP) && 493 !(tcf_csum_ipv6_hopopts(ip6xh, ixhl, &pl))) 494 goto fail; 495 nexthdr = ip6xh->nexthdr; 496 hl += ixhl; 497 break; 498 case IPPROTO_ICMPV6: 499 if (update_flags & TCA_CSUM_UPDATE_FLAG_ICMP) 500 if (!tcf_csum_ipv6_icmp(skb, 501 hl, pl + sizeof(*ip6h))) 502 goto fail; 503 goto done; 504 case IPPROTO_TCP: 505 if (update_flags & TCA_CSUM_UPDATE_FLAG_TCP) 506 if (!tcf_csum_ipv6_tcp(skb, 507 hl, pl + sizeof(*ip6h))) 508 goto fail; 509 goto done; 510 case IPPROTO_UDP: 511 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDP) 512 if (!tcf_csum_ipv6_udp(skb, hl, 513 pl + sizeof(*ip6h), 0)) 514 goto fail; 515 goto done; 516 case IPPROTO_UDPLITE: 517 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDPLITE) 518 if (!tcf_csum_ipv6_udp(skb, hl, 519 pl + sizeof(*ip6h), 1)) 520 goto fail; 521 goto done; 522 case IPPROTO_SCTP: 523 if ((update_flags & TCA_CSUM_UPDATE_FLAG_SCTP) && 524 !tcf_csum_sctp(skb, hl, pl + sizeof(*ip6h))) 525 goto fail; 526 goto done; 527 default: 528 goto ignore_skb; 529 } 530 } while (pskb_may_pull(skb, hl + 1 + ntkoff)); 531 532 done: 533 ignore_skb: 534 return 1; 535 536 fail: 537 return 0; 538 } 539 540 static int tcf_csum(struct sk_buff *skb, const struct tc_action *a, 541 struct tcf_result *res) 542 { 543 struct tcf_csum *p = to_tcf_csum(a); 544 int action; 545 u32 update_flags; 546 547 spin_lock(&p->tcf_lock); 548 tcf_lastuse_update(&p->tcf_tm); 549 bstats_update(&p->tcf_bstats, skb); 550 action = p->tcf_action; 551 update_flags = p->update_flags; 552 spin_unlock(&p->tcf_lock); 553 554 if (unlikely(action == TC_ACT_SHOT)) 555 goto drop; 556 557 switch (tc_skb_protocol(skb)) { 558 case cpu_to_be16(ETH_P_IP): 559 if (!tcf_csum_ipv4(skb, update_flags)) 560 goto drop; 561 break; 562 case cpu_to_be16(ETH_P_IPV6): 563 if (!tcf_csum_ipv6(skb, update_flags)) 564 goto drop; 565 break; 566 } 567 568 return action; 569 570 drop: 571 spin_lock(&p->tcf_lock); 572 p->tcf_qstats.drops++; 573 spin_unlock(&p->tcf_lock); 574 return TC_ACT_SHOT; 575 } 576 577 static int tcf_csum_dump(struct sk_buff *skb, struct tc_action *a, int bind, 578 int ref) 579 { 580 unsigned char *b = skb_tail_pointer(skb); 581 struct tcf_csum *p = to_tcf_csum(a); 582 struct tc_csum opt = { 583 .update_flags = p->update_flags, 584 .index = p->tcf_index, 585 .action = p->tcf_action, 586 .refcnt = p->tcf_refcnt - ref, 587 .bindcnt = p->tcf_bindcnt - bind, 588 }; 589 struct tcf_t t; 590 591 if (nla_put(skb, TCA_CSUM_PARMS, sizeof(opt), &opt)) 592 goto nla_put_failure; 593 594 tcf_tm_dump(&t, &p->tcf_tm); 595 if (nla_put_64bit(skb, TCA_CSUM_TM, sizeof(t), &t, TCA_CSUM_PAD)) 596 goto nla_put_failure; 597 598 return skb->len; 599 600 nla_put_failure: 601 nlmsg_trim(skb, b); 602 return -1; 603 } 604 605 static int tcf_csum_walker(struct net *net, struct sk_buff *skb, 606 struct netlink_callback *cb, int type, 607 const struct tc_action_ops *ops) 608 { 609 struct tc_action_net *tn = net_generic(net, csum_net_id); 610 611 return tcf_generic_walker(tn, skb, cb, type, ops); 612 } 613 614 static int tcf_csum_search(struct net *net, struct tc_action **a, u32 index) 615 { 616 struct tc_action_net *tn = net_generic(net, csum_net_id); 617 618 return tcf_hash_search(tn, a, index); 619 } 620 621 static struct tc_action_ops act_csum_ops = { 622 .kind = "csum", 623 .type = TCA_ACT_CSUM, 624 .owner = THIS_MODULE, 625 .act = tcf_csum, 626 .dump = tcf_csum_dump, 627 .init = tcf_csum_init, 628 .walk = tcf_csum_walker, 629 .lookup = tcf_csum_search, 630 .size = sizeof(struct tcf_csum), 631 }; 632 633 static __net_init int csum_init_net(struct net *net) 634 { 635 struct tc_action_net *tn = net_generic(net, csum_net_id); 636 637 return tc_action_net_init(tn, &act_csum_ops, CSUM_TAB_MASK); 638 } 639 640 static void __net_exit csum_exit_net(struct net *net) 641 { 642 struct tc_action_net *tn = net_generic(net, csum_net_id); 643 644 tc_action_net_exit(tn); 645 } 646 647 static struct pernet_operations csum_net_ops = { 648 .init = csum_init_net, 649 .exit = csum_exit_net, 650 .id = &csum_net_id, 651 .size = sizeof(struct tc_action_net), 652 }; 653 654 MODULE_DESCRIPTION("Checksum updating actions"); 655 MODULE_LICENSE("GPL"); 656 657 static int __init csum_init_module(void) 658 { 659 return tcf_register_action(&act_csum_ops, &csum_net_ops); 660 } 661 662 static void __exit csum_cleanup_module(void) 663 { 664 tcf_unregister_action(&act_csum_ops, &csum_net_ops); 665 } 666 667 module_init(csum_init_module); 668 module_exit(csum_cleanup_module); 669