Merge pull request #1166 from apache/dev-postgresql
Merge Dev branch
diff --git a/.github/workflows/ubuntu.yaml b/.github/workflows/ubuntu.yaml
index b67fcda..feb801a 100644
--- a/.github/workflows/ubuntu.yaml
+++ b/.github/workflows/ubuntu.yaml
@@ -40,22 +40,24 @@
# run: cd build && make
# - name: C++ test
# run: build/bin/test_singa
-
+
build-cpptest-on-cpu:
- runs-on: ubuntu-latest
+ runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v1
- name: get-oneDNN
run: wget https://github.com/oneapi-src/oneDNN/releases/download/v1.1/dnnl_lnx_1.1.0_cpu_gomp.tgz -P /tmp/ && tar zxf /tmp/dnnl_lnx_1.1.0_cpu_gomp.tgz -C /tmp
+ - name: setup-sys-env
+ run: sudo apt-get install -y curl wget git cmake
- name: install-build-dependencies
run: sudo apt-get install -y libgoogle-glog-dev libprotobuf-dev protobuf-compiler libncurses-dev libopenblas-dev gfortran libblas-dev liblapack-dev libatlas-base-dev swig dh-autoreconf lcov
- name: configure
run: mkdir build && cd build && cmake -DUSE_PYTHON=NO -DENABLE_TEST=YES -DCODE_COVERAGE=YES -DUSE_DNNL=YES ..
env:
- DNNL_ROOT: /tmp/dnnl_lnx_1.1.0_cpu_gomp/
+ DNNL_ROOT: /tmp/dnnl_lnx_1.1.0_cpu_gomp/
- name: build
- run: cd build && make
+ run: cd build && make -j8
- name: C++ test
run: build/bin/test_singa
- name: Upload coverage to Codecov
diff --git a/examples/cnn_ms/run.sh b/examples/cnn_ms/run.sh
new file mode 100644
index 0000000..a536a1e
--- /dev/null
+++ b/examples/cnn_ms/run.sh
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+#!/usr/bin/env python -W ignore::DeprecationWarning
+
+### mnist
+python train_cnn.py mlp mnist
+python train_cnn.py cnn mnist
+python train_cnn.py resnet mnist
+python train_cnn.py alexnet mnist
+
+### cifar10
+python train_cnn.py mlp cifar10
+python train_cnn.py cnn cifar10
+python train_cnn.py resnet cifar10
+python train_cnn.py alexnet cifar10
+
+### cifar100
+python train_cnn.py mlp cifar100
+python train_cnn.py cnn cifar100
+python train_cnn.py resnet cifar100
+python train_cnn.py alexnet cifar100
diff --git a/examples/cnn_ms/train_cnn.py b/examples/cnn_ms/train_cnn.py
new file mode 100644
index 0000000..d7f8f70
--- /dev/null
+++ b/examples/cnn_ms/train_cnn.py
@@ -0,0 +1,554 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import singa_wrap as singa
+from singa import device
+from singa import tensor
+from singa import opt
+from singa import autograd
+from singa.opt import Optimizer
+from singa.opt import DecayScheduler
+from singa.opt import Constant
+import numpy as np
+import time
+import argparse
+from PIL import Image
+
+np_dtype = {"float16": np.float16, "float32": np.float32}
+
+singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
+
+### MSOptimizer
+class MSOptimizer(Optimizer):
+ def __call__(self, loss):
+ pn_p_g_list = self.call_with_returns(loss)
+ self.step()
+ return pn_p_g_list
+
+ def call_with_returns(self, loss):
+ pn_p_g_list = []
+ for p, g in autograd.backward(loss):
+ if p.name is None:
+ p.name = id(p)
+ self.apply(p.name, p, g)
+ pn_p_g_list.append(p.name, p, g)
+ return pn_p_g_list
+
+# MSSGD -- actually no change of code
+class MSSGD(MSOptimizer):
+ """Implements stochastic gradient descent (optionally with momentum).
+
+ Nesterov momentum is based on the formula from `On the importance of initialization and momentum in deep learning`__.
+
+ Args:
+ lr(float): learning rate
+ momentum(float, optional): momentum factor(default: 0)
+ weight_decay(float, optional): weight decay(L2 penalty)(default: 0)
+ dampening(float, optional): dampening for momentum(default: 0)
+ nesterov(bool, optional): enables Nesterov momentum(default: False)
+
+ Typical usage example:
+ >> > from singa import opt
+ >> > optimizer = opt.SGD(lr=0.1, momentum=0.9)
+ >> > optimizer.update()
+
+ __ http: // www.cs.toronto.edu / %7Ehinton / absps / momentum.pdf
+
+ .. note::
+ The implementation of SGD with Momentum / Nesterov subtly differs from
+ Sutskever et. al. and implementations in some other frameworks.
+
+ Considering the specific case of Momentum, the update can be written as
+
+ .. math::
+ v = \rho * v + g \\
+ p = p - lr * v
+
+ where p, g, v and: math: `\rho` denote the parameters, gradient,
+ velocity, and momentum respectively.
+
+ This is in contrast to Sutskever et. al. and
+ other frameworks which employ an update of the form
+
+ .. math::
+ v = \rho * v + lr * g \\
+ p = p - v
+
+ The Nesterov version is analogously modified.
+ """
+
+ def __init__(self,
+ lr=0.1,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ dtype=tensor.float32):
+ super(MSSGD, self).__init__(lr, dtype)
+
+ # init momentum
+ if type(momentum) == float or type(momentum) == int:
+ if momentum < 0.0:
+ raise ValueError("Invalid momentum value: {}".format(momentum))
+ self.momentum = Constant(momentum)
+ elif isinstance(momentum, DecayScheduler):
+ self.momentum = momentum
+ momentum = momentum.init_value
+ else:
+ raise TypeError("Wrong momentum type")
+ self.mom_value = self.momentum(self.step_counter).as_type(self.dtype)
+
+ # init dampening
+ if type(dampening) == float or type(dampening) == int:
+ self.dampening = Constant(dampening)
+ elif isinstance(dampening, DecayScheduler):
+ self.dampening = dampening
+ dampening = dampening.init_value
+ else:
+ raise TypeError("Wrong dampening type")
+ self.dam_value = self.dampening(self.step_counter).as_type(self.dtype)
+
+ # init weight_decay
+ if type(weight_decay) == float or type(weight_decay) == int:
+ if weight_decay < 0.0:
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay))
+ self.weight_decay = Constant(weight_decay)
+ elif isinstance(weight_decay, DecayScheduler):
+ self.weight_decay = weight_decay
+ else:
+ raise TypeError("Wrong weight_decay type")
+ self.decay_value = self.weight_decay(self.step_counter).as_type(
+ self.dtype)
+
+ # init other params
+ self.nesterov = nesterov
+ self.moments = dict()
+
+ # check value
+ if nesterov and (momentum <= 0 or dampening != 0):
+ raise ValueError(
+ "Nesterov momentum requires a momentum and zero dampening")
+
+ def apply(self, param_name, param_value, param_grad):
+ """Performs a single optimization step.
+
+ Args:
+ param_name(String): the name of the param
+ param_value(Tensor): param values to be update in-place
+ grad(Tensor): param gradients; the values may be updated
+ in this function; cannot use it anymore
+ """
+ assert param_value.shape == param_grad.shape, ("shape mismatch",
+ param_value.shape,
+ param_grad.shape)
+ self.device_check(param_value, self.step_counter, self.lr_value,
+ self.mom_value, self.dam_value, self.decay_value)
+
+ # derive dtype from input
+ assert param_value.dtype == self.dtype
+
+ # TODO add branch operator
+ # if self.decay_value != 0:
+ if self.weight_decay.init_value != 0:
+ singa.Axpy(self.decay_value.data, param_value.data, param_grad.data)
+
+ if self.momentum.init_value != 0:
+ if param_name not in self.moments:
+ flag = param_value.device.graph_enabled()
+ param_value.device.EnableGraph(False)
+ self.moments[param_name] = tensor.zeros_like(param_value)
+ param_value.device.EnableGraph(flag)
+
+ buf = self.moments[param_name]
+ buf *= self.mom_value
+ alpha = 1.0 - self.dam_value
+ singa.Axpy(alpha.data, param_grad.data, buf.data)
+
+ if self.nesterov:
+ singa.Axpy(self.mom_value.data, buf.data, param_grad.data)
+ else:
+ param_grad = buf
+
+ minus_lr = 0.0 - self.lr_value
+ singa.Axpy(minus_lr.data, param_grad.data, param_value.data)
+
+ def step(self):
+ # increment step counter, lr and moment
+ super().step()
+ mom_value = self.momentum(self.step_counter).as_type(self.dtype)
+ dam_value = self.dampening(self.step_counter).as_type(self.dtype)
+ decay_value = self.weight_decay(self.step_counter).as_type(self.dtype)
+ self.mom_value.copy_from(mom_value)
+ self.dam_value.copy_from(dam_value)
+ self.decay_value.copy_from(decay_value)
+
+ def get_states(self):
+ states = super().get_states()
+ if self.mom_value > 0:
+ states[
+ 'moments'] = self.moments # a dict for 1st order moments tensors
+ return states
+
+ def set_states(self, states):
+ super().set_states(states)
+ if 'moments' in states:
+ self.moments = states['moments']
+ self.mom_value = self.momentum(self.step_counter)
+
+
+# Data augmentation
+def augmentation(x, batch_size):
+ xpad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'symmetric')
+ for data_num in range(0, batch_size):
+ offset = np.random.randint(8, size=2)
+ x[data_num, :, :, :] = xpad[data_num, :,
+ offset[0]:offset[0] + x.shape[2],
+ offset[1]:offset[1] + x.shape[2]]
+ if_flip = np.random.randint(2)
+ if (if_flip):
+ x[data_num, :, :, :] = x[data_num, :, :, ::-1]
+ return x
+
+
+# Calculate accuracy
+def accuracy(pred, target):
+ # y is network output to be compared with ground truth (int)
+ y = np.argmax(pred, axis=1)
+ a = y == target
+ correct = np.array(a, "int").sum()
+ return correct
+
+
+# Data partition according to the rank
+def partition(global_rank, world_size, train_x, train_y, val_x, val_y):
+ # Partition training data
+ data_per_rank = train_x.shape[0] // world_size
+ idx_start = global_rank * data_per_rank
+ idx_end = (global_rank + 1) * data_per_rank
+ train_x = train_x[idx_start:idx_end]
+ train_y = train_y[idx_start:idx_end]
+
+ # Partition evaluation data
+ data_per_rank = val_x.shape[0] // world_size
+ idx_start = global_rank * data_per_rank
+ idx_end = (global_rank + 1) * data_per_rank
+ val_x = val_x[idx_start:idx_end]
+ val_y = val_y[idx_start:idx_end]
+ return train_x, train_y, val_x, val_y
+
+
+# Function to all reduce NUMPY accuracy and loss from multiple devices
+def reduce_variable(variable, dist_opt, reducer):
+ reducer.copy_from_numpy(variable)
+ dist_opt.all_reduce(reducer.data)
+ dist_opt.wait()
+ output = tensor.to_numpy(reducer)
+ return output
+
+
+def resize_dataset(x, image_size):
+ num_data = x.shape[0]
+ dim = x.shape[1]
+ X = np.zeros(shape=(num_data, dim, image_size, image_size),
+ dtype=np.float32)
+ for n in range(0, num_data):
+ for d in range(0, dim):
+ X[n, d, :, :] = np.array(Image.fromarray(x[n, d, :, :]).resize(
+ (image_size, image_size), Image.BILINEAR),
+ dtype=np.float32)
+ return X
+
+
+def run(global_rank,
+ world_size,
+ local_rank,
+ max_epoch,
+ batch_size,
+ model,
+ data,
+ mssgd,
+ graph,
+ verbosity,
+ dist_option='plain',
+ spars=None,
+ precision='float32'):
+ # dev = device.create_cuda_gpu_on(local_rank) # need to change to CPU device for CPU-only machines
+ dev = device.get_default_device()
+ dev.SetRandSeed(0)
+ np.random.seed(0)
+
+ if data == 'cifar10':
+ from data import cifar10
+ train_x, train_y, val_x, val_y = cifar10.load()
+ elif data == 'cifar100':
+ from data import cifar100
+ train_x, train_y, val_x, val_y = cifar100.load()
+ elif data == 'mnist':
+ from data import mnist
+ train_x, train_y, val_x, val_y = mnist.load()
+
+
+ num_channels = train_x.shape[1]
+ image_size = train_x.shape[2]
+ data_size = np.prod(train_x.shape[1:train_x.ndim]).item()
+ num_classes = (np.max(train_y) + 1).item()
+
+ if model == 'resnet':
+ from model import resnet
+ model = resnet.resnet50(num_channels=num_channels,
+ num_classes=num_classes)
+ elif model == 'xceptionnet':
+ from model import xceptionnet
+ model = xceptionnet.create_model(num_channels=num_channels,
+ num_classes=num_classes)
+ elif model == 'cnn':
+ from model import cnn
+ model = cnn.create_model(num_channels=num_channels,
+ num_classes=num_classes)
+ elif model == 'alexnet':
+ from model import alexnet
+ model = alexnet.create_model(num_channels=num_channels,
+ num_classes=num_classes)
+ elif model == 'mlp':
+ import os, sys, inspect
+ current = os.path.dirname(
+ os.path.abspath(inspect.getfile(inspect.currentframe())))
+ parent = os.path.dirname(current)
+ sys.path.insert(0, parent)
+ from mlp import model
+ model = model.create_model(data_size=data_size,
+ num_classes=num_classes)
+
+ elif model == 'msmlp':
+ import os, sys, inspect
+ current = os.path.dirname(
+ os.path.abspath(inspect.getfile(inspect.currentframe())))
+ parent = os.path.dirname(current)
+ sys.path.insert(0, parent)
+ from msmlp import model
+ model = model.create_model(data_size=data_size,
+ num_classes=num_classes)
+
+ # For distributed training, sequential has better performance
+ if hasattr(mssgd, "communicator"):
+ DIST = True
+ sequential = True
+ else:
+ DIST = False
+ sequential = False
+
+ if DIST:
+ train_x, train_y, val_x, val_y = partition(global_rank, world_size,
+ train_x, train_y, val_x,
+ val_y)
+
+ if model.dimension == 4:
+ tx = tensor.Tensor(
+ (batch_size, num_channels, model.input_size, model.input_size), dev,
+ singa_dtype[precision])
+ elif model.dimension == 2:
+ tx = tensor.Tensor((batch_size, data_size), dev, singa_dtype[precision])
+ np.reshape(train_x, (train_x.shape[0], -1))
+ np.reshape(val_x, (val_x.shape[0], -1))
+
+ ty = tensor.Tensor((batch_size,), dev, tensor.int32)
+ num_train_batch = train_x.shape[0] // batch_size
+ num_val_batch = val_x.shape[0] // batch_size
+ idx = np.arange(train_x.shape[0], dtype=np.int32)
+
+ # Attach model to graph
+ model.set_optimizer(mssgd)
+ model.compile([tx], is_train=True, use_graph=graph, sequential=sequential)
+ dev.SetVerbosity(verbosity)
+
+ # Training and evaluation loop
+ for epoch in range(max_epoch):
+ start_time = time.time()
+ np.random.shuffle(idx)
+
+ if global_rank == 0:
+ print('Starting Epoch %d:' % (epoch))
+
+ # Training phase
+ train_correct = np.zeros(shape=[1], dtype=np.float32)
+ test_correct = np.zeros(shape=[1], dtype=np.float32)
+ train_loss = np.zeros(shape=[1], dtype=np.float32)
+
+ model.train()
+ print ("num_train_batch: \n", num_train_batch)
+ print ()
+ for b in range(num_train_batch):
+ if b % 200 == 0:
+ print ("b: \n", b)
+ # Generate the patch data in this iteration
+ x = train_x[idx[b * batch_size:(b + 1) * batch_size]]
+ if model.dimension == 4:
+ x = augmentation(x, batch_size)
+ if (image_size != model.input_size):
+ x = resize_dataset(x, model.input_size)
+ x = x.astype(np_dtype[precision])
+ y = train_y[idx[b * batch_size:(b + 1) * batch_size]]
+
+
+ synflow_flag = False
+ # Train the model
+ if epoch == (max_epoch - 1) and b == (num_train_batch - 1): ### synflow calcuation for the last batch
+ print ("last epoch calculate synflow")
+ synflow_flag = True
+ ### step 1: all one input
+ # Copy the patch data into input tensors
+ tx.copy_from_numpy(np.ones(x.shape, dtype=np.float32))
+ ty.copy_from_numpy(y)
+ ### step 2: all weights turned to positive (done)
+ ### step 3: new loss (done)
+ pn_p_g_list, out, loss = model(tx, ty,dist_option, spars, synflow_flag)
+ ### step 4: calculate the multiplication of weights
+ synflow_score = 0.0
+ for pn_p_g_item in pn_p_g_list:
+ print ("calculate weight param * grad parameter name: \n", pn_p_g_item[0])
+ if len(pn_p_g_item[1].data.shape) == 2: # param_value.data is "weight"
+ synflow_score += np.sum(np.absolute(tensor.to_numpy(pn_p_g_item[1].data) * tensor.to_numpy(pn_p_g_item[2].data)))
+ print ("synflow_score: \n", synflow_score)
+ elif epoch == (max_epoch - 1) and b == (num_train_batch - 2): # all weights turned to positive
+ # Copy the patch data into input tensors
+ tx.copy_from_numpy(x)
+ ty.copy_from_numpy(y)
+ pn_p_g_list, out, loss = model(tx, ty, dist_option, spars, synflow_flag)
+ train_correct += accuracy(tensor.to_numpy(out), y)
+ train_loss += tensor.to_numpy(loss)[0]
+ # all params turned to positive
+ for pn_p_g_item in pn_p_g_list:
+ print ("absolute value parameter name: \n", pn_p_g_item[0])
+ pn_p_g_item[1] = tensor.abs(pn_p_g_item[1]) # tensor variables
+ else: # normal train steps
+ # Copy the patch data into input tensors
+ tx.copy_from_numpy(x)
+ ty.copy_from_numpy(y)
+ pn_p_g_list, out, loss = model(tx, ty, synflow_flag, dist_option, spars)
+ train_correct += accuracy(tensor.to_numpy(out), y)
+ train_loss += tensor.to_numpy(loss)[0]
+
+ if DIST:
+ # Reduce the evaluation accuracy and loss from multiple devices
+ reducer = tensor.Tensor((1,), dev, tensor.float32)
+ train_correct = reduce_variable(train_correct, mssgd, reducer)
+ train_loss = reduce_variable(train_loss, mssgd, reducer)
+
+ if global_rank == 0:
+ print('Training loss = %f, training accuracy = %f' %
+ (train_loss, train_correct /
+ (num_train_batch * batch_size * world_size)),
+ flush=True)
+
+ # Evaluation phase
+ model.eval()
+ for b in range(num_val_batch):
+ x = val_x[b * batch_size:(b + 1) * batch_size]
+ if model.dimension == 4:
+ if (image_size != model.input_size):
+ x = resize_dataset(x, model.input_size)
+ x = x.astype(np_dtype[precision])
+ y = val_y[b * batch_size:(b + 1) * batch_size]
+ tx.copy_from_numpy(x)
+ ty.copy_from_numpy(y)
+ out_test = model(tx)
+ test_correct += accuracy(tensor.to_numpy(out_test), y)
+
+ if DIST:
+ # Reduce the evaulation accuracy from multiple devices
+ test_correct = reduce_variable(test_correct, mssgd, reducer)
+
+ # Output the evaluation accuracy
+ if global_rank == 0:
+ print('Evaluation accuracy = %f, Elapsed Time = %fs' %
+ (test_correct / (num_val_batch * batch_size * world_size),
+ time.time() - start_time),
+ flush=True)
+
+ dev.PrintTimeProfiling()
+
+
+if __name__ == '__main__':
+ # Use argparse to get command config: max_epoch, model, data, etc., for single gpu training
+ parser = argparse.ArgumentParser(
+ description='Training using the autograd and graph.')
+ parser.add_argument(
+ 'model',
+ choices=['cnn', 'resnet', 'xceptionnet', 'mlp', 'msmlp', 'alexnet'],
+ default='cnn')
+ parser.add_argument('data',
+ choices=['mnist', 'cifar10', 'cifar100'],
+ default='mnist')
+ parser.add_argument('-p',
+ choices=['float32', 'float16'],
+ default='float32',
+ dest='precision')
+ parser.add_argument('-m',
+ '--max-epoch',
+ default=100,
+ type=int,
+ help='maximum epochs',
+ dest='max_epoch')
+ parser.add_argument('-b',
+ '--batch-size',
+ default=64,
+ type=int,
+ help='batch size',
+ dest='batch_size')
+ parser.add_argument('-l',
+ '--learning-rate',
+ default=0.005,
+ type=float,
+ help='initial learning rate',
+ dest='lr')
+ # Determine which gpu to use
+ parser.add_argument('-i',
+ '--device-id',
+ default=0,
+ type=int,
+ help='which GPU to use',
+ dest='device_id')
+ parser.add_argument('-g',
+ '--disable-graph',
+ default='True',
+ action='store_false',
+ help='disable graph',
+ dest='graph')
+ parser.add_argument('-v',
+ '--log-verbosity',
+ default=0,
+ type=int,
+ help='logging verbosity',
+ dest='verbosity')
+
+ args = parser.parse_args()
+
+ mssgd = MSSGD(lr=args.lr, momentum=0.9, weight_decay=1e-5, dtype=singa_dtype[args.precision])
+ run(0,
+ 1,
+ args.device_id,
+ args.max_epoch,
+ args.batch_size,
+ args.model,
+ args.data,
+ mssgd,
+ args.graph,
+ args.verbosity,
+ precision=args.precision)
diff --git a/examples/cnn_ms/train_mpi.py b/examples/cnn_ms/train_mpi.py
new file mode 100644
index 0000000..563d4b2
--- /dev/null
+++ b/examples/cnn_ms/train_mpi.py
@@ -0,0 +1,91 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+from singa import singa_wrap as singa
+from singa import opt
+from singa import tensor
+import argparse
+import train_cnn
+
+singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
+
+if __name__ == '__main__':
+ # Use argparse to get command config: max_epoch, model, data, etc., for single gpu training
+ parser = argparse.ArgumentParser(
+ description='Training using the autograd and graph.')
+ parser.add_argument('model',
+ choices=['cnn', 'resnet', 'xceptionnet', 'mlp'],
+ default='cnn')
+ parser.add_argument('data', choices=['mnist', 'cifar10', 'cifar100'], default='mnist')
+ parser.add_argument('-p',
+ choices=['float32', 'float16'],
+ default='float32',
+ dest='precision')
+ parser.add_argument('-m',
+ '--max-epoch',
+ default=10,
+ type=int,
+ help='maximum epochs',
+ dest='max_epoch')
+ parser.add_argument('-b',
+ '--batch-size',
+ default=64,
+ type=int,
+ help='batch size',
+ dest='batch_size')
+ parser.add_argument('-l',
+ '--learning-rate',
+ default=0.005,
+ type=float,
+ help='initial learning rate',
+ dest='lr')
+ parser.add_argument('-d',
+ '--dist-option',
+ default='plain',
+ choices=['plain','half','partialUpdate','sparseTopK','sparseThreshold'],
+ help='distibuted training options',
+ dest='dist_option') # currently partialUpdate support graph=False only
+ parser.add_argument('-s',
+ '--sparsification',
+ default='0.05',
+ type=float,
+ help='the sparsity parameter used for sparsification, between 0 to 1',
+ dest='spars')
+ parser.add_argument('-g',
+ '--disable-graph',
+ default='True',
+ action='store_false',
+ help='disable graph',
+ dest='graph')
+ parser.add_argument('-v',
+ '--log-verbosity',
+ default=0,
+ type=int,
+ help='logging verbosity',
+ dest='verbosity')
+
+ args = parser.parse_args()
+
+ sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5, dtype=singa_dtype[args.precision])
+ sgd = opt.DistOpt(sgd)
+
+ train_cnn.run(sgd.global_rank, sgd.world_size, sgd.local_rank, args.max_epoch,
+ args.batch_size, args.model, args.data, sgd, args.graph,
+ args.verbosity, args.dist_option, args.spars, args.precision)
diff --git a/examples/cnn_ms/train_multiprocess.py b/examples/cnn_ms/train_multiprocess.py
new file mode 100644
index 0000000..182dd35
--- /dev/null
+++ b/examples/cnn_ms/train_multiprocess.py
@@ -0,0 +1,111 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+from singa import singa_wrap as singa
+from singa import opt
+from singa import tensor
+import argparse
+import train_cnn
+import multiprocessing
+
+singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
+
+def run(args, local_rank, world_size, nccl_id):
+ sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5, dtype=singa_dtype[args.precision])
+ sgd = opt.DistOpt(sgd, nccl_id=nccl_id, local_rank=local_rank, world_size=world_size)
+ train_cnn.run(sgd.global_rank, sgd.world_size, sgd.local_rank, args.max_epoch,
+ args.batch_size, args.model, args.data, sgd, args.graph,
+ args.verbosity, args.dist_option, args.spars, args.precision)
+
+
+if __name__ == '__main__':
+ # Use argparse to get command config: max_epoch, model, data, etc., for single gpu training
+ parser = argparse.ArgumentParser(
+ description='Training using the autograd and graph.')
+ parser.add_argument('model',
+ choices=['resnet', 'xceptionnet', 'cnn', 'mlp'],
+ default='cnn')
+ parser.add_argument('data', choices=['cifar10', 'cifar100', 'mnist'], default='mnist')
+ parser.add_argument('-p',
+ choices=['float32', 'float16'],
+ default='float32',
+ dest='precision')
+ parser.add_argument('-m',
+ '--max-epoch',
+ default=10,
+ type=int,
+ help='maximum epochs',
+ dest='max_epoch')
+ parser.add_argument('-b',
+ '--batch-size',
+ default=64,
+ type=int,
+ help='batch size',
+ dest='batch_size')
+ parser.add_argument('-l',
+ '--learning-rate',
+ default=0.005,
+ type=float,
+ help='initial learning rate',
+ dest='lr')
+ parser.add_argument('-w',
+ '--world-size',
+ default=2,
+ type=int,
+ help='number of gpus to be used',
+ dest='world_size')
+ parser.add_argument('-d',
+ '--dist-option',
+ default='plain',
+ choices=['plain','half','partialUpdate','sparseTopK','sparseThreshold'],
+ help='distibuted training options',
+ dest='dist_option') # currently partialUpdate support graph=False only
+ parser.add_argument('-s',
+ '--sparsification',
+ default='0.05',
+ type=float,
+ help='the sparsity parameter used for sparsification, between 0 to 1',
+ dest='spars')
+ parser.add_argument('-g',
+ '--disable-graph',
+ default='True',
+ action='store_false',
+ help='disable graph',
+ dest='graph')
+ parser.add_argument('-v',
+ '--log-verbosity',
+ default=0,
+ type=int,
+ help='logging verbosity',
+ dest='verbosity')
+
+ args = parser.parse_args()
+
+ # Generate a NCCL ID to be used for collective communication
+ nccl_id = singa.NcclIdHolder()
+
+ process = []
+ for local_rank in range(0, args.world_size):
+ process.append(
+ multiprocessing.Process(target=run,
+ args=(args, local_rank, args.world_size, nccl_id)))
+
+ for p in process:
+ p.start()
diff --git a/examples/hfl/README.md b/examples/hfl/README.md
index cf20e64..2916bc5 100644
--- a/examples/hfl/README.md
+++ b/examples/hfl/README.md
@@ -27,7 +27,7 @@
## Preparation
-Go to the Conda environment that contains the Singa library, and run
+Go to the Conda environment that contains the Singa library, and install the required libraries.
```bash
pip install -r requirements.txt
@@ -41,18 +41,18 @@
# 3. run the following command which:
# (1) splits the dataset into N subsets
# (2) splits each subsets into train set and test set (8:2)
-python -m bank N
+python -m bank 3
```
## Run the example
-Run the server first (set the number of epochs to 3)
+Run the server first (set the maximum number of epochs to 3 by the "-m" parameter)
```bash
python -m src.server -m 3 --num_clients 3
```
-Then, start 3 clients in different terminal
+Then, start 3 clients in different terminals (similarly set the maximum number of epochs to 3)
```bash
python -m src.client --model mlp --data bank -m 3 -i 0 -d non-iid
@@ -60,4 +60,4 @@
python -m src.client --model mlp --data bank -m 3 -i 2 -d non-iid
```
-Finally, the server and clients finish the FL training.
\ No newline at end of file
+Finally, the server and clients finish the FL training.
diff --git a/examples/hfl/src/client.py b/examples/hfl/src/client.py
index 80ab11f..dbff42b 100644
--- a/examples/hfl/src/client.py
+++ b/examples/hfl/src/client.py
@@ -40,6 +40,7 @@
np_dtype = {"float16": np.float16, "float32": np.float32}
singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
+
class Client:
"""Client sends and receives protobuf messages.
@@ -63,6 +64,7 @@
Args:
global_rank (int, optional): The rank in training process. Defaults to 0.
+ Provided by the '-i' parameter (device_id) in the running script.
host (str, optional): Host ip address. Defaults to '127.0.0.1'.
port (str, optional): Port. Defaults to 1234.
"""
diff --git a/examples/hfl/src/server.py b/examples/hfl/src/server.py
index 7450cc1..68780e1 100644
--- a/examples/hfl/src/server.py
+++ b/examples/hfl/src/server.py
@@ -80,6 +80,7 @@
"""Start pair each client to a global rank"""
for _ in range(self.num_clients):
conn, addr = self.sock.accept()
+ # rank is the global device_id when initializing the client
rank = utils.receive_int(conn)
self.conns[rank] = conn
self.addrs[rank] = addr
diff --git a/examples/model_selection/Trails/README.md b/examples/model_selection/Trails/README.md
index 39bd012..8f14525 100644
--- a/examples/model_selection/Trails/README.md
+++ b/examples/model_selection/Trails/README.md
@@ -23,7 +23,7 @@

