id: version-3.1.0-onnx title: ONNX original_id: onnx

ONNX is an open representation format for machine learning models, which enables AI developers to use models across different libraries and tools. SINGA supports loading ONNX format models for training and inference, and saving models defined using SINGA APIs (e.g., Module) into ONNX format.

SINGA has been tested with the following version of ONNX.

ONNX versionFile format versionOpset version ai.onnxOpset version ai.onnx.mlOpset version ai.onnx.training
1.6.06112-

General usage

Loading an ONNX Model into SINGA

After loading an ONNX model from disk by onnx.load, You only need to update the batch-size of input using tensor.PlaceHolder after SINGA v3.0, the shape of internal tensors will be inferred automatically.

Then, you should define a class inheriting from sonnx.SONNXModel and implement two methods forward for forward work and train_one_batch for training work. After you call model.compile, the SONNX iterates and translates all the nodes within the ONNX model‘s graph into SINGA operators, loads all stored weights and infers each intermediate tensor’s shape.

import onnx
from singa import device
from singa import sonnx

class MyModel(sonnx.SONNXModel):

    def __init__(self, onnx_model):
        super(MyModel, self).__init__(onnx_model)

    def forward(self, *x):
        y = super(MyModel, self).forward(*x)
        # Since SINGA model returns the output as a list,
        # if there is only one output,
        # you just need to take the first element.
        return y[0]

    def train_one_batch(self, x, y):
        pass

model_path = "PATH/To/ONNX/MODEL"
onnx_model = onnx.load(model_path)

# convert onnx model into SINGA model
dev = device.create_cuda_gpu()
x = tensor.PlaceHolder(INPUT.shape, device=dev)
model = MyModel(onnx_model)
model.compile([x], is_train=False, use_graph=True, sequential=True)

Inference SINGA model

Once the model is created, you can do inference by calling model.forward. The input and output must be SINGA Tensor instances.

x = tensor.Tensor(device=dev, data=INPUT)
y = model.forward(x)

Saving SINGA model into ONNX Format

Given the input tensors and the output tensors generated by the operators the model, you can trace back all internal operations. Therefore, a SINGA model is defined by the input and outputs tensors. To export a SINGA model into ONNX format, you just need to provide the input and output tensor list.

# x is the input tensor, y is the output tensor
sonnx.to_onnx([x], [y])

Re-training an ONNX model

To train (or refine) an ONNX model using SINGA, you need to implement the train_one_batch from sonnx.SONNXModel and mark the is_train=True when calling model.compile.

from singa import opt
from singa import autograd

class MyModel(sonnx.SONNXModel):

    def __init__(self, onnx_model):
        super(MyModel, self).__init__(onnx_model)

    def forward(self, *x):
        y = super(MyModel, self).forward(*x)
        return y[0]

    def train_one_batch(self, x, y, dist_option, spars):
        out = self.forward(x)
        loss = autograd.softmax_cross_entropy(out, y)
        if dist_option == 'fp32':
            self.optimizer.backward_and_update(loss)
        elif dist_option == 'fp16':
            self.optimizer.backward_and_update_half(loss)
        elif dist_option == 'partialUpdate':
            self.optimizer.backward_and_partial_update(loss)
        elif dist_option == 'sparseTopK':
            self.optimizer.backward_and_sparse_update(loss,
                                                      topK=True,
                                                      spars=spars)
        elif dist_option == 'sparseThreshold':
            self.optimizer.backward_and_sparse_update(loss,
                                                      topK=False,
                                                      spars=spars)
        return out, loss

    def set_optimizer(self, optimizer):
        self.optimizer = optimizer

sgd = opt.SGD(lr=0.005, momentum=0.9, weight_decay=1e-5)
model.set_optimizer(sgd)
model.compile([tx], is_train=True, use_graph=graph, sequential=True)

Transfer-learning an ONNX model

You also can append some layers to the end of the ONNX model to do transfer-learning. The last_layers accept a negative integer indicating the layer to cut off from. For example, -1 means cut off after the final output(do not cut off any layer), -2 means you cut off after the last second layer.

from singa import opt
from singa import autograd

