blob: 48a2664577b784ab67606f761a36fb00bf2fd63e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.feature;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.feature.standardscaler.StandardScaler;
import org.apache.flink.ml.feature.standardscaler.StandardScalerModel;
import org.apache.flink.ml.feature.standardscaler.StandardScalerModelData;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
/** Tests {@link StandardScaler} and {@link StandardScalerModel}. */
public class StandardScalerTest extends AbstractTestBase {
@Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
private StreamExecutionEnvironment env;
private StreamTableEnvironment tEnv;
private Table denseTable;
private final List<Row> denseInput =
Arrays.asList(
Row.of(Vectors.dense(-2.5, 9, 1)),
Row.of(Vectors.dense(1.4, -5, 1)),
Row.of(Vectors.dense(2, -1, -2)));
private final List<DenseVector> expectedResWithMean =
Arrays.asList(
Vectors.dense(-2.8, 8, 1),
Vectors.dense(1.1, -6, 1),
Vectors.dense(1.7, -2, -2));
private final List<DenseVector> expectedResWithStd =
Arrays.asList(
Vectors.dense(-1.0231819, 1.2480754, 0.5773502),
Vectors.dense(0.5729819, -0.6933752, 0.5773503),
Vectors.dense(0.8185455, -0.1386750, -1.1547005));
private final List<DenseVector> expectedResWithMeanAndStd =
Arrays.asList(
Vectors.dense(-1.1459637, 1.1094004, 0.5773503),
Vectors.dense(0.45020003, -0.8320503, 0.5773503),
Vectors.dense(0.69576368, -0.2773501, -1.1547005));
private final double[] expectedMean = new double[] {0.3, 1, 0};
private final double[] expectedStd = new double[] {2.4433583, 7.2111026, 1.7320508};
private static final double TOLERANCE = 1e-7;
@Before
public void before() {
Configuration config = new Configuration();
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
env = StreamExecutionEnvironment.getExecutionEnvironment(config);
env.setParallelism(4);
env.enableCheckpointing(100);
env.setRestartStrategy(RestartStrategies.noRestart());
tEnv = StreamTableEnvironment.create(env);
denseTable = tEnv.fromDataStream(env.fromCollection(denseInput)).as("input");
}
@SuppressWarnings("unchecked")
private void verifyPredictionResult(
List<DenseVector> expectedOutput, Table output, String predictionCol) throws Exception {
List<Row> collectedResult =
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
List<DenseVector> predictions = new ArrayList<>(collectedResult.size());
for (Row r : collectedResult) {
Vector vec = (Vector) r.getField(predictionCol);
predictions.add(vec.toDense());
}
assertEquals(expectedOutput.size(), predictions.size());
predictions.sort(
(vec1, vec2) -> {
int size = Math.min(vec1.size(), vec2.size());
for (int i = 0; i < size; i++) {
int cmp = Double.compare(vec1.get(i), vec2.get(i));
if (cmp != 0) {
return cmp;
}
}
return 0;
});
for (int i = 0; i < predictions.size(); i++) {
assertArrayEquals(expectedOutput.get(i).values, predictions.get(i).values, TOLERANCE);
}
}
@Test
public void testParam() {
StandardScaler standardScaler = new StandardScaler();
assertEquals("input", standardScaler.getInputCol());
assertEquals(false, standardScaler.getWithMean());
assertEquals(true, standardScaler.getWithStd());
assertEquals("output", standardScaler.getOutputCol());
standardScaler
.setInputCol("test_input")
.setWithMean(true)
.setWithStd(false)
.setOutputCol("test_output");
assertEquals("test_input", standardScaler.getInputCol());
assertEquals(true, standardScaler.getWithMean());
assertEquals(false, standardScaler.getWithStd());
assertEquals("test_output", standardScaler.getOutputCol());
}
@Test
public void testOutputSchema() {
Table tempTable = denseTable.as("test_input");
StandardScaler standardScaler =
new StandardScaler().setInputCol("test_input").setOutputCol("test_output");
Table output = standardScaler.fit(tempTable).transform(tempTable)[0];
assertEquals(
Arrays.asList("test_input", "test_output"),
output.getResolvedSchema().getColumnNames());
}
@Test
public void testFitAndPredictWithStd() throws Exception {
StandardScaler standardScaler = new StandardScaler();
Table output = standardScaler.fit(denseTable).transform(denseTable)[0];
verifyPredictionResult(expectedResWithStd, output, standardScaler.getOutputCol());
}
@Test
public void testFitAndPredictWithMean() throws Exception {
StandardScaler standardScaler = new StandardScaler().setWithStd(false).setWithMean(true);
Table output = standardScaler.fit(denseTable).transform(denseTable)[0];
verifyPredictionResult(expectedResWithMean, output, standardScaler.getOutputCol());
}
@Test
public void testFitAndPredictWithMeanAndStd() throws Exception {
StandardScaler standardScaler = new StandardScaler().setWithMean(true);
Table output = standardScaler.fit(denseTable).transform(denseTable)[0];
verifyPredictionResult(expectedResWithMeanAndStd, output, standardScaler.getOutputCol());
}
@Test
public void testInputTypeConversion() throws Exception {
denseTable = TestUtils.convertDataTypesToSparseInt(tEnv, denseTable);
assertArrayEquals(
new Class<?>[] {SparseVector.class}, TestUtils.getColumnDataTypes(denseTable));
StandardScaler standardScaler = new StandardScaler().setWithMean(true);
Table output = standardScaler.fit(denseTable).transform(denseTable)[0];
verifyPredictionResult(expectedResWithMeanAndStd, output, standardScaler.getOutputCol());
}
@Test
public void testSaveLoadAndPredict() throws Exception {
StandardScaler standardScaler = new StandardScaler();
standardScaler =
TestUtils.saveAndReload(
tEnv, standardScaler, tempFolder.newFolder().getAbsolutePath());
StandardScalerModel model = standardScaler.fit(denseTable);
model = TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
assertEquals(
Arrays.asList("mean", "std"),
model.getModelData()[0].getResolvedSchema().getColumnNames());
Table output = model.transform(denseTable)[0];
verifyPredictionResult(expectedResWithStd, output, standardScaler.getOutputCol());
}
@Test
@SuppressWarnings("unchecked")
public void testGetModelData() throws Exception {
StandardScaler standardScaler = new StandardScaler();
StandardScalerModel model = standardScaler.fit(denseTable);
Table modelDataTable = model.getModelData()[0];
assertEquals(
Arrays.asList("mean", "std"), modelDataTable.getResolvedSchema().getColumnNames());
List<StandardScalerModelData> collectedModelData =
(List<StandardScalerModelData>)
IteratorUtils.toList(
StandardScalerModelData.getModelDataStream(modelDataTable)
.executeAndCollect());
assertEquals(1, collectedModelData.size());
StandardScalerModelData modelData = collectedModelData.get(0);
assertArrayEquals(expectedMean, modelData.mean.values, TOLERANCE);
assertArrayEquals(expectedStd, modelData.std.values, TOLERANCE);
}
@Test
public void testSetModelData() throws Exception {
StandardScaler standardScaler = new StandardScaler();
StandardScalerModel model = standardScaler.fit(denseTable);
StandardScalerModel newModel = new StandardScalerModel();
ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
newModel.setModelData(model.getModelData());
Table output = newModel.transform(denseTable)[0];
verifyPredictionResult(expectedResWithStd, output, standardScaler.getOutputCol());
}
@Test
public void testSparseInput() throws Exception {
final List<Row> sparseInput =
Arrays.asList(
Row.of(Vectors.sparse(3, new int[] {0, 1}, new double[] {-2.5, 1})),
Row.of(Vectors.sparse(3, new int[] {1, 2}, new double[] {2, -2})),
Row.of(Vectors.sparse(3, new int[] {0, 2}, new double[] {1.4, 1})));
Table sparseTable = tEnv.fromDataStream(env.fromCollection(sparseInput)).as("input");
final List<DenseVector> expectedResWithStd =
Arrays.asList(
Vectors.dense(-1.2653836, 1, 0),
Vectors.dense(0, 2, -1.30930734),
Vectors.dense(0.7086148, 0, 0.6546537));
StandardScaler standardScaler = new StandardScaler();
Table output = standardScaler.fit(sparseTable).transform(sparseTable)[0];
verifyPredictionResult(expectedResWithStd, output, standardScaler.getOutputCol());
}
@Test
@SuppressWarnings("unchecked")
public void testFitOnEmptyData() throws Exception {
Table emptyTable =
tEnv.fromDataStream(env.fromCollection(denseInput).filter(x -> x.getArity() == 0))
.as("input");
StandardScalerModel model = new StandardScaler().fit(emptyTable);
Table modelDataTable = model.getModelData()[0];
try {
IteratorUtils.toList(
StandardScalerModelData.getModelDataStream(modelDataTable).executeAndCollect());
fail();
} catch (Throwable e) {
assertEquals("The training set is empty.", ExceptionUtils.getRootCause(e).getMessage());
}
}
}