xref: /openbmc/qemu/target/arm/tcg/sme_helper.c (revision ec08d9a51e6af3cd3edbdbf2ca6e97a1e2b5f0d1)
1 /*
2  * ARM SME Operations
3  *
4  * Copyright (c) 2022 Linaro, Ltd.
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, see <http://www.gnu.org/licenses/>.
18  */
19 
20 #include "qemu/osdep.h"
21 #include "cpu.h"
22 #include "internals.h"
23 #include "tcg/tcg-gvec-desc.h"
24 #include "exec/helper-proto.h"
25 #include "exec/cpu_ldst.h"
26 #include "exec/exec-all.h"
27 #include "qemu/int128.h"
28 #include "fpu/softfloat.h"
29 #include "vec_internal.h"
30 #include "sve_ldst_internal.h"
31 
helper_set_svcr(CPUARMState * env,uint32_t val,uint32_t mask)32 void helper_set_svcr(CPUARMState *env, uint32_t val, uint32_t mask)
33 {
34     aarch64_set_svcr(env, val, mask);
35 }
36 
helper_sme_zero(CPUARMState * env,uint32_t imm,uint32_t svl)37 void helper_sme_zero(CPUARMState *env, uint32_t imm, uint32_t svl)
38 {
39     uint32_t i;
40 
41     /*
42      * Special case clearing the entire ZA space.
43      * This falls into the CONSTRAINED UNPREDICTABLE zeroing of any
44      * parts of the ZA storage outside of SVL.
45      */
46     if (imm == 0xff) {
47         memset(env->zarray, 0, sizeof(env->zarray));
48         return;
49     }
50 
51     /*
52      * Recall that ZAnH.D[m] is spread across ZA[n+8*m],
53      * so each row is discontiguous within ZA[].
54      */
55     for (i = 0; i < svl; i++) {
56         if (imm & (1 << (i % 8))) {
57             memset(&env->zarray[i], 0, svl);
58         }
59     }
60 }
61 
62 
63 /*
64  * When considering the ZA storage as an array of elements of
65  * type T, the index within that array of the Nth element of
66  * a vertical slice of a tile can be calculated like this,
67  * regardless of the size of type T. This is because the tiles
68  * are interleaved, so if type T is size N bytes then row 1 of
69  * the tile is N rows away from row 0. The division by N to
70  * convert a byte offset into an array index and the multiplication
71  * by N to convert from vslice-index-within-the-tile to
72  * the index within the ZA storage cancel out.
73  */
74 #define tile_vslice_index(i) ((i) * sizeof(ARMVectorReg))
75 
76 /*
77  * When doing byte arithmetic on the ZA storage, the element
78  * byteoff bytes away in a tile vertical slice is always this
79  * many bytes away in the ZA storage, regardless of the
80  * size of the tile element, assuming that byteoff is a multiple
81  * of the element size. Again this is because of the interleaving
82  * of the tiles. For instance if we have 1 byte per element then
83  * each row of the ZA storage has one byte of the vslice data,
84  * and (counting from 0) byte 8 goes in row 8 of the storage
85  * at offset (8 * row-size-in-bytes).
86  * If we have 8 bytes per element then each row of the ZA storage
87  * has 8 bytes of the data, but there are 8 interleaved tiles and
88  * so byte 8 of the data goes into row 1 of the tile,
89  * which is again row 8 of the storage, so the offset is still
90  * (8 * row-size-in-bytes). Similarly for other element sizes.
91  */
92 #define tile_vslice_offset(byteoff) ((byteoff) * sizeof(ARMVectorReg))
93 
94 
95 /*
96  * Move Zreg vector to ZArray column.
97  */
98 #define DO_MOVA_C(NAME, TYPE, H)                                        \
99 void HELPER(NAME)(void *za, void *vn, void *vg, uint32_t desc)          \
100 {                                                                       \
101     int i, oprsz = simd_oprsz(desc);                                    \
102     for (i = 0; i < oprsz; ) {                                          \
103         uint16_t pg = *(uint16_t *)(vg + H1_2(i >> 3));                 \
104         do {                                                            \
105             if (pg & 1) {                                               \
106                 *(TYPE *)(za + tile_vslice_offset(i)) = *(TYPE *)(vn + H(i)); \
107             }                                                           \
108             i += sizeof(TYPE);                                          \
109             pg >>= sizeof(TYPE);                                        \
110         } while (i & 15);                                               \
111     }                                                                   \
112 }
113 
DO_MOVA_C(sme_mova_cz_b,uint8_t,H1)114 DO_MOVA_C(sme_mova_cz_b, uint8_t, H1)
115 DO_MOVA_C(sme_mova_cz_h, uint16_t, H1_2)
116 DO_MOVA_C(sme_mova_cz_s, uint32_t, H1_4)
117 
118 void HELPER(sme_mova_cz_d)(void *za, void *vn, void *vg, uint32_t desc)
119 {
120     int i, oprsz = simd_oprsz(desc) / 8;
121     uint8_t *pg = vg;
122     uint64_t *n = vn;
123     uint64_t *a = za;
124 
125     for (i = 0; i < oprsz; i++) {
126         if (pg[H1(i)] & 1) {
127             a[tile_vslice_index(i)] = n[i];
128         }
129     }
130 }
131 
HELPER(sme_mova_cz_q)132 void HELPER(sme_mova_cz_q)(void *za, void *vn, void *vg, uint32_t desc)
133 {
134     int i, oprsz = simd_oprsz(desc) / 16;
135     uint16_t *pg = vg;
136     Int128 *n = vn;
137     Int128 *a = za;
138 
139     /*
140      * Int128 is used here simply to copy 16 bytes, and to simplify
141      * the address arithmetic.
142      */
143     for (i = 0; i < oprsz; i++) {
144         if (pg[H2(i)] & 1) {
145             a[tile_vslice_index(i)] = n[i];
146         }
147     }
148 }
149 
150 #undef DO_MOVA_C
151 
152 /*
153  * Move ZArray column to Zreg vector.
154  */
155 #define DO_MOVA_Z(NAME, TYPE, H)                                        \
156 void HELPER(NAME)(void *vd, void *za, void *vg, uint32_t desc)          \
157 {                                                                       \
158     int i, oprsz = simd_oprsz(desc);                                    \
159     for (i = 0; i < oprsz; ) {                                          \
160         uint16_t pg = *(uint16_t *)(vg + H1_2(i >> 3));                 \
161         do {                                                            \
162             if (pg & 1) {                                               \
163                 *(TYPE *)(vd + H(i)) = *(TYPE *)(za + tile_vslice_offset(i)); \
164             }                                                           \
165             i += sizeof(TYPE);                                          \
166             pg >>= sizeof(TYPE);                                        \
167         } while (i & 15);                                               \
168     }                                                                   \
169 }
170 
DO_MOVA_Z(sme_mova_zc_b,uint8_t,H1)171 DO_MOVA_Z(sme_mova_zc_b, uint8_t, H1)
172 DO_MOVA_Z(sme_mova_zc_h, uint16_t, H1_2)
173 DO_MOVA_Z(sme_mova_zc_s, uint32_t, H1_4)
174 
175 void HELPER(sme_mova_zc_d)(void *vd, void *za, void *vg, uint32_t desc)
176 {
177     int i, oprsz = simd_oprsz(desc) / 8;
178     uint8_t *pg = vg;
179     uint64_t *d = vd;
180     uint64_t *a = za;
181 
182     for (i = 0; i < oprsz; i++) {
183         if (pg[H1(i)] & 1) {
184             d[i] = a[tile_vslice_index(i)];
185         }
186     }
187 }
188 
HELPER(sme_mova_zc_q)189 void HELPER(sme_mova_zc_q)(void *vd, void *za, void *vg, uint32_t desc)
190 {
191     int i, oprsz = simd_oprsz(desc) / 16;
192     uint16_t *pg = vg;
193     Int128 *d = vd;
194     Int128 *a = za;
195 
196     /*
197      * Int128 is used here simply to copy 16 bytes, and to simplify
198      * the address arithmetic.
199      */
200     for (i = 0; i < oprsz; i++, za += sizeof(ARMVectorReg)) {
201         if (pg[H2(i)] & 1) {
202             d[i] = a[tile_vslice_index(i)];
203         }
204     }
205 }
206 
207 #undef DO_MOVA_Z
208 
209 /*
210  * Clear elements in a tile slice comprising len bytes.
211  */
212 
213 typedef void ClearFn(void *ptr, size_t off, size_t len);
214 
clear_horizontal(void * ptr,size_t off,size_t len)215 static void clear_horizontal(void *ptr, size_t off, size_t len)
216 {
217     memset(ptr + off, 0, len);
218 }
219 
clear_vertical_b(void * vptr,size_t off,size_t len)220 static void clear_vertical_b(void *vptr, size_t off, size_t len)
221 {
222     for (size_t i = 0; i < len; ++i) {
223         *(uint8_t *)(vptr + tile_vslice_offset(i + off)) = 0;
224     }
225 }
226 
clear_vertical_h(void * vptr,size_t off,size_t len)227 static void clear_vertical_h(void *vptr, size_t off, size_t len)
228 {
229     for (size_t i = 0; i < len; i += 2) {
230         *(uint16_t *)(vptr + tile_vslice_offset(i + off)) = 0;
231     }
232 }
233 
clear_vertical_s(void * vptr,size_t off,size_t len)234 static void clear_vertical_s(void *vptr, size_t off, size_t len)
235 {
236     for (size_t i = 0; i < len; i += 4) {
237         *(uint32_t *)(vptr + tile_vslice_offset(i + off)) = 0;
238     }
239 }
240 
clear_vertical_d(void * vptr,size_t off,size_t len)241 static void clear_vertical_d(void *vptr, size_t off, size_t len)
242 {
243     for (size_t i = 0; i < len; i += 8) {
244         *(uint64_t *)(vptr + tile_vslice_offset(i + off)) = 0;
245     }
246 }
247 
clear_vertical_q(void * vptr,size_t off,size_t len)248 static void clear_vertical_q(void *vptr, size_t off, size_t len)
249 {
250     for (size_t i = 0; i < len; i += 16) {
251         memset(vptr + tile_vslice_offset(i + off), 0, 16);
252     }
253 }
254 
255 /*
256  * Copy elements from an array into a tile slice comprising len bytes.
257  */
258 
259 typedef void CopyFn(void *dst, const void *src, size_t len);
260 
copy_horizontal(void * dst,const void * src,size_t len)261 static void copy_horizontal(void *dst, const void *src, size_t len)
262 {
263     memcpy(dst, src, len);
264 }
265 
copy_vertical_b(void * vdst,const void * vsrc,size_t len)266 static void copy_vertical_b(void *vdst, const void *vsrc, size_t len)
267 {
268     const uint8_t *src = vsrc;
269     uint8_t *dst = vdst;
270     size_t i;
271 
272     for (i = 0; i < len; ++i) {
273         dst[tile_vslice_index(i)] = src[i];
274     }
275 }
276 
copy_vertical_h(void * vdst,const void * vsrc,size_t len)277 static void copy_vertical_h(void *vdst, const void *vsrc, size_t len)
278 {
279     const uint16_t *src = vsrc;
280     uint16_t *dst = vdst;
281     size_t i;
282 
283     for (i = 0; i < len / 2; ++i) {
284         dst[tile_vslice_index(i)] = src[i];
285     }
286 }
287 
copy_vertical_s(void * vdst,const void * vsrc,size_t len)288 static void copy_vertical_s(void *vdst, const void *vsrc, size_t len)
289 {
290     const uint32_t *src = vsrc;
291     uint32_t *dst = vdst;
292     size_t i;
293 
294     for (i = 0; i < len / 4; ++i) {
295         dst[tile_vslice_index(i)] = src[i];
296     }
297 }
298 
copy_vertical_d(void * vdst,const void * vsrc,size_t len)299 static void copy_vertical_d(void *vdst, const void *vsrc, size_t len)
300 {
301     const uint64_t *src = vsrc;
302     uint64_t *dst = vdst;
303     size_t i;
304 
305     for (i = 0; i < len / 8; ++i) {
306         dst[tile_vslice_index(i)] = src[i];
307     }
308 }
309 
copy_vertical_q(void * vdst,const void * vsrc,size_t len)310 static void copy_vertical_q(void *vdst, const void *vsrc, size_t len)
311 {
312     for (size_t i = 0; i < len; i += 16) {
313         memcpy(vdst + tile_vslice_offset(i), vsrc + i, 16);
314     }
315 }
316 
317 /*
318  * Host and TLB primitives for vertical tile slice addressing.
319  */
320 
321 #define DO_LD(NAME, TYPE, HOST, TLB)                                        \
322 static inline void sme_##NAME##_v_host(void *za, intptr_t off, void *host)  \
323 {                                                                           \
324     TYPE val = HOST(host);                                                  \
325     *(TYPE *)(za + tile_vslice_offset(off)) = val;                          \
326 }                                                                           \
327 static inline void sme_##NAME##_v_tlb(CPUARMState *env, void *za,           \
328                         intptr_t off, target_ulong addr, uintptr_t ra)      \
329 {                                                                           \
330     TYPE val = TLB(env, useronly_clean_ptr(addr), ra);                      \
331     *(TYPE *)(za + tile_vslice_offset(off)) = val;                          \
332 }
333 
334 #define DO_ST(NAME, TYPE, HOST, TLB)                                        \
335 static inline void sme_##NAME##_v_host(void *za, intptr_t off, void *host)  \
336 {                                                                           \
337     TYPE val = *(TYPE *)(za + tile_vslice_offset(off));                     \
338     HOST(host, val);                                                        \
339 }                                                                           \
340 static inline void sme_##NAME##_v_tlb(CPUARMState *env, void *za,           \
341                         intptr_t off, target_ulong addr, uintptr_t ra)      \
342 {                                                                           \
343     TYPE val = *(TYPE *)(za + tile_vslice_offset(off));                     \
344     TLB(env, useronly_clean_ptr(addr), val, ra);                            \
345 }
346 
347 /*
348  * The ARMVectorReg elements are stored in host-endian 64-bit units.
349  * For 128-bit quantities, the sequence defined by the Elem[] pseudocode
350  * corresponds to storing the two 64-bit pieces in little-endian order.
351  */
352 #define DO_LDQ(HNAME, VNAME, BE, HOST, TLB)                                 \
353 static inline void HNAME##_host(void *za, intptr_t off, void *host)         \
354 {                                                                           \
355     uint64_t val0 = HOST(host), val1 = HOST(host + 8);                      \
356     uint64_t *ptr = za + off;                                               \
357     ptr[0] = BE ? val1 : val0, ptr[1] = BE ? val0 : val1;                   \
358 }                                                                           \
359 static inline void VNAME##_v_host(void *za, intptr_t off, void *host)       \
360 {                                                                           \
361     HNAME##_host(za, tile_vslice_offset(off), host);                        \
362 }                                                                           \
363 static inline void HNAME##_tlb(CPUARMState *env, void *za, intptr_t off,    \
364                                target_ulong addr, uintptr_t ra)             \
365 {                                                                           \
366     uint64_t val0 = TLB(env, useronly_clean_ptr(addr), ra);                 \
367     uint64_t val1 = TLB(env, useronly_clean_ptr(addr + 8), ra);             \
368     uint64_t *ptr = za + off;                                               \
369     ptr[0] = BE ? val1 : val0, ptr[1] = BE ? val0 : val1;                   \
370 }                                                                           \
371 static inline void VNAME##_v_tlb(CPUARMState *env, void *za, intptr_t off,  \
372                                target_ulong addr, uintptr_t ra)             \
373 {                                                                           \
374     HNAME##_tlb(env, za, tile_vslice_offset(off), addr, ra);                \
375 }
376 
377 #define DO_STQ(HNAME, VNAME, BE, HOST, TLB)                                 \
378 static inline void HNAME##_host(void *za, intptr_t off, void *host)         \
379 {                                                                           \
380     uint64_t *ptr = za + off;                                               \
381     HOST(host, ptr[BE]);                                                    \
382     HOST(host + 8, ptr[!BE]);                                               \
383 }                                                                           \
384 static inline void VNAME##_v_host(void *za, intptr_t off, void *host)       \
385 {                                                                           \
386     HNAME##_host(za, tile_vslice_offset(off), host);                        \
387 }                                                                           \
388 static inline void HNAME##_tlb(CPUARMState *env, void *za, intptr_t off,    \
389                                target_ulong addr, uintptr_t ra)             \
390 {                                                                           \
391     uint64_t *ptr = za + off;                                               \
392     TLB(env, useronly_clean_ptr(addr), ptr[BE], ra);                        \
393     TLB(env, useronly_clean_ptr(addr + 8), ptr[!BE], ra);                   \
394 }                                                                           \
395 static inline void VNAME##_v_tlb(CPUARMState *env, void *za, intptr_t off,  \
396                                target_ulong addr, uintptr_t ra)             \
397 {                                                                           \
398     HNAME##_tlb(env, za, tile_vslice_offset(off), addr, ra);                \
399 }
400 
DO_LD(ld1b,uint8_t,ldub_p,cpu_ldub_data_ra)401 DO_LD(ld1b, uint8_t, ldub_p, cpu_ldub_data_ra)
402 DO_LD(ld1h_be, uint16_t, lduw_be_p, cpu_lduw_be_data_ra)
403 DO_LD(ld1h_le, uint16_t, lduw_le_p, cpu_lduw_le_data_ra)
404 DO_LD(ld1s_be, uint32_t, ldl_be_p, cpu_ldl_be_data_ra)
405 DO_LD(ld1s_le, uint32_t, ldl_le_p, cpu_ldl_le_data_ra)
406 DO_LD(ld1d_be, uint64_t, ldq_be_p, cpu_ldq_be_data_ra)
407 DO_LD(ld1d_le, uint64_t, ldq_le_p, cpu_ldq_le_data_ra)
408 
409 DO_LDQ(sve_ld1qq_be, sme_ld1q_be, 1, ldq_be_p, cpu_ldq_be_data_ra)
410 DO_LDQ(sve_ld1qq_le, sme_ld1q_le, 0, ldq_le_p, cpu_ldq_le_data_ra)
411 
412 DO_ST(st1b, uint8_t, stb_p, cpu_stb_data_ra)
413 DO_ST(st1h_be, uint16_t, stw_be_p, cpu_stw_be_data_ra)
414 DO_ST(st1h_le, uint16_t, stw_le_p, cpu_stw_le_data_ra)
415 DO_ST(st1s_be, uint32_t, stl_be_p, cpu_stl_be_data_ra)
416 DO_ST(st1s_le, uint32_t, stl_le_p, cpu_stl_le_data_ra)
417 DO_ST(st1d_be, uint64_t, stq_be_p, cpu_stq_be_data_ra)
418 DO_ST(st1d_le, uint64_t, stq_le_p, cpu_stq_le_data_ra)
419 
420 DO_STQ(sve_st1qq_be, sme_st1q_be, 1, stq_be_p, cpu_stq_be_data_ra)
421 DO_STQ(sve_st1qq_le, sme_st1q_le, 0, stq_le_p, cpu_stq_le_data_ra)
422 
423 #undef DO_LD
424 #undef DO_ST
425 #undef DO_LDQ
426 #undef DO_STQ
427 
428 /*
429  * Common helper for all contiguous predicated loads.
430  */
431 
432 static inline QEMU_ALWAYS_INLINE
433 void sme_ld1(CPUARMState *env, void *za, uint64_t *vg,
434              const target_ulong addr, uint32_t desc, const uintptr_t ra,
435              const int esz, uint32_t mtedesc, bool vertical,
436              sve_ldst1_host_fn *host_fn,
437              sve_ldst1_tlb_fn *tlb_fn,
438              ClearFn *clr_fn,
439              CopyFn *cpy_fn)
440 {
441     const intptr_t reg_max = simd_oprsz(desc);
442     const intptr_t esize = 1 << esz;
443     intptr_t reg_off, reg_last;
444     SVEContLdSt info;
445     void *host;
446     int flags;
447 
448     /* Find the active elements.  */
449     if (!sve_cont_ldst_elements(&info, addr, vg, reg_max, esz, esize)) {
450         /* The entire predicate was false; no load occurs.  */
451         clr_fn(za, 0, reg_max);
452         return;
453     }
454 
455     /* Probe the page(s).  Exit with exception for any invalid page. */
456     sve_cont_ldst_pages(&info, FAULT_ALL, env, addr, MMU_DATA_LOAD, ra);
457 
458     /* Handle watchpoints for all active elements. */
459     sve_cont_ldst_watchpoints(&info, env, vg, addr, esize, esize,
460                               BP_MEM_READ, ra);
461 
462     /*
463      * Handle mte checks for all active elements.
464      * Since TBI must be set for MTE, !mtedesc => !mte_active.
465      */
466     if (mtedesc) {
467         sve_cont_ldst_mte_check(&info, env, vg, addr, esize, esize,
468                                 mtedesc, ra);
469     }
470 
471     flags = info.page[0].flags | info.page[1].flags;
472     if (unlikely(flags != 0)) {
473 #ifdef CONFIG_USER_ONLY
474         g_assert_not_reached();
475 #else
476         /*
477          * At least one page includes MMIO.
478          * Any bus operation can fail with cpu_transaction_failed,
479          * which for ARM will raise SyncExternal.  Perform the load
480          * into scratch memory to preserve register state until the end.
481          */
482         ARMVectorReg scratch = { };
483 
484         reg_off = info.reg_off_first[0];
485         reg_last = info.reg_off_last[1];
486         if (reg_last < 0) {
487             reg_last = info.reg_off_split;
488             if (reg_last < 0) {
489                 reg_last = info.reg_off_last[0];
490             }
491         }
492 
493         do {
494             uint64_t pg = vg[reg_off >> 6];
495             do {
496                 if ((pg >> (reg_off & 63)) & 1) {
497                     tlb_fn(env, &scratch, reg_off, addr + reg_off, ra);
498                 }
499                 reg_off += esize;
500             } while (reg_off & 63);
501         } while (reg_off <= reg_last);
502 
503         cpy_fn(za, &scratch, reg_max);
504         return;
505 #endif
506     }
507 
508     /* The entire operation is in RAM, on valid pages. */
509 
510     reg_off = info.reg_off_first[0];
511     reg_last = info.reg_off_last[0];
512     host = info.page[0].host;
513 
514     if (!vertical) {
515         memset(za, 0, reg_max);
516     } else if (reg_off) {
517         clr_fn(za, 0, reg_off);
518     }
519 
520     set_helper_retaddr(ra);
521 
522     while (reg_off <= reg_last) {
523         uint64_t pg = vg[reg_off >> 6];
524         do {
525             if ((pg >> (reg_off & 63)) & 1) {
526                 host_fn(za, reg_off, host + reg_off);
527             } else if (vertical) {
528                 clr_fn(za, reg_off, esize);
529             }
530             reg_off += esize;
531         } while (reg_off <= reg_last && (reg_off & 63));
532     }
533 
534     clear_helper_retaddr();
535 
536     /*
537      * Use the slow path to manage the cross-page misalignment.
538      * But we know this is RAM and cannot trap.
539      */
540     reg_off = info.reg_off_split;
541     if (unlikely(reg_off >= 0)) {
542         tlb_fn(env, za, reg_off, addr + reg_off, ra);
543     }
544 
545     reg_off = info.reg_off_first[1];
546     if (unlikely(reg_off >= 0)) {
547         reg_last = info.reg_off_last[1];
548         host = info.page[1].host;
549 
550         set_helper_retaddr(ra);
551 
552         do {
553             uint64_t pg = vg[reg_off >> 6];
554             do {
555                 if ((pg >> (reg_off & 63)) & 1) {
556                     host_fn(za, reg_off, host + reg_off);
557                 } else if (vertical) {
558                     clr_fn(za, reg_off, esize);
559                 }
560                 reg_off += esize;
561             } while (reg_off & 63);
562         } while (reg_off <= reg_last);
563 
564         clear_helper_retaddr();
565     }
566 }
567 
568 static inline QEMU_ALWAYS_INLINE
sme_ld1_mte(CPUARMState * env,void * za,uint64_t * vg,target_ulong addr,uint32_t desc,uintptr_t ra,const int esz,bool vertical,sve_ldst1_host_fn * host_fn,sve_ldst1_tlb_fn * tlb_fn,ClearFn * clr_fn,CopyFn * cpy_fn)569 void sme_ld1_mte(CPUARMState *env, void *za, uint64_t *vg,
570                  target_ulong addr, uint32_t desc, uintptr_t ra,
571                  const int esz, bool vertical,
572                  sve_ldst1_host_fn *host_fn,
573                  sve_ldst1_tlb_fn *tlb_fn,
574                  ClearFn *clr_fn,
575                  CopyFn *cpy_fn)
576 {
577     uint32_t mtedesc = desc >> (SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT);
578     int bit55 = extract64(addr, 55, 1);
579 
580     /* Remove mtedesc from the normal sve descriptor. */
581     desc = extract32(desc, 0, SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT);
582 
583     /* Perform gross MTE suppression early. */
584     if (!tbi_check(mtedesc, bit55) ||
585         tcma_check(mtedesc, bit55, allocation_tag_from_addr(addr))) {
586         mtedesc = 0;
587     }
588 
589     sme_ld1(env, za, vg, addr, desc, ra, esz, mtedesc, vertical,
590             host_fn, tlb_fn, clr_fn, cpy_fn);
591 }
592 
593 #define DO_LD(L, END, ESZ)                                                 \
594 void HELPER(sme_ld1##L##END##_h)(CPUARMState *env, void *za, void *vg,     \
595                                  target_ulong addr, uint32_t desc)         \
596 {                                                                          \
597     sme_ld1(env, za, vg, addr, desc, GETPC(), ESZ, 0, false,               \
598             sve_ld1##L##L##END##_host, sve_ld1##L##L##END##_tlb,           \
599             clear_horizontal, copy_horizontal);                            \
600 }                                                                          \
601 void HELPER(sme_ld1##L##END##_v)(CPUARMState *env, void *za, void *vg,     \
602                                  target_ulong addr, uint32_t desc)         \
603 {                                                                          \
604     sme_ld1(env, za, vg, addr, desc, GETPC(), ESZ, 0, true,                \
605             sme_ld1##L##END##_v_host, sme_ld1##L##END##_v_tlb,             \
606             clear_vertical_##L, copy_vertical_##L);                        \
607 }                                                                          \
608 void HELPER(sme_ld1##L##END##_h_mte)(CPUARMState *env, void *za, void *vg, \
609                                      target_ulong addr, uint32_t desc)     \
610 {                                                                          \
611     sme_ld1_mte(env, za, vg, addr, desc, GETPC(), ESZ, false,              \
612                 sve_ld1##L##L##END##_host, sve_ld1##L##L##END##_tlb,       \
613                 clear_horizontal, copy_horizontal);                        \
614 }                                                                          \
615 void HELPER(sme_ld1##L##END##_v_mte)(CPUARMState *env, void *za, void *vg, \
616                                      target_ulong addr, uint32_t desc)     \
617 {                                                                          \
618     sme_ld1_mte(env, za, vg, addr, desc, GETPC(), ESZ, true,               \
619                 sme_ld1##L##END##_v_host, sme_ld1##L##END##_v_tlb,         \
620                 clear_vertical_##L, copy_vertical_##L);                    \
621 }
622 
623 DO_LD(b, , MO_8)
DO_LD(h,_be,MO_16)624 DO_LD(h, _be, MO_16)
625 DO_LD(h, _le, MO_16)
626 DO_LD(s, _be, MO_32)
627 DO_LD(s, _le, MO_32)
628 DO_LD(d, _be, MO_64)
629 DO_LD(d, _le, MO_64)
630 DO_LD(q, _be, MO_128)
631 DO_LD(q, _le, MO_128)
632 
633 #undef DO_LD
634 
635 /*
636  * Common helper for all contiguous predicated stores.
637  */
638 
639 static inline QEMU_ALWAYS_INLINE
640 void sme_st1(CPUARMState *env, void *za, uint64_t *vg,
641              const target_ulong addr, uint32_t desc, const uintptr_t ra,
642              const int esz, uint32_t mtedesc, bool vertical,
643              sve_ldst1_host_fn *host_fn,
644              sve_ldst1_tlb_fn *tlb_fn)
645 {
646     const intptr_t reg_max = simd_oprsz(desc);
647     const intptr_t esize = 1 << esz;
648     intptr_t reg_off, reg_last;
649     SVEContLdSt info;
650     void *host;
651     int flags;
652 
653     /* Find the active elements.  */
654     if (!sve_cont_ldst_elements(&info, addr, vg, reg_max, esz, esize)) {
655         /* The entire predicate was false; no store occurs.  */
656         return;
657     }
658 
659     /* Probe the page(s).  Exit with exception for any invalid page. */
660     sve_cont_ldst_pages(&info, FAULT_ALL, env, addr, MMU_DATA_STORE, ra);
661 
662     /* Handle watchpoints for all active elements. */
663     sve_cont_ldst_watchpoints(&info, env, vg, addr, esize, esize,
664                               BP_MEM_WRITE, ra);
665 
666     /*
667      * Handle mte checks for all active elements.
668      * Since TBI must be set for MTE, !mtedesc => !mte_active.
669      */
670     if (mtedesc) {
671         sve_cont_ldst_mte_check(&info, env, vg, addr, esize, esize,
672                                 mtedesc, ra);
673     }
674 
675     flags = info.page[0].flags | info.page[1].flags;
676     if (unlikely(flags != 0)) {
677 #ifdef CONFIG_USER_ONLY
678         g_assert_not_reached();
679 #else
680         /*
681          * At least one page includes MMIO.
682          * Any bus operation can fail with cpu_transaction_failed,
683          * which for ARM will raise SyncExternal.  We cannot avoid
684          * this fault and will leave with the store incomplete.
685          */
686         reg_off = info.reg_off_first[0];
687         reg_last = info.reg_off_last[1];
688         if (reg_last < 0) {
689             reg_last = info.reg_off_split;
690             if (reg_last < 0) {
691                 reg_last = info.reg_off_last[0];
692             }
693         }
694 
695         do {
696             uint64_t pg = vg[reg_off >> 6];
697             do {
698                 if ((pg >> (reg_off & 63)) & 1) {
699                     tlb_fn(env, za, reg_off, addr + reg_off, ra);
700                 }
701                 reg_off += esize;
702             } while (reg_off & 63);
703         } while (reg_off <= reg_last);
704         return;
705 #endif
706     }
707 
708     reg_off = info.reg_off_first[0];
709     reg_last = info.reg_off_last[0];
710     host = info.page[0].host;
711 
712     set_helper_retaddr(ra);
713 
714     while (reg_off <= reg_last) {
715         uint64_t pg = vg[reg_off >> 6];
716         do {
717             if ((pg >> (reg_off & 63)) & 1) {
718                 host_fn(za, reg_off, host + reg_off);
719             }
720             reg_off += 1 << esz;
721         } while (reg_off <= reg_last && (reg_off & 63));
722     }
723 
724     clear_helper_retaddr();
725 
726     /*
727      * Use the slow path to manage the cross-page misalignment.
728      * But we know this is RAM and cannot trap.
729      */
730     reg_off = info.reg_off_split;
731     if (unlikely(reg_off >= 0)) {
732         tlb_fn(env, za, reg_off, addr + reg_off, ra);
733     }
734 
735     reg_off = info.reg_off_first[1];
736     if (unlikely(reg_off >= 0)) {
737         reg_last = info.reg_off_last[1];
738         host = info.page[1].host;
739 
740         set_helper_retaddr(ra);
741 
742         do {
743             uint64_t pg = vg[reg_off >> 6];
744             do {
745                 if ((pg >> (reg_off & 63)) & 1) {
746                     host_fn(za, reg_off, host + reg_off);
747                 }
748                 reg_off += 1 << esz;
749             } while (reg_off & 63);
750         } while (reg_off <= reg_last);
751 
752         clear_helper_retaddr();
753     }
754 }
755 
756 static inline QEMU_ALWAYS_INLINE
sme_st1_mte(CPUARMState * env,void * za,uint64_t * vg,target_ulong addr,uint32_t desc,uintptr_t ra,int esz,bool vertical,sve_ldst1_host_fn * host_fn,sve_ldst1_tlb_fn * tlb_fn)757 void sme_st1_mte(CPUARMState *env, void *za, uint64_t *vg, target_ulong addr,
758                  uint32_t desc, uintptr_t ra, int esz, bool vertical,
759                  sve_ldst1_host_fn *host_fn,
760                  sve_ldst1_tlb_fn *tlb_fn)
761 {
762     uint32_t mtedesc = desc >> (SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT);
763     int bit55 = extract64(addr, 55, 1);
764 
765     /* Remove mtedesc from the normal sve descriptor. */
766     desc = extract32(desc, 0, SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT);
767 
768     /* Perform gross MTE suppression early. */
769     if (!tbi_check(mtedesc, bit55) ||
770         tcma_check(mtedesc, bit55, allocation_tag_from_addr(addr))) {
771         mtedesc = 0;
772     }
773 
774     sme_st1(env, za, vg, addr, desc, ra, esz, mtedesc,
775             vertical, host_fn, tlb_fn);
776 }
777 
778 #define DO_ST(L, END, ESZ)                                                 \
779 void HELPER(sme_st1##L##END##_h)(CPUARMState *env, void *za, void *vg,     \
780                                  target_ulong addr, uint32_t desc)         \
781 {                                                                          \
782     sme_st1(env, za, vg, addr, desc, GETPC(), ESZ, 0, false,               \
783             sve_st1##L##L##END##_host, sve_st1##L##L##END##_tlb);          \
784 }                                                                          \
785 void HELPER(sme_st1##L##END##_v)(CPUARMState *env, void *za, void *vg,     \
786                                  target_ulong addr, uint32_t desc)         \
787 {                                                                          \
788     sme_st1(env, za, vg, addr, desc, GETPC(), ESZ, 0, true,                \
789             sme_st1##L##END##_v_host, sme_st1##L##END##_v_tlb);            \
790 }                                                                          \
791 void HELPER(sme_st1##L##END##_h_mte)(CPUARMState *env, void *za, void *vg, \
792                                      target_ulong addr, uint32_t desc)     \
793 {                                                                          \
794     sme_st1_mte(env, za, vg, addr, desc, GETPC(), ESZ, false,              \
795                 sve_st1##L##L##END##_host, sve_st1##L##L##END##_tlb);      \
796 }                                                                          \
797 void HELPER(sme_st1##L##END##_v_mte)(CPUARMState *env, void *za, void *vg, \
798                                      target_ulong addr, uint32_t desc)     \
799 {                                                                          \
800     sme_st1_mte(env, za, vg, addr, desc, GETPC(), ESZ, true,               \
801                 sme_st1##L##END##_v_host, sme_st1##L##END##_v_tlb);        \
802 }
803 
804 DO_ST(b, , MO_8)
DO_ST(h,_be,MO_16)805 DO_ST(h, _be, MO_16)
806 DO_ST(h, _le, MO_16)
807 DO_ST(s, _be, MO_32)
808 DO_ST(s, _le, MO_32)
809 DO_ST(d, _be, MO_64)
810 DO_ST(d, _le, MO_64)
811 DO_ST(q, _be, MO_128)
812 DO_ST(q, _le, MO_128)
813 
814 #undef DO_ST
815 
816 void HELPER(sme_addha_s)(void *vzda, void *vzn, void *vpn,
817                          void *vpm, uint32_t desc)
818 {
819     intptr_t row, col, oprsz = simd_oprsz(desc) / 4;
820     uint64_t *pn = vpn, *pm = vpm;
821     uint32_t *zda = vzda, *zn = vzn;
822 
823     for (row = 0; row < oprsz; ) {
824         uint64_t pa = pn[row >> 4];
825         do {
826             if (pa & 1) {
827                 for (col = 0; col < oprsz; ) {
828                     uint64_t pb = pm[col >> 4];
829                     do {
830                         if (pb & 1) {
831                             zda[tile_vslice_index(row) + H4(col)] += zn[H4(col)];
832                         }
833                         pb >>= 4;
834                     } while (++col & 15);
835                 }
836             }
837             pa >>= 4;
838         } while (++row & 15);
839     }
840 }
841 
HELPER(sme_addha_d)842 void HELPER(sme_addha_d)(void *vzda, void *vzn, void *vpn,
843                          void *vpm, uint32_t desc)
844 {
845     intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
846     uint8_t *pn = vpn, *pm = vpm;
847     uint64_t *zda = vzda, *zn = vzn;
848 
849     for (row = 0; row < oprsz; ++row) {
850         if (pn[H1(row)] & 1) {
851             for (col = 0; col < oprsz; ++col) {
852                 if (pm[H1(col)] & 1) {
853                     zda[tile_vslice_index(row) + col] += zn[col];
854                 }
855             }
856         }
857     }
858 }
859 
HELPER(sme_addva_s)860 void HELPER(sme_addva_s)(void *vzda, void *vzn, void *vpn,
861                          void *vpm, uint32_t desc)
862 {
863     intptr_t row, col, oprsz = simd_oprsz(desc) / 4;
864     uint64_t *pn = vpn, *pm = vpm;
865     uint32_t *zda = vzda, *zn = vzn;
866 
867     for (row = 0; row < oprsz; ) {
868         uint64_t pa = pn[row >> 4];
869         do {
870             if (pa & 1) {
871                 uint32_t zn_row = zn[H4(row)];
872                 for (col = 0; col < oprsz; ) {
873                     uint64_t pb = pm[col >> 4];
874                     do {
875                         if (pb & 1) {
876                             zda[tile_vslice_index(row) + H4(col)] += zn_row;
877                         }
878                         pb >>= 4;
879                     } while (++col & 15);
880                 }
881             }
882             pa >>= 4;
883         } while (++row & 15);
884     }
885 }
886 
HELPER(sme_addva_d)887 void HELPER(sme_addva_d)(void *vzda, void *vzn, void *vpn,
888                          void *vpm, uint32_t desc)
889 {
890     intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
891     uint8_t *pn = vpn, *pm = vpm;
892     uint64_t *zda = vzda, *zn = vzn;
893 
894     for (row = 0; row < oprsz; ++row) {
895         if (pn[H1(row)] & 1) {
896             uint64_t zn_row = zn[row];
897             for (col = 0; col < oprsz; ++col) {
898                 if (pm[H1(col)] & 1) {
899                     zda[tile_vslice_index(row) + col] += zn_row;
900                 }
901             }
902         }
903     }
904 }
905 
HELPER(sme_fmopa_s)906 void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
907                          void *vpm, void *vst, uint32_t desc)
908 {
909     intptr_t row, col, oprsz = simd_maxsz(desc);
910     uint32_t neg = simd_data(desc) << 31;
911     uint16_t *pn = vpn, *pm = vpm;
912     float_status fpst;
913 
914     /*
915      * Make a copy of float_status because this operation does not
916      * update the cumulative fp exception status.  It also produces
917      * default nans.
918      */
919     fpst = *(float_status *)vst;
920     set_default_nan_mode(true, &fpst);
921 
922     for (row = 0; row < oprsz; ) {
923         uint16_t pa = pn[H2(row >> 4)];
924         do {
925             if (pa & 1) {
926                 void *vza_row = vza + tile_vslice_offset(row);
927                 uint32_t n = *(uint32_t *)(vzn + H1_4(row)) ^ neg;
928 
929                 for (col = 0; col < oprsz; ) {
930                     uint16_t pb = pm[H2(col >> 4)];
931                     do {
932                         if (pb & 1) {
933                             uint32_t *a = vza_row + H1_4(col);
934                             uint32_t *m = vzm + H1_4(col);
935                             *a = float32_muladd(n, *m, *a, 0, &fpst);
936                         }
937                         col += 4;
938                         pb >>= 4;
939                     } while (col & 15);
940                 }
941             }
942             row += 4;
943             pa >>= 4;
944         } while (row & 15);
945     }
946 }
947 
HELPER(sme_fmopa_d)948 void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn,
949                          void *vpm, void *vst, uint32_t desc)
950 {
951     intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
952     uint64_t neg = (uint64_t)simd_data(desc) << 63;
953     uint64_t *za = vza, *zn = vzn, *zm = vzm;
954     uint8_t *pn = vpn, *pm = vpm;
955     float_status fpst = *(float_status *)vst;
956 
957     set_default_nan_mode(true, &fpst);
958 
959     for (row = 0; row < oprsz; ++row) {
960         if (pn[H1(row)] & 1) {
961             uint64_t *za_row = &za[tile_vslice_index(row)];
962             uint64_t n = zn[row] ^ neg;
963 
964             for (col = 0; col < oprsz; ++col) {
965                 if (pm[H1(col)] & 1) {
966                     uint64_t *a = &za_row[col];
967                     *a = float64_muladd(n, zm[col], *a, 0, &fpst);
968                 }
969             }
970         }
971     }
972 }
973 
974 /*
975  * Alter PAIR as needed for controlling predicates being false,
976  * and for NEG on an enabled row element.
977  */
f16mop_adj_pair(uint32_t pair,uint32_t pg,uint32_t neg)978 static inline uint32_t f16mop_adj_pair(uint32_t pair, uint32_t pg, uint32_t neg)
979 {
980     /*
981      * The pseudocode uses a conditional negate after the conditional zero.
982      * It is simpler here to unconditionally negate before conditional zero.
983      */
984     pair ^= neg;
985     if (!(pg & 1)) {
986         pair &= 0xffff0000u;
987     }
988     if (!(pg & 4)) {
989         pair &= 0x0000ffffu;
990     }
991     return pair;
992 }
993 
f16_dotadd(float32 sum,uint32_t e1,uint32_t e2,float_status * s_f16,float_status * s_std,float_status * s_odd)994 static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2,
995                           float_status *s_f16, float_status *s_std,
996                           float_status *s_odd)
997 {
998     /*
999      * We need three different float_status for different parts of this
1000      * operation:
1001      *  - the input conversion of the float16 values must use the
1002      *    f16-specific float_status, so that the FPCR.FZ16 control is applied
1003      *  - operations on float32 including the final accumulation must use
1004      *    the normal float_status, so that FPCR.FZ is applied
1005      *  - we have pre-set-up copy of s_std which is set to round-to-odd,
1006      *    for the multiply (see below)
1007      */
1008     float64 e1r = float16_to_float64(e1 & 0xffff, true, s_f16);
1009     float64 e1c = float16_to_float64(e1 >> 16, true, s_f16);
1010     float64 e2r = float16_to_float64(e2 & 0xffff, true, s_f16);
1011     float64 e2c = float16_to_float64(e2 >> 16, true, s_f16);
1012     float64 t64;
1013     float32 t32;
1014 
1015     /*
1016      * The ARM pseudocode function FPDot performs both multiplies
1017      * and the add with a single rounding operation.  Emulate this
1018      * by performing the first multiply in round-to-odd, then doing
1019      * the second multiply as fused multiply-add, and rounding to
1020      * float32 all in one step.
1021      */
1022     t64 = float64_mul(e1r, e2r, s_odd);
1023     t64 = float64r32_muladd(e1c, e2c, t64, 0, s_std);
1024 
1025     /* This conversion is exact, because we've already rounded. */
1026     t32 = float64_to_float32(t64, s_std);
1027 
1028     /* The final accumulation step is not fused. */
1029     return float32_add(sum, t32, s_std);
1030 }
1031 
HELPER(sme_fmopa_h)1032 void HELPER(sme_fmopa_h)(void *vza, void *vzn, void *vzm, void *vpn,
1033                          void *vpm, CPUARMState *env, uint32_t desc)
1034 {
1035     intptr_t row, col, oprsz = simd_maxsz(desc);
1036     uint32_t neg = simd_data(desc) * 0x80008000u;
1037     uint16_t *pn = vpn, *pm = vpm;
1038     float_status fpst_odd, fpst_std, fpst_f16;
1039 
1040     /*
1041      * Make copies of fp_status and fp_status_f16, because this operation
1042      * does not update the cumulative fp exception status.  It also
1043      * produces default NaNs. We also need a second copy of fp_status with
1044      * round-to-odd -- see above.
1045      */
1046     fpst_f16 = env->vfp.fp_status_f16;
1047     fpst_std = env->vfp.fp_status;
1048     set_default_nan_mode(true, &fpst_std);
1049     set_default_nan_mode(true, &fpst_f16);
1050     fpst_odd = fpst_std;
1051     set_float_rounding_mode(float_round_to_odd, &fpst_odd);
1052 
1053     for (row = 0; row < oprsz; ) {
1054         uint16_t prow = pn[H2(row >> 4)];
1055         do {
1056             void *vza_row = vza + tile_vslice_offset(row);
1057             uint32_t n = *(uint32_t *)(vzn + H1_4(row));
1058 
1059             n = f16mop_adj_pair(n, prow, neg);
1060 
1061             for (col = 0; col < oprsz; ) {
1062                 uint16_t pcol = pm[H2(col >> 4)];
1063                 do {
1064                     if (prow & pcol & 0b0101) {
1065                         uint32_t *a = vza_row + H1_4(col);
1066                         uint32_t m = *(uint32_t *)(vzm + H1_4(col));
1067 
1068                         m = f16mop_adj_pair(m, pcol, 0);
1069                         *a = f16_dotadd(*a, n, m,
1070                                         &fpst_f16, &fpst_std, &fpst_odd);
1071                     }
1072                     col += 4;
1073                     pcol >>= 4;
1074                 } while (col & 15);
1075             }
1076             row += 4;
1077             prow >>= 4;
1078         } while (row & 15);
1079     }
1080 }
1081 
HELPER(sme_bfmopa)1082 void HELPER(sme_bfmopa)(void *vza, void *vzn, void *vzm,
1083                         void *vpn, void *vpm, CPUARMState *env, uint32_t desc)
1084 {
1085     intptr_t row, col, oprsz = simd_maxsz(desc);
1086     uint32_t neg = simd_data(desc) * 0x80008000u;
1087     uint16_t *pn = vpn, *pm = vpm;
1088     float_status fpst, fpst_odd;
1089 
1090     if (is_ebf(env, &fpst, &fpst_odd)) {
1091         for (row = 0; row < oprsz; ) {
1092             uint16_t prow = pn[H2(row >> 4)];
1093             do {
1094                 void *vza_row = vza + tile_vslice_offset(row);
1095                 uint32_t n = *(uint32_t *)(vzn + H1_4(row));
1096 
1097                 n = f16mop_adj_pair(n, prow, neg);
1098 
1099                 for (col = 0; col < oprsz; ) {
1100                     uint16_t pcol = pm[H2(col >> 4)];
1101                     do {
1102                         if (prow & pcol & 0b0101) {
1103                             uint32_t *a = vza_row + H1_4(col);
1104                             uint32_t m = *(uint32_t *)(vzm + H1_4(col));
1105 
1106                             m = f16mop_adj_pair(m, pcol, 0);
1107                             *a = bfdotadd_ebf(*a, n, m, &fpst, &fpst_odd);
1108                         }
1109                         col += 4;
1110                         pcol >>= 4;
1111                     } while (col & 15);
1112                 }
1113                 row += 4;
1114                 prow >>= 4;
1115             } while (row & 15);
1116         }
1117     } else {
1118         for (row = 0; row < oprsz; ) {
1119             uint16_t prow = pn[H2(row >> 4)];
1120             do {
1121                 void *vza_row = vza + tile_vslice_offset(row);
1122                 uint32_t n = *(uint32_t *)(vzn + H1_4(row));
1123 
1124                 n = f16mop_adj_pair(n, prow, neg);
1125 
1126                 for (col = 0; col < oprsz; ) {
1127                     uint16_t pcol = pm[H2(col >> 4)];
1128                     do {
1129                         if (prow & pcol & 0b0101) {
1130                             uint32_t *a = vza_row + H1_4(col);
1131                             uint32_t m = *(uint32_t *)(vzm + H1_4(col));
1132 
1133                             m = f16mop_adj_pair(m, pcol, 0);
1134                             *a = bfdotadd(*a, n, m, &fpst);
1135                         }
1136                         col += 4;
1137                         pcol >>= 4;
1138                     } while (col & 15);
1139                 }
1140                 row += 4;
1141                 prow >>= 4;
1142             } while (row & 15);
1143         }
1144     }
1145 }
1146 
1147 typedef uint32_t IMOPFn32(uint32_t, uint32_t, uint32_t, uint8_t, bool);
do_imopa_s(uint32_t * za,uint32_t * zn,uint32_t * zm,uint8_t * pn,uint8_t * pm,uint32_t desc,IMOPFn32 * fn)1148 static inline void do_imopa_s(uint32_t *za, uint32_t *zn, uint32_t *zm,
1149                               uint8_t *pn, uint8_t *pm,
1150                               uint32_t desc, IMOPFn32 *fn)
1151 {
1152     intptr_t row, col, oprsz = simd_oprsz(desc) / 4;
1153     bool neg = simd_data(desc);
1154 
1155     for (row = 0; row < oprsz; ++row) {
1156         uint8_t pa = (pn[H1(row >> 1)] >> ((row & 1) * 4)) & 0xf;
1157         uint32_t *za_row = &za[tile_vslice_index(row)];
1158         uint32_t n = zn[H4(row)];
1159 
1160         for (col = 0; col < oprsz; ++col) {
1161             uint8_t pb = pm[H1(col >> 1)] >> ((col & 1) * 4);
1162             uint32_t *a = &za_row[H4(col)];
1163 
1164             *a = fn(n, zm[H4(col)], *a, pa & pb, neg);
1165         }
1166     }
1167 }
1168 
1169 typedef uint64_t IMOPFn64(uint64_t, uint64_t, uint64_t, uint8_t, bool);
do_imopa_d(uint64_t * za,uint64_t * zn,uint64_t * zm,uint8_t * pn,uint8_t * pm,uint32_t desc,IMOPFn64 * fn)1170 static inline void do_imopa_d(uint64_t *za, uint64_t *zn, uint64_t *zm,
1171                               uint8_t *pn, uint8_t *pm,
1172                               uint32_t desc, IMOPFn64 *fn)
1173 {
1174     intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
1175     bool neg = simd_data(desc);
1176 
1177     for (row = 0; row < oprsz; ++row) {
1178         uint8_t pa = pn[H1(row)];
1179         uint64_t *za_row = &za[tile_vslice_index(row)];
1180         uint64_t n = zn[row];
1181 
1182         for (col = 0; col < oprsz; ++col) {
1183             uint8_t pb = pm[H1(col)];
1184             uint64_t *a = &za_row[col];
1185 
1186             *a = fn(n, zm[col], *a, pa & pb, neg);
1187         }
1188     }
1189 }
1190 
1191 #define DEF_IMOP_32(NAME, NTYPE, MTYPE) \
1192 static uint32_t NAME(uint32_t n, uint32_t m, uint32_t a, uint8_t p, bool neg) \
1193 {                                                                           \
1194     uint32_t sum = 0;                                                       \
1195     /* Apply P to N as a mask, making the inactive elements 0. */           \
1196     n &= expand_pred_b(p);                                                  \
1197     sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0);                               \
1198     sum += (NTYPE)(n >> 8) * (MTYPE)(m >> 8);                               \
1199     sum += (NTYPE)(n >> 16) * (MTYPE)(m >> 16);                             \
1200     sum += (NTYPE)(n >> 24) * (MTYPE)(m >> 24);                             \
1201     return neg ? a - sum : a + sum;                                         \
1202 }
1203 
1204 #define DEF_IMOP_64(NAME, NTYPE, MTYPE) \
1205 static uint64_t NAME(uint64_t n, uint64_t m, uint64_t a, uint8_t p, bool neg) \
1206 {                                                                           \
1207     uint64_t sum = 0;                                                       \
1208     /* Apply P to N as a mask, making the inactive elements 0. */           \
1209     n &= expand_pred_h(p);                                                  \
1210     sum += (int64_t)(NTYPE)(n >> 0) * (MTYPE)(m >> 0);                      \
1211     sum += (int64_t)(NTYPE)(n >> 16) * (MTYPE)(m >> 16);                    \
1212     sum += (int64_t)(NTYPE)(n >> 32) * (MTYPE)(m >> 32);                    \
1213     sum += (int64_t)(NTYPE)(n >> 48) * (MTYPE)(m >> 48);                    \
1214     return neg ? a - sum : a + sum;                                         \
1215 }
1216 
1217 DEF_IMOP_32(smopa_s, int8_t, int8_t)
1218 DEF_IMOP_32(umopa_s, uint8_t, uint8_t)
1219 DEF_IMOP_32(sumopa_s, int8_t, uint8_t)
1220 DEF_IMOP_32(usmopa_s, uint8_t, int8_t)
1221 
1222 DEF_IMOP_64(smopa_d, int16_t, int16_t)
1223 DEF_IMOP_64(umopa_d, uint16_t, uint16_t)
1224 DEF_IMOP_64(sumopa_d, int16_t, uint16_t)
1225 DEF_IMOP_64(usmopa_d, uint16_t, int16_t)
1226 
1227 #define DEF_IMOPH(NAME, S) \
1228     void HELPER(sme_##NAME##_##S)(void *vza, void *vzn, void *vzm,          \
1229                                   void *vpn, void *vpm, uint32_t desc)      \
1230     { do_imopa_##S(vza, vzn, vzm, vpn, vpm, desc, NAME##_##S); }
1231 
1232 DEF_IMOPH(smopa, s)
1233 DEF_IMOPH(umopa, s)
1234 DEF_IMOPH(sumopa, s)
1235 DEF_IMOPH(usmopa, s)
1236 
1237 DEF_IMOPH(smopa, d)
1238 DEF_IMOPH(umopa, d)
1239 DEF_IMOPH(sumopa, d)
1240 DEF_IMOPH(usmopa, d)
1241