1 /*
2 * ARM AdvSIMD / SVE Vector Helpers
3 *
4 * Copyright (c) 2020 Linaro
5 *
6 * This library is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
10 *
11 * This library is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with this library; if not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #ifndef TARGET_ARM_VEC_INTERNAL_H
21 #define TARGET_ARM_VEC_INTERNAL_H
22
23 #include "fpu/softfloat.h"
24
25 typedef struct CPUArchState CPUARMState;
26
27 /*
28 * Note that vector data is stored in host-endian 64-bit chunks,
29 * so addressing units smaller than that needs a host-endian fixup.
30 *
31 * The H<N> macros are used when indexing an array of elements of size N.
32 *
33 * The H1_<N> macros are used when performing byte arithmetic and then
34 * casting the final pointer to a type of size N.
35 */
36 #if HOST_BIG_ENDIAN
37 #define H1(x) ((x) ^ 7)
38 #define H1_2(x) ((x) ^ 6)
39 #define H1_4(x) ((x) ^ 4)
40 #define H2(x) ((x) ^ 3)
41 #define H4(x) ((x) ^ 1)
42 #else
43 #define H1(x) (x)
44 #define H1_2(x) (x)
45 #define H1_4(x) (x)
46 #define H2(x) (x)
47 #define H4(x) (x)
48 #endif
49 /*
50 * Access to 64-bit elements isn't host-endian dependent; we provide H8
51 * and H1_8 so that when a function is being generated from a macro we
52 * can pass these rather than an empty macro argument, for clarity.
53 */
54 #define H8(x) (x)
55 #define H1_8(x) (x)
56
57 /*
58 * Expand active predicate bits to bytes, for byte elements.
59 */
60 extern const uint64_t expand_pred_b_data[256];
expand_pred_b(uint8_t byte)61 static inline uint64_t expand_pred_b(uint8_t byte)
62 {
63 return expand_pred_b_data[byte];
64 }
65
66 /* Similarly for half-word elements. */
67 extern const uint64_t expand_pred_h_data[0x55 + 1];
expand_pred_h(uint8_t byte)68 static inline uint64_t expand_pred_h(uint8_t byte)
69 {
70 return expand_pred_h_data[byte & 0x55];
71 }
72
clear_tail(void * vd,uintptr_t opr_sz,uintptr_t max_sz)73 static inline void clear_tail(void *vd, uintptr_t opr_sz, uintptr_t max_sz)
74 {
75 uint64_t *d = vd + opr_sz;
76 uintptr_t i;
77
78 for (i = opr_sz; i < max_sz; i += 8) {
79 *d++ = 0;
80 }
81 }
82
do_sqrshl_bhs(int32_t src,int32_t shift,int bits,bool round,uint32_t * sat)83 static inline int32_t do_sqrshl_bhs(int32_t src, int32_t shift, int bits,
84 bool round, uint32_t *sat)
85 {
86 if (shift <= -bits) {
87 /* Rounding the sign bit always produces 0. */
88 if (round) {
89 return 0;
90 }
91 return src >> 31;
92 } else if (shift < 0) {
93 if (round) {
94 src >>= -shift - 1;
95 return (src >> 1) + (src & 1);
96 }
97 return src >> -shift;
98 } else if (shift < bits) {
99 int32_t val = src << shift;
100 if (bits == 32) {
101 if (!sat || val >> shift == src) {
102 return val;
103 }
104 } else {
105 int32_t extval = sextract32(val, 0, bits);
106 if (!sat || val == extval) {
107 return extval;
108 }
109 }
110 } else if (!sat || src == 0) {
111 return 0;
112 }
113
114 *sat = 1;
115 return (1u << (bits - 1)) - (src >= 0);
116 }
117
do_uqrshl_bhs(uint32_t src,int32_t shift,int bits,bool round,uint32_t * sat)118 static inline uint32_t do_uqrshl_bhs(uint32_t src, int32_t shift, int bits,
119 bool round, uint32_t *sat)
120 {
121 if (shift <= -(bits + round)) {
122 return 0;
123 } else if (shift < 0) {
124 if (round) {
125 src >>= -shift - 1;
126 return (src >> 1) + (src & 1);
127 }
128 return src >> -shift;
129 } else if (shift < bits) {
130 uint32_t val = src << shift;
131 if (bits == 32) {
132 if (!sat || val >> shift == src) {
133 return val;
134 }
135 } else {
136 uint32_t extval = extract32(val, 0, bits);
137 if (!sat || val == extval) {
138 return extval;
139 }
140 }
141 } else if (!sat || src == 0) {
142 return 0;
143 }
144
145 *sat = 1;
146 return MAKE_64BIT_MASK(0, bits);
147 }
148
do_suqrshl_bhs(int32_t src,int32_t shift,int bits,bool round,uint32_t * sat)149 static inline int32_t do_suqrshl_bhs(int32_t src, int32_t shift, int bits,
150 bool round, uint32_t *sat)
151 {
152 if (sat && src < 0) {
153 *sat = 1;
154 return 0;
155 }
156 return do_uqrshl_bhs(src, shift, bits, round, sat);
157 }
158
do_sqrshl_d(int64_t src,int64_t shift,bool round,uint32_t * sat)159 static inline int64_t do_sqrshl_d(int64_t src, int64_t shift,
160 bool round, uint32_t *sat)
161 {
162 if (shift <= -64) {
163 /* Rounding the sign bit always produces 0. */
164 if (round) {
165 return 0;
166 }
167 return src >> 63;
168 } else if (shift < 0) {
169 if (round) {
170 src >>= -shift - 1;
171 return (src >> 1) + (src & 1);
172 }
173 return src >> -shift;
174 } else if (shift < 64) {
175 int64_t val = src << shift;
176 if (!sat || val >> shift == src) {
177 return val;
178 }
179 } else if (!sat || src == 0) {
180 return 0;
181 }
182
183 *sat = 1;
184 return src < 0 ? INT64_MIN : INT64_MAX;
185 }
186
do_uqrshl_d(uint64_t src,int64_t shift,bool round,uint32_t * sat)187 static inline uint64_t do_uqrshl_d(uint64_t src, int64_t shift,
188 bool round, uint32_t *sat)
189 {
190 if (shift <= -(64 + round)) {
191 return 0;
192 } else if (shift < 0) {
193 if (round) {
194 src >>= -shift - 1;
195 return (src >> 1) + (src & 1);
196 }
197 return src >> -shift;
198 } else if (shift < 64) {
199 uint64_t val = src << shift;
200 if (!sat || val >> shift == src) {
201 return val;
202 }
203 } else if (!sat || src == 0) {
204 return 0;
205 }
206
207 *sat = 1;
208 return UINT64_MAX;
209 }
210
do_suqrshl_d(int64_t src,int64_t shift,bool round,uint32_t * sat)211 static inline int64_t do_suqrshl_d(int64_t src, int64_t shift,
212 bool round, uint32_t *sat)
213 {
214 if (sat && src < 0) {
215 *sat = 1;
216 return 0;
217 }
218 return do_uqrshl_d(src, shift, round, sat);
219 }
220
221 int8_t do_sqrdmlah_b(int8_t, int8_t, int8_t, bool, bool);
222 int16_t do_sqrdmlah_h(int16_t, int16_t, int16_t, bool, bool, uint32_t *);
223 int32_t do_sqrdmlah_s(int32_t, int32_t, int32_t, bool, bool, uint32_t *);
224 int64_t do_sqrdmlah_d(int64_t, int64_t, int64_t, bool, bool);
225
226 #define do_ssat_b(val) MIN(MAX(val, INT8_MIN), INT8_MAX)
227 #define do_ssat_h(val) MIN(MAX(val, INT16_MIN), INT16_MAX)
228 #define do_ssat_s(val) MIN(MAX(val, INT32_MIN), INT32_MAX)
229 #define do_usat_b(val) MIN(MAX(val, 0), UINT8_MAX)
230 #define do_usat_h(val) MIN(MAX(val, 0), UINT16_MAX)
231 #define do_usat_s(val) MIN(MAX(val, 0), UINT32_MAX)
232
do_urshr(uint64_t x,unsigned sh)233 static inline uint64_t do_urshr(uint64_t x, unsigned sh)
234 {
235 if (likely(sh < 64)) {
236 return (x >> sh) + ((x >> (sh - 1)) & 1);
237 } else if (sh == 64) {
238 return x >> 63;
239 } else {
240 return 0;
241 }
242 }
243
do_srshr(int64_t x,unsigned sh)244 static inline int64_t do_srshr(int64_t x, unsigned sh)
245 {
246 if (likely(sh < 64)) {
247 return (x >> sh) + ((x >> (sh - 1)) & 1);
248 } else {
249 /* Rounding the sign bit always produces 0. */
250 return 0;
251 }
252 }
253
254 /**
255 * bfdotadd:
256 * @sum: addend
257 * @e1, @e2: multiplicand vectors
258 * @fpst: floating-point status to use
259 *
260 * BFloat16 2-way dot product of @e1 & @e2, accumulating with @sum.
261 * The @e1 and @e2 operands correspond to the 32-bit source vector
262 * slots and contain two Bfloat16 values each.
263 *
264 * Corresponds to the ARM pseudocode function BFDotAdd, specialized
265 * for the FPCR.EBF == 0 case.
266 */
267 float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2, float_status *fpst);
268 /**
269 * bfdotadd_ebf:
270 * @sum: addend
271 * @e1, @e2: multiplicand vectors
272 * @fpst: floating-point status to use
273 * @fpst_odd: floating-point status to use for round-to-odd operations
274 *
275 * BFloat16 2-way dot product of @e1 & @e2, accumulating with @sum.
276 * The @e1 and @e2 operands correspond to the 32-bit source vector
277 * slots and contain two Bfloat16 values each.
278 *
279 * Corresponds to the ARM pseudocode function BFDotAdd, specialized
280 * for the FPCR.EBF == 1 case.
281 */
282 float32 bfdotadd_ebf(float32 sum, uint32_t e1, uint32_t e2,
283 float_status *fpst, float_status *fpst_odd);
284
285 /**
286 * is_ebf:
287 * @env: CPU state
288 * @statusp: pointer to floating point status to fill in
289 * @oddstatusp: pointer to floating point status to fill in for round-to-odd
290 *
291 * Determine whether a BFDotAdd operation should use FPCR.EBF = 0
292 * or FPCR.EBF = 1 semantics. On return, has initialized *statusp
293 * and *oddstatusp to suitable float_status arguments to use with either
294 * bfdotadd() or bfdotadd_ebf().
295 * Returns true for EBF = 1, false for EBF = 0. (The caller should use this
296 * to decide whether to call bfdotadd() or bfdotadd_ebf().)
297 */
298 bool is_ebf(CPUARMState *env, float_status *statusp, float_status *oddstatusp);
299
300 /*
301 * Negate as for FPCR.AH=1 -- do not negate NaNs.
302 */
bfloat16_ah_chs(float16 a)303 static inline float16 bfloat16_ah_chs(float16 a)
304 {
305 return bfloat16_is_any_nan(a) ? a : bfloat16_chs(a);
306 }
307
float16_ah_chs(float16 a)308 static inline float16 float16_ah_chs(float16 a)
309 {
310 return float16_is_any_nan(a) ? a : float16_chs(a);
311 }
312
float32_ah_chs(float32 a)313 static inline float32 float32_ah_chs(float32 a)
314 {
315 return float32_is_any_nan(a) ? a : float32_chs(a);
316 }
317
float64_ah_chs(float64 a)318 static inline float64 float64_ah_chs(float64 a)
319 {
320 return float64_is_any_nan(a) ? a : float64_chs(a);
321 }
322
float16_maybe_ah_chs(float16 a,bool fpcr_ah)323 static inline float16 float16_maybe_ah_chs(float16 a, bool fpcr_ah)
324 {
325 return fpcr_ah && float16_is_any_nan(a) ? a : float16_chs(a);
326 }
327
float32_maybe_ah_chs(float32 a,bool fpcr_ah)328 static inline float32 float32_maybe_ah_chs(float32 a, bool fpcr_ah)
329 {
330 return fpcr_ah && float32_is_any_nan(a) ? a : float32_chs(a);
331 }
332
float64_maybe_ah_chs(float64 a,bool fpcr_ah)333 static inline float64 float64_maybe_ah_chs(float64 a, bool fpcr_ah)
334 {
335 return fpcr_ah && float64_is_any_nan(a) ? a : float64_chs(a);
336 }
337
338 /* Not actually called directly as a helper, but uses similar machinery. */
339 bfloat16 helper_sme2_ah_fmax_b16(bfloat16 a, bfloat16 b, float_status *fpst);
340 bfloat16 helper_sme2_ah_fmin_b16(bfloat16 a, bfloat16 b, float_status *fpst);
341
342 float32 sve_f16_to_f32(float16 f, float_status *fpst);
343 float16 sve_f32_to_f16(float32 f, float_status *fpst);
344
345 /*
346 * Decode helper functions for predicate as counter.
347 */
348
349 typedef struct {
350 unsigned count;
351 unsigned lg2_stride;
352 bool invert;
353 } DecodeCounter;
354
355 static inline DecodeCounter
decode_counter(unsigned png,unsigned vl,unsigned v_esz)356 decode_counter(unsigned png, unsigned vl, unsigned v_esz)
357 {
358 DecodeCounter ret = { };
359
360 /* C.f. Arm pseudocode CounterToPredicate. */
361 if (likely(png & 0xf)) {
362 unsigned p_esz = ctz32(png);
363
364 /*
365 * maxbit = log2(pl(bits) * 4)
366 * = log2(vl(bytes) * 4)
367 * = log2(vl) + 2
368 * maxbit_mask = ones<maxbit:0>
369 * = (1 << (maxbit + 1)) - 1
370 * = (1 << (log2(vl) + 2 + 1)) - 1
371 * = (1 << (log2(vl) + 3)) - 1
372 * = (pow2ceil(vl) << 3) - 1
373 */
374 ret.count = png & (((unsigned)pow2ceil(vl) << 3) - 1);
375 ret.count >>= p_esz + 1;
376
377 ret.invert = (png >> 15) & 1;
378
379 /*
380 * The Arm pseudocode for CounterToPredicate expands the count to
381 * a set of bits, and then the operation proceeds as for the original
382 * interpretation of predicates as a set of bits.
383 *
384 * We can avoid the expansion by adjusting the count and supplying
385 * an element stride.
386 */
387 if (unlikely(p_esz != v_esz)) {
388 if (p_esz < v_esz) {
389 /*
390 * For predicate esz < vector esz, the expanded predicate
391 * will have more bits set than will be consumed.
392 * Adjust the count down, rounding up.
393 * Consider p_esz = MO_8, v_esz = MO_64, count 14:
394 * The expanded predicate would be
395 * 0011 1111 1111 1111
396 * The significant bits are
397 * ...1 ...1 ...1 ...1
398 */
399 unsigned shift = v_esz - p_esz;
400 unsigned trunc = ret.count >> shift;
401 ret.count = trunc + (ret.count != (trunc << shift));
402 } else {
403 /*
404 * For predicate esz > vector esz, the expanded predicate
405 * will have bits set only at power-of-two multiples of
406 * the vector esz. Bits at other multiples will all be
407 * false. Adjust the count up, and supply the caller
408 * with a stride of elements to skip.
409 */
410 unsigned shift = p_esz - v_esz;
411 ret.count <<= shift;
412 ret.lg2_stride = shift;
413 }
414 }
415 }
416 return ret;
417 }
418
419 /* Extract @len bits from an array of uint64_t at offset @pos bits. */
extractn(uint64_t * p,unsigned pos,unsigned len)420 static inline uint64_t extractn(uint64_t *p, unsigned pos, unsigned len)
421 {
422 uint64_t x;
423
424 p += pos / 64;
425 pos = pos % 64;
426
427 x = p[0];
428 if (pos + len > 64) {
429 x = (x >> pos) | (p[1] << (-pos & 63));
430 pos = 0;
431 }
432 return extract64(x, pos, len);
433 }
434
435 /* Deposit @len bits into an array of uint64_t at offset @pos bits. */
depositn(uint64_t * p,unsigned pos,unsigned len,uint64_t val)436 static inline void depositn(uint64_t *p, unsigned pos,
437 unsigned len, uint64_t val)
438 {
439 p += pos / 64;
440 pos = pos % 64;
441
442 if (pos + len <= 64) {
443 p[0] = deposit64(p[0], pos, len, val);
444 } else {
445 unsigned len0 = 64 - pos;
446 unsigned len1 = len - len0;
447
448 p[0] = deposit64(p[0], pos, len0, val);
449 p[1] = deposit64(p[1], 0, len1, val >> len0);
450 }
451 }
452
453 #endif /* TARGET_ARM_VEC_INTERNAL_H */
454