blob: 5eb3aa15874db4673e6e4594238211fd8ca9a742 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.factorization.fm;
import hivemall.optimizer.EtaEstimator;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.math.MathUtils;
import java.util.Arrays;
import java.util.Objects;
import java.util.Random;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.metadata.HiveException;
public abstract class FactorizationMachineModel {
protected final boolean _classification;
protected final int _factor;
protected final double _sigma;
@Nonnull
protected final EtaEstimator _eta;
@Nonnull
protected final VInitScheme _initScheme;
@Nonnull
protected final Random _rnd;
// Hyperparameter for regression
protected final double _min_target;
protected final double _max_target;
// Regulation Variables
protected float _lambdaW0;
protected float _lambdaW;
@Nonnull
protected final float[] _lambdaV;
public FactorizationMachineModel(@Nonnull FMHyperParameters params) {
this._classification = params.classification;
this._factor = params.factors;
this._sigma = params.sigma;
this._eta = Objects.requireNonNull(params.eta);
this._initScheme = Objects.requireNonNull(params.vInit);
this._rnd = new Random(params.seed);
this._min_target = params.minTarget;
this._max_target = params.maxTarget;
// Regulation Variables
this._lambdaW0 = params.lambdaW0;
this._lambdaW = params.lambdaW;
this._lambdaV = new float[params.factors];
Arrays.fill(_lambdaV, params.lambdaV);
}
public abstract int getSize();
protected int getMinIndex() {
throw new UnsupportedOperationException();
}
protected int getMaxIndex() {
throw new UnsupportedOperationException();
}
public abstract float getW0();
protected abstract void setW0(float nextW0);
/**
* @param i index value >= 1
*/
protected float getW(int i) {
throw new UnsupportedOperationException();
}
public abstract float getW(@Nonnull Feature x);
protected abstract void setW(@Nonnull Feature x, float nextWi);
/**
* @param i index value >= 1
*/
@Nullable
protected float[] getV(int i, boolean init) {
throw new UnsupportedOperationException();
}
public abstract float getV(@Nonnull Feature x, int f);
protected abstract void setV(@Nonnull Feature x, int f, float nextVif);
/**
* @param f index value >= 0
*/
float getLambdaV(int f) {
return _lambdaV[f];
}
final double dloss(@Nonnull final Feature[] x, final double y) throws HiveException {
double p = predict(x);
return dloss(p, y);
}
final double dloss(double p, final double y) {
final double ret;
if (_classification) {
ret = (MathUtils.sigmoid(p * y) - 1.d) * y;
} else { // regression
p = Math.min(p, _max_target);
p = Math.max(p, _min_target);
//ret = 2.d * (p - y);
ret = p - y;
}
return ret;
}
protected double predict(@Nonnull final Feature[] x) throws HiveException {
// w0
double ret = getW0();
// W
for (Feature e : x) {
double xj = e.getValue();
float w = getW(e);
double wx = w * xj;
ret += wx;
}
// V
for (int f = 0, k = _factor; f < k; f++) {
double sumVjfXj = 0.d;
double sumV2X2 = 0.d;
for (Feature e : x) {
double xj = e.getValue();
float vjf = getV(e, f);
double vx = vjf * xj;
sumVjfXj += vx;
sumV2X2 += (vx * vx);
}
ret += 0.5d * (sumVjfXj * sumVjfXj - sumV2X2);
assert (!Double.isNaN(ret));
}
if (!NumberUtils.isFinite(ret)) {
throw new HiveException(
"Detected " + ret + " in predict. We recommend to normalize training examples.\n"
+ "Dumping variables ...\n" + varDump(x));
}
return ret;
}
protected String varDump(@Nonnull final Feature[] x) {
final StringBuilder buf = new StringBuilder(1024);
for (int i = 0; i < x.length; i++) {
Feature e = x[i];
String j = e.getFeature();
double xj = e.getValue();
if (i != 0) {
buf.append(", ");
}
buf.append("x[").append(j).append("] = ").append(xj);
}
buf.append("\n");
buf.append("W0 = ").append(getW0()).append('\n');
for (int i = 0; i < x.length; i++) {
Feature e = x[i];
String j = e.getFeature();
float wi = getW(e);
if (i != 0) {
buf.append(", ");
}
buf.append("W[").append(j).append("] = ").append(wi);
}
buf.append("\n");
for (int f = 0, k = _factor; f < k; f++) {
for (int i = 0; i < x.length; i++) {
Feature e = x[i];
String j = e.getFeature();
float vjf = getV(e, f);
if (i != 0) {
buf.append(", ");
}
buf.append('V').append(f).append('[').append(j).append("] = ").append(vjf);
}
buf.append('\n');
}
return buf.toString();
}
final void updateW0(final double dloss, final float eta) {
float gradW0 = (float) dloss;
float prevW0 = getW0();
float nextW0 = prevW0 - eta * (gradW0 + 2.f * _lambdaW0 * prevW0);
if (!NumberUtils.isFinite(nextW0)) {
throw new IllegalStateException("Got " + nextW0 + " for next W0\n" + "gradW0=" + gradW0
+ ", prevW0=" + prevW0 + ", dloss=" + dloss + ", eta=" + eta);
}
setW0(nextW0);
}
/**
* @return whether to update V or not
*/
void updateWi(final double dloss, @Nonnull final Feature x, final float eta) {
final double Xi = x.getValue();
float gradWi = (float) (dloss * Xi);
float wi = getW(x);
float nextWi = wi - eta * (gradWi + 2.f * _lambdaW * wi);
if (!NumberUtils.isFinite(nextWi)) {
throw new IllegalStateException(
"Got " + nextWi + " for next W[" + x.getFeature() + "]\n" + "Xi=" + Xi + ", gradWi="
+ gradWi + ", wi=" + wi + ", dloss=" + dloss + ", eta=" + eta);
}
setW(x, nextWi);
}
final void updateV(final double dloss, @Nonnull final Feature x, final int f,
final double sumViX, final float eta) {
final double Xi = x.getValue();
float Vif = getV(x, f);
double h = gradV(Xi, Vif, sumViX);
float gradV = (float) (dloss * h);
float LambdaVf = getLambdaV(f);
float nextVif = Vif - eta * (gradV + 2.f * LambdaVf * Vif);
if (!NumberUtils.isFinite(nextVif)) {
throw new IllegalStateException(
"Got " + nextVif + " for next V" + f + '[' + x.getFeature() + "]\n" + "Xi=" + Xi
+ ", Vif=" + Vif + ", h=" + h + ", gradV=" + gradV + ", lambdaVf="
+ LambdaVf + ", dloss=" + dloss + ", sumViX=" + sumViX + ", eta=" + eta);
}
setV(x, f, nextVif);
}
final void updateLambdaW0(final double dloss, final float eta) {
float lambda_w_grad = -2.f * eta * getW0();
float lambdaW0 = _lambdaW0 - (float) (eta * dloss * lambda_w_grad);
this._lambdaW0 = Math.max(0.f, lambdaW0);
}
final void updateLambdaW(@Nonnull Feature[] x, double dloss, float eta) {
double sumWX = 0.d;
for (Feature e : x) {
assert (e != null) : Arrays.toString(x);
double xi = e.getValue();
sumWX += getW(e) * xi;
}
double lambda_w_grad = -2.f * eta * sumWX;
float lambdaW = _lambdaW - (float) (eta * dloss * lambda_w_grad);
this._lambdaW = Math.max(0.f, lambdaW);
}
/**
* <pre>
* grad_lambdafg := (grad l(y(x),y)) * -2 * alpha * ((\sum_{j} x_j * v'_jf) * (\sum_{j \in group(g)} x_j * v_jf) - \sum_{j \in group(g)} x^2_j * v_jf * v'_jf)
* := (grad l(y(x),y)) * -2 * alpha * (sum_f_dash * sum_f(g) - sum_f_dash_f(g))
* sum_f_dash := \sum_{j} x_j * v'_lj, this is independent of the groups
* sum_f(g) := \sum_{j \in group(g)} x_j * v_jf
* sum_f_dash_f(g) := \sum_{j \in group(g)} x^2_j * v_jf * v'_jf
* := \sum_{j \in group(g)} x_j * v'_jf * x_j * v_jf
* v_jf' := v_jf - alpha ( grad_v_jf + 2 * lambda_v_f * v_jf)
* </pre>
*/
final void updateLambdaV(@Nonnull final Feature[] x, final double dloss, final float eta) {
for (int f = 0, k = _factor; f < k; f++) {
double sum_f_dash = 0.d, sum_f = 0.d, sum_f_dash_f = 0.d;
float lambdaVf = getLambdaV(f);
final double sumVfX = sumVfX(x, f);
for (Feature e : x) {
assert (e != null) : Arrays.toString(x);
double x_j = e.getValue();
float v_jf = getV(e, f);
double gradV = gradV(x_j, v_jf, sumVfX);
double v_dash = v_jf - eta * (gradV + 2.d * lambdaVf * v_jf);
sum_f_dash += x_j * v_dash;
sum_f += x_j * v_jf;
sum_f_dash_f += x_j * v_dash * x_j * v_jf;
}
double lambda_v_grad = -2.f * eta * (sum_f_dash * sum_f - sum_f_dash_f);
lambdaVf -= eta * dloss * lambda_v_grad;
_lambdaV[f] = Math.max(0.f, lambdaVf);
}
}
double[] sumVfX(@Nonnull final Feature[] x) {
final int k = _factor;
final double[] ret = new double[k];
for (int f = 0; f < k; f++) {
ret[f] = sumVfX(x, f);
}
return ret;
}
private double sumVfX(@Nonnull final Feature[] x, final int f) {
double ret = 0.d;
for (Feature e : x) {
double xj = e.getValue();
float Vjf = getV(e, f);
ret += Vjf * xj;
}
if (!NumberUtils.isFinite(ret)) {
throw new IllegalStateException(
"Got " + ret + " for sumV[ " + f + "]X.\n" + "x = " + Arrays.toString(x));
}
return ret;
}
/**
* <pre>
* grad_v_if := multi * (x_i * (sum_f - v_if * x_i))
* sum_f := \sum_j v_jf * x_j
* </pre>
*/
private double gradV(@Nonnull final double Xj, final float Vjf, final double sumVfX) {
return Xj * (sumVfX - Vjf * Xj);
}
public void check(@Nonnull Feature[] x) throws HiveException {}
public enum VInitScheme {
adjustedRandom /* default */, libffmRandom, random, gaussian;
@Nonnegative
float maxInitValue;
@Nonnegative
double initStdDev;
Random[] rand;
@Nonnull
public static VInitScheme resolve(@Nullable String opt) {
return resolve(opt, adjustedRandom);
}
@Nonnull
public static VInitScheme resolve(@Nullable String opt,
@Nonnull VInitScheme defaultScheme) {
if (opt == null) {
return defaultScheme;
} else if ("adjusted_random".equalsIgnoreCase(opt)
|| "adjustedRandom".equalsIgnoreCase(opt)) {
return adjustedRandom;
} else if ("libffm_random".equalsIgnoreCase(opt) || "libffmRandom".equalsIgnoreCase(opt)
|| "libffm".equalsIgnoreCase(opt)) {
return VInitScheme.libffmRandom;
} else if ("random".equalsIgnoreCase(opt)) {
return random;
} else if ("gaussian".equalsIgnoreCase(opt)) {
return gaussian;
}
return defaultScheme;
}
public void setMaxInitValue(float maxInitValue) {
this.maxInitValue = maxInitValue;
}
public void setInitStdDev(double initStdDev) {
this.initStdDev = initStdDev;
}
public void initRandom(int factor, long seed) {
final int size = (this != gaussian) ? 1 : factor;
this.rand = new Random[size];
for (int i = 0; i < size; i++) {
rand[i] = new Random(seed + i);
}
}
}
@Nonnull
protected final float[] initV() {
final float[] ret = new float[_factor];
switch (_initScheme) {
case adjustedRandom:
adjustedRandomFill(ret, _initScheme.rand[0], _initScheme.maxInitValue);
break;
case libffmRandom:
libffmRandomFill(ret, _initScheme.rand[0], _initScheme.maxInitValue);
break;
case random:
randomFill(ret, _initScheme.rand[0], _initScheme.maxInitValue);
break;
case gaussian:
gaussianFill(ret, _initScheme.rand, _initScheme.initStdDev);
break;
default:
throw new IllegalStateException(
"Unsupported V initialization scheme: " + _initScheme);
}
return ret;
}
protected static final void adjustedRandomFill(@Nonnull final float[] a,
@Nonnull final Random rand, final float maxInitValue) {
final int k = a.length;
final float basev = maxInitValue / k;
for (int i = 0; i < k; i++) {
float v = rand.nextFloat() * basev;
a[i] = v;
}
}
// libffm's V initialization scheme: 1/sqrt(k)
// https://github.com/guestwalk/libffm/blob/master/ffm.cpp#L287
protected static final void libffmRandomFill(@Nonnull final float[] a,
@Nonnull final Random rand, final float maxInitValue) {
final int k = a.length;
final float basev = maxInitValue / (float) Math.sqrt(k);
for (int i = 0; i < k; i++) {
float v = rand.nextFloat() * basev;
a[i] = v;
}
}
protected static final void randomFill(@Nonnull final float[] a, @Nonnull final Random rand,
final float maxInitValue) {
final int k = a.length;
for (int i = 0; i < k; i++) {
float v = rand.nextFloat() * maxInitValue;
a[i] = v;
}
}
// libfm uses gaussian for initialization
// https://github.com/srendle/libfm/blob/30b9c799c41d043f31565cbf827bf41d0dc3e2ab/src/fm_core/fm_model.h#L96
protected static final void gaussianFill(@Nonnull final float[] a, @Nonnull final Random[] rand,
final double stddev) {
for (int i = 0, k = a.length; i < k; i++) {
float v = (float) MathUtils.gaussian(0.d, stddev, rand[i]);
a[i] = v;
}
}
}