import mxnet as mx | |
def slice_symbol_to_seq_symobls(net, seq_len, axis=1, squeeze_axis=True): | |
net = mx.sym.SliceChannel(data=net, num_outputs=seq_len, axis=axis, squeeze_axis=squeeze_axis) | |
hidden_all = [] | |
for seq_index in range(seq_len): | |
hidden_all.append(net[seq_index]) | |
net = hidden_all | |
return net |