1 // SPDX-License-Identifier: GPL-2.0-only 2 #include <linux/module.h> 3 #include <linux/moduleparam.h> 4 #include <linux/rbtree_augmented.h> 5 #include <linux/random.h> 6 #include <linux/slab.h> 7 #include <asm/timex.h> 8 9 #define __param(type, name, init, msg) \ 10 static type name = init; \ 11 module_param(name, type, 0444); \ 12 MODULE_PARM_DESC(name, msg); 13 14 __param(int, nnodes, 100, "Number of nodes in the rb-tree"); 15 __param(int, perf_loops, 1000, "Number of iterations modifying the rb-tree"); 16 __param(int, check_loops, 100, "Number of iterations modifying and verifying the rb-tree"); 17 18 struct test_node { 19 u32 key; 20 struct rb_node rb; 21 22 /* following fields used for testing augmented rbtree functionality */ 23 u32 val; 24 u32 augmented; 25 }; 26 27 static struct rb_root_cached root = RB_ROOT_CACHED; 28 static struct test_node *nodes = NULL; 29 30 static struct rnd_state rnd; 31 32 static void insert(struct test_node *node, struct rb_root_cached *root) 33 { 34 struct rb_node **new = &root->rb_root.rb_node, *parent = NULL; 35 u32 key = node->key; 36 37 while (*new) { 38 parent = *new; 39 if (key < rb_entry(parent, struct test_node, rb)->key) 40 new = &parent->rb_left; 41 else 42 new = &parent->rb_right; 43 } 44 45 rb_link_node(&node->rb, parent, new); 46 rb_insert_color(&node->rb, &root->rb_root); 47 } 48 49 static void insert_cached(struct test_node *node, struct rb_root_cached *root) 50 { 51 struct rb_node **new = &root->rb_root.rb_node, *parent = NULL; 52 u32 key = node->key; 53 bool leftmost = true; 54 55 while (*new) { 56 parent = *new; 57 if (key < rb_entry(parent, struct test_node, rb)->key) 58 new = &parent->rb_left; 59 else { 60 new = &parent->rb_right; 61 leftmost = false; 62 } 63 } 64 65 rb_link_node(&node->rb, parent, new); 66 rb_insert_color_cached(&node->rb, root, leftmost); 67 } 68 69 static inline void erase(struct test_node *node, struct rb_root_cached *root) 70 { 71 rb_erase(&node->rb, &root->rb_root); 72 } 73 74 static inline void erase_cached(struct test_node *node, struct rb_root_cached *root) 75 { 76 rb_erase_cached(&node->rb, root); 77 } 78 79 80 #define NODE_VAL(node) ((node)->val) 81 82 RB_DECLARE_CALLBACKS_MAX(static, augment_callbacks, 83 struct test_node, rb, u32, augmented, NODE_VAL) 84 85 static void insert_augmented(struct test_node *node, 86 struct rb_root_cached *root) 87 { 88 struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL; 89 u32 key = node->key; 90 u32 val = node->val; 91 struct test_node *parent; 92 93 while (*new) { 94 rb_parent = *new; 95 parent = rb_entry(rb_parent, struct test_node, rb); 96 if (parent->augmented < val) 97 parent->augmented = val; 98 if (key < parent->key) 99 new = &parent->rb.rb_left; 100 else 101 new = &parent->rb.rb_right; 102 } 103 104 node->augmented = val; 105 rb_link_node(&node->rb, rb_parent, new); 106 rb_insert_augmented(&node->rb, &root->rb_root, &augment_callbacks); 107 } 108 109 static void insert_augmented_cached(struct test_node *node, 110 struct rb_root_cached *root) 111 { 112 struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL; 113 u32 key = node->key; 114 u32 val = node->val; 115 struct test_node *parent; 116 bool leftmost = true; 117 118 while (*new) { 119 rb_parent = *new; 120 parent = rb_entry(rb_parent, struct test_node, rb); 121 if (parent->augmented < val) 122 parent->augmented = val; 123 if (key < parent->key) 124 new = &parent->rb.rb_left; 125 else { 126 new = &parent->rb.rb_right; 127 leftmost = false; 128 } 129 } 130 131 node->augmented = val; 132 rb_link_node(&node->rb, rb_parent, new); 133 rb_insert_augmented_cached(&node->rb, root, 134 leftmost, &augment_callbacks); 135 } 136 137 138 static void erase_augmented(struct test_node *node, struct rb_root_cached *root) 139 { 140 rb_erase_augmented(&node->rb, &root->rb_root, &augment_callbacks); 141 } 142 143 static void erase_augmented_cached(struct test_node *node, 144 struct rb_root_cached *root) 145 { 146 rb_erase_augmented_cached(&node->rb, root, &augment_callbacks); 147 } 148 149 static void init(void) 150 { 151 int i; 152 for (i = 0; i < nnodes; i++) { 153 nodes[i].key = prandom_u32_state(&rnd); 154 nodes[i].val = prandom_u32_state(&rnd); 155 } 156 } 157 158 static bool is_red(struct rb_node *rb) 159 { 160 return !(rb->__rb_parent_color & 1); 161 } 162 163 static int black_path_count(struct rb_node *rb) 164 { 165 int count; 166 for (count = 0; rb; rb = rb_parent(rb)) 167 count += !is_red(rb); 168 return count; 169 } 170 171 static void check_postorder_foreach(int nr_nodes) 172 { 173 struct test_node *cur, *n; 174 int count = 0; 175 rbtree_postorder_for_each_entry_safe(cur, n, &root.rb_root, rb) 176 count++; 177 178 WARN_ON_ONCE(count != nr_nodes); 179 } 180 181 static void check_postorder(int nr_nodes) 182 { 183 struct rb_node *rb; 184 int count = 0; 185 for (rb = rb_first_postorder(&root.rb_root); rb; rb = rb_next_postorder(rb)) 186 count++; 187 188 WARN_ON_ONCE(count != nr_nodes); 189 } 190 191 static void check(int nr_nodes) 192 { 193 struct rb_node *rb; 194 int count = 0, blacks = 0; 195 u32 prev_key = 0; 196 197 for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) { 198 struct test_node *node = rb_entry(rb, struct test_node, rb); 199 WARN_ON_ONCE(node->key < prev_key); 200 WARN_ON_ONCE(is_red(rb) && 201 (!rb_parent(rb) || is_red(rb_parent(rb)))); 202 if (!count) 203 blacks = black_path_count(rb); 204 else 205 WARN_ON_ONCE((!rb->rb_left || !rb->rb_right) && 206 blacks != black_path_count(rb)); 207 prev_key = node->key; 208 count++; 209 } 210 211 WARN_ON_ONCE(count != nr_nodes); 212 WARN_ON_ONCE(count < (1 << black_path_count(rb_last(&root.rb_root))) - 1); 213 214 check_postorder(nr_nodes); 215 check_postorder_foreach(nr_nodes); 216 } 217 218 static void check_augmented(int nr_nodes) 219 { 220 struct rb_node *rb; 221 222 check(nr_nodes); 223 for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) { 224 struct test_node *node = rb_entry(rb, struct test_node, rb); 225 u32 subtree, max = node->val; 226 if (node->rb.rb_left) { 227 subtree = rb_entry(node->rb.rb_left, struct test_node, 228 rb)->augmented; 229 if (max < subtree) 230 max = subtree; 231 } 232 if (node->rb.rb_right) { 233 subtree = rb_entry(node->rb.rb_right, struct test_node, 234 rb)->augmented; 235 if (max < subtree) 236 max = subtree; 237 } 238 WARN_ON_ONCE(node->augmented != max); 239 } 240 } 241 242 static int __init rbtree_test_init(void) 243 { 244 int i, j; 245 cycles_t time1, time2, time; 246 struct rb_node *node; 247 248 nodes = kmalloc_array(nnodes, sizeof(*nodes), GFP_KERNEL); 249 if (!nodes) 250 return -ENOMEM; 251 252 printk(KERN_ALERT "rbtree testing"); 253 254 prandom_seed_state(&rnd, 3141592653589793238ULL); 255 init(); 256 257 time1 = get_cycles(); 258 259 for (i = 0; i < perf_loops; i++) { 260 for (j = 0; j < nnodes; j++) 261 insert(nodes + j, &root); 262 for (j = 0; j < nnodes; j++) 263 erase(nodes + j, &root); 264 } 265 266 time2 = get_cycles(); 267 time = time2 - time1; 268 269 time = div_u64(time, perf_loops); 270 printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", 271 (unsigned long long)time); 272 273 time1 = get_cycles(); 274 275 for (i = 0; i < perf_loops; i++) { 276 for (j = 0; j < nnodes; j++) 277 insert_cached(nodes + j, &root); 278 for (j = 0; j < nnodes; j++) 279 erase_cached(nodes + j, &root); 280 } 281 282 time2 = get_cycles(); 283 time = time2 - time1; 284 285 time = div_u64(time, perf_loops); 286 printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n", 287 (unsigned long long)time); 288 289 for (i = 0; i < nnodes; i++) 290 insert(nodes + i, &root); 291 292 time1 = get_cycles(); 293 294 for (i = 0; i < perf_loops; i++) { 295 for (node = rb_first(&root.rb_root); node; node = rb_next(node)) 296 ; 297 } 298 299 time2 = get_cycles(); 300 time = time2 - time1; 301 302 time = div_u64(time, perf_loops); 303 printk(" -> test 3 (latency of inorder traversal): %llu cycles\n", 304 (unsigned long long)time); 305 306 time1 = get_cycles(); 307 308 for (i = 0; i < perf_loops; i++) 309 node = rb_first(&root.rb_root); 310 311 time2 = get_cycles(); 312 time = time2 - time1; 313 314 time = div_u64(time, perf_loops); 315 printk(" -> test 4 (latency to fetch first node)\n"); 316 printk(" non-cached: %llu cycles\n", (unsigned long long)time); 317 318 time1 = get_cycles(); 319 320 for (i = 0; i < perf_loops; i++) 321 node = rb_first_cached(&root); 322 323 time2 = get_cycles(); 324 time = time2 - time1; 325 326 time = div_u64(time, perf_loops); 327 printk(" cached: %llu cycles\n", (unsigned long long)time); 328 329 for (i = 0; i < nnodes; i++) 330 erase(nodes + i, &root); 331 332 /* run checks */ 333 for (i = 0; i < check_loops; i++) { 334 init(); 335 for (j = 0; j < nnodes; j++) { 336 check(j); 337 insert(nodes + j, &root); 338 } 339 for (j = 0; j < nnodes; j++) { 340 check(nnodes - j); 341 erase(nodes + j, &root); 342 } 343 check(0); 344 } 345 346 printk(KERN_ALERT "augmented rbtree testing"); 347 348 init(); 349 350 time1 = get_cycles(); 351 352 for (i = 0; i < perf_loops; i++) { 353 for (j = 0; j < nnodes; j++) 354 insert_augmented(nodes + j, &root); 355 for (j = 0; j < nnodes; j++) 356 erase_augmented(nodes + j, &root); 357 } 358 359 time2 = get_cycles(); 360 time = time2 - time1; 361 362 time = div_u64(time, perf_loops); 363 printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", (unsigned long long)time); 364 365 time1 = get_cycles(); 366 367 for (i = 0; i < perf_loops; i++) { 368 for (j = 0; j < nnodes; j++) 369 insert_augmented_cached(nodes + j, &root); 370 for (j = 0; j < nnodes; j++) 371 erase_augmented_cached(nodes + j, &root); 372 } 373 374 time2 = get_cycles(); 375 time = time2 - time1; 376 377 time = div_u64(time, perf_loops); 378 printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n", (unsigned long long)time); 379 380 for (i = 0; i < check_loops; i++) { 381 init(); 382 for (j = 0; j < nnodes; j++) { 383 check_augmented(j); 384 insert_augmented(nodes + j, &root); 385 } 386 for (j = 0; j < nnodes; j++) { 387 check_augmented(nnodes - j); 388 erase_augmented(nodes + j, &root); 389 } 390 check_augmented(0); 391 } 392 393 kfree(nodes); 394 395 return -EAGAIN; /* Fail will directly unload the module */ 396 } 397 398 static void __exit rbtree_test_exit(void) 399 { 400 printk(KERN_ALERT "test exit\n"); 401 } 402 403 module_init(rbtree_test_init) 404 module_exit(rbtree_test_exit) 405 406 MODULE_LICENSE("GPL"); 407 MODULE_AUTHOR("Michel Lespinasse"); 408 MODULE_DESCRIPTION("Red Black Tree test"); 409