blob: e6e819b9977d5b2ced3ab2e665812f3eb242dcb4 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
source("scripts/nn/layers/lstm.dml") as lstm
batch_size = as.integer($1)
seq_length = as.integer($2)
num_features = as.integer($3)
hidden_size = as.integer($4)
debug = as.logical(as.integer($5))
seq = as.logical(as.integer($6))
[W,b,out0, c0] = lstm::init(batch_size, num_features, hidden_size)
lstmIn = rand(rows=batch_size, cols=seq_length*num_features, min=-2, max=2, pdf="uniform")
W = rand(rows=num_features + hidden_size, cols=hidden_size*4, min=-1, max=1, pdf="uniform")
b = rand(rows=1, cols=4*hidden_size, min=-1, max=1, pdf="uniform")
out0 = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform")
c0 = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform")
#print(toString(b[1,1]))
#print(toString(W[1,1]))
#print(toString(lstmIn[1,1]))
#print(toString(out0[1,1]))
#print(toString(c0[1,1]))
t0 = time()
[out, c, cache_out, cache_c, cache_ifog] = lstm(lstmIn, W, b, out0, c0, seq)
t1 = time()
if(debug){
print(toString(out))
}
[out2, c2, cache_out2, cache_c2, cache_ifog2] = lstm::forward(lstmIn, W,b,seq_length,num_features,seq,out0, c0)
t2 = time()
if(debug){
print(toString(out2))
}
write(cache_out, $7, format="text");
write(cache_out2, $8, format="text");
write(cache_c, $9, format="text");
write(cache_c2, $10, format="text");
write(cache_ifog, $11, format="text");
write(cache_ifog2, $12, format="text");
T = 1000000
print("built-in took: " + (t1 - t0)/T + " ms")
print("dml-script took: " + (t2 - t1)/T + " ms")