blob: 723cb0bc56d3969b9f5de755290d945c8b28a62d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.ftvec.conv;
import hivemall.UDFWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hashing.MurmurHash3;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.StringUtils;
import hivemall.utils.struct.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
// @formatter:off
@Description(name = "to_libsvm_format",
value = "_FUNC_(array<string> feautres [, double/integer target, const string options])"
+ " - Returns a string representation of libsvm",
extended = "Usage:\n" +
" select to_libsvm_format(array('apple:3.4','orange:2.1'))\n" +
" > 6284535:3.4 8104713:2.1\n" +
" select to_libsvm_format(array('apple:3.4','orange:2.1'), '-features 10')\n" +
" > 3:2.1 7:3.4\n" +
" select to_libsvm_format(array('7:3.4','3:2.1'), 5.0)\n" +
" > 5.0 3:2.1 7:3.4")
// @formatter:on
@UDFType(deterministic = true, stateful = false)
public final class ToLibSVMFormatUDF extends UDFWithOptions {
private ListObjectInspector _featuresOI;
@Nullable
private PrimitiveObjectInspector _targetOI = null;
private int _numFeatures = MurmurHash3.DEFAULT_NUM_FEATURES;
private StringBuilder _buf;
@Override
protected Options getOptions() {
Options opts = new Options();
opts.addOption("features", "num_features", true,
"The number of features [default: 16777217 (2^24)]");
return opts;
}
@Override
protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException {
CommandLine cl = parseOptions(optionValue);
this._numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"),
MurmurHash3.DEFAULT_NUM_FEATURES);
assumeTrue(_numFeatures > 0, "num_features must be greater than 0: " + _numFeatures);
return cl;
}
@Override
public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
assumeTrue(argOIs.length >= 1 || argOIs.length <= 3,
"to_libsvm_format UDF takes 1~3 arguments");
this._featuresOI = HiveUtils.asListOI(argOIs[0]);
if (argOIs.length == 2) {
ObjectInspector argOI1 = argOIs[1];
if (HiveUtils.isNumberOI(argOI1)) {
this._targetOI = HiveUtils.asNumberOI(argOI1);
} else if (HiveUtils.isConstString(argOI1)) { // no target
String opts = HiveUtils.getConstString(argOI1);
processOptions(opts);
} else {
throw new UDFArgumentException(
"Unexpected argument type for 2nd argument: " + argOI1.getTypeName());
}
} else if (argOIs.length == 3) {
this._targetOI = HiveUtils.asNumberOI(argOIs[1]);
String opts = HiveUtils.getConstString(argOIs[2]);
processOptions(opts);
}
this._buf = new StringBuilder();
return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
}
@Nullable
@Override
public String evaluate(DeferredObject[] args) throws HiveException {
final StringBuilder buf = this._buf;
StringUtils.clear(buf);
Object arg0 = args[0].get();
if (arg0 == null) {
return null;
}
final int featureSize = _featuresOI.getListLength(arg0);
List<Pair<Integer, Double>> features = new ArrayList<>(featureSize);
for (int i = 0; i < featureSize; i++) {
Object e = _featuresOI.getListElement(arg0, i);
if (e == null) {
continue;
}
Pair<Integer, Double> fv = parse(e.toString(), _numFeatures);
features.add(fv);
}
Collections.sort(features, comparator);
if (_targetOI != null) {
Object arg1 = args[1].get();
if (arg1 == null) {
throw new HiveException("Detected NULL for the 2nd argument");
}
if (HiveUtils.isIntegerOI(_targetOI)) {
int label = HiveUtils.getInt(arg1, _targetOI);
buf.append(label);
} else {
double label = HiveUtils.getDouble(arg1, _targetOI);
buf.append(label);
}
buf.append(' ');
}
for (int i = 0, size = features.size(); i < size; i++) {
if (i != 0) {
buf.append(' ');
}
Pair<Integer, Double> fv = features.get(i);
buf.append(fv.getKey().intValue());
buf.append(':');
buf.append(fv.getValue().doubleValue());
}
return buf.toString();
}
@Nonnull
public static Pair<Integer, Double> parse(@Nonnull final String fv,
@Nonnegative final int numFeatures) throws UDFArgumentException {
final int headPos = fv.indexOf(':');
if (headPos == -1) {
if (NumberUtils.isDigits(fv)) {
final int f;
try {
f = Integer.parseInt(fv);
} catch (NumberFormatException e) {
throw new UDFArgumentException("Invalid feature value: " + fv);
}
return new Pair<>(f, 1.d);
} else {
return new Pair<>(mhash(fv, numFeatures), 1.d);
}
} else {
final int tailPos = fv.lastIndexOf(':');
if (headPos != tailPos) {
throw new UDFArgumentException("Unsupported feature format: " + fv);
}
String f = fv.substring(0, headPos);
String v = fv.substring(headPos + 1);
final double d;
try {
d = Double.parseDouble(v);
} catch (NumberFormatException e) {
throw new UDFArgumentException("Invalid feature value: " + fv);
}
if (NumberUtils.isDigits(f)) {
final int i;
try {
i = Integer.parseInt(f);
} catch (NumberFormatException e) {
throw new UDFArgumentException("Invalid feature value: " + fv);
}
return new Pair<>(i, d);
} else {
return new Pair<>(mhash(f, numFeatures), d);
}
}
}
private static int mhash(@Nonnull final String word, final int numFeatures) {
int r = MurmurHash3.murmurhash3_x86_32(word, 0, word.length(), 0x9747b28c) % numFeatures;
if (r < 0) {
r += numFeatures;
}
return r + 1;
}
private static final Comparator<Pair<Integer, Double>> comparator =
new Comparator<Pair<Integer, Double>>() {
@Override
public int compare(Pair<Integer, Double> l, Pair<Integer, Double> r) {
return l.getKey().compareTo(r.getKey());
}
};
@Override
public String getDisplayString(String[] args) {
return "to_libsvm_format( " + StringUtils.join(args, ',') + " )";
}
}