| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # This script aims to provide a snippet of a common usecase for our deep learning algorithms. |
| |
| # Imports |
| source("nn/layers/conv2d_builtin.dml") as conv2d |
| |
| # Hyperparameters & Settings |
| batch_size = 32 |
| num_batches = 2 |
| |
| # Generate dummy input data |
| N = batch_size * num_batches # num examples |
| C = 3 # num input channels |
| Hin = 224 # input height |
| Win = 224 # input width |
| X = rand(rows=N, cols=C*Hin*Win, pdf="normal") |
| |
| # Create network: |
| Hf = 3 # filter height |
| Wf = 3 # filter width |
| stride = 1 |
| pad = 1 # For same dimensions, (Hf - stride) / 2 |
| F1 = 32 # num conv filters in conv1 |
| [Wc1, bc1] = conv2d::init(F1, C, Hf, Wf) # inputs: (N, C*Hin*Win) |
| |
| # Create data structure to store gradients computed in parallel |
| doutc1_agg = matrix(0, rows=num_batches, cols=batch_size*F1*Hin*Win) |
| dWc1_agg = matrix(0, rows=num_batches, cols=nrow(Wc1)*ncol(Wc1)) |
| |
| # Imagine that a loop over mini-batches would go here, but for this test, we will |
| # hardcode to the first iteration. |
| j = 1 |
| while(FALSE){} # cut DAG! |
| |
| # Get a mini-batch in this group |
| beg = ((j-1) * batch_size) + 1 |
| end = beg + batch_size - 1 |
| X_batch = X[beg:end,] |
| |
| [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, Wc1, bc1, C, Hin, Win, Hf, Wf, |
| stride, stride, pad, pad) |
| |
| # Note: This causes Spark execution, unless constant folding + IPA second chance compilation is |
| # enabled, due to the Houtc1 & Woutc1 DAGs being unevaluated, and thus the size of doutc1 remaining |
| # unknown. From a mathematical standpoint, those DAGs can be computed during initial compilation |
| # since they are dependent only on scalar arguments passed into conv2d::forward. However, the |
| # initial constant folding (static rewrite) pass can't replace the following Houtc1 & Woutc1 due to |
| # them being dependent on transient reads of the arguments passed into conv2d::forward. Once IPA |
| # runs, those transient reads will be replaced with scalars due to the IPA scalar replacement. |
| # Once that is complete, a second pass of constant folding (static rewrite) can evaluate Houtc1 & |
| # Woutc1 to literals. Given that, a subsequent second pass of IPA can make use of those literals |
| # during size propagation, thus allowing doutc1 to have a known size. Additionally, this allows |
| # known sizes and known literals to be passed into the conv2d::backward function below, once again |
| # allowing for the propagation of known sizes. Overall, with a second pass of static rewrites + |
| # IPA, all sizes in this script can be known during initial compilation, and thus no Spark |
| # instructions will be compiled or run. On the contrary, without that second pass, these sizes |
| # will remain unknown even during recompilation. |
| doutc1 = rand(rows=nrow(outc1), cols=F1*Houtc1*Woutc1) |
| [dX_batch, dWc1, dbc1] = conv2d::backward(doutc1, Houtc1, Woutc1, X_batch, Wc1, bc1, C, |
| Hin, Win, Hf, Wf, stride, stride, pad, pad) |
| |
| doutc1_agg[j,] = matrix(doutc1, rows=1, cols=batch_size*F1*Hin*Win) |
| dWc1_agg[j,] = matrix(dWc1, rows=1, cols=nrow(Wc1)*ncol(Wc1)) |
| |
| # Print outputs to force execution |
| while(FALSE){} # cut DAG! |
| print(sum(doutc1_agg) + " " + sum(dWc1_agg)) |