xref: /openbmc/linux/net/mpls/af_mpls.c (revision b96fc2f3)
1 #include <linux/types.h>
2 #include <linux/skbuff.h>
3 #include <linux/socket.h>
4 #include <linux/sysctl.h>
5 #include <linux/net.h>
6 #include <linux/module.h>
7 #include <linux/if_arp.h>
8 #include <linux/ipv6.h>
9 #include <linux/mpls.h>
10 #include <linux/vmalloc.h>
11 #include <net/ip.h>
12 #include <net/dst.h>
13 #include <net/sock.h>
14 #include <net/arp.h>
15 #include <net/ip_fib.h>
16 #include <net/netevent.h>
17 #include <net/netns/generic.h>
18 #if IS_ENABLED(CONFIG_IPV6)
19 #include <net/ipv6.h>
20 #include <net/addrconf.h>
21 #endif
22 #include "internal.h"
23 
24 #define LABEL_NOT_SPECIFIED (1<<20)
25 #define MAX_NEW_LABELS 2
26 
27 /* This maximum ha length copied from the definition of struct neighbour */
28 #define MAX_VIA_ALEN (ALIGN(MAX_ADDR_LEN, sizeof(unsigned long)))
29 
30 enum mpls_payload_type {
31 	MPT_UNSPEC, /* IPv4 or IPv6 */
32 	MPT_IPV4 = 4,
33 	MPT_IPV6 = 6,
34 
35 	/* Other types not implemented:
36 	 *  - Pseudo-wire with or without control word (RFC4385)
37 	 *  - GAL (RFC5586)
38 	 */
39 };
40 
41 struct mpls_route { /* next hop label forwarding entry */
42 	struct net_device __rcu *rt_dev;
43 	struct rcu_head		rt_rcu;
44 	u32			rt_label[MAX_NEW_LABELS];
45 	u8			rt_protocol; /* routing protocol that set this entry */
46 	u8                      rt_payload_type;
47 	u8			rt_labels;
48 	u8			rt_via_alen;
49 	u8			rt_via_table;
50 	u8			rt_via[0];
51 };
52 
53 static int zero = 0;
54 static int label_limit = (1 << 20) - 1;
55 
56 static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
57 		       struct nlmsghdr *nlh, struct net *net, u32 portid,
58 		       unsigned int nlm_flags);
59 
60 static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
61 {
62 	struct mpls_route *rt = NULL;
63 
64 	if (index < net->mpls.platform_labels) {
65 		struct mpls_route __rcu **platform_label =
66 			rcu_dereference(net->mpls.platform_label);
67 		rt = rcu_dereference(platform_label[index]);
68 	}
69 	return rt;
70 }
71 
72 static inline struct mpls_dev *mpls_dev_get(const struct net_device *dev)
73 {
74 	return rcu_dereference_rtnl(dev->mpls_ptr);
75 }
76 
77 bool mpls_output_possible(const struct net_device *dev)
78 {
79 	return dev && (dev->flags & IFF_UP) && netif_carrier_ok(dev);
80 }
81 EXPORT_SYMBOL_GPL(mpls_output_possible);
82 
83 static unsigned int mpls_rt_header_size(const struct mpls_route *rt)
84 {
85 	/* The size of the layer 2.5 labels to be added for this route */
86 	return rt->rt_labels * sizeof(struct mpls_shim_hdr);
87 }
88 
89 unsigned int mpls_dev_mtu(const struct net_device *dev)
90 {
91 	/* The amount of data the layer 2 frame can hold */
92 	return dev->mtu;
93 }
94 EXPORT_SYMBOL_GPL(mpls_dev_mtu);
95 
96 bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
97 {
98 	if (skb->len <= mtu)
99 		return false;
100 
101 	if (skb_is_gso(skb) && skb_gso_network_seglen(skb) <= mtu)
102 		return false;
103 
104 	return true;
105 }
106 EXPORT_SYMBOL_GPL(mpls_pkt_too_big);
107 
108 static bool mpls_egress(struct mpls_route *rt, struct sk_buff *skb,
109 			struct mpls_entry_decoded dec)
110 {
111 	enum mpls_payload_type payload_type;
112 	bool success = false;
113 
114 	/* The IPv4 code below accesses through the IPv4 header
115 	 * checksum, which is 12 bytes into the packet.
116 	 * The IPv6 code below accesses through the IPv6 hop limit
117 	 * which is 8 bytes into the packet.
118 	 *
119 	 * For all supported cases there should always be at least 12
120 	 * bytes of packet data present.  The IPv4 header is 20 bytes
121 	 * without options and the IPv6 header is always 40 bytes
122 	 * long.
123 	 */
124 	if (!pskb_may_pull(skb, 12))
125 		return false;
126 
127 	payload_type = rt->rt_payload_type;
128 	if (payload_type == MPT_UNSPEC)
129 		payload_type = ip_hdr(skb)->version;
130 
131 	switch (payload_type) {
132 	case MPT_IPV4: {
133 		struct iphdr *hdr4 = ip_hdr(skb);
134 		skb->protocol = htons(ETH_P_IP);
135 		csum_replace2(&hdr4->check,
136 			      htons(hdr4->ttl << 8),
137 			      htons(dec.ttl << 8));
138 		hdr4->ttl = dec.ttl;
139 		success = true;
140 		break;
141 	}
142 	case MPT_IPV6: {
143 		struct ipv6hdr *hdr6 = ipv6_hdr(skb);
144 		skb->protocol = htons(ETH_P_IPV6);
145 		hdr6->hop_limit = dec.ttl;
146 		success = true;
147 		break;
148 	}
149 	case MPT_UNSPEC:
150 		break;
151 	}
152 
153 	return success;
154 }
155 
156 static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
157 			struct packet_type *pt, struct net_device *orig_dev)
158 {
159 	struct net *net = dev_net(dev);
160 	struct mpls_shim_hdr *hdr;
161 	struct mpls_route *rt;
162 	struct mpls_entry_decoded dec;
163 	struct net_device *out_dev;
164 	struct mpls_dev *mdev;
165 	unsigned int hh_len;
166 	unsigned int new_header_size;
167 	unsigned int mtu;
168 	int err;
169 
170 	/* Careful this entire function runs inside of an rcu critical section */
171 
172 	mdev = mpls_dev_get(dev);
173 	if (!mdev || !mdev->input_enabled)
174 		goto drop;
175 
176 	if (skb->pkt_type != PACKET_HOST)
177 		goto drop;
178 
179 	if ((skb = skb_share_check(skb, GFP_ATOMIC)) == NULL)
180 		goto drop;
181 
182 	if (!pskb_may_pull(skb, sizeof(*hdr)))
183 		goto drop;
184 
185 	/* Read and decode the label */
186 	hdr = mpls_hdr(skb);
187 	dec = mpls_entry_decode(hdr);
188 
189 	/* Pop the label */
190 	skb_pull(skb, sizeof(*hdr));
191 	skb_reset_network_header(skb);
192 
193 	skb_orphan(skb);
194 
195 	rt = mpls_route_input_rcu(net, dec.label);
196 	if (!rt)
197 		goto drop;
198 
199 	/* Find the output device */
200 	out_dev = rcu_dereference(rt->rt_dev);
201 	if (!mpls_output_possible(out_dev))
202 		goto drop;
203 
204 	if (skb_warn_if_lro(skb))
205 		goto drop;
206 
207 	skb_forward_csum(skb);
208 
209 	/* Verify ttl is valid */
210 	if (dec.ttl <= 1)
211 		goto drop;
212 	dec.ttl -= 1;
213 
214 	/* Verify the destination can hold the packet */
215 	new_header_size = mpls_rt_header_size(rt);
216 	mtu = mpls_dev_mtu(out_dev);
217 	if (mpls_pkt_too_big(skb, mtu - new_header_size))
218 		goto drop;
219 
220 	hh_len = LL_RESERVED_SPACE(out_dev);
221 	if (!out_dev->header_ops)
222 		hh_len = 0;
223 
224 	/* Ensure there is enough space for the headers in the skb */
225 	if (skb_cow(skb, hh_len + new_header_size))
226 		goto drop;
227 
228 	skb->dev = out_dev;
229 	skb->protocol = htons(ETH_P_MPLS_UC);
230 
231 	if (unlikely(!new_header_size && dec.bos)) {
232 		/* Penultimate hop popping */
233 		if (!mpls_egress(rt, skb, dec))
234 			goto drop;
235 	} else {
236 		bool bos;
237 		int i;
238 		skb_push(skb, new_header_size);
239 		skb_reset_network_header(skb);
240 		/* Push the new labels */
241 		hdr = mpls_hdr(skb);
242 		bos = dec.bos;
243 		for (i = rt->rt_labels - 1; i >= 0; i--) {
244 			hdr[i] = mpls_entry_encode(rt->rt_label[i], dec.ttl, 0, bos);
245 			bos = false;
246 		}
247 	}
248 
249 	err = neigh_xmit(rt->rt_via_table, out_dev, rt->rt_via, skb);
250 	if (err)
251 		net_dbg_ratelimited("%s: packet transmission failed: %d\n",
252 				    __func__, err);
253 	return 0;
254 
255 drop:
256 	kfree_skb(skb);
257 	return NET_RX_DROP;
258 }
259 
260 static struct packet_type mpls_packet_type __read_mostly = {
261 	.type = cpu_to_be16(ETH_P_MPLS_UC),
262 	.func = mpls_forward,
263 };
264 
265 static const struct nla_policy rtm_mpls_policy[RTA_MAX+1] = {
266 	[RTA_DST]		= { .type = NLA_U32 },
267 	[RTA_OIF]		= { .type = NLA_U32 },
268 };
269 
270 struct mpls_route_config {
271 	u32			rc_protocol;
272 	u32			rc_ifindex;
273 	u16			rc_via_table;
274 	u16			rc_via_alen;
275 	u8			rc_via[MAX_VIA_ALEN];
276 	u32			rc_label;
277 	u32			rc_output_labels;
278 	u32			rc_output_label[MAX_NEW_LABELS];
279 	u32			rc_nlflags;
280 	enum mpls_payload_type	rc_payload_type;
281 	struct nl_info		rc_nlinfo;
282 };
283 
284 static struct mpls_route *mpls_rt_alloc(size_t alen)
285 {
286 	struct mpls_route *rt;
287 
288 	rt = kzalloc(sizeof(*rt) + alen, GFP_KERNEL);
289 	if (rt)
290 		rt->rt_via_alen = alen;
291 	return rt;
292 }
293 
294 static void mpls_rt_free(struct mpls_route *rt)
295 {
296 	if (rt)
297 		kfree_rcu(rt, rt_rcu);
298 }
299 
300 static void mpls_notify_route(struct net *net, unsigned index,
301 			      struct mpls_route *old, struct mpls_route *new,
302 			      const struct nl_info *info)
303 {
304 	struct nlmsghdr *nlh = info ? info->nlh : NULL;
305 	unsigned portid = info ? info->portid : 0;
306 	int event = new ? RTM_NEWROUTE : RTM_DELROUTE;
307 	struct mpls_route *rt = new ? new : old;
308 	unsigned nlm_flags = (old && new) ? NLM_F_REPLACE : 0;
309 	/* Ignore reserved labels for now */
310 	if (rt && (index >= MPLS_LABEL_FIRST_UNRESERVED))
311 		rtmsg_lfib(event, index, rt, nlh, net, portid, nlm_flags);
312 }
313 
314 static void mpls_route_update(struct net *net, unsigned index,
315 			      struct net_device *dev, struct mpls_route *new,
316 			      const struct nl_info *info)
317 {
318 	struct mpls_route __rcu **platform_label;
319 	struct mpls_route *rt, *old = NULL;
320 
321 	ASSERT_RTNL();
322 
323 	platform_label = rtnl_dereference(net->mpls.platform_label);
324 	rt = rtnl_dereference(platform_label[index]);
325 	if (!dev || (rt && (rtnl_dereference(rt->rt_dev) == dev))) {
326 		rcu_assign_pointer(platform_label[index], new);
327 		old = rt;
328 	}
329 
330 	mpls_notify_route(net, index, old, new, info);
331 
332 	/* If we removed a route free it now */
333 	mpls_rt_free(old);
334 }
335 
336 static unsigned find_free_label(struct net *net)
337 {
338 	struct mpls_route __rcu **platform_label;
339 	size_t platform_labels;
340 	unsigned index;
341 
342 	platform_label = rtnl_dereference(net->mpls.platform_label);
343 	platform_labels = net->mpls.platform_labels;
344 	for (index = MPLS_LABEL_FIRST_UNRESERVED; index < platform_labels;
345 	     index++) {
346 		if (!rtnl_dereference(platform_label[index]))
347 			return index;
348 	}
349 	return LABEL_NOT_SPECIFIED;
350 }
351 
352 #if IS_ENABLED(CONFIG_INET)
353 static struct net_device *inet_fib_lookup_dev(struct net *net, void *addr)
354 {
355 	struct net_device *dev;
356 	struct rtable *rt;
357 	struct in_addr daddr;
358 
359 	memcpy(&daddr, addr, sizeof(struct in_addr));
360 	rt = ip_route_output(net, daddr.s_addr, 0, 0, 0);
361 	if (IS_ERR(rt))
362 		return ERR_CAST(rt);
363 
364 	dev = rt->dst.dev;
365 	dev_hold(dev);
366 
367 	ip_rt_put(rt);
368 
369 	return dev;
370 }
371 #else
372 static struct net_device *inet_fib_lookup_dev(struct net *net, void *addr)
373 {
374 	return ERR_PTR(-EAFNOSUPPORT);
375 }
376 #endif
377 
378 #if IS_ENABLED(CONFIG_IPV6)
379 static struct net_device *inet6_fib_lookup_dev(struct net *net, void *addr)
380 {
381 	struct net_device *dev;
382 	struct dst_entry *dst;
383 	struct flowi6 fl6;
384 	int err;
385 
386 	if (!ipv6_stub)
387 		return ERR_PTR(-EAFNOSUPPORT);
388 
389 	memset(&fl6, 0, sizeof(fl6));
390 	memcpy(&fl6.daddr, addr, sizeof(struct in6_addr));
391 	err = ipv6_stub->ipv6_dst_lookup(net, NULL, &dst, &fl6);
392 	if (err)
393 		return ERR_PTR(err);
394 
395 	dev = dst->dev;
396 	dev_hold(dev);
397 	dst_release(dst);
398 
399 	return dev;
400 }
401 #else
402 static struct net_device *inet6_fib_lookup_dev(struct net *net, void *addr)
403 {
404 	return ERR_PTR(-EAFNOSUPPORT);
405 }
406 #endif
407 
408 static struct net_device *find_outdev(struct net *net,
409 				      struct mpls_route_config *cfg)
410 {
411 	struct net_device *dev = NULL;
412 
413 	if (!cfg->rc_ifindex) {
414 		switch (cfg->rc_via_table) {
415 		case NEIGH_ARP_TABLE:
416 			dev = inet_fib_lookup_dev(net, cfg->rc_via);
417 			break;
418 		case NEIGH_ND_TABLE:
419 			dev = inet6_fib_lookup_dev(net, cfg->rc_via);
420 			break;
421 		case NEIGH_LINK_TABLE:
422 			break;
423 		}
424 	} else {
425 		dev = dev_get_by_index(net, cfg->rc_ifindex);
426 	}
427 
428 	if (!dev)
429 		return ERR_PTR(-ENODEV);
430 
431 	return dev;
432 }
433 
434 static int mpls_route_add(struct mpls_route_config *cfg)
435 {
436 	struct mpls_route __rcu **platform_label;
437 	struct net *net = cfg->rc_nlinfo.nl_net;
438 	struct net_device *dev = NULL;
439 	struct mpls_route *rt, *old;
440 	unsigned index;
441 	int i;
442 	int err = -EINVAL;
443 
444 	index = cfg->rc_label;
445 
446 	/* If a label was not specified during insert pick one */
447 	if ((index == LABEL_NOT_SPECIFIED) &&
448 	    (cfg->rc_nlflags & NLM_F_CREATE)) {
449 		index = find_free_label(net);
450 	}
451 
452 	/* Reserved labels may not be set */
453 	if (index < MPLS_LABEL_FIRST_UNRESERVED)
454 		goto errout;
455 
456 	/* The full 20 bit range may not be supported. */
457 	if (index >= net->mpls.platform_labels)
458 		goto errout;
459 
460 	/* Ensure only a supported number of labels are present */
461 	if (cfg->rc_output_labels > MAX_NEW_LABELS)
462 		goto errout;
463 
464 	dev = find_outdev(net, cfg);
465 	if (IS_ERR(dev)) {
466 		err = PTR_ERR(dev);
467 		dev = NULL;
468 		goto errout;
469 	}
470 
471 	/* Ensure this is a supported device */
472 	err = -EINVAL;
473 	if (!mpls_dev_get(dev))
474 		goto errout;
475 
476 	err = -EINVAL;
477 	if ((cfg->rc_via_table == NEIGH_LINK_TABLE) &&
478 	    (dev->addr_len != cfg->rc_via_alen))
479 		goto errout;
480 
481 	/* Append makes no sense with mpls */
482 	err = -EOPNOTSUPP;
483 	if (cfg->rc_nlflags & NLM_F_APPEND)
484 		goto errout;
485 
486 	err = -EEXIST;
487 	platform_label = rtnl_dereference(net->mpls.platform_label);
488 	old = rtnl_dereference(platform_label[index]);
489 	if ((cfg->rc_nlflags & NLM_F_EXCL) && old)
490 		goto errout;
491 
492 	err = -EEXIST;
493 	if (!(cfg->rc_nlflags & NLM_F_REPLACE) && old)
494 		goto errout;
495 
496 	err = -ENOENT;
497 	if (!(cfg->rc_nlflags & NLM_F_CREATE) && !old)
498 		goto errout;
499 
500 	err = -ENOMEM;
501 	rt = mpls_rt_alloc(cfg->rc_via_alen);
502 	if (!rt)
503 		goto errout;
504 
505 	rt->rt_labels = cfg->rc_output_labels;
506 	for (i = 0; i < rt->rt_labels; i++)
507 		rt->rt_label[i] = cfg->rc_output_label[i];
508 	rt->rt_protocol = cfg->rc_protocol;
509 	RCU_INIT_POINTER(rt->rt_dev, dev);
510 	rt->rt_payload_type = cfg->rc_payload_type;
511 	rt->rt_via_table = cfg->rc_via_table;
512 	memcpy(rt->rt_via, cfg->rc_via, cfg->rc_via_alen);
513 
514 	mpls_route_update(net, index, NULL, rt, &cfg->rc_nlinfo);
515 
516 	dev_put(dev);
517 	return 0;
518 
519 errout:
520 	if (dev)
521 		dev_put(dev);
522 	return err;
523 }
524 
525 static int mpls_route_del(struct mpls_route_config *cfg)
526 {
527 	struct net *net = cfg->rc_nlinfo.nl_net;
528 	unsigned index;
529 	int err = -EINVAL;
530 
531 	index = cfg->rc_label;
532 
533 	/* Reserved labels may not be removed */
534 	if (index < MPLS_LABEL_FIRST_UNRESERVED)
535 		goto errout;
536 
537 	/* The full 20 bit range may not be supported */
538 	if (index >= net->mpls.platform_labels)
539 		goto errout;
540 
541 	mpls_route_update(net, index, NULL, NULL, &cfg->rc_nlinfo);
542 
543 	err = 0;
544 errout:
545 	return err;
546 }
547 
548 #define MPLS_PERDEV_SYSCTL_OFFSET(field)	\
549 	(&((struct mpls_dev *)0)->field)
550 
551 static const struct ctl_table mpls_dev_table[] = {
552 	{
553 		.procname	= "input",
554 		.maxlen		= sizeof(int),
555 		.mode		= 0644,
556 		.proc_handler	= proc_dointvec,
557 		.data		= MPLS_PERDEV_SYSCTL_OFFSET(input_enabled),
558 	},
559 	{ }
560 };
561 
562 static int mpls_dev_sysctl_register(struct net_device *dev,
563 				    struct mpls_dev *mdev)
564 {
565 	char path[sizeof("net/mpls/conf/") + IFNAMSIZ];
566 	struct ctl_table *table;
567 	int i;
568 
569 	table = kmemdup(&mpls_dev_table, sizeof(mpls_dev_table), GFP_KERNEL);
570 	if (!table)
571 		goto out;
572 
573 	/* Table data contains only offsets relative to the base of
574 	 * the mdev at this point, so make them absolute.
575 	 */
576 	for (i = 0; i < ARRAY_SIZE(mpls_dev_table); i++)
577 		table[i].data = (char *)mdev + (uintptr_t)table[i].data;
578 
579 	snprintf(path, sizeof(path), "net/mpls/conf/%s", dev->name);
580 
581 	mdev->sysctl = register_net_sysctl(dev_net(dev), path, table);
582 	if (!mdev->sysctl)
583 		goto free;
584 
585 	return 0;
586 
587 free:
588 	kfree(table);
589 out:
590 	return -ENOBUFS;
591 }
592 
593 static void mpls_dev_sysctl_unregister(struct mpls_dev *mdev)
594 {
595 	struct ctl_table *table;
596 
597 	table = mdev->sysctl->ctl_table_arg;
598 	unregister_net_sysctl_table(mdev->sysctl);
599 	kfree(table);
600 }
601 
602 static struct mpls_dev *mpls_add_dev(struct net_device *dev)
603 {
604 	struct mpls_dev *mdev;
605 	int err = -ENOMEM;
606 
607 	ASSERT_RTNL();
608 
609 	mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
610 	if (!mdev)
611 		return ERR_PTR(err);
612 
613 	err = mpls_dev_sysctl_register(dev, mdev);
614 	if (err)
615 		goto free;
616 
617 	rcu_assign_pointer(dev->mpls_ptr, mdev);
618 
619 	return mdev;
620 
621 free:
622 	kfree(mdev);
623 	return ERR_PTR(err);
624 }
625 
626 static void mpls_ifdown(struct net_device *dev)
627 {
628 	struct mpls_route __rcu **platform_label;
629 	struct net *net = dev_net(dev);
630 	struct mpls_dev *mdev;
631 	unsigned index;
632 
633 	platform_label = rtnl_dereference(net->mpls.platform_label);
634 	for (index = 0; index < net->mpls.platform_labels; index++) {
635 		struct mpls_route *rt = rtnl_dereference(platform_label[index]);
636 		if (!rt)
637 			continue;
638 		if (rtnl_dereference(rt->rt_dev) != dev)
639 			continue;
640 		rt->rt_dev = NULL;
641 	}
642 
643 	mdev = mpls_dev_get(dev);
644 	if (!mdev)
645 		return;
646 
647 	mpls_dev_sysctl_unregister(mdev);
648 
649 	RCU_INIT_POINTER(dev->mpls_ptr, NULL);
650 
651 	kfree_rcu(mdev, rcu);
652 }
653 
654 static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
655 			   void *ptr)
656 {
657 	struct net_device *dev = netdev_notifier_info_to_dev(ptr);
658 	struct mpls_dev *mdev;
659 
660 	switch(event) {
661 	case NETDEV_REGISTER:
662 		/* For now just support ethernet devices */
663 		if ((dev->type == ARPHRD_ETHER) ||
664 		    (dev->type == ARPHRD_LOOPBACK)) {
665 			mdev = mpls_add_dev(dev);
666 			if (IS_ERR(mdev))
667 				return notifier_from_errno(PTR_ERR(mdev));
668 		}
669 		break;
670 
671 	case NETDEV_UNREGISTER:
672 		mpls_ifdown(dev);
673 		break;
674 	case NETDEV_CHANGENAME:
675 		mdev = mpls_dev_get(dev);
676 		if (mdev) {
677 			int err;
678 
679 			mpls_dev_sysctl_unregister(mdev);
680 			err = mpls_dev_sysctl_register(dev, mdev);
681 			if (err)
682 				return notifier_from_errno(err);
683 		}
684 		break;
685 	}
686 	return NOTIFY_OK;
687 }
688 
689 static struct notifier_block mpls_dev_notifier = {
690 	.notifier_call = mpls_dev_notify,
691 };
692 
693 static int nla_put_via(struct sk_buff *skb,
694 		       u8 table, const void *addr, int alen)
695 {
696 	static const int table_to_family[NEIGH_NR_TABLES + 1] = {
697 		AF_INET, AF_INET6, AF_DECnet, AF_PACKET,
698 	};
699 	struct nlattr *nla;
700 	struct rtvia *via;
701 	int family = AF_UNSPEC;
702 
703 	nla = nla_reserve(skb, RTA_VIA, alen + 2);
704 	if (!nla)
705 		return -EMSGSIZE;
706 
707 	if (table <= NEIGH_NR_TABLES)
708 		family = table_to_family[table];
709 
710 	via = nla_data(nla);
711 	via->rtvia_family = family;
712 	memcpy(via->rtvia_addr, addr, alen);
713 	return 0;
714 }
715 
716 int nla_put_labels(struct sk_buff *skb, int attrtype,
717 		   u8 labels, const u32 label[])
718 {
719 	struct nlattr *nla;
720 	struct mpls_shim_hdr *nla_label;
721 	bool bos;
722 	int i;
723 	nla = nla_reserve(skb, attrtype, labels*4);
724 	if (!nla)
725 		return -EMSGSIZE;
726 
727 	nla_label = nla_data(nla);
728 	bos = true;
729 	for (i = labels - 1; i >= 0; i--) {
730 		nla_label[i] = mpls_entry_encode(label[i], 0, 0, bos);
731 		bos = false;
732 	}
733 
734 	return 0;
735 }
736 EXPORT_SYMBOL_GPL(nla_put_labels);
737 
738 int nla_get_labels(const struct nlattr *nla,
739 		   u32 max_labels, u32 *labels, u32 label[])
740 {
741 	unsigned len = nla_len(nla);
742 	unsigned nla_labels;
743 	struct mpls_shim_hdr *nla_label;
744 	bool bos;
745 	int i;
746 
747 	/* len needs to be an even multiple of 4 (the label size) */
748 	if (len & 3)
749 		return -EINVAL;
750 
751 	/* Limit the number of new labels allowed */
752 	nla_labels = len/4;
753 	if (nla_labels > max_labels)
754 		return -EINVAL;
755 
756 	nla_label = nla_data(nla);
757 	bos = true;
758 	for (i = nla_labels - 1; i >= 0; i--, bos = false) {
759 		struct mpls_entry_decoded dec;
760 		dec = mpls_entry_decode(nla_label + i);
761 
762 		/* Ensure the bottom of stack flag is properly set
763 		 * and ttl and tc are both clear.
764 		 */
765 		if ((dec.bos != bos) || dec.ttl || dec.tc)
766 			return -EINVAL;
767 
768 		switch (dec.label) {
769 		case MPLS_LABEL_IMPLNULL:
770 			/* RFC3032: This is a label that an LSR may
771 			 * assign and distribute, but which never
772 			 * actually appears in the encapsulation.
773 			 */
774 			return -EINVAL;
775 		}
776 
777 		label[i] = dec.label;
778 	}
779 	*labels = nla_labels;
780 	return 0;
781 }
782 EXPORT_SYMBOL_GPL(nla_get_labels);
783 
784 static int rtm_to_route_config(struct sk_buff *skb,  struct nlmsghdr *nlh,
785 			       struct mpls_route_config *cfg)
786 {
787 	struct rtmsg *rtm;
788 	struct nlattr *tb[RTA_MAX+1];
789 	int index;
790 	int err;
791 
792 	err = nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, rtm_mpls_policy);
793 	if (err < 0)
794 		goto errout;
795 
796 	err = -EINVAL;
797 	rtm = nlmsg_data(nlh);
798 	memset(cfg, 0, sizeof(*cfg));
799 
800 	if (rtm->rtm_family != AF_MPLS)
801 		goto errout;
802 	if (rtm->rtm_dst_len != 20)
803 		goto errout;
804 	if (rtm->rtm_src_len != 0)
805 		goto errout;
806 	if (rtm->rtm_tos != 0)
807 		goto errout;
808 	if (rtm->rtm_table != RT_TABLE_MAIN)
809 		goto errout;
810 	/* Any value is acceptable for rtm_protocol */
811 
812 	/* As mpls uses destination specific addresses
813 	 * (or source specific address in the case of multicast)
814 	 * all addresses have universal scope.
815 	 */
816 	if (rtm->rtm_scope != RT_SCOPE_UNIVERSE)
817 		goto errout;
818 	if (rtm->rtm_type != RTN_UNICAST)
819 		goto errout;
820 	if (rtm->rtm_flags != 0)
821 		goto errout;
822 
823 	cfg->rc_label		= LABEL_NOT_SPECIFIED;
824 	cfg->rc_protocol	= rtm->rtm_protocol;
825 	cfg->rc_nlflags		= nlh->nlmsg_flags;
826 	cfg->rc_nlinfo.portid	= NETLINK_CB(skb).portid;
827 	cfg->rc_nlinfo.nlh	= nlh;
828 	cfg->rc_nlinfo.nl_net	= sock_net(skb->sk);
829 
830 	for (index = 0; index <= RTA_MAX; index++) {
831 		struct nlattr *nla = tb[index];
832 		if (!nla)
833 			continue;
834 
835 		switch(index) {
836 		case RTA_OIF:
837 			cfg->rc_ifindex = nla_get_u32(nla);
838 			break;
839 		case RTA_NEWDST:
840 			if (nla_get_labels(nla, MAX_NEW_LABELS,
841 					   &cfg->rc_output_labels,
842 					   cfg->rc_output_label))
843 				goto errout;
844 			break;
845 		case RTA_DST:
846 		{
847 			u32 label_count;
848 			if (nla_get_labels(nla, 1, &label_count,
849 					   &cfg->rc_label))
850 				goto errout;
851 
852 			/* Reserved labels may not be set */
853 			if (cfg->rc_label < MPLS_LABEL_FIRST_UNRESERVED)
854 				goto errout;
855 
856 			break;
857 		}
858 		case RTA_VIA:
859 		{
860 			struct rtvia *via = nla_data(nla);
861 			if (nla_len(nla) < offsetof(struct rtvia, rtvia_addr))
862 				goto errout;
863 			cfg->rc_via_alen   = nla_len(nla) -
864 				offsetof(struct rtvia, rtvia_addr);
865 			if (cfg->rc_via_alen > MAX_VIA_ALEN)
866 				goto errout;
867 
868 			/* Validate the address family */
869 			switch(via->rtvia_family) {
870 			case AF_PACKET:
871 				cfg->rc_via_table = NEIGH_LINK_TABLE;
872 				break;
873 			case AF_INET:
874 				cfg->rc_via_table = NEIGH_ARP_TABLE;
875 				if (cfg->rc_via_alen != 4)
876 					goto errout;
877 				break;
878 			case AF_INET6:
879 				cfg->rc_via_table = NEIGH_ND_TABLE;
880 				if (cfg->rc_via_alen != 16)
881 					goto errout;
882 				break;
883 			default:
884 				/* Unsupported address family */
885 				goto errout;
886 			}
887 
888 			memcpy(cfg->rc_via, via->rtvia_addr, cfg->rc_via_alen);
889 			break;
890 		}
891 		default:
892 			/* Unsupported attribute */
893 			goto errout;
894 		}
895 	}
896 
897 	err = 0;
898 errout:
899 	return err;
900 }
901 
902 static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh)
903 {
904 	struct mpls_route_config cfg;
905 	int err;
906 
907 	err = rtm_to_route_config(skb, nlh, &cfg);
908 	if (err < 0)
909 		return err;
910 
911 	return mpls_route_del(&cfg);
912 }
913 
914 
915 static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh)
916 {
917 	struct mpls_route_config cfg;
918 	int err;
919 
920 	err = rtm_to_route_config(skb, nlh, &cfg);
921 	if (err < 0)
922 		return err;
923 
924 	return mpls_route_add(&cfg);
925 }
926 
927 static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event,
928 			   u32 label, struct mpls_route *rt, int flags)
929 {
930 	struct net_device *dev;
931 	struct nlmsghdr *nlh;
932 	struct rtmsg *rtm;
933 
934 	nlh = nlmsg_put(skb, portid, seq, event, sizeof(*rtm), flags);
935 	if (nlh == NULL)
936 		return -EMSGSIZE;
937 
938 	rtm = nlmsg_data(nlh);
939 	rtm->rtm_family = AF_MPLS;
940 	rtm->rtm_dst_len = 20;
941 	rtm->rtm_src_len = 0;
942 	rtm->rtm_tos = 0;
943 	rtm->rtm_table = RT_TABLE_MAIN;
944 	rtm->rtm_protocol = rt->rt_protocol;
945 	rtm->rtm_scope = RT_SCOPE_UNIVERSE;
946 	rtm->rtm_type = RTN_UNICAST;
947 	rtm->rtm_flags = 0;
948 
949 	if (rt->rt_labels &&
950 	    nla_put_labels(skb, RTA_NEWDST, rt->rt_labels, rt->rt_label))
951 		goto nla_put_failure;
952 	if (nla_put_via(skb, rt->rt_via_table, rt->rt_via, rt->rt_via_alen))
953 		goto nla_put_failure;
954 	dev = rtnl_dereference(rt->rt_dev);
955 	if (dev && nla_put_u32(skb, RTA_OIF, dev->ifindex))
956 		goto nla_put_failure;
957 	if (nla_put_labels(skb, RTA_DST, 1, &label))
958 		goto nla_put_failure;
959 
960 	nlmsg_end(skb, nlh);
961 	return 0;
962 
963 nla_put_failure:
964 	nlmsg_cancel(skb, nlh);
965 	return -EMSGSIZE;
966 }
967 
968 static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
969 {
970 	struct net *net = sock_net(skb->sk);
971 	struct mpls_route __rcu **platform_label;
972 	size_t platform_labels;
973 	unsigned int index;
974 
975 	ASSERT_RTNL();
976 
977 	index = cb->args[0];
978 	if (index < MPLS_LABEL_FIRST_UNRESERVED)
979 		index = MPLS_LABEL_FIRST_UNRESERVED;
980 
981 	platform_label = rtnl_dereference(net->mpls.platform_label);
982 	platform_labels = net->mpls.platform_labels;
983 	for (; index < platform_labels; index++) {
984 		struct mpls_route *rt;
985 		rt = rtnl_dereference(platform_label[index]);
986 		if (!rt)
987 			continue;
988 
989 		if (mpls_dump_route(skb, NETLINK_CB(cb->skb).portid,
990 				    cb->nlh->nlmsg_seq, RTM_NEWROUTE,
991 				    index, rt, NLM_F_MULTI) < 0)
992 			break;
993 	}
994 	cb->args[0] = index;
995 
996 	return skb->len;
997 }
998 
999 static inline size_t lfib_nlmsg_size(struct mpls_route *rt)
1000 {
1001 	size_t payload =
1002 		NLMSG_ALIGN(sizeof(struct rtmsg))
1003 		+ nla_total_size(2 + rt->rt_via_alen)	/* RTA_VIA */
1004 		+ nla_total_size(4);			/* RTA_DST */
1005 	if (rt->rt_labels)				/* RTA_NEWDST */
1006 		payload += nla_total_size(rt->rt_labels * 4);
1007 	if (rt->rt_dev)					/* RTA_OIF */
1008 		payload += nla_total_size(4);
1009 	return payload;
1010 }
1011 
1012 static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
1013 		       struct nlmsghdr *nlh, struct net *net, u32 portid,
1014 		       unsigned int nlm_flags)
1015 {
1016 	struct sk_buff *skb;
1017 	u32 seq = nlh ? nlh->nlmsg_seq : 0;
1018 	int err = -ENOBUFS;
1019 
1020 	skb = nlmsg_new(lfib_nlmsg_size(rt), GFP_KERNEL);
1021 	if (skb == NULL)
1022 		goto errout;
1023 
1024 	err = mpls_dump_route(skb, portid, seq, event, label, rt, nlm_flags);
1025 	if (err < 0) {
1026 		/* -EMSGSIZE implies BUG in lfib_nlmsg_size */
1027 		WARN_ON(err == -EMSGSIZE);
1028 		kfree_skb(skb);
1029 		goto errout;
1030 	}
1031 	rtnl_notify(skb, net, portid, RTNLGRP_MPLS_ROUTE, nlh, GFP_KERNEL);
1032 
1033 	return;
1034 errout:
1035 	if (err < 0)
1036 		rtnl_set_sk_err(net, RTNLGRP_MPLS_ROUTE, err);
1037 }
1038 
1039 static int resize_platform_label_table(struct net *net, size_t limit)
1040 {
1041 	size_t size = sizeof(struct mpls_route *) * limit;
1042 	size_t old_limit;
1043 	size_t cp_size;
1044 	struct mpls_route __rcu **labels = NULL, **old;
1045 	struct mpls_route *rt0 = NULL, *rt2 = NULL;
1046 	unsigned index;
1047 
1048 	if (size) {
1049 		labels = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
1050 		if (!labels)
1051 			labels = vzalloc(size);
1052 
1053 		if (!labels)
1054 			goto nolabels;
1055 	}
1056 
1057 	/* In case the predefined labels need to be populated */
1058 	if (limit > MPLS_LABEL_IPV4NULL) {
1059 		struct net_device *lo = net->loopback_dev;
1060 		rt0 = mpls_rt_alloc(lo->addr_len);
1061 		if (!rt0)
1062 			goto nort0;
1063 		RCU_INIT_POINTER(rt0->rt_dev, lo);
1064 		rt0->rt_protocol = RTPROT_KERNEL;
1065 		rt0->rt_payload_type = MPT_IPV4;
1066 		rt0->rt_via_table = NEIGH_LINK_TABLE;
1067 		memcpy(rt0->rt_via, lo->dev_addr, lo->addr_len);
1068 	}
1069 	if (limit > MPLS_LABEL_IPV6NULL) {
1070 		struct net_device *lo = net->loopback_dev;
1071 		rt2 = mpls_rt_alloc(lo->addr_len);
1072 		if (!rt2)
1073 			goto nort2;
1074 		RCU_INIT_POINTER(rt2->rt_dev, lo);
1075 		rt2->rt_protocol = RTPROT_KERNEL;
1076 		rt2->rt_payload_type = MPT_IPV6;
1077 		rt2->rt_via_table = NEIGH_LINK_TABLE;
1078 		memcpy(rt2->rt_via, lo->dev_addr, lo->addr_len);
1079 	}
1080 
1081 	rtnl_lock();
1082 	/* Remember the original table */
1083 	old = rtnl_dereference(net->mpls.platform_label);
1084 	old_limit = net->mpls.platform_labels;
1085 
1086 	/* Free any labels beyond the new table */
1087 	for (index = limit; index < old_limit; index++)
1088 		mpls_route_update(net, index, NULL, NULL, NULL);
1089 
1090 	/* Copy over the old labels */
1091 	cp_size = size;
1092 	if (old_limit < limit)
1093 		cp_size = old_limit * sizeof(struct mpls_route *);
1094 
1095 	memcpy(labels, old, cp_size);
1096 
1097 	/* If needed set the predefined labels */
1098 	if ((old_limit <= MPLS_LABEL_IPV6NULL) &&
1099 	    (limit > MPLS_LABEL_IPV6NULL)) {
1100 		RCU_INIT_POINTER(labels[MPLS_LABEL_IPV6NULL], rt2);
1101 		rt2 = NULL;
1102 	}
1103 
1104 	if ((old_limit <= MPLS_LABEL_IPV4NULL) &&
1105 	    (limit > MPLS_LABEL_IPV4NULL)) {
1106 		RCU_INIT_POINTER(labels[MPLS_LABEL_IPV4NULL], rt0);
1107 		rt0 = NULL;
1108 	}
1109 
1110 	/* Update the global pointers */
1111 	net->mpls.platform_labels = limit;
1112 	rcu_assign_pointer(net->mpls.platform_label, labels);
1113 
1114 	rtnl_unlock();
1115 
1116 	mpls_rt_free(rt2);
1117 	mpls_rt_free(rt0);
1118 
1119 	if (old) {
1120 		synchronize_rcu();
1121 		kvfree(old);
1122 	}
1123 	return 0;
1124 
1125 nort2:
1126 	mpls_rt_free(rt0);
1127 nort0:
1128 	kvfree(labels);
1129 nolabels:
1130 	return -ENOMEM;
1131 }
1132 
1133 static int mpls_platform_labels(struct ctl_table *table, int write,
1134 				void __user *buffer, size_t *lenp, loff_t *ppos)
1135 {
1136 	struct net *net = table->data;
1137 	int platform_labels = net->mpls.platform_labels;
1138 	int ret;
1139 	struct ctl_table tmp = {
1140 		.procname	= table->procname,
1141 		.data		= &platform_labels,
1142 		.maxlen		= sizeof(int),
1143 		.mode		= table->mode,
1144 		.extra1		= &zero,
1145 		.extra2		= &label_limit,
1146 	};
1147 
1148 	ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos);
1149 
1150 	if (write && ret == 0)
1151 		ret = resize_platform_label_table(net, platform_labels);
1152 
1153 	return ret;
1154 }
1155 
1156 static const struct ctl_table mpls_table[] = {
1157 	{
1158 		.procname	= "platform_labels",
1159 		.data		= NULL,
1160 		.maxlen		= sizeof(int),
1161 		.mode		= 0644,
1162 		.proc_handler	= mpls_platform_labels,
1163 	},
1164 	{ }
1165 };
1166 
1167 static int mpls_net_init(struct net *net)
1168 {
1169 	struct ctl_table *table;
1170 
1171 	net->mpls.platform_labels = 0;
1172 	net->mpls.platform_label = NULL;
1173 
1174 	table = kmemdup(mpls_table, sizeof(mpls_table), GFP_KERNEL);
1175 	if (table == NULL)
1176 		return -ENOMEM;
1177 
1178 	table[0].data = net;
1179 	net->mpls.ctl = register_net_sysctl(net, "net/mpls", table);
1180 	if (net->mpls.ctl == NULL) {
1181 		kfree(table);
1182 		return -ENOMEM;
1183 	}
1184 
1185 	return 0;
1186 }
1187 
1188 static void mpls_net_exit(struct net *net)
1189 {
1190 	struct mpls_route __rcu **platform_label;
1191 	size_t platform_labels;
1192 	struct ctl_table *table;
1193 	unsigned int index;
1194 
1195 	table = net->mpls.ctl->ctl_table_arg;
1196 	unregister_net_sysctl_table(net->mpls.ctl);
1197 	kfree(table);
1198 
1199 	/* An rcu grace period has passed since there was a device in
1200 	 * the network namespace (and thus the last in flight packet)
1201 	 * left this network namespace.  This is because
1202 	 * unregister_netdevice_many and netdev_run_todo has completed
1203 	 * for each network device that was in this network namespace.
1204 	 *
1205 	 * As such no additional rcu synchronization is necessary when
1206 	 * freeing the platform_label table.
1207 	 */
1208 	rtnl_lock();
1209 	platform_label = rtnl_dereference(net->mpls.platform_label);
1210 	platform_labels = net->mpls.platform_labels;
1211 	for (index = 0; index < platform_labels; index++) {
1212 		struct mpls_route *rt = rtnl_dereference(platform_label[index]);
1213 		RCU_INIT_POINTER(platform_label[index], NULL);
1214 		mpls_rt_free(rt);
1215 	}
1216 	rtnl_unlock();
1217 
1218 	kvfree(platform_label);
1219 }
1220 
1221 static struct pernet_operations mpls_net_ops = {
1222 	.init = mpls_net_init,
1223 	.exit = mpls_net_exit,
1224 };
1225 
1226 static int __init mpls_init(void)
1227 {
1228 	int err;
1229 
1230 	BUILD_BUG_ON(sizeof(struct mpls_shim_hdr) != 4);
1231 
1232 	err = register_pernet_subsys(&mpls_net_ops);
1233 	if (err)
1234 		goto out;
1235 
1236 	err = register_netdevice_notifier(&mpls_dev_notifier);
1237 	if (err)
1238 		goto out_unregister_pernet;
1239 
1240 	dev_add_pack(&mpls_packet_type);
1241 
1242 	rtnl_register(PF_MPLS, RTM_NEWROUTE, mpls_rtm_newroute, NULL, NULL);
1243 	rtnl_register(PF_MPLS, RTM_DELROUTE, mpls_rtm_delroute, NULL, NULL);
1244 	rtnl_register(PF_MPLS, RTM_GETROUTE, NULL, mpls_dump_routes, NULL);
1245 	err = 0;
1246 out:
1247 	return err;
1248 
1249 out_unregister_pernet:
1250 	unregister_pernet_subsys(&mpls_net_ops);
1251 	goto out;
1252 }
1253 module_init(mpls_init);
1254 
1255 static void __exit mpls_exit(void)
1256 {
1257 	rtnl_unregister_all(PF_MPLS);
1258 	dev_remove_pack(&mpls_packet_type);
1259 	unregister_netdevice_notifier(&mpls_dev_notifier);
1260 	unregister_pernet_subsys(&mpls_net_ops);
1261 }
1262 module_exit(mpls_exit);
1263 
1264 MODULE_DESCRIPTION("MultiProtocol Label Switching");
1265 MODULE_LICENSE("GPL v2");
1266 MODULE_ALIAS_NETPROTO(PF_MPLS);
1267