1 // SPDX-License-Identifier: GPL-2.0 2 #include <asm/bug.h> 3 #include <linux/rbtree_augmented.h> 4 #include "drbd_interval.h" 5 6 /** 7 * interval_end - return end of @node 8 */ 9 static inline 10 sector_t interval_end(struct rb_node *node) 11 { 12 struct drbd_interval *this = rb_entry(node, struct drbd_interval, rb); 13 return this->end; 14 } 15 16 #define NODE_END(node) ((node)->sector + ((node)->size >> 9)) 17 18 RB_DECLARE_CALLBACKS_MAX(static, augment_callbacks, 19 struct drbd_interval, rb, sector_t, end, NODE_END); 20 21 /** 22 * drbd_insert_interval - insert a new interval into a tree 23 */ 24 bool 25 drbd_insert_interval(struct rb_root *root, struct drbd_interval *this) 26 { 27 struct rb_node **new = &root->rb_node, *parent = NULL; 28 sector_t this_end = this->sector + (this->size >> 9); 29 30 BUG_ON(!IS_ALIGNED(this->size, 512)); 31 32 while (*new) { 33 struct drbd_interval *here = 34 rb_entry(*new, struct drbd_interval, rb); 35 36 parent = *new; 37 if (here->end < this_end) 38 here->end = this_end; 39 if (this->sector < here->sector) 40 new = &(*new)->rb_left; 41 else if (this->sector > here->sector) 42 new = &(*new)->rb_right; 43 else if (this < here) 44 new = &(*new)->rb_left; 45 else if (this > here) 46 new = &(*new)->rb_right; 47 else 48 return false; 49 } 50 51 this->end = this_end; 52 rb_link_node(&this->rb, parent, new); 53 rb_insert_augmented(&this->rb, root, &augment_callbacks); 54 return true; 55 } 56 57 /** 58 * drbd_contains_interval - check if a tree contains a given interval 59 * @sector: start sector of @interval 60 * @interval: may not be a valid pointer 61 * 62 * Returns if the tree contains the node @interval with start sector @start. 63 * Does not dereference @interval until @interval is known to be a valid object 64 * in @tree. Returns %false if @interval is in the tree but with a different 65 * sector number. 66 */ 67 bool 68 drbd_contains_interval(struct rb_root *root, sector_t sector, 69 struct drbd_interval *interval) 70 { 71 struct rb_node *node = root->rb_node; 72 73 while (node) { 74 struct drbd_interval *here = 75 rb_entry(node, struct drbd_interval, rb); 76 77 if (sector < here->sector) 78 node = node->rb_left; 79 else if (sector > here->sector) 80 node = node->rb_right; 81 else if (interval < here) 82 node = node->rb_left; 83 else if (interval > here) 84 node = node->rb_right; 85 else 86 return true; 87 } 88 return false; 89 } 90 91 /** 92 * drbd_remove_interval - remove an interval from a tree 93 */ 94 void 95 drbd_remove_interval(struct rb_root *root, struct drbd_interval *this) 96 { 97 rb_erase_augmented(&this->rb, root, &augment_callbacks); 98 } 99 100 /** 101 * drbd_find_overlap - search for an interval overlapping with [sector, sector + size) 102 * @sector: start sector 103 * @size: size, aligned to 512 bytes 104 * 105 * Returns an interval overlapping with [sector, sector + size), or NULL if 106 * there is none. When there is more than one overlapping interval in the 107 * tree, the interval with the lowest start sector is returned, and all other 108 * overlapping intervals will be on the right side of the tree, reachable with 109 * rb_next(). 110 */ 111 struct drbd_interval * 112 drbd_find_overlap(struct rb_root *root, sector_t sector, unsigned int size) 113 { 114 struct rb_node *node = root->rb_node; 115 struct drbd_interval *overlap = NULL; 116 sector_t end = sector + (size >> 9); 117 118 BUG_ON(!IS_ALIGNED(size, 512)); 119 120 while (node) { 121 struct drbd_interval *here = 122 rb_entry(node, struct drbd_interval, rb); 123 124 if (node->rb_left && 125 sector < interval_end(node->rb_left)) { 126 /* Overlap if any must be on left side */ 127 node = node->rb_left; 128 } else if (here->sector < end && 129 sector < here->sector + (here->size >> 9)) { 130 overlap = here; 131 break; 132 } else if (sector >= here->sector) { 133 /* Overlap if any must be on right side */ 134 node = node->rb_right; 135 } else 136 break; 137 } 138 return overlap; 139 } 140 141 struct drbd_interval * 142 drbd_next_overlap(struct drbd_interval *i, sector_t sector, unsigned int size) 143 { 144 sector_t end = sector + (size >> 9); 145 struct rb_node *node; 146 147 for (;;) { 148 node = rb_next(&i->rb); 149 if (!node) 150 return NULL; 151 i = rb_entry(node, struct drbd_interval, rb); 152 if (i->sector >= end) 153 return NULL; 154 if (sector < i->sector + (i->size >> 9)) 155 return i; 156 } 157 } 158