diff --git a/python/EvenOdd.ipynb b/python/EvenOdd.ipynb new file mode 100644 index 000000000..0d1045ccb --- /dev/null +++ b/python/EvenOdd.ipynb @@ -0,0 +1,421 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Handwritten Digit & Parity Recognition\n", + "MNIST is a database containing a total of 70,000 images of handwritten digits collected from a combination of Census Bureau employees and high school students. The database also includes labels for each of the digits. For example, a handwritten \"7\" has a label 7 to serve as the \"ground truth\". In other words, we can train the model to classify digits given the images because we have the correct answer. Because MNIST is sanitized and has these labels, it makes for a good dataset to run image recognition algorithms on - something deep learning performs well at.\n", + "\n", + "The most basic use case for MNIST is to train a model that classifies digits. For example, we can build a model that classifies each image as one of 0, 1,..., or 9. To view a tutorial on this basic use case, see [Handwritten Digit Recognition](https://github.com/dmlc/mxnet-notebooks/blob/master/python/tutorials/mnist.ipynb). This tutorial will take it up a notch by showing how to train a model that can classify digits *as well* as distinguish between even and odd numbers. For example, given an image that looks like a \"7\", the model should classify the image as digit 7 and parity odd.\n", + "\n", + "## Prepare data\n", + "First we download the dataset and obtain data iterators for MNIST training and validation sets. Read more about the [MNISTIter API](http://mxnet.io/api/python/io.html#mxnet.io.MNISTIter)." + ] + }, + { + "cell_type": "code", + "execution_count": 199, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "import mxnet as mx\n", + "import mxnet.metric\n", + "import numpy as np\n", + "\n", + "# Finds MXNet source directory so we can import some utility functions\n", + "mxnet_path = os.path.dirname(os.path.abspath(os.path.expanduser(mxnet.__file__)))\n", + "sys.path.append(os.path.join(mxnet_path, \"../../tests/python/common\"))\n", + "sys.path.append(os.path.join(mxnet_path, \"../../example/python-howto\"))\n", + "\n", + "# Calls a utility function to provide two MNISTIter objects\n", + "from data import mnist_iterator\n", + "train, val = mnist_iterator(batch_size=100, input_shape = (784,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data iterator\n", + "\n", + "Now we will create a new iterator that will allow us to multitask, classifying digits *and* parity. We can build off of the MNIST iterator, which returns batches of input data and labels for digits 0, 1,...,9. We need this iterator to return batches of data with both labels for digits 0, 1,...,9 as well as labels for even and odd. To create this second set of labels, we simply calculate the parity based off of the ground truth we already have - the digit labels.\n", + "\n", + "* For **provide_data()** we are going to return MNIST iterator's data description since we want to keep the input data format the same.\n", + "* For **provide_label()** we are returning a list of 2 data descriptions: the first item is the name and shape of the digit labels, the second item is the name and shape of the parity labels.\n", + "* For **next()** we are returning a DataBatch object whose data field is set to the batch input, and the label field is set to a list of 2 lists: one with labels for the digit and the other with labels for the parity. This way the model can evaluate its accuracy in predicting both labels.\n", + "\n", + "[Read more about the DataIter API](http://mxnet.io/api/python/io.html?highlight=dataiter#mxnet.io.DataIter) and see a tutorial on [developing new iterators](https://github.com/dmlc/mxnet-notebooks/blob/master/python/basic/data.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": 200, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class DigitParity_iterator(mx.io.DataIter): # Extend DataIter\n", + " '''Multi label iterator'''\n", + "\n", + " def __init__(self, data_iter):\n", + " super(DigitParity_iterator, self).__init__()\n", + " self.data_iter = data_iter\n", + " self.batch_size = self.data_iter.batch_size\n", + "\n", + " @property\n", + " def provide_data(self):\n", + " # Return the name and shape of the input data as it is\n", + " return self.data_iter.provide_data\n", + "\n", + " @property\n", + " def provide_label(self):\n", + " # Return the digit and parity labels name/shape/data type as a list\n", + " # You will see later that we name our Softmax symbols \"softmaxdigit\" and \"softmaxparity\"\n", + " # MXNet will create labels called \"softmaxdigit_label\" and \"softmaxparity_label\"\n", + " # So those are the names we need to provide here\n", + " return [mx.io.DataDesc('softmaxdigit_label', (self.batch_size,), np.float32),\n", + " mx.io.DataDesc('softmaxparity_label', (self.batch_size,), np.float32)]\n", + "\n", + " def hard_reset(self):\n", + " self.data_iter.hard_reset()\n", + "\n", + " def reset(self):\n", + " self.data_iter.reset()\n", + "\n", + " def next(self):\n", + " # grab the next batch\n", + " batch = self.data_iter.next()\n", + " \n", + " # initialize \"labels\" with the original digit ground truth we have\n", + " labels = []\n", + " labels.append(batch.label[0]) \n", + " eolabels = []\n", + " \n", + " # calculate parity labels from the digit labels\n", + " for i in batch.label[0].asnumpy():\n", + " eolabels.append(i%2)\n", + " eolabels = mx.nd.array(np.array(eolabels))\n", + " \n", + " # append new parity labels to \"labels\"\n", + " labels.append(eolabels)\n", + " \n", + " return mx.io.DataBatch(data=batch.data, label=labels, pad=batch.pad, index=batch.index)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we just pass the MNIST iterators to iterator we just created." + ] + }, + { + "cell_type": "code", + "execution_count": 201, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "_train = DigitParity_iterator(train)\n", + "_val = DigitParity_iterator(val)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluation\n", + "\n", + "We create a class to help evaluate how accurately the model is classifying digits *and* parity. We know that for each example, the model will make 2 predictions (one for each of the above). So for a batch of n examples, the model makes 2n predictions. We want to keep track of how many of the 2n predictions were correct. \n", + "\n", + "Our class can extend the mx.metric.EvalMetric class. We will set **num_inst** to 2n and increment **sum_metric** every time the model predicted correctly. Then, EvalMetric can use those to numbers to return an evaluation.\n", + "\n", + "[Read more about EvalMetric API](http://mxnet.io/api/python/model.html?highlight=evalmetric#mxnet.metric.EvalMetric).\n", + "\n", + "To show the status of training, we will also write a function to be called after every epoch. What would be helpful in those status reports is a simple visualization of how many correct and incorrect predictions there are for each label. Take parity for example:\n", + "\n", + "```\n", + "[[90, 10]\n", + " [20, 80]]\n", + "```\n", + " \n", + "The first row tells us that the model correctly labeled even examples 90 times, and incorrectly labeled even examples as odd 10 times. The second row tells us that the model correctly labeled odd examples 80 times, and incorrectly labeled even examples as odd 20 times. What we want to see are big numbers on the main diagonal, which tells us that each label is guessed correctly often.\n", + "\n", + "So, we will also build this state into the evaluation class and have the status function print this state out after every epoch." + ] + }, + { + "cell_type": "code", + "execution_count": 202, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class MNISTPerDigitAccuracy(mx.metric.EvalMetric): # Extend EvalMetric\n", + " \"\"\"Calculate accuracy\"\"\"\n", + "\n", + " def reset(self):\n", + " if hasattr(self, 'cms'):\n", + " for cm in self.cms:\n", + " cm.fill(0)\n", + " \n", + "\n", + " def __init__(self,sizes):\n", + " super(MNISTPerDigitAccuracy, self).__init__('perdigitaccuracy')\n", + " \n", + " # Create matrices for each label type\n", + " # Each matrix is x*x, where x is the number of classes in the label type\n", + " # For digit label it will be 10*10\n", + " # For parity label it will be 2*2\n", + " self.cms = map( lambda x: np.zeros((x,x), dtype=int), sizes )\n", + " self.reset()\n", + " return super(MNISTPerDigitAccuracy, self).reset()\n", + "\n", + " def update(self, labels, preds):\n", + " mx.metric.check_label_shapes(labels, preds)\n", + " \n", + " # Go through each label type: digit and parity\n", + " for i in range(len(labels)):\n", + " # For the current label type, zip up the predictions versus the ground truths\n", + " for label, pred_label in zip(labels[i].asnumpy(), preds[i].asnumpy()):\n", + " # Predictions come in the form of probability distributions\n", + " # For digit labels, it will be a vector of 10 numbers. Ex:\n", + " # [0, .01, .09, 0, 0, .06, .04, .8, 0, 0] shows \"7\" has the highest probability\n", + " # So we take the maximum from the prediction\n", + " pred_label = int(np.argmax(pred_label))\n", + " \n", + " # Keep track of what predictions the models makes for what ground truths\n", + " # It is a matrix where big numbers on the main diagonal is our goal\n", + " label = int(label)\n", + " self.cms[i][label,pred_label] += 1\n", + " \n", + " # Keep track of how many correct predictions the model makes\n", + " self.sum_metric += (pred_label == label)\n", + " \n", + " # Keep track of how many predictions are made over all\n", + " self.num_inst += 1\n", + " \n", + "def perDigitMetric(params):\n", + " # Print the state of what predictions the models makes for what ground truths\n", + " for i, cm in enumerate(params.eval_metric.cms):\n", + " print \"Label=\", i\n", + " print cm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model\n", + "\n", + "Now we build the model. It will be similar to the one we build for digit classification, except we add another layer, fc3eo, to predict the parity." + ] + }, + { + "cell_type": "code", + "execution_count": 203, + "metadata": { + "collapsed": false, + "scrolled": false + }, + "outputs": [], + "source": [ + "data = mx.symbol.Variable('data')\n", + "fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)\n", + "act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type=\"relu\")\n", + "fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)\n", + "act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type=\"relu\")\n", + "fc3digit = mx.symbol.FullyConnected(data = act2, name='fc3digit', num_hidden=10)\n", + "\n", + "# A fully connected layer for the parity. The hidden size is 2 to account for even and odd.\n", + "fc3parity = mx.symbol.FullyConnected(data = act2, name='fc3parity', num_hidden=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we create the output symbols. We will group those two symbols together so we can multitask our learning." + ] + }, + { + "cell_type": "code", + "execution_count": 204, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "mlp1 = mx.symbol.SoftmaxOutput(data = fc3digit, name = 'softmaxdigit')\n", + "mlp2 = mx.symbol.SoftmaxOutput(data = fc3parity, name = 'softmaxparity')\n", + "\n", + "# Group the digit and parity symbols to multitask\n", + "mlp = mx.symbol.Group([mlp1,mlp2])\n", + "\n", + "optimizer_params = (('learning_rate', 0.1),('momentum', 0.9), ('wd', 0.00001))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training\n", + "Now let's train the model." + ] + }, + { + "cell_type": "code", + "execution_count": 205, + "metadata": { + "collapsed": false, + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Epoch[0] Batch [100]\tSpeed: 33022.90 samples/sec\tTrain-perdigitaccuracy=0.390198\n", + "INFO:root:Epoch[0] Batch [200]\tSpeed: 34679.12 samples/sec\tTrain-perdigitaccuracy=0.603881\n", + "INFO:root:Epoch[0] Batch [300]\tSpeed: 33373.04 samples/sec\tTrain-perdigitaccuracy=0.712957\n", + "INFO:root:Epoch[0] Batch [400]\tSpeed: 30580.56 samples/sec\tTrain-perdigitaccuracy=0.770224\n", + "INFO:root:Epoch[0] Batch [500]\tSpeed: 32021.01 samples/sec\tTrain-perdigitaccuracy=0.807515\n", + "INFO:root:Epoch[0] Train-perdigitaccuracy=0.832458\n", + "INFO:root:Epoch[0] Time cost=1.826\n", + "INFO:root:Epoch[0] Validation-perdigitaccuracy=0.851321\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Label= 0\n", + "[[ 965 0 0 3 0 2 4 3 3 0]\n", + " [ 0 1104 1 10 0 1 1 0 18 0]\n", + " [ 12 1 981 16 3 1 9 4 5 0]\n", + " [ 0 0 5 990 0 5 0 4 6 0]\n", + " [ 2 0 2 0 940 0 14 2 2 20]\n", + " [ 5 0 0 56 4 796 7 2 17 5]\n", + " [ 10 3 0 1 4 9 926 0 5 0]\n", + " [ 1 8 20 9 3 0 0 968 2 17]\n", + " [ 7 1 6 22 3 1 8 2 920 4]\n", + " [ 3 5 0 10 27 6 2 8 2 946]]\n", + "Label= 1\n", + "[[4791 135]\n", + " [ 111 4963]]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Epoch[1] Batch [100]\tSpeed: 32412.70 samples/sec\tTrain-perdigitaccuracy=0.865418\n", + "INFO:root:Epoch[1] Batch [200]\tSpeed: 36032.52 samples/sec\tTrain-perdigitaccuracy=0.877009\n", + "INFO:root:Epoch[1] Batch [300]\tSpeed: 34819.55 samples/sec\tTrain-perdigitaccuracy=0.886309\n", + "INFO:root:Epoch[1] Batch [400]\tSpeed: 33539.16 samples/sec\tTrain-perdigitaccuracy=0.894074\n", + "INFO:root:Epoch[1] Batch [500]\tSpeed: 34900.47 samples/sec\tTrain-perdigitaccuracy=0.900733\n", + "INFO:root:Epoch[1] Train-perdigitaccuracy=0.906185\n", + "INFO:root:Epoch[1] Time cost=1.754\n", + "INFO:root:Epoch[1] Validation-perdigitaccuracy=0.911046\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Label= 0\n", + "[[ 974 0 0 1 0 0 2 1 2 0]\n", + " [ 0 1122 4 3 1 0 1 0 4 0]\n", + " [ 7 1 1011 3 4 0 3 2 1 0]\n", + " [ 5 2 10 972 0 8 0 4 6 3]\n", + " [ 1 0 3 0 953 0 6 1 0 18]\n", + " [ 8 1 1 13 4 843 7 1 9 5]\n", + " [ 11 3 3 0 4 6 925 0 6 0]\n", + " [ 2 7 20 7 2 0 0 968 2 20]\n", + " [ 23 0 21 3 3 1 10 2 905 6]\n", + " [ 4 2 0 5 21 6 1 2 1 967]]\n", + "Label= 1\n", + "[[4885 41]\n", + " [ 114 4960]]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Epoch[2] Batch [100]\tSpeed: 32870.39 samples/sec\tTrain-perdigitaccuracy=0.915363\n", + "INFO:root:Epoch[2] Batch [200]\tSpeed: 36177.10 samples/sec\tTrain-perdigitaccuracy=0.919403\n", + "INFO:root:Epoch[2] Batch [300]\tSpeed: 34680.55 samples/sec\tTrain-perdigitaccuracy=0.922878\n", + "INFO:root:Epoch[2] Batch [400]\tSpeed: 33414.76 samples/sec\tTrain-perdigitaccuracy=0.925980\n", + "INFO:root:Epoch[2] Batch [500]\tSpeed: 33036.99 samples/sec\tTrain-perdigitaccuracy=0.928848\n", + "INFO:root:Epoch[2] Train-perdigitaccuracy=0.931305\n", + "INFO:root:Epoch[2] Time cost=1.770\n", + "INFO:root:Epoch[2] Validation-perdigitaccuracy=0.933336\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Label= 0\n", + "[[ 969 0 3 0 0 0 7 1 0 0]\n", + " [ 0 1111 8 3 0 0 5 2 5 1]\n", + " [ 7 0 1014 0 2 0 3 2 4 0]\n", + " [ 1 0 15 959 0 14 0 6 9 6]\n", + " [ 1 0 5 0 961 0 8 0 0 7]\n", + " [ 6 0 0 7 1 856 11 1 6 4]\n", + " [ 6 3 1 0 5 2 934 0 7 0]\n", + " [ 1 4 18 2 3 0 0 986 3 11]\n", + " [ 15 0 28 0 6 3 15 2 899 6]\n", + " [ 5 1 0 2 29 6 1 4 1 960]]\n", + "Label= 1\n", + "[[4902 24]\n", + " [ 146 4928]]\n" + ] + } + ], + "source": [ + "# Because we have two sets of labels, one for each of digit and parity SoftmaxOutput,\n", + "# we have to explicitly call them out here\n", + "mod = mx.mod.Module(mlp, label_names=['softmaxdigit_label','softmaxparity_label'])\n", + "mod.fit(_train, \n", + " eval_data=_val,\n", + " optimizer_params=optimizer_params,\n", + " eval_metric=MNISTPerDigitAccuracy([10,2]),\n", + " eval_end_callback=perDigitMetric,\n", + " num_epoch=3, batch_end_callback=mx.callback.Speedometer(100,100))" + ] + } + ], + "metadata": { + "anaconda-cloud": {}, + "celltoolbar": "Raw Cell Format", + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}