1634f1a71SBobby Eshleman // SPDX-License-Identifier: GPL-2.0
2634f1a71SBobby Eshleman /* Copyright (c) 2022 Bobby Eshleman <bobby.eshleman@bytedance.com>
3634f1a71SBobby Eshleman *
4634f1a71SBobby Eshleman * Based off of net/unix/unix_bpf.c
5634f1a71SBobby Eshleman */
6634f1a71SBobby Eshleman
7634f1a71SBobby Eshleman #include <linux/bpf.h>
8634f1a71SBobby Eshleman #include <linux/module.h>
9634f1a71SBobby Eshleman #include <linux/skmsg.h>
10634f1a71SBobby Eshleman #include <linux/socket.h>
11634f1a71SBobby Eshleman #include <linux/wait.h>
12634f1a71SBobby Eshleman #include <net/af_vsock.h>
13634f1a71SBobby Eshleman #include <net/sock.h>
14634f1a71SBobby Eshleman
15634f1a71SBobby Eshleman #define vsock_sk_has_data(__sk, __psock) \
16634f1a71SBobby Eshleman ({ !skb_queue_empty(&(__sk)->sk_receive_queue) || \
17634f1a71SBobby Eshleman !skb_queue_empty(&(__psock)->ingress_skb) || \
18634f1a71SBobby Eshleman !list_empty(&(__psock)->ingress_msg); \
19634f1a71SBobby Eshleman })
20634f1a71SBobby Eshleman
21634f1a71SBobby Eshleman static struct proto *vsock_prot_saved __read_mostly;
22634f1a71SBobby Eshleman static DEFINE_SPINLOCK(vsock_prot_lock);
23634f1a71SBobby Eshleman static struct proto vsock_bpf_prot;
24634f1a71SBobby Eshleman
vsock_has_data(struct sock * sk,struct sk_psock * psock)25634f1a71SBobby Eshleman static bool vsock_has_data(struct sock *sk, struct sk_psock *psock)
26634f1a71SBobby Eshleman {
27634f1a71SBobby Eshleman struct vsock_sock *vsk = vsock_sk(sk);
28634f1a71SBobby Eshleman s64 ret;
29634f1a71SBobby Eshleman
30634f1a71SBobby Eshleman ret = vsock_connectible_has_data(vsk);
31634f1a71SBobby Eshleman if (ret > 0)
32634f1a71SBobby Eshleman return true;
33634f1a71SBobby Eshleman
34634f1a71SBobby Eshleman return vsock_sk_has_data(sk, psock);
35634f1a71SBobby Eshleman }
36634f1a71SBobby Eshleman
vsock_msg_wait_data(struct sock * sk,struct sk_psock * psock,long timeo)37634f1a71SBobby Eshleman static bool vsock_msg_wait_data(struct sock *sk, struct sk_psock *psock, long timeo)
38634f1a71SBobby Eshleman {
39634f1a71SBobby Eshleman bool ret;
40634f1a71SBobby Eshleman
41634f1a71SBobby Eshleman DEFINE_WAIT_FUNC(wait, woken_wake_function);
42634f1a71SBobby Eshleman
43634f1a71SBobby Eshleman if (sk->sk_shutdown & RCV_SHUTDOWN)
44634f1a71SBobby Eshleman return true;
45634f1a71SBobby Eshleman
46634f1a71SBobby Eshleman if (!timeo)
47634f1a71SBobby Eshleman return false;
48634f1a71SBobby Eshleman
49634f1a71SBobby Eshleman add_wait_queue(sk_sleep(sk), &wait);
50634f1a71SBobby Eshleman sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
51634f1a71SBobby Eshleman ret = vsock_has_data(sk, psock);
52634f1a71SBobby Eshleman if (!ret) {
53634f1a71SBobby Eshleman wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
54634f1a71SBobby Eshleman ret = vsock_has_data(sk, psock);
55634f1a71SBobby Eshleman }
56634f1a71SBobby Eshleman sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
57634f1a71SBobby Eshleman remove_wait_queue(sk_sleep(sk), &wait);
58634f1a71SBobby Eshleman return ret;
59634f1a71SBobby Eshleman }
60634f1a71SBobby Eshleman
__vsock_recvmsg(struct sock * sk,struct msghdr * msg,size_t len,int flags)61634f1a71SBobby Eshleman static int __vsock_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags)
62634f1a71SBobby Eshleman {
63634f1a71SBobby Eshleman struct socket *sock = sk->sk_socket;
64634f1a71SBobby Eshleman int err;
65634f1a71SBobby Eshleman
66634f1a71SBobby Eshleman if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
67921f1acfSCong Wang err = __vsock_connectible_recvmsg(sock, msg, len, flags);
68634f1a71SBobby Eshleman else if (sk->sk_type == SOCK_DGRAM)
69921f1acfSCong Wang err = __vsock_dgram_recvmsg(sock, msg, len, flags);
70634f1a71SBobby Eshleman else
71634f1a71SBobby Eshleman err = -EPROTOTYPE;
72634f1a71SBobby Eshleman
73634f1a71SBobby Eshleman return err;
74634f1a71SBobby Eshleman }
75634f1a71SBobby Eshleman
vsock_bpf_recvmsg(struct sock * sk,struct msghdr * msg,size_t len,int flags,int * addr_len)76634f1a71SBobby Eshleman static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
77634f1a71SBobby Eshleman size_t len, int flags, int *addr_len)
78634f1a71SBobby Eshleman {
79634f1a71SBobby Eshleman struct sk_psock *psock;
8058e586c3SStefano Garzarella struct vsock_sock *vsk;
81634f1a71SBobby Eshleman int copied;
82634f1a71SBobby Eshleman
83634f1a71SBobby Eshleman psock = sk_psock_get(sk);
84634f1a71SBobby Eshleman if (unlikely(!psock))
85634f1a71SBobby Eshleman return __vsock_recvmsg(sk, msg, len, flags);
86634f1a71SBobby Eshleman
87634f1a71SBobby Eshleman lock_sock(sk);
8858e586c3SStefano Garzarella vsk = vsock_sk(sk);
8958e586c3SStefano Garzarella
90*61004a51SMichal Luczaj if (WARN_ON_ONCE(!vsk->transport)) {
9158e586c3SStefano Garzarella copied = -ENODEV;
9258e586c3SStefano Garzarella goto out;
9358e586c3SStefano Garzarella }
9458e586c3SStefano Garzarella
95634f1a71SBobby Eshleman if (vsock_has_data(sk, psock) && sk_psock_queue_empty(psock)) {
96634f1a71SBobby Eshleman release_sock(sk);
97634f1a71SBobby Eshleman sk_psock_put(sk, psock);
98634f1a71SBobby Eshleman return __vsock_recvmsg(sk, msg, len, flags);
99634f1a71SBobby Eshleman }
100634f1a71SBobby Eshleman
101634f1a71SBobby Eshleman copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
102634f1a71SBobby Eshleman while (copied == 0) {
103634f1a71SBobby Eshleman long timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
104634f1a71SBobby Eshleman
105634f1a71SBobby Eshleman if (!vsock_msg_wait_data(sk, psock, timeo)) {
106634f1a71SBobby Eshleman copied = -EAGAIN;
107634f1a71SBobby Eshleman break;
108634f1a71SBobby Eshleman }
109634f1a71SBobby Eshleman
110634f1a71SBobby Eshleman if (sk_psock_queue_empty(psock)) {
111634f1a71SBobby Eshleman release_sock(sk);
112634f1a71SBobby Eshleman sk_psock_put(sk, psock);
113634f1a71SBobby Eshleman return __vsock_recvmsg(sk, msg, len, flags);
114634f1a71SBobby Eshleman }
115634f1a71SBobby Eshleman
116634f1a71SBobby Eshleman copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
117634f1a71SBobby Eshleman }
118634f1a71SBobby Eshleman
11958e586c3SStefano Garzarella out:
120634f1a71SBobby Eshleman release_sock(sk);
121634f1a71SBobby Eshleman sk_psock_put(sk, psock);
122634f1a71SBobby Eshleman
123634f1a71SBobby Eshleman return copied;
124634f1a71SBobby Eshleman }
125634f1a71SBobby Eshleman
vsock_bpf_rebuild_protos(struct proto * prot,const struct proto * base)126634f1a71SBobby Eshleman static void vsock_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
127634f1a71SBobby Eshleman {
128634f1a71SBobby Eshleman *prot = *base;
129634f1a71SBobby Eshleman prot->close = sock_map_close;
130634f1a71SBobby Eshleman prot->recvmsg = vsock_bpf_recvmsg;
131634f1a71SBobby Eshleman prot->sock_is_readable = sk_msg_is_readable;
132634f1a71SBobby Eshleman }
133634f1a71SBobby Eshleman
vsock_bpf_check_needs_rebuild(struct proto * ops)134634f1a71SBobby Eshleman static void vsock_bpf_check_needs_rebuild(struct proto *ops)
135634f1a71SBobby Eshleman {
136634f1a71SBobby Eshleman /* Paired with the smp_store_release() below. */
137634f1a71SBobby Eshleman if (unlikely(ops != smp_load_acquire(&vsock_prot_saved))) {
138634f1a71SBobby Eshleman spin_lock_bh(&vsock_prot_lock);
139634f1a71SBobby Eshleman if (likely(ops != vsock_prot_saved)) {
140634f1a71SBobby Eshleman vsock_bpf_rebuild_protos(&vsock_bpf_prot, ops);
141634f1a71SBobby Eshleman /* Make sure proto function pointers are updated before publishing the
142634f1a71SBobby Eshleman * pointer to the struct.
143634f1a71SBobby Eshleman */
144634f1a71SBobby Eshleman smp_store_release(&vsock_prot_saved, ops);
145634f1a71SBobby Eshleman }
146634f1a71SBobby Eshleman spin_unlock_bh(&vsock_prot_lock);
147634f1a71SBobby Eshleman }
148634f1a71SBobby Eshleman }
149634f1a71SBobby Eshleman
vsock_bpf_update_proto(struct sock * sk,struct sk_psock * psock,bool restore)150634f1a71SBobby Eshleman int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
151634f1a71SBobby Eshleman {
152634f1a71SBobby Eshleman struct vsock_sock *vsk;
153634f1a71SBobby Eshleman
154634f1a71SBobby Eshleman if (restore) {
155634f1a71SBobby Eshleman sk->sk_write_space = psock->saved_write_space;
156634f1a71SBobby Eshleman sock_replace_proto(sk, psock->sk_proto);
157634f1a71SBobby Eshleman return 0;
158634f1a71SBobby Eshleman }
159634f1a71SBobby Eshleman
160634f1a71SBobby Eshleman vsk = vsock_sk(sk);
161634f1a71SBobby Eshleman if (!vsk->transport)
162634f1a71SBobby Eshleman return -ENODEV;
163634f1a71SBobby Eshleman
164634f1a71SBobby Eshleman if (!vsk->transport->read_skb)
165634f1a71SBobby Eshleman return -EOPNOTSUPP;
166634f1a71SBobby Eshleman
167634f1a71SBobby Eshleman vsock_bpf_check_needs_rebuild(psock->sk_proto);
168634f1a71SBobby Eshleman sock_replace_proto(sk, &vsock_bpf_prot);
169634f1a71SBobby Eshleman return 0;
170634f1a71SBobby Eshleman }
171634f1a71SBobby Eshleman
vsock_bpf_build_proto(void)172634f1a71SBobby Eshleman void __init vsock_bpf_build_proto(void)
173634f1a71SBobby Eshleman {
174634f1a71SBobby Eshleman vsock_bpf_rebuild_protos(&vsock_bpf_prot, &vsock_proto);
175634f1a71SBobby Eshleman }
176