001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *   http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.compress.utils;
019
020import java.io.File;
021import java.io.IOException;
022import java.nio.ByteBuffer;
023import java.nio.channels.ClosedChannelException;
024import java.nio.channels.NonWritableChannelException;
025import java.nio.channels.SeekableByteChannel;
026import java.nio.file.Files;
027import java.nio.file.Path;
028import java.nio.file.StandardOpenOption;
029import java.util.ArrayList;
030import java.util.Arrays;
031import java.util.Collections;
032import java.util.List;
033import java.util.Objects;
034
035/**
036 * Implements a read-only {@link SeekableByteChannel} that concatenates a collection of other {@link SeekableByteChannel}s.
037 * <p>
038 * This is a lose port of <a href=
039 * "https://github.com/frugalmechanic/fm-common/blob/master/jvm/src/main/scala/fm/common/MultiReadOnlySeekableByteChannel.scala">
040 * MultiReadOnlySeekableByteChannel</a>
041 * by Tim Underwood.
042 * </p>
043 *
044 * @since 1.19
045 */
046public class MultiReadOnlySeekableByteChannel implements SeekableByteChannel {
047
048    private static final Path[] EMPTY_PATH_ARRAY = {};
049
050    /**
051     * Concatenates the given files.
052     *
053     * @param files the files to concatenate
054     * @throws NullPointerException if files is null
055     * @throws IOException          if opening a channel for one of the files fails
056     * @return SeekableByteChannel that concatenates all provided files
057     */
058    public static SeekableByteChannel forFiles(final File... files) throws IOException {
059        final List<Path> paths = new ArrayList<>();
060        for (final File f : Objects.requireNonNull(files, "files")) {
061            paths.add(f.toPath());
062        }
063        return forPaths(paths.toArray(EMPTY_PATH_ARRAY));
064    }
065
066    /**
067     * Concatenates the given file paths.
068     *
069     * @param paths the file paths to concatenate, note that the LAST FILE of files should be the LAST SEGMENT(.zip) and these files should be added in correct
070     *              order (e.g.: .z01, .z02... .z99, .zip)
071     * @return SeekableByteChannel that concatenates all provided files
072     * @throws NullPointerException if files is null
073     * @throws IOException          if opening a channel for one of the files fails
074     * @throws IOException          if the first channel doesn't seem to hold the beginning of a split archive
075     * @since 1.22
076     */
077    public static SeekableByteChannel forPaths(final Path... paths) throws IOException {
078        final List<SeekableByteChannel> channels = new ArrayList<>();
079        for (final Path path : Objects.requireNonNull(paths, "paths")) {
080            channels.add(Files.newByteChannel(path, StandardOpenOption.READ));
081        }
082        if (channels.size() == 1) {
083            return channels.get(0);
084        }
085        return new MultiReadOnlySeekableByteChannel(channels);
086    }
087
088    /**
089     * Concatenates the given channels.
090     *
091     * @param channels the channels to concatenate
092     * @throws NullPointerException if channels is null
093     * @return SeekableByteChannel that concatenates all provided channels
094     */
095    public static SeekableByteChannel forSeekableByteChannels(final SeekableByteChannel... channels) {
096        if (Objects.requireNonNull(channels, "channels").length == 1) {
097            return channels[0];
098        }
099        return new MultiReadOnlySeekableByteChannel(Arrays.asList(channels));
100    }
101
102    private final List<SeekableByteChannel> channelList;
103
104    private long globalPosition;
105
106    private int currentChannelIdx;
107
108    /**
109     * Concatenates the given channels.
110     *
111     * @param channels the channels to concatenate
112     * @throws NullPointerException if channels is null
113     */
114    public MultiReadOnlySeekableByteChannel(final List<SeekableByteChannel> channels) {
115        this.channelList = Collections.unmodifiableList(new ArrayList<>(Objects.requireNonNull(channels, "channels")));
116    }
117
118    @Override
119    public void close() throws IOException {
120        IOException first = null;
121        for (final SeekableByteChannel ch : channelList) {
122            try {
123                ch.close();
124            } catch (final IOException ex) {
125                if (first == null) {
126                    first = ex;
127                }
128            }
129        }
130        if (first != null) {
131            throw new IOException("failed to close wrapped channel", first);
132        }
133    }
134
135    @Override
136    public boolean isOpen() {
137        return channelList.stream().allMatch(SeekableByteChannel::isOpen);
138    }
139
140    /**
141     * Gets this channel's position.
142     * <p>
143     * This method violates the contract of {@link SeekableByteChannel#position()} as it will not throw any exception when invoked on a closed channel. Instead
144     * it will return the position the channel had when close has been called.
145     * </p>
146     */
147    @Override
148    public long position() {
149        return globalPosition;
150    }
151
152    @Override
153    public synchronized SeekableByteChannel position(final long newPosition) throws IOException {
154        if (newPosition < 0) {
155            throw new IllegalArgumentException("Negative position: " + newPosition);
156        }
157        if (!isOpen()) {
158            throw new ClosedChannelException();
159        }
160        globalPosition = newPosition;
161        long pos = newPosition;
162        for (int i = 0; i < channelList.size(); i++) {
163            final SeekableByteChannel currentChannel = channelList.get(i);
164            final long size = currentChannel.size();
165
166            final long newChannelPos;
167            if (pos == -1L) {
168                // Position is already set for the correct channel,
169                // the rest of the channels get reset to 0
170                newChannelPos = 0;
171            } else if (pos <= size) {
172                // This channel is where we want to be
173                currentChannelIdx = i;
174                final long tmp = pos;
175                pos = -1L; // Mark pos as already being set
176                newChannelPos = tmp;
177            } else {
178                // newPosition is past this channel. Set channel
179                // position to the end and substract channel size from
180                // pos
181                pos -= size;
182                newChannelPos = size;
183            }
184            currentChannel.position(newChannelPos);
185        }
186        return this;
187    }
188
189    /**
190     * Sets the position based on the given channel number and relative offset
191     *
192     * @param channelNumber  the channel number
193     * @param relativeOffset the relative offset in the corresponding channel
194     * @return global position of all channels as if they are a single channel
195     * @throws IOException if positioning fails
196     */
197    public synchronized SeekableByteChannel position(final long channelNumber, final long relativeOffset) throws IOException {
198        if (!isOpen()) {
199            throw new ClosedChannelException();
200        }
201        long globalPosition = relativeOffset;
202        for (int i = 0; i < channelNumber; i++) {
203            globalPosition += channelList.get(i).size();
204        }
205
206        return position(globalPosition);
207    }
208
209    @Override
210    public synchronized int read(final ByteBuffer dst) throws IOException {
211        if (!isOpen()) {
212            throw new ClosedChannelException();
213        }
214        if (!dst.hasRemaining()) {
215            return 0;
216        }
217
218        int totalBytesRead = 0;
219        while (dst.hasRemaining() && currentChannelIdx < channelList.size()) {
220            final SeekableByteChannel currentChannel = channelList.get(currentChannelIdx);
221            final int newBytesRead = currentChannel.read(dst);
222            if (newBytesRead == -1) {
223                // EOF for this channel -- advance to next channel idx
224                currentChannelIdx += 1;
225                continue;
226            }
227            if (currentChannel.position() >= currentChannel.size()) {
228                // we are at the end of the current channel
229                currentChannelIdx++;
230            }
231            totalBytesRead += newBytesRead;
232        }
233        if (totalBytesRead > 0) {
234            globalPosition += totalBytesRead;
235            return totalBytesRead;
236        }
237        return -1;
238    }
239
240    @Override
241    public long size() throws IOException {
242        if (!isOpen()) {
243            throw new ClosedChannelException();
244        }
245        long acc = 0;
246        for (final SeekableByteChannel ch : channelList) {
247            acc += ch.size();
248        }
249        return acc;
250    }
251
252    /**
253     * @throws NonWritableChannelException since this implementation is read-only.
254     */
255    @Override
256    public SeekableByteChannel truncate(final long size) {
257        throw new NonWritableChannelException();
258    }
259
260    /**
261     * @throws NonWritableChannelException since this implementation is read-only.
262     */
263    @Override
264    public int write(final ByteBuffer src) {
265        throw new NonWritableChannelException();
266    }
267
268}