2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-11-23 03:42:16 +01:00
"**Chapter 9 – Unsupervised Learning**"
2021-10-15 10:46:27 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-23 03:42:16 +01:00
"_This notebook contains all the sample code and solutions to the exercises in chapter 9._"
2016-09-27 23:31:21 +02:00
]
},
2019-11-05 15:26:52 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
2021-11-23 03:42:16 +01:00
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/09_unsupervised_learning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
2019-11-05 15:26:52 +01:00
" </td>\n",
2021-05-25 21:31:19 +02:00
" <td>\n",
2021-11-23 03:42:16 +01:00
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/09_unsupervised_learning.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
2021-05-25 21:31:19 +02:00
" </td>\n",
2019-11-05 15:26:52 +01:00
"</table>"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
2021-11-22 09:36:00 +01:00
"metadata": {
"tags": []
},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"This project requires Python 3.8 or above:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 1,
2018-03-15 18:51:08 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2019-01-16 16:42:53 +01:00
"import sys\n",
2019-01-21 11:13:10 +01:00
"\n",
2021-11-22 09:36:00 +01:00
"assert sys.version_info >= (3, 8)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It also requires Scikit-Learn ≥ 1.0.1:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
2021-10-15 10:46:27 +02:00
"import sklearn\n",
2016-09-27 23:31:21 +02:00
"\n",
2021-11-22 09:36:00 +01:00
"assert sklearn.__version__ >= \"1.0.1\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we did in previous chapters, let's define the default font sizes to make the figures prettier:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
2021-11-27 11:03:26 +01:00
"import matplotlib.pyplot as plt\n",
2016-09-27 23:31:21 +02:00
"\n",
2021-11-27 11:03:26 +01:00
"plt.rc('font', size=14)\n",
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
2022-02-19 06:17:36 +01:00
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)"
2021-11-22 09:36:00 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And let's create the `images/unsupervised_learning` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
2021-10-15 10:46:27 +02:00
"IMAGES_PATH = Path() / \"images\" / \"unsupervised_learning\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
2019-01-21 11:13:10 +01:00
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
2021-10-15 10:46:27 +02:00
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
2016-09-27 23:31:21 +02:00
" if tight_layout:\n",
" plt.tight_layout()\n",
2021-02-14 03:02:09 +01:00
" plt.savefig(path, format=fig_extension, dpi=resolution)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"# Clustering"
2016-09-27 23:31:21 +02:00
]
},
2017-09-15 17:52:20 +02:00
{
2021-11-22 09:36:00 +01:00
"cell_type": "markdown",
2017-09-15 17:52:20 +02:00
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"**Introduction – Classification _vs_ Clustering**"
2017-09-15 17:52:20 +02:00
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 5,
2018-03-15 18:51:08 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 1\n",
2021-11-22 09:36:00 +01:00
"\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.datasets import load_iris\n",
"\n",
2019-01-15 05:36:29 +01:00
"data = load_iris()\n",
"X = data.data\n",
"y = data.target\n",
2021-11-22 09:36:00 +01:00
"data.target_names\n",
"\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(9, 3.5))\n",
"\n",
"plt.subplot(121)\n",
2019-08-12 08:45:33 +02:00
"plt.plot(X[y==0, 2], X[y==0, 3], \"yo\", label=\"Iris setosa\")\n",
"plt.plot(X[y==1, 2], X[y==1, 3], \"bs\", label=\"Iris versicolor\")\n",
"plt.plot(X[y==2, 2], X[y==2, 3], \"g^\", label=\"Iris virginica\")\n",
2021-11-22 09:36:00 +01:00
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.grid()\n",
"plt.legend()\n",
2019-01-15 05:36:29 +01:00
"\n",
"plt.subplot(122)\n",
"plt.scatter(X[:, 2], X[:, 3], c=\"k\", marker=\".\")\n",
2021-11-22 09:36:00 +01:00
"plt.xlabel(\"Petal length\")\n",
2019-01-15 05:36:29 +01:00
"plt.tick_params(labelleft=False)\n",
2021-11-22 09:36:00 +01:00
"plt.gca().set_axisbelow(True)\n",
"plt.grid()\n",
2017-06-02 16:02:35 +02:00
"\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"classification_vs_clustering_plot\")\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"**Note**: the next cell shows how a Gaussian mixture model (explained later in this chapter) can actually separate these clusters pretty well using all 4 features: petal length & width, and sepal length & width. This code maps each cluster to a class. Instead of hard coding the mapping, the code picks the most common class for each cluster using the `scipy.stats.mode()` function:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2019-01-15 05:36:29 +01:00
"execution_count": 6,
2018-03-15 18:51:08 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code\n",
2021-11-22 09:36:00 +01:00
"\n",
"import numpy as np\n",
2021-02-14 03:02:09 +01:00
"from scipy import stats\n",
2021-11-22 09:36:00 +01:00
"from sklearn.mixture import GaussianMixture\n",
"\n",
"y_pred = GaussianMixture(n_components=3, random_state=42).fit(X).predict(X)\n",
2021-02-14 03:02:09 +01:00
"\n",
"mapping = {}\n",
"for class_id in np.unique(y):\n",
" mode, _ = stats.mode(y_pred[y==class_id])\n",
" mapping[mode[0]] = class_id\n",
"\n",
2021-11-22 09:36:00 +01:00
"y_pred = np.array([mapping[cluster_id] for cluster_id in y_pred])\n",
"\n",
2019-01-15 05:36:29 +01:00
"plt.plot(X[y_pred==0, 2], X[y_pred==0, 3], \"yo\", label=\"Cluster 1\")\n",
"plt.plot(X[y_pred==1, 2], X[y_pred==1, 3], \"bs\", label=\"Cluster 2\")\n",
"plt.plot(X[y_pred==2, 2], X[y_pred==2, 3], \"g^\", label=\"Cluster 3\")\n",
2021-11-22 09:36:00 +01:00
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper left\")\n",
"plt.grid()\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2017-06-02 16:02:35 +02:00
]
},
{
2021-11-22 09:36:00 +01:00
"cell_type": "markdown",
2018-03-15 18:51:08 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"What's the ratio of iris plants we assigned to the right cluster?"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 7,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"(y_pred==y).sum() / len(y_pred)"
2021-02-14 03:02:09 +01:00
]
},
2017-06-02 16:02:35 +02:00
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"## K-Means"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2019-01-15 05:36:29 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"**Fit and predict**"
2016-09-27 23:31:21 +02:00
]
},
{
2021-11-22 09:36:00 +01:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"Let's train a K-Means clusterer on a dataset if blobs. It will try to find each blob's center and assign each instance to the closest blob:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 8,
2018-03-15 18:51:08 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"from sklearn.cluster import KMeans\n",
"from sklearn.datasets import make_blobs\n",
"\n",
2022-02-19 06:17:36 +01:00
"# extra code – the exact arguments of make_blobs() are not important\n",
2021-11-22 09:36:00 +01:00
"blob_centers = np.array([[ 0.2, 2.3], [-1.5 , 2.3], [-2.8, 1.8],\n",
" [-2.8, 2.8], [-2.8, 1.3]])\n",
"blob_std = np.array([0.4, 0.3, 0.1, 0.1, 0.1])\n",
"X, y = make_blobs(n_samples=2000, centers=blob_centers, cluster_std=blob_std,\n",
" random_state=7)\n",
"\n",
"k = 5\n",
"kmeans = KMeans(n_clusters=k, random_state=42)\n",
"y_pred = kmeans.fit_predict(X)"
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"Now let's plot them:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 9,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 2\n",
2021-11-22 09:36:00 +01:00
"\n",
2019-01-15 05:36:29 +01:00
"def plot_clusters(X, y=None):\n",
" plt.scatter(X[:, 0], X[:, 1], c=y, s=1)\n",
2021-11-22 09:36:00 +01:00
" plt.xlabel(\"$x_1$\")\n",
" plt.ylabel(\"$x_2$\", rotation=0)\n",
"\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(8, 4))\n",
"plot_clusters(X)\n",
2021-11-22 09:36:00 +01:00
"plt.gca().set_axisbelow(True)\n",
"plt.grid()\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"blobs_plot\")\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"Each instance was assigned to one of the 5 clusters:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 10,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"y_pred"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 11,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"y_pred is kmeans.labels_"
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"And the following 5 _centroids_ (i.e., cluster centers) were estimated:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 12,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"kmeans.cluster_centers_"
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"Note that the `KMeans` instance preserves the labels of the instances it was trained on. Somewhat confusingly, in this context, the _label_ of an instance is the index of the cluster that instance gets assigned to (they are not targets, they are predictions):"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 13,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"kmeans.labels_"
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"Of course, we can predict the labels of new instances:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 14,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"import numpy as np\n",
"\n",
2019-01-15 05:36:29 +01:00
"X_new = np.array([[0, 2], [3, 2], [-3, 3], [-3, 2.5]])\n",
"kmeans.predict(X_new)"
2016-09-27 23:31:21 +02:00
]
},
{
2017-06-02 16:02:35 +02:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2021-10-03 12:05:49 +02:00
"**Decision Boundaries**"
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"Let's plot the model's decision boundaries. This gives us a _Voronoi diagram_:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 15,
2018-03-15 18:51:08 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 3\n",
2021-11-22 09:36:00 +01:00
"\n",
2019-01-15 05:36:29 +01:00
"def plot_data(X):\n",
" plt.plot(X[:, 0], X[:, 1], 'k.', markersize=2)\n",
"\n",
"def plot_centroids(centroids, weights=None, circle_color='w', cross_color='k'):\n",
" if weights is not None:\n",
" centroids = centroids[weights > weights.max() / 10]\n",
" plt.scatter(centroids[:, 0], centroids[:, 1],\n",
2021-02-14 03:02:09 +01:00
" marker='o', s=35, linewidths=8,\n",
2019-01-15 05:36:29 +01:00
" color=circle_color, zorder=10, alpha=0.9)\n",
" plt.scatter(centroids[:, 0], centroids[:, 1],\n",
2021-02-14 03:02:09 +01:00
" marker='x', s=2, linewidths=12,\n",
2019-01-15 05:36:29 +01:00
" color=cross_color, zorder=11, alpha=1)\n",
"\n",
"def plot_decision_boundaries(clusterer, X, resolution=1000, show_centroids=True,\n",
" show_xlabels=True, show_ylabels=True):\n",
" mins = X.min(axis=0) - 0.1\n",
" maxs = X.max(axis=0) + 0.1\n",
" xx, yy = np.meshgrid(np.linspace(mins[0], maxs[0], resolution),\n",
" np.linspace(mins[1], maxs[1], resolution))\n",
" Z = clusterer.predict(np.c_[xx.ravel(), yy.ravel()])\n",
" Z = Z.reshape(xx.shape)\n",
2017-06-02 16:02:35 +02:00
"\n",
2019-01-15 05:36:29 +01:00
" plt.contourf(Z, extent=(mins[0], maxs[0], mins[1], maxs[1]),\n",
" cmap=\"Pastel2\")\n",
" plt.contour(Z, extent=(mins[0], maxs[0], mins[1], maxs[1]),\n",
" linewidths=1, colors='k')\n",
" plot_data(X)\n",
" if show_centroids:\n",
" plot_centroids(clusterer.cluster_centers_)\n",
2017-06-02 16:02:35 +02:00
"\n",
2019-01-15 05:36:29 +01:00
" if show_xlabels:\n",
2021-11-22 09:36:00 +01:00
" plt.xlabel(\"$x_1$\")\n",
2019-01-15 05:36:29 +01:00
" else:\n",
" plt.tick_params(labelbottom=False)\n",
" if show_ylabels:\n",
2021-11-22 09:36:00 +01:00
" plt.ylabel(\"$x_2$\", rotation=0)\n",
2019-01-15 05:36:29 +01:00
" else:\n",
2021-11-22 09:36:00 +01:00
" plt.tick_params(labelleft=False)\n",
"\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(8, 4))\n",
"plot_decision_boundaries(kmeans, X)\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"voronoi_plot\")\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-03-15 18:51:08 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"Not bad! Some of the instances near the edges were probably assigned to the wrong cluster, but overall it looks pretty good."
2016-09-27 23:31:21 +02:00
]
},
{
2017-06-02 16:02:35 +02:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-10-03 12:05:49 +02:00
"**Hard Clustering _vs_ Soft Clustering**"
2016-09-27 23:31:21 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"Rather than arbitrarily choosing the closest cluster for each instance, which is called _hard clustering_, it might be better to measure the distance of each instance to all 5 centroids. This is what the `transform()` method does:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 16,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"kmeans.transform(X_new).round(2)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"You can verify that this is indeed the Euclidian distance between each instance and each centroid:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 17,
2018-03-15 18:51:08 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code\n",
2021-11-22 09:36:00 +01:00
"np.linalg.norm(np.tile(X_new, (1, k)).reshape(-1, k, 2)\n",
" - kmeans.cluster_centers_, axis=2).round(2)"
2016-09-27 23:31:21 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-10-03 12:05:49 +02:00
"### The K-Means Algorithm"
2016-09-27 23:31:21 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-02-14 03:02:09 +01:00
"The K-Means algorithm is one of the fastest clustering algorithms, and also one of the simplest:\n",
2021-11-22 09:36:00 +01:00
"* First initialize $k$ centroids randomly: e.g., $k$ distinct instances are chosen randomly from the dataset and the centroids are placed at their locations.\n",
2019-01-15 05:36:29 +01:00
"* Repeat until convergence (i.e., until the centroids stop moving):\n",
" * Assign each instance to the closest centroid.\n",
" * Update the centroids to be the mean of the instances that are assigned to them."
2016-09-27 23:31:21 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"The `KMeans` class uses an optimized initialization technique by default. To get the original K-Means algorithm (for educational purposes only), you must set `init=\"random\"` and `n_init=1`. More on this later in this chapter."
2019-01-15 05:36:29 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's run the K-Means algorithm for 1, 2 and 3 iterations, to see how the centroids move around:"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 18,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 4\n",
2021-11-22 09:36:00 +01:00
"\n",
"kmeans_iter1 = KMeans(n_clusters=5, init=\"random\", n_init=1, max_iter=1,\n",
" random_state=5)\n",
"kmeans_iter2 = KMeans(n_clusters=5, init=\"random\", n_init=1, max_iter=2,\n",
" random_state=5)\n",
"kmeans_iter3 = KMeans(n_clusters=5, init=\"random\", n_init=1, max_iter=3,\n",
" random_state=5)\n",
2019-01-15 05:36:29 +01:00
"kmeans_iter1.fit(X)\n",
"kmeans_iter2.fit(X)\n",
2021-11-22 09:36:00 +01:00
"kmeans_iter3.fit(X)\n",
"\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(10, 8))\n",
2016-09-27 23:31:21 +02:00
"\n",
2019-01-15 05:36:29 +01:00
"plt.subplot(321)\n",
"plot_data(X)\n",
"plot_centroids(kmeans_iter1.cluster_centers_, circle_color='r', cross_color='w')\n",
2021-11-22 09:36:00 +01:00
"plt.ylabel(\"$x_2$\", rotation=0)\n",
2019-01-15 05:36:29 +01:00
"plt.tick_params(labelbottom=False)\n",
2021-11-22 09:36:00 +01:00
"plt.title(\"Update the centroids (initially randomly)\")\n",
2016-09-27 23:31:21 +02:00
"\n",
2019-01-15 05:36:29 +01:00
"plt.subplot(322)\n",
2021-11-22 09:36:00 +01:00
"plot_decision_boundaries(kmeans_iter1, X, show_xlabels=False,\n",
" show_ylabels=False)\n",
"plt.title(\"Label the instances\")\n",
2016-09-27 23:31:21 +02:00
"\n",
2019-01-15 05:36:29 +01:00
"plt.subplot(323)\n",
2021-11-22 09:36:00 +01:00
"plot_decision_boundaries(kmeans_iter1, X, show_centroids=False,\n",
" show_xlabels=False)\n",
2019-01-15 05:36:29 +01:00
"plot_centroids(kmeans_iter2.cluster_centers_)\n",
2016-09-27 23:31:21 +02:00
"\n",
2019-01-15 05:36:29 +01:00
"plt.subplot(324)\n",
2021-11-22 09:36:00 +01:00
"plot_decision_boundaries(kmeans_iter2, X, show_xlabels=False,\n",
" show_ylabels=False)\n",
2016-09-27 23:31:21 +02:00
"\n",
2019-01-15 05:36:29 +01:00
"plt.subplot(325)\n",
"plot_decision_boundaries(kmeans_iter2, X, show_centroids=False)\n",
"plot_centroids(kmeans_iter3.cluster_centers_)\n",
2016-09-27 23:31:21 +02:00
"\n",
2019-01-15 05:36:29 +01:00
"plt.subplot(326)\n",
"plot_decision_boundaries(kmeans_iter3, X, show_ylabels=False)\n",
2016-09-27 23:31:21 +02:00
"\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"kmeans_algorithm_plot\")\n",
2016-09-27 23:31:21 +02:00
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-10-03 12:05:49 +02:00
"**K-Means Variability**"
2016-09-27 23:31:21 +02:00
]
},
2017-04-07 21:33:53 +02:00
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-03-15 18:51:08 +01:00
"metadata": {},
2017-04-07 21:33:53 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"In the original K-Means algorithm, the centroids are just initialized randomly, and the algorithm simply runs a single iteration to gradually improve the centroids, as we saw above.\n",
"\n",
"However, one major problem with this approach is that if you run K-Means multiple times (or with different random seeds), it can converge to very different solutions, as you can see below:"
2017-04-07 21:33:53 +02:00
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 19,
2018-03-15 18:51:08 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 5\n",
2021-11-22 09:36:00 +01:00
"\n",
"def plot_clusterer_comparison(clusterer1, clusterer2, X, title1=None,\n",
" title2=None):\n",
2019-01-15 05:36:29 +01:00
" clusterer1.fit(X)\n",
" clusterer2.fit(X)\n",
"\n",
" plt.figure(figsize=(10, 3.2))\n",
2016-09-27 23:31:21 +02:00
"\n",
2019-01-15 05:36:29 +01:00
" plt.subplot(121)\n",
" plot_decision_boundaries(clusterer1, X)\n",
" if title1:\n",
2021-11-22 09:36:00 +01:00
" plt.title(title1)\n",
2016-09-27 23:31:21 +02:00
"\n",
2019-01-15 05:36:29 +01:00
" plt.subplot(122)\n",
" plot_decision_boundaries(clusterer2, X, show_ylabels=False)\n",
" if title2:\n",
2021-11-22 09:36:00 +01:00
" plt.title(title2)\n",
"\n",
"kmeans_rnd_init1 = KMeans(n_clusters=5, init=\"random\", n_init=1, random_state=2)\n",
"kmeans_rnd_init2 = KMeans(n_clusters=5, init=\"random\", n_init=1, random_state=9)\n",
2019-01-15 05:36:29 +01:00
"\n",
"plot_clusterer_comparison(kmeans_rnd_init1, kmeans_rnd_init2, X,\n",
2021-11-22 09:36:00 +01:00
" \"Solution 1\",\n",
" \"Solution 2 (with a different random init)\")\n",
2019-01-15 05:36:29 +01:00
"\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"kmeans_variability_plot\")\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2017-06-02 16:02:35 +02:00
]
},
2021-11-22 09:36:00 +01:00
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"good_init = np.array([[-3, 3], [-3, 2], [-3, 1], [-1, 2], [0, 2]])\n",
"kmeans = KMeans(n_clusters=5, init=good_init, n_init=1, random_state=42)\n",
"kmeans.fit(X)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code\n",
2021-11-22 09:36:00 +01:00
"plt.figure(figsize=(8, 4))\n",
"plot_decision_boundaries(kmeans, X)"
]
},
2017-06-02 16:02:35 +02:00
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"### Inertia"
2016-09-27 23:31:21 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-03-15 18:51:08 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"To select the best model, we will need a way to evaluate a K-Mean model's performance. Unfortunately, clustering is an unsupervised task, so we do not have the targets. But at least we can measure the distance between each instance and its centroid. This is the idea behind the _inertia_ metric:"
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 22,
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"kmeans.inertia_"
2016-09-27 23:31:21 +02:00
]
},
{
2021-11-22 09:36:00 +01:00
"cell_type": "code",
"execution_count": 23,
2017-06-24 17:23:47 +02:00
"metadata": {},
2021-11-22 09:36:00 +01:00
"outputs": [],
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"kmeans_rnd_init1.inertia_ # extra code"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 24,
2018-03-15 18:51:08 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"kmeans_rnd_init2.inertia_ # extra code"
2016-09-27 23:31:21 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-03-15 18:51:08 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"As you can easily verify, inertia is the sum of the squared distances between each training instance and its closest centroid:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 25,
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code\n",
2021-11-22 09:36:00 +01:00
"X_dist = kmeans.transform(X)\n",
"(X_dist[np.arange(len(X_dist)), kmeans.labels_] ** 2).sum()"
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"The `score()` method returns the negative inertia. Why negative? Well, it is because a predictor's `score()` method must always respect the \"_greater is better_\" rule."
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 26,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"kmeans.score(X)"
2017-06-02 16:02:35 +02:00
]
},
{
2021-11-22 09:36:00 +01:00
"cell_type": "markdown",
2018-03-15 18:51:08 +01:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"### Multiple Initializations"
2017-06-02 16:02:35 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"So one approach to solve the variability issue is to simply run the K-Means algorithm multiple times with different random initializations, and select the solution that minimizes the inertia."
2019-01-15 05:36:29 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When you set the `n_init` hyperparameter, Scikit-Learn runs the original algorithm `n_init` times, and selects the solution that minimizes the inertia. By default, Scikit-Learn sets `n_init=10`."
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 27,
2018-03-15 18:51:08 +01:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code\n",
2019-01-15 05:36:29 +01:00
"kmeans_rnd_10_inits = KMeans(n_clusters=5, init=\"random\", n_init=10,\n",
2021-11-22 09:36:00 +01:00
" random_state=2)\n",
2019-01-15 05:36:29 +01:00
"kmeans_rnd_10_inits.fit(X)"
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"As you can see, we end up with the initial model, which is certainly the optimal K-Means solution (at least in terms of inertia, and assuming $k=5$)."
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 28,
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(8, 4))\n",
"plot_decision_boundaries(kmeans_rnd_10_inits, X)\n",
"plt.show()"
2017-06-02 16:02:35 +02:00
]
},
2021-11-22 09:36:00 +01:00
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"kmeans_rnd_10_inits.inertia_"
]
},
2017-06-02 16:02:35 +02:00
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2021-10-03 12:05:49 +02:00
"### Centroid initialization methods"
2017-06-02 16:02:35 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"Instead of initializing the centroids entirely randomly, it is preferable to initialize them using the following algorithm, proposed in a [2006 paper](https://goo.gl/eNUPw6) by David Arthur and Sergei Vassilvitskii:\n",
"* Take one centroid $c_1$, chosen uniformly at random from the dataset.\n",
"* Take a new center $c_i$, choosing an instance $\\mathbf{x}_i$ with probability: $D(\\mathbf{x}_i)^2$ / $\\sum\\limits_{j=1}^{m}{D(\\mathbf{x}_j)}^2$ where $D(\\mathbf{x}_i)$ is the distance between the instance $\\mathbf{x}_i$ and the closest centroid that was already chosen. This probability distribution ensures that instances that are further away from already chosen centroids are much more likely be selected as centroids.\n",
"* Repeat the previous step until all $k$ centroids have been chosen."
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"The rest of the K-Means++ algorithm is just regular K-Means. With this initialization, the K-Means algorithm is much less likely to converge to a suboptimal solution, so it is possible to reduce `n_init` considerably. Most of the time, this largely compensates for the additional complexity of the initialization process."
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"To set the initialization to K-Means++, simply set `init=\"k-means++\"` (this is actually the default):"
2016-09-27 23:31:21 +02:00
]
},
2017-06-02 16:02:35 +02:00
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"### Accelerated K-Means"
2016-09-27 23:31:21 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-03-15 18:51:08 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"The K-Means algorithm can sometimes be accelerated by avoiding many unnecessary distance calculations: this is achieved by exploiting the triangle inequality (given three points A, B and C, the distance AC is always such that AC ≤ AB + BC) and by keeping track of lower and upper bounds for distances between instances and centroids (see this [2003 paper](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf) by Charles Elkan for more details)."
2021-02-14 03:02:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"For Elkan's variant of K-Means, use `algorithm=\"elkan\"`. For regular KMeans, use `algorithm=\"full\"`. The default is `\"auto\"`, which uses the full algorithm since Scikit-Learn 1.1 (it used Elkan's algorithm before that)."
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"### Mini-Batch K-Means"
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"Scikit-Learn also implements a variant of the K-Means algorithm that supports mini-batches (see [this paper](http://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf)):"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 30,
2017-06-24 17:23:47 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2017-06-02 16:02:35 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"from sklearn.cluster import MiniBatchKMeans\n",
"\n",
2019-01-15 05:36:29 +01:00
"minibatch_kmeans = MiniBatchKMeans(n_clusters=5, random_state=42)\n",
"minibatch_kmeans.fit(X)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 31,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"minibatch_kmeans.inertia_"
2017-06-02 16:02:35 +02:00
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"**Using `MiniBatchKMeans` along with `memmap`** (not in the book)"
2016-09-27 23:31:21 +02:00
]
},
2021-03-01 21:29:06 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"If the dataset does not fit in memory, the simplest option is to use the `memmap` class, just like we did for incremental PCA in the previous chapter. First let's load MNIST:"
2021-03-01 21:29:06 +01:00
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 32,
2019-01-15 05:36:29 +01:00
"metadata": {},
"outputs": [],
"source": [
2019-05-19 11:19:29 +02:00
"from sklearn.datasets import fetch_openml\n",
"\n",
2021-11-22 09:36:00 +01:00
"mnist = fetch_openml('mnist_784', as_frame=False)"
2016-09-27 23:31:21 +02:00
]
},
{
2017-06-02 16:02:35 +02:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"Let's split the dataset:"
2016-09-27 23:31:21 +02:00
]
},
{
2017-06-02 16:02:35 +02:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 33,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"X_train, y_train = mnist.data[:60000], mnist.target[:60000]\n",
"X_test, y_test = mnist.data[60000:], mnist.target[60000:]"
2019-01-15 05:36:29 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"Next, let's write the training set to a `memmap`:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 34,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"filename = \"my_mnist.mmap\"\n",
"X_memmap = np.memmap(filename, dtype='float32', mode='write',\n",
" shape=X_train.shape)\n",
"X_memmap[:] = X_train\n",
"X_memmap.flush()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 35,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"from sklearn.cluster import MiniBatchKMeans\n",
2016-09-27 23:31:21 +02:00
"\n",
2021-11-22 09:36:00 +01:00
"minibatch_kmeans = MiniBatchKMeans(n_clusters=10, batch_size=10,\n",
" random_state=42)\n",
"minibatch_kmeans.fit(X_memmap)"
2016-09-27 23:31:21 +02:00
]
},
{
2017-06-02 16:02:35 +02:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"Let's plot the inertia ratio and the training time ratio between Mini-batch K-Means and regular K-Means:"
2016-09-27 23:31:21 +02:00
]
},
{
2017-06-02 16:02:35 +02:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 36,
2018-04-04 11:49:00 +02:00
"metadata": {},
2017-06-02 16:02:35 +02:00
"outputs": [],
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 6\n",
2021-11-22 09:36:00 +01:00
"\n",
"from timeit import timeit\n",
"\n",
"max_k = 100\n",
"times = np.empty((max_k, 2))\n",
"inertias = np.empty((max_k, 2))\n",
"for k in range(1, max_k + 1):\n",
" kmeans_ = KMeans(n_clusters=k, algorithm=\"full\", random_state=42)\n",
2019-01-15 05:36:29 +01:00
" minibatch_kmeans = MiniBatchKMeans(n_clusters=k, random_state=42)\n",
2021-11-22 09:36:00 +01:00
" print(f\"\\r{k}/{max_k}\", end=\"\") # \\r returns to the start of line\n",
" times[k - 1, 0] = timeit(\"kmeans_.fit(X)\", number=10, globals=globals())\n",
" times[k - 1, 1] = timeit(\"minibatch_kmeans.fit(X)\", number=10,\n",
" globals=globals())\n",
" inertias[k - 1, 0] = kmeans_.inertia_\n",
" inertias[k - 1, 1] = minibatch_kmeans.inertia_\n",
"\n",
2022-02-19 06:17:36 +01:00
"plt.figure(figsize=(10, 4))\n",
2019-01-15 05:36:29 +01:00
"\n",
"plt.subplot(121)\n",
2021-11-22 09:36:00 +01:00
"plt.plot(range(1, max_k + 1), inertias[:, 0], \"r--\", label=\"K-Means\")\n",
"plt.plot(range(1, max_k + 1), inertias[:, 1], \"b.-\", label=\"Mini-batch K-Means\")\n",
"plt.xlabel(\"$k$\")\n",
"plt.title(\"Inertia\")\n",
"plt.legend()\n",
"plt.axis([1, max_k, 0, 100])\n",
"plt.grid()\n",
2019-01-15 05:36:29 +01:00
"\n",
"plt.subplot(122)\n",
2021-11-22 09:36:00 +01:00
"plt.plot(range(1, max_k + 1), times[:, 0], \"r--\", label=\"K-Means\")\n",
"plt.plot(range(1, max_k + 1), times[:, 1], \"b.-\", label=\"Mini-batch K-Means\")\n",
"plt.xlabel(\"$k$\")\n",
"plt.title(\"Training time (seconds)\")\n",
"plt.axis([1, max_k, 0, 4])\n",
"plt.grid()\n",
"\n",
"save_fig(\"minibatch_kmeans_vs_kmeans_plot\")\n",
2016-09-27 23:31:21 +02:00
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"### Finding the optimal number of clusters"
2016-09-27 23:31:21 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-04-04 11:49:00 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"What if the number of clusters was set to a lower or greater value than 5?"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 37,
2018-04-04 11:49:00 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 7\n",
2021-11-22 09:36:00 +01:00
"\n",
2019-01-15 05:36:29 +01:00
"kmeans_k3 = KMeans(n_clusters=3, random_state=42)\n",
"kmeans_k8 = KMeans(n_clusters=8, random_state=42)\n",
2017-06-02 16:02:35 +02:00
"\n",
2019-01-15 05:36:29 +01:00
"plot_clusterer_comparison(kmeans_k3, kmeans_k8, X, \"$k=3$\", \"$k=8$\")\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"bad_n_clusters_plot\")\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-04-04 11:49:00 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"Ouch, these two models don't look great. What about their inertias?"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 38,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"kmeans_k3.inertia_"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 39,
2017-06-24 17:23:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"kmeans_k8.inertia_"
2016-09-27 23:31:21 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"No, we cannot simply take the value of $k$ that minimizes the inertia, since it keeps getting lower as we increase $k$. Indeed, the more clusters there are, the closer each instance will be to its closest centroid, and therefore the lower the inertia will be. However, we can plot the inertia as a function of $k$ and analyze the resulting curve:"
2017-06-24 17:23:47 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 40,
2017-06-24 17:23:47 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2017-06-24 17:23:47 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 8\n",
2021-11-22 09:36:00 +01:00
"\n",
2019-01-15 05:36:29 +01:00
"kmeans_per_k = [KMeans(n_clusters=k, random_state=42).fit(X)\n",
" for k in range(1, 10)]\n",
2021-11-22 09:36:00 +01:00
"inertias = [model.inertia_ for model in kmeans_per_k]\n",
"\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(8, 3.5))\n",
"plt.plot(range(1, 10), inertias, \"bo-\")\n",
2021-11-22 09:36:00 +01:00
"plt.xlabel(\"$k$\")\n",
"plt.ylabel(\"Inertia\")\n",
"plt.annotate(\"\", xy=(4, inertias[3]), xytext=(4.45, 650),\n",
" arrowprops=dict(facecolor='black', shrink=0.1))\n",
2021-11-27 11:03:26 +01:00
"plt.text(4.5, 650, \"Elbow\", horizontalalignment=\"center\")\n",
2019-01-15 05:36:29 +01:00
"plt.axis([1, 8.5, 0, 1300])\n",
2021-11-22 09:36:00 +01:00
"plt.grid()\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"inertia_vs_k_plot\")\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2017-06-24 17:23:47 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"As you can see, there is an elbow at $k=4$, which means that less clusters than that would be bad, and more clusters would not help much and might cut clusters in half. So $k=4$ is a pretty good choice. Of course in this example it is not perfect since it means that the two blobs in the lower left will be considered as just a single cluster, but it's a pretty good clustering nonetheless."
2017-06-24 17:23:47 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 41,
2018-04-04 11:49:00 +02:00
"metadata": {},
2017-06-24 17:23:47 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code\n",
2021-11-22 09:36:00 +01:00
"plot_decision_boundaries(kmeans_per_k[4 - 1], X)\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2017-06-24 17:23:47 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-04-04 11:49:00 +02:00
"metadata": {},
2017-06-24 17:23:47 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"Another approach is to look at the _silhouette score_, which is the mean _silhouette coefficient_ over all the instances. An instance's silhouette coefficient is equal to (_b_ - _a_) / max(_a_, _b_) where _a_ is the mean distance to the other instances in the same cluster (it is the _mean intra-cluster distance_), and _b_ is the _mean nearest-cluster distance_, that is the mean distance to the instances of the next closest cluster (defined as the one that minimizes _b_, excluding the instance's own cluster). The silhouette coefficient can vary between -1 and +1: a coefficient close to +1 means that the instance is well inside its own cluster and far from other clusters, while a coefficient close to 0 means that it is close to a cluster boundary, and finally a coefficient close to -1 means that the instance may have been assigned to the wrong cluster."
2017-06-24 17:23:47 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"Let's plot the silhouette score as a function of $k$:"
2017-06-24 17:23:47 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 42,
2018-04-04 11:49:00 +02:00
"metadata": {},
2017-06-24 17:23:47 +02:00
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"from sklearn.metrics import silhouette_score"
2017-06-24 17:23:47 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 43,
2018-04-04 11:49:00 +02:00
"metadata": {},
2017-06-24 17:23:47 +02:00
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"silhouette_score(X, kmeans.labels_)"
2017-06-24 17:23:47 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 44,
2017-06-24 17:23:47 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 9\n",
2021-11-22 09:36:00 +01:00
"\n",
2019-01-15 05:36:29 +01:00
"silhouette_scores = [silhouette_score(X, model.labels_)\n",
2021-11-22 09:36:00 +01:00
" for model in kmeans_per_k[1:]]\n",
"\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(8, 3))\n",
"plt.plot(range(2, 10), silhouette_scores, \"bo-\")\n",
2021-11-22 09:36:00 +01:00
"plt.xlabel(\"$k$\")\n",
"plt.ylabel(\"Silhouette score\")\n",
2019-01-15 05:36:29 +01:00
"plt.axis([1.8, 8.5, 0.55, 0.7])\n",
2021-11-22 09:36:00 +01:00
"plt.grid()\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"silhouette_score_vs_k_plot\")\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2017-06-24 17:23:47 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"As you can see, this visualization is much richer than the previous one: in particular, although it confirms that $k=4$ is a very good choice, but it also underlines the fact that $k=5$ is quite good as well."
2017-06-24 17:23:47 +02:00
]
},
{
2018-04-04 11:49:00 +02:00
"cell_type": "markdown",
"metadata": {},
2017-06-24 17:23:47 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"An even more informative visualization is given when you plot every instance's silhouette coefficient, sorted by the cluster they are assigned to and by the value of the coefficient. This is called a _silhouette diagram_:"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 45,
2019-01-15 05:36:29 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 10\n",
2021-11-22 09:36:00 +01:00
"\n",
2019-01-15 05:36:29 +01:00
"from sklearn.metrics import silhouette_samples\n",
"from matplotlib.ticker import FixedLocator, FixedFormatter\n",
"\n",
"plt.figure(figsize=(11, 9))\n",
"\n",
"for k in (3, 4, 5, 6):\n",
" plt.subplot(2, 2, k - 2)\n",
" \n",
" y_pred = kmeans_per_k[k - 1].labels_\n",
" silhouette_coefficients = silhouette_samples(X, y_pred)\n",
"\n",
" padding = len(X) // 30\n",
" pos = padding\n",
" ticks = []\n",
" for i in range(k):\n",
" coeffs = silhouette_coefficients[y_pred == i]\n",
" coeffs.sort()\n",
"\n",
2021-11-22 09:36:00 +01:00
" color = plt.cm.Spectral(i / k)\n",
2019-01-15 05:36:29 +01:00
" plt.fill_betweenx(np.arange(pos, pos + len(coeffs)), 0, coeffs,\n",
" facecolor=color, edgecolor=color, alpha=0.7)\n",
" ticks.append(pos + len(coeffs) // 2)\n",
" pos += len(coeffs) + padding\n",
"\n",
" plt.gca().yaxis.set_major_locator(FixedLocator(ticks))\n",
" plt.gca().yaxis.set_major_formatter(FixedFormatter(range(k)))\n",
" if k in (3, 5):\n",
" plt.ylabel(\"Cluster\")\n",
" \n",
" if k in (5, 6):\n",
" plt.gca().set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])\n",
" plt.xlabel(\"Silhouette Coefficient\")\n",
" else:\n",
" plt.tick_params(labelbottom=False)\n",
"\n",
" plt.axvline(x=silhouette_scores[k - 2], color=\"red\", linestyle=\"--\")\n",
2021-11-22 09:36:00 +01:00
" plt.title(f\"$k={k}$\")\n",
2019-01-15 05:36:29 +01:00
"\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"silhouette_analysis_plot\")\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2017-06-24 17:23:47 +02:00
]
},
2021-02-14 03:02:09 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see, $k=5$ looks like the best option here, as all clusters are roughly the same size, and they all cross the dashed line, which represents the mean silhouette score."
]
},
2017-06-24 17:23:47 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-03 12:05:49 +02:00
"## Limits of K-Means"
2017-06-24 17:23:47 +02:00
]
},
{
2021-11-22 09:36:00 +01:00
"cell_type": "markdown",
2017-06-24 17:23:47 +02:00
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"Let's generate a more difficult dataset, with elongated blobs and varying densities, and show that K-Means struggles to cluster it correctly:"
2017-06-24 17:23:47 +02:00
]
},
{
2018-04-04 11:49:00 +02:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 46,
2017-06-24 17:23:47 +02:00
"metadata": {},
2018-04-04 11:49:00 +02:00
"outputs": [],
2017-06-24 17:23:47 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 11\n",
2021-11-22 09:36:00 +01:00
"\n",
"X1, y1 = make_blobs(n_samples=1000, centers=((4, -4), (0, 0)), random_state=42)\n",
"X1 = X1.dot(np.array([[0.374, 0.95], [0.732, 0.598]]))\n",
"X2, y2 = make_blobs(n_samples=250, centers=1, random_state=42)\n",
"X2 = X2 + [6, -8]\n",
"X = np.r_[X1, X2]\n",
"y = np.r_[y1, y2]\n",
"\n",
"kmeans_good = KMeans(n_clusters=3,\n",
" init=np.array([[-1.5, 2.5], [0.5, 0], [4, 0]]),\n",
" n_init=1, random_state=42)\n",
2019-01-15 05:36:29 +01:00
"kmeans_bad = KMeans(n_clusters=3, random_state=42)\n",
"kmeans_good.fit(X)\n",
2021-11-22 09:36:00 +01:00
"kmeans_bad.fit(X)\n",
"\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(10, 3.2))\n",
"\n",
"plt.subplot(121)\n",
"plot_decision_boundaries(kmeans_good, X)\n",
2021-11-22 09:36:00 +01:00
"plt.title(f\"Inertia = {kmeans_good.inertia_:.1f}\")\n",
2019-01-15 05:36:29 +01:00
"\n",
"plt.subplot(122)\n",
"plot_decision_boundaries(kmeans_bad, X, show_ylabels=False)\n",
2021-11-22 09:36:00 +01:00
"plt.title(f\"Inertia = {kmeans_bad.inertia_:.1f}\")\n",
2019-01-15 05:36:29 +01:00
"\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"bad_kmeans_plot\")\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2017-06-24 17:23:47 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-03 12:05:49 +02:00
"## Using Clustering for Image Segmentation"
2017-06-24 17:23:47 +02:00
]
},
2021-11-22 09:36:00 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Download the ladybug image:"
]
},
2017-06-24 17:23:47 +02:00
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 47,
2017-06-24 17:23:47 +02:00
"metadata": {},
"outputs": [],
2019-11-05 15:26:52 +01:00
"source": [
2022-02-19 09:36:43 +01:00
"# extra code – downloads the ladybug image\n",
"\n",
"import urllib.request\n",
2021-11-22 09:36:00 +01:00
"\n",
2022-02-19 09:36:43 +01:00
"homl3_root = \"https://github.com/ageron/handson-ml3/raw/main/\"\n",
2019-11-05 15:26:52 +01:00
"filename = \"ladybug.png\"\n",
2021-11-22 09:36:00 +01:00
"filepath = IMAGES_PATH / filename\n",
"if not filepath.is_file():\n",
2021-10-15 10:46:27 +02:00
" print(\"Downloading\", filename)\n",
2022-02-19 09:36:43 +01:00
" url = f\"{homl3_root}/images/unsupervised_learning/{filename}\"\n",
2021-11-22 09:36:00 +01:00
" urllib.request.urlretrieve(url, filepath)"
2019-11-05 15:26:52 +01:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 48,
2019-11-05 15:26:52 +01:00
"metadata": {},
"outputs": [],
2017-06-24 17:23:47 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"import PIL\n",
"\n",
"image = np.asarray(PIL.Image.open(filepath))\n",
2019-01-15 05:36:29 +01:00
"image.shape"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 49,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"X = image.reshape(-1, 3)\n",
"kmeans = KMeans(n_clusters=8, random_state=42).fit(X)\n",
"segmented_img = kmeans.cluster_centers_[kmeans.labels_]\n",
2021-11-22 09:36:00 +01:00
"segmented_img = segmented_img.reshape(image.shape)"
2018-04-04 11:49:00 +02:00
]
},
2020-01-26 07:16:11 +01:00
{
2021-11-22 09:36:00 +01:00
"cell_type": "code",
"execution_count": 50,
2020-01-26 07:16:11 +01:00
"metadata": {},
2021-11-22 09:36:00 +01:00
"outputs": [],
2020-01-26 07:16:11 +01:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 12\n",
2021-11-22 09:36:00 +01:00
"\n",
"segmented_imgs = []\n",
"n_colors = (10, 8, 6, 4, 2)\n",
"for n_clusters in n_colors:\n",
" kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(X)\n",
" segmented_img = kmeans.cluster_centers_[kmeans.labels_]\n",
" segmented_imgs.append(segmented_img.reshape(image.shape))\n",
"\n",
2022-02-19 06:17:36 +01:00
"plt.figure(figsize=(10, 5))\n",
2021-11-22 09:36:00 +01:00
"plt.subplots_adjust(wspace=0.05, hspace=0.1)\n",
"\n",
"plt.subplot(2, 3, 1)\n",
"plt.imshow(image)\n",
"plt.title(\"Original image\")\n",
"plt.axis('off')\n",
"\n",
"for idx, n_clusters in enumerate(n_colors):\n",
" plt.subplot(2, 3, 2 + idx)\n",
" plt.imshow(segmented_imgs[idx] / 255)\n",
" plt.title(f\"{n_clusters} colors\")\n",
" plt.axis('off')\n",
"\n",
2022-02-19 06:17:36 +01:00
"save_fig('image_segmentation_plot', tight_layout=False)\n",
2021-11-22 09:36:00 +01:00
"plt.show()"
2020-01-26 07:16:11 +01:00
]
},
2018-04-04 11:49:00 +02:00
{
2021-11-22 09:36:00 +01:00
"cell_type": "markdown",
2018-04-04 11:49:00 +02:00
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"## Using Clustering for Semi-Supervised Learning"
2018-04-04 11:49:00 +02:00
]
},
{
2021-11-22 09:36:00 +01:00
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
2018-04-04 11:49:00 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"Another use case for clustering is semi-supervised learning, when we have plenty of unlabeled instances and very few labeled instances."
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"Let's tackle the _digits dataset_ which is a simple MNIST-like dataset containing 1,797 grayscale 8× 8 images representing digits 0 to 9."
2018-04-04 11:49:00 +02:00
]
},
{
2021-11-22 09:36:00 +01:00
"cell_type": "code",
"execution_count": 51,
2018-04-04 11:49:00 +02:00
"metadata": {},
2021-11-22 09:36:00 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"from sklearn.datasets import load_digits\n",
"\n",
"X_digits, y_digits = load_digits(return_X_y=True)\n",
"X_train, y_train = X_digits[:1400], y_digits[:1400]\n",
"X_test, y_test = X_digits[1400:], y_digits[1400:]"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"Let's look at the performance of a logistic regression model when we only have 50 labeled instances:"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 52,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"from sklearn.linear_model import LogisticRegression\n",
"\n",
"n_labeled = 50\n",
"log_reg = LogisticRegression(max_iter=10_000)\n",
"log_reg.fit(X_train[:n_labeled], y_train[:n_labeled])"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 53,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"log_reg.score(X_test, y_test)"
2018-04-04 11:49:00 +02:00
]
},
{
2021-11-22 09:36:00 +01:00
"cell_type": "code",
"execution_count": 54,
2018-04-04 11:49:00 +02:00
"metadata": {},
2021-11-22 09:36:00 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – measure the accuracy when we use the whole training set\n",
2021-11-22 09:36:00 +01:00
"log_reg_full = LogisticRegression(max_iter=10_000)\n",
"log_reg_full.fit(X_train, y_train)\n",
"log_reg_full.score(X_test, y_test)"
2018-04-04 11:49:00 +02:00
]
},
{
2021-11-22 09:36:00 +01:00
"cell_type": "markdown",
2018-04-04 11:49:00 +02:00
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"It's much less than earlier of course. Let's see how we can do better. First, let's cluster the training set into 50 clusters, then for each cluster let's find the image closest to the centroid. We will call these images the representative images:"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 55,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"k = 50\n",
2019-01-15 05:36:29 +01:00
"kmeans = KMeans(n_clusters=k, random_state=42)\n",
"X_digits_dist = kmeans.fit_transform(X_train)\n",
2021-11-22 09:36:00 +01:00
"representative_digit_idx = X_digits_dist.argmin(axis=0)\n",
2019-01-15 05:36:29 +01:00
"X_representative_digits = X_train[representative_digit_idx]"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"Now let's plot these representative images and label them manually:"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 56,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 13\n",
2021-11-22 09:36:00 +01:00
"\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(8, 2))\n",
"for index, X_representative_digit in enumerate(X_representative_digits):\n",
" plt.subplot(k // 10, 10, index + 1)\n",
2021-11-22 09:36:00 +01:00
" plt.imshow(X_representative_digit.reshape(8, 8), cmap=\"binary\",\n",
" interpolation=\"bilinear\")\n",
2019-01-15 05:36:29 +01:00
" plt.axis('off')\n",
2018-04-04 11:49:00 +02:00
"\n",
2022-02-19 06:17:36 +01:00
"save_fig(\"representative_images_plot\", tight_layout=False)\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 57,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"y_representative_digits = np.array([\n",
2021-11-22 09:36:00 +01:00
" 1, 3, 6, 0, 7, 9, 2, 4, 8, 9,\n",
" 5, 4, 7, 1, 2, 6, 1, 2, 5, 1,\n",
" 4, 1, 3, 3, 8, 8, 2, 5, 6, 9,\n",
" 1, 4, 0, 6, 8, 3, 4, 6, 7, 2,\n",
" 4, 1, 0, 7, 5, 1, 9, 9, 3, 7\n",
"])"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"Now we have a dataset with just 50 labeled instances, but instead of being completely random instances, each of them is a representative image of its cluster. Let's see if the performance is any better:"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 58,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"log_reg = LogisticRegression(max_iter=10_000)\n",
2019-01-15 05:36:29 +01:00
"log_reg.fit(X_representative_digits, y_representative_digits)\n",
"log_reg.score(X_test, y_test)"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-04-04 11:49:00 +02:00
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"Wow! We jumped from 74.8% accuracy to 84.9%, although we are still only training the model on 50 instances. Since it's often costly and painful to label instances, especially when it has to be done manually by experts, it's a good idea to make them label representative instances rather than just random instances."
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"But perhaps we can go one step further: what if we propagated the labels to all the other instances in the same cluster?"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 59,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"y_train_propagated = np.empty(len(X_train), dtype=np.int64)\n",
2019-01-15 05:36:29 +01:00
"for i in range(k):\n",
2021-11-22 09:36:00 +01:00
" y_train_propagated[kmeans.labels_ == i] = y_representative_digits[i]"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 60,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"log_reg = LogisticRegression(max_iter=10_000)\n",
2019-01-15 05:36:29 +01:00
"log_reg.fit(X_train, y_train_propagated)"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 61,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"log_reg.score(X_test, y_test)"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"We got another significant accuracy boost! Let's see if we can do even better by ignoring the 1% instances that are farthest from their cluster center: this should eliminate some outliers:"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 62,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"percentile_closest = 99\n",
2019-01-15 05:36:29 +01:00
"\n",
"X_cluster_dist = X_digits_dist[np.arange(len(X_train)), kmeans.labels_]\n",
"for i in range(k):\n",
" in_cluster = (kmeans.labels_ == i)\n",
" cluster_dist = X_cluster_dist[in_cluster]\n",
" cutoff_distance = np.percentile(cluster_dist, percentile_closest)\n",
" above_cutoff = (X_cluster_dist > cutoff_distance)\n",
2021-11-22 09:36:00 +01:00
" X_cluster_dist[in_cluster & above_cutoff] = -1\n",
"\n",
2019-01-15 05:36:29 +01:00
"partially_propagated = (X_cluster_dist != -1)\n",
"X_train_partially_propagated = X_train[partially_propagated]\n",
"y_train_partially_propagated = y_train_propagated[partially_propagated]"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 63,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"log_reg = LogisticRegression(max_iter=10_000)\n",
"log_reg.fit(X_train_partially_propagated, y_train_partially_propagated)\n",
2019-01-15 05:36:29 +01:00
"log_reg.score(X_test, y_test)"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"Wow, another accuracy boost! We have even slightly surpassed the performance we got by training on the fully labeled training set!"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"Our propagated labels are actually pretty good: their accuracy is about 97.6%:"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 64,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"(y_train_partially_propagated == y_train[partially_propagated]).mean()"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-01-26 07:16:11 +01:00
"You could now do a few iterations of *active learning*:\n",
2019-01-15 05:36:29 +01:00
"1. Manually label the instances that the classifier is least sure about, if possible by picking them in distinct clusters.\n",
"2. Train a new model with these additional labels."
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"## DBSCAN"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 65,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"from sklearn.cluster import DBSCAN\n",
"from sklearn.datasets import make_moons\n",
"\n",
"X, y = make_moons(n_samples=1000, noise=0.05, random_state=42)\n",
2019-01-15 05:36:29 +01:00
"dbscan = DBSCAN(eps=0.05, min_samples=5)\n",
"dbscan.fit(X)"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 66,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"dbscan.labels_[:10]"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 67,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"dbscan.core_sample_indices_[:10]"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 68,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"dbscan.components_"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 69,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 14\n",
2021-11-22 09:36:00 +01:00
"\n",
2019-01-15 05:36:29 +01:00
"def plot_dbscan(dbscan, X, size, show_xlabels=True, show_ylabels=True):\n",
" core_mask = np.zeros_like(dbscan.labels_, dtype=bool)\n",
" core_mask[dbscan.core_sample_indices_] = True\n",
" anomalies_mask = dbscan.labels_ == -1\n",
" non_core_mask = ~(core_mask | anomalies_mask)\n",
"\n",
" cores = dbscan.components_\n",
" anomalies = X[anomalies_mask]\n",
" non_cores = X[non_core_mask]\n",
" \n",
" plt.scatter(cores[:, 0], cores[:, 1],\n",
" c=dbscan.labels_[core_mask], marker='o', s=size, cmap=\"Paired\")\n",
2021-11-22 09:36:00 +01:00
" plt.scatter(cores[:, 0], cores[:, 1], marker='*', s=20,\n",
" c=dbscan.labels_[core_mask])\n",
2019-01-15 05:36:29 +01:00
" plt.scatter(anomalies[:, 0], anomalies[:, 1],\n",
" c=\"r\", marker=\"x\", s=100)\n",
2021-11-22 09:36:00 +01:00
" plt.scatter(non_cores[:, 0], non_cores[:, 1],\n",
" c=dbscan.labels_[non_core_mask], marker=\".\")\n",
2019-01-15 05:36:29 +01:00
" if show_xlabels:\n",
2021-11-22 09:36:00 +01:00
" plt.xlabel(\"$x_1$\")\n",
2019-01-15 05:36:29 +01:00
" else:\n",
" plt.tick_params(labelbottom=False)\n",
" if show_ylabels:\n",
2021-11-22 09:36:00 +01:00
" plt.ylabel(\"$x_2$\", rotation=0)\n",
2019-01-15 05:36:29 +01:00
" else:\n",
" plt.tick_params(labelleft=False)\n",
2021-11-22 09:36:00 +01:00
" plt.title(f\"eps={dbscan.eps:.2f}, min_samples={dbscan.min_samples}\")\n",
" plt.grid()\n",
" plt.gca().set_axisbelow(True)\n",
"\n",
"dbscan2 = DBSCAN(eps=0.2)\n",
"dbscan2.fit(X)\n",
"\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(9, 3.2))\n",
"\n",
"plt.subplot(121)\n",
"plot_dbscan(dbscan, X, size=100)\n",
"\n",
"plt.subplot(122)\n",
"plot_dbscan(dbscan2, X, size=600, show_ylabels=False)\n",
"\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"dbscan_plot\")\n",
2021-11-22 09:36:00 +01:00
"plt.show()"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 70,
2019-01-15 05:36:29 +01:00
"metadata": {},
2018-04-04 11:49:00 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"dbscan = dbscan2 # extra code – the text says we now use eps=0.2"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 71,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"from sklearn.neighbors import KNeighborsClassifier\n",
"\n",
2019-01-15 05:36:29 +01:00
"knn = KNeighborsClassifier(n_neighbors=50)\n",
"knn.fit(dbscan.components_, dbscan.labels_[dbscan.core_sample_indices_])"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 72,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"X_new = np.array([[-0.5, 0], [0, 0.5], [1, -0.1], [2, 1]])\n",
"knn.predict(X_new)"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 73,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"knn.predict_proba(X_new)"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 74,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 15\n",
2021-11-22 09:36:00 +01:00
"\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(6, 3))\n",
"plot_decision_boundaries(knn, X, show_centroids=False)\n",
"plt.scatter(X_new[:, 0], X_new[:, 1], c=\"b\", marker=\"+\", s=200, zorder=10)\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"cluster_classification_plot\")\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 75,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"y_dist, y_pred_idx = knn.kneighbors(X_new, n_neighbors=1)\n",
"y_pred = dbscan.labels_[dbscan.core_sample_indices_][y_pred_idx]\n",
"y_pred[y_dist > 0.2] = -1\n",
"y_pred.ravel()"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-04-04 11:49:00 +02:00
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"## Other Clustering Algorithms"
2018-04-04 11:49:00 +02:00
]
},
2021-11-22 09:36:00 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code in this section is bonus material, not in the book."
]
},
2018-04-04 11:49:00 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"### Spectral Clustering"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 76,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"from sklearn.cluster import SpectralClustering"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 77,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"sc1 = SpectralClustering(n_clusters=2, gamma=100, random_state=42)\n",
"sc1.fit(X)"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 78,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"sc1.affinity_matrix_.round(2)"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 79,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"sc2 = SpectralClustering(n_clusters=2, gamma=1, random_state=42)\n",
"sc2.fit(X)"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 80,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"def plot_spectral_clustering(sc, X, size, alpha, show_xlabels=True,\n",
" show_ylabels=True):\n",
" plt.scatter(X[:, 0], X[:, 1], marker='o', s=size, c='gray', cmap=\"Paired\",\n",
" alpha=alpha)\n",
2019-01-15 05:36:29 +01:00
" plt.scatter(X[:, 0], X[:, 1], marker='o', s=30, c='w')\n",
" plt.scatter(X[:, 0], X[:, 1], marker='.', s=10, c=sc.labels_, cmap=\"Paired\")\n",
" \n",
" if show_xlabels:\n",
2021-11-22 09:36:00 +01:00
" plt.xlabel(\"$x_1$\")\n",
2019-01-15 05:36:29 +01:00
" else:\n",
" plt.tick_params(labelbottom=False)\n",
" if show_ylabels:\n",
2021-11-22 09:36:00 +01:00
" plt.ylabel(\"$x_2$\", rotation=0)\n",
2019-01-15 05:36:29 +01:00
" else:\n",
" plt.tick_params(labelleft=False)\n",
2021-11-22 09:36:00 +01:00
" plt.title(f\"RBF gamma={sc.gamma}\")"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 81,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(9, 3.2))\n",
"\n",
"plt.subplot(121)\n",
2019-01-15 05:36:29 +01:00
"plot_spectral_clustering(sc1, X, size=500, alpha=0.1)\n",
2018-04-04 11:49:00 +02:00
"\n",
"plt.subplot(122)\n",
2019-01-15 05:36:29 +01:00
"plot_spectral_clustering(sc2, X, size=4000, alpha=0.01, show_ylabels=False)\n",
2018-04-04 11:49:00 +02:00
"\n",
2021-11-22 09:36:00 +01:00
"plt.show()"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
2019-01-15 05:36:29 +01:00
"metadata": {},
2018-04-04 11:49:00 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"### Agglomerative Clustering"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 82,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"from sklearn.cluster import AgglomerativeClustering"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 83,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"X = np.array([0, 2, 5, 8.5]).reshape(-1, 1)\n",
"agg = AgglomerativeClustering(linkage=\"complete\").fit(X)"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 84,
2019-01-15 05:36:29 +01:00
"metadata": {},
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"def learned_parameters(estimator):\n",
" return [attrib for attrib in dir(estimator)\n",
" if attrib.endswith(\"_\") and not attrib.startswith(\"_\")]"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 85,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
"source": [
"learned_parameters(agg)"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 86,
2019-01-15 05:36:29 +01:00
"metadata": {
"scrolled": true
},
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"agg.children_"
2018-04-04 11:49:00 +02:00
]
},
{
2018-12-21 03:18:31 +01:00
"cell_type": "markdown",
2018-04-04 11:49:00 +02:00
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"# Gaussian Mixtures"
2018-04-04 11:49:00 +02:00
]
},
2021-11-22 09:36:00 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's generate the same dataset as earliers with three ellipsoids (the one K-Means had trouble with):"
]
},
2018-04-04 11:49:00 +02:00
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 87,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"X1, y1 = make_blobs(n_samples=1000, centers=((4, -4), (0, 0)), random_state=42)\n",
"X1 = X1.dot(np.array([[0.374, 0.95], [0.732, 0.598]]))\n",
"X2, y2 = make_blobs(n_samples=250, centers=1, random_state=42)\n",
"X2 = X2 + [6, -8]\n",
"X = np.r_[X1, X2]\n",
"y = np.r_[y1, y2]"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"Let's train a Gaussian mixture model on the previous dataset:"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 88,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"from sklearn.mixture import GaussianMixture"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 89,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"gm = GaussianMixture(n_components=3, n_init=10, random_state=42)\n",
"gm.fit(X)"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-04-04 11:49:00 +02:00
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"Let's look at the parameters that the EM algorithm estimated:"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 90,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"gm.weights_"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 91,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"gm.means_"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 92,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"gm.covariances_"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"Did the algorithm actually converge?"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 93,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"gm.converged_"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-04-04 11:49:00 +02:00
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"Yes, good. How many iterations did it take?"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 94,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"gm.n_iter_"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"You can now use the model to predict which cluster each instance belongs to (hard clustering) or the probabilities that it came from each cluster. For this, just use `predict()` method or the `predict_proba()` method:"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 95,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"gm.predict(X)"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 96,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2021-11-22 09:36:00 +01:00
"gm.predict_proba(X).round(3)"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-04-04 11:49:00 +02:00
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"This is a generative model, so you can sample new instances from it (and get their labels):"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 97,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"X_new, y_new = gm.sample(6)\n",
"X_new"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 98,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"y_new"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"Notice that they are sampled sequentially from each cluster."
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-04-04 11:49:00 +02:00
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"You can also estimate the log of the _probability density function_ (PDF) at any location using the `score_samples()` method:"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 99,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"gm.score_samples(X).round(2)"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"Let's check that the PDF integrates to 1 over the whole space. We just take a large square around the clusters, and chop it into a grid of tiny squares, then we compute the approximate probability that the instances will be generated in each tiny square (by multiplying the PDF at one corner of the tiny square by the area of the square), and finally summing all these probabilities). The result is very close to 1:"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 100,
2018-04-04 11:49:00 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – bonus material\n",
2021-11-22 09:36:00 +01:00
"\n",
2019-01-15 05:36:29 +01:00
"resolution = 100\n",
"grid = np.arange(-10, 10, 1 / resolution)\n",
"xx, yy = np.meshgrid(grid, grid)\n",
"X_full = np.vstack([xx.ravel(), yy.ravel()]).T\n",
"\n",
"pdf = np.exp(gm.score_samples(X_full))\n",
"pdf_probas = pdf * (1 / resolution) ** 2\n",
"pdf_probas.sum()"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"Now let's plot the resulting decision boundaries (dashed lines) and density contours:"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 101,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cells generates and saves Figure 9– 16\n",
2021-11-22 09:36:00 +01:00
"\n",
2019-01-15 05:36:29 +01:00
"from matplotlib.colors import LogNorm\n",
"\n",
"def plot_gaussian_mixture(clusterer, X, resolution=1000, show_ylabels=True):\n",
" mins = X.min(axis=0) - 0.1\n",
" maxs = X.max(axis=0) + 0.1\n",
" xx, yy = np.meshgrid(np.linspace(mins[0], maxs[0], resolution),\n",
" np.linspace(mins[1], maxs[1], resolution))\n",
" Z = -clusterer.score_samples(np.c_[xx.ravel(), yy.ravel()])\n",
" Z = Z.reshape(xx.shape)\n",
"\n",
" plt.contourf(xx, yy, Z,\n",
" norm=LogNorm(vmin=1.0, vmax=30.0),\n",
" levels=np.logspace(0, 2, 12))\n",
" plt.contour(xx, yy, Z,\n",
" norm=LogNorm(vmin=1.0, vmax=30.0),\n",
" levels=np.logspace(0, 2, 12),\n",
" linewidths=1, colors='k')\n",
"\n",
" Z = clusterer.predict(np.c_[xx.ravel(), yy.ravel()])\n",
" Z = Z.reshape(xx.shape)\n",
" plt.contour(xx, yy, Z,\n",
" linewidths=2, colors='r', linestyles='dashed')\n",
" \n",
" plt.plot(X[:, 0], X[:, 1], 'k.', markersize=2)\n",
" plot_centroids(clusterer.means_, clusterer.weights_)\n",
"\n",
2021-11-22 09:36:00 +01:00
" plt.xlabel(\"$x_1$\")\n",
2019-01-15 05:36:29 +01:00
" if show_ylabels:\n",
2021-11-22 09:36:00 +01:00
" plt.ylabel(\"$x_2$\", rotation=0)\n",
2019-01-15 05:36:29 +01:00
" else:\n",
2021-11-22 09:36:00 +01:00
" plt.tick_params(labelleft=False)\n",
"\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(8, 4))\n",
"\n",
"plot_gaussian_mixture(gm, X)\n",
"\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"gaussian_mixtures_plot\")\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2018-04-04 11:49:00 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"You can impose constraints on the covariance matrices that the algorithm looks for by setting the `covariance_type` hyperparameter:\n",
"* `\"spherical\"`: all clusters must be spherical, but they can have different diameters (i.e., different variances).\n",
2021-11-22 09:36:00 +01:00
"* `\"diag\"`: clusters can take on any ellipsoidal shape of any size, but the ellipsoid's axes must be parallel to the axes (i.e., the covariance matrices must be diagonal).\n",
"* `\"tied\"`: all clusters must have the same shape, which can be any ellipsoid (i.e., they all share the same covariance matrix).\n",
"* `\"full\"` (default): no constraint, all clusters can take on any ellipsoidal shape of any size."
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 102,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 17\n",
2021-11-22 09:36:00 +01:00
"\n",
"gm_full = GaussianMixture(n_components=3, n_init=10,\n",
" covariance_type=\"full\", random_state=42)\n",
"gm_tied = GaussianMixture(n_components=3, n_init=10,\n",
" covariance_type=\"tied\", random_state=42)\n",
"gm_spherical = GaussianMixture(n_components=3, n_init=10,\n",
" covariance_type=\"spherical\", random_state=42)\n",
"gm_diag = GaussianMixture(n_components=3, n_init=10,\n",
" covariance_type=\"diag\", random_state=42)\n",
2019-01-15 05:36:29 +01:00
"gm_full.fit(X)\n",
"gm_tied.fit(X)\n",
"gm_spherical.fit(X)\n",
2021-11-22 09:36:00 +01:00
"gm_diag.fit(X)\n",
"\n",
2019-01-15 05:36:29 +01:00
"def compare_gaussian_mixtures(gm1, gm2, X):\n",
" plt.figure(figsize=(9, 4))\n",
"\n",
" plt.subplot(121)\n",
" plot_gaussian_mixture(gm1, X)\n",
2021-11-22 09:36:00 +01:00
" plt.title(f'covariance_type=\"{gm1.covariance_type}\"')\n",
2019-01-15 05:36:29 +01:00
"\n",
" plt.subplot(122)\n",
" plot_gaussian_mixture(gm2, X, show_ylabels=False)\n",
2021-11-22 09:36:00 +01:00
" plt.title(f'covariance_type=\"{gm2.covariance_type}\"')\n",
"\n",
2019-01-15 05:36:29 +01:00
"compare_gaussian_mixtures(gm_tied, gm_spherical, X)\n",
2018-04-04 11:49:00 +02:00
"\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"covariance_type_plot\")\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 103,
2018-04-04 11:49:00 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2018-04-04 11:49:00 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – comparing covariance_type=\"full\" and covariance_type=\"diag\"\n",
2019-01-15 05:36:29 +01:00
"compare_gaussian_mixtures(gm_full, gm_diag, X)\n",
"plt.tight_layout()\n",
"plt.show()"
2018-04-04 11:49:00 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2018-04-04 11:49:00 +02:00
"metadata": {},
2017-06-26 00:09:23 +02:00
"source": [
2021-10-03 12:05:49 +02:00
"## Anomaly Detection Using Gaussian Mixtures"
2017-06-26 00:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"Gaussian Mixtures can be used for _anomaly detection_: instances located in low-density regions can be considered anomalies. You must define what density threshold you want to use. For example, in a manufacturing company that tries to detect defective products, the ratio of defective products is usually well-known. Say it is equal to 2%, then you can set the density threshold to be the value that results in having 2% of the instances located in areas below that threshold density:"
2017-06-26 00:09:23 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 104,
2017-06-26 00:09:23 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"densities = gm.score_samples(X)\n",
2021-11-22 09:36:00 +01:00
"density_threshold = np.percentile(densities, 2)\n",
2019-01-15 05:36:29 +01:00
"anomalies = X[densities < density_threshold]"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 105,
2019-01-15 05:36:29 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 18\n",
2021-11-22 09:36:00 +01:00
"\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(8, 4))\n",
"\n",
"plot_gaussian_mixture(gm, X)\n",
"plt.scatter(anomalies[:, 0], anomalies[:, 1], color='r', marker='*')\n",
"plt.ylim(top=5.1)\n",
"\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"mixture_anomaly_detection_plot\")\n",
2017-06-26 00:09:23 +02:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-03 12:05:49 +02:00
"## Selecting the Number of Clusters"
2017-06-26 00:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"We cannot use the inertia or the silhouette score because they both assume that the clusters are spherical. Instead, we can try to find the model that minimizes a theoretical information criterion such as the Bayesian Information Criterion (BIC) or the Akaike Information Criterion (AIC):\n",
"\n",
"${BIC} = {\\log(m)p - 2\\log({\\hat L})}$\n",
"\n",
"${AIC} = 2p - 2\\log(\\hat L)$\n",
"\n",
"* $m$ is the number of instances.\n",
"* $p$ is the number of parameters learned by the model.\n",
"* $\\hat L$ is the maximized value of the likelihood function of the model. This is the conditional probability of the observed data $\\mathbf{X}$, given the model and its optimized parameters.\n",
"\n",
"Both BIC and AIC penalize models that have more parameters to learn (e.g., more clusters), and reward models that fit the data well (i.e., models that give a high likelihood to the observed data)."
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 19\n",
2021-11-22 09:36:00 +01:00
"\n",
"from scipy.stats import norm\n",
"\n",
"x_val = 2.5\n",
"std_val = 1.3\n",
"x_range = [-6, 4]\n",
"x_proba_range = [-2, 2]\n",
"stds_range = [1, 2]\n",
"\n",
"xs = np.linspace(x_range[0], x_range[1], 501)\n",
"stds = np.linspace(stds_range[0], stds_range[1], 501)\n",
"Xs, Stds = np.meshgrid(xs, stds)\n",
"Z = 2 * norm.pdf(Xs - 1.0, 0, Stds) + norm.pdf(Xs + 4.0, 0, Stds)\n",
2022-02-19 06:17:36 +01:00
"Z = Z / Z.sum(axis=1)[:, np.newaxis] / (xs[1] - xs[0])\n",
2021-11-22 09:36:00 +01:00
"\n",
"x_example_idx = (xs >= x_val).argmax() # index of the first value >= x_val\n",
"max_idx = Z[:, x_example_idx].argmax()\n",
"max_val = Z[:, x_example_idx].max()\n",
"s_example_idx = (stds >= std_val).argmax()\n",
"x_range_min_idx = (xs >= x_proba_range[0]).argmax()\n",
"x_range_max_idx = (xs >= x_proba_range[1]).argmax()\n",
"log_max_idx = np.log(Z[:, x_example_idx]).argmax()\n",
"log_max_val = np.log(Z[:, x_example_idx]).max()\n",
"\n",
"plt.figure(figsize=(8, 4.5))\n",
"\n",
"plt.subplot(2, 2, 1)\n",
"plt.contourf(Xs, Stds, Z, cmap=\"GnBu\")\n",
"plt.plot([-6, 4], [std_val, std_val], \"k-\", linewidth=2)\n",
"plt.plot([x_val, x_val], [1, 2], \"b-\", linewidth=2)\n",
"plt.ylabel(r\"$\\theta$\", rotation=0, labelpad=10)\n",
"plt.title(r\"Model $f(x; \\theta)$\")\n",
"\n",
"plt.subplot(2, 2, 2)\n",
"plt.plot(stds, Z[:, x_example_idx], \"b-\")\n",
"plt.plot(stds[max_idx], max_val, \"r.\")\n",
"plt.plot([stds[max_idx], stds[max_idx]], [0, max_val], \"r:\")\n",
"plt.plot([0, stds[max_idx]], [max_val, max_val], \"r:\")\n",
2021-11-27 11:03:26 +01:00
"plt.text(stds[max_idx]+ 0.01, 0.081, r\"$\\hat{\\theta}$\")\n",
"plt.text(stds[max_idx]+ 0.01, max_val - 0.006, r\"$Max$\")\n",
"plt.text(1.01, max_val - 0.008, r\"$\\hat{\\mathcal{L}}$\")\n",
2021-11-22 09:36:00 +01:00
"plt.ylabel(r\"$\\mathcal{L}$\", rotation=0, labelpad=10)\n",
"plt.title(fr\"$\\mathcal{{L}}(\\theta|x={x_val}) = f(x={x_val}; \\theta)$\")\n",
"plt.grid()\n",
"plt.axis([1, 2, 0.08, 0.12])\n",
"\n",
"plt.subplot(2, 2, 3)\n",
"plt.plot(xs, Z[s_example_idx], \"k-\")\n",
"plt.fill_between(xs[x_range_min_idx:x_range_max_idx+1],\n",
" Z[s_example_idx, x_range_min_idx:x_range_max_idx+1], alpha=0.2)\n",
"plt.xlabel(r\"$x$\")\n",
"plt.ylabel(\"PDF\")\n",
"plt.title(fr\"PDF $f(x; \\theta={std_val})$\")\n",
"plt.grid()\n",
"plt.axis([-6, 4, 0, 0.25])\n",
"\n",
"plt.subplot(2, 2, 4)\n",
"plt.plot(stds, np.log(Z[:, x_example_idx]), \"b-\")\n",
"plt.plot(stds[log_max_idx], log_max_val, \"r.\")\n",
"plt.plot([stds[log_max_idx], stds[log_max_idx]], [-5, log_max_val], \"r:\")\n",
"plt.plot([0, stds[log_max_idx]], [log_max_val, log_max_val], \"r:\")\n",
2021-11-27 11:03:26 +01:00
"plt.text(stds[log_max_idx]+ 0.01, log_max_val - 0.06, r\"$Max$\")\n",
"plt.text(stds[log_max_idx]+ 0.01, -2.49, r\"$\\hat{\\theta}$\")\n",
"plt.text(1.01, log_max_val - 0.08, r\"$\\log \\, \\hat{\\mathcal{L}}$\")\n",
2021-11-22 09:36:00 +01:00
"plt.xlabel(r\"$\\theta$\")\n",
"plt.ylabel(r\"$\\log\\mathcal{L}$\", rotation=0, labelpad=10)\n",
"plt.title(fr\"$\\log \\, \\mathcal{{L}}(\\theta|x={x_val})$\")\n",
"plt.grid()\n",
"plt.axis([1, 2, -2.5, -2.1])\n",
"\n",
"save_fig(\"likelihood_function_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 107,
2019-01-15 05:36:29 +01:00
"metadata": {},
"outputs": [],
"source": [
"gm.bic(X)"
2017-06-26 00:09:23 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 108,
2017-06-26 00:09:23 +02:00
"metadata": {},
"outputs": [],
"source": [
2019-01-15 05:36:29 +01:00
"gm.aic(X)"
2017-06-26 00:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"We could compute the BIC manually like this:"
2017-06-26 00:09:23 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 109,
2018-04-04 11:49:00 +02:00
"metadata": {},
2017-06-26 00:09:23 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – bonus material\n",
2019-01-15 05:36:29 +01:00
"n_clusters = 3\n",
"n_dims = 2\n",
"n_params_for_weights = n_clusters - 1\n",
"n_params_for_means = n_clusters * n_dims\n",
"n_params_for_covariance = n_clusters * n_dims * (n_dims + 1) // 2\n",
"n_params = n_params_for_weights + n_params_for_means + n_params_for_covariance\n",
"max_log_likelihood = gm.score(X) * len(X) # log(L^)\n",
"bic = np.log(len(X)) * n_params - 2 * max_log_likelihood\n",
2021-11-22 09:36:00 +01:00
"aic = 2 * n_params - 2 * max_log_likelihood\n",
"print(f\"bic = {bic}\")\n",
"print(f\"aic = {aic}\")\n",
"print(f\"n_params = {n_params}\")"
2017-06-26 00:09:23 +02:00
]
},
{
"cell_type": "markdown",
2019-01-15 05:36:29 +01:00
"metadata": {},
2017-06-26 00:09:23 +02:00
"source": [
2019-01-15 05:36:29 +01:00
"There's one weight per cluster, but the sum must be equal to 1, so we have one degree of freedom less, hence the -1. Similarly, the degrees of freedom for an $n \\times n$ covariance matrix is not $n^2$, but $1 + 2 + \\dots + n = \\dfrac{n (n+1)}{2}$."
2017-06-26 00:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"Let's train Gaussian Mixture models with various values of $k$ and measure their BIC:"
2017-06-26 00:09:23 +02:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 110,
2018-04-04 11:49:00 +02:00
"metadata": {},
2017-06-26 00:09:23 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 20\n",
2021-11-22 09:36:00 +01:00
"\n",
2019-01-15 05:36:29 +01:00
"gms_per_k = [GaussianMixture(n_components=k, n_init=10, random_state=42).fit(X)\n",
2021-11-22 09:36:00 +01:00
" for k in range(1, 11)]\n",
2019-01-15 05:36:29 +01:00
"bics = [model.bic(X) for model in gms_per_k]\n",
2021-11-22 09:36:00 +01:00
"aics = [model.aic(X) for model in gms_per_k]\n",
"\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(8, 3))\n",
"plt.plot(range(1, 11), bics, \"bo-\", label=\"BIC\")\n",
"plt.plot(range(1, 11), aics, \"go--\", label=\"AIC\")\n",
2021-11-22 09:36:00 +01:00
"plt.xlabel(\"$k$\")\n",
"plt.ylabel(\"Information Criterion\")\n",
"plt.axis([1, 9.5, min(aics) - 50, max(aics) + 50])\n",
"plt.annotate(\"\", xy=(3, bics[2]), xytext=(3.4, 8650),\n",
" arrowprops=dict(facecolor='black', shrink=0.1))\n",
2021-11-27 11:03:26 +01:00
"plt.text(3.5, 8660, \"Minimum\", horizontalalignment=\"center\")\n",
2019-01-15 05:36:29 +01:00
"plt.legend()\n",
2021-11-22 09:36:00 +01:00
"plt.grid()\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"aic_bic_vs_k_plot\")\n",
2019-01-15 05:36:29 +01:00
"plt.show()"
2017-06-26 00:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-03 12:05:49 +02:00
"## Bayesian Gaussian Mixture Models"
2017-06-26 00:09:23 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "markdown",
2017-06-26 00:09:23 +02:00
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"Rather than manually searching for the optimal number of clusters, it is possible to use instead the `BayesianGaussianMixture` class which is capable of giving weights equal (or close) to zero to unnecessary clusters. Just set the number of components to a value that you believe is greater than the optimal number of clusters, and the algorithm will eliminate the unnecessary clusters automatically."
2017-06-26 00:09:23 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 111,
2017-06-26 00:09:23 +02:00
"metadata": {},
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"from sklearn.mixture import BayesianGaussianMixture\n",
"\n",
2019-01-15 05:36:29 +01:00
"bgm = BayesianGaussianMixture(n_components=10, n_init=10, random_state=42)\n",
2021-11-22 09:36:00 +01:00
"bgm.fit(X)\n",
"bgm.weights_.round(2)"
2017-06-26 00:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-22 09:36:00 +01:00
"The algorithm automatically detected that only 3 components are needed!"
2017-06-26 00:09:23 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 112,
2017-06-26 00:09:23 +02:00
"metadata": {},
2019-01-15 05:36:29 +01:00
"outputs": [],
2017-06-26 00:09:23 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this figure is almost identical to Figure 9– 16\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(8, 5))\n",
"plot_gaussian_mixture(bgm, X)\n",
"plt.show()"
2017-06-26 00:09:23 +02:00
]
},
{
2019-01-15 05:36:29 +01:00
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 113,
2017-06-26 00:09:23 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 9– 21\n",
2017-06-26 00:09:23 +02:00
"\n",
2021-11-22 09:36:00 +01:00
"X_moons, y_moons = make_moons(n_samples=1000, noise=0.05, random_state=42)\n",
2019-01-15 05:36:29 +01:00
"\n",
"bgm = BayesianGaussianMixture(n_components=10, n_init=10, random_state=42)\n",
2021-11-22 09:36:00 +01:00
"bgm.fit(X_moons)\n",
"\n",
2019-01-15 05:36:29 +01:00
"plt.figure(figsize=(9, 3.2))\n",
"\n",
"plt.subplot(121)\n",
"plot_data(X_moons)\n",
2021-11-22 09:36:00 +01:00
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$x_2$\", rotation=0)\n",
"plt.grid()\n",
2019-01-15 05:36:29 +01:00
"\n",
"plt.subplot(122)\n",
"plot_gaussian_mixture(bgm, X_moons, show_ylabels=False)\n",
2017-06-26 00:09:23 +02:00
"\n",
2019-05-09 10:25:26 +02:00
"save_fig(\"moons_vs_bgm_plot\")\n",
2017-06-26 00:09:23 +02:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-15 05:36:29 +01:00
"Oops, not great... instead of detecting 2 moon-shaped clusters, the algorithm detected 8 ellipsoidal clusters. However, the density plot does not look too bad, so it might be usable for anomaly detection."
2017-06-26 00:09:23 +02:00
]
},
2020-01-26 07:16:11 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. to 9."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-25 09:45:32 +01:00
"1. In Machine Learning, clustering is the unsupervised task of grouping similar instances together. The notion of similarity depends on the task at hand: for example, in some cases two nearby instances will be considered similar, while in others similar instances may be far apart as long as they belong to the same densely packed group. Popular clustering algorithms include K-Means, DBSCAN, agglomerative clustering, BIRCH, Mean-Shift, affinity propagation, and spectral :.\n",
"2. The main applications of clustering algorithms include data analysis, customer segmentation, recommender systems, search engines, image segmentation, semi-supervised learning, dimensionality reduction, anomaly detection, and novelty detection.\n",
"3. The elbow rule is a simple technique to select the number of clusters when using K-Means: just plot the inertia (the mean squared distance from each instance to its nearest centroid) as a function of the number of clusters, and find the point in the curve where the inertia stops dropping fast (the \"elbow\"). This is generally close to the optimal number of clusters. Another approach is to plot the silhouette score as a function of the number of clusters. There will often be a peak, and the optimal number of clusters is generally nearby. The silhouette score is the mean silhouette coefficient over all instances. This coefficient varies from +1 for instances that are well inside their cluster and far from other clusters, to – 1 for instances that are very close to another cluster. You may also plot the silhouette diagrams and perform a more thorough analysis.\n",
"4. Labeling a dataset is costly and time-consuming. Therefore, it is common to have plenty of unlabeled instances, but few labeled instances. Label propagation is a technique that consists in copying some (or all) of the labels from the labeled instances to similar unlabeled instances. This can greatly extend the number of labeled instances, and thereby allow a supervised algorithm to reach better performance (this is a form of semi-supervised learning). One approach is to use a clustering algorithm such as K-Means on all the instances, then for each cluster find the most common label or the label of the most representative instance (i.e., the one closest to the centroid) and propagate it to the unlabeled instances in the same cluster.\n",
"5. K-Means and BIRCH scale well to large datasets. DBSCAN and Mean-Shift look for regions of high density.\n",
"6. Active learning is useful whenever you have plenty of unlabeled instances but labeling is costly. In this case (which is very common), rather than randomly selecting instances to label, it is often preferable to perform active learning, where human experts interact with the learning algorithm, providing labels for specific instances when the algorithm requests them. A common approach is uncertainty sampling (see the _Active Learning_ section in chapter 9).\n",
"7. Many people use the terms _anomaly detection_ and _novelty detection_ interchangeably, but they are not exactly the same. In anomaly detection, the algorithm is trained on a dataset that may contain outliers, and the goal is typically to identify these outliers (within the training set), as well as outliers among new instances. In novelty detection, the algorithm is trained on a dataset that is presumed to be \"clean,\" and the objective is to detect novelties strictly among new instances. Some algorithms work best for anomaly detection (e.g., Isolation Forest), while others are better suited for novelty detection (e.g., one-class SVM).\n",
"8. A Gaussian mixture model (GMM) is a probabilistic model that assumes that the instances were generated from a mixture of several Gaussian distributions whose parameters are unknown. In other words, the assumption is that the data is grouped into a finite number of clusters, each with an ellipsoidal shape (but the clusters may have different ellipsoidal shapes, sizes, orientations, and densities), and we don't know which cluster each instance belongs to. This model is useful for density estimation, clustering, and anomaly detection.\n",
"9. One way to find the right number of clusters when using a Gaussian mixture model is to plot the Bayesian information criterion (BIC) or the Akaike information criterion (AIC) as a function of the number of clusters, then choose the number of clusters that minimizes the BIC or AIC. Another technique is to use a Bayesian Gaussian mixture model, which automatically selects the number of clusters."
2020-01-26 07:16:11 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 10. Cluster the Olivetti Faces Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Exercise: The classic Olivetti faces dataset contains 400 grayscale 64 × 64– pixel images of faces. Each image is flattened to a 1D vector of size 4,096. 40 different people were photographed (10 times each), and the usual task is to train a model that can predict which person is represented in each picture. Load the dataset using the `sklearn.datasets.fetch_olivetti_faces()` function.*"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 114,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_olivetti_faces\n",
"\n",
"olivetti = fetch_olivetti_faces()"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 115,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"print(olivetti.DESCR)"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 116,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"olivetti.target"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Exercise: Then split it into a training set, a validation set, and a test set (note that the dataset is already scaled between 0 and 1). Since the dataset is quite small, you probably want to use stratified sampling to ensure that there are the same number of images per person in each set.*"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 117,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import StratifiedShuffleSplit\n",
"\n",
"strat_split = StratifiedShuffleSplit(n_splits=1, test_size=40, random_state=42)\n",
2021-11-22 09:36:00 +01:00
"train_valid_idx, test_idx = next(strat_split.split(olivetti.data,\n",
" olivetti.target))\n",
2020-01-26 07:16:11 +01:00
"X_train_valid = olivetti.data[train_valid_idx]\n",
"y_train_valid = olivetti.target[train_valid_idx]\n",
"X_test = olivetti.data[test_idx]\n",
"y_test = olivetti.target[test_idx]\n",
"\n",
"strat_split = StratifiedShuffleSplit(n_splits=1, test_size=80, random_state=43)\n",
"train_idx, valid_idx = next(strat_split.split(X_train_valid, y_train_valid))\n",
"X_train = X_train_valid[train_idx]\n",
"y_train = y_train_valid[train_idx]\n",
"X_valid = X_train_valid[valid_idx]\n",
"y_valid = y_train_valid[valid_idx]"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 118,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"print(X_train.shape, y_train.shape)\n",
"print(X_valid.shape, y_valid.shape)\n",
"print(X_test.shape, y_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To speed things up, we'll reduce the data's dimensionality using PCA:"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 119,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn.decomposition import PCA\n",
"\n",
"pca = PCA(0.99)\n",
"X_train_pca = pca.fit_transform(X_train)\n",
"X_valid_pca = pca.transform(X_valid)\n",
"X_test_pca = pca.transform(X_test)\n",
"\n",
"pca.n_components_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Exercise: Next, cluster the images using K-Means, and ensure that you have a good number of clusters (using one of the techniques discussed in this chapter).*"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 120,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn.cluster import KMeans\n",
"\n",
"k_range = range(5, 150, 5)\n",
"kmeans_per_k = []\n",
"for k in k_range:\n",
2021-11-22 09:36:00 +01:00
" print(f\"k={k}\")\n",
" kmeans = KMeans(n_clusters=k, random_state=42)\n",
" kmeans.fit(X_train_pca)\n",
2020-01-26 07:16:11 +01:00
" kmeans_per_k.append(kmeans)"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 121,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import silhouette_score\n",
"\n",
"silhouette_scores = [silhouette_score(X_train_pca, model.labels_)\n",
" for model in kmeans_per_k]\n",
"best_index = np.argmax(silhouette_scores)\n",
"best_k = k_range[best_index]\n",
"best_score = silhouette_scores[best_index]\n",
"\n",
"plt.figure(figsize=(8, 3))\n",
"plt.plot(k_range, silhouette_scores, \"bo-\")\n",
2021-11-22 09:36:00 +01:00
"plt.xlabel(\"$k$\")\n",
"plt.ylabel(\"Silhouette score\")\n",
2020-01-26 07:16:11 +01:00
"plt.plot(best_k, best_score, \"rs\")\n",
2021-11-22 09:36:00 +01:00
"plt.grid()\n",
2020-01-26 07:16:11 +01:00
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 122,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"best_k"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-02-19 05:52:10 +01:00
"It looks like the best number of clusters is quite high, at 120. You might have expected it to be 40, since there are 40 different people on the pictures. However, the same person may look quite different on different pictures (e.g., with or without glasses, or simply shifted left or right)."
2020-01-26 07:16:11 +01:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 123,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"inertias = [model.inertia_ for model in kmeans_per_k]\n",
"best_inertia = inertias[best_index]\n",
"\n",
"plt.figure(figsize=(8, 3.5))\n",
"plt.plot(k_range, inertias, \"bo-\")\n",
2021-11-22 09:36:00 +01:00
"plt.xlabel(\"$k$\")\n",
"plt.ylabel(\"Inertia\")\n",
2020-01-26 07:16:11 +01:00
"plt.plot(best_k, best_inertia, \"rs\")\n",
2021-11-22 09:36:00 +01:00
"plt.grid()\n",
2020-01-26 07:16:11 +01:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-08-23 02:21:09 +02:00
"The optimal number of clusters is not clear on this inertia diagram, as there is no obvious elbow, so let's stick with k=120."
2020-01-26 07:16:11 +01:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 124,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"best_model = kmeans_per_k[best_index]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Exercise: Visualize the clusters: do you see similar faces in each cluster?*"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 125,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"def plot_faces(faces, labels, n_cols=5):\n",
2021-03-02 03:09:30 +01:00
" faces = faces.reshape(-1, 64, 64)\n",
2020-01-26 07:16:11 +01:00
" n_rows = (len(faces) - 1) // n_cols + 1\n",
" plt.figure(figsize=(n_cols, n_rows * 1.1))\n",
" for index, (face, label) in enumerate(zip(faces, labels)):\n",
" plt.subplot(n_rows, n_cols, index + 1)\n",
2021-03-02 03:09:30 +01:00
" plt.imshow(face, cmap=\"gray\")\n",
2020-01-26 07:16:11 +01:00
" plt.axis(\"off\")\n",
" plt.title(label)\n",
" plt.show()\n",
"\n",
"for cluster_id in np.unique(best_model.labels_):\n",
" print(\"Cluster\", cluster_id)\n",
" in_cluster = best_model.labels_==cluster_id\n",
2021-03-02 03:09:30 +01:00
" faces = X_train[in_cluster]\n",
2020-01-26 07:16:11 +01:00
" labels = y_train[in_cluster]\n",
" plot_faces(faces, labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"About 2 out of 3 clusters are useful: that is, they contain at least 2 pictures, all of the same person. However, the rest of the clusters have either one or more intruders, or they have just a single picture.\n",
"\n",
"Clustering images this way may be too imprecise to be directly useful when training a model (as we will see below), but it can be tremendously useful when labeling images in a new dataset: it will usually make labelling much faster."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 11. Using Clustering as Preprocessing for Classification"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Exercise: Continuing with the Olivetti faces dataset, train a classifier to predict which person is represented in each picture, and evaluate it on the validation set.*"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 126,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"clf = RandomForestClassifier(n_estimators=150, random_state=42)\n",
"clf.fit(X_train_pca, y_train)\n",
"clf.score(X_valid_pca, y_valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Exercise: Next, use K-Means as a dimensionality reduction tool, and train a classifier on the reduced set.*"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 127,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"X_train_reduced = best_model.transform(X_train_pca)\n",
"X_valid_reduced = best_model.transform(X_valid_pca)\n",
"X_test_reduced = best_model.transform(X_test_pca)\n",
"\n",
"clf = RandomForestClassifier(n_estimators=150, random_state=42)\n",
"clf.fit(X_train_reduced, y_train)\n",
" \n",
"clf.score(X_valid_reduced, y_valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Yikes! That's not better at all! Let's see if tuning the number of clusters helps."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Exercise: Search for the number of clusters that allows the classifier to get the best performance: what performance can you reach?*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We could use a `GridSearchCV` like we did earlier in this notebook, but since we already have a validation set, we don't need K-fold cross-validation, and we're only exploring a single hyperparameter, so it's simpler to just run a loop manually:"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 128,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"from sklearn.pipeline import make_pipeline\n",
2020-01-26 07:16:11 +01:00
"\n",
"for n_clusters in k_range:\n",
2021-11-22 09:36:00 +01:00
" pipeline = make_pipeline(\n",
" KMeans(n_clusters=n_clusters, random_state=42),\n",
" RandomForestClassifier(n_estimators=150, random_state=42)\n",
" )\n",
2020-01-26 07:16:11 +01:00
" pipeline.fit(X_train_pca, y_train)\n",
" print(n_clusters, pipeline.score(X_valid_pca, y_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Oh well, even by tuning the number of clusters, we never get beyond 80% accuracy. Looks like the distances to the cluster centroids are not as informative as the original images."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Exercise: What if you append the features from the reduced set to the original features (again, searching for the best number of clusters)?*"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 129,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"X_train_extended = np.c_[X_train_pca, X_train_reduced]\n",
"X_valid_extended = np.c_[X_valid_pca, X_valid_reduced]\n",
"X_test_extended = np.c_[X_test_pca, X_test_reduced]"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 130,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"clf = RandomForestClassifier(n_estimators=150, random_state=42)\n",
"clf.fit(X_train_extended, y_train)\n",
"clf.score(X_valid_extended, y_valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's a bit better, but still worse than without the cluster features. The clusters are not useful to directly train a classifier in this case (but they can still help when labelling new training instances)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 12. A Gaussian Mixture Model for the Olivetti Faces Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Exercise: Train a Gaussian mixture model on the Olivetti faces dataset. To speed up the algorithm, you should probably reduce the dataset's dimensionality (e.g., use PCA, preserving 99% of the variance).*"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 131,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn.mixture import GaussianMixture\n",
"\n",
"gm = GaussianMixture(n_components=40, random_state=42)\n",
"y_pred = gm.fit_predict(X_train_pca)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Exercise: Use the model to generate some new faces (using the `sample()` method), and visualize them (if you used PCA, you will need to use its `inverse_transform()` method).*"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 132,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"n_gen_faces = 20\n",
"gen_faces_reduced, y_gen_faces = gm.sample(n_samples=n_gen_faces)\n",
"gen_faces = pca.inverse_transform(gen_faces_reduced)"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 133,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"plot_faces(gen_faces, y_gen_faces)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Exercise: Try to modify some images (e.g., rotate, flip, darken) and see if the model can detect the anomalies (i.e., compare the output of the `score_samples()` method for normal images and for anomalies).*"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 134,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"n_rotated = 4\n",
"rotated = np.transpose(X_train[:n_rotated].reshape(-1, 64, 64), axes=[0, 2, 1])\n",
"rotated = rotated.reshape(-1, 64*64)\n",
"y_rotated = y_train[:n_rotated]\n",
"\n",
"n_flipped = 3\n",
"flipped = X_train[:n_flipped].reshape(-1, 64, 64)[:, ::-1]\n",
"flipped = flipped.reshape(-1, 64*64)\n",
"y_flipped = y_train[:n_flipped]\n",
"\n",
"n_darkened = 3\n",
"darkened = X_train[:n_darkened].copy()\n",
"darkened[:, 1:-1] *= 0.3\n",
"y_darkened = y_train[:n_darkened]\n",
"\n",
"X_bad_faces = np.r_[rotated, flipped, darkened]\n",
"y_bad = np.concatenate([y_rotated, y_flipped, y_darkened])\n",
"\n",
"plot_faces(X_bad_faces, y_bad)"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 135,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"X_bad_faces_pca = pca.transform(X_bad_faces)"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 136,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"gm.score_samples(X_bad_faces_pca)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The bad faces are all considered highly unlikely by the Gaussian Mixture model. Compare this to the scores of some training instances:"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 137,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"gm.score_samples(X_train_pca[:10])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 13. Using Dimensionality Reduction Techniques for Anomaly Detection"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Exercise: Some dimensionality reduction techniques can also be used for anomaly detection. For example, take the Olivetti faces dataset and reduce it with PCA, preserving 99% of the variance. Then compute the reconstruction error for each image. Next, take some of the modified images you built in the previous exercise, and look at their reconstruction error: notice how much larger the reconstruction error is. If you plot a reconstructed image, you will see why: it tries to reconstruct a normal face.*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We already reduced the dataset using PCA earlier:"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 138,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-11-22 09:36:00 +01:00
"X_train_pca.round(2)"
2020-01-26 07:16:11 +01:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 139,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"def reconstruction_errors(pca, X):\n",
" X_pca = pca.transform(X)\n",
" X_reconstructed = pca.inverse_transform(X_pca)\n",
" mse = np.square(X_reconstructed - X).mean(axis=-1)\n",
" return mse"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 140,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"reconstruction_errors(pca, X_train).mean()"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 141,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"reconstruction_errors(pca, X_bad_faces).mean()"
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 142,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-03-02 03:09:30 +01:00
"plot_faces(X_bad_faces, y_bad)"
2020-01-26 07:16:11 +01:00
]
},
{
"cell_type": "code",
2021-11-22 09:36:00 +01:00
"execution_count": 143,
2020-01-26 07:16:11 +01:00
"metadata": {},
"outputs": [],
"source": [
"X_bad_faces_reconstructed = pca.inverse_transform(X_bad_faces_pca)\n",
2021-03-02 03:09:30 +01:00
"plot_faces(X_bad_faces_reconstructed, y_bad)"
2020-01-26 07:16:11 +01:00
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": null,
2018-04-04 11:49:00 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
2021-11-22 09:36:00 +01:00
"display_name": "Python 3",
2016-09-27 23:31:21 +02:00
"language": "python",
2017-09-15 17:52:20 +02:00
"name": "python3"
2016-09-27 23:31:21 +02:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
2017-09-15 17:52:20 +02:00
"version": 3
2016-09-27 23:31:21 +02:00
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
2017-09-15 17:52:20 +02:00
"pygments_lexer": "ipython3",
2021-10-17 03:27:34 +02:00
"version": "3.8.12"
2016-09-27 23:31:21 +02:00
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
2016-09-27 23:31:21 +02:00
}