1 /* 2 * Resizable, Scalable, Concurrent Hash Table 3 * 4 * Copyright (c) 2014-2015 Thomas Graf <tgraf@suug.ch> 5 * Copyright (c) 2008-2014 Patrick McHardy <kaber@trash.net> 6 * 7 * This program is free software; you can redistribute it and/or modify 8 * it under the terms of the GNU General Public License version 2 as 9 * published by the Free Software Foundation. 10 */ 11 12 /************************************************************************** 13 * Self Test 14 **************************************************************************/ 15 16 #include <linux/init.h> 17 #include <linux/jhash.h> 18 #include <linux/kernel.h> 19 #include <linux/kthread.h> 20 #include <linux/module.h> 21 #include <linux/rcupdate.h> 22 #include <linux/rhashtable.h> 23 #include <linux/semaphore.h> 24 #include <linux/slab.h> 25 #include <linux/sched.h> 26 #include <linux/vmalloc.h> 27 28 #define MAX_ENTRIES 1000000 29 #define TEST_INSERT_FAIL INT_MAX 30 31 static int entries = 50000; 32 module_param(entries, int, 0); 33 MODULE_PARM_DESC(entries, "Number of entries to add (default: 50000)"); 34 35 static int runs = 4; 36 module_param(runs, int, 0); 37 MODULE_PARM_DESC(runs, "Number of test runs per variant (default: 4)"); 38 39 static int max_size = 65536; 40 module_param(max_size, int, 0); 41 MODULE_PARM_DESC(runs, "Maximum table size (default: 65536)"); 42 43 static bool shrinking = false; 44 module_param(shrinking, bool, 0); 45 MODULE_PARM_DESC(shrinking, "Enable automatic shrinking (default: off)"); 46 47 static int size = 8; 48 module_param(size, int, 0); 49 MODULE_PARM_DESC(size, "Initial size hint of table (default: 8)"); 50 51 static int tcount = 10; 52 module_param(tcount, int, 0); 53 MODULE_PARM_DESC(tcount, "Number of threads to spawn (default: 10)"); 54 55 struct test_obj { 56 int value; 57 struct rhash_head node; 58 }; 59 60 struct thread_data { 61 int id; 62 struct task_struct *task; 63 struct test_obj *objs; 64 }; 65 66 static struct test_obj array[MAX_ENTRIES]; 67 68 static struct rhashtable_params test_rht_params = { 69 .head_offset = offsetof(struct test_obj, node), 70 .key_offset = offsetof(struct test_obj, value), 71 .key_len = sizeof(int), 72 .hashfn = jhash, 73 .nulls_base = (3U << RHT_BASE_SHIFT), 74 }; 75 76 static struct semaphore prestart_sem; 77 static struct semaphore startup_sem = __SEMAPHORE_INITIALIZER(startup_sem, 0); 78 79 static int __init test_rht_lookup(struct rhashtable *ht) 80 { 81 unsigned int i; 82 83 for (i = 0; i < entries * 2; i++) { 84 struct test_obj *obj; 85 bool expected = !(i % 2); 86 u32 key = i; 87 88 if (array[i / 2].value == TEST_INSERT_FAIL) 89 expected = false; 90 91 obj = rhashtable_lookup_fast(ht, &key, test_rht_params); 92 93 if (expected && !obj) { 94 pr_warn("Test failed: Could not find key %u\n", key); 95 return -ENOENT; 96 } else if (!expected && obj) { 97 pr_warn("Test failed: Unexpected entry found for key %u\n", 98 key); 99 return -EEXIST; 100 } else if (expected && obj) { 101 if (obj->value != i) { 102 pr_warn("Test failed: Lookup value mismatch %u!=%u\n", 103 obj->value, i); 104 return -EINVAL; 105 } 106 } 107 108 cond_resched_rcu(); 109 } 110 111 return 0; 112 } 113 114 static void test_bucket_stats(struct rhashtable *ht) 115 { 116 unsigned int err, total = 0, chain_len = 0; 117 struct rhashtable_iter hti; 118 struct rhash_head *pos; 119 120 err = rhashtable_walk_init(ht, &hti); 121 if (err) { 122 pr_warn("Test failed: allocation error"); 123 return; 124 } 125 126 err = rhashtable_walk_start(&hti); 127 if (err && err != -EAGAIN) { 128 pr_warn("Test failed: iterator failed: %d\n", err); 129 return; 130 } 131 132 while ((pos = rhashtable_walk_next(&hti))) { 133 if (PTR_ERR(pos) == -EAGAIN) { 134 pr_info("Info: encountered resize\n"); 135 chain_len++; 136 continue; 137 } else if (IS_ERR(pos)) { 138 pr_warn("Test failed: rhashtable_walk_next() error: %ld\n", 139 PTR_ERR(pos)); 140 break; 141 } 142 143 total++; 144 } 145 146 rhashtable_walk_stop(&hti); 147 rhashtable_walk_exit(&hti); 148 149 pr_info(" Traversal complete: counted=%u, nelems=%u, entries=%d, table-jumps=%u\n", 150 total, atomic_read(&ht->nelems), entries, chain_len); 151 152 if (total != atomic_read(&ht->nelems) || total != entries) 153 pr_warn("Test failed: Total count mismatch ^^^"); 154 } 155 156 static s64 __init test_rhashtable(struct rhashtable *ht) 157 { 158 struct test_obj *obj; 159 int err; 160 unsigned int i, insert_fails = 0; 161 s64 start, end; 162 163 /* 164 * Insertion Test: 165 * Insert entries into table with all keys even numbers 166 */ 167 pr_info(" Adding %d keys\n", entries); 168 start = ktime_get_ns(); 169 for (i = 0; i < entries; i++) { 170 struct test_obj *obj = &array[i]; 171 172 obj->value = i * 2; 173 174 err = rhashtable_insert_fast(ht, &obj->node, test_rht_params); 175 if (err == -ENOMEM || err == -EBUSY) { 176 /* Mark failed inserts but continue */ 177 obj->value = TEST_INSERT_FAIL; 178 insert_fails++; 179 } else if (err) { 180 return err; 181 } 182 183 cond_resched(); 184 } 185 186 if (insert_fails) 187 pr_info(" %u insertions failed due to memory pressure\n", 188 insert_fails); 189 190 test_bucket_stats(ht); 191 rcu_read_lock(); 192 test_rht_lookup(ht); 193 rcu_read_unlock(); 194 195 test_bucket_stats(ht); 196 197 pr_info(" Deleting %d keys\n", entries); 198 for (i = 0; i < entries; i++) { 199 u32 key = i * 2; 200 201 if (array[i].value != TEST_INSERT_FAIL) { 202 obj = rhashtable_lookup_fast(ht, &key, test_rht_params); 203 BUG_ON(!obj); 204 205 rhashtable_remove_fast(ht, &obj->node, test_rht_params); 206 } 207 208 cond_resched(); 209 } 210 211 end = ktime_get_ns(); 212 pr_info(" Duration of test: %lld ns\n", end - start); 213 214 return end - start; 215 } 216 217 static struct rhashtable ht; 218 219 static int thread_lookup_test(struct thread_data *tdata) 220 { 221 int i, err = 0; 222 223 for (i = 0; i < entries; i++) { 224 struct test_obj *obj; 225 int key = (tdata->id << 16) | i; 226 227 obj = rhashtable_lookup_fast(&ht, &key, test_rht_params); 228 if (obj && (tdata->objs[i].value == TEST_INSERT_FAIL)) { 229 pr_err(" found unexpected object %d\n", key); 230 err++; 231 } else if (!obj && (tdata->objs[i].value != TEST_INSERT_FAIL)) { 232 pr_err(" object %d not found!\n", key); 233 err++; 234 } else if (obj && (obj->value != key)) { 235 pr_err(" wrong object returned (got %d, expected %d)\n", 236 obj->value, key); 237 err++; 238 } 239 } 240 return err; 241 } 242 243 static int threadfunc(void *data) 244 { 245 int i, step, err = 0, insert_fails = 0; 246 struct thread_data *tdata = data; 247 248 up(&prestart_sem); 249 if (down_interruptible(&startup_sem)) 250 pr_err(" thread[%d]: down_interruptible failed\n", tdata->id); 251 252 for (i = 0; i < entries; i++) { 253 tdata->objs[i].value = (tdata->id << 16) | i; 254 err = rhashtable_insert_fast(&ht, &tdata->objs[i].node, 255 test_rht_params); 256 if (err == -ENOMEM || err == -EBUSY) { 257 tdata->objs[i].value = TEST_INSERT_FAIL; 258 insert_fails++; 259 } else if (err) { 260 pr_err(" thread[%d]: rhashtable_insert_fast failed\n", 261 tdata->id); 262 goto out; 263 } 264 } 265 if (insert_fails) 266 pr_info(" thread[%d]: %d insert failures\n", 267 tdata->id, insert_fails); 268 269 err = thread_lookup_test(tdata); 270 if (err) { 271 pr_err(" thread[%d]: rhashtable_lookup_test failed\n", 272 tdata->id); 273 goto out; 274 } 275 276 for (step = 10; step > 0; step--) { 277 for (i = 0; i < entries; i += step) { 278 if (tdata->objs[i].value == TEST_INSERT_FAIL) 279 continue; 280 err = rhashtable_remove_fast(&ht, &tdata->objs[i].node, 281 test_rht_params); 282 if (err) { 283 pr_err(" thread[%d]: rhashtable_remove_fast failed\n", 284 tdata->id); 285 goto out; 286 } 287 tdata->objs[i].value = TEST_INSERT_FAIL; 288 } 289 err = thread_lookup_test(tdata); 290 if (err) { 291 pr_err(" thread[%d]: rhashtable_lookup_test (2) failed\n", 292 tdata->id); 293 goto out; 294 } 295 } 296 out: 297 while (!kthread_should_stop()) { 298 set_current_state(TASK_INTERRUPTIBLE); 299 schedule(); 300 } 301 return err; 302 } 303 304 static int __init test_rht_init(void) 305 { 306 int i, err, started_threads = 0, failed_threads = 0; 307 u64 total_time = 0; 308 struct thread_data *tdata; 309 struct test_obj *objs; 310 311 entries = min(entries, MAX_ENTRIES); 312 313 test_rht_params.automatic_shrinking = shrinking; 314 test_rht_params.max_size = max_size; 315 test_rht_params.nelem_hint = size; 316 317 pr_info("Running rhashtable test nelem=%d, max_size=%d, shrinking=%d\n", 318 size, max_size, shrinking); 319 320 for (i = 0; i < runs; i++) { 321 s64 time; 322 323 pr_info("Test %02d:\n", i); 324 memset(&array, 0, sizeof(array)); 325 err = rhashtable_init(&ht, &test_rht_params); 326 if (err < 0) { 327 pr_warn("Test failed: Unable to initialize hashtable: %d\n", 328 err); 329 continue; 330 } 331 332 time = test_rhashtable(&ht); 333 rhashtable_destroy(&ht); 334 if (time < 0) { 335 pr_warn("Test failed: return code %lld\n", time); 336 return -EINVAL; 337 } 338 339 total_time += time; 340 } 341 342 do_div(total_time, runs); 343 pr_info("Average test time: %llu\n", total_time); 344 345 if (!tcount) 346 return 0; 347 348 pr_info("Testing concurrent rhashtable access from %d threads\n", 349 tcount); 350 sema_init(&prestart_sem, 1 - tcount); 351 tdata = vzalloc(tcount * sizeof(struct thread_data)); 352 if (!tdata) 353 return -ENOMEM; 354 objs = vzalloc(tcount * entries * sizeof(struct test_obj)); 355 if (!objs) { 356 vfree(tdata); 357 return -ENOMEM; 358 } 359 360 err = rhashtable_init(&ht, &test_rht_params); 361 if (err < 0) { 362 pr_warn("Test failed: Unable to initialize hashtable: %d\n", 363 err); 364 vfree(tdata); 365 vfree(objs); 366 return -EINVAL; 367 } 368 for (i = 0; i < tcount; i++) { 369 tdata[i].id = i; 370 tdata[i].objs = objs + i * entries; 371 tdata[i].task = kthread_run(threadfunc, &tdata[i], 372 "rhashtable_thrad[%d]", i); 373 if (IS_ERR(tdata[i].task)) 374 pr_err(" kthread_run failed for thread %d\n", i); 375 else 376 started_threads++; 377 } 378 if (down_interruptible(&prestart_sem)) 379 pr_err(" down interruptible failed\n"); 380 for (i = 0; i < tcount; i++) 381 up(&startup_sem); 382 for (i = 0; i < tcount; i++) { 383 if (IS_ERR(tdata[i].task)) 384 continue; 385 if ((err = kthread_stop(tdata[i].task))) { 386 pr_warn("Test failed: thread %d returned: %d\n", 387 i, err); 388 failed_threads++; 389 } 390 } 391 pr_info("Started %d threads, %d failed\n", 392 started_threads, failed_threads); 393 rhashtable_destroy(&ht); 394 vfree(tdata); 395 vfree(objs); 396 return 0; 397 } 398 399 static void __exit test_rht_exit(void) 400 { 401 } 402 403 module_init(test_rht_init); 404 module_exit(test_rht_exit); 405 406 MODULE_LICENSE("GPL v2"); 407