blob: 9af4d945fd1df73a0d95a35501ea9b4fe6f751a3 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml") as mnist_lenet
source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
# Generate the training data
[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
n = nrow(images)
# Generate the training data
[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
# Split into training and validation
val_size = n * 0.1
X = images[(val_size+1):n,]
X_val = images[1:val_size,]
Y = labels[(val_size+1):n,]
Y_val = labels[1:val_size,]
# Arguments
epochs = 10
workers = 2
batchsize = 64
# Train
[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers)
# Compute validation loss & accuracy
probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
loss_val = cross_entropy_loss::forward(probs_val, Y_val)
accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
# Output results
print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)