| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| /* |
| * This file contains functions used for training a U-Net model, with or without a parameter server setup. |
| */ |
| |
| source("scripts/nn/layers/affine.dml") as affine |
| source("scripts/nn/layers/conv2d_builtin.dml") as conv2d |
| source("scripts/nn/layers/conv2d_transpose.dml") as conv2d_transpose |
| source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss |
| source("scripts/nn/layers/dropout.dml") as dropout |
| source("scripts/nn/layers/l2_reg.dml") as l2_reg |
| source("scripts/nn/layers/max_pool2d_builtin.dml") as max_pool2d |
| source("scripts/nn/layers/relu.dml") as relu |
| source("scripts/nn/layers/softmax2d.dml") as softmax2d |
| source("scripts/nn/optim/sgd_momentum.dml") as sgd_momentum |
| source("scripts/nn/layers/dropout.dml") as dropout |
| source("scripts/utils/image_utils.dml") as img_utils |
| |
| /* |
| * Pad input features X with extrapolated data by mirroring. |
| * Only the height and width are padded, no extra channels are added. |
| * Dimensions changed from (N,C*Hin*Win) to (N,C*(Hin+184)*(Win+184)). |
| * |
| * Inputs: |
| * - X: Features to pad |
| * - N: Number of input elements of X |
| * - C: Number of channels of X |
| * - Hin: Height of each element of X |
| * - Win: Width of each element of X |
| * - useParfor: Use parfor loop if true |
| * |
| * Outputs: |
| * - X_extrapolated: X padded with extrapolated data |
| * - input_HW: Height and Width of X_extrapolated |
| */ |
| extrapolate = function(matrix[double] X, int N, int C, int Hin, int Win, boolean useParfor = TRUE) |
| return (matrix[double] X_extrapolated, int input_HW){ |
| input_HW = Hin + 184 # Assuming filter HW 3 and conv stride 1 |
| pad_size = 92 # 184 / 2 |
| channel_width = input_HW*input_HW |
| |
| if ( useParfor ) { |
| X_extrapolated = extrapolate_images_parfor(X, N, C, Hin, Win, pad_size, channel_width) |
| } |
| else { |
| X_extrapolated = extrapolate_images_for(X, N, C, Hin, Win, pad_size, channel_width) |
| } |
| } |
| |
| /* |
| * Pad input features X with extrapolated data by mirroring. |
| * Only the height and width are padded, no extra channels are added. |
| * Dimensions changed from (N,C*Hin*Win) to (N,C*(Hin+184)*(Win+184)). |
| * |
| * For loop is used when looping over rows. |
| * |
| * Inputs: |
| * - X: Features to pad |
| * - N: Number of input elements of X |
| * - C: Number of channels of X |
| * - Hin: Height of each element of X |
| * - Win: Width of each element of X |
| * - pad_size: Pad size for each side of the image |
| * - channel_width: Width of a single channel |
| * |
| * Outputs: |
| * - X_extrapolated: X padded with extrapolated data |
| */ |
| extrapolate_images_for = function(matrix[double] X, int N, int C, int Hin, int Win, int pad_size, int channel_width) |
| return (matrix[double] X_extrapolated){ |
| X_extrapolated = matrix(0, rows=N, cols=C*channel_width) |
| for ( row in 1:N ){ |
| img = X[row,] |
| X_extrapolated[row,] = extrapolate_image(img, C, Hin, Win, pad_size, channel_width) |
| } |
| } |
| |
| /* |
| * Pad input features X with extrapolated data by mirroring. |
| * Only the height and width are padded, no extra channels are added. |
| * Dimensions changed from (N,C*Hin*Win) to (N,C*(Hin+184)*(Win+184)). |
| * |
| * Parfor loop is used when looping over rows. |
| * |
| * Inputs: |
| * - X: Features to pad |
| * - N: Number of input elements of X |
| * - C: Number of channels of X |
| * - Hin: Height of each element of X |
| * - Win: Width of each element of X |
| * - pad_size: Pad size for each side of the image |
| * - channel_width: Width of a single channel |
| * |
| * Outputs: |
| * - X_extrapolated: X padded with extrapolated data |
| */ |
| extrapolate_images_parfor = function(matrix[double] X, int N, int C, int Hin, int Win, int pad_size, int channel_width) |
| return (matrix[double] X_extrapolated){ |
| X_extrapolated = matrix(0, rows=N, cols=C*channel_width) |
| parfor ( row in 1:N, check=0 ){ |
| img = X[row,] |
| X_extrapolated[row,] = extrapolate_image(img, C, Hin, Win, pad_size, channel_width) |
| } |
| } |
| |
| /* |
| * Pad input image img with extrapolated data by mirroring. |
| * Only the height and width are padded, no extra channels are added. |
| * Dimensions changed from (1,C*Hin*Win) to (1,C*(Hin+184)*(Win+184)). |
| * |
| * |
| * Inputs: |
| * - img: Image of dimension (1,C*Hin*Win) to pad |
| * - C: Number of channels of image |
| * - Hin: Height of image |
| * - Win: Width of image |
| * - pad_size: Pad size for each side of the image |
| * - channel_width: Width of a single channel |
| * |
| * Outputs: |
| * - img_extrapolated: image padded with extrapolated data |
| */ |
| extrapolate_image = function(matrix[double] img, int C, int Hin, int Win, int pad_size, int channel_width) |
| return (matrix[double] img_extrapolated){ |
| img_extrapolated = matrix(0, rows=1, cols=C*channel_width) |
| for ( i in 1:C ){ |
| start_channel = ((i-1) * Hin * Win)+1 |
| end_channel = i * Hin * Win |
| channel_slice = matrix(img[1,start_channel:end_channel], rows=Hin, cols=Win) |
| start_col = ((i-1)*channel_width)+1 |
| end_col = i*channel_width |
| img_extrapolated[1,start_col:end_col] = extrapolate_channel(channel_slice, Hin, Win, pad_size, channel_width) |
| } |
| } |
| |
| /* |
| * Pad single channel of input image img with extrapolated data by mirroring. |
| * Dimensions changed from (Hin,Win) to (1,(Hin+184)*(Win+184)). |
| * |
| * |
| * Inputs: |
| * - img: Image of dimension (Hin,Win) to pad |
| * - Hin: Height of image |
| * - Win: Width of image |
| * - pad_size: Pad size for each side of the image |
| * - channel_width: Width of a single channel |
| * |
| * Outputs: |
| * - channel_extrapolated: channel of image padded with extrapolated data |
| */ |
| extrapolate_channel = function(matrix[double] img, int Hin, int Win, int pad_size, int channel_width) |
| return (matrix[double] channel_extrapolated){ |
| pad_left = t(rev(t(img[,1:pad_size]))) |
| pad_right = t(rev(t(img[,(Win-(pad_size-1)):Win]))) |
| pad_top = rev(img[1:(pad_size),]) |
| pad_bottom = rev(img[(Hin-(pad_size-1)):Hin,]) |
| pad_top_left = rev(pad_left[1:(pad_size),]) |
| pad_top_right = rev(pad_right[1:(pad_size),]) |
| pad_bottom_left = rev(pad_left[(Hin-(pad_size-1)):Hin,]) |
| pad_bottom_right = rev(pad_right[(Hin-(pad_size-1)):Hin]) |
| |
| pad_left_full = rbind(pad_top_left, pad_left, pad_bottom_left) |
| pad_right_full = rbind(pad_top_right, pad_right, pad_bottom_right) |
| pad_center_full = rbind(pad_top, img, pad_bottom) |
| |
| modified_channel = cbind(pad_left_full, pad_center_full, pad_right_full) |
| channel_extrapolated = matrix(modified_channel, rows=1, cols=channel_width) |
| } |
| |
| /* |
| * Trains a U-Net model with a parameter server setup. |
| * |
| * The input matrix, X, has N examples, each represented as a 3D |
| * volume unrolled into a single vector. The targets, Y, have K |
| * classes representing the segmentation map of the input. |
| * |
| * Inputs: |
| * - X: Input data matrix, of shape (N, C*Hin*Win) |
| * - y: Target matrix, of shape (N, K) |
| * - X_val: Input validation data matrix, of shape (N, C*Hin*Win) |
| * - y_val: Target validation matrix, of shape (N, K) |
| * - C: Number of input channels |
| * - Hin: Input height |
| * - Win: Input width |
| * - epochs: Total number of full training loops over the full data set |
| * - workers: Number of federated workers |
| * - utype: Update type (synchronous, asynchronous, etc.) |
| * - freq: Frequency of weight updates ("BATCH" or "EPOCH") |
| * - batch_size: Batch size |
| * - scheme: Parameter server training scheme |
| * - learning_rate: The learning rate for the SGD with momentum |
| * - seed: Seed for the initialization of the convolution weights. Default is -1 meaning that the seeds are random. |
| * - M: Size of the segmentation map (C*(Hin-pad)*(Win-pad)) |
| * - K: Number of output categories (for each element of segmentation map) |
| * - he: Homomorphic encryption activated (boolean) |
| * - F1: Number of filters of the top layer of the U-Net model. Default is 64. |
| * |
| * Outputs: |
| * - model_trained: List containing weights and biases |
| */ |
| train_paramserv = function(matrix[double] X, matrix[double] y, |
| matrix[double] X_val, matrix[double] y_val, |
| int C, int Hin, int Win, int epochs, int workers, |
| string utype, string freq, int batch_size, string scheme, double learning_rate, |
| int seed = -1, int M, int K, boolean he = FALSE, int F1 = 64) |
| return (list[unknown] model_trained) { |
| N = nrow(X) # Number of inputs |
| |
| # Define model network constants |
| Hf = 3 # convolution filter height |
| Wf = 3 # convolution filter width |
| conv_stride = 1 |
| pool_stride = 2 |
| pool_HWf = 2 |
| conv_t_HWf = 2 |
| conv_t_stride = 2 |
| pad = 0 # For same dimensions, (Hf - stride) / 2 |
| F2 = F1 * 2 |
| F3 = F2 * 2 |
| F4 = F3 * 2 |
| F5 = F4 * 2 |
| dropProb = 1.0 |
| dropSeed = -1 |
| |
| # Create different seeds for each layer, unless seed is -1 |
| lseed = list() |
| for ( i in 1:23 ){ |
| if (seed == -1){ |
| lseed = rbind(lseed, -1) |
| } else { |
| lseed = rbind(lseed, seed+i) |
| } |
| } |
| |
| # Initialize convolution weights |
| |
| # First step |
| [W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = as.integer(as.scalar(lseed[1]))) # inputs: (N, C*Hin*Win) |
| [W2, b2] = conv2d::init(F1, F1, Hf, Wf, seed = as.integer(as.scalar(lseed[2]))) |
| # Second step |
| [W3, b3] = conv2d::init(F2, F1, Hf, Wf, seed = as.integer(as.scalar(lseed[3]))) |
| [W4, b4] = conv2d::init(F2, F2, Hf, Wf, seed = as.integer(as.scalar(lseed[4]))) |
| # Third step |
| [W5, b5] = conv2d::init(F3, F2, Hf, Wf, seed = as.integer(as.scalar(lseed[5]))) |
| [W6, b6] = conv2d::init(F3, F3, Hf, Wf, seed = as.integer(as.scalar(lseed[6]))) |
| # Fourth step |
| [W7, b7] = conv2d::init(F4, F3, Hf, Wf, seed = as.integer(as.scalar(lseed[7]))) |
| [W8, b8] = conv2d::init(F4, F4, Hf, Wf, seed = as.integer(as.scalar(lseed[8]))) |
| # Fifth step |
| [W9, b9] = conv2d::init(F5, F4, Hf, Wf, seed = as.integer(as.scalar(lseed[9]))) |
| [W10, b10] = conv2d::init(F5, F5, Hf, Wf, seed = as.integer(as.scalar(lseed[10]))) |
| # First Up-convolution |
| [W11, b11] = conv2d_transpose::init_seed(F4, F5, conv_t_HWf, conv_t_HWf, seed = as.integer(as.scalar(lseed[11]))) |
| [W12, b12] = conv2d::init(F4, F5, Hf, Wf, seed = as.integer(as.scalar(lseed[12]))) |
| [W13, b13] = conv2d::init(F4, F4, Hf, Wf, seed = as.integer(as.scalar(lseed[13]))) |
| # Second Up-convolution |
| [W14, b14] = conv2d_transpose::init_seed(F3, F4, conv_t_HWf, conv_t_HWf, seed = as.integer(as.scalar(lseed[14]))) |
| [W15, b15] = conv2d::init(F3, F4, Hf, Wf, seed = as.integer(as.scalar(lseed[15]))) |
| [W16, b16] = conv2d::init(F3, F3, Hf, Wf, seed = as.integer(as.scalar(lseed[16]))) |
| # Third Up-convolution |
| [W17, b17] = conv2d_transpose::init_seed(F2, F3, conv_t_HWf, conv_t_HWf, seed = as.integer(as.scalar(lseed[17]))) |
| [W18, b18] = conv2d::init(F2, F3, Hf, Wf, seed = as.integer(as.scalar(lseed[18]))) |
| [W19, b19] = conv2d::init(F2, F2, Hf, Wf, seed = as.integer(as.scalar(lseed[19]))) |
| # Fourth Up-convolution |
| [W20, b20] = conv2d_transpose::init_seed(F1, F2, conv_t_HWf, conv_t_HWf, seed = as.integer(as.scalar(lseed[20]))) |
| [W21, b21] = conv2d::init(F1, F2, Hf, Wf, seed = as.integer(as.scalar(lseed[21]))) |
| [W22, b22] = conv2d::init(F1, F1, Hf, Wf, seed = as.integer(as.scalar(lseed[22]))) |
| # Segmentation map |
| [W23, b23] = conv2d::init(K*C, F1, 1, 1, seed = as.integer(as.scalar(lseed[23]))) |
| |
| # Initialize SGD with momentum |
| vW1 = sgd_momentum::init(W1); vb1 = sgd_momentum::init(b1) |
| vW2 = sgd_momentum::init(W2); vb2 = sgd_momentum::init(b2) |
| vW3 = sgd_momentum::init(W3); vb3 = sgd_momentum::init(b3) |
| vW4 = sgd_momentum::init(W4); vb4 = sgd_momentum::init(b4) |
| vW5 = sgd_momentum::init(W5); vb5 = sgd_momentum::init(b5) |
| vW6 = sgd_momentum::init(W6); vb6 = sgd_momentum::init(b6) |
| vW7 = sgd_momentum::init(W7); vb7 = sgd_momentum::init(b7) |
| vW8 = sgd_momentum::init(W8); vb8 = sgd_momentum::init(b8) |
| vW9 = sgd_momentum::init(W9); vb9 = sgd_momentum::init(b9) |
| vW10 = sgd_momentum::init(W10); vb10 = sgd_momentum::init(b10) |
| vW11 = sgd_momentum::init(W11); vb11 = sgd_momentum::init(b11) |
| vW12 = sgd_momentum::init(W12); vb12 = sgd_momentum::init(b12) |
| vW13 = sgd_momentum::init(W13); vb13 = sgd_momentum::init(b13) |
| vW14 = sgd_momentum::init(W14); vb14 = sgd_momentum::init(b14) |
| vW15 = sgd_momentum::init(W15); vb15 = sgd_momentum::init(b15) |
| vW16 = sgd_momentum::init(W16); vb16 = sgd_momentum::init(b16) |
| vW17 = sgd_momentum::init(W17); vb17 = sgd_momentum::init(b17) |
| vW18 = sgd_momentum::init(W18); vb18 = sgd_momentum::init(b18) |
| vW19 = sgd_momentum::init(W19); vb19 = sgd_momentum::init(b19) |
| vW20 = sgd_momentum::init(W20); vb20 = sgd_momentum::init(b20) |
| vW21 = sgd_momentum::init(W21); vb21 = sgd_momentum::init(b21) |
| vW22 = sgd_momentum::init(W22); vb22 = sgd_momentum::init(b22) |
| vW23 = sgd_momentum::init(W23); vb23 = sgd_momentum::init(b23) |
| |
| # Define optimizer constants |
| mu = 0.9 # momentum |
| decay = 0.95 # learning rate decay constant |
| |
| # Regularization |
| lambda = 5e-04 |
| |
| # Create the model list |
| model_list = list( |
| W1, W2, W3, W4, W5, W6, W7, W8, W9, W10, W11, W12, W13, W14, W15, W16, W17, W18, W19, W20, W21, W22, W23, |
| b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, b16, b17, b18, b19, b20, b21, b22, b23, |
| vW1, vW2, vW3, vW4, vW5, vW6, vW7, vW8, vW9, vW10, vW11, vW12, vW13, vW14, vW15, vW16, vW17, vW18, vW19, vW20, vW21, vW22, vW23, |
| vb1, vb2, vb3, vb4, vb5, vb6, vb7, vb8, vb9, vb10, vb11, vb12, vb13, vb14, vb15, vb16, vb17, vb18, vb19, vb20, vb21, vb22, vb23) |
| |
| # Create the hyper parameter list |
| params = list( |
| learning_rate=learning_rate, mu=mu, decay=decay, M=M, K=K, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, |
| conv_stride=conv_stride, pool_stride=pool_stride, pool_HWf=pool_HWf, conv_t_HWf=conv_t_HWf, conv_t_stride=conv_t_stride, |
| pad=pad, lambda=lambda, F1=F1, F2=F2, F3=F3, F4=F4, F5=F5, dropProb=dropProb, dropSeed=dropSeed) |
| |
| # Use paramserv function |
| model_trained = paramserv(model=model_list, features=X, labels=y, val_features=X_val, val_labels=y_val, |
| upd="./scripts/nn/examples/u-net.dml::gradients", |
| agg="./scripts/nn/examples/u-net.dml::aggregation", |
| scheme=scheme, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, |
| k=workers, hyperparams=params, checkpointing="NONE", he=he, modelAvg=TRUE) |
| } |
| |
| /* |
| * Forward pass of U-Net model on X using specified batch size. |
| * |
| * Inputs: |
| * - X: Input features of size C*Hin*Win. |
| * The features need to be padded with mirrored data, hence the actual input size before padding is C*(Hin-pad)*(Win-pad). |
| * - C: Number of input channels |
| * - Hin: Input height |
| * - Win: Input width |
| * - batch_size: Batch size |
| * - model: List of weights of the model (23 weights, 23 biases) |
| * - M: Size of the segmentation map (C*(Hin-pad)*(Win-pad)) |
| * - K: Number of output categories (for each element of segmentation map) |
| * - F1: Number of filters of the top layer of the U-Net model. Default is 64. |
| * |
| * Output: |
| * - probs: Segmentation map probabilities generated by the forward pass of the U-Net model |
| */ |
| predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size, list[unknown] model, int M, int K, int F1 = 64) |
| return (matrix[double] probs) { |
| W1 = as.matrix(model[1]) |
| W2 = as.matrix(model[2]) |
| W3 = as.matrix(model[3]) |
| W4 = as.matrix(model[4]) |
| W5 = as.matrix(model[5]) |
| W6 = as.matrix(model[6]) |
| W7 = as.matrix(model[7]) |
| W8 = as.matrix(model[8]) |
| W9 = as.matrix(model[9]) |
| W10 = as.matrix(model[10]) |
| W11 = as.matrix(model[11]) |
| W12 = as.matrix(model[12]) |
| W13 = as.matrix(model[13]) |
| W14 = as.matrix(model[14]) |
| W15 = as.matrix(model[15]) |
| W16 = as.matrix(model[16]) |
| W17 = as.matrix(model[17]) |
| W18 = as.matrix(model[18]) |
| W19 = as.matrix(model[19]) |
| W20 = as.matrix(model[20]) |
| W21 = as.matrix(model[21]) |
| W22 = as.matrix(model[22]) |
| W23 = as.matrix(model[23]) |
| b1 = as.matrix(model[24]) |
| b2 = as.matrix(model[25]) |
| b3 = as.matrix(model[26]) |
| b4 = as.matrix(model[27]) |
| b5 = as.matrix(model[28]) |
| b6 = as.matrix(model[29]) |
| b7 = as.matrix(model[30]) |
| b8 = as.matrix(model[31]) |
| b9 = as.matrix(model[32]) |
| b10 = as.matrix(model[33]) |
| b11 = as.matrix(model[34]) |
| b12 = as.matrix(model[35]) |
| b13 = as.matrix(model[36]) |
| b14 = as.matrix(model[37]) |
| b15 = as.matrix(model[38]) |
| b16 = as.matrix(model[39]) |
| b17 = as.matrix(model[40]) |
| b18 = as.matrix(model[41]) |
| b19 = as.matrix(model[42]) |
| b20 = as.matrix(model[43]) |
| b21 = as.matrix(model[44]) |
| b22 = as.matrix(model[45]) |
| b23 = as.matrix(model[46]) |
| N = nrow(X) # Number of inputs |
| |
| Hf = 3 # convolution filter height |
| Wf = 3 # convolution filter width |
| conv_stride = 1 |
| pool_stride = 2 |
| pool_HWf = 2 |
| conv_t_HWf = 2 |
| conv_t_stride = 2 |
| pad = 0 # For same dimensions, (Hf - stride) / 2 |
| F2 = F1 * 2 |
| F3 = F2 * 2 |
| F4 = F3 * 2 |
| F5 = F4 * 2 |
| dropProb = 1.0 |
| dropSeed = -1 |
| |
| # Compute predictions over mini-batches |
| probs = matrix(0, rows=N, cols=K*M) |
| iters = ceil(N / batch_size) |
| for(i in 1:iters, check=0) { |
| # Get next batch |
| beg = ((i-1) * batch_size) %% N + 1 |
| end = min(N, beg + batch_size - 1) |
| X_batch = X[beg:end,] |
| |
| # Down-Convolution |
| [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr1 = relu::forward(outc1) |
| [outc2, Houtc2, Woutc2] = conv2d::forward(outr1, W2, b2, F1, Houtc1, Woutc1, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr2 = relu::forward(outc2) |
| [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr2, F1, Houtc2, Woutc2, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0) |
| [outc3, Houtc3, Woutc3] = conv2d::forward(outp1, W3, b3, F1, Houtp1, Woutp1, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr3 = relu::forward(outc3) |
| [outc4, Houtc4, Woutc4] = conv2d::forward(outr3, W4, b4, F2, Houtc3, Woutc3, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr4 = relu::forward(outc4) |
| [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr4, F2, Houtc4, Woutc4, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0) |
| [outc5, Houtc5, Woutc5] = conv2d::forward(outp2, W5, b5, F2, Houtp2, Woutp2, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr5 = relu::forward(outc5) |
| [outc6, Houtc6, Woutc6] = conv2d::forward(outr5, W6, b6, F3, Houtc5, Woutc5, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr6 = relu::forward(outc6) |
| [outp3, Houtp3, Woutp3] = max_pool2d::forward(outr6, F3, Houtc6, Woutc6, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0) |
| [outc7, Houtc7, Woutc7] = conv2d::forward(outp3, W7, b7, F3, Houtp3, Woutp3, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr7 = relu::forward(outc7) |
| [outc8, Houtc8, Woutc8] = conv2d::forward(outr7, W8, b8, F4, Houtc7, Woutc7, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr8 = relu::forward(outc8) |
| [outp4, Houtp4, Woutp4] = max_pool2d::forward(outr8, F4, Houtc8, Woutc8, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0) |
| [outd1, mask1] = dropout::forward(outp4, dropProb, dropSeed) |
| |
| # Bottom |
| [outc9, Houtc9, Woutc9] = conv2d::forward(outd1, W9, b9, F4, Houtp4, Woutp4, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr9 = relu::forward(outc9) |
| [outc10, Houtc10, Woutc10] = conv2d::forward(outr9, W10, b10, F5, Houtc9, Woutc9, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr10 = relu::forward(outc10) |
| [outc11, Houtc11, Woutc11] = conv2d_transpose::forward(outr10, W11, b11, F5, Houtc10, Woutc10, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0) |
| |
| # Up-Convolution |
| outConcat1 = cbind(img_utils::crop_channel(outr8, Houtc8, Woutc8, Houtc11, Woutc11, F4),outc11) |
| [outc12, Houtc12, Woutc12] = conv2d::forward(outConcat1, W12, b12, F5, Houtc11, Woutc11, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr11 = relu::forward(outc12) |
| [outc13, Houtc13, Woutc13] = conv2d::forward(outr11, W13, b13, F4, Houtc12, Woutc12, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr12 = relu::forward(outc13) |
| [outc14, Houtc14, Woutc14] = conv2d_transpose::forward(outr12, W14, b14, F4, Houtc13, Woutc13, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0) |
| outConcat2 = cbind(img_utils::crop_channel(outr6, Houtc6, Woutc6, Houtc14, Woutc14, F3),outc14) |
| [outc15, Houtc15, Woutc15] = conv2d::forward(outConcat2, W15, b15, F4, Houtc14, Woutc14, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr13 = relu::forward(outc15) |
| [outc16, Houtc16, Woutc16] = conv2d::forward(outr13, W16, b16, F3, Houtc15, Woutc15, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr14 = relu::forward(outc16) |
| [outc17, Houtc17, Woutc17] = conv2d_transpose::forward(outr14, W17, b17, F3, Houtc16, Woutc16, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0) |
| outConcat3 = cbind(img_utils::crop_channel(outr4, Houtc4, Woutc4, Houtc17, Woutc17, F2), outc17) |
| [outc18, Houtc18, Woutc18] = conv2d::forward(outConcat3, W18, b18, F3, Houtc17, Woutc17, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr15 = relu::forward(outc18) |
| [outc19, Houtc19, Woutc19] = conv2d::forward(outr15, W19, b19, F2, Houtc18, Woutc18, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr16 = relu::forward(outc19) |
| [outc20, Houtc20, Woutc20] = conv2d_transpose::forward(outr16, W20, b20, F2, Houtc19, Woutc19, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0) |
| outConcat4 = cbind(img_utils::crop_channel(outr2, Houtc2, Woutc2, Houtc20, Woutc20, F1), outc20) |
| [outc21, Houtc21, Woutc21] = conv2d::forward(outConcat4, W21, b21, F2, Houtc20, Woutc20, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr17 = relu::forward(outc21) |
| [outc22, Houtc22, Woutc22] = conv2d::forward(outr17, W22, b22, F1, Houtc21, Woutc21, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr18 = relu::forward(outc22) |
| |
| # This last conv2d needs to create the segmentation map (1x1 filter): |
| [outc23, Houtc23, Woutc23] = conv2d::forward(outr18, W23, b23, F1, Houtc22, Woutc22, 1, 1, conv_stride, conv_stride, pad, pad) |
| |
| # Store predictions |
| probs[beg:end,] = softmax2d::forward(outc23, K) |
| } |
| } |
| |
| /* |
| * Forward and backward pass of U-Net with gradients returned. |
| * |
| * Inputs: |
| * - model: List of model weights |
| * - hyperparams: List of hyper parameters containing: |
| * - (scalar[integer]) C: Number of input channels |
| * - (scalar[integer]) Hin: Input height |
| * - (scalar[integer]) Win: Input width |
| * - (scalar[integer]) Hf: Filter height |
| * - (scalar[integer]) Wf: Filter width |
| * - (scalar[integer]) pool_stride: Stride of the max pool operations |
| * - (scalar[integer]) pool_HWf: Filter height and width of the max pool operation |
| * - (scalar[integer]) conv_stride: Stride of all convolutions |
| * - (scalar[integer]) conv_t_HWf: Filter height and width of the transpose convolutions |
| * - (scalar[integer]) conv_t_stride: Stride of the transpose convolutions |
| * - (scalar[integer]) pad: Padding of all convolutions and transpose convolutions |
| * - (scalar[double]) lambda: Regularization strength |
| * - (scalar[integer]) F1, F2, F3, F4, F5: Number of filters of the convolutions in the five layers |
| * - (scalar[double]) dropProb: Dropout probability |
| * - (scalar[integer]) dropSeed: Dropout seed |
| * - (scalar[integer]) M: Size of the segmentation map (C*(Hin-pad)*(Win-pad)) |
| * - (scalar[integer]) K: Number of output categories (for each element of segmentation map) |
| * - features: Features of size C*Hin*Win. The features need to be padded with mirrored data. |
| * The input feature size should result in an output size of the U-Net equal to the label size. |
| * See extrapolate function for how to pad the features by extrapolating. |
| * - labels: Labels of size C * (Hin-pad) * (Win-pad) representing a segmentation map of the input features. |
| * Output: |
| * - gradients: List of gradients |
| */ |
| gradients = function(list[unknown] model, |
| list[unknown] hyperparams, |
| matrix[double] features, |
| matrix[double] labels) |
| return (list[unknown] gradients) { |
| C = as.integer(as.scalar(hyperparams["C"])) |
| Hin = as.integer(as.scalar(hyperparams["Hin"])) |
| Win = as.integer(as.scalar(hyperparams["Win"])) |
| Hf = as.integer(as.scalar(hyperparams["Hf"])) |
| Wf = as.integer(as.scalar(hyperparams["Wf"])) |
| pool_stride = as.integer(as.scalar(hyperparams["pool_stride"])) |
| pool_HWf = as.integer(as.scalar(hyperparams["pool_HWf"])) |
| conv_stride = as.integer(as.scalar(hyperparams["conv_stride"])) |
| conv_t_HWf = as.integer(as.scalar(hyperparams["conv_t_HWf"])) |
| conv_t_stride = as.integer(as.scalar(hyperparams["conv_t_stride"])) |
| pad = as.integer(as.scalar(hyperparams["pad"])) |
| lambda = as.double(as.scalar(hyperparams["lambda"])) |
| F1 = as.integer(as.scalar(hyperparams["F1"])) |
| F2 = as.integer(as.scalar(hyperparams["F2"])) |
| F3 = as.integer(as.scalar(hyperparams["F3"])) |
| F4 = as.integer(as.scalar(hyperparams["F4"])) |
| F5 = as.integer(as.scalar(hyperparams["F5"])) |
| dropProb = as.double(as.scalar(hyperparams["dropProb"])) |
| dropSeed = as.integer(as.scalar(hyperparams["dropSeed"])) |
| M = as.integer(as.scalar(hyperparams["M"])) |
| K = as.integer(as.scalar(hyperparams["K"])) |
| W1 = as.matrix(model[1]) |
| W2 = as.matrix(model[2]) |
| W3 = as.matrix(model[3]) |
| W4 = as.matrix(model[4]) |
| W5 = as.matrix(model[5]) |
| W6 = as.matrix(model[6]) |
| W7 = as.matrix(model[7]) |
| W8 = as.matrix(model[8]) |
| W9 = as.matrix(model[9]) |
| W10 = as.matrix(model[10]) |
| W11 = as.matrix(model[11]) |
| W12 = as.matrix(model[12]) |
| W13 = as.matrix(model[13]) |
| W14 = as.matrix(model[14]) |
| W15 = as.matrix(model[15]) |
| W16 = as.matrix(model[16]) |
| W17 = as.matrix(model[17]) |
| W18 = as.matrix(model[18]) |
| W19 = as.matrix(model[19]) |
| W20 = as.matrix(model[20]) |
| W21 = as.matrix(model[21]) |
| W22 = as.matrix(model[22]) |
| W23 = as.matrix(model[23]) |
| b1 = as.matrix(model[24]) |
| b2 = as.matrix(model[25]) |
| b3 = as.matrix(model[26]) |
| b4 = as.matrix(model[27]) |
| b5 = as.matrix(model[28]) |
| b6 = as.matrix(model[29]) |
| b7 = as.matrix(model[30]) |
| b8 = as.matrix(model[31]) |
| b9 = as.matrix(model[32]) |
| b10 = as.matrix(model[33]) |
| b11 = as.matrix(model[34]) |
| b12 = as.matrix(model[35]) |
| b13 = as.matrix(model[36]) |
| b14 = as.matrix(model[37]) |
| b15 = as.matrix(model[38]) |
| b16 = as.matrix(model[39]) |
| b17 = as.matrix(model[40]) |
| b18 = as.matrix(model[41]) |
| b19 = as.matrix(model[42]) |
| b20 = as.matrix(model[43]) |
| b21 = as.matrix(model[44]) |
| b22 = as.matrix(model[45]) |
| b23 = as.matrix(model[46]) |
| |
| # Down-Convolution |
| [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr1 = relu::forward(outc1) |
| [outc2, Houtc2, Woutc2] = conv2d::forward(outr1, W2, b2, F1, Houtc1, Woutc1, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr2 = relu::forward(outc2) |
| [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr2, F1, Houtc2, Woutc2, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0) |
| [outc3, Houtc3, Woutc3] = conv2d::forward(outp1, W3, b3, F1, Houtp1, Woutp1, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr3 = relu::forward(outc3) |
| [outc4, Houtc4, Woutc4] = conv2d::forward(outr3, W4, b4, F2, Houtc3, Woutc3, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr4 = relu::forward(outc4) |
| [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr4, F2, Houtc4, Woutc4, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0) |
| [outc5, Houtc5, Woutc5] = conv2d::forward(outp2, W5, b5, F2, Houtp2, Woutp2, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr5 = relu::forward(outc5) |
| [outc6, Houtc6, Woutc6] = conv2d::forward(outr5, W6, b6, F3, Houtc5, Woutc5, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr6 = relu::forward(outc6) |
| [outp3, Houtp3, Woutp3] = max_pool2d::forward(outr6, F3, Houtc6, Woutc6, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0) |
| [outc7, Houtc7, Woutc7] = conv2d::forward(outp3, W7, b7, F3, Houtp3, Woutp3, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr7 = relu::forward(outc7) |
| [outc8, Houtc8, Woutc8] = conv2d::forward(outr7, W8, b8, F4, Houtc7, Woutc7, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr8 = relu::forward(outc8) |
| [outp4, Houtp4, Woutp4] = max_pool2d::forward(outr8, F4, Houtc8, Woutc8, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0) |
| [outd1, mask1] = dropout::forward(outp4, dropProb, dropSeed) |
| |
| # Bottom |
| [outc9, Houtc9, Woutc9] = conv2d::forward(outd1, W9, b9, F4, Houtp4, Woutp4, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr9 = relu::forward(outc9) |
| [outc10, Houtc10, Woutc10] = conv2d::forward(outr9, W10, b10, F5, Houtc9, Woutc9, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr10 = relu::forward(outc10) |
| [outc11, Houtc11, Woutc11] = conv2d_transpose::forward(outr10, W11, b11, F5, Houtc10, Woutc10, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0) |
| |
| # Up-Convolution |
| outConcat1 = cbind(img_utils::crop_channel(outr8, Houtc8, Woutc8, Houtc11, Woutc11, F4),outc11) |
| [outc12, Houtc12, Woutc12] = conv2d::forward(outConcat1, W12, b12, F5, Houtc11, Woutc11, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr11 = relu::forward(outc12) |
| [outc13, Houtc13, Woutc13] = conv2d::forward(outr11, W13, b13, F4, Houtc12, Woutc12, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr12 = relu::forward(outc13) |
| [outc14, Houtc14, Woutc14] = conv2d_transpose::forward(outr12, W14, b14, F4, Houtc13, Woutc13, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0) |
| outConcat2 = cbind(img_utils::crop_channel(outr6, Houtc6, Woutc6, Houtc14, Woutc14, F3),outc14) |
| [outc15, Houtc15, Woutc15] = conv2d::forward(outConcat2, W15, b15, F4, Houtc14, Woutc14, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr13 = relu::forward(outc15) |
| [outc16, Houtc16, Woutc16] = conv2d::forward(outr13, W16, b16, F3, Houtc15, Woutc15, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr14 = relu::forward(outc16) |
| [outc17, Houtc17, Woutc17] = conv2d_transpose::forward(outr14, W17, b17, F3, Houtc16, Woutc16, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0) |
| outConcat3 = cbind(img_utils::crop_channel(outr4, Houtc4, Woutc4, Houtc17, Woutc17, F2), outc17) |
| [outc18, Houtc18, Woutc18] = conv2d::forward(outConcat3, W18, b18, F3, Houtc17, Woutc17, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr15 = relu::forward(outc18) |
| [outc19, Houtc19, Woutc19] = conv2d::forward(outr15, W19, b19, F2, Houtc18, Woutc18, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr16 = relu::forward(outc19) |
| [outc20, Houtc20, Woutc20] = conv2d_transpose::forward(outr16, W20, b20, F2, Houtc19, Woutc19, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad, 0, 0) |
| outConcat4 = cbind(img_utils::crop_channel(outr2, Houtc2, Woutc2, Houtc20, Woutc20, F1), outc20) |
| [outc21, Houtc21, Woutc21] = conv2d::forward(outConcat4, W21, b21, F2, Houtc20, Woutc20, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr17 = relu::forward(outc21) |
| [outc22, Houtc22, Woutc22] = conv2d::forward(outr17, W22, b22, F1, Houtc21, Woutc21, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| outr18 = relu::forward(outc22) |
| |
| # This last conv2d needs to create the segmentation map (1x1 filter): |
| [outc23, Houtc23, Woutc23] = conv2d::forward(outr18, W23, b23, F1, Houtc22, Woutc22, 1, 1, conv_stride, conv_stride, pad, pad) |
| probs = softmax2d::forward(outc23, K) |
| |
| # Compute loss & accuracy for training data |
| loss = cross_entropy_loss::forward(probs, labels) |
| accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels)) |
| # print("[+] Completed forward pass on batch: train loss: " + loss + ", train accuracy: " + accuracy) |
| |
| # Compute data backward pass |
| |
| ## loss |
| dprobs = cross_entropy_loss::backward(probs, labels) |
| doutc23 = softmax2d::backward(dprobs, outc23, K) |
| |
| # Up-Convolution |
| # conv2d parameters: (previous_gradient, output height, output width, input to original layer, layer weight, layer bias, layer input channel number, input height, input width, filter height, filter width, stride height, stride width, pad height, pad width) |
| [doutc22, dW23, db23] = conv2d::backward(doutc23, Houtc23, Woutc23, outr18, W23, b23, F1, Houtc22, Woutc22, 1, 1, conv_stride, conv_stride, pad, pad) |
| doutr18 = relu::backward(doutc22, outc22) |
| [doutc21, dW22, db22] = conv2d::backward(doutr18, Houtc22, Woutc22, outr17, W22, b22, F1, Houtc21, Woutc21, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutr17 = relu::backward(doutc21, outc21) |
| [doutc20, dW21, db21] = conv2d::backward(doutr17, Houtc21, Woutc21, outConcat4, W21, b21, F2, Houtc20, Woutc20, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutc20_cropped = doutc20[,(F1*Houtc20*Woutc20+1):(ncol(doutc20))] #Removing half of the gradients since they are related to a different layer. |
| [doutc19, dW20, db20] = conv2d_transpose::backward(doutc20_cropped, Houtc20, Woutc20, outr16, W20, b20, F2, Houtc19, Woutc19, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad) |
| doutr16 = relu::backward(doutc19, outc19) |
| [doutc18, dW19, db19] = conv2d::backward(doutr16, Houtc19, Woutc19, outr15, W19, b19, F2, Houtc18, Woutc18, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutr15 = relu::backward(doutc18, outc18) |
| [doutc17, dW18, db18] = conv2d::backward(doutr15, Houtc18, Woutc18, outConcat3, W18, b18, F3, Houtc17, Woutc17, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutc17_cropped = doutc17[,(F2*Houtc17*Woutc17+1):ncol(doutc17)] |
| [doutc16, dW17, db17] = conv2d_transpose::backward(doutc17_cropped, Houtc17, Woutc17, outr14, W17, b17, F3, Houtc16, Woutc16, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad) |
| doutr14 = relu::backward(doutc16, outc16) |
| [doutc15, dW16, db16] = conv2d::backward(doutr14, Houtc16, Woutc16, outr13, W16, b16, F3, Houtc15, Woutc15, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutr13 = relu::backward(doutc15, outc15) |
| [doutc14, dW15, db15] = conv2d::backward(doutr13, Houtc15, Woutc15, outConcat2, W15, b15, F4, Houtc14, Woutc14, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutc14_cropped = doutc14[,(F3*Houtc14*Woutc14+1):ncol(doutc14)] |
| [doutc13, dW14, db14] = conv2d_transpose::backward(doutc14_cropped, Houtc14, Woutc14, outr12, W14, b14, F4, Houtc13, Woutc13, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad) |
| doutr12 = relu::backward(doutc13, outc13) |
| [doutc12, dW13, db13] = conv2d::backward(doutr12, Houtc13, Woutc13, outr11, W13, b13, F4, Houtc12, Woutc12, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutr11 = relu::backward(doutc12, outc12) |
| [doutc11, dW12, db12] = conv2d::backward(doutr11, Houtc12, Woutc12, outConcat1, W12, b12, F5, Houtc11, Woutc11, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| |
| # Bottom |
| doutc11_cropped = doutc11[,(F4*Houtc11*Woutc11+1):ncol(doutc11)] |
| [doutc10, dW11, db11] = conv2d_transpose::backward(doutc11_cropped, Houtc11, Woutc11, outr10, W11, b11, F5, Houtc10, Woutc10, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad) |
| doutr10 = relu::backward(doutc10, outc10) |
| [doutc9, dW10, db10] = conv2d::backward(doutr10, Houtc10, Woutc10, outr9, W10, b10, F5, Houtc9, Woutc9, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutr9 = relu::backward(doutc9, outc9) |
| [doutc8, dW9, db9] = conv2d::backward(doutr9, Houtc9, Woutc9, outd1, W9, b9, F4, Houtp4, Woutp4, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| |
| # Down-Convolution |
| doutd1 = dropout::backward(doutc8, outp4, dropProb, mask1) |
| doutp4 = max_pool2d::backward(doutd1, Houtp4, Woutp4, outr8, F4, Houtc8, Woutc8, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0) |
| doutr8 = relu::backward(doutp4, outc8) |
| [doutc7, dW8, db8] = conv2d::backward(doutr8, Houtc8, Woutc8, outr7, W8, b8, F4, Houtc7, Woutc7, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutr7 = relu::backward(doutc7, outc7) |
| [doutc6, dW7, db7] = conv2d::backward(doutr7, Houtc7, Woutc7, outp3, W7, b7, F3, Houtp3, Woutp3, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutp3 = max_pool2d::backward(doutc6, Houtp3, Woutp3, outr6, F3, Houtc6, Woutc6, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0) |
| doutr6 = relu::backward(doutp3, outc6) |
| [doutc5, dW6, db6] = conv2d::backward(doutr6, Houtc6, Woutc6, outr5, W6, b6, F3, Houtc5, Woutc5, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutr5 = relu::backward(doutc5, outc5) |
| [doutc4, dW5, db5] = conv2d::backward(doutr5, Houtc5, Woutc5, outp2, W5, b5, F2, Houtp2, Woutp2, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutp2 = max_pool2d::backward(doutc4, Houtp2, Woutp2, outr4, F2, Houtc4, Woutc4, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0) |
| doutr4 = relu::backward(doutp2, outc4) |
| [doutc3, dW4, db4] = conv2d::backward(doutr4, Houtc4, Woutc4, outr3, W4, b4, F2, Houtc3, Woutc3, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutr3 = relu::backward(doutc3, outc3) |
| [doutc2, dW3, db3] = conv2d::backward(doutr3, Houtc3, Woutc3, outp1, W3, b3, F1, Houtp1, Woutp1, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutp1 = max_pool2d::backward(doutc2, Houtp1, Woutp1, outr2, F1, Houtc2, Woutc2, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0) |
| doutr2 = relu::backward(doutp1, outc2) |
| [doutc1, dW2, db2] = conv2d::backward(doutr2, Houtc2, Woutc2, outr1, W2, b2, F1, Houtc1, Woutc1, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| doutr1 = relu::backward(doutc1,outc1) |
| [dx_batch, dW1, db1] = conv2d::backward(doutr1, Houtc1, Woutc1, features, W1, b1, C, Hin, Win, Hf, Wf, conv_stride, conv_stride, pad, pad) |
| |
| # Compute regularization backward pass |
| dW1_reg = l2_reg::backward(W1, lambda) |
| dW2_reg = l2_reg::backward(W2, lambda) |
| dW3_reg = l2_reg::backward(W3, lambda) |
| dW4_reg = l2_reg::backward(W4, lambda) |
| dW5_reg = l2_reg::backward(W5, lambda) |
| dW6_reg = l2_reg::backward(W6, lambda) |
| dW7_reg = l2_reg::backward(W7, lambda) |
| dW8_reg = l2_reg::backward(W8, lambda) |
| dW9_reg = l2_reg::backward(W9, lambda) |
| dW10_reg = l2_reg::backward(W10, lambda) |
| dW11_reg = l2_reg::backward(W11, lambda) |
| dW12_reg = l2_reg::backward(W12, lambda) |
| dW13_reg = l2_reg::backward(W13, lambda) |
| dW14_reg = l2_reg::backward(W14, lambda) |
| dW15_reg = l2_reg::backward(W15, lambda) |
| dW16_reg = l2_reg::backward(W16, lambda) |
| dW17_reg = l2_reg::backward(W17, lambda) |
| dW18_reg = l2_reg::backward(W18, lambda) |
| dW19_reg = l2_reg::backward(W19, lambda) |
| dW20_reg = l2_reg::backward(W20, lambda) |
| dW21_reg = l2_reg::backward(W21, lambda) |
| dW22_reg = l2_reg::backward(W22, lambda) |
| dW23_reg = l2_reg::backward(W23, lambda) |
| |
| dW1 = dW1 + dW1_reg |
| dW2 = dW2 + dW2_reg |
| dW3 = dW3 + dW3_reg |
| dW4 = dW4 + dW4_reg |
| dW5 = dW5 + dW5_reg |
| dW6 = dW6 + dW6_reg |
| dW7 = dW7 + dW7_reg |
| dW8 = dW8 + dW8_reg |
| dW9 = dW9 + dW9_reg |
| dW10 = dW10 + dW10_reg |
| dW11 = dW11 + dW11_reg |
| dW12 = dW12 + dW12_reg |
| dW13 = dW13 + dW13_reg |
| dW14 = dW14 + dW14_reg |
| dW15 = dW15 + dW15_reg |
| dW16 = dW16 + dW16_reg |
| dW17 = dW17 + dW17_reg |
| dW18 = dW18 + dW18_reg |
| dW19 = dW19 + dW19_reg |
| dW20 = dW20 + dW20_reg |
| dW21 = dW21 + dW21_reg |
| dW22 = dW22 + dW22_reg |
| dW23 = dW23 + dW23_reg |
| |
| gradients = list( |
| dW1, dW2, dW3, dW4, dW5, dW6, dW7, dW8, dW9, dW10, dW11, dW12, dW13, dW14, dW15, dW16, dW17, dW18, dW19, dW20, dW21, dW22, dW23, |
| db1, db2, db3, db4, db5, db6, db7, db8, db9, db10, db11, db12, db13, db14, db15, db16, db17, db18, db19, db20, db21, db22, db23 |
| ) |
| } |
| |
| /* |
| * Updates the model weights based on gradients and hyperparameters (learning rate and mu). |
| * |
| * Inputs: |
| * - model: List of model weights |
| * - hyperparams: List of hyper parameters containing (scalar[double]) learning_rate and (scalar[double]) mu |
| * - gradients: List of gradients |
| * |
| * Outputs: |
| * - model_result: List of updated model weights |
| */ |
| aggregation = function(list[unknown] model, |
| list[unknown] hyperparams, |
| list[unknown] gradients) |
| return (list[unknown] model_result) { |
| W1 = as.matrix(model[1]) |
| W2 = as.matrix(model[2]) |
| W3 = as.matrix(model[3]) |
| W4 = as.matrix(model[4]) |
| W5 = as.matrix(model[5]) |
| W6 = as.matrix(model[6]) |
| W7 = as.matrix(model[7]) |
| W8 = as.matrix(model[8]) |
| W9 = as.matrix(model[9]) |
| W10 = as.matrix(model[10]) |
| W11 = as.matrix(model[11]) |
| W12 = as.matrix(model[12]) |
| W13 = as.matrix(model[13]) |
| W14 = as.matrix(model[14]) |
| W15 = as.matrix(model[15]) |
| W16 = as.matrix(model[16]) |
| W17 = as.matrix(model[17]) |
| W18 = as.matrix(model[18]) |
| W19 = as.matrix(model[19]) |
| W20 = as.matrix(model[20]) |
| W21 = as.matrix(model[21]) |
| W22 = as.matrix(model[22]) |
| W23 = as.matrix(model[23]) |
| |
| |
| b1 = as.matrix(model[24]) |
| b2 = as.matrix(model[25]) |
| b3 = as.matrix(model[26]) |
| b4 = as.matrix(model[27]) |
| b5 = as.matrix(model[28]) |
| b6 = as.matrix(model[29]) |
| b7 = as.matrix(model[30]) |
| b8 = as.matrix(model[31]) |
| b9 = as.matrix(model[32]) |
| b10 = as.matrix(model[33]) |
| b11 = as.matrix(model[34]) |
| b12 = as.matrix(model[35]) |
| b13 = as.matrix(model[36]) |
| b14 = as.matrix(model[37]) |
| b15 = as.matrix(model[38]) |
| b16 = as.matrix(model[39]) |
| b17 = as.matrix(model[40]) |
| b18 = as.matrix(model[41]) |
| b19 = as.matrix(model[42]) |
| b20 = as.matrix(model[43]) |
| b21 = as.matrix(model[44]) |
| b22 = as.matrix(model[45]) |
| b23 = as.matrix(model[46]) |
| |
| dW1 = as.matrix(gradients[1]) |
| dW2 = as.matrix(gradients[2]) |
| dW3 = as.matrix(gradients[3]) |
| dW4 = as.matrix(gradients[4]) |
| dW5 = as.matrix(gradients[5]) |
| dW6 = as.matrix(gradients[6]) |
| dW7 = as.matrix(gradients[7]) |
| dW8 = as.matrix(gradients[8]) |
| dW9 = as.matrix(gradients[9]) |
| dW10 = as.matrix(gradients[10]) |
| dW11 = as.matrix(gradients[11]) |
| dW12 = as.matrix(gradients[12]) |
| dW13 = as.matrix(gradients[13]) |
| dW14 = as.matrix(gradients[14]) |
| dW15 = as.matrix(gradients[15]) |
| dW16 = as.matrix(gradients[16]) |
| dW17 = as.matrix(gradients[17]) |
| dW18 = as.matrix(gradients[18]) |
| dW19 = as.matrix(gradients[19]) |
| dW20 = as.matrix(gradients[20]) |
| dW21 = as.matrix(gradients[21]) |
| dW22 = as.matrix(gradients[22]) |
| dW23 = as.matrix(gradients[23]) |
| |
| db1 = as.matrix(gradients[24]) |
| db2 = as.matrix(gradients[25]) |
| db3 = as.matrix(gradients[26]) |
| db4 = as.matrix(gradients[27]) |
| db5 = as.matrix(gradients[28]) |
| db6 = as.matrix(gradients[29]) |
| db7 = as.matrix(gradients[30]) |
| db8 = as.matrix(gradients[31]) |
| db9 = as.matrix(gradients[32]) |
| db10 = as.matrix(gradients[33]) |
| db11 = as.matrix(gradients[34]) |
| db12 = as.matrix(gradients[35]) |
| db13 = as.matrix(gradients[36]) |
| db14 = as.matrix(gradients[37]) |
| db15 = as.matrix(gradients[38]) |
| db16 = as.matrix(gradients[39]) |
| db17 = as.matrix(gradients[40]) |
| db18 = as.matrix(gradients[41]) |
| db19 = as.matrix(gradients[42]) |
| db20 = as.matrix(gradients[43]) |
| db21 = as.matrix(gradients[44]) |
| db22 = as.matrix(gradients[45]) |
| db23 = as.matrix(gradients[46]) |
| |
| vW1 = as.matrix(model[47]) |
| vW2 = as.matrix(model[48]) |
| vW3 = as.matrix(model[49]) |
| vW4 = as.matrix(model[50]) |
| vW5 = as.matrix(model[51]) |
| vW6 = as.matrix(model[52]) |
| vW7 = as.matrix(model[53]) |
| vW8 = as.matrix(model[54]) |
| vW9 = as.matrix(model[55]) |
| vW10 = as.matrix(model[56]) |
| vW11 = as.matrix(model[57]) |
| vW12 = as.matrix(model[58]) |
| vW13 = as.matrix(model[59]) |
| vW14 = as.matrix(model[60]) |
| vW15 = as.matrix(model[61]) |
| vW16 = as.matrix(model[62]) |
| vW17 = as.matrix(model[63]) |
| vW18 = as.matrix(model[64]) |
| vW19 = as.matrix(model[65]) |
| vW20 = as.matrix(model[66]) |
| vW21 = as.matrix(model[67]) |
| vW22 = as.matrix(model[68]) |
| vW23 = as.matrix(model[69]) |
| |
| vb1 = as.matrix(model[70]) |
| vb2 = as.matrix(model[71]) |
| vb3 = as.matrix(model[72]) |
| vb4 = as.matrix(model[73]) |
| vb5 = as.matrix(model[74]) |
| vb6 = as.matrix(model[75]) |
| vb7 = as.matrix(model[76]) |
| vb8 = as.matrix(model[77]) |
| vb9 = as.matrix(model[78]) |
| vb10 = as.matrix(model[79]) |
| vb11 = as.matrix(model[80]) |
| vb12 = as.matrix(model[81]) |
| vb13 = as.matrix(model[82]) |
| vb14 = as.matrix(model[83]) |
| vb15 = as.matrix(model[84]) |
| vb16 = as.matrix(model[85]) |
| vb17 = as.matrix(model[86]) |
| vb18 = as.matrix(model[87]) |
| vb19 = as.matrix(model[88]) |
| vb20 = as.matrix(model[89]) |
| vb21 = as.matrix(model[90]) |
| vb22 = as.matrix(model[91]) |
| vb23 = as.matrix(model[92]) |
| |
| learning_rate = as.double(as.scalar(hyperparams["learning_rate"])) |
| mu = as.double(as.scalar(hyperparams["mu"])) |
| |
| # Optimize with SGD with momentum |
| [W1, vW1] = sgd_momentum::update(W1, dW1, learning_rate, mu, vW1) |
| [b1, vb1] = sgd_momentum::update(b1, db1, learning_rate, mu, vb1) |
| [W2, vW2] = sgd_momentum::update(W2, dW2, learning_rate, mu, vW2) |
| [b2, vb2] = sgd_momentum::update(b2, db2, learning_rate, mu, vb2) |
| [W3, vW3] = sgd_momentum::update(W3, dW3, learning_rate, mu, vW3) |
| [b3, vb3] = sgd_momentum::update(b3, db3, learning_rate, mu, vb3) |
| [W4, vW4] = sgd_momentum::update(W4, dW4, learning_rate, mu, vW4) |
| [b4, vb4] = sgd_momentum::update(b4, db4, learning_rate, mu, vb4) |
| [W5, vW5] = sgd_momentum::update(W5, dW5, learning_rate, mu, vW5) |
| [b5, vb5] = sgd_momentum::update(b5, db5, learning_rate, mu, vb5) |
| [W6, vW6] = sgd_momentum::update(W6, dW6, learning_rate, mu, vW6) |
| [b6, vb6] = sgd_momentum::update(b6, db6, learning_rate, mu, vb6) |
| [W7, vW7] = sgd_momentum::update(W7, dW7, learning_rate, mu, vW7) |
| [b7, vb7] = sgd_momentum::update(b7, db7, learning_rate, mu, vb7) |
| [W8, vW8] = sgd_momentum::update(W8, dW8, learning_rate, mu, vW8) |
| [b8, vb8] = sgd_momentum::update(b8, db8, learning_rate, mu, vb8) |
| [W9, vW9] = sgd_momentum::update(W9, dW9, learning_rate, mu, vW9) |
| [b9, vb9] = sgd_momentum::update(b9, db9, learning_rate, mu, vb9) |
| [W10, vW10] = sgd_momentum::update(W10, dW10, learning_rate, mu, vW10) |
| [b10, vb10] = sgd_momentum::update(b10, db10, learning_rate, mu, vb10) |
| [W11, vW11] = sgd_momentum::update(W11, dW11, learning_rate, mu, vW11) |
| [b11, vb11] = sgd_momentum::update(b11, db11, learning_rate, mu, vb11) |
| [W12, vW12] = sgd_momentum::update(W12, dW12, learning_rate, mu, vW12) |
| [b12, vb12] = sgd_momentum::update(b12, db12, learning_rate, mu, vb12) |
| [W13, vW13] = sgd_momentum::update(W13, dW13, learning_rate, mu, vW13) |
| [b13, vb13] = sgd_momentum::update(b13, db13, learning_rate, mu, vb13) |
| [W14, vW14] = sgd_momentum::update(W14, dW14, learning_rate, mu, vW14) |
| [b14, vb14] = sgd_momentum::update(b14, db14, learning_rate, mu, vb14) |
| [W15, vW15] = sgd_momentum::update(W15, dW15, learning_rate, mu, vW15) |
| [b15, vb15] = sgd_momentum::update(b15, db15, learning_rate, mu, vb15) |
| [W16, vW16] = sgd_momentum::update(W16, dW16, learning_rate, mu, vW16) |
| [b16, vb16] = sgd_momentum::update(b16, db16, learning_rate, mu, vb16) |
| [W17, vW17] = sgd_momentum::update(W17, dW17, learning_rate, mu, vW17) |
| [b17, vb17] = sgd_momentum::update(b17, db17, learning_rate, mu, vb17) |
| [W18, vW18] = sgd_momentum::update(W18, dW18, learning_rate, mu, vW18) |
| [b18, vb18] = sgd_momentum::update(b18, db18, learning_rate, mu, vb18) |
| [W19, vW19] = sgd_momentum::update(W19, dW19, learning_rate, mu, vW19) |
| [b19, vb19] = sgd_momentum::update(b19, db19, learning_rate, mu, vb19) |
| [W20, vW20] = sgd_momentum::update(W20, dW20, learning_rate, mu, vW20) |
| [b20, vb20] = sgd_momentum::update(b20, db20, learning_rate, mu, vb20) |
| [W21, vW21] = sgd_momentum::update(W21, dW21, learning_rate, mu, vW21) |
| [b21, vb21] = sgd_momentum::update(b21, db21, learning_rate, mu, vb21) |
| [W22, vW22] = sgd_momentum::update(W22, dW22, learning_rate, mu, vW22) |
| [b22, vb22] = sgd_momentum::update(b22, db22, learning_rate, mu, vb22) |
| [W23, vW23] = sgd_momentum::update(W23, dW23, learning_rate, mu, vW23) |
| [b23, vb23] = sgd_momentum::update(b23, db23, learning_rate, mu, vb23) |
| |
| model_result = list( |
| W1, W2, W3, W4, W5, W6, W7, W8, W9, W10, W11, W12, W13, W14, W15, W16, W17, W18, W19, W20, W21, W22, W23, |
| b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, b16, b17, b18, b19, b20, b21, b22, b23, |
| vW1, vW2, vW3, vW4, vW5, vW6, vW7, vW8, vW9, vW10, vW11, vW12, vW13, vW14, vW15, vW16, vW17, vW18, vW19, vW20, vW21, vW22, vW23, |
| vb1, vb2, vb3, vb4, vb5, vb6, vb7, vb8, vb9, vb10, vb11, vb12, vb13, vb14, vb15, vb16, vb17, vb18, vb19, vb20, vb21, vb22, vb23 |
| ) |
| } |
| |
| /* |
| * Evaluates a U-Net architecture. |
| * |
| * The probs matrix contains the class probability predictions and y contains the target. |
| * |
| * Inputs: |
| * - probs: Class probabilities, of shape (N, C*Hin*Win) |
| * - y: Target matrix, of shape (N, C*Hin*Win) |
| * |
| * Outputs: |
| * - loss: Scalar loss, of shape (1) |
| * - accuracy: Scalar accuracy, of shape (1) |
| */ |
| eval = function(matrix[double] probs, matrix[double] y) |
| return (double loss, double accuracy) { |
| |
| # Compute loss & accuracy |
| loss = cross_entropy_loss::forward(probs, y) |
| correct_pred = rowIndexMax(probs) == rowIndexMax(y) |
| accuracy = mean(correct_pred) |
| } |
| |
| /* |
| * Gives the accuracy and loss for a model and given feature and label matrices |
| * |
| * This function is a combination of the predict and eval function used for validation. |
| * For inputs see eval and predict. |
| * |
| * Inputs: |
| * - val_features: Validation data features |
| * - val_labels: Validation data labels |
| * - model: List of weights of the trained model |
| * - hyperparams: Hyperparameters including C, Hin, Win, and K |
| * - F1: Number of filters of the top layer of the U-Net model. Default is 64. |
| * - batch_size: Batch size of prediction |
| * |
| * Outputs: |
| * - loss: Scalar loss, of shape (1). |
| * - accuracy: Scalar accuracy, of shape (1). |
| */ |
| validate = function(matrix[double] val_features, matrix[double] val_labels, |
| list[unknown] model, list[unknown] hyperparams, int F1 = 64, int batch_size = 32) |
| return (double loss, double accuracy) |
| { |
| C = as.integer(as.scalar(hyperparams["C"])) |
| Hin = as.integer(as.scalar(hyperparams["Hin"])) |
| Win = as.integer(as.scalar(hyperparams["Win"])) |
| M = as.integer(as.scalar(hyperparams["M"])) |
| K = as.integer(as.scalar(hyperparams["K"])) |
| predictions = predict(val_features, C, Hin, Win, batch_size, model, M, K, F1) |
| [loss, accuracy] = eval(predictions, val_labels) |
| } |