xref: /openbmc/linux/lib/test_rhashtable.c (revision 5ff32883)
1 /*
2  * Resizable, Scalable, Concurrent Hash Table
3  *
4  * Copyright (c) 2014-2015 Thomas Graf <tgraf@suug.ch>
5  * Copyright (c) 2008-2014 Patrick McHardy <kaber@trash.net>
6  *
7  * This program is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License version 2 as
9  * published by the Free Software Foundation.
10  */
11 
12 /**************************************************************************
13  * Self Test
14  **************************************************************************/
15 
16 #include <linux/init.h>
17 #include <linux/jhash.h>
18 #include <linux/kernel.h>
19 #include <linux/kthread.h>
20 #include <linux/module.h>
21 #include <linux/rcupdate.h>
22 #include <linux/rhashtable.h>
23 #include <linux/slab.h>
24 #include <linux/sched.h>
25 #include <linux/random.h>
26 #include <linux/vmalloc.h>
27 #include <linux/wait.h>
28 
29 #define MAX_ENTRIES	1000000
30 #define TEST_INSERT_FAIL INT_MAX
31 
32 static int parm_entries = 50000;
33 module_param(parm_entries, int, 0);
34 MODULE_PARM_DESC(parm_entries, "Number of entries to add (default: 50000)");
35 
36 static int runs = 4;
37 module_param(runs, int, 0);
38 MODULE_PARM_DESC(runs, "Number of test runs per variant (default: 4)");
39 
40 static int max_size = 0;
41 module_param(max_size, int, 0);
42 MODULE_PARM_DESC(max_size, "Maximum table size (default: calculated)");
43 
44 static bool shrinking = false;
45 module_param(shrinking, bool, 0);
46 MODULE_PARM_DESC(shrinking, "Enable automatic shrinking (default: off)");
47 
48 static int size = 8;
49 module_param(size, int, 0);
50 MODULE_PARM_DESC(size, "Initial size hint of table (default: 8)");
51 
52 static int tcount = 10;
53 module_param(tcount, int, 0);
54 MODULE_PARM_DESC(tcount, "Number of threads to spawn (default: 10)");
55 
56 static bool enomem_retry = false;
57 module_param(enomem_retry, bool, 0);
58 MODULE_PARM_DESC(enomem_retry, "Retry insert even if -ENOMEM was returned (default: off)");
59 
60 struct test_obj_val {
61 	int	id;
62 	int	tid;
63 };
64 
65 struct test_obj {
66 	struct test_obj_val	value;
67 	struct rhash_head	node;
68 };
69 
70 struct test_obj_rhl {
71 	struct test_obj_val	value;
72 	struct rhlist_head	list_node;
73 };
74 
75 struct thread_data {
76 	unsigned int entries;
77 	int id;
78 	struct task_struct *task;
79 	struct test_obj *objs;
80 };
81 
82 static u32 my_hashfn(const void *data, u32 len, u32 seed)
83 {
84 	const struct test_obj_rhl *obj = data;
85 
86 	return (obj->value.id % 10);
87 }
88 
89 static int my_cmpfn(struct rhashtable_compare_arg *arg, const void *obj)
90 {
91 	const struct test_obj_rhl *test_obj = obj;
92 	const struct test_obj_val *val = arg->key;
93 
94 	return test_obj->value.id - val->id;
95 }
96 
97 static struct rhashtable_params test_rht_params = {
98 	.head_offset = offsetof(struct test_obj, node),
99 	.key_offset = offsetof(struct test_obj, value),
100 	.key_len = sizeof(struct test_obj_val),
101 	.hashfn = jhash,
102 };
103 
104 static struct rhashtable_params test_rht_params_dup = {
105 	.head_offset = offsetof(struct test_obj_rhl, list_node),
106 	.key_offset = offsetof(struct test_obj_rhl, value),
107 	.key_len = sizeof(struct test_obj_val),
108 	.hashfn = jhash,
109 	.obj_hashfn = my_hashfn,
110 	.obj_cmpfn = my_cmpfn,
111 	.nelem_hint = 128,
112 	.automatic_shrinking = false,
113 };
114 
115 static atomic_t startup_count;
116 static DECLARE_WAIT_QUEUE_HEAD(startup_wait);
117 
118 static int insert_retry(struct rhashtable *ht, struct test_obj *obj,
119                         const struct rhashtable_params params)
120 {
121 	int err, retries = -1, enomem_retries = 0;
122 
123 	do {
124 		retries++;
125 		cond_resched();
126 		err = rhashtable_insert_fast(ht, &obj->node, params);
127 		if (err == -ENOMEM && enomem_retry) {
128 			enomem_retries++;
129 			err = -EBUSY;
130 		}
131 	} while (err == -EBUSY);
132 
133 	if (enomem_retries)
134 		pr_info(" %u insertions retried after -ENOMEM\n",
135 			enomem_retries);
136 
137 	return err ? : retries;
138 }
139 
140 static int __init test_rht_lookup(struct rhashtable *ht, struct test_obj *array,
141 				  unsigned int entries)
142 {
143 	unsigned int i;
144 
145 	for (i = 0; i < entries; i++) {
146 		struct test_obj *obj;
147 		bool expected = !(i % 2);
148 		struct test_obj_val key = {
149 			.id = i,
150 		};
151 
152 		if (array[i / 2].value.id == TEST_INSERT_FAIL)
153 			expected = false;
154 
155 		obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
156 
157 		if (expected && !obj) {
158 			pr_warn("Test failed: Could not find key %u\n", key.id);
159 			return -ENOENT;
160 		} else if (!expected && obj) {
161 			pr_warn("Test failed: Unexpected entry found for key %u\n",
162 				key.id);
163 			return -EEXIST;
164 		} else if (expected && obj) {
165 			if (obj->value.id != i) {
166 				pr_warn("Test failed: Lookup value mismatch %u!=%u\n",
167 					obj->value.id, i);
168 				return -EINVAL;
169 			}
170 		}
171 
172 		cond_resched_rcu();
173 	}
174 
175 	return 0;
176 }
177 
178 static void test_bucket_stats(struct rhashtable *ht, unsigned int entries)
179 {
180 	unsigned int err, total = 0, chain_len = 0;
181 	struct rhashtable_iter hti;
182 	struct rhash_head *pos;
183 
184 	err = rhashtable_walk_init(ht, &hti, GFP_KERNEL);
185 	if (err) {
186 		pr_warn("Test failed: allocation error");
187 		return;
188 	}
189 
190 	rhashtable_walk_start(&hti);
191 
192 	while ((pos = rhashtable_walk_next(&hti))) {
193 		if (PTR_ERR(pos) == -EAGAIN) {
194 			pr_info("Info: encountered resize\n");
195 			chain_len++;
196 			continue;
197 		} else if (IS_ERR(pos)) {
198 			pr_warn("Test failed: rhashtable_walk_next() error: %ld\n",
199 				PTR_ERR(pos));
200 			break;
201 		}
202 
203 		total++;
204 	}
205 
206 	rhashtable_walk_stop(&hti);
207 	rhashtable_walk_exit(&hti);
208 
209 	pr_info("  Traversal complete: counted=%u, nelems=%u, entries=%d, table-jumps=%u\n",
210 		total, atomic_read(&ht->nelems), entries, chain_len);
211 
212 	if (total != atomic_read(&ht->nelems) || total != entries)
213 		pr_warn("Test failed: Total count mismatch ^^^");
214 }
215 
216 static s64 __init test_rhashtable(struct rhashtable *ht, struct test_obj *array,
217 				  unsigned int entries)
218 {
219 	struct test_obj *obj;
220 	int err;
221 	unsigned int i, insert_retries = 0;
222 	s64 start, end;
223 
224 	/*
225 	 * Insertion Test:
226 	 * Insert entries into table with all keys even numbers
227 	 */
228 	pr_info("  Adding %d keys\n", entries);
229 	start = ktime_get_ns();
230 	for (i = 0; i < entries; i++) {
231 		struct test_obj *obj = &array[i];
232 
233 		obj->value.id = i * 2;
234 		err = insert_retry(ht, obj, test_rht_params);
235 		if (err > 0)
236 			insert_retries += err;
237 		else if (err)
238 			return err;
239 	}
240 
241 	if (insert_retries)
242 		pr_info("  %u insertions retried due to memory pressure\n",
243 			insert_retries);
244 
245 	test_bucket_stats(ht, entries);
246 	rcu_read_lock();
247 	test_rht_lookup(ht, array, entries);
248 	rcu_read_unlock();
249 
250 	test_bucket_stats(ht, entries);
251 
252 	pr_info("  Deleting %d keys\n", entries);
253 	for (i = 0; i < entries; i++) {
254 		struct test_obj_val key = {
255 			.id = i * 2,
256 		};
257 
258 		if (array[i].value.id != TEST_INSERT_FAIL) {
259 			obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
260 			BUG_ON(!obj);
261 
262 			rhashtable_remove_fast(ht, &obj->node, test_rht_params);
263 		}
264 
265 		cond_resched();
266 	}
267 
268 	end = ktime_get_ns();
269 	pr_info("  Duration of test: %lld ns\n", end - start);
270 
271 	return end - start;
272 }
273 
274 static struct rhashtable ht;
275 static struct rhltable rhlt;
276 
277 static int __init test_rhltable(unsigned int entries)
278 {
279 	struct test_obj_rhl *rhl_test_objects;
280 	unsigned long *obj_in_table;
281 	unsigned int i, j, k;
282 	int ret, err;
283 
284 	if (entries == 0)
285 		entries = 1;
286 
287 	rhl_test_objects = vzalloc(array_size(entries,
288 					      sizeof(*rhl_test_objects)));
289 	if (!rhl_test_objects)
290 		return -ENOMEM;
291 
292 	ret = -ENOMEM;
293 	obj_in_table = vzalloc(array_size(sizeof(unsigned long),
294 					  BITS_TO_LONGS(entries)));
295 	if (!obj_in_table)
296 		goto out_free;
297 
298 	err = rhltable_init(&rhlt, &test_rht_params);
299 	if (WARN_ON(err))
300 		goto out_free;
301 
302 	k = prandom_u32();
303 	ret = 0;
304 	for (i = 0; i < entries; i++) {
305 		rhl_test_objects[i].value.id = k;
306 		err = rhltable_insert(&rhlt, &rhl_test_objects[i].list_node,
307 				      test_rht_params);
308 		if (WARN(err, "error %d on element %d\n", err, i))
309 			break;
310 		if (err == 0)
311 			set_bit(i, obj_in_table);
312 	}
313 
314 	if (err)
315 		ret = err;
316 
317 	pr_info("test %d add/delete pairs into rhlist\n", entries);
318 	for (i = 0; i < entries; i++) {
319 		struct rhlist_head *h, *pos;
320 		struct test_obj_rhl *obj;
321 		struct test_obj_val key = {
322 			.id = k,
323 		};
324 		bool found;
325 
326 		rcu_read_lock();
327 		h = rhltable_lookup(&rhlt, &key, test_rht_params);
328 		if (WARN(!h, "key not found during iteration %d of %d", i, entries)) {
329 			rcu_read_unlock();
330 			break;
331 		}
332 
333 		if (i) {
334 			j = i - 1;
335 			rhl_for_each_entry_rcu(obj, pos, h, list_node) {
336 				if (WARN(pos == &rhl_test_objects[j].list_node, "old element found, should be gone"))
337 					break;
338 			}
339 		}
340 
341 		cond_resched_rcu();
342 
343 		found = false;
344 
345 		rhl_for_each_entry_rcu(obj, pos, h, list_node) {
346 			if (pos == &rhl_test_objects[i].list_node) {
347 				found = true;
348 				break;
349 			}
350 		}
351 
352 		rcu_read_unlock();
353 
354 		if (WARN(!found, "element %d not found", i))
355 			break;
356 
357 		err = rhltable_remove(&rhlt, &rhl_test_objects[i].list_node, test_rht_params);
358 		WARN(err, "rhltable_remove: err %d for iteration %d\n", err, i);
359 		if (err == 0)
360 			clear_bit(i, obj_in_table);
361 	}
362 
363 	if (ret == 0 && err)
364 		ret = err;
365 
366 	for (i = 0; i < entries; i++) {
367 		WARN(test_bit(i, obj_in_table), "elem %d allegedly still present", i);
368 
369 		err = rhltable_insert(&rhlt, &rhl_test_objects[i].list_node,
370 				      test_rht_params);
371 		if (WARN(err, "error %d on element %d\n", err, i))
372 			break;
373 		if (err == 0)
374 			set_bit(i, obj_in_table);
375 	}
376 
377 	pr_info("test %d random rhlist add/delete operations\n", entries);
378 	for (j = 0; j < entries; j++) {
379 		u32 i = prandom_u32_max(entries);
380 		u32 prand = prandom_u32();
381 
382 		cond_resched();
383 
384 		if (prand == 0)
385 			prand = prandom_u32();
386 
387 		if (prand & 1) {
388 			prand >>= 1;
389 			continue;
390 		}
391 
392 		err = rhltable_remove(&rhlt, &rhl_test_objects[i].list_node, test_rht_params);
393 		if (test_bit(i, obj_in_table)) {
394 			clear_bit(i, obj_in_table);
395 			if (WARN(err, "cannot remove element at slot %d", i))
396 				continue;
397 		} else {
398 			if (WARN(err != -ENOENT, "removed non-existant element %d, error %d not %d",
399 			     i, err, -ENOENT))
400 				continue;
401 		}
402 
403 		if (prand & 1) {
404 			prand >>= 1;
405 			continue;
406 		}
407 
408 		err = rhltable_insert(&rhlt, &rhl_test_objects[i].list_node, test_rht_params);
409 		if (err == 0) {
410 			if (WARN(test_and_set_bit(i, obj_in_table), "succeeded to insert same object %d", i))
411 				continue;
412 		} else {
413 			if (WARN(!test_bit(i, obj_in_table), "failed to insert object %d", i))
414 				continue;
415 		}
416 
417 		if (prand & 1) {
418 			prand >>= 1;
419 			continue;
420 		}
421 
422 		i = prandom_u32_max(entries);
423 		if (test_bit(i, obj_in_table)) {
424 			err = rhltable_remove(&rhlt, &rhl_test_objects[i].list_node, test_rht_params);
425 			WARN(err, "cannot remove element at slot %d", i);
426 			if (err == 0)
427 				clear_bit(i, obj_in_table);
428 		} else {
429 			err = rhltable_insert(&rhlt, &rhl_test_objects[i].list_node, test_rht_params);
430 			WARN(err, "failed to insert object %d", i);
431 			if (err == 0)
432 				set_bit(i, obj_in_table);
433 		}
434 	}
435 
436 	for (i = 0; i < entries; i++) {
437 		cond_resched();
438 		err = rhltable_remove(&rhlt, &rhl_test_objects[i].list_node, test_rht_params);
439 		if (test_bit(i, obj_in_table)) {
440 			if (WARN(err, "cannot remove element at slot %d", i))
441 				continue;
442 		} else {
443 			if (WARN(err != -ENOENT, "removed non-existant element, error %d not %d",
444 				 err, -ENOENT))
445 			continue;
446 		}
447 	}
448 
449 	rhltable_destroy(&rhlt);
450 out_free:
451 	vfree(rhl_test_objects);
452 	vfree(obj_in_table);
453 	return ret;
454 }
455 
456 static int __init test_rhashtable_max(struct test_obj *array,
457 				      unsigned int entries)
458 {
459 	unsigned int i, insert_retries = 0;
460 	int err;
461 
462 	test_rht_params.max_size = roundup_pow_of_two(entries / 8);
463 	err = rhashtable_init(&ht, &test_rht_params);
464 	if (err)
465 		return err;
466 
467 	for (i = 0; i < ht.max_elems; i++) {
468 		struct test_obj *obj = &array[i];
469 
470 		obj->value.id = i * 2;
471 		err = insert_retry(&ht, obj, test_rht_params);
472 		if (err > 0)
473 			insert_retries += err;
474 		else if (err)
475 			return err;
476 	}
477 
478 	err = insert_retry(&ht, &array[ht.max_elems], test_rht_params);
479 	if (err == -E2BIG) {
480 		err = 0;
481 	} else {
482 		pr_info("insert element %u should have failed with %d, got %d\n",
483 				ht.max_elems, -E2BIG, err);
484 		if (err == 0)
485 			err = -1;
486 	}
487 
488 	rhashtable_destroy(&ht);
489 
490 	return err;
491 }
492 
493 static unsigned int __init print_ht(struct rhltable *rhlt)
494 {
495 	struct rhashtable *ht;
496 	const struct bucket_table *tbl;
497 	char buff[512] = "";
498 	unsigned int i, cnt = 0;
499 
500 	ht = &rhlt->ht;
501 	/* Take the mutex to avoid RCU warning */
502 	mutex_lock(&ht->mutex);
503 	tbl = rht_dereference(ht->tbl, ht);
504 	for (i = 0; i < tbl->size; i++) {
505 		struct rhash_head *pos, *next;
506 		struct test_obj_rhl *p;
507 
508 		pos = rht_dereference(tbl->buckets[i], ht);
509 		next = !rht_is_a_nulls(pos) ? rht_dereference(pos->next, ht) : NULL;
510 
511 		if (!rht_is_a_nulls(pos)) {
512 			sprintf(buff, "%s\nbucket[%d] -> ", buff, i);
513 		}
514 
515 		while (!rht_is_a_nulls(pos)) {
516 			struct rhlist_head *list = container_of(pos, struct rhlist_head, rhead);
517 			sprintf(buff, "%s[[", buff);
518 			do {
519 				pos = &list->rhead;
520 				list = rht_dereference(list->next, ht);
521 				p = rht_obj(ht, pos);
522 
523 				sprintf(buff, "%s val %d (tid=%d)%s", buff, p->value.id, p->value.tid,
524 					list? ", " : " ");
525 				cnt++;
526 			} while (list);
527 
528 			pos = next,
529 			next = !rht_is_a_nulls(pos) ?
530 				rht_dereference(pos->next, ht) : NULL;
531 
532 			sprintf(buff, "%s]]%s", buff, !rht_is_a_nulls(pos) ? " -> " : "");
533 		}
534 	}
535 	printk(KERN_ERR "\n---- ht: ----%s\n-------------\n", buff);
536 	mutex_unlock(&ht->mutex);
537 
538 	return cnt;
539 }
540 
541 static int __init test_insert_dup(struct test_obj_rhl *rhl_test_objects,
542 				  int cnt, bool slow)
543 {
544 	struct rhltable rhlt;
545 	unsigned int i, ret;
546 	const char *key;
547 	int err = 0;
548 
549 	err = rhltable_init(&rhlt, &test_rht_params_dup);
550 	if (WARN_ON(err))
551 		return err;
552 
553 	for (i = 0; i < cnt; i++) {
554 		rhl_test_objects[i].value.tid = i;
555 		key = rht_obj(&rhlt.ht, &rhl_test_objects[i].list_node.rhead);
556 		key += test_rht_params_dup.key_offset;
557 
558 		if (slow) {
559 			err = PTR_ERR(rhashtable_insert_slow(&rhlt.ht, key,
560 							     &rhl_test_objects[i].list_node.rhead));
561 			if (err == -EAGAIN)
562 				err = 0;
563 		} else
564 			err = rhltable_insert(&rhlt,
565 					      &rhl_test_objects[i].list_node,
566 					      test_rht_params_dup);
567 		if (WARN(err, "error %d on element %d/%d (%s)\n", err, i, cnt, slow? "slow" : "fast"))
568 			goto skip_print;
569 	}
570 
571 	ret = print_ht(&rhlt);
572 	WARN(ret != cnt, "missing rhltable elements (%d != %d, %s)\n", ret, cnt, slow? "slow" : "fast");
573 
574 skip_print:
575 	rhltable_destroy(&rhlt);
576 
577 	return 0;
578 }
579 
580 static int __init test_insert_duplicates_run(void)
581 {
582 	struct test_obj_rhl rhl_test_objects[3] = {};
583 
584 	pr_info("test inserting duplicates\n");
585 
586 	/* two different values that map to same bucket */
587 	rhl_test_objects[0].value.id = 1;
588 	rhl_test_objects[1].value.id = 21;
589 
590 	/* and another duplicate with same as [0] value
591 	 * which will be second on the bucket list */
592 	rhl_test_objects[2].value.id = rhl_test_objects[0].value.id;
593 
594 	test_insert_dup(rhl_test_objects, 2, false);
595 	test_insert_dup(rhl_test_objects, 3, false);
596 	test_insert_dup(rhl_test_objects, 2, true);
597 	test_insert_dup(rhl_test_objects, 3, true);
598 
599 	return 0;
600 }
601 
602 static int thread_lookup_test(struct thread_data *tdata)
603 {
604 	unsigned int entries = tdata->entries;
605 	int i, err = 0;
606 
607 	for (i = 0; i < entries; i++) {
608 		struct test_obj *obj;
609 		struct test_obj_val key = {
610 			.id = i,
611 			.tid = tdata->id,
612 		};
613 
614 		obj = rhashtable_lookup_fast(&ht, &key, test_rht_params);
615 		if (obj && (tdata->objs[i].value.id == TEST_INSERT_FAIL)) {
616 			pr_err("  found unexpected object %d-%d\n", key.tid, key.id);
617 			err++;
618 		} else if (!obj && (tdata->objs[i].value.id != TEST_INSERT_FAIL)) {
619 			pr_err("  object %d-%d not found!\n", key.tid, key.id);
620 			err++;
621 		} else if (obj && memcmp(&obj->value, &key, sizeof(key))) {
622 			pr_err("  wrong object returned (got %d-%d, expected %d-%d)\n",
623 			       obj->value.tid, obj->value.id, key.tid, key.id);
624 			err++;
625 		}
626 
627 		cond_resched();
628 	}
629 	return err;
630 }
631 
632 static int threadfunc(void *data)
633 {
634 	int i, step, err = 0, insert_retries = 0;
635 	struct thread_data *tdata = data;
636 
637 	if (atomic_dec_and_test(&startup_count))
638 		wake_up(&startup_wait);
639 	if (wait_event_interruptible(startup_wait, atomic_read(&startup_count) == -1)) {
640 		pr_err("  thread[%d]: interrupted\n", tdata->id);
641 		goto out;
642 	}
643 
644 	for (i = 0; i < tdata->entries; i++) {
645 		tdata->objs[i].value.id = i;
646 		tdata->objs[i].value.tid = tdata->id;
647 		err = insert_retry(&ht, &tdata->objs[i], test_rht_params);
648 		if (err > 0) {
649 			insert_retries += err;
650 		} else if (err) {
651 			pr_err("  thread[%d]: rhashtable_insert_fast failed\n",
652 			       tdata->id);
653 			goto out;
654 		}
655 	}
656 	if (insert_retries)
657 		pr_info("  thread[%d]: %u insertions retried due to memory pressure\n",
658 			tdata->id, insert_retries);
659 
660 	err = thread_lookup_test(tdata);
661 	if (err) {
662 		pr_err("  thread[%d]: rhashtable_lookup_test failed\n",
663 		       tdata->id);
664 		goto out;
665 	}
666 
667 	for (step = 10; step > 0; step--) {
668 		for (i = 0; i < tdata->entries; i += step) {
669 			if (tdata->objs[i].value.id == TEST_INSERT_FAIL)
670 				continue;
671 			err = rhashtable_remove_fast(&ht, &tdata->objs[i].node,
672 			                             test_rht_params);
673 			if (err) {
674 				pr_err("  thread[%d]: rhashtable_remove_fast failed\n",
675 				       tdata->id);
676 				goto out;
677 			}
678 			tdata->objs[i].value.id = TEST_INSERT_FAIL;
679 
680 			cond_resched();
681 		}
682 		err = thread_lookup_test(tdata);
683 		if (err) {
684 			pr_err("  thread[%d]: rhashtable_lookup_test (2) failed\n",
685 			       tdata->id);
686 			goto out;
687 		}
688 	}
689 out:
690 	while (!kthread_should_stop()) {
691 		set_current_state(TASK_INTERRUPTIBLE);
692 		schedule();
693 	}
694 	return err;
695 }
696 
697 static int __init test_rht_init(void)
698 {
699 	unsigned int entries;
700 	int i, err, started_threads = 0, failed_threads = 0;
701 	u64 total_time = 0;
702 	struct thread_data *tdata;
703 	struct test_obj *objs;
704 
705 	if (parm_entries < 0)
706 		parm_entries = 1;
707 
708 	entries = min(parm_entries, MAX_ENTRIES);
709 
710 	test_rht_params.automatic_shrinking = shrinking;
711 	test_rht_params.max_size = max_size ? : roundup_pow_of_two(entries);
712 	test_rht_params.nelem_hint = size;
713 
714 	objs = vzalloc(array_size(sizeof(struct test_obj),
715 				  test_rht_params.max_size + 1));
716 	if (!objs)
717 		return -ENOMEM;
718 
719 	pr_info("Running rhashtable test nelem=%d, max_size=%d, shrinking=%d\n",
720 		size, max_size, shrinking);
721 
722 	for (i = 0; i < runs; i++) {
723 		s64 time;
724 
725 		pr_info("Test %02d:\n", i);
726 		memset(objs, 0, test_rht_params.max_size * sizeof(struct test_obj));
727 
728 		err = rhashtable_init(&ht, &test_rht_params);
729 		if (err < 0) {
730 			pr_warn("Test failed: Unable to initialize hashtable: %d\n",
731 				err);
732 			continue;
733 		}
734 
735 		time = test_rhashtable(&ht, objs, entries);
736 		rhashtable_destroy(&ht);
737 		if (time < 0) {
738 			vfree(objs);
739 			pr_warn("Test failed: return code %lld\n", time);
740 			return -EINVAL;
741 		}
742 
743 		total_time += time;
744 	}
745 
746 	pr_info("test if its possible to exceed max_size %d: %s\n",
747 			test_rht_params.max_size, test_rhashtable_max(objs, entries) == 0 ?
748 			"no, ok" : "YES, failed");
749 	vfree(objs);
750 
751 	do_div(total_time, runs);
752 	pr_info("Average test time: %llu\n", total_time);
753 
754 	test_insert_duplicates_run();
755 
756 	if (!tcount)
757 		return 0;
758 
759 	pr_info("Testing concurrent rhashtable access from %d threads\n",
760 	        tcount);
761 	atomic_set(&startup_count, tcount);
762 	tdata = vzalloc(array_size(tcount, sizeof(struct thread_data)));
763 	if (!tdata)
764 		return -ENOMEM;
765 	objs  = vzalloc(array3_size(sizeof(struct test_obj), tcount, entries));
766 	if (!objs) {
767 		vfree(tdata);
768 		return -ENOMEM;
769 	}
770 
771 	test_rht_params.max_size = max_size ? :
772 	                           roundup_pow_of_two(tcount * entries);
773 	err = rhashtable_init(&ht, &test_rht_params);
774 	if (err < 0) {
775 		pr_warn("Test failed: Unable to initialize hashtable: %d\n",
776 			err);
777 		vfree(tdata);
778 		vfree(objs);
779 		return -EINVAL;
780 	}
781 	for (i = 0; i < tcount; i++) {
782 		tdata[i].id = i;
783 		tdata[i].entries = entries;
784 		tdata[i].objs = objs + i * entries;
785 		tdata[i].task = kthread_run(threadfunc, &tdata[i],
786 		                            "rhashtable_thrad[%d]", i);
787 		if (IS_ERR(tdata[i].task)) {
788 			pr_err(" kthread_run failed for thread %d\n", i);
789 			atomic_dec(&startup_count);
790 		} else {
791 			started_threads++;
792 		}
793 	}
794 	if (wait_event_interruptible(startup_wait, atomic_read(&startup_count) == 0))
795 		pr_err("  wait_event interruptible failed\n");
796 	/* count is 0 now, set it to -1 and wake up all threads together */
797 	atomic_dec(&startup_count);
798 	wake_up_all(&startup_wait);
799 	for (i = 0; i < tcount; i++) {
800 		if (IS_ERR(tdata[i].task))
801 			continue;
802 		if ((err = kthread_stop(tdata[i].task))) {
803 			pr_warn("Test failed: thread %d returned: %d\n",
804 			        i, err);
805 			failed_threads++;
806 		}
807 	}
808 	rhashtable_destroy(&ht);
809 	vfree(tdata);
810 	vfree(objs);
811 
812 	/*
813 	 * rhltable_remove is very expensive, default values can cause test
814 	 * to run for 2 minutes or more,  use a smaller number instead.
815 	 */
816 	err = test_rhltable(entries / 16);
817 	pr_info("Started %d threads, %d failed, rhltable test returns %d\n",
818 	        started_threads, failed_threads, err);
819 	return 0;
820 }
821 
822 static void __exit test_rht_exit(void)
823 {
824 }
825 
826 module_init(test_rht_init);
827 module_exit(test_rht_exit);
828 
829 MODULE_LICENSE("GPL v2");
830