xref: /openbmc/linux/net/handshake/netlink.c (revision 3e8bd1ba)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Generic netlink handshake service
4  *
5  * Author: Chuck Lever <chuck.lever@oracle.com>
6  *
7  * Copyright (c) 2023, Oracle and/or its affiliates.
8  */
9 
10 #include <linux/types.h>
11 #include <linux/socket.h>
12 #include <linux/kernel.h>
13 #include <linux/module.h>
14 #include <linux/skbuff.h>
15 #include <linux/mm.h>
16 
17 #include <net/sock.h>
18 #include <net/genetlink.h>
19 #include <net/netns/generic.h>
20 
21 #include <kunit/visibility.h>
22 
23 #include <uapi/linux/handshake.h>
24 #include "handshake.h"
25 #include "genl.h"
26 
27 #include <trace/events/handshake.h>
28 
29 /**
30  * handshake_genl_notify - Notify handlers that a request is waiting
31  * @net: target network namespace
32  * @proto: handshake protocol
33  * @flags: memory allocation control flags
34  *
35  * Returns zero on success or a negative errno if notification failed.
36  */
37 int handshake_genl_notify(struct net *net, const struct handshake_proto *proto,
38 			  gfp_t flags)
39 {
40 	struct sk_buff *msg;
41 	void *hdr;
42 
43 	/* Disable notifications during unit testing */
44 	if (!test_bit(HANDSHAKE_F_PROTO_NOTIFY, &proto->hp_flags))
45 		return 0;
46 
47 	if (!genl_has_listeners(&handshake_nl_family, net,
48 				proto->hp_handler_class))
49 		return -ESRCH;
50 
51 	msg = genlmsg_new(GENLMSG_DEFAULT_SIZE, flags);
52 	if (!msg)
53 		return -ENOMEM;
54 
55 	hdr = genlmsg_put(msg, 0, 0, &handshake_nl_family, 0,
56 			  HANDSHAKE_CMD_READY);
57 	if (!hdr)
58 		goto out_free;
59 
60 	if (nla_put_u32(msg, HANDSHAKE_A_ACCEPT_HANDLER_CLASS,
61 			proto->hp_handler_class) < 0) {
62 		genlmsg_cancel(msg, hdr);
63 		goto out_free;
64 	}
65 
66 	genlmsg_end(msg, hdr);
67 	return genlmsg_multicast_netns(&handshake_nl_family, net, msg,
68 				       0, proto->hp_handler_class, flags);
69 
70 out_free:
71 	nlmsg_free(msg);
72 	return -EMSGSIZE;
73 }
74 
75 /**
76  * handshake_genl_put - Create a generic netlink message header
77  * @msg: buffer in which to create the header
78  * @info: generic netlink message context
79  *
80  * Returns a ready-to-use header, or NULL.
81  */
82 struct nlmsghdr *handshake_genl_put(struct sk_buff *msg,
83 				    struct genl_info *info)
84 {
85 	return genlmsg_put(msg, info->snd_portid, info->snd_seq,
86 			   &handshake_nl_family, 0, info->genlhdr->cmd);
87 }
88 EXPORT_SYMBOL(handshake_genl_put);
89 
90 /*
91  * dup() a kernel socket for use as a user space file descriptor
92  * in the current process. The kernel socket must have an
93  * instatiated struct file.
94  *
95  * Implicit argument: "current()"
96  */
97 static int handshake_dup(struct socket *sock)
98 {
99 	struct file *file;
100 	int newfd;
101 
102 	file = get_file(sock->file);
103 	newfd = get_unused_fd_flags(O_CLOEXEC);
104 	if (newfd < 0) {
105 		fput(file);
106 		return newfd;
107 	}
108 
109 	fd_install(newfd, file);
110 	return newfd;
111 }
112 
113 int handshake_nl_accept_doit(struct sk_buff *skb, struct genl_info *info)
114 {
115 	struct net *net = sock_net(skb->sk);
116 	struct handshake_net *hn = handshake_pernet(net);
117 	struct handshake_req *req = NULL;
118 	struct socket *sock;
119 	int class, fd, err;
120 
121 	err = -EOPNOTSUPP;
122 	if (!hn)
123 		goto out_status;
124 
125 	err = -EINVAL;
126 	if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_ACCEPT_HANDLER_CLASS))
127 		goto out_status;
128 	class = nla_get_u32(info->attrs[HANDSHAKE_A_ACCEPT_HANDLER_CLASS]);
129 
130 	err = -EAGAIN;
131 	req = handshake_req_next(hn, class);
132 	if (!req)
133 		goto out_status;
134 
135 	sock = req->hr_sk->sk_socket;
136 	fd = handshake_dup(sock);
137 	if (fd < 0) {
138 		err = fd;
139 		goto out_complete;
140 	}
141 	err = req->hr_proto->hp_accept(req, info, fd);
142 	if (err) {
143 		fput(sock->file);
144 		goto out_complete;
145 	}
146 
147 	trace_handshake_cmd_accept(net, req, req->hr_sk, fd);
148 	return 0;
149 
150 out_complete:
151 	handshake_complete(req, -EIO, NULL);
152 out_status:
153 	trace_handshake_cmd_accept_err(net, req, NULL, err);
154 	return err;
155 }
156 
157 int handshake_nl_done_doit(struct sk_buff *skb, struct genl_info *info)
158 {
159 	struct net *net = sock_net(skb->sk);
160 	struct handshake_req *req;
161 	struct socket *sock;
162 	int fd, status, err;
163 
164 	if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_DONE_SOCKFD))
165 		return -EINVAL;
166 	fd = nla_get_u32(info->attrs[HANDSHAKE_A_DONE_SOCKFD]);
167 
168 	sock = sockfd_lookup(fd, &err);
169 	if (!sock)
170 		return err;
171 
172 	req = handshake_req_hash_lookup(sock->sk);
173 	if (!req) {
174 		err = -EBUSY;
175 		trace_handshake_cmd_done_err(net, req, sock->sk, err);
176 		fput(sock->file);
177 		return err;
178 	}
179 
180 	trace_handshake_cmd_done(net, req, sock->sk, fd);
181 
182 	status = -EIO;
183 	if (info->attrs[HANDSHAKE_A_DONE_STATUS])
184 		status = nla_get_u32(info->attrs[HANDSHAKE_A_DONE_STATUS]);
185 
186 	handshake_complete(req, status, info);
187 	fput(sock->file);
188 	return 0;
189 }
190 
191 static unsigned int handshake_net_id;
192 
193 static int __net_init handshake_net_init(struct net *net)
194 {
195 	struct handshake_net *hn = net_generic(net, handshake_net_id);
196 	unsigned long tmp;
197 	struct sysinfo si;
198 
199 	/*
200 	 * Arbitrary limit to prevent handshakes that do not make
201 	 * progress from clogging up the system. The cap scales up
202 	 * with the amount of physical memory on the system.
203 	 */
204 	si_meminfo(&si);
205 	tmp = si.totalram / (25 * si.mem_unit);
206 	hn->hn_pending_max = clamp(tmp, 3UL, 50UL);
207 
208 	spin_lock_init(&hn->hn_lock);
209 	hn->hn_pending = 0;
210 	hn->hn_flags = 0;
211 	INIT_LIST_HEAD(&hn->hn_requests);
212 	return 0;
213 }
214 
215 static void __net_exit handshake_net_exit(struct net *net)
216 {
217 	struct handshake_net *hn = net_generic(net, handshake_net_id);
218 	struct handshake_req *req;
219 	LIST_HEAD(requests);
220 
221 	/*
222 	 * Drain the net's pending list. Requests that have been
223 	 * accepted and are in progress will be destroyed when
224 	 * the socket is closed.
225 	 */
226 	spin_lock(&hn->hn_lock);
227 	set_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags);
228 	list_splice_init(&requests, &hn->hn_requests);
229 	spin_unlock(&hn->hn_lock);
230 
231 	while (!list_empty(&requests)) {
232 		req = list_first_entry(&requests, struct handshake_req, hr_list);
233 		list_del(&req->hr_list);
234 
235 		/*
236 		 * Requests on this list have not yet been
237 		 * accepted, so they do not have an fd to put.
238 		 */
239 
240 		handshake_complete(req, -ETIMEDOUT, NULL);
241 	}
242 }
243 
244 static struct pernet_operations handshake_genl_net_ops = {
245 	.init		= handshake_net_init,
246 	.exit		= handshake_net_exit,
247 	.id		= &handshake_net_id,
248 	.size		= sizeof(struct handshake_net),
249 };
250 
251 /**
252  * handshake_pernet - Get the handshake private per-net structure
253  * @net: network namespace
254  *
255  * Returns a pointer to the net's private per-net structure for the
256  * handshake module, or NULL if handshake_init() failed.
257  */
258 struct handshake_net *handshake_pernet(struct net *net)
259 {
260 	return handshake_net_id ?
261 		net_generic(net, handshake_net_id) : NULL;
262 }
263 EXPORT_SYMBOL_IF_KUNIT(handshake_pernet);
264 
265 static int __init handshake_init(void)
266 {
267 	int ret;
268 
269 	ret = handshake_req_hash_init();
270 	if (ret) {
271 		pr_warn("handshake: hash initialization failed (%d)\n", ret);
272 		return ret;
273 	}
274 
275 	ret = genl_register_family(&handshake_nl_family);
276 	if (ret) {
277 		pr_warn("handshake: netlink registration failed (%d)\n", ret);
278 		handshake_req_hash_destroy();
279 		return ret;
280 	}
281 
282 	/*
283 	 * ORDER: register_pernet_subsys must be done last.
284 	 *
285 	 *	If initialization does not make it past pernet_subsys
286 	 *	registration, then handshake_net_id will remain 0. That
287 	 *	shunts the handshake consumer API to return ENOTSUPP
288 	 *	to prevent it from dereferencing something that hasn't
289 	 *	been allocated.
290 	 */
291 	ret = register_pernet_subsys(&handshake_genl_net_ops);
292 	if (ret) {
293 		pr_warn("handshake: pernet registration failed (%d)\n", ret);
294 		genl_unregister_family(&handshake_nl_family);
295 		handshake_req_hash_destroy();
296 	}
297 
298 	return ret;
299 }
300 
301 static void __exit handshake_exit(void)
302 {
303 	unregister_pernet_subsys(&handshake_genl_net_ops);
304 	handshake_net_id = 0;
305 
306 	handshake_req_hash_destroy();
307 	genl_unregister_family(&handshake_nl_family);
308 }
309 
310 module_init(handshake_init);
311 module_exit(handshake_exit);
312