class MyModel(sonnx.SONNXModel):

    def __init__(self, onnx_model):
        super(MyModel, self).__init__(onnx_model)
        self.linear = layer.Linear(1000, 3)

    def forward(self, *x):
        # cut off after the last third layer
        # and append a linear layer
        y = super(MyModel, self).forward(*x, last_layers=-3)[0]
        y = self.linear(y)
        return y

    def train_one_batch(self, x, y, dist_option, spars):
        out = self.forward(x)
        loss = autograd.softmax_cross_entropy(out, y)
        if dist_option == 'fp32':
            self.optimizer.backward_and_update(loss)
        elif dist_option == 'fp16':
            self.optimizer.backward_and_update_half(loss)
        elif dist_option == 'partialUpdate':
            self.optimizer.backward_and_partial_update(loss)
        elif dist_option == 'sparseTopK':
            self.optimizer.backward_and_sparse_update(loss,
                                                      topK=True,
                                                      spars=spars)
        elif dist_option == 'sparseThreshold':
            self.optimizer.backward_and_sparse_update(loss,
                                                      topK=False,
                                                      spars=spars)
        return out, loss

    def set_optimizer(self, optimizer):
        self.optimizer = optimizer

sgd = opt.SGD(lr=0.005, momentum=0.9, weight_decay=1e-5)
model.set_optimizer(sgd)
model.compile([tx], is_train=True, use_graph=graph, sequential=True)

ONNX model zoo

The ONNX Model Zoo is a collection of pre-trained, state-of-the-art models in the ONNX format contributed by community members. SINGA has supported several CV and NLP models now. More models are going to be supported soon.

Image Classification

This collection of models take images as input, then classifies the major objects in the images into 1000 object categories such as keyboard, mouse, pencil, and many animals.

