blob: d007dfeff3fc2d7f64acd6aeb3ae9d265223c6f3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.nn;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.ignite.ml.TestUtils;
import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer;
import org.apache.ignite.ml.math.primitives.matrix.Matrix;
import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.optimization.LossFunctions;
import org.apache.ignite.ml.optimization.SmoothParametrized;
import org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator;
import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
import org.apache.ignite.ml.structures.LabeledVector;
import org.junit.Before;
import org.junit.Test;
import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
/**
* Tests for {@link MLPTrainer} that don't require to start the whole Ignite infrastructure.
*/
@RunWith(Enclosed.class)
public class MLPTrainerTest {
/**
* Parameterized tests.
*/
@RunWith(Parameterized.class)
public static class ComponentParamTests {
/** Number of parts to be tested. */
private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7};
/** Batch sizes to be tested. */
private static final int[] batchSizesToBeTested = new int[] {1, 2, 3, 4};
/** Parameters. */
@Parameterized.Parameters(name = "Data divided on {0} partitions, training with batch size {1}")
public static Iterable<Integer[]> data() {
List<Integer[]> res = new ArrayList<>();
for (int part : partsToBeTested)
for (int batchSize1 : batchSizesToBeTested)
res.add(new Integer[] {part, batchSize1});
return res;
}
/** Number of partitions. */
@Parameterized.Parameter
public int parts;
/** Batch size. */
@Parameterized.Parameter(1)
public int batchSize;
/**
* Test 'XOR' operation training with {@link SimpleGDUpdateCalculator} updater.
*/
@Test
public void testXORSimpleGD() {
xorTest(new UpdatesStrategy<>(
new SimpleGDUpdateCalculator(0.2),
SimpleGDParameterUpdate.SUM_LOCAL,
SimpleGDParameterUpdate.AVG
));
}
/**
* Test 'XOR' operation training with {@link RPropUpdateCalculator}.
*/
@Test
public void testXORRProp() {
xorTest(new UpdatesStrategy<>(
new RPropUpdateCalculator(),
RPropParameterUpdate.SUM_LOCAL,
RPropParameterUpdate.AVG
));
}
/**
* Test 'XOR' operation training with {@link NesterovUpdateCalculator}.
*/
@Test
public void testXORNesterov() {
xorTest(new UpdatesStrategy<>(
new NesterovUpdateCalculator<MultilayerPerceptron>(0.1, 0.7),
NesterovParameterUpdate::sum,
NesterovParameterUpdate::avg
));
}
/**
* Common method for testing 'XOR' with various updaters.
* @param updatesStgy Update strategy.
* @param <P> Updater parameters type.
*/
private <P extends Serializable> void xorTest(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy) {
Map<Integer, LabeledVector<double[]>> xorData = new HashMap<>();
xorData.put(0, VectorUtils.of(0.0, 0.0).labeled(new double[]{0.0}));
xorData.put(1, VectorUtils.of(0.0, 1.0).labeled(new double[]{1.0}));
xorData.put(2, VectorUtils.of(1.0, 0.0).labeled(new double[]{1.0}));
xorData.put(3, VectorUtils.of(1.0, 1.0).labeled(new double[]{0.0}));
MLPArchitecture arch = new MLPArchitecture(2).
withAddedLayer(10, true, Activators.RELU).
withAddedLayer(1, false, Activators.SIGMOID);
MLPTrainer<P> trainer = new MLPTrainer<>(
arch,
LossFunctions.MSE,
updatesStgy,
3000,
batchSize,
50,
123L
);
MultilayerPerceptron mlp = trainer.fit(xorData, parts, new LabeledDummyVectorizer<>());
Matrix predict = mlp.predict(new DenseMatrix(new double[][]{
{0.0, 0.0},
{0.0, 1.0},
{1.0, 0.0},
{1.0, 1.0}
}));
TestUtils.checkIsInEpsilonNeighbourhood(new DenseVector(new double[]{0.0}), predict.getRow(0), 1E-1);
}
/** */
@Test
public void testUpdate() {
UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(
new SimpleGDUpdateCalculator(0.2),
SimpleGDParameterUpdate.SUM_LOCAL,
SimpleGDParameterUpdate.AVG
);
Map<Integer, LabeledVector<double[]>> xorData = new HashMap<>();
xorData.put(0, VectorUtils.of(0.0, 0.0).labeled(new double[]{0.0}));
xorData.put(1, VectorUtils.of(0.0, 1.0).labeled(new double[]{1.0}));
xorData.put(2, VectorUtils.of(1.0, 0.0).labeled(new double[]{1.0}));
xorData.put(3, VectorUtils.of(1.0, 1.0).labeled(new double[]{0.0}));
MLPArchitecture arch = new MLPArchitecture(2).
withAddedLayer(10, true, Activators.RELU).
withAddedLayer(1, false, Activators.SIGMOID);
MLPTrainer<SimpleGDParameterUpdate> trainer = new MLPTrainer<>(
arch,
LossFunctions.MSE,
updatesStgy,
3000,
batchSize,
50,
123L
);
MultilayerPerceptron originalMdl = trainer.fit(xorData, parts, new LabeledDummyVectorizer<>());
MultilayerPerceptron updatedOnSameDS = trainer.update(
originalMdl,
xorData,
parts,
new LabeledDummyVectorizer<>()
);
MultilayerPerceptron updatedOnEmptyDS = trainer.update(
originalMdl,
new HashMap<Integer, LabeledVector<double[]>>(),
parts,
new LabeledDummyVectorizer<>()
);
DenseMatrix matrix = new DenseMatrix(new double[][] {
{0.0, 0.0},
{0.0, 1.0},
{1.0, 0.0},
{1.0, 1.0}
});
TestUtils.checkIsInEpsilonNeighbourhood(originalMdl.predict(matrix).getRow(0), updatedOnSameDS.predict(matrix).getRow(0), 1E-1);
TestUtils.checkIsInEpsilonNeighbourhood(originalMdl.predict(matrix).getRow(0), updatedOnEmptyDS.predict(matrix).getRow(0), 1E-1);
}
}
/**
* Non-parameterized tests.
*/
public static class ComponentSingleTests {
/** Data. */
private double[] data;
/** Initialization. */
@Before
public void init() {
data = new double[10];
for (int i = 0; i < 10; i++)
data[i] = i;
}
/** */
@Test
public void testBatchWithSingleColumnAndSingleRow() {
double[] res = MLPTrainer.batch(data, new int[]{1}, 10);
TestUtils.assertEquals(new double[]{1.0}, res, 1e-12);
}
/** */
@Test
public void testBatchWithMultiColumnAndSingleRow() {
double[] res = MLPTrainer.batch(data, new int[]{1}, 5);
TestUtils.assertEquals(new double[]{1.0, 6.0}, res, 1e-12);
}
/** */
@Test
public void testBatchWithMultiColumnAndMultiRow() {
double[] res = MLPTrainer.batch(data, new int[]{1, 3}, 5);
TestUtils.assertEquals(new double[]{1.0, 3.0, 6.0, 8.0}, res, 1e-12);
}
}
}