xref: /openbmc/linux/drivers/net/wireguard/noise.c (revision 0b26ca68)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5 
6 #include "noise.h"
7 #include "device.h"
8 #include "peer.h"
9 #include "messages.h"
10 #include "queueing.h"
11 #include "peerlookup.h"
12 
13 #include <linux/rcupdate.h>
14 #include <linux/slab.h>
15 #include <linux/bitmap.h>
16 #include <linux/scatterlist.h>
17 #include <linux/highmem.h>
18 #include <crypto/algapi.h>
19 
20 /* This implements Noise_IKpsk2:
21  *
22  * <- s
23  * ******
24  * -> e, es, s, ss, {t}
25  * <- e, ee, se, psk, {}
26  */
27 
28 static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
29 static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
30 static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
31 static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
32 static atomic64_t keypair_counter = ATOMIC64_INIT(0);
33 
34 void __init wg_noise_init(void)
35 {
36 	struct blake2s_state blake;
37 
38 	blake2s(handshake_init_chaining_key, handshake_name, NULL,
39 		NOISE_HASH_LEN, sizeof(handshake_name), 0);
40 	blake2s_init(&blake, NOISE_HASH_LEN);
41 	blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
42 	blake2s_update(&blake, identifier_name, sizeof(identifier_name));
43 	blake2s_final(&blake, handshake_init_hash);
44 }
45 
46 /* Must hold peer->handshake.static_identity->lock */
47 void wg_noise_precompute_static_static(struct wg_peer *peer)
48 {
49 	down_write(&peer->handshake.lock);
50 	if (!peer->handshake.static_identity->has_identity ||
51 	    !curve25519(peer->handshake.precomputed_static_static,
52 			peer->handshake.static_identity->static_private,
53 			peer->handshake.remote_static))
54 		memset(peer->handshake.precomputed_static_static, 0,
55 		       NOISE_PUBLIC_KEY_LEN);
56 	up_write(&peer->handshake.lock);
57 }
58 
59 void wg_noise_handshake_init(struct noise_handshake *handshake,
60 			     struct noise_static_identity *static_identity,
61 			     const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
62 			     const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
63 			     struct wg_peer *peer)
64 {
65 	memset(handshake, 0, sizeof(*handshake));
66 	init_rwsem(&handshake->lock);
67 	handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
68 	handshake->entry.peer = peer;
69 	memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
70 	if (peer_preshared_key)
71 		memcpy(handshake->preshared_key, peer_preshared_key,
72 		       NOISE_SYMMETRIC_KEY_LEN);
73 	handshake->static_identity = static_identity;
74 	handshake->state = HANDSHAKE_ZEROED;
75 	wg_noise_precompute_static_static(peer);
76 }
77 
78 static void handshake_zero(struct noise_handshake *handshake)
79 {
80 	memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
81 	memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
82 	memset(&handshake->hash, 0, NOISE_HASH_LEN);
83 	memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
84 	handshake->remote_index = 0;
85 	handshake->state = HANDSHAKE_ZEROED;
86 }
87 
88 void wg_noise_handshake_clear(struct noise_handshake *handshake)
89 {
90 	down_write(&handshake->lock);
91 	wg_index_hashtable_remove(
92 			handshake->entry.peer->device->index_hashtable,
93 			&handshake->entry);
94 	handshake_zero(handshake);
95 	up_write(&handshake->lock);
96 }
97 
98 static struct noise_keypair *keypair_create(struct wg_peer *peer)
99 {
100 	struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
101 
102 	if (unlikely(!keypair))
103 		return NULL;
104 	spin_lock_init(&keypair->receiving_counter.lock);
105 	keypair->internal_id = atomic64_inc_return(&keypair_counter);
106 	keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
107 	keypair->entry.peer = peer;
108 	kref_init(&keypair->refcount);
109 	return keypair;
110 }
111 
112 static void keypair_free_rcu(struct rcu_head *rcu)
113 {
114 	kfree_sensitive(container_of(rcu, struct noise_keypair, rcu));
115 }
116 
117 static void keypair_free_kref(struct kref *kref)
118 {
119 	struct noise_keypair *keypair =
120 		container_of(kref, struct noise_keypair, refcount);
121 
122 	net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
123 			    keypair->entry.peer->device->dev->name,
124 			    keypair->internal_id,
125 			    keypair->entry.peer->internal_id);
126 	wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
127 				  &keypair->entry);
128 	call_rcu(&keypair->rcu, keypair_free_rcu);
129 }
130 
131 void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
132 {
133 	if (unlikely(!keypair))
134 		return;
135 	if (unlikely(unreference_now))
136 		wg_index_hashtable_remove(
137 			keypair->entry.peer->device->index_hashtable,
138 			&keypair->entry);
139 	kref_put(&keypair->refcount, keypair_free_kref);
140 }
141 
142 struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
143 {
144 	RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
145 		"Taking noise keypair reference without holding the RCU BH read lock");
146 	if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
147 		return NULL;
148 	return keypair;
149 }
150 
151 void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
152 {
153 	struct noise_keypair *old;
154 
155 	spin_lock_bh(&keypairs->keypair_update_lock);
156 
157 	/* We zero the next_keypair before zeroing the others, so that
158 	 * wg_noise_received_with_keypair returns early before subsequent ones
159 	 * are zeroed.
160 	 */
161 	old = rcu_dereference_protected(keypairs->next_keypair,
162 		lockdep_is_held(&keypairs->keypair_update_lock));
163 	RCU_INIT_POINTER(keypairs->next_keypair, NULL);
164 	wg_noise_keypair_put(old, true);
165 
166 	old = rcu_dereference_protected(keypairs->previous_keypair,
167 		lockdep_is_held(&keypairs->keypair_update_lock));
168 	RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
169 	wg_noise_keypair_put(old, true);
170 
171 	old = rcu_dereference_protected(keypairs->current_keypair,
172 		lockdep_is_held(&keypairs->keypair_update_lock));
173 	RCU_INIT_POINTER(keypairs->current_keypair, NULL);
174 	wg_noise_keypair_put(old, true);
175 
176 	spin_unlock_bh(&keypairs->keypair_update_lock);
177 }
178 
179 void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
180 {
181 	struct noise_keypair *keypair;
182 
183 	wg_noise_handshake_clear(&peer->handshake);
184 	wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
185 
186 	spin_lock_bh(&peer->keypairs.keypair_update_lock);
187 	keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
188 			lockdep_is_held(&peer->keypairs.keypair_update_lock));
189 	if (keypair)
190 		keypair->sending.is_valid = false;
191 	keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
192 			lockdep_is_held(&peer->keypairs.keypair_update_lock));
193 	if (keypair)
194 		keypair->sending.is_valid = false;
195 	spin_unlock_bh(&peer->keypairs.keypair_update_lock);
196 }
197 
198 static void add_new_keypair(struct noise_keypairs *keypairs,
199 			    struct noise_keypair *new_keypair)
200 {
201 	struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
202 
203 	spin_lock_bh(&keypairs->keypair_update_lock);
204 	previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
205 		lockdep_is_held(&keypairs->keypair_update_lock));
206 	next_keypair = rcu_dereference_protected(keypairs->next_keypair,
207 		lockdep_is_held(&keypairs->keypair_update_lock));
208 	current_keypair = rcu_dereference_protected(keypairs->current_keypair,
209 		lockdep_is_held(&keypairs->keypair_update_lock));
210 	if (new_keypair->i_am_the_initiator) {
211 		/* If we're the initiator, it means we've sent a handshake, and
212 		 * received a confirmation response, which means this new
213 		 * keypair can now be used.
214 		 */
215 		if (next_keypair) {
216 			/* If there already was a next keypair pending, we
217 			 * demote it to be the previous keypair, and free the
218 			 * existing current. Note that this means KCI can result
219 			 * in this transition. It would perhaps be more sound to
220 			 * always just get rid of the unused next keypair
221 			 * instead of putting it in the previous slot, but this
222 			 * might be a bit less robust. Something to think about
223 			 * for the future.
224 			 */
225 			RCU_INIT_POINTER(keypairs->next_keypair, NULL);
226 			rcu_assign_pointer(keypairs->previous_keypair,
227 					   next_keypair);
228 			wg_noise_keypair_put(current_keypair, true);
229 		} else /* If there wasn't an existing next keypair, we replace
230 			* the previous with the current one.
231 			*/
232 			rcu_assign_pointer(keypairs->previous_keypair,
233 					   current_keypair);
234 		/* At this point we can get rid of the old previous keypair, and
235 		 * set up the new keypair.
236 		 */
237 		wg_noise_keypair_put(previous_keypair, true);
238 		rcu_assign_pointer(keypairs->current_keypair, new_keypair);
239 	} else {
240 		/* If we're the responder, it means we can't use the new keypair
241 		 * until we receive confirmation via the first data packet, so
242 		 * we get rid of the existing previous one, the possibly
243 		 * existing next one, and slide in the new next one.
244 		 */
245 		rcu_assign_pointer(keypairs->next_keypair, new_keypair);
246 		wg_noise_keypair_put(next_keypair, true);
247 		RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
248 		wg_noise_keypair_put(previous_keypair, true);
249 	}
250 	spin_unlock_bh(&keypairs->keypair_update_lock);
251 }
252 
253 bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
254 				    struct noise_keypair *received_keypair)
255 {
256 	struct noise_keypair *old_keypair;
257 	bool key_is_new;
258 
259 	/* We first check without taking the spinlock. */
260 	key_is_new = received_keypair ==
261 		     rcu_access_pointer(keypairs->next_keypair);
262 	if (likely(!key_is_new))
263 		return false;
264 
265 	spin_lock_bh(&keypairs->keypair_update_lock);
266 	/* After locking, we double check that things didn't change from
267 	 * beneath us.
268 	 */
269 	if (unlikely(received_keypair !=
270 		    rcu_dereference_protected(keypairs->next_keypair,
271 			    lockdep_is_held(&keypairs->keypair_update_lock)))) {
272 		spin_unlock_bh(&keypairs->keypair_update_lock);
273 		return false;
274 	}
275 
276 	/* When we've finally received the confirmation, we slide the next
277 	 * into the current, the current into the previous, and get rid of
278 	 * the old previous.
279 	 */
280 	old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
281 		lockdep_is_held(&keypairs->keypair_update_lock));
282 	rcu_assign_pointer(keypairs->previous_keypair,
283 		rcu_dereference_protected(keypairs->current_keypair,
284 			lockdep_is_held(&keypairs->keypair_update_lock)));
285 	wg_noise_keypair_put(old_keypair, true);
286 	rcu_assign_pointer(keypairs->current_keypair, received_keypair);
287 	RCU_INIT_POINTER(keypairs->next_keypair, NULL);
288 
289 	spin_unlock_bh(&keypairs->keypair_update_lock);
290 	return true;
291 }
292 
293 /* Must hold static_identity->lock */
294 void wg_noise_set_static_identity_private_key(
295 	struct noise_static_identity *static_identity,
296 	const u8 private_key[NOISE_PUBLIC_KEY_LEN])
297 {
298 	memcpy(static_identity->static_private, private_key,
299 	       NOISE_PUBLIC_KEY_LEN);
300 	curve25519_clamp_secret(static_identity->static_private);
301 	static_identity->has_identity = curve25519_generate_public(
302 		static_identity->static_public, private_key);
303 }
304 
305 /* This is Hugo Krawczyk's HKDF:
306  *  - https://eprint.iacr.org/2010/264.pdf
307  *  - https://tools.ietf.org/html/rfc5869
308  */
309 static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
310 		size_t first_len, size_t second_len, size_t third_len,
311 		size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
312 {
313 	u8 output[BLAKE2S_HASH_SIZE + 1];
314 	u8 secret[BLAKE2S_HASH_SIZE];
315 
316 	WARN_ON(IS_ENABLED(DEBUG) &&
317 		(first_len > BLAKE2S_HASH_SIZE ||
318 		 second_len > BLAKE2S_HASH_SIZE ||
319 		 third_len > BLAKE2S_HASH_SIZE ||
320 		 ((second_len || second_dst || third_len || third_dst) &&
321 		  (!first_len || !first_dst)) ||
322 		 ((third_len || third_dst) && (!second_len || !second_dst))));
323 
324 	/* Extract entropy from data into secret */
325 	blake2s256_hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
326 
327 	if (!first_dst || !first_len)
328 		goto out;
329 
330 	/* Expand first key: key = secret, data = 0x1 */
331 	output[0] = 1;
332 	blake2s256_hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
333 	memcpy(first_dst, output, first_len);
334 
335 	if (!second_dst || !second_len)
336 		goto out;
337 
338 	/* Expand second key: key = secret, data = first-key || 0x2 */
339 	output[BLAKE2S_HASH_SIZE] = 2;
340 	blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
341 			BLAKE2S_HASH_SIZE);
342 	memcpy(second_dst, output, second_len);
343 
344 	if (!third_dst || !third_len)
345 		goto out;
346 
347 	/* Expand third key: key = secret, data = second-key || 0x3 */
348 	output[BLAKE2S_HASH_SIZE] = 3;
349 	blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
350 			BLAKE2S_HASH_SIZE);
351 	memcpy(third_dst, output, third_len);
352 
353 out:
354 	/* Clear sensitive data from stack */
355 	memzero_explicit(secret, BLAKE2S_HASH_SIZE);
356 	memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
357 }
358 
359 static void derive_keys(struct noise_symmetric_key *first_dst,
360 			struct noise_symmetric_key *second_dst,
361 			const u8 chaining_key[NOISE_HASH_LEN])
362 {
363 	u64 birthdate = ktime_get_coarse_boottime_ns();
364 	kdf(first_dst->key, second_dst->key, NULL, NULL,
365 	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
366 	    chaining_key);
367 	first_dst->birthdate = second_dst->birthdate = birthdate;
368 	first_dst->is_valid = second_dst->is_valid = true;
369 }
370 
371 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
372 				u8 key[NOISE_SYMMETRIC_KEY_LEN],
373 				const u8 private[NOISE_PUBLIC_KEY_LEN],
374 				const u8 public[NOISE_PUBLIC_KEY_LEN])
375 {
376 	u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
377 
378 	if (unlikely(!curve25519(dh_calculation, private, public)))
379 		return false;
380 	kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
381 	    NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
382 	memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
383 	return true;
384 }
385 
386 static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
387 					    u8 key[NOISE_SYMMETRIC_KEY_LEN],
388 					    const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
389 {
390 	static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
391 	if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
392 		return false;
393 	kdf(chaining_key, key, NULL, precomputed, NOISE_HASH_LEN,
394 	    NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
395 	    chaining_key);
396 	return true;
397 }
398 
399 static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
400 {
401 	struct blake2s_state blake;
402 
403 	blake2s_init(&blake, NOISE_HASH_LEN);
404 	blake2s_update(&blake, hash, NOISE_HASH_LEN);
405 	blake2s_update(&blake, src, src_len);
406 	blake2s_final(&blake, hash);
407 }
408 
409 static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
410 		    u8 key[NOISE_SYMMETRIC_KEY_LEN],
411 		    const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
412 {
413 	u8 temp_hash[NOISE_HASH_LEN];
414 
415 	kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
416 	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
417 	mix_hash(hash, temp_hash, NOISE_HASH_LEN);
418 	memzero_explicit(temp_hash, NOISE_HASH_LEN);
419 }
420 
421 static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
422 			   u8 hash[NOISE_HASH_LEN],
423 			   const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
424 {
425 	memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
426 	memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
427 	mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
428 }
429 
430 static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
431 			    size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
432 			    u8 hash[NOISE_HASH_LEN])
433 {
434 	chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
435 				 NOISE_HASH_LEN,
436 				 0 /* Always zero for Noise_IK */, key);
437 	mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
438 }
439 
440 static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
441 			    size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
442 			    u8 hash[NOISE_HASH_LEN])
443 {
444 	if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
445 				      hash, NOISE_HASH_LEN,
446 				      0 /* Always zero for Noise_IK */, key))
447 		return false;
448 	mix_hash(hash, src_ciphertext, src_len);
449 	return true;
450 }
451 
452 static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
453 			      const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
454 			      u8 chaining_key[NOISE_HASH_LEN],
455 			      u8 hash[NOISE_HASH_LEN])
456 {
457 	if (ephemeral_dst != ephemeral_src)
458 		memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
459 	mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
460 	kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
461 	    NOISE_PUBLIC_KEY_LEN, chaining_key);
462 }
463 
464 static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
465 {
466 	struct timespec64 now;
467 
468 	ktime_get_real_ts64(&now);
469 
470 	/* In order to prevent some sort of infoleak from precise timers, we
471 	 * round down the nanoseconds part to the closest rounded-down power of
472 	 * two to the maximum initiations per second allowed anyway by the
473 	 * implementation.
474 	 */
475 	now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
476 		rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
477 
478 	/* https://cr.yp.to/libtai/tai64.html */
479 	*(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
480 	*(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
481 }
482 
483 bool
484 wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
485 				     struct noise_handshake *handshake)
486 {
487 	u8 timestamp[NOISE_TIMESTAMP_LEN];
488 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
489 	bool ret = false;
490 
491 	/* We need to wait for crng _before_ taking any locks, since
492 	 * curve25519_generate_secret uses get_random_bytes_wait.
493 	 */
494 	wait_for_random_bytes();
495 
496 	down_read(&handshake->static_identity->lock);
497 	down_write(&handshake->lock);
498 
499 	if (unlikely(!handshake->static_identity->has_identity))
500 		goto out;
501 
502 	dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
503 
504 	handshake_init(handshake->chaining_key, handshake->hash,
505 		       handshake->remote_static);
506 
507 	/* e */
508 	curve25519_generate_secret(handshake->ephemeral_private);
509 	if (!curve25519_generate_public(dst->unencrypted_ephemeral,
510 					handshake->ephemeral_private))
511 		goto out;
512 	message_ephemeral(dst->unencrypted_ephemeral,
513 			  dst->unencrypted_ephemeral, handshake->chaining_key,
514 			  handshake->hash);
515 
516 	/* es */
517 	if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
518 		    handshake->remote_static))
519 		goto out;
520 
521 	/* s */
522 	message_encrypt(dst->encrypted_static,
523 			handshake->static_identity->static_public,
524 			NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
525 
526 	/* ss */
527 	if (!mix_precomputed_dh(handshake->chaining_key, key,
528 				handshake->precomputed_static_static))
529 		goto out;
530 
531 	/* {t} */
532 	tai64n_now(timestamp);
533 	message_encrypt(dst->encrypted_timestamp, timestamp,
534 			NOISE_TIMESTAMP_LEN, key, handshake->hash);
535 
536 	dst->sender_index = wg_index_hashtable_insert(
537 		handshake->entry.peer->device->index_hashtable,
538 		&handshake->entry);
539 
540 	handshake->state = HANDSHAKE_CREATED_INITIATION;
541 	ret = true;
542 
543 out:
544 	up_write(&handshake->lock);
545 	up_read(&handshake->static_identity->lock);
546 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
547 	return ret;
548 }
549 
550 struct wg_peer *
551 wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
552 				      struct wg_device *wg)
553 {
554 	struct wg_peer *peer = NULL, *ret_peer = NULL;
555 	struct noise_handshake *handshake;
556 	bool replay_attack, flood_attack;
557 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
558 	u8 chaining_key[NOISE_HASH_LEN];
559 	u8 hash[NOISE_HASH_LEN];
560 	u8 s[NOISE_PUBLIC_KEY_LEN];
561 	u8 e[NOISE_PUBLIC_KEY_LEN];
562 	u8 t[NOISE_TIMESTAMP_LEN];
563 	u64 initiation_consumption;
564 
565 	down_read(&wg->static_identity.lock);
566 	if (unlikely(!wg->static_identity.has_identity))
567 		goto out;
568 
569 	handshake_init(chaining_key, hash, wg->static_identity.static_public);
570 
571 	/* e */
572 	message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
573 
574 	/* es */
575 	if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
576 		goto out;
577 
578 	/* s */
579 	if (!message_decrypt(s, src->encrypted_static,
580 			     sizeof(src->encrypted_static), key, hash))
581 		goto out;
582 
583 	/* Lookup which peer we're actually talking to */
584 	peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
585 	if (!peer)
586 		goto out;
587 	handshake = &peer->handshake;
588 
589 	/* ss */
590 	if (!mix_precomputed_dh(chaining_key, key,
591 				handshake->precomputed_static_static))
592 	    goto out;
593 
594 	/* {t} */
595 	if (!message_decrypt(t, src->encrypted_timestamp,
596 			     sizeof(src->encrypted_timestamp), key, hash))
597 		goto out;
598 
599 	down_read(&handshake->lock);
600 	replay_attack = memcmp(t, handshake->latest_timestamp,
601 			       NOISE_TIMESTAMP_LEN) <= 0;
602 	flood_attack = (s64)handshake->last_initiation_consumption +
603 			       NSEC_PER_SEC / INITIATIONS_PER_SECOND >
604 		       (s64)ktime_get_coarse_boottime_ns();
605 	up_read(&handshake->lock);
606 	if (replay_attack || flood_attack)
607 		goto out;
608 
609 	/* Success! Copy everything to peer */
610 	down_write(&handshake->lock);
611 	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
612 	if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
613 		memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
614 	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
615 	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
616 	handshake->remote_index = src->sender_index;
617 	initiation_consumption = ktime_get_coarse_boottime_ns();
618 	if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
619 		handshake->last_initiation_consumption = initiation_consumption;
620 	handshake->state = HANDSHAKE_CONSUMED_INITIATION;
621 	up_write(&handshake->lock);
622 	ret_peer = peer;
623 
624 out:
625 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
626 	memzero_explicit(hash, NOISE_HASH_LEN);
627 	memzero_explicit(chaining_key, NOISE_HASH_LEN);
628 	up_read(&wg->static_identity.lock);
629 	if (!ret_peer)
630 		wg_peer_put(peer);
631 	return ret_peer;
632 }
633 
634 bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
635 					struct noise_handshake *handshake)
636 {
637 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
638 	bool ret = false;
639 
640 	/* We need to wait for crng _before_ taking any locks, since
641 	 * curve25519_generate_secret uses get_random_bytes_wait.
642 	 */
643 	wait_for_random_bytes();
644 
645 	down_read(&handshake->static_identity->lock);
646 	down_write(&handshake->lock);
647 
648 	if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
649 		goto out;
650 
651 	dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
652 	dst->receiver_index = handshake->remote_index;
653 
654 	/* e */
655 	curve25519_generate_secret(handshake->ephemeral_private);
656 	if (!curve25519_generate_public(dst->unencrypted_ephemeral,
657 					handshake->ephemeral_private))
658 		goto out;
659 	message_ephemeral(dst->unencrypted_ephemeral,
660 			  dst->unencrypted_ephemeral, handshake->chaining_key,
661 			  handshake->hash);
662 
663 	/* ee */
664 	if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
665 		    handshake->remote_ephemeral))
666 		goto out;
667 
668 	/* se */
669 	if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
670 		    handshake->remote_static))
671 		goto out;
672 
673 	/* psk */
674 	mix_psk(handshake->chaining_key, handshake->hash, key,
675 		handshake->preshared_key);
676 
677 	/* {} */
678 	message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
679 
680 	dst->sender_index = wg_index_hashtable_insert(
681 		handshake->entry.peer->device->index_hashtable,
682 		&handshake->entry);
683 
684 	handshake->state = HANDSHAKE_CREATED_RESPONSE;
685 	ret = true;
686 
687 out:
688 	up_write(&handshake->lock);
689 	up_read(&handshake->static_identity->lock);
690 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
691 	return ret;
692 }
693 
694 struct wg_peer *
695 wg_noise_handshake_consume_response(struct message_handshake_response *src,
696 				    struct wg_device *wg)
697 {
698 	enum noise_handshake_state state = HANDSHAKE_ZEROED;
699 	struct wg_peer *peer = NULL, *ret_peer = NULL;
700 	struct noise_handshake *handshake;
701 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
702 	u8 hash[NOISE_HASH_LEN];
703 	u8 chaining_key[NOISE_HASH_LEN];
704 	u8 e[NOISE_PUBLIC_KEY_LEN];
705 	u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
706 	u8 static_private[NOISE_PUBLIC_KEY_LEN];
707 	u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];
708 
709 	down_read(&wg->static_identity.lock);
710 
711 	if (unlikely(!wg->static_identity.has_identity))
712 		goto out;
713 
714 	handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
715 		wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
716 		src->receiver_index, &peer);
717 	if (unlikely(!handshake))
718 		goto out;
719 
720 	down_read(&handshake->lock);
721 	state = handshake->state;
722 	memcpy(hash, handshake->hash, NOISE_HASH_LEN);
723 	memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
724 	memcpy(ephemeral_private, handshake->ephemeral_private,
725 	       NOISE_PUBLIC_KEY_LEN);
726 	memcpy(preshared_key, handshake->preshared_key,
727 	       NOISE_SYMMETRIC_KEY_LEN);
728 	up_read(&handshake->lock);
729 
730 	if (state != HANDSHAKE_CREATED_INITIATION)
731 		goto fail;
732 
733 	/* e */
734 	message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
735 
736 	/* ee */
737 	if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
738 		goto fail;
739 
740 	/* se */
741 	if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
742 		goto fail;
743 
744 	/* psk */
745 	mix_psk(chaining_key, hash, key, preshared_key);
746 
747 	/* {} */
748 	if (!message_decrypt(NULL, src->encrypted_nothing,
749 			     sizeof(src->encrypted_nothing), key, hash))
750 		goto fail;
751 
752 	/* Success! Copy everything to peer */
753 	down_write(&handshake->lock);
754 	/* It's important to check that the state is still the same, while we
755 	 * have an exclusive lock.
756 	 */
757 	if (handshake->state != state) {
758 		up_write(&handshake->lock);
759 		goto fail;
760 	}
761 	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
762 	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
763 	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
764 	handshake->remote_index = src->sender_index;
765 	handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
766 	up_write(&handshake->lock);
767 	ret_peer = peer;
768 	goto out;
769 
770 fail:
771 	wg_peer_put(peer);
772 out:
773 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
774 	memzero_explicit(hash, NOISE_HASH_LEN);
775 	memzero_explicit(chaining_key, NOISE_HASH_LEN);
776 	memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
777 	memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
778 	memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
779 	up_read(&wg->static_identity.lock);
780 	return ret_peer;
781 }
782 
783 bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
784 				      struct noise_keypairs *keypairs)
785 {
786 	struct noise_keypair *new_keypair;
787 	bool ret = false;
788 
789 	down_write(&handshake->lock);
790 	if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
791 	    handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
792 		goto out;
793 
794 	new_keypair = keypair_create(handshake->entry.peer);
795 	if (!new_keypair)
796 		goto out;
797 	new_keypair->i_am_the_initiator = handshake->state ==
798 					  HANDSHAKE_CONSUMED_RESPONSE;
799 	new_keypair->remote_index = handshake->remote_index;
800 
801 	if (new_keypair->i_am_the_initiator)
802 		derive_keys(&new_keypair->sending, &new_keypair->receiving,
803 			    handshake->chaining_key);
804 	else
805 		derive_keys(&new_keypair->receiving, &new_keypair->sending,
806 			    handshake->chaining_key);
807 
808 	handshake_zero(handshake);
809 	rcu_read_lock_bh();
810 	if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
811 					   handshake)->is_dead))) {
812 		add_new_keypair(keypairs, new_keypair);
813 		net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
814 				    handshake->entry.peer->device->dev->name,
815 				    new_keypair->internal_id,
816 				    handshake->entry.peer->internal_id);
817 		ret = wg_index_hashtable_replace(
818 			handshake->entry.peer->device->index_hashtable,
819 			&handshake->entry, &new_keypair->entry);
820 	} else {
821 		kfree_sensitive(new_keypair);
822 	}
823 	rcu_read_unlock_bh();
824 
825 out:
826 	up_write(&handshake->lock);
827 	return ret;
828 }
829