1 /*
2  * Copyright (c) 2014, Cisco Systems, Inc. All rights reserved.
3  *
4  * This software is available to you under a choice of one of two
5  * licenses.  You may choose to be licensed under the terms of the GNU
6  * General Public License (GPL) Version 2, available from the file
7  * COPYING in the main directory of this source tree, or the
8  * BSD license below:
9  *
10  *     Redistribution and use in source and binary forms, with or
11  *     without modification, are permitted provided that the following
12  *     conditions are met:
13  *
14  *      - Redistributions of source code must retain the above
15  *        copyright notice, this list of conditions and the following
16  *        disclaimer.
17  *
18  *      - Redistributions in binary form must reproduce the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer in the documentation and/or other materials
21  *        provided with the distribution.
22  *
23  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30  * SOFTWARE.
31  *
32  */
33 
34 #include <linux/init.h>
35 #include <linux/list.h>
36 #include <linux/slab.h>
37 #include <linux/list_sort.h>
38 
39 #include <linux/interval_tree_generic.h>
40 #include "usnic_uiom_interval_tree.h"
41 
42 #define START(node) ((node)->start)
43 #define LAST(node) ((node)->last)
44 
45 #define MAKE_NODE(node, start, end, ref_cnt, flags, err, err_out)	\
46 		do {							\
47 			node = usnic_uiom_interval_node_alloc(start,	\
48 					end, ref_cnt, flags);		\
49 				if (!node) {				\
50 					err = -ENOMEM;			\
51 					goto err_out;			\
52 				}					\
53 		} while (0)
54 
55 #define MARK_FOR_ADD(node, list) (list_add_tail(&node->link, list))
56 
57 #define MAKE_NODE_AND_APPEND(node, start, end, ref_cnt, flags, err,	\
58 				err_out, list)				\
59 				do {					\
60 					MAKE_NODE(node, start, end,	\
61 						ref_cnt, flags, err,	\
62 						err_out);		\
63 					MARK_FOR_ADD(node, list);	\
64 				} while (0)
65 
66 #define FLAGS_EQUAL(flags1, flags2, mask)				\
67 			(((flags1) & (mask)) == ((flags2) & (mask)))
68 
69 static struct usnic_uiom_interval_node*
70 usnic_uiom_interval_node_alloc(long int start, long int last, int ref_cnt,
71 				int flags)
72 {
73 	struct usnic_uiom_interval_node *interval = kzalloc(sizeof(*interval),
74 								GFP_ATOMIC);
75 	if (!interval)
76 		return NULL;
77 
78 	interval->start = start;
79 	interval->last = last;
80 	interval->flags = flags;
81 	interval->ref_cnt = ref_cnt;
82 
83 	return interval;
84 }
85 
86 static int interval_cmp(void *priv, const struct list_head *a,
87 			const struct list_head *b)
88 {
89 	struct usnic_uiom_interval_node *node_a, *node_b;
90 
91 	node_a = list_entry(a, struct usnic_uiom_interval_node, link);
92 	node_b = list_entry(b, struct usnic_uiom_interval_node, link);
93 
94 	/* long to int */
95 	if (node_a->start < node_b->start)
96 		return -1;
97 	else if (node_a->start > node_b->start)
98 		return 1;
99 
100 	return 0;
101 }
102 
103 static void
104 find_intervals_intersection_sorted(struct rb_root_cached *root,
105 				   unsigned long start, unsigned long last,
106 				   struct list_head *list)
107 {
108 	struct usnic_uiom_interval_node *node;
109 
110 	INIT_LIST_HEAD(list);
111 
112 	for (node = usnic_uiom_interval_tree_iter_first(root, start, last);
113 		node;
114 		node = usnic_uiom_interval_tree_iter_next(node, start, last))
115 		list_add_tail(&node->link, list);
116 
117 	list_sort(NULL, list, interval_cmp);
118 }
119 
120 int usnic_uiom_get_intervals_diff(unsigned long start, unsigned long last,
121 					int flags, int flag_mask,
122 					struct rb_root_cached *root,
123 					struct list_head *diff_set)
124 {
125 	struct usnic_uiom_interval_node *interval, *tmp;
126 	int err = 0;
127 	long int pivot = start;
128 	LIST_HEAD(intersection_set);
129 
130 	INIT_LIST_HEAD(diff_set);
131 
132 	find_intervals_intersection_sorted(root, start, last,
133 						&intersection_set);
134 
135 	list_for_each_entry(interval, &intersection_set, link) {
136 		if (pivot < interval->start) {
137 			MAKE_NODE_AND_APPEND(tmp, pivot, interval->start - 1,
138 						1, flags, err, err_out,
139 						diff_set);
140 			pivot = interval->start;
141 		}
142 
143 		/*
144 		 * Invariant: Set [start, pivot] is either in diff_set or root,
145 		 * but not in both.
146 		 */
147 
148 		if (pivot > interval->last) {
149 			continue;
150 		} else if (pivot <= interval->last &&
151 				FLAGS_EQUAL(interval->flags, flags,
152 				flag_mask)) {
153 			pivot = interval->last + 1;
154 		}
155 	}
156 
157 	if (pivot <= last)
158 		MAKE_NODE_AND_APPEND(tmp, pivot, last, 1, flags, err, err_out,
159 					diff_set);
160 
161 	return 0;
162 
163 err_out:
164 	list_for_each_entry_safe(interval, tmp, diff_set, link) {
165 		list_del(&interval->link);
166 		kfree(interval);
167 	}
168 
169 	return err;
170 }
171 
172 void usnic_uiom_put_interval_set(struct list_head *intervals)
173 {
174 	struct usnic_uiom_interval_node *interval, *tmp;
175 	list_for_each_entry_safe(interval, tmp, intervals, link)
176 		kfree(interval);
177 }
178 
179 int usnic_uiom_insert_interval(struct rb_root_cached *root, unsigned long start,
180 				unsigned long last, int flags)
181 {
182 	struct usnic_uiom_interval_node *interval, *tmp;
183 	unsigned long istart, ilast;
184 	int iref_cnt, iflags;
185 	unsigned long lpivot = start;
186 	int err = 0;
187 	LIST_HEAD(to_add);
188 	LIST_HEAD(intersection_set);
189 
190 	find_intervals_intersection_sorted(root, start, last,
191 						&intersection_set);
192 
193 	list_for_each_entry(interval, &intersection_set, link) {
194 		/*
195 		 * Invariant - lpivot is the left edge of next interval to be
196 		 * inserted
197 		 */
198 		istart = interval->start;
199 		ilast = interval->last;
200 		iref_cnt = interval->ref_cnt;
201 		iflags = interval->flags;
202 
203 		if (istart < lpivot) {
204 			MAKE_NODE_AND_APPEND(tmp, istart, lpivot - 1, iref_cnt,
205 						iflags, err, err_out, &to_add);
206 		} else if (istart > lpivot) {
207 			MAKE_NODE_AND_APPEND(tmp, lpivot, istart - 1, 1, flags,
208 						err, err_out, &to_add);
209 			lpivot = istart;
210 		} else {
211 			lpivot = istart;
212 		}
213 
214 		if (ilast > last) {
215 			MAKE_NODE_AND_APPEND(tmp, lpivot, last, iref_cnt + 1,
216 						iflags | flags, err, err_out,
217 						&to_add);
218 			MAKE_NODE_AND_APPEND(tmp, last + 1, ilast, iref_cnt,
219 						iflags, err, err_out, &to_add);
220 		} else {
221 			MAKE_NODE_AND_APPEND(tmp, lpivot, ilast, iref_cnt + 1,
222 						iflags | flags, err, err_out,
223 						&to_add);
224 		}
225 
226 		lpivot = ilast + 1;
227 	}
228 
229 	if (lpivot <= last)
230 		MAKE_NODE_AND_APPEND(tmp, lpivot, last, 1, flags, err, err_out,
231 					&to_add);
232 
233 	list_for_each_entry_safe(interval, tmp, &intersection_set, link) {
234 		usnic_uiom_interval_tree_remove(interval, root);
235 		kfree(interval);
236 	}
237 
238 	list_for_each_entry(interval, &to_add, link)
239 		usnic_uiom_interval_tree_insert(interval, root);
240 
241 	return 0;
242 
243 err_out:
244 	list_for_each_entry_safe(interval, tmp, &to_add, link)
245 		kfree(interval);
246 
247 	return err;
248 }
249 
250 void usnic_uiom_remove_interval(struct rb_root_cached *root,
251 				unsigned long start, unsigned long last,
252 				struct list_head *removed)
253 {
254 	struct usnic_uiom_interval_node *interval;
255 
256 	for (interval = usnic_uiom_interval_tree_iter_first(root, start, last);
257 			interval;
258 			interval = usnic_uiom_interval_tree_iter_next(interval,
259 									start,
260 									last)) {
261 		if (--interval->ref_cnt == 0)
262 			list_add_tail(&interval->link, removed);
263 	}
264 
265 	list_for_each_entry(interval, removed, link)
266 		usnic_uiom_interval_tree_remove(interval, root);
267 }
268 
269 INTERVAL_TREE_DEFINE(struct usnic_uiom_interval_node, rb,
270 			unsigned long, __subtree_last,
271 			START, LAST, , usnic_uiom_interval_tree)
272