1 /* SPDX-License-Identifier: GPL-2.0-only */
2 /* Copyright (C) 2013 Jozsef Kadlecsik <kadlec@netfilter.org> */
3 
4 #ifndef __IP_SET_BITMAP_IP_GEN_H
5 #define __IP_SET_BITMAP_IP_GEN_H
6 
7 #define mtype_do_test		IPSET_TOKEN(MTYPE, _do_test)
8 #define mtype_gc_test		IPSET_TOKEN(MTYPE, _gc_test)
9 #define mtype_is_filled		IPSET_TOKEN(MTYPE, _is_filled)
10 #define mtype_do_add		IPSET_TOKEN(MTYPE, _do_add)
11 #define mtype_ext_cleanup	IPSET_TOKEN(MTYPE, _ext_cleanup)
12 #define mtype_do_del		IPSET_TOKEN(MTYPE, _do_del)
13 #define mtype_do_list		IPSET_TOKEN(MTYPE, _do_list)
14 #define mtype_do_head		IPSET_TOKEN(MTYPE, _do_head)
15 #define mtype_adt_elem		IPSET_TOKEN(MTYPE, _adt_elem)
16 #define mtype_add_timeout	IPSET_TOKEN(MTYPE, _add_timeout)
17 #define mtype_gc_init		IPSET_TOKEN(MTYPE, _gc_init)
18 #define mtype_kadt		IPSET_TOKEN(MTYPE, _kadt)
19 #define mtype_uadt		IPSET_TOKEN(MTYPE, _uadt)
20 #define mtype_destroy		IPSET_TOKEN(MTYPE, _destroy)
21 #define mtype_memsize		IPSET_TOKEN(MTYPE, _memsize)
22 #define mtype_flush		IPSET_TOKEN(MTYPE, _flush)
23 #define mtype_head		IPSET_TOKEN(MTYPE, _head)
24 #define mtype_same_set		IPSET_TOKEN(MTYPE, _same_set)
25 #define mtype_elem		IPSET_TOKEN(MTYPE, _elem)
26 #define mtype_test		IPSET_TOKEN(MTYPE, _test)
27 #define mtype_add		IPSET_TOKEN(MTYPE, _add)
28 #define mtype_del		IPSET_TOKEN(MTYPE, _del)
29 #define mtype_list		IPSET_TOKEN(MTYPE, _list)
30 #define mtype_gc		IPSET_TOKEN(MTYPE, _gc)
31 #define mtype			MTYPE
32 
33 #define get_ext(set, map, id)	((map)->extensions + ((set)->dsize * (id)))
34 
35 static void
36 mtype_gc_init(struct ip_set *set, void (*gc)(struct timer_list *t))
37 {
38 	struct mtype *map = set->data;
39 
40 	timer_setup(&map->gc, gc, 0);
41 	mod_timer(&map->gc, jiffies + IPSET_GC_PERIOD(set->timeout) * HZ);
42 }
43 
44 static void
45 mtype_ext_cleanup(struct ip_set *set)
46 {
47 	struct mtype *map = set->data;
48 	u32 id;
49 
50 	for (id = 0; id < map->elements; id++)
51 		if (test_bit(id, map->members))
52 			ip_set_ext_destroy(set, get_ext(set, map, id));
53 }
54 
55 static void
56 mtype_destroy(struct ip_set *set)
57 {
58 	struct mtype *map = set->data;
59 
60 	if (SET_WITH_TIMEOUT(set))
61 		del_timer_sync(&map->gc);
62 
63 	if (set->dsize && set->extensions & IPSET_EXT_DESTROY)
64 		mtype_ext_cleanup(set);
65 	ip_set_free(map->members);
66 	ip_set_free(map);
67 
68 	set->data = NULL;
69 }
70 
71 static void
72 mtype_flush(struct ip_set *set)
73 {
74 	struct mtype *map = set->data;
75 
76 	if (set->extensions & IPSET_EXT_DESTROY)
77 		mtype_ext_cleanup(set);
78 	bitmap_zero(map->members, map->elements);
79 	set->elements = 0;
80 	set->ext_size = 0;
81 }
82 
83 /* Calculate the actual memory size of the set data */
84 static size_t
85 mtype_memsize(const struct mtype *map, size_t dsize)
86 {
87 	return sizeof(*map) + map->memsize +
88 	       map->elements * dsize;
89 }
90 
91 static int
92 mtype_head(struct ip_set *set, struct sk_buff *skb)
93 {
94 	const struct mtype *map = set->data;
95 	struct nlattr *nested;
96 	size_t memsize = mtype_memsize(map, set->dsize) + set->ext_size;
97 
98 	nested = nla_nest_start(skb, IPSET_ATTR_DATA);
99 	if (!nested)
100 		goto nla_put_failure;
101 	if (mtype_do_head(skb, map) ||
102 	    nla_put_net32(skb, IPSET_ATTR_REFERENCES, htonl(set->ref)) ||
103 	    nla_put_net32(skb, IPSET_ATTR_MEMSIZE, htonl(memsize)) ||
104 	    nla_put_net32(skb, IPSET_ATTR_ELEMENTS, htonl(set->elements)))
105 		goto nla_put_failure;
106 	if (unlikely(ip_set_put_flags(skb, set)))
107 		goto nla_put_failure;
108 	nla_nest_end(skb, nested);
109 
110 	return 0;
111 nla_put_failure:
112 	return -EMSGSIZE;
113 }
114 
115 static int
116 mtype_test(struct ip_set *set, void *value, const struct ip_set_ext *ext,
117 	   struct ip_set_ext *mext, u32 flags)
118 {
119 	struct mtype *map = set->data;
120 	const struct mtype_adt_elem *e = value;
121 	void *x = get_ext(set, map, e->id);
122 	int ret = mtype_do_test(e, map, set->dsize);
123 
124 	if (ret <= 0)
125 		return ret;
126 	return ip_set_match_extensions(set, ext, mext, flags, x);
127 }
128 
129 static int
130 mtype_add(struct ip_set *set, void *value, const struct ip_set_ext *ext,
131 	  struct ip_set_ext *mext, u32 flags)
132 {
133 	struct mtype *map = set->data;
134 	const struct mtype_adt_elem *e = value;
135 	void *x = get_ext(set, map, e->id);
136 	int ret = mtype_do_add(e, map, flags, set->dsize);
137 
138 	if (ret == IPSET_ADD_FAILED) {
139 		if (SET_WITH_TIMEOUT(set) &&
140 		    ip_set_timeout_expired(ext_timeout(x, set))) {
141 			set->elements--;
142 			ret = 0;
143 		} else if (!(flags & IPSET_FLAG_EXIST)) {
144 			set_bit(e->id, map->members);
145 			return -IPSET_ERR_EXIST;
146 		}
147 		/* Element is re-added, cleanup extensions */
148 		ip_set_ext_destroy(set, x);
149 	}
150 	if (ret > 0)
151 		set->elements--;
152 
153 	if (SET_WITH_TIMEOUT(set))
154 #ifdef IP_SET_BITMAP_STORED_TIMEOUT
155 		mtype_add_timeout(ext_timeout(x, set), e, ext, set, map, ret);
156 #else
157 		ip_set_timeout_set(ext_timeout(x, set), ext->timeout);
158 #endif
159 
160 	if (SET_WITH_COUNTER(set))
161 		ip_set_init_counter(ext_counter(x, set), ext);
162 	if (SET_WITH_COMMENT(set))
163 		ip_set_init_comment(set, ext_comment(x, set), ext);
164 	if (SET_WITH_SKBINFO(set))
165 		ip_set_init_skbinfo(ext_skbinfo(x, set), ext);
166 
167 	/* Activate element */
168 	set_bit(e->id, map->members);
169 	set->elements++;
170 
171 	return 0;
172 }
173 
174 static int
175 mtype_del(struct ip_set *set, void *value, const struct ip_set_ext *ext,
176 	  struct ip_set_ext *mext, u32 flags)
177 {
178 	struct mtype *map = set->data;
179 	const struct mtype_adt_elem *e = value;
180 	void *x = get_ext(set, map, e->id);
181 
182 	if (mtype_do_del(e, map))
183 		return -IPSET_ERR_EXIST;
184 
185 	ip_set_ext_destroy(set, x);
186 	set->elements--;
187 	if (SET_WITH_TIMEOUT(set) &&
188 	    ip_set_timeout_expired(ext_timeout(x, set)))
189 		return -IPSET_ERR_EXIST;
190 
191 	return 0;
192 }
193 
194 #ifndef IP_SET_BITMAP_STORED_TIMEOUT
195 static bool
196 mtype_is_filled(const struct mtype_elem *x)
197 {
198 	return true;
199 }
200 #endif
201 
202 static int
203 mtype_list(const struct ip_set *set,
204 	   struct sk_buff *skb, struct netlink_callback *cb)
205 {
206 	struct mtype *map = set->data;
207 	struct nlattr *adt, *nested;
208 	void *x;
209 	u32 id, first = cb->args[IPSET_CB_ARG0];
210 	int ret = 0;
211 
212 	adt = nla_nest_start(skb, IPSET_ATTR_ADT);
213 	if (!adt)
214 		return -EMSGSIZE;
215 	/* Extensions may be replaced */
216 	rcu_read_lock();
217 	for (; cb->args[IPSET_CB_ARG0] < map->elements;
218 	     cb->args[IPSET_CB_ARG0]++) {
219 		cond_resched_rcu();
220 		id = cb->args[IPSET_CB_ARG0];
221 		x = get_ext(set, map, id);
222 		if (!test_bit(id, map->members) ||
223 		    (SET_WITH_TIMEOUT(set) &&
224 #ifdef IP_SET_BITMAP_STORED_TIMEOUT
225 		     mtype_is_filled(x) &&
226 #endif
227 		     ip_set_timeout_expired(ext_timeout(x, set))))
228 			continue;
229 		nested = nla_nest_start(skb, IPSET_ATTR_DATA);
230 		if (!nested) {
231 			if (id == first) {
232 				nla_nest_cancel(skb, adt);
233 				ret = -EMSGSIZE;
234 				goto out;
235 			}
236 
237 			goto nla_put_failure;
238 		}
239 		if (mtype_do_list(skb, map, id, set->dsize))
240 			goto nla_put_failure;
241 		if (ip_set_put_extensions(skb, set, x, mtype_is_filled(x)))
242 			goto nla_put_failure;
243 		nla_nest_end(skb, nested);
244 	}
245 	nla_nest_end(skb, adt);
246 
247 	/* Set listing finished */
248 	cb->args[IPSET_CB_ARG0] = 0;
249 
250 	goto out;
251 
252 nla_put_failure:
253 	nla_nest_cancel(skb, nested);
254 	if (unlikely(id == first)) {
255 		cb->args[IPSET_CB_ARG0] = 0;
256 		ret = -EMSGSIZE;
257 	}
258 	nla_nest_end(skb, adt);
259 out:
260 	rcu_read_unlock();
261 	return ret;
262 }
263 
264 static void
265 mtype_gc(struct timer_list *t)
266 {
267 	struct mtype *map = from_timer(map, t, gc);
268 	struct ip_set *set = map->set;
269 	void *x;
270 	u32 id;
271 
272 	/* We run parallel with other readers (test element)
273 	 * but adding/deleting new entries is locked out
274 	 */
275 	spin_lock_bh(&set->lock);
276 	for (id = 0; id < map->elements; id++)
277 		if (mtype_gc_test(id, map, set->dsize)) {
278 			x = get_ext(set, map, id);
279 			if (ip_set_timeout_expired(ext_timeout(x, set))) {
280 				clear_bit(id, map->members);
281 				ip_set_ext_destroy(set, x);
282 				set->elements--;
283 			}
284 		}
285 	spin_unlock_bh(&set->lock);
286 
287 	map->gc.expires = jiffies + IPSET_GC_PERIOD(set->timeout) * HZ;
288 	add_timer(&map->gc);
289 }
290 
291 static const struct ip_set_type_variant mtype = {
292 	.kadt	= mtype_kadt,
293 	.uadt	= mtype_uadt,
294 	.adt	= {
295 		[IPSET_ADD] = mtype_add,
296 		[IPSET_DEL] = mtype_del,
297 		[IPSET_TEST] = mtype_test,
298 	},
299 	.destroy = mtype_destroy,
300 	.flush	= mtype_flush,
301 	.head	= mtype_head,
302 	.list	= mtype_list,
303 	.same_set = mtype_same_set,
304 };
305 
306 #endif /* __IP_SET_BITMAP_IP_GEN_H */
307