| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| /* |
| * MNIST Softmax Example |
| */ |
| # Imports |
| source("nn/layers/affine.dml") as affine |
| source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss |
| source("nn/layers/softmax.dml") as softmax |
| source("nn/optim/sgd_nesterov.dml") as sgd_nesterov |
| |
| train = function(matrix[double] X, matrix[double] Y, |
| matrix[double] X_val, matrix[double] Y_val, |
| int epochs) |
| return (matrix[double] W, matrix[double] b) { |
| /* |
| * Trains a softmax classifier. |
| * |
| * The input matrix, X, has N examples, each with D features. |
| * The targets, Y, have K classes, and are one-hot encoded. |
| * |
| * Inputs: |
| * - X: Input data matrix, of shape (N, D). |
| * - Y: Target matrix, of shape (N, K). |
| * - X_val: Input validation data matrix, of shape (N, C*Hin*Win). |
| * - Y_val: Target validation matrix, of shape (N, K). |
| * - epochs: Total number of full training loops over the full data set. |
| * |
| * Outputs: |
| * - W: Weights (parameters) matrix, of shape (D, M). |
| * - b: Biases vector, of shape (1, M). |
| */ |
| N = nrow(X) # num examples |
| D = ncol(X) # num features |
| K = ncol(Y) # num classes |
| |
| # Create softmax classifier: |
| # affine -> softmax |
| [W, b] = affine::init(D, K) |
| W = W / sqrt(2.0/(D)) * sqrt(1/(D)) |
| |
| # Initialize SGD w/ Nesterov momentum optimizer |
| lr = 0.2 # learning rate |
| mu = 0 # momentum |
| decay = 0.99 # learning rate decay constant |
| vW = sgd_nesterov::init(W) # optimizer momentum state for W |
| vb = sgd_nesterov::init(b) # optimizer momentum state for b |
| |
| # Optimize |
| print("Starting optimization") |
| batch_size = 50 |
| iters = 1000 #ceil(N / batch_size) |
| for (e in 1:epochs) { |
| for(i in 1:iters) { |
| # Get next batch |
| beg = ((i-1) * batch_size) %% N + 1 |
| end = min(N, beg + batch_size - 1) |
| X_batch = X[beg:end,] |
| y_batch = Y[beg:end,] |
| |
| # Compute forward pass |
| ## affine & softmax: |
| out = affine::forward(X_batch, W, b) |
| probs = softmax::forward(out) |
| |
| # Compute loss & accuracy for training & validation data |
| loss = cross_entropy_loss::forward(probs, y_batch) |
| accuracy = mean(rowIndexMax(probs) == rowIndexMax(y_batch)) |
| probs_val = predict(X_val, W, b) |
| loss_val = cross_entropy_loss::forward(probs_val, Y_val) |
| accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) |
| print("Epoch: " + e + ", Iter: " + i + ", Train Loss: " + loss + ", Train Accuracy: " + |
| accuracy + ", Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) |
| |
| # Compute backward pass |
| ## loss: |
| dprobs = cross_entropy_loss::backward(probs, y_batch) |
| ## affine & softmax: |
| dout = softmax::backward(dprobs, out) |
| [dX_batch, dW, db] = affine::backward(dout, X_batch, W, b) |
| |
| # Optimize with SGD w/ Nesterov momentum |
| [W, vW] = sgd_nesterov::update(W, dW, lr, mu, vW) |
| [b, vb] = sgd_nesterov::update(b, db, lr, mu, vb) |
| } |
| # Anneal momentum towards 0.999 |
| mu = mu + (0.999 - mu)/(1+epochs-e) |
| # Decay learning rate |
| lr = lr * decay |
| } |
| } |
| |
| predict = function(matrix[double] X, matrix[double] W, matrix[double] b) |
| return (matrix[double] probs) { |
| /* |
| * Computes the class probability predictions of a softmax classifier. |
| * |
| * The input matrix, X, has N examples, each with D features. |
| * |
| * Inputs: |
| * - X: Input data matrix, of shape (N, D). |
| * - W: Weights (parameters) matrix, of shape (D, M). |
| * - b: Biases vector, of shape (1, M). |
| * |
| * Outputs: |
| * - probs: Class probabilities, of shape (N, K). |
| */ |
| # Compute forward pass |
| ## affine & softmax: |
| out = affine::forward(X, W, b) |
| probs = softmax::forward(out) |
| } |
| |
| eval = function(matrix[double] probs, matrix[double] Y) |
| return (double loss, double accuracy) { |
| /* |
| * Evaluates a softmax classifier. |
| * |
| * The probs matrix contains the class probability predictions |
| * of K classes over N examples. The targets, Y, have K classes, |
| * and are one-hot encoded. |
| * |
| * Inputs: |
| * - probs: Class probabilities, of shape (N, K). |
| * - Y: Target matrix, of shape (N, K). |
| * |
| * Outputs: |
| * - loss: Scalar loss, of shape (1). |
| * - accuracy: Scalar accuracy, of shape (1). |
| */ |
| # Compute loss & accuracy |
| loss = cross_entropy_loss::forward(probs, Y) |
| correct_pred = rowIndexMax(probs) == rowIndexMax(Y) |
| accuracy = mean(correct_pred) |
| } |
| |
| generate_dummy_data = function() |
| return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) { |
| /* |
| * Generate a dummy dataset similar to the MNIST dataset. |
| * |
| * Outputs: |
| * - X: Input data matrix, of shape (N, D). |
| * - Y: Target matrix, of shape (N, K). |
| * - C: Number of input channels (dimensionality of input depth). |
| * - Hin: Input height. |
| * - Win: Input width. |
| */ |
| # Generate dummy input data |
| N = 1024 # num examples |
| C = 1 # num input channels |
| Hin = 28 # input height |
| Win = 28 # input width |
| T = 10 # num targets |
| X = rand(rows=N, cols=C*Hin*Win, pdf="normal") |
| classes = round(rand(rows=N, cols=1, min=1, max=T, pdf="uniform")) |
| Y = table(seq(1, N), classes) # one-hot encoding |
| } |
| |