1 // SPDX-License-Identifier: GPL-2.0 2 /* Multipath TCP token management 3 * Copyright (c) 2017 - 2019, Intel Corporation. 4 * 5 * Note: This code is based on mptcp_ctrl.c from multipath-tcp.org, 6 * authored by: 7 * 8 * Sébastien Barré <sebastien.barre@uclouvain.be> 9 * Christoph Paasch <christoph.paasch@uclouvain.be> 10 * Jaakko Korkeaniemi <jaakko.korkeaniemi@aalto.fi> 11 * Gregory Detal <gregory.detal@uclouvain.be> 12 * Fabien Duchêne <fabien.duchene@uclouvain.be> 13 * Andreas Seelinger <Andreas.Seelinger@rwth-aachen.de> 14 * Lavkesh Lahngir <lavkesh51@gmail.com> 15 * Andreas Ripke <ripke@neclab.eu> 16 * Vlad Dogaru <vlad.dogaru@intel.com> 17 * Octavian Purdila <octavian.purdila@intel.com> 18 * John Ronan <jronan@tssg.org> 19 * Catalin Nicutar <catalin.nicutar@gmail.com> 20 * Brandon Heller <brandonh@stanford.edu> 21 */ 22 23 #define pr_fmt(fmt) "MPTCP: " fmt 24 25 #include <linux/kernel.h> 26 #include <linux/module.h> 27 #include <linux/memblock.h> 28 #include <linux/ip.h> 29 #include <linux/tcp.h> 30 #include <net/sock.h> 31 #include <net/inet_common.h> 32 #include <net/protocol.h> 33 #include <net/mptcp.h> 34 #include "protocol.h" 35 36 #define TOKEN_MAX_RETRIES 4 37 #define TOKEN_MAX_CHAIN_LEN 4 38 39 struct token_bucket { 40 spinlock_t lock; 41 int chain_len; 42 struct hlist_nulls_head req_chain; 43 struct hlist_nulls_head msk_chain; 44 }; 45 46 static struct token_bucket *token_hash __read_mostly; 47 static unsigned int token_mask __read_mostly; 48 49 static struct token_bucket *token_bucket(u32 token) 50 { 51 return &token_hash[token & token_mask]; 52 } 53 54 /* called with bucket lock held */ 55 static struct mptcp_subflow_request_sock * 56 __token_lookup_req(struct token_bucket *t, u32 token) 57 { 58 struct mptcp_subflow_request_sock *req; 59 struct hlist_nulls_node *pos; 60 61 hlist_nulls_for_each_entry_rcu(req, pos, &t->req_chain, token_node) 62 if (req->token == token) 63 return req; 64 return NULL; 65 } 66 67 /* called with bucket lock held */ 68 static struct mptcp_sock * 69 __token_lookup_msk(struct token_bucket *t, u32 token) 70 { 71 struct hlist_nulls_node *pos; 72 struct sock *sk; 73 74 sk_nulls_for_each_rcu(sk, pos, &t->msk_chain) 75 if (mptcp_sk(sk)->token == token) 76 return mptcp_sk(sk); 77 return NULL; 78 } 79 80 static bool __token_bucket_busy(struct token_bucket *t, u32 token) 81 { 82 return !token || t->chain_len >= TOKEN_MAX_CHAIN_LEN || 83 __token_lookup_req(t, token) || __token_lookup_msk(t, token); 84 } 85 86 static void mptcp_crypto_key_gen_sha(u64 *key, u32 *token, u64 *idsn) 87 { 88 /* we might consider a faster version that computes the key as a 89 * hash of some information available in the MPTCP socket. Use 90 * random data at the moment, as it's probably the safest option 91 * in case multiple sockets are opened in different namespaces at 92 * the same time. 93 */ 94 get_random_bytes(key, sizeof(u64)); 95 mptcp_crypto_key_sha(*key, token, idsn); 96 } 97 98 /** 99 * mptcp_token_new_request - create new key/idsn/token for subflow_request 100 * @req: the request socket 101 * 102 * This function is called when a new mptcp connection is coming in. 103 * 104 * It creates a unique token to identify the new mptcp connection, 105 * a secret local key and the initial data sequence number (idsn). 106 * 107 * Returns 0 on success. 108 */ 109 int mptcp_token_new_request(struct request_sock *req) 110 { 111 struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); 112 int retries = TOKEN_MAX_RETRIES; 113 struct token_bucket *bucket; 114 u32 token; 115 116 again: 117 mptcp_crypto_key_gen_sha(&subflow_req->local_key, 118 &subflow_req->token, 119 &subflow_req->idsn); 120 pr_debug("req=%p local_key=%llu, token=%u, idsn=%llu\n", 121 req, subflow_req->local_key, subflow_req->token, 122 subflow_req->idsn); 123 124 token = subflow_req->token; 125 bucket = token_bucket(token); 126 spin_lock_bh(&bucket->lock); 127 if (__token_bucket_busy(bucket, token)) { 128 spin_unlock_bh(&bucket->lock); 129 if (!--retries) 130 return -EBUSY; 131 goto again; 132 } 133 134 hlist_nulls_add_head_rcu(&subflow_req->token_node, &bucket->req_chain); 135 bucket->chain_len++; 136 spin_unlock_bh(&bucket->lock); 137 return 0; 138 } 139 140 /** 141 * mptcp_token_new_connect - create new key/idsn/token for subflow 142 * @sk: the socket that will initiate a connection 143 * 144 * This function is called when a new outgoing mptcp connection is 145 * initiated. 146 * 147 * It creates a unique token to identify the new mptcp connection, 148 * a secret local key and the initial data sequence number (idsn). 149 * 150 * On success, the mptcp connection can be found again using 151 * the computed token at a later time, this is needed to process 152 * join requests. 153 * 154 * returns 0 on success. 155 */ 156 int mptcp_token_new_connect(struct sock *sk) 157 { 158 struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); 159 struct mptcp_sock *msk = mptcp_sk(subflow->conn); 160 int retries = TOKEN_MAX_RETRIES; 161 struct token_bucket *bucket; 162 163 pr_debug("ssk=%p, local_key=%llu, token=%u, idsn=%llu\n", 164 sk, subflow->local_key, subflow->token, subflow->idsn); 165 166 again: 167 mptcp_crypto_key_gen_sha(&subflow->local_key, &subflow->token, 168 &subflow->idsn); 169 170 bucket = token_bucket(subflow->token); 171 spin_lock_bh(&bucket->lock); 172 if (__token_bucket_busy(bucket, subflow->token)) { 173 spin_unlock_bh(&bucket->lock); 174 if (!--retries) 175 return -EBUSY; 176 goto again; 177 } 178 179 WRITE_ONCE(msk->token, subflow->token); 180 __sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain); 181 bucket->chain_len++; 182 spin_unlock_bh(&bucket->lock); 183 return 0; 184 } 185 186 /** 187 * mptcp_token_accept - replace a req sk with full sock in token hash 188 * @req: the request socket to be removed 189 * @msk: the just cloned socket linked to the new connection 190 * 191 * Called when a SYN packet creates a new logical connection, i.e. 192 * is not a join request. 193 */ 194 void mptcp_token_accept(struct mptcp_subflow_request_sock *req, 195 struct mptcp_sock *msk) 196 { 197 struct mptcp_subflow_request_sock *pos; 198 struct token_bucket *bucket; 199 200 bucket = token_bucket(req->token); 201 spin_lock_bh(&bucket->lock); 202 203 /* pedantic lookup check for the moved token */ 204 pos = __token_lookup_req(bucket, req->token); 205 if (!WARN_ON_ONCE(pos != req)) 206 hlist_nulls_del_init_rcu(&req->token_node); 207 __sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain); 208 spin_unlock_bh(&bucket->lock); 209 } 210 211 /** 212 * mptcp_token_get_sock - retrieve mptcp connection sock using its token 213 * @token: token of the mptcp connection to retrieve 214 * 215 * This function returns the mptcp connection structure with the given token. 216 * A reference count on the mptcp socket returned is taken. 217 * 218 * returns NULL if no connection with the given token value exists. 219 */ 220 struct mptcp_sock *mptcp_token_get_sock(u32 token) 221 { 222 struct hlist_nulls_node *pos; 223 struct token_bucket *bucket; 224 struct mptcp_sock *msk; 225 struct sock *sk; 226 227 rcu_read_lock(); 228 bucket = token_bucket(token); 229 230 again: 231 sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) { 232 msk = mptcp_sk(sk); 233 if (READ_ONCE(msk->token) != token) 234 continue; 235 if (!refcount_inc_not_zero(&sk->sk_refcnt)) 236 goto not_found; 237 if (READ_ONCE(msk->token) != token) { 238 sock_put(sk); 239 goto again; 240 } 241 goto found; 242 } 243 if (get_nulls_value(pos) != (token & token_mask)) 244 goto again; 245 246 not_found: 247 msk = NULL; 248 249 found: 250 rcu_read_unlock(); 251 return msk; 252 } 253 EXPORT_SYMBOL_GPL(mptcp_token_get_sock); 254 255 /** 256 * mptcp_token_iter_next - iterate over the token container from given pos 257 * @net: namespace to be iterated 258 * @s_slot: start slot number 259 * @s_num: start number inside the given lock 260 * 261 * This function returns the first mptcp connection structure found inside the 262 * token container starting from the specified position, or NULL. 263 * 264 * On successful iteration, the iterator is move to the next position and the 265 * the acquires a reference to the returned socket. 266 */ 267 struct mptcp_sock *mptcp_token_iter_next(const struct net *net, long *s_slot, 268 long *s_num) 269 { 270 struct mptcp_sock *ret = NULL; 271 struct hlist_nulls_node *pos; 272 int slot, num; 273 274 for (slot = *s_slot; slot <= token_mask; *s_num = 0, slot++) { 275 struct token_bucket *bucket = &token_hash[slot]; 276 struct sock *sk; 277 278 num = 0; 279 280 if (hlist_nulls_empty(&bucket->msk_chain)) 281 continue; 282 283 rcu_read_lock(); 284 sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) { 285 ++num; 286 if (!net_eq(sock_net(sk), net)) 287 continue; 288 289 if (num <= *s_num) 290 continue; 291 292 if (!refcount_inc_not_zero(&sk->sk_refcnt)) 293 continue; 294 295 if (!net_eq(sock_net(sk), net)) { 296 sock_put(sk); 297 continue; 298 } 299 300 ret = mptcp_sk(sk); 301 rcu_read_unlock(); 302 goto out; 303 } 304 rcu_read_unlock(); 305 } 306 307 out: 308 *s_slot = slot; 309 *s_num = num; 310 return ret; 311 } 312 EXPORT_SYMBOL_GPL(mptcp_token_iter_next); 313 314 /** 315 * mptcp_token_destroy_request - remove mptcp connection/token 316 * @req: mptcp request socket dropping the token 317 * 318 * Remove the token associated to @req. 319 */ 320 void mptcp_token_destroy_request(struct request_sock *req) 321 { 322 struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); 323 struct mptcp_subflow_request_sock *pos; 324 struct token_bucket *bucket; 325 326 if (hlist_nulls_unhashed(&subflow_req->token_node)) 327 return; 328 329 bucket = token_bucket(subflow_req->token); 330 spin_lock_bh(&bucket->lock); 331 pos = __token_lookup_req(bucket, subflow_req->token); 332 if (!WARN_ON_ONCE(pos != subflow_req)) { 333 hlist_nulls_del_init_rcu(&pos->token_node); 334 bucket->chain_len--; 335 } 336 spin_unlock_bh(&bucket->lock); 337 } 338 339 /** 340 * mptcp_token_destroy - remove mptcp connection/token 341 * @msk: mptcp connection dropping the token 342 * 343 * Remove the token associated to @msk 344 */ 345 void mptcp_token_destroy(struct mptcp_sock *msk) 346 { 347 struct token_bucket *bucket; 348 struct mptcp_sock *pos; 349 350 if (sk_unhashed((struct sock *)msk)) 351 return; 352 353 bucket = token_bucket(msk->token); 354 spin_lock_bh(&bucket->lock); 355 pos = __token_lookup_msk(bucket, msk->token); 356 if (!WARN_ON_ONCE(pos != msk)) { 357 __sk_nulls_del_node_init_rcu((struct sock *)pos); 358 bucket->chain_len--; 359 } 360 spin_unlock_bh(&bucket->lock); 361 } 362 363 void __init mptcp_token_init(void) 364 { 365 int i; 366 367 token_hash = alloc_large_system_hash("MPTCP token", 368 sizeof(struct token_bucket), 369 0, 370 20,/* one slot per 1MB of memory */ 371 HASH_ZERO, 372 NULL, 373 &token_mask, 374 0, 375 64 * 1024); 376 for (i = 0; i < token_mask + 1; ++i) { 377 INIT_HLIST_NULLS_HEAD(&token_hash[i].req_chain, i); 378 INIT_HLIST_NULLS_HEAD(&token_hash[i].msk_chain, i); 379 spin_lock_init(&token_hash[i].lock); 380 } 381 } 382 383 #if IS_MODULE(CONFIG_MPTCP_KUNIT_TESTS) 384 EXPORT_SYMBOL_GPL(mptcp_token_new_request); 385 EXPORT_SYMBOL_GPL(mptcp_token_new_connect); 386 EXPORT_SYMBOL_GPL(mptcp_token_accept); 387 EXPORT_SYMBOL_GPL(mptcp_token_destroy_request); 388 EXPORT_SYMBOL_GPL(mptcp_token_destroy); 389 #endif 390