1 // SPDX-License-Identifier: GPL-2.0 2 /* Multipath TCP 3 * 4 * Copyright (c) 2022, Intel Corporation. 5 */ 6 7 #include "protocol.h" 8 9 void mptcp_free_local_addr_list(struct mptcp_sock *msk) 10 { 11 struct mptcp_pm_addr_entry *entry, *tmp; 12 struct sock *sk = (struct sock *)msk; 13 LIST_HEAD(free_list); 14 15 if (!mptcp_pm_is_userspace(msk)) 16 return; 17 18 spin_lock_bh(&msk->pm.lock); 19 list_splice_init(&msk->pm.userspace_pm_local_addr_list, &free_list); 20 spin_unlock_bh(&msk->pm.lock); 21 22 list_for_each_entry_safe(entry, tmp, &free_list, list) { 23 sock_kfree_s(sk, entry, sizeof(*entry)); 24 } 25 } 26 27 int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk, 28 struct mptcp_pm_addr_entry *entry) 29 { 30 DECLARE_BITMAP(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1); 31 struct mptcp_pm_addr_entry *match = NULL; 32 struct sock *sk = (struct sock *)msk; 33 struct mptcp_pm_addr_entry *e; 34 bool addr_match = false; 35 bool id_match = false; 36 int ret = -EINVAL; 37 38 bitmap_zero(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1); 39 40 spin_lock_bh(&msk->pm.lock); 41 list_for_each_entry(e, &msk->pm.userspace_pm_local_addr_list, list) { 42 addr_match = mptcp_addresses_equal(&e->addr, &entry->addr, true); 43 if (addr_match && entry->addr.id == 0) 44 entry->addr.id = e->addr.id; 45 id_match = (e->addr.id == entry->addr.id); 46 if (addr_match && id_match) { 47 match = e; 48 break; 49 } else if (addr_match || id_match) { 50 break; 51 } 52 __set_bit(e->addr.id, id_bitmap); 53 } 54 55 if (!match && !addr_match && !id_match) { 56 /* Memory for the entry is allocated from the 57 * sock option buffer. 58 */ 59 e = sock_kmalloc(sk, sizeof(*e), GFP_ATOMIC); 60 if (!e) { 61 spin_unlock_bh(&msk->pm.lock); 62 return -ENOMEM; 63 } 64 65 *e = *entry; 66 if (!e->addr.id) 67 e->addr.id = find_next_zero_bit(id_bitmap, 68 MPTCP_PM_MAX_ADDR_ID + 1, 69 1); 70 list_add_tail_rcu(&e->list, &msk->pm.userspace_pm_local_addr_list); 71 ret = e->addr.id; 72 } else if (match) { 73 ret = entry->addr.id; 74 } 75 76 spin_unlock_bh(&msk->pm.lock); 77 return ret; 78 } 79 80 int mptcp_userspace_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, 81 unsigned int id, 82 u8 *flags, int *ifindex) 83 { 84 struct mptcp_pm_addr_entry *entry, *match = NULL; 85 86 *flags = 0; 87 *ifindex = 0; 88 89 spin_lock_bh(&msk->pm.lock); 90 list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) { 91 if (id == entry->addr.id) { 92 match = entry; 93 break; 94 } 95 } 96 spin_unlock_bh(&msk->pm.lock); 97 if (match) { 98 *flags = match->flags; 99 *ifindex = match->ifindex; 100 } 101 102 return 0; 103 } 104 105 int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk, 106 struct mptcp_addr_info *skc) 107 { 108 struct mptcp_pm_addr_entry new_entry; 109 __be16 msk_sport = ((struct inet_sock *) 110 inet_sk((struct sock *)msk))->inet_sport; 111 112 memset(&new_entry, 0, sizeof(struct mptcp_pm_addr_entry)); 113 new_entry.addr = *skc; 114 new_entry.addr.id = 0; 115 new_entry.flags = MPTCP_PM_ADDR_FLAG_IMPLICIT; 116 117 if (new_entry.addr.port == msk_sport) 118 new_entry.addr.port = 0; 119 120 return mptcp_userspace_pm_append_new_local_addr(msk, &new_entry); 121 } 122 123 int mptcp_nl_cmd_announce(struct sk_buff *skb, struct genl_info *info) 124 { 125 struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; 126 struct nlattr *addr = info->attrs[MPTCP_PM_ATTR_ADDR]; 127 struct mptcp_pm_addr_entry addr_val; 128 struct mptcp_sock *msk; 129 int err = -EINVAL; 130 u32 token_val; 131 132 if (!addr || !token) { 133 GENL_SET_ERR_MSG(info, "missing required inputs"); 134 return err; 135 } 136 137 token_val = nla_get_u32(token); 138 139 msk = mptcp_token_get_sock(sock_net(skb->sk), token_val); 140 if (!msk) { 141 NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); 142 return err; 143 } 144 145 if (!mptcp_pm_is_userspace(msk)) { 146 GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); 147 goto announce_err; 148 } 149 150 err = mptcp_pm_parse_entry(addr, info, true, &addr_val); 151 if (err < 0) { 152 GENL_SET_ERR_MSG(info, "error parsing local address"); 153 goto announce_err; 154 } 155 156 if (addr_val.addr.id == 0 || !(addr_val.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) { 157 GENL_SET_ERR_MSG(info, "invalid addr id or flags"); 158 goto announce_err; 159 } 160 161 err = mptcp_userspace_pm_append_new_local_addr(msk, &addr_val); 162 if (err < 0) { 163 GENL_SET_ERR_MSG(info, "did not match address and id"); 164 goto announce_err; 165 } 166 167 lock_sock((struct sock *)msk); 168 spin_lock_bh(&msk->pm.lock); 169 170 if (mptcp_pm_alloc_anno_list(msk, &addr_val)) { 171 mptcp_pm_announce_addr(msk, &addr_val.addr, false); 172 mptcp_pm_nl_addr_send_ack(msk); 173 } 174 175 spin_unlock_bh(&msk->pm.lock); 176 release_sock((struct sock *)msk); 177 178 err = 0; 179 announce_err: 180 sock_put((struct sock *)msk); 181 return err; 182 } 183 184 int mptcp_nl_cmd_remove(struct sk_buff *skb, struct genl_info *info) 185 { 186 struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; 187 struct nlattr *id = info->attrs[MPTCP_PM_ATTR_LOC_ID]; 188 struct mptcp_pm_addr_entry *match = NULL; 189 struct mptcp_pm_addr_entry *entry; 190 struct mptcp_sock *msk; 191 LIST_HEAD(free_list); 192 int err = -EINVAL; 193 u32 token_val; 194 u8 id_val; 195 196 if (!id || !token) { 197 GENL_SET_ERR_MSG(info, "missing required inputs"); 198 return err; 199 } 200 201 id_val = nla_get_u8(id); 202 token_val = nla_get_u32(token); 203 204 msk = mptcp_token_get_sock(sock_net(skb->sk), token_val); 205 if (!msk) { 206 NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); 207 return err; 208 } 209 210 if (!mptcp_pm_is_userspace(msk)) { 211 GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); 212 goto remove_err; 213 } 214 215 lock_sock((struct sock *)msk); 216 217 list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) { 218 if (entry->addr.id == id_val) { 219 match = entry; 220 break; 221 } 222 } 223 224 if (!match) { 225 GENL_SET_ERR_MSG(info, "address with specified id not found"); 226 release_sock((struct sock *)msk); 227 goto remove_err; 228 } 229 230 list_move(&match->list, &free_list); 231 232 mptcp_pm_remove_addrs_and_subflows(msk, &free_list); 233 234 release_sock((struct sock *)msk); 235 236 list_for_each_entry_safe(match, entry, &free_list, list) { 237 sock_kfree_s((struct sock *)msk, match, sizeof(*match)); 238 } 239 240 err = 0; 241 remove_err: 242 sock_put((struct sock *)msk); 243 return err; 244 } 245 246 int mptcp_nl_cmd_sf_create(struct sk_buff *skb, struct genl_info *info) 247 { 248 struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; 249 struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; 250 struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR]; 251 struct mptcp_addr_info addr_r; 252 struct mptcp_addr_info addr_l; 253 struct mptcp_sock *msk; 254 int err = -EINVAL; 255 struct sock *sk; 256 u32 token_val; 257 258 if (!laddr || !raddr || !token) { 259 GENL_SET_ERR_MSG(info, "missing required inputs"); 260 return err; 261 } 262 263 token_val = nla_get_u32(token); 264 265 msk = mptcp_token_get_sock(genl_info_net(info), token_val); 266 if (!msk) { 267 NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); 268 return err; 269 } 270 271 if (!mptcp_pm_is_userspace(msk)) { 272 GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); 273 goto create_err; 274 } 275 276 err = mptcp_pm_parse_addr(laddr, info, &addr_l); 277 if (err < 0) { 278 NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr"); 279 goto create_err; 280 } 281 282 if (addr_l.id == 0) { 283 NL_SET_ERR_MSG_ATTR(info->extack, laddr, "missing local addr id"); 284 goto create_err; 285 } 286 287 err = mptcp_pm_parse_addr(raddr, info, &addr_r); 288 if (err < 0) { 289 NL_SET_ERR_MSG_ATTR(info->extack, raddr, "error parsing remote addr"); 290 goto create_err; 291 } 292 293 sk = &msk->sk.icsk_inet.sk; 294 lock_sock(sk); 295 296 err = __mptcp_subflow_connect(sk, &addr_l, &addr_r); 297 298 release_sock(sk); 299 300 create_err: 301 sock_put((struct sock *)msk); 302 return err; 303 } 304 305 static struct sock *mptcp_nl_find_ssk(struct mptcp_sock *msk, 306 const struct mptcp_addr_info *local, 307 const struct mptcp_addr_info *remote) 308 { 309 struct sock *sk = &msk->sk.icsk_inet.sk; 310 struct mptcp_subflow_context *subflow; 311 struct sock *found = NULL; 312 313 if (local->family != remote->family) 314 return NULL; 315 316 lock_sock(sk); 317 318 mptcp_for_each_subflow(msk, subflow) { 319 const struct inet_sock *issk; 320 struct sock *ssk; 321 322 ssk = mptcp_subflow_tcp_sock(subflow); 323 324 if (local->family != ssk->sk_family) 325 continue; 326 327 issk = inet_sk(ssk); 328 329 switch (ssk->sk_family) { 330 case AF_INET: 331 if (issk->inet_saddr != local->addr.s_addr || 332 issk->inet_daddr != remote->addr.s_addr) 333 continue; 334 break; 335 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 336 case AF_INET6: { 337 const struct ipv6_pinfo *pinfo = inet6_sk(ssk); 338 339 if (!ipv6_addr_equal(&local->addr6, &pinfo->saddr) || 340 !ipv6_addr_equal(&remote->addr6, &ssk->sk_v6_daddr)) 341 continue; 342 break; 343 } 344 #endif 345 default: 346 continue; 347 } 348 349 if (issk->inet_sport == local->port && 350 issk->inet_dport == remote->port) { 351 found = ssk; 352 goto found; 353 } 354 } 355 356 found: 357 release_sock(sk); 358 359 return found; 360 } 361 362 int mptcp_nl_cmd_sf_destroy(struct sk_buff *skb, struct genl_info *info) 363 { 364 struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; 365 struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; 366 struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR]; 367 struct mptcp_addr_info addr_l; 368 struct mptcp_addr_info addr_r; 369 struct mptcp_sock *msk; 370 struct sock *sk, *ssk; 371 int err = -EINVAL; 372 u32 token_val; 373 374 if (!laddr || !raddr || !token) { 375 GENL_SET_ERR_MSG(info, "missing required inputs"); 376 return err; 377 } 378 379 token_val = nla_get_u32(token); 380 381 msk = mptcp_token_get_sock(genl_info_net(info), token_val); 382 if (!msk) { 383 NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); 384 return err; 385 } 386 387 if (!mptcp_pm_is_userspace(msk)) { 388 GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); 389 goto destroy_err; 390 } 391 392 err = mptcp_pm_parse_addr(laddr, info, &addr_l); 393 if (err < 0) { 394 NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr"); 395 goto destroy_err; 396 } 397 398 err = mptcp_pm_parse_addr(raddr, info, &addr_r); 399 if (err < 0) { 400 NL_SET_ERR_MSG_ATTR(info->extack, raddr, "error parsing remote addr"); 401 goto destroy_err; 402 } 403 404 if (addr_l.family != addr_r.family) { 405 GENL_SET_ERR_MSG(info, "address families do not match"); 406 goto destroy_err; 407 } 408 409 if (!addr_l.port || !addr_r.port) { 410 GENL_SET_ERR_MSG(info, "missing local or remote port"); 411 goto destroy_err; 412 } 413 414 sk = &msk->sk.icsk_inet.sk; 415 ssk = mptcp_nl_find_ssk(msk, &addr_l, &addr_r); 416 if (ssk) { 417 struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); 418 419 mptcp_subflow_shutdown(sk, ssk, RCV_SHUTDOWN | SEND_SHUTDOWN); 420 mptcp_close_ssk(sk, ssk, subflow); 421 err = 0; 422 } else { 423 err = -ESRCH; 424 } 425 426 destroy_err: 427 sock_put((struct sock *)msk); 428 return err; 429 } 430