[FLINK-23399][checkpoint] Add a benchmark for rescaling
diff --git a/src/main/java/org/apache/flink/state/benchmark/HashMapStateBackendRescalingBenchmarkExecutor.java b/src/main/java/org/apache/flink/state/benchmark/HashMapStateBackendRescalingBenchmarkExecutor.java
new file mode 100644
index 0000000..33ee9b3
--- /dev/null
+++ b/src/main/java/org/apache/flink/state/benchmark/HashMapStateBackendRescalingBenchmarkExecutor.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.state.benchmark;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.config.ConfigUtil;
+import org.apache.flink.config.StateBenchmarkOptions;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.contrib.streaming.state.benchmark.RescalingBenchmarkBuilder;
+import org.apache.flink.runtime.state.hashmap.HashMapStateBackend;
+import org.apache.flink.runtime.state.storage.FileSystemCheckpointStorage;
+import org.openjdk.jmh.annotations.*;
+import org.openjdk.jmh.runner.RunnerException;
+
+import java.io.IOException;
+import java.net.URI;
+
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static org.openjdk.jmh.annotations.Mode.AverageTime;
+
+@OutputTimeUnit(MILLISECONDS)
+@BenchmarkMode(AverageTime)
+@Warmup(iterations = 3)
+public class HashMapStateBackendRescalingBenchmarkExecutor extends RescalingBenchmarkBase {
+    // numberOfKeys = 1250000, keyLen = 96, valueLen = 128, state size ~= 270MB
+    private final int numberOfKeys = 1250000;
+    private final int keyLen = 96;
+
+    public static void main(String[] args) throws RunnerException {
+        runBenchmark(HashMapStateBackendRescalingBenchmarkExecutor.class);
+    }
+
+    @Setup(Level.Trial)
+    public void setUp() throws Exception {
+        // FsStateBackend is deprecated in favor of HashMapStateBackend with setting checkpointStorage.
+        HashMapStateBackend stateBackend = new HashMapStateBackend();
+        Configuration benchMarkConfig = ConfigUtil.loadBenchMarkConf();
+        String stateDataDirPath = benchMarkConfig.getString(StateBenchmarkOptions.STATE_DATA_DIR);
+        benchmark =
+                new RescalingBenchmarkBuilder<byte[]>()
+                        .setMaxParallelism(128)
+                        .setParallelismBefore(rescaleType.getParallelismBefore())
+                        .setParallelismAfter(rescaleType.getParallelismAfter())
+                        .setCheckpointStorageAccess(
+                                new FileSystemCheckpointStorage(new URI("file://" + stateDataDirPath), 0)
+                                        .createCheckpointStorage(new JobID()))
+                        .setStateBackend(stateBackend)
+                        .setStreamRecordGenerator(new ByteArrayRecordGenerator(numberOfKeys, keyLen))
+                        .setStateProcessFunctionSupplier(TestKeyedFunction::new)
+                        .build();
+        benchmark.setUp();
+    }
+
+    @Setup(Level.Iteration)
+    public void setUpPerInvocation() throws Exception {
+        benchmark.prepareStateForOperator(rescaleType.getSubtaskIndex());
+    }
+
+    @TearDown(Level.Trial)
+    public void tearDown() throws IOException {
+        benchmark.tearDown();
+    }
+
+    @Benchmark
+    public void rescaleHeap() throws Exception {
+        benchmark.rescale();
+    }
+
+    @TearDown(Level.Iteration)
+    public void tearDownPerInvocation() throws Exception {
+        benchmark.closeOperator();
+    }
+}
diff --git a/src/main/java/org/apache/flink/state/benchmark/RescalingBenchmarkBase.java b/src/main/java/org/apache/flink/state/benchmark/RescalingBenchmarkBase.java
new file mode 100644
index 0000000..26fc415
--- /dev/null
+++ b/src/main/java/org/apache/flink/state/benchmark/RescalingBenchmarkBase.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.state.benchmark;
+
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.benchmark.BenchmarkBase;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.contrib.streaming.state.benchmark.RescalingBenchmark;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.runner.Runner;
+import org.openjdk.jmh.runner.RunnerException;
+import org.openjdk.jmh.runner.options.Options;
+import org.openjdk.jmh.runner.options.OptionsBuilder;
+import org.openjdk.jmh.runner.options.VerboseMode;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.Random;
+
+public class RescalingBenchmarkBase extends BenchmarkBase {
+
+    @Param({"RESCALE_IN", "RESCALE_OUT"})
+    protected RescaleType rescaleType;
+
+    protected RescalingBenchmark<byte[]> benchmark;
+
+    public static void runBenchmark(Class<?> clazz) throws RunnerException {
+        Options options =
+                new OptionsBuilder()
+                        .verbosity(VerboseMode.NORMAL)
+                        .include(".*" + clazz.getCanonicalName() + ".*")
+                        .build();
+
+        new Runner(options).run();
+    }
+
+    @State(Scope.Thread)
+    public enum RescaleType {
+        RESCALE_OUT(1, 2, 0),
+        RESCALE_IN(2, 1, 0);
+
+        private final int parallelismBefore;
+        private final int parallelismAfter;
+        private final int subtaskIndex;
+
+        RescaleType(int parallelismBefore, int parallelismAfter, int subtaskIdx) {
+            this.parallelismBefore = parallelismBefore;
+            this.parallelismAfter = parallelismAfter;
+            this.subtaskIndex = subtaskIdx;
+        }
+
+        public int getParallelismBefore() {
+            return parallelismBefore;
+        }
+
+        public int getParallelismAfter() {
+            return parallelismAfter;
+        }
+
+        public int getSubtaskIndex() {
+            return subtaskIndex;
+        }
+    }
+
+    protected static class ByteArrayRecordGenerator
+            implements RescalingBenchmark.StreamRecordGenerator<byte[]> {
+        private final Random random = new Random(0);
+        private final int numberOfKeys;
+        private final byte[] fatArray;
+        private int count = 0;
+
+
+        protected ByteArrayRecordGenerator(final int numberOfKeys,
+                                           final int keyLen) {
+            this.numberOfKeys = numberOfKeys;
+            fatArray = new byte[keyLen];
+        }
+
+        // generate deterministic elements for source
+        @Override
+        public Iterator<StreamRecord<byte[]>> generate() {
+            return new Iterator<StreamRecord<byte[]>>() {
+                @Override
+                public boolean hasNext() {
+                    return count < numberOfKeys;
+                }
+
+                @Override
+                public StreamRecord<byte[]> next() {
+                    random.nextBytes(fatArray);
+                    changePrefixOfArray(count, fatArray);
+                    // make the hashcode of keys different.
+                    StreamRecord<byte[]> record =
+                            new StreamRecord<>(Arrays.copyOf(fatArray, fatArray.length), 0);
+                    count += 1;
+                    return record;
+                }
+            };
+        }
+
+        @Override
+        public TypeInformation getTypeInformation() {
+            return PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO;
+        }
+
+        private void changePrefixOfArray(int number, byte[] fatArray) {
+            fatArray[0] = (byte) ((number >> 24) & 0xFF);
+            fatArray[1] = (byte) ((number >> 16) & 0xFF);
+            fatArray[2] = (byte) ((number >> 8) & 0xFF);
+            fatArray[3] = (byte) (number & 0xFF);
+        }
+    }
+
+    protected static class TestKeyedFunction extends KeyedProcessFunction<byte[], byte[], Void> {
+
+        private static final long serialVersionUID = 1L;
+        private final Random random = new Random(0);
+        private final int valueLen = 128;
+
+        private ValueState<byte[]> randomState;
+        private final byte[] stateArray = new byte[valueLen];
+
+        @Override
+        public void open(Configuration parameters) throws Exception {
+            super.open(parameters);
+            randomState =
+                    this.getRuntimeContext()
+                            .getState(new ValueStateDescriptor<>("RandomState", byte[].class));
+        }
+
+        @Override
+        public void processElement(
+                byte[] value,
+                KeyedProcessFunction<byte[], byte[], Void>.Context ctx,
+                Collector<Void> out)
+                throws Exception {
+            random.nextBytes(stateArray);
+            randomState.update(Arrays.copyOf(stateArray, stateArray.length));
+        }
+    }
+}
diff --git a/src/main/java/org/apache/flink/state/benchmark/RocksdbStateBackendRescalingBenchmarkExecutor.java b/src/main/java/org/apache/flink/state/benchmark/RocksdbStateBackendRescalingBenchmarkExecutor.java
new file mode 100644
index 0000000..99a5118
--- /dev/null
+++ b/src/main/java/org/apache/flink/state/benchmark/RocksdbStateBackendRescalingBenchmarkExecutor.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.state.benchmark;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.config.ConfigUtil;
+import org.apache.flink.config.StateBenchmarkOptions;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.contrib.streaming.state.EmbeddedRocksDBStateBackend;
+import org.apache.flink.contrib.streaming.state.benchmark.RescalingBenchmarkBuilder;
+import org.apache.flink.runtime.state.storage.FileSystemCheckpointStorage;
+
+import org.openjdk.jmh.annotations.*;
+import org.openjdk.jmh.runner.RunnerException;
+
+import java.io.IOException;
+
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static org.openjdk.jmh.annotations.Mode.AverageTime;
+
+@OutputTimeUnit(MILLISECONDS)
+@BenchmarkMode(AverageTime)
+@Warmup(iterations = 3)
+public class RocksdbStateBackendRescalingBenchmarkExecutor extends RescalingBenchmarkBase {
+    // numberOfKeys = 10_000_000, keyLen = 96, valueLen = 128, state size ~= 2.2GB
+    private final int numberOfKeys = 10_000_000;
+    private final int keyLen = 96;
+
+    public static void main(String[] args) throws RunnerException {
+        runBenchmark(RocksdbStateBackendRescalingBenchmarkExecutor.class);
+    }
+
+    @Setup(Level.Trial)
+    public void setUp() throws Exception {
+        EmbeddedRocksDBStateBackend stateBackend = new EmbeddedRocksDBStateBackend(true);
+        Configuration benchMarkConfig = ConfigUtil.loadBenchMarkConf();
+        String stateDataDirPath = benchMarkConfig.getString(StateBenchmarkOptions.STATE_DATA_DIR);
+        benchmark =
+                new RescalingBenchmarkBuilder<byte[]>()
+                        .setMaxParallelism(128)
+                        .setParallelismBefore(rescaleType.getParallelismBefore())
+                        .setParallelismAfter(rescaleType.getParallelismAfter())
+                        .setManagedMemorySize(512 * 1024 * 1024)
+                        .setCheckpointStorageAccess(
+                                new FileSystemCheckpointStorage("file://" + stateDataDirPath)
+                                        .createCheckpointStorage(new JobID()))
+                        .setStateBackend(stateBackend)
+                        .setStreamRecordGenerator(new ByteArrayRecordGenerator(numberOfKeys, keyLen))
+                        .setStateProcessFunctionSupplier(TestKeyedFunction::new)
+                        .build();
+        benchmark.setUp();
+    }
+
+    @Setup(Level.Iteration)
+    public void setUpPerInvocation() throws Exception {
+        benchmark.prepareStateForOperator(rescaleType.getSubtaskIndex());
+    }
+
+    @TearDown(Level.Trial)
+    public void tearDown() throws IOException {
+        benchmark.tearDown();
+    }
+
+    @Benchmark
+    public void rescaleRocksDB() throws Exception {
+        benchmark.rescale();
+    }
+
+    @TearDown(Level.Iteration)
+    public void tearDownPerInvocation() throws Exception {
+        benchmark.closeOperator();
+    }
+}