xref: /openbmc/linux/drivers/net/wireguard/send.c (revision f97cee494dc92395a668445bcd24d34c89f4ff8c)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5 
6 #include "queueing.h"
7 #include "timers.h"
8 #include "device.h"
9 #include "peer.h"
10 #include "socket.h"
11 #include "messages.h"
12 #include "cookie.h"
13 
14 #include <linux/uio.h>
15 #include <linux/inetdevice.h>
16 #include <linux/socket.h>
17 #include <net/ip_tunnels.h>
18 #include <net/udp.h>
19 #include <net/sock.h>
20 
21 static void wg_packet_send_handshake_initiation(struct wg_peer *peer)
22 {
23 	struct message_handshake_initiation packet;
24 
25 	if (!wg_birthdate_has_expired(atomic64_read(&peer->last_sent_handshake),
26 				      REKEY_TIMEOUT))
27 		return; /* This function is rate limited. */
28 
29 	atomic64_set(&peer->last_sent_handshake, ktime_get_coarse_boottime_ns());
30 	net_dbg_ratelimited("%s: Sending handshake initiation to peer %llu (%pISpfsc)\n",
31 			    peer->device->dev->name, peer->internal_id,
32 			    &peer->endpoint.addr);
33 
34 	if (wg_noise_handshake_create_initiation(&packet, &peer->handshake)) {
35 		wg_cookie_add_mac_to_packet(&packet, sizeof(packet), peer);
36 		wg_timers_any_authenticated_packet_traversal(peer);
37 		wg_timers_any_authenticated_packet_sent(peer);
38 		atomic64_set(&peer->last_sent_handshake,
39 			     ktime_get_coarse_boottime_ns());
40 		wg_socket_send_buffer_to_peer(peer, &packet, sizeof(packet),
41 					      HANDSHAKE_DSCP);
42 		wg_timers_handshake_initiated(peer);
43 	}
44 }
45 
46 void wg_packet_handshake_send_worker(struct work_struct *work)
47 {
48 	struct wg_peer *peer = container_of(work, struct wg_peer,
49 					    transmit_handshake_work);
50 
51 	wg_packet_send_handshake_initiation(peer);
52 	wg_peer_put(peer);
53 }
54 
55 void wg_packet_send_queued_handshake_initiation(struct wg_peer *peer,
56 						bool is_retry)
57 {
58 	if (!is_retry)
59 		peer->timer_handshake_attempts = 0;
60 
61 	rcu_read_lock_bh();
62 	/* We check last_sent_handshake here in addition to the actual function
63 	 * we're queueing up, so that we don't queue things if not strictly
64 	 * necessary:
65 	 */
66 	if (!wg_birthdate_has_expired(atomic64_read(&peer->last_sent_handshake),
67 				      REKEY_TIMEOUT) ||
68 			unlikely(READ_ONCE(peer->is_dead)))
69 		goto out;
70 
71 	wg_peer_get(peer);
72 	/* Queues up calling packet_send_queued_handshakes(peer), where we do a
73 	 * peer_put(peer) after:
74 	 */
75 	if (!queue_work(peer->device->handshake_send_wq,
76 			&peer->transmit_handshake_work))
77 		/* If the work was already queued, we want to drop the
78 		 * extra reference:
79 		 */
80 		wg_peer_put(peer);
81 out:
82 	rcu_read_unlock_bh();
83 }
84 
85 void wg_packet_send_handshake_response(struct wg_peer *peer)
86 {
87 	struct message_handshake_response packet;
88 
89 	atomic64_set(&peer->last_sent_handshake, ktime_get_coarse_boottime_ns());
90 	net_dbg_ratelimited("%s: Sending handshake response to peer %llu (%pISpfsc)\n",
91 			    peer->device->dev->name, peer->internal_id,
92 			    &peer->endpoint.addr);
93 
94 	if (wg_noise_handshake_create_response(&packet, &peer->handshake)) {
95 		wg_cookie_add_mac_to_packet(&packet, sizeof(packet), peer);
96 		if (wg_noise_handshake_begin_session(&peer->handshake,
97 						     &peer->keypairs)) {
98 			wg_timers_session_derived(peer);
99 			wg_timers_any_authenticated_packet_traversal(peer);
100 			wg_timers_any_authenticated_packet_sent(peer);
101 			atomic64_set(&peer->last_sent_handshake,
102 				     ktime_get_coarse_boottime_ns());
103 			wg_socket_send_buffer_to_peer(peer, &packet,
104 						      sizeof(packet),
105 						      HANDSHAKE_DSCP);
106 		}
107 	}
108 }
109 
110 void wg_packet_send_handshake_cookie(struct wg_device *wg,
111 				     struct sk_buff *initiating_skb,
112 				     __le32 sender_index)
113 {
114 	struct message_handshake_cookie packet;
115 
116 	net_dbg_skb_ratelimited("%s: Sending cookie response for denied handshake message for %pISpfsc\n",
117 				wg->dev->name, initiating_skb);
118 	wg_cookie_message_create(&packet, initiating_skb, sender_index,
119 				 &wg->cookie_checker);
120 	wg_socket_send_buffer_as_reply_to_skb(wg, initiating_skb, &packet,
121 					      sizeof(packet));
122 }
123 
124 static void keep_key_fresh(struct wg_peer *peer)
125 {
126 	struct noise_keypair *keypair;
127 	bool send;
128 
129 	rcu_read_lock_bh();
130 	keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
131 	send = keypair && READ_ONCE(keypair->sending.is_valid) &&
132 	       (atomic64_read(&keypair->sending_counter) > REKEY_AFTER_MESSAGES ||
133 		(keypair->i_am_the_initiator &&
134 		 wg_birthdate_has_expired(keypair->sending.birthdate, REKEY_AFTER_TIME)));
135 	rcu_read_unlock_bh();
136 
137 	if (unlikely(send))
138 		wg_packet_send_queued_handshake_initiation(peer, false);
139 }
140 
141 static unsigned int calculate_skb_padding(struct sk_buff *skb)
142 {
143 	unsigned int padded_size, last_unit = skb->len;
144 
145 	if (unlikely(!PACKET_CB(skb)->mtu))
146 		return ALIGN(last_unit, MESSAGE_PADDING_MULTIPLE) - last_unit;
147 
148 	/* We do this modulo business with the MTU, just in case the networking
149 	 * layer gives us a packet that's bigger than the MTU. In that case, we
150 	 * wouldn't want the final subtraction to overflow in the case of the
151 	 * padded_size being clamped. Fortunately, that's very rarely the case,
152 	 * so we optimize for that not happening.
153 	 */
154 	if (unlikely(last_unit > PACKET_CB(skb)->mtu))
155 		last_unit %= PACKET_CB(skb)->mtu;
156 
157 	padded_size = min(PACKET_CB(skb)->mtu,
158 			  ALIGN(last_unit, MESSAGE_PADDING_MULTIPLE));
159 	return padded_size - last_unit;
160 }
161 
162 static bool encrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
163 {
164 	unsigned int padding_len, plaintext_len, trailer_len;
165 	struct scatterlist sg[MAX_SKB_FRAGS + 8];
166 	struct message_data *header;
167 	struct sk_buff *trailer;
168 	int num_frags;
169 
170 	/* Force hash calculation before encryption so that flow analysis is
171 	 * consistent over the inner packet.
172 	 */
173 	skb_get_hash(skb);
174 
175 	/* Calculate lengths. */
176 	padding_len = calculate_skb_padding(skb);
177 	trailer_len = padding_len + noise_encrypted_len(0);
178 	plaintext_len = skb->len + padding_len;
179 
180 	/* Expand data section to have room for padding and auth tag. */
181 	num_frags = skb_cow_data(skb, trailer_len, &trailer);
182 	if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg)))
183 		return false;
184 
185 	/* Set the padding to zeros, and make sure it and the auth tag are part
186 	 * of the skb.
187 	 */
188 	memset(skb_tail_pointer(trailer), 0, padding_len);
189 
190 	/* Expand head section to have room for our header and the network
191 	 * stack's headers.
192 	 */
193 	if (unlikely(skb_cow_head(skb, DATA_PACKET_HEAD_ROOM) < 0))
194 		return false;
195 
196 	/* Finalize checksum calculation for the inner packet, if required. */
197 	if (unlikely(skb->ip_summed == CHECKSUM_PARTIAL &&
198 		     skb_checksum_help(skb)))
199 		return false;
200 
201 	/* Only after checksumming can we safely add on the padding at the end
202 	 * and the header.
203 	 */
204 	skb_set_inner_network_header(skb, 0);
205 	header = (struct message_data *)skb_push(skb, sizeof(*header));
206 	header->header.type = cpu_to_le32(MESSAGE_DATA);
207 	header->key_idx = keypair->remote_index;
208 	header->counter = cpu_to_le64(PACKET_CB(skb)->nonce);
209 	pskb_put(skb, trailer, trailer_len);
210 
211 	/* Now we can encrypt the scattergather segments */
212 	sg_init_table(sg, num_frags);
213 	if (skb_to_sgvec(skb, sg, sizeof(struct message_data),
214 			 noise_encrypted_len(plaintext_len)) <= 0)
215 		return false;
216 	return chacha20poly1305_encrypt_sg_inplace(sg, plaintext_len, NULL, 0,
217 						   PACKET_CB(skb)->nonce,
218 						   keypair->sending.key);
219 }
220 
221 void wg_packet_send_keepalive(struct wg_peer *peer)
222 {
223 	struct sk_buff *skb;
224 
225 	if (skb_queue_empty(&peer->staged_packet_queue)) {
226 		skb = alloc_skb(DATA_PACKET_HEAD_ROOM + MESSAGE_MINIMUM_LENGTH,
227 				GFP_ATOMIC);
228 		if (unlikely(!skb))
229 			return;
230 		skb_reserve(skb, DATA_PACKET_HEAD_ROOM);
231 		skb->dev = peer->device->dev;
232 		PACKET_CB(skb)->mtu = skb->dev->mtu;
233 		skb_queue_tail(&peer->staged_packet_queue, skb);
234 		net_dbg_ratelimited("%s: Sending keepalive packet to peer %llu (%pISpfsc)\n",
235 				    peer->device->dev->name, peer->internal_id,
236 				    &peer->endpoint.addr);
237 	}
238 
239 	wg_packet_send_staged_packets(peer);
240 }
241 
242 static void wg_packet_create_data_done(struct sk_buff *first,
243 				       struct wg_peer *peer)
244 {
245 	struct sk_buff *skb, *next;
246 	bool is_keepalive, data_sent = false;
247 
248 	wg_timers_any_authenticated_packet_traversal(peer);
249 	wg_timers_any_authenticated_packet_sent(peer);
250 	skb_list_walk_safe(first, skb, next) {
251 		is_keepalive = skb->len == message_data_len(0);
252 		if (likely(!wg_socket_send_skb_to_peer(peer, skb,
253 				PACKET_CB(skb)->ds) && !is_keepalive))
254 			data_sent = true;
255 	}
256 
257 	if (likely(data_sent))
258 		wg_timers_data_sent(peer);
259 
260 	keep_key_fresh(peer);
261 }
262 
263 void wg_packet_tx_worker(struct work_struct *work)
264 {
265 	struct crypt_queue *queue = container_of(work, struct crypt_queue,
266 						 work);
267 	struct noise_keypair *keypair;
268 	enum packet_state state;
269 	struct sk_buff *first;
270 	struct wg_peer *peer;
271 
272 	while ((first = __ptr_ring_peek(&queue->ring)) != NULL &&
273 	       (state = atomic_read_acquire(&PACKET_CB(first)->state)) !=
274 		       PACKET_STATE_UNCRYPTED) {
275 		__ptr_ring_discard_one(&queue->ring);
276 		peer = PACKET_PEER(first);
277 		keypair = PACKET_CB(first)->keypair;
278 
279 		if (likely(state == PACKET_STATE_CRYPTED))
280 			wg_packet_create_data_done(first, peer);
281 		else
282 			kfree_skb_list(first);
283 
284 		wg_noise_keypair_put(keypair, false);
285 		wg_peer_put(peer);
286 		if (need_resched())
287 			cond_resched();
288 	}
289 }
290 
291 void wg_packet_encrypt_worker(struct work_struct *work)
292 {
293 	struct crypt_queue *queue = container_of(work, struct multicore_worker,
294 						 work)->ptr;
295 	struct sk_buff *first, *skb, *next;
296 
297 	while ((first = ptr_ring_consume_bh(&queue->ring)) != NULL) {
298 		enum packet_state state = PACKET_STATE_CRYPTED;
299 
300 		skb_list_walk_safe(first, skb, next) {
301 			if (likely(encrypt_packet(skb,
302 					PACKET_CB(first)->keypair))) {
303 				wg_reset_packet(skb, true);
304 			} else {
305 				state = PACKET_STATE_DEAD;
306 				break;
307 			}
308 		}
309 		wg_queue_enqueue_per_peer(&PACKET_PEER(first)->tx_queue, first,
310 					  state);
311 		if (need_resched())
312 			cond_resched();
313 	}
314 }
315 
316 static void wg_packet_create_data(struct sk_buff *first)
317 {
318 	struct wg_peer *peer = PACKET_PEER(first);
319 	struct wg_device *wg = peer->device;
320 	int ret = -EINVAL;
321 
322 	rcu_read_lock_bh();
323 	if (unlikely(READ_ONCE(peer->is_dead)))
324 		goto err;
325 
326 	ret = wg_queue_enqueue_per_device_and_peer(&wg->encrypt_queue,
327 						   &peer->tx_queue, first,
328 						   wg->packet_crypt_wq,
329 						   &wg->encrypt_queue.last_cpu);
330 	if (unlikely(ret == -EPIPE))
331 		wg_queue_enqueue_per_peer(&peer->tx_queue, first,
332 					  PACKET_STATE_DEAD);
333 err:
334 	rcu_read_unlock_bh();
335 	if (likely(!ret || ret == -EPIPE))
336 		return;
337 	wg_noise_keypair_put(PACKET_CB(first)->keypair, false);
338 	wg_peer_put(peer);
339 	kfree_skb_list(first);
340 }
341 
342 void wg_packet_purge_staged_packets(struct wg_peer *peer)
343 {
344 	spin_lock_bh(&peer->staged_packet_queue.lock);
345 	peer->device->dev->stats.tx_dropped += peer->staged_packet_queue.qlen;
346 	__skb_queue_purge(&peer->staged_packet_queue);
347 	spin_unlock_bh(&peer->staged_packet_queue.lock);
348 }
349 
350 void wg_packet_send_staged_packets(struct wg_peer *peer)
351 {
352 	struct noise_keypair *keypair;
353 	struct sk_buff_head packets;
354 	struct sk_buff *skb;
355 
356 	/* Steal the current queue into our local one. */
357 	__skb_queue_head_init(&packets);
358 	spin_lock_bh(&peer->staged_packet_queue.lock);
359 	skb_queue_splice_init(&peer->staged_packet_queue, &packets);
360 	spin_unlock_bh(&peer->staged_packet_queue.lock);
361 	if (unlikely(skb_queue_empty(&packets)))
362 		return;
363 
364 	/* First we make sure we have a valid reference to a valid key. */
365 	rcu_read_lock_bh();
366 	keypair = wg_noise_keypair_get(
367 		rcu_dereference_bh(peer->keypairs.current_keypair));
368 	rcu_read_unlock_bh();
369 	if (unlikely(!keypair))
370 		goto out_nokey;
371 	if (unlikely(!READ_ONCE(keypair->sending.is_valid)))
372 		goto out_nokey;
373 	if (unlikely(wg_birthdate_has_expired(keypair->sending.birthdate,
374 					      REJECT_AFTER_TIME)))
375 		goto out_invalid;
376 
377 	/* After we know we have a somewhat valid key, we now try to assign
378 	 * nonces to all of the packets in the queue. If we can't assign nonces
379 	 * for all of them, we just consider it a failure and wait for the next
380 	 * handshake.
381 	 */
382 	skb_queue_walk(&packets, skb) {
383 		/* 0 for no outer TOS: no leak. TODO: at some later point, we
384 		 * might consider using flowi->tos as outer instead.
385 		 */
386 		PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0, ip_hdr(skb), skb);
387 		PACKET_CB(skb)->nonce =
388 				atomic64_inc_return(&keypair->sending_counter) - 1;
389 		if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES))
390 			goto out_invalid;
391 	}
392 
393 	packets.prev->next = NULL;
394 	wg_peer_get(keypair->entry.peer);
395 	PACKET_CB(packets.next)->keypair = keypair;
396 	wg_packet_create_data(packets.next);
397 	return;
398 
399 out_invalid:
400 	WRITE_ONCE(keypair->sending.is_valid, false);
401 out_nokey:
402 	wg_noise_keypair_put(keypair, false);
403 
404 	/* We orphan the packets if we're waiting on a handshake, so that they
405 	 * don't block a socket's pool.
406 	 */
407 	skb_queue_walk(&packets, skb)
408 		skb_orphan(skb);
409 	/* Then we put them back on the top of the queue. We're not too
410 	 * concerned about accidentally getting things a little out of order if
411 	 * packets are being added really fast, because this queue is for before
412 	 * packets can even be sent and it's small anyway.
413 	 */
414 	spin_lock_bh(&peer->staged_packet_queue.lock);
415 	skb_queue_splice(&packets, &peer->staged_packet_queue);
416 	spin_unlock_bh(&peer->staged_packet_queue.lock);
417 
418 	/* If we're exiting because there's something wrong with the key, it
419 	 * means we should initiate a new handshake.
420 	 */
421 	wg_packet_send_queued_handshake_initiation(peer, false);
422 }
423