1e7096c13SJason A. Donenfeld // SPDX-License-Identifier: GPL-2.0
2e7096c13SJason A. Donenfeld /*
3e7096c13SJason A. Donenfeld * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4e7096c13SJason A. Donenfeld */
5e7096c13SJason A. Donenfeld
6e7096c13SJason A. Donenfeld #include "ratelimiter.h"
7e7096c13SJason A. Donenfeld #include <linux/siphash.h>
8e7096c13SJason A. Donenfeld #include <linux/mm.h>
9e7096c13SJason A. Donenfeld #include <linux/slab.h>
10e7096c13SJason A. Donenfeld #include <net/ip.h>
11e7096c13SJason A. Donenfeld
12e7096c13SJason A. Donenfeld static struct kmem_cache *entry_cache;
13e7096c13SJason A. Donenfeld static hsiphash_key_t key;
14e7096c13SJason A. Donenfeld static spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock");
15e7096c13SJason A. Donenfeld static DEFINE_MUTEX(init_lock);
16e7096c13SJason A. Donenfeld static u64 init_refcnt; /* Protected by init_lock, hence not atomic. */
17e7096c13SJason A. Donenfeld static atomic_t total_entries = ATOMIC_INIT(0);
18e7096c13SJason A. Donenfeld static unsigned int max_entries, table_size;
19e7096c13SJason A. Donenfeld static void wg_ratelimiter_gc_entries(struct work_struct *);
20e7096c13SJason A. Donenfeld static DECLARE_DEFERRABLE_WORK(gc_work, wg_ratelimiter_gc_entries);
21e7096c13SJason A. Donenfeld static struct hlist_head *table_v4;
22e7096c13SJason A. Donenfeld #if IS_ENABLED(CONFIG_IPV6)
23e7096c13SJason A. Donenfeld static struct hlist_head *table_v6;
24e7096c13SJason A. Donenfeld #endif
25e7096c13SJason A. Donenfeld
26e7096c13SJason A. Donenfeld struct ratelimiter_entry {
27e7096c13SJason A. Donenfeld u64 last_time_ns, tokens, ip;
28e7096c13SJason A. Donenfeld void *net;
29e7096c13SJason A. Donenfeld spinlock_t lock;
30e7096c13SJason A. Donenfeld struct hlist_node hash;
31e7096c13SJason A. Donenfeld struct rcu_head rcu;
32e7096c13SJason A. Donenfeld };
33e7096c13SJason A. Donenfeld
34e7096c13SJason A. Donenfeld enum {
35e7096c13SJason A. Donenfeld PACKETS_PER_SECOND = 20,
36e7096c13SJason A. Donenfeld PACKETS_BURSTABLE = 5,
37e7096c13SJason A. Donenfeld PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND,
38e7096c13SJason A. Donenfeld TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE
39e7096c13SJason A. Donenfeld };
40e7096c13SJason A. Donenfeld
entry_free(struct rcu_head * rcu)41e7096c13SJason A. Donenfeld static void entry_free(struct rcu_head *rcu)
42e7096c13SJason A. Donenfeld {
43e7096c13SJason A. Donenfeld kmem_cache_free(entry_cache,
44e7096c13SJason A. Donenfeld container_of(rcu, struct ratelimiter_entry, rcu));
45e7096c13SJason A. Donenfeld atomic_dec(&total_entries);
46e7096c13SJason A. Donenfeld }
47e7096c13SJason A. Donenfeld
entry_uninit(struct ratelimiter_entry * entry)48e7096c13SJason A. Donenfeld static void entry_uninit(struct ratelimiter_entry *entry)
49e7096c13SJason A. Donenfeld {
50e7096c13SJason A. Donenfeld hlist_del_rcu(&entry->hash);
51e7096c13SJason A. Donenfeld call_rcu(&entry->rcu, entry_free);
52e7096c13SJason A. Donenfeld }
53e7096c13SJason A. Donenfeld
54e7096c13SJason A. Donenfeld /* Calling this function with a NULL work uninits all entries. */
wg_ratelimiter_gc_entries(struct work_struct * work)55e7096c13SJason A. Donenfeld static void wg_ratelimiter_gc_entries(struct work_struct *work)
56e7096c13SJason A. Donenfeld {
57e7096c13SJason A. Donenfeld const u64 now = ktime_get_coarse_boottime_ns();
58e7096c13SJason A. Donenfeld struct ratelimiter_entry *entry;
59e7096c13SJason A. Donenfeld struct hlist_node *temp;
60e7096c13SJason A. Donenfeld unsigned int i;
61e7096c13SJason A. Donenfeld
62e7096c13SJason A. Donenfeld for (i = 0; i < table_size; ++i) {
63e7096c13SJason A. Donenfeld spin_lock(&table_lock);
64e7096c13SJason A. Donenfeld hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) {
65e7096c13SJason A. Donenfeld if (unlikely(!work) ||
66e7096c13SJason A. Donenfeld now - entry->last_time_ns > NSEC_PER_SEC)
67e7096c13SJason A. Donenfeld entry_uninit(entry);
68e7096c13SJason A. Donenfeld }
69e7096c13SJason A. Donenfeld #if IS_ENABLED(CONFIG_IPV6)
70e7096c13SJason A. Donenfeld hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) {
71e7096c13SJason A. Donenfeld if (unlikely(!work) ||
72e7096c13SJason A. Donenfeld now - entry->last_time_ns > NSEC_PER_SEC)
73e7096c13SJason A. Donenfeld entry_uninit(entry);
74e7096c13SJason A. Donenfeld }
75e7096c13SJason A. Donenfeld #endif
76e7096c13SJason A. Donenfeld spin_unlock(&table_lock);
77e7096c13SJason A. Donenfeld if (likely(work))
78e7096c13SJason A. Donenfeld cond_resched();
79e7096c13SJason A. Donenfeld }
80e7096c13SJason A. Donenfeld if (likely(work))
81e7096c13SJason A. Donenfeld queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
82e7096c13SJason A. Donenfeld }
83e7096c13SJason A. Donenfeld
wg_ratelimiter_allow(struct sk_buff * skb,struct net * net)84e7096c13SJason A. Donenfeld bool wg_ratelimiter_allow(struct sk_buff *skb, struct net *net)
85e7096c13SJason A. Donenfeld {
86e7096c13SJason A. Donenfeld /* We only take the bottom half of the net pointer, so that we can hash
87e7096c13SJason A. Donenfeld * 3 words in the end. This way, siphash's len param fits into the final
88e7096c13SJason A. Donenfeld * u32, and we don't incur an extra round.
89e7096c13SJason A. Donenfeld */
90e7096c13SJason A. Donenfeld const u32 net_word = (unsigned long)net;
91e7096c13SJason A. Donenfeld struct ratelimiter_entry *entry;
92e7096c13SJason A. Donenfeld struct hlist_head *bucket;
93e7096c13SJason A. Donenfeld u64 ip;
94e7096c13SJason A. Donenfeld
95e7096c13SJason A. Donenfeld if (skb->protocol == htons(ETH_P_IP)) {
96e7096c13SJason A. Donenfeld ip = (u64 __force)ip_hdr(skb)->saddr;
97e7096c13SJason A. Donenfeld bucket = &table_v4[hsiphash_2u32(net_word, ip, &key) &
98e7096c13SJason A. Donenfeld (table_size - 1)];
99e7096c13SJason A. Donenfeld }
100e7096c13SJason A. Donenfeld #if IS_ENABLED(CONFIG_IPV6)
101e7096c13SJason A. Donenfeld else if (skb->protocol == htons(ETH_P_IPV6)) {
102e7096c13SJason A. Donenfeld /* Only use 64 bits, so as to ratelimit the whole /64. */
103e7096c13SJason A. Donenfeld memcpy(&ip, &ipv6_hdr(skb)->saddr, sizeof(ip));
104e7096c13SJason A. Donenfeld bucket = &table_v6[hsiphash_3u32(net_word, ip >> 32, ip, &key) &
105e7096c13SJason A. Donenfeld (table_size - 1)];
106e7096c13SJason A. Donenfeld }
107e7096c13SJason A. Donenfeld #endif
108e7096c13SJason A. Donenfeld else
109e7096c13SJason A. Donenfeld return false;
110e7096c13SJason A. Donenfeld rcu_read_lock();
111e7096c13SJason A. Donenfeld hlist_for_each_entry_rcu(entry, bucket, hash) {
112e7096c13SJason A. Donenfeld if (entry->net == net && entry->ip == ip) {
113e7096c13SJason A. Donenfeld u64 now, tokens;
114e7096c13SJason A. Donenfeld bool ret;
115e7096c13SJason A. Donenfeld /* Quasi-inspired by nft_limit.c, but this is actually a
116e7096c13SJason A. Donenfeld * slightly different algorithm. Namely, we incorporate
117e7096c13SJason A. Donenfeld * the burst as part of the maximum tokens, rather than
118e7096c13SJason A. Donenfeld * as part of the rate.
119e7096c13SJason A. Donenfeld */
120e7096c13SJason A. Donenfeld spin_lock(&entry->lock);
121e7096c13SJason A. Donenfeld now = ktime_get_coarse_boottime_ns();
122e7096c13SJason A. Donenfeld tokens = min_t(u64, TOKEN_MAX,
123e7096c13SJason A. Donenfeld entry->tokens + now -
124e7096c13SJason A. Donenfeld entry->last_time_ns);
125e7096c13SJason A. Donenfeld entry->last_time_ns = now;
126e7096c13SJason A. Donenfeld ret = tokens >= PACKET_COST;
127e7096c13SJason A. Donenfeld entry->tokens = ret ? tokens - PACKET_COST : tokens;
128e7096c13SJason A. Donenfeld spin_unlock(&entry->lock);
129e7096c13SJason A. Donenfeld rcu_read_unlock();
130e7096c13SJason A. Donenfeld return ret;
131e7096c13SJason A. Donenfeld }
132e7096c13SJason A. Donenfeld }
133e7096c13SJason A. Donenfeld rcu_read_unlock();
134e7096c13SJason A. Donenfeld
135e7096c13SJason A. Donenfeld if (atomic_inc_return(&total_entries) > max_entries)
136e7096c13SJason A. Donenfeld goto err_oom;
137e7096c13SJason A. Donenfeld
138e7096c13SJason A. Donenfeld entry = kmem_cache_alloc(entry_cache, GFP_KERNEL);
139e7096c13SJason A. Donenfeld if (unlikely(!entry))
140e7096c13SJason A. Donenfeld goto err_oom;
141e7096c13SJason A. Donenfeld
142e7096c13SJason A. Donenfeld entry->net = net;
143e7096c13SJason A. Donenfeld entry->ip = ip;
144e7096c13SJason A. Donenfeld INIT_HLIST_NODE(&entry->hash);
145e7096c13SJason A. Donenfeld spin_lock_init(&entry->lock);
146e7096c13SJason A. Donenfeld entry->last_time_ns = ktime_get_coarse_boottime_ns();
147e7096c13SJason A. Donenfeld entry->tokens = TOKEN_MAX - PACKET_COST;
148e7096c13SJason A. Donenfeld spin_lock(&table_lock);
149e7096c13SJason A. Donenfeld hlist_add_head_rcu(&entry->hash, bucket);
150e7096c13SJason A. Donenfeld spin_unlock(&table_lock);
151e7096c13SJason A. Donenfeld return true;
152e7096c13SJason A. Donenfeld
153e7096c13SJason A. Donenfeld err_oom:
154e7096c13SJason A. Donenfeld atomic_dec(&total_entries);
155e7096c13SJason A. Donenfeld return false;
156e7096c13SJason A. Donenfeld }
157e7096c13SJason A. Donenfeld
wg_ratelimiter_init(void)158e7096c13SJason A. Donenfeld int wg_ratelimiter_init(void)
159e7096c13SJason A. Donenfeld {
160e7096c13SJason A. Donenfeld mutex_lock(&init_lock);
161e7096c13SJason A. Donenfeld if (++init_refcnt != 1)
162e7096c13SJason A. Donenfeld goto out;
163e7096c13SJason A. Donenfeld
164e7096c13SJason A. Donenfeld entry_cache = KMEM_CACHE(ratelimiter_entry, 0);
165e7096c13SJason A. Donenfeld if (!entry_cache)
166e7096c13SJason A. Donenfeld goto err;
167e7096c13SJason A. Donenfeld
168e7096c13SJason A. Donenfeld /* xt_hashlimit.c uses a slightly different algorithm for ratelimiting,
169e7096c13SJason A. Donenfeld * but what it shares in common is that it uses a massive hashtable. So,
170e7096c13SJason A. Donenfeld * we borrow their wisdom about good table sizes on different systems
171e7096c13SJason A. Donenfeld * dependent on RAM. This calculation here comes from there.
172e7096c13SJason A. Donenfeld */
173e7096c13SJason A. Donenfeld table_size = (totalram_pages() > (1U << 30) / PAGE_SIZE) ? 8192 :
174e7096c13SJason A. Donenfeld max_t(unsigned long, 16, roundup_pow_of_two(
175e7096c13SJason A. Donenfeld (totalram_pages() << PAGE_SHIFT) /
176e7096c13SJason A. Donenfeld (1U << 14) / sizeof(struct hlist_head)));
177e7096c13SJason A. Donenfeld max_entries = table_size * 8;
178e7096c13SJason A. Donenfeld
179*d5f50794SGustavo A. R. Silva table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL);
180e7096c13SJason A. Donenfeld if (unlikely(!table_v4))
181e7096c13SJason A. Donenfeld goto err_kmemcache;
182e7096c13SJason A. Donenfeld
183e7096c13SJason A. Donenfeld #if IS_ENABLED(CONFIG_IPV6)
184*d5f50794SGustavo A. R. Silva table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL);
185e7096c13SJason A. Donenfeld if (unlikely(!table_v6)) {
186e7096c13SJason A. Donenfeld kvfree(table_v4);
187e7096c13SJason A. Donenfeld goto err_kmemcache;
188e7096c13SJason A. Donenfeld }
189e7096c13SJason A. Donenfeld #endif
190e7096c13SJason A. Donenfeld
191e7096c13SJason A. Donenfeld queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
192e7096c13SJason A. Donenfeld get_random_bytes(&key, sizeof(key));
193e7096c13SJason A. Donenfeld out:
194e7096c13SJason A. Donenfeld mutex_unlock(&init_lock);
195e7096c13SJason A. Donenfeld return 0;
196e7096c13SJason A. Donenfeld
197e7096c13SJason A. Donenfeld err_kmemcache:
198e7096c13SJason A. Donenfeld kmem_cache_destroy(entry_cache);
199e7096c13SJason A. Donenfeld err:
200e7096c13SJason A. Donenfeld --init_refcnt;
201e7096c13SJason A. Donenfeld mutex_unlock(&init_lock);
202e7096c13SJason A. Donenfeld return -ENOMEM;
203e7096c13SJason A. Donenfeld }
204e7096c13SJason A. Donenfeld
wg_ratelimiter_uninit(void)205e7096c13SJason A. Donenfeld void wg_ratelimiter_uninit(void)
206e7096c13SJason A. Donenfeld {
207e7096c13SJason A. Donenfeld mutex_lock(&init_lock);
208e7096c13SJason A. Donenfeld if (!init_refcnt || --init_refcnt)
209e7096c13SJason A. Donenfeld goto out;
210e7096c13SJason A. Donenfeld
211e7096c13SJason A. Donenfeld cancel_delayed_work_sync(&gc_work);
212e7096c13SJason A. Donenfeld wg_ratelimiter_gc_entries(NULL);
213e7096c13SJason A. Donenfeld rcu_barrier();
214e7096c13SJason A. Donenfeld kvfree(table_v4);
215e7096c13SJason A. Donenfeld #if IS_ENABLED(CONFIG_IPV6)
216e7096c13SJason A. Donenfeld kvfree(table_v6);
217e7096c13SJason A. Donenfeld #endif
218e7096c13SJason A. Donenfeld kmem_cache_destroy(entry_cache);
219e7096c13SJason A. Donenfeld out:
220e7096c13SJason A. Donenfeld mutex_unlock(&init_lock);
221e7096c13SJason A. Donenfeld }
222e7096c13SJason A. Donenfeld
223e7096c13SJason A. Donenfeld #include "selftest/ratelimiter.c"
224