1 /*
2  * Copyright (c) 2008-2009 Patrick McHardy <kaber@trash.net>
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License version 2 as
6  * published by the Free Software Foundation.
7  *
8  * Development of this code funded by Astaro AG (http://www.astaro.com/)
9  */
10 
11 #include <linux/kernel.h>
12 #include <linux/init.h>
13 #include <linux/module.h>
14 #include <linux/list.h>
15 #include <linux/rbtree.h>
16 #include <linux/netlink.h>
17 #include <linux/netfilter.h>
18 #include <linux/netfilter/nf_tables.h>
19 #include <net/netfilter/nf_tables.h>
20 
21 struct nft_rbtree {
22 	struct rb_root		root;
23 	rwlock_t		lock;
24 	seqcount_t		count;
25 	struct delayed_work	gc_work;
26 };
27 
28 struct nft_rbtree_elem {
29 	struct rb_node		node;
30 	struct nft_set_ext	ext;
31 };
32 
33 static bool nft_rbtree_interval_end(const struct nft_rbtree_elem *rbe)
34 {
35 	return nft_set_ext_exists(&rbe->ext, NFT_SET_EXT_FLAGS) &&
36 	       (*nft_set_ext_flags(&rbe->ext) & NFT_SET_ELEM_INTERVAL_END);
37 }
38 
39 static bool nft_rbtree_equal(const struct nft_set *set, const void *this,
40 			     const struct nft_rbtree_elem *interval)
41 {
42 	return memcmp(this, nft_set_ext_key(&interval->ext), set->klen) == 0;
43 }
44 
45 static bool __nft_rbtree_lookup(const struct net *net, const struct nft_set *set,
46 				const u32 *key, const struct nft_set_ext **ext,
47 				unsigned int seq)
48 {
49 	struct nft_rbtree *priv = nft_set_priv(set);
50 	const struct nft_rbtree_elem *rbe, *interval = NULL;
51 	u8 genmask = nft_genmask_cur(net);
52 	const struct rb_node *parent;
53 	const void *this;
54 	int d;
55 
56 	parent = rcu_dereference_raw(priv->root.rb_node);
57 	while (parent != NULL) {
58 		if (read_seqcount_retry(&priv->count, seq))
59 			return false;
60 
61 		rbe = rb_entry(parent, struct nft_rbtree_elem, node);
62 
63 		this = nft_set_ext_key(&rbe->ext);
64 		d = memcmp(this, key, set->klen);
65 		if (d < 0) {
66 			parent = rcu_dereference_raw(parent->rb_left);
67 			if (interval &&
68 			    nft_rbtree_equal(set, this, interval) &&
69 			    nft_rbtree_interval_end(rbe) &&
70 			    !nft_rbtree_interval_end(interval))
71 				continue;
72 			interval = rbe;
73 		} else if (d > 0)
74 			parent = rcu_dereference_raw(parent->rb_right);
75 		else {
76 			if (!nft_set_elem_active(&rbe->ext, genmask)) {
77 				parent = rcu_dereference_raw(parent->rb_left);
78 				continue;
79 			}
80 			if (nft_rbtree_interval_end(rbe))
81 				goto out;
82 
83 			*ext = &rbe->ext;
84 			return true;
85 		}
86 	}
87 
88 	if (set->flags & NFT_SET_INTERVAL && interval != NULL &&
89 	    nft_set_elem_active(&interval->ext, genmask) &&
90 	    !nft_rbtree_interval_end(interval)) {
91 		*ext = &interval->ext;
92 		return true;
93 	}
94 out:
95 	return false;
96 }
97 
98 static bool nft_rbtree_lookup(const struct net *net, const struct nft_set *set,
99 			      const u32 *key, const struct nft_set_ext **ext)
100 {
101 	struct nft_rbtree *priv = nft_set_priv(set);
102 	unsigned int seq = read_seqcount_begin(&priv->count);
103 	bool ret;
104 
105 	ret = __nft_rbtree_lookup(net, set, key, ext, seq);
106 	if (ret || !read_seqcount_retry(&priv->count, seq))
107 		return ret;
108 
109 	read_lock_bh(&priv->lock);
110 	seq = read_seqcount_begin(&priv->count);
111 	ret = __nft_rbtree_lookup(net, set, key, ext, seq);
112 	read_unlock_bh(&priv->lock);
113 
114 	return ret;
115 }
116 
117 static bool __nft_rbtree_get(const struct net *net, const struct nft_set *set,
118 			     const u32 *key, struct nft_rbtree_elem **elem,
119 			     unsigned int seq, unsigned int flags, u8 genmask)
120 {
121 	struct nft_rbtree_elem *rbe, *interval = NULL;
122 	struct nft_rbtree *priv = nft_set_priv(set);
123 	const struct rb_node *parent;
124 	const void *this;
125 	int d;
126 
127 	parent = rcu_dereference_raw(priv->root.rb_node);
128 	while (parent != NULL) {
129 		if (read_seqcount_retry(&priv->count, seq))
130 			return false;
131 
132 		rbe = rb_entry(parent, struct nft_rbtree_elem, node);
133 
134 		this = nft_set_ext_key(&rbe->ext);
135 		d = memcmp(this, key, set->klen);
136 		if (d < 0) {
137 			parent = rcu_dereference_raw(parent->rb_left);
138 			if (!(flags & NFT_SET_ELEM_INTERVAL_END))
139 				interval = rbe;
140 		} else if (d > 0) {
141 			parent = rcu_dereference_raw(parent->rb_right);
142 			if (flags & NFT_SET_ELEM_INTERVAL_END)
143 				interval = rbe;
144 		} else {
145 			if (!nft_set_elem_active(&rbe->ext, genmask))
146 				parent = rcu_dereference_raw(parent->rb_left);
147 
148 			if (!nft_set_ext_exists(&rbe->ext, NFT_SET_EXT_FLAGS) ||
149 			    (*nft_set_ext_flags(&rbe->ext) & NFT_SET_ELEM_INTERVAL_END) ==
150 			    (flags & NFT_SET_ELEM_INTERVAL_END)) {
151 				*elem = rbe;
152 				return true;
153 			}
154 			return false;
155 		}
156 	}
157 
158 	if (set->flags & NFT_SET_INTERVAL && interval != NULL &&
159 	    nft_set_elem_active(&interval->ext, genmask) &&
160 	    ((!nft_rbtree_interval_end(interval) &&
161 	      !(flags & NFT_SET_ELEM_INTERVAL_END)) ||
162 	     (nft_rbtree_interval_end(interval) &&
163 	      (flags & NFT_SET_ELEM_INTERVAL_END)))) {
164 		*elem = interval;
165 		return true;
166 	}
167 
168 	return false;
169 }
170 
171 static void *nft_rbtree_get(const struct net *net, const struct nft_set *set,
172 			    const struct nft_set_elem *elem, unsigned int flags)
173 {
174 	struct nft_rbtree *priv = nft_set_priv(set);
175 	unsigned int seq = read_seqcount_begin(&priv->count);
176 	struct nft_rbtree_elem *rbe = ERR_PTR(-ENOENT);
177 	const u32 *key = (const u32 *)&elem->key.val;
178 	u8 genmask = nft_genmask_cur(net);
179 	bool ret;
180 
181 	ret = __nft_rbtree_get(net, set, key, &rbe, seq, flags, genmask);
182 	if (ret || !read_seqcount_retry(&priv->count, seq))
183 		return rbe;
184 
185 	read_lock_bh(&priv->lock);
186 	seq = read_seqcount_begin(&priv->count);
187 	ret = __nft_rbtree_get(net, set, key, &rbe, seq, flags, genmask);
188 	if (!ret)
189 		rbe = ERR_PTR(-ENOENT);
190 	read_unlock_bh(&priv->lock);
191 
192 	return rbe;
193 }
194 
195 static int __nft_rbtree_insert(const struct net *net, const struct nft_set *set,
196 			       struct nft_rbtree_elem *new,
197 			       struct nft_set_ext **ext)
198 {
199 	struct nft_rbtree *priv = nft_set_priv(set);
200 	u8 genmask = nft_genmask_next(net);
201 	struct nft_rbtree_elem *rbe;
202 	struct rb_node *parent, **p;
203 	int d;
204 
205 	parent = NULL;
206 	p = &priv->root.rb_node;
207 	while (*p != NULL) {
208 		parent = *p;
209 		rbe = rb_entry(parent, struct nft_rbtree_elem, node);
210 		d = memcmp(nft_set_ext_key(&rbe->ext),
211 			   nft_set_ext_key(&new->ext),
212 			   set->klen);
213 		if (d < 0)
214 			p = &parent->rb_left;
215 		else if (d > 0)
216 			p = &parent->rb_right;
217 		else {
218 			if (nft_rbtree_interval_end(rbe) &&
219 			    !nft_rbtree_interval_end(new)) {
220 				p = &parent->rb_left;
221 			} else if (!nft_rbtree_interval_end(rbe) &&
222 				   nft_rbtree_interval_end(new)) {
223 				p = &parent->rb_right;
224 			} else if (nft_set_elem_active(&rbe->ext, genmask)) {
225 				*ext = &rbe->ext;
226 				return -EEXIST;
227 			} else {
228 				p = &parent->rb_left;
229 			}
230 		}
231 	}
232 	rb_link_node_rcu(&new->node, parent, p);
233 	rb_insert_color(&new->node, &priv->root);
234 	return 0;
235 }
236 
237 static int nft_rbtree_insert(const struct net *net, const struct nft_set *set,
238 			     const struct nft_set_elem *elem,
239 			     struct nft_set_ext **ext)
240 {
241 	struct nft_rbtree *priv = nft_set_priv(set);
242 	struct nft_rbtree_elem *rbe = elem->priv;
243 	int err;
244 
245 	write_lock_bh(&priv->lock);
246 	write_seqcount_begin(&priv->count);
247 	err = __nft_rbtree_insert(net, set, rbe, ext);
248 	write_seqcount_end(&priv->count);
249 	write_unlock_bh(&priv->lock);
250 
251 	return err;
252 }
253 
254 static void nft_rbtree_remove(const struct net *net,
255 			      const struct nft_set *set,
256 			      const struct nft_set_elem *elem)
257 {
258 	struct nft_rbtree *priv = nft_set_priv(set);
259 	struct nft_rbtree_elem *rbe = elem->priv;
260 
261 	write_lock_bh(&priv->lock);
262 	write_seqcount_begin(&priv->count);
263 	rb_erase(&rbe->node, &priv->root);
264 	write_seqcount_end(&priv->count);
265 	write_unlock_bh(&priv->lock);
266 }
267 
268 static void nft_rbtree_activate(const struct net *net,
269 				const struct nft_set *set,
270 				const struct nft_set_elem *elem)
271 {
272 	struct nft_rbtree_elem *rbe = elem->priv;
273 
274 	nft_set_elem_change_active(net, set, &rbe->ext);
275 	nft_set_elem_clear_busy(&rbe->ext);
276 }
277 
278 static bool nft_rbtree_flush(const struct net *net,
279 			     const struct nft_set *set, void *priv)
280 {
281 	struct nft_rbtree_elem *rbe = priv;
282 
283 	if (!nft_set_elem_mark_busy(&rbe->ext) ||
284 	    !nft_is_active(net, &rbe->ext)) {
285 		nft_set_elem_change_active(net, set, &rbe->ext);
286 		return true;
287 	}
288 	return false;
289 }
290 
291 static void *nft_rbtree_deactivate(const struct net *net,
292 				   const struct nft_set *set,
293 				   const struct nft_set_elem *elem)
294 {
295 	const struct nft_rbtree *priv = nft_set_priv(set);
296 	const struct rb_node *parent = priv->root.rb_node;
297 	struct nft_rbtree_elem *rbe, *this = elem->priv;
298 	u8 genmask = nft_genmask_next(net);
299 	int d;
300 
301 	while (parent != NULL) {
302 		rbe = rb_entry(parent, struct nft_rbtree_elem, node);
303 
304 		d = memcmp(nft_set_ext_key(&rbe->ext), &elem->key.val,
305 					   set->klen);
306 		if (d < 0)
307 			parent = parent->rb_left;
308 		else if (d > 0)
309 			parent = parent->rb_right;
310 		else {
311 			if (nft_rbtree_interval_end(rbe) &&
312 			    !nft_rbtree_interval_end(this)) {
313 				parent = parent->rb_left;
314 				continue;
315 			} else if (!nft_rbtree_interval_end(rbe) &&
316 				   nft_rbtree_interval_end(this)) {
317 				parent = parent->rb_right;
318 				continue;
319 			} else if (!nft_set_elem_active(&rbe->ext, genmask)) {
320 				parent = parent->rb_left;
321 				continue;
322 			}
323 			nft_rbtree_flush(net, set, rbe);
324 			return rbe;
325 		}
326 	}
327 	return NULL;
328 }
329 
330 static void nft_rbtree_walk(const struct nft_ctx *ctx,
331 			    struct nft_set *set,
332 			    struct nft_set_iter *iter)
333 {
334 	struct nft_rbtree *priv = nft_set_priv(set);
335 	struct nft_rbtree_elem *rbe;
336 	struct nft_set_elem elem;
337 	struct rb_node *node;
338 
339 	read_lock_bh(&priv->lock);
340 	for (node = rb_first(&priv->root); node != NULL; node = rb_next(node)) {
341 		rbe = rb_entry(node, struct nft_rbtree_elem, node);
342 
343 		if (iter->count < iter->skip)
344 			goto cont;
345 		if (!nft_set_elem_active(&rbe->ext, iter->genmask))
346 			goto cont;
347 
348 		elem.priv = rbe;
349 
350 		iter->err = iter->fn(ctx, set, iter, &elem);
351 		if (iter->err < 0) {
352 			read_unlock_bh(&priv->lock);
353 			return;
354 		}
355 cont:
356 		iter->count++;
357 	}
358 	read_unlock_bh(&priv->lock);
359 }
360 
361 static void nft_rbtree_gc(struct work_struct *work)
362 {
363 	struct nft_rbtree_elem *rbe, *rbe_end = NULL, *rbe_prev = NULL;
364 	struct nft_set_gc_batch *gcb = NULL;
365 	struct nft_rbtree *priv;
366 	struct rb_node *node;
367 	struct nft_set *set;
368 
369 	priv = container_of(work, struct nft_rbtree, gc_work.work);
370 	set  = nft_set_container_of(priv);
371 
372 	write_lock_bh(&priv->lock);
373 	write_seqcount_begin(&priv->count);
374 	for (node = rb_first(&priv->root); node != NULL; node = rb_next(node)) {
375 		rbe = rb_entry(node, struct nft_rbtree_elem, node);
376 
377 		if (nft_rbtree_interval_end(rbe)) {
378 			rbe_end = rbe;
379 			continue;
380 		}
381 		if (!nft_set_elem_expired(&rbe->ext))
382 			continue;
383 		if (nft_set_elem_mark_busy(&rbe->ext))
384 			continue;
385 
386 		if (rbe_prev) {
387 			rb_erase(&rbe_prev->node, &priv->root);
388 			rbe_prev = NULL;
389 		}
390 		gcb = nft_set_gc_batch_check(set, gcb, GFP_ATOMIC);
391 		if (!gcb)
392 			break;
393 
394 		atomic_dec(&set->nelems);
395 		nft_set_gc_batch_add(gcb, rbe);
396 		rbe_prev = rbe;
397 
398 		if (rbe_end) {
399 			atomic_dec(&set->nelems);
400 			nft_set_gc_batch_add(gcb, rbe_end);
401 			rb_erase(&rbe_end->node, &priv->root);
402 			rbe_end = NULL;
403 		}
404 		node = rb_next(node);
405 		if (!node)
406 			break;
407 	}
408 	if (rbe_prev)
409 		rb_erase(&rbe_prev->node, &priv->root);
410 	write_seqcount_end(&priv->count);
411 	write_unlock_bh(&priv->lock);
412 
413 	nft_set_gc_batch_complete(gcb);
414 
415 	queue_delayed_work(system_power_efficient_wq, &priv->gc_work,
416 			   nft_set_gc_interval(set));
417 }
418 
419 static u64 nft_rbtree_privsize(const struct nlattr * const nla[],
420 			       const struct nft_set_desc *desc)
421 {
422 	return sizeof(struct nft_rbtree);
423 }
424 
425 static int nft_rbtree_init(const struct nft_set *set,
426 			   const struct nft_set_desc *desc,
427 			   const struct nlattr * const nla[])
428 {
429 	struct nft_rbtree *priv = nft_set_priv(set);
430 
431 	rwlock_init(&priv->lock);
432 	seqcount_init(&priv->count);
433 	priv->root = RB_ROOT;
434 
435 	INIT_DEFERRABLE_WORK(&priv->gc_work, nft_rbtree_gc);
436 	if (set->flags & NFT_SET_TIMEOUT)
437 		queue_delayed_work(system_power_efficient_wq, &priv->gc_work,
438 				   nft_set_gc_interval(set));
439 
440 	return 0;
441 }
442 
443 static void nft_rbtree_destroy(const struct nft_set *set)
444 {
445 	struct nft_rbtree *priv = nft_set_priv(set);
446 	struct nft_rbtree_elem *rbe;
447 	struct rb_node *node;
448 
449 	cancel_delayed_work_sync(&priv->gc_work);
450 	rcu_barrier();
451 	while ((node = priv->root.rb_node) != NULL) {
452 		rb_erase(node, &priv->root);
453 		rbe = rb_entry(node, struct nft_rbtree_elem, node);
454 		nft_set_elem_destroy(set, rbe, true);
455 	}
456 }
457 
458 static bool nft_rbtree_estimate(const struct nft_set_desc *desc, u32 features,
459 				struct nft_set_estimate *est)
460 {
461 	if (desc->size)
462 		est->size = sizeof(struct nft_rbtree) +
463 			    desc->size * sizeof(struct nft_rbtree_elem);
464 	else
465 		est->size = ~0;
466 
467 	est->lookup = NFT_SET_CLASS_O_LOG_N;
468 	est->space  = NFT_SET_CLASS_O_N;
469 
470 	return true;
471 }
472 
473 struct nft_set_type nft_set_rbtree_type __read_mostly = {
474 	.owner		= THIS_MODULE,
475 	.features	= NFT_SET_INTERVAL | NFT_SET_MAP | NFT_SET_OBJECT | NFT_SET_TIMEOUT,
476 	.ops		= {
477 		.privsize	= nft_rbtree_privsize,
478 		.elemsize	= offsetof(struct nft_rbtree_elem, ext),
479 		.estimate	= nft_rbtree_estimate,
480 		.init		= nft_rbtree_init,
481 		.destroy	= nft_rbtree_destroy,
482 		.insert		= nft_rbtree_insert,
483 		.remove		= nft_rbtree_remove,
484 		.deactivate	= nft_rbtree_deactivate,
485 		.flush		= nft_rbtree_flush,
486 		.activate	= nft_rbtree_activate,
487 		.lookup		= nft_rbtree_lookup,
488 		.walk		= nft_rbtree_walk,
489 		.get		= nft_rbtree_get,
490 	},
491 };
492