xref: /openbmc/qemu/target/arm/tcg/sme_helper.c (revision c017386f28c03a03b8f14444f8671d3d8f7180fe)
1 /*
2  * ARM SME Operations
3  *
4  * Copyright (c) 2022 Linaro, Ltd.
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, see <http://www.gnu.org/licenses/>.
18  */
19 
20 #include "qemu/osdep.h"
21 #include "cpu.h"
22 #include "internals.h"
23 #include "tcg/tcg-gvec-desc.h"
24 #include "exec/helper-proto.h"
25 #include "accel/tcg/cpu-ldst.h"
26 #include "accel/tcg/helper-retaddr.h"
27 #include "qemu/int128.h"
28 #include "fpu/softfloat.h"
29 #include "vec_internal.h"
30 #include "sve_ldst_internal.h"
31 
32 
vectors_overlap(ARMVectorReg * x,unsigned nx,ARMVectorReg * y,unsigned ny)33 static bool vectors_overlap(ARMVectorReg *x, unsigned nx,
34                             ARMVectorReg *y, unsigned ny)
35 {
36     return !(x + nx <= y || y + ny <= x);
37 }
38 
helper_set_svcr(CPUARMState * env,uint32_t val,uint32_t mask)39 void helper_set_svcr(CPUARMState *env, uint32_t val, uint32_t mask)
40 {
41     aarch64_set_svcr(env, val, mask);
42 }
43 
helper_sme_zero(CPUARMState * env,uint32_t imm,uint32_t svl)44 void helper_sme_zero(CPUARMState *env, uint32_t imm, uint32_t svl)
45 {
46     uint32_t i;
47 
48     /*
49      * Special case clearing the entire ZArray.
50      * This falls into the CONSTRAINED UNPREDICTABLE zeroing of any
51      * parts of the ZA storage outside of SVL.
52      */
53     if (imm == 0xff) {
54         memset(env->za_state.za, 0, sizeof(env->za_state.za));
55         return;
56     }
57 
58     /*
59      * Recall that ZAnH.D[m] is spread across ZA[n+8*m],
60      * so each row is discontiguous within ZA[].
61      */
62     for (i = 0; i < svl; i++) {
63         if (imm & (1 << (i % 8))) {
64             memset(&env->za_state.za[i], 0, svl);
65         }
66     }
67 }
68 
69 
70 /*
71  * When considering the ZA storage as an array of elements of
72  * type T, the index within that array of the Nth element of
73  * a vertical slice of a tile can be calculated like this,
74  * regardless of the size of type T. This is because the tiles
75  * are interleaved, so if type T is size N bytes then row 1 of
76  * the tile is N rows away from row 0. The division by N to
77  * convert a byte offset into an array index and the multiplication
78  * by N to convert from vslice-index-within-the-tile to
79  * the index within the ZA storage cancel out.
80  */
81 #define tile_vslice_index(i) ((i) * sizeof(ARMVectorReg))
82 
83 /*
84  * When doing byte arithmetic on the ZA storage, the element
85  * byteoff bytes away in a tile vertical slice is always this
86  * many bytes away in the ZA storage, regardless of the
87  * size of the tile element, assuming that byteoff is a multiple
88  * of the element size. Again this is because of the interleaving
89  * of the tiles. For instance if we have 1 byte per element then
90  * each row of the ZA storage has one byte of the vslice data,
91  * and (counting from 0) byte 8 goes in row 8 of the storage
92  * at offset (8 * row-size-in-bytes).
93  * If we have 8 bytes per element then each row of the ZA storage
94  * has 8 bytes of the data, but there are 8 interleaved tiles and
95  * so byte 8 of the data goes into row 1 of the tile,
96  * which is again row 8 of the storage, so the offset is still
97  * (8 * row-size-in-bytes). Similarly for other element sizes.
98  */
99 #define tile_vslice_offset(byteoff) ((byteoff) * sizeof(ARMVectorReg))
100 
101 
102 /*
103  * Move Zreg vector to ZArray column.
104  */
105 #define DO_MOVA_C(NAME, TYPE, H)                                        \
106 void HELPER(NAME)(void *za, void *vn, void *vg, uint32_t desc)          \
107 {                                                                       \
108     int i, oprsz = simd_oprsz(desc);                                    \
109     for (i = 0; i < oprsz; ) {                                          \
110         uint16_t pg = *(uint16_t *)(vg + H1_2(i >> 3));                 \
111         do {                                                            \
112             if (pg & 1) {                                               \
113                 *(TYPE *)(za + tile_vslice_offset(i)) = *(TYPE *)(vn + H(i)); \
114             }                                                           \
115             i += sizeof(TYPE);                                          \
116             pg >>= sizeof(TYPE);                                        \
117         } while (i & 15);                                               \
118     }                                                                   \
119 }
120 
DO_MOVA_C(sme_mova_cz_b,uint8_t,H1)121 DO_MOVA_C(sme_mova_cz_b, uint8_t, H1)
122 DO_MOVA_C(sme_mova_cz_h, uint16_t, H1_2)
123 DO_MOVA_C(sme_mova_cz_s, uint32_t, H1_4)
124 
125 void HELPER(sme_mova_cz_d)(void *za, void *vn, void *vg, uint32_t desc)
126 {
127     int i, oprsz = simd_oprsz(desc) / 8;
128     uint8_t *pg = vg;
129     uint64_t *n = vn;
130     uint64_t *a = za;
131 
132     for (i = 0; i < oprsz; i++) {
133         if (pg[H1(i)] & 1) {
134             a[tile_vslice_index(i)] = n[i];
135         }
136     }
137 }
138 
HELPER(sme_mova_cz_q)139 void HELPER(sme_mova_cz_q)(void *za, void *vn, void *vg, uint32_t desc)
140 {
141     int i, oprsz = simd_oprsz(desc) / 16;
142     uint16_t *pg = vg;
143     Int128 *n = vn;
144     Int128 *a = za;
145 
146     /*
147      * Int128 is used here simply to copy 16 bytes, and to simplify
148      * the address arithmetic.
149      */
150     for (i = 0; i < oprsz; i++) {
151         if (pg[H2(i)] & 1) {
152             a[tile_vslice_index(i)] = n[i];
153         }
154     }
155 }
156 
157 #undef DO_MOVA_C
158 
159 /*
160  * Move ZArray column to Zreg vector.
161  */
162 #define DO_MOVA_Z(NAME, TYPE, H)                                        \
163 void HELPER(NAME)(void *vd, void *za, void *vg, uint32_t desc)          \
164 {                                                                       \
165     int i, oprsz = simd_oprsz(desc);                                    \
166     for (i = 0; i < oprsz; ) {                                          \
167         uint16_t pg = *(uint16_t *)(vg + H1_2(i >> 3));                 \
168         do {                                                            \
169             if (pg & 1) {                                               \
170                 *(TYPE *)(vd + H(i)) = *(TYPE *)(za + tile_vslice_offset(i)); \
171             }                                                           \
172             i += sizeof(TYPE);                                          \
173             pg >>= sizeof(TYPE);                                        \
174         } while (i & 15);                                               \
175     }                                                                   \
176 }
177 
DO_MOVA_Z(sme_mova_zc_b,uint8_t,H1)178 DO_MOVA_Z(sme_mova_zc_b, uint8_t, H1)
179 DO_MOVA_Z(sme_mova_zc_h, uint16_t, H1_2)
180 DO_MOVA_Z(sme_mova_zc_s, uint32_t, H1_4)
181 
182 void HELPER(sme_mova_zc_d)(void *vd, void *za, void *vg, uint32_t desc)
183 {
184     int i, oprsz = simd_oprsz(desc) / 8;
185     uint8_t *pg = vg;
186     uint64_t *d = vd;
187     uint64_t *a = za;
188 
189     for (i = 0; i < oprsz; i++) {
190         if (pg[H1(i)] & 1) {
191             d[i] = a[tile_vslice_index(i)];
192         }
193     }
194 }
195 
HELPER(sme_mova_zc_q)196 void HELPER(sme_mova_zc_q)(void *vd, void *za, void *vg, uint32_t desc)
197 {
198     int i, oprsz = simd_oprsz(desc) / 16;
199     uint16_t *pg = vg;
200     Int128 *d = vd;
201     Int128 *a = za;
202 
203     /*
204      * Int128 is used here simply to copy 16 bytes, and to simplify
205      * the address arithmetic.
206      */
207     for (i = 0; i < oprsz; i++, za += sizeof(ARMVectorReg)) {
208         if (pg[H2(i)] & 1) {
209             d[i] = a[tile_vslice_index(i)];
210         }
211     }
212 }
213 
214 #undef DO_MOVA_Z
215 
HELPER(sme2_mova_zc_b)216 void HELPER(sme2_mova_zc_b)(void *vdst, void *vsrc, uint32_t desc)
217 {
218     const uint8_t *src = vsrc;
219     uint8_t *dst = vdst;
220     size_t i, n = simd_oprsz(desc);
221 
222     for (i = 0; i < n; ++i) {
223         dst[i] = src[tile_vslice_index(i)];
224     }
225 }
226 
HELPER(sme2_mova_zc_h)227 void HELPER(sme2_mova_zc_h)(void *vdst, void *vsrc, uint32_t desc)
228 {
229     const uint16_t *src = vsrc;
230     uint16_t *dst = vdst;
231     size_t i, n = simd_oprsz(desc) / 2;
232 
233     for (i = 0; i < n; ++i) {
234         dst[i] = src[tile_vslice_index(i)];
235     }
236 }
237 
HELPER(sme2_mova_zc_s)238 void HELPER(sme2_mova_zc_s)(void *vdst, void *vsrc, uint32_t desc)
239 {
240     const uint32_t *src = vsrc;
241     uint32_t *dst = vdst;
242     size_t i, n = simd_oprsz(desc) / 4;
243 
244     for (i = 0; i < n; ++i) {
245         dst[i] = src[tile_vslice_index(i)];
246     }
247 }
248 
HELPER(sme2_mova_zc_d)249 void HELPER(sme2_mova_zc_d)(void *vdst, void *vsrc, uint32_t desc)
250 {
251     const uint64_t *src = vsrc;
252     uint64_t *dst = vdst;
253     size_t i, n = simd_oprsz(desc) / 8;
254 
255     for (i = 0; i < n; ++i) {
256         dst[i] = src[tile_vslice_index(i)];
257     }
258 }
259 
HELPER(sme2p1_movaz_zc_b)260 void HELPER(sme2p1_movaz_zc_b)(void *vdst, void *vsrc, uint32_t desc)
261 {
262     uint8_t *src = vsrc;
263     uint8_t *dst = vdst;
264     size_t i, n = simd_oprsz(desc);
265 
266     for (i = 0; i < n; ++i) {
267         dst[i] = src[tile_vslice_index(i)];
268         src[tile_vslice_index(i)] = 0;
269     }
270 }
271 
HELPER(sme2p1_movaz_zc_h)272 void HELPER(sme2p1_movaz_zc_h)(void *vdst, void *vsrc, uint32_t desc)
273 {
274     uint16_t *src = vsrc;
275     uint16_t *dst = vdst;
276     size_t i, n = simd_oprsz(desc) / 2;
277 
278     for (i = 0; i < n; ++i) {
279         dst[i] = src[tile_vslice_index(i)];
280         src[tile_vslice_index(i)] = 0;
281     }
282 }
283 
HELPER(sme2p1_movaz_zc_s)284 void HELPER(sme2p1_movaz_zc_s)(void *vdst, void *vsrc, uint32_t desc)
285 {
286     uint32_t *src = vsrc;
287     uint32_t *dst = vdst;
288     size_t i, n = simd_oprsz(desc) / 4;
289 
290     for (i = 0; i < n; ++i) {
291         dst[i] = src[tile_vslice_index(i)];
292         src[tile_vslice_index(i)] = 0;
293     }
294 }
295 
HELPER(sme2p1_movaz_zc_d)296 void HELPER(sme2p1_movaz_zc_d)(void *vdst, void *vsrc, uint32_t desc)
297 {
298     uint64_t *src = vsrc;
299     uint64_t *dst = vdst;
300     size_t i, n = simd_oprsz(desc) / 8;
301 
302     for (i = 0; i < n; ++i) {
303         dst[i] = src[tile_vslice_index(i)];
304         src[tile_vslice_index(i)] = 0;
305     }
306 }
307 
HELPER(sme2p1_movaz_zc_q)308 void HELPER(sme2p1_movaz_zc_q)(void *vdst, void *vsrc, uint32_t desc)
309 {
310     Int128 *src = vsrc;
311     Int128 *dst = vdst;
312     size_t i, n = simd_oprsz(desc) / 16;
313 
314     for (i = 0; i < n; ++i) {
315         dst[i] = src[tile_vslice_index(i)];
316         memset(&src[tile_vslice_index(i)], 0, 16);
317     }
318 }
319 
320 /*
321  * Clear elements in a tile slice comprising len bytes.
322  */
323 
324 typedef void ClearFn(void *ptr, size_t off, size_t len);
325 
clear_horizontal(void * ptr,size_t off,size_t len)326 static void clear_horizontal(void *ptr, size_t off, size_t len)
327 {
328     memset(ptr + off, 0, len);
329 }
330 
clear_vertical_b(void * vptr,size_t off,size_t len)331 static void clear_vertical_b(void *vptr, size_t off, size_t len)
332 {
333     for (size_t i = 0; i < len; ++i) {
334         *(uint8_t *)(vptr + tile_vslice_offset(i + off)) = 0;
335     }
336 }
337 
clear_vertical_h(void * vptr,size_t off,size_t len)338 static void clear_vertical_h(void *vptr, size_t off, size_t len)
339 {
340     for (size_t i = 0; i < len; i += 2) {
341         *(uint16_t *)(vptr + tile_vslice_offset(i + off)) = 0;
342     }
343 }
344 
clear_vertical_s(void * vptr,size_t off,size_t len)345 static void clear_vertical_s(void *vptr, size_t off, size_t len)
346 {
347     for (size_t i = 0; i < len; i += 4) {
348         *(uint32_t *)(vptr + tile_vslice_offset(i + off)) = 0;
349     }
350 }
351 
clear_vertical_d(void * vptr,size_t off,size_t len)352 static void clear_vertical_d(void *vptr, size_t off, size_t len)
353 {
354     for (size_t i = 0; i < len; i += 8) {
355         *(uint64_t *)(vptr + tile_vslice_offset(i + off)) = 0;
356     }
357 }
358 
clear_vertical_q(void * vptr,size_t off,size_t len)359 static void clear_vertical_q(void *vptr, size_t off, size_t len)
360 {
361     for (size_t i = 0; i < len; i += 16) {
362         memset(vptr + tile_vslice_offset(i + off), 0, 16);
363     }
364 }
365 
366 /*
367  * Copy elements from an array into a tile slice comprising len bytes.
368  */
369 
370 typedef void CopyFn(void *dst, const void *src, size_t len);
371 
copy_horizontal(void * dst,const void * src,size_t len)372 static void copy_horizontal(void *dst, const void *src, size_t len)
373 {
374     memcpy(dst, src, len);
375 }
376 
copy_vertical_b(void * vdst,const void * vsrc,size_t len)377 static void copy_vertical_b(void *vdst, const void *vsrc, size_t len)
378 {
379     const uint8_t *src = vsrc;
380     uint8_t *dst = vdst;
381     size_t i;
382 
383     for (i = 0; i < len; ++i) {
384         dst[tile_vslice_index(i)] = src[i];
385     }
386 }
387 
copy_vertical_h(void * vdst,const void * vsrc,size_t len)388 static void copy_vertical_h(void *vdst, const void *vsrc, size_t len)
389 {
390     const uint16_t *src = vsrc;
391     uint16_t *dst = vdst;
392     size_t i;
393 
394     for (i = 0; i < len / 2; ++i) {
395         dst[tile_vslice_index(i)] = src[i];
396     }
397 }
398 
copy_vertical_s(void * vdst,const void * vsrc,size_t len)399 static void copy_vertical_s(void *vdst, const void *vsrc, size_t len)
400 {
401     const uint32_t *src = vsrc;
402     uint32_t *dst = vdst;
403     size_t i;
404 
405     for (i = 0; i < len / 4; ++i) {
406         dst[tile_vslice_index(i)] = src[i];
407     }
408 }
409 
copy_vertical_d(void * vdst,const void * vsrc,size_t len)410 static void copy_vertical_d(void *vdst, const void *vsrc, size_t len)
411 {
412     const uint64_t *src = vsrc;
413     uint64_t *dst = vdst;
414     size_t i;
415 
416     for (i = 0; i < len / 8; ++i) {
417         dst[tile_vslice_index(i)] = src[i];
418     }
419 }
420 
copy_vertical_q(void * vdst,const void * vsrc,size_t len)421 static void copy_vertical_q(void *vdst, const void *vsrc, size_t len)
422 {
423     for (size_t i = 0; i < len; i += 16) {
424         memcpy(vdst + tile_vslice_offset(i), vsrc + i, 16);
425     }
426 }
427 
HELPER(sme2_mova_cz_b)428 void HELPER(sme2_mova_cz_b)(void *vdst, void *vsrc, uint32_t desc)
429 {
430     copy_vertical_b(vdst, vsrc, simd_oprsz(desc));
431 }
432 
HELPER(sme2_mova_cz_h)433 void HELPER(sme2_mova_cz_h)(void *vdst, void *vsrc, uint32_t desc)
434 {
435     copy_vertical_h(vdst, vsrc, simd_oprsz(desc));
436 }
437 
HELPER(sme2_mova_cz_s)438 void HELPER(sme2_mova_cz_s)(void *vdst, void *vsrc, uint32_t desc)
439 {
440     copy_vertical_s(vdst, vsrc, simd_oprsz(desc));
441 }
442 
HELPER(sme2_mova_cz_d)443 void HELPER(sme2_mova_cz_d)(void *vdst, void *vsrc, uint32_t desc)
444 {
445     copy_vertical_d(vdst, vsrc, simd_oprsz(desc));
446 }
447 
448 /*
449  * Host and TLB primitives for vertical tile slice addressing.
450  */
451 
452 #define DO_LD(NAME, TYPE, HOST, TLB)                                        \
453 static inline void sme_##NAME##_v_host(void *za, intptr_t off, void *host)  \
454 {                                                                           \
455     TYPE val = HOST(host);                                                  \
456     *(TYPE *)(za + tile_vslice_offset(off)) = val;                          \
457 }                                                                           \
458 static inline void sme_##NAME##_v_tlb(CPUARMState *env, void *za,           \
459                         intptr_t off, target_ulong addr, uintptr_t ra)      \
460 {                                                                           \
461     TYPE val = TLB(env, useronly_clean_ptr(addr), ra);                      \
462     *(TYPE *)(za + tile_vslice_offset(off)) = val;                          \
463 }
464 
465 #define DO_ST(NAME, TYPE, HOST, TLB)                                        \
466 static inline void sme_##NAME##_v_host(void *za, intptr_t off, void *host)  \
467 {                                                                           \
468     TYPE val = *(TYPE *)(za + tile_vslice_offset(off));                     \
469     HOST(host, val);                                                        \
470 }                                                                           \
471 static inline void sme_##NAME##_v_tlb(CPUARMState *env, void *za,           \
472                         intptr_t off, target_ulong addr, uintptr_t ra)      \
473 {                                                                           \
474     TYPE val = *(TYPE *)(za + tile_vslice_offset(off));                     \
475     TLB(env, useronly_clean_ptr(addr), val, ra);                            \
476 }
477 
478 #define DO_LDQ(HNAME, VNAME) \
479 static inline void VNAME##_v_host(void *za, intptr_t off, void *host)       \
480 {                                                                           \
481     HNAME##_host(za, tile_vslice_offset(off), host);                        \
482 }                                                                           \
483 static inline void VNAME##_v_tlb(CPUARMState *env, void *za, intptr_t off,  \
484                                target_ulong addr, uintptr_t ra)             \
485 {                                                                           \
486     HNAME##_tlb(env, za, tile_vslice_offset(off), addr, ra);                \
487 }
488 
489 #define DO_STQ(HNAME, VNAME) \
490 static inline void VNAME##_v_host(void *za, intptr_t off, void *host)       \
491 {                                                                           \
492     HNAME##_host(za, tile_vslice_offset(off), host);                        \
493 }                                                                           \
494 static inline void VNAME##_v_tlb(CPUARMState *env, void *za, intptr_t off,  \
495                                target_ulong addr, uintptr_t ra)             \
496 {                                                                           \
497     HNAME##_tlb(env, za, tile_vslice_offset(off), addr, ra);                \
498 }
499 
DO_LD(ld1b,uint8_t,ldub_p,cpu_ldub_data_ra)500 DO_LD(ld1b, uint8_t, ldub_p, cpu_ldub_data_ra)
501 DO_LD(ld1h_be, uint16_t, lduw_be_p, cpu_lduw_be_data_ra)
502 DO_LD(ld1h_le, uint16_t, lduw_le_p, cpu_lduw_le_data_ra)
503 DO_LD(ld1s_be, uint32_t, ldl_be_p, cpu_ldl_be_data_ra)
504 DO_LD(ld1s_le, uint32_t, ldl_le_p, cpu_ldl_le_data_ra)
505 DO_LD(ld1d_be, uint64_t, ldq_be_p, cpu_ldq_be_data_ra)
506 DO_LD(ld1d_le, uint64_t, ldq_le_p, cpu_ldq_le_data_ra)
507 
508 DO_LDQ(sve_ld1qq_be, sme_ld1q_be)
509 DO_LDQ(sve_ld1qq_le, sme_ld1q_le)
510 
511 DO_ST(st1b, uint8_t, stb_p, cpu_stb_data_ra)
512 DO_ST(st1h_be, uint16_t, stw_be_p, cpu_stw_be_data_ra)
513 DO_ST(st1h_le, uint16_t, stw_le_p, cpu_stw_le_data_ra)
514 DO_ST(st1s_be, uint32_t, stl_be_p, cpu_stl_be_data_ra)
515 DO_ST(st1s_le, uint32_t, stl_le_p, cpu_stl_le_data_ra)
516 DO_ST(st1d_be, uint64_t, stq_be_p, cpu_stq_be_data_ra)
517 DO_ST(st1d_le, uint64_t, stq_le_p, cpu_stq_le_data_ra)
518 
519 DO_STQ(sve_st1qq_be, sme_st1q_be)
520 DO_STQ(sve_st1qq_le, sme_st1q_le)
521 
522 #undef DO_LD
523 #undef DO_ST
524 #undef DO_LDQ
525 #undef DO_STQ
526 
527 /*
528  * Common helper for all contiguous predicated loads.
529  */
530 
531 static inline QEMU_ALWAYS_INLINE
532 void sme_ld1(CPUARMState *env, void *za, uint64_t *vg,
533              const target_ulong addr, uint32_t desc, const uintptr_t ra,
534              const int esz, uint32_t mtedesc, bool vertical,
535              sve_ldst1_host_fn *host_fn,
536              sve_ldst1_tlb_fn *tlb_fn,
537              ClearFn *clr_fn,
538              CopyFn *cpy_fn)
539 {
540     const intptr_t reg_max = simd_oprsz(desc);
541     const intptr_t esize = 1 << esz;
542     intptr_t reg_off, reg_last;
543     SVEContLdSt info;
544     void *host;
545     int flags;
546 
547     /* Find the active elements.  */
548     if (!sve_cont_ldst_elements(&info, addr, vg, reg_max, esz, esize)) {
549         /* The entire predicate was false; no load occurs.  */
550         clr_fn(za, 0, reg_max);
551         return;
552     }
553 
554     /* Probe the page(s).  Exit with exception for any invalid page. */
555     sve_cont_ldst_pages(&info, FAULT_ALL, env, addr, MMU_DATA_LOAD, ra);
556 
557     /* Handle watchpoints for all active elements. */
558     sve_cont_ldst_watchpoints(&info, env, vg, addr, esize, esize,
559                               BP_MEM_READ, ra);
560 
561     /*
562      * Handle mte checks for all active elements.
563      * Since TBI must be set for MTE, !mtedesc => !mte_active.
564      */
565     if (mtedesc) {
566         sve_cont_ldst_mte_check(&info, env, vg, addr, esize, esize,
567                                 mtedesc, ra);
568     }
569 
570     flags = info.page[0].flags | info.page[1].flags;
571     if (unlikely(flags != 0)) {
572 #ifdef CONFIG_USER_ONLY
573         g_assert_not_reached();
574 #else
575         /*
576          * At least one page includes MMIO.
577          * Any bus operation can fail with cpu_transaction_failed,
578          * which for ARM will raise SyncExternal.  Perform the load
579          * into scratch memory to preserve register state until the end.
580          */
581         ARMVectorReg scratch = { };
582 
583         reg_off = info.reg_off_first[0];
584         reg_last = info.reg_off_last[1];
585         if (reg_last < 0) {
586             reg_last = info.reg_off_split;
587             if (reg_last < 0) {
588                 reg_last = info.reg_off_last[0];
589             }
590         }
591 
592         do {
593             uint64_t pg = vg[reg_off >> 6];
594             do {
595                 if ((pg >> (reg_off & 63)) & 1) {
596                     tlb_fn(env, &scratch, reg_off, addr + reg_off, ra);
597                 }
598                 reg_off += esize;
599             } while (reg_off & 63);
600         } while (reg_off <= reg_last);
601 
602         cpy_fn(za, &scratch, reg_max);
603         return;
604 #endif
605     }
606 
607     /* The entire operation is in RAM, on valid pages. */
608 
609     reg_off = info.reg_off_first[0];
610     reg_last = info.reg_off_last[0];
611     host = info.page[0].host;
612 
613     if (!vertical) {
614         memset(za, 0, reg_max);
615     } else if (reg_off) {
616         clr_fn(za, 0, reg_off);
617     }
618 
619     set_helper_retaddr(ra);
620 
621     while (reg_off <= reg_last) {
622         uint64_t pg = vg[reg_off >> 6];
623         do {
624             if ((pg >> (reg_off & 63)) & 1) {
625                 host_fn(za, reg_off, host + reg_off);
626             } else if (vertical) {
627                 clr_fn(za, reg_off, esize);
628             }
629             reg_off += esize;
630         } while (reg_off <= reg_last && (reg_off & 63));
631     }
632 
633     clear_helper_retaddr();
634 
635     /*
636      * Use the slow path to manage the cross-page misalignment.
637      * But we know this is RAM and cannot trap.
638      */
639     reg_off = info.reg_off_split;
640     if (unlikely(reg_off >= 0)) {
641         tlb_fn(env, za, reg_off, addr + reg_off, ra);
642     }
643 
644     reg_off = info.reg_off_first[1];
645     if (unlikely(reg_off >= 0)) {
646         reg_last = info.reg_off_last[1];
647         host = info.page[1].host;
648 
649         set_helper_retaddr(ra);
650 
651         do {
652             uint64_t pg = vg[reg_off >> 6];
653             do {
654                 if ((pg >> (reg_off & 63)) & 1) {
655                     host_fn(za, reg_off, host + reg_off);
656                 } else if (vertical) {
657                     clr_fn(za, reg_off, esize);
658                 }
659                 reg_off += esize;
660             } while (reg_off & 63);
661         } while (reg_off <= reg_last);
662 
663         clear_helper_retaddr();
664     }
665 }
666 
667 static inline QEMU_ALWAYS_INLINE
sme_ld1_mte(CPUARMState * env,void * za,uint64_t * vg,target_ulong addr,uint64_t desc,uintptr_t ra,const int esz,bool vertical,sve_ldst1_host_fn * host_fn,sve_ldst1_tlb_fn * tlb_fn,ClearFn * clr_fn,CopyFn * cpy_fn)668 void sme_ld1_mte(CPUARMState *env, void *za, uint64_t *vg,
669                  target_ulong addr, uint64_t desc, uintptr_t ra,
670                  const int esz, bool vertical,
671                  sve_ldst1_host_fn *host_fn,
672                  sve_ldst1_tlb_fn *tlb_fn,
673                  ClearFn *clr_fn,
674                  CopyFn *cpy_fn)
675 {
676     uint32_t mtedesc = desc >> 32;
677     int bit55 = extract64(addr, 55, 1);
678 
679     /* Perform gross MTE suppression early. */
680     if (!tbi_check(mtedesc, bit55) ||
681         tcma_check(mtedesc, bit55, allocation_tag_from_addr(addr))) {
682         mtedesc = 0;
683     }
684 
685     sme_ld1(env, za, vg, addr, desc, ra, esz, mtedesc, vertical,
686             host_fn, tlb_fn, clr_fn, cpy_fn);
687 }
688 
689 #define DO_LD(L, END, ESZ)                                                 \
690 void HELPER(sme_ld1##L##END##_h)(CPUARMState *env, void *za, void *vg,     \
691                                  target_ulong addr, uint64_t desc)         \
692 {                                                                          \
693     sme_ld1(env, za, vg, addr, desc, GETPC(), ESZ, 0, false,               \
694             sve_ld1##L##L##END##_host, sve_ld1##L##L##END##_tlb,           \
695             clear_horizontal, copy_horizontal);                            \
696 }                                                                          \
697 void HELPER(sme_ld1##L##END##_v)(CPUARMState *env, void *za, void *vg,     \
698                                  target_ulong addr, uint64_t desc)         \
699 {                                                                          \
700     sme_ld1(env, za, vg, addr, desc, GETPC(), ESZ, 0, true,                \
701             sme_ld1##L##END##_v_host, sme_ld1##L##END##_v_tlb,             \
702             clear_vertical_##L, copy_vertical_##L);                        \
703 }                                                                          \
704 void HELPER(sme_ld1##L##END##_h_mte)(CPUARMState *env, void *za, void *vg, \
705                                      target_ulong addr, uint64_t desc)     \
706 {                                                                          \
707     sme_ld1_mte(env, za, vg, addr, desc, GETPC(), ESZ, false,              \
708                 sve_ld1##L##L##END##_host, sve_ld1##L##L##END##_tlb,       \
709                 clear_horizontal, copy_horizontal);                        \
710 }                                                                          \
711 void HELPER(sme_ld1##L##END##_v_mte)(CPUARMState *env, void *za, void *vg, \
712                                      target_ulong addr, uint64_t desc)     \
713 {                                                                          \
714     sme_ld1_mte(env, za, vg, addr, desc, GETPC(), ESZ, true,               \
715                 sme_ld1##L##END##_v_host, sme_ld1##L##END##_v_tlb,         \
716                 clear_vertical_##L, copy_vertical_##L);                    \
717 }
718 
719 DO_LD(b, , MO_8)
DO_LD(h,_be,MO_16)720 DO_LD(h, _be, MO_16)
721 DO_LD(h, _le, MO_16)
722 DO_LD(s, _be, MO_32)
723 DO_LD(s, _le, MO_32)
724 DO_LD(d, _be, MO_64)
725 DO_LD(d, _le, MO_64)
726 DO_LD(q, _be, MO_128)
727 DO_LD(q, _le, MO_128)
728 
729 #undef DO_LD
730 
731 /*
732  * Common helper for all contiguous predicated stores.
733  */
734 
735 static inline QEMU_ALWAYS_INLINE
736 void sme_st1(CPUARMState *env, void *za, uint64_t *vg,
737              const target_ulong addr, uint32_t desc, const uintptr_t ra,
738              const int esz, uint32_t mtedesc, bool vertical,
739              sve_ldst1_host_fn *host_fn,
740              sve_ldst1_tlb_fn *tlb_fn)
741 {
742     const intptr_t reg_max = simd_oprsz(desc);
743     const intptr_t esize = 1 << esz;
744     intptr_t reg_off, reg_last;
745     SVEContLdSt info;
746     void *host;
747     int flags;
748 
749     /* Find the active elements.  */
750     if (!sve_cont_ldst_elements(&info, addr, vg, reg_max, esz, esize)) {
751         /* The entire predicate was false; no store occurs.  */
752         return;
753     }
754 
755     /* Probe the page(s).  Exit with exception for any invalid page. */
756     sve_cont_ldst_pages(&info, FAULT_ALL, env, addr, MMU_DATA_STORE, ra);
757 
758     /* Handle watchpoints for all active elements. */
759     sve_cont_ldst_watchpoints(&info, env, vg, addr, esize, esize,
760                               BP_MEM_WRITE, ra);
761 
762     /*
763      * Handle mte checks for all active elements.
764      * Since TBI must be set for MTE, !mtedesc => !mte_active.
765      */
766     if (mtedesc) {
767         sve_cont_ldst_mte_check(&info, env, vg, addr, esize, esize,
768                                 mtedesc, ra);
769     }
770 
771     flags = info.page[0].flags | info.page[1].flags;
772     if (unlikely(flags != 0)) {
773 #ifdef CONFIG_USER_ONLY
774         g_assert_not_reached();
775 #else
776         /*
777          * At least one page includes MMIO.
778          * Any bus operation can fail with cpu_transaction_failed,
779          * which for ARM will raise SyncExternal.  We cannot avoid
780          * this fault and will leave with the store incomplete.
781          */
782         reg_off = info.reg_off_first[0];
783         reg_last = info.reg_off_last[1];
784         if (reg_last < 0) {
785             reg_last = info.reg_off_split;
786             if (reg_last < 0) {
787                 reg_last = info.reg_off_last[0];
788             }
789         }
790 
791         do {
792             uint64_t pg = vg[reg_off >> 6];
793             do {
794                 if ((pg >> (reg_off & 63)) & 1) {
795                     tlb_fn(env, za, reg_off, addr + reg_off, ra);
796                 }
797                 reg_off += esize;
798             } while (reg_off & 63);
799         } while (reg_off <= reg_last);
800         return;
801 #endif
802     }
803 
804     reg_off = info.reg_off_first[0];
805     reg_last = info.reg_off_last[0];
806     host = info.page[0].host;
807 
808     set_helper_retaddr(ra);
809 
810     while (reg_off <= reg_last) {
811         uint64_t pg = vg[reg_off >> 6];
812         do {
813             if ((pg >> (reg_off & 63)) & 1) {
814                 host_fn(za, reg_off, host + reg_off);
815             }
816             reg_off += 1 << esz;
817         } while (reg_off <= reg_last && (reg_off & 63));
818     }
819 
820     clear_helper_retaddr();
821 
822     /*
823      * Use the slow path to manage the cross-page misalignment.
824      * But we know this is RAM and cannot trap.
825      */
826     reg_off = info.reg_off_split;
827     if (unlikely(reg_off >= 0)) {
828         tlb_fn(env, za, reg_off, addr + reg_off, ra);
829     }
830 
831     reg_off = info.reg_off_first[1];
832     if (unlikely(reg_off >= 0)) {
833         reg_last = info.reg_off_last[1];
834         host = info.page[1].host;
835 
836         set_helper_retaddr(ra);
837 
838         do {
839             uint64_t pg = vg[reg_off >> 6];
840             do {
841                 if ((pg >> (reg_off & 63)) & 1) {
842                     host_fn(za, reg_off, host + reg_off);
843                 }
844                 reg_off += 1 << esz;
845             } while (reg_off & 63);
846         } while (reg_off <= reg_last);
847 
848         clear_helper_retaddr();
849     }
850 }
851 
852 static inline QEMU_ALWAYS_INLINE
sme_st1_mte(CPUARMState * env,void * za,uint64_t * vg,target_ulong addr,uint64_t desc,uintptr_t ra,int esz,bool vertical,sve_ldst1_host_fn * host_fn,sve_ldst1_tlb_fn * tlb_fn)853 void sme_st1_mte(CPUARMState *env, void *za, uint64_t *vg, target_ulong addr,
854                  uint64_t desc, uintptr_t ra, int esz, bool vertical,
855                  sve_ldst1_host_fn *host_fn,
856                  sve_ldst1_tlb_fn *tlb_fn)
857 {
858     uint32_t mtedesc = desc >> 32;
859     int bit55 = extract64(addr, 55, 1);
860 
861     /* Perform gross MTE suppression early. */
862     if (!tbi_check(mtedesc, bit55) ||
863         tcma_check(mtedesc, bit55, allocation_tag_from_addr(addr))) {
864         mtedesc = 0;
865     }
866 
867     sme_st1(env, za, vg, addr, desc, ra, esz, mtedesc,
868             vertical, host_fn, tlb_fn);
869 }
870 
871 #define DO_ST(L, END, ESZ)                                                 \
872 void HELPER(sme_st1##L##END##_h)(CPUARMState *env, void *za, void *vg,     \
873                                  target_ulong addr, uint64_t desc)         \
874 {                                                                          \
875     sme_st1(env, za, vg, addr, desc, GETPC(), ESZ, 0, false,               \
876             sve_st1##L##L##END##_host, sve_st1##L##L##END##_tlb);          \
877 }                                                                          \
878 void HELPER(sme_st1##L##END##_v)(CPUARMState *env, void *za, void *vg,     \
879                                  target_ulong addr, uint64_t desc)         \
880 {                                                                          \
881     sme_st1(env, za, vg, addr, desc, GETPC(), ESZ, 0, true,                \
882             sme_st1##L##END##_v_host, sme_st1##L##END##_v_tlb);            \
883 }                                                                          \
884 void HELPER(sme_st1##L##END##_h_mte)(CPUARMState *env, void *za, void *vg, \
885                                      target_ulong addr, uint64_t desc)     \
886 {                                                                          \
887     sme_st1_mte(env, za, vg, addr, desc, GETPC(), ESZ, false,              \
888                 sve_st1##L##L##END##_host, sve_st1##L##L##END##_tlb);      \
889 }                                                                          \
890 void HELPER(sme_st1##L##END##_v_mte)(CPUARMState *env, void *za, void *vg, \
891                                      target_ulong addr, uint64_t desc)     \
892 {                                                                          \
893     sme_st1_mte(env, za, vg, addr, desc, GETPC(), ESZ, true,               \
894                 sme_st1##L##END##_v_host, sme_st1##L##END##_v_tlb);        \
895 }
896 
897 DO_ST(b, , MO_8)
DO_ST(h,_be,MO_16)898 DO_ST(h, _be, MO_16)
899 DO_ST(h, _le, MO_16)
900 DO_ST(s, _be, MO_32)
901 DO_ST(s, _le, MO_32)
902 DO_ST(d, _be, MO_64)
903 DO_ST(d, _le, MO_64)
904 DO_ST(q, _be, MO_128)
905 DO_ST(q, _le, MO_128)
906 
907 #undef DO_ST
908 
909 void HELPER(sme_addha_s)(void *vzda, void *vzn, void *vpn,
910                          void *vpm, uint32_t desc)
911 {
912     intptr_t row, col, oprsz = simd_oprsz(desc) / 4;
913     uint64_t *pn = vpn, *pm = vpm;
914     uint32_t *zda = vzda, *zn = vzn;
915 
916     for (row = 0; row < oprsz; ) {
917         uint64_t pa = pn[row >> 4];
918         do {
919             if (pa & 1) {
920                 for (col = 0; col < oprsz; ) {
921                     uint64_t pb = pm[col >> 4];
922                     do {
923                         if (pb & 1) {
924                             zda[tile_vslice_index(row) + H4(col)] += zn[H4(col)];
925                         }
926                         pb >>= 4;
927                     } while (++col & 15);
928                 }
929             }
930             pa >>= 4;
931         } while (++row & 15);
932     }
933 }
934 
HELPER(sme_addha_d)935 void HELPER(sme_addha_d)(void *vzda, void *vzn, void *vpn,
936                          void *vpm, uint32_t desc)
937 {
938     intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
939     uint8_t *pn = vpn, *pm = vpm;
940     uint64_t *zda = vzda, *zn = vzn;
941 
942     for (row = 0; row < oprsz; ++row) {
943         if (pn[H1(row)] & 1) {
944             for (col = 0; col < oprsz; ++col) {
945                 if (pm[H1(col)] & 1) {
946                     zda[tile_vslice_index(row) + col] += zn[col];
947                 }
948             }
949         }
950     }
951 }
952 
HELPER(sme_addva_s)953 void HELPER(sme_addva_s)(void *vzda, void *vzn, void *vpn,
954                          void *vpm, uint32_t desc)
955 {
956     intptr_t row, col, oprsz = simd_oprsz(desc) / 4;
957     uint64_t *pn = vpn, *pm = vpm;
958     uint32_t *zda = vzda, *zn = vzn;
959 
960     for (row = 0; row < oprsz; ) {
961         uint64_t pa = pn[row >> 4];
962         do {
963             if (pa & 1) {
964                 uint32_t zn_row = zn[H4(row)];
965                 for (col = 0; col < oprsz; ) {
966                     uint64_t pb = pm[col >> 4];
967                     do {
968                         if (pb & 1) {
969                             zda[tile_vslice_index(row) + H4(col)] += zn_row;
970                         }
971                         pb >>= 4;
972                     } while (++col & 15);
973                 }
974             }
975             pa >>= 4;
976         } while (++row & 15);
977     }
978 }
979 
HELPER(sme_addva_d)980 void HELPER(sme_addva_d)(void *vzda, void *vzn, void *vpn,
981                          void *vpm, uint32_t desc)
982 {
983     intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
984     uint8_t *pn = vpn, *pm = vpm;
985     uint64_t *zda = vzda, *zn = vzn;
986 
987     for (row = 0; row < oprsz; ++row) {
988         if (pn[H1(row)] & 1) {
989             uint64_t zn_row = zn[row];
990             for (col = 0; col < oprsz; ++col) {
991                 if (pm[H1(col)] & 1) {
992                     zda[tile_vslice_index(row) + col] += zn_row;
993                 }
994             }
995         }
996     }
997 }
998 
do_fmopa_h(void * vza,void * vzn,void * vzm,uint16_t * pn,uint16_t * pm,float_status * fpst,uint32_t desc,uint16_t negx,int negf)999 static void do_fmopa_h(void *vza, void *vzn, void *vzm, uint16_t *pn,
1000                        uint16_t *pm, float_status *fpst, uint32_t desc,
1001                        uint16_t negx, int negf)
1002 {
1003     intptr_t row, col, oprsz = simd_maxsz(desc);
1004 
1005     for (row = 0; row < oprsz; ) {
1006         uint16_t pa = pn[H2(row >> 4)];
1007         do {
1008             if (pa & 1) {
1009                 void *vza_row = vza + tile_vslice_offset(row);
1010                 uint16_t n = *(uint32_t *)(vzn + H1_2(row)) ^ negx;
1011 
1012                 for (col = 0; col < oprsz; ) {
1013                     uint16_t pb = pm[H2(col >> 4)];
1014                     do {
1015                         if (pb & 1) {
1016                             uint16_t *a = vza_row + H1_2(col);
1017                             uint16_t *m = vzm + H1_2(col);
1018                             *a = float16_muladd(n, *m, *a, negf, fpst);
1019                         }
1020                         col += 2;
1021                         pb >>= 2;
1022                     } while (col & 15);
1023                 }
1024             }
1025             row += 2;
1026             pa >>= 2;
1027         } while (row & 15);
1028     }
1029 }
1030 
HELPER(sme_fmopa_h)1031 void HELPER(sme_fmopa_h)(void *vza, void *vzn, void *vzm, void *vpn,
1032                          void *vpm, float_status *fpst, uint32_t desc)
1033 {
1034     do_fmopa_h(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0);
1035 }
1036 
HELPER(sme_fmops_h)1037 void HELPER(sme_fmops_h)(void *vza, void *vzn, void *vzm, void *vpn,
1038                          void *vpm, float_status *fpst, uint32_t desc)
1039 {
1040     do_fmopa_h(vza, vzn, vzm, vpn, vpm, fpst, desc, 1u << 15, 0);
1041 }
1042 
HELPER(sme_ah_fmops_h)1043 void HELPER(sme_ah_fmops_h)(void *vza, void *vzn, void *vzm, void *vpn,
1044                             void *vpm, float_status *fpst, uint32_t desc)
1045 {
1046     do_fmopa_h(vza, vzn, vzm, vpn, vpm, fpst, desc, 0,
1047                float_muladd_negate_product);
1048 }
1049 
do_fmopa_s(void * vza,void * vzn,void * vzm,uint16_t * pn,uint16_t * pm,float_status * fpst,uint32_t desc,uint32_t negx,int negf)1050 static void do_fmopa_s(void *vza, void *vzn, void *vzm, uint16_t *pn,
1051                        uint16_t *pm, float_status *fpst, uint32_t desc,
1052                        uint32_t negx, int negf)
1053 {
1054     intptr_t row, col, oprsz = simd_maxsz(desc);
1055 
1056     for (row = 0; row < oprsz; ) {
1057         uint16_t pa = pn[H2(row >> 4)];
1058         do {
1059             if (pa & 1) {
1060                 void *vza_row = vza + tile_vslice_offset(row);
1061                 uint32_t n = *(uint32_t *)(vzn + H1_4(row)) ^ negx;
1062 
1063                 for (col = 0; col < oprsz; ) {
1064                     uint16_t pb = pm[H2(col >> 4)];
1065                     do {
1066                         if (pb & 1) {
1067                             uint32_t *a = vza_row + H1_4(col);
1068                             uint32_t *m = vzm + H1_4(col);
1069                             *a = float32_muladd(n, *m, *a, negf, fpst);
1070                         }
1071                         col += 4;
1072                         pb >>= 4;
1073                     } while (col & 15);
1074                 }
1075             }
1076             row += 4;
1077             pa >>= 4;
1078         } while (row & 15);
1079     }
1080 }
1081 
HELPER(sme_fmopa_s)1082 void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
1083                          void *vpm, float_status *fpst, uint32_t desc)
1084 {
1085     do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0);
1086 }
1087 
HELPER(sme_fmops_s)1088 void HELPER(sme_fmops_s)(void *vza, void *vzn, void *vzm, void *vpn,
1089                          void *vpm, float_status *fpst, uint32_t desc)
1090 {
1091     do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 1u << 31, 0);
1092 }
1093 
HELPER(sme_ah_fmops_s)1094 void HELPER(sme_ah_fmops_s)(void *vza, void *vzn, void *vzm, void *vpn,
1095                             void *vpm, float_status *fpst, uint32_t desc)
1096 {
1097     do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 0,
1098                float_muladd_negate_product);
1099 }
1100 
do_fmopa_d(uint64_t * za,uint64_t * zn,uint64_t * zm,uint8_t * pn,uint8_t * pm,float_status * fpst,uint32_t desc,uint64_t negx,int negf)1101 static void do_fmopa_d(uint64_t *za, uint64_t *zn, uint64_t *zm, uint8_t *pn,
1102                        uint8_t *pm, float_status *fpst, uint32_t desc,
1103                        uint64_t negx, int negf)
1104 {
1105     intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
1106 
1107     for (row = 0; row < oprsz; ++row) {
1108         if (pn[H1(row)] & 1) {
1109             uint64_t *za_row = &za[tile_vslice_index(row)];
1110             uint64_t n = zn[row] ^ negx;
1111 
1112             for (col = 0; col < oprsz; ++col) {
1113                 if (pm[H1(col)] & 1) {
1114                     uint64_t *a = &za_row[col];
1115                     *a = float64_muladd(n, zm[col], *a, negf, fpst);
1116                 }
1117             }
1118         }
1119     }
1120 }
1121 
HELPER(sme_fmopa_d)1122 void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn,
1123                          void *vpm, float_status *fpst, uint32_t desc)
1124 {
1125     do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0);
1126 }
1127 
HELPER(sme_fmops_d)1128 void HELPER(sme_fmops_d)(void *vza, void *vzn, void *vzm, void *vpn,
1129                          void *vpm, float_status *fpst, uint32_t desc)
1130 {
1131     do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 1ull << 63, 0);
1132 }
1133 
HELPER(sme_ah_fmops_d)1134 void HELPER(sme_ah_fmops_d)(void *vza, void *vzn, void *vzm, void *vpn,
1135                             void *vpm, float_status *fpst, uint32_t desc)
1136 {
1137     do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 0,
1138                float_muladd_negate_product);
1139 }
1140 
do_bfmopa(void * vza,void * vzn,void * vzm,uint16_t * pn,uint16_t * pm,float_status * fpst,uint32_t desc,uint16_t negx,int negf)1141 static void do_bfmopa(void *vza, void *vzn, void *vzm, uint16_t *pn,
1142                       uint16_t *pm, float_status *fpst, uint32_t desc,
1143                       uint16_t negx, int negf)
1144 {
1145     intptr_t row, col, oprsz = simd_maxsz(desc);
1146 
1147     for (row = 0; row < oprsz; ) {
1148         uint16_t pa = pn[H2(row >> 4)];
1149         do {
1150             if (pa & 1) {
1151                 void *vza_row = vza + tile_vslice_offset(row);
1152                 uint16_t n = *(uint32_t *)(vzn + H1_2(row)) ^ negx;
1153 
1154                 for (col = 0; col < oprsz; ) {
1155                     uint16_t pb = pm[H2(col >> 4)];
1156                     do {
1157                         if (pb & 1) {
1158                             uint16_t *a = vza_row + H1_2(col);
1159                             uint16_t *m = vzm + H1_2(col);
1160                             *a = bfloat16_muladd(n, *m, *a, negf, fpst);
1161                         }
1162                         col += 2;
1163                         pb >>= 2;
1164                     } while (col & 15);
1165                 }
1166             }
1167             row += 2;
1168             pa >>= 2;
1169         } while (row & 15);
1170     }
1171 }
1172 
HELPER(sme_bfmopa)1173 void HELPER(sme_bfmopa)(void *vza, void *vzn, void *vzm, void *vpn,
1174                         void *vpm, float_status *fpst, uint32_t desc)
1175 {
1176     do_bfmopa(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0);
1177 }
1178 
HELPER(sme_bfmops)1179 void HELPER(sme_bfmops)(void *vza, void *vzn, void *vzm, void *vpn,
1180                         void *vpm, float_status *fpst, uint32_t desc)
1181 {
1182     do_bfmopa(vza, vzn, vzm, vpn, vpm, fpst, desc, 1u << 15, 0);
1183 }
1184 
HELPER(sme_ah_bfmops)1185 void HELPER(sme_ah_bfmops)(void *vza, void *vzn, void *vzm, void *vpn,
1186                            void *vpm, float_status *fpst, uint32_t desc)
1187 {
1188     do_bfmopa(vza, vzn, vzm, vpn, vpm, fpst, desc, 0,
1189               float_muladd_negate_product);
1190 }
1191 
1192 /*
1193  * Alter PAIR as needed for controlling predicates being false,
1194  * and for NEG on an enabled row element.
1195  */
f16mop_adj_pair(uint32_t pair,uint32_t pg,uint32_t neg)1196 static inline uint32_t f16mop_adj_pair(uint32_t pair, uint32_t pg, uint32_t neg)
1197 {
1198     /*
1199      * The pseudocode uses a conditional negate after the conditional zero.
1200      * It is simpler here to unconditionally negate before conditional zero.
1201      */
1202     pair ^= neg;
1203     if (!(pg & 1)) {
1204         pair &= 0xffff0000u;
1205     }
1206     if (!(pg & 4)) {
1207         pair &= 0x0000ffffu;
1208     }
1209     return pair;
1210 }
1211 
f16mop_ah_neg_adj_pair(uint32_t pair,uint32_t pg)1212 static inline uint32_t f16mop_ah_neg_adj_pair(uint32_t pair, uint32_t pg)
1213 {
1214     uint32_t l = pg & 1 ? float16_ah_chs(pair) : 0;
1215     uint32_t h = pg & 4 ? float16_ah_chs(pair >> 16) : 0;
1216     return l | (h << 16);
1217 }
1218 
bf16mop_ah_neg_adj_pair(uint32_t pair,uint32_t pg)1219 static inline uint32_t bf16mop_ah_neg_adj_pair(uint32_t pair, uint32_t pg)
1220 {
1221     uint32_t l = pg & 1 ? bfloat16_ah_chs(pair) : 0;
1222     uint32_t h = pg & 4 ? bfloat16_ah_chs(pair >> 16) : 0;
1223     return l | (h << 16);
1224 }
1225 
f16_dotadd(float32 sum,uint32_t e1,uint32_t e2,float_status * s_f16,float_status * s_std,float_status * s_odd)1226 static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2,
1227                           float_status *s_f16, float_status *s_std,
1228                           float_status *s_odd)
1229 {
1230     /*
1231      * We need three different float_status for different parts of this
1232      * operation:
1233      *  - the input conversion of the float16 values must use the
1234      *    f16-specific float_status, so that the FPCR.FZ16 control is applied
1235      *  - operations on float32 including the final accumulation must use
1236      *    the normal float_status, so that FPCR.FZ is applied
1237      *  - we have pre-set-up copy of s_std which is set to round-to-odd,
1238      *    for the multiply (see below)
1239      */
1240     float16 h1r = e1 & 0xffff;
1241     float16 h1c = e1 >> 16;
1242     float16 h2r = e2 & 0xffff;
1243     float16 h2c = e2 >> 16;
1244     float32 t32;
1245 
1246     /* C.f. FPProcessNaNs4 */
1247     if (float16_is_any_nan(h1r) || float16_is_any_nan(h1c) ||
1248         float16_is_any_nan(h2r) || float16_is_any_nan(h2c)) {
1249         float16 t16;
1250 
1251         if (float16_is_signaling_nan(h1r, s_f16)) {
1252             t16 = h1r;
1253         } else if (float16_is_signaling_nan(h1c, s_f16)) {
1254             t16 = h1c;
1255         } else if (float16_is_signaling_nan(h2r, s_f16)) {
1256             t16 = h2r;
1257         } else if (float16_is_signaling_nan(h2c, s_f16)) {
1258             t16 = h2c;
1259         } else if (float16_is_any_nan(h1r)) {
1260             t16 = h1r;
1261         } else if (float16_is_any_nan(h1c)) {
1262             t16 = h1c;
1263         } else if (float16_is_any_nan(h2r)) {
1264             t16 = h2r;
1265         } else {
1266             t16 = h2c;
1267         }
1268         t32 = float16_to_float32(t16, true, s_f16);
1269     } else {
1270         float64 e1r = float16_to_float64(h1r, true, s_f16);
1271         float64 e1c = float16_to_float64(h1c, true, s_f16);
1272         float64 e2r = float16_to_float64(h2r, true, s_f16);
1273         float64 e2c = float16_to_float64(h2c, true, s_f16);
1274         float64 t64;
1275 
1276         /*
1277          * The ARM pseudocode function FPDot performs both multiplies
1278          * and the add with a single rounding operation.  Emulate this
1279          * by performing the first multiply in round-to-odd, then doing
1280          * the second multiply as fused multiply-add, and rounding to
1281          * float32 all in one step.
1282          */
1283         t64 = float64_mul(e1r, e2r, s_odd);
1284         t64 = float64r32_muladd(e1c, e2c, t64, 0, s_std);
1285 
1286         /* This conversion is exact, because we've already rounded. */
1287         t32 = float64_to_float32(t64, s_std);
1288     }
1289 
1290     /* The final accumulation step is not fused. */
1291     return float32_add(sum, t32, s_std);
1292 }
1293 
do_fmopa_w_h(void * vza,void * vzn,void * vzm,uint16_t * pn,uint16_t * pm,CPUARMState * env,uint32_t desc,uint32_t negx,bool ah_neg)1294 static void do_fmopa_w_h(void *vza, void *vzn, void *vzm, uint16_t *pn,
1295                          uint16_t *pm, CPUARMState *env, uint32_t desc,
1296                          uint32_t negx, bool ah_neg)
1297 {
1298     intptr_t row, col, oprsz = simd_maxsz(desc);
1299     float_status fpst_odd = env->vfp.fp_status[FPST_ZA];
1300 
1301     set_float_rounding_mode(float_round_to_odd, &fpst_odd);
1302 
1303     for (row = 0; row < oprsz; ) {
1304         uint16_t prow = pn[H2(row >> 4)];
1305         do {
1306             void *vza_row = vza + tile_vslice_offset(row);
1307             uint32_t n = *(uint32_t *)(vzn + H1_4(row));
1308 
1309             if (ah_neg) {
1310                 n = f16mop_ah_neg_adj_pair(n, prow);
1311             } else {
1312                 n = f16mop_adj_pair(n, prow, negx);
1313             }
1314 
1315             for (col = 0; col < oprsz; ) {
1316                 uint16_t pcol = pm[H2(col >> 4)];
1317                 do {
1318                     if (prow & pcol & 0b0101) {
1319                         uint32_t *a = vza_row + H1_4(col);
1320                         uint32_t m = *(uint32_t *)(vzm + H1_4(col));
1321 
1322                         m = f16mop_adj_pair(m, pcol, 0);
1323                         *a = f16_dotadd(*a, n, m,
1324                                         &env->vfp.fp_status[FPST_ZA_F16],
1325                                         &env->vfp.fp_status[FPST_ZA],
1326                                         &fpst_odd);
1327                     }
1328                     col += 4;
1329                     pcol >>= 4;
1330                 } while (col & 15);
1331             }
1332             row += 4;
1333             prow >>= 4;
1334         } while (row & 15);
1335     }
1336 }
1337 
HELPER(sme_fmopa_w_h)1338 void HELPER(sme_fmopa_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
1339                            void *vpm, CPUARMState *env, uint32_t desc)
1340 {
1341     do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0, false);
1342 }
1343 
HELPER(sme_fmops_w_h)1344 void HELPER(sme_fmops_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
1345                            void *vpm, CPUARMState *env, uint32_t desc)
1346 {
1347     do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0x80008000u, false);
1348 }
1349 
HELPER(sme_ah_fmops_w_h)1350 void HELPER(sme_ah_fmops_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
1351                               void *vpm, CPUARMState *env, uint32_t desc)
1352 {
1353     do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0, true);
1354 }
1355 
HELPER(sme2_fdot_h)1356 void HELPER(sme2_fdot_h)(void *vd, void *vn, void *vm, void *va,
1357                          CPUARMState *env, uint32_t desc)
1358 {
1359     intptr_t i, oprsz = simd_maxsz(desc);
1360     bool za = extract32(desc, SIMD_DATA_SHIFT, 1);
1361     float_status *fpst_std = &env->vfp.fp_status[za ? FPST_ZA : FPST_A64];
1362     float_status *fpst_f16 = &env->vfp.fp_status[za ? FPST_ZA_F16 : FPST_A64_F16];
1363     float_status fpst_odd = *fpst_std;
1364     float32 *d = vd, *a = va;
1365     uint32_t *n = vn, *m = vm;
1366 
1367     set_float_rounding_mode(float_round_to_odd, &fpst_odd);
1368 
1369     for (i = 0; i < oprsz / sizeof(float32); ++i) {
1370         d[H4(i)] = f16_dotadd(a[H4(i)], n[H4(i)], m[H4(i)],
1371                               fpst_f16, fpst_std, &fpst_odd);
1372     }
1373 }
1374 
HELPER(sme2_fdot_idx_h)1375 void HELPER(sme2_fdot_idx_h)(void *vd, void *vn, void *vm, void *va,
1376                              CPUARMState *env, uint32_t desc)
1377 {
1378     intptr_t i, j, oprsz = simd_maxsz(desc);
1379     intptr_t elements = oprsz / sizeof(float32);
1380     intptr_t eltspersegment = MIN(4, elements);
1381     int idx = extract32(desc, SIMD_DATA_SHIFT, 2);
1382     bool za = extract32(desc, SIMD_DATA_SHIFT + 2, 1);
1383     float_status *fpst_std = &env->vfp.fp_status[za ? FPST_ZA : FPST_A64];
1384     float_status *fpst_f16 = &env->vfp.fp_status[za ? FPST_ZA_F16 : FPST_A64_F16];
1385     float_status fpst_odd = *fpst_std;
1386     float32 *d = vd, *a = va;
1387     uint32_t *n = vn, *m = (uint32_t *)vm + H4(idx);
1388 
1389     set_float_rounding_mode(float_round_to_odd, &fpst_odd);
1390 
1391     for (i = 0; i < elements; i += eltspersegment) {
1392         uint32_t mm = m[i];
1393         for (j = 0; j < eltspersegment; ++j) {
1394             d[H4(i + j)] = f16_dotadd(a[H4(i + j)], n[H4(i + j)], mm,
1395                                       fpst_f16, fpst_std, &fpst_odd);
1396         }
1397     }
1398 }
1399 
HELPER(sme2_fvdot_idx_h)1400 void HELPER(sme2_fvdot_idx_h)(void *vd, void *vn, void *vm, void *va,
1401                               CPUARMState *env, uint32_t desc)
1402 {
1403     intptr_t i, j, oprsz = simd_maxsz(desc);
1404     intptr_t elements = oprsz / sizeof(float32);
1405     intptr_t eltspersegment = MIN(4, elements);
1406     int idx = extract32(desc, SIMD_DATA_SHIFT, 2);
1407     int sel = extract32(desc, SIMD_DATA_SHIFT + 2, 1);
1408     float_status fpst_odd, *fpst_std, *fpst_f16;
1409     float32 *d = vd, *a = va;
1410     uint16_t *n0 = vn;
1411     uint16_t *n1 = vn + sizeof(ARMVectorReg);
1412     uint32_t *m = (uint32_t *)vm + H4(idx);
1413 
1414     fpst_std = &env->vfp.fp_status[FPST_ZA];
1415     fpst_f16 = &env->vfp.fp_status[FPST_ZA_F16];
1416     fpst_odd = *fpst_std;
1417     set_float_rounding_mode(float_round_to_odd, &fpst_odd);
1418 
1419     for (i = 0; i < elements; i += eltspersegment) {
1420         uint32_t mm = m[i];
1421         for (j = 0; j < eltspersegment; ++j) {
1422             uint32_t nn = (n0[H2(2 * (i + j) + sel)])
1423                         | (n1[H2(2 * (i + j) + sel)] << 16);
1424             d[i + H4(j)] = f16_dotadd(a[i + H4(j)], nn, mm,
1425                                       fpst_f16, fpst_std, &fpst_odd);
1426         }
1427     }
1428 }
1429 
do_bfmopa_w(void * vza,void * vzn,void * vzm,uint16_t * pn,uint16_t * pm,CPUARMState * env,uint32_t desc,uint32_t negx,bool ah_neg)1430 static void do_bfmopa_w(void *vza, void *vzn, void *vzm,
1431                         uint16_t *pn, uint16_t *pm, CPUARMState *env,
1432                         uint32_t desc, uint32_t negx, bool ah_neg)
1433 {
1434     intptr_t row, col, oprsz = simd_maxsz(desc);
1435     float_status fpst, fpst_odd;
1436 
1437     if (is_ebf(env, &fpst, &fpst_odd)) {
1438         for (row = 0; row < oprsz; ) {
1439             uint16_t prow = pn[H2(row >> 4)];
1440             do {
1441                 void *vza_row = vza + tile_vslice_offset(row);
1442                 uint32_t n = *(uint32_t *)(vzn + H1_4(row));
1443 
1444                 if (ah_neg) {
1445                     n = bf16mop_ah_neg_adj_pair(n, prow);
1446                 } else {
1447                     n = f16mop_adj_pair(n, prow, negx);
1448                 }
1449 
1450                 for (col = 0; col < oprsz; ) {
1451                     uint16_t pcol = pm[H2(col >> 4)];
1452                     do {
1453                         if (prow & pcol & 0b0101) {
1454                             uint32_t *a = vza_row + H1_4(col);
1455                             uint32_t m = *(uint32_t *)(vzm + H1_4(col));
1456 
1457                             m = f16mop_adj_pair(m, pcol, 0);
1458                             *a = bfdotadd_ebf(*a, n, m, &fpst, &fpst_odd);
1459                         }
1460                         col += 4;
1461                         pcol >>= 4;
1462                     } while (col & 15);
1463                 }
1464                 row += 4;
1465                 prow >>= 4;
1466             } while (row & 15);
1467         }
1468     } else {
1469         for (row = 0; row < oprsz; ) {
1470             uint16_t prow = pn[H2(row >> 4)];
1471             do {
1472                 void *vza_row = vza + tile_vslice_offset(row);
1473                 uint32_t n = *(uint32_t *)(vzn + H1_4(row));
1474 
1475                 if (ah_neg) {
1476                     n = bf16mop_ah_neg_adj_pair(n, prow);
1477                 } else {
1478                     n = f16mop_adj_pair(n, prow, negx);
1479                 }
1480 
1481                 for (col = 0; col < oprsz; ) {
1482                     uint16_t pcol = pm[H2(col >> 4)];
1483                     do {
1484                         if (prow & pcol & 0b0101) {
1485                             uint32_t *a = vza_row + H1_4(col);
1486                             uint32_t m = *(uint32_t *)(vzm + H1_4(col));
1487 
1488                             m = f16mop_adj_pair(m, pcol, 0);
1489                             *a = bfdotadd(*a, n, m, &fpst);
1490                         }
1491                         col += 4;
1492                         pcol >>= 4;
1493                     } while (col & 15);
1494                 }
1495                 row += 4;
1496                 prow >>= 4;
1497             } while (row & 15);
1498         }
1499     }
1500 }
1501 
HELPER(sme_bfmopa_w)1502 void HELPER(sme_bfmopa_w)(void *vza, void *vzn, void *vzm, void *vpn,
1503                           void *vpm, CPUARMState *env, uint32_t desc)
1504 {
1505     do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0, false);
1506 }
1507 
HELPER(sme_bfmops_w)1508 void HELPER(sme_bfmops_w)(void *vza, void *vzn, void *vzm, void *vpn,
1509                           void *vpm, CPUARMState *env, uint32_t desc)
1510 {
1511     do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0x80008000u, false);
1512 }
1513 
HELPER(sme_ah_bfmops_w)1514 void HELPER(sme_ah_bfmops_w)(void *vza, void *vzn, void *vzm, void *vpn,
1515                              void *vpm, CPUARMState *env, uint32_t desc)
1516 {
1517     do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0, true);
1518 }
1519 
1520 typedef uint32_t IMOPFn32(uint32_t, uint32_t, uint32_t, uint8_t, bool);
do_imopa_s(uint32_t * za,uint32_t * zn,uint32_t * zm,uint8_t * pn,uint8_t * pm,uint32_t desc,IMOPFn32 * fn)1521 static inline void do_imopa_s(uint32_t *za, uint32_t *zn, uint32_t *zm,
1522                               uint8_t *pn, uint8_t *pm,
1523                               uint32_t desc, IMOPFn32 *fn)
1524 {
1525     intptr_t row, col, oprsz = simd_oprsz(desc) / 4;
1526     bool neg = simd_data(desc);
1527 
1528     for (row = 0; row < oprsz; ++row) {
1529         uint8_t pa = (pn[H1(row >> 1)] >> ((row & 1) * 4)) & 0xf;
1530         uint32_t *za_row = &za[tile_vslice_index(row)];
1531         uint32_t n = zn[H4(row)];
1532 
1533         for (col = 0; col < oprsz; ++col) {
1534             uint8_t pb = pm[H1(col >> 1)] >> ((col & 1) * 4);
1535             uint32_t *a = &za_row[H4(col)];
1536 
1537             *a = fn(n, zm[H4(col)], *a, pa & pb, neg);
1538         }
1539     }
1540 }
1541 
1542 typedef uint64_t IMOPFn64(uint64_t, uint64_t, uint64_t, uint8_t, bool);
do_imopa_d(uint64_t * za,uint64_t * zn,uint64_t * zm,uint8_t * pn,uint8_t * pm,uint32_t desc,IMOPFn64 * fn)1543 static inline void do_imopa_d(uint64_t *za, uint64_t *zn, uint64_t *zm,
1544                               uint8_t *pn, uint8_t *pm,
1545                               uint32_t desc, IMOPFn64 *fn)
1546 {
1547     intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
1548     bool neg = simd_data(desc);
1549 
1550     for (row = 0; row < oprsz; ++row) {
1551         uint8_t pa = pn[H1(row)];
1552         uint64_t *za_row = &za[tile_vslice_index(row)];
1553         uint64_t n = zn[row];
1554 
1555         for (col = 0; col < oprsz; ++col) {
1556             uint8_t pb = pm[H1(col)];
1557             uint64_t *a = &za_row[col];
1558 
1559             *a = fn(n, zm[col], *a, pa & pb, neg);
1560         }
1561     }
1562 }
1563 
1564 #define DEF_IMOP_8x4_32(NAME, NTYPE, MTYPE) \
1565 static uint32_t NAME(uint32_t n, uint32_t m, uint32_t a, uint8_t p, bool neg) \
1566 {                                                                           \
1567     uint32_t sum = 0;                                                       \
1568     /* Apply P to N as a mask, making the inactive elements 0. */           \
1569     n &= expand_pred_b(p);                                                  \
1570     sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0);                               \
1571     sum += (NTYPE)(n >> 8) * (MTYPE)(m >> 8);                               \
1572     sum += (NTYPE)(n >> 16) * (MTYPE)(m >> 16);                             \
1573     sum += (NTYPE)(n >> 24) * (MTYPE)(m >> 24);                             \
1574     return neg ? a - sum : a + sum;                                         \
1575 }
1576 
1577 #define DEF_IMOP_16x4_64(NAME, NTYPE, MTYPE) \
1578 static uint64_t NAME(uint64_t n, uint64_t m, uint64_t a, uint8_t p, bool neg) \
1579 {                                                                           \
1580     uint64_t sum = 0;                                                       \
1581     /* Apply P to N as a mask, making the inactive elements 0. */           \
1582     n &= expand_pred_h(p);                                                  \
1583     sum += (int64_t)(NTYPE)(n >> 0) * (MTYPE)(m >> 0);                      \
1584     sum += (int64_t)(NTYPE)(n >> 16) * (MTYPE)(m >> 16);                    \
1585     sum += (int64_t)(NTYPE)(n >> 32) * (MTYPE)(m >> 32);                    \
1586     sum += (int64_t)(NTYPE)(n >> 48) * (MTYPE)(m >> 48);                    \
1587     return neg ? a - sum : a + sum;                                         \
1588 }
1589 
DEF_IMOP_8x4_32(smopa_s,int8_t,int8_t)1590 DEF_IMOP_8x4_32(smopa_s, int8_t, int8_t)
1591 DEF_IMOP_8x4_32(umopa_s, uint8_t, uint8_t)
1592 DEF_IMOP_8x4_32(sumopa_s, int8_t, uint8_t)
1593 DEF_IMOP_8x4_32(usmopa_s, uint8_t, int8_t)
1594 
1595 DEF_IMOP_16x4_64(smopa_d, int16_t, int16_t)
1596 DEF_IMOP_16x4_64(umopa_d, uint16_t, uint16_t)
1597 DEF_IMOP_16x4_64(sumopa_d, int16_t, uint16_t)
1598 DEF_IMOP_16x4_64(usmopa_d, uint16_t, int16_t)
1599 
1600 #define DEF_IMOPH(P, NAME, S) \
1601     void HELPER(P##_##NAME##_##S)(void *vza, void *vzn, void *vzm,          \
1602                                   void *vpn, void *vpm, uint32_t desc)      \
1603     { do_imopa_##S(vza, vzn, vzm, vpn, vpm, desc, NAME##_##S); }
1604 
1605 DEF_IMOPH(sme, smopa, s)
1606 DEF_IMOPH(sme, umopa, s)
1607 DEF_IMOPH(sme, sumopa, s)
1608 DEF_IMOPH(sme, usmopa, s)
1609 
1610 DEF_IMOPH(sme, smopa, d)
1611 DEF_IMOPH(sme, umopa, d)
1612 DEF_IMOPH(sme, sumopa, d)
1613 DEF_IMOPH(sme, usmopa, d)
1614 
1615 static uint32_t bmopa_s(uint32_t n, uint32_t m, uint32_t a, uint8_t p, bool neg)
1616 {
1617     uint32_t sum = ctpop32(~(n ^ m));
1618     if (neg) {
1619         sum = -sum;
1620     }
1621     if (!(p & 1)) {
1622         sum = 0;
1623     }
1624     return a + sum;
1625 }
1626 
DEF_IMOPH(sme2,bmopa,s)1627 DEF_IMOPH(sme2, bmopa, s)
1628 
1629 #define DEF_IMOP_16x2_32(NAME, NTYPE, MTYPE) \
1630 static uint32_t NAME(uint32_t n, uint32_t m, uint32_t a, uint8_t p, bool neg) \
1631 {                                                                           \
1632     uint32_t sum = 0;                                                       \
1633     /* Apply P to N as a mask, making the inactive elements 0. */           \
1634     n &= expand_pred_h(p);                                                  \
1635     sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0);                               \
1636     sum += (NTYPE)(n >> 16) * (MTYPE)(m >> 16);                             \
1637     return neg ? a - sum : a + sum;                                         \
1638 }
1639 
1640 DEF_IMOP_16x2_32(smopa2_s, int16_t, int16_t)
1641 DEF_IMOP_16x2_32(umopa2_s, uint16_t, uint16_t)
1642 
1643 DEF_IMOPH(sme2, smopa2, s)
1644 DEF_IMOPH(sme2, umopa2, s)
1645 
1646 #define DO_VDOT_IDX(NAME, TYPED, TYPEN, TYPEM, HD, HN) \
1647 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc)            \
1648 {                                                                         \
1649     intptr_t svl = simd_oprsz(desc);                                      \
1650     intptr_t elements = svl / sizeof(TYPED);                              \
1651     intptr_t eltperseg = 16 / sizeof(TYPED);                              \
1652     intptr_t nreg = sizeof(TYPED) / sizeof(TYPEN);                        \
1653     intptr_t vstride = (svl / nreg) * sizeof(ARMVectorReg);               \
1654     intptr_t zstride = sizeof(ARMVectorReg) / sizeof(TYPEN);              \
1655     intptr_t idx = extract32(desc, SIMD_DATA_SHIFT, 2);                   \
1656     TYPEN *n = vn;                                                        \
1657     TYPEM *m = vm;                                                        \
1658     for (intptr_t r = 0; r < nreg; r++) {                                 \
1659         TYPED *d = vd + r * vstride;                                      \
1660         for (intptr_t seg = 0; seg < elements; seg += eltperseg) {        \
1661             intptr_t s = seg + idx;                                       \
1662             for (intptr_t e = seg; e < seg + eltperseg; e++) {            \
1663                 TYPED sum = d[HD(e)];                                     \
1664                 for (intptr_t i = 0; i < nreg; i++) {                     \
1665                     TYPED nn = n[i * zstride + HN(nreg * e + r)];         \
1666                     TYPED mm = m[HN(nreg * s + i)];                       \
1667                     sum += nn * mm;                                       \
1668                 }                                                         \
1669                 d[HD(e)] = sum;                                           \
1670             }                                                             \
1671         }                                                                 \
1672     }                                                                     \
1673 }
1674 
1675 DO_VDOT_IDX(sme2_svdot_idx_4b, int32_t, int8_t, int8_t, H4, H1)
1676 DO_VDOT_IDX(sme2_uvdot_idx_4b, uint32_t, uint8_t, uint8_t, H4, H1)
1677 DO_VDOT_IDX(sme2_suvdot_idx_4b, int32_t, int8_t, uint8_t, H4, H1)
1678 DO_VDOT_IDX(sme2_usvdot_idx_4b, int32_t, uint8_t, int8_t, H4, H1)
1679 
1680 DO_VDOT_IDX(sme2_svdot_idx_4h, int64_t, int16_t, int16_t, H8, H2)
1681 DO_VDOT_IDX(sme2_uvdot_idx_4h, uint64_t, uint16_t, uint16_t, H8, H2)
1682 
1683 DO_VDOT_IDX(sme2_svdot_idx_2h, int32_t, int16_t, int16_t, H4, H2)
1684 DO_VDOT_IDX(sme2_uvdot_idx_2h, uint32_t, uint16_t, uint16_t, H4, H2)
1685 
1686 #undef DO_VDOT_IDX
1687 
1688 #define DO_MLALL(NAME, TYPEW, TYPEN, TYPEM, HW, HN, OP) \
1689 void HELPER(NAME)(void *vd, void *vn, void *vm, void *va, uint32_t desc) \
1690 {                                                               \
1691     intptr_t elements = simd_oprsz(desc) / sizeof(TYPEW);       \
1692     intptr_t sel = extract32(desc, SIMD_DATA_SHIFT, 2);         \
1693     TYPEW *d = vd, *a = va; TYPEN *n = vn; TYPEM *m = vm;       \
1694     for (intptr_t i = 0; i < elements; ++i) {                   \
1695         TYPEW nn = n[HN(i * 4 + sel)];                          \
1696         TYPEM mm = m[HN(i * 4 + sel)];                          \
1697         d[HW(i)] = a[HW(i)] OP (nn * mm);                       \
1698     }                                                           \
1699 }
1700 
1701 DO_MLALL(sme2_smlall_s, int32_t, int8_t, int8_t, H4, H1, +)
1702 DO_MLALL(sme2_smlall_d, int64_t, int16_t, int16_t, H8, H2, +)
1703 DO_MLALL(sme2_smlsll_s, int32_t, int8_t, int8_t, H4, H1, -)
1704 DO_MLALL(sme2_smlsll_d, int64_t, int16_t, int16_t, H8, H2, -)
1705 
1706 DO_MLALL(sme2_umlall_s, uint32_t, uint8_t, uint8_t, H4, H1, +)
1707 DO_MLALL(sme2_umlall_d, uint64_t, uint16_t, uint16_t, H8, H2, +)
1708 DO_MLALL(sme2_umlsll_s, uint32_t, uint8_t, uint8_t, H4, H1, -)
1709 DO_MLALL(sme2_umlsll_d, uint64_t, uint16_t, uint16_t, H8, H2, -)
1710 
1711 DO_MLALL(sme2_usmlall_s, uint32_t, uint8_t, int8_t, H4, H1, +)
1712 
1713 #undef DO_MLALL
1714 
1715 #define DO_MLALL_IDX(NAME, TYPEW, TYPEN, TYPEM, HW, HN, OP) \
1716 void HELPER(NAME)(void *vd, void *vn, void *vm, void *va, uint32_t desc) \
1717 {                                                               \
1718     intptr_t elements = simd_oprsz(desc) / sizeof(TYPEW);       \
1719     intptr_t eltspersegment = 16 / sizeof(TYPEW);               \
1720     intptr_t sel = extract32(desc, SIMD_DATA_SHIFT, 2);         \
1721     intptr_t idx = extract32(desc, SIMD_DATA_SHIFT + 2, 4);     \
1722     TYPEW *d = vd, *a = va; TYPEN *n = vn; TYPEM *m = vm;       \
1723     for (intptr_t i = 0; i < elements; i += eltspersegment) {   \
1724         TYPEW mm = m[HN(i * 4 + idx)];                          \
1725         for (intptr_t j = 0; j < eltspersegment; ++j) {         \
1726             TYPEN nn = n[HN((i + j) * 4 + sel)];                \
1727             d[HW(i + j)] = a[HW(i + j)] OP (nn * mm);           \
1728         }                                                       \
1729     }                                                           \
1730 }
1731 
1732 DO_MLALL_IDX(sme2_smlall_idx_s, int32_t, int8_t, int8_t, H4, H1, +)
1733 DO_MLALL_IDX(sme2_smlall_idx_d, int64_t, int16_t, int16_t, H8, H2, +)
1734 DO_MLALL_IDX(sme2_smlsll_idx_s, int32_t, int8_t, int8_t, H4, H1, -)
1735 DO_MLALL_IDX(sme2_smlsll_idx_d, int64_t, int16_t, int16_t, H8, H2, -)
1736 
1737 DO_MLALL_IDX(sme2_umlall_idx_s, uint32_t, uint8_t, uint8_t, H4, H1, +)
1738 DO_MLALL_IDX(sme2_umlall_idx_d, uint64_t, uint16_t, uint16_t, H8, H2, +)
1739 DO_MLALL_IDX(sme2_umlsll_idx_s, uint32_t, uint8_t, uint8_t, H4, H1, -)
1740 DO_MLALL_IDX(sme2_umlsll_idx_d, uint64_t, uint16_t, uint16_t, H8, H2, -)
1741 
1742 DO_MLALL_IDX(sme2_usmlall_idx_s, uint32_t, uint8_t, int8_t, H4, H1, +)
1743 DO_MLALL_IDX(sme2_sumlall_idx_s, uint32_t, int8_t, uint8_t, H4, H1, +)
1744 
1745 #undef DO_MLALL_IDX
1746 
1747 /* Convert and compress */
1748 void HELPER(sme2_bfcvt)(void *vd, void *vs, float_status *fpst, uint32_t desc)
1749 {
1750     ARMVectorReg scratch;
1751     size_t oprsz = simd_oprsz(desc);
1752     size_t i, n = oprsz / 4;
1753     float32 *s0 = vs;
1754     float32 *s1 = vs + sizeof(ARMVectorReg);
1755     bfloat16 *d = vd;
1756 
1757     if (vd == s1) {
1758         s1 = memcpy(&scratch, s1, oprsz);
1759     }
1760 
1761     for (i = 0; i < n; ++i) {
1762         d[H2(i)] = float32_to_bfloat16(s0[H4(i)], fpst);
1763     }
1764     for (i = 0; i < n; ++i) {
1765         d[H2(i) + n] = float32_to_bfloat16(s1[H4(i)], fpst);
1766     }
1767 }
1768 
HELPER(sme2_fcvt_n)1769 void HELPER(sme2_fcvt_n)(void *vd, void *vs, float_status *fpst, uint32_t desc)
1770 {
1771     ARMVectorReg scratch;
1772     size_t oprsz = simd_oprsz(desc);
1773     size_t i, n = oprsz / 4;
1774     float32 *s0 = vs;
1775     float32 *s1 = vs + sizeof(ARMVectorReg);
1776     float16 *d = vd;
1777 
1778     if (vd == s1) {
1779         s1 = memcpy(&scratch, s1, oprsz);
1780     }
1781 
1782     for (i = 0; i < n; ++i) {
1783         d[H2(i)] = sve_f32_to_f16(s0[H4(i)], fpst);
1784     }
1785     for (i = 0; i < n; ++i) {
1786         d[H2(i) + n] = sve_f32_to_f16(s1[H4(i)], fpst);
1787     }
1788 }
1789 
1790 #define SQCVT2(NAME, TW, TN, HW, HN, SAT)                       \
1791 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
1792 {                                                               \
1793     ARMVectorReg scratch;                                       \
1794     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
1795     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
1796     TN *d = vd;                                                 \
1797     if (vectors_overlap(vd, 1, vs, 2)) {                        \
1798         d = (TN *)&scratch;                                     \
1799     }                                                           \
1800     for (size_t i = 0; i < n; ++i) {                            \
1801         d[HN(i)] = SAT(s0[HW(i)]);                              \
1802         d[HN(i + n)] = SAT(s1[HW(i)]);                          \
1803     }                                                           \
1804     if (d != vd) {                                              \
1805         memcpy(vd, d, oprsz);                                   \
1806     }                                                           \
1807 }
1808 
SQCVT2(sme2_sqcvt_sh,int32_t,int16_t,H4,H2,do_ssat_h)1809 SQCVT2(sme2_sqcvt_sh, int32_t, int16_t, H4, H2, do_ssat_h)
1810 SQCVT2(sme2_uqcvt_sh, uint32_t, uint16_t, H4, H2, do_usat_h)
1811 SQCVT2(sme2_sqcvtu_sh, int32_t, uint16_t, H4, H2, do_usat_h)
1812 
1813 #undef SQCVT2
1814 
1815 #define SQCVT4(NAME, TW, TN, HW, HN, SAT)                       \
1816 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
1817 {                                                               \
1818     ARMVectorReg scratch;                                       \
1819     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
1820     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
1821     TW *s2 = vs + 2 * sizeof(ARMVectorReg);                     \
1822     TW *s3 = vs + 3 * sizeof(ARMVectorReg);                     \
1823     TN *d = vd;                                                 \
1824     if (vectors_overlap(vd, 1, vs, 4)) {                        \
1825         d = (TN *)&scratch;                                     \
1826     }                                                           \
1827     for (size_t i = 0; i < n; ++i) {                            \
1828         d[HN(i)] = SAT(s0[HW(i)]);                              \
1829         d[HN(i + n)] = SAT(s1[HW(i)]);                          \
1830         d[HN(i + 2 * n)] = SAT(s2[HW(i)]);                      \
1831         d[HN(i + 3 * n)] = SAT(s3[HW(i)]);                      \
1832     }                                                           \
1833     if (d != vd) {                                              \
1834         memcpy(vd, d, oprsz);                                   \
1835     }                                                           \
1836 }
1837 
1838 SQCVT4(sme2_sqcvt_sb, int32_t, int8_t, H4, H2, do_ssat_b)
1839 SQCVT4(sme2_uqcvt_sb, uint32_t, uint8_t, H4, H2, do_usat_b)
1840 SQCVT4(sme2_sqcvtu_sb, int32_t, uint8_t, H4, H2, do_usat_b)
1841 
1842 SQCVT4(sme2_sqcvt_dh, int64_t, int16_t, H8, H2, do_ssat_h)
1843 SQCVT4(sme2_uqcvt_dh, uint64_t, uint16_t, H8, H2, do_usat_h)
1844 SQCVT4(sme2_sqcvtu_dh, int64_t, uint16_t, H8, H2, do_usat_h)
1845 
1846 #undef SQCVT4
1847 
1848 #define SQRSHR2(NAME, TW, TN, HW, HN, RSHR, SAT)                \
1849 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
1850 {                                                               \
1851     ARMVectorReg scratch;                                       \
1852     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
1853     int shift = simd_data(desc);                                \
1854     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
1855     TN *d = vd;                                                 \
1856     if (vectors_overlap(vd, 1, vs, 2)) {                        \
1857         d = (TN *)&scratch;                                     \
1858     }                                                           \
1859     for (size_t i = 0; i < n; ++i) {                            \
1860         d[HN(i)] = SAT(RSHR(s0[HW(i)], shift));                 \
1861         d[HN(i + n)] = SAT(RSHR(s1[HW(i)], shift));             \
1862     }                                                           \
1863     if (d != vd) {                                              \
1864         memcpy(vd, d, oprsz);                                   \
1865     }                                                           \
1866 }
1867 
1868 SQRSHR2(sme2_sqrshr_sh, int32_t, int16_t, H4, H2, do_srshr, do_ssat_h)
1869 SQRSHR2(sme2_uqrshr_sh, uint32_t, uint16_t, H4, H2, do_urshr, do_usat_h)
1870 SQRSHR2(sme2_sqrshru_sh, int32_t, uint16_t, H4, H2, do_srshr, do_usat_h)
1871 
1872 #undef SQRSHR2
1873 
1874 #define SQRSHR4(NAME, TW, TN, HW, HN, RSHR, SAT)                \
1875 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
1876 {                                                               \
1877     ARMVectorReg scratch;                                       \
1878     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
1879     int shift = simd_data(desc);                                \
1880     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
1881     TW *s2 = vs + 2 * sizeof(ARMVectorReg);                     \
1882     TW *s3 = vs + 3 * sizeof(ARMVectorReg);                     \
1883     TN *d = vd;                                                 \
1884     if (vectors_overlap(vd, 1, vs, 4)) {                        \
1885         d = (TN *)&scratch;                                     \
1886     }                                                           \
1887     for (size_t i = 0; i < n; ++i) {                            \
1888         d[HN(i)] = SAT(RSHR(s0[HW(i)], shift));                 \
1889         d[HN(i + n)] = SAT(RSHR(s1[HW(i)], shift));             \
1890         d[HN(i + 2 * n)] = SAT(RSHR(s2[HW(i)], shift));         \
1891         d[HN(i + 3 * n)] = SAT(RSHR(s3[HW(i)], shift));         \
1892     }                                                           \
1893     if (d != vd) {                                              \
1894         memcpy(vd, d, oprsz);                                   \
1895     }                                                           \
1896 }
1897 
1898 SQRSHR4(sme2_sqrshr_sb, int32_t, int8_t, H4, H2, do_srshr, do_ssat_b)
1899 SQRSHR4(sme2_uqrshr_sb, uint32_t, uint8_t, H4, H2, do_urshr, do_usat_b)
1900 SQRSHR4(sme2_sqrshru_sb, int32_t, uint8_t, H4, H2, do_srshr, do_usat_b)
1901 
1902 SQRSHR4(sme2_sqrshr_dh, int64_t, int16_t, H8, H2, do_srshr, do_ssat_h)
1903 SQRSHR4(sme2_uqrshr_dh, uint64_t, uint16_t, H8, H2, do_urshr, do_usat_h)
1904 SQRSHR4(sme2_sqrshru_dh, int64_t, uint16_t, H8, H2, do_srshr, do_usat_h)
1905 
1906 #undef SQRSHR4
1907 
1908 /* Convert and interleave */
1909 void HELPER(sme2_bfcvtn)(void *vd, void *vs, float_status *fpst, uint32_t desc)
1910 {
1911     size_t i, n = simd_oprsz(desc) / 4;
1912     float32 *s0 = vs;
1913     float32 *s1 = vs + sizeof(ARMVectorReg);
1914     bfloat16 *d = vd;
1915 
1916     for (i = 0; i < n; ++i) {
1917         bfloat16 d0 = float32_to_bfloat16(s0[H4(i)], fpst);
1918         bfloat16 d1 = float32_to_bfloat16(s1[H4(i)], fpst);
1919         d[H2(i * 2 + 0)] = d0;
1920         d[H2(i * 2 + 1)] = d1;
1921     }
1922 }
1923 
HELPER(sme2_fcvtn)1924 void HELPER(sme2_fcvtn)(void *vd, void *vs, float_status *fpst, uint32_t desc)
1925 {
1926     size_t i, n = simd_oprsz(desc) / 4;
1927     float32 *s0 = vs;
1928     float32 *s1 = vs + sizeof(ARMVectorReg);
1929     bfloat16 *d = vd;
1930 
1931     for (i = 0; i < n; ++i) {
1932         bfloat16 d0 = sve_f32_to_f16(s0[H4(i)], fpst);
1933         bfloat16 d1 = sve_f32_to_f16(s1[H4(i)], fpst);
1934         d[H2(i * 2 + 0)] = d0;
1935         d[H2(i * 2 + 1)] = d1;
1936     }
1937 }
1938 
1939 #define SQCVTN2(NAME, TW, TN, HW, HN, SAT)                      \
1940 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
1941 {                                                               \
1942     ARMVectorReg scratch;                                       \
1943     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
1944     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
1945     TN *d = vd;                                                 \
1946     if (vectors_overlap(vd, 1, vs, 2)) {                        \
1947         d = (TN *)&scratch;                                     \
1948     }                                                           \
1949     for (size_t i = 0; i < n; ++i) {                            \
1950         d[HN(2 * i + 0)] = SAT(s0[HW(i)]);                      \
1951         d[HN(2 * i + 1)] = SAT(s1[HW(i)]);                      \
1952     }                                                           \
1953     if (d != vd) {                                              \
1954         memcpy(vd, d, oprsz);                                   \
1955     }                                                           \
1956 }
1957 
SQCVTN2(sme2_sqcvtn_sh,int32_t,int16_t,H4,H2,do_ssat_h)1958 SQCVTN2(sme2_sqcvtn_sh, int32_t, int16_t, H4, H2, do_ssat_h)
1959 SQCVTN2(sme2_uqcvtn_sh, uint32_t, uint16_t, H4, H2, do_usat_h)
1960 SQCVTN2(sme2_sqcvtun_sh, int32_t, uint16_t, H4, H2, do_usat_h)
1961 
1962 #undef SQCVTN2
1963 
1964 #define SQCVTN4(NAME, TW, TN, HW, HN, SAT)                      \
1965 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
1966 {                                                               \
1967     ARMVectorReg scratch;                                       \
1968     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
1969     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
1970     TW *s2 = vs + 2 * sizeof(ARMVectorReg);                     \
1971     TW *s3 = vs + 3 * sizeof(ARMVectorReg);                     \
1972     TN *d = vd;                                                 \
1973     if (vectors_overlap(vd, 1, vs, 4)) {                        \
1974         d = (TN *)&scratch;                                     \
1975     }                                                           \
1976     for (size_t i = 0; i < n; ++i) {                            \
1977         d[HN(4 * i + 0)] = SAT(s0[HW(i)]);                      \
1978         d[HN(4 * i + 1)] = SAT(s1[HW(i)]);                      \
1979         d[HN(4 * i + 2)] = SAT(s2[HW(i)]);                      \
1980         d[HN(4 * i + 3)] = SAT(s3[HW(i)]);                      \
1981     }                                                           \
1982     if (d != vd) {                                              \
1983         memcpy(vd, d, oprsz);                                   \
1984     }                                                           \
1985 }
1986 
1987 SQCVTN4(sme2_sqcvtn_sb, int32_t, int8_t, H4, H1, do_ssat_b)
1988 SQCVTN4(sme2_uqcvtn_sb, uint32_t, uint8_t, H4, H1, do_usat_b)
1989 SQCVTN4(sme2_sqcvtun_sb, int32_t, uint8_t, H4, H1, do_usat_b)
1990 
1991 SQCVTN4(sme2_sqcvtn_dh, int64_t, int16_t, H8, H2, do_ssat_h)
1992 SQCVTN4(sme2_uqcvtn_dh, uint64_t, uint16_t, H8, H2, do_usat_h)
1993 SQCVTN4(sme2_sqcvtun_dh, int64_t, uint16_t, H8, H2, do_usat_h)
1994 
1995 #undef SQCVTN4
1996 
1997 #define SQRSHRN2(NAME, TW, TN, HW, HN, RSHR, SAT)               \
1998 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
1999 {                                                               \
2000     ARMVectorReg scratch;                                       \
2001     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
2002     int shift = simd_data(desc);                                \
2003     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
2004     TN *d = vd;                                                 \
2005     if (vectors_overlap(vd, 1, vs, 2)) {                        \
2006         d = (TN *)&scratch;                                     \
2007     }                                                           \
2008     for (size_t i = 0; i < n; ++i) {                            \
2009         d[HN(2 * i + 0)] = SAT(RSHR(s0[HW(i)], shift));         \
2010         d[HN(2 * i + 1)] = SAT(RSHR(s1[HW(i)], shift));         \
2011     }                                                           \
2012     if (d != vd) {                                              \
2013         memcpy(vd, d, oprsz);                                   \
2014     }                                                           \
2015 }
2016 
2017 SQRSHRN2(sme2_sqrshrn_sh, int32_t, int16_t, H4, H2, do_srshr, do_ssat_h)
2018 SQRSHRN2(sme2_uqrshrn_sh, uint32_t, uint16_t, H4, H2, do_urshr, do_usat_h)
2019 SQRSHRN2(sme2_sqrshrun_sh, int32_t, uint16_t, H4, H2, do_srshr, do_usat_h)
2020 
2021 #undef SQRSHRN2
2022 
2023 #define SQRSHRN4(NAME, TW, TN, HW, HN, RSHR, SAT)               \
2024 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
2025 {                                                               \
2026     ARMVectorReg scratch;                                       \
2027     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
2028     int shift = simd_data(desc);                                \
2029     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
2030     TW *s2 = vs + 2 * sizeof(ARMVectorReg);                     \
2031     TW *s3 = vs + 3 * sizeof(ARMVectorReg);                     \
2032     TN *d = vd;                                                 \
2033     if (vectors_overlap(vd, 1, vs, 4)) {                        \
2034         d = (TN *)&scratch;                                     \
2035     }                                                           \
2036     for (size_t i = 0; i < n; ++i) {                            \
2037         d[HN(4 * i + 0)] = SAT(RSHR(s0[HW(i)], shift));         \
2038         d[HN(4 * i + 1)] = SAT(RSHR(s1[HW(i)], shift));         \
2039         d[HN(4 * i + 2)] = SAT(RSHR(s2[HW(i)], shift));         \
2040         d[HN(4 * i + 3)] = SAT(RSHR(s3[HW(i)], shift));         \
2041     }                                                           \
2042     if (d != vd) {                                              \
2043         memcpy(vd, d, oprsz);                                   \
2044     }                                                           \
2045 }
2046 
2047 SQRSHRN4(sme2_sqrshrn_sb, int32_t, int8_t, H4, H1, do_srshr, do_ssat_b)
2048 SQRSHRN4(sme2_uqrshrn_sb, uint32_t, uint8_t, H4, H1, do_urshr, do_usat_b)
2049 SQRSHRN4(sme2_sqrshrun_sb, int32_t, uint8_t, H4, H1, do_srshr, do_usat_b)
2050 
2051 SQRSHRN4(sme2_sqrshrn_dh, int64_t, int16_t, H8, H2, do_srshr, do_ssat_h)
2052 SQRSHRN4(sme2_uqrshrn_dh, uint64_t, uint16_t, H8, H2, do_urshr, do_usat_h)
2053 SQRSHRN4(sme2_sqrshrun_dh, int64_t, uint16_t, H8, H2, do_srshr, do_usat_h)
2054 
2055 #undef SQRSHRN4
2056 
2057 /* Expand and convert */
2058 void HELPER(sme2_fcvt_w)(void *vd, void *vs, float_status *fpst, uint32_t desc)
2059 {
2060     ARMVectorReg scratch;
2061     size_t oprsz = simd_oprsz(desc);
2062     size_t i, n = oprsz / 4;
2063     float16 *s = vs;
2064     float32 *d0 = vd;
2065     float32 *d1 = vd + sizeof(ARMVectorReg);
2066 
2067     if (vectors_overlap(vd, 1, vs, 2)) {
2068         s = memcpy(&scratch, s, oprsz);
2069     }
2070 
2071     for (i = 0; i < n; ++i) {
2072         d0[H4(i)] = sve_f16_to_f32(s[H2(i)], fpst);
2073     }
2074     for (i = 0; i < n; ++i) {
2075         d1[H4(i)] = sve_f16_to_f32(s[H2(n + i)], fpst);
2076     }
2077 }
2078 
2079 #define UNPK(NAME, SREG, TW, TN, HW, HN)                        \
2080 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
2081 {                                                               \
2082     ARMVectorReg scratch[SREG];                                 \
2083     size_t oprsz = simd_oprsz(desc);                            \
2084     size_t n = oprsz / sizeof(TW);                              \
2085     if (vectors_overlap(vd, 2 * SREG, vs, SREG)) {              \
2086         vs = memcpy(scratch, vs, sizeof(scratch));              \
2087     }                                                           \
2088     for (size_t r = 0; r < SREG; ++r) {                         \
2089         TN *s = vs + r * sizeof(ARMVectorReg);                  \
2090         for (size_t i = 0; i < 2; ++i) {                        \
2091             TW *d = vd + (2 * r + i) * sizeof(ARMVectorReg);    \
2092             for (size_t e = 0; e < n; ++e) {                    \
2093                 d[HW(e)] = s[HN(i * n + e)];                    \
2094             }                                                   \
2095         }                                                       \
2096     }                                                           \
2097 }
2098 
2099 UNPK(sme2_sunpk2_bh, 1, int16_t, int8_t, H2, H1)
2100 UNPK(sme2_sunpk2_hs, 1, int32_t, int16_t, H4, H2)
2101 UNPK(sme2_sunpk2_sd, 1, int64_t, int32_t, H8, H4)
2102 
2103 UNPK(sme2_sunpk4_bh, 2, int16_t, int8_t, H2, H1)
2104 UNPK(sme2_sunpk4_hs, 2, int32_t, int16_t, H4, H2)
2105 UNPK(sme2_sunpk4_sd, 2, int64_t, int32_t, H8, H4)
2106 
2107 UNPK(sme2_uunpk2_bh, 1, uint16_t, uint8_t, H2, H1)
2108 UNPK(sme2_uunpk2_hs, 1, uint32_t, uint16_t, H4, H2)
2109 UNPK(sme2_uunpk2_sd, 1, uint64_t, uint32_t, H8, H4)
2110 
2111 UNPK(sme2_uunpk4_bh, 2, uint16_t, uint8_t, H2, H1)
2112 UNPK(sme2_uunpk4_hs, 2, uint32_t, uint16_t, H4, H2)
2113 UNPK(sme2_uunpk4_sd, 2, uint64_t, uint32_t, H8, H4)
2114 
2115 #undef UNPK
2116 
2117 /* Deinterleave and convert. */
HELPER(sme2_fcvtl)2118 void HELPER(sme2_fcvtl)(void *vd, void *vs, float_status *fpst, uint32_t desc)
2119 {
2120     size_t i, n = simd_oprsz(desc) / 4;
2121     float16 *s = vs;
2122     float32 *d0 = vd;
2123     float32 *d1 = vd + sizeof(ARMVectorReg);
2124 
2125     for (i = 0; i < n; ++i) {
2126         float32 v0 = sve_f16_to_f32(s[H2(i * 2 + 0)], fpst);
2127         float32 v1 = sve_f16_to_f32(s[H2(i * 2 + 1)], fpst);
2128         d0[H4(i)] = v0;
2129         d1[H4(i)] = v1;
2130     }
2131 }
2132 
HELPER(sme2_scvtf)2133 void HELPER(sme2_scvtf)(void *vd, void *vs, float_status *fpst, uint32_t desc)
2134 {
2135     size_t i, n = simd_oprsz(desc) / 4;
2136     int32_t *d = vd;
2137     float32 *s = vs;
2138 
2139     for (i = 0; i < n; ++i) {
2140         d[i] = int32_to_float32(s[i], fpst);
2141     }
2142 }
2143 
HELPER(sme2_ucvtf)2144 void HELPER(sme2_ucvtf)(void *vd, void *vs, float_status *fpst, uint32_t desc)
2145 {
2146     size_t i, n = simd_oprsz(desc) / 4;
2147     uint32_t *d = vd;
2148     float32 *s = vs;
2149 
2150     for (i = 0; i < n; ++i) {
2151         d[i] = uint32_to_float32(s[i], fpst);
2152     }
2153 }
2154 
2155 #define ZIP2(NAME, TYPE, H)                                     \
2156 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc)  \
2157 {                                                               \
2158     ARMVectorReg scratch[2];                                    \
2159     size_t oprsz = simd_oprsz(desc);                            \
2160     size_t pairs = oprsz / (sizeof(TYPE) * 2);                  \
2161     TYPE *n = vn, *m = vm;                                      \
2162     if (vectors_overlap(vd, 2, vn, 1)) {                        \
2163         n = memcpy(&scratch[0], vn, oprsz);                     \
2164     }                                                           \
2165     if (vectors_overlap(vd, 2, vm, 1)) {                        \
2166         m = memcpy(&scratch[1], vm, oprsz);                     \
2167     }                                                           \
2168     for (size_t r = 0; r < 2; ++r) {                            \
2169         TYPE *d = vd + r * sizeof(ARMVectorReg);                \
2170         size_t base = r * pairs;                                \
2171         for (size_t p = 0; p < pairs; ++p) {                    \
2172             d[H(2 * p + 0)] = n[base + H(p)];                   \
2173             d[H(2 * p + 1)] = m[base + H(p)];                   \
2174         }                                                       \
2175     }                                                           \
2176 }
2177 
ZIP2(sme2_zip2_b,uint8_t,H1)2178 ZIP2(sme2_zip2_b, uint8_t, H1)
2179 ZIP2(sme2_zip2_h, uint16_t, H2)
2180 ZIP2(sme2_zip2_s, uint32_t, H4)
2181 ZIP2(sme2_zip2_d, uint64_t, )
2182 ZIP2(sme2_zip2_q, Int128, )
2183 
2184 #undef ZIP2
2185 
2186 #define ZIP4(NAME, TYPE, H)                                     \
2187 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
2188 {                                                               \
2189     ARMVectorReg scratch[4];                                    \
2190     size_t oprsz = simd_oprsz(desc);                            \
2191     size_t quads = oprsz / (sizeof(TYPE) * 4);                  \
2192     TYPE *s0, *s1, *s2, *s3;                                    \
2193     if (vs == vd) {                                             \
2194         vs = memcpy(scratch, vs, sizeof(scratch));              \
2195     }                                                           \
2196     s0 = vs;                                                    \
2197     s1 = vs + sizeof(ARMVectorReg);                             \
2198     s2 = vs + 2 * sizeof(ARMVectorReg);                         \
2199     s3 = vs + 3 * sizeof(ARMVectorReg);                         \
2200     for (size_t r = 0; r < 4; ++r) {                            \
2201         TYPE *d = vd + r * sizeof(ARMVectorReg);                \
2202         size_t base = r * quads;                                \
2203         for (size_t q = 0; q < quads; ++q) {                    \
2204             d[H(4 * q + 0)] = s0[base + H(q)];                  \
2205             d[H(4 * q + 1)] = s1[base + H(q)];                  \
2206             d[H(4 * q + 2)] = s2[base + H(q)];                  \
2207             d[H(4 * q + 3)] = s3[base + H(q)];                  \
2208         }                                                       \
2209     }                                                           \
2210 }
2211 
2212 ZIP4(sme2_zip4_b, uint8_t, H1)
2213 ZIP4(sme2_zip4_h, uint16_t, H2)
2214 ZIP4(sme2_zip4_s, uint32_t, H4)
2215 ZIP4(sme2_zip4_d, uint64_t, )
2216 ZIP4(sme2_zip4_q, Int128, )
2217 
2218 #undef ZIP4
2219 
2220 #define UZP2(NAME, TYPE, H)                                     \
2221 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc)  \
2222 {                                                               \
2223     ARMVectorReg scratch[2];                                    \
2224     size_t oprsz = simd_oprsz(desc);                            \
2225     size_t pairs = oprsz / (sizeof(TYPE) * 2);                  \
2226     TYPE *d0 = vd, *d1 = vd + sizeof(ARMVectorReg);             \
2227     if (vectors_overlap(vd, 2, vn, 1)) {                        \
2228         vn = memcpy(&scratch[0], vn, oprsz);                    \
2229     }                                                           \
2230     if (vectors_overlap(vd, 2, vm, 1)) {                        \
2231         vm = memcpy(&scratch[1], vm, oprsz);                    \
2232     }                                                           \
2233     for (size_t r = 0; r < 2; ++r) {                            \
2234         TYPE *s = r ? vm : vn;                                  \
2235         size_t base = r * pairs;                                \
2236         for (size_t p = 0; p < pairs; ++p) {                    \
2237             d0[base + H(p)] = s[H(2 * p + 0)];                  \
2238             d1[base + H(p)] = s[H(2 * p + 1)];                  \
2239         }                                                       \
2240     }                                                           \
2241 }
2242 
2243 UZP2(sme2_uzp2_b, uint8_t, H1)
2244 UZP2(sme2_uzp2_h, uint16_t, H2)
2245 UZP2(sme2_uzp2_s, uint32_t, H4)
2246 UZP2(sme2_uzp2_d, uint64_t, )
2247 UZP2(sme2_uzp2_q, Int128, )
2248 
2249 #undef UZP2
2250 
2251 #define UZP4(NAME, TYPE, H)                                     \
2252 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
2253 {                                                               \
2254     ARMVectorReg scratch[4];                                    \
2255     size_t oprsz = simd_oprsz(desc);                            \
2256     size_t quads = oprsz / (sizeof(TYPE) * 4);                  \
2257     TYPE *d0, *d1, *d2, *d3;                                    \
2258     if (vs == vd) {                                             \
2259         vs = memcpy(scratch, vs, sizeof(scratch));              \
2260     }                                                           \
2261     d0 = vd;                                                    \
2262     d1 = vd + sizeof(ARMVectorReg);                             \
2263     d2 = vd + 2 * sizeof(ARMVectorReg);                         \
2264     d3 = vd + 3 * sizeof(ARMVectorReg);                         \
2265     for (size_t r = 0; r < 4; ++r) {                            \
2266         TYPE *s = vs + r * sizeof(ARMVectorReg);                \
2267         size_t base = r * quads;                                \
2268         for (size_t q = 0; q < quads; ++q) {                    \
2269             d0[base + H(q)] = s[H(4 * q + 0)];                  \
2270             d1[base + H(q)] = s[H(4 * q + 1)];                  \
2271             d2[base + H(q)] = s[H(4 * q + 2)];                  \
2272             d3[base + H(q)] = s[H(4 * q + 3)];                  \
2273         }                                                       \
2274     }                                                           \
2275 }
2276 
2277 UZP4(sme2_uzp4_b, uint8_t, H1)
2278 UZP4(sme2_uzp4_h, uint16_t, H2)
2279 UZP4(sme2_uzp4_s, uint32_t, H4)
2280 UZP4(sme2_uzp4_d, uint64_t, )
2281 UZP4(sme2_uzp4_q, Int128, )
2282 
2283 #undef UZP4
2284 
2285 #define ICLAMP(NAME, TYPE, H) \
2286 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc)  \
2287 {                                                               \
2288     size_t stride = sizeof(ARMVectorReg) / sizeof(TYPE);        \
2289     size_t elements = simd_oprsz(desc) / sizeof(TYPE);          \
2290     size_t nreg = simd_data(desc);                              \
2291     TYPE *d = vd, *n = vn, *m = vm;                             \
2292     for (size_t e = 0; e < elements; e++) {                     \
2293         TYPE nn = n[H(e)], mm = m[H(e)];                        \
2294         for (size_t r = 0; r < nreg; r++) {                     \
2295             TYPE *dd = &d[r * stride + H(e)];                   \
2296             *dd = MIN(MAX(*dd, nn), mm);                        \
2297         }                                                       \
2298     }                                                           \
2299 }
2300 
2301 ICLAMP(sme2_sclamp_b, int8_t, H1)
2302 ICLAMP(sme2_sclamp_h, int16_t, H2)
2303 ICLAMP(sme2_sclamp_s, int32_t, H4)
2304 ICLAMP(sme2_sclamp_d, int64_t, H8)
2305 
2306 ICLAMP(sme2_uclamp_b, uint8_t, H1)
2307 ICLAMP(sme2_uclamp_h, uint16_t, H2)
2308 ICLAMP(sme2_uclamp_s, uint32_t, H4)
2309 ICLAMP(sme2_uclamp_d, uint64_t, H8)
2310 
2311 #undef ICLAMP
2312 
2313 /*
2314  * Note the argument ordering to minnum and maxnum must match
2315  * the ARM pseudocode so that NaNs are propagated properly.
2316  */
2317 #define FCLAMP(NAME, TYPE, H) \
2318 void HELPER(NAME)(void *vd, void *vn, void *vm,                 \
2319                   float_status *fpst, uint32_t desc)            \
2320 {                                                               \
2321     size_t stride = sizeof(ARMVectorReg) / sizeof(TYPE);        \
2322     size_t elements = simd_oprsz(desc) / sizeof(TYPE);          \
2323     size_t nreg = simd_data(desc);                              \
2324     TYPE *d = vd, *n = vn, *m = vm;                             \
2325     for (size_t e = 0; e < elements; e++) {                     \
2326         TYPE nn = n[H(e)], mm = m[H(e)];                        \
2327         for (size_t r = 0; r < nreg; r++) {                     \
2328             TYPE *dd = &d[r * stride + H(e)];                   \
2329             *dd = TYPE##_minnum(TYPE##_maxnum(nn, *dd, fpst), mm, fpst); \
2330         }                                                       \
2331     }                                                           \
2332 }
2333 
2334 FCLAMP(sme2_fclamp_h, float16, H2)
2335 FCLAMP(sme2_fclamp_s, float32, H4)
2336 FCLAMP(sme2_fclamp_d, float64, H8)
2337 FCLAMP(sme2_bfclamp, bfloat16, H2)
2338 
2339 #undef FCLAMP
2340 
2341 void HELPER(sme2_sel_b)(void *vd, void *vn, void *vm,
2342                         uint32_t png, uint32_t desc)
2343 {
2344     int vl = simd_oprsz(desc);
2345     int nreg = simd_data(desc);
2346     int elements = vl / sizeof(uint8_t);
2347     DecodeCounter p = decode_counter(png, vl, MO_8);
2348 
2349     if (p.lg2_stride == 0) {
2350         if (p.invert) {
2351             for (int r = 0; r < nreg; r++) {
2352                 uint8_t *d = vd + r * sizeof(ARMVectorReg);
2353                 uint8_t *n = vn + r * sizeof(ARMVectorReg);
2354                 uint8_t *m = vm + r * sizeof(ARMVectorReg);
2355                 int split = p.count - r * elements;
2356 
2357                 if (split <= 0) {
2358                     memcpy(d, n, vl);  /* all true */
2359                 } else if (elements <= split) {
2360                     memcpy(d, m, vl);  /* all false */
2361                 } else {
2362                     for (int e = 0; e < split; e++) {
2363                         d[H1(e)] = m[H1(e)];
2364                     }
2365                     for (int e = split; e < elements; e++) {
2366                         d[H1(e)] = n[H1(e)];
2367                     }
2368                 }
2369             }
2370         } else {
2371             for (int r = 0; r < nreg; r++) {
2372                 uint8_t *d = vd + r * sizeof(ARMVectorReg);
2373                 uint8_t *n = vn + r * sizeof(ARMVectorReg);
2374                 uint8_t *m = vm + r * sizeof(ARMVectorReg);
2375                 int split = p.count - r * elements;
2376 
2377                 if (split <= 0) {
2378                     memcpy(d, m, vl);  /* all false */
2379                 } else if (elements <= split) {
2380                     memcpy(d, n, vl);  /* all true */
2381                 } else {
2382                     for (int e = 0; e < split; e++) {
2383                         d[H1(e)] = n[H1(e)];
2384                     }
2385                     for (int e = split; e < elements; e++) {
2386                         d[H1(e)] = m[H1(e)];
2387                     }
2388                 }
2389             }
2390         }
2391     } else {
2392         int estride = 1 << p.lg2_stride;
2393         if (p.invert) {
2394             for (int r = 0; r < nreg; r++) {
2395                 uint8_t *d = vd + r * sizeof(ARMVectorReg);
2396                 uint8_t *n = vn + r * sizeof(ARMVectorReg);
2397                 uint8_t *m = vm + r * sizeof(ARMVectorReg);
2398                 int split = p.count - r * elements;
2399                 int e = 0;
2400 
2401                 for (; e < MIN(split, elements); e++) {
2402                     d[H1(e)] = m[H1(e)];
2403                 }
2404                 for (; e < elements; e += estride) {
2405                     d[H1(e)] = n[H1(e)];
2406                     for (int i = 1; i < estride; i++) {
2407                         d[H1(e + i)] = m[H1(e + i)];
2408                     }
2409                 }
2410             }
2411         } else {
2412             for (int r = 0; r < nreg; r++) {
2413                 uint8_t *d = vd + r * sizeof(ARMVectorReg);
2414                 uint8_t *n = vn + r * sizeof(ARMVectorReg);
2415                 uint8_t *m = vm + r * sizeof(ARMVectorReg);
2416                 int split = p.count - r * elements;
2417                 int e = 0;
2418 
2419                 for (; e < MIN(split, elements); e += estride) {
2420                     d[H1(e)] = n[H1(e)];
2421                     for (int i = 1; i < estride; i++) {
2422                         d[H1(e + i)] = m[H1(e + i)];
2423                     }
2424                 }
2425                 for (; e < elements; e++) {
2426                     d[H1(e)] = m[H1(e)];
2427                 }
2428             }
2429         }
2430     }
2431 }
2432 
HELPER(sme2_sel_h)2433 void HELPER(sme2_sel_h)(void *vd, void *vn, void *vm,
2434                         uint32_t png, uint32_t desc)
2435 {
2436     int vl = simd_oprsz(desc);
2437     int nreg = simd_data(desc);
2438     int elements = vl / sizeof(uint16_t);
2439     DecodeCounter p = decode_counter(png, vl, MO_16);
2440 
2441     if (p.lg2_stride == 0) {
2442         if (p.invert) {
2443             for (int r = 0; r < nreg; r++) {
2444                 uint16_t *d = vd + r * sizeof(ARMVectorReg);
2445                 uint16_t *n = vn + r * sizeof(ARMVectorReg);
2446                 uint16_t *m = vm + r * sizeof(ARMVectorReg);
2447                 int split = p.count - r * elements;
2448 
2449                 if (split <= 0) {
2450                     memcpy(d, n, vl);  /* all true */
2451                 } else if (elements <= split) {
2452                     memcpy(d, m, vl);  /* all false */
2453                 } else {
2454                     for (int e = 0; e < split; e++) {
2455                         d[H2(e)] = m[H2(e)];
2456                     }
2457                     for (int e = split; e < elements; e++) {
2458                         d[H2(e)] = n[H2(e)];
2459                     }
2460                 }
2461             }
2462         } else {
2463             for (int r = 0; r < nreg; r++) {
2464                 uint16_t *d = vd + r * sizeof(ARMVectorReg);
2465                 uint16_t *n = vn + r * sizeof(ARMVectorReg);
2466                 uint16_t *m = vm + r * sizeof(ARMVectorReg);
2467                 int split = p.count - r * elements;
2468 
2469                 if (split <= 0) {
2470                     memcpy(d, m, vl);  /* all false */
2471                 } else if (elements <= split) {
2472                     memcpy(d, n, vl);  /* all true */
2473                 } else {
2474                     for (int e = 0; e < split; e++) {
2475                         d[H2(e)] = n[H2(e)];
2476                     }
2477                     for (int e = split; e < elements; e++) {
2478                         d[H2(e)] = m[H2(e)];
2479                     }
2480                 }
2481             }
2482         }
2483     } else {
2484         int estride = 1 << p.lg2_stride;
2485         if (p.invert) {
2486             for (int r = 0; r < nreg; r++) {
2487                 uint16_t *d = vd + r * sizeof(ARMVectorReg);
2488                 uint16_t *n = vn + r * sizeof(ARMVectorReg);
2489                 uint16_t *m = vm + r * sizeof(ARMVectorReg);
2490                 int split = p.count - r * elements;
2491                 int e = 0;
2492 
2493                 for (; e < MIN(split, elements); e++) {
2494                     d[H2(e)] = m[H2(e)];
2495                 }
2496                 for (; e < elements; e += estride) {
2497                     d[H2(e)] = n[H2(e)];
2498                     for (int i = 1; i < estride; i++) {
2499                         d[H2(e + i)] = m[H2(e + i)];
2500                     }
2501                 }
2502             }
2503         } else {
2504             for (int r = 0; r < nreg; r++) {
2505                 uint16_t *d = vd + r * sizeof(ARMVectorReg);
2506                 uint16_t *n = vn + r * sizeof(ARMVectorReg);
2507                 uint16_t *m = vm + r * sizeof(ARMVectorReg);
2508                 int split = p.count - r * elements;
2509                 int e = 0;
2510 
2511                 for (; e < MIN(split, elements); e += estride) {
2512                     d[H2(e)] = n[H2(e)];
2513                     for (int i = 1; i < estride; i++) {
2514                         d[H2(e + i)] = m[H2(e + i)];
2515                     }
2516                 }
2517                 for (; e < elements; e++) {
2518                     d[H2(e)] = m[H2(e)];
2519                 }
2520             }
2521         }
2522     }
2523 }
2524 
HELPER(sme2_sel_s)2525 void HELPER(sme2_sel_s)(void *vd, void *vn, void *vm,
2526                         uint32_t png, uint32_t desc)
2527 {
2528     int vl = simd_oprsz(desc);
2529     int nreg = simd_data(desc);
2530     int elements = vl / sizeof(uint32_t);
2531     DecodeCounter p = decode_counter(png, vl, MO_32);
2532 
2533     if (p.lg2_stride == 0) {
2534         if (p.invert) {
2535             for (int r = 0; r < nreg; r++) {
2536                 uint32_t *d = vd + r * sizeof(ARMVectorReg);
2537                 uint32_t *n = vn + r * sizeof(ARMVectorReg);
2538                 uint32_t *m = vm + r * sizeof(ARMVectorReg);
2539                 int split = p.count - r * elements;
2540 
2541                 if (split <= 0) {
2542                     memcpy(d, n, vl);  /* all true */
2543                 } else if (elements <= split) {
2544                     memcpy(d, m, vl);  /* all false */
2545                 } else {
2546                     for (int e = 0; e < split; e++) {
2547                         d[H4(e)] = m[H4(e)];
2548                     }
2549                     for (int e = split; e < elements; e++) {
2550                         d[H4(e)] = n[H4(e)];
2551                     }
2552                 }
2553             }
2554         } else {
2555             for (int r = 0; r < nreg; r++) {
2556                 uint32_t *d = vd + r * sizeof(ARMVectorReg);
2557                 uint32_t *n = vn + r * sizeof(ARMVectorReg);
2558                 uint32_t *m = vm + r * sizeof(ARMVectorReg);
2559                 int split = p.count - r * elements;
2560 
2561                 if (split <= 0) {
2562                     memcpy(d, m, vl);  /* all false */
2563                 } else if (elements <= split) {
2564                     memcpy(d, n, vl);  /* all true */
2565                 } else {
2566                     for (int e = 0; e < split; e++) {
2567                         d[H4(e)] = n[H4(e)];
2568                     }
2569                     for (int e = split; e < elements; e++) {
2570                         d[H4(e)] = m[H4(e)];
2571                     }
2572                 }
2573             }
2574         }
2575     } else {
2576         /* p.esz must be MO_64, so stride must be 2. */
2577         if (p.invert) {
2578             for (int r = 0; r < nreg; r++) {
2579                 uint32_t *d = vd + r * sizeof(ARMVectorReg);
2580                 uint32_t *n = vn + r * sizeof(ARMVectorReg);
2581                 uint32_t *m = vm + r * sizeof(ARMVectorReg);
2582                 int split = p.count - r * elements;
2583                 int e = 0;
2584 
2585                 for (; e < MIN(split, elements); e++) {
2586                     d[H4(e)] = m[H4(e)];
2587                 }
2588                 for (; e < elements; e += 2) {
2589                     d[H4(e)] = n[H4(e)];
2590                     d[H4(e + 1)] = m[H4(e + 1)];
2591                 }
2592             }
2593         } else {
2594             for (int r = 0; r < nreg; r++) {
2595                 uint32_t *d = vd + r * sizeof(ARMVectorReg);
2596                 uint32_t *n = vn + r * sizeof(ARMVectorReg);
2597                 uint32_t *m = vm + r * sizeof(ARMVectorReg);
2598                 int split = p.count - r * elements;
2599                 int e = 0;
2600 
2601                 for (; e < MIN(split, elements); e += 2) {
2602                     d[H4(e)] = n[H4(e)];
2603                     d[H4(e + 1)] = m[H4(e + 1)];
2604                 }
2605                 for (; e < elements; e++) {
2606                     d[H4(e)] = m[H4(e)];
2607                 }
2608             }
2609         }
2610     }
2611 }
2612 
HELPER(sme2_sel_d)2613 void HELPER(sme2_sel_d)(void *vd, void *vn, void *vm,
2614                         uint32_t png, uint32_t desc)
2615 {
2616     int vl = simd_oprsz(desc);
2617     int nreg = simd_data(desc);
2618     int elements = vl / sizeof(uint64_t);
2619     DecodeCounter p = decode_counter(png, vl, MO_64);
2620 
2621     if (p.invert) {
2622         for (int r = 0; r < nreg; r++) {
2623             uint64_t *d = vd + r * sizeof(ARMVectorReg);
2624             uint64_t *n = vn + r * sizeof(ARMVectorReg);
2625             uint64_t *m = vm + r * sizeof(ARMVectorReg);
2626             int split = p.count - r * elements;
2627 
2628             if (split <= 0) {
2629                 memcpy(d, n, vl);  /* all true */
2630             } else if (elements <= split) {
2631                 memcpy(d, m, vl);  /* all false */
2632             } else {
2633                 memcpy(d, m, split * sizeof(uint64_t));
2634                 memcpy(d + split, n + split,
2635                        (elements - split) * sizeof(uint64_t));
2636             }
2637         }
2638     } else {
2639         for (int r = 0; r < nreg; r++) {
2640             uint64_t *d = vd + r * sizeof(ARMVectorReg);
2641             uint64_t *n = vn + r * sizeof(ARMVectorReg);
2642             uint64_t *m = vm + r * sizeof(ARMVectorReg);
2643             int split = p.count - r * elements;
2644 
2645             if (split <= 0) {
2646                 memcpy(d, m, vl);  /* all false */
2647             } else if (elements <= split) {
2648                 memcpy(d, n, vl);  /* all true */
2649             } else {
2650                 memcpy(d, n, split * sizeof(uint64_t));
2651                 memcpy(d + split, m + split,
2652                        (elements - split) * sizeof(uint64_t));
2653             }
2654         }
2655     }
2656 }
2657