blob: 33196ab690893d5fb9524ec12dd7a184f8d578c8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.regression;
import hivemall.LearnerBaseUDTF;
import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue;
import hivemall.model.WeightValue.WeightValueWithCovar;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.FloatAccumulator;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;
/**
* The base class for regression algorithms. RegressionBaseUDTF provides general implementation for
* online training and batch training.
*/
public abstract class RegressionBaseUDTF extends LearnerBaseUDTF {
private static final Log logger = LogFactory.getLog(RegressionBaseUDTF.class);
private ListObjectInspector featureListOI;
private PrimitiveObjectInspector featureInputOI;
private PrimitiveObjectInspector targetOI;
private boolean parseFeature;
protected PredictionModel model;
protected int count;
// The accumulated delta of each weight values.
protected transient Map<Object, FloatAccumulator> accumulated;
protected int sampled;
public RegressionBaseUDTF() {
this(false);
}
public RegressionBaseUDTF(boolean enableNewModel) {
super(enableNewModel);
}
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length < 2) {
throw new UDFArgumentException(
"_FUNC_ takes 2 arguments: List<Int|BigInt|Text> features, float target [, constant string options]");
}
this.featureInputOI = processFeaturesOI(argOIs[0]);
this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]);
processOptions(argOIs);
this.model = createModel();
this.count = 0;
this.sampled = 0;
return getReturnOI(getFeatureOutputOI(featureInputOI));
}
@Nonnull
protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector arg)
throws UDFArgumentException {
this.featureListOI = (ListObjectInspector) arg;
ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector();
HiveUtils.validateFeatureOI(featureRawOI);
this.parseFeature = HiveUtils.isStringOI(featureRawOI);
return HiveUtils.asPrimitiveObjectInspector(featureRawOI);
}
@Nonnull
protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector featureOutputOI) {
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
fieldNames.add("feature");
fieldOIs.add(featureOutputOI);
fieldNames.add("weight");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
if (useCovariance()) {
fieldNames.add("covar");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
}
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Override
public void process(Object[] args) throws HiveException {
if (is_mini_batch && accumulated == null) {
this.accumulated = new HashMap<Object, FloatAccumulator>(1024);
}
List<?> features = (List<?>) featureListOI.getList(args[0]);
FeatureValue[] featureVector = parseFeatures(features);
if (featureVector == null) {
return;
}
float target = PrimitiveObjectInspectorUtils.getFloat(args[1], targetOI);
checkTargetValue(target);
count++;
train(featureVector, target);
}
@Nullable
protected final FeatureValue[] parseFeatures(@Nonnull final List<?> features) {
final int size = features.size();
if (size == 0) {
return null;
}
final ObjectInspector featureInspector = featureListOI.getListElementObjectInspector();
final FeatureValue[] featureVector = new FeatureValue[size];
for (int i = 0; i < size; i++) {
Object f = features.get(i);
if (f == null) {
continue;
}
final FeatureValue fv;
if (parseFeature) {
fv = FeatureValue.parse(f);
} else {
Object k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
fv = new FeatureValue(k, 1.f);
}
featureVector[i] = fv;
}
return featureVector;
}
protected void checkTargetValue(float target) throws UDFArgumentException {}
protected void train(@Nonnull final FeatureValue[] features, final float target) {
float p = predict(features);
update(features, target, p);
}
protected float predict(@Nonnull final FeatureValue[] features) {
float score = 0.f;
for (FeatureValue f : features) {// a += w[i] * x[i]
if (f == null) {
continue;
}
final Object k = f.getFeature();
final float v = f.getValueAsFloat();
float old_w = model.getWeight(k);
if (old_w != 0f) {
score += (old_w * v);
}
}
return score;
}
@Nonnull
protected PredictionResult calcScoreAndNorm(@Nonnull final FeatureValue[] features) {
float score = 0.f;
float squared_norm = 0.f;
for (FeatureValue f : features) {// a += w[i] * x[i]
if (f == null) {
continue;
}
final Object k = f.getFeature();
final float v = f.getValueAsFloat();
float old_w = model.getWeight(k);
if (old_w != 0f) {
score += (old_w * v);
}
squared_norm += (v * v);
}
return new PredictionResult(score).squaredNorm(squared_norm);
}
@Nonnull
protected PredictionResult calcScoreAndVariance(@Nonnull final FeatureValue[] features) {
float score = 0.f;
float variance = 0.f;
for (FeatureValue f : features) {// a += w[i] * x[i]
if (f == null) {
continue;
}
final Object k = f.getFeature();
final float v = f.getValueAsFloat();
IWeightValue old_w = model.get(k);
if (old_w == null) {
variance += (v * v);
} else {
score += (old_w.get() * v);
variance += (old_w.getCovariance() * v * v);
}
}
return new PredictionResult(score).variance(variance);
}
protected void update(@Nonnull final FeatureValue[] features, final float target,
final float predicted) {
final float grad = computeGradient(target, predicted);
if (is_mini_batch) {
accumulateUpdate(features, grad);
if (sampled >= mini_batch_size) {
batchUpdate();
}
} else {
onlineUpdate(features, grad);
}
}
/**
* Compute a gradient by using a loss function in derived classes
*/
protected float computeGradient(float target, float predicted) {
throw new UnsupportedOperationException();
}
protected void accumulateUpdate(@Nonnull final FeatureValue[] features, final float coeff) {
for (int i = 0; i < features.length; i++) {
if (features[i] == null) {
continue;
}
final Object x = features[i].getFeature();
final float xi = features[i].getValueAsFloat();
float delta = xi * coeff;
FloatAccumulator acc = accumulated.get(x);
if (acc == null) {
acc = new FloatAccumulator(delta);
accumulated.put(x, acc);
} else {
acc.add(delta);
}
}
sampled++;
}
protected void batchUpdate() {
if (accumulated.isEmpty()) {
this.sampled = 0;
return;
}
for (Map.Entry<Object, FloatAccumulator> e : accumulated.entrySet()) {
Object x = e.getKey();
FloatAccumulator v = e.getValue();
float delta = v.get();
float old_w = model.getWeight(x);
float new_w = old_w + delta;
model.set(x, new WeightValue(new_w));
}
accumulated.clear();
this.sampled = 0;
}
/**
* Calculate the update value for online training.
*/
protected void onlineUpdate(@Nonnull final FeatureValue[] features, float coeff) {
for (FeatureValue f : features) {// w[i] += y * x[i]
if (f == null) {
continue;
}
final Object x = f.getFeature();
final float xi = f.getValueAsFloat();
float old_w = model.getWeight(x);
float new_w = old_w + (coeff * xi);
model.set(x, new WeightValue(new_w));
}
}
@Override
public final void close() throws HiveException {
super.close();
if (model != null) {
if (accumulated != null) { // Update model with accumulated delta
batchUpdate();
this.accumulated = null;
}
int numForwarded = 0;
if (useCovariance()) {
final WeightValueWithCovar probe = new WeightValueWithCovar();
final Object[] forwardMapObj = new Object[3];
final FloatWritable fv = new FloatWritable();
final FloatWritable cov = new FloatWritable();
final IMapIterator<Object, IWeightValue> itor = model.entries();
while (itor.next() != -1) {
itor.getValue(probe);
if (!probe.isTouched()) {
continue; // skip outputting untouched weights
}
Object k = itor.getKey();
fv.set(probe.get());
cov.set(probe.getCovariance());
forwardMapObj[0] = k;
forwardMapObj[1] = fv;
forwardMapObj[2] = cov;
forward(forwardMapObj);
numForwarded++;
}
} else {
final WeightValue probe = new WeightValue();
final Object[] forwardMapObj = new Object[2];
final FloatWritable fv = new FloatWritable();
final IMapIterator<Object, IWeightValue> itor = model.entries();
while (itor.next() != -1) {
itor.getValue(probe);
if (!probe.isTouched()) {
continue; // skip outputting untouched weights
}
Object k = itor.getKey();
fv.set(probe.get());
forwardMapObj[0] = k;
forwardMapObj[1] = fv;
forward(forwardMapObj);
numForwarded++;
}
}
long numMixed = model.getNumMixed();
this.model = null;
logger.info("Trained a prediction model using " + count + " training examples"
+ (numMixed > 0 ? "( numMixed: " + numMixed + " )" : ""));
logger.info("Forwarded the prediction model of " + numForwarded + " rows");
}
}
}