1 // SPDX-License-Identifier: GPL-2.0 2 /* 3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. 4 */ 5 6 #include "allowedips.h" 7 #include "peer.h" 8 9 static struct kmem_cache *node_cache; 10 11 static void swap_endian(u8 *dst, const u8 *src, u8 bits) 12 { 13 if (bits == 32) { 14 *(u32 *)dst = be32_to_cpu(*(const __be32 *)src); 15 } else if (bits == 128) { 16 ((u64 *)dst)[0] = be64_to_cpu(((const __be64 *)src)[0]); 17 ((u64 *)dst)[1] = be64_to_cpu(((const __be64 *)src)[1]); 18 } 19 } 20 21 static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, 22 u8 cidr, u8 bits) 23 { 24 node->cidr = cidr; 25 node->bit_at_a = cidr / 8U; 26 #ifdef __LITTLE_ENDIAN 27 node->bit_at_a ^= (bits / 8U - 1U) % 8U; 28 #endif 29 node->bit_at_b = 7U - (cidr % 8U); 30 node->bitlen = bits; 31 memcpy(node->bits, src, bits / 8U); 32 } 33 34 static inline u8 choose(struct allowedips_node *node, const u8 *key) 35 { 36 return (key[node->bit_at_a] >> node->bit_at_b) & 1; 37 } 38 39 static void push_rcu(struct allowedips_node **stack, 40 struct allowedips_node __rcu *p, unsigned int *len) 41 { 42 if (rcu_access_pointer(p)) { 43 WARN_ON(IS_ENABLED(DEBUG) && *len >= 128); 44 stack[(*len)++] = rcu_dereference_raw(p); 45 } 46 } 47 48 static void node_free_rcu(struct rcu_head *rcu) 49 { 50 kmem_cache_free(node_cache, container_of(rcu, struct allowedips_node, rcu)); 51 } 52 53 static void root_free_rcu(struct rcu_head *rcu) 54 { 55 struct allowedips_node *node, *stack[128] = { 56 container_of(rcu, struct allowedips_node, rcu) }; 57 unsigned int len = 1; 58 59 while (len > 0 && (node = stack[--len])) { 60 push_rcu(stack, node->bit[0], &len); 61 push_rcu(stack, node->bit[1], &len); 62 kmem_cache_free(node_cache, node); 63 } 64 } 65 66 static void root_remove_peer_lists(struct allowedips_node *root) 67 { 68 struct allowedips_node *node, *stack[128] = { root }; 69 unsigned int len = 1; 70 71 while (len > 0 && (node = stack[--len])) { 72 push_rcu(stack, node->bit[0], &len); 73 push_rcu(stack, node->bit[1], &len); 74 if (rcu_access_pointer(node->peer)) 75 list_del(&node->peer_list); 76 } 77 } 78 79 static unsigned int fls128(u64 a, u64 b) 80 { 81 return a ? fls64(a) + 64U : fls64(b); 82 } 83 84 static u8 common_bits(const struct allowedips_node *node, const u8 *key, 85 u8 bits) 86 { 87 if (bits == 32) 88 return 32U - fls(*(const u32 *)node->bits ^ *(const u32 *)key); 89 else if (bits == 128) 90 return 128U - fls128( 91 *(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0], 92 *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]); 93 return 0; 94 } 95 96 static bool prefix_matches(const struct allowedips_node *node, const u8 *key, 97 u8 bits) 98 { 99 /* This could be much faster if it actually just compared the common 100 * bits properly, by precomputing a mask bswap(~0 << (32 - cidr)), and 101 * the rest, but it turns out that common_bits is already super fast on 102 * modern processors, even taking into account the unfortunate bswap. 103 * So, we just inline it like this instead. 104 */ 105 return common_bits(node, key, bits) >= node->cidr; 106 } 107 108 static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits, 109 const u8 *key) 110 { 111 struct allowedips_node *node = trie, *found = NULL; 112 113 while (node && prefix_matches(node, key, bits)) { 114 if (rcu_access_pointer(node->peer)) 115 found = node; 116 if (node->cidr == bits) 117 break; 118 node = rcu_dereference_bh(node->bit[choose(node, key)]); 119 } 120 return found; 121 } 122 123 /* Returns a strong reference to a peer */ 124 static struct wg_peer *lookup(struct allowedips_node __rcu *root, u8 bits, 125 const void *be_ip) 126 { 127 /* Aligned so it can be passed to fls/fls64 */ 128 u8 ip[16] __aligned(__alignof(u64)); 129 struct allowedips_node *node; 130 struct wg_peer *peer = NULL; 131 132 swap_endian(ip, be_ip, bits); 133 134 rcu_read_lock_bh(); 135 retry: 136 node = find_node(rcu_dereference_bh(root), bits, ip); 137 if (node) { 138 peer = wg_peer_get_maybe_zero(rcu_dereference_bh(node->peer)); 139 if (!peer) 140 goto retry; 141 } 142 rcu_read_unlock_bh(); 143 return peer; 144 } 145 146 static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, 147 u8 cidr, u8 bits, struct allowedips_node **rnode, 148 struct mutex *lock) 149 { 150 struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock)); 151 struct allowedips_node *parent = NULL; 152 bool exact = false; 153 154 while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) { 155 parent = node; 156 if (parent->cidr == cidr) { 157 exact = true; 158 break; 159 } 160 node = rcu_dereference_protected(parent->bit[choose(parent, key)], lockdep_is_held(lock)); 161 } 162 *rnode = parent; 163 return exact; 164 } 165 166 static inline void connect_node(struct allowedips_node **parent, u8 bit, struct allowedips_node *node) 167 { 168 node->parent_bit_packed = (unsigned long)parent | bit; 169 rcu_assign_pointer(*parent, node); 170 } 171 172 static inline void choose_and_connect_node(struct allowedips_node *parent, struct allowedips_node *node) 173 { 174 u8 bit = choose(parent, node->bits); 175 connect_node(&parent->bit[bit], bit, node); 176 } 177 178 static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, 179 u8 cidr, struct wg_peer *peer, struct mutex *lock) 180 { 181 struct allowedips_node *node, *parent, *down, *newnode; 182 183 if (unlikely(cidr > bits || !peer)) 184 return -EINVAL; 185 186 if (!rcu_access_pointer(*trie)) { 187 node = kmem_cache_zalloc(node_cache, GFP_KERNEL); 188 if (unlikely(!node)) 189 return -ENOMEM; 190 RCU_INIT_POINTER(node->peer, peer); 191 list_add_tail(&node->peer_list, &peer->allowedips_list); 192 copy_and_assign_cidr(node, key, cidr, bits); 193 connect_node(trie, 2, node); 194 return 0; 195 } 196 if (node_placement(*trie, key, cidr, bits, &node, lock)) { 197 rcu_assign_pointer(node->peer, peer); 198 list_move_tail(&node->peer_list, &peer->allowedips_list); 199 return 0; 200 } 201 202 newnode = kmem_cache_zalloc(node_cache, GFP_KERNEL); 203 if (unlikely(!newnode)) 204 return -ENOMEM; 205 RCU_INIT_POINTER(newnode->peer, peer); 206 list_add_tail(&newnode->peer_list, &peer->allowedips_list); 207 copy_and_assign_cidr(newnode, key, cidr, bits); 208 209 if (!node) { 210 down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); 211 } else { 212 const u8 bit = choose(node, key); 213 down = rcu_dereference_protected(node->bit[bit], lockdep_is_held(lock)); 214 if (!down) { 215 connect_node(&node->bit[bit], bit, newnode); 216 return 0; 217 } 218 } 219 cidr = min(cidr, common_bits(down, key, bits)); 220 parent = node; 221 222 if (newnode->cidr == cidr) { 223 choose_and_connect_node(newnode, down); 224 if (!parent) 225 connect_node(trie, 2, newnode); 226 else 227 choose_and_connect_node(parent, newnode); 228 return 0; 229 } 230 231 node = kmem_cache_zalloc(node_cache, GFP_KERNEL); 232 if (unlikely(!node)) { 233 list_del(&newnode->peer_list); 234 kmem_cache_free(node_cache, newnode); 235 return -ENOMEM; 236 } 237 INIT_LIST_HEAD(&node->peer_list); 238 copy_and_assign_cidr(node, newnode->bits, cidr, bits); 239 240 choose_and_connect_node(node, down); 241 choose_and_connect_node(node, newnode); 242 if (!parent) 243 connect_node(trie, 2, node); 244 else 245 choose_and_connect_node(parent, node); 246 return 0; 247 } 248 249 void wg_allowedips_init(struct allowedips *table) 250 { 251 table->root4 = table->root6 = NULL; 252 table->seq = 1; 253 } 254 255 void wg_allowedips_free(struct allowedips *table, struct mutex *lock) 256 { 257 struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6; 258 259 ++table->seq; 260 RCU_INIT_POINTER(table->root4, NULL); 261 RCU_INIT_POINTER(table->root6, NULL); 262 if (rcu_access_pointer(old4)) { 263 struct allowedips_node *node = rcu_dereference_protected(old4, 264 lockdep_is_held(lock)); 265 266 root_remove_peer_lists(node); 267 call_rcu(&node->rcu, root_free_rcu); 268 } 269 if (rcu_access_pointer(old6)) { 270 struct allowedips_node *node = rcu_dereference_protected(old6, 271 lockdep_is_held(lock)); 272 273 root_remove_peer_lists(node); 274 call_rcu(&node->rcu, root_free_rcu); 275 } 276 } 277 278 int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, 279 u8 cidr, struct wg_peer *peer, struct mutex *lock) 280 { 281 /* Aligned so it can be passed to fls */ 282 u8 key[4] __aligned(__alignof(u32)); 283 284 ++table->seq; 285 swap_endian(key, (const u8 *)ip, 32); 286 return add(&table->root4, 32, key, cidr, peer, lock); 287 } 288 289 int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, 290 u8 cidr, struct wg_peer *peer, struct mutex *lock) 291 { 292 /* Aligned so it can be passed to fls64 */ 293 u8 key[16] __aligned(__alignof(u64)); 294 295 ++table->seq; 296 swap_endian(key, (const u8 *)ip, 128); 297 return add(&table->root6, 128, key, cidr, peer, lock); 298 } 299 300 void wg_allowedips_remove_by_peer(struct allowedips *table, 301 struct wg_peer *peer, struct mutex *lock) 302 { 303 struct allowedips_node *node, *child, **parent_bit, *parent, *tmp; 304 bool free_parent; 305 306 if (list_empty(&peer->allowedips_list)) 307 return; 308 ++table->seq; 309 list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) { 310 list_del_init(&node->peer_list); 311 RCU_INIT_POINTER(node->peer, NULL); 312 if (node->bit[0] && node->bit[1]) 313 continue; 314 child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])], 315 lockdep_is_held(lock)); 316 if (child) 317 child->parent_bit_packed = node->parent_bit_packed; 318 parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL); 319 *parent_bit = child; 320 parent = (void *)parent_bit - 321 offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]); 322 free_parent = !rcu_access_pointer(node->bit[0]) && 323 !rcu_access_pointer(node->bit[1]) && 324 (node->parent_bit_packed & 3) <= 1 && 325 !rcu_access_pointer(parent->peer); 326 if (free_parent) 327 child = rcu_dereference_protected( 328 parent->bit[!(node->parent_bit_packed & 1)], 329 lockdep_is_held(lock)); 330 call_rcu(&node->rcu, node_free_rcu); 331 if (!free_parent) 332 continue; 333 if (child) 334 child->parent_bit_packed = parent->parent_bit_packed; 335 *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child; 336 call_rcu(&parent->rcu, node_free_rcu); 337 } 338 } 339 340 int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) 341 { 342 const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); 343 swap_endian(ip, node->bits, node->bitlen); 344 memset(ip + cidr_bytes, 0, node->bitlen / 8U - cidr_bytes); 345 if (node->cidr) 346 ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U); 347 348 *cidr = node->cidr; 349 return node->bitlen == 32 ? AF_INET : AF_INET6; 350 } 351 352 /* Returns a strong reference to a peer */ 353 struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table, 354 struct sk_buff *skb) 355 { 356 if (skb->protocol == htons(ETH_P_IP)) 357 return lookup(table->root4, 32, &ip_hdr(skb)->daddr); 358 else if (skb->protocol == htons(ETH_P_IPV6)) 359 return lookup(table->root6, 128, &ipv6_hdr(skb)->daddr); 360 return NULL; 361 } 362 363 /* Returns a strong reference to a peer */ 364 struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, 365 struct sk_buff *skb) 366 { 367 if (skb->protocol == htons(ETH_P_IP)) 368 return lookup(table->root4, 32, &ip_hdr(skb)->saddr); 369 else if (skb->protocol == htons(ETH_P_IPV6)) 370 return lookup(table->root6, 128, &ipv6_hdr(skb)->saddr); 371 return NULL; 372 } 373 374 int __init wg_allowedips_slab_init(void) 375 { 376 node_cache = KMEM_CACHE(allowedips_node, 0); 377 return node_cache ? 0 : -ENOMEM; 378 } 379 380 void wg_allowedips_slab_uninit(void) 381 { 382 rcu_barrier(); 383 kmem_cache_destroy(node_cache); 384 } 385 386 #include "selftest/allowedips.c" 387