1 // SPDX-License-Identifier: GPL-2.0 2 /* 3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. 4 */ 5 6 #include "ratelimiter.h" 7 #include <linux/siphash.h> 8 #include <linux/mm.h> 9 #include <linux/slab.h> 10 #include <net/ip.h> 11 12 static struct kmem_cache *entry_cache; 13 static hsiphash_key_t key; 14 static spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock"); 15 static DEFINE_MUTEX(init_lock); 16 static u64 init_refcnt; /* Protected by init_lock, hence not atomic. */ 17 static atomic_t total_entries = ATOMIC_INIT(0); 18 static unsigned int max_entries, table_size; 19 static void wg_ratelimiter_gc_entries(struct work_struct *); 20 static DECLARE_DEFERRABLE_WORK(gc_work, wg_ratelimiter_gc_entries); 21 static struct hlist_head *table_v4; 22 #if IS_ENABLED(CONFIG_IPV6) 23 static struct hlist_head *table_v6; 24 #endif 25 26 struct ratelimiter_entry { 27 u64 last_time_ns, tokens, ip; 28 void *net; 29 spinlock_t lock; 30 struct hlist_node hash; 31 struct rcu_head rcu; 32 }; 33 34 enum { 35 PACKETS_PER_SECOND = 20, 36 PACKETS_BURSTABLE = 5, 37 PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND, 38 TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE 39 }; 40 41 static void entry_free(struct rcu_head *rcu) 42 { 43 kmem_cache_free(entry_cache, 44 container_of(rcu, struct ratelimiter_entry, rcu)); 45 atomic_dec(&total_entries); 46 } 47 48 static void entry_uninit(struct ratelimiter_entry *entry) 49 { 50 hlist_del_rcu(&entry->hash); 51 call_rcu(&entry->rcu, entry_free); 52 } 53 54 /* Calling this function with a NULL work uninits all entries. */ 55 static void wg_ratelimiter_gc_entries(struct work_struct *work) 56 { 57 const u64 now = ktime_get_coarse_boottime_ns(); 58 struct ratelimiter_entry *entry; 59 struct hlist_node *temp; 60 unsigned int i; 61 62 for (i = 0; i < table_size; ++i) { 63 spin_lock(&table_lock); 64 hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) { 65 if (unlikely(!work) || 66 now - entry->last_time_ns > NSEC_PER_SEC) 67 entry_uninit(entry); 68 } 69 #if IS_ENABLED(CONFIG_IPV6) 70 hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) { 71 if (unlikely(!work) || 72 now - entry->last_time_ns > NSEC_PER_SEC) 73 entry_uninit(entry); 74 } 75 #endif 76 spin_unlock(&table_lock); 77 if (likely(work)) 78 cond_resched(); 79 } 80 if (likely(work)) 81 queue_delayed_work(system_power_efficient_wq, &gc_work, HZ); 82 } 83 84 bool wg_ratelimiter_allow(struct sk_buff *skb, struct net *net) 85 { 86 /* We only take the bottom half of the net pointer, so that we can hash 87 * 3 words in the end. This way, siphash's len param fits into the final 88 * u32, and we don't incur an extra round. 89 */ 90 const u32 net_word = (unsigned long)net; 91 struct ratelimiter_entry *entry; 92 struct hlist_head *bucket; 93 u64 ip; 94 95 if (skb->protocol == htons(ETH_P_IP)) { 96 ip = (u64 __force)ip_hdr(skb)->saddr; 97 bucket = &table_v4[hsiphash_2u32(net_word, ip, &key) & 98 (table_size - 1)]; 99 } 100 #if IS_ENABLED(CONFIG_IPV6) 101 else if (skb->protocol == htons(ETH_P_IPV6)) { 102 /* Only use 64 bits, so as to ratelimit the whole /64. */ 103 memcpy(&ip, &ipv6_hdr(skb)->saddr, sizeof(ip)); 104 bucket = &table_v6[hsiphash_3u32(net_word, ip >> 32, ip, &key) & 105 (table_size - 1)]; 106 } 107 #endif 108 else 109 return false; 110 rcu_read_lock(); 111 hlist_for_each_entry_rcu(entry, bucket, hash) { 112 if (entry->net == net && entry->ip == ip) { 113 u64 now, tokens; 114 bool ret; 115 /* Quasi-inspired by nft_limit.c, but this is actually a 116 * slightly different algorithm. Namely, we incorporate 117 * the burst as part of the maximum tokens, rather than 118 * as part of the rate. 119 */ 120 spin_lock(&entry->lock); 121 now = ktime_get_coarse_boottime_ns(); 122 tokens = min_t(u64, TOKEN_MAX, 123 entry->tokens + now - 124 entry->last_time_ns); 125 entry->last_time_ns = now; 126 ret = tokens >= PACKET_COST; 127 entry->tokens = ret ? tokens - PACKET_COST : tokens; 128 spin_unlock(&entry->lock); 129 rcu_read_unlock(); 130 return ret; 131 } 132 } 133 rcu_read_unlock(); 134 135 if (atomic_inc_return(&total_entries) > max_entries) 136 goto err_oom; 137 138 entry = kmem_cache_alloc(entry_cache, GFP_KERNEL); 139 if (unlikely(!entry)) 140 goto err_oom; 141 142 entry->net = net; 143 entry->ip = ip; 144 INIT_HLIST_NODE(&entry->hash); 145 spin_lock_init(&entry->lock); 146 entry->last_time_ns = ktime_get_coarse_boottime_ns(); 147 entry->tokens = TOKEN_MAX - PACKET_COST; 148 spin_lock(&table_lock); 149 hlist_add_head_rcu(&entry->hash, bucket); 150 spin_unlock(&table_lock); 151 return true; 152 153 err_oom: 154 atomic_dec(&total_entries); 155 return false; 156 } 157 158 int wg_ratelimiter_init(void) 159 { 160 mutex_lock(&init_lock); 161 if (++init_refcnt != 1) 162 goto out; 163 164 entry_cache = KMEM_CACHE(ratelimiter_entry, 0); 165 if (!entry_cache) 166 goto err; 167 168 /* xt_hashlimit.c uses a slightly different algorithm for ratelimiting, 169 * but what it shares in common is that it uses a massive hashtable. So, 170 * we borrow their wisdom about good table sizes on different systems 171 * dependent on RAM. This calculation here comes from there. 172 */ 173 table_size = (totalram_pages() > (1U << 30) / PAGE_SIZE) ? 8192 : 174 max_t(unsigned long, 16, roundup_pow_of_two( 175 (totalram_pages() << PAGE_SHIFT) / 176 (1U << 14) / sizeof(struct hlist_head))); 177 max_entries = table_size * 8; 178 179 table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL); 180 if (unlikely(!table_v4)) 181 goto err_kmemcache; 182 183 #if IS_ENABLED(CONFIG_IPV6) 184 table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL); 185 if (unlikely(!table_v6)) { 186 kvfree(table_v4); 187 goto err_kmemcache; 188 } 189 #endif 190 191 queue_delayed_work(system_power_efficient_wq, &gc_work, HZ); 192 get_random_bytes(&key, sizeof(key)); 193 out: 194 mutex_unlock(&init_lock); 195 return 0; 196 197 err_kmemcache: 198 kmem_cache_destroy(entry_cache); 199 err: 200 --init_refcnt; 201 mutex_unlock(&init_lock); 202 return -ENOMEM; 203 } 204 205 void wg_ratelimiter_uninit(void) 206 { 207 mutex_lock(&init_lock); 208 if (!init_refcnt || --init_refcnt) 209 goto out; 210 211 cancel_delayed_work_sync(&gc_work); 212 wg_ratelimiter_gc_entries(NULL); 213 rcu_barrier(); 214 kvfree(table_v4); 215 #if IS_ENABLED(CONFIG_IPV6) 216 kvfree(table_v6); 217 #endif 218 kmem_cache_destroy(entry_cache); 219 out: 220 mutex_unlock(&init_lock); 221 } 222 223 #include "selftest/ratelimiter.c" 224