1/* SPDX-License-Identifier: GPL-2.0-only OR BSD-3-Clause */
2/*
3 * AES CTR mode by8 optimization with AVX instructions. (x86_64)
4 *
5 * Copyright(c) 2014 Intel Corporation.
6 *
7 * Contact Information:
8 * James Guilford <james.guilford@intel.com>
9 * Sean Gulley <sean.m.gulley@intel.com>
10 * Chandramouli Narayanan <mouli@linux.intel.com>
11 */
12/*
13 * This is AES128/192/256 CTR mode optimization implementation. It requires
14 * the support of Intel(R) AESNI and AVX instructions.
15 *
16 * This work was inspired by the AES CTR mode optimization published
17 * in Intel Optimized IPSEC Cryptographic library.
18 * Additional information on it can be found at:
19 *    https://github.com/intel/intel-ipsec-mb
20 */
21
22#include <linux/linkage.h>
23
24#define VMOVDQ		vmovdqu
25
26#define xdata0		%xmm0
27#define xdata1		%xmm1
28#define xdata2		%xmm2
29#define xdata3		%xmm3
30#define xdata4		%xmm4
31#define xdata5		%xmm5
32#define xdata6		%xmm6
33#define xdata7		%xmm7
34#define xcounter	%xmm8
35#define xbyteswap	%xmm9
36#define xkey0		%xmm10
37#define xkey4		%xmm11
38#define xkey8		%xmm12
39#define xkey12		%xmm13
40#define xkeyA		%xmm14
41#define xkeyB		%xmm15
42
43#define p_in		%rdi
44#define p_iv		%rsi
45#define p_keys		%rdx
46#define p_out		%rcx
47#define num_bytes	%r8
48
49#define tmp		%r10
50#define	DDQ_DATA	0
51#define	XDATA		1
52#define KEY_128		1
53#define KEY_192		2
54#define KEY_256		3
55
56.section .rodata
57.align 16
58
59byteswap_const:
60	.octa 0x000102030405060708090A0B0C0D0E0F
61ddq_low_msk:
62	.octa 0x0000000000000000FFFFFFFFFFFFFFFF
63ddq_high_add_1:
64	.octa 0x00000000000000010000000000000000
65ddq_add_1:
66	.octa 0x00000000000000000000000000000001
67ddq_add_2:
68	.octa 0x00000000000000000000000000000002
69ddq_add_3:
70	.octa 0x00000000000000000000000000000003
71ddq_add_4:
72	.octa 0x00000000000000000000000000000004
73ddq_add_5:
74	.octa 0x00000000000000000000000000000005
75ddq_add_6:
76	.octa 0x00000000000000000000000000000006
77ddq_add_7:
78	.octa 0x00000000000000000000000000000007
79ddq_add_8:
80	.octa 0x00000000000000000000000000000008
81
82.text
83
84/* generate a unique variable for ddq_add_x */
85
86/* generate a unique variable for xmm register */
87.macro setxdata n
88	var_xdata = %xmm\n
89.endm
90
91/* club the numeric 'id' to the symbol 'name' */
92
93.macro club name, id
94.altmacro
95	.if \name == XDATA
96		setxdata %\id
97	.endif
98.noaltmacro
99.endm
100
101/*
102 * do_aes num_in_par load_keys key_len
103 * This increments p_in, but not p_out
104 */
105.macro do_aes b, k, key_len
106	.set by, \b
107	.set load_keys, \k
108	.set klen, \key_len
109
110	.if (load_keys)
111		vmovdqa	0*16(p_keys), xkey0
112	.endif
113
114	vpshufb	xbyteswap, xcounter, xdata0
115
116	.set i, 1
117	.rept (by - 1)
118		club XDATA, i
119		vpaddq	(ddq_add_1 + 16 * (i - 1))(%rip), xcounter, var_xdata
120		vptest	ddq_low_msk(%rip), var_xdata
121		jnz 1f
122		vpaddq	ddq_high_add_1(%rip), var_xdata, var_xdata
123		vpaddq	ddq_high_add_1(%rip), xcounter, xcounter
124		1:
125		vpshufb	xbyteswap, var_xdata, var_xdata
126		.set i, (i +1)
127	.endr
128
129	vmovdqa	1*16(p_keys), xkeyA
130
131	vpxor	xkey0, xdata0, xdata0
132	vpaddq	(ddq_add_1 + 16 * (by - 1))(%rip), xcounter, xcounter
133	vptest	ddq_low_msk(%rip), xcounter
134	jnz	1f
135	vpaddq	ddq_high_add_1(%rip), xcounter, xcounter
136	1:
137
138	.set i, 1
139	.rept (by - 1)
140		club XDATA, i
141		vpxor	xkey0, var_xdata, var_xdata
142		.set i, (i +1)
143	.endr
144
145	vmovdqa	2*16(p_keys), xkeyB
146
147	.set i, 0
148	.rept by
149		club XDATA, i
150		vaesenc	xkeyA, var_xdata, var_xdata		/* key 1 */
151		.set i, (i +1)
152	.endr
153
154	.if (klen == KEY_128)
155		.if (load_keys)
156			vmovdqa	3*16(p_keys), xkey4
157		.endif
158	.else
159		vmovdqa	3*16(p_keys), xkeyA
160	.endif
161
162	.set i, 0
163	.rept by
164		club XDATA, i
165		vaesenc	xkeyB, var_xdata, var_xdata		/* key 2 */
166		.set i, (i +1)
167	.endr
168
169	add	$(16*by), p_in
170
171	.if (klen == KEY_128)
172		vmovdqa	4*16(p_keys), xkeyB
173	.else
174		.if (load_keys)
175			vmovdqa	4*16(p_keys), xkey4
176		.endif
177	.endif
178
179	.set i, 0
180	.rept by
181		club XDATA, i
182		/* key 3 */
183		.if (klen == KEY_128)
184			vaesenc	xkey4, var_xdata, var_xdata
185		.else
186			vaesenc	xkeyA, var_xdata, var_xdata
187		.endif
188		.set i, (i +1)
189	.endr
190
191	vmovdqa	5*16(p_keys), xkeyA
192
193	.set i, 0
194	.rept by
195		club XDATA, i
196		/* key 4 */
197		.if (klen == KEY_128)
198			vaesenc	xkeyB, var_xdata, var_xdata
199		.else
200			vaesenc	xkey4, var_xdata, var_xdata
201		.endif
202		.set i, (i +1)
203	.endr
204
205	.if (klen == KEY_128)
206		.if (load_keys)
207			vmovdqa	6*16(p_keys), xkey8
208		.endif
209	.else
210		vmovdqa	6*16(p_keys), xkeyB
211	.endif
212
213	.set i, 0
214	.rept by
215		club XDATA, i
216		vaesenc	xkeyA, var_xdata, var_xdata		/* key 5 */
217		.set i, (i +1)
218	.endr
219
220	vmovdqa	7*16(p_keys), xkeyA
221
222	.set i, 0
223	.rept by
224		club XDATA, i
225		/* key 6 */
226		.if (klen == KEY_128)
227			vaesenc	xkey8, var_xdata, var_xdata
228		.else
229			vaesenc	xkeyB, var_xdata, var_xdata
230		.endif
231		.set i, (i +1)
232	.endr
233
234	.if (klen == KEY_128)
235		vmovdqa	8*16(p_keys), xkeyB
236	.else
237		.if (load_keys)
238			vmovdqa	8*16(p_keys), xkey8
239		.endif
240	.endif
241
242	.set i, 0
243	.rept by
244		club XDATA, i
245		vaesenc	xkeyA, var_xdata, var_xdata		/* key 7 */
246		.set i, (i +1)
247	.endr
248
249	.if (klen == KEY_128)
250		.if (load_keys)
251			vmovdqa	9*16(p_keys), xkey12
252		.endif
253	.else
254		vmovdqa	9*16(p_keys), xkeyA
255	.endif
256
257	.set i, 0
258	.rept by
259		club XDATA, i
260		/* key 8 */
261		.if (klen == KEY_128)
262			vaesenc	xkeyB, var_xdata, var_xdata
263		.else
264			vaesenc	xkey8, var_xdata, var_xdata
265		.endif
266		.set i, (i +1)
267	.endr
268
269	vmovdqa	10*16(p_keys), xkeyB
270
271	.set i, 0
272	.rept by
273		club XDATA, i
274		/* key 9 */
275		.if (klen == KEY_128)
276			vaesenc	xkey12, var_xdata, var_xdata
277		.else
278			vaesenc	xkeyA, var_xdata, var_xdata
279		.endif
280		.set i, (i +1)
281	.endr
282
283	.if (klen != KEY_128)
284		vmovdqa	11*16(p_keys), xkeyA
285	.endif
286
287	.set i, 0
288	.rept by
289		club XDATA, i
290		/* key 10 */
291		.if (klen == KEY_128)
292			vaesenclast	xkeyB, var_xdata, var_xdata
293		.else
294			vaesenc	xkeyB, var_xdata, var_xdata
295		.endif
296		.set i, (i +1)
297	.endr
298
299	.if (klen != KEY_128)
300		.if (load_keys)
301			vmovdqa	12*16(p_keys), xkey12
302		.endif
303
304		.set i, 0
305		.rept by
306			club XDATA, i
307			vaesenc	xkeyA, var_xdata, var_xdata	/* key 11 */
308			.set i, (i +1)
309		.endr
310
311		.if (klen == KEY_256)
312			vmovdqa	13*16(p_keys), xkeyA
313		.endif
314
315		.set i, 0
316		.rept by
317			club XDATA, i
318			.if (klen == KEY_256)
319				/* key 12 */
320				vaesenc	xkey12, var_xdata, var_xdata
321			.else
322				vaesenclast xkey12, var_xdata, var_xdata
323			.endif
324			.set i, (i +1)
325		.endr
326
327		.if (klen == KEY_256)
328			vmovdqa	14*16(p_keys), xkeyB
329
330			.set i, 0
331			.rept by
332				club XDATA, i
333				/* key 13 */
334				vaesenc	xkeyA, var_xdata, var_xdata
335				.set i, (i +1)
336			.endr
337
338			.set i, 0
339			.rept by
340				club XDATA, i
341				/* key 14 */
342				vaesenclast	xkeyB, var_xdata, var_xdata
343				.set i, (i +1)
344			.endr
345		.endif
346	.endif
347
348	.set i, 0
349	.rept (by / 2)
350		.set j, (i+1)
351		VMOVDQ	(i*16 - 16*by)(p_in), xkeyA
352		VMOVDQ	(j*16 - 16*by)(p_in), xkeyB
353		club XDATA, i
354		vpxor	xkeyA, var_xdata, var_xdata
355		club XDATA, j
356		vpxor	xkeyB, var_xdata, var_xdata
357		.set i, (i+2)
358	.endr
359
360	.if (i < by)
361		VMOVDQ	(i*16 - 16*by)(p_in), xkeyA
362		club XDATA, i
363		vpxor	xkeyA, var_xdata, var_xdata
364	.endif
365
366	.set i, 0
367	.rept by
368		club XDATA, i
369		VMOVDQ	var_xdata, i*16(p_out)
370		.set i, (i+1)
371	.endr
372.endm
373
374.macro do_aes_load val, key_len
375	do_aes \val, 1, \key_len
376.endm
377
378.macro do_aes_noload val, key_len
379	do_aes \val, 0, \key_len
380.endm
381
382/* main body of aes ctr load */
383
384.macro do_aes_ctrmain key_len
385	cmp	$16, num_bytes
386	jb	.Ldo_return2\key_len
387
388	vmovdqa	byteswap_const(%rip), xbyteswap
389	vmovdqu	(p_iv), xcounter
390	vpshufb	xbyteswap, xcounter, xcounter
391
392	mov	num_bytes, tmp
393	and	$(7*16), tmp
394	jz	.Lmult_of_8_blks\key_len
395
396	/* 1 <= tmp <= 7 */
397	cmp	$(4*16), tmp
398	jg	.Lgt4\key_len
399	je	.Leq4\key_len
400
401.Llt4\key_len:
402	cmp	$(2*16), tmp
403	jg	.Leq3\key_len
404	je	.Leq2\key_len
405
406.Leq1\key_len:
407	do_aes_load	1, \key_len
408	add	$(1*16), p_out
409	and	$(~7*16), num_bytes
410	jz	.Ldo_return2\key_len
411	jmp	.Lmain_loop2\key_len
412
413.Leq2\key_len:
414	do_aes_load	2, \key_len
415	add	$(2*16), p_out
416	and	$(~7*16), num_bytes
417	jz	.Ldo_return2\key_len
418	jmp	.Lmain_loop2\key_len
419
420
421.Leq3\key_len:
422	do_aes_load	3, \key_len
423	add	$(3*16), p_out
424	and	$(~7*16), num_bytes
425	jz	.Ldo_return2\key_len
426	jmp	.Lmain_loop2\key_len
427
428.Leq4\key_len:
429	do_aes_load	4, \key_len
430	add	$(4*16), p_out
431	and	$(~7*16), num_bytes
432	jz	.Ldo_return2\key_len
433	jmp	.Lmain_loop2\key_len
434
435.Lgt4\key_len:
436	cmp	$(6*16), tmp
437	jg	.Leq7\key_len
438	je	.Leq6\key_len
439
440.Leq5\key_len:
441	do_aes_load	5, \key_len
442	add	$(5*16), p_out
443	and	$(~7*16), num_bytes
444	jz	.Ldo_return2\key_len
445	jmp	.Lmain_loop2\key_len
446
447.Leq6\key_len:
448	do_aes_load	6, \key_len
449	add	$(6*16), p_out
450	and	$(~7*16), num_bytes
451	jz	.Ldo_return2\key_len
452	jmp	.Lmain_loop2\key_len
453
454.Leq7\key_len:
455	do_aes_load	7, \key_len
456	add	$(7*16), p_out
457	and	$(~7*16), num_bytes
458	jz	.Ldo_return2\key_len
459	jmp	.Lmain_loop2\key_len
460
461.Lmult_of_8_blks\key_len:
462	.if (\key_len != KEY_128)
463		vmovdqa	0*16(p_keys), xkey0
464		vmovdqa	4*16(p_keys), xkey4
465		vmovdqa	8*16(p_keys), xkey8
466		vmovdqa	12*16(p_keys), xkey12
467	.else
468		vmovdqa	0*16(p_keys), xkey0
469		vmovdqa	3*16(p_keys), xkey4
470		vmovdqa	6*16(p_keys), xkey8
471		vmovdqa	9*16(p_keys), xkey12
472	.endif
473.align 16
474.Lmain_loop2\key_len:
475	/* num_bytes is a multiple of 8 and >0 */
476	do_aes_noload	8, \key_len
477	add	$(8*16), p_out
478	sub	$(8*16), num_bytes
479	jne	.Lmain_loop2\key_len
480
481.Ldo_return2\key_len:
482	/* return updated IV */
483	vpshufb	xbyteswap, xcounter, xcounter
484	vmovdqu	xcounter, (p_iv)
485	RET
486.endm
487
488/*
489 * routine to do AES128 CTR enc/decrypt "by8"
490 * XMM registers are clobbered.
491 * Saving/restoring must be done at a higher level
492 * aes_ctr_enc_128_avx_by8(void *in, void *iv, void *keys, void *out,
493 *			unsigned int num_bytes)
494 */
495SYM_FUNC_START(aes_ctr_enc_128_avx_by8)
496	/* call the aes main loop */
497	do_aes_ctrmain KEY_128
498
499SYM_FUNC_END(aes_ctr_enc_128_avx_by8)
500
501/*
502 * routine to do AES192 CTR enc/decrypt "by8"
503 * XMM registers are clobbered.
504 * Saving/restoring must be done at a higher level
505 * aes_ctr_enc_192_avx_by8(void *in, void *iv, void *keys, void *out,
506 *			unsigned int num_bytes)
507 */
508SYM_FUNC_START(aes_ctr_enc_192_avx_by8)
509	/* call the aes main loop */
510	do_aes_ctrmain KEY_192
511
512SYM_FUNC_END(aes_ctr_enc_192_avx_by8)
513
514/*
515 * routine to do AES256 CTR enc/decrypt "by8"
516 * XMM registers are clobbered.
517 * Saving/restoring must be done at a higher level
518 * aes_ctr_enc_256_avx_by8(void *in, void *iv, void *keys, void *out,
519 *			unsigned int num_bytes)
520 */
521SYM_FUNC_START(aes_ctr_enc_256_avx_by8)
522	/* call the aes main loop */
523	do_aes_ctrmain KEY_256
524
525SYM_FUNC_END(aes_ctr_enc_256_avx_by8)
526