handson-ml/03_classification.ipynb

4610 lines
852 KiB
Plaintext
Raw Permalink Normal View History

2016-05-22 17:40:18 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
"**Chapter 3 Classification**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2017-08-19 17:01:55 +02:00
"_This notebook contains all the sample code and solutions to the exercises in chapter 3._"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/03_classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/03_classification.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
" </td>\n",
"</table>"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 11:03:20 +01:00
"This project requires Python 3.7 or above:"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"import sys\n",
2021-05-25 02:07:29 +02:00
"\n",
2022-02-19 11:03:20 +01:00
"assert sys.version_info >= (3, 7)"
]
},
2016-05-22 17:40:18 +02:00
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
2021-11-03 03:54:29 +01:00
"It also requires Scikit-Learn ≥ 1.0.1:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from packaging import version\n",
2021-11-03 03:54:29 +01:00
"import sklearn\n",
2021-10-29 07:03:30 +02:00
"\n",
"assert version.parse(sklearn.__version__) >= version.parse(\"1.0.1\")"
2021-11-03 03:54:29 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Just like in the previous chapter, let's define the default font sizes to make the figures prettier:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
2016-05-22 17:40:18 +02:00
"source": [
2021-11-27 00:43:26 +01:00
"import matplotlib.pyplot as plt\n",
2021-11-03 03:54:29 +01:00
"\n",
"plt.rc('font', size=14)\n",
2021-11-27 00:43:26 +01:00
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
2021-12-08 03:16:42 +01:00
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)"
2021-11-03 03:54:29 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And let's create the `images/classification` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 4,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
2021-11-03 03:54:29 +01:00
"from pathlib import Path\n",
"\n",
"IMAGES_PATH = Path() / \"images\" / \"classification\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MNIST"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 5,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
2021-10-30 21:28:30 +02:00
"source": [
2021-11-03 03:54:29 +01:00
"from sklearn.datasets import fetch_openml\n",
"\n",
"mnist = fetch_openml('mnist_784', as_frame=False)"
2021-10-30 21:28:30 +02:00
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges \n",
"**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown \n",
"**Please cite**: \n",
"\n",
"The MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples \n",
"\n",
"It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field. \n",
"\n",
"With some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets. \n",
"\n",
"The MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n",
"\n",
"Downloaded from openml.org.\n"
]
}
],
2021-10-30 21:28:30 +02:00
"source": [
"# extra code it's a bit too long\n",
2021-11-03 03:54:29 +01:00
"print(mnist.DESCR)"
2021-10-30 21:28:30 +02:00
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"mnist.keys() # extra code we only use data and target in this notebook"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-10-30 21:28:30 +02:00
"execution_count": 8,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
2021-11-03 03:54:29 +01:00
"X, y = mnist.data, mnist.target\n",
"X"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-10-30 21:28:30 +02:00
"execution_count": 9,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(70000, 784)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-11-03 03:54:29 +01:00
"X.shape"
]
},
2016-05-22 17:40:18 +02:00
{
2021-11-03 03:54:29 +01:00
"cell_type": "code",
"execution_count": 10,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array(['5', '0', '4', ..., '4', '5', '6'], dtype=object)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-11-03 03:54:29 +01:00
"y"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 11,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(70000,)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
2021-11-03 03:54:29 +01:00
"y.shape"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 12,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"784"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
2021-11-03 03:54:29 +01:00
"source": [
"28 * 28"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARAAAAEQCAYAAAB4CisVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAHMklEQVR4nO3dO0iWbQDG8ce0o1jWZtEcuHSgcAg6Qk3WGg1Rk0HlokTg0BjUVrZFU9QiObgUCTVEEA5FB8hBiGioRUyooQi/4Zvzfrh6Ms3fb30v3vfuwL8HunltmZubqwASK/72AYClS0CAmIAAMQEBYgICxAQEiLUVXvd/vEDLr17wBALEBASICQgQExAgJiBATECAmIAAMQEBYgICxAQEiAkIEBMQICYgQExAgJiAADEBAWICAsQEBIgJCBATECAmIEBMQICYgAAxAQFiAgLEBASICQgQExAgJiBATECAmIAAMQEBYgICxAQEiAkIEBMQICYgQExAgJiAADEBAWICAsQEBIgJCBATECAmIECs7W8fgD/j58+fxc2XL18W4CT/Gx4errX79u1bcTM5OVnc3Lx5s7gZHBwsbu7du1fcVFVVrVmzpri5dOlScXP58uVan7dYeAIBYgICxAQEiAkIEBMQICYgQExAgJiAADEXyRrw4cOH4ub79+/FzbNnz4qbp0+f1jrTzMxMcTMyMlLrvRabrVu3FjcXLlwobkZHR4ubjo6OWmfavn17cbN///5a77WUeAIBYgICxAQEiAkIEBMQICYgQExAgJiAALGWubm5+V6f98V/3YsXL2rtDh06VNws5Ld/LWWtra3Fze3bt4ub9vb2Jo5Tbd68udZu48aNxc22bdt+9zh/S8uvXvAEAsQEBIgJCBATECAmIEBMQICYgAAxAQFiAgLE3ESdx/T0dK1dT09PcTM1NfW7x/kr6vza6tzCfPz4ca3PW7VqVXHjVu+CcxMVaJ6AADEBAWICAsQEBIgJCBATECAmIEDMz8adx6ZNm2rtrl27VtyMjY0VNzt37ixu+vv7a52pjh07dhQ34+PjxU2drw988+ZNnSNV169fr7VjcfAEAsQEBIgJCBATECAmIEBMQICYgAAxAQFivpFsgczOzhY3HR0dxU1fX1+tz7t161Zxc+fOneLm5MmTtT6Pf5pvJAOaJyBATECAmIAAMQEBYgICxAQEiAkIEPONZAtk/fr1jbzPhg0bGnmfqqp32ezEiRPFzYoV/h1arvzJAzEBAWICAsQEBIgJCBATECAmIEBMQICYgAAxX2m4xHz9+rXWrre3t7h58uRJcfPgwYPi5siRI3WOxNLlKw2B5gkIEBMQICYgQExAgJiAADEBAWICAsRcJPtHTU1NFTe7du0qbjo7O4ubgwcPFje7d+8ubqqqqs6dO1fctLT88l4Tf4aLZEDzBASICQgQExAgJiBATECAmIAAMQEBYi6SLWOjo6PFzZkzZ4qb2dnZJo5TVVVVXblypbg5depUcdPV1dXEcfifi2RA8wQEiAkIEBMQICYgQExAgJiAADEBAWIukjGv169fFzcDAwPFzfj4eBPHqaqqqs6ePVvcDA0NFTdbtmxp4jjLgYtkQPMEBIgJCBATECAmIEBMQICYgAAxAQFiLpLx22ZmZoqbsbGxWu91+vTp4qbwd7aqqqo6fPhwcfPo0aM6R8JFMuBPEBAgJiBATECAmIAAMQEBYgICxAQEiAkIEHMTlUVl9erVxc2PHz+Km5UrVxY3Dx8+LG4OHDhQ3CwDbqICzRMQICYgQExAgJiAADEBAWICAsQEBIi1/e0DsLi9evWquBkZGSluJiYman1enUtidXR3dxc3+/bta+SzljNPIEBMQICYgAAxAQFiAgLEBASICQgQExAg5iLZP2pycrK4uXHjRnFz//794ubTp0+1ztSUtrbyX9uurq7iZsUK/37+Lr+DQExAgJiAADEBAWICAsQEBIgJCBATECDmItkiUudC1t27d2u91/DwcHHz/v37Wu+1kPbs2VPcDA0NFTfHjh1r4jgUeAIBYgICxAQEiAkIEBMQICYgQExAgJiAADEXyRrw+fPn4ubt27fFzfnz54ubd+/e1TrTQurp6SluLl68WOu9jh8/Xtz4JrHFw58EEBMQICYgQExAgJiAADEBAWICAsQEBIgJCBBbtjdRp6eni5u+vr5a7/Xy5cviZmpqqtZ7LaS9e/cWNwMDA8XN0aNHi5u1a9fWOhNLiycQICYgQExAgJiAADEBAWICAsQEBIgJCBBbchfJnj9/XtxcvXq1uJmYmChuPn78WOtMC2ndunW1dv39/cVNnZ8x297eXuvzWJ48gQAxAQFiAgLEBASICQgQExAgJiBATECA2JK7SDY6OtrIpknd3d3FTW9vb3HT2tpa3AwODtY6U2dnZ60d/A5PIEBMQICYgAAxAQFiAgLEBASICQgQExAg1jI3Nzff6/O+CCwLLb96wRMIEBMQICYgQExAgJiAADEBAWICAsQEBIgJCBATECAmIEBMQICYgAAxAQFiAgLEBASICQgQExAgJiBATECAmIAAMQEBYgICxAQEiAkIEBMQICYgQKyt8PovfyYmgCcQICYgQExAgJiAADEBAWICAsT+A3RNA9lsM+CIAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"import matplotlib.pyplot as plt\n",
"\n",
2021-10-29 07:03:30 +02:00
"def plot_digit(image_data):\n",
" image = image_data.reshape(28, 28)\n",
" plt.imshow(image, cmap=\"binary\")\n",
2021-10-29 07:03:30 +02:00
" plt.axis(\"off\")\n",
"\n",
2021-10-29 07:03:30 +02:00
"some_digit = X[0]\n",
"plot_digit(some_digit)\n",
"save_fig(\"some_digit_plot\") # extra code\n",
2021-10-29 07:03:30 +02:00
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 14,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"'5'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-10-29 07:03:30 +02:00
"y[0]"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 15,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAH3CAYAAAAmMFzFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9d3Sj15mn+QAgQBAgCJAgwATmnMnKuRRKkhWtlm253fbYctr1uNe7O9O73j27k86e45nenent2W23j6fdlke25HaSJSsHq1SlqlLlYs45gAQDQOQc9o/a75qsIFVgESgJzzk6pUOA4P3w3e/e977h98oSiQRp0qRJkyZNmk8v8mQPIE2aNGnSpEmTXNLGQJo0adKkSfMpJ20MpEmTJk2aNJ9y0sZAmjRp0qRJ8yknbQykSZMmTZo0n3LSxkCaNGnSpEnzKSfjY17/JNQdyvjkXAd8cq4lfR2pwyfpOuCTcy3p60gdPknXcU3SnoE0adKkSZPmU07aGEiTJk2aNGk+5aSNgTRp0qRJk+ZTTtoYSJMmTZo0aT7lfFwCYZqbIBqNEovFCAQCRCIRQqEQSqUSlUqFWq0mIyMDpVKZ7GGmSZPmJkgkEng8HsLhMMFgELlcjkKhICcnB5VKhUKhSPYQ03yK8fv9+Hw+AORyObm5ucjlN3/OTxsDm8j8/Dw2m43XXnuNyclJTp48SVNTE01NTTz44INUVlZSXV1NRkb6a0+T5m7B4/HwD//wDwwODvLuu+9iNBoxmUz8y3/5L+ns7MRsNt/S4psmzWbw+uuv87vf/Y5EIoHRaOQHP/gBeXl5N/05SdmV4vE4sVhsw89isRg2m41IJEI8HicSiRCJRCgrK0OtVjM1NYXP52NtbQ24bAEVFhaSk5NDaWlpUq3zcDhMIBCgp6eHyclJ+vr6mJubw2q1kpmZSSKRQK/Xs7q6Smlp6SfCGIjH47hcLlwuF5OTk1gsFkpKSlCr1Sl/UnI4HDgcDhYWFvD7/SQSCQwGA/X19Wg0GtRqdbKH+IklHo+L/0KhEC6Xi1gsRiwWo6CgAI1Gg0x23eqnLSccDuPxeBgdHWVsbIz5+Xm8Xi9ut5vV1VU8Hg/5+fkpbwyEQiGi0SiJRIJIJILX62V5eZnl5eWP/b7lcjl5eXkYDAaqqqpS4lql6/D5fKyuruJwOGhoaECv1yd7aFuGz+djbm6OkZERRkdHyc/PR6PRcKudiLd8V5I2er/fv+HnXq+Xl156CbfbTSgUYnV1lbW1Nb7zne9QXFzMj370I4aHhzl+/DgZGRmoVCo+//nP097ezje/+U10Ot1WX4pA2lh+/OMfc+bMGXw+nzB2JicnmZ6e5uzZs5SVlXHkyBE0Gk3SxrpZRCIRBgYGOH/+PH/zN3/DV7/6VZ555hksFkvKX19/fz8ffPABL7zwAuPj48RiMXbv3s1/+A//gaqqKsrKypI9xE8soVCISCRCIBBgaWmJCxcu4PV6CQQCPPnkk1RWVqJSqZI9TIHT6cRqtXLixAkWFhaQyWS4XC58Ph/z8/MsLCxQVlaW8gb+6uoqPp+PaDSKw+FgYGCA1157jddff/1jjfesrCwOHDjA3r17+f73v49SqUy6QRCLxXC5XAwPD/PHP/6R48eP83/9X/8Xu3btSuq4tpK5uTl+8pOfcO7cOYaHh7n//vspKCi45cPYps/gRCKB1+slFottOD2urKzg9/sJhUI4nU4mJyc3WDDhcJjBwUFCoRCxWIxoNArAiRMnMBgMdHV1sbS0hFKpRKfTYTAYKC0tpaSkJGknUcloOXPmDB9++CFTU1MEg0FisRhyuZyMjAwUCgUymYxoNIrX68VqtaJQKMjLy0Mul2/aKSgWi7G4uEg4HCYUCmEymcjPz9+Uz77e37PZbHg8Hsxmc1KNsRvF4/EwMTHB2bNn+eCDD7Db7cRiMTEPb9WiTnN9pHk/MzPD7Owsi4uLeDwelpeXcbvdLCwsEA6HiUajaLVa6uvrOXjwYMp4ZzIzM9HpdJSWlgovgUQqz5dwOIzL5WJxcZHFxUX6+/ux2+0i/2FmZoaRkRESicRVXtorCYVCjI+Po1KpePXVV2loaKC2thaVSpUUL04ikSAUCjE1NcXPfvYz5ubmsNvtuN1u/H4/WVlZKeVd2mxCoRDnzp2jv7+f8+fPY7VaycjIoLi4mPLy8ls2TDfdGIjFYjidTuHmn5ubY25ujsHBQex2O16vl8XFRT788MOPfJj0ej0Gg4EPP/wQlUpFf38/kUgEtVqN0WikqKiI8vLypLrdA4EA8/PznDx5kt/+9rfY7XZCoRBw2bUmJQ0qFApcLheBQIC5uTmxwCiVyk0zZGKx2AYXpkwm2xJjwOv1UlxcTHZ29h37W5uF2+2mp6eHc+fOcfLkSaLR6IZFI5UWkCsNlPXPikwmu+ZYU2n88KdF2263093dzYcffsjw8DDLy8tMTU0RiUTEe+VyORqNhqWlJXbu3JkyxoBarUan02GxWPB6vUxNTSV7SDdEOBxmaWmJnp4eurq6OHXqFDabTYRnJMNAmjMfNXei0SgTExMEg0H0ej3xeJyysjIUCkVS1t5EIkE4HGZqaornnnsOrVZLTk4Obrcbn8+HWq1OuWdhMwkGg5w4cYLe3l7Onz+PTCZDpVJRVFQk7sutsKl3cnFxkYWFBf7mb/4Gu91OMBgkEAgQDAbxer1EIhGi0SihUOgjDQGZTEZnZycHDx7EYDCQkZHBzp07USgUGAwGcnJy0Ol0tLW1kZ+fv6UZ+pFIhHA4zPz8PAMDA/z85z9nfHwcu90uvBlweREpKiqira2NxsZGfvnLXzI7O8v/8X/8H9TV1fHYY4/R0dFBS0vLpo1raGiIlZUVlpaW0Ol0NDQ0bMpnX0kikSAajYq/tbKygs/nI5FIpPRpKRQKiROpFD9NNaLRKDMzMywuLtLb24vNZmNpaYnl5WUCgQAAZWVlPPTQQ2i1WrKyspieniYSibBz506MRiMWiyWp1yDFcI8fP87MzAznz59neXkZh8MhvIZ6vR6FQkFmZiZOpxO3201fXx8ej4fPfe5zWCyWO2rM3ihKpZKsrCxMJhM2mw2ZTJaS8+ZK5ufnee655xgYGGBoaAin0ykOKrc6/tXVVY4ePUpxcTENDQ1UVlamhEcwEAgQjUZZW1vD6XTecjb93UI0GmV8fJyZmRlisRiFhYUUFxeze/dudu7cSWZm5i197qYaA1JC0KVLl4Q78CP/+P9/apYIh8MAKBQKCgsLaW9vR6VSCetTpVJhMBhQq9Vis83KytrMS/hY3G43brebsbExBgcHuXTpkshzWI9MJkOpVGI2m2loaCA3N5fZ2Vm6u7vx+XzU19dTWlq6aeOKRqMsLi6yvLyM3W4XpSZ3Ainvw26343K5xAk7IyMjZR/CcDiMz+fbYLgoFAqUSiV6vZ78/Pykl4mFQiH8fj9TU1NiE52dncVqtbK0tEQwGCSRSFBZWUlhYSE6nQ6tVsvIyAiRSISCggISiURSjIFEIkEwGCQYDDI7OyvyASYnJzl9+rQ4CCgUClQqFZWVlajVarFwud1uHA4HmZmZwt2bCkheGIVCkfKJsesJBoPMzc0xOzvL7Oys+Ll0HTk5OQAbvAPwp2d7fahWIhQKsbCwIDy8Hxde2Cqi0ajweEQikbvCWJMIh8MihyYrK4usrKyPXENDoRA+nw+73Y7H4xGlhOXl5RQXF2MymW55LJtqDOTm5mI2mzEYDLhcro81BvLz8ykqKkKr1RKPx+nt7SUSiZCRkUFjYyMPP/wwwAZXlvRFSQ/oVhKPx/n1r3/NxYsXOXr0qEgkisfjV73X5/MxOjpKfX09gUCAyspK4vE4fX19+P1+Jicn6ezs3LSxBYNB3nvvPQKBAFVVVZv2uddCivW+8cYbJBIJdu/eTXV1NeXl5Sm5YIbDYbq7uzl37hy//e1vcTqdAJjNZkpKSvh3/+7fUVpaisViSZp7OpFI0NXVxdjYGD/60Y+w2Ww4HA5RXVNbW0t2djYul4tgMMgPf/hD8TxEIhFxwt69ezft7e1b6iaNxWL4/X6OHz/OiRM
"text/plain": [
"<Figure size 648x648 with 100 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"# extra code this cell generates and saves Figure 32\n",
2021-10-29 07:03:30 +02:00
"plt.figure(figsize=(9, 9))\n",
"for idx, image_data in enumerate(X[:100]):\n",
" plt.subplot(10, 10, idx + 1)\n",
" plot_digit(image_data)\n",
"plt.subplots_adjust(wspace=0, hspace=0)\n",
"save_fig(\"more_digits_plot\", tight_layout=False)\n",
"plt.show()"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 16,
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
2021-10-02 13:14:44 +02:00
"# Training a Binary Classifier"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 17,
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"y_train_5 = (y_train == '5') # True for all 5s, False for all other digits\n",
2021-10-30 21:28:30 +02:00
"y_test_5 = (y_test == '5')"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 18,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"SGDClassifier(random_state=42)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"from sklearn.linear_model import SGDClassifier\n",
"\n",
2021-10-29 07:03:30 +02:00
"sgd_clf = SGDClassifier(random_state=42)\n",
2016-05-22 17:40:18 +02:00
"sgd_clf.fit(X_train, y_train_5)"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 19,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([ True])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"sgd_clf.predict([some_digit])"
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
"# Performance Measures"
2016-05-22 17:40:18 +02:00
]
},
2021-10-02 13:14:44 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Measuring Accuracy Using Cross-Validation"
2021-10-02 13:14:44 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 20,
2021-10-02 13:14:44 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0.95035, 0.96035, 0.9604 ])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
2021-10-02 13:14:44 +02:00
"source": [
"from sklearn.model_selection import cross_val_score\n",
"\n",
"cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")"
2021-10-02 13:14:44 +02:00
]
},
2016-05-22 17:40:18 +02:00
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 21,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.95035\n",
"0.96035\n",
"0.9604\n"
]
}
],
2016-05-22 17:40:18 +02:00
"source": [
2016-11-05 18:13:54 +01:00
"from sklearn.model_selection import StratifiedKFold\n",
"from sklearn.base import clone\n",
2016-05-22 17:40:18 +02:00
"\n",
"skfolds = StratifiedKFold(n_splits=3) # add shuffle=True if the dataset is not\n",
" # already shuffled\n",
2016-11-05 18:13:54 +01:00
"for train_index, test_index in skfolds.split(X_train, y_train_5):\n",
" clone_clf = clone(sgd_clf)\n",
2016-05-22 17:40:18 +02:00
" X_train_folds = X_train[train_index]\n",
" y_train_folds = y_train_5[train_index]\n",
2016-05-22 17:40:18 +02:00
" X_test_fold = X_train[test_index]\n",
" y_test_fold = y_train_5[test_index]\n",
2016-11-05 18:13:54 +01:00
"\n",
" clone_clf.fit(X_train_folds, y_train_folds)\n",
" y_pred = clone_clf.predict(X_test_fold)\n",
2016-05-22 17:40:18 +02:00
" n_correct = sum(y_pred == y_test_fold)\n",
2021-11-03 03:54:29 +01:00
" print(n_correct / len(y_pred))"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 22,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"False\n"
]
}
],
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"from sklearn.dummy import DummyClassifier\n",
"\n",
"dummy_clf = DummyClassifier()\n",
"dummy_clf.fit(X_train, y_train_5)\n",
2021-11-03 03:54:29 +01:00
"print(any(dummy_clf.predict(X_train)))"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 23,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0.90965, 0.90965, 0.90965])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"cross_val_score(dummy_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")"
2020-11-21 00:22:42 +01:00
]
},
2021-10-02 13:14:44 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Confusion Matrix"
]
},
2016-05-22 17:40:18 +02:00
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 24,
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
2016-11-05 18:13:54 +01:00
"from sklearn.model_selection import cross_val_predict\n",
2016-05-22 17:40:18 +02:00
"\n",
"y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 25,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[53892, 687],\n",
" [ 1891, 3530]])"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"from sklearn.metrics import confusion_matrix\n",
"\n",
2021-11-21 05:06:37 +01:00
"cm = confusion_matrix(y_train_5, y_train_pred)\n",
"cm"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 26,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[54579, 0],\n",
" [ 0, 5421]])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train_perfect_predictions = y_train_5 # pretend we reached perfection\n",
"confusion_matrix(y_train_5, y_train_perfect_predictions)"
]
},
2021-10-02 13:14:44 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Precision and Recall"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 27,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.8370879772350012"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"from sklearn.metrics import precision_score, recall_score\n",
"\n",
2021-11-03 03:54:29 +01:00
"precision_score(y_train_5, y_train_pred) # == 3530 / (687 + 3530)"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 28,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.8370879772350012"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"# extra code this cell also computes the precision: TP / (FP + TP)\n",
2020-11-21 00:22:42 +01:00
"cm[1, 1] / (cm[0, 1] + cm[1, 1])"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 29,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.6511713705958311"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
2021-11-03 03:54:29 +01:00
"recall_score(y_train_5, y_train_pred) # == 3530 / (1891 + 3530)"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 30,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.6511713705958311"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"# extra code this cell also computes the recall: TP / (FN + TP)\n",
2020-11-21 00:22:42 +01:00
"cm[1, 1] / (cm[1, 0] + cm[1, 1])"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 31,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.7325171197343846"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"from sklearn.metrics import f1_score\n",
"\n",
2016-05-22 17:40:18 +02:00
"f1_score(y_train_5, y_train_pred)"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 32,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.7325171197343847"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"# extra code this cell also computes the f1 score\n",
2020-11-21 00:22:42 +01:00
"cm[1, 1] / (cm[1, 1] + (cm[1, 0] + cm[0, 1]) / 2)"
2016-05-22 17:40:18 +02:00
]
},
2021-10-02 13:14:44 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Precision/Recall Trade-off"
]
},
2016-05-22 17:40:18 +02:00
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 33,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([2164.22030239])"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"y_scores = sgd_clf.decision_function([some_digit])\n",
"y_scores"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 34,
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"threshold = 0\n",
"y_some_digit_pred = (y_scores > threshold)"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 35,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([ True])"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2016-05-22 17:40:18 +02:00
"y_some_digit_pred"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 36,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([ True])"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"# extra code just shows that y_scores > 0 produces the same result as\n",
"# calling predict()\n",
2021-11-21 05:06:37 +01:00
"y_scores > 0"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 37,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([False])"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"threshold = 3000\n",
2016-05-22 17:40:18 +02:00
"y_some_digit_pred = (y_scores > threshold)\n",
"y_some_digit_pred"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 38,
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,\n",
" method=\"decision_function\")"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 39,
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.metrics import precision_recall_curve\n",
"\n",
"precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 40,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAEQCAYAAACutU7EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABZo0lEQVR4nO3deZxN5R/A8c8zq1mM3dh3xk72qOyyJCqi7AopUVFp+VlCKSQqhKhIVEIKhRolspN9G/u+zzD7zPP745nd7HNnztyZ79vrvu655zznnO+9x535znOeRWmtEUIIIYSwJw5WByCEEEIIkVaSwAghhBDC7kgCI4QQQgi7IwmMEEIIIeyOJDBCCCGEsDuSwAghhBDC7qSYwCilFiilriqlDiSxXSmlZiqlTiil/lNK1bN9mEIIIYQQsVJTA/MV0D6Z7R2AylGPwcDsjIclhBBCCJG0FBMYrfVfwM1kinQBvtHGv0B+pVRxWwUohBBCCJGQLdrAlATOxXl9PmqdEEIIIUSmcLLBMVQi6xKdn0ApNRhzmwncqE9+G5w9N1Kgoj52hcJBOcR7jl6fWHkABxxQSiV6jJhdVJzlONuVUmZ/c9D4+0T/iyqDItmyWS0yMhIHB2m3np2l9xrdu3cPAA8PD1uHJOKQ71DmCQlx4No1VwIDE/+1XLbsPVxdIwG4fDkP/v7OiZZzdY2gbNnAmNfHjuVN8pze3sHkyxcGwO3bzly9mifJslWqBMQsnz7tQWho4v8P8uULw9s7GIDgYEfOnnVP8phJvafKle+ilI6K/9h1rXWRxPa3RQJzHigd53Up4GJiBbXWc4G5AOVrlNfjl4xHa42Oyneil6PnZ4peTmpdwn2SOk5weDDhkeFJ7pPc8VKzTWvNybMnKVqsKBE6gkgdSURkRMxyeGQ4ASEBBIYFcjv4NmGRYUTqyGQfoRGhBIcHx1sX77Mk9vOIJP42e6JQuDm74eHsgYujC27Obni6eOLk4BTzcHV0xcPFA0fliKODIw7KIWY5j2MevFy9yOOUh8LuhSmetzjeHt64Obvh4uiCh7MH+fPkp4BbAbb8vYUWLVpY/ZZFMnx9feUaZWNyfTKP1hAUlPT2PHkgOncMCYGIiMTLbd78F+3aPZKqY7q4gFNUFhAWZh5JcY+ThwQFmWMnxsnJHBdMjCEhSR8zqffk5gbRf0Mrpc4ktb8tEpifgWFKqaVAY+CO1vpSSjsVcitE3zp9bXD67CErvtha6/sSnZCIEPxD/PEP8ede6L0UE6MIHcHd0LuERYQRHhlOWGQYgWGBMclSRGRETLno1yERIdwLvUdIRAiBYYHcC7sXL0GLuxwYFkhQWBBhkWGERoQSEBJAeGR4vGNG7xOV/hEYFkhgWGDKH0AGOCgHirgUoeDBgjg7OuPk4ISLowtFPYqSzzUfhd0LU9CtIEXci+Dt6R2zvqBbQbw9vXFQ8ldndnbunLmLXbp06RRKCpE9KRU/SUiOq2vS21xcYv+gTcsxnZ3NIzXc3FJXztHRNu8pKSkmMEqp74AWQGGl1HlgLOAMoLWeA6wBOgIngEBgQNrDEKmhlDI1DzjGrHNzdiN/nvzWBZUB0QlZQGgAoRGhMQlPdI1ZdIIVnTwlljTdDb1LUFgQd0PvcunuJc77n8c/xJ/QiFBCIkK4E3yHOyF3uBl0kyshV7gSciVdsRZyK0QBtwIUdCuIm5MbBd0KUj5/eeoUq0OJvCUom68s5QuUx8nBFn8TiLTq06cPYP6QEMLeLF8Oa9ZA9+7QPrk+vyKeFH/aaq2fSWG7Bl6yWUQi14hOyLIiAQsMC2TFhhXUrV83JjG6E2wSm4DQAG4E3uBa4DVuBN7gyr0rXLl3hYCQAC7dvYR/iD83gm5wI+hGsueIrtFxd3anqEdRyucvT/Ui1SnsXpjKBSvzSNlH4rUtErbz7rvvWh2CEOn211+wYAFUry4JTFrIn4siV3B3dqekW0lqFK2R5n3DI8O5HnidG4E38A/xJyg8iCt3r7Dvyj7O3jnLhYAL+N3y47z/ec77nwfg2I1jbD67Od5x8ufJT+cqnfEp5MMDxR+gednmeLhIo1NbaNOmjdUhZGvBwaZdQsL2t/v3w/jxUKYMNGhgHpUrx7Y/EFnj7FnzXKaMtXHYG0lghEiBk4MTxTyLUcyzWLz1z9SKXznpH+LPneA7+If4c/XeVY5cP8Lxm8e5GHCR9X7ruRl0k0X/LYop7+7sTs8aPXm00qM0K92Mkl4y+kB6+fn5AVChQgWLI7Gdy5dhxQooUQKaNYPChVPeJzLSJCmRkSYJmTUL1q6FX3+F4sXhqaegRQtzvGLFoF07c564KlWC4cPh5Zfjr0+ugafIGElg0kcSGCFsxMvVCy9XLwBqUIOW5VvGbNNas/vSbv48/SeHrh1i6/mtHLl+hAV7F7Bg7wIAfAr50LlKZ8oXKE/lgpVpWb6ltKlJpYEDBwL21QZm+XL44gtYvz52Xb58sH07VKkCAwbAunXx9ylRAnr1gvffh9On4fXXoUgR8PGBUaNiy/3zDxQtChMnxiYoly7BZ5+Zx0cfmX0bNIBffjFJjZcXbNkCJ07AsWOxxzp92tzaCAqCxo1r8fHH8OCDUktjS5LApI/8dBQiCyilqF+iPvVL1I9Zt+fSHlYfW80Gvw3svLiTozeOcnTr0ZjtLo4uPFH1CbpX707bim1jkiNxv/Hjx1sdQoo2b4bdu03Nxpkz0K3b/WXu3Int4fHoo/cnMBcvwp49MG+eqU3x9jZJUEJLl8KMGbFJRsWKMG2aOf+//0KTJmb9q69C//7w5JOmbEQEbNoEhQrFHuvPP2O74m7bVohmzcwv2ubNTbsNJ/ktkiGBgXD9uukB5O1tdTT2RemkOnNnsgYNGuidO3dacu7MIOMjZH/Z+RqFRoSywW8D2y9s54L/Bf44/Qd+t/xitns4e9CrVi/GtRhH8bw5d6aO7HyN0ktr+Pln6NrVvI4e7yK6e2mTJuaWTc2asHcvRHWoiiciAjZuhG3bTHJRp465DbRtm6k12bkTfvwRRo6Exx83NTW2+mv+8GGYOxcOHgQ3twts21aSK1egQgVTWxOdJJ05A6VL39/ORiTv6FGoWtV8nidPZvx4Oe07pJTapbVukNg2yZ2FyAZcHF3oWLkjHSt3jFl3+Nphlh5YytoTa9lxcQdzd89l/p759K7dm1ebvEod7zrSqynK0aOm5srHx8fiSGItX25qQn78Mf766IG+EvvbsVatxI/l6Gjaq7RrF39948bmAfDttxmPOTHVqsH06WbZ1/c4jzxSkjVrTK1M9H+/U6dMLU/VqvDGG/Dss7GDmYnkRUaa61o85/5dkmkkVxYim6pWpBrjW45n+6Dt7HthH63LtyZSR/LNvm944IsHqDizIgv3LCQiMokhOXORIUOGMGTIEMvOf+aM+WWuFIwZYxq8jhx5f/Kydq3911A4OMBjj5kxS6L9849pZHz4sGm7U748TJ2a/CiswqhWDX77Db76yupI7I+df5WEyB1qe9dmfZ/17Bi0gxfqv0Aht0Kcun2KgT8PpMpnVViwZwFW3Q7ODt5//33ef//9LD3nxYvQubNJWsqVi13/+++mPcOi2A5nTJ1qalxy6hgfvXvD+fPml3CNGuazef1187lMnGje++rV5laXELYiCYwQdkIpRYMSDZj92Gwuj7rM3MfmUiZfGfxu+fHcz8/xwBcP8NE/H3HuzrmUD5bDNG3alKZNm9r8uDdvmka3np7QoQMUKGAax4Jp7PrLL/fvs2yZeX74YfOLW2tTG5PTubhAv35mbJkVK6BUKdMDKjAQxo0zbXMaNoQPP0x6Hp/c6OJF8Pe3Ogr7JAmMEHbIycGJQfUHceLlEyzsspBinsXYd2Ufb254kzKflOGtDW8RHB5sdZhZ5sCBAxw4cCDDx4lOOAAGDjQNZj/7DO7dMz2Cbt+GKVNgw4b4tS7Hj5u2DFpD2bIZDsOuKWUaLJ85YxoYDxwIr70Gbdua7aNHm7Yy330niQxA376m+/xvv1kdif2RBEYIO+bs6Ez/uv05OuwoX3f9mmalmwEw+Z/JVP+8OvN
"text/plain": [
"<Figure size 576x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"plt.figure(figsize=(8, 4)) # extra code it's not needed, just formatting\n",
2021-10-29 07:03:30 +02:00
"plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n",
"plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n",
"plt.vlines(threshold, 0, 1.0, \"k\", \"dotted\", label=\"threshold\")\n",
2021-10-29 07:03:30 +02:00
"\n",
"# extra code this section just beautifies and saves Figure 35\n",
"idx = (thresholds >= threshold).argmax() # first index ≥ threshold\n",
"plt.plot(thresholds[idx], precisions[idx], \"bo\")\n",
"plt.plot(thresholds[idx], recalls[idx], \"go\")\n",
2021-10-29 07:03:30 +02:00
"plt.axis([-50000, 50000, 0, 1])\n",
2021-11-03 03:54:29 +01:00
"plt.grid()\n",
2021-10-29 07:03:30 +02:00
"plt.xlabel(\"Threshold\")\n",
"plt.legend(loc=\"center right\")\n",
"save_fig(\"precision_recall_vs_threshold_plot\")\n",
"\n",
2016-05-22 17:40:18 +02:00
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 41,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAFYCAYAAAAV9ygtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABce0lEQVR4nO3dd3hURdvA4d8kpIceQiAJCRCKQenSS6RJMfAKUqQIKCAKCL6g0qQIKCgoRWn6UgQV9LOBCEoLiCCgFJUSeu+9hPT5/thkSYVNyO7ZbJ77uvbylNkzz45knz3nzJlRWmuEEEIIW3MyOgAhhBB5kyQgIYQQhpAEJIQQwhCSgIQQQhhCEpAQQghDSAISQghhCJslIKXUAqXUJaXUv5nsV0qpmUqpI0qpv5VS1W0VmxBCCNuz5RnQIqDlA/a3AsolvfoBc2wQkxBCCIPYLAFprTcD1x5QpB3wuTb5AyiklCphm+iEEELYmj3dA/IHTqdYP5O0TQghhAPKZ3QAKagMtmU4TpBSqh+my3Q4eRSoUbRYcQq7Z/T2vC0xMREnJ3v6jWEfpF0yJ22TMWmXzB06dOiK1rpYdt5rTwnoDBCYYj0AOJdRQa31fGA+gFuJcnrYvBW82bKi9SPMZSIiIggLCzM6DLsj7ZI5aZuMSbtkTil1MrvvtaeUvgJ4Iak3XB3gptb6vNFBCSGEsA6bnQEppb4CwgAfpdQZYCzgAqC1ngv8DLQGjgBRQG9bxSaEEML2bJaAtNbPP2S/BgbYKBwhhBAGs6dLcEIIIfIQSUBCCCEMIQkojbsx8SzbcYqhX+/l7I17RocjhBAOy566YRvq/M17zNt0jG//OsPtmHgArt2NYWHvWiQkarYfu8qKvec4cukOH3SsQmkfL25ExeLu4oy7i3OOxqK1Ril5rkkI4djybAJKTNT89M95Cnm4EBF5maXbTxIbn5iqzMbIywQPX4Vvfjcu3Y4xb39qakSqcqElCvBig9K45XMivEpJAGLjE3FxVkTHJeLsZEom9+ISiI5LICFRs+vUdXafusHvR65w7PJdfLxduRYVS4mCHhy/cheAfo3K0Lt+MHei49l37hYx8QkUL+BO4/LFJEEJIXK9PJmALt2Kpu+Sv9h7+kaq7W2eKMGAp0JISNSEf7zlfvnbMZQq4klUbDxX7sSmO97+87cY9s1eAGZtOIyXWz52n7qRrtyDnLsZDWBOPgDzNx9j/uZjGZZ3d3EiOi6Rj7tWo7CnKx6uzni6OhNY2BMvtzz5v1UIkcvkuW+qgxdu0XPBDi7eun9GUyOoMOPbVuJx/4Lmbb8Pb8Jzc7ZSu3QRutcJokZQYeISNH+dvM7u09cJKOxJ2WJetJm5JdXxD12889AYPF2dqRlchOqlCnHqahSP+xfk6OU7VAksRGiJAuw5fYPRP9yftcK/kEe6+1HRcaaztYFf7k53/O9erUdsfCKJOsORjIQQwi7kqQT095kbdPtsO7ej43kyuDDVSxWmrK83z1UPwMkp9SUt/0IebBvRNNU213yKumWLUrdsUfO2E5PboLXm7zM3GbdyH+V8vWlSsThFvFxxd3EiIVFTyNOV29FxlPPNz+3oOHy83dLVl9Lj/gXpXicIMF3Kc81n6iuitebnfy4AMHbFvxmejQG0n73VFK8zHHoqi40khBA2kmcS0Mmrd3lx0U5uR8fzdKXizOhSLcc6DyilqBJYiO9frf/Qsh6uWaszOfkk19OmsmmGiuT/ptT10z/YevSqeT02AYKHryKwiAdTn6tC7TJF071HCCGMkicSUHRcAv0+/4srd2JpWM6Hj7tWx8XZ8Xqgf9m3jnk5ePgq8/Lpa/foPP8PAMr4eHE7Jp7l/epQppi3zWMUQohkDp+APlx7iJnrDwNQ2seLOd1rOGTySev4e615c+FabrsUZc2+C+btx5I6OTSZtokQX2+OXLpDheL5uRUdx42oOO7FJQDw06AG+BfywMM157uZCyEEOHgC2n7sqjn5AMzoUhXvPNJDTClFmzKuhIXVIDFR89+v9+DkpPhu11lzmSOXTB0mIi/eTvf+Z2ZtSbetbpmiLHrxSdzySUISQjw6h/42/mjdIfNy19qlqBxQyLhgDOTkpJjepRoAH3aqypFLd/h1/wUu3IwmLkGTmKhpWN4HJ6VYsOU4f568nuFxth27SoXRawCoH1KU34+Y7je5uzgxrEUFXqxfmuTHk87fjE7qiOFMXEIit+7FUcDDJU+cfQohLOOwCeifMzf549g1AHy83XhLJqwzC/H1JsQ3JMN9rZ8wdW5ITNTEJiSy7sBFDl24zcwNR1KVS04+YOoSPnHVASauOmBR/SULuvP0436UL56f52oESFISIo9y2AT05Y5TAPSuH8yo1o+RT77kssTJSeHu5MwzlUtCZfhviwpcvh3Dfz75nedrBVK8gDsRkZe5eCs60zOmzJy7Gc3C308AMOK7f9LtX/t6Iw5cuE3VgEKUKuqZEx9HCGGHHDIBRcclsHKvaTbvrrVKSfLJIcXyu/H78Cbm9Y4178+gfvFWNFfuxFDUyw23fE44OysOXbiNu4szZYp54emaj/M371H3vQ04OykSEjN/SLb5R5vTbXu2mj9ebs4U8XKjX6MyeeZenhCOzCH/in8/coU7MfFUKlmAcsXzGx1OnlC8gDvFC7in2lYzuEiq9RIFPTgxuY15PT4hkX/P3SIuIZHb0XG8uOhPAHy8XdM9ZPv97vudJ1J2LHklrCynrkZx9PIdAot40upxP9pWKSk/OoTIBRwyAf2S1O346Up+BkciHiSfsxNVAwuZ11MmJ601529GM+K7f9h06DL5nBTxGZw1zYk4al4+eOE2a/df5L9f7zVvc3dxYslLtfHN70apIp4yiKsQdsThEpDWmk2HLgPQPLS4wdGI7FJKUbKQB4tfrJVu39U7MdSbvIGY+ESahxZn7f6LmR4nOi6RjnO3pdvu/8cGapcuQhEvV4rld+P52qUA09BHPt5uOfdBhBCZcrgEdPJqFBdvxVDUy5WKfnL5zREV9XYjcmKrDPdtjLzE9mPX+O3wZfadu5XpMc7euMd3KS7rvbf6YKr9BdzzUTO4CAGFPYhL0LR83A/f/G48VqJAznwIIYTjJaDtx03dg2uVLiKXW/Kgpyr48lQFX4a3ut/tPnmCv61HrhB58TZTft5PwwrF0Vqz7sClDI9zKzqeDQfv7/sqqVdlSo3KF6N/ozLUC/HJ+Q8iRB7gcAnozxOmLsFPprkBLvKu5B8i9UJ8qBfiQ+m4k4SF1TTvTzkD7d2YeD777Thbj16hgl9+Pt92MtPjbj50mc1Jl3sBGoT4MKJ1RUJLFJAfP0JYwOES0P7zpssulQMKPqSkECYpk4WXWz4GNyvH4GblAHin3ePmfYmJmtvR8Sz54wR/nbzOxsjLqY6z5ciVVPNDuTo7Mezp8nSrHSSTBAqRAYf6q4iNT+Rw0oRwFeVavchhTk6Kgp4uDGxSzrwtMVHz6/4LrPrngvnZs2SxCYm8+/NB3v35IGWLefH+c5WpESRn5kIkc6gEdPTyHWITEgku6ikPKgqbcHJStHy8BC0fL8HMLlXR2tTBYfPhy4z6/v6stkcv36XDHFNvvKfdDvPL4umcOnWKUqVKMWnSJLp162bURxDCMA71LX0g6fKb9FQSRlBKoRQEFvGkW+0gutUO4mZUHHUnrycq1jTNxZ19G/l0zcfoeNOU8CdPnqRfv34AkoREnpPrHxePu3qGP9d+D8DRCze48OVwbvy9AYCoqCjCwsJYvnw5ADdv3iQsLIzvvvsOgCtXrhAWFsbKlSsBuHDhAmFhYaxZYxrx+fTp04SFhbFu3ToAjh07RlhYGJs2bQIgMjKSsLAwtm41TYH977//EhYWxs6dOwHYs2cPYWFh7NmzB4CdO3cSFhbGv/+afhlv3bqVsLAwIiMjAdi0aRNhYWEcO3YMgHXr1hEWFsbp06cBWLNmDWFhYVy4YHrQduXKlYSFhXHlyhUAvvvuO8LCwrh58yYAGzZsICwsjKioKACWLl1KWFgYcXFxACxatIiwsDBzW3766ac0a9bMvD579mx
"text/plain": [
"<Figure size 432x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"import matplotlib.patches as patches # extra code for the curved arrow\n",
"\n",
"plt.figure(figsize=(6, 5)) # extra code not needed, just formatting\n",
"\n",
"plt.plot(recalls, precisions, linewidth=2, label=\"Precision/Recall curve\")\n",
2021-10-29 07:03:30 +02:00
"\n",
"# extra code just beautifies and saves Figure 36\n",
"plt.plot([recalls[idx], recalls[idx]], [0., precisions[idx]], \"k:\")\n",
"plt.plot([0.0, recalls[idx]], [precisions[idx], precisions[idx]], \"k:\")\n",
"plt.plot([recalls[idx]], [precisions[idx]], \"ko\",\n",
" label=\"Point at threshold 3,000\")\n",
"plt.gca().add_patch(patches.FancyArrowPatch(\n",
" (0.79, 0.60), (0.61, 0.78),\n",
" connectionstyle=\"arc3,rad=.2\",\n",
" arrowstyle=\"Simple, tail_width=1.5, head_width=8, head_length=10\",\n",
" color=\"#444444\"))\n",
2021-11-27 01:39:04 +01:00
"plt.text(0.56, 0.62, \"Higher\\nthreshold\", color=\"#333333\")\n",
2021-10-29 07:03:30 +02:00
"plt.xlabel(\"Recall\")\n",
"plt.ylabel(\"Precision\")\n",
"plt.axis([0, 1, 0, 1])\n",
2021-11-03 03:54:29 +01:00
"plt.grid()\n",
"plt.legend(loc=\"lower left\")\n",
"save_fig(\"precision_vs_recall_plot\")\n",
2021-10-29 07:03:30 +02:00
"\n",
"plt.show()"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 42,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"3370.0194991439557"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"idx_for_90_precision = (precisions >= 0.90).argmax()\n",
"threshold_for_90_precision = thresholds[idx_for_90_precision]\n",
"threshold_for_90_precision"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 43,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"y_train_pred_90 = (y_scores >= threshold_for_90_precision)"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 44,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9000345901072293"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"precision_score(y_train_5, y_train_pred_90)"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 45,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.4799852425751706"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"recall_at_90_precision = recall_score(y_train_5, y_train_pred_90)\n",
"recall_at_90_precision"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
2021-10-02 13:14:44 +02:00
"## The ROC Curve"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 46,
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.metrics import roc_curve\n",
"\n",
"fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 47,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAFYCAYAAAAV9ygtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABwY0lEQVR4nO3dd3gU1frA8e/ZTe8JJBBa6L13lC4CKiiKIiiCinBtVxFBrIAFRNH7Uy8ochEBBVREBRXBgpEiRUBQekvoEBJaQkjbfX9/zGZNQspussmmnM/z7JPdmdmZd0+SefecOXOOEhE0TdM0raSZ3B2ApmmaVjHpBKRpmqa5hU5AmqZpmlvoBKRpmqa5hU5AmqZpmlvoBKRpmqa5RYklIKXUPKVUnFJqVx7rlVLqPaXUIaXUX0qptiUVm6ZpmlbySrIGNB/on8/6m4AGtscY4IMSiEnTNE1zkxJLQCKyFjifzya3AQvFsAkIUUpFlkx0mqZpWkkrTdeAqgPHs7w+YVumaZqmlUMe7g4gC5XLslzHCVJKjcFopsPHx6ddrVq1ijOuMstqtWIylabvGKWDLpe86bLJnbPlIlmeZD632p6L7aclcxi0LNtk/Zlhsf20glLXbpeSIXiasr/PKmApydHVBNLOHooXkfDCvL00JaATQM0sr2sAp3LbUETmAHMAGjVqJPv37y/+6Mqg6Ohoevbs6e4wSh1dLnlzV9lYrUKGVbCKYLEKFhGs1qzP4Wq6hatpFjKsVjKsQoZFyLBYOXUpBW8PE1YRRMj1p9X++p/nB84mER7gRbpVSM8w9vn3yUtUCfIGMI5tBRHhXHwCwaGhxr6scOJiMlfTLAT5epKabuXkxauYlJEAHGUuYL1X4YvTrm64P94eZo4mXKFF9WD8vMyYTSbMJvAwmTCbFGaT4tj5ZFrVCMFsApNJ4WFSmJWyP7dYIcjXg7rhAVw6fZS3Xnmet2fNoU6NSMIDfY4WNr7SlIBWAI8rpT4DOgGXROS0m2PStHLHahWS0jJIy7CSbrGSmm4lPimVDKvw17kMYjfE4OflYZzkrVYyLEYi2HvmMmF+XiSnW/gj5jw1w/xItxjrM6xWEpLSOH4hmeohvllO+JkJwDjxZ00AVquxPjE1w91F4pj4+GsXJaXZn2cmH6Wwn7zNStlP8qkZFlLSrTSuGoi3h8l24r9Ku6gQ2zYmzCpLAjAp4pPSaBoZRLCvJ5UCvPAwmfAw/7MeoHKAt22ZCU+zwsNsomqQj329q206f5iYA3sxJ5+nckDtIu2rxBKQUmoJ0BOorJQ6AUwGPAFEZDawErgZOAQkAw+UVGya5g4p6RZS0i2kZlg5l5hKusVKWoaVNIuVfacT8fIw2b7pG9/Oz1xK4UpaBj6eZlLTrfx5/AKRwT7Zv+Vn1iLsJ3vjW/yZS1e5kJxOsK8nl66m5x/Ytj0OxX8wLinX5bEJyc4WhZ2Xh8l+0jYp7CdvkzIeZy6n0Lx6EGaTCU+TwsNsrD+akEybWqEowKTApBRKKftzk4nsr5Vxcj6XmErTakHGidt2Ar+ckkHtSv7GtrYksmvX37Ru1dKIx7ZvESEiyBtvDzPeniaCfT3xMptQqnhO/O4kIuzcuZPWrVvTuXNnDh06hLe3d5H3W2IJSESGFbBegMdKKBxNy1VqhoWjCclcuppuTwYXrqSRmJKB2aTsySDDKhyOS8Lf2/gXstiWWWzNQ3/EnqdWmB+/H06gWrCvbf0/TUdJLvrWf+TcFae2z5p8wvy98DQrvDxMJCSlUTXYBz9rCqFhoSSlZtC4aiBmk3Fi9jApzGbF+aQ0WtYIxsvDuPhQJcgHT7Ox3sOsAEWInycetqShspzwTeraJKBMxnMPk8LHs6BGKffxiNtLz0YR7g7DbWbPns3jjz/Oli1baNeunUuSD5SuJjhNc4rFKqRlWEnNsHAhOd3eHJR5srdYhfikVNItwu5Tl/H2MLHzxEWuXkphyfGtnL2cyulLV0lJN2oeV9MtLo3v+PmrAJy8eDXPbbw8TIT4epKaYcTbJDIQLw8TCuPkXT8iwHZyN77xJ6VaiAjyJszfC28PE6kZVqoF+2Kyncgzaw6ZJ3izKXsSqBzghb+3B94euX9TN64BdXJpOWhl3/Dhw0lPT6dNmzYu3a9OQFqpkW6x8vfJS/y856ytF5CVdIuQbrGy7egFqof4siX2PL6eZuISU4t2sDNnC9zkunqV8DSb8DSbSLiSSmSwD5UDvDGblP1b/9nLqfakYba33RsXeRWKmmF++HmZCfb1tLfTe5gVniYT3h4mTMXUTq9pRbVz507eeecd5syZQ2BgIE888YTLj6ETkFasRIRTl1JITs0gLjGVuMQU9p5OxMfDxJ7TiRyKS3T4msG+M4kAJKb803zl7WGcyK+mW8iwCo2qBNquC5jsF2pPXrjK9fUrcelqOm1rhbLnwGFu6tICb08z1UN8iQj0ticaT7Mql234muasHTt28PPPP3P8+HHq1q1bLMfQCUhzCatVuHQ13ajB7D3LbwfOYbEKJy7k3fyUm0BvD4J8Pbm+fiXqhQfgYTbhZUsoglC7kj++XmZqhvpROcCrUMkiWo7Ts7keZEPTchIRjhw5Qr169Rg5ciR33HEHgYGBxXY8nYA0h2TYmsd+3ReHRYTtRy/i7+3B5piEbDWSvFQO8KZaiA9VgoxmrAtX0mgXFQpAgyoBtKwRQqifp659aJobvfDCC8yePZvdu3cTGRlZrMkHdAKq8DIsVuISU1l/MJ6j569gscK2o+eJCPIhNd3C5pjzDiWYTCF+nrSoHkyzasG0rhlC8+pBVA/x1YlF08qA0aNHU6VKFapWrVoix9MJqIJISbew70wi249eYNvRC2yOSeBicjoZTty6bTYpwvy9iArz46YWkVitQqOqgQT7etKgSgB+XvrPSdPKmo0bN7J69WqmTJlCnTp1ePLJJ0vs2PqMUU4dOZfEZ/vS+PToH2yJOc/lAmox9cL9qRceQM0wP6qH+JJmsVIvPABvDxOBPh40rRaEt0fpvU9D07TC+frrr/nqq68YO3YsISEhJXpsnYDKicSUdGLir/BB9GF+2HUmy5q4bNs1iAigUdVA6kcE0LF2GK1rheiai6ZVMBaLhfj4eKpUqcK0adN47rnnSjz5gE5AZVJKuoVdJy/x8944jp9PZs2+uFxvogz3VdzQvAbX16/MDU0idKLRNA2A+++/n23btrF9+3Z8fHwIDQ11Sxz6jFRG/H44nuk/7OOvE5fy3KZasA91wwNoUSOYJ29owKYN6+jZs2UJRqlpWlkwatQorr/+enx8fNwah05ApdSxhGSiD8Sx4VA8q3fnftd++6hQPM0m/tWjLu1rhxHgrX+dmqbl7pdffuHUqVPcd9999OzZs1RMSaLPWKXIldQMPlofw7wNMVxMzn3E4udvbszwzlG6OU3TNIeJCP/5z384c+YMw4YNw8OjdJw/SkcUFdSxhGTe+eUAJ85fZUvs+WvWh/h50rdpFW5qEUmnOmE66Wia5pS0tDTS0tIICAhg0aJFmM3mUpN8QCegEhd3OYV3fjnI74fi8xwDrXvDcCYNaEr9iIASjk7TtPLCarVy88034+vry4oVK9zSy60gOgGVAKtV2BSTwKLNx/j+r+yTvHqaFfdfV5uuDcJpVi2IygGumWdD07SKzWQyMWTIEAICAkrtSCQ6ARWjtAwr7/5ygFm/Hs623NvDxINd6/DA9bWJCHRvLxRN08qX5cuXU6lSJbp27cqYMWPcHU6+dAIqBscSknl66Q7+iL1gX2ZS4OflwbgbG3JPp1qlevZHTdPKpvT0dCZOnEiDBg3o2rWru8MpkE5ALpRusTJvfQyv/7Av2/IpA5tyX5famPXkY5qmFYOrV6/i5eWFp6cnq1evJiKibEwfrhOQi6w9cI4R87ZkW/a/Ee3p0ySi1La/appW9iUlJdGrVy/69OnD66+/TlRUlLtDcphOQEVgtQrzNsQw+7cjxCf9M0X0iC5RTBnYTE+3rGlasfP396d79+506dLF3aE4TSegQjpzKYXOr/+SbVn9iADeHdqaZtWC3RSVpmk
"text/plain": [
"<Figure size 432x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"idx_for_threshold_at_90 = (thresholds <= threshold_for_90_precision).argmax()\n",
"tpr_90, fpr_90 = tpr[idx_for_threshold_at_90], fpr[idx_for_threshold_at_90]\n",
2021-10-29 07:03:30 +02:00
"\n",
"plt.figure(figsize=(6, 5)) # extra code not needed, just formatting\n",
"plt.plot(fpr, tpr, linewidth=2, label=\"ROC curve\")\n",
"plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier's ROC curve\")\n",
"plt.plot([fpr_90], [tpr_90], \"ko\", label=\"Threshold for 90% precision\")\n",
2021-10-29 07:03:30 +02:00
"\n",
"# extra code just beautifies and saves Figure 37\n",
"plt.gca().add_patch(patches.FancyArrowPatch(\n",
" (0.20, 0.89), (0.07, 0.70),\n",
" connectionstyle=\"arc3,rad=.4\",\n",
" arrowstyle=\"Simple, tail_width=1.5, head_width=8, head_length=10\",\n",
" color=\"#444444\"))\n",
2021-11-27 01:39:04 +01:00
"plt.text(0.12, 0.71, \"Higher\\nthreshold\", color=\"#333333\")\n",
"plt.xlabel('False Positive Rate (Fall-Out)')\n",
"plt.ylabel('True Positive Rate (Recall)')\n",
2021-11-03 03:54:29 +01:00
"plt.grid()\n",
"plt.axis([0, 1, 0, 1])\n",
"plt.legend(loc=\"lower right\", fontsize=13)\n",
2021-10-29 07:03:30 +02:00
"save_fig(\"roc_curve_plot\")\n",
"\n",
2016-05-22 17:40:18 +02:00
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 48,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9604938554008616"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"from sklearn.metrics import roc_auc_score\n",
"\n",
"roc_auc_score(y_train_5, y_scores)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning:** the following cell may take a few minutes to run."
]
},
{
2021-10-29 07:03:30 +02:00
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 49,
"metadata": {},
2021-10-29 07:03:30 +02:00
"outputs": [],
"source": [
2021-10-29 07:03:30 +02:00
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"forest_clf = RandomForestClassifier(random_state=42)"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
2021-10-29 07:03:30 +02:00
"y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,\n",
" method=\"predict_proba\")"
]
},
2016-05-22 17:40:18 +02:00
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 51,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0.11, 0.89],\n",
" [0.99, 0.01]])"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"y_probas_forest[:2]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These are _estimated probabilities_. Among the images that the model classified as positive with a probability between 50% and 60%, there are actually about 94% positive images:"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 52,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"94.0%\n"
]
}
],
"source": [
"# Not in the code\n",
"idx_50_to_60 = (y_probas_forest[:, 1] > 0.50) & (y_probas_forest[:, 1] < 0.60)\n",
"print(f\"{(y_train_5[idx_50_to_60]).sum() / idx_50_to_60.sum():.1%}\")"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 53,
2017-07-07 21:56:30 +02:00
"metadata": {},
"outputs": [],
"source": [
"y_scores_forest = y_probas_forest[:, 1]\n",
"precisions_forest, recalls_forest, thresholds_forest = precision_recall_curve(\n",
" y_train_5, y_scores_forest)"
]
},
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 54,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAFYCAYAAAAV9ygtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABDUUlEQVR4nO3dd3xUVf7/8ddJD0lIQkgCJEACoYViQ4ooglhA7G1X9uvaELvurqhY1sWyiqzl51rWxYa6a1vsgiCW0BVRUXoPvYWehPTz++MmQwIJJJnJ3JT38/GYx9y5c+fczxxCPjnnnnuOsdYiIiLibwFuByAiIk2TEpCIiLhCCUhERFyhBCQiIq5QAhIREVcoAYmIiCv8loCMMa8bY3YYYxZX8b4xxvzTGLPaGPObMeZEf8UmIiL+588W0ERg6FHeHwZ0Kn2MAv7lh5hERMQlfktA1tqZwO6jHHIh8JZ1fA/EGGNa+yc6ERHxt/p0DSgJ2Fju9abSfSIi0ggFuR1AOaaSfZXOE2SMGYXTTYcJa35SQFBbirPD6jI2EWkATGW/RQBjbIX3jTm0r7ygIEtgoCUgwGIMnmdjSggIMKWfL9t3qOzAwLLPOZ8JCKi8/MZo5cqVWdba+Np8tj4loE1A23Kvk4EtlR1orZ0ATAAIbd3JPjRhEn8e0rXuI2xgZs6cycCBA90Oo96pi3opKYHiYudRVHTouWy7oADy86t+lCn/C7L8c1XbR9tX9lxS4pw/L+/Q+fLynH2HfzYzcx0pKamUlBz5PSo7T1XnLP9cvm7KyispqXhMYSEcPAg5OYceBw4cehQWHvmdysoqLISyKS2rmtqyulNelp3HF6KioHVrSEyE+Hho2RICA53Yo6IgLs55PyUFkpOd7WbNfHd+fzHGrK/tZ+tTAvoMuM0Y8x7QF9hnrd1anQ8GBzfMf7i6FhZWonqphOqlahkZ6xk0KNXtMGqsfEIqLnb2HZ6syr9fWHhouyyRFhTA1q2wc6eTDHNznUR48CCsWLGBhIR2nmRc9igsdJJ5VpbzuZwc53P79x9KnitXVv97dO0KI0dC587Qti2kp0NIiO/rq77wWwIyxrwLDAJaGmM2AX8DggGstS8DU4BzgdVALnCtv2ITkYbN6fpy/hj1Rnp65fszMtYyaFC7apdjLezZ4yS0HTuc5LRr16GW2IEDzuvNmyEz03neuhWWL4fRow+VExICxx0Hw4dD//7Qvj2kpjaepOS3BGStvfIY71vgVj+FIyJSZ4yBFi2cR/fu1ftMURF88glMmQLbt8Pq1U7r6ccfnUeZsDDo1w8GDoQ//MFpLTVU9akLrlbaRgVwy+A0t8MQEfFKUBBcdpnzKLNvH8yZA1984bSOMjNh3TrIyHAejz4Kl18ODzwAPXtWPQijvqpPw7BrJdBAZGiDz6MiIkeIjoZzz4WXXoJvv4W1a53rTR9/DNdd5yStDz5wuuni4uCKK2DDBrejrr4Gn4B8LbegiPfmb+CuD35l896DbocjIlJBXBxcdBG89pqTkG6/HWJinGtO//sfdOkC117rJKyyARn1VYNPQNtzLZN+2uSz8jKzchnz0SI+/HkTD368qMJ7m/bkctm/5rIuKweAvbkF5BX6/l9Yy6SLSHUkJ8M//wm7d8OaNU4LKC8PJk6EIUOgRw+YMaP6w9D9rcH3XR0ssqzdmV3jz5WUWL5YtJWY8GC27D3I7/s4I1zS2zT3HPPdip2kjJlM5rjhAPyyYS8L1u9h8FMZFcp67KIenJrWkkWb93H+cW0AKCgqISSoYn7fd7CQvMJiikssP2/Ywy8b9rIuK4fzj2vNxSck8+3y7Tz2xTLWlia4UQM7cO2AFOIiQj1lzV+3m5NTYjENrbNXROqMMdChA7z/Pjz2mJOA/vtf57rRoEFw/PHw4INw6aUuB3qYBp+AamPH/jxuePsnft2417OvS6soTmgXC8Dnt53K+S/M9ry3aU8uybHNOJBXdHhRnN45nhPbxbIuK4fb3/2F579dRURoEL9scMr+bezZNA8L5r6PFvHu/Mo7Z+8Z2gWAsOBAT/IBmDBzLRNmruXB4d0YeVoH9uQUcMW/55UeG0BeYQkvjDiB1tHh9EhqTmhQoFf1IiINX6dO8Pe/w1//CuPGwQsvwMKFzuCGa66B55+HyEi3o3Q0uQS0fNt+rn59Ptv3H7r9/KT2sUSFHaqKnsnRzBlzBpf9ay59U1tQUuLsv+ykZFJbRvDLxj0kxzYjLT6SA3mFdG0VxfDnZwGwcnvF1tgPa3dzVnoiic1DPfuahQTSO6UFJ7aLYcOuXFLiIgA4oW0sX9x+Kgs37uXBTw6tWpFf5AQwYdZaz768Qmffbe/8QmrLCO4c0okLj2/DBws20r1NNNn5RZTU13a3iNS5sDAYOxbGjIFXXoF773VaRvPnw6efQlo9GDzc5BLQ1n15ZGUXcHJKLCe2i6VjQiSXnZhMQEDFLq2kmHDm3Tekwr6QoAD6d4yjf8e4I8r95JYBjP18CZ0SIjmjayItIkJoFhJIemunS+/aAancOLAjB/IKaRkZesT5AMJDAumRFE2PpGj+r197oGJX3j3ndKFHm2gA/vbZYrKyC4gKDeKBc7txZnoivR/7mqzsQ4k1JBBWDvaiskSkwQsLcwYqnHGGM2R76VI49VRYsMC5huSmJpeABndJ4K3r+nBS+1jCgn3XZXVc2xg+vmVAle9Hhzu3aIeH1Oyc5a8jGWMY3stZoaLsubzOiZEVElBBMaSMmcyax88lsJKEJyJNR/fu8MMPcP75zsCEq65yRsq5eTm5SSSgvMJiduzPp12cMwHYgLSWLkdUN965oZ9nO2XMZADuOCONPbkFbNidyyUvzaVDywgO5Bfx/qh+dIivJx3BIuIXUVEwaZIzVDsjA2bNcmZUcEuDH4bdLMiQllD1L9Jnpq+k61+nMuy5mUxdXK25TRuFdU+cy+Wdg+mVHMOvG/dyyUtzAViblcPOA/mc8fQMHvr00HWmsZ8todtfp5IyZjIpYyazePM+9uYWVFW8iDRQLVvCHXc427fc4tzY6pYGn4ASmhkuObHyjsz563bzz29WAZBTUExSTNOZAtkYw/AOIZyZnsiQbon8/uS2XHJixfX9yo/KW7x5HwfL3dN03vOzOf6R6WzblwfA3NVZjP1sCflF9fzONhE5pjvvdGbeXrIEzjzTuYnVDY26C+7b5Ts82yP6tqNncrSL0bhr3KW9AHjmiuNZvSObr5ZuI711c6y1GGMYf1kvlm87wOuz17Fg/aGfxu378/h62XaemLKMnIJiJs7NBGBAWhxz1+zivRv60bdDHN8u386gzgme/uSt+/JoERHiuc5WUmIptpbgwAb/N49IgxcT41z/Of10+PVXZ7btr7/2/7I2DT4BFVvYn1dI87Aj52EfM6wrE+euIzI0mHuHasG6MmkJkaQlVByD2SE+kg7xkZzb0xncUFLiDOEOCDC0a9GswrBwgDmrd9E8LIhW0WHkFxVz3cQFlZ5r/v1DSGgeRnZBEb3GfgVAm+gwzunRis6JUVx2UrKSkogLWrd2ks6pp8K8ec4NrI8/7t8YGnwC2nighJcz1nBPFQlm2SNDKS6xBOmXXI2UHyYeGxFC5rjh7DyQz0UvzuHKPm1JbB7G+l25tI+L4LXZ66osZ8nW/SQ0D+OTXzZ79m3Zl8cbczIBuO+jRWSMHkS7Fs3ocP8UAKb/eSDLth3g+OQYz8AREfG9du2c+eP694ennoKrr3YGKPhLg09Alckvcka9tW3RDGMMQYEaguwL8VGhzBlzxhH7rz81lfN6tSYrO5+4iFBCgwIIDDSs2p7tuQ/qj/1TOCs9kf5PfEtggKG4tIX1xe2nktIygqemrfCUd9azMz3bzUICmXH3YAIDDB/9vIn9eUWMGthBM6CL+EjfvnD99fDqq3DfffDRR/47d6P8Xzx7VRbXv7mAC49vw3O/P8HtcJqExOZhJDYPq7DvpPaxFV63jg73zKtXpqyr78bTO/DCd6sBaBkZQla2MwLvlI5xNA8P4teN+3hs8jIAz8ASgA9vPsVznvFTl5OWEMkFx7VRi1e
"text/plain": [
"<Figure size 432x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(6, 5)) # extra code not needed, just formatting\n",
"\n",
"plt.plot(recalls_forest, precisions_forest, \"b-\", linewidth=2,\n",
" label=\"Random Forest\")\n",
"plt.plot(recalls, precisions, \"--\", linewidth=2, label=\"SGD\")\n",
"\n",
"# extra code just beautifies and saves Figure 38\n",
"plt.xlabel(\"Recall\")\n",
"plt.ylabel(\"Precision\")\n",
"plt.axis([0, 1, 0, 1])\n",
"plt.grid()\n",
"plt.legend(loc=\"lower left\")\n",
"save_fig(\"pr_curve_comparison_plot\")\n",
"\n",
2016-05-22 17:40:18 +02:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We could use `cross_val_predict(forest_clf, X_train, y_train_5, cv=3)` to compute `y_train_pred_forest`, but since we already have the estimated probabilities, we can just use the default threshold of 50% probability to get the same predictions much faster:"
]
},
2016-05-22 17:40:18 +02:00
{
"cell_type": "code",
2021-11-03 03:54:29 +01:00
"execution_count": 55,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9274509803921569"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"y_train_pred_forest = y_probas_forest[:, 1] >= 0.5 # positive proba ≥ 50%\n",
"f1_score(y_train_5, y_train_pred_forest)"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9983436731328145"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"roc_auc_score(y_train_5, y_scores_forest)"
]
},
2016-05-22 17:40:18 +02:00
{
"cell_type": "code",
"execution_count": 57,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9897468089558485"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"precision_score(y_train_5, y_train_pred_forest)"
]
},
{
"cell_type": "code",
"execution_count": 58,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.8725327430363402"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"recall_score(y_train_5, y_train_pred_forest)"
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
2021-10-02 13:14:44 +02:00
"# Multiclass Classification"
2016-05-22 17:40:18 +02:00
]
},
2021-10-29 07:03:30 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-30 21:28:30 +02:00
"SVMs do not scale well to large datasets, so let's only train on the first 2,000 instances, or else this section will take a very long time to run:"
2021-10-29 07:03:30 +02:00
]
},
2016-05-22 17:40:18 +02:00
{
"cell_type": "code",
"execution_count": 59,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"SVC(random_state=42)"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"from sklearn.svm import SVC\n",
"\n",
"svm_clf = SVC(random_state=42)\n",
"svm_clf.fit(X_train[:2000], y_train[:2000]) # y_train, not y_train_5"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 60,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array(['5'], dtype=object)"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"svm_clf.predict([some_digit])"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 61,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[ 3.79, 0.73, 6.06, 8.3 , -0.29, 9.3 , 1.75, 2.77, 7.21,\n",
" 4.82]])"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"some_digit_scores = svm_clf.decision_function([some_digit])\n",
"some_digit_scores.round(2)"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 62,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"5"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"class_id = some_digit_scores.argmax()\n",
2021-10-30 21:28:30 +02:00
"class_id"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 63,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], dtype=object)"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-10-29 07:03:30 +02:00
"svm_clf.classes_"
]
},
{
"cell_type": "code",
"execution_count": 64,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"'5'"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
2021-10-30 21:28:30 +02:00
"svm_clf.classes_[class_id]"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you want `decision_function()` to return all 45 scores, you can set the `decision_function_shape` hyperparameter to `\"ovo\"`. The default value is `\"ovr\"`, but don't let this confuse you: `SVC` always uses OvO for training. This hyperparameter only affects whether or not the 45 scores get aggregated or not:"
]
},
2016-05-22 17:40:18 +02:00
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.11, -0.21, -0.97, 0.51, -1.01, 0.19, 0.09, -0.31, -0.04,\n",
" -0.45, -1.28, 0.25, -1.01, -0.13, -0.32, -0.9 , -0.36, -0.93,\n",
" 0.79, -1. , 0.45, 0.24, -0.24, 0.25, 1.54, -0.77, 1.11,\n",
" 1.13, 1.04, 1.2 , -1.42, -0.53, -0.45, -0.99, -0.95, 1.21,\n",
" 1. , 1. , 1.08, -0.02, -0.67, -0.14, -0.3 , -0.13, 0.25]])"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code shows how to get all 45 OvO scores if needed\n",
"svm_clf.decision_function_shape = \"ovo\"\n",
"some_digit_scores_ovo = svm_clf.decision_function([some_digit])\n",
"some_digit_scores_ovo.round(2)"
]
},
{
"cell_type": "code",
"execution_count": 66,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"OneVsRestClassifier(estimator=SVC(random_state=42))"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"from sklearn.multiclass import OneVsRestClassifier\n",
"\n",
"ovr_clf = OneVsRestClassifier(SVC(random_state=42))\n",
2021-11-03 03:54:29 +01:00
"ovr_clf.fit(X_train[:2000], y_train[:2000])"
]
},
{
"cell_type": "code",
"execution_count": 67,
2021-11-03 03:54:29 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array(['5'], dtype='<U1')"
]
},
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
}
],
2021-11-03 03:54:29 +01:00
"source": [
2021-10-29 07:03:30 +02:00
"ovr_clf.predict([some_digit])"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 68,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"10"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"len(ovr_clf.estimators_)"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 69,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array(['3'], dtype='<U1')"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"sgd_clf = SGDClassifier(random_state=42)\n",
"sgd_clf.fit(X_train, y_train)\n",
2021-10-29 07:03:30 +02:00
"sgd_clf.predict([some_digit])"
2016-05-22 17:40:18 +02:00
]
},
{
2021-10-29 07:03:30 +02:00
"cell_type": "code",
"execution_count": 70,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[-31893., -34420., -9531., 1824., -22320., -1386., -26189.,\n",
" -16148., -4604., -12051.]])"
]
},
"execution_count": 70,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sgd_clf.decision_function([some_digit]).round()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning:** the following two cells may take a few minutes each to run:"
]
},
2016-05-22 17:40:18 +02:00
{
"cell_type": "code",
"execution_count": 71,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0.87365, 0.85835, 0.8689 ])"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring=\"accuracy\")"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 72,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0.8983, 0.891 , 0.9018])"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"from sklearn.preprocessing import StandardScaler\n",
2021-10-29 07:03:30 +02:00
"\n",
2016-05-22 17:40:18 +02:00
"scaler = StandardScaler()\n",
"X_train_scaled = scaler.fit_transform(X_train.astype(\"float64\"))\n",
"cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring=\"accuracy\")"
2016-05-22 17:40:18 +02:00
]
},
2021-01-28 05:12:29 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-02 13:14:44 +02:00
"# Error Analysis"
2021-01-28 05:12:29 +01:00
]
},
2021-10-30 21:28:30 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning:** the following cell will take a few minutes to run:"
]
},
2016-05-22 17:40:18 +02:00
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUMAAAEKCAYAAACIZDejAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABpmElEQVR4nO2dd3gUVduH72c3m95IQokJLSSANBEBUbCBHSuK/bWLggUVpSl2sACvir5iw67Ye0MF1E9FUCx0CDWEFtJ7suV8f8wm2YRNMrvZxATPfV1zZfbszDMnZ2afOfX5iVIKjUaj+bdj+aczoNFoNK0B7Qw1Go0G7Qw1Go0G0M5Qo9FoAO0MNRqNBoCgfzoDDREdF6TaJ4UE3O7+NYG3qdFUIZbmqWMolyvgNsspoVJVSFNsnHJChMrJdZo6duWqikVKqVObcr3molU7w/ZJITzyUe+A232+Z0rAbWrcSJN+Vy1PM0wts4RHBNwmgKukJOA2l6vFTbaRnetk+aJkU8faErckNPmCzUSrdoYajaYtoHCqwNdaWxrtDDUaTZNQgIu2v3hDO0ONRtNkXOiaoUaj+ZejUNh1M1mj0fzbUYBTN5M1Go1G9xn6jIicCjwJWIEXlVKPNHbOgv7d6HBYBQBpZxcT3cXO4ts7ENvdDsCwqTm071fJr4/FsX+VMX8wf5uNw6/Pp9/lhXxxVSdcdmO6R9bfIZz7/i5TeT3pglxOvywHpeCZu5PYvDrc938Y6NGvlBtn7sLlFJwO4fE7kuk3tIQzr8zBXink7LMxZ2Jn7JVNm5v26ZZVbPjTyOPiD9qxaGF8k+wB3Dgzk7QBZVisig+fb8/3H7drkr2klHKeX7KByWNTsVjhqql7cDpAKWH2xC7s3x3sk73wSCcz39yCvVIIDVO89Egif/0UxYnn53Li2FwsFvjqzXiWNiHfqf1LuXr6HoKCFBv/DmfBQ4eYPrdHn2LG37MNl0twOuDJu1Lpe0QhZ1y2F3ulhdx9NuZOScNeaeHSmzM4dnQO+dk2AKZd3heXy7dpSjPf2kJq/zI+frE9C5/s6NO5TUEBzoMg+pW0VAgvEbECm4CTgEzgN+BipdS6+s7p0T9CDSw6iYu+21mdtnt5KOmfRHLcrOx6r/X+mUmc9sJeIjrVTAQt3W/liysTGftFZqPzDCNjHDzy7lZuPSOV+E52Jj+VwaRz0kz+p7Vp195OeamFshIrQ0YWcvw5+bw+pyNZmcG4XMI1d+8mc0tIk53Xyz+v56rhhzbJhidde5Vx40O7mDw2lbAIJ898u4mrjjZhv4F5hnfO20F8Bzuvz+3Exr/CcdiNF8DJF+bQJa2cFx9K8imPIgqxgMspdOpSwfT525l7WxcuuDGL2RO7ACacSQPPf5DNxYOvb+OBa7pRVmI1nS9LhDHPsF1CJeVlVuPeH5fHcWdk88a8zmTtCsHlEq6evJ3MrWF8835HLr05g907wlj6aft67TY2zzAhsZLDjykmIdFu2hkuV4spVLlNmhx62GHB6qsvzU0fTEres1IpNbgp12suWnI53lBgs1Jqq1KqEngbOLuxk0qzrXx2aSLf3NiRokyjIpv5UzifXpzIzw/E4yivfR+z1wYTFues5QgBNn8WSY/RxaYy2vvwUtYsj8Bht7BvZwhhES5swf51EOftt1X/kBx2weWEvRkh1W99R6VRY2wq7To4mP3BZma8uJ2OyZVNtpe7z4bdLliDFGGRLoryzTsDb/QaWEJeVhD79xg1nypHCBAR5WTb+jCfbSoluJxG2YW7bYwYXUB5qYWHF27hnhe3kZDof1kcekQpZSUWpj6TwaPvbqHfUHPPTxV52cHV995uF5xO2LsztObe2y3V+QcYe90u5ixczVmX7/Erv9l7fKtZBwqFwmlya820pDNMAnZ6fM50p9VCRMaJyO8i8nthroNLlmRw5pt7OPSiQn6Y3p6EvhVc+O1Ozlq4B1uki1ULYmqdn/5JJKlnHfjQbv4sktQzzD3MUe2cFBfU/PhLCq1ExZpbblQfIWFOrpyyl/ee6VCd1jm1nCGjivjh09gm2Qa4/MhDufO8VL58PY7b5u5s/IRGKMq3sntbCC/9tIH5325qcrPr4on7eOd/tW0MHVXAU19u5Iwrcli/0r9VG/GdKpn7UToPv7WVX76OIb6Tneg4B9Mu7sGit+O4bsZuv/Mc38lOSp9yHr2xC4/d0oVb52SCHz/okDAnV9yewQcv1jzunXuUMuS4PH74wqhRffp6IhPOPIzpV/Zh2Mhc+g0p8DvfLY1SYDe5NYaIxIrI+yKyQUTWi8hRIhInIt+KSLr7bzuP46eJyGYR2Sgip3ikHyEiq93fzRNpfGlUSzpDb5k5oHiUUs8rpQYrpQZHxwURGmfUyDofU0bx7iCCIxVBIcZpaWcV11pn7HLC9sURdD+ldnMib7MNa4giuovDVEaL8qxERtc4v4hoZ5NqRtYgxV3P7uDtpzuQkR4KGE2aO57YycxxXbFXNP02FOYateaVP0QHpGY46Lhi4jvZuero3lx7bC+umrrH79rx0FEFpP8dTlFe7S7qFYtjuPn0Xrz6WCeumupfbShnbzCTzk3jltFpTHgok6I8Kyu/jwaE37+Ppnvvcr/sgvEcrPs9nNJiKzl7bRTkWomJN/cMVWENcjHtyU28+2wSGZuNPt2EThXc/uhmZt3Sq7qvuCjfBgiVFVZ+/iaO1L6BX3rXfAhOk5sJngS+Vkr1Bg4D1gNTgcVKqTRgsfszItIHuAjoC5wKPOPujgOYD4wD0txbo+uhW9IZZgKdPT4nAw2+tsVpxeX2STkbgglt56SyqKZAdy0LI8Y9kAKwe1kY7ftVEBxZ28emfxJJ2llFpjO64c9w+g4twRqkaJ9USVmJxe8BDhHFlKd38MuiGJZ9bdRio+MczHhhB09NS2LPjqYHjQgNd2KxGP9z90PLKMht+riYoCgusOJyCaXFVoJsCouf74MefcsYcHQxM9/YwqBjirnunt10SKpx2MUFVsrLfO8q8HTOpcVWyoqtrFoWSc/DSgFIG1DK7h3+Nx03/BlBckoFFqsiLMJJbLzjAIfeECKKO+eks+zbOJZ9Z/QJR7ezc9fTG3n6nhT2ZIRWHxsRVeVkFQOOLGTXNt+7Df4pFOBS5raGEJFo4FhgAYBSqlIplY/Rnfaq+7BXgXPc+2cDbyulKpRS24DNwFARSQSilVLLlDEo8prHOfXSkqPJvwFpItId2IXh0S9p6ITgikg+GpOELcKFCBzzQDbpn0ax8YMogkJdhLZzcdzD+6uPT/8kkrSzazeFlYJtiyI4+x3zzaXigiA+ezWeOR9uRimYP8O3jn1Php9ewNBRRcQmOBg5Jo/tG0IRMZpg4+418tTU0d+uPSu45bGdlBVbUQrmTTG3aL4h/vgxiuPPyWfux+nYghWfvJxARZl/L4SF8zqxcJ6xP+nxHXz9VjyDji1i1Hm5KCXYK4UnJ3du2IgXuvYq54b7duFyGX2bz96XxF8/RTL4hCIeey8diwW/7FZRUmjlk5cSmP3BZoKCYMHMQ3wa4R1+Sg5Dj8+jXYKdkWfvZ/vGcBBI6FjJuOnbAVj8cXu+eb8j19+9jeTuZYjAquUx/PaD7yPgt87eSZ/BJdiCFT0PK+X+q7v7bMNfTNb6ABJE5HePz88rpZ5376cA+4GXReQwYCUwEeiolNoDoJTaIyJVfU1JwK8etqq63uzu/brpDdJio8kAInI68ATG1JqXlFIzGzq+R/8IpaPWtDF01Jrq0eRA01xRa5o6mtx3QLB6+4sOjR8IDOiyq97RZBEZjOHchiullovIk0AhcLNSKtbjuDylVDsR+R+wTCn1hjt9AfAlkAE8rJQ60Z1+DDBZKXVmQ3lr0XmGSqkvMTKr0WgOEhRgVwHpccsEMpVSy92f38foH9wnIonuWmEikOVxvLeut0z3ft30BtGRrjUaTZNQCE4sprYG7Si1F9gpIr3cSaOAdcCnwBXutCuAT9z7nwIXiUiIu/stDVjhblIXicgw9yjy5R7n1ItejqfRaJqMSwWse+Rm4E0RCQa2AldhVNreFZFrMJr
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"\n",
"y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)\n",
"plt.rc('font', size=9) # extra code make the text smaller\n",
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAATkAAAEKCAYAAABpDyLyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABWhUlEQVR4nO2dd3hUVfrHP+/MpCcQSCBA6J3QkRWwsCAouCoo9oYdFVBQQdm17Lr+1EXEBaUKWEAFFVGKVFFXRSlKL6ElQEJIQhJCQhLS5vz+uAMpJJNJZm6a5/M89yEzc+73vjl3cjjnnnPeryil0Gg0mtqKpaoD0Gg0GjPRjZxGo6nV6EZOo9HUanQjp9FoajW6kdNoNLUaW1UH4IzQ+lbVopnnQzy8O8DjmhrNBcRqTt9B5ds9rnmeDHJUtrijMWRggEpOyXep7B+7s9cppYa6c73yUq0buRbNbPy6NtzjujeGX+ZxTQDEre9KydS0JT5m1IGZmFC/1sA6HtcEyE9L87jmFrXRbY2klHy2rGvqUlmvxkdD3b5gOanWjZxGo6kJKPKV53uZnkI3chqNxi0UYKf6jjh0I6fRaNzGju7JaTSaWopCkauHqxqNpraigPxqPFyt9uvkls9vyOhrIhg9MILl8xoW+WzZnDBuDL+MsylWAPZvC2Ds4E4887eOxEX7AHDurJWX72nrdBKt94A05v8cyYebDnDH2AQAHnkxjtnfHWTi9BMXyw26NYWbHzntcuy9B6Qx/6cDfPjLfu4Y49D9RxyzN0QycfrxCumaGquHdZ+deoLPd+1l7sbIi++5+/ubqWtW3QJYLIr3lm3nX3P2AfDQc9HMXP4Hz/3n4MUy1wxLYPj9J6tFvOXFjnLpqAoqtZETkaEiclBEjojIpLLKW88Hse6zUN759gDvbdjP1u/qcjLKaLxOn/Rix09BNAjPvlj+67lh/P39KEZOOsnqhQ0AWDKtMXc8FV/qygaLRTHmjZO8dG8rHhvQgYHDU2kdkUVE70yeHNwBi1XRsmMW3r52rr3jDCs/dm0G3GJRjHk9lpfua81jAzsy8OYzDt0Mnry2IxYLhXRTXNI1NVYTdNd/UZ8X72198bV/UL5bv7+ZumbVwQWGjzxJTJS/EW9gHp16pjFm+GWGbvsMvH3yGXxLAqsWN64W8ZYHBeQr5dJRFVRaIyciVmAmcD0QAdwtIhHOzrFkB9KxVwa+fgqrDbr0Tee3tcEAzPtXMx568WSRxstmU+Sct5CdZcHmpTh1zJvkeC+69jtX6jU69Mwk7pg38Sd8yMu18OPyYPpel4bNSwEKb19Ffp5w+5OJLP8glPw819aBGbo+hXTr0fe6s4V07eTnCrc/kcjyBQ1c0jU3Vs/r7t0SSHqq9eJrZcet399MXbPqACAkLJu//DWFdV82MuJV4HVB18dOXq5w6yOxrFgUTn6ea3+SZsZbEewuHlVBZfbkLgeOKKWilFI5wBJguLMT8n3S2bs5kLQUK+ezhN+/r0tSnDdb1tclpHEOrTtnFSl/+9h4ZjzfnOXzwrjxwUQWTg7nvolxToMKaZTL6Tjvi6+TTnlRr0Euv6yuy6wNh0g44U1GmpX2PbL4bV1dl39ZQ9erqG5onqG7/iAJMd5kpFtp3yOT39a7pmturJ7XLU5WhtWt399MXTPr4PF/HOWDt1thVxfitbFpfQjvfb2DhJO+ZJyz0b7rOTZ/H1It4i0vCkW+i0dVUJkTD+FATKHXsUCf4oVEZBQwCqBZuJVXno3n5bvb4xuQT6uILKxWxefvNua1zw5dcoHWXbKYusp4xrF3cyD1w3JBweQnWmH1UjzySiz1GuQVu96lgSoFX85qyJezjGeA49+OYeGUMIbek8xl/dOJOuDH4ulhTn/ZUnVnh/HlbOPc8VNOsHBKI4bencxlf00n6oAvi6c3Kr+mWbG6qVsS7vz+ZuqaVQeXD0gmNdmbI/uC6Hp56sX3ly5oxtIFzQAY99ohFr3bgiG3xdPryjNEHwxgyZzmVRJvRVAKcqvvvEOl9uRK6i9fUjVKqfeVUr2VUr0bhFi57u5kpq87wORlhwgKzqNhsxwSTnjz1LURPNynC0mnvBk/JIIzibZCGvD59MbcPf4Un/23MfdMiGPgiBRWLmhY/HIknfKiQZOci69DG+eSHF/QA2vTJROA2KM+DL7tDK8/0ZKWHc/TpFX2JVqX6uYW1U0opNvZoRvlw+DbUgzdDs51zY3V87rOqMjvb6auWXUQ0SuNvtck8+HGrbwwNZJufVKZ8FbBhEnrTsajlJPH/Bg0PIE3n+lEi3YZNGmRVZqkqfFWDCHfxaMqqMxGLhZoVuh1U8D5WBJITTIar8STXvy2ph6Dbkvm0927+WDLXj7YspfQxjlMW7efeg0Lemgbvwih96CzBAbnk51lwSIgFkV21qW/7sGd/oS3yiGsWTY2LzsDhqeyudAw54GJ8Syc0gibF1isRpus7ODr5/wJg6GbXUj3DJvXF+xpfOD5eBa+3Ribl8LieLxUlq65sXpe1xkV+f3N1DWrDj56pxUjB/ThoUGXM/m5juzeEszbz3e8+PnIccdZ9F4LbLZC8SrBx9f5hvequGeloQC7cu2oCipzuLoNaCcirYCTwF3APWWd9MZjrUk/Y8NqUzzx+gkCg53f/PNZwsYvQ3htsTGcvXlUIm+MaoPNy87zM6MvKW/PF2a+GM4bn0VhscL6JfU5fsgXgH5Dz3Jwlz8pjh7YgT8CmLPxINEHfIna7+c0Dnu+MPOlpoauRbH+8/ocP2Sc029IKgd3Ftb1Z853kUQf8HOqa2qsJuhOmnmMbv3OUbd+Hp/8vo9Fbzdi3ZKQCv/+ZuqaVQfO6DcoiUN7AklJNFYMHNgZxKwVfxB9MIDog4HVLl5nVFUvzRWkMo1sRORvwDTACnyglHrdWfnLuvsonYWkGj/sKAmdhQRrnZqVhSRNpbh10zp381ZLvr30UVBJdGt+8g+lVG93rldeKnXHg1JqNbC6Mq+p0WjMRQG5qvruK9DbujQajVsohPxqvHlKN3IajcZt7Kr6PqbQjZxGo3ELY4O+buQ0Gk2tRcjXz+Q0Gk1txcgMrBu5CnF4d4Apyz3Wxe30uCbAkPCepuhqqFFLaVS+a85VtQWlhBxlLbtgFVGtGzmNRlMzsOtnchqNprZiTDzo4apGo6m16IkHjUZTi9ETDxqNptaTrxcDazSa2opCyFXVtympvn3MEvCkO9HX80MZNbADjw3owLJ5hunNorcbcU+vCJ4c3IEnB3dg68YgAPZtDeCJQR146vr2nIw2Uk6fO2vlH3e3LnVlQ01zlKpJbl1mxWuGppe3nWlLdzNzxU7mrN7BfU8bOg9PPMaslTt57q3DF8teMzyR4Q+UmWLR1HgrwoWJB1eOqqAyjWw+EJFEEdlbkfM96U50LNKXNZ+G8O63h5jz3UG2bKjDySij8brlsdPM/u4gs787yOWD0gH4am4DXp4XzYOT4li10ND9dFoYdz2dUGpmoZrkKFXT3LrMiNesOsjNESaN7MyYYT0YM6w7l/VPpevlZ+nUM53RN/XAWsit69oRp1n1qWvp36uXW5eQr1w7qoLKbFo/AoZW9GRPuhOdOOxDp16Z+PobLmDd+p1j05rgUstbbYrsCy5gNkXcMW+ST3nRrV9GqefUJEepmubWZUa85rlfCeczjTqw2RQ2m9H1t3kX1ENennDbY3EsX9i4Brt1WVw6yqIs21IRqSsiK0Vkl4jsE5GHytKstEZOKfUTkFLR8z3pTtSy43n2bAkwXMAyhW3f17norLXywwY8MagDU59pdvEP9K6nEpn+fDO+nt+AYQ8l8dF/GvPA86fKFX91dpSqaW5dZsRrZh1YLIoZK3ayePM2dmyqy56tddm0LoQZK3YRH+NDRrrDrWtjfZc1q5Vbl4J8ZXHpcIaLtqVjgP1Kqe7AAGCqiHjjhGr3tLCwW5cv/oXev7RsRd2JmrfL5o7Rifz9rjb4BtgNFzCb4sYHkrjnGcOI+uO3GvH+q0147r8xtOmSxfRVxrOTPZsDCGmUi1LC64+3wOalGPXPuEtcwEqiujpK1TS
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"plt.rc('font', size=10) # extra code\n",
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred,\n",
" normalize=\"true\", values_format=\".0%\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 75,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAATkAAAEKCAYAAABpDyLyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAACM40lEQVR4nOydd3hT1f+A35OkTfdiz7L3LLMIyB6KbLegoOIABRUHyNctoIKDjYNVNkUEUQREQTaIjAId0AFtoXsP2iY5vz9u6KArbVPW777Pk4cmufnkcHJzcs6553xeIaVERUVF5X5Fc6cLoKKiolKZqI2ciorKfY3ayKmoqNzXqI2ciorKfY3ayKmoqNzX6O50AUrCVuilHY53uhgWI/S2Vo8ps3OsHhNACFEpcaWDvlLiYqqcsCLzhtVjNm2bbvWYAEHnHKwe8wbpZMusCp0Mg/s6yvgEo0XHnjqXtVtKOaQi71dW7upGzg5Huon+d7oYFqOr28DqMY2R160eE0DYWr9BBjB4NamUuNoMQ6XEFX6XrB5z1+5jVo8JMLhOR6vHPG76s8Ix4hKMHN9d16JjbWoFV63wG5aRu7qRU1FRuReQGGUldbWtgNrIqaioVAgJmLh7NxWojZyKikqFMVXWpKkVUBs5FRWVCiGR5KjDVRUVlfsVCRjV4ap16NwnhZc/vYZWI9m1wYPNi2rw/PvX6Nw3lZAL9nw1tT4A/cck4Oxm5Jefqt3RuI5OObz+7mk8G6WChG/ndMS793U6dYsh5LILX3/WCYC+g8Nxdslmx5bGpcZ844sQuvVLIinehpeHtAVg4rvhdOmTRPBFB+a9pcToPyoOJ1cD21fVLDVm1ZpZTP8yCPeq2UiTYNfmGmxfU4eJ00Pp3DuRYH9H5r/bHIB+I2Jwds1h+5o6RcZ666VDdOsYQVKKHZPeGQnAi0+dpLtXOAajlmvRzsxb9gDpGXpaN4vm9eePkZOjYfbCB7kW7YKjQxazXj/AjLkDgbyVDW+8dpRunSNISrbj5dcfAcDJKYuZbx+kRvV0omMcmf1lL9LS9bRqEcNrr5wgJ0fDnHm9uB7ljKNjNjPfPsj7H/UrELdAPdTKYvq8YNyr5Sj1sLE621fVZOK7V+n8YBLBFx2ZP12p334jY3F2MxZbv2nJWr6ZXo+wADuEgDe/vsqp/S7sWu+Bq4ey3GLCjGt07Z/KhROOLJxRFxtbyXtLwqjTMJu0ZC2zX/bk8/UhFLfap3OfFF7+JNJ83lZh8+IaPD/zGp37phBy0Z6vpnoCZT9vy8PdPCd3WxcDCyGGCCEChRCXhRDvleW1Go1k8uxIZj3dkBf7NKfviCQatcqkVecMXhnQHI1W0qBFJrZ2JgY+lsivqy27Ul1ZcQEmTfXj1PEavPx0f6Y815f4ODtatElgynN90WjAs1EKtrZGBgy9ym8/N7Qo5t6tVZn1XPPc+w7OBlp1SuWVoW3RaKBB8wxs9SYGjIlj59rqFsU0GgU/zG3ISw914o3H2zHsqes0bJ5Gy46pvDrcC60WGjRLx1ZvZOCoaHaur1VsrD0HmjBz7sACj/3nV5sX3xnJS++OIPK6C0+O8ANg7MMX+OSbvqzY1IlHBgYC8Mzoc2zY3o5bG6K9+xox6+N+BR57fMwFzpyryfOvjODMuZo8NuYCAGNG+vPp3N6s9OnIsKFBADz1mB8bt7QpFLdAPRgEP8z25KVB7XljTGuGjYumYYt0Wnql8epD7dBqZW79Dhxbcv0u/aAOnfuk8NPBAJb+GUj9plkAjHoxlqV/BrL0z0C69k8FYOvyavzvh1Cee+8aO9co59e6b2vwxOvRxTZwGo1k8ucRzHqmES/2bUHfkYnm8zadVwa2UM6F3PM2oUznbVmRgFFKi253gtvWyAkhtMBiYCjQCnhSCNHK0tc375jBtTBboq7qMeRo2L/dje6DUtDZSEBiaycxGgSPvhLD9hVVMRosW99YWXHtHXJo0z6ePTuVXqDBoCEtxQYbGxMg0euNGA2CMU9d5lffRhiNln0U50+4kJqU1wGXJpFbVr2dCUOOYOyk62xfXQOjwbKYibG2BF90AiAzXUd4iAPVa2ehM5fVVm/CYBCMfSGS7T61S4zrF1CT1LSCa/BO+dXBZFJe43+pGlU9MpQ6MWqwtTWgtzVgMAhqVU+hins65/wL947OX6xBalrBhcbe3cL5869GAPz5VyN6dA/Pjau3NaLXGzAYNdSqmUqVKhn4XahRej1ccDTXg5bwy3ZUr52dVw+59XuN7auKr9/0VA1+xxwZ8lQCADa2EifX4hfLanWSrBsasjI16HSSa2G2xF+3oZ138YuKlfNWn++8daf7oOR8560JY47g0Zdj2P5TNYvP2/JisvB2J7idPbmuwGUpZYiUMhvYCIyw9MVVauYQey3vyxN33Qb3ajkc+t2VJXuDiL5qS3qKlmYdMjm629XiQlVW3Fq1M0hOsuWNmadZsGI/r797GpMUHN5fm4Ur9xN13YH0dBuatkzk2KHie0alkZmu5dAfHiz+7QJREXoyUrU0a5/Osb3u5YpXvc4NGrdMx++EK4f3VGXRL2eIitCTnqqjWZtUju2rUu6yAgzuc4mTZ5Wh7sbtbXnjhSOMHnqR7XtaMuHx/1i9xcviWG6uN0hIVHYBJCQ64Oqq9JY2+bbm9cnHGTU8gF9/a8azz5xhzbr2ZSpn9TpZNG6dgd8JZw7/4cGineeJCteTnqqlWbt0jv3pUexro67oca1iYP4b9Xl1YDO+easeNzKUr9qvK6vxcv/mzH+jHqlJWgCeeC2G796px7YfqzF8Qhyr5tbi2XdKXgSunLc2uffjrtvgXtWgnLd7AokOt1XK2iGDo3ssP2/Lg0RitPB2J7idc3J1gPB89yOAbrceJISYBEwCsMMh3+OFA0oJW5ZUZ8sSZdgwbV44a76qwZCn4unUO5UQf3s2fFfyr3dlxdVoTTRplszyb9sSeNGDSVP9ePSZS6z9sSVb1zcF4PV3T7P2xxYMGnYFr64xhAa7sGl18xLjFoXv8lr4LlcaymlzQ/H5ug5DHo/Bq1cKoQH2bFhU9PzZrdg5GJm1wJ/lsxuSka7D98e6+P6orGSf+tklfBZ4MnhsFF49kwgNdGDj0vplKudTI89iNGnYd0jpfQVfqcLrHwwDoG2LKOITHQDJ+6/vx2DUsHxtF5KS7cv0HgAhoR688Y6yc6hNq2gSEuwRAma8fRCjQfD9ik4lxrVzMDJrSRDLP/UkI02H7/e18f2+tlIPc0Lw+aYugx+LwatXMqEBDmxcXLB+jUa47OfA5M8iaeGVwdL/1WHTouoMnxDHU29EIQSs/rIm339cm7e+Cadxm0y+26nsvPA75kiVmjlIKfj8JU90NpJJH17DvVrBHR/FnrdLa7BlqXJuTvvqKmu+qsmQJ+Pp9GAqIf52bPiu9DnasiIl5Ny9U3K3tSdXVH+5UNVIKb+XUnaWUna2IW94Enfdhmq1s3PvV62VQ3xU3i9Z4zbKECgiWM+AsYl8/nIDGrS4Qe2GWSUWqrLixsfaExdrR+BF5Rf/8N+1adIsOff5Rk2TAIgMd6L/kHDmftAFz4ap1K6bVmLckmjcShneRITa0X90PLOnNMGzWSa1G5S+P1OrMzFrgT9//1qdI3sLzt80bqmUKSLMnv4jY5gzrQUNmmZQ2zPT4rIN7H2Zbh0jmLuoN4VPBcnTo86xblt7xo05yxrfDuw71IhRg/1LjJmUbIeHu/L5eLhnkJx8675ZyZOPnWf9prY8/cQ5fNa3Y9/+Rox4JLDYmFqdiVlLLvH3jqoc2V2wt1awfuOY81pTGjTLKFS/VWvlUK1WDi28lLL1HJbEZT973KsZ0GpBo4GhTycQeKbgXlQpYf23NXhqWjRrv67BuOlR9BuTyC8/FZ5PU87bvH3NVWvlEB+d77xtbT5vQ/QMGJugnLfNSz9vy4fAaOHtTnA7G7kIoF6++3WBa5a+OPCMA3UaZlOjnjJX1GdEEsfydcOffTuKNV/VRGcDGq3SdkoT2NmXPBNQWXETE+yIjbGnTj1lcrl951iuhjnnPj/uhQDW/tgCnU6
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"sample_weight = (y_train_pred != y_train)\n",
"plt.rc('font', size=10) # extra code\n",
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred,\n",
" sample_weight=sample_weight,\n",
" normalize=\"true\", values_format=\".0%\")\n",
2016-05-22 17:40:18 +02:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's put all plots in a couple of figures for the book:"
]
},
2016-05-22 17:40:18 +02:00
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnsAAAEbCAYAAACiKEJQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADiNklEQVR4nOydd3gUVduH77MlvSeQQAiEEAgdhNAUlar42hHsFQQpKihdQH1fBalSBewFBAQsYEHAwqcoRVF6S0hCgDTSe9nd8/0xm0rKbrKLLM59XXslM3vmN8+cmX3mOV1IKVFRUVFRUVFRUbk20fzTBqioqKioqKioqNgPNdhTUVFRUVFRUbmGUYM9FRUVFRUVFZVrGDXYU1FRUVFRUVG5hlGDPRUVFRUVFRWVaxg12FNRUVFRUVFRuYZRg71rDCHEq0KIZCGEFEI8aQO9ULNWpA3Mu+oRQuwWQqz8p+1QUVG5+qnqH6+EvxRCRJrPEVpLmjghxGR72aDieOj+aQP+DQghAoGXgDuAZkAqcARYIaX8zobn6Qi8AgwF9gJZNpA9DzRBsdnhMDvEWKCHlPJPCw4ZCpTY1ahrgFv7u8u0dKPF6Q8eKdohpRxiR5NU7IglPkwIEQe0AB6TUq6rcvwBoAcwRUq56AqafqVxaH+pYj2O4gvVYM/OmION34AcYAZwGKVGdSCwBmhuw9OFm/9+JW00W7aU0ggk2ULrakYI4SSlLJZSpv/TtjgCqelG9u9oZnF6fZOzAXY0R8WOWOnDzgMjgXUVju8IdADSrozFlSn9bV+Jc/1b/GUpQgi9lPJfXTh2FF+oNuPan1WAACKllJuklKellCellCuBLqWJhBDNhRBfCiFyzJ8vhBDNKnz/qhDimBDiQSHEWXOar4QQAaXfA1+ak5uEENK8/yMhxDcVDSrVqrDdSQjxoxAi26x7WAjR3/zdZc0SQoibhBD7hRCF5ibjJUIIpwrf7xZCrBJCzBVCpAohUoQQi4QQNT5vQognhRC5QojbhBCnhBD5QohtQghvIcQwIUSUECJLCLFWCOFa4bghQohfhRAZQoh0IcQOIUS7CtKx5r9/mK9jd8V8EUJME0JcAC5UsH2l+f8IIUSeEOLxKucrFkL0rula/h1IjNJk8UfFobHIh5lZD/QRQoRV2DcS2ALk1naSCj5goNnX5QkhfhZCtKyS7hkhRLT5dxgthBhV5XsphBhv9qF5wNwK/vMJcxNnrhDiQyGEkxBinBDivBAiTQjxZkU/JYR4VAjxh9kvpgghNgshgmu5hqrNurvN21U//czfOwkh5gshLpiv9w8hxK1VNIeYfWKhEOJXoE1t+VgBDyHEOvO1JokKzbpCiA+qeS9ohBDxQogXa7i2fmbb/yOEOCCEKAZuFUI4CyGWmt8FhUKIfUKIvhWO2y+EmFZh+1OzTpB52818L2+w8LquMhzDF6rBnh0RQvgBQ4CVUsrLHJ2UMsOcTgBfAYHAAKA/0BT4yvxdKaHAA8C9wC3AdcAc83eLgFKn18T8sZT1QCLQ06z5KlBYwzUFA9uBv81pRwIPAW9USfoIYACuB54FJpptrw1nYJL52IFAJMpL4gngPuAelGakcRWOcQeWmm3vh9J0/bUoDz57mv8OQcmToRWOvRnobP5uYFVjpJSngReAlUKIVkKIRsBHwBwp5b46ruWaRgImpMUfFcfEUh9WgVTga+Ap8/FOwKPA+xae0hml9nAE0AfwQak9LLXnXmAlym++I7AMWCWEuLOKzivAd0An4C3zvlDgbhQfch8wHNiK0rx8C/A08ByKfy3FyazVxXxcALDBwmsBxd80qfBZAyQDp8zff4jihx422/oxiv/qYr7eEJR3wy6gK7ACWGDhuV8ETgLdzNcwVwhR6v/eBYYIISq+JwYDQcDaOnTnA7OAtsB+sz0PoNyz64CjwPcVtHejvNNKuRnlOeln3r4BpevMAQuv66rCUXyh2oxrX8JRSsQn60g3CMWZtJJSxgEIIR4GolGCkB/M6XTAk1LKLHOadzA7VSllrhAi0/y/tc0ILYBFUspSBxRdS9pxKIHhOCmlCTgphJgOvC2EmC2lzDenOyGlfNn8/xlz6XsgtTtKHTDeHGQhhFiPEmwFSilTzfu2ojiOxeZr/byigBDiKSAbJcjbA1wyf5VWTb4UAiOklEU1GSSlfEcIcRvwKUoz1Fng9Vqu4V+DCbXG7l+ApT6sIh+g+INXgLuATCnlL5XLrTVS1QcsAj4UQmjM/mYysNZcqwiKb+kOTEMJMkv5TEr5XumG+dxa4Cmz/zwmhPgeJfAINjfznhRC/IbiXz4HkFJ+UEEzRggx1pyumZTyQl0XU7FbiBDiAeBJoL+UMkkI0QqloBwqpYw3J1sphBgEPIPia8cC8cDz5q45p4QQbYDX6jo3sF9KWVoZcEYI0QMlAPxCSrlXCHEKpSA9z5xmBLBNSnmpGq2KvCql3Gm+JnezjU9LKb817xuDUmkxHiUo3A2MF0LogJaAN7AcJZ83ogR9vztyc7Aj+EK1Zs++WOTdgHZAQmmgByCljAESgPYV0p0rDfTMJACNG2ok8CbwnhDiJyHETCFE2zps3Wt2vKXsQSkBh1fYd6TKcZbYWlTq5M0kA0mlgV6FfWU65hq39UJp2s42f6/Bsr6Qx2oL9CrwNMq13QQ8au6X869GIimRJos/Kg6LpT6sIjvMxw1Gqfn/oPbklajqAxIAPUoNHyj+57cqx+yhsp8EqG4wVnwV/5kMnKnSn6+qf+kmhNgqhDgnhMipoGtVX2tzs+4HwMgKrQLdUPLphLmpNVcIkQvcDrQyp2kH7KvSB3uvhaetmm4vlfPpXcprYP1Qaj0tqYGtmLetUO5P2T0x+8eK5/oVpca2B0pg9ytKBUY/8/f9UAJCh8RRfKEa7NmXKJRa3nZ1pBPmdNVRcX/Vko+k7nto4nKHra8kIuWrKD/Mr1CaXY8IIUb8A7YaqjmmLp2vgUYoJeFeKM0IBpTgsy7yLEgDSnORN+AC1Nhf59+EBIxIiz8qDoulPqwMc0HwY5TRuwPM/1tKdT4AKv/mq3ugqu6r7rddnS+p0b+Ya612APnAYyjBSukoSkv8C2adpii+9U0p5foKX2nM5+uB0kRb+mmHUssG9Qu2LWUt0MLcv+4RlKbVnRYcVzFvS+2r8Z6Ym///QqnJ6wf8jBIMthBCtEa5/t3Wm3914Ci+UA327Ii5Cn8H8KwQwqPq90IIH/O/J4BgUWHeJKF0cG5q/q4hXOLy/ntdq7E1Skq5XEp5O0rp7uka9E6gdMCu+Oz0BYpRmjivGEIIfxTHOFdK+YOU8iTgSeXuCaWldm09z+EDfILSJ/ItYK0QwqveRl9DOEI/FZWGYYUPq8oHwI3ALillgg1NOonibyrSl4b7yepoi9JH7yUp5S/mbi5WtaQIIVxQAr19wMtVvv4bJVgKklJGV/lcNKc5AfSq0nfb0sFhVdP1pkJzvPnefoESWI4APqpHq0U0io+tOCBDi9LfsuI92Y0S7N0M7JZSFqL095uJA/fXK8URfKEa7NmfcSg/6D+FEMOFMsKzrbnvR2lT5w8o0xl8KoTobq7y/xSlNPRTA8//E3CdEGKEECJcCDEVpUMsAEIIVyHEW+aRVqFCiF7U7jxXoQShq4QQ7YQQt6P0+VhZob/elSIDpTQ6ynxtN6N0gK5YO5ACFKCMGgsUQnhbeY415nO8jNIvKIfyDt//WiRglNLij4pDY4kPq4S5G0oAyiAIW7IQeEwoo21bCyGeQ6mVsnTQgjXEA0UogW6Y2ddZ0leuIm+jNEFPBQKFEEHmj5OU8gyKn/9IKDMOhAllwuTJFQZSrEEZWLLUnO/DgDEWnru3EGKGOZ9GAY8DS6qkeRcl/7qgDBaxCillHrAamCeUUbrtzNuBKO+KUnajBHueKO+10n2P4uD99RzFF6rBnp2RUsai9M3YhTKK6QhKAHYXStMj5v4Y96DUwu1GqeZOAu5p6Hx5UsodwH9RRu0eRHEcFX+ERsAXpanlNMr0LXtROvJWp3cRuA2lufQQSgl+A0qTzRXF3Fz0AMqI2mMoQdhsFAd
"text/plain": [
"<Figure size 648x288 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"# extra code this cell generates and saves Figure 39\n",
"fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))\n",
"plt.rc('font', size=9)\n",
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0])\n",
"axs[0].set_title(\"Confusion matrix\")\n",
"plt.rc('font', size=10)\n",
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[1],\n",
" normalize=\"true\", values_format=\".0%\")\n",
"axs[1].set_title(\"CM normalized by row\")\n",
"save_fig(\"confusion_matrix_plot_1\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnoAAAEfCAYAAADWe2t4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOydd3wUxf+Hn7m73KUXCDWU0EJvIUgRkN4EqVbEgoogKGBX+KlfFVARC01RQSB0ghBEERABBUGQGiAkkARICAnpnSR3N78/9kgh7RJyYHQfXvsit7f73s/Nzr53dmZ2RkgpUVFRUVFRUVFR+fehudsBqKioqKioqKio2Aa1oKeioqKioqKi8i9FLeipqKioqKioqPxLUQt6KioqKioqKir/UtSCnoqKioqKiorKvxS1oKeioqKioqKi8i9FLeip2BwhxAohxPaSPtvomNuFECtK+f4pIUS6LWNQUVH596D6mPUIId4TQpy523GoKPxnC3qWi1QWsxy+27H9B5gGPH63g1BRqeqoPnZXUX1MpUqgu9sB3GV+Bcbfsi6npI2FEHopZc4t63SASZZz5OmK7mcL7nQsUsqUO3GcfwLF5RkVlUpG9bG7EMt/ycdUqjb/2Ro9C9lSyphblsSbX1qejKcIIX4QQmQAc25WSVuqzMOAbMBJCNFACLFFCJFmWX4QQtQroFXSfr2EEIeFEOlCiBQhxF9CiDYlBSyEuCSEmCWEWCqESBVCRAkhXrtlm4rGIoUQk4UQgUKITCFEqBCijxCinhBipxAiQwhxUgjhW0CruhBinSWOLCHEWSHE06UlesEmDyFE7xJqJPYV2L67EGK/JaarQoivhBCuBb53tGimCyFihRBvl3b8W2IZbvmdN4QQe4UQjS3rvYUQJiGE3y3bPyeEiBdC6EvQ22eJ71MhRBxw0LK+l+Xc3rDE+PlNDSHEEMt50lk+N7OkwVcFdGcLIXZb+7vuNIP6OEm/9vZWL0KIX+52zP8iVB9TfayyfUwIIV4RQlwQQmRb0mVuge/bCiF+taRVoiVuN2vSqsC6Qs27N7cRQrwhhIix5KOPhBAay7bXLevfuEVHCiEmCiE2Wc5tuBDijtS0ltf37pr3SSn/kwuwAthexjYSuA48CzQGGgHvARnALsAXaAPYAceBP4HOgB9wGPgbEBat4vbTAUnAp0AToAXwGNCylJguAQnAVKAp8KIlzm6W78VtxCKBq8CjQDNgLRAD7ARGAD7Az8DpAvF4Aa8BHSxpNBGlNqFfSWld8DOgB2oXWDpZ0uQ9y/dtgXTgFUtMXYBDQEABvSWWuAdZfssmIBVYUUo6PgXkWtLlXqAj8DtwqkA67QSW3LLfIeDzUnT3AWnAfMv5bGlJowzga8vnYZZ0nW/Zx9kSS1fL5+eAOOB8Ad2DwMy7fd2UtPi2M8jca02sXoC/73bM/4bl1murhG1UH1N9rLw+NhdIBiZYzk834AXLd46WOLdaftd9QCiwucD+7wFnSsunJWyTiuKTLSznzwz8YonHB5hkOb+dbsnfUSjN6E0t2+YADW19/ZXX9+6W9911o7pbiyVTGS0XX8Hl41sy0MJiMmcuUKvAugGACfAusK6xJZP2L2W/apZj3FeOuC8B625ZdwGYdTuxFPi9cwt8bmNZ93KBdb0t6zxLiXE98N0taV2sQd6ynwOKYf1AvkmtApbdsl0HSww1UQpJ2cC4At87o5jUilJifMqicW+BdQ0taXczncaimLW95XNLyz5tStHdR4EbiGXdbOAioLnl+NmAo+XzX8Bblr/XAO8CWUAdFGPNKRjrP23xbaeXN6IbWb2gFvQqZUH1MdXHKtnHLMe9AUwq4fvngBTApZj0bFrg3FSkoBcJaAus+5uifnoJeLWU860DMoHHbX39ldf37pb3/debbn9HudgKLvNu2ebvYvaLklLGFvjcEoiWUl66uUJKGQ5EA61K2k8qzSsrgJ1CiJ+EEC8LIepbEffpWz5Ho5hFhWMpQfvm90HFrKsJIITQCiFmCiFOCyEShPIG2GiggRW/Iw8hhEBJCy0wXlquWJQn48ctzRnpFv2Dlu+aWBY9yhPqzd+bfkvMJWEGjhTY7zKF0ykQpYA12vJ5AnBESlnW22THbvncEjgkpTQXWHfAEndTy+d9KGYJyhPyDktsvVGe1HMLxvpPQwJmpNWLSqWi+ljp2qqPlc/HWgEGYE8J37dEKXylFVj3pyWOVsXvYjXnpJSmAp9jKZoGseTnk5vknW8ppRGlReTWbSqd8vre3fK+/3pBL1NKefGWJf6WbTKK2e/WdQJKPIMF1xfRklI+jVKN/zvwABAqhBhURty5xRzj5rmscCzFaMtS1t083qsozRHzgH4oN5mtKKZVHt4BegHDpZQFY9MA31H4JtYepfnjJMrvtQlSylyUJ/EJQuk/Nx5YZsWuFckf+4B7hRCtABeUwuI+oA9KYe9PSzz/WMzl+KdSqag+Vrq26mPl87GyYrH23BTEXIyuXTHbFZcnSssnpe13R8o35fG9u+V9//WCXmVxDvASQnjfXGHpDFvX8l2pSClPSSk/llL2Rrm5P3m3YqkAPYAfpZT+UsqTQBhKXwqrEUKMBV4HRkgpo275+jjQupgb2UUpZRZKk2gu0LWAnhNKc01ZaFD6/9zcrwFKOgUX2OZblMLWCygFsPXl+W0WzgHdhBAFr7ceKE/ZYZbPf6A8Rb8OHLA81e4jv6C3rwLHvWNIJCZp/aLyj0T1MdXHQEnfbJQCb0nftxdCuBRY190SR3DxuxCH0g2lIB1KiaFKUF7fu1ve918v6BmEELVvWWpUQOdXlM6va4QQnSxvOK1Bubh/K2knIUQjy1tF3YUQDYUQfYB23J6RVSiW2yAU6CeE6CGEaAEsQunsbRVCeTNvJfA2cKXAeahm2eRj4B4hxNdCiI5CiKZCiGFCiKWQ17yxDPhYCDFACNEaWI7SdFIWRuALIUQ3IUQHSxxnUdIQi34oSjPrPJSO06nW/rYCLEEx3iVCiJZCiPuBj4BFUsrMAr/jOEqH4r2W/Q4B9VFqSvZV4Lh3lH9688W/GNXHbh/Vx/K3TQO+BOYKIZ4WQjQRQtwjhJhs2WQNSk3qKqG8fdsLWAr8IKW8WILsb0BHIcQEy29/HaVLSpVHbbr959MfuHbLcqK8IpZ+GCNRnlr2odyoY4CRBfpoFEcmylPjJhSjWYlyEX1c3hgqIZaK8iFK/5AdKM02GSi/wVr8UF42+ILC5+EHACnlaZSmEG9gP4r5zyW/jw0ozS57gS2W/89YYimLbJQXJVahvAyhAUYXk07LUJpwrGm2LYKU8iowBOWNuJMoBr4O5aZQkL0oxr7Pst8NlDcNs/kH988DpZ3EhLR6UalUVB+7fVQfK8xbKOfv/1Bq6TYD9Sy/JRPlzWBXlDQLRHkonVCSmJRyJ/A/S5zHUNJhiRVx/KMpr+/dLe8TtrlmVFT+PQhl3KZnpJTlasr5L9G+vV7u/NnT6u3r1Lt2TErpV/aWKioqlYHqY5VPeX0P7o73/ddnxlBRKREhhDPKeE7TUJ5EVUpBfcVCReWfh+pjtqUq+N5/velWRaU0FqEMgXAQpQ+KSgnIKtB8oaLyH0X1MRtRXt+7W96n1uipqJSAlPIplAFJVcpCgkktv6mo/ONQfcyGVBHfUwt6Kioqt40ycKiKiorKf4eq4ntqQU9FRaUSEJhsN+arioqKyj+QquF7/+iCnl4YpD1OdzsMqxGG8g6ibj0yxzaTIigz9lQ+0tFgE11bPT6JrBu2EQaatS1p8P7bI/S0Y6Vr3iCDHJld7kwhAXMVaMKoClQ539Pb0PdybeR7GhvdnG2UFma9bbrTa9Js53vS3kb5IrPyY/63+94/uqBnjxNdREmDc//z0NXztpm26eo1m+jayqSNvk3L3qgCaDONNtEVQRdsoguwY+dhm+gO8upY6Zp/mX8te6MSqApPtlWBqud7DW2mbYqyje9pHOxtoisb1LWJbqa3q010HfeetYkugPTxto3uicqP+S9Z0rS+ZVMVfE9961ZFReW2UQYOFVYv1iCEGCyECBFCXBRCvFnCNr2
"text/plain": [
"<Figure size 648x288 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 310\n",
"fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))\n",
"plt.rc('font', size=10)\n",
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0],\n",
" sample_weight=sample_weight,\n",
" normalize=\"true\", values_format=\".0%\")\n",
"axs[0].set_title(\"Errors normalized by row\")\n",
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[1],\n",
" sample_weight=sample_weight,\n",
2021-10-29 07:03:30 +02:00
" normalize=\"pred\", values_format=\".0%\")\n",
"axs[1].set_title(\"Errors normalized by column\")\n",
"save_fig(\"confusion_matrix_plot_2\")\n",
"plt.show()\n",
"plt.rc('font', size=14) # make fonts great again"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 78,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
2021-10-30 21:28:30 +02:00
"cl_a, cl_b = '3', '5'\n",
2016-05-22 17:40:18 +02:00
"X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]\n",
"X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]\n",
"X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]\n",
2021-10-29 07:03:30 +02:00
"X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 79,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAFYCAYAAAAMUATOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOydd3hVZda375PkpPfeSUJCQkgBEnqV3qsgKoqKiHXUccbXGZ1xZnzHsY0zg6jYqCIiJaGFEFoSkhDSeyG9995P298ffGe/dFIOzsjkvq5clyabZ+19zt5rP8961votiSAIDDPMMMMMM3S0/t0nMMwwwwzzoDDsUIcZZphhNMSwQx1mmGGG0RDDDnWYYYYZRkMMO9RhhhlmGA0x7FCHGWaYYTSEzr/7BO6GtbW14Obm9u8+jWF+4ZSUlADg7u7+bz6TYR4EkpOTGwVBsLnd3/6jHaqbmxtJSUn/7tMYZphhhhGRSCRld/rb8JJ/mGGGGUZD/EfPUIcZRhP88Y9/BOAvf/nLkMa5ePEiu3btorGxkZkzZ7Jw4UK8vb0xNDTUxGkO8wDwQDjUhIQErl69Sm9vLxKJBIDy8nLOnz9Pfn4+Y8eO5dVXX2XRokXo6GjmkpOTk8nPz6e7u5vW1lYuXLhAXFyc+HdBEPDw8OCzzz5j+vTpQ7YXHx9PZWUlSqWS+Ph49PT00NHRoa2tjYceeogRI0YQFBQ0ZDv/KTQ3N5OdnY2FhQVdXV38/e9/JyIiQvy7sbExS5cu5eOPP8bU1PSuY1VUVGjknKZMmYKbmxsJCQmcOXOG0NBQDA0NWbBgAS+++OI9z2OYuyOTydi5cydvvfUW48eP56OPPiI4OFjjdlpaWtixYwdff/01zc3NTJ48mT//+c9Mnjx5yGP/4h1qUlISf/zjH0lISEChUIgOValUIpPJUCgUREdHU15eTnZ2Nq+88sqQZxTvvvsuP/30E3V1dSgUCgRBQC6X09fXd8Nxubm5vPXWW3z33Xd4e3sPypZKpSIsLIwPPviA7OxsVCoVMplMvE5BENi/fz9OTk4sX76cX/3qV9jb2w/p+v6dKJVKTp48yYEDB0hMTMTW1hapVEp6ejrt7e3ice3t7Rw9ehQ9PT3ee+89zMzM7jjmrl27NHJu+vr6ODs709LSwpkzZxAEgWnTpnH27FmeeeYZjTjU5uZmMjMziY+PJyoqitzcXLq6ujAyMsLDwwNDQ0MSEhLw9vbm+++/x9XVdcg2+/r6uHLlCunp6ZSUlFBRUUFxcTF1dXUolUp0dXWxtLRk5cqVvPbaa5ibmw/Z5u2QSCTo6+ujq6tLVlYWBw4c0LhDbWpq4rPPPuPbb7+lpqYGlUpFT08PCoVCI+P/4h1qW1sb9fX1tLa2ohZ6kUgkeHp6EhwcTGtrK2fPnqW4uJgdO3YgkUh48803h2Tz6tWrlJaW0tfXhyAIaGlpERAQwKJFizA1NSUxMZHw8HB6e3vJz8/n4MGD4rJzoBw+fJi//e1v5OTkIJPJbntMb28vHR0dNDY20tvby6effjqUy/u3EhYWxscff0xKSgpyuRxnZ2fmz5/PihUr6O3tpbu7m8zMTE6dOkVzczOhoaHMnj2bNWvW/Cznl5mZyT//+U9CQkIAUCgUrF27Fmtr6yGPffHiRT777DOSkpJob29HIpGIP21tbSQnJyOTydDT08PCwgIjI6Mh20xISGD79u1cunQJhUKBi4sLDg4OjBw5khEjRgDXrrGyspKDBw8SGBjI6tWrh2z3dkilUsaMGcP06dM5deoUV65c4cqVK0yaNEljNhISEoiMjKSmpgZBEMSJiab4xTtUQRC4XjHL1NSUiRMnsmnTJubPn098fDxVVVVkZmbS2NhISkrKkG0qFAoMDAyYNWsWbm5umJqaMmPGDKZNm4a2tjYtLS3Mnj2b1157ja6uriHZvHTpEmVlZaIz3bJlC3PmzEEikZCens7JkycpLCykp6eHnp4eWltbh3x9/07S0tIoLS2lt7cXAG1tbQIDA5k5cyaCINDb20tsbCyFhYUUFhZiaGh4T2f2u9/9DoC//e1vgz6v5uZmLly4wPfff090dDQymQwfHx+eeOIJVq1ahZbW0Pd3L1y4QEJCArq6uixevJilS5cyduxYDAwM6OnpoaKigrq6OlxdXfHx8cHCwmLINq9cuYJUKuWtt95iwoQJWFtbI5VKb3A09fX1/PDDDxw9ehQrK6sh27wTSqWSxsZGKisrgWurM21tbY3aaG1tpa2tTfQZbm5uLFu2jMDAQI2M/4t3qGPHjmXNmjXY29vT3d3NzJkzWblyJd7e3ujq6opLZLjmfG9elg+Gt99+mxdffBEHBwfMzMzQ1tbG2NhYDCVUV1cTHx+vEZvPPvssVVVVKJVKrKysWL9+PZMnT0YikTBnzhyWLVvGoUOH2Llz56DGVyqVdHR0kJqaSkVFBQqFgtLSUnR1dbG2tsbJyYnRo0djb2+Pvr6+xmLQd2L27NkkJSUhlUoZN24ca9asYfz48eKSvr6+npKSEsrKytDR0cHDw+OeseOmpqYBn0dhYSGxsbFERkair6+PRCIhOTmZ7OxsdHR0WLJkCS+++CJBQUF3DTf0l7NnzxIWFkZXVxdvvfUWy5cvx8rKCiMjIyQSCYIg4OnpiVwuR09PD11d3SHbBHjkkUdYtWoVpqamGBsb3+LAmpubSUxM5PLlyzz//PMaidPX1tZSWVlJb28vWVlZxMfH4+LiwsSJE7l06RJZWVloa2tjb2/PmDFjhmzveoyMjNDX1wfA09OTzZs38+STT2JiYqKR8X/xDtXS0pJnn32W9evXo1QqsbCwEGcssbGx7Ny5k7Kya2lj5ubmzJ07d8g2/fz8EAQBbW3tG2YmCoWC9PR0vvnmG2JjY5FIJNjY2LB06dJB2/L19eWdd95BV1cXY2NjrK2tMTY2BkBHR4f6+nri4uLEpeBA+ec//8mxY8dobGykp6cHgJ6eHiQSCbq6uujr62NkZIStrS0jRozAysoKGxsbxo0bh5+fn0aWutcTFBTERx99RGdnJ2ZmZuL1qlQq8vPz2bdvHyEhIfT29mJlZcWsWbPuufT9+uuv+22/tbWVI0eOcPDgQQoLC2lubsbBwQG4tpnh7+/PqlWrWLp0KZ6enuLDOVTCw8MpLy+np6eHsLAwcnJyABg1ahTBwcF4enqKLzVNYmtre8e/lZWV8cMPPxAbG8uGDRt4+OGHhxxm6Ojo4O9//zvR0dF0d3fT0dFBS0sLBgYGHDp0iPb2duRyOZ6enixatAgDA4Mh2bueoqIiTp48SV5eHt7e3jzzzDM89thj2NnZaczGL96hamlp4eDgIN70cG2pEBkZyT//+U+io6Pp7e1FX1+fMWPGaMSh3m6WVllZyYULFzhx4gRRUVG0tbVhYmLCnDlzePjhhwdtSyqVEhAQgJaW1i3LypKSEiIiIsjIyKCvr29QDnX37t2UlZWxePHiW/59ZWUlJSUl1NbWkpubi76+Pnp6eujr62NnZ8ekSZN47LHHGDt2rMYedENDw1s28Gpra4mKiiI0NJSoqCiampqws7Nj69atPProoxqxq+bTTz8lJCSEwsJClEolxsbGzJs3jxEjRpCTk0NPTw+Ojo74+vpqZJmvRr2Mb21tpbm5mY6ODpqamjh79iwWFhbMnTuX55577mep9ioqKiIiIoK4uDiUSiWPPPIIixcv1sjLs6enh8TERFJSUlAqlcC1VVx7ezv19fUAWFtbM3fuXFauXDlke01NTcTFxRETE4NCoSAxMZG2tjbmzZvHnDlzcHR01Ggc9RfvUG8mNzeXEydOcPLkSdLT0+no6ABg5MiR/Pa3v8XDw0Oj9trb28nLyyM0NJRTp05RVlZGe3s7enp6jB8/nhdeeGHIu+43O3BBEMjOzmbXrl2EhYWJ8cbB8NBDDyGRSNi0adMtTrGjo4O2tjaqq6uprKykuLiYyspKysvLSUpKoqSkhPz8fFatWsULL7ww6HO4mb6+Prq7u8nOziYtLY2srCyuXLlCaWkp7e3tODg48Oijj7J582ZcXFzuOd5vfvMbAD7
"text/plain": [
"<Figure size 360x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"# extra code this cell generates and saves Figure 311\n",
2021-10-29 07:03:30 +02:00
"size = 5\n",
"pad = 0.2\n",
"plt.figure(figsize=(size, size))\n",
"for images, (label_col, label_row) in [(X_ba, (0, 0)), (X_bb, (1, 0)),\n",
" (X_aa, (0, 1)), (X_ab, (1, 1))]:\n",
" for idx, image_data in enumerate(images[:size*size]):\n",
" x = idx % size + label_col * (size + pad)\n",
" y = idx // size + label_row * (size + pad)\n",
" plt.imshow(image_data.reshape(28, 28), cmap=\"binary\",\n",
" extent=(x, x + 1, y, y + 1))\n",
"plt.xticks([size / 2, size + pad + size / 2], [str(cl_a), str(cl_b)])\n",
"plt.yticks([size / 2, size + pad + size / 2], [str(cl_b), str(cl_a)])\n",
"plt.plot([size + pad / 2, size + pad / 2], [0, 2 * size + pad], \"k:\")\n",
"plt.plot([0, 2 * size + pad], [size + pad / 2, size + pad / 2], \"k:\")\n",
"plt.axis([0, 2 * size + pad, 0, 2 * size + pad])\n",
"plt.xlabel(\"Predicted label\")\n",
"plt.ylabel(\"True label\")\n",
"save_fig(\"error_analysis_digits_plot\")\n",
2016-05-22 17:40:18 +02:00
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"Note: there are several other ways you could code a plot like this one, but it's a bit hard to get the axis labels right:\n",
"* using [nested GridSpecs](https://matplotlib.org/stable/gallery/subplots_axes_and_figures/gridspec_nested.html)\n",
"* merging all the digits in each block into a single image (then using 2×2 subplots). For example:\n",
" ```python\n",
" X_aa[:25].reshape(5, 5, 28, 28).transpose(0, 2, 1, 3).reshape(5 * 28, 5 * 28)\n",
" ```\n",
"* using [subfigures](https://matplotlib.org/stable/gallery/subplots_axes_and_figures/subfigures.html) (since Matplotlib 3.4)"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"# Multilabel Classification"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 80,
2021-11-03 03:54:29 +01:00
"metadata": {
"tags": []
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"KNeighborsClassifier()"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"import numpy as np\n",
2021-10-29 07:03:30 +02:00
"from sklearn.neighbors import KNeighborsClassifier\n",
"\n",
2021-10-30 21:28:30 +02:00
"y_train_large = (y_train >= '7')\n",
"y_train_odd = (y_train.astype('int8') % 2 == 1)\n",
2021-10-29 07:03:30 +02:00
"y_multilabel = np.c_[y_train_large, y_train_odd]\n",
"\n",
"knn_clf = KNeighborsClassifier()\n",
"knn_clf.fit(X_train, y_multilabel)"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 81,
2021-10-29 07:03:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[False, True]])"
]
},
"execution_count": 81,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"knn_clf.predict([some_digit])"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
"**Warning**: the following cell may take a few minutes to run:"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.976410265560605"
]
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-11-03 03:54:29 +01:00
"y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3)\n",
"f1_score(y_multilabel, y_train_knn_pred, average=\"macro\")"
]
},
{
"cell_type": "code",
"execution_count": 83,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9778357403921755"
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"# extra code shows that we get a negligible performance improvement when we\n",
"# set average=\"weighted\" because the classes are already pretty\n",
"# well balanced.\n",
2021-11-03 03:54:29 +01:00
"f1_score(y_multilabel, y_train_knn_pred, average=\"weighted\")"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"ClassifierChain(base_estimator=SVC(), cv=3, random_state=42)"
]
},
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-11-03 03:54:29 +01:00
"from sklearn.multioutput import ClassifierChain\n",
"\n",
"chain_clf = ClassifierChain(SVC(), cv=3, random_state=42)\n",
"chain_clf.fit(X_train[:2000], y_multilabel[:2000])"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0., 1.]])"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain_clf.predict([some_digit])"
]
},
2016-05-22 17:40:18 +02:00
{
2021-10-29 07:03:30 +02:00
"cell_type": "markdown",
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"# Multioutput Classification"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 86,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
2021-11-03 03:54:29 +01:00
"np.random.seed(42) # to make this code example reproducible\n",
2021-10-29 07:03:30 +02:00
"noise = np.random.randint(0, 100, (len(X_train), 784))\n",
"X_train_mod = X_train + noise\n",
"noise = np.random.randint(0, 100, (len(X_test), 784))\n",
"X_test_mod = X_test + noise\n",
"y_train_mod = X_train\n",
"y_test_mod = X_test"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 87,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAADPCAYAAACz4wViAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAASi0lEQVR4nO3dSWyWhbfH8afW0kInWqAjpaUDLbQIlMFCiKAFp+jGxJUmbowmkpjoQuMCIiFRE4eNQzSyMLoykYWJMRiFKgUFWoFSoANvSweg0HmgdG7v4u7ufX8HqOAB+X6W7zdPeant/9w3OZwbMTMzEwAA8G97wPsNAADuTwwgAIALBhAAwAUDCADgggEEAHDxoBUHBgbkilxLS4t8LiIiQrbU1FTZzp8/L1tpaals/f39YV/v7e2Vz3R1dcmWnp4uW2JiomxXr16VLScnR7a///5btoULF8qWlpYm29mzZ2VbtWqVbH19fbJlZ2fLNjg4GPb1pKQk+Yy1gVlTUyNbZmambA88oP9vqrq6Otk2bNggW0xMjP6BvjHWTIEgCPs7xCcgAIALBhAAwAUDCADgggEEAHDBAAIAuGAAAQBcmGvY1lqxtcobCoVks9aKFyxYINuRI0dki4yMDPt6WVmZfMb6u+Xm5spmrQ6fPn1atrGxMdliY2Nly8rKkq2qqmpWX/PChQuyFRQUyFZRUSHb5s2bw75urdZbPwsJCQmyzZ8/X7Zr167JNjIyIlt1dbVs6u8G4J/hExAAwAUDCADgggEEAHDBAAIAuGAAAQBcMIAAAC7MNWx1ZToI7DXm5ubmWbWtW7dab0dSa9PDw8Pymfj4eNmsdeM5c+bMqiUnJ8s2NDQk2/j4uGzWqrL197ty5YpsDQ0Nsln/fc6cORP2dWtdXz0TBEGwYsUK2U6ePCnb8uXLZXvwQf3jbq3eA7gz+AQEAHDBAAIAuGAAAQBcMIAAAC4YQAAAFwwgAIALcw3bWmO2rhxv375dtuPHj8tmXSu2LmV3dnaGfd1au7Xa3LlzZZuampJt0aJFsqWlpcl2/fp12ayV6cLCQtnOnj0rW0xMjGzW2rR1RVt9TeuKucVaBx8cHJQtLi5OtjVr1sjW2NgoW0ZGhmwAZo9PQAAAFwwgAIALBhAAwAUDCADgggEEAHDBAAIAuDDXsK2V4+zsbNkiIiJks1aOu7u7ZUtKSpJNrQdPTEzIZ7Zs2SKbtfpcX18vW0pKimzWWnReXp5svb29slnryAUFBbJZK+3WZe6oqCjZlLKyMtkOHTokm/WzFxkZKZu15m/93ay1fAB3Bp+AAAAuGEAAABcMIACACwYQAMAFAwgA4IIBBABwYe6erl+/XjZrrTgzM1M2a3176dKlsvX19cm2YsWKsK9XVVXJZ9rb22Wz3r/197527ZpsY2NjsjU1Ncm2YcMG2dQV8Bs1a7V72bJlsllr2BcvXgz7uvXfzWKtwhcXF8s2MzMjm7WWf/Xq1Zt7YwBuGz4BAQBcMIAAAC4YQAAAFwwgAIALBhAAwAUDCADgIsJaWz127JiMsbGx8rn+/n7ZYmJiZLOuO+fm5t7y12xubpbPlJSUyGZd5Z4/f75s1pVm63syPT0t2+TkpGwJCQmypaamyjYyMjKrP8+6ZN7a2hr2det7smbNGtlOnDghW3x8vGxpaWmy9fT0yGa9z+zsbH3e/cb0Lxhw/wj7O8QnIACACwYQAMAFAwgA4IIBBABwwQACALgwj5FaG0VLliyRzTokGR0dLZu1+WQdMf3ggw/Cvn758mX5zPDwsGzWltgrr7wiW1tbm2wvvfSSbNaRTOt7aR1UtY6RJicny2ZtAA4NDcmWlZUV9nXre2L9vfPy8mSzjr5am4/qaG0QBEFHR4ds94Mffvgh7Otff/21fCYjI0M2a9v1hRdekM3635z8/HzZcG/iExAAwAUDCADgggEEAHDBAAIAuGAAAQBcMIAAAC7MY6QdHR2zOqQ4NTUlW319vWwbNmyQLRQKyfbOO++Efb22tlY+Y610Wiva1hp5b2+vbNYRU+vP27p1q2xVVVWyrVy5Ujbre2mtzyYlJd3yc6+++qp8xlqnXrx4sWzWYVfrZ2/OnDmyFRQUyJaYmPifP0a6dOnSsK+3tLT8q+/DOrBrrdH/F6h/yvDWW2/JZ9atW3en3s7txjFSAMDdgwEEAHDBAAIAuGAAAQBcMIAAAC4YQAAAF+Y17IgIvX2ampoq24ULF2Tbtm2bbAMDA9bbkd59992wr1dXV8tnrPVF68r03r17ZbP+3q2trbJt2rRJNmuVvKamRrZ58+bJdv36ddnGxsZk6+npka2uru6W/6zvvvtOtjNnzshmXdG2/rsePnxYNmu1OzExUbb/CvVzbf2MWWvR586dk+3kyZOy/f7777IdPXpUNus6v3WRfbaioqJkW7hwoWzW1XX191Pr2UFwT61hh8UnIACACwYQAMAFAwgA4IIBBABwwQACALhgAAEAXJhr2Nb6orVyPD4+LltXV5dsmZmZs/qaaiVyx44ds3of1p+1c+dO2eLj42Wz1oofe+wx2Y4fPy7b6OiobIsWLZLNWrW2Lka/+eabsk1OToZ9PSMjQz4TGRkpm3Upu7S0VLbKykrZSkpKZLOuld8PysvLb+n1G3nyySdn9VxfX59s1vq2tY5sXY2fLesqfmFhoWxFRUWyqWv6eXl5N//G7jF8AgIAuGAAAQBcMIAAAC4YQAAAFwwgAIALBhAAwEXEzMyMjL/88ouM1pqvdb11aGhINmsV1rqqrC4/W1eMrbVo68pvKBSSzfp7qzXlILDXS63VdGtV2bosvnHjRtnee+892d5//33Z1Brsl19+KZ9ZunSpbAkJCbJZ/w2s9Vjr58u6+rx582Z9Fv7G9C8Y/pP27dsn2/PPPy/bypUrw75eUVEhn0lOTr75N+Yr7O8Qn4AAAC4YQAAAFwwgAIALBhAAwAUDCADgggEEAHBhXsNeuHChbNnZ2bJZK8cxMTGyqWuwQWBfTm5oaAj7unWV1lrRVl8vCOz3uGzZMtmsVetHH31UtubmZtmKi4tlO336tGzWinZ1dbVs1jXpl19+Oezrqamp8hnrorr1nFq7D4IgOHfunGw5OTmyWdeUgf+rs7NTttdee00265+97Nq1K+zr99Cq9S3jExAAwAUDCADgggEEAHDBAAIAuGAAAQBcMIAAAC7MNezS0lLZ2traZJuYmJAtNjZWNuvCtrXGrK479/T0yGfi4+Nls97jggULZLOuU6enp8tWWVkpm/W9vHLlimxTU1Oy/fjjj7KdOnVKtoceekg29f2cO3eufCYlJUW2sbGxW/6zgiAI0tLSZBscHJStq6tLNuvKOe5Pn3/+uWzWirZ1kd36ZyP/VXwCAgC4YAABAFwwgAAALhhAAAAXDCAAgAsGEADARYR1nXViYkLG9vZ2+dz09LRs1kprU1OTbNb689DQUNjXW1pa5DO5ubmytba2ymZdhE5KSpLtwIEDsq1YsUK2vr4+2UZGRmSzVsnLyspki4uLk+3777+XTa1GWxfVGxsbZbPWyK3/BqFQSDbrAvqSJUtkS05OjpDxxvQvGO5qhw8flq28vFy28fFx2f744w/ZHnnkkZt7Y/emsL9DfAICALhgAAEAXDCAAAAuGEAAABcMIACACwYQAMCFeQ3burZ8/vx5/UUf1F+2u7tbtszMTNkuXbokm7oiGxUVJZ/p6OiQLTs7Wzbra46Ojsq2du1a2awLudZqt3XB+eOPP5ZteHhYtueee042a31bXQK31uet69Q5OTmyWe/fWmm3Ln1bz+H+9PPPP8tmrVpv27ZNNnW5/37FJyAAgAsGEADABQMIAOCCAQQAcMEAAgC4YAABAFyYa9jXr1+XzbpybK0cV1ZW3sTb+v+sFe22trawr+fn58tnrFXx+vp62awr2n/99desnrO+zydOnJCtuLhYturqatmKiopke/HFF2WzrgMvW7Ys7OvR0dHymbS0NNmOHj0q28MPPyxbQ0ODbKtWrZLNulK8fft22XBvsy7K79+/Xzbr53r37t2yWf+M437EJyAAgAsGEADABQMIAOC
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"# extra code this cell generates and saves Figure 312\n",
"plt.subplot(121); plot_digit(X_test_mod[0])\n",
"plt.subplot(122); plot_digit(y_test_mod[0])\n",
2021-10-29 07:03:30 +02:00
"save_fig(\"noisy_digit_example_plot\")\n",
"plt.show()"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 88,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARAAAAEQCAYAAAB4CisVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAFn0lEQVR4nO3dIWtWURzA4U2GGgSTcViMCn4DNRk0mIxWm8VmWRUEwSwoCFpmEjEsGUWwiUkwGcVgmOjY6yeY5/LzMre9z1Pvn3svbPx2YIdzVxeLxQpAcex/vwBweAkIkAkIkAkIkAkIkAkIkK0NrvsfL7C61wUrECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECATECBb+98v8L+srq4OZzY2Nibda319fThz7ty54cylS5cmPQ8OCisQIBMQIBMQIBMQIBMQIBMQIBMQIBMQIBMQIFvanajv3r0bzvz8+XPSva5cuTKcWSwWw5mrV69Oet5B8+vXr+HM27dvZ3ve6dOnhzP3798fzty+fXuO11lqViBAJiBAJiBAJiBAJiBAJiBAJiBAJiBAtjrY4DTe/cTKzs7OcOb79+/DmadPnw5nPnz4MOmdNjc3J82N3Lp1azgz5XjIb9++TXre69evJ82NnDx5cjjz5s2b4czly5dneJtDb88fsBUIkAkIkAkIkAkIkAkIkAkIkAkIkAkIkNlIxoEyZVPelA1gN27cGM68evVqOHP9+vXhzBKwkQyYn4AAmYAAmYAAmYAAmYAAmYAAmYAA2dJ+2pKDaW1t/Cu5vb29D2/CFFYgQCYgQCYgQCYgQCYgQCYgQCYgQCYgQCYgQOZIQw6d8+fPD2d2d3eHM58+fZrjdZaBIw2B+QkIkAkIkAkIkAkIkAkIkAkIkAkIkDnSkAPl/fv3w5kpG8Du3bs3x+swYAUCZAICZAICZAICZAICZAICZAICZAICZE4k40A5dmyev2lfvnwZzpw9e3aWZy0BJ5IB8xMQIBMQIBMQIBMQIBMQIBMQIBMQIHMiGfvi2bNns93rwYMHwxmbxPaHFQiQCQiQCQiQCQiQCQiQCQiQCQiQCQiQ2UjGvtja2prtXnfv3p3tXvwbKxAgExAgExAgExAgExAgExAgExAgExAgExAgsxOVf7axsTGcef78+aR7ff78+V9fh31kBQJkAgJkAgJkAgJkAgJkAgJkAgJkAgJkq4vF4m/X/3qRo+/jx4/DmQsXLgxnHj16NOl5d+7cmTTHvlrd64IVCJAJCJAJCJAJCJAJCJAJCJAJCJAJCJDZSMZf3bx5czjz8uXL4cz29vak5504cWLSHPvKRjJgfgICZAICZAICZAICZAICZAICZAICZD5tucQePnw4nNnc3BzOPH78eDhjg9jRZAUCZAICZAICZAICZAICZAICZAICZAICZE4kO6J2dnaGM8ePHx/ODH4/VlZWVlZ+/PgxnDl16tRwhgPLiWTA/AQEyAQEyAQEyAQEyAQEyAQEyAQEyAQEyBxpeESdOXNmlvs8efJkOGOX6fKyAgEyAQEyAQEyAQEyAQEyAQEyAQEyAQEyG8kOmd+/f0+au3bt2nDmxYsXw5mLFy9Oeh7LyQoEyAQEyAQEyAQEyAQEyAQEyAQEyAQEyHwb95D5+vXrpLn19fVZnre7uzvLfTjUfBsXmJ+AAJmAAJmAAJmAAJmAAJmAAJmAAJkTyZbY1tbW/34FDjkrECATECATECATECATECATECATECATECATECBzpOEhM+eRho4rZCJHGgLzExAgExAgExAgExAgExAgExAgExAgs5EMGLGRDJifgACZgACZgACZgACZgACZgACZgADZ6Nu4e24gAbACATIBATIBATIBATIBATIBAbI/hKWyB2gX2E4AAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 17:40:18 +02:00
"source": [
"knn_clf = KNeighborsClassifier()\n",
2021-10-29 07:03:30 +02:00
"knn_clf.fit(X_train_mod, y_train_mod)\n",
"clean_digit = knn_clf.predict([X_test_mod[0]])\n",
2021-10-29 07:03:30 +02:00
"plot_digit(clean_digit)\n",
"save_fig(\"cleaned_digit_example_plot\") # extra code saves Figure 313\n",
2021-10-29 07:03:30 +02:00
"plt.show()"
2016-05-22 17:40:18 +02:00
]
},
{
2021-10-29 07:03:30 +02:00
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"# Exercise solutions"
2016-05-22 17:40:18 +02:00
]
},
{
2021-10-29 07:03:30 +02:00
"cell_type": "markdown",
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"## 1. An MNIST Classifier With Over 97% Accuracy"
2016-05-22 17:40:18 +02:00
]
},
{
2021-10-29 07:03:30 +02:00
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"Exercise: _Try to build a classifier for the MNIST dataset that achieves over 97% accuracy on the test set. Hint: the `KNeighborsClassifier` works quite well for this task; you just need to find good hyperparameter values (try a grid search on the `weights` and `n_neighbors` hyperparameters)._"
2016-05-22 17:40:18 +02:00
]
},
{
2021-10-29 07:03:30 +02:00
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"Let's start with a simple K-Nearest Neighbors classifier and measure its performance on the test set. This will be our baseline:"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 89,
2017-07-07 21:56:30 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9688"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 17:40:18 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"knn_clf = KNeighborsClassifier()\n",
"knn_clf.fit(X_train, y_train)\n",
"baseline_accuracy = knn_clf.score(X_test, y_test)\n",
"baseline_accuracy"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"Great! A regular KNN classifier with the default hyperparameters is already very close to our goal."
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see if tuning the hyperparameters can help. To speed up the search, let's train only on the first 10,000 images:"
]
},
2017-10-04 10:57:40 +02:00
{
"cell_type": "code",
"execution_count": 90,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"GridSearchCV(cv=5, estimator=KNeighborsClassifier(),\n",
" param_grid=[{'n_neighbors': [3, 4, 5, 6],\n",
" 'weights': ['uniform', 'distance']}])"
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
2021-10-29 07:03:30 +02:00
"param_grid = [{'weights': [\"uniform\", \"distance\"], 'n_neighbors': [3, 4, 5, 6]}]\n",
2017-10-04 10:57:40 +02:00
"\n",
"knn_clf = KNeighborsClassifier()\n",
2021-10-29 07:03:30 +02:00
"grid_search = GridSearchCV(knn_clf, param_grid, cv=5)\n",
"grid_search.fit(X_train[:10_000], y_train[:10_000])"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "code",
"execution_count": 91,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"{'n_neighbors': 4, 'weights': 'distance'}"
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"grid_search.best_params_"
]
},
{
"cell_type": "code",
"execution_count": 92,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9441999999999998"
]
},
"execution_count": 92,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"grid_search.best_score_"
]
},
2021-10-29 07:03:30 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The score dropped, but that was expected since we only trained on 10,000 images. So let's take the best model and train it again on the full training set:"
]
},
2017-10-04 10:57:40 +02:00
{
"cell_type": "code",
"execution_count": 93,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9714"
]
},
"execution_count": 93,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"grid_search.best_estimator_.fit(X_train, y_train)\n",
"tuned_accuracy = grid_search.score(X_test, y_test)\n",
"tuned_accuracy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We reached our goal of 97% accuracy! 🥳"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Data Augmentation"
]
},
2021-10-29 07:03:30 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Exercise: _Write a function that can shift an MNIST image in any direction (left, right, up, or down) by one pixel. You can use the `shift()` function from the `scipy.ndimage` module. For example, `shift(image, [2, 1], cval=0)` shifts the image two pixels down and one pixel to the right. Then, for each image in the training set, create four shifted copies (one per direction) and add them to the training set. Finally, train your best model on this expanded training set and measure its accuracy on the test set. You should observe that your model performs even better now! This technique of artificially growing the training set is called _data augmentation_ or _training set expansion_._"
2021-10-29 07:03:30 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's try augmenting the MNIST dataset by adding slightly shifted versions of each image."
]
},
2017-10-04 10:57:40 +02:00
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
2017-10-04 10:57:40 +02:00
"outputs": [],
"source": [
"from scipy.ndimage import shift"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
2017-10-04 10:57:40 +02:00
"outputs": [],
"source": [
"def shift_image(image, dx, dy):\n",
" image = image.reshape((28, 28))\n",
" shifted_image = shift(image, [dy, dx], cval=0, mode=\"constant\")\n",
" return shifted_image.reshape([-1])"
]
},
2021-10-29 07:03:30 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see if it works:"
]
},
2017-10-04 10:57:40 +02:00
{
"cell_type": "code",
"execution_count": 96,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAp0AAADTCAYAAADDGKgLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbxklEQVR4nO3df5Ac5Z3f8c+HBZ9AuiQSSLYEskRhiIUDBtUicIwPiCHBdu4wdQeGsnXowAgSizpcUnyIu4tVFAQfhTlfIpsYB0U4MTjmhw5FmOM3CA5BJBSXkK0QYUvWyejEqvAdEsbWyfrmj26RYfcZ7czOPDs7Pe9X1dROf+eZ7qd397vPd7v7mXZECAAAAMjpkE53AAAAANVH0QkAAIDsKDoBAACQHUUnAAAAsqPoBAAAQHYUnQAAAMiOonMMsb3V9qIm3xO2f6/N/Vhie2M71wmMlkbyaHAb2++z/ajtt2y39XPkbB9V5unZTb5vle3l7ewL0G26MZ9tL7e9qsn1ftD2Gtu/tL21xW6OWRSdbWb7aNt32N5ue6/tn9n+lu1jGnj7aZK+0eQmp0r6n833FOg+tifb/kY5yPzK9k7bT9g+r8lVDc61RZKmSTpF0lTbZ5cDy1Ft6jqAQcjnd7lR0i8kfVDSabbn2d7T4T613aGd7kCV2D5W0vOStki6TNJmScdJuknSWtsfiYitife9JyL2RsRAs9uMiL9trddAV7lf0hGSrpD0qqQpks6SdGQzK0nk2gckvRQRm6XiqEPrXQUwDPL5//uApAcP1Ai2O9ubTDjS2V5fl7Rf0rkR8UREbIuIpySdW8a/Lkm2n7Z9u+1bbQ9I+usyPvgUwQm2nykPt79i+5O299ieV9PmndPrtmeWy79r+zHbv7D9o9r/Gm332b7T9hbbb9vebPtLtvldwJhm+59I+pik68r8+mlErI2IWyPiu4Oaj7P9Tdtvlmcd/t2gdb2Ta+WprAsk/X6ZP8slPVU2HaiJyYUv2f5xmT8v2/7coHWfZvulMm//t6TTG9i3I8pTcnvKoz3XJ9pMtH2X7Z+X237c9odqXv9b25+pWf5r27ttH1ouH1/uy9E134M/Odj3Ccilyvmc2NeDbsfFJQAflvTvy/49Lem/ShpfLoftJc1udyyi0GgT25MknS/p6xHxi9rXyuVvSPqE7Yll+HOSrCLpfj+xvkMkrZC0T9IZkuZJ+rKk32igOzdJ+o8qfonXSvqu7Qnla4dI+pmkiyXNkvTHkq6X9AcN7irQKXvKx+/YHjdM2y9KelnSbEl/JukW2x+p0/Y0SY9L+p6Ky1X+UNLvlq99qCYmFafArpD0BUknSrpZ0jdtf0qSbI+X9JCkn0jql3SdpFsb2LdbJZ1Xbvfjkk6V9FuD2ixXMeBdIGmOilNxf2X78PL1ZySdU/bjiHL7vyq/StLZkl6NiJ/VrLOZ7xPQTlXO58EOup2yT69I+mr5/HckXasix6eWj5Fsd+yJCB5teKgYDELShXVev7B8fY6kpyVtSLTZKmlR+fxfqSg4j655/Z+X65hXEwtJv1c+n1kuX1Xz+tFl7MyD9P0rkh6vWV4iaWOnv6c8eAx+qBg83pD0S0lrVPwhPn1Qm62S7hkU2yzpTwa1WVSzvErS8prls8u8OaomNl7S25I+NmjdX5P0/fL5fEl/J2lCzeufK9d1dp19mqCiOPzsoNjfHeiTpOPLdfxWTZt/LOnvJX2+XP43kl4pn58naZOkuyQtLmPfkfStZr5PPHjkfFQxn8s2yyWtanQ75fJGSUtqludJ2tPpn1G7HxzpbL96M+U86PWXhlnPByW9Fu8+KrFWxWn64Wyoef5a+XXKOx2xr7a9zvaAiwuVvyjp/Q2sF+ioiLhfxQSB35b0sIp/xF5InI7eMGj5NdXkwAidKGmciqOLew48VBR7x5VtZqn4h7J2AsCaYdZ7nKT31LYr3/9yTZtZKnK/ts3fl21OLENPSzrB9jQVg+xTZezs8vWzyuVaOb5PQEMqms8j2U7PYCJR+2xWUVB+SNJfJl6fVb7+43L5rWHWZ9UvYIfzDweeRES4uCD5EEkqr/n6morZfc9LelPFIf8LR7gtYFRFxC8lPVY+brD9XyQtsX1rROwtm/3D4Lep9cuJDrz/tyVtG/Tage2N5Or/Rt5zsDbFKY+ITbZ3qigyz1aR52sl/SfbJ6o46/H0oPfm+D4BDatgPo9kOz2DPy5tEhFvSHpE0r8tr6d6R7n8BUkPl+0asUnS0eVRiwP61frP7ExJL0bE0ohYHxGvqgf/20Kl/EjFP9DDXRfWjAODXd+g7fxK0oyIeHXQ46c1bU4qrwU74IxhtvWqisHnnXbl+//ZoG0fIukjNW3+kaSTytcOeEbSp1T8rXgmipmwuyR9SUOv5wTGom7P58Ea2U69Pvcd5PWuRNHZXgtUJMvjtv+F7ekuPkD2MRX/MS1oYl2Pqbiw+C7bH7Z9hqTbVFzn2cqH3f5fSbNtf6KczfqnKk67AWOa7SNtP2n7c7ZPtn2s7YtUFFRPRMSbbdzcT1Xk2adcfJbghIjYreKas1ttX277A7ZPKS9XmV++724VObrM9odcfHLEHx9sQ+Wpuzsl/Znt88oZ6ctUM+BE8dEvD6qYfPAx2ydJ+u8qzlTcXbO6pyV9RtLmiHi9jD2j4jq0p0f83QDarKr5PFiD20nZqmLW/nkuPpD+iIO07RoUnW0UET9WcYThh5L+m4oZb3erOGp5WkRsaWJd+1Wc8v4NSf9LxYSAm1Qkzi9b6OY3Vczqu1vFqbeZKmbMAWPdHkkvqJh5+oyKPPsPKn6XP3OQ9zWtPCL4ZRU5t1PS0vKlP1Ux0W5Ruf3HVEyG2FK+b4+kf61i4s96FYPNHzWwyUUqrsFcUX7dKGn1oDZ/oOJvwcry6xGSzo+It2vaPKWiWH16mBjQaVXO58EOup06fX5e0n+WdI+kARXFeNdzOUsKXcD2hyX9QFJ/RAw3EQkAAGDMoOgcw2xfqGLC0WYVRyRvU3Ga/tTgBwcAALoIs9fHtt9U8UG40yX9XMXpsS9ScAIAgG7DkU4AAABkx0QiAAAAZNdS0Wn7fNuv2H7V9nXt6hSAPMhZoHuQr6iaEZ9et92n4jMfz5O0XcXH71waET+q956jjjoqZs6cOaLtAe20detW7dq1qx13m+gazeYs+YqxgnxljEV3qZezrUwkmqPiDhc/kSTb35V0gd59d4x3mTlzptatW9fCJoH26O/v73QXOqGpnCVfMVaQr4yx6C71craV0+tHS/qbmuXtZQzA2ETOAt2DfEXltFJ0pk51DDlXb3u+7XW21w0MDLSwOQAtGjZnyVdgzGCMReW0UnRuV/H5kQccI+m1wY0i4o6I6I+I/smTJ7ewOQAtGjZnyVdgzGCMReW0UnSulXS87WNtv0fSJSruCQxgbCJnge5BvqJyRjyRKCL22V4g6RFJfZKWRcQP29YzAG1FzgLdg3xFFbV0G8yI+L6k77epLwAyI2eB7kG+omq4IxEAAACyo+gEAABAdhSdAAAAyI6iEwAAANlRdAIAACA7ik4AAABkR9EJAACA7Cg6AQAAkB1FJwAAALKj6AQAAEB2FJ0AAADIjqITAAAA2VF0AgAAIDuKTgAAAGRH0QkAAIDsKDoBAACQHUUnAAAAsqPoBAAAQHYUnQAAAMiOohMAAADZHdrKm21vlbRb0q8l7YuI/nZ0Cgf3xhtvJONLly5NxpcsWZKMR8SQ2KGHpn8lHnnkkWT8rLPOSsb7+vqScXQWOQt0D/K1ufGumbFOYrzrhJaKztI5EbGrDesBMDrIWaB7kK+oDE6vAwAAILtWi86Q9Kjtl2zPb0eHAGRFzgLdg3xFpbR6ev2jEfGa7SmSHrP9fyJidW2DMlHmS9L73//+FjcHoEUHzVnyFRhTGGNRKS0d6YyI18qvr0taIWlOos0dEdEfEf2TJ09uZXMAWjRczpKvwNjBGIuqGfGRTtvjJR0SEbvL5/9S0g1t61kP2b9/fzL+5JNPJuNz585Nxnfu3NnUdqdNmzYktmPHjmTbc889NxnftSt9ffukSZOa6gvyI2eB7lH
"text/plain": [
"<Figure size 864x216 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-10-04 10:57:40 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"image = X_train[1000] # some random digit to demo\n",
2017-10-04 10:57:40 +02:00
"shifted_image_down = shift_image(image, 0, 5)\n",
"shifted_image_left = shift_image(image, -5, 0)\n",
"\n",
2021-12-08 03:16:42 +01:00
"plt.figure(figsize=(12, 3))\n",
2017-10-04 10:57:40 +02:00
"plt.subplot(131)\n",
2021-10-29 07:03:30 +02:00
"plt.title(\"Original\")\n",
"plt.imshow(image.reshape(28, 28),\n",
" interpolation=\"nearest\", cmap=\"Greys\")\n",
2017-10-04 10:57:40 +02:00
"plt.subplot(132)\n",
2021-10-29 07:03:30 +02:00
"plt.title(\"Shifted down\")\n",
"plt.imshow(shifted_image_down.reshape(28, 28),\n",
" interpolation=\"nearest\", cmap=\"Greys\")\n",
2017-10-04 10:57:40 +02:00
"plt.subplot(133)\n",
2021-10-29 07:03:30 +02:00
"plt.title(\"Shifted left\")\n",
"plt.imshow(shifted_image_left.reshape(28, 28),\n",
" interpolation=\"nearest\", cmap=\"Greys\")\n",
2017-10-04 10:57:40 +02:00
"plt.show()"
]
},
2021-10-29 07:03:30 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Looks good! Now let's create an augmented training set by shifting every image left, right, up and down by one pixel:"
]
},
2017-10-04 10:57:40 +02:00
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
2017-10-04 10:57:40 +02:00
"outputs": [],
"source": [
"X_train_augmented = [image for image in X_train]\n",
"y_train_augmented = [label for label in y_train]\n",
"\n",
2021-10-29 07:03:30 +02:00
"for dx, dy in ((-1, 0), (1, 0), (0, 1), (0, -1)):\n",
2017-10-04 10:57:40 +02:00
" for image, label in zip(X_train, y_train):\n",
" X_train_augmented.append(shift_image(image, dx, dy))\n",
" y_train_augmented.append(label)\n",
"\n",
"X_train_augmented = np.array(X_train_augmented)\n",
"y_train_augmented = np.array(y_train_augmented)"
]
},
2021-10-29 07:03:30 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's shuffle the augmented training set, or else all shifted images will be grouped together:"
]
},
2017-10-04 10:57:40 +02:00
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
2017-10-04 10:57:40 +02:00
"outputs": [],
"source": [
"shuffle_idx = np.random.permutation(len(X_train_augmented))\n",
"X_train_augmented = X_train_augmented[shuffle_idx]\n",
"y_train_augmented = y_train_augmented[shuffle_idx]"
]
},
2021-10-29 07:03:30 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's train the model using the best hyperparameters we found in the previous exercise:"
]
},
2017-10-04 10:57:40 +02:00
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
2017-10-04 10:57:40 +02:00
"outputs": [],
"source": [
"knn_clf = KNeighborsClassifier(**grid_search.best_params_)"
]
},
{
"cell_type": "code",
"execution_count": 100,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"KNeighborsClassifier(n_neighbors=4, weights='distance')"
]
},
"execution_count": 100,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"knn_clf.fit(X_train_augmented, y_train_augmented)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning**: the following cell may take a few minutes to run:"
]
},
2017-10-04 10:57:40 +02:00
{
"cell_type": "code",
"execution_count": 101,
2021-10-29 07:03:30 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9763"
]
},
"execution_count": 101,
"metadata": {},
"output_type": "execute_result"
}
],
2021-10-29 07:03:30 +02:00
"source": [
"augmented_accuracy = knn_clf.score(X_test, y_test)\n",
"augmented_accuracy"
2021-10-29 07:03:30 +02:00
]
},
{
"cell_type": "markdown",
2017-10-04 10:57:40 +02:00
"metadata": {},
2021-10-29 07:03:30 +02:00
"source": [
"By simply augmenting the data, we've got a 0.5% accuracy boost. Perhaps it does not sound so impressive, but it actually means that the error rate dropped significantly:"
2021-10-29 07:03:30 +02:00
]
},
{
"cell_type": "code",
"execution_count": 102,
2021-10-29 07:03:30 +02:00
"metadata": {
"tags": []
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"error_rate_change = -17%\n"
]
}
],
2017-10-04 10:57:40 +02:00
"source": [
2021-10-29 07:03:30 +02:00
"error_rate_change = (1 - augmented_accuracy) / (1 - tuned_accuracy) - 1\n",
"print(f\"error_rate_change = {error_rate_change:.0%}\")"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-29 07:03:30 +02:00
"The error rate dropped quite a bit thanks to data augmentation."
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Tackle the Titanic dataset"
]
},
2021-10-29 07:03:30 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Exercise: _Tackle the Titanic dataset. A great place to start is on [Kaggle](https://www.kaggle.com/c/titanic). Alternatively, you can download the data from https://homl.info/titanic.tgz and unzip this tarball like you did for the housing data in Chapter 2. This will give you two CSV files: _train.csv_ and _test.csv_ which you can load using `pandas.read_csv()`. The goal is to train a classifier that can predict the `Survived` column based on the other columns._"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's fetch the data and load it:"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {},
2017-10-04 10:57:40 +02:00
"outputs": [],
"source": [
2022-02-20 22:20:48 +01:00
"from pathlib import Path\n",
2017-10-04 10:57:40 +02:00
"import pandas as pd\n",
2022-02-20 22:20:48 +01:00
"import tarfile\n",
"import urllib.request\n",
2017-10-04 10:57:40 +02:00
"\n",
"def load_titanic_data():\n",
2022-02-20 22:20:48 +01:00
" tarball_path = Path(\"datasets/titanic.tgz\")\n",
" if not tarball_path.is_file():\n",
" Path(\"datasets\").mkdir(parents=True, exist_ok=True)\n",
" url = \"https://github.com/ageron/data/raw/main/titanic.tgz\"\n",
" urllib.request.urlretrieve(url, tarball_path)\n",
" with tarfile.open(tarball_path) as titanic_tarball:\n",
" titanic_tarball.extractall(path=\"datasets\")\n",
" return [pd.read_csv(Path(\"datasets/titanic\") / filename)\n",
" for filename in (\"train.csv\", \"test.csv\")]"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {},
2017-10-04 10:57:40 +02:00
"outputs": [],
"source": [
"train_data, test_data = load_titanic_data()"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The data is already split into a training set and a test set. However, the test data does *not* contain the labels: your goal is to train the best model you can on the training data, then make your predictions on the test data and upload them to Kaggle to see your final score."
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's take a peek at the top few rows of the training set:"
]
},
{
"cell_type": "code",
"execution_count": 105,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PassengerId</th>\n",
" <th>Survived</th>\n",
" <th>Pclass</th>\n",
" <th>Name</th>\n",
" <th>Sex</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Ticket</th>\n",
" <th>Fare</th>\n",
" <th>Cabin</th>\n",
" <th>Embarked</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Braund, Mr. Owen Harris</td>\n",
" <td>male</td>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>A/5 21171</td>\n",
" <td>7.2500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n",
" <td>female</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>PC 17599</td>\n",
" <td>71.2833</td>\n",
" <td>C85</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Heikkinen, Miss. Laina</td>\n",
" <td>female</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>STON/O2. 3101282</td>\n",
" <td>7.9250</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n",
" <td>female</td>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>113803</td>\n",
" <td>53.1000</td>\n",
" <td>C123</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Allen, Mr. William Henry</td>\n",
" <td>male</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>373450</td>\n",
" <td>8.0500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" PassengerId Survived Pclass \\\n",
"0 1 0 3 \n",
"1 2 1 1 \n",
"2 3 1 3 \n",
"3 4 1 1 \n",
"4 5 0 3 \n",
"\n",
" Name Sex Age SibSp \\\n",
"0 Braund, Mr. Owen Harris male 22.0 1 \n",
"1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n",
"2 Heikkinen, Miss. Laina female 26.0 0 \n",
"3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n",
"4 Allen, Mr. William Henry male 35.0 0 \n",
"\n",
" Parch Ticket Fare Cabin Embarked \n",
"0 0 A/5 21171 7.2500 NaN S \n",
"1 0 PC 17599 71.2833 C85 C \n",
"2 0 STON/O2. 3101282 7.9250 NaN S \n",
"3 0 113803 53.1000 C123 S \n",
"4 0 373450 8.0500 NaN S "
]
},
"execution_count": 105,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"train_data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The attributes have the following meaning:\n",
"* **PassengerId**: a unique identifier for each passenger\n",
2017-10-04 10:57:40 +02:00
"* **Survived**: that's the target, 0 means the passenger did not survive, while 1 means he/she survived.\n",
"* **Pclass**: passenger class.\n",
"* **Name**, **Sex**, **Age**: self-explanatory\n",
"* **SibSp**: how many siblings & spouses of the passenger aboard the Titanic.\n",
"* **Parch**: how many children & parents of the passenger aboard the Titanic.\n",
"* **Ticket**: ticket id\n",
"* **Fare**: price paid (in pounds)\n",
"* **Cabin**: passenger's cabin number\n",
"* **Embarked**: where the passenger embarked the Titanic"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The goal is to predict whether or not a passenger survived based on attributes such as their age, sex, passenger class, where they embarked and so on."
]
},
2017-10-04 10:57:40 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's explicitly set the `PassengerId` column as the index column:"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "code",
"execution_count": 106,
2017-10-04 10:57:40 +02:00
"metadata": {},
"outputs": [],
"source": [
"train_data = train_data.set_index(\"PassengerId\")\n",
"test_data = test_data.set_index(\"PassengerId\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's get more info to see how much data is missing:"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Int64Index: 891 entries, 1 to 891\n",
"Data columns (total 11 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Survived 891 non-null int64 \n",
" 1 Pclass 891 non-null int64 \n",
" 2 Name 891 non-null object \n",
" 3 Sex 891 non-null object \n",
" 4 Age 714 non-null float64\n",
" 5 SibSp 891 non-null int64 \n",
" 6 Parch 891 non-null int64 \n",
" 7 Ticket 891 non-null object \n",
" 8 Fare 891 non-null float64\n",
" 9 Cabin 204 non-null object \n",
" 10 Embarked 889 non-null object \n",
"dtypes: float64(2), int64(4), object(5)\n",
"memory usage: 83.5+ KB\n"
]
}
],
2017-10-04 10:57:40 +02:00
"source": [
"train_data.info()"
]
},
{
"cell_type": "code",
"execution_count": 108,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"27.0"
]
},
"execution_count": 108,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data[train_data[\"Sex\"]==\"female\"][\"Age\"].median()"
]
},
2017-10-04 10:57:40 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Okay, the **Age**, **Cabin** and **Embarked** attributes are sometimes null (less than 891 non-null), especially the **Cabin** (77% are null). We will ignore the **Cabin** for now and focus on the rest. The **Age** attribute has about 19% null values, so we will need to decide what to do with them. Replacing null values with the median age seems reasonable. We could be a bit smarter by predicting the age based on the other columns (for example, the median age is 37 in 1st class, 29 in 2nd class and 24 in 3rd class), but we'll keep things simple and just use the overall median age."
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The **Name** and **Ticket** attributes may have some value, but they will be a bit tricky to convert into useful numbers that a model can consume. So for now, we will ignore them."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's take a look at the numerical attributes:"
]
},
{
"cell_type": "code",
"execution_count": 109,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Survived</th>\n",
" <th>Pclass</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Fare</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>891.000000</td>\n",
" <td>891.000000</td>\n",
" <td>714.000000</td>\n",
" <td>891.000000</td>\n",
" <td>891.000000</td>\n",
" <td>891.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>0.383838</td>\n",
" <td>2.308642</td>\n",
" <td>29.699113</td>\n",
" <td>0.523008</td>\n",
" <td>0.381594</td>\n",
" <td>32.204208</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.486592</td>\n",
" <td>0.836071</td>\n",
" <td>14.526507</td>\n",
" <td>1.102743</td>\n",
" <td>0.806057</td>\n",
" <td>49.693429</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.000000</td>\n",
" <td>1.000000</td>\n",
" <td>0.416700</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>0.000000</td>\n",
" <td>2.000000</td>\n",
" <td>20.125000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>7.910400</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>0.000000</td>\n",
" <td>3.000000</td>\n",
" <td>28.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>14.454200</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>1.000000</td>\n",
" <td>3.000000</td>\n",
" <td>38.000000</td>\n",
" <td>1.000000</td>\n",
" <td>0.000000</td>\n",
" <td>31.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>1.000000</td>\n",
" <td>3.000000</td>\n",
" <td>80.000000</td>\n",
" <td>8.000000</td>\n",
" <td>6.000000</td>\n",
" <td>512.329200</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Survived Pclass Age SibSp Parch Fare\n",
"count 891.000000 891.000000 714.000000 891.000000 891.000000 891.000000\n",
"mean 0.383838 2.308642 29.699113 0.523008 0.381594 32.204208\n",
"std 0.486592 0.836071 14.526507 1.102743 0.806057 49.693429\n",
"min 0.000000 1.000000 0.416700 0.000000 0.000000 0.000000\n",
"25% 0.000000 2.000000 20.125000 0.000000 0.000000 7.910400\n",
"50% 0.000000 3.000000 28.000000 0.000000 0.000000 14.454200\n",
"75% 1.000000 3.000000 38.000000 1.000000 0.000000 31.000000\n",
"max 1.000000 3.000000 80.000000 8.000000 6.000000 512.329200"
]
},
"execution_count": 109,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"train_data.describe()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* Yikes, only 38% **Survived**! 😭 That's close enough to 40%, so accuracy will be a reasonable metric to evaluate our model.\n",
2017-10-04 10:57:40 +02:00
"* The mean **Fare** was £32.20, which does not seem so expensive (but it was probably a lot of money back then).\n",
"* The mean **Age** was less than 30 years old."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check that the target is indeed 0 or 1:"
]
},
{
"cell_type": "code",
"execution_count": 110,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0 549\n",
"1 342\n",
"Name: Survived, dtype: int64"
]
},
"execution_count": 110,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"train_data[\"Survived\"].value_counts()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's take a quick look at all the categorical attributes:"
]
},
{
"cell_type": "code",
"execution_count": 111,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"3 491\n",
"1 216\n",
"2 184\n",
"Name: Pclass, dtype: int64"
]
},
"execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"train_data[\"Pclass\"].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 112,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"male 577\n",
"female 314\n",
"Name: Sex, dtype: int64"
]
},
"execution_count": 112,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"train_data[\"Sex\"].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 113,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"S 644\n",
"C 168\n",
"Q 77\n",
"Name: Embarked, dtype: int64"
]
},
"execution_count": 113,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"train_data[\"Embarked\"].value_counts()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The Embarked attribute tells us where the passenger embarked: C=Cherbourg, Q=Queenstown, S=Southampton."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's build our preprocessing pipelines, starting with the pipeline for numerical attributes:"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {},
2017-10-04 10:57:40 +02:00
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.impute import SimpleImputer\n",
2017-10-04 10:57:40 +02:00
"\n",
"num_pipeline = Pipeline([\n",
" (\"imputer\", SimpleImputer(strategy=\"median\")),\n",
" (\"scaler\", StandardScaler())\n",
2017-10-04 10:57:40 +02:00
" ])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can build the pipeline for the categorical attributes:"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "code",
"execution_count": 115,
"metadata": {},
2017-10-04 10:57:40 +02:00
"outputs": [],
"source": [
"from sklearn.preprocessing import OrdinalEncoder, OneHotEncoder"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: the `sparse` hyperparameter below was renamed to `sparse_output`."
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [],
2017-10-04 10:57:40 +02:00
"source": [
"cat_pipeline = Pipeline([\n",
" (\"ordinal_encoder\", OrdinalEncoder()), \n",
" (\"imputer\", SimpleImputer(strategy=\"most_frequent\")),\n",
" (\"cat_encoder\", OneHotEncoder(sparse_output=False)),\n",
2017-10-04 10:57:40 +02:00
" ])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, let's join the numerical and categorical pipelines:"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
2017-10-04 10:57:40 +02:00
"outputs": [],
"source": [
"from sklearn.compose import ColumnTransformer\n",
"\n",
"num_attribs = [\"Age\", \"SibSp\", \"Parch\", \"Fare\"]\n",
"cat_attribs = [\"Pclass\", \"Sex\", \"Embarked\"]\n",
"\n",
"preprocess_pipeline = ColumnTransformer([\n",
" (\"num\", num_pipeline, num_attribs),\n",
" (\"cat\", cat_pipeline, cat_attribs),\n",
2017-10-04 10:57:40 +02:00
" ])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Cool! Now we have a nice preprocessing pipeline that takes the raw data and outputs numerical input features that we can feed to any Machine Learning model we want."
]
},
{
"cell_type": "code",
"execution_count": 118,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[-0.56573582, 0.43279337, -0.47367361, ..., 0. ,\n",
" 0. , 1. ],\n",
" [ 0.6638609 , 0.43279337, -0.47367361, ..., 1. ,\n",
" 0. , 0. ],\n",
" [-0.25833664, -0.4745452 , -0.47367361, ..., 0. ,\n",
" 0. , 1. ],\n",
" ...,\n",
" [-0.10463705, 0.43279337, 2.00893337, ..., 0. ,\n",
" 0. , 1. ],\n",
" [-0.25833664, -0.4745452 , -0.47367361, ..., 1. ,\n",
" 0. , 0. ],\n",
" [ 0.20276213, -0.4745452 , -0.47367361, ..., 0. ,\n",
" 1. , 0. ]])"
]
},
"execution_count": 118,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"X_train = preprocess_pipeline.fit_transform(train_data)\n",
"X_train"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's not forget to get the labels:"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
2017-10-04 10:57:40 +02:00
"outputs": [],
"source": [
"y_train = train_data[\"Survived\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We are now ready to train a classifier. Let's start with a `RandomForestClassifier`:"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "code",
"execution_count": 120,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"RandomForestClassifier(random_state=42)"
]
},
"execution_count": 120,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)\n",
"forest_clf.fit(X_train, y_train)"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Great, our model is trained, let's use it to make predictions on the test set:"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
2017-10-04 10:57:40 +02:00
"outputs": [],
"source": [
"X_test = preprocess_pipeline.transform(test_data)\n",
"y_pred = forest_clf.predict(X_test)"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now we could just build a CSV file with these predictions (respecting the format expected by Kaggle), then upload it and hope for the best. But wait! We can do better than hope. Why don't we use cross-validation to have an idea of how good our model is?"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "code",
"execution_count": 122,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.8137578027465668"
]
},
"execution_count": 122,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"forest_scores = cross_val_score(forest_clf, X_train, y_train, cv=10)\n",
"forest_scores.mean()"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Okay, not too bad! Looking at the [leaderboard](https://www.kaggle.com/c/titanic/leaderboard) for the Titanic competition on Kaggle, you can see that our score is in the top 2%, woohoo! Some Kagglers reached 100% accuracy, but since you can easily find the [list of victims](https://www.encyclopedia-titanica.org/titanic-victims/) of the Titanic, it seems likely that there was little Machine Learning involved in their performance! 😆"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's try an `SVC`:"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "code",
"execution_count": 123,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.8249313358302123"
]
},
"execution_count": 123,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"from sklearn.svm import SVC\n",
2017-10-04 10:57:40 +02:00
"\n",
"svm_clf = SVC(gamma=\"auto\")\n",
"svm_scores = cross_val_score(svm_clf, X_train, y_train, cv=10)\n",
"svm_scores.mean()"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Great! This model looks better."
2017-10-04 10:57:40 +02:00
]
},
2018-05-26 15:01:33 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But instead of just looking at the mean accuracy across the 10 cross-validation folds, let's plot all 10 scores for each model, along with a box plot highlighting the lower and upper quartiles, and \"whiskers\" showing the extent of the scores (thanks to Nevin Yilmaz for suggesting this visualization). Note that the `boxplot()` function detects outliers (called \"fliers\") and does not include them within the whiskers. Specifically, if the lower quartile is $Q_1$ and the upper quartile is $Q_3$, then the interquartile range $IQR = Q_3 - Q_1$ (this is the box's height), and any score lower than $Q_1 - 1.5 \\times IQR$ is a flier, and so is any score greater than $Q3 + 1.5 \\times IQR$."
2018-05-26 15:01:33 +02:00
]
},
{
"cell_type": "code",
"execution_count": 124,
2018-05-26 15:01:33 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfwAAAD4CAYAAAAJtFSxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAdYklEQVR4nO3df5QddZ3m8feT7vQKAjEJPaghP8gMG8iqRHI3NDoIgiMwozDqriZGB7Inh5M9ZGTcGRdQ58gedPyJLk6yYowIOgFkECW4KDiATNTpJLdNMCSZuD1NOomg2yTtqKxjp9Of/aOq4aa5HSrp23Vv33pe59xz68e36n6SdOe59a2qbykiMDMzs+Y2qd4FmJmZ2fhz4JuZmRWAA9/MzKwAHPhmZmYF4MA3MzMrgNZ6FzCeTj755JgzZ069yzAzM8tFV1fXMxHRXm1dUwf+nDlzKJfL9S7DzMwsF5J6R1vnLn0zM7MCcOCbmZkVgAPfzMysABz4ZmZmBeDANzMzKwAHvpmZWQE48M3MLLF3E2y4KXm3ptPU9+GbmVlGezfB7ZfBoQFoaYMr1sPMRfWuymrIgW9mViCSsjX863OOuDoialCN5cld+mZmBRIR1V97NhI3npK0ufGUZH60tg77CcmBb2ZmSff9FeuTaXfnNyUHvpmZJYZD3mHflBz4ZmZmBeDANzMzKwAHvpmZWQE48M3MzArAgW9mZlYAuQa+pEsk7ZLULem6KuunSLpf0uOStktaVrHu/emyJyTdKekledZuZmY2keUW+JJagNXApcB8YImk+SOaXQ3siIizgAuAmyS1SZoBvA8oRcSrgBZgcV61m5mZTXR5HuEvArojoiciBoC7gMtHtAngRCVjP54AHAAG03WtwHGSWoHjgafyKdvMzGziyzPwZwB7K+b3pcsqrQLOJAnzbcA1ETEUET8DPgPsAZ4G/jUiHqr2IZKuklSWVO7r66v1n8HMzGxCyjPwqz2xYeSAzBcDW4FXAguAVZJOkjSVpDfgtHTdSyW9p9qHRMSaiChFRKm9vb1WtZuZmU1oeQb+PmBmxfypvLBbfhlwbyS6gSeBM4A3AU9GRF9EHATuBV6XQ81mZmZNIc/A3wycLuk0SW0kF92tH9FmD3ARgKRTgHlAT7q8Q9Lx6fn9i4CduVVuuerq7Wf1o9109fbXuxQzs6bRmtcHRcSgpJXAgyRX2d8aEdslrUjX3wLcCNwmaRvJKYBrI+IZ4BlJ9wA/JrmIbwuwJq/aLT9dvf0sXdvJwOAQba2TWLe8g4Wzp9a7LDOzCU/N/FzjUqkU5XK53mVYFUlHzdg188+vWT1I8u/VBCapKyJK1dZ5pD2ri4io+irvPsC8Dz8AwLwPP0B594FR2/o/JTOz7Bz41lAWzp7KuuUdAO7ONzOrIQe+NZzhkHfYm5nVjgPfzMysABz4ZmZmBeDANzMzKwAHvpmZWQE48M3MzArAgW9mZlYADnwzM7MCcOCbmZkVgAPfzMysABz4ZmZmBeDANzMzKwAHvpmZWQE48M3MzArAgW9mZlYADnwzM7MCcOCbmZkVgAPfzMysAHINfEmXSNolqVvSdVXWT5F0v6THJW2XtKxi3csk3SPpnyXtlHRunrWbmZlNZLkFvqQWYDVwKTAfWCJp/ohmVwM7IuIs4ALgJklt6bqbge9GxBnAWcDOXAo3MzNrAnke4S8CuiOiJyIGgLuAy0e0CeBESQJOAA4Ag5JOAt4AfBkgIgYi4pe5VW5mZjbB5Rn4M4C9FfP70mWVVgFnAk8B24BrImIImAv0AV+RtEXSWkkvzaFmMzOzppBn4KvKshgxfzGwFXglsABYlR7dtwJnA1+IiNcCzwIvuAYAQNJVksqSyn19fTUq3czMbGLLM/D3ATMr5k8lOZKvtAy4NxLdwJPAGem2+yJiY9ruHpIvAC8QEWsiohQRpfb29pr+AczMzCaqPAN/M3C6pNPSC/EWA+tHtNkDXAQg6RRgHtATET8H9kqal7a7CNiRT9lmZmYTX2teHxQRg5JWAg8CLcCtEbFd0op0/S3AjcBtkraRnAK4NiKeSXfx58C69MtCD0lvgJmZmWWgiJGn0ZtHqVSKcrlc7zLsGEiimX82zRqVf/cmNkldEVGqts4j7ZmZmRWAA9/MzKwAHPjWcO7YuOewdzMzGzsHvjWUOzbu4YPf3AbAB7+5zaFvZlYjDnxrKN954ukjzpuZ2bFx4FtDufRVrzjivJmZHZvc7sM3y+Ld58wCYOkn4W/e9urn5s3MbGx8hG8NZzjkHfZmZrXjwLeamzZtGpLG9ALGvI9p06bV+W/CzKxxuEvfaq6/v78hRuoa/uJgZmY+wjczMysEB76ZmVkBOPDNzMwKwIFvZmZWAA58MzOzAnDgm5mZFYAD38zMrAAc+GZmZgXgwDczMysAB741nK7eflY/2k1Xb3+9SzErlvJth79bU/HQutZQunr7Wbq2k4HBIdpaJ7FueQcLZ0+td1lmza98G3z7mmR6+L10Zb2qsXGQ6xG+pEsk7ZLULem6KuunSLpf0uOStktaNmJ9i6Qtkr6dX9WWp86e/QwMDjEUcHBwiM6e/fUuyawYdt535Hmb8HILfEktwGrgUmA+sETS/BHNrgZ2RMRZwAXATZLaKtZfA+zMoVyrk46502lrnUSLYHLrJDrmTq93SWbFcOblR563CS9T4Ev60zSwx2IR0B0RPRExANwFjPyJCuBEJY85OwE4AAymNZwK/Amwdox1WANbOHsq65Z38N/ePM/d+WZ5Kl0Jb7k5mX7Lze7Ob0JZz+GvA34t6Xbg1ojYdQyfNQPYWzG/DzhnRJtVwHrgKeBE4F0RMZSu+5/Af0+XWxNbOHuqg96sHkpXAssc9k0qa5f+y4GPAOcDOyT9QNIySS89is+q9nDykQ9NvxjYCrwSWACsknSSpLcA/zciul70Q6SrJJUllfv6+o6iPDOziW/atGlIOuYXMKbtJTFt2rQ6/y1YNZkCPyJ+HRFfjIgO4NXARuDjwNOSviSpI8Nu9gEzK+ZPJTmSr7QMuDcS3cCTwBnA64HLJO0mORVwoaS/G6XWNRFRiohSe3t7lj+emVnT6O/vJyLq+urv9y21jeioL9qLiB3A54A1QBvwLmCDpI2SXnOETTcDp0s6Lb0QbzFJ932lPcBFAJJOAeYBPRFxfUScGhFz0u0eiYj3HG3tZmZmRZU58CVNlvROSd8lOfK+EFgBnALMBn4KfH207SNiEFgJPEhypf3dEbFd0gpJK9JmNwKvk7QNeBi4NiKeOYY/l5mZmVVQxMjT6FUaSX8LLCE55/41YG16pF/ZZhawOyIaZvS+UqkU5XK53mUUjiSy/FwVpQ6zPDXCz30j1FBUkroiolRtXdar9OeTHJ3fm95SV81TwBuPoT4zMzMbZ5kCPyIuytBmEHhszBWZmZlZzWUdeOdjFefZK5evkHRj7csyMzOzWsp6vv29wJYqy7uAP6tdOWZmZjYesgb+7wHVRrHZT3KVvpmZmTWwrIG/BzivyvI3kAyoY1YzXb39rH60m65eD95hlqvybfC1tyXv1nSyXqX/ReBz6YA5j6TLLiIZbe+T41GYFVNXbz9L13YyMDhEW+skP0DHLC/l2+Db1yTT/5L+N+8x9ZtK1qF1byIJ/c+TDLDzU+Bm4EsR8anxK8+KprNnPwODQwwFHBwcorNnf71LMiuGnfcded4mvMyD5ETE9cDJQAdwLtAeEdeNV2FWTB1zp9PWOokWweTWSXTMnV7vksyK4czLjzxvE17WLn0AIuJZkjHxzcbFwtlTWbe8g86e/XTMne7ufLO8DHff77wvCXt35zedTEPrAkh6I8nwurNIHprznIi4sPaljZ2H1q2PRhlWs1HqMMtTI/zcN0INRXWkoXWzDrxzJfAd4ETgApJb9KYCZwM7Rt3QzMzMGkLWc/h/BayMiCXAQeD6iHgt8HfAb8arODMzM6uNrIE/F/iHdPp3wAnp9CrgyhrXZGZmZjWWNfD3k3TnA/wMeFU6PR04rtZFmZmZWW1lvUp/A/BmYBtwN/B5SX9EMvjO98apNjMzM6uRrIG/Enh
"text/plain": [
"<Figure size 576x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2018-05-26 15:01:33 +02:00
"source": [
"plt.figure(figsize=(8, 4))\n",
"plt.plot([1]*10, svm_scores, \".\")\n",
"plt.plot([2]*10, forest_scores, \".\")\n",
2021-12-08 03:16:42 +01:00
"plt.boxplot([svm_scores, forest_scores], labels=(\"SVM\", \"Random Forest\"))\n",
2021-10-29 07:03:30 +02:00
"plt.ylabel(\"Accuracy\")\n",
2018-05-26 15:01:33 +02:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The random forest classifier got a very high score on one of the 10 folds, but overall it had a lower mean score, as well as a bigger spread, so it looks like the SVM classifier is more likely to generalize well."
]
},
2017-10-04 10:57:40 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To improve this result further, you could:\n",
"* Compare many more models and tune hyperparameters using cross validation and grid search,\n",
"* Do more feature engineering, for example:\n",
" * Try to convert numerical attributes to categorical attributes: for example, different age groups had very different survival rates (see below), so it may help to create an age bucket category and use it instead of the age. Similarly, it may be useful to have a special category for people traveling alone since only 30% of them survived (see below).\n",
" * Replace **SibSp** and **Parch** with their sum.\n",
" * Try to identify parts of names that correlate well with the **Survived** attribute.\n",
" * Use the **Cabin** column, for example take its first letter and treat it as a categorical attribute."
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "code",
"execution_count": 125,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Survived</th>\n",
" </tr>\n",
" <tr>\n",
" <th>AgeBucket</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0.0</th>\n",
" <td>0.576923</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15.0</th>\n",
" <td>0.362745</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30.0</th>\n",
" <td>0.423256</td>\n",
" </tr>\n",
" <tr>\n",
" <th>45.0</th>\n",
" <td>0.404494</td>\n",
" </tr>\n",
" <tr>\n",
" <th>60.0</th>\n",
" <td>0.240000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75.0</th>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Survived\n",
"AgeBucket \n",
"0.0 0.576923\n",
"15.0 0.362745\n",
"30.0 0.423256\n",
"45.0 0.404494\n",
"60.0 0.240000\n",
"75.0 1.000000"
]
},
"execution_count": 125,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"train_data[\"AgeBucket\"] = train_data[\"Age\"] // 15 * 15\n",
"train_data[[\"AgeBucket\", \"Survived\"]].groupby(['AgeBucket']).mean()"
]
},
{
"cell_type": "code",
"execution_count": 126,
2017-10-04 10:57:40 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Survived</th>\n",
" </tr>\n",
" <tr>\n",
" <th>RelativesOnboard</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.303538</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.552795</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.578431</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.724138</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.200000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.136364</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>0.333333</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Survived\n",
"RelativesOnboard \n",
"0 0.303538\n",
"1 0.552795\n",
"2 0.578431\n",
"3 0.724138\n",
"4 0.200000\n",
"5 0.136364\n",
"6 0.333333\n",
"7 0.000000\n",
"10 0.000000"
]
},
"execution_count": 126,
"metadata": {},
"output_type": "execute_result"
}
],
2017-10-04 10:57:40 +02:00
"source": [
"train_data[\"RelativesOnboard\"] = train_data[\"SibSp\"] + train_data[\"Parch\"]\n",
"train_data[[\"RelativesOnboard\", \"Survived\"]].groupby(\n",
" ['RelativesOnboard']).mean()"
2017-10-04 10:57:40 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Spam classifier"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-29 07:03:30 +02:00
"Exercise: _Build a spam classifier (a more challenging exercise):_\n",
"\n",
"* _Download examples of spam and ham from [Apache SpamAssassin's public datasets](https://homl.info/spamassassin)._\n",
"* _Unzip the datasets and familiarize yourself with the data format._\n",
"* _Split the datasets into a training set and a test set._\n",
"* _Write a data preparation pipeline to convert each email into a feature vector. Your preparation pipeline should transform an email into a (sparse) vector that indicates the presence or absence of each possible word. For example, if all emails only ever contain four words, \"Hello,\" \"how,\" \"are,\" \"you,\" then the email \"Hello you Hello Hello you\" would be converted into a vector [1, 0, 0, 1] (meaning [“Hello\" is present, \"how\" is absent, \"are\" is absent, \"you\" is present]), or [3, 0, 0, 2] if you prefer to count the number of occurrences of each word._\n",
"\n",
"_You may want to add hyperparameters to your preparation pipeline to control whether or not to strip off email headers, convert each email to lowercase, remove punctuation, replace all URLs with \"URL,\" replace all numbers with \"NUMBER,\" or even perform _stemming_ (i.e., trim off word endings; there are Python libraries available to do this)._\n",
"\n",
"_Finally, try out several classifiers and see if you can build a great spam classifier, with both high recall and high precision._"
]
},
{
"cell_type": "code",
"execution_count": 127,
"metadata": {},
"outputs": [],
"source": [
"import tarfile\n",
"\n",
"def fetch_spam_data():\n",
" spam_root = \"http://spamassassin.apache.org/old/publiccorpus/\"\n",
" ham_url = spam_root + \"20030228_easy_ham.tar.bz2\"\n",
" spam_url = spam_root + \"20030228_spam.tar.bz2\"\n",
"\n",
" spam_path = Path() / \"datasets\" / \"spam\"\n",
" spam_path.mkdir(parents=True, exist_ok=True)\n",
" for dir_name, tar_name, url in ((\"easy_ham\", \"ham\", ham_url),\n",
" (\"spam\", \"spam\", spam_url)):\n",
" if not (spam_path / dir_name).is_dir():\n",
" path = (spam_path / tar_name).with_suffix(\".tar.bz2\")\n",
" print(\"Downloading\", path)\n",
" urllib.request.urlretrieve(url, path)\n",
" tar_bz2_file = tarfile.open(path)\n",
" tar_bz2_file.extractall(path=spam_path)\n",
" tar_bz2_file.close()\n",
" return [spam_path / dir_name for dir_name in (\"easy_ham\", \"spam\")]"
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
"ham_dir, spam_dir = fetch_spam_data()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, let's load all the emails:"
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {},
"outputs": [],
"source": [
"ham_filenames = [f for f in sorted(ham_dir.iterdir()) if len(f.name) > 20]\n",
"spam_filenames = [f for f in sorted(spam_dir.iterdir()) if len(f.name) > 20]"
]
},
{
"cell_type": "code",
"execution_count": 130,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"2500"
]
},
"execution_count": 130,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(ham_filenames)"
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"500"
]
},
"execution_count": 131,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(spam_filenames)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can use Python's `email` module to parse these emails (this handles headers, encoding, and so on):"
]
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {},
"outputs": [],
"source": [
"import email\n",
"import email.policy\n",
"\n",
"def load_email(filepath):\n",
" with open(filepath, \"rb\") as f:\n",
" return email.parser.BytesParser(policy=email.policy.default).parse(f)"
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {},
"outputs": [],
"source": [
"ham_emails = [load_email(filepath) for filepath in ham_filenames]\n",
"spam_emails = [load_email(filepath) for filepath in spam_filenames]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's look at one example of ham and one example of spam, to get a feel of what the data looks like:"
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Martin A posted:\n",
"Tassos Papadopoulos, the Greek sculptor behind the plan, judged that the\n",
" limestone of Mount Kerdylio, 70 miles east of Salonika and not far from the\n",
" Mount Athos monastic community, was ideal for the patriotic sculpture. \n",
" \n",
" As well as Alexander's granite features, 240 ft high and 170 ft wide, a\n",
" museum, a restored amphitheatre and car park for admiring crowds are\n",
"planned\n",
"---------------------\n",
"So is this mountain limestone or granite?\n",
"If it's limestone, it'll weather pretty fast.\n",
"\n",
"------------------------ Yahoo! Groups Sponsor ---------------------~-->\n",
"4 DVDs Free +s&p Join Now\n",
"http://us.click.yahoo.com/pt6YBB/NXiEAA/mG3HAA/7gSolB/TM\n",
"---------------------------------------------------------------------~->\n",
"\n",
"To unsubscribe from this group, send an email to:\n",
"forteana-unsubscribe@egroups.com\n",
"\n",
" \n",
"\n",
"Your use of Yahoo! Groups is subject to http://docs.yahoo.com/info/terms/\n"
]
}
],
"source": [
"print(ham_emails[1].get_content().strip())"
]
},
{
"cell_type": "code",
"execution_count": 135,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Help wanted. We are a 14 year old fortune 500 company, that is\n",
"growing at a tremendous rate. We are looking for individuals who\n",
"want to work from home.\n",
"\n",
"This is an opportunity to make an excellent income. No experience\n",
"is required. We will train you.\n",
"\n",
"So if you are looking to be employed from home with a career that has\n",
"vast opportunities, then go:\n",
"\n",
"http://www.basetel.com/wealthnow\n",
"\n",
"We are looking for energetic and self motivated people. If that is you\n",
"than click on the link and fill out the form, and one of our\n",
"employement specialist will contact you.\n",
"\n",
"To be removed from our link simple go to:\n",
"\n",
"http://www.basetel.com/remove.html\n",
"\n",
"\n",
"4139vOLW7-758DoDY1425FRhM1-764SMFc8513fCsLl40\n"
]
}
],
"source": [
"print(spam_emails[6].get_content().strip())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Some emails are actually multipart, with images and attachments (which can have their own attachments). Let's look at the various types of structures we have:"
]
},
{
"cell_type": "code",
"execution_count": 136,
"metadata": {},
"outputs": [],
"source": [
"def get_email_structure(email):\n",
" if isinstance(email, str):\n",
" return email\n",
" payload = email.get_payload()\n",
" if isinstance(payload, list):\n",
2021-11-21 22:17:46 +01:00
" multipart = \", \".join([get_email_structure(sub_email)\n",
" for sub_email in payload])\n",
" return f\"multipart({multipart})\"\n",
" else:\n",
" return email.get_content_type()"
]
},
{
"cell_type": "code",
"execution_count": 137,
"metadata": {},
"outputs": [],
"source": [
"from collections import Counter\n",
2017-10-04 10:57:40 +02:00
"\n",
"def structures_counter(emails):\n",
" structures = Counter()\n",
" for email in emails:\n",
" structure = get_email_structure(email)\n",
" structures[structure] += 1\n",
" return structures"
]
},
{
"cell_type": "code",
"execution_count": 138,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[('text/plain', 2408),\n",
" ('multipart(text/plain, application/pgp-signature)', 66),\n",
" ('multipart(text/plain, text/html)', 8),\n",
" ('multipart(text/plain, text/plain)', 4),\n",
" ('multipart(text/plain)', 3),\n",
" ('multipart(text/plain, application/octet-stream)', 2),\n",
" ('multipart(text/plain, text/enriched)', 1),\n",
" ('multipart(text/plain, application/ms-tnef, text/plain)', 1),\n",
" ('multipart(multipart(text/plain, text/plain, text/plain), application/pgp-signature)',\n",
" 1),\n",
" ('multipart(text/plain, video/mng)', 1),\n",
" ('multipart(text/plain, multipart(text/plain))', 1),\n",
" ('multipart(text/plain, application/x-pkcs7-signature)', 1),\n",
" ('multipart(text/plain, multipart(text/plain, text/plain), text/rfc822-headers)',\n",
" 1),\n",
" ('multipart(text/plain, multipart(text/plain, text/plain), multipart(multipart(text/plain, application/x-pkcs7-signature)))',\n",
" 1),\n",
" ('multipart(text/plain, application/x-java-applet)', 1)]"
]
},
"execution_count": 138,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"structures_counter(ham_emails).most_common()"
]
},
{
"cell_type": "code",
"execution_count": 139,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[('text/plain', 218),\n",
" ('text/html', 183),\n",
" ('multipart(text/plain, text/html)', 45),\n",
" ('multipart(text/html)', 20),\n",
" ('multipart(text/plain)', 19),\n",
" ('multipart(multipart(text/html))', 5),\n",
" ('multipart(text/plain, image/jpeg)', 3),\n",
" ('multipart(text/html, application/octet-stream)', 2),\n",
" ('multipart(text/plain, application/octet-stream)', 1),\n",
" ('multipart(text/html, text/plain)', 1),\n",
" ('multipart(multipart(text/html), application/octet-stream, image/jpeg)', 1),\n",
" ('multipart(multipart(text/plain, text/html), image/gif)', 1),\n",
" ('multipart/alternative', 1)]"
]
},
"execution_count": 139,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"structures_counter(spam_emails).most_common()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2018-01-22 21:08:12 +01:00
"It seems that the ham emails are more often plain text, while spam has quite a lot of HTML. Moreover, quite a few ham emails are signed using PGP, while no spam is. In short, it seems that the email structure is useful information to have."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's take a look at the email headers:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 140,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Return-Path : <12a1mailbot1@web.de>\n",
"Delivered-To : zzzz@localhost.spamassassin.taint.org\n",
"Received : from localhost (localhost [127.0.0.1])\tby phobos.labs.spamassassin.taint.org (Postfix) with ESMTP id 136B943C32\tfor <zzzz@localhost>; Thu, 22 Aug 2002 08:17:21 -0400 (EDT)\n",
"Received : from mail.webnote.net [193.120.211.219]\tby localhost with POP3 (fetchmail-5.9.0)\tfor zzzz@localhost (single-drop); Thu, 22 Aug 2002 13:17:21 +0100 (IST)\n",
"Received : from dd_it7 ([210.97.77.167])\tby webnote.net (8.9.3/8.9.3) with ESMTP id NAA04623\tfor <zzzz@spamassassin.taint.org>; Thu, 22 Aug 2002 13:09:41 +0100\n",
"From : 12a1mailbot1@web.de\n",
"Received : from r-smtp.korea.com - 203.122.2.197 by dd_it7 with Microsoft SMTPSVC(5.5.1775.675.6);\t Sat, 24 Aug 2002 09:42:10 +0900\n",
"To : dcek1a1@netsgo.com\n",
"Subject : Life Insurance - Why Pay More?\n",
"Date : Wed, 21 Aug 2002 20:31:57 -1600\n",
"MIME-Version : 1.0\n",
"Message-ID : <0103c1042001882DD_IT7@dd_it7>\n",
"Content-Type : text/html; charset=\"iso-8859-1\"\n",
"Content-Transfer-Encoding : quoted-printable\n"
]
}
],
"source": [
"for header, value in spam_emails[0].items():\n",
2021-12-08 03:16:42 +01:00
" print(header, \":\", value)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There's probably a lot of useful information in there, such as the sender's email address (12a1mailbot1@web.de looks fishy), but we will just focus on the `Subject` header:"
]
},
{
"cell_type": "code",
"execution_count": 141,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"'Life Insurance - Why Pay More?'"
]
},
"execution_count": 141,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"spam_emails[0][\"Subject\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Okay, before we learn too much about the data, let's not forget to split it into a training set and a test set:"
]
},
{
"cell_type": "code",
"execution_count": 142,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
2020-11-21 00:22:42 +01:00
"X = np.array(ham_emails + spam_emails, dtype=object)\n",
"y = np.array([0] * len(ham_emails) + [1] * len(spam_emails))\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,\n",
" random_state=42)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Okay, let's start writing the preprocessing functions. First, we will need a function to convert HTML to plain text. Arguably the best way to do this would be to use the great [BeautifulSoup](https://www.crummy.com/software/BeautifulSoup/) library, but I would like to avoid adding another dependency to this project, so let's hack a quick & dirty solution using regular expressions (at the risk of [un̨ho͞ly radiańcé destro҉ying all enli̍̈́̂̈́ghtenment](https://stackoverflow.com/a/1732454/38626)). The following function first drops the `<head>` section, then converts all `<a>` tags to the word HYPERLINK, then it gets rid of all HTML tags, leaving only the plain text. For readability, it also replaces multiple newlines with single newlines, and finally it unescapes html entities (such as `&gt;` or `&nbsp;`):"
]
},
{
"cell_type": "code",
"execution_count": 143,
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"from html import unescape\n",
"\n",
"def html_to_plain_text(html):\n",
" text = re.sub('<head.*?>.*?</head>', '', html, flags=re.M | re.S | re.I)\n",
" text = re.sub('<a\\s.*?>', ' HYPERLINK ', text, flags=re.M | re.S | re.I)\n",
" text = re.sub('<.*?>', '', text, flags=re.M | re.S)\n",
" text = re.sub(r'(\\s*\\n)+', '\\n', text, flags=re.M | re.S)\n",
" return unescape(text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see if it works. This is HTML spam:"
]
},
{
"cell_type": "code",
"execution_count": 144,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<HTML><HEAD><TITLE></TITLE><META http-equiv=\"Content-Type\" content=\"text/html; charset=windows-1252\"><STYLE>A:link {TEX-DECORATION: none}A:active {TEXT-DECORATION: none}A:visited {TEXT-DECORATION: none}A:hover {COLOR: #0033ff; TEXT-DECORATION: underline}</STYLE><META content=\"MSHTML 6.00.2713.1100\" name=\"GENERATOR\"></HEAD>\n",
"<BODY text=\"#000000\" vLink=\"#0033ff\" link=\"#0033ff\" bgColor=\"#CCCC99\"><TABLE borderColor=\"#660000\" cellSpacing=\"0\" cellPadding=\"0\" border=\"0\" width=\"100%\"><TR><TD bgColor=\"#CCCC99\" valign=\"top\" colspan=\"2\" height=\"27\">\n",
"<font size=\"6\" face=\"Arial, Helvetica, sans-serif\" color=\"#660000\">\n",
"<b>OTC</b></font></TD></TR><TR><TD height=\"2\" bgcolor=\"#6a694f\">\n",
"<font size=\"5\" face=\"Times New Roman, Times, serif\" color=\"#FFFFFF\">\n",
"<b>&nbsp;Newsletter</b></font></TD><TD height=\"2\" bgcolor=\"#6a694f\"><div align=\"right\"><font color=\"#FFFFFF\">\n",
"<b>Discover Tomorrow's Winners&nbsp;</b></font></div></TD></TR><TR><TD height=\"25\" colspan=\"2\" bgcolor=\"#CCCC99\"><table width=\"100%\" border=\"0\" ...\n"
]
}
],
"source": [
"html_spam_emails = [email for email in X_train[y_train==1]\n",
" if get_email_structure(email) == \"text/html\"]\n",
"sample_html_spam = html_spam_emails[7]\n",
"print(sample_html_spam.get_content().strip()[:1000], \"...\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And this is the resulting plain text:"
]
},
{
"cell_type": "code",
"execution_count": 145,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"OTC\n",
" Newsletter\n",
"Discover Tomorrow's Winners \n",
"For Immediate Release\n",
"Cal-Bay (Stock Symbol: CBYI)\n",
"Watch for analyst \"Strong Buy Recommendations\" and several advisory newsletters picking CBYI. CBYI has filed to be traded on the OTCBB, share prices historically INCREASE when companies get listed on this larger trading exchange. CBYI is trading around 25 cents and should skyrocket to $2.66 - $3.25 a share in the near future.\n",
"Put CBYI on your watch list, acquire a position TODAY.\n",
"REASONS TO INVEST IN CBYI\n",
"A profitable company and is on track to beat ALL earnings estimates!\n",
"One of the FASTEST growing distributors in environmental & safety equipment instruments.\n",
"Excellent management team, several EXCLUSIVE contracts. IMPRESSIVE client list including the U.S. Air Force, Anheuser-Busch, Chevron Refining and Mitsubishi Heavy Industries, GE-Energy & Environmental Research.\n",
"RAPIDLY GROWING INDUSTRY\n",
"Industry revenues exceed $900 million, estimates indicate that there could be as much as $25 billi ...\n"
]
}
],
"source": [
"print(html_to_plain_text(sample_html_spam.get_content())[:1000], \"...\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Great! Now let's write a function that takes an email as input and returns its content as plain text, whatever its format is:"
]
},
{
"cell_type": "code",
"execution_count": 146,
"metadata": {},
"outputs": [],
"source": [
"def email_to_text(email):\n",
" html = None\n",
" for part in email.walk():\n",
" ctype = part.get_content_type()\n",
" if not ctype in (\"text/plain\", \"text/html\"):\n",
" continue\n",
" try:\n",
" content = part.get_content()\n",
" except: # in case of encoding issues\n",
" content = str(part.get_payload())\n",
" if ctype == \"text/plain\":\n",
" return content\n",
" else:\n",
" html = content\n",
" if html:\n",
" return html_to_plain_text(html)"
]
},
{
"cell_type": "code",
"execution_count": 147,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"OTC\n",
" Newsletter\n",
"Discover Tomorrow's Winners \n",
"For Immediate Release\n",
"Cal-Bay (Stock Symbol: CBYI)\n",
"Wat ...\n"
]
}
],
"source": [
"print(email_to_text(sample_html_spam)[:100], \"...\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's throw in some stemming! We will use the Natural Language Toolkit ([NLTK](http://www.nltk.org/)):"
]
},
{
"cell_type": "code",
"execution_count": 148,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Computations => comput\n",
"Computation => comput\n",
"Computing => comput\n",
"Computed => comput\n",
"Compute => comput\n",
"Compulsive => compuls\n"
]
}
],
"source": [
"import nltk\n",
"\n",
"stemmer = nltk.PorterStemmer()\n",
"for word in (\"Computations\", \"Computation\", \"Computing\", \"Computed\", \"Compute\",\n",
" \"Compulsive\"):\n",
" print(word, \"=>\", stemmer.stem(word))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will also need a way to replace URLs with the word \"URL\". For this, we could use hard core [regular expressions](https://mathiasbynens.be/demo/url-regex) but we will just use the [urlextract](https://github.com/lipoja/URLExtract) library:"
]
},
{
"cell_type": "code",
"execution_count": 149,
"metadata": {},
"outputs": [],
"source": [
"# Is this notebook running on Colab or Kaggle?\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n",
"\n",
2021-05-25 02:07:29 +02:00
"# if running this notebook on Colab or Kaggle, we just pip install urlextract\n",
"if IS_COLAB or IS_KAGGLE:\n",
" %pip install -q -U urlextract"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note:** inside a Jupyter notebook, always use `%pip` instead of `!pip`, as `!pip` may install the library inside the wrong environment, while `%pip` makes sure it's installed inside the currently running environment."
]
},
{
"cell_type": "code",
"execution_count": 150,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['github.com', 'https://youtu.be/7Pq-S557XQU?t=3m32s']\n"
]
}
],
"source": [
"import urlextract # may require an Internet connection to download root domain\n",
" # names\n",
"\n",
"url_extractor = urlextract.URLExtract()\n",
"some_text = \"Will it detect github.com and https://youtu.be/7Pq-S557XQU?t=3m32s\"\n",
"print(url_extractor.find_urls(some_text))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We are ready to put all this together into a transformer that we will use to convert emails to word counters. Note that we split sentences into words using Python's `split()` method, which uses whitespaces for word boundaries. This works for many written languages, but not all. For example, Chinese and Japanese scripts generally don't use spaces between words, and Vietnamese often uses spaces even between syllables. It's okay in this exercise, because the dataset is (mostly) in English."
]
},
{
"cell_type": "code",
"execution_count": 151,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"\n",
"class EmailToWordCounterTransformer(BaseEstimator, TransformerMixin):\n",
" def __init__(self, strip_headers=True, lower_case=True,\n",
" remove_punctuation=True, replace_urls=True,\n",
" replace_numbers=True, stemming=True):\n",
" self.strip_headers = strip_headers\n",
" self.lower_case = lower_case\n",
" self.remove_punctuation = remove_punctuation\n",
" self.replace_urls = replace_urls\n",
" self.replace_numbers = replace_numbers\n",
" self.stemming = stemming\n",
" def fit(self, X, y=None):\n",
" return self\n",
" def transform(self, X, y=None):\n",
" X_transformed = []\n",
" for email in X:\n",
" text = email_to_text(email) or \"\"\n",
" if self.lower_case:\n",
" text = text.lower()\n",
" if self.replace_urls and url_extractor is not None:\n",
" urls = list(set(url_extractor.find_urls(text)))\n",
" urls.sort(key=lambda url: len(url), reverse=True)\n",
" for url in urls:\n",
" text = text.replace(url, \" URL \")\n",
" if self.replace_numbers:\n",
" text = re.sub(r'\\d+(?:\\.\\d*)?(?:[eE][+-]?\\d+)?', 'NUMBER', text)\n",
" if self.remove_punctuation:\n",
" text = re.sub(r'\\W+', ' ', text, flags=re.M)\n",
" word_counts = Counter(text.split())\n",
" if self.stemming and stemmer is not None:\n",
" stemmed_word_counts = Counter()\n",
" for word, count in word_counts.items():\n",
" stemmed_word = stemmer.stem(word)\n",
" stemmed_word_counts[stemmed_word] += count\n",
" word_counts = stemmed_word_counts\n",
" X_transformed.append(word_counts)\n",
" return np.array(X_transformed)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's try this transformer on a few emails:"
]
},
{
"cell_type": "code",
"execution_count": 152,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([Counter({'chuck': 1, 'murcko': 1, 'wrote': 1, 'stuff': 1, 'yawn': 1, 'r': 1}),\n",
" Counter({'the': 11, 'of': 9, 'and': 8, 'all': 3, 'christian': 3, 'to': 3, 'by': 3, 'jefferson': 2, 'i': 2, 'have': 2, 'superstit': 2, 'one': 2, 'on': 2, 'been': 2, 'ha': 2, 'half': 2, 'rogueri': 2, 'teach': 2, 'jesu': 2, 'some': 1, 'interest': 1, 'quot': 1, 'url': 1, 'thoma': 1, 'examin': 1, 'known': 1, 'word': 1, 'do': 1, 'not': 1, 'find': 1, 'in': 1, 'our': 1, 'particular': 1, 'redeem': 1, 'featur': 1, 'they': 1, 'are': 1, 'alik': 1, 'found': 1, 'fabl': 1, 'mytholog': 1, 'million': 1, 'innoc': 1, 'men': 1, 'women': 1, 'children': 1, 'sinc': 1, 'introduct': 1, 'burnt': 1, 'tortur': 1, 'fine': 1, 'imprison': 1, 'what': 1, 'effect': 1, 'thi': 1, 'coercion': 1, 'make': 1, 'world': 1, 'fool': 1, 'other': 1, 'hypocrit': 1, 'support': 1, 'error': 1, 'over': 1, 'earth': 1, 'six': 1, 'histor': 1, 'american': 1, 'john': 1, 'e': 1, 'remsburg': 1, 'letter': 1, 'william': 1, 'short': 1, 'again': 1, 'becom': 1, 'most': 1, 'pervert': 1, 'system': 1, 'that': 1, 'ever': 1, 'shone': 1, 'man': 1, 'absurd': 1, 'untruth': 1, 'were': 1, 'perpetr': 1, 'upon': 1, 'a': 1, 'larg': 1, 'band': 1, 'dupe': 1, 'import': 1, 'led': 1, 'paul': 1, 'first': 1, 'great': 1, 'corrupt': 1}),\n",
" Counter({'url': 4, 's': 3, 'group': 3, 'to': 3, 'in': 2, 'forteana': 2, 'martin': 2, 'an': 2, 'and': 2, 'we': 2, 'is': 2, 'yahoo': 2, 'unsubscrib': 2, 'y': 1, 'adamson': 1, 'wrote': 1, 'for': 1, 'altern': 1, 'rather': 1, 'more': 1, 'factual': 1, 'base': 1, 'rundown': 1, 'on': 1, 'hamza': 1, 'career': 1, 'includ': 1, 'hi': 1, 'belief': 1, 'that': 1, 'all': 1, 'non': 1, 'muslim': 1, 'yemen': 1, 'should': 1, 'be': 1, 'murder': 1, 'outright': 1, 'know': 1, 'how': 1, 'unbias': 1, 'memri': 1, 'don': 1, 't': 1, 'html': 1, 'rob': 1, 'sponsor': 1, 'number': 1, 'dvd': 1, 'free': 1, 'p': 1, 'join': 1, 'now': 1, 'from': 1, 'thi': 1, 'send': 1, 'email': 1, 'egroup': 1, 'com': 1, 'your': 1, 'use': 1, 'of': 1, 'subject': 1})],\n",
" dtype=object)"
]
},
"execution_count": 152,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_few = X_train[:3]\n",
"X_few_wordcounts = EmailToWordCounterTransformer().fit_transform(X_few)\n",
"X_few_wordcounts"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This looks about right!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we have the word counts, and we need to convert them to vectors. For this, we will build another transformer whose `fit()` method will build the vocabulary (an ordered list of the most common words) and whose `transform()` method will use the vocabulary to convert word counts to vectors. The output is a sparse matrix."
]
},
{
"cell_type": "code",
"execution_count": 153,
"metadata": {},
"outputs": [],
"source": [
"from scipy.sparse import csr_matrix\n",
"\n",
"class WordCounterToVectorTransformer(BaseEstimator, TransformerMixin):\n",
" def __init__(self, vocabulary_size=1000):\n",
" self.vocabulary_size = vocabulary_size\n",
" def fit(self, X, y=None):\n",
" total_count = Counter()\n",
" for word_count in X:\n",
" for word, count in word_count.items():\n",
" total_count[word] += min(count, 10)\n",
" most_common = total_count.most_common()[:self.vocabulary_size]\n",
" self.vocabulary_ = {word: index + 1\n",
" for index, (word, count) in enumerate(most_common)}\n",
" return self\n",
" def transform(self, X, y=None):\n",
" rows = []\n",
" cols = []\n",
" data = []\n",
" for row, word_count in enumerate(X):\n",
" for word, count in word_count.items():\n",
" rows.append(row)\n",
" cols.append(self.vocabulary_.get(word, 0))\n",
" data.append(count)\n",
" return csr_matrix((data, (rows, cols)),\n",
" shape=(len(X), self.vocabulary_size + 1))"
]
},
{
"cell_type": "code",
"execution_count": 154,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<3x11 sparse matrix of type '<class 'numpy.int64'>'\n",
"\twith 20 stored elements in Compressed Sparse Row format>"
]
},
"execution_count": 154,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vocab_transformer = WordCounterToVectorTransformer(vocabulary_size=10)\n",
"X_few_vectors = vocab_transformer.fit_transform(X_few_wordcounts)\n",
"X_few_vectors"
]
},
{
"cell_type": "code",
"execution_count": 155,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[ 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [99, 11, 9, 8, 3, 1, 3, 1, 3, 2, 3],\n",
" [67, 0, 1, 2, 3, 4, 1, 2, 0, 1, 0]])"
]
},
"execution_count": 155,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_few_vectors.toarray()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What does this matrix mean? Well, the 99 in the second row, first column, means that the second email contains 99 words that are not part of the vocabulary. The 11 next to it means that the first word in the vocabulary is present 11 times in this email. The 9 next to it means that the second word is present 9 times, and so on. You can look at the vocabulary to know which words we are talking about. The first word is \"the\", the second word is \"of\", etc."
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"{'the': 1,\n",
" 'of': 2,\n",
" 'and': 3,\n",
" 'to': 4,\n",
" 'url': 5,\n",
" 'all': 6,\n",
" 'in': 7,\n",
" 'christian': 8,\n",
" 'on': 9,\n",
" 'by': 10}"
]
},
"execution_count": 156,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vocab_transformer.vocabulary_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We are now ready to train our first spam classifier! Let's transform the whole dataset:"
]
},
{
"cell_type": "code",
"execution_count": 157,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline\n",
"\n",
"preprocess_pipeline = Pipeline([\n",
" (\"email_to_wordcount\", EmailToWordCounterTransformer()),\n",
" (\"wordcount_to_vector\", WordCounterToVectorTransformer()),\n",
"])\n",
"\n",
"X_train_transformed = preprocess_pipeline.fit_transform(X_train)"
]
},
{
"cell_type": "code",
"execution_count": 158,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.985"
]
},
"execution_count": 158,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.model_selection import cross_val_score\n",
"\n",
"log_clf = LogisticRegression(max_iter=1000, random_state=42)\n",
"score = cross_val_score(log_clf, X_train_transformed, y_train, cv=3)\n",
"score.mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-11-21 00:22:42 +01:00
"Over 98.5%, not bad for a first try! :) However, remember that we are using the \"easy\" dataset. You can try with the harder datasets, the results won't be so amazing. You would have to try multiple models, select the best ones and fine-tune them using cross-validation, and so on.\n",
"\n",
"But you get the picture, so let's stop now, and just print out the precision/recall we get on the test set:"
]
},
{
"cell_type": "code",
"execution_count": 159,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Precision: 96.88%\n",
"Recall: 97.89%\n"
]
}
],
"source": [
"from sklearn.metrics import precision_score, recall_score\n",
"\n",
"X_test_transformed = preprocess_pipeline.transform(X_test)\n",
"\n",
"log_clf = LogisticRegression(max_iter=1000, random_state=42)\n",
"log_clf.fit(X_train_transformed, y_train)\n",
"\n",
"y_pred = log_clf.predict(X_test_transformed)\n",
"\n",
2021-11-21 22:17:46 +01:00
"print(f\"Precision: {precision_score(y_test, y_pred):.2%}\")\n",
"print(f\"Recall: {recall_score(y_test, y_pred):.2%}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2016-05-22 17:40:18 +02:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
2016-05-22 17:40:18 +02:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
2016-05-22 17:40:18 +02:00
},
2016-09-27 23:31:21 +02:00
"nav_menu": {},
2016-05-22 17:40:18 +02:00
"toc": {
2016-09-27 23:31:21 +02:00
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
2016-05-22 17:40:18 +02:00
"toc_cell": false,
2016-09-27 23:31:21 +02:00
"toc_section_display": "block",
2016-05-22 17:40:18 +02:00
"toc_window_display": false
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
2016-05-22 17:40:18 +02:00
}