| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| source("nn/layers/affine.dml") as affine |
| |
| |
| # # # initializing the matrix by hand parsing issues in rand command within lineage |
| M = 5; N = 5 |
| X_batch = rand(rows=M, cols=N, sparsity=1) |
| |
| W_1 = rand(rows=M, cols=N, sparsity=1) |
| b_1 = matrix(0, rows=1, cols=M) |
| |
| prob = affine::forward(X_batch, W_1, b_1) |
| lin = lineage(prob) |
| |
| # # TODO stop instruction parser to parse lineage string |
| # # for now it is stopped by adding a string foo as a work-around |
| if(sum(prob) > 0) |
| lin = lin+"foo" |
| |
| # # The lineage is passed as a list item because even after adding "foo" string the |
| # # compiler keep parsing the lineage instruction so it is passed as a list item to avoid parsing |
| # # # create autodiff by parsing the lineage instructions |
| diffs = autoDiff(output=prob, lineage=list(lin)); |
| |
| ad_dX = as.matrix(diffs['dX']) |
| ad_dW = as.matrix(diffs['dW']) |
| ad_dB = as.matrix(diffs['dB']) |
| |
| # # # # compute the derivatives from the backward script |
| [dX, dW, dB] = affine::backward(prob, X_batch, W_1, b_1) |
| |
| sameX = dX != ad_dX |
| sameW = dW != ad_dW |
| sameB = dB != ad_dB |
| |
| output = ((sum(sameX) == 0) & (sum(sameW) == 0) & (sum(sameB) == 0)) |
| |
| write(dX, $1) |
| write(ad_dX, $2) |