| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License |
| */ |
| |
| package org.apache.toree.magic.builtin |
| |
| import java.io.{File, PrintStream} |
| import java.net.URL |
| import java.nio.file.{Files, Paths} |
| |
| import org.apache.toree.magic._ |
| import org.apache.toree.magic.builtin.AddJar._ |
| import org.apache.toree.magic.dependencies._ |
| import org.apache.toree.utils.{ArgumentParsingSupport, DownloadSupport, LogLike, FileUtils} |
| import com.typesafe.config.Config |
| import org.apache.toree.plugins.annotations.Event |
| |
| object AddJar { |
| |
| private var jarDir:Option[String] = None |
| |
| def getJarDir(config: Config): String = { |
| jarDir.getOrElse({ |
| jarDir = Some( |
| if(config.hasPath("jar_dir") && Files.exists(Paths.get(config.getString("jar_dir")))) { |
| config.getString("jar_dir") |
| } else { |
| FileUtils.createManagedTempDirectory("toree_add_jars").getAbsolutePath |
| } |
| ) |
| jarDir.get |
| }) |
| } |
| } |
| |
| class AddJar |
| extends LineMagic with IncludeInterpreter |
| with IncludeOutputStream with DownloadSupport with ArgumentParsingSupport |
| with IncludeKernel with IncludePluginManager with IncludeConfig with LogLike |
| { |
| // Option to mark re-downloading of jars |
| private val _force = |
| parser.accepts("f", "forces re-download of specified jar") |
| |
| // Option to mark re-downloading of jars |
| private val _magic = |
| parser.accepts("magic", "loads jar as a magic extension") |
| |
| // Lazy because the outputStream is not provided at construction |
| private def printStream = new PrintStream(outputStream) |
| |
| /** |
| * Retrieves file name from URL. |
| * |
| * @param location The remote location (URL) |
| * @return The name of the remote URL, or an empty string if one does not exist |
| */ |
| def getFileFromLocation(location: String): String = { |
| val url = new URL(location) |
| val file = url.getFile.split("/") |
| if (file.length > 0) { |
| file.last |
| } else { |
| "" |
| } |
| } |
| |
| /** |
| * Downloads and adds the specified jar to the |
| * interpreter/compiler/cluster classpaths. |
| * |
| * @param code The line containing the location of the jar |
| */ |
| @Event(name = "addjar") |
| override def execute(code: String): Unit = { |
| val nonOptionArgs = parseArgs(code.trim) |
| |
| // Check valid arguments |
| if (nonOptionArgs.length != 1) { |
| printHelp(printStream, """%AddJar <jar_url>""") |
| return |
| } |
| |
| // Check if the jar we want to download is valid |
| val jarRemoteLocation = nonOptionArgs(0) |
| if (jarRemoteLocation.isEmpty) { |
| printHelp(printStream, """%AddJar <jar_url>""") |
| return |
| } |
| |
| // Get the destination of the jar |
| val jarName = getFileFromLocation(jarRemoteLocation) |
| |
| // Ensure the URL actually contains a jar or zip file |
| if (!jarName.endsWith(".jar") && !jarName.endsWith(".zip")) { |
| throw new IllegalArgumentException( |
| s"The jar file $jarName must end in .jar or .zip." |
| ) |
| } |
| |
| val downloadLocation = getJarDir(config) + "/" + jarName |
| |
| logger.debug("Downloading jar to %s".format(downloadLocation)) |
| |
| val fileDownloadLocation = new File(downloadLocation) |
| |
| // Check if exists in cache or force applied |
| if (_force || !fileDownloadLocation.exists()) { |
| // Report beginning of download |
| printStream.println(s"Starting download from $jarRemoteLocation") |
| |
| downloadFile( |
| new URL(jarRemoteLocation), |
| new File(downloadLocation).toURI.toURL |
| ) |
| |
| // Report download finished |
| printStream.println(s"Finished download of $jarName") |
| } else { |
| printStream.println(s"Using cached version of $jarName") |
| } |
| |
| if (_magic) { |
| val plugins = pluginManager.loadPlugins(fileDownloadLocation) |
| pluginManager.initializePlugins(plugins) |
| } else { |
| kernel.addJars(fileDownloadLocation.toURI) |
| } |
| } |
| } |