Merge pull request #1204 from Zrealshadow/dev-postgresql-patch-2

diff --git a/examples/trans/data.py b/examples/trans/data.py
new file mode 100644
index 0000000..8be5157
--- /dev/null
+++ b/examples/trans/data.py
@@ -0,0 +1,181 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import re
+import numpy as np
+from collections import Counter
+
+
+class Vocab:
+    """
+    The class of vocab, include 2 dicts of token to index and index to token
+    """
+    def __init__(self, sentences):
+        """
+        Args:
+            sentences: a 2-dim list
+        """
+        flatten = lambda lst: [item for sublist in lst for item in sublist]
+        self.sentence = sentences
+        self.token2index = {'<pad>': 0, '<bos>': 1, '<eos>': 2, '<unk>': 3}
+        self.token2index.update({
+            token: index + 4
+            for index, (token, freq) in enumerate(
+                sorted(Counter(flatten(self.sentence)).items(), key=lambda x: x[1], reverse=True))
+        })
+        self.index2token = {index: token for token, index in self.token2index.items()}
+
+    def __getitem__(self, query):
+        if isinstance(query, str):
+            return self.token2index.get(query, self.token2index.get('<unk>'))
+        elif isinstance(query, (int, np.int32, np.int64)):
+            return self.index2token.get(query, '<unk>')
+        elif isinstance(query, (list, tuple, np.ndarray)):
+            return [self.__getitem__(item) for item in query]
+        else:
+            raise ValueError("The type of query is invalid.")
+
+    def __len__(self):
+        return len(self.index2token)
+
+
+class CmnDataset:
+    def __init__(self, path='cmn-eng/cmn.txt', shuffle=False, batch_size=32, train_ratio=0.8, random_seed=0):
+        """
+        cmn dataset, download from https://www.manythings.org/anki/, contains 29909 Chinese and English translation
+        pairs, the pair format: English + TAB + Chinese + TAB + Attribution
+        Args:
+            path: the path of the dataset, default 'cmn-eng/cnn.txt'
+            shuffle: shuffle the dataset, default False
+            batch_size: the size of every batch, default 32
+            train_ratio: the proportion of the training set to the total data set, default 0.8
+            random_seed: the random seed, used for shuffle operation, default 0
+        """
+        src_max_len, tgt_max_len, src_sts, tgt_sts = CmnDataset._split_sentences(path)
+        en_vab, cn_vab = Vocab(src_sts), Vocab(tgt_sts)
+        src_np, tgt_in_np, tgt_out_np = CmnDataset._encoding_stc(src_sts, tgt_sts, src_max_len, tgt_max_len,
+                                                                 en_vab, cn_vab)
+
+        self.src_max_len, self.tgt_max_len = src_max_len, tgt_max_len
+        self.en_vab, self.cn_vab = en_vab, cn_vab
+        self.en_vab_size, self.cn_vab_size = len(en_vab), len(cn_vab)
+
+        self.src_inputs, self.tgt_inputs, self.tgt_outputs = src_np, tgt_in_np, tgt_out_np
+
+        self.shuffle, self.random_seed = shuffle, random_seed
+
+        assert batch_size > 0, "The number of batch_size must be greater than 0"
+        self.batch_size = batch_size
+
+        assert (0 < train_ratio <= 1.0), "The number of train_ratio must be in (0.0, 1.0]"
+        self.train_ratio = train_ratio
+
+        self.total_size = len(src_np)
+        self.train_size = int(self.total_size * train_ratio)
+        self.test_size = self.total_size - self.train_size
+
+        if shuffle:
+            index = [i for i in range(self.total_size)]
+            np.random.seed(self.random_seed)
+            np.random.shuffle(index)
+
+            self.src_inputs = src_np[index]
+            self.tgt_inputs = tgt_in_np[index]
+            self.tgt_outputs = tgt_out_np[index]
+
+        self.train_src_inputs, self.test_src_inputs = self.src_inputs[:self.train_size], self.src_inputs[self.train_size:]
+        self.train_tgt_inputs, self.test_tgt_inputs = self.tgt_inputs[:self.train_size], self.tgt_inputs[self.train_size:]
+        self.train_tgt_outputs, self.test_tgt_outputs = self.tgt_outputs[:self.train_size], self.tgt_outputs[self.train_size:]
+
+    @staticmethod
+    def _split_sentences(path):
+        en_max_len, cn_max_len = 0, 0
+        en_sts, cn_sts = [], []
+        with open(path, 'r', encoding='utf-8') as f:
+            for line in f:
+                line_split = line.split('\t')
+                line_split[0] = re.sub(r'[^\w\s\'-]', '', line_split[0])
+                line_split[0] = line_split[0].lower()
+                # [\u4e00-\u9fa5] matching Chinese characters
+                line_split[1] = re.sub("[^\u4e00-\u9fa5]", "", line_split[1])
+
+                en_stc = line_split[0].split(' ')
+                cn_stc = [word for word in line_split[1]]
+                en_sts.append(en_stc)
+                cn_sts.append(cn_stc)
+                en_max_len = max(en_max_len, len(en_stc))
+                cn_max_len = max(cn_max_len, len(cn_stc))
+        return en_max_len, cn_max_len, en_sts, cn_sts
+
+    @staticmethod
+    def _encoding_stc(src_tokens, tgt_tokens, src_max_len, tgt_max_len, src_vocab, tgt_vocab):
+        src_list = []
+        for line in src_tokens:
+            if len(line) > src_max_len:
+                line = line[:src_max_len]
+            lst = src_vocab[line + ['<pad>'] * (src_max_len + 1 - len(line))]
+            src_list.append(lst)
+        tgt_in_list, tgt_out_list = [], []
+        for line in tgt_tokens:
+            if len(line) > tgt_max_len:
+                line = line[:tgt_max_len]
+            in_lst = tgt_vocab[['<bos>'] + line + ['<pad>'] * (tgt_max_len - len(line))]
+            out_lst = tgt_vocab[line + ['<eos>'] + ['<pad>'] * (tgt_max_len - len(line))]
+            tgt_in_list.append(in_lst)
+            tgt_out_list.append(out_lst)
+        src_np = np.asarray(src_list, dtype=np.int32)
+        tgt_in_np = np.asarray(tgt_in_list, dtype=np.int32)
+        tgt_out_np = np.asarray(tgt_out_list, dtype=np.int32)
+        return src_np, tgt_in_np, tgt_out_np
+
+    def get_batch_data(self, batch, mode='train'):
+        assert (mode == 'train' or mode == 'test'), "The mode must be 'train' or 'test'."
+        total_size = self.train_size
+        if mode == 'test':
+            total_size = self.test_size
+
+        max_batch = total_size // self.batch_size
+        if total_size % self.batch_size > 0:
+            max_batch += 1
+        assert batch < max_batch, "The batch number is out of bounds."
+
+        low = batch * self.batch_size
+        if (batch + 1) * self.batch_size < total_size:
+            high = (batch + 1) * self.batch_size
+        else:
+            high = total_size
+        if mode == 'train':
+            if high-low != self.batch_size:
+                return (np.concatenate((self.train_src_inputs[low:high], self.train_src_inputs[:self.batch_size-high+low]), axis=0),
+                        np.concatenate((self.train_tgt_inputs[low:high], self.train_tgt_inputs[:self.batch_size-high+low]), axis=0),
+                        np.concatenate((self.train_tgt_outputs[low:high], self.train_tgt_outputs[:self.batch_size-high+low]), axis=0))
+            else:
+                return self.train_src_inputs[low:high], self.train_tgt_inputs[low:high], self.train_tgt_outputs[low:high]
+        else:
+            if high-low != self.batch_size:
+                return (np.concatenate((self.test_src_inputs[low:high], self.test_src_inputs[:self.batch_size-high+low]), axis=0),
+                        np.concatenate((self.test_tgt_inputs[low:high], self.test_tgt_inputs[:self.batch_size-high+low]), axis=0),
+                        np.concatenate((self.test_tgt_outputs[low:high], self.test_tgt_outputs[:self.batch_size-high+low]), axis=0))
+            return self.test_src_inputs[low:high], self.test_tgt_inputs[low:high], self.test_tgt_outputs[low:high]
+
+    def __len__(self):
+        return self.src_inputs.shape[0]
+
+    def __getitem__(self, idx):
+        return self.src_inputs[idx], self.tgt_inputs[idx], self.tgt_outputs[idx]