xref: /openbmc/qemu/tests/tcg/s390x/fma.c (revision 2c471a8291c182130a77702d9bd4c910d987c6a9)
1 /*
2  * Test floating-point multiply-and-add instructions.
3  *
4  * SPDX-License-Identifier: GPL-2.0-or-later
5  */
6 #include <fenv.h>
7 #include <stdbool.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 #include "float.h"
12 
13 union val {
14     float e;
15     double d;
16     long double x;
17     char buf[16];
18 };
19 
20 /*
21  * PoP tables as close to the original as possible.
22  */
23 static const char *table1[N_SIGNED_CLASSES][N_SIGNED_CLASSES] = {
24      /*         -inf           -Fn          -0             +0             +Fn          +inf           QNaN         SNaN     */
25     {/* -inf */ "P(+inf)",     "P(+inf)",   "Xi: T(dNaN)", "Xi: T(dNaN)", "P(-inf)",   "P(-inf)",     "P(b)",      "Xi: T(b*)"},
26     {/* -Fn  */ "P(+inf)",     "P(a*b)",    "P(+0)",       "P(-0)",       "P(a*b)",    "P(-inf)",     "P(b)",      "Xi: T(b*)"},
27     {/* -0   */ "Xi: T(dNaN)", "P(+0)",     "P(+0)",       "P(-0)",       "P(-0)",     "Xi: T(dNaN)", "P(b)",      "Xi: T(b*)"},
28     {/* +0   */ "Xi: T(dNaN)", "P(-0)",     "P(-0)",       "P(+0)",       "P(+0)",     "Xi: T(dNaN)", "P(b)",      "Xi: T(b*)"},
29     {/* +Fn  */ "P(-inf)",     "P(a*b)",    "P(-0)",       "P(+0)",       "P(a*b)",    "P(+inf)",     "P(b)",      "Xi: T(b*)"},
30     {/* +inf */ "P(-inf)",     "P(-inf)",   "Xi: T(dNaN)", "Xi: T(dNaN)", "P(+inf)",   "P(+inf)",     "P(b)",      "Xi: T(b*)"},
31     {/* QNaN */ "P(a)",        "P(a)",      "P(a)",        "P(a)",        "P(a)",      "P(a)",        "P(a)",      "Xi: T(b*)"},
32     {/* SNaN */ "Xi: T(a*)",   "Xi: T(a*)", "Xi: T(a*)",   "Xi: T(a*)",   "Xi: T(a*)", "Xi: T(a*)",   "Xi: T(a*)", "Xi: T(a*)"},
33 };
34 
35 static const char *table2[N_SIGNED_CLASSES][N_SIGNED_CLASSES] = {
36      /*         -inf           -Fn        -0         +0         +Fn        +inf           QNaN    SNaN     */
37     {/* -inf */ "T(-inf)",     "T(-inf)", "T(-inf)", "T(-inf)", "T(-inf)", "Xi: T(dNaN)", "T(c)", "Xi: T(c*)"},
38     {/* -Fn  */ "T(-inf)",     "R(p+c)",  "R(p)",    "R(p)",    "R(p+c)",  "T(+inf)",     "T(c)", "Xi: T(c*)"},
39     {/* -0   */ "T(-inf)",     "R(c)",    "T(-0)",   "Rezd",    "R(c)",    "T(+inf)",     "T(c)", "Xi: T(c*)"},
40     {/* +0   */ "T(-inf)",     "R(c)",    "Rezd",    "T(+0)",   "R(c)",    "T(+inf)",     "T(c)", "Xi: T(c*)"},
41     {/* +Fn  */ "T(-inf)",     "R(p+c)",  "R(p)",    "R(p)",    "R(p+c)",  "T(+inf)",     "T(c)", "Xi: T(c*)"},
42     {/* +inf */ "Xi: T(dNaN)", "T(+inf)", "T(+inf)", "T(+inf)", "T(+inf)", "T(+inf)",     "T(c)", "Xi: T(c*)"},
43     {/* QNaN */ "T(p)",        "T(p)",    "T(p)",    "T(p)",    "T(p)",    "T(p)",        "T(p)", "Xi: T(c*)"},
44      /* SNaN: can't happen */
45 };
46 
interpret_tables(union val * r,bool * xi,int fmt,int cls_a,const union val * a,int cls_b,const union val * b,int cls_c,const union val * c)47 static void interpret_tables(union val *r, bool *xi, int fmt,
48                              int cls_a, const union val *a,
49                              int cls_b, const union val *b,
50                              int cls_c, const union val *c)
51 {
52     const char *spec1 = table1[cls_a][cls_b];
53     const char *spec2;
54     union val p;
55     int cls_p;
56 
57     *xi = false;
58 
59     if (strcmp(spec1, "P(-inf)") == 0) {
60         cls_p = CLASS_MINUS_INF;
61     } else if (strcmp(spec1, "P(+inf)") == 0) {
62         cls_p = CLASS_PLUS_INF;
63     } else if (strcmp(spec1, "P(-0)") == 0) {
64         cls_p = CLASS_MINUS_ZERO;
65     } else if (strcmp(spec1, "P(+0)") == 0) {
66         cls_p = CLASS_PLUS_ZERO;
67     } else if (strcmp(spec1, "P(a)") == 0) {
68         cls_p = cls_a;
69         memcpy(&p, a, sizeof(p));
70     } else if (strcmp(spec1, "P(b)") == 0) {
71         cls_p = cls_b;
72         memcpy(&p, b, sizeof(p));
73     } else if (strcmp(spec1, "P(a*b)") == 0) {
74         /*
75          * In the general case splitting fma into multiplication and addition
76          * doesn't work, but this is the case with our test inputs.
77          */
78         cls_p = cls_a == cls_b ? CLASS_PLUS_FN : CLASS_MINUS_FN;
79         switch (fmt) {
80         case 0:
81             p.e = a->e * b->e;
82             break;
83         case 1:
84             p.d = a->d * b->d;
85             break;
86         case 2:
87             p.x = a->x * b->x;
88             break;
89         default:
90             fprintf(stderr, "Unsupported fmt: %d\n", fmt);
91             exit(1);
92         }
93     } else if (strcmp(spec1, "Xi: T(dNaN)") == 0) {
94         memcpy(r, default_nans[fmt], sizeof(*r));
95         *xi = true;
96         return;
97     } else if (strcmp(spec1, "Xi: T(a*)") == 0) {
98         memcpy(r, a, sizeof(*r));
99         snan_to_qnan(r->buf, fmt);
100         *xi = true;
101         return;
102     } else if (strcmp(spec1, "Xi: T(b*)") == 0) {
103         memcpy(r, b, sizeof(*r));
104         snan_to_qnan(r->buf, fmt);
105         *xi = true;
106         return;
107     } else {
108         fprintf(stderr, "Unsupported spec1: %s\n", spec1);
109         exit(1);
110     }
111 
112     spec2 = table2[cls_p][cls_c];
113     if (strcmp(spec2, "T(-inf)") == 0) {
114         memcpy(r, signed_floats[fmt][CLASS_MINUS_INF].v[0], sizeof(*r));
115     } else if (strcmp(spec2, "T(+inf)") == 0) {
116         memcpy(r, signed_floats[fmt][CLASS_PLUS_INF].v[0], sizeof(*r));
117     } else if (strcmp(spec2, "T(-0)") == 0) {
118         memcpy(r, signed_floats[fmt][CLASS_MINUS_ZERO].v[0], sizeof(*r));
119     } else if (strcmp(spec2, "T(+0)") == 0 || strcmp(spec2, "Rezd") == 0) {
120         memcpy(r, signed_floats[fmt][CLASS_PLUS_ZERO].v[0], sizeof(*r));
121     } else if (strcmp(spec2, "R(c)") == 0 || strcmp(spec2, "T(c)") == 0) {
122         memcpy(r, c, sizeof(*r));
123     } else if (strcmp(spec2, "R(p)") == 0 || strcmp(spec2, "T(p)") == 0) {
124         memcpy(r, &p, sizeof(*r));
125     } else if (strcmp(spec2, "R(p+c)") == 0 || strcmp(spec2, "T(p+c)") == 0) {
126         switch (fmt) {
127         case 0:
128             r->e = p.e + c->e;
129             break;
130         case 1:
131             r->d = p.d + c->d;
132             break;
133         case 2:
134             r->x = p.x + c->x;
135             break;
136         default:
137             fprintf(stderr, "Unsupported fmt: %d\n", fmt);
138             exit(1);
139         }
140     } else if (strcmp(spec2, "Xi: T(dNaN)") == 0) {
141         memcpy(r, default_nans[fmt], sizeof(*r));
142         *xi = true;
143     } else if (strcmp(spec2, "Xi: T(c*)") == 0) {
144         memcpy(r, c, sizeof(*r));
145         snan_to_qnan(r->buf, fmt);
146         *xi = true;
147     } else {
148         fprintf(stderr, "Unsupported spec2: %s\n", spec2);
149         exit(1);
150     }
151 }
152 
153 struct iter {
154     int fmt;
155     int cls[3];
156     int val[3];
157 };
158 
iter_next(struct iter * it)159 static bool iter_next(struct iter *it)
160 {
161     int i;
162 
163     for (i = 2; i >= 0; i--) {
164         if (++it->val[i] != signed_floats[it->fmt][it->cls[i]].n) {
165             return true;
166         }
167         it->val[i] = 0;
168 
169         if (++it->cls[i] != N_SIGNED_CLASSES) {
170             return true;
171         }
172         it->cls[i] = 0;
173     }
174 
175     return ++it->fmt != N_FORMATS;
176 }
177 
main(void)178 int main(void)
179 {
180     int ret = EXIT_SUCCESS;
181     struct iter it = {};
182 
183     do {
184         size_t n = float_sizes[it.fmt];
185         union val a, b, c, exp, res;
186         bool xi_exp, xi;
187 
188         memcpy(&a, signed_floats[it.fmt][it.cls[0]].v[it.val[0]], sizeof(a));
189         memcpy(&b, signed_floats[it.fmt][it.cls[1]].v[it.val[1]], sizeof(b));
190         memcpy(&c, signed_floats[it.fmt][it.cls[2]].v[it.val[2]], sizeof(c));
191 
192         interpret_tables(&exp, &xi_exp, it.fmt,
193                          it.cls[1], &b, it.cls[2], &c, it.cls[0], &a);
194 
195         memcpy(&res, &a, sizeof(res));
196         feclearexcept(FE_ALL_EXCEPT);
197         switch (it.fmt) {
198         case 0:
199             asm("maebr %[a],%[b],%[c]"
200                 : [a] "+f" (res.e) : [b] "f" (b.e), [c] "f" (c.e));
201             break;
202         case 1:
203             asm("madbr %[a],%[b],%[c]"
204                 : [a] "+f" (res.d) : [b] "f" (b.d), [c] "f" (c.d));
205             break;
206         case 2:
207             asm("wfmaxb %[a],%[c],%[b],%[a]"
208                 : [a] "+v" (res.x) : [b] "v" (b.x), [c] "v" (c.x));
209             break;
210         default:
211             fprintf(stderr, "Unsupported fmt: %d\n", it.fmt);
212             exit(1);
213         }
214         xi = fetestexcept(FE_ALL_EXCEPT) == FE_INVALID;
215 
216         if (memcmp(&res, &exp, n) != 0 || xi != xi_exp) {
217             fprintf(stderr, "[  FAILED  ] ");
218             dump_v(stderr, &b, n);
219             fprintf(stderr, " * ");
220             dump_v(stderr, &c, n);
221             fprintf(stderr, " + ");
222             dump_v(stderr, &a, n);
223             fprintf(stderr, ": actual=");
224             dump_v(stderr, &res, n);
225             fprintf(stderr, "/%d, expected=", (int)xi);
226             dump_v(stderr, &exp, n);
227             fprintf(stderr, "/%d\n", (int)xi_exp);
228             ret = EXIT_FAILURE;
229         }
230     } while (iter_next(&it));
231 
232     return ret;
233 }
234