blob: 3645fad53ae96f4e715b48eeb8e0ae97e6e83830 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
source("scripts/nn/layers/lstm.dml") as lstm
batch_size = as.integer($1)
seq_length = as.integer($2)
num_features = as.integer($3)
hidden_size = as.integer($4)
debug = as.logical(as.integer($5))
seq = as.logical(as.integer($6))
[W,b,out0, c0] = lstm::init(batch_size, num_features, hidden_size)
lstmIn = rand(rows=batch_size, cols=seq_length*num_features, min=-2, max=2, pdf="uniform")
W = rand(rows=num_features + hidden_size, cols=hidden_size*4, min=-1, max=1, pdf="uniform")
b = rand(rows=1, cols=4*hidden_size, min=-1, max=1, pdf="uniform")
out0 = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform")
c0 = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform")
dout = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform")
if(seq){
dout = rand(rows=batch_size, cols=hidden_size*seq_length, min=-1, max=1, pdf="uniform")
}
dc = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform")
#print(toString(b[1,1]))
#print(toString(W[1,1]))
#print(toString(lstmIn[1,1]))
#print(toString(out0[1,1]))
#print(toString(c0[1,1]))
[out, c, cache_out, cache_c, cache_ifog] = lstm(lstmIn, W, b, out0, c0, seq)
[out2, c2, cache_out2, cache_c2, cache_ifog2] = lstm::forward(lstmIn, W,b,seq_length,num_features,seq,out0, c0)
t0 = time()
[dx, dw, db, dout0, dc0] = lstm_backward(lstmIn, W, b, out0, c0, seq, dout, dc, cache_out, cache_c, cache_ifog)
t1 = time()
[dx2, dw2, db2, dout02, dc02] = lstm::backward(dout, dc, lstmIn, W,b,seq_length,num_features,seq,out0, c0,cache_out2, cache_c2, cache_ifog2)
t2 = time()
if(debug){
print(toString(out))
print(toString(out2))
}
print(toString(dw[1,1]))
print(toString(dw2[1,1]))
#print(toString(dx[1,1]))
#print(toString(dx2[1,1]))
#print(toString(db[1,1]))
#print(toString(db2[1,1]))
#print(toString(dout0[1,1]))
#print(toString(dout02[1,1]))
#print(toString(dc0[1,1]))
#print(toString(dc02[1,1]))
write(dx, $7, format="text");
write(dx2, $8, format="text");
write(dw, $9, format="text");
write(dw2, $10, format="text");
write(db, $11, format="text");
write(db2, $12, format="text");
write(dout0, $13, format="text");
write(dout02, $14, format="text");
write(dc0, $15, format="text");
write(dc02, $16, format="text");
T = 1000000
print("built-in took: " + (t1 - t0)/T + " ms")
print("dml-script took: " + (t2 - t1)/T + " ms")