1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (c) 2023 Oracle and/or its affiliates.
4  *
5  * KUnit test of the handshake upcall mechanism.
6  */
7 
8 #include <kunit/test.h>
9 #include <kunit/visibility.h>
10 
11 #include <linux/kernel.h>
12 
13 #include <net/sock.h>
14 #include <net/genetlink.h>
15 #include <net/netns/generic.h>
16 
17 #include <uapi/linux/handshake.h>
18 #include "handshake.h"
19 
20 MODULE_IMPORT_NS(EXPORTED_FOR_KUNIT_TESTING);
21 
22 static int test_accept_func(struct handshake_req *req, struct genl_info *info,
23 			    int fd)
24 {
25 	return 0;
26 }
27 
28 static void test_done_func(struct handshake_req *req, unsigned int status,
29 			   struct genl_info *info)
30 {
31 }
32 
33 struct handshake_req_alloc_test_param {
34 	const char			*desc;
35 	struct handshake_proto		*proto;
36 	gfp_t				gfp;
37 	bool				expect_success;
38 };
39 
40 static struct handshake_proto handshake_req_alloc_proto_2 = {
41 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_NONE,
42 };
43 
44 static struct handshake_proto handshake_req_alloc_proto_3 = {
45 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_MAX,
46 };
47 
48 static struct handshake_proto handshake_req_alloc_proto_4 = {
49 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
50 };
51 
52 static struct handshake_proto handshake_req_alloc_proto_5 = {
53 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
54 	.hp_accept		= test_accept_func,
55 };
56 
57 static struct handshake_proto handshake_req_alloc_proto_6 = {
58 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
59 	.hp_privsize		= UINT_MAX,
60 	.hp_accept		= test_accept_func,
61 	.hp_done		= test_done_func,
62 };
63 
64 static struct handshake_proto handshake_req_alloc_proto_good = {
65 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
66 	.hp_accept		= test_accept_func,
67 	.hp_done		= test_done_func,
68 };
69 
70 static const
71 struct handshake_req_alloc_test_param handshake_req_alloc_params[] = {
72 	{
73 		.desc			= "handshake_req_alloc NULL proto",
74 		.proto			= NULL,
75 		.gfp			= GFP_KERNEL,
76 		.expect_success		= false,
77 	},
78 	{
79 		.desc			= "handshake_req_alloc CLASS_NONE",
80 		.proto			= &handshake_req_alloc_proto_2,
81 		.gfp			= GFP_KERNEL,
82 		.expect_success		= false,
83 	},
84 	{
85 		.desc			= "handshake_req_alloc CLASS_MAX",
86 		.proto			= &handshake_req_alloc_proto_3,
87 		.gfp			= GFP_KERNEL,
88 		.expect_success		= false,
89 	},
90 	{
91 		.desc			= "handshake_req_alloc no callbacks",
92 		.proto			= &handshake_req_alloc_proto_4,
93 		.gfp			= GFP_KERNEL,
94 		.expect_success		= false,
95 	},
96 	{
97 		.desc			= "handshake_req_alloc no done callback",
98 		.proto			= &handshake_req_alloc_proto_5,
99 		.gfp			= GFP_KERNEL,
100 		.expect_success		= false,
101 	},
102 	{
103 		.desc			= "handshake_req_alloc excessive privsize",
104 		.proto			= &handshake_req_alloc_proto_6,
105 		.gfp			= GFP_KERNEL,
106 		.expect_success		= false,
107 	},
108 	{
109 		.desc			= "handshake_req_alloc all good",
110 		.proto			= &handshake_req_alloc_proto_good,
111 		.gfp			= GFP_KERNEL,
112 		.expect_success		= true,
113 	},
114 };
115 
116 static void
117 handshake_req_alloc_get_desc(const struct handshake_req_alloc_test_param *param,
118 			     char *desc)
119 {
120 	strscpy(desc, param->desc, KUNIT_PARAM_DESC_SIZE);
121 }
122 
123 /* Creates the function handshake_req_alloc_gen_params */
124 KUNIT_ARRAY_PARAM(handshake_req_alloc, handshake_req_alloc_params,
125 		  handshake_req_alloc_get_desc);
126 
127 static void handshake_req_alloc_case(struct kunit *test)
128 {
129 	const struct handshake_req_alloc_test_param *param = test->param_value;
130 	struct handshake_req *result;
131 
132 	/* Arrange */
133 
134 	/* Act */
135 	result = handshake_req_alloc(param->proto, param->gfp);
136 
137 	/* Assert */
138 	if (param->expect_success)
139 		KUNIT_EXPECT_NOT_NULL(test, result);
140 	else
141 		KUNIT_EXPECT_NULL(test, result);
142 
143 	kfree(result);
144 }
145 
146 static void handshake_req_submit_test1(struct kunit *test)
147 {
148 	struct socket *sock;
149 	int err, result;
150 
151 	/* Arrange */
152 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
153 			    &sock, 1);
154 	KUNIT_ASSERT_EQ(test, err, 0);
155 
156 	/* Act */
157 	result = handshake_req_submit(sock, NULL, GFP_KERNEL);
158 
159 	/* Assert */
160 	KUNIT_EXPECT_EQ(test, result, -EINVAL);
161 
162 	sock_release(sock);
163 }
164 
165 static void handshake_req_submit_test2(struct kunit *test)
166 {
167 	struct handshake_req *req;
168 	int result;
169 
170 	/* Arrange */
171 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
172 	KUNIT_ASSERT_NOT_NULL(test, req);
173 
174 	/* Act */
175 	result = handshake_req_submit(NULL, req, GFP_KERNEL);
176 
177 	/* Assert */
178 	KUNIT_EXPECT_EQ(test, result, -EINVAL);
179 
180 	/* handshake_req_submit() destroys @req on error */
181 }
182 
183 static void handshake_req_submit_test3(struct kunit *test)
184 {
185 	struct handshake_req *req;
186 	struct socket *sock;
187 	int err, result;
188 
189 	/* Arrange */
190 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
191 	KUNIT_ASSERT_NOT_NULL(test, req);
192 
193 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
194 			    &sock, 1);
195 	KUNIT_ASSERT_EQ(test, err, 0);
196 	sock->file = NULL;
197 
198 	/* Act */
199 	result = handshake_req_submit(sock, req, GFP_KERNEL);
200 
201 	/* Assert */
202 	KUNIT_EXPECT_EQ(test, result, -EINVAL);
203 
204 	/* handshake_req_submit() destroys @req on error */
205 	sock_release(sock);
206 }
207 
208 static void handshake_req_submit_test4(struct kunit *test)
209 {
210 	struct handshake_req *req, *result;
211 	struct socket *sock;
212 	int err;
213 
214 	/* Arrange */
215 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
216 	KUNIT_ASSERT_NOT_NULL(test, req);
217 
218 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
219 			    &sock, 1);
220 	KUNIT_ASSERT_EQ(test, err, 0);
221 	sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
222 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
223 	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
224 
225 	err = handshake_req_submit(sock, req, GFP_KERNEL);
226 	KUNIT_ASSERT_EQ(test, err, 0);
227 
228 	/* Act */
229 	result = handshake_req_hash_lookup(sock->sk);
230 
231 	/* Assert */
232 	KUNIT_EXPECT_NOT_NULL(test, result);
233 	KUNIT_EXPECT_PTR_EQ(test, req, result);
234 
235 	handshake_req_cancel(sock->sk);
236 	sock_release(sock);
237 }
238 
239 static void handshake_req_submit_test5(struct kunit *test)
240 {
241 	struct handshake_req *req;
242 	struct handshake_net *hn;
243 	struct socket *sock;
244 	struct net *net;
245 	int saved, err;
246 
247 	/* Arrange */
248 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
249 	KUNIT_ASSERT_NOT_NULL(test, req);
250 
251 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
252 			    &sock, 1);
253 	KUNIT_ASSERT_EQ(test, err, 0);
254 	sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
255 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
256 	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
257 
258 	net = sock_net(sock->sk);
259 	hn = handshake_pernet(net);
260 	KUNIT_ASSERT_NOT_NULL(test, hn);
261 
262 	saved = hn->hn_pending;
263 	hn->hn_pending = hn->hn_pending_max + 1;
264 
265 	/* Act */
266 	err = handshake_req_submit(sock, req, GFP_KERNEL);
267 
268 	/* Assert */
269 	KUNIT_EXPECT_EQ(test, err, -EAGAIN);
270 
271 	sock_release(sock);
272 	hn->hn_pending = saved;
273 }
274 
275 static void handshake_req_submit_test6(struct kunit *test)
276 {
277 	struct handshake_req *req1, *req2;
278 	struct socket *sock;
279 	int err;
280 
281 	/* Arrange */
282 	req1 = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
283 	KUNIT_ASSERT_NOT_NULL(test, req1);
284 	req2 = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
285 	KUNIT_ASSERT_NOT_NULL(test, req2);
286 
287 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
288 			    &sock, 1);
289 	KUNIT_ASSERT_EQ(test, err, 0);
290 	sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
291 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
292 	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
293 
294 	/* Act */
295 	err = handshake_req_submit(sock, req1, GFP_KERNEL);
296 	KUNIT_ASSERT_EQ(test, err, 0);
297 	err = handshake_req_submit(sock, req2, GFP_KERNEL);
298 
299 	/* Assert */
300 	KUNIT_EXPECT_EQ(test, err, -EBUSY);
301 
302 	handshake_req_cancel(sock->sk);
303 	sock_release(sock);
304 }
305 
306 static void handshake_req_cancel_test1(struct kunit *test)
307 {
308 	struct handshake_req *req;
309 	struct socket *sock;
310 	bool result;
311 	int err;
312 
313 	/* Arrange */
314 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
315 	KUNIT_ASSERT_NOT_NULL(test, req);
316 
317 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
318 			    &sock, 1);
319 	KUNIT_ASSERT_EQ(test, err, 0);
320 
321 	sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
322 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
323 
324 	err = handshake_req_submit(sock, req, GFP_KERNEL);
325 	KUNIT_ASSERT_EQ(test, err, 0);
326 
327 	/* NB: handshake_req hasn't been accepted */
328 
329 	/* Act */
330 	result = handshake_req_cancel(sock->sk);
331 
332 	/* Assert */
333 	KUNIT_EXPECT_TRUE(test, result);
334 
335 	sock_release(sock);
336 }
337 
338 static void handshake_req_cancel_test2(struct kunit *test)
339 {
340 	struct handshake_req *req, *next;
341 	struct handshake_net *hn;
342 	struct socket *sock;
343 	struct net *net;
344 	bool result;
345 	int err;
346 
347 	/* Arrange */
348 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
349 	KUNIT_ASSERT_NOT_NULL(test, req);
350 
351 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
352 			    &sock, 1);
353 	KUNIT_ASSERT_EQ(test, err, 0);
354 
355 	sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
356 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
357 
358 	err = handshake_req_submit(sock, req, GFP_KERNEL);
359 	KUNIT_ASSERT_EQ(test, err, 0);
360 
361 	net = sock_net(sock->sk);
362 	hn = handshake_pernet(net);
363 	KUNIT_ASSERT_NOT_NULL(test, hn);
364 
365 	/* Pretend to accept this request */
366 	next = handshake_req_next(hn, HANDSHAKE_HANDLER_CLASS_TLSHD);
367 	KUNIT_ASSERT_PTR_EQ(test, req, next);
368 
369 	/* Act */
370 	result = handshake_req_cancel(sock->sk);
371 
372 	/* Assert */
373 	KUNIT_EXPECT_TRUE(test, result);
374 
375 	sock_release(sock);
376 }
377 
378 static void handshake_req_cancel_test3(struct kunit *test)
379 {
380 	struct handshake_req *req, *next;
381 	struct handshake_net *hn;
382 	struct socket *sock;
383 	struct net *net;
384 	bool result;
385 	int err;
386 
387 	/* Arrange */
388 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
389 	KUNIT_ASSERT_NOT_NULL(test, req);
390 
391 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
392 			    &sock, 1);
393 	KUNIT_ASSERT_EQ(test, err, 0);
394 
395 	sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
396 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
397 
398 	err = handshake_req_submit(sock, req, GFP_KERNEL);
399 	KUNIT_ASSERT_EQ(test, err, 0);
400 
401 	net = sock_net(sock->sk);
402 	hn = handshake_pernet(net);
403 	KUNIT_ASSERT_NOT_NULL(test, hn);
404 
405 	/* Pretend to accept this request */
406 	next = handshake_req_next(hn, HANDSHAKE_HANDLER_CLASS_TLSHD);
407 	KUNIT_ASSERT_PTR_EQ(test, req, next);
408 
409 	/* Pretend to complete this request */
410 	handshake_complete(next, -ETIMEDOUT, NULL);
411 
412 	/* Act */
413 	result = handshake_req_cancel(sock->sk);
414 
415 	/* Assert */
416 	KUNIT_EXPECT_FALSE(test, result);
417 
418 	sock_release(sock);
419 }
420 
421 static struct handshake_req *handshake_req_destroy_test;
422 
423 static void test_destroy_func(struct handshake_req *req)
424 {
425 	handshake_req_destroy_test = req;
426 }
427 
428 static struct handshake_proto handshake_req_alloc_proto_destroy = {
429 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
430 	.hp_accept		= test_accept_func,
431 	.hp_done		= test_done_func,
432 	.hp_destroy		= test_destroy_func,
433 };
434 
435 static void handshake_req_destroy_test1(struct kunit *test)
436 {
437 	struct handshake_req *req;
438 	struct socket *sock;
439 	int err;
440 
441 	/* Arrange */
442 	handshake_req_destroy_test = NULL;
443 
444 	req = handshake_req_alloc(&handshake_req_alloc_proto_destroy, GFP_KERNEL);
445 	KUNIT_ASSERT_NOT_NULL(test, req);
446 
447 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
448 			    &sock, 1);
449 	KUNIT_ASSERT_EQ(test, err, 0);
450 
451 	sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
452 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
453 
454 	err = handshake_req_submit(sock, req, GFP_KERNEL);
455 	KUNIT_ASSERT_EQ(test, err, 0);
456 
457 	handshake_req_cancel(sock->sk);
458 
459 	/* Act */
460 	sock_release(sock);
461 
462 	/* Assert */
463 	KUNIT_EXPECT_PTR_EQ(test, handshake_req_destroy_test, req);
464 }
465 
466 static struct kunit_case handshake_api_test_cases[] = {
467 	{
468 		.name			= "req_alloc API fuzzing",
469 		.run_case		= handshake_req_alloc_case,
470 		.generate_params	= handshake_req_alloc_gen_params,
471 	},
472 	{
473 		.name			= "req_submit NULL req arg",
474 		.run_case		= handshake_req_submit_test1,
475 	},
476 	{
477 		.name			= "req_submit NULL sock arg",
478 		.run_case		= handshake_req_submit_test2,
479 	},
480 	{
481 		.name			= "req_submit NULL sock->file",
482 		.run_case		= handshake_req_submit_test3,
483 	},
484 	{
485 		.name			= "req_lookup works",
486 		.run_case		= handshake_req_submit_test4,
487 	},
488 	{
489 		.name			= "req_submit max pending",
490 		.run_case		= handshake_req_submit_test5,
491 	},
492 	{
493 		.name			= "req_submit multiple",
494 		.run_case		= handshake_req_submit_test6,
495 	},
496 	{
497 		.name			= "req_cancel before accept",
498 		.run_case		= handshake_req_cancel_test1,
499 	},
500 	{
501 		.name			= "req_cancel after accept",
502 		.run_case		= handshake_req_cancel_test2,
503 	},
504 	{
505 		.name			= "req_cancel after done",
506 		.run_case		= handshake_req_cancel_test3,
507 	},
508 	{
509 		.name			= "req_destroy works",
510 		.run_case		= handshake_req_destroy_test1,
511 	},
512 	{}
513 };
514 
515 static struct kunit_suite handshake_api_suite = {
516        .name                   = "Handshake API tests",
517        .test_cases             = handshake_api_test_cases,
518 };
519 
520 kunit_test_suites(&handshake_api_suite);
521 
522 MODULE_DESCRIPTION("Test handshake upcall API functions");
523 MODULE_LICENSE("GPL");
524