blob: e47f841a0cee3ffb469bec2e2d04cf8540bc13f0 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
source("scripts/nn/layers/bilstm.dml") as bilstm
source("scripts/nn/layers/lstm.dml") as lstm
batch_size = as.integer($1)
seq_length = as.integer($2)
num_features = as.integer($3)
hidden_size = as.integer($4)
debug = as.logical(as.integer($5))
seq = as.logical(as.integer($6))
factor = 0.01
lstmIn = matrix(seq(0,batch_size*seq_length*num_features - 1)*factor, rows=batch_size,cols=(seq_length*num_features))
W = seq(0, (num_features + hidden_size)*hidden_size*4 - 1)*factor
W = W - (num_features + hidden_size)*hidden_size*factor
W = matrix(W,rows=num_features + hidden_size, cols=hidden_size*4)
if(batch_size == 2){
b = (matrix(seq(0,4*hidden_size- 1), rows=1, cols=4*hidden_size) - 2*hidden_size)*factor
c0 = (matrix(seq(0,2*batch_size*hidden_size - 1), rows=batch_size*2, cols=hidden_size) - 2*hidden_size)*factor
out0 = (matrix(seq(0,2*batch_size*hidden_size - 1), rows=batch_size*2, cols=hidden_size) + 2*hidden_size)*factor
} else {
b = matrix(1,rows=1, cols=4*hidden_size)*factor
out0 = matrix(1,rows=batch_size, cols=hidden_size)*factor
c0 = matrix(0,rows=batch_size, cols=hidden_size)*factor
c0 = rbind(c0, c0)
out0 = rbind(out0, out0)
}
[out2, c2, cache_out2, cache_c2, cache_ifog2] = bilstm::forward(lstmIn, W, W, b, b,seq_length,num_features,seq,out0, c0)
expected = read($7 + "_" + $1 +"_" + $2 +"_" + $3 +"_" + $4 + ".csv", format="csv");
if(seq == FALSE){
expectedA = expected[,(seq_length-1)*hidden_size*2 + 1 : (seq_length-1)*hidden_size*2 + hidden_size]
expectedB = expected[, hidden_size + 1 : hidden_size*2]
expected = cbind(expectedA, expectedB)
}
error = expected - out2
error = max(abs(error))
#print(error)
write(error, $8, format="text");