| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # MNIST LeNet - Train |
| # |
| # This script trains a convolutional net using the "LeNet" architecture |
| # on images of handwritten digits. |
| # |
| # Inputs: |
| # - train: File containing labeled MNIST training images. |
| # The format is "label, pixel_1, pixel_2, ..., pixel_n". |
| # - test: File containing labeled MNIST test images. |
| # The format is "label, pixel_1, pixel_2, ..., pixel_n". |
| # - C: Number of color chanels in the images. |
| # - Hin: Input image height. |
| # - Win: Input image width. |
| # - epochs: [DEFAULT: 10] Total number of full training loops over |
| # the full data set. |
| # - out_dir: [DEFAULT: "."] Directory to store weights and bias |
| # matrices of trained model, as well as final test accuracy. |
| # - fmt: [DEFAULT: "csv"] File format of `train` and `test` data. |
| # Options include: "csv", "mm", "text", and "binary". |
| # |
| # Outputs: |
| # - W1, W2, W3, W4: Files containing the trained weights of the model. |
| # - b1, b2, b3, b4: Files containing the trained biases of the model. |
| # - accuracy: File containing the final accuracy on the test data. |
| # |
| # Data: |
| # The MNIST dataset contains labeled images of handwritten digits, |
| # where each example is a 28x28 pixel image of grayscale values in |
| # the range [0,255] stretched out as 784 pixels, and each label is |
| # one of 10 possible digits in [0,9]. |
| # |
| # Sample Invocation (running from outside the `nn` folder): |
| # 1. Download data (60,000 training examples, and 10,000 test examples) |
| # ``` |
| # nn/examples/get_mnist_data.sh |
| # ``` |
| # |
| # 2. Execute using Spark |
| # ``` |
| # spark-submit --master local[*] --driver-memory 10G |
| # --conf spark.driver.maxResultSize=0 --conf spark.rpc.message.maxSize=128 |
| # $SYSTEMDS_ROOT/target/SystemDS.jar -f nn/examples/mnist_lenet-train.dml |
| # -nvargs train=nn/examples/data/mnist/mnist_train.csv test=nn/examples/data/mnist/mnist_test.csv |
| # C=1 Hin=28 Win=28 epochs=10 out_dir=nn/examples/model/mnist_lenet |
| # ``` |
| # |
| source("nn/examples/mnist_lenet.dml") as mnist_lenet |
| |
| # Read training data & settings |
| fmt = ifdef($fmt, "csv") |
| train = read($train, format=fmt) |
| test = read($test, format=fmt) |
| C = $C |
| Hin = $Hin |
| Win = $Win |
| epochs = ifdef($epochs, 10) |
| out_dir = ifdef($out_dir, ".") |
| |
| # Extract images and labels |
| images = train[,2:ncol(train)] |
| labels = train[,1] |
| X_test = test[,2:ncol(test)] |
| Y_test = test[,1] |
| |
| # Scale images to [-1,1], and one-hot encode the labels |
| n = nrow(train) |
| n_test = nrow(test) |
| images = (images / 255.0) * 2 - 1 |
| labels = table(seq(1, n), labels+1, n, 10) |
| X_test = (X_test / 255.0) * 2 - 1 |
| Y_test = table(seq(1, n_test), Y_test+1, n_test, 10) |
| |
| # Split into training (55,000 examples) and validation (5,000 examples) |
| X = images[5001:nrow(images),] |
| X_val = images[1:5000,] |
| Y = labels[5001:nrow(images),] |
| Y_val = labels[1:5000,] |
| |
| # Train |
| [W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs) |
| |
| # Write model out |
| write(W1, out_dir+"/W1") |
| write(b1, out_dir+"/b1") |
| write(W2, out_dir+"/W2") |
| write(b2, out_dir+"/b2") |
| write(W3, out_dir+"/W3") |
| write(b3, out_dir+"/b3") |
| write(W4, out_dir+"/W4") |
| write(b4, out_dir+"/b4") |
| |
| # Eval on test set |
| probs = mnist_lenet::predict(X_test, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) |
| [loss, accuracy] = mnist_lenet::eval(probs, Y_test) |
| |
| # Output results |
| print("Test Accuracy: " + accuracy) |
| write(accuracy, out_dir+"/accuracy") |
| |
| print("") |
| print("") |
| |