[FLINK-10095][state] Swap serialization order in TTL value: first timestamp then user value
This closes #6510.
(cherry picked from commit b535fab4c4529a15e50174a30c3743275cedaab4)
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java
index 303285a..2f291d3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java
@@ -190,9 +190,10 @@
/** Serializer for user state value with TTL. */
private static class TtlSerializer<T> extends CompositeSerializer<TtlValue<T>> {
+ private static final long serialVersionUID = 131020282727167064L;
TtlSerializer(TypeSerializer<T> userValueSerializer) {
- super(true, userValueSerializer, LongSerializer.INSTANCE);
+ super(true, LongSerializer.INSTANCE, userValueSerializer);
}
TtlSerializer(PrecomputedParameters precomputed, TypeSerializer<?> ... fieldSerializers) {
@@ -203,7 +204,7 @@
@Override
public TtlValue<T> createInstance(@Nonnull Object ... values) {
Preconditions.checkArgument(values.length == 2);
- return new TtlValue<>((T) values[0], (long) values[1]);
+ return new TtlValue<>((T) values[1], (long) values[0]);
}
@Override
@@ -213,7 +214,7 @@
@Override
protected Object getField(@Nonnull TtlValue<T> v, int index) {
- return index == 0 ? v.getUserValue() : v.getLastAccessTimestamp();
+ return index == 0 ? v.getLastAccessTimestamp() : v.getUserValue();
}
@SuppressWarnings("unchecked")
@@ -223,7 +224,7 @@
TypeSerializer<?> ... originalSerializers) {
Preconditions.checkNotNull(originalSerializers);
Preconditions.checkArgument(originalSerializers.length == 2);
- return new TtlSerializer<>(precomputed, (TypeSerializer<T>) originalSerializers[0]);
+ return new TtlSerializer<>(precomputed, (TypeSerializer<T>) originalSerializers[1]);
}
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java
index 228d045..1ee05cd 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java
@@ -19,16 +19,14 @@
package org.apache.flink.runtime.state.ttl;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
-import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.ByteArrayDataInputView;
import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer;
import org.apache.flink.util.FlinkRuntimeException;
-import org.apache.flink.util.Preconditions;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
-import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Optional;
@@ -36,10 +34,12 @@
abstract class TtlStateSnapshotTransformer<T> implements CollectionStateSnapshotTransformer<T> {
private final TtlTimeProvider ttlTimeProvider;
final long ttl;
+ private final ByteArrayDataInputView div;
TtlStateSnapshotTransformer(@Nonnull TtlTimeProvider ttlTimeProvider, long ttl) {
this.ttlTimeProvider = ttlTimeProvider;
this.ttl = ttl;
+ this.div = new ByteArrayDataInputView();
}
<V> TtlValue<V> filterTtlValue(TtlValue<V> value) {
@@ -54,10 +54,9 @@
return TtlUtils.expired(ts, ttl, ttlTimeProvider);
}
- private static long deserializeTs(
- byte[] value, int offset) throws IOException {
- return LongSerializer.INSTANCE.deserialize(
- new DataInputViewStreamWrapper(new ByteArrayInputStream(value, offset, Long.BYTES)));
+ long deserializeTs(byte[] value) throws IOException {
+ div.setData(value, 0, Long.BYTES);
+ return LongSerializer.INSTANCE.deserialize(div);
}
@Override
@@ -88,10 +87,9 @@
if (value == null) {
return null;
}
- Preconditions.checkArgument(value.length >= Long.BYTES);
long ts;
try {
- ts = deserializeTs(value, value.length - Long.BYTES);
+ ts = deserializeTs(value);
} catch (IOException e) {
throw new FlinkRuntimeException("Unexpected timestamp deserialization failure");
}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index f7af354..17ba985 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -1400,6 +1400,10 @@
if (stateDesc instanceof ListStateDescriptor) {
Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState();
return original.map(est -> createRocksDBListStateTransformer(stateDesc, est)).orElse(null);
+ } else if (stateDesc instanceof MapStateDescriptor) {
+ Optional<StateSnapshotTransformer<byte[]>> original = snapshotTransformFactory.createForSerializedState();
+ return (StateSnapshotTransformer<SV>) original
+ .map(RocksDBMapState.StateSnapshotTransformerWrapper::new).orElse(null);
} else {
Optional<StateSnapshotTransformer<byte[]>> original = snapshotTransformFactory.createForSerializedState();
return (StateSnapshotTransformer<SV>) original.orElse(null);
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBMapState.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBMapState.java
index 4ec1f77..b08eade 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBMapState.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBMapState.java
@@ -24,6 +24,8 @@
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.MapSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.ByteArrayDataInputView;
+import org.apache.flink.core.memory.ByteArrayDataOutputView;
import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
@@ -31,6 +33,7 @@
import org.apache.flink.queryablestate.client.state.serialization.KvStateSerializer;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.runtime.state.internal.InternalMapState;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.Preconditions;
@@ -44,9 +47,11 @@
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
@@ -625,4 +630,74 @@
(Map<UK, UV>) stateDesc.getDefaultValue(),
backend);
}
+
+ /**
+ * RocksDB map state specific byte value transformer wrapper.
+ *
+ * <p>This specific transformer wrapper checks the first byte to detect null user value entries
+ * and if not null forward the rest of byte array to the original byte value transformer.
+ */
+ static class StateSnapshotTransformerWrapper implements StateSnapshotTransformer<byte[]> {
+ private static final byte[] NULL_VALUE;
+ private static final byte NON_NULL_VALUE_PREFIX;
+ static {
+ ByteArrayDataOutputView dov = new ByteArrayDataOutputView(1);
+ try {
+ dov.writeBoolean(true);
+ NULL_VALUE = dov.toByteArray();
+ dov.reset();
+ dov.writeBoolean(false);
+ NON_NULL_VALUE_PREFIX = dov.toByteArray()[0];
+ } catch (IOException e) {
+ throw new FlinkRuntimeException("Failed to serialize boolean flag of map user null value", e);
+ }
+ }
+
+ private final StateSnapshotTransformer<byte[]> elementTransformer;
+ private final ByteArrayDataInputView div;
+
+ StateSnapshotTransformerWrapper(StateSnapshotTransformer<byte[]> originalTransformer) {
+ this.elementTransformer = originalTransformer;
+ this.div = new ByteArrayDataInputView();
+ }
+
+ @Override
+ @Nullable
+ public byte[] filterOrTransform(@Nullable byte[] value) {
+ if (value == null || isNull(value)) {
+ return NULL_VALUE;
+ } else {
+ // we have to skip the first byte indicating null user value
+ // TODO: optimization here could be to work with slices and not byte arrays
+ // and copy slice sub-array only when needed
+ byte[] woNullByte = Arrays.copyOfRange(value, 1, value.length);
+ byte[] filteredValue = elementTransformer.filterOrTransform(woNullByte);
+ if (filteredValue == null) {
+ filteredValue = NULL_VALUE;
+ } else if (filteredValue != woNullByte) {
+ filteredValue = prependWithNonNullByte(filteredValue, value);
+ } else {
+ filteredValue = value;
+ }
+ return filteredValue;
+ }
+ }
+
+ private boolean isNull(byte[] value) {
+ try {
+ div.setData(value, 0, 1);
+ return div.readBoolean();
+ } catch (IOException e) {
+ throw new FlinkRuntimeException("Failed to deserialize boolean flag of map user null value", e);
+ }
+ }
+
+ private static byte[] prependWithNonNullByte(byte[] value, byte[] reuse) {
+ int len = 1 + value.length;
+ byte[] result = reuse.length == len ? reuse : new byte[len];
+ result[0] = NON_NULL_VALUE_PREFIX;
+ System.arraycopy(value, 0, result, 1, value.length);
+ return result;
+ }
+ }
}