1e7096c13SJason A. Donenfeld // SPDX-License-Identifier: GPL-2.0
2e7096c13SJason A. Donenfeld /*
3e7096c13SJason A. Donenfeld  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4e7096c13SJason A. Donenfeld  */
5e7096c13SJason A. Donenfeld 
6e7096c13SJason A. Donenfeld #include "ratelimiter.h"
7e7096c13SJason A. Donenfeld #include <linux/siphash.h>
8e7096c13SJason A. Donenfeld #include <linux/mm.h>
9e7096c13SJason A. Donenfeld #include <linux/slab.h>
10e7096c13SJason A. Donenfeld #include <net/ip.h>
11e7096c13SJason A. Donenfeld 
12e7096c13SJason A. Donenfeld static struct kmem_cache *entry_cache;
13e7096c13SJason A. Donenfeld static hsiphash_key_t key;
14e7096c13SJason A. Donenfeld static spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock");
15e7096c13SJason A. Donenfeld static DEFINE_MUTEX(init_lock);
16e7096c13SJason A. Donenfeld static u64 init_refcnt; /* Protected by init_lock, hence not atomic. */
17e7096c13SJason A. Donenfeld static atomic_t total_entries = ATOMIC_INIT(0);
18e7096c13SJason A. Donenfeld static unsigned int max_entries, table_size;
19e7096c13SJason A. Donenfeld static void wg_ratelimiter_gc_entries(struct work_struct *);
20e7096c13SJason A. Donenfeld static DECLARE_DEFERRABLE_WORK(gc_work, wg_ratelimiter_gc_entries);
21e7096c13SJason A. Donenfeld static struct hlist_head *table_v4;
22e7096c13SJason A. Donenfeld #if IS_ENABLED(CONFIG_IPV6)
23e7096c13SJason A. Donenfeld static struct hlist_head *table_v6;
24e7096c13SJason A. Donenfeld #endif
25e7096c13SJason A. Donenfeld 
26e7096c13SJason A. Donenfeld struct ratelimiter_entry {
27e7096c13SJason A. Donenfeld 	u64 last_time_ns, tokens, ip;
28e7096c13SJason A. Donenfeld 	void *net;
29e7096c13SJason A. Donenfeld 	spinlock_t lock;
30e7096c13SJason A. Donenfeld 	struct hlist_node hash;
31e7096c13SJason A. Donenfeld 	struct rcu_head rcu;
32e7096c13SJason A. Donenfeld };
33e7096c13SJason A. Donenfeld 
34e7096c13SJason A. Donenfeld enum {
35e7096c13SJason A. Donenfeld 	PACKETS_PER_SECOND = 20,
36e7096c13SJason A. Donenfeld 	PACKETS_BURSTABLE = 5,
37e7096c13SJason A. Donenfeld 	PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND,
38e7096c13SJason A. Donenfeld 	TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE
39e7096c13SJason A. Donenfeld };
40e7096c13SJason A. Donenfeld 
entry_free(struct rcu_head * rcu)41e7096c13SJason A. Donenfeld static void entry_free(struct rcu_head *rcu)
42e7096c13SJason A. Donenfeld {
43e7096c13SJason A. Donenfeld 	kmem_cache_free(entry_cache,
44e7096c13SJason A. Donenfeld 			container_of(rcu, struct ratelimiter_entry, rcu));
45e7096c13SJason A. Donenfeld 	atomic_dec(&total_entries);
46e7096c13SJason A. Donenfeld }
47e7096c13SJason A. Donenfeld 
entry_uninit(struct ratelimiter_entry * entry)48e7096c13SJason A. Donenfeld static void entry_uninit(struct ratelimiter_entry *entry)
49e7096c13SJason A. Donenfeld {
50e7096c13SJason A. Donenfeld 	hlist_del_rcu(&entry->hash);
51e7096c13SJason A. Donenfeld 	call_rcu(&entry->rcu, entry_free);
52e7096c13SJason A. Donenfeld }
53e7096c13SJason A. Donenfeld 
54e7096c13SJason A. Donenfeld /* Calling this function with a NULL work uninits all entries. */
wg_ratelimiter_gc_entries(struct work_struct * work)55e7096c13SJason A. Donenfeld static void wg_ratelimiter_gc_entries(struct work_struct *work)
56e7096c13SJason A. Donenfeld {
57e7096c13SJason A. Donenfeld 	const u64 now = ktime_get_coarse_boottime_ns();
58e7096c13SJason A. Donenfeld 	struct ratelimiter_entry *entry;
59e7096c13SJason A. Donenfeld 	struct hlist_node *temp;
60e7096c13SJason A. Donenfeld 	unsigned int i;
61e7096c13SJason A. Donenfeld 
62e7096c13SJason A. Donenfeld 	for (i = 0; i < table_size; ++i) {
63e7096c13SJason A. Donenfeld 		spin_lock(&table_lock);
64e7096c13SJason A. Donenfeld 		hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) {
65e7096c13SJason A. Donenfeld 			if (unlikely(!work) ||
66e7096c13SJason A. Donenfeld 			    now - entry->last_time_ns > NSEC_PER_SEC)
67e7096c13SJason A. Donenfeld 				entry_uninit(entry);
68e7096c13SJason A. Donenfeld 		}
69e7096c13SJason A. Donenfeld #if IS_ENABLED(CONFIG_IPV6)
70e7096c13SJason A. Donenfeld 		hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) {
71e7096c13SJason A. Donenfeld 			if (unlikely(!work) ||
72e7096c13SJason A. Donenfeld 			    now - entry->last_time_ns > NSEC_PER_SEC)
73e7096c13SJason A. Donenfeld 				entry_uninit(entry);
74e7096c13SJason A. Donenfeld 		}
75e7096c13SJason A. Donenfeld #endif
76e7096c13SJason A. Donenfeld 		spin_unlock(&table_lock);
77e7096c13SJason A. Donenfeld 		if (likely(work))
78e7096c13SJason A. Donenfeld 			cond_resched();
79e7096c13SJason A. Donenfeld 	}
80e7096c13SJason A. Donenfeld 	if (likely(work))
81e7096c13SJason A. Donenfeld 		queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
82e7096c13SJason A. Donenfeld }
83e7096c13SJason A. Donenfeld 
wg_ratelimiter_allow(struct sk_buff * skb,struct net * net)84e7096c13SJason A. Donenfeld bool wg_ratelimiter_allow(struct sk_buff *skb, struct net *net)
85e7096c13SJason A. Donenfeld {
86e7096c13SJason A. Donenfeld 	/* We only take the bottom half of the net pointer, so that we can hash
87e7096c13SJason A. Donenfeld 	 * 3 words in the end. This way, siphash's len param fits into the final
88e7096c13SJason A. Donenfeld 	 * u32, and we don't incur an extra round.
89e7096c13SJason A. Donenfeld 	 */
90e7096c13SJason A. Donenfeld 	const u32 net_word = (unsigned long)net;
91e7096c13SJason A. Donenfeld 	struct ratelimiter_entry *entry;
92e7096c13SJason A. Donenfeld 	struct hlist_head *bucket;
93e7096c13SJason A. Donenfeld 	u64 ip;
94e7096c13SJason A. Donenfeld 
95e7096c13SJason A. Donenfeld 	if (skb->protocol == htons(ETH_P_IP)) {
96e7096c13SJason A. Donenfeld 		ip = (u64 __force)ip_hdr(skb)->saddr;
97e7096c13SJason A. Donenfeld 		bucket = &table_v4[hsiphash_2u32(net_word, ip, &key) &
98e7096c13SJason A. Donenfeld 				   (table_size - 1)];
99e7096c13SJason A. Donenfeld 	}
100e7096c13SJason A. Donenfeld #if IS_ENABLED(CONFIG_IPV6)
101e7096c13SJason A. Donenfeld 	else if (skb->protocol == htons(ETH_P_IPV6)) {
102e7096c13SJason A. Donenfeld 		/* Only use 64 bits, so as to ratelimit the whole /64. */
103e7096c13SJason A. Donenfeld 		memcpy(&ip, &ipv6_hdr(skb)->saddr, sizeof(ip));
104e7096c13SJason A. Donenfeld 		bucket = &table_v6[hsiphash_3u32(net_word, ip >> 32, ip, &key) &
105e7096c13SJason A. Donenfeld 				   (table_size - 1)];
106e7096c13SJason A. Donenfeld 	}
107e7096c13SJason A. Donenfeld #endif
108e7096c13SJason A. Donenfeld 	else
109e7096c13SJason A. Donenfeld 		return false;
110e7096c13SJason A. Donenfeld 	rcu_read_lock();
111e7096c13SJason A. Donenfeld 	hlist_for_each_entry_rcu(entry, bucket, hash) {
112e7096c13SJason A. Donenfeld 		if (entry->net == net && entry->ip == ip) {
113e7096c13SJason A. Donenfeld 			u64 now, tokens;
114e7096c13SJason A. Donenfeld 			bool ret;
115e7096c13SJason A. Donenfeld 			/* Quasi-inspired by nft_limit.c, but this is actually a
116e7096c13SJason A. Donenfeld 			 * slightly different algorithm. Namely, we incorporate
117e7096c13SJason A. Donenfeld 			 * the burst as part of the maximum tokens, rather than
118e7096c13SJason A. Donenfeld 			 * as part of the rate.
119e7096c13SJason A. Donenfeld 			 */
120e7096c13SJason A. Donenfeld 			spin_lock(&entry->lock);
121e7096c13SJason A. Donenfeld 			now = ktime_get_coarse_boottime_ns();
122e7096c13SJason A. Donenfeld 			tokens = min_t(u64, TOKEN_MAX,
123e7096c13SJason A. Donenfeld 				       entry->tokens + now -
124e7096c13SJason A. Donenfeld 					       entry->last_time_ns);
125e7096c13SJason A. Donenfeld 			entry->last_time_ns = now;
126e7096c13SJason A. Donenfeld 			ret = tokens >= PACKET_COST;
127e7096c13SJason A. Donenfeld 			entry->tokens = ret ? tokens - PACKET_COST : tokens;
128e7096c13SJason A. Donenfeld 			spin_unlock(&entry->lock);
129e7096c13SJason A. Donenfeld 			rcu_read_unlock();
130e7096c13SJason A. Donenfeld 			return ret;
131e7096c13SJason A. Donenfeld 		}
132e7096c13SJason A. Donenfeld 	}
133e7096c13SJason A. Donenfeld 	rcu_read_unlock();
134e7096c13SJason A. Donenfeld 
135e7096c13SJason A. Donenfeld 	if (atomic_inc_return(&total_entries) > max_entries)
136e7096c13SJason A. Donenfeld 		goto err_oom;
137e7096c13SJason A. Donenfeld 
138e7096c13SJason A. Donenfeld 	entry = kmem_cache_alloc(entry_cache, GFP_KERNEL);
139e7096c13SJason A. Donenfeld 	if (unlikely(!entry))
140e7096c13SJason A. Donenfeld 		goto err_oom;
141e7096c13SJason A. Donenfeld 
142e7096c13SJason A. Donenfeld 	entry->net = net;
143e7096c13SJason A. Donenfeld 	entry->ip = ip;
144e7096c13SJason A. Donenfeld 	INIT_HLIST_NODE(&entry->hash);
145e7096c13SJason A. Donenfeld 	spin_lock_init(&entry->lock);
146e7096c13SJason A. Donenfeld 	entry->last_time_ns = ktime_get_coarse_boottime_ns();
147e7096c13SJason A. Donenfeld 	entry->tokens = TOKEN_MAX - PACKET_COST;
148e7096c13SJason A. Donenfeld 	spin_lock(&table_lock);
149e7096c13SJason A. Donenfeld 	hlist_add_head_rcu(&entry->hash, bucket);
150e7096c13SJason A. Donenfeld 	spin_unlock(&table_lock);
151e7096c13SJason A. Donenfeld 	return true;
152e7096c13SJason A. Donenfeld 
153e7096c13SJason A. Donenfeld err_oom:
154e7096c13SJason A. Donenfeld 	atomic_dec(&total_entries);
155e7096c13SJason A. Donenfeld 	return false;
156e7096c13SJason A. Donenfeld }
157e7096c13SJason A. Donenfeld 
wg_ratelimiter_init(void)158e7096c13SJason A. Donenfeld int wg_ratelimiter_init(void)
159e7096c13SJason A. Donenfeld {
160e7096c13SJason A. Donenfeld 	mutex_lock(&init_lock);
161e7096c13SJason A. Donenfeld 	if (++init_refcnt != 1)
162e7096c13SJason A. Donenfeld 		goto out;
163e7096c13SJason A. Donenfeld 
164e7096c13SJason A. Donenfeld 	entry_cache = KMEM_CACHE(ratelimiter_entry, 0);
165e7096c13SJason A. Donenfeld 	if (!entry_cache)
166e7096c13SJason A. Donenfeld 		goto err;
167e7096c13SJason A. Donenfeld 
168e7096c13SJason A. Donenfeld 	/* xt_hashlimit.c uses a slightly different algorithm for ratelimiting,
169e7096c13SJason A. Donenfeld 	 * but what it shares in common is that it uses a massive hashtable. So,
170e7096c13SJason A. Donenfeld 	 * we borrow their wisdom about good table sizes on different systems
171e7096c13SJason A. Donenfeld 	 * dependent on RAM. This calculation here comes from there.
172e7096c13SJason A. Donenfeld 	 */
173e7096c13SJason A. Donenfeld 	table_size = (totalram_pages() > (1U << 30) / PAGE_SIZE) ? 8192 :
174e7096c13SJason A. Donenfeld 		max_t(unsigned long, 16, roundup_pow_of_two(
175e7096c13SJason A. Donenfeld 			(totalram_pages() << PAGE_SHIFT) /
176e7096c13SJason A. Donenfeld 			(1U << 14) / sizeof(struct hlist_head)));
177e7096c13SJason A. Donenfeld 	max_entries = table_size * 8;
178e7096c13SJason A. Donenfeld 
179*d5f50794SGustavo A. R. Silva 	table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL);
180e7096c13SJason A. Donenfeld 	if (unlikely(!table_v4))
181e7096c13SJason A. Donenfeld 		goto err_kmemcache;
182e7096c13SJason A. Donenfeld 
183e7096c13SJason A. Donenfeld #if IS_ENABLED(CONFIG_IPV6)
184*d5f50794SGustavo A. R. Silva 	table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL);
185e7096c13SJason A. Donenfeld 	if (unlikely(!table_v6)) {
186e7096c13SJason A. Donenfeld 		kvfree(table_v4);
187e7096c13SJason A. Donenfeld 		goto err_kmemcache;
188e7096c13SJason A. Donenfeld 	}
189e7096c13SJason A. Donenfeld #endif
190e7096c13SJason A. Donenfeld 
191e7096c13SJason A. Donenfeld 	queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
192e7096c13SJason A. Donenfeld 	get_random_bytes(&key, sizeof(key));
193e7096c13SJason A. Donenfeld out:
194e7096c13SJason A. Donenfeld 	mutex_unlock(&init_lock);
195e7096c13SJason A. Donenfeld 	return 0;
196e7096c13SJason A. Donenfeld 
197e7096c13SJason A. Donenfeld err_kmemcache:
198e7096c13SJason A. Donenfeld 	kmem_cache_destroy(entry_cache);
199e7096c13SJason A. Donenfeld err:
200e7096c13SJason A. Donenfeld 	--init_refcnt;
201e7096c13SJason A. Donenfeld 	mutex_unlock(&init_lock);
202e7096c13SJason A. Donenfeld 	return -ENOMEM;
203e7096c13SJason A. Donenfeld }
204e7096c13SJason A. Donenfeld 
wg_ratelimiter_uninit(void)205e7096c13SJason A. Donenfeld void wg_ratelimiter_uninit(void)
206e7096c13SJason A. Donenfeld {
207e7096c13SJason A. Donenfeld 	mutex_lock(&init_lock);
208e7096c13SJason A. Donenfeld 	if (!init_refcnt || --init_refcnt)
209e7096c13SJason A. Donenfeld 		goto out;
210e7096c13SJason A. Donenfeld 
211e7096c13SJason A. Donenfeld 	cancel_delayed_work_sync(&gc_work);
212e7096c13SJason A. Donenfeld 	wg_ratelimiter_gc_entries(NULL);
213e7096c13SJason A. Donenfeld 	rcu_barrier();
214e7096c13SJason A. Donenfeld 	kvfree(table_v4);
215e7096c13SJason A. Donenfeld #if IS_ENABLED(CONFIG_IPV6)
216e7096c13SJason A. Donenfeld 	kvfree(table_v6);
217e7096c13SJason A. Donenfeld #endif
218e7096c13SJason A. Donenfeld 	kmem_cache_destroy(entry_cache);
219e7096c13SJason A. Donenfeld out:
220e7096c13SJason A. Donenfeld 	mutex_unlock(&init_lock);
221e7096c13SJason A. Donenfeld }
222e7096c13SJason A. Donenfeld 
223e7096c13SJason A. Donenfeld #include "selftest/ratelimiter.c"
224