PyTorch vs Apache MXNet

PyTorch is a popular deep learning framework due to its easy-to-understand API and its completely imperative approach. Apache MXNet includes the Gluon API which gives you the simplicity and flexibility of PyTorch and allows you to hybridize your network to leverage performance optimizations of the symbolic graph. As of April 2019, NVidia performance benchmarks show that Apache MXNet outperforms PyTorch by ~77% on training ResNet-50: 10,925 images per second vs. 6,175.

In the next 10 minutes, we'll do a quick comparison between the two frameworks and show how small the learning curve can be when switching from PyTorch to Apache MXNet.

Installation

PyTorch uses conda for installation by default, for example:

# !conda install pytorch-cpu -c pytorch, torchvision

For MXNet we use pip:

# !pip install mxnet

To install Apache MXNet with GPU support, you need to specify CUDA version. For example, the snippet below will install Apache MXNet with CUDA 10.2 support:

# !pip install mxnet-cu102

Data manipulation

Both PyTorch and Apache MXNet relies on multidimensional matrices as a data sources. While PyTorch follows Torch‘s naming convention and refers to multidimensional matrices as “tensors”, Apache MXNet follows NumPy’s conventions and refers to them as “NDArrays”.

In the code snippets below, we create a two-dimensional matrix where each element is initialized to 1. We show how to add 1 to each element of matrices and print the results.

PyTorch:

import torch

x = torch.ones(5,3)
y = x + 1
y

MXNet:

from mxnet import np

x = np.ones((5,3))
y = x + 1
y

The main difference apart from the package name is that the MXNet's shape input parameter needs to be passed as a tuple enclosed in parentheses as in NumPy.

Both frameworks support multiple functions to create and manipulate tensors / NDArrays. You can find more of them in the documentation.

Model training

After covering the basics of data creation and manipulation, let's dive deep and compare how model training is done in both frameworks. In order to do so, we are going to solve image classification task on MNIST data set using Multilayer Perceptron (MLP) in both frameworks. We divide the task in 4 steps.

1. Read data

The first step is to obtain the data. We download the MNIST data set from the web and load it into memory so that we can read batches one by one.

PyTorch:

from torchvision import datasets, transforms

trans = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.13,), (0.31,))])
pt_train_data = torch.utils.data.DataLoader(datasets.MNIST(
    root='.', train=True, download=True, transform=trans),
    batch_size=128, shuffle=True, num_workers=4)

MXNet:

from mxnet import gluon
from mxnet.gluon.data.vision import datasets, transforms

trans = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize(0.13, 0.31)])
mx_train_data = gluon.data.DataLoader(
    datasets.MNIST(train=True).transform_first(trans),
    batch_size=128, shuffle=True, num_workers=4)

Both frameworks allows you to download MNIST data set from their sources and specify that only training part of the data set is required.

The main difference between the code snippets is that MXNet uses transform_first method to indicate that the data transformation is done on the first element of the data batch, the MNIST picture, rather than the second element, the label.

2. Creating the model

Below we define a Multilayer Perceptron (MLP) with a single hidden layer and 10 units in the output layer.

PyTorch:

import torch.nn as pt_nn

pt_net = pt_nn.Sequential(
    pt_nn.Linear(28*28, 256),
    pt_nn.ReLU(),
    pt_nn.Linear(256, 10))

MXNet:

import mxnet.gluon.nn as mx_nn

mx_net = mx_nn.Sequential()
mx_net.add(mx_nn.Dense(256, activation='relu'),
           mx_nn.Dense(10))
mx_net.initialize()

We used the Sequential container to stack layers one after the other in order to construct the neural network. Apache MXNet differs from PyTorch in the following ways:

  • In PyTorch you have to specify the input size as the first argument of the Linear object. Apache MXNet provides an extra flexibility to network structure by automatically inferring the input size after the first forward pass.

  • In Apache MXNet you can specify activation functions directly in fully connected and convolutional layers.

  • After the model structure is defined, Apache MXNet requires you to explicitly call the model initialization function.

With a Sequential block, layers are executed one after the other. To have a different execution model, with PyTorch you can inherit from nn.Module and then customize how the .forward() function is executed. Similarly, in Apache MXNet you can inherit from gluon.Block to achieve similar results.

