1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Access kernel memory without faulting. 4 */ 5 #include <linux/export.h> 6 #include <linux/mm.h> 7 #include <linux/uaccess.h> 8 9 static __always_inline long 10 probe_read_common(void *dst, const void __user *src, size_t size) 11 { 12 long ret; 13 14 pagefault_disable(); 15 ret = __copy_from_user_inatomic(dst, src, size); 16 pagefault_enable(); 17 18 return ret ? -EFAULT : 0; 19 } 20 21 /** 22 * probe_kernel_read(): safely attempt to read from a kernel-space location 23 * @dst: pointer to the buffer that shall take the data 24 * @src: address to read from 25 * @size: size of the data chunk 26 * 27 * Safely read from address @src to the buffer at @dst. If a kernel fault 28 * happens, handle that and return -EFAULT. 29 * 30 * We ensure that the copy_from_user is executed in atomic context so that 31 * do_page_fault() doesn't attempt to take mmap_sem. This makes 32 * probe_kernel_read() suitable for use within regions where the caller 33 * already holds mmap_sem, or other locks which nest inside mmap_sem. 34 */ 35 36 long __weak probe_kernel_read(void *dst, const void *src, size_t size) 37 __attribute__((alias("__probe_kernel_read"))); 38 39 long __probe_kernel_read(void *dst, const void *src, size_t size) 40 { 41 long ret; 42 mm_segment_t old_fs = get_fs(); 43 44 set_fs(KERNEL_DS); 45 ret = probe_read_common(dst, (__force const void __user *)src, size); 46 set_fs(old_fs); 47 48 return ret; 49 } 50 EXPORT_SYMBOL_GPL(probe_kernel_read); 51 52 /** 53 * probe_user_read(): safely attempt to read from a user-space location 54 * @dst: pointer to the buffer that shall take the data 55 * @src: address to read from. This must be a user address. 56 * @size: size of the data chunk 57 * 58 * Safely read from user address @src to the buffer at @dst. If a kernel fault 59 * happens, handle that and return -EFAULT. 60 */ 61 62 long __weak probe_user_read(void *dst, const void __user *src, size_t size) 63 __attribute__((alias("__probe_user_read"))); 64 65 long __probe_user_read(void *dst, const void __user *src, size_t size) 66 { 67 long ret = -EFAULT; 68 mm_segment_t old_fs = get_fs(); 69 70 set_fs(USER_DS); 71 if (access_ok(src, size)) 72 ret = probe_read_common(dst, src, size); 73 set_fs(old_fs); 74 75 return ret; 76 } 77 EXPORT_SYMBOL_GPL(probe_user_read); 78 79 /** 80 * probe_kernel_write(): safely attempt to write to a location 81 * @dst: address to write to 82 * @src: pointer to the data that shall be written 83 * @size: size of the data chunk 84 * 85 * Safely write to address @dst from the buffer at @src. If a kernel fault 86 * happens, handle that and return -EFAULT. 87 */ 88 long __weak probe_kernel_write(void *dst, const void *src, size_t size) 89 __attribute__((alias("__probe_kernel_write"))); 90 91 long __probe_kernel_write(void *dst, const void *src, size_t size) 92 { 93 long ret; 94 mm_segment_t old_fs = get_fs(); 95 96 set_fs(KERNEL_DS); 97 pagefault_disable(); 98 ret = __copy_to_user_inatomic((__force void __user *)dst, src, size); 99 pagefault_enable(); 100 set_fs(old_fs); 101 102 return ret ? -EFAULT : 0; 103 } 104 EXPORT_SYMBOL_GPL(probe_kernel_write); 105 106 107 /** 108 * strncpy_from_unsafe: - Copy a NUL terminated string from unsafe address. 109 * @dst: Destination address, in kernel space. This buffer must be at 110 * least @count bytes long. 111 * @unsafe_addr: Unsafe address. 112 * @count: Maximum number of bytes to copy, including the trailing NUL. 113 * 114 * Copies a NUL-terminated string from unsafe address to kernel buffer. 115 * 116 * On success, returns the length of the string INCLUDING the trailing NUL. 117 * 118 * If access fails, returns -EFAULT (some data may have been copied 119 * and the trailing NUL added). 120 * 121 * If @count is smaller than the length of the string, copies @count-1 bytes, 122 * sets the last byte of @dst buffer to NUL and returns @count. 123 */ 124 long strncpy_from_unsafe(char *dst, const void *unsafe_addr, long count) 125 { 126 mm_segment_t old_fs = get_fs(); 127 const void *src = unsafe_addr; 128 long ret; 129 130 if (unlikely(count <= 0)) 131 return 0; 132 133 set_fs(KERNEL_DS); 134 pagefault_disable(); 135 136 do { 137 ret = __get_user(*dst++, (const char __user __force *)src++); 138 } while (dst[-1] && ret == 0 && src - unsafe_addr < count); 139 140 dst[-1] = '\0'; 141 pagefault_enable(); 142 set_fs(old_fs); 143 144 return ret ? -EFAULT : src - unsafe_addr; 145 } 146 147 /** 148 * strncpy_from_unsafe_user: - Copy a NUL terminated string from unsafe user 149 * address. 150 * @dst: Destination address, in kernel space. This buffer must be at 151 * least @count bytes long. 152 * @unsafe_addr: Unsafe user address. 153 * @count: Maximum number of bytes to copy, including the trailing NUL. 154 * 155 * Copies a NUL-terminated string from unsafe user address to kernel buffer. 156 * 157 * On success, returns the length of the string INCLUDING the trailing NUL. 158 * 159 * If access fails, returns -EFAULT (some data may have been copied 160 * and the trailing NUL added). 161 * 162 * If @count is smaller than the length of the string, copies @count-1 bytes, 163 * sets the last byte of @dst buffer to NUL and returns @count. 164 */ 165 long strncpy_from_unsafe_user(char *dst, const void __user *unsafe_addr, 166 long count) 167 { 168 mm_segment_t old_fs = get_fs(); 169 long ret; 170 171 if (unlikely(count <= 0)) 172 return 0; 173 174 set_fs(USER_DS); 175 pagefault_disable(); 176 ret = strncpy_from_user(dst, unsafe_addr, count); 177 pagefault_enable(); 178 set_fs(old_fs); 179 180 if (ret >= count) { 181 ret = count; 182 dst[ret - 1] = '\0'; 183 } else if (ret > 0) { 184 ret++; 185 } 186 187 return ret; 188 } 189 190 /** 191 * strnlen_unsafe_user: - Get the size of a user string INCLUDING final NUL. 192 * @unsafe_addr: The string to measure. 193 * @count: Maximum count (including NUL) 194 * 195 * Get the size of a NUL-terminated string in user space without pagefault. 196 * 197 * Returns the size of the string INCLUDING the terminating NUL. 198 * 199 * If the string is too long, returns a number larger than @count. User 200 * has to check the return value against "> count". 201 * On exception (or invalid count), returns 0. 202 * 203 * Unlike strnlen_user, this can be used from IRQ handler etc. because 204 * it disables pagefaults. 205 */ 206 long strnlen_unsafe_user(const void __user *unsafe_addr, long count) 207 { 208 mm_segment_t old_fs = get_fs(); 209 int ret; 210 211 set_fs(USER_DS); 212 pagefault_disable(); 213 ret = strnlen_user(unsafe_addr, count); 214 pagefault_enable(); 215 set_fs(old_fs); 216 217 return ret; 218 } 219