blob: 41b93f3bb9e5144539e75bc147ba94c9402418fb [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import mxnet as mx
import sys
sys.path.insert(0, "../../python")
import bisect
import random
import numpy as np
BATCH_SIZE = 1
SEQ_LENGTH = 0
NUM_GPU = 1
def get_label(buf, num_lable):
ret = np.zeros(num_lable)
for i in range(len(buf)):
ret[i] = int(buf[i])
return ret
class BucketSTTIter(mx.io.DataIter):
def __init__(self, count, datagen, batch_size, num_label, init_states, seq_length, width, height,
sort_by_duration=True,
is_bi_graphemes=False,
partition="train",
buckets=[],
save_feature_as_csvfile=False
):
super(BucketSTTIter, self).__init__()
self.maxLabelLength = num_label
# global param
self.batch_size = batch_size
self.count = count
self.num_label = num_label
self.init_states = init_states
self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states]
self.width = width
self.height = height
self.datagen = datagen
self.label = None
self.is_bi_graphemes = is_bi_graphemes
# self.partition = datagen.partition
if partition == 'train':
durations = datagen.train_durations
audio_paths = datagen.train_audio_paths
texts = datagen.train_texts
elif partition == 'validation':
durations = datagen.val_durations
audio_paths = datagen.val_audio_paths
texts = datagen.val_texts
elif partition == 'test':
durations = datagen.test_durations
audio_paths = datagen.test_audio_paths
texts = datagen.test_texts
else:
raise Exception("Invalid partition to load metadata. "
"Must be train/validation/test")
# if sortagrad
if sort_by_duration:
durations, audio_paths, texts = datagen.sort_by_duration(durations,
audio_paths,
texts)
else:
durations = durations
audio_paths = audio_paths
texts = texts
self.trainDataList = zip(durations, audio_paths, texts)
self.trainDataIter = iter(self.trainDataList)
self.is_first_epoch = True
data_lengths = [int(d*100) for d in durations]
if len(buckets) == 0:
buckets = [i for i, j in enumerate(np.bincount(data_lengths))
if j >= batch_size]
if len(buckets) == 0:
raise Exception('There is no valid buckets. It may occured by large batch_size for each buckets. max bincount:%d batch_size:%d' % (max(np.bincount(data_lengths)), batch_size))
buckets.sort()
ndiscard = 0
self.data = [[] for _ in buckets]
for i, sent in enumerate(data_lengths):
buck = bisect.bisect_left(buckets, sent)
if buck == len(buckets):
ndiscard += 1
continue
self.data[buck].append(self.trainDataList[i])
if ndiscard != 0:
print("WARNING: discarded %d sentences longer than the largest bucket."% ndiscard)
self.buckets = buckets
self.nddata = []
self.ndlabel = []
self.default_bucket_key = max(buckets)
self.idx = []
for i, buck in enumerate(self.data):
self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
self.curr_idx = 0
self.provide_data = [('data', (self.batch_size, self.default_bucket_key , width * height))] + init_states
self.provide_label = [('label', (self.batch_size, self.maxLabelLength))]
self.save_feature_as_csvfile=save_feature_as_csvfile
#self.reset()
def reset(self):
"""Resets the iterator to the beginning of the data."""
self.curr_idx = 0
random.shuffle(self.idx)
for buck in self.data:
np.random.shuffle(buck)
def next(self):
"""Returns the next batch of data."""
if self.curr_idx == len(self.idx):
raise StopIteration
i, j = self.idx[self.curr_idx]
self.curr_idx += 1
audio_paths = []
texts = []
for duration, audio_path, text in self.data[i][j:j+self.batch_size]:
audio_paths.append(audio_path)
texts.append(text)
if self.is_first_epoch:
data_set = self.datagen.prepare_minibatch(audio_paths, texts, overwrite=True,
is_bi_graphemes=self.is_bi_graphemes,
seq_length=self.buckets[i],
save_feature_as_csvfile=self.save_feature_as_csvfile)
else:
data_set = self.datagen.prepare_minibatch(audio_paths, texts, overwrite=False,
is_bi_graphemes=self.is_bi_graphemes,
seq_length=self.buckets[i],
save_feature_as_csvfile=self.save_feature_as_csvfile)
data_all = [mx.nd.array(data_set['x'])] + self.init_state_arrays
label_all = [mx.nd.array(data_set['y'])]
self.label = label_all
provide_data = [('data', (self.batch_size, self.buckets[i], self.width * self.height))] + self.init_states
return mx.io.DataBatch(data_all, label_all, pad=0,
bucket_key=self.buckets[i],
provide_data=provide_data,
provide_label=self.provide_label)