3. Loss function and optimization algorithm

The next step is to define the loss function and pick an optimization algorithm. Both PyTorch and Apache MXNet provide multiple options to chose from, and for our particular case we are going to use the cross-entropy loss function and the Stochastic Gradient Descent (SGD) optimization algorithm.

PyTorch:

pt_loss_fn = pt_nn.CrossEntropyLoss()
pt_trainer = torch.optim.SGD(pt_net.parameters(), lr=0.1)

MXNet:

mx_loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
mx_trainer = gluon.Trainer(mx_net.collect_params(),
                           'sgd', {'learning_rate': 0.1})

The code difference between frameworks is small. The main difference is that in Apache MXNet we use Trainer class, which accepts optimization algorithm as an argument. We also use .collect_params() method to get parameters of the network.

4. Training

Finally, we implement the training algorithm. Note that the results for each run may vary because the weights will get different initialization values and the data will be read in a different order due to shuffling.

PyTorch:

import time

for epoch in range(5):
    total_loss = .0
    tic = time.time()
    for X, y in pt_train_data:
        pt_trainer.zero_grad()
        loss = pt_loss_fn(pt_net(X.view(-1, 28*28)), y)
        loss.backward()
        pt_trainer.step()
        total_loss += loss.mean()
    print('epoch %d, avg loss %.4f, time %.2f' % (
        epoch, total_loss/len(pt_train_data), time.time()-tic))

MXNet:

from mxnet import autograd

for epoch in range(5):
    total_loss = .0
    tic = time.time()
    for X, y in mx_train_data:
        with autograd.record():
            loss = mx_loss_fn(mx_net(X), y)
        loss.backward()
        mx_trainer.step(batch_size=128)
        total_loss += loss.mean().item()
    print('epoch %d, avg loss %.4f, time %.2f' % (
        epoch, total_loss/len(mx_train_data), time.time()-tic))

Some of the differences in Apache MXNet when compared to PyTorch are as follows:

  • In Apache MXNet, you don't need to flatten the 4-D input into 2-D when feeding the data into forward pass.

  • In Apache MXNet, you need to perform the calculation within the autograd.record() scope so that it can be automatically differentiated in the backward pass.

  • It is not necessary to clear the gradient every time as with PyTorch's trainer.zero_grad() because by default the new gradient is written in, not accumulated.

  • You need to specify the update step size (usually batch size) when performing step() on the trainer.

  • You need to call .item() to turn a multidimensional array into a scalar.

  • In this sample, Apache MXNet is twice as fast as PyTorch. Though you need to be cautious with such toy comparisons.

Conclusion

As we saw above, Apache MXNet Gluon API and PyTorch have many similarities. The main difference lies in terminology (Tensor vs. NDArray) and behavior of accumulating gradients: gradients are accumulated in PyTorch and overwritten in Apache MXNet. The rest of the code is very similar, and it is quite straightforward to move code from one framework to the other.

Recommended Next Steps

While Apache MXNet Gluon API is very similar to PyTorch, there are some extra functionality that can make your code even faster.

  • Check out Hybridize tutorial to learn how to write imperative code which can be converted to symbolic one.

  • Also, check out how to extend Apache MXNet with your own custom layers.

Appendix

Below you can find a detailed comparison of various PyTorch functions and their equivalent in Gluon API of Apache MXNet.

Tensor operation

Here is the list of function names in PyTorch Tensor that are different from Apache MXNet NDArray.

