| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| source("scripts/nn/layers/lstm.dml") as lstm |
| |
| batch_size = as.integer($1) |
| seq_length = as.integer($2) |
| num_features = as.integer($3) |
| hidden_size = as.integer($4) |
| debug = as.logical(as.integer($5)) |
| seq = as.logical(as.integer($6)) |
| |
| [W,b,out0, c0] = lstm::init(batch_size, num_features, hidden_size) |
| lstmIn = rand(rows=batch_size, cols=seq_length*num_features, min=-2, max=2, pdf="uniform") |
| W = rand(rows=num_features + hidden_size, cols=hidden_size*4, min=-1, max=1, pdf="uniform") |
| b = rand(rows=1, cols=4*hidden_size, min=-1, max=1, pdf="uniform") |
| out0 = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform") |
| c0 = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform") |
| dout = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform") |
| if(seq){ |
| dout = rand(rows=batch_size, cols=hidden_size*seq_length, min=-1, max=1, pdf="uniform") |
| } |
| dc = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform") |
| |
| #print(toString(b[1,1])) |
| #print(toString(W[1,1])) |
| #print(toString(lstmIn[1,1])) |
| #print(toString(out0[1,1])) |
| #print(toString(c0[1,1])) |
| |
| [out, c, cache_out, cache_c, cache_ifog] = lstm(lstmIn, W, b, out0, c0, seq) |
| [out2, c2, cache_out2, cache_c2, cache_ifog2] = lstm::forward(lstmIn, W,b,seq_length,num_features,seq,out0, c0) |
| |
| t0 = time() |
| [dx, dw, db, dout0, dc0] = lstm_backward(lstmIn, W, b, out0, c0, seq, dout, dc, cache_out, cache_c, cache_ifog) |
| t1 = time() |
| [dx2, dw2, db2, dout02, dc02] = lstm::backward(dout, dc, lstmIn, W,b,seq_length,num_features,seq,out0, c0,cache_out2, cache_c2, cache_ifog2) |
| t2 = time() |
| |
| if(debug){ |
| print(toString(out)) |
| print(toString(out2)) |
| } |
| |
| print(toString(dw[1,1])) |
| print(toString(dw2[1,1])) |
| #print(toString(dx[1,1])) |
| #print(toString(dx2[1,1])) |
| #print(toString(db[1,1])) |
| #print(toString(db2[1,1])) |
| #print(toString(dout0[1,1])) |
| #print(toString(dout02[1,1])) |
| #print(toString(dc0[1,1])) |
| #print(toString(dc02[1,1])) |
| |
| write(dx, $7, format="text"); |
| write(dx2, $8, format="text"); |
| write(dw, $9, format="text"); |
| write(dw2, $10, format="text"); |
| write(db, $11, format="text"); |
| write(db2, $12, format="text"); |
| write(dout0, $13, format="text"); |
| write(dout02, $14, format="text"); |
| write(dc0, $15, format="text"); |
| write(dc02, $16, format="text"); |
| |
| T = 1000000 |
| print("built-in took: " + (t1 - t0)/T + " ms") |
| print("dml-script took: " + (t2 - t1)/T + " ms") |