blob: da5862edbee1ff9d461743537e75f044ddee100c [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# TODO move to builtin functions (needs fix for imports in builtin functions)
# TODO scale up to real EfficientNet-B0
# Trains a partial Efficient-Net B0 model
# This script trains the top and bottom part of the Efficient-Net B0
# The original Efficient-Net B0 has the following Layers
#----------------------------------------------------------------
# Layers Dimension Filters Nr Repeats
#----------------------------------------------------------------
# 1. Conv3x3 224x224 32 1
# 2. MBConv1, k3x3 112x112 16 1
# 3. MBConv6, k3x3 56x 56 24 2
# 4. MBConv6, k5x5 28x 28 40 2
# 5. MBConv6, k3x3 14x 14 80 3
# 6. MBConv6, k5x5 14x 14 112 3
# 7. MBConv6, k5x5 7x 7 192 4
# 8. MBConv6, k3x3 7x 7 320 1
# 9. Conv1x1 & Pooling & FC 7x 7 1280 1
#----------------------------------------------------------------
# In this partial implementation we implement the layers number 1, 2 and the prediction layer 9
# This init-Method is purely for convenience reasons there is not problem with a manual initialization of weight and
# biases. To extend the current implementation to a full EfficientNet-B0 only the intermediate MBConv need to be extended
# Both stem and top part are already complete as is the first MBConv layer.
# The number after MBConv is the corresponding ExpansionFactor and is followed
# by the kernel size stride and padding can be calculated from the dimension. If the layer is repeated
# The skip connection is activated otherwise not.
#----------------------------------------------------------------
source("nn/layers/batch_norm2d.dml") as batchnorm
source("nn/layers/conv2d_builtin.dml") as conv2d
source("nn/layers/conv2d_depthwise.dml") as depthwise
source("nn/layers/global_avg_pool2d.dml") as global_avg_pool
source("nn/layers/silu.dml") as silu
source("nn/layers/upsample2d.dml") as upsample
source("nn/layers/mbconv.dml") as mbconv
source("nn/layers/affine.dml") as affine
source("nn/layers/softmax.dml") as softmax
source("nn/optim/sgd.dml") as sgd
source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
initNetwork = function(int InputChannels, int NumberOutputClasses, int seed)
return(list[unknown] model)
{
/*
* Convenience function for initialization of all required weights and biases.
*
* Inputs:
* - InputChannels: Number of Input Channels for the model (Cin)
* - NumberOutputClasses: Number of classes for the network
* - seed: seed for the random generation of the weights
*
* Outputs:
* - model: A list containing the total of 36 matrices needed for the computation of the
* Mini EfficientNet
*/
# Layer 1
[CW_stem, Cb_stem] = conv2d::init(32, InputChannels, 3, 3, seed)
seed = ifelse(seed==-1, -1, seed + 1);
[Gamma_stem, Beta_stem, EmaMean_stem, EmaVar_stem] = batchnorm::init(32)
# Layer 2
[mb_parameters] = mbconv::init(32, 16, 3, 3, 1, 0.25, seed)
seed = ifelse(seed==-1, -1, seed + 1);
# Layer 9
[CW_top, Cb_top] = conv2d::init(1280, 16, 1, 1, seed)
seed = ifelse(seed==-1, -1, seed + 1);
[Gamma_top, Beta_top, EmaMean_top, EmaVar_top] = batchnorm::init(1280)
[DW_top, Db_top] = affine::init(1280, NumberOutputClasses, seed)
model = list(CW_stem, Cb_stem, Gamma_stem, Beta_stem, EmaMean_stem, EmaVar_stem,
as.matrix(mb_parameters[1]),
as.matrix(mb_parameters[2]),
as.matrix(mb_parameters[3]),
as.matrix(mb_parameters[4]),
as.matrix(mb_parameters[5]),
as.matrix(mb_parameters[6]),
as.matrix(mb_parameters[7]),
as.matrix(mb_parameters[8]),
as.matrix(mb_parameters[9]),
as.matrix(mb_parameters[10]),
as.matrix(mb_parameters[11]),
as.matrix(mb_parameters[12]),
as.matrix(mb_parameters[13]),
as.matrix(mb_parameters[14]),
as.matrix(mb_parameters[15]),
as.matrix(mb_parameters[16]),
as.matrix(mb_parameters[17]),
as.matrix(mb_parameters[18]),
as.matrix(mb_parameters[19]),
as.matrix(mb_parameters[20]),
as.matrix(mb_parameters[21]),
as.matrix(mb_parameters[22]),
CW_top, Cb_top, Gamma_top, Beta_top, EmaMean_top, EmaVar_top, DW_top, Db_top)
}
netPredict = function(matrix[double] X, list[unknown] model, int Cin, int Hin, int Win)
return(matrix[double] pred)
{
/*
* This function generates the prediction of the model for a input X
*
* Inputs:
* - X: Input features of format (N, Cin * Hin * Win)
* - model: the list of length 36 containing the matrices generated from the initNetwork function
* - Cin: Number of input channels (dimensionality of depth).
* - Hin: Input height.
* - Win: Input width.
*
* Outputs:
* - pred: The output of the final softmax layer of the Mini Efficient-Net
*/
CW_stem = as.matrix(model[1])
Cb_stem = as.matrix(model[2])
Gamma_stem = as.matrix(model[3])
Beta_stem = as.matrix(model[4])
EmaMean_stem = as.matrix(model[5])
EmaVar_stem = as.matrix(model[6])
MBConv_params = model[7:28]
CW_top = as.matrix(model[29])
Cb_top = as.matrix(model[30])
Gamma_top = as.matrix(model[31])
Beta_top = as.matrix(model[32])
EmaMean_top = as.matrix(model[33])
EmaVar_top = as.matrix(model[34])
DW_top = as.matrix(model[35])
Db_top = as.matrix(model[36])
padh = (Hin + 1) %% 2
padw = (Win + 1) %% 2
[stem_out, stem_h, stem_w] = conv2d::forward(X, CW_stem, Cb_stem, Cin, Hin, Win, 3, 3, 2, 2, padh, padw)
[bn_stem_out, update_EmaMean_stem, update_EmaVar_stem, cache_EmaMean_stem, cache_EmaVar_stem] = batchnorm::forward(
stem_out, Gamma_stem, Beta_stem, 32, stem_h, stem_w, "train", EmaMean_stem, EmaVar_stem, 0.9, 1e-5)
silu_out = silu::forward(bn_stem_out)
[mbconv_out, intermediate_mbconv, mbconvbatchnorm_updates, mbconv_h, mbconv_w] = mbconv::forward(
silu_out, MBConv_params, 32, 16, stem_h, stem_w, 3, 3, 2, 2, padh, padw, FALSE, 1, "train", 0.25)
[top_out, outh, outw] = conv2d::forward(mbconv_out, CW_top, Cb_top, 16, mbconv_h, mbconv_w, 1, 1, 1, 1, 0, 0)
[bntop_out, update_EmaMean_top, update_EmaVar_top, cache_EmaMean_top, cache_EmaVar_top] = batchnorm::forward(
top_out, Gamma_top, Beta_top, 1280, outh, outw, "train", EmaMean_top, EmaVar_top, 0.9, 1e-5)
silu_out2 = silu::forward(bntop_out)
[pool_out, None, None] = global_avg_pool::forward(silu_out2, 1280, outh, outw)
dense_out = affine::forward(pool_out, DW_top, Db_top)
pred = softmax::forward(dense_out)
}
netTrain = function(list[unknown] model, matrix[double] X, int Cin, int Hin, int Win,
matrix[double] Y, int epochs, int batch_size, double learning_rate, double lr_decay, boolean verbose)
return(list[unknown] trained_model)
{
/*
* This function trains the given model with an sgd optimizer with the given batch_size for a number of
* epochs.
*
* Inputs:
* - model: the list of length 36 containing the matrices generated from the initNetwork function
* - X: Input features of format (N, Cin * Hin * Win)
* - Cin: Number of input channels (dimensionality of depth).
* - Hin: Input height.
* - Win: Input width.
* - Y: The true labels used for learning in a OneHotEncoding(N, NumberOutClasses)
* - epochs: Number of epochs to train for
* - batch_size: Size of batch used for a single update step
* - learning_rate: Size of batch used for a single update step
* - lr_decay: The learning rate is multiplied with lr_decay after each epoch.
* - verbose: Whether the accuracy and the cross-entropy loss should be printed after each update step
*
* Outputs:
* - trained_model: The new list of the updated 36 matrices
*/
CW_stem = as.matrix(model[1])
Cb_stem = as.matrix(model[2])
Gamma_stem = as.matrix(model[3])
Beta_stem = as.matrix(model[4])
EmaMean_stem = as.matrix(model[5])
EmaVar_stem = as.matrix(model[6])
MBConv_params = model[7:28]
CW_top = as.matrix(model[29])
Cb_top = as.matrix(model[30])
Gamma_top = as.matrix(model[31])
Beta_top = as.matrix(model[32])
EmaMean_top = as.matrix(model[33])
EmaVar_top = as.matrix(model[34])
DW_top = as.matrix(model[35])
Db_top = as.matrix(model[36])
padh = (Hin + 1) %% 2
padw = (Win + 1) %% 2
N = nrow(X)
lr = learning_rate
# Optimize
iters = ceil(N / batch_size)
for (e in 1:epochs) {
for(i in 1:iters) {
# Get next batch
beg = ((i-1) * batch_size) %% N + 1
end = min(N, beg + batch_size - 1)
X_batch = X[beg:end,]
y_batch = Y[beg:end,]
# Compute forward pass
[stem_out, stem_h, stem_w] = conv2d::forward(X_batch, CW_stem, Cb_stem, Cin, Hin, Win, 3, 3, 2, 2, padh, padw)
[bn_stem_out, update_EmaMean_stem, update_EmaVar_stem, cache_EmaMean_stem, cache_EmaVar_stem] = batchnorm::forward(stem_out, Gamma_stem, Beta_stem, 32, stem_h, stem_w, "train", EmaMean_stem, EmaVar_stem, 0.9, 1e-5)
silu_out = silu::forward(bn_stem_out)
[mbconv_out, intermediate_mbconv, mbconvbatchnorm_updates, mbconv_h, mbconv_w] = mbconv::forward(silu_out, MBConv_params, 32, 16, stem_h, stem_w, 3, 3, 2, 2, padh, padw, FALSE, 1, "train", 0.25)
[top_out, outh, outw] = conv2d::forward(mbconv_out, CW_top, Cb_top, 16, mbconv_h, mbconv_w, 1, 1, 1, 1, 0, 0)
[bntop_out, update_EmaMean_top, update_EmaVar_top, cache_EmaMean_top, cache_EmaVar_top] = batchnorm::forward(top_out, Gamma_top, Beta_top, 1280, outh, outw, "train", EmaMean_top, EmaVar_top, 0.9, 1e-5)
silu_out2 = silu::forward(bntop_out)
[pool_out, None, None] = global_avg_pool::forward(silu_out2, 1280, outh, outw)
dense_out = affine::forward(pool_out, DW_top, Db_top)
pred = softmax::forward(dense_out)
# Compute loss & accuracy for training
loss = cross_entropy_loss::forward(pred, y_batch)
if(verbose) {
accuracy = mean(rowIndexMax(pred) == rowIndexMax(y_batch))
print("Epoch: " + e + ", Iter: " + i + ", Train Loss: " + loss + ", Train Accuracy: " + accuracy)
}
# Compute backward pass
## loss:
dprobs = cross_entropy_loss::backward(pred, y_batch)
## TOP
d_softmax = softmax::backward(dprobs, dense_out)
[d_dense_back, dDenseW_top, dDenseb_top] = affine::backward(d_softmax, pool_out, DW_top, Db_top)
d_pool_back = global_avg_pool::backward(d_dense_back, silu_out2, 1280, outh, outw)
d_silu2_back = silu::backward(d_pool_back, bntop_out)
[d_bntop_back, dGamma_top, dBeta_top] = batchnorm::backward(d_silu2_back, cache_EmaMean_top, cache_EmaVar_top, top_out, Gamma_top, 1280, outh, outw, 1e-5)
[dtop_back, d_ConvW_top, d_Convb_top] = conv2d::backward(d_bntop_back, outh, outw, mbconv_out, CW_top, Cb_top, 16, mbconv_h, mbconv_w, 1, 1, 1, 1, 0, 0)
# MBCONV
[d_mbconv_back, mbconv_gradients] = mbconv::backward(dtop_back, silu_out, MBConv_params, intermediate_mbconv, mbconvbatchnorm_updates, 32, 16, stem_h, stem_w, 3, 3, 2, 2, padh, padw, FALSE, 1, "train", 0.25)
## STEM
d_silu_back = silu::backward(d_mbconv_back, bn_stem_out)
[d_bn_stem_back, dGamma_stem, dBeta_stem] = batchnorm::backward(d_silu_back, cache_EmaMean_stem, cache_EmaVar_stem, stem_out, Gamma_stem, 32, stem_h, stem_w, 1e-5)
[dconv_back, dW_stem, db_stem] = conv2d::backward(d_bn_stem_back, stem_h, stem_w, X_batch, CW_stem, Cb_stem, Cin, Hin, Win, 3, 3, 2, 2, padh, padw)
#Optimize with SGD
# Update Stem
CW_stem = sgd::update(CW_stem, dW_stem, lr)
Cb_stem = sgd::update(Cb_stem, db_stem, lr)
Gamma_stem = sgd::update(Gamma_stem, dGamma_stem, lr)
Beta_stem = sgd::update(Beta_stem, dBeta_stem, lr)
EmaMean_stem = update_EmaMean_stem
EmaVar_stem = update_EmaVar_stem
# Update MBConv
update_depth_W = sgd::update(as.matrix(MBConv_params[7]), as.matrix(mbconv_gradients[11]), lr)
update_depth_b = sgd::update(as.matrix(MBConv_params[8]), as.matrix(mbconv_gradients[12]), lr)
update_gamma_depth = sgd::update(as.matrix(MBConv_params[9]), as.matrix(mbconv_gradients[9]), lr)
update_beta_depth = sgd::update(as.matrix(MBConv_params[10]), as.matrix(mbconv_gradients[10]), lr)
update_ema_mean_depth = as.matrix(mbconvbatchnorm_updates[5])
update_ema_var_depth = as.matrix(mbconvbatchnorm_updates[6])
update_squeeze_W = sgd::update(as.matrix(MBConv_params[13]), as.matrix(mbconv_gradients[7]), lr)
update_squeeze_b = sgd::update(as.matrix(MBConv_params[14]), as.matrix(mbconv_gradients[8]), lr)
update_excite_W = sgd::update(as.matrix(MBConv_params[15]), as.matrix(mbconv_gradients[5]), lr)
update_excite_b = sgd::update(as.matrix(MBConv_params[16]), as.matrix(mbconv_gradients[6]), lr)
update_out_W = sgd::update(as.matrix(MBConv_params[17]), as.matrix(mbconv_gradients[3]), lr)
update_out_b = sgd::update(as.matrix(MBConv_params[18]), as.matrix(mbconv_gradients[4]), lr)
update_out_gamma = sgd::update(as.matrix(MBConv_params[19]), as.matrix(mbconv_gradients[1]), lr)
update_out_beta = sgd::update(as.matrix(MBConv_params[20]), as.matrix(mbconv_gradients[2]), lr)
update_ema_mean_out = as.matrix(mbconvbatchnorm_updates[9])
update_ema_var_out = as.matrix(mbconvbatchnorm_updates[10])
MBConv_params = list(
as.matrix(model[7]), as.matrix(model[8]),
as.matrix(model[9]), as.matrix(model[10]),
as.matrix(model[11]), as.matrix(model[12]),
update_depth_W, update_depth_b,
update_gamma_depth, update_beta_depth,
update_ema_mean_depth, update_ema_var_depth,
update_squeeze_W, update_squeeze_b,
update_excite_W, update_excite_b,
update_out_W, update_out_b,
update_out_gamma, update_out_beta,
update_ema_mean_out, update_ema_var_out)
# Update Top
CW_top = sgd::update(CW_top, d_ConvW_top, lr)
Cb_top = sgd::update(Cb_top, d_Convb_top, lr)
Gamma_top = sgd::update(Gamma_top, dGamma_top, lr)
Beta_top = sgd::update(Beta_top, dBeta_top, lr)
EmaMean_top = update_EmaMean_top
EmaVar_top = update_EmaVar_top
DW_top = sgd::update(DW_top, dDenseW_top, lr)
Db_top = sgd::update(Db_top, dDenseb_top, lr)
}
# Decay learning rate
lr = lr * lr_decay
}
# Pack everything into model format
trained_model = list(CW_stem, Cb_stem, Gamma_stem, Beta_stem, EmaMean_stem, EmaVar_stem,
as.matrix(MBConv_params[1]), as.matrix(MBConv_params[2]),
as.matrix(MBConv_params[3]), as.matrix(MBConv_params[4]),
as.matrix(MBConv_params[5]), as.matrix(MBConv_params[6]),
as.matrix(MBConv_params[7]), as.matrix(MBConv_params[8]),
as.matrix(MBConv_params[9]), as.matrix(MBConv_params[10]),
as.matrix(MBConv_params[11]), as.matrix(MBConv_params[12]),
as.matrix(MBConv_params[13]), as.matrix(MBConv_params[14]),
as.matrix(MBConv_params[15]), as.matrix(MBConv_params[16]),
as.matrix(MBConv_params[17]), as.matrix(MBConv_params[18]),
as.matrix(MBConv_params[19]), as.matrix(MBConv_params[20]),
as.matrix(MBConv_params[21]), as.matrix(MBConv_params[22]),
CW_top, Cb_top, Gamma_top, Beta_top, EmaMean_top, EmaVar_top, DW_top, Db_top)
}