1 #include <linux/module.h> 2 #include <linux/moduleparam.h> 3 #include <linux/rbtree_augmented.h> 4 #include <linux/random.h> 5 #include <linux/slab.h> 6 #include <asm/timex.h> 7 8 #define __param(type, name, init, msg) \ 9 static type name = init; \ 10 module_param(name, type, 0444); \ 11 MODULE_PARM_DESC(name, msg); 12 13 __param(int, nnodes, 100, "Number of nodes in the rb-tree"); 14 __param(int, perf_loops, 1000, "Number of iterations modifying the rb-tree"); 15 __param(int, check_loops, 100, "Number of iterations modifying and verifying the rb-tree"); 16 17 struct test_node { 18 u32 key; 19 struct rb_node rb; 20 21 /* following fields used for testing augmented rbtree functionality */ 22 u32 val; 23 u32 augmented; 24 }; 25 26 static struct rb_root_cached root = RB_ROOT_CACHED; 27 static struct test_node *nodes = NULL; 28 29 static struct rnd_state rnd; 30 31 static void insert(struct test_node *node, struct rb_root_cached *root) 32 { 33 struct rb_node **new = &root->rb_root.rb_node, *parent = NULL; 34 u32 key = node->key; 35 36 while (*new) { 37 parent = *new; 38 if (key < rb_entry(parent, struct test_node, rb)->key) 39 new = &parent->rb_left; 40 else 41 new = &parent->rb_right; 42 } 43 44 rb_link_node(&node->rb, parent, new); 45 rb_insert_color(&node->rb, &root->rb_root); 46 } 47 48 static void insert_cached(struct test_node *node, struct rb_root_cached *root) 49 { 50 struct rb_node **new = &root->rb_root.rb_node, *parent = NULL; 51 u32 key = node->key; 52 bool leftmost = true; 53 54 while (*new) { 55 parent = *new; 56 if (key < rb_entry(parent, struct test_node, rb)->key) 57 new = &parent->rb_left; 58 else { 59 new = &parent->rb_right; 60 leftmost = false; 61 } 62 } 63 64 rb_link_node(&node->rb, parent, new); 65 rb_insert_color_cached(&node->rb, root, leftmost); 66 } 67 68 static inline void erase(struct test_node *node, struct rb_root_cached *root) 69 { 70 rb_erase(&node->rb, &root->rb_root); 71 } 72 73 static inline void erase_cached(struct test_node *node, struct rb_root_cached *root) 74 { 75 rb_erase_cached(&node->rb, root); 76 } 77 78 79 static inline u32 augment_recompute(struct test_node *node) 80 { 81 u32 max = node->val, child_augmented; 82 if (node->rb.rb_left) { 83 child_augmented = rb_entry(node->rb.rb_left, struct test_node, 84 rb)->augmented; 85 if (max < child_augmented) 86 max = child_augmented; 87 } 88 if (node->rb.rb_right) { 89 child_augmented = rb_entry(node->rb.rb_right, struct test_node, 90 rb)->augmented; 91 if (max < child_augmented) 92 max = child_augmented; 93 } 94 return max; 95 } 96 97 RB_DECLARE_CALLBACKS(static, augment_callbacks, struct test_node, rb, 98 u32, augmented, augment_recompute) 99 100 static void insert_augmented(struct test_node *node, 101 struct rb_root_cached *root) 102 { 103 struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL; 104 u32 key = node->key; 105 u32 val = node->val; 106 struct test_node *parent; 107 108 while (*new) { 109 rb_parent = *new; 110 parent = rb_entry(rb_parent, struct test_node, rb); 111 if (parent->augmented < val) 112 parent->augmented = val; 113 if (key < parent->key) 114 new = &parent->rb.rb_left; 115 else 116 new = &parent->rb.rb_right; 117 } 118 119 node->augmented = val; 120 rb_link_node(&node->rb, rb_parent, new); 121 rb_insert_augmented(&node->rb, &root->rb_root, &augment_callbacks); 122 } 123 124 static void insert_augmented_cached(struct test_node *node, 125 struct rb_root_cached *root) 126 { 127 struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL; 128 u32 key = node->key; 129 u32 val = node->val; 130 struct test_node *parent; 131 bool leftmost = true; 132 133 while (*new) { 134 rb_parent = *new; 135 parent = rb_entry(rb_parent, struct test_node, rb); 136 if (parent->augmented < val) 137 parent->augmented = val; 138 if (key < parent->key) 139 new = &parent->rb.rb_left; 140 else { 141 new = &parent->rb.rb_right; 142 leftmost = false; 143 } 144 } 145 146 node->augmented = val; 147 rb_link_node(&node->rb, rb_parent, new); 148 rb_insert_augmented_cached(&node->rb, root, 149 leftmost, &augment_callbacks); 150 } 151 152 153 static void erase_augmented(struct test_node *node, struct rb_root_cached *root) 154 { 155 rb_erase_augmented(&node->rb, &root->rb_root, &augment_callbacks); 156 } 157 158 static void erase_augmented_cached(struct test_node *node, 159 struct rb_root_cached *root) 160 { 161 rb_erase_augmented_cached(&node->rb, root, &augment_callbacks); 162 } 163 164 static void init(void) 165 { 166 int i; 167 for (i = 0; i < nnodes; i++) { 168 nodes[i].key = prandom_u32_state(&rnd); 169 nodes[i].val = prandom_u32_state(&rnd); 170 } 171 } 172 173 static bool is_red(struct rb_node *rb) 174 { 175 return !(rb->__rb_parent_color & 1); 176 } 177 178 static int black_path_count(struct rb_node *rb) 179 { 180 int count; 181 for (count = 0; rb; rb = rb_parent(rb)) 182 count += !is_red(rb); 183 return count; 184 } 185 186 static void check_postorder_foreach(int nr_nodes) 187 { 188 struct test_node *cur, *n; 189 int count = 0; 190 rbtree_postorder_for_each_entry_safe(cur, n, &root.rb_root, rb) 191 count++; 192 193 WARN_ON_ONCE(count != nr_nodes); 194 } 195 196 static void check_postorder(int nr_nodes) 197 { 198 struct rb_node *rb; 199 int count = 0; 200 for (rb = rb_first_postorder(&root.rb_root); rb; rb = rb_next_postorder(rb)) 201 count++; 202 203 WARN_ON_ONCE(count != nr_nodes); 204 } 205 206 static void check(int nr_nodes) 207 { 208 struct rb_node *rb; 209 int count = 0, blacks = 0; 210 u32 prev_key = 0; 211 212 for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) { 213 struct test_node *node = rb_entry(rb, struct test_node, rb); 214 WARN_ON_ONCE(node->key < prev_key); 215 WARN_ON_ONCE(is_red(rb) && 216 (!rb_parent(rb) || is_red(rb_parent(rb)))); 217 if (!count) 218 blacks = black_path_count(rb); 219 else 220 WARN_ON_ONCE((!rb->rb_left || !rb->rb_right) && 221 blacks != black_path_count(rb)); 222 prev_key = node->key; 223 count++; 224 } 225 226 WARN_ON_ONCE(count != nr_nodes); 227 WARN_ON_ONCE(count < (1 << black_path_count(rb_last(&root.rb_root))) - 1); 228 229 check_postorder(nr_nodes); 230 check_postorder_foreach(nr_nodes); 231 } 232 233 static void check_augmented(int nr_nodes) 234 { 235 struct rb_node *rb; 236 237 check(nr_nodes); 238 for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) { 239 struct test_node *node = rb_entry(rb, struct test_node, rb); 240 WARN_ON_ONCE(node->augmented != augment_recompute(node)); 241 } 242 } 243 244 static int __init rbtree_test_init(void) 245 { 246 int i, j; 247 cycles_t time1, time2, time; 248 struct rb_node *node; 249 250 nodes = kmalloc_array(nnodes, sizeof(*nodes), GFP_KERNEL); 251 if (!nodes) 252 return -ENOMEM; 253 254 printk(KERN_ALERT "rbtree testing"); 255 256 prandom_seed_state(&rnd, 3141592653589793238ULL); 257 init(); 258 259 time1 = get_cycles(); 260 261 for (i = 0; i < perf_loops; i++) { 262 for (j = 0; j < nnodes; j++) 263 insert(nodes + j, &root); 264 for (j = 0; j < nnodes; j++) 265 erase(nodes + j, &root); 266 } 267 268 time2 = get_cycles(); 269 time = time2 - time1; 270 271 time = div_u64(time, perf_loops); 272 printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", 273 (unsigned long long)time); 274 275 time1 = get_cycles(); 276 277 for (i = 0; i < perf_loops; i++) { 278 for (j = 0; j < nnodes; j++) 279 insert_cached(nodes + j, &root); 280 for (j = 0; j < nnodes; j++) 281 erase_cached(nodes + j, &root); 282 } 283 284 time2 = get_cycles(); 285 time = time2 - time1; 286 287 time = div_u64(time, perf_loops); 288 printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n", 289 (unsigned long long)time); 290 291 for (i = 0; i < nnodes; i++) 292 insert(nodes + i, &root); 293 294 time1 = get_cycles(); 295 296 for (i = 0; i < perf_loops; i++) { 297 for (node = rb_first(&root.rb_root); node; node = rb_next(node)) 298 ; 299 } 300 301 time2 = get_cycles(); 302 time = time2 - time1; 303 304 time = div_u64(time, perf_loops); 305 printk(" -> test 3 (latency of inorder traversal): %llu cycles\n", 306 (unsigned long long)time); 307 308 time1 = get_cycles(); 309 310 for (i = 0; i < perf_loops; i++) 311 node = rb_first(&root.rb_root); 312 313 time2 = get_cycles(); 314 time = time2 - time1; 315 316 time = div_u64(time, perf_loops); 317 printk(" -> test 4 (latency to fetch first node)\n"); 318 printk(" non-cached: %llu cycles\n", (unsigned long long)time); 319 320 time1 = get_cycles(); 321 322 for (i = 0; i < perf_loops; i++) 323 node = rb_first_cached(&root); 324 325 time2 = get_cycles(); 326 time = time2 - time1; 327 328 time = div_u64(time, perf_loops); 329 printk(" cached: %llu cycles\n", (unsigned long long)time); 330 331 for (i = 0; i < nnodes; i++) 332 erase(nodes + i, &root); 333 334 /* run checks */ 335 for (i = 0; i < check_loops; i++) { 336 init(); 337 for (j = 0; j < nnodes; j++) { 338 check(j); 339 insert(nodes + j, &root); 340 } 341 for (j = 0; j < nnodes; j++) { 342 check(nnodes - j); 343 erase(nodes + j, &root); 344 } 345 check(0); 346 } 347 348 printk(KERN_ALERT "augmented rbtree testing"); 349 350 init(); 351 352 time1 = get_cycles(); 353 354 for (i = 0; i < perf_loops; i++) { 355 for (j = 0; j < nnodes; j++) 356 insert_augmented(nodes + j, &root); 357 for (j = 0; j < nnodes; j++) 358 erase_augmented(nodes + j, &root); 359 } 360 361 time2 = get_cycles(); 362 time = time2 - time1; 363 364 time = div_u64(time, perf_loops); 365 printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", (unsigned long long)time); 366 367 time1 = get_cycles(); 368 369 for (i = 0; i < perf_loops; i++) { 370 for (j = 0; j < nnodes; j++) 371 insert_augmented_cached(nodes + j, &root); 372 for (j = 0; j < nnodes; j++) 373 erase_augmented_cached(nodes + j, &root); 374 } 375 376 time2 = get_cycles(); 377 time = time2 - time1; 378 379 time = div_u64(time, perf_loops); 380 printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n", (unsigned long long)time); 381 382 for (i = 0; i < check_loops; i++) { 383 init(); 384 for (j = 0; j < nnodes; j++) { 385 check_augmented(j); 386 insert_augmented(nodes + j, &root); 387 } 388 for (j = 0; j < nnodes; j++) { 389 check_augmented(nnodes - j); 390 erase_augmented(nodes + j, &root); 391 } 392 check_augmented(0); 393 } 394 395 kfree(nodes); 396 397 return -EAGAIN; /* Fail will directly unload the module */ 398 } 399 400 static void __exit rbtree_test_exit(void) 401 { 402 printk(KERN_ALERT "test exit\n"); 403 } 404 405 module_init(rbtree_test_init) 406 module_exit(rbtree_test_exit) 407 408 MODULE_LICENSE("GPL"); 409 MODULE_AUTHOR("Michel Lespinasse"); 410 MODULE_DESCRIPTION("Red Black Tree test"); 411