xref: /openbmc/qemu/target/arm/tcg/sme_helper.c (revision ec08d9a51e6af3cd3edbdbf2ca6e97a1e2b5f0d1)
1  /*
2   * ARM SME Operations
3   *
4   * Copyright (c) 2022 Linaro, Ltd.
5   *
6   * This library is free software; you can redistribute it and/or
7   * modify it under the terms of the GNU Lesser General Public
8   * License as published by the Free Software Foundation; either
9   * version 2.1 of the License, or (at your option) any later version.
10   *
11   * This library is distributed in the hope that it will be useful,
12   * but WITHOUT ANY WARRANTY; without even the implied warranty of
13   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14   * Lesser General Public License for more details.
15   *
16   * You should have received a copy of the GNU Lesser General Public
17   * License along with this library; if not, see <http://www.gnu.org/licenses/>.
18   */
19  
20  #include "qemu/osdep.h"
21  #include "cpu.h"
22  #include "internals.h"
23  #include "tcg/tcg-gvec-desc.h"
24  #include "exec/helper-proto.h"
25  #include "exec/cpu_ldst.h"
26  #include "exec/exec-all.h"
27  #include "qemu/int128.h"
28  #include "fpu/softfloat.h"
29  #include "vec_internal.h"
30  #include "sve_ldst_internal.h"
31  
helper_set_svcr(CPUARMState * env,uint32_t val,uint32_t mask)32  void helper_set_svcr(CPUARMState *env, uint32_t val, uint32_t mask)
33  {
34      aarch64_set_svcr(env, val, mask);
35  }
36  
helper_sme_zero(CPUARMState * env,uint32_t imm,uint32_t svl)37  void helper_sme_zero(CPUARMState *env, uint32_t imm, uint32_t svl)
38  {
39      uint32_t i;
40  
41      /*
42       * Special case clearing the entire ZA space.
43       * This falls into the CONSTRAINED UNPREDICTABLE zeroing of any
44       * parts of the ZA storage outside of SVL.
45       */
46      if (imm == 0xff) {
47          memset(env->zarray, 0, sizeof(env->zarray));
48          return;
49      }
50  
51      /*
52       * Recall that ZAnH.D[m] is spread across ZA[n+8*m],
53       * so each row is discontiguous within ZA[].
54       */
55      for (i = 0; i < svl; i++) {
56          if (imm & (1 << (i % 8))) {
57              memset(&env->zarray[i], 0, svl);
58          }
59      }
60  }
61  
62  
63  /*
64   * When considering the ZA storage as an array of elements of
65   * type T, the index within that array of the Nth element of
66   * a vertical slice of a tile can be calculated like this,
67   * regardless of the size of type T. This is because the tiles
68   * are interleaved, so if type T is size N bytes then row 1 of
69   * the tile is N rows away from row 0. The division by N to
70   * convert a byte offset into an array index and the multiplication
71   * by N to convert from vslice-index-within-the-tile to
72   * the index within the ZA storage cancel out.
73   */
74  #define tile_vslice_index(i) ((i) * sizeof(ARMVectorReg))
75  
76  /*
77   * When doing byte arithmetic on the ZA storage, the element
78   * byteoff bytes away in a tile vertical slice is always this
79   * many bytes away in the ZA storage, regardless of the
80   * size of the tile element, assuming that byteoff is a multiple
81   * of the element size. Again this is because of the interleaving
82   * of the tiles. For instance if we have 1 byte per element then
83   * each row of the ZA storage has one byte of the vslice data,
84   * and (counting from 0) byte 8 goes in row 8 of the storage
85   * at offset (8 * row-size-in-bytes).
86   * If we have 8 bytes per element then each row of the ZA storage
87   * has 8 bytes of the data, but there are 8 interleaved tiles and
88   * so byte 8 of the data goes into row 1 of the tile,
89   * which is again row 8 of the storage, so the offset is still
90   * (8 * row-size-in-bytes). Similarly for other element sizes.
91   */
92  #define tile_vslice_offset(byteoff) ((byteoff) * sizeof(ARMVectorReg))
93  
94  
95  /*
96   * Move Zreg vector to ZArray column.
97   */
98  #define DO_MOVA_C(NAME, TYPE, H)                                        \
99  void HELPER(NAME)(void *za, void *vn, void *vg, uint32_t desc)          \
100  {                                                                       \
101      int i, oprsz = simd_oprsz(desc);                                    \
102      for (i = 0; i < oprsz; ) {                                          \
103          uint16_t pg = *(uint16_t *)(vg + H1_2(i >> 3));                 \
104          do {                                                            \
105              if (pg & 1) {                                               \
106                  *(TYPE *)(za + tile_vslice_offset(i)) = *(TYPE *)(vn + H(i)); \
107              }                                                           \
108              i += sizeof(TYPE);                                          \
109              pg >>= sizeof(TYPE);                                        \
110          } while (i & 15);                                               \
111      }                                                                   \
112  }
113  
DO_MOVA_C(sme_mova_cz_b,uint8_t,H1)114  DO_MOVA_C(sme_mova_cz_b, uint8_t, H1)
115  DO_MOVA_C(sme_mova_cz_h, uint16_t, H1_2)
116  DO_MOVA_C(sme_mova_cz_s, uint32_t, H1_4)
117  
118  void HELPER(sme_mova_cz_d)(void *za, void *vn, void *vg, uint32_t desc)
119  {
120      int i, oprsz = simd_oprsz(desc) / 8;
121      uint8_t *pg = vg;
122      uint64_t *n = vn;
123      uint64_t *a = za;
124  
125      for (i = 0; i < oprsz; i++) {
126          if (pg[H1(i)] & 1) {
127              a[tile_vslice_index(i)] = n[i];
128          }
129      }
130  }
131  
HELPER(sme_mova_cz_q)132  void HELPER(sme_mova_cz_q)(void *za, void *vn, void *vg, uint32_t desc)
133  {
134      int i, oprsz = simd_oprsz(desc) / 16;
135      uint16_t *pg = vg;
136      Int128 *n = vn;
137      Int128 *a = za;
138  
139      /*
140       * Int128 is used here simply to copy 16 bytes, and to simplify
141       * the address arithmetic.
142       */
143      for (i = 0; i < oprsz; i++) {
144          if (pg[H2(i)] & 1) {
145              a[tile_vslice_index(i)] = n[i];
146          }
147      }
148  }
149  
150  #undef DO_MOVA_C
151  
152  /*
153   * Move ZArray column to Zreg vector.
154   */
155  #define DO_MOVA_Z(NAME, TYPE, H)                                        \
156  void HELPER(NAME)(void *vd, void *za, void *vg, uint32_t desc)          \
157  {                                                                       \
158      int i, oprsz = simd_oprsz(desc);                                    \
159      for (i = 0; i < oprsz; ) {                                          \
160          uint16_t pg = *(uint16_t *)(vg + H1_2(i >> 3));                 \
161          do {                                                            \
162              if (pg & 1) {                                               \
163                  *(TYPE *)(vd + H(i)) = *(TYPE *)(za + tile_vslice_offset(i)); \
164              }                                                           \
165              i += sizeof(TYPE);                                          \
166              pg >>= sizeof(TYPE);                                        \
167          } while (i & 15);                                               \
168      }                                                                   \
169  }
170  
DO_MOVA_Z(sme_mova_zc_b,uint8_t,H1)171  DO_MOVA_Z(sme_mova_zc_b, uint8_t, H1)
172  DO_MOVA_Z(sme_mova_zc_h, uint16_t, H1_2)
173  DO_MOVA_Z(sme_mova_zc_s, uint32_t, H1_4)
174  
175  void HELPER(sme_mova_zc_d)(void *vd, void *za, void *vg, uint32_t desc)
176  {
177      int i, oprsz = simd_oprsz(desc) / 8;
178      uint8_t *pg = vg;
179      uint64_t *d = vd;
180      uint64_t *a = za;
181  
182      for (i = 0; i < oprsz; i++) {
183          if (pg[H1(i)] & 1) {
184              d[i] = a[tile_vslice_index(i)];
185          }
186      }
187  }
188  
HELPER(sme_mova_zc_q)189  void HELPER(sme_mova_zc_q)(void *vd, void *za, void *vg, uint32_t desc)
190  {
191      int i, oprsz = simd_oprsz(desc) / 16;
192      uint16_t *pg = vg;
193      Int128 *d = vd;
194      Int128 *a = za;
195  
196      /*
197       * Int128 is used here simply to copy 16 bytes, and to simplify
198       * the address arithmetic.
199       */
200      for (i = 0; i < oprsz; i++, za += sizeof(ARMVectorReg)) {
201          if (pg[H2(i)] & 1) {
202              d[i] = a[tile_vslice_index(i)];
203          }
204      }
205  }
206  
207  #undef DO_MOVA_Z
208  
209  /*
210   * Clear elements in a tile slice comprising len bytes.
211   */
212  
213  typedef void ClearFn(void *ptr, size_t off, size_t len);
214  
clear_horizontal(void * ptr,size_t off,size_t len)215  static void clear_horizontal(void *ptr, size_t off, size_t len)
216  {
217      memset(ptr + off, 0, len);
218  }
219  
clear_vertical_b(void * vptr,size_t off,size_t len)220  static void clear_vertical_b(void *vptr, size_t off, size_t len)
221  {
222      for (size_t i = 0; i < len; ++i) {
223          *(uint8_t *)(vptr + tile_vslice_offset(i + off)) = 0;
224      }
225  }
226  
clear_vertical_h(void * vptr,size_t off,size_t len)227  static void clear_vertical_h(void *vptr, size_t off, size_t len)
228  {
229      for (size_t i = 0; i < len; i += 2) {
230          *(uint16_t *)(vptr + tile_vslice_offset(i + off)) = 0;
231      }
232  }
233  
clear_vertical_s(void * vptr,size_t off,size_t len)234  static void clear_vertical_s(void *vptr, size_t off, size_t len)
235  {
236      for (size_t i = 0; i < len; i += 4) {
237          *(uint32_t *)(vptr + tile_vslice_offset(i + off)) = 0;
238      }
239  }
240  
clear_vertical_d(void * vptr,size_t off,size_t len)241  static void clear_vertical_d(void *vptr, size_t off, size_t len)
242  {
243      for (size_t i = 0; i < len; i += 8) {
244          *(uint64_t *)(vptr + tile_vslice_offset(i + off)) = 0;
245      }
246  }
247  
clear_vertical_q(void * vptr,size_t off,size_t len)248  static void clear_vertical_q(void *vptr, size_t off, size_t len)
249  {
250      for (size_t i = 0; i < len; i += 16) {
251          memset(vptr + tile_vslice_offset(i + off), 0, 16);
252      }
253  }
254  
255  /*
256   * Copy elements from an array into a tile slice comprising len bytes.
257   */
258  
259  typedef void CopyFn(void *dst, const void *src, size_t len);
260  
copy_horizontal(void * dst,const void * src,size_t len)261  static void copy_horizontal(void *dst, const void *src, size_t len)
262  {
263      memcpy(dst, src, len);
264  }
265  
copy_vertical_b(void * vdst,const void * vsrc,size_t len)266  static void copy_vertical_b(void *vdst, const void *vsrc, size_t len)
267  {
268      const uint8_t *src = vsrc;
269      uint8_t *dst = vdst;
270      size_t i;
271  
272      for (i = 0; i < len; ++i) {
273          dst[tile_vslice_index(i)] = src[i];
274      }
275  }
276  
copy_vertical_h(void * vdst,const void * vsrc,size_t len)277  static void copy_vertical_h(void *vdst, const void *vsrc, size_t len)
278  {
279      const uint16_t *src = vsrc;
280      uint16_t *dst = vdst;
281      size_t i;
282  
283      for (i = 0; i < len / 2; ++i) {
284          dst[tile_vslice_index(i)] = src[i];
285      }
286  }
287  
copy_vertical_s(void * vdst,const void * vsrc,size_t len)288  static void copy_vertical_s(void *vdst, const void *vsrc, size_t len)
289  {
290      const uint32_t *src = vsrc;
291      uint32_t *dst = vdst;
292      size_t i;
293  
294      for (i = 0; i < len / 4; ++i) {
295          dst[tile_vslice_index(i)] = src[i];
296      }
297  }
298  
copy_vertical_d(void * vdst,const void * vsrc,size_t len)299  static void copy_vertical_d(void *vdst, const void *vsrc, size_t len)
300  {
301      const uint64_t *src = vsrc;
302      uint64_t *dst = vdst;
303      size_t i;
304  
305      for (i = 0; i < len / 8; ++i) {
306          dst[tile_vslice_index(i)] = src[i];
307      }
308  }
309  
copy_vertical_q(void * vdst,const void * vsrc,size_t len)310  static void copy_vertical_q(void *vdst, const void *vsrc, size_t len)
311  {
312      for (size_t i = 0; i < len; i += 16) {
313          memcpy(vdst + tile_vslice_offset(i), vsrc + i, 16);
314      }
315  }
316  
317  /*
318   * Host and TLB primitives for vertical tile slice addressing.
319   */
320  
321  #define DO_LD(NAME, TYPE, HOST, TLB)                                        \
322  static inline void sme_##NAME##_v_host(void *za, intptr_t off, void *host)  \
323  {                                                                           \
324      TYPE val = HOST(host);                                                  \
325      *(TYPE *)(za + tile_vslice_offset(off)) = val;                          \
326  }                                                                           \
327  static inline void sme_##NAME##_v_tlb(CPUARMState *env, void *za,           \
328                          intptr_t off, target_ulong addr, uintptr_t ra)      \
329  {                                                                           \
330      TYPE val = TLB(env, useronly_clean_ptr(addr), ra);                      \
331      *(TYPE *)(za + tile_vslice_offset(off)) = val;                          \
332  }
333  
334  #define DO_ST(NAME, TYPE, HOST, TLB)                                        \
335  static inline void sme_##NAME##_v_host(void *za, intptr_t off, void *host)  \
336  {                                                                           \
337      TYPE val = *(TYPE *)(za + tile_vslice_offset(off));                     \
338      HOST(host, val);                                                        \
339  }                                                                           \
340  static inline void sme_##NAME##_v_tlb(CPUARMState *env, void *za,           \
341                          intptr_t off, target_ulong addr, uintptr_t ra)      \
342  {                                                                           \
343      TYPE val = *(TYPE *)(za + tile_vslice_offset(off));                     \
344      TLB(env, useronly_clean_ptr(addr), val, ra);                            \
345  }
346  
347  /*
348   * The ARMVectorReg elements are stored in host-endian 64-bit units.
349   * For 128-bit quantities, the sequence defined by the Elem[] pseudocode
350   * corresponds to storing the two 64-bit pieces in little-endian order.
351   */
352  #define DO_LDQ(HNAME, VNAME, BE, HOST, TLB)                                 \
353  static inline void HNAME##_host(void *za, intptr_t off, void *host)         \
354  {                                                                           \
355      uint64_t val0 = HOST(host), val1 = HOST(host + 8);                      \
356      uint64_t *ptr = za + off;                                               \
357      ptr[0] = BE ? val1 : val0, ptr[1] = BE ? val0 : val1;                   \
358  }                                                                           \
359  static inline void VNAME##_v_host(void *za, intptr_t off, void *host)       \
360  {                                                                           \
361      HNAME##_host(za, tile_vslice_offset(off), host);                        \
362  }                                                                           \
363  static inline void HNAME##_tlb(CPUARMState *env, void *za, intptr_t off,    \
364                                 target_ulong addr, uintptr_t ra)             \
365  {                                                                           \
366      uint64_t val0 = TLB(env, useronly_clean_ptr(addr), ra);                 \
367      uint64_t val1 = TLB(env, useronly_clean_ptr(addr + 8), ra);             \
368      uint64_t *ptr = za + off;                                               \
369      ptr[0] = BE ? val1 : val0, ptr[1] = BE ? val0 : val1;                   \
370  }                                                                           \
371  static inline void VNAME##_v_tlb(CPUARMState *env, void *za, intptr_t off,  \
372                                 target_ulong addr, uintptr_t ra)             \
373  {                                                                           \
374      HNAME##_tlb(env, za, tile_vslice_offset(off), addr, ra);                \
375  }
376  
377  #define DO_STQ(HNAME, VNAME, BE, HOST, TLB)                                 \
378  static inline void HNAME##_host(void *za, intptr_t off, void *host)         \
379  {                                                                           \
380      uint64_t *ptr = za + off;                                               \
381      HOST(host, ptr[BE]);                                                    \
382      HOST(host + 8, ptr[!BE]);                                               \
383  }                                                                           \
384  static inline void VNAME##_v_host(void *za, intptr_t off, void *host)       \
385  {                                                                           \
386      HNAME##_host(za, tile_vslice_offset(off), host);                        \
387  }                                                                           \
388  static inline void HNAME##_tlb(CPUARMState *env, void *za, intptr_t off,    \
389                                 target_ulong addr, uintptr_t ra)             \
390  {                                                                           \
391      uint64_t *ptr = za + off;                                               \
392      TLB(env, useronly_clean_ptr(addr), ptr[BE], ra);                        \
393      TLB(env, useronly_clean_ptr(addr + 8), ptr[!BE], ra);                   \
394  }                                                                           \
395  static inline void VNAME##_v_tlb(CPUARMState *env, void *za, intptr_t off,  \
396                                 target_ulong addr, uintptr_t ra)             \
397  {                                                                           \
398      HNAME##_tlb(env, za, tile_vslice_offset(off), addr, ra);                \
399  }
400  
DO_LD(ld1b,uint8_t,ldub_p,cpu_ldub_data_ra)401  DO_LD(ld1b, uint8_t, ldub_p, cpu_ldub_data_ra)
402  DO_LD(ld1h_be, uint16_t, lduw_be_p, cpu_lduw_be_data_ra)
403  DO_LD(ld1h_le, uint16_t, lduw_le_p, cpu_lduw_le_data_ra)
404  DO_LD(ld1s_be, uint32_t, ldl_be_p, cpu_ldl_be_data_ra)
405  DO_LD(ld1s_le, uint32_t, ldl_le_p, cpu_ldl_le_data_ra)
406  DO_LD(ld1d_be, uint64_t, ldq_be_p, cpu_ldq_be_data_ra)
407  DO_LD(ld1d_le, uint64_t, ldq_le_p, cpu_ldq_le_data_ra)
408  
409  DO_LDQ(sve_ld1qq_be, sme_ld1q_be, 1, ldq_be_p, cpu_ldq_be_data_ra)
410  DO_LDQ(sve_ld1qq_le, sme_ld1q_le, 0, ldq_le_p, cpu_ldq_le_data_ra)
411  
412  DO_ST(st1b, uint8_t, stb_p, cpu_stb_data_ra)
413  DO_ST(st1h_be, uint16_t, stw_be_p, cpu_stw_be_data_ra)
414  DO_ST(st1h_le, uint16_t, stw_le_p, cpu_stw_le_data_ra)
415  DO_ST(st1s_be, uint32_t, stl_be_p, cpu_stl_be_data_ra)
416  DO_ST(st1s_le, uint32_t, stl_le_p, cpu_stl_le_data_ra)
417  DO_ST(st1d_be, uint64_t, stq_be_p, cpu_stq_be_data_ra)
418  DO_ST(st1d_le, uint64_t, stq_le_p, cpu_stq_le_data_ra)
419  
420  DO_STQ(sve_st1qq_be, sme_st1q_be, 1, stq_be_p, cpu_stq_be_data_ra)
421  DO_STQ(sve_st1qq_le, sme_st1q_le, 0, stq_le_p, cpu_stq_le_data_ra)
422  
423  #undef DO_LD
424  #undef DO_ST
425  #undef DO_LDQ
426  #undef DO_STQ
427  
428  /*
429   * Common helper for all contiguous predicated loads.
430   */
431  
432  static inline QEMU_ALWAYS_INLINE
433  void sme_ld1(CPUARMState *env, void *za, uint64_t *vg,
434               const target_ulong addr, uint32_t desc, const uintptr_t ra,
435               const int esz, uint32_t mtedesc, bool vertical,
436               sve_ldst1_host_fn *host_fn,
437               sve_ldst1_tlb_fn *tlb_fn,
438               ClearFn *clr_fn,
439               CopyFn *cpy_fn)
440  {
441      const intptr_t reg_max = simd_oprsz(desc);
442      const intptr_t esize = 1 << esz;
443      intptr_t reg_off, reg_last;
444      SVEContLdSt info;
445      void *host;
446      int flags;
447  
448      /* Find the active elements.  */
449      if (!sve_cont_ldst_elements(&info, addr, vg, reg_max, esz, esize)) {
450          /* The entire predicate was false; no load occurs.  */
451          clr_fn(za, 0, reg_max);
452          return;
453      }
454  
455      /* Probe the page(s).  Exit with exception for any invalid page. */
456      sve_cont_ldst_pages(&info, FAULT_ALL, env, addr, MMU_DATA_LOAD, ra);
457  
458      /* Handle watchpoints for all active elements. */
459      sve_cont_ldst_watchpoints(&info, env, vg, addr, esize, esize,
460                                BP_MEM_READ, ra);
461  
462      /*
463       * Handle mte checks for all active elements.
464       * Since TBI must be set for MTE, !mtedesc => !mte_active.
465       */
466      if (mtedesc) {
467          sve_cont_ldst_mte_check(&info, env, vg, addr, esize, esize,
468                                  mtedesc, ra);
469      }
470  
471      flags = info.page[0].flags | info.page[1].flags;
472      if (unlikely(flags != 0)) {
473  #ifdef CONFIG_USER_ONLY
474          g_assert_not_reached();
475  #else
476          /*
477           * At least one page includes MMIO.
478           * Any bus operation can fail with cpu_transaction_failed,
479           * which for ARM will raise SyncExternal.  Perform the load
480           * into scratch memory to preserve register state until the end.
481           */
482          ARMVectorReg scratch = { };
483  
484          reg_off = info.reg_off_first[0];
485          reg_last = info.reg_off_last[1];
486          if (reg_last < 0) {
487              reg_last = info.reg_off_split;
488              if (reg_last < 0) {
489                  reg_last = info.reg_off_last[0];
490              }
491          }
492  
493          do {
494              uint64_t pg = vg[reg_off >> 6];
495              do {
496                  if ((pg >> (reg_off & 63)) & 1) {
497                      tlb_fn(env, &scratch, reg_off, addr + reg_off, ra);
498                  }
499                  reg_off += esize;
500              } while (reg_off & 63);
501          } while (reg_off <= reg_last);
502  
503          cpy_fn(za, &scratch, reg_max);
504          return;
505  #endif
506      }
507  
508      /* The entire operation is in RAM, on valid pages. */
509  
510      reg_off = info.reg_off_first[0];
511      reg_last = info.reg_off_last[0];
512      host = info.page[0].host;
513  
514      if (!vertical) {
515          memset(za, 0, reg_max);
516      } else if (reg_off) {
517          clr_fn(za, 0, reg_off);
518      }
519  
520      set_helper_retaddr(ra);
521  
522      while (reg_off <= reg_last) {
523          uint64_t pg = vg[reg_off >> 6];
524          do {
525              if ((pg >> (reg_off & 63)) & 1) {
526                  host_fn(za, reg_off, host + reg_off);
527              } else if (vertical) {
528                  clr_fn(za, reg_off, esize);
529              }
530              reg_off += esize;
531          } while (reg_off <= reg_last && (reg_off & 63));
532      }
533  
534      clear_helper_retaddr();
535  
536      /*
537       * Use the slow path to manage the cross-page misalignment.
538       * But we know this is RAM and cannot trap.
539       */
540      reg_off = info.reg_off_split;
541      if (unlikely(reg_off >= 0)) {
542          tlb_fn(env, za, reg_off, addr + reg_off, ra);
543      }
544  
545      reg_off = info.reg_off_first[1];
546      if (unlikely(reg_off >= 0)) {
547          reg_last = info.reg_off_last[1];
548          host = info.page[1].host;
549  
550          set_helper_retaddr(ra);
551  
552          do {
553              uint64_t pg = vg[reg_off >> 6];
554              do {
555                  if ((pg >> (reg_off & 63)) & 1) {
556                      host_fn(za, reg_off, host + reg_off);
557                  } else if (vertical) {
558                      clr_fn(za, reg_off, esize);
559                  }
560                  reg_off += esize;
561              } while (reg_off & 63);
562          } while (reg_off <= reg_last);
563  
564          clear_helper_retaddr();
565      }
566  }
567  
568  static inline QEMU_ALWAYS_INLINE
sme_ld1_mte(CPUARMState * env,void * za,uint64_t * vg,target_ulong addr,uint32_t desc,uintptr_t ra,const int esz,bool vertical,sve_ldst1_host_fn * host_fn,sve_ldst1_tlb_fn * tlb_fn,ClearFn * clr_fn,CopyFn * cpy_fn)569  void sme_ld1_mte(CPUARMState *env, void *za, uint64_t *vg,
570                   target_ulong addr, uint32_t desc, uintptr_t ra,
571                   const int esz, bool vertical,
572                   sve_ldst1_host_fn *host_fn,
573                   sve_ldst1_tlb_fn *tlb_fn,
574                   ClearFn *clr_fn,
575                   CopyFn *cpy_fn)
576  {
577      uint32_t mtedesc = desc >> (SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT);
578      int bit55 = extract64(addr, 55, 1);
579  
580      /* Remove mtedesc from the normal sve descriptor. */
581      desc = extract32(desc, 0, SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT);
582  
583      /* Perform gross MTE suppression early. */
584      if (!tbi_check(mtedesc, bit55) ||
585          tcma_check(mtedesc, bit55, allocation_tag_from_addr(addr))) {
586          mtedesc = 0;
587      }
588  
589      sme_ld1(env, za, vg, addr, desc, ra, esz, mtedesc, vertical,
590              host_fn, tlb_fn, clr_fn, cpy_fn);
591  }
592  
593  #define DO_LD(L, END, ESZ)                                                 \
594  void HELPER(sme_ld1##L##END##_h)(CPUARMState *env, void *za, void *vg,     \
595                                   target_ulong addr, uint32_t desc)         \
596  {                                                                          \
597      sme_ld1(env, za, vg, addr, desc, GETPC(), ESZ, 0, false,               \
598              sve_ld1##L##L##END##_host, sve_ld1##L##L##END##_tlb,           \
599              clear_horizontal, copy_horizontal);                            \
600  }                                                                          \
601  void HELPER(sme_ld1##L##END##_v)(CPUARMState *env, void *za, void *vg,     \
602                                   target_ulong addr, uint32_t desc)         \
603  {                                                                          \
604      sme_ld1(env, za, vg, addr, desc, GETPC(), ESZ, 0, true,                \
605              sme_ld1##L##END##_v_host, sme_ld1##L##END##_v_tlb,             \
606              clear_vertical_##L, copy_vertical_##L);                        \
607  }                                                                          \
608  void HELPER(sme_ld1##L##END##_h_mte)(CPUARMState *env, void *za, void *vg, \
609                                       target_ulong addr, uint32_t desc)     \
610  {                                                                          \
611      sme_ld1_mte(env, za, vg, addr, desc, GETPC(), ESZ, false,              \
612                  sve_ld1##L##L##END##_host, sve_ld1##L##L##END##_tlb,       \
613                  clear_horizontal, copy_horizontal);                        \
614  }                                                                          \
615  void HELPER(sme_ld1##L##END##_v_mte)(CPUARMState *env, void *za, void *vg, \
616                                       target_ulong addr, uint32_t desc)     \
617  {                                                                          \
618      sme_ld1_mte(env, za, vg, addr, desc, GETPC(), ESZ, true,               \
619                  sme_ld1##L##END##_v_host, sme_ld1##L##END##_v_tlb,         \
620                  clear_vertical_##L, copy_vertical_##L);                    \
621  }
622  
623  DO_LD(b, , MO_8)
DO_LD(h,_be,MO_16)624  DO_LD(h, _be, MO_16)
625  DO_LD(h, _le, MO_16)
626  DO_LD(s, _be, MO_32)
627  DO_LD(s, _le, MO_32)
628  DO_LD(d, _be, MO_64)
629  DO_LD(d, _le, MO_64)
630  DO_LD(q, _be, MO_128)
631  DO_LD(q, _le, MO_128)
632  
633  #undef DO_LD
634  
635  /*
636   * Common helper for all contiguous predicated stores.
637   */
638  
639  static inline QEMU_ALWAYS_INLINE
640  void sme_st1(CPUARMState *env, void *za, uint64_t *vg,
641               const target_ulong addr, uint32_t desc, const uintptr_t ra,
642               const int esz, uint32_t mtedesc, bool vertical,
643               sve_ldst1_host_fn *host_fn,
644               sve_ldst1_tlb_fn *tlb_fn)
645  {
646      const intptr_t reg_max = simd_oprsz(desc);
647      const intptr_t esize = 1 << esz;
648      intptr_t reg_off, reg_last;
649      SVEContLdSt info;
650      void *host;
651      int flags;
652  
653      /* Find the active elements.  */
654      if (!sve_cont_ldst_elements(&info, addr, vg, reg_max, esz, esize)) {
655          /* The entire predicate was false; no store occurs.  */
656          return;
657      }
658  
659      /* Probe the page(s).  Exit with exception for any invalid page. */
660      sve_cont_ldst_pages(&info, FAULT_ALL, env, addr, MMU_DATA_STORE, ra);
661  
662      /* Handle watchpoints for all active elements. */
663      sve_cont_ldst_watchpoints(&info, env, vg, addr, esize, esize,
664                                BP_MEM_WRITE, ra);
665  
666      /*
667       * Handle mte checks for all active elements.
668       * Since TBI must be set for MTE, !mtedesc => !mte_active.
669       */
670      if (mtedesc) {
671          sve_cont_ldst_mte_check(&info, env, vg, addr, esize, esize,
672                                  mtedesc, ra);
673      }
674  
675      flags = info.page[0].flags | info.page[1].flags;
676      if (unlikely(flags != 0)) {
677  #ifdef CONFIG_USER_ONLY
678          g_assert_not_reached();
679  #else
680          /*
681           * At least one page includes MMIO.
682           * Any bus operation can fail with cpu_transaction_failed,
683           * which for ARM will raise SyncExternal.  We cannot avoid
684           * this fault and will leave with the store incomplete.
685           */
686          reg_off = info.reg_off_first[0];
687          reg_last = info.reg_off_last[1];
688          if (reg_last < 0) {
689              reg_last = info.reg_off_split;
690              if (reg_last < 0) {
691                  reg_last = info.reg_off_last[0];
692              }
693          }
694  
695          do {
696              uint64_t pg = vg[reg_off >> 6];
697              do {
698                  if ((pg >> (reg_off & 63)) & 1) {
699                      tlb_fn(env, za, reg_off, addr + reg_off, ra);
700                  }
701                  reg_off += esize;
702              } while (reg_off & 63);
703          } while (reg_off <= reg_last);
704          return;
705  #endif
706      }
707  
708      reg_off = info.reg_off_first[0];
709      reg_last = info.reg_off_last[0];
710      host = info.page[0].host;
711  
712      set_helper_retaddr(ra);
713  
714      while (reg_off <= reg_last) {
715          uint64_t pg = vg[reg_off >> 6];
716          do {
717              if ((pg >> (reg_off & 63)) & 1) {
718                  host_fn(za, reg_off, host + reg_off);
719              }
720              reg_off += 1 << esz;
721          } while (reg_off <= reg_last && (reg_off & 63));
722      }
723  
724      clear_helper_retaddr();
725  
726      /*
727       * Use the slow path to manage the cross-page misalignment.
728       * But we know this is RAM and cannot trap.
729       */
730      reg_off = info.reg_off_split;
731      if (unlikely(reg_off >= 0)) {
732          tlb_fn(env, za, reg_off, addr + reg_off, ra);
733      }
734  
735      reg_off = info.reg_off_first[1];
736      if (unlikely(reg_off >= 0)) {
737          reg_last = info.reg_off_last[1];
738          host = info.page[1].host;
739  
740          set_helper_retaddr(ra);
741  
742          do {
743              uint64_t pg = vg[reg_off >> 6];
744              do {
745                  if ((pg >> (reg_off & 63)) & 1) {
746                      host_fn(za, reg_off, host + reg_off);
747                  }
748                  reg_off += 1 << esz;
749              } while (reg_off & 63);
750          } while (reg_off <= reg_last);
751  
752          clear_helper_retaddr();
753      }
754  }
755  
756  static inline QEMU_ALWAYS_INLINE
sme_st1_mte(CPUARMState * env,void * za,uint64_t * vg,target_ulong addr,uint32_t desc,uintptr_t ra,int esz,bool vertical,sve_ldst1_host_fn * host_fn,sve_ldst1_tlb_fn * tlb_fn)757  void sme_st1_mte(CPUARMState *env, void *za, uint64_t *vg, target_ulong addr,
758                   uint32_t desc, uintptr_t ra, int esz, bool vertical,
759                   sve_ldst1_host_fn *host_fn,
760                   sve_ldst1_tlb_fn *tlb_fn)
761  {
762      uint32_t mtedesc = desc >> (SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT);
763      int bit55 = extract64(addr, 55, 1);
764  
765      /* Remove mtedesc from the normal sve descriptor. */
766      desc = extract32(desc, 0, SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT);
767  
768      /* Perform gross MTE suppression early. */
769      if (!tbi_check(mtedesc, bit55) ||
770          tcma_check(mtedesc, bit55, allocation_tag_from_addr(addr))) {
771          mtedesc = 0;
772      }
773  
774      sme_st1(env, za, vg, addr, desc, ra, esz, mtedesc,
775              vertical, host_fn, tlb_fn);
776  }
777  
778  #define DO_ST(L, END, ESZ)                                                 \
779  void HELPER(sme_st1##L##END##_h)(CPUARMState *env, void *za, void *vg,     \
780                                   target_ulong addr, uint32_t desc)         \
781  {                                                                          \
782      sme_st1(env, za, vg, addr, desc, GETPC(), ESZ, 0, false,               \
783              sve_st1##L##L##END##_host, sve_st1##L##L##END##_tlb);          \
784  }                                                                          \
785  void HELPER(sme_st1##L##END##_v)(CPUARMState *env, void *za, void *vg,     \
786                                   target_ulong addr, uint32_t desc)         \
787  {                                                                          \
788      sme_st1(env, za, vg, addr, desc, GETPC(), ESZ, 0, true,                \
789              sme_st1##L##END##_v_host, sme_st1##L##END##_v_tlb);            \
790  }                                                                          \
791  void HELPER(sme_st1##L##END##_h_mte)(CPUARMState *env, void *za, void *vg, \
792                                       target_ulong addr, uint32_t desc)     \
793  {                                                                          \
794      sme_st1_mte(env, za, vg, addr, desc, GETPC(), ESZ, false,              \
795                  sve_st1##L##L##END##_host, sve_st1##L##L##END##_tlb);      \
796  }                                                                          \
797  void HELPER(sme_st1##L##END##_v_mte)(CPUARMState *env, void *za, void *vg, \
798                                       target_ulong addr, uint32_t desc)     \
799  {                                                                          \
800      sme_st1_mte(env, za, vg, addr, desc, GETPC(), ESZ, true,               \
801                  sme_st1##L##END##_v_host, sme_st1##L##END##_v_tlb);        \
802  }
803  
804  DO_ST(b, , MO_8)
DO_ST(h,_be,MO_16)805  DO_ST(h, _be, MO_16)
806  DO_ST(h, _le, MO_16)
807  DO_ST(s, _be, MO_32)
808  DO_ST(s, _le, MO_32)
809  DO_ST(d, _be, MO_64)
810  DO_ST(d, _le, MO_64)
811  DO_ST(q, _be, MO_128)
812  DO_ST(q, _le, MO_128)
813  
814  #undef DO_ST
815  
816  void HELPER(sme_addha_s)(void *vzda, void *vzn, void *vpn,
817                           void *vpm, uint32_t desc)
818  {
819      intptr_t row, col, oprsz = simd_oprsz(desc) / 4;
820      uint64_t *pn = vpn, *pm = vpm;
821      uint32_t *zda = vzda, *zn = vzn;
822  
823      for (row = 0; row < oprsz; ) {
824          uint64_t pa = pn[row >> 4];
825          do {
826              if (pa & 1) {
827                  for (col = 0; col < oprsz; ) {
828                      uint64_t pb = pm[col >> 4];
829                      do {
830                          if (pb & 1) {
831                              zda[tile_vslice_index(row) + H4(col)] += zn[H4(col)];
832                          }
833                          pb >>= 4;
834                      } while (++col & 15);
835                  }
836              }
837              pa >>= 4;
838          } while (++row & 15);
839      }
840  }
841  
HELPER(sme_addha_d)842  void HELPER(sme_addha_d)(void *vzda, void *vzn, void *vpn,
843                           void *vpm, uint32_t desc)
844  {
845      intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
846      uint8_t *pn = vpn, *pm = vpm;
847      uint64_t *zda = vzda, *zn = vzn;
848  
849      for (row = 0; row < oprsz; ++row) {
850          if (pn[H1(row)] & 1) {
851              for (col = 0; col < oprsz; ++col) {
852                  if (pm[H1(col)] & 1) {
853                      zda[tile_vslice_index(row) + col] += zn[col];
854                  }
855              }
856          }
857      }
858  }
859  
HELPER(sme_addva_s)860  void HELPER(sme_addva_s)(void *vzda, void *vzn, void *vpn,
861                           void *vpm, uint32_t desc)
862  {
863      intptr_t row, col, oprsz = simd_oprsz(desc) / 4;
864      uint64_t *pn = vpn, *pm = vpm;
865      uint32_t *zda = vzda, *zn = vzn;
866  
867      for (row = 0; row < oprsz; ) {
868          uint64_t pa = pn[row >> 4];
869          do {
870              if (pa & 1) {
871                  uint32_t zn_row = zn[H4(row)];
872                  for (col = 0; col < oprsz; ) {
873                      uint64_t pb = pm[col >> 4];
874                      do {
875                          if (pb & 1) {
876                              zda[tile_vslice_index(row) + H4(col)] += zn_row;
877                          }
878                          pb >>= 4;
879                      } while (++col & 15);
880                  }
881              }
882              pa >>= 4;
883          } while (++row & 15);
884      }
885  }
886  
HELPER(sme_addva_d)887  void HELPER(sme_addva_d)(void *vzda, void *vzn, void *vpn,
888                           void *vpm, uint32_t desc)
889  {
890      intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
891      uint8_t *pn = vpn, *pm = vpm;
892      uint64_t *zda = vzda, *zn = vzn;
893  
894      for (row = 0; row < oprsz; ++row) {
895          if (pn[H1(row)] & 1) {
896              uint64_t zn_row = zn[row];
897              for (col = 0; col < oprsz; ++col) {
898                  if (pm[H1(col)] & 1) {
899                      zda[tile_vslice_index(row) + col] += zn_row;
900                  }
901              }
902          }
903      }
904  }
905  
HELPER(sme_fmopa_s)906  void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
907                           void *vpm, void *vst, uint32_t desc)
908  {
909      intptr_t row, col, oprsz = simd_maxsz(desc);
910      uint32_t neg = simd_data(desc) << 31;
911      uint16_t *pn = vpn, *pm = vpm;
912      float_status fpst;
913  
914      /*
915       * Make a copy of float_status because this operation does not
916       * update the cumulative fp exception status.  It also produces
917       * default nans.
918       */
919      fpst = *(float_status *)vst;
920      set_default_nan_mode(true, &fpst);
921  
922      for (row = 0; row < oprsz; ) {
923          uint16_t pa = pn[H2(row >> 4)];
924          do {
925              if (pa & 1) {
926                  void *vza_row = vza + tile_vslice_offset(row);
927                  uint32_t n = *(uint32_t *)(vzn + H1_4(row)) ^ neg;
928  
929                  for (col = 0; col < oprsz; ) {
930                      uint16_t pb = pm[H2(col >> 4)];
931                      do {
932                          if (pb & 1) {
933                              uint32_t *a = vza_row + H1_4(col);
934                              uint32_t *m = vzm + H1_4(col);
935                              *a = float32_muladd(n, *m, *a, 0, &fpst);
936                          }
937                          col += 4;
938                          pb >>= 4;
939                      } while (col & 15);
940                  }
941              }
942              row += 4;
943              pa >>= 4;
944          } while (row & 15);
945      }
946  }
947  
HELPER(sme_fmopa_d)948  void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn,
949                           void *vpm, void *vst, uint32_t desc)
950  {
951      intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
952      uint64_t neg = (uint64_t)simd_data(desc) << 63;
953      uint64_t *za = vza, *zn = vzn, *zm = vzm;
954      uint8_t *pn = vpn, *pm = vpm;
955      float_status fpst = *(float_status *)vst;
956  
957      set_default_nan_mode(true, &fpst);
958  
959      for (row = 0; row < oprsz; ++row) {
960          if (pn[H1(row)] & 1) {
961              uint64_t *za_row = &za[tile_vslice_index(row)];
962              uint64_t n = zn[row] ^ neg;
963  
964              for (col = 0; col < oprsz; ++col) {
965                  if (pm[H1(col)] & 1) {
966                      uint64_t *a = &za_row[col];
967                      *a = float64_muladd(n, zm[col], *a, 0, &fpst);
968                  }
969              }
970          }
971      }
972  }
973  
974  /*
975   * Alter PAIR as needed for controlling predicates being false,
976   * and for NEG on an enabled row element.
977   */
f16mop_adj_pair(uint32_t pair,uint32_t pg,uint32_t neg)978  static inline uint32_t f16mop_adj_pair(uint32_t pair, uint32_t pg, uint32_t neg)
979  {
980      /*
981       * The pseudocode uses a conditional negate after the conditional zero.
982       * It is simpler here to unconditionally negate before conditional zero.
983       */
984      pair ^= neg;
985      if (!(pg & 1)) {
986          pair &= 0xffff0000u;
987      }
988      if (!(pg & 4)) {
989          pair &= 0x0000ffffu;
990      }
991      return pair;
992  }
993  
f16_dotadd(float32 sum,uint32_t e1,uint32_t e2,float_status * s_f16,float_status * s_std,float_status * s_odd)994  static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2,
995                            float_status *s_f16, float_status *s_std,
996                            float_status *s_odd)
997  {
998      /*
999       * We need three different float_status for different parts of this
1000       * operation:
1001       *  - the input conversion of the float16 values must use the
1002       *    f16-specific float_status, so that the FPCR.FZ16 control is applied
1003       *  - operations on float32 including the final accumulation must use
1004       *    the normal float_status, so that FPCR.FZ is applied
1005       *  - we have pre-set-up copy of s_std which is set to round-to-odd,
1006       *    for the multiply (see below)
1007       */
1008      float64 e1r = float16_to_float64(e1 & 0xffff, true, s_f16);
1009      float64 e1c = float16_to_float64(e1 >> 16, true, s_f16);
1010      float64 e2r = float16_to_float64(e2 & 0xffff, true, s_f16);
1011      float64 e2c = float16_to_float64(e2 >> 16, true, s_f16);
1012      float64 t64;
1013      float32 t32;
1014  
1015      /*
1016       * The ARM pseudocode function FPDot performs both multiplies
1017       * and the add with a single rounding operation.  Emulate this
1018       * by performing the first multiply in round-to-odd, then doing
1019       * the second multiply as fused multiply-add, and rounding to
1020       * float32 all in one step.
1021       */
1022      t64 = float64_mul(e1r, e2r, s_odd);
1023      t64 = float64r32_muladd(e1c, e2c, t64, 0, s_std);
1024  
1025      /* This conversion is exact, because we've already rounded. */
1026      t32 = float64_to_float32(t64, s_std);
1027  
1028      /* The final accumulation step is not fused. */
1029      return float32_add(sum, t32, s_std);
1030  }
1031  
HELPER(sme_fmopa_h)1032  void HELPER(sme_fmopa_h)(void *vza, void *vzn, void *vzm, void *vpn,
1033                           void *vpm, CPUARMState *env, uint32_t desc)
1034  {
1035      intptr_t row, col, oprsz = simd_maxsz(desc);
1036      uint32_t neg = simd_data(desc) * 0x80008000u;
1037      uint16_t *pn = vpn, *pm = vpm;
1038      float_status fpst_odd, fpst_std, fpst_f16;
1039  
1040      /*
1041       * Make copies of fp_status and fp_status_f16, because this operation
1042       * does not update the cumulative fp exception status.  It also
1043       * produces default NaNs. We also need a second copy of fp_status with
1044       * round-to-odd -- see above.
1045       */
1046      fpst_f16 = env->vfp.fp_status_f16;
1047      fpst_std = env->vfp.fp_status;
1048      set_default_nan_mode(true, &fpst_std);
1049      set_default_nan_mode(true, &fpst_f16);
1050      fpst_odd = fpst_std;
1051      set_float_rounding_mode(float_round_to_odd, &fpst_odd);
1052  
1053      for (row = 0; row < oprsz; ) {
1054          uint16_t prow = pn[H2(row >> 4)];
1055          do {
1056              void *vza_row = vza + tile_vslice_offset(row);
1057              uint32_t n = *(uint32_t *)(vzn + H1_4(row));
1058  
1059              n = f16mop_adj_pair(n, prow, neg);
1060  
1061              for (col = 0; col < oprsz; ) {
1062                  uint16_t pcol = pm[H2(col >> 4)];
1063                  do {
1064                      if (prow & pcol & 0b0101) {
1065                          uint32_t *a = vza_row + H1_4(col);
1066                          uint32_t m = *(uint32_t *)(vzm + H1_4(col));
1067  
1068                          m = f16mop_adj_pair(m, pcol, 0);
1069                          *a = f16_dotadd(*a, n, m,
1070                                          &fpst_f16, &fpst_std, &fpst_odd);
1071                      }
1072                      col += 4;
1073                      pcol >>= 4;
1074                  } while (col & 15);
1075              }
1076              row += 4;
1077              prow >>= 4;
1078          } while (row & 15);
1079      }
1080  }
1081  
HELPER(sme_bfmopa)1082  void HELPER(sme_bfmopa)(void *vza, void *vzn, void *vzm,
1083                          void *vpn, void *vpm, CPUARMState *env, uint32_t desc)
1084  {
1085      intptr_t row, col, oprsz = simd_maxsz(desc);
1086      uint32_t neg = simd_data(desc) * 0x80008000u;
1087      uint16_t *pn = vpn, *pm = vpm;
1088      float_status fpst, fpst_odd;
1089  
1090      if (is_ebf(env, &fpst, &fpst_odd)) {
1091          for (row = 0; row < oprsz; ) {
1092              uint16_t prow = pn[H2(row >> 4)];
1093              do {
1094                  void *vza_row = vza + tile_vslice_offset(row);
1095                  uint32_t n = *(uint32_t *)(vzn + H1_4(row));
1096  
1097                  n = f16mop_adj_pair(n, prow, neg);
1098  
1099                  for (col = 0; col < oprsz; ) {
1100                      uint16_t pcol = pm[H2(col >> 4)];
1101                      do {
1102                          if (prow & pcol & 0b0101) {
1103                              uint32_t *a = vza_row + H1_4(col);
1104                              uint32_t m = *(uint32_t *)(vzm + H1_4(col));
1105  
1106                              m = f16mop_adj_pair(m, pcol, 0);
1107                              *a = bfdotadd_ebf(*a, n, m, &fpst, &fpst_odd);
1108                          }
1109                          col += 4;
1110                          pcol >>= 4;
1111                      } while (col & 15);
1112                  }
1113                  row += 4;
1114                  prow >>= 4;
1115              } while (row & 15);
1116          }
1117      } else {
1118          for (row = 0; row < oprsz; ) {
1119              uint16_t prow = pn[H2(row >> 4)];
1120              do {
1121                  void *vza_row = vza + tile_vslice_offset(row);
1122                  uint32_t n = *(uint32_t *)(vzn + H1_4(row));
1123  
1124                  n = f16mop_adj_pair(n, prow, neg);
1125  
1126                  for (col = 0; col < oprsz; ) {
1127                      uint16_t pcol = pm[H2(col >> 4)];
1128                      do {
1129                          if (prow & pcol & 0b0101) {
1130                              uint32_t *a = vza_row + H1_4(col);
1131                              uint32_t m = *(uint32_t *)(vzm + H1_4(col));
1132  
1133                              m = f16mop_adj_pair(m, pcol, 0);
1134                              *a = bfdotadd(*a, n, m, &fpst);
1135                          }
1136                          col += 4;
1137                          pcol >>= 4;
1138                      } while (col & 15);
1139                  }
1140                  row += 4;
1141                  prow >>= 4;
1142              } while (row & 15);
1143          }
1144      }
1145  }
1146  
1147  typedef uint32_t IMOPFn32(uint32_t, uint32_t, uint32_t, uint8_t, bool);
do_imopa_s(uint32_t * za,uint32_t * zn,uint32_t * zm,uint8_t * pn,uint8_t * pm,uint32_t desc,IMOPFn32 * fn)1148  static inline void do_imopa_s(uint32_t *za, uint32_t *zn, uint32_t *zm,
1149                                uint8_t *pn, uint8_t *pm,
1150                                uint32_t desc, IMOPFn32 *fn)
1151  {
1152      intptr_t row, col, oprsz = simd_oprsz(desc) / 4;
1153      bool neg = simd_data(desc);
1154  
1155      for (row = 0; row < oprsz; ++row) {
1156          uint8_t pa = (pn[H1(row >> 1)] >> ((row & 1) * 4)) & 0xf;
1157          uint32_t *za_row = &za[tile_vslice_index(row)];
1158          uint32_t n = zn[H4(row)];
1159  
1160          for (col = 0; col < oprsz; ++col) {
1161              uint8_t pb = pm[H1(col >> 1)] >> ((col & 1) * 4);
1162              uint32_t *a = &za_row[H4(col)];
1163  
1164              *a = fn(n, zm[H4(col)], *a, pa & pb, neg);
1165          }
1166      }
1167  }
1168  
1169  typedef uint64_t IMOPFn64(uint64_t, uint64_t, uint64_t, uint8_t, bool);
do_imopa_d(uint64_t * za,uint64_t * zn,uint64_t * zm,uint8_t * pn,uint8_t * pm,uint32_t desc,IMOPFn64 * fn)1170  static inline void do_imopa_d(uint64_t *za, uint64_t *zn, uint64_t *zm,
1171                                uint8_t *pn, uint8_t *pm,
1172                                uint32_t desc, IMOPFn64 *fn)
1173  {
1174      intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
1175      bool neg = simd_data(desc);
1176  
1177      for (row = 0; row < oprsz; ++row) {
1178          uint8_t pa = pn[H1(row)];
1179          uint64_t *za_row = &za[tile_vslice_index(row)];
1180          uint64_t n = zn[row];
1181  
1182          for (col = 0; col < oprsz; ++col) {
1183              uint8_t pb = pm[H1(col)];
1184              uint64_t *a = &za_row[col];
1185  
1186              *a = fn(n, zm[col], *a, pa & pb, neg);
1187          }
1188      }
1189  }
1190  
1191  #define DEF_IMOP_32(NAME, NTYPE, MTYPE) \
1192  static uint32_t NAME(uint32_t n, uint32_t m, uint32_t a, uint8_t p, bool neg) \
1193  {                                                                           \
1194      uint32_t sum = 0;                                                       \
1195      /* Apply P to N as a mask, making the inactive elements 0. */           \
1196      n &= expand_pred_b(p);                                                  \
1197      sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0);                               \
1198      sum += (NTYPE)(n >> 8) * (MTYPE)(m >> 8);                               \
1199      sum += (NTYPE)(n >> 16) * (MTYPE)(m >> 16);                             \
1200      sum += (NTYPE)(n >> 24) * (MTYPE)(m >> 24);                             \
1201      return neg ? a - sum : a + sum;                                         \
1202  }
1203  
1204  #define DEF_IMOP_64(NAME, NTYPE, MTYPE) \
1205  static uint64_t NAME(uint64_t n, uint64_t m, uint64_t a, uint8_t p, bool neg) \
1206  {                                                                           \
1207      uint64_t sum = 0;                                                       \
1208      /* Apply P to N as a mask, making the inactive elements 0. */           \
1209      n &= expand_pred_h(p);                                                  \
1210      sum += (int64_t)(NTYPE)(n >> 0) * (MTYPE)(m >> 0);                      \
1211      sum += (int64_t)(NTYPE)(n >> 16) * (MTYPE)(m >> 16);                    \
1212      sum += (int64_t)(NTYPE)(n >> 32) * (MTYPE)(m >> 32);                    \
1213      sum += (int64_t)(NTYPE)(n >> 48) * (MTYPE)(m >> 48);                    \
1214      return neg ? a - sum : a + sum;                                         \
1215  }
1216  
1217  DEF_IMOP_32(smopa_s, int8_t, int8_t)
1218  DEF_IMOP_32(umopa_s, uint8_t, uint8_t)
1219  DEF_IMOP_32(sumopa_s, int8_t, uint8_t)
1220  DEF_IMOP_32(usmopa_s, uint8_t, int8_t)
1221  
1222  DEF_IMOP_64(smopa_d, int16_t, int16_t)
1223  DEF_IMOP_64(umopa_d, uint16_t, uint16_t)
1224  DEF_IMOP_64(sumopa_d, int16_t, uint16_t)
1225  DEF_IMOP_64(usmopa_d, uint16_t, int16_t)
1226  
1227  #define DEF_IMOPH(NAME, S) \
1228      void HELPER(sme_##NAME##_##S)(void *vza, void *vzn, void *vzm,          \
1229                                    void *vpn, void *vpm, uint32_t desc)      \
1230      { do_imopa_##S(vza, vzn, vzm, vpn, vpm, desc, NAME##_##S); }
1231  
1232  DEF_IMOPH(smopa, s)
1233  DEF_IMOPH(umopa, s)
1234  DEF_IMOPH(sumopa, s)
1235  DEF_IMOPH(usmopa, s)
1236  
1237  DEF_IMOPH(smopa, d)
1238  DEF_IMOPH(umopa, d)
1239  DEF_IMOPH(sumopa, d)
1240  DEF_IMOPH(usmopa, d)
1241