| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # MNIST LeNet - Train |
| # |
| # This script trains a convolutional net using the "LeNet" architecture |
| # on generated dummy data using distributed synchronous SGD. This is |
| # mainly for engine development purposes. |
| # |
| # Inputs: |
| # - N: [DEFAULT: 1024] Number of train dummy images to generate. |
| # - Nval: [DEFAULT: 512] Number of val dummy images to generate. |
| # - Ntest: [DEFAULT: 512] Number of test dummy images to generate. |
| # - C: [DEFAULT: 3] Number of color chanels in the images. |
| # - Hin: [DEFAULT: 224] Input image height. |
| # - Win: [DEFAULT: 224] Input image width. |
| # - K: [DEFAULT: 10] Number of dummy classes to use. |
| # - batch_size: [DEFAULT: 32] Number of examples in individual batches. |
| # - parallel_batches: [DEFAULT: 4] Number of batches to run in parallel. |
| # - epochs: [DEFAULT: 10] Total number of full training loops over |
| # the full data set. |
| # - out_dir: [DEFAULT: "."] Directory to store weights and bias |
| # matrices of trained model, as well as final test accuracy. |
| # |
| # Outputs: |
| # - W1, W2, W3, W4: Files containing the trained weights of the model. |
| # - b1, b2, b3, b4: Files containing the trained biases of the model. |
| # - accuracy: File containing the final accuracy on the test data. |
| # |
| # Data: |
| # The MNIST dataset contains labeled images of handwritten digits, |
| # where each example is a 28x28 pixel image of grayscale values in |
| # the range [0,255] stretched out as 784 pixels, and each label is |
| # one of 10 possible digits in [0,9]. |
| # |
| # Sample Invocation (running from wihtin the `examples` folder): |
| # 1. Download data (60,000 training examples, and 10,000 test examples) |
| # ``` |
| # nn/examples/get_mnist_data.sh |
| # ``` |
| # |
| # 2. Execute using Spark |
| # ``` |
| # spark-submit --master local[*] --driver-memory 10G --conf spark.driver.maxResultSize=0 |
| # --conf spark.rpc.message.maxSize=128 $SYSTEMDS_ROOT/target/SystemDS.jar |
| # -f nn/examples/mnist_lenet_distrib_sgd-train-dummy-data.dml -nvargs N=1024 Nval=512 Ntest=512 |
| # C=1 Hin=28 Win=28 K=10 batch_size=32 parallel_batches=8 epochs=10 |
| # out_dir=nn/examples/model/mnist_lenet |
| # ``` |
| # |
| source("nn/examples/mnist_lenet_distrib_sgd.dml") as mnist_lenet |
| |
| # Read training data & settings |
| N = ifdef($N, 1024) |
| Nval = ifdef($Nval, 512) |
| Ntest = ifdef($Ntest, 512) |
| C = ifdef($C, 3) |
| Hin = ifdef($Hin, 224) |
| Win = ifdef($Win, 224) |
| K = ifdef($K, 10) |
| batch_size = ifdef($batch_size, 32) |
| parallel_batches = ifdef($parallel_batches, 4) |
| epochs = ifdef($epochs, 10) |
| out_dir = ifdef($out_dir, ".") |
| |
| # Generate dummy data |
| [X, Y] = mnist_lenet::generate_dummy_data(N, C, Hin, Win, K) |
| [X_val, Y_val] = mnist_lenet::generate_dummy_data(Nval, C, Hin, Win, K) |
| [X_test, Y_test] = mnist_lenet::generate_dummy_data(Ntest, C, Hin, Win, K) |
| |
| # Train |
| [W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, batch_size, |
| parallel_batches, epochs) |
| |
| # Write model out |
| write(W1, out_dir+"/W1") |
| write(b1, out_dir+"/b1") |
| write(W2, out_dir+"/W2") |
| write(b2, out_dir+"/b2") |
| write(W3, out_dir+"/W3") |
| write(b3, out_dir+"/b3") |
| write(W4, out_dir+"/W4") |
| write(b4, out_dir+"/b4") |
| |
| # Eval on test set |
| probs = mnist_lenet::predict(X_test, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) |
| [loss, accuracy] = mnist_lenet::eval(probs, Y_test) |
| |
| # Output results |
| print("Test Accuracy: " + accuracy) |
| write(accuracy, out_dir+"/accuracy") |
| |
| print("") |
| print("") |
| |