xref: /openbmc/linux/drivers/net/wireguard/noise.c (revision 8622a0e5)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5 
6 #include "noise.h"
7 #include "device.h"
8 #include "peer.h"
9 #include "messages.h"
10 #include "queueing.h"
11 #include "peerlookup.h"
12 
13 #include <linux/rcupdate.h>
14 #include <linux/slab.h>
15 #include <linux/bitmap.h>
16 #include <linux/scatterlist.h>
17 #include <linux/highmem.h>
18 #include <crypto/algapi.h>
19 
20 /* This implements Noise_IKpsk2:
21  *
22  * <- s
23  * ******
24  * -> e, es, s, ss, {t}
25  * <- e, ee, se, psk, {}
26  */
27 
28 static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
29 static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
30 static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
31 static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
32 static atomic64_t keypair_counter = ATOMIC64_INIT(0);
33 
34 void __init wg_noise_init(void)
35 {
36 	struct blake2s_state blake;
37 
38 	blake2s(handshake_init_chaining_key, handshake_name, NULL,
39 		NOISE_HASH_LEN, sizeof(handshake_name), 0);
40 	blake2s_init(&blake, NOISE_HASH_LEN);
41 	blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
42 	blake2s_update(&blake, identifier_name, sizeof(identifier_name));
43 	blake2s_final(&blake, handshake_init_hash);
44 }
45 
46 /* Must hold peer->handshake.static_identity->lock */
47 bool wg_noise_precompute_static_static(struct wg_peer *peer)
48 {
49 	bool ret;
50 
51 	down_write(&peer->handshake.lock);
52 	if (peer->handshake.static_identity->has_identity) {
53 		ret = curve25519(
54 			peer->handshake.precomputed_static_static,
55 			peer->handshake.static_identity->static_private,
56 			peer->handshake.remote_static);
57 	} else {
58 		u8 empty[NOISE_PUBLIC_KEY_LEN] = { 0 };
59 
60 		ret = curve25519(empty, empty, peer->handshake.remote_static);
61 		memset(peer->handshake.precomputed_static_static, 0,
62 		       NOISE_PUBLIC_KEY_LEN);
63 	}
64 	up_write(&peer->handshake.lock);
65 	return ret;
66 }
67 
68 bool wg_noise_handshake_init(struct noise_handshake *handshake,
69 			   struct noise_static_identity *static_identity,
70 			   const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
71 			   const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
72 			   struct wg_peer *peer)
73 {
74 	memset(handshake, 0, sizeof(*handshake));
75 	init_rwsem(&handshake->lock);
76 	handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
77 	handshake->entry.peer = peer;
78 	memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
79 	if (peer_preshared_key)
80 		memcpy(handshake->preshared_key, peer_preshared_key,
81 		       NOISE_SYMMETRIC_KEY_LEN);
82 	handshake->static_identity = static_identity;
83 	handshake->state = HANDSHAKE_ZEROED;
84 	return wg_noise_precompute_static_static(peer);
85 }
86 
87 static void handshake_zero(struct noise_handshake *handshake)
88 {
89 	memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
90 	memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
91 	memset(&handshake->hash, 0, NOISE_HASH_LEN);
92 	memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
93 	handshake->remote_index = 0;
94 	handshake->state = HANDSHAKE_ZEROED;
95 }
96 
97 void wg_noise_handshake_clear(struct noise_handshake *handshake)
98 {
99 	wg_index_hashtable_remove(
100 			handshake->entry.peer->device->index_hashtable,
101 			&handshake->entry);
102 	down_write(&handshake->lock);
103 	handshake_zero(handshake);
104 	up_write(&handshake->lock);
105 	wg_index_hashtable_remove(
106 			handshake->entry.peer->device->index_hashtable,
107 			&handshake->entry);
108 }
109 
110 static struct noise_keypair *keypair_create(struct wg_peer *peer)
111 {
112 	struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
113 
114 	if (unlikely(!keypair))
115 		return NULL;
116 	keypair->internal_id = atomic64_inc_return(&keypair_counter);
117 	keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
118 	keypair->entry.peer = peer;
119 	kref_init(&keypair->refcount);
120 	return keypair;
121 }
122 
123 static void keypair_free_rcu(struct rcu_head *rcu)
124 {
125 	kzfree(container_of(rcu, struct noise_keypair, rcu));
126 }
127 
128 static void keypair_free_kref(struct kref *kref)
129 {
130 	struct noise_keypair *keypair =
131 		container_of(kref, struct noise_keypair, refcount);
132 
133 	net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
134 			    keypair->entry.peer->device->dev->name,
135 			    keypair->internal_id,
136 			    keypair->entry.peer->internal_id);
137 	wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
138 				  &keypair->entry);
139 	call_rcu(&keypair->rcu, keypair_free_rcu);
140 }
141 
142 void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
143 {
144 	if (unlikely(!keypair))
145 		return;
146 	if (unlikely(unreference_now))
147 		wg_index_hashtable_remove(
148 			keypair->entry.peer->device->index_hashtable,
149 			&keypair->entry);
150 	kref_put(&keypair->refcount, keypair_free_kref);
151 }
152 
153 struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
154 {
155 	RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
156 		"Taking noise keypair reference without holding the RCU BH read lock");
157 	if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
158 		return NULL;
159 	return keypair;
160 }
161 
162 void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
163 {
164 	struct noise_keypair *old;
165 
166 	spin_lock_bh(&keypairs->keypair_update_lock);
167 
168 	/* We zero the next_keypair before zeroing the others, so that
169 	 * wg_noise_received_with_keypair returns early before subsequent ones
170 	 * are zeroed.
171 	 */
172 	old = rcu_dereference_protected(keypairs->next_keypair,
173 		lockdep_is_held(&keypairs->keypair_update_lock));
174 	RCU_INIT_POINTER(keypairs->next_keypair, NULL);
175 	wg_noise_keypair_put(old, true);
176 
177 	old = rcu_dereference_protected(keypairs->previous_keypair,
178 		lockdep_is_held(&keypairs->keypair_update_lock));
179 	RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
180 	wg_noise_keypair_put(old, true);
181 
182 	old = rcu_dereference_protected(keypairs->current_keypair,
183 		lockdep_is_held(&keypairs->keypair_update_lock));
184 	RCU_INIT_POINTER(keypairs->current_keypair, NULL);
185 	wg_noise_keypair_put(old, true);
186 
187 	spin_unlock_bh(&keypairs->keypair_update_lock);
188 }
189 
190 void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
191 {
192 	struct noise_keypair *keypair;
193 
194 	wg_noise_handshake_clear(&peer->handshake);
195 	wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
196 
197 	spin_lock_bh(&peer->keypairs.keypair_update_lock);
198 	keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
199 			lockdep_is_held(&peer->keypairs.keypair_update_lock));
200 	if (keypair)
201 		keypair->sending.is_valid = false;
202 	keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
203 			lockdep_is_held(&peer->keypairs.keypair_update_lock));
204 	if (keypair)
205 		keypair->sending.is_valid = false;
206 	spin_unlock_bh(&peer->keypairs.keypair_update_lock);
207 }
208 
209 static void add_new_keypair(struct noise_keypairs *keypairs,
210 			    struct noise_keypair *new_keypair)
211 {
212 	struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
213 
214 	spin_lock_bh(&keypairs->keypair_update_lock);
215 	previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
216 		lockdep_is_held(&keypairs->keypair_update_lock));
217 	next_keypair = rcu_dereference_protected(keypairs->next_keypair,
218 		lockdep_is_held(&keypairs->keypair_update_lock));
219 	current_keypair = rcu_dereference_protected(keypairs->current_keypair,
220 		lockdep_is_held(&keypairs->keypair_update_lock));
221 	if (new_keypair->i_am_the_initiator) {
222 		/* If we're the initiator, it means we've sent a handshake, and
223 		 * received a confirmation response, which means this new
224 		 * keypair can now be used.
225 		 */
226 		if (next_keypair) {
227 			/* If there already was a next keypair pending, we
228 			 * demote it to be the previous keypair, and free the
229 			 * existing current. Note that this means KCI can result
230 			 * in this transition. It would perhaps be more sound to
231 			 * always just get rid of the unused next keypair
232 			 * instead of putting it in the previous slot, but this
233 			 * might be a bit less robust. Something to think about
234 			 * for the future.
235 			 */
236 			RCU_INIT_POINTER(keypairs->next_keypair, NULL);
237 			rcu_assign_pointer(keypairs->previous_keypair,
238 					   next_keypair);
239 			wg_noise_keypair_put(current_keypair, true);
240 		} else /* If there wasn't an existing next keypair, we replace
241 			* the previous with the current one.
242 			*/
243 			rcu_assign_pointer(keypairs->previous_keypair,
244 					   current_keypair);
245 		/* At this point we can get rid of the old previous keypair, and
246 		 * set up the new keypair.
247 		 */
248 		wg_noise_keypair_put(previous_keypair, true);
249 		rcu_assign_pointer(keypairs->current_keypair, new_keypair);
250 	} else {
251 		/* If we're the responder, it means we can't use the new keypair
252 		 * until we receive confirmation via the first data packet, so
253 		 * we get rid of the existing previous one, the possibly
254 		 * existing next one, and slide in the new next one.
255 		 */
256 		rcu_assign_pointer(keypairs->next_keypair, new_keypair);
257 		wg_noise_keypair_put(next_keypair, true);
258 		RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
259 		wg_noise_keypair_put(previous_keypair, true);
260 	}
261 	spin_unlock_bh(&keypairs->keypair_update_lock);
262 }
263 
264 bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
265 				    struct noise_keypair *received_keypair)
266 {
267 	struct noise_keypair *old_keypair;
268 	bool key_is_new;
269 
270 	/* We first check without taking the spinlock. */
271 	key_is_new = received_keypair ==
272 		     rcu_access_pointer(keypairs->next_keypair);
273 	if (likely(!key_is_new))
274 		return false;
275 
276 	spin_lock_bh(&keypairs->keypair_update_lock);
277 	/* After locking, we double check that things didn't change from
278 	 * beneath us.
279 	 */
280 	if (unlikely(received_keypair !=
281 		    rcu_dereference_protected(keypairs->next_keypair,
282 			    lockdep_is_held(&keypairs->keypair_update_lock)))) {
283 		spin_unlock_bh(&keypairs->keypair_update_lock);
284 		return false;
285 	}
286 
287 	/* When we've finally received the confirmation, we slide the next
288 	 * into the current, the current into the previous, and get rid of
289 	 * the old previous.
290 	 */
291 	old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
292 		lockdep_is_held(&keypairs->keypair_update_lock));
293 	rcu_assign_pointer(keypairs->previous_keypair,
294 		rcu_dereference_protected(keypairs->current_keypair,
295 			lockdep_is_held(&keypairs->keypair_update_lock)));
296 	wg_noise_keypair_put(old_keypair, true);
297 	rcu_assign_pointer(keypairs->current_keypair, received_keypair);
298 	RCU_INIT_POINTER(keypairs->next_keypair, NULL);
299 
300 	spin_unlock_bh(&keypairs->keypair_update_lock);
301 	return true;
302 }
303 
304 /* Must hold static_identity->lock */
305 void wg_noise_set_static_identity_private_key(
306 	struct noise_static_identity *static_identity,
307 	const u8 private_key[NOISE_PUBLIC_KEY_LEN])
308 {
309 	memcpy(static_identity->static_private, private_key,
310 	       NOISE_PUBLIC_KEY_LEN);
311 	curve25519_clamp_secret(static_identity->static_private);
312 	static_identity->has_identity = curve25519_generate_public(
313 		static_identity->static_public, private_key);
314 }
315 
316 /* This is Hugo Krawczyk's HKDF:
317  *  - https://eprint.iacr.org/2010/264.pdf
318  *  - https://tools.ietf.org/html/rfc5869
319  */
320 static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
321 		size_t first_len, size_t second_len, size_t third_len,
322 		size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
323 {
324 	u8 output[BLAKE2S_HASH_SIZE + 1];
325 	u8 secret[BLAKE2S_HASH_SIZE];
326 
327 	WARN_ON(IS_ENABLED(DEBUG) &&
328 		(first_len > BLAKE2S_HASH_SIZE ||
329 		 second_len > BLAKE2S_HASH_SIZE ||
330 		 third_len > BLAKE2S_HASH_SIZE ||
331 		 ((second_len || second_dst || third_len || third_dst) &&
332 		  (!first_len || !first_dst)) ||
333 		 ((third_len || third_dst) && (!second_len || !second_dst))));
334 
335 	/* Extract entropy from data into secret */
336 	blake2s256_hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
337 
338 	if (!first_dst || !first_len)
339 		goto out;
340 
341 	/* Expand first key: key = secret, data = 0x1 */
342 	output[0] = 1;
343 	blake2s256_hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
344 	memcpy(first_dst, output, first_len);
345 
346 	if (!second_dst || !second_len)
347 		goto out;
348 
349 	/* Expand second key: key = secret, data = first-key || 0x2 */
350 	output[BLAKE2S_HASH_SIZE] = 2;
351 	blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
352 			BLAKE2S_HASH_SIZE);
353 	memcpy(second_dst, output, second_len);
354 
355 	if (!third_dst || !third_len)
356 		goto out;
357 
358 	/* Expand third key: key = secret, data = second-key || 0x3 */
359 	output[BLAKE2S_HASH_SIZE] = 3;
360 	blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
361 			BLAKE2S_HASH_SIZE);
362 	memcpy(third_dst, output, third_len);
363 
364 out:
365 	/* Clear sensitive data from stack */
366 	memzero_explicit(secret, BLAKE2S_HASH_SIZE);
367 	memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
368 }
369 
370 static void symmetric_key_init(struct noise_symmetric_key *key)
371 {
372 	spin_lock_init(&key->counter.receive.lock);
373 	atomic64_set(&key->counter.counter, 0);
374 	memset(key->counter.receive.backtrack, 0,
375 	       sizeof(key->counter.receive.backtrack));
376 	key->birthdate = ktime_get_coarse_boottime_ns();
377 	key->is_valid = true;
378 }
379 
380 static void derive_keys(struct noise_symmetric_key *first_dst,
381 			struct noise_symmetric_key *second_dst,
382 			const u8 chaining_key[NOISE_HASH_LEN])
383 {
384 	kdf(first_dst->key, second_dst->key, NULL, NULL,
385 	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
386 	    chaining_key);
387 	symmetric_key_init(first_dst);
388 	symmetric_key_init(second_dst);
389 }
390 
391 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
392 				u8 key[NOISE_SYMMETRIC_KEY_LEN],
393 				const u8 private[NOISE_PUBLIC_KEY_LEN],
394 				const u8 public[NOISE_PUBLIC_KEY_LEN])
395 {
396 	u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
397 
398 	if (unlikely(!curve25519(dh_calculation, private, public)))
399 		return false;
400 	kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
401 	    NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
402 	memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
403 	return true;
404 }
405 
406 static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
407 {
408 	struct blake2s_state blake;
409 
410 	blake2s_init(&blake, NOISE_HASH_LEN);
411 	blake2s_update(&blake, hash, NOISE_HASH_LEN);
412 	blake2s_update(&blake, src, src_len);
413 	blake2s_final(&blake, hash);
414 }
415 
416 static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
417 		    u8 key[NOISE_SYMMETRIC_KEY_LEN],
418 		    const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
419 {
420 	u8 temp_hash[NOISE_HASH_LEN];
421 
422 	kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
423 	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
424 	mix_hash(hash, temp_hash, NOISE_HASH_LEN);
425 	memzero_explicit(temp_hash, NOISE_HASH_LEN);
426 }
427 
428 static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
429 			   u8 hash[NOISE_HASH_LEN],
430 			   const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
431 {
432 	memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
433 	memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
434 	mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
435 }
436 
437 static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
438 			    size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
439 			    u8 hash[NOISE_HASH_LEN])
440 {
441 	chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
442 				 NOISE_HASH_LEN,
443 				 0 /* Always zero for Noise_IK */, key);
444 	mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
445 }
446 
447 static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
448 			    size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
449 			    u8 hash[NOISE_HASH_LEN])
450 {
451 	if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
452 				      hash, NOISE_HASH_LEN,
453 				      0 /* Always zero for Noise_IK */, key))
454 		return false;
455 	mix_hash(hash, src_ciphertext, src_len);
456 	return true;
457 }
458 
459 static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
460 			      const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
461 			      u8 chaining_key[NOISE_HASH_LEN],
462 			      u8 hash[NOISE_HASH_LEN])
463 {
464 	if (ephemeral_dst != ephemeral_src)
465 		memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
466 	mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
467 	kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
468 	    NOISE_PUBLIC_KEY_LEN, chaining_key);
469 }
470 
471 static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
472 {
473 	struct timespec64 now;
474 
475 	ktime_get_real_ts64(&now);
476 
477 	/* In order to prevent some sort of infoleak from precise timers, we
478 	 * round down the nanoseconds part to the closest rounded-down power of
479 	 * two to the maximum initiations per second allowed anyway by the
480 	 * implementation.
481 	 */
482 	now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
483 		rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
484 
485 	/* https://cr.yp.to/libtai/tai64.html */
486 	*(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
487 	*(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
488 }
489 
490 bool
491 wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
492 				     struct noise_handshake *handshake)
493 {
494 	u8 timestamp[NOISE_TIMESTAMP_LEN];
495 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
496 	bool ret = false;
497 
498 	/* We need to wait for crng _before_ taking any locks, since
499 	 * curve25519_generate_secret uses get_random_bytes_wait.
500 	 */
501 	wait_for_random_bytes();
502 
503 	down_read(&handshake->static_identity->lock);
504 	down_write(&handshake->lock);
505 
506 	if (unlikely(!handshake->static_identity->has_identity))
507 		goto out;
508 
509 	dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
510 
511 	handshake_init(handshake->chaining_key, handshake->hash,
512 		       handshake->remote_static);
513 
514 	/* e */
515 	curve25519_generate_secret(handshake->ephemeral_private);
516 	if (!curve25519_generate_public(dst->unencrypted_ephemeral,
517 					handshake->ephemeral_private))
518 		goto out;
519 	message_ephemeral(dst->unencrypted_ephemeral,
520 			  dst->unencrypted_ephemeral, handshake->chaining_key,
521 			  handshake->hash);
522 
523 	/* es */
524 	if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
525 		    handshake->remote_static))
526 		goto out;
527 
528 	/* s */
529 	message_encrypt(dst->encrypted_static,
530 			handshake->static_identity->static_public,
531 			NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
532 
533 	/* ss */
534 	kdf(handshake->chaining_key, key, NULL,
535 	    handshake->precomputed_static_static, NOISE_HASH_LEN,
536 	    NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
537 	    handshake->chaining_key);
538 
539 	/* {t} */
540 	tai64n_now(timestamp);
541 	message_encrypt(dst->encrypted_timestamp, timestamp,
542 			NOISE_TIMESTAMP_LEN, key, handshake->hash);
543 
544 	dst->sender_index = wg_index_hashtable_insert(
545 		handshake->entry.peer->device->index_hashtable,
546 		&handshake->entry);
547 
548 	handshake->state = HANDSHAKE_CREATED_INITIATION;
549 	ret = true;
550 
551 out:
552 	up_write(&handshake->lock);
553 	up_read(&handshake->static_identity->lock);
554 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
555 	return ret;
556 }
557 
558 struct wg_peer *
559 wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
560 				      struct wg_device *wg)
561 {
562 	struct wg_peer *peer = NULL, *ret_peer = NULL;
563 	struct noise_handshake *handshake;
564 	bool replay_attack, flood_attack;
565 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
566 	u8 chaining_key[NOISE_HASH_LEN];
567 	u8 hash[NOISE_HASH_LEN];
568 	u8 s[NOISE_PUBLIC_KEY_LEN];
569 	u8 e[NOISE_PUBLIC_KEY_LEN];
570 	u8 t[NOISE_TIMESTAMP_LEN];
571 	u64 initiation_consumption;
572 
573 	down_read(&wg->static_identity.lock);
574 	if (unlikely(!wg->static_identity.has_identity))
575 		goto out;
576 
577 	handshake_init(chaining_key, hash, wg->static_identity.static_public);
578 
579 	/* e */
580 	message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
581 
582 	/* es */
583 	if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
584 		goto out;
585 
586 	/* s */
587 	if (!message_decrypt(s, src->encrypted_static,
588 			     sizeof(src->encrypted_static), key, hash))
589 		goto out;
590 
591 	/* Lookup which peer we're actually talking to */
592 	peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
593 	if (!peer)
594 		goto out;
595 	handshake = &peer->handshake;
596 
597 	/* ss */
598 	kdf(chaining_key, key, NULL, handshake->precomputed_static_static,
599 	    NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
600 	    chaining_key);
601 
602 	/* {t} */
603 	if (!message_decrypt(t, src->encrypted_timestamp,
604 			     sizeof(src->encrypted_timestamp), key, hash))
605 		goto out;
606 
607 	down_read(&handshake->lock);
608 	replay_attack = memcmp(t, handshake->latest_timestamp,
609 			       NOISE_TIMESTAMP_LEN) <= 0;
610 	flood_attack = (s64)handshake->last_initiation_consumption +
611 			       NSEC_PER_SEC / INITIATIONS_PER_SECOND >
612 		       (s64)ktime_get_coarse_boottime_ns();
613 	up_read(&handshake->lock);
614 	if (replay_attack || flood_attack)
615 		goto out;
616 
617 	/* Success! Copy everything to peer */
618 	down_write(&handshake->lock);
619 	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
620 	if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
621 		memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
622 	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
623 	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
624 	handshake->remote_index = src->sender_index;
625 	if ((s64)(handshake->last_initiation_consumption -
626 	    (initiation_consumption = ktime_get_coarse_boottime_ns())) < 0)
627 		handshake->last_initiation_consumption = initiation_consumption;
628 	handshake->state = HANDSHAKE_CONSUMED_INITIATION;
629 	up_write(&handshake->lock);
630 	ret_peer = peer;
631 
632 out:
633 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
634 	memzero_explicit(hash, NOISE_HASH_LEN);
635 	memzero_explicit(chaining_key, NOISE_HASH_LEN);
636 	up_read(&wg->static_identity.lock);
637 	if (!ret_peer)
638 		wg_peer_put(peer);
639 	return ret_peer;
640 }
641 
642 bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
643 					struct noise_handshake *handshake)
644 {
645 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
646 	bool ret = false;
647 
648 	/* We need to wait for crng _before_ taking any locks, since
649 	 * curve25519_generate_secret uses get_random_bytes_wait.
650 	 */
651 	wait_for_random_bytes();
652 
653 	down_read(&handshake->static_identity->lock);
654 	down_write(&handshake->lock);
655 
656 	if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
657 		goto out;
658 
659 	dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
660 	dst->receiver_index = handshake->remote_index;
661 
662 	/* e */
663 	curve25519_generate_secret(handshake->ephemeral_private);
664 	if (!curve25519_generate_public(dst->unencrypted_ephemeral,
665 					handshake->ephemeral_private))
666 		goto out;
667 	message_ephemeral(dst->unencrypted_ephemeral,
668 			  dst->unencrypted_ephemeral, handshake->chaining_key,
669 			  handshake->hash);
670 
671 	/* ee */
672 	if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
673 		    handshake->remote_ephemeral))
674 		goto out;
675 
676 	/* se */
677 	if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
678 		    handshake->remote_static))
679 		goto out;
680 
681 	/* psk */
682 	mix_psk(handshake->chaining_key, handshake->hash, key,
683 		handshake->preshared_key);
684 
685 	/* {} */
686 	message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
687 
688 	dst->sender_index = wg_index_hashtable_insert(
689 		handshake->entry.peer->device->index_hashtable,
690 		&handshake->entry);
691 
692 	handshake->state = HANDSHAKE_CREATED_RESPONSE;
693 	ret = true;
694 
695 out:
696 	up_write(&handshake->lock);
697 	up_read(&handshake->static_identity->lock);
698 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
699 	return ret;
700 }
701 
702 struct wg_peer *
703 wg_noise_handshake_consume_response(struct message_handshake_response *src,
704 				    struct wg_device *wg)
705 {
706 	enum noise_handshake_state state = HANDSHAKE_ZEROED;
707 	struct wg_peer *peer = NULL, *ret_peer = NULL;
708 	struct noise_handshake *handshake;
709 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
710 	u8 hash[NOISE_HASH_LEN];
711 	u8 chaining_key[NOISE_HASH_LEN];
712 	u8 e[NOISE_PUBLIC_KEY_LEN];
713 	u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
714 	u8 static_private[NOISE_PUBLIC_KEY_LEN];
715 
716 	down_read(&wg->static_identity.lock);
717 
718 	if (unlikely(!wg->static_identity.has_identity))
719 		goto out;
720 
721 	handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
722 		wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
723 		src->receiver_index, &peer);
724 	if (unlikely(!handshake))
725 		goto out;
726 
727 	down_read(&handshake->lock);
728 	state = handshake->state;
729 	memcpy(hash, handshake->hash, NOISE_HASH_LEN);
730 	memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
731 	memcpy(ephemeral_private, handshake->ephemeral_private,
732 	       NOISE_PUBLIC_KEY_LEN);
733 	up_read(&handshake->lock);
734 
735 	if (state != HANDSHAKE_CREATED_INITIATION)
736 		goto fail;
737 
738 	/* e */
739 	message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
740 
741 	/* ee */
742 	if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
743 		goto fail;
744 
745 	/* se */
746 	if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
747 		goto fail;
748 
749 	/* psk */
750 	mix_psk(chaining_key, hash, key, handshake->preshared_key);
751 
752 	/* {} */
753 	if (!message_decrypt(NULL, src->encrypted_nothing,
754 			     sizeof(src->encrypted_nothing), key, hash))
755 		goto fail;
756 
757 	/* Success! Copy everything to peer */
758 	down_write(&handshake->lock);
759 	/* It's important to check that the state is still the same, while we
760 	 * have an exclusive lock.
761 	 */
762 	if (handshake->state != state) {
763 		up_write(&handshake->lock);
764 		goto fail;
765 	}
766 	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
767 	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
768 	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
769 	handshake->remote_index = src->sender_index;
770 	handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
771 	up_write(&handshake->lock);
772 	ret_peer = peer;
773 	goto out;
774 
775 fail:
776 	wg_peer_put(peer);
777 out:
778 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
779 	memzero_explicit(hash, NOISE_HASH_LEN);
780 	memzero_explicit(chaining_key, NOISE_HASH_LEN);
781 	memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
782 	memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
783 	up_read(&wg->static_identity.lock);
784 	return ret_peer;
785 }
786 
787 bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
788 				      struct noise_keypairs *keypairs)
789 {
790 	struct noise_keypair *new_keypair;
791 	bool ret = false;
792 
793 	down_write(&handshake->lock);
794 	if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
795 	    handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
796 		goto out;
797 
798 	new_keypair = keypair_create(handshake->entry.peer);
799 	if (!new_keypair)
800 		goto out;
801 	new_keypair->i_am_the_initiator = handshake->state ==
802 					  HANDSHAKE_CONSUMED_RESPONSE;
803 	new_keypair->remote_index = handshake->remote_index;
804 
805 	if (new_keypair->i_am_the_initiator)
806 		derive_keys(&new_keypair->sending, &new_keypair->receiving,
807 			    handshake->chaining_key);
808 	else
809 		derive_keys(&new_keypair->receiving, &new_keypair->sending,
810 			    handshake->chaining_key);
811 
812 	handshake_zero(handshake);
813 	rcu_read_lock_bh();
814 	if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
815 					   handshake)->is_dead))) {
816 		add_new_keypair(keypairs, new_keypair);
817 		net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
818 				    handshake->entry.peer->device->dev->name,
819 				    new_keypair->internal_id,
820 				    handshake->entry.peer->internal_id);
821 		ret = wg_index_hashtable_replace(
822 			handshake->entry.peer->device->index_hashtable,
823 			&handshake->entry, &new_keypair->entry);
824 	} else {
825 		kzfree(new_keypair);
826 	}
827 	rcu_read_unlock_bh();
828 
829 out:
830 	up_write(&handshake->lock);
831 	return ret;
832 }
833