This Scala tutorial guides you through a classic computer vision application: identifying hand written digits.
Let's train a 3-layer network (i.e multilayer perceptron network) on the MNIST dataset to classify handwritten digits.
First, define the neural network's architecture using the Symbol API:
import ml.dmlc.mxnet._ import ml.dmlc.mxnet.optimizer.SGD // model definition val data = Symbol.Variable("data") val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data, "num_hidden" -> 128)) val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1, "act_type" -> "relu")) val fc2 = Symbol.FullyConnected(name = "fc2")()(Map("data" -> act1, "num_hidden" -> 64)) val act2 = Symbol.Activation(name = "relu2")()(Map("data" -> fc2, "act_type" -> "relu")) val fc3 = Symbol.FullyConnected(name = "fc3")()(Map("data" -> act2, "num_hidden" -> 10)) val mlp = Symbol.SoftmaxOutput(name = "sm")()(Map("data" -> fc3))
Then, load the training and validation data using DataIterators.
You can download the MNIST data using the get_mnist_data script. We've already written a DataIterator for the MNIST dataset:
// load MNIST dataset val trainDataIter = IO.MNISTIter(Map( "image" -> "data/train-images-idx3-ubyte", "label" -> "data/train-labels-idx1-ubyte", "data_shape" -> "(1, 28, 28)", "label_name" -> "sm_label", "batch_size" -> "50", "shuffle" -> "1", "flat" -> "0", "silent" -> "0", "seed" -> "10")) val valDataIter = IO.MNISTIter(Map( "image" -> "data/t10k-images-idx3-ubyte", "label" -> "data/t10k-labels-idx1-ubyte", "data_shape" -> "(1, 28, 28)", "label_name" -> "sm_label", "batch_size" -> "50", "shuffle" -> "1", "flat" -> "0", "silent" -> "0"))
We can use the FeedForward builder to train our network:
// setup model and fit the training data val model = FeedForward.newBuilder(mlp) .setContext(Context.cpu()) .setNumEpoch(10) .setOptimizer(new SGD(learningRate = 0.1f, momentum = 0.9f, wd = 0.0001f)) .setTrainData(trainDataIter) .setEvalData(valDataIter) .build()
Finally, let's make predictions against the validation dataset and compare the predicted labels with the real labels.
val probArrays = model.predict(valDataIter) // in this case, we do not have multiple outputs require(probArrays.length == 1) val prob = probArrays(0) // get real labels import scala.collection.mutable.ListBuffer valDataIter.reset() val labels = ListBuffer.empty[NDArray] while (valDataIter.hasNext) { val evalData = valDataIter.next() labels += evalData.label(0).copy() } val y = NDArray.concatenate(labels) // get predicted labels val predictedY = NDArray.argmaxChannel(prob) require(y.shape == predictedY.shape) // calculate accuracy var numCorrect = 0 var numTotal = 0 for ((labelElem, predElem) <- y.toArray zip predictedY.toArray) { if (labelElem == predElem) { numCorrect += 1 } numTotal += 1 } val acc = numCorrect.toFloat / numTotal println(s"Final accuracy = $acc")
Check out more MXNet Scala examples below.