| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| path = $1 |
| out_path = $2 |
| # Load data and take first 100 images |
| data = read(path, format="csv") |
| train_data = data[1:500,] |
| val_data = data[501:550,] |
| C = 1 |
| Hin = 28 |
| Win = 28 |
| epochs = 10 |
| |
| # Extract images and labels |
| images = train_data[,2:ncol(train_data)] |
| labels_int = train_data[,1] |
| |
| images_val = val_data[,2:ncol(val_data)] |
| labels_int_val = val_data[,1] |
| |
| # Scale images to [-1,1], and one-hot encode the labels |
| n = nrow(train_data) |
| images = (images / 255.0) * 2 - 1 |
| labels = table(seq(1, n), labels_int+1, n, 10) |
| |
| n = nrow(val_data) |
| images_val = (images_val / 255.0) * 2 - 1 |
| labels_val = table(seq(1, n), labels_int_val+1, n, 10) |
| |
| # Train |
| model = lenetTrain(images, labels, images_val, labels_val, C, Hin, Win, 128, 3, |
| 0.007, 0.9, 0.95, 5e-04, TRUE, -1) |
| |
| # Predict on the training set to test capacity of the network |
| probs = lenetPredict(model=model, X=images, C=C, Hin=Hin, Win=Win) |
| |
| # Accuracy |
| correct_pred = rowIndexMax(probs) == rowIndexMax(labels) |
| accuracy = mean(correct_pred) |
| print(toString(accuracy)) |
| write(accuracy, out_path) |