blob: b93c857b70c6299f7630a64e69a45bf15db6b4ee [file] [log] [blame]
/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2015 Makoto YUI
* Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hivemall.fm;
import hivemall.fm.Entry.AdaGradEntry;
import hivemall.fm.Entry.FTRLEntry;
import hivemall.fm.FMHyperParameters.FFMHyperParameters;
import hivemall.utils.buffer.HeapBuffer;
import hivemall.utils.collections.Int2LongOpenHashTable;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.math.MathUtils;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachineModel {
private static final int DEFAULT_MAPSIZE = 65536;
// LEARNING PARAMS
private float _w0;
@Nonnull
private final Int2LongOpenHashTable _map;
private final HeapBuffer _buf;
// hyperparams
private final int _numFeatures;
private final int _numFields;
// FTEL
private final float _alpha;
private final float _beta;
private final float _lambda1;
private final float _lamdda2;
private final int _entrySize;
public FFMStringFeatureMapModel(@Nonnull FFMHyperParameters params) {
super(params);
this._w0 = 0.f;
this._map = new Int2LongOpenHashTable(DEFAULT_MAPSIZE);
this._buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE);
this._numFeatures = params.numFeatures;
this._numFields = params.numFields;
this._alpha = params.alphaFTRL;
this._beta = params.betaFTRL;
this._lambda1 = params.lambda1;
this._lamdda2 = params.lamdda2;
this._entrySize = entrySize(_factor, _useFTRL, _useAdaGrad);
}
@Nonnull
FFMPredictionModel toPredictionModel() {
return new FFMPredictionModel(_map, _buf, _w0, _factor, _numFeatures, _numFields);
}
@Override
public int getSize() {
return _map.size();
}
@Override
public float getW0() {
return _w0;
}
@Override
protected void setW0(float nextW0) {
this._w0 = nextW0;
}
@Override
public float getW(@Nonnull final Feature x) {
int j = x.getFeatureIndex();
Entry entry = getEntry(j);
if (entry == null) {
return 0.f;
}
return entry.getW();
}
@Override
protected void setW(@Nonnull final Feature x, final float nextWi) {
final int j = x.getFeatureIndex();
Entry entry = getEntry(j);
if (entry == null) {
float[] V = initV();
entry = newEntry(nextWi, V);
long ptr = entry.getOffset();
_map.put(j, ptr);
} else {
entry.setW(nextWi);
}
}
@Override
void updateWi(final double dloss, @Nonnull final Feature x, final float eta) {
final double Xi = x.getValue();
float gradWi = (float) (dloss * Xi);
final Entry theta = getEntry(x);
float wi = theta.getW();
float nextWi = wi - eta * (gradWi + 2.f * _lambdaW * wi);
if (!NumberUtils.isFinite(nextWi)) {
throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature()
+ "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss
+ ", eta=" + eta);
}
theta.setW(nextWi);
}
/**
* Update Wi using Follow-the-Regularized-Leader
*/
boolean updateWiFTRL(final double dloss, @Nonnull final Feature x, final float eta) {
final double Xi = x.getValue();
float gradWi = (float) (dloss * Xi);
final Entry theta = getEntry(x);
float wi = theta.getW();
final float z = theta.updateZ(gradWi, _alpha);
final double n = theta.updateN(gradWi);
if (Math.abs(z) <= _lambda1) {
removeEntry(x);
return wi != 0;
}
final float nextWi = (float) ((MathUtils.sign(z) * _lambda1 - z) / ((_beta + Math.sqrt(n))
/ _alpha + _lamdda2));
if (!NumberUtils.isFinite(nextWi)) {
throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature()
+ "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss
+ ", eta=" + eta + ", n=" + n + ", z=" + z);
}
theta.setW(nextWi);
return (nextWi != 0) || (wi != 0);
}
/**
* @return V_x,yField,f
*/
@Override
public float getV(@Nonnull final Feature x, @Nonnull final int yField, final int f) {
final int j = Feature.toIntFeature(x, yField, _numFields);
Entry entry = getEntry(j);
if (entry == null) {
float[] V = initV();
entry = newEntry(V);
long ptr = entry.getOffset();
_map.put(j, ptr);
}
return entry.getV(f);
}
@Override
protected void setV(@Nonnull final Feature x, @Nonnull final int yField, final int f,
final float nextVif) {
final int j = Feature.toIntFeature(x, yField, _numFields);
Entry entry = getEntry(j);
if (entry == null) {
float[] V = initV();
entry = newEntry(V);
long ptr = entry.getOffset();
_map.put(j, ptr);
}
entry.setV(f, nextVif);
}
@Override
protected Entry getEntry(@Nonnull final Feature x) {
final int j = x.getFeatureIndex();
Entry entry = getEntry(j);
if (entry == null) {
float[] V = initV();
entry = newEntry(V);
long ptr = entry.getOffset();
_map.put(j, ptr);
}
return entry;
}
@Override
protected Entry getEntry(@Nonnull final Feature x, @Nonnull final int yField) {
final int j = Feature.toIntFeature(x, yField, _numFields);
Entry entry = getEntry(j);
if (entry == null) {
float[] V = initV();
entry = newEntry(V);
long ptr = entry.getOffset();
_map.put(j, ptr);
}
return entry;
}
protected void removeEntry(@Nonnull final Feature x) {
int j = x.getFeatureIndex();
_map.remove(j);
}
@Nonnull
protected final Entry newEntry(final float W, @Nonnull final float[] V) {
Entry entry = newEntry();
entry.setW(W);
entry.setV(V);
return entry;
}
@Nonnull
protected final Entry newEntry(@Nonnull final float[] V) {
Entry entry = newEntry();
entry.setV(V);
return entry;
}
@Nonnull
private Entry newEntry() {
if (_useFTRL) {
long ptr = _buf.allocate(_entrySize);
return new FTRLEntry(_buf, _factor, ptr);
} else if (_useAdaGrad) {
long ptr = _buf.allocate(_entrySize);
return new AdaGradEntry(_buf, _factor, ptr);
} else {
long ptr = _buf.allocate(_entrySize);
return new Entry(_buf, _factor, ptr);
}
}
@Nullable
private Entry getEntry(final int key) {
final long ptr = _map.get(key);
if (ptr == -1L) {
return null;
}
return getEntry(ptr);
}
@Nonnull
private Entry getEntry(long ptr) {
if (_useFTRL) {
return new FTRLEntry(_buf, _factor, ptr);
} else if (_useAdaGrad) {
return new AdaGradEntry(_buf, _factor, ptr);
} else {
return new Entry(_buf, _factor, ptr);
}
}
private static int entrySize(int factors, boolean ftrl, boolean adagrad) {
if (ftrl) {
return FTRLEntry.sizeOf(factors);
} else if (adagrad) {
return AdaGradEntry.sizeOf(factors);
} else {
return Entry.sizeOf(factors);
}
}
}