-# Build & Run examples
+# Build & Run examples:
## Singa + PostgreSQL
diff --git a/examples/model_slicing_psql/README.md b/examples/model_slicing_psql/README.md
new file mode 100644
index 0000000..bfdbd5e
--- /dev/null
+++ b/examples/model_slicing_psql/README.md
@@ -0,0 +1,22 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Dynamic Model Slicing on PostgreSQL
+
+Examples inside this folder show how to dynamically slice a model for a subset of database records dynamically specified by a corresponding SQL query inside RDBMS, such as PostgreSQL.
\ No newline at end of file
diff --git a/examples/msmlp/model.py b/examples/msmlp/model.py
new file mode 100644
index 0000000..2a4d0e6
--- /dev/null
+++ b/examples/msmlp/model.py
@@ -0,0 +1,202 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import layer
+from singa import model
+from singa import tensor
+from singa import opt
+from singa import device
+from singa.autograd import Operator
+from singa.layer import Layer
+from singa import singa_wrap as singa
+import argparse
+import numpy as np
+
+np_dtype = {"float16": np.float16, "float32": np.float32}
+
+singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
+
+#### self-defined loss begin
+
+### from autograd.py
+class SumError(Operator):
+
+ def __init__(self):
+ super(SumError, self).__init__()
+ # self.t = t.data
+
+ def forward(self, x):
+ # self.err = singa.__sub__(x, self.t)
+ self.data_x = x
+ # sqr = singa.Square(self.err)
+ # loss = singa.SumAll(sqr)
+ loss = singa.SumAll(x)
+ # self.n = 1
+ # for s in x.shape():
+ # self.n *= s
+ # loss /= self.n
+ return loss
+
+ def backward(self, dy=1.0):
+ # dx = self.err
+ dev = device.get_default_device()
+ dx = tensor.Tensor(self.data_x.shape, dev, singa_dtype['float32'])
+ dx.copy_from_numpy(np.ones(self.data_x.shape))
+ # dx *= float(2 / self.n)
+ dx *= dy
+ return dx
+
+def se_loss(x):
+ # assert x.shape == t.shape, "input and target shape different: %s, %s" % (
+ # x.shape, t.shape)
+ return SumError()(x)[0]
+
+### from layer.py
+class SumErrorLayer(Layer):
+ """
+ Generate a MeanSquareError operator
+ """
+
+ def __init__(self):
+ super(SumErrorLayer, self).__init__()
+
+ def forward(self, x):
+ return se_loss(x)
+
+#### self-defined loss end
+
+class MSMLP(model.Model):
+
+ def __init__(self, data_size=10, perceptron_size=100, num_classes=10):
+ super(MSMLP, self).__init__()
+ self.num_classes = num_classes
+ self.dimension = 2
+
+ self.relu = layer.ReLU()
+ self.linear1 = layer.Linear(perceptron_size)
+ self.linear2 = layer.Linear(num_classes)
+ self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
+ self.sum_error = SumErrorLayer()
+
+ def forward(self, inputs):
+ y = self.linear1(inputs)
+ y = self.relu(y)
+ y = self.linear2(y)
+ return y
+
+ def train_one_batch(self, x, y, synflow_flag, dist_option, spars):
+ out = self.forward(x)
+ if synflow_flag:
+ loss = self.sum_error(out)
+ else: # normal training
+ loss = self.softmax_cross_entropy(out, y)
+
+ if dist_option == 'plain':
+ pn_p_g_list = self.optimizer(loss)
+ elif dist_option == 'half':
+ self.optimizer.backward_and_update_half(loss)
+ elif dist_option == 'partialUpdate':
+ self.optimizer.backward_and_partial_update(loss)
+ elif dist_option == 'sparseTopK':
+ self.optimizer.backward_and_sparse_update(loss,
+ topK=True,
+ spars=spars)
+ elif dist_option == 'sparseThreshold':
+ self.optimizer.backward_and_sparse_update(loss,
+ topK=False,
+ spars=spars)
+ return pn_p_g_list, out, loss
+
+ def set_optimizer(self, optimizer):
+ self.optimizer = optimizer
+
+
+def create_model(pretrained=False, **kwargs):
+ """Constructs a CNN model.
+
+ Args:
+ pretrained (bool): If True, returns a pre-trained model.
+
+ Returns:
+ The created CNN model.
+ """
+ model = MSMLP(**kwargs)
+
+ return model
+
+
+__all__ = ['MLP', 'create_model']
+
+if __name__ == "__main__":
+ np.random.seed(0)
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-p',
+ choices=['float32', 'float16'],
+ default='float32',
+ dest='precision')
+ parser.add_argument('-g',
+ '--disable-graph',
+ default='True',
+ action='store_false',
+ help='disable graph',
+ dest='graph')
+ parser.add_argument('-m',
+ '--max-epoch',
+ default=1001,
+ type=int,
+ help='maximum epochs',
+ dest='max_epoch')
+ args = parser.parse_args()
+
+ # generate the boundary
+ f = lambda x: (5 * x + 1)
+ bd_x = np.linspace(-1.0, 1, 200)
+ bd_y = f(bd_x)
+
+ # generate the training data
+ x = np.random.uniform(-1, 1, 400)
+ y = f(x) + 2 * np.random.randn(len(x))
+
+ # choose one precision
+ precision = singa_dtype[args.precision]
+ np_precision = np_dtype[args.precision]
+
+ # convert training data to 2d space
+ label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)]).astype(np.int32)
+ data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np_precision)
+
+ dev = device.create_cuda_gpu_on(0)
+ sgd = opt.SGD(0.1, 0.9, 1e-5, dtype=singa_dtype[args.precision])
+ tx = tensor.Tensor((400, 2), dev, precision)
+ ty = tensor.Tensor((400,), dev, tensor.int32)
+ model = MLP(data_size=2, perceptron_size=3, num_classes=2)
+
+ # attach model to graph
+ model.set_optimizer(sgd)
+ model.compile([tx], is_train=True, use_graph=args.graph, sequential=True)
+ model.train()
+
+ for i in range(args.max_epoch):
+ tx.copy_from_numpy(data)
+ ty.copy_from_numpy(label)
+ out, loss = model(tx, ty, 'fp32', spars=None)
+
+ if i % 100 == 0:
+ print("training loss = ", tensor.to_numpy(loss)[0])
diff --git a/examples/msmlp/native.py b/examples/msmlp/native.py
new file mode 100644
index 0000000..a82ec3b
--- /dev/null
+++ b/examples/msmlp/native.py
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import tensor
+from singa.tensor import Tensor
+from singa import autograd
+from singa import opt
+import numpy as np
+from singa import device
+import argparse
+
+np_dtype = {"float16": np.float16, "float32": np.float32}
+
+singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-p',
+ choices=['float32', 'float16'],
+ default='float32',
+ dest='precision')
+ parser.add_argument('-m',
+ '--max-epoch',
+ default=1001,
+ type=int,
+ help='maximum epochs',
+ dest='max_epoch')
+ args = parser.parse_args()
+
+ np.random.seed(0)
+
+ autograd.training = True
+
+ # prepare training data in numpy array
+
+ # generate the boundary
+ f = lambda x: (5 * x + 1)
+ bd_x = np.linspace(-1.0, 1, 200)
+ bd_y = f(bd_x)
+
+ # generate the training data
+ x = np.random.uniform(-1, 1, 400)
+ y = f(x) + 2 * np.random.randn(len(x))
+
+ # convert training data to 2d space
+ label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)])
+ data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32)
+
+ def to_categorical(y, num_classes):
+ """
+ Converts a class vector (integers) to binary class matrix.
+
+ Args:
+ y: class vector to be converted into a matrix
+ (integers from 0 to num_classes).
+ num_classes: total number of classes.
+
+ Returns:
+ A binary matrix representation of the input.
+ """
+ y = np.array(y, dtype="int")
+ n = y.shape[0]
+ categorical = np.zeros((n, num_classes))
+ categorical[np.arange(n), y] = 1
+ return categorical
+
+ label = to_categorical(label, 2).astype(np.float32)
+ print("train_data_shape:", data.shape)
+ print("train_label_shape:", label.shape)
+
+ precision = singa_dtype[args.precision]
+ np_precision = np_dtype[args.precision]
+
+ dev = device.create_cuda_gpu()
+
+ inputs = Tensor(data=data, device=dev)
+ target = Tensor(data=label, device=dev)
+
+ inputs = inputs.as_type(precision)
+ target = target.as_type(tensor.int32)
+
+ w0_np = np.random.normal(0, 0.1, (2, 3)).astype(np_precision)
+ w0 = Tensor(data=w0_np,
+ device=dev,
+ dtype=precision,
+ requires_grad=True,
+ stores_grad=True)
+ b0 = Tensor(shape=(3,),
+ device=dev,
+ dtype=precision,
+ requires_grad=True,
+ stores_grad=True)
+ b0.set_value(0.0)
+
+ w1_np = np.random.normal(0, 0.1, (3, 2)).astype(np_precision)
+ w1 = Tensor(data=w1_np,
+ device=dev,
+ dtype=precision,
+ requires_grad=True,
+ stores_grad=True)
+ b1 = Tensor(shape=(2,),
+ device=dev,
+ dtype=precision,
+ requires_grad=True,
+ stores_grad=True)
+ b1.set_value(0.0)
+
+ sgd = opt.SGD(0.05, 0.8)
+
+ # training process
+ for i in range(args.max_epoch):
+ x = autograd.matmul(inputs, w0)
+ x = autograd.add_bias(x, b0)
+ x = autograd.relu(x)
+ x = autograd.matmul(x, w1)
+ x = autograd.add_bias(x, b1)
+ loss = autograd.softmax_cross_entropy(x, target)
+ sgd(loss)
+
+ if i % 100 == 0:
+ print("%d, training loss = " % i, tensor.to_numpy(loss)[0])