1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * Copyright (C) 2016 Namjae Jeon <linkinjeon@kernel.org> 4 * Copyright (C) 2018 Samsung Electronics Co., Ltd. 5 */ 6 7 #include <linux/freezer.h> 8 9 #include "smb_common.h" 10 #include "server.h" 11 #include "auth.h" 12 #include "connection.h" 13 #include "transport_tcp.h" 14 15 #define IFACE_STATE_DOWN BIT(0) 16 #define IFACE_STATE_CONFIGURED BIT(1) 17 18 static atomic_t active_num_conn; 19 20 struct interface { 21 struct task_struct *ksmbd_kthread; 22 struct socket *ksmbd_socket; 23 struct list_head entry; 24 char *name; 25 struct mutex sock_release_lock; 26 int state; 27 }; 28 29 static LIST_HEAD(iface_list); 30 31 static int bind_additional_ifaces; 32 33 struct tcp_transport { 34 struct ksmbd_transport transport; 35 struct socket *sock; 36 struct kvec *iov; 37 unsigned int nr_iov; 38 }; 39 40 static struct ksmbd_transport_ops ksmbd_tcp_transport_ops; 41 42 static void tcp_stop_kthread(struct task_struct *kthread); 43 static struct interface *alloc_iface(char *ifname); 44 45 #define KSMBD_TRANS(t) (&(t)->transport) 46 #define TCP_TRANS(t) ((struct tcp_transport *)container_of(t, \ 47 struct tcp_transport, transport)) 48 49 static inline void ksmbd_tcp_nodelay(struct socket *sock) 50 { 51 tcp_sock_set_nodelay(sock->sk); 52 } 53 54 static inline void ksmbd_tcp_reuseaddr(struct socket *sock) 55 { 56 sock_set_reuseaddr(sock->sk); 57 } 58 59 static inline void ksmbd_tcp_rcv_timeout(struct socket *sock, s64 secs) 60 { 61 lock_sock(sock->sk); 62 if (secs && secs < MAX_SCHEDULE_TIMEOUT / HZ - 1) 63 sock->sk->sk_rcvtimeo = secs * HZ; 64 else 65 sock->sk->sk_rcvtimeo = MAX_SCHEDULE_TIMEOUT; 66 release_sock(sock->sk); 67 } 68 69 static inline void ksmbd_tcp_snd_timeout(struct socket *sock, s64 secs) 70 { 71 sock_set_sndtimeo(sock->sk, secs); 72 } 73 74 static struct tcp_transport *alloc_transport(struct socket *client_sk) 75 { 76 struct tcp_transport *t; 77 struct ksmbd_conn *conn; 78 79 t = kzalloc(sizeof(*t), GFP_KERNEL); 80 if (!t) 81 return NULL; 82 t->sock = client_sk; 83 84 conn = ksmbd_conn_alloc(); 85 if (!conn) { 86 kfree(t); 87 return NULL; 88 } 89 90 conn->transport = KSMBD_TRANS(t); 91 KSMBD_TRANS(t)->conn = conn; 92 KSMBD_TRANS(t)->ops = &ksmbd_tcp_transport_ops; 93 return t; 94 } 95 96 static void free_transport(struct tcp_transport *t) 97 { 98 kernel_sock_shutdown(t->sock, SHUT_RDWR); 99 sock_release(t->sock); 100 t->sock = NULL; 101 102 ksmbd_conn_free(KSMBD_TRANS(t)->conn); 103 kfree(t->iov); 104 kfree(t); 105 } 106 107 /** 108 * kvec_array_init() - initialize a IO vector segment 109 * @new: IO vector to be initialized 110 * @iov: base IO vector 111 * @nr_segs: number of segments in base iov 112 * @bytes: total iovec length so far for read 113 * 114 * Return: Number of IO segments 115 */ 116 static unsigned int kvec_array_init(struct kvec *new, struct kvec *iov, 117 unsigned int nr_segs, size_t bytes) 118 { 119 size_t base = 0; 120 121 while (bytes || !iov->iov_len) { 122 int copy = min(bytes, iov->iov_len); 123 124 bytes -= copy; 125 base += copy; 126 if (iov->iov_len == base) { 127 iov++; 128 nr_segs--; 129 base = 0; 130 } 131 } 132 133 memcpy(new, iov, sizeof(*iov) * nr_segs); 134 new->iov_base += base; 135 new->iov_len -= base; 136 return nr_segs; 137 } 138 139 /** 140 * get_conn_iovec() - get connection iovec for reading from socket 141 * @t: TCP transport instance 142 * @nr_segs: number of segments in iov 143 * 144 * Return: return existing or newly allocate iovec 145 */ 146 static struct kvec *get_conn_iovec(struct tcp_transport *t, unsigned int nr_segs) 147 { 148 struct kvec *new_iov; 149 150 if (t->iov && nr_segs <= t->nr_iov) 151 return t->iov; 152 153 /* not big enough -- allocate a new one and release the old */ 154 new_iov = kmalloc_array(nr_segs, sizeof(*new_iov), GFP_KERNEL); 155 if (new_iov) { 156 kfree(t->iov); 157 t->iov = new_iov; 158 t->nr_iov = nr_segs; 159 } 160 return new_iov; 161 } 162 163 static unsigned short ksmbd_tcp_get_port(const struct sockaddr *sa) 164 { 165 switch (sa->sa_family) { 166 case AF_INET: 167 return ntohs(((struct sockaddr_in *)sa)->sin_port); 168 case AF_INET6: 169 return ntohs(((struct sockaddr_in6 *)sa)->sin6_port); 170 } 171 return 0; 172 } 173 174 /** 175 * ksmbd_tcp_new_connection() - create a new tcp session on mount 176 * @client_sk: socket associated with new connection 177 * 178 * whenever a new connection is requested, create a conn thread 179 * (session thread) to handle new incoming smb requests from the connection 180 * 181 * Return: 0 on success, otherwise error 182 */ 183 static int ksmbd_tcp_new_connection(struct socket *client_sk) 184 { 185 struct sockaddr *csin; 186 int rc = 0; 187 struct tcp_transport *t; 188 189 t = alloc_transport(client_sk); 190 if (!t) { 191 sock_release(client_sk); 192 return -ENOMEM; 193 } 194 195 csin = KSMBD_TCP_PEER_SOCKADDR(KSMBD_TRANS(t)->conn); 196 if (kernel_getpeername(client_sk, csin) < 0) { 197 pr_err("client ip resolution failed\n"); 198 rc = -EINVAL; 199 goto out_error; 200 } 201 202 KSMBD_TRANS(t)->handler = kthread_run(ksmbd_conn_handler_loop, 203 KSMBD_TRANS(t)->conn, 204 "ksmbd:%u", 205 ksmbd_tcp_get_port(csin)); 206 if (IS_ERR(KSMBD_TRANS(t)->handler)) { 207 pr_err("cannot start conn thread\n"); 208 rc = PTR_ERR(KSMBD_TRANS(t)->handler); 209 free_transport(t); 210 } 211 return rc; 212 213 out_error: 214 free_transport(t); 215 return rc; 216 } 217 218 /** 219 * ksmbd_kthread_fn() - listen to new SMB connections and callback server 220 * @p: arguments to forker thread 221 * 222 * Return: 0 on success, error number otherwise 223 */ 224 static int ksmbd_kthread_fn(void *p) 225 { 226 struct socket *client_sk = NULL; 227 struct interface *iface = (struct interface *)p; 228 int ret; 229 230 while (!kthread_should_stop()) { 231 mutex_lock(&iface->sock_release_lock); 232 if (!iface->ksmbd_socket) { 233 mutex_unlock(&iface->sock_release_lock); 234 break; 235 } 236 ret = kernel_accept(iface->ksmbd_socket, &client_sk, 237 SOCK_NONBLOCK); 238 mutex_unlock(&iface->sock_release_lock); 239 if (ret) { 240 if (ret == -EAGAIN) 241 /* check for new connections every 100 msecs */ 242 schedule_timeout_interruptible(HZ / 10); 243 continue; 244 } 245 246 if (server_conf.max_connections && 247 atomic_inc_return(&active_num_conn) >= server_conf.max_connections) { 248 pr_info_ratelimited("Limit the maximum number of connections(%u)\n", 249 atomic_read(&active_num_conn)); 250 atomic_dec(&active_num_conn); 251 sock_release(client_sk); 252 continue; 253 } 254 255 ksmbd_debug(CONN, "connect success: accepted new connection\n"); 256 client_sk->sk->sk_rcvtimeo = KSMBD_TCP_RECV_TIMEOUT; 257 client_sk->sk->sk_sndtimeo = KSMBD_TCP_SEND_TIMEOUT; 258 259 ksmbd_tcp_new_connection(client_sk); 260 } 261 262 ksmbd_debug(CONN, "releasing socket\n"); 263 return 0; 264 } 265 266 /** 267 * ksmbd_tcp_run_kthread() - start forker thread 268 * @iface: pointer to struct interface 269 * 270 * start forker thread(ksmbd/0) at module init time to listen 271 * on port 445 for new SMB connection requests. It creates per connection 272 * server threads(ksmbd/x) 273 * 274 * Return: 0 on success or error number 275 */ 276 static int ksmbd_tcp_run_kthread(struct interface *iface) 277 { 278 int rc; 279 struct task_struct *kthread; 280 281 kthread = kthread_run(ksmbd_kthread_fn, (void *)iface, "ksmbd-%s", 282 iface->name); 283 if (IS_ERR(kthread)) { 284 rc = PTR_ERR(kthread); 285 return rc; 286 } 287 iface->ksmbd_kthread = kthread; 288 289 return 0; 290 } 291 292 /** 293 * ksmbd_tcp_readv() - read data from socket in given iovec 294 * @t: TCP transport instance 295 * @iov_orig: base IO vector 296 * @nr_segs: number of segments in base iov 297 * @to_read: number of bytes to read from socket 298 * @max_retries: maximum retry count 299 * 300 * Return: on success return number of bytes read from socket, 301 * otherwise return error number 302 */ 303 static int ksmbd_tcp_readv(struct tcp_transport *t, struct kvec *iov_orig, 304 unsigned int nr_segs, unsigned int to_read, 305 int max_retries) 306 { 307 int length = 0; 308 int total_read; 309 unsigned int segs; 310 struct msghdr ksmbd_msg; 311 struct kvec *iov; 312 struct ksmbd_conn *conn = KSMBD_TRANS(t)->conn; 313 314 iov = get_conn_iovec(t, nr_segs); 315 if (!iov) 316 return -ENOMEM; 317 318 ksmbd_msg.msg_control = NULL; 319 ksmbd_msg.msg_controllen = 0; 320 321 for (total_read = 0; to_read; total_read += length, to_read -= length) { 322 try_to_freeze(); 323 324 if (!ksmbd_conn_alive(conn)) { 325 total_read = -ESHUTDOWN; 326 break; 327 } 328 segs = kvec_array_init(iov, iov_orig, nr_segs, total_read); 329 330 length = kernel_recvmsg(t->sock, &ksmbd_msg, 331 iov, segs, to_read, 0); 332 333 if (length == -EINTR) { 334 total_read = -ESHUTDOWN; 335 break; 336 } else if (ksmbd_conn_need_reconnect(conn)) { 337 total_read = -EAGAIN; 338 break; 339 } else if (length == -ERESTARTSYS || length == -EAGAIN) { 340 /* 341 * If max_retries is negative, Allow unlimited 342 * retries to keep connection with inactive sessions. 343 */ 344 if (max_retries == 0) { 345 total_read = length; 346 break; 347 } else if (max_retries > 0) { 348 max_retries--; 349 } 350 351 usleep_range(1000, 2000); 352 length = 0; 353 continue; 354 } else if (length <= 0) { 355 total_read = length; 356 break; 357 } 358 } 359 return total_read; 360 } 361 362 /** 363 * ksmbd_tcp_read() - read data from socket in given buffer 364 * @t: TCP transport instance 365 * @buf: buffer to store read data from socket 366 * @to_read: number of bytes to read from socket 367 * 368 * Return: on success return number of bytes read from socket, 369 * otherwise return error number 370 */ 371 static int ksmbd_tcp_read(struct ksmbd_transport *t, char *buf, 372 unsigned int to_read, int max_retries) 373 { 374 struct kvec iov; 375 376 iov.iov_base = buf; 377 iov.iov_len = to_read; 378 379 return ksmbd_tcp_readv(TCP_TRANS(t), &iov, 1, to_read, max_retries); 380 } 381 382 static int ksmbd_tcp_writev(struct ksmbd_transport *t, struct kvec *iov, 383 int nvecs, int size, bool need_invalidate, 384 unsigned int remote_key) 385 386 { 387 struct msghdr smb_msg = {.msg_flags = MSG_NOSIGNAL}; 388 389 return kernel_sendmsg(TCP_TRANS(t)->sock, &smb_msg, iov, nvecs, size); 390 } 391 392 static void ksmbd_tcp_disconnect(struct ksmbd_transport *t) 393 { 394 free_transport(TCP_TRANS(t)); 395 if (server_conf.max_connections) 396 atomic_dec(&active_num_conn); 397 } 398 399 static void tcp_destroy_socket(struct socket *ksmbd_socket) 400 { 401 int ret; 402 403 if (!ksmbd_socket) 404 return; 405 406 /* set zero to timeout */ 407 ksmbd_tcp_rcv_timeout(ksmbd_socket, 0); 408 ksmbd_tcp_snd_timeout(ksmbd_socket, 0); 409 410 ret = kernel_sock_shutdown(ksmbd_socket, SHUT_RDWR); 411 if (ret) 412 pr_err("Failed to shutdown socket: %d\n", ret); 413 sock_release(ksmbd_socket); 414 } 415 416 /** 417 * create_socket - create socket for ksmbd/0 418 * 419 * Return: 0 on success, error number otherwise 420 */ 421 static int create_socket(struct interface *iface) 422 { 423 int ret; 424 struct sockaddr_in6 sin6; 425 struct sockaddr_in sin; 426 struct socket *ksmbd_socket; 427 bool ipv4 = false; 428 429 ret = sock_create(PF_INET6, SOCK_STREAM, IPPROTO_TCP, &ksmbd_socket); 430 if (ret) { 431 if (ret != -EAFNOSUPPORT) 432 pr_err("Can't create socket for ipv6, fallback to ipv4: %d\n", ret); 433 ret = sock_create(PF_INET, SOCK_STREAM, IPPROTO_TCP, 434 &ksmbd_socket); 435 if (ret) { 436 pr_err("Can't create socket for ipv4: %d\n", ret); 437 goto out_clear; 438 } 439 440 sin.sin_family = PF_INET; 441 sin.sin_addr.s_addr = htonl(INADDR_ANY); 442 sin.sin_port = htons(server_conf.tcp_port); 443 ipv4 = true; 444 } else { 445 sin6.sin6_family = PF_INET6; 446 sin6.sin6_addr = in6addr_any; 447 sin6.sin6_port = htons(server_conf.tcp_port); 448 } 449 450 ksmbd_tcp_nodelay(ksmbd_socket); 451 ksmbd_tcp_reuseaddr(ksmbd_socket); 452 453 ret = sock_setsockopt(ksmbd_socket, 454 SOL_SOCKET, 455 SO_BINDTODEVICE, 456 KERNEL_SOCKPTR(iface->name), 457 strlen(iface->name)); 458 if (ret != -ENODEV && ret < 0) { 459 pr_err("Failed to set SO_BINDTODEVICE: %d\n", ret); 460 goto out_error; 461 } 462 463 if (ipv4) 464 ret = kernel_bind(ksmbd_socket, (struct sockaddr *)&sin, 465 sizeof(sin)); 466 else 467 ret = kernel_bind(ksmbd_socket, (struct sockaddr *)&sin6, 468 sizeof(sin6)); 469 if (ret) { 470 pr_err("Failed to bind socket: %d\n", ret); 471 goto out_error; 472 } 473 474 ksmbd_socket->sk->sk_rcvtimeo = KSMBD_TCP_RECV_TIMEOUT; 475 ksmbd_socket->sk->sk_sndtimeo = KSMBD_TCP_SEND_TIMEOUT; 476 477 ret = kernel_listen(ksmbd_socket, KSMBD_SOCKET_BACKLOG); 478 if (ret) { 479 pr_err("Port listen() error: %d\n", ret); 480 goto out_error; 481 } 482 483 iface->ksmbd_socket = ksmbd_socket; 484 ret = ksmbd_tcp_run_kthread(iface); 485 if (ret) { 486 pr_err("Can't start ksmbd main kthread: %d\n", ret); 487 goto out_error; 488 } 489 iface->state = IFACE_STATE_CONFIGURED; 490 491 return 0; 492 493 out_error: 494 tcp_destroy_socket(ksmbd_socket); 495 out_clear: 496 iface->ksmbd_socket = NULL; 497 return ret; 498 } 499 500 static int ksmbd_netdev_event(struct notifier_block *nb, unsigned long event, 501 void *ptr) 502 { 503 struct net_device *netdev = netdev_notifier_info_to_dev(ptr); 504 struct interface *iface; 505 int ret, found = 0; 506 507 switch (event) { 508 case NETDEV_UP: 509 if (netif_is_bridge_port(netdev)) 510 return NOTIFY_OK; 511 512 list_for_each_entry(iface, &iface_list, entry) { 513 if (!strcmp(iface->name, netdev->name)) { 514 found = 1; 515 if (iface->state != IFACE_STATE_DOWN) 516 break; 517 ret = create_socket(iface); 518 if (ret) 519 return NOTIFY_OK; 520 break; 521 } 522 } 523 if (!found && bind_additional_ifaces) { 524 iface = alloc_iface(kstrdup(netdev->name, GFP_KERNEL)); 525 if (!iface) 526 return NOTIFY_OK; 527 ret = create_socket(iface); 528 if (ret) 529 break; 530 } 531 break; 532 case NETDEV_DOWN: 533 list_for_each_entry(iface, &iface_list, entry) { 534 if (!strcmp(iface->name, netdev->name) && 535 iface->state == IFACE_STATE_CONFIGURED) { 536 tcp_stop_kthread(iface->ksmbd_kthread); 537 iface->ksmbd_kthread = NULL; 538 mutex_lock(&iface->sock_release_lock); 539 tcp_destroy_socket(iface->ksmbd_socket); 540 iface->ksmbd_socket = NULL; 541 mutex_unlock(&iface->sock_release_lock); 542 543 iface->state = IFACE_STATE_DOWN; 544 break; 545 } 546 } 547 break; 548 } 549 550 return NOTIFY_DONE; 551 } 552 553 static struct notifier_block ksmbd_netdev_notifier = { 554 .notifier_call = ksmbd_netdev_event, 555 }; 556 557 int ksmbd_tcp_init(void) 558 { 559 register_netdevice_notifier(&ksmbd_netdev_notifier); 560 561 return 0; 562 } 563 564 static void tcp_stop_kthread(struct task_struct *kthread) 565 { 566 int ret; 567 568 if (!kthread) 569 return; 570 571 ret = kthread_stop(kthread); 572 if (ret) 573 pr_err("failed to stop forker thread\n"); 574 } 575 576 void ksmbd_tcp_destroy(void) 577 { 578 struct interface *iface, *tmp; 579 580 unregister_netdevice_notifier(&ksmbd_netdev_notifier); 581 582 list_for_each_entry_safe(iface, tmp, &iface_list, entry) { 583 list_del(&iface->entry); 584 kfree(iface->name); 585 kfree(iface); 586 } 587 } 588 589 static struct interface *alloc_iface(char *ifname) 590 { 591 struct interface *iface; 592 593 if (!ifname) 594 return NULL; 595 596 iface = kzalloc(sizeof(struct interface), GFP_KERNEL); 597 if (!iface) { 598 kfree(ifname); 599 return NULL; 600 } 601 602 iface->name = ifname; 603 iface->state = IFACE_STATE_DOWN; 604 list_add(&iface->entry, &iface_list); 605 mutex_init(&iface->sock_release_lock); 606 return iface; 607 } 608 609 int ksmbd_tcp_set_interfaces(char *ifc_list, int ifc_list_sz) 610 { 611 int sz = 0; 612 613 if (!ifc_list_sz) { 614 struct net_device *netdev; 615 616 rtnl_lock(); 617 for_each_netdev(&init_net, netdev) { 618 if (netif_is_bridge_port(netdev)) 619 continue; 620 if (!alloc_iface(kstrdup(netdev->name, GFP_KERNEL))) 621 return -ENOMEM; 622 } 623 rtnl_unlock(); 624 bind_additional_ifaces = 1; 625 return 0; 626 } 627 628 while (ifc_list_sz > 0) { 629 if (!alloc_iface(kstrdup(ifc_list, GFP_KERNEL))) 630 return -ENOMEM; 631 632 sz = strlen(ifc_list); 633 if (!sz) 634 break; 635 636 ifc_list += sz + 1; 637 ifc_list_sz -= (sz + 1); 638 } 639 640 bind_additional_ifaces = 0; 641 642 return 0; 643 } 644 645 static struct ksmbd_transport_ops ksmbd_tcp_transport_ops = { 646 .read = ksmbd_tcp_read, 647 .writev = ksmbd_tcp_writev, 648 .disconnect = ksmbd_tcp_disconnect, 649 }; 650