xref: /openbmc/linux/net/sched/sch_frag.c (revision de8c12110a130337c8e7e7b8250de0580e644dee)
1 // SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
2 #include <net/netlink.h>
3 #include <net/sch_generic.h>
4 #include <net/dst.h>
5 #include <net/ip.h>
6 #include <net/ip6_fib.h>
7 
8 struct sch_frag_data {
9 	unsigned long dst;
10 	struct qdisc_skb_cb cb;
11 	__be16 inner_protocol;
12 	u16 vlan_tci;
13 	__be16 vlan_proto;
14 	unsigned int l2_len;
15 	u8 l2_data[VLAN_ETH_HLEN];
16 	int (*xmit)(struct sk_buff *skb);
17 };
18 
19 static DEFINE_PER_CPU(struct sch_frag_data, sch_frag_data_storage);
20 
21 static int sch_frag_xmit(struct net *net, struct sock *sk, struct sk_buff *skb)
22 {
23 	struct sch_frag_data *data = this_cpu_ptr(&sch_frag_data_storage);
24 
25 	if (skb_cow_head(skb, data->l2_len) < 0) {
26 		kfree_skb(skb);
27 		return -ENOMEM;
28 	}
29 
30 	__skb_dst_copy(skb, data->dst);
31 	*qdisc_skb_cb(skb) = data->cb;
32 	skb->inner_protocol = data->inner_protocol;
33 	if (data->vlan_tci & VLAN_CFI_MASK)
34 		__vlan_hwaccel_put_tag(skb, data->vlan_proto,
35 				       data->vlan_tci & ~VLAN_CFI_MASK);
36 	else
37 		__vlan_hwaccel_clear_tag(skb);
38 
39 	/* Reconstruct the MAC header.  */
40 	skb_push(skb, data->l2_len);
41 	memcpy(skb->data, &data->l2_data, data->l2_len);
42 	skb_postpush_rcsum(skb, skb->data, data->l2_len);
43 	skb_reset_mac_header(skb);
44 
45 	return data->xmit(skb);
46 }
47 
48 static void sch_frag_prepare_frag(struct sk_buff *skb,
49 				  int (*xmit)(struct sk_buff *skb))
50 {
51 	unsigned int hlen = skb_network_offset(skb);
52 	struct sch_frag_data *data;
53 
54 	data = this_cpu_ptr(&sch_frag_data_storage);
55 	data->dst = skb->_skb_refdst;
56 	data->cb = *qdisc_skb_cb(skb);
57 	data->xmit = xmit;
58 	data->inner_protocol = skb->inner_protocol;
59 	if (skb_vlan_tag_present(skb))
60 		data->vlan_tci = skb_vlan_tag_get(skb) | VLAN_CFI_MASK;
61 	else
62 		data->vlan_tci = 0;
63 	data->vlan_proto = skb->vlan_proto;
64 	data->l2_len = hlen;
65 	memcpy(&data->l2_data, skb->data, hlen);
66 
67 	memset(IPCB(skb), 0, sizeof(struct inet_skb_parm));
68 	skb_pull(skb, hlen);
69 }
70 
71 static unsigned int
72 sch_frag_dst_get_mtu(const struct dst_entry *dst)
73 {
74 	return dst->dev->mtu;
75 }
76 
77 static struct dst_ops sch_frag_dst_ops = {
78 	.family = AF_UNSPEC,
79 	.mtu = sch_frag_dst_get_mtu,
80 };
81 
82 static int sch_fragment(struct net *net, struct sk_buff *skb,
83 			u16 mru, int (*xmit)(struct sk_buff *skb))
84 {
85 	int ret = -1;
86 
87 	if (skb_network_offset(skb) > VLAN_ETH_HLEN) {
88 		net_warn_ratelimited("L2 header too long to fragment\n");
89 		goto err;
90 	}
91 
92 	if (skb_protocol(skb, true) == htons(ETH_P_IP)) {
93 		struct dst_entry sch_frag_dst;
94 		unsigned long orig_dst;
95 
96 		sch_frag_prepare_frag(skb, xmit);
97 		dst_init(&sch_frag_dst, &sch_frag_dst_ops, NULL, 1,
98 			 DST_OBSOLETE_NONE, DST_NOCOUNT);
99 		sch_frag_dst.dev = skb->dev;
100 
101 		orig_dst = skb->_skb_refdst;
102 		skb_dst_set_noref(skb, &sch_frag_dst);
103 		IPCB(skb)->frag_max_size = mru;
104 
105 		ret = ip_do_fragment(net, skb->sk, skb, sch_frag_xmit);
106 		refdst_drop(orig_dst);
107 	} else if (skb_protocol(skb, true) == htons(ETH_P_IPV6)) {
108 		unsigned long orig_dst;
109 		struct rt6_info sch_frag_rt;
110 
111 		sch_frag_prepare_frag(skb, xmit);
112 		memset(&sch_frag_rt, 0, sizeof(sch_frag_rt));
113 		dst_init(&sch_frag_rt.dst, &sch_frag_dst_ops, NULL, 1,
114 			 DST_OBSOLETE_NONE, DST_NOCOUNT);
115 		sch_frag_rt.dst.dev = skb->dev;
116 
117 		orig_dst = skb->_skb_refdst;
118 		skb_dst_set_noref(skb, &sch_frag_rt.dst);
119 		IP6CB(skb)->frag_max_size = mru;
120 
121 		ret = ipv6_stub->ipv6_fragment(net, skb->sk, skb,
122 					       sch_frag_xmit);
123 		refdst_drop(orig_dst);
124 	} else {
125 		net_warn_ratelimited("Fail frag %s: eth=%x, MRU=%d, MTU=%d\n",
126 				     netdev_name(skb->dev),
127 				     ntohs(skb_protocol(skb, true)), mru,
128 				     skb->dev->mtu);
129 		goto err;
130 	}
131 
132 	return ret;
133 err:
134 	kfree_skb(skb);
135 	return ret;
136 }
137 
138 int sch_frag_xmit_hook(struct sk_buff *skb, int (*xmit)(struct sk_buff *skb))
139 {
140 	u16 mru = qdisc_skb_cb(skb)->mru;
141 	int err;
142 
143 	if (mru && skb->len > mru + skb->dev->hard_header_len)
144 		err = sch_fragment(dev_net(skb->dev), skb, mru, xmit);
145 	else
146 		err = xmit(skb);
147 
148 	return err;
149 }
150 EXPORT_SYMBOL_GPL(sch_frag_xmit_hook);
151