1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (c) 2023 Oracle and/or its affiliates.
4  *
5  * KUnit test of the handshake upcall mechanism.
6  */
7 
8 #include <kunit/test.h>
9 #include <kunit/visibility.h>
10 
11 #include <linux/kernel.h>
12 
13 #include <net/sock.h>
14 #include <net/genetlink.h>
15 #include <net/netns/generic.h>
16 
17 #include <uapi/linux/handshake.h>
18 #include "handshake.h"
19 
20 MODULE_IMPORT_NS(EXPORTED_FOR_KUNIT_TESTING);
21 
22 static int test_accept_func(struct handshake_req *req, struct genl_info *info,
23 			    int fd)
24 {
25 	return 0;
26 }
27 
28 static void test_done_func(struct handshake_req *req, unsigned int status,
29 			   struct genl_info *info)
30 {
31 }
32 
33 struct handshake_req_alloc_test_param {
34 	const char			*desc;
35 	struct handshake_proto		*proto;
36 	gfp_t				gfp;
37 	bool				expect_success;
38 };
39 
40 static struct handshake_proto handshake_req_alloc_proto_2 = {
41 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_NONE,
42 };
43 
44 static struct handshake_proto handshake_req_alloc_proto_3 = {
45 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_MAX,
46 };
47 
48 static struct handshake_proto handshake_req_alloc_proto_4 = {
49 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
50 };
51 
52 static struct handshake_proto handshake_req_alloc_proto_5 = {
53 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
54 	.hp_accept		= test_accept_func,
55 };
56 
57 static struct handshake_proto handshake_req_alloc_proto_6 = {
58 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
59 	.hp_privsize		= UINT_MAX,
60 	.hp_accept		= test_accept_func,
61 	.hp_done		= test_done_func,
62 };
63 
64 static struct handshake_proto handshake_req_alloc_proto_good = {
65 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
66 	.hp_accept		= test_accept_func,
67 	.hp_done		= test_done_func,
68 };
69 
70 static const
71 struct handshake_req_alloc_test_param handshake_req_alloc_params[] = {
72 	{
73 		.desc			= "handshake_req_alloc NULL proto",
74 		.proto			= NULL,
75 		.gfp			= GFP_KERNEL,
76 		.expect_success		= false,
77 	},
78 	{
79 		.desc			= "handshake_req_alloc CLASS_NONE",
80 		.proto			= &handshake_req_alloc_proto_2,
81 		.gfp			= GFP_KERNEL,
82 		.expect_success		= false,
83 	},
84 	{
85 		.desc			= "handshake_req_alloc CLASS_MAX",
86 		.proto			= &handshake_req_alloc_proto_3,
87 		.gfp			= GFP_KERNEL,
88 		.expect_success		= false,
89 	},
90 	{
91 		.desc			= "handshake_req_alloc no callbacks",
92 		.proto			= &handshake_req_alloc_proto_4,
93 		.gfp			= GFP_KERNEL,
94 		.expect_success		= false,
95 	},
96 	{
97 		.desc			= "handshake_req_alloc no done callback",
98 		.proto			= &handshake_req_alloc_proto_5,
99 		.gfp			= GFP_KERNEL,
100 		.expect_success		= false,
101 	},
102 	{
103 		.desc			= "handshake_req_alloc excessive privsize",
104 		.proto			= &handshake_req_alloc_proto_6,
105 		.gfp			= GFP_KERNEL | __GFP_NOWARN,
106 		.expect_success		= false,
107 	},
108 	{
109 		.desc			= "handshake_req_alloc all good",
110 		.proto			= &handshake_req_alloc_proto_good,
111 		.gfp			= GFP_KERNEL,
112 		.expect_success		= true,
113 	},
114 };
115 
116 static void
117 handshake_req_alloc_get_desc(const struct handshake_req_alloc_test_param *param,
118 			     char *desc)
119 {
120 	strscpy(desc, param->desc, KUNIT_PARAM_DESC_SIZE);
121 }
122 
123 /* Creates the function handshake_req_alloc_gen_params */
124 KUNIT_ARRAY_PARAM(handshake_req_alloc, handshake_req_alloc_params,
125 		  handshake_req_alloc_get_desc);
126 
127 static void handshake_req_alloc_case(struct kunit *test)
128 {
129 	const struct handshake_req_alloc_test_param *param = test->param_value;
130 	struct handshake_req *result;
131 
132 	/* Arrange */
133 
134 	/* Act */
135 	result = handshake_req_alloc(param->proto, param->gfp);
136 
137 	/* Assert */
138 	if (param->expect_success)
139 		KUNIT_EXPECT_NOT_NULL(test, result);
140 	else
141 		KUNIT_EXPECT_NULL(test, result);
142 
143 	kfree(result);
144 }
145 
146 static void handshake_req_submit_test1(struct kunit *test)
147 {
148 	struct socket *sock;
149 	int err, result;
150 
151 	/* Arrange */
152 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
153 			    &sock, 1);
154 	KUNIT_ASSERT_EQ(test, err, 0);
155 
156 	/* Act */
157 	result = handshake_req_submit(sock, NULL, GFP_KERNEL);
158 
159 	/* Assert */
160 	KUNIT_EXPECT_EQ(test, result, -EINVAL);
161 
162 	sock_release(sock);
163 }
164 
165 static void handshake_req_submit_test2(struct kunit *test)
166 {
167 	struct handshake_req *req;
168 	int result;
169 
170 	/* Arrange */
171 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
172 	KUNIT_ASSERT_NOT_NULL(test, req);
173 
174 	/* Act */
175 	result = handshake_req_submit(NULL, req, GFP_KERNEL);
176 
177 	/* Assert */
178 	KUNIT_EXPECT_EQ(test, result, -EINVAL);
179 
180 	/* handshake_req_submit() destroys @req on error */
181 }
182 
183 static void handshake_req_submit_test3(struct kunit *test)
184 {
185 	struct handshake_req *req;
186 	struct socket *sock;
187 	int err, result;
188 
189 	/* Arrange */
190 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
191 	KUNIT_ASSERT_NOT_NULL(test, req);
192 
193 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
194 			    &sock, 1);
195 	KUNIT_ASSERT_EQ(test, err, 0);
196 	sock->file = NULL;
197 
198 	/* Act */
199 	result = handshake_req_submit(sock, req, GFP_KERNEL);
200 
201 	/* Assert */
202 	KUNIT_EXPECT_EQ(test, result, -EINVAL);
203 
204 	/* handshake_req_submit() destroys @req on error */
205 	sock_release(sock);
206 }
207 
208 static void handshake_req_submit_test4(struct kunit *test)
209 {
210 	struct handshake_req *req, *result;
211 	struct socket *sock;
212 	struct file *filp;
213 	int err;
214 
215 	/* Arrange */
216 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
217 	KUNIT_ASSERT_NOT_NULL(test, req);
218 
219 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
220 			    &sock, 1);
221 	KUNIT_ASSERT_EQ(test, err, 0);
222 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
223 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
224 	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
225 	sock->file = filp;
226 
227 	err = handshake_req_submit(sock, req, GFP_KERNEL);
228 	KUNIT_ASSERT_EQ(test, err, 0);
229 
230 	/* Act */
231 	result = handshake_req_hash_lookup(sock->sk);
232 
233 	/* Assert */
234 	KUNIT_EXPECT_NOT_NULL(test, result);
235 	KUNIT_EXPECT_PTR_EQ(test, req, result);
236 
237 	handshake_req_cancel(sock->sk);
238 	fput(filp);
239 }
240 
241 static void handshake_req_submit_test5(struct kunit *test)
242 {
243 	struct handshake_req *req;
244 	struct handshake_net *hn;
245 	struct socket *sock;
246 	struct file *filp;
247 	struct net *net;
248 	int saved, err;
249 
250 	/* Arrange */
251 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
252 	KUNIT_ASSERT_NOT_NULL(test, req);
253 
254 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
255 			    &sock, 1);
256 	KUNIT_ASSERT_EQ(test, err, 0);
257 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
258 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
259 	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
260 	sock->file = filp;
261 
262 	net = sock_net(sock->sk);
263 	hn = handshake_pernet(net);
264 	KUNIT_ASSERT_NOT_NULL(test, hn);
265 
266 	saved = hn->hn_pending;
267 	hn->hn_pending = hn->hn_pending_max + 1;
268 
269 	/* Act */
270 	err = handshake_req_submit(sock, req, GFP_KERNEL);
271 
272 	/* Assert */
273 	KUNIT_EXPECT_EQ(test, err, -EAGAIN);
274 
275 	fput(filp);
276 	hn->hn_pending = saved;
277 }
278 
279 static void handshake_req_submit_test6(struct kunit *test)
280 {
281 	struct handshake_req *req1, *req2;
282 	struct socket *sock;
283 	struct file *filp;
284 	int err;
285 
286 	/* Arrange */
287 	req1 = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
288 	KUNIT_ASSERT_NOT_NULL(test, req1);
289 	req2 = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
290 	KUNIT_ASSERT_NOT_NULL(test, req2);
291 
292 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
293 			    &sock, 1);
294 	KUNIT_ASSERT_EQ(test, err, 0);
295 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
296 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
297 	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
298 	sock->file = filp;
299 
300 	/* Act */
301 	err = handshake_req_submit(sock, req1, GFP_KERNEL);
302 	KUNIT_ASSERT_EQ(test, err, 0);
303 	err = handshake_req_submit(sock, req2, GFP_KERNEL);
304 
305 	/* Assert */
306 	KUNIT_EXPECT_EQ(test, err, -EBUSY);
307 
308 	handshake_req_cancel(sock->sk);
309 	fput(filp);
310 }
311 
312 static void handshake_req_cancel_test1(struct kunit *test)
313 {
314 	struct handshake_req *req;
315 	struct socket *sock;
316 	struct file *filp;
317 	bool result;
318 	int err;
319 
320 	/* Arrange */
321 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
322 	KUNIT_ASSERT_NOT_NULL(test, req);
323 
324 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
325 			    &sock, 1);
326 	KUNIT_ASSERT_EQ(test, err, 0);
327 
328 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
329 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
330 	sock->file = filp;
331 
332 	err = handshake_req_submit(sock, req, GFP_KERNEL);
333 	KUNIT_ASSERT_EQ(test, err, 0);
334 
335 	/* NB: handshake_req hasn't been accepted */
336 
337 	/* Act */
338 	result = handshake_req_cancel(sock->sk);
339 
340 	/* Assert */
341 	KUNIT_EXPECT_TRUE(test, result);
342 
343 	fput(filp);
344 }
345 
346 static void handshake_req_cancel_test2(struct kunit *test)
347 {
348 	struct handshake_req *req, *next;
349 	struct handshake_net *hn;
350 	struct socket *sock;
351 	struct file *filp;
352 	struct net *net;
353 	bool result;
354 	int err;
355 
356 	/* Arrange */
357 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
358 	KUNIT_ASSERT_NOT_NULL(test, req);
359 
360 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
361 			    &sock, 1);
362 	KUNIT_ASSERT_EQ(test, err, 0);
363 
364 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
365 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
366 	sock->file = filp;
367 
368 	err = handshake_req_submit(sock, req, GFP_KERNEL);
369 	KUNIT_ASSERT_EQ(test, err, 0);
370 
371 	net = sock_net(sock->sk);
372 	hn = handshake_pernet(net);
373 	KUNIT_ASSERT_NOT_NULL(test, hn);
374 
375 	/* Pretend to accept this request */
376 	next = handshake_req_next(hn, HANDSHAKE_HANDLER_CLASS_TLSHD);
377 	KUNIT_ASSERT_PTR_EQ(test, req, next);
378 
379 	/* Act */
380 	result = handshake_req_cancel(sock->sk);
381 
382 	/* Assert */
383 	KUNIT_EXPECT_TRUE(test, result);
384 
385 	fput(filp);
386 }
387 
388 static void handshake_req_cancel_test3(struct kunit *test)
389 {
390 	struct handshake_req *req, *next;
391 	struct handshake_net *hn;
392 	struct socket *sock;
393 	struct file *filp;
394 	struct net *net;
395 	bool result;
396 	int err;
397 
398 	/* Arrange */
399 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
400 	KUNIT_ASSERT_NOT_NULL(test, req);
401 
402 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
403 			    &sock, 1);
404 	KUNIT_ASSERT_EQ(test, err, 0);
405 
406 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
407 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
408 	sock->file = filp;
409 
410 	err = handshake_req_submit(sock, req, GFP_KERNEL);
411 	KUNIT_ASSERT_EQ(test, err, 0);
412 
413 	net = sock_net(sock->sk);
414 	hn = handshake_pernet(net);
415 	KUNIT_ASSERT_NOT_NULL(test, hn);
416 
417 	/* Pretend to accept this request */
418 	next = handshake_req_next(hn, HANDSHAKE_HANDLER_CLASS_TLSHD);
419 	KUNIT_ASSERT_PTR_EQ(test, req, next);
420 
421 	/* Pretend to complete this request */
422 	handshake_complete(next, -ETIMEDOUT, NULL);
423 
424 	/* Act */
425 	result = handshake_req_cancel(sock->sk);
426 
427 	/* Assert */
428 	KUNIT_EXPECT_FALSE(test, result);
429 
430 	fput(filp);
431 }
432 
433 static struct handshake_req *handshake_req_destroy_test;
434 
435 static void test_destroy_func(struct handshake_req *req)
436 {
437 	handshake_req_destroy_test = req;
438 }
439 
440 static struct handshake_proto handshake_req_alloc_proto_destroy = {
441 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
442 	.hp_accept		= test_accept_func,
443 	.hp_done		= test_done_func,
444 	.hp_destroy		= test_destroy_func,
445 };
446 
447 static void handshake_req_destroy_test1(struct kunit *test)
448 {
449 	struct handshake_req *req;
450 	struct socket *sock;
451 	struct file *filp;
452 	int err;
453 
454 	/* Arrange */
455 	handshake_req_destroy_test = NULL;
456 
457 	req = handshake_req_alloc(&handshake_req_alloc_proto_destroy, GFP_KERNEL);
458 	KUNIT_ASSERT_NOT_NULL(test, req);
459 
460 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
461 			    &sock, 1);
462 	KUNIT_ASSERT_EQ(test, err, 0);
463 
464 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
465 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
466 	sock->file = filp;
467 
468 	err = handshake_req_submit(sock, req, GFP_KERNEL);
469 	KUNIT_ASSERT_EQ(test, err, 0);
470 
471 	handshake_req_cancel(sock->sk);
472 
473 	/* Act */
474 	/* Ensure the close/release/put process has run to
475 	 * completion before checking the result.
476 	 */
477 	__fput_sync(filp);
478 
479 	/* Assert */
480 	KUNIT_EXPECT_PTR_EQ(test, handshake_req_destroy_test, req);
481 }
482 
483 static struct kunit_case handshake_api_test_cases[] = {
484 	{
485 		.name			= "req_alloc API fuzzing",
486 		.run_case		= handshake_req_alloc_case,
487 		.generate_params	= handshake_req_alloc_gen_params,
488 	},
489 	{
490 		.name			= "req_submit NULL req arg",
491 		.run_case		= handshake_req_submit_test1,
492 	},
493 	{
494 		.name			= "req_submit NULL sock arg",
495 		.run_case		= handshake_req_submit_test2,
496 	},
497 	{
498 		.name			= "req_submit NULL sock->file",
499 		.run_case		= handshake_req_submit_test3,
500 	},
501 	{
502 		.name			= "req_lookup works",
503 		.run_case		= handshake_req_submit_test4,
504 	},
505 	{
506 		.name			= "req_submit max pending",
507 		.run_case		= handshake_req_submit_test5,
508 	},
509 	{
510 		.name			= "req_submit multiple",
511 		.run_case		= handshake_req_submit_test6,
512 	},
513 	{
514 		.name			= "req_cancel before accept",
515 		.run_case		= handshake_req_cancel_test1,
516 	},
517 	{
518 		.name			= "req_cancel after accept",
519 		.run_case		= handshake_req_cancel_test2,
520 	},
521 	{
522 		.name			= "req_cancel after done",
523 		.run_case		= handshake_req_cancel_test3,
524 	},
525 	{
526 		.name			= "req_destroy works",
527 		.run_case		= handshake_req_destroy_test1,
528 	},
529 	{}
530 };
531 
532 static struct kunit_suite handshake_api_suite = {
533        .name                   = "Handshake API tests",
534        .test_cases             = handshake_api_test_cases,
535 };
536 
537 kunit_test_suites(&handshake_api_suite);
538 
539 MODULE_DESCRIPTION("Test handshake upcall API functions");
540 MODULE_LICENSE("GPL");
541