{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": [
     "remove-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Install datascience package if needed\n",
    "try:\n",
    "    import datascience\n",
    "except ImportError:\n",
    "    import micropip\n",
    "    await micropip.install('datascience')\n",
    "import matplotlib\n",
    "#matplotlib.use('Agg')\n",
    "path_data = '../../../assets/data/'\n",
    "from datascience import *\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import numpy as np\n",
    "import math\n",
    "import scipy.stats as stats\n",
    "plt.style.use('fivethirtyeight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The Accuracy of the Classifier\n",
    "To see how well our classifier does, we might put 50% of the data into the training set and the other 50% into the test set.  Basically, we are setting aside some data for later use, so we can use it to measure the accuracy of our classifier.  We've been calling that the *test set*. Sometimes people will call the data that you set aside for testing a *hold-out set*, and they'll call this strategy for estimating accuracy the *hold-out method*.\n",
    "\n",
    "Note that this approach requires great discipline.  Before you start applying machine learning methods, you have to take some of your data and set it aside for testing.  You must avoid using the test set for developing your classifier: you shouldn't use it to help train your classifier or tweak its settings or for brainstorming ways to improve your classifier.  Instead, you should use it only once, at the very end, after you've finalized your classifier, when you want an unbiased estimate of its accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": [
     "remove-input"
    ]
   },
   "outputs": [],
   "source": [
    "def distance(point1, point2):\n",
    "    \"\"\"Returns the distance between point1 and point2\n",
    "    where each argument is an array \n",
    "    consisting of the coordinates of the point\"\"\"\n",
    "    return np.sqrt(np.sum((point1 - point2)**2))\n",
    "\n",
    "def all_distances(training, new_point):\n",
    "    \"\"\"Returns an array of distances\n",
    "    between each point in the training set\n",
    "    and the new point (which is a row of attributes)\"\"\"\n",
    "    attributes = training.drop('Class')\n",
    "    def distance_from_point(row):\n",
    "        return distance(np.array(new_point), np.array(row))\n",
    "    return attributes.apply(distance_from_point)\n",
    "\n",
    "def table_with_distances(training, new_point):\n",
    "    \"\"\"Augments the training table \n",
    "    with a column of distances from new_point\"\"\"\n",
    "    return training.with_column('Distance', all_distances(training, new_point))\n",
    "\n",
    "def closest(training, new_point, k):\n",
    "    \"\"\"Returns a table of the k rows of the augmented table\n",
    "    corresponding to the k smallest distances\"\"\"\n",
    "    with_dists = table_with_distances(training, new_point)\n",
    "    sorted_by_distance = with_dists.sort('Distance')\n",
    "    topk = sorted_by_distance.take(np.arange(k))\n",
    "    return topk\n",
    "\n",
    "def majority(topkclasses):\n",
    "    ones = topkclasses.where('Class', are.equal_to(1)).num_rows\n",
    "    zeros = topkclasses.where('Class', are.equal_to(0)).num_rows\n",
    "    if ones > zeros:\n",
    "        return 1\n",
    "    else:\n",
    "        return 0\n",
    "\n",
    "def classify(training, new_point, k):\n",
    "    closestk = closest(training, new_point, k)\n",
    "    topkclasses = closestk.select('Class')\n",
    "    return majority(topkclasses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": [
     "remove-input"
    ]
   },
   "outputs": [],
   "source": [
    "wine = Table.read_table(path_data + 'wine.csv')\n",
    "\n",
    "# For converting Class to binary\n",
    "\n",
    "def is_one(x):\n",
    "    if x == 1:\n",
    "        return 1\n",
    "    else:\n",
    "        return 0\n",
    "    \n",
    "wine = wine.with_column('Class', wine.apply(is_one, 0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Measuring the Accuracy of Our Wine Classifier\n",
    "OK, so let's apply the hold-out method to evaluate the effectiveness of the $k$-nearest neighbor classifier for identifying wines.  The data set has 178 wines, so we'll randomly permute the data set and put 89 of them in the training set and the remaining 89 in the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "shuffled_wine = wine.sample(with_replacement=False) \n",
    "training_set = shuffled_wine.take(np.arange(89))\n",
    "test_set  = shuffled_wine.take(np.arange(89, 178))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll train the classifier using the 89 wines in the training set, and evaluate how well it performs on the test set. To make our lives easier, we'll write a function to evaluate a classifier on every wine in the test set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_zero(array):\n",
    "    \"\"\"Counts the number of 0's in an array\"\"\"\n",
    "    return len(array) - np.count_nonzero(array)\n",
    "\n",
    "def count_equal(array1, array2):\n",
    "    \"\"\"Takes two numerical arrays of equal length\n",
    "    and counts the indices where the two are equal\"\"\"\n",
    "    return count_zero(array1 - array2)\n",
    "\n",
    "def evaluate_accuracy(training, test, k):\n",
    "    test_attributes = test.drop('Class')\n",
    "    def classify_testrow(row):\n",
    "        return classify(training, row, k)\n",
    "    c = test_attributes.apply(classify_testrow)\n",
    "    return count_equal(c, test.column('Class')) / test.num_rows"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now for the grand reveal -- let's see how we did.  We'll arbitrarily use $k=5$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.898876404494382"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate_accuracy(training_set, test_set, 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The accuracy rate isn't bad at all for a simple classifier."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Breast Cancer Diagnosis\n",
    "\n",
    "Now I want to do an example based on diagnosing breast cancer.  I was inspired by Brittany Wenger, who won the Google national science fair in 2012 as a 17-year old high school student.  Here's Brittany:\n",
    "\n",
    "![Brittany Wenger](http://i.huffpost.com/gen/701499/thumbs/o-GSF83-570.jpg?3)\n",
    "\n",
    "Brittany's [science fair project](https://sites.google.com/a/googlesciencefair.com/science-fair-2012-project-64a91af142a459cfb486ed5cb05f803b2eb41354-1333130785-87/home) was to build a classification algorithm to diagnose breast cancer.  She won grand prize for building an algorithm whose accuracy was almost 99%. \n",
    "\n",
    "Let's see how well we can do, with the ideas we've learned in this course.\n",
    "\n",
    "So, let me tell you a little bit about the data set.  Basically, if a woman has a lump in her breast, the doctors may want to take a biopsy to see if it is cancerous.  There are several different procedures for doing that.  Brittany focused on fine needle aspiration (FNA), because it is less invasive than the alternatives.  The doctor gets a sample of the mass, puts it under a microscope, takes a picture, and a trained lab tech analyzes the picture to determine whether it is cancer or not.  We get a picture like one of the following:\n",
    "\n",
    "![benign](../../../images/benign.png)\n",
    "\n",
    "![cancer](../../../images/malignant.png)\n",
    "\n",
    "Unfortunately, distinguishing between benign vs malignant can be tricky.  So, researchers have studied the use of machine learning to help with this task.  The idea is that we'll ask the lab tech to analyze the image and compute various attributes: things like the typical size of a cell, how much variation there is among the cell sizes, and so on.  Then, we'll try to use this information to predict (classify) whether the sample is malignant or not.  We have a training set of past samples from women where the correct diagnosis is known, and we'll hope that our machine learning algorithm can use those to learn how to predict the diagnosis for future samples.\n",
    "\n",
    "We end up with the following data set.  For the \"Class\" column, 1 means malignant (cancer); 0 means benign (not cancer)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "    <thead>\n",
       "        <tr>\n",
       "            <th>Clump Thickness</th> <th>Uniformity of Cell Size</th> <th>Uniformity of Cell Shape</th> <th>Marginal Adhesion</th> <th>Single Epithelial Cell Size</th> <th>Bare Nuclei</th> <th>Bland Chromatin</th> <th>Normal Nucleoli</th> <th>Mitoses</th> <th>Class</th>\n",
       "        </tr>\n",
       "    </thead>\n",
       "    <tbody>\n",
       "        <tr>\n",
       "            <td>5              </td> <td>1                      </td> <td>1                       </td> <td>1                </td> <td>2                          </td> <td>1          </td> <td>3              </td> <td>1              </td> <td>1      </td> <td>0    </td>\n",
       "        </tr>\n",
       "        <tr>\n",
       "            <td>5              </td> <td>4                      </td> <td>4                       </td> <td>5                </td> <td>7                          </td> <td>10         </td> <td>3              </td> <td>2              </td> <td>1      </td> <td>0    </td>\n",
       "        </tr>\n",
       "        <tr>\n",
       "            <td>3              </td> <td>1                      </td> <td>1                       </td> <td>1                </td> <td>2                          </td> <td>2          </td> <td>3              </td> <td>1              </td> <td>1      </td> <td>0    </td>\n",
       "        </tr>\n",
       "        <tr>\n",
       "            <td>6              </td> <td>8                      </td> <td>8                       </td> <td>1                </td> <td>3                          </td> <td>4          </td> <td>3              </td> <td>7              </td> <td>1      </td> <td>0    </td>\n",
       "        </tr>\n",
       "        <tr>\n",
       "            <td>4              </td> <td>1                      </td> <td>1                       </td> <td>3                </td> <td>2                          </td> <td>1          </td> <td>3              </td> <td>1              </td> <td>1      </td> <td>0    </td>\n",
       "        </tr>\n",
       "        <tr>\n",
       "            <td>8              </td> <td>10                     </td> <td>10                      </td> <td>8                </td> <td>7                          </td> <td>10         </td> <td>9              </td> <td>7              </td> <td>1      </td> <td>1    </td>\n",
       "        </tr>\n",
       "        <tr>\n",
       "            <td>1              </td> <td>1                      </td> <td>1                       </td> <td>1                </td> <td>2                          </td> <td>10         </td> <td>3              </td> <td>1              </td> <td>1      </td> <td>0    </td>\n",
       "        </tr>\n",
       "        <tr>\n",
       "            <td>2              </td> <td>1                      </td> <td>2                       </td> <td>1                </td> <td>2                          </td> <td>1          </td> <td>3              </td> <td>1              </td> <td>1      </td> <td>0    </td>\n",
       "        </tr>\n",
       "        <tr>\n",
       "            <td>2              </td> <td>1                      </td> <td>1                       </td> <td>1                </td> <td>2                          </td> <td>1          </td> <td>1              </td> <td>1              </td> <td>5      </td> <td>0    </td>\n",
       "        </tr>\n",
       "        <tr>\n",
       "            <td>4              </td> <td>2                      </td> <td>1                       </td> <td>1                </td> <td>2                          </td> <td>1          </td> <td>2              </td> <td>1              </td> <td>1      </td> <td>0    </td>\n",
       "        </tr>\n",
       "    </tbody>\n",
       "</table>\n",
       "<p>... (673 rows omitted)</p>"
      ],
      "text/plain": [
       "Clump Thickness | Uniformity of Cell Size | Uniformity of Cell Shape | Marginal Adhesion | Single Epithelial Cell Size | Bare Nuclei | Bland Chromatin | Normal Nucleoli | Mitoses | Class\n",
       "5               | 1                       | 1                        | 1                 | 2                           | 1           | 3               | 1               | 1       | 0\n",
       "5               | 4                       | 4                        | 5                 | 7                           | 10          | 3               | 2               | 1       | 0\n",
       "3               | 1                       | 1                        | 1                 | 2                           | 2           | 3               | 1               | 1       | 0\n",
       "6               | 8                       | 8                        | 1                 | 3                           | 4           | 3               | 7               | 1       | 0\n",
       "4               | 1                       | 1                        | 3                 | 2                           | 1           | 3               | 1               | 1       | 0\n",
       "8               | 10                      | 10                       | 8                 | 7                           | 10          | 9               | 7               | 1       | 1\n",
       "1               | 1                       | 1                        | 1                 | 2                           | 10          | 3               | 1               | 1       | 0\n",
       "2               | 1                       | 2                        | 1                 | 2                           | 1           | 3               | 1               | 1       | 0\n",
       "2               | 1                       | 1                        | 1                 | 2                           | 1           | 1               | 1               | 5       | 0\n",
       "4               | 2                       | 1                        | 1                 | 2                           | 1           | 2               | 1               | 1       | 0\n",
       "... (673 rows omitted)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "patients = Table.read_table(path_data + 'breast-cancer.csv').drop('ID')\n",
    "patients"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So we have 9 different attributes.  I don't know how to make a 9-dimensional scatterplot of all of them, so I'm going to pick two and plot them:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "color_table = Table().with_columns(\n",
    "    'Class', make_array(1, 0),\n",
    "    'Color', make_array('darkblue', 'gold')\n",
    ")\n",
    "patients_with_colors = patients.join('Class', color_table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "patients_with_colors.scatter('Bland Chromatin', 'Single Epithelial Cell Size', group='Color')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Oops.  That plot is utterly misleading, because there are a bunch of points that have identical values for both the x- and y-coordinates.  To make it easier to see all the data points, I'm going to add a little bit of random jitter to the x- and y-values.  Here's how that looks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": [
     "remove-input"
    ]
   },
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def randomize_column(a):\n",
    "    return a + np.random.normal(0.0, 0.09, size=len(a))\n",
    "Table().with_columns(\n",
    "        'Bland Chromatin (jittered)', \n",
    "        randomize_column(patients.column('Bland Chromatin')),\n",
    "        'Single Epithelial Cell Size (jittered)', \n",
    "        randomize_column(patients.column('Single Epithelial Cell Size')),\n",
    "        'Class', patients.column('Class')\n",
    "    ).join('Class', color_table).scatter(1, 2, group='Color')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For instance, you can see there are lots of samples with chromatin = 2 and epithelial cell size = 2; all non-cancerous.\n",
    "\n",
    "Keep in mind that the jittering is just for visualization purposes, to make it easier to get a feeling for the data.  We're ready to work with the data now, and we'll use the original (unjittered) data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we'll create a training set and a test set. The data set has 683 patients, so we'll randomly permute the data set and put 342 of them in the training set and the remaining 341 in the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "shuffled_patients = patients.sample(683, with_replacement=False) \n",
    "training_set = shuffled_patients.take(np.arange(342))\n",
    "test_set  = shuffled_patients.take(np.arange(342, 683))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's stick with 5 nearest neighbors, and see how well our classifier does."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.967741935483871"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate_accuracy(training_set, test_set, 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Over 96% accuracy.  Not bad!  Once again, pretty darn good for such a simple technique.\n",
    "\n",
    "As a footnote, you might have noticed that Brittany Wenger did even better.  What techniques did she use? One key innovation is that she incorporated a confidence score into her results: her algorithm had a way to determine when it was not able to make a confident prediction, and for those patients, it didn't even try to predict their diagnosis.  Her algorithm was 99% accurate on the patients where it made a prediction -- so that extension seemed to help quite a bit."
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}