blob: 6cad1b1071fe162e0e1ec23cc15c4631a7984416 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.api;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.ExampleStages.SumEstimator;
import org.apache.flink.ml.api.ExampleStages.SumModel;
import org.apache.flink.ml.api.ExampleStages.UnionAlgoOperator;
import org.apache.flink.ml.builder.Graph;
import org.apache.flink.ml.builder.GraphBuilder;
import org.apache.flink.ml.builder.GraphModel;
import org.apache.flink.ml.builder.TableId;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.junit.Before;
import org.junit.Test;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/** Tests the behavior of {@link Graph} and {@link GraphModel}. */
public class GraphTest extends AbstractTestBase {
private StreamExecutionEnvironment env;
private StreamTableEnvironment tEnv;
@Before
public void before() {
Configuration config = new Configuration();
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
env = StreamExecutionEnvironment.getExecutionEnvironment(config);
env.setParallelism(4);
env.enableCheckpointing(100);
env.setRestartStrategy(RestartStrategies.noRestart());
tEnv = StreamTableEnvironment.create(env);
}
// Executes the given stage using the given inputs and verifies that it produces the expected
// output. Then repeats this procedure after saving and loading the given stage.
private static void executeSaveLoadAndCheckOutput(
StreamTableEnvironment tEnv,
Stage<?> stage,
List<List<Integer>> inputs,
List<Integer> expectedOutput,
List<List<Integer>> modelDataInputs,
List<Integer> expectedModelDataOutput,
boolean modelDataExists)
throws Exception {
StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv);
// Executes the given stage and verifies that it produces the expected output.
TestUtils.executeAndCheckOutput(
env, stage, inputs, expectedOutput, modelDataInputs, expectedModelDataOutput);
// Saves and loads the given stage.
String path = Files.createTempDirectory("").toString();
stage.save(path);
if (modelDataExists) {
env.execute();
}
Stage<?> loadedStage = null;
if (stage instanceof Estimator) {
loadedStage = Graph.load(tEnv, path);
} else {
loadedStage = GraphModel.load(tEnv, path);
}
// Executes the loaded stage and verifies that it produces the expected output.
TestUtils.executeAndCheckOutput(
env, loadedStage, inputs, expectedOutput, modelDataInputs, expectedModelDataOutput);
}
@Test
public void testGraphModelWithoutEstimator() throws Exception {
GraphBuilder builder = new GraphBuilder();
// Creates nodes.
SumModel stage1 = new SumModel().setModelData(tEnv.fromValues(2));
SumModel stage2 = new SumModel().setModelData(tEnv.fromValues(1));
AlgoOperator<?> stage3 = new UnionAlgoOperator();
// Creates inputs.
TableId input1 = builder.createTableId();
TableId input2 = builder.createTableId();
// Feeds inputs and gets outputs.
TableId output1 = builder.addAlgoOperator(stage1, input1)[0];
TableId output2 = builder.addAlgoOperator(stage2, input2)[0];
TableId output3 = builder.addAlgoOperator(stage3, output1, output2)[0];
// Builds a Model from the graph.
TableId[] inputs = new TableId[] {input1, input2};
TableId[] outputs = new TableId[] {output3};
Model<?> model = builder.buildModel(inputs, outputs);
// Executes the GraphModel and verifies that it produces the expected output.
List<List<Integer>> inputValues = new ArrayList<>();
inputValues.add(Arrays.asList(1, 2, 3));
inputValues.add(Arrays.asList(10, 11, 12));
List<Integer> expectedOutputValues = Arrays.asList(3, 4, 5, 11, 12, 13);
executeSaveLoadAndCheckOutput(
tEnv, model, inputValues, expectedOutputValues, null, null, true);
}
@Test
public void testGraphModelWithEstimator() throws Exception {
GraphBuilder builder = new GraphBuilder();
// Creates nodes.
Estimator<?, ?> stage1 = new SumEstimator();
Estimator<?, ?> stage2 = new SumEstimator();
AlgoOperator<?> stage3 = new UnionAlgoOperator();
// Creates inputs.
TableId input1 = builder.createTableId();
TableId input2 = builder.createTableId();
// Feeds inputs and gets outputs.
TableId output1 = builder.addEstimator(stage1, input1)[0];
TableId output2 = builder.addEstimator(stage2, input2)[0];
TableId output3 = builder.addAlgoOperator(stage3, output1, output2)[0];
// Builds a Model from the graph.
TableId[] inputs = new TableId[] {input1, input2};
TableId[] outputs = new TableId[] {output3};
Model<?> model = builder.buildModel(inputs, outputs);
// Executes the GraphModel and verifies that it produces the expected output.
List<List<Integer>> inputValues = new ArrayList<>();
inputValues.add(Arrays.asList(1, 2, 3));
inputValues.add(Arrays.asList(10, 11, 12));
List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9, 43, 44, 45);
executeSaveLoadAndCheckOutput(
tEnv, model, inputValues, expectedOutputValues, null, null, false);
}
@Test
public void testGraphModelWithSetGetModelData() throws Exception {
GraphBuilder builder = new GraphBuilder();
// Creates nodes.
SumModel stage1 = new SumModel().setModelData(tEnv.fromValues(1));
SumModel stage2 = new SumModel();
SumModel stage3 = new SumModel().setModelData(tEnv.fromValues(3));
// Creates inputs and modelDataInputs.
TableId input = builder.createTableId();
TableId modelDataInput = builder.createTableId();
// Feeds inputs and gets outputs.
TableId output1 = builder.addAlgoOperator(stage1, input)[0];
TableId output2 = builder.addAlgoOperator(stage2, output1)[0];
builder.setModelDataOnModel(stage2, modelDataInput);
TableId output3 = builder.addAlgoOperator(stage3, output2)[0];
TableId modelDataOutput = builder.getModelDataFromModel(stage3)[0];
// Builds a Model from the graph.
TableId[] inputs = new TableId[] {input};
TableId[] outputs = new TableId[] {output3};
TableId[] modelDataInputs = new TableId[] {modelDataInput};
TableId[] modelDataOutputs = new TableId[] {modelDataOutput};
Model<?> model = builder.buildModel(inputs, outputs, modelDataInputs, modelDataOutputs);
// Executes the GraphModel and verifies that it produces the expected output.
List<List<Integer>> inputValues = Collections.singletonList(Arrays.asList(1, 2, 3));
List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9);
List<List<Integer>> inputModelDataValues =
Collections.singletonList(Collections.singletonList(2));
List<Integer> expectedModelDataOutputValues = Collections.singletonList(3);
executeSaveLoadAndCheckOutput(
tEnv,
model,
inputValues,
expectedOutputValues,
inputModelDataValues,
expectedModelDataOutputValues,
true);
}
@Test
public void testGraphWithEstimator() throws Exception {
GraphBuilder builder = new GraphBuilder();
// Creates nodes.
Estimator<?, ?> stage1 = new SumEstimator();
Estimator<?, ?> stage2 = new SumEstimator();
AlgoOperator<?> stage3 = new UnionAlgoOperator();
// Creates inputs.
TableId input1 = builder.createTableId();
TableId input2 = builder.createTableId();
// Feeds inputs and gets outputs.
TableId output1 = builder.addEstimator(stage1, input1)[0];
TableId output2 = builder.addEstimator(stage2, input2)[0];
TableId output3 = builder.addAlgoOperator(stage3, output1, output2)[0];
// Builds an Estimator from the graph.
TableId[] inputs = new TableId[] {input1, input2};
TableId[] outputs = new TableId[] {output3};
Estimator<?, ?> estimator = builder.buildEstimator(inputs, outputs);
// Executes the Graph and verifies that it produces the expected output.
List<List<Integer>> inputValues = new ArrayList<>();
inputValues.add(Arrays.asList(1, 2, 3));
inputValues.add(Arrays.asList(10, 11, 12));
List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9, 43, 44, 45);
executeSaveLoadAndCheckOutput(
tEnv, estimator, inputValues, expectedOutputValues, null, null, false);
}
@Test
public void testGraphWithSetGetModelData() throws Exception {
GraphBuilder builder = new GraphBuilder();
// Creates nodes.
Estimator<?, ?> stage1 = new SumEstimator();
SumModel stage2 = new SumModel();
AlgoOperator<?> stage3 = new UnionAlgoOperator();
// Creates inputs.
TableId input1 = builder.createTableId();
TableId input2 = builder.createTableId();
// Feeds inputs and gets outputs.
TableId output1 = builder.addEstimator(stage1, input1)[0];
TableId modelDataOutput = builder.getModelDataFromEstimator(stage1)[0];
TableId output2 = builder.addAlgoOperator(stage2, input2)[0];
builder.setModelDataOnModel(stage2, modelDataOutput);
TableId output3 = builder.addAlgoOperator(stage3, output1, output2)[0];
// Builds an Estimator from the graph.
TableId[] inputs = new TableId[] {input1, input2};
TableId[] outputs = new TableId[] {output3};
TableId[] modelDataOutputs = new TableId[] {modelDataOutput};
Estimator<?, ?> estimator = builder.buildEstimator(inputs, outputs, null, modelDataOutputs);
// Executes the Graph and verifies that it produces the expected output.
List<List<Integer>> inputValues = new ArrayList<>();
inputValues.add(Arrays.asList(1, 2, 3));
inputValues.add(Arrays.asList(10, 11, 12));
List<Integer> expectedOutputValues = Arrays.asList(7, 8, 9, 16, 17, 18);
List<Integer> expectedModelDataOutputValues = Collections.singletonList(6);
executeSaveLoadAndCheckOutput(
tEnv,
estimator,
inputValues,
expectedOutputValues,
null,
expectedModelDataOutputValues,
true);
}
}