| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| source("scripts/nn/layers/bilstm.dml") as bilstm |
| source("scripts/nn/layers/lstm.dml") as lstm |
| |
| batch_size = as.integer($1) |
| seq_length = as.integer($2) |
| num_features = as.integer($3) |
| hidden_size = as.integer($4) |
| debug = as.logical(as.integer($5)) |
| seq = as.logical(as.integer($6)) |
| |
| factor = 0.01 |
| lstmIn = matrix(seq(0,batch_size*seq_length*num_features - 1)*factor, rows=batch_size,cols=(seq_length*num_features)) |
| |
| W = seq(0, (num_features + hidden_size)*hidden_size*4 - 1)*factor |
| W = W - (num_features + hidden_size)*hidden_size*factor |
| W = matrix(W,rows=num_features + hidden_size, cols=hidden_size*4) |
| b = matrix(1,rows=1, cols=4*hidden_size)*factor |
| out0 = matrix(1,rows=batch_size, cols=hidden_size)*factor |
| c0 = matrix(0,rows=batch_size, cols=hidden_size)*factor |
| |
| out0 = rbind(out0, out0) |
| c0 = rbind(c0, c0) |
| |
| [out, c, cache_out, cache_c, cache_ifog] = bilstm::forward(lstmIn, W, W, b, b,seq_length,num_features,seq,out0, c0) |
| |
| dc = matrix(0,rows=batch_size*2,cols=hidden_size) |
| if(batch_size == 5){ |
| dout = matrix(seq(0, batch_size*hidden_size*seq_length*2 - 1), rows=batch_size, cols=hidden_size*seq_length*2) |
| } else if(batch_size == 4) { |
| dout = matrix(0, rows=batch_size, cols=hidden_size*2) |
| dc = matrix(seq(0, batch_size*hidden_size*2 - 1), rows=batch_size*2, cols=hidden_size) |
| } else if(batch_size == 3) { |
| |
| } else { |
| dout = matrix(1, rows=batch_size, cols=hidden_size*2) |
| if(seq){ |
| dout = matrix(1, rows=batch_size, cols=hidden_size*seq_length*2) |
| } |
| } |
| |
| [dx, dw, db, dw_reverse, db_reverse, dout0, dc0] = bilstm::backward(dout, dc, lstmIn, W, W, b, b, seq_length,num_features, seq, out0, c0, cache_out, cache_c, cache_ifog) |
| #print(toString(dx)) |
| #print(toString(dout0)) |
| #print(toString(dc0)) |
| |
| #print(toString(dw)) |
| #print(toString(dw_reverse)) |
| #print(toString(db)) |
| #print(toString(db_reverse)) |
| |
| errors = matrix(0, rows=1, cols=3) |
| expected = read($7 + "_" + $1 +"_" + $2 +"_" + $3 +"_" + $4 + ".csv", format="csv"); |
| tmp = rbind(dw, dw_reverse, db, db_reverse) |
| error = expected - tmp |
| error = max(abs(error)) |
| errors[1,1] = error |
| |
| expected2 = read($8 + "_" + $1 +"_" + $2 +"_" + $3 +"_" + $4 + ".csv", format="csv"); |
| tmp = rbind(dout0, dc0) |
| error = expected2 - tmp |
| error = max(abs(error)) |
| errors[1,2] = error |
| |
| expected3 = read($9 + "_" + $1 +"_" + $2 +"_" + $3 +"_" + $4 + ".csv", format="csv"); |
| error = expected3 - dx |
| error = max(abs(error)) |
| errors[1,3] = error |
| |
| write(errors, $10, format="text"); |
| |
| |