xref: /openbmc/linux/crypto/algif_skcipher.c (revision a89988a6)
1 /*
2  * algif_skcipher: User-space interface for skcipher algorithms
3  *
4  * This file provides the user-space API for symmetric key ciphers.
5  *
6  * Copyright (c) 2010 Herbert Xu <herbert@gondor.apana.org.au>
7  *
8  * This program is free software; you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License as published by the Free
10  * Software Foundation; either version 2 of the License, or (at your option)
11  * any later version.
12  *
13  */
14 
15 #include <crypto/scatterwalk.h>
16 #include <crypto/skcipher.h>
17 #include <crypto/if_alg.h>
18 #include <linux/init.h>
19 #include <linux/list.h>
20 #include <linux/kernel.h>
21 #include <linux/sched/signal.h>
22 #include <linux/mm.h>
23 #include <linux/module.h>
24 #include <linux/net.h>
25 #include <net/sock.h>
26 
27 struct skcipher_sg_list {
28 	struct list_head list;
29 
30 	int cur;
31 
32 	struct scatterlist sg[0];
33 };
34 
35 struct skcipher_tfm {
36 	struct crypto_skcipher *skcipher;
37 	bool has_key;
38 };
39 
40 struct skcipher_ctx {
41 	struct list_head tsgl;
42 	struct af_alg_sgl rsgl;
43 
44 	void *iv;
45 
46 	struct af_alg_completion completion;
47 
48 	atomic_t inflight;
49 	size_t used;
50 
51 	unsigned int len;
52 	bool more;
53 	bool merge;
54 	bool enc;
55 
56 	struct skcipher_request req;
57 };
58 
59 struct skcipher_async_rsgl {
60 	struct af_alg_sgl sgl;
61 	struct list_head list;
62 };
63 
64 struct skcipher_async_req {
65 	struct kiocb *iocb;
66 	struct skcipher_async_rsgl first_sgl;
67 	struct list_head list;
68 	struct scatterlist *tsg;
69 	atomic_t *inflight;
70 	struct skcipher_request req;
71 };
72 
73 #define MAX_SGL_ENTS ((4096 - sizeof(struct skcipher_sg_list)) / \
74 		      sizeof(struct scatterlist) - 1)
75 
76 static void skcipher_free_async_sgls(struct skcipher_async_req *sreq)
77 {
78 	struct skcipher_async_rsgl *rsgl, *tmp;
79 	struct scatterlist *sgl;
80 	struct scatterlist *sg;
81 	int i, n;
82 
83 	list_for_each_entry_safe(rsgl, tmp, &sreq->list, list) {
84 		af_alg_free_sg(&rsgl->sgl);
85 		if (rsgl != &sreq->first_sgl)
86 			kfree(rsgl);
87 	}
88 	sgl = sreq->tsg;
89 	n = sg_nents(sgl);
90 	for_each_sg(sgl, sg, n, i)
91 		put_page(sg_page(sg));
92 
93 	kfree(sreq->tsg);
94 }
95 
96 static void skcipher_async_cb(struct crypto_async_request *req, int err)
97 {
98 	struct skcipher_async_req *sreq = req->data;
99 	struct kiocb *iocb = sreq->iocb;
100 
101 	atomic_dec(sreq->inflight);
102 	skcipher_free_async_sgls(sreq);
103 	kzfree(sreq);
104 	iocb->ki_complete(iocb, err, err);
105 }
106 
107 static inline int skcipher_sndbuf(struct sock *sk)
108 {
109 	struct alg_sock *ask = alg_sk(sk);
110 	struct skcipher_ctx *ctx = ask->private;
111 
112 	return max_t(int, max_t(int, sk->sk_sndbuf & PAGE_MASK, PAGE_SIZE) -
113 			  ctx->used, 0);
114 }
115 
116 static inline bool skcipher_writable(struct sock *sk)
117 {
118 	return PAGE_SIZE <= skcipher_sndbuf(sk);
119 }
120 
121 static int skcipher_alloc_sgl(struct sock *sk)
122 {
123 	struct alg_sock *ask = alg_sk(sk);
124 	struct skcipher_ctx *ctx = ask->private;
125 	struct skcipher_sg_list *sgl;
126 	struct scatterlist *sg = NULL;
127 
128 	sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
129 	if (!list_empty(&ctx->tsgl))
130 		sg = sgl->sg;
131 
132 	if (!sg || sgl->cur >= MAX_SGL_ENTS) {
133 		sgl = sock_kmalloc(sk, sizeof(*sgl) +
134 				       sizeof(sgl->sg[0]) * (MAX_SGL_ENTS + 1),
135 				   GFP_KERNEL);
136 		if (!sgl)
137 			return -ENOMEM;
138 
139 		sg_init_table(sgl->sg, MAX_SGL_ENTS + 1);
140 		sgl->cur = 0;
141 
142 		if (sg)
143 			sg_chain(sg, MAX_SGL_ENTS + 1, sgl->sg);
144 
145 		list_add_tail(&sgl->list, &ctx->tsgl);
146 	}
147 
148 	return 0;
149 }
150 
151 static void skcipher_pull_sgl(struct sock *sk, size_t used, int put)
152 {
153 	struct alg_sock *ask = alg_sk(sk);
154 	struct skcipher_ctx *ctx = ask->private;
155 	struct skcipher_sg_list *sgl;
156 	struct scatterlist *sg;
157 	int i;
158 
159 	while (!list_empty(&ctx->tsgl)) {
160 		sgl = list_first_entry(&ctx->tsgl, struct skcipher_sg_list,
161 				       list);
162 		sg = sgl->sg;
163 
164 		for (i = 0; i < sgl->cur; i++) {
165 			size_t plen = min_t(size_t, used, sg[i].length);
166 
167 			if (!sg_page(sg + i))
168 				continue;
169 
170 			sg[i].length -= plen;
171 			sg[i].offset += plen;
172 
173 			used -= plen;
174 			ctx->used -= plen;
175 
176 			if (sg[i].length)
177 				return;
178 			if (put)
179 				put_page(sg_page(sg + i));
180 			sg_assign_page(sg + i, NULL);
181 		}
182 
183 		list_del(&sgl->list);
184 		sock_kfree_s(sk, sgl,
185 			     sizeof(*sgl) + sizeof(sgl->sg[0]) *
186 					    (MAX_SGL_ENTS + 1));
187 	}
188 
189 	if (!ctx->used)
190 		ctx->merge = 0;
191 }
192 
193 static void skcipher_free_sgl(struct sock *sk)
194 {
195 	struct alg_sock *ask = alg_sk(sk);
196 	struct skcipher_ctx *ctx = ask->private;
197 
198 	skcipher_pull_sgl(sk, ctx->used, 1);
199 }
200 
201 static int skcipher_wait_for_wmem(struct sock *sk, unsigned flags)
202 {
203 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
204 	int err = -ERESTARTSYS;
205 	long timeout;
206 
207 	if (flags & MSG_DONTWAIT)
208 		return -EAGAIN;
209 
210 	sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk);
211 
212 	add_wait_queue(sk_sleep(sk), &wait);
213 	for (;;) {
214 		if (signal_pending(current))
215 			break;
216 		timeout = MAX_SCHEDULE_TIMEOUT;
217 		if (sk_wait_event(sk, &timeout, skcipher_writable(sk), &wait)) {
218 			err = 0;
219 			break;
220 		}
221 	}
222 	remove_wait_queue(sk_sleep(sk), &wait);
223 
224 	return err;
225 }
226 
227 static void skcipher_wmem_wakeup(struct sock *sk)
228 {
229 	struct socket_wq *wq;
230 
231 	if (!skcipher_writable(sk))
232 		return;
233 
234 	rcu_read_lock();
235 	wq = rcu_dereference(sk->sk_wq);
236 	if (skwq_has_sleeper(wq))
237 		wake_up_interruptible_sync_poll(&wq->wait, POLLIN |
238 							   POLLRDNORM |
239 							   POLLRDBAND);
240 	sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
241 	rcu_read_unlock();
242 }
243 
244 static int skcipher_wait_for_data(struct sock *sk, unsigned flags)
245 {
246 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
247 	struct alg_sock *ask = alg_sk(sk);
248 	struct skcipher_ctx *ctx = ask->private;
249 	long timeout;
250 	int err = -ERESTARTSYS;
251 
252 	if (flags & MSG_DONTWAIT) {
253 		return -EAGAIN;
254 	}
255 
256 	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
257 
258 	add_wait_queue(sk_sleep(sk), &wait);
259 	for (;;) {
260 		if (signal_pending(current))
261 			break;
262 		timeout = MAX_SCHEDULE_TIMEOUT;
263 		if (sk_wait_event(sk, &timeout, ctx->used, &wait)) {
264 			err = 0;
265 			break;
266 		}
267 	}
268 	remove_wait_queue(sk_sleep(sk), &wait);
269 
270 	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
271 
272 	return err;
273 }
274 
275 static void skcipher_data_wakeup(struct sock *sk)
276 {
277 	struct alg_sock *ask = alg_sk(sk);
278 	struct skcipher_ctx *ctx = ask->private;
279 	struct socket_wq *wq;
280 
281 	if (!ctx->used)
282 		return;
283 
284 	rcu_read_lock();
285 	wq = rcu_dereference(sk->sk_wq);
286 	if (skwq_has_sleeper(wq))
287 		wake_up_interruptible_sync_poll(&wq->wait, POLLOUT |
288 							   POLLRDNORM |
289 							   POLLRDBAND);
290 	sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT);
291 	rcu_read_unlock();
292 }
293 
294 static int skcipher_sendmsg(struct socket *sock, struct msghdr *msg,
295 			    size_t size)
296 {
297 	struct sock *sk = sock->sk;
298 	struct alg_sock *ask = alg_sk(sk);
299 	struct sock *psk = ask->parent;
300 	struct alg_sock *pask = alg_sk(psk);
301 	struct skcipher_ctx *ctx = ask->private;
302 	struct skcipher_tfm *skc = pask->private;
303 	struct crypto_skcipher *tfm = skc->skcipher;
304 	unsigned ivsize = crypto_skcipher_ivsize(tfm);
305 	struct skcipher_sg_list *sgl;
306 	struct af_alg_control con = {};
307 	long copied = 0;
308 	bool enc = 0;
309 	bool init = 0;
310 	int err;
311 	int i;
312 
313 	if (msg->msg_controllen) {
314 		err = af_alg_cmsg_send(msg, &con);
315 		if (err)
316 			return err;
317 
318 		init = 1;
319 		switch (con.op) {
320 		case ALG_OP_ENCRYPT:
321 			enc = 1;
322 			break;
323 		case ALG_OP_DECRYPT:
324 			enc = 0;
325 			break;
326 		default:
327 			return -EINVAL;
328 		}
329 
330 		if (con.iv && con.iv->ivlen != ivsize)
331 			return -EINVAL;
332 	}
333 
334 	err = -EINVAL;
335 
336 	lock_sock(sk);
337 	if (!ctx->more && ctx->used)
338 		goto unlock;
339 
340 	if (init) {
341 		ctx->enc = enc;
342 		if (con.iv)
343 			memcpy(ctx->iv, con.iv->iv, ivsize);
344 	}
345 
346 	while (size) {
347 		struct scatterlist *sg;
348 		unsigned long len = size;
349 		size_t plen;
350 
351 		if (ctx->merge) {
352 			sgl = list_entry(ctx->tsgl.prev,
353 					 struct skcipher_sg_list, list);
354 			sg = sgl->sg + sgl->cur - 1;
355 			len = min_t(unsigned long, len,
356 				    PAGE_SIZE - sg->offset - sg->length);
357 
358 			err = memcpy_from_msg(page_address(sg_page(sg)) +
359 					      sg->offset + sg->length,
360 					      msg, len);
361 			if (err)
362 				goto unlock;
363 
364 			sg->length += len;
365 			ctx->merge = (sg->offset + sg->length) &
366 				     (PAGE_SIZE - 1);
367 
368 			ctx->used += len;
369 			copied += len;
370 			size -= len;
371 			continue;
372 		}
373 
374 		if (!skcipher_writable(sk)) {
375 			err = skcipher_wait_for_wmem(sk, msg->msg_flags);
376 			if (err)
377 				goto unlock;
378 		}
379 
380 		len = min_t(unsigned long, len, skcipher_sndbuf(sk));
381 
382 		err = skcipher_alloc_sgl(sk);
383 		if (err)
384 			goto unlock;
385 
386 		sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
387 		sg = sgl->sg;
388 		if (sgl->cur)
389 			sg_unmark_end(sg + sgl->cur - 1);
390 		do {
391 			i = sgl->cur;
392 			plen = min_t(size_t, len, PAGE_SIZE);
393 
394 			sg_assign_page(sg + i, alloc_page(GFP_KERNEL));
395 			err = -ENOMEM;
396 			if (!sg_page(sg + i))
397 				goto unlock;
398 
399 			err = memcpy_from_msg(page_address(sg_page(sg + i)),
400 					      msg, plen);
401 			if (err) {
402 				__free_page(sg_page(sg + i));
403 				sg_assign_page(sg + i, NULL);
404 				goto unlock;
405 			}
406 
407 			sg[i].length = plen;
408 			len -= plen;
409 			ctx->used += plen;
410 			copied += plen;
411 			size -= plen;
412 			sgl->cur++;
413 		} while (len && sgl->cur < MAX_SGL_ENTS);
414 
415 		if (!size)
416 			sg_mark_end(sg + sgl->cur - 1);
417 
418 		ctx->merge = plen & (PAGE_SIZE - 1);
419 	}
420 
421 	err = 0;
422 
423 	ctx->more = msg->msg_flags & MSG_MORE;
424 
425 unlock:
426 	skcipher_data_wakeup(sk);
427 	release_sock(sk);
428 
429 	return copied ?: err;
430 }
431 
432 static ssize_t skcipher_sendpage(struct socket *sock, struct page *page,
433 				 int offset, size_t size, int flags)
434 {
435 	struct sock *sk = sock->sk;
436 	struct alg_sock *ask = alg_sk(sk);
437 	struct skcipher_ctx *ctx = ask->private;
438 	struct skcipher_sg_list *sgl;
439 	int err = -EINVAL;
440 
441 	if (flags & MSG_SENDPAGE_NOTLAST)
442 		flags |= MSG_MORE;
443 
444 	lock_sock(sk);
445 	if (!ctx->more && ctx->used)
446 		goto unlock;
447 
448 	if (!size)
449 		goto done;
450 
451 	if (!skcipher_writable(sk)) {
452 		err = skcipher_wait_for_wmem(sk, flags);
453 		if (err)
454 			goto unlock;
455 	}
456 
457 	err = skcipher_alloc_sgl(sk);
458 	if (err)
459 		goto unlock;
460 
461 	ctx->merge = 0;
462 	sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
463 
464 	if (sgl->cur)
465 		sg_unmark_end(sgl->sg + sgl->cur - 1);
466 
467 	sg_mark_end(sgl->sg + sgl->cur);
468 	get_page(page);
469 	sg_set_page(sgl->sg + sgl->cur, page, size, offset);
470 	sgl->cur++;
471 	ctx->used += size;
472 
473 done:
474 	ctx->more = flags & MSG_MORE;
475 
476 unlock:
477 	skcipher_data_wakeup(sk);
478 	release_sock(sk);
479 
480 	return err ?: size;
481 }
482 
483 static int skcipher_all_sg_nents(struct skcipher_ctx *ctx)
484 {
485 	struct skcipher_sg_list *sgl;
486 	struct scatterlist *sg;
487 	int nents = 0;
488 
489 	list_for_each_entry(sgl, &ctx->tsgl, list) {
490 		sg = sgl->sg;
491 
492 		while (!sg->length)
493 			sg++;
494 
495 		nents += sg_nents(sg);
496 	}
497 	return nents;
498 }
499 
500 static int skcipher_recvmsg_async(struct socket *sock, struct msghdr *msg,
501 				  int flags)
502 {
503 	struct sock *sk = sock->sk;
504 	struct alg_sock *ask = alg_sk(sk);
505 	struct sock *psk = ask->parent;
506 	struct alg_sock *pask = alg_sk(psk);
507 	struct skcipher_ctx *ctx = ask->private;
508 	struct skcipher_tfm *skc = pask->private;
509 	struct crypto_skcipher *tfm = skc->skcipher;
510 	struct skcipher_sg_list *sgl;
511 	struct scatterlist *sg;
512 	struct skcipher_async_req *sreq;
513 	struct skcipher_request *req;
514 	struct skcipher_async_rsgl *last_rsgl = NULL;
515 	unsigned int txbufs = 0, len = 0, tx_nents;
516 	unsigned int reqsize = crypto_skcipher_reqsize(tfm);
517 	unsigned int ivsize = crypto_skcipher_ivsize(tfm);
518 	int err = -ENOMEM;
519 	bool mark = false;
520 	char *iv;
521 
522 	sreq = kzalloc(sizeof(*sreq) + reqsize + ivsize, GFP_KERNEL);
523 	if (unlikely(!sreq))
524 		goto out;
525 
526 	req = &sreq->req;
527 	iv = (char *)(req + 1) + reqsize;
528 	sreq->iocb = msg->msg_iocb;
529 	INIT_LIST_HEAD(&sreq->list);
530 	sreq->inflight = &ctx->inflight;
531 
532 	lock_sock(sk);
533 	tx_nents = skcipher_all_sg_nents(ctx);
534 	sreq->tsg = kcalloc(tx_nents, sizeof(*sg), GFP_KERNEL);
535 	if (unlikely(!sreq->tsg))
536 		goto unlock;
537 	sg_init_table(sreq->tsg, tx_nents);
538 	memcpy(iv, ctx->iv, ivsize);
539 	skcipher_request_set_tfm(req, tfm);
540 	skcipher_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP,
541 				      skcipher_async_cb, sreq);
542 
543 	while (iov_iter_count(&msg->msg_iter)) {
544 		struct skcipher_async_rsgl *rsgl;
545 		int used;
546 
547 		if (!ctx->used) {
548 			err = skcipher_wait_for_data(sk, flags);
549 			if (err)
550 				goto free;
551 		}
552 		sgl = list_first_entry(&ctx->tsgl,
553 				       struct skcipher_sg_list, list);
554 		sg = sgl->sg;
555 
556 		while (!sg->length)
557 			sg++;
558 
559 		used = min_t(unsigned long, ctx->used,
560 			     iov_iter_count(&msg->msg_iter));
561 		used = min_t(unsigned long, used, sg->length);
562 
563 		if (txbufs == tx_nents) {
564 			struct scatterlist *tmp;
565 			int x;
566 			/* Ran out of tx slots in async request
567 			 * need to expand */
568 			tmp = kcalloc(tx_nents * 2, sizeof(*tmp),
569 				      GFP_KERNEL);
570 			if (!tmp) {
571 				err = -ENOMEM;
572 				goto free;
573 			}
574 
575 			sg_init_table(tmp, tx_nents * 2);
576 			for (x = 0; x < tx_nents; x++)
577 				sg_set_page(&tmp[x], sg_page(&sreq->tsg[x]),
578 					    sreq->tsg[x].length,
579 					    sreq->tsg[x].offset);
580 			kfree(sreq->tsg);
581 			sreq->tsg = tmp;
582 			tx_nents *= 2;
583 			mark = true;
584 		}
585 		/* Need to take over the tx sgl from ctx
586 		 * to the asynch req - these sgls will be freed later */
587 		sg_set_page(sreq->tsg + txbufs++, sg_page(sg), sg->length,
588 			    sg->offset);
589 
590 		if (list_empty(&sreq->list)) {
591 			rsgl = &sreq->first_sgl;
592 			list_add_tail(&rsgl->list, &sreq->list);
593 		} else {
594 			rsgl = kmalloc(sizeof(*rsgl), GFP_KERNEL);
595 			if (!rsgl) {
596 				err = -ENOMEM;
597 				goto free;
598 			}
599 			list_add_tail(&rsgl->list, &sreq->list);
600 		}
601 
602 		used = af_alg_make_sg(&rsgl->sgl, &msg->msg_iter, used);
603 		err = used;
604 		if (used < 0)
605 			goto free;
606 		if (last_rsgl)
607 			af_alg_link_sg(&last_rsgl->sgl, &rsgl->sgl);
608 
609 		last_rsgl = rsgl;
610 		len += used;
611 		skcipher_pull_sgl(sk, used, 0);
612 		iov_iter_advance(&msg->msg_iter, used);
613 	}
614 
615 	if (mark)
616 		sg_mark_end(sreq->tsg + txbufs - 1);
617 
618 	skcipher_request_set_crypt(req, sreq->tsg, sreq->first_sgl.sgl.sg,
619 				   len, iv);
620 	err = ctx->enc ? crypto_skcipher_encrypt(req) :
621 			 crypto_skcipher_decrypt(req);
622 	if (err == -EINPROGRESS) {
623 		atomic_inc(&ctx->inflight);
624 		err = -EIOCBQUEUED;
625 		sreq = NULL;
626 		goto unlock;
627 	}
628 free:
629 	skcipher_free_async_sgls(sreq);
630 unlock:
631 	skcipher_wmem_wakeup(sk);
632 	release_sock(sk);
633 	kzfree(sreq);
634 out:
635 	return err;
636 }
637 
638 static int skcipher_recvmsg_sync(struct socket *sock, struct msghdr *msg,
639 				 int flags)
640 {
641 	struct sock *sk = sock->sk;
642 	struct alg_sock *ask = alg_sk(sk);
643 	struct sock *psk = ask->parent;
644 	struct alg_sock *pask = alg_sk(psk);
645 	struct skcipher_ctx *ctx = ask->private;
646 	struct skcipher_tfm *skc = pask->private;
647 	struct crypto_skcipher *tfm = skc->skcipher;
648 	unsigned bs = crypto_skcipher_blocksize(tfm);
649 	struct skcipher_sg_list *sgl;
650 	struct scatterlist *sg;
651 	int err = -EAGAIN;
652 	int used;
653 	long copied = 0;
654 
655 	lock_sock(sk);
656 	while (msg_data_left(msg)) {
657 		if (!ctx->used) {
658 			err = skcipher_wait_for_data(sk, flags);
659 			if (err)
660 				goto unlock;
661 		}
662 
663 		used = min_t(unsigned long, ctx->used, msg_data_left(msg));
664 
665 		used = af_alg_make_sg(&ctx->rsgl, &msg->msg_iter, used);
666 		err = used;
667 		if (err < 0)
668 			goto unlock;
669 
670 		if (ctx->more || used < ctx->used)
671 			used -= used % bs;
672 
673 		err = -EINVAL;
674 		if (!used)
675 			goto free;
676 
677 		sgl = list_first_entry(&ctx->tsgl,
678 				       struct skcipher_sg_list, list);
679 		sg = sgl->sg;
680 
681 		while (!sg->length)
682 			sg++;
683 
684 		skcipher_request_set_crypt(&ctx->req, sg, ctx->rsgl.sg, used,
685 					   ctx->iv);
686 
687 		err = af_alg_wait_for_completion(
688 				ctx->enc ?
689 					crypto_skcipher_encrypt(&ctx->req) :
690 					crypto_skcipher_decrypt(&ctx->req),
691 				&ctx->completion);
692 
693 free:
694 		af_alg_free_sg(&ctx->rsgl);
695 
696 		if (err)
697 			goto unlock;
698 
699 		copied += used;
700 		skcipher_pull_sgl(sk, used, 1);
701 		iov_iter_advance(&msg->msg_iter, used);
702 	}
703 
704 	err = 0;
705 
706 unlock:
707 	skcipher_wmem_wakeup(sk);
708 	release_sock(sk);
709 
710 	return copied ?: err;
711 }
712 
713 static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg,
714 			    size_t ignored, int flags)
715 {
716 	return (msg->msg_iocb && !is_sync_kiocb(msg->msg_iocb)) ?
717 		skcipher_recvmsg_async(sock, msg, flags) :
718 		skcipher_recvmsg_sync(sock, msg, flags);
719 }
720 
721 static unsigned int skcipher_poll(struct file *file, struct socket *sock,
722 				  poll_table *wait)
723 {
724 	struct sock *sk = sock->sk;
725 	struct alg_sock *ask = alg_sk(sk);
726 	struct skcipher_ctx *ctx = ask->private;
727 	unsigned int mask;
728 
729 	sock_poll_wait(file, sk_sleep(sk), wait);
730 	mask = 0;
731 
732 	if (ctx->used)
733 		mask |= POLLIN | POLLRDNORM;
734 
735 	if (skcipher_writable(sk))
736 		mask |= POLLOUT | POLLWRNORM | POLLWRBAND;
737 
738 	return mask;
739 }
740 
741 static struct proto_ops algif_skcipher_ops = {
742 	.family		=	PF_ALG,
743 
744 	.connect	=	sock_no_connect,
745 	.socketpair	=	sock_no_socketpair,
746 	.getname	=	sock_no_getname,
747 	.ioctl		=	sock_no_ioctl,
748 	.listen		=	sock_no_listen,
749 	.shutdown	=	sock_no_shutdown,
750 	.getsockopt	=	sock_no_getsockopt,
751 	.mmap		=	sock_no_mmap,
752 	.bind		=	sock_no_bind,
753 	.accept		=	sock_no_accept,
754 	.setsockopt	=	sock_no_setsockopt,
755 
756 	.release	=	af_alg_release,
757 	.sendmsg	=	skcipher_sendmsg,
758 	.sendpage	=	skcipher_sendpage,
759 	.recvmsg	=	skcipher_recvmsg,
760 	.poll		=	skcipher_poll,
761 };
762 
763 static int skcipher_check_key(struct socket *sock)
764 {
765 	int err = 0;
766 	struct sock *psk;
767 	struct alg_sock *pask;
768 	struct skcipher_tfm *tfm;
769 	struct sock *sk = sock->sk;
770 	struct alg_sock *ask = alg_sk(sk);
771 
772 	lock_sock(sk);
773 	if (ask->refcnt)
774 		goto unlock_child;
775 
776 	psk = ask->parent;
777 	pask = alg_sk(ask->parent);
778 	tfm = pask->private;
779 
780 	err = -ENOKEY;
781 	lock_sock_nested(psk, SINGLE_DEPTH_NESTING);
782 	if (!tfm->has_key)
783 		goto unlock;
784 
785 	if (!pask->refcnt++)
786 		sock_hold(psk);
787 
788 	ask->refcnt = 1;
789 	sock_put(psk);
790 
791 	err = 0;
792 
793 unlock:
794 	release_sock(psk);
795 unlock_child:
796 	release_sock(sk);
797 
798 	return err;
799 }
800 
801 static int skcipher_sendmsg_nokey(struct socket *sock, struct msghdr *msg,
802 				  size_t size)
803 {
804 	int err;
805 
806 	err = skcipher_check_key(sock);
807 	if (err)
808 		return err;
809 
810 	return skcipher_sendmsg(sock, msg, size);
811 }
812 
813 static ssize_t skcipher_sendpage_nokey(struct socket *sock, struct page *page,
814 				       int offset, size_t size, int flags)
815 {
816 	int err;
817 
818 	err = skcipher_check_key(sock);
819 	if (err)
820 		return err;
821 
822 	return skcipher_sendpage(sock, page, offset, size, flags);
823 }
824 
825 static int skcipher_recvmsg_nokey(struct socket *sock, struct msghdr *msg,
826 				  size_t ignored, int flags)
827 {
828 	int err;
829 
830 	err = skcipher_check_key(sock);
831 	if (err)
832 		return err;
833 
834 	return skcipher_recvmsg(sock, msg, ignored, flags);
835 }
836 
837 static struct proto_ops algif_skcipher_ops_nokey = {
838 	.family		=	PF_ALG,
839 
840 	.connect	=	sock_no_connect,
841 	.socketpair	=	sock_no_socketpair,
842 	.getname	=	sock_no_getname,
843 	.ioctl		=	sock_no_ioctl,
844 	.listen		=	sock_no_listen,
845 	.shutdown	=	sock_no_shutdown,
846 	.getsockopt	=	sock_no_getsockopt,
847 	.mmap		=	sock_no_mmap,
848 	.bind		=	sock_no_bind,
849 	.accept		=	sock_no_accept,
850 	.setsockopt	=	sock_no_setsockopt,
851 
852 	.release	=	af_alg_release,
853 	.sendmsg	=	skcipher_sendmsg_nokey,
854 	.sendpage	=	skcipher_sendpage_nokey,
855 	.recvmsg	=	skcipher_recvmsg_nokey,
856 	.poll		=	skcipher_poll,
857 };
858 
859 static void *skcipher_bind(const char *name, u32 type, u32 mask)
860 {
861 	struct skcipher_tfm *tfm;
862 	struct crypto_skcipher *skcipher;
863 
864 	tfm = kzalloc(sizeof(*tfm), GFP_KERNEL);
865 	if (!tfm)
866 		return ERR_PTR(-ENOMEM);
867 
868 	skcipher = crypto_alloc_skcipher(name, type, mask);
869 	if (IS_ERR(skcipher)) {
870 		kfree(tfm);
871 		return ERR_CAST(skcipher);
872 	}
873 
874 	tfm->skcipher = skcipher;
875 
876 	return tfm;
877 }
878 
879 static void skcipher_release(void *private)
880 {
881 	struct skcipher_tfm *tfm = private;
882 
883 	crypto_free_skcipher(tfm->skcipher);
884 	kfree(tfm);
885 }
886 
887 static int skcipher_setkey(void *private, const u8 *key, unsigned int keylen)
888 {
889 	struct skcipher_tfm *tfm = private;
890 	int err;
891 
892 	err = crypto_skcipher_setkey(tfm->skcipher, key, keylen);
893 	tfm->has_key = !err;
894 
895 	return err;
896 }
897 
898 static void skcipher_wait(struct sock *sk)
899 {
900 	struct alg_sock *ask = alg_sk(sk);
901 	struct skcipher_ctx *ctx = ask->private;
902 	int ctr = 0;
903 
904 	while (atomic_read(&ctx->inflight) && ctr++ < 100)
905 		msleep(100);
906 }
907 
908 static void skcipher_sock_destruct(struct sock *sk)
909 {
910 	struct alg_sock *ask = alg_sk(sk);
911 	struct skcipher_ctx *ctx = ask->private;
912 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(&ctx->req);
913 
914 	if (atomic_read(&ctx->inflight))
915 		skcipher_wait(sk);
916 
917 	skcipher_free_sgl(sk);
918 	sock_kzfree_s(sk, ctx->iv, crypto_skcipher_ivsize(tfm));
919 	sock_kfree_s(sk, ctx, ctx->len);
920 	af_alg_release_parent(sk);
921 }
922 
923 static int skcipher_accept_parent_nokey(void *private, struct sock *sk)
924 {
925 	struct skcipher_ctx *ctx;
926 	struct alg_sock *ask = alg_sk(sk);
927 	struct skcipher_tfm *tfm = private;
928 	struct crypto_skcipher *skcipher = tfm->skcipher;
929 	unsigned int len = sizeof(*ctx) + crypto_skcipher_reqsize(skcipher);
930 
931 	ctx = sock_kmalloc(sk, len, GFP_KERNEL);
932 	if (!ctx)
933 		return -ENOMEM;
934 
935 	ctx->iv = sock_kmalloc(sk, crypto_skcipher_ivsize(skcipher),
936 			       GFP_KERNEL);
937 	if (!ctx->iv) {
938 		sock_kfree_s(sk, ctx, len);
939 		return -ENOMEM;
940 	}
941 
942 	memset(ctx->iv, 0, crypto_skcipher_ivsize(skcipher));
943 
944 	INIT_LIST_HEAD(&ctx->tsgl);
945 	ctx->len = len;
946 	ctx->used = 0;
947 	ctx->more = 0;
948 	ctx->merge = 0;
949 	ctx->enc = 0;
950 	atomic_set(&ctx->inflight, 0);
951 	af_alg_init_completion(&ctx->completion);
952 
953 	ask->private = ctx;
954 
955 	skcipher_request_set_tfm(&ctx->req, skcipher);
956 	skcipher_request_set_callback(&ctx->req, CRYPTO_TFM_REQ_MAY_SLEEP |
957 						 CRYPTO_TFM_REQ_MAY_BACKLOG,
958 				      af_alg_complete, &ctx->completion);
959 
960 	sk->sk_destruct = skcipher_sock_destruct;
961 
962 	return 0;
963 }
964 
965 static int skcipher_accept_parent(void *private, struct sock *sk)
966 {
967 	struct skcipher_tfm *tfm = private;
968 
969 	if (!tfm->has_key && crypto_skcipher_has_setkey(tfm->skcipher))
970 		return -ENOKEY;
971 
972 	return skcipher_accept_parent_nokey(private, sk);
973 }
974 
975 static const struct af_alg_type algif_type_skcipher = {
976 	.bind		=	skcipher_bind,
977 	.release	=	skcipher_release,
978 	.setkey		=	skcipher_setkey,
979 	.accept		=	skcipher_accept_parent,
980 	.accept_nokey	=	skcipher_accept_parent_nokey,
981 	.ops		=	&algif_skcipher_ops,
982 	.ops_nokey	=	&algif_skcipher_ops_nokey,
983 	.name		=	"skcipher",
984 	.owner		=	THIS_MODULE
985 };
986 
987 static int __init algif_skcipher_init(void)
988 {
989 	return af_alg_register_type(&algif_type_skcipher);
990 }
991 
992 static void __exit algif_skcipher_exit(void)
993 {
994 	int err = af_alg_unregister_type(&algif_type_skcipher);
995 	BUG_ON(err);
996 }
997 
998 module_init(algif_skcipher_init);
999 module_exit(algif_skcipher_exit);
1000 MODULE_LICENSE("GPL");
1001