xref: /openbmc/linux/drivers/net/wireguard/noise.c (revision 15e3ae36)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5 
6 #include "noise.h"
7 #include "device.h"
8 #include "peer.h"
9 #include "messages.h"
10 #include "queueing.h"
11 #include "peerlookup.h"
12 
13 #include <linux/rcupdate.h>
14 #include <linux/slab.h>
15 #include <linux/bitmap.h>
16 #include <linux/scatterlist.h>
17 #include <linux/highmem.h>
18 #include <crypto/algapi.h>
19 
20 /* This implements Noise_IKpsk2:
21  *
22  * <- s
23  * ******
24  * -> e, es, s, ss, {t}
25  * <- e, ee, se, psk, {}
26  */
27 
28 static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
29 static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
30 static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
31 static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
32 static atomic64_t keypair_counter = ATOMIC64_INIT(0);
33 
34 void __init wg_noise_init(void)
35 {
36 	struct blake2s_state blake;
37 
38 	blake2s(handshake_init_chaining_key, handshake_name, NULL,
39 		NOISE_HASH_LEN, sizeof(handshake_name), 0);
40 	blake2s_init(&blake, NOISE_HASH_LEN);
41 	blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
42 	blake2s_update(&blake, identifier_name, sizeof(identifier_name));
43 	blake2s_final(&blake, handshake_init_hash);
44 }
45 
46 /* Must hold peer->handshake.static_identity->lock */
47 void wg_noise_precompute_static_static(struct wg_peer *peer)
48 {
49 	down_write(&peer->handshake.lock);
50 	if (!peer->handshake.static_identity->has_identity ||
51 	    !curve25519(peer->handshake.precomputed_static_static,
52 			peer->handshake.static_identity->static_private,
53 			peer->handshake.remote_static))
54 		memset(peer->handshake.precomputed_static_static, 0,
55 		       NOISE_PUBLIC_KEY_LEN);
56 	up_write(&peer->handshake.lock);
57 }
58 
59 void wg_noise_handshake_init(struct noise_handshake *handshake,
60 			     struct noise_static_identity *static_identity,
61 			     const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
62 			     const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
63 			     struct wg_peer *peer)
64 {
65 	memset(handshake, 0, sizeof(*handshake));
66 	init_rwsem(&handshake->lock);
67 	handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
68 	handshake->entry.peer = peer;
69 	memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
70 	if (peer_preshared_key)
71 		memcpy(handshake->preshared_key, peer_preshared_key,
72 		       NOISE_SYMMETRIC_KEY_LEN);
73 	handshake->static_identity = static_identity;
74 	handshake->state = HANDSHAKE_ZEROED;
75 	wg_noise_precompute_static_static(peer);
76 }
77 
78 static void handshake_zero(struct noise_handshake *handshake)
79 {
80 	memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
81 	memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
82 	memset(&handshake->hash, 0, NOISE_HASH_LEN);
83 	memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
84 	handshake->remote_index = 0;
85 	handshake->state = HANDSHAKE_ZEROED;
86 }
87 
88 void wg_noise_handshake_clear(struct noise_handshake *handshake)
89 {
90 	wg_index_hashtable_remove(
91 			handshake->entry.peer->device->index_hashtable,
92 			&handshake->entry);
93 	down_write(&handshake->lock);
94 	handshake_zero(handshake);
95 	up_write(&handshake->lock);
96 	wg_index_hashtable_remove(
97 			handshake->entry.peer->device->index_hashtable,
98 			&handshake->entry);
99 }
100 
101 static struct noise_keypair *keypair_create(struct wg_peer *peer)
102 {
103 	struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
104 
105 	if (unlikely(!keypair))
106 		return NULL;
107 	keypair->internal_id = atomic64_inc_return(&keypair_counter);
108 	keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
109 	keypair->entry.peer = peer;
110 	kref_init(&keypair->refcount);
111 	return keypair;
112 }
113 
114 static void keypair_free_rcu(struct rcu_head *rcu)
115 {
116 	kzfree(container_of(rcu, struct noise_keypair, rcu));
117 }
118 
119 static void keypair_free_kref(struct kref *kref)
120 {
121 	struct noise_keypair *keypair =
122 		container_of(kref, struct noise_keypair, refcount);
123 
124 	net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
125 			    keypair->entry.peer->device->dev->name,
126 			    keypair->internal_id,
127 			    keypair->entry.peer->internal_id);
128 	wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
129 				  &keypair->entry);
130 	call_rcu(&keypair->rcu, keypair_free_rcu);
131 }
132 
133 void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
134 {
135 	if (unlikely(!keypair))
136 		return;
137 	if (unlikely(unreference_now))
138 		wg_index_hashtable_remove(
139 			keypair->entry.peer->device->index_hashtable,
140 			&keypair->entry);
141 	kref_put(&keypair->refcount, keypair_free_kref);
142 }
143 
144 struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
145 {
146 	RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
147 		"Taking noise keypair reference without holding the RCU BH read lock");
148 	if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
149 		return NULL;
150 	return keypair;
151 }
152 
153 void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
154 {
155 	struct noise_keypair *old;
156 
157 	spin_lock_bh(&keypairs->keypair_update_lock);
158 
159 	/* We zero the next_keypair before zeroing the others, so that
160 	 * wg_noise_received_with_keypair returns early before subsequent ones
161 	 * are zeroed.
162 	 */
163 	old = rcu_dereference_protected(keypairs->next_keypair,
164 		lockdep_is_held(&keypairs->keypair_update_lock));
165 	RCU_INIT_POINTER(keypairs->next_keypair, NULL);
166 	wg_noise_keypair_put(old, true);
167 
168 	old = rcu_dereference_protected(keypairs->previous_keypair,
169 		lockdep_is_held(&keypairs->keypair_update_lock));
170 	RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
171 	wg_noise_keypair_put(old, true);
172 
173 	old = rcu_dereference_protected(keypairs->current_keypair,
174 		lockdep_is_held(&keypairs->keypair_update_lock));
175 	RCU_INIT_POINTER(keypairs->current_keypair, NULL);
176 	wg_noise_keypair_put(old, true);
177 
178 	spin_unlock_bh(&keypairs->keypair_update_lock);
179 }
180 
181 void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
182 {
183 	struct noise_keypair *keypair;
184 
185 	wg_noise_handshake_clear(&peer->handshake);
186 	wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
187 
188 	spin_lock_bh(&peer->keypairs.keypair_update_lock);
189 	keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
190 			lockdep_is_held(&peer->keypairs.keypair_update_lock));
191 	if (keypair)
192 		keypair->sending.is_valid = false;
193 	keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
194 			lockdep_is_held(&peer->keypairs.keypair_update_lock));
195 	if (keypair)
196 		keypair->sending.is_valid = false;
197 	spin_unlock_bh(&peer->keypairs.keypair_update_lock);
198 }
199 
200 static void add_new_keypair(struct noise_keypairs *keypairs,
201 			    struct noise_keypair *new_keypair)
202 {
203 	struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
204 
205 	spin_lock_bh(&keypairs->keypair_update_lock);
206 	previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
207 		lockdep_is_held(&keypairs->keypair_update_lock));
208 	next_keypair = rcu_dereference_protected(keypairs->next_keypair,
209 		lockdep_is_held(&keypairs->keypair_update_lock));
210 	current_keypair = rcu_dereference_protected(keypairs->current_keypair,
211 		lockdep_is_held(&keypairs->keypair_update_lock));
212 	if (new_keypair->i_am_the_initiator) {
213 		/* If we're the initiator, it means we've sent a handshake, and
214 		 * received a confirmation response, which means this new
215 		 * keypair can now be used.
216 		 */
217 		if (next_keypair) {
218 			/* If there already was a next keypair pending, we
219 			 * demote it to be the previous keypair, and free the
220 			 * existing current. Note that this means KCI can result
221 			 * in this transition. It would perhaps be more sound to
222 			 * always just get rid of the unused next keypair
223 			 * instead of putting it in the previous slot, but this
224 			 * might be a bit less robust. Something to think about
225 			 * for the future.
226 			 */
227 			RCU_INIT_POINTER(keypairs->next_keypair, NULL);
228 			rcu_assign_pointer(keypairs->previous_keypair,
229 					   next_keypair);
230 			wg_noise_keypair_put(current_keypair, true);
231 		} else /* If there wasn't an existing next keypair, we replace
232 			* the previous with the current one.
233 			*/
234 			rcu_assign_pointer(keypairs->previous_keypair,
235 					   current_keypair);
236 		/* At this point we can get rid of the old previous keypair, and
237 		 * set up the new keypair.
238 		 */
239 		wg_noise_keypair_put(previous_keypair, true);
240 		rcu_assign_pointer(keypairs->current_keypair, new_keypair);
241 	} else {
242 		/* If we're the responder, it means we can't use the new keypair
243 		 * until we receive confirmation via the first data packet, so
244 		 * we get rid of the existing previous one, the possibly
245 		 * existing next one, and slide in the new next one.
246 		 */
247 		rcu_assign_pointer(keypairs->next_keypair, new_keypair);
248 		wg_noise_keypair_put(next_keypair, true);
249 		RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
250 		wg_noise_keypair_put(previous_keypair, true);
251 	}
252 	spin_unlock_bh(&keypairs->keypair_update_lock);
253 }
254 
255 bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
256 				    struct noise_keypair *received_keypair)
257 {
258 	struct noise_keypair *old_keypair;
259 	bool key_is_new;
260 
261 	/* We first check without taking the spinlock. */
262 	key_is_new = received_keypair ==
263 		     rcu_access_pointer(keypairs->next_keypair);
264 	if (likely(!key_is_new))
265 		return false;
266 
267 	spin_lock_bh(&keypairs->keypair_update_lock);
268 	/* After locking, we double check that things didn't change from
269 	 * beneath us.
270 	 */
271 	if (unlikely(received_keypair !=
272 		    rcu_dereference_protected(keypairs->next_keypair,
273 			    lockdep_is_held(&keypairs->keypair_update_lock)))) {
274 		spin_unlock_bh(&keypairs->keypair_update_lock);
275 		return false;
276 	}
277 
278 	/* When we've finally received the confirmation, we slide the next
279 	 * into the current, the current into the previous, and get rid of
280 	 * the old previous.
281 	 */
282 	old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
283 		lockdep_is_held(&keypairs->keypair_update_lock));
284 	rcu_assign_pointer(keypairs->previous_keypair,
285 		rcu_dereference_protected(keypairs->current_keypair,
286 			lockdep_is_held(&keypairs->keypair_update_lock)));
287 	wg_noise_keypair_put(old_keypair, true);
288 	rcu_assign_pointer(keypairs->current_keypair, received_keypair);
289 	RCU_INIT_POINTER(keypairs->next_keypair, NULL);
290 
291 	spin_unlock_bh(&keypairs->keypair_update_lock);
292 	return true;
293 }
294 
295 /* Must hold static_identity->lock */
296 void wg_noise_set_static_identity_private_key(
297 	struct noise_static_identity *static_identity,
298 	const u8 private_key[NOISE_PUBLIC_KEY_LEN])
299 {
300 	memcpy(static_identity->static_private, private_key,
301 	       NOISE_PUBLIC_KEY_LEN);
302 	curve25519_clamp_secret(static_identity->static_private);
303 	static_identity->has_identity = curve25519_generate_public(
304 		static_identity->static_public, private_key);
305 }
306 
307 /* This is Hugo Krawczyk's HKDF:
308  *  - https://eprint.iacr.org/2010/264.pdf
309  *  - https://tools.ietf.org/html/rfc5869
310  */
311 static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
312 		size_t first_len, size_t second_len, size_t third_len,
313 		size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
314 {
315 	u8 output[BLAKE2S_HASH_SIZE + 1];
316 	u8 secret[BLAKE2S_HASH_SIZE];
317 
318 	WARN_ON(IS_ENABLED(DEBUG) &&
319 		(first_len > BLAKE2S_HASH_SIZE ||
320 		 second_len > BLAKE2S_HASH_SIZE ||
321 		 third_len > BLAKE2S_HASH_SIZE ||
322 		 ((second_len || second_dst || third_len || third_dst) &&
323 		  (!first_len || !first_dst)) ||
324 		 ((third_len || third_dst) && (!second_len || !second_dst))));
325 
326 	/* Extract entropy from data into secret */
327 	blake2s256_hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
328 
329 	if (!first_dst || !first_len)
330 		goto out;
331 
332 	/* Expand first key: key = secret, data = 0x1 */
333 	output[0] = 1;
334 	blake2s256_hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
335 	memcpy(first_dst, output, first_len);
336 
337 	if (!second_dst || !second_len)
338 		goto out;
339 
340 	/* Expand second key: key = secret, data = first-key || 0x2 */
341 	output[BLAKE2S_HASH_SIZE] = 2;
342 	blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
343 			BLAKE2S_HASH_SIZE);
344 	memcpy(second_dst, output, second_len);
345 
346 	if (!third_dst || !third_len)
347 		goto out;
348 
349 	/* Expand third key: key = secret, data = second-key || 0x3 */
350 	output[BLAKE2S_HASH_SIZE] = 3;
351 	blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
352 			BLAKE2S_HASH_SIZE);
353 	memcpy(third_dst, output, third_len);
354 
355 out:
356 	/* Clear sensitive data from stack */
357 	memzero_explicit(secret, BLAKE2S_HASH_SIZE);
358 	memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
359 }
360 
361 static void symmetric_key_init(struct noise_symmetric_key *key)
362 {
363 	spin_lock_init(&key->counter.receive.lock);
364 	atomic64_set(&key->counter.counter, 0);
365 	memset(key->counter.receive.backtrack, 0,
366 	       sizeof(key->counter.receive.backtrack));
367 	key->birthdate = ktime_get_coarse_boottime_ns();
368 	key->is_valid = true;
369 }
370 
371 static void derive_keys(struct noise_symmetric_key *first_dst,
372 			struct noise_symmetric_key *second_dst,
373 			const u8 chaining_key[NOISE_HASH_LEN])
374 {
375 	kdf(first_dst->key, second_dst->key, NULL, NULL,
376 	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
377 	    chaining_key);
378 	symmetric_key_init(first_dst);
379 	symmetric_key_init(second_dst);
380 }
381 
382 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
383 				u8 key[NOISE_SYMMETRIC_KEY_LEN],
384 				const u8 private[NOISE_PUBLIC_KEY_LEN],
385 				const u8 public[NOISE_PUBLIC_KEY_LEN])
386 {
387 	u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
388 
389 	if (unlikely(!curve25519(dh_calculation, private, public)))
390 		return false;
391 	kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
392 	    NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
393 	memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
394 	return true;
395 }
396 
397 static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
398 					    u8 key[NOISE_SYMMETRIC_KEY_LEN],
399 					    const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
400 {
401 	static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
402 	if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
403 		return false;
404 	kdf(chaining_key, key, NULL, precomputed, NOISE_HASH_LEN,
405 	    NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
406 	    chaining_key);
407 	return true;
408 }
409 
410 static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
411 {
412 	struct blake2s_state blake;
413 
414 	blake2s_init(&blake, NOISE_HASH_LEN);
415 	blake2s_update(&blake, hash, NOISE_HASH_LEN);
416 	blake2s_update(&blake, src, src_len);
417 	blake2s_final(&blake, hash);
418 }
419 
420 static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
421 		    u8 key[NOISE_SYMMETRIC_KEY_LEN],
422 		    const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
423 {
424 	u8 temp_hash[NOISE_HASH_LEN];
425 
426 	kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
427 	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
428 	mix_hash(hash, temp_hash, NOISE_HASH_LEN);
429 	memzero_explicit(temp_hash, NOISE_HASH_LEN);
430 }
431 
432 static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
433 			   u8 hash[NOISE_HASH_LEN],
434 			   const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
435 {
436 	memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
437 	memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
438 	mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
439 }
440 
441 static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
442 			    size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
443 			    u8 hash[NOISE_HASH_LEN])
444 {
445 	chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
446 				 NOISE_HASH_LEN,
447 				 0 /* Always zero for Noise_IK */, key);
448 	mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
449 }
450 
451 static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
452 			    size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
453 			    u8 hash[NOISE_HASH_LEN])
454 {
455 	if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
456 				      hash, NOISE_HASH_LEN,
457 				      0 /* Always zero for Noise_IK */, key))
458 		return false;
459 	mix_hash(hash, src_ciphertext, src_len);
460 	return true;
461 }
462 
463 static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
464 			      const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
465 			      u8 chaining_key[NOISE_HASH_LEN],
466 			      u8 hash[NOISE_HASH_LEN])
467 {
468 	if (ephemeral_dst != ephemeral_src)
469 		memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
470 	mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
471 	kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
472 	    NOISE_PUBLIC_KEY_LEN, chaining_key);
473 }
474 
475 static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
476 {
477 	struct timespec64 now;
478 
479 	ktime_get_real_ts64(&now);
480 
481 	/* In order to prevent some sort of infoleak from precise timers, we
482 	 * round down the nanoseconds part to the closest rounded-down power of
483 	 * two to the maximum initiations per second allowed anyway by the
484 	 * implementation.
485 	 */
486 	now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
487 		rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
488 
489 	/* https://cr.yp.to/libtai/tai64.html */
490 	*(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
491 	*(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
492 }
493 
494 bool
495 wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
496 				     struct noise_handshake *handshake)
497 {
498 	u8 timestamp[NOISE_TIMESTAMP_LEN];
499 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
500 	bool ret = false;
501 
502 	/* We need to wait for crng _before_ taking any locks, since
503 	 * curve25519_generate_secret uses get_random_bytes_wait.
504 	 */
505 	wait_for_random_bytes();
506 
507 	down_read(&handshake->static_identity->lock);
508 	down_write(&handshake->lock);
509 
510 	if (unlikely(!handshake->static_identity->has_identity))
511 		goto out;
512 
513 	dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
514 
515 	handshake_init(handshake->chaining_key, handshake->hash,
516 		       handshake->remote_static);
517 
518 	/* e */
519 	curve25519_generate_secret(handshake->ephemeral_private);
520 	if (!curve25519_generate_public(dst->unencrypted_ephemeral,
521 					handshake->ephemeral_private))
522 		goto out;
523 	message_ephemeral(dst->unencrypted_ephemeral,
524 			  dst->unencrypted_ephemeral, handshake->chaining_key,
525 			  handshake->hash);
526 
527 	/* es */
528 	if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
529 		    handshake->remote_static))
530 		goto out;
531 
532 	/* s */
533 	message_encrypt(dst->encrypted_static,
534 			handshake->static_identity->static_public,
535 			NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
536 
537 	/* ss */
538 	if (!mix_precomputed_dh(handshake->chaining_key, key,
539 				handshake->precomputed_static_static))
540 		goto out;
541 
542 	/* {t} */
543 	tai64n_now(timestamp);
544 	message_encrypt(dst->encrypted_timestamp, timestamp,
545 			NOISE_TIMESTAMP_LEN, key, handshake->hash);
546 
547 	dst->sender_index = wg_index_hashtable_insert(
548 		handshake->entry.peer->device->index_hashtable,
549 		&handshake->entry);
550 
551 	handshake->state = HANDSHAKE_CREATED_INITIATION;
552 	ret = true;
553 
554 out:
555 	up_write(&handshake->lock);
556 	up_read(&handshake->static_identity->lock);
557 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
558 	return ret;
559 }
560 
561 struct wg_peer *
562 wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
563 				      struct wg_device *wg)
564 {
565 	struct wg_peer *peer = NULL, *ret_peer = NULL;
566 	struct noise_handshake *handshake;
567 	bool replay_attack, flood_attack;
568 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
569 	u8 chaining_key[NOISE_HASH_LEN];
570 	u8 hash[NOISE_HASH_LEN];
571 	u8 s[NOISE_PUBLIC_KEY_LEN];
572 	u8 e[NOISE_PUBLIC_KEY_LEN];
573 	u8 t[NOISE_TIMESTAMP_LEN];
574 	u64 initiation_consumption;
575 
576 	down_read(&wg->static_identity.lock);
577 	if (unlikely(!wg->static_identity.has_identity))
578 		goto out;
579 
580 	handshake_init(chaining_key, hash, wg->static_identity.static_public);
581 
582 	/* e */
583 	message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
584 
585 	/* es */
586 	if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
587 		goto out;
588 
589 	/* s */
590 	if (!message_decrypt(s, src->encrypted_static,
591 			     sizeof(src->encrypted_static), key, hash))
592 		goto out;
593 
594 	/* Lookup which peer we're actually talking to */
595 	peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
596 	if (!peer)
597 		goto out;
598 	handshake = &peer->handshake;
599 
600 	/* ss */
601 	if (!mix_precomputed_dh(chaining_key, key,
602 				handshake->precomputed_static_static))
603 	    goto out;
604 
605 	/* {t} */
606 	if (!message_decrypt(t, src->encrypted_timestamp,
607 			     sizeof(src->encrypted_timestamp), key, hash))
608 		goto out;
609 
610 	down_read(&handshake->lock);
611 	replay_attack = memcmp(t, handshake->latest_timestamp,
612 			       NOISE_TIMESTAMP_LEN) <= 0;
613 	flood_attack = (s64)handshake->last_initiation_consumption +
614 			       NSEC_PER_SEC / INITIATIONS_PER_SECOND >
615 		       (s64)ktime_get_coarse_boottime_ns();
616 	up_read(&handshake->lock);
617 	if (replay_attack || flood_attack)
618 		goto out;
619 
620 	/* Success! Copy everything to peer */
621 	down_write(&handshake->lock);
622 	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
623 	if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
624 		memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
625 	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
626 	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
627 	handshake->remote_index = src->sender_index;
628 	if ((s64)(handshake->last_initiation_consumption -
629 	    (initiation_consumption = ktime_get_coarse_boottime_ns())) < 0)
630 		handshake->last_initiation_consumption = initiation_consumption;
631 	handshake->state = HANDSHAKE_CONSUMED_INITIATION;
632 	up_write(&handshake->lock);
633 	ret_peer = peer;
634 
635 out:
636 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
637 	memzero_explicit(hash, NOISE_HASH_LEN);
638 	memzero_explicit(chaining_key, NOISE_HASH_LEN);
639 	up_read(&wg->static_identity.lock);
640 	if (!ret_peer)
641 		wg_peer_put(peer);
642 	return ret_peer;
643 }
644 
645 bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
646 					struct noise_handshake *handshake)
647 {
648 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
649 	bool ret = false;
650 
651 	/* We need to wait for crng _before_ taking any locks, since
652 	 * curve25519_generate_secret uses get_random_bytes_wait.
653 	 */
654 	wait_for_random_bytes();
655 
656 	down_read(&handshake->static_identity->lock);
657 	down_write(&handshake->lock);
658 
659 	if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
660 		goto out;
661 
662 	dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
663 	dst->receiver_index = handshake->remote_index;
664 
665 	/* e */
666 	curve25519_generate_secret(handshake->ephemeral_private);
667 	if (!curve25519_generate_public(dst->unencrypted_ephemeral,
668 					handshake->ephemeral_private))
669 		goto out;
670 	message_ephemeral(dst->unencrypted_ephemeral,
671 			  dst->unencrypted_ephemeral, handshake->chaining_key,
672 			  handshake->hash);
673 
674 	/* ee */
675 	if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
676 		    handshake->remote_ephemeral))
677 		goto out;
678 
679 	/* se */
680 	if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
681 		    handshake->remote_static))
682 		goto out;
683 
684 	/* psk */
685 	mix_psk(handshake->chaining_key, handshake->hash, key,
686 		handshake->preshared_key);
687 
688 	/* {} */
689 	message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
690 
691 	dst->sender_index = wg_index_hashtable_insert(
692 		handshake->entry.peer->device->index_hashtable,
693 		&handshake->entry);
694 
695 	handshake->state = HANDSHAKE_CREATED_RESPONSE;
696 	ret = true;
697 
698 out:
699 	up_write(&handshake->lock);
700 	up_read(&handshake->static_identity->lock);
701 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
702 	return ret;
703 }
704 
705 struct wg_peer *
706 wg_noise_handshake_consume_response(struct message_handshake_response *src,
707 				    struct wg_device *wg)
708 {
709 	enum noise_handshake_state state = HANDSHAKE_ZEROED;
710 	struct wg_peer *peer = NULL, *ret_peer = NULL;
711 	struct noise_handshake *handshake;
712 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
713 	u8 hash[NOISE_HASH_LEN];
714 	u8 chaining_key[NOISE_HASH_LEN];
715 	u8 e[NOISE_PUBLIC_KEY_LEN];
716 	u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
717 	u8 static_private[NOISE_PUBLIC_KEY_LEN];
718 
719 	down_read(&wg->static_identity.lock);
720 
721 	if (unlikely(!wg->static_identity.has_identity))
722 		goto out;
723 
724 	handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
725 		wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
726 		src->receiver_index, &peer);
727 	if (unlikely(!handshake))
728 		goto out;
729 
730 	down_read(&handshake->lock);
731 	state = handshake->state;
732 	memcpy(hash, handshake->hash, NOISE_HASH_LEN);
733 	memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
734 	memcpy(ephemeral_private, handshake->ephemeral_private,
735 	       NOISE_PUBLIC_KEY_LEN);
736 	up_read(&handshake->lock);
737 
738 	if (state != HANDSHAKE_CREATED_INITIATION)
739 		goto fail;
740 
741 	/* e */
742 	message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
743 
744 	/* ee */
745 	if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
746 		goto fail;
747 
748 	/* se */
749 	if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
750 		goto fail;
751 
752 	/* psk */
753 	mix_psk(chaining_key, hash, key, handshake->preshared_key);
754 
755 	/* {} */
756 	if (!message_decrypt(NULL, src->encrypted_nothing,
757 			     sizeof(src->encrypted_nothing), key, hash))
758 		goto fail;
759 
760 	/* Success! Copy everything to peer */
761 	down_write(&handshake->lock);
762 	/* It's important to check that the state is still the same, while we
763 	 * have an exclusive lock.
764 	 */
765 	if (handshake->state != state) {
766 		up_write(&handshake->lock);
767 		goto fail;
768 	}
769 	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
770 	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
771 	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
772 	handshake->remote_index = src->sender_index;
773 	handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
774 	up_write(&handshake->lock);
775 	ret_peer = peer;
776 	goto out;
777 
778 fail:
779 	wg_peer_put(peer);
780 out:
781 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
782 	memzero_explicit(hash, NOISE_HASH_LEN);
783 	memzero_explicit(chaining_key, NOISE_HASH_LEN);
784 	memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
785 	memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
786 	up_read(&wg->static_identity.lock);
787 	return ret_peer;
788 }
789 
790 bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
791 				      struct noise_keypairs *keypairs)
792 {
793 	struct noise_keypair *new_keypair;
794 	bool ret = false;
795 
796 	down_write(&handshake->lock);
797 	if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
798 	    handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
799 		goto out;
800 
801 	new_keypair = keypair_create(handshake->entry.peer);
802 	if (!new_keypair)
803 		goto out;
804 	new_keypair->i_am_the_initiator = handshake->state ==
805 					  HANDSHAKE_CONSUMED_RESPONSE;
806 	new_keypair->remote_index = handshake->remote_index;
807 
808 	if (new_keypair->i_am_the_initiator)
809 		derive_keys(&new_keypair->sending, &new_keypair->receiving,
810 			    handshake->chaining_key);
811 	else
812 		derive_keys(&new_keypair->receiving, &new_keypair->sending,
813 			    handshake->chaining_key);
814 
815 	handshake_zero(handshake);
816 	rcu_read_lock_bh();
817 	if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
818 					   handshake)->is_dead))) {
819 		add_new_keypair(keypairs, new_keypair);
820 		net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
821 				    handshake->entry.peer->device->dev->name,
822 				    new_keypair->internal_id,
823 				    handshake->entry.peer->internal_id);
824 		ret = wg_index_hashtable_replace(
825 			handshake->entry.peer->device->index_hashtable,
826 			&handshake->entry, &new_keypair->entry);
827 	} else {
828 		kzfree(new_keypair);
829 	}
830 	rcu_read_unlock_bh();
831 
832 out:
833 	up_write(&handshake->lock);
834 	return ret;
835 }
836