1 // SPDX-License-Identifier: GPL-2.0-only 2 // Copyright (C) 2019-2020 Arm Ltd. 3 4 #include <linux/compiler.h> 5 #include <linux/kasan-checks.h> 6 #include <linux/kernel.h> 7 8 #include <net/checksum.h> 9 10 /* Looks dumb, but generates nice-ish code */ 11 static u64 accumulate(u64 sum, u64 data) 12 { 13 __uint128_t tmp = (__uint128_t)sum + data; 14 return tmp + (tmp >> 64); 15 } 16 17 /* 18 * We over-read the buffer and this makes KASAN unhappy. Instead, disable 19 * instrumentation and call kasan explicitly. 20 */ 21 unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len) 22 { 23 unsigned int offset, shift, sum; 24 const u64 *ptr; 25 u64 data, sum64 = 0; 26 27 if (unlikely(len == 0)) 28 return 0; 29 30 offset = (unsigned long)buff & 7; 31 /* 32 * This is to all intents and purposes safe, since rounding down cannot 33 * result in a different page or cache line being accessed, and @buff 34 * should absolutely not be pointing to anything read-sensitive. We do, 35 * however, have to be careful not to piss off KASAN, which means using 36 * unchecked reads to accommodate the head and tail, for which we'll 37 * compensate with an explicit check up-front. 38 */ 39 kasan_check_read(buff, len); 40 ptr = (u64 *)(buff - offset); 41 len = len + offset - 8; 42 43 /* 44 * Head: zero out any excess leading bytes. Shifting back by the same 45 * amount should be at least as fast as any other way of handling the 46 * odd/even alignment, and means we can ignore it until the very end. 47 */ 48 shift = offset * 8; 49 data = *ptr++; 50 #ifdef __LITTLE_ENDIAN 51 data = (data >> shift) << shift; 52 #else 53 data = (data << shift) >> shift; 54 #endif 55 56 /* 57 * Body: straightforward aligned loads from here on (the paired loads 58 * underlying the quadword type still only need dword alignment). The 59 * main loop strictly excludes the tail, so the second loop will always 60 * run at least once. 61 */ 62 while (unlikely(len > 64)) { 63 __uint128_t tmp1, tmp2, tmp3, tmp4; 64 65 tmp1 = *(__uint128_t *)ptr; 66 tmp2 = *(__uint128_t *)(ptr + 2); 67 tmp3 = *(__uint128_t *)(ptr + 4); 68 tmp4 = *(__uint128_t *)(ptr + 6); 69 70 len -= 64; 71 ptr += 8; 72 73 /* This is the "don't dump the carry flag into a GPR" idiom */ 74 tmp1 += (tmp1 >> 64) | (tmp1 << 64); 75 tmp2 += (tmp2 >> 64) | (tmp2 << 64); 76 tmp3 += (tmp3 >> 64) | (tmp3 << 64); 77 tmp4 += (tmp4 >> 64) | (tmp4 << 64); 78 tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64); 79 tmp1 += (tmp1 >> 64) | (tmp1 << 64); 80 tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64); 81 tmp3 += (tmp3 >> 64) | (tmp3 << 64); 82 tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64); 83 tmp1 += (tmp1 >> 64) | (tmp1 << 64); 84 tmp1 = ((tmp1 >> 64) << 64) | sum64; 85 tmp1 += (tmp1 >> 64) | (tmp1 << 64); 86 sum64 = tmp1 >> 64; 87 } 88 while (len > 8) { 89 __uint128_t tmp; 90 91 sum64 = accumulate(sum64, data); 92 tmp = *(__uint128_t *)ptr; 93 94 len -= 16; 95 ptr += 2; 96 97 #ifdef __LITTLE_ENDIAN 98 data = tmp >> 64; 99 sum64 = accumulate(sum64, tmp); 100 #else 101 data = tmp; 102 sum64 = accumulate(sum64, tmp >> 64); 103 #endif 104 } 105 if (len > 0) { 106 sum64 = accumulate(sum64, data); 107 data = *ptr; 108 len -= 8; 109 } 110 /* 111 * Tail: zero any over-read bytes similarly to the head, again 112 * preserving odd/even alignment. 113 */ 114 shift = len * -8; 115 #ifdef __LITTLE_ENDIAN 116 data = (data << shift) >> shift; 117 #else 118 data = (data >> shift) << shift; 119 #endif 120 sum64 = accumulate(sum64, data); 121 122 /* Finally, folding */ 123 sum64 += (sum64 >> 32) | (sum64 << 32); 124 sum = sum64 >> 32; 125 sum += (sum >> 16) | (sum << 16); 126 if (offset & 1) 127 return (u16)swab32(sum); 128 129 return sum >> 16; 130 } 131 132 __sum16 csum_ipv6_magic(const struct in6_addr *saddr, 133 const struct in6_addr *daddr, 134 __u32 len, __u8 proto, __wsum csum) 135 { 136 __uint128_t src, dst; 137 u64 sum = (__force u64)csum; 138 139 src = *(const __uint128_t *)saddr->s6_addr; 140 dst = *(const __uint128_t *)daddr->s6_addr; 141 142 sum += (__force u32)htonl(len); 143 #ifdef __LITTLE_ENDIAN 144 sum += (u32)proto << 24; 145 #else 146 sum += proto; 147 #endif 148 src += (src >> 64) | (src << 64); 149 dst += (dst >> 64) | (dst << 64); 150 151 sum = accumulate(sum, src >> 64); 152 sum = accumulate(sum, dst >> 64); 153 154 sum += ((sum >> 32) | (sum << 32)); 155 return csum_fold((__force __wsum)(sum >> 32)); 156 } 157 EXPORT_SYMBOL(csum_ipv6_magic); 158