blob: ded0746ad40bc382a5ba42bf501e854535fa597c [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
path = $1
out_path = $2
# Load data and take first 100 images
data = read(path, format="csv")
train_data = data[1:500,]
val_data = data[501:550,]
C = 1
Hin = 28
Win = 28
epochs = 10
# Extract images and labels
images = train_data[,2:ncol(train_data)]
labels_int = train_data[,1]
images_val = val_data[,2:ncol(val_data)]
labels_int_val = val_data[,1]
# Scale images to [-1,1], and one-hot encode the labels
n = nrow(train_data)
images = (images / 255.0) * 2 - 1
labels = table(seq(1, n), labels_int+1, n, 10)
n = nrow(val_data)
images_val = (images_val / 255.0) * 2 - 1
labels_val = table(seq(1, n), labels_int_val+1, n, 10)
# Train
model = lenetTrain(images, labels, images_val, labels_val, C, Hin, Win, 128, 3,
0.007, 0.9, 0.95, 5e-04, TRUE, -1)
# Predict on the training set to test capacity of the network
probs = lenetPredict(model=model, X=images, C=C, Hin=Hin, Win=Win)
# Accuracy
correct_pred = rowIndexMax(probs) == rowIndexMax(labels)
accuracy = mean(correct_pred)
print(toString(accuracy))
write(accuracy, out_path)