| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| /* |
| * Example script to test different optimizers with AlexNet on ImageNet |
| * |
| * This script demonstrates how different optimizers perform on ImageNet, |
| * particularly focusing on large batch training scenarios. |
| */ |
| |
| source("imagenet_alexnet.dml") as imagenet_alexnet |
| |
| # ImageNet parameters |
| C = 3 # RGB channels |
| Hin = 224 # Height |
| Win = 224 # Width |
| K = 1000 # Number of classes |
| |
| print("\n=======================================================") |
| print("Optimizer Comparison on ImageNet AlexNet") |
| print("=======================================================\n") |
| |
| # For demonstration, we'll use a smaller subset of ImageNet |
| # In practice, you would load the full ImageNet dataset |
| print("Loading ImageNet subset for demonstration...") |
| |
| # Simulate loading training data (5K samples for faster demo) |
| n_train = 5000 |
| X = rand(rows=n_train, cols=C*Hin*Win, min=0, max=1, seed=42) |
| y = rand(rows=n_train, cols=K, min=0, max=0, seed=42) |
| # Create one-hot encoded labels |
| for(i in 1:n_train) { |
| class = as.scalar(round(rand(rows=1, cols=1, min=1, max=K, seed=42+i))) |
| y[i, class] = 1 |
| } |
| |
| # Simulate validation data (500 samples) |
| n_val = 500 |
| X_val = rand(rows=n_val, cols=C*Hin*Win, min=0, max=1, seed=43) |
| y_val = rand(rows=n_val, cols=K, min=0, max=0, seed=43) |
| for(i in 1:n_val) { |
| class = as.scalar(round(rand(rows=1, cols=1, min=1, max=K, seed=43+i))) |
| y_val[i, class] = 1 |
| } |
| |
| # Training parameters |
| epochs = 1 # Reduced for demonstration |
| batch_size = 512 # Medium batch size for fair comparison |
| |
| # Test different optimizers |
| optimizers = list("sgd", "sgd_momentum", "adam", "lars") |
| learning_rates = list(0.01, 0.01, 0.001, 0.1) # Tuned for each optimizer |
| |
| # Store results |
| results = matrix(0, rows=length(optimizers), cols=5) |
| # Columns: optimizer_id, top1_acc, top5_acc, final_loss, train_time |
| |
| print("Configuration:") |
| print("- Dataset: ImageNet subset (demonstration)") |
| print("- Model: AlexNet with Batch Normalization") |
| print("- Training samples: " + n_train) |
| print("- Validation samples: " + n_val) |
| print("- Epochs: " + epochs) |
| print("- Batch size: " + batch_size) |
| print("\n") |
| |
| # Test each optimizer |
| for (i in 1:length(optimizers)) { |
| optimizer = as.scalar(optimizers[i]) |
| lr = as.scalar(learning_rates[i]) |
| |
| print("\n=========================================") |
| print("Testing optimizer: " + optimizer) |
| print("Learning rate: " + lr) |
| print("-----------------------------------------") |
| |
| # Train model |
| start_time = time() |
| model = imagenet_alexnet::train(X, y, X_val, y_val, C, Hin, Win, |
| epochs, optimizer, lr, batch_size) |
| train_time = (time() - start_time) / 1000.0 # Convert to seconds |
| |
| # Extract all model parameters |
| W1 = as.matrix(model["W1"]); b1 = as.matrix(model["b1"]) |
| W2 = as.matrix(model["W2"]); b2 = as.matrix(model["b2"]) |
| W3 = as.matrix(model["W3"]); b3 = as.matrix(model["b3"]) |
| W4 = as.matrix(model["W4"]); b4 = as.matrix(model["b4"]) |
| W5 = as.matrix(model["W5"]); b5 = as.matrix(model["b5"]) |
| W6 = as.matrix(model["W6"]); b6 = as.matrix(model["b6"]) |
| W7 = as.matrix(model["W7"]); b7 = as.matrix(model["b7"]) |
| W8 = as.matrix(model["W8"]); b8 = as.matrix(model["b8"]) |
| |
| # Extract BN parameters |
| gamma1 = as.matrix(model["gamma1"]); beta1 = as.matrix(model["beta1"]) |
| ema_mean1 = as.matrix(model["ema_mean1"]); ema_var1 = as.matrix(model["ema_var1"]) |
| gamma2 = as.matrix(model["gamma2"]); beta2 = as.matrix(model["beta2"]) |
| ema_mean2 = as.matrix(model["ema_mean2"]); ema_var2 = as.matrix(model["ema_var2"]) |
| gamma3 = as.matrix(model["gamma3"]); beta3 = as.matrix(model["beta3"]) |
| ema_mean3 = as.matrix(model["ema_mean3"]); ema_var3 = as.matrix(model["ema_var3"]) |
| gamma4 = as.matrix(model["gamma4"]); beta4 = as.matrix(model["beta4"]) |
| ema_mean4 = as.matrix(model["ema_mean4"]); ema_var4 = as.matrix(model["ema_var4"]) |
| gamma5 = as.matrix(model["gamma5"]); beta5 = as.matrix(model["beta5"]) |
| ema_mean5 = as.matrix(model["ema_mean5"]); ema_var5 = as.matrix(model["ema_var5"]) |
| |
| # Evaluate on validation set |
| probs_val = imagenet_alexnet::predict(X_val, C, Hin, Win, |
| W1, b1, W2, b2, W3, b3, W4, b4, |
| W5, b5, W6, b6, W7, b7, W8, b8, |
| gamma1, beta1, ema_mean1, ema_var1, |
| gamma2, beta2, ema_mean2, ema_var2, |
| gamma3, beta3, ema_mean3, ema_var3, |
| gamma4, beta4, ema_mean4, ema_var4, |
| gamma5, beta5, ema_mean5, ema_var5) |
| [loss_val, top1_acc, top5_acc] = imagenet_alexnet::eval(probs_val, y_val) |
| |
| print("\nFinal Results:") |
| print("Validation Loss: " + loss_val) |
| print("Top-1 Accuracy: " + top1_acc + " (" + (top1_acc * 100) + "%)") |
| print("Top-5 Accuracy: " + top5_acc + " (" + (top5_acc * 100) + "%)") |
| print("Training Time: " + train_time + " seconds") |
| |
| # Store results |
| results[i, 1] = i # optimizer id |
| results[i, 2] = top1_acc |
| results[i, 3] = top5_acc |
| results[i, 4] = loss_val |
| results[i, 5] = train_time |
| } |
| |
| # Print summary comparison |
| print("\n\n=======================================================") |
| print("OPTIMIZER COMPARISON SUMMARY") |
| print("=======================================================") |
| print("\nOptimizer | Top-1 Acc | Top-5 Acc | Val Loss | Time (s)") |
| print("---------------|-----------|-----------|----------|----------") |
| |
| optimizer_names = list("SGD", "SGD+Momentum", "Adam", "LARS") |
| for(i in 1:nrow(results)) { |
| opt_name = as.scalar(optimizer_names[i]) |
| print(sprintf("%-14s | %9.4f | %9.4f | %8.4f | %8.2f", |
| opt_name, |
| as.scalar(results[i,2]), |
| as.scalar(results[i,3]), |
| as.scalar(results[i,4]), |
| as.scalar(results[i,5]))) |
| } |
| |
| # Find best performers |
| best_top1_idx = as.scalar(rowIndexMax(results[,2])) |
| best_top5_idx = as.scalar(rowIndexMax(results[,3])) |
| fastest_idx = as.scalar(rowIndexMin(results[,5])) |
| |
| print("\nBest Performers:") |
| print("- Highest Top-1 Accuracy: " + as.scalar(optimizer_names[best_top1_idx]) + |
| " (" + as.scalar(results[best_top1_idx,2]) + ")") |
| print("- Highest Top-5 Accuracy: " + as.scalar(optimizer_names[best_top5_idx]) + |
| " (" + as.scalar(results[best_top5_idx,3]) + ")") |
| print("- Fastest Training: " + as.scalar(optimizer_names[fastest_idx]) + |
| " (" + as.scalar(results[fastest_idx,5]) + "s)") |
| |
| print("\nKey Observations:") |
| print("1. SGD with momentum typically provides good baseline performance") |
| print("2. Adam converges quickly but may not achieve best final accuracy") |
| print("3. LARS excels with large batch sizes (not fully demonstrated here)") |
| print("4. Proper learning rate tuning is crucial for each optimizer") |
| print("5. Batch normalization helps stabilize training across optimizers") |
| |
| print("\nNote: This is a demonstration with limited data and epochs.") |
| print("Full ImageNet training would require:") |
| print("- 1.2M+ training images") |
| print("- 90+ epochs") |
| print("- Proper data augmentation") |
| print("- Learning rate scheduling") |
| print("=======================================================\n") |