| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package hivemall.factorization.fm; |
| |
| import static org.apache.hadoop.hive.ql.util.JavaDataModel.JAVA64_ARRAY_META; |
| import static org.apache.hadoop.hive.ql.util.JavaDataModel.JAVA64_REF; |
| import static org.apache.hadoop.hive.ql.util.JavaDataModel.PRIMITIVES2; |
| |
| import hivemall.utils.hadoop.HiveUtils; |
| import hivemall.utils.hadoop.WritableUtils; |
| |
| import java.util.ArrayList; |
| |
| import javax.annotation.Nonnull; |
| import javax.annotation.Nullable; |
| |
| import org.apache.hadoop.hive.ql.exec.Description; |
| import org.apache.hadoop.hive.ql.exec.UDFArgumentException; |
| import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; |
| import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; |
| import org.apache.hadoop.hive.ql.metadata.HiveException; |
| import org.apache.hadoop.hive.ql.parse.SemanticException; |
| import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; |
| import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; |
| import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; |
| import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType; |
| import org.apache.hadoop.hive.serde2.io.DoubleWritable; |
| import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; |
| import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; |
| import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; |
| import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; |
| import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; |
| import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; |
| import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; |
| import org.apache.hadoop.hive.serde2.objectinspector.StructField; |
| import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; |
| import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; |
| import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; |
| import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector; |
| import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; |
| import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; |
| |
| @Description(name = "fm_predict", |
| value = "_FUNC_(Float Wj, array<float> Vjf, float Xj) - Returns a prediction value in Double") |
| public final class FMPredictGenericUDAF extends AbstractGenericUDAFResolver { |
| |
| public FMPredictGenericUDAF() { |
| super(); |
| } |
| |
| @Override |
| public Evaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException { |
| if (typeInfo.length != 3) { |
| throw new UDFArgumentLengthException( |
| "Expected argument length is 3 but given argument length was " + typeInfo.length); |
| } |
| if (!HiveUtils.isNumberTypeInfo(typeInfo[0])) { |
| throw new UDFArgumentTypeException(0, |
| "Number type is expected for the first argument Wj: " + typeInfo[0].getTypeName()); |
| } |
| if (typeInfo[1].getCategory() != Category.LIST) { |
| throw new UDFArgumentTypeException(1, |
| "List type is expected for the second argument Vjf: " + typeInfo[1].getTypeName()); |
| } |
| ListTypeInfo typeInfo1 = (ListTypeInfo) typeInfo[1]; |
| if (!HiveUtils.isNumberTypeInfo(typeInfo1.getListElementTypeInfo())) { |
| throw new UDFArgumentTypeException(1, |
| "Number type is expected for the element type of list Vjf: " |
| + typeInfo1.getTypeName()); |
| } |
| if (!HiveUtils.isNumberTypeInfo(typeInfo[2])) { |
| throw new UDFArgumentTypeException(2, |
| "Number type is expected for the third argument Xj: " + typeInfo[2].getTypeName()); |
| } |
| return new Evaluator(); |
| } |
| |
| public static class Evaluator extends GenericUDAFEvaluator { |
| |
| // input OI |
| private PrimitiveObjectInspector wOI; |
| private ListObjectInspector vOI; |
| private PrimitiveObjectInspector vElemOI; |
| private PrimitiveObjectInspector xOI; |
| |
| // merge OI |
| private StructObjectInspector internalMergeOI; |
| private StructField retField, sumVjXjField, sumV2X2Field; |
| private WritableDoubleObjectInspector retOI; |
| private StandardListObjectInspector sumVjXjOI, sumV2X2OI; |
| |
| public Evaluator() {} |
| |
| @Override |
| public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { |
| assert (parameters.length == 3); |
| super.init(mode, parameters); |
| |
| // initialize input |
| if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data |
| this.wOI = HiveUtils.asDoubleCompatibleOI(parameters, 0); |
| this.vOI = HiveUtils.asListOI(parameters, 1); |
| this.vElemOI = HiveUtils.asDoubleCompatibleOI(vOI.getListElementObjectInspector()); |
| this.xOI = HiveUtils.asDoubleCompatibleOI(parameters, 2); |
| } else {// from partial aggregation |
| StructObjectInspector soi = (StructObjectInspector) parameters[0]; |
| this.internalMergeOI = soi; |
| this.retField = soi.getStructFieldRef("ret"); |
| this.sumVjXjField = soi.getStructFieldRef("sumVjXj"); |
| this.sumV2X2Field = soi.getStructFieldRef("sumV2X2"); |
| this.retOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; |
| this.sumVjXjOI = ObjectInspectorFactory.getStandardListObjectInspector( |
| PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); |
| this.sumV2X2OI = ObjectInspectorFactory.getStandardListObjectInspector( |
| PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); |
| } |
| |
| // initialize output |
| final ObjectInspector outputOI; |
| if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial |
| outputOI = internalMergeOI(); |
| } else { |
| outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; |
| } |
| return outputOI; |
| } |
| |
| private static StructObjectInspector internalMergeOI() { |
| ArrayList<String> fieldNames = new ArrayList<String>(); |
| ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); |
| |
| fieldNames.add("ret"); |
| fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); |
| fieldNames.add("sumVjXj"); |
| fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( |
| PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); |
| fieldNames.add("sumV2X2"); |
| fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( |
| PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); |
| |
| return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); |
| } |
| |
| @Override |
| public FMPredictAggregationBuffer getNewAggregationBuffer() throws HiveException { |
| FMPredictAggregationBuffer buf = new FMPredictAggregationBuffer(); |
| buf.reset(); |
| return buf; |
| } |
| |
| @Override |
| public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) |
| throws HiveException { |
| FMPredictAggregationBuffer buf = (FMPredictAggregationBuffer) agg; |
| buf.reset(); |
| } |
| |
| @Override |
| public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, |
| Object[] parameters) throws HiveException { |
| if (parameters[0] == null) { |
| return; |
| } |
| FMPredictAggregationBuffer buf = (FMPredictAggregationBuffer) agg; |
| |
| double w = PrimitiveObjectInspectorUtils.getDouble(parameters[0], wOI); |
| if (parameters[1] == null || /* for TD */vOI.getListLength(parameters[1]) == 0) {// Vif was null |
| buf.iterate(w); |
| } else { |
| if (parameters[2] == null) { |
| throw new UDFArgumentException("The third argument Xj must not be null"); |
| } |
| double x = PrimitiveObjectInspectorUtils.getDouble(parameters[2], xOI); |
| buf.iterate(w, x, parameters[1], vOI, vElemOI); |
| } |
| } |
| |
| @Override |
| public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) |
| throws HiveException { |
| FMPredictAggregationBuffer buf = (FMPredictAggregationBuffer) agg; |
| |
| final Object[] partialResult = new Object[3]; |
| partialResult[0] = new DoubleWritable(buf.ret); |
| if (buf.sumVjXj != null) { |
| partialResult[1] = WritableUtils.toWritableList(buf.sumVjXj); |
| partialResult[2] = WritableUtils.toWritableList(buf.sumV2X2); |
| } |
| return partialResult; |
| } |
| |
| @Override |
| public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) |
| throws HiveException { |
| if (partial == null) { |
| return; |
| } |
| FMPredictAggregationBuffer buf = (FMPredictAggregationBuffer) agg; |
| |
| Object o1 = internalMergeOI.getStructFieldData(partial, retField); |
| double ret = retOI.get(o1); |
| |
| Object sumVjXj = internalMergeOI.getStructFieldData(partial, sumVjXjField); |
| Object sumV2X2 = internalMergeOI.getStructFieldData(partial, sumV2X2Field); |
| |
| // -------------------------------------------------------------- |
| // [workaround] |
| // java.lang.ClassCastException: org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray |
| // cannot be cast to [Ljava.lang.Object; |
| if (sumVjXj instanceof LazyBinaryArray) { |
| sumVjXj = ((LazyBinaryArray) sumVjXj).getList(); |
| } |
| if (sumV2X2 instanceof LazyBinaryArray) { |
| sumV2X2 = ((LazyBinaryArray) sumV2X2).getList(); |
| } |
| // -------------------------------------------------------------- |
| |
| buf.merge(ret, sumVjXj, sumV2X2, sumVjXjOI, sumV2X2OI); |
| } |
| |
| @Override |
| public DoubleWritable terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) |
| throws HiveException { |
| FMPredictAggregationBuffer buf = (FMPredictAggregationBuffer) agg; |
| double predict = buf.getPrediction(); |
| return new DoubleWritable(predict); |
| } |
| |
| } |
| |
| @AggregationType(estimable = true) |
| public static class FMPredictAggregationBuffer extends AbstractAggregationBuffer { |
| |
| private double ret; |
| private double[] sumVjXj; |
| private double[] sumV2X2; |
| |
| FMPredictAggregationBuffer() { |
| super(); |
| } |
| |
| void reset() { |
| this.ret = 0.d; |
| this.sumVjXj = null; |
| this.sumV2X2 = null; |
| } |
| |
| void iterate(double Wj) { |
| this.ret += Wj; |
| } |
| |
| void iterate(final double Wj, final double Xj, @Nonnull final Object Vif, |
| @Nonnull final ListObjectInspector vOI, |
| @Nonnull final PrimitiveObjectInspector vElemOI) throws HiveException { |
| this.ret += (Wj * Xj); |
| |
| final int factors = vOI.getListLength(Vif); |
| if (factors < 1) { |
| throw new HiveException("# of Factor should be more than 0: " + factors); |
| } |
| |
| if (sumVjXj == null) { |
| this.sumVjXj = new double[factors]; |
| this.sumV2X2 = new double[factors]; |
| } else if (sumVjXj.length != factors) { |
| throw new HiveException("Mismatch in the number of factors"); |
| } |
| |
| for (int f = 0; f < factors; f++) { |
| Object o = vOI.getListElement(Vif, f); |
| if (o == null) { |
| throw new HiveException("Vj" + f + " should not be null"); |
| } |
| double v = PrimitiveObjectInspectorUtils.getDouble(o, vElemOI); |
| double vx = v * Xj; |
| |
| sumVjXj[f] += vx; |
| sumV2X2[f] += (vx * vx); |
| } |
| } |
| |
| void merge(final double o_ret, @Nullable final Object o_sumVjXj, |
| @Nullable final Object o_sumV2X2, |
| @Nonnull final StandardListObjectInspector sumVjXjOI, |
| @Nonnull final StandardListObjectInspector sumV2X2OI) throws HiveException { |
| this.ret += o_ret; |
| if (o_sumVjXj == null) { |
| return; |
| } |
| |
| if (o_sumV2X2 == null) {//sanity check |
| throw new HiveException("o_sumV2X2 should not be null"); |
| } |
| |
| final int factors = sumVjXjOI.getListLength(o_sumVjXj); |
| if (sumVjXj == null) { |
| this.sumVjXj = new double[factors]; |
| this.sumV2X2 = new double[factors]; |
| } else if (sumVjXj.length != factors) {//sanity check |
| throw new HiveException("Mismatch in the number of factors"); |
| } |
| |
| final WritableDoubleObjectInspector doubleOI = |
| PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; |
| for (int f = 0; f < factors; f++) { |
| Object o1 = sumVjXjOI.getListElement(o_sumVjXj, f); |
| Object o2 = sumV2X2OI.getListElement(o_sumV2X2, f); |
| double d1 = doubleOI.get(o1); |
| double d2 = doubleOI.get(o2); |
| sumVjXj[f] += d1; |
| sumV2X2[f] += d2; |
| } |
| } |
| |
| double getPrediction() { |
| double predict = this.ret; |
| if (sumVjXj != null) { |
| final int factors = sumVjXj.length; |
| for (int f = 0; f < factors; f++) { |
| double d1 = sumVjXj[f]; |
| double d2 = sumV2X2[f]; |
| predict += 0.5d * (d1 * d1 - d2); |
| } |
| } |
| return predict; |
| } |
| |
| @Override |
| public int estimate() { |
| if (sumVjXj == null) { |
| return PRIMITIVES2 + 2 * JAVA64_REF; |
| } else { |
| // model.array() = JAVA64_ARRAY_META + JAVA64_REF |
| return PRIMITIVES2 + 2 * (JAVA64_ARRAY_META + PRIMITIVES2 * sumVjXj.length); |
| } |
| } |
| } |
| |
| } |