Merge pull request #816 from apache/master
Update dev branch due to master branch hot-fix of onnx example error
diff --git a/examples/onnx/densenet121.py b/examples/onnx/densenet121.py
index 14cfa5f..7027164 100644
--- a/examples/onnx/densenet121.py
+++ b/examples/onnx/densenet121.py
@@ -25,7 +25,7 @@
from singa import autograd
from singa import sonnx
import onnx
-from utils import download_model, update_batch_size, check_exist_or_download
+from utils import download_model, check_exist_or_download
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
@@ -44,7 +44,7 @@
return img
-def get_image_labe():
+def get_image_label():
# download label
label_url = 'https://s3.amazonaws.com/onnx-model-zoo/synset.txt'
with open(check_exist_or_download(label_url), 'r') as f:
@@ -56,56 +56,52 @@
return img, labels
-class Infer:
+class MyModel(sonnx.SONNXModel):
- def __init__(self, sg_ir):
- self.sg_ir = sg_ir
- for idx, tens in sg_ir.tensor_map.items():
- # allow the tensors to be updated
- tens.requires_grad = True
- tens.stores_grad = True
- sg_ir.tensor_map[idx] = tens
+ def __init__(self, onnx_model):
+ super(MyModel, self).__init__(onnx_model)
- def forward(self, x):
- return sg_ir.run([x])[0]
+ def forward(self, *x):
+ y = super(MyModel, self).forward(*x)
+ return y[0]
+
+ def train_one_batch(self, x, y):
+ pass
if __name__ == "__main__":
+
+ download_dir = '/tmp'
url = 'https://s3.amazonaws.com/download.onnx/models/opset_9/densenet121.tar.gz'
- download_dir = '/tmp/'
model_path = os.path.join(download_dir, 'densenet121', 'model.onnx')
logging.info("onnx load model...")
download_model(url)
onnx_model = onnx.load(model_path)
- # set batch size
- onnx_model = update_batch_size(onnx_model, 1)
+ # inference demo
+ logging.info("preprocessing...")
+ img, labels = get_image_label()
+ img = preprocess(img)
+ # sg_ir = sonnx.prepare(onnx_model) # run without graph
+ # y = sg_ir.run([img])
- # prepare the model
- logging.info("prepare model...")
+ logging.info("model compling...")
dev = device.create_cuda_gpu()
- sg_ir = sonnx.prepare(onnx_model, device=dev)
- autograd.training = False
- model = Infer(sg_ir)
+ x = tensor.Tensor(device=dev, data=img)
+ model = MyModel(onnx_model)
+ model.compile([x], is_train=False, use_graph=True, sequential=True)
# verifty the test
# from utils import load_dataset
- # inputs, ref_outputs = load_dataset(
- # os.path.join('/tmp', 'densenet121', 'test_data_set_0'))
+ # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'densenet121', 'test_data_set_0'))
# x_batch = tensor.Tensor(device=dev, data=inputs[0])
- # outputs = model.forward(x_batch)
+ # outputs = sg_ir.run([x_batch])
# for ref_o, o in zip(ref_outputs, outputs):
# np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
- # inference
- logging.info("preprocessing...")
- img, labels = get_image_labe()
- img = preprocess(img)
-
logging.info("model running...")
- x_batch = tensor.Tensor(device=dev, data=img)
- y = model.forward(x_batch)
+ y = model.forward(x)
logging.info("postprocessing...")
y = tensor.softmax(y)
@@ -113,4 +109,4 @@
scores = np.squeeze(scores)
a = np.argsort(scores)[::-1]
for i in a[0:5]:
- logging.info('class=%s ; probability=%f' % (labels[i], scores[i]))
+ logging.info('class=%s ; probability=%f' % (labels[i], scores[i]))
\ No newline at end of file
diff --git a/examples/onnx/shufflenetv1.py b/examples/onnx/shufflenetv1.py
index 35fa777..2ace96e 100644
--- a/examples/onnx/shufflenetv1.py
+++ b/examples/onnx/shufflenetv1.py
@@ -26,9 +26,7 @@
from singa import autograd
from singa import sonnx
import onnx
-from utils import download_model
-from utils import update_batch_size
-from utils import check_exist_or_download
+from utils import download_model, check_exist_or_download
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
@@ -56,51 +54,53 @@
return img, labels
-class Infer:
+class MyModel(sonnx.SONNXModel):
- def __init__(self, sg_ir):
- self.sg_ir = sg_ir
- for idx, tens in sg_ir.tensor_map.items():
- tens.require_grad = True
- tens.store_grad = True
- sg_ir.tensor_map[idx] = tens
+ def __init__(self, onnx_model):
+ super(MyModel, self).__init__(onnx_model)
- def forward(self, x):
- return sg_ir.run([x])[0]
+ def forward(self, *x):
+ y = super(MyModel, self).forward(*x)
+ return y[0]
+
+ def train_one_batch(self, x, y):
+ pass
if __name__ == '__main__':
+ download_dir = '/tmp'
url = 'https://github.com/onnx/models/raw/master/vision/classification/shufflenet/model/shufflenet-9.tar.gz'
- download_dir = "/tmp/"
model_path = os.path.join(download_dir, 'shufflenet', 'model.onnx')
- logging.info("onnx load model....")
+
+ logging.info("onnx load model...")
download_model(url)
onnx_model = onnx.load(model_path)
- # setting batch size
- onnx_model = update_batch_size(onnx_model, 1)
- # preparing the model
- logging.info("preparing model...")
- dev = device.create_cuda_gpu()
- sg_ir = sonnx.prepare(onnx_model, device=dev)
- autograd.training = False
- model = Infer(sg_ir)
-
- # verifying the test dataset
- #from utils import load_dataset
- #inputs,ref_outputs = load_dataset(os.path.join('/tmp','shufflenet','test_data_set_0'))
- #x_batch = tensor.Tensor(device = dev,data=inputs[0])
- #outputs = model.forward(x_batch)
- # for ref_o,o in zip(ref_outputs,outputs):
- # np.testing.assert_almost_equal(ref_o,tensor.to_numpy(o),4)
-
- # inference
+
+ # inference demo
logging.info("preprocessing...")
img, labels = get_image_label()
img = preprocess(img)
- x_batch = tensor.Tensor(device=dev, data=img)
- logging.info("model running....")
- y = model.forward(x_batch)
- logging.info("postprocessing....")
+ # sg_ir = sonnx.prepare(onnx_model) # run without graph
+ # y = sg_ir.run([img])
+
+ logging.info("model compling...")
+ dev = device.create_cuda_gpu()
+ x = tensor.Tensor(device=dev, data=img)
+ model = MyModel(onnx_model)
+ model.compile([x], is_train=False, use_graph=True, sequential=True)
+
+ # verifty the test
+ # from utils import load_dataset
+ # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'shufflenet', 'test_data_set_0'))
+ # x_batch = tensor.Tensor(device=dev, data=inputs[0])
+ # outputs = sg_ir.run([x_batch])
+ # for ref_o, o in zip(ref_outputs, outputs):
+ # np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
+
+ logging.info("model running...")
+ y = model.forward(x)
+
+ logging.info("postprocessing...")
y = tensor.softmax(y)
scores = tensor.to_numpy(y)
scores = np.squeeze(scores)
diff --git a/examples/onnx/shufflenetv2.py b/examples/onnx/shufflenetv2.py
index 74dd794..60f84a4 100644
--- a/examples/onnx/shufflenetv2.py
+++ b/examples/onnx/shufflenetv2.py
@@ -43,7 +43,7 @@
return img
-def get_image_labe():
+def get_image_label():
# download label
label_url = 'https://s3.amazonaws.com/onnx-model-zoo/synset.txt'
with open(check_exist_or_download(label_url), 'r') as f:
@@ -81,7 +81,7 @@
# inference
logging.info("preprocessing...")
- img, labels = get_image_labe()
+ img, labels = get_image_label()
img = preprocess(img)
# sg_ir = sonnx.prepare(onnx_model) # run without graph
# y = sg_ir.run([img])
diff --git a/examples/onnx/squeezenet.py b/examples/onnx/squeezenet.py
index 8a6ecf5..09d7708 100644
--- a/examples/onnx/squeezenet.py
+++ b/examples/onnx/squeezenet.py
@@ -25,7 +25,7 @@
from singa import autograd
from singa import sonnx
import onnx
-from utils import download_model, update_batch_size, check_exist_or_download
+from utils import download_model, check_exist_or_download
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
@@ -56,24 +56,32 @@
return img, labels
-class Infer:
-
- def __init__(self, sg_ir):
- self.sg_ir = sg_ir
- for idx, tens in sg_ir.tensor_map.items():
- # allow the tensors to be updated
- tens.requires_grad = True
- tens.stores_grad = True
- sg_ir.tensor_map[idx] = tens
-
- def forward(self, x):
- return sg_ir.run([x])[0]
+def get_image_label():
+ # download label
+ label_url = 'https://s3.amazonaws.com/onnx-model-zoo/synset.txt'
+ with open(check_exist_or_download(label_url), 'r') as f:
+ labels = [l.rstrip() for l in f]
+ image_url = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg'
+ img = Image.open(check_exist_or_download(image_url))
+ return img, labels
-if __name__ == "__main__":
+class MyModel(sonnx.SONNXModel):
+ def __init__(self, onnx_model):
+ super(MyModel, self).__init__(onnx_model)
+
+ def forward(self, *x):
+ y = super(MyModel, self).forward(*x)
+ return y[0]
+
+ def train_one_batch(self, x, y):
+ pass
+
+
+if __name__ == '__main__':
+ download_dir = '/tmp'
url = 'https://github.com/onnx/models/raw/master/vision/classification/squeezenet/model/squeezenet1.1-7.tar.gz'
- download_dir = '/tmp/'
model_path = os.path.join(download_dir, 'squeezenet1.1',
'squeezenet1.1.onnx')
@@ -81,33 +89,29 @@
download_model(url)
onnx_model = onnx.load(model_path)
- # set batch size
- onnx_model = update_batch_size(onnx_model, 1)
-
- # prepare the model
- logging.info("prepare model...")
- dev = device.create_cuda_gpu()
- sg_ir = sonnx.prepare(onnx_model, device=dev)
- autograd.training = False
- model = Infer(sg_ir)
-
- # verify the test
- # from utils import load_dataset
- # inputs, ref_outputs = load_dataset(
- # os.path.join('/tmp', 'squeezenet1.1', 'test_data_set_0'))
- # x_batch = tensor.Tensor(device=dev, data=inputs[0])
- # outputs = model.forward(x_batch)
- # for ref_o, o in zip(ref_outputs, outputs):
- # np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
-
- # inference
+ # inference demo
logging.info("preprocessing...")
img, labels = get_image_label()
img = preprocess(img)
+ # sg_ir = sonnx.prepare(onnx_model) # run without graph
+ # y = sg_ir.run([img])
+
+ logging.info("model compling...")
+ dev = device.create_cuda_gpu()
+ x = tensor.Tensor(device=dev, data=img)
+ model = MyModel(onnx_model)
+ model.compile([x], is_train=False, use_graph=True, sequential=True)
+
+ # verifty the test
+ # from utils import load_dataset
+ # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'squeezenet1.1', 'test_data_set_0'))
+ # x_batch = tensor.Tensor(device=dev, data=inputs[0])
+ # outputs = sg_ir.run([x_batch])
+ # for ref_o, o in zip(ref_outputs, outputs):
+ # np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
logging.info("model running...")
- x_batch = tensor.Tensor(device=dev, data=img)
- y = model.forward(x_batch)
+ y = model.forward(x)
logging.info("postprocessing...")
y = tensor.softmax(y)
diff --git a/examples/onnx/vgg19.py b/examples/onnx/vgg19.py
index 49606cb..a2c3ea7 100644
--- a/examples/onnx/vgg19.py
+++ b/examples/onnx/vgg19.py
@@ -22,10 +22,9 @@
from singa import device
from singa import tensor
-from singa import autograd
from singa import sonnx
import onnx
-from utils import download_model, update_batch_size, check_exist_or_download
+from utils import download_model, check_exist_or_download
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
@@ -44,7 +43,7 @@
return img
-def get_image_label():
+def get_image_labe():
# download label
label_url = 'https://s3.amazonaws.com/onnx-model-zoo/synset.txt'
with open(check_exist_or_download(label_url), 'r') as f:
@@ -56,22 +55,21 @@
return img, labels
-class Infer:
+class MyModel(sonnx.SONNXModel):
- def __init__(self, sg_ir):
- self.sg_ir = sg_ir
- for idx, tens in sg_ir.tensor_map.items():
- # allow the tensors to be updated
- tens.requires_grad = True
- tens.stores_grad = True
- sg_ir.tensor_map[idx] = tens
+ def __init__(self, onnx_model):
+ super(MyModel, self).__init__(onnx_model)
- def forward(self, x):
- return sg_ir.run([x])[0]
+ def forward(self, *x):
+ y = super(MyModel, self).forward(*x)
+ return y[0]
+
+ def train_one_batch(self, x, y):
+ pass
if __name__ == "__main__":
- url = 'https://github.com/onnx/models/raw/master/vision/classification/vgg/model/vgg19-7.tar.gz'
+ url = 'https://s3.amazonaws.com/onnx-model-zoo/vgg/vgg19/vgg19.tar.gz'
download_dir = '/tmp/'
model_path = os.path.join(download_dir, 'vgg19', 'vgg19.onnx')
@@ -79,33 +77,30 @@
download_model(url)
onnx_model = onnx.load(model_path)
- # set batch size
- onnx_model = update_batch_size(onnx_model, 1)
+ # inference
+ logging.info("preprocessing...")
+ img, labels = get_image_labe()
+ img = preprocess(img)
+ # sg_ir = sonnx.prepare(onnx_model) # run without graph
+ # y = sg_ir.run([img])
- # prepare the model
- logging.info("prepare model...")
- # dev = device.get_default_device()
+ logging.info("model compling...")
dev = device.create_cuda_gpu()
- sg_ir = sonnx.prepare(onnx_model, device=dev)
- autograd.training = False
- model = Infer(sg_ir)
+ x = tensor.PlaceHolder(img.shape, device=dev)
+ model = MyModel(onnx_model)
+ model.compile([x], is_train=False, use_graph=True, sequential=True)
- # verify the test
+ # verifty the test
# from utils import load_dataset
# inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'vgg19', 'test_data_set_0'))
# x_batch = tensor.Tensor(device=dev, data=inputs[0])
- # outputs = model.forward(x_batch)
+ # outputs = sg_ir.run([x_batch])
# for ref_o, o in zip(ref_outputs, outputs):
# np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
- # inference
- logging.info("preprocessing...")
- img, labels = get_image_label()
- img = preprocess(img)
-
logging.info("model running...")
- x_batch = tensor.Tensor(device=dev, data=img)
- y = model.forward(x_batch)
+ x = tensor.Tensor(device=dev, data=img)
+ y = model.forward(x)
logging.info("postprocessing...")
y = tensor.softmax(y)