[IO-853] BoundedInputStream.reset() not updating count
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index 813ef3b..281f0a0 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -51,6 +51,7 @@
<action dev="ggregory" type="fix" due-to="Gary Gregory">Reimplement FileSystemUtils using NIO.</action>
<action dev="ggregory" type="fix" issue="IO-851" due-to="Sebb, Gary Gregory">FileSystemUtils no longer throws IllegalStateException.</action>
<action dev="ggregory" type="fix" due-to="Gary Gregory">Avoid possible NullPointerException in FileUtils.listAccumulate(File, IOFileFilter, IOFileFilter, FileVisitOption...).</action>
+ <action dev="ggregory" type="fix" issue="IO-853" due-to="Mike Drob, Gary Gregory">BoundedInputStream.reset() not updating count.</action>
<!-- UPDATE -->
<action dev="ggregory" type="update" due-to="Gary Gregory">Bump commons.bytebuddy.version from 1.14.12 to 1.14.13 #605.</action>
<action dev="ggregory" type="update" due-to="Gary Gregory, Dependabot">Bump org.apache.commons:commons-parent from 67 to 69 #608.</action>
diff --git a/src/main/java/org/apache/commons/io/input/BoundedInputStream.java b/src/main/java/org/apache/commons/io/input/BoundedInputStream.java
index 91e56f7..31767e4 100644
--- a/src/main/java/org/apache/commons/io/input/BoundedInputStream.java
+++ b/src/main/java/org/apache/commons/io/input/BoundedInputStream.java
@@ -249,6 +249,9 @@
/** The current count of bytes counted. */
private long count;
+ /** The current mark. */
+ private long mark;
+
/** The max count of bytes to read. */
private final long maxCount;
@@ -347,7 +350,7 @@
* @return The count of bytes read.
* @since 2.12.0
*/
- public long getCount() {
+ public synchronized long getCount() {
return count;
}
@@ -404,6 +407,7 @@
@Override
public synchronized void mark(final int readLimit) {
in.mark(readLimit);
+ mark = count;
}
/**
@@ -482,6 +486,7 @@
@Override
public synchronized void reset() throws IOException {
in.reset();
+ count = mark;
}
/**
@@ -504,7 +509,7 @@
* @throws IOException if an I/O error occurs.
*/
@Override
- public long skip(final long n) throws IOException {
+ public synchronized long skip(final long n) throws IOException {
final long skip = super.skip(toReadLen(n));
count += skip;
return skip;
diff --git a/src/test/java/org/apache/commons/io/input/BoundedInputStreamTest.java b/src/test/java/org/apache/commons/io/input/BoundedInputStreamTest.java
index ad3d50e..8265f76 100644
--- a/src/test/java/org/apache/commons/io/input/BoundedInputStreamTest.java
+++ b/src/test/java/org/apache/commons/io/input/BoundedInputStreamTest.java
@@ -27,6 +27,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.io.IOUtils;
+import org.apache.commons.lang3.mutable.MutableInt;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
@@ -36,10 +37,12 @@
*/
public class BoundedInputStreamTest {
- private void compare(final String msg, final byte[] expected, final byte[] actual) {
- assertEquals(expected.length, actual.length, msg + " length");
+ private void compare(final String message, final byte[] expected, final byte[] actual) {
+ assertEquals(expected.length, actual.length, () -> message + " (array length equals check)");
+ final MutableInt mi = new MutableInt();
for (int i = 0; i < expected.length; i++) {
- assertEquals(expected[i], actual[i], msg + " byte[" + i + "]");
+ mi.setValue(i);
+ assertEquals(expected[i], actual[i], () -> message + " byte[" + mi + "]");
}
}
@@ -144,6 +147,107 @@
}
}
+ @Test
+ public void testMarkReset() throws Exception {
+ final byte[] helloWorld = "Hello World".getBytes(StandardCharsets.UTF_8);
+ final int helloWorldLen = helloWorld.length;
+ final byte[] hello = "Hello".getBytes(StandardCharsets.UTF_8);
+ final byte[] world = " World".getBytes(StandardCharsets.UTF_8);
+ final int helloLen = hello.length;
+ // limit = -1
+ try (BoundedInputStream bounded = BoundedInputStream.builder().setInputStream(new ByteArrayInputStream(helloWorld)).get()) {
+ assertTrue(bounded.markSupported());
+ bounded.mark(0);
+ compare("limit = -1", helloWorld, IOUtils.toByteArray(bounded));
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ // again
+ bounded.reset();
+ compare("limit = -1", hello, IOUtils.toByteArray(bounded, helloLen));
+ bounded.mark(helloWorldLen);
+ compare("limit = -1", world, IOUtils.toByteArray(bounded));
+ bounded.reset();
+ compare("limit = -1", world, IOUtils.toByteArray(bounded));
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ }
+ // limit = 0
+ try (BoundedInputStream bounded = BoundedInputStream.builder().setInputStream(new ByteArrayInputStream(helloWorld)).setMaxCount(0).get()) {
+ assertTrue(bounded.markSupported());
+ bounded.mark(0);
+ compare("limit = 0", IOUtils.EMPTY_BYTE_ARRAY, IOUtils.toByteArray(bounded));
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ // again
+ bounded.reset();
+ compare("limit = 0", IOUtils.EMPTY_BYTE_ARRAY, IOUtils.toByteArray(bounded));
+ bounded.mark(helloWorldLen);
+ compare("limit = 0", IOUtils.EMPTY_BYTE_ARRAY, IOUtils.toByteArray(bounded));
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ }
+ // limit = length
+ try (BoundedInputStream bounded = BoundedInputStream.builder().setInputStream(new ByteArrayInputStream(helloWorld))
+ .setMaxCount(helloWorld.length).get()) {
+ assertTrue(bounded.markSupported());
+ bounded.mark(0);
+ compare("limit = length", helloWorld, IOUtils.toByteArray(bounded));
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ // again
+ bounded.reset();
+ compare("limit = length", hello, IOUtils.toByteArray(bounded, helloLen));
+ bounded.mark(helloWorldLen);
+ compare("limit = length", world, IOUtils.toByteArray(bounded));
+ bounded.reset();
+ compare("limit = length", world, IOUtils.toByteArray(bounded));
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ }
+ // limit > length
+ try (BoundedInputStream bounded = BoundedInputStream.builder().setInputStream(new ByteArrayInputStream(helloWorld))
+ .setMaxCount(helloWorld.length + 1).get()) {
+ assertTrue(bounded.markSupported());
+ bounded.mark(0);
+ compare("limit > length", helloWorld, IOUtils.toByteArray(bounded));
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ // again
+ bounded.reset();
+ compare("limit > length", helloWorld, IOUtils.toByteArray(bounded));
+ bounded.reset();
+ compare("limit > length", hello, IOUtils.toByteArray(bounded, helloLen));
+ bounded.mark(helloWorldLen);
+ compare("limit > length", world, IOUtils.toByteArray(bounded));
+ bounded.reset();
+ compare("limit > length", world, IOUtils.toByteArray(bounded));
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ }
+ // limit < length
+ try (BoundedInputStream bounded = BoundedInputStream.builder().setInputStream(new ByteArrayInputStream(helloWorld))
+ .setMaxCount(helloWorld.length - (hello.length + 1)).get()) {
+ assertTrue(bounded.markSupported());
+ bounded.mark(0);
+ compare("limit < length", hello, IOUtils.toByteArray(bounded));
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ // again
+ bounded.reset();
+ compare("limit < length", hello, IOUtils.toByteArray(bounded));
+ //
+ bounded.reset();
+ compare("limit < length", hello, IOUtils.toByteArray(bounded, helloLen));
+ bounded.mark(helloWorldLen);
+ compare("limit < length", IOUtils.EMPTY_BYTE_ARRAY, IOUtils.toByteArray(bounded));
+ bounded.reset();
+ compare("limit < length", IOUtils.EMPTY_BYTE_ARRAY, IOUtils.toByteArray(bounded));
+
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ }
+ }
+
@SuppressWarnings("deprecation")
@Test
public void testOnMaxLength() throws Exception {
@@ -326,35 +430,70 @@
public void testReset() throws Exception {
final byte[] helloWorld = "Hello World".getBytes(StandardCharsets.UTF_8);
final byte[] hello = "Hello".getBytes(StandardCharsets.UTF_8);
+ // limit = -1
try (BoundedInputStream bounded = BoundedInputStream.builder().setInputStream(new ByteArrayInputStream(helloWorld)).get()) {
assertTrue(bounded.markSupported());
+ bounded.reset();
+ compare("limit = -1", helloWorld, IOUtils.toByteArray(bounded));
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ // again
+ bounded.reset();
compare("limit = -1", helloWorld, IOUtils.toByteArray(bounded));
// should be invariant
assertTrue(bounded.markSupported());
}
+ // limit = 0
try (BoundedInputStream bounded = BoundedInputStream.builder().setInputStream(new ByteArrayInputStream(helloWorld)).setMaxCount(0).get()) {
assertTrue(bounded.markSupported());
+ bounded.reset();
+ compare("limit = 0", IOUtils.EMPTY_BYTE_ARRAY, IOUtils.toByteArray(bounded));
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ // again
+ bounded.reset();
compare("limit = 0", IOUtils.EMPTY_BYTE_ARRAY, IOUtils.toByteArray(bounded));
// should be invariant
assertTrue(bounded.markSupported());
}
+ // limit = length
try (BoundedInputStream bounded = BoundedInputStream.builder().setInputStream(new ByteArrayInputStream(helloWorld))
.setMaxCount(helloWorld.length).get()) {
assertTrue(bounded.markSupported());
+ bounded.reset();
+ compare("limit = length", helloWorld, IOUtils.toByteArray(bounded));
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ // again
+ bounded.reset();
compare("limit = length", helloWorld, IOUtils.toByteArray(bounded));
// should be invariant
assertTrue(bounded.markSupported());
}
+ // limit > length
try (BoundedInputStream bounded = BoundedInputStream.builder().setInputStream(new ByteArrayInputStream(helloWorld))
.setMaxCount(helloWorld.length + 1).get()) {
assertTrue(bounded.markSupported());
+ bounded.reset();
+ compare("limit > length", helloWorld, IOUtils.toByteArray(bounded));
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ // again
+ bounded.reset();
compare("limit > length", helloWorld, IOUtils.toByteArray(bounded));
// should be invariant
assertTrue(bounded.markSupported());
}
+ // limit < length
try (BoundedInputStream bounded = BoundedInputStream.builder().setInputStream(new ByteArrayInputStream(helloWorld))
.setMaxCount(helloWorld.length - 6).get()) {
assertTrue(bounded.markSupported());
+ bounded.reset();
+ compare("limit < length", hello, IOUtils.toByteArray(bounded));
+ // should be invariant
+ assertTrue(bounded.markSupported());
+ // again
+ bounded.reset();
compare("limit < length", hello, IOUtils.toByteArray(bounded));
// should be invariant
assertTrue(bounded.markSupported());