Baidu-WarpCTC is a CTC implementation by Baidu that supports using GPU processors. It supports using CTC with LSTM to solve label alignment problems in many areas, such as OCR and speech recognition.
You can get the source code for the example on GitHub.
cd ~/ git clone https://github.com/baidu-research/warp-ctc cd warp-ctc mkdir build cd build cmake .. make sudo make install
comment out following lines in make/config.mk WARPCTC_PATH = $(HOME)/warp-ctc MXNET_PLUGINS += plugin/warpctc/warpctc.mk rebuild mxnet by make clean && make -j4
There are two examples. One is a toy example that validates CTC integration. The second is an OCR example with LSTM and CTC. You can run it by typing the following code:
cd examples/warpctc python lstm_ocr.py
The OCR example is constructed as follows:
The following code shows the detailed construction of the net:
def lstm_unroll(num_lstm_layer, seq_len,
num_hidden, num_label):
param_cells = []
last_states = []
for i in range(num_lstm_layer):
param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
h=mx.sym.Variable("l%d_init_h" % i))
last_states.append(state)
assert(len(last_states) == num_lstm_layer)
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
#every column of image is an input, there are seq_len inputs
wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)
hidden_all = []
for seqidx in range(seq_len):
hidden = wordvec[seqidx]
for i in range(num_lstm_layer):
next_state = lstm(num_hidden, indata=hidden,
prev_state=last_states[i],
param=param_cells[i],
seqidx=seqidx, layeridx=i)
hidden = next_state.h
last_states[i] = next_state
hidden_all.append(hidden)
hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11)
# here we do NOT need to transpose label as other lstm examples do
label = mx.sym.Reshape(data=label, target_shape=(0,))
#label should be int type, so use cast
label = mx.sym.Cast(data = label, dtype = 'int32')
sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len)
return sm
Provide labels with length b. For samples whose label length is smaller than b, append 0 to the label data to make it have length b.
0 is reserved for a blank label.