HuffmanDecoder.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.compress.compressors.deflate64;
import static org.apache.commons.compress.compressors.deflate64.HuffmanState.DYNAMIC_CODES;
import static org.apache.commons.compress.compressors.deflate64.HuffmanState.FIXED_CODES;
import static org.apache.commons.compress.compressors.deflate64.HuffmanState.INITIAL;
import static org.apache.commons.compress.compressors.deflate64.HuffmanState.STORED;
import java.io.Closeable;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteOrder;
import java.util.Arrays;
import org.apache.commons.compress.utils.BitInputStream;
import org.apache.commons.compress.utils.ByteUtils;
import org.apache.commons.compress.utils.ExactMath;
import org.apache.commons.lang3.ArrayFill;
/**
* TODO This class can't be final because it is mocked by Mockito.
*/
class HuffmanDecoder implements Closeable {
private static final class BinaryTreeNode {
private final int bits;
int literal = -1;
BinaryTreeNode leftNode;
BinaryTreeNode rightNode;
private BinaryTreeNode(final int bits) {
this.bits = bits;
}
void leaf(final int symbol) {
literal = symbol;
leftNode = null;
rightNode = null;
}
BinaryTreeNode left() {
if (leftNode == null && literal == -1) {
leftNode = new BinaryTreeNode(bits + 1);
}
return leftNode;
}
BinaryTreeNode right() {
if (rightNode == null && literal == -1) {
rightNode = new BinaryTreeNode(bits + 1);
}
return rightNode;
}
}
private abstract static class DecoderState {
abstract int available() throws IOException;
abstract boolean hasData();
abstract int read(byte[] b, int off, int len) throws IOException;
abstract HuffmanState state();
}
private static final class DecodingMemory {
private final byte[] memory;
private final int mask;
private int wHead;
private boolean wrappedAround;
private DecodingMemory() {
this(16);
}
private DecodingMemory(final int bits) {
memory = new byte[1 << bits];
mask = memory.length - 1;
}
byte add(final byte b) {
memory[wHead] = b;
wHead = incCounter(wHead);
return b;
}
void add(final byte[] b, final int off, final int len) {
for (int i = off; i < off + len; i++) {
add(b[i]);
}
}
private int incCounter(final int counter) {
final int newCounter = counter + 1 & mask;
if (!wrappedAround && newCounter < counter) {
wrappedAround = true;
}
return newCounter;
}
void recordToBuffer(final int distance, final int length, final byte[] buff) {
if (distance > memory.length) {
throw new IllegalStateException("Illegal distance parameter: " + distance);
}
final int start = wHead - distance & mask;
if (!wrappedAround && start >= wHead) {
throw new IllegalStateException("Attempt to read beyond memory: dist=" + distance);
}
for (int i = 0, pos = start; i < length; i++, pos = incCounter(pos)) {
buff[i] = add(memory[pos]);
}
}
}
private final class HuffmanCodes extends DecoderState {
private boolean endOfBlock;
private final HuffmanState state;
private final BinaryTreeNode lengthTree;
private final BinaryTreeNode distanceTree;
private int runBufferPos;
private byte[] runBuffer = ByteUtils.EMPTY_BYTE_ARRAY;
private int runBufferLength;
HuffmanCodes(final HuffmanState state, final int[] lengths, final int[] distance) {
this.state = state;
lengthTree = buildTree(lengths);
distanceTree = buildTree(distance);
}
@Override
int available() {
return runBufferLength - runBufferPos;
}
private int copyFromRunBuffer(final byte[] b, final int off, final int len) {
final int bytesInBuffer = runBufferLength - runBufferPos;
int copiedBytes = 0;
if (bytesInBuffer > 0) {
copiedBytes = Math.min(len, bytesInBuffer);
System.arraycopy(runBuffer, runBufferPos, b, off, copiedBytes);
runBufferPos += copiedBytes;
}
return copiedBytes;
}
private int decodeNext(final byte[] b, final int off, final int len) throws IOException {
if (endOfBlock) {
return -1;
}
int result = copyFromRunBuffer(b, off, len);
while (result < len) {
final int symbol = nextSymbol(reader, lengthTree);
if (symbol < 256) {
b[off + result++] = memory.add((byte) symbol);
} else if (symbol > 256) {
final int runMask = RUN_LENGTH_TABLE[symbol - 257];
int run = runMask >>> 5;
final int runXtra = runMask & 0x1F;
run = ExactMath.add(run, readBits(runXtra));
final int distSym = nextSymbol(reader, distanceTree);
final int distMask = DISTANCE_TABLE[distSym];
int dist = distMask >>> 4;
final int distXtra = distMask & 0xF;
dist = ExactMath.add(dist, readBits(distXtra));
if (runBuffer.length < run) {
runBuffer = new byte[run];
}
runBufferLength = run;
runBufferPos = 0;
memory.recordToBuffer(dist, run, runBuffer);
result += copyFromRunBuffer(b, off + result, len - result);
} else {
endOfBlock = true;
return result;
}
}
return result;
}
@Override
boolean hasData() {
return !endOfBlock;
}
@Override
int read(final byte[] b, final int off, final int len) throws IOException {
if (len == 0) {
return 0;
}
return decodeNext(b, off, len);
}
@Override
HuffmanState state() {
return endOfBlock ? INITIAL : state;
}
}
private static final class InitialState extends DecoderState {
@Override
int available() {
return 0;
}
@Override
boolean hasData() {
return false;
}
@Override
int read(final byte[] b, final int off, final int len) throws IOException {
if (len == 0) {
return 0;
}
throw new IllegalStateException("Cannot read in this state");
}
@Override
HuffmanState state() {
return INITIAL;
}
}
private final class UncompressedState extends DecoderState {
private final long blockLength;
private long read;
private UncompressedState(final long blockLength) {
this.blockLength = blockLength;
}
@Override
int available() throws IOException {
return (int) Math.min(blockLength - read, reader.bitsAvailable() / Byte.SIZE);
}
@Override
boolean hasData() {
return read < blockLength;
}
@Override
int read(final byte[] b, final int off, final int len) throws IOException {
if (len == 0) {
return 0;
}
// as len is an int and (blockLength - read) is >= 0 the min must fit into an int as well
final int max = (int) Math.min(blockLength - read, len);
int readSoFar = 0;
while (readSoFar < max) {
final int readNow;
if (reader.bitsCached() > 0) {
final byte next = (byte) readBits(Byte.SIZE);
b[off + readSoFar] = memory.add(next);
readNow = 1;
} else {
readNow = in.read(b, off + readSoFar, max - readSoFar);
if (readNow == -1) {
throw new EOFException("Truncated Deflate64 Stream");
}
memory.add(b, off + readSoFar, readNow);
}
read += readNow;
readSoFar += readNow;
}
return max;
}
@Override
HuffmanState state() {
return read < blockLength ? STORED : INITIAL;
}
}
/**
* <pre>
* --------------------------------------------------------------------
* idx xtra base idx xtra base idx xtra base
* --------------------------------------------------------------------
* 257 0 3 267 1 15,16 277 4 67-82
* 258 0 4 268 1 17,18 278 4 83-98
* 259 0 5 269 2 19-22 279 4 99-114
* 260 0 6 270 2 23-26 280 4 115-130
* 261 0 7 271 2 27-30 281 5 131-162
* 262 0 8 272 2 31-34 282 5 163-194
* 263 0 9 273 3 35-42 283 5 195-226
* 264 0 10 274 3 43-50 284 5 227-257
* 265 1 11,12 275 3 51-58 285 16 3
* 266 1 13,14 276 3 59-66
* --------------------------------------------------------------------
* </pre>
*
* value = (base of run length) << 5 | (number of extra bits to read)
*/
private static final short[] RUN_LENGTH_TABLE = { 96, 128, 160, 192, 224, 256, 288, 320, 353, 417, 481, 545, 610, 738, 866, 994, 1123, 1379, 1635, 1891,
2148, 2660, 3172, 3684, 4197, 5221, 6245, 7269, 112 };
/**
* <pre>
* --------------------------------------------------------------------
* idx xtra dist idx xtra dist idx xtra dist
* --------------------------------------------------------------------
* 0 0 1 10 4 33-48 20 9 1025-1536
* 1 0 2 11 4 49-64 21 9 1537-2048
* 2 0 3 12 5 65-96 22 10 2049-3072
* 3 0 4 13 5 97-128 23 10 3073-4096
* 4 1 5,6 14 6 129-192 24 11 4097-6144
* 5 1 7,8 15 6 193-256 25 11 6145-8192
* 6 2 9-12 16 7 257-384 26 12 8193-12288
* 7 2 13-16 17 7 385-512 27 12 12289-16384
* 8 3 17-24 18 8 513-768 28 13 16385-24576
* 9 3 25-32 19 8 769-1024 29 13 24577-32768
* 30 14 32769-49152
* 31 14 49153-65536
* --------------------------------------------------------------------
* </pre>
*
* value = (base of distance) << 4 | (number of extra bits to read)
*/
private static final int[] DISTANCE_TABLE = { 16, 32, 48, 64, 81, 113, 146, 210, 275, 403, // 0-9
532, 788, 1045, 1557, 2070, 3094, 4119, 6167, 8216, 12312, // 10-19
16409, 24601, 32794, 49178, 65563, 98331, 131100, 196636, 262173, 393245, // 20-29
524318, 786462 // 30-31
};
/**
* When using dynamic huffman codes the order in which the values are stored follows the positioning below
*/
private static final int[] CODE_LENGTHS_ORDER = { 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 };
/**
* Huffman Fixed Literal / Distance tables for mode 1
*/
private static final int[] FIXED_LITERALS;
private static final int[] FIXED_DISTANCE;
static {
FIXED_LITERALS = new int[288];
Arrays.fill(FIXED_LITERALS, 0, 144, 8);
Arrays.fill(FIXED_LITERALS, 144, 256, 9);
Arrays.fill(FIXED_LITERALS, 256, 280, 7);
Arrays.fill(FIXED_LITERALS, 280, 288, 8);
FIXED_DISTANCE = ArrayFill.fill(new int[32], 5);
}
private static BinaryTreeNode buildTree(final int[] litTable) {
final int[] literalCodes = getCodes(litTable);
final BinaryTreeNode root = new BinaryTreeNode(0);
for (int i = 0; i < litTable.length; i++) {
final int len = litTable[i];
if (len != 0) {
BinaryTreeNode node = root;
final int lit = literalCodes[len - 1];
for (int p = len - 1; p >= 0; p--) {
final int bit = lit & 1 << p;
node = bit == 0 ? node.left() : node.right();
if (node == null) {
throw new IllegalStateException("node doesn't exist in Huffman tree");
}
}
node.leaf(i);
literalCodes[len - 1]++;
}
}
return root;
}
private static int[] getCodes(final int[] litTable) {
int max = 0;
int[] blCount = new int[65];
for (final int aLitTable : litTable) {
if (aLitTable < 0 || aLitTable > 64) {
throw new IllegalArgumentException("Invalid code " + aLitTable + " in literal table");
}
max = Math.max(max, aLitTable);
blCount[aLitTable]++;
}
blCount = Arrays.copyOf(blCount, max + 1);
int code = 0;
final int[] nextCode = new int[max + 1];
for (int i = 0; i <= max; i++) {
code = code + blCount[i] << 1;
nextCode[i] = code;
}
return nextCode;
}
private static int nextSymbol(final BitInputStream reader, final BinaryTreeNode tree) throws IOException {
BinaryTreeNode node = tree;
while (node != null && node.literal == -1) {
final long bit = readBits(reader, 1);
node = bit == 0 ? node.leftNode : node.rightNode;
}
return node != null ? node.literal : -1;
}
private static void populateDynamicTables(final BitInputStream reader, final int[] literals, final int[] distances) throws IOException {
final int codeLengths = (int) (readBits(reader, 4) + 4);
final int[] codeLengthValues = new int[19];
for (int cLen = 0; cLen < codeLengths; cLen++) {
codeLengthValues[CODE_LENGTHS_ORDER[cLen]] = (int) readBits(reader, 3);
}
final BinaryTreeNode codeLengthTree = buildTree(codeLengthValues);
final int[] auxBuffer = new int[literals.length + distances.length];
int value = -1;
int length = 0;
int off = 0;
while (off < auxBuffer.length) {
if (length > 0) {
auxBuffer[off++] = value;
length--;
} else {
final int symbol = nextSymbol(reader, codeLengthTree);
if (symbol < 16) {
value = symbol;
auxBuffer[off++] = value;
} else {
switch (symbol) {
case 16:
length = (int) (readBits(reader, 2) + 3);
break;
case 17:
value = 0;
length = (int) (readBits(reader, 3) + 3);
break;
case 18:
value = 0;
length = (int) (readBits(reader, 7) + 11);
break;
default:
break;
}
}
}
}
System.arraycopy(auxBuffer, 0, literals, 0, literals.length);
System.arraycopy(auxBuffer, literals.length, distances, 0, distances.length);
}
private static long readBits(final BitInputStream reader, final int numBits) throws IOException {
final long r = reader.readBits(numBits);
if (r == -1) {
throw new EOFException("Truncated Deflate64 Stream");
}
return r;
}
private boolean finalBlock;
private DecoderState state;
private BitInputStream reader;
private final InputStream in;
private final DecodingMemory memory = new DecodingMemory();
HuffmanDecoder(final InputStream in) {
this.reader = new BitInputStream(in, ByteOrder.LITTLE_ENDIAN);
this.in = in;
state = new InitialState();
}
int available() throws IOException {
return state.available();
}
@Override
public void close() {
state = new InitialState();
reader = null;
}
public int decode(final byte[] b) throws IOException {
return decode(b, 0, b.length);
}
public int decode(final byte[] b, final int off, final int len) throws IOException {
while (!finalBlock || state.hasData()) {
if (state.state() == INITIAL) {
finalBlock = readBits(1) == 1;
final int mode = (int) readBits(2);
switch (mode) {
case 0:
switchToUncompressedState();
break;
case 1:
state = new HuffmanCodes(FIXED_CODES, FIXED_LITERALS, FIXED_DISTANCE);
break;
case 2:
final int[][] tables = readDynamicTables();
state = new HuffmanCodes(DYNAMIC_CODES, tables[0], tables[1]);
break;
default:
throw new IllegalStateException("Unsupported compression: " + mode);
}
} else {
final int r = state.read(b, off, len);
if (r != 0) {
return r;
}
}
}
return -1;
}
/**
* @since 1.17
*/
long getBytesRead() {
return reader.getBytesRead();
}
private long readBits(final int numBits) throws IOException {
return readBits(reader, numBits);
}
private int[][] readDynamicTables() throws IOException {
final int[][] result = new int[2][];
final int literals = (int) (readBits(5) + 257);
result[0] = new int[literals];
final int distances = (int) (readBits(5) + 1);
result[1] = new int[distances];
populateDynamicTables(reader, result[0], result[1]);
return result;
}
private void switchToUncompressedState() throws IOException {
reader.alignWithByteBoundary();
final long bLen = readBits(16);
final long bNLen = readBits(16);
if (((bLen ^ 0xFFFF) & 0xFFFF) != bNLen) {
// noinspection DuplicateStringLiteralInspection
throw new IllegalStateException("Illegal LEN / NLEN values");
}
state = new UncompressedState(bLen);
}
}