| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.ml.feature.stringindexer; |
| |
| import org.apache.flink.api.common.restartstrategy.RestartStrategies; |
| import org.apache.flink.configuration.Configuration; |
| import org.apache.flink.ml.common.param.HasHandleInvalid; |
| import org.apache.flink.ml.util.ReadWriteUtils; |
| import org.apache.flink.ml.util.TestUtils; |
| import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; |
| import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; |
| import org.apache.flink.table.api.Table; |
| import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; |
| import org.apache.flink.test.util.AbstractTestBase; |
| import org.apache.flink.types.Row; |
| |
| import org.apache.commons.collections.IteratorUtils; |
| import org.apache.commons.lang3.exception.ExceptionUtils; |
| import org.junit.Before; |
| import org.junit.Rule; |
| import org.junit.Test; |
| import org.junit.rules.TemporaryFolder; |
| |
| import java.util.Arrays; |
| import java.util.Collections; |
| import java.util.HashSet; |
| import java.util.List; |
| import java.util.Set; |
| |
| import static org.junit.Assert.assertArrayEquals; |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertTrue; |
| import static org.junit.Assert.fail; |
| |
| /** Tests the {@link StringIndexer} and {@link StringIndexerModel}. */ |
| public class StringIndexerTest extends AbstractTestBase { |
| @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); |
| private StreamExecutionEnvironment env; |
| private StreamTableEnvironment tEnv; |
| private Table trainTable; |
| private Table predictTable; |
| |
| private final String[][] expectedAlphabeticAscModelData = |
| new String[][] {{"a", "b", "c", "d"}, {"-1.0", "0.0", "1.0", "2.0"}}; |
| private final List<Row> expectedAlphabeticAscPredictData = |
| Arrays.asList( |
| Row.of("a", 2.0, 0.0, 3.0), |
| Row.of("b", 1.0, 1.0, 2.0), |
| Row.of("e", 2.0, 4.0, 3.0)); |
| private final List<Row> expectedAlphabeticDescPredictData = |
| Arrays.asList( |
| Row.of("a", 2.0, 3.0, 0.0), |
| Row.of("b", 1.0, 2.0, 1.0), |
| Row.of("e", 2.0, 4.0, 0.0)); |
| private final List<Row> expectedFreqAscPredictData = |
| Arrays.asList( |
| Row.of("a", 2.0, 2.0, 3.0), |
| Row.of("b", 1.0, 3.0, 1.0), |
| Row.of("e", 2.0, 4.0, 3.0)); |
| private final List<Row> expectedFreqDescPredictData = |
| Arrays.asList( |
| Row.of("a", 2.0, 1.0, 0.0), |
| Row.of("b", 1.0, 0.0, 2.0), |
| Row.of("e", 2.0, 4.0, 0.0)); |
| |
| @Before |
| public void before() { |
| Configuration config = new Configuration(); |
| config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); |
| env = StreamExecutionEnvironment.getExecutionEnvironment(config); |
| env.setParallelism(4); |
| env.enableCheckpointing(100); |
| env.setRestartStrategy(RestartStrategies.noRestart()); |
| tEnv = StreamTableEnvironment.create(env); |
| |
| List<Row> trainData = |
| Arrays.asList( |
| Row.of("a", 1.0), |
| Row.of("b", 1.0), |
| Row.of("b", 2.0), |
| Row.of("c", 0.0), |
| Row.of("d", 2.0), |
| Row.of("a", 2.0), |
| Row.of("b", 2.0), |
| Row.of("b", -1.0), |
| Row.of("a", -1.0), |
| Row.of("c", -1.0)); |
| trainTable = |
| tEnv.fromDataStream(env.fromCollection(trainData)).as("inputCol1", "inputCol2"); |
| |
| List<Row> predictData = Arrays.asList(Row.of("a", 2.0), Row.of("b", 1.0), Row.of("e", 2.0)); |
| predictTable = |
| tEnv.fromDataStream(env.fromCollection(predictData)).as("inputCol1", "inputCol2"); |
| } |
| |
| @Test |
| public void testParam() { |
| StringIndexer stringIndexer = new StringIndexer(); |
| assertEquals(stringIndexer.getStringOrderType(), StringIndexerParams.ARBITRARY_ORDER); |
| assertEquals(stringIndexer.getHandleInvalid(), StringIndexerParams.ERROR_INVALID); |
| |
| stringIndexer |
| .setInputCols("inputCol1", "inputCol2") |
| .setOutputCols("outputCol1", "outputCol2") |
| .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER) |
| .setHandleInvalid(StringIndexerParams.SKIP_INVALID); |
| |
| assertArrayEquals(new String[] {"inputCol1", "inputCol2"}, stringIndexer.getInputCols()); |
| assertArrayEquals(new String[] {"outputCol1", "outputCol2"}, stringIndexer.getOutputCols()); |
| assertEquals(stringIndexer.getStringOrderType(), StringIndexerParams.ALPHABET_ASC_ORDER); |
| assertEquals(stringIndexer.getHandleInvalid(), StringIndexerParams.SKIP_INVALID); |
| } |
| |
| @Test |
| public void testOutputSchema() { |
| StringIndexer stringIndexer = |
| new StringIndexer() |
| .setInputCols("inputCol1", "inputCol2") |
| .setOutputCols("outputCol1", "outputCol2") |
| .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER) |
| .setHandleInvalid(StringIndexerParams.SKIP_INVALID); |
| Table output = stringIndexer.fit(trainTable).transform(predictTable)[0]; |
| |
| assertEquals( |
| Arrays.asList("inputCol1", "inputCol2", "outputCol1", "outputCol2"), |
| output.getResolvedSchema().getColumnNames()); |
| } |
| |
| @Test |
| @SuppressWarnings("all") |
| public void testStringOrderType() throws Exception { |
| StringIndexer stringIndexer = |
| new StringIndexer() |
| .setInputCols("inputCol1", "inputCol2") |
| .setOutputCols("outputCol1", "outputCol2") |
| .setHandleInvalid(StringIndexerParams.KEEP_INVALID); |
| Table output; |
| List<Row> predictedResult; |
| |
| // AlphabetAsc order. |
| stringIndexer.setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER); |
| output = stringIndexer.fit(trainTable).transform(predictTable)[0]; |
| predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); |
| verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult); |
| |
| // AlphabetDesc order. |
| stringIndexer.setStringOrderType(StringIndexerParams.ALPHABET_DESC_ORDER); |
| output = stringIndexer.fit(trainTable).transform(predictTable)[0]; |
| predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); |
| verifyPredictionResult(expectedAlphabeticDescPredictData, predictedResult); |
| |
| // FrequencyAsc order. |
| stringIndexer.setStringOrderType(StringIndexerParams.FREQUENCY_ASC_ORDER); |
| output = stringIndexer.fit(trainTable).transform(predictTable)[0]; |
| predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); |
| verifyPredictionResult(expectedFreqAscPredictData, predictedResult); |
| |
| // FrequencyDesc order. |
| stringIndexer.setStringOrderType(StringIndexerParams.FREQUENCY_DESC_ORDER); |
| output = stringIndexer.fit(trainTable).transform(predictTable)[0]; |
| predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); |
| verifyPredictionResult(expectedFreqDescPredictData, predictedResult); |
| |
| // Arbitrary order. |
| stringIndexer.setStringOrderType(StringIndexerParams.ARBITRARY_ORDER); |
| output = stringIndexer.fit(trainTable).transform(predictTable)[0]; |
| predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); |
| |
| Set<Double> distinctStringsCol1 = new HashSet<>(); |
| Set<Double> distinctStringsCol2 = new HashSet<>(); |
| double index; |
| for (Row r : predictedResult) { |
| index = (Double) r.getField(2); |
| distinctStringsCol1.add(index); |
| assertTrue(index >= 0 && index <= 4); |
| index = (Double) r.getField(3); |
| assertTrue(index >= 0 && index <= 3); |
| distinctStringsCol2.add(index); |
| } |
| |
| assertEquals(3, distinctStringsCol1.size()); |
| assertEquals(2, distinctStringsCol2.size()); |
| } |
| |
| @Test |
| @SuppressWarnings("unchecked") |
| public void testHandleInvalid() throws Exception { |
| StringIndexer stringIndexer = |
| new StringIndexer() |
| .setInputCols("inputCol1", "inputCol2") |
| .setOutputCols("outputCol1", "outputCol2") |
| .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER); |
| |
| Table output; |
| List<Row> expectedResult; |
| |
| // Keeps invalid data. |
| stringIndexer.setHandleInvalid(StringIndexerParams.KEEP_INVALID); |
| output = stringIndexer.fit(trainTable).transform(predictTable)[0]; |
| List<Row> predictedResult = |
| IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); |
| verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult); |
| |
| // Skips invalid data. |
| stringIndexer.setHandleInvalid(StringIndexerParams.SKIP_INVALID); |
| output = stringIndexer.fit(trainTable).transform(predictTable)[0]; |
| predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); |
| expectedResult = Arrays.asList(Row.of("a", 2.0, 0.0, 3.0), Row.of("b", 1.0, 1.0, 2.0)); |
| verifyPredictionResult(expectedResult, predictedResult); |
| |
| // Throws an exception on invalid data. |
| stringIndexer.setHandleInvalid(StringIndexerParams.ERROR_INVALID); |
| try { |
| output = stringIndexer.fit(trainTable).transform(predictTable)[0]; |
| IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); |
| fail(); |
| } catch (Throwable e) { |
| assertEquals( |
| "The input contains unseen string: e. " |
| + "See " |
| + HasHandleInvalid.HANDLE_INVALID |
| + " parameter for more options.", |
| ExceptionUtils.getRootCause(e).getMessage()); |
| } |
| } |
| |
| @Test |
| @SuppressWarnings("unchecked") |
| public void testFitAndPredict() throws Exception { |
| StringIndexer stringIndexer = |
| new StringIndexer() |
| .setInputCols("inputCol1", "inputCol2") |
| .setOutputCols("outputCol1", "outputCol2") |
| .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER) |
| .setHandleInvalid(StringIndexerParams.KEEP_INVALID); |
| Table output = stringIndexer.fit(trainTable).transform(predictTable)[0]; |
| |
| List<Row> predictedResult = |
| IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); |
| verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult); |
| } |
| |
| @Test |
| @SuppressWarnings("unchecked") |
| public void testSaveLoadAndPredict() throws Exception { |
| StringIndexer stringIndexer = |
| new StringIndexer() |
| .setInputCols("inputCol1", "inputCol2") |
| .setOutputCols("outputCol1", "outputCol2") |
| .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER) |
| .setHandleInvalid(StringIndexerParams.KEEP_INVALID); |
| stringIndexer = |
| TestUtils.saveAndReload( |
| tEnv, stringIndexer, tempFolder.newFolder().getAbsolutePath()); |
| |
| StringIndexerModel model = stringIndexer.fit(trainTable); |
| model = TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath()); |
| |
| assertEquals( |
| Collections.singletonList("stringArrays"), |
| model.getModelData()[0].getResolvedSchema().getColumnNames()); |
| |
| Table output = model.transform(predictTable)[0]; |
| List<Row> predictedResult = |
| IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); |
| verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult); |
| } |
| |
| @Test |
| @SuppressWarnings("unchecked") |
| public void testGetModelData() throws Exception { |
| StringIndexerModel model = |
| new StringIndexer() |
| .setInputCols("inputCol1", "inputCol2") |
| .setOutputCols("outputCol1", "outputCol2") |
| .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER) |
| .fit(trainTable); |
| Table modelDataTable = model.getModelData()[0]; |
| |
| assertEquals( |
| Collections.singletonList("stringArrays"), |
| modelDataTable.getResolvedSchema().getColumnNames()); |
| |
| List<StringIndexerModelData> collectedModelData = |
| (List<StringIndexerModelData>) |
| (IteratorUtils.toList( |
| StringIndexerModelData.getModelDataStream(modelDataTable) |
| .executeAndCollect())); |
| assertEquals(1, collectedModelData.size()); |
| |
| StringIndexerModelData modelData = collectedModelData.get(0); |
| assertEquals(2, modelData.stringArrays.length); |
| assertArrayEquals(expectedAlphabeticAscModelData[0], modelData.stringArrays[0]); |
| assertArrayEquals(expectedAlphabeticAscModelData[1], modelData.stringArrays[1]); |
| } |
| |
| @Test |
| @SuppressWarnings("unchecked") |
| public void testSetModelData() throws Exception { |
| StringIndexerModel model = |
| new StringIndexer() |
| .setInputCols("inputCol1", "inputCol2") |
| .setOutputCols("outputCol1", "outputCol2") |
| .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER) |
| .setHandleInvalid(StringIndexerParams.KEEP_INVALID) |
| .fit(trainTable); |
| |
| StringIndexerModel newModel = new StringIndexerModel(); |
| ReadWriteUtils.updateExistingParams(newModel, model.getParamMap()); |
| newModel.setModelData(model.getModelData()); |
| Table output = newModel.transform(predictTable)[0]; |
| |
| List<Row> predictedResult = |
| IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); |
| verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult); |
| } |
| |
| static void verifyPredictionResult(List<Row> expected, List<Row> result) { |
| compareResultCollections( |
| expected, |
| result, |
| (row1, row2) -> { |
| int arity = Math.min(row1.getArity(), row2.getArity()); |
| for (int i = 0; i < arity; i++) { |
| int cmp = |
| String.valueOf(row1.getField(i)) |
| .compareTo(String.valueOf(row2.getField(i))); |
| if (cmp != 0) { |
| return cmp; |
| } |
| } |
| return 0; |
| }); |
| } |
| } |