Support AUC evaluation UDAF
diff --git a/core/src/main/java/hivemall/evaluation/AUCUDAF.java b/core/src/main/java/hivemall/evaluation/AUCUDAF.java
new file mode 100644
index 0000000..eb6a1a9
--- /dev/null
+++ b/core/src/main/java/hivemall/evaluation/AUCUDAF.java
@@ -0,0 +1,589 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.evaluation;
+
+import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
+import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaLongObjectInspector;
+import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableLongObjectInspector;
+import hivemall.utils.hadoop.HiveUtils;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import javax.annotation.Nonnull;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
+import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+import org.apache.hadoop.io.LongWritable;
+
+@SuppressWarnings("deprecation")
+@Description(name = "auc",
+ value = "_FUNC_(array rankItems | double score, array correctItems | int label "
+ + "[, const int recommendSize = rankItems.size ])" + " - Returns AUC")
+public final class AUCUDAF extends AbstractGenericUDAFResolver {
+
+ @Override
+ public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException {
+ if (typeInfo.length != 2 && typeInfo.length != 3) {
+ throw new UDFArgumentTypeException(typeInfo.length - 1,
+ "_FUNC_ takes two or three arguments");
+ }
+
+ if (HiveUtils.isNumberTypeInfo(typeInfo[0]) && HiveUtils.isIntegerTypeInfo(typeInfo[1])) {
+ return new ClassificationEvaluator();
+ } else {
+ ListTypeInfo arg1type = HiveUtils.asListTypeInfo(typeInfo[0]);
+ if (!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo())) {
+ throw new UDFArgumentTypeException(0,
+ "The first argument `array rankItems` is invalid form: " + typeInfo[0]);
+ }
+
+ ListTypeInfo arg2type = HiveUtils.asListTypeInfo(typeInfo[1]);
+ if (!HiveUtils.isPrimitiveTypeInfo(arg2type.getListElementTypeInfo())) {
+ throw new UDFArgumentTypeException(1,
+ "The second argument `array correctItems` is invalid form: " + typeInfo[1]);
+ }
+
+ return new RankingEvaluator();
+ }
+ }
+
+ public static class ClassificationEvaluator extends GenericUDAFEvaluator {
+
+ private PrimitiveObjectInspector scoreOI;
+ private PrimitiveObjectInspector labelOI;
+
+ private StructObjectInspector internalMergeOI;
+ private StructField indexScoreField;
+ private StructField areaField;
+ private StructField fpField;
+ private StructField tpField;
+ private StructField fpPrevField;
+ private StructField tpPrevField;
+ private StructField areaPartialMapField;
+ private StructField fpPartialMapField;
+ private StructField tpPartialMapField;
+ private StructField fpPrevPartialMapField;
+ private StructField tpPrevPartialMapField;
+
+ public ClassificationEvaluator() {}
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
+ assert (parameters.length == 2 || parameters.length == 3) : parameters.length;
+ super.init(mode, parameters);
+
+ // initialize input
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
+ this.scoreOI = HiveUtils.asDoubleCompatibleOI(parameters[0]);
+ this.labelOI = HiveUtils.asIntegerOI(parameters[1]);
+ } else {// from partial aggregation
+ StructObjectInspector soi = (StructObjectInspector) parameters[0];
+ this.internalMergeOI = soi;
+ this.indexScoreField = soi.getStructFieldRef("indexScore");
+ this.areaField = soi.getStructFieldRef("area");
+ this.fpField = soi.getStructFieldRef("fp");
+ this.tpField = soi.getStructFieldRef("tp");
+ this.fpPrevField = soi.getStructFieldRef("fpPrev");
+ this.tpPrevField = soi.getStructFieldRef("tpPrev");
+ this.areaPartialMapField = soi.getStructFieldRef("areaPartialMap");
+ this.fpPartialMapField = soi.getStructFieldRef("fpPartialMap");
+ this.tpPartialMapField = soi.getStructFieldRef("tpPartialMap");
+ this.fpPrevPartialMapField = soi.getStructFieldRef("fpPrevPartialMap");
+ this.tpPrevPartialMapField = soi.getStructFieldRef("tpPrevPartialMap");
+ }
+
+ // initialize output
+ final ObjectInspector outputOI;
+ if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
+ outputOI = internalMergeOI();
+ } else {// terminate
+ outputOI = writableDoubleObjectInspector;
+ }
+ return outputOI;
+ }
+
+ private static StructObjectInspector internalMergeOI() {
+ ArrayList<String> fieldNames = new ArrayList<String>();
+ ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+
+ fieldNames.add("indexScore");
+ fieldOIs.add(writableDoubleObjectInspector);
+ fieldNames.add("area");
+ fieldOIs.add(writableDoubleObjectInspector);
+ fieldNames.add("fp");
+ fieldOIs.add(writableLongObjectInspector);
+ fieldNames.add("tp");
+ fieldOIs.add(writableLongObjectInspector);
+ fieldNames.add("fpPrev");
+ fieldOIs.add(writableLongObjectInspector);
+ fieldNames.add("tpPrev");
+ fieldOIs.add(writableLongObjectInspector);
+
+ MapObjectInspector areaPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ javaDoubleObjectInspector, javaDoubleObjectInspector);
+ fieldNames.add("areaPartialMap");
+ fieldOIs.add(areaPartialMapOI);
+
+ MapObjectInspector fpPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ javaDoubleObjectInspector, javaLongObjectInspector);
+ fieldNames.add("fpPartialMap");
+ fieldOIs.add(fpPartialMapOI);
+
+ MapObjectInspector tpPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ javaDoubleObjectInspector, javaLongObjectInspector);
+ fieldNames.add("tpPartialMap");
+ fieldOIs.add(tpPartialMapOI);
+
+ MapObjectInspector fpPrevPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ javaDoubleObjectInspector, javaLongObjectInspector);
+ fieldNames.add("fpPrevPartialMap");
+ fieldOIs.add(fpPrevPartialMapOI);
+
+ MapObjectInspector tpPrevPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ javaDoubleObjectInspector, javaLongObjectInspector);
+ fieldNames.add("tpPrevPartialMap");
+ fieldOIs.add(tpPrevPartialMapOI);
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public AggregationBuffer getNewAggregationBuffer() throws HiveException {
+ AggregationBuffer myAggr = new ClassificationAUCAggregationBuffer();
+ reset(myAggr);
+ return myAggr;
+ }
+
+ @Override
+ public void reset(AggregationBuffer agg) throws HiveException {
+ ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg;
+ myAggr.reset();
+ }
+
+ @Override
+ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
+ ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg;
+
+ if (parameters[0] == null) {
+ return;
+ }
+ if (parameters[1] == null) {
+ return;
+ }
+
+ double score = HiveUtils.getDouble(parameters[0], scoreOI);
+ if (score < 0.0d || score > 1.0d) {
+ throw new UDFArgumentException("score value MUST be in range [0,1]: " + score);
+ }
+
+ int label = PrimitiveObjectInspectorUtils.getInt(parameters[1], labelOI);
+ if (label == -1) {
+ label = 0;
+ } else if (label != 0 && label != 1) {
+ throw new UDFArgumentException("label MUST be 0/1 or -1/1: " + label);
+ }
+
+ myAggr.iterate(score, label);
+ }
+
+ @Override
+ public Object terminatePartial(AggregationBuffer agg) throws HiveException {
+ ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg;
+
+ Object[] partialResult = new Object[11];
+ partialResult[0] = new DoubleWritable(myAggr.indexScore);
+ partialResult[1] = new DoubleWritable(myAggr.area);
+ partialResult[2] = new LongWritable(myAggr.fp);
+ partialResult[3] = new LongWritable(myAggr.tp);
+ partialResult[4] = new LongWritable(myAggr.fpPrev);
+ partialResult[5] = new LongWritable(myAggr.tpPrev);
+ partialResult[6] = myAggr.areaPartialMap;
+ partialResult[7] = myAggr.fpPartialMap;
+ partialResult[8] = myAggr.tpPartialMap;
+ partialResult[9] = myAggr.fpPrevPartialMap;
+ partialResult[10] = myAggr.tpPrevPartialMap;
+
+ return partialResult;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public void merge(AggregationBuffer agg, Object partial) throws HiveException {
+ if (partial == null) {
+ return;
+ }
+
+ Object indexScoreObj = internalMergeOI.getStructFieldData(partial, indexScoreField);
+ Object areaObj = internalMergeOI.getStructFieldData(partial, areaField);
+ Object fpObj = internalMergeOI.getStructFieldData(partial, fpField);
+ Object tpObj = internalMergeOI.getStructFieldData(partial, tpField);
+ Object fpPrevObj = internalMergeOI.getStructFieldData(partial, fpPrevField);
+ Object tpPrevObj = internalMergeOI.getStructFieldData(partial, tpPrevField);
+ Object areaPartialMapObj = internalMergeOI.getStructFieldData(partial,
+ areaPartialMapField);
+ Object fpPartialMapObj = internalMergeOI.getStructFieldData(partial, fpPartialMapField);
+ Object tpPartialMapObj = internalMergeOI.getStructFieldData(partial, tpPartialMapField);
+ Object fpPrevPartialMapObj = internalMergeOI.getStructFieldData(partial,
+ fpPrevPartialMapField);
+ Object tpPrevPartialMapObj = internalMergeOI.getStructFieldData(partial,
+ tpPrevPartialMapField);
+
+ double indexScore = writableDoubleObjectInspector.get(indexScoreObj);
+ double area = writableDoubleObjectInspector.get(areaObj);
+ long fp = writableLongObjectInspector.get(fpObj);
+ long tp = writableLongObjectInspector.get(tpObj);
+ long fpPrev = writableLongObjectInspector.get(fpPrevObj);
+ long tpPrev = writableLongObjectInspector.get(tpPrevObj);
+
+ StandardMapObjectInspector ddMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ javaDoubleObjectInspector, javaDoubleObjectInspector);
+ StandardMapObjectInspector dlMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ javaDoubleObjectInspector, javaLongObjectInspector);
+
+ Map<Double, Double> areaPartialMap = (Map<Double, Double>) ddMapOI.getMap(HiveUtils.castLazyBinaryObject(areaPartialMapObj));
+ Map<Double, Long> fpPartialMap = (Map<Double, Long>) dlMapOI.getMap(HiveUtils.castLazyBinaryObject(fpPartialMapObj));
+ Map<Double, Long> tpPartialMap = (Map<Double, Long>) dlMapOI.getMap(HiveUtils.castLazyBinaryObject(tpPartialMapObj));
+ Map<Double, Long> fpPrevPartialMap = (Map<Double, Long>) dlMapOI.getMap(HiveUtils.castLazyBinaryObject(fpPrevPartialMapObj));
+ Map<Double, Long> tpPrevPartialMap = (Map<Double, Long>) dlMapOI.getMap(HiveUtils.castLazyBinaryObject(tpPrevPartialMapObj));
+
+ ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg;
+ myAggr.merge(indexScore, area, fp, tp, fpPrev, tpPrev, areaPartialMap, fpPartialMap,
+ tpPartialMap, fpPrevPartialMap, tpPrevPartialMap);
+ }
+
+ @Override
+ public DoubleWritable terminate(AggregationBuffer agg) throws HiveException {
+ ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg;
+ double result = myAggr.get();
+ return new DoubleWritable(result);
+ }
+
+ }
+
+ public static class ClassificationAUCAggregationBuffer implements
+ GenericUDAFEvaluator.AggregationBuffer {
+
+ double area, scorePrev, indexScore;
+ long fp, tp, fpPrev, tpPrev;
+ Map<Double, Double> areaPartialMap;
+ Map<Double, Long> fpPartialMap, tpPartialMap, fpPrevPartialMap, tpPrevPartialMap;
+
+ public ClassificationAUCAggregationBuffer() {
+ super();
+ }
+
+ void reset() {
+ this.area = 0.d;
+ this.scorePrev = Double.POSITIVE_INFINITY;
+ this.indexScore = 0.d;
+ this.fp = 0;
+ this.tp = 0;
+ this.fpPrev = 0;
+ this.tpPrev = 0;
+ this.areaPartialMap = new HashMap<Double, Double>();
+ this.fpPartialMap = new HashMap<Double, Long>();
+ this.tpPartialMap = new HashMap<Double, Long>();
+ this.fpPrevPartialMap = new HashMap<Double, Long>();
+ this.tpPrevPartialMap = new HashMap<Double, Long>();
+ }
+
+ void merge(double o_indexScore, double o_area, long o_fp, long o_tp, long o_fpPrev,
+ long o_tpPrev, Map<Double, Double> o_areaPartialMap,
+ Map<Double, Long> o_fpPartialMap, Map<Double, Long> o_tpPartialMap,
+ Map<Double, Long> o_fpPrevPartialMap, Map<Double, Long> o_tpPrevPartialMap) {
+
+ // merge past partial results
+ areaPartialMap.putAll(o_areaPartialMap);
+ fpPartialMap.putAll(o_fpPartialMap);
+ tpPartialMap.putAll(o_tpPartialMap);
+ fpPrevPartialMap.putAll(o_fpPrevPartialMap);
+ tpPrevPartialMap.putAll(o_tpPrevPartialMap);
+
+ // finalize source AUC computation
+ o_area += trapezoidArea(o_fp, o_fpPrev, o_tp, o_tpPrev);
+
+ // store source results
+ areaPartialMap.put(o_indexScore, o_area);
+ fpPartialMap.put(o_indexScore, o_fp);
+ tpPartialMap.put(o_indexScore, o_tp);
+ fpPrevPartialMap.put(o_indexScore, o_fpPrev);
+ tpPrevPartialMap.put(o_indexScore, o_tpPrev);
+ }
+
+ double get() throws HiveException {
+ // store self results
+ areaPartialMap.put(indexScore, area);
+ fpPartialMap.put(indexScore, fp);
+ tpPartialMap.put(indexScore, tp);
+ fpPrevPartialMap.put(indexScore, fpPrev);
+ tpPrevPartialMap.put(indexScore, tpPrev);
+
+ SortedMap<Double, Double> areaPartialSortedMap = new TreeMap<Double, Double>(
+ Collections.reverseOrder());
+ areaPartialSortedMap.putAll(areaPartialMap);
+
+ // initialize with leftmost partial result
+ double firstKey = areaPartialSortedMap.firstKey();
+ double res = areaPartialSortedMap.get(firstKey);
+ long fpAccum = fpPartialMap.get(firstKey);
+ long tpAccum = tpPartialMap.get(firstKey);
+ long fpPrevAccum = fpPrevPartialMap.get(firstKey);
+ long tpPrevAccum = tpPrevPartialMap.get(firstKey);
+
+ // Merge from left (larger score) to right (smaller score)
+ for (double k : areaPartialSortedMap.keySet()) {
+ if (k == firstKey) { // variables are already initialized with the leftmost partial result
+ continue;
+ }
+
+ // sum up partial area
+ res += areaPartialSortedMap.get(k);
+
+ // adjust combined area by adding missing rectangle
+ res += trapezoidArea(0, fpPartialMap.get(k), tpAccum, tpAccum);
+
+ // sum up (prev) TP/FP count
+ fpPrevAccum = fpAccum + fpPrevPartialMap.get(k);
+ tpPrevAccum = tpAccum + tpPrevPartialMap.get(k);
+ fpAccum = fpAccum + fpPartialMap.get(k);
+ tpAccum = tpAccum + tpPartialMap.get(k);
+ }
+
+ if (tpAccum == 0 || fpAccum == 0) {
+ throw new HiveException(
+ "AUC score is not defined because there is only one class in `label`.");
+ }
+
+ // finalize by adding a trapezoid based on the last tp/fp counts
+ res += trapezoidArea(fpAccum, fpPrevAccum, tpAccum, tpPrevAccum);
+
+ return res / (tpAccum * fpAccum); // scale
+ }
+
+ void iterate(double score, int label) {
+ if (score != scorePrev) {
+ if (scorePrev == Double.POSITIVE_INFINITY) {
+ // store maximum score as an index
+ indexScore = score;
+ }
+ area += trapezoidArea(fp, fpPrev, tp, tpPrev); // under (fp, tp)-(fpPrev, tpPrev)
+ scorePrev = score;
+ fpPrev = fp;
+ tpPrev = tp;
+ }
+ if (label == 1) {
+ tp++; // this finally will be the number of positive samples
+ } else {
+ fp++; // this finally will be the number of negative samples
+ }
+ }
+
+ private double trapezoidArea(double x1, double x2, double y1, double y2) {
+ double base = Math.abs(x1 - x2);
+ double height = (y1 + y2) / 2.d;
+ return base * height;
+ }
+ }
+
+ public static class RankingEvaluator extends GenericUDAFEvaluator {
+
+ private ListObjectInspector recommendListOI;
+ private ListObjectInspector truthListOI;
+ private WritableIntObjectInspector recommendSizeOI;
+
+ private StructObjectInspector internalMergeOI;
+ private StructField countField;
+ private StructField sumField;
+
+ public RankingEvaluator() {}
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
+ assert (parameters.length == 2 || parameters.length == 3) : parameters.length;
+ super.init(mode, parameters);
+
+ // initialize input
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
+ this.recommendListOI = (ListObjectInspector) parameters[0];
+ this.truthListOI = (ListObjectInspector) parameters[1];
+ if (parameters.length == 3) {
+ this.recommendSizeOI = (WritableIntObjectInspector) parameters[2];
+ }
+ } else {// from partial aggregation
+ StructObjectInspector soi = (StructObjectInspector) parameters[0];
+ this.internalMergeOI = soi;
+ this.countField = soi.getStructFieldRef("count");
+ this.sumField = soi.getStructFieldRef("sum");
+ }
+
+ // initialize output
+ final ObjectInspector outputOI;
+ if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
+ outputOI = internalMergeOI();
+ } else {// terminate
+ outputOI = writableDoubleObjectInspector;
+ }
+ return outputOI;
+ }
+
+ private static StructObjectInspector internalMergeOI() {
+ ArrayList<String> fieldNames = new ArrayList<String>();
+ ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+
+ fieldNames.add("sum");
+ fieldOIs.add(writableDoubleObjectInspector);
+ fieldNames.add("count");
+ fieldOIs.add(writableLongObjectInspector);
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public AggregationBuffer getNewAggregationBuffer() throws HiveException {
+ AggregationBuffer myAggr = new RankingAUCAggregationBuffer();
+ reset(myAggr);
+ return myAggr;
+ }
+
+ @Override
+ public void reset(AggregationBuffer agg) throws HiveException {
+ RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) agg;
+ myAggr.reset();
+ }
+
+ @Override
+ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
+ RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) agg;
+
+ List<?> recommendList = recommendListOI.getList(parameters[0]);
+ if (recommendList == null) {
+ recommendList = Collections.emptyList();
+ }
+ List<?> truthList = truthListOI.getList(parameters[1]);
+ if (truthList == null) {
+ return;
+ }
+
+ int recommendSize = recommendList.size();
+ if (parameters.length == 3) {
+ recommendSize = recommendSizeOI.get(parameters[2]);
+ }
+ if (recommendSize < 0 || recommendSize > recommendList.size()) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in [0, " + recommendList.size()
+ + "]");
+ }
+
+ myAggr.iterate(recommendList, truthList, recommendSize);
+ }
+
+ @Override
+ public Object terminatePartial(AggregationBuffer agg) throws HiveException {
+ RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) agg;
+
+ Object[] partialResult = new Object[2];
+ partialResult[0] = new DoubleWritable(myAggr.sum);
+ partialResult[1] = new LongWritable(myAggr.count);
+ return partialResult;
+ }
+
+ @Override
+ public void merge(AggregationBuffer agg, Object partial) throws HiveException {
+ if (partial == null) {
+ return;
+ }
+
+ Object sumObj = internalMergeOI.getStructFieldData(partial, sumField);
+ Object countObj = internalMergeOI.getStructFieldData(partial, countField);
+ double sum = writableDoubleObjectInspector.get(sumObj);
+ long count = writableLongObjectInspector.get(countObj);
+
+ RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) agg;
+ myAggr.merge(sum, count);
+ }
+
+ @Override
+ public DoubleWritable terminate(AggregationBuffer agg) throws HiveException {
+ RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) agg;
+ double result = myAggr.get();
+ return new DoubleWritable(result);
+ }
+
+ }
+
+ public static class RankingAUCAggregationBuffer implements
+ GenericUDAFEvaluator.AggregationBuffer {
+
+ double sum;
+ long count;
+
+ public RankingAUCAggregationBuffer() {
+ super();
+ }
+
+ void reset() {
+ this.sum = 0.d;
+ this.count = 0;
+ }
+
+ void merge(double o_sum, long o_count) {
+ sum += o_sum;
+ count += o_count;
+ }
+
+ double get() {
+ if (count == 0) {
+ return 0.d;
+ }
+ return sum / count;
+ }
+
+ void iterate(@Nonnull List<?> recommendList, @Nonnull List<?> truthList,
+ @Nonnull int recommendSize) {
+ sum += BinaryResponsesMeasures.AUC(recommendList, truthList, recommendSize);
+ count++;
+ }
+ }
+
+}
diff --git a/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java b/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
index 09ae68c..4338d79 100644
--- a/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
+++ b/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
@@ -18,13 +18,16 @@
*/
package hivemall.evaluation;
+import hivemall.utils.lang.Preconditions;
+
import java.util.List;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
/**
* Utility class of various measures.
- *
+ *
* See http://recsyswiki.com/wiki/Discounted_Cumulative_Gain
*/
public final class BinaryResponsesMeasures {
@@ -49,7 +52,7 @@
/**
* Computes the ideal DCG
- *
+ *
* @param n the number of positive items
* @return ideal DCG
*/
@@ -61,4 +64,43 @@
return idcg;
}
+ /**
+ * Computes the area under the ROC curve (AUC)
+ *
+ * @param rankedList a list of ranked item IDs (first item is highest-ranked)
+ * @param groundTruth a collection of positive/correct item IDs
+ * @param recommendSize top-`recommendSize` items in `rankedList` are recommended
+ * @return AUC
+ */
+ public static double AUC(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
+ @Nonnegative final int recommendSize) {
+ Preconditions.checkArgument(recommendSize > 0);
+
+ int nTruePositive = 0, nCorrectPairs = 0;
+
+ // count # of pairs of items that are ranked in the correct order (i.e. TP > FP)
+ final int k = Math.min(rankedList.size(), recommendSize);
+ for (int i = 0; i < k; i++) {
+ Object item_id = rankedList.get(i);
+ if (groundTruth.contains(item_id)) {
+ // # of true positives which are ranked higher position than i-th recommended item
+ nTruePositive++;
+ } else {
+ // for each FP item, # of correct ordered <TP, FP> pairs equals to # of TPs at i-th position
+ nCorrectPairs += nTruePositive;
+ }
+ }
+
+ // # of all possible <TP, FP> pairs
+ int nPairs = nTruePositive * (recommendSize - nTruePositive);
+
+ // if there is no TP or no FP, it's meaningless for this metric (i.e., AUC=0.5)
+ if (nPairs == 0) {
+ return 0.5d;
+ }
+
+ // AUC can equivalently be calculated by counting the portion of correctly ordered pairs
+ return ((double) nCorrectPairs) / nPairs;
+ }
+
}
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 1cc8607..0cd807b 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -545,6 +545,13 @@
return PrimitiveObjectInspectorUtils.getInt(o, oi);
}
+ public static double getDouble(@Nullable Object o, @Nonnull PrimitiveObjectInspector oi) {
+ if (o == null) {
+ return 0.d;
+ }
+ return PrimitiveObjectInspectorUtils.getDouble(o, oi);
+ }
+
/**
* @return the number of true bits
*/
diff --git a/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java b/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java
new file mode 100644
index 0000000..020d23a
--- /dev/null
+++ b/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.evaluation;
+
+import java.util.ArrayList;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+public class AUCUDAFTest {
+ AUCUDAF auc;
+ GenericUDAFEvaluator evaluator;
+ ObjectInspector[] inputOIs;
+ ObjectInspector[] partialOI;
+ AUCUDAF.ClassificationAUCAggregationBuffer agg;
+
+ @Before
+ public void setUp() throws Exception {
+ auc = new AUCUDAF();
+
+ inputOIs = new ObjectInspector[] {
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.DOUBLE),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT)};
+
+ evaluator = auc.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+ ArrayList<String> fieldNames = new ArrayList<String>();
+ ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ fieldNames.add("indexScore");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ fieldNames.add("area");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ fieldNames.add("fp");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+ fieldNames.add("tp");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+ fieldNames.add("fpPrev");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+ fieldNames.add("tpPrev");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+
+ MapObjectInspector areaPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector,
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ fieldNames.add("areaPartialMap");
+ fieldOIs.add(areaPartialMapOI);
+
+ MapObjectInspector fpPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector,
+ PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+ fieldNames.add("fpPartialMap");
+ fieldOIs.add(fpPartialMapOI);
+
+ MapObjectInspector tpPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector,
+ PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+ fieldNames.add("tpPartialMap");
+ fieldOIs.add(tpPartialMapOI);
+
+ MapObjectInspector fpPrevPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector,
+ PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+ fieldNames.add("fpPrevPartialMap");
+ fieldOIs.add(fpPrevPartialMapOI);
+
+ MapObjectInspector tpPrevPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector,
+ PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+ fieldNames.add("tpPrevPartialMap");
+ fieldOIs.add(tpPrevPartialMapOI);
+
+ partialOI = new ObjectInspector[2];
+ partialOI[0] = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+
+ agg = (AUCUDAF.ClassificationAUCAggregationBuffer) evaluator.getNewAggregationBuffer();
+ }
+
+ @Test
+ public void test() throws Exception {
+ // should be sorted by scores in a descending order
+ final double[] scores = new double[] {0.8, 0.7, 0.5, 0.5, 0.3, 0.2};
+ final int[] labels = new int[] {1, 1, 0, 1, 1, 0};
+
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+
+ for (int i = 0; i < scores.length; i++) {
+ evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+ }
+
+ Assert.assertEquals(0.8125, agg.get(), 1e-5);
+ }
+
+ @Test(expected = HiveException.class)
+ public void testAllTruePositive() throws Exception {
+ final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
+ final int[] labels = new int[] {1, 1, 1, 1, 1};
+
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+
+ for (int i = 0; i < scores.length; i++) {
+ evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+ }
+
+ // AUC for all TP scores are not defined
+ agg.get();
+ }
+
+ @Test(expected = HiveException.class)
+ public void testAllFalsePositive() throws Exception {
+ final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
+ final int[] labels = new int[] {0, 0, 0, 0, 0};
+
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+
+ for (int i = 0; i < scores.length; i++) {
+ evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+ }
+
+ // AUC for all FP scores are not defined
+ agg.get();
+ }
+
+ @Test
+ public void testMaxAUC() throws Exception {
+ final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
+ final int[] labels = new int[] {1, 1, 1, 1, 0};
+
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+
+ for (int i = 0; i < scores.length; i++) {
+ evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+ }
+
+ // All TPs are ranked higher than FPs => AUC=1.0
+ Assert.assertEquals(1.d, agg.get(), 1e-5);
+ }
+
+ @Test
+ public void testMinAUC() throws Exception {
+ final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
+ final int[] labels = new int[] {0, 0, 0, 1, 1};
+
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+
+ for (int i = 0; i < scores.length; i++) {
+ evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+ }
+
+ // All TPs are ranked lower than FPs => AUC=0.0
+ Assert.assertEquals(0.d, agg.get(), 1e-5);
+ }
+
+ @Test
+ public void testMidAUC() throws Exception {
+ final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
+
+ // if TP and FP appear alternately, AUC=0.5
+ final int[] labels1 = new int[] {1, 0, 1, 0, 1};
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+ for (int i = 0; i < scores.length; i++) {
+ evaluator.iterate(agg, new Object[] {scores[i], labels1[i]});
+ }
+ Assert.assertEquals(0.5, agg.get(), 1e-5);
+
+ final int[] labels2 = new int[] {0, 1, 0, 1, 0};
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+ for (int i = 0; i < scores.length; i++) {
+ evaluator.iterate(agg, new Object[] {scores[i], labels2[i]});
+ }
+ Assert.assertEquals(0.5, agg.get(), 1e-5);
+ }
+
+ @Test
+ public void testMerge() throws Exception {
+ final double[] scores = new double[] {0.8, 0.7, 0.5, 0.5, 0.3, 0.2};
+ final int[] labels = new int[] {1, 1, 0, 1, 1, 0};
+
+ Object[] partials = new Object[3];
+
+ // bin #1
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+ evaluator.iterate(agg, new Object[] {scores[0], labels[0]});
+ evaluator.iterate(agg, new Object[] {scores[1], labels[1]});
+ partials[0] = evaluator.terminatePartial(agg);
+
+ // bin #2
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+ evaluator.iterate(agg, new Object[] {scores[2], labels[2]});
+ evaluator.iterate(agg, new Object[] {scores[3], labels[3]});
+ partials[1] = evaluator.terminatePartial(agg);
+
+ // bin #3
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+ evaluator.iterate(agg, new Object[] {scores[4], labels[4]});
+ evaluator.iterate(agg, new Object[] {scores[5], labels[5]});
+ partials[2] = evaluator.terminatePartial(agg);
+
+ // merge bins
+ // merge in a different order; e.g., <bin0, bin1>, <bin1, bin0> should return same value
+ final int[][] orders = new int[][] { {0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0},
+ {2, 0, 1}};
+ for (int i = 0; i < orders.length; i++) {
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI);
+ evaluator.reset(agg);
+
+ evaluator.merge(agg, partials[orders[i][0]]);
+ evaluator.merge(agg, partials[orders[i][1]]);
+ evaluator.merge(agg, partials[orders[i][2]]);
+
+ Assert.assertEquals(0.8125, agg.get(), 1e-5);
+ }
+ }
+
+ @Test
+ public void test100() throws Exception {
+ final double[] scores = new double[] {0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.8, 0.8,
+ 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.7, 0.7, 0.7, 0.7,
+ 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.6, 0.6,
+ 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.4,
+ 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3,
+ 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1,
+ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
+ final int[] labels = new int[] {1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
+ 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1,
+ 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1,
+ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0};
+
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+
+ for (int i = 0; i < scores.length; i++) {
+ evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+ }
+
+ // should equal to scikit-learn's result
+ Assert.assertEquals(0.567226890756, agg.get(), 1e-5);
+ }
+
+ @Test
+ public void testMerge100() throws Exception {
+ final double[] scores = new double[] {0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.8, 0.8,
+ 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.7, 0.7, 0.7, 0.7,
+ 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.6, 0.6,
+ 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.4,
+ 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3,
+ 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1,
+ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
+ final int[] labels = new int[] {1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
+ 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1,
+ 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1,
+ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0};
+
+ Object[] partials = new Object[3];
+
+ // bin #1 (score is in [0.9, 0.7])
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+ int i = 0;
+ while (scores[i] > 0.6) {
+ evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+ i++;
+ }
+ partials[0] = evaluator.terminatePartial(agg);
+
+ // bin #2 (score is in [0.6, 0.4])
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+ while (scores[i] > 0.3) {
+ evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+ i++;
+ }
+ partials[1] = evaluator.terminatePartial(agg);
+
+ // bin #3 (score is in [0.3, 0.1])
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+ while (i < 100) {
+ evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+ i++;
+ }
+ partials[2] = evaluator.terminatePartial(agg);
+
+ // merge bins
+ // merge in a different order; e.g., <bin0, bin1>, <bin1, bin0> should return same value
+ final int[][] orders = new int[][] { {0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0},
+ {2, 0, 1}};
+ for (int j = 0; j < orders.length; j++) {
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI);
+ evaluator.reset(agg);
+
+ evaluator.merge(agg, partials[orders[j][0]]);
+ evaluator.merge(agg, partials[orders[j][1]]);
+ evaluator.merge(agg, partials[orders[j][2]]);
+
+ Assert.assertEquals(0.567226890756, agg.get(), 1e-5);
+ }
+ }
+}
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index ee36635..33694e1 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -528,6 +528,9 @@
DROP FUNCTION IF EXISTS logloss;
CREATE FUNCTION logloss as 'hivemall.evaluation.LogarithmicLossUDAF' USING JAR '${hivemall_jar}';
+DROP FUNCTION IF EXISTS auc;
+CREATE FUNCTION auc as 'hivemall.evaluation.AUCUDAF' USING JAR '${hivemall_jar}';
+
--------------------------
-- Matrix Factorization --
--------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 022d821..e3c7a39 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -524,6 +524,9 @@
drop temporary function logloss;
create temporary function logloss as 'hivemall.evaluation.LogarithmicLossUDAF';
+drop temporary function if exists auc;
+create temporary function auc as 'hivemall.evaluation.AUCUDAF';
+
--------------------------
-- Matrix Factorization --
--------------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index 14c07a6..460d6e6 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -426,6 +426,9 @@
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS rmse")
sqlContext.sql("CREATE TEMPORARY FUNCTION rmse AS 'hivemall.evaluation.RootMeanSquaredErrorUDAF'")
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS auc")
+sqlContext.sql("CREATE TEMPORARY FUNCTION auc AS 'hivemall.evaluation.AUCUDAF'")
+
/**
* Matrix Factorization
*/
diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql
index 0e704c7..f208d23 100644
--- a/resources/ddl/define-udfs.td.hql
+++ b/resources/ddl/define-udfs.td.hql
@@ -152,6 +152,7 @@
create temporary function tilex2lon as 'hivemall.geospatial.TileX2LonUDF';
create temporary function tiley2lat as 'hivemall.geospatial.TileY2LatUDF';
create temporary function haversine_distance as 'hivemall.geospatial.HaversineDistanceUDF';
+create temporary function auc as 'hivemall.evaluation.AUCUDAF';
-- NLP features
create temporary function tokenize_ja as 'hivemall.nlp.tokenizer.KuromojiUDF';