1 /*------------------------------------------------------------------------- 2 * Filename: mini_inflate.c 3 * Version: $Id: mini_inflate.c,v 1.3 2002/01/24 22:58:42 rfeany Exp $ 4 * Copyright: Copyright (C) 2001, Russ Dill 5 * Author: Russ Dill <Russ.Dill@asu.edu> 6 * Description: Mini inflate implementation (RFC 1951) 7 *-----------------------------------------------------------------------*/ 8 /* 9 * 10 * This program is free software; you can redistribute it and/or modify 11 * it under the terms of the GNU General Public License as published by 12 * the Free Software Foundation; either version 2 of the License, or 13 * (at your option) any later version. 14 * 15 * This program is distributed in the hope that it will be useful, 16 * but WITHOUT ANY WARRANTY; without even the implied warranty of 17 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 18 * GNU General Public License for more details. 19 * 20 * You should have received a copy of the GNU General Public License 21 * along with this program; if not, write to the Free Software 22 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 23 * 24 */ 25 26 #include <config.h> 27 #include <jffs2/mini_inflate.h> 28 29 /* The order that the code lengths in section 3.2.7 are in */ 30 static unsigned char huffman_order[] = {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 31 11, 4, 12, 3, 13, 2, 14, 1, 15}; 32 33 inline void cramfs_memset(int *s, const int c, size n) 34 { 35 n--; 36 for (;n > 0; n--) s[n] = c; 37 s[0] = c; 38 } 39 40 /* associate a stream with a block of data and reset the stream */ 41 static void init_stream(struct bitstream *stream, unsigned char *data, 42 void *(*inflate_memcpy)(void *, const void *, size)) 43 { 44 stream->error = NO_ERROR; 45 stream->memcpy = inflate_memcpy; 46 stream->decoded = 0; 47 stream->data = data; 48 stream->bit = 0; /* The first bit of the stream is the lsb of the 49 * first byte */ 50 51 /* really sorry about all this initialization, think of a better way, 52 * let me know and it will get cleaned up */ 53 stream->codes.bits = 8; 54 stream->codes.num_symbols = 19; 55 stream->codes.lengths = stream->code_lengths; 56 stream->codes.symbols = stream->code_symbols; 57 stream->codes.count = stream->code_count; 58 stream->codes.first = stream->code_first; 59 stream->codes.pos = stream->code_pos; 60 61 stream->lengths.bits = 16; 62 stream->lengths.num_symbols = 288; 63 stream->lengths.lengths = stream->length_lengths; 64 stream->lengths.symbols = stream->length_symbols; 65 stream->lengths.count = stream->length_count; 66 stream->lengths.first = stream->length_first; 67 stream->lengths.pos = stream->length_pos; 68 69 stream->distance.bits = 16; 70 stream->distance.num_symbols = 32; 71 stream->distance.lengths = stream->distance_lengths; 72 stream->distance.symbols = stream->distance_symbols; 73 stream->distance.count = stream->distance_count; 74 stream->distance.first = stream->distance_first; 75 stream->distance.pos = stream->distance_pos; 76 77 } 78 79 /* pull 'bits' bits out of the stream. The last bit pulled it returned as the 80 * msb. (section 3.1.1) 81 */ 82 inline unsigned long pull_bits(struct bitstream *stream, 83 const unsigned int bits) 84 { 85 unsigned long ret; 86 int i; 87 88 ret = 0; 89 for (i = 0; i < bits; i++) { 90 ret += ((*(stream->data) >> stream->bit) & 1) << i; 91 92 /* if, before incrementing, we are on bit 7, 93 * go to the lsb of the next byte */ 94 if (stream->bit++ == 7) { 95 stream->bit = 0; 96 stream->data++; 97 } 98 } 99 return ret; 100 } 101 102 inline int pull_bit(struct bitstream *stream) 103 { 104 int ret = ((*(stream->data) >> stream->bit) & 1); 105 if (stream->bit++ == 7) { 106 stream->bit = 0; 107 stream->data++; 108 } 109 return ret; 110 } 111 112 /* discard bits up to the next whole byte */ 113 static void discard_bits(struct bitstream *stream) 114 { 115 if (stream->bit != 0) { 116 stream->bit = 0; 117 stream->data++; 118 } 119 } 120 121 /* No decompression, the data is all literals (section 3.2.4) */ 122 static void decompress_none(struct bitstream *stream, unsigned char *dest) 123 { 124 unsigned int length; 125 126 discard_bits(stream); 127 length = *(stream->data++); 128 length += *(stream->data++) << 8; 129 pull_bits(stream, 16); /* throw away the inverse of the size */ 130 131 stream->decoded += length; 132 stream->memcpy(dest, stream->data, length); 133 stream->data += length; 134 } 135 136 /* Read in a symbol from the stream (section 3.2.2) */ 137 static int read_symbol(struct bitstream *stream, struct huffman_set *set) 138 { 139 int bits = 0; 140 int code = 0; 141 while (!(set->count[bits] && code < set->first[bits] + 142 set->count[bits])) { 143 code = (code << 1) + pull_bit(stream); 144 if (++bits > set->bits) { 145 /* error decoding (corrupted data?) */ 146 stream->error = CODE_NOT_FOUND; 147 return -1; 148 } 149 } 150 return set->symbols[set->pos[bits] + code - set->first[bits]]; 151 } 152 153 /* decompress a stream of data encoded with the passed length and distance 154 * huffman codes */ 155 static void decompress_huffman(struct bitstream *stream, unsigned char *dest) 156 { 157 struct huffman_set *lengths = &(stream->lengths); 158 struct huffman_set *distance = &(stream->distance); 159 160 int symbol, length, dist, i; 161 162 do { 163 if ((symbol = read_symbol(stream, lengths)) < 0) return; 164 if (symbol < 256) { 165 *(dest++) = symbol; /* symbol is a literal */ 166 stream->decoded++; 167 } else if (symbol > 256) { 168 /* Determine the length of the repitition 169 * (section 3.2.5) */ 170 if (symbol < 265) length = symbol - 254; 171 else if (symbol == 285) length = 258; 172 else { 173 length = pull_bits(stream, (symbol - 261) >> 2); 174 length += (4 << ((symbol - 261) >> 2)) + 3; 175 length += ((symbol - 1) % 4) << 176 ((symbol - 261) >> 2); 177 } 178 179 /* Determine how far back to go */ 180 if ((symbol = read_symbol(stream, distance)) < 0) 181 return; 182 if (symbol < 4) dist = symbol + 1; 183 else { 184 dist = pull_bits(stream, (symbol - 2) >> 1); 185 dist += (2 << ((symbol - 2) >> 1)) + 1; 186 dist += (symbol % 2) << ((symbol - 2) >> 1); 187 } 188 stream->decoded += length; 189 for (i = 0; i < length; i++) { 190 *dest = dest[-dist]; 191 dest++; 192 } 193 } 194 } while (symbol != 256); /* 256 is the end of the data block */ 195 } 196 197 /* Fill the lookup tables (section 3.2.2) */ 198 static void fill_code_tables(struct huffman_set *set) 199 { 200 int code = 0, i, length; 201 202 /* fill in the first code of each bit length, and the pos pointer */ 203 set->pos[0] = 0; 204 for (i = 1; i < set->bits; i++) { 205 code = (code + set->count[i - 1]) << 1; 206 set->first[i] = code; 207 set->pos[i] = set->pos[i - 1] + set->count[i - 1]; 208 } 209 210 /* Fill in the table of symbols in order of their huffman code */ 211 for (i = 0; i < set->num_symbols; i++) { 212 if ((length = set->lengths[i])) 213 set->symbols[set->pos[length]++] = i; 214 } 215 216 /* reset the pos pointer */ 217 for (i = 1; i < set->bits; i++) set->pos[i] -= set->count[i]; 218 } 219 220 static void init_code_tables(struct huffman_set *set) 221 { 222 cramfs_memset(set->lengths, 0, set->num_symbols); 223 cramfs_memset(set->count, 0, set->bits); 224 cramfs_memset(set->first, 0, set->bits); 225 } 226 227 /* read in the huffman codes for dynamic decoding (section 3.2.7) */ 228 static void decompress_dynamic(struct bitstream *stream, unsigned char *dest) 229 { 230 /* I tried my best to minimize the memory footprint here, while still 231 * keeping up performance. I really dislike the _lengths[] tables, but 232 * I see no way of eliminating them without a sizable performance 233 * impact. The first struct table keeps track of stats on each bit 234 * length. The _length table keeps a record of the bit length of each 235 * symbol. The _symbols table is for looking up symbols by the huffman 236 * code (the pos element points to the first place in the symbol table 237 * where that bit length occurs). I also hate the initization of these 238 * structs, if someone knows how to compact these, lemme know. */ 239 240 struct huffman_set *codes = &(stream->codes); 241 struct huffman_set *lengths = &(stream->lengths); 242 struct huffman_set *distance = &(stream->distance); 243 244 int hlit = pull_bits(stream, 5) + 257; 245 int hdist = pull_bits(stream, 5) + 1; 246 int hclen = pull_bits(stream, 4) + 4; 247 int length, curr_code, symbol, i, last_code; 248 249 last_code = 0; 250 251 init_code_tables(codes); 252 init_code_tables(lengths); 253 init_code_tables(distance); 254 255 /* fill in the count of each bit length' as well as the lengths 256 * table */ 257 for (i = 0; i < hclen; i++) { 258 length = pull_bits(stream, 3); 259 codes->lengths[huffman_order[i]] = length; 260 if (length) codes->count[length]++; 261 262 } 263 fill_code_tables(codes); 264 265 /* Do the same for the length codes, being carefull of wrap through 266 * to the distance table */ 267 curr_code = 0; 268 while (curr_code < hlit) { 269 if ((symbol = read_symbol(stream, codes)) < 0) return; 270 if (symbol == 0) { 271 curr_code++; 272 last_code = 0; 273 } else if (symbol < 16) { /* Literal length */ 274 lengths->lengths[curr_code] = last_code = symbol; 275 lengths->count[symbol]++; 276 curr_code++; 277 } else if (symbol == 16) { /* repeat the last symbol 3 - 6 278 * times */ 279 length = 3 + pull_bits(stream, 2); 280 for (;length; length--, curr_code++) 281 if (curr_code < hlit) { 282 lengths->lengths[curr_code] = 283 last_code; 284 lengths->count[last_code]++; 285 } else { /* wrap to the distance table */ 286 distance->lengths[curr_code - hlit] = 287 last_code; 288 distance->count[last_code]++; 289 } 290 } else if (symbol == 17) { /* repeat a bit length 0 */ 291 curr_code += 3 + pull_bits(stream, 3); 292 last_code = 0; 293 } else { /* same, but more times */ 294 curr_code += 11 + pull_bits(stream, 7); 295 last_code = 0; 296 } 297 } 298 fill_code_tables(lengths); 299 300 /* Fill the distance table, don't need to worry about wrapthrough 301 * here */ 302 curr_code -= hlit; 303 while (curr_code < hdist) { 304 if ((symbol = read_symbol(stream, codes)) < 0) return; 305 if (symbol == 0) { 306 curr_code++; 307 last_code = 0; 308 } else if (symbol < 16) { 309 distance->lengths[curr_code] = last_code = symbol; 310 distance->count[symbol]++; 311 curr_code++; 312 } else if (symbol == 16) { 313 length = 3 + pull_bits(stream, 2); 314 for (;length; length--, curr_code++) { 315 distance->lengths[curr_code] = 316 last_code; 317 distance->count[last_code]++; 318 } 319 } else if (symbol == 17) { 320 curr_code += 3 + pull_bits(stream, 3); 321 last_code = 0; 322 } else { 323 curr_code += 11 + pull_bits(stream, 7); 324 last_code = 0; 325 } 326 } 327 fill_code_tables(distance); 328 329 decompress_huffman(stream, dest); 330 } 331 332 /* fill in the length and distance huffman codes for fixed encoding 333 * (section 3.2.6) */ 334 static void decompress_fixed(struct bitstream *stream, unsigned char *dest) 335 { 336 /* let gcc fill in the initial values */ 337 struct huffman_set *lengths = &(stream->lengths); 338 struct huffman_set *distance = &(stream->distance); 339 340 cramfs_memset(lengths->count, 0, 16); 341 cramfs_memset(lengths->first, 0, 16); 342 cramfs_memset(lengths->lengths, 8, 144); 343 cramfs_memset(lengths->lengths + 144, 9, 112); 344 cramfs_memset(lengths->lengths + 256, 7, 24); 345 cramfs_memset(lengths->lengths + 280, 8, 8); 346 lengths->count[7] = 24; 347 lengths->count[8] = 152; 348 lengths->count[9] = 112; 349 350 cramfs_memset(distance->count, 0, 16); 351 cramfs_memset(distance->first, 0, 16); 352 cramfs_memset(distance->lengths, 5, 32); 353 distance->count[5] = 32; 354 355 356 fill_code_tables(lengths); 357 fill_code_tables(distance); 358 359 360 decompress_huffman(stream, dest); 361 } 362 363 /* returns the number of bytes decoded, < 0 if there was an error. Note that 364 * this function assumes that the block starts on a byte boundry 365 * (non-compliant, but I don't see where this would happen). section 3.2.3 */ 366 long decompress_block(unsigned char *dest, unsigned char *source, 367 void *(*inflate_memcpy)(void *, const void *, size)) 368 { 369 int bfinal, btype; 370 struct bitstream stream; 371 372 init_stream(&stream, source, inflate_memcpy); 373 do { 374 bfinal = pull_bit(&stream); 375 btype = pull_bits(&stream, 2); 376 if (btype == NO_COMP) decompress_none(&stream, dest + stream.decoded); 377 else if (btype == DYNAMIC_COMP) 378 decompress_dynamic(&stream, dest + stream.decoded); 379 else if (btype == FIXED_COMP) decompress_fixed(&stream, dest + stream.decoded); 380 else stream.error = COMP_UNKNOWN; 381 } while (!bfinal && !stream.error); 382 383 #if 0 384 putstr("decompress_block start\r\n"); 385 putLabeledWord("stream.error = ",stream.error); 386 putLabeledWord("stream.decoded = ",stream.decoded); 387 putLabeledWord("dest = ",dest); 388 putstr("decompress_block end\r\n"); 389 #endif 390 return stream.error ? -stream.error : stream.decoded; 391 } 392