1 /* 2 * SME outer product, 1 x 1. 3 * SPDX-License-Identifier: GPL-2.0-or-later 4 */ 5 6 #include <stdio.h> 7 8 static void foo(float *dst) 9 { 10 asm(".arch_extension sme\n\t" 11 "smstart\n\t" 12 "ptrue p0.s, vl4\n\t" 13 "fmov z0.s, #1.0\n\t" 14 /* 15 * An outer product of a vector of 1.0 by itself should be a matrix of 1.0. 16 * Note that we are using tile 1 here (za1.s) rather than tile 0. 17 */ 18 "zero {za}\n\t" 19 "fmopa za1.s, p0/m, p0/m, z0.s, z0.s\n\t" 20 /* 21 * Read the first 4x4 sub-matrix of elements from tile 1: 22 * Note that za1h should be interchangeable here. 23 */ 24 "mov w12, #0\n\t" 25 "mova z0.s, p0/m, za1v.s[w12, #0]\n\t" 26 "mova z1.s, p0/m, za1v.s[w12, #1]\n\t" 27 "mova z2.s, p0/m, za1v.s[w12, #2]\n\t" 28 "mova z3.s, p0/m, za1v.s[w12, #3]\n\t" 29 /* 30 * And store them to the input pointer (dst in the C code): 31 */ 32 "st1w {z0.s}, p0, [%0]\n\t" 33 "add x0, x0, #16\n\t" 34 "st1w {z1.s}, p0, [x0]\n\t" 35 "add x0, x0, #16\n\t" 36 "st1w {z2.s}, p0, [x0]\n\t" 37 "add x0, x0, #16\n\t" 38 "st1w {z3.s}, p0, [x0]\n\t" 39 "smstop" 40 : : "r"(dst) 41 : "x12", "d0", "d1", "d2", "d3", "memory"); 42 } 43 44 int main() 45 { 46 float dst[16] = { }; 47 48 foo(dst); 49 50 for (int i = 0; i < 16; i++) { 51 if (dst[i] != 1.0f) { 52 goto failure; 53 } 54 } 55 /* success */ 56 return 0; 57 58 failure: 59 for (int i = 0; i < 16; i++) { 60 printf("%f%c", dst[i], i % 4 == 3 ? '\n' : ' '); 61 } 62 return 1; 63 } 64