FunctionPyTorchMXNet Gluon
Element-wise inverse cosinex.acos() or torch.acos(x)nd.arccos(x)
Batch Matrix product and accumulationtorch.addbmm(M, batch1, batch2)nd.linalg_gemm(M, batch1, batch2) Leading n-2 dim are reduced
Element-wise division of t1, t2, multiply v, and add ttorch.addcdiv(t, v, t1, t2)t + v*(t1/t2)
Matrix product and accumulationtorch.addmm(M, mat1, mat2)nd.linalg_gemm(M, mat1, mat2)
Outer-product of two vector add a matrixm.addr(vec1, vec2)Not available
Element-wise applies functionx.apply_(calllable)Not available, but there is nd.custom(x, 'op')
Element-wise inverse sinex.asin() or torch.asin(x)nd.arcsin(x)
Element-wise inverse tangentx.atan() or torch.atan(x)nd.arctan(x)
Tangent of two tensorx.atan2(y) or torch.atan2(x, y)Not available
batch matrix productx.bmm(y) or torch.bmm(x, x)nd.linalg_gemm2(x, y)
Draws a sample from bernoulli distributionx.bernoulli()Not available
Fills a tensor with number drawn from Cauchy distributionx.cauchy_()Not available
Splits a tensor in a given dimx.chunk(num_of_chunk)nd.split(x, num_outputs=num_of_chunk)
Limits the values of a tensor to between min and maxx.clamp(min, max)nd.clip(x, min, max)
Returns a copy of the tensorx.clone()x.copy()
Cross productx.cross(y)Not available
Cumulative product along an axisx.cumprod(1)Not available
Cumulative sum along an axisx.cumsum(1)Not available
Address of the first elementx.data_ptr()Not available
Creates a diagonal tensorx.diag()Not available
Computes norm of a tensorx.dist()nd.norm(x) Only calculate L2 norm
Computes Gauss error functionx.erf()Not available
Broadcasts/Expands tensor to new shapex.expand(3,4)x.broadcast_to([3, 4])
Fills a tensor with samples drawn from exponential distributionx.exponential_()nd.random_exponential()
Element-wise modx.fmod(3)nd.module(x, 3)
Fractional portion of a tensorx.frac()x - nd.trunc(x)
Gathers values along an axis specified by dimtorch.gather(x, 1, torch.LongTensor([[0,0],[1,0]]))nd.gather_nd(x, nd.array([[[0,0],[1,1]],[[0,0],[1,0]]]))
Solves least square & least normB.gels(A)Not available
Draws from geometirc distributionx.geometric_(p)Not available
Device context of a tensorprint(x) will print which device x is onx.context
Repeats tensorx.repeat(4,2)x.tile(4,2)
Data type of a tensorx.type()x.dtype
Scattertorch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23)nd.scatter_nd(nd.array([1.23,1.23]), nd.array([[0,1],[2,3]]), (2,4))
Returns the shape of a tensorx.size()x.shape
Number of elements in a tensorx.numel()x.size
Returns this tensor as a NumPy ndarrayx.numpy()x.asnumpy()
Eigendecomposition for symmetric matrixe, v = a.symeig()v, e = nd.linalg.syevd(a)
Transposex.t()x.T
Sample uniformlytorch.uniform_()nd.sample_uniform()
Inserts a new dimesionx.unsqueeze()nd.expand_dims(x)
Reshapex.view(16)x.reshape((16,))
Veiw as a specified tensorx.view_as(y)x.reshape_like(y)
Returns a copy of the tensor after casting to a specified typex.type(type)x.astype(dtype)
Copies the value of one tensor to anotherdst.copy_(src)src.copyto(dst)
Returns a zero tensor with specified shapex = torch.zeros(2,3)x = nd.zeros((2,3))
Returns a one tensor with specified shapex = torch.ones(2,3)x = nd.ones((2,3)
Returns a Tensor filled with the scalar value 1, with the same size as inputy = torch.ones_like(x)y = nd.ones_like(x)

Functional

GPU

Just like Tensor, MXNet NDArray can be copied to and operated on GPU. This is done by specifying context.

FunctionPyTorchMXNet Gluon
Copy to GPUy = torch.FloatTensor(1).cuda()y = mx.nd.ones((1,), ctx=mx.gpu(0))
Convert to numpy arrayx = y.cpu().numpy()x = y.asnumpy()
Context scopewith torch.cuda.device(1):
    y= torch.cuda.FloatTensor(1)
with mx.gpu(1):
    y = mx.nd.ones((3,5))

Cross-device

Just like Tensor, MXNet NDArray can be copied across multiple GPUs.

FunctionPyTorchMXNet Gluon
Copy from GPU 0 to GPU 1x = torch.cuda.FloatTensor(1)
y=x.cuda(1)
x = mx.nd.ones((1,), ctx=mx.gpu(0))
y=x.as_in_context(mx.gpu(1))
Copy Tensor/NDArray on different GPUsy.copy_(x)x.copyto(y)

Autograd

Variable wrapper vs autograd scope

Autograd package of PyTorch/MXNet enables automatic differentiation of Tensor/NDArray.

FunctionPyTorchMXNet Gluon
Recording computationx = Variable(torch.FloatTensor(1), requires_grad=True)
y = x * 2
y.backward()
x = mx.nd.ones((1,))
x.attach_grad()
with mx.autograd.record():
    y = x * 2
y.backward()

Scope override (pause, train_mode, predict_mode)

Some operators (Dropout, BatchNorm, etc) behave differently in training and making predictions. This can be controlled with train_mode and predict_mode scope in MXNet. Pause scope is for code that does not need gradients to be calculated.

FunctionPyTorchMXNet Gluon
Scope overrideNot availablex = mx.nd.ones((1,))
with autograd.train_mode():
    y = mx.nd.Dropout(x)
    with autograd.predict_mode():
        z = mx.nd.Dropout(y)

w = mx.nd.ones((1,))
w.attach_grad()
with autograd.record():
    y = x * w
    y.backward()
    with autograd.pause():
        w += w.grad

Batch-end synchronization is needed

Apache MXNet uses lazy evaluation to achieve superior performance. The Python thread just pushes the operations into the backend engine and then returns. In training phase batch-end synchronization is needed, e.g, asnumpy(), wait_to_read(), metric.update(...).

FunctionPyTorchMXNet Gluon
Batch-end synchronizationNot availablefor (data, label) in train_data:
    with autograd.record():
        output = net(data)
        L = loss(output, label)
        L.backward()
    trainer.step(data.shape[0])
    metric.update([label], [output])

PyTorch module and Gluon blocks

For new block definition, gluon is similar to PyTorch

FunctionPyTorchMXNet Gluon
New block definitionclass Net(torch.nn.Module):
    def __init__(self, D_in, D_out):
        super(Net, self).__init__()
        self.linear = torch.nn.Linear(D_in, D_out)
    def forward(self, x):
        return self.linear(x)
class Net(mx.gluon.Block):
    def __init__(self, D_in, D_out):
        super(Net, self).__init__()
        self.dense=mx.gluon.nn.Dense(D_out, in_units=D_in)
    def forward(self, x):
        return self.dense(x)

Parameter and Initializer

When creating new layers in PyTorch, you do not need to specify its parameter initializer, and different layers have different default initializer. When you create new layers in Gluon API, you can specify its initializer or just leave it none. The parameters will finish initializing after calling net.initialize(<init method>) and all parameters will be initialized in init method except those layers whose initializer specified.

FunctionPyTorchMXNet Gluon
Get all parametersnet.parameters()net.collect_params()
Initialize networkNot Availablenet.initialize(mx.init.Xavier())
Specify layer initializerlayer = torch.nn.Linear(20, 10)
torch.nn.init.normal(layer.weight, 0, 0.01)
layer = mx.gluon.nn.Dense(10, weight_initializer=mx.init.Normal(0.01))

Usage of existing blocks look alike

FunctionPyTorchMXNet Gluon
Usage of existing blocksy=net(x)y=net(x)

HybridBlock can be hybridized, and allows partial-shape info

HybridBlock supports forwarding with both Symbol and NDArray. After hybridized, HybridBlock will create a symbolic graph representing the forward computation and cache it. Most of the built-in blocks (Dense, Conv2D, MaxPool2D, BatchNorm, etc.) are HybridBlocks.

Instead of explicitly declaring the number of inputs to a layer, we can simply state the number of outputs. The shape will be inferred on the fly once the network is provided with some input.

FunctionPyTorchMXNet Gluon
partial-shape
hybridized
Not Availablenet = mx.gluon.nn.HybridSequential()
net.add(mx.gluon.nn.Dense(10))
net.hybridize()

SymbolBlock

SymbolBlock can construct block from symbol. This is useful for using pre-trained models as feature extractors.

FunctionPyTorchMXNet Gluon
SymbolBlockNot Availablealexnet = mx.gluon.model_zoo.vision.alexnet(pretrained=True)
out = alexnet(inputs)
internals = out.get_internals()
outputs = [internals['model_dense0_relu_fwd_output']]
feat_model = gluon.SymbolBlock(outputs, inputs, params=alexnet.collect_params())

PyTorch optimizer vs Gluon Trainer

For Gluon API calling zero_grad is not necessary most of the time

zero_grad in optimizer (PyTorch) or Trainer (Gluon API) clears the gradients of all parameters. In Gluon API, there is no need to clear the gradients every batch if grad_req = 'write'(default).

FunctionPyTorchMXNet Gluon
clear the gradientsoptm = torch.optim.SGD(model.parameters(), lr=0.1)
optm.zero_grad()
loss_fn(model(input), target).backward()
optm.step()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})
with autograd.record():
    loss = loss_fn(net(data), label)
