blob: 036a3219b18596c881341a07c743130198b8f9f5 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
source("scripts/nn/layers/bilstm.dml") as bilstm
source("scripts/nn/layers/lstm.dml") as lstm
batch_size = as.integer($1)
seq_length = as.integer($2)
num_features = as.integer($3)
hidden_size = as.integer($4)
debug = as.logical(as.integer($5))
seq = as.logical(as.integer($6))
factor = 0.01
lstmIn = matrix(seq(0,batch_size*seq_length*num_features - 1)*factor, rows=batch_size,cols=(seq_length*num_features))
W = seq(0, (num_features + hidden_size)*hidden_size*4 - 1)*factor
W = W - (num_features + hidden_size)*hidden_size*factor
W = matrix(W,rows=num_features + hidden_size, cols=hidden_size*4)
b = matrix(1,rows=1, cols=4*hidden_size)*factor
out0 = matrix(1,rows=batch_size, cols=hidden_size)*factor
c0 = matrix(0,rows=batch_size, cols=hidden_size)*factor
out0 = rbind(out0, out0)
c0 = rbind(c0, c0)
[out, c, cache_out, cache_c, cache_ifog] = bilstm::forward(lstmIn, W, W, b, b,seq_length,num_features,seq,out0, c0)
dc = matrix(0,rows=batch_size*2,cols=hidden_size)
if(batch_size == 5){
dout = matrix(seq(0, batch_size*hidden_size*seq_length*2 - 1), rows=batch_size, cols=hidden_size*seq_length*2)
} else if(batch_size == 4) {
dout = matrix(0, rows=batch_size, cols=hidden_size*2)
dc = matrix(seq(0, batch_size*hidden_size*2 - 1), rows=batch_size*2, cols=hidden_size)
} else if(batch_size == 3) {
} else {
dout = matrix(1, rows=batch_size, cols=hidden_size*2)
if(seq){
dout = matrix(1, rows=batch_size, cols=hidden_size*seq_length*2)
}
}
[dx, dw, db, dw_reverse, db_reverse, dout0, dc0] = bilstm::backward(dout, dc, lstmIn, W, W, b, b, seq_length,num_features, seq, out0, c0, cache_out, cache_c, cache_ifog)
#print(toString(dx))
#print(toString(dout0))
#print(toString(dc0))
#print(toString(dw))
#print(toString(dw_reverse))
#print(toString(db))
#print(toString(db_reverse))
errors = matrix(0, rows=1, cols=3)
expected = read($7 + "_" + $1 +"_" + $2 +"_" + $3 +"_" + $4 + ".csv", format="csv");
tmp = rbind(dw, dw_reverse, db, db_reverse)
error = expected - tmp
error = max(abs(error))
errors[1,1] = error
expected2 = read($8 + "_" + $1 +"_" + $2 +"_" + $3 +"_" + $4 + ".csv", format="csv");
tmp = rbind(dout0, dc0)
error = expected2 - tmp
error = max(abs(error))
errors[1,2] = error
expected3 = read($9 + "_" + $1 +"_" + $2 +"_" + $3 +"_" + $4 + ".csv", format="csv");
error = expected3 - dx
error = max(abs(error))
errors[1,3] = error
write(errors, $10, format="text");