Model ClassReferenceDescriptionLink
MobileNetSandler et al.Light-weight deep neural network best suited for mobile and embedded vision applications.
Top-5 error from paper - ~10%
Open In Colab
ResNet18He et al.A CNN model (up to 152 layers). Uses shortcut connections to achieve higher accuracy when classifying images.
Top-5 error from paper - ~3.6%
Open In Colab
VGG16Simonyan et al.Deep CNN model(up to 19 layers). Similar to AlexNet but uses multiple smaller kernel-sized filters that provides more accuracy when classifying images.
Top-5 error from paper - ~8%
Open In Colab
ShuffleNet_V2Simonyan et al.Extremely computation efficient CNN model that is designed specifically for mobile devices. This network architecture design considers direct metric such as speed, instead of indirect metric like FLOP. Top-1 error from paper - ~30.6%[Open In Colab

We also give some re-training examples by using VGG and ResNet, please check examples/onnx/training.

Object Detection

Object detection models detect the presence of multiple objects in an image and segment out areas of the image where the objects are detected.

Model ClassReferenceDescriptionLink
Tiny YOLOv2Redmon et al.A real-time CNN for object detection that detects 20 different classes. A smaller version of the more complex full YOLOv2 network.Open In Colab

Face Analysis

Face detection models identify and/or recognize human faces and emotions in given images.

Model ClassReferenceDescriptionLink
ArcFaceDeng et al.A CNN based model for face recognition which learns discriminative features of faces and produces embeddings for input face images.Open In Colab
Emotion FerPlusBarsoum et al.Deep CNN for emotion recognition trained on images of faces.Open In Colab

Machine Comprehension

This subset of natural language processing models that answer questions about a given context paragraph.

Model ClassReferenceDescriptionLink
BERT-SquadDevlin et al.This model answers questions based on the context of the given input paragraph.Open In Colab
RoBERTaDevlin et al.A large transformer-based model that predicts sentiment based on given input text.Open In Colab
GPT-2Devlin et al.A large transformer-based language model that given a sequence of words within some text, predicts the next word.[Open In Colab

Supported operators

The following operators are supported:

  • Acos
  • Acosh
  • Add
  • And
  • Asin
  • Asinh
  • Atan
  • Atanh
  • AveragePool
  • BatchNormalization
  • Cast
  • Ceil
  • Clip
  • Concat
  • ConstantOfShape
  • Conv
  • Cos
  • Cosh
  • Div
  • Dropout
  • Elu
  • Equal
  • Erf
  • Expand
  • Flatten
  • Gather
  • Gemm
  • GlobalAveragePool
  • Greater
  • HardSigmoid
  • Identity
  • LeakyRelu
  • Less
  • Log
  • MatMul
  • Max
  • MaxPool
  • Mean
  • Min
  • Mul
  • Neg
  • NonZero
  • Not
  • OneHot
  • Or
  • Pad
  • Pow
  • PRelu
  • Reciprocal
  • ReduceMean
  • ReduceSum
  • Relu
  • Reshape
  • ScatterElements
  • Selu
  • Shape
  • Sigmoid
  • Sign
  • Sin
  • Sinh
  • Slice
  • Softmax
  • Softplus
  • Softsign
  • Split
  • Sqrt
  • Squeeze
  • Sub
  • Sum
  • Tan
  • Tanh
  • Tile
  • Transpose
  • Unsqueeze
  • Upsample
  • Where
  • Xor

Special comments for ONNX backend

  • Conv, MaxPool and AveragePool

    Input must be 1d(N*C*H) and 2d(N*C*H*W) shape and dilation must be 1.

  • BatchNormalization

    epsilon is 1e-05 and cannot be changed.

  • Cast

    Only support float32 and int32, other types are casted to these two types.

  • Squeeze and Unsqueeze

    If you encounter errors when you Squeeze or Unsqueeze between Tensor and Scalar, please report to us.

  • Empty tensor Empty tensor is illegal in SINGA.

Implementation

The code of SINGA ONNX locates at python/singa/soonx.py. There are four main class, SingaFrontend, SingaBackend, SingaRep and SONNXModel. SingaFrontend translates a SINGA model to an ONNX model; SingaBackend translates an ONNX model to SingaRep object which stores all SINGA operators and tensors(the tensor in this doc means SINGA Tensor); SingaRep can be run like a SINGA model. SONNXModel inherits from model.Model which defines a unified API for SINGA.

SingaFrontend

The entry function of SingaFrontend is singa_to_onnx_model which also is called to_onnx. singa_to_onnx_model creates the ONNX model, and it also create a ONNX graph by using singa_to_onnx_graph.

singa_to_onnx_graph accepts the output of the model, and recursively iterate the SINGA model's graph from the output to get all operators to form a queue. The input and intermediate tensors, i.e, trainable weights, of the SINGA model is picked up at the same time. The input is stored in onnx_model.graph.input; the output is stored in onnx_model.graph.output; and the trainable weights are stored in onnx_model.graph.initializer.

Then the SINGA operator in the queue is translated to ONNX operators one by one. _rename_operators defines the operators name mapping between SINGA and ONNX. _special_operators defines which function to be used to translate the operator.

In addition, some operators in SINGA has different definition with ONNX, that is, ONNX regards some attributes of SINGA operators as input, so _unhandled_operators defines which function to handle the special operator.

Since the bool type is regarded as int32 in SINGA, _bool_operators defines the operators to be changed as bool type.

SingaBackend

The entry function of SingaBackend is prepare which checks the version of ONNX model and call _onnx_model_to_singa_ops then.

The purpose of _onnx_model_to_singa_ops is to get SINGA tensors and operators. The tensors are stored in a dictionary by their name in ONNX, and operators are stored in queue by the form of namedtuple('SingaOps', ['node', 'operator']). For each operator, node is an instance from OnnxNode which is defined to store some basic information for an ONNX node; operator is the SINGA operator's forward function;

The first step of _onnx_model_to_singa_ops has four steps, the first one is to call _parse_graph_params to get all tensors stored as params. Then call _parse_graph_inputs_outputs to get all input and output information stores as inputs and outputs. Finally, it iterators all nodes within the ONNX graph and parses it by _onnx_node_to_singa_op as SIGNA operators or layers and store them as outputs. Some weights are stored within an ONNX node called Constant, SONNX can handle them by _onnx_constant_to_np to store it into params.

This class finally return a SingaRep object and stores above params, inputs, outputs, layers.

SingaRep

SingaBackend stores all SINGA tensors and operators. run accepts the input of the model and runs the SINGA operators one by one following the operators' queue. The user can use last_layers to cut off the model after the last few layers.

SONNXModel

SONNXModel inherits from sonnx.SONNXModel and implements the method forward to provide a unified API with other SINGA models.