[TOREE-408] Add support for hdfs and s3 to AddJar
Closes #125
diff --git a/kernel/src/main/scala/org/apache/toree/magic/builtin/AddJar.scala b/kernel/src/main/scala/org/apache/toree/magic/builtin/AddJar.scala
index 48a8124..ef5e927 100644
--- a/kernel/src/main/scala/org/apache/toree/magic/builtin/AddJar.scala
+++ b/kernel/src/main/scala/org/apache/toree/magic/builtin/AddJar.scala
@@ -18,17 +18,18 @@
package org.apache.toree.magic.builtin
import java.io.{File, PrintStream}
-import java.net.URL
+import java.net.{URL, URI}
import java.nio.file.{Files, Paths}
-
import org.apache.toree.magic._
import org.apache.toree.magic.builtin.AddJar._
import org.apache.toree.magic.dependencies._
import org.apache.toree.utils.{ArgumentParsingSupport, DownloadSupport, LogLike, FileUtils}
import com.typesafe.config.Config
+import org.apache.hadoop.fs.Path
import org.apache.toree.plugins.annotations.Event
object AddJar {
+ val HADOOP_FS_SCHEMES = Set("hdfs", "s3", "s3n", "file")
private var jarDir:Option[String] = None
@@ -63,18 +64,18 @@
private def printStream = new PrintStream(outputStream)
/**
- * Retrieves file name from URL.
+ * Retrieves file name from a URI.
*
- * @param location The remote location (URL)
- * @return The name of the remote URL, or an empty string if one does not exist
+ * @param location a URI
+ * @return The file name of the remote URI, or an empty string if one does not exist
*/
def getFileFromLocation(location: String): String = {
- val url = new URL(location)
- val file = url.getFile.split("/")
- if (file.length > 0) {
- file.last
+ val uri = new URI(location)
+ val pathParts = uri.getPath.split("/")
+ if (pathParts.nonEmpty) {
+ pathParts.last
} else {
- ""
+ ""
}
}
@@ -122,10 +123,27 @@
// Report beginning of download
printStream.println(s"Starting download from $jarRemoteLocation")
- downloadFile(
- new URL(jarRemoteLocation),
- new File(downloadLocation).toURI.toURL
- )
+ val jar = URI.create(jarRemoteLocation)
+ if (HADOOP_FS_SCHEMES.contains(jar.getScheme)) {
+ val conf = kernel.sparkContext.hadoopConfiguration
+ val jarPath = new Path(jarRemoteLocation)
+ val fs = jarPath.getFileSystem(conf)
+ val destPath = if (downloadLocation.startsWith("file:")) {
+ new Path(downloadLocation)
+ } else {
+ new Path("file:" + downloadLocation)
+ }
+
+ fs.copyToLocalFile(
+ false /* keep original file */,
+ jarPath, destPath,
+ true /* don't create checksum files */)
+ } else {
+ downloadFile(
+ new URL(jarRemoteLocation),
+ new File(downloadLocation).toURI.toURL
+ )
+ }
// Report download finished
printStream.println(s"Finished download of $jarName")
diff --git a/kernel/src/test/scala/org/apache/toree/magic/builtin/AddJarSpec.scala b/kernel/src/test/scala/org/apache/toree/magic/builtin/AddJarSpec.scala
index 1c7b3fc..8d1f44b 100644
--- a/kernel/src/test/scala/org/apache/toree/magic/builtin/AddJarSpec.scala
+++ b/kernel/src/test/scala/org/apache/toree/magic/builtin/AddJarSpec.scala
@@ -91,7 +91,8 @@
url = """http://www.example.com/remotecontent?filepath=/path/to/someJar.jar"""
jarName = addJarMagic.getFileFromLocation(url)
- assert(jarName == "someJar.jar")
+ // File names come from the path, not from the query fragment
+ assert(jarName == "remotecontent")
url = """http://www.example.com/"""
jarName = addJarMagic.getFileFromLocation(url)