[MINOR] Marked nn ex deprecated & new MNIST ex

Updated examples and replicated the setup from (McMahan et al., 2017)
Added nesterov momentum sgd to the example and bumped the layers up

Closes 1014#
diff --git a/scripts/nn/examples/Example - MNIST LeNet.ipynb b/scripts/nn/examples/Example - MNIST LeNet deprecated.ipynb
similarity index 98%
rename from scripts/nn/examples/Example - MNIST LeNet.ipynb
rename to scripts/nn/examples/Example - MNIST LeNet deprecated.ipynb
index fe9a09b..ea4fdc5 100644
--- a/scripts/nn/examples/Example - MNIST LeNet.ipynb
+++ b/scripts/nn/examples/Example - MNIST LeNet deprecated.ipynb
@@ -4,7 +4,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Quick Setup"
+    "## Quick Setup - Warning: Deprecated"
    ]
   },
   {
@@ -186,4 +186,4 @@
  },
  "nbformat": 4,
  "nbformat_minor": 1
-}
+}
\ No newline at end of file
diff --git a/scripts/nn/examples/Example - MNIST Softmax Classifier.ipynb b/scripts/nn/examples/Example - MNIST Softmax Classifier deprecated.ipynb
similarity index 98%
rename from scripts/nn/examples/Example - MNIST Softmax Classifier.ipynb
rename to scripts/nn/examples/Example - MNIST Softmax Classifier deprecated.ipynb
index dcef4ea..f0939f7 100644
--- a/scripts/nn/examples/Example - MNIST Softmax Classifier.ipynb
+++ b/scripts/nn/examples/Example - MNIST Softmax Classifier deprecated.ipynb
@@ -4,7 +4,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Quick Setup"
+    "## Quick Setup - Warning: Deprecated"
    ]
   },
   {
@@ -176,4 +176,4 @@
  },
  "nbformat": 4,
  "nbformat_minor": 1
