blob: 7d33e1a369c7a5f350ba655a5e3609edb327110f [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements; and to You under the Apache License, Version 2.0. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Classify images from MNIST using LeNet"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset\n",
"\n",
"Download the [dataset](http://deeplearning.net/data/mnist/mnist.pkl.gz) to your workspace (i.e. the notebook folder)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from __future__ import division\n",
"from builtins import zip\n",
"from builtins import str\n",
"from builtins import range\n",
"from past.utils import old_div\n",
"from future import standard_library\n",
"from __future__ import print_function\n",
"from tqdm import tnrange, tqdm_notebook\n",
"\n",
"standard_library.install_aliases()\n",
"import pickle, gzip\n",
"\n",
"# Load the dataset\n",
"f = gzip.open('mnist.pkl.gz', 'rb')\n",
"train_set, valid_set, _ = pickle.load(f, encoding='latin1')\n",
"f.close()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(50000, 784) (50000,)\n",
"(10000, 784) (10000,)\n"
]
}
],
"source": [
"print(train_set[0].shape, train_set[1].shape)\n",
"print(valid_set[0].shape, valid_set[1].shape)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"train_x = np.reshape(train_set[0], (50000, 1, 28, 28)).astype(np.float32, copy=False)\n",
"train_y = np.array(train_set[1]).astype(np.int32, copy=False)\n",
"valid_x = np.reshape(valid_set[0], (10000, 1, 28, 28))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fdde5663438>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADoBJREFUeJzt3X2MXOV1x/HfyXq9jo1JvHHYboiLHeMEiGlMOjIgLKCi\nuA5CMiiKiRVFDiFxmuCktK4EdavGrWjlVgmRQynS0ri2I95CAsJ/0CR0FUGiwpbFMeYtvJlNY7Ps\nYjZgQ4i9Xp/+sdfRBnaeWc/cmTu75/uRVjtzz71zj6792zszz8x9zN0FIJ53Fd0AgGIQfiAowg8E\nRfiBoAg/EBThB4Ii/EBQhB8IivADQU1r5M6mW5vP0KxG7hII5bd6U4f9kE1k3ZrCb2YrJG2W1CLp\nP9x9U2r9GZqls+2iWnYJIKHHuye8btVP+82sRdJNkj4h6QxJq83sjGofD0Bj1fKaf6mk5919j7sf\nlnSHpJX5tAWg3moJ/8mSfjXm/t5s2e8xs7Vm1mtmvcM6VMPuAOSp7u/2u3uXu5fcvdSqtnrvDsAE\n1RL+fZLmjbn/wWwZgEmglvA/ImmRmS0ws+mSPi1pRz5tAai3qof63P2Ima2T9CONDvVtcfcnc+sM\nQF3VNM7v7vdJui+nXgA0EB/vBYIi/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK\n8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8I\nivADQRF+IKiaZuk1sz5JByWNSDri7qU8mkJ+bFr6n7jl/XPruv9n/np+2drIzKPJbU9ZOJisz/yK\nJesv3zC9bG1n6c7ktvtH3kzWz75rfbJ+6l89nKw3g5rCn/kTd9+fw+MAaCCe9gNB1Rp+l/RjM3vU\nzNbm0RCAxqj1af8yd99nZidJut/MfuHuD45dIfujsFaSZmhmjbsDkJeazvzuvi/7PSjpHklLx1mn\ny91L7l5qVVstuwOQo6rDb2azzGz2sduSlkt6Iq/GANRXLU/7OyTdY2bHHuc2d/9hLl0BqLuqw+/u\neyR9LMdepqyW0xcl697Wmqy/dMF7k/W3zik/Jt3+nvR49U8/lh7vLtJ//WZ2sv4v/7YiWe8587ay\ntReH30puu2ng4mT9Az/1ZH0yYKgPCIrwA0ERfiAowg8ERfiBoAg/EFQe3+oLb+TCjyfrN2y9KVn/\ncGv5r55OZcM+kqz//Y2fS9anvZkebjv3rnVla7P3HUlu27Y/PRQ4s7cnWZ8MOPMDQRF+ICjCDwRF\n+IGgCD8QFOEHgiL8QFCM8+eg7ZmXkvVHfzsvWf9w60Ce7eRqff85yfqeN9KX/t668Ptla68fTY/T\nd3z7f5L1epr8X9itjDM/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRl7o0b0TzR2v1su6hh+2sWQ1ee\nm6wfWJG+vHbL7hOS9ce+cuNx93TM9fv/KFl/5IL0OP7Ia68n635u+au7930tuakWrH4svQLeoce7\ndcCH0nOXZzjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQFcf5zWyLpEslDbr74mxZu6Q7Jc2X1Cdp\nlbv/utLOoo7zV9Iy933J+sirQ8n6i7eVH6t/8vwtyW2X/vNXk/WTbiruO/U4fnmP82+V9PaJ0K+T\n1O3uiyR1Z/cBTCIVw+/uD0p6+6lnpaRt2e1tki7LuS8AdVbta/4Od+/Pbr8sqSOnfgA0SM1v+Pno\nmwZl3zgws7Vm1mtmvcM6VOvuAOSk2vAPmFmnJGW/B8ut6O5d7l5y91Kr2qrcHYC8VRv+HZLWZLfX\nSLo3n3YANErF8JvZ7ZIekvQRM9trZldJ2iTpYjN7TtKfZvcBTCIVr9vv7qvLlBiwz8nI/ldr2n74\nwPSqt/3oZ55K1l+5uSX9AEdHqt43isUn/ICgCD8QFOEHgiL8QFCEHwiK8ANBMUX3FHD6tc+WrV15\nZnpE9j9P6U7WL/jU1cn67DsfTtbRvDjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQjPNPAalpsl/9\n8unJbf9vx1vJ+nXXb0/W/2bV5cm6//w9ZWvz/umh5LZq4PTxEXHmB4Ii/EBQhB8IivADQRF+ICjC\nDwRF+IGgKk7RnSem6G4+Q58/N1m/9evfSNYXTJtR9b4/un1dsr7olv5k/cievqr3PVXlPUU3gCmI\n8ANBEX4gKMIPBEX4gaAIPxAU4QeCqjjOb2ZbJF0qadDdF2fLNkr6oqRXstU2uPt9lXbGOP/k4+ct\nSdZP3LQ3Wb/9Qz+qet+n/eQLyfpH/qH8dQwkaeS5PVXve7LKe5x/q6QV4yz/lrsvyX4qBh9Ac6kY\nfnd/UNJQA3oB0EC1vOZfZ2a7zWyLmc3JrSMADVFt+G+WtFDSEkn9kr5ZbkUzW2tmvWbWO6xDVe4O\nQN6qCr+7D7j7iLsflXSLpKWJdbvcveTupVa1VdsngJxVFX4z6xxz93JJT+TTDoBGqXjpbjO7XdKF\nkuaa2V5JX5d0oZktkeSS+iR9qY49AqgDvs+PmrR0nJSsv3TFqWVrPdduTm77rgpPTD/z4vJk/fVl\nrybrUxHf5wdQEeEHgiL8QFCEHwiK8ANBEX4gKIb6UJjv7U1P0T3Tpifrv/HDyfqlX72m/GPf05Pc\ndrJiqA9ARYQfCIrwA0ERfiAowg8ERfiBoAg/EFTF7/MjtqPL0pfufuFT6Sm6Fy/pK1urNI5fyY1D\nZyXrM+/trenxpzrO/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOP8U5yVFifrz34tPdZ+y3nbkvXz\nZ6S/U1+LQz6crD88tCD9AEf7c+xm6uHMDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBVRznN7N5krZL\n6pDkkrrcfbOZtUu6U9J8SX2SVrn7r+vXalzTFpySrL9w5QfK1jZecUdy20+esL+qnvKwYaCUrD+w\n+Zxkfc629HX/kTaRM/8RSevd/QxJ50i62szOkHSdpG53XySpO7sPYJKoGH5373f3ndntg5KelnSy\npJWSjn38a5uky+rVJID8HddrfjObL+ksST2SOtz92OcnX9boywIAk8SEw29mJ0j6gaRr3P3A2JqP\nTvg37qR/ZrbWzHrNrHdYh2pqFkB+JhR+M2vVaPBvdfe7s8UDZtaZ1TslDY63rbt3uXvJ3Uutasuj\nZwA5qBh+MzNJ35H0tLvfMKa0Q9Ka7PYaSffm3x6AepnIV3rPk/RZSY+b2a5s2QZJmyR9z8yukvRL\nSavq0+LkN23+Hybrr/9xZ7J+xT/+MFn/8/fenazX0/r+9HDcQ/9efjivfev/Jredc5ShvHqqGH53\n/5mkcvN9X5RvOwAahU/4AUERfiAowg8ERfiBoAg/EBThB4Li0t0TNK3zD8rWhrbMSm775QUPJOur\nZw9U1VMe1u1blqzvvDk9Rffc7z+RrLcfZKy+WXHmB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgwozz\nH/6z9GWiD//lULK+4dT7ytaWv/vNqnrKy8DIW2Vr5+9Yn9z2tL/7RbLe/lp6nP5osopmxpkfCIrw\nA0ERfiAowg8ERfiBoAg/EBThB4IKM87fd1n679yzZ95Vt33f9NrCZH3zA8uTdRspd+X0Uadd/2LZ\n2qKBnuS2I8kqpjLO/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QlLl7egWzeZK2S+qQ5JK63H2zmW2U\n9EVJr2SrbnD38l96l3SitfvZxqzeQL30eLcO+FD6gyGZiXzI54ik9e6+08xmS3rUzO7Pat9y929U\n2yiA4lQMv7v3S+rPbh80s6clnVzvxgDU13G95jez+ZLOknTsM6PrzGy3mW0xszlltllrZr1m1jus\nQzU1CyA/Ew6/mZ0g6QeSrnH3A5JulrRQ0hKNPjP45njbuXuXu5fcvdSqthxaBpCHCYXfzFo1Gvxb\n3f1uSXL3AXcfcfejkm6RtLR+bQLIW8Xwm5lJ+o6kp939hjHLO8esdrmk9HStAJrKRN7tP0/SZyU9\nbma7smUbJK02syUaHf7rk/SlunQIoC4m8m7/zySNN26YHNMH0Nz4hB8QFOEHgiL8QFCEHwiK8ANB\nEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiCoipfuznVnZq9I+uWYRXMl7W9YA8enWXtr\n1r4keqtWnr2d4u7vn8iKDQ3/O3Zu1uvupcIaSGjW3pq1L4neqlVUbzztB4Ii/EBQRYe/q+D9pzRr\nb83al0Rv1Sqkt0Jf8wMoTtFnfgAFKST8ZrbCzJ4xs+fN7LoieijHzPrM7HEz22VmvQX3ssXMBs3s\niTHL2s3sfjN7Lvs97jRpBfW20cz2Zcdul5ldUlBv88zsJ2b2lJk9aWZ/kS0v9Ngl+irkuDX8ab+Z\ntUh6VtLFkvZKekTSand/qqGNlGFmfZJK7l74mLCZnS/pDUnb3X1xtuxfJQ25+6bsD+ccd7+2SXrb\nKOmNomduziaU6Rw7s7SkyyR9TgUeu0Rfq1TAcSvizL9U0vPuvsfdD0u6Q9LKAvpoeu7+oKShty1e\nKWlbdnubRv/zNFyZ3pqCu/e7+87s9kFJx2aWLvTYJfoqRBHhP1nSr8bc36vmmvLbJf3YzB41s7VF\nNzOOjmzadEl6WVJHkc2Mo+LMzY30tpmlm+bYVTPjdd54w++dlrn7xyV9QtLV2dPbpuSjr9maabhm\nQjM3N8o4M0v/TpHHrtoZr/NWRPj3SZo35v4Hs2VNwd33Zb8HJd2j5pt9eODYJKnZ78GC+/mdZpq5\nebyZpdUEx66ZZrwuIvyPSFpkZgvMbLqkT0vaUUAf72Bms7I3YmRmsyQtV/PNPrxD0prs9hpJ9xbY\ny+9plpmby80srYKPXdPNeO3uDf+RdIlG3/F/QdLfFtFDmb4+JOmx7OfJonuTdLtGnwYOa/S9kask\nvU9St6TnJP23pPYm6u27kh6XtFujQessqLdlGn1Kv1vSruznkqKPXaKvQo4bn/ADguINPyAowg8E\nRfiBoAg/EBThB4Ii/EBQhB8IivADQf0/sEWOix6VKakAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fdde7991f60>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"plt.imshow(train_x[0][0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create the CNN model\n",
"\n",
"TODO: plot the net structure"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('conv1', (32, 14, 14))\n",
"('relu1', (32, 14, 14))\n",
"('conv2', (32, 7, 7))\n",
"('relu2', (32, 7, 7))\n",
"('pool', (32, 4, 4))\n",
"('flat', (512,))\n",
"('dense', (10,))\n"
]
},
{
"data": {
"text/plain": [
"<singa.layer.Dense at 0x7fddc5ca3b00>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from singa import net as ffnet\n",
"from singa.layer import Conv2D, MaxPooling2D, Dropout, Activation, Flatten, Dense\n",
"from singa import optimizer, loss, metric\n",
"from singa import layer\n",
"layer.engine = 'singacpp'\n",
"net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy())\n",
"net.add(Conv2D('conv1', 32, 3, 2, input_sample_shape=(1,28,28)))\n",
"net.add(Activation('relu1'))\n",
"net.add(Conv2D('conv2', 32, 3, 2))\n",
"net.add(Activation('relu2'))\n",
"net.add(MaxPooling2D('pool', 3, 2))\n",
"net.add(Flatten('flat'))\n",
"net.add(Dense('dense', 10))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialize the parameters\n",
"\n",
"* weight matrix - guassian distribution\n",
"* bias - 0"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"conv1/weight (32, 9) 0.07648436725139618\n",
"conv1/bias (32,) 0.0\n",
"conv2/weight (32, 288) 0.08030246943235397\n",
"conv2/bias (32,) 0.0\n",
"dense/weight (512, 10) 0.07954108715057373\n",
"dense/bias (10,) 0.0\n"
]
}
],
"source": [
"for pname, pval in zip(net.param_names(), net.param_values()):\n",
" if len(pval.shape) > 1:\n",
" pval.gaussian(0, 0.1)\n",
" else:\n",
" pval.set_value(0)\n",
" print(pname, pval.shape, pval.l1())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set up the optimizer and tensors"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from singa import tensor\n",
"#from singa.proto import core_pb2\n",
"from singa import device\n",
"from singa import utils\n",
"cpu = device.get_default_device()\n",
"\n",
"opt = optimizer.SGD(momentum=0.9, weight_decay=1e-4)\n",
"batch_size = 32\n",
"num_train_batch = old_div(train_x.shape[0], batch_size)\n",
"\n",
"tx = tensor.Tensor((batch_size, 1, 28, 28))\n",
"ty = tensor.Tensor((batch_size,), cpu , tensor.int32)\n",
"\n",
"# for progress bar\n",
"from tqdm import tnrange\n",
"idx = np.arange(train_x.shape[0], dtype=np.int32)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conduct SGD\n",
"\n",
"1. process the training data multile time, each time is called on epoch; \n",
"2. for each epoch, read the data as mini-batches in random order\n",
"3. for each mini-batch, do BP and update the parameters "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b9a8c3eff6744b90a2be87729cb60cd5"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch = 0, training loss = 0.291366, training accuracy = 0.907370\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4cca5a2fff6b47859adafae922ea20d9"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch = 1, training loss = 0.111163, training accuracy = 0.965089\n"
]
}
],
"source": [
"for epoch in range(2):\n",
" np.random.shuffle(idx)\n",
" loss, acc = 0.0, 0.0\n",
" \n",
" bar = tnrange(num_train_batch, desc='Epoch %d' % epoch)\n",
" for b in bar:\n",
" x = train_x[idx[b * batch_size: (b + 1) * batch_size]]\n",
" y = train_y[idx[b * batch_size: (b + 1) * batch_size]]\n",
" tx.copy_from_numpy(x)\n",
" ty.copy_from_numpy(y)\n",
" grads, (l, a) = net.train(tx, ty)\n",
" loss += l\n",
" acc += a\n",
" for (s, p, g) in zip(net.param_names(), net.param_values(), grads):\n",
" opt.apply_with_lr(epoch, 0.01, g, p, str(s), b)\n",
" # update progress bar\n",
" bar.set_postfix(train_loss=l, train_accuracy=a)\n",
" print('Epoch = %d, training loss = %f, training accuracy = %f' % (epoch, old_div(loss, num_train_batch), old_div(acc, num_train_batch)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save model to disk"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"net.save('checkpoint')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load model from disk"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"NOTE: If your model was saved using pickle, then set use_pickle=True for loading it\n"
]
}
],
"source": [
"for pval in net.param_values():\n",
" pval.set_value(0)\n",
"net.load('checkpoint')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Do prediction"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from PIL import Image\n",
"img = Image.open('static/digit.jpg').convert('L')\n",
"img = img.resize((28,28))\n",
"img = old_div(np.array(img, dtype=np.float32),255)\n",
"img = tensor.from_numpy(img)\n",
"img.reshape((1,1,28,28))\n",
"y=net.predict(img)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fddc4b29240>]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAG7xJREFUeJzt3WtsY2d+3/Hvn9RtJNJzk4Zaz4xnxrbIZjZIuhvBu80C\n7aLrBva2tV+0KWw0vQSL+E2cbppFCqct3MJ9labYXgA3rZOmadNkHa+7KAbptA7QbNug6C483k22\nazuk5PHYM7MjSnM1KY2u/PcFeTQaWhdKOuQhz/l9AMMidYb8gxj99Mzz/M/zmLsjIiLxkoq6ABER\nCZ/CXUQkhhTuIiIxpHAXEYkhhbuISAwp3EVEYkjhLiISQwp3EZEYUriLiMRQX1RvPDo66qdPn47q\n7UVEetJbb7113d3HdrousnA/ffo0Fy5ciOrtRUR6kpl90Mp1mpYREYkhhbuISAwp3EVEYkjhLiIS\nQwp3EZEY2jHczew3zGzWzL6/xffNzP6VmU2b2ffM7NPhlykiIrvRysj9N4Entvn+k8BE47/ngF/d\nf1kiIrIfO4a7u/9v4OY2lzwN/Eev+xZwyMw+EVaBIr2murTK1y9cRkdYSpTCmHM/Dlze8PhK47mP\nMbPnzOyCmV2Ym5sL4a1Fus83vnOFX3z9e7xz7aOoS5EE6+iCqru/4u6T7j45Nrbj3bMiPelPZioA\nlMqViCuRJAsj3K8CJzc8PtF4TiSRSo1wL85UI65EkiyMcD8H/M1G18xngTvufi2E1xXpOe5OsayR\nu0Rvx43DzOxrwOeBUTO7AvwjoB/A3f8NcB74IjANLAA/3a5iRbrdzEeLVBZXSaeM4ozCXaKzY7i7\n+7M7fN+Bnw2tIpEeVirXp2J+/JGj/OHUdapLq2QGI9t8VRJMd6iKhCiYb//LP/IgAFOampGIKNxF\nQlQsVxjLDvLYmSOA5t0lOgp3kRCVyhUKuSwnjwwz1J9Sx4xERuEuEpJazZkqV8nnsqRTxsSxrEbu\nEhmFu0hIrty6y92VNQrjGQDyOYW7REfhLhKSoL89n8sCUBjPMFtZ4tb8cpRlSUIp3EVCEozSJxrh\nHvxfo3eJgsJdJCTFmQrHDx1Y72svKNwlQgp3kZCUyhUK49n1x584OER2sG99ukakkxTuIiFYWatx\ncW5+fb4dwMzIj2fX71oV6SSFu0gIPrgxz/Jabb1TJhB0zOjgDuk0hbtICIKblSaOZe97Pp/LcHth\nhbnKUhRlSYIp3EVCUCxXSBk8euz+kXuwqKp5d+k0hbtICKbKFU4fHWGoP33f8/nxoGNG8+7SWQp3\nkRAUy5X7FlMDo5lBjo4MrO8WKdIpCneRfVpcWePS9fn1UXqzfC6raRnpOIW7yD69N1el5vXF083k\ncxmmyhVqNXXMSOco3EX2KbgDtbDJtAzU593nl9e4evtuJ8uShFO4i+xTqVylP22cHh3Z9PtB6E/N\nampGOkfhLrJPpZkKj4xl6E9v/uMUbCCmgzukkxTuIvu0VadM4OCBfj5xcEgbiElHKdxF9qG6tMqV\nW3e3XEwNTOSyFNUOKR2kcBfZh6mmAzq2UshlmJ6rsqaOGekQhbvIPkw17jwtbNHjHsjnsiyv1vjg\nxnwnyhJRuIvsR7FcYag/xcnDw9teVxjXwR3SWQp3kX0oNRZTUynb9rpHj2UwU8eMdI7CXWQfijOV\nj23zu5nhgT5OHh7WyF06RuEuske3F5aZrSx97ICOrWiPGekkhbvIHgXb+O7UKRMojGe4dH2epdW1\ndpYlAijcRfYsGIXv1CkTyOeyrNac96+rY0bar6VwN7MnzKxoZtNm9sIm33/IzL5pZt81s++Z2RfD\nL1Wku5RmKmQH+xh/YKil64NfArqZSTphx3A3szTwMvAkcBZ41szONl32D4HX3P1TwDPAvw67UJFu\nUyxXyI9nMdu+UyZwZnSEdMq0qCod0crI/TFg2t0vuvsy8CrwdNM1DjzQ+Pog8IPwShTpPu6+3gbZ\nqsG+NGdGR9QOKR3R18I1x4HLGx5fAT7TdM0/Bn7fzH4OGAEeD6U6kS41V13i9sIKhR32lGlWyGX5\n/g/utKkqkXvCWlB9FvhNdz8BfBH4LTP72Gub2XNmdsHMLszNzYX01iKdV2qMvrc6Wm8r+VyWD28u\nsLC82o6yRNa1Eu5XgZMbHp9oPLfRl4DXANz9/wJDwGjzC7n7K+4+6e6TY2Nje6tYpAsUW9wwrFlh\nPIM7TM9qakbaq5VwfxOYMLMzZjZAfcH0XNM1HwJfADCzH6Ie7hqaS2yVZiocHRlgNDO4qz937+AO\nLapKe+0Y7u6+CjwPvAG8S70r5m0ze8nMnmpc9hXgZ8zsj4GvAX/b3bW3qcTWTgd0bOXUkWEG+lLq\nmJG2a2VBFXc/D5xveu7FDV+/A3wu3NJEupO7M1Wu8JOTJ3e+uElfOsWjY5n1u1tF2kV3qIrs0tXb\nd5lfXtvTyB3qNzNp5C7tpnAX2aXS+mLq7togA/lclmt3FrlzdyXMskTuo3AX2aXgJqSJPY7cg18K\nUxq9Sxsp3EV2qVSu8ImDQxw80L+nPx9M52j7X2knhbvILu1224Fmxw8dYGQgvX7+qkg7KNxFdmGt\n5kzNVlve5nczqZQxkcuq113aSuEusgsf3JhnebXGxLG9LaYGCjl1zEh7KdxFdqG0ywM6tjKRy3Bj\nfpnr1aUwyhL5GIW7yC4UZ6qYwaP7Hbk3fjmUNDUjbaJwF9mF0myFh44MMzzQ0s3dWyo0FmQ1NSPt\nonAX2YXSzP46ZQJj2UEODfdTVMeMtInCXaRFS6trvH99fs93pm5kZuS1qCptpHAXadH71+dZrXko\nI3eo36lamqmgDVSlHRTuIi0K+tL32ykTKOSyVJZWuXZnMZTXE9lI4S7Soqlylb6U8fDo/qdl4N42\nBJqakXZQuIu0qFiucGZ0hIG+cH5sFO7STgp3kRbtd0+ZZodHBjiWHVzfZVIkTAp3kRYsLK/y4c2F\nUMMdUMeMtI3CXaQF07NV3KEwHs58eyCfyzI1W6FWU8eMhEvhLtKC4MzTsEfuhfEMiys1Lt9aCPV1\nRRTuIi0olSsM9KU4dXQk1NddP7hDe8xIyBTuIi0ozlR4dCxDOmWhvu6EOmakTRTuIi0olSuh3by0\nUWawj+OHDmiPGQmdwl1kBx8trnDtzmLo8+2BwnhWh2VL6BTuIjuYWj+gI9xOmUA+l+W9uSora7W2\nvL4kk8JdZAfBTUbtG7lnWFlzLl2fb8vrSzIp3EV2UCpXGBlIc/zQgba8/nrHjKZmJEQKd5EdFGcq\nTOSymIXbKRN4ZCxDynTknoRL4S6yg6nZyvqxeO0w1J/m9NGR9RulRMKgcBfZxvXqEtery+Tb0Aa5\nkfaYkbAp3EW2EQRuO0fuAPnxLJduzLO4stbW95HkaCnczewJMyua2bSZvbDFNX/NzN4xs7fN7HfC\nLVMkGsE8eBjnpm6nkMtS8/oGZSJh2DHczSwNvAw8CZwFnjWzs03XTAC/BHzO3T8J/HwbahXpuGK5\nyqHhfsayg219n+CXh6ZmJCytjNwfA6bd/aK7LwOvAk83XfMzwMvufgvA3WfDLVMkGlONAzra1SkT\nOD06Qn/atKgqoWkl3I8Dlzc8vtJ4bqM8kDez/2Nm3zKzJzZ7ITN7zswumNmFubm5vVUs0iHuTrHc\n3k6ZQH86xSNjGY3cJTRhLaj2ARPA54FngV8zs0PNF7n7K+4+6e6TY2NjIb21SHvMfLRIZXG17Z0y\ngXwuq61/JTSthPtV4OSGxycaz210BTjn7ivu/j5Qoh72Ij0rCNr8sfYupgYK41mu3r5LZXGlI+8n\n8dZKuL8JTJjZGTMbAJ4BzjVd81+oj9oxs1Hq0zQXQ6xTpOOCKZJ27SnTbKLxS2RKHTMSgh3D3d1X\ngeeBN4B3gdfc/W0ze8nMnmpc9gZww8zeAb4J/KK732hX0SKdUCpXOZYd5PDIQEfeL9gvXtv/Shj6\nWrnI3c8D55uee3HD1w78QuM/kVho1wEdWzl5eJih/tT6LpQi+6E7VEU2Uas5pXKFiWOdC/dUyrQN\ngYRG4S6yicu3FlhcqbXtgI6t5HNZbf0roVC4i2xivVOmQ4upgUIuy1xliVvzyx19X4kfhbvIJoKO\nlYkOh/uEtiGQkCjcRTZRnKlw4vABMoMt9RyEJljAVbjLfincRTZRauwp02njDwyRHerTvLvsm8Jd\npMnKWo335qqRhLuZUchlKakdUvZJ4S7S5NL1eVbWvOOdMoH8eL1jpn77iMjeKNxFmgTb7kYxcof6\nXjZ37q4wV1mK5P0lHhTuIk2K5Qopg0fGohu5B3WI7JXCXaRJaabC6aMjDPWnI3n/YP94bf8r+6Fw\nF2kSVadM4GhmkNHMgNohZV8U7iIbLK6scenGfMcO6NhKfRsCdczI3incRTZ4b65KzenI0Xrbyeey\nTJcr1GrqmJG9UbiLbBBMhUTVBhnI57LML69x9fbdSOuQ3qVwF9mgOFOlP22cOjoSaR3BLxfNu8te\nKdxFNiiVKzwylqE/He2PRrBhmdohZa8U7iIbFGei7ZQJPDDUz4MHhyipHVL2SOEu0lBdWuXq7bsd\nPVpvOxO57PrdsiK7pXAXaQgOpu6GkTvUt/+dnquyulaLuhTpQQp3kYbSerhH2ykTyOeyLK/W+ODm\nQtSlSA9SuIs0FGeqDPWnOHl4OOpSgHu99pp3l71QuIs0BNsOpFIWdSkAPHosg5k6ZmRvFO4iDVHv\nKdPswECah44MM6VFVdkDhbsIcGt+mdnKUuTbDjSr7zGjkbvsnsJdhHuLqRNdspgaKOSyvH99nqXV\ntahLkR6jcBdh454yXTZyH8+yVnMuzs1HXYr0GIW7CPWj9bJDfYw/MBR1KfdZ75jR1IzsksJdhHpH\nSiGXxaw7OmUCZ0ZH6EuZwl12TeEuiefu9U6ZLpuSARjoS3FmdITijDpmZHdaCncze8LMimY2bWYv\nbHPdXzEzN7PJ8EoUaa+5yhK3F1bIH+uuxdRAfjyrkbvs2o7hbmZp4GXgSeAs8KyZnd3kuizwZeDb\nYRcp0k5Bq2E3jtyhPu/+4c0FFpZXoy5FekgrI/fHgGl3v+juy8CrwNObXPdPgF8GFkOsT6Ttgp0X\nu63HPRDcWDU9q6kZaV0r4X4cuLzh8ZXGc+vM7NPASXf/ryHWJtIRpZkKo5kBjmYGoy5lU8FGZkXt\nMSO7sO8FVTNLAV8FvtLCtc+Z2QUzuzA3N7fftxYJRbHLth1oduroCAN9Kc27y660Eu5XgZMbHp9o\nPBfIAj8M/E8zuwR8Fji32aKqu7/i7pPuPjk2Nrb3qkVCUqs5U10e7umUMXEsQ1F7zMgutBLubwIT\nZnbGzAaAZ4BzwTfd/Y67j7r7aXc/DXwLeMrdL7SlYpEQXb19l/nlta4Od6ivB2jrX9mNHcPd3VeB\n54E3gHeB19z9bTN7ycyeaneBIu00NRtsO9CdbZCB/HiWmY8WuXN3JepSpEf0tXKRu58Hzjc99+IW\n135+/2WJdEZwc9BEl4/cg0XVqXKFydNHIq5GeoHuUJVEK5UrPHhwiAeG+qMuZVvBtJG2/5VWKdwl\n0Yozla4ftQMcP3SAkYG05t2lZQp3Say1mjM9V+26bX43Y2bkx3Vwh7RO4S6J9cGNeZZXa13fKRMo\n5LI6ck9apnCXxFo/oKNHwn0il+XG/DLXq0tRlyI9QOEuiVWcqWIGj3bpbpDN1g/u0Ly7tEDhLolV\nKld46MgwBwbSUZfSknyjF1/z7tIKhbskVrfvKdNsLDPI4eF+7TEjLVG4SyItra5x6fp8z8y3Q6Nj\nJpdd36JYZDsKd0mk96/Ps1rzrj2gYyv5xh4z7h51KdLlFO6SSMHe6MFt/b0iP56lsrTKtTs6E0e2\np3CXRCqVK/SljIdHeyvcC9qGQFqkcJdEKs5UOTNaPwSjlwT/0lA7pOykt/5mi4RkarbSc/PtAIeG\nB8g9MKhFVdmRwl0SZ2F5lQ9vLvRUp8xG9Y4Zjdxlewp3SZzp2SruvbeYGsjnskzNVlirqWNGtqZw\nl8S51ynTmyP3Qi7L4kqNyzcXoi5FupjCXRKnVK4w0Jfi1NGRqEvZk2CtQB0zsh2FuyROqVxl4liG\ndMqiLmVPJo7dO3JPZCsKd0mcUrnSs4upACODfZw4fICiOmZkGwp3SZQ7d1e4dmexJ47W206hsQ2B\nyFYU7pIowVRGYbw3O2UC+fEs781VWV6tRV2KdCmFuyRKsAjZq50ygUIuy2rNuXRjPupSpEsp3CVR\npspVRgbSHD90IOpS9iX45aSbmWQrCndJlOJMfdsBs97slAk8PDZCyrTHjGxN4S6JUipXyB/r7SkZ\ngKH+NKdHR9TrLltSuEtiXK8ucWN+uSc3DNtMQacyyTYU7pIYwRRGL/e4b5TPZbl0Y57FlbWoS5Eu\npHCXxAgWH/M93gYZKIxnca9vhCbSTOEuiVEsVzk83M9YZjDqUkKxfnCH5t1lEwp3SYxSucJErvc7\nZQKnjo4wkE5pUVU21VK4m9kTZlY0s2kze2GT7/+Cmb1jZt8zs/9hZqfCL1Vk79yd0kxv7ynTrD+d\n4uGxEbVDyqZ2DHczSwMvA08CZ4Fnzexs02XfBSbd/UeA14F/GnahIvtx7c4ilaXV2HTKBArj6piR\nzbUycn8MmHb3i+6+DLwKPL3xAnf/prsHJwd8CzgRbpki+xPMS8dp5A71jpmrt+9SWVyJuhTpMq2E\n+3Hg8obHVxrPbeVLwH/b7Btm9pyZXTCzC3Nzc61XKbJP650yPXq03laCbQim1DEjTUJdUDWznwIm\ngV/Z7Pvu/oq7T7r75NjYWJhvLbKt4kyVY9lBDg0PRF1KqIJ/iWjeXZr1tXDNVeDkhscnGs/dx8we\nB/4B8OfcfSmc8kTCUSpXKMRsvh3gxOEDHOhPq2NGPqaVkfubwISZnTGzAeAZ4NzGC8zsU8C/BZ5y\n99nwyxTZu1rNmZqt9Pw2v5tJpYx8LsOUFlWlyY7h7u6rwPPAG8C7wGvu/raZvWRmTzUu+xUgA3zd\nzP7IzM5t8XIiHXf51gKLK7XYLaYG8rmsRu7yMa1My+Du54HzTc+9uOHrx0OuSyQ0xZlg24H4hvvX\n37rCzflljozEa01B9k53qErsBZ0yE8fi1SkTCH5paRsC2UjhLrFXLFc5cfgAI4Mt/UO15xR0KpNs\nQuEusTdVjte2A81yDwzywFCfwl3uo3CXWFtZq/HeXDW28+0AZlbfhmBGHTNyj8JdYu3S9XlW1jzW\nI3eAiUbHjLtHXYp0CYW7xFrQIjgRs20HmhVyWe7cXWG2ovsHpU7hLrFWmqmQMnhkLN7hHtygVdQ2\nBNKgcJdYK5WrnB4dYag/HXUpbaVTmaSZwl1irRTzTpnA0cwgo5lBhbusU7hLbC2urHHpxnws95TZ\nTD6Xoag9ZqRB4S6xNT1bpeYkKNyzTJUr1GrqmBGFu8TY+ulL4/FeTA0UxrMsLK9x9fbdqEuRLqBw\nl9gqlasMpFOcOjoSdSkdkdc2BLKBwl1iq1Su8PDYCP3pZPw1DzpmtP2vgMJdYqw4E88DOraSHern\nwYNDOnJPAIW7xFRlcYWrt+/G8mi97eTHs+qYEUDhLjE1NVsPuCSN3KG+DcF7s1VW12pRlyIRU7hL\nLE0FnTIJC/d8LsvyWo0Pbi5EXYpETOEusVScqXKgP82JwweiLqWjgmkozbuLwl1iqVSuMJHLkEpZ\n1KV01CNjGczUMSMKd4mpYjlZnTKBAwNpTh0ZVq+7KNwlfm7NLzNXWUrcfHsgn8tq619RuEv8BKPW\nOB+tt53CeJZLNxZYWl2LuhSJkMJdYqeU0E6ZQD6XZa3mXJybj7oUiZDCXWKnWK6QHeoj98Bg1KVE\nQnvMCCjcJYZKM1UKuSxmyeqUCZwZHaEvZZp3TziFu8SKu9c7ZRI63w4w0Jfi4bERjdwTTuEusTJX\nWeLO3ZXEzrcH8rksJe0xk2gKd4mV4OadJPa4b1TIZfnw5gILy6tRlyIRUbhLrATzzMHe5kk10fjl\nNqXRe2K1FO5m9oSZFc1s2sxe2OT7g2b2u43vf9vMToddqEgrSuUKo5kBjmaS2SkTCPaY0TYEybVj\nuJtZGngZeBI4CzxrZmebLvsScMvdHwX+OfDLYRcq0opiuZr4KRmAh44MM9iX0gZiCdbKyP0xYNrd\nL7r7MvAq8HTTNU8D/6Hx9evAFyypfWgSmVrNmU7onjLN0iljIpehNKtpmaTqa+Ga48DlDY+vAJ/Z\n6hp3XzWzO8BR4HoYRW702puX+bU/vBj2y/Ysj7qALlJzZ355LXGnL20ln8vye398jb/w1f8VdSld\nIayfFff9v9KXH8/z1I8+GEI1W2sl3ENjZs8BzwE89NBDe3qNQ8P9TCR8sayZoX8kBf70iUN84YeO\nRV1GV/jrnznF0motlDCKi9B+Vvb5MoeH+8OpYxuthPtV4OSGxycaz212zRUz6wMOAjeaX8jdXwFe\nAZicnNzT37if+OQ4P/HJ8b38UZFE+bFTh/mxU4ejLkMi0sqc+5vAhJmdMbMB4BngXNM154C/1fj6\nrwJ/4BouiIhEZseRe2MO/XngDSAN/Ia7v21mLwEX3P0c8O+A3zKzaeAm9V8AIiISkZbm3N39PHC+\n6bkXN3y9CPxkuKWJiMhe6Q5VEZEYUriLiMSQwl1EJIYU7iIiMaRwFxGJIYuqHd3M5oAP9vjHR2nD\n1gY9TJ/H/fR53KPP4n5x+DxOufvYThdFFu77YWYX3H0y6jq6hT6P++nzuEefxf2S9HloWkZEJIYU\n7iIiMdSr4f5K1AV0GX0e99PncY8+i/sl5vPoyTl3ERHZXq+O3EVEZBs9F+47HdadFGZ20sy+aWbv\nmNnbZvblqGvqBmaWNrPvmtnvRV1L1MzskJm9bmZ/YmbvmtmfibqmqJjZ3238nHzfzL5mZkNR19Ru\nPRXuLR7WnRSrwFfc/SzwWeBnE/xZbPRl4N2oi+gS/xL47+7+p4AfJaGfi5kdB/4OMOnuP0x96/LY\nb0veU+FOa4d1J4K7X3P37zS+rlD/wT0ebVXRMrMTwF8Efj3qWqJmZgeBP0v9rAXcfdndb0dbVaT6\ngAONk+KGgR9EXE/b9Vq4b3ZYd6IDDcDMTgOfAr4dbSWR+xfA3wNqURfSBc4Ac8C/b0xT/bqZjURd\nVBTc/Srwz4APgWvAHXf//Wirar9eC3dpYmYZ4D8DP+/uH0VdT1TM7C8Bs+7+VtS1dIk+4NPAr7r7\np4B5IJFrVGZ2mPq/8M8ADwIjZvZT0VbVfr0W7q0c1p0YZtZPPdh/292/EXU9Efsc8JSZXaI+Xffn\nzew/RVtSpK4AV9w9+Nfc69TDPokeB9539zl3XwG+Afx4xDW1Xa+FeyuHdSeCmRn1+dR33f2rUdcT\nNXf/JXc/4e6nqf+9+AN3j/3obCvuPgNcNrNC46kvAO9EWFKUPgQ+a2bDjZ+bL5CAxeWWzlDtFlsd\n1h1xWVH5HPA3gP9nZn/UeO7vN867FQH4OeC3GwOhi8BPR1xPJNz922b2OvAd6l1m3yUBd6rqDlUR\nkRjqtWkZERFpgcJdRCSGFO4iIjGkcBcRiSGFu4hIDCncRURiSOEuIhJDCncRkRj6/5J3/U/FAH0D\nAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fdde5663da0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"prob=tensor.to_numpy(y)[0]\n",
"plt.plot(list(range(10)), prob)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Debug\n",
"\n",
"Print l1 norm or parameter and layer feature\n",
"\n",
"1. parameter initialization\n",
"2. learning rate\n",
"3. weight decay\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"conv1/weight (32, 9) 7.971656799316406\n",
"conv1/bias (32,) 0.0\n",
"conv2/weight (32, 288) 8.005664825439453\n",
"conv2/bias (32,) 0.0\n",
"dense/weight (512, 10) 7.921195983886719\n",
"dense/bias (10,) 0.0\n",
"\n",
"\n",
"Epoch 0\n",
"-->conv1: 4.059905\n",
"conv1-->relu1: 2.407956\n",
"relu1-->conv2: 620.319519\n",
"conv2-->relu2: 302.810760\n",
"relu2-->pool: 965.994873\n",
"pool-->flat: 965.994873\n",
"flat-->dense: 276680.062500\n",
"-->dense: 0.270273\n",
"dense-->flat: 0.270273\n",
"flat-->pool: 0.074453\n",
"pool-->relu2: 0.046011\n",
"relu2-->conv2: 8.166893\n",
"conv2-->relu1: 2.553801\n",
"relu1-->conv1: 299.891296\n",
"\n",
" loss = 68.231674, params\n",
"conv1/weight 9.855006217956543\n",
"conv1/bias 9.832422256469727\n",
"conv2/weight 8.0507173538208\n",
"conv2/bias 0.1285100281238556\n",
"dense/weight 8.218809127807617\n",
"dense/bias 0.0013595324708148837\n",
"\n",
"\n",
"Epoch 1\n",
"-->conv1: 17.634811\n",
"conv1-->relu1: 3.629616\n",
"relu1-->conv2: 1683.136475\n",
"conv2-->relu2: 389.035248\n",
"relu2-->pool: 934.496582\n",
"pool-->flat: 934.496582\n",
"flat-->dense: 1198527.500000\n",
"-->dense: 0.322575\n",
"dense-->flat: 0.322575\n",
"flat-->pool: 0.094647\n",
"pool-->relu2: 0.039437\n",
"relu2-->conv2: 8.781067\n",
"conv2-->relu1: 2.430810\n",
"relu1-->conv1: 502.457764\n",
"\n",
" loss = 79.148743, params\n",
"conv1/weight 14.95351791381836\n",
"conv1/bias 28.66775131225586\n",
"conv2/weight 8.543733596801758\n",
"conv2/bias 0.320060670375824\n",
"dense/weight 9.1849365234375\n",
"dense/bias 0.0028705699369311333\n",
"\n",
"\n",
"Epoch 2\n",
"-->conv1: 44.400776\n",
"conv1-->relu1: 8.777474\n",
"relu1-->conv2: 12810.666016\n",
"conv2-->relu2: 1705.242798\n",
"relu2-->pool: 2268.597656\n",
"pool-->flat: 2268.597656\n",
"flat-->dense: 4815859.500000\n",
"-->dense: 0.347976\n",
"dense-->flat: 0.347976\n",
"flat-->pool: 0.109863\n",
"pool-->relu2: 0.024067\n",
"relu2-->conv2: 7.023767\n",
"conv2-->relu1: 1.075155\n",
"relu1-->conv1: 418.399963\n",
"\n",
" loss = 76.419479, params\n",
"conv1/weight 29.47024154663086\n",
"conv1/bias 52.10609436035156\n",
"conv2/weight 10.518251419067383\n",
"conv2/bias 0.5825839042663574\n",
"dense/weight 12.961698532104492\n",
"dense/bias 0.004623417742550373\n",
"\n",
"\n",
"Epoch 3\n",
"-->conv1: 82.561668\n",
"conv1-->relu1: 2.029232\n",
"relu1-->conv2: 19918.304688\n",
"conv2-->relu2: 2221.255859\n",
"relu2-->pool: 3186.028076\n",
"pool-->flat: 3186.028076\n",
"flat-->dense: 8799418.000000\n",
"-->dense: 0.585852\n",
"dense-->flat: 0.585852\n",
"flat-->pool: 0.186150\n",
"pool-->relu2: 0.025595\n",
"relu2-->conv2: 10.546330\n",
"conv2-->relu1: 1.008846\n",
"relu1-->conv1: 2478.874512\n",
"\n",
" loss = 87.336548, params\n",
"conv1/weight 45.3760871887207\n",
"conv1/bias 122.97279357910156\n",
"conv2/weight 12.587639808654785\n",
"conv2/bias 0.9543725252151489\n",
"dense/weight 21.757226943969727\n",
"dense/bias 0.007790021598339081\n",
"\n",
"\n",
"Epoch 4\n",
"-->conv1: 181.790573\n",
"conv1-->relu1: 0.000030\n",
"relu1-->conv2: 1.017071\n",
"conv2-->relu2: 0.120676\n",
"relu2-->pool: 0.305297\n",
"pool-->flat: 0.305297\n",
"flat-->dense: 2401.528809\n",
"-->dense: 0.703576\n",
"dense-->flat: 0.703576\n",
"flat-->pool: 0.229087\n",
"pool-->relu2: 0.029813\n",
"relu2-->conv2: 13.549859\n",
"conv2-->relu1: 0.006924\n",
"relu1-->conv1: 4.986567\n",
"\n",
" loss = 68.686081, params\n",
"conv1/weight 59.74170684814453\n",
"conv1/bias 191.91458129882812\n",
"conv2/weight 14.60444450378418\n",
"conv2/bias 1.4008562564849854\n",
"dense/weight 30.463171005249023\n",
"dense/bias 0.010233680717647076\n",
"\n",
"\n",
"Epoch 5\n",
"-->conv1: 260.486359\n",
"conv1-->relu1: 0.000000\n",
"relu1-->conv2: 1.400860\n",
"conv2-->relu2: 0.062520\n",
"relu2-->pool: 0.062520\n",
"pool-->flat: 0.062520\n",
"flat-->dense: 296.831329\n",
"-->dense: 2.126958\n",
"dense-->flat: 2.126958\n",
"flat-->pool: 0.694517\n",
"pool-->relu2: 0.014730\n",
"relu2-->conv2: 6.709585\n",
"conv2-->relu1: 0.000000\n",
"relu1-->conv1: 0.000000\n",
"\n",
" loss = 65.246483, params\n",
"conv1/weight 72.80561828613281\n",
"conv1/bias 253.96200561523438\n",
"conv2/weight 16.481393814086914\n",
"conv2/bias 1.8957630395889282\n",
"dense/weight 38.40126419067383\n",
"dense/bias 0.014084763824939728\n",
"\n",
"\n",
"Epoch 6\n",
"-->conv1: 347.085114\n",
"conv1-->relu1: 0.000000\n",
"relu1-->conv2: 1.895745\n",
"conv2-->relu2: 0.013967\n",
"relu2-->pool: 0.013967\n",
"pool-->flat: 0.013967\n",
"flat-->dense: 33.670372\n",
"-->dense: 1.395184\n",
"dense-->flat: 1.395184\n",
"flat-->pool: 0.455570\n",
"pool-->relu2: 0.009508\n",
"relu2-->conv2: 3.835509\n",
"conv2-->relu1: 0.000000\n",
"relu1-->conv1: 0.000000\n",
"\n",
" loss = 38.581238, params\n",
"conv1/weight 84.56834411621094\n",
"conv1/bias 309.804443359375\n",
"conv2/weight 18.20405387878418\n",
"conv2/bias 2.4867684841156006\n",
"dense/weight 45.590518951416016\n",
"dense/bias 0.016868162900209427\n",
"\n",
"\n",
"Epoch 7\n",
"-->conv1: 416.296112\n",
"conv1-->relu1: 0.000000\n",
"relu1-->conv2: 2.486759\n",
"conv2-->relu2: 0.000000\n",
"relu2-->pool: 0.000000\n",
"pool-->flat: 0.000000\n",
"flat-->dense: 0.016868\n",
"-->dense: 1.538023\n",
"dense-->flat: 1.538023\n",
"flat-->pool: 0.502211\n",
"pool-->relu2: 0.000000\n",
"relu2-->conv2: 0.000000\n",
"conv2-->relu1: 0.000000\n",
"relu1-->conv1: 0.000000\n",
"\n",
" loss = 2.308411, params\n",
"conv1/weight 95.16199493408203\n",
"conv1/bias 360.06231689453125\n",
"conv2/weight 19.77170753479004\n",
"conv2/bias 3.043811798095703\n",
"dense/weight 52.087501525878906\n",
"dense/bias 0.01910635642707348\n",
"\n",
"\n",
"Epoch 8\n",
"-->conv1: 469.519379\n",
"conv1-->relu1: 0.000000\n",
"relu1-->conv2: 3.043824\n",
"conv2-->relu2: 0.000000\n",
"relu2-->pool: 0.000000\n",
"pool-->flat: 0.000000\n",
"flat-->dense: 0.019106\n",
"-->dense: 1.641877\n",
"dense-->flat: 1.641877\n",
"flat-->pool: 0.536123\n",
"pool-->relu2: 0.000000\n",
"relu2-->conv2: 0.000000\n",
"conv2-->relu1: 0.000000\n",
"relu1-->conv1: 0.000000\n",
"\n",
" loss = 2.301430, params\n",
"conv1/weight 104.70797729492188\n",
"conv1/bias 405.2940368652344\n",
"conv2/weight 21.199617385864258\n",
"conv2/bias 3.545147657394409\n",
"dense/weight 57.95866775512695\n",
"dense/bias 0.02116413414478302\n",
"\n",
"\n",
"Epoch 9\n",
"-->conv1: 525.609802\n",
"conv1-->relu1: 0.000000\n",
"relu1-->conv2: 3.545132\n",
"conv2-->relu2: 0.000000\n",
"relu2-->pool: 0.000000\n",
"pool-->flat: 0.000000\n",
"flat-->dense: 0.021164\n",
"-->dense: 2.182356\n",
"dense-->flat: 2.182356\n",
"flat-->pool: 0.712606\n",
"pool-->relu2: 0.000000\n",
"relu2-->conv2: 0.000000\n",
"conv2-->relu1: 0.000000\n",
"relu1-->conv1: 0.000000\n",
"\n",
" loss = 2.316046, params\n",
"conv1/weight 113.30311584472656\n",
"conv1/bias 446.00213623046875\n",
"conv2/weight 22.49642562866211\n",
"conv2/bias 3.9963467121124268\n",
"dense/weight 63.25651931762695\n",
"dense/bias 0.02255747839808464\n"
]
}
],
"source": [
"np.random.shuffle(idx)\n",
"ffnet.verbose=True\n",
"for pname, pval in zip(net.param_names(), net.param_values()):\n",
" if len(pval.shape) > 1:\n",
" pval.gaussian(0, 10)\n",
" else:\n",
" pval.set_value(0)\n",
" print(pname, pval.shape, pval.l1())\n",
"for b in range(10):\n",
" print(\"\\n\\nEpoch %d\" % b)\n",
" x = train_x[idx[b * batch_size: (b + 1) * batch_size]]\n",
" y = train_y[idx[b * batch_size: (b + 1) * batch_size]]\n",
" tx.copy_from_numpy(x)\n",
" ty.copy_from_numpy(y)\n",
" grads, (l, a) = net.train(tx, ty)\n",
" print('\\n loss = %f, params' % l)\n",
" for (s, p, g) in zip(net.param_names(), net.param_values(), grads):\n",
" opt.apply_with_lr(epoch, 0.01, g, p, str(s), b)\n",
" print(s, p.l1())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def vis_square(data):\n",
" \"\"\"Take an array of shape (n, height, width) or (n, height, width, 3)\n",
" and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)\"\"\"\n",
" \n",
" # normalize data for display\n",
" data = old_div((data - data.min()), (data.max() - data.min()))\n",
" \n",
" # force the number of filters to be square\n",
" n = int(np.ceil(np.sqrt(data.shape[0])))\n",
" padding = (((0, n ** 2 - data.shape[0]),\n",
" (0, 1), (0, 1)) # add some space between filters\n",
" + ((0, 0),) * (data.ndim - 3)) # don't pad the last dimension (if there is one)\n",
" data = np.pad(data, padding, mode='constant', constant_values=1) # pad with ones (white)\n",
" \n",
" # tile the filters into an image\n",
" data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))\n",
" data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])\n",
" \n",
" plt.imshow(data); plt.axis('off')"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"NOTE: If your model was saved using pickle, then set use_pickle=True for loading it\n"
]
}
],
"source": [
"np.random.shuffle(idx)\n",
"ffnet.verbose=False\n",
"net.load('checkpoint')\n",
"b=1\n",
"x = train_x[idx[b * batch_size: (b + 1) * batch_size]] \n",
"tx.copy_from_numpy(x)\n",
"\n",
"r = net.forward(False, tx, ['relu1', 'relu2']) "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGwdJREFUeJztnWmQHVd1x//9+s2iWTUzkmaTrNEyY1mWsGxLsmWMWYwt\nkJMQCGBMOSGh4gAJFBRZigpVBAJZWSoVioQlwQbsShnKEOLYsRxcIGNsIdmSbGNZ0mCto9FoZrTM\nvr3XnQ+nu2+PXs+MHL++3TP3//vy7tzbT33U3eed0/fec47lui4IIWaRSVoAQoh+qPiEGAgVnxAD\noeITYiBUfEIMhIpPiIFQ8QkxECo+IQZCxSfEQLI6T3Zb5j2p2Ca4s/sAAGB7y6aEJRF8eYD0yUR5\noknzPcs0dVpzHUuLT4iBUPEJMRAqPiEGQsUnxECo+IQYCBWfEAPRupw338m2XRG0c8dPJiiJYujO\nGwEAW//i2aAv59gAgMc6rwIArP3iVDDm7n8pVnnspUuD9uiWNu+k8lHeM6pNjjTj3CzLfyXnQ9fj\neJeMjY5GfqfY0OITYiBUfEIMJNWufra1JWjnTncDAEZ+5wYAQNXR4WBMl9s4vka5sVnP1bcb6oM+\nZ3WryLP3RS3yAMDAavntfvJbW4K+pV9/BgBQ9bEKAMDZm9TxjS/KLXdzuVjkcS5cCNplj/ZNG7Ob\nm9Rxm9bL54GDscgRhbV5AwCgd0t10Lfs2/sAACM71O67yv/eDwBwpyZjkaPrLXJf2v5zPOgbvVVk\n67lRXtPaPv1MLOf2ocUnxEBSbfFz3WcK+iof+iUAwF7dpo7TJE92qNAC5M+dD9qTm1cDAEo1yQMA\nox0TAICm+wvHxhplVm3xYdXnOvGGS1hZ9UjZjcsAKG8td6YnGMsuKgcAOLFKM53Tb64BACz/6r6g\nz5mQ6zdZrWxgTat4JnFN4E7Wyf/aeeFQ0Ff+gnxO3XZDLOe8FFp8QgyEik+IgaTa1Z+69bqgXfKT\n56aN5Y4e1ywN4CxSlyvqF7O8Z0SO0yQPAKz8gUiSfWJv0Oe7287qMQBA7SMh997JxyqPM64mrKyR\nkemDGVu1p3S9oCnGlsl1CMvoM7xcRbLWdXXHK0jU25Z/beYMqC0OtPiEGEiqLX75/mNBO295P4UJ\nlvzK7Npf0JddsTxoO50ndIoDACh7dG9B38U7NwMA7KxYXOsZfcuLYfIXB0QOb5LPHQotwQ4M6hEi\n5GVc+RV5nqJ8jYkG5afFtdTps+4fC+VwbtoIACg/a0d8o/jQ4hNiIKm2+OGlMh+rrAwA4HrLMEkz\ntaIhaFunuhKURHH29WK9Gn9cmbAkQr5XNvJYpWqh0xnXdP9Ccxrh5UQff0NY4259nmSu52xB34k7\nFgEA1nzheQDxzxPR4hNiIFR8Qgwk1a5+FGlx8f0ls+yZi0Gf/gUqRWbDuqBd3TIEAKj7kUw26lxe\nnIY3sZbxdulZFRXBUL6vL/Iruum+RSaNO/5UhTXrcvr9SU8AyFV4S42XLoHGBC0+IQYybyx+ptzb\n2x2x+SIJnBskmsr9xYE5jtTDiXeqKEH7Kfl0Rl9OSBpPjpoqAGpZL5PgUmyYbFNj0K7tFNsX9xJe\nFKffvzZot//H8CxHFh9afEIMhIpPiIHMG1c/LS6+T/a8TMLEu/N9buzFtQCA7PUqAcaKj4vbmORk\nIwC4eZlWzHoJOKLW0ZOg8+Org/aaz0mIrs4JUH8vykhr6Ky7X9AoAS0+IUaSbotvhUKVUjAxZK/v\nUH90pcN6nf8NyaRbXnIu6HMvDCQlTuCBAAA8y5aGexfGDT1WSXiSY7dfAwAoHUjO7tLiE2Igqbb4\n9jq13DGwUfbEV/9I3sniSoQ4G+4plQrMsjQFTs9B7f27pRFKvZXkvIO/dJdGBu6WGgStu5Kd/VjU\nLfNDy5yKOY6MD1p8QgyEik+IgViuxokXp6c9XbM8hCxAMk2dc76H0uITYiBaJ/e2t16r83QzsvO0\npNDa3rJpjiP1sLNb7fdPm0w7rn5zwpIIj770UwDpe4aA9N2zy4EWnxADoeITYiBUfEIMhIpPiIFQ\n8QkxECo+IQZCxSfEQFIdpPOqwzljDpwJl8vKpaR4hl0jNd+rHi0J+m6qewUA8PCZ1wEAhu5vDcbq\n7nsmVnmsOhWWe+yuFgDA5GK5j5kpdVzTHgklqtrVGas8mUWLgnb+Wgmrzl4cU30vHS78UgIBWPYS\nCUJz2poBAJnBkIxHXin6+WjxCTGQdFv8EKPvklJHgyskV3vrY73B2PnNSwAAix8M5UbPFz841Ulj\nyGmpWPpPtjwedH3sCx8FAAyukb8P/+2/BmM7fngLACA/GE/RSuesype/6tuS5CIq5dbYO7YCAKzy\nsqDPjaGs1vgb1gft5s+I5cy5yt5tWSw1CO59cHvQt+LzT0sjbg8ylO03KKvVLwlVej68LRhbSotP\nCCkG88biX1wjln75VyURRz6UMmnyzUsBAPbSJUFfVGHC18zU1KzDVokUhdSZJGTwTZKspDbzcNBX\nf98eAMDSNSul4w/U8Vb9Yu+L8Vj8TCj1lrNE2plBsarh3PUVXZKMwp3jmr5W+jeqQp3Hn28HAKz7\npEps+f333w4AyN2u8trbdXUAgPxFVSUpDnp3qKSf9d+e/rwuOqcScWa8CkTO6GjRzk2LT4iBUPEJ\nMZB54+rnvTmgqKyok4tlEibf1x+rDJEZWUMTQFaJXE6drn71r8WN3vHYJ4K+Dkdc/d43yeTRY6Nq\nAi13/GSs8uROd6s/vLbdIOW9rDpVJBKj3jWaDLn6tl10eYbXqteL7ID8++H7ONLiPTtnVP47Z2io\n6HJEUX1y5udkqlLZ5GK6+D60+IQYSKotvl9xBABW3XscQHR1GD9PehxLeHORXbUyaOeOHtd+fufA\nQQDAuj+vDvosb4JtyV1i3T/y87uDsQ48p1E6T55a2WQ02aQm/uy9UtAzUxFvptn1nyv0cPKh52ps\ntVjduj1qA5Su56j0579S57xkrGww3to+tPiEGAgVnxADSbWr706qyY9c95lpY/1/pHY2NRxMrkCC\nOzSS2LnDhCek3G1Sounf1n4NAPD7f/+JyO/Egb8GHkXpK2oHX27C26UXs6sftZ9j+N1bg/a1a48C\nAAbuvyJWOaJwJwp3KlrXXw0AqDqidonG4fTT4hNiIKm2+NO4JFJvpFUtoy29V7KL6kzaby+V3YL5\nvr45jtRDeCK0865yAMA9ne8DAJQ8/mzkd2KXqdLbcVYl8qBbWXw/qlAr3jN05g2qq3+X7Hxc9eRe\ndZhWoaYztLoKAFD1g1/Geh5afEIMhIpPiIHMH1ffw71JJq5qjimHzM3FG+gRhVVaMvdBGvETOQDA\nv+y4DwDwmc9/EABQh3QkDQnvmLNDCTJ0YW3ZCABo6VCvZ9Y35JUtiT0gYTLl8jpU2V380OTI82k5\nCyEkVaTb4kek3jp1WyUAoO2LqlxQvHucFOEJtPzZ3lmO1IgXK3DqfW1B175Radd9b49+ebLqkXLq\nZTehdVKWYjOVlfrlCT1Dp94q8oweU97Guv95HoC+Z2gmBt4hZbiqH9yt5Xy0+IQYSLotfgRZb7+M\nMzY2+4ExEE4ykRaL7+91v/l9+4K+7x6SDSornRe1yxOe+5iqkffW0movjmBY/2an7HKVaNTaIpti\nWh+oCvqciE00SeBo1kRafEIMhIpPiIGk2tXPVKtQ0/53bwAANLysf+nOJy3ufRhnRNznE+9SOf9X\nV4v7msQCVTgRR8Zr+5EUs+3jj4t84+KgPdInr0XLhpJduouifo88W7oko8UnxEAs99VWq3kNOD3t\nSW6DJsQIMk2dcxYEoMUnxECo+IQYiNbJvTu23qHzdDPyyJ5HAADbWzYlLImws1vtQkybTGmTZ8ct\n70xYEuHRJ38UtN+26oYEJVE8duzyQ3lp8QkxECo+IQZCxSfEQKj4hBgIFZ8QA6HiE2Igqd6r75y/\noNp+4cCMFD7MbOxQB3rJKDL9Khe5KdgdawAAL38qtA/eyypR+4LUhm/53svBUP6CuqZxMPYOlbO+\n63a5L9kBsS+LetWGstYHOkWemLMUWxMqtiPfJNcoMxIKxfWeHaeiNOjKXBiOVaaMl2F44nUql/9U\ntTzXuXK5VhU9qqZE9pmXii9D0f9FQkjqSbfFjyoP7Ej8kvP8ywVDmRUqQi0qbddCZGKFRJ/Z59Wt\nXPugROyd3C5W7JU/WxeMtX36mVjl6blBlbpu/5Pp5wpH58XteQTn6VUehXtKko7OGQG3dlV8AgFw\nvapH2SdUAdNLFTGzaX3Q9lO+FTNpCC0+IQZCxSfEQFLt6keRbW4CAOTO9BSMuePKFbLKSgvGFyJ9\n14gbWBGqKerulVx7DStkD/nZrer33SqR6+JOqcmjYuK2zZwL8dDn1YRs+0fjLRHlY5Wq58AvUpld\ntTLoyx07UfidYXnFdKviKejpu+zhMmJ+gdig9kCnkiuO0HlafEIMZN5ZfNj2jENpq26jg5oTMlVV\nOlBYKvxCu1yrylOqLy5L79P6wMz3oGSp/szI4fLhPq6t7J2fpXjaRHJFzFV+PAueHxxUcniVdPzl\nRWRCNnmq+OnmaPEJMRAqPiEGkmpXP1wMwfUmPdzZCmlYc6YaW3BUPlQ4Sebv5htdIa8BHfeqnWhx\n724oe/TZgr7h994ojblTwRWdbFOj+sN/TRxTk8D5iL0ibkaPnOGSbFatN9HnZXK2Qq8jziQn9wgh\nRSDVFj+M5U1+5LpOJyxJ+rlwvZR+tke9GIZONbsXe972iKWnM3fIhGL7B/YVjMVN5JLcxcHCPo3e\nYtjSz4gTr29Gi0+Igcwbi4/szMt42dYWjYKkE39ZCgD6rxXr1fKk2PfwspFO7KvaAQBWJh1xE1ZO\nrkeu/1zBWHhTjy5pw/fs0ipN/oae2M4d679OCEklVHxCDGTeuPq54ydnHjRwGe9SJrddFbSdEnFW\nqw7IRGjhnj49HLtTJhk7/lhiB5yE5PBxBwp38QVjmpbwpmEX2l1//34xQ3CjoMUnxEDSbfFnseR+\nlB4Rut+glojqXxSLn8TSZ3jCqnSTJNuI2i+vCyu0LOZGWFh/Uk/n9KOVFbVzx8YLxuKe1POhxSfE\nQKj4hBhIql39WV3VWcJzTcK++koAwORaFcNQtzO5abSB33xd0J7clwK7MqquSz5i/T7K/Y8bPzlI\nOO+gv+fBOTrLJHYRScGdIYToJtUWP7wjj3v0o8m/dBgAsPbuhAXxWPyzo0G7+sHeWY7UgzNSGH0X\nJL1IiKgMw+7Jbq0y0OITYiBUfEIMxIojg+dMOD3t6YjWIGQBk2maO+MJLT4hBqJ1cm97yyadp5uR\nnd0HAAA71t2SsCTCo4eeDNppu0aUJxpfHiCdMs0FLT4hBkLFJ8RAqPiEGAgVnxADoeITYiBUfEIM\nJNV79e31qqwyvibVYN64pBMA8OKQqrLT9TcS2VT2yN5Y5bncbLV+pBUA4PTZmKTxznXl2oI+a1KK\nLEaVgNaJX8XG9SLkksj2m6muDtqWXwwzlAAjCZmyK1cAACZWLw367BG5Z/YRic7LXxyIVQZafEIM\nhIpPiIGk2tW3BkeC9uFur/jhh8Rd67l1WTD24S/9GADww8eXB31x14HPrm6T84Sys+Z/fUw+q1XY\np13q1YufLH6NcwA4/Xa5DlWnVXGsqQr5PR++W16HVnz+6VjOHUV2hboHvbeJS2t5ok3WqGvV+Esv\nD9+eF2OVZ/z164L2ibdL8ha3SuUdLquWbLbN/65yFpbuLCz8WUxGrpZ8kSe3q2QyliPJOWrXyFjF\nd5TcUYVRXyu0+IQYSKot/ugGlYhja9uvAQDnjkgSg+Yp9at98A/lOLtRTZbEnbhjeL1Y2spfdBaM\n5StLgrYdpAiLx+JPVcpn1Q8KrcLgX2+L5ZyzMd6uylLXnBCvK/vEcwCAzKb1wdjoiioAwKKsegTd\nXPErAAw3q39/yX75rL9f7Wm/cNcWAMDFtcobadolHpszXpgFtxgMrJLno/aw6lv2dbl/F353q3x2\nKHmqy4svDy0+IQZCxSfEQFLt6pc+ptbl9227CQCwEjJRde6m5mBs/ILk98h26clQCgDZMZmxinJP\nJ2vUZS2NuRRSrnLm3Cb2mP6yUJO16v9ednH6681Yc2XQLr0grwFxuPdhRpvUNajqcgvOmfPqf5QN\nqMzEcZevGl8in5WnQ/fOkedprFHkXdSnxuJ45aDFJ8RAUm3xw6z6u30AAPfaqwEA/ZvUL/nFp2XZ\naiXitfiZSmWxyl7wdlhFlYcKG1on3mxj7V+VJcSw3YzazaeLqsd/FbSdS0pE5SqUnak80CN9McvT\ndu8rQdv1LKcb2s033iA3q/pkSJKY09Gt+ueX5TQhzyLjLYOOLRPPo3FvvMvRtPiEGMi8sfj+e06u\nXpY2qjtUbvKGf6qI/E6xscLVeyLeAy3PG6n6VZ8WeQAgd6anoK/rDllqbHl6rGAsbpyRkYI+P+Zi\nslrZGXc83vdon1xPYazE2G9vDdqTi8W6V5xUe/bzBd8oLlF59Qfv2Djt77Kj6hmKwyuixSfEQKj4\nhBjIvHH1s8tlAq9ng+ypHjqpRF/20+LvZZ6LqHDO4TWyG63mJ126xZlWFspfosrs2q9djiiGrqwD\nADQ8q1zcKHdXFxfXqGen4XmZTMsfPKJdjvA9698kNnjpPnn1yJ2M9xmixSfEQOaNxXeW1AIAhq6T\nSb5lj5cmKU5AOBqt+pCXPCHmJbwozr/n2qBdezS5Mtlh7IZ6AMBkldiXzLmLwVgiEm6VCbTxBnV/\nWnZJghedFaV8prap2IV8mZy/br+U8s7HLA8tPiEGQsUnxEBS7epbJcqd77teXH0rI65+7QO7tcsT\ntYd7cLPK/Vf9M/0TRLBk59lIi9ou2PoPz+iXI4KxzasBADXHvT0YEXsOdHJ+gzf5eizUeeCQfkEy\nsh+k7xo1uVfzirj2+cNH9Yig5SyEkFSRaotvL1cReOffKNbW6iub6fDYcSMsvj0RmqbKx73nqxDn\n9dfIqZO7LNPJqN2No42ScKL++7J/P4kJPX8ZGAAG18hn28Nqd2Hc0YFR2FeKJzS4XkUvXvVlWd7M\nO3qeIVp8Qgwk1RY/nBe+/fekbXfIz7Z+2zodX47Kg2ovuP4FIQTv+M27443mulyyzSr1Vvl5uUtW\nqTdXE1Mqq9nIN9WpPzyXIzM2dWmXVqxhiaGoeVmlirNG9V4bWnxCDISKT4iBWDp3LDk97Yl4w4SY\nRKapc86ca7T4hBiI1sm97S2bdJ5uRnZ2S171tMkDpE+mtMnzpnvuSVgS4Wff+lbQTts1uhxo8Qkx\nECo+IQZCxSfEQKj4hBgIFZ8QA6HiE2IgVHxCDCTVQTrk8un96E1B2/V+zqtOS5BM7e5TwVjudLc2\nmbJNErAztaoJAGDlVUiMfUKCm/Jne2OVYahVPeLnr5PrYdeogKblSyQPYO+ulqBv2YHpxT4XIrT4\nhBgILf4CYVLVgUTbd49L3yoppXXwsyoZRcc9+ix+7w5JOFHZI8kuyvpV6OmJD0phzxVfUpl33ani\nhxYv3avqHyz55kszHtd68+KgPbFEwogzUws3tIQWnxADocVfIIyuVCmk/Pf4jPdp3XNdMGaVSY6u\nqDRixcZPAFp/n5Q4RyitlPUWmZOwG1UyilzX6aLLYPeqij3+FZrcvjnoW7THK6P9lNrnnr1dxp2S\nOYPc5i20+IQYCBWfEAOhqz/P8WsP1O+3Cwf9jLdDJUGXDhffJzvmNSIyx+a9NHxxuPdhopYvndKQ\nvbMLbZ9jL1wX34cWnxADocWf5/hLYEu+UVg9Z/DOLQCApqe0ihTQ+s0XAahMtnZNTTBma3I87Kuv\nDNq5WqlcM7BKPfblD58r/JIB5tCA/yIh5FKo+IQYCF39Bcy5jTJJ1f41VZhEZ8EoZ2ho2t9d92wI\n2o3P6vH1R9tqCvqWPTtS0Ge3r9YhTmqgxSfEQGjxFyDuNimkWTogFl9nRN5sTIaMb/aJ57Sf34/U\nW/LI3oKx0bUNusVJFFp8QgyEik+IgdDVX4CcubkSAHDFw/0Akq8sPPzeGwEAdYeSqE2rqOouvBLZ\n5RKyPGGYCTTsv0sIAWjxFwx2naoDP9osljV/8EhS4kzj3AaZZFz52T3azz1VoWxb7YE+ANM9oNGr\nmzVLlA5o8QkxEFr8BUL/b60L2s2/SPZd2ie7ug0AYOW9aLeIKL3YsVSkXb7zKADAvnJt0OdmF34k\nXhS0+IQYCBWfEAOhq79AqPuOCsu1F9cCSH4ZD468crQ9dF7+TECE6h/vD9p+zlxreDR0RB1MhBaf\nEAOxXFdf7nCnp33hJionJCVkmjrnnLGkxSfEQKj4hBiI1sm9t12xee6DNPDYyWcBANtbNiUsibCz\n+8DcBxFSRGjxCTEQKj4hBkLFJ8RAqPiEGAgVnxADoeITYiCp3qvv5l5dFngrq++/c+ED2wAAQyvV\nJqnMlHwu/+mwOnD3C9pkIuRyocUnxEBSbfGzzU1B250Sc3r8Q1IEsfqU2va/+LsSmRb2EOK2/n23\niDzN/6vOk5kSmZZ+5WTQd+6tFQAAZzQcEUZIstDiE2IgVHxCDCTVrn6u56z6wwsfXvnlfQCATIt6\nDfAdfKukNHR8vGkfOtp65JwP9apTerXqz3zkiqCvbHmJNI68Eqs8hLwaaPEJMZBUW/yJHSqar7xb\nJsec5w/J59HjSYgUcORwCwCgY6qrYCzvqN9TP7MrIWmCFp8QA6HiE2IgqXb1K/edCtrupEyc5SOK\nMmTbZDIt36WvDvxVf3VMzhnqy916PQDgxAn1e9rhHtcmEyGXCy0+IQaSaouf7+sv7MzYAAC7qlL1\nacwU7JPv6yvo6/+YTEAue2ixbnEIeVXQ4hNiIKm2+GEujdSz6mqDdr777KWHa8NuqA/aE5NyOZvu\n352UOIRcFrT4hBgIFZ8QA5k3rr5VVgYAcKc8lz8T+s2KeV/+bBz5y46gXb7PkymByUZCXg20+IQY\nSKot/rQJPa/tT6Y5Z5Kb0AOUB1LTfiHoW/Yp2dRDe0/SDi0+IQZCxSfEQFLt6vu79AAA3h59Z2Aw\nIWGmM/CuawEAYxMjQZ+fiIOQtEOLT4iBpNriZzaqpbJTb68DAFzxX7J/3zmSbIKL+qckAceivsZE\n5SDk/wMtPiEGQsUnxEAsV+MuM6ennUvchMRMpqnTmvMYHYIQQtKFVotPCEkHtPiEGAgVnxADoeIT\nYiBUfEIMhIpPiIFQ8QkxECo+IQZCxSfEQKj4hBgIFZ8QA6HiE2IgVHxCDISKT4iBUPEJMRAqPiEG\nQsUnxECo+IQYCBWfEAOh4hNiIFR8QgyEik+IgVDxCTEQKj4hBvJ//HxftMJ4RmkAAAAASUVORK5C\nYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fddc4be1e48>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"r1 = tensor.to_numpy(r['relu1'])[0]\n",
"vis_square(r1)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADeFJREFUeJzt3W1sHMUdx/Hdu3NsEuzEECcXOyQOid0kQHDqJCQUqVVF\nSUqlFgm16oPKQxEtUFQKUktLS1ErIVJRlZKqQFupwBue1Qf1IUlLoa0gzyEJMSTkTEMebOw44JAn\nbHy325fV7H/RXU67s3v+fz/vZjS3M7nzL6MZ7c66vu87AHTJJD0AAPYRfEAhgg8oRPABhQg+oBDB\nBxQi+IBCBB9QiOADCuVsdvap7BfEbYLuhAnW+l+3f7NRXtnaZa3vMOv7d4q6JMeUtvE4jhxT2sbj\nOOn7zTL5glvuc8z4gEIEH1CI4AMKEXxAIaube07II8D+6OgZXybTtVDUuX1DRrk0NCTajFd9d14q\n6kYvPm2U517TI9r4xWJsY4Ip19ZqlPuubhdtpq/ZYGk0zPiASgQfUIjgAwrZXeNHxNv5eqL9v3/V\nMqN8do/cTyj17rc1HKftp+XXhoN/mi/qpn1ubxzDqdj+1SuM8pzvbYzkutnmZlF36MYFRnlKb0m0\nmfj7zaIuKr03zTbK7XfbW8+HYcYHFCL4gEIEH1CI4AMKpW5zL9PYKOq8EyfKf3D5IrO86dWIRiSd\naDO/tsEledGm/Yf2NvcqYXsjb/hac+Ou+XG5cTc2/YNY+u67boGoO3f3mFGuX7s1lr4/zLxHDhjl\npG+dYsYHFCL4gEIEH1AodWv8itbzIYqT6oxynP+wab8yb74oPNYdY2+1KbimP/n5S0SbzuvjuWEm\n/0CyN8eEKfb1G+V9jywTbTpv2mJrOMz4gEYEH1CI4AMKEXxAodRt7lUr98/tifXd+udx8zXG5q77\nHhd1a56VTwxq0bQn2b8ZZnxAIYIPKETwAYVqcnHqr7hY1Lkbd1nrv7DGvBml41vxndwyXtz7g+tE\nXaOzKZa+cu2zRN3JC80HqRr+Yu9mGceRJ0PnH+QEHgCWEXxAIYIPKETwAYVcP+S1VnHxBjrsdQYo\nlckX3LJtbAwEQLoQfEAhgg8oZPUGnpWtXTa7E9b37zTKaRuP4yQ7prSNx3FCfrOZ8rQjN1N2SRuZ\ndQe3ibqVbYvNCov7ZmG/WSWY8QGFCD6gEMEHFCL4gELJP52XyZplT763PJLrVsmtmyDq/LHyr36q\n9nPjVXZBh6wcetcolo6+U/Y6YRt5ftF8IZWbk3/W++9ZapTzm+XfWdVP7FWxmXf06ytE3ekZ5r9t\n9mp5qpQ/OnrGfYVhxgcUIviAQgQfUCj5NX7KHblB3jDS8oh85XOQu3CuqPN37YlkTNl5c4xyqTdd\nr+QOU9pTEHVLd5rr7K1d0ezLZNrPE3Xtd5u/WXbquaJNRLtLFZn6m/J/Q3HeBsSMDyhE8AGFCD6g\nEMEHFErd5p7bfYGoG8lPNMr1f90q2mQWma9j8nrkZlI18i8OibrgJtDRb8ibMab+uvzmTdWOHS/b\nJJefbpSLA4NxjaZqf3hzkVGe6bwWyXUr2ex0Gxpk3ZILRZ2/rSeSMaUNMz6gEMEHFCL4gELpW+Pv\nOyjq6refKPu599sazc9EtDQrvdEr6oZuNtf0LQ/HuJ4PEXyYZexyeZOR83xyrw0PE/bas/Zvm/+O\nomgRndErzYd0nL/JfaLRRa2irj6uAVXCDTlZKKLTfZjxAYUIPqAQwQcUIviAQqnb3PNOlN/Iy82W\nT1+5/9hhlON8ssn2Zl45mWL630xWuLFO1HV+7bC1/s96YbdR9kLa1Ids+CUqxmO6mfEBhQg+oBDB\nBxQi+IBCqdvcq8TYTHlsknuoP5a+sgs7RZ1/oM8oe6dOxdJ3pbL/eiXR/isx7UW5uWeTNzJilEc/\nvVS0qV+bss29GDHjAwoRfEAhgg8o5PoW3+XtDXSk/04ToMZl8oWQx/oCbWwMBEC6EHxAIYIPKETw\nAYWs3sCzsrXLZnfC+v6dRjlt43GcZMcUNp5Vs5YkMJL/W3dwm1FO428mvqNsNO8ArMS6/Zur+hwz\nPqAQwQcUIviAQjX5kI522aYmo1w6Xv6VWtXyixUcep2Ra9pD37/EKE8clPdutTxtnoFeyelLaSS+\no5DvzK03D+o+fLs8Ev2sI+Z3dM7v4jvpiRkfUIjgAwoRfEAhgg8oVJObexdul/9f9XSHHZicnMKD\ny0Vdx22bIrn28GcWGuWmJ0OuG9xw80qR9B0q5Nrn3buh/MfiGEuF5m5tEHX7b2gXdd6reyPpz51/\nvlFuW13++4kTMz6gEMEHFCL4gEI1scbf99vAiajd6T8N9dnPrhF1d922LJJrT9l9zCiHrZUHv2ne\nQDP9l8muKdPmzaUjoi7bPBhbf96uPbFduxrM+IBCBB9QiOADChF8QKGa2NzrvDH9m3lBX3r6NlE3\nx4nmaSuvx7ypZOD2S0Wb/ANs5p2p0vCwtb68jy8WdZl/77DWPzM+oBDBBxQi+IBCNbHGT7t9D8sb\nczpvju/0lKBTbel6QAnl2VzPh/afaO8AEkHwAYUIPqAQwQcUqonNvRNfNE+zaXwqmpNsovKJxfLJ\nq/4Y+8tOmWyUOx47Jtqw3WcKHknu+/K47ziP987lpxvl4kB8TwJWghkfUIjgAwoRfEAhN2ytExdv\noMNeZ4BSmXzBLdvGxkAApAvBBxQi+IBCBB9QyOoNPCtbu2x2J6zv32mU0zYex0l2TGkbj+PUxm+2\nao55lLmbsxertYWXq/ocMz6gEMEHFCL4gEIEH1CoJp/OO90i/79qfabXKJcGj8Q2nuFrVxjlKb3v\nizbuy3ITKC5xHq+dbW4Wdd7JU0bZH/ugqmvnZuSNcunou1VdR8hkZZ1XiubaIfwPzH+/PzoaW19R\nYcYHFCL4gEIEH1CoJtb4wRN3GkPaxLeCkyZ+5W2j7F7xlsXepTS+Liu47xA2xuLbA0bZrZsQSd+H\n77xE1M28L33fUZKY8QGFCD6gEMEHFCL4gEI1sbkXlF3YKepKr++z1n99YDPvoQMviTa3zL7M0mgc\nx7/0YlHnbthllINHcjuO45SOvRdJ/8VPdou6JDccp/Ta3OqtjNt9gVF+r1NuUTc9ae/YeGZ8QCGC\nDyhE8AGFanKNH7aeD74iqXT8uK3hONfuuUbUTXL+a63/4HrecRwn02iuIaNaz4fJvbA9tmtX4+xn\nN4u6gz8ybyia9RO7exCZt8ybvpq2v2a1/yBmfEAhgg8oRPABhQg+oFBNbu6FsbmZF7xhZdKqdG1u\nOY7dI57DTrzJzQi8D76vX7SxOcYZm6o7JSgqpXciOl0oIsz4gEIEH1CI4AMKjZs1vk1XPPgfo/zC\nRZMSGsmHKw0P2+ss7ARbzyv/OTeeeSds76Du79ti6atabn29qLN5Oi8zPqAQwQcUIviAQgQfUMj1\nfd9aZ95Ah73OAKUy+YJbto2NgQBIF4IPKETwAYWs3sCzsrXLZnfC+n7z1dVpG4/jJDumsPGsmrVE\n1IXdfBKXtYWXjTK/mSlsPJVgxgcUIviAQgQfUIjgAwrxdJ4S1T4N5heLIRcz5wu3Tv4ZDT/XapZ3\ntIg2c3+21yh7J0+VHU9Ush+ZJ+pKb/TG16Fr3lNz6uplokndSfOJxgnrtsY2HGZ8QCGCDyhE8AGF\nxs0av/fny43yvDvie+XwkVvN1zG1vCLXprk9B0Wd1VNxAuI83cU7fVrUTb7SXC9PduT6+dB3zO+x\n7RfxnZJz4MdmX7Pvka/QyrbIfYjS0FA0Awg8DDfpOfmar+Lzs4xy5vB80cbr2SvqqsGMDyhE8AGF\nCD6gEMEHFBo3m3sTzovn5o9jX10h6mY89YZRduvqRJtix0x5sS3RbO65Sy8yyv7W3aKNd5n5xFjm\npeqe4opT6/2BDba6CbH1FdzMy+WnizbFgcHY+q9E7nJzQ9jvviC2vpjxAYUIPqAQwQcUGjdr/Dk3\nHDDKIS91qkrzM6+IulLgZphsx/miTe7oCVEX8rhLVdzdBaMcdnSxzTV9dspkUVc69p61/qux97tz\nRN28O5Jd4wf5O6K5WScMMz6gEMEHFCL4gEIEH1CoJjf39j0kTy/pvGVLLH2FPdUWvIHGOSg3hYqD\nR2IZj+M4jjcyYpSzU88VbUpH34mt/yB/pLon/8LeY2/L/PvfEnVRbb5Gxotqi1pixgcUIviAQgQf\nUIjgAwrV5ObetI3ZRPvP3X/UKI/e3SbaZGLc3AvyZudlpcXNveBmY8Vce/NObqb5GxUP91nru1KZ\nhgajXPX3WklfsV0ZQGoRfEAhgg8o5Pp+2LNd8fAGOux1BiiVyRfcsm1sDARAuhB8QCGCDyhE8AGF\nrN7As7K1q3yjGK3vN4+jStt4AFuY8QGFCD6gEMEHFKrJh3TcxfLVQv6O1xIYCVCbmPEBhQg+oBDB\nBxQi+IBCNbG5N/DHBUY5f1X6N/LevX6FqDvn0Y0JjASQmPEBhQg+oBDBBxSqiTV+aVNz0kM4Yx+7\ndauo2/NoAgMBQjDjAwoRfEAhgg8oRPABhWpic69t9QajHHZyTdKn6QTtuOejoq7B2ZLASACJGR9Q\niOADChF8QKGaWOMHpW09f/zLy0Vd0xObEhgJUBlmfEAhgg8oRPABhQg+oFBNbO5lFs03yt6rexMa\nSbjBK8ZEXdMTCQwEqBAzPqAQwQcUIviAQq7v+9Y68wY67HUGKJXJF9yybWwMBEC6EHxAIYIPKETw\nAYWsbu4BSAdmfEAhgg8oRPABhQg+oBDBBxQi+IBCBB9QiOADChF8QCGCDyhE8AGFCD6gEMEHFCL4\ngEIEH1CI4AMKEXxAIYIPKETwAYUIPqAQwQcUIviAQgQfUOh/i3YTUOpYmq0AAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fdde5696208>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"r2 = tensor.to_numpy(r['relu2'])[0]\n",
"vis_square(r2)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(32, 288)\n"
]
}
],
"source": [
"p=net.param_values()[2]\n",
"print(p.shape)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACvFJREFUeJzt3Xts1lcdx/HztKUtpS2UDlool5TSMtYIExyDTTdAJUsW\n2KYuYRdiGAOBjGQYdSZTsznUxMSYiKOMLcpYlLjE6GTDibcoMgbCBlSuLWu5M2ih5dqu5Xn8Y/7t\n5xh/+W3J5/36+5PnnLbPp+efb87J5HK5AMBP3ke9AQAfDcoPmKL8gCnKD5ii/IApyg+YovyAKcoP\nmKL8gKmCNBf75BtPy3HCrpah8nNKj+v/WXu/sUZmxv3pMZkpLvlAZqoGX5aZEEL4a+NrMjO/bbbM\n7DhQJzP5l/JlpvWhtTIzbssimakbfU6vdWSEzLTdt05man+7RGYabzkhM8cuVshMCCE03/5Lmbl4\n45rMTP3LEzJTsa1IZnY/0yQzedUtGRkKnPyALcoPmKL8gCnKD5ii/IApyg+YovyAKcoPmEp1yGd0\nebfMXOqvTGEnH8p0FupQS7GMHKsaFLdgo45cvPOCzDT8uURmWvbXxOxIKjyuf0e9I/TXqGT41SS2\nE0JxVkZ6ntUDRbkVvUnsJoQQwoztX9GhywNk5NL4BDbzP+DkB0xRfsAU5QdMUX7AFOUHTFF+wBTl\nB0xRfsBUqkM+p9aPk5nSgfoSkhEPtiewmxAG1OjBk8z5MpmpqutIYjshhBC6N+tJjws7h8nMxLUn\n9WJf0pHvztc32ax64RGZ6auMeBNyho5k8vWQz4ml/fqDDpTrTAghTNeRgt36OxIG6Z+/9HjEfhLE\nyQ+YovyAKcoPmKL8gCnKD5ii/IApyg+YovyAqUwuFzF8kZDs2fr0FgNM8VwXgP+K8gOmKD9givID\npig/YIryA6YoP2CK8gOmUr3JZ9yWRTKTKdA3tRQc009oHV7YJDPjNy6VmRuD9a0wxYPjnn46eOcr\nMtPWd0VmZv/+qzKTV9onM0dn/1xm1nWPlJnXz02SmUNnhsvMkbs2yMzxfv37mfmbr8lM8fm4c+/A\n8jUyc/NLy2SmbEqnzPT258vMvmkbZSYWJz9givIDpig/YIryA6YoP2CK8gOmKD9givIDplId8ilq\nL5KZGzfrIY683qiLSqQFn/u7zPx6w0yZuTJJDybFWrhspcxU1uhhkO7xOhOj6af367Vu75GZgsKI\nJ7QiLFqwQmbGfuuMzLS36aGjWP0RT3F1Hh0qM0WdEWfxtJgdxeHkB0xRfsAU5QdMUX7AFOUHTFF+\nwBTlB0xRfsBUqkM+w/boQY+S1/XATNv9ybz6tWn13TJz5dP6lp78AckN+XQ8fk1met4rk5kRb0X8\njhboSNk8PTDTdXSYzBS36tuXwl06cnmUHhTr2q5vHyq9nMygWAghZId9IDNjNuqhq5Ozk9tTDE5+\nwBTlB0xRfsAU5QdMUX7AFOUHTFF+wBTlB0xRfsBUJpdLZlouRvZsfXqLAabyqluiRgU5+QFTlB8w\nRfkBU5QfMEX5AVOUHzBF+QFTlB8wleo1XnesXCozHQ/oa6zK/zhIZnY91yQztZsWy0xhh/4VlbXL\nSAghhF3P6j293XNDZha+rN+rK7iu99P85BqZ+cSPl+sPinClVl/h1nb/OpkZ/4tlMpMt1teqFXXG\nvWV4cIn+HU3e+ZDM1FZckJmWN+tkZv8KvZ9YnPyAKcoPmKL8gCnKD5ii/IApyg+YovyAKcoPmEp1\nyOf9GTqTvVIoM12zexLYTQgTXtKfc/TBUpnpG5TcG2srVj0hM71T9SDQZ6ftSWI7obhTX77UO69L\nZgZtr0hiO2HoLR0y0/3uTTLTM1a/wRjr2uEhMrN3qP4ehbF6ECpJnPyAKcoPmKL8gCnKD5ii/IAp\nyg+YovyAKcoPmEp1yKe0Xf+vuXqbHmAp31qsF5ulI61P6h+/dIce4Cm/94xeLFLpQ6dl5sq2GpnZ\n/uIUvdgzO2Sk8tHjMnPk4CiZKdcXNEU5f65cZqqmnJeZjv3DkthOCCGEoov6O9JfqYelCjtTrSMn\nP+CK8gOmKD9givIDpig/YIryA6YoP2CK8gOmMrmcHj5ISvZsfXqLAabyqluirpbi5AdMUX7AFOUH\nTFF+wBTlB0xRfsAU5QdMUX7AVKpXhzzcpq/XmVD6vsxs+sndMrPruSaZGbdlkczkn9PPh3197msy\nE0IISwbrW3pq33xcZgpPD5CZ/F4953Fg2RqZmbhtgcxMHnlKZlou6Ce0dk99VWZizNj7RZk5/6/h\nUZ/V+oj+Hn3+4FyZOXpopMzMv2O7zHy/ap/MxOLkB0xRfsAU5QdMUX7AFOUHTFF+wBTlB0xRfsBU\nqkM+BzdOlJnu18t0ZmkSuwkh5hKjIbd0yswP35gXtd6Sh9fKzE1b9QBPxzT9pNmECXrwJkbN83o/\nOx9okJmqtyMWm6ojjauXy0wu4kirmNERsaE4ra3VMpN/XQ9dXeofmMR2onHyA6YoP2CK8gOmKD9g\nivIDpig/YIryA6YoP2Aq1SGf+vmHZWbXbWNlJnMxmVe/Ro64KDMd3aUy8/IX9I04H9L/azunZmWm\n6Hy+zFzfXKO384KOHLunSGZyxf0yM+ebb+nFIlxr6JWZgWU6091cGbfgFB0ZV6dvnzq9dZTMHOyu\n0otF/FljcfIDpig/YIryA6YoP2CK8gOmKD9givIDpig/YCqTi7nOJiHZs/XpLQaYyqtu0dcGBU5+\nwBblB0xRfsAU5QdMUX7AFOUHTFF+wBTlB0ylepPPrMcWy8y14XpLWX2RTfjn95pk5lPfWSYzl8br\ntbJjrutQCKF15nqZebR9pszsPKZvO5rb0CwzPxrxjszU/m6JzFTs0X+QvjI9d9K8Ut+IVPcr/VZb\nXp9eq6xNRkIIIbzzbf09avjbl2VmVt0RmTnwg0kys/X5iOuXInHyA6YoP2CK8gOmKD9givIDpig/\nYIryA6YoP2Aq1SGfkzP1cv1DbsjM5InHkthO1ABPw/R2mTnQPOb/38x/nFxVLzPlNfr3uLmgUWZi\nhnwGntJrXa3RFzT1Vei/a4xskX7ObHTjOZnpfbc6ie2EEEJ46tY/yMwrJ6bLTMNT+5PYTjROfsAU\n5QdMUX7AFOUHTFF+wBTlB0xRfsAU5QdMpTrkk1d7VWZK9pTKTOWt+nNi3CjSwymXeotlJq836nWk\nKJ2L9c+2pGGbzKzeNzOB3YRQ2KUzmZz++Xv6I65fitAw4XQin9NZk9y5t/7p+2Tm5Bz9XTteMlQv\nNuYfMVuKwskPmKL8gCnKD5ii/IApyg+YovyAKcoPmKL8gCnKD5jK5HJ68igp2bP16S0GmMqrboka\nOeXkB0xRfsAU5QdMUX7AFOUHTFF+wBTlB0xRfsBUqtd4TfjZMpnJj7gSKxux60OL18jMPYfulZmW\nPaNlpqDmmt5QCOHwZzZE5YA0cPIDpig/YIryA6YoP2CK8gOmKD9givIDpig/YCrVIZ/Cxm6ZqR16\nQWbOXC5PYjvh3KtjZGZIVn/O5ZEJbAZIGSc/YIryA6YoP2CK8gOmKD9givIDpig/YIryA6ZSHfIp\n36iHc5rnlMhMzSg9CBSjq1FP8JS9ly8z2ZN6z8DHDSc/YIryA6YoP2CK8gOmKD9givIDpig/YIry\nA6YyuVwutcWyZ+vTWwwwlVfdot+8C5z8gC3KD5ii/IApyg+YovyAKcoPmKL8gCnKD5hKdcgHwMcH\nJz9givIDpig/YIryA6YoP2CK8gOmKD9givIDpig/YIryA6YoP2CK8gOmKD9givIDpig/YIryA6Yo\nP2CK8gOmKD9givIDpig/YIryA6YoP2Dq36RpBgbrBfHPAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fddc4a323c8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"vis_square(tensor.to_numpy(p)[0].reshape(32, 3,3))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "py3",
"language": "python",
"name": "py3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.3"
},
"widgets": {
"state": {
"0678ea185e8c48a9ab20bdb956e6b18a": {
"views": [
{
"cell_index": 13
}
]
},
"1373fb7b9e754b639a3f27ebf0372d70": {
"views": [
{
"cell_index": 13
}
]
},
"49561f2ab00b457f82357766d967ca8b": {
"views": [
{
"cell_index": 13
}
]
}
},
"version": "1.2.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}