blob: 9a8cc53373df17567089c48c398016f54e1fadc0 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Licensed to the Apache Software Foundation (ASF) under one\n",
"# or more contributor license agreements. See the NOTICE file\n",
"# distributed with this work for additional information\n",
"# regarding copyright ownership. The ASF licenses this file\n",
"# to you under the Apache License, Version 2.0 (the\n",
"# \"License\"); you may not use this file except in compliance\n",
"# with the License. You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing,\n",
"# software distributed under the License is distributed on an\n",
"# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
"# KIND, either express or implied. See the License for the\n",
"# specific language governing permissions and limitations\n",
"# under the License."
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"# Fast Sign Adversary Generation Example\n",
"\n",
"This notebook demos finds adversary examples using MXNet Gluon and taking advantage of the gradient information\n",
"\n",
"[1] Goodfellow, Ian J., Jonathon Shlens, and Christian Szegedy. \"Explaining and harnessing adversarial examples.\" arXiv preprint arXiv:1412.6572 (2014).\n",
"https://arxiv.org/abs/1412.6572"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 1,
"source": [
"%matplotlib inline\n",
"import mxnet as mx\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.cm as cm\n",
"\n",
"from mxnet import gluon"
],
"outputs": [],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"Build simple CNN network for solving the MNIST dataset digit recognition task"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 17,
"source": [
"ctx = mx.gpu() if mx.device.num_gpus() else mx.cpu()\n",
"batch_size = 128"
],
"outputs": [],
"metadata": {
"collapsed": true
}
},
{
"cell_type": "markdown",
"source": [
"## Data Loading"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"source": [
"transform = lambda x,y: (x.transpose((2,0,1)).astype('float32')/255., y)\n",
"\n",
"train_dataset = gluon.data.vision.MNIST(train=True).transform(transform)\n",
"test_dataset = gluon.data.vision.MNIST(train=False).transform(transform)\n",
"\n",
"train_data = gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=5)\n",
"test_data = gluon.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Create the network"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 4,
"source": [
"net = gluon.nn.HybridSequential()\n",
"with net.name_scope():\n",
" net.add(\n",
" gluon.nn.Conv2D(kernel_size=5, channels=20, activation='tanh'),\n",
" gluon.nn.MaxPool2D(pool_size=2, strides=2),\n",
" gluon.nn.Conv2D(kernel_size=5, channels=50, activation='tanh'),\n",
" gluon.nn.MaxPool2D(pool_size=2, strides=2),\n",
" gluon.nn.Flatten(),\n",
" gluon.nn.Dense(500, activation='tanh'),\n",
" gluon.nn.Dense(10)\n",
" )"
],
"outputs": [],
"metadata": {
"collapsed": true
}
},
{
"cell_type": "markdown",
"source": [
"## Initialize training"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"source": [
"net.initialize(mx.initializer.Uniform(), ctx=ctx)\n",
"net.hybridize()"
],
"outputs": [],
"metadata": {
"collapsed": true
}
},
{
"cell_type": "code",
"execution_count": 6,
"source": [
"loss = gluon.loss.SoftmaxCELoss()"
],
"outputs": [],
"metadata": {
"collapsed": true
}
},
{
"cell_type": "code",
"execution_count": 7,
"source": [
"trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1, 'momentum':0.95})"
],
"outputs": [],
"metadata": {
"collapsed": true
}
},
{
"cell_type": "markdown",
"source": [
"## Training loop"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"source": [
"epoch = 3\n",
"for e in range(epoch):\n",
" train_loss = 0.\n",
" acc = mx.gluon.metric.Accuracy()\n",
" for i, (data, label) in enumerate(train_data):\n",
" data = data.as_in_context(ctx)\n",
" label = label.as_in_context(ctx)\n",
" \n",
" with mx.autograd.record():\n",
" output = net(data)\n",
" l = loss(output, label)\n",
" \n",
" l.backward()\n",
" trainer.update(data.shape[0])\n",
" \n",
" train_loss += l.mean().item()\n",
" acc.update(label, output)\n",
" \n",
" print(\"Train Accuracy: %.2f\\t Train Loss: %.5f\" % (acc.get()[1], train_loss/(i+1)))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Train Accuracy: 0.92\t Train Loss: 0.32142\n",
"Train Accuracy: 0.97\t Train Loss: 0.16773\n",
"Train Accuracy: 0.97\t Train Loss: 0.14660\n"
]
}
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"## Perturbation\n",
"\n",
"We first run a validation batch and measure the resulting accuracy.\n",
"We then perturbate this batch by modifying the input in the opposite direction of the gradient."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 9,
"source": [
"# Get a batch from the testing set\n",
"for data, label in test_data:\n",
" data = data.as_in_context(ctx)\n",
" label = label.as_in_context(ctx)\n",
" break\n",
"\n",
"# Attach gradient to it to get the gradient of the loss with respect to the input\n",
"data.attach_grad()\n",
"with mx.autograd.record():\n",
" output = net(data) \n",
" l = loss(output, label)\n",
"l.backward()\n",
"\n",
"acc = mx.gluon.metric.Accuracy()\n",
"acc.update(label, output)\n",
"\n",
"print(\"Validation batch accuracy {}\".format(acc.get()[1]))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Validation batch accuracy 0.96875\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Now we perturb the input"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"source": [
"data_perturbated = data + 0.15 * mx.np.sign(data.grad)\n",
"\n",
"output = net(data_perturbated) \n",
"\n",
"acc = mx.gluon.metric.Accuracy()\n",
"acc.update(label, output)\n",
"\n",
"print(\"Validation batch accuracy after perturbation {}\".format(acc.get()[1]))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Validation batch accuracy after perturbation 0.40625\n"
]
}
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"## Visualization"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Let's visualize an example after pertubation.\n",
"\n",
"We can see that the prediction is often incorrect."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 16,
"source": [
"from random import randint\n",
"idx = randint(0, batch_size-1)\n",
"\n",
"plt.imshow(data_perturbated[idx, :].asnumpy().reshape(28,28), cmap=cm.Greys_r)\n",
"print(\"true label: %d\" % label.asnumpy()[idx])\n",
"print(\"predicted: %d\" % np.argmax(output.asnumpy(), axis=1)[idx])"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"true label: 1\n",
"predicted: 3\n"
]
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADpJJREFUeJzt3V+IXeW5x/Hfc9JsNbbMmLbGkAQdgxwZAxoZY+EMJy1tgo2F2AuluSg5IE0vIrbQi4q9qJeh9A9eSHGqobG2ScVWDConsaFgS0p1FI/G8VRNSWmGJGOxpCnIjJk8vdgrZYx7r7Wz1989z/cDw+xZ715rPbMmv6y997vW+5q7C0A8/1F3AQDqQfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwT1sSp31mq1fNmyZaVs+/Tp06Vs97yhoaHa9p0lrbYmq/O41X3M0n73rNref//9rm1nz57V/Py89VJDrvCb2W2SHpS0RNIj7r4r7fnLli3T+Ph4nl129eyzz5ay3fPS6i5731nKOqZlq/O41X3M0n73rNqmpqa6tk1PT/dcQ98v+81siaSHJH1R0qikbWY22u/2AFQrz3v+DZLecfc/u/ucpH2SthZTFoCy5Qn/Kkl/XfDz8WTZh5jZDjObNLPJubm5HLsDUKTSP+139wl3H3P3sVarVfbuAPQoT/inJa1Z8PPqZBmAAZAn/C9Jus7MRsysJekrkvYXUxaAsvXd1efuZ83sHkkH1O7q2+3ubxRWWQd1dg3V3Z3XVE0+LrfffnvdJTRarn5+d39O0nMF1QKgQlzeCwRF+IGgCD8QFOEHgiL8QFCEHwiq0vv5szS5z7jJBvW4NbkfPuuYlll71r5HRkYK2Q9nfiAowg8ERfiBoAg/EBThB4Ii/EBQ5u7V7cysup1doMndSnmldQ1l/d51dhPW+TcZ1O5RKb2rb3p6WrOzsz0N3c2ZHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCatQtvVkWc199HoN6XOq8bbZueX63tFl6LwZnfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IKlc/v5kdk3RG0ryks+4+lvb8oaEhjY+P59klOshzb3revvQ8ffXPPPNM6rr79u1Lbd+7d29q+/z8fGp7HovhGoQiLvL5nLv/rYDtAKgQL/uBoPKG3yUdNLOXzWxHEQUBqEbel/3j7j5tZldKet7M/t/dX1j4hOQ/hR2SdNlll+XcHYCi5Drzu/t08n1G0lOSNnR4zoS7j7n7WKvVyrM7AAXqO/xmdrmZfeL8Y0mbJR0pqjAA5crzsn+FpKfM7Px2fuHu/1tIVQBKV+m4/cPDw74Y+/kHeQz4vIaHh1PbDx482LXtlltuybXvO++8M7X9ySef7HvbTe7HT7ufn3H7AWQi/EBQhB8IivADQRF+ICjCDwQ1UEN3lylyd10eGzduTG1fvnx517ajR4+mrjs7O5vanqcrr8my/i2mTdF9MTjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQld7Sa2bV7QyFWLp0aWp7Vp/0tdde2/e+d+7cmdp+4MCBvrc9yNL6+bmlF0Amwg8ERfiBoAg/EBThB4Ii/EBQhB8IqtL7+bOm6Oae+s7KHEY665ivW7cutX3Tpk2p7Wn37J87dy513aj9+FXhzA8ERfiBoAg/EBThB4Ii/EBQhB8IivADQWX285vZbklfkjTj7uuSZcsl/VLSNZKOSbrL3f+et5g8/dlNvkagzume8x6XLVu2pLZnjb2f5sUXX+x7XeTXy5n/p5Juu2DZfZIOuft1kg4lPwMYIJnhd/cXJL13weKtkvYkj/dIuqPgugCUrN/3/Cvc/UTy+KSkFQXVA6AiuT/w8/YggF3H5jOzHWY2aWaTc3NzeXcHoCD9hv+Uma2UpOT7TLcnuvuEu4+5+1ir1epzdwCK1m/490vanjzeLunpYsoBUJXM8JvZXkl/kPSfZnbczO6WtEvSJjN7W9IXkp8BDJDMfn5339al6fMF15JLnX3pdSvzGodbb7011/rz8/Nd23btynfOiPo3n5qaKmQ7XOEHBEX4gaAIPxAU4QeCIvxAUIQfCKrSobtRjjxdXuvXr09tHx0dTW1fu3Ztavvs7GzXtquvvjp13ax25MOZHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCqrSf//Tp06Xdfhr19s68brzxxtT2rH78LIcPH861PsrDmR8IivADQRF+ICjCDwRF+IGgCD8QFOEHguJ+/uCuv/76XOtnTcH2+OOP59p+Hkzbno4zPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EldnPb2a7JX1J0oy7r0uWPSDpa5LeTZ52v7s/V1aR6N/NN9+c2n7DDTfk2n5WP/+pU6dybR/l6eXM/1NJt3VY/iN3vyn5IvjAgMkMv7u/IOm9CmoBUKE87/nvMbPXzGy3mV1RWEUAKtFv+H8saa2kmySdkPSDbk80sx1mNmlmk33uC0AJ+gq/u59y93l3PyfpJ5I2pDx3wt3H3H2s3yIBFK+v8JvZygU/flnSkWLKAVCVXrr69kr6rKRPmdlxSd+V9Fkzu0mSSzom6esl1gigBJnhd/dtHRY/WkItudR573YT7s3uZnh4OLXdzHJt/6233sq1fh6L9X79rN9rZGSk720vxBV+QFCEHwiK8ANBEX4gKMIPBEX4gaAYursAWV0zdXYFbt68ObX96NGjqe1r1qxJbX/iiScuuqaiZB3XtL9Lk7tnq8KZHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCqrSff2hoSOPj413bm3yLZh55rwPIWv+qq67q2nbppZemrpvl8OHDqe1HjqSP45Lnb0pffLk48wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUNzP3wB5r29Iu3Yi79DcBw4cSG0v89qMJo+TsBhw5geCIvxAUIQfCIrwA0ERfiAowg8ERfiBoDL7+c1sjaTHJK2Q5JIm3P1BM1su6ZeSrpF0TNJd7v73PMWUOa3xYjY0NNT3umfOnEltf/jhh/vedpa8/fR5/uZNvoYga99TU1OF7KeXM/9ZSd9y91FJn5G008xGJd0n6ZC7XyfpUPIzgAGRGX53P+HurySPz0h6U9IqSVsl7UmetkfSHWUVCaB4F/We38yukbRe0h8lrXD3E0nTSbXfFgAYED2H38w+LulXkr7p7v9Y2OburvbnAZ3W22Fmk2Y2OTc3l6tYAMXpKfxmtlTt4P/c3X+dLD5lZiuT9pWSZjqt6+4T7j7m7mOtVquImgEUIDP81r4t7FFJb7r7Dxc07Ze0PXm8XdLTxZcHoCy93NL7X5K+Kul1M3s1WXa/pF2SnjCzuyX9RdJd5ZSILGm39GZZvXp1avvGjRtT2z/44IO+951X3iHPmyqr7pGRkUL2kxl+d/+9pG43hX++kCoAVI4r/ICgCD8QFOEHgiL8QFCEHwiK8ANBLZqhuwd5GOesft0lS5aktqf1+65duzZ13ZMnT6a219mPn6XJw4YPwjUGnPmBoAg/EBThB4Ii/EBQhB8IivADQRF+IKhF088/yLL6jLP6+fMM3T0z03EApp4NQn92J3n76Qf1916IMz8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBGXtmbYq2plZ6s4G+Z78Ol155ZVd2+69997UdR966KHU9kceeaSvmlCetPEbpqenNTs7222o/Q/hzA8ERfiBoAg/EBThB4Ii/EBQhB8IivADQWX285vZGkmPSVohySVNuPuDZvaApK9Jejd56v3u/lzGtqq7qKBCZV+fkGcM+cVw3/kgKvPfxNTUVNe2i+nn72Uwj7OSvuXur5jZJyS9bGbPJ20/cvfv97IjAM2SGX53PyHpRPL4jJm9KWlV2YUBKNdFvec3s2skrZf0x2TRPWb2mpntNrMruqyzw8wmzWwyV6UACtVz+M3s45J+Jemb7v4PST+WtFbSTWq/MvhBp/XcfcLdx9x9rIB6ARSkp/Cb2VK1g/9zd/+1JLn7KXefd/dzkn4iaUN5ZQIoWmb4zcwkPSrpTXf/4YLlKxc87cuSjhRfHoCy9NLVNy7pd5Jel3QuWXy/pG1qv+R3ScckfT35cDBtW4uyqw+4WHm6Aivr6nP330vqtLHUPn0AzcYVfkBQhB8IivADQRF+ICjCDwRF+IGgKh26+5JLLvFVqwbznqDR0dG+181zS27Zyr7lN8/txmlDVKMzhu4GkInwA0ERfiAowg8ERfiBoAg/EBThB4KqeorudyX9ZcGiT0n6W2UFXJym1tbUuiRq61eRtV3t7p/u5YmVhv8jOzebbOrYfk2tral1SdTWr7pq42U/EBThB4KqO/wTNe8/TVNra2pdErX1q5baan3PD6A+dZ/5AdSklvCb2W1m9icze8fM7qujhm7M7JiZvW5mr9Y9xVgyDdqMmR1ZsGy5mT1vZm8n3ztOk1ZTbQ+Y2XRy7F41sy011bbGzH5rZlNm9oaZfSNZXuuxS6mrluNW+ct+M1si6S1JmyQdl/SSpG3u3n0w8gqZ2TFJY+5ee5+wmf23pH9Keszd1yXLvifpPXfflfzHeYW7f7shtT0g6Z91z9ycTCizcuHM0pLukPQ/qvHYpdR1l2o4bnWc+TdIesfd/+zuc5L2SdpaQx2N5+4vSHrvgsVbJe1JHu9R+x9P5brU1gjufsLdX0ken5F0fmbpWo9dSl21qCP8qyT9dcHPx9WsKb9d0kEze9nMdtRdTAcrFsyMdFLSijqL6SBz5uYqXTCzdGOOXT8zXheND/w+atzdb5b0RUk7k5e3jeTt92xN6q7paebmqnSYWfrf6jx2/c54XbQ6wj8tac2Cn1cnyxrB3aeT7zOSnlLzZh8+dX6S1OT7TM31/FuTZm7uNLO0GnDsmjTjdR3hf0nSdWY2YmYtSV+RtL+GOj7CzC5PPoiRmV0uabOaN/vwfknbk8fbJT1dYy0f0pSZm7vNLK2aj13jZrx298q/JG1R+xP/o5K+U0cNXeq6VtL/JV9v1F2bpL1qvwz8QO3PRu6W9ElJhyS9Lek3kpY3qLafqT2b82tqB21lTbWNq/2S/jVJryZfW+o+dil11XLcuMIPCIoP/ICgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBPUv5DLnMbZADooAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {}
}
],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}