1 /*
2  * Copyright (c) 2014, Cisco Systems, Inc. All rights reserved.
3  *
4  * This software is available to you under a choice of one of two
5  * licenses.  You may choose to be licensed under the terms of the GNU
6  * General Public License (GPL) Version 2, available from the file
7  * COPYING in the main directory of this source tree, or the
8  * BSD license below:
9  *
10  *     Redistribution and use in source and binary forms, with or
11  *     without modification, are permitted provided that the following
12  *     conditions are met:
13  *
14  *      - Redistributions of source code must retain the above
15  *        copyright notice, this list of conditions and the following
16  *        disclaimer.
17  *
18  *      - Redistributions in binary form must reproduce the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer in the documentation and/or other materials
21  *        provided with the distribution.
22  *
23  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30  * SOFTWARE.
31  *
32  */
33 
34 #include <linux/init.h>
35 #include <linux/list.h>
36 #include <linux/slab.h>
37 #include <linux/list_sort.h>
38 
39 #include <linux/interval_tree_generic.h>
40 #include "usnic_uiom_interval_tree.h"
41 
42 #define START(node) ((node)->start)
43 #define LAST(node) ((node)->last)
44 
45 #define MAKE_NODE(node, start, end, ref_cnt, flags, err, err_out)	\
46 		do {							\
47 			node = usnic_uiom_interval_node_alloc(start,	\
48 					end, ref_cnt, flags);		\
49 				if (!node) {				\
50 					err = -ENOMEM;			\
51 					goto err_out;			\
52 				}					\
53 		} while (0)
54 
55 #define MARK_FOR_ADD(node, list) (list_add_tail(&node->link, list))
56 
57 #define MAKE_NODE_AND_APPEND(node, start, end, ref_cnt, flags, err,	\
58 				err_out, list)				\
59 				do {					\
60 					MAKE_NODE(node, start, end,	\
61 						ref_cnt, flags, err,	\
62 						err_out);		\
63 					MARK_FOR_ADD(node, list);	\
64 				} while (0)
65 
66 #define FLAGS_EQUAL(flags1, flags2, mask)				\
67 			(((flags1) & (mask)) == ((flags2) & (mask)))
68 
69 static struct usnic_uiom_interval_node*
70 usnic_uiom_interval_node_alloc(long int start, long int last, int ref_cnt,
71 				int flags)
72 {
73 	struct usnic_uiom_interval_node *interval = kzalloc(sizeof(*interval),
74 								GFP_ATOMIC);
75 	if (!interval)
76 		return NULL;
77 
78 	interval->start = start;
79 	interval->last = last;
80 	interval->flags = flags;
81 	interval->ref_cnt = ref_cnt;
82 
83 	return interval;
84 }
85 
86 static int interval_cmp(void *priv, struct list_head *a, struct list_head *b)
87 {
88 	struct usnic_uiom_interval_node *node_a, *node_b;
89 
90 	node_a = list_entry(a, struct usnic_uiom_interval_node, link);
91 	node_b = list_entry(b, struct usnic_uiom_interval_node, link);
92 
93 	/* long to int */
94 	if (node_a->start < node_b->start)
95 		return -1;
96 	else if (node_a->start > node_b->start)
97 		return 1;
98 
99 	return 0;
100 }
101 
102 static void
103 find_intervals_intersection_sorted(struct rb_root *root, unsigned long start,
104 					unsigned long last,
105 					struct list_head *list)
106 {
107 	struct usnic_uiom_interval_node *node;
108 
109 	INIT_LIST_HEAD(list);
110 
111 	for (node = usnic_uiom_interval_tree_iter_first(root, start, last);
112 		node;
113 		node = usnic_uiom_interval_tree_iter_next(node, start, last))
114 		list_add_tail(&node->link, list);
115 
116 	list_sort(NULL, list, interval_cmp);
117 }
118 
119 int usnic_uiom_get_intervals_diff(unsigned long start, unsigned long last,
120 					int flags, int flag_mask,
121 					struct rb_root *root,
122 					struct list_head *diff_set)
123 {
124 	struct usnic_uiom_interval_node *interval, *tmp;
125 	int err = 0;
126 	long int pivot = start;
127 	LIST_HEAD(intersection_set);
128 
129 	INIT_LIST_HEAD(diff_set);
130 
131 	find_intervals_intersection_sorted(root, start, last,
132 						&intersection_set);
133 
134 	list_for_each_entry(interval, &intersection_set, link) {
135 		if (pivot < interval->start) {
136 			MAKE_NODE_AND_APPEND(tmp, pivot, interval->start - 1,
137 						1, flags, err, err_out,
138 						diff_set);
139 			pivot = interval->start;
140 		}
141 
142 		/*
143 		 * Invariant: Set [start, pivot] is either in diff_set or root,
144 		 * but not in both.
145 		 */
146 
147 		if (pivot > interval->last) {
148 			continue;
149 		} else if (pivot <= interval->last &&
150 				FLAGS_EQUAL(interval->flags, flags,
151 				flag_mask)) {
152 			pivot = interval->last + 1;
153 		}
154 	}
155 
156 	if (pivot <= last)
157 		MAKE_NODE_AND_APPEND(tmp, pivot, last, 1, flags, err, err_out,
158 					diff_set);
159 
160 	return 0;
161 
162 err_out:
163 	list_for_each_entry_safe(interval, tmp, diff_set, link) {
164 		list_del(&interval->link);
165 		kfree(interval);
166 	}
167 
168 	return err;
169 }
170 
171 void usnic_uiom_put_interval_set(struct list_head *intervals)
172 {
173 	struct usnic_uiom_interval_node *interval, *tmp;
174 	list_for_each_entry_safe(interval, tmp, intervals, link)
175 		kfree(interval);
176 }
177 
178 int usnic_uiom_insert_interval(struct rb_root *root, unsigned long start,
179 				unsigned long last, int flags)
180 {
181 	struct usnic_uiom_interval_node *interval, *tmp;
182 	unsigned long istart, ilast;
183 	int iref_cnt, iflags;
184 	unsigned long lpivot = start;
185 	int err = 0;
186 	LIST_HEAD(to_add);
187 	LIST_HEAD(intersection_set);
188 
189 	find_intervals_intersection_sorted(root, start, last,
190 						&intersection_set);
191 
192 	list_for_each_entry(interval, &intersection_set, link) {
193 		/*
194 		 * Invariant - lpivot is the left edge of next interval to be
195 		 * inserted
196 		 */
197 		istart = interval->start;
198 		ilast = interval->last;
199 		iref_cnt = interval->ref_cnt;
200 		iflags = interval->flags;
201 
202 		if (istart < lpivot) {
203 			MAKE_NODE_AND_APPEND(tmp, istart, lpivot - 1, iref_cnt,
204 						iflags, err, err_out, &to_add);
205 		} else if (istart > lpivot) {
206 			MAKE_NODE_AND_APPEND(tmp, lpivot, istart - 1, 1, flags,
207 						err, err_out, &to_add);
208 			lpivot = istart;
209 		} else {
210 			lpivot = istart;
211 		}
212 
213 		if (ilast > last) {
214 			MAKE_NODE_AND_APPEND(tmp, lpivot, last, iref_cnt + 1,
215 						iflags | flags, err, err_out,
216 						&to_add);
217 			MAKE_NODE_AND_APPEND(tmp, last + 1, ilast, iref_cnt,
218 						iflags, err, err_out, &to_add);
219 		} else {
220 			MAKE_NODE_AND_APPEND(tmp, lpivot, ilast, iref_cnt + 1,
221 						iflags | flags, err, err_out,
222 						&to_add);
223 		}
224 
225 		lpivot = ilast + 1;
226 	}
227 
228 	if (lpivot <= last)
229 		MAKE_NODE_AND_APPEND(tmp, lpivot, last, 1, flags, err, err_out,
230 					&to_add);
231 
232 	list_for_each_entry_safe(interval, tmp, &intersection_set, link) {
233 		usnic_uiom_interval_tree_remove(interval, root);
234 		kfree(interval);
235 	}
236 
237 	list_for_each_entry(interval, &to_add, link)
238 		usnic_uiom_interval_tree_insert(interval, root);
239 
240 	return 0;
241 
242 err_out:
243 	list_for_each_entry_safe(interval, tmp, &to_add, link)
244 		kfree(interval);
245 
246 	return err;
247 }
248 
249 void usnic_uiom_remove_interval(struct rb_root *root, unsigned long start,
250 				unsigned long last, struct list_head *removed)
251 {
252 	struct usnic_uiom_interval_node *interval;
253 
254 	for (interval = usnic_uiom_interval_tree_iter_first(root, start, last);
255 			interval;
256 			interval = usnic_uiom_interval_tree_iter_next(interval,
257 									start,
258 									last)) {
259 		if (--interval->ref_cnt == 0)
260 			list_add_tail(&interval->link, removed);
261 	}
262 
263 	list_for_each_entry(interval, removed, link)
264 		usnic_uiom_interval_tree_remove(interval, root);
265 }
266 
267 INTERVAL_TREE_DEFINE(struct usnic_uiom_interval_node, rb,
268 			unsigned long, __subtree_last,
269 			START, LAST, , usnic_uiom_interval_tree)
270