-}
+}
\ No newline at end of file
diff --git a/scripts/nn/examples/Example-MNIST_2NN_ReLu_Softmax.dml b/scripts/nn/examples/Example-MNIST_2NN_ReLu_Softmax.dml
new file mode 100644
index 0000000..7b225dc
--- /dev/null
+++ b/scripts/nn/examples/Example-MNIST_2NN_ReLu_Softmax.dml
@@ -0,0 +1,70 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * The MNIST Data can be downloaded as follows:
+ * mkdir -p data/mnist/
+ * cd data/mnist/
+ * curl -O https://pjreddie.com/media/files/mnist_train.csv
+ * curl -O https://pjreddie.com/media/files/mnist_test.csv
+ */
+
+source("nn/examples/mnist_2NN.dml") as mnist_2NN
+
+# Read training data
+data = read("mnist_data/mnist_train.csv", format="csv")
+n = nrow(data)
+
+# Extract images and labels
+images = data[,2:ncol(data)]
+labels = data[,1]
+
+# Scale images to [0,1], and one-hot encode the labels
+images = images / 255.0
+labels = table(seq(1, n), labels+1, n, 10)
+
+# Split into training (55,000 examples) and validation (5,000 examples)
+X = images[5001:nrow(images),]
+X_val = images[1:5000,]
+y = labels[5001:nrow(images),]
+y_val = labels[1:5000,]
+
+# Train
+epochs = 5
+[W_1, b_1, W_2, b_2, W_3, b_3] = mnist_2NN::train(X, y, X_val, y_val, epochs)
+
+# Read test data
+data = read("mnist_data/mnist_test.csv", format="csv")
+n = nrow(data)
+
+# Extract images and labels
+X_test = data[,2:ncol(data)]
+y_test = data[,1]
+
+# Scale images to [0,1], and one-hot encode the labels
+X_test = X_test / 255.0
+y_test = table(seq(1, n), y_test+1, n, 10)
+
+# Eval on test set
+probs = mnist_2NN::predict(X_test, W_1, b_1, W_2, b_2, W_3, b_3)
+[loss, accuracy] = mnist_2NN::eval(probs, y_test)
+
+print("Test Accuracy: " + accuracy)
\ No newline at end of file
diff --git a/scripts/nn/examples/Example-MNIST_Softmax.dml b/scripts/nn/examples/Example-MNIST_Softmax.dml
new file mode 100644
index 0000000..ede1905
--- /dev/null
+++ b/scripts/nn/examples/Example-MNIST_Softmax.dml
@@ -0,0 +1,70 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+* The MNIST Data can be downloaded as follows:
+* mkdir -p data/mnist/
+* cd data/mnist/
+* curl -O https://pjreddie.com/media/files/mnist_train.csv
+* curl -O https://pjreddie.com/media/files/mnist_test.csv
+*/
+
+source("nn/examples/mnist_softmax.dml") as mnist_softmax
+
+# Read training data
+data = read("mnist_data/mnist_train.csv", format="csv")
+n = nrow(data)
+
+# Extract images and labels
+images = data[,2:ncol(data)]
+labels = data[,1]
+
+# Scale images to [0,1], and one-hot encode the labels
+images = images / 255.0
+labels = table(seq(1, n), labels+1, n, 10)
+
+# Split into training (55,000 examples) and validation (5,000 examples)
+X = images[5001:nrow(images),]
+X_val = images[1:5000,]
+y = labels[5001:nrow(images),]
+y_val = labels[1:5000,]
+
+# Train
+epochs = 1
+[W, b] = mnist_softmax::train(X, y, X_val, y_val, epochs)
+
+# Read test data
+data = read("mnist_data/mnist_test.csv", format="csv")
+n = nrow(data)
+
+# Extract images and labels
+X_test = data[,2:ncol(data)]
+y_test = data[,1]
+
+# Scale images to [0,1], and one-hot encode the labels
+X_test = X_test / 255.0
+y_test = table(seq(1, n), y_test+1, n, 10)
+
+# Eval on test set
+probs = mnist_softmax::predict(X_test, W, b)
+[loss, accuracy] = mnist_softmax::eval(probs, y_test)
+
+print("Test Accuracy: " + accuracy)
\ No newline at end of file
diff --git a/scripts/nn/examples/mnist_2NN.dml b/scripts/nn/examples/mnist_2NN.dml
new file mode 100644
index 0000000..3287867
--- /dev/null
+++ b/scripts/nn/examples/mnist_2NN.dml
@@ -0,0 +1,178 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * MNIST 2NN Relu Example
+ */
+
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+train = function(matrix[double] X, matrix[double] Y,
+                 matrix[double] X_val, matrix[double] Y_val,
+                 int epochs)
+    return (matrix[double] W_1, matrix[double] b_1,
+            matrix[double] W_2, matrix[double] b_2,
+            matrix[double] W_3, matrix[double] b_3) {
+  /*
+   * Trains a 2NN relu softmax classifier.
+   *
+   * The input matrix, X, has N examples, each with D features.
+   * The targets, Y, have K classes, and are one-hot encoded.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+   *  - Y_val: Target validation matrix, of shape (N, K).
+   *  - epochs: Total number of full training loops over the full data set.
+   *
+   * Outputs:
+   *  - W: Weights (parameters) matrix, of shape (D, M, 3).
+   *  - b: Biases vector, of shape (1, M, 3).
+   */
+  N = nrow(X)  # num examples
+  D = ncol(X)  # num features
+  K = ncol(Y)  # num classes
+
+  # Create the network:
+  # input -> 200 neuron affine -> relu -> 200 neuron affine -> relu -> K neurons affine -> softmax
+  [W_1, b_1] = affine::init(D, 200)
+  [W_2, b_2] = affine::init(200, 200)
+  [W_3, b_3] = affine::init(200, K)
+
+  # Initialize SGD
+  lr = 0.2  # learning rate
+  mu = 0  # momentum
+  decay = 0.99  # learning rate decay constant
+  vW_1 = sgd_nesterov::init(W_1)  # optimizer momentum state for W_1
+  vb_1 = sgd_nesterov::init(b_1)  # optimizer momentum state for b_1
+  vW_2 = sgd_nesterov::init(W_2)  # optimizer momentum state for W_2
+  vb_2 = sgd_nesterov::init(b_2)  # optimizer momentum state for b_2
+  vW_3 = sgd_nesterov::init(W_3)  # optimizer momentum state for W_3
+  vb_3 = sgd_nesterov::init(b_3)  # optimizer momentum state for b_3
+
+  # Optimize
+  print("Starting optimization")
+  batch_size = 50
+  iters = 1000
+  for (e in 1:epochs) {
+    for(i in 1:iters) {
+      # Get next batch
+      beg = ((i-1) * batch_size) %% N + 1
+      end = min(N, beg + batch_size - 1)
+      X_batch = X[beg:end,]
+      y_batch = Y[beg:end,]
+
+      # Compute forward pass
+      ## input D -> 200 neuron affine -> relu -> 200 neuron affine -> relu -> K neurons affine -> softmax
+      out_1 = affine::forward(X_batch, W_1, b_1)
+      out_1_relu = relu::forward(out_1)
+      out_2 = affine::forward(out_1_relu, W_2, b_2)
+      out_2_relu = relu::forward(out_2)
+      out_3 = affine::forward(out_2_relu, W_3, b_3)
+      probs = softmax::forward(out_3)
+
+      # Compute loss & accuracy for training & validation data
+      loss = cross_entropy_loss::forward(probs, y_batch)
+      accuracy = mean(rowIndexMax(probs) == rowIndexMax(y_batch))
+      probs_val = predict(X_val, W_1, b_1, W_2, b_2, W_3, b_3)
+      loss_val = cross_entropy_loss::forward(probs_val, Y_val)
+      accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
+      print("Epoch: " + e + ", Iter: " + i + ", Train Loss: " + loss + ", Train Accuracy: " +
+            accuracy + ", Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
+
+      # Compute backward pass
+      ## loss:
+      dprobs = cross_entropy_loss::backward(probs, y_batch)
+      dout_3 = softmax::backward(dprobs, out_3)
+      [dout_2_relu, dW_3, db_3] = affine::backward(dout_3, out_2_relu, W_3, b_3)
+      dout_2 = relu::backward(dout_2_relu, out_2)
+      [dout_1_relu, dW_2, db_2] = affine::backward(dout_2, out_1_relu, W_2, b_2)
+      dout_1 = relu::backward(dout_1_relu, out_1)
+      [dX_batch, dW_1, db_1] = affine::backward(dout_1, X_batch, W_1, b_1)
+
+      # Optimize with SGD
+      [W_3, vW_3] = sgd_nesterov::update(W_3, dW_3, lr, mu, vW_3)
+      [b_3, vb_3] = sgd_nesterov::update(b_3, db_3, lr, mu, vb_3)
+      [W_2, vW_2] = sgd_nesterov::update(W_2, dW_2, lr, mu, vW_2)
+      [b_2, vb_2] = sgd_nesterov::update(b_2, db_2, lr, mu, vb_2)
+      [W_1, vW_1] = sgd_nesterov::update(W_1, dW_1, lr, mu, vW_1)
+      [b_1, vb_1] = sgd_nesterov::update(b_1, db_1, lr, mu, vb_1)
+    }
+    # Anneal momentum towards 0.999
+    mu = mu + (0.999 - mu)/(1+epochs-e)
+    # Decay learning rate
+    lr = lr * decay
+  }
+}
+
+predict = function(matrix[double] X,
+                    matrix[double] W_1, matrix[double] b_1,
+                    matrix[double] W_2, matrix[double] b_2,
+                    matrix[double] W_3, matrix[double] b_3)
+    return (matrix[double] probs) {
+  /*
+   * Computes the class probability predictions of a softmax classifier.
+   *
+   * The input matrix, X, has N examples, each with D features.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - W: Weights (parameters) matrix, of shape (D, M).
+   *  - b: Biases vector, of shape (1, M).
+   *
+   * Outputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   */
+  # Compute forward pass
+  ## input -> 200 neuron affine -> relu -> 200 neuron affine -> relu -> K neurons affine -> softmax
+  out_1_relu = relu::forward(affine::forward(X, W_1, b_1))
+  out_2_relu = relu::forward(affine::forward(out_1_relu, W_2, b_2))
+  probs = softmax::forward(affine::forward(out_2_relu, W_3, b_3))
+}
+
+eval = function(matrix[double] probs, matrix[double] Y)
+    return (double loss, double accuracy) {
+  /*
+   * Evaluates a classifier.
+   *
+   * The probs matrix contains the class probability predictions
+   * of K classes over N examples.  The targets, Y, have K classes,
+   * and are one-hot encoded.
+   *
+   * Inputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   *  - Y: Target matrix, of shape (N, K).
+   *
+   * Outputs:
+   *  - loss: Scalar loss, of shape (1).
+   *  - accuracy: Scalar accuracy, of shape (1).
+   */
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, Y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(Y)
+  accuracy = mean(correct_pred)
+}
diff --git a/src/test/java/org/apache/sysds/test/applications/NNTest.java b/src/test/java/org/apache/sysds/test/applications/NNTest.java
index d03b768..db39a01 100644
--- a/src/test/java/org/apache/sysds/test/applications/NNTest.java
+++ b/src/test/java/org/apache/sysds/test/applications/NNTest.java
@@ -38,6 +38,7 @@
 	public void testNNLibrary() {
 		Script script = dmlFromFile(TEST_SCRIPT);
 		String stdOut = executeAndCaptureStdOut(ml, script).getRight();
+		System.out.println(stdOut);
 		assertTrue(stdOut, !stdOut.contains(ERROR_STRING));
 	}
 }