1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * csum_partial_copy - do IP checksumming and copy
4  *
5  * (C) Copyright 1996 Linus Torvalds
6  * accelerated versions (and 21264 assembly versions ) contributed by
7  *	Rick Gorton	<rick.gorton@alpha-processor.com>
8  *
9  * Don't look at this too closely - you'll go mad. The things
10  * we do for performance..
11  */
12 
13 #include <linux/types.h>
14 #include <linux/string.h>
15 #include <linux/uaccess.h>
16 
17 
18 #define ldq_u(x,y) \
19 __asm__ __volatile__("ldq_u %0,%1":"=r" (x):"m" (*(const unsigned long *)(y)))
20 
21 #define stq_u(x,y) \
22 __asm__ __volatile__("stq_u %1,%0":"=m" (*(unsigned long *)(y)):"r" (x))
23 
24 #define extql(x,y,z) \
25 __asm__ __volatile__("extql %1,%2,%0":"=r" (z):"r" (x),"r" (y))
26 
27 #define extqh(x,y,z) \
28 __asm__ __volatile__("extqh %1,%2,%0":"=r" (z):"r" (x),"r" (y))
29 
30 #define mskql(x,y,z) \
31 __asm__ __volatile__("mskql %1,%2,%0":"=r" (z):"r" (x),"r" (y))
32 
33 #define mskqh(x,y,z) \
34 __asm__ __volatile__("mskqh %1,%2,%0":"=r" (z):"r" (x),"r" (y))
35 
36 #define insql(x,y,z) \
37 __asm__ __volatile__("insql %1,%2,%0":"=r" (z):"r" (x),"r" (y))
38 
39 #define insqh(x,y,z) \
40 __asm__ __volatile__("insqh %1,%2,%0":"=r" (z):"r" (x),"r" (y))
41 
42 
43 #define __get_user_u(x,ptr)				\
44 ({							\
45 	long __guu_err;					\
46 	__asm__ __volatile__(				\
47 	"1:	ldq_u %0,%2\n"				\
48 	"2:\n"						\
49 	EXC(1b,2b,%0,%1)				\
50 		: "=r"(x), "=r"(__guu_err)		\
51 		: "m"(__m(ptr)), "1"(0));		\
52 	__guu_err;					\
53 })
54 
55 #define __put_user_u(x,ptr)				\
56 ({							\
57 	long __puu_err;					\
58 	__asm__ __volatile__(				\
59 	"1:	stq_u %2,%1\n"				\
60 	"2:\n"						\
61 	EXC(1b,2b,$31,%0)				\
62 		: "=r"(__puu_err)			\
63 		: "m"(__m(addr)), "rJ"(x), "0"(0));	\
64 	__puu_err;					\
65 })
66 
67 
68 static inline unsigned short from64to16(unsigned long x)
69 {
70 	/* Using extract instructions is a bit more efficient
71 	   than the original shift/bitmask version.  */
72 
73 	union {
74 		unsigned long	ul;
75 		unsigned int	ui[2];
76 		unsigned short	us[4];
77 	} in_v, tmp_v, out_v;
78 
79 	in_v.ul = x;
80 	tmp_v.ul = (unsigned long) in_v.ui[0] + (unsigned long) in_v.ui[1];
81 
82 	/* Since the bits of tmp_v.sh[3] are going to always be zero,
83 	   we don't have to bother to add that in.  */
84 	out_v.ul = (unsigned long) tmp_v.us[0] + (unsigned long) tmp_v.us[1]
85 			+ (unsigned long) tmp_v.us[2];
86 
87 	/* Similarly, out_v.us[2] is always zero for the final add.  */
88 	return out_v.us[0] + out_v.us[1];
89 }
90 
91 
92 
93 /*
94  * Ok. This isn't fun, but this is the EASY case.
95  */
96 static inline unsigned long
97 csum_partial_cfu_aligned(const unsigned long __user *src, unsigned long *dst,
98 			 long len, unsigned long checksum,
99 			 int *errp)
100 {
101 	unsigned long carry = 0;
102 	int err = 0;
103 
104 	while (len >= 0) {
105 		unsigned long word;
106 		err |= __get_user(word, src);
107 		checksum += carry;
108 		src++;
109 		checksum += word;
110 		len -= 8;
111 		carry = checksum < word;
112 		*dst = word;
113 		dst++;
114 	}
115 	len += 8;
116 	checksum += carry;
117 	if (len) {
118 		unsigned long word, tmp;
119 		err |= __get_user(word, src);
120 		tmp = *dst;
121 		mskql(word, len, word);
122 		checksum += word;
123 		mskqh(tmp, len, tmp);
124 		carry = checksum < word;
125 		*dst = word | tmp;
126 		checksum += carry;
127 	}
128 	if (err && errp) *errp = err;
129 	return checksum;
130 }
131 
132 /*
133  * This is even less fun, but this is still reasonably
134  * easy.
135  */
136 static inline unsigned long
137 csum_partial_cfu_dest_aligned(const unsigned long __user *src,
138 			      unsigned long *dst,
139 			      unsigned long soff,
140 			      long len, unsigned long checksum,
141 			      int *errp)
142 {
143 	unsigned long first;
144 	unsigned long word, carry;
145 	unsigned long lastsrc = 7+len+(unsigned long)src;
146 	int err = 0;
147 
148 	err |= __get_user_u(first,src);
149 	carry = 0;
150 	while (len >= 0) {
151 		unsigned long second;
152 
153 		err |= __get_user_u(second, src+1);
154 		extql(first, soff, word);
155 		len -= 8;
156 		src++;
157 		extqh(second, soff, first);
158 		checksum += carry;
159 		word |= first;
160 		first = second;
161 		checksum += word;
162 		*dst = word;
163 		dst++;
164 		carry = checksum < word;
165 	}
166 	len += 8;
167 	checksum += carry;
168 	if (len) {
169 		unsigned long tmp;
170 		unsigned long second;
171 		err |= __get_user_u(second, lastsrc);
172 		tmp = *dst;
173 		extql(first, soff, word);
174 		extqh(second, soff, first);
175 		word |= first;
176 		mskql(word, len, word);
177 		checksum += word;
178 		mskqh(tmp, len, tmp);
179 		carry = checksum < word;
180 		*dst = word | tmp;
181 		checksum += carry;
182 	}
183 	if (err && errp) *errp = err;
184 	return checksum;
185 }
186 
187 /*
188  * This is slightly less fun than the above..
189  */
190 static inline unsigned long
191 csum_partial_cfu_src_aligned(const unsigned long __user *src,
192 			     unsigned long *dst,
193 			     unsigned long doff,
194 			     long len, unsigned long checksum,
195 			     unsigned long partial_dest,
196 			     int *errp)
197 {
198 	unsigned long carry = 0;
199 	unsigned long word;
200 	unsigned long second_dest;
201 	int err = 0;
202 
203 	mskql(partial_dest, doff, partial_dest);
204 	while (len >= 0) {
205 		err |= __get_user(word, src);
206 		len -= 8;
207 		insql(word, doff, second_dest);
208 		checksum += carry;
209 		stq_u(partial_dest | second_dest, dst);
210 		src++;
211 		checksum += word;
212 		insqh(word, doff, partial_dest);
213 		carry = checksum < word;
214 		dst++;
215 	}
216 	len += 8;
217 	if (len) {
218 		checksum += carry;
219 		err |= __get_user(word, src);
220 		mskql(word, len, word);
221 		len -= 8;
222 		checksum += word;
223 		insql(word, doff, second_dest);
224 		len += doff;
225 		carry = checksum < word;
226 		partial_dest |= second_dest;
227 		if (len >= 0) {
228 			stq_u(partial_dest, dst);
229 			if (!len) goto out;
230 			dst++;
231 			insqh(word, doff, partial_dest);
232 		}
233 		doff = len;
234 	}
235 	ldq_u(second_dest, dst);
236 	mskqh(second_dest, doff, second_dest);
237 	stq_u(partial_dest | second_dest, dst);
238 out:
239 	checksum += carry;
240 	if (err && errp) *errp = err;
241 	return checksum;
242 }
243 
244 /*
245  * This is so totally un-fun that it's frightening. Don't
246  * look at this too closely, you'll go blind.
247  */
248 static inline unsigned long
249 csum_partial_cfu_unaligned(const unsigned long __user * src,
250 			   unsigned long * dst,
251 			   unsigned long soff, unsigned long doff,
252 			   long len, unsigned long checksum,
253 			   unsigned long partial_dest,
254 			   int *errp)
255 {
256 	unsigned long carry = 0;
257 	unsigned long first;
258 	unsigned long lastsrc;
259 	int err = 0;
260 
261 	err |= __get_user_u(first, src);
262 	lastsrc = 7+len+(unsigned long)src;
263 	mskql(partial_dest, doff, partial_dest);
264 	while (len >= 0) {
265 		unsigned long second, word;
266 		unsigned long second_dest;
267 
268 		err |= __get_user_u(second, src+1);
269 		extql(first, soff, word);
270 		checksum += carry;
271 		len -= 8;
272 		extqh(second, soff, first);
273 		src++;
274 		word |= first;
275 		first = second;
276 		insql(word, doff, second_dest);
277 		checksum += word;
278 		stq_u(partial_dest | second_dest, dst);
279 		carry = checksum < word;
280 		insqh(word, doff, partial_dest);
281 		dst++;
282 	}
283 	len += doff;
284 	checksum += carry;
285 	if (len >= 0) {
286 		unsigned long second, word;
287 		unsigned long second_dest;
288 
289 		err |= __get_user_u(second, lastsrc);
290 		extql(first, soff, word);
291 		extqh(second, soff, first);
292 		word |= first;
293 		first = second;
294 		mskql(word, len-doff, word);
295 		checksum += word;
296 		insql(word, doff, second_dest);
297 		carry = checksum < word;
298 		stq_u(partial_dest | second_dest, dst);
299 		if (len) {
300 			ldq_u(second_dest, dst+1);
301 			insqh(word, doff, partial_dest);
302 			mskqh(second_dest, len, second_dest);
303 			stq_u(partial_dest | second_dest, dst+1);
304 		}
305 		checksum += carry;
306 	} else {
307 		unsigned long second, word;
308 		unsigned long second_dest;
309 
310 		err |= __get_user_u(second, lastsrc);
311 		extql(first, soff, word);
312 		extqh(second, soff, first);
313 		word |= first;
314 		ldq_u(second_dest, dst);
315 		mskql(word, len-doff, word);
316 		checksum += word;
317 		mskqh(second_dest, len, second_dest);
318 		carry = checksum < word;
319 		insql(word, doff, word);
320 		stq_u(partial_dest | word | second_dest, dst);
321 		checksum += carry;
322 	}
323 	if (err && errp) *errp = err;
324 	return checksum;
325 }
326 
327 __wsum
328 csum_partial_copy_from_user(const void __user *src, void *dst, int len,
329 			       __wsum sum, int *errp)
330 {
331 	unsigned long checksum = (__force u32) sum;
332 	unsigned long soff = 7 & (unsigned long) src;
333 	unsigned long doff = 7 & (unsigned long) dst;
334 
335 	if (len) {
336 		if (!access_ok(src, len)) {
337 			if (errp) *errp = -EFAULT;
338 			memset(dst, 0, len);
339 			return sum;
340 		}
341 		if (!doff) {
342 			if (!soff)
343 				checksum = csum_partial_cfu_aligned(
344 					(const unsigned long __user *) src,
345 					(unsigned long *) dst,
346 					len-8, checksum, errp);
347 			else
348 				checksum = csum_partial_cfu_dest_aligned(
349 					(const unsigned long __user *) src,
350 					(unsigned long *) dst,
351 					soff, len-8, checksum, errp);
352 		} else {
353 			unsigned long partial_dest;
354 			ldq_u(partial_dest, dst);
355 			if (!soff)
356 				checksum = csum_partial_cfu_src_aligned(
357 					(const unsigned long __user *) src,
358 					(unsigned long *) dst,
359 					doff, len-8, checksum,
360 					partial_dest, errp);
361 			else
362 				checksum = csum_partial_cfu_unaligned(
363 					(const unsigned long __user *) src,
364 					(unsigned long *) dst,
365 					soff, doff, len-8, checksum,
366 					partial_dest, errp);
367 		}
368 		checksum = from64to16 (checksum);
369 	}
370 	return (__force __wsum)checksum;
371 }
372 EXPORT_SYMBOL(csum_partial_copy_from_user);
373 
374 __wsum
375 csum_partial_copy_nocheck(const void *src, void *dst, int len, __wsum sum)
376 {
377 	__wsum checksum;
378 	mm_segment_t oldfs = get_fs();
379 	set_fs(KERNEL_DS);
380 	checksum = csum_partial_copy_from_user((__force const void __user *)src,
381 						dst, len, sum, NULL);
382 	set_fs(oldfs);
383 	return checksum;
384 }
385 EXPORT_SYMBOL(csum_partial_copy_nocheck);
386