1 #include <linux/init.h>
2 #include <linux/list.h>
3 #include <linux/slab.h>
4 #include <linux/list_sort.h>
5 
6 #include <linux/interval_tree_generic.h>
7 #include "usnic_uiom_interval_tree.h"
8 
9 #define START(node) ((node)->start)
10 #define LAST(node) ((node)->last)
11 
12 #define MAKE_NODE(node, start, end, ref_cnt, flags, err, err_out)	\
13 		do {							\
14 			node = usnic_uiom_interval_node_alloc(start,	\
15 					end, ref_cnt, flags);		\
16 				if (!node) {				\
17 					err = -ENOMEM;			\
18 					goto err_out;			\
19 				}					\
20 		} while (0)
21 
22 #define MARK_FOR_ADD(node, list) (list_add_tail(&node->link, list))
23 
24 #define MAKE_NODE_AND_APPEND(node, start, end, ref_cnt, flags, err,	\
25 				err_out, list)				\
26 				do {					\
27 					MAKE_NODE(node, start, end,	\
28 						ref_cnt, flags, err,	\
29 						err_out);		\
30 					MARK_FOR_ADD(node, list);	\
31 				} while (0)
32 
33 #define FLAGS_EQUAL(flags1, flags2, mask)				\
34 			(((flags1) & (mask)) == ((flags2) & (mask)))
35 
36 static struct usnic_uiom_interval_node*
37 usnic_uiom_interval_node_alloc(long int start, long int last, int ref_cnt,
38 				int flags)
39 {
40 	struct usnic_uiom_interval_node *interval = kzalloc(sizeof(*interval),
41 								GFP_ATOMIC);
42 	if (!interval)
43 		return NULL;
44 
45 	interval->start = start;
46 	interval->last = last;
47 	interval->flags = flags;
48 	interval->ref_cnt = ref_cnt;
49 
50 	return interval;
51 }
52 
53 static int interval_cmp(void *priv, struct list_head *a, struct list_head *b)
54 {
55 	struct usnic_uiom_interval_node *node_a, *node_b;
56 
57 	node_a = list_entry(a, struct usnic_uiom_interval_node, link);
58 	node_b = list_entry(b, struct usnic_uiom_interval_node, link);
59 
60 	/* long to int */
61 	if (node_a->start < node_b->start)
62 		return -1;
63 	else if (node_a->start > node_b->start)
64 		return 1;
65 
66 	return 0;
67 }
68 
69 static void
70 find_intervals_intersection_sorted(struct rb_root *root, unsigned long start,
71 					unsigned long last,
72 					struct list_head *list)
73 {
74 	struct usnic_uiom_interval_node *node;
75 
76 	INIT_LIST_HEAD(list);
77 
78 	for (node = usnic_uiom_interval_tree_iter_first(root, start, last);
79 		node;
80 		node = usnic_uiom_interval_tree_iter_next(node, start, last))
81 		list_add_tail(&node->link, list);
82 
83 	list_sort(NULL, list, interval_cmp);
84 }
85 
86 int usnic_uiom_get_intervals_diff(unsigned long start, unsigned long last,
87 					int flags, int flag_mask,
88 					struct rb_root *root,
89 					struct list_head *diff_set)
90 {
91 	struct usnic_uiom_interval_node *interval, *tmp;
92 	int err = 0;
93 	long int pivot = start;
94 	LIST_HEAD(intersection_set);
95 
96 	INIT_LIST_HEAD(diff_set);
97 
98 	find_intervals_intersection_sorted(root, start, last,
99 						&intersection_set);
100 
101 	list_for_each_entry(interval, &intersection_set, link) {
102 		if (pivot < interval->start) {
103 			MAKE_NODE_AND_APPEND(tmp, pivot, interval->start - 1,
104 						1, flags, err, err_out,
105 						diff_set);
106 			pivot = interval->start;
107 		}
108 
109 		/*
110 		 * Invariant: Set [start, pivot] is either in diff_set or root,
111 		 * but not in both.
112 		 */
113 
114 		if (pivot > interval->last) {
115 			continue;
116 		} else if (pivot <= interval->last &&
117 				FLAGS_EQUAL(interval->flags, flags,
118 				flag_mask)) {
119 			pivot = interval->last + 1;
120 		}
121 	}
122 
123 	if (pivot <= last)
124 		MAKE_NODE_AND_APPEND(tmp, pivot, last, 1, flags, err, err_out,
125 					diff_set);
126 
127 	return 0;
128 
129 err_out:
130 	list_for_each_entry_safe(interval, tmp, diff_set, link) {
131 		list_del(&interval->link);
132 		kfree(interval);
133 	}
134 
135 	return err;
136 }
137 
138 void usnic_uiom_put_interval_set(struct list_head *intervals)
139 {
140 	struct usnic_uiom_interval_node *interval, *tmp;
141 	list_for_each_entry_safe(interval, tmp, intervals, link)
142 		kfree(interval);
143 }
144 
145 int usnic_uiom_insert_interval(struct rb_root *root, unsigned long start,
146 				unsigned long last, int flags)
147 {
148 	struct usnic_uiom_interval_node *interval, *tmp;
149 	unsigned long istart, ilast;
150 	int iref_cnt, iflags;
151 	unsigned long lpivot = start;
152 	int err = 0;
153 	LIST_HEAD(to_add);
154 	LIST_HEAD(intersection_set);
155 
156 	find_intervals_intersection_sorted(root, start, last,
157 						&intersection_set);
158 
159 	list_for_each_entry(interval, &intersection_set, link) {
160 		/*
161 		 * Invariant - lpivot is the left edge of next interval to be
162 		 * inserted
163 		 */
164 		istart = interval->start;
165 		ilast = interval->last;
166 		iref_cnt = interval->ref_cnt;
167 		iflags = interval->flags;
168 
169 		if (istart < lpivot) {
170 			MAKE_NODE_AND_APPEND(tmp, istart, lpivot - 1, iref_cnt,
171 						iflags, err, err_out, &to_add);
172 		} else if (istart > lpivot) {
173 			MAKE_NODE_AND_APPEND(tmp, lpivot, istart - 1, 1, flags,
174 						err, err_out, &to_add);
175 			lpivot = istart;
176 		} else {
177 			lpivot = istart;
178 		}
179 
180 		if (ilast > last) {
181 			MAKE_NODE_AND_APPEND(tmp, lpivot, last, iref_cnt + 1,
182 						iflags | flags, err, err_out,
183 						&to_add);
184 			MAKE_NODE_AND_APPEND(tmp, last + 1, ilast, iref_cnt,
185 						iflags, err, err_out, &to_add);
186 		} else {
187 			MAKE_NODE_AND_APPEND(tmp, lpivot, ilast, iref_cnt + 1,
188 						iflags | flags, err, err_out,
189 						&to_add);
190 		}
191 
192 		lpivot = ilast + 1;
193 	}
194 
195 	if (lpivot <= last)
196 		MAKE_NODE_AND_APPEND(tmp, lpivot, last, 1, flags, err, err_out,
197 					&to_add);
198 
199 	list_for_each_entry_safe(interval, tmp, &intersection_set, link) {
200 		usnic_uiom_interval_tree_remove(interval, root);
201 		kfree(interval);
202 	}
203 
204 	list_for_each_entry(interval, &to_add, link)
205 		usnic_uiom_interval_tree_insert(interval, root);
206 
207 	return 0;
208 
209 err_out:
210 	list_for_each_entry_safe(interval, tmp, &to_add, link)
211 		kfree(interval);
212 
213 	return err;
214 }
215 
216 void usnic_uiom_remove_interval(struct rb_root *root, unsigned long start,
217 				unsigned long last, struct list_head *removed)
218 {
219 	struct usnic_uiom_interval_node *interval;
220 
221 	for (interval = usnic_uiom_interval_tree_iter_first(root, start, last);
222 			interval;
223 			interval = usnic_uiom_interval_tree_iter_next(interval,
224 									start,
225 									last)) {
226 		if (--interval->ref_cnt == 0)
227 			list_add_tail(&interval->link, removed);
228 	}
229 
230 	list_for_each_entry(interval, removed, link)
231 		usnic_uiom_interval_tree_remove(interval, root);
232 }
233 
234 INTERVAL_TREE_DEFINE(struct usnic_uiom_interval_node, rb,
235 			unsigned long, __subtree_last,
236 			START, LAST, , usnic_uiom_interval_tree)
237