blob: 4ecc028c697626a2328e0b3f8a4c8ac3a7663926 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.classifier;
import hivemall.annotations.Experimental;
import hivemall.annotations.VisibleForTesting;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
import hivemall.optimizer.LossFunctions;
import hivemall.utils.collections.Fastutil;
import hivemall.utils.hashing.HashFunction;
import hivemall.utils.lang.Preconditions;
import it.unimi.dsi.fastutil.ints.Int2FloatMap;
import it.unimi.dsi.fastutil.ints.Int2FloatOpenHashMap;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
/**
* Degree-2 polynomial kernel expansion Passive Aggressive.
*
* <pre>
* Hideki Isozaki and Hideto Kazawa: Efficient Support Vector Classifiers for Named Entity Recognition, Proc.COLING, 2002
* </pre>
*
* @since v0.5-rc.1
*/
@Description(name = "train_kpa",
value = "_FUNC_(array<string|int|bigint> features, int label [, const string options])"
+ " - returns a relation <h int, hk int, float w0, float w1, float w2, float w3>")
@Experimental
public final class KernelExpansionPassiveAggressiveUDTF extends BinaryOnlineClassifierUDTF {
// ------------------------------------
// Hyper parameters
private float _pkc;
// Algorithm
private Algorithm _algo;
// ------------------------------------
// Model parameters
private float _w0;
private Int2FloatMap _w1;
private Int2FloatMap _w2;
private Int2FloatMap _w3;
// ------------------------------------
private float _loss;
public KernelExpansionPassiveAggressiveUDTF() {}
@VisibleForTesting
float getLoss() {//only used for testing purposes at the moment
return _loss;
}
@Override
protected Options getOptions() {
Options opts = new Options();
opts.addOption("pkc", true,
"Constant c inside polynomial kernel K = (dot(xi,xj) + c)^2 [default 1.0]");
opts.addOption("algo", "algorithm", true,
"Algorithm for calculating loss [pa, pa1 (default), pa2]");
opts.addOption("c", "aggressiveness", true,
"Aggressiveness parameter C for PA-1 and PA-2 [default 1.0]");
return opts;
}
@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
float pkc = 1.f;
float c = 1.f;
String algo = "pa1";
final CommandLine cl = super.processOptions(argOIs);
if (cl != null) {
String pkc_str = cl.getOptionValue("pkc");
if (pkc_str != null) {
pkc = Float.parseFloat(pkc_str);
}
String c_str = cl.getOptionValue("c");
if (c_str != null) {
c = Float.parseFloat(c_str);
if (c <= 0.f) {
throw new UDFArgumentException(
"Aggressiveness parameter C must be C > 0: " + c);
}
}
algo = cl.getOptionValue("algo", algo);
}
if ("pa1".equalsIgnoreCase(algo)) {
this._algo = new PA1(c);
} else if ("pa2".equalsIgnoreCase(algo)) {
this._algo = new PA2(c);
} else if ("pa".equalsIgnoreCase(algo)) {
this._algo = new PA();
} else {
throw new UDFArgumentException("Unsupported algorithm: " + algo);
}
this._pkc = pkc;
return cl;
}
interface Algorithm {
float eta(final float loss, @Nonnull final PredictionResult margin);
}
static class PA implements Algorithm {
PA() {}
@Override
public float eta(float loss, PredictionResult margin) {
return loss / margin.getSquaredNorm();
}
}
static class PA1 implements Algorithm {
private final float c;
PA1(float c) {
this.c = c;
}
@Override
public float eta(float loss, PredictionResult margin) {
float squared_norm = margin.getSquaredNorm();
float eta = loss / squared_norm;
return Math.min(c, eta);
}
}
static class PA2 implements Algorithm {
private final float c;
PA2(float c) {
this.c = c;
}
@Override
public float eta(float loss, PredictionResult margin) {
float squared_norm = margin.getSquaredNorm();
float eta = loss / (squared_norm + (0.5f / c));
return eta;
}
}
@Override
protected PredictionModel createModel() {
this._w0 = 0.f;
this._w1 = new Int2FloatOpenHashMap(16384);
_w1.defaultReturnValue(0.f);
this._w2 = new Int2FloatOpenHashMap(16384);
_w2.defaultReturnValue(0.f);
this._w3 = new Int2FloatOpenHashMap(16384);
_w3.defaultReturnValue(0.f);
return null;
}
@Override
protected StructObjectInspector getReturnOI(ObjectInspector featureRawOI) {
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
fieldNames.add("h");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("w0");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
fieldNames.add("w1");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
fieldNames.add("w2");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
fieldNames.add("hk");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("w3");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Nullable
FeatureValue[] parseFeatures(@Nonnull final List<?> features) {
final int size = features.size();
if (size == 0) {
return null;
}
final FeatureValue[] featureVector = new FeatureValue[size];
for (int i = 0; i < size; i++) {
Object f = features.get(i);
if (f == null) {
continue;
}
FeatureValue fv = FeatureValue.parse(f, true);
featureVector[i] = fv;
}
return featureVector;
}
@Override
protected void train(@Nonnull final FeatureValue[] features, final int label) {
final float y = label > 0 ? 1.f : -1.f;
PredictionResult margin = calcScoreWithKernelAndNorm(features);
float p = margin.getScore();
float loss = LossFunctions.hingeLoss(p, y); // 1.0 - y * p
this._loss = loss;
if (loss > 0.f) { // y * p < 1
updateKernel(y, loss, margin, features);
}
}
@Override
float predict(@Nonnull final FeatureValue[] features) {
float score = 0.f;
for (int i = 0; i < features.length; ++i) {
if (features[i] == null) {
continue;
}
int h = features[i].getFeatureAsInt();
float w1 = _w1.get(h);
float w2 = _w2.get(h);
double xi = features[i].getValue();
double xx = xi * xi;
score += w1 * xi;
score += w2 * xx;
for (int j = i + 1; j < features.length; ++j) {
int k = features[j].getFeatureAsInt();
int hk = HashFunction.hash(h, k, true);
float w3 = _w3.get(hk);
double xj = features[j].getValue();
score += xi * xj * w3;
}
}
return score;
}
@Nonnull
final PredictionResult calcScoreWithKernelAndNorm(@Nonnull final FeatureValue[] features) {
float score = _w0;
float norm = 0.f;
for (int i = 0; i < features.length; ++i) {
if (features[i] == null) {
continue;
}
int h = features[i].getFeatureAsInt();
float w1 = _w1.get(h);
float w2 = _w2.get(h);
double xi = features[i].getValue();
double xx = xi * xi;
score += w1 * xi;
score += w2 * xx;
norm += xx;
for (int j = i + 1; j < features.length; ++j) {
int k = features[j].getFeatureAsInt();
int hk = HashFunction.hash(h, k, true);
float w3 = _w3.get(hk);
double xj = features[j].getValue();
score += xi * xj * w3;
}
}
return new PredictionResult(score).squaredNorm(norm);
}
protected void updateKernel(final float label, final float loss,
@Nonnull final PredictionResult margin, @Nonnull final FeatureValue[] features) {
float eta = _algo.eta(loss, margin);
float coeff = eta * label;
expandKernel(features, coeff);
}
private void expandKernel(@Nonnull final FeatureValue[] supportVector, final float alpha) {
final float pkc = _pkc;
// W0 += α c^2
this._w0 += alpha * pkc * pkc;
for (int i = 0; i < supportVector.length; ++i) {
final FeatureValue si = supportVector[i];
final int h = si.getFeatureAsInt();
float Zih = si.getValueAsFloat();
float alphaZih = alpha * Zih;
final float alphaZih2 = alphaZih * 2.f;
// W1[h] += 2 c α Zi[h]
_w1.put(h, _w1.get(h) + pkc * alphaZih2);
// W2[h] += α Zi[h]^2
_w2.put(h, _w2.get(h) + alphaZih * Zih);
for (int j = i + 1; j < supportVector.length; ++j) {
FeatureValue sj = supportVector[j];
int k = sj.getFeatureAsInt();
int hk = HashFunction.hash(h, k, true);
float Zjk = sj.getValueAsFloat();
// W3 += 2 α Zi[h] Zi[k]
_w3.put(hk, _w3.get(hk) + alphaZih2 * Zjk);
}
}
}
@Override
public void close() throws HiveException {
final IntWritable h = new IntWritable(0); // row[0]
final FloatWritable w0 = new FloatWritable(_w0); // row[1]
final FloatWritable w1 = new FloatWritable(); // row[2]
final FloatWritable w2 = new FloatWritable(); // row[3]
final IntWritable hk = new IntWritable(0); // row[4]
final FloatWritable w3 = new FloatWritable(); // row[5]
final Object[] row = new Object[] {h, w0, null, null, null, null};
forward(row); // 0(f), w0
row[1] = null;
row[2] = w1;
row[3] = w2;
final Int2FloatMap w2map = _w2;
for (Int2FloatMap.Entry e : Fastutil.fastIterable(_w1)) {
int k = e.getIntKey();
Preconditions.checkArgument(k > 0, HiveException.class);
h.set(k);
w1.set(e.getFloatValue());
w2.set(w2map.get(k));
forward(row); // h(f), w1, w2
}
this._w1 = null;
this._w2 = null;
row[0] = null;
row[2] = null;
row[3] = null;
row[4] = hk;
row[5] = w3;
_w3.int2FloatEntrySet();
for (Int2FloatMap.Entry e : Fastutil.fastIterable(_w3)) {
int k = e.getIntKey();
Preconditions.checkArgument(k > 0, HiveException.class);
hk.set(k);
w3.set(e.getFloatValue());
forward(row); // hk(f), w3
}
this._w3 = null;
}
}