xref: /openbmc/linux/net/mctp/route.c (revision b022f886)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Management Component Transport Protocol (MCTP) - routing
4  * implementation.
5  *
6  * This is currently based on a simple routing table, with no dst cache. The
7  * number of routes should stay fairly small, so the lookup cost is small.
8  *
9  * Copyright (c) 2021 Code Construct
10  * Copyright (c) 2021 Google
11  */
12 
13 #include <linux/idr.h>
14 #include <linux/mctp.h>
15 #include <linux/netdevice.h>
16 #include <linux/rtnetlink.h>
17 #include <linux/skbuff.h>
18 
19 #include <uapi/linux/if_arp.h>
20 
21 #include <net/mctp.h>
22 #include <net/mctpdevice.h>
23 #include <net/netlink.h>
24 #include <net/sock.h>
25 
26 #include <trace/events/mctp.h>
27 
28 static const unsigned int mctp_message_maxlen = 64 * 1024;
29 static const unsigned long mctp_key_lifetime = 6 * CONFIG_HZ;
30 
31 /* route output callbacks */
32 static int mctp_route_discard(struct mctp_route *route, struct sk_buff *skb)
33 {
34 	kfree_skb(skb);
35 	return 0;
36 }
37 
38 static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
39 {
40 	struct mctp_skb_cb *cb = mctp_cb(skb);
41 	struct mctp_hdr *mh;
42 	struct sock *sk;
43 	u8 type;
44 
45 	WARN_ON(!rcu_read_lock_held());
46 
47 	/* TODO: look up in skb->cb? */
48 	mh = mctp_hdr(skb);
49 
50 	if (!skb_headlen(skb))
51 		return NULL;
52 
53 	type = (*(u8 *)skb->data) & 0x7f;
54 
55 	sk_for_each_rcu(sk, &net->mctp.binds) {
56 		struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
57 
58 		if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net)
59 			continue;
60 
61 		if (msk->bind_type != type)
62 			continue;
63 
64 		if (msk->bind_addr != MCTP_ADDR_ANY &&
65 		    msk->bind_addr != mh->dest)
66 			continue;
67 
68 		return msk;
69 	}
70 
71 	return NULL;
72 }
73 
74 static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local,
75 			   mctp_eid_t peer, u8 tag)
76 {
77 	if (key->local_addr != local)
78 		return false;
79 
80 	if (key->peer_addr != peer)
81 		return false;
82 
83 	if (key->tag != tag)
84 		return false;
85 
86 	return true;
87 }
88 
89 /* returns a key (with key->lock held, and refcounted), or NULL if no such
90  * key exists.
91  */
92 static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb,
93 					   mctp_eid_t peer,
94 					   unsigned long *irqflags)
95 	__acquires(&key->lock)
96 {
97 	struct mctp_sk_key *key, *ret;
98 	unsigned long flags;
99 	struct mctp_hdr *mh;
100 	u8 tag;
101 
102 	mh = mctp_hdr(skb);
103 	tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
104 
105 	ret = NULL;
106 	spin_lock_irqsave(&net->mctp.keys_lock, flags);
107 
108 	hlist_for_each_entry(key, &net->mctp.keys, hlist) {
109 		if (!mctp_key_match(key, mh->dest, peer, tag))
110 			continue;
111 
112 		spin_lock(&key->lock);
113 		if (key->valid) {
114 			refcount_inc(&key->refs);
115 			ret = key;
116 			break;
117 		}
118 		spin_unlock(&key->lock);
119 	}
120 
121 	if (ret) {
122 		spin_unlock(&net->mctp.keys_lock);
123 		*irqflags = flags;
124 	} else {
125 		spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
126 	}
127 
128 	return ret;
129 }
130 
131 static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk,
132 					  mctp_eid_t local, mctp_eid_t peer,
133 					  u8 tag, gfp_t gfp)
134 {
135 	struct mctp_sk_key *key;
136 
137 	key = kzalloc(sizeof(*key), gfp);
138 	if (!key)
139 		return NULL;
140 
141 	key->peer_addr = peer;
142 	key->local_addr = local;
143 	key->tag = tag;
144 	key->sk = &msk->sk;
145 	key->valid = true;
146 	spin_lock_init(&key->lock);
147 	refcount_set(&key->refs, 1);
148 
149 	return key;
150 }
151 
152 void mctp_key_unref(struct mctp_sk_key *key)
153 {
154 	if (refcount_dec_and_test(&key->refs))
155 		kfree(key);
156 }
157 
158 static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
159 {
160 	struct net *net = sock_net(&msk->sk);
161 	struct mctp_sk_key *tmp;
162 	unsigned long flags;
163 	int rc = 0;
164 
165 	spin_lock_irqsave(&net->mctp.keys_lock, flags);
166 
167 	hlist_for_each_entry(tmp, &net->mctp.keys, hlist) {
168 		if (mctp_key_match(tmp, key->local_addr, key->peer_addr,
169 				   key->tag)) {
170 			spin_lock(&tmp->lock);
171 			if (tmp->valid)
172 				rc = -EEXIST;
173 			spin_unlock(&tmp->lock);
174 			if (rc)
175 				break;
176 		}
177 	}
178 
179 	if (!rc) {
180 		refcount_inc(&key->refs);
181 		key->expiry = jiffies + mctp_key_lifetime;
182 		timer_reduce(&msk->key_expiry, key->expiry);
183 
184 		hlist_add_head(&key->hlist, &net->mctp.keys);
185 		hlist_add_head(&key->sklist, &msk->keys);
186 	}
187 
188 	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
189 
190 	return rc;
191 }
192 
193 /* We're done with the key; unset valid and remove from lists. There may still
194  * be outstanding refs on the key though...
195  */
196 static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net,
197 				   unsigned long flags)
198 	__releases(&key->lock)
199 {
200 	struct sk_buff *skb;
201 
202 	skb = key->reasm_head;
203 	key->reasm_head = NULL;
204 	key->reasm_dead = true;
205 	key->valid = false;
206 	spin_unlock_irqrestore(&key->lock, flags);
207 
208 	spin_lock_irqsave(&net->mctp.keys_lock, flags);
209 	hlist_del(&key->hlist);
210 	hlist_del(&key->sklist);
211 	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
212 
213 	/* one unref for the lists */
214 	mctp_key_unref(key);
215 
216 	/* and one for the local reference */
217 	mctp_key_unref(key);
218 
219 	if (skb)
220 		kfree_skb(skb);
221 
222 }
223 
224 static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb)
225 {
226 	struct mctp_hdr *hdr = mctp_hdr(skb);
227 	u8 exp_seq, this_seq;
228 
229 	this_seq = (hdr->flags_seq_tag >> MCTP_HDR_SEQ_SHIFT)
230 		& MCTP_HDR_SEQ_MASK;
231 
232 	if (!key->reasm_head) {
233 		key->reasm_head = skb;
234 		key->reasm_tailp = &(skb_shinfo(skb)->frag_list);
235 		key->last_seq = this_seq;
236 		return 0;
237 	}
238 
239 	exp_seq = (key->last_seq + 1) & MCTP_HDR_SEQ_MASK;
240 
241 	if (this_seq != exp_seq)
242 		return -EINVAL;
243 
244 	if (key->reasm_head->len + skb->len > mctp_message_maxlen)
245 		return -EINVAL;
246 
247 	skb->next = NULL;
248 	skb->sk = NULL;
249 	*key->reasm_tailp = skb;
250 	key->reasm_tailp = &skb->next;
251 
252 	key->last_seq = this_seq;
253 
254 	key->reasm_head->data_len += skb->len;
255 	key->reasm_head->len += skb->len;
256 	key->reasm_head->truesize += skb->truesize;
257 
258 	return 0;
259 }
260 
261 static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
262 {
263 	struct net *net = dev_net(skb->dev);
264 	struct mctp_sk_key *key;
265 	struct mctp_sock *msk;
266 	struct mctp_hdr *mh;
267 	unsigned long f;
268 	u8 tag, flags;
269 	int rc;
270 
271 	msk = NULL;
272 	rc = -EINVAL;
273 
274 	/* we may be receiving a locally-routed packet; drop source sk
275 	 * accounting
276 	 */
277 	skb_orphan(skb);
278 
279 	/* ensure we have enough data for a header and a type */
280 	if (skb->len < sizeof(struct mctp_hdr) + 1)
281 		goto out;
282 
283 	/* grab header, advance data ptr */
284 	mh = mctp_hdr(skb);
285 	skb_pull(skb, sizeof(struct mctp_hdr));
286 
287 	if (mh->ver != 1)
288 		goto out;
289 
290 	flags = mh->flags_seq_tag & (MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM);
291 	tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
292 
293 	rcu_read_lock();
294 
295 	/* lookup socket / reasm context, exactly matching (src,dest,tag).
296 	 * we hold a ref on the key, and key->lock held.
297 	 */
298 	key = mctp_lookup_key(net, skb, mh->src, &f);
299 
300 	if (flags & MCTP_HDR_FLAG_SOM) {
301 		if (key) {
302 			msk = container_of(key->sk, struct mctp_sock, sk);
303 		} else {
304 			/* first response to a broadcast? do a more general
305 			 * key lookup to find the socket, but don't use this
306 			 * key for reassembly - we'll create a more specific
307 			 * one for future packets if required (ie, !EOM).
308 			 */
309 			key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY, &f);
310 			if (key) {
311 				msk = container_of(key->sk,
312 						   struct mctp_sock, sk);
313 				spin_unlock_irqrestore(&key->lock, f);
314 				mctp_key_unref(key);
315 				key = NULL;
316 			}
317 		}
318 
319 		if (!key && !msk && (tag & MCTP_HDR_FLAG_TO))
320 			msk = mctp_lookup_bind(net, skb);
321 
322 		if (!msk) {
323 			rc = -ENOENT;
324 			goto out_unlock;
325 		}
326 
327 		/* single-packet message? deliver to socket, clean up any
328 		 * pending key.
329 		 */
330 		if (flags & MCTP_HDR_FLAG_EOM) {
331 			sock_queue_rcv_skb(&msk->sk, skb);
332 			if (key) {
333 				/* we've hit a pending reassembly; not much we
334 				 * can do but drop it
335 				 */
336 				trace_mctp_key_release(key,
337 						       MCTP_TRACE_KEY_REPLIED);
338 				__mctp_key_unlock_drop(key, net, f);
339 				key = NULL;
340 			}
341 			rc = 0;
342 			goto out_unlock;
343 		}
344 
345 		/* broadcast response or a bind() - create a key for further
346 		 * packets for this message
347 		 */
348 		if (!key) {
349 			key = mctp_key_alloc(msk, mh->dest, mh->src,
350 					     tag, GFP_ATOMIC);
351 			if (!key) {
352 				rc = -ENOMEM;
353 				goto out_unlock;
354 			}
355 
356 			/* we can queue without the key lock here, as the
357 			 * key isn't observable yet
358 			 */
359 			mctp_frag_queue(key, skb);
360 
361 			/* if the key_add fails, we've raced with another
362 			 * SOM packet with the same src, dest and tag. There's
363 			 * no way to distinguish future packets, so all we
364 			 * can do is drop; we'll free the skb on exit from
365 			 * this function.
366 			 */
367 			rc = mctp_key_add(key, msk);
368 			if (rc)
369 				kfree(key);
370 
371 			trace_mctp_key_acquire(key);
372 
373 			/* we don't need to release key->lock on exit */
374 			key = NULL;
375 
376 		} else {
377 			if (key->reasm_head || key->reasm_dead) {
378 				/* duplicate start? drop everything */
379 				trace_mctp_key_release(key,
380 						       MCTP_TRACE_KEY_INVALIDATED);
381 				__mctp_key_unlock_drop(key, net, f);
382 				rc = -EEXIST;
383 				key = NULL;
384 			} else {
385 				rc = mctp_frag_queue(key, skb);
386 			}
387 		}
388 
389 	} else if (key) {
390 		/* this packet continues a previous message; reassemble
391 		 * using the message-specific key
392 		 */
393 
394 		/* we need to be continuing an existing reassembly... */
395 		if (!key->reasm_head)
396 			rc = -EINVAL;
397 		else
398 			rc = mctp_frag_queue(key, skb);
399 
400 		/* end of message? deliver to socket, and we're done with
401 		 * the reassembly/response key
402 		 */
403 		if (!rc && flags & MCTP_HDR_FLAG_EOM) {
404 			sock_queue_rcv_skb(key->sk, key->reasm_head);
405 			key->reasm_head = NULL;
406 			trace_mctp_key_release(key, MCTP_TRACE_KEY_REPLIED);
407 			__mctp_key_unlock_drop(key, net, f);
408 			key = NULL;
409 		}
410 
411 	} else {
412 		/* not a start, no matching key */
413 		rc = -ENOENT;
414 	}
415 
416 out_unlock:
417 	rcu_read_unlock();
418 	if (key) {
419 		spin_unlock_irqrestore(&key->lock, f);
420 		mctp_key_unref(key);
421 	}
422 out:
423 	if (rc)
424 		kfree_skb(skb);
425 	return rc;
426 }
427 
428 static unsigned int mctp_route_mtu(struct mctp_route *rt)
429 {
430 	return rt->mtu ?: READ_ONCE(rt->dev->dev->mtu);
431 }
432 
433 static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb)
434 {
435 	struct mctp_hdr *hdr = mctp_hdr(skb);
436 	char daddr_buf[MAX_ADDR_LEN];
437 	char *daddr = NULL;
438 	unsigned int mtu;
439 	int rc;
440 
441 	skb->protocol = htons(ETH_P_MCTP);
442 
443 	mtu = READ_ONCE(skb->dev->mtu);
444 	if (skb->len > mtu) {
445 		kfree_skb(skb);
446 		return -EMSGSIZE;
447 	}
448 
449 	/* If lookup fails let the device handle daddr==NULL */
450 	if (mctp_neigh_lookup(route->dev, hdr->dest, daddr_buf) == 0)
451 		daddr = daddr_buf;
452 
453 	rc = dev_hard_header(skb, skb->dev, ntohs(skb->protocol),
454 			     daddr, skb->dev->dev_addr, skb->len);
455 	if (rc) {
456 		kfree_skb(skb);
457 		return -EHOSTUNREACH;
458 	}
459 
460 	rc = dev_queue_xmit(skb);
461 	if (rc)
462 		rc = net_xmit_errno(rc);
463 
464 	return rc;
465 }
466 
467 /* route alloc/release */
468 static void mctp_route_release(struct mctp_route *rt)
469 {
470 	if (refcount_dec_and_test(&rt->refs)) {
471 		mctp_dev_put(rt->dev);
472 		kfree_rcu(rt, rcu);
473 	}
474 }
475 
476 /* returns a route with the refcount at 1 */
477 static struct mctp_route *mctp_route_alloc(void)
478 {
479 	struct mctp_route *rt;
480 
481 	rt = kzalloc(sizeof(*rt), GFP_KERNEL);
482 	if (!rt)
483 		return NULL;
484 
485 	INIT_LIST_HEAD(&rt->list);
486 	refcount_set(&rt->refs, 1);
487 	rt->output = mctp_route_discard;
488 
489 	return rt;
490 }
491 
492 unsigned int mctp_default_net(struct net *net)
493 {
494 	return READ_ONCE(net->mctp.default_net);
495 }
496 
497 int mctp_default_net_set(struct net *net, unsigned int index)
498 {
499 	if (index == 0)
500 		return -EINVAL;
501 	WRITE_ONCE(net->mctp.default_net, index);
502 	return 0;
503 }
504 
505 /* tag management */
506 static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key,
507 			     struct mctp_sock *msk)
508 {
509 	struct netns_mctp *mns = &net->mctp;
510 
511 	lockdep_assert_held(&mns->keys_lock);
512 
513 	key->expiry = jiffies + mctp_key_lifetime;
514 	timer_reduce(&msk->key_expiry, key->expiry);
515 
516 	/* we hold the net->key_lock here, allowing updates to both
517 	 * then net and sk
518 	 */
519 	hlist_add_head_rcu(&key->hlist, &mns->keys);
520 	hlist_add_head_rcu(&key->sklist, &msk->keys);
521 	refcount_inc(&key->refs);
522 }
523 
524 /* Allocate a locally-owned tag value for (saddr, daddr), and reserve
525  * it for the socket msk
526  */
527 static int mctp_alloc_local_tag(struct mctp_sock *msk,
528 				mctp_eid_t saddr, mctp_eid_t daddr, u8 *tagp)
529 {
530 	struct net *net = sock_net(&msk->sk);
531 	struct netns_mctp *mns = &net->mctp;
532 	struct mctp_sk_key *key, *tmp;
533 	unsigned long flags;
534 	int rc = -EAGAIN;
535 	u8 tagbits;
536 
537 	/* for NULL destination EIDs, we may get a response from any peer */
538 	if (daddr == MCTP_ADDR_NULL)
539 		daddr = MCTP_ADDR_ANY;
540 
541 	/* be optimistic, alloc now */
542 	key = mctp_key_alloc(msk, saddr, daddr, 0, GFP_KERNEL);
543 	if (!key)
544 		return -ENOMEM;
545 
546 	/* 8 possible tag values */
547 	tagbits = 0xff;
548 
549 	spin_lock_irqsave(&mns->keys_lock, flags);
550 
551 	/* Walk through the existing keys, looking for potential conflicting
552 	 * tags. If we find a conflict, clear that bit from tagbits
553 	 */
554 	hlist_for_each_entry(tmp, &mns->keys, hlist) {
555 		/* We can check the lookup fields (*_addr, tag) without the
556 		 * lock held, they don't change over the lifetime of the key.
557 		 */
558 
559 		/* if we don't own the tag, it can't conflict */
560 		if (tmp->tag & MCTP_HDR_FLAG_TO)
561 			continue;
562 
563 		if (!((tmp->peer_addr == daddr ||
564 		       tmp->peer_addr == MCTP_ADDR_ANY) &&
565 		       tmp->local_addr == saddr))
566 			continue;
567 
568 		spin_lock(&tmp->lock);
569 		/* key must still be valid. If we find a match, clear the
570 		 * potential tag value
571 		 */
572 		if (tmp->valid)
573 			tagbits &= ~(1 << tmp->tag);
574 		spin_unlock(&tmp->lock);
575 
576 		if (!tagbits)
577 			break;
578 	}
579 
580 	if (tagbits) {
581 		key->tag = __ffs(tagbits);
582 		mctp_reserve_tag(net, key, msk);
583 		trace_mctp_key_acquire(key);
584 
585 		*tagp = key->tag;
586 		rc = 0;
587 	}
588 
589 	spin_unlock_irqrestore(&mns->keys_lock, flags);
590 
591 	if (!tagbits)
592 		kfree(key);
593 
594 	return rc;
595 }
596 
597 /* routing lookups */
598 static bool mctp_rt_match_eid(struct mctp_route *rt,
599 			      unsigned int net, mctp_eid_t eid)
600 {
601 	return READ_ONCE(rt->dev->net) == net &&
602 		rt->min <= eid && rt->max >= eid;
603 }
604 
605 /* compares match, used for duplicate prevention */
606 static bool mctp_rt_compare_exact(struct mctp_route *rt1,
607 				  struct mctp_route *rt2)
608 {
609 	ASSERT_RTNL();
610 	return rt1->dev->net == rt2->dev->net &&
611 		rt1->min == rt2->min &&
612 		rt1->max == rt2->max;
613 }
614 
615 struct mctp_route *mctp_route_lookup(struct net *net, unsigned int dnet,
616 				     mctp_eid_t daddr)
617 {
618 	struct mctp_route *tmp, *rt = NULL;
619 
620 	list_for_each_entry_rcu(tmp, &net->mctp.routes, list) {
621 		/* TODO: add metrics */
622 		if (mctp_rt_match_eid(tmp, dnet, daddr)) {
623 			if (refcount_inc_not_zero(&tmp->refs)) {
624 				rt = tmp;
625 				break;
626 			}
627 		}
628 	}
629 
630 	return rt;
631 }
632 
633 static struct mctp_route *mctp_route_lookup_null(struct net *net,
634 						 struct net_device *dev)
635 {
636 	struct mctp_route *rt;
637 
638 	list_for_each_entry_rcu(rt, &net->mctp.routes, list) {
639 		if (rt->dev->dev == dev && rt->type == RTN_LOCAL &&
640 		    refcount_inc_not_zero(&rt->refs))
641 			return rt;
642 	}
643 
644 	return NULL;
645 }
646 
647 /* sends a skb to rt and releases the route. */
648 int mctp_do_route(struct mctp_route *rt, struct sk_buff *skb)
649 {
650 	int rc;
651 
652 	rc = rt->output(rt, skb);
653 	mctp_route_release(rt);
654 	return rc;
655 }
656 
657 static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
658 				  unsigned int mtu, u8 tag)
659 {
660 	const unsigned int hlen = sizeof(struct mctp_hdr);
661 	struct mctp_hdr *hdr, *hdr2;
662 	unsigned int pos, size;
663 	struct sk_buff *skb2;
664 	int rc;
665 	u8 seq;
666 
667 	hdr = mctp_hdr(skb);
668 	seq = 0;
669 	rc = 0;
670 
671 	if (mtu < hlen + 1) {
672 		kfree_skb(skb);
673 		return -EMSGSIZE;
674 	}
675 
676 	/* we've got the header */
677 	skb_pull(skb, hlen);
678 
679 	for (pos = 0; pos < skb->len;) {
680 		/* size of message payload */
681 		size = min(mtu - hlen, skb->len - pos);
682 
683 		skb2 = alloc_skb(MCTP_HEADER_MAXLEN + hlen + size, GFP_KERNEL);
684 		if (!skb2) {
685 			rc = -ENOMEM;
686 			break;
687 		}
688 
689 		/* generic skb copy */
690 		skb2->protocol = skb->protocol;
691 		skb2->priority = skb->priority;
692 		skb2->dev = skb->dev;
693 		memcpy(skb2->cb, skb->cb, sizeof(skb2->cb));
694 
695 		if (skb->sk)
696 			skb_set_owner_w(skb2, skb->sk);
697 
698 		/* establish packet */
699 		skb_reserve(skb2, MCTP_HEADER_MAXLEN);
700 		skb_reset_network_header(skb2);
701 		skb_put(skb2, hlen + size);
702 		skb2->transport_header = skb2->network_header + hlen;
703 
704 		/* copy header fields, calculate SOM/EOM flags & seq */
705 		hdr2 = mctp_hdr(skb2);
706 		hdr2->ver = hdr->ver;
707 		hdr2->dest = hdr->dest;
708 		hdr2->src = hdr->src;
709 		hdr2->flags_seq_tag = tag &
710 			(MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
711 
712 		if (pos == 0)
713 			hdr2->flags_seq_tag |= MCTP_HDR_FLAG_SOM;
714 
715 		if (pos + size == skb->len)
716 			hdr2->flags_seq_tag |= MCTP_HDR_FLAG_EOM;
717 
718 		hdr2->flags_seq_tag |= seq << MCTP_HDR_SEQ_SHIFT;
719 
720 		/* copy message payload */
721 		skb_copy_bits(skb, pos, skb_transport_header(skb2), size);
722 
723 		/* do route, but don't drop the rt reference */
724 		rc = rt->output(rt, skb2);
725 		if (rc)
726 			break;
727 
728 		seq = (seq + 1) & MCTP_HDR_SEQ_MASK;
729 		pos += size;
730 	}
731 
732 	mctp_route_release(rt);
733 	consume_skb(skb);
734 	return rc;
735 }
736 
737 int mctp_local_output(struct sock *sk, struct mctp_route *rt,
738 		      struct sk_buff *skb, mctp_eid_t daddr, u8 req_tag)
739 {
740 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
741 	struct mctp_skb_cb *cb = mctp_cb(skb);
742 	struct mctp_hdr *hdr;
743 	unsigned long flags;
744 	unsigned int mtu;
745 	mctp_eid_t saddr;
746 	int rc;
747 	u8 tag;
748 
749 	if (WARN_ON(!rt->dev))
750 		return -EINVAL;
751 
752 	spin_lock_irqsave(&rt->dev->addrs_lock, flags);
753 	if (rt->dev->num_addrs == 0) {
754 		rc = -EHOSTUNREACH;
755 	} else {
756 		/* use the outbound interface's first address as our source */
757 		saddr = rt->dev->addrs[0];
758 		rc = 0;
759 	}
760 	spin_unlock_irqrestore(&rt->dev->addrs_lock, flags);
761 
762 	if (rc)
763 		return rc;
764 
765 	if (req_tag & MCTP_HDR_FLAG_TO) {
766 		rc = mctp_alloc_local_tag(msk, saddr, daddr, &tag);
767 		if (rc)
768 			return rc;
769 		tag |= MCTP_HDR_FLAG_TO;
770 	} else {
771 		tag = req_tag;
772 	}
773 
774 
775 	skb->protocol = htons(ETH_P_MCTP);
776 	skb->priority = 0;
777 	skb_reset_transport_header(skb);
778 	skb_push(skb, sizeof(struct mctp_hdr));
779 	skb_reset_network_header(skb);
780 	skb->dev = rt->dev->dev;
781 
782 	/* cb->net will have been set on initial ingress */
783 	cb->src = saddr;
784 
785 	/* set up common header fields */
786 	hdr = mctp_hdr(skb);
787 	hdr->ver = 1;
788 	hdr->dest = daddr;
789 	hdr->src = saddr;
790 
791 	mtu = mctp_route_mtu(rt);
792 
793 	if (skb->len + sizeof(struct mctp_hdr) <= mtu) {
794 		hdr->flags_seq_tag = MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM |
795 			tag;
796 		return mctp_do_route(rt, skb);
797 	} else {
798 		return mctp_do_fragment_route(rt, skb, mtu, tag);
799 	}
800 }
801 
802 /* route management */
803 static int mctp_route_add(struct mctp_dev *mdev, mctp_eid_t daddr_start,
804 			  unsigned int daddr_extent, unsigned int mtu,
805 			  unsigned char type)
806 {
807 	int (*rtfn)(struct mctp_route *rt, struct sk_buff *skb);
808 	struct net *net = dev_net(mdev->dev);
809 	struct mctp_route *rt, *ert;
810 
811 	if (!mctp_address_ok(daddr_start))
812 		return -EINVAL;
813 
814 	if (daddr_extent > 0xff || daddr_start + daddr_extent >= 255)
815 		return -EINVAL;
816 
817 	switch (type) {
818 	case RTN_LOCAL:
819 		rtfn = mctp_route_input;
820 		break;
821 	case RTN_UNICAST:
822 		rtfn = mctp_route_output;
823 		break;
824 	default:
825 		return -EINVAL;
826 	}
827 
828 	rt = mctp_route_alloc();
829 	if (!rt)
830 		return -ENOMEM;
831 
832 	rt->min = daddr_start;
833 	rt->max = daddr_start + daddr_extent;
834 	rt->mtu = mtu;
835 	rt->dev = mdev;
836 	mctp_dev_hold(rt->dev);
837 	rt->type = type;
838 	rt->output = rtfn;
839 
840 	ASSERT_RTNL();
841 	/* Prevent duplicate identical routes. */
842 	list_for_each_entry(ert, &net->mctp.routes, list) {
843 		if (mctp_rt_compare_exact(rt, ert)) {
844 			mctp_route_release(rt);
845 			return -EEXIST;
846 		}
847 	}
848 
849 	list_add_rcu(&rt->list, &net->mctp.routes);
850 
851 	return 0;
852 }
853 
854 static int mctp_route_remove(struct mctp_dev *mdev, mctp_eid_t daddr_start,
855 			     unsigned int daddr_extent)
856 {
857 	struct net *net = dev_net(mdev->dev);
858 	struct mctp_route *rt, *tmp;
859 	mctp_eid_t daddr_end;
860 	bool dropped;
861 
862 	if (daddr_extent > 0xff || daddr_start + daddr_extent >= 255)
863 		return -EINVAL;
864 
865 	daddr_end = daddr_start + daddr_extent;
866 	dropped = false;
867 
868 	ASSERT_RTNL();
869 
870 	list_for_each_entry_safe(rt, tmp, &net->mctp.routes, list) {
871 		if (rt->dev == mdev &&
872 		    rt->min == daddr_start && rt->max == daddr_end) {
873 			list_del_rcu(&rt->list);
874 			/* TODO: immediate RTM_DELROUTE */
875 			mctp_route_release(rt);
876 			dropped = true;
877 		}
878 	}
879 
880 	return dropped ? 0 : -ENOENT;
881 }
882 
883 int mctp_route_add_local(struct mctp_dev *mdev, mctp_eid_t addr)
884 {
885 	return mctp_route_add(mdev, addr, 0, 0, RTN_LOCAL);
886 }
887 
888 int mctp_route_remove_local(struct mctp_dev *mdev, mctp_eid_t addr)
889 {
890 	return mctp_route_remove(mdev, addr, 0);
891 }
892 
893 /* removes all entries for a given device */
894 void mctp_route_remove_dev(struct mctp_dev *mdev)
895 {
896 	struct net *net = dev_net(mdev->dev);
897 	struct mctp_route *rt, *tmp;
898 
899 	ASSERT_RTNL();
900 	list_for_each_entry_safe(rt, tmp, &net->mctp.routes, list) {
901 		if (rt->dev == mdev) {
902 			list_del_rcu(&rt->list);
903 			/* TODO: immediate RTM_DELROUTE */
904 			mctp_route_release(rt);
905 		}
906 	}
907 }
908 
909 /* Incoming packet-handling */
910 
911 static int mctp_pkttype_receive(struct sk_buff *skb, struct net_device *dev,
912 				struct packet_type *pt,
913 				struct net_device *orig_dev)
914 {
915 	struct net *net = dev_net(dev);
916 	struct mctp_dev *mdev;
917 	struct mctp_skb_cb *cb;
918 	struct mctp_route *rt;
919 	struct mctp_hdr *mh;
920 
921 	rcu_read_lock();
922 	mdev = __mctp_dev_get(dev);
923 	rcu_read_unlock();
924 	if (!mdev) {
925 		/* basic non-data sanity checks */
926 		goto err_drop;
927 	}
928 
929 	if (!pskb_may_pull(skb, sizeof(struct mctp_hdr)))
930 		goto err_drop;
931 
932 	skb_reset_transport_header(skb);
933 	skb_reset_network_header(skb);
934 
935 	/* We have enough for a header; decode and route */
936 	mh = mctp_hdr(skb);
937 	if (mh->ver < MCTP_VER_MIN || mh->ver > MCTP_VER_MAX)
938 		goto err_drop;
939 
940 	cb = __mctp_cb(skb);
941 	cb->net = READ_ONCE(mdev->net);
942 
943 	rt = mctp_route_lookup(net, cb->net, mh->dest);
944 
945 	/* NULL EID, but addressed to our physical address */
946 	if (!rt && mh->dest == MCTP_ADDR_NULL && skb->pkt_type == PACKET_HOST)
947 		rt = mctp_route_lookup_null(net, dev);
948 
949 	if (!rt)
950 		goto err_drop;
951 
952 	mctp_do_route(rt, skb);
953 
954 	return NET_RX_SUCCESS;
955 
956 err_drop:
957 	kfree_skb(skb);
958 	return NET_RX_DROP;
959 }
960 
961 static struct packet_type mctp_packet_type = {
962 	.type = cpu_to_be16(ETH_P_MCTP),
963 	.func = mctp_pkttype_receive,
964 };
965 
966 /* netlink interface */
967 
968 static const struct nla_policy rta_mctp_policy[RTA_MAX + 1] = {
969 	[RTA_DST]		= { .type = NLA_U8 },
970 	[RTA_METRICS]		= { .type = NLA_NESTED },
971 	[RTA_OIF]		= { .type = NLA_U32 },
972 };
973 
974 /* Common part for RTM_NEWROUTE and RTM_DELROUTE parsing.
975  * tb must hold RTA_MAX+1 elements.
976  */
977 static int mctp_route_nlparse(struct sk_buff *skb, struct nlmsghdr *nlh,
978 			      struct netlink_ext_ack *extack,
979 			      struct nlattr **tb, struct rtmsg **rtm,
980 			      struct mctp_dev **mdev, mctp_eid_t *daddr_start)
981 {
982 	struct net *net = sock_net(skb->sk);
983 	struct net_device *dev;
984 	unsigned int ifindex;
985 	int rc;
986 
987 	rc = nlmsg_parse(nlh, sizeof(struct rtmsg), tb, RTA_MAX,
988 			 rta_mctp_policy, extack);
989 	if (rc < 0) {
990 		NL_SET_ERR_MSG(extack, "incorrect format");
991 		return rc;
992 	}
993 
994 	if (!tb[RTA_DST]) {
995 		NL_SET_ERR_MSG(extack, "dst EID missing");
996 		return -EINVAL;
997 	}
998 	*daddr_start = nla_get_u8(tb[RTA_DST]);
999 
1000 	if (!tb[RTA_OIF]) {
1001 		NL_SET_ERR_MSG(extack, "ifindex missing");
1002 		return -EINVAL;
1003 	}
1004 	ifindex = nla_get_u32(tb[RTA_OIF]);
1005 
1006 	*rtm = nlmsg_data(nlh);
1007 	if ((*rtm)->rtm_family != AF_MCTP) {
1008 		NL_SET_ERR_MSG(extack, "route family must be AF_MCTP");
1009 		return -EINVAL;
1010 	}
1011 
1012 	dev = __dev_get_by_index(net, ifindex);
1013 	if (!dev) {
1014 		NL_SET_ERR_MSG(extack, "bad ifindex");
1015 		return -ENODEV;
1016 	}
1017 	*mdev = mctp_dev_get_rtnl(dev);
1018 	if (!*mdev)
1019 		return -ENODEV;
1020 
1021 	if (dev->flags & IFF_LOOPBACK) {
1022 		NL_SET_ERR_MSG(extack, "no routes to loopback");
1023 		return -EINVAL;
1024 	}
1025 
1026 	return 0;
1027 }
1028 
1029 static const struct nla_policy rta_metrics_policy[RTAX_MAX + 1] = {
1030 	[RTAX_MTU]		= { .type = NLA_U32 },
1031 };
1032 
1033 static int mctp_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
1034 			 struct netlink_ext_ack *extack)
1035 {
1036 	struct nlattr *tb[RTA_MAX + 1];
1037 	struct nlattr *tbx[RTAX_MAX + 1];
1038 	mctp_eid_t daddr_start;
1039 	struct mctp_dev *mdev;
1040 	struct rtmsg *rtm;
1041 	unsigned int mtu;
1042 	int rc;
1043 
1044 	rc = mctp_route_nlparse(skb, nlh, extack, tb,
1045 				&rtm, &mdev, &daddr_start);
1046 	if (rc < 0)
1047 		return rc;
1048 
1049 	if (rtm->rtm_type != RTN_UNICAST) {
1050 		NL_SET_ERR_MSG(extack, "rtm_type must be RTN_UNICAST");
1051 		return -EINVAL;
1052 	}
1053 
1054 	mtu = 0;
1055 	if (tb[RTA_METRICS]) {
1056 		rc = nla_parse_nested(tbx, RTAX_MAX, tb[RTA_METRICS],
1057 				      rta_metrics_policy, NULL);
1058 		if (rc < 0)
1059 			return rc;
1060 		if (tbx[RTAX_MTU])
1061 			mtu = nla_get_u32(tbx[RTAX_MTU]);
1062 	}
1063 
1064 	if (rtm->rtm_type != RTN_UNICAST)
1065 		return -EINVAL;
1066 
1067 	rc = mctp_route_add(mdev, daddr_start, rtm->rtm_dst_len, mtu,
1068 			    rtm->rtm_type);
1069 	return rc;
1070 }
1071 
1072 static int mctp_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
1073 			 struct netlink_ext_ack *extack)
1074 {
1075 	struct nlattr *tb[RTA_MAX + 1];
1076 	mctp_eid_t daddr_start;
1077 	struct mctp_dev *mdev;
1078 	struct rtmsg *rtm;
1079 	int rc;
1080 
1081 	rc = mctp_route_nlparse(skb, nlh, extack, tb,
1082 				&rtm, &mdev, &daddr_start);
1083 	if (rc < 0)
1084 		return rc;
1085 
1086 	/* we only have unicast routes */
1087 	if (rtm->rtm_type != RTN_UNICAST)
1088 		return -EINVAL;
1089 
1090 	rc = mctp_route_remove(mdev, daddr_start, rtm->rtm_dst_len);
1091 	return rc;
1092 }
1093 
1094 static int mctp_fill_rtinfo(struct sk_buff *skb, struct mctp_route *rt,
1095 			    u32 portid, u32 seq, int event, unsigned int flags)
1096 {
1097 	struct nlmsghdr *nlh;
1098 	struct rtmsg *hdr;
1099 	void *metrics;
1100 
1101 	nlh = nlmsg_put(skb, portid, seq, event, sizeof(*hdr), flags);
1102 	if (!nlh)
1103 		return -EMSGSIZE;
1104 
1105 	hdr = nlmsg_data(nlh);
1106 	hdr->rtm_family = AF_MCTP;
1107 
1108 	/* we use the _len fields as a number of EIDs, rather than
1109 	 * a number of bits in the address
1110 	 */
1111 	hdr->rtm_dst_len = rt->max - rt->min;
1112 	hdr->rtm_src_len = 0;
1113 	hdr->rtm_tos = 0;
1114 	hdr->rtm_table = RT_TABLE_DEFAULT;
1115 	hdr->rtm_protocol = RTPROT_STATIC; /* everything is user-defined */
1116 	hdr->rtm_scope = RT_SCOPE_LINK; /* TODO: scope in mctp_route? */
1117 	hdr->rtm_type = rt->type;
1118 
1119 	if (nla_put_u8(skb, RTA_DST, rt->min))
1120 		goto cancel;
1121 
1122 	metrics = nla_nest_start_noflag(skb, RTA_METRICS);
1123 	if (!metrics)
1124 		goto cancel;
1125 
1126 	if (rt->mtu) {
1127 		if (nla_put_u32(skb, RTAX_MTU, rt->mtu))
1128 			goto cancel;
1129 	}
1130 
1131 	nla_nest_end(skb, metrics);
1132 
1133 	if (rt->dev) {
1134 		if (nla_put_u32(skb, RTA_OIF, rt->dev->dev->ifindex))
1135 			goto cancel;
1136 	}
1137 
1138 	/* TODO: conditional neighbour physaddr? */
1139 
1140 	nlmsg_end(skb, nlh);
1141 
1142 	return 0;
1143 
1144 cancel:
1145 	nlmsg_cancel(skb, nlh);
1146 	return -EMSGSIZE;
1147 }
1148 
1149 static int mctp_dump_rtinfo(struct sk_buff *skb, struct netlink_callback *cb)
1150 {
1151 	struct net *net = sock_net(skb->sk);
1152 	struct mctp_route *rt;
1153 	int s_idx, idx;
1154 
1155 	/* TODO: allow filtering on route data, possibly under
1156 	 * cb->strict_check
1157 	 */
1158 
1159 	/* TODO: change to struct overlay */
1160 	s_idx = cb->args[0];
1161 	idx = 0;
1162 
1163 	rcu_read_lock();
1164 	list_for_each_entry_rcu(rt, &net->mctp.routes, list) {
1165 		if (idx++ < s_idx)
1166 			continue;
1167 		if (mctp_fill_rtinfo(skb, rt,
1168 				     NETLINK_CB(cb->skb).portid,
1169 				     cb->nlh->nlmsg_seq,
1170 				     RTM_NEWROUTE, NLM_F_MULTI) < 0)
1171 			break;
1172 	}
1173 
1174 	rcu_read_unlock();
1175 	cb->args[0] = idx;
1176 
1177 	return skb->len;
1178 }
1179 
1180 /* net namespace implementation */
1181 static int __net_init mctp_routes_net_init(struct net *net)
1182 {
1183 	struct netns_mctp *ns = &net->mctp;
1184 
1185 	INIT_LIST_HEAD(&ns->routes);
1186 	INIT_HLIST_HEAD(&ns->binds);
1187 	mutex_init(&ns->bind_lock);
1188 	INIT_HLIST_HEAD(&ns->keys);
1189 	spin_lock_init(&ns->keys_lock);
1190 	WARN_ON(mctp_default_net_set(net, MCTP_INITIAL_DEFAULT_NET));
1191 	return 0;
1192 }
1193 
1194 static void __net_exit mctp_routes_net_exit(struct net *net)
1195 {
1196 	struct mctp_route *rt;
1197 
1198 	rcu_read_lock();
1199 	list_for_each_entry_rcu(rt, &net->mctp.routes, list)
1200 		mctp_route_release(rt);
1201 	rcu_read_unlock();
1202 }
1203 
1204 static struct pernet_operations mctp_net_ops = {
1205 	.init = mctp_routes_net_init,
1206 	.exit = mctp_routes_net_exit,
1207 };
1208 
1209 int __init mctp_routes_init(void)
1210 {
1211 	dev_add_pack(&mctp_packet_type);
1212 
1213 	rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_GETROUTE,
1214 			     NULL, mctp_dump_rtinfo, 0);
1215 	rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_NEWROUTE,
1216 			     mctp_newroute, NULL, 0);
1217 	rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_DELROUTE,
1218 			     mctp_delroute, NULL, 0);
1219 
1220 	return register_pernet_subsys(&mctp_net_ops);
1221 }
1222 
1223 void __exit mctp_routes_exit(void)
1224 {
1225 	unregister_pernet_subsys(&mctp_net_ops);
1226 	rtnl_unregister(PF_MCTP, RTM_DELROUTE);
1227 	rtnl_unregister(PF_MCTP, RTM_NEWROUTE);
1228 	rtnl_unregister(PF_MCTP, RTM_GETROUTE);
1229 	dev_remove_pack(&mctp_packet_type);
1230 }
1231