blob: 14eaac3f21df57e5e249313a916c5b7e14ea7803 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.tests.queryablestate;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.state.MapState;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.filesystem.FsStateBackend;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
import org.apache.flink.util.Collector;
import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables;
import java.time.Duration;
import java.time.Instant;
import java.util.Random;
/**
* Streaming application that creates an {@link Email} pojo with random ids and increasing
* timestamps and passes it to a stateful {@link org.apache.flink.api.common.functions.FlatMapFunction},
* where it is exposed as queryable state.
*/
public class QsStateProducer {
public static void main(final String[] args) throws Exception {
final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
ParameterTool tool = ParameterTool.fromArgs(args);
String tmpPath = tool.getRequired("tmp-dir");
String stateBackendType = tool.getRequired("state-backend");
StateBackend stateBackend;
switch (stateBackendType) {
case "rocksdb":
stateBackend = new RocksDBStateBackend(tmpPath);
break;
case "fs":
stateBackend = new FsStateBackend(tmpPath);
break;
case "memory":
stateBackend = new MemoryStateBackend();
break;
default:
throw new RuntimeException("Unsupported state backend " + stateBackendType);
}
env.setStateBackend(stateBackend);
env.enableCheckpointing(1000L);
env.getCheckpointConfig().setMaxConcurrentCheckpoints(1);
env.getCheckpointConfig().setMinPauseBetweenCheckpoints(0);
env.addSource(new EmailSource())
.keyBy(new KeySelector<Email, String>() {
private static final long serialVersionUID = -1480525724620425363L;
@Override
public String getKey(Email value) throws Exception {
return QsConstants.KEY;
}
})
.flatMap(new TestFlatMap());
env.execute();
}
private static class EmailSource extends RichSourceFunction<Email> {
private static final long serialVersionUID = -7286937645300388040L;
private transient volatile boolean isRunning;
private transient Random random;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
this.random = new Random();
this.isRunning = true;
}
@Override
public void run(SourceContext<Email> ctx) throws Exception {
// Sleep for 10 seconds on start to allow time to copy jobid
Thread.sleep(10000L);
int types = LabelSurrogate.Type.values().length;
while (isRunning) {
int r = random.nextInt(100);
final EmailId emailId = new EmailId(Integer.toString(random.nextInt()));
final Instant timestamp = Instant.now().minus(Duration.ofDays(1L));
final String foo = String.format("foo #%d", r);
final LabelSurrogate label = new LabelSurrogate(LabelSurrogate.Type.values()[r % types], "bar");
synchronized (ctx.getCheckpointLock()) {
ctx.collect(new Email(emailId, timestamp, foo, label));
}
Thread.sleep(30L);
}
}
@Override
public void cancel() {
isRunning = false;
}
}
private static class TestFlatMap extends RichFlatMapFunction<Email, Object> implements CheckpointedFunction {
private static final long serialVersionUID = 7821128115999005941L;
private transient MapState<EmailId, EmailInformation> state;
private transient int count;
@Override
public void open(Configuration parameters) {
MapStateDescriptor<EmailId, EmailInformation> stateDescriptor =
new MapStateDescriptor<>(
QsConstants.STATE_NAME,
TypeInformation.of(new TypeHint<EmailId>() {
}),
TypeInformation.of(new TypeHint<EmailInformation>() {
})
);
stateDescriptor.setQueryable(QsConstants.QUERY_NAME);
state = getRuntimeContext().getMapState(stateDescriptor);
count = -1;
}
@Override
public void flatMap(Email value, Collector<Object> out) throws Exception {
state.put(value.getEmailId(), new EmailInformation(value));
count = Iterables.size(state.keys());
}
@Override
public void snapshotState(FunctionSnapshotContext context) {
System.out.println("Count on snapshot: " + count); // we look for it in the test
}
@Override
public void initializeState(FunctionInitializationContext context) {
}
}
}