1*88232ec1SChuck Lever // SPDX-License-Identifier: GPL-2.0
2*88232ec1SChuck Lever /*
3*88232ec1SChuck Lever  * Copyright (c) 2023 Oracle and/or its affiliates.
4*88232ec1SChuck Lever  *
5*88232ec1SChuck Lever  * KUnit test of the handshake upcall mechanism.
6*88232ec1SChuck Lever  */
7*88232ec1SChuck Lever 
8*88232ec1SChuck Lever #include <kunit/test.h>
9*88232ec1SChuck Lever #include <kunit/visibility.h>
10*88232ec1SChuck Lever 
11*88232ec1SChuck Lever #include <linux/kernel.h>
12*88232ec1SChuck Lever 
13*88232ec1SChuck Lever #include <net/sock.h>
14*88232ec1SChuck Lever #include <net/genetlink.h>
15*88232ec1SChuck Lever #include <net/netns/generic.h>
16*88232ec1SChuck Lever 
17*88232ec1SChuck Lever #include <uapi/linux/handshake.h>
18*88232ec1SChuck Lever #include "handshake.h"
19*88232ec1SChuck Lever 
20*88232ec1SChuck Lever MODULE_IMPORT_NS(EXPORTED_FOR_KUNIT_TESTING);
21*88232ec1SChuck Lever 
22*88232ec1SChuck Lever static int test_accept_func(struct handshake_req *req, struct genl_info *info,
23*88232ec1SChuck Lever 			    int fd)
24*88232ec1SChuck Lever {
25*88232ec1SChuck Lever 	return 0;
26*88232ec1SChuck Lever }
27*88232ec1SChuck Lever 
28*88232ec1SChuck Lever static void test_done_func(struct handshake_req *req, unsigned int status,
29*88232ec1SChuck Lever 			   struct genl_info *info)
30*88232ec1SChuck Lever {
31*88232ec1SChuck Lever }
32*88232ec1SChuck Lever 
33*88232ec1SChuck Lever struct handshake_req_alloc_test_param {
34*88232ec1SChuck Lever 	const char			*desc;
35*88232ec1SChuck Lever 	struct handshake_proto		*proto;
36*88232ec1SChuck Lever 	gfp_t				gfp;
37*88232ec1SChuck Lever 	bool				expect_success;
38*88232ec1SChuck Lever };
39*88232ec1SChuck Lever 
40*88232ec1SChuck Lever static struct handshake_proto handshake_req_alloc_proto_2 = {
41*88232ec1SChuck Lever 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_NONE,
42*88232ec1SChuck Lever };
43*88232ec1SChuck Lever 
44*88232ec1SChuck Lever static struct handshake_proto handshake_req_alloc_proto_3 = {
45*88232ec1SChuck Lever 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_MAX,
46*88232ec1SChuck Lever };
47*88232ec1SChuck Lever 
48*88232ec1SChuck Lever static struct handshake_proto handshake_req_alloc_proto_4 = {
49*88232ec1SChuck Lever 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
50*88232ec1SChuck Lever };
51*88232ec1SChuck Lever 
52*88232ec1SChuck Lever static struct handshake_proto handshake_req_alloc_proto_5 = {
53*88232ec1SChuck Lever 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
54*88232ec1SChuck Lever 	.hp_accept		= test_accept_func,
55*88232ec1SChuck Lever };
56*88232ec1SChuck Lever 
57*88232ec1SChuck Lever static struct handshake_proto handshake_req_alloc_proto_6 = {
58*88232ec1SChuck Lever 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
59*88232ec1SChuck Lever 	.hp_privsize		= UINT_MAX,
60*88232ec1SChuck Lever 	.hp_accept		= test_accept_func,
61*88232ec1SChuck Lever 	.hp_done		= test_done_func,
62*88232ec1SChuck Lever };
63*88232ec1SChuck Lever 
64*88232ec1SChuck Lever static struct handshake_proto handshake_req_alloc_proto_good = {
65*88232ec1SChuck Lever 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
66*88232ec1SChuck Lever 	.hp_accept		= test_accept_func,
67*88232ec1SChuck Lever 	.hp_done		= test_done_func,
68*88232ec1SChuck Lever };
69*88232ec1SChuck Lever 
70*88232ec1SChuck Lever static const
71*88232ec1SChuck Lever struct handshake_req_alloc_test_param handshake_req_alloc_params[] = {
72*88232ec1SChuck Lever 	{
73*88232ec1SChuck Lever 		.desc			= "handshake_req_alloc NULL proto",
74*88232ec1SChuck Lever 		.proto			= NULL,
75*88232ec1SChuck Lever 		.gfp			= GFP_KERNEL,
76*88232ec1SChuck Lever 		.expect_success		= false,
77*88232ec1SChuck Lever 	},
78*88232ec1SChuck Lever 	{
79*88232ec1SChuck Lever 		.desc			= "handshake_req_alloc CLASS_NONE",
80*88232ec1SChuck Lever 		.proto			= &handshake_req_alloc_proto_2,
81*88232ec1SChuck Lever 		.gfp			= GFP_KERNEL,
82*88232ec1SChuck Lever 		.expect_success		= false,
83*88232ec1SChuck Lever 	},
84*88232ec1SChuck Lever 	{
85*88232ec1SChuck Lever 		.desc			= "handshake_req_alloc CLASS_MAX",
86*88232ec1SChuck Lever 		.proto			= &handshake_req_alloc_proto_3,
87*88232ec1SChuck Lever 		.gfp			= GFP_KERNEL,
88*88232ec1SChuck Lever 		.expect_success		= false,
89*88232ec1SChuck Lever 	},
90*88232ec1SChuck Lever 	{
91*88232ec1SChuck Lever 		.desc			= "handshake_req_alloc no callbacks",
92*88232ec1SChuck Lever 		.proto			= &handshake_req_alloc_proto_4,
93*88232ec1SChuck Lever 		.gfp			= GFP_KERNEL,
94*88232ec1SChuck Lever 		.expect_success		= false,
95*88232ec1SChuck Lever 	},
96*88232ec1SChuck Lever 	{
97*88232ec1SChuck Lever 		.desc			= "handshake_req_alloc no done callback",
98*88232ec1SChuck Lever 		.proto			= &handshake_req_alloc_proto_5,
99*88232ec1SChuck Lever 		.gfp			= GFP_KERNEL,
100*88232ec1SChuck Lever 		.expect_success		= false,
101*88232ec1SChuck Lever 	},
102*88232ec1SChuck Lever 	{
103*88232ec1SChuck Lever 		.desc			= "handshake_req_alloc excessive privsize",
104*88232ec1SChuck Lever 		.proto			= &handshake_req_alloc_proto_6,
105*88232ec1SChuck Lever 		.gfp			= GFP_KERNEL,
106*88232ec1SChuck Lever 		.expect_success		= false,
107*88232ec1SChuck Lever 	},
108*88232ec1SChuck Lever 	{
109*88232ec1SChuck Lever 		.desc			= "handshake_req_alloc all good",
110*88232ec1SChuck Lever 		.proto			= &handshake_req_alloc_proto_good,
111*88232ec1SChuck Lever 		.gfp			= GFP_KERNEL,
112*88232ec1SChuck Lever 		.expect_success		= true,
113*88232ec1SChuck Lever 	},
114*88232ec1SChuck Lever };
115*88232ec1SChuck Lever 
116*88232ec1SChuck Lever static void
117*88232ec1SChuck Lever handshake_req_alloc_get_desc(const struct handshake_req_alloc_test_param *param,
118*88232ec1SChuck Lever 			     char *desc)
119*88232ec1SChuck Lever {
120*88232ec1SChuck Lever 	strscpy(desc, param->desc, KUNIT_PARAM_DESC_SIZE);
121*88232ec1SChuck Lever }
122*88232ec1SChuck Lever 
123*88232ec1SChuck Lever /* Creates the function handshake_req_alloc_gen_params */
124*88232ec1SChuck Lever KUNIT_ARRAY_PARAM(handshake_req_alloc, handshake_req_alloc_params,
125*88232ec1SChuck Lever 		  handshake_req_alloc_get_desc);
126*88232ec1SChuck Lever 
127*88232ec1SChuck Lever static void handshake_req_alloc_case(struct kunit *test)
128*88232ec1SChuck Lever {
129*88232ec1SChuck Lever 	const struct handshake_req_alloc_test_param *param = test->param_value;
130*88232ec1SChuck Lever 	struct handshake_req *result;
131*88232ec1SChuck Lever 
132*88232ec1SChuck Lever 	/* Arrange */
133*88232ec1SChuck Lever 
134*88232ec1SChuck Lever 	/* Act */
135*88232ec1SChuck Lever 	result = handshake_req_alloc(param->proto, param->gfp);
136*88232ec1SChuck Lever 
137*88232ec1SChuck Lever 	/* Assert */
138*88232ec1SChuck Lever 	if (param->expect_success)
139*88232ec1SChuck Lever 		KUNIT_EXPECT_NOT_NULL(test, result);
140*88232ec1SChuck Lever 	else
141*88232ec1SChuck Lever 		KUNIT_EXPECT_NULL(test, result);
142*88232ec1SChuck Lever 
143*88232ec1SChuck Lever 	kfree(result);
144*88232ec1SChuck Lever }
145*88232ec1SChuck Lever 
146*88232ec1SChuck Lever static void handshake_req_submit_test1(struct kunit *test)
147*88232ec1SChuck Lever {
148*88232ec1SChuck Lever 	struct socket *sock;
149*88232ec1SChuck Lever 	int err, result;
150*88232ec1SChuck Lever 
151*88232ec1SChuck Lever 	/* Arrange */
152*88232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
153*88232ec1SChuck Lever 			    &sock, 1);
154*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
155*88232ec1SChuck Lever 
156*88232ec1SChuck Lever 	/* Act */
157*88232ec1SChuck Lever 	result = handshake_req_submit(sock, NULL, GFP_KERNEL);
158*88232ec1SChuck Lever 
159*88232ec1SChuck Lever 	/* Assert */
160*88232ec1SChuck Lever 	KUNIT_EXPECT_EQ(test, result, -EINVAL);
161*88232ec1SChuck Lever 
162*88232ec1SChuck Lever 	sock_release(sock);
163*88232ec1SChuck Lever }
164*88232ec1SChuck Lever 
165*88232ec1SChuck Lever static void handshake_req_submit_test2(struct kunit *test)
166*88232ec1SChuck Lever {
167*88232ec1SChuck Lever 	struct handshake_req *req;
168*88232ec1SChuck Lever 	int result;
169*88232ec1SChuck Lever 
170*88232ec1SChuck Lever 	/* Arrange */
171*88232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
172*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
173*88232ec1SChuck Lever 
174*88232ec1SChuck Lever 	/* Act */
175*88232ec1SChuck Lever 	result = handshake_req_submit(NULL, req, GFP_KERNEL);
176*88232ec1SChuck Lever 
177*88232ec1SChuck Lever 	/* Assert */
178*88232ec1SChuck Lever 	KUNIT_EXPECT_EQ(test, result, -EINVAL);
179*88232ec1SChuck Lever 
180*88232ec1SChuck Lever 	/* handshake_req_submit() destroys @req on error */
181*88232ec1SChuck Lever }
182*88232ec1SChuck Lever 
183*88232ec1SChuck Lever static void handshake_req_submit_test3(struct kunit *test)
184*88232ec1SChuck Lever {
185*88232ec1SChuck Lever 	struct handshake_req *req;
186*88232ec1SChuck Lever 	struct socket *sock;
187*88232ec1SChuck Lever 	int err, result;
188*88232ec1SChuck Lever 
189*88232ec1SChuck Lever 	/* Arrange */
190*88232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
191*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
192*88232ec1SChuck Lever 
193*88232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
194*88232ec1SChuck Lever 			    &sock, 1);
195*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
196*88232ec1SChuck Lever 	sock->file = NULL;
197*88232ec1SChuck Lever 
198*88232ec1SChuck Lever 	/* Act */
199*88232ec1SChuck Lever 	result = handshake_req_submit(sock, req, GFP_KERNEL);
200*88232ec1SChuck Lever 
201*88232ec1SChuck Lever 	/* Assert */
202*88232ec1SChuck Lever 	KUNIT_EXPECT_EQ(test, result, -EINVAL);
203*88232ec1SChuck Lever 
204*88232ec1SChuck Lever 	/* handshake_req_submit() destroys @req on error */
205*88232ec1SChuck Lever 	sock_release(sock);
206*88232ec1SChuck Lever }
207*88232ec1SChuck Lever 
208*88232ec1SChuck Lever static void handshake_req_submit_test4(struct kunit *test)
209*88232ec1SChuck Lever {
210*88232ec1SChuck Lever 	struct handshake_req *req, *result;
211*88232ec1SChuck Lever 	struct socket *sock;
212*88232ec1SChuck Lever 	int err;
213*88232ec1SChuck Lever 
214*88232ec1SChuck Lever 	/* Arrange */
215*88232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
216*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
217*88232ec1SChuck Lever 
218*88232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
219*88232ec1SChuck Lever 			    &sock, 1);
220*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
221*88232ec1SChuck Lever 	sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
222*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
223*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
224*88232ec1SChuck Lever 
225*88232ec1SChuck Lever 	err = handshake_req_submit(sock, req, GFP_KERNEL);
226*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
227*88232ec1SChuck Lever 
228*88232ec1SChuck Lever 	/* Act */
229*88232ec1SChuck Lever 	result = handshake_req_hash_lookup(sock->sk);
230*88232ec1SChuck Lever 
231*88232ec1SChuck Lever 	/* Assert */
232*88232ec1SChuck Lever 	KUNIT_EXPECT_NOT_NULL(test, result);
233*88232ec1SChuck Lever 	KUNIT_EXPECT_PTR_EQ(test, req, result);
234*88232ec1SChuck Lever 
235*88232ec1SChuck Lever 	handshake_req_cancel(sock->sk);
236*88232ec1SChuck Lever 	sock_release(sock);
237*88232ec1SChuck Lever }
238*88232ec1SChuck Lever 
239*88232ec1SChuck Lever static void handshake_req_submit_test5(struct kunit *test)
240*88232ec1SChuck Lever {
241*88232ec1SChuck Lever 	struct handshake_req *req;
242*88232ec1SChuck Lever 	struct handshake_net *hn;
243*88232ec1SChuck Lever 	struct socket *sock;
244*88232ec1SChuck Lever 	struct net *net;
245*88232ec1SChuck Lever 	int saved, err;
246*88232ec1SChuck Lever 
247*88232ec1SChuck Lever 	/* Arrange */
248*88232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
249*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
250*88232ec1SChuck Lever 
251*88232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
252*88232ec1SChuck Lever 			    &sock, 1);
253*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
254*88232ec1SChuck Lever 	sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
255*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
256*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
257*88232ec1SChuck Lever 
258*88232ec1SChuck Lever 	net = sock_net(sock->sk);
259*88232ec1SChuck Lever 	hn = handshake_pernet(net);
260*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, hn);
261*88232ec1SChuck Lever 
262*88232ec1SChuck Lever 	saved = hn->hn_pending;
263*88232ec1SChuck Lever 	hn->hn_pending = hn->hn_pending_max + 1;
264*88232ec1SChuck Lever 
265*88232ec1SChuck Lever 	/* Act */
266*88232ec1SChuck Lever 	err = handshake_req_submit(sock, req, GFP_KERNEL);
267*88232ec1SChuck Lever 
268*88232ec1SChuck Lever 	/* Assert */
269*88232ec1SChuck Lever 	KUNIT_EXPECT_EQ(test, err, -EAGAIN);
270*88232ec1SChuck Lever 
271*88232ec1SChuck Lever 	sock_release(sock);
272*88232ec1SChuck Lever 	hn->hn_pending = saved;
273*88232ec1SChuck Lever }
274*88232ec1SChuck Lever 
275*88232ec1SChuck Lever static void handshake_req_submit_test6(struct kunit *test)
276*88232ec1SChuck Lever {
277*88232ec1SChuck Lever 	struct handshake_req *req1, *req2;
278*88232ec1SChuck Lever 	struct socket *sock;
279*88232ec1SChuck Lever 	int err;
280*88232ec1SChuck Lever 
281*88232ec1SChuck Lever 	/* Arrange */
282*88232ec1SChuck Lever 	req1 = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
283*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req1);
284*88232ec1SChuck Lever 	req2 = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
285*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req2);
286*88232ec1SChuck Lever 
287*88232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
288*88232ec1SChuck Lever 			    &sock, 1);
289*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
290*88232ec1SChuck Lever 	sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
291*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
292*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, sock->sk);
293*88232ec1SChuck Lever 
294*88232ec1SChuck Lever 	/* Act */
295*88232ec1SChuck Lever 	err = handshake_req_submit(sock, req1, GFP_KERNEL);
296*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
297*88232ec1SChuck Lever 	err = handshake_req_submit(sock, req2, GFP_KERNEL);
298*88232ec1SChuck Lever 
299*88232ec1SChuck Lever 	/* Assert */
300*88232ec1SChuck Lever 	KUNIT_EXPECT_EQ(test, err, -EBUSY);
301*88232ec1SChuck Lever 
302*88232ec1SChuck Lever 	handshake_req_cancel(sock->sk);
303*88232ec1SChuck Lever 	sock_release(sock);
304*88232ec1SChuck Lever }
305*88232ec1SChuck Lever 
306*88232ec1SChuck Lever static void handshake_req_cancel_test1(struct kunit *test)
307*88232ec1SChuck Lever {
308*88232ec1SChuck Lever 	struct handshake_req *req;
309*88232ec1SChuck Lever 	struct socket *sock;
310*88232ec1SChuck Lever 	bool result;
311*88232ec1SChuck Lever 	int err;
312*88232ec1SChuck Lever 
313*88232ec1SChuck Lever 	/* Arrange */
314*88232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
315*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
316*88232ec1SChuck Lever 
317*88232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
318*88232ec1SChuck Lever 			    &sock, 1);
319*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
320*88232ec1SChuck Lever 
321*88232ec1SChuck Lever 	sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
322*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
323*88232ec1SChuck Lever 
324*88232ec1SChuck Lever 	err = handshake_req_submit(sock, req, GFP_KERNEL);
325*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
326*88232ec1SChuck Lever 
327*88232ec1SChuck Lever 	/* NB: handshake_req hasn't been accepted */
328*88232ec1SChuck Lever 
329*88232ec1SChuck Lever 	/* Act */
330*88232ec1SChuck Lever 	result = handshake_req_cancel(sock->sk);
331*88232ec1SChuck Lever 
332*88232ec1SChuck Lever 	/* Assert */
333*88232ec1SChuck Lever 	KUNIT_EXPECT_TRUE(test, result);
334*88232ec1SChuck Lever 
335*88232ec1SChuck Lever 	sock_release(sock);
336*88232ec1SChuck Lever }
337*88232ec1SChuck Lever 
338*88232ec1SChuck Lever static void handshake_req_cancel_test2(struct kunit *test)
339*88232ec1SChuck Lever {
340*88232ec1SChuck Lever 	struct handshake_req *req, *next;
341*88232ec1SChuck Lever 	struct handshake_net *hn;
342*88232ec1SChuck Lever 	struct socket *sock;
343*88232ec1SChuck Lever 	struct net *net;
344*88232ec1SChuck Lever 	bool result;
345*88232ec1SChuck Lever 	int err;
346*88232ec1SChuck Lever 
347*88232ec1SChuck Lever 	/* Arrange */
348*88232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
349*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
350*88232ec1SChuck Lever 
351*88232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
352*88232ec1SChuck Lever 			    &sock, 1);
353*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
354*88232ec1SChuck Lever 
355*88232ec1SChuck Lever 	sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
356*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
357*88232ec1SChuck Lever 
358*88232ec1SChuck Lever 	err = handshake_req_submit(sock, req, GFP_KERNEL);
359*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
360*88232ec1SChuck Lever 
361*88232ec1SChuck Lever 	net = sock_net(sock->sk);
362*88232ec1SChuck Lever 	hn = handshake_pernet(net);
363*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, hn);
364*88232ec1SChuck Lever 
365*88232ec1SChuck Lever 	/* Pretend to accept this request */
366*88232ec1SChuck Lever 	next = handshake_req_next(hn, HANDSHAKE_HANDLER_CLASS_TLSHD);
367*88232ec1SChuck Lever 	KUNIT_ASSERT_PTR_EQ(test, req, next);
368*88232ec1SChuck Lever 
369*88232ec1SChuck Lever 	/* Act */
370*88232ec1SChuck Lever 	result = handshake_req_cancel(sock->sk);
371*88232ec1SChuck Lever 
372*88232ec1SChuck Lever 	/* Assert */
373*88232ec1SChuck Lever 	KUNIT_EXPECT_TRUE(test, result);
374*88232ec1SChuck Lever 
375*88232ec1SChuck Lever 	sock_release(sock);
376*88232ec1SChuck Lever }
377*88232ec1SChuck Lever 
378*88232ec1SChuck Lever static void handshake_req_cancel_test3(struct kunit *test)
379*88232ec1SChuck Lever {
380*88232ec1SChuck Lever 	struct handshake_req *req, *next;
381*88232ec1SChuck Lever 	struct handshake_net *hn;
382*88232ec1SChuck Lever 	struct socket *sock;
383*88232ec1SChuck Lever 	struct net *net;
384*88232ec1SChuck Lever 	bool result;
385*88232ec1SChuck Lever 	int err;
386*88232ec1SChuck Lever 
387*88232ec1SChuck Lever 	/* Arrange */
388*88232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
389*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
390*88232ec1SChuck Lever 
391*88232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
392*88232ec1SChuck Lever 			    &sock, 1);
393*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
394*88232ec1SChuck Lever 
395*88232ec1SChuck Lever 	sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
396*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
397*88232ec1SChuck Lever 
398*88232ec1SChuck Lever 	err = handshake_req_submit(sock, req, GFP_KERNEL);
399*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
400*88232ec1SChuck Lever 
401*88232ec1SChuck Lever 	net = sock_net(sock->sk);
402*88232ec1SChuck Lever 	hn = handshake_pernet(net);
403*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, hn);
404*88232ec1SChuck Lever 
405*88232ec1SChuck Lever 	/* Pretend to accept this request */
406*88232ec1SChuck Lever 	next = handshake_req_next(hn, HANDSHAKE_HANDLER_CLASS_TLSHD);
407*88232ec1SChuck Lever 	KUNIT_ASSERT_PTR_EQ(test, req, next);
408*88232ec1SChuck Lever 
409*88232ec1SChuck Lever 	/* Pretend to complete this request */
410*88232ec1SChuck Lever 	handshake_complete(next, -ETIMEDOUT, NULL);
411*88232ec1SChuck Lever 
412*88232ec1SChuck Lever 	/* Act */
413*88232ec1SChuck Lever 	result = handshake_req_cancel(sock->sk);
414*88232ec1SChuck Lever 
415*88232ec1SChuck Lever 	/* Assert */
416*88232ec1SChuck Lever 	KUNIT_EXPECT_FALSE(test, result);
417*88232ec1SChuck Lever 
418*88232ec1SChuck Lever 	sock_release(sock);
419*88232ec1SChuck Lever }
420*88232ec1SChuck Lever 
421*88232ec1SChuck Lever static struct handshake_req *handshake_req_destroy_test;
422*88232ec1SChuck Lever 
423*88232ec1SChuck Lever static void test_destroy_func(struct handshake_req *req)
424*88232ec1SChuck Lever {
425*88232ec1SChuck Lever 	handshake_req_destroy_test = req;
426*88232ec1SChuck Lever }
427*88232ec1SChuck Lever 
428*88232ec1SChuck Lever static struct handshake_proto handshake_req_alloc_proto_destroy = {
429*88232ec1SChuck Lever 	.hp_handler_class	= HANDSHAKE_HANDLER_CLASS_TLSHD,
430*88232ec1SChuck Lever 	.hp_accept		= test_accept_func,
431*88232ec1SChuck Lever 	.hp_done		= test_done_func,
432*88232ec1SChuck Lever 	.hp_destroy		= test_destroy_func,
433*88232ec1SChuck Lever };
434*88232ec1SChuck Lever 
435*88232ec1SChuck Lever static void handshake_req_destroy_test1(struct kunit *test)
436*88232ec1SChuck Lever {
437*88232ec1SChuck Lever 	struct handshake_req *req;
438*88232ec1SChuck Lever 	struct socket *sock;
439*88232ec1SChuck Lever 	int err;
440*88232ec1SChuck Lever 
441*88232ec1SChuck Lever 	/* Arrange */
442*88232ec1SChuck Lever 	handshake_req_destroy_test = NULL;
443*88232ec1SChuck Lever 
444*88232ec1SChuck Lever 	req = handshake_req_alloc(&handshake_req_alloc_proto_destroy, GFP_KERNEL);
445*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_NULL(test, req);
446*88232ec1SChuck Lever 
447*88232ec1SChuck Lever 	err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
448*88232ec1SChuck Lever 			    &sock, 1);
449*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
450*88232ec1SChuck Lever 
451*88232ec1SChuck Lever 	sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
452*88232ec1SChuck Lever 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
453*88232ec1SChuck Lever 
454*88232ec1SChuck Lever 	err = handshake_req_submit(sock, req, GFP_KERNEL);
455*88232ec1SChuck Lever 	KUNIT_ASSERT_EQ(test, err, 0);
456*88232ec1SChuck Lever 
457*88232ec1SChuck Lever 	handshake_req_cancel(sock->sk);
458*88232ec1SChuck Lever 
459*88232ec1SChuck Lever 	/* Act */
460*88232ec1SChuck Lever 	sock_release(sock);
461*88232ec1SChuck Lever 
462*88232ec1SChuck Lever 	/* Assert */
463*88232ec1SChuck Lever 	KUNIT_EXPECT_PTR_EQ(test, handshake_req_destroy_test, req);
464*88232ec1SChuck Lever }
465*88232ec1SChuck Lever 
466*88232ec1SChuck Lever static struct kunit_case handshake_api_test_cases[] = {
467*88232ec1SChuck Lever 	{
468*88232ec1SChuck Lever 		.name			= "req_alloc API fuzzing",
469*88232ec1SChuck Lever 		.run_case		= handshake_req_alloc_case,
470*88232ec1SChuck Lever 		.generate_params	= handshake_req_alloc_gen_params,
471*88232ec1SChuck Lever 	},
472*88232ec1SChuck Lever 	{
473*88232ec1SChuck Lever 		.name			= "req_submit NULL req arg",
474*88232ec1SChuck Lever 		.run_case		= handshake_req_submit_test1,
475*88232ec1SChuck Lever 	},
476*88232ec1SChuck Lever 	{
477*88232ec1SChuck Lever 		.name			= "req_submit NULL sock arg",
478*88232ec1SChuck Lever 		.run_case		= handshake_req_submit_test2,
479*88232ec1SChuck Lever 	},
480*88232ec1SChuck Lever 	{
481*88232ec1SChuck Lever 		.name			= "req_submit NULL sock->file",
482*88232ec1SChuck Lever 		.run_case		= handshake_req_submit_test3,
483*88232ec1SChuck Lever 	},
484*88232ec1SChuck Lever 	{
485*88232ec1SChuck Lever 		.name			= "req_lookup works",
486*88232ec1SChuck Lever 		.run_case		= handshake_req_submit_test4,
487*88232ec1SChuck Lever 	},
488*88232ec1SChuck Lever 	{
489*88232ec1SChuck Lever 		.name			= "req_submit max pending",
490*88232ec1SChuck Lever 		.run_case		= handshake_req_submit_test5,
491*88232ec1SChuck Lever 	},
492*88232ec1SChuck Lever 	{
493*88232ec1SChuck Lever 		.name			= "req_submit multiple",
494*88232ec1SChuck Lever 		.run_case		= handshake_req_submit_test6,
495*88232ec1SChuck Lever 	},
496*88232ec1SChuck Lever 	{
497*88232ec1SChuck Lever 		.name			= "req_cancel before accept",
498*88232ec1SChuck Lever 		.run_case		= handshake_req_cancel_test1,
499*88232ec1SChuck Lever 	},
500*88232ec1SChuck Lever 	{
501*88232ec1SChuck Lever 		.name			= "req_cancel after accept",
502*88232ec1SChuck Lever 		.run_case		= handshake_req_cancel_test2,
503*88232ec1SChuck Lever 	},
504*88232ec1SChuck Lever 	{
505*88232ec1SChuck Lever 		.name			= "req_cancel after done",
506*88232ec1SChuck Lever 		.run_case		= handshake_req_cancel_test3,
507*88232ec1SChuck Lever 	},
508*88232ec1SChuck Lever 	{
509*88232ec1SChuck Lever 		.name			= "req_destroy works",
510*88232ec1SChuck Lever 		.run_case		= handshake_req_destroy_test1,
511*88232ec1SChuck Lever 	},
512*88232ec1SChuck Lever 	{}
513*88232ec1SChuck Lever };
514*88232ec1SChuck Lever 
515*88232ec1SChuck Lever static struct kunit_suite handshake_api_suite = {
516*88232ec1SChuck Lever        .name                   = "Handshake API tests",
517*88232ec1SChuck Lever        .test_cases             = handshake_api_test_cases,
518*88232ec1SChuck Lever };
519*88232ec1SChuck Lever 
520*88232ec1SChuck Lever kunit_test_suites(&handshake_api_suite);
521*88232ec1SChuck Lever 
522*88232ec1SChuck Lever MODULE_DESCRIPTION("Test handshake upcall API functions");
523*88232ec1SChuck Lever MODULE_LICENSE("GPL");
524