1 /* 2 * Copyright (c) 2014, Cisco Systems, Inc. All rights reserved. 3 * 4 * This software is available to you under a choice of one of two 5 * licenses. You may choose to be licensed under the terms of the GNU 6 * General Public License (GPL) Version 2, available from the file 7 * COPYING in the main directory of this source tree, or the 8 * BSD license below: 9 * 10 * Redistribution and use in source and binary forms, with or 11 * without modification, are permitted provided that the following 12 * conditions are met: 13 * 14 * - Redistributions of source code must retain the above 15 * copyright notice, this list of conditions and the following 16 * disclaimer. 17 * 18 * - Redistributions in binary form must reproduce the above 19 * copyright notice, this list of conditions and the following 20 * disclaimer in the documentation and/or other materials 21 * provided with the distribution. 22 * 23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 24 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 25 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 26 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 27 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 28 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 29 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 30 * SOFTWARE. 31 * 32 */ 33 34 #include <linux/init.h> 35 #include <linux/list.h> 36 #include <linux/slab.h> 37 #include <linux/list_sort.h> 38 39 #include <linux/interval_tree_generic.h> 40 #include "usnic_uiom_interval_tree.h" 41 42 #define START(node) ((node)->start) 43 #define LAST(node) ((node)->last) 44 45 #define MAKE_NODE(node, start, end, ref_cnt, flags, err, err_out) \ 46 do { \ 47 node = usnic_uiom_interval_node_alloc(start, \ 48 end, ref_cnt, flags); \ 49 if (!node) { \ 50 err = -ENOMEM; \ 51 goto err_out; \ 52 } \ 53 } while (0) 54 55 #define MARK_FOR_ADD(node, list) (list_add_tail(&node->link, list)) 56 57 #define MAKE_NODE_AND_APPEND(node, start, end, ref_cnt, flags, err, \ 58 err_out, list) \ 59 do { \ 60 MAKE_NODE(node, start, end, \ 61 ref_cnt, flags, err, \ 62 err_out); \ 63 MARK_FOR_ADD(node, list); \ 64 } while (0) 65 66 #define FLAGS_EQUAL(flags1, flags2, mask) \ 67 (((flags1) & (mask)) == ((flags2) & (mask))) 68 69 static struct usnic_uiom_interval_node* 70 usnic_uiom_interval_node_alloc(long int start, long int last, int ref_cnt, 71 int flags) 72 { 73 struct usnic_uiom_interval_node *interval = kzalloc(sizeof(*interval), 74 GFP_ATOMIC); 75 if (!interval) 76 return NULL; 77 78 interval->start = start; 79 interval->last = last; 80 interval->flags = flags; 81 interval->ref_cnt = ref_cnt; 82 83 return interval; 84 } 85 86 static int interval_cmp(void *priv, const struct list_head *a, 87 const struct list_head *b) 88 { 89 struct usnic_uiom_interval_node *node_a, *node_b; 90 91 node_a = list_entry(a, struct usnic_uiom_interval_node, link); 92 node_b = list_entry(b, struct usnic_uiom_interval_node, link); 93 94 /* long to int */ 95 if (node_a->start < node_b->start) 96 return -1; 97 else if (node_a->start > node_b->start) 98 return 1; 99 100 return 0; 101 } 102 103 static void 104 find_intervals_intersection_sorted(struct rb_root_cached *root, 105 unsigned long start, unsigned long last, 106 struct list_head *list) 107 { 108 struct usnic_uiom_interval_node *node; 109 110 INIT_LIST_HEAD(list); 111 112 for (node = usnic_uiom_interval_tree_iter_first(root, start, last); 113 node; 114 node = usnic_uiom_interval_tree_iter_next(node, start, last)) 115 list_add_tail(&node->link, list); 116 117 list_sort(NULL, list, interval_cmp); 118 } 119 120 int usnic_uiom_get_intervals_diff(unsigned long start, unsigned long last, 121 int flags, int flag_mask, 122 struct rb_root_cached *root, 123 struct list_head *diff_set) 124 { 125 struct usnic_uiom_interval_node *interval, *tmp; 126 int err = 0; 127 long int pivot = start; 128 LIST_HEAD(intersection_set); 129 130 INIT_LIST_HEAD(diff_set); 131 132 find_intervals_intersection_sorted(root, start, last, 133 &intersection_set); 134 135 list_for_each_entry(interval, &intersection_set, link) { 136 if (pivot < interval->start) { 137 MAKE_NODE_AND_APPEND(tmp, pivot, interval->start - 1, 138 1, flags, err, err_out, 139 diff_set); 140 pivot = interval->start; 141 } 142 143 /* 144 * Invariant: Set [start, pivot] is either in diff_set or root, 145 * but not in both. 146 */ 147 148 if (pivot > interval->last) { 149 continue; 150 } else if (pivot <= interval->last && 151 FLAGS_EQUAL(interval->flags, flags, 152 flag_mask)) { 153 pivot = interval->last + 1; 154 } 155 } 156 157 if (pivot <= last) 158 MAKE_NODE_AND_APPEND(tmp, pivot, last, 1, flags, err, err_out, 159 diff_set); 160 161 return 0; 162 163 err_out: 164 list_for_each_entry_safe(interval, tmp, diff_set, link) { 165 list_del(&interval->link); 166 kfree(interval); 167 } 168 169 return err; 170 } 171 172 void usnic_uiom_put_interval_set(struct list_head *intervals) 173 { 174 struct usnic_uiom_interval_node *interval, *tmp; 175 list_for_each_entry_safe(interval, tmp, intervals, link) 176 kfree(interval); 177 } 178 179 int usnic_uiom_insert_interval(struct rb_root_cached *root, unsigned long start, 180 unsigned long last, int flags) 181 { 182 struct usnic_uiom_interval_node *interval, *tmp; 183 unsigned long istart, ilast; 184 int iref_cnt, iflags; 185 unsigned long lpivot = start; 186 int err = 0; 187 LIST_HEAD(to_add); 188 LIST_HEAD(intersection_set); 189 190 find_intervals_intersection_sorted(root, start, last, 191 &intersection_set); 192 193 list_for_each_entry(interval, &intersection_set, link) { 194 /* 195 * Invariant - lpivot is the left edge of next interval to be 196 * inserted 197 */ 198 istart = interval->start; 199 ilast = interval->last; 200 iref_cnt = interval->ref_cnt; 201 iflags = interval->flags; 202 203 if (istart < lpivot) { 204 MAKE_NODE_AND_APPEND(tmp, istart, lpivot - 1, iref_cnt, 205 iflags, err, err_out, &to_add); 206 } else if (istart > lpivot) { 207 MAKE_NODE_AND_APPEND(tmp, lpivot, istart - 1, 1, flags, 208 err, err_out, &to_add); 209 lpivot = istart; 210 } else { 211 lpivot = istart; 212 } 213 214 if (ilast > last) { 215 MAKE_NODE_AND_APPEND(tmp, lpivot, last, iref_cnt + 1, 216 iflags | flags, err, err_out, 217 &to_add); 218 MAKE_NODE_AND_APPEND(tmp, last + 1, ilast, iref_cnt, 219 iflags, err, err_out, &to_add); 220 } else { 221 MAKE_NODE_AND_APPEND(tmp, lpivot, ilast, iref_cnt + 1, 222 iflags | flags, err, err_out, 223 &to_add); 224 } 225 226 lpivot = ilast + 1; 227 } 228 229 if (lpivot <= last) 230 MAKE_NODE_AND_APPEND(tmp, lpivot, last, 1, flags, err, err_out, 231 &to_add); 232 233 list_for_each_entry_safe(interval, tmp, &intersection_set, link) { 234 usnic_uiom_interval_tree_remove(interval, root); 235 kfree(interval); 236 } 237 238 list_for_each_entry(interval, &to_add, link) 239 usnic_uiom_interval_tree_insert(interval, root); 240 241 return 0; 242 243 err_out: 244 list_for_each_entry_safe(interval, tmp, &to_add, link) 245 kfree(interval); 246 247 return err; 248 } 249 250 void usnic_uiom_remove_interval(struct rb_root_cached *root, 251 unsigned long start, unsigned long last, 252 struct list_head *removed) 253 { 254 struct usnic_uiom_interval_node *interval; 255 256 for (interval = usnic_uiom_interval_tree_iter_first(root, start, last); 257 interval; 258 interval = usnic_uiom_interval_tree_iter_next(interval, 259 start, 260 last)) { 261 if (--interval->ref_cnt == 0) 262 list_add_tail(&interval->link, removed); 263 } 264 265 list_for_each_entry(interval, removed, link) 266 usnic_uiom_interval_tree_remove(interval, root); 267 } 268 269 INTERVAL_TREE_DEFINE(struct usnic_uiom_interval_node, rb, 270 unsigned long, __subtree_last, 271 START, LAST, , usnic_uiom_interval_tree) 272