xref: /openbmc/linux/lib/test_objagg.c (revision 4f2c0a4acffbec01079c28f839422e64ddeff004)
1  // SPDX-License-Identifier: BSD-3-Clause OR GPL-2.0
2  /* Copyright (c) 2018 Mellanox Technologies. All rights reserved */
3  
4  #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
5  
6  #include <linux/kernel.h>
7  #include <linux/module.h>
8  #include <linux/slab.h>
9  #include <linux/random.h>
10  #include <linux/objagg.h>
11  
12  struct tokey {
13  	unsigned int id;
14  };
15  
16  #define NUM_KEYS 32
17  
key_id_index(unsigned int key_id)18  static int key_id_index(unsigned int key_id)
19  {
20  	if (key_id >= NUM_KEYS) {
21  		WARN_ON(1);
22  		return 0;
23  	}
24  	return key_id;
25  }
26  
27  #define BUF_LEN 128
28  
29  struct world {
30  	unsigned int root_count;
31  	unsigned int delta_count;
32  	char next_root_buf[BUF_LEN];
33  	struct objagg_obj *objagg_objs[NUM_KEYS];
34  	unsigned int key_refs[NUM_KEYS];
35  };
36  
37  struct root {
38  	struct tokey key;
39  	char buf[BUF_LEN];
40  };
41  
42  struct delta {
43  	unsigned int key_id_diff;
44  };
45  
world_obj_get(struct world * world,struct objagg * objagg,unsigned int key_id)46  static struct objagg_obj *world_obj_get(struct world *world,
47  					struct objagg *objagg,
48  					unsigned int key_id)
49  {
50  	struct objagg_obj *objagg_obj;
51  	struct tokey key;
52  	int err;
53  
54  	key.id = key_id;
55  	objagg_obj = objagg_obj_get(objagg, &key);
56  	if (IS_ERR(objagg_obj)) {
57  		pr_err("Key %u: Failed to get object.\n", key_id);
58  		return objagg_obj;
59  	}
60  	if (!world->key_refs[key_id_index(key_id)]) {
61  		world->objagg_objs[key_id_index(key_id)] = objagg_obj;
62  	} else if (world->objagg_objs[key_id_index(key_id)] != objagg_obj) {
63  		pr_err("Key %u: God another object for the same key.\n",
64  		       key_id);
65  		err = -EINVAL;
66  		goto err_key_id_check;
67  	}
68  	world->key_refs[key_id_index(key_id)]++;
69  	return objagg_obj;
70  
71  err_key_id_check:
72  	objagg_obj_put(objagg, objagg_obj);
73  	return ERR_PTR(err);
74  }
75  
world_obj_put(struct world * world,struct objagg * objagg,unsigned int key_id)76  static void world_obj_put(struct world *world, struct objagg *objagg,
77  			  unsigned int key_id)
78  {
79  	struct objagg_obj *objagg_obj;
80  
81  	if (!world->key_refs[key_id_index(key_id)])
82  		return;
83  	objagg_obj = world->objagg_objs[key_id_index(key_id)];
84  	objagg_obj_put(objagg, objagg_obj);
85  	world->key_refs[key_id_index(key_id)]--;
86  }
87  
88  #define MAX_KEY_ID_DIFF 5
89  
delta_check(void * priv,const void * parent_obj,const void * obj)90  static bool delta_check(void *priv, const void *parent_obj, const void *obj)
91  {
92  	const struct tokey *parent_key = parent_obj;
93  	const struct tokey *key = obj;
94  	int diff = key->id - parent_key->id;
95  
96  	return diff >= 0 && diff <= MAX_KEY_ID_DIFF;
97  }
98  
delta_create(void * priv,void * parent_obj,void * obj)99  static void *delta_create(void *priv, void *parent_obj, void *obj)
100  {
101  	struct tokey *parent_key = parent_obj;
102  	struct world *world = priv;
103  	struct tokey *key = obj;
104  	int diff = key->id - parent_key->id;
105  	struct delta *delta;
106  
107  	if (!delta_check(priv, parent_obj, obj))
108  		return ERR_PTR(-EINVAL);
109  
110  	delta = kzalloc(sizeof(*delta), GFP_KERNEL);
111  	if (!delta)
112  		return ERR_PTR(-ENOMEM);
113  	delta->key_id_diff = diff;
114  	world->delta_count++;
115  	return delta;
116  }
117  
delta_destroy(void * priv,void * delta_priv)118  static void delta_destroy(void *priv, void *delta_priv)
119  {
120  	struct delta *delta = delta_priv;
121  	struct world *world = priv;
122  
123  	world->delta_count--;
124  	kfree(delta);
125  }
126  
root_create(void * priv,void * obj,unsigned int id)127  static void *root_create(void *priv, void *obj, unsigned int id)
128  {
129  	struct world *world = priv;
130  	struct tokey *key = obj;
131  	struct root *root;
132  
133  	root = kzalloc(sizeof(*root), GFP_KERNEL);
134  	if (!root)
135  		return ERR_PTR(-ENOMEM);
136  	memcpy(&root->key, key, sizeof(root->key));
137  	memcpy(root->buf, world->next_root_buf, sizeof(root->buf));
138  	world->root_count++;
139  	return root;
140  }
141  
root_destroy(void * priv,void * root_priv)142  static void root_destroy(void *priv, void *root_priv)
143  {
144  	struct root *root = root_priv;
145  	struct world *world = priv;
146  
147  	world->root_count--;
148  	kfree(root);
149  }
150  
test_nodelta_obj_get(struct world * world,struct objagg * objagg,unsigned int key_id,bool should_create_root)151  static int test_nodelta_obj_get(struct world *world, struct objagg *objagg,
152  				unsigned int key_id, bool should_create_root)
153  {
154  	unsigned int orig_root_count = world->root_count;
155  	struct objagg_obj *objagg_obj;
156  	const struct root *root;
157  	int err;
158  
159  	if (should_create_root)
160  		get_random_bytes(world->next_root_buf,
161  			      sizeof(world->next_root_buf));
162  
163  	objagg_obj = world_obj_get(world, objagg, key_id);
164  	if (IS_ERR(objagg_obj)) {
165  		pr_err("Key %u: Failed to get object.\n", key_id);
166  		return PTR_ERR(objagg_obj);
167  	}
168  	if (should_create_root) {
169  		if (world->root_count != orig_root_count + 1) {
170  			pr_err("Key %u: Root was not created\n", key_id);
171  			err = -EINVAL;
172  			goto err_check_root_count;
173  		}
174  	} else {
175  		if (world->root_count != orig_root_count) {
176  			pr_err("Key %u: Root was incorrectly created\n",
177  			       key_id);
178  			err = -EINVAL;
179  			goto err_check_root_count;
180  		}
181  	}
182  	root = objagg_obj_root_priv(objagg_obj);
183  	if (root->key.id != key_id) {
184  		pr_err("Key %u: Root has unexpected key id\n", key_id);
185  		err = -EINVAL;
186  		goto err_check_key_id;
187  	}
188  	if (should_create_root &&
189  	    memcmp(world->next_root_buf, root->buf, sizeof(root->buf))) {
190  		pr_err("Key %u: Buffer does not match the expected content\n",
191  		       key_id);
192  		err = -EINVAL;
193  		goto err_check_buf;
194  	}
195  	return 0;
196  
197  err_check_buf:
198  err_check_key_id:
199  err_check_root_count:
200  	objagg_obj_put(objagg, objagg_obj);
201  	return err;
202  }
203  
test_nodelta_obj_put(struct world * world,struct objagg * objagg,unsigned int key_id,bool should_destroy_root)204  static int test_nodelta_obj_put(struct world *world, struct objagg *objagg,
205  				unsigned int key_id, bool should_destroy_root)
206  {
207  	unsigned int orig_root_count = world->root_count;
208  
209  	world_obj_put(world, objagg, key_id);
210  
211  	if (should_destroy_root) {
212  		if (world->root_count != orig_root_count - 1) {
213  			pr_err("Key %u: Root was not destroyed\n", key_id);
214  			return -EINVAL;
215  		}
216  	} else {
217  		if (world->root_count != orig_root_count) {
218  			pr_err("Key %u: Root was incorrectly destroyed\n",
219  			       key_id);
220  			return -EINVAL;
221  		}
222  	}
223  	return 0;
224  }
225  
check_stats_zero(struct objagg * objagg)226  static int check_stats_zero(struct objagg *objagg)
227  {
228  	const struct objagg_stats *stats;
229  	int err = 0;
230  
231  	stats = objagg_stats_get(objagg);
232  	if (IS_ERR(stats))
233  		return PTR_ERR(stats);
234  
235  	if (stats->stats_info_count != 0) {
236  		pr_err("Stats: Object count is not zero while it should be\n");
237  		err = -EINVAL;
238  	}
239  
240  	objagg_stats_put(stats);
241  	return err;
242  }
243  
check_stats_nodelta(struct objagg * objagg)244  static int check_stats_nodelta(struct objagg *objagg)
245  {
246  	const struct objagg_stats *stats;
247  	int i;
248  	int err;
249  
250  	stats = objagg_stats_get(objagg);
251  	if (IS_ERR(stats))
252  		return PTR_ERR(stats);
253  
254  	if (stats->stats_info_count != NUM_KEYS) {
255  		pr_err("Stats: Unexpected object count (%u expected, %u returned)\n",
256  		       NUM_KEYS, stats->stats_info_count);
257  		err = -EINVAL;
258  		goto stats_put;
259  	}
260  
261  	for (i = 0; i < stats->stats_info_count; i++) {
262  		if (stats->stats_info[i].stats.user_count != 2) {
263  			pr_err("Stats: incorrect user count\n");
264  			err = -EINVAL;
265  			goto stats_put;
266  		}
267  		if (stats->stats_info[i].stats.delta_user_count != 2) {
268  			pr_err("Stats: incorrect delta user count\n");
269  			err = -EINVAL;
270  			goto stats_put;
271  		}
272  	}
273  	err = 0;
274  
275  stats_put:
276  	objagg_stats_put(stats);
277  	return err;
278  }
279  
delta_check_dummy(void * priv,const void * parent_obj,const void * obj)280  static bool delta_check_dummy(void *priv, const void *parent_obj,
281  			      const void *obj)
282  {
283  	return false;
284  }
285  
delta_create_dummy(void * priv,void * parent_obj,void * obj)286  static void *delta_create_dummy(void *priv, void *parent_obj, void *obj)
287  {
288  	return ERR_PTR(-EOPNOTSUPP);
289  }
290  
delta_destroy_dummy(void * priv,void * delta_priv)291  static void delta_destroy_dummy(void *priv, void *delta_priv)
292  {
293  }
294  
295  static const struct objagg_ops nodelta_ops = {
296  	.obj_size = sizeof(struct tokey),
297  	.delta_check = delta_check_dummy,
298  	.delta_create = delta_create_dummy,
299  	.delta_destroy = delta_destroy_dummy,
300  	.root_create = root_create,
301  	.root_destroy = root_destroy,
302  };
303  
test_nodelta(void)304  static int test_nodelta(void)
305  {
306  	struct world world = {};
307  	struct objagg *objagg;
308  	int i;
309  	int err;
310  
311  	objagg = objagg_create(&nodelta_ops, NULL, &world);
312  	if (IS_ERR(objagg))
313  		return PTR_ERR(objagg);
314  
315  	err = check_stats_zero(objagg);
316  	if (err)
317  		goto err_stats_first_zero;
318  
319  	/* First round of gets, the root objects should be created */
320  	for (i = 0; i < NUM_KEYS; i++) {
321  		err = test_nodelta_obj_get(&world, objagg, i, true);
322  		if (err)
323  			goto err_obj_first_get;
324  	}
325  
326  	/* Do the second round of gets, all roots are already created,
327  	 * make sure that no new root is created
328  	 */
329  	for (i = 0; i < NUM_KEYS; i++) {
330  		err = test_nodelta_obj_get(&world, objagg, i, false);
331  		if (err)
332  			goto err_obj_second_get;
333  	}
334  
335  	err = check_stats_nodelta(objagg);
336  	if (err)
337  		goto err_stats_nodelta;
338  
339  	for (i = NUM_KEYS - 1; i >= 0; i--) {
340  		err = test_nodelta_obj_put(&world, objagg, i, false);
341  		if (err)
342  			goto err_obj_first_put;
343  	}
344  	for (i = NUM_KEYS - 1; i >= 0; i--) {
345  		err = test_nodelta_obj_put(&world, objagg, i, true);
346  		if (err)
347  			goto err_obj_second_put;
348  	}
349  
350  	err = check_stats_zero(objagg);
351  	if (err)
352  		goto err_stats_second_zero;
353  
354  	objagg_destroy(objagg);
355  	return 0;
356  
357  err_stats_nodelta:
358  err_obj_first_put:
359  err_obj_second_get:
360  	for (i--; i >= 0; i--)
361  		world_obj_put(&world, objagg, i);
362  
363  	i = NUM_KEYS;
364  err_obj_first_get:
365  err_obj_second_put:
366  	for (i--; i >= 0; i--)
367  		world_obj_put(&world, objagg, i);
368  err_stats_first_zero:
369  err_stats_second_zero:
370  	objagg_destroy(objagg);
371  	return err;
372  }
373  
374  static const struct objagg_ops delta_ops = {
375  	.obj_size = sizeof(struct tokey),
376  	.delta_check = delta_check,
377  	.delta_create = delta_create,
378  	.delta_destroy = delta_destroy,
379  	.root_create = root_create,
380  	.root_destroy = root_destroy,
381  };
382  
383  enum action {
384  	ACTION_GET,
385  	ACTION_PUT,
386  };
387  
388  enum expect_delta {
389  	EXPECT_DELTA_SAME,
390  	EXPECT_DELTA_INC,
391  	EXPECT_DELTA_DEC,
392  };
393  
394  enum expect_root {
395  	EXPECT_ROOT_SAME,
396  	EXPECT_ROOT_INC,
397  	EXPECT_ROOT_DEC,
398  };
399  
400  struct expect_stats_info {
401  	struct objagg_obj_stats stats;
402  	bool is_root;
403  	unsigned int key_id;
404  };
405  
406  struct expect_stats {
407  	unsigned int info_count;
408  	struct expect_stats_info info[NUM_KEYS];
409  };
410  
411  struct action_item {
412  	unsigned int key_id;
413  	enum action action;
414  	enum expect_delta expect_delta;
415  	enum expect_root expect_root;
416  	struct expect_stats expect_stats;
417  };
418  
419  #define EXPECT_STATS(count, ...)		\
420  {						\
421  	.info_count = count,			\
422  	.info = { __VA_ARGS__ }			\
423  }
424  
425  #define ROOT(key_id, user_count, delta_user_count)	\
426  	{{user_count, delta_user_count}, true, key_id}
427  
428  #define DELTA(key_id, user_count)			\
429  	{{user_count, user_count}, false, key_id}
430  
431  static const struct action_item action_items[] = {
432  	{
433  		1, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
434  		EXPECT_STATS(1, ROOT(1, 1, 1)),
435  	},	/* r: 1			d: */
436  	{
437  		7, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
438  		EXPECT_STATS(2, ROOT(1, 1, 1), ROOT(7, 1, 1)),
439  	},	/* r: 1, 7		d: */
440  	{
441  		3, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
442  		EXPECT_STATS(3, ROOT(1, 1, 2), ROOT(7, 1, 1),
443  				DELTA(3, 1)),
444  	},	/* r: 1, 7		d: 3^1 */
445  	{
446  		5, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
447  		EXPECT_STATS(4, ROOT(1, 1, 3), ROOT(7, 1, 1),
448  				DELTA(3, 1), DELTA(5, 1)),
449  	},	/* r: 1, 7		d: 3^1, 5^1 */
450  	{
451  		3, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
452  		EXPECT_STATS(4, ROOT(1, 1, 4), ROOT(7, 1, 1),
453  				DELTA(3, 2), DELTA(5, 1)),
454  	},	/* r: 1, 7		d: 3^1, 3^1, 5^1 */
455  	{
456  		1, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
457  		EXPECT_STATS(4, ROOT(1, 2, 5), ROOT(7, 1, 1),
458  				DELTA(3, 2), DELTA(5, 1)),
459  	},	/* r: 1, 1, 7		d: 3^1, 3^1, 5^1 */
460  	{
461  		30, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
462  		EXPECT_STATS(5, ROOT(1, 2, 5), ROOT(7, 1, 1), ROOT(30, 1, 1),
463  				DELTA(3, 2), DELTA(5, 1)),
464  	},	/* r: 1, 1, 7, 30	d: 3^1, 3^1, 5^1 */
465  	{
466  		8, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
467  		EXPECT_STATS(6, ROOT(1, 2, 5), ROOT(7, 1, 2), ROOT(30, 1, 1),
468  				DELTA(3, 2), DELTA(5, 1), DELTA(8, 1)),
469  	},	/* r: 1, 1, 7, 30	d: 3^1, 3^1, 5^1, 8^7 */
470  	{
471  		8, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
472  		EXPECT_STATS(6, ROOT(1, 2, 5), ROOT(7, 1, 3), ROOT(30, 1, 1),
473  				DELTA(3, 2), DELTA(8, 2), DELTA(5, 1)),
474  	},	/* r: 1, 1, 7, 30	d: 3^1, 3^1, 5^1, 8^7, 8^7 */
475  	{
476  		3, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
477  		EXPECT_STATS(6, ROOT(1, 2, 4), ROOT(7, 1, 3), ROOT(30, 1, 1),
478  				DELTA(8, 2), DELTA(3, 1), DELTA(5, 1)),
479  	},	/* r: 1, 1, 7, 30	d: 3^1, 5^1, 8^7, 8^7 */
480  	{
481  		3, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
482  		EXPECT_STATS(5, ROOT(1, 2, 3), ROOT(7, 1, 3), ROOT(30, 1, 1),
483  				DELTA(8, 2), DELTA(5, 1)),
484  	},	/* r: 1, 1, 7, 30	d: 5^1, 8^7, 8^7 */
485  	{
486  		1, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
487  		EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(1, 1, 2), ROOT(30, 1, 1),
488  				DELTA(8, 2), DELTA(5, 1)),
489  	},	/* r: 1, 7, 30		d: 5^1, 8^7, 8^7 */
490  	{
491  		1, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
492  		EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(30, 1, 1), ROOT(1, 0, 1),
493  				DELTA(8, 2), DELTA(5, 1)),
494  	},	/* r: 7, 30		d: 5^1, 8^7, 8^7 */
495  	{
496  		5, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_DEC,
497  		EXPECT_STATS(3, ROOT(7, 1, 3), ROOT(30, 1, 1),
498  				DELTA(8, 2)),
499  	},	/* r: 7, 30		d: 8^7, 8^7 */
500  	{
501  		5, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
502  		EXPECT_STATS(4, ROOT(7, 1, 3), ROOT(30, 1, 1), ROOT(5, 1, 1),
503  				DELTA(8, 2)),
504  	},	/* r: 7, 30, 5		d: 8^7, 8^7 */
505  	{
506  		6, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
507  		EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(5, 1, 2), ROOT(30, 1, 1),
508  				DELTA(8, 2), DELTA(6, 1)),
509  	},	/* r: 7, 30, 5		d: 8^7, 8^7, 6^5 */
510  	{
511  		8, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
512  		EXPECT_STATS(5, ROOT(7, 1, 4), ROOT(5, 1, 2), ROOT(30, 1, 1),
513  				DELTA(8, 3), DELTA(6, 1)),
514  	},	/* r: 7, 30, 5		d: 8^7, 8^7, 8^7, 6^5 */
515  	{
516  		8, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
517  		EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(5, 1, 2), ROOT(30, 1, 1),
518  				DELTA(8, 2), DELTA(6, 1)),
519  	},	/* r: 7, 30, 5		d: 8^7, 8^7, 6^5 */
520  	{
521  		8, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
522  		EXPECT_STATS(5, ROOT(7, 1, 2), ROOT(5, 1, 2), ROOT(30, 1, 1),
523  				DELTA(8, 1), DELTA(6, 1)),
524  	},	/* r: 7, 30, 5		d: 8^7, 6^5 */
525  	{
526  		8, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
527  		EXPECT_STATS(4, ROOT(5, 1, 2), ROOT(7, 1, 1), ROOT(30, 1, 1),
528  				DELTA(6, 1)),
529  	},	/* r: 7, 30, 5		d: 6^5 */
530  	{
531  		8, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
532  		EXPECT_STATS(5, ROOT(5, 1, 3), ROOT(7, 1, 1), ROOT(30, 1, 1),
533  				DELTA(6, 1), DELTA(8, 1)),
534  	},	/* r: 7, 30, 5		d: 6^5, 8^5 */
535  	{
536  		7, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_DEC,
537  		EXPECT_STATS(4, ROOT(5, 1, 3), ROOT(30, 1, 1),
538  				DELTA(6, 1), DELTA(8, 1)),
539  	},	/* r: 30, 5		d: 6^5, 8^5 */
540  	{
541  		30, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_DEC,
542  		EXPECT_STATS(3, ROOT(5, 1, 3),
543  				DELTA(6, 1), DELTA(8, 1)),
544  	},	/* r: 5			d: 6^5, 8^5 */
545  	{
546  		5, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
547  		EXPECT_STATS(3, ROOT(5, 0, 2),
548  				DELTA(6, 1), DELTA(8, 1)),
549  	},	/* r:			d: 6^5, 8^5 */
550  	{
551  		6, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
552  		EXPECT_STATS(2, ROOT(5, 0, 1),
553  				DELTA(8, 1)),
554  	},	/* r:			d: 6^5 */
555  	{
556  		8, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_DEC,
557  		EXPECT_STATS(0, ),
558  	},	/* r:			d: */
559  };
560  
check_expect(struct world * world,const struct action_item * action_item,unsigned int orig_delta_count,unsigned int orig_root_count)561  static int check_expect(struct world *world,
562  			const struct action_item *action_item,
563  			unsigned int orig_delta_count,
564  			unsigned int orig_root_count)
565  {
566  	unsigned int key_id = action_item->key_id;
567  
568  	switch (action_item->expect_delta) {
569  	case EXPECT_DELTA_SAME:
570  		if (orig_delta_count != world->delta_count) {
571  			pr_err("Key %u: Delta count changed while expected to remain the same.\n",
572  			       key_id);
573  			return -EINVAL;
574  		}
575  		break;
576  	case EXPECT_DELTA_INC:
577  		if (WARN_ON(action_item->action == ACTION_PUT))
578  			return -EINVAL;
579  		if (orig_delta_count + 1 != world->delta_count) {
580  			pr_err("Key %u: Delta count was not incremented.\n",
581  			       key_id);
582  			return -EINVAL;
583  		}
584  		break;
585  	case EXPECT_DELTA_DEC:
586  		if (WARN_ON(action_item->action == ACTION_GET))
587  			return -EINVAL;
588  		if (orig_delta_count - 1 != world->delta_count) {
589  			pr_err("Key %u: Delta count was not decremented.\n",
590  			       key_id);
591  			return -EINVAL;
592  		}
593  		break;
594  	}
595  
596  	switch (action_item->expect_root) {
597  	case EXPECT_ROOT_SAME:
598  		if (orig_root_count != world->root_count) {
599  			pr_err("Key %u: Root count changed while expected to remain the same.\n",
600  			       key_id);
601  			return -EINVAL;
602  		}
603  		break;
604  	case EXPECT_ROOT_INC:
605  		if (WARN_ON(action_item->action == ACTION_PUT))
606  			return -EINVAL;
607  		if (orig_root_count + 1 != world->root_count) {
608  			pr_err("Key %u: Root count was not incremented.\n",
609  			       key_id);
610  			return -EINVAL;
611  		}
612  		break;
613  	case EXPECT_ROOT_DEC:
614  		if (WARN_ON(action_item->action == ACTION_GET))
615  			return -EINVAL;
616  		if (orig_root_count - 1 != world->root_count) {
617  			pr_err("Key %u: Root count was not decremented.\n",
618  			       key_id);
619  			return -EINVAL;
620  		}
621  	}
622  
623  	return 0;
624  }
625  
obj_to_key_id(struct objagg_obj * objagg_obj)626  static unsigned int obj_to_key_id(struct objagg_obj *objagg_obj)
627  {
628  	const struct tokey *root_key;
629  	const struct delta *delta;
630  	unsigned int key_id;
631  
632  	root_key = objagg_obj_root_priv(objagg_obj);
633  	key_id = root_key->id;
634  	delta = objagg_obj_delta_priv(objagg_obj);
635  	if (delta)
636  		key_id += delta->key_id_diff;
637  	return key_id;
638  }
639  
640  static int
check_expect_stats_nums(const struct objagg_obj_stats_info * stats_info,const struct expect_stats_info * expect_stats_info,const char ** errmsg)641  check_expect_stats_nums(const struct objagg_obj_stats_info *stats_info,
642  			const struct expect_stats_info *expect_stats_info,
643  			const char **errmsg)
644  {
645  	if (stats_info->is_root != expect_stats_info->is_root) {
646  		if (errmsg)
647  			*errmsg = "Incorrect root/delta indication";
648  		return -EINVAL;
649  	}
650  	if (stats_info->stats.user_count !=
651  	    expect_stats_info->stats.user_count) {
652  		if (errmsg)
653  			*errmsg = "Incorrect user count";
654  		return -EINVAL;
655  	}
656  	if (stats_info->stats.delta_user_count !=
657  	    expect_stats_info->stats.delta_user_count) {
658  		if (errmsg)
659  			*errmsg = "Incorrect delta user count";
660  		return -EINVAL;
661  	}
662  	return 0;
663  }
664  
665  static int
check_expect_stats_key_id(const struct objagg_obj_stats_info * stats_info,const struct expect_stats_info * expect_stats_info,const char ** errmsg)666  check_expect_stats_key_id(const struct objagg_obj_stats_info *stats_info,
667  			  const struct expect_stats_info *expect_stats_info,
668  			  const char **errmsg)
669  {
670  	if (obj_to_key_id(stats_info->objagg_obj) !=
671  	    expect_stats_info->key_id) {
672  		if (errmsg)
673  			*errmsg = "incorrect key id";
674  		return -EINVAL;
675  	}
676  	return 0;
677  }
678  
check_expect_stats_neigh(const struct objagg_stats * stats,const struct expect_stats * expect_stats,int pos)679  static int check_expect_stats_neigh(const struct objagg_stats *stats,
680  				    const struct expect_stats *expect_stats,
681  				    int pos)
682  {
683  	int i;
684  	int err;
685  
686  	for (i = pos - 1; i >= 0; i--) {
687  		err = check_expect_stats_nums(&stats->stats_info[i],
688  					      &expect_stats->info[pos], NULL);
689  		if (err)
690  			break;
691  		err = check_expect_stats_key_id(&stats->stats_info[i],
692  						&expect_stats->info[pos], NULL);
693  		if (!err)
694  			return 0;
695  	}
696  	for (i = pos + 1; i < stats->stats_info_count; i++) {
697  		err = check_expect_stats_nums(&stats->stats_info[i],
698  					      &expect_stats->info[pos], NULL);
699  		if (err)
700  			break;
701  		err = check_expect_stats_key_id(&stats->stats_info[i],
702  						&expect_stats->info[pos], NULL);
703  		if (!err)
704  			return 0;
705  	}
706  	return -EINVAL;
707  }
708  
__check_expect_stats(const struct objagg_stats * stats,const struct expect_stats * expect_stats,const char ** errmsg)709  static int __check_expect_stats(const struct objagg_stats *stats,
710  				const struct expect_stats *expect_stats,
711  				const char **errmsg)
712  {
713  	int i;
714  	int err;
715  
716  	if (stats->stats_info_count != expect_stats->info_count) {
717  		*errmsg = "Unexpected object count";
718  		return -EINVAL;
719  	}
720  
721  	for (i = 0; i < stats->stats_info_count; i++) {
722  		err = check_expect_stats_nums(&stats->stats_info[i],
723  					      &expect_stats->info[i], errmsg);
724  		if (err)
725  			return err;
726  		err = check_expect_stats_key_id(&stats->stats_info[i],
727  						&expect_stats->info[i], errmsg);
728  		if (err) {
729  			/* It is possible that one of the neighbor stats with
730  			 * same numbers have the correct key id, so check it
731  			 */
732  			err = check_expect_stats_neigh(stats, expect_stats, i);
733  			if (err)
734  				return err;
735  		}
736  	}
737  	return 0;
738  }
739  
check_expect_stats(struct objagg * objagg,const struct expect_stats * expect_stats,const char ** errmsg)740  static int check_expect_stats(struct objagg *objagg,
741  			      const struct expect_stats *expect_stats,
742  			      const char **errmsg)
743  {
744  	const struct objagg_stats *stats;
745  	int err;
746  
747  	stats = objagg_stats_get(objagg);
748  	if (IS_ERR(stats)) {
749  		*errmsg = "objagg_stats_get() failed.";
750  		return PTR_ERR(stats);
751  	}
752  	err = __check_expect_stats(stats, expect_stats, errmsg);
753  	objagg_stats_put(stats);
754  	return err;
755  }
756  
test_delta_action_item(struct world * world,struct objagg * objagg,const struct action_item * action_item,bool inverse)757  static int test_delta_action_item(struct world *world,
758  				  struct objagg *objagg,
759  				  const struct action_item *action_item,
760  				  bool inverse)
761  {
762  	unsigned int orig_delta_count = world->delta_count;
763  	unsigned int orig_root_count = world->root_count;
764  	unsigned int key_id = action_item->key_id;
765  	enum action action = action_item->action;
766  	struct objagg_obj *objagg_obj;
767  	const char *errmsg;
768  	int err;
769  
770  	if (inverse)
771  		action = action == ACTION_GET ? ACTION_PUT : ACTION_GET;
772  
773  	switch (action) {
774  	case ACTION_GET:
775  		objagg_obj = world_obj_get(world, objagg, key_id);
776  		if (IS_ERR(objagg_obj))
777  			return PTR_ERR(objagg_obj);
778  		break;
779  	case ACTION_PUT:
780  		world_obj_put(world, objagg, key_id);
781  		break;
782  	}
783  
784  	if (inverse)
785  		return 0;
786  	err = check_expect(world, action_item,
787  			   orig_delta_count, orig_root_count);
788  	if (err)
789  		goto errout;
790  
791  	err = check_expect_stats(objagg, &action_item->expect_stats, &errmsg);
792  	if (err) {
793  		pr_err("Key %u: Stats: %s\n", action_item->key_id, errmsg);
794  		goto errout;
795  	}
796  
797  	return 0;
798  
799  errout:
800  	/* This can only happen when action is not inversed.
801  	 * So in case of an error, cleanup by doing inverse action.
802  	 */
803  	test_delta_action_item(world, objagg, action_item, true);
804  	return err;
805  }
806  
test_delta(void)807  static int test_delta(void)
808  {
809  	struct world world = {};
810  	struct objagg *objagg;
811  	int i;
812  	int err;
813  
814  	objagg = objagg_create(&delta_ops, NULL, &world);
815  	if (IS_ERR(objagg))
816  		return PTR_ERR(objagg);
817  
818  	for (i = 0; i < ARRAY_SIZE(action_items); i++) {
819  		err = test_delta_action_item(&world, objagg,
820  					     &action_items[i], false);
821  		if (err)
822  			goto err_do_action_item;
823  	}
824  
825  	objagg_destroy(objagg);
826  	return 0;
827  
828  err_do_action_item:
829  	for (i--; i >= 0; i--)
830  		test_delta_action_item(&world, objagg, &action_items[i], true);
831  
832  	objagg_destroy(objagg);
833  	return err;
834  }
835  
836  struct hints_case {
837  	const unsigned int *key_ids;
838  	size_t key_ids_count;
839  	struct expect_stats expect_stats;
840  	struct expect_stats expect_stats_hints;
841  };
842  
843  static const unsigned int hints_case_key_ids[] = {
844  	1, 7, 3, 5, 3, 1, 30, 8, 8, 5, 6, 8,
845  };
846  
847  static const struct hints_case hints_case = {
848  	.key_ids = hints_case_key_ids,
849  	.key_ids_count = ARRAY_SIZE(hints_case_key_ids),
850  	.expect_stats =
851  		EXPECT_STATS(7, ROOT(1, 2, 7), ROOT(7, 1, 4), ROOT(30, 1, 1),
852  				DELTA(8, 3), DELTA(3, 2),
853  				DELTA(5, 2), DELTA(6, 1)),
854  	.expect_stats_hints =
855  		EXPECT_STATS(7, ROOT(3, 2, 9), ROOT(1, 2, 2), ROOT(30, 1, 1),
856  				DELTA(8, 3), DELTA(5, 2),
857  				DELTA(6, 1), DELTA(7, 1)),
858  };
859  
__pr_debug_stats(const struct objagg_stats * stats)860  static void __pr_debug_stats(const struct objagg_stats *stats)
861  {
862  	int i;
863  
864  	for (i = 0; i < stats->stats_info_count; i++)
865  		pr_debug("Stat index %d key %u: u %d, d %d, %s\n", i,
866  			 obj_to_key_id(stats->stats_info[i].objagg_obj),
867  			 stats->stats_info[i].stats.user_count,
868  			 stats->stats_info[i].stats.delta_user_count,
869  			 stats->stats_info[i].is_root ? "root" : "noroot");
870  }
871  
pr_debug_stats(struct objagg * objagg)872  static void pr_debug_stats(struct objagg *objagg)
873  {
874  	const struct objagg_stats *stats;
875  
876  	stats = objagg_stats_get(objagg);
877  	if (IS_ERR(stats))
878  		return;
879  	__pr_debug_stats(stats);
880  	objagg_stats_put(stats);
881  }
882  
pr_debug_hints_stats(struct objagg_hints * objagg_hints)883  static void pr_debug_hints_stats(struct objagg_hints *objagg_hints)
884  {
885  	const struct objagg_stats *stats;
886  
887  	stats = objagg_hints_stats_get(objagg_hints);
888  	if (IS_ERR(stats))
889  		return;
890  	__pr_debug_stats(stats);
891  	objagg_stats_put(stats);
892  }
893  
check_expect_hints_stats(struct objagg_hints * objagg_hints,const struct expect_stats * expect_stats,const char ** errmsg)894  static int check_expect_hints_stats(struct objagg_hints *objagg_hints,
895  				    const struct expect_stats *expect_stats,
896  				    const char **errmsg)
897  {
898  	const struct objagg_stats *stats;
899  	int err;
900  
901  	stats = objagg_hints_stats_get(objagg_hints);
902  	if (IS_ERR(stats))
903  		return PTR_ERR(stats);
904  	err = __check_expect_stats(stats, expect_stats, errmsg);
905  	objagg_stats_put(stats);
906  	return err;
907  }
908  
test_hints_case(const struct hints_case * hints_case)909  static int test_hints_case(const struct hints_case *hints_case)
910  {
911  	struct objagg_obj *objagg_obj;
912  	struct objagg_hints *hints;
913  	struct world world2 = {};
914  	struct world world = {};
915  	struct objagg *objagg2;
916  	struct objagg *objagg;
917  	const char *errmsg;
918  	int i;
919  	int err;
920  
921  	objagg = objagg_create(&delta_ops, NULL, &world);
922  	if (IS_ERR(objagg))
923  		return PTR_ERR(objagg);
924  
925  	for (i = 0; i < hints_case->key_ids_count; i++) {
926  		objagg_obj = world_obj_get(&world, objagg,
927  					   hints_case->key_ids[i]);
928  		if (IS_ERR(objagg_obj)) {
929  			err = PTR_ERR(objagg_obj);
930  			goto err_world_obj_get;
931  		}
932  	}
933  
934  	pr_debug_stats(objagg);
935  	err = check_expect_stats(objagg, &hints_case->expect_stats, &errmsg);
936  	if (err) {
937  		pr_err("Stats: %s\n", errmsg);
938  		goto err_check_expect_stats;
939  	}
940  
941  	hints = objagg_hints_get(objagg, OBJAGG_OPT_ALGO_SIMPLE_GREEDY);
942  	if (IS_ERR(hints)) {
943  		err = PTR_ERR(hints);
944  		goto err_hints_get;
945  	}
946  
947  	pr_debug_hints_stats(hints);
948  	err = check_expect_hints_stats(hints, &hints_case->expect_stats_hints,
949  				       &errmsg);
950  	if (err) {
951  		pr_err("Hints stats: %s\n", errmsg);
952  		goto err_check_expect_hints_stats;
953  	}
954  
955  	objagg2 = objagg_create(&delta_ops, hints, &world2);
956  	if (IS_ERR(objagg2))
957  		return PTR_ERR(objagg2);
958  
959  	for (i = 0; i < hints_case->key_ids_count; i++) {
960  		objagg_obj = world_obj_get(&world2, objagg2,
961  					   hints_case->key_ids[i]);
962  		if (IS_ERR(objagg_obj)) {
963  			err = PTR_ERR(objagg_obj);
964  			goto err_world2_obj_get;
965  		}
966  	}
967  
968  	pr_debug_stats(objagg2);
969  	err = check_expect_stats(objagg2, &hints_case->expect_stats_hints,
970  				 &errmsg);
971  	if (err) {
972  		pr_err("Stats2: %s\n", errmsg);
973  		goto err_check_expect_stats2;
974  	}
975  
976  	err = 0;
977  
978  err_check_expect_stats2:
979  err_world2_obj_get:
980  	for (i--; i >= 0; i--)
981  		world_obj_put(&world2, objagg, hints_case->key_ids[i]);
982  	i = hints_case->key_ids_count;
983  	objagg_destroy(objagg2);
984  err_check_expect_hints_stats:
985  	objagg_hints_put(hints);
986  err_hints_get:
987  err_check_expect_stats:
988  err_world_obj_get:
989  	for (i--; i >= 0; i--)
990  		world_obj_put(&world, objagg, hints_case->key_ids[i]);
991  
992  	objagg_destroy(objagg);
993  	return err;
994  }
test_hints(void)995  static int test_hints(void)
996  {
997  	return test_hints_case(&hints_case);
998  }
999  
test_objagg_init(void)1000  static int __init test_objagg_init(void)
1001  {
1002  	int err;
1003  
1004  	err = test_nodelta();
1005  	if (err)
1006  		return err;
1007  	err = test_delta();
1008  	if (err)
1009  		return err;
1010  	return test_hints();
1011  }
1012  
test_objagg_exit(void)1013  static void __exit test_objagg_exit(void)
1014  {
1015  }
1016  
1017  module_init(test_objagg_init);
1018  module_exit(test_objagg_exit);
1019  MODULE_LICENSE("Dual BSD/GPL");
1020  MODULE_AUTHOR("Jiri Pirko <jiri@mellanox.com>");
1021  MODULE_DESCRIPTION("Test module for objagg");
1022