Add ChecksumInputStream and test (#548)
Should replace the Apache Commons Compress version
diff --git a/src/main/java/org/apache/commons/io/input/ChecksumInputStream.java b/src/main/java/org/apache/commons/io/input/ChecksumInputStream.java
new file mode 100644
index 0000000..f1a2551
--- /dev/null
+++ b/src/main/java/org/apache/commons/io/input/ChecksumInputStream.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.io.input;
+
+import static org.apache.commons.io.IOUtils.EOF;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.zip.CheckedInputStream;
+import java.util.zip.Checksum;
+
+import org.apache.commons.io.build.AbstractStreamBuilder;
+
+/**
+ * Automatically verifies a {@link Checksum} value once the stream is exhausted or the count threshold is reached.
+ * <p>
+ * If the {@link Checksum} does not meet the expected value when exhausted, then the input stream throws an
+ * {@link IOException}.
+ * </p>
+ * <p>
+ * If you do not need the verification or threshold feature, then use a plain {@link CheckedInputStream}.
+ * </p>
+ *
+ * @since 2.16.0
+ */
+public final class ChecksumInputStream extends CountingInputStream {
+
+ // @formatter:off
+ /**
+ * Builds a new {@link ChecksumInputStream} instance.
+ * <p>
+ * There is no default {@link Checksum}; you MUST provide one.
+ * </p>
+ * <h2>Using NIO</h2>
+ *
+ * <pre>{@code
+ * ChecksumInputStream s = ChecksumInputStream.builder()
+ * .setPath(Paths.get("MyFile.xml"))
+ * .setChecksum(new CRC32())
+ * .setExpectedChecksumValue(12345)
+ * .get();
+ * }</pre>
+ *
+ * <h2>Using IO</h2>
+ *
+ * <pre>{@code
+ * ChecksumInputStream s = ChecksumInputStream.builder()
+ * .setFile(new File("MyFile.xml"))
+ * .setChecksum(new CRC32())
+ * .setExpectedChecksumValue(12345)
+ * .get();
+ * }</pre>
+ *
+ * <h2>Validating only part of an InputStream</h2>
+ * <p>
+ * The following validates the first 100 bytes of the given input.
+ * </p>
+ * <pre>{@code
+ * ChecksumInputStream s = ChecksumInputStream.builder()
+ * .setPath(Paths.get("MyFile.xml"))
+ * .setChecksum(new CRC32())
+ * .setExpectedChecksumValue(12345)
+ * .setCountThreshold(100)
+ * .get();
+ * }</pre>
+ * <p>
+ * To validate input <em>after</em> the beginning of a stream, build an instance with an InputStream starting where you want to validate.
+ * </p>
+ * <pre>{@code
+ * InputStream inputStream = ...;
+ * inputStream.read(...);
+ * inputStream.skip(...);
+ * ChecksumInputStream s = ChecksumInputStream.builder()
+ * .setInputStream(inputStream)
+ * .setChecksum(new CRC32())
+ * .setExpectedChecksumValue(12345)
+ * .setCountThreshold(100)
+ * .get();
+ * }</pre>
+ */
+ // @formatter:on
+ public static class Builder extends AbstractStreamBuilder<ChecksumInputStream, Builder> {
+
+ /**
+ * There is no default checksum, you MUST provide one. This avoids any issue with a default {@link Checksum}
+ * being proven deficient or insecure in the future.
+ */
+ private Checksum checksum;
+
+ /**
+ * The count threshold to limit how much input is consumed to update the {@link Checksum} before the input
+ * stream validates its value.
+ * <p>
+ * By default, all input updates the {@link Checksum}.
+ * </p>
+ */
+ private long countThreshold = -1;
+
+ /**
+ * The expected {@link Checksum} value once the stream is exhausted or the count threshold is reached.
+ */
+ private long expectedChecksumValue;
+
+ /**
+ * Constructs a new instance.
+ * <p>
+ * This builder requires an input convertible by {@link #getInputStream()}.
+ * </p>
+ * <p>
+ * You must provide an origin that can be converted to an InputStream by this builder, otherwise, this call will
+ * throw an {@link UnsupportedOperationException}.
+ * </p>
+ *
+ * @return a new instance.
+ * @throws UnsupportedOperationException if the origin cannot provide an InputStream.
+ * @see #getInputStream()
+ */
+ @SuppressWarnings("resource")
+ @Override
+ public ChecksumInputStream get() throws IOException {
+ return new ChecksumInputStream(getInputStream(), checksum, expectedChecksumValue, countThreshold);
+ }
+
+ /**
+ * Sets the Checksum.
+ *
+ * @param checksum the Checksum.
+ * @return this.
+ */
+ public Builder setChecksum(final Checksum checksum) {
+ this.checksum = checksum;
+ return this;
+ }
+
+ /**
+ * Sets the count threshold to limit how much input is consumed to update the {@link Checksum} before the input
+ * stream validates its value.
+ * <p>
+ * By default, all input updates the {@link Checksum}.
+ * </p>
+ *
+ * @param countThreshold the count threshold. A negative number means the threshold is unbound.
+ * @return this.
+ */
+ public Builder setCountThreshold(final long countThreshold) {
+ this.countThreshold = countThreshold;
+ return this;
+ }
+
+ /**
+ * The expected {@link Checksum} value once the stream is exhausted or the count threshold is reached.
+ *
+ * @param expectedChecksumValue The expected Checksum value.
+ * @return this.
+ */
+ public Builder setExpectedChecksumValue(final long expectedChecksumValue) {
+ this.expectedChecksumValue = expectedChecksumValue;
+ return this;
+ }
+
+ }
+
+ /**
+ * Constructs a new {@link Builder}.
+ *
+ * @return a new {@link Builder}.
+ */
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** The expected checksum. */
+ private final long expectedChecksumValue;
+
+ /**
+ * The count threshold to limit how much input is consumed to update the {@link Checksum} before the input stream
+ * validates its value.
+ * <p>
+ * By default, all input updates the {@link Checksum}.
+ * </p>
+ */
+ private final long countThreshold;
+
+ /**
+ * Constructs a new instance.
+ *
+ * @param in the stream to wrap.
+ * @param checksum a Checksum implementation.
+ * @param expectedChecksumValue the expected checksum.
+ * @param countThreshold the count threshold to limit how much input is consumed, a negative number means the
+ * threshold is unbound.
+ */
+ private ChecksumInputStream(final InputStream in, final Checksum checksum, final long expectedChecksumValue,
+ final long countThreshold) {
+ super(new CheckedInputStream(in, checksum));
+ this.countThreshold = countThreshold;
+ this.expectedChecksumValue = expectedChecksumValue;
+ }
+
+ @Override
+ protected synchronized void afterRead(final int n) throws IOException {
+ super.afterRead(n);
+ if ((countThreshold > 0 && getByteCount() >= countThreshold || n == EOF)
+ && expectedChecksumValue != getChecksum().getValue()) {
+ // Validate when past the threshold or at EOF
+ throw new IOException("Checksum verification failed.");
+ }
+ }
+
+ /**
+ * Gets the current checksum value.
+ *
+ * @return the current checksum value.
+ */
+ private Checksum getChecksum() {
+ return ((CheckedInputStream) in).getChecksum();
+ }
+
+ /**
+ * Gets the byte count remaining to read.
+ *
+ * @return bytes remaining to read, a negative number means the threshold is unbound.
+ */
+ public long getRemaining() {
+ return countThreshold - getByteCount();
+ }
+
+}
diff --git a/src/test/java/org/apache/commons/io/input/ChecksumInputStreamTest.java b/src/test/java/org/apache/commons/io/input/ChecksumInputStreamTest.java
new file mode 100644
index 0000000..94290bc
--- /dev/null
+++ b/src/test/java/org/apache/commons/io/input/ChecksumInputStreamTest.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.io.input;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.zip.Adler32;
+import java.util.zip.CRC32;
+
+import org.junit.jupiter.api.Test;
+
+/**
+ * Tests {@link ChecksumInputStream}.
+ */
+public class ChecksumInputStreamTest {
+
+ @Test
+ public void testDefaultThresholdFailure() throws IOException {
+ final byte[] byteArray = new byte[3];
+ final Adler32 adler32 = new Adler32();
+ try (ChecksumInputStream checksum = ChecksumInputStream.builder()
+ // @formatter:off
+ .setByteArray(byteArray)
+ .setChecksum(adler32)
+ .setExpectedChecksumValue((byte) -68)
+ .get()) {
+ // @formatter:on
+ assertEquals(0, checksum.getByteCount());
+ assertEquals(-1, checksum.getRemaining());
+ // Ask to read one more byte than there is, we get the correct byte count.
+ assertEquals(byteArray.length, checksum.read(new byte[byteArray.length + 1]));
+ // Next read is at EOF
+ assertThrows(IOException.class, () -> checksum.read(new byte[1]));
+ assertEquals(byteArray.length, checksum.getByteCount());
+ assertEquals(-4, checksum.getRemaining());
+ }
+ }
+
+ @Test
+ public void testDefaultThresholdSuccess() throws IOException {
+ // sanity-check
+ final Adler32 sanityCheck = new Adler32();
+ final byte[] byteArray = new byte[3];
+ sanityCheck.update(byteArray);
+ final long expectedChecksum = sanityCheck.getValue();
+ // actual
+ final Adler32 adler32 = new Adler32();
+ try (ChecksumInputStream checksum = ChecksumInputStream.builder()
+ // @formatter:off
+ .setByteArray(byteArray)
+ .setChecksum(adler32)
+ .setExpectedChecksumValue(expectedChecksum)
+ .get()) {
+ // @formatter:on
+ assertEquals(0, checksum.getByteCount());
+ assertEquals(-1, checksum.getRemaining());
+ assertEquals(3, checksum.read(byteArray));
+ assertEquals(byteArray.length, checksum.getByteCount());
+ assertEquals(-4, checksum.getRemaining());
+ assertEquals(-1, checksum.read(byteArray));
+ assertEquals(byteArray.length, checksum.getByteCount());
+ assertEquals(-4, checksum.getRemaining());
+ }
+ }
+
+ @Test
+ public void testReadTakingByteArrayThrowsException() throws IOException {
+ final Adler32 adler32 = new Adler32();
+ final byte[] byteArray = new byte[3];
+ final long sizeThreshold = -1859L;
+ try (ChecksumInputStream checksum = ChecksumInputStream.builder()
+ // @formatter:off
+ .setByteArray(byteArray)
+ .setChecksum(adler32)
+ .setExpectedChecksumValue((byte) -68)
+ .setCountThreshold(sizeThreshold)
+ .get()) {
+ // @formatter:on
+ assertEquals(0, checksum.getByteCount());
+ assertEquals(sizeThreshold, checksum.getRemaining());
+ // Ask to read one more byte than there is.
+ assertEquals(byteArray.length, checksum.read(new byte[byteArray.length + 1]));
+ // Next read is at EOF
+ assertThrows(IOException.class, () -> checksum.read(new byte[1]));
+ assertEquals(byteArray.length, checksum.getByteCount());
+ assertEquals(sizeThreshold - byteArray.length, checksum.getRemaining());
+ }
+ }
+
+ @Test
+ public void testReadTakingNoArgumentsThrowsException() throws IOException {
+ final CRC32 crc32 = new CRC32();
+ final byte[] byteArray = new byte[9];
+ try (ChecksumInputStream checksum = ChecksumInputStream.builder()
+ // @formatter:off
+ .setByteArray(byteArray)
+ .setChecksum(crc32)
+ .setExpectedChecksumValue((byte) 1)
+ .setCountThreshold(1)
+ .get()) {
+ // @formatter:on
+ assertEquals(0, checksum.getByteCount());
+ assertEquals(1, checksum.getRemaining());
+ assertThrows(IOException.class, () -> checksum.read());
+ assertEquals(1, checksum.getByteCount());
+ assertEquals(0, checksum.getRemaining());
+ }
+ }
+
+ @Test
+ public void testSkip() throws IOException {
+ // sanity-check
+ final CRC32 sanityCheck = new CRC32();
+ final byte[] byteArray = new byte[4];
+ sanityCheck.update(byteArray);
+ final long expectedChecksum = sanityCheck.getValue();
+ // actual
+ final CRC32 crc32 = new CRC32();
+ final InputStream byteArrayInputStream = new ByteArrayInputStream(byteArray);
+ try (ChecksumInputStream checksum = ChecksumInputStream.builder()
+ // @formatter:off
+ .setInputStream(byteArrayInputStream)
+ .setChecksum(crc32)
+ .setExpectedChecksumValue(expectedChecksum)
+ .setCountThreshold(33)
+ .get()) {
+ // @formatter:on
+ assertEquals(0, checksum.getByteCount());
+ @SuppressWarnings("unused")
+ final int intOne = checksum.read(byteArray);
+ assertEquals(byteArray.length, checksum.getByteCount());
+ assertEquals(29, checksum.getRemaining());
+ final long skipReturnValue = checksum.skip((byte) 1);
+ assertEquals(byteArray.length, checksum.getByteCount());
+ assertEquals(29, checksum.getRemaining());
+ assertEquals(558161692L, crc32.getValue());
+ assertEquals(0, byteArrayInputStream.available());
+ assertArrayEquals(new byte[4], byteArray);
+ assertEquals(0L, skipReturnValue);
+ assertEquals(29, checksum.getRemaining());
+ }
+ }
+
+}