188232ec1SChuck Lever // SPDX-License-Identifier: GPL-2.0
288232ec1SChuck Lever /*
388232ec1SChuck Lever  * Copyright (c) 2023 Oracle and/or its affiliates.
488232ec1SChuck Lever  *
588232ec1SChuck Lever  * KUnit test of the handshake upcall mechanism.
688232ec1SChuck Lever  */
788232ec1SChuck Lever 
888232ec1SChuck Lever #include <kunit/test.h>
988232ec1SChuck Lever #include <kunit/visibility.h>
1088232ec1SChuck Lever 
1188232ec1SChuck Lever #include <linux/kernel.h>
1288232ec1SChuck Lever 
1388232ec1SChuck Lever #include <net/sock.h>
1488232ec1SChuck Lever #include <net/genetlink.h>
1588232ec1SChuck Lever #include <net/netns/generic.h>
1688232ec1SChuck Lever 
1788232ec1SChuck Lever #include <uapi/linux/handshake.h>
1888232ec1SChuck Lever #include "handshake.h"
1988232ec1SChuck Lever 
2088232ec1SChuck Lever MODULE_IMPORT_NS(EXPORTED_FOR_KUNIT_TESTING);
2188232ec1SChuck Lever 
test_accept_func(struct handshake_req * req,struct genl_info * info,int fd)2288232ec1SChuck Lever static int test_accept_func(struct handshake_req *req, struct genl_info *info,
2388232ec1SChuck Lever 			    int fd)
2488232ec1SChuck Lever {
2588232ec1SChuck Lever 	return 0;
2688232ec1SChuck Lever }
2788232ec1SChuck Lever 
test_done_func(struct handshake_req * req,unsigned int status,struct genl_info * info)2888232ec1SChuck Lever static void test_done_func(struct handshake_req *req, unsigned int status,
2988232ec1SChuck Lever 			   struct genl_info *info)
3088232ec1SChuck Lever {
3188232ec1SChuck Lever }
3288232ec1SChuck Lever 
3388232ec1SChuck Lever struct handshake_req_alloc_test_param {
3488232ec1SChuck Lever 	const char			*desc;
3588232ec1SChuck Lever 	struct handshake_proto		*proto;
3688232ec1SChuck Lever 	gfp_t				gfp;
3788232ec1SChuck Lever 	bool				expect_success;
3888232ec1SChuck Lever };
3988232ec1SChuck Lever 
4088232ec1SChuck Lever static struct handshake_proto handshake_req_alloc_proto_2 = {
4188232ec1SChuck Lever 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_NONE,
4288232ec1SChuck Lever };
4388232ec1SChuck Lever 
4488232ec1SChuck Lever static struct handshake_proto handshake_req_alloc_proto_3 = {
4588232ec1SChuck Lever 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_MAX,
4688232ec1SChuck Lever };
4788232ec1SChuck Lever 
4888232ec1SChuck Lever static struct handshake_proto handshake_req_alloc_proto_4 = {
4988232ec1SChuck Lever 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
5088232ec1SChuck Lever };
5188232ec1SChuck Lever 
5288232ec1SChuck Lever static struct handshake_proto handshake_req_alloc_proto_5 = {
5388232ec1SChuck Lever 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
5488232ec1SChuck Lever 	.hp_accept		= test_accept_func,
5588232ec1SChuck Lever };
5688232ec1SChuck Lever 
5788232ec1SChuck Lever static struct handshake_proto handshake_req_alloc_proto_6 = {
5888232ec1SChuck Lever 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
5988232ec1SChuck Lever 	.hp_privsize		= UINT_MAX,
6088232ec1SChuck Lever 	.hp_accept		= test_accept_func,
6188232ec1SChuck Lever 	.hp_done		= test_done_func,
6288232ec1SChuck Lever };
6388232ec1SChuck Lever 
6488232ec1SChuck Lever static struct handshake_proto handshake_req_alloc_proto_good = {
6588232ec1SChuck Lever 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
6688232ec1SChuck Lever 	.hp_accept		= test_accept_func,
6788232ec1SChuck Lever 	.hp_done		= test_done_func,
6888232ec1SChuck Lever };
6988232ec1SChuck Lever 
7088232ec1SChuck Lever static const
7188232ec1SChuck Lever struct handshake_req_alloc_test_param handshake_req_alloc_params[] = {
7288232ec1SChuck Lever 	{
7388232ec1SChuck Lever 		.desc			= "handshake_req_alloc NULL proto",
7488232ec1SChuck Lever 		.proto			= NULL,
7588232ec1SChuck Lever 		.gfp			= GFP_KERNEL,
7688232ec1SChuck Lever 		.expect_success		= false,
7788232ec1SChuck Lever 	},
7888232ec1SChuck Lever 	{
7988232ec1SChuck Lever 		.desc			= "handshake_req_alloc CLASS_NONE",
8088232ec1SChuck Lever 		.proto			= &handshake_req_alloc_proto_2,
8188232ec1SChuck Lever 		.gfp			= GFP_KERNEL,
8288232ec1SChuck Lever 		.expect_success		= false,
8388232ec1SChuck Lever 	},
8488232ec1SChuck Lever 	{
8588232ec1SChuck Lever 		.desc			= "handshake_req_alloc CLASS_MAX",
8688232ec1SChuck Lever 		.proto			= &handshake_req_alloc_proto_3,
8788232ec1SChuck Lever 		.gfp			= GFP_KERNEL,
8888232ec1SChuck Lever 		.expect_success		= false,
8988232ec1SChuck Lever 	},
9088232ec1SChuck Lever 	{
9188232ec1SChuck Lever 		.desc			= "handshake_req_alloc no callbacks",
9288232ec1SChuck Lever 		.proto			= &handshake_req_alloc_proto_4,
9388232ec1SChuck Lever 		.gfp			= GFP_KERNEL,
9488232ec1SChuck Lever 		.expect_success		= false,
9588232ec1SChuck Lever 	},
9688232ec1SChuck Lever 	{
9788232ec1SChuck Lever 		.desc			= "handshake_req_alloc no done callback",
9888232ec1SChuck Lever 		.proto			= &handshake_req_alloc_proto_5,
9988232ec1SChuck Lever 		.gfp			= GFP_KERNEL,
10088232ec1SChuck Lever 		.expect_success		= false,
10188232ec1SChuck Lever 	},
10288232ec1SChuck Lever 	{
10388232ec1SChuck Lever 		.desc			= "handshake_req_alloc excessive privsize",
10488232ec1SChuck Lever 		.proto			= &handshake_req_alloc_proto_6,
105b21c7ba6SChuck Lever 		.gfp			= GFP_KERNEL | __GFP_NOWARN,
10688232ec1SChuck Lever 		.expect_success		= false,
10788232ec1SChuck Lever 	},
10888232ec1SChuck Lever 	{
10988232ec1SChuck Lever 		.desc			= "handshake_req_alloc all good",
11088232ec1SChuck Lever 		.proto			= &handshake_req_alloc_proto_good,
11188232ec1SChuck Lever 		.gfp			= GFP_KERNEL,
11288232ec1SChuck Lever 		.expect_success		= true,
11388232ec1SChuck Lever 	},
11488232ec1SChuck Lever };
11588232ec1SChuck Lever 
11688232ec1SChuck Lever static void
handshake_req_alloc_get_desc(const struct handshake_req_alloc_test_param * param,char * desc)11788232ec1SChuck Lever handshake_req_alloc_get_desc(const struct handshake_req_alloc_test_param *param,
11888232ec1SChuck Lever 			     char *desc)
11988232ec1SChuck Lever {
12088232ec1SChuck Lever 	strscpy(desc, param->desc, KUNIT_PARAM_DESC_SIZE);
12188232ec1SChuck Lever }
12288232ec1SChuck Lever 
12388232ec1SChuck Lever /* Creates the function handshake_req_alloc_gen_params */
12488232ec1SChuck Lever KUNIT_ARRAY_PARAM(handshake_req_alloc, handshake_req_alloc_params,
12588232ec1SChuck Lever 		  handshake_req_alloc_get_desc);
12688232ec1SChuck Lever 
handshake_req_alloc_case(struct kunit * test)12788232ec1SChuck Lever static void handshake_req_alloc_case(struct kunit *test)
12888232ec1SChuck Lever {
12988232ec1SChuck Lever 	const struct handshake_req_alloc_test_param *param = test->param_value;
13088232ec1SChuck Lever 	struct handshake_req *result;
13188232ec1SChuck Lever 
13288232ec1SChuck Lever 	/* Arrange */
13388232ec1SChuck Lever 
13488232ec1SChuck Lever 	/* Act */
13588232ec1SChuck Lever 	result = handshake_req_alloc(param->proto, param->gfp);
13688232ec1SChuck Lever 
13788232ec1SChuck Lever 	/* Assert */
13888232ec1SChuck Lever 	if (param->expect_success)
13988232ec1SChuck Lever 		KUNIT_EXPECT_NOT_NULL(test, result);
14088232ec1SChuck Lever 	else
14188232ec1SChuck Lever 		KUNIT_EXPECT_NULL(test, result);
14288232ec1SChuck Lever 
14388232ec1SChuck Lever 	kfree(result);
14488232ec1SChuck Lever }
14588232ec1SChuck Lever 
handshake_req_submit_test1(struct kunit * test)14688232ec1SChuck Lever static void handshake_req_submit_test1(struct kunit *test)
14788232ec1SChuck Lever {
14888232ec1SChuck Lever 	struct socket *sock;
14988232ec1SChuck Lever 	int err, result;
15088232ec1SChuck Lever 
15188232ec1SChuck Lever 	/* Arrange */
15288232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
15388232ec1SChuck Lever 			    &sock, 1);
15488232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
15588232ec1SChuck Lever 
15688232ec1SChuck Lever 	/* Act */
15788232ec1SChuck Lever 	result = handshake_req_submit(sock, NULL, GFP_KERNEL);
15888232ec1SChuck Lever 
15988232ec1SChuck Lever 	/* Assert */
16088232ec1SChuck Lever 	KUNIT_EXPECT_EQ(test, result, -EINVAL);
16188232ec1SChuck Lever 
16288232ec1SChuck Lever 	sock_release(sock);
16388232ec1SChuck Lever }
16488232ec1SChuck Lever 
handshake_req_submit_test2(struct kunit * test)16588232ec1SChuck Lever static void handshake_req_submit_test2(struct kunit *test)
16688232ec1SChuck Lever {
16788232ec1SChuck Lever 	struct handshake_req *req;
16888232ec1SChuck Lever 	int result;
16988232ec1SChuck Lever 
17088232ec1SChuck Lever 	/* Arrange */
17188232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
17288232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
17388232ec1SChuck Lever 
17488232ec1SChuck Lever 	/* Act */
17588232ec1SChuck Lever 	result = handshake_req_submit(NULL, req, GFP_KERNEL);
17688232ec1SChuck Lever 
17788232ec1SChuck Lever 	/* Assert */
17888232ec1SChuck Lever 	KUNIT_EXPECT_EQ(test, result, -EINVAL);
17988232ec1SChuck Lever 
18088232ec1SChuck Lever 	/* handshake_req_submit() destroys @req on error */
18188232ec1SChuck Lever }
18288232ec1SChuck Lever 
handshake_req_submit_test3(struct kunit * test)18388232ec1SChuck Lever static void handshake_req_submit_test3(struct kunit *test)
18488232ec1SChuck Lever {
18588232ec1SChuck Lever 	struct handshake_req *req;
18688232ec1SChuck Lever 	struct socket *sock;
18788232ec1SChuck Lever 	int err, result;
18888232ec1SChuck Lever 
18988232ec1SChuck Lever 	/* Arrange */
19088232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
19188232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
19288232ec1SChuck Lever 
19388232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
19488232ec1SChuck Lever 			    &sock, 1);
19588232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
19688232ec1SChuck Lever 	sock->file = NULL;
19788232ec1SChuck Lever 
19888232ec1SChuck Lever 	/* Act */
19988232ec1SChuck Lever 	result = handshake_req_submit(sock, req, GFP_KERNEL);
20088232ec1SChuck Lever 
20188232ec1SChuck Lever 	/* Assert */
20288232ec1SChuck Lever 	KUNIT_EXPECT_EQ(test, result, -EINVAL);
20388232ec1SChuck Lever 
20488232ec1SChuck Lever 	/* handshake_req_submit() destroys @req on error */
20588232ec1SChuck Lever 	sock_release(sock);
20688232ec1SChuck Lever }
20788232ec1SChuck Lever 
handshake_req_submit_test4(struct kunit * test)20888232ec1SChuck Lever static void handshake_req_submit_test4(struct kunit *test)
20988232ec1SChuck Lever {
21088232ec1SChuck Lever 	struct handshake_req *req, *result;
21188232ec1SChuck Lever 	struct socket *sock;
21218c40a1cSChuck Lever 	struct file *filp;
21388232ec1SChuck Lever 	int err;
21488232ec1SChuck Lever 
21588232ec1SChuck Lever 	/* Arrange */
21688232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
21788232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
21888232ec1SChuck Lever 
21988232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
22088232ec1SChuck Lever 			    &sock, 1);
22188232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
22218c40a1cSChuck Lever 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
22318c40a1cSChuck Lever 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
22488232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
22518c40a1cSChuck Lever 	sock->file = filp;
22688232ec1SChuck Lever 
22788232ec1SChuck Lever 	err = handshake_req_submit(sock, req, GFP_KERNEL);
22888232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
22988232ec1SChuck Lever 
23088232ec1SChuck Lever 	/* Act */
23188232ec1SChuck Lever 	result = handshake_req_hash_lookup(sock->sk);
23288232ec1SChuck Lever 
23388232ec1SChuck Lever 	/* Assert */
23488232ec1SChuck Lever 	KUNIT_EXPECT_NOT_NULL(test, result);
23588232ec1SChuck Lever 	KUNIT_EXPECT_PTR_EQ(test, req, result);
23688232ec1SChuck Lever 
23788232ec1SChuck Lever 	handshake_req_cancel(sock->sk);
2384a0f07d7SJinjie Ruan 	fput(filp);
23988232ec1SChuck Lever }
24088232ec1SChuck Lever 
handshake_req_submit_test5(struct kunit * test)24188232ec1SChuck Lever static void handshake_req_submit_test5(struct kunit *test)
24288232ec1SChuck Lever {
24388232ec1SChuck Lever 	struct handshake_req *req;
24488232ec1SChuck Lever 	struct handshake_net *hn;
24588232ec1SChuck Lever 	struct socket *sock;
24618c40a1cSChuck Lever 	struct file *filp;
24788232ec1SChuck Lever 	struct net *net;
24888232ec1SChuck Lever 	int saved, err;
24988232ec1SChuck Lever 
25088232ec1SChuck Lever 	/* Arrange */
25188232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
25288232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
25388232ec1SChuck Lever 
25488232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
25588232ec1SChuck Lever 			    &sock, 1);
25688232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
25718c40a1cSChuck Lever 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
25818c40a1cSChuck Lever 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
25988232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
26018c40a1cSChuck Lever 	sock->file = filp;
26188232ec1SChuck Lever 
26288232ec1SChuck Lever 	net = sock_net(sock->sk);
26388232ec1SChuck Lever 	hn = handshake_pernet(net);
26488232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, hn);
26588232ec1SChuck Lever 
26688232ec1SChuck Lever 	saved = hn->hn_pending;
26788232ec1SChuck Lever 	hn->hn_pending = hn->hn_pending_max + 1;
26888232ec1SChuck Lever 
26988232ec1SChuck Lever 	/* Act */
27088232ec1SChuck Lever 	err = handshake_req_submit(sock, req, GFP_KERNEL);
27188232ec1SChuck Lever 
27288232ec1SChuck Lever 	/* Assert */
27388232ec1SChuck Lever 	KUNIT_EXPECT_EQ(test, err, -EAGAIN);
27488232ec1SChuck Lever 
2754a0f07d7SJinjie Ruan 	fput(filp);
27688232ec1SChuck Lever 	hn->hn_pending = saved;
27788232ec1SChuck Lever }
27888232ec1SChuck Lever 
handshake_req_submit_test6(struct kunit * test)27988232ec1SChuck Lever static void handshake_req_submit_test6(struct kunit *test)
28088232ec1SChuck Lever {
28188232ec1SChuck Lever 	struct handshake_req *req1, *req2;
28288232ec1SChuck Lever 	struct socket *sock;
28318c40a1cSChuck Lever 	struct file *filp;
28488232ec1SChuck Lever 	int err;
28588232ec1SChuck Lever 
28688232ec1SChuck Lever 	/* Arrange */
28788232ec1SChuck Lever 	req1 = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
28888232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req1);
28988232ec1SChuck Lever 	req2 = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
29088232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req2);
29188232ec1SChuck Lever 
29288232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
29388232ec1SChuck Lever 			    &sock, 1);
29488232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
29518c40a1cSChuck Lever 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
29618c40a1cSChuck Lever 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
29788232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
29818c40a1cSChuck Lever 	sock->file = filp;
29988232ec1SChuck Lever 
30088232ec1SChuck Lever 	/* Act */
30188232ec1SChuck Lever 	err = handshake_req_submit(sock, req1, GFP_KERNEL);
30288232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
30388232ec1SChuck Lever 	err = handshake_req_submit(sock, req2, GFP_KERNEL);
30488232ec1SChuck Lever 
30588232ec1SChuck Lever 	/* Assert */
30688232ec1SChuck Lever 	KUNIT_EXPECT_EQ(test, err, -EBUSY);
30788232ec1SChuck Lever 
30888232ec1SChuck Lever 	handshake_req_cancel(sock->sk);
3094a0f07d7SJinjie Ruan 	fput(filp);
31088232ec1SChuck Lever }
31188232ec1SChuck Lever 
handshake_req_cancel_test1(struct kunit * test)31288232ec1SChuck Lever static void handshake_req_cancel_test1(struct kunit *test)
31388232ec1SChuck Lever {
31488232ec1SChuck Lever 	struct handshake_req *req;
31588232ec1SChuck Lever 	struct socket *sock;
31618c40a1cSChuck Lever 	struct file *filp;
31788232ec1SChuck Lever 	bool result;
31888232ec1SChuck Lever 	int err;
31988232ec1SChuck Lever 
32088232ec1SChuck Lever 	/* Arrange */
32188232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
32288232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
32388232ec1SChuck Lever 
32488232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
32588232ec1SChuck Lever 			    &sock, 1);
32688232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
32788232ec1SChuck Lever 
32818c40a1cSChuck Lever 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
32918c40a1cSChuck Lever 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
33018c40a1cSChuck Lever 	sock->file = filp;
33188232ec1SChuck Lever 
33288232ec1SChuck Lever 	err = handshake_req_submit(sock, req, GFP_KERNEL);
33388232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
33488232ec1SChuck Lever 
33588232ec1SChuck Lever 	/* NB: handshake_req hasn't been accepted */
33688232ec1SChuck Lever 
33788232ec1SChuck Lever 	/* Act */
33888232ec1SChuck Lever 	result = handshake_req_cancel(sock->sk);
33988232ec1SChuck Lever 
34088232ec1SChuck Lever 	/* Assert */
34188232ec1SChuck Lever 	KUNIT_EXPECT_TRUE(test, result);
34288232ec1SChuck Lever 
3434a0f07d7SJinjie Ruan 	fput(filp);
34488232ec1SChuck Lever }
34588232ec1SChuck Lever 
handshake_req_cancel_test2(struct kunit * test)34688232ec1SChuck Lever static void handshake_req_cancel_test2(struct kunit *test)
34788232ec1SChuck Lever {
34888232ec1SChuck Lever 	struct handshake_req *req, *next;
34988232ec1SChuck Lever 	struct handshake_net *hn;
35088232ec1SChuck Lever 	struct socket *sock;
35118c40a1cSChuck Lever 	struct file *filp;
35288232ec1SChuck Lever 	struct net *net;
35388232ec1SChuck Lever 	bool result;
35488232ec1SChuck Lever 	int err;
35588232ec1SChuck Lever 
35688232ec1SChuck Lever 	/* Arrange */
35788232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
35888232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
35988232ec1SChuck Lever 
36088232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
36188232ec1SChuck Lever 			    &sock, 1);
36288232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
36388232ec1SChuck Lever 
36418c40a1cSChuck Lever 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
36518c40a1cSChuck Lever 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
36618c40a1cSChuck Lever 	sock->file = filp;
36788232ec1SChuck Lever 
36888232ec1SChuck Lever 	err = handshake_req_submit(sock, req, GFP_KERNEL);
36988232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
37088232ec1SChuck Lever 
37188232ec1SChuck Lever 	net = sock_net(sock->sk);
37288232ec1SChuck Lever 	hn = handshake_pernet(net);
37388232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, hn);
37488232ec1SChuck Lever 
37588232ec1SChuck Lever 	/* Pretend to accept this request */
37688232ec1SChuck Lever 	next = handshake_req_next(hn, HANDSHAKE_HANDLER_CLASS_TLSHD);
37788232ec1SChuck Lever 	KUNIT_ASSERT_PTR_EQ(test, req, next);
37888232ec1SChuck Lever 
37988232ec1SChuck Lever 	/* Act */
38088232ec1SChuck Lever 	result = handshake_req_cancel(sock->sk);
38188232ec1SChuck Lever 
38288232ec1SChuck Lever 	/* Assert */
38388232ec1SChuck Lever 	KUNIT_EXPECT_TRUE(test, result);
38488232ec1SChuck Lever 
3854a0f07d7SJinjie Ruan 	fput(filp);
38688232ec1SChuck Lever }
38788232ec1SChuck Lever 
handshake_req_cancel_test3(struct kunit * test)38888232ec1SChuck Lever static void handshake_req_cancel_test3(struct kunit *test)
38988232ec1SChuck Lever {
39088232ec1SChuck Lever 	struct handshake_req *req, *next;
39188232ec1SChuck Lever 	struct handshake_net *hn;
39288232ec1SChuck Lever 	struct socket *sock;
39318c40a1cSChuck Lever 	struct file *filp;
39488232ec1SChuck Lever 	struct net *net;
39588232ec1SChuck Lever 	bool result;
39688232ec1SChuck Lever 	int err;
39788232ec1SChuck Lever 
39888232ec1SChuck Lever 	/* Arrange */
39988232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
40088232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
40188232ec1SChuck Lever 
40288232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
40388232ec1SChuck Lever 			    &sock, 1);
40488232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
40588232ec1SChuck Lever 
40618c40a1cSChuck Lever 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
40718c40a1cSChuck Lever 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
40818c40a1cSChuck Lever 	sock->file = filp;
40988232ec1SChuck Lever 
41088232ec1SChuck Lever 	err = handshake_req_submit(sock, req, GFP_KERNEL);
41188232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
41288232ec1SChuck Lever 
41388232ec1SChuck Lever 	net = sock_net(sock->sk);
41488232ec1SChuck Lever 	hn = handshake_pernet(net);
41588232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, hn);
41688232ec1SChuck Lever 
41788232ec1SChuck Lever 	/* Pretend to accept this request */
41888232ec1SChuck Lever 	next = handshake_req_next(hn, HANDSHAKE_HANDLER_CLASS_TLSHD);
41988232ec1SChuck Lever 	KUNIT_ASSERT_PTR_EQ(test, req, next);
42088232ec1SChuck Lever 
42188232ec1SChuck Lever 	/* Pretend to complete this request */
42288232ec1SChuck Lever 	handshake_complete(next, -ETIMEDOUT, NULL);
42388232ec1SChuck Lever 
42488232ec1SChuck Lever 	/* Act */
42588232ec1SChuck Lever 	result = handshake_req_cancel(sock->sk);
42688232ec1SChuck Lever 
42788232ec1SChuck Lever 	/* Assert */
42888232ec1SChuck Lever 	KUNIT_EXPECT_FALSE(test, result);
42988232ec1SChuck Lever 
4304a0f07d7SJinjie Ruan 	fput(filp);
43188232ec1SChuck Lever }
43288232ec1SChuck Lever 
43388232ec1SChuck Lever static struct handshake_req *handshake_req_destroy_test;
43488232ec1SChuck Lever 
test_destroy_func(struct handshake_req * req)43588232ec1SChuck Lever static void test_destroy_func(struct handshake_req *req)
43688232ec1SChuck Lever {
43788232ec1SChuck Lever 	handshake_req_destroy_test = req;
43888232ec1SChuck Lever }
43988232ec1SChuck Lever 
44088232ec1SChuck Lever static struct handshake_proto handshake_req_alloc_proto_destroy = {
44188232ec1SChuck Lever 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
44288232ec1SChuck Lever 	.hp_accept		= test_accept_func,
44388232ec1SChuck Lever 	.hp_done		= test_done_func,
44488232ec1SChuck Lever 	.hp_destroy		= test_destroy_func,
44588232ec1SChuck Lever };
44688232ec1SChuck Lever 
handshake_req_destroy_test1(struct kunit * test)44788232ec1SChuck Lever static void handshake_req_destroy_test1(struct kunit *test)
44888232ec1SChuck Lever {
44988232ec1SChuck Lever 	struct handshake_req *req;
45088232ec1SChuck Lever 	struct socket *sock;
45118c40a1cSChuck Lever 	struct file *filp;
45288232ec1SChuck Lever 	int err;
45388232ec1SChuck Lever 
45488232ec1SChuck Lever 	/* Arrange */
45588232ec1SChuck Lever 	handshake_req_destroy_test = NULL;
45688232ec1SChuck Lever 
45788232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_destroy, GFP_KERNEL);
45888232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
45988232ec1SChuck Lever 
46088232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
46188232ec1SChuck Lever 			    &sock, 1);
46288232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
46388232ec1SChuck Lever 
46418c40a1cSChuck Lever 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
46518c40a1cSChuck Lever 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, filp);
46618c40a1cSChuck Lever 	sock->file = filp;
46788232ec1SChuck Lever 
46888232ec1SChuck Lever 	err = handshake_req_submit(sock, req, GFP_KERNEL);
46988232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
47088232ec1SChuck Lever 
47188232ec1SChuck Lever 	handshake_req_cancel(sock->sk);
47288232ec1SChuck Lever 
47388232ec1SChuck Lever 	/* Act */
474*d74226e0SChuck Lever 	/* Ensure the close/release/put process has run to
475*d74226e0SChuck Lever 	 * completion before checking the result.
476*d74226e0SChuck Lever 	 */
477*d74226e0SChuck Lever 	__fput_sync(filp);
47888232ec1SChuck Lever 
47988232ec1SChuck Lever 	/* Assert */
48088232ec1SChuck Lever 	KUNIT_EXPECT_PTR_EQ(test, handshake_req_destroy_test, req);
48188232ec1SChuck Lever }
48288232ec1SChuck Lever 
48388232ec1SChuck Lever static struct kunit_case handshake_api_test_cases[] = {
48488232ec1SChuck Lever 	{
48588232ec1SChuck Lever 		.name			= "req_alloc API fuzzing",
48688232ec1SChuck Lever 		.run_case		= handshake_req_alloc_case,
48788232ec1SChuck Lever 		.generate_params	= handshake_req_alloc_gen_params,
48888232ec1SChuck Lever 	},
48988232ec1SChuck Lever 	{
49088232ec1SChuck Lever 		.name			= "req_submit NULL req arg",
49188232ec1SChuck Lever 		.run_case		= handshake_req_submit_test1,
49288232ec1SChuck Lever 	},
49388232ec1SChuck Lever 	{
49488232ec1SChuck Lever 		.name			= "req_submit NULL sock arg",
49588232ec1SChuck Lever 		.run_case		= handshake_req_submit_test2,
49688232ec1SChuck Lever 	},
49788232ec1SChuck Lever 	{
49888232ec1SChuck Lever 		.name			= "req_submit NULL sock->file",
49988232ec1SChuck Lever 		.run_case		= handshake_req_submit_test3,
50088232ec1SChuck Lever 	},
50188232ec1SChuck Lever 	{
50288232ec1SChuck Lever 		.name			= "req_lookup works",
50388232ec1SChuck Lever 		.run_case		= handshake_req_submit_test4,
50488232ec1SChuck Lever 	},
50588232ec1SChuck Lever 	{
50688232ec1SChuck Lever 		.name			= "req_submit max pending",
50788232ec1SChuck Lever 		.run_case		= handshake_req_submit_test5,
50888232ec1SChuck Lever 	},
50988232ec1SChuck Lever 	{
51088232ec1SChuck Lever 		.name			= "req_submit multiple",
51188232ec1SChuck Lever 		.run_case		= handshake_req_submit_test6,
51288232ec1SChuck Lever 	},
51388232ec1SChuck Lever 	{
51488232ec1SChuck Lever 		.name			= "req_cancel before accept",
51588232ec1SChuck Lever 		.run_case		= handshake_req_cancel_test1,
51688232ec1SChuck Lever 	},
51788232ec1SChuck Lever 	{
51888232ec1SChuck Lever 		.name			= "req_cancel after accept",
51988232ec1SChuck Lever 		.run_case		= handshake_req_cancel_test2,
52088232ec1SChuck Lever 	},
52188232ec1SChuck Lever 	{
52288232ec1SChuck Lever 		.name			= "req_cancel after done",
52388232ec1SChuck Lever 		.run_case		= handshake_req_cancel_test3,
52488232ec1SChuck Lever 	},
52588232ec1SChuck Lever 	{
52688232ec1SChuck Lever 		.name			= "req_destroy works",
52788232ec1SChuck Lever 		.run_case		= handshake_req_destroy_test1,
52888232ec1SChuck Lever 	},
52988232ec1SChuck Lever 	{}
53088232ec1SChuck Lever };
53188232ec1SChuck Lever 
53288232ec1SChuck Lever static struct kunit_suite handshake_api_suite = {
53388232ec1SChuck Lever        .name                   = "Handshake API tests",
53488232ec1SChuck Lever        .test_cases             = handshake_api_test_cases,
53588232ec1SChuck Lever };
53688232ec1SChuck Lever 
53788232ec1SChuck Lever kunit_test_suites(&handshake_api_suite);
53888232ec1SChuck Lever 
53988232ec1SChuck Lever MODULE_DESCRIPTION("Test handshake upcall API functions");
54088232ec1SChuck Lever MODULE_LICENSE("GPL");
541