blob: 5ca255a6d2a967b095e32a4b99569f645068c99d [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
/*
* This file contains functions used for training a U-Net model, with or without a parameter server setup.
*/
source("scripts/nn/layers/affine.dml") as affine
source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
source("scripts/nn/layers/conv2d_transpose.dml") as conv2d_transpose
source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
source("scripts/nn/layers/dropout.dml") as dropout
source("scripts/nn/layers/l2_reg.dml") as l2_reg
source("scripts/nn/layers/max_pool2d_builtin.dml") as max_pool2d
source("scripts/nn/layers/relu.dml") as relu
source("scripts/nn/layers/softmax2d.dml") as softmax2d
source("scripts/nn/optim/sgd_momentum.dml") as sgd_momentum
source("scripts/nn/layers/dropout.dml") as dropout
source("scripts/utils/image_utils.dml") as img_utils
/*
* Pad input features X with extrapolated data by mirroring.
* Only the height and width are padded, no extra channels are added.
* Dimensions changed from (N,C*Hin*Win) to (N,C*(Hin+184)*(Win+184)).
*
* Inputs:
* - X: Features to pad
* - N: Number of input elements of X
* - C: Number of channels of X
* - Hin: Height of each element of X
* - Win: Width of each element of X
* - useParfor: Use parfor loop if true
*
* Outputs:
* - X_extrapolated: X padded with extrapolated data
* - input_HW: Height and Width of X_extrapolated
*/
extrapolate = function(matrix[double] X, int N, int C, int Hin, int Win, boolean useParfor = TRUE)
return (matrix[double] X_extrapolated, int input_HW){
input_HW = Hin + 184 # Assuming filter HW 3 and conv stride 1
pad_size = 92 # 184 / 2
channel_width = input_HW*input_HW
if ( useParfor ) {
X_extrapolated = extrapolate_images_parfor(X, N, C, Hin, Win, pad_size, channel_width)
}
else {
X_extrapolated = extrapolate_images_for(X, N, C, Hin, Win, pad_size, channel_width)
}
}
/*
* Pad input features X with extrapolated data by mirroring.
* Only the height and width are padded, no extra channels are added.
* Dimensions changed from (N,C*Hin*Win) to (N,C*(Hin+184)*(Win+184)).
*
* For loop is used when looping over rows.
*
* Inputs:
* - X: Features to pad
* - N: Number of input elements of X
* - C: Number of channels of X
* - Hin: Height of each element of X
* - Win: Width of each element of X
* - pad_size: Pad size for each side of the image
* - channel_width: Width of a single channel
*
* Outputs:
* - X_extrapolated: X padded with extrapolated data
*/
extrapolate_images_for = function(matrix[double] X, int N, int C, int Hin, int Win, int pad_size, int channel_width)
return (matrix[double] X_extrapolated){
X_extrapolated = matrix(0, rows=N, cols=C*channel_width)
for ( row in 1:N ){
img = X[row,]
X_extrapolated[row,] = extrapolate_image(img, C, Hin, Win, pad_size, channel_width)
}
}
/*
* Pad input features X with extrapolated data by mirroring.
* Only the height and width are padded, no extra channels are added.
* Dimensions changed from (N,C*Hin*Win) to (N,C*(Hin+184)*(Win+184)).
*
* Parfor loop is used when looping over rows.
*
* Inputs:
* - X: Features to pad
* - N: Number of input elements of X
* - C: Number of channels of X
* - Hin: Height of each element of X
* - Win: Width of each element of X
* - pad_size: Pad size for each side of the image
* - channel_width: Width of a single channel
*
* Outputs:
* - X_extrapolated: X padded with extrapolated data
*/
extrapolate_images_parfor = function(matrix[double] X, int N, int C, int Hin, int Win, int pad_size, int channel_width)
return (matrix[double] X_extrapolated){
X_extrapolated = matrix(0, rows=N, cols=C*channel_width)
parfor ( row in 1:N, check=0 ){
img = X[row,]
X_extrapolated[row,] = extrapolate_image(img, C, Hin, Win, pad_size, channel_width)
}
}
/*
* Pad input image img with extrapolated data by mirroring.
* Only the height and width are padded, no extra channels are added.
* Dimensions changed from (1,C*Hin*Win) to (1,C*(Hin+184)*(Win+184)).
*
*
* Inputs:
* - img: Image of dimension (1,C*Hin*Win) to pad
* - C: Number of channels of image
* - Hin: Height of image
* - Win: Width of image
* - pad_size: Pad size for each side of the image
* - channel_width: Width of a single channel
*
* Outputs:
* - img_extrapolated: image padded with extrapolated data
*/
extrapolate_image = function(matrix[double] img, int C, int Hin, int Win, int pad_size, int channel_width)
return (matrix[double] img_extrapolated){
img_extrapolated = matrix(0, rows=1, cols=C*channel_width)
for ( i in 1:C ){
start_channel = ((i-1) * Hin * Win)+1
end_channel = i * Hin * Win
channel_slice = matrix(img[1,start_channel:end_channel], rows=Hin, cols=Win)
start_col = ((i-1)*channel_width)+1
end_col = i*channel_width
img_extrapolated[1,start_col:end_col] = extrapolate_channel(channel_slice, Hin, Win, pad_size, channel_width)
}
}
/*
* Pad single channel of input image img with extrapolated data by mirroring.
* Dimensions changed from (Hin,Win) to (1,(Hin+184)*(Win+184)).
*
*
* Inputs:
* - img: Image of dimension (Hin,Win) to pad
* - Hin: Height of image
* - Win: Width of image
* - pad_size: Pad size for each side of the image
* - channel_width: Width of a single channel
*
* Outputs:
* - channel_extrapolated: channel of image padded with extrapolated data
*/
extrapolate_channel = function(matrix[double] img, int Hin, int Win, int pad_size, int channel_width)
return (matrix[double] channel_extrapolated){
pad_left = t(rev(t(img[,1:pad_size])))
pad_right = t(rev(t(img[,(Win-(pad_size-1)):Win])))
pad_top = rev(img[1:(pad_size),])
pad_bottom = rev(img[(Hin-(pad_size-1)):Hin,])
pad_top_left = rev(pad_left[1:(pad_size),])
pad_top_right = rev(pad_right[1:(pad_size),])
pad_bottom_left = rev(pad_left[(Hin-(pad_size-1)):Hin,])
pad_bottom_right = rev(pad_right[(Hin-(pad_size-1)):Hin])
pad_left_full = rbind(pad_top_left, pad_left, pad_bottom_left)
pad_right_full = rbind(pad_top_right, pad_right, pad_bottom_right)
pad_center_full = rbind(pad_top, img, pad_bottom)
modified_channel = cbind(pad_left_full, pad_center_full, pad_right_full)
channel_extrapolated = matrix(modified_channel, rows=1, cols=channel_width)
}
/*
* Trains a U-Net model with a parameter server setup.
*
* The input matrix, X, has N examples, each represented as a 3D
* volume unrolled into a single vector. The targets, Y, have K
* classes representing the segmentation map of the input.
*
* Inputs:
* - X: Input data matrix, of shape (N, C*Hin*Win)
* - y: Target matrix, of shape (N, K)
* - X_val: Input validation data matrix, of shape (N, C*Hin*Win)
* - y_val: Target validation matrix, of shape (N, K)
* - C: Number of input channels
* - Hin: Input height
* - Win: Input width
* - epochs: Total number of full training loops over the full data set
* - workers: Number of federated workers
* - utype: Update type (synchronous, asynchronous, etc.)
* - freq: Frequency of weight updates ("BATCH" or "EPOCH")
* - batch_size: Batch size
* - scheme: Parameter server training scheme
* - learning_rate: The learning rate for the SGD with momentum
* - seed: Seed for the initialization of the convolution weights. Default is -1 meaning that the seeds are random.
* - M: Size of the segmentation map (C*(Hin-pad)*(Win-pad))
* - K: Number of output categories (for each element of segmentation map)
* - he: Homomorphic encryption activated (boolean)
* - F1: Number of filters of the top layer of the U-Net model. Default is 64.
*
* Outputs:
* - model_trained: List containing weights and biases
*/
train_paramserv = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val,
int C, int Hin, int Win, int epochs, int workers,
string utype, string freq, int batch_size, string scheme, double learning_rate,
int seed = -1, int M, int K, boolean he = FALSE, int F1 = 64)
return (list[unknown] model_trained) {
N = nrow(X) # Number of inputs
# Define model network constants
Hf = 3 # convolution filter height
Wf = 3 # convolution filter width
conv_stride = 1
pool_stride = 2
pool_HWf = 2
conv_t_HWf = 2
conv_t_stride = 2
pad = 0 # For same dimensions, (Hf - stride) / 2
F2 = F1 * 2
F3 = F2 * 2
F4 = F3 * 2
F5 = F4 * 2
dropProb = 1.0
dropSeed = -1
# Create different seeds for each layer, unless seed is -1
lseed = list()
for ( i in 1:23 ){
if (seed == -1){
lseed = rbind(lseed, -1)
} else {
lseed = rbind(lseed, seed+i)
}
}
# Initialize convolution weights
# First step
[W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = as.integer(as.scalar(lseed[1]))) # inputs: (N, C*Hin*Win)
[W2, b2] = conv2d::init(F1, F1, Hf, Wf, seed = as.integer(as.scalar(lseed[2])))
# Second step
[W3, b3] = conv2d::init(F2, F1, Hf, Wf, seed = as.integer(as.scalar(lseed[3])))
[W4, b4] = conv2d::init(F2, F2, Hf, Wf, seed = as.integer(as.scalar(lseed[4])))
# Third step
[W5, b5] = conv2d::init(F3, F2, Hf, Wf, seed = as.integer(as.scalar(lseed[5])))
[W6, b6] = conv2d::init(F3, F3, Hf, Wf, seed = as.integer(as.scalar(lseed[6])))
# Fourth step
[W7, b7] = conv2d::init(F4, F3, Hf, Wf, seed = as.integer(as.scalar(lseed[7])))
[W8, b8] = conv2d::init(F4, F4, Hf, Wf, seed = as.integer(as.scalar(lseed[8])))
# Fifth step
[W9, b9] = conv2d::init(F5, F4, Hf, Wf, seed = as.integer(as.scalar(lseed[9])))
[W10, b10] = conv2d::init(F5, F5, Hf, Wf, seed = as.integer(as.scalar(lseed[10])))
# First Up-convolution
[W11, b11] = conv2d_transpose::init_seed(F4, F5, conv_t_HWf, conv_t_HWf, seed = as.integer(as.scalar(lseed[11])))
[W12, b12] = conv2d::init(F4, F5, Hf, Wf, seed = as.integer(as.scalar(lseed[12])))
[W13, b13] = conv2d::init(F4, F4, Hf, Wf, seed = as.integer(as.scalar(lseed[13])))
# Second Up-convolution
[W14, b14] = conv2d_transpose::init_seed(F3, F4, conv_t_HWf, conv_t_HWf, seed = as.integer(as.scalar(lseed[14])))
[W15, b15] = conv2d::init(F3, F4, Hf, Wf, seed = as.integer(as.scalar(lseed[15])))
[W16, b16] = conv2d::init(F3, F3, Hf, Wf, seed = as.integer(as.scalar(lseed[16])))
# Third Up-convolution
[W17, b17] = conv2d_transpose::init_seed(F2, F3, conv_t_HWf, conv_t_HWf, seed = as.integer(as.scalar(lseed[17])))
[W18, b18] = conv2d::init(F2, F3, Hf, Wf, seed = as.integer(as.scalar(lseed[18])))
[W19, b19] = conv2d::init(F2, F2, Hf, Wf, seed = as.integer(as.scalar(lseed[19])))
# Fourth Up-convolution
[W20, b20] = conv2d_transpose::init_seed(F1, F2, conv_t_HWf, conv_t_HWf, seed = as.integer(as.scalar(lseed[20])))
[W21, b21] = conv2d::init(F1, F2, Hf, Wf, seed = as.integer(as.scalar(lseed[21])))
[W22, b22] = conv2d::init(F1, F1, Hf, Wf, seed = as.integer(as.scalar(lseed[22])))
# Segmentation map
[W23, b23] = conv2d::init(K*C, F1, 1, 1, seed = as.integer(as.scalar(lseed[23])))
# Initialize SGD with momentum
vW1 = sgd_momentum::init(W1); vb1 = sgd_momentum::init(b1)
vW2 = sgd_momentum::init(W2); vb2 = sgd_momentum::init(b2)
vW3 = sgd_momentum::init(W3); vb3 = sgd_momentum::init(b3)
vW4 = sgd_momentum::init(W4); vb4 = sgd_momentum::init(b4)
vW5 = sgd_momentum::init(W5); vb5 = sgd_momentum::init(b5)
vW6 = sgd_momentum::init(W6); vb6 = sgd_momentum::init(b6)
vW7 = sgd_momentum::init(W7); vb7 = sgd_momentum::init(b7)
vW8 = sgd_momentum::init(W8); vb8 = sgd_momentum::init(b8)
vW9 = sgd_momentum::init(W9); vb9 = sgd_momentum::init(b9)
vW10 = sgd_momentum::init(W10); vb10 = sgd_momentum::init(b10)
vW11 = sgd_momentum::init(W11); vb11 = sgd_momentum::init(b11)
vW12 = sgd_momentum::init(W12); vb12 = sgd_momentum::init(b12)
vW13 = sgd_momentum::init(W13); vb13 = sgd_momentum::init(b13)
vW14 = sgd_momentum::init(W14); vb14 = sgd_momentum::init(b14)
vW15 = sgd_momentum::init(W15); vb15 = sgd_momentum::init(b15)
vW16 = sgd_momentum::init(W16); vb16 = sgd_momentum::init(b16)
vW17 = sgd_momentum::init(W17); vb17 = sgd_momentum::init(b17)
vW18 = sgd_momentum::init(W18); vb18 = sgd_momentum::init(b18)
vW19 = sgd_momentum::init(W19); vb19 = sgd_momentum::init(b19)
vW20 = sgd_momentum::init(W20); vb20 = sgd_momentum::init(b20)
vW21 = sgd_momentum::init(W21); vb21 = sgd_momentum::init(b21)
vW22 = sgd_momentum::init(W22); vb22 = sgd_momentum::init(b22)
vW23 = sgd_momentum::init(W23); vb23 = sgd_momentum::init(b23)
# Define optimizer constants
mu = 0.9 # momentum
decay = 0.95 # learning rate decay constant
# Regularization
lambda = 5e-04
# Create the model list
model_list = list(
W1, W2, W3, W4, W5, W6, W7, W8, W9, W10, W11, W12, W13, W14, W15, W16, W17, W18, W19, W20, W21, W22, W23,
b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, b16, b17, b18, b19, b20, b21, b22, b23,
vW1, vW2, vW3, vW4, vW5, vW6, vW7, vW8, vW9, vW10, vW11, vW12, vW13, vW14, vW15, vW16, vW17, vW18, vW19, vW20, vW21, vW22, vW23,
vb1, vb2, vb3, vb4, vb5, vb6, vb7, vb8, vb9, vb10, vb11, vb12, vb13, vb14, vb15, vb16, vb17, vb18, vb19, vb20, vb21, vb22, vb23)
# Create the hyper parameter list
params = list(
learning_rate=learning_rate, mu=mu, decay=decay, M=M, K=K, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf,
conv_stride=conv_stride, pool_stride=pool_stride, pool_HWf=pool_HWf, conv_t_HWf=conv_t_HWf, conv_t_stride=conv_t_stride,
pad=pad, lambda=lambda, F1=F1, F2=F2, F3=F3, F4=F4, F5=F5, dropProb=dropProb, dropSeed=dropSeed)
# Use paramserv function
model_trained = paramserv(model=model_list, features=X, labels=y, val_features=X_val, val_labels=y_val,
upd="./scripts/nn/examples/u-net.dml::gradients",
agg="./scripts/nn/examples/u-net.dml::aggregation",
scheme=scheme, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
k=workers, hyperparams=params, checkpointing="NONE", he=he, modelAvg=TRUE)
}
/*
* Forward pass of U-Net model on X using specified batch size.
*
* Inputs:
* - X: Input features of size C*Hin*Win.
* The features need to be padded with mirrored data, hence the actual input size before padding is C*(Hin-pad)*(Win-pad).
* - C: Number of input channels
* - Hin: Input height
* - Win: Input width
* - batch_size: Batch size
* - model: List of weights of the model (23 weights, 23 biases)
* - M: Size of the segmentation map (C*(Hin-pad)*(Win-pad))
* - K: Number of output categories (for each element of segmentation map)
* - F1: Number of filters of the top layer of the U-Net model. Default is 64.
*
* Output:
* - probs: Segmentation map probabilities generated by the forward pass of the U-Net model
*/
predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size, list[unknown] model, int M, int K, int F1 = 64)
return (matrix[double] probs) {
W1 = as.matrix(model[1])
W2 = as.matrix(model[2])
W3 = as.matrix(model[3])
W4 = as.matrix(model[4])
W5 = as.matrix(model[5])
W6 = as.matrix(model[6])
W7 = as.matrix(model[7])
W8 = as.matrix(model[8])
W9 = as.matrix(model[9])
W10 = as.matrix(model[10])
W11 = as.matrix(model[11])
W12 = as.matrix(model[12])
W13 = as.matrix(model[13])
W14 = as.matrix(model[14])
W15 = as.matrix(model[15])
W16 = as.matrix(model[16])
W17 = as.matrix(model[17])
W18 = as.matrix(model[18])
W19 = as.matrix(model[19])
W20 = as.matrix(model[20])
W21 = as.matrix(model[21])
W22 = as.matrix(model[22])
W23 = as.matrix(model[23])
b1 = as.matrix(model[24])
b2 = as.matrix(model[25])
b3 = as.matrix(model[26])
b4 = as.matrix(model[27])
b5 = as.matrix(model[28])
b6 = as.matrix(model[29])
b7 = as.matrix(model[30])
b8 = as.matrix(model[31])
b9 = as.matrix(model[32])
b10 = as.matrix(model[33])
b11 = as.matrix(model[34])
b12 = as.matrix(model[35])
b13 = as.matrix(model[36])
b14 = as.matrix(model[37])
b15 = as.matrix(model[38])
b16 = as.matrix(model[39])
b17 = as.matrix(model[40])
b18 = as.matrix(model[41])
b19 = as.matrix(model[42])
b20 = as.matrix(model[43])
b21 = as.matrix(model[44])
b22 = as.matrix(model[45])
b23 = as.matrix(model[46])
N = nrow(X) # Number of inputs
Hf = 3 # convolution filter height
Wf = 3 # convolution filter width
conv_stride = 1
pool_stride = 2
pool_HWf = 2
conv_t_HWf = 2
conv_t_stride = 2
pad = 0 # For same dimensions, (Hf - stride) / 2
F2 = F1 * 2
F3 = F2 * 2
F4 = F3 * 2
F5 = F4 * 2
dropProb = 1.0
dropSeed = -1
# Compute predictions over mini-batches
probs = matrix(0, rows=N, cols=K*M)
iters = ceil(N / batch_size)
for(i in 1:iters, check=0) {
# Get next batch
beg = ((i-1) * batch_size) %% N + 1
end = min(N, beg + batch_size - 1)
X_batch = X[beg:end,]
# Down-Convolution
[outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr1 = relu::forward(outc1)
[outc2, Houtc2, Woutc2] = conv2d::forward(outr1, W2, b2, F1, Houtc1, Woutc1, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr2 = relu::forward(outc2)
[outp1, Houtp1, Woutp1] = max_pool2d::forward(outr2, F1, Houtc2, Woutc2, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
[outc3, Houtc3, Woutc3] = conv2d::forward(outp1, W3, b3, F1, Houtp1, Woutp1, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr3 = relu::forward(outc3)
[outc4, Houtc4, Woutc4] = conv2d::forward(outr3, W4, b4, F2, Houtc3, Woutc3, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr4 = relu::forward(outc4)
[outp2, Houtp2, Woutp2] = max_pool2d::forward(outr4, F2, Houtc4, Woutc4, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
[outc5, Houtc5, Woutc5] = conv2d::forward(outp2, W5, b5, F2, Houtp2, Woutp2, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr5 = relu::forward(outc5)
[outc6, Houtc6, Woutc6] = conv2d::forward(outr5, W6, b6, F3, Houtc5, Woutc5, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr6 = relu::forward(outc6)
[outp3, Houtp3, Woutp3] = max_pool2d::forward(outr6, F3, Houtc6, Woutc6, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
[outc7, Houtc7, Woutc7] = conv2d::forward(outp3, W7, b7, F3, Houtp3, Woutp3, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr7 = relu::forward(outc7)
[outc8, Houtc8, Woutc8] = conv2d::forward(outr7, W8, b8, F4, Houtc7, Woutc7, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr8 = relu::forward(outc8)
[outp4, Houtp4, Woutp4] = max_pool2d::forward(outr8, F4, Houtc8, Woutc8, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
[outd1, mask1] = dropout::forward(outp4, dropProb, dropSeed)
# Bottom
[outc9, Houtc9, Woutc9] = conv2d::forward(outd1, W9, b9, F4, Houtp4, Woutp4, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr9 = relu::forward(outc9)
[outc10, Houtc10, Woutc10] = conv2d::forward(outr9, W10, b10, F5, Houtc9, Woutc9, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr10 = relu::forward(outc10)
[outc11, Houtc11, Woutc11] = conv2d_transpose::forward(outr10, W11, b11, F5, Houtc10, Woutc10, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0)
# Up-Convolution
outConcat1 = cbind(img_utils::crop_channel(outr8, Houtc8, Woutc8, Houtc11, Woutc11, F4),outc11)
[outc12, Houtc12, Woutc12] = conv2d::forward(outConcat1, W12, b12, F5, Houtc11, Woutc11, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr11 = relu::forward(outc12)
[outc13, Houtc13, Woutc13] = conv2d::forward(outr11, W13, b13, F4, Houtc12, Woutc12, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr12 = relu::forward(outc13)
[outc14, Houtc14, Woutc14] = conv2d_transpose::forward(outr12, W14, b14, F4, Houtc13, Woutc13, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0)
outConcat2 = cbind(img_utils::crop_channel(outr6, Houtc6, Woutc6, Houtc14, Woutc14, F3),outc14)
[outc15, Houtc15, Woutc15] = conv2d::forward(outConcat2, W15, b15, F4, Houtc14, Woutc14, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr13 = relu::forward(outc15)
[outc16, Houtc16, Woutc16] = conv2d::forward(outr13, W16, b16, F3, Houtc15, Woutc15, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr14 = relu::forward(outc16)
[outc17, Houtc17, Woutc17] = conv2d_transpose::forward(outr14, W17, b17, F3, Houtc16, Woutc16, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0)
outConcat3 = cbind(img_utils::crop_channel(outr4, Houtc4, Woutc4, Houtc17, Woutc17, F2), outc17)
[outc18, Houtc18, Woutc18] = conv2d::forward(outConcat3, W18, b18, F3, Houtc17, Woutc17, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr15 = relu::forward(outc18)
[outc19, Houtc19, Woutc19] = conv2d::forward(outr15, W19, b19, F2, Houtc18, Woutc18, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr16 = relu::forward(outc19)
[outc20, Houtc20, Woutc20] = conv2d_transpose::forward(outr16, W20, b20, F2, Houtc19, Woutc19, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0)
outConcat4 = cbind(img_utils::crop_channel(outr2, Houtc2, Woutc2, Houtc20, Woutc20, F1), outc20)
[outc21, Houtc21, Woutc21] = conv2d::forward(outConcat4, W21, b21, F2, Houtc20, Woutc20, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr17 = relu::forward(outc21)
[outc22, Houtc22, Woutc22] = conv2d::forward(outr17, W22, b22, F1, Houtc21, Woutc21, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr18 = relu::forward(outc22)
# This last conv2d needs to create the segmentation map (1x1 filter):
[outc23, Houtc23, Woutc23] = conv2d::forward(outr18, W23, b23, F1, Houtc22, Woutc22, 1, 1, conv_stride, conv_stride, pad, pad)
# Store predictions
probs[beg:end,] = softmax2d::forward(outc23, K)
}
}
/*
* Forward and backward pass of U-Net with gradients returned.
*
* Inputs:
* - model: List of model weights
* - hyperparams: List of hyper parameters containing:
* - (scalar[integer]) C: Number of input channels
* - (scalar[integer]) Hin: Input height
* - (scalar[integer]) Win: Input width
* - (scalar[integer]) Hf: Filter height
* - (scalar[integer]) Wf: Filter width
* - (scalar[integer]) pool_stride: Stride of the max pool operations
* - (scalar[integer]) pool_HWf: Filter height and width of the max pool operation
* - (scalar[integer]) conv_stride: Stride of all convolutions
* - (scalar[integer]) conv_t_HWf: Filter height and width of the transpose convolutions
* - (scalar[integer]) conv_t_stride: Stride of the transpose convolutions
* - (scalar[integer]) pad: Padding of all convolutions and transpose convolutions
* - (scalar[double]) lambda: Regularization strength
* - (scalar[integer]) F1, F2, F3, F4, F5: Number of filters of the convolutions in the five layers
* - (scalar[double]) dropProb: Dropout probability
* - (scalar[integer]) dropSeed: Dropout seed
* - (scalar[integer]) M: Size of the segmentation map (C*(Hin-pad)*(Win-pad))
* - (scalar[integer]) K: Number of output categories (for each element of segmentation map)
* - features: Features of size C*Hin*Win. The features need to be padded with mirrored data.
* The input feature size should result in an output size of the U-Net equal to the label size.
* See extrapolate function for how to pad the features by extrapolating.
* - labels: Labels of size C * (Hin-pad) * (Win-pad) representing a segmentation map of the input features.
* Output:
* - gradients: List of gradients
*/
gradients = function(list[unknown] model,
list[unknown] hyperparams,
matrix[double] features,
matrix[double] labels)
return (list[unknown] gradients) {
C = as.integer(as.scalar(hyperparams["C"]))
Hin = as.integer(as.scalar(hyperparams["Hin"]))
Win = as.integer(as.scalar(hyperparams["Win"]))
Hf = as.integer(as.scalar(hyperparams["Hf"]))
Wf = as.integer(as.scalar(hyperparams["Wf"]))
pool_stride = as.integer(as.scalar(hyperparams["pool_stride"]))
pool_HWf = as.integer(as.scalar(hyperparams["pool_HWf"]))
conv_stride = as.integer(as.scalar(hyperparams["conv_stride"]))
conv_t_HWf = as.integer(as.scalar(hyperparams["conv_t_HWf"]))
conv_t_stride = as.integer(as.scalar(hyperparams["conv_t_stride"]))
pad = as.integer(as.scalar(hyperparams["pad"]))
lambda = as.double(as.scalar(hyperparams["lambda"]))
F1 = as.integer(as.scalar(hyperparams["F1"]))
F2 = as.integer(as.scalar(hyperparams["F2"]))
F3 = as.integer(as.scalar(hyperparams["F3"]))
F4 = as.integer(as.scalar(hyperparams["F4"]))
F5 = as.integer(as.scalar(hyperparams["F5"]))
dropProb = as.double(as.scalar(hyperparams["dropProb"]))
dropSeed = as.integer(as.scalar(hyperparams["dropSeed"]))
M = as.integer(as.scalar(hyperparams["M"]))
K = as.integer(as.scalar(hyperparams["K"]))
W1 = as.matrix(model[1])
W2 = as.matrix(model[2])
W3 = as.matrix(model[3])
W4 = as.matrix(model[4])
W5 = as.matrix(model[5])
W6 = as.matrix(model[6])
W7 = as.matrix(model[7])
W8 = as.matrix(model[8])
W9 = as.matrix(model[9])
W10 = as.matrix(model[10])
W11 = as.matrix(model[11])
W12 = as.matrix(model[12])
W13 = as.matrix(model[13])
W14 = as.matrix(model[14])
W15 = as.matrix(model[15])
W16 = as.matrix(model[16])
W17 = as.matrix(model[17])
W18 = as.matrix(model[18])
W19 = as.matrix(model[19])
W20 = as.matrix(model[20])
W21 = as.matrix(model[21])
W22 = as.matrix(model[22])
W23 = as.matrix(model[23])
b1 = as.matrix(model[24])
b2 = as.matrix(model[25])
b3 = as.matrix(model[26])
b4 = as.matrix(model[27])
b5 = as.matrix(model[28])
b6 = as.matrix(model[29])
b7 = as.matrix(model[30])
b8 = as.matrix(model[31])
b9 = as.matrix(model[32])
b10 = as.matrix(model[33])
b11 = as.matrix(model[34])
b12 = as.matrix(model[35])
b13 = as.matrix(model[36])
b14 = as.matrix(model[37])
b15 = as.matrix(model[38])
b16 = as.matrix(model[39])
b17 = as.matrix(model[40])
b18 = as.matrix(model[41])
b19 = as.matrix(model[42])
b20 = as.matrix(model[43])
b21 = as.matrix(model[44])
b22 = as.matrix(model[45])
b23 = as.matrix(model[46])
# Down-Convolution
[outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr1 = relu::forward(outc1)
[outc2, Houtc2, Woutc2] = conv2d::forward(outr1, W2, b2, F1, Houtc1, Woutc1, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr2 = relu::forward(outc2)
[outp1, Houtp1, Woutp1] = max_pool2d::forward(outr2, F1, Houtc2, Woutc2, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
[outc3, Houtc3, Woutc3] = conv2d::forward(outp1, W3, b3, F1, Houtp1, Woutp1, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr3 = relu::forward(outc3)
[outc4, Houtc4, Woutc4] = conv2d::forward(outr3, W4, b4, F2, Houtc3, Woutc3, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr4 = relu::forward(outc4)
[outp2, Houtp2, Woutp2] = max_pool2d::forward(outr4, F2, Houtc4, Woutc4, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
[outc5, Houtc5, Woutc5] = conv2d::forward(outp2, W5, b5, F2, Houtp2, Woutp2, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr5 = relu::forward(outc5)
[outc6, Houtc6, Woutc6] = conv2d::forward(outr5, W6, b6, F3, Houtc5, Woutc5, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr6 = relu::forward(outc6)
[outp3, Houtp3, Woutp3] = max_pool2d::forward(outr6, F3, Houtc6, Woutc6, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
[outc7, Houtc7, Woutc7] = conv2d::forward(outp3, W7, b7, F3, Houtp3, Woutp3, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr7 = relu::forward(outc7)
[outc8, Houtc8, Woutc8] = conv2d::forward(outr7, W8, b8, F4, Houtc7, Woutc7, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr8 = relu::forward(outc8)
[outp4, Houtp4, Woutp4] = max_pool2d::forward(outr8, F4, Houtc8, Woutc8, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
[outd1, mask1] = dropout::forward(outp4, dropProb, dropSeed)
# Bottom
[outc9, Houtc9, Woutc9] = conv2d::forward(outd1, W9, b9, F4, Houtp4, Woutp4, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr9 = relu::forward(outc9)
[outc10, Houtc10, Woutc10] = conv2d::forward(outr9, W10, b10, F5, Houtc9, Woutc9, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr10 = relu::forward(outc10)
[outc11, Houtc11, Woutc11] = conv2d_transpose::forward(outr10, W11, b11, F5, Houtc10, Woutc10, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0)
# Up-Convolution
outConcat1 = cbind(img_utils::crop_channel(outr8, Houtc8, Woutc8, Houtc11, Woutc11, F4),outc11)
[outc12, Houtc12, Woutc12] = conv2d::forward(outConcat1, W12, b12, F5, Houtc11, Woutc11, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr11 = relu::forward(outc12)
[outc13, Houtc13, Woutc13] = conv2d::forward(outr11, W13, b13, F4, Houtc12, Woutc12, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr12 = relu::forward(outc13)
[outc14, Houtc14, Woutc14] = conv2d_transpose::forward(outr12, W14, b14, F4, Houtc13, Woutc13, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0)
outConcat2 = cbind(img_utils::crop_channel(outr6, Houtc6, Woutc6, Houtc14, Woutc14, F3),outc14)
[outc15, Houtc15, Woutc15] = conv2d::forward(outConcat2, W15, b15, F4, Houtc14, Woutc14, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr13 = relu::forward(outc15)
[outc16, Houtc16, Woutc16] = conv2d::forward(outr13, W16, b16, F3, Houtc15, Woutc15, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr14 = relu::forward(outc16)
[outc17, Houtc17, Woutc17] = conv2d_transpose::forward(outr14, W17, b17, F3, Houtc16, Woutc16, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0)
outConcat3 = cbind(img_utils::crop_channel(outr4, Houtc4, Woutc4, Houtc17, Woutc17, F2), outc17)
[outc18, Houtc18, Woutc18] = conv2d::forward(outConcat3, W18, b18, F3, Houtc17, Woutc17, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr15 = relu::forward(outc18)
[outc19, Houtc19, Woutc19] = conv2d::forward(outr15, W19, b19, F2, Houtc18, Woutc18, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr16 = relu::forward(outc19)
[outc20, Houtc20, Woutc20] = conv2d_transpose::forward(outr16, W20, b20, F2, Houtc19, Woutc19, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0)
outConcat4 = cbind(img_utils::crop_channel(outr2, Houtc2, Woutc2, Houtc20, Woutc20, F1), outc20)
[outc21, Houtc21, Woutc21] = conv2d::forward(outConcat4, W21, b21, F2, Houtc20, Woutc20, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr17 = relu::forward(outc21)
[outc22, Houtc22, Woutc22] = conv2d::forward(outr17, W22, b22, F1, Houtc21, Woutc21, Hf, Wf, conv_stride, conv_stride, pad, pad)
outr18 = relu::forward(outc22)
# This last conv2d needs to create the segmentation map (1x1 filter):
[outc23, Houtc23, Woutc23] = conv2d::forward(outr18, W23, b23, F1, Houtc22, Woutc22, 1, 1, conv_stride, conv_stride, pad, pad)
probs = softmax2d::forward(outc23, K)
# Compute loss & accuracy for training data
loss = cross_entropy_loss::forward(probs, labels)
accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
# print("[+] Completed forward pass on batch: train loss: " + loss + ", train accuracy: " + accuracy)
# Compute data backward pass
## loss
dprobs = cross_entropy_loss::backward(probs, labels)
doutc23 = softmax2d::backward(dprobs, outc23, K)
# Up-Convolution
# conv2d parameters: (previous_gradient, output height, output width, input to original layer, layer weight, layer bias, layer input channel number, input height, input width, filter height, filter width, stride height, stride width, pad height, pad width)
[doutc22, dW23, db23] = conv2d::backward(doutc23, Houtc23, Woutc23, outr18, W23, b23, F1, Houtc22, Woutc22, 1, 1, conv_stride, conv_stride, pad, pad)
doutr18 = relu::backward(doutc22, outc22)
[doutc21, dW22, db22] = conv2d::backward(doutr18, Houtc22, Woutc22, outr17, W22, b22, F1, Houtc21, Woutc21, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutr17 = relu::backward(doutc21, outc21)
[doutc20, dW21, db21] = conv2d::backward(doutr17, Houtc21, Woutc21, outConcat4, W21, b21, F2, Houtc20, Woutc20, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutc20_cropped = doutc20[,(F1*Houtc20*Woutc20+1):(ncol(doutc20))] #Removing half of the gradients since they are related to a different layer.
[doutc19, dW20, db20] = conv2d_transpose::backward(doutc20_cropped, Houtc20, Woutc20, outr16, W20, b20, F2, Houtc19, Woutc19, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad)
doutr16 = relu::backward(doutc19, outc19)
[doutc18, dW19, db19] = conv2d::backward(doutr16, Houtc19, Woutc19, outr15, W19, b19, F2, Houtc18, Woutc18, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutr15 = relu::backward(doutc18, outc18)
[doutc17, dW18, db18] = conv2d::backward(doutr15, Houtc18, Woutc18, outConcat3, W18, b18, F3, Houtc17, Woutc17, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutc17_cropped = doutc17[,(F2*Houtc17*Woutc17+1):ncol(doutc17)]
[doutc16, dW17, db17] = conv2d_transpose::backward(doutc17_cropped, Houtc17, Woutc17, outr14, W17, b17, F3, Houtc16, Woutc16, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad)
doutr14 = relu::backward(doutc16, outc16)
[doutc15, dW16, db16] = conv2d::backward(doutr14, Houtc16, Woutc16, outr13, W16, b16, F3, Houtc15, Woutc15, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutr13 = relu::backward(doutc15, outc15)
[doutc14, dW15, db15] = conv2d::backward(doutr13, Houtc15, Woutc15, outConcat2, W15, b15, F4, Houtc14, Woutc14, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutc14_cropped = doutc14[,(F3*Houtc14*Woutc14+1):ncol(doutc14)]
[doutc13, dW14, db14] = conv2d_transpose::backward(doutc14_cropped, Houtc14, Woutc14, outr12, W14, b14, F4, Houtc13, Woutc13, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad)
doutr12 = relu::backward(doutc13, outc13)
[doutc12, dW13, db13] = conv2d::backward(doutr12, Houtc13, Woutc13, outr11, W13, b13, F4, Houtc12, Woutc12, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutr11 = relu::backward(doutc12, outc12)
[doutc11, dW12, db12] = conv2d::backward(doutr11, Houtc12, Woutc12, outConcat1, W12, b12, F5, Houtc11, Woutc11, Hf, Wf, conv_stride, conv_stride, pad, pad)
# Bottom
doutc11_cropped = doutc11[,(F4*Houtc11*Woutc11+1):ncol(doutc11)]
[doutc10, dW11, db11] = conv2d_transpose::backward(doutc11_cropped, Houtc11, Woutc11, outr10, W11, b11, F5, Houtc10, Woutc10, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad)
doutr10 = relu::backward(doutc10, outc10)
[doutc9, dW10, db10] = conv2d::backward(doutr10, Houtc10, Woutc10, outr9, W10, b10, F5, Houtc9, Woutc9, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutr9 = relu::backward(doutc9, outc9)
[doutc8, dW9, db9] = conv2d::backward(doutr9, Houtc9, Woutc9, outd1, W9, b9, F4, Houtp4, Woutp4, Hf, Wf, conv_stride, conv_stride, pad, pad)
# Down-Convolution
doutd1 = dropout::backward(doutc8, outp4, dropProb, mask1)
doutp4 = max_pool2d::backward(doutd1, Houtp4, Woutp4, outr8, F4, Houtc8, Woutc8, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
doutr8 = relu::backward(doutp4, outc8)
[doutc7, dW8, db8] = conv2d::backward(doutr8, Houtc8, Woutc8, outr7, W8, b8, F4, Houtc7, Woutc7, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutr7 = relu::backward(doutc7, outc7)
[doutc6, dW7, db7] = conv2d::backward(doutr7, Houtc7, Woutc7, outp3, W7, b7, F3, Houtp3, Woutp3, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutp3 = max_pool2d::backward(doutc6, Houtp3, Woutp3, outr6, F3, Houtc6, Woutc6, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
doutr6 = relu::backward(doutp3, outc6)
[doutc5, dW6, db6] = conv2d::backward(doutr6, Houtc6, Woutc6, outr5, W6, b6, F3, Houtc5, Woutc5, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutr5 = relu::backward(doutc5, outc5)
[doutc4, dW5, db5] = conv2d::backward(doutr5, Houtc5, Woutc5, outp2, W5, b5, F2, Houtp2, Woutp2, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutp2 = max_pool2d::backward(doutc4, Houtp2, Woutp2, outr4, F2, Houtc4, Woutc4, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
doutr4 = relu::backward(doutp2, outc4)
[doutc3, dW4, db4] = conv2d::backward(doutr4, Houtc4, Woutc4, outr3, W4, b4, F2, Houtc3, Woutc3, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutr3 = relu::backward(doutc3, outc3)
[doutc2, dW3, db3] = conv2d::backward(doutr3, Houtc3, Woutc3, outp1, W3, b3, F1, Houtp1, Woutp1, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutp1 = max_pool2d::backward(doutc2, Houtp1, Woutp1, outr2, F1, Houtc2, Woutc2, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
doutr2 = relu::backward(doutp1, outc2)
[doutc1, dW2, db2] = conv2d::backward(doutr2, Houtc2, Woutc2, outr1, W2, b2, F1, Houtc1, Woutc1, Hf, Wf, conv_stride, conv_stride, pad, pad)
doutr1 = relu::backward(doutc1,outc1)
[dx_batch, dW1, db1] = conv2d::backward(doutr1, Houtc1, Woutc1, features, W1, b1, C, Hin, Win, Hf, Wf, conv_stride, conv_stride, pad, pad)
# Compute regularization backward pass
dW1_reg = l2_reg::backward(W1, lambda)
dW2_reg = l2_reg::backward(W2, lambda)
dW3_reg = l2_reg::backward(W3, lambda)
dW4_reg = l2_reg::backward(W4, lambda)
dW5_reg = l2_reg::backward(W5, lambda)
dW6_reg = l2_reg::backward(W6, lambda)
dW7_reg = l2_reg::backward(W7, lambda)
dW8_reg = l2_reg::backward(W8, lambda)
dW9_reg = l2_reg::backward(W9, lambda)
dW10_reg = l2_reg::backward(W10, lambda)
dW11_reg = l2_reg::backward(W11, lambda)
dW12_reg = l2_reg::backward(W12, lambda)
dW13_reg = l2_reg::backward(W13, lambda)
dW14_reg = l2_reg::backward(W14, lambda)
dW15_reg = l2_reg::backward(W15, lambda)
dW16_reg = l2_reg::backward(W16, lambda)
dW17_reg = l2_reg::backward(W17, lambda)
dW18_reg = l2_reg::backward(W18, lambda)
dW19_reg = l2_reg::backward(W19, lambda)
dW20_reg = l2_reg::backward(W20, lambda)
dW21_reg = l2_reg::backward(W21, lambda)
dW22_reg = l2_reg::backward(W22, lambda)
dW23_reg = l2_reg::backward(W23, lambda)
dW1 = dW1 + dW1_reg
dW2 = dW2 + dW2_reg
dW3 = dW3 + dW3_reg
dW4 = dW4 + dW4_reg
dW5 = dW5 + dW5_reg
dW6 = dW6 + dW6_reg
dW7 = dW7 + dW7_reg
dW8 = dW8 + dW8_reg
dW9 = dW9 + dW9_reg
dW10 = dW10 + dW10_reg
dW11 = dW11 + dW11_reg
dW12 = dW12 + dW12_reg
dW13 = dW13 + dW13_reg
dW14 = dW14 + dW14_reg
dW15 = dW15 + dW15_reg
dW16 = dW16 + dW16_reg
dW17 = dW17 + dW17_reg
dW18 = dW18 + dW18_reg
dW19 = dW19 + dW19_reg
dW20 = dW20 + dW20_reg
dW21 = dW21 + dW21_reg
dW22 = dW22 + dW22_reg
dW23 = dW23 + dW23_reg
gradients = list(
dW1, dW2, dW3, dW4, dW5, dW6, dW7, dW8, dW9, dW10, dW11, dW12, dW13, dW14, dW15, dW16, dW17, dW18, dW19, dW20, dW21, dW22, dW23,
db1, db2, db3, db4, db5, db6, db7, db8, db9, db10, db11, db12, db13, db14, db15, db16, db17, db18, db19, db20, db21, db22, db23
)
}
/*
* Updates the model weights based on gradients and hyperparameters (learning rate and mu).
*
* Inputs:
* - model: List of model weights
* - hyperparams: List of hyper parameters containing (scalar[double]) learning_rate and (scalar[double]) mu
* - gradients: List of gradients
*
* Outputs:
* - model_result: List of updated model weights
*/
aggregation = function(list[unknown] model,
list[unknown] hyperparams,
list[unknown] gradients)
return (list[unknown] model_result) {
W1 = as.matrix(model[1])
W2 = as.matrix(model[2])
W3 = as.matrix(model[3])
W4 = as.matrix(model[4])
W5 = as.matrix(model[5])
W6 = as.matrix(model[6])
W7 = as.matrix(model[7])
W8 = as.matrix(model[8])
W9 = as.matrix(model[9])
W10 = as.matrix(model[10])
W11 = as.matrix(model[11])
W12 = as.matrix(model[12])
W13 = as.matrix(model[13])
W14 = as.matrix(model[14])
W15 = as.matrix(model[15])
W16 = as.matrix(model[16])
W17 = as.matrix(model[17])
W18 = as.matrix(model[18])
W19 = as.matrix(model[19])
W20 = as.matrix(model[20])
W21 = as.matrix(model[21])
W22 = as.matrix(model[22])
W23 = as.matrix(model[23])
b1 = as.matrix(model[24])
b2 = as.matrix(model[25])
b3 = as.matrix(model[26])
b4 = as.matrix(model[27])
b5 = as.matrix(model[28])
b6 = as.matrix(model[29])
b7 = as.matrix(model[30])
b8 = as.matrix(model[31])
b9 = as.matrix(model[32])
b10 = as.matrix(model[33])
b11 = as.matrix(model[34])
b12 = as.matrix(model[35])
b13 = as.matrix(model[36])
b14 = as.matrix(model[37])
b15 = as.matrix(model[38])
b16 = as.matrix(model[39])
b17 = as.matrix(model[40])
b18 = as.matrix(model[41])
b19 = as.matrix(model[42])
b20 = as.matrix(model[43])
b21 = as.matrix(model[44])
b22 = as.matrix(model[45])
b23 = as.matrix(model[46])
dW1 = as.matrix(gradients[1])
dW2 = as.matrix(gradients[2])
dW3 = as.matrix(gradients[3])
dW4 = as.matrix(gradients[4])
dW5 = as.matrix(gradients[5])
dW6 = as.matrix(gradients[6])
dW7 = as.matrix(gradients[7])
dW8 = as.matrix(gradients[8])
dW9 = as.matrix(gradients[9])
dW10 = as.matrix(gradients[10])
dW11 = as.matrix(gradients[11])
dW12 = as.matrix(gradients[12])
dW13 = as.matrix(gradients[13])
dW14 = as.matrix(gradients[14])
dW15 = as.matrix(gradients[15])
dW16 = as.matrix(gradients[16])
dW17 = as.matrix(gradients[17])
dW18 = as.matrix(gradients[18])
dW19 = as.matrix(gradients[19])
dW20 = as.matrix(gradients[20])
dW21 = as.matrix(gradients[21])
dW22 = as.matrix(gradients[22])
dW23 = as.matrix(gradients[23])
db1 = as.matrix(gradients[24])
db2 = as.matrix(gradients[25])
db3 = as.matrix(gradients[26])
db4 = as.matrix(gradients[27])
db5 = as.matrix(gradients[28])
db6 = as.matrix(gradients[29])
db7 = as.matrix(gradients[30])
db8 = as.matrix(gradients[31])
db9 = as.matrix(gradients[32])
db10 = as.matrix(gradients[33])
db11 = as.matrix(gradients[34])
db12 = as.matrix(gradients[35])
db13 = as.matrix(gradients[36])
db14 = as.matrix(gradients[37])
db15 = as.matrix(gradients[38])
db16 = as.matrix(gradients[39])
db17 = as.matrix(gradients[40])
db18 = as.matrix(gradients[41])
db19 = as.matrix(gradients[42])
db20 = as.matrix(gradients[43])
db21 = as.matrix(gradients[44])
db22 = as.matrix(gradients[45])
db23 = as.matrix(gradients[46])
vW1 = as.matrix(model[47])
vW2 = as.matrix(model[48])
vW3 = as.matrix(model[49])
vW4 = as.matrix(model[50])
vW5 = as.matrix(model[51])
vW6 = as.matrix(model[52])
vW7 = as.matrix(model[53])
vW8 = as.matrix(model[54])
vW9 = as.matrix(model[55])
vW10 = as.matrix(model[56])
vW11 = as.matrix(model[57])
vW12 = as.matrix(model[58])
vW13 = as.matrix(model[59])
vW14 = as.matrix(model[60])
vW15 = as.matrix(model[61])
vW16 = as.matrix(model[62])
vW17 = as.matrix(model[63])
vW18 = as.matrix(model[64])
vW19 = as.matrix(model[65])
vW20 = as.matrix(model[66])
vW21 = as.matrix(model[67])
vW22 = as.matrix(model[68])
vW23 = as.matrix(model[69])
vb1 = as.matrix(model[70])
vb2 = as.matrix(model[71])
vb3 = as.matrix(model[72])
vb4 = as.matrix(model[73])
vb5 = as.matrix(model[74])
vb6 = as.matrix(model[75])
vb7 = as.matrix(model[76])
vb8 = as.matrix(model[77])
vb9 = as.matrix(model[78])
vb10 = as.matrix(model[79])
vb11 = as.matrix(model[80])
vb12 = as.matrix(model[81])
vb13 = as.matrix(model[82])
vb14 = as.matrix(model[83])
vb15 = as.matrix(model[84])
vb16 = as.matrix(model[85])
vb17 = as.matrix(model[86])
vb18 = as.matrix(model[87])
vb19 = as.matrix(model[88])
vb20 = as.matrix(model[89])
vb21 = as.matrix(model[90])
vb22 = as.matrix(model[91])
vb23 = as.matrix(model[92])
learning_rate = as.double(as.scalar(hyperparams["learning_rate"]))
mu = as.double(as.scalar(hyperparams["mu"]))
# Optimize with SGD with momentum
[W1, vW1] = sgd_momentum::update(W1, dW1, learning_rate, mu, vW1)
[b1, vb1] = sgd_momentum::update(b1, db1, learning_rate, mu, vb1)
[W2, vW2] = sgd_momentum::update(W2, dW2, learning_rate, mu, vW2)
[b2, vb2] = sgd_momentum::update(b2, db2, learning_rate, mu, vb2)
[W3, vW3] = sgd_momentum::update(W3, dW3, learning_rate, mu, vW3)
[b3, vb3] = sgd_momentum::update(b3, db3, learning_rate, mu, vb3)
[W4, vW4] = sgd_momentum::update(W4, dW4, learning_rate, mu, vW4)
[b4, vb4] = sgd_momentum::update(b4, db4, learning_rate, mu, vb4)
[W5, vW5] = sgd_momentum::update(W5, dW5, learning_rate, mu, vW5)
[b5, vb5] = sgd_momentum::update(b5, db5, learning_rate, mu, vb5)
[W6, vW6] = sgd_momentum::update(W6, dW6, learning_rate, mu, vW6)
[b6, vb6] = sgd_momentum::update(b6, db6, learning_rate, mu, vb6)
[W7, vW7] = sgd_momentum::update(W7, dW7, learning_rate, mu, vW7)
[b7, vb7] = sgd_momentum::update(b7, db7, learning_rate, mu, vb7)
[W8, vW8] = sgd_momentum::update(W8, dW8, learning_rate, mu, vW8)
[b8, vb8] = sgd_momentum::update(b8, db8, learning_rate, mu, vb8)
[W9, vW9] = sgd_momentum::update(W9, dW9, learning_rate, mu, vW9)
[b9, vb9] = sgd_momentum::update(b9, db9, learning_rate, mu, vb9)
[W10, vW10] = sgd_momentum::update(W10, dW10, learning_rate, mu, vW10)
[b10, vb10] = sgd_momentum::update(b10, db10, learning_rate, mu, vb10)
[W11, vW11] = sgd_momentum::update(W11, dW11, learning_rate, mu, vW11)
[b11, vb11] = sgd_momentum::update(b11, db11, learning_rate, mu, vb11)
[W12, vW12] = sgd_momentum::update(W12, dW12, learning_rate, mu, vW12)
[b12, vb12] = sgd_momentum::update(b12, db12, learning_rate, mu, vb12)
[W13, vW13] = sgd_momentum::update(W13, dW13, learning_rate, mu, vW13)
[b13, vb13] = sgd_momentum::update(b13, db13, learning_rate, mu, vb13)
[W14, vW14] = sgd_momentum::update(W14, dW14, learning_rate, mu, vW14)
[b14, vb14] = sgd_momentum::update(b14, db14, learning_rate, mu, vb14)
[W15, vW15] = sgd_momentum::update(W15, dW15, learning_rate, mu, vW15)
[b15, vb15] = sgd_momentum::update(b15, db15, learning_rate, mu, vb15)
[W16, vW16] = sgd_momentum::update(W16, dW16, learning_rate, mu, vW16)
[b16, vb16] = sgd_momentum::update(b16, db16, learning_rate, mu, vb16)
[W17, vW17] = sgd_momentum::update(W17, dW17, learning_rate, mu, vW17)
[b17, vb17] = sgd_momentum::update(b17, db17, learning_rate, mu, vb17)
[W18, vW18] = sgd_momentum::update(W18, dW18, learning_rate, mu, vW18)
[b18, vb18] = sgd_momentum::update(b18, db18, learning_rate, mu, vb18)
[W19, vW19] = sgd_momentum::update(W19, dW19, learning_rate, mu, vW19)
[b19, vb19] = sgd_momentum::update(b19, db19, learning_rate, mu, vb19)
[W20, vW20] = sgd_momentum::update(W20, dW20, learning_rate, mu, vW20)
[b20, vb20] = sgd_momentum::update(b20, db20, learning_rate, mu, vb20)
[W21, vW21] = sgd_momentum::update(W21, dW21, learning_rate, mu, vW21)
[b21, vb21] = sgd_momentum::update(b21, db21, learning_rate, mu, vb21)
[W22, vW22] = sgd_momentum::update(W22, dW22, learning_rate, mu, vW22)
[b22, vb22] = sgd_momentum::update(b22, db22, learning_rate, mu, vb22)
[W23, vW23] = sgd_momentum::update(W23, dW23, learning_rate, mu, vW23)
[b23, vb23] = sgd_momentum::update(b23, db23, learning_rate, mu, vb23)
model_result = list(
W1, W2, W3, W4, W5, W6, W7, W8, W9, W10, W11, W12, W13, W14, W15, W16, W17, W18, W19, W20, W21, W22, W23,
b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, b16, b17, b18, b19, b20, b21, b22, b23,
vW1, vW2, vW3, vW4, vW5, vW6, vW7, vW8, vW9, vW10, vW11, vW12, vW13, vW14, vW15, vW16, vW17, vW18, vW19, vW20, vW21, vW22, vW23,
vb1, vb2, vb3, vb4, vb5, vb6, vb7, vb8, vb9, vb10, vb11, vb12, vb13, vb14, vb15, vb16, vb17, vb18, vb19, vb20, vb21, vb22, vb23
)
}
/*
* Evaluates a U-Net architecture.
*
* The probs matrix contains the class probability predictions and y contains the target.
*
* Inputs:
* - probs: Class probabilities, of shape (N, C*Hin*Win)
* - y: Target matrix, of shape (N, C*Hin*Win)
*
* Outputs:
* - loss: Scalar loss, of shape (1)
* - accuracy: Scalar accuracy, of shape (1)
*/
eval = function(matrix[double] probs, matrix[double] y)
return (double loss, double accuracy) {
# Compute loss & accuracy
loss = cross_entropy_loss::forward(probs, y)
correct_pred = rowIndexMax(probs) == rowIndexMax(y)
accuracy = mean(correct_pred)
}
/*
* Gives the accuracy and loss for a model and given feature and label matrices
*
* This function is a combination of the predict and eval function used for validation.
* For inputs see eval and predict.
*
* Inputs:
* - val_features: Validation data features
* - val_labels: Validation data labels
* - model: List of weights of the trained model
* - hyperparams: Hyperparameters including C, Hin, Win, and K
* - F1: Number of filters of the top layer of the U-Net model. Default is 64.
* - batch_size: Batch size of prediction
*
* Outputs:
* - loss: Scalar loss, of shape (1).
* - accuracy: Scalar accuracy, of shape (1).
*/
validate = function(matrix[double] val_features, matrix[double] val_labels,
list[unknown] model, list[unknown] hyperparams, int F1 = 64, int batch_size = 32)
return (double loss, double accuracy)
{
C = as.integer(as.scalar(hyperparams["C"]))
Hin = as.integer(as.scalar(hyperparams["Hin"]))
Win = as.integer(as.scalar(hyperparams["Win"]))
M = as.integer(as.scalar(hyperparams["M"]))
K = as.integer(as.scalar(hyperparams["K"]))
predictions = predict(val_features, C, Hin, Win, batch_size, model, M, K, F1)
[loss, accuracy] = eval(predictions, val_labels)
}