xref: /openbmc/qemu/target/arm/tcg/vec_internal.h (revision 784155cd)
1 /*
2  * ARM AdvSIMD / SVE Vector Helpers
3  *
4  * Copyright (c) 2020 Linaro
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, see <http://www.gnu.org/licenses/>.
18  */
19 
20 #ifndef TARGET_ARM_VEC_INTERNAL_H
21 #define TARGET_ARM_VEC_INTERNAL_H
22 
23 /*
24  * Note that vector data is stored in host-endian 64-bit chunks,
25  * so addressing units smaller than that needs a host-endian fixup.
26  *
27  * The H<N> macros are used when indexing an array of elements of size N.
28  *
29  * The H1_<N> macros are used when performing byte arithmetic and then
30  * casting the final pointer to a type of size N.
31  */
32 #if HOST_BIG_ENDIAN
33 #define H1(x)   ((x) ^ 7)
34 #define H1_2(x) ((x) ^ 6)
35 #define H1_4(x) ((x) ^ 4)
36 #define H2(x)   ((x) ^ 3)
37 #define H4(x)   ((x) ^ 1)
38 #else
39 #define H1(x)   (x)
40 #define H1_2(x) (x)
41 #define H1_4(x) (x)
42 #define H2(x)   (x)
43 #define H4(x)   (x)
44 #endif
45 /*
46  * Access to 64-bit elements isn't host-endian dependent; we provide H8
47  * and H1_8 so that when a function is being generated from a macro we
48  * can pass these rather than an empty macro argument, for clarity.
49  */
50 #define H8(x)   (x)
51 #define H1_8(x) (x)
52 
53 /*
54  * Expand active predicate bits to bytes, for byte elements.
55  */
56 extern const uint64_t expand_pred_b_data[256];
57 static inline uint64_t expand_pred_b(uint8_t byte)
58 {
59     return expand_pred_b_data[byte];
60 }
61 
62 /* Similarly for half-word elements. */
63 extern const uint64_t expand_pred_h_data[0x55 + 1];
64 static inline uint64_t expand_pred_h(uint8_t byte)
65 {
66     return expand_pred_h_data[byte & 0x55];
67 }
68 
69 static inline void clear_tail(void *vd, uintptr_t opr_sz, uintptr_t max_sz)
70 {
71     uint64_t *d = vd + opr_sz;
72     uintptr_t i;
73 
74     for (i = opr_sz; i < max_sz; i += 8) {
75         *d++ = 0;
76     }
77 }
78 
79 static inline int32_t do_sqrshl_bhs(int32_t src, int32_t shift, int bits,
80                                     bool round, uint32_t *sat)
81 {
82     if (shift <= -bits) {
83         /* Rounding the sign bit always produces 0. */
84         if (round) {
85             return 0;
86         }
87         return src >> 31;
88     } else if (shift < 0) {
89         if (round) {
90             src >>= -shift - 1;
91             return (src >> 1) + (src & 1);
92         }
93         return src >> -shift;
94     } else if (shift < bits) {
95         int32_t val = src << shift;
96         if (bits == 32) {
97             if (!sat || val >> shift == src) {
98                 return val;
99             }
100         } else {
101             int32_t extval = sextract32(val, 0, bits);
102             if (!sat || val == extval) {
103                 return extval;
104             }
105         }
106     } else if (!sat || src == 0) {
107         return 0;
108     }
109 
110     *sat = 1;
111     return (1u << (bits - 1)) - (src >= 0);
112 }
113 
114 static inline uint32_t do_uqrshl_bhs(uint32_t src, int32_t shift, int bits,
115                                      bool round, uint32_t *sat)
116 {
117     if (shift <= -(bits + round)) {
118         return 0;
119     } else if (shift < 0) {
120         if (round) {
121             src >>= -shift - 1;
122             return (src >> 1) + (src & 1);
123         }
124         return src >> -shift;
125     } else if (shift < bits) {
126         uint32_t val = src << shift;
127         if (bits == 32) {
128             if (!sat || val >> shift == src) {
129                 return val;
130             }
131         } else {
132             uint32_t extval = extract32(val, 0, bits);
133             if (!sat || val == extval) {
134                 return extval;
135             }
136         }
137     } else if (!sat || src == 0) {
138         return 0;
139     }
140 
141     *sat = 1;
142     return MAKE_64BIT_MASK(0, bits);
143 }
144 
145 static inline int32_t do_suqrshl_bhs(int32_t src, int32_t shift, int bits,
146                                      bool round, uint32_t *sat)
147 {
148     if (sat && src < 0) {
149         *sat = 1;
150         return 0;
151     }
152     return do_uqrshl_bhs(src, shift, bits, round, sat);
153 }
154 
155 static inline int64_t do_sqrshl_d(int64_t src, int64_t shift,
156                                   bool round, uint32_t *sat)
157 {
158     if (shift <= -64) {
159         /* Rounding the sign bit always produces 0. */
160         if (round) {
161             return 0;
162         }
163         return src >> 63;
164     } else if (shift < 0) {
165         if (round) {
166             src >>= -shift - 1;
167             return (src >> 1) + (src & 1);
168         }
169         return src >> -shift;
170     } else if (shift < 64) {
171         int64_t val = src << shift;
172         if (!sat || val >> shift == src) {
173             return val;
174         }
175     } else if (!sat || src == 0) {
176         return 0;
177     }
178 
179     *sat = 1;
180     return src < 0 ? INT64_MIN : INT64_MAX;
181 }
182 
183 static inline uint64_t do_uqrshl_d(uint64_t src, int64_t shift,
184                                    bool round, uint32_t *sat)
185 {
186     if (shift <= -(64 + round)) {
187         return 0;
188     } else if (shift < 0) {
189         if (round) {
190             src >>= -shift - 1;
191             return (src >> 1) + (src & 1);
192         }
193         return src >> -shift;
194     } else if (shift < 64) {
195         uint64_t val = src << shift;
196         if (!sat || val >> shift == src) {
197             return val;
198         }
199     } else if (!sat || src == 0) {
200         return 0;
201     }
202 
203     *sat = 1;
204     return UINT64_MAX;
205 }
206 
207 static inline int64_t do_suqrshl_d(int64_t src, int64_t shift,
208                                    bool round, uint32_t *sat)
209 {
210     if (sat && src < 0) {
211         *sat = 1;
212         return 0;
213     }
214     return do_uqrshl_d(src, shift, round, sat);
215 }
216 
217 int8_t do_sqrdmlah_b(int8_t, int8_t, int8_t, bool, bool);
218 int16_t do_sqrdmlah_h(int16_t, int16_t, int16_t, bool, bool, uint32_t *);
219 int32_t do_sqrdmlah_s(int32_t, int32_t, int32_t, bool, bool, uint32_t *);
220 int64_t do_sqrdmlah_d(int64_t, int64_t, int64_t, bool, bool);
221 
222 /**
223  * bfdotadd:
224  * @sum: addend
225  * @e1, @e2: multiplicand vectors
226  *
227  * BFloat16 2-way dot product of @e1 & @e2, accumulating with @sum.
228  * The @e1 and @e2 operands correspond to the 32-bit source vector
229  * slots and contain two Bfloat16 values each.
230  *
231  * Corresponds to the ARM pseudocode function BFDotAdd.
232  */
233 float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2);
234 
235 #endif /* TARGET_ARM_VEC_INTERNAL_H */
236