blob: 622388993c8ca3415ff44a91198bf9f1a45592cb [file] [log] [blame]
# lstm cell symbol
lstm <- function(num.hidden, indata, prev.state, param, seqidx, layeridx, dropout=0) {
if (dropout > 0)
indata <- mx.symbol.Dropout(data=indata, p=dropout)
i2h <- mx.symbol.FullyConnected(data=indata,
weight=param$i2h.weight,
bias=param$i2h.bias,
num.hidden=num.hidden * 4,
name=paste0("t", seqidx, ".l", layeridx, ".i2h"))
h2h <- mx.symbol.FullyConnected(data=prev.state$h,
weight=param$h2h.weight,
bias=param$h2h.bias,
num.hidden=num.hidden * 4,
name=paste0("t", seqidx, ".l", layeridx, ".h2h"))
gates <- i2h + h2h
slice.gates <- mx.symbol.SliceChannel(gates, num.outputs=4,
name=paste0("t", seqidx, ".l", layeridx, ".slice"))
in.gate <- mx.symbol.Activation(slice.gates[[1]], act.type="sigmoid")
in.transform <- mx.symbol.Activation(slice.gates[[2]], act.type="tanh")
forget.gate <- mx.symbol.Activation(slice.gates[[3]], act.type="sigmoid")
out.gate <- mx.symbol.Activation(slice.gates[[4]], act.type="sigmoid")
next.c <- (forget.gate * prev.state$c) + (in.gate * in.transform)
next.h <- out.gate * mx.symbol.Activation(next.c, act.type="tanh")
return (list(c=next.c, h=next.h))
}
# unrolled lstm network
lstm.unroll <- function(num.lstm.layer, seq.len, input.size,
num.hidden, num.embed, num.label, dropout=0.) {
embed.weight <- mx.symbol.Variable("embed.weight")
cls.weight <- mx.symbol.Variable("cls.weight")
cls.bias <- mx.symbol.Variable("cls.bias")
param.cells <- lapply(1:num.lstm.layer, function(i) {
cell <- list(i2h.weight = mx.symbol.Variable(paste0("l", i, ".i2h.weight")),
i2h.bias = mx.symbol.Variable(paste0("l", i, ".i2h.bias")),
h2h.weight = mx.symbol.Variable(paste0("l", i, ".h2h.weight")),
h2h.bias = mx.symbol.Variable(paste0("l", i, ".h2h.bias")))
return (cell)
})
last.states <- lapply(1:num.lstm.layer, function(i) {
state <- list(c=mx.symbol.Variable(paste0("l", i, ".init.c")),
h=mx.symbol.Variable(paste0("l", i, ".init.h")))
return (state)
})
# embeding layer
label <- mx.symbol.Variable("label")
data <- mx.symbol.Variable("data")
embed <- mx.symbol.Embedding(data=data, input_dim=input.size,
weight=embed.weight, output_dim=num.embed, name="embed")
wordvec <- mx.symbol.SliceChannel(data=embed, num_outputs=seq.len, squeeze_axis=1)
last.hidden <- list()
for (seqidx in 1:seq.len) {
hidden <- wordvec[[seqidx]]
# stack lstm
for (i in 1:num.lstm.layer) {
dp <- ifelse(i==1, 0, dropout)
next.state <- lstm(num.hidden, indata=hidden,
prev.state=last.states[[i]],
param=param.cells[[i]],
seqidx=seqidx, layeridx=i,
dropout=dp)
hidden <- next.state$h
last.states[[i]] <- next.state
}
# decoder
if (dropout > 0)
hidden <- mx.symbol.Dropout(data=hidden, p=dropout)
last.hidden <- c(last.hidden, hidden)
}
last.hidden$dim <- 0
last.hidden$num.args <- seq.len
concat <-mxnet:::mx.varg.symbol.Concat(last.hidden)
fc <- mx.symbol.FullyConnected(data=concat,
weight=cls.weight,
bias=cls.bias,
num.hidden=num.label)
label <- mx.symbol.transpose(data=label)
label <- mx.symbol.Reshape(data=label, target.shape=c(0))
loss.all <- mx.symbol.SoftmaxOutput(data=fc, label=label, name="sm")
return (loss.all)
}
# lstm inference model symbol
lstm.inference.symbol <- function(num.lstm.layer, input.size,
num.hidden, num.embed, num.label, dropout=0.) {
seqidx <- 0
embed.weight <- mx.symbol.Variable("embed.weight")
cls.weight <- mx.symbol.Variable("cls.weight")
cls.bias <- mx.symbol.Variable("cls.bias")
param.cells <- lapply(1:num.lstm.layer, function(i) {
cell <- list(i2h.weight = mx.symbol.Variable(paste0("l", i, ".i2h.weight")),
i2h.bias = mx.symbol.Variable(paste0("l", i, ".i2h.bias")),
h2h.weight = mx.symbol.Variable(paste0("l", i, ".h2h.weight")),
h2h.bias = mx.symbol.Variable(paste0("l", i, ".h2h.bias")))
return (cell)
})
last.states <- lapply(1:num.lstm.layer, function(i) {
state <- list(c=mx.symbol.Variable(paste0("l", i, ".init.c")),
h=mx.symbol.Variable(paste0("l", i, ".init.h")))
return (state)
})
# embeding layer
data <- mx.symbol.Variable("data")
hidden <- mx.symbol.Embedding(data=data, input_dim=input.size,
weight=embed.weight, output_dim=num.embed, name="embed")
# stack lstm
for (i in 1:num.lstm.layer) {
dp <- ifelse(i==1, 0, dropout)
next.state <- lstm(num.hidden, indata=hidden,
prev.state=last.states[[i]],
param=param.cells[[i]],
seqidx=seqidx, layeridx=i,
dropout=dp)
hidden <- next.state$h
last.states[[i]] <- next.state
}
# decoder
if (dropout > 0)
hidden <- mx.symbol.Dropout(data=hidden, p=dropout)
fc <- mx.symbol.FullyConnected(data=hidden, num_hidden=num.label,
weight=cls.weight, bias=cls.bias, name='pred')
sm <- mx.symbol.SoftmaxOutput(data=fc, name='sm')
unpack.c <- lapply(1:num.lstm.layer, function(i) {
state <- last.states[[i]]
state.c <- mx.symbol.BlockGrad(state$c, name=paste0("l", i, ".last.c"))
return (state.c)
})
unpack.h <- lapply(1:num.lstm.layer, function(i) {
state <- last.states[[i]]
state.h <- mx.symbol.BlockGrad(state$h, name=paste0("l", i, ".last.h"))
return (state.h)
})
list.all <- c(sm, unpack.c, unpack.h)
return (mx.symbol.Group(list.all))
}
#' Training LSTM Unrolled Model
#'
#' @param train.data mx.io.DataIter or list(data=R.array, label=R.array)
#' The Training set.
#' @param eval.data mx.io.DataIter or list(data=R.array, label=R.array), optional
#' The validation set used for validation evaluation during the progress.
#' @param num.lstm.layer integer
#' The number of the layer of lstm.
#' @param seq.len integer
#' The length of the input sequence.
#' @param num.hidden integer
#' The number of hidden nodes.
#' @param num.embed integer
#' The output dim of embedding.
#' @param num.label integer
#' The number of labels.
#' @param batch.size integer
#' The batch size used for R array training.
#' @param input.size integer
#' The input dim of one-hot encoding of embedding
#' @param ctx mx.context, optional
#' The device used to perform training.
#' @param num.round integer, default=10
#' The number of iterations over training data to train the model.
#' @param update.period integer, default=1
#' The number of iterations to update parameters during training period.
#' @param initializer initializer object. default=mx.init.uniform(0.01)
#' The initialization scheme for parameters.
#' @param dropout float, default=0
#' A number in [0,1) containing the dropout ratio from the last hidden layer to the output layer.
#' @param optimizer string, default="sgd"
#' The optimization method.
#' @param epoch.end.callback function, optional
#' The callback when iteration ends.
#' @param batch.end.callback function, optional
#' The callback when one mini-batch iteration ends.
#' @param ... other parameters passing to \code{mx.lstm}/.
#' @return model A trained lstm unrolled model.
#'
#' @export
mx.lstm <- function(train.data, eval.data=NULL,
num.lstm.layer, seq.len,
num.hidden, num.embed, num.label,
batch.size, input.size,
ctx=mx.ctx.default(),
num.round=10, update.period=1,
initializer=mx.init.uniform(0.01),
dropout=0, optimizer='sgd',
epoch.end.callback=NULL, batch.end.callback=NULL,
model,
arg.params,
...) {
# check data and change data into iterator
train.data <- check.data(train.data, batch.size, TRUE)
eval.data <- check.data(eval.data, batch.size, FALSE)
# get unrolled lstm symbol
if(missing(model)){
rnn.sym <- lstm.unroll(num.lstm.layer=num.lstm.layer,
num.hidden=num.hidden,
seq.len=seq.len,
input.size=input.size,
num.embed=num.embed,
num.label=num.label,
dropout=dropout)
} else {
rnn.sym=model$symbol
}
init.states.c <- lapply(1:num.lstm.layer, function(i) {
state.c <- paste0("l", i, ".init.c")
return (state.c)
})
init.states.h <- lapply(1:num.lstm.layer, function(i) {
state.h <- paste0("l", i, ".init.h")
return (state.h)
})
init.states.name <- c(init.states.c, init.states.h)
# set up lstm model
model <- setup.rnn.model(rnn.sym=rnn.sym,
ctx=ctx,
num.rnn.layer=num.lstm.layer,
seq.len=seq.len,
num.hidden=num.hidden,
num.embed=num.embed,
num.label=num.label,
batch.size=batch.size,
input.size=input.size,
init.states.name=init.states.name,
initializer=initializer,
dropout=dropout)
# restore states
if (!missing(arg.params)){
arg.names <- names(model$rnn.exec$ref.arg.arrays)
for (k in names(arg.params)) {
if ((k %in% arg.names) && is.param.name(k) ) {
rnn.input <- list()
rnn.input[[k]] <- arg.params[[k]]
mx.exec.update.arg.arrays(model$rnn.exec, rnn.input, match.name=TRUE)
}
}
}
# train lstm model
model <- train.rnn( model, train.data, eval.data,
num.round=num.round,
update.period=update.period,
ctx=ctx,
init.states.name=init.states.name,
epoch.end.callback=epoch.end.callback,
batch.end.callback=batch.end.callback,
...)
# change model into MXFeedForwardModel
model <- list(symbol=model$symbol, arg.params=model$rnn.exec$ref.arg.arrays, aux.params=model$rnn.exec$ref.aux.arrays)
return(structure(model, class="MXFeedForwardModel"))
}
#' Create a LSTM Inference Model
#'
#' @param num.lstm.layer integer
#' The number of the layer of lstm.
#' @param input.size integer
#' The input dim of one-hot encoding of embedding
#' @param num.hidden integer
#' The number of hidden nodes.
#' @param num.embed integer
#' The output dim of embedding.
#' @param num.label integer
#' The number of labels.
#' @param batch.size integer, default=1
#' The batch size used for R array training.
#' @param arg.params list
#' The batch size used for R array training.
#' @param ctx mx.context, optional
#' Model parameter, list of name to NDArray of net's weights.
#' @param dropout float, default=0
#' A number in [0,1) containing the dropout ratio from the last hidden layer to the output layer.
#' @return model list(rnn.exec=integer, symbol=mxnet symbol, num.rnn.layer=integer, num.hidden=integer, seq.len=integer, batch.size=integer, num.embed=integer)
#' A lstm inference model.
#'
#' @export
mx.lstm.inference <- function(num.lstm.layer,
input.size,
num.hidden,
num.embed,
num.label,
batch.size=1,
arg.params,
ctx=mx.cpu(),
dropout=0.) {
sym <- lstm.inference.symbol(num.lstm.layer=num.lstm.layer,
input.size=input.size,
num.hidden=num.hidden,
num.embed=num.embed,
num.label=num.label,
dropout=dropout)
init.states.c <- lapply(1:num.lstm.layer, function(i) {
state.c <- paste0("l", i, ".init.c")
return (state.c)
})
init.states.h <- lapply(1:num.lstm.layer, function(i) {
state.h <- paste0("l", i, ".init.h")
return (state.h)
})
init.states.name <- c(init.states.c, init.states.h)
seq.len <- 1
# set up lstm model
model <- setup.rnn.model(rnn.sym=sym,
ctx=ctx,
num.rnn.layer=num.lstm.layer,
seq.len=seq.len,
num.hidden=num.hidden,
num.embed=num.embed,
num.label=num.label,
batch.size=batch.size,
input.size=input.size,
init.states.name=init.states.name,
initializer=mx.init.uniform(0.01),
dropout=dropout)
arg.names <- names(model$rnn.exec$ref.arg.arrays)
for (k in names(arg.params)) {
if ((k %in% arg.names) && is.param.name(k) ) {
rnn.input <- list()
rnn.input[[k]] <- arg.params[[k]]
mx.exec.update.arg.arrays(model$rnn.exec, rnn.input, match.name=TRUE)
}
}
init.states <- list()
for (i in 1:num.lstm.layer) {
init.states[[paste0("l", i, ".init.c")]] <- model$rnn.exec$ref.arg.arrays[[paste0("l", i, ".init.c")]]*0
init.states[[paste0("l", i, ".init.h")]] <- model$rnn.exec$ref.arg.arrays[[paste0("l", i, ".init.h")]]*0
}
mx.exec.update.arg.arrays(model$rnn.exec, init.states, match.name=TRUE)
return (model)
}
#' Using forward function to predict in lstm inference model
#'
#' @param model lstm model
#' A Lstm inference model
#' @param input.data, array.matrix
#' The input data for forward function
#' @param new.seq boolean, default=FALSE
#' Whether the input is the start of a new sequence
#'
#' @return result A list(prob=prob, model=model) containing the result probability of each label and the model.
#'
#' @export
mx.lstm.forward <- function(model, input.data, new.seq=FALSE) {
if (new.seq == TRUE) {
init.states <- list()
for (i in 1:model$num.rnn.layer) {
init.states[[paste0("l", i, ".init.c")]] <- model$rnn.exec$ref.arg.arrays[[paste0("l", i, ".init.c")]]*0
init.states[[paste0("l", i, ".init.h")]] <- model$rnn.exec$ref.arg.arrays[[paste0("l", i, ".init.h")]]*0
}
mx.exec.update.arg.arrays(model$rnn.exec, init.states, match.name=TRUE)
}
dim(input.data) <- c(model$batch.size)
data <- list(data=mx.nd.array(input.data))
mx.exec.update.arg.arrays(model$rnn.exec, data, match.name=TRUE)
mx.exec.forward(model$rnn.exec, is.train=FALSE)
init.states <- list()
for (i in 1:model$num.rnn.layer) {
init.states[[paste0("l", i, ".init.c")]] <- model$rnn.exec$ref.outputs[[paste0("l", i, ".last.c_output")]]
init.states[[paste0("l", i, ".init.h")]] <- model$rnn.exec$ref.outputs[[paste0("l", i, ".last.h_output")]]
}
mx.exec.update.arg.arrays(model$rnn.exec, init.states, match.name=TRUE)
prob <- model$rnn.exec$ref.outputs[["sm_output"]]
return (list(prob=prob, model=model))
}