1 /*
2  * Copyright (c) 2014, Cisco Systems, Inc. All rights reserved.
3  *
4  * This program is free software; you may redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation; version 2 of the License.
7  *
8  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
9  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
10  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
11  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
12  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
13  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
14  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15  * SOFTWARE.
16  *
17  */
18 
19 #include <linux/init.h>
20 #include <linux/list.h>
21 #include <linux/slab.h>
22 #include <linux/list_sort.h>
23 
24 #include <linux/interval_tree_generic.h>
25 #include "usnic_uiom_interval_tree.h"
26 
27 #define START(node) ((node)->start)
28 #define LAST(node) ((node)->last)
29 
30 #define MAKE_NODE(node, start, end, ref_cnt, flags, err, err_out)	\
31 		do {							\
32 			node = usnic_uiom_interval_node_alloc(start,	\
33 					end, ref_cnt, flags);		\
34 				if (!node) {				\
35 					err = -ENOMEM;			\
36 					goto err_out;			\
37 				}					\
38 		} while (0)
39 
40 #define MARK_FOR_ADD(node, list) (list_add_tail(&node->link, list))
41 
42 #define MAKE_NODE_AND_APPEND(node, start, end, ref_cnt, flags, err,	\
43 				err_out, list)				\
44 				do {					\
45 					MAKE_NODE(node, start, end,	\
46 						ref_cnt, flags, err,	\
47 						err_out);		\
48 					MARK_FOR_ADD(node, list);	\
49 				} while (0)
50 
51 #define FLAGS_EQUAL(flags1, flags2, mask)				\
52 			(((flags1) & (mask)) == ((flags2) & (mask)))
53 
54 static struct usnic_uiom_interval_node*
55 usnic_uiom_interval_node_alloc(long int start, long int last, int ref_cnt,
56 				int flags)
57 {
58 	struct usnic_uiom_interval_node *interval = kzalloc(sizeof(*interval),
59 								GFP_ATOMIC);
60 	if (!interval)
61 		return NULL;
62 
63 	interval->start = start;
64 	interval->last = last;
65 	interval->flags = flags;
66 	interval->ref_cnt = ref_cnt;
67 
68 	return interval;
69 }
70 
71 static int interval_cmp(void *priv, struct list_head *a, struct list_head *b)
72 {
73 	struct usnic_uiom_interval_node *node_a, *node_b;
74 
75 	node_a = list_entry(a, struct usnic_uiom_interval_node, link);
76 	node_b = list_entry(b, struct usnic_uiom_interval_node, link);
77 
78 	/* long to int */
79 	if (node_a->start < node_b->start)
80 		return -1;
81 	else if (node_a->start > node_b->start)
82 		return 1;
83 
84 	return 0;
85 }
86 
87 static void
88 find_intervals_intersection_sorted(struct rb_root *root, unsigned long start,
89 					unsigned long last,
90 					struct list_head *list)
91 {
92 	struct usnic_uiom_interval_node *node;
93 
94 	INIT_LIST_HEAD(list);
95 
96 	for (node = usnic_uiom_interval_tree_iter_first(root, start, last);
97 		node;
98 		node = usnic_uiom_interval_tree_iter_next(node, start, last))
99 		list_add_tail(&node->link, list);
100 
101 	list_sort(NULL, list, interval_cmp);
102 }
103 
104 int usnic_uiom_get_intervals_diff(unsigned long start, unsigned long last,
105 					int flags, int flag_mask,
106 					struct rb_root *root,
107 					struct list_head *diff_set)
108 {
109 	struct usnic_uiom_interval_node *interval, *tmp;
110 	int err = 0;
111 	long int pivot = start;
112 	LIST_HEAD(intersection_set);
113 
114 	INIT_LIST_HEAD(diff_set);
115 
116 	find_intervals_intersection_sorted(root, start, last,
117 						&intersection_set);
118 
119 	list_for_each_entry(interval, &intersection_set, link) {
120 		if (pivot < interval->start) {
121 			MAKE_NODE_AND_APPEND(tmp, pivot, interval->start - 1,
122 						1, flags, err, err_out,
123 						diff_set);
124 			pivot = interval->start;
125 		}
126 
127 		/*
128 		 * Invariant: Set [start, pivot] is either in diff_set or root,
129 		 * but not in both.
130 		 */
131 
132 		if (pivot > interval->last) {
133 			continue;
134 		} else if (pivot <= interval->last &&
135 				FLAGS_EQUAL(interval->flags, flags,
136 				flag_mask)) {
137 			pivot = interval->last + 1;
138 		}
139 	}
140 
141 	if (pivot <= last)
142 		MAKE_NODE_AND_APPEND(tmp, pivot, last, 1, flags, err, err_out,
143 					diff_set);
144 
145 	return 0;
146 
147 err_out:
148 	list_for_each_entry_safe(interval, tmp, diff_set, link) {
149 		list_del(&interval->link);
150 		kfree(interval);
151 	}
152 
153 	return err;
154 }
155 
156 void usnic_uiom_put_interval_set(struct list_head *intervals)
157 {
158 	struct usnic_uiom_interval_node *interval, *tmp;
159 	list_for_each_entry_safe(interval, tmp, intervals, link)
160 		kfree(interval);
161 }
162 
163 int usnic_uiom_insert_interval(struct rb_root *root, unsigned long start,
164 				unsigned long last, int flags)
165 {
166 	struct usnic_uiom_interval_node *interval, *tmp;
167 	unsigned long istart, ilast;
168 	int iref_cnt, iflags;
169 	unsigned long lpivot = start;
170 	int err = 0;
171 	LIST_HEAD(to_add);
172 	LIST_HEAD(intersection_set);
173 
174 	find_intervals_intersection_sorted(root, start, last,
175 						&intersection_set);
176 
177 	list_for_each_entry(interval, &intersection_set, link) {
178 		/*
179 		 * Invariant - lpivot is the left edge of next interval to be
180 		 * inserted
181 		 */
182 		istart = interval->start;
183 		ilast = interval->last;
184 		iref_cnt = interval->ref_cnt;
185 		iflags = interval->flags;
186 
187 		if (istart < lpivot) {
188 			MAKE_NODE_AND_APPEND(tmp, istart, lpivot - 1, iref_cnt,
189 						iflags, err, err_out, &to_add);
190 		} else if (istart > lpivot) {
191 			MAKE_NODE_AND_APPEND(tmp, lpivot, istart - 1, 1, flags,
192 						err, err_out, &to_add);
193 			lpivot = istart;
194 		} else {
195 			lpivot = istart;
196 		}
197 
198 		if (ilast > last) {
199 			MAKE_NODE_AND_APPEND(tmp, lpivot, last, iref_cnt + 1,
200 						iflags | flags, err, err_out,
201 						&to_add);
202 			MAKE_NODE_AND_APPEND(tmp, last + 1, ilast, iref_cnt,
203 						iflags, err, err_out, &to_add);
204 		} else {
205 			MAKE_NODE_AND_APPEND(tmp, lpivot, ilast, iref_cnt + 1,
206 						iflags | flags, err, err_out,
207 						&to_add);
208 		}
209 
210 		lpivot = ilast + 1;
211 	}
212 
213 	if (lpivot <= last)
214 		MAKE_NODE_AND_APPEND(tmp, lpivot, last, 1, flags, err, err_out,
215 					&to_add);
216 
217 	list_for_each_entry_safe(interval, tmp, &intersection_set, link) {
218 		usnic_uiom_interval_tree_remove(interval, root);
219 		kfree(interval);
220 	}
221 
222 	list_for_each_entry(interval, &to_add, link)
223 		usnic_uiom_interval_tree_insert(interval, root);
224 
225 	return 0;
226 
227 err_out:
228 	list_for_each_entry_safe(interval, tmp, &to_add, link)
229 		kfree(interval);
230 
231 	return err;
232 }
233 
234 void usnic_uiom_remove_interval(struct rb_root *root, unsigned long start,
235 				unsigned long last, struct list_head *removed)
236 {
237 	struct usnic_uiom_interval_node *interval;
238 
239 	for (interval = usnic_uiom_interval_tree_iter_first(root, start, last);
240 			interval;
241 			interval = usnic_uiom_interval_tree_iter_next(interval,
242 									start,
243 									last)) {
244 		if (--interval->ref_cnt == 0)
245 			list_add_tail(&interval->link, removed);
246 	}
247 
248 	list_for_each_entry(interval, removed, link)
249 		usnic_uiom_interval_tree_remove(interval, root);
250 }
251 
252 INTERVAL_TREE_DEFINE(struct usnic_uiom_interval_node, rb,
253 			unsigned long, __subtree_last,
254 			START, LAST, , usnic_uiom_interval_tree)
255