blob: 1aba6583d699c65a44ec017c1486bea9a730de85 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.migration
import java.util
import org.apache.flink.api.common.accumulators.IntCounter
import org.apache.flink.api.common.functions.RichFlatMapFunction
import org.apache.flink.api.common.state._
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.tuple.Tuple2
import org.apache.flink.api.scala.createTypeInformation
import org.apache.flink.api.scala.migration.CustomEnum.CustomEnum
import org.apache.flink.configuration.Configuration
import org.apache.flink.contrib.streaming.state.RocksDBStateBackend
import org.apache.flink.runtime.state.memory.MemoryStateBackend
import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext, StateBackendLoader}
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction
import org.apache.flink.streaming.api.functions.source.SourceFunction
import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.test.checkpointing.utils.SavepointMigrationTestBase
import org.apache.flink.testutils.migration.MigrationVersion
import org.apache.flink.util.Collector
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.{Assert, Ignore, Test}
import scala.util.{Failure, Try}
object StatefulJobWBroadcastStateMigrationITCase {
@Parameterized.Parameters(name = "Migrate Savepoint / Backend: {0}")
def parameters: util.Collection[(MigrationVersion, String)] = {
util.Arrays.asList(
(MigrationVersion.v1_5, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
(MigrationVersion.v1_5, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME),
(MigrationVersion.v1_6, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
(MigrationVersion.v1_6, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME),
(MigrationVersion.v1_7, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
(MigrationVersion.v1_7, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME),
(MigrationVersion.v1_8, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
(MigrationVersion.v1_8, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME),
(MigrationVersion.v1_9, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
(MigrationVersion.v1_9, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME),
(MigrationVersion.v1_10, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
(MigrationVersion.v1_10, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME),
(MigrationVersion.v1_11, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
(MigrationVersion.v1_11, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME),
(MigrationVersion.v1_12, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
(MigrationVersion.v1_12, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME),
(MigrationVersion.v1_13, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
(MigrationVersion.v1_13, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME))
}
// TODO to generate savepoints for a specific Flink version / backend type,
// TODO change these values accordingly, e.g. to generate for 1.3 with RocksDB,
// TODO set as (MigrationVersion.v1_3, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME)
// TODO Note: You should generate the savepoint based on the release branch instead of the master.
val GENERATE_SAVEPOINT_VER: MigrationVersion = MigrationVersion.v1_9
val GENERATE_SAVEPOINT_BACKEND_TYPE: String = StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME
val NUM_ELEMENTS = 4
}
/**
* ITCase for migration Scala state types across different Flink versions.
*/
@RunWith(classOf[Parameterized])
class StatefulJobWBroadcastStateMigrationITCase(
migrationVersionAndBackend: (MigrationVersion, String))
extends SavepointMigrationTestBase with Serializable {
@Test
@Ignore
def testCreateSavepointWithBroadcastState(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
StatefulJobWBroadcastStateMigrationITCase.GENERATE_SAVEPOINT_BACKEND_TYPE match {
case StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME =>
env.setStateBackend(new RocksDBStateBackend(new MemoryStateBackend()))
case StateBackendLoader.MEMORY_STATE_BACKEND_NAME =>
env.setStateBackend(new MemoryStateBackend())
case _ => throw new UnsupportedOperationException
}
lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long](
"broadcast-state-1",
BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]])
lazy val secondBroadcastStateDesc = new MapStateDescriptor[String, String](
"broadcast-state-2",
BasicTypeInfo.STRING_TYPE_INFO,
BasicTypeInfo.STRING_TYPE_INFO)
env.setStateBackend(new MemoryStateBackend)
env.enableCheckpointing(500)
env.setParallelism(4)
env.setMaxParallelism(4)
val stream = env
.addSource(
new CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedSource")
.keyBy(
new KeySelector[(Long, Long), Long] {
override def getKey(value: (Long, Long)): Long = value._1
}
)
.flatMap(new StatefulFlatMapper)
.keyBy(
new KeySelector[(Long, Long), Long] {
override def getKey(value: (Long, Long)): Long = value._1
}
)
val broadcastStream = env
.addSource(
new CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedBroadcastSource")
.broadcast(firstBroadcastStateDesc, secondBroadcastStateDesc)
stream
.connect(broadcastStream)
.process(new TestBroadcastProcessFunction)
.addSink(new AccumulatorCountingSink)
executeAndSavepoint(
env,
s"src/test/resources/stateful-scala-with-broadcast" +
s"-udf-migration-itcase-flink" +
s"${StatefulJobWBroadcastStateMigrationITCase.GENERATE_SAVEPOINT_VER}" +
s"-${StatefulJobWBroadcastStateMigrationITCase.GENERATE_SAVEPOINT_BACKEND_TYPE}-savepoint",
new Tuple2(
AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR,
StatefulJobWBroadcastStateMigrationITCase.NUM_ELEMENTS
)
)
}
@Test
def testRestoreSavepointWithBroadcast(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
migrationVersionAndBackend._2 match {
case StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME =>
env.setStateBackend(new RocksDBStateBackend(new MemoryStateBackend()))
case StateBackendLoader.MEMORY_STATE_BACKEND_NAME =>
env.setStateBackend(new MemoryStateBackend())
case _ => throw new UnsupportedOperationException
}
lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long](
"broadcast-state-1",
BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]])
lazy val secondBroadcastStateDesc = new MapStateDescriptor[String, String](
"broadcast-state-2",
BasicTypeInfo.STRING_TYPE_INFO,
BasicTypeInfo.STRING_TYPE_INFO)
env.setStateBackend(new MemoryStateBackend)
env.enableCheckpointing(500)
env.setParallelism(4)
env.setMaxParallelism(4)
val stream = env
.addSource(
new CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedSource")
.keyBy(
new KeySelector[(Long, Long), Long] {
override def getKey(value: (Long, Long)): Long = value._1
}
)
.flatMap(new StatefulFlatMapper)
.keyBy(
new KeySelector[(Long, Long), Long] {
override def getKey(value: (Long, Long)): Long = value._1
}
)
val broadcastStream = env
.addSource(
new CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedBroadcastSource")
.broadcast(firstBroadcastStateDesc, secondBroadcastStateDesc)
val expectedFirstState: Map[Long, Long] =
Map(0L -> 0L, 1L -> 1L, 2L -> 2L, 3L -> 3L)
val expectedSecondState: Map[String, String] =
Map("0" -> "0", "1" -> "1", "2" -> "2", "3" -> "3")
stream
.connect(broadcastStream)
.process(new VerifyingBroadcastProcessFunction(expectedFirstState, expectedSecondState))
.addSink(new AccumulatorCountingSink)
restoreAndExecute(
env,
SavepointMigrationTestBase.getResourceFilename(
s"stateful-scala-with-broadcast" +
s"-udf-migration-itcase-flink${migrationVersionAndBackend._1}" +
s"-${migrationVersionAndBackend._2}-savepoint"),
new Tuple2(
AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR,
StatefulJobWBroadcastStateMigrationITCase.NUM_ELEMENTS)
)
}
}
class TestBroadcastProcessFunction
extends KeyedBroadcastProcessFunction
[Long, (Long, Long), (Long, Long), (Long, Long)] {
lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long](
"broadcast-state-1",
BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]])
val secondBroadcastStateDesc = new MapStateDescriptor[String, String](
"broadcast-state-2",
BasicTypeInfo.STRING_TYPE_INFO,
BasicTypeInfo.STRING_TYPE_INFO)
@throws[Exception]
override def processElement(
value: (Long, Long),
ctx: KeyedBroadcastProcessFunction
[Long, (Long, Long), (Long, Long), (Long, Long)]#ReadOnlyContext,
out: Collector[(Long, Long)]): Unit = {
out.collect(value)
}
@throws[Exception]
override def processBroadcastElement(
value: (Long, Long),
ctx: KeyedBroadcastProcessFunction
[Long, (Long, Long), (Long, Long), (Long, Long)]#Context,
out: Collector[(Long, Long)]): Unit = {
ctx.getBroadcastState(firstBroadcastStateDesc).put(value._1, value._2)
ctx.getBroadcastState(secondBroadcastStateDesc).put(value._1.toString, value._2.toString)
}
}
@SerialVersionUID(1L)
private object CheckpointedSource {
var CHECKPOINTED_STRING = "Here be dragons!"
}
@SerialVersionUID(1L)
private class CheckpointedSource(val numElements: Int)
extends SourceFunction[(Long, Long)] with CheckpointedFunction {
private var isRunning = true
private var state: ListState[CustomCaseClass] = _
@throws[Exception]
override def run(ctx: SourceFunction.SourceContext[(Long, Long)]) {
ctx.emitWatermark(new Watermark(0))
ctx.getCheckpointLock synchronized {
var i = 0
while (i < numElements) {
ctx.collect(i, i)
i += 1
}
}
// don't emit a final watermark so that we don't trigger the registered event-time
// timers
while (isRunning) Thread.sleep(20)
}
def cancel() {
isRunning = false
}
override def initializeState(context: FunctionInitializationContext): Unit = {
state = context.getOperatorStateStore.getListState(
new ListStateDescriptor[CustomCaseClass](
"sourceState", createTypeInformation[CustomCaseClass]))
}
override def snapshotState(context: FunctionSnapshotContext): Unit = {
state.clear()
state.add(CustomCaseClass("Here be dragons!", 123))
}
}
@SerialVersionUID(1L)
private object AccumulatorCountingSink {
var NUM_ELEMENTS_ACCUMULATOR = classOf[AccumulatorCountingSink[_]] + "_NUM_ELEMENTS"
}
@SerialVersionUID(1L)
private class AccumulatorCountingSink[T] extends RichSinkFunction[T] {
private var count: Int = 0
@throws[Exception]
override def open(parameters: Configuration) {
super.open(parameters)
getRuntimeContext.addAccumulator(
AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, new IntCounter)
}
@throws[Exception]
override def invoke(value: T) {
count += 1
getRuntimeContext.getAccumulator(
AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR).add(1)
}
}
class StatefulFlatMapper extends RichFlatMapFunction[(Long, Long), (Long, Long)] {
private var caseClassState: ValueState[CustomCaseClass] = _
private var caseClassWithNestingState: ValueState[CustomCaseClassWithNesting] = _
private var collectionState: ValueState[List[CustomCaseClass]] = _
private var tryState: ValueState[Try[CustomCaseClass]] = _
private var tryFailureState: ValueState[Try[CustomCaseClass]] = _
private var optionState: ValueState[Option[CustomCaseClass]] = _
private var optionNoneState: ValueState[Option[CustomCaseClass]] = _
private var eitherLeftState: ValueState[Either[CustomCaseClass, String]] = _
private var eitherRightState: ValueState[Either[CustomCaseClass, String]] = _
private var enumOneState: ValueState[CustomEnum] = _
private var enumThreeState: ValueState[CustomEnum] = _
override def open(parameters: Configuration): Unit = {
caseClassState = getRuntimeContext.getState(
new ValueStateDescriptor[CustomCaseClass](
"caseClassState", createTypeInformation[CustomCaseClass]))
caseClassWithNestingState = getRuntimeContext.getState(
new ValueStateDescriptor[CustomCaseClassWithNesting](
"caseClassWithNestingState", createTypeInformation[CustomCaseClassWithNesting]))
collectionState = getRuntimeContext.getState(
new ValueStateDescriptor[List[CustomCaseClass]](
"collectionState", createTypeInformation[List[CustomCaseClass]]))
tryState = getRuntimeContext.getState(
new ValueStateDescriptor[Try[CustomCaseClass]](
"tryState", createTypeInformation[Try[CustomCaseClass]]))
tryFailureState = getRuntimeContext.getState(
new ValueStateDescriptor[Try[CustomCaseClass]](
"tryFailureState", createTypeInformation[Try[CustomCaseClass]]))
optionState = getRuntimeContext.getState(
new ValueStateDescriptor[Option[CustomCaseClass]](
"optionState", createTypeInformation[Option[CustomCaseClass]]))
optionNoneState = getRuntimeContext.getState(
new ValueStateDescriptor[Option[CustomCaseClass]](
"optionNoneState", createTypeInformation[Option[CustomCaseClass]]))
eitherLeftState = getRuntimeContext.getState(
new ValueStateDescriptor[Either[CustomCaseClass, String]](
"eitherLeftState", createTypeInformation[Either[CustomCaseClass, String]]))
eitherRightState = getRuntimeContext.getState(
new ValueStateDescriptor[Either[CustomCaseClass, String]](
"eitherRightState", createTypeInformation[Either[CustomCaseClass, String]]))
enumOneState = getRuntimeContext.getState(
new ValueStateDescriptor[CustomEnum](
"enumOneState", createTypeInformation[CustomEnum]))
enumThreeState = getRuntimeContext.getState(
new ValueStateDescriptor[CustomEnum](
"enumThreeState", createTypeInformation[CustomEnum]))
}
override def flatMap(in: (Long, Long), collector: Collector[(Long, Long)]): Unit = {
caseClassState.update(CustomCaseClass(in._1.toString, in._2 * 2))
caseClassWithNestingState.update(
CustomCaseClassWithNesting(in._1, CustomCaseClass(in._1.toString, in._2 * 2)))
collectionState.update(List(CustomCaseClass(in._1.toString, in._2 * 2)))
tryState.update(Try(CustomCaseClass(in._1.toString, in._2 * 5)))
tryFailureState.update(Failure(new RuntimeException))
optionState.update(Some(CustomCaseClass(in._1.toString, in._2 * 2)))
optionNoneState.update(None)
eitherLeftState.update(Left(CustomCaseClass(in._1.toString, in._2 * 2)))
eitherRightState.update(Right((in._1 * 3).toString))
enumOneState.update(CustomEnum.ONE)
enumOneState.update(CustomEnum.THREE)
collector.collect(in)
}
}
class VerifyingBroadcastProcessFunction(
firstExpectedBroadcastState: Map[Long, Long],
secondExpectedBroadcastState: Map[String, String])
extends KeyedBroadcastProcessFunction
[Long, (Long, Long), (Long, Long), (Long, Long)] {
lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long](
"broadcast-state-1",
BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]])
val secondBroadcastStateDesc = new MapStateDescriptor[String, String](
"broadcast-state-2",
BasicTypeInfo.STRING_TYPE_INFO,
BasicTypeInfo.STRING_TYPE_INFO)
@throws[Exception]
override def processElement(
value: (Long, Long),
ctx: KeyedBroadcastProcessFunction
[Long, (Long, Long), (Long, Long), (Long, Long)]#ReadOnlyContext,
out: Collector[(Long, Long)]): Unit = {
var actualFirstState = Map[Long, Long]()
import scala.collection.JavaConversions._
for (entry <- ctx.getBroadcastState(firstBroadcastStateDesc).immutableEntries()) {
val v = firstExpectedBroadcastState.get(entry.getKey).get
Assert.assertEquals(v, entry.getValue)
actualFirstState += (entry.getKey -> entry.getValue)
}
Assert.assertEquals(firstExpectedBroadcastState, actualFirstState)
var actualSecondState = Map[String, String]()
import scala.collection.JavaConversions._
for (entry <- ctx.getBroadcastState(secondBroadcastStateDesc).immutableEntries()) {
val v = secondExpectedBroadcastState.get(entry.getKey).get
Assert.assertEquals(v, entry.getValue)
actualSecondState += (entry.getKey -> entry.getValue)
}
Assert.assertEquals(secondExpectedBroadcastState, actualSecondState)
out.collect(value)
}
@throws[Exception]
override def processBroadcastElement(
value: (Long, Long),
ctx: KeyedBroadcastProcessFunction
[Long, (Long, Long), (Long, Long), (Long, Long)]#Context,
out: Collector[(Long, Long)]): Unit = {
// do nothing
}
}