datafu-spark subproject
Signed-off-by: Eyal Allweil <eyal@apache.org>
diff --git a/.gitignore b/.gitignore
index e219be2..28aa89b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -32,4 +32,12 @@
gradle/wrapper/gradle-wrapper.jar
gradle/wrapper/gradle-wrapper.properties
gradlew*
-.java-version
\ No newline at end of file
+.java-version
+datafu-spark/out
+datafu-spark/derby.log
+datafu-spark/spark-warehouse
+datafu-spark/metastore_db
+datafu-spark/bin
+datafu-spark/test-output
+datafu-spark/data
+datafu-spark/.apt_generated/
\ No newline at end of file
diff --git a/build.gradle b/build.gradle
index 4b7f396..cbe1a06 100644
--- a/build.gradle
+++ b/build.gradle
@@ -92,7 +92,11 @@
'datafu-pig/input*',
'datafu-pig/docs',
'datafu-pig/queries',
- 'datafu-pig/query'
+ 'datafu-pig/query',
+ 'datafu-spark/metastore_db/**',
+ 'datafu-spark/spark-warehouse/**',
+ 'datafu-spark/derby.log',
+ 'datafu-spark/data/**'
]
}
diff --git a/datafu-spark/README.md b/datafu-spark/README.md
new file mode 100644
index 0000000..429f136
--- /dev/null
+++ b/datafu-spark/README.md
@@ -0,0 +1,89 @@
+# datafu-spark
+
+datafu-spark contains a number of spark API's and a "Scala-Python bridge" that makes calling Scala code from Python, and vice-versa, easier.
+
+Here are some examples of things you can do with it:
+
+* ["Dedup" a table](https://github.com/apache/datafu/blob/spark-tmp/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala#L139) - remove duplicates based on a key and ordering (typically a date updated field, to get only the mostly recently updated record).
+
+* [Join a table with a numeric field with a table with a range](https://github.com/apache/datafu/blob/spark-tmp/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala#L361)
+
+* [Do a skewed join between tables](https://github.com/apache/datafu/blob/spark-tmp/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala#L274) (where the small table is still too big to fit in memory)
+
+* [Count distinct up to](https://github.com/apache/datafu/blob/spark-tmp/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala#L224) - an efficient implementation when you just want to verify that a certain minimum of distinct rows appear in a table
+
+It has been tested on Spark releases from 2.1.0 to 2.4.3, using Scala 2.10, 2.11 and 2.12. You can check if your Spark/Scala version combination has been tested by looking [here.](https://github.com/apache/datafu/blob/spark-tmp/datafu-spark/build_and_test_spark.sh#L20)
+
+-----------
+
+In order to call the datafu-spark API's from Pyspark, you can do the following (tested on a Hortonworks vm)
+
+First, call pyspark with the following parameters
+
+```bash
+export PYTHONPATH=datafu-spark_2.11_2.3.0-1.5.0-SNAPSHOT.jar
+
+pyspark --jars datafu-spark_2.11_2.3.0-1.5.0-SNAPSHOT.jar --conf spark.executorEnv.PYTHONPATH=datafu-spark_2.11_2.3.0-1.5.0-SNAPSHOT.jar
+```
+
+The following is an example of calling the Spark version of the datafu _dedup_ method
+
+```python
+from pyspark_utils.df_utils import PySparkDFUtils
+
+df_utils = PySparkDFUtils()
+
+df_people = sqlContext.createDataFrame([
+ ("a", "Alice", 34),
+ ("a", "Sara", 33),
+ ("b", "Bob", 36),
+ ("b", "Charlie", 30),
+ ("c", "David", 29),
+ ("c", "Esther", 32),
+ ("c", "Fanny", 36),
+ ("c", "Zoey", 36)],
+ ["id", "name", "age"])
+
+func_dedup_res = df_utils.dedup(dataFrame=df_people, groupCol=df_people.id,
+ orderCols=[df_people.age.desc(), df_people.name.desc()])
+
+func_dedup_res.registerTempTable("dedup")
+
+func_dedup_res.show()
+```
+
+This should produce the following output
+
+<pre>
++---+-----+---+
+| id| name|age|
++---+-----+---+
+| c| Zoey| 36|
+| b| Bob| 36|
+| a|Alice| 34|
++---+-----+---+
+</pre>
+
+-----------
+
+# Development
+
+Building and testing datafu-spark can be done as described in the [the main DataFu README.](https://github.com/apache/datafu/blob/master/README.md#developers)
+
+If you wish to build for a specific Scala/Spark version, there are two options. One is to change the *scalaVersion* and *sparkVersion* in [the main gradle.properties file.](https://github.com/apache/datafu/blob/spark-tmp/gradle.properties#L22)
+
+The other is to pass these parameters in the command line. For example, to build and test for Scala 2.12 and Spark 2.4.0, you would use
+
+```bash
+./gradlew :datafu-spark:test -PscalaVersion=2.12 -PsparkVersion=2.4.0
+```
+
+There is a [script](https://github.com/apache/datafu/tree/spark-tmp/datafu-spark/build_and_test_spark.sh) for building and testing datafu-spark across the multiple Scala/Spark combinations.
+
+To see the available options run it like this:
+
+```bash
+./build_and_test_spark.sh -h
+```
+
+
diff --git a/datafu-spark/build.gradle b/datafu-spark/build.gradle
new file mode 100644
index 0000000..52da990
--- /dev/null
+++ b/datafu-spark/build.gradle
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+// Much of this file is a variation on the Apache Samza build.gradle file
+
+buildscript {
+ repositories {
+ mavenCentral()
+ }
+}
+
+plugins {
+ id "de.undercouch.download" version "3.4.3"
+}
+
+apply from: file("gradle/dependency-versions-scala-" + scalaVersion + ".gradle")
+
+apply plugin: 'scala'
+
+allprojects {
+ // For all scala compilation, add extra compiler options, taken from version-specific
+ // dependency-versions-scala file applied above.
+ tasks.withType(ScalaCompile) {
+ scalaCompileOptions.additionalParameters = [ scalaOptions ]
+ }
+}
+
+archivesBaseName = 'datafu-spark_' + scalaVersion + '_' + sparkVersion
+
+cleanEclipse {
+ doLast {
+ delete ".apt_generated"
+ delete ".settings"
+ delete ".factorypath"
+ delete "bin"
+ }
+}
+
+dependencies {
+ compile "org.scala-lang:scala-library:$scalaLibVersion"
+
+ testCompile "com.holdenkarau:spark-testing-base_" + scalaVersion + ":" + sparkVersion + "_" + sparkTestingBaseVersion
+ testCompile "org.scalatest:scalatest_" + scalaVersion + ":" + sparkVersion
+}
+
+// we need to set up the build for hadoop 3
+if (hadoopVersion.startsWith("2.")) {
+ dependencies {
+ compile "org.apache.hadoop:hadoop-common:$hadoopVersion"
+ compile "org.apache.hadoop:hadoop-hdfs:$hadoopVersion"
+ compile "org.apache.hadoop:hadoop-mapreduce-client-jobclient:$hadoopVersion"
+ compile "org.apache.spark:spark-core_" + scalaVersion + ":" + sparkVersion
+ compile "org.apache.spark:spark-hive_" + scalaVersion + ":" + sparkVersion
+ }
+} else {
+ dependencies {
+ compile "org.apache.hadoop:hadoop-core:$hadoopVersion"
+ }
+}
+
+project.ext.sparkFile = file("build/spark-zips/spark-" + sparkVersion + ".zip")
+project.ext.sparkUnzipped = "build/spark-unzipped/spark-" + sparkVersion
+
+// download pyspark for testing. This is not shipped with datafu-spark.
+task downloadPySpark (type: Download) {
+ src 'https://github.com/apache/spark/archive/v' + sparkVersion + '.zip'
+ dest project.sparkFile
+ onlyIfNewer true
+}
+
+downloadPySpark.onlyIf {
+ ! project.sparkFile.exists()
+}
+
+task unzipPySpark(dependsOn: downloadPySpark, type: Copy) {
+ from zipTree(downloadPySpark.dest)
+ into file("build/spark-unzipped/")
+}
+
+unzipPySpark.onlyIf {
+ ! file(project.sparkUnzipped).exists()
+}
+
+task zipPySpark(dependsOn: unzipPySpark, type: Zip) {
+ archiveName = "pyspark-" + sparkVersion + ".zip"
+ include "pyspark/**/*"
+ destinationDir = file("data/pysparks/")
+ from file(project.sparkUnzipped + "/python/")
+}
+
+zipPySpark.onlyIf {
+ ! file("data/pysparks/pyspark-" + sparkVersion + ".zip").exists()
+}
+
+// download py4j for testing. This is not shipped with datafu-spark.
+project.ext.py4js = [
+ "py4j-0.10.4-src.zip" : "https://files.pythonhosted.org/packages/93/a7/0e1719e8ad34d194aae72dc07a37e65fd3895db7c797a67a828333cd6067/py4j-0.10.4-py2.py3-none-any.whl",
+ "py4j-0.10.6-src.zip" : "https://files.pythonhosted.org/packages/4a/08/162710786239aa72bd72bb46c64f2b02f54250412ba928cb373b30699139/py4j-0.10.6-py2.py3-none-any.whl",
+ "py4j-0.10.7-src.zip" : "https://files.pythonhosted.org/packages/e3/53/c737818eb9a7dc32a7cd4f1396e787bd94200c3997c72c1dbe028587bd76/py4j-0.10.7-py2.py3-none-any.whl",
+ "py4j-0.10.8.1-src.zip" : "https://files.pythonhosted.org/packages/04/de/2d314a921ef4c20b283e1de94e0780273678caac901564df06b948e4ba9b/py4j-0.10.8.1-py2.py3-none-any.whl"
+]
+
+task downloadPy4js {
+ doLast {
+ for (s in py4js) {
+ download {
+ src s.value
+ dest file("data/py4js/" + s.key)
+ }
+ }
+ }
+}
+
+downloadPy4js.onlyIf {
+ ! file("data/py4js").exists()
+}
+
+
+// The downloads of pyspark and py4j must succeed in order to test the Scala Python bridge in Eclipse or Gradle
+tasks.eclipse.dependsOn('zipPySpark')
+tasks.compileTestScala.dependsOn('zipPySpark')
+tasks.eclipse.dependsOn('downloadPy4js')
+tasks.compileTestScala.dependsOn('downloadPy4js')
+
+
+test {
+ systemProperty 'datafu.jar.dir', file('build/libs')
+ systemProperty 'datafu.data.dir', file('data')
+
+ systemProperty 'datafu.spark.version', sparkVersion
+
+ maxHeapSize = "2G"
+}
diff --git a/datafu-spark/build_and_test_spark.sh b/datafu-spark/build_and_test_spark.sh
new file mode 100755
index 0000000..ff8cdcd
--- /dev/null
+++ b/datafu-spark/build_and_test_spark.sh
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#!/bin/bash
+
+export SPARK_VERSIONS_FOR_SCALA_210="2.1.0 2.1.1 2.1.2 2.1.3 2.2.0 2.2.1 2.2.2"
+export SPARK_VERSIONS_FOR_SCALA_211="2.1.0 2.1.1 2.1.2 2.1.3 2.2.0 2.2.1 2.2.2 2.3.0 2.3.1 2.3.2 2.4.0 2.4.1 2.4.2 2.4.3"
+export SPARK_VERSIONS_FOR_SCALA_212="2.4.0 2.4.1 2.4.2 2.4.3"
+
+export LATEST_SPARK_VERSIONS_FOR_SCALA_210="2.1.3 2.2.2"
+export LATEST_SPARK_VERSIONS_FOR_SCALA_211="2.1.3 2.2.2 2.3.2 2.4.3"
+export LATEST_SPARK_VERSIONS_FOR_SCALA_212="2.4.3"
+
+STARTTIME=$(date +%s)
+
+function log {
+ echo $1
+ if [[ $LOG_FILE != "NONE" ]]; then
+ echo $1 >> $LOG_FILE
+ fi
+}
+
+function build {
+ echo "----- Building versions for Scala $scala, Spark $spark ----"
+ if ./gradlew :datafu-spark:clean; then
+ echo "----- Clean for Scala $scala, spark $spark succeeded"
+ if ./gradlew :datafu-spark:assemble -PscalaVersion=$scala -PsparkVersion=$spark; then
+ echo "----- Build for Scala $scala, spark $spark succeeded"
+ if ./gradlew :datafu-spark:test -PscalaVersion=$scala -PsparkVersion=$spark $TEST_PARAMS; then
+ log "Testing for Scala $scala, spark $spark succeeded"
+ if [[ $JARS_DIR != "NONE" ]]; then
+ cp datafu-spark/build/libs/*.jar $JARS_DIR/
+ fi
+ else
+ log "Testing for Scala $scala, spark $spark failed (build succeeded)"
+ fi
+ else
+ log "Build for Scala $scala, spark $spark failed"
+ fi
+ else
+ log "Clean for Scala $scala, Spark $spark failed"
+ fi
+}
+
+# -------------------------------------
+
+export JARS_DIR=NONE
+export LOG_FILE=NONE
+
+while getopts "l:j:t:hq" arg; do
+ case $arg in
+ l)
+ LOG_FILE=$OPTARG
+ ;;
+ j)
+ JARS_DIR=$OPTARG
+ ;;
+ t)
+ TEST_PARAMS=$OPTARG
+ ;;
+ q)
+ SPARK_VERSIONS_FOR_SCALA_210=$LATEST_SPARK_VERSIONS_FOR_SCALA_210
+ SPARK_VERSIONS_FOR_SCALA_211=$LATEST_SPARK_VERSIONS_FOR_SCALA_211
+ SPARK_VERSIONS_FOR_SCALA_212=$LATEST_SPARK_VERSIONS_FOR_SCALA_212
+ ;;
+ h)
+ echo "Builds and tests datafu-spark in multiple Scala/Spark combinations"
+ echo "Usage: build_ and_test_spark <options>"
+ echo " -t Optional. Name of param for passing to Gradle testing - for example, to only run the test 'FakeTest' pass '-Dtest.single=FakeTest'"
+ echo " -j Optional. Dir for putting artifacts that have compiled and tested successfully"
+ echo " -l Optional. Name of file for writing build summary log"
+ echo " -q Optional. Quick - only build and test the latest minor version of each major Spark release"
+ echo " -h Optional. Prints this help"
+ exit 0
+ ;;
+ esac
+done
+
+if [[ $LOG_FILE != "NONE" ]]; then
+ echo "Building datafu-spark: $TEST_PARAMS" > $LOG_FILE
+fi
+
+if [[ $JARS_DIR != "NONE" ]]; then
+ echo "Copying successfully built and tested jars to $JARS_DIR" > $LOG_FILE
+ mkdir $JARS_DIR
+fi
+
+export scala=2.10
+for spark in $SPARK_VERSIONS_FOR_SCALA_210; do
+ build
+done
+
+export scala=2.11
+for spark in $SPARK_VERSIONS_FOR_SCALA_211; do
+ build
+done
+
+export scala=2.12
+for spark in $SPARK_VERSIONS_FOR_SCALA_212; do
+ build
+done
+
+export ENDTIME=$(date +%s)
+
+log "Build took $(((($ENDTIME - $STARTTIME))/60)) minutes, $((($ENDTIME - $STARTTIME)%60)) seconds"
+
diff --git a/datafu-spark/gradle/dependency-versions-scala-2.10.gradle b/datafu-spark/gradle/dependency-versions-scala-2.10.gradle
new file mode 100644
index 0000000..75cec19
--- /dev/null
+++ b/datafu-spark/gradle/dependency-versions-scala-2.10.gradle
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+ext {
+ scalaVersion = "2.10"
+ scalaLibVersion = "2.10.7"
+ sparkTestingBaseVersion = "0.12.0"
+ // Extra options for the compiler:
+ // -feature: Give detailed warnings about language feature use (rather than just 'there were 4 warnings')
+ scalaOptions = "-feature"
+}
diff --git a/datafu-spark/gradle/dependency-versions-scala-2.11.gradle b/datafu-spark/gradle/dependency-versions-scala-2.11.gradle
new file mode 100644
index 0000000..faae562
--- /dev/null
+++ b/datafu-spark/gradle/dependency-versions-scala-2.11.gradle
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+ext {
+ scalaVersion = "2.11"
+ scalaLibVersion = "2.11.8"
+ sparkTestingBaseVersion = "0.12.0"
+ // Extra options for the compiler:
+ // -feature: Give detailed warnings about language feature use (rather than just 'there were 4 warnings')
+ scalaOptions = "-feature"
+}
diff --git a/datafu-spark/gradle/dependency-versions-scala-2.12.gradle b/datafu-spark/gradle/dependency-versions-scala-2.12.gradle
new file mode 100644
index 0000000..7e31630
--- /dev/null
+++ b/datafu-spark/gradle/dependency-versions-scala-2.12.gradle
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+ext {
+ scalaVersion = "2.12"
+ scalaLibVersion = "2.12.8"
+ sparkTestingBaseVersion = "0.12.0"
+ // Extra options for the compiler:
+ // -feature: Give detailed warnings about language feature use (rather than just 'there were 4 warnings')
+ scalaOptions = "-feature"
+}
diff --git a/datafu-spark/src/main/resources/META-INF/LICENSE b/datafu-spark/src/main/resources/META-INF/LICENSE
new file mode 100644
index 0000000..57bc88a
--- /dev/null
+++ b/datafu-spark/src/main/resources/META-INF/LICENSE
@@ -0,0 +1,202 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
diff --git a/datafu-spark/src/main/resources/META-INF/NOTICE b/datafu-spark/src/main/resources/META-INF/NOTICE
new file mode 100644
index 0000000..123f612
--- /dev/null
+++ b/datafu-spark/src/main/resources/META-INF/NOTICE
@@ -0,0 +1,6 @@
+Apache DataFu
+Copyright 2010-2018 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
diff --git a/datafu-spark/src/main/resources/META-INF/services/datafu.spark.PythonResource b/datafu-spark/src/main/resources/META-INF/services/datafu.spark.PythonResource
new file mode 100644
index 0000000..ee69a0d
--- /dev/null
+++ b/datafu-spark/src/main/resources/META-INF/services/datafu.spark.PythonResource
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+datafu.spark.CoreBridgeDirectory
diff --git a/datafu-spark/src/main/resources/pyspark_utils/__init__.py b/datafu-spark/src/main/resources/pyspark_utils/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/datafu-spark/src/main/resources/pyspark_utils/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/datafu-spark/src/main/resources/pyspark_utils/bridge_utils.py b/datafu-spark/src/main/resources/pyspark_utils/bridge_utils.py
new file mode 100644
index 0000000..40134d0
--- /dev/null
+++ b/datafu-spark/src/main/resources/pyspark_utils/bridge_utils.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+
+from py4j.java_gateway import JavaGateway, GatewayClient
+from pyspark.conf import SparkConf
+from pyspark.context import SparkContext
+from pyspark.sql import SparkSession
+
+# use jvm gateway to create a java class instance by full-qualified class name
+def _getjvm_class(gateway, fullClassName):
+ return gateway.jvm.java.lang.Thread.currentThread().getContextClassLoader().loadClass(fullClassName).newInstance()
+
+
+class Context(object):
+
+ def __init__(self):
+ from py4j.java_gateway import java_import
+ """When running a Python script from Scala - this function is called
+ by the script to initialize the connection to the Java Gateway and get the spark context.
+ code is basically copied from:
+ https://github.com/apache/zeppelin/blob/master/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py#L30
+ """
+
+ if os.environ.get("SPARK_EXECUTOR_URI"):
+ SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"])
+
+ gateway = JavaGateway(GatewayClient(port=int(os.environ.get("PYSPARK_GATEWAY_PORT"))), auto_convert=True)
+ java_import(gateway.jvm, "org.apache.spark.SparkEnv")
+ java_import(gateway.jvm, "org.apache.spark.SparkConf")
+ java_import(gateway.jvm, "org.apache.spark.api.java.*")
+ java_import(gateway.jvm, "org.apache.spark.api.python.*")
+ java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
+ java_import(gateway.jvm, "org.apache.spark.sql.*")
+ java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
+
+ intp = gateway.entry_point
+
+ jSparkSession = intp.pyGetSparkSession()
+ jsc = intp.pyGetJSparkContext(jSparkSession)
+ jconf = intp.pyGetSparkConf(jsc)
+ conf = SparkConf(_jvm = gateway.jvm, _jconf = jconf)
+ self.sc = SparkContext(jsc=jsc, gateway=gateway, conf=conf)
+
+ # Spark 2
+ self.sparkSession = SparkSession(self.sc, jSparkSession)
+ self.sqlContext = self.sparkSession._wrapped
+
+ctx = None
+
+
+def get_contexts():
+ global ctx
+ if not ctx:
+ ctx = Context()
+
+ return ctx.sc, ctx.sqlContext, ctx.sparkSession
diff --git a/datafu-spark/src/main/resources/pyspark_utils/df_utils.py b/datafu-spark/src/main/resources/pyspark_utils/df_utils.py
new file mode 100644
index 0000000..16553ce
--- /dev/null
+++ b/datafu-spark/src/main/resources/pyspark_utils/df_utils.py
@@ -0,0 +1,171 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pyspark
+from pyspark.sql import DataFrame
+from pyspark_utils.bridge_utils import _getjvm_class
+
+
+def _get_utils(df):
+ _gateway = df._sc._gateway
+ return _getjvm_class(_gateway, "datafu.spark.SparkDFUtilsBridge")
+
+
+# public:
+
+
+def dedup_with_order(df, group_col, order_cols = []):
+ """
+ Used get the 'latest' record (after ordering according to the provided order columns) in each group.
+ :param df: DataFrame to operate on
+ :param group_col: column to group by the records
+ :param order_cols: columns to order the records according to.
+ :return: DataFrame representing the data after the operation
+ """
+ java_cols = _cols_to_java_cols(order_cols)
+ jdf = _get_utils(df).dedupWithOrder(df._jdf, group_col._jc, java_cols)
+ return DataFrame(jdf, df.sql_ctx)
+
+
+def dedup_top_n(df, n, group_col, order_cols = []):
+ """
+ Used get the top N records (after ordering according to the provided order columns) in each group.
+ :param df: DataFrame to operate on
+ :param n: number of records to return from each group
+ :param group_col: column to group by the records
+ :param order_cols: columns to order the records according to
+ :return: DataFrame representing the data after the operation
+ """
+ java_cols = _cols_to_java_cols(order_cols)
+ jdf = _get_utils(df).dedupTopN(df._jdf, n, group_col._jc, java_cols)
+ return DataFrame(jdf, df.sql_ctx)
+
+
+def dedup_with_combiner(df, group_col, order_by_col, desc = True, columns_filter = [], columns_filter_keep = True):
+ """
+ Used get the 'latest' record (after ordering according to the provided order columns) in each group.
+ :param df: DataFrame to operate on
+ :param group_col: column to group by the records
+ :param order_by_col: column to order the records according to
+ :param desc: have the order as desc
+ :param columns_filter: columns to filter
+ :param columns_filter_keep: indicates whether we should filter the selected columns 'out' or alternatively have only
+* those columns in the result
+ :return: DataFrame representing the data after the operation
+ """
+ jdf = _get_utils(df).dedupWithCombiner(df._jdf, group_col._jc, order_by_col._jc, desc, columns_filter, columns_filter_keep)
+ return DataFrame(jdf, df.sql_ctx)
+
+
+def change_schema(df, new_scheme = []):
+ """
+ Returns a DataFrame with the column names renamed to the column names in the new schema
+ :param df: DataFrame to operate on
+ :param new_scheme: new column names
+ :return: DataFrame representing the data after the operation
+ """
+ jdf = _get_utils(df).changeSchema(df._jdf, new_scheme)
+ return DataFrame(jdf, df.sql_ctx)
+
+
+def join_skewed(df_left, df_right, join_exprs, num_shards = 30, join_type="inner"):
+ """
+ Used to perform a join when the right df is relatively small but doesn't fit to perform broadcast join.
+ Use cases:
+ a. excluding keys that might be skew from a medium size list.
+ b. join a big skewed table with a table that has small number of very big rows.
+ :param df_left: left DataFrame
+ :param df_right: right DataFrame
+ :param join_exprs: join expression
+ :param num_shards: number of shards
+ :param join_type: join type
+ :return: DataFrame representing the data after the operation
+ """
+ jdf = _get_utils(df_left).joinSkewed(df_left._jdf, df_right._jdf, join_exprs._jc, num_shards, join_type)
+ return DataFrame(jdf, df_left.sql_ctx)
+
+
+def broadcast_join_skewed(not_skewed_df, skewed_df, join_col, number_of_custs_to_broadcast):
+ """
+ Suitable to perform a join in cases when one DF is skewed and the other is not skewed.
+ splits both of the DFs to two parts according to the skewed keys.
+ 1. Map-join: broadcasts the skewed-keys part of the not skewed DF to the skewed-keys part of the skewed DF
+ 2. Regular join: between the remaining two parts.
+ :param not_skewed_df: not skewed DataFrame
+ :param skewed_df: skewed DataFrame
+ :param join_col: join column
+ :param number_of_custs_to_broadcast: number of custs to broadcast
+ :return: DataFrame representing the data after the operation
+ """
+ jdf = _get_utils(skewed_df).broadcastJoinSkewed(not_skewed_df._jdf, skewed_df._jdf, join_col, number_of_custs_to_broadcast)
+ return DataFrame(jdf, not_skewed_df.sql_ctx)
+
+
+def join_with_range(df_single, col_single, df_range, col_range_start, col_range_end, decrease_factor):
+ """
+ Helper function to join a table with column to a table with range of the same column.
+ For example, ip table with whois data that has range of ips as lines.
+ The main problem which this handles is doing naive explode on the range can result in huge table.
+ requires:
+ 1. single table needs to be distinct on the join column, because there could be a few corresponding ranges so we dedup at the end - we choose the minimal range.
+ 2. the range and single columns to be numeric.
+ """
+ jdf = _get_utils(df_single).joinWithRange(df_single._jdf, col_single, df_range._jdf, col_range_start, col_range_end, decrease_factor)
+ return DataFrame(jdf, df_single.sql_ctx)
+
+
+def join_with_range_and_dedup(df_single, col_single, df_range, col_range_start, col_range_end, decrease_factor, dedup_small_range):
+ """
+ Helper function to join a table with column to a table with range of the same column.
+ For example, ip table with whois data that has range of ips as lines.
+ The main problem which this handles is doing naive explode on the range can result in huge table.
+ requires:
+ 1. single table needs to be distinct on the join column, because there could be a few corresponding ranges so we dedup at the end - we choose the minimal range.
+ 2. the range and single columns to be numeric.
+ """
+ jdf = _get_utils(df_single).joinWithRangeAndDedup(df_single._jdf, col_single, df_range._jdf, col_range_start, col_range_end, decrease_factor, dedup_small_range)
+ return DataFrame(jdf, df_single.sql_ctx)
+
+
+def _cols_to_java_cols(cols):
+ return _map_if_needed(lambda x: x._jc, cols)
+
+
+def _dfs_to_java_dfs(dfs):
+ return _map_if_needed(lambda x: x._jdf, dfs)
+
+
+def _map_if_needed(func, itr):
+ return map(func, itr) if itr is not None else itr
+
+
+def activate():
+ """Activate integration between datafu-spark and PySpark.
+ This function only needs to be called once.
+
+ This technique taken from pymongo_spark
+ https://github.com/mongodb/mongo-hadoop/blob/master/spark/src/main/python/pymongo_spark.py
+ """
+ pyspark.sql.DataFrame.dedup_with_order = dedup_with_order
+ pyspark.sql.DataFrame.dedup_top_n = dedup_top_n
+ pyspark.sql.DataFrame.dedup_with_combiner = dedup_with_combiner
+ pyspark.sql.DataFrame.change_schema = change_schema
+ pyspark.sql.DataFrame.join_skewed = join_skewed
+ pyspark.sql.DataFrame.broadcast_join_skewed = broadcast_join_skewed
+ pyspark.sql.DataFrame.join_with_range = join_with_range
+ pyspark.sql.DataFrame.join_with_range_and_dedup = join_with_range_and_dedup
+
diff --git a/datafu-spark/src/main/resources/pyspark_utils/init_spark_context.py b/datafu-spark/src/main/resources/pyspark_utils/init_spark_context.py
new file mode 100644
index 0000000..5d01169
--- /dev/null
+++ b/datafu-spark/src/main/resources/pyspark_utils/init_spark_context.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pyspark_utils.bridge_utils import get_contexts
+sc, sqlContext, spark = get_contexts()
+
+print("initiated contexts")
diff --git a/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala b/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
new file mode 100644
index 0000000..be4600d
--- /dev/null
+++ b/datafu-spark/src/main/scala/datafu/spark/DataFrameOps.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package datafu.spark
+
+import org.apache.spark.sql.{Column, DataFrame}
+
+/**
+ * implicit class to enable easier usage e.g:
+ *
+ * df.dedup(..)
+ *
+ * instead of:
+ *
+ * SparkDFUtils.dedup(...)
+ *
+ */
+object DataFrameOps {
+
+ implicit class someDataFrameUtils(df: DataFrame) {
+
+ def dedupWithOrder(groupCol: Column, orderCols: Column*): DataFrame =
+ SparkDFUtils.dedupWithOrder(df, groupCol, orderCols: _*)
+
+ def dedupTopN(n: Int, groupCol: Column, orderCols: Column*): DataFrame =
+ SparkDFUtils.dedupTopN(df, n, groupCol, orderCols: _*)
+
+ def dedupWithCombiner(groupCol: Column,
+ orderByCol: Column,
+ desc: Boolean = true,
+ moreAggFunctions: Seq[Column] = Nil,
+ columnsFilter: Seq[String] = Nil,
+ columnsFilterKeep: Boolean = true): DataFrame =
+ SparkDFUtils.dedupWithCombiner(df,
+ groupCol,
+ orderByCol,
+ desc,
+ moreAggFunctions,
+ columnsFilter,
+ columnsFilterKeep)
+
+ def flatten(colName: String): DataFrame = SparkDFUtils.flatten(df, colName)
+
+ def changeSchema(newScheme: String*): DataFrame =
+ SparkDFUtils.changeSchema(df, newScheme: _*)
+
+ def joinWithRange(colSingle: String,
+ dfRange: DataFrame,
+ colRangeStart: String,
+ colRangeEnd: String,
+ DECREASE_FACTOR: Long = 2 ^ 8): DataFrame =
+ SparkDFUtils.joinWithRange(df,
+ colSingle,
+ dfRange,
+ colRangeStart,
+ colRangeEnd,
+ DECREASE_FACTOR)
+
+ def joinWithRangeAndDedup(colSingle: String,
+ dfRange: DataFrame,
+ colRangeStart: String,
+ colRangeEnd: String,
+ DECREASE_FACTOR: Long = 2 ^ 8,
+ dedupSmallRange: Boolean = true): DataFrame =
+ SparkDFUtils.joinWithRangeAndDedup(df,
+ colSingle,
+ dfRange,
+ colRangeStart,
+ colRangeEnd,
+ DECREASE_FACTOR,
+ dedupSmallRange)
+
+ def broadcastJoinSkewed(skewed: DataFrame,
+ joinCol: String,
+ numberCustsToBroadcast: Int): DataFrame =
+ SparkDFUtils.broadcastJoinSkewed(df,
+ skewed,
+ joinCol,
+ numberCustsToBroadcast)
+
+ def joinSkewed(notSkewed: DataFrame,
+ joinExprs: Column,
+ numShards: Int = 1000,
+ joinType: String = "inner"): DataFrame =
+ SparkDFUtils.joinSkewed(df, notSkewed, joinExprs, numShards, joinType)
+ }
+}
diff --git a/datafu-spark/src/main/scala/datafu/spark/PythonPathsManager.scala b/datafu-spark/src/main/scala/datafu/spark/PythonPathsManager.scala
new file mode 100644
index 0000000..26377b8
--- /dev/null
+++ b/datafu-spark/src/main/scala/datafu/spark/PythonPathsManager.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package datafu.spark
+
+import java.io.{File, IOException}
+import java.net.JarURLConnection
+import java.nio.file.Paths
+import java.util
+import java.util.{MissingResourceException, ServiceLoader}
+
+import scala.collection.JavaConverters._
+
+import org.apache.log4j.Logger
+
+
+/**
+ * Represents a resource that needs to be added to PYTHONPATH used by ScalaPythonBridge.
+ *
+ * To ensure your python resources (modules, files, etc.) are properly added to the bridge,
+ * do the following:
+ * 1) Put all the resource under some root directory with a unique name x, and make sure path/to/x
+ * is visible to the class loader (usually just use src/main/resources/x).
+ * 2) Extend this class like this:
+ * class MyResource extends PythonResource("x")
+ * This assumes x is under src/main/resources/x
+ * 3) (since we use ServiceLoader) Add a file to your jar/project:
+ * META-INF/services/spark.utils.PythonResource
+ * with a single line containing the full name (including package) of MyResource.
+ *
+ * This process involves scanning the entire jar and copying files from the jar to some temporary
+ * location, so if your jar is really big consider putting the resources in a smaller jar.
+ *
+ * @param resourcePath Path to the resource, will be loaded via
+ * getClass.getClassLoader.getResource()
+ * @param isAbsolutePath Set to true if the resource is in some absolute path rather than in jar
+ * (try to avoid that).
+ */
+abstract class PythonResource(val resourcePath: String,
+ val isAbsolutePath: Boolean = false)
+
+/**
+ * There are two phases of resolving python files path:
+ *
+ * 1) When launching spark:
+ * the files need to be added to spark.executorEnv.PYTHONPATH
+ *
+ * 2) When executing python file via bridge:
+ * the files need to be added to the process PYTHONPATH.
+ * This is different than the previous phase because
+ * this python process is spawned by datafu-spark, not by spark, and always on the driver.
+ */
+object PythonPathsManager {
+
+ case class ResolvedResource(resource: PythonResource,
+ resolvedLocation: String)
+
+ private val logger: Logger = Logger.getLogger(getClass)
+
+ val resources: Seq[ResolvedResource] =
+ ServiceLoader
+ .load(classOf[PythonResource])
+ .asScala
+ .map(p => ResolvedResource(p, resolveDependencyLocation(p)))
+ .toSeq
+
+ logResolved
+
+ def getAbsolutePaths(): Seq[String] = resources.map(_.resolvedLocation).distinct
+ def getAbsolutePathsForJava(): util.List[String] =
+ resources.map(_.resolvedLocation).distinct.asJava
+
+ def getPYTHONPATH(): String =
+ resources
+ .map(_.resolvedLocation)
+ .map(p => new File(p))
+ .map(_.getName) // get just the name of the file
+ .mkString(":")
+
+ private def resolveDependencyLocation(resource: PythonResource): String =
+ if (resource.isAbsolutePath) {
+ if (!new File(resource.resourcePath).exists()) {
+ throw new IOException(
+ "Could not find resource in absolute path: " + resource.resourcePath)
+ } else {
+ logger.info("Using file absolute path: " + resource.resourcePath)
+ resource.resourcePath
+ }
+ } else {
+ Option(getClass.getClassLoader.getResource(resource.resourcePath)) match {
+ case None =>
+ logger.error(
+ "Didn't find resource in classpath! resource path: " + resource.resourcePath)
+ throw new MissingResourceException(
+ "Didn't find resource in classpath!",
+ resource.getClass.getName,
+ resource.resourcePath)
+ case Some(p) =>
+ p.toURI.getScheme match {
+ case "jar" =>
+ // if dependency is inside jar file, use jar file path:
+ val jarPath = new File(
+ p.openConnection()
+ .asInstanceOf[JarURLConnection]
+ .getJarFileURL
+ .toURI).getPath
+ logger.info(
+ s"Dependency ${resource.resourcePath} found inside jar: " + jarPath)
+ jarPath
+ case "file" =>
+ val file = new File(p.getFile)
+ if (!file.exists()) {
+ logger.warn("Dependency not found, skipping: " + file.getPath)
+ null
+ } else {
+ if (file.isDirectory) {
+ val t_path =
+ if (System
+ .getProperty("os.name")
+ .toLowerCase()
+ .contains("win") && p.getPath().startsWith("/")) {
+ val path = p.getPath.substring(1)
+ logger.warn(
+ s"Fixing path for windows operating system! " +
+ s"converted ${p.getPath} to $path")
+ path
+ } else {
+ p.getPath
+ }
+ val path = Paths.get(t_path)
+ logger.info(
+ s"Dependency found as directory: ${t_path}\n\tusing " +
+ s"parent path: ${path.getParent}")
+ path.getParent.toString
+ } else {
+ logger.info("Dependency found as a file: " + p.getPath)
+ p.getPath
+ }
+ }
+ }
+ }
+ }
+
+ private def logResolved = {
+ logger.info(s"Discovered ${resources.size} python paths:\n" +
+ resources
+ .map(p =>
+ s"className: ${p.resource.getClass.getName}\n\tresource: " +
+ s"${p.resource.resourcePath}\n\tlocation: ${p.resolvedLocation}")
+ .mkString("\n")) + "\n\n"
+ }
+}
+
+/**
+ * Contains all python files needed by the bridge itself
+ */
+class CoreBridgeDirectory extends PythonResource("pyspark_utils")
diff --git a/datafu-spark/src/main/scala/datafu/spark/ScalaPythonBridge.scala b/datafu-spark/src/main/scala/datafu/spark/ScalaPythonBridge.scala
new file mode 100644
index 0000000..1726916
--- /dev/null
+++ b/datafu-spark/src/main/scala/datafu/spark/ScalaPythonBridge.scala
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package datafu.spark
+
+import java.io._
+import java.net.URL
+import java.nio.file.Files
+import java.util.UUID
+
+import org.slf4j.LoggerFactory
+
+import org.apache.spark.SparkConf
+import org.apache.spark.api.java.JavaSparkContext
+import org.apache.spark.deploy.SparkPythonRunner
+import org.apache.spark.sql.SparkSession
+
+
+/**
+ * this class let's the user invoke PySpark code from scala
+ * example usage:
+ *
+ * val runner = ScalaPythonBridgeRunner()
+ * runner.runPythonFile("my_package/my_pyspark_logic.py")
+ *
+ */
+case class ScalaPythonBridgeRunner(extraPath: String = "") {
+
+ val logger = LoggerFactory.getLogger(this.getClass)
+ // for the bridge we take the full resolved location,
+ // since this runs on the driver where the files are local:
+ logger.info("constructing PYTHONPATH")
+
+ // we include multiple options for py4j because on any given cluster only one should be found
+ val pythonPath = (PythonPathsManager.getAbsolutePaths() ++
+ Array("pyspark.zip",
+ "py4j-0.10.4-src.zip",
+ "py4j-0.10.6-src.zip",
+ "py4j-0.10.7-src.zip",
+ "py4j-0.10.8.1-src.zip") ++
+ Option(extraPath).getOrElse("").split(",")).distinct
+
+ logger.info("Bridge PYTHONPATH: " + pythonPath.mkString(":"))
+
+ val runner = SparkPythonRunner(pythonPath.mkString(","))
+
+ def runPythonFile(filename: String): String = {
+ val pyScript = resolveRunnableScript(filename)
+ logger.info(s"Running python file $pyScript")
+ runner.runPyFile(pyScript)
+ }
+
+ def runPythonString(str: String): String = {
+ val tmpFile = writeToTempFile(str, "pyspark-tmp-file-", ".py")
+ logger.info(
+ "Running tmp PySpark file: " + tmpFile.getAbsolutePath + " with content:\n" + str)
+ runner.runPyFile(tmpFile.getAbsolutePath)
+ }
+
+ private def resolveRunnableScript(path: String): String = {
+ logger.info("Resolving python script location for: " + path)
+
+ val res
+ : String = Option(this.getClass.getClassLoader.getResource(path)) match {
+ case None =>
+ logger.info("Didn't find script via classLoader, using as is: " + path)
+ path
+ case Some(resource) =>
+ resource.toURI.getScheme match {
+ case "jar" =>
+ // if inside jar, extract it and return cloned file:
+ logger.info("Script found inside jar, extracting...")
+ val outputFile = ResourceCloning.cloneResource(resource, path)
+ logger.info("Extracted file path: " + outputFile.getPath)
+ outputFile.getPath
+ case _ =>
+ logger.info("Using script original path: " + resource.getPath)
+ resource.getPath
+ }
+ }
+ res
+ }
+
+ private def writeToTempFile(contents: String,
+ prefix: String,
+ suffix: String): File = {
+ val tempFi = File.createTempFile(prefix, suffix)
+ tempFi.deleteOnExit()
+ val bw = new BufferedWriter(new FileWriter(tempFi))
+ bw.write(contents)
+ bw.close()
+ tempFi
+ }
+
+}
+
+/**
+ * Do not instantiate this class! Use the companion object instead.
+ * This class should only be used by python
+ */
+object ScalaPythonBridge { // need empty ctor for py4j gateway
+
+ /**
+ * members used to allow python script share context with main Scala program calling it.
+ * Python script calls :
+ * sc, sqlContext, spark = utils.get_contexts()
+ * our Python util function get_contexts
+ * uses the following to create Python wrappers around Java SparkContext and SQLContext.
+ */
+ // Called by python util get_contexts()
+ def pyGetSparkSession(): SparkSession = SparkSession.builder().getOrCreate()
+ def pyGetJSparkContext(sparkSession: SparkSession): JavaSparkContext =
+ new JavaSparkContext(sparkSession.sparkContext)
+ def pyGetSparkConf(jsc: JavaSparkContext): SparkConf = jsc.getConf
+
+}
+
+/**
+ * Utility for extracting resource from a jar and copy it to a temporary location
+ */
+object ResourceCloning {
+
+ private val logger = LoggerFactory.getLogger(this.getClass)
+
+ val uuid = UUID.randomUUID().toString.substring(6)
+ val outputTempDir = new File(System.getProperty("java.io.tmpdir"),
+ s"risk_tmp/$uuid/cloned_resources/")
+ forceMkdirs(outputTempDir)
+
+ def cloneResource(resource: URL, outputFileName: String): File = {
+ val outputTmpFile = new File(outputTempDir, outputFileName)
+ if (outputTmpFile.exists()) {
+ logger.info(s"resource $outputFileName already exists, skipping..")
+ outputTmpFile
+ } else {
+ logger.info("cloning resource: " + resource)
+ if (!outputTmpFile.exists()) {
+ // it is possible that the file was already extracted in the session
+ forceMkdirs(outputTmpFile.getParentFile)
+ val inputStream = resource.openStream()
+ streamToFile(outputTmpFile, inputStream)
+ }
+ outputTmpFile
+ }
+ }
+
+ private def forceMkdirs(dir: File) =
+ if (!dir.exists() && !dir.mkdirs()) {
+ throw new IOException("Failed to create " + dir.getPath)
+ }
+
+ private def streamToFile(outputFile: File, inputStream: InputStream) = {
+ try {
+ Files.copy(inputStream, outputFile.toPath)
+ } finally {
+ inputStream.close()
+ assert(outputFile.exists())
+ }
+ }
+}
diff --git a/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala b/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
new file mode 100644
index 0000000..0ee1520
--- /dev/null
+++ b/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala
@@ -0,0 +1,510 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package datafu.spark
+
+import java.util.{List => JavaList}
+
+import org.apache.spark.sql.{Column, DataFrame}
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{LongType, SparkOverwriteUDAFs, StructType}
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * class definition so we could expose this functionality in PySpark
+ */
+class SparkDFUtilsBridge {
+
+ def dedupWithOrder(df: DataFrame,
+ groupCol: Column,
+ orderCols: JavaList[Column]): DataFrame = {
+ val converted = convertJavaListToSeq(orderCols)
+ SparkDFUtils.dedupWithOrder(df = df, groupCol = groupCol, orderCols = converted: _*)
+ }
+
+ def dedupTopN(df: DataFrame,
+ n: Int,
+ groupCol: Column,
+ orderCols: JavaList[Column]): DataFrame = {
+ val converted = convertJavaListToSeq(orderCols)
+ SparkDFUtils.dedupTopN(df = df,
+ n = n,
+ groupCol = groupCol,
+ orderCols = converted: _*)
+ }
+
+ def dedupWithCombiner(df: DataFrame,
+ groupCol: Column,
+ orderByCol: Column,
+ desc: Boolean,
+ columnsFilter: JavaList[String],
+ columnsFilterKeep: Boolean): DataFrame = {
+ val columnsFilter_converted = convertJavaListToSeq(columnsFilter)
+ SparkDFUtils.dedupWithCombiner(
+ df = df,
+ groupCol = groupCol,
+ orderByCol = orderByCol,
+ desc = desc,
+ moreAggFunctions = Nil,
+ columnsFilter = columnsFilter_converted,
+ columnsFilterKeep = columnsFilterKeep
+ )
+ }
+
+ def changeSchema(df: DataFrame, newScheme: JavaList[String]): DataFrame = {
+ val newScheme_converted = convertJavaListToSeq(newScheme)
+ SparkDFUtils.changeSchema(df = df, newScheme = newScheme_converted: _*)
+ }
+
+ def joinSkewed(dfLeft: DataFrame,
+ dfRight: DataFrame,
+ joinExprs: Column,
+ numShards: Int,
+ joinType: String): DataFrame = {
+ SparkDFUtils.joinSkewed(dfLeft = dfLeft,
+ dfRight = dfRight,
+ joinExprs = joinExprs,
+ numShards = numShards,
+ joinType = joinType)
+ }
+
+ def broadcastJoinSkewed(notSkewed: DataFrame,
+ skewed: DataFrame,
+ joinCol: String,
+ numRowsToBroadcast: Int): DataFrame = {
+ SparkDFUtils.broadcastJoinSkewed(notSkewed = notSkewed,
+ skewed = skewed,
+ joinCol = joinCol,
+ numRowsToBroadcast = numRowsToBroadcast)
+ }
+
+ def joinWithRange(dfSingle: DataFrame,
+ colSingle: String,
+ dfRange: DataFrame,
+ colRangeStart: String,
+ colRangeEnd: String,
+ DECREASE_FACTOR: Long): DataFrame = {
+ SparkDFUtils.joinWithRange(dfSingle = dfSingle,
+ colSingle = colSingle,
+ dfRange = dfRange,
+ colRangeStart = colRangeStart,
+ colRangeEnd = colRangeEnd,
+ DECREASE_FACTOR = DECREASE_FACTOR)
+ }
+
+ def joinWithRangeAndDedup(dfSingle: DataFrame,
+ colSingle: String,
+ dfRange: DataFrame,
+ colRangeStart: String,
+ colRangeEnd: String,
+ DECREASE_FACTOR: Long,
+ dedupSmallRange: Boolean): DataFrame = {
+ SparkDFUtils.joinWithRangeAndDedup(
+ dfSingle = dfSingle,
+ colSingle = colSingle,
+ dfRange = dfRange,
+ colRangeStart = colRangeStart,
+ colRangeEnd = colRangeEnd,
+ DECREASE_FACTOR = DECREASE_FACTOR,
+ dedupSmallRange = dedupSmallRange
+ )
+ }
+
+ private def convertJavaListToSeq[T](list: JavaList[T]): Seq[T] = {
+ scala.collection.JavaConverters
+ .asScalaIteratorConverter(list.iterator())
+ .asScala
+ .toList
+ }
+}
+
+object SparkDFUtils {
+
+ /**
+ * Used to get the 'latest' record (after ordering according to the provided order columns)
+ * in each group.
+ * Different from {@link org.apache.spark.sql.Dataset#dropDuplicates} because order matters.
+ *
+ * @param df DataFrame to operate on
+ * @param groupCol column to group by the records
+ * @param orderCols columns to order the records according to
+ * @return DataFrame representing the data after the operation
+ */
+ def dedupWithOrder(df: DataFrame, groupCol: Column, orderCols: Column*): DataFrame = {
+ dedupTopN(df, 1, groupCol, orderCols: _*)
+ }
+
+ /**
+ * Used get the top N records (after ordering according to the provided order columns)
+ * in each group.
+ *
+ * @param df DataFrame to operate on
+ * @param n number of records to return from each group
+ * @param groupCol column to group by the records
+ * @param orderCols columns to order the records according to
+ * @return DataFrame representing the data after the operation
+ */
+ def dedupTopN(df: DataFrame,
+ n: Int,
+ groupCol: Column,
+ orderCols: Column*): DataFrame = {
+ val w = Window.partitionBy(groupCol).orderBy(orderCols: _*)
+ df.withColumn("rn", row_number.over(w)).where(col("rn") <= n).drop("rn")
+ }
+
+ /**
+ * Used to get the 'latest' record (after ordering according to the provided order columns)
+ * in each group.
+ * the same functionality as {@link #dedup} but implemented using UDAF to utilize
+ * map side aggregation.
+ * this function should be used in cases when you expect a large number of rows to get combined,
+ * as they share the same group column.
+ *
+ * @param df DataFrame to operate on
+ * @param groupCol column to group by the records
+ * @param orderByCol column to order the records according to
+ * @param desc have the order as desc
+ * @param moreAggFunctions more aggregate functions
+ * @param columnsFilter columns to filter
+ * @param columnsFilterKeep indicates whether we should filter the selected columns 'out'
+ * or alternatively have only those columns in the result
+ * @return DataFrame representing the data after the operation
+ */
+ def dedupWithCombiner(df: DataFrame,
+ groupCol: Column,
+ orderByCol: Column,
+ desc: Boolean = true,
+ moreAggFunctions: Seq[Column] = Nil,
+ columnsFilter: Seq[String] = Nil,
+ columnsFilterKeep: Boolean = true): DataFrame = {
+ val newDF =
+ if (columnsFilter == Nil) {
+ df.withColumn("sort_by_column", orderByCol)
+ } else {
+ if (columnsFilterKeep) {
+ df.withColumn("sort_by_column", orderByCol)
+ .select("sort_by_column", columnsFilter: _*)
+ } else {
+ df.select(
+ df.columns
+ .filter(colName => !columnsFilter.contains(colName))
+ .map(colName => new Column(colName)): _*)
+ .withColumn("sort_by_column", orderByCol)
+ }
+ }
+
+ val aggFunc =
+ if (desc) SparkOverwriteUDAFs.maxValueByKey(_: Column, _: Column)
+ else SparkOverwriteUDAFs.minValueByKey(_: Column, _: Column)
+
+ val df2 = newDF
+ .groupBy(groupCol.as("group_by_col"))
+ .agg(aggFunc(expr("sort_by_column"), expr("struct(sort_by_column, *)"))
+ .as("h1"),
+ struct(lit(1).as("lit_placeholder_col") +: moreAggFunctions: _*)
+ .as("h2"))
+ .selectExpr("h2.*", "h1.*")
+ .drop("lit_placeholder_col")
+ .drop("sort_by_column")
+ df2
+ }
+
+ /**
+ * Returns a DataFrame with the given column (should be a StructType)
+ * replaced by its inner fields.
+ * This method only flattens a single level of nesting.
+ *
+ * +-------+----------+----------+----------+
+ * |id |s.sub_col1|s.sub_col2|s.sub_col3|
+ * +-------+----------+----------+----------+
+ * |123 |1 |2 |3 |
+ * +-------+----------+----------+----------+
+ *
+ * +-------+----------+----------+----------+
+ * |id |sub_col1 |sub_col2 |sub_col3 |
+ * +-------+----------+----------+----------+
+ * |123 |1 |2 |3 |
+ * +-------+----------+----------+----------+
+ *
+ * @param df DataFrame to operate on
+ * @param colName column name for a column of type StructType
+ * @return DataFrame representing the data after the operation
+ */
+ def flatten(df: DataFrame, colName: String): DataFrame = {
+ assert(df.schema(colName).dataType.isInstanceOf[StructType],
+ s"Column $colName must be of type Struct")
+ val outerFields = df.schema.fields.map(_.name).toSet
+ val flattenFields = df
+ .schema(colName)
+ .dataType
+ .asInstanceOf[StructType]
+ .fields
+ .filter(f => !outerFields.contains(f.name))
+ .map("`" + colName + "`.`" + _.name + "`")
+ df.selectExpr("*" +: flattenFields: _*).drop(colName)
+ }
+
+ /**
+ * Returns a DataFrame with the column names renamed to the column names in the new schema
+ *
+ * @param df DataFrame to operate on
+ * @param newScheme new column names
+ * @return DataFrame representing the data after the operation
+ */
+ def changeSchema(df: DataFrame, newScheme: String*): DataFrame =
+ df.select(df.columns.zip(newScheme).map {
+ case (oldCol: String, newCol: String) => col(oldCol).as(newCol)
+ }: _*)
+
+ /**
+ * Used to perform a join when the right df is relatively small
+ * but still too big to fit in memory to perform map side broadcast join.
+ * Use cases:
+ * a. excluding keys that might be skewed from a medium size list.
+ * b. join a big skewed table with a table that has small number of very large rows.
+ *
+ * @param dfLeft left DataFrame
+ * @param dfRight right DataFrame
+ * @param joinExprs join expression
+ * @param numShards number of shards - number of times to duplicate the right DataFrame
+ * @param joinType join type
+ * @return joined DataFrame
+ */
+ def joinSkewed(dfLeft: DataFrame,
+ dfRight: DataFrame,
+ joinExprs: Column,
+ numShards: Int = 10,
+ joinType: String = "inner"): DataFrame = {
+ // skew join based on salting
+ // salts the left DF by adding another random column and join with the right DF after
+ // duplicating it
+ val ss = dfLeft.sparkSession
+ import ss.implicits._
+ val shards = 1.to(numShards).toDF("shard")
+ dfLeft
+ .withColumn("randLeft", ceil(rand() * numShards))
+ .join(dfRight.crossJoin(shards),
+ joinExprs and $"randLeft" === $"shard",
+ joinType)
+ .drop($"randLeft")
+ .drop($"shard")
+ }
+
+ /**
+ * Suitable to perform a join in cases when one DF is skewed and the other is not skewed.
+ * splits both of the DFs to two parts according to the skewed keys.
+ * 1. Map-join: broadcasts the skewed-keys part of the not skewed DF to the skewed-keys
+ * part of the skewed DF
+ * 2. Regular join: between the remaining two parts.
+ *
+ * @param notSkewed not skewed DataFrame
+ * @param skewed skewed DataFrame
+ * @param joinCol join column
+ * @param numRowsToBroadcast num of rows to broadcast
+ * @return DataFrame representing the data after the operation
+ */
+ def broadcastJoinSkewed(notSkewed: DataFrame,
+ skewed: DataFrame,
+ joinCol: String,
+ numRowsToBroadcast: Int): DataFrame = {
+ val ss = notSkewed.sparkSession
+ import ss.implicits._
+ val skewedKeys = skewed
+ .groupBy(joinCol)
+ .count()
+ .sort($"count".desc)
+ .limit(numRowsToBroadcast)
+ .drop("count")
+ .withColumnRenamed(joinCol, "skew_join_key")
+ .cache()
+
+ val notSkewedWithSkewIndicator = notSkewed
+ .join(broadcast(skewedKeys), $"skew_join_key" === col(joinCol), "left")
+ .withColumn("is_skewed_record", col("skew_join_key").isNotNull)
+ .drop("skew_join_key")
+ .persist(StorageLevel.DISK_ONLY)
+
+ // broadcast map-join, sending the notSkewed data
+ val bigRecordsJnd =
+ broadcast(notSkewedWithSkewIndicator.filter("is_skewed_record"))
+ .join(skewed, joinCol)
+
+ // regular join for the rest
+ val skewedWithoutSkewedKeys = skewed
+ .join(broadcast(skewedKeys), $"skew_join_key" === col(joinCol), "left")
+ .where("skew_join_key is null")
+ .drop("skew_join_key")
+ val smallRecordsJnd = notSkewedWithSkewIndicator
+ .filter("not is_skewed_record")
+ .join(skewedWithoutSkewedKeys, joinCol)
+
+ smallRecordsJnd
+ .union(bigRecordsJnd)
+ .drop("is_skewed_record", "skew_join_key")
+ }
+
+ /**
+ * Helper function to join a table with point column to a table with range column.
+ * For example, join a table that contains specific time in minutes with a table that
+ * contains time ranges.
+ * The main problem this function addresses is that doing naive explode on the ranges can result
+ * in a huge table.
+ * requires:
+ * 1. point table needs to be distinct on the point column. there could be a few corresponding
+ * ranges to each point, so we choose the minimal range.
+ * 2. the range and point columns need to be numeric.
+ *
+ * TIMES:
+ * +-------+
+ * |time |
+ * +-------+
+ * |11:55 |
+ * +-------+
+ *
+ * TIME RANGES:
+ * +----------+---------+----------+
+ * |start_time|end_time |desc |
+ * +----------+---------+----------+
+ * |10:00 |12:00 | meeting |
+ * +----------+---------+----------+
+ * |11:50 |12:15 | lunch |
+ * +----------+---------+----------+
+ *
+ * OUTPUT:
+ * +-------+----------+---------+---------+
+ * |time |start_time|end_time |desc |
+ * +-------+----------+---------+---------+
+ * |11:55 |10:00 |12:00 | meeting |
+ * +-------+----------+---------+---------+
+ * |11:55 |11:50 |12:15 | lunch |
+ * +-------+----------+---------+---------+
+ *
+ * @param dfSingle DataFrame that contains the point column
+ * @param colSingle the point column's name
+ * @param dfRange DataFrame that contains the range column
+ * @param colRangeStart the start range column's name
+ * @param colRangeEnd the end range column's name
+ * @param DECREASE_FACTOR resolution factor. instead of exploding the range column directly,
+ * we first decrease its resolution by this factor
+ * @return
+ */
+ def joinWithRange(dfSingle: DataFrame,
+ colSingle: String,
+ dfRange: DataFrame,
+ colRangeStart: String,
+ colRangeEnd: String,
+ DECREASE_FACTOR: Long): DataFrame = {
+ val dfJoined = joinWithRangeInternal(dfSingle,
+ colSingle,
+ dfRange,
+ colRangeStart,
+ colRangeEnd,
+ DECREASE_FACTOR)
+ dfJoined.drop("range_start",
+ "range_end",
+ "decreased_range_single",
+ "single",
+ "decreased_single",
+ "range_size")
+ }
+
+ private def joinWithRangeInternal(dfSingle: DataFrame,
+ colSingle: String,
+ dfRange: DataFrame,
+ colRangeStart: String,
+ colRangeEnd: String,
+ DECREASE_FACTOR: Long): DataFrame = {
+
+ import org.apache.spark.sql.functions.udf
+ val rangeUDF = udf((start: Long, end: Long) => (start to end).toArray)
+ val dfRange_exploded = dfRange
+ .withColumn("range_start", col(colRangeStart).cast(LongType))
+ .withColumn("range_end", col(colRangeEnd).cast(LongType))
+ .withColumn("decreased_range_single",
+ explode(
+ rangeUDF(col("range_start") / lit(DECREASE_FACTOR),
+ col("range_end") / lit(DECREASE_FACTOR))))
+
+ dfSingle
+ .withColumn("single", floor(col(colSingle).cast(LongType)))
+ .withColumn("decreased_single",
+ floor(col(colSingle).cast(LongType) / lit(DECREASE_FACTOR)))
+ .join(dfRange_exploded,
+ col("decreased_single") === col("decreased_range_single"),
+ "left_outer")
+ .withColumn("range_size", expr("(range_end - range_start + 1)"))
+ .filter("single>=range_start and single<=range_end")
+ }
+
+ /**
+ * Run joinWithRange and afterwards run dedup
+ *
+ * @param dedupSmallRange - by small/large range
+ *
+ * OUTPUT for dedupSmallRange = "true":
+ * +-------+----------+---------+---------+
+ * |time |start_time|end_time |desc |
+ * +-------+----------+---------+---------+
+ * |11:55 |11:50 |12:15 | lunch |
+ * +-------+----------+---------+---------+
+ *
+ * OUTPUT for dedupSmallRange = "false":
+ * +-------+----------+---------+---------+
+ * |time |start_time|end_time |desc |
+ * +-------+----------+---------+---------+
+ * |11:55 |10:00 |12:00 | meeting |
+ * +-------+----------+---------+---------+
+ *
+ */
+ def joinWithRangeAndDedup(dfSingle: DataFrame,
+ colSingle: String,
+ dfRange: DataFrame,
+ colRangeStart: String,
+ colRangeEnd: String,
+ DECREASE_FACTOR: Long,
+ dedupSmallRange: Boolean): DataFrame = {
+
+ val dfJoined = joinWithRangeInternal(dfSingle,
+ colSingle,
+ dfRange,
+ colRangeStart,
+ colRangeEnd,
+ DECREASE_FACTOR)
+
+ // "range_start" is here for consistency
+ val dfDeduped = if (dedupSmallRange) {
+ dedupWithCombiner(dfJoined,
+ col(colSingle),
+ struct("range_size", "range_start"),
+ desc = false)
+ } else {
+ dedupWithCombiner(dfJoined,
+ col(colSingle),
+ struct(expr("-range_size"), col("range_start")),
+ desc = true)
+ }
+
+ dfDeduped.drop("range_start",
+ "range_end",
+ "decreased_range_single",
+ "single",
+ "decreased_single",
+ "range_size")
+ }
+}
diff --git a/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala b/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala
new file mode 100644
index 0000000..3e0810b
--- /dev/null
+++ b/datafu-spark/src/main/scala/datafu/spark/SparkUDAFs.scala
@@ -0,0 +1,278 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package datafu.spark
+
+import scala.collection.{mutable, Map}
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
+import org.apache.spark.sql.types.{ArrayType, _}
+
+
+// @TODO: add documentation and tests to all the functions. maybe also expose in python.
+
+object SparkUDAFs {
+
+ /**
+ * Like Google's MultiSets.
+ * Aggregate function that creates a map of key to its count.
+ */
+ class MultiSet() extends UserDefinedAggregateFunction {
+
+ def inputSchema: StructType = new StructType().add("key", StringType)
+
+ def bufferSchema: StructType =
+ new StructType().add("mp", MapType(StringType, IntegerType))
+
+ def dataType: DataType = MapType(StringType, IntegerType, false)
+
+ def deterministic: Boolean = true
+
+ // This function is called whenever key changes
+ def initialize(buffer: MutableAggregationBuffer): Unit = {
+ buffer(0) = mutable.Map()
+ }
+
+ // Iterate over each entry of a group
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+ val key = input.getString(0)
+ if (key != null) {
+ buffer(0) = buffer.getMap(0) + (key -> (buffer
+ .getMap(0)
+ .getOrElse(key, 0) + 1))
+ }
+ }
+
+ // Merge two partial aggregates
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+ val mp = mutable.Map[String, Int]() ++= buffer1.getMap(0)
+ buffer2
+ .getMap(0)
+ .keys
+ .foreach((key: String) =>
+ if (key != null) {
+ mp.put(key,
+ mp.getOrElse(key, 0) + buffer2.getMap(0).getOrElse(key, 0))
+ })
+ buffer1(0) = mp
+ }
+
+ // Called after all the entries are exhausted.
+ def evaluate(buffer: Row): Any = {
+ buffer(0)
+ }
+
+ }
+
+ /**
+ * Essentially the same as MultiSet, but gets an Array for input.
+ * There is an extra option to limit the number of keys (like CountDistinctUpTo)
+ */
+ class MultiArraySet[T: Ordering](dt: DataType = StringType, maxKeys: Int = -1)
+ extends UserDefinedAggregateFunction {
+
+ def inputSchema: StructType = new StructType().add("key", ArrayType(dt))
+
+ def bufferSchema: StructType = new StructType().add("mp", dataType)
+
+ def dataType: DataType = MapType(dt, IntegerType, false)
+
+ def deterministic: Boolean = true
+
+ // This function is called whenever key changes
+ def initialize(buffer: MutableAggregationBuffer): Unit = {
+ buffer(0) = mutable.Map()
+ }
+
+ // Iterate over each entry of a group
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+ val mp = mutable.Map[T, Int]() ++= buffer.getMap(0)
+ val keyArr: Seq[T] = Option(input.getAs[Seq[T]](0)).getOrElse(Nil)
+ for (key <- keyArr; if key != null)
+ mp.put(key, mp.getOrElse(key, 0) + 1)
+
+ buffer(0) = limitKeys(mp, 3)
+ }
+
+ // Merge two partial aggregates
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+ val mp = mutable.Map[T, Int]() ++= buffer1.getMap(0)
+ buffer2
+ .getMap(0)
+ .keys
+ .foreach((key: T) =>
+ if (key != null) {
+ mp.put(key,
+ mp.getOrElse(key, 0) + buffer2.getMap(0).getOrElse(key, 0))
+ })
+
+ buffer1(0) = limitKeys(mp, 3)
+ }
+
+ private def limitKeys(mp: Map[T, Int], factor: Int = 1): Map[T, Int] = {
+ if (maxKeys > 0 && maxKeys * factor < mp.size) {
+ val k = mp.toList.map(_.swap).sorted.reverse(maxKeys - 1)._1
+ var mp2 = mutable.Map[T, Int]() ++= mp.filter((t: (T, Int)) =>
+ t._2 >= k)
+ var toRemove = mp2.size - maxKeys
+ if (toRemove > 0) {
+ mp2 = mp2.filter((t: (T, Int)) => {
+ if (t._2 > k) {
+ true
+ } else {
+ if (toRemove >= 0) {
+ toRemove = toRemove - 1
+ }
+ toRemove < 0
+ }
+ })
+ }
+ mp2
+ } else {
+ mp
+ }
+ }
+
+ // Called after all the entries are exhausted.
+ def evaluate(buffer: Row): Map[T, Int] = {
+ limitKeys(buffer.getMap(0).asInstanceOf[Map[T, Int]])
+ }
+
+ }
+
+ /**
+ * Merge maps of kind string -> set<string>
+ */
+ class MapSetMerge extends UserDefinedAggregateFunction {
+
+ def inputSchema: StructType = new StructType().add("key", dataType)
+
+ def bufferSchema: StructType = inputSchema
+
+ def dataType: DataType = MapType(StringType, ArrayType(StringType))
+
+ def deterministic: Boolean = true
+
+ // This function is called whenever key changes
+ def initialize(buffer: MutableAggregationBuffer): Unit = {
+ buffer(0) = mutable.Map()
+ }
+
+ // Iterate over each entry of a group
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+ val mp0 = input.getMap(0)
+ if (mp0 != null) {
+ val mp = mutable.Map[String, mutable.WrappedArray[String]]() ++= input
+ .getMap(0)
+ buffer(0) =
+ merge(mp, buffer.getMap[String, mutable.WrappedArray[String]](0))
+ }
+ }
+
+ // Merge two partial aggregates
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+ val mp = mutable.Map[String, mutable.WrappedArray[String]]() ++= buffer1
+ .getMap(0)
+ buffer1(0) =
+ merge(mp, buffer2.getMap[String, mutable.WrappedArray[String]](0))
+ }
+
+ def merge(mpBuffer: mutable.Map[String, mutable.WrappedArray[String]],
+ mp: Map[String, mutable.WrappedArray[String]])
+ : mutable.Map[String, mutable.WrappedArray[String]] = {
+ if (mp != null) {
+ mp.keys.foreach((key: String) => {
+ val blah1: mutable.WrappedArray[String] =
+ mpBuffer.getOrElse(key, mutable.WrappedArray.empty)
+ val blah2: mutable.WrappedArray[String] =
+ mp.getOrElse(key, mutable.WrappedArray.empty)
+ mpBuffer.put(
+ key,
+ mutable.WrappedArray.make(
+ (Option(blah1).getOrElse(mutable.WrappedArray.empty) ++ Option(
+ blah2).getOrElse(mutable.WrappedArray.empty)).toSet.toArray)
+ )
+ })
+ }
+
+ mpBuffer
+ }
+
+ // Called after all the entries are exhausted.
+ def evaluate(buffer: Row): Any = {
+ buffer(0)
+ }
+
+ }
+
+ /**
+ * Counts number of distinct records, but only up to a preset amount -
+ * more efficient than an unbounded count
+ */
+ class CountDistinctUpTo(maxItems: Int = -1)
+ extends UserDefinedAggregateFunction {
+
+ def inputSchema: StructType = new StructType().add("key", StringType)
+
+ def bufferSchema: StructType =
+ new StructType().add("mp", MapType(StringType, BooleanType))
+
+ def dataType: DataType = IntegerType
+
+ def deterministic: Boolean = true
+
+ // This function is called whenever key changes
+ def initialize(buffer: MutableAggregationBuffer): Unit = {
+ buffer(0) = mutable.Map()
+ }
+
+ // Iterate over each entry of a group
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+ if (buffer.getMap(0).size < maxItems) {
+ val key = input.getString(0)
+ if (key != null) {
+ buffer(0) = buffer.getMap(0) + (key -> true)
+ }
+ }
+ }
+
+ // Merge two partial aggregates
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+ if (buffer1.getMap(0).size < maxItems) {
+ val mp = mutable.Map[String, Boolean]() ++= buffer1.getMap(0)
+ buffer2
+ .getMap(0)
+ .keys
+ .foreach((key: String) =>
+ if (key != null) {
+ mp.put(key, true)
+ })
+ buffer1(0) = mp
+ }
+
+ }
+
+ // Called after all the entries are exhausted.
+ def evaluate(buffer: Row): Int = {
+ math.min(buffer.getMap(0).size, maxItems)
+ }
+
+ }
+
+}
diff --git a/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala b/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala
new file mode 100644
index 0000000..b1b81b8
--- /dev/null
+++ b/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.types
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
+import org.apache.spark.sql.catalyst.util.TypeUtils
+
+object SparkOverwriteUDAFs {
+ def minValueByKey(key: Column, value: Column): Column =
+ Column(MinValueByKey(key.expr, value.expr).toAggregateExpression(false))
+ def maxValueByKey(key: Column, value: Column): Column =
+ Column(MaxValueByKey(key.expr, value.expr).toAggregateExpression(false))
+}
+
+case class MinValueByKey(child1: Expression, child2: Expression)
+ extends ExtramumValueByKey(child1, child2, LessThan)
+case class MaxValueByKey(child1: Expression, child2: Expression)
+ extends ExtramumValueByKey(child1, child2, GreaterThan)
+
+abstract class ExtramumValueByKey(
+ child1: Expression,
+ child2: Expression,
+ bComp: (Expression, Expression) => BinaryComparison)
+ extends DeclarativeAggregate
+ with ExpectsInputTypes {
+
+ override def children: Seq[Expression] = child1 :: child2 :: Nil
+
+ override def nullable: Boolean = true
+
+ // Return data type.
+ override def dataType: DataType = child2.dataType
+
+ // Expected input data type.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, AnyDataType)
+
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForOrderingExpr(child1.dataType, "function minmax")
+
+ private lazy val minmax = AttributeReference("minmax", child1.dataType)()
+ private lazy val data = AttributeReference("data", child2.dataType)()
+
+ override lazy val aggBufferAttributes
+ : Seq[AttributeReference] = minmax :: data :: Nil
+
+ override lazy val initialValues: Seq[Expression] = Seq(
+ Literal.create(null, child1.dataType),
+ Literal.create(null, child2.dataType)
+ )
+
+ override lazy val updateExpressions: Seq[Expression] =
+ chooseKeyValue(minmax, data, child1, child2)
+
+ override lazy val mergeExpressions: Seq[Expression] =
+ chooseKeyValue(minmax.left, data.left, minmax.right, data.right)
+
+ private def chooseKeyValue(key1: Expression,
+ value1: Expression,
+ key2: Expression,
+ value2: Expression) = Seq(
+ If(IsNull(key1),
+ key2,
+ If(IsNull(key2), key1, If(bComp(key1, key2), key1, key2))),
+ If(IsNull(key1),
+ value2,
+ If(IsNull(key2), value1, If(bComp(key1, key2), value1, value2)))
+ )
+
+ override lazy val evaluateExpression: AttributeReference = data
+}
diff --git a/datafu-spark/src/main/scala/spark/utils/overwrites/SparkPythonRunner.scala b/datafu-spark/src/main/scala/spark/utils/overwrites/SparkPythonRunner.scala
new file mode 100644
index 0000000..e86bed5
--- /dev/null
+++ b/datafu-spark/src/main/scala/spark/utils/overwrites/SparkPythonRunner.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy
+
+import java.io._
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import datafu.spark.ScalaPythonBridge
+import org.apache.log4j.Logger
+
+import org.apache.spark.api.python.PythonUtils
+import org.apache.spark.util.Utils
+
+/**
+ * Internal class - should not be used by user
+ *
+ * background:
+ * We had to "override" Spark's PythonRunner because we failed on premature python process closing.
+ * In PythonRunner the python process exits immediately when finished to read the file,
+ * this caused us to Accumulators Exceptions when the driver tries to get accumulation data
+ * from the python gateway.
+ * Instead, like in Zeppelin, we create an "interactive" python process, feed it the python
+ * script and not closing the gateway.
+ */
+case class SparkPythonRunner(pyPaths: String,
+ otherArgs: Array[String] = Array()) {
+
+ val logger: Logger = Logger.getLogger(getClass)
+ val (reader, writer, process) = initPythonEnv()
+
+ def runPyFile(pythonFile: String): String = {
+
+ val formattedPythonFile = PythonRunner.formatPath(pythonFile)
+ execFile(formattedPythonFile, writer, reader)
+
+ }
+
+ private def initPythonEnv(): (BufferedReader, BufferedWriter, Process) = {
+
+ val pythonExec =
+ sys.env.getOrElse("PYSPARK_DRIVER_PYTHON",
+ sys.env.getOrElse("PYSPARK_PYTHON", "python"))
+
+ // Format python filename paths before adding them to the PYTHONPATH
+ val formattedPyFiles = PythonRunner.formatPaths(pyPaths)
+
+ // Launch a Py4J gateway server for the process to connect to; this will let it see our
+ // Java system properties and such
+ val gatewayServer = new py4j.GatewayServer(ScalaPythonBridge, 0)
+ val thread = new Thread(new Runnable() {
+ override def run(): Unit = Utils.logUncaughtExceptions {
+ gatewayServer.start()
+ }
+ })
+ thread.setName("py4j-gateway-init")
+ thread.setDaemon(true)
+ thread.start()
+
+ // Wait until the gateway server has started, so that we know which port is it bound to.
+ // `gatewayServer.start()` will start a new thread and run the server code there, after
+ // initializing the socket, so the thread started above will end as soon as the server is
+ // ready to serve connections.
+ thread.join()
+
+ // Build up a PYTHONPATH that includes the Spark assembly JAR (where this class is), the
+ // python directories in SPARK_HOME (if set), and any files in the pyPaths argument
+ val pathElements = new ArrayBuffer[String]
+ pathElements ++= formattedPyFiles
+ pathElements += PythonUtils.sparkPythonPath
+ pathElements += sys.env.getOrElse("PYTHONPATH", "")
+ val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*)
+ logger.info(
+ s"Running python with PYTHONPATH:\n\t${formattedPyFiles.mkString(",")}")
+
+ // Launch Python process
+ val builder = new ProcessBuilder(
+ (Seq(pythonExec, "-iu") ++ otherArgs).asJava)
+ val env = builder.environment()
+ env.put("PYTHONPATH", pythonPath)
+ // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
+ env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
+ env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
+ env.put("PYSPARK_ALLOW_INSECURE_GATEWAY", "1") // needed for Spark 2.4.1 and newer, will stop working in Spark 3.x
+ builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
+ val process = builder.start()
+ val writer = new BufferedWriter(
+ new OutputStreamWriter(process.getOutputStream))
+ val reader = new BufferedReader(
+ new InputStreamReader(process.getInputStream))
+
+ (reader, writer, process)
+ }
+
+ private def execFile(filename: String,
+ writer: BufferedWriter,
+ reader: BufferedReader): String = {
+ writer.write("import traceback\n")
+ writer.write("try:\n")
+ writer.write(" execfile('" + filename + "')\n")
+ writer.write(" print (\"*!?flush reader!?*\")\n")
+ writer.write("except Exception as e:\n")
+ writer.write(" traceback.print_exc()\n")
+ writer.write(" print (\"*!?flush error reader!?*\")\n\n")
+// writer.write(" exit(1)\n\n")
+ writer.flush()
+ var output = ""
+ var line: String = reader.readLine
+ while (!line.contains("*!?flush reader!?*") && !line.contains(
+ "*!?flush error reader!?*")) {
+ logger.info(line)
+ if (line == "...") {
+ output += "Syntax error ! "
+ }
+ output += "\r" + line + "\n"
+ line = reader.readLine
+ }
+
+ if (line.contains("*!?flush error reader!?*")) {
+ throw new RuntimeException("python bridge error: " + output)
+ }
+
+ output
+ }
+
+}
diff --git a/datafu-spark/src/test/resources/META-INF/services/datafu.spark.PythonResource b/datafu-spark/src/test/resources/META-INF/services/datafu.spark.PythonResource
new file mode 100644
index 0000000..072cde9
--- /dev/null
+++ b/datafu-spark/src/test/resources/META-INF/services/datafu.spark.PythonResource
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+datafu.spark.ExampleFiles
+datafu.spark.Py4JResource
+datafu.spark.PysparkResource
diff --git a/datafu-spark/src/test/resources/log4j.properties b/datafu-spark/src/test/resources/log4j.properties
new file mode 100644
index 0000000..bc52d61
--- /dev/null
+++ b/datafu-spark/src/test/resources/log4j.properties
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+log4j.rootCategory=WARN, stdout
+
+log4j.appender.stdout=org.apache.log4j.ConsoleAppender
+log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
+log4j.appender.stdout.layout.ConversionPattern=%d %5p %c{3} - %m%n
diff --git a/datafu-spark/src/test/resources/python_tests/df_utils_tests.py b/datafu-spark/src/test/resources/python_tests/df_utils_tests.py
new file mode 100644
index 0000000..bf5d068
--- /dev/null
+++ b/datafu-spark/src/test/resources/python_tests/df_utils_tests.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This file is used by the datafu-spark unit tests
+
+import os
+import sys
+from pprint import pprint as p
+
+from pyspark_utils import df_utils
+
+p('CHECKING IF PATHS EXISTS:')
+for x in sys.path:
+ p('PATH ' + x + ': ' + str(os.path.exists(x)))
+
+
+df_utils.activate()
+
+df_people = sqlContext.createDataFrame([
+ ("a", "Alice", 34),
+ ("a", "Sara", 33),
+ ("b", "Bob", 36),
+ ("b", "Charlie", 30),
+ ("c", "David", 29),
+ ("c", "Esther", 32),
+ ("c", "Fanny", 36),
+ ("c", "Zoey", 36)],
+ ["id", "name", "age"])
+
+func_dedup_res = df_people.dedup_with_order(group_col=df_people.id,
+ order_cols=[df_people.age.desc(), df_people.name.desc()])
+func_dedup_res.registerTempTable("dedup_with_order")
+
+func_dedupTopN_res = df_people.dedup_top_n(n=2, group_col=df_people.id,
+ order_cols=[df_people.age.desc(), df_people.name.desc()])
+func_dedupTopN_res.registerTempTable("dedupTopN")
+
+func_dedup2_res = df_people.dedup_with_combiner(group_col=df_people.id, order_by_col=df_people.age, desc=True,
+ columns_filter=["name"], columns_filter_keep=False)
+func_dedup2_res.registerTempTable("dedup_with_combiner")
+
+func_changeSchema_res = df_people.change_schema(new_scheme=["id1", "name1", "age1"])
+func_changeSchema_res.registerTempTable("changeSchema")
+
+df_people2 = sqlContext.createDataFrame([
+ ("a", "Laura", 34),
+ ("a", "Stephani", 33),
+ ("b", "Margaret", 36)],
+ ["id", "name", "age"])
+
+simpleDF = sqlContext.createDataFrame([
+ ("a", "1")],
+ ["id", "value"])
+from pyspark.sql.functions import expr
+
+func_joinSkewed_res = df_utils.join_skewed(df_left=df_people2.alias("df1"), df_right=simpleDF.alias("df2"),
+ join_exprs=expr("df1.id == df2.id"), num_shards=5,
+ join_type="inner")
+func_joinSkewed_res.registerTempTable("joinSkewed")
+
+func_broadcastJoinSkewed_res = df_utils.broadcast_join_skewed(not_skewed_df=df_people2, skewed_df=simpleDF, join_col="id",
+ number_of_custs_to_broadcast=5)
+func_broadcastJoinSkewed_res.registerTempTable("broadcastJoinSkewed")
+
+dfRange = sqlContext.createDataFrame([
+ ("a", 34, 36)],
+ ["id1", "start", "end"])
+func_joinWithRange_res = df_utils.join_with_range(df_single=df_people2, col_single="age", df_range=dfRange,
+ col_range_start="start", col_range_end="end",
+ decrease_factor=5)
+func_joinWithRange_res.registerTempTable("joinWithRange")
+
+func_joinWithRangeAndDedup_res = df_utils.join_with_range_and_dedup(df_single=df_people2, col_single="age", df_range=dfRange,
+ col_range_start="start", col_range_end="end",
+ decrease_factor=5, dedup_small_range=True)
+func_joinWithRangeAndDedup_res.registerTempTable("joinWithRangeAndDedup")
diff --git a/datafu-spark/src/test/resources/python_tests/pyfromscala.py b/datafu-spark/src/test/resources/python_tests/pyfromscala.py
new file mode 100644
index 0000000..3162ff4
--- /dev/null
+++ b/datafu-spark/src/test/resources/python_tests/pyfromscala.py
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Some examples of cross python-Scala functionality
+# This file is used by the datafu-spark unit tests
+
+
+# print the PYTHONPATH
+import sys
+from pprint import pprint as p
+p(sys.path)
+
+from pyspark.sql import functions as F
+
+
+import os
+print os.getcwd()
+
+
+###############################################################
+# query scala defined DF
+###############################################################
+dfout = sqlContext.sql("select num * 2 as d from dfin")
+dfout.registerTempTable("dfout")
+dfout.groupBy(dfout['d']).count().show()
+sqlContext.sql("select count(*) as cnt from dfout").show()
+dfout.groupBy(dfout['d']).agg(F.count(F.col('d')).alias('cnt')).show()
+
+sqlContext.sql("select d * 4 as d from dfout").registerTempTable("dfout2")
+
+
+###############################################################
+# check python UDFs
+###############################################################
+
+def magic_func(s):
+
+ return s + " magic"
+
+sqlContext.udf.register("magic", magic_func)
+
+
+###############################################################
+# check sc.textFile
+###############################################################
+
+DEL = '\x10'
+
+from pyspark.sql.types import StructType, StructField
+from pyspark.sql.types import StringType
+
+schema = StructType([
+ StructField("A", StringType()),
+ StructField("B", StringType())
+])
+
+txt_df = sqlContext.read.csv('src/test/resources/text.csv', sep=DEL, schema=schema)
+
+print type(txt_df)
+print dir(txt_df)
+print txt_df.count()
+
+txt_df.show()
+
+txt_df2 = sc.textFile('src/test/resources/text.csv').map(lambda x: x.split(DEL)).toDF()
+txt_df2.show()
+
+
+###############################################################
+# convert python dict to DataFrame
+###############################################################
+
+d = {'a': 0.1, 'b': 2}
+d = [(k,1.0*d[k]) for k in d]
+stats_df = sc.parallelize(d, 1).toDF(["name", "val"])
+stats_df.registerTempTable('stats')
+
+sqlContext.table("stats").show()
+
+
diff --git a/datafu-spark/src/test/resources/python_tests/pyfromscala_with_error.py b/datafu-spark/src/test/resources/python_tests/pyfromscala_with_error.py
new file mode 100644
index 0000000..d784662
--- /dev/null
+++ b/datafu-spark/src/test/resources/python_tests/pyfromscala_with_error.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This file is used by the datafu-spark unit tests
+
+sqlContext.sql("select * from edw.table_not_exists")
diff --git a/datafu-spark/src/test/resources/text.csv b/datafu-spark/src/test/resources/text.csv
new file mode 100644
index 0000000..c0d981d
--- /dev/null
+++ b/datafu-spark/src/test/resources/text.csv
@@ -0,0 +1,5 @@
+14
+52
+47
+38
+03
\ No newline at end of file
diff --git a/datafu-spark/src/test/scala/datafu/spark/PySparkLibTestResources.scala b/datafu-spark/src/test/scala/datafu/spark/PySparkLibTestResources.scala
new file mode 100644
index 0000000..5086295
--- /dev/null
+++ b/datafu-spark/src/test/scala/datafu/spark/PySparkLibTestResources.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package datafu.spark
+
+import java.io.File
+
+class PysparkResource extends PythonResource(PathsResolver.pyspark, true)
+
+class Py4JResource extends PythonResource(PathsResolver.py4j, true)
+
+object PathsResolver {
+
+ val sparkSystemVersion = System.getProperty("datafu.spark.version")
+
+ val py4js = Map(
+ "2.1.0" -> "0.10.4",
+ "2.1.1" -> "0.10.4",
+ "2.1.2" -> "0.10.4",
+ "2.1.3" -> "0.10.4",
+ "2.2.0" -> "0.10.7",
+ "2.2.1" -> "0.10.7",
+ "2.2.2" -> "0.10.7",
+ "2.3.0" -> "0.10.6",
+ "2.3.1" -> "0.10.7",
+ "2.3.2" -> "0.10.7",
+ "2.4.0" -> "0.10.8.1",
+ "2.4.1" -> "0.10.8.1",
+ "2.4.2" -> "0.10.8.1",
+ "2.4.3" -> "0.10.8.1"
+ )
+
+ val sparkVersion = if (sparkSystemVersion == null) "2.3.0" else sparkSystemVersion
+
+ val py4jVersion = py4js.getOrElse(sparkVersion, "0.10.6") // our default
+
+ val pyspark = ResourceCloning.cloneResource(new File("data/pysparks/pyspark-" + sparkVersion + ".zip").toURI().toURL(),
+ "pyspark_cloned.zip").getPath
+ val py4j = ResourceCloning.cloneResource(new File("data/py4js/py4j-" + py4jVersion + "-src.zip").toURI().toURL(),
+ "py4j_cloned.zip").getPath
+}
diff --git a/datafu-spark/src/test/scala/datafu/spark/TestScalaPythonBridge.scala b/datafu-spark/src/test/scala/datafu/spark/TestScalaPythonBridge.scala
new file mode 100644
index 0000000..aa6ed52
--- /dev/null
+++ b/datafu-spark/src/test/scala/datafu/spark/TestScalaPythonBridge.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package datafu.spark
+
+import java.io.File
+
+import scala.util.Try
+
+import com.holdenkarau.spark.testing.Utils
+import org.junit._
+import org.junit.runner.RunWith
+import org.scalatest.FunSuite
+import org.scalatest.junit.JUnitRunner
+import org.slf4j.LoggerFactory
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.SparkSession
+
+object TestScalaPythonBridge {
+
+ val logger = LoggerFactory.getLogger(this.getClass)
+
+ def getNewRunner(): ScalaPythonBridgeRunner = {
+ val runner = ScalaPythonBridgeRunner()
+ runner.runPythonFile("pyspark_utils/init_spark_context.py")
+ runner
+ }
+
+ def getNewSparkSession(): SparkSession = {
+
+ val tempDir = Utils.createTempDir()
+ val localMetastorePath = new File(tempDir, "metastore").getCanonicalPath
+ val localWarehousePath = new File(tempDir, "wharehouse").getCanonicalPath
+ val pythonPath =
+ PythonPathsManager.getAbsolutePaths().mkString(File.pathSeparator)
+ logger.info("Creating SparkConf with PYTHONPATH: " + pythonPath)
+ val sparkConf = new SparkConf()
+ .setMaster("local[1]")
+ .set("spark.sql.warehouse.dir", localWarehousePath)
+ .set("javax.jdo.option.ConnectionURL",
+ s"jdbc:derby:;databaseName=$localMetastorePath;create=true")
+ .setExecutorEnv(Seq(("PYTHONPATH", pythonPath)))
+ .setAppName("Spark Unit Test")
+
+ val builder = SparkSession.builder().config(sparkConf).enableHiveSupport()
+ val spark = builder.getOrCreate()
+
+ spark
+ }
+}
+
+@RunWith(classOf[JUnitRunner])
+class TestScalaPythonBridge extends FunSuite {
+
+ private val spark = TestScalaPythonBridge.getNewSparkSession
+ private lazy val runner = TestScalaPythonBridge.getNewRunner()
+
+ def assertTable(tableName: String, expected: String): Unit =
+ Assert.assertEquals(
+ expected,
+ spark.table(tableName).collect().sortBy(_.toString).mkString(", "))
+
+ test("pyfromscala.py") {
+
+ import spark.implicits._
+
+ val dfin = spark.sparkContext.parallelize(1 to 10).toDF("num")
+ dfin.createOrReplaceTempView("dfin")
+
+ runner.runPythonFile("python_tests/pyfromscala.py")
+
+ // try to invoke python udf from scala code
+ assert(
+ spark
+ .sql("select magic('python_udf')")
+ .collect()
+ .mkString(",") == "[python_udf magic]")
+
+ assertTable("dfout",
+ "[10], [12], [14], [16], [18], [20], [2], [4], [6], [8]")
+ assertTable("dfout2",
+ "[16], [24], [32], [40], [48], [56], [64], [72], [80], [8]")
+ assertTable("stats", "[a,0.1], [b,2.0]")
+ }
+
+ test("pyfromscala_with_error.py") {
+ val t = Try(runner.runPythonFile("python_tests/pyfromscala_with_error.py"))
+ assert(t.isFailure)
+ assert(t.failed.get.isInstanceOf[RuntimeException])
+ }
+
+ test("SparkDFUtilsBridge") {
+ runner.runPythonFile("python_tests/df_utils_tests.py")
+ assertTable("dedup_with_order", "[a,Alice,34], [b,Bob,36], [c,Zoey,36]")
+ assertTable(
+ "dedupTopN",
+ "[a,Alice,34], [a,Sara,33], [b,Bob,36], [b,Charlie,30], [c,Fanny,36], [c,Zoey,36]")
+ assertTable("dedup_with_combiner", "[a,34], [b,36], [c,36]")
+ assertTable(
+ "changeSchema",
+ "[a,Alice,34], [a,Sara,33], [b,Bob,36], [b,Charlie,30], [c,David,29], [c,Esther,32], " +
+ "[c,Fanny,36], [c,Zoey,36]")
+ assertTable("joinSkewed", "[a,Laura,34,a,1], [a,Stephani,33,a,1]")
+ assertTable("broadcastJoinSkewed", "[a,Laura,34,1], [a,Stephani,33,1]")
+ assertTable("joinWithRange",
+ "[a,Laura,34,a,34,36], [b,Margaret,36,a,34,36]")
+ assertTable("joinWithRangeAndDedup",
+ "[a,Laura,34,a,34,36], [b,Margaret,36,a,34,36]")
+ }
+
+}
+
+class ExampleFiles extends PythonResource("python_tests")
diff --git a/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala b/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala
new file mode 100644
index 0000000..b148c35
--- /dev/null
+++ b/datafu-spark/src/test/scala/datafu/spark/TestSparkDFUtils.scala
@@ -0,0 +1,378 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package datafu.spark
+
+import com.holdenkarau.spark.testing.DataFrameSuiteBase
+import org.junit.runner.RunWith
+import org.scalatest.FunSuite
+import org.scalatest.junit.JUnitRunner
+
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+
+@RunWith(classOf[JUnitRunner])
+class DataFrameOpsTests extends FunSuite with DataFrameSuiteBase {
+
+ import DataFrameOps._
+
+ import spark.implicits._
+
+ val inputSchema = List(
+ StructField("col_grp", StringType, true),
+ StructField("col_ord", IntegerType, false),
+ StructField("col_str", StringType, true)
+ )
+
+ val dedupSchema = List(
+ StructField("col_grp", StringType, true),
+ StructField("col_ord", IntegerType, false)
+ )
+
+ lazy val inputRDD = sc.parallelize(
+ Seq(Row("a", 1, "asd1"),
+ Row("a", 2, "asd2"),
+ Row("a", 3, "asd3"),
+ Row("b", 1, "asd4")))
+
+ lazy val inputDataFrame =
+ sqlContext.createDataFrame(inputRDD, StructType(inputSchema)).cache
+
+ test("dedup") {
+ val expected: DataFrame =
+ sqlContext.createDataFrame(sc.parallelize(Seq(Row("b", 1), Row("a", 3))),
+ StructType(dedupSchema))
+
+ assertDataFrameEquals(expected,
+ inputDataFrame
+ .dedupWithOrder($"col_grp", $"col_ord".desc)
+ .select($"col_grp", $"col_ord"))
+ }
+
+ case class dedupExp(col2: String,
+ col_grp: String,
+ col_ord: Option[Int],
+ col_str: String)
+
+ test("dedup2_by_int") {
+
+ val expectedByIntDf: DataFrame = sqlContext.createDataFrame(
+ List(dedupExp("asd4", "b", Option(1), "asd4"),
+ dedupExp("asd1", "a", Option(3), "asd3")))
+
+ val actual = inputDataFrame.dedupWithCombiner($"col_grp",
+ $"col_ord",
+ moreAggFunctions = Seq(min($"col_str")))
+
+ assertDataFrameEquals(expectedByIntDf, actual)
+ }
+
+ case class dedupExp2(col_grp: String, col_ord: Option[Int], col_str: String)
+
+ test("dedup2_by_string_asc") {
+
+ val actual = inputDataFrame.dedupWithCombiner($"col_grp", $"col_str", desc = false)
+
+ val expectedByStringDf: DataFrame = sqlContext.createDataFrame(
+ List(dedupExp2("b", Option(1), "asd4"),
+ dedupExp2("a", Option(1), "asd1")))
+
+ assertDataFrameEquals(expectedByStringDf, actual)
+ }
+
+ test("test_dedup2_by_complex_column") {
+
+ val actual = inputDataFrame.dedupWithCombiner($"col_grp",
+ expr("cast(concat('-',col_ord) as int)"),
+ desc = false)
+
+ val expectedComplex: DataFrame = sqlContext.createDataFrame(
+ List(dedupExp2("b", Option(1), "asd4"),
+ dedupExp2("a", Option(3), "asd3")))
+
+ assertDataFrameEquals(expectedComplex, actual)
+ }
+
+ case class Inner(col_grp: String, col_ord: Int)
+
+ case class expComplex(
+ col_grp: String,
+ col_ord: Option[Int],
+ col_str: String,
+ arr_col: Array[String],
+ struct_col: Inner,
+ map_col: Map[String, Int]
+ )
+
+ test("test_dedup2_with_other_complex_column") {
+
+ val actual = inputDataFrame
+ .withColumn("arr_col", expr("array(col_grp, col_ord)"))
+ .withColumn("struct_col", expr("struct(col_grp, col_ord)"))
+ .withColumn("map_col", expr("map(col_grp, col_ord)"))
+ .withColumn("map_col_blah", expr("map(col_grp, col_ord)"))
+ .dedupWithCombiner($"col_grp", expr("cast(concat('-',col_ord) as int)"))
+ .drop("map_col_blah")
+
+ val expected: DataFrame = sqlContext.createDataFrame(
+ List(
+ expComplex("b",
+ Option(1),
+ "asd4",
+ Array("b", "1"),
+ Inner("b", 1),
+ Map("b" -> 1)),
+ expComplex("a",
+ Option(1),
+ "asd1",
+ Array("a", "1"),
+ Inner("a", 1),
+ Map("a" -> 1))
+ ))
+
+ assertDataFrameEquals(expected, actual)
+ }
+
+ val dedupTopNExpectedSchema = List(
+ StructField("col_grp", StringType, true),
+ StructField("col_ord", IntegerType, false)
+ )
+
+ test("test_dedup_top_n") {
+ val actual = inputDataFrame
+ .dedupTopN(2, $"col_grp", $"col_ord".desc)
+ .select($"col_grp", $"col_ord")
+
+ val expected = sqlContext.createDataFrame(
+ sc.parallelize(Seq(Row("b", 1), Row("a", 3), Row("a", 2))),
+ StructType(dedupTopNExpectedSchema))
+
+ assertDataFrameEquals(expected, actual)
+ }
+
+ val schema2 = List(
+ StructField("start", IntegerType, false),
+ StructField("end", IntegerType, false),
+ StructField("desc", StringType, true)
+ )
+
+ val expectedSchemaRangedJoin = List(
+ StructField("col_grp", StringType, true),
+ StructField("col_ord", IntegerType, false),
+ StructField("col_str", StringType, true),
+ StructField("start", IntegerType, true),
+ StructField("end", IntegerType, true),
+ StructField("desc", StringType, true)
+ )
+
+ test("join_with_range") {
+ val joinWithRangeDataFrame =
+ sqlContext.createDataFrame(sc.parallelize(
+ Seq(Row(1, 2, "asd1"),
+ Row(1, 4, "asd2"),
+ Row(3, 5, "asd3"),
+ Row(3, 10, "asd4"))),
+ StructType(schema2))
+
+ val expected = sqlContext.createDataFrame(
+ sc.parallelize(
+ Seq(
+ Row("b", 1, "asd4", 1, 2, "asd1"),
+ Row("a", 2, "asd2", 1, 2, "asd1"),
+ Row("a", 1, "asd1", 1, 2, "asd1"),
+ Row("b", 1, "asd4", 1, 4, "asd2"),
+ Row("a", 3, "asd3", 1, 4, "asd2"),
+ Row("a", 2, "asd2", 1, 4, "asd2"),
+ Row("a", 1, "asd1", 1, 4, "asd2"),
+ Row("a", 3, "asd3", 3, 5, "asd3"),
+ Row("a", 3, "asd3", 3, 10, "asd4")
+ )),
+ StructType(expectedSchemaRangedJoin)
+ )
+
+ val actual = inputDataFrame.joinWithRange("col_ord",
+ joinWithRangeDataFrame,
+ "start",
+ "end")
+
+ assertDataFrameEquals(expected, actual)
+ }
+
+ val expectedSchemaRangedJoinWithDedup = List(
+ StructField("col_grp", StringType, true),
+ StructField("col_ord", IntegerType, true),
+ StructField("col_str", StringType, true),
+ StructField("start", IntegerType, true),
+ StructField("end", IntegerType, true),
+ StructField("desc", StringType, true)
+ )
+
+ test("join_with_range_and_dedup") {
+ val df = sc
+ .parallelize(
+ List(("a", 1, "asd1"),
+ ("a", 2, "asd2"),
+ ("a", 3, "asd3"),
+ ("b", 1, "asd4")))
+ .toDF("col_grp", "col_ord", "col_str")
+ val dfr = sc
+ .parallelize(
+ List((1, 2, "asd1"), (1, 4, "asd2"), (3, 5, "asd3"), (3, 10, "asd4")))
+ .toDF("start", "end", "desc")
+
+ val expected = sqlContext.createDataFrame(
+ sc.parallelize(
+ Seq(
+ Row("b", 1, "asd4", 1, 2, "asd1"),
+ Row("a", 3, "asd3", 3, 5, "asd3"),
+ Row("a", 2, "asd2", 1, 2, "asd1")
+ )),
+ StructType(expectedSchemaRangedJoinWithDedup)
+ )
+
+ val actual = df.joinWithRangeAndDedup("col_ord", dfr, "start", "end")
+
+ assertDataFrameEquals(expected, actual)
+ }
+
+ test("broadcastJoinSkewed") {
+ val skewedList = List(("1", "a"),
+ ("1", "b"),
+ ("1", "c"),
+ ("1", "d"),
+ ("1", "e"),
+ ("2", "k"),
+ ("0", "k"))
+ val skewed =
+ sqlContext.createDataFrame(skewedList).toDF("key", "val_skewed")
+ val notSkewed = sqlContext
+ .createDataFrame((1 to 10).map(i => (i.toString, s"str$i")))
+ .toDF("key", "val")
+
+ val expected = sqlContext
+ .createDataFrame(
+ List(
+ ("1", "str1", "a"),
+ ("1", "str1", "b"),
+ ("1", "str1", "c"),
+ ("1", "str1", "d"),
+ ("1", "str1", "e"),
+ ("2", "str2", "k")
+ ))
+ .toDF("key", "val", "val_skewed")
+
+ val actual1 = notSkewed.broadcastJoinSkewed(skewed, "key", 1)
+
+ assertDataFrameEquals(expected, actual1.sort($"val_skewed"))
+
+ val actual2 = notSkewed.broadcastJoinSkewed(skewed, "key", 2)
+
+ assertDataFrameEquals(expected, actual2.sort($"val_skewed"))
+ }
+
+ // because of nulls in expected data, an actual schema needs to be used
+ case class expJoinSkewed(str1: String,
+ str2: String,
+ str3: String,
+ str4: String)
+
+ test("joinSkewed") {
+ val skewedList = List(("1", "a"),
+ ("1", "b"),
+ ("1", "c"),
+ ("1", "d"),
+ ("1", "e"),
+ ("2", "k"),
+ ("0", "k"))
+ val skewed =
+ sqlContext.createDataFrame(skewedList).toDF("key", "val_skewed")
+ val notSkewed = sqlContext
+ .createDataFrame((1 to 10).map(i => (i.toString, s"str$i")))
+ .toDF("key", "val")
+
+ val actual1 =
+ skewed.as("a").joinSkewed(notSkewed.as("b"), expr("a.key = b.key"), 3)
+
+ val expected1 = sqlContext
+ .createDataFrame(
+ List(
+ ("1", "a", "1", "str1"),
+ ("1", "b", "1", "str1"),
+ ("1", "c", "1", "str1"),
+ ("1", "d", "1", "str1"),
+ ("1", "e", "1", "str1"),
+ ("2", "k", "2", "str2")
+ ))
+ .toDF("key", "val_skewed", "key", "val")
+
+ // assertDataFrameEquals cares about order but we don't
+ assertDataFrameEquals(expected1, actual1.sort($"val_skewed"))
+
+ val actual2 = skewed
+ .as("a")
+ .joinSkewed(notSkewed.as("b"), expr("a.key = b.key"), 3, "left_outer")
+
+ val expected2 = sqlContext
+ .createDataFrame(
+ List(
+ expJoinSkewed("1", "a", "1", "str1"),
+ expJoinSkewed("1", "b", "1", "str1"),
+ expJoinSkewed("1", "c", "1", "str1"),
+ expJoinSkewed("1", "d", "1", "str1"),
+ expJoinSkewed("1", "e", "1", "str1"),
+ expJoinSkewed("2", "k", "2", "str2"),
+ expJoinSkewed("0", "k", null, null)
+ ))
+ .toDF("key", "val_skewed", "key", "val")
+
+ // assertDataFrameEquals cares about order but we don't
+ assertDataFrameEquals(expected2, actual2.sort($"val_skewed"))
+ }
+
+ val changedSchema = List(
+ StructField("fld1", StringType, true),
+ StructField("fld2", IntegerType, false),
+ StructField("fld3", StringType, true)
+ )
+
+ test("test_changeSchema") {
+
+ val actual = inputDataFrame.changeSchema("fld1", "fld2", "fld3")
+
+ val expected =
+ sqlContext.createDataFrame(inputRDD, StructType(changedSchema))
+
+ assertDataFrameEquals(expected, actual)
+ }
+
+ test("test_flatten") {
+
+ val input = inputDataFrame
+ .withColumn("struct_col", expr("struct(col_grp, col_ord)"))
+ .select("struct_col")
+
+ val expected: DataFrame = inputDataFrame.select("col_grp", "col_ord")
+
+ val actual = input.flatten("struct_col")
+
+ assertDataFrameEquals(expected, actual)
+ }
+}
diff --git a/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala b/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala
new file mode 100644
index 0000000..8df9fba
--- /dev/null
+++ b/datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package datafu.spark
+
+import com.holdenkarau.spark.testing.DataFrameSuiteBase
+import org.junit.Assert
+import org.junit.runner.RunWith
+import org.scalatest.FunSuite
+import org.scalatest.junit.JUnitRunner
+import org.slf4j.LoggerFactory
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
+import org.apache.spark.sql.types._
+
+@RunWith(classOf[JUnitRunner])
+class UdafTests extends FunSuite with DataFrameSuiteBase {
+
+ import spark.implicits._
+
+ /**
+ * taken from https://github.com/holdenk/spark-testing-base/issues/234#issuecomment-390150835
+ *
+ * Solves problem with Hive in Spark 2.3.0 in spark-testing-base
+ */
+ override def conf: SparkConf =
+ super.conf.set(CATALOG_IMPLEMENTATION.key, "hive")
+
+ val logger = LoggerFactory.getLogger(this.getClass)
+
+ val inputSchema = List(
+ StructField("col_grp", StringType, true),
+ StructField("col_ord", IntegerType, false),
+ StructField("col_str", StringType, true)
+ )
+
+ lazy val inputRDD = sc.parallelize(
+ Seq(Row("a", 1, "asd1"),
+ Row("a", 2, "asd2"),
+ Row("a", 3, "asd3"),
+ Row("b", 1, "asd4")))
+
+ lazy val df =
+ sqlContext.createDataFrame(inputRDD, StructType(inputSchema)).cache
+
+ case class mapExp(map_col: Map[String, Int])
+ case class mapArrExp(map_col: Map[String, Array[String]])
+
+ test("test multiset simple") {
+ val ms = new SparkUDAFs.MultiSet()
+ val expected: DataFrame =
+ sqlContext.createDataFrame(List(mapExp(Map("b" -> 1, "a" -> 3))))
+ assertDataFrameEquals(expected, df.agg(ms($"col_grp").as("map_col")))
+ }
+
+ val mas = new SparkUDAFs.MultiArraySet[String]()
+
+ test("test multiarrayset simple") {
+ assertDataFrameEquals(
+ sqlContext.createDataFrame(List(mapExp(Map("tre" -> 1, "asd" -> 2)))),
+ spark
+ .sql("select array('asd','tre','asd') arr")
+ .groupBy()
+ .agg(mas($"arr").as("map_col"))
+ )
+ }
+
+ test("test multiarrayset all nulls") {
+ // end case
+ spark.sql("drop table if exists mas_table")
+ spark.sql("create table mas_table (arr array<string>)")
+ spark.sql(
+ "insert overwrite table mas_table select case when 1=2 then array('asd') end " +
+ "from (select 1)z")
+ spark.sql(
+ "insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
+ spark.sql(
+ "insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
+ spark.sql(
+ "insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
+ spark.sql(
+ "insert into table mas_table select case when 1=2 then array('asd') end from (select 1)z")
+
+ val expected = sqlContext.createDataFrame(List(mapExp(Map())))
+
+ val actual =
+ spark.table("mas_table").groupBy().agg(mas($"arr").as("map_col"))
+
+ assertDataFrameEquals(expected, actual)
+ }
+
+ test("test multiarrayset max keys") {
+ // max keys case
+ spark.sql("drop table if exists mas_table2")
+ spark.sql("create table mas_table2 (arr array<string>)")
+ spark.sql(
+ "insert overwrite table mas_table2 select array('asd','dsa') from (select 1)z")
+ spark.sql(
+ "insert into table mas_table2 select array('asd','abc') from (select 1)z")
+ spark.sql(
+ "insert into table mas_table2 select array('asd') from (select 1)z")
+ spark.sql(
+ "insert into table mas_table2 select array('asd') from (select 1)z")
+ spark.sql(
+ "insert into table mas_table2 select array('asd') from (select 1)z")
+ spark.sql(
+ "insert into table mas_table2 select array('asd2') from (select 1)z")
+
+ val mas2 = new SparkUDAFs.MultiArraySet[String](maxKeys = 2)
+
+ assertDataFrameEquals(
+ sqlContext.createDataFrame(List(mapExp(Map("dsa" -> 1, "asd" -> 5)))),
+ spark.table("mas_table2").groupBy().agg(mas2($"arr").as("map_col")))
+
+ val mas1 = new SparkUDAFs.MultiArraySet[String](maxKeys = 1)
+ assertDataFrameEquals(
+ sqlContext.createDataFrame(List(mapExp(Map("asd" -> 5)))),
+ spark.table("mas_table2").groupBy().agg(mas1($"arr").as("map_col")))
+ }
+
+ test("test multiarrayset big input") {
+ val N = 100000
+ val blah = spark.sparkContext
+ .parallelize(1 to N, 20)
+ .toDF("num")
+ .selectExpr("array('asd',concat('dsa',num)) as arr")
+ val mas = new SparkUDAFs.MultiArraySet[String](maxKeys = 3)
+ val time1 = System.currentTimeMillis()
+ val mp = blah
+ .groupBy()
+ .agg(mas($"arr"))
+ .collect()
+ .map(_.getMap[String, Int](0))
+ .head
+ Assert.assertEquals(3, mp.size)
+ Assert.assertEquals("asd", mp.maxBy(_._2)._1)
+ Assert.assertEquals(N, mp.maxBy(_._2)._2)
+ val time2 = System.currentTimeMillis()
+ logger.info("time took: " + (time2 - time1) / 1000 + " secs")
+ }
+
+ test("test mapmerge") {
+ val mapMerge = new SparkUDAFs.MapSetMerge()
+
+ spark.sql("drop table if exists mapmerge_table")
+ spark.sql("create table mapmerge_table (c map<string, array<string>>)")
+ spark.sql(
+ "insert overwrite table mapmerge_table select map('k1', array('v1')) from (select 1) z")
+ spark.sql(
+ "insert into table mapmerge_table select map('k1', array('v1')) from (select 1) z")
+ spark.sql(
+ "insert into table mapmerge_table select map('k2', array('v3')) from (select 1) z")
+
+ assertDataFrameEquals(
+ sqlContext.createDataFrame(
+ List(mapArrExp(Map("k1" -> Array("v1"), "k2" -> Array("v3"))))),
+ spark.table("mapmerge_table").groupBy().agg(mapMerge($"c").as("map_col"))
+ )
+ }
+
+ test("minKeyValue") {
+ assertDataFrameEquals(
+ sqlContext.createDataFrame(List(("b", "asd4"), ("a", "asd1"))),
+ df.groupBy($"col_grp".as("_1"))
+ .agg(SparkOverwriteUDAFs.minValueByKey($"col_ord", $"col_str").as("_2"))
+ )
+ }
+
+ case class Exp4(colGrp: String, colOrd: Int, colStr: String, asd: String)
+
+ val minKeyValueWindowExpectedSchema = List(
+ StructField("col_grp", StringType, true),
+ StructField("col_ord", IntegerType, false),
+ StructField("col_str", StringType, true),
+ StructField("asd", StringType, true)
+ )
+
+ test("minKeyValue window") {
+ assertDataFrameEquals(
+ sqlContext.createDataFrame(
+ sc.parallelize(
+ Seq(
+ Row("b", 1, "asd4", "asd4"),
+ Row("a", 1, "asd1", "asd1"),
+ Row("a", 2, "asd2", "asd1"),
+ Row("a", 3, "asd3", "asd1")
+ )),
+ StructType(minKeyValueWindowExpectedSchema)
+ ),
+ df.withColumn("asd",
+ SparkOverwriteUDAFs
+ .minValueByKey($"col_ord", $"col_str")
+ .over(Window.partitionBy("col_grp")))
+ )
+ }
+
+ case class Exp5(col_grp: String, col_ord: Option[Int])
+
+ test("countDistinctUpTo") {
+ import datafu.spark.SparkUDAFs.CountDistinctUpTo
+
+ val countDistinctUpTo3 = new CountDistinctUpTo(3)
+ val countDistinctUpTo6 = new CountDistinctUpTo(6)
+
+ val inputDF = sqlContext.createDataFrame(
+ List(
+ Exp5("b", Option(1)),
+ Exp5("a", Option(1)),
+ Exp5("a", Option(2)),
+ Exp5("a", Option(3)),
+ Exp5("a", Option(4))
+ ))
+
+ val results3DF = sqlContext.createDataFrame(
+ List(
+ Exp5("b", Option(1)),
+ Exp5("a", Option(3))
+ ))
+
+ val results6DF = sqlContext.createDataFrame(
+ List(
+ Exp5("b", Option(1)),
+ Exp5("a", Option(4))
+ ))
+
+ inputDF
+ .groupBy("col_grp")
+ .agg(countDistinctUpTo3($"col_ord").as("col_ord"))
+ .show
+
+ assertDataFrameEquals(results3DF,
+ inputDF
+ .groupBy("col_grp")
+ .agg(countDistinctUpTo3($"col_ord").as("col_ord")))
+
+ assertDataFrameEquals(results6DF,
+ inputDF
+ .groupBy("col_grp")
+ .agg(countDistinctUpTo6($"col_ord").as("col_ord")))
+ }
+
+}
diff --git a/gradle.properties b/gradle.properties
index 4d40c7e..05d1749 100644
--- a/gradle.properties
+++ b/gradle.properties
@@ -19,4 +19,6 @@
version=1.5.0
gradleVersion=4.8.1
org.gradle.jvmargs="-XX:MaxPermSize=512m"
+scalaVersion=2.11
+sparkVersion=2.3.0
release=false
diff --git a/settings.gradle b/settings.gradle
index 7556761..8c478a6 100644
--- a/settings.gradle
+++ b/settings.gradle
@@ -17,4 +17,4 @@
* under the License.
*/
-include "build-plugin","datafu-pig","datafu-hourglass"
\ No newline at end of file
+include "build-plugin","datafu-pig","datafu-hourglass","datafu-spark"