| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.ml.feature; |
| |
| import org.apache.flink.api.common.restartstrategy.RestartStrategies; |
| import org.apache.flink.api.java.tuple.Tuple2; |
| import org.apache.flink.configuration.Configuration; |
| import org.apache.flink.ml.common.param.HasHandleInvalid; |
| import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder; |
| import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel; |
| import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModelData; |
| import org.apache.flink.ml.linalg.Vector; |
| import org.apache.flink.ml.linalg.Vectors; |
| import org.apache.flink.ml.util.ReadWriteUtils; |
| import org.apache.flink.ml.util.TestUtils; |
| import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; |
| import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; |
| import org.apache.flink.table.api.Table; |
| import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; |
| import org.apache.flink.types.Row; |
| import org.apache.flink.util.CloseableIterator; |
| |
| import org.apache.commons.lang3.exception.ExceptionUtils; |
| import org.junit.Assert; |
| import org.junit.Before; |
| import org.junit.Rule; |
| import org.junit.Test; |
| import org.junit.rules.TemporaryFolder; |
| |
| import java.util.Arrays; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| |
| import static org.junit.Assert.assertArrayEquals; |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertFalse; |
| import static org.junit.Assert.assertTrue; |
| |
| /** Tests OneHotEncoder and OneHotEncoderModel. */ |
| public class OneHotEncoderTest { |
| @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); |
| |
| private StreamExecutionEnvironment env; |
| private StreamTableEnvironment tEnv; |
| private Table trainTable; |
| private Table predictTable; |
| private Map<Double, Vector>[] expectedOutput; |
| private OneHotEncoder estimator; |
| |
| @Before |
| public void before() { |
| Configuration config = new Configuration(); |
| config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); |
| env = StreamExecutionEnvironment.getExecutionEnvironment(config); |
| env.setParallelism(4); |
| env.enableCheckpointing(100); |
| env.setRestartStrategy(RestartStrategies.noRestart()); |
| tEnv = StreamTableEnvironment.create(env); |
| |
| List<Row> trainData = Arrays.asList(Row.of(0.0), Row.of(1.0), Row.of(2.0), Row.of(0.0)); |
| |
| trainTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("input"); |
| |
| List<Row> predictData = Arrays.asList(Row.of(0.0), Row.of(1.0), Row.of(2.0)); |
| |
| predictTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("input"); |
| |
| expectedOutput = |
| new HashMap[] { |
| new HashMap<Double, Vector>() { |
| { |
| put(0.0, Vectors.sparse(2, new int[] {0}, new double[] {1.0})); |
| put(1.0, Vectors.sparse(2, new int[] {1}, new double[] {1.0})); |
| put(2.0, Vectors.sparse(2, new int[0], new double[0])); |
| } |
| } |
| }; |
| |
| estimator = new OneHotEncoder().setInputCols("input").setOutputCols("output"); |
| } |
| |
| /** |
| * Executes a given table and collect its results. Results are returned as a map array. Each |
| * element in the array is a map corresponding to a input column whose key is the original value |
| * in the input column, value is the one-hot encoding result of that value. |
| * |
| * @param table A table to be executed and to have its result collected |
| * @param inputCols Name of the input columns |
| * @param outputCols Name of the output columns containing one-hot encoding result |
| * @return An array of map containing the collected results for each input column |
| */ |
| private static Map<Double, Vector>[] executeAndCollect( |
| Table table, String[] inputCols, String[] outputCols) { |
| Map<Double, Vector>[] maps = new HashMap[inputCols.length]; |
| for (int i = 0; i < inputCols.length; i++) { |
| maps[i] = new HashMap<>(); |
| } |
| for (CloseableIterator<Row> it = table.execute().collect(); it.hasNext(); ) { |
| Row row = it.next(); |
| for (int i = 0; i < inputCols.length; i++) { |
| maps[i].put( |
| ((Number) row.getField(inputCols[i])).doubleValue(), |
| (Vector) row.getField(outputCols[i])); |
| } |
| } |
| return maps; |
| } |
| |
| @Test |
| public void testParam() { |
| OneHotEncoder estimator = new OneHotEncoder(); |
| |
| assertTrue(estimator.getDropLast()); |
| |
| estimator.setInputCols("test_input").setOutputCols("test_output").setDropLast(false); |
| |
| assertArrayEquals(new String[] {"test_input"}, estimator.getInputCols()); |
| assertArrayEquals(new String[] {"test_output"}, estimator.getOutputCols()); |
| assertFalse(estimator.getDropLast()); |
| |
| OneHotEncoderModel model = new OneHotEncoderModel(); |
| |
| assertTrue(model.getDropLast()); |
| |
| model.setInputCols("test_input").setOutputCols("test_output").setDropLast(false); |
| |
| assertArrayEquals(new String[] {"test_input"}, model.getInputCols()); |
| assertArrayEquals(new String[] {"test_output"}, model.getOutputCols()); |
| assertFalse(model.getDropLast()); |
| } |
| |
| @Test |
| public void testFitAndPredict() { |
| OneHotEncoderModel model = estimator.fit(trainTable); |
| Table outputTable = model.transform(predictTable)[0]; |
| Map<Double, Vector>[] actualOutput = |
| executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); |
| assertArrayEquals(expectedOutput, actualOutput); |
| } |
| |
| @Test |
| public void testInputTypeConversion() throws Exception { |
| trainTable = TestUtils.convertDataTypesToSparseInt(tEnv, trainTable); |
| predictTable = TestUtils.convertDataTypesToSparseInt(tEnv, predictTable); |
| assertArrayEquals(new Class<?>[] {Integer.class}, TestUtils.getColumnDataTypes(trainTable)); |
| assertArrayEquals( |
| new Class<?>[] {Integer.class}, TestUtils.getColumnDataTypes(predictTable)); |
| |
| OneHotEncoderModel model = estimator.fit(trainTable); |
| Table outputTable = model.transform(predictTable)[0]; |
| Map<Double, Vector>[] actualOutput = |
| executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); |
| assertArrayEquals(expectedOutput, actualOutput); |
| } |
| |
| @Test |
| public void testDropLast() { |
| estimator.setDropLast(false); |
| |
| expectedOutput = |
| new HashMap[] { |
| new HashMap<Double, Vector>() { |
| { |
| put(0.0, Vectors.sparse(3, new int[] {0}, new double[] {1.0})); |
| put(1.0, Vectors.sparse(3, new int[] {1}, new double[] {1.0})); |
| put(2.0, Vectors.sparse(3, new int[] {2}, new double[] {1.0})); |
| } |
| } |
| }; |
| |
| OneHotEncoderModel model = estimator.fit(trainTable); |
| Table outputTable = model.transform(predictTable)[0]; |
| Map<Double, Vector>[] actualOutput = |
| executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); |
| assertArrayEquals(expectedOutput, actualOutput); |
| } |
| |
| @Test |
| public void testInputDataType() { |
| List<Row> trainData = Arrays.asList(Row.of(0), Row.of(1), Row.of(2), Row.of(0)); |
| |
| trainTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("input"); |
| |
| List<Row> predictData = Arrays.asList(Row.of(0), Row.of(1), Row.of(2)); |
| predictTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("input"); |
| |
| expectedOutput = |
| new HashMap[] { |
| new HashMap<Double, Vector>() { |
| { |
| put(0.0, Vectors.sparse(2, new int[] {0}, new double[] {1.0})); |
| put(1.0, Vectors.sparse(2, new int[] {1}, new double[] {1.0})); |
| put(2.0, Vectors.sparse(2, new int[0], new double[0])); |
| } |
| } |
| }; |
| |
| OneHotEncoderModel model = estimator.fit(trainTable); |
| Table outputTable = model.transform(predictTable)[0]; |
| Map<Double, Vector>[] actualOutput = |
| executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); |
| assertArrayEquals(expectedOutput, actualOutput); |
| } |
| |
| @Test |
| public void testNotSupportedHandleInvalidOptions() { |
| estimator.setHandleInvalid(HasHandleInvalid.SKIP_INVALID); |
| try { |
| estimator.fit(trainTable); |
| Assert.fail("Expected IllegalArgumentException"); |
| } catch (Exception e) { |
| assertEquals(IllegalArgumentException.class, ((Throwable) e).getClass()); |
| } |
| } |
| |
| @Test |
| public void testNonIndexedTrainData() { |
| List<Row> trainData = Arrays.asList(Row.of(0.5), Row.of(1.0), Row.of(2.0), Row.of(0.0)); |
| |
| trainTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("input"); |
| OneHotEncoderModel model = estimator.fit(trainTable); |
| Table outputTable = model.transform(predictTable)[0]; |
| try { |
| outputTable.execute().collect().next(); |
| Assert.fail("Expected IllegalArgumentException"); |
| } catch (Exception e) { |
| Throwable exception = ExceptionUtils.getRootCause(e); |
| assertEquals(IllegalArgumentException.class, exception.getClass()); |
| assertEquals("Value 0.5 cannot be parsed as indexed integer.", exception.getMessage()); |
| } |
| } |
| |
| @Test |
| public void testNonIndexedPredictData() { |
| List<Row> predictData = Arrays.asList(Row.of(0.5), Row.of(1.0), Row.of(2.0), Row.of(0.0)); |
| |
| predictTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("input"); |
| OneHotEncoderModel model = estimator.fit(trainTable); |
| Table outputTable = model.transform(predictTable)[0]; |
| try { |
| outputTable.execute().collect().next(); |
| Assert.fail("Expected IllegalArgumentException"); |
| } catch (Exception e) { |
| Throwable exception = e; |
| while (exception.getCause() != null) { |
| exception = exception.getCause(); |
| } |
| assertEquals(IllegalArgumentException.class, exception.getClass()); |
| assertEquals("Value 0.5 cannot be parsed as indexed integer.", exception.getMessage()); |
| } |
| } |
| |
| @Test |
| public void testSaveLoad() throws Exception { |
| estimator = |
| TestUtils.saveAndReload(tEnv, estimator, tempFolder.newFolder().getAbsolutePath()); |
| OneHotEncoderModel model = estimator.fit(trainTable); |
| model = TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath()); |
| Table outputTable = model.transform(predictTable)[0]; |
| Map<Double, Vector>[] actualOutput = |
| executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); |
| assertArrayEquals(expectedOutput, actualOutput); |
| } |
| |
| @Test |
| public void testGetModelData() throws Exception { |
| OneHotEncoderModel model = estimator.fit(trainTable); |
| Tuple2<Integer, Integer> expected = new Tuple2<>(0, 2); |
| Tuple2<Integer, Integer> actual = |
| OneHotEncoderModelData.getModelDataStream(model.getModelData()[0]) |
| .executeAndCollect() |
| .next(); |
| assertEquals(expected, actual); |
| } |
| |
| @Test |
| public void testSetModelData() { |
| OneHotEncoderModel modelA = estimator.fit(trainTable); |
| |
| Table modelData = modelA.getModelData()[0]; |
| OneHotEncoderModel modelB = new OneHotEncoderModel().setModelData(modelData); |
| ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap()); |
| |
| Table outputTable = modelB.transform(predictTable)[0]; |
| Map<Double, Vector>[] actualOutput = |
| executeAndCollect(outputTable, modelB.getInputCols(), modelB.getOutputCols()); |
| assertArrayEquals(expectedOutput, actualOutput); |
| } |
| } |