xref: /openbmc/linux/drivers/net/wireguard/noise.c (revision 4464005a12b5c79e1a364e6272ee10a83413f928)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5 
6 #include "noise.h"
7 #include "device.h"
8 #include "peer.h"
9 #include "messages.h"
10 #include "queueing.h"
11 #include "peerlookup.h"
12 
13 #include <linux/rcupdate.h>
14 #include <linux/slab.h>
15 #include <linux/bitmap.h>
16 #include <linux/scatterlist.h>
17 #include <linux/highmem.h>
18 #include <crypto/algapi.h>
19 
20 /* This implements Noise_IKpsk2:
21  *
22  * <- s
23  * ******
24  * -> e, es, s, ss, {t}
25  * <- e, ee, se, psk, {}
26  */
27 
28 static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
29 static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
30 static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
31 static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
32 static atomic64_t keypair_counter = ATOMIC64_INIT(0);
33 
34 void __init wg_noise_init(void)
35 {
36 	struct blake2s_state blake;
37 
38 	blake2s(handshake_init_chaining_key, handshake_name, NULL,
39 		NOISE_HASH_LEN, sizeof(handshake_name), 0);
40 	blake2s_init(&blake, NOISE_HASH_LEN);
41 	blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
42 	blake2s_update(&blake, identifier_name, sizeof(identifier_name));
43 	blake2s_final(&blake, handshake_init_hash);
44 }
45 
46 /* Must hold peer->handshake.static_identity->lock */
47 void wg_noise_precompute_static_static(struct wg_peer *peer)
48 {
49 	down_write(&peer->handshake.lock);
50 	if (!peer->handshake.static_identity->has_identity ||
51 	    !curve25519(peer->handshake.precomputed_static_static,
52 			peer->handshake.static_identity->static_private,
53 			peer->handshake.remote_static))
54 		memset(peer->handshake.precomputed_static_static, 0,
55 		       NOISE_PUBLIC_KEY_LEN);
56 	up_write(&peer->handshake.lock);
57 }
58 
59 void wg_noise_handshake_init(struct noise_handshake *handshake,
60 			     struct noise_static_identity *static_identity,
61 			     const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
62 			     const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
63 			     struct wg_peer *peer)
64 {
65 	memset(handshake, 0, sizeof(*handshake));
66 	init_rwsem(&handshake->lock);
67 	handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
68 	handshake->entry.peer = peer;
69 	memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
70 	if (peer_preshared_key)
71 		memcpy(handshake->preshared_key, peer_preshared_key,
72 		       NOISE_SYMMETRIC_KEY_LEN);
73 	handshake->static_identity = static_identity;
74 	handshake->state = HANDSHAKE_ZEROED;
75 	wg_noise_precompute_static_static(peer);
76 }
77 
78 static void handshake_zero(struct noise_handshake *handshake)
79 {
80 	memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
81 	memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
82 	memset(&handshake->hash, 0, NOISE_HASH_LEN);
83 	memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
84 	handshake->remote_index = 0;
85 	handshake->state = HANDSHAKE_ZEROED;
86 }
87 
88 void wg_noise_handshake_clear(struct noise_handshake *handshake)
89 {
90 	wg_index_hashtable_remove(
91 			handshake->entry.peer->device->index_hashtable,
92 			&handshake->entry);
93 	down_write(&handshake->lock);
94 	handshake_zero(handshake);
95 	up_write(&handshake->lock);
96 	wg_index_hashtable_remove(
97 			handshake->entry.peer->device->index_hashtable,
98 			&handshake->entry);
99 }
100 
101 static struct noise_keypair *keypair_create(struct wg_peer *peer)
102 {
103 	struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
104 
105 	if (unlikely(!keypair))
106 		return NULL;
107 	spin_lock_init(&keypair->receiving_counter.lock);
108 	keypair->internal_id = atomic64_inc_return(&keypair_counter);
109 	keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
110 	keypair->entry.peer = peer;
111 	kref_init(&keypair->refcount);
112 	return keypair;
113 }
114 
115 static void keypair_free_rcu(struct rcu_head *rcu)
116 {
117 	kzfree(container_of(rcu, struct noise_keypair, rcu));
118 }
119 
120 static void keypair_free_kref(struct kref *kref)
121 {
122 	struct noise_keypair *keypair =
123 		container_of(kref, struct noise_keypair, refcount);
124 
125 	net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
126 			    keypair->entry.peer->device->dev->name,
127 			    keypair->internal_id,
128 			    keypair->entry.peer->internal_id);
129 	wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
130 				  &keypair->entry);
131 	call_rcu(&keypair->rcu, keypair_free_rcu);
132 }
133 
134 void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
135 {
136 	if (unlikely(!keypair))
137 		return;
138 	if (unlikely(unreference_now))
139 		wg_index_hashtable_remove(
140 			keypair->entry.peer->device->index_hashtable,
141 			&keypair->entry);
142 	kref_put(&keypair->refcount, keypair_free_kref);
143 }
144 
145 struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
146 {
147 	RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
148 		"Taking noise keypair reference without holding the RCU BH read lock");
149 	if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
150 		return NULL;
151 	return keypair;
152 }
153 
154 void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
155 {
156 	struct noise_keypair *old;
157 
158 	spin_lock_bh(&keypairs->keypair_update_lock);
159 
160 	/* We zero the next_keypair before zeroing the others, so that
161 	 * wg_noise_received_with_keypair returns early before subsequent ones
162 	 * are zeroed.
163 	 */
164 	old = rcu_dereference_protected(keypairs->next_keypair,
165 		lockdep_is_held(&keypairs->keypair_update_lock));
166 	RCU_INIT_POINTER(keypairs->next_keypair, NULL);
167 	wg_noise_keypair_put(old, true);
168 
169 	old = rcu_dereference_protected(keypairs->previous_keypair,
170 		lockdep_is_held(&keypairs->keypair_update_lock));
171 	RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
172 	wg_noise_keypair_put(old, true);
173 
174 	old = rcu_dereference_protected(keypairs->current_keypair,
175 		lockdep_is_held(&keypairs->keypair_update_lock));
176 	RCU_INIT_POINTER(keypairs->current_keypair, NULL);
177 	wg_noise_keypair_put(old, true);
178 
179 	spin_unlock_bh(&keypairs->keypair_update_lock);
180 }
181 
182 void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
183 {
184 	struct noise_keypair *keypair;
185 
186 	wg_noise_handshake_clear(&peer->handshake);
187 	wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
188 
189 	spin_lock_bh(&peer->keypairs.keypair_update_lock);
190 	keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
191 			lockdep_is_held(&peer->keypairs.keypair_update_lock));
192 	if (keypair)
193 		keypair->sending.is_valid = false;
194 	keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
195 			lockdep_is_held(&peer->keypairs.keypair_update_lock));
196 	if (keypair)
197 		keypair->sending.is_valid = false;
198 	spin_unlock_bh(&peer->keypairs.keypair_update_lock);
199 }
200 
201 static void add_new_keypair(struct noise_keypairs *keypairs,
202 			    struct noise_keypair *new_keypair)
203 {
204 	struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
205 
206 	spin_lock_bh(&keypairs->keypair_update_lock);
207 	previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
208 		lockdep_is_held(&keypairs->keypair_update_lock));
209 	next_keypair = rcu_dereference_protected(keypairs->next_keypair,
210 		lockdep_is_held(&keypairs->keypair_update_lock));
211 	current_keypair = rcu_dereference_protected(keypairs->current_keypair,
212 		lockdep_is_held(&keypairs->keypair_update_lock));
213 	if (new_keypair->i_am_the_initiator) {
214 		/* If we're the initiator, it means we've sent a handshake, and
215 		 * received a confirmation response, which means this new
216 		 * keypair can now be used.
217 		 */
218 		if (next_keypair) {
219 			/* If there already was a next keypair pending, we
220 			 * demote it to be the previous keypair, and free the
221 			 * existing current. Note that this means KCI can result
222 			 * in this transition. It would perhaps be more sound to
223 			 * always just get rid of the unused next keypair
224 			 * instead of putting it in the previous slot, but this
225 			 * might be a bit less robust. Something to think about
226 			 * for the future.
227 			 */
228 			RCU_INIT_POINTER(keypairs->next_keypair, NULL);
229 			rcu_assign_pointer(keypairs->previous_keypair,
230 					   next_keypair);
231 			wg_noise_keypair_put(current_keypair, true);
232 		} else /* If there wasn't an existing next keypair, we replace
233 			* the previous with the current one.
234 			*/
235 			rcu_assign_pointer(keypairs->previous_keypair,
236 					   current_keypair);
237 		/* At this point we can get rid of the old previous keypair, and
238 		 * set up the new keypair.
239 		 */
240 		wg_noise_keypair_put(previous_keypair, true);
241 		rcu_assign_pointer(keypairs->current_keypair, new_keypair);
242 	} else {
243 		/* If we're the responder, it means we can't use the new keypair
244 		 * until we receive confirmation via the first data packet, so
245 		 * we get rid of the existing previous one, the possibly
246 		 * existing next one, and slide in the new next one.
247 		 */
248 		rcu_assign_pointer(keypairs->next_keypair, new_keypair);
249 		wg_noise_keypair_put(next_keypair, true);
250 		RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
251 		wg_noise_keypair_put(previous_keypair, true);
252 	}
253 	spin_unlock_bh(&keypairs->keypair_update_lock);
254 }
255 
256 bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
257 				    struct noise_keypair *received_keypair)
258 {
259 	struct noise_keypair *old_keypair;
260 	bool key_is_new;
261 
262 	/* We first check without taking the spinlock. */
263 	key_is_new = received_keypair ==
264 		     rcu_access_pointer(keypairs->next_keypair);
265 	if (likely(!key_is_new))
266 		return false;
267 
268 	spin_lock_bh(&keypairs->keypair_update_lock);
269 	/* After locking, we double check that things didn't change from
270 	 * beneath us.
271 	 */
272 	if (unlikely(received_keypair !=
273 		    rcu_dereference_protected(keypairs->next_keypair,
274 			    lockdep_is_held(&keypairs->keypair_update_lock)))) {
275 		spin_unlock_bh(&keypairs->keypair_update_lock);
276 		return false;
277 	}
278 
279 	/* When we've finally received the confirmation, we slide the next
280 	 * into the current, the current into the previous, and get rid of
281 	 * the old previous.
282 	 */
283 	old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
284 		lockdep_is_held(&keypairs->keypair_update_lock));
285 	rcu_assign_pointer(keypairs->previous_keypair,
286 		rcu_dereference_protected(keypairs->current_keypair,
287 			lockdep_is_held(&keypairs->keypair_update_lock)));
288 	wg_noise_keypair_put(old_keypair, true);
289 	rcu_assign_pointer(keypairs->current_keypair, received_keypair);
290 	RCU_INIT_POINTER(keypairs->next_keypair, NULL);
291 
292 	spin_unlock_bh(&keypairs->keypair_update_lock);
293 	return true;
294 }
295 
296 /* Must hold static_identity->lock */
297 void wg_noise_set_static_identity_private_key(
298 	struct noise_static_identity *static_identity,
299 	const u8 private_key[NOISE_PUBLIC_KEY_LEN])
300 {
301 	memcpy(static_identity->static_private, private_key,
302 	       NOISE_PUBLIC_KEY_LEN);
303 	curve25519_clamp_secret(static_identity->static_private);
304 	static_identity->has_identity = curve25519_generate_public(
305 		static_identity->static_public, private_key);
306 }
307 
308 /* This is Hugo Krawczyk's HKDF:
309  *  - https://eprint.iacr.org/2010/264.pdf
310  *  - https://tools.ietf.org/html/rfc5869
311  */
312 static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
313 		size_t first_len, size_t second_len, size_t third_len,
314 		size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
315 {
316 	u8 output[BLAKE2S_HASH_SIZE + 1];
317 	u8 secret[BLAKE2S_HASH_SIZE];
318 
319 	WARN_ON(IS_ENABLED(DEBUG) &&
320 		(first_len > BLAKE2S_HASH_SIZE ||
321 		 second_len > BLAKE2S_HASH_SIZE ||
322 		 third_len > BLAKE2S_HASH_SIZE ||
323 		 ((second_len || second_dst || third_len || third_dst) &&
324 		  (!first_len || !first_dst)) ||
325 		 ((third_len || third_dst) && (!second_len || !second_dst))));
326 
327 	/* Extract entropy from data into secret */
328 	blake2s256_hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
329 
330 	if (!first_dst || !first_len)
331 		goto out;
332 
333 	/* Expand first key: key = secret, data = 0x1 */
334 	output[0] = 1;
335 	blake2s256_hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
336 	memcpy(first_dst, output, first_len);
337 
338 	if (!second_dst || !second_len)
339 		goto out;
340 
341 	/* Expand second key: key = secret, data = first-key || 0x2 */
342 	output[BLAKE2S_HASH_SIZE] = 2;
343 	blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
344 			BLAKE2S_HASH_SIZE);
345 	memcpy(second_dst, output, second_len);
346 
347 	if (!third_dst || !third_len)
348 		goto out;
349 
350 	/* Expand third key: key = secret, data = second-key || 0x3 */
351 	output[BLAKE2S_HASH_SIZE] = 3;
352 	blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
353 			BLAKE2S_HASH_SIZE);
354 	memcpy(third_dst, output, third_len);
355 
356 out:
357 	/* Clear sensitive data from stack */
358 	memzero_explicit(secret, BLAKE2S_HASH_SIZE);
359 	memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
360 }
361 
362 static void derive_keys(struct noise_symmetric_key *first_dst,
363 			struct noise_symmetric_key *second_dst,
364 			const u8 chaining_key[NOISE_HASH_LEN])
365 {
366 	u64 birthdate = ktime_get_coarse_boottime_ns();
367 	kdf(first_dst->key, second_dst->key, NULL, NULL,
368 	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
369 	    chaining_key);
370 	first_dst->birthdate = second_dst->birthdate = birthdate;
371 	first_dst->is_valid = second_dst->is_valid = true;
372 }
373 
374 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
375 				u8 key[NOISE_SYMMETRIC_KEY_LEN],
376 				const u8 private[NOISE_PUBLIC_KEY_LEN],
377 				const u8 public[NOISE_PUBLIC_KEY_LEN])
378 {
379 	u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
380 
381 	if (unlikely(!curve25519(dh_calculation, private, public)))
382 		return false;
383 	kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
384 	    NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
385 	memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
386 	return true;
387 }
388 
389 static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
390 					    u8 key[NOISE_SYMMETRIC_KEY_LEN],
391 					    const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
392 {
393 	static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
394 	if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
395 		return false;
396 	kdf(chaining_key, key, NULL, precomputed, NOISE_HASH_LEN,
397 	    NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
398 	    chaining_key);
399 	return true;
400 }
401 
402 static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
403 {
404 	struct blake2s_state blake;
405 
406 	blake2s_init(&blake, NOISE_HASH_LEN);
407 	blake2s_update(&blake, hash, NOISE_HASH_LEN);
408 	blake2s_update(&blake, src, src_len);
409 	blake2s_final(&blake, hash);
410 }
411 
412 static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
413 		    u8 key[NOISE_SYMMETRIC_KEY_LEN],
414 		    const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
415 {
416 	u8 temp_hash[NOISE_HASH_LEN];
417 
418 	kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
419 	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
420 	mix_hash(hash, temp_hash, NOISE_HASH_LEN);
421 	memzero_explicit(temp_hash, NOISE_HASH_LEN);
422 }
423 
424 static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
425 			   u8 hash[NOISE_HASH_LEN],
426 			   const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
427 {
428 	memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
429 	memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
430 	mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
431 }
432 
433 static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
434 			    size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
435 			    u8 hash[NOISE_HASH_LEN])
436 {
437 	chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
438 				 NOISE_HASH_LEN,
439 				 0 /* Always zero for Noise_IK */, key);
440 	mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
441 }
442 
443 static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
444 			    size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
445 			    u8 hash[NOISE_HASH_LEN])
446 {
447 	if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
448 				      hash, NOISE_HASH_LEN,
449 				      0 /* Always zero for Noise_IK */, key))
450 		return false;
451 	mix_hash(hash, src_ciphertext, src_len);
452 	return true;
453 }
454 
455 static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
456 			      const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
457 			      u8 chaining_key[NOISE_HASH_LEN],
458 			      u8 hash[NOISE_HASH_LEN])
459 {
460 	if (ephemeral_dst != ephemeral_src)
461 		memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
462 	mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
463 	kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
464 	    NOISE_PUBLIC_KEY_LEN, chaining_key);
465 }
466 
467 static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
468 {
469 	struct timespec64 now;
470 
471 	ktime_get_real_ts64(&now);
472 
473 	/* In order to prevent some sort of infoleak from precise timers, we
474 	 * round down the nanoseconds part to the closest rounded-down power of
475 	 * two to the maximum initiations per second allowed anyway by the
476 	 * implementation.
477 	 */
478 	now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
479 		rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
480 
481 	/* https://cr.yp.to/libtai/tai64.html */
482 	*(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
483 	*(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
484 }
485 
486 bool
487 wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
488 				     struct noise_handshake *handshake)
489 {
490 	u8 timestamp[NOISE_TIMESTAMP_LEN];
491 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
492 	bool ret = false;
493 
494 	/* We need to wait for crng _before_ taking any locks, since
495 	 * curve25519_generate_secret uses get_random_bytes_wait.
496 	 */
497 	wait_for_random_bytes();
498 
499 	down_read(&handshake->static_identity->lock);
500 	down_write(&handshake->lock);
501 
502 	if (unlikely(!handshake->static_identity->has_identity))
503 		goto out;
504 
505 	dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
506 
507 	handshake_init(handshake->chaining_key, handshake->hash,
508 		       handshake->remote_static);
509 
510 	/* e */
511 	curve25519_generate_secret(handshake->ephemeral_private);
512 	if (!curve25519_generate_public(dst->unencrypted_ephemeral,
513 					handshake->ephemeral_private))
514 		goto out;
515 	message_ephemeral(dst->unencrypted_ephemeral,
516 			  dst->unencrypted_ephemeral, handshake->chaining_key,
517 			  handshake->hash);
518 
519 	/* es */
520 	if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
521 		    handshake->remote_static))
522 		goto out;
523 
524 	/* s */
525 	message_encrypt(dst->encrypted_static,
526 			handshake->static_identity->static_public,
527 			NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
528 
529 	/* ss */
530 	if (!mix_precomputed_dh(handshake->chaining_key, key,
531 				handshake->precomputed_static_static))
532 		goto out;
533 
534 	/* {t} */
535 	tai64n_now(timestamp);
536 	message_encrypt(dst->encrypted_timestamp, timestamp,
537 			NOISE_TIMESTAMP_LEN, key, handshake->hash);
538 
539 	dst->sender_index = wg_index_hashtable_insert(
540 		handshake->entry.peer->device->index_hashtable,
541 		&handshake->entry);
542 
543 	handshake->state = HANDSHAKE_CREATED_INITIATION;
544 	ret = true;
545 
546 out:
547 	up_write(&handshake->lock);
548 	up_read(&handshake->static_identity->lock);
549 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
550 	return ret;
551 }
552 
553 struct wg_peer *
554 wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
555 				      struct wg_device *wg)
556 {
557 	struct wg_peer *peer = NULL, *ret_peer = NULL;
558 	struct noise_handshake *handshake;
559 	bool replay_attack, flood_attack;
560 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
561 	u8 chaining_key[NOISE_HASH_LEN];
562 	u8 hash[NOISE_HASH_LEN];
563 	u8 s[NOISE_PUBLIC_KEY_LEN];
564 	u8 e[NOISE_PUBLIC_KEY_LEN];
565 	u8 t[NOISE_TIMESTAMP_LEN];
566 	u64 initiation_consumption;
567 
568 	down_read(&wg->static_identity.lock);
569 	if (unlikely(!wg->static_identity.has_identity))
570 		goto out;
571 
572 	handshake_init(chaining_key, hash, wg->static_identity.static_public);
573 
574 	/* e */
575 	message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
576 
577 	/* es */
578 	if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
579 		goto out;
580 
581 	/* s */
582 	if (!message_decrypt(s, src->encrypted_static,
583 			     sizeof(src->encrypted_static), key, hash))
584 		goto out;
585 
586 	/* Lookup which peer we're actually talking to */
587 	peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
588 	if (!peer)
589 		goto out;
590 	handshake = &peer->handshake;
591 
592 	/* ss */
593 	if (!mix_precomputed_dh(chaining_key, key,
594 				handshake->precomputed_static_static))
595 	    goto out;
596 
597 	/* {t} */
598 	if (!message_decrypt(t, src->encrypted_timestamp,
599 			     sizeof(src->encrypted_timestamp), key, hash))
600 		goto out;
601 
602 	down_read(&handshake->lock);
603 	replay_attack = memcmp(t, handshake->latest_timestamp,
604 			       NOISE_TIMESTAMP_LEN) <= 0;
605 	flood_attack = (s64)handshake->last_initiation_consumption +
606 			       NSEC_PER_SEC / INITIATIONS_PER_SECOND >
607 		       (s64)ktime_get_coarse_boottime_ns();
608 	up_read(&handshake->lock);
609 	if (replay_attack || flood_attack)
610 		goto out;
611 
612 	/* Success! Copy everything to peer */
613 	down_write(&handshake->lock);
614 	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
615 	if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
616 		memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
617 	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
618 	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
619 	handshake->remote_index = src->sender_index;
620 	if ((s64)(handshake->last_initiation_consumption -
621 	    (initiation_consumption = ktime_get_coarse_boottime_ns())) < 0)
622 		handshake->last_initiation_consumption = initiation_consumption;
623 	handshake->state = HANDSHAKE_CONSUMED_INITIATION;
624 	up_write(&handshake->lock);
625 	ret_peer = peer;
626 
627 out:
628 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
629 	memzero_explicit(hash, NOISE_HASH_LEN);
630 	memzero_explicit(chaining_key, NOISE_HASH_LEN);
631 	up_read(&wg->static_identity.lock);
632 	if (!ret_peer)
633 		wg_peer_put(peer);
634 	return ret_peer;
635 }
636 
637 bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
638 					struct noise_handshake *handshake)
639 {
640 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
641 	bool ret = false;
642 
643 	/* We need to wait for crng _before_ taking any locks, since
644 	 * curve25519_generate_secret uses get_random_bytes_wait.
645 	 */
646 	wait_for_random_bytes();
647 
648 	down_read(&handshake->static_identity->lock);
649 	down_write(&handshake->lock);
650 
651 	if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
652 		goto out;
653 
654 	dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
655 	dst->receiver_index = handshake->remote_index;
656 
657 	/* e */
658 	curve25519_generate_secret(handshake->ephemeral_private);
659 	if (!curve25519_generate_public(dst->unencrypted_ephemeral,
660 					handshake->ephemeral_private))
661 		goto out;
662 	message_ephemeral(dst->unencrypted_ephemeral,
663 			  dst->unencrypted_ephemeral, handshake->chaining_key,
664 			  handshake->hash);
665 
666 	/* ee */
667 	if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
668 		    handshake->remote_ephemeral))
669 		goto out;
670 
671 	/* se */
672 	if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
673 		    handshake->remote_static))
674 		goto out;
675 
676 	/* psk */
677 	mix_psk(handshake->chaining_key, handshake->hash, key,
678 		handshake->preshared_key);
679 
680 	/* {} */
681 	message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
682 
683 	dst->sender_index = wg_index_hashtable_insert(
684 		handshake->entry.peer->device->index_hashtable,
685 		&handshake->entry);
686 
687 	handshake->state = HANDSHAKE_CREATED_RESPONSE;
688 	ret = true;
689 
690 out:
691 	up_write(&handshake->lock);
692 	up_read(&handshake->static_identity->lock);
693 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
694 	return ret;
695 }
696 
697 struct wg_peer *
698 wg_noise_handshake_consume_response(struct message_handshake_response *src,
699 				    struct wg_device *wg)
700 {
701 	enum noise_handshake_state state = HANDSHAKE_ZEROED;
702 	struct wg_peer *peer = NULL, *ret_peer = NULL;
703 	struct noise_handshake *handshake;
704 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
705 	u8 hash[NOISE_HASH_LEN];
706 	u8 chaining_key[NOISE_HASH_LEN];
707 	u8 e[NOISE_PUBLIC_KEY_LEN];
708 	u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
709 	u8 static_private[NOISE_PUBLIC_KEY_LEN];
710 	u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];
711 
712 	down_read(&wg->static_identity.lock);
713 
714 	if (unlikely(!wg->static_identity.has_identity))
715 		goto out;
716 
717 	handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
718 		wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
719 		src->receiver_index, &peer);
720 	if (unlikely(!handshake))
721 		goto out;
722 
723 	down_read(&handshake->lock);
724 	state = handshake->state;
725 	memcpy(hash, handshake->hash, NOISE_HASH_LEN);
726 	memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
727 	memcpy(ephemeral_private, handshake->ephemeral_private,
728 	       NOISE_PUBLIC_KEY_LEN);
729 	memcpy(preshared_key, handshake->preshared_key,
730 	       NOISE_SYMMETRIC_KEY_LEN);
731 	up_read(&handshake->lock);
732 
733 	if (state != HANDSHAKE_CREATED_INITIATION)
734 		goto fail;
735 
736 	/* e */
737 	message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
738 
739 	/* ee */
740 	if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
741 		goto fail;
742 
743 	/* se */
744 	if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
745 		goto fail;
746 
747 	/* psk */
748 	mix_psk(chaining_key, hash, key, preshared_key);
749 
750 	/* {} */
751 	if (!message_decrypt(NULL, src->encrypted_nothing,
752 			     sizeof(src->encrypted_nothing), key, hash))
753 		goto fail;
754 
755 	/* Success! Copy everything to peer */
756 	down_write(&handshake->lock);
757 	/* It's important to check that the state is still the same, while we
758 	 * have an exclusive lock.
759 	 */
760 	if (handshake->state != state) {
761 		up_write(&handshake->lock);
762 		goto fail;
763 	}
764 	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
765 	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
766 	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
767 	handshake->remote_index = src->sender_index;
768 	handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
769 	up_write(&handshake->lock);
770 	ret_peer = peer;
771 	goto out;
772 
773 fail:
774 	wg_peer_put(peer);
775 out:
776 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
777 	memzero_explicit(hash, NOISE_HASH_LEN);
778 	memzero_explicit(chaining_key, NOISE_HASH_LEN);
779 	memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
780 	memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
781 	memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
782 	up_read(&wg->static_identity.lock);
783 	return ret_peer;
784 }
785 
786 bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
787 				      struct noise_keypairs *keypairs)
788 {
789 	struct noise_keypair *new_keypair;
790 	bool ret = false;
791 
792 	down_write(&handshake->lock);
793 	if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
794 	    handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
795 		goto out;
796 
797 	new_keypair = keypair_create(handshake->entry.peer);
798 	if (!new_keypair)
799 		goto out;
800 	new_keypair->i_am_the_initiator = handshake->state ==
801 					  HANDSHAKE_CONSUMED_RESPONSE;
802 	new_keypair->remote_index = handshake->remote_index;
803 
804 	if (new_keypair->i_am_the_initiator)
805 		derive_keys(&new_keypair->sending, &new_keypair->receiving,
806 			    handshake->chaining_key);
807 	else
808 		derive_keys(&new_keypair->receiving, &new_keypair->sending,
809 			    handshake->chaining_key);
810 
811 	handshake_zero(handshake);
812 	rcu_read_lock_bh();
813 	if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
814 					   handshake)->is_dead))) {
815 		add_new_keypair(keypairs, new_keypair);
816 		net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
817 				    handshake->entry.peer->device->dev->name,
818 				    new_keypair->internal_id,
819 				    handshake->entry.peer->internal_id);
820 		ret = wg_index_hashtable_replace(
821 			handshake->entry.peer->device->index_hashtable,
822 			&handshake->entry, &new_keypair->entry);
823 	} else {
824 		kzfree(new_keypair);
825 	}
826 	rcu_read_unlock_bh();
827 
828 out:
829 	up_write(&handshake->lock);
830 	return ret;
831 }
832