blob: 32586864b9160482cbcbcb6c6d595390f89398c4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.dataset.feature.extractor;
import java.io.Serializable;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.ignite.binary.BinaryObject;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
/**
* Class for extracting labeled vectors from upstream. This is an abstract class providing API for extracting feature
* and label values by "coordinates" of them from upstream objects. For example {@link BinaryObject} can be upstream
* object and coordinates for them are names of fields with double-values.
*
* @param <K> Type of keys in upstream.
* @param <V> Type of values in upstream.
* @param <C> Type of "coordinate" - index of feature value in upstream object.
* @param <L> Type of label for resulting vectors.
*/
public abstract class Vectorizer<K, V, C extends Serializable, L> implements FeatureLabelExtractor<K, V, L> {
/** Label coordinate shortcut. */
private LabelCoordinate lbCoordinateShortcut;
/** Serial version uid. */
private static final long serialVersionUID = 4301406952131379459L;
/** If useAllValues == true then Vectorizer extract all fields as features from upstream object (except label). */
private final boolean useAllValues;
/** Extraction coordinates. */
private List<C> extractionCoordinates;
/** Label coordinate. */
private C labelCoord;
/**
* Extracts labeled vector from upstream object.
*
* @param key Key.
* @param value Value.
* @return vector.
*/
@Override public LabeledVector<L> apply(K key, V value) {
L lbl = isLabeled() ? label(labelCoord(key, value), key, value) : zero();
List<C> allCoords = null;
if (useAllValues) {
allCoords = allCoords(key, value).stream()
.filter(coord -> !coord.equals(labelCoord) && !excludedCoords.contains(coord))
.collect(Collectors.toList());
}
int vectorLen = useAllValues ? allCoords.size() : extractionCoordinates.size();
A.ensure(vectorLen >= 0, "vectorLength >= 0");
List<C> coordinatesForExtraction = useAllValues ? allCoords : extractionCoordinates;
Vector vector = createVector(vectorLen);
for (int i = 0; i < coordinatesForExtraction.size(); i++) {
Serializable feature = feature(coordinatesForExtraction.get(i), key, value);
if (feature != null)
vector.setRaw(i, feature);
}
return new LabeledVector<>(vector, lbl);
}
/** Excluded coordinates. */
private HashSet<C> excludedCoords = new HashSet<>();
/**
* Creates an instance of Vectorizer.
*
* @param coords Coordinates for feature extraction. If array is empty then Vectorizer will extract all fields from
* upstream object.
*/
public Vectorizer(C... coords) {
extractionCoordinates = Arrays.asList(coords);
this.useAllValues = coords.length == 0;
}
/**
* @return true if label in vector is valid.
*/
private boolean isLabeled() {
return labelCoord != null || lbCoordinateShortcut != null;
}
/**
* Evaluates label coordinate if need.
*
* @param key Key.
* @param value Value.
* @return label coordinate.
*/
private C labelCoord(K key, V value) {
A.ensure(isLabeled(), "isLabeled");
if (labelCoord != null)
return labelCoord;
else {
List<C> allCoords = allCoords(key, value);
A.ensure(!allCoords.isEmpty(), "!allCoords.isEmpty()");
switch (lbCoordinateShortcut) {
case FIRST:
labelCoord = allCoords.get(0);
break;
case LAST:
labelCoord = allCoords.get(allCoords.size() - 1);
break;
default:
throw new IllegalArgumentException();
}
return labelCoord;
}
}
/**
* Sets label coordinate for Vectorizer. By default it equals null and zero() will be used as label value.
*
* @param labelCoord Label coordinate.
* @return this.
*/
public Vectorizer<K, V, C, L> labeled(C labelCoord) {
this.labelCoord = labelCoord;
this.lbCoordinateShortcut = null;
return this;
}
/**
* Sets label coordinate for Vectorizer. By default it equals null and zero() will be used as label value.
*
* @param labelCoord Label coordinate.
* @return this.
*/
public Vectorizer<K, V, C, L> labeled(LabelCoordinate labelCoord) {
this.lbCoordinateShortcut = labelCoord;
this.labelCoord = null;
return this;
}
/**
* Exclude these coordinates from result vector.
*
* @param coords Coordinates.
* @return this.
*/
public Vectorizer<K, V, C, L> exclude(C... coords) {
this.excludedCoords.addAll(Arrays.asList(coords));
return this;
}
/**
* Shotrcuts for coordinates in feature vector.
*/
public enum LabelCoordinate {
/** First. */
FIRST,
/** Last. */
LAST
}
/** {@inheritDoc} */
@Override public LabeledVector<L> extract(K k, V v) {
return apply(k, v);
}
/**
* Extracts feature value by given coordinate.
*
* @param coord Coordinate.
* @param key Key.
* @param value Value.
* @return feature value.
*/
protected abstract Serializable feature(C coord, K key, V value);
/**
* Extract label value by given coordinate.
*
* @param coord Coordinate.
* @param key Key.
* @param value Value.
* @return label value.
*/
protected abstract L label(C coord, K key, V value);
/**
* Returns default label value for unlabeled data.
*
* @return label value.
*/
protected abstract L zero();
/**
* Returns list of all coordinate with feature values.
*
* @param key Key.
* @param value Value.
* @return all coordinates list.
*/
protected abstract List<C> allCoords(K key, V value);
/**
* Create an instance of vector.
*
* @param size Vector size.
* @return vector.
*/
protected Vector createVector(int size) {
return new DenseVector(size);
}
/**
* Utility class for convenient overridings.
*
* @param <K> Type of key.
* @param <V> Type of value.
* @param <C> Type of coordinate.
* @param <L> Type od label.
*/
public abstract static class VectorizerAdapter<K, V, C extends Serializable, L> extends Vectorizer<K, V, C, L> {
/** {@inheritDoc} */
@Override protected Double feature(C coord, K key, V value) {
throw new IllegalStateException();
}
/** {@inheritDoc} */
@Override protected L label(C coord, K key, V value) {
throw new IllegalStateException();
}
/** {@inheritDoc} */
@Override protected L zero() {
throw new IllegalStateException();
}
/** {@inheritDoc} */
@Override protected List<C> allCoords(K key, V value) {
throw new IllegalStateException();
}
}
}