tree: fa23b386de2d65270a31cabd1dd28fc8eb76f811 [path history] [tgz]
  1. bin/
  2. src/
  3. pom.xml

Deep Learning on Spark

Here comes the MXNet on Spark. It is built on the MXNet Scala Package and brings deep learning to Spark.

Now you have an end-to-end solution for large-scale deep models, which means you can take advantage of both the flexible parallel training approaches and GPU support with MXNet, and the fast data processing flow with Spark, to build the full pipeline from raw data to efficient deep learning.

The MXNet on Spark is still in experimental stage. Any suggestion or contribution will be highly appreciated.


Checkout the Installation Guide contains instructions to install mxnet. Remember to enable the distributed training, i.e., set USE_DIST_KVSTORE = 1.

Compile the Scala Package by

make scalapkg

This will automatically build the spark submodule. Now you can submit Spark job with these built jars.

You can find a piece of submit script in the bin directory of the spark module. Remember to set variables and versions according to your own environment.


Here is a Spark job example of how training a deep network looks like.

First define the parameters for the training procedure,

val conf = new SparkConf().setAppName("MXNet")
val sc = new SparkContext(conf)

val mxnet = new MXNet()
  .setContext(Context.cpu()) // or GPU if you like
  .setNetwork(network) // e.g. MLP model
  // These jars are required by the KVStores at runtime.
  // They will be uploaded and distributed to each node automatically

Now load data and do distributed training,

val trainData = parseRawData(sc, cmdLine.input)
val model =

In this example, it will start PS scheduler on driver and launch 2 servers. The input data will be split into 4 pieces and train with dist_async mode.

To save the output model, simply call save method,, cmdLine.output + "/model")

Predicting is straightforward,

val valData = parseRawData(sc, cmdLine.inputVal)
val brModel = sc.broadcast(model)
val res = valData.mapPartitions { data =>
  val probArrays = brModel.value.predict(points.toIterator)
  require(probArrays.length == 1)
  val prob = probArrays(0)
  val py = NDArray.argmax_channel(prob.get)
  val labels = py.toArray.mkString(",")
res.saveAsTextFile(cmdLine.output + "/data")


  • Sometime you have to specify the java argument, to help MXNet find the right java binary on worker nodes.
  • MXNet and ps-lite currently do NOT support multiple instances in one process, (we will fix this issue in the future, but with lower priority.) thus you must run Spark job in cluster mode (standalone, yarn-client, yarn-cluster). Local mode is NOT supported because it runs tasks in multiples threads with one process, which will block the initialization of KVStore. (Hint: If you only have one physical node and want to test the Spark package, you can start N workers on one node by setting export SPARK_WORKER_INSTANCES=N in Also, remember to set --executor-cores 1 to ensure there's only one task run in one Spark executor.
  • Fault tolerance is not fully supported. If some of your tasks fail, please restart the whole application. We will solve it soon.