blob: 272614f67f2d93b9302b02eeb12863aaeadbef06 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.xgboost;
import hivemall.UDTFWithOptions;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
/**
* This is a base class to handle the options for XGBoost and provide common functions among various
* tasks.
*/
public abstract class XGBoostUDTF extends UDTFWithOptions {
private static final Log logger = LogFactory.getLog(XGBoostUDTF.class);
// Settings for the XGBoost native library
static {
NativeLibLoader.initXGBoost();
}
// For input buffer
private final List<LabeledPoint> featuresList;
// For input parameters
private ListObjectInspector featureListOI;
private PrimitiveObjectInspector featureElemOI;
private PrimitiveObjectInspector targetOI;
// For XGBoost options
@Nonnull
protected final Map<String, Object> params = new HashMap<String, Object>();
// XGBoost options can be found in https://github.com/dmlc/xgboost/blob/master/doc/parameter.md
// Most of default parameters are set along with the official one.
{
/** General parameters */
params.put("booster", "gbtree");
params.put("num_round", 8);
params.put("silent", 1);
// Set to 1 by default because most of distributed systems assume
// each worker has a single vcore.
params.put("nthread", 1);
/** Parameters for both boosters */
params.put("alpha", 0.0);
// This default value depends on a booster type
// params.put("lambda", 0.0);
/** Parameters for Tree Booster */
params.put("eta", 0.3);
params.put("gamma", 0.0);
params.put("max_depth", 6);
params.put("min_child_weight", 1);
params.put("max_delta_step", 0);
params.put("subsample", 1.0);
params.put("colsample_bytree", 1.0);
params.put("colsample_bylevel", 1.0);
// The memory-based version of XGBoost only supports `exact`
params.put("tree_method", "exact");
/** Learning Task Parameters */
params.put("base_score", 0.5);
}
public XGBoostUDTF() {
this.featuresList = new ArrayList<>(1024);
}
@Override
protected Options getOptions() {
final Options opts = new Options();
/** General parameters */
opts.addOption("booster", true,
"Set a booster to use, gbtree or gblinear. [default: gbree]");
opts.addOption("num_round", true, "Number of boosting iterations [default: 8]");
opts.addOption("silent", true,
"0 means printing running messages, 1 means silent mode [default: 1]");
opts.addOption("nthread", true,
"Number of parallel threads used to run xgboost [default: 1]");
opts.addOption("num_pbuffer", true,
"Size of prediction buffer [set automatically by xgboost]");
opts.addOption("num_feature", true,
"Feature dimension used in boosting [default: set automatically by xgboost]");
/** Parameters for both boosters */
opts.addOption("alpha", true, "L1 regularization term on weights [default: 0.0]");
opts.addOption("lambda", true,
"L2 regularization term on weights [default: 1.0 for gbtree, 0.0 for gblinear]");
/** Parameters for Tree Booster */
opts.addOption("eta", true,
"Step size shrinkage used in update to prevents overfitting [default: 0.3]");
opts.addOption("gamma", true,
"Minimum loss reduction required to make a further partition on a leaf node of the tree [default: 0.0]");
opts.addOption("max_depth", true, "Max depth of decision tree [default: 6]");
opts.addOption("min_child_weight", true,
"Minimum sum of instance weight(hessian) needed in a child [default: 1]");
opts.addOption("max_delta_step", true,
"Maximum delta step we allow each tree's weight estimation to be [default: 0]");
opts.addOption("subsample", true,
"Subsample ratio of the training instance [default: 1.0]");
opts.addOption("colsample_bytree", true,
"Subsample ratio of columns when constructing each tree [default: 1.0]");
opts.addOption("colsample_bylevel", true,
"Subsample ratio of columns for each split, in each level [default: 1.0]");
/** Parameters for Linear Booster */
opts.addOption("lambda_bias", true, "L2 regularization term on bias [default: 0.0]");
/** Learning Task Parameters */
opts.addOption("base_score", true,
"Initial prediction score of all instances, global bias [default: 0.5]");
opts.addOption("eval_metric", true,
"Evaluation metrics for validation data [default according to objective]");
return opts;
}
@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
CommandLine cl = null;
if (argOIs.length >= 3) {
final String rawArgs = HiveUtils.getConstString(argOIs[2]);
cl = this.parseOptions(rawArgs);
/** General parameters */
if (cl.hasOption("booster")) {
params.put("booster", cl.getOptionValue("booster"));
}
if (cl.hasOption("num_round")) {
params.put("num_round", Integer.valueOf(cl.getOptionValue("num_round")));
}
if (cl.hasOption("silent")) {
params.put("silent", Integer.valueOf(cl.getOptionValue("silent")));
}
if (cl.hasOption("nthread")) {
params.put("nthread", Integer.valueOf(cl.getOptionValue("nthread")));
}
if (cl.hasOption("num_pbuffer")) {
params.put("num_pbuffer", Integer.valueOf(cl.getOptionValue("num_pbuffer")));
}
if (cl.hasOption("num_feature")) {
params.put("num_feature", Integer.valueOf(cl.getOptionValue("num_feature")));
}
/** Parameters for both boosters */
if (cl.hasOption("alpha")) {
params.put("alpha", Double.valueOf(cl.getOptionValue("alpha")));
}
if (cl.hasOption("lambda")) {
params.put("lambda", Double.valueOf(cl.getOptionValue("lambda")));
}
/** Parameters for Tree Booster */
if (cl.hasOption("eta")) {
params.put("eta", Double.valueOf(cl.getOptionValue("eta")));
}
if (cl.hasOption("gamma")) {
params.put("gamma", Double.valueOf(cl.getOptionValue("gamma")));
}
if (cl.hasOption("max_depth")) {
params.put("max_depth", Integer.valueOf(cl.getOptionValue("max_depth")));
}
if (cl.hasOption("min_child_weight")) {
params.put("min_child_weight",
Integer.valueOf(cl.getOptionValue("min_child_weight")));
}
if (cl.hasOption("max_delta_step")) {
params.put("max_delta_step", Integer.valueOf(cl.getOptionValue("max_delta_step")));
}
if (cl.hasOption("subsample")) {
params.put("subsample", Double.valueOf(cl.getOptionValue("subsample")));
}
if (cl.hasOption("colsample_bytree")) {
params.put("colsamle_bytree",
Double.valueOf(cl.getOptionValue("colsample_bytree")));
}
if (cl.hasOption("colsample_bylevel")) {
params.put("colsamle_bylevel",
Double.valueOf(cl.getOptionValue("colsample_bylevel")));
}
/** Parameters for Linear Booster */
if (cl.hasOption("lambda_bias")) {
params.put("lambda_bias", Double.valueOf(cl.getOptionValue("lambda_bias")));
}
/** Learning Task Parameters */
if (cl.hasOption("base_score")) {
params.put("base_score", Double.valueOf(cl.getOptionValue("base_score")));
}
if (cl.hasOption("eval_metric")) {
params.put("eval_metric", cl.getOptionValue("eval_metric"));
}
}
try {
// Try to create a `Booster` instance to check if given XGBoost options
// are valid, or not.
createXGBooster(params, featuresList);
} catch (Exception e) {
throw new UDFArgumentException(e);
}
return cl;
}
/** All the functions return (string model_id, byte[] pred_model) as built models */
@Nonnull
private static StructObjectInspector getReturnOIs() {
final List<String> fieldNames = new ArrayList<>(2);
final List<ObjectInspector> fieldOIs = new ArrayList<>(2);
fieldNames.add("model_id");
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
fieldNames.add("pred_model");
fieldOIs.add(PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Override
public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs)
throws UDFArgumentException {
processOptions(argOIs);
final ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
final ObjectInspector elemOI = listOI.getListElementObjectInspector();
this.featureListOI = listOI;
this.featureElemOI = HiveUtils.asStringOI(elemOI);
this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]);
return getReturnOIs();
}
/** It `target` has valid input range, it overrides this */
protected void checkTargetValue(double target) throws HiveException {}
@Override
public void process(@Nonnull Object[] args) throws HiveException {
if (args[0] == null) {
return;
}
// TODO: Need to support dense inputs
final List<?> features = (List<?>) featureListOI.getList(args[0]);
final String[] fv = new String[features.size()];
for (int i = 0; i < features.size(); i++) {
fv[i] = (String) featureElemOI.getPrimitiveJavaObject(features.get(i));
}
double target = PrimitiveObjectInspectorUtils.getDouble(args[1], this.targetOI);
checkTargetValue(target);
final LabeledPoint point = XGBoostUtils.parseFeatures(target, fv);
if (point != null) {
this.featuresList.add(point);
}
}
@Nonnull
private static String generateUniqueModelId() {
return "xgbmodel-" + HadoopUtils.getUniqueTaskIdString();
}
@Nonnull
private static Booster createXGBooster(final Map<String, Object> params,
final List<LabeledPoint> input) throws NoSuchMethodException, XGBoostError,
IllegalAccessException, InvocationTargetException, InstantiationException {
Class<?>[] args = {Map.class, DMatrix[].class};
Constructor<Booster> ctor = Booster.class.getDeclaredConstructor(args);
ctor.setAccessible(true);
return ctor.newInstance(
new Object[] {params, new DMatrix[] {new DMatrix(input.iterator(), "")}});
}
@Override
public void close() throws HiveException {
try {
// Kick off training with XGBoost
final DMatrix trainData = new DMatrix(featuresList.iterator(), "");
final Booster booster = createXGBooster(params, featuresList);
final int num_round = (Integer) params.get("num_round");
for (int i = 0; i < num_round; i++) {
booster.update(trainData, i);
}
// Output the built model
final String modelId = generateUniqueModelId();
final byte[] predModel = booster.toByteArray();
logger.info("model_id:" + modelId.toString() + " size:" + predModel.length);
forward(new Object[] {modelId, predModel});
} catch (Exception e) {
throw new HiveException(e);
}
}
}