[HIVEMALL-261][HIVEMALL-262] argmin/argmax/argsort UDF
## What changes were proposed in this pull request?
Introduce argmin/argmax/argsort UDF
## What type of PR is it?
Feature
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-261
https://issues.apache.org/jira/browse/HIVEMALL-262
## How was this patch tested?
unit tests, manual tests on EMR
## How to use this feature?
```sql
SELECT argmax(array(5,2,0,1));
> 0
SELECT array_slice(array(5,2,0,1), argmax(array(5,2,0,1)));
> 5
SELECT argmin(array(5,2,0,1));
> 2
SELECT argsort(array(5,2,0,1));
> 2, 3, 1, 0
SELECT array_slice(array(5,2,0,1), argsort(array(5,2,0,1)));
> 0, 1, 2, 5
SELECT argsort(argsort(array(5,2,0,1))), argrank(array(5,2,0,1));
> 3, 2, 0, 1
SELECT arange(5), arange(1, 5), arange(1, 5, 1), arange(0, 5, 1);
> [0,1,2,3,4] [1,2,3,4] [1,2,3,4] [0,1,2,3,4]
SELECT arange(1, 6, 2);
> 1, 3, 5
SELECT arange(-1, -6, 2);
> -1, -3, -5
SELECT argsort(array(5, 2, 0, 1)), argrank(array(5, 2, 0, 1)), argsort(argsort(array(5, 2, 0, 1)));
> [2,3,1,0] [3,2,0,1] [3,2,0,1]
```
## Checklist
(Please remove this section if not needed; check `x` for YES, blank for NO)
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [ ] Did you run system tests on Hive (or Spark)?
Author: Makoto Yui <myui@apache.org>
Closes #197 from myui/argmax.
diff --git a/core/src/main/java/hivemall/tools/array/ArangeUDF.java b/core/src/main/java/hivemall/tools/array/ArangeUDF.java
new file mode 100644
index 0000000..cde282b
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/array/ArangeUDF.java
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.array;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.ArrayUtils;
+import hivemall.utils.lang.StringUtils;
+
+import java.util.Arrays;
+import java.util.List;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nullable;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.IntWritable;
+
+// @formatter:off
+@Description(name = "arange",
+value = "_FUNC_([int start=0, ] int stop, [int step=1]) - Return evenly spaced values within a given interval",
+extended = "SELECT arange(5), arange(1, 5), arange(1, 5, 1), arange(0, 5, 1);\n" +
+ "> [0,1,2,3,4] [1,2,3,4] [1,2,3,4] [0,1,2,3,4]\n" +
+ "\n" +
+ "SELECT arange(1, 6, 2);\n" +
+ "> 1, 3, 5\n" +
+ "\n" +
+ "SELECT arange(-1, -6, 2);\n" +
+ "> -1, -3, -5")
+// @formatter:on
+@UDFType(deterministic = true, stateful = false)
+public final class ArangeUDF extends GenericUDF {
+
+ @Nullable
+ private PrimitiveObjectInspector startOI;
+ private PrimitiveObjectInspector stopOI;
+ @Nullable
+ private PrimitiveObjectInspector stepOI;
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ switch (argOIs.length) {
+ case 1:
+ if (!HiveUtils.isIntegerOI(argOIs[0])) {
+ throw new UDFArgumentException(
+ "arange(int stop) expects integer for the 1st argument: "
+ + argOIs[0].getTypeName());
+ }
+ this.stopOI = HiveUtils.asIntegerOI(argOIs[0]);
+ break;
+ case 3:
+ if (!HiveUtils.isIntegerOI(argOIs[2])) {
+ throw new UDFArgumentException(
+ "arange(int start, int stop, int step) expects integer for the 3rd argument: "
+ + argOIs[2].getTypeName());
+ }
+ this.stepOI = HiveUtils.asIntegerOI(argOIs[2]);
+ // fall through
+ case 2:
+ if (!HiveUtils.isIntegerOI(argOIs[0])) {
+ throw new UDFArgumentException(
+ "arange(int start, int stop) expects integer for the 1st argument: "
+ + argOIs[0].getTypeName());
+ }
+ this.startOI = HiveUtils.asIntegerOI(argOIs[0]);
+ if (!HiveUtils.isIntegerOI(argOIs[1])) {
+ throw new UDFArgumentException(
+ "arange(int start, int stop) expects integer for the 2nd argument: "
+ + argOIs[1].getTypeName());
+ }
+ this.stopOI = HiveUtils.asIntegerOI(argOIs[1]);
+ break;
+ default:
+ throw new UDFArgumentException(
+ "arange([int start=0, ] int stop, [int step=1]) takes 1~3 arguments: "
+ + argOIs.length);
+ }
+
+ return ObjectInspectorFactory.getStandardListObjectInspector(
+ ObjectInspectorUtils.getStandardObjectInspector(
+ PrimitiveObjectInspectorFactory.writableIntObjectInspector));
+ }
+
+ @Nullable
+ @Override
+ public List<IntWritable> evaluate(DeferredObject[] arguments) throws HiveException {
+ int start = 0, step = 1;
+ final int stop;
+ switch (arguments.length) {
+ case 1: {
+ Object arg0 = arguments[0].get();
+ if (arg0 == null) {
+ return null;
+ }
+ stop = PrimitiveObjectInspectorUtils.getInt(arg0, stopOI);
+ break;
+ }
+ case 3: {
+ Object arg2 = arguments[2].get();
+ if (arg2 == null) {
+ return null;
+ }
+ step = PrimitiveObjectInspectorUtils.getInt(arg2, stepOI);
+ // fall through
+ }
+ case 2: {
+ Object arg0 = arguments[0].get();
+ if (arg0 == null) {
+ return null;
+ }
+ start = PrimitiveObjectInspectorUtils.getInt(arg0, startOI);
+ Object arg1 = arguments[1].get();
+ if (arg1 == null) {
+ return null;
+ }
+ stop = PrimitiveObjectInspectorUtils.getInt(arg1, stopOI);
+ break;
+ }
+ default:
+ throw new UDFArgumentException(
+ "arange([int start=0, ] int stop, [int step=1]) takes 1~3 arguments: "
+ + arguments.length);
+ }
+
+ return Arrays.asList(range(start, stop, step));
+ }
+
+ /**
+ * Return evenly spaced values within a given interval.
+ *
+ * @param start inclusive index of the start
+ * @param stop exclusive index of the end
+ * @param step positive interval value
+ */
+ private static IntWritable[] range(final int start, final int stop,
+ @Nonnegative final int step) throws UDFArgumentException {
+ if (step <= 0) {
+ throw new UDFArgumentException("Invalid step value: " + step);
+ }
+
+ final IntWritable[] r;
+ final int diff = stop - start;
+ if (diff < 0) {
+ final int count = ArrayUtils.divideAndRoundUp(-diff, step);
+ r = new IntWritable[count];
+ for (int i = 0, value = start; i < r.length; i++, value -= step) {
+ r[i] = new IntWritable(value);
+ }
+ } else {
+ final int count = ArrayUtils.divideAndRoundUp(diff, step);
+ r = new IntWritable[count];
+ for (int i = 0, value = start; i < r.length; i++, value += step) {
+ r[i] = new IntWritable(value);
+ }
+ }
+ return r;
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ return "arange(" + StringUtils.join(children, ',') + ')';
+ }
+
+}
diff --git a/core/src/main/java/hivemall/tools/array/ArgmaxUDF.java b/core/src/main/java/hivemall/tools/array/ArgmaxUDF.java
new file mode 100644
index 0000000..d697907
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/array/ArgmaxUDF.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.array;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.StringUtils;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.IntWritable;
+
+@Description(name = "argmax",
+ value = "_FUNC_(array<T> a) - Returns the first index of the maximum value",
+ extended = "SELECT argmax(array(5,2,0,1));\n" + "> 0")
+@UDFType(deterministic = true, stateful = false)
+public final class ArgmaxUDF extends GenericUDF {
+
+ private ListObjectInspector listOI;
+ private PrimitiveObjectInspector elemOI;
+
+ private IntWritable result;
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ if (argOIs.length != 1) {
+ throw new UDFArgumentException("argmax takes exactly one argument: " + argOIs.length);
+ }
+ this.listOI = HiveUtils.asListOI(argOIs[0]);
+ this.elemOI = HiveUtils.asPrimitiveObjectInspector(listOI.getListElementObjectInspector());
+
+ this.result = new IntWritable();
+
+ return PrimitiveObjectInspectorFactory.writableIntObjectInspector;
+ }
+
+ @Override
+ public IntWritable evaluate(DeferredObject[] arguments) throws HiveException {
+ Object arg0 = arguments[0].get();
+ if (arg0 == null) {
+ return null;
+ }
+
+ int index = -1;
+ Object maxObject = null;
+ final int size = listOI.getListLength(arg0);
+ for (int i = 0; i < size; i++) {
+ Object ai = listOI.getListElement(arg0, i);
+ if (ai == null) {
+ continue;
+ }
+
+ if (maxObject == null) {
+ maxObject = ai;
+ index = i;
+ } else {
+ final int cmp = ObjectInspectorUtils.compare(ai, elemOI, maxObject, elemOI);
+ if (cmp > 0) {
+ maxObject = ai;
+ index = i;
+ }
+ }
+ }
+
+ result.set(index);
+ return result;
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ return "argmax(" + StringUtils.join(children, ',') + ')';
+ }
+
+}
diff --git a/core/src/main/java/hivemall/tools/array/ArgminUDF.java b/core/src/main/java/hivemall/tools/array/ArgminUDF.java
new file mode 100644
index 0000000..a2664b7
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/array/ArgminUDF.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.array;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.StringUtils;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.IntWritable;
+
+@Description(name = "argmin",
+ value = "_FUNC_(array<T> a) - Returns the first index of the minimum value",
+ extended = "SELECT argmin(array(5,2,0,1));\n" + "> 2")
+@UDFType(deterministic = true, stateful = false)
+public final class ArgminUDF extends GenericUDF {
+
+ private ListObjectInspector listOI;
+ private PrimitiveObjectInspector elemOI;
+
+ private IntWritable result;
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ if (argOIs.length != 1) {
+ throw new UDFArgumentException("argmin takes exactly one argument: " + argOIs.length);
+ }
+ this.listOI = HiveUtils.asListOI(argOIs[0]);
+ this.elemOI = HiveUtils.asPrimitiveObjectInspector(listOI.getListElementObjectInspector());
+
+ this.result = new IntWritable();
+
+ return PrimitiveObjectInspectorFactory.writableIntObjectInspector;
+ }
+
+ @Override
+ public IntWritable evaluate(DeferredObject[] arguments) throws HiveException {
+ Object arg0 = arguments[0].get();
+ if (arg0 == null) {
+ return null;
+ }
+
+ int index = -1;
+ Object minObject = null;
+ final int size = listOI.getListLength(arg0);
+ for (int i = 0; i < size; i++) {
+ Object ai = listOI.getListElement(arg0, i);
+ if (ai == null) {
+ continue;
+ }
+
+ if (minObject == null) {
+ minObject = ai;
+ index = i;
+ } else {
+ final int cmp = ObjectInspectorUtils.compare(ai, elemOI, minObject, elemOI);
+ if (cmp < 0) {
+ minObject = ai;
+ index = i;
+ }
+ }
+ }
+
+ result.set(index);
+ return result;
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ return "argmin(" + StringUtils.join(children, ',') + ')';
+ }
+
+}
diff --git a/core/src/main/java/hivemall/tools/array/ArgrankUDF.java b/core/src/main/java/hivemall/tools/array/ArgrankUDF.java
new file mode 100644
index 0000000..f283d39
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/array/ArgrankUDF.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.array;
+
+import hivemall.utils.hadoop.HiveUtils;
+
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+
+import javax.annotation.Nullable;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.IntWritable;
+
+// @formatter:off
+@Description(name = "argrank",
+ value = "_FUNC_(array<ANY> a) - Returns the indices that would sort an array.",
+ extended = "SELECT argrank(array(5,2,0,1)), argsort(argsort(array(5,2,0,1)));\n" +
+ "> [3, 2, 0, 1] [3, 2, 0, 1]")
+// @formatter:on
+@UDFType(deterministic = true, stateful = false)
+public final class ArgrankUDF extends GenericUDF {
+
+ private ListObjectInspector listOI;
+ private ObjectInspector elemOI;
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ if (argOIs.length != 1) {
+ throw new UDFArgumentLengthException(
+ "argrank(array<ANY> a) takes exactly 1 argument: " + argOIs.length);
+ }
+ ObjectInspector argOI0 = argOIs[0];
+ if (argOI0.getCategory() != Category.LIST) {
+ throw new UDFArgumentException(
+ "argrank(array<ANY> a) expects array<ANY> for the first argument: "
+ + argOI0.getTypeName());
+ }
+
+ this.listOI = HiveUtils.asListOI(argOI0);
+ this.elemOI = listOI.getListElementObjectInspector();
+
+ return ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+ @Nullable
+ @Override
+ public List<IntWritable> evaluate(DeferredObject[] arguments) throws HiveException {
+ final Object arg0 = arguments[0].get();
+ if (arg0 == null) {
+ return null;
+ }
+
+ final int size = listOI.getListLength(arg0);
+
+ final Integer[] indexes = new Integer[size];
+ for (int i = 0; i < size; i++) {
+ indexes[i] = i;
+ }
+ Arrays.sort(indexes, new Comparator<Integer>() {
+ @Override
+ public int compare(Integer i, Integer j) {
+ Object ei = listOI.getListElement(arg0, i.intValue());
+ Object ej = listOI.getListElement(arg0, j.intValue());
+ return ObjectInspectorUtils.compare(ei, elemOI, ej, elemOI);
+ }
+ });
+
+ final IntWritable[] ret = new IntWritable[size];
+ for (int i = 0; i < size; i++) {
+ ret[i] = new IntWritable(i);
+ }
+ Arrays.sort(ret, new Comparator<IntWritable>() {
+ @Override
+ public int compare(IntWritable i, IntWritable j) {
+ int ei = indexes[i.get()].intValue();
+ int ej = indexes[j.get()].intValue();
+ return Integer.compare(ei, ej);
+ }
+ });
+ return Arrays.asList(ret);
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ return "argrank(" + StringUtils.join(children, ',') + ')';
+ }
+
+}
diff --git a/core/src/main/java/hivemall/tools/array/ArgsortUDF.java b/core/src/main/java/hivemall/tools/array/ArgsortUDF.java
new file mode 100644
index 0000000..f8e6b8f
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/array/ArgsortUDF.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.array;
+
+import hivemall.utils.hadoop.HiveUtils;
+
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+
+import javax.annotation.Nullable;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.IntWritable;
+
+// @formatter:off
+@Description(name = "argsort",
+ value = "_FUNC_(array<ANY> a) - Returns the indices that would sort an array.",
+ extended = "SELECT argsort(array(5,2,0,1));\n" +
+ "> 2, 3, 1, 0\n" +
+ "\n" +
+ "SELECT array_slice(array(5,2,0,1), argsort(array(5,2,0,1)));\n" +
+ "> 0, 1, 2, 5")
+// @formatter:on
+@UDFType(deterministic = true, stateful = false)
+public final class ArgsortUDF extends GenericUDF {
+
+ private ListObjectInspector listOI;
+ private ObjectInspector elemOI;
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ if (argOIs.length != 1) {
+ throw new UDFArgumentLengthException(
+ "argsort(array<ANY> a) takes exactly 1 argument: " + argOIs.length);
+ }
+ ObjectInspector argOI0 = argOIs[0];
+ if (argOI0.getCategory() != Category.LIST) {
+ throw new UDFArgumentException(
+ "argsort(array<ANY> a) expects array<ANY> for the first argument: "
+ + argOI0.getTypeName());
+ }
+
+ this.listOI = HiveUtils.asListOI(argOI0);
+ this.elemOI = listOI.getListElementObjectInspector();
+
+ return ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+ @Nullable
+ @Override
+ public List<IntWritable> evaluate(DeferredObject[] arguments) throws HiveException {
+ final Object arg0 = arguments[0].get();
+ if (arg0 == null) {
+ return null;
+ }
+
+ final int size = listOI.getListLength(arg0);
+
+ final Integer[] indexes = new Integer[size];
+ for (int i = 0; i < size; i++) {
+ indexes[i] = i;
+ }
+ Arrays.sort(indexes, new Comparator<Integer>() {
+ @Override
+ public int compare(Integer i, Integer j) {
+ Object ei = listOI.getListElement(arg0, i.intValue());
+ Object ej = listOI.getListElement(arg0, j.intValue());
+ return ObjectInspectorUtils.compare(ei, elemOI, ej, elemOI);
+ }
+ });
+
+ final IntWritable[] ret = new IntWritable[size];
+ for (int i = 0; i < size; i++) {
+ ret[i] = new IntWritable(indexes[i].intValue());
+ }
+ return Arrays.asList(ret);
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ return "argsort(" + StringUtils.join(children, ',') + ')';
+ }
+
+}
diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
index 5945118..562a7b5 100644
--- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
@@ -18,11 +18,14 @@
*/
package hivemall.utils.lang;
+import static hivemall.utils.lang.Preconditions.checkElementIndex;
+
import hivemall.utils.random.PRNG;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Comparator;
import java.util.List;
import java.util.Random;
@@ -43,6 +46,36 @@
private ArrayUtils() {}
+ @SuppressWarnings("unchecked")
+ public static <T> T[] newInstance(@Nonnull final T[] a, @Nonnegative final int newLength) {
+ return (T[]) newInstance(a.getClass(), newLength);
+ }
+
+ @SuppressWarnings("unchecked")
+ public static <T> T[] newInstance(@Nonnull final Class<? extends T[]> newType,
+ @Nonnegative final int newLength) {
+ if ((Object) newType == (Object) Object[].class) {
+ return (T[]) new Object[newLength];
+ } else {
+ return (T[]) Array.newInstance(newType.getComponentType(), newLength);
+ }
+ }
+
+ @Nonnull
+ public static int get(final int[] a, final int index) {
+ return a[checkElementIndex(index, a.length)];
+ }
+
+ @Nonnull
+ public static double get(final double[] a, final int index) {
+ return a[checkElementIndex(index, a.length)];
+ }
+
+ @Nonnull
+ public static <T> T get(final T[] a, final int index) {
+ return a[checkElementIndex(index, a.length)];
+ }
+
@Nonnull
public static double[] set(@Nonnull double[] src, final int index, final double value) {
if (index >= src.length) {
@@ -149,6 +182,36 @@
}
@Nonnull
+ public static int[] slice(@Nonnull final int[] a, @Nonnull final int... indexes) {
+ final int size = indexes.length;
+ final int[] ret = new int[size];
+ for (int i = 0; i < size; i++) {
+ ret[i] = get(a, indexes[i]);
+ }
+ return ret;
+ }
+
+ @Nonnull
+ public static double[] slice(@Nonnull final double[] a, @Nonnull final int... indexes) {
+ final int size = indexes.length;
+ final double[] ret = new double[size];
+ for (int i = 0; i < size; i++) {
+ ret[i] = get(a, indexes[i]);
+ }
+ return ret;
+ }
+
+ @Nonnull
+ public static <T> T[] slice(@Nonnull final T[] a, @Nonnull final int... indexes) {
+ final int size = indexes.length;
+ final T[] ret = newInstance(a, size);
+ for (int i = 0; i < size; i++) {
+ ret[i] = get(a, indexes[i]);
+ }
+ return ret;
+ }
+
+ @Nonnull
public static <T> T[] shuffle(@Nonnull final T[] array) {
shuffle(array, array.length);
return array;
@@ -835,6 +898,265 @@
}
}
+ @Nonnull
+ public static int[] argsort(@Nonnull final int[] a) {
+ final int size = a.length;
+ final Integer[] indexes = new Integer[size];
+ for (int i = 0; i < size; i++) {
+ indexes[i] = i;
+ }
+ Arrays.sort(indexes, new Comparator<Integer>() {
+ @Override
+ public int compare(Integer i, Integer j) {
+ return Integer.compare(a[i], a[j]);
+ }
+ });
+
+ final int[] ret = new int[size];
+ for (int i = 0; i < size; i++) {
+ ret[i] = indexes[i].intValue();
+ }
+ return ret;
+ }
+
+ @Nonnull
+ public static int[] argsort(@Nonnull final double[] a) {
+ final int size = a.length;
+ final Integer[] indexes = new Integer[size];
+ for (int i = 0; i < size; i++) {
+ indexes[i] = i;
+ }
+ Arrays.sort(indexes, new Comparator<Integer>() {
+ @Override
+ public int compare(Integer i, Integer j) {
+ return Double.compare(a[i], a[j]);
+ }
+ });
+
+ final int[] ret = new int[size];
+ for (int i = 0; i < size; i++) {
+ ret[i] = indexes[i].intValue();
+ }
+ return ret;
+ }
+
+ @Nonnull
+ public static <T extends Comparable<T>> int[] argsort(@Nonnull final T[] a) {
+ final int size = a.length;
+ final Integer[] indexes = new Integer[size];
+ for (int i = 0; i < size; i++) {
+ indexes[i] = i;
+ }
+ Arrays.sort(indexes, new Comparator<Integer>() {
+ @Override
+ public int compare(Integer i, Integer j) {
+ T ai = a[i.intValue()];
+ T aj = a[j.intValue()];
+ return ai.compareTo(aj);
+ }
+ });
+
+ final int[] ret = new int[size];
+ for (int i = 0; i < size; i++) {
+ ret[i] = indexes[i].intValue();
+ }
+ return ret;
+ }
+
+ @Nonnull
+ public static <T> int[] argsort(@Nonnull final T[] a, @Nonnull final Comparator<? super T> c) {
+ final int size = a.length;
+ final Integer[] indexes = new Integer[size];
+ for (int i = 0; i < size; i++) {
+ indexes[i] = i;
+ }
+ Arrays.sort(indexes, new Comparator<Integer>() {
+ @Override
+ public int compare(Integer i, Integer j) {
+ return c.compare(a[i.intValue()], a[j.intValue()]);
+ }
+ });
+
+ final int[] ret = new int[size];
+ for (int i = 0; i < size; i++) {
+ ret[i] = indexes[i].intValue();
+ }
+ return ret;
+ }
+
+ public static int[] argrank(@Nonnull final int[] a) {
+ return argsort(argsort(a));
+ }
+
+ public static int[] argrank(@Nonnull final double[] a) {
+ return argsort(argsort(a));
+ }
+
+ @Nonnull
+ public static <T> int[] argrank(@Nonnull final T[] a, @Nonnull final Comparator<? super T> c) {
+ return argsort(argsort(a, c));
+ }
+
+ public static int argmin(@Nonnull final double[] a) {
+ final int size = a.length;
+ if (size == 0) {
+ throw new IllegalArgumentException("attempt to get argmax of an empty array");
+ }
+
+ int minIdx = 0;
+ double minValue = Double.POSITIVE_INFINITY;
+ for (int i = 0; i < size; i++) {
+ final double v = a[i];
+ if (v < minValue) {
+ minIdx = i;
+ minValue = v;
+ }
+ }
+ return minIdx;
+ }
+
+ public static <T extends Comparable<T>> int argmin(@Nonnull final T[] a) {
+ final int size = a.length;
+ if (size == 0) {
+ throw new IllegalArgumentException("attempt to get argmax of an empty array");
+ }
+
+ int minIdx = 0;
+ T minValue = null;
+ for (int i = 0; i < size; i++) {
+ final T v = a[i];
+ if (v == null) {
+ continue;
+ }
+ if (minValue == null || v.compareTo(minValue) < 0) {
+ minIdx = i;
+ minValue = v;
+ }
+ }
+ return minIdx;
+ }
+
+ @Nonnull
+ public static <T> int argmin(@Nonnull final T[] a, @Nonnull final Comparator<? super T> c) {
+ final int size = a.length;
+ if (size == 0) {
+ throw new IllegalArgumentException("attempt to get argmax of an empty array");
+ }
+ if (size == 1) {
+ return 0;
+ }
+
+ int minIdx = 0;
+ T minValue = a[0];
+ for (int i = 1; i < size; i++) {
+ final T v = a[i];
+ if (c.compare(v, minValue) < 0) {
+ minIdx = i;
+ minValue = v;
+ }
+ }
+ return minIdx;
+ }
+
+ public static <T extends Comparable<T>> int argmax(@Nonnull final T[] a) {
+ final int size = a.length;
+ if (size == 0) {
+ throw new IllegalArgumentException("attempt to get argmax of an empty array");
+ }
+
+ int maxIdx = 0;
+ T maxValue = null;
+ for (int i = 0; i < size; i++) {
+ final T v = a[i];
+ if (v == null) {
+ continue;
+ }
+ if (maxValue == null || v.compareTo(maxValue) > 0) {
+ maxIdx = i;
+ maxValue = v;
+ }
+ }
+ return maxIdx;
+ }
+
+ @Nonnull
+ public static <T> int argmax(@Nonnull final T[] a, @Nonnull final Comparator<? super T> c) {
+ final int size = a.length;
+ if (size == 0) {
+ throw new IllegalArgumentException("attempt to get argmax of an empty array");
+ }
+ if (size == 1) {
+ return 0;
+ }
+
+ int maxIdx = 0;
+ T maxValue = a[0]; // consideration for null
+ for (int i = 1; i < size; i++) {
+ final T v = a[i];
+ if (c.compare(v, maxValue) > 0) {
+ maxIdx = i;
+ maxValue = v;
+ }
+ }
+ return maxIdx;
+ }
+
+ public static int[] range(final int stop) {
+ return range(0, stop);
+ }
+
+ public static int[] range(final int start, final int stop) {
+ final int count = stop - start;
+ final int[] r;
+ if (count < 0) {
+ r = new int[-count];
+ for (int i = 0; i < r.length; i++) {
+ r[i] = start - i;
+ }
+ } else {
+ r = new int[count];
+ for (int i = 0; i < r.length; i++) {
+ r[i] = start + i;
+ }
+ }
+ return r;
+ }
+
+ /**
+ * Return evenly spaced values within a given interval.
+ *
+ * @param start inclusive index of the start
+ * @param stop exclusive index of the end
+ * @param step positive interval value
+ */
+ public static int[] range(final int start, final int stop, @Nonnegative final int step) {
+ if (step <= 0) {
+ throw new IllegalArgumentException("Invalid step value: " + step);
+ }
+
+ final int[] r;
+ final int diff = stop - start;
+ if (diff < 0) {
+ final int count = divideAndRoundUp(-diff, step);
+ r = new int[count];
+ for (int i = 0, value = start; i < r.length; i++, value -= step) {
+ r[i] = value;
+ }
+ } else {
+ final int count = divideAndRoundUp(diff, step);
+ r = new int[count];
+ for (int i = 0, value = start; i < r.length; i++, value += step) {
+ r[i] = value;
+ }
+ }
+ return r;
+ }
+
+ public static int divideAndRoundUp(@Nonnegative final int num,
+ @Nonnegative final int divisor) {
+ return (num + divisor - 1) / divisor;
+ }
+
public static int count(@Nonnull final int[] values, final int valueToFind) {
int cnt = 0;
for (int i = 0; i < values.length; i++) {
diff --git a/core/src/main/java/hivemall/utils/lang/Preconditions.java b/core/src/main/java/hivemall/utils/lang/Preconditions.java
index 3843a27..99aa2dc 100644
--- a/core/src/main/java/hivemall/utils/lang/Preconditions.java
+++ b/core/src/main/java/hivemall/utils/lang/Preconditions.java
@@ -132,4 +132,17 @@
}
}
+ @Nonnull
+ public static int checkElementIndex(final int index, final int size) {
+ if (index < 0) {
+ throw new IndexOutOfBoundsException("index (" + index + ") must not be negative");
+ } else if (size < 0) {
+ throw new IndexOutOfBoundsException("negative size: " + size);
+ } else if (index >= size) {
+ throw new IndexOutOfBoundsException(
+ "index (" + index + ") must be less than size (" + size + ")");
+ }
+ return index;
+ }
+
}
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java
index ed5b2fd..2163e3a 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -18,6 +18,8 @@
*/
package hivemall.utils.math;
+import static java.lang.Math.abs;
+
import java.util.Random;
import javax.annotation.Nonnegative;
@@ -38,6 +40,15 @@
return 1.d / Math.cos(d);
}
+ public static int divideAndRoundUp(final int num, final int divisor) {
+ if (divisor == 0) {
+ throw new ArithmeticException("/ by zero");
+ }
+ final int sign = (num > 0 ? 1 : -1) * (divisor > 0 ? 1 : -1);
+ final int div = abs(divisor);
+ return sign * (abs(num) + div - 1) / div;
+ }
+
/**
* Returns a bit mask for the specified number of bits.
*/
diff --git a/core/src/test/java/hivemall/utils/lang/ArrayUtilsTest.java b/core/src/test/java/hivemall/utils/lang/ArrayUtilsTest.java
index 42acc0b..5b98073 100644
--- a/core/src/test/java/hivemall/utils/lang/ArrayUtilsTest.java
+++ b/core/src/test/java/hivemall/utils/lang/ArrayUtilsTest.java
@@ -18,11 +18,22 @@
*/
package hivemall.utils.lang;
+import static hivemall.utils.lang.ArrayUtils.argmax;
+import static hivemall.utils.lang.ArrayUtils.argmin;
+import static hivemall.utils.lang.ArrayUtils.argrank;
+import static hivemall.utils.lang.ArrayUtils.argsort;
+import static hivemall.utils.lang.ArrayUtils.newInstance;
+import static hivemall.utils.lang.ArrayUtils.range;
+import static hivemall.utils.lang.ArrayUtils.slice;
+import static java.lang.Math.abs;
+
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Comparator;
import java.util.List;
import java.util.Random;
+import org.apache.commons.collections.ComparatorUtils;
import org.junit.Assert;
import org.junit.Test;
@@ -92,4 +103,119 @@
Assert.assertEquals(ArrayList.class, actual.getClass());
}
+ public void testNewInstance() {
+ Assert.assertArrayEquals(new Integer[2], newInstance(new Integer[] {1, 2, 3}, 2));
+ }
+
+ @Test
+ public void testSlice() {
+ int[] a = new int[] {5, 2, 0, 1};
+ Assert.assertArrayEquals(new int[] {1, 2}, slice(a, new int[] {3, 1}));
+ }
+
+ @Test(expected = IndexOutOfBoundsException.class)
+ public void testSliceIndexOutOfBound() {
+ int[] a = new int[] {5, 2, 0, 1};
+ Assert.assertArrayEquals(new int[] {1, 2}, slice(a, new int[] {3, 5}));
+ }
+
+ @Test
+ public void testArgsortTArray() {
+ Double[] a = new Double[] {5d, 2d, 0d, 1d};
+ Assert.assertArrayEquals(new int[] {2, 3, 1, 0}, argsort(a));
+ Assert.assertArrayEquals(new Double[] {0d, 1d, 2d, 5d}, slice(a, argsort(a)));
+ Assert.assertArrayEquals(new int[] {3, 2, 0, 1}, argsort(argsort(a)));
+ }
+
+ @Test
+ public void testArgsortTArrayComparatorOfQsuperT() {
+ Double[] a = new Double[] {5d, -2d, 0d, -1d};
+ Comparator<Double> cmp = new Comparator<Double>() {
+ @Override
+ public int compare(Double l, Double r) {
+ return Double.compare(abs(l.doubleValue()), abs(r.doubleValue()));
+ }
+ };
+ Assert.assertArrayEquals(new int[] {2, 3, 1, 0}, argsort(a, cmp));
+ Assert.assertArrayEquals(new Double[] {0d, -1d, -2d, 5d}, slice(a, argsort(a, cmp)));
+ }
+
+ @Test
+ public void testArgrankIntArray() {
+ int[] a = new int[] {5, 2, 0, 1};
+ Assert.assertArrayEquals(new int[] {3, 2, 0, 1}, argrank(a));
+ }
+
+ @Test
+ public void testArgrankDoubleArray() {
+ double[] a = new double[] {5.1d, 2.1d, 0.1d, 1.1d};
+ Assert.assertArrayEquals(new int[] {3, 2, 0, 1}, argrank(a));
+ }
+
+ @Test
+ public void testArgminDoubleArray() {
+ double[] a = new double[] {5d, 2d, 0.1d, 1d};
+ Assert.assertEquals(2, argmin(a));
+ Assert.assertArrayEquals(new double[] {0.1d}, slice(a, argmin(a)), 1e-8);
+ }
+
+ @Test
+ public void testArgminTArray() {
+ Double[] a = new Double[] {5d, 2d, 0.1d, 1d};
+ Assert.assertEquals(2, argmin(a));
+ Assert.assertArrayEquals(new Double[] {0.1d}, slice(a, argmin(a)));
+ }
+
+ @Test
+ public void testArgminTArrayComparatorOfQsuperT() {
+ Double[] a = new Double[] {5d, -2d, 0.1d, -1d};
+ Comparator<Double> cmp = new Comparator<Double>() {
+ @Override
+ public int compare(Double l, Double r) {
+ return Double.compare(abs(l.doubleValue()), abs(r.doubleValue()));
+ }
+ };
+ Assert.assertEquals(2, argmin(a, cmp));
+ Assert.assertArrayEquals(new Double[] {0.1d}, slice(a, argmin(a, cmp)));
+ }
+
+ @Test
+ public void testArgmaxTArray() {
+ Double[] a = new Double[] {5d, 2d, 0.1d, 1d};
+ Assert.assertEquals(0, argmax(a));
+ Assert.assertArrayEquals(new Double[] {5d}, slice(a, argmax(a)));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testArgmaxTArrayComparatorOfQsuperT() {
+ Double[] a = new Double[] {2d, 5d, 0d, null, 1d};
+ Assert.assertEquals(2, argmax(a, ComparatorUtils.nullLowComparator(
+ ComparatorUtils.reversedComparator(ComparatorUtils.naturalComparator()))));
+ Assert.assertEquals(1,
+ argmax(a, ComparatorUtils.nullLowComparator(ComparatorUtils.naturalComparator())));
+ Assert.assertEquals(3,
+ argmax(a, ComparatorUtils.nullHighComparator(ComparatorUtils.naturalComparator())));
+ }
+
+ @Test
+ public void testRange() {
+ Assert.assertArrayEquals(new int[] {0, 1, 2, 3, 4}, range(5));
+ Assert.assertArrayEquals(new int[] {1, 2, 3, 4}, range(1, 5));
+ Assert.assertArrayEquals(new int[] {0, -1, -2}, range(-3));
+ Assert.assertArrayEquals(new int[] {5, 4, 3, 2}, range(5, 1));
+ Assert.assertArrayEquals(new int[] {3, 2, 1, 0, -1}, range(3, -2));
+
+ Assert.assertArrayEquals(new int[] {0, 1, 2, 3, 4}, range(0, 5, 1));
+ Assert.assertArrayEquals(new int[] {1, 2, 3, 4}, range(1, 5, 1));
+ Assert.assertArrayEquals(new int[] {0, -1, -2}, range(0, -3, 1));
+ Assert.assertArrayEquals(new int[] {5, 4, 3, 2}, range(5, 1, 1));
+ Assert.assertArrayEquals(new int[] {3, 2, 1, 0, -1}, range(3, -2, 1));
+
+ Assert.assertArrayEquals(new int[] {1, 3}, range(1, 5, 2));
+ Assert.assertArrayEquals(new int[] {1, 3, 5}, range(1, 6, 2));
+ Assert.assertArrayEquals(new int[] {6, 4, 2}, range(6, 1, 2));
+ Assert.assertArrayEquals(new int[] {-1, -3, -5}, range(-1, -6, 2));
+ }
+
}
diff --git a/docs/gitbook/misc/generic_funcs.md b/docs/gitbook/misc/generic_funcs.md
index aa4b462..9dfbc2b 100644
--- a/docs/gitbook/misc/generic_funcs.md
+++ b/docs/gitbook/misc/generic_funcs.md
@@ -23,6 +23,45 @@
# Array
+- `arange([int start=0, ] int stop, [int step=1])` - Return evenly spaced values within a given interval
+ ```sql
+ SELECT arange(5), arange(1, 5), arange(1, 5, 1), arange(0, 5, 1);
+ > [0,1,2,3,4] [1,2,3,4] [1,2,3,4] [0,1,2,3,4]
+
+ SELECT arange(1, 6, 2);
+ > 1, 3, 5
+
+ SELECT arange(-1, -6, 2);
+ > -1, -3, -5
+ ```
+
+- `argmax(array<T> a)` - Returns the first index of the maximum value
+ ```sql
+ SELECT argmax(array(5,2,0,1));
+ > 0
+ ```
+
+- `argmin(array<T> a)` - Returns the first index of the minimum value
+ ```sql
+ SELECT argmin(array(5,2,0,1));
+ > 2
+ ```
+
+- `argrank(array<ANY> a)` - Returns the indices that would sort an array.
+ ```sql
+ SELECT argrank(array(5,2,0,1)), argsort(argsort(array(5,2,0,1)));
+ > [3, 2, 0, 1] [3, 2, 0, 1]
+ ```
+
+- `argsort(array<ANY> a)` - Returns the indices that would sort an array.
+ ```sql
+ SELECT argsort(array(5,2,0,1));
+ > 2, 3, 1, 0
+
+ SELECT array_slice(array(5,2,0,1), argsort(array(5,2,0,1)));
+ > 0, 1, 2, 5
+ ```
+
- `array_append(array<T> arr, T elem)` - Append an element to the end of an array
```sql
SELECT array_append(array(1,2),3);
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index 343215a..1fcb3b8 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -464,6 +464,21 @@
DROP FUNCTION IF EXISTS conditional_emit;
CREATE FUNCTION conditional_emit as 'hivemall.tools.array.ConditionalEmitUDTF' USING JAR '${hivemall_jar}';
+DROP FUNCTION IF EXISTS argmin;
+CREATE FUNCTION argmin as 'hivemall.tools.array.ArgminUDF' USING JAR '${hivemall_jar}';
+
+DROP FUNCTION IF EXISTS argmax;
+CREATE FUNCTION argmax as 'hivemall.tools.array.ArgmaxUDF' USING JAR '${hivemall_jar}';
+
+DROP FUNCTION IF EXISTS arange;
+CREATE FUNCTION arange as 'hivemall.tools.array.ArangeUDF' USING JAR '${hivemall_jar}';
+
+DROP FUNCTION IF EXISTS argrank;
+CREATE FUNCTION argrank as 'hivemall.tools.array.ArgrankUDF' USING JAR '${hivemall_jar}';
+
+DROP FUNCTION IF EXISTS argsort;
+CREATE FUNCTION argsort as 'hivemall.tools.array.ArgsortUDF' USING JAR '${hivemall_jar}';
+
-----------------------------
-- bit operation functions --
-----------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 2a9b437..2658e5e 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -456,6 +456,21 @@
drop temporary function if exists conditional_emit;
create temporary function conditional_emit as 'hivemall.tools.array.ConditionalEmitUDTF';
+drop temporary function if exists argmin;
+create temporary function argmin as 'hivemall.tools.array.ArgminUDF';
+
+drop temporary function if exists argmax;
+create temporary function argmax as 'hivemall.tools.array.ArgmaxUDF';
+
+drop temporary function if exists arange;
+create temporary function arange as 'hivemall.tools.array.ArangeUDF';
+
+drop temporary function if exists argrank;
+create temporary function argrank as 'hivemall.tools.array.ArgrankUDF';
+
+drop temporary function if exists argsort;
+create temporary function argsort as 'hivemall.tools.array.ArgsortUDF';
+
-----------------------------
-- bit operation functions --
-----------------------------
@@ -892,4 +907,3 @@
create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE)
tf * (log(10, n_docs / max2(1,df_t)) + 1.0);
-
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index d62e3a2..89f3def 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -455,6 +455,21 @@
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS conditional_emit")
sqlContext.sql("CREATE TEMPORARY FUNCTION conditional_emit AS 'hivemall.tools.array.ConditionalEmitUDTF'")
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS argmin")
+sqlContext.sql("CREATE TEMPORARY FUNCTION argmin AS 'hivemall.tools.array.ArgminUDF'")
+
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS argmax")
+sqlContext.sql("CREATE TEMPORARY FUNCTION argmax AS 'hivemall.tools.array.ArgmaxUDF'")
+
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS arange")
+sqlContext.sql("CREATE TEMPORARY FUNCTION arange AS 'hivemall.tools.array.ArangeUDF'")
+
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS argrank")
+sqlContext.sql("CREATE TEMPORARY FUNCTION argrank AS 'hivemall.tools.array.ArgrankUDF'")
+
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS argsort")
+sqlContext.sql("CREATE TEMPORARY FUNCTION argsort AS 'hivemall.tools.array.ArgsortUDF'")
+
/**
* Bit operation functions
*/