xref: /openbmc/linux/net/vmw_vsock/vsock_bpf.c (revision d37cf9b63113f13d742713881ce691fc615d8b3b)
1634f1a71SBobby Eshleman // SPDX-License-Identifier: GPL-2.0
2634f1a71SBobby Eshleman /* Copyright (c) 2022 Bobby Eshleman <bobby.eshleman@bytedance.com>
3634f1a71SBobby Eshleman  *
4634f1a71SBobby Eshleman  * Based off of net/unix/unix_bpf.c
5634f1a71SBobby Eshleman  */
6634f1a71SBobby Eshleman 
7634f1a71SBobby Eshleman #include <linux/bpf.h>
8634f1a71SBobby Eshleman #include <linux/module.h>
9634f1a71SBobby Eshleman #include <linux/skmsg.h>
10634f1a71SBobby Eshleman #include <linux/socket.h>
11634f1a71SBobby Eshleman #include <linux/wait.h>
12634f1a71SBobby Eshleman #include <net/af_vsock.h>
13634f1a71SBobby Eshleman #include <net/sock.h>
14634f1a71SBobby Eshleman 
15634f1a71SBobby Eshleman #define vsock_sk_has_data(__sk, __psock)				\
16634f1a71SBobby Eshleman 		({	!skb_queue_empty(&(__sk)->sk_receive_queue) ||	\
17634f1a71SBobby Eshleman 			!skb_queue_empty(&(__psock)->ingress_skb) ||	\
18634f1a71SBobby Eshleman 			!list_empty(&(__psock)->ingress_msg);		\
19634f1a71SBobby Eshleman 		})
20634f1a71SBobby Eshleman 
21634f1a71SBobby Eshleman static struct proto *vsock_prot_saved __read_mostly;
22634f1a71SBobby Eshleman static DEFINE_SPINLOCK(vsock_prot_lock);
23634f1a71SBobby Eshleman static struct proto vsock_bpf_prot;
24634f1a71SBobby Eshleman 
vsock_has_data(struct sock * sk,struct sk_psock * psock)25634f1a71SBobby Eshleman static bool vsock_has_data(struct sock *sk, struct sk_psock *psock)
26634f1a71SBobby Eshleman {
27634f1a71SBobby Eshleman 	struct vsock_sock *vsk = vsock_sk(sk);
28634f1a71SBobby Eshleman 	s64 ret;
29634f1a71SBobby Eshleman 
30634f1a71SBobby Eshleman 	ret = vsock_connectible_has_data(vsk);
31634f1a71SBobby Eshleman 	if (ret > 0)
32634f1a71SBobby Eshleman 		return true;
33634f1a71SBobby Eshleman 
34634f1a71SBobby Eshleman 	return vsock_sk_has_data(sk, psock);
35634f1a71SBobby Eshleman }
36634f1a71SBobby Eshleman 
vsock_msg_wait_data(struct sock * sk,struct sk_psock * psock,long timeo)37634f1a71SBobby Eshleman static bool vsock_msg_wait_data(struct sock *sk, struct sk_psock *psock, long timeo)
38634f1a71SBobby Eshleman {
39634f1a71SBobby Eshleman 	bool ret;
40634f1a71SBobby Eshleman 
41634f1a71SBobby Eshleman 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
42634f1a71SBobby Eshleman 
43634f1a71SBobby Eshleman 	if (sk->sk_shutdown & RCV_SHUTDOWN)
44634f1a71SBobby Eshleman 		return true;
45634f1a71SBobby Eshleman 
46634f1a71SBobby Eshleman 	if (!timeo)
47634f1a71SBobby Eshleman 		return false;
48634f1a71SBobby Eshleman 
49634f1a71SBobby Eshleman 	add_wait_queue(sk_sleep(sk), &wait);
50634f1a71SBobby Eshleman 	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
51634f1a71SBobby Eshleman 	ret = vsock_has_data(sk, psock);
52634f1a71SBobby Eshleman 	if (!ret) {
53634f1a71SBobby Eshleman 		wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
54634f1a71SBobby Eshleman 		ret = vsock_has_data(sk, psock);
55634f1a71SBobby Eshleman 	}
56634f1a71SBobby Eshleman 	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
57634f1a71SBobby Eshleman 	remove_wait_queue(sk_sleep(sk), &wait);
58634f1a71SBobby Eshleman 	return ret;
59634f1a71SBobby Eshleman }
60634f1a71SBobby Eshleman 
__vsock_recvmsg(struct sock * sk,struct msghdr * msg,size_t len,int flags)61634f1a71SBobby Eshleman static int __vsock_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags)
62634f1a71SBobby Eshleman {
63634f1a71SBobby Eshleman 	struct socket *sock = sk->sk_socket;
64634f1a71SBobby Eshleman 	int err;
65634f1a71SBobby Eshleman 
66634f1a71SBobby Eshleman 	if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
67921f1acfSCong Wang 		err = __vsock_connectible_recvmsg(sock, msg, len, flags);
68634f1a71SBobby Eshleman 	else if (sk->sk_type == SOCK_DGRAM)
69921f1acfSCong Wang 		err = __vsock_dgram_recvmsg(sock, msg, len, flags);
70634f1a71SBobby Eshleman 	else
71634f1a71SBobby Eshleman 		err = -EPROTOTYPE;
72634f1a71SBobby Eshleman 
73634f1a71SBobby Eshleman 	return err;
74634f1a71SBobby Eshleman }
75634f1a71SBobby Eshleman 
vsock_bpf_recvmsg(struct sock * sk,struct msghdr * msg,size_t len,int flags,int * addr_len)76634f1a71SBobby Eshleman static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
77634f1a71SBobby Eshleman 			     size_t len, int flags, int *addr_len)
78634f1a71SBobby Eshleman {
79634f1a71SBobby Eshleman 	struct sk_psock *psock;
8058e586c3SStefano Garzarella 	struct vsock_sock *vsk;
81634f1a71SBobby Eshleman 	int copied;
82634f1a71SBobby Eshleman 
83634f1a71SBobby Eshleman 	psock = sk_psock_get(sk);
84634f1a71SBobby Eshleman 	if (unlikely(!psock))
85634f1a71SBobby Eshleman 		return __vsock_recvmsg(sk, msg, len, flags);
86634f1a71SBobby Eshleman 
87634f1a71SBobby Eshleman 	lock_sock(sk);
8858e586c3SStefano Garzarella 	vsk = vsock_sk(sk);
8958e586c3SStefano Garzarella 
90*61004a51SMichal Luczaj 	if (WARN_ON_ONCE(!vsk->transport)) {
9158e586c3SStefano Garzarella 		copied = -ENODEV;
9258e586c3SStefano Garzarella 		goto out;
9358e586c3SStefano Garzarella 	}
9458e586c3SStefano Garzarella 
95634f1a71SBobby Eshleman 	if (vsock_has_data(sk, psock) && sk_psock_queue_empty(psock)) {
96634f1a71SBobby Eshleman 		release_sock(sk);
97634f1a71SBobby Eshleman 		sk_psock_put(sk, psock);
98634f1a71SBobby Eshleman 		return __vsock_recvmsg(sk, msg, len, flags);
99634f1a71SBobby Eshleman 	}
100634f1a71SBobby Eshleman 
101634f1a71SBobby Eshleman 	copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
102634f1a71SBobby Eshleman 	while (copied == 0) {
103634f1a71SBobby Eshleman 		long timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
104634f1a71SBobby Eshleman 
105634f1a71SBobby Eshleman 		if (!vsock_msg_wait_data(sk, psock, timeo)) {
106634f1a71SBobby Eshleman 			copied = -EAGAIN;
107634f1a71SBobby Eshleman 			break;
108634f1a71SBobby Eshleman 		}
109634f1a71SBobby Eshleman 
110634f1a71SBobby Eshleman 		if (sk_psock_queue_empty(psock)) {
111634f1a71SBobby Eshleman 			release_sock(sk);
112634f1a71SBobby Eshleman 			sk_psock_put(sk, psock);
113634f1a71SBobby Eshleman 			return __vsock_recvmsg(sk, msg, len, flags);
114634f1a71SBobby Eshleman 		}
115634f1a71SBobby Eshleman 
116634f1a71SBobby Eshleman 		copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
117634f1a71SBobby Eshleman 	}
118634f1a71SBobby Eshleman 
11958e586c3SStefano Garzarella out:
120634f1a71SBobby Eshleman 	release_sock(sk);
121634f1a71SBobby Eshleman 	sk_psock_put(sk, psock);
122634f1a71SBobby Eshleman 
123634f1a71SBobby Eshleman 	return copied;
124634f1a71SBobby Eshleman }
125634f1a71SBobby Eshleman 
vsock_bpf_rebuild_protos(struct proto * prot,const struct proto * base)126634f1a71SBobby Eshleman static void vsock_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
127634f1a71SBobby Eshleman {
128634f1a71SBobby Eshleman 	*prot        = *base;
129634f1a71SBobby Eshleman 	prot->close  = sock_map_close;
130634f1a71SBobby Eshleman 	prot->recvmsg = vsock_bpf_recvmsg;
131634f1a71SBobby Eshleman 	prot->sock_is_readable = sk_msg_is_readable;
132634f1a71SBobby Eshleman }
133634f1a71SBobby Eshleman 
vsock_bpf_check_needs_rebuild(struct proto * ops)134634f1a71SBobby Eshleman static void vsock_bpf_check_needs_rebuild(struct proto *ops)
135634f1a71SBobby Eshleman {
136634f1a71SBobby Eshleman 	/* Paired with the smp_store_release() below. */
137634f1a71SBobby Eshleman 	if (unlikely(ops != smp_load_acquire(&vsock_prot_saved))) {
138634f1a71SBobby Eshleman 		spin_lock_bh(&vsock_prot_lock);
139634f1a71SBobby Eshleman 		if (likely(ops != vsock_prot_saved)) {
140634f1a71SBobby Eshleman 			vsock_bpf_rebuild_protos(&vsock_bpf_prot, ops);
141634f1a71SBobby Eshleman 			/* Make sure proto function pointers are updated before publishing the
142634f1a71SBobby Eshleman 			 * pointer to the struct.
143634f1a71SBobby Eshleman 			 */
144634f1a71SBobby Eshleman 			smp_store_release(&vsock_prot_saved, ops);
145634f1a71SBobby Eshleman 		}
146634f1a71SBobby Eshleman 		spin_unlock_bh(&vsock_prot_lock);
147634f1a71SBobby Eshleman 	}
148634f1a71SBobby Eshleman }
149634f1a71SBobby Eshleman 
vsock_bpf_update_proto(struct sock * sk,struct sk_psock * psock,bool restore)150634f1a71SBobby Eshleman int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
151634f1a71SBobby Eshleman {
152634f1a71SBobby Eshleman 	struct vsock_sock *vsk;
153634f1a71SBobby Eshleman 
154634f1a71SBobby Eshleman 	if (restore) {
155634f1a71SBobby Eshleman 		sk->sk_write_space = psock->saved_write_space;
156634f1a71SBobby Eshleman 		sock_replace_proto(sk, psock->sk_proto);
157634f1a71SBobby Eshleman 		return 0;
158634f1a71SBobby Eshleman 	}
159634f1a71SBobby Eshleman 
160634f1a71SBobby Eshleman 	vsk = vsock_sk(sk);
161634f1a71SBobby Eshleman 	if (!vsk->transport)
162634f1a71SBobby Eshleman 		return -ENODEV;
163634f1a71SBobby Eshleman 
164634f1a71SBobby Eshleman 	if (!vsk->transport->read_skb)
165634f1a71SBobby Eshleman 		return -EOPNOTSUPP;
166634f1a71SBobby Eshleman 
167634f1a71SBobby Eshleman 	vsock_bpf_check_needs_rebuild(psock->sk_proto);
168634f1a71SBobby Eshleman 	sock_replace_proto(sk, &vsock_bpf_prot);
169634f1a71SBobby Eshleman 	return 0;
170634f1a71SBobby Eshleman }
171634f1a71SBobby Eshleman 
vsock_bpf_build_proto(void)172634f1a71SBobby Eshleman void __init vsock_bpf_build_proto(void)
173634f1a71SBobby Eshleman {
174634f1a71SBobby Eshleman 	vsock_bpf_rebuild_protos(&vsock_bpf_prot, &vsock_proto);
175634f1a71SBobby Eshleman }
176