1 /* SPDX-License-Identifier: GPL-2.0-or-later */ 2 /* 3 * decompress_common.h - Code shared by the XPRESS and LZX decompressors 4 * 5 * Copyright (C) 2015 Eric Biggers 6 */ 7 8 #include <linux/string.h> 9 #include <linux/compiler.h> 10 #include <linux/types.h> 11 #include <linux/slab.h> 12 #include <asm/unaligned.h> 13 14 15 /* "Force inline" macro (not required, but helpful for performance) */ 16 #define forceinline __always_inline 17 18 /* Enable whole-word match copying on selected architectures */ 19 #if defined(__i386__) || defined(__x86_64__) || defined(__ARM_FEATURE_UNALIGNED) 20 # define FAST_UNALIGNED_ACCESS 21 #endif 22 23 /* Size of a machine word */ 24 #define WORDBYTES (sizeof(size_t)) 25 26 static forceinline void 27 copy_unaligned_word(const void *src, void *dst) 28 { 29 put_unaligned(get_unaligned((const size_t *)src), (size_t *)dst); 30 } 31 32 33 /* Generate a "word" with platform-dependent size whose bytes all contain the 34 * value 'b'. 35 */ 36 static forceinline size_t repeat_byte(u8 b) 37 { 38 size_t v; 39 40 v = b; 41 v |= v << 8; 42 v |= v << 16; 43 v |= v << ((WORDBYTES == 8) ? 32 : 0); 44 return v; 45 } 46 47 /* Structure that encapsulates a block of in-memory data being interpreted as a 48 * stream of bits, optionally with interwoven literal bytes. Bits are assumed 49 * to be stored in little endian 16-bit coding units, with the bits ordered high 50 * to low. 51 */ 52 struct input_bitstream { 53 54 /* Bits that have been read from the input buffer. The bits are 55 * left-justified; the next bit is always bit 31. 56 */ 57 u32 bitbuf; 58 59 /* Number of bits currently held in @bitbuf. */ 60 u32 bitsleft; 61 62 /* Pointer to the next byte to be retrieved from the input buffer. */ 63 const u8 *next; 64 65 /* Pointer to just past the end of the input buffer. */ 66 const u8 *end; 67 }; 68 69 /* Initialize a bitstream to read from the specified input buffer. */ 70 static forceinline void init_input_bitstream(struct input_bitstream *is, 71 const void *buffer, u32 size) 72 { 73 is->bitbuf = 0; 74 is->bitsleft = 0; 75 is->next = buffer; 76 is->end = is->next + size; 77 } 78 79 /* Ensure the bit buffer variable for the bitstream contains at least @num_bits 80 * bits. Following this, bitstream_peek_bits() and/or bitstream_remove_bits() 81 * may be called on the bitstream to peek or remove up to @num_bits bits. Note 82 * that @num_bits must be <= 16. 83 */ 84 static forceinline void bitstream_ensure_bits(struct input_bitstream *is, 85 u32 num_bits) 86 { 87 if (is->bitsleft < num_bits) { 88 if (is->end - is->next >= 2) { 89 is->bitbuf |= (u32)get_unaligned_le16(is->next) 90 << (16 - is->bitsleft); 91 is->next += 2; 92 } 93 is->bitsleft += 16; 94 } 95 } 96 97 /* Return the next @num_bits bits from the bitstream, without removing them. 98 * There must be at least @num_bits remaining in the buffer variable, from a 99 * previous call to bitstream_ensure_bits(). 100 */ 101 static forceinline u32 102 bitstream_peek_bits(const struct input_bitstream *is, const u32 num_bits) 103 { 104 return (is->bitbuf >> 1) >> (sizeof(is->bitbuf) * 8 - num_bits - 1); 105 } 106 107 /* Remove @num_bits from the bitstream. There must be at least @num_bits 108 * remaining in the buffer variable, from a previous call to 109 * bitstream_ensure_bits(). 110 */ 111 static forceinline void 112 bitstream_remove_bits(struct input_bitstream *is, u32 num_bits) 113 { 114 is->bitbuf <<= num_bits; 115 is->bitsleft -= num_bits; 116 } 117 118 /* Remove and return @num_bits bits from the bitstream. There must be at least 119 * @num_bits remaining in the buffer variable, from a previous call to 120 * bitstream_ensure_bits(). 121 */ 122 static forceinline u32 123 bitstream_pop_bits(struct input_bitstream *is, u32 num_bits) 124 { 125 u32 bits = bitstream_peek_bits(is, num_bits); 126 127 bitstream_remove_bits(is, num_bits); 128 return bits; 129 } 130 131 /* Read and return the next @num_bits bits from the bitstream. */ 132 static forceinline u32 133 bitstream_read_bits(struct input_bitstream *is, u32 num_bits) 134 { 135 bitstream_ensure_bits(is, num_bits); 136 return bitstream_pop_bits(is, num_bits); 137 } 138 139 /* Read and return the next literal byte embedded in the bitstream. */ 140 static forceinline u8 141 bitstream_read_byte(struct input_bitstream *is) 142 { 143 if (unlikely(is->end == is->next)) 144 return 0; 145 return *is->next++; 146 } 147 148 /* Read and return the next 16-bit integer embedded in the bitstream. */ 149 static forceinline u16 150 bitstream_read_u16(struct input_bitstream *is) 151 { 152 u16 v; 153 154 if (unlikely(is->end - is->next < 2)) 155 return 0; 156 v = get_unaligned_le16(is->next); 157 is->next += 2; 158 return v; 159 } 160 161 /* Read and return the next 32-bit integer embedded in the bitstream. */ 162 static forceinline u32 163 bitstream_read_u32(struct input_bitstream *is) 164 { 165 u32 v; 166 167 if (unlikely(is->end - is->next < 4)) 168 return 0; 169 v = get_unaligned_le32(is->next); 170 is->next += 4; 171 return v; 172 } 173 174 /* Read into @dst_buffer an array of literal bytes embedded in the bitstream. 175 * Return either a pointer to the byte past the last written, or NULL if the 176 * read overflows the input buffer. 177 */ 178 static forceinline void *bitstream_read_bytes(struct input_bitstream *is, 179 void *dst_buffer, size_t count) 180 { 181 if ((size_t)(is->end - is->next) < count) 182 return NULL; 183 memcpy(dst_buffer, is->next, count); 184 is->next += count; 185 return (u8 *)dst_buffer + count; 186 } 187 188 /* Align the input bitstream on a coding-unit boundary. */ 189 static forceinline void bitstream_align(struct input_bitstream *is) 190 { 191 is->bitsleft = 0; 192 is->bitbuf = 0; 193 } 194 195 extern int make_huffman_decode_table(u16 decode_table[], const u32 num_syms, 196 const u32 num_bits, const u8 lens[], 197 const u32 max_codeword_len, 198 u16 working_space[]); 199 200 201 /* Reads and returns the next Huffman-encoded symbol from a bitstream. If the 202 * input data is exhausted, the Huffman symbol is decoded as if the missing bits 203 * are all zeroes. 204 */ 205 static forceinline u32 read_huffsym(struct input_bitstream *istream, 206 const u16 decode_table[], 207 u32 table_bits, 208 u32 max_codeword_len) 209 { 210 u32 entry; 211 u32 key_bits; 212 213 bitstream_ensure_bits(istream, max_codeword_len); 214 215 /* Index the decode table by the next table_bits bits of the input. */ 216 key_bits = bitstream_peek_bits(istream, table_bits); 217 entry = decode_table[key_bits]; 218 if (entry < 0xC000) { 219 /* Fast case: The decode table directly provided the 220 * symbol and codeword length. The low 11 bits are the 221 * symbol, and the high 5 bits are the codeword length. 222 */ 223 bitstream_remove_bits(istream, entry >> 11); 224 return entry & 0x7FF; 225 } 226 /* Slow case: The codeword for the symbol is longer than 227 * table_bits, so the symbol does not have an entry 228 * directly in the first (1 << table_bits) entries of the 229 * decode table. Traverse the appropriate binary tree 230 * bit-by-bit to decode the symbol. 231 */ 232 bitstream_remove_bits(istream, table_bits); 233 do { 234 key_bits = (entry & 0x3FFF) + bitstream_pop_bits(istream, 1); 235 } while ((entry = decode_table[key_bits]) >= 0xC000); 236 return entry; 237 } 238 239 /* 240 * Copy an LZ77 match at (dst - offset) to dst. 241 * 242 * The length and offset must be already validated --- that is, (dst - offset) 243 * can't underrun the output buffer, and (dst + length) can't overrun the output 244 * buffer. Also, the length cannot be 0. 245 * 246 * @bufend points to the byte past the end of the output buffer. This function 247 * won't write any data beyond this position. 248 * 249 * Returns dst + length. 250 */ 251 static forceinline u8 *lz_copy(u8 *dst, u32 length, u32 offset, const u8 *bufend, 252 u32 min_length) 253 { 254 const u8 *src = dst - offset; 255 256 /* 257 * Try to copy one machine word at a time. On i386 and x86_64 this is 258 * faster than copying one byte at a time, unless the data is 259 * near-random and all the matches have very short lengths. Note that 260 * since this requires unaligned memory accesses, it won't necessarily 261 * be faster on every architecture. 262 * 263 * Also note that we might copy more than the length of the match. For 264 * example, if a word is 8 bytes and the match is of length 5, then 265 * we'll simply copy 8 bytes. This is okay as long as we don't write 266 * beyond the end of the output buffer, hence the check for (bufend - 267 * end >= WORDBYTES - 1). 268 */ 269 #ifdef FAST_UNALIGNED_ACCESS 270 u8 * const end = dst + length; 271 272 if (bufend - end >= (ptrdiff_t)(WORDBYTES - 1)) { 273 274 if (offset >= WORDBYTES) { 275 /* The source and destination words don't overlap. */ 276 277 /* To improve branch prediction, one iteration of this 278 * loop is unrolled. Most matches are short and will 279 * fail the first check. But if that check passes, then 280 * it becomes increasing likely that the match is long 281 * and we'll need to continue copying. 282 */ 283 284 copy_unaligned_word(src, dst); 285 src += WORDBYTES; 286 dst += WORDBYTES; 287 288 if (dst < end) { 289 do { 290 copy_unaligned_word(src, dst); 291 src += WORDBYTES; 292 dst += WORDBYTES; 293 } while (dst < end); 294 } 295 return end; 296 } else if (offset == 1) { 297 298 /* Offset 1 matches are equivalent to run-length 299 * encoding of the previous byte. This case is common 300 * if the data contains many repeated bytes. 301 */ 302 size_t v = repeat_byte(*(dst - 1)); 303 304 do { 305 put_unaligned(v, (size_t *)dst); 306 src += WORDBYTES; 307 dst += WORDBYTES; 308 } while (dst < end); 309 return end; 310 } 311 /* 312 * We don't bother with special cases for other 'offset < 313 * WORDBYTES', which are usually rarer than 'offset == 1'. Extra 314 * checks will just slow things down. Actually, it's possible 315 * to handle all the 'offset < WORDBYTES' cases using the same 316 * code, but it still becomes more complicated doesn't seem any 317 * faster overall; it definitely slows down the more common 318 * 'offset == 1' case. 319 */ 320 } 321 #endif /* FAST_UNALIGNED_ACCESS */ 322 323 /* Fall back to a bytewise copy. */ 324 325 if (min_length >= 2) { 326 *dst++ = *src++; 327 length--; 328 } 329 if (min_length >= 3) { 330 *dst++ = *src++; 331 length--; 332 } 333 do { 334 *dst++ = *src++; 335 } while (--length); 336 337 return dst; 338 } 339