1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */ 3 4 #include <linux/skmsg.h> 5 #include <net/sock.h> 6 #include <net/udp.h> 7 #include <net/inet_common.h> 8 9 #include "udp_impl.h" 10 11 static struct proto *udpv6_prot_saved __read_mostly; 12 13 static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 14 int noblock, int flags, int *addr_len) 15 { 16 #if IS_ENABLED(CONFIG_IPV6) 17 if (sk->sk_family == AF_INET6) 18 return udpv6_prot_saved->recvmsg(sk, msg, len, noblock, flags, 19 addr_len); 20 #endif 21 return udp_prot.recvmsg(sk, msg, len, noblock, flags, addr_len); 22 } 23 24 static bool udp_sk_has_data(struct sock *sk) 25 { 26 return !skb_queue_empty(&udp_sk(sk)->reader_queue) || 27 !skb_queue_empty(&sk->sk_receive_queue); 28 } 29 30 static bool psock_has_data(struct sk_psock *psock) 31 { 32 return !skb_queue_empty(&psock->ingress_skb) || 33 !sk_psock_queue_empty(psock); 34 } 35 36 #define udp_msg_has_data(__sk, __psock) \ 37 ({ udp_sk_has_data(__sk) || psock_has_data(__psock); }) 38 39 static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock, 40 long timeo) 41 { 42 DEFINE_WAIT_FUNC(wait, woken_wake_function); 43 int ret = 0; 44 45 if (sk->sk_shutdown & RCV_SHUTDOWN) 46 return 1; 47 48 if (!timeo) 49 return ret; 50 51 add_wait_queue(sk_sleep(sk), &wait); 52 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 53 ret = udp_msg_has_data(sk, psock); 54 if (!ret) { 55 wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); 56 ret = udp_msg_has_data(sk, psock); 57 } 58 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 59 remove_wait_queue(sk_sleep(sk), &wait); 60 return ret; 61 } 62 63 static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 64 int nonblock, int flags, int *addr_len) 65 { 66 struct sk_psock *psock; 67 int copied, ret; 68 69 if (unlikely(flags & MSG_ERRQUEUE)) 70 return inet_recv_error(sk, msg, len, addr_len); 71 72 psock = sk_psock_get(sk); 73 if (unlikely(!psock)) 74 return sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 75 76 if (!psock_has_data(psock)) { 77 ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 78 goto out; 79 } 80 81 msg_bytes_ready: 82 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 83 if (!copied) { 84 long timeo; 85 int data; 86 87 timeo = sock_rcvtimeo(sk, nonblock); 88 data = udp_msg_wait_data(sk, psock, timeo); 89 if (data) { 90 if (psock_has_data(psock)) 91 goto msg_bytes_ready; 92 ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 93 goto out; 94 } 95 copied = -EAGAIN; 96 } 97 ret = copied; 98 out: 99 sk_psock_put(sk, psock); 100 return ret; 101 } 102 103 enum { 104 UDP_BPF_IPV4, 105 UDP_BPF_IPV6, 106 UDP_BPF_NUM_PROTS, 107 }; 108 109 static DEFINE_SPINLOCK(udpv6_prot_lock); 110 static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS]; 111 112 static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 113 { 114 *prot = *base; 115 prot->unhash = sock_map_unhash; 116 prot->close = sock_map_close; 117 prot->recvmsg = udp_bpf_recvmsg; 118 } 119 120 static void udp_bpf_check_v6_needs_rebuild(struct proto *ops) 121 { 122 if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) { 123 spin_lock_bh(&udpv6_prot_lock); 124 if (likely(ops != udpv6_prot_saved)) { 125 udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops); 126 smp_store_release(&udpv6_prot_saved, ops); 127 } 128 spin_unlock_bh(&udpv6_prot_lock); 129 } 130 } 131 132 static int __init udp_bpf_v4_build_proto(void) 133 { 134 udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot); 135 return 0; 136 } 137 late_initcall(udp_bpf_v4_build_proto); 138 139 int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 140 { 141 int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; 142 143 if (restore) { 144 sk->sk_write_space = psock->saved_write_space; 145 WRITE_ONCE(sk->sk_prot, psock->sk_proto); 146 return 0; 147 } 148 149 if (sk->sk_family == AF_INET6) 150 udp_bpf_check_v6_needs_rebuild(psock->sk_proto); 151 152 WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]); 153 return 0; 154 } 155 EXPORT_SYMBOL_GPL(udp_bpf_update_proto); 156