| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| source("scripts/nn/layers/bilstm.dml") as bilstm |
| source("scripts/nn/layers/lstm.dml") as lstm |
| |
| batch_size = as.integer($1) |
| seq_length = as.integer($2) |
| num_features = as.integer($3) |
| hidden_size = as.integer($4) |
| debug = as.logical(as.integer($5)) |
| seq = as.logical(as.integer($6)) |
| |
| factor = 0.01 |
| lstmIn = matrix(seq(0,batch_size*seq_length*num_features - 1)*factor, rows=batch_size,cols=(seq_length*num_features)) |
| |
| W = seq(0, (num_features + hidden_size)*hidden_size*4 - 1)*factor |
| W = W - (num_features + hidden_size)*hidden_size*factor |
| W = matrix(W,rows=num_features + hidden_size, cols=hidden_size*4) |
| if(batch_size == 2){ |
| b = (matrix(seq(0,4*hidden_size- 1), rows=1, cols=4*hidden_size) - 2*hidden_size)*factor |
| c0 = (matrix(seq(0,2*batch_size*hidden_size - 1), rows=batch_size*2, cols=hidden_size) - 2*hidden_size)*factor |
| out0 = (matrix(seq(0,2*batch_size*hidden_size - 1), rows=batch_size*2, cols=hidden_size) + 2*hidden_size)*factor |
| } else { |
| b = matrix(1,rows=1, cols=4*hidden_size)*factor |
| out0 = matrix(1,rows=batch_size, cols=hidden_size)*factor |
| c0 = matrix(0,rows=batch_size, cols=hidden_size)*factor |
| c0 = rbind(c0, c0) |
| out0 = rbind(out0, out0) |
| } |
| |
| [out2, c2, cache_out2, cache_c2, cache_ifog2] = bilstm::forward(lstmIn, W, W, b, b,seq_length,num_features,seq,out0, c0) |
| expected = read($7 + "_" + $1 +"_" + $2 +"_" + $3 +"_" + $4 + ".csv", format="csv"); |
| if(seq == FALSE){ |
| expectedA = expected[,(seq_length-1)*hidden_size*2 + 1 : (seq_length-1)*hidden_size*2 + hidden_size] |
| expectedB = expected[, hidden_size + 1 : hidden_size*2] |
| expected = cbind(expectedA, expectedB) |
| } |
| error = expected - out2 |
| error = max(abs(error)) |
| #print(error) |
| write(error, $8, format="text"); |