Merge pull request #1228 from npcmaci/dev-postgresql

create the application folder for the healthcare model zoo
diff --git a/examples/healthcare/data/malaria.py b/examples/healthcare/data/malaria.py
new file mode 100644
index 0000000..46422b7
--- /dev/null
+++ b/examples/healthcare/data/malaria.py
@@ -0,0 +1,122 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+try:
+    import pickle
+except ImportError:
+    import cPickle as pickle
+
+import numpy as np
+import os
+import sys
+from PIL import Image
+
+
+# need to save to specific local directories
+def load_train_data(dir_path="/tmp/malaria", resize_size=(128, 128)):
+    dir_path = check_dataset_exist(dirpath=dir_path)
+    path_train_label_1 = os.path.join(dir_path, "training_set/Parasitized")
+    path_train_label_0 = os.path.join(dir_path, "training_set/Uninfected")
+    train_label_1 = load_image_path(os.listdir(path_train_label_1))
+    train_label_0 = load_image_path(os.listdir(path_train_label_0))
+    labels = []
+    Images = np.empty((len(train_label_1) + len(train_label_0),
+                       3, resize_size[0], resize_size[1]), dtype=np.uint8)
+    for i in range(len(train_label_0)):
+        image_path = os.path.join(path_train_label_0, train_label_0[i])
+        temp_image = np.array(Image.open(image_path).resize(
+            resize_size).convert("RGB")).transpose(2, 0, 1)
+        Images[i] = temp_image
+        labels.append(0)
+    for i in range(len(train_label_1)):
+        image_path = os.path.join(path_train_label_1, train_label_1[i])
+        temp_image = np.array(Image.open(image_path).resize(
+            resize_size).convert("RGB")).transpose(2, 0, 1)
+        Images[i + len(train_label_0)] = temp_image
+        labels.append(1)
+
+    Images = np.array(Images, dtype=np.float32)
+    labels = np.array(labels, dtype=np.int32)
+    return Images, labels
+
+
+# need to save to specific local directories
+def load_test_data(dir_path='/tmp/malaria', resize_size=(128, 128)):
+    dir_path = check_dataset_exist(dirpath=dir_path)
+    path_test_label_1 = os.path.join(dir_path, "testing_set/Parasitized")
+    path_test_label_0 = os.path.join(dir_path, "testing_set/Uninfected")
+    test_label_1 = load_image_path(os.listdir(path_test_label_1))
+    test_label_0 = load_image_path(os.listdir(path_test_label_0))
+    labels = []
+    Images = np.empty((len(test_label_1) + len(test_label_0),
+                       3, resize_size[0], resize_size[1]), dtype=np.uint8)
+    for i in range(len(test_label_0)):
+        image_path = os.path.join(path_test_label_0, test_label_0[i])
+        temp_image = np.array(Image.open(image_path).resize(
+            resize_size).convert("RGB")).transpose(2, 0, 1)
+        Images[i] = temp_image
+        labels.append(0)
+    for i in range(len(test_label_1)):
+        image_path = os.path.join(path_test_label_1, test_label_1[i])
+        temp_image = np.array(Image.open(image_path).resize(
+            resize_size).convert("RGB")).transpose(2, 0, 1)
+        Images[i + len(test_label_0)] = temp_image
+        labels.append(1)
+
+    Images = np.array(Images, dtype=np.float32)
+    labels = np.array(labels, dtype=np.int32)
+    return Images, labels
+
+
+def load_image_path(list):
+    new_list = []
+    for image_path in list:
+        if (image_path.endswith(".png") or image_path.endswith(".jpg")):
+            new_list.append(image_path)
+    return new_list
+
+
+def check_dataset_exist(dirpath):
+    if not os.path.exists(dirpath):
+        print(
+            'Please download the malaria dataset first'
+        )
+        sys.exit(0)
+    return dirpath
+
+
+def normalize(train_x, val_x):
+    mean = [0.5339, 0.4180, 0.4460]  # mean for malaria dataset
+    std = [0.3329, 0.2637, 0.2761]  # std for malaria dataset
+    train_x /= 255
+    val_x /= 255
+    for ch in range(0, 2):
+        train_x[:, ch, :, :] -= mean[ch]
+        train_x[:, ch, :, :] /= std[ch]
+        val_x[:, ch, :, :] -= mean[ch]
+        val_x[:, ch, :, :] /= std[ch]
+    return train_x, val_x
+
+
+def load(dir_path):
+    train_x, train_y = load_train_data(dir_path=dir_path)
+    val_x, val_y = load_test_data(dir_path=dir_path)
+    train_x, val_x = normalize(train_x, val_x)
+    train_y = train_y.flatten()
+    val_y = val_y.flatten()
+    return train_x, train_y, val_x, val_y
diff --git a/examples/healthcare/models/malaria_net.py b/examples/healthcare/models/malaria_net.py
new file mode 100644
index 0000000..2a10a70
--- /dev/null
+++ b/examples/healthcare/models/malaria_net.py
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from singa import layer
+from singa import model
+from singa import tensor
+from singa import opt
+from singa import device
+
+import numpy as np
+
+np_dtype = {"float16": np.float16, "float32": np.float32}
+
+singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
+
+class CNN(model.Model):
+
+    def __init__(self, num_classes=10, num_channels=1):
+        super(CNN, self).__init__()
+        self.num_classes = num_classes
+        self.input_size = 128
+        self.dimension = 4
+        self.conv1 = layer.Conv2d(num_channels, 32, 3, padding=0, activation="RELU")
+        self.conv2 = layer.Conv2d(32, 64, 3, padding=0, activation="RELU")
+        self.conv3 = layer.Conv2d(64, 64, 3, padding=0, activation="RELU")
+        self.linear1 = layer.Linear(128)
+        self.linear2 = layer.Linear(num_classes)
+        self.pooling1 = layer.MaxPool2d(2, 2, padding=0)
+        self.pooling2 = layer.MaxPool2d(2, 2, padding=0)
+        self.pooling3 = layer.MaxPool2d(2, 2, padding=0)
+        self.relu = layer.ReLU()
+        self.flatten = layer.Flatten()
+        self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
+        self.sigmoid = layer
+
+    def forward(self, x):
+        y = self.conv1(x)
+        y = self.pooling1(y)
+        y = self.conv2(y)
+        y = self.pooling2(y)
+        y = self.conv3(y)
+        y = self.pooling3(y)
+        y = self.flatten(y)
+        y = self.linear1(y)
+        y = self.relu(y)
+        y = self.linear2(y)
+        return y
+
+    def train_one_batch(self, x, y, dist_option, spars):
+        out = self.forward(x)
+        loss = self.softmax_cross_entropy(out, y)
+
+        if dist_option == 'plain':
+            self.optimizer(loss)
+        elif dist_option == 'half':
+            self.optimizer.backward_and_update_half(loss)
+        elif dist_option == 'partialUpdate':
+            self.optimizer.backward_and_partial_update(loss)
+        elif dist_option == 'sparseTopK':
+            self.optimizer.backward_and_sparse_update(loss,
+                                                      topK=True,
+                                                      spars=spars)
+        elif dist_option == 'sparseThreshold':
+            self.optimizer.backward_and_sparse_update(loss,
+                                                      topK=False,
+                                                      spars=spars)
+        return out, loss
+
+    def set_optimizer(self, optimizer):
+        self.optimizer = optimizer
+
+
+class MLP(model.Model):
+
+    def __init__(self, perceptron_size=100, num_classes=10):
+        super(MLP, self).__init__()
+        self.num_classes = num_classes
+        self.dimension = 2
+
+        self.relu = layer.ReLU()
+        self.linear1 = layer.Linear(perceptron_size)
+        self.linear2 = layer.Linear(num_classes)
+        self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
+
+    def forward(self, inputs):
+        y = self.linear1(inputs)
+        y = self.relu(y)
+        y = self.linear2(y)
+        return y
+
+    def train_one_batch(self, x, y, dist_option, spars):
+        out = self.forward(x)
+        loss = self.softmax_cross_entropy(out, y)
+
+        if dist_option == 'plain':
+            self.optimizer(loss)
+        elif dist_option == 'half':
+            self.optimizer.backward_and_update_half(loss)
+        elif dist_option == 'partialUpdate':
+            self.optimizer.backward_and_partial_update(loss)
+        elif dist_option == 'sparseTopK':
+            self.optimizer.backward_and_sparse_update(loss,
+                                                      topK=True,
+                                                      spars=spars)
+        elif dist_option == 'sparseThreshold':
+            self.optimizer.backward_and_sparse_update(loss,
+                                                      topK=False,
+                                                      spars=spars)
+        return out, loss
+
+    def set_optimizer(self, optimizer):
+        self.optimizer = optimizer
+
+
+def create_model(model_option='cnn', **kwargs):
+    """Constructs a CNN model.
+
+    Args:
+        pretrained (bool): If True, returns a pre-trained model.
+
+    Returns:
+        The created CNN model.
+    """
+    model = CNN(**kwargs)
+    if model_option=='mlp':
+        model = MLP(**kwargs)
+
+    return model
+
+
+__all__ = ['CNN', 'MLP', 'create_model']
\ No newline at end of file