[FLINK-19692][core] Add a header for each keyGroup
This closes #168.
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/logger/UnboundedFeedbackLogger.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/logger/UnboundedFeedbackLogger.java
index 51e8e7b..7464f09 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/logger/UnboundedFeedbackLogger.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/logger/UnboundedFeedbackLogger.java
@@ -19,10 +19,13 @@
import static org.apache.flink.util.Preconditions.checkState;
+import java.io.BufferedInputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
+import java.io.PushbackInputStream;
+import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
@@ -30,8 +33,11 @@
import java.util.function.ToIntFunction;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
+import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputSerializer;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
@@ -104,6 +110,7 @@
// to this operator must be written to the underlying stream.
for (Integer keyGroupId : assignedKeyGroupIds) {
checkpointedStreamOperations.startNewKeyGroup(keyedStateOutputStream, keyGroupId);
+ Header.writeHeader(target);
@Nullable KeyGroupStream<T> stream = keyGroupStreams.get(keyGroupId);
if (stream == null) {
@@ -116,8 +123,8 @@
public void replyLoggedEnvelops(InputStream rawKeyedStateInputs, FeedbackConsumer<T> consumer)
throws Exception {
-
- DataInputViewStreamWrapper in = new DataInputViewStreamWrapper(rawKeyedStateInputs);
+ DataInputView in =
+ new DataInputViewStreamWrapper(Header.skipHeaderSilently(rawKeyedStateInputs));
KeyGroupStream.readFrom(in, serializer, consumer);
}
@@ -138,4 +145,37 @@
keyedStateOutputStream = null;
keyGroupStreams.clear();
}
+
+ @VisibleForTesting
+ static final class Header {
+ private static final int STATEFUN_VERSION = 0;
+ private static final int STATEFUN_MAGIC = 710818519;
+ private static final byte[] HEADER_BYTES = headerBytes();
+
+ public static void writeHeader(DataOutputView target) throws IOException {
+ target.write(HEADER_BYTES);
+ }
+
+ public static InputStream skipHeaderSilently(InputStream rawKeyedInput) throws IOException {
+ byte[] header = new byte[HEADER_BYTES.length];
+ PushbackInputStream input =
+ new PushbackInputStream(new BufferedInputStream(rawKeyedInput), header.length);
+ int bytesRead = input.read(header);
+ if (bytesRead > 0 && !Arrays.equals(header, HEADER_BYTES)) {
+ input.unread(header, 0, bytesRead);
+ }
+ return input;
+ }
+
+ private static byte[] headerBytes() {
+ DataOutputSerializer out = new DataOutputSerializer(8);
+ try {
+ out.writeInt(STATEFUN_VERSION);
+ out.writeInt(STATEFUN_MAGIC);
+ } catch (IOException e) {
+ throw new IllegalStateException("Unable to compute the header bytes");
+ }
+ return out.getCopyOfBuffer();
+ }
+ }
}
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/logger/UnboundedFeedbackLoggerTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/logger/UnboundedFeedbackLoggerTest.java
index ac7efdd..afe504a 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/logger/UnboundedFeedbackLoggerTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/logger/UnboundedFeedbackLoggerTest.java
@@ -25,8 +25,11 @@
import java.util.function.Function;
import java.util.stream.IntStream;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputSerializer;
import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync;
import org.apache.flink.statefun.flink.core.di.ObjectContainer;
+import org.apache.flink.statefun.flink.core.logger.UnboundedFeedbackLogger.Header;
import org.hamcrest.Matchers;
import org.junit.AfterClass;
import org.junit.BeforeClass;
@@ -73,12 +76,62 @@
roundTrip(100, 1024);
}
+ @Test
+ public void roundTripWithoutElements() throws Exception {
+ roundTrip(0, 1024);
+ }
+
@Ignore
@Test
public void roundTripWithSpill() throws Exception {
roundTrip(1_000_000, 0);
}
+ @Test
+ public void roundTripWithHeader() throws IOException {
+ DataOutputSerializer out = new DataOutputSerializer(32);
+ Header.writeHeader(out);
+ out.writeInt(123);
+ out.writeInt(456);
+ InputStream in = new ByteArrayInputStream(out.getCopyOfBuffer());
+
+ DataInputViewStreamWrapper view = new DataInputViewStreamWrapper(Header.skipHeaderSilently(in));
+
+ assertThat(view.readInt(), is(123));
+ assertThat(view.readInt(), is(456));
+ }
+
+ @Test
+ public void roundTripWithoutHeader() throws IOException {
+ DataOutputSerializer out = new DataOutputSerializer(32);
+ out.writeInt(123);
+ out.writeInt(456);
+ InputStream in = new ByteArrayInputStream(out.getCopyOfBuffer());
+
+ DataInputViewStreamWrapper view = new DataInputViewStreamWrapper(Header.skipHeaderSilently(in));
+
+ assertThat(view.readInt(), is(123));
+ assertThat(view.readInt(), is(456));
+ }
+
+ @Test
+ public void emptyKeyGroupWithHeader() throws IOException {
+ DataOutputSerializer out = new DataOutputSerializer(32);
+ Header.writeHeader(out);
+ InputStream in = new ByteArrayInputStream(out.getCopyOfBuffer());
+
+ DataInputViewStreamWrapper view = new DataInputViewStreamWrapper(Header.skipHeaderSilently(in));
+
+ assertThat(view.read(), is(-1));
+ }
+
+ @Test
+ public void emptyKeyGroupWithoutHeader() throws IOException {
+ InputStream in = new ByteArrayInputStream(new byte[0]);
+ DataInputViewStreamWrapper view = new DataInputViewStreamWrapper(Header.skipHeaderSilently(in));
+ assertThat(view.read(), is(-1));
+ }
+
private void roundTrip(int numElements, int maxMemoryInBytes) throws Exception {
InputStream input = serializeKeyGroup(1, maxMemoryInBytes, numElements);