loss.backward()
trainer.step(batch_size)

Multi-GPU training

FunctionPyTorchMXNet Gluon
data parallelismnet = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
output = net(data)
ctx = [mx.gpu(i) for i in range(3)]
data = gluon.utils.split_and_load(data, ctx)
label = gluon.utils.split_and_load(label, ctx)
with autograd.record():
    losses = [loss(net(X), Y) for X, Y in zip(data, label)]
for l in losses:
    l.backward()

Distributed training

FunctionPytorchMXNet Gluon
distributed data parallelismtorch.distributed.init_process_group(...)
model = torch.nn.parallel.distributedDataParallel(model, ...)
store = kv.create('dist')
trainer = gluon.Trainer(net.collect_params(), ..., kvstore=store)

Monitoring

Apache MXNet has pre-defined metrics

Gluon provide several predefined metrics which can online evaluate the performance of a learned model.

FunctionPyTorchMXNet Gluon
metricNot availablemetric = mx.metric.Accuracy()
with autograd.record():
    output = net(data)
    L = loss(ouput, label)
    loss(ouput, label).backward()
trainer.step(batch_size)
metric.update(label, output)

Data visualization

TensorboardX (PyTorch) and MXBoard (MXNet) can be used to visualize your network and plot quantitative metrics about the execution of your graph.

