xref: /openbmc/linux/drivers/net/wireguard/netlink.c (revision f17f06a0)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5 
6 #include "netlink.h"
7 #include "device.h"
8 #include "peer.h"
9 #include "socket.h"
10 #include "queueing.h"
11 #include "messages.h"
12 
13 #include <uapi/linux/wireguard.h>
14 
15 #include <linux/if.h>
16 #include <net/genetlink.h>
17 #include <net/sock.h>
18 #include <crypto/algapi.h>
19 
20 static struct genl_family genl_family;
21 
22 static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = {
23 	[WGDEVICE_A_IFINDEX]		= { .type = NLA_U32 },
24 	[WGDEVICE_A_IFNAME]		= { .type = NLA_NUL_STRING, .len = IFNAMSIZ - 1 },
25 	[WGDEVICE_A_PRIVATE_KEY]	= { .type = NLA_EXACT_LEN, .len = NOISE_PUBLIC_KEY_LEN },
26 	[WGDEVICE_A_PUBLIC_KEY]		= { .type = NLA_EXACT_LEN, .len = NOISE_PUBLIC_KEY_LEN },
27 	[WGDEVICE_A_FLAGS]		= { .type = NLA_U32 },
28 	[WGDEVICE_A_LISTEN_PORT]	= { .type = NLA_U16 },
29 	[WGDEVICE_A_FWMARK]		= { .type = NLA_U32 },
30 	[WGDEVICE_A_PEERS]		= { .type = NLA_NESTED }
31 };
32 
33 static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = {
34 	[WGPEER_A_PUBLIC_KEY]				= { .type = NLA_EXACT_LEN, .len = NOISE_PUBLIC_KEY_LEN },
35 	[WGPEER_A_PRESHARED_KEY]			= { .type = NLA_EXACT_LEN, .len = NOISE_SYMMETRIC_KEY_LEN },
36 	[WGPEER_A_FLAGS]				= { .type = NLA_U32 },
37 	[WGPEER_A_ENDPOINT]				= { .type = NLA_MIN_LEN, .len = sizeof(struct sockaddr) },
38 	[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]	= { .type = NLA_U16 },
39 	[WGPEER_A_LAST_HANDSHAKE_TIME]			= { .type = NLA_EXACT_LEN, .len = sizeof(struct __kernel_timespec) },
40 	[WGPEER_A_RX_BYTES]				= { .type = NLA_U64 },
41 	[WGPEER_A_TX_BYTES]				= { .type = NLA_U64 },
42 	[WGPEER_A_ALLOWEDIPS]				= { .type = NLA_NESTED },
43 	[WGPEER_A_PROTOCOL_VERSION]			= { .type = NLA_U32 }
44 };
45 
46 static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = {
47 	[WGALLOWEDIP_A_FAMILY]		= { .type = NLA_U16 },
48 	[WGALLOWEDIP_A_IPADDR]		= { .type = NLA_MIN_LEN, .len = sizeof(struct in_addr) },
49 	[WGALLOWEDIP_A_CIDR_MASK]	= { .type = NLA_U8 }
50 };
51 
52 static struct wg_device *lookup_interface(struct nlattr **attrs,
53 					  struct sk_buff *skb)
54 {
55 	struct net_device *dev = NULL;
56 
57 	if (!attrs[WGDEVICE_A_IFINDEX] == !attrs[WGDEVICE_A_IFNAME])
58 		return ERR_PTR(-EBADR);
59 	if (attrs[WGDEVICE_A_IFINDEX])
60 		dev = dev_get_by_index(sock_net(skb->sk),
61 				       nla_get_u32(attrs[WGDEVICE_A_IFINDEX]));
62 	else if (attrs[WGDEVICE_A_IFNAME])
63 		dev = dev_get_by_name(sock_net(skb->sk),
64 				      nla_data(attrs[WGDEVICE_A_IFNAME]));
65 	if (!dev)
66 		return ERR_PTR(-ENODEV);
67 	if (!dev->rtnl_link_ops || !dev->rtnl_link_ops->kind ||
68 	    strcmp(dev->rtnl_link_ops->kind, KBUILD_MODNAME)) {
69 		dev_put(dev);
70 		return ERR_PTR(-EOPNOTSUPP);
71 	}
72 	return netdev_priv(dev);
73 }
74 
75 static int get_allowedips(struct sk_buff *skb, const u8 *ip, u8 cidr,
76 			  int family)
77 {
78 	struct nlattr *allowedip_nest;
79 
80 	allowedip_nest = nla_nest_start(skb, 0);
81 	if (!allowedip_nest)
82 		return -EMSGSIZE;
83 
84 	if (nla_put_u8(skb, WGALLOWEDIP_A_CIDR_MASK, cidr) ||
85 	    nla_put_u16(skb, WGALLOWEDIP_A_FAMILY, family) ||
86 	    nla_put(skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ?
87 		    sizeof(struct in6_addr) : sizeof(struct in_addr), ip)) {
88 		nla_nest_cancel(skb, allowedip_nest);
89 		return -EMSGSIZE;
90 	}
91 
92 	nla_nest_end(skb, allowedip_nest);
93 	return 0;
94 }
95 
96 struct dump_ctx {
97 	struct wg_device *wg;
98 	struct wg_peer *next_peer;
99 	u64 allowedips_seq;
100 	struct allowedips_node *next_allowedip;
101 };
102 
103 #define DUMP_CTX(cb) ((struct dump_ctx *)(cb)->args)
104 
105 static int
106 get_peer(struct wg_peer *peer, struct sk_buff *skb, struct dump_ctx *ctx)
107 {
108 
109 	struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, 0);
110 	struct allowedips_node *allowedips_node = ctx->next_allowedip;
111 	bool fail;
112 
113 	if (!peer_nest)
114 		return -EMSGSIZE;
115 
116 	down_read(&peer->handshake.lock);
117 	fail = nla_put(skb, WGPEER_A_PUBLIC_KEY, NOISE_PUBLIC_KEY_LEN,
118 		       peer->handshake.remote_static);
119 	up_read(&peer->handshake.lock);
120 	if (fail)
121 		goto err;
122 
123 	if (!allowedips_node) {
124 		const struct __kernel_timespec last_handshake = {
125 			.tv_sec = peer->walltime_last_handshake.tv_sec,
126 			.tv_nsec = peer->walltime_last_handshake.tv_nsec
127 		};
128 
129 		down_read(&peer->handshake.lock);
130 		fail = nla_put(skb, WGPEER_A_PRESHARED_KEY,
131 			       NOISE_SYMMETRIC_KEY_LEN,
132 			       peer->handshake.preshared_key);
133 		up_read(&peer->handshake.lock);
134 		if (fail)
135 			goto err;
136 
137 		if (nla_put(skb, WGPEER_A_LAST_HANDSHAKE_TIME,
138 			    sizeof(last_handshake), &last_handshake) ||
139 		    nla_put_u16(skb, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL,
140 				peer->persistent_keepalive_interval) ||
141 		    nla_put_u64_64bit(skb, WGPEER_A_TX_BYTES, peer->tx_bytes,
142 				      WGPEER_A_UNSPEC) ||
143 		    nla_put_u64_64bit(skb, WGPEER_A_RX_BYTES, peer->rx_bytes,
144 				      WGPEER_A_UNSPEC) ||
145 		    nla_put_u32(skb, WGPEER_A_PROTOCOL_VERSION, 1))
146 			goto err;
147 
148 		read_lock_bh(&peer->endpoint_lock);
149 		if (peer->endpoint.addr.sa_family == AF_INET)
150 			fail = nla_put(skb, WGPEER_A_ENDPOINT,
151 				       sizeof(peer->endpoint.addr4),
152 				       &peer->endpoint.addr4);
153 		else if (peer->endpoint.addr.sa_family == AF_INET6)
154 			fail = nla_put(skb, WGPEER_A_ENDPOINT,
155 				       sizeof(peer->endpoint.addr6),
156 				       &peer->endpoint.addr6);
157 		read_unlock_bh(&peer->endpoint_lock);
158 		if (fail)
159 			goto err;
160 		allowedips_node =
161 			list_first_entry_or_null(&peer->allowedips_list,
162 					struct allowedips_node, peer_list);
163 	}
164 	if (!allowedips_node)
165 		goto no_allowedips;
166 	if (!ctx->allowedips_seq)
167 		ctx->allowedips_seq = peer->device->peer_allowedips.seq;
168 	else if (ctx->allowedips_seq != peer->device->peer_allowedips.seq)
169 		goto no_allowedips;
170 
171 	allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS);
172 	if (!allowedips_nest)
173 		goto err;
174 
175 	list_for_each_entry_from(allowedips_node, &peer->allowedips_list,
176 				 peer_list) {
177 		u8 cidr, ip[16] __aligned(__alignof(u64));
178 		int family;
179 
180 		family = wg_allowedips_read_node(allowedips_node, ip, &cidr);
181 		if (get_allowedips(skb, ip, cidr, family)) {
182 			nla_nest_end(skb, allowedips_nest);
183 			nla_nest_end(skb, peer_nest);
184 			ctx->next_allowedip = allowedips_node;
185 			return -EMSGSIZE;
186 		}
187 	}
188 	nla_nest_end(skb, allowedips_nest);
189 no_allowedips:
190 	nla_nest_end(skb, peer_nest);
191 	ctx->next_allowedip = NULL;
192 	ctx->allowedips_seq = 0;
193 	return 0;
194 err:
195 	nla_nest_cancel(skb, peer_nest);
196 	return -EMSGSIZE;
197 }
198 
199 static int wg_get_device_start(struct netlink_callback *cb)
200 {
201 	struct wg_device *wg;
202 
203 	wg = lookup_interface(genl_dumpit_info(cb)->attrs, cb->skb);
204 	if (IS_ERR(wg))
205 		return PTR_ERR(wg);
206 	DUMP_CTX(cb)->wg = wg;
207 	return 0;
208 }
209 
210 static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
211 {
212 	struct wg_peer *peer, *next_peer_cursor;
213 	struct dump_ctx *ctx = DUMP_CTX(cb);
214 	struct wg_device *wg = ctx->wg;
215 	struct nlattr *peers_nest;
216 	int ret = -EMSGSIZE;
217 	bool done = true;
218 	void *hdr;
219 
220 	rtnl_lock();
221 	mutex_lock(&wg->device_update_lock);
222 	cb->seq = wg->device_update_gen;
223 	next_peer_cursor = ctx->next_peer;
224 
225 	hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
226 			  &genl_family, NLM_F_MULTI, WG_CMD_GET_DEVICE);
227 	if (!hdr)
228 		goto out;
229 	genl_dump_check_consistent(cb, hdr);
230 
231 	if (!ctx->next_peer) {
232 		if (nla_put_u16(skb, WGDEVICE_A_LISTEN_PORT,
233 				wg->incoming_port) ||
234 		    nla_put_u32(skb, WGDEVICE_A_FWMARK, wg->fwmark) ||
235 		    nla_put_u32(skb, WGDEVICE_A_IFINDEX, wg->dev->ifindex) ||
236 		    nla_put_string(skb, WGDEVICE_A_IFNAME, wg->dev->name))
237 			goto out;
238 
239 		down_read(&wg->static_identity.lock);
240 		if (wg->static_identity.has_identity) {
241 			if (nla_put(skb, WGDEVICE_A_PRIVATE_KEY,
242 				    NOISE_PUBLIC_KEY_LEN,
243 				    wg->static_identity.static_private) ||
244 			    nla_put(skb, WGDEVICE_A_PUBLIC_KEY,
245 				    NOISE_PUBLIC_KEY_LEN,
246 				    wg->static_identity.static_public)) {
247 				up_read(&wg->static_identity.lock);
248 				goto out;
249 			}
250 		}
251 		up_read(&wg->static_identity.lock);
252 	}
253 
254 	peers_nest = nla_nest_start(skb, WGDEVICE_A_PEERS);
255 	if (!peers_nest)
256 		goto out;
257 	ret = 0;
258 	/* If the last cursor was removed via list_del_init in peer_remove, then
259 	 * we just treat this the same as there being no more peers left. The
260 	 * reason is that seq_nr should indicate to userspace that this isn't a
261 	 * coherent dump anyway, so they'll try again.
262 	 */
263 	if (list_empty(&wg->peer_list) ||
264 	    (ctx->next_peer && list_empty(&ctx->next_peer->peer_list))) {
265 		nla_nest_cancel(skb, peers_nest);
266 		goto out;
267 	}
268 	lockdep_assert_held(&wg->device_update_lock);
269 	peer = list_prepare_entry(ctx->next_peer, &wg->peer_list, peer_list);
270 	list_for_each_entry_continue(peer, &wg->peer_list, peer_list) {
271 		if (get_peer(peer, skb, ctx)) {
272 			done = false;
273 			break;
274 		}
275 		next_peer_cursor = peer;
276 	}
277 	nla_nest_end(skb, peers_nest);
278 
279 out:
280 	if (!ret && !done && next_peer_cursor)
281 		wg_peer_get(next_peer_cursor);
282 	wg_peer_put(ctx->next_peer);
283 	mutex_unlock(&wg->device_update_lock);
284 	rtnl_unlock();
285 
286 	if (ret) {
287 		genlmsg_cancel(skb, hdr);
288 		return ret;
289 	}
290 	genlmsg_end(skb, hdr);
291 	if (done) {
292 		ctx->next_peer = NULL;
293 		return 0;
294 	}
295 	ctx->next_peer = next_peer_cursor;
296 	return skb->len;
297 
298 	/* At this point, we can't really deal ourselves with safely zeroing out
299 	 * the private key material after usage. This will need an additional API
300 	 * in the kernel for marking skbs as zero_on_free.
301 	 */
302 }
303 
304 static int wg_get_device_done(struct netlink_callback *cb)
305 {
306 	struct dump_ctx *ctx = DUMP_CTX(cb);
307 
308 	if (ctx->wg)
309 		dev_put(ctx->wg->dev);
310 	wg_peer_put(ctx->next_peer);
311 	return 0;
312 }
313 
314 static int set_port(struct wg_device *wg, u16 port)
315 {
316 	struct wg_peer *peer;
317 
318 	if (wg->incoming_port == port)
319 		return 0;
320 	list_for_each_entry(peer, &wg->peer_list, peer_list)
321 		wg_socket_clear_peer_endpoint_src(peer);
322 	if (!netif_running(wg->dev)) {
323 		wg->incoming_port = port;
324 		return 0;
325 	}
326 	return wg_socket_init(wg, port);
327 }
328 
329 static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs)
330 {
331 	int ret = -EINVAL;
332 	u16 family;
333 	u8 cidr;
334 
335 	if (!attrs[WGALLOWEDIP_A_FAMILY] || !attrs[WGALLOWEDIP_A_IPADDR] ||
336 	    !attrs[WGALLOWEDIP_A_CIDR_MASK])
337 		return ret;
338 	family = nla_get_u16(attrs[WGALLOWEDIP_A_FAMILY]);
339 	cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]);
340 
341 	if (family == AF_INET && cidr <= 32 &&
342 	    nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr))
343 		ret = wg_allowedips_insert_v4(
344 			&peer->device->peer_allowedips,
345 			nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
346 			&peer->device->device_update_lock);
347 	else if (family == AF_INET6 && cidr <= 128 &&
348 		 nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr))
349 		ret = wg_allowedips_insert_v6(
350 			&peer->device->peer_allowedips,
351 			nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
352 			&peer->device->device_update_lock);
353 
354 	return ret;
355 }
356 
357 static int set_peer(struct wg_device *wg, struct nlattr **attrs)
358 {
359 	u8 *public_key = NULL, *preshared_key = NULL;
360 	struct wg_peer *peer = NULL;
361 	u32 flags = 0;
362 	int ret;
363 
364 	ret = -EINVAL;
365 	if (attrs[WGPEER_A_PUBLIC_KEY] &&
366 	    nla_len(attrs[WGPEER_A_PUBLIC_KEY]) == NOISE_PUBLIC_KEY_LEN)
367 		public_key = nla_data(attrs[WGPEER_A_PUBLIC_KEY]);
368 	else
369 		goto out;
370 	if (attrs[WGPEER_A_PRESHARED_KEY] &&
371 	    nla_len(attrs[WGPEER_A_PRESHARED_KEY]) == NOISE_SYMMETRIC_KEY_LEN)
372 		preshared_key = nla_data(attrs[WGPEER_A_PRESHARED_KEY]);
373 
374 	if (attrs[WGPEER_A_FLAGS])
375 		flags = nla_get_u32(attrs[WGPEER_A_FLAGS]);
376 	ret = -EOPNOTSUPP;
377 	if (flags & ~__WGPEER_F_ALL)
378 		goto out;
379 
380 	ret = -EPFNOSUPPORT;
381 	if (attrs[WGPEER_A_PROTOCOL_VERSION]) {
382 		if (nla_get_u32(attrs[WGPEER_A_PROTOCOL_VERSION]) != 1)
383 			goto out;
384 	}
385 
386 	peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable,
387 					  nla_data(attrs[WGPEER_A_PUBLIC_KEY]));
388 	ret = 0;
389 	if (!peer) { /* Peer doesn't exist yet. Add a new one. */
390 		if (flags & (WGPEER_F_REMOVE_ME | WGPEER_F_UPDATE_ONLY))
391 			goto out;
392 
393 		/* The peer is new, so there aren't allowed IPs to remove. */
394 		flags &= ~WGPEER_F_REPLACE_ALLOWEDIPS;
395 
396 		down_read(&wg->static_identity.lock);
397 		if (wg->static_identity.has_identity &&
398 		    !memcmp(nla_data(attrs[WGPEER_A_PUBLIC_KEY]),
399 			    wg->static_identity.static_public,
400 			    NOISE_PUBLIC_KEY_LEN)) {
401 			/* We silently ignore peers that have the same public
402 			 * key as the device. The reason we do it silently is
403 			 * that we'd like for people to be able to reuse the
404 			 * same set of API calls across peers.
405 			 */
406 			up_read(&wg->static_identity.lock);
407 			ret = 0;
408 			goto out;
409 		}
410 		up_read(&wg->static_identity.lock);
411 
412 		peer = wg_peer_create(wg, public_key, preshared_key);
413 		if (IS_ERR(peer)) {
414 			/* Similar to the above, if the key is invalid, we skip
415 			 * it without fanfare, so that services don't need to
416 			 * worry about doing key validation themselves.
417 			 */
418 			ret = PTR_ERR(peer) == -EKEYREJECTED ? 0 : PTR_ERR(peer);
419 			peer = NULL;
420 			goto out;
421 		}
422 		/* Take additional reference, as though we've just been
423 		 * looked up.
424 		 */
425 		wg_peer_get(peer);
426 	}
427 
428 	if (flags & WGPEER_F_REMOVE_ME) {
429 		wg_peer_remove(peer);
430 		goto out;
431 	}
432 
433 	if (preshared_key) {
434 		down_write(&peer->handshake.lock);
435 		memcpy(&peer->handshake.preshared_key, preshared_key,
436 		       NOISE_SYMMETRIC_KEY_LEN);
437 		up_write(&peer->handshake.lock);
438 	}
439 
440 	if (attrs[WGPEER_A_ENDPOINT]) {
441 		struct sockaddr *addr = nla_data(attrs[WGPEER_A_ENDPOINT]);
442 		size_t len = nla_len(attrs[WGPEER_A_ENDPOINT]);
443 
444 		if ((len == sizeof(struct sockaddr_in) &&
445 		     addr->sa_family == AF_INET) ||
446 		    (len == sizeof(struct sockaddr_in6) &&
447 		     addr->sa_family == AF_INET6)) {
448 			struct endpoint endpoint = { { { 0 } } };
449 
450 			memcpy(&endpoint.addr, addr, len);
451 			wg_socket_set_peer_endpoint(peer, &endpoint);
452 		}
453 	}
454 
455 	if (flags & WGPEER_F_REPLACE_ALLOWEDIPS)
456 		wg_allowedips_remove_by_peer(&wg->peer_allowedips, peer,
457 					     &wg->device_update_lock);
458 
459 	if (attrs[WGPEER_A_ALLOWEDIPS]) {
460 		struct nlattr *attr, *allowedip[WGALLOWEDIP_A_MAX + 1];
461 		int rem;
462 
463 		nla_for_each_nested(attr, attrs[WGPEER_A_ALLOWEDIPS], rem) {
464 			ret = nla_parse_nested(allowedip, WGALLOWEDIP_A_MAX,
465 					       attr, allowedip_policy, NULL);
466 			if (ret < 0)
467 				goto out;
468 			ret = set_allowedip(peer, allowedip);
469 			if (ret < 0)
470 				goto out;
471 		}
472 	}
473 
474 	if (attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]) {
475 		const u16 persistent_keepalive_interval = nla_get_u16(
476 				attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]);
477 		const bool send_keepalive =
478 			!peer->persistent_keepalive_interval &&
479 			persistent_keepalive_interval &&
480 			netif_running(wg->dev);
481 
482 		peer->persistent_keepalive_interval = persistent_keepalive_interval;
483 		if (send_keepalive)
484 			wg_packet_send_keepalive(peer);
485 	}
486 
487 	if (netif_running(wg->dev))
488 		wg_packet_send_staged_packets(peer);
489 
490 out:
491 	wg_peer_put(peer);
492 	if (attrs[WGPEER_A_PRESHARED_KEY])
493 		memzero_explicit(nla_data(attrs[WGPEER_A_PRESHARED_KEY]),
494 				 nla_len(attrs[WGPEER_A_PRESHARED_KEY]));
495 	return ret;
496 }
497 
498 static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
499 {
500 	struct wg_device *wg = lookup_interface(info->attrs, skb);
501 	u32 flags = 0;
502 	int ret;
503 
504 	if (IS_ERR(wg)) {
505 		ret = PTR_ERR(wg);
506 		goto out_nodev;
507 	}
508 
509 	rtnl_lock();
510 	mutex_lock(&wg->device_update_lock);
511 
512 	if (info->attrs[WGDEVICE_A_FLAGS])
513 		flags = nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]);
514 	ret = -EOPNOTSUPP;
515 	if (flags & ~__WGDEVICE_F_ALL)
516 		goto out;
517 
518 	ret = -EPERM;
519 	if ((info->attrs[WGDEVICE_A_LISTEN_PORT] ||
520 	     info->attrs[WGDEVICE_A_FWMARK]) &&
521 	    !ns_capable(wg->creating_net->user_ns, CAP_NET_ADMIN))
522 		goto out;
523 
524 	++wg->device_update_gen;
525 
526 	if (info->attrs[WGDEVICE_A_FWMARK]) {
527 		struct wg_peer *peer;
528 
529 		wg->fwmark = nla_get_u32(info->attrs[WGDEVICE_A_FWMARK]);
530 		list_for_each_entry(peer, &wg->peer_list, peer_list)
531 			wg_socket_clear_peer_endpoint_src(peer);
532 	}
533 
534 	if (info->attrs[WGDEVICE_A_LISTEN_PORT]) {
535 		ret = set_port(wg,
536 			nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT]));
537 		if (ret)
538 			goto out;
539 	}
540 
541 	if (flags & WGDEVICE_F_REPLACE_PEERS)
542 		wg_peer_remove_all(wg);
543 
544 	if (info->attrs[WGDEVICE_A_PRIVATE_KEY] &&
545 	    nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]) ==
546 		    NOISE_PUBLIC_KEY_LEN) {
547 		u8 *private_key = nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]);
548 		u8 public_key[NOISE_PUBLIC_KEY_LEN];
549 		struct wg_peer *peer, *temp;
550 
551 		if (!crypto_memneq(wg->static_identity.static_private,
552 				   private_key, NOISE_PUBLIC_KEY_LEN))
553 			goto skip_set_private_key;
554 
555 		/* We remove before setting, to prevent race, which means doing
556 		 * two 25519-genpub ops.
557 		 */
558 		if (curve25519_generate_public(public_key, private_key)) {
559 			peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable,
560 							  public_key);
561 			if (peer) {
562 				wg_peer_put(peer);
563 				wg_peer_remove(peer);
564 			}
565 		}
566 
567 		down_write(&wg->static_identity.lock);
568 		wg_noise_set_static_identity_private_key(&wg->static_identity,
569 							 private_key);
570 		list_for_each_entry_safe(peer, temp, &wg->peer_list,
571 					 peer_list) {
572 			BUG_ON(!wg_noise_precompute_static_static(peer));
573 			wg_noise_expire_current_peer_keypairs(peer);
574 		}
575 		wg_cookie_checker_precompute_device_keys(&wg->cookie_checker);
576 		up_write(&wg->static_identity.lock);
577 	}
578 skip_set_private_key:
579 
580 	if (info->attrs[WGDEVICE_A_PEERS]) {
581 		struct nlattr *attr, *peer[WGPEER_A_MAX + 1];
582 		int rem;
583 
584 		nla_for_each_nested(attr, info->attrs[WGDEVICE_A_PEERS], rem) {
585 			ret = nla_parse_nested(peer, WGPEER_A_MAX, attr,
586 					       peer_policy, NULL);
587 			if (ret < 0)
588 				goto out;
589 			ret = set_peer(wg, peer);
590 			if (ret < 0)
591 				goto out;
592 		}
593 	}
594 	ret = 0;
595 
596 out:
597 	mutex_unlock(&wg->device_update_lock);
598 	rtnl_unlock();
599 	dev_put(wg->dev);
600 out_nodev:
601 	if (info->attrs[WGDEVICE_A_PRIVATE_KEY])
602 		memzero_explicit(nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]),
603 				 nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]));
604 	return ret;
605 }
606 
607 static const struct genl_ops genl_ops[] = {
608 	{
609 		.cmd = WG_CMD_GET_DEVICE,
610 		.start = wg_get_device_start,
611 		.dumpit = wg_get_device_dump,
612 		.done = wg_get_device_done,
613 		.flags = GENL_UNS_ADMIN_PERM
614 	}, {
615 		.cmd = WG_CMD_SET_DEVICE,
616 		.doit = wg_set_device,
617 		.flags = GENL_UNS_ADMIN_PERM
618 	}
619 };
620 
621 static struct genl_family genl_family __ro_after_init = {
622 	.ops = genl_ops,
623 	.n_ops = ARRAY_SIZE(genl_ops),
624 	.name = WG_GENL_NAME,
625 	.version = WG_GENL_VERSION,
626 	.maxattr = WGDEVICE_A_MAX,
627 	.module = THIS_MODULE,
628 	.policy = device_policy,
629 	.netnsok = true
630 };
631 
632 int __init wg_genetlink_init(void)
633 {
634 	return genl_register_family(&genl_family);
635 }
636 
637 void __exit wg_genetlink_uninit(void)
638 {
639 	genl_unregister_family(&genl_family);
640 }
641