| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.mahout |
| |
| import java.io._ |
| |
| import org.apache.log4j.Logger |
| import org.apache.mahout.logging._ |
| import org.apache.mahout.math.drm._ |
| import org.apache.mahout.math.{Matrix, MatrixWritable, Vector, VectorWritable} |
| import org.apache.mahout.sparkbindings.drm.{CheckpointedDrmSpark, CheckpointedDrmSparkOps, SparkBCast} |
| import org.apache.mahout.util.IOUtilsScala |
| import org.apache.spark.broadcast.Broadcast |
| import org.apache.spark.rdd.RDD |
| import org.apache.spark.{SparkConf, SparkContext} |
| import org.apache.spark.mllib.regression.LabeledPoint |
| import org.apache.spark.mllib.linalg.{DenseVector => DenseSparkVector, SparseVector => SparseSparkVector, Vector => SparkVector} |
| import org.apache.spark.sql.DataFrame |
| |
| import collection._ |
| import collection.generic.Growable |
| import scala.reflect.{ClassTag,classTag} |
| |
| /** Public api for Spark-specific operators */ |
| package object sparkbindings { |
| |
| private final implicit val log: Logger = getLog(getClass) |
| |
| /** Row-wise organized DRM rdd type */ |
| type DrmRdd[K] = RDD[DrmTuple[K]] |
| |
| /** |
| * Blockifed DRM rdd (keys of original DRM are grouped into array corresponding to rows of Matrix |
| * object value |
| */ |
| type BlockifiedDrmRdd[K] = RDD[BlockifiedDrmTuple[K]] |
| |
| /** |
| * Create proper spark context that includes local Mahout jars |
| * @param masterUrl Url of the spark master. |
| * @param appName Applicatin name to launch. |
| * @param customJars Custom Jars to ship with applicatiom. |
| * @param sparkConf A SparkConf class to configure and create a SparkContext with. Default is new SparkConf() |
| * @param addMahoutJars Flag to aff mahout jars or not. Defaults to true. |
| * @return a SparkDistributedContext with the above attributes. |
| */ |
| def mahoutSparkContext(masterUrl: String, appName: String, customJars: TraversableOnce[String] = Nil, |
| sparkConf: SparkConf = new SparkConf(), addMahoutJars: Boolean = true): |
| SparkDistributedContext = { |
| |
| val closeables = mutable.ListBuffer.empty[Closeable] |
| |
| try { |
| |
| // when not including the artifact, eg. for viennacl , we always need |
| // to load all mahout jars |
| // will need to handle this somehow. |
| |
| if (addMahoutJars) { |
| |
| // context specific jars |
| val mcjars = findMahoutContextJars(closeables) |
| |
| if (log.isDebugEnabled) { |
| log.debug("Mahout jars:") |
| mcjars.foreach(j => log.debug(j)) |
| } |
| |
| sparkConf.setJars(jars = mcjars.toSeq ++ customJars) |
| // seems to kill drivers |
| // if (!(customJars.size > 0)) sparkConf.setJars(customJars.toSeq) |
| |
| } else { |
| // In local mode we don't care about jars, do we? |
| // yes adding jars always now since we are not including the artifacts |
| sparkConf.setJars(customJars.toSeq) |
| } |
| |
| sparkConf.setAppName(appName).setMaster(masterUrl).set("spark.serializer", |
| "org.apache.spark.serializer.KryoSerializer").set("spark.kryo.registrator", |
| "org.apache.mahout.sparkbindings.io.MahoutKryoRegistrator") |
| |
| if (System.getenv("SPARK_HOME") != null) { |
| sparkConf.setSparkHome(System.getenv("SPARK_HOME")) |
| } |
| |
| new SparkDistributedContext(new SparkContext(config = sparkConf)) |
| |
| } finally { |
| IOUtilsScala.close(closeables) |
| } |
| } |
| |
| implicit def sdc2sc(sdc: SparkDistributedContext): SparkContext = sdc.sc |
| |
| implicit def sc2sdc(sc: SparkContext): SparkDistributedContext = new SparkDistributedContext(sc) |
| |
| implicit def dc2sc(dc: DistributedContext): SparkContext = { |
| assert(dc.isInstanceOf[SparkDistributedContext], "distributed context must be Spark-specific.") |
| sdc2sc(dc.asInstanceOf[SparkDistributedContext]) |
| } |
| |
| /** Broadcast transforms */ |
| implicit def sb2bc[T](b: Broadcast[T]): BCast[T] = new SparkBCast(b) |
| |
| /** Adding Spark-specific ops */ |
| implicit def cpDrm2cpDrmSparkOps[K](drm: CheckpointedDrm[K]): CheckpointedDrmSparkOps[K] = |
| new CheckpointedDrmSparkOps[K](drm) |
| |
| implicit def drm2cpDrmSparkOps[K](drm: DrmLike[K]): CheckpointedDrmSparkOps[K] = drm: CheckpointedDrm[K] |
| |
| private[sparkbindings] implicit def m2w(m: Matrix): MatrixWritable = new MatrixWritable(m) |
| |
| private[sparkbindings] implicit def w2m(w: MatrixWritable): Matrix = w.get() |
| |
| private[sparkbindings] implicit def v2w(v: Vector): VectorWritable = new VectorWritable(v) |
| |
| private[sparkbindings] implicit def w2v(w: VectorWritable): Vector = w.get() |
| |
| /** |
| * ==Wrap existing RDD into a matrix== |
| * |
| * @param rdd source rdd conforming to [[org.apache.mahout.sparkbindings.DrmRdd]] |
| * @param nrow optional, number of rows. If not specified, we'll try to figure out on our own. |
| * @param ncol optional, number of columns. If not specififed, we'll try to figure out on our own. |
| * @param cacheHint optional, desired cache policy for that rdd. |
| * @param canHaveMissingRows optional. For int-keyed rows, there might be implied but missing rows. |
| * If underlying rdd may have that condition, we need to know since some |
| * operators consider that a deficiency and we'll need to fix it lazily |
| * before proceeding with such operators. It only meaningful if `nrow` is |
| * also specified (otherwise, we'll run quick test to figure if rows may |
| * be missing, at the time we count the rows). |
| * @tparam K row key type |
| * @return wrapped DRM |
| */ |
| def drmWrap[K: ClassTag](rdd: DrmRdd[K], nrow: Long = -1, ncol: Int = -1, cacheHint: CacheHint.CacheHint = |
| CacheHint.NONE, canHaveMissingRows: Boolean = false): CheckpointedDrm[K] = |
| |
| new CheckpointedDrmSpark[K](rddInput = rdd, _nrow = nrow, _ncol = ncol, cacheHint = cacheHint, |
| _canHaveMissingRows = canHaveMissingRows) |
| |
| /** A drmWrap version that takes an RDD[org.apache.spark.mllib.regression.LabeledPoint] |
| * returns a DRM where column the label is the last column */ |
| def drmWrapMLLibLabeledPoint(rdd: RDD[LabeledPoint], |
| nrow: Long = -1, |
| ncol: Int = -1, |
| cacheHint: CacheHint.CacheHint = CacheHint.NONE, |
| canHaveMissingRows: Boolean = false): CheckpointedDrm[Int] = { |
| val drmRDD: DrmRdd[Int] = rdd.zipWithIndex.map(lv => { |
| lv._1.features match { |
| case _: DenseSparkVector => (lv._2.toInt, new org.apache.mahout.math.DenseVector( lv._1.features.toArray ++ Array(lv._1.label) )) |
| case _: SparseSparkVector => (lv._2.toInt, |
| new org.apache.mahout.math.RandomAccessSparseVector(new org.apache.mahout.math.DenseVector( lv._1.features.toArray ++ Array(lv._1.label) )) ) |
| } |
| }) |
| |
| drmWrap(drmRDD, nrow, ncol, cacheHint, canHaveMissingRows) |
| } |
| |
| /** A drmWrap version that takes a DataFrame of Row[Double] */ |
| def drmWrapDataFrame(df: DataFrame, |
| nrow: Long = -1, |
| ncol: Int = -1, |
| cacheHint: CacheHint.CacheHint = CacheHint.NONE, |
| canHaveMissingRows: Boolean = false): CheckpointedDrm[Int] = { |
| val drmRDD: DrmRdd[Int] = df.rdd |
| .zipWithIndex |
| .map( o => (o._2.toInt, o._1.mkString(",").split(",").map(s => s.toDouble)) ) |
| .map(o => (o._1, new org.apache.mahout.math.DenseVector( o._2 ))) |
| |
| drmWrap(drmRDD, nrow, ncol, cacheHint, canHaveMissingRows) |
| } |
| |
| /** A drmWrap Version that takes an RDD[org.apache.spark.mllib.linalg.Vector] */ |
| def drmWrapMLLibVector(rdd: RDD[SparkVector], |
| nrow: Long = -1, |
| ncol: Int = -1, |
| cacheHint: CacheHint.CacheHint = CacheHint.NONE, |
| canHaveMissingRows: Boolean = false): CheckpointedDrm[Int] = { |
| val drmRDD: DrmRdd[Int] = rdd.zipWithIndex.map( v => { |
| v._1 match { |
| case _: DenseSparkVector => (v._2.toInt, new org.apache.mahout.math.DenseVector(v._1.toArray)) |
| case _: SparseSparkVector => (v._2.toInt, new org.apache.mahout.math.RandomAccessSparseVector(new org.apache.mahout.math.DenseVector(v._1.toArray)) ) |
| } |
| }) |
| drmWrap(drmRDD, nrow, ncol, cacheHint, canHaveMissingRows) |
| } |
| |
| /** Another drmWrap version that takes in vertical block-partitioned input to form the matrix. */ |
| def drmWrapBlockified[K: ClassTag](blockifiedDrmRdd: BlockifiedDrmRdd[K], nrow: Long = -1, ncol: Int = -1, |
| cacheHint: CacheHint.CacheHint = CacheHint.NONE, |
| canHaveMissingRows: Boolean = false): CheckpointedDrm[K] = |
| |
| drmWrap(drm.deblockify(blockifiedDrmRdd), nrow, ncol, cacheHint, canHaveMissingRows) |
| |
| private[sparkbindings] def getMahoutHome() = { |
| var mhome = System.getenv("MAHOUT_HOME") |
| if (mhome == null) mhome = System.getProperty("mahout.home") |
| require(mhome != null, "MAHOUT_HOME is required to spawn mahout-based spark jobs") |
| mhome |
| } |
| |
| /** Acquire proper Mahout jars to be added to task context based on current MAHOUT_HOME. */ |
| private[sparkbindings] def findMahoutContextJars(closeables: Growable[Closeable]) = { |
| |
| // Figure Mahout classpath using $MAHOUT_HOME/mahout classpath command. |
| val fmhome = new File(getMahoutHome()) |
| val bin = new File(fmhome, "bin") |
| val exec = new File(bin, "mahout") |
| if (!exec.canExecute) |
| throw new IllegalArgumentException("Cannot execute %s.".format(exec.getAbsolutePath)) |
| |
| // find out where we our spark jars are. |
| val p = Runtime.getRuntime.exec(Array(exec.getAbsolutePath, "-spark", "classpath")) |
| |
| closeables += new Closeable { |
| def close() { |
| p.destroy() |
| } |
| } |
| |
| val r = new BufferedReader(new InputStreamReader(p.getInputStream)) |
| closeables += r |
| |
| val w = new StringWriter() |
| closeables += w |
| |
| var continue = true |
| val jars = new mutable.ArrayBuffer[String]() |
| do { |
| val cp = r.readLine() |
| if (cp == null) |
| throw new IllegalArgumentException("Unable to read output from \"mahout -spark classpath\". Is SPARK_HOME " + |
| "defined?") |
| |
| val j = cp.split(File.pathSeparatorChar) |
| if (j.length > 10) { |
| // assume this is a valid classpath line |
| jars ++= j |
| continue = false |
| } |
| } while (continue) |
| |
| // jars.foreach(j => log.info(j)) |
| // context specific jars |
| val mcjars = jars.filter(j => |
| j.matches(".*core-\\d.*\\.jar") || |
| j.matches(".*mahout-math-scala_\\d.*\\.jar") || |
| j.matches(".*mahout-hdfs-\\d.*\\.jar") || |
| // no need for mapreduce jar in Spark |
| // j.matches(".*mahout-mr-\\d.*\\.jar") || |
| j.matches(".*spark_\\d.*\\.jar") || |
| // vcl jars: mahout-native-viennacl_2.11.jar, |
| // mahout-native-viennacl-omp_2.11.jar |
| // j.matches(".*mahout-native-viennacl_\\d.*\\\\.jar") || |
| // j.matches(".*mahout-native-viennacl-omp_\\d.*\\.jar")|| |
| j.matches(".*mahout-native-viennacl*.jar")|| |
| // while WIP on MAHOUT-1894, use single wildcard |
| // TODO: remove after 1894 is closed out |
| j.matches(".*spark*-dependency-reduced.jar") |
| |
| ) |
| // Tune out "bad" classifiers |
| .filter(n => |
| !n.matches(".*-tests.jar") && |
| !n.matches(".*-sources.jar") && |
| !n.matches(".*-job.jar") && |
| // During maven tests, the maven classpath also creeps in for some reason |
| !n.matches(".*/.m2/.*") |
| ) |
| /* verify jar passed to context */ |
| info("\n\n\n") |
| mcjars.foreach(j => info(j)) |
| info("\n\n\n") |
| /**/ |
| mcjars |
| } |
| |
| private[sparkbindings] def validateBlockifiedDrmRdd[K](rdd: BlockifiedDrmRdd[K]): Boolean = { |
| // Mostly, here each block must contain exactly one block |
| val part1Req = rdd.mapPartitions(piter => Iterator(piter.size == 1)).reduce(_ && _) |
| |
| if (!part1Req) warn("blockified rdd: condition not met: exactly 1 per partition") |
| |
| part1Req |
| } |
| |
| } |