1 /* 2 * Copyright (c) 2014, Cisco Systems, Inc. All rights reserved. 3 * 4 * This software is available to you under a choice of one of two 5 * licenses. You may choose to be licensed under the terms of the GNU 6 * General Public License (GPL) Version 2, available from the file 7 * COPYING in the main directory of this source tree, or the 8 * BSD license below: 9 * 10 * Redistribution and use in source and binary forms, with or 11 * without modification, are permitted provided that the following 12 * conditions are met: 13 * 14 * - Redistributions of source code must retain the above 15 * copyright notice, this list of conditions and the following 16 * disclaimer. 17 * 18 * - Redistributions in binary form must reproduce the above 19 * copyright notice, this list of conditions and the following 20 * disclaimer in the documentation and/or other materials 21 * provided with the distribution. 22 * 23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 24 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 25 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 26 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 27 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 28 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 29 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 30 * SOFTWARE. 31 * 32 */ 33 34 #include <linux/init.h> 35 #include <linux/list.h> 36 #include <linux/slab.h> 37 #include <linux/list_sort.h> 38 39 #include <linux/interval_tree_generic.h> 40 #include "usnic_uiom_interval_tree.h" 41 42 #define START(node) ((node)->start) 43 #define LAST(node) ((node)->last) 44 45 #define MAKE_NODE(node, start, end, ref_cnt, flags, err, err_out) \ 46 do { \ 47 node = usnic_uiom_interval_node_alloc(start, \ 48 end, ref_cnt, flags); \ 49 if (!node) { \ 50 err = -ENOMEM; \ 51 goto err_out; \ 52 } \ 53 } while (0) 54 55 #define MARK_FOR_ADD(node, list) (list_add_tail(&node->link, list)) 56 57 #define MAKE_NODE_AND_APPEND(node, start, end, ref_cnt, flags, err, \ 58 err_out, list) \ 59 do { \ 60 MAKE_NODE(node, start, end, \ 61 ref_cnt, flags, err, \ 62 err_out); \ 63 MARK_FOR_ADD(node, list); \ 64 } while (0) 65 66 #define FLAGS_EQUAL(flags1, flags2, mask) \ 67 (((flags1) & (mask)) == ((flags2) & (mask))) 68 69 static struct usnic_uiom_interval_node* 70 usnic_uiom_interval_node_alloc(long int start, long int last, int ref_cnt, 71 int flags) 72 { 73 struct usnic_uiom_interval_node *interval = kzalloc(sizeof(*interval), 74 GFP_ATOMIC); 75 if (!interval) 76 return NULL; 77 78 interval->start = start; 79 interval->last = last; 80 interval->flags = flags; 81 interval->ref_cnt = ref_cnt; 82 83 return interval; 84 } 85 86 static int interval_cmp(void *priv, struct list_head *a, struct list_head *b) 87 { 88 struct usnic_uiom_interval_node *node_a, *node_b; 89 90 node_a = list_entry(a, struct usnic_uiom_interval_node, link); 91 node_b = list_entry(b, struct usnic_uiom_interval_node, link); 92 93 /* long to int */ 94 if (node_a->start < node_b->start) 95 return -1; 96 else if (node_a->start > node_b->start) 97 return 1; 98 99 return 0; 100 } 101 102 static void 103 find_intervals_intersection_sorted(struct rb_root_cached *root, 104 unsigned long start, unsigned long last, 105 struct list_head *list) 106 { 107 struct usnic_uiom_interval_node *node; 108 109 INIT_LIST_HEAD(list); 110 111 for (node = usnic_uiom_interval_tree_iter_first(root, start, last); 112 node; 113 node = usnic_uiom_interval_tree_iter_next(node, start, last)) 114 list_add_tail(&node->link, list); 115 116 list_sort(NULL, list, interval_cmp); 117 } 118 119 int usnic_uiom_get_intervals_diff(unsigned long start, unsigned long last, 120 int flags, int flag_mask, 121 struct rb_root_cached *root, 122 struct list_head *diff_set) 123 { 124 struct usnic_uiom_interval_node *interval, *tmp; 125 int err = 0; 126 long int pivot = start; 127 LIST_HEAD(intersection_set); 128 129 INIT_LIST_HEAD(diff_set); 130 131 find_intervals_intersection_sorted(root, start, last, 132 &intersection_set); 133 134 list_for_each_entry(interval, &intersection_set, link) { 135 if (pivot < interval->start) { 136 MAKE_NODE_AND_APPEND(tmp, pivot, interval->start - 1, 137 1, flags, err, err_out, 138 diff_set); 139 pivot = interval->start; 140 } 141 142 /* 143 * Invariant: Set [start, pivot] is either in diff_set or root, 144 * but not in both. 145 */ 146 147 if (pivot > interval->last) { 148 continue; 149 } else if (pivot <= interval->last && 150 FLAGS_EQUAL(interval->flags, flags, 151 flag_mask)) { 152 pivot = interval->last + 1; 153 } 154 } 155 156 if (pivot <= last) 157 MAKE_NODE_AND_APPEND(tmp, pivot, last, 1, flags, err, err_out, 158 diff_set); 159 160 return 0; 161 162 err_out: 163 list_for_each_entry_safe(interval, tmp, diff_set, link) { 164 list_del(&interval->link); 165 kfree(interval); 166 } 167 168 return err; 169 } 170 171 void usnic_uiom_put_interval_set(struct list_head *intervals) 172 { 173 struct usnic_uiom_interval_node *interval, *tmp; 174 list_for_each_entry_safe(interval, tmp, intervals, link) 175 kfree(interval); 176 } 177 178 int usnic_uiom_insert_interval(struct rb_root_cached *root, unsigned long start, 179 unsigned long last, int flags) 180 { 181 struct usnic_uiom_interval_node *interval, *tmp; 182 unsigned long istart, ilast; 183 int iref_cnt, iflags; 184 unsigned long lpivot = start; 185 int err = 0; 186 LIST_HEAD(to_add); 187 LIST_HEAD(intersection_set); 188 189 find_intervals_intersection_sorted(root, start, last, 190 &intersection_set); 191 192 list_for_each_entry(interval, &intersection_set, link) { 193 /* 194 * Invariant - lpivot is the left edge of next interval to be 195 * inserted 196 */ 197 istart = interval->start; 198 ilast = interval->last; 199 iref_cnt = interval->ref_cnt; 200 iflags = interval->flags; 201 202 if (istart < lpivot) { 203 MAKE_NODE_AND_APPEND(tmp, istart, lpivot - 1, iref_cnt, 204 iflags, err, err_out, &to_add); 205 } else if (istart > lpivot) { 206 MAKE_NODE_AND_APPEND(tmp, lpivot, istart - 1, 1, flags, 207 err, err_out, &to_add); 208 lpivot = istart; 209 } else { 210 lpivot = istart; 211 } 212 213 if (ilast > last) { 214 MAKE_NODE_AND_APPEND(tmp, lpivot, last, iref_cnt + 1, 215 iflags | flags, err, err_out, 216 &to_add); 217 MAKE_NODE_AND_APPEND(tmp, last + 1, ilast, iref_cnt, 218 iflags, err, err_out, &to_add); 219 } else { 220 MAKE_NODE_AND_APPEND(tmp, lpivot, ilast, iref_cnt + 1, 221 iflags | flags, err, err_out, 222 &to_add); 223 } 224 225 lpivot = ilast + 1; 226 } 227 228 if (lpivot <= last) 229 MAKE_NODE_AND_APPEND(tmp, lpivot, last, 1, flags, err, err_out, 230 &to_add); 231 232 list_for_each_entry_safe(interval, tmp, &intersection_set, link) { 233 usnic_uiom_interval_tree_remove(interval, root); 234 kfree(interval); 235 } 236 237 list_for_each_entry(interval, &to_add, link) 238 usnic_uiom_interval_tree_insert(interval, root); 239 240 return 0; 241 242 err_out: 243 list_for_each_entry_safe(interval, tmp, &to_add, link) 244 kfree(interval); 245 246 return err; 247 } 248 249 void usnic_uiom_remove_interval(struct rb_root_cached *root, 250 unsigned long start, unsigned long last, 251 struct list_head *removed) 252 { 253 struct usnic_uiom_interval_node *interval; 254 255 for (interval = usnic_uiom_interval_tree_iter_first(root, start, last); 256 interval; 257 interval = usnic_uiom_interval_tree_iter_next(interval, 258 start, 259 last)) { 260 if (--interval->ref_cnt == 0) 261 list_add_tail(&interval->link, removed); 262 } 263 264 list_for_each_entry(interval, removed, link) 265 usnic_uiom_interval_tree_remove(interval, root); 266 } 267 268 INTERVAL_TREE_DEFINE(struct usnic_uiom_interval_node, rb, 269 unsigned long, __subtree_last, 270 START, LAST, , usnic_uiom_interval_tree) 271