blob: 6b434ec1049bc3564470ff973bc2f2c30ca659c6 [file] [log] [blame]
import mxnet as mx
def slice_symbol_to_seq_symobls(net, seq_len, axis=1, squeeze_axis=True):
net = mx.sym.SliceChannel(data=net, num_outputs=seq_len, axis=axis, squeeze_axis=squeeze_axis)
hidden_all = []
for seq_index in range(seq_len):
hidden_all.append(net[seq_index])
net = hidden_all
return net