001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 * http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.commons.compress.compressors.lz4;
020
021import java.io.IOException;
022import java.io.InputStream;
023import java.util.Arrays;
024import java.util.zip.CheckedInputStream;
025
026import org.apache.commons.compress.compressors.CompressorInputStream;
027import org.apache.commons.compress.utils.ByteUtils;
028import org.apache.commons.compress.utils.IOUtils;
029import org.apache.commons.compress.utils.InputStreamStatistics;
030import org.apache.commons.io.input.BoundedInputStream;
031
032/**
033 * CompressorInputStream for the LZ4 frame format.
034 *
035 * <p>
036 * Based on the "spec" in the version "1.5.1 (31/03/2015)"
037 * </p>
038 *
039 * @see <a href="https://lz4.github.io/lz4/lz4_Frame_format.html">LZ4 Frame Format Description</a>
040 * @since 1.14
041 * @NotThreadSafe
042 */
043public class FramedLZ4CompressorInputStream extends CompressorInputStream implements InputStreamStatistics {
044
045    /** Used by FramedLZ4CompressorOutputStream as well. */
046    static final byte[] LZ4_SIGNATURE = { 4, 0x22, 0x4d, 0x18 };
047    private static final byte[] SKIPPABLE_FRAME_TRAILER = { 0x2a, 0x4d, 0x18 };
048    private static final byte SKIPPABLE_FRAME_PREFIX_BYTE_MASK = 0x50;
049
050    static final int VERSION_MASK = 0xC0;
051    static final int SUPPORTED_VERSION = 0x40;
052    static final int BLOCK_INDEPENDENCE_MASK = 0x20;
053    static final int BLOCK_CHECKSUM_MASK = 0x10;
054    static final int CONTENT_SIZE_MASK = 0x08;
055    static final int CONTENT_CHECKSUM_MASK = 0x04;
056    static final int BLOCK_MAX_SIZE_MASK = 0x70;
057    static final int UNCOMPRESSED_FLAG_MASK = 0x80000000;
058
059    private static boolean isSkippableFrameSignature(final byte[] b) {
060        if ((b[0] & SKIPPABLE_FRAME_PREFIX_BYTE_MASK) != SKIPPABLE_FRAME_PREFIX_BYTE_MASK) {
061            return false;
062        }
063        for (int i = 1; i < 4; i++) {
064            if (b[i] != SKIPPABLE_FRAME_TRAILER[i - 1]) {
065                return false;
066            }
067        }
068        return true;
069    }
070
071    /**
072     * Checks if the signature matches what is expected for a .lz4 file.
073     * <p>
074     * .lz4 files start with a four byte signature.
075     * </p>
076     *
077     * @param signature the bytes to check
078     * @param length    the number of bytes to check
079     * @return true if this is a .sz stream, false otherwise
080     */
081    public static boolean matches(final byte[] signature, final int length) {
082
083        if (length < LZ4_SIGNATURE.length) {
084            return false;
085        }
086
087        byte[] shortenedSig = signature;
088        if (signature.length > LZ4_SIGNATURE.length) {
089            shortenedSig = Arrays.copyOf(signature, LZ4_SIGNATURE.length);
090        }
091
092        return Arrays.equals(shortenedSig, LZ4_SIGNATURE);
093    }
094
095    /** Used in no-arg read method. */
096    private final byte[] oneByte = new byte[1];
097    private final ByteUtils.ByteSupplier supplier = this::readOneByte;
098
099    private final BoundedInputStream inputStream;
100    private final boolean decompressConcatenated;
101    private boolean expectBlockChecksum;
102    private boolean expectBlockDependency;
103
104    private boolean expectContentChecksum;
105
106    private InputStream currentBlock;
107
108    private boolean endReached, inUncompressed;
109
110    /** Used for frame header checksum and content checksum, if present. */
111    private final org.apache.commons.codec.digest.XXHash32 contentHash = new org.apache.commons.codec.digest.XXHash32();
112
113    /** Used for block checksum, if present. */
114    private final org.apache.commons.codec.digest.XXHash32 blockHash = new org.apache.commons.codec.digest.XXHash32();
115
116    /** Only created if the frame doesn't set the block independence flag. */
117    private byte[] blockDependencyBuffer;
118
119    /**
120     * Creates a new input stream that decompresses streams compressed using the LZ4 frame format and stops after decompressing the first frame.
121     *
122     * @param in the InputStream from which to read the compressed data
123     * @throws IOException if reading fails
124     */
125    public FramedLZ4CompressorInputStream(final InputStream in) throws IOException {
126        this(in, false);
127    }
128
129    /**
130     * Creates a new input stream that decompresses streams compressed using the LZ4 frame format.
131     *
132     * @param in                     the InputStream from which to read the compressed data
133     * @param decompressConcatenated if true, decompress until the end of the input; if false, stop after the first LZ4 frame and leave the input position to
134     *                               point to the next byte after the frame stream
135     * @throws IOException if reading fails
136     */
137    public FramedLZ4CompressorInputStream(final InputStream in, final boolean decompressConcatenated) throws IOException {
138        this.inputStream = BoundedInputStream.builder().setInputStream(in).get();
139        this.decompressConcatenated = decompressConcatenated;
140        init(true);
141    }
142
143    private void appendToBlockDependencyBuffer(final byte[] b, final int off, int len) {
144        len = Math.min(len, blockDependencyBuffer.length);
145        if (len > 0) {
146            final int keep = blockDependencyBuffer.length - len;
147            if (keep > 0) {
148                // move last keep bytes towards the start of the buffer
149                System.arraycopy(blockDependencyBuffer, len, blockDependencyBuffer, 0, keep);
150            }
151            // append new data
152            System.arraycopy(b, off, blockDependencyBuffer, keep, len);
153        }
154    }
155
156    /** {@inheritDoc} */
157    @Override
158    public void close() throws IOException {
159        try {
160            if (currentBlock != null) {
161                currentBlock.close();
162                currentBlock = null;
163            }
164        } finally {
165            inputStream.close();
166        }
167    }
168
169    /**
170     * @since 1.17
171     */
172    @Override
173    public long getCompressedCount() {
174        return inputStream.getCount();
175    }
176
177    private void init(final boolean firstFrame) throws IOException {
178        if (readSignature(firstFrame)) {
179            readFrameDescriptor();
180            nextBlock();
181        }
182    }
183
184    private void maybeFinishCurrentBlock() throws IOException {
185        if (currentBlock != null) {
186            currentBlock.close();
187            currentBlock = null;
188            if (expectBlockChecksum) {
189                verifyChecksum(blockHash, "block");
190                blockHash.reset();
191            }
192        }
193    }
194
195    private void nextBlock() throws IOException {
196        maybeFinishCurrentBlock();
197        final long len = ByteUtils.fromLittleEndian(supplier, 4);
198        final boolean uncompressed = (len & UNCOMPRESSED_FLAG_MASK) != 0;
199        final int realLen = (int) (len & ~UNCOMPRESSED_FLAG_MASK);
200        if (realLen == 0) {
201            verifyContentChecksum();
202            if (!decompressConcatenated) {
203                endReached = true;
204            } else {
205                init(false);
206            }
207            return;
208        }
209        // @formatter:off
210        InputStream capped = BoundedInputStream.builder()
211                .setInputStream(inputStream)
212                .setMaxCount(realLen)
213                .setPropagateClose(false)
214                .get();
215        // @formatter:on
216        if (expectBlockChecksum) {
217            capped = new CheckedInputStream(capped, blockHash);
218        }
219        if (uncompressed) {
220            inUncompressed = true;
221            currentBlock = capped;
222        } else {
223            inUncompressed = false;
224            final BlockLZ4CompressorInputStream s = new BlockLZ4CompressorInputStream(capped);
225            if (expectBlockDependency) {
226                s.prefill(blockDependencyBuffer);
227            }
228            currentBlock = s;
229        }
230    }
231
232    /** {@inheritDoc} */
233    @Override
234    public int read() throws IOException {
235        return read(oneByte, 0, 1) == -1 ? -1 : oneByte[0] & 0xFF;
236    }
237
238    /** {@inheritDoc} */
239    @Override
240    public int read(final byte[] b, final int off, final int len) throws IOException {
241        if (len == 0) {
242            return 0;
243        }
244        if (endReached) {
245            return -1;
246        }
247        int r = readOnce(b, off, len);
248        if (r == -1) {
249            nextBlock();
250            if (!endReached) {
251                r = readOnce(b, off, len);
252            }
253        }
254        if (r != -1) {
255            if (expectBlockDependency) {
256                appendToBlockDependencyBuffer(b, off, r);
257            }
258            if (expectContentChecksum) {
259                contentHash.update(b, off, r);
260            }
261        }
262        return r;
263    }
264
265    private void readFrameDescriptor() throws IOException {
266        final int flags = readOneByte();
267        if (flags == -1) {
268            throw new IOException("Premature end of stream while reading frame flags");
269        }
270        contentHash.update(flags);
271        if ((flags & VERSION_MASK) != SUPPORTED_VERSION) {
272            throw new IOException("Unsupported version " + (flags >> 6));
273        }
274        expectBlockDependency = (flags & BLOCK_INDEPENDENCE_MASK) == 0;
275        if (expectBlockDependency) {
276            if (blockDependencyBuffer == null) {
277                blockDependencyBuffer = new byte[BlockLZ4CompressorInputStream.WINDOW_SIZE];
278            }
279        } else {
280            blockDependencyBuffer = null;
281        }
282        expectBlockChecksum = (flags & BLOCK_CHECKSUM_MASK) != 0;
283        final boolean expectContentSize = (flags & CONTENT_SIZE_MASK) != 0;
284        expectContentChecksum = (flags & CONTENT_CHECKSUM_MASK) != 0;
285        final int bdByte = readOneByte();
286        if (bdByte == -1) { // max size is irrelevant for this implementation
287            throw new IOException("Premature end of stream while reading frame BD byte");
288        }
289        contentHash.update(bdByte);
290        if (expectContentSize) { // for now, we don't care, contains the uncompressed size
291            final byte[] contentSize = new byte[8];
292            final int skipped = IOUtils.readFully(inputStream, contentSize);
293            count(skipped);
294            if (8 != skipped) {
295                throw new IOException("Premature end of stream while reading content size");
296            }
297            contentHash.update(contentSize, 0, contentSize.length);
298        }
299        final int headerHash = readOneByte();
300        if (headerHash == -1) { // partial hash of header.
301            throw new IOException("Premature end of stream while reading frame header checksum");
302        }
303        final int expectedHash = (int) (contentHash.getValue() >> 8 & 0xff);
304        contentHash.reset();
305        if (headerHash != expectedHash) {
306            throw new IOException("Frame header checksum mismatch");
307        }
308    }
309
310    private int readOnce(final byte[] b, final int off, final int len) throws IOException {
311        if (inUncompressed) {
312            final int cnt = currentBlock.read(b, off, len);
313            count(cnt);
314            return cnt;
315        }
316        final BlockLZ4CompressorInputStream l = (BlockLZ4CompressorInputStream) currentBlock;
317        final long before = l.getBytesRead();
318        final int cnt = currentBlock.read(b, off, len);
319        count(l.getBytesRead() - before);
320        return cnt;
321    }
322
323    private int readOneByte() throws IOException {
324        final int b = inputStream.read();
325        if (b != -1) {
326            count(1);
327            return b & 0xFF;
328        }
329        return -1;
330    }
331
332    private boolean readSignature(final boolean firstFrame) throws IOException {
333        final String garbageMessage = firstFrame ? "Not a LZ4 frame stream" : "LZ4 frame stream followed by garbage";
334        final byte[] b = new byte[4];
335        int read = IOUtils.readFully(inputStream, b);
336        count(read);
337        if (0 == read && !firstFrame) {
338            // good LZ4 frame and nothing after it
339            endReached = true;
340            return false;
341        }
342        if (4 != read) {
343            throw new IOException(garbageMessage);
344        }
345
346        read = skipSkippableFrame(b);
347        if (0 == read && !firstFrame) {
348            // good LZ4 frame with only some skippable frames after it
349            endReached = true;
350            return false;
351        }
352        if (4 != read || !matches(b, 4)) {
353            throw new IOException(garbageMessage);
354        }
355        return true;
356    }
357
358    /**
359     * Skips over the contents of a skippable frame as well as skippable frames following it.
360     * <p>
361     * It then tries to read four more bytes which are supposed to hold an LZ4 signature and returns the number of bytes read while storing the bytes in the
362     * given array.
363     * </p>
364     */
365    private int skipSkippableFrame(final byte[] b) throws IOException {
366        int read = 4;
367        while (read == 4 && isSkippableFrameSignature(b)) {
368            final long len = ByteUtils.fromLittleEndian(supplier, 4);
369            if (len < 0) {
370                throw new IOException("Found illegal skippable frame with negative size");
371            }
372            final long skipped = org.apache.commons.io.IOUtils.skip(inputStream, len);
373            count(skipped);
374            if (len != skipped) {
375                throw new IOException("Premature end of stream while skipping frame");
376            }
377            read = IOUtils.readFully(inputStream, b);
378            count(read);
379        }
380        return read;
381    }
382
383    private void verifyChecksum(final org.apache.commons.codec.digest.XXHash32 hash, final String kind) throws IOException {
384        final byte[] checksum = new byte[4];
385        final int read = IOUtils.readFully(inputStream, checksum);
386        count(read);
387        if (4 != read) {
388            throw new IOException("Premature end of stream while reading " + kind + " checksum");
389        }
390        final long expectedHash = hash.getValue();
391        if (expectedHash != ByteUtils.fromLittleEndian(checksum)) {
392            throw new IOException(kind + " checksum mismatch.");
393        }
394    }
395
396    private void verifyContentChecksum() throws IOException {
397        if (expectContentChecksum) {
398            verifyChecksum(contentHash, "content");
399        }
400        contentHash.reset();
401    }
402}