xref: /openbmc/linux/net/vmw_vsock/vsock_bpf.c (revision c67ce71d)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2022 Bobby Eshleman <bobby.eshleman@bytedance.com>
3  *
4  * Based off of net/unix/unix_bpf.c
5  */
6 
7 #include <linux/bpf.h>
8 #include <linux/module.h>
9 #include <linux/skmsg.h>
10 #include <linux/socket.h>
11 #include <linux/wait.h>
12 #include <net/af_vsock.h>
13 #include <net/sock.h>
14 
15 #define vsock_sk_has_data(__sk, __psock)				\
16 		({	!skb_queue_empty(&(__sk)->sk_receive_queue) ||	\
17 			!skb_queue_empty(&(__psock)->ingress_skb) ||	\
18 			!list_empty(&(__psock)->ingress_msg);		\
19 		})
20 
21 static struct proto *vsock_prot_saved __read_mostly;
22 static DEFINE_SPINLOCK(vsock_prot_lock);
23 static struct proto vsock_bpf_prot;
24 
25 static bool vsock_has_data(struct sock *sk, struct sk_psock *psock)
26 {
27 	struct vsock_sock *vsk = vsock_sk(sk);
28 	s64 ret;
29 
30 	ret = vsock_connectible_has_data(vsk);
31 	if (ret > 0)
32 		return true;
33 
34 	return vsock_sk_has_data(sk, psock);
35 }
36 
37 static bool vsock_msg_wait_data(struct sock *sk, struct sk_psock *psock, long timeo)
38 {
39 	bool ret;
40 
41 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
42 
43 	if (sk->sk_shutdown & RCV_SHUTDOWN)
44 		return true;
45 
46 	if (!timeo)
47 		return false;
48 
49 	add_wait_queue(sk_sleep(sk), &wait);
50 	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
51 	ret = vsock_has_data(sk, psock);
52 	if (!ret) {
53 		wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
54 		ret = vsock_has_data(sk, psock);
55 	}
56 	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
57 	remove_wait_queue(sk_sleep(sk), &wait);
58 	return ret;
59 }
60 
61 static int __vsock_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags)
62 {
63 	struct socket *sock = sk->sk_socket;
64 	int err;
65 
66 	if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
67 		err = __vsock_connectible_recvmsg(sock, msg, len, flags);
68 	else if (sk->sk_type == SOCK_DGRAM)
69 		err = __vsock_dgram_recvmsg(sock, msg, len, flags);
70 	else
71 		err = -EPROTOTYPE;
72 
73 	return err;
74 }
75 
76 static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
77 			     size_t len, int flags, int *addr_len)
78 {
79 	struct sk_psock *psock;
80 	int copied;
81 
82 	psock = sk_psock_get(sk);
83 	if (unlikely(!psock))
84 		return __vsock_recvmsg(sk, msg, len, flags);
85 
86 	lock_sock(sk);
87 	if (vsock_has_data(sk, psock) && sk_psock_queue_empty(psock)) {
88 		release_sock(sk);
89 		sk_psock_put(sk, psock);
90 		return __vsock_recvmsg(sk, msg, len, flags);
91 	}
92 
93 	copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
94 	while (copied == 0) {
95 		long timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
96 
97 		if (!vsock_msg_wait_data(sk, psock, timeo)) {
98 			copied = -EAGAIN;
99 			break;
100 		}
101 
102 		if (sk_psock_queue_empty(psock)) {
103 			release_sock(sk);
104 			sk_psock_put(sk, psock);
105 			return __vsock_recvmsg(sk, msg, len, flags);
106 		}
107 
108 		copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
109 	}
110 
111 	release_sock(sk);
112 	sk_psock_put(sk, psock);
113 
114 	return copied;
115 }
116 
117 static void vsock_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
118 {
119 	*prot        = *base;
120 	prot->close  = sock_map_close;
121 	prot->recvmsg = vsock_bpf_recvmsg;
122 	prot->sock_is_readable = sk_msg_is_readable;
123 }
124 
125 static void vsock_bpf_check_needs_rebuild(struct proto *ops)
126 {
127 	/* Paired with the smp_store_release() below. */
128 	if (unlikely(ops != smp_load_acquire(&vsock_prot_saved))) {
129 		spin_lock_bh(&vsock_prot_lock);
130 		if (likely(ops != vsock_prot_saved)) {
131 			vsock_bpf_rebuild_protos(&vsock_bpf_prot, ops);
132 			/* Make sure proto function pointers are updated before publishing the
133 			 * pointer to the struct.
134 			 */
135 			smp_store_release(&vsock_prot_saved, ops);
136 		}
137 		spin_unlock_bh(&vsock_prot_lock);
138 	}
139 }
140 
141 int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
142 {
143 	struct vsock_sock *vsk;
144 
145 	if (restore) {
146 		sk->sk_write_space = psock->saved_write_space;
147 		sock_replace_proto(sk, psock->sk_proto);
148 		return 0;
149 	}
150 
151 	vsk = vsock_sk(sk);
152 	if (!vsk->transport)
153 		return -ENODEV;
154 
155 	if (!vsk->transport->read_skb)
156 		return -EOPNOTSUPP;
157 
158 	vsock_bpf_check_needs_rebuild(psock->sk_proto);
159 	sock_replace_proto(sk, &vsock_bpf_prot);
160 	return 0;
161 }
162 
163 void __init vsock_bpf_build_proto(void)
164 {
165 	vsock_bpf_rebuild_protos(&vsock_bpf_prot, &vsock_proto);
166 }
167