PyTorchMXNet
sw = tensorboardX.SummaryWriter()sw = mxboard.SummaryWriter()
......
for name, param in model.named_parameters():for name, param in net.collect_params():
grad = param.clone().cpu().data.numpy() grad = param.grad.asnumpy().flatten()
sw.add_histogram(name, grad, n_iter) sw.add_histogram(tag=str(param),
... values=grad,
sw.close() bins=200,
global_step=i)
...
sw.close()

I/O and deploy

Data loading

Dataset and DataLoader are the basic components for loading data.

ClassPyTorchMXNet Gluon
Dataset holding arraystorch.utils.data.TensorDataset(data_tensor, label_tensor)gluon.data.ArrayDataset(data_array, label_array)
Data loadertorch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, drop_last=False)gluon.data.DataLoader(dataset, batch_size=None, shuffle=False, sampler=None, last_batch='keep', batch_sampler=None, batchify_fn=None, num_workers=0)
Sequentially applied samplertorch.utils.data.sampler.SequentialSampler(data_source)gluon.data.SequentialSampler(length)
Random order samplertorch.utils.data.sampler.RandomSampler(data_source)gluon.data.RandomSampler(length)

Some commonly used datasets for computer vision are provided in mx.gluon.data.vision package.

ClassPyTorchMXNet Gluon
MNIST handwritten digits dataset.torchvision.datasets.MNISTmx.gluon.data.vision.MNIST
CIFAR10 Dataset.torchvision.datasets.CIFAR10mx.gluon.data.vision.CIFAR10
CIFAR100 Dataset.torchvision.datasets.CIFAR100mx.gluon.data.vision.CIFAR100
A generic data loader where the images are arranged in folders.torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)mx.gluon.data.vision.ImageFolderDataset(root, flag, transform=None)

Serialization

Serialization and deserialization are achieved by calling save_parameters and load_parameters.

ClassPyTorchMXNet Gluon
Save model parameterstorch.save(the_model.state_dict(), filename)model.save_parameters(filename)
Load parametersthe_model.load_state_dict(torch.load(PATH))model.load_parameters(filename, ctx, allow_missing=False, ignore_extra=False)