blob: 26922dcd519c6915a38f868c63b5e5d0ced63d2d [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
source("nn/layers/affine.dml") as affine
# # # initializing the matrix by hand parsing issues in rand command within lineage
M = 5; N = 5
X_batch = rand(rows=M, cols=N, sparsity=1)
W_1 = rand(rows=M, cols=N, sparsity=1)
b_1 = matrix(0, rows=1, cols=M)
prob = affine::forward(X_batch, W_1, b_1)
lin = lineage(prob)
# # TODO stop instruction parser to parse lineage string
# # for now it is stopped by adding a string foo as a work-around
if(sum(prob) > 0)
lin = lin+"foo"
# # The lineage is passed as a list item because even after adding "foo" string the
# # compiler keep parsing the lineage instruction so it is passed as a list item to avoid parsing
# # # create autodiff by parsing the lineage instructions
diffs = autoDiff(output=prob, lineage=list(lin));
ad_dX = as.matrix(diffs['dX'])
ad_dW = as.matrix(diffs['dW'])
ad_dB = as.matrix(diffs['dB'])
# # # # compute the derivatives from the backward script
[dX, dW, dB] = affine::backward(prob, X_batch, W_1, b_1)
sameX = dX != ad_dX
sameW = dW != ad_dW
sameB = dB != ad_dB
output = ((sum(sameX) == 0) & (sum(sameW) == 0) & (sum(sameB) == 0))
write(dX, $1)
write(ad_dX, $2)