blob: 0d35df58b84cbef87d333200ceabe6e916b18e9d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.common;
import java.util.UUID;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.binary.BinaryObject;
import org.apache.ignite.binary.BinaryObjectBuilder;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
import org.apache.ignite.ml.dataset.feature.extractor.impl.BinaryObjectVectorizer;
import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
import org.junit.Test;
/**
* Test for IGNITE-10700.
*/
public class KeepBinaryTest extends GridCommonAbstractTest {
/** Number of nodes in grid. */
private static final int NODE_COUNT = 2;
/** Number of samples. */
public static final int NUMBER_OF_SAMPLES = 1000;
/** Half of samples. */
public static final int HALF = NUMBER_OF_SAMPLES / 2;
/** Ignite instance. */
private Ignite ignite;
/** {@inheritDoc} */
@Override protected void beforeTestsStarted() throws Exception {
for (int i = 1; i <= NODE_COUNT; i++)
startGrid(i);
}
/** {@inheritDoc} */
@Override protected void afterTestsStopped() {
stopAllGrids();
}
/** {@inheritDoc} */
@Override protected void beforeTest() {
/* Grid instance. */
ignite = grid(NODE_COUNT);
ignite.configuration().setPeerClassLoadingEnabled(true);
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
}
/**
* Startup Ignite, populate cache and train some model.
*/
@Test
public void test() {
IgniteCache<Integer, BinaryObject> dataCache = populateCache(ignite);
KMeansTrainer trainer = new KMeansTrainer();
CacheBasedDatasetBuilder<Integer, BinaryObject> datasetBuilder =
new CacheBasedDatasetBuilder<>(ignite, dataCache).withKeepBinary(true);
KMeansModel mdl = trainer.fit(datasetBuilder, new BinaryObjectVectorizer<Integer>("feature1").labeled("label"));
Integer zeroCentre = mdl.predict(VectorUtils.num2Vec(0.0));
assertTrue(mdl.getCenters()[zeroCentre].get(0) == 0);
}
/**
* Populate cache with binary objects.
*/
private IgniteCache<Integer, BinaryObject> populateCache(Ignite ignite) {
CacheConfiguration<Integer, BinaryObject> cacheConfiguration = new CacheConfiguration<>();
cacheConfiguration.setName("TEST_" + UUID.randomUUID());
IgniteCache<Integer, BinaryObject> cache = ignite.createCache(cacheConfiguration).withKeepBinary();
BinaryObjectBuilder builder = ignite.binary().builder("testType");
for (int i = 0; i < NUMBER_OF_SAMPLES; i++) {
if (i < HALF)
cache.put(i, builder.setField("feature1", 0.0).setField("label", 0.0).build());
else
cache.put(i, builder.setField("feature1", 10.0).setField("label", 1.0).build());
}
return cache;
}
}