| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package hivemall.factorization.fm; |
| |
| import hivemall.optimizer.EtaEstimator; |
| import hivemall.utils.lang.NumberUtils; |
| import hivemall.utils.math.MathUtils; |
| |
| import java.util.Arrays; |
| import java.util.Objects; |
| import java.util.Random; |
| |
| import javax.annotation.Nonnegative; |
| import javax.annotation.Nonnull; |
| import javax.annotation.Nullable; |
| |
| import org.apache.hadoop.hive.ql.metadata.HiveException; |
| |
| public abstract class FactorizationMachineModel { |
| |
| protected final boolean _classification; |
| protected final int _factor; |
| protected final double _sigma; |
| @Nonnull |
| protected final EtaEstimator _eta; |
| @Nonnull |
| protected final VInitScheme _initScheme; |
| @Nonnull |
| protected final Random _rnd; |
| |
| // Hyperparameter for regression |
| protected final double _min_target; |
| protected final double _max_target; |
| |
| // Regulation Variables |
| protected float _lambdaW0; |
| protected float _lambdaW; |
| @Nonnull |
| protected final float[] _lambdaV; |
| |
| public FactorizationMachineModel(@Nonnull FMHyperParameters params) { |
| this._classification = params.classification; |
| this._factor = params.factors; |
| this._sigma = params.sigma; |
| this._eta = Objects.requireNonNull(params.eta); |
| this._initScheme = Objects.requireNonNull(params.vInit); |
| this._rnd = new Random(params.seed); |
| |
| this._min_target = params.minTarget; |
| this._max_target = params.maxTarget; |
| |
| // Regulation Variables |
| this._lambdaW0 = params.lambdaW0; |
| this._lambdaW = params.lambdaW; |
| this._lambdaV = new float[params.factors]; |
| Arrays.fill(_lambdaV, params.lambdaV); |
| } |
| |
| public abstract int getSize(); |
| |
| protected int getMinIndex() { |
| throw new UnsupportedOperationException(); |
| } |
| |
| protected int getMaxIndex() { |
| throw new UnsupportedOperationException(); |
| } |
| |
| public abstract float getW0(); |
| |
| protected abstract void setW0(float nextW0); |
| |
| /** |
| * @param i index value >= 1 |
| */ |
| protected float getW(int i) { |
| throw new UnsupportedOperationException(); |
| } |
| |
| public abstract float getW(@Nonnull Feature x); |
| |
| protected abstract void setW(@Nonnull Feature x, float nextWi); |
| |
| /** |
| * @param i index value >= 1 |
| */ |
| @Nullable |
| protected float[] getV(int i, boolean init) { |
| throw new UnsupportedOperationException(); |
| } |
| |
| public abstract float getV(@Nonnull Feature x, int f); |
| |
| protected abstract void setV(@Nonnull Feature x, int f, float nextVif); |
| |
| /** |
| * @param f index value >= 0 |
| */ |
| float getLambdaV(int f) { |
| return _lambdaV[f]; |
| } |
| |
| final double dloss(@Nonnull final Feature[] x, final double y) throws HiveException { |
| double p = predict(x); |
| return dloss(p, y); |
| } |
| |
| final double dloss(double p, final double y) { |
| final double ret; |
| if (_classification) { |
| ret = (MathUtils.sigmoid(p * y) - 1.d) * y; |
| } else { // regression |
| p = Math.min(p, _max_target); |
| p = Math.max(p, _min_target); |
| //ret = 2.d * (p - y); |
| ret = p - y; |
| } |
| return ret; |
| } |
| |
| protected double predict(@Nonnull final Feature[] x) throws HiveException { |
| // w0 |
| double ret = getW0(); |
| |
| // W |
| for (Feature e : x) { |
| double xj = e.getValue(); |
| float w = getW(e); |
| double wx = w * xj; |
| ret += wx; |
| } |
| |
| // V |
| for (int f = 0, k = _factor; f < k; f++) { |
| double sumVjfXj = 0.d; |
| double sumV2X2 = 0.d; |
| |
| for (Feature e : x) { |
| double xj = e.getValue(); |
| float vjf = getV(e, f); |
| double vx = vjf * xj; |
| sumVjfXj += vx; |
| sumV2X2 += (vx * vx); |
| } |
| ret += 0.5d * (sumVjfXj * sumVjfXj - sumV2X2); |
| assert (!Double.isNaN(ret)); |
| } |
| if (!NumberUtils.isFinite(ret)) { |
| throw new HiveException( |
| "Detected " + ret + " in predict. We recommend to normalize training examples.\n" |
| + "Dumping variables ...\n" + varDump(x)); |
| } |
| return ret; |
| } |
| |
| protected String varDump(@Nonnull final Feature[] x) { |
| final StringBuilder buf = new StringBuilder(1024); |
| for (int i = 0; i < x.length; i++) { |
| Feature e = x[i]; |
| String j = e.getFeature(); |
| double xj = e.getValue(); |
| if (i != 0) { |
| buf.append(", "); |
| } |
| buf.append("x[").append(j).append("] = ").append(xj); |
| } |
| buf.append("\n"); |
| buf.append("W0 = ").append(getW0()).append('\n'); |
| for (int i = 0; i < x.length; i++) { |
| Feature e = x[i]; |
| String j = e.getFeature(); |
| float wi = getW(e); |
| if (i != 0) { |
| buf.append(", "); |
| } |
| buf.append("W[").append(j).append("] = ").append(wi); |
| } |
| buf.append("\n"); |
| for (int f = 0, k = _factor; f < k; f++) { |
| for (int i = 0; i < x.length; i++) { |
| Feature e = x[i]; |
| String j = e.getFeature(); |
| float vjf = getV(e, f); |
| if (i != 0) { |
| buf.append(", "); |
| } |
| buf.append('V').append(f).append('[').append(j).append("] = ").append(vjf); |
| } |
| buf.append('\n'); |
| } |
| return buf.toString(); |
| } |
| |
| final void updateW0(final double dloss, final float eta) { |
| float gradW0 = (float) dloss; |
| float prevW0 = getW0(); |
| float nextW0 = prevW0 - eta * (gradW0 + 2.f * _lambdaW0 * prevW0); |
| if (!NumberUtils.isFinite(nextW0)) { |
| throw new IllegalStateException("Got " + nextW0 + " for next W0\n" + "gradW0=" + gradW0 |
| + ", prevW0=" + prevW0 + ", dloss=" + dloss + ", eta=" + eta); |
| } |
| setW0(nextW0); |
| } |
| |
| /** |
| * @return whether to update V or not |
| */ |
| void updateWi(final double dloss, @Nonnull final Feature x, final float eta) { |
| final double Xi = x.getValue(); |
| float gradWi = (float) (dloss * Xi); |
| float wi = getW(x); |
| float nextWi = wi - eta * (gradWi + 2.f * _lambdaW * wi); |
| if (!NumberUtils.isFinite(nextWi)) { |
| throw new IllegalStateException( |
| "Got " + nextWi + " for next W[" + x.getFeature() + "]\n" + "Xi=" + Xi + ", gradWi=" |
| + gradWi + ", wi=" + wi + ", dloss=" + dloss + ", eta=" + eta); |
| } |
| setW(x, nextWi); |
| } |
| |
| final void updateV(final double dloss, @Nonnull final Feature x, final int f, |
| final double sumViX, final float eta) { |
| final double Xi = x.getValue(); |
| float Vif = getV(x, f); |
| double h = gradV(Xi, Vif, sumViX); |
| float gradV = (float) (dloss * h); |
| float LambdaVf = getLambdaV(f); |
| float nextVif = Vif - eta * (gradV + 2.f * LambdaVf * Vif); |
| if (!NumberUtils.isFinite(nextVif)) { |
| throw new IllegalStateException( |
| "Got " + nextVif + " for next V" + f + '[' + x.getFeature() + "]\n" + "Xi=" + Xi |
| + ", Vif=" + Vif + ", h=" + h + ", gradV=" + gradV + ", lambdaVf=" |
| + LambdaVf + ", dloss=" + dloss + ", sumViX=" + sumViX + ", eta=" + eta); |
| } |
| setV(x, f, nextVif); |
| } |
| |
| final void updateLambdaW0(final double dloss, final float eta) { |
| float lambda_w_grad = -2.f * eta * getW0(); |
| float lambdaW0 = _lambdaW0 - (float) (eta * dloss * lambda_w_grad); |
| this._lambdaW0 = Math.max(0.f, lambdaW0); |
| } |
| |
| final void updateLambdaW(@Nonnull Feature[] x, double dloss, float eta) { |
| double sumWX = 0.d; |
| for (Feature e : x) { |
| assert (e != null) : Arrays.toString(x); |
| double xi = e.getValue(); |
| sumWX += getW(e) * xi; |
| } |
| double lambda_w_grad = -2.f * eta * sumWX; |
| float lambdaW = _lambdaW - (float) (eta * dloss * lambda_w_grad); |
| this._lambdaW = Math.max(0.f, lambdaW); |
| } |
| |
| /** |
| * <pre> |
| * grad_lambdafg := (grad l(y(x),y)) * -2 * alpha * ((\sum_{j} x_j * v'_jf) * (\sum_{j \in group(g)} x_j * v_jf) - \sum_{j \in group(g)} x^2_j * v_jf * v'_jf) |
| * := (grad l(y(x),y)) * -2 * alpha * (sum_f_dash * sum_f(g) - sum_f_dash_f(g)) |
| * sum_f_dash := \sum_{j} x_j * v'_lj, this is independent of the groups |
| * sum_f(g) := \sum_{j \in group(g)} x_j * v_jf |
| * sum_f_dash_f(g) := \sum_{j \in group(g)} x^2_j * v_jf * v'_jf |
| * := \sum_{j \in group(g)} x_j * v'_jf * x_j * v_jf |
| * v_jf' := v_jf - alpha ( grad_v_jf + 2 * lambda_v_f * v_jf) |
| * </pre> |
| */ |
| final void updateLambdaV(@Nonnull final Feature[] x, final double dloss, final float eta) { |
| for (int f = 0, k = _factor; f < k; f++) { |
| double sum_f_dash = 0.d, sum_f = 0.d, sum_f_dash_f = 0.d; |
| float lambdaVf = getLambdaV(f); |
| |
| final double sumVfX = sumVfX(x, f); |
| for (Feature e : x) { |
| assert (e != null) : Arrays.toString(x); |
| double x_j = e.getValue(); |
| |
| float v_jf = getV(e, f); |
| double gradV = gradV(x_j, v_jf, sumVfX); |
| double v_dash = v_jf - eta * (gradV + 2.d * lambdaVf * v_jf); |
| |
| sum_f_dash += x_j * v_dash; |
| sum_f += x_j * v_jf; |
| sum_f_dash_f += x_j * v_dash * x_j * v_jf; |
| } |
| |
| double lambda_v_grad = -2.f * eta * (sum_f_dash * sum_f - sum_f_dash_f); |
| lambdaVf -= eta * dloss * lambda_v_grad; |
| _lambdaV[f] = Math.max(0.f, lambdaVf); |
| } |
| } |
| |
| double[] sumVfX(@Nonnull final Feature[] x) { |
| final int k = _factor; |
| final double[] ret = new double[k]; |
| for (int f = 0; f < k; f++) { |
| ret[f] = sumVfX(x, f); |
| } |
| return ret; |
| } |
| |
| private double sumVfX(@Nonnull final Feature[] x, final int f) { |
| double ret = 0.d; |
| for (Feature e : x) { |
| double xj = e.getValue(); |
| float Vjf = getV(e, f); |
| ret += Vjf * xj; |
| } |
| if (!NumberUtils.isFinite(ret)) { |
| throw new IllegalStateException( |
| "Got " + ret + " for sumV[ " + f + "]X.\n" + "x = " + Arrays.toString(x)); |
| } |
| return ret; |
| } |
| |
| /** |
| * <pre> |
| * grad_v_if := multi * (x_i * (sum_f - v_if * x_i)) |
| * sum_f := \sum_j v_jf * x_j |
| * </pre> |
| */ |
| private double gradV(@Nonnull final double Xj, final float Vjf, final double sumVfX) { |
| return Xj * (sumVfX - Vjf * Xj); |
| } |
| |
| public void check(@Nonnull Feature[] x) throws HiveException {} |
| |
| public enum VInitScheme { |
| adjustedRandom /* default */, libffmRandom, random, gaussian; |
| |
| @Nonnegative |
| float maxInitValue; |
| @Nonnegative |
| double initStdDev; |
| Random[] rand; |
| |
| @Nonnull |
| public static VInitScheme resolve(@Nullable String opt) { |
| return resolve(opt, adjustedRandom); |
| } |
| |
| @Nonnull |
| public static VInitScheme resolve(@Nullable String opt, |
| @Nonnull VInitScheme defaultScheme) { |
| if (opt == null) { |
| return defaultScheme; |
| } else if ("adjusted_random".equalsIgnoreCase(opt) |
| || "adjustedRandom".equalsIgnoreCase(opt)) { |
| return adjustedRandom; |
| } else if ("libffm_random".equalsIgnoreCase(opt) || "libffmRandom".equalsIgnoreCase(opt) |
| || "libffm".equalsIgnoreCase(opt)) { |
| return VInitScheme.libffmRandom; |
| } else if ("random".equalsIgnoreCase(opt)) { |
| return random; |
| } else if ("gaussian".equalsIgnoreCase(opt)) { |
| return gaussian; |
| } |
| return defaultScheme; |
| } |
| |
| public void setMaxInitValue(float maxInitValue) { |
| this.maxInitValue = maxInitValue; |
| } |
| |
| public void setInitStdDev(double initStdDev) { |
| this.initStdDev = initStdDev; |
| } |
| |
| public void initRandom(int factor, long seed) { |
| final int size = (this != gaussian) ? 1 : factor; |
| this.rand = new Random[size]; |
| for (int i = 0; i < size; i++) { |
| rand[i] = new Random(seed + i); |
| } |
| } |
| } |
| |
| @Nonnull |
| protected final float[] initV() { |
| final float[] ret = new float[_factor]; |
| switch (_initScheme) { |
| case adjustedRandom: |
| adjustedRandomFill(ret, _initScheme.rand[0], _initScheme.maxInitValue); |
| break; |
| case libffmRandom: |
| libffmRandomFill(ret, _initScheme.rand[0], _initScheme.maxInitValue); |
| break; |
| case random: |
| randomFill(ret, _initScheme.rand[0], _initScheme.maxInitValue); |
| break; |
| case gaussian: |
| gaussianFill(ret, _initScheme.rand, _initScheme.initStdDev); |
| break; |
| default: |
| throw new IllegalStateException( |
| "Unsupported V initialization scheme: " + _initScheme); |
| } |
| return ret; |
| } |
| |
| protected static final void adjustedRandomFill(@Nonnull final float[] a, |
| @Nonnull final Random rand, final float maxInitValue) { |
| final int k = a.length; |
| final float basev = maxInitValue / k; |
| for (int i = 0; i < k; i++) { |
| float v = rand.nextFloat() * basev; |
| a[i] = v; |
| } |
| } |
| |
| // libffm's V initialization scheme: 1/sqrt(k) |
| // https://github.com/guestwalk/libffm/blob/master/ffm.cpp#L287 |
| protected static final void libffmRandomFill(@Nonnull final float[] a, |
| @Nonnull final Random rand, final float maxInitValue) { |
| final int k = a.length; |
| final float basev = maxInitValue / (float) Math.sqrt(k); |
| for (int i = 0; i < k; i++) { |
| float v = rand.nextFloat() * basev; |
| a[i] = v; |
| } |
| } |
| |
| protected static final void randomFill(@Nonnull final float[] a, @Nonnull final Random rand, |
| final float maxInitValue) { |
| final int k = a.length; |
| for (int i = 0; i < k; i++) { |
| float v = rand.nextFloat() * maxInitValue; |
| a[i] = v; |
| } |
| } |
| |
| // libfm uses gaussian for initialization |
| // https://github.com/srendle/libfm/blob/30b9c799c41d043f31565cbf827bf41d0dc3e2ab/src/fm_core/fm_model.h#L96 |
| protected static final void gaussianFill(@Nonnull final float[] a, @Nonnull final Random[] rand, |
| final double stddev) { |
| for (int i = 0, k = a.length; i < k; i++) { |
| float v = (float) MathUtils.gaussian(0.d, stddev, rand[i]); |
| a[i] = v; |
| } |
| } |
| |
| } |