blob: e5b3d43ac07d641ee6a4ca66ad21c163e1addda0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C0111, too-many-statements, too-many-locals
# pylint: too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
"""
architecture file for deep speech 2 model
"""
import json
import math
import argparse
import mxnet as mx
from stt_layer_batchnorm import batchnorm
from stt_layer_conv import conv
from stt_layer_fc import sequence_fc
from stt_layer_gru import bi_gru_unroll, gru_unroll
from stt_layer_lstm import bi_lstm_unroll
from stt_layer_slice import slice_symbol_to_seq_symobls
from stt_layer_warpctc import warpctc_layer
def prepare_data(args):
"""
set atual shape of data
"""
rnn_type = args.config.get("arch", "rnn_type")
num_rnn_layer = args.config.getint("arch", "num_rnn_layer")
num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list"))
batch_size = args.config.getint("common", "batch_size")
if rnn_type == 'lstm':
init_c = [('l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l]))
for l in range(num_rnn_layer)]
init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
for l in range(num_rnn_layer)]
elif rnn_type == 'bilstm':
forward_init_c = [('forward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l]))
for l in range(num_rnn_layer)]
backward_init_c = [('backward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l]))
for l in range(num_rnn_layer)]
init_c = forward_init_c + backward_init_c
forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
for l in range(num_rnn_layer)]
backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
for l in range(num_rnn_layer)]
init_h = forward_init_h + backward_init_h
elif rnn_type == 'gru':
init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
for l in range(num_rnn_layer)]
elif rnn_type == 'bigru':
forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
for l in range(num_rnn_layer)]
backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
for l in range(num_rnn_layer)]
init_h = forward_init_h + backward_init_h
else:
raise Exception('network type should be one of the lstm,bilstm,gru,bigru')
if rnn_type == 'lstm' or rnn_type == 'bilstm':
init_states = init_c + init_h
elif rnn_type == 'gru' or rnn_type == 'bigru':
init_states = init_h
return init_states
def arch(args, seq_len=None):
"""
define deep speech 2 network
"""
if isinstance(args, argparse.Namespace):
mode = args.config.get("common", "mode")
is_bucketing = args.config.getboolean("arch", "is_bucketing")
if mode == "train" or is_bucketing:
channel_num = args.config.getint("arch", "channel_num")
conv_layer1_filter_dim = \
tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride")))
conv_layer2_filter_dim = \
tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride")))
rnn_type = args.config.get("arch", "rnn_type")
num_rnn_layer = args.config.getint("arch", "num_rnn_layer")
num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list"))
is_batchnorm = args.config.getboolean("arch", "is_batchnorm")
if seq_len is None:
seq_len = args.config.getint('arch', 'max_t_count')
num_label = args.config.getint('arch', 'max_label_length')
num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers")
num_hidden_rear_fc_list = json.loads(args.config.get("arch", "num_hidden_rear_fc_list"))
act_type_rear_fc_list = json.loads(args.config.get("arch", "act_type_rear_fc_list"))
# model symbol generation
# input preparation
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0))
net = conv(net=net,
channels=channel_num,
filter_dimension=conv_layer1_filter_dim,
stride=conv_layer1_stride,
no_bias=is_batchnorm,
name='conv1')
if is_batchnorm:
# batch norm normalizes axis 1
net = batchnorm(net, name="conv1_batchnorm")
net = conv(net=net,
channels=channel_num,
filter_dimension=conv_layer2_filter_dim,
stride=conv_layer2_stride,
no_bias=is_batchnorm,
name='conv2')
if is_batchnorm:
# batch norm normalizes axis 1
net = batchnorm(net, name="conv2_batchnorm")
net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3))
net = mx.sym.Reshape(data=net, shape=(0, 0, -3))
seq_len_after_conv_layer1 = int(
math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1
seq_len_after_conv_layer2 = int(
math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0])
/ conv_layer2_stride[0])) + 1
net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1)
if rnn_type == "bilstm":
net = bi_lstm_unroll(net=net,
seq_len=seq_len_after_conv_layer2,
num_hidden_lstm_list=num_hidden_rnn_list,
num_lstm_layer=num_rnn_layer,
dropout=0.,
is_batchnorm=is_batchnorm,
is_bucketing=is_bucketing)
elif rnn_type == "gru":
net = gru_unroll(net=net,
seq_len=seq_len_after_conv_layer2,
num_hidden_gru_list=num_hidden_rnn_list,
num_gru_layer=num_rnn_layer,
dropout=0.,
is_batchnorm=is_batchnorm,
is_bucketing=is_bucketing)
elif rnn_type == "bigru":
net = bi_gru_unroll(net=net,
seq_len=seq_len_after_conv_layer2,
num_hidden_gru_list=num_hidden_rnn_list,
num_gru_layer=num_rnn_layer,
dropout=0.,
is_batchnorm=is_batchnorm,
is_bucketing=is_bucketing)
else:
raise Exception('rnn_type should be one of the followings, bilstm,gru,bigru')
# rear fc layers
net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2,
num_layer=num_rear_fc_layers, prefix="rear",
num_hidden_list=num_hidden_rear_fc_list,
act_type_list=act_type_rear_fc_list,
is_batchnorm=is_batchnorm)
# warpctc layer
net = warpctc_layer(net=net,
seq_len=seq_len_after_conv_layer2,
label=label,
num_label=num_label,
character_classes_count=
(args.config.getint('arch', 'n_classes') + 1))
args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
return net
elif mode == 'load' or mode == 'predict':
conv_layer1_filter_dim = \
tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride")))
conv_layer2_filter_dim = \
tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride")))
if seq_len is None:
seq_len = args.config.getint('arch', 'max_t_count')
seq_len_after_conv_layer1 = int(
math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1
seq_len_after_conv_layer2 = int(
math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0])
/ conv_layer2_stride[0])) + 1
args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
else:
raise Exception('mode must be the one of the followings - train,predict,load')
class BucketingArch(object):
def __init__(self, args):
self.args = args
def sym_gen(self, seq_len):
args = self.args
net = arch(args, seq_len)
init_states = prepare_data(args)
init_state_names = [x[0] for x in init_states]
init_state_names.insert(0, 'data')
return net, init_state_names, ('label',)
def get_sym_gen(self):
return self.sym_gen