| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.ml.feature; |
| |
| import org.apache.flink.api.common.restartstrategy.RestartStrategies; |
| import org.apache.flink.configuration.Configuration; |
| import org.apache.flink.ml.feature.vectorslicer.VectorSlicer; |
| import org.apache.flink.ml.linalg.DenseVector; |
| import org.apache.flink.ml.linalg.SparseVector; |
| import org.apache.flink.ml.linalg.Vectors; |
| import org.apache.flink.ml.util.TestUtils; |
| import org.apache.flink.streaming.api.datastream.DataStream; |
| import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; |
| import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; |
| import org.apache.flink.table.api.Table; |
| import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; |
| import org.apache.flink.test.util.AbstractTestBase; |
| import org.apache.flink.types.Row; |
| |
| import org.apache.commons.collections.IteratorUtils; |
| import org.apache.commons.lang3.exception.ExceptionUtils; |
| import org.junit.Before; |
| import org.junit.Test; |
| |
| import java.util.Arrays; |
| import java.util.List; |
| |
| import static org.junit.Assert.assertArrayEquals; |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.fail; |
| |
| /** Tests {@link VectorSlicer}. */ |
| public class VectorSlicerTest extends AbstractTestBase { |
| |
| private StreamTableEnvironment tEnv; |
| private Table inputDataTable; |
| |
| private static final List<Row> INPUT_DATA = |
| Arrays.asList( |
| Row.of( |
| 0, |
| Vectors.dense(2.1, 3.1, 2.3, 3.4, 5.3, 5.1), |
| Vectors.sparse(5, new int[] {1, 3, 4}, new double[] {0.1, 0.2, 0.3})), |
| Row.of( |
| 1, |
| Vectors.dense(2.3, 4.1, 1.3, 2.4, 5.1, 4.1), |
| Vectors.sparse(5, new int[] {1, 2, 4}, new double[] {0.1, 0.2, 0.3}))); |
| |
| private static final DenseVector EXPECTED_OUTPUT_DATA_1 = Vectors.dense(2.1, 3.1, 2.3); |
| private static final DenseVector EXPECTED_OUTPUT_DATA_2 = Vectors.dense(2.3, 4.1, 1.3); |
| |
| private static final SparseVector EXPECTED_OUTPUT_DATA_3 = |
| Vectors.sparse(3, new int[] {1}, new double[] {0.1}); |
| private static final SparseVector EXPECTED_OUTPUT_DATA_4 = |
| Vectors.sparse(3, new int[] {1, 2}, new double[] {0.1, 0.2}); |
| |
| @Before |
| public void before() { |
| Configuration config = new Configuration(); |
| config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); |
| StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); |
| env.setParallelism(4); |
| env.enableCheckpointing(100); |
| env.setRestartStrategy(RestartStrategies.noRestart()); |
| tEnv = StreamTableEnvironment.create(env); |
| DataStream<Row> dataStream = env.fromCollection(INPUT_DATA); |
| inputDataTable = tEnv.fromDataStream(dataStream).as("id", "vec", "sparseVec"); |
| } |
| |
| private void verifyOutputResult(Table output, String outputCol, boolean isSparse) |
| throws Exception { |
| DataStream<Row> dataStream = tEnv.toDataStream(output); |
| List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect()); |
| assertEquals(2, results.size()); |
| for (Row result : results) { |
| if (result.getField(0) == (Object) 0) { |
| if (isSparse) { |
| assertEquals(EXPECTED_OUTPUT_DATA_3, result.getField(outputCol)); |
| } else { |
| assertEquals(EXPECTED_OUTPUT_DATA_1, result.getField(outputCol)); |
| } |
| } else if (result.getField(0) == (Object) 1) { |
| if (isSparse) { |
| assertEquals(EXPECTED_OUTPUT_DATA_4, result.getField(outputCol)); |
| } else { |
| assertEquals(EXPECTED_OUTPUT_DATA_2, result.getField(outputCol)); |
| } |
| } else { |
| throw new RuntimeException("Result id value is error, it must be 0 or 1."); |
| } |
| } |
| } |
| |
| @Test |
| public void testParam() { |
| VectorSlicer vectorSlicer = new VectorSlicer(); |
| assertEquals("input", vectorSlicer.getInputCol()); |
| assertEquals("output", vectorSlicer.getOutputCol()); |
| vectorSlicer.setInputCol("vec").setOutputCol("sliceVec").setIndices(0, 1, 2); |
| assertEquals("vec", vectorSlicer.getInputCol()); |
| assertEquals("sliceVec", vectorSlicer.getOutputCol()); |
| assertArrayEquals(new Integer[] {0, 1, 2}, vectorSlicer.getIndices()); |
| } |
| |
| @Test |
| public void testSaveLoadAndTransform() throws Exception { |
| VectorSlicer vectorSlicer = |
| new VectorSlicer().setInputCol("vec").setOutputCol("sliceVec").setIndices(0, 1, 2); |
| VectorSlicer loadedVectorSlicer = |
| TestUtils.saveAndReload( |
| tEnv, vectorSlicer, TEMPORARY_FOLDER.newFolder().getAbsolutePath()); |
| Table output = loadedVectorSlicer.transform(inputDataTable)[0]; |
| verifyOutputResult(output, loadedVectorSlicer.getOutputCol(), false); |
| } |
| |
| @Test |
| public void testEmptyIndices() { |
| try { |
| VectorSlicer vectorSlicer = |
| new VectorSlicer().setInputCol("vec").setOutputCol("sliceVec").setIndices(); |
| vectorSlicer.transform(inputDataTable); |
| fail(); |
| } catch (Exception e) { |
| assertEquals("Parameter indices is given an invalid value {}", e.getMessage()); |
| } |
| } |
| |
| @Test |
| public void testIndicesLargerThanVectorSize() { |
| try { |
| VectorSlicer vectorSlicer = |
| new VectorSlicer() |
| .setInputCol("vec") |
| .setOutputCol("sliceVec") |
| .setIndices(1, 2, 10); |
| Table output = vectorSlicer.transform(inputDataTable)[0]; |
| DataStream<Row> dataStream = tEnv.toDataStream(output); |
| IteratorUtils.toList(dataStream.executeAndCollect()); |
| fail(); |
| } catch (Exception e) { |
| assertEquals( |
| "Index value 10 is greater than vector size:6", |
| ExceptionUtils.getRootCause(e).getMessage()); |
| } |
| } |
| |
| @Test |
| public void testIndicesSmallerThanZero() { |
| try { |
| new VectorSlicer().setInputCol("vec").setOutputCol("sliceVec").setIndices(1, -2); |
| fail(); |
| } catch (Exception e) { |
| assertEquals("Parameter indices is given an invalid value {1,-2}", e.getMessage()); |
| } |
| } |
| |
| @Test |
| public void testDuplicateIndices() { |
| try { |
| new VectorSlicer().setInputCol("vec").setOutputCol("sliceVec").setIndices(1, 1, 3); |
| fail(); |
| } catch (Exception e) { |
| assertEquals("Parameter indices is given an invalid value {1,1,3}", e.getMessage()); |
| } |
| } |
| |
| @Test |
| public void testDenseTransform() throws Exception { |
| VectorSlicer vectorSlicer = |
| new VectorSlicer().setInputCol("vec").setOutputCol("sliceVec").setIndices(0, 1, 2); |
| |
| Table output = vectorSlicer.transform(inputDataTable)[0]; |
| verifyOutputResult(output, vectorSlicer.getOutputCol(), false); |
| } |
| |
| @Test |
| public void testDenseTransformWithUnorderedIndices() throws Exception { |
| VectorSlicer vectorSlicer = |
| new VectorSlicer().setInputCol("vec").setOutputCol("sliceVec").setIndices(0, 2, 1); |
| |
| Table output = vectorSlicer.transform(inputDataTable)[0]; |
| DataStream<Row> dataStream = tEnv.toDataStream(output); |
| List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect()); |
| assertEquals(2, results.size()); |
| for (Row result : results) { |
| if (result.getField(0) == (Object) 0) { |
| assertEquals( |
| Vectors.dense(2.1, 2.3, 3.1), result.getField(vectorSlicer.getOutputCol())); |
| |
| } else if (result.getField(0) == (Object) 1) { |
| assertEquals( |
| Vectors.dense(2.3, 1.3, 4.1), result.getField(vectorSlicer.getOutputCol())); |
| } else { |
| throw new RuntimeException("Result id value is error, it must be 0 or 1."); |
| } |
| } |
| } |
| |
| @Test |
| public void testSparseTransform() throws Exception { |
| VectorSlicer vectorSlicer = |
| new VectorSlicer() |
| .setInputCol("sparseVec") |
| .setOutputCol("sliceVec") |
| .setIndices(0, 1, 2); |
| Table output = vectorSlicer.transform(inputDataTable)[0]; |
| verifyOutputResult(output, vectorSlicer.getOutputCol(), true); |
| } |
| } |