| # pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme |
| # pylint: disable=superfluous-parens, no-member, invalid-name |
| from __future__ import print_function |
| import sys |
| sys.path.insert(0, "../../python") |
| import numpy as np |
| import mxnet as mx |
| |
| # The interface of a data iter that works for bucketing |
| # |
| # DataIter |
| # - default_bucket_key: the bucket key for the default symbol. |
| # |
| # DataBatch |
| # - provide_data: same as DataIter, but specific to this batch |
| # - provide_label: same as DataIter, but specific to this batch |
| # - bucket_key: the key for the bucket that should be used for this batch |
| |
| def default_read_content(path): |
| with open(path) as ins: |
| content = ins.read() |
| content = content.replace('\n', ' <eos> ').replace('. ', ' <eos> ') |
| return content |
| |
| def default_build_vocab(path): |
| content = default_read_content(path) |
| content = content.split(' ') |
| idx = 1 # 0 is left for zero-padding |
| the_vocab = {} |
| the_vocab[' '] = 0 # put a dummy element here so that len(vocab) is correct |
| for word in content: |
| if len(word) == 0: |
| continue |
| if not word in the_vocab: |
| the_vocab[word] = idx |
| idx += 1 |
| return the_vocab |
| |
| def default_text2id(sentence, the_vocab): |
| words = sentence.split(' ') |
| words = [the_vocab[w] for w in words if len(w) > 0] |
| return words |
| |
| def default_gen_buckets(sentences, batch_size, the_vocab): |
| len_dict = {} |
| max_len = -1 |
| for sentence in sentences: |
| words = default_text2id(sentence, the_vocab) |
| if len(words) == 0: |
| continue |
| if len(words) > max_len: |
| max_len = len(words) |
| if len(words) in len_dict: |
| len_dict[len(words)] += 1 |
| else: |
| len_dict[len(words)] = 1 |
| print(len_dict) |
| |
| tl = 0 |
| buckets = [] |
| for l, n in len_dict.items(): # TODO: There are better heuristic ways to do this |
| if n + tl >= batch_size: |
| buckets.append(l) |
| tl = 0 |
| else: |
| tl += n |
| if tl > 0: |
| buckets.append(max_len) |
| return buckets |
| |
| class SimpleBatch(object): |
| def __init__(self, data_names, data, data_layouts, label_names, label, label_layouts, bucket_key): |
| self.data = data |
| self.label = label |
| self.data_names = data_names |
| self.label_names = label_names |
| self.data_layouts = data_layouts |
| self.label_layouts = label_layouts |
| self.bucket_key = bucket_key |
| |
| self.pad = 0 |
| self.index = None # TODO: what is index? |
| |
| @property |
| def provide_data(self): |
| return [mx.io.DataDesc(n, x.shape, layout=l) for n, x, l in zip(self.data_names, self.data, self.data_layouts)] |
| |
| @property |
| def provide_label(self): |
| return [mx.io.DataDesc(n, x.shape, layout=l) for n, x, l in zip(self.label_names, self.label, self.label_layouts)] |
| |
| class DummyIter(mx.io.DataIter): |
| "A dummy iterator that always return the same batch, used for speed testing" |
| def __init__(self, real_iter): |
| super(DummyIter, self).__init__() |
| self.real_iter = real_iter |
| self.provide_data = real_iter.provide_data |
| self.provide_label = real_iter.provide_label |
| self.batch_size = real_iter.batch_size |
| |
| for batch in real_iter: |
| self.the_batch = batch |
| break |
| |
| def __iter__(self): |
| return self |
| |
| def next(self): |
| return self.the_batch |
| |
| class BucketSentenceIter(mx.io.DataIter): |
| def __init__(self, path, vocab, buckets, batch_size, |
| init_states, data_name='data', label_name='label', |
| seperate_char=' <eos> ', text2id=None, read_content=None, |
| time_major=True): |
| super(BucketSentenceIter, self).__init__() |
| |
| if text2id is None: |
| self.text2id = default_text2id |
| else: |
| self.text2id = text2id |
| if read_content is None: |
| self.read_content = default_read_content |
| else: |
| self.read_content = read_content |
| content = self.read_content(path) |
| sentences = content.split(seperate_char) |
| |
| if len(buckets) == 0: |
| buckets = default_gen_buckets(sentences, batch_size, vocab) |
| |
| self.vocab_size = len(vocab) |
| self.data_name = data_name |
| self.label_name = label_name |
| self.time_major = time_major |
| |
| buckets.sort() |
| self.buckets = buckets |
| self.data = [[] for _ in buckets] |
| |
| # pre-allocate with the largest bucket for better memory sharing |
| self.default_bucket_key = max(buckets) |
| |
| for sentence in sentences: |
| sentence = self.text2id(sentence, vocab) |
| if len(sentence) == 0: |
| continue |
| for i, bkt in enumerate(buckets): |
| if bkt >= len(sentence): |
| self.data[i].append(sentence) |
| break |
| # we just ignore the sentence it is longer than the maximum |
| # bucket size here |
| |
| # convert data into ndarrays for better speed during training |
| data = [np.zeros((len(x), buckets[i])) for i, x in enumerate(self.data)] |
| for i_bucket in range(len(self.buckets)): |
| for j in range(len(self.data[i_bucket])): |
| sentence = self.data[i_bucket][j] |
| data[i_bucket][j, :len(sentence)] = sentence |
| self.data = data |
| |
| # Get the size of each bucket, so that we could sample |
| # uniformly from the bucket |
| bucket_sizes = [len(x) for x in self.data] |
| |
| print("Summary of dataset ==================") |
| for bkt, size in zip(buckets, bucket_sizes): |
| print("bucket of len %3d : %d samples" % (bkt, size)) |
| |
| self.batch_size = batch_size |
| self.make_data_iter_plan() |
| |
| self.init_states = init_states |
| self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] |
| |
| if self.time_major: |
| self.provide_data = [mx.io.DataDesc('data', (self.default_bucket_key, batch_size), layout='TN')] + init_states |
| self.provide_label = [mx.io.DataDesc('softmax_label', (self.default_bucket_key, batch_size), layout='TN')] |
| else: |
| self.provide_data = [('data', (batch_size, self.default_bucket_key))] + init_states |
| self.provide_label = [('softmax_label', (self.batch_size, self.default_bucket_key))] |
| |
| def make_data_iter_plan(self): |
| "make a random data iteration plan" |
| # truncate each bucket into multiple of batch-size |
| bucket_n_batches = [] |
| for i in range(len(self.data)): |
| bucket_n_batches.append(len(self.data[i]) / self.batch_size) |
| self.data[i] = self.data[i][:int(bucket_n_batches[i]*self.batch_size)] |
| |
| bucket_plan = np.hstack([np.zeros(n, int)+i for i, n in enumerate(bucket_n_batches)]) |
| np.random.shuffle(bucket_plan) |
| |
| bucket_idx_all = [np.random.permutation(len(x)) for x in self.data] |
| |
| self.bucket_plan = bucket_plan |
| self.bucket_idx_all = bucket_idx_all |
| self.bucket_curr_idx = [0 for x in self.data] |
| |
| self.data_buffer = [] |
| self.label_buffer = [] |
| for i_bucket in range(len(self.data)): |
| if self.time_major: |
| data = np.zeros((self.buckets[i_bucket], self.batch_size)) |
| label = np.zeros((self.buckets[i_bucket], self.batch_size)) |
| else: |
| data = np.zeros((self.batch_size, self.buckets[i_bucket])) |
| label = np.zeros((self.batch_size, self.buckets[i_bucket])) |
| |
| self.data_buffer.append(data) |
| self.label_buffer.append(label) |
| |
| def __iter__(self): |
| for i_bucket in self.bucket_plan: |
| data = self.data_buffer[i_bucket] |
| i_idx = self.bucket_curr_idx[i_bucket] |
| idx = self.bucket_idx_all[i_bucket][i_idx:i_idx+self.batch_size] |
| self.bucket_curr_idx[i_bucket] += self.batch_size |
| |
| init_state_names = [x[0] for x in self.init_states] |
| |
| if self.time_major: |
| data[:] = self.data[i_bucket][idx].T |
| else: |
| data[:] = self.data[i_bucket][idx] |
| |
| label = self.label_buffer[i_bucket] |
| if self.time_major: |
| label[:-1, :] = data[1:, :] |
| label[-1, :] = 0 |
| else: |
| label[:, :-1] = data[:, 1:] |
| label[:, -1] = 0 |
| |
| data_all = [mx.nd.array(data)] + self.init_state_arrays |
| label_all = [mx.nd.array(label)] |
| data_names = ['data'] + init_state_names |
| label_names = ['softmax_label'] |
| |
| data_batch = SimpleBatch(data_names, data_all, [x.layout for x in self.provide_data], |
| label_names, label_all, [x.layout for x in self.provide_label], |
| self.buckets[i_bucket]) |
| yield data_batch |
| |
| |
| def reset(self): |
| self.bucket_curr_idx = [0 for x in self.data] |