| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package hivemall.xgboost; |
| |
| import hivemall.UDTFWithOptions; |
| import hivemall.utils.hadoop.HadoopUtils; |
| import hivemall.utils.hadoop.HiveUtils; |
| |
| import java.lang.reflect.Constructor; |
| import java.lang.reflect.InvocationTargetException; |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| |
| import javax.annotation.Nonnull; |
| |
| import ml.dmlc.xgboost4j.LabeledPoint; |
| import ml.dmlc.xgboost4j.java.Booster; |
| import ml.dmlc.xgboost4j.java.DMatrix; |
| import ml.dmlc.xgboost4j.java.XGBoostError; |
| |
| import org.apache.commons.cli.CommandLine; |
| import org.apache.commons.cli.Options; |
| import org.apache.commons.logging.Log; |
| import org.apache.commons.logging.LogFactory; |
| import org.apache.hadoop.hive.ql.exec.UDFArgumentException; |
| import org.apache.hadoop.hive.ql.metadata.HiveException; |
| import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; |
| import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; |
| import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; |
| import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; |
| import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; |
| import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; |
| import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; |
| |
| /** |
| * This is a base class to handle the options for XGBoost and provide common functions among various |
| * tasks. |
| */ |
| public abstract class XGBoostUDTF extends UDTFWithOptions { |
| private static final Log logger = LogFactory.getLog(XGBoostUDTF.class); |
| |
| // Settings for the XGBoost native library |
| static { |
| NativeLibLoader.initXGBoost(); |
| } |
| |
| // For input buffer |
| private final List<LabeledPoint> featuresList; |
| |
| // For input parameters |
| private ListObjectInspector featureListOI; |
| private PrimitiveObjectInspector featureElemOI; |
| private PrimitiveObjectInspector targetOI; |
| |
| // For XGBoost options |
| @Nonnull |
| protected final Map<String, Object> params = new HashMap<String, Object>(); |
| |
| // XGBoost options can be found in https://github.com/dmlc/xgboost/blob/master/doc/parameter.md |
| // Most of default parameters are set along with the official one. |
| { |
| /** General parameters */ |
| params.put("booster", "gbtree"); |
| params.put("num_round", 8); |
| params.put("silent", 1); |
| // Set to 1 by default because most of distributed systems assume |
| // each worker has a single vcore. |
| params.put("nthread", 1); |
| |
| /** Parameters for both boosters */ |
| params.put("alpha", 0.0); |
| // This default value depends on a booster type |
| // params.put("lambda", 0.0); |
| |
| /** Parameters for Tree Booster */ |
| params.put("eta", 0.3); |
| params.put("gamma", 0.0); |
| params.put("max_depth", 6); |
| params.put("min_child_weight", 1); |
| params.put("max_delta_step", 0); |
| params.put("subsample", 1.0); |
| params.put("colsample_bytree", 1.0); |
| params.put("colsample_bylevel", 1.0); |
| // The memory-based version of XGBoost only supports `exact` |
| params.put("tree_method", "exact"); |
| |
| /** Learning Task Parameters */ |
| params.put("base_score", 0.5); |
| } |
| |
| public XGBoostUDTF() { |
| this.featuresList = new ArrayList<>(1024); |
| } |
| |
| @Override |
| protected Options getOptions() { |
| final Options opts = new Options(); |
| |
| /** General parameters */ |
| opts.addOption("booster", true, |
| "Set a booster to use, gbtree or gblinear. [default: gbree]"); |
| opts.addOption("num_round", true, "Number of boosting iterations [default: 8]"); |
| opts.addOption("silent", true, |
| "0 means printing running messages, 1 means silent mode [default: 1]"); |
| opts.addOption("nthread", true, |
| "Number of parallel threads used to run xgboost [default: 1]"); |
| opts.addOption("num_pbuffer", true, |
| "Size of prediction buffer [set automatically by xgboost]"); |
| opts.addOption("num_feature", true, |
| "Feature dimension used in boosting [default: set automatically by xgboost]"); |
| |
| /** Parameters for both boosters */ |
| opts.addOption("alpha", true, "L1 regularization term on weights [default: 0.0]"); |
| opts.addOption("lambda", true, |
| "L2 regularization term on weights [default: 1.0 for gbtree, 0.0 for gblinear]"); |
| |
| /** Parameters for Tree Booster */ |
| opts.addOption("eta", true, |
| "Step size shrinkage used in update to prevents overfitting [default: 0.3]"); |
| opts.addOption("gamma", true, |
| "Minimum loss reduction required to make a further partition on a leaf node of the tree [default: 0.0]"); |
| opts.addOption("max_depth", true, "Max depth of decision tree [default: 6]"); |
| opts.addOption("min_child_weight", true, |
| "Minimum sum of instance weight(hessian) needed in a child [default: 1]"); |
| opts.addOption("max_delta_step", true, |
| "Maximum delta step we allow each tree's weight estimation to be [default: 0]"); |
| opts.addOption("subsample", true, |
| "Subsample ratio of the training instance [default: 1.0]"); |
| opts.addOption("colsample_bytree", true, |
| "Subsample ratio of columns when constructing each tree [default: 1.0]"); |
| opts.addOption("colsample_bylevel", true, |
| "Subsample ratio of columns for each split, in each level [default: 1.0]"); |
| |
| /** Parameters for Linear Booster */ |
| opts.addOption("lambda_bias", true, "L2 regularization term on bias [default: 0.0]"); |
| |
| /** Learning Task Parameters */ |
| opts.addOption("base_score", true, |
| "Initial prediction score of all instances, global bias [default: 0.5]"); |
| opts.addOption("eval_metric", true, |
| "Evaluation metrics for validation data [default according to objective]"); |
| |
| return opts; |
| } |
| |
| @Override |
| protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { |
| CommandLine cl = null; |
| if (argOIs.length >= 3) { |
| final String rawArgs = HiveUtils.getConstString(argOIs[2]); |
| cl = this.parseOptions(rawArgs); |
| |
| /** General parameters */ |
| if (cl.hasOption("booster")) { |
| params.put("booster", cl.getOptionValue("booster")); |
| } |
| if (cl.hasOption("num_round")) { |
| params.put("num_round", Integer.valueOf(cl.getOptionValue("num_round"))); |
| } |
| if (cl.hasOption("silent")) { |
| params.put("silent", Integer.valueOf(cl.getOptionValue("silent"))); |
| } |
| if (cl.hasOption("nthread")) { |
| params.put("nthread", Integer.valueOf(cl.getOptionValue("nthread"))); |
| } |
| if (cl.hasOption("num_pbuffer")) { |
| params.put("num_pbuffer", Integer.valueOf(cl.getOptionValue("num_pbuffer"))); |
| } |
| if (cl.hasOption("num_feature")) { |
| params.put("num_feature", Integer.valueOf(cl.getOptionValue("num_feature"))); |
| } |
| |
| /** Parameters for both boosters */ |
| if (cl.hasOption("alpha")) { |
| params.put("alpha", Double.valueOf(cl.getOptionValue("alpha"))); |
| } |
| if (cl.hasOption("lambda")) { |
| params.put("lambda", Double.valueOf(cl.getOptionValue("lambda"))); |
| } |
| |
| /** Parameters for Tree Booster */ |
| if (cl.hasOption("eta")) { |
| params.put("eta", Double.valueOf(cl.getOptionValue("eta"))); |
| } |
| if (cl.hasOption("gamma")) { |
| params.put("gamma", Double.valueOf(cl.getOptionValue("gamma"))); |
| } |
| if (cl.hasOption("max_depth")) { |
| params.put("max_depth", Integer.valueOf(cl.getOptionValue("max_depth"))); |
| } |
| if (cl.hasOption("min_child_weight")) { |
| params.put("min_child_weight", |
| Integer.valueOf(cl.getOptionValue("min_child_weight"))); |
| } |
| if (cl.hasOption("max_delta_step")) { |
| params.put("max_delta_step", Integer.valueOf(cl.getOptionValue("max_delta_step"))); |
| } |
| if (cl.hasOption("subsample")) { |
| params.put("subsample", Double.valueOf(cl.getOptionValue("subsample"))); |
| } |
| if (cl.hasOption("colsample_bytree")) { |
| params.put("colsamle_bytree", |
| Double.valueOf(cl.getOptionValue("colsample_bytree"))); |
| } |
| if (cl.hasOption("colsample_bylevel")) { |
| params.put("colsamle_bylevel", |
| Double.valueOf(cl.getOptionValue("colsample_bylevel"))); |
| } |
| |
| /** Parameters for Linear Booster */ |
| if (cl.hasOption("lambda_bias")) { |
| params.put("lambda_bias", Double.valueOf(cl.getOptionValue("lambda_bias"))); |
| } |
| |
| /** Learning Task Parameters */ |
| if (cl.hasOption("base_score")) { |
| params.put("base_score", Double.valueOf(cl.getOptionValue("base_score"))); |
| } |
| if (cl.hasOption("eval_metric")) { |
| params.put("eval_metric", cl.getOptionValue("eval_metric")); |
| } |
| } |
| |
| try { |
| // Try to create a `Booster` instance to check if given XGBoost options |
| // are valid, or not. |
| createXGBooster(params, featuresList); |
| } catch (Exception e) { |
| throw new UDFArgumentException(e); |
| } |
| |
| return cl; |
| } |
| |
| /** All the functions return (string model_id, byte[] pred_model) as built models */ |
| @Nonnull |
| private static StructObjectInspector getReturnOIs() { |
| final List<String> fieldNames = new ArrayList<>(2); |
| final List<ObjectInspector> fieldOIs = new ArrayList<>(2); |
| fieldNames.add("model_id"); |
| fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector); |
| fieldNames.add("pred_model"); |
| fieldOIs.add(PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector); |
| return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); |
| } |
| |
| @Override |
| public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) |
| throws UDFArgumentException { |
| processOptions(argOIs); |
| final ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]); |
| final ObjectInspector elemOI = listOI.getListElementObjectInspector(); |
| this.featureListOI = listOI; |
| this.featureElemOI = HiveUtils.asStringOI(elemOI); |
| this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]); |
| return getReturnOIs(); |
| } |
| |
| /** It `target` has valid input range, it overrides this */ |
| protected void checkTargetValue(double target) throws HiveException {} |
| |
| @Override |
| public void process(@Nonnull Object[] args) throws HiveException { |
| if (args[0] == null) { |
| return; |
| } |
| |
| // TODO: Need to support dense inputs |
| final List<?> features = (List<?>) featureListOI.getList(args[0]); |
| final String[] fv = new String[features.size()]; |
| for (int i = 0; i < features.size(); i++) { |
| fv[i] = (String) featureElemOI.getPrimitiveJavaObject(features.get(i)); |
| } |
| double target = PrimitiveObjectInspectorUtils.getDouble(args[1], this.targetOI); |
| checkTargetValue(target); |
| final LabeledPoint point = XGBoostUtils.parseFeatures(target, fv); |
| if (point != null) { |
| this.featuresList.add(point); |
| } |
| } |
| |
| @Nonnull |
| private static String generateUniqueModelId() { |
| return "xgbmodel-" + HadoopUtils.getUniqueTaskIdString(); |
| } |
| |
| @Nonnull |
| private static Booster createXGBooster(final Map<String, Object> params, |
| final List<LabeledPoint> input) throws NoSuchMethodException, XGBoostError, |
| IllegalAccessException, InvocationTargetException, InstantiationException { |
| Class<?>[] args = {Map.class, DMatrix[].class}; |
| Constructor<Booster> ctor = Booster.class.getDeclaredConstructor(args); |
| ctor.setAccessible(true); |
| return ctor.newInstance( |
| new Object[] {params, new DMatrix[] {new DMatrix(input.iterator(), "")}}); |
| } |
| |
| @Override |
| public void close() throws HiveException { |
| try { |
| // Kick off training with XGBoost |
| final DMatrix trainData = new DMatrix(featuresList.iterator(), ""); |
| final Booster booster = createXGBooster(params, featuresList); |
| final int num_round = (Integer) params.get("num_round"); |
| for (int i = 0; i < num_round; i++) { |
| booster.update(trainData, i); |
| } |
| |
| // Output the built model |
| final String modelId = generateUniqueModelId(); |
| final byte[] predModel = booster.toByteArray(); |
| logger.info("model_id:" + modelId.toString() + " size:" + predModel.length); |
| forward(new Object[] {modelId, predModel}); |
| } catch (Exception e) { |
| throw new HiveException(e); |
| } |
| } |
| |
| } |