1 // SPDX-License-Identifier: (BSD-3-Clause OR GPL-2.0-only)
2 /* Copyright(c) 2014 - 2020 Intel Corporation */
3 #include <linux/module.h>
4 #include <crypto/internal/rsa.h>
5 #include <crypto/internal/akcipher.h>
6 #include <crypto/akcipher.h>
7 #include <crypto/kpp.h>
8 #include <crypto/internal/kpp.h>
9 #include <crypto/dh.h>
10 #include <linux/dma-mapping.h>
11 #include <linux/fips.h>
12 #include <crypto/scatterwalk.h>
13 #include "icp_qat_fw_pke.h"
14 #include "adf_accel_devices.h"
15 #include "qat_algs_send.h"
16 #include "adf_transport.h"
17 #include "adf_common_drv.h"
18 #include "qat_crypto.h"
19 
20 static DEFINE_MUTEX(algs_lock);
21 static unsigned int active_devs;
22 
23 struct qat_rsa_input_params {
24 	union {
25 		struct {
26 			dma_addr_t m;
27 			dma_addr_t e;
28 			dma_addr_t n;
29 		} enc;
30 		struct {
31 			dma_addr_t c;
32 			dma_addr_t d;
33 			dma_addr_t n;
34 		} dec;
35 		struct {
36 			dma_addr_t c;
37 			dma_addr_t p;
38 			dma_addr_t q;
39 			dma_addr_t dp;
40 			dma_addr_t dq;
41 			dma_addr_t qinv;
42 		} dec_crt;
43 		u64 in_tab[8];
44 	};
45 } __packed __aligned(64);
46 
47 struct qat_rsa_output_params {
48 	union {
49 		struct {
50 			dma_addr_t c;
51 		} enc;
52 		struct {
53 			dma_addr_t m;
54 		} dec;
55 		u64 out_tab[8];
56 	};
57 } __packed __aligned(64);
58 
59 struct qat_rsa_ctx {
60 	char *n;
61 	char *e;
62 	char *d;
63 	char *p;
64 	char *q;
65 	char *dp;
66 	char *dq;
67 	char *qinv;
68 	dma_addr_t dma_n;
69 	dma_addr_t dma_e;
70 	dma_addr_t dma_d;
71 	dma_addr_t dma_p;
72 	dma_addr_t dma_q;
73 	dma_addr_t dma_dp;
74 	dma_addr_t dma_dq;
75 	dma_addr_t dma_qinv;
76 	unsigned int key_sz;
77 	bool crt_mode;
78 	struct qat_crypto_instance *inst;
79 } __packed __aligned(64);
80 
81 struct qat_dh_input_params {
82 	union {
83 		struct {
84 			dma_addr_t b;
85 			dma_addr_t xa;
86 			dma_addr_t p;
87 		} in;
88 		struct {
89 			dma_addr_t xa;
90 			dma_addr_t p;
91 		} in_g2;
92 		u64 in_tab[8];
93 	};
94 } __packed __aligned(64);
95 
96 struct qat_dh_output_params {
97 	union {
98 		dma_addr_t r;
99 		u64 out_tab[8];
100 	};
101 } __packed __aligned(64);
102 
103 struct qat_dh_ctx {
104 	char *g;
105 	char *xa;
106 	char *p;
107 	dma_addr_t dma_g;
108 	dma_addr_t dma_xa;
109 	dma_addr_t dma_p;
110 	unsigned int p_size;
111 	bool g2;
112 	struct qat_crypto_instance *inst;
113 } __packed __aligned(64);
114 
115 struct qat_asym_request {
116 	union {
117 		struct qat_rsa_input_params rsa;
118 		struct qat_dh_input_params dh;
119 	} in;
120 	union {
121 		struct qat_rsa_output_params rsa;
122 		struct qat_dh_output_params dh;
123 	} out;
124 	dma_addr_t phy_in;
125 	dma_addr_t phy_out;
126 	char *src_align;
127 	char *dst_align;
128 	struct icp_qat_fw_pke_request req;
129 	union {
130 		struct qat_rsa_ctx *rsa;
131 		struct qat_dh_ctx *dh;
132 	} ctx;
133 	union {
134 		struct akcipher_request *rsa;
135 		struct kpp_request *dh;
136 	} areq;
137 	int err;
138 	void (*cb)(struct icp_qat_fw_pke_resp *resp);
139 	struct qat_alg_req alg_req;
140 } __aligned(64);
141 
142 static int qat_alg_send_asym_message(struct qat_asym_request *qat_req,
143 				     struct qat_crypto_instance *inst,
144 				     struct crypto_async_request *base)
145 {
146 	struct qat_alg_req *alg_req = &qat_req->alg_req;
147 
148 	alg_req->fw_req = (u32 *)&qat_req->req;
149 	alg_req->tx_ring = inst->pke_tx;
150 	alg_req->base = base;
151 	alg_req->backlog = &inst->backlog;
152 
153 	return qat_alg_send_message(alg_req);
154 }
155 
156 static void qat_dh_cb(struct icp_qat_fw_pke_resp *resp)
157 {
158 	struct qat_asym_request *req = (void *)(__force long)resp->opaque;
159 	struct kpp_request *areq = req->areq.dh;
160 	struct device *dev = &GET_DEV(req->ctx.dh->inst->accel_dev);
161 	int err = ICP_QAT_FW_PKE_RESP_PKE_STAT_GET(
162 				resp->pke_resp_hdr.comn_resp_flags);
163 
164 	err = (err == ICP_QAT_FW_COMN_STATUS_FLAG_OK) ? 0 : -EINVAL;
165 
166 	if (areq->src) {
167 		dma_unmap_single(dev, req->in.dh.in.b, req->ctx.dh->p_size,
168 				 DMA_TO_DEVICE);
169 		kfree_sensitive(req->src_align);
170 	}
171 
172 	areq->dst_len = req->ctx.dh->p_size;
173 	if (req->dst_align) {
174 		scatterwalk_map_and_copy(req->dst_align, areq->dst, 0,
175 					 areq->dst_len, 1);
176 		kfree_sensitive(req->dst_align);
177 	}
178 
179 	dma_unmap_single(dev, req->out.dh.r, req->ctx.dh->p_size,
180 			 DMA_FROM_DEVICE);
181 
182 	dma_unmap_single(dev, req->phy_in, sizeof(struct qat_dh_input_params),
183 			 DMA_TO_DEVICE);
184 	dma_unmap_single(dev, req->phy_out,
185 			 sizeof(struct qat_dh_output_params),
186 			 DMA_TO_DEVICE);
187 
188 	kpp_request_complete(areq, err);
189 }
190 
191 #define PKE_DH_1536 0x390c1a49
192 #define PKE_DH_G2_1536 0x2e0b1a3e
193 #define PKE_DH_2048 0x4d0c1a60
194 #define PKE_DH_G2_2048 0x3e0b1a55
195 #define PKE_DH_3072 0x510c1a77
196 #define PKE_DH_G2_3072 0x3a0b1a6c
197 #define PKE_DH_4096 0x690c1a8e
198 #define PKE_DH_G2_4096 0x4a0b1a83
199 
200 static unsigned long qat_dh_fn_id(unsigned int len, bool g2)
201 {
202 	unsigned int bitslen = len << 3;
203 
204 	switch (bitslen) {
205 	case 1536:
206 		return g2 ? PKE_DH_G2_1536 : PKE_DH_1536;
207 	case 2048:
208 		return g2 ? PKE_DH_G2_2048 : PKE_DH_2048;
209 	case 3072:
210 		return g2 ? PKE_DH_G2_3072 : PKE_DH_3072;
211 	case 4096:
212 		return g2 ? PKE_DH_G2_4096 : PKE_DH_4096;
213 	default:
214 		return 0;
215 	}
216 }
217 
218 static int qat_dh_compute_value(struct kpp_request *req)
219 {
220 	struct crypto_kpp *tfm = crypto_kpp_reqtfm(req);
221 	struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
222 	struct qat_crypto_instance *inst = ctx->inst;
223 	struct device *dev = &GET_DEV(inst->accel_dev);
224 	struct qat_asym_request *qat_req =
225 			PTR_ALIGN(kpp_request_ctx(req), 64);
226 	struct icp_qat_fw_pke_request *msg = &qat_req->req;
227 	gfp_t flags = qat_algs_alloc_flags(&req->base);
228 	int n_input_params = 0;
229 	u8 *vaddr;
230 	int ret;
231 
232 	if (unlikely(!ctx->xa))
233 		return -EINVAL;
234 
235 	if (req->dst_len < ctx->p_size) {
236 		req->dst_len = ctx->p_size;
237 		return -EOVERFLOW;
238 	}
239 
240 	if (req->src_len > ctx->p_size)
241 		return -EINVAL;
242 
243 	memset(msg, '\0', sizeof(*msg));
244 	ICP_QAT_FW_PKE_HDR_VALID_FLAG_SET(msg->pke_hdr,
245 					  ICP_QAT_FW_COMN_REQ_FLAG_SET);
246 
247 	msg->pke_hdr.cd_pars.func_id = qat_dh_fn_id(ctx->p_size,
248 						    !req->src && ctx->g2);
249 	if (unlikely(!msg->pke_hdr.cd_pars.func_id))
250 		return -EINVAL;
251 
252 	qat_req->cb = qat_dh_cb;
253 	qat_req->ctx.dh = ctx;
254 	qat_req->areq.dh = req;
255 	msg->pke_hdr.service_type = ICP_QAT_FW_COMN_REQ_CPM_FW_PKE;
256 	msg->pke_hdr.comn_req_flags =
257 		ICP_QAT_FW_COMN_FLAGS_BUILD(QAT_COMN_PTR_TYPE_FLAT,
258 					    QAT_COMN_CD_FLD_TYPE_64BIT_ADR);
259 
260 	/*
261 	 * If no source is provided use g as base
262 	 */
263 	if (req->src) {
264 		qat_req->in.dh.in.xa = ctx->dma_xa;
265 		qat_req->in.dh.in.p = ctx->dma_p;
266 		n_input_params = 3;
267 	} else {
268 		if (ctx->g2) {
269 			qat_req->in.dh.in_g2.xa = ctx->dma_xa;
270 			qat_req->in.dh.in_g2.p = ctx->dma_p;
271 			n_input_params = 2;
272 		} else {
273 			qat_req->in.dh.in.b = ctx->dma_g;
274 			qat_req->in.dh.in.xa = ctx->dma_xa;
275 			qat_req->in.dh.in.p = ctx->dma_p;
276 			n_input_params = 3;
277 		}
278 	}
279 
280 	ret = -ENOMEM;
281 	if (req->src) {
282 		/*
283 		 * src can be of any size in valid range, but HW expects it to
284 		 * be the same as modulo p so in case it is different we need
285 		 * to allocate a new buf and copy src data.
286 		 * In other case we just need to map the user provided buffer.
287 		 * Also need to make sure that it is in contiguous buffer.
288 		 */
289 		if (sg_is_last(req->src) && req->src_len == ctx->p_size) {
290 			qat_req->src_align = NULL;
291 			vaddr = sg_virt(req->src);
292 		} else {
293 			int shift = ctx->p_size - req->src_len;
294 
295 			qat_req->src_align = kzalloc(ctx->p_size, flags);
296 			if (unlikely(!qat_req->src_align))
297 				return ret;
298 
299 			scatterwalk_map_and_copy(qat_req->src_align + shift,
300 						 req->src, 0, req->src_len, 0);
301 
302 			vaddr = qat_req->src_align;
303 		}
304 
305 		qat_req->in.dh.in.b = dma_map_single(dev, vaddr, ctx->p_size,
306 						     DMA_TO_DEVICE);
307 		if (unlikely(dma_mapping_error(dev, qat_req->in.dh.in.b)))
308 			goto unmap_src;
309 	}
310 	/*
311 	 * dst can be of any size in valid range, but HW expects it to be the
312 	 * same as modulo m so in case it is different we need to allocate a
313 	 * new buf and copy src data.
314 	 * In other case we just need to map the user provided buffer.
315 	 * Also need to make sure that it is in contiguous buffer.
316 	 */
317 	if (sg_is_last(req->dst) && req->dst_len == ctx->p_size) {
318 		qat_req->dst_align = NULL;
319 		vaddr = sg_virt(req->dst);
320 	} else {
321 		qat_req->dst_align = kzalloc(ctx->p_size, flags);
322 		if (unlikely(!qat_req->dst_align))
323 			goto unmap_src;
324 
325 		vaddr = qat_req->dst_align;
326 	}
327 	qat_req->out.dh.r = dma_map_single(dev, vaddr, ctx->p_size,
328 					   DMA_FROM_DEVICE);
329 	if (unlikely(dma_mapping_error(dev, qat_req->out.dh.r)))
330 		goto unmap_dst;
331 
332 	qat_req->in.dh.in_tab[n_input_params] = 0;
333 	qat_req->out.dh.out_tab[1] = 0;
334 	/* Mapping in.in.b or in.in_g2.xa is the same */
335 	qat_req->phy_in = dma_map_single(dev, &qat_req->in.dh,
336 					 sizeof(struct qat_dh_input_params),
337 					 DMA_TO_DEVICE);
338 	if (unlikely(dma_mapping_error(dev, qat_req->phy_in)))
339 		goto unmap_dst;
340 
341 	qat_req->phy_out = dma_map_single(dev, &qat_req->out.dh,
342 					  sizeof(struct qat_dh_output_params),
343 					  DMA_TO_DEVICE);
344 	if (unlikely(dma_mapping_error(dev, qat_req->phy_out)))
345 		goto unmap_in_params;
346 
347 	msg->pke_mid.src_data_addr = qat_req->phy_in;
348 	msg->pke_mid.dest_data_addr = qat_req->phy_out;
349 	msg->pke_mid.opaque = (u64)(__force long)qat_req;
350 	msg->input_param_count = n_input_params;
351 	msg->output_param_count = 1;
352 
353 	ret = qat_alg_send_asym_message(qat_req, inst, &req->base);
354 	if (ret == -ENOSPC)
355 		goto unmap_all;
356 
357 	return ret;
358 
359 unmap_all:
360 	if (!dma_mapping_error(dev, qat_req->phy_out))
361 		dma_unmap_single(dev, qat_req->phy_out,
362 				 sizeof(struct qat_dh_output_params),
363 				 DMA_TO_DEVICE);
364 unmap_in_params:
365 	if (!dma_mapping_error(dev, qat_req->phy_in))
366 		dma_unmap_single(dev, qat_req->phy_in,
367 				 sizeof(struct qat_dh_input_params),
368 				 DMA_TO_DEVICE);
369 unmap_dst:
370 	if (!dma_mapping_error(dev, qat_req->out.dh.r))
371 		dma_unmap_single(dev, qat_req->out.dh.r, ctx->p_size,
372 				 DMA_FROM_DEVICE);
373 	kfree_sensitive(qat_req->dst_align);
374 unmap_src:
375 	if (req->src) {
376 		if (!dma_mapping_error(dev, qat_req->in.dh.in.b))
377 			dma_unmap_single(dev, qat_req->in.dh.in.b,
378 					 ctx->p_size,
379 					 DMA_TO_DEVICE);
380 		kfree_sensitive(qat_req->src_align);
381 	}
382 	return ret;
383 }
384 
385 static int qat_dh_check_params_length(unsigned int p_len)
386 {
387 	switch (p_len) {
388 	case 1536:
389 	case 2048:
390 	case 3072:
391 	case 4096:
392 		return 0;
393 	}
394 	return -EINVAL;
395 }
396 
397 static int qat_dh_set_params(struct qat_dh_ctx *ctx, struct dh *params)
398 {
399 	struct qat_crypto_instance *inst = ctx->inst;
400 	struct device *dev = &GET_DEV(inst->accel_dev);
401 
402 	if (qat_dh_check_params_length(params->p_size << 3))
403 		return -EINVAL;
404 
405 	ctx->p_size = params->p_size;
406 	ctx->p = dma_alloc_coherent(dev, ctx->p_size, &ctx->dma_p, GFP_KERNEL);
407 	if (!ctx->p)
408 		return -ENOMEM;
409 	memcpy(ctx->p, params->p, ctx->p_size);
410 
411 	/* If g equals 2 don't copy it */
412 	if (params->g_size == 1 && *(char *)params->g == 0x02) {
413 		ctx->g2 = true;
414 		return 0;
415 	}
416 
417 	ctx->g = dma_alloc_coherent(dev, ctx->p_size, &ctx->dma_g, GFP_KERNEL);
418 	if (!ctx->g)
419 		return -ENOMEM;
420 	memcpy(ctx->g + (ctx->p_size - params->g_size), params->g,
421 	       params->g_size);
422 
423 	return 0;
424 }
425 
426 static void qat_dh_clear_ctx(struct device *dev, struct qat_dh_ctx *ctx)
427 {
428 	if (ctx->g) {
429 		memset(ctx->g, 0, ctx->p_size);
430 		dma_free_coherent(dev, ctx->p_size, ctx->g, ctx->dma_g);
431 		ctx->g = NULL;
432 	}
433 	if (ctx->xa) {
434 		memset(ctx->xa, 0, ctx->p_size);
435 		dma_free_coherent(dev, ctx->p_size, ctx->xa, ctx->dma_xa);
436 		ctx->xa = NULL;
437 	}
438 	if (ctx->p) {
439 		memset(ctx->p, 0, ctx->p_size);
440 		dma_free_coherent(dev, ctx->p_size, ctx->p, ctx->dma_p);
441 		ctx->p = NULL;
442 	}
443 	ctx->p_size = 0;
444 	ctx->g2 = false;
445 }
446 
447 static int qat_dh_set_secret(struct crypto_kpp *tfm, const void *buf,
448 			     unsigned int len)
449 {
450 	struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
451 	struct device *dev = &GET_DEV(ctx->inst->accel_dev);
452 	struct dh params;
453 	int ret;
454 
455 	if (crypto_dh_decode_key(buf, len, &params) < 0)
456 		return -EINVAL;
457 
458 	/* Free old secret if any */
459 	qat_dh_clear_ctx(dev, ctx);
460 
461 	ret = qat_dh_set_params(ctx, &params);
462 	if (ret < 0)
463 		goto err_clear_ctx;
464 
465 	ctx->xa = dma_alloc_coherent(dev, ctx->p_size, &ctx->dma_xa,
466 				     GFP_KERNEL);
467 	if (!ctx->xa) {
468 		ret = -ENOMEM;
469 		goto err_clear_ctx;
470 	}
471 	memcpy(ctx->xa + (ctx->p_size - params.key_size), params.key,
472 	       params.key_size);
473 
474 	return 0;
475 
476 err_clear_ctx:
477 	qat_dh_clear_ctx(dev, ctx);
478 	return ret;
479 }
480 
481 static unsigned int qat_dh_max_size(struct crypto_kpp *tfm)
482 {
483 	struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
484 
485 	return ctx->p_size;
486 }
487 
488 static int qat_dh_init_tfm(struct crypto_kpp *tfm)
489 {
490 	struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
491 	struct qat_crypto_instance *inst =
492 			qat_crypto_get_instance_node(numa_node_id());
493 
494 	if (!inst)
495 		return -EINVAL;
496 
497 	kpp_set_reqsize(tfm, sizeof(struct qat_asym_request) + 64);
498 
499 	ctx->p_size = 0;
500 	ctx->g2 = false;
501 	ctx->inst = inst;
502 	return 0;
503 }
504 
505 static void qat_dh_exit_tfm(struct crypto_kpp *tfm)
506 {
507 	struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
508 	struct device *dev = &GET_DEV(ctx->inst->accel_dev);
509 
510 	qat_dh_clear_ctx(dev, ctx);
511 	qat_crypto_put_instance(ctx->inst);
512 }
513 
514 static void qat_rsa_cb(struct icp_qat_fw_pke_resp *resp)
515 {
516 	struct qat_asym_request *req = (void *)(__force long)resp->opaque;
517 	struct akcipher_request *areq = req->areq.rsa;
518 	struct device *dev = &GET_DEV(req->ctx.rsa->inst->accel_dev);
519 	int err = ICP_QAT_FW_PKE_RESP_PKE_STAT_GET(
520 				resp->pke_resp_hdr.comn_resp_flags);
521 
522 	err = (err == ICP_QAT_FW_COMN_STATUS_FLAG_OK) ? 0 : -EINVAL;
523 
524 	kfree_sensitive(req->src_align);
525 
526 	dma_unmap_single(dev, req->in.rsa.enc.m, req->ctx.rsa->key_sz,
527 			 DMA_TO_DEVICE);
528 
529 	areq->dst_len = req->ctx.rsa->key_sz;
530 	if (req->dst_align) {
531 		scatterwalk_map_and_copy(req->dst_align, areq->dst, 0,
532 					 areq->dst_len, 1);
533 
534 		kfree_sensitive(req->dst_align);
535 	}
536 
537 	dma_unmap_single(dev, req->out.rsa.enc.c, req->ctx.rsa->key_sz,
538 			 DMA_FROM_DEVICE);
539 
540 	dma_unmap_single(dev, req->phy_in, sizeof(struct qat_rsa_input_params),
541 			 DMA_TO_DEVICE);
542 	dma_unmap_single(dev, req->phy_out,
543 			 sizeof(struct qat_rsa_output_params),
544 			 DMA_TO_DEVICE);
545 
546 	akcipher_request_complete(areq, err);
547 }
548 
549 void qat_alg_asym_callback(void *_resp)
550 {
551 	struct icp_qat_fw_pke_resp *resp = _resp;
552 	struct qat_asym_request *areq = (void *)(__force long)resp->opaque;
553 	struct qat_instance_backlog *backlog = areq->alg_req.backlog;
554 
555 	areq->cb(resp);
556 
557 	qat_alg_send_backlog(backlog);
558 }
559 
560 #define PKE_RSA_EP_512 0x1c161b21
561 #define PKE_RSA_EP_1024 0x35111bf7
562 #define PKE_RSA_EP_1536 0x4d111cdc
563 #define PKE_RSA_EP_2048 0x6e111dba
564 #define PKE_RSA_EP_3072 0x7d111ea3
565 #define PKE_RSA_EP_4096 0xa5101f7e
566 
567 static unsigned long qat_rsa_enc_fn_id(unsigned int len)
568 {
569 	unsigned int bitslen = len << 3;
570 
571 	switch (bitslen) {
572 	case 512:
573 		return PKE_RSA_EP_512;
574 	case 1024:
575 		return PKE_RSA_EP_1024;
576 	case 1536:
577 		return PKE_RSA_EP_1536;
578 	case 2048:
579 		return PKE_RSA_EP_2048;
580 	case 3072:
581 		return PKE_RSA_EP_3072;
582 	case 4096:
583 		return PKE_RSA_EP_4096;
584 	default:
585 		return 0;
586 	}
587 }
588 
589 #define PKE_RSA_DP1_512 0x1c161b3c
590 #define PKE_RSA_DP1_1024 0x35111c12
591 #define PKE_RSA_DP1_1536 0x4d111cf7
592 #define PKE_RSA_DP1_2048 0x6e111dda
593 #define PKE_RSA_DP1_3072 0x7d111ebe
594 #define PKE_RSA_DP1_4096 0xa5101f98
595 
596 static unsigned long qat_rsa_dec_fn_id(unsigned int len)
597 {
598 	unsigned int bitslen = len << 3;
599 
600 	switch (bitslen) {
601 	case 512:
602 		return PKE_RSA_DP1_512;
603 	case 1024:
604 		return PKE_RSA_DP1_1024;
605 	case 1536:
606 		return PKE_RSA_DP1_1536;
607 	case 2048:
608 		return PKE_RSA_DP1_2048;
609 	case 3072:
610 		return PKE_RSA_DP1_3072;
611 	case 4096:
612 		return PKE_RSA_DP1_4096;
613 	default:
614 		return 0;
615 	}
616 }
617 
618 #define PKE_RSA_DP2_512 0x1c131b57
619 #define PKE_RSA_DP2_1024 0x26131c2d
620 #define PKE_RSA_DP2_1536 0x45111d12
621 #define PKE_RSA_DP2_2048 0x59121dfa
622 #define PKE_RSA_DP2_3072 0x81121ed9
623 #define PKE_RSA_DP2_4096 0xb1111fb2
624 
625 static unsigned long qat_rsa_dec_fn_id_crt(unsigned int len)
626 {
627 	unsigned int bitslen = len << 3;
628 
629 	switch (bitslen) {
630 	case 512:
631 		return PKE_RSA_DP2_512;
632 	case 1024:
633 		return PKE_RSA_DP2_1024;
634 	case 1536:
635 		return PKE_RSA_DP2_1536;
636 	case 2048:
637 		return PKE_RSA_DP2_2048;
638 	case 3072:
639 		return PKE_RSA_DP2_3072;
640 	case 4096:
641 		return PKE_RSA_DP2_4096;
642 	default:
643 		return 0;
644 	}
645 }
646 
647 static int qat_rsa_enc(struct akcipher_request *req)
648 {
649 	struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
650 	struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
651 	struct qat_crypto_instance *inst = ctx->inst;
652 	struct device *dev = &GET_DEV(inst->accel_dev);
653 	struct qat_asym_request *qat_req =
654 			PTR_ALIGN(akcipher_request_ctx(req), 64);
655 	struct icp_qat_fw_pke_request *msg = &qat_req->req;
656 	gfp_t flags = qat_algs_alloc_flags(&req->base);
657 	u8 *vaddr;
658 	int ret;
659 
660 	if (unlikely(!ctx->n || !ctx->e))
661 		return -EINVAL;
662 
663 	if (req->dst_len < ctx->key_sz) {
664 		req->dst_len = ctx->key_sz;
665 		return -EOVERFLOW;
666 	}
667 
668 	if (req->src_len > ctx->key_sz)
669 		return -EINVAL;
670 
671 	memset(msg, '\0', sizeof(*msg));
672 	ICP_QAT_FW_PKE_HDR_VALID_FLAG_SET(msg->pke_hdr,
673 					  ICP_QAT_FW_COMN_REQ_FLAG_SET);
674 	msg->pke_hdr.cd_pars.func_id = qat_rsa_enc_fn_id(ctx->key_sz);
675 	if (unlikely(!msg->pke_hdr.cd_pars.func_id))
676 		return -EINVAL;
677 
678 	qat_req->cb = qat_rsa_cb;
679 	qat_req->ctx.rsa = ctx;
680 	qat_req->areq.rsa = req;
681 	msg->pke_hdr.service_type = ICP_QAT_FW_COMN_REQ_CPM_FW_PKE;
682 	msg->pke_hdr.comn_req_flags =
683 		ICP_QAT_FW_COMN_FLAGS_BUILD(QAT_COMN_PTR_TYPE_FLAT,
684 					    QAT_COMN_CD_FLD_TYPE_64BIT_ADR);
685 
686 	qat_req->in.rsa.enc.e = ctx->dma_e;
687 	qat_req->in.rsa.enc.n = ctx->dma_n;
688 	ret = -ENOMEM;
689 
690 	/*
691 	 * src can be of any size in valid range, but HW expects it to be the
692 	 * same as modulo n so in case it is different we need to allocate a
693 	 * new buf and copy src data.
694 	 * In other case we just need to map the user provided buffer.
695 	 * Also need to make sure that it is in contiguous buffer.
696 	 */
697 	if (sg_is_last(req->src) && req->src_len == ctx->key_sz) {
698 		qat_req->src_align = NULL;
699 		vaddr = sg_virt(req->src);
700 	} else {
701 		int shift = ctx->key_sz - req->src_len;
702 
703 		qat_req->src_align = kzalloc(ctx->key_sz, flags);
704 		if (unlikely(!qat_req->src_align))
705 			return ret;
706 
707 		scatterwalk_map_and_copy(qat_req->src_align + shift, req->src,
708 					 0, req->src_len, 0);
709 		vaddr = qat_req->src_align;
710 	}
711 
712 	qat_req->in.rsa.enc.m = dma_map_single(dev, vaddr, ctx->key_sz,
713 					       DMA_TO_DEVICE);
714 	if (unlikely(dma_mapping_error(dev, qat_req->in.rsa.enc.m)))
715 		goto unmap_src;
716 
717 	if (sg_is_last(req->dst) && req->dst_len == ctx->key_sz) {
718 		qat_req->dst_align = NULL;
719 		vaddr = sg_virt(req->dst);
720 	} else {
721 		qat_req->dst_align = kzalloc(ctx->key_sz, flags);
722 		if (unlikely(!qat_req->dst_align))
723 			goto unmap_src;
724 		vaddr = qat_req->dst_align;
725 	}
726 
727 	qat_req->out.rsa.enc.c = dma_map_single(dev, vaddr, ctx->key_sz,
728 						DMA_FROM_DEVICE);
729 	if (unlikely(dma_mapping_error(dev, qat_req->out.rsa.enc.c)))
730 		goto unmap_dst;
731 
732 	qat_req->in.rsa.in_tab[3] = 0;
733 	qat_req->out.rsa.out_tab[1] = 0;
734 	qat_req->phy_in = dma_map_single(dev, &qat_req->in.rsa,
735 					 sizeof(struct qat_rsa_input_params),
736 					 DMA_TO_DEVICE);
737 	if (unlikely(dma_mapping_error(dev, qat_req->phy_in)))
738 		goto unmap_dst;
739 
740 	qat_req->phy_out = dma_map_single(dev, &qat_req->out.rsa,
741 					  sizeof(struct qat_rsa_output_params),
742 					  DMA_TO_DEVICE);
743 	if (unlikely(dma_mapping_error(dev, qat_req->phy_out)))
744 		goto unmap_in_params;
745 
746 	msg->pke_mid.src_data_addr = qat_req->phy_in;
747 	msg->pke_mid.dest_data_addr = qat_req->phy_out;
748 	msg->pke_mid.opaque = (u64)(__force long)qat_req;
749 	msg->input_param_count = 3;
750 	msg->output_param_count = 1;
751 
752 	ret = qat_alg_send_asym_message(qat_req, inst, &req->base);
753 	if (ret == -ENOSPC)
754 		goto unmap_all;
755 
756 	return ret;
757 
758 unmap_all:
759 	if (!dma_mapping_error(dev, qat_req->phy_out))
760 		dma_unmap_single(dev, qat_req->phy_out,
761 				 sizeof(struct qat_rsa_output_params),
762 				 DMA_TO_DEVICE);
763 unmap_in_params:
764 	if (!dma_mapping_error(dev, qat_req->phy_in))
765 		dma_unmap_single(dev, qat_req->phy_in,
766 				 sizeof(struct qat_rsa_input_params),
767 				 DMA_TO_DEVICE);
768 unmap_dst:
769 	if (!dma_mapping_error(dev, qat_req->out.rsa.enc.c))
770 		dma_unmap_single(dev, qat_req->out.rsa.enc.c,
771 				 ctx->key_sz, DMA_FROM_DEVICE);
772 	kfree_sensitive(qat_req->dst_align);
773 unmap_src:
774 	if (!dma_mapping_error(dev, qat_req->in.rsa.enc.m))
775 		dma_unmap_single(dev, qat_req->in.rsa.enc.m, ctx->key_sz,
776 				 DMA_TO_DEVICE);
777 	kfree_sensitive(qat_req->src_align);
778 	return ret;
779 }
780 
781 static int qat_rsa_dec(struct akcipher_request *req)
782 {
783 	struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
784 	struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
785 	struct qat_crypto_instance *inst = ctx->inst;
786 	struct device *dev = &GET_DEV(inst->accel_dev);
787 	struct qat_asym_request *qat_req =
788 			PTR_ALIGN(akcipher_request_ctx(req), 64);
789 	struct icp_qat_fw_pke_request *msg = &qat_req->req;
790 	gfp_t flags = qat_algs_alloc_flags(&req->base);
791 	u8 *vaddr;
792 	int ret;
793 
794 	if (unlikely(!ctx->n || !ctx->d))
795 		return -EINVAL;
796 
797 	if (req->dst_len < ctx->key_sz) {
798 		req->dst_len = ctx->key_sz;
799 		return -EOVERFLOW;
800 	}
801 
802 	if (req->src_len > ctx->key_sz)
803 		return -EINVAL;
804 
805 	memset(msg, '\0', sizeof(*msg));
806 	ICP_QAT_FW_PKE_HDR_VALID_FLAG_SET(msg->pke_hdr,
807 					  ICP_QAT_FW_COMN_REQ_FLAG_SET);
808 	msg->pke_hdr.cd_pars.func_id = ctx->crt_mode ?
809 		qat_rsa_dec_fn_id_crt(ctx->key_sz) :
810 		qat_rsa_dec_fn_id(ctx->key_sz);
811 	if (unlikely(!msg->pke_hdr.cd_pars.func_id))
812 		return -EINVAL;
813 
814 	qat_req->cb = qat_rsa_cb;
815 	qat_req->ctx.rsa = ctx;
816 	qat_req->areq.rsa = req;
817 	msg->pke_hdr.service_type = ICP_QAT_FW_COMN_REQ_CPM_FW_PKE;
818 	msg->pke_hdr.comn_req_flags =
819 		ICP_QAT_FW_COMN_FLAGS_BUILD(QAT_COMN_PTR_TYPE_FLAT,
820 					    QAT_COMN_CD_FLD_TYPE_64BIT_ADR);
821 
822 	if (ctx->crt_mode) {
823 		qat_req->in.rsa.dec_crt.p = ctx->dma_p;
824 		qat_req->in.rsa.dec_crt.q = ctx->dma_q;
825 		qat_req->in.rsa.dec_crt.dp = ctx->dma_dp;
826 		qat_req->in.rsa.dec_crt.dq = ctx->dma_dq;
827 		qat_req->in.rsa.dec_crt.qinv = ctx->dma_qinv;
828 	} else {
829 		qat_req->in.rsa.dec.d = ctx->dma_d;
830 		qat_req->in.rsa.dec.n = ctx->dma_n;
831 	}
832 	ret = -ENOMEM;
833 
834 	/*
835 	 * src can be of any size in valid range, but HW expects it to be the
836 	 * same as modulo n so in case it is different we need to allocate a
837 	 * new buf and copy src data.
838 	 * In other case we just need to map the user provided buffer.
839 	 * Also need to make sure that it is in contiguous buffer.
840 	 */
841 	if (sg_is_last(req->src) && req->src_len == ctx->key_sz) {
842 		qat_req->src_align = NULL;
843 		vaddr = sg_virt(req->src);
844 	} else {
845 		int shift = ctx->key_sz - req->src_len;
846 
847 		qat_req->src_align = kzalloc(ctx->key_sz, flags);
848 		if (unlikely(!qat_req->src_align))
849 			return ret;
850 
851 		scatterwalk_map_and_copy(qat_req->src_align + shift, req->src,
852 					 0, req->src_len, 0);
853 		vaddr = qat_req->src_align;
854 	}
855 
856 	qat_req->in.rsa.dec.c = dma_map_single(dev, vaddr, ctx->key_sz,
857 					       DMA_TO_DEVICE);
858 	if (unlikely(dma_mapping_error(dev, qat_req->in.rsa.dec.c)))
859 		goto unmap_src;
860 
861 	if (sg_is_last(req->dst) && req->dst_len == ctx->key_sz) {
862 		qat_req->dst_align = NULL;
863 		vaddr = sg_virt(req->dst);
864 	} else {
865 		qat_req->dst_align = kzalloc(ctx->key_sz, flags);
866 		if (unlikely(!qat_req->dst_align))
867 			goto unmap_src;
868 		vaddr = qat_req->dst_align;
869 	}
870 	qat_req->out.rsa.dec.m = dma_map_single(dev, vaddr, ctx->key_sz,
871 						DMA_FROM_DEVICE);
872 	if (unlikely(dma_mapping_error(dev, qat_req->out.rsa.dec.m)))
873 		goto unmap_dst;
874 
875 	if (ctx->crt_mode)
876 		qat_req->in.rsa.in_tab[6] = 0;
877 	else
878 		qat_req->in.rsa.in_tab[3] = 0;
879 	qat_req->out.rsa.out_tab[1] = 0;
880 	qat_req->phy_in = dma_map_single(dev, &qat_req->in.rsa,
881 					 sizeof(struct qat_rsa_input_params),
882 					 DMA_TO_DEVICE);
883 	if (unlikely(dma_mapping_error(dev, qat_req->phy_in)))
884 		goto unmap_dst;
885 
886 	qat_req->phy_out = dma_map_single(dev, &qat_req->out.rsa,
887 					  sizeof(struct qat_rsa_output_params),
888 					  DMA_TO_DEVICE);
889 	if (unlikely(dma_mapping_error(dev, qat_req->phy_out)))
890 		goto unmap_in_params;
891 
892 	msg->pke_mid.src_data_addr = qat_req->phy_in;
893 	msg->pke_mid.dest_data_addr = qat_req->phy_out;
894 	msg->pke_mid.opaque = (u64)(__force long)qat_req;
895 	if (ctx->crt_mode)
896 		msg->input_param_count = 6;
897 	else
898 		msg->input_param_count = 3;
899 
900 	msg->output_param_count = 1;
901 
902 	ret = qat_alg_send_asym_message(qat_req, inst, &req->base);
903 	if (ret == -ENOSPC)
904 		goto unmap_all;
905 
906 	return ret;
907 
908 unmap_all:
909 	if (!dma_mapping_error(dev, qat_req->phy_out))
910 		dma_unmap_single(dev, qat_req->phy_out,
911 				 sizeof(struct qat_rsa_output_params),
912 				 DMA_TO_DEVICE);
913 unmap_in_params:
914 	if (!dma_mapping_error(dev, qat_req->phy_in))
915 		dma_unmap_single(dev, qat_req->phy_in,
916 				 sizeof(struct qat_rsa_input_params),
917 				 DMA_TO_DEVICE);
918 unmap_dst:
919 	if (!dma_mapping_error(dev, qat_req->out.rsa.dec.m))
920 		dma_unmap_single(dev, qat_req->out.rsa.dec.m,
921 				 ctx->key_sz, DMA_FROM_DEVICE);
922 	kfree_sensitive(qat_req->dst_align);
923 unmap_src:
924 	if (!dma_mapping_error(dev, qat_req->in.rsa.dec.c))
925 		dma_unmap_single(dev, qat_req->in.rsa.dec.c, ctx->key_sz,
926 				 DMA_TO_DEVICE);
927 	kfree_sensitive(qat_req->src_align);
928 	return ret;
929 }
930 
931 static int qat_rsa_set_n(struct qat_rsa_ctx *ctx, const char *value,
932 			 size_t vlen)
933 {
934 	struct qat_crypto_instance *inst = ctx->inst;
935 	struct device *dev = &GET_DEV(inst->accel_dev);
936 	const char *ptr = value;
937 	int ret;
938 
939 	while (!*ptr && vlen) {
940 		ptr++;
941 		vlen--;
942 	}
943 
944 	ctx->key_sz = vlen;
945 	ret = -EINVAL;
946 	/* invalid key size provided */
947 	if (!qat_rsa_enc_fn_id(ctx->key_sz))
948 		goto err;
949 
950 	ret = -ENOMEM;
951 	ctx->n = dma_alloc_coherent(dev, ctx->key_sz, &ctx->dma_n, GFP_KERNEL);
952 	if (!ctx->n)
953 		goto err;
954 
955 	memcpy(ctx->n, ptr, ctx->key_sz);
956 	return 0;
957 err:
958 	ctx->key_sz = 0;
959 	ctx->n = NULL;
960 	return ret;
961 }
962 
963 static int qat_rsa_set_e(struct qat_rsa_ctx *ctx, const char *value,
964 			 size_t vlen)
965 {
966 	struct qat_crypto_instance *inst = ctx->inst;
967 	struct device *dev = &GET_DEV(inst->accel_dev);
968 	const char *ptr = value;
969 
970 	while (!*ptr && vlen) {
971 		ptr++;
972 		vlen--;
973 	}
974 
975 	if (!ctx->key_sz || !vlen || vlen > ctx->key_sz) {
976 		ctx->e = NULL;
977 		return -EINVAL;
978 	}
979 
980 	ctx->e = dma_alloc_coherent(dev, ctx->key_sz, &ctx->dma_e, GFP_KERNEL);
981 	if (!ctx->e)
982 		return -ENOMEM;
983 
984 	memcpy(ctx->e + (ctx->key_sz - vlen), ptr, vlen);
985 	return 0;
986 }
987 
988 static int qat_rsa_set_d(struct qat_rsa_ctx *ctx, const char *value,
989 			 size_t vlen)
990 {
991 	struct qat_crypto_instance *inst = ctx->inst;
992 	struct device *dev = &GET_DEV(inst->accel_dev);
993 	const char *ptr = value;
994 	int ret;
995 
996 	while (!*ptr && vlen) {
997 		ptr++;
998 		vlen--;
999 	}
1000 
1001 	ret = -EINVAL;
1002 	if (!ctx->key_sz || !vlen || vlen > ctx->key_sz)
1003 		goto err;
1004 
1005 	ret = -ENOMEM;
1006 	ctx->d = dma_alloc_coherent(dev, ctx->key_sz, &ctx->dma_d, GFP_KERNEL);
1007 	if (!ctx->d)
1008 		goto err;
1009 
1010 	memcpy(ctx->d + (ctx->key_sz - vlen), ptr, vlen);
1011 	return 0;
1012 err:
1013 	ctx->d = NULL;
1014 	return ret;
1015 }
1016 
1017 static void qat_rsa_drop_leading_zeros(const char **ptr, unsigned int *len)
1018 {
1019 	while (!**ptr && *len) {
1020 		(*ptr)++;
1021 		(*len)--;
1022 	}
1023 }
1024 
1025 static void qat_rsa_setkey_crt(struct qat_rsa_ctx *ctx, struct rsa_key *rsa_key)
1026 {
1027 	struct qat_crypto_instance *inst = ctx->inst;
1028 	struct device *dev = &GET_DEV(inst->accel_dev);
1029 	const char *ptr;
1030 	unsigned int len;
1031 	unsigned int half_key_sz = ctx->key_sz / 2;
1032 
1033 	/* p */
1034 	ptr = rsa_key->p;
1035 	len = rsa_key->p_sz;
1036 	qat_rsa_drop_leading_zeros(&ptr, &len);
1037 	if (!len)
1038 		goto err;
1039 	ctx->p = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_p, GFP_KERNEL);
1040 	if (!ctx->p)
1041 		goto err;
1042 	memcpy(ctx->p + (half_key_sz - len), ptr, len);
1043 
1044 	/* q */
1045 	ptr = rsa_key->q;
1046 	len = rsa_key->q_sz;
1047 	qat_rsa_drop_leading_zeros(&ptr, &len);
1048 	if (!len)
1049 		goto free_p;
1050 	ctx->q = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_q, GFP_KERNEL);
1051 	if (!ctx->q)
1052 		goto free_p;
1053 	memcpy(ctx->q + (half_key_sz - len), ptr, len);
1054 
1055 	/* dp */
1056 	ptr = rsa_key->dp;
1057 	len = rsa_key->dp_sz;
1058 	qat_rsa_drop_leading_zeros(&ptr, &len);
1059 	if (!len)
1060 		goto free_q;
1061 	ctx->dp = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_dp,
1062 				     GFP_KERNEL);
1063 	if (!ctx->dp)
1064 		goto free_q;
1065 	memcpy(ctx->dp + (half_key_sz - len), ptr, len);
1066 
1067 	/* dq */
1068 	ptr = rsa_key->dq;
1069 	len = rsa_key->dq_sz;
1070 	qat_rsa_drop_leading_zeros(&ptr, &len);
1071 	if (!len)
1072 		goto free_dp;
1073 	ctx->dq = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_dq,
1074 				     GFP_KERNEL);
1075 	if (!ctx->dq)
1076 		goto free_dp;
1077 	memcpy(ctx->dq + (half_key_sz - len), ptr, len);
1078 
1079 	/* qinv */
1080 	ptr = rsa_key->qinv;
1081 	len = rsa_key->qinv_sz;
1082 	qat_rsa_drop_leading_zeros(&ptr, &len);
1083 	if (!len)
1084 		goto free_dq;
1085 	ctx->qinv = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_qinv,
1086 				       GFP_KERNEL);
1087 	if (!ctx->qinv)
1088 		goto free_dq;
1089 	memcpy(ctx->qinv + (half_key_sz - len), ptr, len);
1090 
1091 	ctx->crt_mode = true;
1092 	return;
1093 
1094 free_dq:
1095 	memset(ctx->dq, '\0', half_key_sz);
1096 	dma_free_coherent(dev, half_key_sz, ctx->dq, ctx->dma_dq);
1097 	ctx->dq = NULL;
1098 free_dp:
1099 	memset(ctx->dp, '\0', half_key_sz);
1100 	dma_free_coherent(dev, half_key_sz, ctx->dp, ctx->dma_dp);
1101 	ctx->dp = NULL;
1102 free_q:
1103 	memset(ctx->q, '\0', half_key_sz);
1104 	dma_free_coherent(dev, half_key_sz, ctx->q, ctx->dma_q);
1105 	ctx->q = NULL;
1106 free_p:
1107 	memset(ctx->p, '\0', half_key_sz);
1108 	dma_free_coherent(dev, half_key_sz, ctx->p, ctx->dma_p);
1109 	ctx->p = NULL;
1110 err:
1111 	ctx->crt_mode = false;
1112 }
1113 
1114 static void qat_rsa_clear_ctx(struct device *dev, struct qat_rsa_ctx *ctx)
1115 {
1116 	unsigned int half_key_sz = ctx->key_sz / 2;
1117 
1118 	/* Free the old key if any */
1119 	if (ctx->n)
1120 		dma_free_coherent(dev, ctx->key_sz, ctx->n, ctx->dma_n);
1121 	if (ctx->e)
1122 		dma_free_coherent(dev, ctx->key_sz, ctx->e, ctx->dma_e);
1123 	if (ctx->d) {
1124 		memset(ctx->d, '\0', ctx->key_sz);
1125 		dma_free_coherent(dev, ctx->key_sz, ctx->d, ctx->dma_d);
1126 	}
1127 	if (ctx->p) {
1128 		memset(ctx->p, '\0', half_key_sz);
1129 		dma_free_coherent(dev, half_key_sz, ctx->p, ctx->dma_p);
1130 	}
1131 	if (ctx->q) {
1132 		memset(ctx->q, '\0', half_key_sz);
1133 		dma_free_coherent(dev, half_key_sz, ctx->q, ctx->dma_q);
1134 	}
1135 	if (ctx->dp) {
1136 		memset(ctx->dp, '\0', half_key_sz);
1137 		dma_free_coherent(dev, half_key_sz, ctx->dp, ctx->dma_dp);
1138 	}
1139 	if (ctx->dq) {
1140 		memset(ctx->dq, '\0', half_key_sz);
1141 		dma_free_coherent(dev, half_key_sz, ctx->dq, ctx->dma_dq);
1142 	}
1143 	if (ctx->qinv) {
1144 		memset(ctx->qinv, '\0', half_key_sz);
1145 		dma_free_coherent(dev, half_key_sz, ctx->qinv, ctx->dma_qinv);
1146 	}
1147 
1148 	ctx->n = NULL;
1149 	ctx->e = NULL;
1150 	ctx->d = NULL;
1151 	ctx->p = NULL;
1152 	ctx->q = NULL;
1153 	ctx->dp = NULL;
1154 	ctx->dq = NULL;
1155 	ctx->qinv = NULL;
1156 	ctx->crt_mode = false;
1157 	ctx->key_sz = 0;
1158 }
1159 
1160 static int qat_rsa_setkey(struct crypto_akcipher *tfm, const void *key,
1161 			  unsigned int keylen, bool private)
1162 {
1163 	struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
1164 	struct device *dev = &GET_DEV(ctx->inst->accel_dev);
1165 	struct rsa_key rsa_key;
1166 	int ret;
1167 
1168 	qat_rsa_clear_ctx(dev, ctx);
1169 
1170 	if (private)
1171 		ret = rsa_parse_priv_key(&rsa_key, key, keylen);
1172 	else
1173 		ret = rsa_parse_pub_key(&rsa_key, key, keylen);
1174 	if (ret < 0)
1175 		goto free;
1176 
1177 	ret = qat_rsa_set_n(ctx, rsa_key.n, rsa_key.n_sz);
1178 	if (ret < 0)
1179 		goto free;
1180 	ret = qat_rsa_set_e(ctx, rsa_key.e, rsa_key.e_sz);
1181 	if (ret < 0)
1182 		goto free;
1183 	if (private) {
1184 		ret = qat_rsa_set_d(ctx, rsa_key.d, rsa_key.d_sz);
1185 		if (ret < 0)
1186 			goto free;
1187 		qat_rsa_setkey_crt(ctx, &rsa_key);
1188 	}
1189 
1190 	if (!ctx->n || !ctx->e) {
1191 		/* invalid key provided */
1192 		ret = -EINVAL;
1193 		goto free;
1194 	}
1195 	if (private && !ctx->d) {
1196 		/* invalid private key provided */
1197 		ret = -EINVAL;
1198 		goto free;
1199 	}
1200 
1201 	return 0;
1202 free:
1203 	qat_rsa_clear_ctx(dev, ctx);
1204 	return ret;
1205 }
1206 
1207 static int qat_rsa_setpubkey(struct crypto_akcipher *tfm, const void *key,
1208 			     unsigned int keylen)
1209 {
1210 	return qat_rsa_setkey(tfm, key, keylen, false);
1211 }
1212 
1213 static int qat_rsa_setprivkey(struct crypto_akcipher *tfm, const void *key,
1214 			      unsigned int keylen)
1215 {
1216 	return qat_rsa_setkey(tfm, key, keylen, true);
1217 }
1218 
1219 static unsigned int qat_rsa_max_size(struct crypto_akcipher *tfm)
1220 {
1221 	struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
1222 
1223 	return ctx->key_sz;
1224 }
1225 
1226 static int qat_rsa_init_tfm(struct crypto_akcipher *tfm)
1227 {
1228 	struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
1229 	struct qat_crypto_instance *inst =
1230 			qat_crypto_get_instance_node(numa_node_id());
1231 
1232 	if (!inst)
1233 		return -EINVAL;
1234 
1235 	akcipher_set_reqsize(tfm, sizeof(struct qat_asym_request) + 64);
1236 
1237 	ctx->key_sz = 0;
1238 	ctx->inst = inst;
1239 	return 0;
1240 }
1241 
1242 static void qat_rsa_exit_tfm(struct crypto_akcipher *tfm)
1243 {
1244 	struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
1245 	struct device *dev = &GET_DEV(ctx->inst->accel_dev);
1246 
1247 	qat_rsa_clear_ctx(dev, ctx);
1248 	qat_crypto_put_instance(ctx->inst);
1249 }
1250 
1251 static struct akcipher_alg rsa = {
1252 	.encrypt = qat_rsa_enc,
1253 	.decrypt = qat_rsa_dec,
1254 	.set_pub_key = qat_rsa_setpubkey,
1255 	.set_priv_key = qat_rsa_setprivkey,
1256 	.max_size = qat_rsa_max_size,
1257 	.init = qat_rsa_init_tfm,
1258 	.exit = qat_rsa_exit_tfm,
1259 	.base = {
1260 		.cra_name = "rsa",
1261 		.cra_driver_name = "qat-rsa",
1262 		.cra_priority = 1000,
1263 		.cra_module = THIS_MODULE,
1264 		.cra_ctxsize = sizeof(struct qat_rsa_ctx),
1265 	},
1266 };
1267 
1268 static struct kpp_alg dh = {
1269 	.set_secret = qat_dh_set_secret,
1270 	.generate_public_key = qat_dh_compute_value,
1271 	.compute_shared_secret = qat_dh_compute_value,
1272 	.max_size = qat_dh_max_size,
1273 	.init = qat_dh_init_tfm,
1274 	.exit = qat_dh_exit_tfm,
1275 	.base = {
1276 		.cra_name = "dh",
1277 		.cra_driver_name = "qat-dh",
1278 		.cra_priority = 1000,
1279 		.cra_module = THIS_MODULE,
1280 		.cra_ctxsize = sizeof(struct qat_dh_ctx),
1281 	},
1282 };
1283 
1284 int qat_asym_algs_register(void)
1285 {
1286 	int ret = 0;
1287 
1288 	mutex_lock(&algs_lock);
1289 	if (++active_devs == 1) {
1290 		rsa.base.cra_flags = 0;
1291 		ret = crypto_register_akcipher(&rsa);
1292 		if (ret)
1293 			goto unlock;
1294 		ret = crypto_register_kpp(&dh);
1295 	}
1296 unlock:
1297 	mutex_unlock(&algs_lock);
1298 	return ret;
1299 }
1300 
1301 void qat_asym_algs_unregister(void)
1302 {
1303 	mutex_lock(&algs_lock);
1304 	if (--active_devs == 0) {
1305 		crypto_unregister_akcipher(&rsa);
1306 		crypto_unregister_kpp(&dh);
1307 	}
1308 	mutex_unlock(&algs_lock);
1309 }
1310