blob: ec1a7fcbf2b5c5b29ea9f69a953977464deafa6d [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# This script aims to provide a snippet of a common usecase for our deep learning algorithms.
# Imports
source("nn/layers/conv2d_builtin.dml") as conv2d
source("nn/layers/max_pool2d_builtin.dml") as max_pool2d
# Hyperparameters & Settings
batch_size = 32
num_batches = 2
# Generate dummy input data
N = batch_size * num_batches # num examples
C = 3 # num input channels
Hin = 224 # input height
Win = 224 # input width
X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
# Create network:
Hf = 3 # filter height
Wf = 3 # filter width
stride = 1
pad = 1 # For same dimensions, (Hf - stride) / 2
F1 = 32 # num conv filters in conv1
[Wc1, bc1] = conv2d::init(F1, C, Hf, Wf) # inputs: (N, C*Hin*Win)
# Create data structure to store gradients computed in parallel
doutc1_agg = matrix(0, rows=num_batches, cols=batch_size*F1*Hin*Win)
dWc1_agg = matrix(0, rows=num_batches, cols=nrow(Wc1)*ncol(Wc1))
# Imagine that a loop over mini-batches would go here, but for this test, we will
# hardcode to the first iteration.
j = 1
while(FALSE){} # cut DAG!
# Get a mini-batch in this group
beg = ((j-1) * batch_size) + 1
end = beg + batch_size - 1
X_batch = X[beg:end,]
# Note: This causes Spark execution, unless constant folding + IPA second chance compilation is
# enabled, due to the Houtc1 & Woutc1 DAGs being unevaluated, and thus the size of downstream nodes
# in the graph that use Houtc1 & Woutc1 becoming unknown. From a mathematical standpoint, those
# DAGs can be computed during initial compilation since they are dependent only on scalar arguments
# passed into conv2d::forward. However, the initial constant folding (static rewrite) pass can't
# replace the following Houtc1 & Woutc1 due to them being dependent on transient reads of the
# arguments passed into conv2d::forward. Once IPA runs, those transient reads will be replaced with
# scalars due to the IPA scalar replacement. Once that is complete, a second pass of constant
# folding (static rewrite) can evaluate Houtc1 & Woutc1 to literals. Given that, a subsequent
# second pass of IPA can make use of those literals during size propagation, thus allowing
# downstream ops to have known sizes. Overall, with a second pass of static rewrites +
# IPA, all sizes in this script can be known during initial compilation, and thus no Spark
# instructions will be compiled or run. On the contrary, without that second pass, these sizes
# will remain unknown even during recompilation.
[outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, Wc1, bc1, C, Hin, Win, Hf, Wf, stride, stride,
pad, pad)
[outp1, Houtp1, Woutp1] = max_pool2d::forward(outc1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0)
doutp1 = rand(rows=nrow(outp1), cols=F1*Houtp1*Woutp1)
doutc1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outc1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0)
[dX_batch, dWc1, dbc1] = conv2d::backward(doutc1, Houtc1, Woutc1, X_batch, Wc1, bc1, C, Hin, Win,
Hf, Wf, stride, stride, pad, pad)
doutc1_agg[j,] = matrix(doutc1, rows=1, cols=batch_size*F1*Hin*Win)
dWc1_agg[j,] = matrix(dWc1, rows=1, cols=nrow(Wc1)*ncol(Wc1))
# Print outputs to force execution
while(FALSE){} # cut DAG!
print(sum(doutc1_agg) + " " + sum(dWc1_agg))