1 /* 2 * ARM SME Operations 3 * 4 * Copyright (c) 2022 Linaro, Ltd. 5 * 6 * This library is free software; you can redistribute it and/or 7 * modify it under the terms of the GNU Lesser General Public 8 * License as published by the Free Software Foundation; either 9 * version 2.1 of the License, or (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 14 * Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public 17 * License along with this library; if not, see <http://www.gnu.org/licenses/>. 18 */ 19 20 #include "qemu/osdep.h" 21 #include "cpu.h" 22 #include "internals.h" 23 #include "tcg/tcg-gvec-desc.h" 24 #include "exec/helper-proto.h" 25 #include "accel/tcg/cpu-ldst.h" 26 #include "accel/tcg/helper-retaddr.h" 27 #include "qemu/int128.h" 28 #include "fpu/softfloat.h" 29 #include "vec_internal.h" 30 #include "sve_ldst_internal.h" 31 32 33 static bool vectors_overlap(ARMVectorReg *x, unsigned nx, 34 ARMVectorReg *y, unsigned ny) 35 { 36 return !(x + nx <= y || y + ny <= x); 37 } 38 39 void helper_set_svcr(CPUARMState *env, uint32_t val, uint32_t mask) 40 { 41 aarch64_set_svcr(env, val, mask); 42 } 43 44 void helper_sme_zero(CPUARMState *env, uint32_t imm, uint32_t svl) 45 { 46 uint32_t i; 47 48 /* 49 * Special case clearing the entire ZArray. 50 * This falls into the CONSTRAINED UNPREDICTABLE zeroing of any 51 * parts of the ZA storage outside of SVL. 52 */ 53 if (imm == 0xff) { 54 memset(env->za_state.za, 0, sizeof(env->za_state.za)); 55 return; 56 } 57 58 /* 59 * Recall that ZAnH.D[m] is spread across ZA[n+8*m], 60 * so each row is discontiguous within ZA[]. 61 */ 62 for (i = 0; i < svl; i++) { 63 if (imm & (1 << (i % 8))) { 64 memset(&env->za_state.za[i], 0, svl); 65 } 66 } 67 } 68 69 70 /* 71 * When considering the ZA storage as an array of elements of 72 * type T, the index within that array of the Nth element of 73 * a vertical slice of a tile can be calculated like this, 74 * regardless of the size of type T. This is because the tiles 75 * are interleaved, so if type T is size N bytes then row 1 of 76 * the tile is N rows away from row 0. The division by N to 77 * convert a byte offset into an array index and the multiplication 78 * by N to convert from vslice-index-within-the-tile to 79 * the index within the ZA storage cancel out. 80 */ 81 #define tile_vslice_index(i) ((i) * sizeof(ARMVectorReg)) 82 83 /* 84 * When doing byte arithmetic on the ZA storage, the element 85 * byteoff bytes away in a tile vertical slice is always this 86 * many bytes away in the ZA storage, regardless of the 87 * size of the tile element, assuming that byteoff is a multiple 88 * of the element size. Again this is because of the interleaving 89 * of the tiles. For instance if we have 1 byte per element then 90 * each row of the ZA storage has one byte of the vslice data, 91 * and (counting from 0) byte 8 goes in row 8 of the storage 92 * at offset (8 * row-size-in-bytes). 93 * If we have 8 bytes per element then each row of the ZA storage 94 * has 8 bytes of the data, but there are 8 interleaved tiles and 95 * so byte 8 of the data goes into row 1 of the tile, 96 * which is again row 8 of the storage, so the offset is still 97 * (8 * row-size-in-bytes). Similarly for other element sizes. 98 */ 99 #define tile_vslice_offset(byteoff) ((byteoff) * sizeof(ARMVectorReg)) 100 101 102 /* 103 * Move Zreg vector to ZArray column. 104 */ 105 #define DO_MOVA_C(NAME, TYPE, H) \ 106 void HELPER(NAME)(void *za, void *vn, void *vg, uint32_t desc) \ 107 { \ 108 int i, oprsz = simd_oprsz(desc); \ 109 for (i = 0; i < oprsz; ) { \ 110 uint16_t pg = *(uint16_t *)(vg + H1_2(i >> 3)); \ 111 do { \ 112 if (pg & 1) { \ 113 *(TYPE *)(za + tile_vslice_offset(i)) = *(TYPE *)(vn + H(i)); \ 114 } \ 115 i += sizeof(TYPE); \ 116 pg >>= sizeof(TYPE); \ 117 } while (i & 15); \ 118 } \ 119 } 120 121 DO_MOVA_C(sme_mova_cz_b, uint8_t, H1) 122 DO_MOVA_C(sme_mova_cz_h, uint16_t, H1_2) 123 DO_MOVA_C(sme_mova_cz_s, uint32_t, H1_4) 124 125 void HELPER(sme_mova_cz_d)(void *za, void *vn, void *vg, uint32_t desc) 126 { 127 int i, oprsz = simd_oprsz(desc) / 8; 128 uint8_t *pg = vg; 129 uint64_t *n = vn; 130 uint64_t *a = za; 131 132 for (i = 0; i < oprsz; i++) { 133 if (pg[H1(i)] & 1) { 134 a[tile_vslice_index(i)] = n[i]; 135 } 136 } 137 } 138 139 void HELPER(sme_mova_cz_q)(void *za, void *vn, void *vg, uint32_t desc) 140 { 141 int i, oprsz = simd_oprsz(desc) / 16; 142 uint16_t *pg = vg; 143 Int128 *n = vn; 144 Int128 *a = za; 145 146 /* 147 * Int128 is used here simply to copy 16 bytes, and to simplify 148 * the address arithmetic. 149 */ 150 for (i = 0; i < oprsz; i++) { 151 if (pg[H2(i)] & 1) { 152 a[tile_vslice_index(i)] = n[i]; 153 } 154 } 155 } 156 157 #undef DO_MOVA_C 158 159 /* 160 * Move ZArray column to Zreg vector. 161 */ 162 #define DO_MOVA_Z(NAME, TYPE, H) \ 163 void HELPER(NAME)(void *vd, void *za, void *vg, uint32_t desc) \ 164 { \ 165 int i, oprsz = simd_oprsz(desc); \ 166 for (i = 0; i < oprsz; ) { \ 167 uint16_t pg = *(uint16_t *)(vg + H1_2(i >> 3)); \ 168 do { \ 169 if (pg & 1) { \ 170 *(TYPE *)(vd + H(i)) = *(TYPE *)(za + tile_vslice_offset(i)); \ 171 } \ 172 i += sizeof(TYPE); \ 173 pg >>= sizeof(TYPE); \ 174 } while (i & 15); \ 175 } \ 176 } 177 178 DO_MOVA_Z(sme_mova_zc_b, uint8_t, H1) 179 DO_MOVA_Z(sme_mova_zc_h, uint16_t, H1_2) 180 DO_MOVA_Z(sme_mova_zc_s, uint32_t, H1_4) 181 182 void HELPER(sme_mova_zc_d)(void *vd, void *za, void *vg, uint32_t desc) 183 { 184 int i, oprsz = simd_oprsz(desc) / 8; 185 uint8_t *pg = vg; 186 uint64_t *d = vd; 187 uint64_t *a = za; 188 189 for (i = 0; i < oprsz; i++) { 190 if (pg[H1(i)] & 1) { 191 d[i] = a[tile_vslice_index(i)]; 192 } 193 } 194 } 195 196 void HELPER(sme_mova_zc_q)(void *vd, void *za, void *vg, uint32_t desc) 197 { 198 int i, oprsz = simd_oprsz(desc) / 16; 199 uint16_t *pg = vg; 200 Int128 *d = vd; 201 Int128 *a = za; 202 203 /* 204 * Int128 is used here simply to copy 16 bytes, and to simplify 205 * the address arithmetic. 206 */ 207 for (i = 0; i < oprsz; i++, za += sizeof(ARMVectorReg)) { 208 if (pg[H2(i)] & 1) { 209 d[i] = a[tile_vslice_index(i)]; 210 } 211 } 212 } 213 214 #undef DO_MOVA_Z 215 216 void HELPER(sme2_mova_zc_b)(void *vdst, void *vsrc, uint32_t desc) 217 { 218 const uint8_t *src = vsrc; 219 uint8_t *dst = vdst; 220 size_t i, n = simd_oprsz(desc); 221 222 for (i = 0; i < n; ++i) { 223 dst[i] = src[tile_vslice_index(i)]; 224 } 225 } 226 227 void HELPER(sme2_mova_zc_h)(void *vdst, void *vsrc, uint32_t desc) 228 { 229 const uint16_t *src = vsrc; 230 uint16_t *dst = vdst; 231 size_t i, n = simd_oprsz(desc) / 2; 232 233 for (i = 0; i < n; ++i) { 234 dst[i] = src[tile_vslice_index(i)]; 235 } 236 } 237 238 void HELPER(sme2_mova_zc_s)(void *vdst, void *vsrc, uint32_t desc) 239 { 240 const uint32_t *src = vsrc; 241 uint32_t *dst = vdst; 242 size_t i, n = simd_oprsz(desc) / 4; 243 244 for (i = 0; i < n; ++i) { 245 dst[i] = src[tile_vslice_index(i)]; 246 } 247 } 248 249 void HELPER(sme2_mova_zc_d)(void *vdst, void *vsrc, uint32_t desc) 250 { 251 const uint64_t *src = vsrc; 252 uint64_t *dst = vdst; 253 size_t i, n = simd_oprsz(desc) / 8; 254 255 for (i = 0; i < n; ++i) { 256 dst[i] = src[tile_vslice_index(i)]; 257 } 258 } 259 260 void HELPER(sme2p1_movaz_zc_b)(void *vdst, void *vsrc, uint32_t desc) 261 { 262 uint8_t *src = vsrc; 263 uint8_t *dst = vdst; 264 size_t i, n = simd_oprsz(desc); 265 266 for (i = 0; i < n; ++i) { 267 dst[i] = src[tile_vslice_index(i)]; 268 src[tile_vslice_index(i)] = 0; 269 } 270 } 271 272 void HELPER(sme2p1_movaz_zc_h)(void *vdst, void *vsrc, uint32_t desc) 273 { 274 uint16_t *src = vsrc; 275 uint16_t *dst = vdst; 276 size_t i, n = simd_oprsz(desc) / 2; 277 278 for (i = 0; i < n; ++i) { 279 dst[i] = src[tile_vslice_index(i)]; 280 src[tile_vslice_index(i)] = 0; 281 } 282 } 283 284 void HELPER(sme2p1_movaz_zc_s)(void *vdst, void *vsrc, uint32_t desc) 285 { 286 uint32_t *src = vsrc; 287 uint32_t *dst = vdst; 288 size_t i, n = simd_oprsz(desc) / 4; 289 290 for (i = 0; i < n; ++i) { 291 dst[i] = src[tile_vslice_index(i)]; 292 src[tile_vslice_index(i)] = 0; 293 } 294 } 295 296 void HELPER(sme2p1_movaz_zc_d)(void *vdst, void *vsrc, uint32_t desc) 297 { 298 uint64_t *src = vsrc; 299 uint64_t *dst = vdst; 300 size_t i, n = simd_oprsz(desc) / 8; 301 302 for (i = 0; i < n; ++i) { 303 dst[i] = src[tile_vslice_index(i)]; 304 src[tile_vslice_index(i)] = 0; 305 } 306 } 307 308 void HELPER(sme2p1_movaz_zc_q)(void *vdst, void *vsrc, uint32_t desc) 309 { 310 Int128 *src = vsrc; 311 Int128 *dst = vdst; 312 size_t i, n = simd_oprsz(desc) / 16; 313 314 for (i = 0; i < n; ++i) { 315 dst[i] = src[tile_vslice_index(i)]; 316 memset(&src[tile_vslice_index(i)], 0, 16); 317 } 318 } 319 320 /* 321 * Clear elements in a tile slice comprising len bytes. 322 */ 323 324 typedef void ClearFn(void *ptr, size_t off, size_t len); 325 326 static void clear_horizontal(void *ptr, size_t off, size_t len) 327 { 328 memset(ptr + off, 0, len); 329 } 330 331 static void clear_vertical_b(void *vptr, size_t off, size_t len) 332 { 333 for (size_t i = 0; i < len; ++i) { 334 *(uint8_t *)(vptr + tile_vslice_offset(i + off)) = 0; 335 } 336 } 337 338 static void clear_vertical_h(void *vptr, size_t off, size_t len) 339 { 340 for (size_t i = 0; i < len; i += 2) { 341 *(uint16_t *)(vptr + tile_vslice_offset(i + off)) = 0; 342 } 343 } 344 345 static void clear_vertical_s(void *vptr, size_t off, size_t len) 346 { 347 for (size_t i = 0; i < len; i += 4) { 348 *(uint32_t *)(vptr + tile_vslice_offset(i + off)) = 0; 349 } 350 } 351 352 static void clear_vertical_d(void *vptr, size_t off, size_t len) 353 { 354 for (size_t i = 0; i < len; i += 8) { 355 *(uint64_t *)(vptr + tile_vslice_offset(i + off)) = 0; 356 } 357 } 358 359 static void clear_vertical_q(void *vptr, size_t off, size_t len) 360 { 361 for (size_t i = 0; i < len; i += 16) { 362 memset(vptr + tile_vslice_offset(i + off), 0, 16); 363 } 364 } 365 366 /* 367 * Copy elements from an array into a tile slice comprising len bytes. 368 */ 369 370 typedef void CopyFn(void *dst, const void *src, size_t len); 371 372 static void copy_horizontal(void *dst, const void *src, size_t len) 373 { 374 memcpy(dst, src, len); 375 } 376 377 static void copy_vertical_b(void *vdst, const void *vsrc, size_t len) 378 { 379 const uint8_t *src = vsrc; 380 uint8_t *dst = vdst; 381 size_t i; 382 383 for (i = 0; i < len; ++i) { 384 dst[tile_vslice_index(i)] = src[i]; 385 } 386 } 387 388 static void copy_vertical_h(void *vdst, const void *vsrc, size_t len) 389 { 390 const uint16_t *src = vsrc; 391 uint16_t *dst = vdst; 392 size_t i; 393 394 for (i = 0; i < len / 2; ++i) { 395 dst[tile_vslice_index(i)] = src[i]; 396 } 397 } 398 399 static void copy_vertical_s(void *vdst, const void *vsrc, size_t len) 400 { 401 const uint32_t *src = vsrc; 402 uint32_t *dst = vdst; 403 size_t i; 404 405 for (i = 0; i < len / 4; ++i) { 406 dst[tile_vslice_index(i)] = src[i]; 407 } 408 } 409 410 static void copy_vertical_d(void *vdst, const void *vsrc, size_t len) 411 { 412 const uint64_t *src = vsrc; 413 uint64_t *dst = vdst; 414 size_t i; 415 416 for (i = 0; i < len / 8; ++i) { 417 dst[tile_vslice_index(i)] = src[i]; 418 } 419 } 420 421 static void copy_vertical_q(void *vdst, const void *vsrc, size_t len) 422 { 423 for (size_t i = 0; i < len; i += 16) { 424 memcpy(vdst + tile_vslice_offset(i), vsrc + i, 16); 425 } 426 } 427 428 void HELPER(sme2_mova_cz_b)(void *vdst, void *vsrc, uint32_t desc) 429 { 430 copy_vertical_b(vdst, vsrc, simd_oprsz(desc)); 431 } 432 433 void HELPER(sme2_mova_cz_h)(void *vdst, void *vsrc, uint32_t desc) 434 { 435 copy_vertical_h(vdst, vsrc, simd_oprsz(desc)); 436 } 437 438 void HELPER(sme2_mova_cz_s)(void *vdst, void *vsrc, uint32_t desc) 439 { 440 copy_vertical_s(vdst, vsrc, simd_oprsz(desc)); 441 } 442 443 void HELPER(sme2_mova_cz_d)(void *vdst, void *vsrc, uint32_t desc) 444 { 445 copy_vertical_d(vdst, vsrc, simd_oprsz(desc)); 446 } 447 448 /* 449 * Host and TLB primitives for vertical tile slice addressing. 450 */ 451 452 #define DO_LD(NAME, TYPE, HOST, TLB) \ 453 static inline void sme_##NAME##_v_host(void *za, intptr_t off, void *host) \ 454 { \ 455 TYPE val = HOST(host); \ 456 *(TYPE *)(za + tile_vslice_offset(off)) = val; \ 457 } \ 458 static inline void sme_##NAME##_v_tlb(CPUARMState *env, void *za, \ 459 intptr_t off, target_ulong addr, uintptr_t ra) \ 460 { \ 461 TYPE val = TLB(env, useronly_clean_ptr(addr), ra); \ 462 *(TYPE *)(za + tile_vslice_offset(off)) = val; \ 463 } 464 465 #define DO_ST(NAME, TYPE, HOST, TLB) \ 466 static inline void sme_##NAME##_v_host(void *za, intptr_t off, void *host) \ 467 { \ 468 TYPE val = *(TYPE *)(za + tile_vslice_offset(off)); \ 469 HOST(host, val); \ 470 } \ 471 static inline void sme_##NAME##_v_tlb(CPUARMState *env, void *za, \ 472 intptr_t off, target_ulong addr, uintptr_t ra) \ 473 { \ 474 TYPE val = *(TYPE *)(za + tile_vslice_offset(off)); \ 475 TLB(env, useronly_clean_ptr(addr), val, ra); \ 476 } 477 478 #define DO_LDQ(HNAME, VNAME) \ 479 static inline void VNAME##_v_host(void *za, intptr_t off, void *host) \ 480 { \ 481 HNAME##_host(za, tile_vslice_offset(off), host); \ 482 } \ 483 static inline void VNAME##_v_tlb(CPUARMState *env, void *za, intptr_t off, \ 484 target_ulong addr, uintptr_t ra) \ 485 { \ 486 HNAME##_tlb(env, za, tile_vslice_offset(off), addr, ra); \ 487 } 488 489 #define DO_STQ(HNAME, VNAME) \ 490 static inline void VNAME##_v_host(void *za, intptr_t off, void *host) \ 491 { \ 492 HNAME##_host(za, tile_vslice_offset(off), host); \ 493 } \ 494 static inline void VNAME##_v_tlb(CPUARMState *env, void *za, intptr_t off, \ 495 target_ulong addr, uintptr_t ra) \ 496 { \ 497 HNAME##_tlb(env, za, tile_vslice_offset(off), addr, ra); \ 498 } 499 500 DO_LD(ld1b, uint8_t, ldub_p, cpu_ldub_data_ra) 501 DO_LD(ld1h_be, uint16_t, lduw_be_p, cpu_lduw_be_data_ra) 502 DO_LD(ld1h_le, uint16_t, lduw_le_p, cpu_lduw_le_data_ra) 503 DO_LD(ld1s_be, uint32_t, ldl_be_p, cpu_ldl_be_data_ra) 504 DO_LD(ld1s_le, uint32_t, ldl_le_p, cpu_ldl_le_data_ra) 505 DO_LD(ld1d_be, uint64_t, ldq_be_p, cpu_ldq_be_data_ra) 506 DO_LD(ld1d_le, uint64_t, ldq_le_p, cpu_ldq_le_data_ra) 507 508 DO_LDQ(sve_ld1qq_be, sme_ld1q_be) 509 DO_LDQ(sve_ld1qq_le, sme_ld1q_le) 510 511 DO_ST(st1b, uint8_t, stb_p, cpu_stb_data_ra) 512 DO_ST(st1h_be, uint16_t, stw_be_p, cpu_stw_be_data_ra) 513 DO_ST(st1h_le, uint16_t, stw_le_p, cpu_stw_le_data_ra) 514 DO_ST(st1s_be, uint32_t, stl_be_p, cpu_stl_be_data_ra) 515 DO_ST(st1s_le, uint32_t, stl_le_p, cpu_stl_le_data_ra) 516 DO_ST(st1d_be, uint64_t, stq_be_p, cpu_stq_be_data_ra) 517 DO_ST(st1d_le, uint64_t, stq_le_p, cpu_stq_le_data_ra) 518 519 DO_STQ(sve_st1qq_be, sme_st1q_be) 520 DO_STQ(sve_st1qq_le, sme_st1q_le) 521 522 #undef DO_LD 523 #undef DO_ST 524 #undef DO_LDQ 525 #undef DO_STQ 526 527 /* 528 * Common helper for all contiguous predicated loads. 529 */ 530 531 static inline QEMU_ALWAYS_INLINE 532 void sme_ld1(CPUARMState *env, void *za, uint64_t *vg, 533 const target_ulong addr, uint32_t desc, const uintptr_t ra, 534 const int esz, uint32_t mtedesc, bool vertical, 535 sve_ldst1_host_fn *host_fn, 536 sve_ldst1_tlb_fn *tlb_fn, 537 ClearFn *clr_fn, 538 CopyFn *cpy_fn) 539 { 540 const intptr_t reg_max = simd_oprsz(desc); 541 const intptr_t esize = 1 << esz; 542 intptr_t reg_off, reg_last; 543 SVEContLdSt info; 544 void *host; 545 int flags; 546 547 /* Find the active elements. */ 548 if (!sve_cont_ldst_elements(&info, addr, vg, reg_max, esz, esize)) { 549 /* The entire predicate was false; no load occurs. */ 550 clr_fn(za, 0, reg_max); 551 return; 552 } 553 554 /* Probe the page(s). Exit with exception for any invalid page. */ 555 sve_cont_ldst_pages(&info, FAULT_ALL, env, addr, MMU_DATA_LOAD, ra); 556 557 /* Handle watchpoints for all active elements. */ 558 sve_cont_ldst_watchpoints(&info, env, vg, addr, esize, esize, 559 BP_MEM_READ, ra); 560 561 /* 562 * Handle mte checks for all active elements. 563 * Since TBI must be set for MTE, !mtedesc => !mte_active. 564 */ 565 if (mtedesc) { 566 sve_cont_ldst_mte_check(&info, env, vg, addr, esize, esize, 567 mtedesc, ra); 568 } 569 570 flags = info.page[0].flags | info.page[1].flags; 571 if (unlikely(flags != 0)) { 572 #ifdef CONFIG_USER_ONLY 573 g_assert_not_reached(); 574 #else 575 /* 576 * At least one page includes MMIO. 577 * Any bus operation can fail with cpu_transaction_failed, 578 * which for ARM will raise SyncExternal. Perform the load 579 * into scratch memory to preserve register state until the end. 580 */ 581 ARMVectorReg scratch = { }; 582 583 reg_off = info.reg_off_first[0]; 584 reg_last = info.reg_off_last[1]; 585 if (reg_last < 0) { 586 reg_last = info.reg_off_split; 587 if (reg_last < 0) { 588 reg_last = info.reg_off_last[0]; 589 } 590 } 591 592 do { 593 uint64_t pg = vg[reg_off >> 6]; 594 do { 595 if ((pg >> (reg_off & 63)) & 1) { 596 tlb_fn(env, &scratch, reg_off, addr + reg_off, ra); 597 } 598 reg_off += esize; 599 } while (reg_off & 63); 600 } while (reg_off <= reg_last); 601 602 cpy_fn(za, &scratch, reg_max); 603 return; 604 #endif 605 } 606 607 /* The entire operation is in RAM, on valid pages. */ 608 609 reg_off = info.reg_off_first[0]; 610 reg_last = info.reg_off_last[0]; 611 host = info.page[0].host; 612 613 if (!vertical) { 614 memset(za, 0, reg_max); 615 } else if (reg_off) { 616 clr_fn(za, 0, reg_off); 617 } 618 619 set_helper_retaddr(ra); 620 621 while (reg_off <= reg_last) { 622 uint64_t pg = vg[reg_off >> 6]; 623 do { 624 if ((pg >> (reg_off & 63)) & 1) { 625 host_fn(za, reg_off, host + reg_off); 626 } else if (vertical) { 627 clr_fn(za, reg_off, esize); 628 } 629 reg_off += esize; 630 } while (reg_off <= reg_last && (reg_off & 63)); 631 } 632 633 clear_helper_retaddr(); 634 635 /* 636 * Use the slow path to manage the cross-page misalignment. 637 * But we know this is RAM and cannot trap. 638 */ 639 reg_off = info.reg_off_split; 640 if (unlikely(reg_off >= 0)) { 641 tlb_fn(env, za, reg_off, addr + reg_off, ra); 642 } 643 644 reg_off = info.reg_off_first[1]; 645 if (unlikely(reg_off >= 0)) { 646 reg_last = info.reg_off_last[1]; 647 host = info.page[1].host; 648 649 set_helper_retaddr(ra); 650 651 do { 652 uint64_t pg = vg[reg_off >> 6]; 653 do { 654 if ((pg >> (reg_off & 63)) & 1) { 655 host_fn(za, reg_off, host + reg_off); 656 } else if (vertical) { 657 clr_fn(za, reg_off, esize); 658 } 659 reg_off += esize; 660 } while (reg_off & 63); 661 } while (reg_off <= reg_last); 662 663 clear_helper_retaddr(); 664 } 665 } 666 667 static inline QEMU_ALWAYS_INLINE 668 void sme_ld1_mte(CPUARMState *env, void *za, uint64_t *vg, 669 target_ulong addr, uint64_t desc, uintptr_t ra, 670 const int esz, bool vertical, 671 sve_ldst1_host_fn *host_fn, 672 sve_ldst1_tlb_fn *tlb_fn, 673 ClearFn *clr_fn, 674 CopyFn *cpy_fn) 675 { 676 uint32_t mtedesc = desc >> 32; 677 int bit55 = extract64(addr, 55, 1); 678 679 /* Perform gross MTE suppression early. */ 680 if (!tbi_check(mtedesc, bit55) || 681 tcma_check(mtedesc, bit55, allocation_tag_from_addr(addr))) { 682 mtedesc = 0; 683 } 684 685 sme_ld1(env, za, vg, addr, desc, ra, esz, mtedesc, vertical, 686 host_fn, tlb_fn, clr_fn, cpy_fn); 687 } 688 689 #define DO_LD(L, END, ESZ) \ 690 void HELPER(sme_ld1##L##END##_h)(CPUARMState *env, void *za, void *vg, \ 691 target_ulong addr, uint64_t desc) \ 692 { \ 693 sme_ld1(env, za, vg, addr, desc, GETPC(), ESZ, 0, false, \ 694 sve_ld1##L##L##END##_host, sve_ld1##L##L##END##_tlb, \ 695 clear_horizontal, copy_horizontal); \ 696 } \ 697 void HELPER(sme_ld1##L##END##_v)(CPUARMState *env, void *za, void *vg, \ 698 target_ulong addr, uint64_t desc) \ 699 { \ 700 sme_ld1(env, za, vg, addr, desc, GETPC(), ESZ, 0, true, \ 701 sme_ld1##L##END##_v_host, sme_ld1##L##END##_v_tlb, \ 702 clear_vertical_##L, copy_vertical_##L); \ 703 } \ 704 void HELPER(sme_ld1##L##END##_h_mte)(CPUARMState *env, void *za, void *vg, \ 705 target_ulong addr, uint64_t desc) \ 706 { \ 707 sme_ld1_mte(env, za, vg, addr, desc, GETPC(), ESZ, false, \ 708 sve_ld1##L##L##END##_host, sve_ld1##L##L##END##_tlb, \ 709 clear_horizontal, copy_horizontal); \ 710 } \ 711 void HELPER(sme_ld1##L##END##_v_mte)(CPUARMState *env, void *za, void *vg, \ 712 target_ulong addr, uint64_t desc) \ 713 { \ 714 sme_ld1_mte(env, za, vg, addr, desc, GETPC(), ESZ, true, \ 715 sme_ld1##L##END##_v_host, sme_ld1##L##END##_v_tlb, \ 716 clear_vertical_##L, copy_vertical_##L); \ 717 } 718 719 DO_LD(b, , MO_8) 720 DO_LD(h, _be, MO_16) 721 DO_LD(h, _le, MO_16) 722 DO_LD(s, _be, MO_32) 723 DO_LD(s, _le, MO_32) 724 DO_LD(d, _be, MO_64) 725 DO_LD(d, _le, MO_64) 726 DO_LD(q, _be, MO_128) 727 DO_LD(q, _le, MO_128) 728 729 #undef DO_LD 730 731 /* 732 * Common helper for all contiguous predicated stores. 733 */ 734 735 static inline QEMU_ALWAYS_INLINE 736 void sme_st1(CPUARMState *env, void *za, uint64_t *vg, 737 const target_ulong addr, uint32_t desc, const uintptr_t ra, 738 const int esz, uint32_t mtedesc, bool vertical, 739 sve_ldst1_host_fn *host_fn, 740 sve_ldst1_tlb_fn *tlb_fn) 741 { 742 const intptr_t reg_max = simd_oprsz(desc); 743 const intptr_t esize = 1 << esz; 744 intptr_t reg_off, reg_last; 745 SVEContLdSt info; 746 void *host; 747 int flags; 748 749 /* Find the active elements. */ 750 if (!sve_cont_ldst_elements(&info, addr, vg, reg_max, esz, esize)) { 751 /* The entire predicate was false; no store occurs. */ 752 return; 753 } 754 755 /* Probe the page(s). Exit with exception for any invalid page. */ 756 sve_cont_ldst_pages(&info, FAULT_ALL, env, addr, MMU_DATA_STORE, ra); 757 758 /* Handle watchpoints for all active elements. */ 759 sve_cont_ldst_watchpoints(&info, env, vg, addr, esize, esize, 760 BP_MEM_WRITE, ra); 761 762 /* 763 * Handle mte checks for all active elements. 764 * Since TBI must be set for MTE, !mtedesc => !mte_active. 765 */ 766 if (mtedesc) { 767 sve_cont_ldst_mte_check(&info, env, vg, addr, esize, esize, 768 mtedesc, ra); 769 } 770 771 flags = info.page[0].flags | info.page[1].flags; 772 if (unlikely(flags != 0)) { 773 #ifdef CONFIG_USER_ONLY 774 g_assert_not_reached(); 775 #else 776 /* 777 * At least one page includes MMIO. 778 * Any bus operation can fail with cpu_transaction_failed, 779 * which for ARM will raise SyncExternal. We cannot avoid 780 * this fault and will leave with the store incomplete. 781 */ 782 reg_off = info.reg_off_first[0]; 783 reg_last = info.reg_off_last[1]; 784 if (reg_last < 0) { 785 reg_last = info.reg_off_split; 786 if (reg_last < 0) { 787 reg_last = info.reg_off_last[0]; 788 } 789 } 790 791 do { 792 uint64_t pg = vg[reg_off >> 6]; 793 do { 794 if ((pg >> (reg_off & 63)) & 1) { 795 tlb_fn(env, za, reg_off, addr + reg_off, ra); 796 } 797 reg_off += esize; 798 } while (reg_off & 63); 799 } while (reg_off <= reg_last); 800 return; 801 #endif 802 } 803 804 reg_off = info.reg_off_first[0]; 805 reg_last = info.reg_off_last[0]; 806 host = info.page[0].host; 807 808 set_helper_retaddr(ra); 809 810 while (reg_off <= reg_last) { 811 uint64_t pg = vg[reg_off >> 6]; 812 do { 813 if ((pg >> (reg_off & 63)) & 1) { 814 host_fn(za, reg_off, host + reg_off); 815 } 816 reg_off += 1 << esz; 817 } while (reg_off <= reg_last && (reg_off & 63)); 818 } 819 820 clear_helper_retaddr(); 821 822 /* 823 * Use the slow path to manage the cross-page misalignment. 824 * But we know this is RAM and cannot trap. 825 */ 826 reg_off = info.reg_off_split; 827 if (unlikely(reg_off >= 0)) { 828 tlb_fn(env, za, reg_off, addr + reg_off, ra); 829 } 830 831 reg_off = info.reg_off_first[1]; 832 if (unlikely(reg_off >= 0)) { 833 reg_last = info.reg_off_last[1]; 834 host = info.page[1].host; 835 836 set_helper_retaddr(ra); 837 838 do { 839 uint64_t pg = vg[reg_off >> 6]; 840 do { 841 if ((pg >> (reg_off & 63)) & 1) { 842 host_fn(za, reg_off, host + reg_off); 843 } 844 reg_off += 1 << esz; 845 } while (reg_off & 63); 846 } while (reg_off <= reg_last); 847 848 clear_helper_retaddr(); 849 } 850 } 851 852 static inline QEMU_ALWAYS_INLINE 853 void sme_st1_mte(CPUARMState *env, void *za, uint64_t *vg, target_ulong addr, 854 uint64_t desc, uintptr_t ra, int esz, bool vertical, 855 sve_ldst1_host_fn *host_fn, 856 sve_ldst1_tlb_fn *tlb_fn) 857 { 858 uint32_t mtedesc = desc >> 32; 859 int bit55 = extract64(addr, 55, 1); 860 861 /* Perform gross MTE suppression early. */ 862 if (!tbi_check(mtedesc, bit55) || 863 tcma_check(mtedesc, bit55, allocation_tag_from_addr(addr))) { 864 mtedesc = 0; 865 } 866 867 sme_st1(env, za, vg, addr, desc, ra, esz, mtedesc, 868 vertical, host_fn, tlb_fn); 869 } 870 871 #define DO_ST(L, END, ESZ) \ 872 void HELPER(sme_st1##L##END##_h)(CPUARMState *env, void *za, void *vg, \ 873 target_ulong addr, uint64_t desc) \ 874 { \ 875 sme_st1(env, za, vg, addr, desc, GETPC(), ESZ, 0, false, \ 876 sve_st1##L##L##END##_host, sve_st1##L##L##END##_tlb); \ 877 } \ 878 void HELPER(sme_st1##L##END##_v)(CPUARMState *env, void *za, void *vg, \ 879 target_ulong addr, uint64_t desc) \ 880 { \ 881 sme_st1(env, za, vg, addr, desc, GETPC(), ESZ, 0, true, \ 882 sme_st1##L##END##_v_host, sme_st1##L##END##_v_tlb); \ 883 } \ 884 void HELPER(sme_st1##L##END##_h_mte)(CPUARMState *env, void *za, void *vg, \ 885 target_ulong addr, uint64_t desc) \ 886 { \ 887 sme_st1_mte(env, za, vg, addr, desc, GETPC(), ESZ, false, \ 888 sve_st1##L##L##END##_host, sve_st1##L##L##END##_tlb); \ 889 } \ 890 void HELPER(sme_st1##L##END##_v_mte)(CPUARMState *env, void *za, void *vg, \ 891 target_ulong addr, uint64_t desc) \ 892 { \ 893 sme_st1_mte(env, za, vg, addr, desc, GETPC(), ESZ, true, \ 894 sme_st1##L##END##_v_host, sme_st1##L##END##_v_tlb); \ 895 } 896 897 DO_ST(b, , MO_8) 898 DO_ST(h, _be, MO_16) 899 DO_ST(h, _le, MO_16) 900 DO_ST(s, _be, MO_32) 901 DO_ST(s, _le, MO_32) 902 DO_ST(d, _be, MO_64) 903 DO_ST(d, _le, MO_64) 904 DO_ST(q, _be, MO_128) 905 DO_ST(q, _le, MO_128) 906 907 #undef DO_ST 908 909 void HELPER(sme_addha_s)(void *vzda, void *vzn, void *vpn, 910 void *vpm, uint32_t desc) 911 { 912 intptr_t row, col, oprsz = simd_oprsz(desc) / 4; 913 uint64_t *pn = vpn, *pm = vpm; 914 uint32_t *zda = vzda, *zn = vzn; 915 916 for (row = 0; row < oprsz; ) { 917 uint64_t pa = pn[row >> 4]; 918 do { 919 if (pa & 1) { 920 for (col = 0; col < oprsz; ) { 921 uint64_t pb = pm[col >> 4]; 922 do { 923 if (pb & 1) { 924 zda[tile_vslice_index(row) + H4(col)] += zn[H4(col)]; 925 } 926 pb >>= 4; 927 } while (++col & 15); 928 } 929 } 930 pa >>= 4; 931 } while (++row & 15); 932 } 933 } 934 935 void HELPER(sme_addha_d)(void *vzda, void *vzn, void *vpn, 936 void *vpm, uint32_t desc) 937 { 938 intptr_t row, col, oprsz = simd_oprsz(desc) / 8; 939 uint8_t *pn = vpn, *pm = vpm; 940 uint64_t *zda = vzda, *zn = vzn; 941 942 for (row = 0; row < oprsz; ++row) { 943 if (pn[H1(row)] & 1) { 944 for (col = 0; col < oprsz; ++col) { 945 if (pm[H1(col)] & 1) { 946 zda[tile_vslice_index(row) + col] += zn[col]; 947 } 948 } 949 } 950 } 951 } 952 953 void HELPER(sme_addva_s)(void *vzda, void *vzn, void *vpn, 954 void *vpm, uint32_t desc) 955 { 956 intptr_t row, col, oprsz = simd_oprsz(desc) / 4; 957 uint64_t *pn = vpn, *pm = vpm; 958 uint32_t *zda = vzda, *zn = vzn; 959 960 for (row = 0; row < oprsz; ) { 961 uint64_t pa = pn[row >> 4]; 962 do { 963 if (pa & 1) { 964 uint32_t zn_row = zn[H4(row)]; 965 for (col = 0; col < oprsz; ) { 966 uint64_t pb = pm[col >> 4]; 967 do { 968 if (pb & 1) { 969 zda[tile_vslice_index(row) + H4(col)] += zn_row; 970 } 971 pb >>= 4; 972 } while (++col & 15); 973 } 974 } 975 pa >>= 4; 976 } while (++row & 15); 977 } 978 } 979 980 void HELPER(sme_addva_d)(void *vzda, void *vzn, void *vpn, 981 void *vpm, uint32_t desc) 982 { 983 intptr_t row, col, oprsz = simd_oprsz(desc) / 8; 984 uint8_t *pn = vpn, *pm = vpm; 985 uint64_t *zda = vzda, *zn = vzn; 986 987 for (row = 0; row < oprsz; ++row) { 988 if (pn[H1(row)] & 1) { 989 uint64_t zn_row = zn[row]; 990 for (col = 0; col < oprsz; ++col) { 991 if (pm[H1(col)] & 1) { 992 zda[tile_vslice_index(row) + col] += zn_row; 993 } 994 } 995 } 996 } 997 } 998 999 static void do_fmopa_h(void *vza, void *vzn, void *vzm, uint16_t *pn, 1000 uint16_t *pm, float_status *fpst, uint32_t desc, 1001 uint16_t negx, int negf) 1002 { 1003 intptr_t row, col, oprsz = simd_maxsz(desc); 1004 1005 for (row = 0; row < oprsz; ) { 1006 uint16_t pa = pn[H2(row >> 4)]; 1007 do { 1008 if (pa & 1) { 1009 void *vza_row = vza + tile_vslice_offset(row); 1010 uint16_t n = *(uint32_t *)(vzn + H1_2(row)) ^ negx; 1011 1012 for (col = 0; col < oprsz; ) { 1013 uint16_t pb = pm[H2(col >> 4)]; 1014 do { 1015 if (pb & 1) { 1016 uint16_t *a = vza_row + H1_2(col); 1017 uint16_t *m = vzm + H1_2(col); 1018 *a = float16_muladd(n, *m, *a, negf, fpst); 1019 } 1020 col += 2; 1021 pb >>= 2; 1022 } while (col & 15); 1023 } 1024 } 1025 row += 2; 1026 pa >>= 2; 1027 } while (row & 15); 1028 } 1029 } 1030 1031 void HELPER(sme_fmopa_h)(void *vza, void *vzn, void *vzm, void *vpn, 1032 void *vpm, float_status *fpst, uint32_t desc) 1033 { 1034 do_fmopa_h(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0); 1035 } 1036 1037 void HELPER(sme_fmops_h)(void *vza, void *vzn, void *vzm, void *vpn, 1038 void *vpm, float_status *fpst, uint32_t desc) 1039 { 1040 do_fmopa_h(vza, vzn, vzm, vpn, vpm, fpst, desc, 1u << 15, 0); 1041 } 1042 1043 void HELPER(sme_ah_fmops_h)(void *vza, void *vzn, void *vzm, void *vpn, 1044 void *vpm, float_status *fpst, uint32_t desc) 1045 { 1046 do_fmopa_h(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 1047 float_muladd_negate_product); 1048 } 1049 1050 static void do_fmopa_s(void *vza, void *vzn, void *vzm, uint16_t *pn, 1051 uint16_t *pm, float_status *fpst, uint32_t desc, 1052 uint32_t negx, int negf) 1053 { 1054 intptr_t row, col, oprsz = simd_maxsz(desc); 1055 1056 for (row = 0; row < oprsz; ) { 1057 uint16_t pa = pn[H2(row >> 4)]; 1058 do { 1059 if (pa & 1) { 1060 void *vza_row = vza + tile_vslice_offset(row); 1061 uint32_t n = *(uint32_t *)(vzn + H1_4(row)) ^ negx; 1062 1063 for (col = 0; col < oprsz; ) { 1064 uint16_t pb = pm[H2(col >> 4)]; 1065 do { 1066 if (pb & 1) { 1067 uint32_t *a = vza_row + H1_4(col); 1068 uint32_t *m = vzm + H1_4(col); 1069 *a = float32_muladd(n, *m, *a, negf, fpst); 1070 } 1071 col += 4; 1072 pb >>= 4; 1073 } while (col & 15); 1074 } 1075 } 1076 row += 4; 1077 pa >>= 4; 1078 } while (row & 15); 1079 } 1080 } 1081 1082 void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn, 1083 void *vpm, float_status *fpst, uint32_t desc) 1084 { 1085 do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0); 1086 } 1087 1088 void HELPER(sme_fmops_s)(void *vza, void *vzn, void *vzm, void *vpn, 1089 void *vpm, float_status *fpst, uint32_t desc) 1090 { 1091 do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 1u << 31, 0); 1092 } 1093 1094 void HELPER(sme_ah_fmops_s)(void *vza, void *vzn, void *vzm, void *vpn, 1095 void *vpm, float_status *fpst, uint32_t desc) 1096 { 1097 do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 1098 float_muladd_negate_product); 1099 } 1100 1101 static void do_fmopa_d(uint64_t *za, uint64_t *zn, uint64_t *zm, uint8_t *pn, 1102 uint8_t *pm, float_status *fpst, uint32_t desc, 1103 uint64_t negx, int negf) 1104 { 1105 intptr_t row, col, oprsz = simd_oprsz(desc) / 8; 1106 1107 for (row = 0; row < oprsz; ++row) { 1108 if (pn[H1(row)] & 1) { 1109 uint64_t *za_row = &za[tile_vslice_index(row)]; 1110 uint64_t n = zn[row] ^ negx; 1111 1112 for (col = 0; col < oprsz; ++col) { 1113 if (pm[H1(col)] & 1) { 1114 uint64_t *a = &za_row[col]; 1115 *a = float64_muladd(n, zm[col], *a, negf, fpst); 1116 } 1117 } 1118 } 1119 } 1120 } 1121 1122 void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn, 1123 void *vpm, float_status *fpst, uint32_t desc) 1124 { 1125 do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0); 1126 } 1127 1128 void HELPER(sme_fmops_d)(void *vza, void *vzn, void *vzm, void *vpn, 1129 void *vpm, float_status *fpst, uint32_t desc) 1130 { 1131 do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 1ull << 63, 0); 1132 } 1133 1134 void HELPER(sme_ah_fmops_d)(void *vza, void *vzn, void *vzm, void *vpn, 1135 void *vpm, float_status *fpst, uint32_t desc) 1136 { 1137 do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 1138 float_muladd_negate_product); 1139 } 1140 1141 static void do_bfmopa(void *vza, void *vzn, void *vzm, uint16_t *pn, 1142 uint16_t *pm, float_status *fpst, uint32_t desc, 1143 uint16_t negx, int negf) 1144 { 1145 intptr_t row, col, oprsz = simd_maxsz(desc); 1146 1147 for (row = 0; row < oprsz; ) { 1148 uint16_t pa = pn[H2(row >> 4)]; 1149 do { 1150 if (pa & 1) { 1151 void *vza_row = vza + tile_vslice_offset(row); 1152 uint16_t n = *(uint32_t *)(vzn + H1_2(row)) ^ negx; 1153 1154 for (col = 0; col < oprsz; ) { 1155 uint16_t pb = pm[H2(col >> 4)]; 1156 do { 1157 if (pb & 1) { 1158 uint16_t *a = vza_row + H1_2(col); 1159 uint16_t *m = vzm + H1_2(col); 1160 *a = bfloat16_muladd(n, *m, *a, negf, fpst); 1161 } 1162 col += 2; 1163 pb >>= 2; 1164 } while (col & 15); 1165 } 1166 } 1167 row += 2; 1168 pa >>= 2; 1169 } while (row & 15); 1170 } 1171 } 1172 1173 void HELPER(sme_bfmopa)(void *vza, void *vzn, void *vzm, void *vpn, 1174 void *vpm, float_status *fpst, uint32_t desc) 1175 { 1176 do_bfmopa(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0); 1177 } 1178 1179 void HELPER(sme_bfmops)(void *vza, void *vzn, void *vzm, void *vpn, 1180 void *vpm, float_status *fpst, uint32_t desc) 1181 { 1182 do_bfmopa(vza, vzn, vzm, vpn, vpm, fpst, desc, 1u << 15, 0); 1183 } 1184 1185 void HELPER(sme_ah_bfmops)(void *vza, void *vzn, void *vzm, void *vpn, 1186 void *vpm, float_status *fpst, uint32_t desc) 1187 { 1188 do_bfmopa(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 1189 float_muladd_negate_product); 1190 } 1191 1192 /* 1193 * Alter PAIR as needed for controlling predicates being false, 1194 * and for NEG on an enabled row element. 1195 */ 1196 static inline uint32_t f16mop_adj_pair(uint32_t pair, uint32_t pg, uint32_t neg) 1197 { 1198 /* 1199 * The pseudocode uses a conditional negate after the conditional zero. 1200 * It is simpler here to unconditionally negate before conditional zero. 1201 */ 1202 pair ^= neg; 1203 if (!(pg & 1)) { 1204 pair &= 0xffff0000u; 1205 } 1206 if (!(pg & 4)) { 1207 pair &= 0x0000ffffu; 1208 } 1209 return pair; 1210 } 1211 1212 static inline uint32_t f16mop_ah_neg_adj_pair(uint32_t pair, uint32_t pg) 1213 { 1214 uint32_t l = pg & 1 ? float16_ah_chs(pair) : 0; 1215 uint32_t h = pg & 4 ? float16_ah_chs(pair >> 16) : 0; 1216 return l | (h << 16); 1217 } 1218 1219 static inline uint32_t bf16mop_ah_neg_adj_pair(uint32_t pair, uint32_t pg) 1220 { 1221 uint32_t l = pg & 1 ? bfloat16_ah_chs(pair) : 0; 1222 uint32_t h = pg & 4 ? bfloat16_ah_chs(pair >> 16) : 0; 1223 return l | (h << 16); 1224 } 1225 1226 static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2, 1227 float_status *s_f16, float_status *s_std, 1228 float_status *s_odd) 1229 { 1230 /* 1231 * We need three different float_status for different parts of this 1232 * operation: 1233 * - the input conversion of the float16 values must use the 1234 * f16-specific float_status, so that the FPCR.FZ16 control is applied 1235 * - operations on float32 including the final accumulation must use 1236 * the normal float_status, so that FPCR.FZ is applied 1237 * - we have pre-set-up copy of s_std which is set to round-to-odd, 1238 * for the multiply (see below) 1239 */ 1240 float16 h1r = e1 & 0xffff; 1241 float16 h1c = e1 >> 16; 1242 float16 h2r = e2 & 0xffff; 1243 float16 h2c = e2 >> 16; 1244 float32 t32; 1245 1246 /* C.f. FPProcessNaNs4 */ 1247 if (float16_is_any_nan(h1r) || float16_is_any_nan(h1c) || 1248 float16_is_any_nan(h2r) || float16_is_any_nan(h2c)) { 1249 float16 t16; 1250 1251 if (float16_is_signaling_nan(h1r, s_f16)) { 1252 t16 = h1r; 1253 } else if (float16_is_signaling_nan(h1c, s_f16)) { 1254 t16 = h1c; 1255 } else if (float16_is_signaling_nan(h2r, s_f16)) { 1256 t16 = h2r; 1257 } else if (float16_is_signaling_nan(h2c, s_f16)) { 1258 t16 = h2c; 1259 } else if (float16_is_any_nan(h1r)) { 1260 t16 = h1r; 1261 } else if (float16_is_any_nan(h1c)) { 1262 t16 = h1c; 1263 } else if (float16_is_any_nan(h2r)) { 1264 t16 = h2r; 1265 } else { 1266 t16 = h2c; 1267 } 1268 t32 = float16_to_float32(t16, true, s_f16); 1269 } else { 1270 float64 e1r = float16_to_float64(h1r, true, s_f16); 1271 float64 e1c = float16_to_float64(h1c, true, s_f16); 1272 float64 e2r = float16_to_float64(h2r, true, s_f16); 1273 float64 e2c = float16_to_float64(h2c, true, s_f16); 1274 float64 t64; 1275 1276 /* 1277 * The ARM pseudocode function FPDot performs both multiplies 1278 * and the add with a single rounding operation. Emulate this 1279 * by performing the first multiply in round-to-odd, then doing 1280 * the second multiply as fused multiply-add, and rounding to 1281 * float32 all in one step. 1282 */ 1283 t64 = float64_mul(e1r, e2r, s_odd); 1284 t64 = float64r32_muladd(e1c, e2c, t64, 0, s_std); 1285 1286 /* This conversion is exact, because we've already rounded. */ 1287 t32 = float64_to_float32(t64, s_std); 1288 } 1289 1290 /* The final accumulation step is not fused. */ 1291 return float32_add(sum, t32, s_std); 1292 } 1293 1294 static void do_fmopa_w_h(void *vza, void *vzn, void *vzm, uint16_t *pn, 1295 uint16_t *pm, CPUARMState *env, uint32_t desc, 1296 uint32_t negx, bool ah_neg) 1297 { 1298 intptr_t row, col, oprsz = simd_maxsz(desc); 1299 float_status fpst_odd = env->vfp.fp_status[FPST_ZA]; 1300 1301 set_float_rounding_mode(float_round_to_odd, &fpst_odd); 1302 1303 for (row = 0; row < oprsz; ) { 1304 uint16_t prow = pn[H2(row >> 4)]; 1305 do { 1306 void *vza_row = vza + tile_vslice_offset(row); 1307 uint32_t n = *(uint32_t *)(vzn + H1_4(row)); 1308 1309 if (ah_neg) { 1310 n = f16mop_ah_neg_adj_pair(n, prow); 1311 } else { 1312 n = f16mop_adj_pair(n, prow, negx); 1313 } 1314 1315 for (col = 0; col < oprsz; ) { 1316 uint16_t pcol = pm[H2(col >> 4)]; 1317 do { 1318 if (prow & pcol & 0b0101) { 1319 uint32_t *a = vza_row + H1_4(col); 1320 uint32_t m = *(uint32_t *)(vzm + H1_4(col)); 1321 1322 m = f16mop_adj_pair(m, pcol, 0); 1323 *a = f16_dotadd(*a, n, m, 1324 &env->vfp.fp_status[FPST_ZA_F16], 1325 &env->vfp.fp_status[FPST_ZA], 1326 &fpst_odd); 1327 } 1328 col += 4; 1329 pcol >>= 4; 1330 } while (col & 15); 1331 } 1332 row += 4; 1333 prow >>= 4; 1334 } while (row & 15); 1335 } 1336 } 1337 1338 void HELPER(sme_fmopa_w_h)(void *vza, void *vzn, void *vzm, void *vpn, 1339 void *vpm, CPUARMState *env, uint32_t desc) 1340 { 1341 do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0, false); 1342 } 1343 1344 void HELPER(sme_fmops_w_h)(void *vza, void *vzn, void *vzm, void *vpn, 1345 void *vpm, CPUARMState *env, uint32_t desc) 1346 { 1347 do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0x80008000u, false); 1348 } 1349 1350 void HELPER(sme_ah_fmops_w_h)(void *vza, void *vzn, void *vzm, void *vpn, 1351 void *vpm, CPUARMState *env, uint32_t desc) 1352 { 1353 do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0, true); 1354 } 1355 1356 void HELPER(sme2_fdot_h)(void *vd, void *vn, void *vm, void *va, 1357 CPUARMState *env, uint32_t desc) 1358 { 1359 intptr_t i, oprsz = simd_maxsz(desc); 1360 bool za = extract32(desc, SIMD_DATA_SHIFT, 1); 1361 float_status *fpst_std = &env->vfp.fp_status[za ? FPST_ZA : FPST_A64]; 1362 float_status *fpst_f16 = &env->vfp.fp_status[za ? FPST_ZA_F16 : FPST_A64_F16]; 1363 float_status fpst_odd = *fpst_std; 1364 float32 *d = vd, *a = va; 1365 uint32_t *n = vn, *m = vm; 1366 1367 set_float_rounding_mode(float_round_to_odd, &fpst_odd); 1368 1369 for (i = 0; i < oprsz / sizeof(float32); ++i) { 1370 d[H4(i)] = f16_dotadd(a[H4(i)], n[H4(i)], m[H4(i)], 1371 fpst_f16, fpst_std, &fpst_odd); 1372 } 1373 } 1374 1375 void HELPER(sme2_fdot_idx_h)(void *vd, void *vn, void *vm, void *va, 1376 CPUARMState *env, uint32_t desc) 1377 { 1378 intptr_t i, j, oprsz = simd_maxsz(desc); 1379 intptr_t elements = oprsz / sizeof(float32); 1380 intptr_t eltspersegment = MIN(4, elements); 1381 int idx = extract32(desc, SIMD_DATA_SHIFT, 2); 1382 bool za = extract32(desc, SIMD_DATA_SHIFT + 2, 1); 1383 float_status *fpst_std = &env->vfp.fp_status[za ? FPST_ZA : FPST_A64]; 1384 float_status *fpst_f16 = &env->vfp.fp_status[za ? FPST_ZA_F16 : FPST_A64_F16]; 1385 float_status fpst_odd = *fpst_std; 1386 float32 *d = vd, *a = va; 1387 uint32_t *n = vn, *m = (uint32_t *)vm + H4(idx); 1388 1389 set_float_rounding_mode(float_round_to_odd, &fpst_odd); 1390 1391 for (i = 0; i < elements; i += eltspersegment) { 1392 uint32_t mm = m[i]; 1393 for (j = 0; j < eltspersegment; ++j) { 1394 d[H4(i + j)] = f16_dotadd(a[H4(i + j)], n[H4(i + j)], mm, 1395 fpst_f16, fpst_std, &fpst_odd); 1396 } 1397 } 1398 } 1399 1400 void HELPER(sme2_fvdot_idx_h)(void *vd, void *vn, void *vm, void *va, 1401 CPUARMState *env, uint32_t desc) 1402 { 1403 intptr_t i, j, oprsz = simd_maxsz(desc); 1404 intptr_t elements = oprsz / sizeof(float32); 1405 intptr_t eltspersegment = MIN(4, elements); 1406 int idx = extract32(desc, SIMD_DATA_SHIFT, 2); 1407 int sel = extract32(desc, SIMD_DATA_SHIFT + 2, 1); 1408 float_status fpst_odd, *fpst_std, *fpst_f16; 1409 float32 *d = vd, *a = va; 1410 uint16_t *n0 = vn; 1411 uint16_t *n1 = vn + sizeof(ARMVectorReg); 1412 uint32_t *m = (uint32_t *)vm + H4(idx); 1413 1414 fpst_std = &env->vfp.fp_status[FPST_ZA]; 1415 fpst_f16 = &env->vfp.fp_status[FPST_ZA_F16]; 1416 fpst_odd = *fpst_std; 1417 set_float_rounding_mode(float_round_to_odd, &fpst_odd); 1418 1419 for (i = 0; i < elements; i += eltspersegment) { 1420 uint32_t mm = m[i]; 1421 for (j = 0; j < eltspersegment; ++j) { 1422 uint32_t nn = (n0[H2(2 * (i + j) + sel)]) 1423 | (n1[H2(2 * (i + j) + sel)] << 16); 1424 d[i + H4(j)] = f16_dotadd(a[i + H4(j)], nn, mm, 1425 fpst_f16, fpst_std, &fpst_odd); 1426 } 1427 } 1428 } 1429 1430 static void do_bfmopa_w(void *vza, void *vzn, void *vzm, 1431 uint16_t *pn, uint16_t *pm, CPUARMState *env, 1432 uint32_t desc, uint32_t negx, bool ah_neg) 1433 { 1434 intptr_t row, col, oprsz = simd_maxsz(desc); 1435 float_status fpst, fpst_odd; 1436 1437 if (is_ebf(env, &fpst, &fpst_odd)) { 1438 for (row = 0; row < oprsz; ) { 1439 uint16_t prow = pn[H2(row >> 4)]; 1440 do { 1441 void *vza_row = vza + tile_vslice_offset(row); 1442 uint32_t n = *(uint32_t *)(vzn + H1_4(row)); 1443 1444 if (ah_neg) { 1445 n = bf16mop_ah_neg_adj_pair(n, prow); 1446 } else { 1447 n = f16mop_adj_pair(n, prow, negx); 1448 } 1449 1450 for (col = 0; col < oprsz; ) { 1451 uint16_t pcol = pm[H2(col >> 4)]; 1452 do { 1453 if (prow & pcol & 0b0101) { 1454 uint32_t *a = vza_row + H1_4(col); 1455 uint32_t m = *(uint32_t *)(vzm + H1_4(col)); 1456 1457 m = f16mop_adj_pair(m, pcol, 0); 1458 *a = bfdotadd_ebf(*a, n, m, &fpst, &fpst_odd); 1459 } 1460 col += 4; 1461 pcol >>= 4; 1462 } while (col & 15); 1463 } 1464 row += 4; 1465 prow >>= 4; 1466 } while (row & 15); 1467 } 1468 } else { 1469 for (row = 0; row < oprsz; ) { 1470 uint16_t prow = pn[H2(row >> 4)]; 1471 do { 1472 void *vza_row = vza + tile_vslice_offset(row); 1473 uint32_t n = *(uint32_t *)(vzn + H1_4(row)); 1474 1475 if (ah_neg) { 1476 n = bf16mop_ah_neg_adj_pair(n, prow); 1477 } else { 1478 n = f16mop_adj_pair(n, prow, negx); 1479 } 1480 1481 for (col = 0; col < oprsz; ) { 1482 uint16_t pcol = pm[H2(col >> 4)]; 1483 do { 1484 if (prow & pcol & 0b0101) { 1485 uint32_t *a = vza_row + H1_4(col); 1486 uint32_t m = *(uint32_t *)(vzm + H1_4(col)); 1487 1488 m = f16mop_adj_pair(m, pcol, 0); 1489 *a = bfdotadd(*a, n, m, &fpst); 1490 } 1491 col += 4; 1492 pcol >>= 4; 1493 } while (col & 15); 1494 } 1495 row += 4; 1496 prow >>= 4; 1497 } while (row & 15); 1498 } 1499 } 1500 } 1501 1502 void HELPER(sme_bfmopa_w)(void *vza, void *vzn, void *vzm, void *vpn, 1503 void *vpm, CPUARMState *env, uint32_t desc) 1504 { 1505 do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0, false); 1506 } 1507 1508 void HELPER(sme_bfmops_w)(void *vza, void *vzn, void *vzm, void *vpn, 1509 void *vpm, CPUARMState *env, uint32_t desc) 1510 { 1511 do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0x80008000u, false); 1512 } 1513 1514 void HELPER(sme_ah_bfmops_w)(void *vza, void *vzn, void *vzm, void *vpn, 1515 void *vpm, CPUARMState *env, uint32_t desc) 1516 { 1517 do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0, true); 1518 } 1519 1520 typedef uint32_t IMOPFn32(uint32_t, uint32_t, uint32_t, uint8_t, bool); 1521 static inline void do_imopa_s(uint32_t *za, uint32_t *zn, uint32_t *zm, 1522 uint8_t *pn, uint8_t *pm, 1523 uint32_t desc, IMOPFn32 *fn) 1524 { 1525 intptr_t row, col, oprsz = simd_oprsz(desc) / 4; 1526 bool neg = simd_data(desc); 1527 1528 for (row = 0; row < oprsz; ++row) { 1529 uint8_t pa = (pn[H1(row >> 1)] >> ((row & 1) * 4)) & 0xf; 1530 uint32_t *za_row = &za[tile_vslice_index(row)]; 1531 uint32_t n = zn[H4(row)]; 1532 1533 for (col = 0; col < oprsz; ++col) { 1534 uint8_t pb = pm[H1(col >> 1)] >> ((col & 1) * 4); 1535 uint32_t *a = &za_row[H4(col)]; 1536 1537 *a = fn(n, zm[H4(col)], *a, pa & pb, neg); 1538 } 1539 } 1540 } 1541 1542 typedef uint64_t IMOPFn64(uint64_t, uint64_t, uint64_t, uint8_t, bool); 1543 static inline void do_imopa_d(uint64_t *za, uint64_t *zn, uint64_t *zm, 1544 uint8_t *pn, uint8_t *pm, 1545 uint32_t desc, IMOPFn64 *fn) 1546 { 1547 intptr_t row, col, oprsz = simd_oprsz(desc) / 8; 1548 bool neg = simd_data(desc); 1549 1550 for (row = 0; row < oprsz; ++row) { 1551 uint8_t pa = pn[H1(row)]; 1552 uint64_t *za_row = &za[tile_vslice_index(row)]; 1553 uint64_t n = zn[row]; 1554 1555 for (col = 0; col < oprsz; ++col) { 1556 uint8_t pb = pm[H1(col)]; 1557 uint64_t *a = &za_row[col]; 1558 1559 *a = fn(n, zm[col], *a, pa & pb, neg); 1560 } 1561 } 1562 } 1563 1564 #define DEF_IMOP_8x4_32(NAME, NTYPE, MTYPE) \ 1565 static uint32_t NAME(uint32_t n, uint32_t m, uint32_t a, uint8_t p, bool neg) \ 1566 { \ 1567 uint32_t sum = 0; \ 1568 /* Apply P to N as a mask, making the inactive elements 0. */ \ 1569 n &= expand_pred_b(p); \ 1570 sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0); \ 1571 sum += (NTYPE)(n >> 8) * (MTYPE)(m >> 8); \ 1572 sum += (NTYPE)(n >> 16) * (MTYPE)(m >> 16); \ 1573 sum += (NTYPE)(n >> 24) * (MTYPE)(m >> 24); \ 1574 return neg ? a - sum : a + sum; \ 1575 } 1576 1577 #define DEF_IMOP_16x4_64(NAME, NTYPE, MTYPE) \ 1578 static uint64_t NAME(uint64_t n, uint64_t m, uint64_t a, uint8_t p, bool neg) \ 1579 { \ 1580 uint64_t sum = 0; \ 1581 /* Apply P to N as a mask, making the inactive elements 0. */ \ 1582 n &= expand_pred_h(p); \ 1583 sum += (int64_t)(NTYPE)(n >> 0) * (MTYPE)(m >> 0); \ 1584 sum += (int64_t)(NTYPE)(n >> 16) * (MTYPE)(m >> 16); \ 1585 sum += (int64_t)(NTYPE)(n >> 32) * (MTYPE)(m >> 32); \ 1586 sum += (int64_t)(NTYPE)(n >> 48) * (MTYPE)(m >> 48); \ 1587 return neg ? a - sum : a + sum; \ 1588 } 1589 1590 DEF_IMOP_8x4_32(smopa_s, int8_t, int8_t) 1591 DEF_IMOP_8x4_32(umopa_s, uint8_t, uint8_t) 1592 DEF_IMOP_8x4_32(sumopa_s, int8_t, uint8_t) 1593 DEF_IMOP_8x4_32(usmopa_s, uint8_t, int8_t) 1594 1595 DEF_IMOP_16x4_64(smopa_d, int16_t, int16_t) 1596 DEF_IMOP_16x4_64(umopa_d, uint16_t, uint16_t) 1597 DEF_IMOP_16x4_64(sumopa_d, int16_t, uint16_t) 1598 DEF_IMOP_16x4_64(usmopa_d, uint16_t, int16_t) 1599 1600 #define DEF_IMOPH(P, NAME, S) \ 1601 void HELPER(P##_##NAME##_##S)(void *vza, void *vzn, void *vzm, \ 1602 void *vpn, void *vpm, uint32_t desc) \ 1603 { do_imopa_##S(vza, vzn, vzm, vpn, vpm, desc, NAME##_##S); } 1604 1605 DEF_IMOPH(sme, smopa, s) 1606 DEF_IMOPH(sme, umopa, s) 1607 DEF_IMOPH(sme, sumopa, s) 1608 DEF_IMOPH(sme, usmopa, s) 1609 1610 DEF_IMOPH(sme, smopa, d) 1611 DEF_IMOPH(sme, umopa, d) 1612 DEF_IMOPH(sme, sumopa, d) 1613 DEF_IMOPH(sme, usmopa, d) 1614 1615 static uint32_t bmopa_s(uint32_t n, uint32_t m, uint32_t a, uint8_t p, bool neg) 1616 { 1617 uint32_t sum = ctpop32(~(n ^ m)); 1618 if (neg) { 1619 sum = -sum; 1620 } 1621 if (!(p & 1)) { 1622 sum = 0; 1623 } 1624 return a + sum; 1625 } 1626 1627 DEF_IMOPH(sme2, bmopa, s) 1628 1629 #define DEF_IMOP_16x2_32(NAME, NTYPE, MTYPE) \ 1630 static uint32_t NAME(uint32_t n, uint32_t m, uint32_t a, uint8_t p, bool neg) \ 1631 { \ 1632 uint32_t sum = 0; \ 1633 /* Apply P to N as a mask, making the inactive elements 0. */ \ 1634 n &= expand_pred_h(p); \ 1635 sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0); \ 1636 sum += (NTYPE)(n >> 16) * (MTYPE)(m >> 16); \ 1637 return neg ? a - sum : a + sum; \ 1638 } 1639 1640 DEF_IMOP_16x2_32(smopa2_s, int16_t, int16_t) 1641 DEF_IMOP_16x2_32(umopa2_s, uint16_t, uint16_t) 1642 1643 DEF_IMOPH(sme2, smopa2, s) 1644 DEF_IMOPH(sme2, umopa2, s) 1645 1646 #define DO_VDOT_IDX(NAME, TYPED, TYPEN, TYPEM, HD, HN) \ 1647 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc) \ 1648 { \ 1649 intptr_t svl = simd_oprsz(desc); \ 1650 intptr_t elements = svl / sizeof(TYPED); \ 1651 intptr_t eltperseg = 16 / sizeof(TYPED); \ 1652 intptr_t nreg = sizeof(TYPED) / sizeof(TYPEN); \ 1653 intptr_t vstride = (svl / nreg) * sizeof(ARMVectorReg); \ 1654 intptr_t zstride = sizeof(ARMVectorReg) / sizeof(TYPEN); \ 1655 intptr_t idx = extract32(desc, SIMD_DATA_SHIFT, 2); \ 1656 TYPEN *n = vn; \ 1657 TYPEM *m = vm; \ 1658 for (intptr_t r = 0; r < nreg; r++) { \ 1659 TYPED *d = vd + r * vstride; \ 1660 for (intptr_t seg = 0; seg < elements; seg += eltperseg) { \ 1661 intptr_t s = seg + idx; \ 1662 for (intptr_t e = seg; e < seg + eltperseg; e++) { \ 1663 TYPED sum = d[HD(e)]; \ 1664 for (intptr_t i = 0; i < nreg; i++) { \ 1665 TYPED nn = n[i * zstride + HN(nreg * e + r)]; \ 1666 TYPED mm = m[HN(nreg * s + i)]; \ 1667 sum += nn * mm; \ 1668 } \ 1669 d[HD(e)] = sum; \ 1670 } \ 1671 } \ 1672 } \ 1673 } 1674 1675 DO_VDOT_IDX(sme2_svdot_idx_4b, int32_t, int8_t, int8_t, H4, H1) 1676 DO_VDOT_IDX(sme2_uvdot_idx_4b, uint32_t, uint8_t, uint8_t, H4, H1) 1677 DO_VDOT_IDX(sme2_suvdot_idx_4b, int32_t, int8_t, uint8_t, H4, H1) 1678 DO_VDOT_IDX(sme2_usvdot_idx_4b, int32_t, uint8_t, int8_t, H4, H1) 1679 1680 DO_VDOT_IDX(sme2_svdot_idx_4h, int64_t, int16_t, int16_t, H8, H2) 1681 DO_VDOT_IDX(sme2_uvdot_idx_4h, uint64_t, uint16_t, uint16_t, H8, H2) 1682 1683 DO_VDOT_IDX(sme2_svdot_idx_2h, int32_t, int16_t, int16_t, H4, H2) 1684 DO_VDOT_IDX(sme2_uvdot_idx_2h, uint32_t, uint16_t, uint16_t, H4, H2) 1685 1686 #undef DO_VDOT_IDX 1687 1688 #define DO_MLALL(NAME, TYPEW, TYPEN, TYPEM, HW, HN, OP) \ 1689 void HELPER(NAME)(void *vd, void *vn, void *vm, void *va, uint32_t desc) \ 1690 { \ 1691 intptr_t elements = simd_oprsz(desc) / sizeof(TYPEW); \ 1692 intptr_t sel = extract32(desc, SIMD_DATA_SHIFT, 2); \ 1693 TYPEW *d = vd, *a = va; TYPEN *n = vn; TYPEM *m = vm; \ 1694 for (intptr_t i = 0; i < elements; ++i) { \ 1695 TYPEW nn = n[HN(i * 4 + sel)]; \ 1696 TYPEM mm = m[HN(i * 4 + sel)]; \ 1697 d[HW(i)] = a[HW(i)] OP (nn * mm); \ 1698 } \ 1699 } 1700 1701 DO_MLALL(sme2_smlall_s, int32_t, int8_t, int8_t, H4, H1, +) 1702 DO_MLALL(sme2_smlall_d, int64_t, int16_t, int16_t, H8, H2, +) 1703 DO_MLALL(sme2_smlsll_s, int32_t, int8_t, int8_t, H4, H1, -) 1704 DO_MLALL(sme2_smlsll_d, int64_t, int16_t, int16_t, H8, H2, -) 1705 1706 DO_MLALL(sme2_umlall_s, uint32_t, uint8_t, uint8_t, H4, H1, +) 1707 DO_MLALL(sme2_umlall_d, uint64_t, uint16_t, uint16_t, H8, H2, +) 1708 DO_MLALL(sme2_umlsll_s, uint32_t, uint8_t, uint8_t, H4, H1, -) 1709 DO_MLALL(sme2_umlsll_d, uint64_t, uint16_t, uint16_t, H8, H2, -) 1710 1711 DO_MLALL(sme2_usmlall_s, uint32_t, uint8_t, int8_t, H4, H1, +) 1712 1713 #undef DO_MLALL 1714 1715 #define DO_MLALL_IDX(NAME, TYPEW, TYPEN, TYPEM, HW, HN, OP) \ 1716 void HELPER(NAME)(void *vd, void *vn, void *vm, void *va, uint32_t desc) \ 1717 { \ 1718 intptr_t elements = simd_oprsz(desc) / sizeof(TYPEW); \ 1719 intptr_t eltspersegment = 16 / sizeof(TYPEW); \ 1720 intptr_t sel = extract32(desc, SIMD_DATA_SHIFT, 2); \ 1721 intptr_t idx = extract32(desc, SIMD_DATA_SHIFT + 2, 4); \ 1722 TYPEW *d = vd, *a = va; TYPEN *n = vn; TYPEM *m = vm; \ 1723 for (intptr_t i = 0; i < elements; i += eltspersegment) { \ 1724 TYPEW mm = m[HN(i * 4 + idx)]; \ 1725 for (intptr_t j = 0; j < eltspersegment; ++j) { \ 1726 TYPEN nn = n[HN((i + j) * 4 + sel)]; \ 1727 d[HW(i + j)] = a[HW(i + j)] OP (nn * mm); \ 1728 } \ 1729 } \ 1730 } 1731 1732 DO_MLALL_IDX(sme2_smlall_idx_s, int32_t, int8_t, int8_t, H4, H1, +) 1733 DO_MLALL_IDX(sme2_smlall_idx_d, int64_t, int16_t, int16_t, H8, H2, +) 1734 DO_MLALL_IDX(sme2_smlsll_idx_s, int32_t, int8_t, int8_t, H4, H1, -) 1735 DO_MLALL_IDX(sme2_smlsll_idx_d, int64_t, int16_t, int16_t, H8, H2, -) 1736 1737 DO_MLALL_IDX(sme2_umlall_idx_s, uint32_t, uint8_t, uint8_t, H4, H1, +) 1738 DO_MLALL_IDX(sme2_umlall_idx_d, uint64_t, uint16_t, uint16_t, H8, H2, +) 1739 DO_MLALL_IDX(sme2_umlsll_idx_s, uint32_t, uint8_t, uint8_t, H4, H1, -) 1740 DO_MLALL_IDX(sme2_umlsll_idx_d, uint64_t, uint16_t, uint16_t, H8, H2, -) 1741 1742 DO_MLALL_IDX(sme2_usmlall_idx_s, uint32_t, uint8_t, int8_t, H4, H1, +) 1743 DO_MLALL_IDX(sme2_sumlall_idx_s, uint32_t, int8_t, uint8_t, H4, H1, +) 1744 1745 #undef DO_MLALL_IDX 1746 1747 /* Convert and compress */ 1748 void HELPER(sme2_bfcvt)(void *vd, void *vs, float_status *fpst, uint32_t desc) 1749 { 1750 ARMVectorReg scratch; 1751 size_t oprsz = simd_oprsz(desc); 1752 size_t i, n = oprsz / 4; 1753 float32 *s0 = vs; 1754 float32 *s1 = vs + sizeof(ARMVectorReg); 1755 bfloat16 *d = vd; 1756 1757 if (vd == s1) { 1758 s1 = memcpy(&scratch, s1, oprsz); 1759 } 1760 1761 for (i = 0; i < n; ++i) { 1762 d[H2(i)] = float32_to_bfloat16(s0[H4(i)], fpst); 1763 } 1764 for (i = 0; i < n; ++i) { 1765 d[H2(i) + n] = float32_to_bfloat16(s1[H4(i)], fpst); 1766 } 1767 } 1768 1769 void HELPER(sme2_fcvt_n)(void *vd, void *vs, float_status *fpst, uint32_t desc) 1770 { 1771 ARMVectorReg scratch; 1772 size_t oprsz = simd_oprsz(desc); 1773 size_t i, n = oprsz / 4; 1774 float32 *s0 = vs; 1775 float32 *s1 = vs + sizeof(ARMVectorReg); 1776 float16 *d = vd; 1777 1778 if (vd == s1) { 1779 s1 = memcpy(&scratch, s1, oprsz); 1780 } 1781 1782 for (i = 0; i < n; ++i) { 1783 d[H2(i)] = sve_f32_to_f16(s0[H4(i)], fpst); 1784 } 1785 for (i = 0; i < n; ++i) { 1786 d[H2(i) + n] = sve_f32_to_f16(s1[H4(i)], fpst); 1787 } 1788 } 1789 1790 #define SQCVT2(NAME, TW, TN, HW, HN, SAT) \ 1791 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 1792 { \ 1793 ARMVectorReg scratch; \ 1794 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 1795 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 1796 TN *d = vd; \ 1797 if (vectors_overlap(vd, 1, vs, 2)) { \ 1798 d = (TN *)&scratch; \ 1799 } \ 1800 for (size_t i = 0; i < n; ++i) { \ 1801 d[HN(i)] = SAT(s0[HW(i)]); \ 1802 d[HN(i + n)] = SAT(s1[HW(i)]); \ 1803 } \ 1804 if (d != vd) { \ 1805 memcpy(vd, d, oprsz); \ 1806 } \ 1807 } 1808 1809 SQCVT2(sme2_sqcvt_sh, int32_t, int16_t, H4, H2, do_ssat_h) 1810 SQCVT2(sme2_uqcvt_sh, uint32_t, uint16_t, H4, H2, do_usat_h) 1811 SQCVT2(sme2_sqcvtu_sh, int32_t, uint16_t, H4, H2, do_usat_h) 1812 1813 #undef SQCVT2 1814 1815 #define SQCVT4(NAME, TW, TN, HW, HN, SAT) \ 1816 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 1817 { \ 1818 ARMVectorReg scratch; \ 1819 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 1820 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 1821 TW *s2 = vs + 2 * sizeof(ARMVectorReg); \ 1822 TW *s3 = vs + 3 * sizeof(ARMVectorReg); \ 1823 TN *d = vd; \ 1824 if (vectors_overlap(vd, 1, vs, 4)) { \ 1825 d = (TN *)&scratch; \ 1826 } \ 1827 for (size_t i = 0; i < n; ++i) { \ 1828 d[HN(i)] = SAT(s0[HW(i)]); \ 1829 d[HN(i + n)] = SAT(s1[HW(i)]); \ 1830 d[HN(i + 2 * n)] = SAT(s2[HW(i)]); \ 1831 d[HN(i + 3 * n)] = SAT(s3[HW(i)]); \ 1832 } \ 1833 if (d != vd) { \ 1834 memcpy(vd, d, oprsz); \ 1835 } \ 1836 } 1837 1838 SQCVT4(sme2_sqcvt_sb, int32_t, int8_t, H4, H2, do_ssat_b) 1839 SQCVT4(sme2_uqcvt_sb, uint32_t, uint8_t, H4, H2, do_usat_b) 1840 SQCVT4(sme2_sqcvtu_sb, int32_t, uint8_t, H4, H2, do_usat_b) 1841 1842 SQCVT4(sme2_sqcvt_dh, int64_t, int16_t, H8, H2, do_ssat_h) 1843 SQCVT4(sme2_uqcvt_dh, uint64_t, uint16_t, H8, H2, do_usat_h) 1844 SQCVT4(sme2_sqcvtu_dh, int64_t, uint16_t, H8, H2, do_usat_h) 1845 1846 #undef SQCVT4 1847 1848 #define SQRSHR2(NAME, TW, TN, HW, HN, RSHR, SAT) \ 1849 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 1850 { \ 1851 ARMVectorReg scratch; \ 1852 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 1853 int shift = simd_data(desc); \ 1854 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 1855 TN *d = vd; \ 1856 if (vectors_overlap(vd, 1, vs, 2)) { \ 1857 d = (TN *)&scratch; \ 1858 } \ 1859 for (size_t i = 0; i < n; ++i) { \ 1860 d[HN(i)] = SAT(RSHR(s0[HW(i)], shift)); \ 1861 d[HN(i + n)] = SAT(RSHR(s1[HW(i)], shift)); \ 1862 } \ 1863 if (d != vd) { \ 1864 memcpy(vd, d, oprsz); \ 1865 } \ 1866 } 1867 1868 SQRSHR2(sme2_sqrshr_sh, int32_t, int16_t, H4, H2, do_srshr, do_ssat_h) 1869 SQRSHR2(sme2_uqrshr_sh, uint32_t, uint16_t, H4, H2, do_urshr, do_usat_h) 1870 SQRSHR2(sme2_sqrshru_sh, int32_t, uint16_t, H4, H2, do_srshr, do_usat_h) 1871 1872 #undef SQRSHR2 1873 1874 #define SQRSHR4(NAME, TW, TN, HW, HN, RSHR, SAT) \ 1875 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 1876 { \ 1877 ARMVectorReg scratch; \ 1878 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 1879 int shift = simd_data(desc); \ 1880 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 1881 TW *s2 = vs + 2 * sizeof(ARMVectorReg); \ 1882 TW *s3 = vs + 3 * sizeof(ARMVectorReg); \ 1883 TN *d = vd; \ 1884 if (vectors_overlap(vd, 1, vs, 4)) { \ 1885 d = (TN *)&scratch; \ 1886 } \ 1887 for (size_t i = 0; i < n; ++i) { \ 1888 d[HN(i)] = SAT(RSHR(s0[HW(i)], shift)); \ 1889 d[HN(i + n)] = SAT(RSHR(s1[HW(i)], shift)); \ 1890 d[HN(i + 2 * n)] = SAT(RSHR(s2[HW(i)], shift)); \ 1891 d[HN(i + 3 * n)] = SAT(RSHR(s3[HW(i)], shift)); \ 1892 } \ 1893 if (d != vd) { \ 1894 memcpy(vd, d, oprsz); \ 1895 } \ 1896 } 1897 1898 SQRSHR4(sme2_sqrshr_sb, int32_t, int8_t, H4, H2, do_srshr, do_ssat_b) 1899 SQRSHR4(sme2_uqrshr_sb, uint32_t, uint8_t, H4, H2, do_urshr, do_usat_b) 1900 SQRSHR4(sme2_sqrshru_sb, int32_t, uint8_t, H4, H2, do_srshr, do_usat_b) 1901 1902 SQRSHR4(sme2_sqrshr_dh, int64_t, int16_t, H8, H2, do_srshr, do_ssat_h) 1903 SQRSHR4(sme2_uqrshr_dh, uint64_t, uint16_t, H8, H2, do_urshr, do_usat_h) 1904 SQRSHR4(sme2_sqrshru_dh, int64_t, uint16_t, H8, H2, do_srshr, do_usat_h) 1905 1906 #undef SQRSHR4 1907 1908 /* Convert and interleave */ 1909 void HELPER(sme2_bfcvtn)(void *vd, void *vs, float_status *fpst, uint32_t desc) 1910 { 1911 size_t i, n = simd_oprsz(desc) / 4; 1912 float32 *s0 = vs; 1913 float32 *s1 = vs + sizeof(ARMVectorReg); 1914 bfloat16 *d = vd; 1915 1916 for (i = 0; i < n; ++i) { 1917 bfloat16 d0 = float32_to_bfloat16(s0[H4(i)], fpst); 1918 bfloat16 d1 = float32_to_bfloat16(s1[H4(i)], fpst); 1919 d[H2(i * 2 + 0)] = d0; 1920 d[H2(i * 2 + 1)] = d1; 1921 } 1922 } 1923 1924 void HELPER(sme2_fcvtn)(void *vd, void *vs, float_status *fpst, uint32_t desc) 1925 { 1926 size_t i, n = simd_oprsz(desc) / 4; 1927 float32 *s0 = vs; 1928 float32 *s1 = vs + sizeof(ARMVectorReg); 1929 bfloat16 *d = vd; 1930 1931 for (i = 0; i < n; ++i) { 1932 bfloat16 d0 = sve_f32_to_f16(s0[H4(i)], fpst); 1933 bfloat16 d1 = sve_f32_to_f16(s1[H4(i)], fpst); 1934 d[H2(i * 2 + 0)] = d0; 1935 d[H2(i * 2 + 1)] = d1; 1936 } 1937 } 1938 1939 #define SQCVTN2(NAME, TW, TN, HW, HN, SAT) \ 1940 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 1941 { \ 1942 ARMVectorReg scratch; \ 1943 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 1944 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 1945 TN *d = vd; \ 1946 if (vectors_overlap(vd, 1, vs, 2)) { \ 1947 d = (TN *)&scratch; \ 1948 } \ 1949 for (size_t i = 0; i < n; ++i) { \ 1950 d[HN(2 * i + 0)] = SAT(s0[HW(i)]); \ 1951 d[HN(2 * i + 1)] = SAT(s1[HW(i)]); \ 1952 } \ 1953 if (d != vd) { \ 1954 memcpy(vd, d, oprsz); \ 1955 } \ 1956 } 1957 1958 SQCVTN2(sme2_sqcvtn_sh, int32_t, int16_t, H4, H2, do_ssat_h) 1959 SQCVTN2(sme2_uqcvtn_sh, uint32_t, uint16_t, H4, H2, do_usat_h) 1960 SQCVTN2(sme2_sqcvtun_sh, int32_t, uint16_t, H4, H2, do_usat_h) 1961 1962 #undef SQCVTN2 1963 1964 #define SQCVTN4(NAME, TW, TN, HW, HN, SAT) \ 1965 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 1966 { \ 1967 ARMVectorReg scratch; \ 1968 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 1969 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 1970 TW *s2 = vs + 2 * sizeof(ARMVectorReg); \ 1971 TW *s3 = vs + 3 * sizeof(ARMVectorReg); \ 1972 TN *d = vd; \ 1973 if (vectors_overlap(vd, 1, vs, 4)) { \ 1974 d = (TN *)&scratch; \ 1975 } \ 1976 for (size_t i = 0; i < n; ++i) { \ 1977 d[HN(4 * i + 0)] = SAT(s0[HW(i)]); \ 1978 d[HN(4 * i + 1)] = SAT(s1[HW(i)]); \ 1979 d[HN(4 * i + 2)] = SAT(s2[HW(i)]); \ 1980 d[HN(4 * i + 3)] = SAT(s3[HW(i)]); \ 1981 } \ 1982 if (d != vd) { \ 1983 memcpy(vd, d, oprsz); \ 1984 } \ 1985 } 1986 1987 SQCVTN4(sme2_sqcvtn_sb, int32_t, int8_t, H4, H1, do_ssat_b) 1988 SQCVTN4(sme2_uqcvtn_sb, uint32_t, uint8_t, H4, H1, do_usat_b) 1989 SQCVTN4(sme2_sqcvtun_sb, int32_t, uint8_t, H4, H1, do_usat_b) 1990 1991 SQCVTN4(sme2_sqcvtn_dh, int64_t, int16_t, H8, H2, do_ssat_h) 1992 SQCVTN4(sme2_uqcvtn_dh, uint64_t, uint16_t, H8, H2, do_usat_h) 1993 SQCVTN4(sme2_sqcvtun_dh, int64_t, uint16_t, H8, H2, do_usat_h) 1994 1995 #undef SQCVTN4 1996 1997 #define SQRSHRN2(NAME, TW, TN, HW, HN, RSHR, SAT) \ 1998 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 1999 { \ 2000 ARMVectorReg scratch; \ 2001 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 2002 int shift = simd_data(desc); \ 2003 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 2004 TN *d = vd; \ 2005 if (vectors_overlap(vd, 1, vs, 2)) { \ 2006 d = (TN *)&scratch; \ 2007 } \ 2008 for (size_t i = 0; i < n; ++i) { \ 2009 d[HN(2 * i + 0)] = SAT(RSHR(s0[HW(i)], shift)); \ 2010 d[HN(2 * i + 1)] = SAT(RSHR(s1[HW(i)], shift)); \ 2011 } \ 2012 if (d != vd) { \ 2013 memcpy(vd, d, oprsz); \ 2014 } \ 2015 } 2016 2017 SQRSHRN2(sme2_sqrshrn_sh, int32_t, int16_t, H4, H2, do_srshr, do_ssat_h) 2018 SQRSHRN2(sme2_uqrshrn_sh, uint32_t, uint16_t, H4, H2, do_urshr, do_usat_h) 2019 SQRSHRN2(sme2_sqrshrun_sh, int32_t, uint16_t, H4, H2, do_srshr, do_usat_h) 2020 2021 #undef SQRSHRN2 2022 2023 #define SQRSHRN4(NAME, TW, TN, HW, HN, RSHR, SAT) \ 2024 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 2025 { \ 2026 ARMVectorReg scratch; \ 2027 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 2028 int shift = simd_data(desc); \ 2029 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 2030 TW *s2 = vs + 2 * sizeof(ARMVectorReg); \ 2031 TW *s3 = vs + 3 * sizeof(ARMVectorReg); \ 2032 TN *d = vd; \ 2033 if (vectors_overlap(vd, 1, vs, 4)) { \ 2034 d = (TN *)&scratch; \ 2035 } \ 2036 for (size_t i = 0; i < n; ++i) { \ 2037 d[HN(4 * i + 0)] = SAT(RSHR(s0[HW(i)], shift)); \ 2038 d[HN(4 * i + 1)] = SAT(RSHR(s1[HW(i)], shift)); \ 2039 d[HN(4 * i + 2)] = SAT(RSHR(s2[HW(i)], shift)); \ 2040 d[HN(4 * i + 3)] = SAT(RSHR(s3[HW(i)], shift)); \ 2041 } \ 2042 if (d != vd) { \ 2043 memcpy(vd, d, oprsz); \ 2044 } \ 2045 } 2046 2047 SQRSHRN4(sme2_sqrshrn_sb, int32_t, int8_t, H4, H1, do_srshr, do_ssat_b) 2048 SQRSHRN4(sme2_uqrshrn_sb, uint32_t, uint8_t, H4, H1, do_urshr, do_usat_b) 2049 SQRSHRN4(sme2_sqrshrun_sb, int32_t, uint8_t, H4, H1, do_srshr, do_usat_b) 2050 2051 SQRSHRN4(sme2_sqrshrn_dh, int64_t, int16_t, H8, H2, do_srshr, do_ssat_h) 2052 SQRSHRN4(sme2_uqrshrn_dh, uint64_t, uint16_t, H8, H2, do_urshr, do_usat_h) 2053 SQRSHRN4(sme2_sqrshrun_dh, int64_t, uint16_t, H8, H2, do_srshr, do_usat_h) 2054 2055 #undef SQRSHRN4 2056 2057 /* Expand and convert */ 2058 void HELPER(sme2_fcvt_w)(void *vd, void *vs, float_status *fpst, uint32_t desc) 2059 { 2060 ARMVectorReg scratch; 2061 size_t oprsz = simd_oprsz(desc); 2062 size_t i, n = oprsz / 4; 2063 float16 *s = vs; 2064 float32 *d0 = vd; 2065 float32 *d1 = vd + sizeof(ARMVectorReg); 2066 2067 if (vectors_overlap(vd, 1, vs, 2)) { 2068 s = memcpy(&scratch, s, oprsz); 2069 } 2070 2071 for (i = 0; i < n; ++i) { 2072 d0[H4(i)] = sve_f16_to_f32(s[H2(i)], fpst); 2073 } 2074 for (i = 0; i < n; ++i) { 2075 d1[H4(i)] = sve_f16_to_f32(s[H2(n + i)], fpst); 2076 } 2077 } 2078 2079 #define UNPK(NAME, SREG, TW, TN, HW, HN) \ 2080 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 2081 { \ 2082 ARMVectorReg scratch[SREG]; \ 2083 size_t oprsz = simd_oprsz(desc); \ 2084 size_t n = oprsz / sizeof(TW); \ 2085 if (vectors_overlap(vd, 2 * SREG, vs, SREG)) { \ 2086 vs = memcpy(scratch, vs, sizeof(scratch)); \ 2087 } \ 2088 for (size_t r = 0; r < SREG; ++r) { \ 2089 TN *s = vs + r * sizeof(ARMVectorReg); \ 2090 for (size_t i = 0; i < 2; ++i) { \ 2091 TW *d = vd + (2 * r + i) * sizeof(ARMVectorReg); \ 2092 for (size_t e = 0; e < n; ++e) { \ 2093 d[HW(e)] = s[HN(i * n + e)]; \ 2094 } \ 2095 } \ 2096 } \ 2097 } 2098 2099 UNPK(sme2_sunpk2_bh, 1, int16_t, int8_t, H2, H1) 2100 UNPK(sme2_sunpk2_hs, 1, int32_t, int16_t, H4, H2) 2101 UNPK(sme2_sunpk2_sd, 1, int64_t, int32_t, H8, H4) 2102 2103 UNPK(sme2_sunpk4_bh, 2, int16_t, int8_t, H2, H1) 2104 UNPK(sme2_sunpk4_hs, 2, int32_t, int16_t, H4, H2) 2105 UNPK(sme2_sunpk4_sd, 2, int64_t, int32_t, H8, H4) 2106 2107 UNPK(sme2_uunpk2_bh, 1, uint16_t, uint8_t, H2, H1) 2108 UNPK(sme2_uunpk2_hs, 1, uint32_t, uint16_t, H4, H2) 2109 UNPK(sme2_uunpk2_sd, 1, uint64_t, uint32_t, H8, H4) 2110 2111 UNPK(sme2_uunpk4_bh, 2, uint16_t, uint8_t, H2, H1) 2112 UNPK(sme2_uunpk4_hs, 2, uint32_t, uint16_t, H4, H2) 2113 UNPK(sme2_uunpk4_sd, 2, uint64_t, uint32_t, H8, H4) 2114 2115 #undef UNPK 2116 2117 /* Deinterleave and convert. */ 2118 void HELPER(sme2_fcvtl)(void *vd, void *vs, float_status *fpst, uint32_t desc) 2119 { 2120 size_t i, n = simd_oprsz(desc) / 4; 2121 float16 *s = vs; 2122 float32 *d0 = vd; 2123 float32 *d1 = vd + sizeof(ARMVectorReg); 2124 2125 for (i = 0; i < n; ++i) { 2126 float32 v0 = sve_f16_to_f32(s[H2(i * 2 + 0)], fpst); 2127 float32 v1 = sve_f16_to_f32(s[H2(i * 2 + 1)], fpst); 2128 d0[H4(i)] = v0; 2129 d1[H4(i)] = v1; 2130 } 2131 } 2132 2133 void HELPER(sme2_scvtf)(void *vd, void *vs, float_status *fpst, uint32_t desc) 2134 { 2135 size_t i, n = simd_oprsz(desc) / 4; 2136 int32_t *d = vd; 2137 float32 *s = vs; 2138 2139 for (i = 0; i < n; ++i) { 2140 d[i] = int32_to_float32(s[i], fpst); 2141 } 2142 } 2143 2144 void HELPER(sme2_ucvtf)(void *vd, void *vs, float_status *fpst, uint32_t desc) 2145 { 2146 size_t i, n = simd_oprsz(desc) / 4; 2147 uint32_t *d = vd; 2148 float32 *s = vs; 2149 2150 for (i = 0; i < n; ++i) { 2151 d[i] = uint32_to_float32(s[i], fpst); 2152 } 2153 } 2154 2155 #define ZIP2(NAME, TYPE, H) \ 2156 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc) \ 2157 { \ 2158 ARMVectorReg scratch[2]; \ 2159 size_t oprsz = simd_oprsz(desc); \ 2160 size_t pairs = oprsz / (sizeof(TYPE) * 2); \ 2161 TYPE *n = vn, *m = vm; \ 2162 if (vectors_overlap(vd, 2, vn, 1)) { \ 2163 n = memcpy(&scratch[0], vn, oprsz); \ 2164 } \ 2165 if (vectors_overlap(vd, 2, vm, 1)) { \ 2166 m = memcpy(&scratch[1], vm, oprsz); \ 2167 } \ 2168 for (size_t r = 0; r < 2; ++r) { \ 2169 TYPE *d = vd + r * sizeof(ARMVectorReg); \ 2170 size_t base = r * pairs; \ 2171 for (size_t p = 0; p < pairs; ++p) { \ 2172 d[H(2 * p + 0)] = n[base + H(p)]; \ 2173 d[H(2 * p + 1)] = m[base + H(p)]; \ 2174 } \ 2175 } \ 2176 } 2177 2178 ZIP2(sme2_zip2_b, uint8_t, H1) 2179 ZIP2(sme2_zip2_h, uint16_t, H2) 2180 ZIP2(sme2_zip2_s, uint32_t, H4) 2181 ZIP2(sme2_zip2_d, uint64_t, ) 2182 ZIP2(sme2_zip2_q, Int128, ) 2183 2184 #undef ZIP2 2185 2186 #define ZIP4(NAME, TYPE, H) \ 2187 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 2188 { \ 2189 ARMVectorReg scratch[4]; \ 2190 size_t oprsz = simd_oprsz(desc); \ 2191 size_t quads = oprsz / (sizeof(TYPE) * 4); \ 2192 TYPE *s0, *s1, *s2, *s3; \ 2193 if (vs == vd) { \ 2194 vs = memcpy(scratch, vs, sizeof(scratch)); \ 2195 } \ 2196 s0 = vs; \ 2197 s1 = vs + sizeof(ARMVectorReg); \ 2198 s2 = vs + 2 * sizeof(ARMVectorReg); \ 2199 s3 = vs + 3 * sizeof(ARMVectorReg); \ 2200 for (size_t r = 0; r < 4; ++r) { \ 2201 TYPE *d = vd + r * sizeof(ARMVectorReg); \ 2202 size_t base = r * quads; \ 2203 for (size_t q = 0; q < quads; ++q) { \ 2204 d[H(4 * q + 0)] = s0[base + H(q)]; \ 2205 d[H(4 * q + 1)] = s1[base + H(q)]; \ 2206 d[H(4 * q + 2)] = s2[base + H(q)]; \ 2207 d[H(4 * q + 3)] = s3[base + H(q)]; \ 2208 } \ 2209 } \ 2210 } 2211 2212 ZIP4(sme2_zip4_b, uint8_t, H1) 2213 ZIP4(sme2_zip4_h, uint16_t, H2) 2214 ZIP4(sme2_zip4_s, uint32_t, H4) 2215 ZIP4(sme2_zip4_d, uint64_t, ) 2216 ZIP4(sme2_zip4_q, Int128, ) 2217 2218 #undef ZIP4 2219 2220 #define UZP2(NAME, TYPE, H) \ 2221 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc) \ 2222 { \ 2223 ARMVectorReg scratch[2]; \ 2224 size_t oprsz = simd_oprsz(desc); \ 2225 size_t pairs = oprsz / (sizeof(TYPE) * 2); \ 2226 TYPE *d0 = vd, *d1 = vd + sizeof(ARMVectorReg); \ 2227 if (vectors_overlap(vd, 2, vn, 1)) { \ 2228 vn = memcpy(&scratch[0], vn, oprsz); \ 2229 } \ 2230 if (vectors_overlap(vd, 2, vm, 1)) { \ 2231 vm = memcpy(&scratch[1], vm, oprsz); \ 2232 } \ 2233 for (size_t r = 0; r < 2; ++r) { \ 2234 TYPE *s = r ? vm : vn; \ 2235 size_t base = r * pairs; \ 2236 for (size_t p = 0; p < pairs; ++p) { \ 2237 d0[base + H(p)] = s[H(2 * p + 0)]; \ 2238 d1[base + H(p)] = s[H(2 * p + 1)]; \ 2239 } \ 2240 } \ 2241 } 2242 2243 UZP2(sme2_uzp2_b, uint8_t, H1) 2244 UZP2(sme2_uzp2_h, uint16_t, H2) 2245 UZP2(sme2_uzp2_s, uint32_t, H4) 2246 UZP2(sme2_uzp2_d, uint64_t, ) 2247 UZP2(sme2_uzp2_q, Int128, ) 2248 2249 #undef UZP2 2250 2251 #define UZP4(NAME, TYPE, H) \ 2252 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 2253 { \ 2254 ARMVectorReg scratch[4]; \ 2255 size_t oprsz = simd_oprsz(desc); \ 2256 size_t quads = oprsz / (sizeof(TYPE) * 4); \ 2257 TYPE *d0, *d1, *d2, *d3; \ 2258 if (vs == vd) { \ 2259 vs = memcpy(scratch, vs, sizeof(scratch)); \ 2260 } \ 2261 d0 = vd; \ 2262 d1 = vd + sizeof(ARMVectorReg); \ 2263 d2 = vd + 2 * sizeof(ARMVectorReg); \ 2264 d3 = vd + 3 * sizeof(ARMVectorReg); \ 2265 for (size_t r = 0; r < 4; ++r) { \ 2266 TYPE *s = vs + r * sizeof(ARMVectorReg); \ 2267 size_t base = r * quads; \ 2268 for (size_t q = 0; q < quads; ++q) { \ 2269 d0[base + H(q)] = s[H(4 * q + 0)]; \ 2270 d1[base + H(q)] = s[H(4 * q + 1)]; \ 2271 d2[base + H(q)] = s[H(4 * q + 2)]; \ 2272 d3[base + H(q)] = s[H(4 * q + 3)]; \ 2273 } \ 2274 } \ 2275 } 2276 2277 UZP4(sme2_uzp4_b, uint8_t, H1) 2278 UZP4(sme2_uzp4_h, uint16_t, H2) 2279 UZP4(sme2_uzp4_s, uint32_t, H4) 2280 UZP4(sme2_uzp4_d, uint64_t, ) 2281 UZP4(sme2_uzp4_q, Int128, ) 2282 2283 #undef UZP4 2284 2285 #define ICLAMP(NAME, TYPE, H) \ 2286 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc) \ 2287 { \ 2288 size_t stride = sizeof(ARMVectorReg) / sizeof(TYPE); \ 2289 size_t elements = simd_oprsz(desc) / sizeof(TYPE); \ 2290 size_t nreg = simd_data(desc); \ 2291 TYPE *d = vd, *n = vn, *m = vm; \ 2292 for (size_t e = 0; e < elements; e++) { \ 2293 TYPE nn = n[H(e)], mm = m[H(e)]; \ 2294 for (size_t r = 0; r < nreg; r++) { \ 2295 TYPE *dd = &d[r * stride + H(e)]; \ 2296 *dd = MIN(MAX(*dd, nn), mm); \ 2297 } \ 2298 } \ 2299 } 2300 2301 ICLAMP(sme2_sclamp_b, int8_t, H1) 2302 ICLAMP(sme2_sclamp_h, int16_t, H2) 2303 ICLAMP(sme2_sclamp_s, int32_t, H4) 2304 ICLAMP(sme2_sclamp_d, int64_t, H8) 2305 2306 ICLAMP(sme2_uclamp_b, uint8_t, H1) 2307 ICLAMP(sme2_uclamp_h, uint16_t, H2) 2308 ICLAMP(sme2_uclamp_s, uint32_t, H4) 2309 ICLAMP(sme2_uclamp_d, uint64_t, H8) 2310 2311 #undef ICLAMP 2312 2313 /* 2314 * Note the argument ordering to minnum and maxnum must match 2315 * the ARM pseudocode so that NaNs are propagated properly. 2316 */ 2317 #define FCLAMP(NAME, TYPE, H) \ 2318 void HELPER(NAME)(void *vd, void *vn, void *vm, \ 2319 float_status *fpst, uint32_t desc) \ 2320 { \ 2321 size_t stride = sizeof(ARMVectorReg) / sizeof(TYPE); \ 2322 size_t elements = simd_oprsz(desc) / sizeof(TYPE); \ 2323 size_t nreg = simd_data(desc); \ 2324 TYPE *d = vd, *n = vn, *m = vm; \ 2325 for (size_t e = 0; e < elements; e++) { \ 2326 TYPE nn = n[H(e)], mm = m[H(e)]; \ 2327 for (size_t r = 0; r < nreg; r++) { \ 2328 TYPE *dd = &d[r * stride + H(e)]; \ 2329 *dd = TYPE##_minnum(TYPE##_maxnum(nn, *dd, fpst), mm, fpst); \ 2330 } \ 2331 } \ 2332 } 2333 2334 FCLAMP(sme2_fclamp_h, float16, H2) 2335 FCLAMP(sme2_fclamp_s, float32, H4) 2336 FCLAMP(sme2_fclamp_d, float64, H8) 2337 FCLAMP(sme2_bfclamp, bfloat16, H2) 2338 2339 #undef FCLAMP 2340 2341 void HELPER(sme2_sel_b)(void *vd, void *vn, void *vm, 2342 uint32_t png, uint32_t desc) 2343 { 2344 int vl = simd_oprsz(desc); 2345 int nreg = simd_data(desc); 2346 int elements = vl / sizeof(uint8_t); 2347 DecodeCounter p = decode_counter(png, vl, MO_8); 2348 2349 if (p.lg2_stride == 0) { 2350 if (p.invert) { 2351 for (int r = 0; r < nreg; r++) { 2352 uint8_t *d = vd + r * sizeof(ARMVectorReg); 2353 uint8_t *n = vn + r * sizeof(ARMVectorReg); 2354 uint8_t *m = vm + r * sizeof(ARMVectorReg); 2355 int split = p.count - r * elements; 2356 2357 if (split <= 0) { 2358 memcpy(d, n, vl); /* all true */ 2359 } else if (elements <= split) { 2360 memcpy(d, m, vl); /* all false */ 2361 } else { 2362 for (int e = 0; e < split; e++) { 2363 d[H1(e)] = m[H1(e)]; 2364 } 2365 for (int e = split; e < elements; e++) { 2366 d[H1(e)] = n[H1(e)]; 2367 } 2368 } 2369 } 2370 } else { 2371 for (int r = 0; r < nreg; r++) { 2372 uint8_t *d = vd + r * sizeof(ARMVectorReg); 2373 uint8_t *n = vn + r * sizeof(ARMVectorReg); 2374 uint8_t *m = vm + r * sizeof(ARMVectorReg); 2375 int split = p.count - r * elements; 2376 2377 if (split <= 0) { 2378 memcpy(d, m, vl); /* all false */ 2379 } else if (elements <= split) { 2380 memcpy(d, n, vl); /* all true */ 2381 } else { 2382 for (int e = 0; e < split; e++) { 2383 d[H1(e)] = n[H1(e)]; 2384 } 2385 for (int e = split; e < elements; e++) { 2386 d[H1(e)] = m[H1(e)]; 2387 } 2388 } 2389 } 2390 } 2391 } else { 2392 int estride = 1 << p.lg2_stride; 2393 if (p.invert) { 2394 for (int r = 0; r < nreg; r++) { 2395 uint8_t *d = vd + r * sizeof(ARMVectorReg); 2396 uint8_t *n = vn + r * sizeof(ARMVectorReg); 2397 uint8_t *m = vm + r * sizeof(ARMVectorReg); 2398 int split = p.count - r * elements; 2399 int e = 0; 2400 2401 for (; e < MIN(split, elements); e++) { 2402 d[H1(e)] = m[H1(e)]; 2403 } 2404 for (; e < elements; e += estride) { 2405 d[H1(e)] = n[H1(e)]; 2406 for (int i = 1; i < estride; i++) { 2407 d[H1(e + i)] = m[H1(e + i)]; 2408 } 2409 } 2410 } 2411 } else { 2412 for (int r = 0; r < nreg; r++) { 2413 uint8_t *d = vd + r * sizeof(ARMVectorReg); 2414 uint8_t *n = vn + r * sizeof(ARMVectorReg); 2415 uint8_t *m = vm + r * sizeof(ARMVectorReg); 2416 int split = p.count - r * elements; 2417 int e = 0; 2418 2419 for (; e < MIN(split, elements); e += estride) { 2420 d[H1(e)] = n[H1(e)]; 2421 for (int i = 1; i < estride; i++) { 2422 d[H1(e + i)] = m[H1(e + i)]; 2423 } 2424 } 2425 for (; e < elements; e++) { 2426 d[H1(e)] = m[H1(e)]; 2427 } 2428 } 2429 } 2430 } 2431 } 2432 2433 void HELPER(sme2_sel_h)(void *vd, void *vn, void *vm, 2434 uint32_t png, uint32_t desc) 2435 { 2436 int vl = simd_oprsz(desc); 2437 int nreg = simd_data(desc); 2438 int elements = vl / sizeof(uint16_t); 2439 DecodeCounter p = decode_counter(png, vl, MO_16); 2440 2441 if (p.lg2_stride == 0) { 2442 if (p.invert) { 2443 for (int r = 0; r < nreg; r++) { 2444 uint16_t *d = vd + r * sizeof(ARMVectorReg); 2445 uint16_t *n = vn + r * sizeof(ARMVectorReg); 2446 uint16_t *m = vm + r * sizeof(ARMVectorReg); 2447 int split = p.count - r * elements; 2448 2449 if (split <= 0) { 2450 memcpy(d, n, vl); /* all true */ 2451 } else if (elements <= split) { 2452 memcpy(d, m, vl); /* all false */ 2453 } else { 2454 for (int e = 0; e < split; e++) { 2455 d[H2(e)] = m[H2(e)]; 2456 } 2457 for (int e = split; e < elements; e++) { 2458 d[H2(e)] = n[H2(e)]; 2459 } 2460 } 2461 } 2462 } else { 2463 for (int r = 0; r < nreg; r++) { 2464 uint16_t *d = vd + r * sizeof(ARMVectorReg); 2465 uint16_t *n = vn + r * sizeof(ARMVectorReg); 2466 uint16_t *m = vm + r * sizeof(ARMVectorReg); 2467 int split = p.count - r * elements; 2468 2469 if (split <= 0) { 2470 memcpy(d, m, vl); /* all false */ 2471 } else if (elements <= split) { 2472 memcpy(d, n, vl); /* all true */ 2473 } else { 2474 for (int e = 0; e < split; e++) { 2475 d[H2(e)] = n[H2(e)]; 2476 } 2477 for (int e = split; e < elements; e++) { 2478 d[H2(e)] = m[H2(e)]; 2479 } 2480 } 2481 } 2482 } 2483 } else { 2484 int estride = 1 << p.lg2_stride; 2485 if (p.invert) { 2486 for (int r = 0; r < nreg; r++) { 2487 uint16_t *d = vd + r * sizeof(ARMVectorReg); 2488 uint16_t *n = vn + r * sizeof(ARMVectorReg); 2489 uint16_t *m = vm + r * sizeof(ARMVectorReg); 2490 int split = p.count - r * elements; 2491 int e = 0; 2492 2493 for (; e < MIN(split, elements); e++) { 2494 d[H2(e)] = m[H2(e)]; 2495 } 2496 for (; e < elements; e += estride) { 2497 d[H2(e)] = n[H2(e)]; 2498 for (int i = 1; i < estride; i++) { 2499 d[H2(e + i)] = m[H2(e + i)]; 2500 } 2501 } 2502 } 2503 } else { 2504 for (int r = 0; r < nreg; r++) { 2505 uint16_t *d = vd + r * sizeof(ARMVectorReg); 2506 uint16_t *n = vn + r * sizeof(ARMVectorReg); 2507 uint16_t *m = vm + r * sizeof(ARMVectorReg); 2508 int split = p.count - r * elements; 2509 int e = 0; 2510 2511 for (; e < MIN(split, elements); e += estride) { 2512 d[H2(e)] = n[H2(e)]; 2513 for (int i = 1; i < estride; i++) { 2514 d[H2(e + i)] = m[H2(e + i)]; 2515 } 2516 } 2517 for (; e < elements; e++) { 2518 d[H2(e)] = m[H2(e)]; 2519 } 2520 } 2521 } 2522 } 2523 } 2524 2525 void HELPER(sme2_sel_s)(void *vd, void *vn, void *vm, 2526 uint32_t png, uint32_t desc) 2527 { 2528 int vl = simd_oprsz(desc); 2529 int nreg = simd_data(desc); 2530 int elements = vl / sizeof(uint32_t); 2531 DecodeCounter p = decode_counter(png, vl, MO_32); 2532 2533 if (p.lg2_stride == 0) { 2534 if (p.invert) { 2535 for (int r = 0; r < nreg; r++) { 2536 uint32_t *d = vd + r * sizeof(ARMVectorReg); 2537 uint32_t *n = vn + r * sizeof(ARMVectorReg); 2538 uint32_t *m = vm + r * sizeof(ARMVectorReg); 2539 int split = p.count - r * elements; 2540 2541 if (split <= 0) { 2542 memcpy(d, n, vl); /* all true */ 2543 } else if (elements <= split) { 2544 memcpy(d, m, vl); /* all false */ 2545 } else { 2546 for (int e = 0; e < split; e++) { 2547 d[H4(e)] = m[H4(e)]; 2548 } 2549 for (int e = split; e < elements; e++) { 2550 d[H4(e)] = n[H4(e)]; 2551 } 2552 } 2553 } 2554 } else { 2555 for (int r = 0; r < nreg; r++) { 2556 uint32_t *d = vd + r * sizeof(ARMVectorReg); 2557 uint32_t *n = vn + r * sizeof(ARMVectorReg); 2558 uint32_t *m = vm + r * sizeof(ARMVectorReg); 2559 int split = p.count - r * elements; 2560 2561 if (split <= 0) { 2562 memcpy(d, m, vl); /* all false */ 2563 } else if (elements <= split) { 2564 memcpy(d, n, vl); /* all true */ 2565 } else { 2566 for (int e = 0; e < split; e++) { 2567 d[H4(e)] = n[H4(e)]; 2568 } 2569 for (int e = split; e < elements; e++) { 2570 d[H4(e)] = m[H4(e)]; 2571 } 2572 } 2573 } 2574 } 2575 } else { 2576 /* p.esz must be MO_64, so stride must be 2. */ 2577 if (p.invert) { 2578 for (int r = 0; r < nreg; r++) { 2579 uint32_t *d = vd + r * sizeof(ARMVectorReg); 2580 uint32_t *n = vn + r * sizeof(ARMVectorReg); 2581 uint32_t *m = vm + r * sizeof(ARMVectorReg); 2582 int split = p.count - r * elements; 2583 int e = 0; 2584 2585 for (; e < MIN(split, elements); e++) { 2586 d[H4(e)] = m[H4(e)]; 2587 } 2588 for (; e < elements; e += 2) { 2589 d[H4(e)] = n[H4(e)]; 2590 d[H4(e + 1)] = m[H4(e + 1)]; 2591 } 2592 } 2593 } else { 2594 for (int r = 0; r < nreg; r++) { 2595 uint32_t *d = vd + r * sizeof(ARMVectorReg); 2596 uint32_t *n = vn + r * sizeof(ARMVectorReg); 2597 uint32_t *m = vm + r * sizeof(ARMVectorReg); 2598 int split = p.count - r * elements; 2599 int e = 0; 2600 2601 for (; e < MIN(split, elements); e += 2) { 2602 d[H4(e)] = n[H4(e)]; 2603 d[H4(e + 1)] = m[H4(e + 1)]; 2604 } 2605 for (; e < elements; e++) { 2606 d[H4(e)] = m[H4(e)]; 2607 } 2608 } 2609 } 2610 } 2611 } 2612 2613 void HELPER(sme2_sel_d)(void *vd, void *vn, void *vm, 2614 uint32_t png, uint32_t desc) 2615 { 2616 int vl = simd_oprsz(desc); 2617 int nreg = simd_data(desc); 2618 int elements = vl / sizeof(uint64_t); 2619 DecodeCounter p = decode_counter(png, vl, MO_64); 2620 2621 if (p.invert) { 2622 for (int r = 0; r < nreg; r++) { 2623 uint64_t *d = vd + r * sizeof(ARMVectorReg); 2624 uint64_t *n = vn + r * sizeof(ARMVectorReg); 2625 uint64_t *m = vm + r * sizeof(ARMVectorReg); 2626 int split = p.count - r * elements; 2627 2628 if (split <= 0) { 2629 memcpy(d, n, vl); /* all true */ 2630 } else if (elements <= split) { 2631 memcpy(d, m, vl); /* all false */ 2632 } else { 2633 memcpy(d, m, split * sizeof(uint64_t)); 2634 memcpy(d + split, n + split, 2635 (elements - split) * sizeof(uint64_t)); 2636 } 2637 } 2638 } else { 2639 for (int r = 0; r < nreg; r++) { 2640 uint64_t *d = vd + r * sizeof(ARMVectorReg); 2641 uint64_t *n = vn + r * sizeof(ARMVectorReg); 2642 uint64_t *m = vm + r * sizeof(ARMVectorReg); 2643 int split = p.count - r * elements; 2644 2645 if (split <= 0) { 2646 memcpy(d, m, vl); /* all false */ 2647 } else if (elements <= split) { 2648 memcpy(d, n, vl); /* all true */ 2649 } else { 2650 memcpy(d, n, split * sizeof(uint64_t)); 2651 memcpy(d + split, m + split, 2652 (elements - split) * sizeof(uint64_t)); 2653 } 2654 } 2655 } 2656 } 2657