Merge pull request #1173 from apache/dev-postgresql
Merge Dev
diff --git a/examples/cnn_ms/README.md b/examples/cnn_ms/README.md
new file mode 100644
index 0000000..177ae4d
--- /dev/null
+++ b/examples/cnn_ms/README.md
@@ -0,0 +1,44 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Image Classification using Convolutional Neural Networks
+
+Examples inside this folder show how to train CNN models using
+SINGA for image classification.
+
+* `data` includes the scripts for preprocessing image datasets.
+ Currently, MNIST, CIFAR10 and CIFAR100 are included.
+
+* `model` includes the CNN model construction codes by creating
+ a subclass of `Module` to wrap the neural network operations
+ of each model. Then computational graph is enabled to optimized
+ the memory and efficiency.
+
+* `autograd` includes the codes to train CNN models by calling the
+ [neural network operations](../../python/singa/autograd.py) imperatively.
+ The computational graph is not created.
+
+* `train_cnn.py` is the training script, which controls the training flow by
+ doing BackPropagation and SGD update.
+
+* `train_multiprocess.py` is the script for distributed training on a single
+ node with multiple GPUs; it uses Python's multiprocessing module and NCCL.
+
+* `train_mpi.py` is the script for distributed training (among multiple nodes)
+ using MPI and NCCL for communication.
+
+* `benchmark.py` tests the training throughput using `ResNet50` as the workload.
\ No newline at end of file
diff --git a/examples/cnn_ms/autograd/cifar10_multiprocess.py b/examples/cnn_ms/autograd/cifar10_multiprocess.py
new file mode 100644
index 0000000..815d011
--- /dev/null
+++ b/examples/cnn_ms/autograd/cifar10_multiprocess.py
@@ -0,0 +1,43 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from resnet_cifar10 import *
+import multiprocessing
+import sys
+
+if __name__ == '__main__':
+
+ # Generate a NCCL ID to be used for collective communication
+ nccl_id = singa.NcclIdHolder()
+
+ # Configure the number of GPUs to be used
+ world_size = int(sys.argv[1])
+
+ # Testing the experimental partial-parameter update asynchronous training
+ partial_update = True
+
+ process = []
+ for local_rank in range(0, world_size):
+ process.append(
+ multiprocessing.Process(target=train_cifar10,
+ args=(True, local_rank, world_size, nccl_id,
+ partial_update)))
+
+ for p in process:
+ p.start()
diff --git a/examples/cnn_ms/autograd/xceptionnet.py b/examples/cnn_ms/autograd/xceptionnet.py
new file mode 100644
index 0000000..8fb23d8
--- /dev/null
+++ b/examples/cnn_ms/autograd/xceptionnet.py
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+from singa import autograd
+from singa import tensor
+from singa import device
+from singa import layer
+from singa import opt
+
+import numpy as np
+from tqdm import trange
+
+# the code is modified from
+# https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/xception.py
+
+
+class Block(layer.Layer):
+
+ def __init__(self,
+ in_filters,
+ out_filters,
+ reps,
+ strides=1,
+ padding=0,
+ start_with_relu=True,
+ grow_first=True):
+ super(Block, self).__init__()
+
+ if out_filters != in_filters or strides != 1:
+ self.skip = layer.Conv2d(in_filters,
+ out_filters,
+ 1,
+ stride=strides,
+ padding=padding,
+ bias=False)
+ self.skipbn = layer.BatchNorm2d(out_filters)
+ else:
+ self.skip = None
+
+ self.layers = []
+
+ filters = in_filters
+ if grow_first:
+ self.layers.append(layer.ReLU())
+ self.layers.append(
+ layer.SeparableConv2d(in_filters,
+ out_filters,
+ 3,
+ stride=1,
+ padding=1,
+ bias=False))
+ self.layers.append(layer.BatchNorm2d(out_filters))
+ filters = out_filters
+
+ for i in range(reps - 1):
+ self.layers.append(layer.ReLU())
+ self.layers.append(
+ layer.SeparableConv2d(filters,
+ filters,
+ 3,
+ stride=1,
+ padding=1,
+ bias=False))
+ self.layers.append(layer.BatchNorm2d(filters))
+
+ if not grow_first:
+ self.layers.append(layer.ReLU())
+ self.layers.append(
+ layer.SeparableConv2d(in_filters,
+ out_filters,
+ 3,
+ stride=1,
+ padding=1,
+ bias=False))
+ self.layers.append(layer.BatchNorm2d(out_filters))
+
+ if not start_with_relu:
+ self.layers = self.layers[1:]
+ else:
+ self.layers[0] = layer.ReLU()
+
+ if strides != 1:
+ self.layers.append(layer.MaxPool2d(3, strides, padding + 1))
+
+ self.register_layers(*self.layers)
+
+ self.add = layer.Add()
+
+ def forward(self, x):
+ y = self.layers[0](x)
+ for layer in self.layers[1:]:
+ if isinstance(y, tuple):
+ y = y[0]
+ y = layer(y)
+
+ if self.skip is not None:
+ skip = self.skip(x)
+ skip = self.skipbn(skip)
+ else:
+ skip = x
+ y = self.add(y, skip)
+ return y
+
+
+__all__ = ['Xception']
+
+
+class Xception(layer.Layer):
+ """
+ Xception optimized for the ImageNet dataset, as specified in
+ https://arxiv.org/pdf/1610.02357.pdf
+ """
+
+ def __init__(self, num_classes=1000):
+ """ Constructor
+ Args:
+ num_classes: number of classes
+ """
+ super(Xception, self).__init__()
+ self.num_classes = num_classes
+
+ self.conv1 = layer.Conv2d(3, 32, 3, 2, 0, bias=False)
+ self.bn1 = layer.BatchNorm2d(32)
+ self.relu1 = layer.ReLU()
+
+ self.conv2 = layer.Conv2d(32, 64, 3, 1, 1, bias=False)
+ self.bn2 = layer.BatchNorm2d(64)
+ self.relu2 = layer.ReLU()
+ # do relu here
+
+ self.block1 = Block(64,
+ 128,
+ 2,
+ 2,
+ padding=0,
+ start_with_relu=False,
+ grow_first=True)
+ self.block2 = Block(128,
+ 256,
+ 2,
+ 2,
+ padding=0,
+ start_with_relu=True,
+ grow_first=True)
+ self.block3 = Block(256,
+ 728,
+ 2,
+ 2,
+ padding=0,
+ start_with_relu=True,
+ grow_first=True)
+
+ self.block4 = Block(728,
+ 728,
+ 3,
+ 1,
+ start_with_relu=True,
+ grow_first=True)
+ self.block5 = Block(728,
+ 728,
+ 3,
+ 1,
+ start_with_relu=True,
+ grow_first=True)
+ self.block6 = Block(728,
+ 728,
+ 3,
+ 1,
+ start_with_relu=True,
+ grow_first=True)
+ self.block7 = Block(728,
+ 728,
+ 3,
+ 1,
+ start_with_relu=True,
+ grow_first=True)
+
+ self.block8 = Block(728,
+ 728,
+ 3,
+ 1,
+ start_with_relu=True,
+ grow_first=True)
+ self.block9 = Block(728,
+ 728,
+ 3,
+ 1,
+ start_with_relu=True,
+ grow_first=True)
+ self.block10 = Block(728,
+ 728,
+ 3,
+ 1,
+ start_with_relu=True,
+ grow_first=True)
+ self.block11 = Block(728,
+ 728,
+ 3,
+ 1,
+ start_with_relu=True,
+ grow_first=True)
+
+ self.block12 = Block(728,
+ 1024,
+ 2,
+ 2,
+ start_with_relu=True,
+ grow_first=False)
+
+ self.conv3 = layer.SeparableConv2d(1024, 1536, 3, 1, 1)
+ self.bn3 = layer.BatchNorm2d(1536)
+ self.relu3 = layer.ReLU()
+
+ # Relu Layer
+ self.conv4 = layer.SeparableConv2d(1536, 2048, 3, 1, 1)
+ self.bn4 = layer.BatchNorm2d(2048)
+
+ self.relu4 = layer.ReLU()
+ self.globalpooling = layer.MaxPool2d(10, 1)
+ self.flatten = layer.Flatten()
+ self.fc = layer.Linear(2048, num_classes)
+
+ def features(self, input):
+ x = self.conv1(input)
+ x = self.bn1(x)
+ x = self.relu1(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu2(x)
+
+ x = self.block1(x)
+ x = self.block2(x)
+ x = self.block3(x)
+ x = self.block4(x)
+ x = self.block5(x)
+ x = self.block6(x)
+ x = self.block7(x)
+ x = self.block8(x)
+ x = self.block9(x)
+ x = self.block10(x)
+ x = self.block11(x)
+ x = self.block12(x)
+
+ x = self.conv3(x)
+ x = self.bn3(x)
+ x = self.relu3(x)
+
+ x = self.conv4(x)
+ x = self.bn4(x)
+ return x
+
+ def logits(self, features):
+ x = self.relu4(features)
+ x = self.globalpooling(x)
+ x = self.flatten(x)
+ x = self.fc(x)
+ return x
+
+ def forward(self, input):
+ x = self.features(input)
+ x = self.logits(x)
+ return x
+
+
+if __name__ == '__main__':
+ model = Xception(num_classes=1000)
+ print('Start intialization............')
+ dev = device.create_cuda_gpu_on(0)
+ #dev = device.create_cuda_gpu()
+
+ niters = 20
+ batch_size = 16
+ IMG_SIZE = 299
+ sgd = opt.SGD(lr=0.1, momentum=0.9, weight_decay=1e-5)
+
+ tx = tensor.Tensor((batch_size, 3, IMG_SIZE, IMG_SIZE), dev)
+ ty = tensor.Tensor((batch_size,), dev, tensor.int32)
+ autograd.training = True
+ x = np.random.randn(batch_size, 3, IMG_SIZE, IMG_SIZE).astype(np.float32)
+ y = np.random.randint(0, 1000, batch_size, dtype=np.int32)
+ tx.copy_from_numpy(x)
+ ty.copy_from_numpy(y)
+
+ with trange(niters) as t:
+ for _ in t:
+ x = model(tx)
+ loss = autograd.softmax_cross_entropy(x, ty)
+ sgd(loss)
diff --git a/examples/cnn_ms/benchmark.py b/examples/cnn_ms/benchmark.py
new file mode 100644
index 0000000..9f69fee
--- /dev/null
+++ b/examples/cnn_ms/benchmark.py
@@ -0,0 +1,121 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+# the code is modified from
+# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
+
+from singa import opt
+from singa import device
+from singa import tensor
+
+import argparse
+import time
+import numpy as np
+from tqdm import trange
+
+
+def train_resnet(DIST=True, graph=True, sequential=False, verbosity=0):
+
+ # Define the hypermeters for the train_resnet
+ niters = 100
+ batch_size = 32
+ sgd = opt.SGD(lr=0.1, momentum=0.9, weight_decay=1e-5)
+
+ IMG_SIZE = 224
+
+ # For distributed training, sequential has better throughput in the current version
+ if DIST == True:
+ sgd = opt.DistOpt(sgd)
+ world_size = sgd.world_size
+ local_rank = sgd.local_rank
+ global_rank = sgd.global_rank
+ sequential = True
+ else:
+ local_rank = 0
+ world_size = 1
+ global_rank = 0
+ sequential = False
+
+ dev = device.create_cuda_gpu_on(local_rank)
+
+ tx = tensor.Tensor((batch_size, 3, IMG_SIZE, IMG_SIZE), dev)
+ ty = tensor.Tensor((batch_size,), dev, tensor.int32)
+ x = np.random.randn(batch_size, 3, IMG_SIZE, IMG_SIZE).astype(np.float32)
+ y = np.random.randint(0, 1000, batch_size, dtype=np.int32)
+ tx.copy_from_numpy(x)
+ ty.copy_from_numpy(y)
+
+ dev.SetVerbosity(verbosity)
+ dev.SetSkipIteration(5)
+
+ # Construct the model
+ from model import resnet
+ model = resnet.resnet50(num_channels=3, num_classes=1000)
+
+ model.train()
+ model.set_optimizer(sgd)
+ model.compile([tx], is_train=True, use_graph=graph, sequential=sequential)
+
+ # Train model
+ dev.Sync()
+ start = time.time()
+ with trange(niters) as t:
+ for _ in t:
+ model(tx, ty, dist_option='fp32', spars=None)
+
+ dev.Sync()
+ end = time.time()
+ titer = (end - start) / float(niters)
+ throughput = float(niters * batch_size * world_size) / (end - start)
+ if global_rank == 0:
+ print("\nThroughput = {} per second".format(throughput), flush=True)
+ print("TotalTime={}".format(end - start), flush=True)
+ print("Total={}".format(titer), flush=True)
+ dev.PrintTimeProfiling()
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(
+ description='Throughput test using Resnet 50')
+ parser.add_argument('--dist',
+ '--enable-dist',
+ default='False',
+ action='store_true',
+ help='enable distributed training',
+ dest='DIST')
+ parser.add_argument('--no-graph',
+ '--disable-graph',
+ default='True',
+ action='store_false',
+ help='disable graph',
+ dest='graph')
+ parser.add_argument('--verbosity',
+ '--log-verbosity',
+ default=0,
+ type=int,
+ help='logging verbosity',
+ dest='verbosity')
+
+ args = parser.parse_args()
+
+ train_resnet(DIST=args.DIST,
+ graph=args.graph,
+ sequential=False,
+ verbosity=args.verbosity)
diff --git a/examples/cnn_ms/pkg_model_code/model.py b/examples/cnn_ms/pkg_model_code/model.py
new file mode 100644
index 0000000..3fea914
--- /dev/null
+++ b/examples/cnn_ms/pkg_model_code/model.py
@@ -0,0 +1,357 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# =============================================================================
+'''
+This script includes Model class for python users
+to use Computational Graph in their model.
+'''
+
+import os
+import gc
+import time
+import json
+import zipfile
+import numpy as np
+from functools import wraps
+from collections import Iterable
+
+from singa import tensor
+from singa import autograd
+from singa import layer
+from .tensor import Tensor
+from . import singa_wrap as singa
+
+
+class ModelMeta(layer.LayerMeta):
+
+ def buffer_operation(func):
+
+ def remove_creator(tensors):
+ if not tensors:
+ return
+
+ if isinstance(tensors, Iterable):
+ if isinstance(tensors, str):
+ return
+ else:
+ for item in tensors:
+ if isinstance(item, Iterable):
+ remove_creator(item)
+ elif isinstance(item, tensor.Tensor):
+ item.creator = None
+ elif isinstance(tensors, tensor.Tensor):
+ tensors.creator = None
+
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ if self.graph_mode and self.training:
+ if len(args) == 0:
+ raise ValueError('expect at least one input tensor')
+
+ if isinstance(args[0], list):
+ assert isinstance(
+ args[0][0],
+ Tensor), ('function expects PlaceHolders or Tensors')
+ dev = args[0][0].device
+ else:
+ assert isinstance(
+ args[0],
+ Tensor), ('function expects PlaceHolders or Tensors')
+ dev = args[0].device
+
+ if not self._buffered:
+ # buffer operations
+ dev.EnableGraph(True)
+ self._results = func(self, *args, **kwargs)
+ dev.Sync()
+ dev.EnableGraph(False)
+ self._buffered = True
+
+ # deconstruct Operations before running the entire graph
+ remove_creator(self._results)
+
+ # make sure all Operations are deallocated
+ gc.collect()
+
+ # run graph
+ dev.RunGraph(self.sequential)
+ return self._results
+ else:
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+ def __new__(cls, name, bases, attr):
+ if 'train_one_batch' in attr:
+ attr['train_one_batch'] = ModelMeta.buffer_operation(
+ attr['train_one_batch'])
+
+ return super(ModelMeta, cls).__new__(cls, name, bases, attr)
+
+
+class Model(layer.Layer, metaclass=ModelMeta):
+ """ Base class for your neural network models.
+
+ Example usage::
+
+ import numpy as np
+ from singa import opt
+ from singa import tensor
+ from singa import device
+ from singa import autograd
+ from singa import layer
+ from singa import model
+
+ class MyModel(model.Model):
+ def __init__(self):
+ super(MyModel, self).__init__()
+
+ self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
+ self.conv1 = layer.Conv2d(1, 20, 5, padding=0)
+ self.conv2 = layer.Conv2d(20, 50, 5, padding=0)
+ self.sgd = opt.SGD(lr=0.01)
+
+ def forward(self, x):
+ y = self.conv1(x)
+ y = self.conv2(y)
+ return y
+
+ def train_one_batch(self, x, y):
+ out = self.forward(x)
+ loss = self.softmax_cross_entropy(out, y)
+ self.sgd(loss)
+ return out, loss
+
+ """
+
+ # save load states constant
+ TENSOR_DICT_FILENAME = '/tensor_dict.npz'
+ STATES_ATTR_FILENAME = '/states_attr.json'
+ MODEL_STATE_TYPE = 0
+ AUX_STATE_TYPE = 1
+
+ def __init__(self):
+ """
+ Initializes internal Model state
+ """
+ super(Model, self).__init__()
+
+ self.training = True
+ self.graph_mode = True
+ self.sequential = False
+ self._buffered = False
+ self._results = None
+
+ def compile(self, inputs, is_train=True, use_graph=False, sequential=False):
+ """ Compile and initialize the model
+
+ This function will automatically derive the shape of parameters
+ in each sublayer based on the shape of input placeholders. It will
+ also do some settings.
+
+ Args:
+ inputs(list): the list of input tensors(placeholders)
+ is_train(bool): when is_trainis True, this model will enter
+ training mode, otherwise it will enter the evaluation mode
+ use_graph(bool): when use_graph is True, computational graph
+ will be used to train this model
+ sequential(bool): when sequential is True, model will execute ops
+ in the graph follow the order of joining the graph
+ """
+ assert len(inputs) > 0 and isinstance(inputs[0], Tensor), (
+ 'compile function expects PlaceHolders or Tensors')
+
+ dev = inputs[0].device
+ dev.EnableGraph(True)
+ self.forward(*inputs)
+ dev.EnableGraph(False)
+ dev.ResetGraph()
+
+ autograd.training = is_train
+ self.training = is_train
+ self.graph_mode = use_graph
+ self.sequential = sequential
+
+ def forward(self, *input):
+ """Defines the computation performed in every forward propagation.
+
+ Should be overridden by all subclasses.
+
+ Args:
+ *input: the input training data for the model
+
+ Returns:
+ out: the outputs of the forward propagation.
+ """
+ raise NotImplementedError
+
+ def train_one_batch(self, *input, **kwargs):
+ """Defines the computation performed in every training iteration
+
+ Should be overridden by all subclasses.
+
+ Args:
+ *input: the arguments of train_one_batch
+ **kwargs: the keyword arguments of train_one_batch
+ """
+ raise NotImplementedError
+
+ def train(self, mode=True):
+ """Set the model in evaluation mode.
+
+ Args:
+ mode(bool): when mode is True, this model will enter training mode
+ """
+ self.training = mode
+ autograd.training = mode
+
+ def eval(self):
+ """Sets the model in evaluation mode.
+ """
+ self.train(mode=False)
+
+ def graph(self, mode=True, sequential=False):
+ """ Turn on the computational graph. Specify execution mode.
+
+ Args:
+ mode(bool): when mode is True, model will use computational graph
+ sequential(bool): when sequential is True, model will execute ops
+ in the graph follow the order of joining the graph
+ """
+ self.graph_mode = mode
+ self.sequential = sequential
+
+ def __get_name__(self):
+ return self.__class__.__name__
+
+ def __call__(self, *input, **kwargs):
+ if self.training:
+ return self.train_one_batch(*input, **kwargs)
+ else:
+ return self.forward(*input, **kwargs)
+
+ def save_states(self, fpath, aux_states={}):
+ """Save states.
+
+ Args:
+ fpath: output file path (without the extension)
+ aux_states(dict): values are standard data types or Tensor,
+ e.g., epoch ID, learning rate, optimizer states
+ """
+ assert not os.path.isfile(fpath), (
+ "Failed to save states, %s is already existed." % fpath)
+
+ states = self.get_states()
+
+ # save states data and attr
+ tensor_dict = {}
+ states_attr = {}
+ for k, v in states.items():
+ assert isinstance(v, tensor.Tensor), "Only tensor state is allowed"
+ tensor_dict[k] = tensor.to_numpy(v)
+ states_attr[k] = {
+ 'state_type': self.MODEL_STATE_TYPE,
+ 'shape': v.shape,
+ 'dtype': v.dtype
+ }
+
+ for k, v in aux_states.items():
+ assert isinstance(v,
+ tensor.Tensor), "Only tensor aux state is allowed"
+ tensor_dict[k] = tensor.to_numpy(v)
+ states_attr[k] = {
+ 'state_type': self.AUX_STATE_TYPE,
+ 'shape': v.shape,
+ 'dtype': v.dtype
+ }
+
+ # save to files
+ timestamp = time.time()
+ tmp_dir = '/tmp/singa_save_states_%s' % timestamp
+ os.mkdir(tmp_dir)
+ tensor_dict_fp = tmp_dir + self.TENSOR_DICT_FILENAME
+ states_attr_fp = tmp_dir + self.STATES_ATTR_FILENAME
+
+ np.savez(tensor_dict_fp, **tensor_dict)
+
+ with open(states_attr_fp, 'w') as fp:
+ json.dump(states_attr, fp)
+
+ compression = zipfile.ZIP_DEFLATED
+ with zipfile.ZipFile(fpath, mode="w") as zf:
+ zf.write(tensor_dict_fp,
+ os.path.basename(tensor_dict_fp),
+ compress_type=compression)
+ zf.write(states_attr_fp,
+ os.path.basename(states_attr_fp),
+ compress_type=compression)
+
+ # clean up tmp files
+ os.remove(tensor_dict_fp)
+ os.remove(states_attr_fp)
+ os.rmdir(tmp_dir)
+
+ def load_states(self, fpath):
+ """Load the model states and auxiliary states from disk.
+
+ Usage:
+ m = MyModel()
+ m.compile(...)
+ aux_states = m.load_states('mymodel.zip')
+
+ Args:
+ path: input file path (without the extension)
+ Returns:
+ dict
+ """
+
+ assert os.path.isfile(fpath), (
+ "Failed to load states, %s is not exist." % fpath)
+
+ timestamp = time.time()
+ tmp_dir = '/tmp/singa_load_states_%s' % timestamp
+ os.mkdir(tmp_dir)
+
+ with zipfile.ZipFile(fpath, 'r') as zf:
+ zf.extractall(tmp_dir)
+
+ tensor_dict_fp = tmp_dir + self.TENSOR_DICT_FILENAME
+ states_attr_fp = tmp_dir + self.STATES_ATTR_FILENAME
+
+ with open(states_attr_fp) as f:
+ states_attr = json.load(f)
+
+ tensor_dict = np.load(tensor_dict_fp)
+
+ # restore singa tensor from numpy
+ model_states = dict()
+ aux_states = dict()
+
+ for k in tensor_dict.files:
+ if states_attr[k]['state_type'] == self.MODEL_STATE_TYPE:
+ model_states[k] = tensor.from_numpy(tensor_dict[k])
+ elif states_attr[k]['state_type'] == self.AUX_STATE_TYPE:
+ aux_states[k] = tensor.from_numpy(tensor_dict[k])
+
+ # restore model_states
+ self.set_states(model_states)
+
+ # clean up tmp files
+ os.remove(tensor_dict_fp)
+ os.remove(states_attr_fp)
+ os.rmdir(tmp_dir)
+ return aux_states
\ No newline at end of file
diff --git a/examples/ms_model_mlp/model.py b/examples/ms_model_mlp/model.py
new file mode 100644
index 0000000..454b382
--- /dev/null
+++ b/examples/ms_model_mlp/model.py
@@ -0,0 +1,226 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import layer
+from singa import model
+from singa import tensor
+from singa import opt
+from singa import device
+from singa.autograd import Operator
+from singa.layer import Layer
+from singa import singa_wrap as singa
+import argparse
+import numpy as np
+
+np_dtype = {"float16": np.float16, "float32": np.float32}
+
+singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
+
+#### self-defined loss begin
+
+### from autograd.py
+class SumError(Operator):
+
+ def __init__(self):
+ super(SumError, self).__init__()
+ # self.t = t.data
+
+ def forward(self, x):
+ # self.err = singa.__sub__(x, self.t)
+ self.data_x = x
+ # sqr = singa.Square(self.err)
+ # loss = singa.SumAll(sqr)
+ loss = singa.SumAll(x)
+ # self.n = 1
+ # for s in x.shape():
+ # self.n *= s
+ # loss /= self.n
+ return loss
+
+ def backward(self, dy=1.0):
+ # dx = self.err
+ dev = device.get_default_device()
+ dx = tensor.Tensor(self.data_x.shape, dev, singa_dtype['float32'])
+ dx.copy_from_numpy(np.ones(self.data_x.shape))
+ # dx *= float(2 / self.n)
+ dx *= dy
+ return dx
+
+def se_loss(x):
+ # assert x.shape == t.shape, "input and target shape different: %s, %s" % (
+ # x.shape, t.shape)
+ return SumError()(x)[0]
+
+### from layer.py
+class SumErrorLayer(Layer):
+ """
+ Generate a MeanSquareError operator
+ """
+
+ def __init__(self):
+ super(SumErrorLayer, self).__init__()
+
+ def forward(self, x):
+ return se_loss(x)
+
+#### self-defined loss end
+
+class MSMLP(model.Model):
+
+ def __init__(self, data_size=10, perceptron_size=100, num_classes=10, layer_hidden_list=[10,10,10,10]):
+ super(MSMLP, self).__init__()
+ self.num_classes = num_classes
+ self.dimension = 2
+
+ self.relu = layer.ReLU()
+ self.linear1 = layer.Linear(layer_hidden_list[0])
+ self.linear2 = layer.Linear(layer_hidden_list[1])
+ self.linear3 = layer.Linear(layer_hidden_list[2])
+ self.linear4 = layer.Linear(layer_hidden_list[3])
+ self.linear5 = layer.Linear(num_classes)
+ self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
+ self.sum_error = SumErrorLayer()
+
+ def forward(self, inputs):
+ y = self.linear1(inputs)
+ y = self.relu(y)
+ y = self.linear2(y)
+ y = self.relu(y)
+ y = self.linear3(y)
+ y = self.relu(y)
+ y = self.linear4(y)
+ y = self.relu(y)
+ y = self.linear5(y)
+ return y
+
+ def train_one_batch(self, x, y, dist_option, spars, synflow_flag):
+ # print ("in train_one_batch")
+ out = self.forward(x)
+ # print ("train_one_batch x.data: \n", x.data)
+ # print ("train_one_batch y.data: \n", y.data)
+ # print ("train_one_batch out.data: \n", out.data)
+ if synflow_flag:
+ # print ("sum_error")
+ loss = self.sum_error(out)
+ else: # normal training
+ # print ("softmax_cross_entropy")
+ loss = self.softmax_cross_entropy(out, y)
+ # print ("train_one_batch loss.data: \n", loss.data)
+
+ if dist_option == 'plain':
+ # print ("before pn_p_g_list = self.optimizer(loss)")
+ pn_p_g_list = self.optimizer(loss)
+ # print ("after pn_p_g_list = self.optimizer(loss)")
+ elif dist_option == 'half':
+ self.optimizer.backward_and_update_half(loss)
+ elif dist_option == 'partialUpdate':
+ self.optimizer.backward_and_partial_update(loss)
+ elif dist_option == 'sparseTopK':
+ self.optimizer.backward_and_sparse_update(loss,
+ topK=True,
+ spars=spars)
+ elif dist_option == 'sparseThreshold':
+ self.optimizer.backward_and_sparse_update(loss,
+ topK=False,
+ spars=spars)
+ # print ("len(pn_p_g_list): \n", len(pn_p_g_list))
+ # print ("len(pn_p_g_list[0]): \n", len(pn_p_g_list[0]))
+ # print ("pn_p_g_list[0][0]: \n", pn_p_g_list[0][0])
+ # print ("pn_p_g_list[0][1].data: \n", pn_p_g_list[0][1].data)
+ # print ("pn_p_g_list[0][2].data: \n", pn_p_g_list[0][2].data)
+ return pn_p_g_list, out, loss
+ # return pn_p_g_list[0], pn_p_g_list[1], pn_p_g_list[2], out, loss
+
+ def set_optimizer(self, optimizer):
+ self.optimizer = optimizer
+
+
+def create_model(pretrained=False, **kwargs):
+ """Constructs a CNN model.
+
+ Args:
+ pretrained (bool): If True, returns a pre-trained model.
+
+ Returns:
+ The created CNN model.
+ """
+ model = MSMLP(**kwargs)
+
+ return model
+
+
+__all__ = ['MLP', 'create_model']
+
+if __name__ == "__main__":
+ np.random.seed(0)
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-p',
+ choices=['float32', 'float16'],
+ default='float32',
+ dest='precision')
+ parser.add_argument('-g',
+ '--disable-graph',
+ default='True',
+ action='store_false',
+ help='disable graph',
+ dest='graph')
+ parser.add_argument('-m',
+ '--max-epoch',
+ default=1001,
+ type=int,
+ help='maximum epochs',
+ dest='max_epoch')
+ args = parser.parse_args()
+
+ # generate the boundary
+ f = lambda x: (5 * x + 1)
+ bd_x = np.linspace(-1.0, 1, 200)
+ bd_y = f(bd_x)
+
+ # generate the training data
+ x = np.random.uniform(-1, 1, 400)
+ y = f(x) + 2 * np.random.randn(len(x))
+
+ # choose one precision
+ precision = singa_dtype[args.precision]
+ np_precision = np_dtype[args.precision]
+
+ # convert training data to 2d space
+ label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)]).astype(np.int32)
+ data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np_precision)
+
+ dev = device.create_cuda_gpu_on(0)
+ sgd = opt.SGD(0.1, 0.9, 1e-5, dtype=singa_dtype[args.precision])
+ tx = tensor.Tensor((400, 2), dev, precision)
+ ty = tensor.Tensor((400,), dev, tensor.int32)
+ model = MLP(data_size=2, perceptron_size=3, num_classes=2)
+
+ # attach model to graph
+ model.set_optimizer(sgd)
+ model.compile([tx], is_train=True, use_graph=args.graph, sequential=True)
+ model.train()
+
+ for i in range(args.max_epoch):
+ tx.copy_from_numpy(data)
+ ty.copy_from_numpy(label)
+ out, loss = model(tx, ty, 'fp32', spars=None)
+
+ if i % 100 == 0:
+ print("training loss = ", tensor.to_numpy(loss)[0])
diff --git a/examples/ms_model_mlp/native.py b/examples/ms_model_mlp/native.py
new file mode 100644
index 0000000..a82ec3b
--- /dev/null
+++ b/examples/ms_model_mlp/native.py
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import tensor
+from singa.tensor import Tensor
+from singa import autograd
+from singa import opt
+import numpy as np
+from singa import device
+import argparse
+
+np_dtype = {"float16": np.float16, "float32": np.float32}
+
+singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-p',
+ choices=['float32', 'float16'],
+ default='float32',
+ dest='precision')
+ parser.add_argument('-m',
+ '--max-epoch',
+ default=1001,
+ type=int,
+ help='maximum epochs',
+ dest='max_epoch')
+ args = parser.parse_args()
+
+ np.random.seed(0)
+
+ autograd.training = True
+
+ # prepare training data in numpy array
+
+ # generate the boundary
+ f = lambda x: (5 * x + 1)
+ bd_x = np.linspace(-1.0, 1, 200)
+ bd_y = f(bd_x)
+
+ # generate the training data
+ x = np.random.uniform(-1, 1, 400)
+ y = f(x) + 2 * np.random.randn(len(x))
+
+ # convert training data to 2d space
+ label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)])
+ data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32)
+
+ def to_categorical(y, num_classes):
+ """
+ Converts a class vector (integers) to binary class matrix.
+
+ Args:
+ y: class vector to be converted into a matrix
+ (integers from 0 to num_classes).
+ num_classes: total number of classes.
+
+ Returns:
+ A binary matrix representation of the input.
+ """
+ y = np.array(y, dtype="int")
+ n = y.shape[0]
+ categorical = np.zeros((n, num_classes))
+ categorical[np.arange(n), y] = 1
+ return categorical
+
+ label = to_categorical(label, 2).astype(np.float32)
+ print("train_data_shape:", data.shape)
+ print("train_label_shape:", label.shape)
+
+ precision = singa_dtype[args.precision]
+ np_precision = np_dtype[args.precision]
+
+ dev = device.create_cuda_gpu()
+
+ inputs = Tensor(data=data, device=dev)
+ target = Tensor(data=label, device=dev)
+
+ inputs = inputs.as_type(precision)
+ target = target.as_type(tensor.int32)
+
+ w0_np = np.random.normal(0, 0.1, (2, 3)).astype(np_precision)
+ w0 = Tensor(data=w0_np,
+ device=dev,
+ dtype=precision,
+ requires_grad=True,
+ stores_grad=True)
+ b0 = Tensor(shape=(3,),
+ device=dev,
+ dtype=precision,
+ requires_grad=True,
+ stores_grad=True)
+ b0.set_value(0.0)
+
+ w1_np = np.random.normal(0, 0.1, (3, 2)).astype(np_precision)
+ w1 = Tensor(data=w1_np,
+ device=dev,
+ dtype=precision,
+ requires_grad=True,
+ stores_grad=True)
+ b1 = Tensor(shape=(2,),
+ device=dev,
+ dtype=precision,
+ requires_grad=True,
+ stores_grad=True)
+ b1.set_value(0.0)
+
+ sgd = opt.SGD(0.05, 0.8)
+
+ # training process
+ for i in range(args.max_epoch):
+ x = autograd.matmul(inputs, w0)
+ x = autograd.add_bias(x, b0)
+ x = autograd.relu(x)
+ x = autograd.matmul(x, w1)
+ x = autograd.add_bias(x, b1)
+ loss = autograd.softmax_cross_entropy(x, target)
+ sgd(loss)
+
+ if i % 100 == 0:
+ print("%d, training loss = " % i, tensor.to_numpy(loss)[0])