1 // SPDX-License-Identifier: GPL-2.0 2 #include <kunit/test.h> 3 4 #include "protocol.h" 5 6 static struct mptcp_subflow_request_sock *build_req_sock(struct kunit *test) 7 { 8 struct mptcp_subflow_request_sock *req; 9 10 req = kunit_kzalloc(test, sizeof(struct mptcp_subflow_request_sock), 11 GFP_USER); 12 KUNIT_EXPECT_NOT_ERR_OR_NULL(test, req); 13 mptcp_token_init_request((struct request_sock *)req); 14 return req; 15 } 16 17 static void mptcp_token_test_req_basic(struct kunit *test) 18 { 19 struct mptcp_subflow_request_sock *req = build_req_sock(test); 20 struct mptcp_sock *null_msk = NULL; 21 22 KUNIT_ASSERT_EQ(test, 0, 23 mptcp_token_new_request((struct request_sock *)req)); 24 KUNIT_EXPECT_NE(test, 0, (int)req->token); 25 KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(req->token)); 26 27 /* cleanup */ 28 mptcp_token_destroy_request((struct request_sock *)req); 29 } 30 31 static struct inet_connection_sock *build_icsk(struct kunit *test) 32 { 33 struct inet_connection_sock *icsk; 34 35 icsk = kunit_kzalloc(test, sizeof(struct inet_connection_sock), 36 GFP_USER); 37 KUNIT_EXPECT_NOT_ERR_OR_NULL(test, icsk); 38 return icsk; 39 } 40 41 static struct mptcp_subflow_context *build_ctx(struct kunit *test) 42 { 43 struct mptcp_subflow_context *ctx; 44 45 ctx = kunit_kzalloc(test, sizeof(struct mptcp_subflow_context), 46 GFP_USER); 47 KUNIT_EXPECT_NOT_ERR_OR_NULL(test, ctx); 48 return ctx; 49 } 50 51 static struct mptcp_sock *build_msk(struct kunit *test) 52 { 53 struct mptcp_sock *msk; 54 55 msk = kunit_kzalloc(test, sizeof(struct mptcp_sock), GFP_USER); 56 KUNIT_EXPECT_NOT_ERR_OR_NULL(test, msk); 57 refcount_set(&((struct sock *)msk)->sk_refcnt, 1); 58 return msk; 59 } 60 61 static void mptcp_token_test_msk_basic(struct kunit *test) 62 { 63 struct inet_connection_sock *icsk = build_icsk(test); 64 struct mptcp_subflow_context *ctx = build_ctx(test); 65 struct mptcp_sock *msk = build_msk(test); 66 struct mptcp_sock *null_msk = NULL; 67 struct sock *sk; 68 69 rcu_assign_pointer(icsk->icsk_ulp_data, ctx); 70 ctx->conn = (struct sock *)msk; 71 sk = (struct sock *)msk; 72 73 KUNIT_ASSERT_EQ(test, 0, 74 mptcp_token_new_connect((struct sock *)icsk)); 75 KUNIT_EXPECT_NE(test, 0, (int)ctx->token); 76 KUNIT_EXPECT_EQ(test, ctx->token, msk->token); 77 KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(ctx->token)); 78 KUNIT_EXPECT_EQ(test, 2, (int)refcount_read(&sk->sk_refcnt)); 79 80 mptcp_token_destroy(msk); 81 KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(ctx->token)); 82 } 83 84 static void mptcp_token_test_accept(struct kunit *test) 85 { 86 struct mptcp_subflow_request_sock *req = build_req_sock(test); 87 struct mptcp_sock *msk = build_msk(test); 88 89 KUNIT_ASSERT_EQ(test, 0, 90 mptcp_token_new_request((struct request_sock *)req)); 91 msk->token = req->token; 92 mptcp_token_accept(req, msk); 93 KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(msk->token)); 94 95 /* this is now a no-op */ 96 mptcp_token_destroy_request((struct request_sock *)req); 97 KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(msk->token)); 98 99 /* cleanup */ 100 mptcp_token_destroy(msk); 101 } 102 103 static void mptcp_token_test_destroyed(struct kunit *test) 104 { 105 struct mptcp_subflow_request_sock *req = build_req_sock(test); 106 struct mptcp_sock *msk = build_msk(test); 107 struct mptcp_sock *null_msk = NULL; 108 struct sock *sk; 109 110 sk = (struct sock *)msk; 111 112 KUNIT_ASSERT_EQ(test, 0, 113 mptcp_token_new_request((struct request_sock *)req)); 114 msk->token = req->token; 115 mptcp_token_accept(req, msk); 116 117 /* simulate race on removal */ 118 refcount_set(&sk->sk_refcnt, 0); 119 KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(msk->token)); 120 121 /* cleanup */ 122 mptcp_token_destroy(msk); 123 } 124 125 static struct kunit_case mptcp_token_test_cases[] = { 126 KUNIT_CASE(mptcp_token_test_req_basic), 127 KUNIT_CASE(mptcp_token_test_msk_basic), 128 KUNIT_CASE(mptcp_token_test_accept), 129 KUNIT_CASE(mptcp_token_test_destroyed), 130 {} 131 }; 132 133 static struct kunit_suite mptcp_token_suite = { 134 .name = "mptcp-token", 135 .test_cases = mptcp_token_test_cases, 136 }; 137 138 kunit_test_suite(mptcp_token_suite); 139 140 MODULE_LICENSE("GPL"); 141