xref: /openbmc/linux/net/sched/act_nat.c (revision 9a87ffc99ec8eb8d35eed7c4f816d75f5cc9662e)
1  // SPDX-License-Identifier: GPL-2.0-or-later
2  /*
3   * Stateless NAT actions
4   *
5   * Copyright (c) 2007 Herbert Xu <herbert@gondor.apana.org.au>
6   */
7  
8  #include <linux/errno.h>
9  #include <linux/init.h>
10  #include <linux/kernel.h>
11  #include <linux/module.h>
12  #include <linux/netfilter.h>
13  #include <linux/rtnetlink.h>
14  #include <linux/skbuff.h>
15  #include <linux/slab.h>
16  #include <linux/spinlock.h>
17  #include <linux/string.h>
18  #include <linux/tc_act/tc_nat.h>
19  #include <net/act_api.h>
20  #include <net/pkt_cls.h>
21  #include <net/icmp.h>
22  #include <net/ip.h>
23  #include <net/netlink.h>
24  #include <net/tc_act/tc_nat.h>
25  #include <net/tcp.h>
26  #include <net/udp.h>
27  #include <net/tc_wrapper.h>
28  
29  static struct tc_action_ops act_nat_ops;
30  
31  static const struct nla_policy nat_policy[TCA_NAT_MAX + 1] = {
32  	[TCA_NAT_PARMS]	= { .len = sizeof(struct tc_nat) },
33  };
34  
tcf_nat_init(struct net * net,struct nlattr * nla,struct nlattr * est,struct tc_action ** a,struct tcf_proto * tp,u32 flags,struct netlink_ext_ack * extack)35  static int tcf_nat_init(struct net *net, struct nlattr *nla, struct nlattr *est,
36  			struct tc_action **a, struct tcf_proto *tp,
37  			u32 flags, struct netlink_ext_ack *extack)
38  {
39  	struct tc_action_net *tn = net_generic(net, act_nat_ops.net_id);
40  	bool bind = flags & TCA_ACT_FLAGS_BIND;
41  	struct tcf_nat_parms *nparm, *oparm;
42  	struct nlattr *tb[TCA_NAT_MAX + 1];
43  	struct tcf_chain *goto_ch = NULL;
44  	struct tc_nat *parm;
45  	int ret = 0, err;
46  	struct tcf_nat *p;
47  	u32 index;
48  
49  	if (nla == NULL)
50  		return -EINVAL;
51  
52  	err = nla_parse_nested_deprecated(tb, TCA_NAT_MAX, nla, nat_policy,
53  					  NULL);
54  	if (err < 0)
55  		return err;
56  
57  	if (tb[TCA_NAT_PARMS] == NULL)
58  		return -EINVAL;
59  	parm = nla_data(tb[TCA_NAT_PARMS]);
60  	index = parm->index;
61  	err = tcf_idr_check_alloc(tn, &index, a, bind);
62  	if (!err) {
63  		ret = tcf_idr_create_from_flags(tn, index, est, a, &act_nat_ops,
64  						bind, flags);
65  		if (ret) {
66  			tcf_idr_cleanup(tn, index);
67  			return ret;
68  		}
69  		ret = ACT_P_CREATED;
70  	} else if (err > 0) {
71  		if (bind)
72  			return 0;
73  		if (!(flags & TCA_ACT_FLAGS_REPLACE)) {
74  			tcf_idr_release(*a, bind);
75  			return -EEXIST;
76  		}
77  	} else {
78  		return err;
79  	}
80  	err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack);
81  	if (err < 0)
82  		goto release_idr;
83  
84  	nparm = kzalloc(sizeof(*nparm), GFP_KERNEL);
85  	if (!nparm) {
86  		err = -ENOMEM;
87  		goto release_idr;
88  	}
89  
90  	nparm->old_addr = parm->old_addr;
91  	nparm->new_addr = parm->new_addr;
92  	nparm->mask = parm->mask;
93  	nparm->flags = parm->flags;
94  
95  	p = to_tcf_nat(*a);
96  
97  	spin_lock_bh(&p->tcf_lock);
98  	goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch);
99  	oparm = rcu_replace_pointer(p->parms, nparm, lockdep_is_held(&p->tcf_lock));
100  	spin_unlock_bh(&p->tcf_lock);
101  
102  	if (goto_ch)
103  		tcf_chain_put_by_act(goto_ch);
104  
105  	if (oparm)
106  		kfree_rcu(oparm, rcu);
107  
108  	return ret;
109  release_idr:
110  	tcf_idr_release(*a, bind);
111  	return err;
112  }
113  
tcf_nat_act(struct sk_buff * skb,const struct tc_action * a,struct tcf_result * res)114  TC_INDIRECT_SCOPE int tcf_nat_act(struct sk_buff *skb,
115  				  const struct tc_action *a,
116  				  struct tcf_result *res)
117  {
118  	struct tcf_nat *p = to_tcf_nat(a);
119  	struct tcf_nat_parms *parms;
120  	struct iphdr *iph;
121  	__be32 old_addr;
122  	__be32 new_addr;
123  	__be32 mask;
124  	__be32 addr;
125  	int egress;
126  	int action;
127  	int ihl;
128  	int noff;
129  
130  	tcf_lastuse_update(&p->tcf_tm);
131  	tcf_action_update_bstats(&p->common, skb);
132  
133  	action = READ_ONCE(p->tcf_action);
134  
135  	parms = rcu_dereference_bh(p->parms);
136  	old_addr = parms->old_addr;
137  	new_addr = parms->new_addr;
138  	mask = parms->mask;
139  	egress = parms->flags & TCA_NAT_FLAG_EGRESS;
140  
141  	if (unlikely(action == TC_ACT_SHOT))
142  		goto drop;
143  
144  	noff = skb_network_offset(skb);
145  	if (!pskb_may_pull(skb, sizeof(*iph) + noff))
146  		goto drop;
147  
148  	iph = ip_hdr(skb);
149  
150  	if (egress)
151  		addr = iph->saddr;
152  	else
153  		addr = iph->daddr;
154  
155  	if (!((old_addr ^ addr) & mask)) {
156  		if (skb_try_make_writable(skb, sizeof(*iph) + noff))
157  			goto drop;
158  
159  		new_addr &= mask;
160  		new_addr |= addr & ~mask;
161  
162  		/* Rewrite IP header */
163  		iph = ip_hdr(skb);
164  		if (egress)
165  			iph->saddr = new_addr;
166  		else
167  			iph->daddr = new_addr;
168  
169  		csum_replace4(&iph->check, addr, new_addr);
170  	} else if ((iph->frag_off & htons(IP_OFFSET)) ||
171  		   iph->protocol != IPPROTO_ICMP) {
172  		goto out;
173  	}
174  
175  	ihl = iph->ihl * 4;
176  
177  	/* It would be nice to share code with stateful NAT. */
178  	switch (iph->frag_off & htons(IP_OFFSET) ? 0 : iph->protocol) {
179  	case IPPROTO_TCP:
180  	{
181  		struct tcphdr *tcph;
182  
183  		if (!pskb_may_pull(skb, ihl + sizeof(*tcph) + noff) ||
184  		    skb_try_make_writable(skb, ihl + sizeof(*tcph) + noff))
185  			goto drop;
186  
187  		tcph = (void *)(skb_network_header(skb) + ihl);
188  		inet_proto_csum_replace4(&tcph->check, skb, addr, new_addr,
189  					 true);
190  		break;
191  	}
192  	case IPPROTO_UDP:
193  	{
194  		struct udphdr *udph;
195  
196  		if (!pskb_may_pull(skb, ihl + sizeof(*udph) + noff) ||
197  		    skb_try_make_writable(skb, ihl + sizeof(*udph) + noff))
198  			goto drop;
199  
200  		udph = (void *)(skb_network_header(skb) + ihl);
201  		if (udph->check || skb->ip_summed == CHECKSUM_PARTIAL) {
202  			inet_proto_csum_replace4(&udph->check, skb, addr,
203  						 new_addr, true);
204  			if (!udph->check)
205  				udph->check = CSUM_MANGLED_0;
206  		}
207  		break;
208  	}
209  	case IPPROTO_ICMP:
210  	{
211  		struct icmphdr *icmph;
212  
213  		if (!pskb_may_pull(skb, ihl + sizeof(*icmph) + noff))
214  			goto drop;
215  
216  		icmph = (void *)(skb_network_header(skb) + ihl);
217  
218  		if (!icmp_is_err(icmph->type))
219  			break;
220  
221  		if (!pskb_may_pull(skb, ihl + sizeof(*icmph) + sizeof(*iph) +
222  					noff))
223  			goto drop;
224  
225  		icmph = (void *)(skb_network_header(skb) + ihl);
226  		iph = (void *)(icmph + 1);
227  		if (egress)
228  			addr = iph->daddr;
229  		else
230  			addr = iph->saddr;
231  
232  		if ((old_addr ^ addr) & mask)
233  			break;
234  
235  		if (skb_try_make_writable(skb, ihl + sizeof(*icmph) +
236  					  sizeof(*iph) + noff))
237  			goto drop;
238  
239  		icmph = (void *)(skb_network_header(skb) + ihl);
240  		iph = (void *)(icmph + 1);
241  
242  		new_addr &= mask;
243  		new_addr |= addr & ~mask;
244  
245  		/* XXX Fix up the inner checksums. */
246  		if (egress)
247  			iph->daddr = new_addr;
248  		else
249  			iph->saddr = new_addr;
250  
251  		inet_proto_csum_replace4(&icmph->checksum, skb, addr, new_addr,
252  					 false);
253  		break;
254  	}
255  	default:
256  		break;
257  	}
258  
259  out:
260  	return action;
261  
262  drop:
263  	tcf_action_inc_drop_qstats(&p->common);
264  	return TC_ACT_SHOT;
265  }
266  
tcf_nat_dump(struct sk_buff * skb,struct tc_action * a,int bind,int ref)267  static int tcf_nat_dump(struct sk_buff *skb, struct tc_action *a,
268  			int bind, int ref)
269  {
270  	unsigned char *b = skb_tail_pointer(skb);
271  	struct tcf_nat *p = to_tcf_nat(a);
272  	struct tc_nat opt = {
273  		.index    = p->tcf_index,
274  		.refcnt   = refcount_read(&p->tcf_refcnt) - ref,
275  		.bindcnt  = atomic_read(&p->tcf_bindcnt) - bind,
276  	};
277  	struct tcf_nat_parms *parms;
278  	struct tcf_t t;
279  
280  	spin_lock_bh(&p->tcf_lock);
281  
282  	opt.action = p->tcf_action;
283  
284  	parms = rcu_dereference_protected(p->parms, lockdep_is_held(&p->tcf_lock));
285  
286  	opt.old_addr = parms->old_addr;
287  	opt.new_addr = parms->new_addr;
288  	opt.mask = parms->mask;
289  	opt.flags = parms->flags;
290  
291  	if (nla_put(skb, TCA_NAT_PARMS, sizeof(opt), &opt))
292  		goto nla_put_failure;
293  
294  	tcf_tm_dump(&t, &p->tcf_tm);
295  	if (nla_put_64bit(skb, TCA_NAT_TM, sizeof(t), &t, TCA_NAT_PAD))
296  		goto nla_put_failure;
297  	spin_unlock_bh(&p->tcf_lock);
298  
299  	return skb->len;
300  
301  nla_put_failure:
302  	spin_unlock_bh(&p->tcf_lock);
303  	nlmsg_trim(skb, b);
304  	return -1;
305  }
306  
tcf_nat_cleanup(struct tc_action * a)307  static void tcf_nat_cleanup(struct tc_action *a)
308  {
309  	struct tcf_nat *p = to_tcf_nat(a);
310  	struct tcf_nat_parms *parms;
311  
312  	parms = rcu_dereference_protected(p->parms, 1);
313  	if (parms)
314  		kfree_rcu(parms, rcu);
315  }
316  
317  static struct tc_action_ops act_nat_ops = {
318  	.kind		=	"nat",
319  	.id		=	TCA_ID_NAT,
320  	.owner		=	THIS_MODULE,
321  	.act		=	tcf_nat_act,
322  	.dump		=	tcf_nat_dump,
323  	.init		=	tcf_nat_init,
324  	.cleanup	=	tcf_nat_cleanup,
325  	.size		=	sizeof(struct tcf_nat),
326  };
327  
nat_init_net(struct net * net)328  static __net_init int nat_init_net(struct net *net)
329  {
330  	struct tc_action_net *tn = net_generic(net, act_nat_ops.net_id);
331  
332  	return tc_action_net_init(net, tn, &act_nat_ops);
333  }
334  
nat_exit_net(struct list_head * net_list)335  static void __net_exit nat_exit_net(struct list_head *net_list)
336  {
337  	tc_action_net_exit(net_list, act_nat_ops.net_id);
338  }
339  
340  static struct pernet_operations nat_net_ops = {
341  	.init = nat_init_net,
342  	.exit_batch = nat_exit_net,
343  	.id   = &act_nat_ops.net_id,
344  	.size = sizeof(struct tc_action_net),
345  };
346  
347  MODULE_DESCRIPTION("Stateless NAT actions");
348  MODULE_LICENSE("GPL");
349  
nat_init_module(void)350  static int __init nat_init_module(void)
351  {
352  	return tcf_register_action(&act_nat_ops, &nat_net_ops);
353  }
354  
nat_cleanup_module(void)355  static void __exit nat_cleanup_module(void)
356  {
357  	tcf_unregister_action(&act_nat_ops, &nat_net_ops);
358  }
359  
360  module_init(nat_init_module);
361  module_exit(nat_cleanup_module);
362