| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # MNIST Softmax - Train |
| # |
| # This script trains a softmax classifier on images of handwritten |
| # digits. |
| # |
| # Inputs: |
| # - train: File containing labeled MNIST training images. |
| # The format is "label, pixel_1, pixel_2, ..., pixel_n". |
| # - test: File containing labeled MNIST test images. |
| # The format is "label, pixel_1, pixel_2, ..., pixel_n". |
| # - out_dir: Directory to store weights and bias matrices of |
| # trained model, as well as final test accuracy. |
| # - fmt: [DEFAULT: "csv"] File format of `train` and `test` data. |
| # Options include: "csv", "mm", "text", and "binary". |
| # |
| # Outputs: |
| # - W: File containing the trained weights of the model. |
| # - b: File containing the trained biases of the model. |
| # - accuracy: File containing the final accuracy on the test data. |
| # |
| # Data: |
| # The MNIST dataset contains labeled images of handwritten digits, |
| # where each example is a 28x28 pixel image of grayscale values in |
| # the range [0,255] stretched out as 784 pixels, and each label is |
| # one of 10 possible digits in [0,9]. |
| # |
| # Sample Invocation (running from wihtin the `examples` folder): |
| # 1. Download data (60,000 training examples, and 10,000 test examples) |
| # ``` |
| # nn/examples/get_mnist_data.sh |
| # ``` |
| # |
| # 2. Execute using Spark |
| # ``` |
| # spark-submit --master local[*] --driver-memory 10G |
| # --conf spark.driver.maxResultSize=0 --conf spark.rpc.message.maxSize=128 |
| # $SYSTEMDS_ROOT/target/SystemDS.jar -f nn/examples/mnist_softmax-train.dml |
| # -nvargs train=nn/examples/data/mnist/mnist_train.csv test=nn/examples/data/mnist/mnist_test.csv |
| # epochs=1 out_dir=nn/examples/model/mnist_softmax |
| # ``` |
| # |
| source("nn/examples/mnist_softmax.dml") as mnist_softmax |
| |
| # Read training data |
| fmt = ifdef($fmt, "csv") |
| train = read($train, format=fmt) |
| test = read($test, format=fmt) |
| epochs = ifdef($epochs, 1) |
| out_dir = ifdef($out_dir, ".") |
| |
| # Extract images and labels |
| images = train[,2:ncol(train)] |
| labels = train[,1] |
| X_test = test[,2:ncol(test)] |
| Y_test = test[,1] |
| |
| # Scale images to [0,1], and one-hot encode the labels |
| n = nrow(train) |
| n_test = nrow(test) |
| classes = 10 |
| images = images / 255.0 |
| labels = table(seq(1, n), labels+1, n, classes) |
| X_test = X_test / 255.0 |
| Y_test = table(seq(1, n_test), Y_test+1, n_test, classes) |
| |
| # Split into training (55,000 examples) and validation (5,000 examples) |
| X = images[5001:nrow(images),] |
| X_val = images[1:5000,] |
| Y = labels[5001:nrow(images),] |
| Y_val = labels[1:5000,] |
| |
| # Train |
| [W, b] = mnist_softmax::train(X, Y, X_val, Y_val, epochs) |
| |
| # Write model out |
| write(W, out_dir+"/W") |
| write(b, out_dir+"/b") |
| |
| # Eval on test set |
| probs = mnist_softmax::predict(X_test, W, b) |
| [loss, accuracy] = mnist_softmax::eval(probs, Y_test) |
| |
| # Output results |
| print("Test Accuracy: " + accuracy) |
| write(accuracy, out_dir+"/accuracy") |
| |
| print("") |
| print("") |
| |