blob: 92f1002a2f01071a73b27904fb3f24b7ac8cd214 [file] [log] [blame]
import json
import math
import mxnet as mx
from stt_layer_batchnorm import batchnorm
from stt_layer_conv import conv
from stt_layer_fc import sequence_fc
from stt_layer_gru import bi_gru_unroll, gru_unroll
from stt_layer_lstm import bi_lstm_unroll
from stt_layer_slice import slice_symbol_to_seq_symobls
from stt_layer_warpctc import warpctc_layer
def prepare_data(args):
rnn_type = args.config.get("arch", "rnn_type")
num_rnn_layer = args.config.getint("arch", "num_rnn_layer")
num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list"))
batch_size = args.config.getint("common", "batch_size")
if rnn_type == 'lstm':
init_c = [('l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) for l in range(num_rnn_layer)]
init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in range(num_rnn_layer)]
elif rnn_type == 'bilstm':
forward_init_c = [('forward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) for l in
range(num_rnn_layer)]
backward_init_c = [('backward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) for l in
range(num_rnn_layer)]
init_c = forward_init_c + backward_init_c
forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in
range(num_rnn_layer)]
backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in
range(num_rnn_layer)]
init_h = forward_init_h + backward_init_h
elif rnn_type == 'gru':
init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in range(num_rnn_layer)]
elif rnn_type == 'bigru':
forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in
range(num_rnn_layer)]
backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in
range(num_rnn_layer)]
init_h = forward_init_h + backward_init_h
else:
raise Exception('network type should be one of the lstm,bilstm,gru,bigru')
if rnn_type == 'lstm' or rnn_type == 'bilstm':
init_states = init_c + init_h
elif rnn_type == 'gru' or rnn_type == 'bigru':
init_states = init_h
return init_states
def arch(args):
mode = args.config.get("common", "mode")
if mode == "train":
channel_num = args.config.getint("arch", "channel_num")
conv_layer1_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride")))
conv_layer2_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride")))
rnn_type = args.config.get("arch", "rnn_type")
num_rnn_layer = args.config.getint("arch", "num_rnn_layer")
num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list"))
is_batchnorm = args.config.getboolean("arch", "is_batchnorm")
seq_len = args.config.getint('arch', 'max_t_count')
num_label = args.config.getint('arch', 'max_label_length')
num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers")
num_hidden_rear_fc_list = json.loads(args.config.get("arch", "num_hidden_rear_fc_list"))
act_type_rear_fc_list = json.loads(args.config.get("arch", "act_type_rear_fc_list"))
# model symbol generation
# input preparation
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0))
net = conv(net=net,
channels=channel_num,
filter_dimension=conv_layer1_filter_dim,
stride=conv_layer1_stride,
no_bias=is_batchnorm
)
if is_batchnorm:
# batch norm normalizes axis 1
net = batchnorm(net)
net = conv(net=net,
channels=channel_num,
filter_dimension=conv_layer2_filter_dim,
stride=conv_layer2_stride,
no_bias=is_batchnorm
)
if is_batchnorm:
# batch norm normalizes axis 1
net = batchnorm(net)
net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3))
net = mx.sym.Reshape(data=net, shape=(0, 0, -3))
seq_len_after_conv_layer1 = int(
math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1
seq_len_after_conv_layer2 = int(
math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1
net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1)
if rnn_type == "bilstm":
net = bi_lstm_unroll(net=net,
seq_len=seq_len_after_conv_layer2,
num_hidden_lstm_list=num_hidden_rnn_list,
num_lstm_layer=num_rnn_layer,
dropout=0.,
is_batchnorm=is_batchnorm)
elif rnn_type == "gru":
net = gru_unroll(net=net,
seq_len=seq_len_after_conv_layer2,
num_hidden_gru_list=num_hidden_rnn_list,
num_gru_layer=num_rnn_layer,
dropout=0.,
is_batchnorm=is_batchnorm)
elif rnn_type == "bigru":
net = bi_gru_unroll(net=net,
seq_len=seq_len_after_conv_layer2,
num_hidden_gru_list=num_hidden_rnn_list,
num_gru_layer=num_rnn_layer,
dropout=0.,
is_batchnorm=is_batchnorm)
else:
raise Exception('rnn_type should be one of the followings, bilstm,gru,bigru')
# rear fc layers
net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, num_layer=num_rear_fc_layers, prefix="rear",
num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list,
is_batchnorm=is_batchnorm)
if is_batchnorm:
hidden_all = []
# batch norm normalizes axis 1
for seq_index in range(seq_len_after_conv_layer2):
hidden = net[seq_index]
hidden = batchnorm(hidden)
hidden_all.append(hidden)
net = hidden_all
# warpctc layer
net = warpctc_layer(net=net,
seq_len=seq_len_after_conv_layer2,
label=label,
num_label=num_label,
character_classes_count=(args.config.getint('arch', 'n_classes') + 1)
)
args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
return net
else:
conv_layer1_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride")))
conv_layer2_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride")))
seq_len = args.config.getint('arch', 'max_t_count')
seq_len_after_conv_layer1 = int(
math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1
seq_len_after_conv_layer2 = int(
math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1
args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))