[HIVEMALL-314] fixed Spark DDLs
## What changes were proposed in this pull request?
fixed Spark DDLs
## What type of PR is it?
Bug Fix
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-314
## How was this patch tested?
manual tests
## Checklist
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [x] Did you run system tests on Hive (or Spark)?
Author: Makoto Yui <myui@apache.org>
Closes #244 from myui/HIVEMALL-314-fix-spark-ddls.
diff --git a/core/src/main/java/hivemall/dataset/LogisticRegressionDataGeneratorUDTF.java b/core/src/main/java/hivemall/dataset/LogisticRegressionDataGeneratorUDTF.java
index 5b87183..24cd9d5 100644
--- a/core/src/main/java/hivemall/dataset/LogisticRegressionDataGeneratorUDTF.java
+++ b/core/src/main/java/hivemall/dataset/LogisticRegressionDataGeneratorUDTF.java
@@ -155,13 +155,9 @@
public void process(Object[] argOIs) throws HiveException {
if (rnd1 == null) {
assert (rnd2 == null);
- final int taskid = HadoopUtils.getTaskId(-1);
- final long seed;
- if (taskid == -1) {
- seed = r_seed; // Non-MR local task
- } else {
- seed = r_seed + taskid;
- }
+ int threadId = (int) Thread.currentThread().getId();
+ int taskid = HadoopUtils.getTaskId(threadId);
+ long seed = r_seed + taskid;
this.rnd1 = new Random(seed);
this.rnd2 = new Random(seed + 1);
}
diff --git a/core/src/main/java/hivemall/utils/hadoop/HadoopUtils.java b/core/src/main/java/hivemall/utils/hadoop/HadoopUtils.java
index 10a17dc..a85798a 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HadoopUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HadoopUtils.java
@@ -27,6 +27,8 @@
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
import java.net.URI;
import java.util.Iterator;
import java.util.Map.Entry;
@@ -138,7 +140,12 @@
public static int getTaskId() {
MapredContext ctx = MapredContextAccessor.get();
if (ctx == null) {
- throw new IllegalStateException("MapredContext is not set");
+ final int sparkTaskId = getSparkTaskId(-1);
+ if (sparkTaskId != -1) {
+ return sparkTaskId;
+ }
+ throw new IllegalStateException(
+ "Both hive.ql.exec.MapredContext and spark.TaskContext is not set");
}
JobConf jobconf = ctx.getJobConf();
if (jobconf == null) {
@@ -175,6 +182,46 @@
return taskid;
}
+ /**
+ * @return org.apache.spark.TaskContext.get().partitionId()
+ */
+ public static int getSparkTaskId(final int defaultValue) {
+ final Class<?> clazz;
+ try {
+ clazz = Class.forName("org.apache.spark.TaskContext");
+ } catch (ClassNotFoundException e) {
+ return defaultValue;
+ }
+ final Method getMethod;
+ try {
+ getMethod = clazz.getDeclaredMethod("get");
+ } catch (NoSuchMethodException | SecurityException e) {
+ return defaultValue;
+ }
+ final Object taskContextInstance;
+ try {
+ taskContextInstance = getMethod.invoke(null);
+ } catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
+ return defaultValue;
+ }
+ final Method partitionIdMethod;
+ try {
+ partitionIdMethod = clazz.getDeclaredMethod("partitionId");
+ } catch (NoSuchMethodException | SecurityException e) {
+ return defaultValue;
+ }
+ final Object result;
+ try {
+ result = partitionIdMethod.invoke(taskContextInstance);
+ } catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
+ return defaultValue;
+ }
+ if (result != null && result instanceof Integer) {
+ return ((Integer) result).intValue();
+ }
+ return defaultValue;
+ }
+
public static String getUniqueTaskIdString() {
MapredContext ctx = MapredContextAccessor.get();
if (ctx != null) {
diff --git a/docs/gitbook/spark/getting_started/installation.md b/docs/gitbook/spark/getting_started/installation.md
index d30b230..9398f79 100644
--- a/docs/gitbook/spark/getting_started/installation.md
+++ b/docs/gitbook/spark/getting_started/installation.md
@@ -20,9 +20,9 @@
Prerequisites
============
-* Spark v2.1 or later
-* Java 7 or later
-* `hivemall-spark-xxx-with-dependencies.jar` that can be found in [the ASF distribution mirror](https://www.apache.org/dyn/closer.cgi/incubator/hivemall/).
+* Spark v2.2 or later
+* Java 8 or later
+* `hivemall-all-<version>.jar` that can be found in [Maven central](https://search.maven.org/search?q=a:hivemall-all) (or use packages built by `bin/build.sh`).
* [define-all.spark](https://github.com/apache/incubator-hivemall/blob/master/resources/ddl/define-all.spark)
Installation
@@ -43,15 +43,11 @@
$ spark-shell --packages org.apache.hivemall:hivemall-all:<version>
```
-You find available Hivemall versions on [Maven repository](https://mvnrepository.com/artifact/org.apache.hivemall/hivemall-all/0.5.2-incubating).
+You find available Hivemall versions on [Maven repository](https://mvnrepository.com/artifact/org.apache.hivemall/hivemall-all/).
-> #### Notice
-> If you would like to try Hivemall functions on the latest release of Spark, you just say `bin/spark-shell` in a Hivemall package.
-> This command automatically downloads the latest Spark version, compiles Hivemall for the version, and invokes spark-shell with the compiled Hivemall binary.
-
Then, you load scripts for Hivemall functions.
```
-scala> :load resources/ddl/define-all.spark
+scala> :load ~/workspace/incubator-hivemall/resources/ddl/define-all.spark
```
\ No newline at end of file
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index c9e7efc..30c465d 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -139,7 +139,7 @@
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS minhashes")
-sqlContext.sql("CREATE TEMPORARY FUNCTION minhashes AS 'hivemall.knn.lsh.MinHashesUDFWrapper'")
+sqlContext.sql("CREATE TEMPORARY FUNCTION minhashes AS 'hivemall.knn.lsh.MinHashesUDF'")
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS minhash")
sqlContext.sql("CREATE TEMPORARY FUNCTION minhash AS 'hivemall.knn.lsh.MinHashUDTF'")
@@ -239,16 +239,16 @@
sqlContext.sql("CREATE TEMPORARY FUNCTION rand_amplify AS 'hivemall.ftvec.amplify.RandomAmplifierUDTF'")
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS add_bias")
-sqlContext.sql("CREATE TEMPORARY FUNCTION add_bias AS 'hivemall.ftvec.AddBiasUDFWrapper'")
+sqlContext.sql("CREATE TEMPORARY FUNCTION add_bias AS 'hivemall.ftvec.AddBiasUDF'")
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS sort_by_feature")
sqlContext.sql("CREATE TEMPORARY FUNCTION sort_by_feature AS 'hivemall.ftvec.SortByFeatureUDF'")
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS extract_feature")
-sqlContext.sql("CREATE TEMPORARY FUNCTION extract_feature AS 'hivemall.ftvec.ExtractFeatureUDFWrapper'")
+sqlContext.sql("CREATE TEMPORARY FUNCTION extract_feature AS 'hivemall.ftvec.ExtractFeatureUDF'")
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS extract_weight")
-sqlContext.sql("CREATE TEMPORARY FUNCTION extract_weight AS 'hivemall.ftvec.ExtractWeightUDFWrapper'")
+sqlContext.sql("CREATE TEMPORARY FUNCTION extract_weight AS 'hivemall.ftvec.ExtractWeightUDF'")
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS add_feature_index")
sqlContext.sql("CREATE TEMPORARY FUNCTION add_feature_index AS 'hivemall.ftvec.AddFeatureIndexUDF'")
@@ -398,9 +398,6 @@
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS float_array")
sqlContext.sql("CREATE TEMPORARY FUNCTION float_array AS 'hivemall.tools.array.AllocFloatArrayUDF'")
-sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS array_remove")
-sqlContext.sql("CREATE TEMPORARY FUNCTION array_remove AS 'hivemall.tools.array.ArrayRemoveUDF'")
-
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS sort_and_uniq_array")
sqlContext.sql("CREATE TEMPORARY FUNCTION sort_and_uniq_array AS 'hivemall.tools.array.SortAndUniqArrayUDF'")
@@ -428,21 +425,12 @@
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS to_string_array")
sqlContext.sql("CREATE TEMPORARY FUNCTION to_string_array AS 'hivemall.tools.array.ToStringArrayUDF'")
-sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS array_intersect")
-sqlContext.sql("CREATE TEMPORARY FUNCTION array_intersect AS 'hivemall.tools.array.ArrayIntersectUDF'")
-
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS select_k_best")
sqlContext.sql("CREATE TEMPORARY FUNCTION select_k_best AS 'hivemall.tools.array.SelectKBestUDF'")
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS array_append")
sqlContext.sql("CREATE TEMPORARY FUNCTION array_append AS 'hivemall.tools.array.ArrayAppendUDF'")
-sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS element_at")
-sqlContext.sql("CREATE TEMPORARY FUNCTION element_at AS 'hivemall.tools.array.ArrayElementAtUDF'")
-
-sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS array_union")
-sqlContext.sql("CREATE TEMPORARY FUNCTION array_union AS 'hivemall.tools.array.ArrayUnionUDF'")
-
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS first_element")
sqlContext.sql("CREATE TEMPORARY FUNCTION first_element AS 'hivemall.tools.array.FirstElementUDF'")
@@ -579,18 +567,14 @@
* MAPRED functions
*/
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS taskid")
+sqlContext.sql("CREATE TEMPORARY FUNCTION taskid AS 'hivemall.tools.mapred.TaskIdUDF'")
+
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS rowid")
-sqlContext.sql("CREATE TEMPORARY FUNCTION rowid AS 'hivemall.tools.mapred.RowIdUDFWrapper'")
+sqlContext.sql("CREATE TEMPORARY FUNCTION rowid AS 'hivemall.tools.mapred.RowIdUDF'")
-/**
- * JSON functions
- */
-
-sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS to_json")
-sqlContext.sql("CREATE TEMPORARY FUNCTION to_json AS 'hivemall.tools.json.ToJsonUDF'")
-
-sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS from_json")
-sqlContext.sql("CREATE TEMPORARY FUNCTION from_json AS 'hivemall.tools.json.FromJsonUDF'")
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS rownum")
+sqlContext.sql("CREATE TEMPORARY FUNCTION rownum AS 'hivemall.tools.mapred.RowNumberUDF'")
/**
* Sanity Check functions
@@ -599,9 +583,6 @@
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS assert")
sqlContext.sql("CREATE TEMPORARY FUNCTION assert AS 'hivemall.tools.sanity.AssertUDF'")
-sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS raise_error")
-sqlContext.sql("CREATE TEMPORARY FUNCTION raise_error AS 'hivemall.tools.sanity.RaiseErrorUDF'")
-
/**
* MISC functions
*/
@@ -663,7 +644,7 @@
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS lr_datagen")
-sqlContext.sql("CREATE TEMPORARY FUNCTION lr_datagen AS 'hivemall.dataset.LogisticRegressionDataGeneratorUDTFWrapper'")
+sqlContext.sql("CREATE TEMPORARY FUNCTION lr_datagen AS 'hivemall.dataset.LogisticRegressionDataGeneratorUDTF'")
/**
* Evaluating functions
@@ -839,13 +820,6 @@
sqlContext.sql("CREATE TEMPORARY FUNCTION train_slim AS 'hivemall.recommend.SlimUDTF'")
/**
- * Data Sketch
- */
-
-sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS approx_count_distinct")
-sqlContext.sql("CREATE TEMPORARY FUNCTION approx_count_distinct AS 'hivemall.sketch.hll.ApproxCountDistinctUDAF'")
-
-/**
* Bloom Filter
*/
@@ -871,12 +845,6 @@
* Aggregation
*/
-sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS max_by")
-sqlContext.sql("CREATE TEMPORARY FUNCTION max_by AS 'hivemall.tools.aggr.MaxByUDAF'")
-
-sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS min_by")
-sqlContext.sql("CREATE TEMPORARY FUNCTION min_by AS 'hivemall.tools.aggr.MinByUDAF'")
-
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS majority_vote")
sqlContext.sql("CREATE TEMPORARY FUNCTION majority_vote AS 'hivemall.tools.aggr.MajorityVoteUDAF'")
@@ -910,7 +878,7 @@
sqlContext.sql("CREATE TEMPORARY FUNCTION train_xgboost AS 'hivemall.xgboost.XGBoostTrainUDTF'")
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS xgboost_predict")
-sqlContext.sql("CREATE TEMPORARY FUNCTION xgboost_predict AS 'hivemall.xgboost.XGBoostOnlinePredictUDTFF'")
+sqlContext.sql("CREATE TEMPORARY FUNCTION xgboost_predict AS 'hivemall.xgboost.XGBoostOnlinePredictUDTF'")
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS xgboost_batch_predict")
sqlContext.sql("CREATE TEMPORARY FUNCTION xgboost_batch_predict AS 'hivemall.xgboost.XGBoostBatchPredictUDTF'")