11f51573fSRichard Henderson /*
21f51573fSRichard Henderson * SME outer product, 1 x 1.
31f51573fSRichard Henderson * SPDX-License-Identifier: GPL-2.0-or-later
41f51573fSRichard Henderson */
51f51573fSRichard Henderson
61f51573fSRichard Henderson #include <stdio.h>
71f51573fSRichard Henderson
81f51573fSRichard Henderson extern void foo(float *dst);
91f51573fSRichard Henderson
101f51573fSRichard Henderson asm(
111f51573fSRichard Henderson " .arch_extension sme\n"
121f51573fSRichard Henderson " .type foo, @function\n"
131f51573fSRichard Henderson "foo:\n"
141f51573fSRichard Henderson " stp x29, x30, [sp, -80]!\n"
151f51573fSRichard Henderson " mov x29, sp\n"
161f51573fSRichard Henderson " stp d8, d9, [sp, 16]\n"
171f51573fSRichard Henderson " stp d10, d11, [sp, 32]\n"
181f51573fSRichard Henderson " stp d12, d13, [sp, 48]\n"
191f51573fSRichard Henderson " stp d14, d15, [sp, 64]\n"
201f51573fSRichard Henderson " smstart\n"
211f51573fSRichard Henderson " ptrue p0.s, vl4\n"
221f51573fSRichard Henderson " fmov z0.s, #1.0\n"
231f51573fSRichard Henderson /*
241f51573fSRichard Henderson * An outer product of a vector of 1.0 by itself should be a matrix of 1.0.
251f51573fSRichard Henderson * Note that we are using tile 1 here (za1.s) rather than tile 0.
261f51573fSRichard Henderson */
271f51573fSRichard Henderson " zero {za}\n"
281f51573fSRichard Henderson " fmopa za1.s, p0/m, p0/m, z0.s, z0.s\n"
291f51573fSRichard Henderson /*
301f51573fSRichard Henderson * Read the first 4x4 sub-matrix of elements from tile 1:
31*673d8215SMichael Tokarev * Note that za1h should be interchangeable here.
321f51573fSRichard Henderson */
331f51573fSRichard Henderson " mov w12, #0\n"
341f51573fSRichard Henderson " mova z0.s, p0/m, za1v.s[w12, #0]\n"
351f51573fSRichard Henderson " mova z1.s, p0/m, za1v.s[w12, #1]\n"
361f51573fSRichard Henderson " mova z2.s, p0/m, za1v.s[w12, #2]\n"
371f51573fSRichard Henderson " mova z3.s, p0/m, za1v.s[w12, #3]\n"
381f51573fSRichard Henderson /*
391f51573fSRichard Henderson * And store them to the input pointer (dst in the C code):
401f51573fSRichard Henderson */
411f51573fSRichard Henderson " st1w {z0.s}, p0, [x0]\n"
421f51573fSRichard Henderson " add x0, x0, #16\n"
431f51573fSRichard Henderson " st1w {z1.s}, p0, [x0]\n"
441f51573fSRichard Henderson " add x0, x0, #16\n"
451f51573fSRichard Henderson " st1w {z2.s}, p0, [x0]\n"
461f51573fSRichard Henderson " add x0, x0, #16\n"
471f51573fSRichard Henderson " st1w {z3.s}, p0, [x0]\n"
481f51573fSRichard Henderson " smstop\n"
491f51573fSRichard Henderson " ldp d8, d9, [sp, 16]\n"
501f51573fSRichard Henderson " ldp d10, d11, [sp, 32]\n"
511f51573fSRichard Henderson " ldp d12, d13, [sp, 48]\n"
521f51573fSRichard Henderson " ldp d14, d15, [sp, 64]\n"
531f51573fSRichard Henderson " ldp x29, x30, [sp], 80\n"
541f51573fSRichard Henderson " ret\n"
551f51573fSRichard Henderson " .size foo, . - foo"
561f51573fSRichard Henderson );
571f51573fSRichard Henderson
main()581f51573fSRichard Henderson int main()
591f51573fSRichard Henderson {
601f51573fSRichard Henderson float dst[16];
611f51573fSRichard Henderson int i, j;
621f51573fSRichard Henderson
631f51573fSRichard Henderson foo(dst);
641f51573fSRichard Henderson
651f51573fSRichard Henderson for (i = 0; i < 16; i++) {
661f51573fSRichard Henderson if (dst[i] != 1.0f) {
671f51573fSRichard Henderson break;
681f51573fSRichard Henderson }
691f51573fSRichard Henderson }
701f51573fSRichard Henderson
711f51573fSRichard Henderson if (i == 16) {
721f51573fSRichard Henderson return 0; /* success */
731f51573fSRichard Henderson }
741f51573fSRichard Henderson
751f51573fSRichard Henderson /* failure */
761f51573fSRichard Henderson for (i = 0; i < 4; ++i) {
771f51573fSRichard Henderson for (j = 0; j < 4; ++j) {
781f51573fSRichard Henderson printf("%f ", (double)dst[i * 4 + j]);
791f51573fSRichard Henderson }
801f51573fSRichard Henderson printf("\n");
811f51573fSRichard Henderson }
821f51573fSRichard Henderson return 1;
831f51573fSRichard Henderson }
84