blob: 4c4e11c40a1c55d3c709004b4862957f22ceb54a [file] [log] [blame]
package org.apache.wayang.spark.operators;
import org.apache.wayang.basic.data.Tuple2;
import org.apache.wayang.basic.model.LinearRegressionModel;
import org.apache.wayang.basic.operators.ModelTransformOperator;
import org.apache.wayang.core.platform.ChannelInstance;
import org.apache.wayang.java.channels.CollectionChannel;
import org.apache.wayang.spark.channels.RddChannel;
import org.apache.wayang.spark.operators.ml.SparkLinearRegressionOperator;
import org.apache.wayang.spark.operators.ml.SparkModelTransformOperator;
import org.junit.Assert;
import org.junit.Test;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class SparkLinearRegressionOperatorTest extends SparkOperatorTestBase {
// y = x1 + x2 + 1
public static List<Tuple2<double[], Double>> trainingData = Arrays.asList(
new Tuple2<>(new double[]{1, 1}, 3D),
new Tuple2<>(new double[]{1, -1}, 1D),
new Tuple2<>(new double[]{3, 2}, 6D)
);
public static List<double[]> inferenceData = Arrays.asList(
new double[]{1, 2},
new double[]{1, -2}
);
public LinearRegressionModel getModel() {
// Prepare test data.
RddChannel.Instance input = this.createRddChannelInstance(trainingData);
CollectionChannel.Instance output = this.createCollectionChannelInstance();
SparkLinearRegressionOperator linearRegressionOperator = new SparkLinearRegressionOperator(true);
// Set up the ChannelInstances.
ChannelInstance[] inputs = new ChannelInstance[]{input};
ChannelInstance[] outputs = new ChannelInstance[]{output};
// Execute.
this.evaluate(linearRegressionOperator, inputs, outputs);
// Verify the outcome.
return output.<LinearRegressionModel>provideCollection().iterator().next();
}
@Test
public void testTraining() {
final LinearRegressionModel model = getModel();
Assert.assertArrayEquals(new double[]{1, 1}, model.getCoefficients(), 1e-6);
Assert.assertEquals(1, model.getIntercept(), 1e-6);
}
@Test
public void testInference() {
// Prepare test data.
CollectionChannel.Instance input1 = this.createCollectionChannelInstance(Collections.singletonList(getModel()));
RddChannel.Instance input2 = this.createRddChannelInstance(inferenceData);
RddChannel.Instance output = this.createRddChannelInstance();
SparkModelTransformOperator<double[], Double> transformOperator = new SparkModelTransformOperator<>(ModelTransformOperator.linearRegression());
// Set up the ChannelInstances.
ChannelInstance[] inputs = new ChannelInstance[]{input1, input2};
ChannelInstance[] outputs = new ChannelInstance[]{output};
// Execute.
this.evaluate(transformOperator, inputs, outputs);
// Verify the outcome.
final List<Tuple2<double[], Double>> results = output.<Tuple2<double[], Double>>provideRdd().collect();
Assert.assertEquals(2, results.size());
Assert.assertEquals(4, results.get(0).field1, 1e-6);
Assert.assertEquals(0, results.get(1).field1, 1e-6);
}
}