1 /* 2 * RISC-V Vector Crypto Extension Helpers for QEMU. 3 * 4 * Copyright (C) 2023 SiFive, Inc. 5 * Written by Codethink Ltd and SiFive. 6 * 7 * This program is free software; you can redistribute it and/or modify it 8 * under the terms and conditions of the GNU General Public License, 9 * version 2 or later, as published by the Free Software Foundation. 10 * 11 * This program is distributed in the hope it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for 14 * more details. 15 * 16 * You should have received a copy of the GNU General Public License along with 17 * this program. If not, see <http://www.gnu.org/licenses/>. 18 */ 19 20 #include "qemu/osdep.h" 21 #include "qemu/host-utils.h" 22 #include "qemu/bitops.h" 23 #include "qemu/bswap.h" 24 #include "cpu.h" 25 #include "crypto/aes.h" 26 #include "crypto/aes-round.h" 27 #include "exec/memop.h" 28 #include "exec/exec-all.h" 29 #include "exec/helper-proto.h" 30 #include "internals.h" 31 #include "vector_internals.h" 32 33 static uint64_t clmul64(uint64_t y, uint64_t x) 34 { 35 uint64_t result = 0; 36 for (int j = 63; j >= 0; j--) { 37 if ((y >> j) & 1) { 38 result ^= (x << j); 39 } 40 } 41 return result; 42 } 43 44 static uint64_t clmulh64(uint64_t y, uint64_t x) 45 { 46 uint64_t result = 0; 47 for (int j = 63; j >= 1; j--) { 48 if ((y >> j) & 1) { 49 result ^= (x >> (64 - j)); 50 } 51 } 52 return result; 53 } 54 55 RVVCALL(OPIVV2, vclmul_vv, OP_UUU_D, H8, H8, H8, clmul64) 56 GEN_VEXT_VV(vclmul_vv, 8) 57 RVVCALL(OPIVX2, vclmul_vx, OP_UUU_D, H8, H8, clmul64) 58 GEN_VEXT_VX(vclmul_vx, 8) 59 RVVCALL(OPIVV2, vclmulh_vv, OP_UUU_D, H8, H8, H8, clmulh64) 60 GEN_VEXT_VV(vclmulh_vv, 8) 61 RVVCALL(OPIVX2, vclmulh_vx, OP_UUU_D, H8, H8, clmulh64) 62 GEN_VEXT_VX(vclmulh_vx, 8) 63 64 RVVCALL(OPIVV2, vror_vv_b, OP_UUU_B, H1, H1, H1, ror8) 65 RVVCALL(OPIVV2, vror_vv_h, OP_UUU_H, H2, H2, H2, ror16) 66 RVVCALL(OPIVV2, vror_vv_w, OP_UUU_W, H4, H4, H4, ror32) 67 RVVCALL(OPIVV2, vror_vv_d, OP_UUU_D, H8, H8, H8, ror64) 68 GEN_VEXT_VV(vror_vv_b, 1) 69 GEN_VEXT_VV(vror_vv_h, 2) 70 GEN_VEXT_VV(vror_vv_w, 4) 71 GEN_VEXT_VV(vror_vv_d, 8) 72 73 RVVCALL(OPIVX2, vror_vx_b, OP_UUU_B, H1, H1, ror8) 74 RVVCALL(OPIVX2, vror_vx_h, OP_UUU_H, H2, H2, ror16) 75 RVVCALL(OPIVX2, vror_vx_w, OP_UUU_W, H4, H4, ror32) 76 RVVCALL(OPIVX2, vror_vx_d, OP_UUU_D, H8, H8, ror64) 77 GEN_VEXT_VX(vror_vx_b, 1) 78 GEN_VEXT_VX(vror_vx_h, 2) 79 GEN_VEXT_VX(vror_vx_w, 4) 80 GEN_VEXT_VX(vror_vx_d, 8) 81 82 RVVCALL(OPIVV2, vrol_vv_b, OP_UUU_B, H1, H1, H1, rol8) 83 RVVCALL(OPIVV2, vrol_vv_h, OP_UUU_H, H2, H2, H2, rol16) 84 RVVCALL(OPIVV2, vrol_vv_w, OP_UUU_W, H4, H4, H4, rol32) 85 RVVCALL(OPIVV2, vrol_vv_d, OP_UUU_D, H8, H8, H8, rol64) 86 GEN_VEXT_VV(vrol_vv_b, 1) 87 GEN_VEXT_VV(vrol_vv_h, 2) 88 GEN_VEXT_VV(vrol_vv_w, 4) 89 GEN_VEXT_VV(vrol_vv_d, 8) 90 91 RVVCALL(OPIVX2, vrol_vx_b, OP_UUU_B, H1, H1, rol8) 92 RVVCALL(OPIVX2, vrol_vx_h, OP_UUU_H, H2, H2, rol16) 93 RVVCALL(OPIVX2, vrol_vx_w, OP_UUU_W, H4, H4, rol32) 94 RVVCALL(OPIVX2, vrol_vx_d, OP_UUU_D, H8, H8, rol64) 95 GEN_VEXT_VX(vrol_vx_b, 1) 96 GEN_VEXT_VX(vrol_vx_h, 2) 97 GEN_VEXT_VX(vrol_vx_w, 4) 98 GEN_VEXT_VX(vrol_vx_d, 8) 99 100 static uint64_t brev8(uint64_t val) 101 { 102 val = ((val & 0x5555555555555555ull) << 1) | 103 ((val & 0xAAAAAAAAAAAAAAAAull) >> 1); 104 val = ((val & 0x3333333333333333ull) << 2) | 105 ((val & 0xCCCCCCCCCCCCCCCCull) >> 2); 106 val = ((val & 0x0F0F0F0F0F0F0F0Full) << 4) | 107 ((val & 0xF0F0F0F0F0F0F0F0ull) >> 4); 108 109 return val; 110 } 111 112 RVVCALL(OPIVV1, vbrev8_v_b, OP_UU_B, H1, H1, brev8) 113 RVVCALL(OPIVV1, vbrev8_v_h, OP_UU_H, H2, H2, brev8) 114 RVVCALL(OPIVV1, vbrev8_v_w, OP_UU_W, H4, H4, brev8) 115 RVVCALL(OPIVV1, vbrev8_v_d, OP_UU_D, H8, H8, brev8) 116 GEN_VEXT_V(vbrev8_v_b, 1) 117 GEN_VEXT_V(vbrev8_v_h, 2) 118 GEN_VEXT_V(vbrev8_v_w, 4) 119 GEN_VEXT_V(vbrev8_v_d, 8) 120 121 #define DO_IDENTITY(a) (a) 122 RVVCALL(OPIVV1, vrev8_v_b, OP_UU_B, H1, H1, DO_IDENTITY) 123 RVVCALL(OPIVV1, vrev8_v_h, OP_UU_H, H2, H2, bswap16) 124 RVVCALL(OPIVV1, vrev8_v_w, OP_UU_W, H4, H4, bswap32) 125 RVVCALL(OPIVV1, vrev8_v_d, OP_UU_D, H8, H8, bswap64) 126 GEN_VEXT_V(vrev8_v_b, 1) 127 GEN_VEXT_V(vrev8_v_h, 2) 128 GEN_VEXT_V(vrev8_v_w, 4) 129 GEN_VEXT_V(vrev8_v_d, 8) 130 131 #define DO_ANDN(a, b) ((a) & ~(b)) 132 RVVCALL(OPIVV2, vandn_vv_b, OP_UUU_B, H1, H1, H1, DO_ANDN) 133 RVVCALL(OPIVV2, vandn_vv_h, OP_UUU_H, H2, H2, H2, DO_ANDN) 134 RVVCALL(OPIVV2, vandn_vv_w, OP_UUU_W, H4, H4, H4, DO_ANDN) 135 RVVCALL(OPIVV2, vandn_vv_d, OP_UUU_D, H8, H8, H8, DO_ANDN) 136 GEN_VEXT_VV(vandn_vv_b, 1) 137 GEN_VEXT_VV(vandn_vv_h, 2) 138 GEN_VEXT_VV(vandn_vv_w, 4) 139 GEN_VEXT_VV(vandn_vv_d, 8) 140 141 RVVCALL(OPIVX2, vandn_vx_b, OP_UUU_B, H1, H1, DO_ANDN) 142 RVVCALL(OPIVX2, vandn_vx_h, OP_UUU_H, H2, H2, DO_ANDN) 143 RVVCALL(OPIVX2, vandn_vx_w, OP_UUU_W, H4, H4, DO_ANDN) 144 RVVCALL(OPIVX2, vandn_vx_d, OP_UUU_D, H8, H8, DO_ANDN) 145 GEN_VEXT_VX(vandn_vx_b, 1) 146 GEN_VEXT_VX(vandn_vx_h, 2) 147 GEN_VEXT_VX(vandn_vx_w, 4) 148 GEN_VEXT_VX(vandn_vx_d, 8) 149 150 RVVCALL(OPIVV1, vbrev_v_b, OP_UU_B, H1, H1, revbit8) 151 RVVCALL(OPIVV1, vbrev_v_h, OP_UU_H, H2, H2, revbit16) 152 RVVCALL(OPIVV1, vbrev_v_w, OP_UU_W, H4, H4, revbit32) 153 RVVCALL(OPIVV1, vbrev_v_d, OP_UU_D, H8, H8, revbit64) 154 GEN_VEXT_V(vbrev_v_b, 1) 155 GEN_VEXT_V(vbrev_v_h, 2) 156 GEN_VEXT_V(vbrev_v_w, 4) 157 GEN_VEXT_V(vbrev_v_d, 8) 158 159 RVVCALL(OPIVV1, vclz_v_b, OP_UU_B, H1, H1, clz8) 160 RVVCALL(OPIVV1, vclz_v_h, OP_UU_H, H2, H2, clz16) 161 RVVCALL(OPIVV1, vclz_v_w, OP_UU_W, H4, H4, clz32) 162 RVVCALL(OPIVV1, vclz_v_d, OP_UU_D, H8, H8, clz64) 163 GEN_VEXT_V(vclz_v_b, 1) 164 GEN_VEXT_V(vclz_v_h, 2) 165 GEN_VEXT_V(vclz_v_w, 4) 166 GEN_VEXT_V(vclz_v_d, 8) 167 168 RVVCALL(OPIVV1, vctz_v_b, OP_UU_B, H1, H1, ctz8) 169 RVVCALL(OPIVV1, vctz_v_h, OP_UU_H, H2, H2, ctz16) 170 RVVCALL(OPIVV1, vctz_v_w, OP_UU_W, H4, H4, ctz32) 171 RVVCALL(OPIVV1, vctz_v_d, OP_UU_D, H8, H8, ctz64) 172 GEN_VEXT_V(vctz_v_b, 1) 173 GEN_VEXT_V(vctz_v_h, 2) 174 GEN_VEXT_V(vctz_v_w, 4) 175 GEN_VEXT_V(vctz_v_d, 8) 176 177 RVVCALL(OPIVV1, vcpop_v_b, OP_UU_B, H1, H1, ctpop8) 178 RVVCALL(OPIVV1, vcpop_v_h, OP_UU_H, H2, H2, ctpop16) 179 RVVCALL(OPIVV1, vcpop_v_w, OP_UU_W, H4, H4, ctpop32) 180 RVVCALL(OPIVV1, vcpop_v_d, OP_UU_D, H8, H8, ctpop64) 181 GEN_VEXT_V(vcpop_v_b, 1) 182 GEN_VEXT_V(vcpop_v_h, 2) 183 GEN_VEXT_V(vcpop_v_w, 4) 184 GEN_VEXT_V(vcpop_v_d, 8) 185 186 #define DO_SLL(N, M) (N << (M & (sizeof(N) * 8 - 1))) 187 RVVCALL(OPIVV2, vwsll_vv_b, WOP_UUU_B, H2, H1, H1, DO_SLL) 188 RVVCALL(OPIVV2, vwsll_vv_h, WOP_UUU_H, H4, H2, H2, DO_SLL) 189 RVVCALL(OPIVV2, vwsll_vv_w, WOP_UUU_W, H8, H4, H4, DO_SLL) 190 GEN_VEXT_VV(vwsll_vv_b, 2) 191 GEN_VEXT_VV(vwsll_vv_h, 4) 192 GEN_VEXT_VV(vwsll_vv_w, 8) 193 194 RVVCALL(OPIVX2, vwsll_vx_b, WOP_UUU_B, H2, H1, DO_SLL) 195 RVVCALL(OPIVX2, vwsll_vx_h, WOP_UUU_H, H4, H2, DO_SLL) 196 RVVCALL(OPIVX2, vwsll_vx_w, WOP_UUU_W, H8, H4, DO_SLL) 197 GEN_VEXT_VX(vwsll_vx_b, 2) 198 GEN_VEXT_VX(vwsll_vx_h, 4) 199 GEN_VEXT_VX(vwsll_vx_w, 8) 200 201 void HELPER(egs_check)(uint32_t egs, CPURISCVState *env) 202 { 203 uint32_t vl = env->vl; 204 uint32_t vstart = env->vstart; 205 206 if (vl % egs != 0 || vstart % egs != 0) { 207 riscv_raise_exception(env, RISCV_EXCP_ILLEGAL_INST, GETPC()); 208 } 209 } 210 211 static inline void xor_round_key(AESState *round_state, AESState *round_key) 212 { 213 round_state->v = round_state->v ^ round_key->v; 214 } 215 216 #define GEN_ZVKNED_HELPER_VV(NAME, ...) \ 217 void HELPER(NAME)(void *vd, void *vs2, CPURISCVState *env, \ 218 uint32_t desc) \ 219 { \ 220 uint32_t vl = env->vl; \ 221 uint32_t total_elems = vext_get_total_elems(env, desc, 4); \ 222 uint32_t vta = vext_vta(desc); \ 223 \ 224 for (uint32_t i = env->vstart / 4; i < env->vl / 4; i++) { \ 225 AESState round_key; \ 226 round_key.d[0] = *((uint64_t *)vs2 + H8(i * 2 + 0)); \ 227 round_key.d[1] = *((uint64_t *)vs2 + H8(i * 2 + 1)); \ 228 AESState round_state; \ 229 round_state.d[0] = *((uint64_t *)vd + H8(i * 2 + 0)); \ 230 round_state.d[1] = *((uint64_t *)vd + H8(i * 2 + 1)); \ 231 __VA_ARGS__; \ 232 *((uint64_t *)vd + H8(i * 2 + 0)) = round_state.d[0]; \ 233 *((uint64_t *)vd + H8(i * 2 + 1)) = round_state.d[1]; \ 234 } \ 235 env->vstart = 0; \ 236 /* set tail elements to 1s */ \ 237 vext_set_elems_1s(vd, vta, vl * 4, total_elems * 4); \ 238 } 239 240 #define GEN_ZVKNED_HELPER_VS(NAME, ...) \ 241 void HELPER(NAME)(void *vd, void *vs2, CPURISCVState *env, \ 242 uint32_t desc) \ 243 { \ 244 uint32_t vl = env->vl; \ 245 uint32_t total_elems = vext_get_total_elems(env, desc, 4); \ 246 uint32_t vta = vext_vta(desc); \ 247 \ 248 for (uint32_t i = env->vstart / 4; i < env->vl / 4; i++) { \ 249 AESState round_key; \ 250 round_key.d[0] = *((uint64_t *)vs2 + H8(0)); \ 251 round_key.d[1] = *((uint64_t *)vs2 + H8(1)); \ 252 AESState round_state; \ 253 round_state.d[0] = *((uint64_t *)vd + H8(i * 2 + 0)); \ 254 round_state.d[1] = *((uint64_t *)vd + H8(i * 2 + 1)); \ 255 __VA_ARGS__; \ 256 *((uint64_t *)vd + H8(i * 2 + 0)) = round_state.d[0]; \ 257 *((uint64_t *)vd + H8(i * 2 + 1)) = round_state.d[1]; \ 258 } \ 259 env->vstart = 0; \ 260 /* set tail elements to 1s */ \ 261 vext_set_elems_1s(vd, vta, vl * 4, total_elems * 4); \ 262 } 263 264 GEN_ZVKNED_HELPER_VV(vaesef_vv, aesenc_SB_SR_AK(&round_state, 265 &round_state, 266 &round_key, 267 false);) 268 GEN_ZVKNED_HELPER_VS(vaesef_vs, aesenc_SB_SR_AK(&round_state, 269 &round_state, 270 &round_key, 271 false);) 272 GEN_ZVKNED_HELPER_VV(vaesdf_vv, aesdec_ISB_ISR_AK(&round_state, 273 &round_state, 274 &round_key, 275 false);) 276 GEN_ZVKNED_HELPER_VS(vaesdf_vs, aesdec_ISB_ISR_AK(&round_state, 277 &round_state, 278 &round_key, 279 false);) 280 GEN_ZVKNED_HELPER_VV(vaesem_vv, aesenc_SB_SR_MC_AK(&round_state, 281 &round_state, 282 &round_key, 283 false);) 284 GEN_ZVKNED_HELPER_VS(vaesem_vs, aesenc_SB_SR_MC_AK(&round_state, 285 &round_state, 286 &round_key, 287 false);) 288 GEN_ZVKNED_HELPER_VV(vaesdm_vv, aesdec_ISB_ISR_AK_IMC(&round_state, 289 &round_state, 290 &round_key, 291 false);) 292 GEN_ZVKNED_HELPER_VS(vaesdm_vs, aesdec_ISB_ISR_AK_IMC(&round_state, 293 &round_state, 294 &round_key, 295 false);) 296 GEN_ZVKNED_HELPER_VS(vaesz_vs, xor_round_key(&round_state, &round_key);) 297 298 void HELPER(vaeskf1_vi)(void *vd_vptr, void *vs2_vptr, uint32_t uimm, 299 CPURISCVState *env, uint32_t desc) 300 { 301 uint32_t *vd = vd_vptr; 302 uint32_t *vs2 = vs2_vptr; 303 uint32_t vl = env->vl; 304 uint32_t total_elems = vext_get_total_elems(env, desc, 4); 305 uint32_t vta = vext_vta(desc); 306 307 uimm &= 0b1111; 308 if (uimm > 10 || uimm == 0) { 309 uimm ^= 0b1000; 310 } 311 312 for (uint32_t i = env->vstart / 4; i < env->vl / 4; i++) { 313 uint32_t rk[8], tmp; 314 static const uint32_t rcon[] = { 315 0x00000001, 0x00000002, 0x00000004, 0x00000008, 0x00000010, 316 0x00000020, 0x00000040, 0x00000080, 0x0000001B, 0x00000036, 317 }; 318 319 rk[0] = vs2[i * 4 + H4(0)]; 320 rk[1] = vs2[i * 4 + H4(1)]; 321 rk[2] = vs2[i * 4 + H4(2)]; 322 rk[3] = vs2[i * 4 + H4(3)]; 323 tmp = ror32(rk[3], 8); 324 325 rk[4] = rk[0] ^ (((uint32_t)AES_sbox[(tmp >> 24) & 0xff] << 24) | 326 ((uint32_t)AES_sbox[(tmp >> 16) & 0xff] << 16) | 327 ((uint32_t)AES_sbox[(tmp >> 8) & 0xff] << 8) | 328 ((uint32_t)AES_sbox[(tmp >> 0) & 0xff] << 0)) 329 ^ rcon[uimm - 1]; 330 rk[5] = rk[1] ^ rk[4]; 331 rk[6] = rk[2] ^ rk[5]; 332 rk[7] = rk[3] ^ rk[6]; 333 334 vd[i * 4 + H4(0)] = rk[4]; 335 vd[i * 4 + H4(1)] = rk[5]; 336 vd[i * 4 + H4(2)] = rk[6]; 337 vd[i * 4 + H4(3)] = rk[7]; 338 } 339 env->vstart = 0; 340 /* set tail elements to 1s */ 341 vext_set_elems_1s(vd, vta, vl * 4, total_elems * 4); 342 } 343 344 void HELPER(vaeskf2_vi)(void *vd_vptr, void *vs2_vptr, uint32_t uimm, 345 CPURISCVState *env, uint32_t desc) 346 { 347 uint32_t *vd = vd_vptr; 348 uint32_t *vs2 = vs2_vptr; 349 uint32_t vl = env->vl; 350 uint32_t total_elems = vext_get_total_elems(env, desc, 4); 351 uint32_t vta = vext_vta(desc); 352 353 uimm &= 0b1111; 354 if (uimm > 14 || uimm < 2) { 355 uimm ^= 0b1000; 356 } 357 358 for (uint32_t i = env->vstart / 4; i < env->vl / 4; i++) { 359 uint32_t rk[12], tmp; 360 static const uint32_t rcon[] = { 361 0x00000001, 0x00000002, 0x00000004, 0x00000008, 0x00000010, 362 0x00000020, 0x00000040, 0x00000080, 0x0000001B, 0x00000036, 363 }; 364 365 rk[0] = vd[i * 4 + H4(0)]; 366 rk[1] = vd[i * 4 + H4(1)]; 367 rk[2] = vd[i * 4 + H4(2)]; 368 rk[3] = vd[i * 4 + H4(3)]; 369 rk[4] = vs2[i * 4 + H4(0)]; 370 rk[5] = vs2[i * 4 + H4(1)]; 371 rk[6] = vs2[i * 4 + H4(2)]; 372 rk[7] = vs2[i * 4 + H4(3)]; 373 374 if (uimm % 2 == 0) { 375 tmp = ror32(rk[7], 8); 376 rk[8] = rk[0] ^ (((uint32_t)AES_sbox[(tmp >> 24) & 0xff] << 24) | 377 ((uint32_t)AES_sbox[(tmp >> 16) & 0xff] << 16) | 378 ((uint32_t)AES_sbox[(tmp >> 8) & 0xff] << 8) | 379 ((uint32_t)AES_sbox[(tmp >> 0) & 0xff] << 0)) 380 ^ rcon[(uimm - 1) / 2]; 381 } else { 382 rk[8] = rk[0] ^ (((uint32_t)AES_sbox[(rk[7] >> 24) & 0xff] << 24) | 383 ((uint32_t)AES_sbox[(rk[7] >> 16) & 0xff] << 16) | 384 ((uint32_t)AES_sbox[(rk[7] >> 8) & 0xff] << 8) | 385 ((uint32_t)AES_sbox[(rk[7] >> 0) & 0xff] << 0)); 386 } 387 rk[9] = rk[1] ^ rk[8]; 388 rk[10] = rk[2] ^ rk[9]; 389 rk[11] = rk[3] ^ rk[10]; 390 391 vd[i * 4 + H4(0)] = rk[8]; 392 vd[i * 4 + H4(1)] = rk[9]; 393 vd[i * 4 + H4(2)] = rk[10]; 394 vd[i * 4 + H4(3)] = rk[11]; 395 } 396 env->vstart = 0; 397 /* set tail elements to 1s */ 398 vext_set_elems_1s(vd, vta, vl * 4, total_elems * 4); 399 } 400