blob: 171508cfce430bbd0e5cf4659864857dc72c4005 [file]
#' Generate a RNN symbolic model - requires CUDA
#'
#' @param config Either seq-to-one or one-to-one
#' @param cell_type Type of RNN cell: either gru or lstm
#' @param num_rnn_layer int, number of stacked layers
#' @param num_hidden int, size of the state in each RNN layer
#' @param num_embed int, default = NULL - no embedding. Dimension of the embedding vectors
#' @param num_decode int, number of output variables in the decoding layer
#' @param input_size int, number of levels in the data - only used for embedding
#' @param dropout
#'
#' @export
rnn.graph <- function (num_rnn_layer, input_size = NULL, num_embed = NULL,
num_hidden, num_decode, dropout = 0, ignore_label = -1, bidirectional = F,
loss_output = NULL, config, cell_type, masking = F, output_last_state = F,
rnn.state = NULL, rnn.state.cell = NULL, prefix = "") {
data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
seq.mask <- mx.symbol.Variable("seq.mask")
if (!is.null(num_embed))
embed.weight <- mx.symbol.Variable("embed.weight")
rnn.params.weight <- mx.symbol.Variable("rnn.params.weight")
if (is.null(rnn.state)) rnn.state <- mx.symbol.Variable("rnn.state")
if (cell_type == "lstm" & is.null(rnn.state.cell)) {
rnn.state.cell <- mx.symbol.Variable("rnn.state.cell")
}
cls.weight <- mx.symbol.Variable("cls.weight")
cls.bias <- mx.symbol.Variable("cls.bias")
if (!is.null(num_embed)) {
data <- mx.symbol.Embedding(data = data, input_dim = input_size,
weight = embed.weight, output_dim = num_embed, name = "embed")
}
data = mx.symbol.swapaxes(data = data, dim1 = 0, dim2 = 1, name = paste0(prefix, "swap_pre"))
if (cell_type == "lstm") {
rnn <- mx.symbol.RNN(data = data, state = rnn.state,
state_cell = rnn.state.cell, parameters = rnn.params.weight,
state.size = num_hidden, num.layers = num_rnn_layer,
bidirectional = bidirectional, mode = cell_type, state.outputs = output_last_state,
p = dropout, name = paste0(prefix, "RNN"))
} else {
rnn <- mx.symbol.RNN(data = data, state = rnn.state,
parameters = rnn.params.weight, state.size = num_hidden,
num.layers = num_rnn_layer, bidirectional = bidirectional, mode = cell_type,
state.outputs = output_last_state, p = dropout,
name = paste0(prefix, "RNN"))
}
if (config == "seq-to-one") {
if (masking) mask <- mx.symbol.SequenceLast(data = rnn[[1]], use.sequence.length = T, sequence_length = seq.mask, name = "mask") else
mask <- mx.symbol.SequenceLast(data = rnn[[1]], use.sequence.length = F, name = "mask")
if (!is.null(loss_output)) {
decode <- mx.symbol.FullyConnected(data = mask, weight = cls.weight, bias = cls.bias, num_hidden = num_decode, name = "decode")
out <- switch(loss_output, softmax = mx.symbol.SoftmaxOutput(data = decode, label = label, use_ignore = !ignore_label == -1, ignore_label = ignore_label, name = "loss"),
linear = mx.symbol.LinearRegressionOutput(data = decode, label = label, name = "loss"),
logistic = mx.symbol.LogisticRegressionOutput(data = decode, label = label, name = "loss"),
MAE = mx.symbol.MAERegressionOutput(data = decode, label = label, name = "loss"))
}
else out <- mask
}
else if (config == "one-to-one") {
if (masking) mask <- mx.symbol.SequenceMask(data = rnn[[1]], use.sequence.length = T, sequence_length = seq.mask, value = 0, name = "mask") else
mask <- mx.symbol.identity(data = rnn[[1]], name = "mask")
mask = mx.symbol.swapaxes(data = mask, dim1 = 0, dim2 = 1, name = paste0(prefix, "swap_post"))
if (!is.null(loss_output)) {
mask <- mx.symbol.reshape(data = mask, shape = c(0, -1), reverse = TRUE)
label <- mx.symbol.reshape(data = label, shape = c(-1))
decode <- mx.symbol.FullyConnected(data = mask, weight = cls.weight, bias = cls.bias, num_hidden = num_decode,
flatten = TRUE, name = paste0(prefix, "decode"))
out <- switch(loss_output, softmax = mx.symbol.SoftmaxOutput(data = decode, label = label, use_ignore = !ignore_label == -1, ignore_label = ignore_label, name = "loss"),
linear = mx.symbol.LinearRegressionOutput(data = decode, label = label, name = "loss"),
logistic = mx.symbol.LogisticRegressionOutput(data = decode, label = label, name = "loss"),
MAE = mx.symbol.MAERegressionOutput(data = decode, label = label, name = "loss"))
} else out <- mask
}
return(out)
}
# LSTM cell symbol
lstm.cell <- function(num_hidden, indata, prev.state, param, seqidx, layeridx, dropout = 0, prefix = "") {
if (dropout > 0 && layeridx > 1)
indata <- mx.symbol.Dropout(data = indata, p = dropout)
i2h <- mx.symbol.FullyConnected(data = indata, weight = param$i2h.weight, bias = param$i2h.bias,
num_hidden = num_hidden * 4, name = paste0(prefix, "t", seqidx, ".l", layeridx, ".i2h"))
if (!is.null(prev.state)) {
h2h <- mx.symbol.FullyConnected(data = prev.state$h, weight = param$h2h.weight,
bias = param$h2h.bias, num_hidden = num_hidden * 4,
name = paste0(prefix, "t", seqidx, ".l", layeridx, ".h2h"))
gates <- i2h + h2h
} else {
gates <- i2h
}
split.gates <- mx.symbol.split(gates, num.outputs = 4, axis = 1, squeeze.axis = F,
name = paste0(prefix, "t", seqidx, ".l", layeridx, ".slice"))
in.gate <- mx.symbol.Activation(split.gates[[1]], act.type = "sigmoid")
in.transform <- mx.symbol.Activation(split.gates[[2]], act.type = "tanh")
forget.gate <- mx.symbol.Activation(split.gates[[3]], act.type = "sigmoid")
out.gate <- mx.symbol.Activation(split.gates[[4]], act.type = "sigmoid")
if (is.null(prev.state)) {
next.c <- in.gate * in.transform
} else {
next.c <- (forget.gate * prev.state$c) + (in.gate * in.transform)
}
next.h <- out.gate * mx.symbol.Activation(next.c, act.type = "tanh")
return(list(h = next.h, c = next.c))
}
# GRU cell symbol
gru.cell <- function(num_hidden, indata, prev.state, param, seqidx, layeridx, dropout = 0, prefix)
{
if (dropout > 0 && layeridx > 1)
indata <- mx.symbol.Dropout(data = indata, p = dropout)
i2h <- mx.symbol.FullyConnected(data = indata, weight = param$gates.i2h.weight,
bias = param$gates.i2h.bias, num_hidden = num_hidden * 2,
name = paste0(prefix, "t", seqidx, ".l", layeridx, ".gates.i2h"))
if (!is.null(prev.state)) {
h2h <- mx.symbol.FullyConnected(data = prev.state$h, weight = param$gates.h2h.weight,
bias = param$gates.h2h.bias, num_hidden = num_hidden * 2,
name = paste0(prefix, "t", seqidx, ".l", layeridx, ".gates.h2h"))
gates <- i2h + h2h
} else {
gates <- i2h
}
split.gates <- mx.symbol.split(gates, num.outputs = 2, axis = 1, squeeze.axis = F,
name = paste0(prefix, "t", seqidx, ".l", layeridx, ".split"))
update.gate <- mx.symbol.Activation(split.gates[[1]], act.type = "sigmoid")
reset.gate <- mx.symbol.Activation(split.gates[[2]], act.type = "sigmoid")
htrans.i2h <- mx.symbol.FullyConnected(data = indata, weight = param$trans.i2h.weight,
bias = param$trans.i2h.bias, num_hidden = num_hidden,
name = paste0(prefix, "t", seqidx, ".l", layeridx, ".trans.i2h"))
if (is.null(prev.state)) {
h.after.reset <- reset.gate * 0
} else {
h.after.reset <- prev.state$h * reset.gate
}
htrans.h2h <- mx.symbol.FullyConnected(data = h.after.reset, weight = param$trans.h2h.weight,
bias = param$trans.h2h.bias, num_hidden = num_hidden,
name = paste0(prefix, "t", seqidx, ".l", layeridx, ".trans.h2h"))
h.trans <- htrans.i2h + htrans.h2h
h.trans.active <- mx.symbol.Activation(h.trans, act.type = "tanh")
if (is.null(prev.state)) {
next.h <- update.gate * h.trans.active
} else {
next.h <- prev.state$h + update.gate * (h.trans.active - prev.state$h)
}
return(list(h = next.h))
}
#' unroll representation of RNN running on non CUDA device
#'
#' @param config Either seq-to-one or one-to-one
#' @param cell_type Type of RNN cell: either gru or lstm
#' @param num_rnn_layer int, number of stacked layers
#' @param seq_len int, number of time steps to unroll
#' @param num_hidden int, size of the state in each RNN layer
#' @param num_embed int, default = NULL - no embedding. Dimension of the embedding vectors
#' @param num_decode int, number of output variables in the decoding layer
#' @param input_size int, number of levels in the data - only used for embedding
#' @param dropout
#'
#' @export
rnn.graph.unroll <- function(num_rnn_layer,
seq_len,
input_size = NULL,
num_embed = NULL,
num_hidden,
num_decode,
dropout = 0,
ignore_label = -1,
loss_output = NULL,
init.state = NULL,
config,
cell_type = "lstm",
masking = F,
output_last_state = F,
prefix = "",
data_name = "data",
label_name = "label") {
if (!is.null(num_embed)) embed.weight <- mx.symbol.Variable(paste0(prefix, "embed.weight"))
# Initial state
if (is.null(init.state) & output_last_state) {
init.state <- lapply(1:num_rnn_layer, function(i) {
if (cell_type=="lstm") {
state <- list(h = mx.symbol.Variable(paste0("init_", prefix, i, "_h")),
c = mx.symbol.Variable(paste0("init_", prefix, i, "_c")))
} else if (cell_type=="gru") {
state <- list(h = mx.symbol.Variable(paste0("init_", prefix, i, "_h")))
}
return (state)
})
}
cls.weight <- mx.symbol.Variable(paste0(prefix, "cls.weight"))
cls.bias <- mx.symbol.Variable(paste0(prefix, "cls.bias"))
param.cells <- lapply(1:num_rnn_layer, function(i) {
if (cell_type=="lstm") {
cell <- list(i2h.weight = mx.symbol.Variable(paste0(prefix, "l", i, ".i2h.weight")),
i2h.bias = mx.symbol.Variable(paste0(prefix, "l", i, ".i2h.bias")),
h2h.weight = mx.symbol.Variable(paste0(prefix, "l", i, ".h2h.weight")),
h2h.bias = mx.symbol.Variable(paste0(prefix, "l", i, ".h2h.bias")))
} else if (cell_type=="gru") {
cell <- list(gates.i2h.weight = mx.symbol.Variable(paste0(prefix, "l", i, ".gates.i2h.weight")),
gates.i2h.bias = mx.symbol.Variable(paste0(prefix, "l", i, ".gates.i2h.bias")),
gates.h2h.weight = mx.symbol.Variable(paste0(prefix, "l", i, ".gates.h2h.weight")),
gates.h2h.bias = mx.symbol.Variable(paste0(prefix, "l", i, ".gates.h2h.bias")),
trans.i2h.weight = mx.symbol.Variable(paste0(prefix, "l", i, ".trans.i2h.weight")),
trans.i2h.bias = mx.symbol.Variable(paste0(prefix, "l", i, ".trans.i2h.bias")),
trans.h2h.weight = mx.symbol.Variable(paste0(prefix, "l", i, ".trans.h2h.weight")),
trans.h2h.bias = mx.symbol.Variable(paste0(prefix, "l", i, ".trans.h2h.bias")))
}
return (cell)
})
# embeding layer
data <- mx.symbol.Variable(data_name)
label <- mx.symbol.Variable(label_name)
seq.mask <- mx.symbol.Variable(paste0(prefix, "seq.mask"))
data = mx.symbol.swapaxes(data = data, dim1 = 0, dim2 = 1, name = paste0(prefix, "swap_pre"))
if (!is.null(num_embed)) {
data <- mx.symbol.Embedding(data = data, input_dim = input_size,
weight=embed.weight, output_dim = num_embed, name = paste0(prefix, "embed"))
}
data <- mx.symbol.split(data = data, axis = 0, num.outputs = seq_len, squeeze_axis = T)
last.hidden <- list()
last.states <- list()
for (seqidx in 1:seq_len) {
hidden <- data[[seqidx]]
for (i in 1:num_rnn_layer) {
if (seqidx==1) prev.state <- init.state[[i]] else
prev.state <- last.states[[i]]
if (cell_type=="lstm") {
cell.symbol <- lstm.cell
} else if (cell_type=="gru"){
cell.symbol <- gru.cell
}
next.state <- cell.symbol(num_hidden = num_hidden,
indata = hidden,
prev.state = prev.state,
param = param.cells[[i]],
seqidx = seqidx,
layeridx = i,
dropout = dropout,
prefix = prefix)
hidden <- next.state$h
last.states[[i]] <- next.state
}
# Aggregate outputs from each timestep
last.hidden <- c(last.hidden, hidden)
}
if (output_last_state) {
out.states = mx.symbol.Group(unlist(last.states))
}
# concat hidden units - concat seq_len blocks of dimension num_hidden x batch.size
concat <- mx.symbol.concat(data = last.hidden, num.args = seq_len, dim = 0, name = paste0(prefix, "concat"))
concat <- mx.symbol.reshape(data = concat, shape = c(num_hidden, -1, seq_len), name = paste0(prefix, "rnn_reshape"))
if (config=="seq-to-one") {
if (masking) mask <- mx.symbol.SequenceLast(data=concat, use.sequence.length = T, sequence_length = seq.mask, name = paste0(prefix, "mask")) else
mask <- mx.symbol.SequenceLast(data=concat, use.sequence.length = F, name = paste0(prefix, "mask"))
if (!is.null(loss_output)) {
decode <- mx.symbol.FullyConnected(data = mask,
weight = cls.weight,
bias = cls.bias,
num_hidden = num_decode,
name = paste0(prefix, "decode"))
out <- switch(loss_output,
softmax = mx.symbol.SoftmaxOutput(data=decode, label=label, use_ignore = !ignore_label == -1, ignore_label = ignore_label, name = paste0(prefix, "loss")),
linear = mx.symbol.LinearRegressionOutput(data=decode, label=label, name = paste0(prefix, "loss")),
logistic = mx.symbol.LogisticRegressionOutput(data=decode, label=label, paste0(prefix, name = "loss")),
MAE = mx.symbol.MAERegressionOutput(data=decode, label=label, paste0(prefix, name = "loss"))
)
} else out <- mask
} else if (config=="one-to-one"){
if (masking) mask <- mx.symbol.SequenceMask(data = concat, use.sequence.length = T, sequence_length = seq.mask, value = 0, name = paste0(prefix, "mask")) else
mask <- mx.symbol.identity(data = concat, name = paste0(prefix, "mask"))
mask = mx.symbol.swapaxes(data = mask, dim1 = 0, dim2 = 1, name = paste0(prefix, "swap_post"))
if (!is.null(loss_output)) {
mask <- mx.symbol.reshape(data = mask, shape = c(0, -1), reverse = TRUE)
label <- mx.symbol.reshape(data = label, shape = c(-1))
decode <- mx.symbol.FullyConnected(data = mask, weight = cls.weight, bias = cls.bias, num_hidden = num_decode,
flatten = T, name = paste0(prefix, "decode"))
out <- switch(loss_output,
softmax = mx.symbol.SoftmaxOutput(data=decode, label=label, use_ignore = !ignore_label == -1, ignore_label = ignore_label,
name = paste0(prefix, "loss")),
linear = mx.symbol.LinearRegressionOutput(data=decode, label=label, name = paste0(prefix, "loss")),
logistic = mx.symbol.LogisticRegressionOutput(data=decode, label=label, name = paste0(prefix, "loss")),
MAE = mx.symbol.MAERegressionOutput(data=decode, label=label, name = paste0(prefix, "loss"))
)
} else out <- mask
}
if (output_last_state) {
return(mx.symbol.Group(c(out, out.states)))
} else return(out)
}