| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # This script aims to provide a snippet of a common usecase for our deep learning algorithms. |
| |
| # Imports |
| source("nn/layers/conv2d_builtin.dml") as conv2d |
| source("nn/layers/max_pool2d_builtin.dml") as max_pool2d |
| |
| # Hyperparameters & Settings |
| batch_size = 32 |
| num_batches = 2 |
| |
| # Generate dummy input data |
| N = batch_size * num_batches # num examples |
| C = 3 # num input channels |
| Hin = 224 # input height |
| Win = 224 # input width |
| X = rand(rows=N, cols=C*Hin*Win, pdf="normal") |
| |
| # Create network: |
| Hf = 3 # filter height |
| Wf = 3 # filter width |
| stride = 1 |
| pad = 1 # For same dimensions, (Hf - stride) / 2 |
| F1 = 32 # num conv filters in conv1 |
| [Wc1, bc1] = conv2d::init(F1, C, Hf, Wf) # inputs: (N, C*Hin*Win) |
| |
| # Create data structure to store gradients computed in parallel |
| doutc1_agg = matrix(0, rows=num_batches, cols=batch_size*F1*Hin*Win) |
| dWc1_agg = matrix(0, rows=num_batches, cols=nrow(Wc1)*ncol(Wc1)) |
| |
| # Imagine that a loop over mini-batches would go here, but for this test, we will |
| # hardcode to the first iteration. |
| j = 1 |
| while(FALSE){} # cut DAG! |
| |
| # Get a mini-batch in this group |
| beg = ((j-1) * batch_size) + 1 |
| end = beg + batch_size - 1 |
| X_batch = X[beg:end,] |
| |
| # Note: This causes Spark execution, unless constant folding + IPA second chance compilation is |
| # enabled, due to the Houtc1 & Woutc1 DAGs being unevaluated, and thus the size of downstream nodes |
| # in the graph that use Houtc1 & Woutc1 becoming unknown. From a mathematical standpoint, those |
| # DAGs can be computed during initial compilation since they are dependent only on scalar arguments |
| # passed into conv2d::forward. However, the initial constant folding (static rewrite) pass can't |
| # replace the following Houtc1 & Woutc1 due to them being dependent on transient reads of the |
| # arguments passed into conv2d::forward. Once IPA runs, those transient reads will be replaced with |
| # scalars due to the IPA scalar replacement. Once that is complete, a second pass of constant |
| # folding (static rewrite) can evaluate Houtc1 & Woutc1 to literals. Given that, a subsequent |
| # second pass of IPA can make use of those literals during size propagation, thus allowing |
| # downstream ops to have known sizes. Overall, with a second pass of static rewrites + |
| # IPA, all sizes in this script can be known during initial compilation, and thus no Spark |
| # instructions will be compiled or run. On the contrary, without that second pass, these sizes |
| # will remain unknown even during recompilation. |
| [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, Wc1, bc1, C, Hin, Win, Hf, Wf, stride, stride, |
| pad, pad) |
| [outp1, Houtp1, Woutp1] = max_pool2d::forward(outc1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0) |
| |
| |
| doutp1 = rand(rows=nrow(outp1), cols=F1*Houtp1*Woutp1) |
| doutc1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outc1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0) |
| [dX_batch, dWc1, dbc1] = conv2d::backward(doutc1, Houtc1, Woutc1, X_batch, Wc1, bc1, C, Hin, Win, |
| Hf, Wf, stride, stride, pad, pad) |
| |
| doutc1_agg[j,] = matrix(doutc1, rows=1, cols=batch_size*F1*Hin*Win) |
| dWc1_agg[j,] = matrix(dWc1, rows=1, cols=nrow(Wc1)*ncol(Wc1)) |
| |
| # Print outputs to force execution |
| while(FALSE){} # cut DAG! |
| print(sum(doutc1_agg) + " " + sum(dWc1_agg)) |