blob: 8e1152173ac7c7362db26e62590a2016c0e9d903 [file] [log] [blame]
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
from __future__ import print_function
import sys
sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx
# The interface of a data iter that works for bucketing
#
# DataIter
# - default_bucket_key: the bucket key for the default symbol.
#
# DataBatch
# - provide_data: same as DataIter, but specific to this batch
# - provide_label: same as DataIter, but specific to this batch
# - bucket_key: the key for the bucket that should be used for this batch
def default_read_content(path):
with open(path) as ins:
content = ins.read()
content = content.replace('\n', ' <eos> ').replace('. ', ' <eos> ')
return content
def default_build_vocab(path):
content = default_read_content(path)
content = content.split(' ')
words = set([x for x in content if len(x) > 0])
words = [x for x in words]
words = sorted(words)
the_vocab = {}
idx = 1 # 0 is left for zero-padding
the_vocab[' '] = 0 # put a dummy element here so that len(vocab) is correct
for word in words:
if len(word) == 0:
continue
if not word in the_vocab:
the_vocab[word] = idx
idx += 1
return the_vocab
def default_text2id(sentence, the_vocab):
words = sentence.split(' ')
words = [the_vocab[w] for w in words if len(w) > 0]
return words
def default_gen_buckets(sentences, batch_size, the_vocab):
len_dict = {}
max_len = -1
for sentence in sentences:
words = default_text2id(sentence, the_vocab)
lw = len(words)
if lw == 0:
continue
if lw > max_len:
max_len = lw
if lw in len_dict:
len_dict[lw] += 1
else:
len_dict[lw] = 1
print(len_dict)
tl = 0
buckets = []
for l, n in len_dict.items(): # TODO: There are better heuristic ways to do this
if n + tl >= batch_size:
buckets.append(l)
tl = 0
else:
tl += n
if tl > 0:
buckets.append(max_len)
return buckets
class SimpleBatch(object):
def __init__(self, data_names, data, label_names, label, bucket_key):
self.data = data
self.label = label
self.data_names = data_names
self.label_names = label_names
self.bucket_key = bucket_key
self.pad = 0
self.index = None # TODO: what is index?
@property
def provide_data(self):
return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
@property
def provide_label(self):
return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
class DummyIter(mx.io.DataIter):
"A dummy iterator that always return the same batch, used for speed testing"
def __init__(self, real_iter):
super(DummyIter, self).__init__()
self.real_iter = real_iter
self.provide_data = real_iter.provide_data
self.provide_label = real_iter.provide_label
self.batch_size = real_iter.batch_size
for batch in real_iter:
self.the_batch = batch
break
def __iter__(self):
return self
def next(self):
return self.the_batch
class BucketSentenceIter(mx.io.DataIter):
def __init__(self, path, vocab, buckets, batch_size,
init_states, data_name='data', label_name='label',
seperate_char=' <eos> ', text2id=None, read_content=None):
super(BucketSentenceIter, self).__init__()
if text2id is None:
self.text2id = default_text2id
else:
self.text2id = text2id
if read_content is None:
self.read_content = default_read_content
else:
self.read_content = read_content
content = self.read_content(path)
sentences = content.split(seperate_char)
if len(buckets) == 0:
buckets = default_gen_buckets(sentences, batch_size, vocab)
print(buckets)
self.vocab_size = len(vocab)
self.data_name = data_name
self.label_name = label_name
buckets.sort()
self.buckets = buckets
self.data = [[] for _ in buckets]
# pre-allocate with the largest bucket for better memory sharing
self.default_bucket_key = max(buckets)
for sentence in sentences:
sentence = self.text2id(sentence, vocab)
if len(sentence) == 0:
continue
for i, bkt in enumerate(buckets):
if bkt >= len(sentence):
self.data[i].append(sentence)
break
# we just ignore the sentence it is longer than the maximum
# bucket size here
# convert data into ndarrays for better speed during training
data = [np.zeros((len(x), buckets[i])) for i, x in enumerate(self.data)]
for i_bucket in range(len(self.buckets)):
for j in range(len(self.data[i_bucket])):
sentence = self.data[i_bucket][j]
data[i_bucket][j, :len(sentence)] = sentence
self.data = data
# Get the size of each bucket, so that we could sample
# uniformly from the bucket
bucket_sizes = [len(x) for x in self.data]
print("Summary of dataset ==================")
for bkt, size in zip(buckets, bucket_sizes):
print("bucket of len %3d : %d samples" % (bkt, size))
self.batch_size = batch_size
self.make_data_iter_plan()
self.init_states = init_states
self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states]
self.provide_data = [('data', (batch_size, self.default_bucket_key))] + init_states
self.provide_label = [('softmax_label', (self.batch_size, self.default_bucket_key))]
def make_data_iter_plan(self):
"make a random data iteration plan"
# truncate each bucket into multiple of batch-size
bucket_n_batches = []
for i in range(len(self.data)):
bucket_n_batches.append(len(self.data[i]) / self.batch_size)
self.data[i] = self.data[i][:int(bucket_n_batches[i]*self.batch_size)]
bucket_plan = np.hstack([np.zeros(n, int)+i for i, n in enumerate(bucket_n_batches)])
np.random.shuffle(bucket_plan)
bucket_idx_all = [np.random.permutation(len(x)) for x in self.data]
self.bucket_plan = bucket_plan
self.bucket_idx_all = bucket_idx_all
self.bucket_curr_idx = [0 for x in self.data]
self.data_buffer = []
self.label_buffer = []
for i_bucket in range(len(self.data)):
data = np.zeros((self.batch_size, self.buckets[i_bucket]))
label = np.zeros((self.batch_size, self.buckets[i_bucket]))
self.data_buffer.append(data)
self.label_buffer.append(label)
def __iter__(self):
init_state_names = [x[0] for x in self.init_states]
for i_bucket in self.bucket_plan:
data = self.data_buffer[i_bucket]
label = self.label_buffer[i_bucket]
i_idx = self.bucket_curr_idx[i_bucket]
idx = self.bucket_idx_all[i_bucket][i_idx:i_idx+self.batch_size]
self.bucket_curr_idx[i_bucket] += self.batch_size
data[:] = self.data[i_bucket][idx]
for k in range(len(data)):
label[k] = sorted(data[k])
#count = len(data[k]) / 2
#for j in range(count):
# data[j+count] = data[j]
#label[:, :-1] = data[:, 1:]
#label[:, -1] = 0
data_all = [mx.nd.array(data)] + self.init_state_arrays
label_all = [mx.nd.array(label)]
data_names = ['data'] + init_state_names
label_names = ['softmax_label']
data_batch = SimpleBatch(data_names, data_all, label_names, label_all,
self.buckets[i_bucket])
yield data_batch
def reset(self):
self.bucket_curr_idx = [0 for x in self.data]