blob: 7eea30128f121d7de25578ee054ccb1cdaba30bb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.api;
import org.apache.flink.api.common.state.BroadcastState;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.junit.Assert;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
/** Defines a few Stage subclasses to be used in unit tests. */
public class ExampleStages {
/**
* A Model subclass that increments every value in the input stream by `delta` and outputs the
* resulting values.
*/
public static class SumModel implements Model<SumModel> {
private final Map<Param<?>, Object> paramMap = new HashMap<>();
private DataStream<Integer> modelData;
private Table modelDataTable;
// This empty constructor is necessary in order for ModelA to be loaded by
// ReadWriteUtils.createStageWithParam
public SumModel() {
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
}
@Override
public Map<Param<?>, Object> getParamMap() {
return paramMap;
}
@Override
public Table[] transform(Table... inputs) {
Assert.assertEquals(1, inputs.length);
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
DataStream<Integer> input = tEnv.toDataStream(inputs[0], Integer.class);
DataStream<Integer> output =
input.connect(modelData.broadcast())
.transform(
"ApplyDeltaOperator",
BasicTypeInfo.INT_TYPE_INFO,
new ApplyDeltaOperator());
return new Table[] {tEnv.fromDataStream(output)};
}
@Override
public SumModel setModelData(Table... inputs) {
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
modelData = tEnv.toDataStream(inputs[0], Integer.class);
modelDataTable = inputs[0];
return this;
}
@Override
public Table[] getModelData() {
return new Table[] {modelDataTable};
}
@Override
public void save(String path) throws IOException {
ReadWriteUtils.saveModelData(modelData, path, new TestUtils.IntEncoder());
ReadWriteUtils.saveMetadata(this, path);
}
public static SumModel load(StreamTableEnvironment tEnv, String path) throws IOException {
Table modelDataTable =
ReadWriteUtils.loadModelData(tEnv, path, new TestUtils.IntegerStreamFormat());
SumModel model = ReadWriteUtils.loadStageParam(path);
return model.setModelData(modelDataTable);
}
}
// Adds delta from the 2nd input to every element in the 1st input and returns the added values.
private static class ApplyDeltaOperator extends AbstractStreamOperator<Integer>
implements TwoInputStreamOperator<Integer, Integer, Integer> {
private ListState<Integer> unProcessedValues;
private BroadcastState<String, Integer> broadcastState = null;
@Override
public void initializeState(StateInitializationContext context) throws Exception {
unProcessedValues =
context.getOperatorStateStore()
.getListState(
new ListStateDescriptor<Integer>(
"unProcessedValues", Integer.class));
broadcastState =
context.getOperatorStateStore()
.getBroadcastState(
new MapStateDescriptor<String, Integer>(
"broadcastState", String.class, Integer.class));
}
@Override
public void processElement1(StreamRecord<Integer> record) throws Exception {
if (broadcastState.get("delta") == null) {
unProcessedValues.add(record.getValue());
} else {
output.collect(new StreamRecord<>(record.getValue() + broadcastState.get("delta")));
}
}
@Override
public void processElement2(StreamRecord<Integer> record) throws Exception {
if (broadcastState.get("delta") != null) {
throw new IllegalStateException("Model data should have exactly one value");
}
broadcastState.put("delta", record.getValue());
for (Integer value : unProcessedValues.get()) {
output.collect(new StreamRecord<>(value + record.getValue()));
}
unProcessedValues.clear();
}
}
/**
* An Estimator subclass which calculates the sum of input values and instantiates a ModelA
* instance with delta=sum(inputs).
*/
public static class SumEstimator implements Estimator<SumEstimator, SumModel> {
private final Map<Param<?>, Object> paramMap = new HashMap<>();
public SumEstimator() {
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
}
@Override
public Map<Param<?>, Object> getParamMap() {
return paramMap;
}
@Override
public SumModel fit(Table... inputs) {
Assert.assertEquals(1, inputs.length);
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
DataStream<Integer> input = tEnv.toDataStream(inputs[0], Integer.class);
DataStream<Integer> modelData =
input.transform("SumOperator", BasicTypeInfo.INT_TYPE_INFO, new SumOperator())
.setParallelism(1);
try {
SumModel model = new SumModel();
return model.setModelData(tEnv.fromDataStream(modelData));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void save(String path) throws IOException {
ReadWriteUtils.saveMetadata(this, path);
}
public static SumEstimator load(StreamTableEnvironment tEnv, String path)
throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
}
private static class SumOperator extends AbstractStreamOperator<Integer>
implements OneInputStreamOperator<Integer, Integer>, BoundedOneInput {
int sum = 0;
@Override
public void endInput() throws Exception {
output.collect(new StreamRecord<>(sum));
}
@Override
public void processElement(StreamRecord<Integer> input) throws Exception {
sum += input.getValue();
}
}
/**
* A Transformer subclass that takes 2 inputs and returns the union of these two inputs as the
* output.
*/
public static class UnionAlgoOperator implements Transformer<UnionAlgoOperator> {
private final Map<Param<?>, Object> paramMap = new HashMap<>();
public UnionAlgoOperator() {
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
}
@Override
public Map<Param<?>, Object> getParamMap() {
return paramMap;
}
@Override
public Table[] transform(Table... inputs) {
Assert.assertEquals(2, inputs.length);
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
DataStream<Integer> inputA = tEnv.toDataStream(inputs[0], Integer.class);
DataStream<Integer> inputB = tEnv.toDataStream(inputs[1], Integer.class);
return new Table[] {tEnv.fromDataStream(inputA.union(inputB))};
}
@Override
public void save(String path) throws IOException {
ReadWriteUtils.saveMetadata(this, path);
}
public static UnionAlgoOperator load(StreamTableEnvironment tEnv, String path)
throws IOException {
return ReadWriteUtils.loadStageParam(path);
}
}
}