handson-ml/05_support_vector_machines....

2613 lines
632 KiB
Plaintext
Raw Normal View History

2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"**Support Vector Machines**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_This notebook is an extra chapter on Support Vector Machines. It also includes exercises and their solutions at the end._"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/05_support_vector_machines.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/05_support_vector_machines.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
" </td>\n",
"</table>"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 11:03:20 +01:00
"This project requires Python 3.7 or above:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 1,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"import sys\n",
2016-09-27 23:31:21 +02:00
"\n",
2022-02-19 11:03:20 +01:00
"assert sys.version_info >= (3, 7)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It also requires Scikit-Learn ≥ 1.0.1:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import sklearn\n",
2016-09-27 23:31:21 +02:00
"\n",
"assert sklearn.__version__ >= \"1.0.1\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we did in previous chapters, let's define the default font sizes to make the figures prettier:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.rc('font', size=14)\n",
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
2021-12-08 03:16:42 +01:00
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And let's create the `images/svm` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"IMAGES_PATH = Path() / \"images\" / \"svm\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
2016-09-27 23:31:21 +02:00
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Linear SVM Classification"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The book starts with a few figures, before the first code example, so the next three cells generate and save these figures. You can skip them if you want."
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 5,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAACnCAYAAAAfQeokAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABH50lEQVR4nO3deXhU1fnA8e/JQhIIiywJhh1JEAigglCKS0AsCLi0giJLRS1BWW3FBRBQgVIUEEVAUFCrlKpAfy5QSkWCgmyCbEkElIJAWESRsGY9vz9mEpMwazIz996Z9/M885CZuXPnvSS8vDlzznmV1hohhBBCCCFCRZjRAQghhBBCCBFIUgALIYQQQoiQIgWwEEIIIYQIKVIACyGEEEKIkCIFsBBCCCGECClSAAshhBBCiJASsAJYKdVAKbVOKZWplEpXSo12cIxSSr2qlPpOKbVbKXVDoOITQgghhBChISKA75UPPKG13qGUqgpsV0r9V2udUeKYO4BE+60jMN/+pxBCCCGEED4RsBFgrfVxrfUO+9fngEygXpnD7gb+rm02AzWUUlcHKkYhhBBCCBH8AjkCXEwp1Ri4HthS5ql6wJES94/aHzte5vWpQCpAdHR0u4YNG/otVjMoLCwkLCz4p2uHwnVa6hoLgWPAJWyZogEQ6cHL/HSNhy4eIrcgl4aVGxIdHu3z83sr0N/L/fv3n9Za1ynv64Mtb+bm5pKfn09YWBjR0Vf+PJjt31pBQQE5OTkAxMTEoJSq8Dn9eY1aay5dugRAVFQU4eHhfnkfT5jte+kPco3+4TJvaq0DegNige3AHxw8txK4qcT9tUA7V+dLSkrSwW7dunVGhxAQoXCdVrnGvHN5esfNO/Q61umN9TbqCwcuePxaf1zj5bzLOnpKtA5/PlxfzL3o8/OXR6C/l8DX2kd5OBjy5k8//aTj4uI0oBcsWHDF82b7t1ZYWKjvuOMODei+ffv65Jz+vsZZs2ZpQNevX19nZ2f79b1cMdv30h/kGv3DVd4MaCmulIoElgNLtNYrHBxyFNs4U5H6QFYgYhNC2OSfz2dPzz2c/fIslepV4rq066jcrLKhMUVFRHHm6TPsGLqDmMgYQ2MR5lCzZk1effVVAJ566imOHz/u5hXGUkoxf/58qlSpwocffsgnn3xidEhujRo1ivbt23P06FHGjx9vdDhC+FQgd4FQwCIgU2s9y8lhHwN/tO8G8RvgrNba3FlNiCBixuK3SHRENG3i2xgdhjCR++67j169enH27Flmz55tdDhuNWrUiClTpgAwbNgwsrOzDY7ItfDwcN58803Cw8NZsGABWVkyHiWCRyBHgDsDg4CuSqmd9ltPpdSjSqlH7cesAg4C3wFvAMMCGJ8QIc3Mxa8QjiilmDdvHjNmzGDq1KlGh+ORkSNHcuONN1pmVLVt27bMmzePbdu2kZCQYHQ4QvhMwBbBaa03AC5n/dvnawwPTERCiCJmL367vNOFiLAIFt+1mAbVG7h/gQgZDRs25IknnjA6DI+Fh4fzxhtv0K5dO+bOnUv//v3p1KmT0WG5lJqaanQIQvhccC85FEK4Zfbi91LeJb48/CWf/+9zalWuZXQ4wsSOHz/OW2+9ZXQYbrVt25Ynn3wSrTVDhgwhNzfX6JA8orVm2bJl7N+/3+hQhKgwQ7ZBC6Ts7GxOnTpFXl6e0aGUW/Xq1cnMzDQ6DL8z03VGRkYSFxdHtWrVjA7Fr8xe/ALsOrmLAl1A67jWVI40V2zCPM6fP0/btm05ffo0zZs3NzoctyZOnMiyZctIT0/nxRdf5NlnnzU6JLfmzJnD6NGjSUlJ4fPPP/fJVm5CGCWoC+Ds7GxOnjxJvXr1fLbvohHOnTtH1apVjQ7D78xyndq+/+WxY8cAgrYItkLxC/B11tcAtEtoZ3AkwsxiY2MZMmQIf/3rXxkyZAivvfaa0SG5FBMTw4IFC7jtttuYPHkyffv2NX3hPmDAAKZMmUJaWhpvvfUWDz/8sNEhCVFuQT0F4tSpU9SrV4/KlStbtvgVgaeUonLlytSrV49Tp04ZHY5fWKX4Bdh+fDsA7a72fQFcty4odeWtbl2fv5UIgAkTJpCYmEhGRgYnTpwwOhy3unbtykMPPURubi6pqakUFhYaHZJLtWrVKt5tY8yYMZw8edLYgIQhgiVvBnUBnJeXR0yM7BkqyicmJsbSU2ecsVLxC7A9y38FsLP/v+X/dWuKjo5m4cKFgG0+8LfffmtwRO7NmDGDuLg4vvjiCxYtWmR0OG498MAD9OjRgzNnzjB69GijwxEGCJa8GdQFMCAjv6LcgvFnx2rF78W8i6T/mE64Cqdt3bZGhyMsICUlhYcffhittSVGVWvWrMkrr7wCwJNPPmmZhh6VK1fm/fffZ+XKlUaHJES5BH0BLISwsVrxC1CoC3mlxys83flpWQAnPPbSSy8RERHBhg0b2Lx5s9HhuHX//fdzxx13cPbsWUuMqjZu3JjJkycDthFsIawoqBfBCSFsrFj8AsRWimVEhxFGhyEspmbNmjRp0oTNmzfToUMHo8Nxq2hUtVWrVsVtku+8806jw3Jp1KhRhIWFyR7BwrJkBNjCUlJSGDHCesVB48aNfTpqYNW/h0CxavFrRiUXf3TpkmLZxR+hoFq1apYofotYrU1yREQEjz/+OJUrSy4Rzpk5Z0oBbEKDBw+md+/ebo9bsWIF06ZNC0BEvrVt2zaGDZMu14Fg9eL3lc2vsCJzBZfzL/vl/PHx3j0eLIs/Qs3q1av5z3/+Y3QYblmtTXKR7OxspkyZEpSLhsWVvMmbZs6ZMgXCjbp1HX+j4uPBqF12cnNzqVSpEjVr1jQmABfy8vKIjIx0eUydOnUCFI1nCgsL0VoTHh5udCg+ZfXi91LeJZ5Y8wQazbmx5/zyHhbYKUtU0OrVq7njjjuoV68eGRkZpt7X24ptkgHuuusu1q9fT1hYGOPGjTM6HOFnwZI3ZQTYDTP89vLoo4/Su3dvpk+fTv369alfvz5w5Uf/K1asoE2bNsTExFCzZk1uvfVWp/s0PvDAA9x7772lHissLKRBgwa8/PLLgK0hxIsvvsg111xDTEwMrVu35r333is+/tChQyilWLp0KV27di3e2P3s2bMMGjSIuLg4oqOjadq0afHekXDlFIjs7Gwee+wxEhMTiY6OpkWLFrz//vulrqt169ZERUXRoEEDpk6ditba6d/XmTNnePDBB7nqqquIiYmhW7dupKenFz//9ttvExsby6pVq0hOTqZSpUqm6UDnK1YvfuHXDnAt67SUBXCi3G6//XY6duzIsWPHGDt2rNHhuGXFNslFXexeeOEFaZMsLEMKYItYv349u3fvZvXq1axdu/aK50+cOEG/fv148MEHyczM5IsvvmDQoEFOzzdw4EBWrlzJL7/8Uuo9jh8/zgMPPADYktqiRYuYO3cuGRkZjB07lqFDh16x7c3YsWMZNmwYGRkZ3HPPPTz77LPs2bOHTz/9lG+//ZbFixdTr149h3ForbnjjjtYv3498+bNIyMjg1mzZlGpUiUAtm/fTt++ffnDH/7Anj17+Nvf/sa0adNcdnkaPHgwW7Zs4aOPPmLr1q1UrlyZHj16cOnSpeJjLl++zJQpU1iwYAEZGRk0atTI6fmsJhiKX/i1A1z7hPYGRyKsLDw8nIULFxIREcH8+fP56quvjA7JrYkTJ9KsWbPiNslm161bNx588EFycnIYOnSoywEKIcxCpkBYRHR0NIsXLyYqKsrh81lZWeTl5dGnT5/iYi45Odnp+bp37061atVYvnw5jzzyCABLlizhtttuo27duly4cIFZs2axZs0abr75ZgCaNGnC1q1bmTt3Lr169So+18iRI+nTp0/x/cOHD3P99dcXL0Bp3Lix0zg+++wzNm3aRHp6OvXr16dq1ao0bdq0+PlZs2Zx66238vzzzwOQlJTEgQMHmD59OiNHjrzifAcOHODjjz9m/fr13HLLLQC8++67NGzYkCVLlvCnP/0JgIKCAubMmUO7dsHVXjdYil/wbwc4EVratGnDk08+ybRp00hNTWXHjh3Fv2SbkRX
"text/plain": [
"<Figure size 720x194.4 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 51\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
2016-09-27 23:31:21 +02:00
"from sklearn.svm import SVC\n",
"from sklearn import datasets\n",
"\n",
"iris = datasets.load_iris(as_frame=True)\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = iris.target\n",
2016-09-27 23:31:21 +02:00
"\n",
2017-02-17 14:47:18 +01:00
"setosa_or_versicolor = (y == 0) | (y == 1)\n",
"X = X[setosa_or_versicolor]\n",
"y = y[setosa_or_versicolor]\n",
2016-09-27 23:31:21 +02:00
"\n",
"# SVM Classifier model\n",
"svm_clf = SVC(kernel=\"linear\", C=float(\"inf\"))\n",
"svm_clf.fit(X, y)\n",
"\n",
2016-09-27 23:31:21 +02:00
"# Bad models\n",
"x0 = np.linspace(0, 5.5, 200)\n",
"pred_1 = 5 * x0 - 20\n",
2016-09-27 23:31:21 +02:00
"pred_2 = x0 - 1.8\n",
"pred_3 = 0.1 * x0 + 0.5\n",
"\n",
"def plot_svc_decision_boundary(svm_clf, xmin, xmax):\n",
" w = svm_clf.coef_[0]\n",
" b = svm_clf.intercept_[0]\n",
"\n",
" # At the decision boundary, w0*x0 + w1*x1 + b = 0\n",
" # => x1 = -w0/w1 * x0 - b/w1\n",
" x0 = np.linspace(xmin, xmax, 200)\n",
" decision_boundary = -w[0] / w[1] * x0 - b / w[1]\n",
2016-09-27 23:31:21 +02:00
"\n",
" margin = 1/w[1]\n",
" gutter_up = decision_boundary + margin\n",
" gutter_down = decision_boundary - margin\n",
" svs = svm_clf.support_vectors_\n",
"\n",
" plt.plot(x0, decision_boundary, \"k-\", linewidth=2, zorder=-2)\n",
" plt.plot(x0, gutter_up, \"k--\", linewidth=2, zorder=-2)\n",
" plt.plot(x0, gutter_down, \"k--\", linewidth=2, zorder=-2)\n",
" plt.scatter(svs[:, 0], svs[:, 1], s=180, facecolors='#AAA',\n",
" zorder=-1)\n",
2016-09-27 23:31:21 +02:00
"\n",
2021-12-08 03:16:42 +01:00
"fig, axes = plt.subplots(ncols=2, figsize=(10, 2.7), sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[0])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(x0, pred_1, \"g--\", linewidth=2)\n",
"plt.plot(x0, pred_2, \"m-\", linewidth=2)\n",
"plt.plot(x0, pred_3, \"r-\", linewidth=2)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\", label=\"Iris versicolor\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\", label=\"Iris setosa\")\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper left\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plot_svc_decision_boundary(svm_clf, 0, 5.5)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\")\n",
"plt.xlabel(\"Petal length\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"large_margin_classification_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnwAAACyCAYAAADYkBK7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABIxUlEQVR4nO3deXxTVfr48c/ThbVAoSyWVTbZQUDWLoIroMKgIiI6MOPoMKOjM46jqONvXMaNxRG/jgs6OvoVFZWvG4KKQGnZ951C2QuUsi9l63Z+fyS5JiWBtCS5Sfu8X6+82iQnuc9N2psn555zHjHGoJRSSimlyq8ouwNQSimllFLBpQmfUkoppVQ5pwmfUkoppVQ5pwmfUkoppVQ5pwmfUkoppVQ5pwmfUkoppVQ5pwmfqjBExIjI7Zf4HI+KyM4AhaSUUmUiIs+IyPoAPE+eiIwOQEgqzGnCpy6ZiKSJyBtebh8tInl2xKSUUnYSkXoi8qaI7BSRcyKSKyKzReR6u2NTFVOM3QEopZRS5dA0oBpwL7AVqA9cDSTYGZSquLSHT4WEiPxXRKaLyMMisldEjorIByJSza1Nqogsdp5iOC4iS0Sko9v9vUVkjoicct4/W0QaOu8bICIZzuc9IiI/iki7i8TUSEQ+cz7mqIh8LyKtS7R5TET2O2P6CIgL8EujlCpnRCQeSAHGGmNmG2N2GWOWGWMmGGM+c7apJCIvisguZw/gdhF5yHlftIj8R0R2iMgZEclyHosu+JktIr8RkY0iclZEtojIX9wfIyKtnGdkzorIZhG5OYgvgwoz2sOnQikFyAGuA5oAnwNbgJdEJAb4BvgPMBKIBboBRQAi0gWYC/wv8AhwDkjll7/h6sBrwFqgKvB34DsRaW+MyS8ZiDPRnAssxPGtOx94FPhZRNoZY06LyB3AP4E/OdsOAx4HjgTsFVFKlUd5zstgEZlvjDnrpc2HOI6JDwOrgGY4jovg6IzZC9wBHAR6ApOBwziOkecRkfuA53Acr1YAHYF3gQLgDWfi9xVwFOiDo/dxElD5EvdVRQpjjF70ckkXIA14w8vto4E85+//BbKBGLf73wV+dv5eBzDA1T62MQVYXIqYquNIFpPdbjPA7c7ffwtkAeJ2fzSOA+odzusLgXdLPO/PwE67X3O96EUv4X0BbsPx5fAssAiYAPRy3tfaeTwaUIrne9l1vHRefwZY73Z9N3BPicf8Gdjo/P0G5zGxqdv9yc44Rtv9eukl+Bc9patCaaMxptDt+j4c41owxhzBkRT+6Dy1+oiINHFr2xWY7euJRaSliHwiIttE5ASQi+NbclMfD+kONAdOOk/X5gHHgdpAS2ebdjgO1O5KXldKqfMYY6YBDYFbgJlAX2CxiDyJ43hWjOPMgVciMkZElovIQefx6S/4OJ6JSD0cvYPvuI5nzse8jOfxbK8xZrfbQ5c441AVgJ7SVYFwAqjl5fZ4HEmUS0GJ+w1u40iNMb8RkdeAAcBg4AUR+ZUx5kdALhLDdzhOgfze+bMQ2AhU8tE+ClgN3OnlPj1lq5S6ZMZxKneW8/KciLyHo2fungs9TkSG4xii8iiOMw0ngAeAoT4e4jqOjnG29/q0pQhdlUOa8KlA2AwMEhExxnGewKmb8z6/GWPWAGuAV0RkJjAK+BFYCVzj7TEikoDj2+sDxpi5ztu6ceG/75XACOCQMeaYjzabgN7A+2639fZ7Z5RSytNGHMelTBxJWn/gBy/tkoElxhhruSsRaemlHQDGmFwR2Qu0NMZ8dIFtNxKRJsaYbOdtPdHJmxWGvtEqEN4CWgD/IyJdRKSNiPwFR0I1wZ8nEJHmIvKyiPQVkWYi0h/ojOMgBTAe6Coik9228TsRaYpjEPIh4D7nLLSrgbdx9PL5MgXHad9vRORq5/ZTRWSi20zdScAoEblPRFqLyBNAr1K9MkqpCkdEEpwrCtwtIp2dx5dhwGPAbOcX28+B90TkNuf9KSLi6vnbAnQTkYHOY8/TOCaXXcgzwGPOmbltRKSjiPzaedwCx/jjTOAjEblSRPoA/+LCx0lVjmjCpy6ZMWY7jhmzrYGfgKU4TpUOM8bM8PNpTgNXAF/gONh9iCMpe8W5jdU4Zve2BRbjGHtyJ1BgjCkGhuNIENcD/waexjGT11fMp50xb3duM9O5zdo4EkiMMVNxHERfwDGLrhPwqp/7o5SquPJwHKceBuYBG4AXgU9wHKsAfu28/jqO489/+WVozDs4EsJPgGXA5cDEC23QGPMejslo9+A4S5IB3A/scN5fjOOUcBSO4+dHOFYh8HmcVOWLeJ6BU0oppZRS5Y328CmllFJKlXOa8CmlVIiJSBMRmSsim0Rkg4g8bHdMSqnyTU/pKqVUiIlIIpBojFkpIjVwVEb4lTFm40UeqpRSZaI9fEopFWLGmBxjzErn7ydxLAHUyN6olFLlmSZ8SillIxG5HEflhSU2h6KUKsciduHl+Ph406pVK7vD8MupU6eoXr263WH4LZLi1ViDI5JiBVixYsUhY0w9u+MoLRGJA6YBfzbGnChx3/04ltWgSpUq3RMTEykqKqK42LMSVlRUFFWqVLGuG2MQCXxRheLiYqKi7OsjsHv74RCD3dsPhxgq+vYBtmzZUrbjnd3FfMt6ueKKK0ykmDt3rt0hlEokxauxBkckxWqMMcByEwbHpdJcgFgcVWQeuVhb9+Ndfn6+Wbx4sRk/frwZPHiwGTdunHXf8uXLDWDq1KljbrnlFjN+/HizePFik5+ff6kvse1/E3ZvPxxisHv74RBDRd++MWU/3kVsD59SSkUqcXTB/QfYZIwp1WLesbGx9OrVi169evHoo4963HfgwAEaNWrE3r17+e677/juu+8AqFatGr1792bq1KnUrVs3ULuhlIogmvAppVToJeGoiLBORFY7b3vS+F+ZxquBAweSnZ3Nzp07ycjIsC6bN29m1apV1KlTx2p7//33U6tWLVJSUkhOTva4TylV/mjCp5RSIWaMmQ8EfqAdICI0b96c5s2b8+tf/xpw9Pxt3brVGnt0+vRpPvjgAwoLC5kwwVHuumPHjqSkpJCSksJ1111HvXoRNyRSKXUBmvAppVQ5V79+ferXr29dj4mJYfr06aSnp5ORkcHSpUtZv34969ev56233uLjjz9m5MiRAGzdupWioiK7QldKBYgmfEopVcFUqlSJG2+8kRtvvBGAs2fPsnz5cusUcGpqqtV2woQJvPPOO7z22mu88cYbVi9gly5diI6OtmsXlFKlpAmfUkpVcFWqVCE5OZnk5GSeeOIJj/vi4uKoX78+hYWFTJs2jWnTpgFQo0YN7rvvPiZOnGhHyEqpUgr5YjIi8hdn7cj1IvKpiFQRkToiMktEspw/a4c6LqWUUuebMGEC+/fvp2PHjrz33nuMGjWKFi1acPLkSWJifukz2LBhAykpKTz11FP88MMPnDhx4gLPqpQKtZD28IlII+AhoL0x5oyIfA7cCbQHZhtjXhaRscBY4PFQxqYiy5Qp8NRTsHv31TRtCi+8AM4hR0qpABMRKleuzL333su9994LwN69ez0WeJ43bx7z589n/vz5gGNB6C5dulingIcMGUJsbKwt8Sul7DmlGwNUFZECoBqwD3gC6Oe8/0MgDU34lA9TpsD998Pp0wDCrl2O66BJn1Kh0qiRZ+nfu+66i6ZNm1rjAJctW8aqVatYtWoVH3zwAUeOHLHazpw5k7Zt23L55ZcHpSqIUup8IU34jDF7RWQCsBs4A/xkjPlJRBoYY3KcbXJEpP4Fn0hVaE895Ur2fnH6tON2TfiUskd8fDw333wzN998M+BY+mXJkiVkZGRw9uxZ6/TvuXPnGDp0KOfOnaNRo0ZWD2BKSgodOnSwvWyVUuVVqE/p1gaGAM2BY8AXInJ3KR5v1ZasV68eaWlpQYgy8PLy8iImVgj/eHfvvhpvS5jt3m1IS5sX+oD8FO6vq7tIilWFp2rVqtG/f3/69+/vcfuRI0cYMGAA8+fPZ+/evXz22Wd89tlnANSuXZupU6dy/fXX2xGyUuVaqE/pXgfsMMYcBBCR/wP6Arkikuj
"text/plain": [
"<Figure size 648x194.4 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 52\n",
"\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
2016-09-27 23:31:21 +02:00
"Xs = np.array([[1, 50], [5, 20], [3, 80], [5, 60]]).astype(np.float64)\n",
"ys = np.array([0, 0, 1, 1])\n",
"svm_clf = SVC(kernel=\"linear\", C=100).fit(Xs, ys)\n",
"\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(Xs)\n",
"svm_clf_scaled = SVC(kernel=\"linear\", C=100).fit(X_scaled, ys)\n",
2016-09-27 23:31:21 +02:00
"\n",
2021-12-08 03:16:42 +01:00
"plt.figure(figsize=(9, 2.7))\n",
2016-09-27 23:31:21 +02:00
"plt.subplot(121)\n",
"plt.plot(Xs[:, 0][ys==1], Xs[:, 1][ys==1], \"bo\")\n",
"plt.plot(Xs[:, 0][ys==0], Xs[:, 1][ys==0], \"ms\")\n",
"plot_svc_decision_boundary(svm_clf, 0, 6)\n",
"plt.xlabel(\"$x_0$\")\n",
"plt.ylabel(\"$x_1$    \", rotation=0)\n",
"plt.title(\"Unscaled\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 6, 0, 90])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.subplot(122)\n",
"plt.plot(X_scaled[:, 0][ys==1], X_scaled[:, 1][ys==1], \"bo\")\n",
"plt.plot(X_scaled[:, 0][ys==0], X_scaled[:, 1][ys==0], \"ms\")\n",
"plot_svc_decision_boundary(svm_clf_scaled, -2, 2)\n",
"plt.xlabel(\"$x'_0$\")\n",
"plt.ylabel(\"$x'_1$ \", rotation=0)\n",
"plt.title(\"Scaled\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([-2, 2, -2, 2])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"sensitivity_to_feature_scales_plot\")\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## Soft Margin Classification"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 7,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAACyCAYAAABMWnkMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABkTklEQVR4nO3dd1hURxfA4d+AFBF7lKKIvbdYosaKvdfYJVGjxhI/jTFRY4pJLNHEErvR2I1GjTF2YwF7770hijS70dgQ5vtjYQMI7C7SOe/z3Efu3dnZmQWOw92ZM0prjRBCCCGEEOmFVXI3QAghhBBCiKQkA2AhhBBCCJGuyABYCCGEEEKkKzIAFkIIIYQQ6YoMgIUQQgghRLoiA2AhhBBCCJGuJNkAWCnlppTyUkpdUEqdU0oNiqGMUkpNVUpdVUqdVkpVSKr2CSGEEEKI9CFDEr7WK+BTrfVxpVRm4JhSapvW+nykMk2AIuFHFWBW+L9CCCGEEEIkiCS7A6y1DtRaHw//+jFwAcgTrVgrYLE2OAhkU0q5JFUbhRBCCCFE2pcsc4CVUvmBt4FD0R7KA/hFOr/F64NkIYQQQggh4i0pp0AAoJRyBP4ABmut/4n+cAxPeW2vZqVUH6APgL29fcV8+fIleDtTkrCwMKys0v56xfTQT+lj2pHU/bx8+fJdrXWu+D4/cty0s7Or6OTkRIYMGbC2tkapmEJvzLTWaK2Nfdda8+zZM6ysrOJVX2JJDz+H6aGPkD76KX1MHHHGzYhglhQHYANsBYbE8vgcoHOk80uAS1x1Fi1aVKd1Xl5eyd2EJJEe+il9TDuSup/AUZ1AsThv3rwaw80FbW1trZs1a6ZXrFihnz59anG71q9fr3PkyGGsz8rKSjdo0EAvWbJEP3ny5I37HV/p4ecwPfRR6/TRT+lj4ogrbiZlFggF/Apc0FpPiqXYOuD98GwQVYFHWuvApGqjEEKkBw4ODvz+++80b94cpRQbN26kU6dOODs78+GHH+Lt7U1YWJhZdTVv3pzAwED+/PNP2rRpg7W1Ndu2bcPT05OCBQsSEhKSyL0RQgjLJeW96OqAJ1BXKXUy/GiqlOqrlOobXmYT4ANcBeYC/ZOwfUIIkW506NCB9evXExAQwNSpU6lUqRL//PMP8+fPx8PDgwIFCjBy5EguXrxosi5bW1tat27NmjVrCAoKYtasWbz77rs0bNgQGxsbAF68eMEXX3zBmTNnErtrQghhUlJmgdirtVZa67Ja6/Lhxyat9Wyt9ezwMlprPUBrXUhrXUZrfTSp2ieEEOlRrly5GDhwIEeOHOHChQt88cUX5MuXj5s3bzJ27FhKlChB5cqVmTZtGnfu3DFZX44cOejbty/79u1j/vz5xusbN25k3LhxlC1blvLlyzNx4kQCA+UDPiFE8kjbM66FEEKYrXjx4owZM4br16/j5eVFz549yZw5M0ePHuV///sfrq6utGjRgpUrV/L8+XOT9UXc/QUoVqwYffv2JXv27Jw6dYqhQ4eSN29eGjVqxNKlS82eciGEEAlBBsBCCCGisLKyok6dOvz6668EBwezYsUKmjZtitaaDRs20LFjR5ycnOjduze7d+82a/BaqlQpZs2aRWBgIH/88QetW7fG2tqav//+m3HjxkXJHCGDYSFEYpMBsBBCpDPPnz/n+++/5/r16ybLZsyYkY4dO7Jx40b8/f2ZMmUKFStW5J9//mHevHnUrl2bggUL8uWXX3Lp0iWT9dnZ2dG2bVv+/PNPAgMDmTFjBl988YVxAHzt2jXc3d0ZNmwYZ8+efeO+CiFETGQALNI9X19flFIcPXo0xnMh0pqwsDC+/vprChYsSM2aNfnll1948OCByec5OTkxaNAgjh49yrlz5xg+fDhubm7cuHGDMWPGULx4capUqcL06dO5e/euyfpy5sxJ//796dq1q/Ha2rVruXXrFhMmTKBMmTJUqFCByZMnExQU9EZ9FkKIyGQALFIFf39/+vTpQ968ebG1tSVPnjz07t2bW7duWVRPnTp1+Pjjj+Ms4+bmRmBgIOXLl3+DFovUwNkZlHr9cHZO7pYlLjs7O7p27YqDgwN79+7lo48+wtnZmffee49//om+P1HMSpYsybhx4/D19WXnzp306NEDR0dHDh8+zMCBA3FxcaFly5asXr3arPnCEYYMGcKePXvo06cP2bJl48SJEwwZMoQ8efLQtm3biBzxQohkklbipgyARYp3/fp1KlWqxNmzZ1m0aBFXr15l6dKlnDt3jsqVK+Pr65ugr2dtbY2zszMZMsR/o8SXL18mYItEYgkOtux6WmFtbc3SpUsJCgpi0aJF1KtXj5CQEM6cOUPmzJmN5c6fP29ywGllZYWHhwfz588nODiY3377jSZNmqC1Zv369bRv3x5nZ2f69OnDnj17TNanlKJGjRrMmTOHwMBAVq9eTatWrbC2tsbW1tY4VeLly5fs3LlT5gsLkcTSStyUAbBI8QYMGICVlRXbt2+nXr165MuXDw8PD7Zv346VlRUDBgwAYr672717d5o3b278eteuXcyYMQOlFEqpGAfPMU2BOH/+PM2aNSNz5szkzp2bzp07R/lINuJ1xo8fT968ecmbN28ivBNCJKzMmTPz/vvvs337dm7evMnChQuNA8wbN25QqlQpihcvbvZ8YQcHBzp37symTZvw9/dn8uTJvP322zx69Ii5c+dSq1YtChUqxNdff82VK1dM1mdvb0+7du1Yu3YtAQEBjBs3zvjY5s2bqVevHu7u7gwfPpxz587F/40QQqQ7MgAWKdr9+/fZsmULAwYMwMHBIcpjDg4O9O/fn82bN5s1f/Hnn3+mWrVq9OjRg8DAQAIDA3FzczP5vMDAQGrVqkXp0qU5fPgw27dv58mTJ7Rs2TLK3addu3Zx+vRptmzZwo4dOyzvrBDJKG/evFSrVs14fvXqVZydnbl8+bJxvnCtWrWYO3cuDx8+NFmfk5MTgwcP5vjx45w9e5Zhw4aRN29erl+/zvfff0/RokWpWrUqM2fO5N69eybre+uttyhQoIDx/Pnz5+TPn59bt24xfvx4SpcuTYUKFZgyZQrBqe1WlBAiyckAWKRoV65cQWtNiRIlYny8ZMmSaK3NupuUNWtWbG1tcXBwwNnZGWdnZ6ytrU0+b9asWZQrV47x48dTokQJypYty+LFizly5EiUu8T29vbMnz+f0qVLU6ZMGfM7KUQKVK9ePfz8/Ni8eTNdunQhY8aMxrm57u7uPHv2zOy6SpUqxQ8//ICvry/bt2/ngw8+wNHRkUOHDjFgwABcXFxo3bo1f/zxBy9evDCrzo4dO3Lt2jV2795N7969yZo1KydOnOCTTz6hXr168e22ECKdkAGwSBUi5wiNLGI+YWyPJ4Rjx46xe/duHB0djUfEneNr164Zy5UuXRo7O7tEa4cQSS1Dhgw0btyYZcuWERwczMKFC6lbty4NGzYkY8aMALx69Yphw4Zx5MgRk/N7ra2tqVevHgsXLiQoKIhly5bRuHFjQkND+euvv3jvvfdwdnY27iRnzvzjiCwWQUFBrFy5khYtWvD+++8by/j6+tKzZ0+8vb1lvrAQwkgGwCJFK1KkCEqpWOf3XbhwAaUUhQoVwsrK6rX/MENCQt64DWFhYTRr1oyTJ09GOa5cuWKcXwyQKVOmN34tkbScnCy7Hnn1s4dHnVS7+jk+MmfOzAcffMCOHTtYsWKF8fq2bduYMGEC77zzDiVKlGDMmDHcuHHDZH2ZMmWiS5cubN68mVu3bjFx4kTKly/Pw4cPmTNnDjVq1KBw4cJ88803XL161WR99vb2tG/fnnXr1vH5558bry9dupQFCxbg4eFBgQIF+OKLL7hw4UL83gQhhEVxMyXHTBkAixQtR44cNGrUiJkzZ/L06dMojz19+pQZM2bQpEkTcuTIQa5cuQgMDIxS5tSpU1HObW1tCQ0NtagNFSpU4Ny5c7i7u1O4cOEoR+QV8yL1CQoCrV8/Yks5m1ZWP7+pyFOHChcuzODBg8mdOzeXLl3iyy+/JH/+/Mad5F69emWyPhcXF4YMGcKJEyc4ffo
"text/plain": [
"<Figure size 720x194.4 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 53\n",
"\n",
2016-09-27 23:31:21 +02:00
"X_outliers = np.array([[3.4, 1.3], [3.2, 0.8]])\n",
"y_outliers = np.array([0, 0])\n",
"Xo1 = np.concatenate([X, X_outliers[:1]], axis=0)\n",
"yo1 = np.concatenate([y, y_outliers[:1]], axis=0)\n",
"Xo2 = np.concatenate([X, X_outliers[1:]], axis=0)\n",
"yo2 = np.concatenate([y, y_outliers[1:]], axis=0)\n",
"\n",
"svm_clf2 = SVC(kernel=\"linear\", C=10**9)\n",
2016-09-27 23:31:21 +02:00
"svm_clf2.fit(Xo2, yo2)\n",
"\n",
2021-12-08 03:16:42 +01:00
"fig, axes = plt.subplots(ncols=2, figsize=(10, 2.7), sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[0])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(Xo1[:, 0][yo1==1], Xo1[:, 1][yo1==1], \"bs\")\n",
"plt.plot(Xo1[:, 0][yo1==0], Xo1[:, 1][yo1==0], \"yo\")\n",
"plt.text(0.3, 1.0, \"Impossible!\", color=\"red\", fontsize=18)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.annotate(\n",
" \"Outlier\",\n",
" xy=(X_outliers[0][0], X_outliers[0][1]),\n",
" xytext=(2.5, 1.7),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(Xo2[:, 0][yo2==1], Xo2[:, 1][yo2==1], \"bs\")\n",
"plt.plot(Xo2[:, 0][yo2==0], Xo2[:, 1][yo2==0], \"yo\")\n",
"plot_svc_decision_boundary(svm_clf2, 0, 5.5)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.annotate(\n",
" \"Outlier\",\n",
" xy=(X_outliers[1][0], X_outliers[1][1]),\n",
" xytext=(3.2, 0.08),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"sensitivity_to_outliers_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 8,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('linearsvc', LinearSVC(C=1, random_state=42))])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2017-06-01 09:23:37 +02:00
"import numpy as np\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.pipeline import make_pipeline\n",
2016-09-27 23:31:21 +02:00
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.svm import LinearSVC\n",
"\n",
"iris = load_iris(as_frame=True)\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = (iris.target == 2) # Iris virginica\n",
2017-06-01 09:23:37 +02:00
"\n",
"svm_clf = make_pipeline(StandardScaler(),\n",
" LinearSVC(C=1, random_state=42))\n",
2017-06-01 09:23:37 +02:00
"svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 9,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([ True, False])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"X_new = [[5.5, 1.7], [5.0, 1.5]]\n",
"svm_clf.predict(X_new)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 10,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.66163411, -0.22036063])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"svm_clf.decision_function(X_new)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAACyCAYAAABMWnkMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAB0mElEQVR4nO2dd3hUVdrAf2fSQyAklCQkQBJKDKTRFZQiKgooiK5YUHHd1bWsFV37urb9dBXF1XXVddVV7GAHRYSgIEWQhAAhlBRIARIIaaTP+f6YzGVuZjKZZCaZCTm/55knM+eec+97T+59551z3yKklCgUCoVCoVAoFN0Fg7sFUCgUCoVCoVAoOhNlACsUCoVCoVAouhXKAFYoFAqFQqFQdCuUAaxQKBQKhUKh6FYoA1ihUCgUCoVC0a1QBrBCoVAoFAqFoluhDGCFQqFQKBQKRbdCGcCKbokQIkUIsVQIkS+EqBFC5Agh3hdCJLlo/5OFEF8JIQqEEFIIsdAV+1UoFIqOxlP0oxDi1qZj1wghtgkhzmlPH4XCFsoAVnQ7mpTtVqAWmA8MB65v2nyniw4TBOxs2l+1i/apUCgUHYqn6EchxHxgCfAMMAr4BVgphBjUlj4KRUsIVQlO0Z0QQpwF/AzcJ6V80cb2UCnlcRcfsxK4XUr5jiv3q1AoFK7Ek/SjEGIzsENK+UeLtn3AZ1LKBx3to1C0hFoBVnQ3XgA221LuAM2VuxDiISFEZSsv9chNoVCcDniEfhRC+AJjgFXNNq0CJjraR6Gwh7e7BVAoOgshxDDgLOCqNgz7N/BJK30K2i2UQqFQeAAeph/7Al7AkWbtR4Dz2tBHoWgRZQAruhOjm/5udXRA04qHSx/5KRQKhQfiifqxuY+msNHmSB+FwgrlAqHoTgQ2/a10dIBygVAoFN0ET9KPJUAjEN6svT+nVnwd6aNQtIhaAVZ0J3Y2/Z0CfNx8oxAiUEp5slmzcoFQKBTdAY/Rj1LKOiHENuB84FOLTecDyxzto1DYQxnAim6DlPJXIcQK4J9CiABgA6ZHZaOBPwJ/A9Y3G9OuR3xCiCBgaNNHAzBICJECHJdSHmz3SSgUCkUH4IH6cTHwnhBiS5MsfwIGYDK6aUMfhcImKg2aolshhPDDlHtyATAEqAOygW+BZ6SUNS46zlRgrY1N70opF7riGAqFQuFKPE0/CiFuBe4HIjCtUN8tpfyp2b5a7aNQ2EIZwAqFQqFQKBSKboUKglMoFAqFQqFQdCuUAaxQKBQKhUKh6FYoA1ihUCgUCoVC0a1QBrBCoVAoFAqFoluhDGCFQqFQKBQKRbeiy+cB7t27txw6dGjrHd1IVVUVPXr0cLcYdlEyugYlo2tQMsK2bdtKpJT9OuwAFig96hqUjK7B02X0dPlAyWjGrh6VUnbp1/Dhw6Wns3btWneL0CpKRtegZHQNSkYpga1S6VENdU24BiWj83i6fFIqGc3Y06PKBUKhUCgUCoVC0a1QBrBCoVAoFAqFoluhDGCFQqFQKBQKRbdCGcAKhUKhUCgUim6FMoAVCoVCoVAoFN2KLp8GzR5Go5GSkhJOnDhBY2Oj2+QIDg4mMzPTbcd3BCWjHi8vL3r37k3fvn0xGNTvRIVCoVAoTidOawM4Pz8fIQTR0dH4+PgghHCLHBUVFfTs2dMtx3YUJeMppJTU19dz5MgR8vPzGTRoUIcfU6FQKBQKRedxWi9tVVVVERkZia+vr9uMX0XXQwiBr68vkZGRVFVVuVschUKhUCgULua0NoAB9fha0W7UtaNQKBQKxemJ+oZXKBQKhUKhUHQrlAGsUCgUCoVCoehWdJoBLIQYKIRYK4TIFELsEkLcaaPPVCFEmRAiren1WGfJ1xWZOnUqt99+e4ftf+HChcyePdvp/aSmpiKEoKSkxOEx77zzDkFBQU4fW6FoD0UVRUx5ZwqHKw+7WxSFQqHocnQFHdqZWSAagHullL8JIXoC24QQP0gpdzfr97OU0nmrqwuzcOFCSkpK+Oabb+z2W758OT4+Ph0mx5IlS5BSOr2fiRMnUlRURJ8+fRweM3/+fGbOnOn0sRWK9vDkT0+y/uB6nlz3JK/OetXd4igUCkWXoivo0E5bAZZSFkkpf2t6XwFkApGddXxn8aRfM3V1dQCEhoZ2aFqw4OBgevfu3aocreHr60t4eHibMnEEBATQv39/h/srFK6iqKKIt9PexiiNvJ32tkfc8wqFQtFV6Co6VLhiha/NBxUiGvgJSJBSllu0TwWWAflAIbBISrnLxvibgJsA+vXrN+aTTz6xeZzg4GCGDh3qEpnvXn03b+94m98n/57F0xe3aWxjYyNeXl4O9//Tn/7EsWPH+PTTT3WfJ06cyOuvv05dXR3Z2dnMnDmT+Ph4XnjhBQC++uor/v73v3PgwAH8/f0ZOXIk7777rk1D8oYbbqC+vp73339fk1EIwciRI7ntttu4/fbbreSYOXMmcXFxBAYG8sEHHzBo0CDWrVvHd999x0MPPcShQ4cYM2YMf/jDH/j9739PRkYGgwcP5ueff2bWrFnk5OTQp08fli5dyqJFi/jwww/5y1/+Ql5eHmPGjOHVV18lOjoaQOtTVFSkybxixQr+8Y9/sGvXLgICApgwYQL/+9//8Pf356OPPuK1115j3759+Pv7c/bZZ/N///d/DBgwoE3/q+bs37+fsrIyh/tXVlZ6vOuGktE+L+59kRWHV9AgG/AW3syKmMVdw+6y6tfRMk6bNm2blHJsR+3fUT3qKajr1jUoGZ3H0+UDpUPN2NWjUspOfQFBwDZgno1tvYCgpvczgX2t7W/48OGyJXbv3t3itrZQWF4o/Z/ylzyODHgqQBZVFLVpfHl5eZv6X3/99XLWrFm6z0FBQfLqq6+WGRkZcseOHVJKKadMmSJvu+02KaWURUVF0sfHRz7//PMyJydHZmRkyDfffFMePnzY5jG++eYb6efnJ0tLSzUZ16xZI728vGRRUZFNOaZMmSKDgoLkPffcIzMzM+Xu3btlXl6e9PX1lXfffbfcs2eP/PTTT+XAgQMlIHNycqSUUq5du1YCsri4WEop5dtvvy29vb3l9OnT5ebNm2V6erpMSUmRF1xwgXast99+W/bo0UP7vHLlSunl5SUffvhhuWvXLpmeni7/8Y9/yKqqKimllG+99Zb89ttv5YEDB+TmzZvl1KlT5TnnnNOmebdFW6+htWvXOn3MjkbJ2DKW97r51dI939EyAltlJ+lle3rUU1DXrWtQMjqPp8snpdKhZuzp0U6tBCeE8MG0wrtUSrm8+XZpsRospVwhhPiXEKKvlNLx6KkO4MmfnsQojQA0yka3+LT4+/vz3//+Fz8/P5vbCwsLqa+v5/LLL2fw4MEAJCQktLi/GTNm0KtXL5YtW8aNN94ImFZdp0+fTnh4eIvjYmJitBVngAcffJDY2FheeOEFhBDExcWxd+9eHn74Ybvn09DQwKuvvkpcXBwAixYt4oYbbsBoNNrMv/vkk08yZ84cnnrqKa0tKSlJe//73/9eex8bG8trr71GfHw8+fn5REVF2ZVFoTBjea+bcdc9r1AoFF2NrqRDHfIBFkL4CyH+IoRY1ZSdYYfly8F9COAtIFNKadOHQAgR3tQPIcT4JvmO2dtvY2Mjn3zyCb/++ivHjh1zSdCWJWZflrpGk79rXWOdW3xaEhISWjR+AZKTkznvvPNISEjgsssu47XXXqO4uLjF/t7e3syfP5+lS5cCUFtby7Jly1iwYIFdOcaMGaP7vGfPHsaNG6fz750wYUKr5+Pn56cZvwADBgygvr6eEydO2Oy/fft2pk6d2uL+fvvtN+bMmcPgwYPp2bMnY8eanngcPHiwVVk8BU/yM+/KpBWl0fv/erPjiEOqScfG/I3avW6mrrGOX/J/cZV4HonRaOTYMbuqVqHoEig96hraq0e7kg51dAX4X8ClwKfAL0B7rMxJwLVAhhAirantIWAQgJTy38DlwC1CiAagGrhStmLR1tfXM3/+fO1zr169iImJITY2lgcffFBrNxpNv0jaWt3LU37N9OjRw+52Ly8vVq1axaZNm1i
"text/plain": [
"<Figure size 720x194.4 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"# extra code this cell generates and saves Figure 54\n",
"\n",
2016-09-27 23:31:21 +02:00
"scaler = StandardScaler()\n",
"svm_clf1 = LinearSVC(C=1, max_iter=10_000, random_state=42)\n",
"svm_clf2 = LinearSVC(C=100, max_iter=10_000, random_state=42)\n",
"\n",
"scaled_svm_clf1 = make_pipeline(scaler, svm_clf1)\n",
"scaled_svm_clf2 = make_pipeline(scaler, svm_clf2)\n",
2016-09-27 23:31:21 +02:00
"\n",
"scaled_svm_clf1.fit(X, y)\n",
"scaled_svm_clf2.fit(X, y)\n",
"\n",
2016-09-27 23:31:21 +02:00
"# Convert to unscaled parameters\n",
"b1 = svm_clf1.decision_function([-scaler.mean_ / scaler.scale_])\n",
"b2 = svm_clf2.decision_function([-scaler.mean_ / scaler.scale_])\n",
"w1 = svm_clf1.coef_[0] / scaler.scale_\n",
"w2 = svm_clf2.coef_[0] / scaler.scale_\n",
"svm_clf1.intercept_ = np.array([b1])\n",
"svm_clf2.intercept_ = np.array([b2])\n",
"svm_clf1.coef_ = np.array([w1])\n",
"svm_clf2.coef_ = np.array([w2])\n",
"\n",
"# Find support vectors (LinearSVC does not do this automatically)\n",
"t = y * 2 - 1\n",
"support_vectors_idx1 = (t * (X.dot(w1) + b1) < 1).ravel()\n",
"support_vectors_idx2 = (t * (X.dot(w2) + b2) < 1).ravel()\n",
"svm_clf1.support_vectors_ = X[support_vectors_idx1]\n",
"svm_clf2.support_vectors_ = X[support_vectors_idx2]\n",
"\n",
2021-12-08 03:16:42 +01:00
"fig, axes = plt.subplots(ncols=2, figsize=(10, 2.7), sharey=True)\n",
"\n",
"plt.sca(axes[0])\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\", label=\"Iris virginica\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\", label=\"Iris versicolor\")\n",
"plot_svc_decision_boundary(svm_clf1, 4, 5.9)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper left\")\n",
"plt.title(f\"$C = {svm_clf1.C}$\")\n",
"plt.axis([4, 5.9, 0.8, 2.8])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
"plot_svc_decision_boundary(svm_clf2, 4, 5.99)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.title(f\"$C = {svm_clf2.C}$\")\n",
"plt.axis([4, 5.9, 0.8, 2.8])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"regularization_plot\")\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Nonlinear SVM Classification"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 12,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAocAAADUCAYAAADukYmSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUSElEQVR4nO3df5Dcd13H8eebpBFocaIULoZ0CJ0pDBUJ2hOrHYZrASdoW9SRoR1xmIEx4AgGxw5a6kjlZGCE4dfATCejbUfAomJBKFBSJUtGG2gbciWtJaWghSOXRqYuJq3kms3bP27zmWuSo3d7u9/PZff5mNmZ3e/u916f3fvuJ69897vfi8xEkiRJAnhS7QFIkiRp5bAcSpIkqbAcSpIkqbAcSpIkqbAcSpIkqVi9lAefffbZuXHjxp7DHnnkEc4888ye11+Omtnmm++233v+7t27f5CZz+jjkFac03luNb9u/r59++h0Opx//vlV8kf5tR+G/AXn18xc9OWCCy7I5dixY8ey1j9ds803322/d8BduYR56nS8nM5zq/l181/60pfmpk2bquWP8ms/DPkLza9+rCxJkqTCcihJkqTCcihJkqTCcihJkqTCcihJPYqI6yPiYETcc8Lyt0TEvoi4NyL+ahDZ69ZBxNzl4osnyvV16waRJmmlaOK9bzmUpN7dCGyevyAiLgZeBbwwM38WeN8ggh96aGnLJQ2HJt77lkNJ6lFm7gQePmHx7wPvycwj3cccbHxgkrQMSzoJtiTpCT0XeElEvAv4EXBVZt554oMiYguwBWBsbIxWq7XEmIkF71n6z1qew4cPN55p/px2u02n06mWP8qvfb38iQXv6ddYLIeS1F+rgZ8CLgR+EfiHiDi3e8LZIjO3AdsAxsfHc2Jiom8D6OfPWoxWq9V4pvlz1q5dS7vdrpY/yq/9Ssg/Ub/G4sfKktRf08DN3T9AcAdwDDi78pgkadEsh5LUX58BLgGIiOcCa4Af9DtkbGxpyyUNhybe+5ZDSepRRNwE7AKeFxHTEfEG4Hrg3O7pbT4JvO7Ej5T74cAByJy77NjRKtcPHOh3kqSVpIn3vsccSlKPMvPKBe56baMDkaQ+cs+hJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJPUoIq6PiIMRcc+8Ze+NiG9GxDci4tMRsbbiECVpySyHktS7G4HNJyy7DXhBZr4QuB+4uulBSdJyWA4lqUeZuRN4+IRl2zPzaPfmV4ENjQ9MkpZhde0BSNIQez3w96e6IyK2AFsAxsbGaLVaPYccPnx4Wesvl/n18tvtNp1Op1r+KL/2w5xvOZSkAYiIa4CjwCdOdX9mbgO2AYyPj+fExETPWa1Wi+Wsv1zm18tfu3Yt7Xa7Wv4ov/bDnG85lKQ+i4jXAZcCL8vMrD0eSVoKy6Ek9VFEbAb+BHhpZj5aezyStFR+IUWSehQRNwG7gOdFxHREvAH4CPA04LaImIqI66oOUpKWyD2HktSjzLzyFIv/pvGBSFIfuedQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQ0siLiFdHxJGIePa8ZR+KiG9HxFjNsa1U69ZBxNzl4osnyvV162qPTBqsUdj2LYeSBJ8C9gJ/BhARVwFXApsz86GaA1upHlrgVVlouTQsRmHbX117AJJUW2ZmRLwd+HxEfBu4BrgkM78VEecAHwOeCTwG/EVm3lxxuJI0UJZDSQIyc3tE3An8JXBZZt7Zveso8NbMnIqIZwK7I+LWzHy02mAlaYD8WFmSgIi4BNgEBFA+IMrMmcyc6l4/CPwPcHaNMUpSEyyHkkZeRGwCbgbeAnwGePcCjxsHzgC+19jgJKlhlkNJI637DeUvAO/PzOuBdwCviIiJEx73dOBvgTdkZjY9zpVmbIHvcC+0XBoWo7DtWw4ljayI+GngVuCWzHwnQGbeA/wj8/YeRsRPAJ8G3p2Zt9cY60pz4ABkzl127GiV6wcO1B6ZNFijsO37hRRJIyszHwaef4rlrzl+PSICuBH4cmZ+rLnRSVIdA99zWPNkkSvlRJUzh2bYOrWVA4fr/LfC/NHLr73t187vs4uA1wC/ERFT3cvPPdFKEfFHEXFvRNwTETdFxJMHP1RJWr6Bl8OaJ4tcKSeqnNw5yd4f7mXyK5PNBps/svm1t/3a+f2Umf+WmU/KzBfNu+z9cetExLOAPwTGM/MFwCrgiibGK0nLtaSPlfft28fExMQSI1oL3rP0n7VUNbPnHFlzhDsuvINclVz3tevY86E9rJld00i2+aOc31rwnma2/dr5K8Jq4CkR8RjwVGB/5fFI0qJ4zOGAPbjxQZK5LzYmyYPPfpDzvnWe+eZriGXm9yPifcB3gf8Dtmfm9vmPiYgtwBaAsbExWq1Wz3mHDx9e1vrLZX69/Ha7TafTqZY/yq/9MOfHUs7IMD4+nnfdddfSAmLh+wZ9Moia2TB3rNm5Hz6XHx39UVn2lNVP4Ttbv8O6swZ/8JX5o5tfe9vvZ35E7M7M8eWNqFkR8VPAPzF3rGKbuW8/fyozP36qx/cyt87XarWq7pE1v17+xMQE7XabqampKvmj/NoPQ/5C86unshmgyZ2THMtjj1vWyU5jx56ZP9r5qurlwH9m5n9n5mPMnWD7VyqPSZIWZeDlsObJImufqHLX9C5mO7OPWzbbmeX26WZOk2b+6ObX3vZr568A3wUujIindk+F8zLgvspjkqRFGfgxh/NPCtn07tea2QB73rjHfPOr5Nfe9mvn15aZX4uITwFfB44Ce4BtdUclSYvjF1IkaQAy8x3M/Sk+STqteMyhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJNW2f//c3xxczGXLlpNWf+773rf49a+99uT8yy5b/PrbTj5d4wVbtix+/c997uT89esXv/7u3SetPnHxxYtff//+3l/7U/1dyN27F5+/fv3J63/uc4vPvuCCk1a/dP9+pu6+e3HrX3bZyfnXXuu2N8rb3gIsh5IkSSosh5IkSSr8CymSVNv69Sd/5LQE9191FetvuaX3/B/z8dJi7N62jYmJid5/wDKeO0Brx47e89evh8zewy+4YHn5l122rPxb1q/n4099KlNTU739gGuvPfXHvYvktneab3sLcM+hJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJA1ARKyKiD0RcUvtsQyzmUMzbJ3ayoHDB2oPRQ3zdz84lkNJGoytwH21BzHsJndOsveHe5n8ymTtoahh/u4Hx3IoSX0WERuAXwf+uvZYhtnMoRlumLqBJLlh6gb3II0Qf/eDtbr2ACRpCH0QeBvwtIUeEBFbgC0AY2NjtFqtnsMOHz68rPWXq1b+B+7/AEc7RwF4rPMYb7rpTbz1vLc2Po6ar3+73abT6VTL93c/nO89y6Ek9VFEXAoczMzdETGx0OMycxuwDWB8fDwnJhZ86BNqtVosZ/3lqpE/c2iG7f++naM5VxCO5lG2H9zOdVdex7qz1jU6lpqv/9q1a2m329Xy/d0P53vPj5Ulqb8uAi6PiP8CPglcEhEfrzuk4TO5c5JjeexxyzrZ8fizEeDvfvAsh5LUR5l5dWZuyMyNwBXAlzPztZWHNXR2Te9itjP7uGWznVlun7690ojUFH/3g+fHypKk086
"text/plain": [
"<Figure size 720x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 55\n",
"\n",
2016-09-27 23:31:21 +02:00
"X1D = np.linspace(-4, 4, 9).reshape(-1, 1)\n",
"X2D = np.c_[X1D, X1D**2]\n",
"y = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n",
"\n",
"plt.figure(figsize=(10, 3))\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.subplot(121)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.plot(X1D[:, 0][y==0], np.zeros(4), \"bs\")\n",
"plt.plot(X1D[:, 0][y==1], np.zeros(5), \"g^\")\n",
"plt.gca().get_yaxis().set_ticks([])\n",
"plt.xlabel(\"$x_1$\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([-4.5, 4.5, -0.2, 0.2])\n",
"\n",
"plt.subplot(122)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n",
"plt.plot(X2D[:, 0][y==0], X2D[:, 1][y==0], \"bs\")\n",
"plt.plot(X2D[:, 0][y==1], X2D[:, 1][y==1], \"g^\")\n",
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$x_2$  \", rotation=0)\n",
2016-09-27 23:31:21 +02:00
"plt.gca().get_yaxis().set_ticks([0, 4, 8, 12, 16])\n",
"plt.plot([-4.5, 4.5], [6.5, 6.5], \"r--\", linewidth=3)\n",
"plt.axis([-4.5, 4.5, -1, 17])\n",
"\n",
"plt.subplots_adjust(right=1)\n",
"\n",
"save_fig(\"higher_dimensions_plot\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 13,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('polynomialfeatures', PolynomialFeatures(degree=3)),\n",
" ('standardscaler', StandardScaler()),\n",
" ('linearsvc',\n",
" LinearSVC(C=10, max_iter=10000, random_state=42))])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2017-06-01 09:23:37 +02:00
"from sklearn.datasets import make_moons\n",
2016-09-27 23:31:21 +02:00
"from sklearn.preprocessing import PolynomialFeatures\n",
"\n",
"X, y = make_moons(n_samples=100, noise=0.15, random_state=42)\n",
2016-09-27 23:31:21 +02:00
"\n",
"polynomial_svm_clf = make_pipeline(\n",
" PolynomialFeatures(degree=3),\n",
" StandardScaler(),\n",
" LinearSVC(C=10, max_iter=10_000, random_state=42)\n",
")\n",
2017-06-01 09:23:37 +02:00
"polynomial_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 14,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAzNElEQVR4nO3de3Dc5Xno8e+zutiSLFm2LMsX+YLx2saIgFMgiXEAY9pwSSDtQC6054S0Uw6d0tPOnOQ0bTPtOe3pJCeTdE46pKFMGyBpnaQk3BJDAlHsQuyEOAYjg61YxhdJ2LLXK1u+SPJqte/5Y3fFStpdrbS/++/5zHgs7a52H/+82uf3vu/ze14xxqCUUko5LeJ2AEoppcJJE5BSSilXaAJSSinlCk1ASimlXKEJSCmllCs0ASmllHKFKwlIRL4hIqdE5M0C998sIgMisjfz56+djlEppZS9Kl163ceBh4FvFnnMK8aYDzsTjlJKKae5MgIyxrwM9Lvx2koppbzBy2tAHxCRN0TkBRG50u1glFJKWcutKbipvAasMMZcEJE7gGeAaL4HisgDwAMAs2fP/o3WpcttDs2a1kUGgyCWPJcT3I53uq+cwhDx0fHVeO3jp1jB3nhNKoWQIhUxRCIVljxnV9fbp40xzTP5WXGrF5yIrAR+aIxpK+GxR4FrjTGniz0uunqt+deH91gTYB6X4hcAaGwq/z8ulniD5uqry34ep9gRr8Tz/3fObppT9nP3JjpprV5X9vM4ReO1j59iBXvjjW/bzqr6DjpuFVZEt1jynFctv32PMebamfysJ0dAIrIIOGmMMSJyPempwrhb8WQTD1iTfMIqX8KxItkopaYW37adWcd3sWdzH1X5J5Qc50oCEpFvAzcDC0SkF/gboArAGPMIcA/wRyKSBIaATxgXhmqaeMozMeFoslHKefGufup2PktF45ucvmc2VS1Ry0Y/5XIlARljPjnF/Q+TLtN2jZXTbWGhCUcpb4l39VO7eweRVYcYuHUFKz2SeLI8OQXnJh31lE4TjlLeN3/uBYZXt1DTbHeB1vRpAsqho56padJRykcOvsHw4HGO1JynFk1AnqSjnsLGEk5NEjl/WhOOUj6QXfe5FOmga2OSqpYoCxu9UXiQK/QJSEc9k+Ub5UQSFZp8lPKBWHsHNYd/wIkNPcjlTTS0XefJ5AMhTkA66hlPp9aUCobIcJyla+cQ37jeM9VuhYQyAemoJ02TjlLBFKmvAYbdDmNKoUtAmnzGJx5NOkoFQ3bdZ3DOAXbXD3nmYtNiQpOAwp54NOkoFVzZdZ/Dqw9Te80iGto2enbdJ1coElCYk48mHqWCLd7Vz/LhvQysHaTu1qs9v+6TK/AJKIzJR5OOUuERuXgKgFkevdi0mMAmoDBWuWniUSpczuztpGbvz7nUMsCRmpgnLzYtJpAJKGyjHk08SoVPdt3n4IZjNFzeTK2Hr/cpJHAJKEzJJ4iJ5+b7lhA/M/n/rmneKDu2HnchIqW8J5t8Yht6mLuxzVfrPrkCloDSOzYEPflkE09Qkk6ufMmn2O1KhZGfLjYtJmAJKNjJJ8iJRylVooE+wD8XmxYTuAQURJp4lFKQnnqr79lDZOlBXl8Zobb5OrdDKosmIA8L4hqPUmr6shvLjSRfoXdTkqr1Uc9tLjcTmoA8SBOPUmqi+XMvUHvZYk7etMZ31W6FaALyEE086Wq3QlVwSoVV9mJTGme7G4jFNAF5hK7zpGmptVLjZS82Pb70IAM1Ed9dbFqMJiCXaeJRSuWTXfepuniA3s19gVn3yaUJyC3J9BbXoMlHKTVZqruX2qqDdN9yloYbtwRm3SeXJiAXSPw01GjiUUoVFhmO09BSQ+qKZYFMPqAJyFG5RQaRyuBeMKuUKs+ZvZ3UxY4SW3IWmOd2OLbRBOSQSWs9CReDUYD2nVPeNHJukKrdL9C7uY/KpgZqfbbFwnRoAnKAFhp4k/adU16SLTqQW+s5E+B1n1yagGykiUcpNR3z514gPms+1RtuCHzyAYi4HUBQafJRSs1IpbgdgWN0BGQDTT5KqenI7u+zb3MfIkupdjsgh2gCspAmHqXUdIxdbFq7k9iGi1StX0/F6HwWNi5yOzRHaAKyiCYf/9G+c8pN2eSzJDrAvnUt1K1LNxnt7QxPiawmIAto8vEnLbVWbmtqqaSqsYrRptmhKDqYSIsQyqTJRyk1U2b4rNshuEpHQDOkiUfphazKChXz5gD9bofhCk1AM6DJxxp+/wDXC1lVOSIXTyH1dfQlut0OxTWuTMGJyDdE5JSIvFngfhGRfxSRQyLSISLvdTrGQjT5WEc/wFVYxbdtp2r39+lYsoOjK6EmwO12inFrBPQ48DDwzQL33w5EM3/eB3w987erNPkopcoRa+9gzslDVNTu5PQ9s2loC367nWJcSUDGmJdFZGWRh9wNfNMYY4BfiEijiCw2xpxwJsLJNPkopaywqLWHc5s/SGIRoU4+4N0quKVAT873vZnbXKHJRyllhchw3O0QPMWrRQj5miGZvA8UeQB4AKC5uZlY4g1rI0kmoSazf4+F14clzDC9iU7rntBm9sS7rOA95b6WE8d3XuNCzpydlef2S/QmOrnvUzcWvH/rEy+Pu03fD/bxSqwj5waJXAN7Gq5AhmcR6a2mt2/yh0pi2ITmYlSvJqBexn86tQJ5y6KMMY8CjwJEV68xzdVXWxaEnSOf3kQnrdXrLH9eu1gRb6Gqt4ma5o2W/VpOHN+Xv32qyL3r8iYfgDNnZ02KLYzvB6e4HWu248Gl2p2cvypC1WXLWBHdUvDxvZ0JWteFoxucVxPQc8BDIvId0sUHA06v/+i0m/WKJZ99L/QUvC+s/F6mrt5NPhWte7iwoYVl77/T7ZA8xZUEJCLfBm4GFohIL/A3QBWAMeYR4HngDuAQMAh82tH4NPn4QrEP6H973P0pl3JpmXowNLVUMrRgLtXr1rgdiue4VQX3ySnuN8AfOxROXkFMPk6fUZc65TZTdn9A6whEWSHs7XaK8WoVnGskfjqQyQecP6P2+5m6jkCUZRpnux2BJ2kCyhHk5KOcV2hbB93uITwiF4sVqiivFiE4Lrvuowqzc0rq5vuWBG5aK2j/HlW6eFc/qe7esV1Oq1ZGqXE7KA/SBIQWHZTKzikpndaaTDfM86dYeweze7Yz2tit7XamEPoEFIbkY3cxgFuC/gGtIyj/Wn5lA/vWtbBSy66LCn0CgmAnH5h6dGHXB3ahBGGVYh/QvRZcSO5Egnv35CB/VwituPOn1NA5Rpu08GAqoU5AWnRg7wWg2Q9Ov47AnPjgn+q4+PG4hV1kOJ65qlFNJbRVcFp04Bz9EFVhE6nXkoNShHIEFIZ1n1JddXt66qeUqR47p6SCsm6jwi2+bTt1saOcXHICmOd2OJ4XygQEmnwmKmWUYseUlF1TgIU6UeuaSnHa/WFmxhqOJl/hzKYkVeujRRuOqrTQJaAwTr3ZXQzgRYU6UYftOEyXdn+Yvtxu16NXVdJw401adl2iUCWgsE695Z65ZqfclDdMdXKgU5P+0NRSydDiFhI3rtHkMw2hSkAQvuTjBUG/XifXdKewsre5vWeNUm4ITQIK49SbV4Rp7UCnsMLJDJ/VhqMzEIoEFNapt3y8PBrRBXClwiUUCQg0+fjhw93K0cO8xksFq+BK5YdjZjUvn6B4Uay9g5rDP+DVDcdoqGmmluVuh+QrgU9AOvWWVuzDPYidqLc+8XLZayphnE4L2vvALtnKt+HIHi5siDN3Y5uWXc9A4BMQ6OhnKkH+QJ1KLHGSzx56iC+v/hoLqhe6HY6jwjjCs0qqu5cl0QFOz09Qf8UmFkQ3uB2SLwU6AenoR03lkXe+ymvnf8kj73yVz1/2946/fqFmpE4kgTCO8KwmC5s1+ZQh8L3gdPSjCoklTvJs7EkMhmdiT3I6Uf7uldPdBVWTgH+lhs65HYLvBXYEpKMf/3F6AfyRd75KCgNAipQloyCdugoX3XKhPIFNQKCjn1x+aMcz8cM73tWf8/X0nivZkiR+rD/vfU3R+WOjnxGT3jhoxCR4JvYkDy7907G
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"# extra code this cell generates and saves Figure 56\n",
"\n",
"def plot_dataset(X, y, axes):\n",
" plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
" plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n",
" plt.axis(axes)\n",
" plt.grid(True, which='both')\n",
" plt.xlabel(\"$x_1$\")\n",
" plt.ylabel(\"$x_2$\", rotation=0)\n",
"\n",
2016-09-27 23:31:21 +02:00
"def plot_predictions(clf, axes):\n",
" x0s = np.linspace(axes[0], axes[1], 100)\n",
" x1s = np.linspace(axes[2], axes[3], 100)\n",
" x0, x1 = np.meshgrid(x0s, x1s)\n",
" X = np.c_[x0.ravel(), x1.ravel()]\n",
" y_pred = clf.predict(X).reshape(x0.shape)\n",
" y_decision = clf.decision_function(X).reshape(x0.shape)\n",
" plt.contourf(x0, x1, y_pred, cmap=plt.cm.brg, alpha=0.2)\n",
" plt.contourf(x0, x1, y_decision, cmap=plt.cm.brg, alpha=0.1)\n",
"\n",
"plot_predictions(polynomial_svm_clf, [-1.5, 2.5, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
"\n",
"save_fig(\"moons_polynomial_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Polynomial Kernel"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 15,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('svc', SVC(C=5, coef0=1, kernel='poly'))])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"from sklearn.svm import SVC\n",
2017-06-01 09:23:37 +02:00
"\n",
"poly_kernel_svm_clf = make_pipeline(StandardScaler(),\n",
" SVC(kernel=\"poly\", degree=3, coef0=1, C=5))\n",
2017-06-01 09:23:37 +02:00
"poly_kernel_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 16,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAuQAAAEQCAYAAAD4T2H3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABnXklEQVR4nO29eXxU93X3//5qAwlJSEhi34yRwYBtcLBjYxNwcJrYsUPjOhtJnzh9Gif5xd2eJm22tuny1G2fpllqt66fPrGdhcSNGy8QO7GtYAI4xLLNImPAEmYRCMQwI0BCEqPRfH9/3BkxGs2MZrnr3PN+veYlzZ0795753jtnPvfc8z1Haa0RBEEQBEEQBMEZSpw2QBAEQRAEQRD8jAhyQRAEQRAEQXAQEeSCIAiCIAiC4CAiyAVBEARBEATBQUSQC4IgCIIgCIKDiCAXBEEQBEEQBAcRQS6MQim1WSn1qNN2CKCUukkptVcpFVZKveS0PYIg5I74VG+hlFqvlGpXSkXkuAl2IoJc8BVKqb9VSh1QSl1QSvUopVqUUquctisN3wb2AJcDdwEopeYqpTbF7D+jlPqOUqoi3x0opZYqpZ5QSr2tlNJKqa8XsK1pSqlvK6UOKaUuKqVOKKWeU0rdXsA2dYrHZ/PdniAI5qKUuksp9QulVCD2/VybYp0JSql/jfmsC0qpZ5RSs+23Niv+E/hvYB7wRwBKqauUUluVUgMxv/aXSimV7w6UUu+KjcGJ2Jjdk2IdpZT6ulKqK7bfl5RSS5PWMW1clVJrYxePZ2L7OxDb9vw8t3dPGv89MZ/t+QER5ILtKKVKlFKlDu3+IPB54CrgZuAw8HOl1DSH7MnEQuCXWutOrXUoNmY/A2qA1cDHgLuBbxSwjyrgCPA1jLHIi5jTfh14L/Bl4Grg1pi9DxVgH8CngRkJj8cK3J4gFBUO+9RJwMvA/8qwzreA38HwWauBWmCzgzanRClVBzQCv9Ban9Ban1NK1QIvAN3AdcAfAl8k8+cdj2rgDQzBP5BmnT8D/hT4g9h+TwMvKKVqEtb5FiaMq1LqM0ALEAQ+BFwJ/E8Mjfi1XLaVRD+jffcMrfVgAdsrbrTW8vDpA0OMPQr0YTibrwCbgUcT1qkA/hE4DlwAWoH3Jm3n/RhCdxD4FfBRQAPzY6/fE9vH7RhOKAIsy3LbSzBEXS+GQ/oRMN3EMaiN2frePN7378DJ2OfeD3wk4fW7gDbgItAJfBVQ2YwrMD9mU+LjHuA2IArMSdjOJ2L7rzVhLN4Avp7ne58FuoDqFK/VF2CTBu52+rsiD3lk8/CzT8UQshpYm7R8MhAGPp6wbE7Ml+Xqd2cCP8QQjv3AbuCWhNc/A3TE9tcBfDqFLQ/HPncvsBVYGXttbQq/uxb4HHAeqEzYzteAEyT49ALGrQ+4J2mZwvht+WrCssqYzZ8xc1yB2Ri/U99J83pdnp/rHqDPzu+f1x+OGyAPBw8+/FvMqbw35sx/EnM8jyas80NgJ/AuYAFwX8wJXBN7fW7sy/wvwCKMiO2xFD8eEYwoyk3AFRhR3vG2PQM4g/EDcyVG1HUT8ApQElvn4zGHlunx8TSfvwL4AnAOmJbDuClgB/Am8L6Y7bcBH4y9/g5gGPjr2GeN2/gH2YwrUApMx/hB/aPY/5XA3wD7kmxpio31LQnHY7zxeCjN58pLkANTMH4EvpLFus+NZ1/S+hrjHD2DIS4+Gz/28pCH2x742KeSXpC/O7a8KWn5PuCvcxjbSUA7hu99F5dS+eK+74PAUOwzX4ERWR4C7oy9roDtGBcj12Pcgfzb2PGZgfF7sCRm610YfrcC+B7wsyRbroutd1ns+eosxiylfyS1IF8Q2/51Sct/Bjxm8rj+SWw7M8dZL6ffltg5OgwcxbhA3AyscPo76uaH4wbIw6EDb9wyu8joq+tq4CyxH4+Yw4sCc5Pe+xTwb7H/78eIDidGf7/C2B8PDbwjYZ1stv03QEvS6/WxbV0fe14Tc6yZHjVJ27gj5jyiGD+e1+c4du+JvffKNK//ECPVJHHZ14Hj2X722PNRjhojspO8XYXxw/yx2POyLMZjahq78xXk18eOyQezWHfWePYlrf8XGKlFyzFu314Avub090ce8kh++Nmnxt6XTpBviPkolbT8l8B/5DC+n8aIEDemeX0H8N2kZY8C22P/vzvmUyuT1tkN/Fm6zwA8n2K7c2Pr3Rh7XpnFmE1JY3cqQb4qtv3kY/ldjHQaM8f134BzWayX028LcCPwSQzfvRp4AuOuRrPT31W3PsoQ/MrlGFf/v44v0Fr3KaXaEta5FkPwvZk0f2UCxpceYDHQqmPfwBi/SbG/CIbjy2Xb7wDepZTqS2P/K1rrXgwnnQtbMJxEI4aT/y+l1I1a65NZvn8FcFJrvT/N61diRDIS2Q78VSwfMZvPng6dabnWOoJxq9ZOsp7cpLU+kcuGtdZ/m/B0dyw38qvA3+WyHUGwAT/71HxQpPdnqVgB7NVan0nz+pUYgjWR7cAHYv+/AyOlKJA0PhMxPnsmku1Uicu11gNY43dT7Xe8Mct1XLNaP9ffFq31r0n4LiilXsY4X/8AIw9fSEIEuX/JRkSVELtthnHrL5H4RJRsv/wXtdbDOW67BEPYfiHF9roBlFIfB/5jnH1/Rmv9w/gTrfUFDMfSAexUSrUDv49x+zIbxhu7TGOiye6zp+IUxu3pRBoxUlzi4zEXI5UmEz/QWptZqaQd4/NcCTyZaUWl1HMY0ZK0aK2rM7z8G6BWKTVNa92dq6GCYCG+9anjcArDRzUCgYTlUzHy47Mlm/FNNW7xZSUYnzGV/zmfYZunMNJXEpka+xsfs9UY6XiZ+Hut9d+Ps07iPonttzNpv90J65gxrm8Bk5VSM7XWXelWKvS3RWs9rJR6FWjOwTZfIYLcv3RgOO0bgLcBlFKTMPIeD8XW2YXhBKdrrbek2c5+YH3Ssuuz2H82234d+DBwVGud/AMT5xlSR48SGU+4lWBEkbLldWCGUurKNFHyNzHSLBK5GSNlpVcplc1nT8Wvga8ppWZrrY/Hlr0H4zb5a7HnXRjR/0xk+vHJGW1UgPkFcJ9S6jta61HRN6VUndb6bOzp72Pc3s2X5RgT3c5mXk0QbEd8ampewxiX9wAbAWKl+a7EyIHPlteBTyilGtNEyfdj+NnEKPnNXBKRrwPTgKjW+u0c9vtr4B+VUhP1pQoh78HwtUdiz19lfL8bymGfhzEE93sw5s4QKxe4GqPCC5g3rk8A/wB8iRSR6wT/XdBvS6xM5NUYpXyFVDidMyMP5x4YVUI6Mb7QS4HHGTsB6QcYkzLuxphoshIjunJX7PV5GILwnzEmIN2F4aQ0MC+2zj2kmG2dxbZnYsyG/ynwztg6t2LkUo/JYczi89ZipDq8EyMH8B0YzvsicHUO2ynBcNJvYkzeuiw2hr8de/1ajMksX+fSpM5eRk/qzPjZY+sk55CXYlRu+SXG7dtbMXLg/7WAc6ACw8kuxxAUD8X+X5jjdi7DqApwAKNs1iKMW++fA47ladudGClFyzBuKf8+xgTcbzv93ZGHPFI9/OZTY9ucEvMZa2M2/n7s+fSkcTkR29cKjLTB3UBpDvuZhHFhsx1DmF6GkY4Sn9T52xgC9fMYUdhUkzq3xXzobbH334gx+X51bJ1UOeSTMcTxj2O+6K7YMf3TAs6Tai753X7gL2P/z01Y589j+7krtt8fY4jimoR1Ch7X2Hb+P4z5B4/FjuO82Nj8K/B/8/yMf4Xx+7gg9tm+GzseOc3Z8tPDcQPk4eDBNxzc9zCE32mMCXSbGf3jUY4hLN/GmK1/CiOCkjiZ6A6M216DMYf3qZhTmxZ7/R5S/3hks+1mjCv4HozbrgdjTqIij89bhZFS0YXxg9cFPA28M2m9R4Ej42yrDvi/GLcKBzHE+YcTXo+XPQyTuuxhNp891WSfubFj1I9R+utfgQkFnAPzGVvqSwMvJaxzDwkTyjJsa0bMnrcTxvc54LY8bXsfRtSvF2MyZxtG1Zkyp7878pBHqofffGqCLal8yNcT1pkY20e8XOEmEsq3xtZ5KdH
2022-02-19 10:24:54 +01:00
"text/plain": [
"<Figure size 756x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"# extra code this cell generates and saves Figure 57\n",
"\n",
"poly100_kernel_svm_clf = make_pipeline(\n",
" StandardScaler(),\n",
" SVC(kernel=\"poly\", degree=10, coef0=100, C=5)\n",
")\n",
"poly100_kernel_svm_clf.fit(X, y)\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10.5, 4), sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[0])\n",
"plot_predictions(poly_kernel_svm_clf, [-1.5, 2.45, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.4, -1, 1.5])\n",
"plt.title(\"degree=3, coef0=1, C=5\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
"plot_predictions(poly100_kernel_svm_clf, [-1.5, 2.45, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.4, -1, 1.5])\n",
"plt.title(\"degree=10, coef0=100, C=5\")\n",
"plt.ylabel(\"\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"moons_kernelized_polynomial_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Similarity Features"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 17,
2016-09-27 23:31:21 +02:00
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAuQAAAEQCAYAAAD4T2H3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABqQ0lEQVR4nO3dd3gUVdvH8e+dhECAANJ77y0gRYpUG4qKDRVUBERERVEfe+dBeey90ARFBV4LKgKKqARQOtJFkN5CJ5BQUjbn/eNsYBPSs7uz2dyf69qL3dnZmd9k2cnJ2TP3EWMMSimllFJKKWeEOB1AKaWUUkqpwkwb5EoppZRSSjlIG+RKKaWUUko5SBvkSimllFJKOUgb5EoppZRSSjlIG+RKKaWUUko5KMzpAL5Uvnx5U7t27Xxt4+TJk5QoUcI7gfIpULIESg7QLJkJlCyBkgOCL8vKlSsPG2MqeClSQPDGOTs7mzZtwuVy0bRpU5/ux58C6f+2twTbMenxBDZ/HU+W521jTNDe2rRpY/Jr3rx5+d6GtwRKlkDJYYxmyUygZAmUHMYEXxZghQmA86w3b944Z2enW7duJioqyuf78adA+r/tLcF2THo8gc1fx5PVeVuHrCillFJKKeUgbZArpZRSSinlIG2QK6WUUkop5SBtkCullFJKKeUgbZArpZRSSinlIG2QK6WUUkop5SBtkCullFJKKeUgbZArpZRSSinlIG2QK6WUUkop5SBtkCullFJKKeUgbZArpZRSSinlIG2QK6WUUkop5SBtkCullFJKKeUgvzXIRWSiiBwUkfWZPC8i8p6IbBGRtSJyocdzvURkk/u5J/2VWSml1Pnycz5XSil1Pn/2kH8K9Mri+SuBBu7bUOBjABEJBT50P98U6CciTX2aVCmlVFY+JQ/n82ASExdDt0+7sT9+v9NRlFJBwG8NcmPMAuBoFqv0ASYbawlQRkSqAO2BLcaYbcaYRGCae13lkJPJJ9l1fBeuFJfTUVQ6hw/Dpk2QkuJ0EhXM8nE+DxqjFozij11/MGr+KKejKKWCQCCNIa8G7PZ4vMe9LLPlyg92xu7kxegX6TChAxsPbQRg4eGF1HqnFpH/i6TLpC68uehNDsQfcDhp4XT8OIwdCy7330Zjx0LjxpCYaD/a338PTz4J8fHOZVSFUlCft2PiYpi0ehIpJoVJqydpL7lSKt/CnA7gQTJYZrJYnvFGRIZivyKlUqVKREdH5ytUfHx8vrfhLf7MEnM6hok7JvLbwd8AaFqqKfMWz+NAqQPUDavLIw0eYdepXaw+uppH5z7KM789w5SLplA2vKxf8qUqrO9PqrlzKzF6dBNSUlbSpEkctWsX49lnS3HmTBzR0dF8/31t5s2rwOWXLyckBBIThfDwTD8+XlfY35/MBFIWH8nxedvb5+zsxMbG4nK58rWftze/TbIrGYAkVxLDpg7joQYPeSdgHgTj/6dgOyY9nsAWEMdjjPHbDagNrM/kubFAP4/Hm4AqQEdgjsfyp4CncrK/Nm3amPyaN29evrfhLf7Kcuz0MVPqf6VMxEsR5rFfHjM7ju3IMsc/h/4x7y559+zj9QfW+yNmhlmc5K8ss2YZM2OGvZ+QYMzKlVlnSUw8t27jxsa8+abvM2aUw2nBlgVYYfx4/k5/y8v5PLtteuOcnZ1u3bqZqKioPL9+34l9pthLxQwvcvYW8VKEiYmL8V7IXAqk/9veEmzHpMcT2Px1PFmdtwNpyMoMYID76vwOwHFjTAywHGggInVEJBy41b2u8rJEVyIAZYqVYeK1E9n8wGZeu+w1apWpleXrGpVvxIMXPQjAqphVRI2JYuD3AzmddNrnmQsblwueegrGjLGPw8PhwmzqVxQpYv9NTISOHSEqyrcZlSLz83mBN2rBKFJM2os0XMalY8mVUvnityErIjIV6A6UF5E9wAtAEQBjzBhgNnAVsAU4BQxyP5csIsOBOUAoMNEYs8FfuQuLbce2ceWXVzK652hubHojNza9MU/baVmpJc90eYZRC0bx79F/+eHWHyhfvLyX0xY+J09C0aIQFgYzZ0KFCrnfRsmSMHHiucdTpkCTJtC6tfdyqsIhr+fzYLB4z+KznRepEl2JLNqzyKFESqlg4LcGuTGmXzbPG+D+TJ6bjT3BKx/45/A/XDL5Ek4nnaZSyUr52lZoSCgje4ykRaUW3D79drpO6sq8O+fle7uFWUIC9OplG87vvQc1auR/m4mJ8Pzz0LIlTJ+e/+2pwiU/5/OCbtU9q5yOoJQKQoE0ZEU5YNuxbfT8rCfJKcnMHzifi2te7JXt3tT0Jn6+/Wd2xO5g6vqpXtlmYRUeDt26QZcu3t3mvHkwebL3tqmUUkqpvAmkKivKz46dPsbln19OgiuBBQMX0KxiM69uv3vt7qy7dx11L6jr1e0WFsnJEBsL5cvDSy95f/upPe1nzsCDD8Ljj0P9+t7fj1JKKaWypj3khViZYmW4+8K7md1/ttcb46nqla2HiLD2wFoemP3AeRdDqcw9/ji0b28b5b60d6+tV+50xSellFKqsNIe8kLq6OmjlI0oyxMXP+GX/c3bPo8Pln9A1ciqPNXlKb/ss6Dr18/2jpcp49v91KsH//4LpUv7dj9KKaWUypj2kBdCX679kvrv1WfDwayL1cyfPx8R4aeffjq7bPv27Vx//fU8+OCDudrngxc9SL/m/Xjm92f4dduvecpdWCQl2X/btYOnn858vczen4oVK/Lee+/lap+pjfGVK9NWYlFKKaWU72mDvJDZdmwb9866l6YVmtKofKMs1+3WrRs9evRg1ChbX/f48eNcffXVNG7cmLfffjtX+xURxl8zniYVmnDn93dy9PTRPB9DMDtzBi66CN5/P/t1M3t/2rdvz/33563AxZtv2vHqZ87k6eVKKaWUygNtkBciySnJ3Db9NkIkhC9v+JKwkOxHLI0cOZLFixfzyy+/cPPNN1OkSBGef/55QkNDc73/EuEl+OL6Lzh48iBvL85dg76wSEqCpk1zfnFlRu/PtGnT8vT+AHz4IaxYAcWK5enlSimllMoDHUNeiPx3/n9ZsmcJU2+cmu3sm6m6dOnCpZdeyvXXX0+ZMmVYunQpW7ZsAWD37t0MHDiQffv2ERISQu/evXn11VcRkUy317pKa+bdOY8O1Tt45ZiCTWQkfPFFztfP6P0pWbLk2ee7detGbGwsxhgaNmzIxIkTKVWqVKbbu+AC+68x8PvvcMkleT0SpZRSSuWU9pAXEsYYth3bxoCoAdza/NZcvbZ+/fqcOnWKkSNHUr169bPLw8LCePXVV9m4cSOrVq1i6dKlTM/BLDMX17yYsJAwDp86TExcUMymnW+xsXDHHbBnT+5fm9n7AzBjxgzWrFnD2rVrqVmzJq+//nqOtjlxIlx6KSxYkPs8SimllModbZAXEiLCFzd8wcRrc3fF3rhx45g4cSJRUVFMmDAhzXNVqlShbdu2AISHh9OyZUt2796do+0muZLoMKEDQ2cOzVWeYLViBfz4I8Tk8u+TrN4fgNLuqzVTUlI4efJklt9eeLrjDpg2DS72zjxRSimllMqCNsgLgR/++eFsRZXQkJyPLZ47dy7Dhw9n/PjxjBkzhqVLl6ap6OHpyJEjfP/991xxxRU52naR0CIMazuMmZtn8uOmH3OcKVhdeqntHW/XLuevyen7c9VVV1GpUiU2bdrE448/nqNth4fDLbdASIidoEgppZRSvqMN8iAXExfDgO8H8Njcx3L1ug0bNtC3b18ef/xxBgwYQIcOHbj00kt54YUXzls3ISGBm266iYceeogmTZrkeB8jLhpB0wpNGfHzCE4nnc5VvmDhcsGiRfa+x9DvbOXm/Zk9ezb79++nffv2fPTRR7nKt2QJNGgAGzfm6mVKKaWUygVtkAe5R355hITkBN7t9W6OX3Pw4EGuvvpqLrvssrMl9QCee+45li9fzuLFi88uc7lc3HbbbbRu3Zr//Oc/ucpWJLQIH1z5Adtjt/Pqn6/m6rXB4ssvoXPn3I3Vzu79mTVr1nmvCQ0N5c4772Ty5Mm5yle3rp04KLU2ulJ
"text/plain": [
"<Figure size 756x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 58\n",
"\n",
2016-09-27 23:31:21 +02:00
"def gaussian_rbf(x, landmark, gamma):\n",
" return np.exp(-gamma * np.linalg.norm(x - landmark, axis=1)**2)\n",
"\n",
"gamma = 0.3\n",
"\n",
"x1s = np.linspace(-4.5, 4.5, 200).reshape(-1, 1)\n",
"x2s = gaussian_rbf(x1s, -2, gamma)\n",
"x3s = gaussian_rbf(x1s, 1, gamma)\n",
"\n",
"XK = np.c_[gaussian_rbf(X1D, -2, gamma), gaussian_rbf(X1D, 1, gamma)]\n",
"yk = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n",
"\n",
"plt.figure(figsize=(10.5, 4))\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.subplot(121)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.scatter(x=[-2, 1], y=[0, 0], s=150, alpha=0.5, c=\"red\")\n",
"plt.plot(X1D[:, 0][yk==0], np.zeros(4), \"bs\")\n",
"plt.plot(X1D[:, 0][yk==1], np.zeros(5), \"g^\")\n",
"plt.plot(x1s, x2s, \"g--\")\n",
"plt.plot(x1s, x3s, \"b:\")\n",
"plt.gca().get_yaxis().set_ticks([0, 0.25, 0.5, 0.75, 1])\n",
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"Similarity\")\n",
"plt.annotate(\n",
" r'$\\mathbf{x}$',\n",
" xy=(X1D[3, 0], 0),\n",
" xytext=(-0.5, 0.20),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=16,\n",
")\n",
"plt.text(-2, 0.9, \"$x_2$\", ha=\"center\", fontsize=15)\n",
"plt.text(1, 0.9, \"$x_3$\", ha=\"center\", fontsize=15)\n",
2016-09-27 23:31:21 +02:00
"plt.axis([-4.5, 4.5, -0.1, 1.1])\n",
"\n",
"plt.subplot(122)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n",
"plt.plot(XK[:, 0][yk==0], XK[:, 1][yk==0], \"bs\")\n",
"plt.plot(XK[:, 0][yk==1], XK[:, 1][yk==1], \"g^\")\n",
"plt.xlabel(\"$x_2$\")\n",
"plt.ylabel(\"$x_3$  \", rotation=0)\n",
"plt.annotate(\n",
" r'$\\phi\\left(\\mathbf{x}\\right)$',\n",
" xy=(XK[3, 0], XK[3, 1]),\n",
" xytext=(0.65, 0.50),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=16,\n",
")\n",
2016-09-27 23:31:21 +02:00
"plt.plot([-0.1, 1.1], [0.57, -0.1], \"r--\", linewidth=3)\n",
"plt.axis([-0.1, 1.1, -0.1, 1.1])\n",
" \n",
"plt.subplots_adjust(right=1)\n",
"\n",
"save_fig(\"kernel_method_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gaussian RBF Kernel"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 18,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('svc', SVC(C=0.001, gamma=5))])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"rbf_kernel_svm_clf = make_pipeline(StandardScaler(),\n",
" SVC(kernel=\"rbf\", gamma=5, C=0.001))\n",
2016-09-27 23:31:21 +02:00
"rbf_kernel_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 19,
2016-09-27 23:31:21 +02:00
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAuQAAAHoCAYAAAABjvqDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAADiAUlEQVR4nOz9eXxcd3X/jz/f2izJlix5LNvxbmzFWZwmzo5xgoMJbdIEl5Z+Ci4tS2k+UCi/th8oS8uHbt+mG/1QIJDSNAkJdZKSkpUEEkQCDk7Aifc1srEteZPHI9mSLMmj0bx/f9y58tXozn7XmfN8POYh6c5dzozuPfd1z/u8z1FaawRBEARBEARB8Icqvw0QBEEQBEEQhEpGBLkgCIIgCIIg+IgIckEQBEEQBEHwERHkgiAIgiAIguAjIsgFQRAEQRAEwUdEkAuCIAiCIAiCj4ggFwRBEARBEAQfEUEuCHmglPotpdQepdT51M/35LHNvymlXldKjSilDpdwbKWU+qhS6lWl1IBSql8ptUUp9edKqeYS9vlXSqnjSqlhpdTLSqnL89ju7UqpN1Kf6ZdKqY/ZrJP1u1JK3ayUelopdUwppZVSHyrmMwiCULmUoU/+TaXUD5VS0ZRfXGOzzhSl1NeUUqeVUudSfnR+2jqtSqmHlVJnU6+HlVItaessVEo9k9rHaaXUV5VSdcXYLTiHCHJByIFS6q3AY8B/AVelfn5XKXVDjk2rgG8DD5VowsPA14DngLXArwBfBG4BfrPIff458H+APwauA04BLyqlmjJtoJRakrJhE7ASuBv4mlLqtyzr5PNdTQN2Af8/YLhI+wVBqFDK1CdPxfCtf5Zlna8AvwW8H7gJaAaeVUpVW9bZAFwN3Ab8Wur3h803U+t+H2hK7eP9wHuBLxdpt+AUWmt5ySvjC8NJPAQMAj3A54FngQct63wA2AwMYAi77wLzLO+vATSGg3gDQ4RtBOYDbwe2p/b/LBCxbPdgatlngZPAWeAfMJzqX6WOdRL4bJrNfwbsAM4Bx4D7gJYSvoPHgBfTlv0IeCTP7T8NHC7y2P8r9d39Zob3C/5cgAJOAH9hWdaQ+v/97yzb/SPQmbbsPuDVYr+r1P/9Q36f5/KSV1he4pPLzyenbT8ztf81acunA3Hgdy3LFgBJ4FdTf1+a2vZtlnVWp5YtT/19W2qbBWnnywjQ7Pf5XckviZALufgyhoN+D/AO4EqMp2ordcCXUu/dgeFQHrHZ118DfwLcALRiONX/C9yFcYO4HMOpW7kZWJJ6/2MYkd3ngCkYjuavgH9QSl1j2SaZOs7lwHrgeoxoBjA+XDeY43WvZX9vBV5Is+uHwCqbz+g0vwu8qbX+nt2bWuszAEqp383jM/1uarMlwBwsn0lrPQz8lOyfKdP3cK1SqjbHOl58V4JQCYhPLj+fnA/XALVM9NvdwF4ufO63YjxIbbJs9zOMByHrOntT25r8EOP/Z/2fCR5T47cBQnBRSk0DPgL8vtb6xdSyPwCOWtfTWt9v+fOXSqmPA3uVUvO11tZ1v6i13pjaz70YDvkarfWW1LJvYwydWTkLfEJrPQbsU0r9H2Cu1vrXUu+/qZT6HMZQ4Rspe75i2f6wUurPgaeUUh/UWieB4xjDnNnot/w+ByMSZaUntdxt2oF9eaz3NPDzHOuYn2FO2t/W9+dl2X4ORhQqfZsajBv+Cfz9rgShrBGfPE65+eR8mAOMAadt9jHHsk5Up8LeAFprrZQ6lbZO+nFPp/YtftpHRJAL2ViK8UT+C3OB1vqcUmqXdSWl1NUY0ZirgBkYKREAC5l4o9hh+d10CDvTls1Ks2FPyvFb1zmTts6E7ZRS78AYxr0UY5ivGiNiNAc4rrVOAAfSP2wOdNrfymaZG6jcq4DWegBjeLoQivlMdtukL/fruxKEckd88gXK0ScXQ/rntvsO8lkn23LBAyRlRciGndiauIJSUzGGu4aA38OYIGhGStJnbY9aftcAWuv0Zenn5Gja3zrDsqqUPYswJqzsBX4bYwjuI1Z7ihgePcnkyMEsCotuFMubGDexrBQ4PHoy9bPQz5Tpe0gAsRzrePFdCUK5Iz7ZoNx8cj6cxHiQmZm23Pq5TwKzlFLjDw2p39vS1kn/7mam9i1+2kckQi5k4wCGo70eOASglGoEVgAHU+tcgnExf0Frba5T7CxzJ7gWw8n/qRnFUUrdkbZOocOjrwK3Av9sWXYrE/P03GID8KhS6jftchaVUi2pnMVChkcPYTjlWzEmfqGUqsfIQ/1Mlu1fBX4jbdmtwOuWm7if35UglDvikw3KzSfnwxsY//tbUzagjJKHl3Lhc7+KUcXqrZZlb+VCBRdznb9MS1+6FTifOobgEyLIhYxorQeVUvcD/6iUOo2RI/yXGJEPM0LThXEhf1IpdQ+Gc/hbP+xN0Ylh358opb4H3IgxmWicIoZH/w34qVLq88ATGJOpbsGYwASAUuqTwCe11pdYli3DcI5zgTql1FWpt/ZoreN5Hvu/MUTwfyml/h74AUYlg0sxShZ+F6O6Qt7Do6mcwq8Af6GU2ocR8flLjMlAGyz2P5Ra//dTi+7F+D9/Bfh34G3AhzDKZpnk811NA5al/qwCFqa+m16tdVc+n0EQKhHxyeOUlU9O2TYDI6WoJbVomVLqDHBSa31Sa31WKfWfwD+ncsJjwL9ipB39CEBrvVcp9QPg35VSf4gxovLvwLNa6/2p/b4A7AYeSuX/RzAebP5Da2196BG8xu8yL/IK9gvDeT2MMUu7B/gc0AF807LO72BEZ0Ywcht/FUvZJi6U2Jpp2ea9xuk34VgfA05b/n4Qw5FY15lQ3iu17DXgXyx/fwqjtNZwylazTNXiEr6H92JM5IljDL3+Ztr7f2XzeV5OHTf9tdiyjgb+KsexFUbVg59jiOZ+YCtGdYOmIj+PStl8IvV/+wmwwsb+l9OWvR3YgnHDPwR8rIjvyjwf0l8PFvNZ5CWvSnqJT55gbzn55A9lsO2vLOvUY0y8jWGkJD2DpXxhap0ZwHdSNvWnfm9JW2dh6v82lNrX14Apfp/blf5SqX+OIOSFUmoKcAT4Z621NBIoAWU02jkI3KS1/pnf9giCED7EJzuH+GTBTyRlRciKUmolxlDcLzA6e3029fMxP+0qE24HHhLHLwhCvohPdhXxyYJv+CLIUzlwdwCntNYrbN5fAzxFatIK8D2t9d94ZqCQzp8ByzGqaWwDbtYTa9kKRaC1vsdvGwRBCCXik11AfLLgJ76krCilbsbIu3ooiyD/tNY6fSa2IAiCIAiCIJQVvtQh11r/FOj149iCIAiCIAiCECSC3BjorUqp7Uqp55VSl/ttjCAIgiAIgiC4QVAndW4BFmmj5urtwJNAu92KSqm7MMoPUV9ff828eQtQ+XW29QVNEhXo56Bw2Aju2KlTpXydOocq97t0PhVOowN9bZvkttPZcywfOg92ntZat7m1/3Q/PH/eQrcOVTJ6/NzUgb82w+s/gle9rXz8h/9MttF7n5YPhX6XWmsYGwOS6Gqgqooq5ez119l5MKMv9q3soVJqMUY900k55DbrHgau1VqfzrbesmUX6//8+lZnDHSJgfjPaaq7wW8zshIGG8E9O4diw7RGqh3ZVyy+hUjd1Y7sy02ctjMRO0tjpN6x/QGcjO9iTl1Od+E72exUsSgAdZFmL01ixW2L39BaX+vFsdqXLdf3fX2LF4cqiKHY8PjvrZHqUFybQbcxETsLwJmGN2kZvnh8udPXvhOUg/8IClYb/fJp+dAd38+CuuV5rRvt2EnjwWfoWnaQhqvm0rjiOma12MaBS2LFotsy+uJARsiVUnOAHq21Vkpdj5FaE8u5XcCezoTw0hcbc0yUVxrmTVqwJ4g3rnImXYgLxZHpum6M1NMfrwqkCBfcJchivBBinX0s5hBDb6shcus6V4R4PvhV9vARjE5hM5VSR4EvAbUAWut7MTpwfVwplcDo7PU+LR2MBI9ojDRMuIkLhSM358moWDT0N66wYV7HIsQLI5v4FgQoHzFuoofPQkujrzb4Isi11u/P8f7Xga97ZI4gCIK
2022-02-19 10:24:54 +01:00
"text/plain": [
"<Figure size 756x504 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 59\n",
"\n",
2016-09-27 23:31:21 +02:00
"from sklearn.svm import SVC\n",
"\n",
"gamma1, gamma2 = 0.1, 5\n",
"C1, C2 = 0.001, 1000\n",
"hyperparams = (gamma1, C1), (gamma1, C2), (gamma2, C1), (gamma2, C2)\n",
"\n",
"svm_clfs = []\n",
"for gamma, C in hyperparams:\n",
" rbf_kernel_svm_clf = make_pipeline(\n",
" StandardScaler(),\n",
" SVC(kernel=\"rbf\", gamma=gamma, C=C)\n",
" )\n",
2016-09-27 23:31:21 +02:00
" rbf_kernel_svm_clf.fit(X, y)\n",
" svm_clfs.append(rbf_kernel_svm_clf)\n",
"\n",
"fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10.5, 7), sharex=True, sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"for i, svm_clf in enumerate(svm_clfs):\n",
" plt.sca(axes[i // 2, i % 2])\n",
" plot_predictions(svm_clf, [-1.5, 2.45, -1, 1.5])\n",
" plot_dataset(X, y, [-1.5, 2.45, -1, 1.5])\n",
2016-09-27 23:31:21 +02:00
" gamma, C = hyperparams[i]\n",
" plt.title(f\"gamma={gamma}, C={C}\")\n",
" if i in (0, 1):\n",
" plt.xlabel(\"\")\n",
" if i in (1, 3):\n",
" plt.ylabel(\"\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"moons_rbf_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# SVM Regression"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 20,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('linearsvr', LinearSVR(epsilon=0.5, random_state=42))])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.svm import LinearSVR\n",
2016-09-27 23:31:21 +02:00
"\n",
"# extra code these 3 lines generate a simple linear dataset\n",
"np.random.seed(42)\n",
"X = 2 * np.random.rand(50, 1)\n",
"y = 4 + 3 * X[:, 0] + np.random.randn(50)\n",
"\n",
"svm_reg = make_pipeline(StandardScaler(),\n",
" LinearSVR(epsilon=0.5, random_state=42))\n",
2017-06-01 09:23:37 +02:00
"svm_reg.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 21,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnsAAAEQCAYAAADI77KTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACRxUlEQVR4nOydeXiMVxfAf28mC7EXQYgICUKorZTYa6ut1SpVWvtOqX0vkti3opSqrbVUi1a1vqqitrZKo43YYgmRkEgQicgyM/f7Y8w0eyaZyWQS9/c88yTzLveemcycnHvuWRQhBBKJRCKRSCSSgolNXgsgkUgkEolEIsk9pLEnkUgkEolEUoCRxp5EIpFIJBJJAUYaexKJRCKRSCQFGGnsSSQSiUQikRRgpLEnkUgkEolEUoCRxp7E4iiKclxRlLUZPZdIJBJrR+oxSX5CGnuSvOAtYHpeC5EaRVFGKYpyS1GUeEVRziuK0iKL66soiiLSeXSylMwSiSTPsDo9pihKS0VRDiiKEvpcFw0w4p7WiqJ8ryjKPUVR4hRF+VdRlEEWEFdiQaSxJ7E4QoiHQoiYvJYjOYqi9AY+ARYA9YEzwCFFUSobcXsnoEKyx9HcklMikVgH1qjHgKLARWAc8MzIe5oBAUBPwAtYD2xUFOW9XJFQkidIY0+SBkXHFEVRbiiK8kxRlABFUfo9P6f3Zr2nKMqp516wK4qidEh2v52iKKsVRQlTFCVBUZQQRVEWJTuf6XaHoiilFEXZpijKo+fzH1EUpXay8wMURYlVFOU1RVEuKoryVFGUY4qiuJnwsicAW4UQnwshLgshxgL3gJFG3BslhLif7JFoghwSicQMvIh6TAjxkxBihhDiW0Br5D0LhBCzhBCnhRA3hRDrgX3A2zmVQ2J9SGNPkh6+wGBgNFALWAhsUBSlS7JrlgCrgXrAL8D3iqJUfH7uQ6AH8C7gAfQGrmZj/q1AE+ANoDEQB/xPUZTCya5xQLeFMghoCpQEPtOfVBSlxXNFmtljxvNr7YGGwOFUchxGt+rNin2KokQoinJaUZSe2XidEokk93ih9JiZKQ48yoVxJXmEbV4LILEuFEUpgs7L1UEIcfL54VuKojRGpzRHPT+2Xgix5/k944CO6LxgswBX4BpwUuiaL99Bty1qzPweQHeglRDixPNj7z8foy+w6fmltsBoIcTV59csA7YoimIjhNAC59Ap8Mx4+PxnGUAFhKc6Hw60y+T+WGAScBpQP5f7a0VR+gshvspibolEkku8oHrMLCiK0hV4DfA257iSvEUae5LU1AIKoVuBimTH7YDgZM9/1/8ihNAqivLn83tBt6L9BbimKMph4Cfg0HPllRWe6LYfko8frShKQLLxARL0CvI5Yc9lLAk8FEI8A64bMV9yRKrnSjrH/rtYiEhgebJD5xRFKQNMAaSxJ5HkHS+yHssxiqJ4AzuBD4UQZy01ryT3kcaeJDX6rf1u6FahyUlCZwBlihDib0VRqqBLXGgLbAP+URSlvRGKMrPxkyttdQbnbEC3/QEcymKuBUKIBUAkoAHKpzrvRFpvX1b8CQzM5j0SicS8vIh6zCQURWmOzqCd8zxuT1KAkMaeJDWXgATAVQiRJqv0ufIDeJXnWaeKoijoYlK+1V/3PEvtG+AbRVG2An8A7ui2RbKa3wZd/Ip++6M4UAfYko3XYfT2hxAiUVGU80D75zLraQ/szcacPJ/zXjbvkUgk5uWF02OmoChKS+BHYK4QYpWp40msD2nsSVIghIh5Hjey7LnyO4Eunf9VdNsS+iSGkYqiXEOXsj8KXXzLegBFUSagM3guoFtFvwc8Ae4aMX+QoijfowukHgY8Bvye378zG68ju9sfK4AvFUU5iy4GbwTgTMpg6YVAYyHEa8+f90f3+vzRvTfd0MUDTc3GvBKJxMy8qHpMUZSi6IxR0BmblRVFqYduS/jO82tS67HW6Ay9dcAORVH0OxwaIcQDY+eWWDfS2JOkx2x025eT0Cm+J+gU3pJk10xDFwDdALgN9BBC6JVgDDAZXQabQGcMvS6EiDNy/oHAKuAAurib00Cn54ovVxBCfK0oSml0gdkV0NWq6iyEuJ3ssgpAtVS36gO5NehW+4NkcoZEYhW8cHoMaAQcS/Z83vPHNmDA82Op9dgAwBHd+zQp2fHbQJXcEVNiaRRdkpFEYhzPtz9uAa8IIc7lsTgSiUSSbaQek7xoyDp7EolEIpFIJAUYsxl7iqJsfl5Y9mKyY+8oihKoKIpWUZRG5ppLIpFIJBKJRGIcZtvGfZ7NEwtsF0J4PT+mrzW0AZgk3eUSiUQikUgklsVsCRpCiBPJ0tn1xy4D6JKhJBKJRCKRSCSWRsbsSSQSiUQikRRgrKL0yvM6RMMAChUq1LBy5cp5LJEOrVaLjU3e28NSjrRYiyxSjrSYSxYhBMnDTBRFydYuwbVr1yKFEGVNFsRIpB6TcmQXa5FFypGWAqfH9IKY44GuJs/FdI4fBxoZM0b16tWFtXDs2LG8FkEIIeVID2uRRcqRFmuRBTgnzKjfsvOQeiwtUo60WIssUo60WIss5tJj1mFCSyQSiUQikUhyBXOWXtkF/A7UUBTlrqIogxVF6aEoyl10/QF/VBTlZ3PNJ5FIJBKJRCLJGnNm4/bJ4NR+c80hkUgkEolEIskechtXIpFIJBKJpAAjjT2JRCKRSCSSAoxVlF7JDk+ePCEiIoKkpKRcn6tEiRJcvnw51+fJCDs7O5ycnPJsfolEkju8SHqsSJEiVKpUKc/ml0gk+czYe/LkCeHh4VSsWJHChQvnemeOmJgYihUrlqtzZIQQgmfPnhEaGopKpcoTGSQSifl5kfSYVqslNDSUyMjIPJlfIpHoyFfbuBEREVSsWBFHR8cC34JNURQcHR2pWLEiRYoUyWtxJBKJmXiR9JiNjQ3lypUjOjo6r0WRSF5o8pWxl5SUROHChfNaDItiiZW/RCKxHC+aHrOzs0OtVue1GBLJC02+MvaAF87wedFer0TyIvAifa9fpNcqkVgr+SpmTyKRWA9qtZqQkBDCwsJITEzE3t4eZ2fnvBZLIpFIjOZF0WPS2MsDHjx4wKeffsro0aMpW9ZifdolErMghODixYsEBQUBoNFoDOfCw8OxsbEhICAALy8v6dUpwEg9JsnPvGh6LN9t4xYERo4cyblz5xg9enReiyKRZAshBKdPnyYoKAiNRpNCQYJOYQohCAoK4vTp0+j6eEsKIlKPSfIrL6Iek8aehdm5cycODg4cPHgQOzs79uzZk9ciSSRGc/HiRSIiItIox9RoNBoiIiK4ePGihSSTWBKpxyT5mfygx37//XemTp1qtvHkNq6Fee+993jvvfcA2LFjRx5LI5EYj1qtNqyE9Zw86cKuXXWIinKkdOk4+vQJoE2be4BOUQYFBeHp6YmtrVQ1BQmpxyT5FWvXY5cvX2bMmDEcPXqU0qVLm21cqYElkgJORgHILi4u2VJeISEhKZ6fPOnChg2NSEzUjREZWYQNGxqhUvnTrFlwivvc3NzM8lqM4eLFiyxfvtxi80kkktynIOsxIQTR0dGULFmSokWLcv36dZYtW8aIESMoWrSoWeaQxp5EUkDJKgDZ398fDw8PowOQw8LCUoyxa1cdg4LUk5hoy44dtQ1KUqPREBYWZhFj7969e8yZM4fNmzfnWccIiURiXgqyHhNC8MMPP+Dr60vx4sU5cuQILi4u3Lx50+yds2TMnoWoVKkSK1asSHEsICCAQoUKcenSpTySSlJQMSYAWb89YWwAcmJiYornUVGO6V4XGZmyYLAl+r+CbuW9fft2xo0bx40bNywy54uG1GMSS1JQ9ZhWq+Wbb76hfv36vPHGG0RGRtK7d2+D/LnRIlUaexaiadOm/PXXXymOjR8/niFDhlCrVq08kkpSUMmNAGR7e/sUz0uXjkv3ujJlnqV4bmdnl+XYOUGtVrNhwwamT58OQOPGjQkJCWHFihVmjXWR/IfUYxJLUlD12IYNG+jVqxfx8fFs27aNa9euMXTo0Fwt8SKNPQuRWkl+9913+Pv7M2/evDyUSlIQySgAedS
2022-02-19 10:24:54 +01:00
"text/plain": [
"<Figure size 648x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"# extra code this cell generates and saves Figure 510\n",
2016-09-27 23:31:21 +02:00
"\n",
"def find_support_vectors(svm_reg, X, y):\n",
" y_pred = svm_reg.predict(X)\n",
" epsilon = svm_reg[-1].epsilon\n",
" off_margin = np.abs(y - y_pred) >= epsilon\n",
2016-09-27 23:31:21 +02:00
" return np.argwhere(off_margin)\n",
"\n",
"def plot_svm_regression(svm_reg, X, y, axes):\n",
" x1s = np.linspace(axes[0], axes[1], 100).reshape(100, 1)\n",
" y_pred = svm_reg.predict(x1s)\n",
" epsilon = svm_reg[-1].epsilon\n",
" plt.plot(x1s, y_pred, \"k-\", linewidth=2, label=r\"$\\hat{y}$\", zorder=-2)\n",
" plt.plot(x1s, y_pred + epsilon, \"k--\", zorder=-2)\n",
" plt.plot(x1s, y_pred - epsilon, \"k--\", zorder=-2)\n",
" plt.scatter(X[svm_reg._support], y[svm_reg._support], s=180,\n",
" facecolors='#AAA', zorder=-1)\n",
2016-09-27 23:31:21 +02:00
" plt.plot(X, y, \"bo\")\n",
" plt.xlabel(\"$x_1$\")\n",
" plt.legend(loc=\"upper left\")\n",
2016-09-27 23:31:21 +02:00
" plt.axis(axes)\n",
"\n",
"svm_reg2 = make_pipeline(StandardScaler(),\n",
" LinearSVR(epsilon=1.2, random_state=42))\n",
"svm_reg2.fit(X, y)\n",
"\n",
"svm_reg._support = find_support_vectors(svm_reg, X, y)\n",
"svm_reg2._support = find_support_vectors(svm_reg2, X, y)\n",
"\n",
"eps_x1 = 1\n",
"eps_y_pred = svm_reg2.predict([[eps_x1]])\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 4), sharey=True)\n",
"plt.sca(axes[0])\n",
"plot_svm_regression(svm_reg, X, y, [0, 2, 3, 11])\n",
"plt.title(f\"epsilon={svm_reg[-1].epsilon}\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
"plt.grid()\n",
"plt.sca(axes[1])\n",
"plot_svm_regression(svm_reg2, X, y, [0, 2, 3, 11])\n",
"plt.title(f\"epsilon={svm_reg2[-1].epsilon}\")\n",
2016-09-27 23:31:21 +02:00
"plt.annotate(\n",
" '', xy=(eps_x1, eps_y_pred), xycoords='data',\n",
" xytext=(eps_x1, eps_y_pred - svm_reg2[-1].epsilon),\n",
2016-09-27 23:31:21 +02:00
" textcoords='data', arrowprops={'arrowstyle': '<->', 'linewidth': 1.5}\n",
" )\n",
"plt.text(0.90, 5.4, r\"$\\epsilon$\", fontsize=16)\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"svm_regression_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 22,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('svr', SVR(C=0.01, degree=2, kernel='poly'))])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.svm import SVR\n",
"\n",
"# extra code these 3 lines generate a simple quadratic dataset\n",
"np.random.seed(42)\n",
"X = 2 * np.random.rand(50, 1) - 1\n",
"y = 0.2 + 0.1 * X[:, 0] + 0.5 * X[:, 0] ** 2 + np.random.randn(50) / 10\n",
"\n",
"svm_poly_reg = make_pipeline(StandardScaler(),\n",
" SVR(kernel=\"poly\", degree=2, C=0.01, epsilon=0.1))\n",
2017-06-01 09:23:37 +02:00
"svm_poly_reg.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 23,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnsAAAEQCAYAAADI77KTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACbb0lEQVR4nOydZ3hURReA30kv9IQAoQuhV0GqdJCOgBTpqEhTUBGVJkUElPYJKFIkgPReBAREAkIApUQIPUCogYSEEkL67nw/NlnTs5vsJpsw7/Psk9y7c2fObWfPzJw5R0gpUSgUCoVCoVDkTqyyWwCFQqFQKBQKhflQxp5CoVAoFApFLkYZewqFQqFQKBS5GGXsKRQKhUKhUORilLGnUCgUCoVCkYtRxp5CoVAoFApFLkYZeykghNgjhFiV3XIoXk2EEIOFEGGpbSsUxqD0mSIrEULcFkKMTW1bkT0oY+8VQwhhK4T4XghxQQjxUgjxUAixXghRKoP11RJCbBJCPBJCRAohbgghVgkhqmdCxmZCiLNx9d0SQgw34JiCQog1QojncZ81QogCScosEEKciav3dkblywI2Aa9ltxBJEUJUF0IcFUJECCEeCCEmCyFEOscMFUJ4CSGeCSGkEKJMFomreAWwdH0mhCgWJ89VIYQmNaNbCPGOEOKyECIq7m+3FMqMFEL4x8l1VgjRJCMyZQFvAIuzW4iECB1ThRABcfrriBCiajrHGHTvcgrK2MsmhBBWQgjrbGjaCXgdmBH3922gJLBfCGFjTEVCiE7A30AeYABQGXgXeAh8lxHhhBBlgX3ACaA2MAtYJIR4J51D16M7n/ZAu7j/1yQpYwWsBn7NiGxZhZQyQkoZlN1yJEQIkQ/4AwhEp8xHA18AY9I51Ak4CEw1p3yK7EXps1SxB4Ljjv87lXYbouvgrQNqxf3dIoSon6BMb2ABMBOdXjwB/J5Ro9acSCkfSynDs1uOJHwJfA6MQqe/goA/hBB50zgm3XuXo5BSvtIfdMpiFRCG7odsArAHWJWgjB3wPXAfeAmcBtomqacjcA2IBP5CpyQkUCbu+8FxbXQALgKxQDUD664C7AVeoHtINwBFTXgNqsTJWt3I6/YY2J3K9wUyKMv3gF+Sfb8AJ9M4pnKc/I0T7Hszbl/FFMqPBW5n4no1Ao4C4cAD4GcgX4LvjwBL0Cnnp3GfOYBVgjLdgQtABPAkrr4iCZ+VBGUTbcftGwbcAKLj/n6Y5HsJDAW2xD1Xt4D+mTjnEUAo4Jhg36S48xcGHF834fugPub5KH1mWfosSR2J7kOC/ZuAP5LsOwRsSLD9N7A8SRk/YJaRMqR5f4DmcdeuE/Bv3P0/C9RJUCY/uo50UNz3t4BPE3x/GxibxnYpYEfc/X8BbAdKJPh+atwz9S5wM67MTsA1g9ddoDPYJybY5xhX77DM3Luc9FEjezAXaAO8A7RC12tqmqTMSqAZ0Beojm506DchRE2AuN7VdnQKrCawEJidQlsO6H4gh6FTSHcMqLsYOmV7EagHtEbX89wthLCKK9NPCBGWzqdfGtcgX9zfp+lerf9oC7iSSo9XSvks/n8DZPs9waEN0Y0EJeQAUFcIYZuKLA3R/fCcSLDPG50ya2TEOaVL3HTOQWA3unvdHV1v3DNJ0X7oRhIborvfQ4FP4+ooCmxEd68ro3veko5CpiVDN+BH4Ad0P7ALgMVCiM5Jik4GdsXJuQnwFEKUTlDPpXTuy6UEdTUEjkkpIxLsOwC4A2UMlV1hdpQ+syx9Zgip6bxGce3ZAXVSKHMQ4/VbmvcnAXOBr9B10m4Be4UQTnHffRt3bCegEvA+uk5fugghBDrDrQjQEmiBTofsjPsunjJAb6Ab8Ba653hGgnqMeUbKAkVJcP3i9NhfmPj3waLJbmszOz/olEwU0C/JvmfEWfFAOUALlEpy7E5gcdz/s4ArJBjhQNejTtoTliTuIRlS9zfAn0m+LxhXV7247bxA+XQ+eVO5BnboDKMUe7RpXLsv42QoaEDZ9GQrnqDsdWBykuObxrVVLJX6JwC3Uth/Cxifwv4Mj+yhmwJekWRfrTj53OK2j8SdR8LnYRJwP+7/1+PKl06ljcGkMbIXd788kxyzCjieYFuSoNcP2KAbieyfYF/pdO5L6QRlD6bQZqm4dhoacN3UyJ6ZPyh9Bhamz5Icl9rIXjQwMMm+gUBU3P/ucbI1TVJmMnDNiHM05P40j2srpWdoSNz2bmBlGu3cJpWRPXQdEU1CPYDOP1kLtI7bnopuxDB/gjITgRsJtg1+RtAZdDKF8/YEDhh47XL8yJ5RPg25kHLolMPJ+B1SyjAhhG+CMq+jGwa+nLjjgT1wOO7/SsBpGfdUxJHSHH8suqFxY+quAzQVKa/GLAf8I6WMHw43ijiflrVAAaCLsYcbWlBKecPIumWSbZHK/rSOiT8urWMyQh2gfJwPTcJ2QHc/4n3tTiV5Hk4C0+N8386jm6a5KIQ4GPf/VinlYwNlqEzykcTjJL+HF+L/kVLGCiEeA24J9t0xsD39IUm2DbkviqxD6TPL1GcGVZtkOyXdZUiZtDDk/sST0jNUJW7Xz8BWIcTr6Px4f5NSHjVQhspAgJTydoL6bwkhAuLqPxS3+46U8nmC4wJIrLsy8oxk9vrlaF51Y8+QF9wK3QPxBhCT5Lv4KS1DH5ooKaXGyLqt0E2npLR0PRB0Q9rA0nTaHialXBe/EacYN6Abjm8upQwxQP6EXI/7W5nE06fJSEWxJ+SYlLJ93P+P0A25J8QN3Q9LajI+AtyEECL+BypuSqAwcdfIhFih8yH8XwrfGTSVIaXUCCHeAhqgm6L4AJglhGgmpTxvoBwpPW9J9yV9piQJFmXFTdOWJnXuSCnjV6yldl/A9NdYkTGUPrM8fWYIqb1b8e9VMLrRsLTKGIIh9yddpJS/x7mDtEfnKrBXCLFFSvmeAYen9Wwl3J+e7jLmGXkUt10UuJfge2OvX47mVTf2bqB7qBqgm/JDCOGMzg/qZlwZH3QPaFEppVcq9VxBtwosIfUMaN+Qus8BvdD98CZ9AeLZTfqrhfQPdZzv20Z059lcSvko1aNS5yA6JTSOFHrRQogC8j8/l1rp1JVQ0ZwEuib5vg1wJo3zP4luqqEh/ynqhoAz6SjuDHAOqGpA775+QuMT3TMWIKUMBYjbfxI4KYT4BriEzkfFEGPvCroFKAlH994ELht+GoDOuT41P0hIrHBPAt8LIRyklJFx+9qg63HfNrJdhXlQ+szy9JkhnET3Ls1JsK8NcbpLShkthDgbt29LkjLbjGjHkPsTT0rPkD6KgZQyGJ2f8Zo4H8UNQojhUsqodOq9DBQXQpSJH90TQryGbqraGP1lzDPij87ga4NuQQpCCAegCbqIAq8G2T2PnN0fdEPS99A9CFXRObKHknj12lp0zsc90PkX1EXXM+0e931pdL4yc4GK6Jz2b5PAL4sUVlQaWLc7uqnB7UD9uDKtgWWk4reSzvnaoPPReIBuWL9ogo+jkXW9jc7fZG/c9SsTV+d0YG8G70dZdAsrfkDXyx4S18Y7Ccp0A66S2Nfvd8AXnZJqGPf/b0nqLo9OUc9HZ6TUivvYGSFfDXS+b0vQOQ2XR+eovDRBmSPophgWxD0PPdD5vMT7rTRA58P3Bjq/t7fjyvdP6VlJYbsruh/1jwAPdOEEYoDOCcpIoEcS2W+TwJfGyPuSH53CjP9R7Y7uPfk8QZl6cfelXoJ9ReOucd84mTrEbRfK7nc/N35Q+syi9FlcvbXiPn+hM1JqAVUSfN8I3czFeHRT6OPj3uf6Ccr0jpNtCDq9uADdorTSRsqS3v1pHnefLyd5hgIB57gy36DTQR5xsmwisT/dbVL32RPoDH5vdFP6ddEZu2eI8xElbjVuErlTfN6MOO+v0L0H3dHpr43ofgPyJijzK/CrMfcuJ32yXYDs/qAb/fk17sUJAr4meagC27gH8FbcC/co7sYndE7uhG4qIBI4BrwX99KkGE7DyLo9gK3oVpdFoAuJsAgjjJQEdZWJkyulz+AE5VZhwCKGuBd2S5wyiEI3grAK3ehXRu9
2022-02-19 10:24:54 +01:00
"text/plain": [
"<Figure size 648x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 511\n",
"\n",
"svm_poly_reg2 = make_pipeline(StandardScaler(),\n",
" SVR(kernel=\"poly\", degree=2, C=100))\n",
"svm_poly_reg2.fit(X, y)\n",
"\n",
"svm_poly_reg._support = find_support_vectors(svm_poly_reg, X, y)\n",
"svm_poly_reg2._support = find_support_vectors(svm_poly_reg2, X, y)\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 4), sharey=True)\n",
"plt.sca(axes[0])\n",
"plot_svm_regression(svm_poly_reg, X, y, [-1, 1, 0, 1])\n",
"plt.title(f\"degree={svm_poly_reg[-1].degree}, \"\n",
" f\"C={svm_poly_reg[-1].C}, \"\n",
" f\"epsilon={svm_poly_reg[-1].epsilon}\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
"plt.grid()\n",
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plot_svm_regression(svm_poly_reg2, X, y, [-1, 1, 0, 1])\n",
"plt.title(f\"degree={svm_poly_reg2[-1].degree}, \"\n",
" f\"C={svm_poly_reg2[-1].C}, \"\n",
" f\"epsilon={svm_poly_reg2[-1].epsilon}\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"svm_with_polynomial_kernel_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Under the hood"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 24,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAm4AAADWCAYAAABorg4iAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA3iUlEQVR4nO3deXwURfrH8U8l5IYEAqKSVVl+sIiwnijiBYgHhyK3CEEuwVuUVVfFC1kU3fUGRBRcNSEcAiLrFdGAqCAiKooIyrIoIAgmXLmP+v1RgRBASCCZnk6+79drXtKVme4n40zl6aerq4y1FhEREREJfiFeByAiIiIiZaPETURERMQnlLiJiIiI+IQSNxERERGfUOImIiIi4hNK3ERERER8QombiIiIiE8ocRMRERHxCSVuUi0ZYy4yxrxljNlojLHGmIFexyQi1Zsx5iZjzDpjTI4x5ktjzIWHef7Dxf3Xvo/NgYpXvKHETaqrmsB3wHAg2+NYRKSaM8ZcDTwLPAqcAXwGvGuMOfEwL10NHL/P46+VGad4T4mbBIwxZokx5q59ticXnyEeV7wda4zZZYxpU9mxWGvfsdbeZ619Ayiq7OOJSPAJpj4JGAH821r7krV2lbX2VuBX4MbDvK7AWrt5n8fWyg9VvKTETQJpO1ALwBhTH+gJpAN1in8+APjJWruwLDszxtxnjNl9mMchLzWISLW2nSDok4wx4cBZQOp+P0oFzjvMYRsVD/lYZ4yZZoxpVJZYxb9qeB2AVCsZuEuUADcDc4DTgfjitpuAsQDGmLeAC4EPrbU9/2B/E4EZhznmxqOIV0SqtmDpk+oBocCW/dq3AJccYl+fAwOBH4D6wP3AZ8aY5tba3w8Th/iUEjcJpO1ALWNMJHADcBnwHFDHGHMJ7ix3WvFznwZewp3xHpS1Nh13diwiciS2E1x9kt1v2xykbd/jvVvqycYsAf5bHONTRxGHBDFdKpVA2nN2mwh8Z639BtiJ6xxvAV6w1uYCWGvTgF2H2pkulYrIUQqWPmkbUAgct197fQ6swv0ha+1uYCXQpKyvEf9RxU0CaTtuPMntwN+L23bgLk1cCgwr5/50qVREjsZ2gqBPstbmGWO+LD7mzH1+dCkwq6wHL64cngyklfU14j9K3CSQMoA2uI7rneK2nbjOcYa19rfy7OxoLksYY2oCjYs3Q4ATjTGnA+nW2p+PZJ8i4jtB0yfhLm2+boxZCnyKu3TbAJcMAmCMuQW4xVp7cvH2v4B5wM+46twDQAzw6hHGID6gxE0Cac9liWestXvGbezY0xbgWFpS+qx0VPHjVdxgXxGp+oKmT7LWTjfG1MXdYHA8bp7JTtba9fs8rR7QdJ/tPwEpxe1bgSXAufu9RqoYU/JZFQkuxpi2uLPLP7qDS0QkYNQnSTBQ4iZByRgzHzgNV/ZPB3pZaxd7G5WIVFfqkyRYBCxxM8acALyGu2umCJhkrX02IAcXERERqQICmbgdDxxvrV1ujKkFfAl0tdZ+H5AARERERHwuYPO4WWt/tdYuL/73LmAVkBCo44uIiIj4nScT8BpjGgJn4JbrEBEREZEyCPh0IMXzZ80CbrfW7tzvZ8MonvAwMjLyrBNPPDHQ4R21oqIiQkL8tyBFUVERxhivwyiXH3/8kSZN/DdBuLXWF5+RggLDpk1R5OSEYgzUr59FXFyh12GVy5o1a7ZZa4+pjH1Xlf7Kb997a63vYgZ/91d+e7/90sfur6z9VUDvKjXGhAH/Ad631h5yHbWmTZva1atXByawCrRgwQLatm3rdRjl9sEHH9C8eXOvwyiXhIQENm7038IIa9asCfrPyGefQY8esHkznHgizJkDO3f677NtjPnSWtuyso/j1/7Kj9/7lStX+i5m8G9/5cf32w997MGUtb8KWEpqXMo+GVh1uKRNRLzz0kvQtq1L2tq2hWXL4MwzvY5KREQgsGPczgf6AxcbY74ufnQK4PFF5BDy8uCGG2DYMMjPh+HDITUVjqmUC40iInIkAjbGzVr7CeCvC+Ui1cTmzdCzJ3z6KUREwKRJcO21XkclIiL701qlItXc0qXQrRts2gR/+pMbz9ay0keFiYjIkfDfbRciUmFeeQUuvNAlbRde6MazKWkTEQlevqy4FRUVsW3bNrZv305hYXBNTxAXF8eqVau8DqPc6tatS3p6eoXv1xhDREQEUVFRvrulvCrLz4c77oDx4932zTfDU09BeLi3cYmIyKH5MnHbsGEDxhgaNmxIWFhYUCUEu3btolatWl6HUW47d+4kMjKyQvdpraWgoICtW7eya9cuYmNjK3T/cmR++w169YKPP3aJ2oQJMGSI11GJiEhZ+PJSaWZmJgkJCYSHhwdV0ialGWMICwvjuOOOo6CgwOtwBPjyS3cp9OOPoUEDWLhQSZuIiJ/4MnEDfDkrcnWl/1fB4bXX4Pzz4Zdf4Lzz3Hi2c8/1OioRESkP/UUVqeIKCtx4tgEDIDfXzdOWlgbHH+91ZCIiUl6+HOMmImWzbRtcfTV89BGEhcHzz8P113sdlYiIHCklbiJV1FdfufnZ1q+H446DN95wl0pFRMS/dKlUpApKSXFJ2vr1cM45bjybkjYREf9T4hZg1lqeeOIJmjZtSlRUFPXr16dHjx4Vtv9zzz2Xf/7zn3u3hwwZgjGGzZs3A27aj1q1arFw4cJD7mfWrFnUqlWL9evX720bMWIEJ598Mlu2bKmweKViFRTAXXdB376QnQ2DB7s7RxMSvI5MREQqghK3APvnP//JK6+8woQJE/jhhx946623uPTSSw943qOPPkrNmjUP+Vi0aNEBr6tduza7du0C4LfffuONN94gPj6ejIwMAF599VUaN25MmzZtDhln9+7dadGiBWPHjgXgqaeeYsaMGcybN49jjz32aN8GqQTp6dCpE/zrX1CjBowbBy+/DBU8PZ+IiFQwa8v+3Cozxs2r6dzK82YDvPfee3Tq1In27dsDcNJJJ3HuQeZkuOGGG+jdu/ch95VwkDJKnTp12L17NwDjx4+nW7dufP3113tXRZgwYQL33HMPAF26dGHRokW0b9+eKVOmlNqPMYZHHnmErl270qhRI8aOHcv7779PkyZNAJfYffrpp7Rr145p06aV702QCvftt9C1K/z3v3DMMW4820UXeR2ViIgcSnq6m6pp0qSyv6bKJG5+0aVLF/72t7/xzTff0KtXL3r06EG9evUOeF58fDzx8fHl3v+eiltOTg4TJ04kNTWV2267jYyMDObPn09GRgZ9+vQB4I477mDo0KG8+uqrB93XpZdeSsuWLXnooYeYPXs2LfdZxPK2225j8ODBJCUllTtGqVgzZ8LAgZCVBWed5RaJP+EEr6MSEZGDsRY++wxefNH13zk55Xt9lblUaq03j/K6/fbbWb16NR06dGDChAn83//930HXNj3SS6V7Km5JSUm0aNGC0047jdjYWDIyMhg3bhw33ngjERERALRr1+6Qy3OlpaWxYsUKrLUHXB5t27atL5f2qkoKC+G++6B3b5e09e8PixYpaRMRCUYZGfDcc/DXv8IFF8Drr7uk7bLLYNassu9HFTcPNG7cmDvvvJPhw4dTt25dVqxYQbNmzUo950gvle6puD3zzDM8/vjjgFv4/uuvv+aDDz5gUhnrsStWrODqq6/m6aef5t133+X+++/n7bffLuNvKJUtIwP69YN334XQUHjySbjtNu+GDIiIyIGshSVLXHVt+vSS6lr9+u7msaFDoVGj8u1TiVsAPf744xx77LGcc8451KhRg1dffZXw8HDatm17wHOP9FJpnTp1WLhwIQkJCXTq1AmA2NhYJk2aRO/evalfv/5h97F+/Xquuuoqhg8fzsCBAzn77LM566yzWLhw4WFvapDKt3KlG8/2009Qt64rtbdr53VUIiKyx44dkJTkErZvvy1pv+QSNwl6ly4QHn5k+1biFkC5ubk8/vjjrF+/nujoaM4991w+/PDDCr1Lc8+l0ttvvx1TXH6Ji4vb23Y46enpXHnllXTs2JGRI0cC0Lx5c3r06MEDDzzAxx9/XGGxSvnNmQPXXgu7d8P
"text/plain": [
"<Figure size 648x230.4 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 512\n",
"\n",
"import matplotlib.patches as patches\n",
"\n",
2016-09-27 23:31:21 +02:00
"def plot_2D_decision_function(w, b, ylabel=True, x1_lim=[-3, 3]):\n",
" x1 = np.linspace(x1_lim[0], x1_lim[1], 200)\n",
" y = w * x1 + b\n",
" half_margin = 1 / w\n",
2016-09-27 23:31:21 +02:00
"\n",
" plt.plot(x1, y, \"b-\", linewidth=2, label=r\"$s = w_1 x_1$\")\n",
" plt.axhline(y=0, color='k', linewidth=1)\n",
" plt.axvline(x=0, color='k', linewidth=1)\n",
" rect = patches.Rectangle((-half_margin, -2), 2 * half_margin, 4,\n",
" edgecolor='none', facecolor='gray', alpha=0.2)\n",
" plt.gca().add_patch(rect)\n",
" plt.plot([-3, 3], [1, 1], \"k--\", linewidth=1)\n",
" plt.plot([-3, 3], [-1, -1], \"k--\", linewidth=1)\n",
" plt.plot(half_margin, 1, \"k.\")\n",
" plt.plot(-half_margin, -1, \"k.\")\n",
2016-09-27 23:31:21 +02:00
" plt.axis(x1_lim + [-2, 2])\n",
" plt.xlabel(\"$x_1$\")\n",
2016-09-27 23:31:21 +02:00
" if ylabel:\n",
" plt.ylabel(\"$s$\", rotation=0, labelpad=5)\n",
" plt.legend()\n",
" plt.text(1.02, -1.6, \"Margin\", ha=\"left\", va=\"center\", color=\"k\")\n",
"\n",
" plt.annotate(\n",
" '', xy=(-half_margin, -1.6), xytext=(half_margin, -1.6),\n",
" arrowprops={'ec': 'k', 'arrowstyle': '<->', 'linewidth': 1.5}\n",
" )\n",
" plt.title(f\"$w_1 = {w}$\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 3.2), sharey=True)\n",
"plt.sca(axes[0])\n",
2016-09-27 23:31:21 +02:00
"plot_2D_decision_function(1, 0)\n",
"plt.grid()\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plot_2D_decision_function(0.5, 0, ylabel=False)\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"small_w_large_margin_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 25,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAj4AAADlCAYAAABTVP1pAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA62klEQVR4nO3dd5wU9f3H8dfnGr0jqDS7FPUgIBpURKMJGruIRixgFGsMiRpFjWASSYy9xi5q/MVTRMVeQcQAAlKkGlRAiiAq5eCAK9/fH99dWJbrt7uz5f18PPZxuzOzM5+52/3ee2a+M2POOUREREQyQVbQBYiIiIgkioKPiIiIZAwFHxEREckYCj4iIiKSMRR8REREJGMo+IiIiEjGUPARERGRjKHgIyIiIhlDwSeKmY02szeCrgOSq5ZUY2YtzGy1me0bdC2JZGZjzOyPQdchUldm1svMnJntFYN59QvNq3UMSqtrLYPNrDDoOjJZxgSfikJEOV+u3wPnJbQ4iYcbgbecc1+FB5jZFWb2jZltMbMZZnZUdWdmZn3NbJyZrQh9XgbHo+iaMLM7zeydqMG3AjebWbMgapLEM7PdzOxhM1tiZltDgf9DMzs+6NrizcwmmNmD5QwfYGaRtyX4L7AH8EPCipOklTHBp7qcc+udc+uCrkNqz8waAhcDT0YMOxu4DxgF9MA3hG+bWcdqzrYxMBcfjItiWnDtHQp8FjnAOfcF8DUK75nkZaA38FvgAOAk4G2gVZBFAZhZlpllB12Hc26bc+47p3s0CQo+u4jeMxTaonjYzEaZ2VozWxPa0s6KmKaRmT1rZoWhra3hZvaGmY2OmMbM7E9m9pWZFZnZF2ZWo39OZlbPzO4NLWOLmU0xsyMjxvcNDSs0s/VmNtXMDqpqXDyYWfvQnpGzzewjM9tsZrPNrHNoL9vE0LDPosOHmd1sZnNCtX4f+ps0iBh/VmjLtlPEsPtCv9u2wIlAGfBpxGz/CIx2zj3unFvgnPsdsAq4vDrr45x7yzl3o3NuTGjecRPxWVkU+qysMbOXI8bnmtk2oC/w59DveV7ELMYBv4lnjZIczKw5cBRwg3PuQ+fcUufcNOfcnc65FyKma2Nmr4U+T0vN7CIzm2tmIyOmcWY2IGr+S8zs2ojXfwx9NzeF9n4+EaohPH5w6Ht7opnNBbYBXcwsz8xuN7PlofdOM7NfRS2rv5ktDLVtn+BDXKx+Tzsd6oqo8xeh38MmMxtvZntHvW94qL0tDLXxI8xsSdQ0Q8xsfqjuL83sDxbx/6Ga9V1qZovNbFvo5yXljP8ytIzvzexdM8sJjTvY/B6+DWa2MdTOHlOrX1SGUPCpnkFACdAHuAoYBpwdMf4u4GjgdOBYIB/fGEX6G36L7EqgK/B34FEz+3UN6vhnaLkX4fdafAG8Y2Z7hL4ErwGTQss/DL+Ho7SycRUtyMxuDH3ZK3tUdqioe+jn5fjDL4cB9YDRofW4Efg50BofSiLlhN7XDf8P/Hj87zxsTGjdbw7Vem1ouv7OudX43/2M8NadmeUBPYH3opbzHv5vmmyuA4YAVwCdgVOA9yPGl+J/d+B/r3sAR0aM/wzoHRkWJW0Vhh6nmFn9SqYbDewHHAecBlwA7FWL5ZXhv4vdgHPxe5oeiJqmPv67eSm+rVsKPI1vI88FDgaeAV43s3wAM+sAvIr/nHcPzfOftaivJuoBw/Ht6c+B5sAj4ZFmdg4wArgJ+BmwgKi2KhRQRgG3AF2Aa4Dr8d/dajGz04EHgXuBg/Bt88NmdnJofC/gIXw7eiD+bxh5iPv/8BtxvfH/F0YCW6q7/IzknMuIB/6LX8KOhiL82Aw4YK+I6d6IeN8EYHLUvN4Hngg9b4zfqjknYnwj4Cf8Hobw6yLgqKj53Ivvh1JZzW9EzGMbcEHE+GzgK3yoahlaj6PLmU+F4ypZdkt8Q1nZo0El778JWAe0jRj2APA90Cpi2NNAQRW1PAY8EzXsl0AxcAOwETg0YtyrkdMDe4bWv2/UPG4BFtXis1QIDI7jZ/Uj4K4qpjkJ2ABYOeMOCa3vvvGqUY/keQBnAj/i/9lNBu4EDosYf0Do83BExLBO+AA9MmKYAwZEzXsJcG0ly+4PbAWyQq8Hh+bTM2KaffGBqWPUe18FHg49HwV8Gfl5xoen7W1zBcufEGoXo9v1IsBFTNcvNK/WUXUeGDHNoNC8wusyGXgkannvAUsiXi8Dzo+aZhgwv5KaBwOFEa8/BZ6KmmY0MCn0/AxgPdCkgvltAC4M+nOYSo9M2+MzEb81Efk4txrvmxP1eiXQJvR8XyCXiL4WzrlN+P4gYV3xW0HvRO4xwe/VqO5ZR+HlbD9845wrxX85uzrnfsR/Wd41szdDu6Q7hKarcFxFnHM/OucWV/GorK9Ld3xoWx0xrCMw1jn3Q9Swb8IvzKyDmd1v/lDgj6Hf0xBgeVR97wHT8KFvoHNuWsToBpS/xRN9fN/KGRYTZva30K71yh79Knj7OGCYmX0Q2sVd3pkoPYDZLtTyRQn/XbTHJwM4517Gh/uT8X17+gBTzOzG0CRd8MEjso1aim/HasTMjjWz90OHrDYCY4E8YPeIyUqAWRGvf4b/rs2Pav9+zY72rwswJerzPLmaZRWwa7t+XTXet9U5tyji9Up8G9s89LozUX3ogKnhJ2a2G9ABv+c+cr3+QfXbdfDr/mnUsEn4/xvgN7SXAt+Y2fNmdqGZNYmY9m7gCfNdCm4ys841WHZGyrTgszn6nzdR/1ArUBz12rHjd2cRwyoSnvZkdv5ydsPvuaiOypbjN9ecG4I/9DERf3jky/Bx9MrGlbuwuh/qygemRA3rwa6NWT4wM7TMVvgwsztwLf6QVS98iJkVVV/4kKIBkeEKYC3QIup1KTs3zuDDa/R7Y+VefINW2SO6UQXAOXcvfpf2O/hd5l+ZWZeoyboT+r2Vo2Xo5/e1LV5Si3Nui3PufefcX5xzffAd+0eGDvNaFW/fPptyps0NPzHfp+5N/CGfs/CHjy8Kjc6LeM/W0EZZWFZo3oeyc/vXJeL91a2xPOvLade/q8b7SqJeh9vWrHKGlSc83WXsvF4H4dv2mqisXd+ID48D8XuYhgMLzWzP0PiR+JD0Kj70zjGzi8qZn4RkWvCJh8X4YNQ7PMD8WUWRHYfn43cHdypnr8nSGixnGxF9OcyfLfHz0PwBcM7Nds7d7pzrh98NfGF1xpXjEXbdiop+TC/vjWbWCL/FMzNiWEv81lHksA74M0/Cw36N3zN2tnPuXefcPHxYa0xE8An1CxgL/A7/Zf97VAkz2bG1hHNuGzAD31co0vH4s7tizjm31jm3sIrH5krev9g5dyc++Bn+8FWkfHbdExl2ELAyam+bZJb5+L5y9fFBJQsfPAAwf0LBnlHv+R7fXyw8TdvI1/jPYh7wB+fcZOfcl+XMozwz8Z/h3ctp/1ZE1HuYmUUGoMOrMe94WkhEux6y/XXo+7UCf0h5lz3iNVjOAnbuo0fodWS7XuKc+8g5NxzfFjTCH+4Oj/+fc+5+59yv8aH34hosP+PkBF1AqnPOFZrZU8DtZrYW38nsZnZs5eCc22hmdwJ3hr7YE/H/zA8Hypxzj1VjOZvM7F/AP0LL+Qb4A9AW3xFub3xnwnH4L+M++C/IvyobV8nyfsT3G6iN8D/p2RHDeuDD3/yIYd3xx+PDjcQP+N/LaWb2BXACvhP0xvA0oa3Ot4C7nXNPmdln+C2cfs65CaH5vIv/e7SKOKx2N/BcaPpP8VtpexLRmbEyZtYY368J/N+2o5l1B350zi2rzjyquZzr8XuhPsNvkV6ID7wToibNATqHtvo2u50vwXAUO3d+lDQV2kv6EvAUPghvxAeUPwEfOuc2ABvMX+/pUTMbij8Ueje7XpbhI+BKM/svfg/pKHY+ZPw//Gd/mJmNxbdfw6qq0Tn3pZk9D4w2s2uAz/F7JfsBXzvnxuK/h9cA95rZw/gO0JfV7LcRc/cBT5vZNOAT/Mkrh+H7b4aNBB4ws3X4dikXv3emnXMueoOsIncAL5nZDHwfov74/kZnAJjZSfgNyYn4NvkYoAmwIHQ
"text/plain": [
"<Figure size 590.4x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 513\n",
"\n",
"s = np.linspace(-2.5, 2.5, 200)\n",
"hinge_pos = np.where(1 - s < 0, 0, 1 - s) # max(0, 1 - s)\n",
"hinge_neg = np.where(1 + s < 0, 0, 1 + s) # max(0, 1 + s)\n",
"\n",
"titles = (r\"Hinge loss = $max(0, 1 - s\\,t)$\", \"Squared Hinge loss\")\n",
"\n",
"fix, axs = plt.subplots(1, 2, sharey=True, figsize=(8.2, 3))\n",
"\n",
"for ax, loss_pos, loss_neg, title in zip(\n",
" axs, (hinge_pos, hinge_pos ** 2), (hinge_neg, hinge_neg ** 2), titles):\n",
" ax.plot(s, loss_pos, \"g-\", linewidth=2, zorder=10, label=\"$t=1$\")\n",
" ax.plot(s, loss_neg, \"r--\", linewidth=2, zorder=10, label=\"$t=-1$\")\n",
" ax.grid(True, which='both')\n",
" ax.axhline(y=0, color='k')\n",
" ax.axvline(x=0, color='k')\n",
" ax.set_xlabel(r\"$s = \\mathbf{w}^\\intercal \\mathbf{x} + b$\")\n",
" ax.axis([-2.5, 2.5, -0.5, 2.5])\n",
" ax.legend(loc=\"center right\")\n",
" ax.set_title(title)\n",
" ax.set_yticks(np.arange(0, 2.5, 1))\n",
" ax.set_aspect(\"equal\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"hinge_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Extra Material"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## Linear SVM classifier implementation using Batch Gradient Descent"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 26,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = (iris.target == 2)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 27,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator\n",
"\n",
"class MyLinearSVC(BaseEstimator):\n",
" def __init__(self, C=1, eta0=1, eta_d=10000, n_epochs=1000,\n",
" random_state=None):\n",
2016-09-27 23:31:21 +02:00
" self.C = C\n",
" self.eta0 = eta0\n",
" self.n_epochs = n_epochs\n",
" self.random_state = random_state\n",
" self.eta_d = eta_d\n",
"\n",
" def eta(self, epoch):\n",
" return self.eta0 / (epoch + self.eta_d)\n",
" \n",
" def fit(self, X, y):\n",
" # Random initialization\n",
" if self.random_state:\n",
" np.random.seed(self.random_state)\n",
" w = np.random.randn(X.shape[1], 1) # n feature weights\n",
2016-09-27 23:31:21 +02:00
" b = 0\n",
"\n",
" m = len(X)\n",
" t = np.array(y, dtype=np.float64).reshape(-1, 1) * 2 - 1\n",
2016-09-27 23:31:21 +02:00
" X_t = X * t\n",
" self.Js=[]\n",
"\n",
" # Training\n",
" for epoch in range(self.n_epochs):\n",
" support_vectors_idx = (X_t.dot(w) + t * b < 1).ravel()\n",
" X_t_sv = X_t[support_vectors_idx]\n",
" t_sv = t[support_vectors_idx]\n",
"\n",
" J = 1/2 * (w * w).sum() + self.C * ((1 - X_t_sv.dot(w)).sum() - b * t_sv.sum())\n",
2016-09-27 23:31:21 +02:00
" self.Js.append(J)\n",
"\n",
" w_gradient_vector = w - self.C * X_t_sv.sum(axis=0).reshape(-1, 1)\n",
" b_derivative = -self.C * t_sv.sum()\n",
2016-09-27 23:31:21 +02:00
" \n",
" w = w - self.eta(epoch) * w_gradient_vector\n",
" b = b - self.eta(epoch) * b_derivative\n",
" \n",
"\n",
" self.intercept_ = np.array([b])\n",
" self.coef_ = np.array([w])\n",
2017-12-19 22:40:17 +01:00
" support_vectors_idx = (X_t.dot(w) + t * b < 1).ravel()\n",
2016-09-27 23:31:21 +02:00
" self.support_vectors_ = X[support_vectors_idx]\n",
" return self\n",
"\n",
" def decision_function(self, X):\n",
" return X.dot(self.coef_[0]) + self.intercept_[0]\n",
"\n",
" def predict(self, X):\n",
" return self.decision_function(X) >= 0"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[ True],\n",
" [False]])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"C = 2\n",
"svm_clf = MyLinearSVC(C=C, eta0 = 10, eta_d = 1000, n_epochs=60000,\n",
" random_state=2)\n",
2016-09-27 23:31:21 +02:00
"svm_clf.fit(X, y)\n",
"svm_clf.predict(np.array([[5, 2], [4, 1]]))"
]
},
{
"cell_type": "code",
"execution_count": 29,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZkAAAEOCAYAAABbxmo1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAcXklEQVR4nO3de5hcdZ3n8fenu9OdSyeEkASzCWu4xGhE5JJxuOl0DIwOMMIIjDjiZlzGeGEUbzsmOCoz6wVcdRxnXSUrSFgYJKArPIZBINqwohC5SiA0uQKBkAQIJJ2QTrr7u3+c00l1pxNSSZ061ac/r+fpp0796tSp77fTnU+fX506RxGBmZlZFuryLsDMzIrLIWNmZplxyJiZWWYcMmZmlhmHjJmZZcYhY2ZmmalayEi6WtJ6SUtKxsZIulPSsvT24JLH5kpaLqlN0nuqVaeZmVVONfdkrgHe22dsDrAoIqYAi9L7SJoGXAC8NX3O/5JUX71SzcysEqoWMhFxD/Byn+Gzgfnp8nzgnJLxn0ZER0SsApYD76hGnWZmVjkNOb/+oRGxFiAi1koan45PBO4rWW9NOrYbSbOB2QB1w0ad0HDQeCaPqmP1pm4AJo8qzttO3d3d1NUVp5++3N/AVuT+itwbwFNPPfViRIzLYtt5h8yeqJ+xfs9/ExHzgHkATROmxIRZ36Pt8jOZPGchAG2Xn5lZkdXW2tpKS0tL3mVkxv0NbEXur8i9AUh6Oqtt5x3N6yRNAEhv16fja4DDStabBDxf5drMzOwA5R0ytwKz0uVZwC0l4xdIapJ0ODAFWJxDfWZmdgCqNl0m6QagBRgraQ3wVeByYIGki4BngPMBIuJxSQuAJ4BO4OKI6KpWrWZmVhlVC5mI+OAeHpq5h/W/Dnw9u4rMzCxreU+XmZlZgTlkzMwsMw4ZMzPLjEPGzMwyU9iQ+cD0w15/JTMzy1RhQ2ZEUwPNTbV6QgMzs8GhsCGj/k5MY2ZmVVXYkDEzs/wVOmQi+j2nppmZVUlhQ8azZWZm+StsyJiZWf4KHTKeLDMzy1dhQ8ZHl5mZ5a+wIWNmZvkrdMj44DIzs3wVNmTk+TIzs9wVNmTMzCx/hQ6Z8PFlZma5KmzIeLLMzCx/hQ0ZMzPLX6FDxkeXmZnlq7gh4/kyM7PcFTdkzMwsd4UOGc+WmZnlq7AhI8+XmZnlrrAhY2Zm+St2yHi+zMwsV4UNGZ+6zMwsf4UNGTMzy1+hQ8bnLjMzy1dhQ8azZWZm+StsyJiZWf4KHTI+d5mZWb4KGzI+uszMLH+FDRkzM8tfTYSMpM9KelzSEkk3SBoqaYykOyUtS28PLne7ni0zM8tX7iEjaSLwaWB6RBwN1AMXAHOARRExBViU3t/37fr4MjOz3OUeMqkGYJikBmA48DxwNjA/fXw+cE4+pZmZ2f5qyLuAiHhO0reBZ4DXgDsi4g5Jh0bE2nSdtZLG9/d8SbOB2QCNbzgKgNbWVp5+ejvd3UFra2s12qiK9vb2QvXTl/sb2IrcX5F7y1ruIZO+13I2cDjwCnCTpAv39fkRMQ+YB9A0YUoAtLS08OD2NrRqOS0tLRWvOS+tra2F6qcv9zewFbm/IveWtVqYLjsNWBURGyJiB/Bz4GRgnaQJAOnt+hxrNDOz/VALIfMMcKKk4ZIEzASWArcCs9J1ZgG3lLthH11mZpav3KfLIuJ+STcDDwGdwMMk01/NwAJJF5EE0fnlbNfHlpmZ5S/3kAGIiK8CX+0z3EGyV2NmZgNULUyXZcbnLjMzy1dxQ8YnLzMzy11xQ8bMzHLnkDEzs8wUNmQ8WWZmlr/ChoyZmeWv8CETPsTMzCw3hQ0ZH1xmZpa/woaMmZnlr/Ah49kyM7P8FDZkfGVMM7P8FTZkzMwsf4UPGc+WmZnlp7Ah46PLzMzyV9iQMTOz/BU+ZPxhTDOz/BQ2ZDxbZmaWv8KGjJmZ5a/wIePJMjOz/BQ2ZHx0mZlZ/gobMmZmlr/Ch4wPLjMzy09hQ0aeLzMzy11hQ8bMzPJX+JAJH19mZpabwoeMmZnlxyFjZmaZKXzI+OgyM7P8FDZkfHCZmVn+ChsyZmaWP4eMmZllprAhI5/s38wsd4UNmR5+49/MLD+FDRm/8W9mlr+aCBlJoyXdLOlJSUslnSRpjKQ7JS1Lbw/Ou04zMytPTYQM8K/A7RHxZuDtwFJgDrAoIqYAi9L7ZfNpZczM8pN7yEgaBbwLuAogIrZHxCvA2cD8dLX5wDllbbdyJZqZ2X5S5PzOuKRjgXnAEyR7MQ8ClwDPRcTokvU2RsRuU2aSZgOzARrfcNQJE2Z9j2veO4LbVm1nQdsOrjxtOE0NxYic9vZ2mpub8y4jM+5vYCtyf0XuDWDGjBkPRsT0LLbdkMVGy9QAHA98KiLul/SvlDE1FhHzSEKKpglTAqClpYU2rYC2Jzn1ne9kRFMttHngWltbaWlpybuMzLi/ga3I/RW5t6zlPl0GrAHWRMT96f2bSUJnnaQJAOnt+nI26qPLzMzyl3vIRMQLwLOSpqZDM0mmzm4FZqVjs4BbcijPzMwOQK3MI30KuF5SI7AS+AhJAC6QdBHwDHD+/mzYx5aZmeWnJkImIh4B+nvTaeb+btOnlTEzy1/u02VZ+bdfLwNg87YdOVdiZjZ4FTZkNm3rBOCj1z6QcyVmZoNXYUOmx5LnNuVdgpnZoFX4kDEzs/wccMhIGlKJQiqtzu/7m5nlrqyQkfRpSeeW3L8KeE1SW8nnXGrCKUeNzbsEM7NBr9w9mU8DGwAkvQv4a+BvgEeA71S0sgP0xfe+Oe8SzMwGvXJDZiKwOl3+S+CmiFgAXAacWLmyDtzY5qady79b/mKOlZiZDV7lhswmYFy6fDrJdV4AdgBDK1VUJRw6alfI3P3UhhwrMTMbvMoNmTuA/52+F3MU8B/p+FuBVZUs7ECp5AyZV96zMsdKzMwGr3JD5mLgXmAscF5EvJyOHw/cUMnCzMxs4Cvr3GURsYnkZJZ9x79asYoO0Mih/bfU1R3U+7hmM7OqKvcQ5mmlhypLOl3SdZLmSqqvfHnlGdUobrn4lH4fW7mhvcrVmJlZudNlVwHHAUiaRHKNlzEk02hfq2xp5RszVBwxrv9LpJ7+L/dUuRozMys3ZN4CPJQunw/cHxFnAB8GPljJwirhJx/5k7xLMDMb1MoNmXpge7o8E7gtXV4BHFqpoiplxtTxeZdgZjaolRsyS4BPSHonScjcno5PBGr+E49Lnns17xLMzAaVckPmi8BHgVbghoh4LB1/H7C4gnVVzFHjd71Hc9a//TbHSszMBp+yQiYi7iH5xP/YiPivJQ9dCXyikoVVyl2f+7Ne9yMip0rMzAafsk/1HxFdJGdePlrSWyUNjYjVEbE+g/oq7orb2/Iuwcxs0Cj3czINkv4HsBF4FHgM2CjpW7V6XZm+fnT3irxLMDMbNMrdk/kWcCHwceBNwBSSabIPA9+sbGmVc9lfTsu7BDOzQanckPkb4KKImB8RK9Kva4C/Az5U8eoq5G9PObzX/Sdf2JRTJWZmg0u5IXMQyWdi+loBjD7gaqpk/aaOvEswMxsUyg2ZR0mujtnXJeljNetnnzh553J7R2eOlZiZDR7lhsw/ALMkPSVpvqRrJLWRvE/zhcqXVzknvPHgncufvP6hvaxpZmaVsj+fk3kTcBPQDIxKl99D/3s4ZmY2iO3P52Sej4gvRcS5EfH+iPhHYAtwbuXLq6xV3zxj5/IDq1/ey5pmZlYJZYfMQFZ6SebzfvT7HCsxMxscBlXIAKy+/Mydy/cur/lzepqZDWiDLmQAPjD9MAA+9OP7c67EzKzYGvZlJUm3vs4qoypQS9Vccd4x3PjAswAc+8938MhX/jzniszMimlf92Reep2vVcC1WRSYlcWXzgTgla07ePDpjTlXY2ZWTPu0JxMRH8m6kGobP2oobzq0mafWtXPuD3/Hym+cQV2dXv+JZma2zwblezI97vjsrmvNHHHpbXtZ08zM9kfNhIykekkPS/plen+MpDslLUtvD369beyP0s/OTJ6zMIuXMDMbtGomZEjOf7a05P4cYFFETAE
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"plt.plot(range(svm_clf.n_epochs), svm_clf.Js)\n",
"plt.axis([0, svm_clf.n_epochs, 0, 100])\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.grid()\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 30,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-15.56761653] [[[2.28120287]\n",
" [2.71621742]]]\n"
]
}
],
2016-09-27 23:31:21 +02:00
"source": [
"print(svm_clf.intercept_, svm_clf.coef_)"
]
},
{
"cell_type": "code",
"execution_count": 31,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-15.51721253] [[2.27128546 2.71287145]]\n"
]
}
],
2016-09-27 23:31:21 +02:00
"source": [
"svm_clf2 = SVC(kernel=\"linear\", C=C)\n",
"svm_clf2.fit(X, y.ravel())\n",
"print(svm_clf2.intercept_, svm_clf2.coef_)"
]
},
{
"cell_type": "code",
"execution_count": 32,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAq4AAADwCAYAAADB5OaoAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAACRF0lEQVR4nOydZ3hURRuG70kPLbTQWxqhhCYivYkIIqAUBZQmItKUIiLSERAEBFGKgBQpomgA6YJIL9IhlBSSICWhhN7S5/ux2f2SsEk2YTtzX9e5sjunzHPOnn0zO2fmeYWUEoVCoVAoFAqFwtpxsLQAhUKhUCgUCoXCEFTDVaFQKBQKhUJhE6iGq0KhUCgUCoXCJlANV4VCoVAoFAqFTaAargqFQqFQKBQKm0A1XBUKhUKhUCgUNoFquCpsGiHEeCHEWUvrUCgUCoVCYXpUw1VhdIQQy4QQUgjxk55101LWbTLwWE1Sti+cwSYzgMbPo9eYCCE8hRDzhBCXhBBxQogbQoidQojmKevP6LsuKetapZxr+VRl7YUQ/wgh7gkhHgshgoQQk4UQRcx1TgqF4sUhsxim4pfCGlANV4WpuAJ0EkLk1hYIIZyAbsBlY1UipXwkpbxtrOPlFCGEkxBCAIHAK8CHQHmgNbAVKJSy6WLSXZdU9AL2SSlDU445GfgdOJVynErAIKAc0M9U56JQKF5oMothKn4pLI+UUi1qMeoCLAM2ASeAD1KVvwVEAj+nrG8EJADF0u0/GTiT8roJIIHCGdQ1Hjirp+5BwDXgLrAUyJVqGwEMB8KBp0AQ0DXdcacCISnrLwHTALf09QI9U46TBJRK0fpaJtemIBCb+rqklHsC8UD3lPevpBxraAbHyW/pz1ktalGLfS1A/sximIpfarGGRfW4KkzJYjS/wrX0QtOIlABSyr1oGn3dtRsIIRxS3i9+jnobAgHAa0AnoB2ahqyWSWh6Ewag6QWYAiwQQryZapvHKXorAv2BzsCodPV4Ae8B7wDVgOvAI6CtEMJNnzAp5R1gPWmvC2h6op8Cf6S8fz9Fww8ZHOeevnKFQqF4Dh6RSQxT8UthDaiGq8KU/AK8LITwE0IUA1qi6RFNzU/AB6netwCKACufo94HQD8p5QUp5XY0j6uaAaQ84hoK9JZSbpNSRkopfwEWoWnIAiClnCilPCClvCSl3AJ8DXRJV48L0E1KeUJKeVZKmYimB7YrcE8IcUgIMUMIUVvPOTdIPRYMzT+CX6SUT1Le+wHhUsqE57gOCoVCYTAGxjAVvxQWRTVcFSZDSnkXWIcmqPUAdksp049v/RnwFkLUS3nfC1gvn2/c6vmUAKwlCk1jGDQ9rG7ANiHEI+2CZsyVj3YHIURHIcR+IcT1lPWzgDLp6rkqpbyRukBKGQiUANqgGRdWDzgshBiZarOdaIZM9EqpqzZQGc0/BJ2EHJy3QqFQPBcGxDAVvxQWRTVcFaZmCZpH/71SXqdBSnkL2AD0EkIUAtryfMMEQDNuNk01/P9e1/5tA1RPtVQGXgcQQtQBfgX+StmuBjAacE533Mf6KpdSxkopd0gpv5JS1kNzPuOFEC4p6yWaIRPdhRCOaIYtnJZSHk91mFDAR7uPQqFQmIvMYpiKXwpLoxquClOzE82g/cJoxkbpYxHwLvAxcAP424R6zgNxQFkp5cV0y38p29QHrqUMFzgqpQwDyj5nnU5oenq1LAWKohkf25m0vRWgGWaRGxio74BCiPzPoUehUCiyQ/oYpuKXwmI4WVqAwr6RUkohRFVASCnjMthsB3AbGAdMlVIm69kmQAhxL13ZmRzoeSiEmAHMSLGv2gvkAeoAyVLKhWh6C0oKId4HDqEZd5t+fOszpPQY/46mZ/kM8BB4GY2DwU4p5YNUOq4KIf4C5qHpyV2VTue/QohpwHQhRCk0FjVX0UwI+xC4CEzI7vkrFApFRhgaw1T8UlgS1XBVmBwp5cMs1kshxFI0FlNLM9hsl56yvDmUNAZNz+4wYD6ayVyn0FheIaXcKISYDnwHuAPbgbFognRmPAIOo3Ew8AVc0Vhy/YLGySA9PwFvoJnUcDf9SinlF0KIY2gmjX2I5vsaCfxpgBaFQqHILtmJYSp+KSyC0AxXUSgsixBiPuArpWxuaS0KhUKhUCisE9XjqrAoQggPoCaaCVzvWliOQqFQKBQKK0Y1XBWW5k80WVYWSyk3W1qMQqFQKBQK60UNFVAoFAqFQqFQ2ATKDkuhUCgUCoVCYROohqtCoVAoFAqFwiaw6zGu+fPnl76+vpaWwePHj8mdO7elZQBKS0YoLfpRWvRz/PjxGCmlp6V1mAJriZtgXZ+50qIfpUU/SsuzGC1uSinNsgCl0XhxXgDOAYP0bPM5Gj/NU8BZIAkomLLuEhCUsu6YIXWWL19eWgO7du2ytAQdSot+lBb9KC36MTQG2eJiLXFTSuv6zJUW/Sgt+lFansVYcdOcPa6JwGdSyhNCiLzAcSHEDinlee0GUsrpwHQAIUQbYIiU8k6qYzSVUsaYUbNCoVAoFAqFwkow2xhXKWW0lPJEyuuHaHpeS2aySxdgtTm0KRQKhUKhUCisH4vYYQkhyqHJER8gU+VvT7U+F5q8xr7aHlchRCRwF5DAAqnJKa/v2H2APgCFCxeu+csvv+Ds7GyS8zCUR48ekSdPHotq0KK06Edp0Y/Sop+mTZsel1K+bGkdxiJ13PT09Kw5Z84cPD09EUJYVJc1feZKi36UFv0oLc9itLhpjPEG2VmAPMBxoH0m23QCNqYrK5HytwhwGmiUVV2lSpWSbm5uctiwYfLWrVs5GpNhDKxlfImUSktGKC36UVr0gx2PcS1TpowEZKlSpeSiRYtkfHy8MS9dtrCmz1xp0Y/Soh+l5VmMFTfNaoclhHAGAoFVUsq1mWzamXTDBKSUUSl/bwLr0GRbyhRHR0diY2OZMWMG3t7eTJgwgQcPnungVSgUCkUKzs7OVK1alatXr/LRRx9RqVIlVq9eTXJysqWlKRQKhfkmZwnNM6fFwAUp5cxMtvMAGgNdU5XlBhyklA9TXr8OfJVVna6urhw7dozRo0ezbds2xo8fzw8//MDy5ctp2bIlV69e5fHjx899blnh4eHBhQsXTF6PISgt+rGklty5c1OqVCkcHJStssLyODo6cvLkSdasWcPYsWMJCwvjvffeY8qUKfz+++/4+/tbWqJCoXiBMaerQH2gGxAkhDiVUjYSKAMgpfwxpawdsF1KmbpFWRRYlzLeygn4RUq5zZBKa9asydatW9m7dy8jR47k8OHDlC9fnpiYGIQQ+Pv7m7zB8PDhQ/LmzWvSOgxFadGPpbQkJydz7do1YmJiKFKkiNnrVyj04eDgQOfOnenYsSM///wzEyZM4MaNG5QqVcrS0hQKxQuOOV0F9ksphZSyqpSyesqyRUr5Y6pGK1LKZVLKzun2jZBSVktZKkspJ2e3/kaNGrFv3z5OnTqFr68v9+7do0iRIkRGRnL79m3tOFqFwqw4ODhQtGhR7t+/b2kpCsUzODk58eGHHxIWFsb27dt1JuYPHz6kY8eOHD582MIKFQrFi8YL9WxSCEFAQAAASUlJPH36lLt37xIZGcn58+e5d++easAqzI6zszOJiYmWlqFQZIirqyvVqlXTvZ8zZw6BgYHUrVuXtm3bcubMGQuqUygULxIvVMM1PR4eHpQrVw4XFxeePn3KxYsXCQ4OVhO4FGbF0pZDCkV26du3LyNHjiRXrlxs3LiRatWq0aVLF0JDQy0tTaFQ2DkvdMNVCEHhwoUJCAigdOnSODk58fjxY0JDQwkPD7e0PIVCobBKChQowOTJk4mIiGDQoEG4uLjw66+/UqlSJaZMmWJpeQqFwo55oRuuWrTjDKtUqULJkiVxdHQkV65clpaVY5o0acLAgQNNdvyePXvSunXr5z7Ovn37EEIQE2N4Ft9ly5ZZhZGyQqGAokWL8t133xEWFkbv3r0BqFq1qoVVKRQKe8acrgJWj6OjI8WLF38mY8yNGzd4/PgxJUqUwM3NzWL6evbsSUxMDJs2bcp0u7Vr15o0W9js2bONMha4du3
"text/plain": [
"<Figure size 792x230.4 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"yr = y.ravel()\n",
"fig, axes = plt.subplots(ncols=2, figsize=(11, 3.2), sharey=True)\n",
"plt.sca(axes[0])\n",
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\", label=\"Iris virginica\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\", label=\"Not Iris virginica\")\n",
2016-09-27 23:31:21 +02:00
"plot_svc_decision_boundary(svm_clf, 4, 6)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.title(\"MyLinearSVC\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([4, 6, 0.8, 2.8])\n",
"plt.legend(loc=\"upper left\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\")\n",
"plot_svc_decision_boundary(svm_clf2, 4, 6)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.title(\"SVC\")\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"plt.grid()\n",
"\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 33,
2016-09-27 23:31:21 +02:00
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-12.52988101 1.94162342 1.84544824]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAADwCAYAAADhPsSkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABWyElEQVR4nO2dd3gU5fbHPycFIkV6FZCahF6kFwWuIgp4UUG5CIKCXBQQQX6IBbwUUS8WRBABBRS5KEoRBBTpgnQEEkgCiFEg9CZBEkjy/v6Y3XWTbJLd7GZnkryf55mHnZn3nfnOZDj77plz3iNKKTQajUaTtwkwW4BGo9Foch5t7DUajSYfoI29RqPR5AO0sddoNJp8gDb2Go1Gkw/Qxl6j0WjyAdrYazRZICL9RSTej+dTItLDaT1cRLaLSIKIxLpqo9FkhTb2GksgImVE5CMRiRWRRBE5KyLrReQ+pzbVReQTEfnd1iZORDaKSD8RKeDUTjktf4nIcRH5n4i0zeDcj4jIBhG5IiLXRSRCRN4QkbL+uHYXVABWOq1PAv4CwoFmGbTRaDJFG3uNVVgCNAcGAKFAV2ANUApARJoCvwD1gGFAfeBBYDbQj7+NoJ1nMAxibdsxbwJbROT/nBuJyBvA18B+2znrAMOBqsCzPr1CN1FKnVFKJTptqglsVUrFKqXOZ9DGI0QkSETEW62aXIRSSi96MXUBigMKuDeD/QIcAvYAARm1cfqsgB4u2kwGkoCatvXmtrYjM9Jl+7c/EO+0vQbwLXAGuA7sA7qm6fsIcBC4AVwCNgPlbPsq2/pfwhixRwO9XOm3fXZe/uPqGoE7gC+By7ZlFVDLaf9/gEjbtfwKJANFzP7b68V/ix7Za6xAvG15SERCXOxvhDHifkcpleLqAMpm0bLgXYxfs91t609gGOsPMzjmlQyOUwTjV8d9QEOMXyVLRSQcQETKYxjezzB+WdwNLHDq/xFQCOgA1AVeADI6VwUgxqa9AvBO2gYiUgjYCCQA9wCtgNPAOts+O9WA3kBPm+6EDM6pyYMEmS1Ao1FKJYlIf2AOMEhEfgG2AV8rpXZiuHXAMHoAiEgx4JTTYSYrpSZncZ6LInIOqG7bVAv4VSl1y0O9B4ADTpveEJFuQA8M/3pFIBj4Rin1u61NpFP7O4EltuMA/JbJuc6ISBLGL4szGTTrhfHr5yn7l56I/Bs4h+GaWmxrVwDoq5Q6696VavISemSvsQRKqSUYRrIbxqi5NbBDRF7JoMs1jBF/IyAOw5C5g2C4QOyfPUZECovIf0XksIhctkXqNAWq2JocANYBkSKyRESeFZEyTof4AHjNFmEzSUTuyo4OJ+7CGLVfE5F4m56rQAkMl5Odk9rQ51+0sddYBqVUglLqR6XUBKVUa+BTDF9zrK1JuFPbFKXUMaXUMYyXr1kiIqWBMsBx26YjQA3nSB43eQfDFTIWw23SCNiF7QtHKZUMdLItBzFeEB8VkYa2/Z9iGOd5GL9afhaR/3iowZkAjBfMjdIsocAsp3bXvTiHJpejjb3GyhzGcDVGA1HAaBEJ9OJ4LwIpGC9HAf4HFAaGumosIsUzOE5b4HOl1BKl1EHgJKlH0CiD7Uqp8RiRQnHA4077TyqlZiulHgPGAYOyfVXGC+KawAX7F6DTcsmL42ryENpnrzEdESmFEf44F2MkfA3DLTIaWK+Uumrz6a8DttvCJaOAQKANUAkjusSZ4rYXpQUwDHE/4ElgtO3XAEqpnSLyX2CKiFTCeNF6EmPUPQA4Box3IfkI8LCIfAvcAl4HHC+WRaQlcC/wA3AWaIwRgXPYtv8DDFfVEeB2oLN9XzZZCIwCvhWRccAftvP9E/hYKXXUi2Nr8gja2GusQDywAyO+vSZQEOPl6/8wXniilNolIk2AlzGiZ8pjhDUeBF4FPklzzDm2fxMxIlN2AO2VUlucGymlXhKRPcAQDAMfhPHC9FuMqBlXjMRwMf2EEeY4FSdjj+Evb4ORD1AcOAFMVEp9YdsfYLuGyhhfbOsxfnVkC6XUXyJyN/AWxpdmMYxfEhtt+jQaIzZZo9FoNHkb7bPXaDSafIA29hqNRpMP0MZeo9Fo8gHa2Gs0Gk0+QBt7jUajyQfk6dDL4sWLq5o1a5otIx3Xr1+ncOHCLvdduXKFEydOEBQURO3atS2jy2ysqk3r8hyrarOqrr17915QSpXJumUW+Gt6TYyY4o0YyTCHgOEu2vwfRtr3foyJo5KBkrZ9sUCEbd8ed84ZGhqqrMjGjRsz3f/XX3+p2NhYx/rvv/+u3nnnHXXz5k1TdZmJVbVpXZ5jVW1W1eWuvctq8acbJwl4USlVG2gJDBGROs4NlFJTlFKNlFKNMJJnNqvU6d4dbPub+k21Cdx2223ceeedjvWRI0cyatQoGjZsyIYNG0xUptFocit+M/ZKqdNKqX22z9cwRvh3ZNLlX8Aif2izOgMGDKBmzZpERUXxj3/8g8cff5yTJ0+aLUuj0eQiTMmgFZGqwBagnlLqTxf7C2HMUVLTPrIXkd8wUr8VMEspNTuDYw/CNqlUmTJl7lq8eLGrZqYSHx9PkSJFPOqjlOLs2bOcPn2alJQUAgICqFChAuXKlcNX1eWyo8tfWFWb1uU5VtVmVV0dOnTY6xNvhi98QZ4sGFV+9gKPZNLmcWBlmm0Vbf+WxZgv/O6szpVbffaZ8fvvv6sePXooQAUGBqrIyEhL6MpprKpN6/Icq2qzqi585LP3azSOiARjzCy4UCm1NJOmvUjjwlFKxdn+PSciyzDqh25x0TdPU6VKFb7++mt+/PFHDhw4QN26dYG/R/7ly5c3WaFGo7EifvPZ2yrZfwpEKaXey6RdMYyCEN86bSssIkXtnzGKQkS6PkL+4L777mPUqFGO9W+//Zbq1aszadIkEhJ0aVGNRpMaf0bjtAH6Ah1FZL9teVBEBovIYKd2DwNrlVLOVXXKAVtF5ABGRaBVSqnv/Sfd+mzdupUbN24wduxY6tWrx6pVq8yWpNFoLITf3DhKqa24UfNTKTUfmJ9m23GgYY4IyyO88847dOnShaFDh3L48GG6du1Kt27dmDp1KtWrV8/6ABqNJk+jp0vIQ3To0IH9+/fz3nvvUbRoUVauXEmdOnV0bL5Go9HGPq8RHBzMiBEjiImJoW/fvpQvX56WLVuaLUuj0ZiMNvZ5lAoVKvD5559z4MABChUqBMCff/5J3759OXLkiMnqNBqNv9HGPo9TrFgxx+c333yTL774gvr16/PKK69w/fr1THpqNJq8hDb2+YgRI0bw9NNPc/PmTd58803Cw8P5+uuv7UlrGo0mD6ONfT6ibNmyfPrpp2zfvp0mTZpw8uRJHnvsMe677z6io6PNlqfRaHIQbezzIS1btmTXrl3MnDmTEiVKsH79en7//XezZWk0mhxEG/t8SmBgIIMHD+bIkSN89NFH3H///Y59u3fv1q4djSaPkaeNfWJiIlFRUWbLsDSlS5fm2Wefdazv2bOHFi1a0KFDByIj8/WMFBpNniJPG/vk5GQaNGjA6NGjuXbtmtlycgWnT5+mVKlSbN68mUaNGvHCCy9w9epVs2VpNBovydPGPigoiOTkZKZMmUJ4eDiLFi3S7oks6NatGzExMTz33HMopfjggw8ICwvj888/JyUlxWx5Go0mm+RpY1+gQAF27txJ8+bNiYuLo3fv3nTs2FG7J7KgZMmSzJgxgz179tCqVSvOnj1Lv379mDRpktnSNBpNNvHnFMeVRWSjiESJyCERGe6iTXsRueo0K+Y4p32dRSRGRI6JyBh3z9usWTO2b9/OnDlzKFWqFJs2baJRo0aMHDlSuyeyoHHjxmzdupX58+dTs2ZNnnnmGbMlaTSabGKpguM2flK2ouNKqQkAIhIIzAAeAOoA/8qgr0sCAgIYOHAgR44ccbgn3n//fcLCwliwYIF27WRCQEAA/fr1IyYmhgoVKgCQlJTEfffdx9y5c7VrR5MrOH3tNPfMv4cz8WfMluIRp6+dhtKE+eJYVi447kxz4JhS6rhS6ibwJfBPTzU4uydat27N2bNnefLJJ2nXrh0HDhzw9HD
2022-02-19 10:24:54 +01:00
"text/plain": [
"<Figure size 396x230.4 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"from sklearn.linear_model import SGDClassifier\n",
"\n",
"sgd_clf = SGDClassifier(loss=\"hinge\", alpha=0.017, max_iter=1000, tol=1e-3,\n",
" random_state=42)\n",
"sgd_clf.fit(X, y)\n",
2016-09-27 23:31:21 +02:00
"\n",
"m = len(X)\n",
"t = np.array(y).reshape(-1, 1) * 2 - 1 # -1 if y == 0, or +1 if y == 1\n",
2016-09-27 23:31:21 +02:00
"X_b = np.c_[np.ones((m, 1)), X] # Add bias input x0=1\n",
"X_b_t = X_b * t\n",
"sgd_theta = np.r_[sgd_clf.intercept_[0], sgd_clf.coef_[0]]\n",
"print(sgd_theta)\n",
"support_vectors_idx = (X_b_t.dot(sgd_theta) < 1).ravel()\n",
"sgd_clf.support_vectors_ = X[support_vectors_idx]\n",
"sgd_clf.C = C\n",
"\n",
2021-12-08 03:16:42 +01:00
"plt.figure(figsize=(5.5, 3.2))\n",
2016-09-27 23:31:21 +02:00
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\")\n",
"plot_svc_decision_boundary(sgd_clf, 4, 6)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.title(\"SGDClassifier\")\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"plt.grid()\n",
"\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## 1. to 8."
2016-09-27 23:31:21 +02:00
]
},
{
2017-06-01 09:23:37 +02:00
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"1. The fundamental idea behind Support Vector Machines is to fit the widest possible \"street\" between the classes. In other words, the goal is to have the largest possible margin between the decision boundary that separates the two classes and the training instances. When performing soft margin classification, the SVM searches for a compromise between perfectly separating the two classes and having the widest possible street (i.e., a few instances may end up on the street). Another key idea is to use kernels when training on nonlinear datasets. SVMs can also be tweaked to perform linear and nonlinear regression, as well as novelty detection.\n",
"2. After training an SVM, a _support vector_ is any instance located on the \"street\" (see the previous answer), including its border. The decision boundary is entirely determined by the support vectors. Any instance that is _not_ a support vector (i.e., is off the street) has no influence whatsoever; you could remove them, add more instances, or move them around, and as long as they stay off the street they won't affect the decision boundary. Computing the predictions with a kernelized SVM only involves the support vectors, not the whole training set.\n",
"3. SVMs try to fit the largest possible \"street\" between the classes (see the first answer), so if the training set is not scaled, the SVM will tend to neglect small features (see Figure 52).\n",
"4. You can use the `decision_function()` method to get confidence scores. These scores represent the distance between the instance and the decision boundary. However, they cannot be directly converted into an estimation of the class probability. If you set `probability=True` when creating an `SVC`, then at the end of training it will use 5-fold cross-validation to generate out-of-sample scores for the training samples, and it will train a `LogisticRegression` model to map these scores to estimated probabilities. The `predict_proba()` and `predict_log_proba()` methods will then be available.\n",
"5. All three classes can be used for large-margin linear classification. The `SVC` class also supports the kernel trick, which makes it capable of handling nonlinear tasks. However, this comes at a cost: the `SVC` class does not scale well to datasets with many instances. It does scale well to a large number of features, though. The `LinearSVC` class implements an optimized algorithm for linear SVMs, while `SGDClassifier` uses Stochastic Gradient Descent. Depending on the dataset `LinearSVC` may be a bit faster than `SGDClassifier`, but not always, and `SGDClassifier` is more flexible, plus it supports incremental learning.\n",
"6. If an SVM classifier trained with an RBF kernel underfits the training set, there might be too much regularization. To decrease it, you need to increase `gamma` or `C` (or both).\n",
"7. A Regression SVM model tries to fit as many instances within a small margin around its predictions. If you add instances within this margin, the model will not be affected at all: it is said to be _ϵ-insensitive_.\n",
"8. The kernel trick is mathematical technique that makes it possible to train a nonlinear SVM model. The resulting model is equivalent to mapping the inputs to another space using a nonlinear transformation, then training a linear SVM on the resulting high-dimensional inputs. The kernel trick gives the same result without having to transform the inputs at all."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"# 9."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: Train a `LinearSVC` on a linearly separable dataset. Then train an `SVC` and a `SGDClassifier` on the same dataset. See if you can get them to produce roughly the same model._"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's use the Iris dataset: the Iris Setosa and Iris Versicolor classes are linearly separable."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn import datasets\n",
"\n",
"iris = datasets.load_iris(as_frame=True)\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = iris.target\n",
2017-06-01 09:23:37 +02:00
"\n",
"setosa_or_versicolor = (y == 0) | (y == 1)\n",
"X = X[setosa_or_versicolor]\n",
"y = y[setosa_or_versicolor]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's build and train 3 models:\n",
"* Remember that `LinearSVC` uses `loss=\"squared_hinge\"` by default, so if we want all 3 models to produce similar results, we need to set `loss=\"hinge\"`.\n",
"* Also, the `SVC` class uses an RBF kernel by default, so we need to set `kernel=\"linear\"` to get similar results as the other two models.\n",
"* Lastly, the `SGDClassifier` class does not have a `C` hyperparameter, but it has another regularization hyperparameter called `alpha`, so we can tweak it to get similar results as the other two models."
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 35,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.svm import SVC, LinearSVC\n",
"from sklearn.linear_model import SGDClassifier\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"C = 5\n",
"alpha = 0.05\n",
2017-06-01 09:23:37 +02:00
"\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(X)\n",
"\n",
"lin_clf = LinearSVC(loss=\"hinge\", C=C, random_state=42).fit(X_scaled, y)\n",
"svc_clf = SVC(kernel=\"linear\", C=C).fit(X_scaled, y)\n",
"sgd_clf = SGDClassifier(alpha=alpha, random_state=42).fit(X_scaled, y)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's plot the decision boundaries of these three models:"
]
},
{
"cell_type": "code",
"execution_count": 36,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqMAAAEOCAYAAAC99R7FAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABuIElEQVR4nO3dd3xO1x/A8c/JEklolAgSETOE2rVbCTVLqVF7q01b1Py1RUsXbdXW2pSiatVsiVU1G1QQI7FiVa0gZJzfH/fxCBKekCdPxvf9et2X59577r3n5mZ8nXvO+SqtNUIIIYQQQtiCna0rIIQQQgghMi4JRoUQQgghhM1IMCqEEEIIIWxGglEhhBBCCGEzEowKIYQQQgibkWBUCCGEEELYTIoFo0qpvEqpzUqpI0qpw0qp9xIoo5RS3yulTiilDiqlysbbV1cpdcy0b0hK1VsIIYQQQlhPSraMxgADtNbFgEpAb6WU/2Nl6gGFTUs3YAqAUsoemGTa7w+0SuBYIYQQQgiRxqRYMKq1vqC13m/6fAs4Ang9VqwRMFcb/gLclVK5gQrACa31Ka31fWCRqawQQgghhEjDHGxxUaWUL1AG2PXYLi/gbLz1c6ZtCW2vmMi5u2G0quLs7FzOx8cneSotUlRcXBx2dtKlOa2S55d2ybNL2+T5pW3p+fmFhob+q7X2SGhfigejSik34Bfgfa31zcd3J3CIfsr2JzdqPR2YDuDn56ePHTv2ArUVthIUFERAQICtqyGekzy/tEueXdomzy9tS8/PTyl1OrF9KRqMKqUcMQLRBVrrZQkUOQfkjbfuDUQATolsF0IIIYQQaVhKjqZXwAzgiNb6m0SKrQTam0bVVwJuaK0vAHuAwkqp/EopJ6ClqawQQgghhEjDUrJltCrQDjiklAo2bRsG+ABoracCa4D6wAngDtDJtC9GKdUHWA/YAzO11odTsO5CCCGEEMIKUiwY1VpvJ+G+n/HLaKB3IvvWYASrQgghhBAinUifQ7aEEEIIIUSaYJOpnYQQIr6bN29y+fJloqOjbV2VDO+ll17iyJEjtq6GRRwdHcmZMydZs2a1dVWEEC9AglEhhE3dvHmTS5cu4eXlRebMmTHGOgpbuXXrFlmyZLF1NZ5Ja83du3c5f/48gASkQqRh8ppeCGFTly9fxsvLCxcXFwlEhcWUUri4uODl5cXly5dtXR0hxAuQYFQIYVPR0dFkzpzZ1tUQaVTmzJmle4cQaZwEo0IIm5MWUfG85HtHiLRPglEhhBBCCGEzEowKIYQQQgibkWBUCCGSma+vL2PHjrV1NYQQIk2QYFQIIZ5Dx44dadCgQYL79uzZQ69evVK4RonbsmULNWvWJEeOHLi4uFCwYEHatGnDzZs32bdvH0optm/fnuCx77zzDlWrVjWv37p1i48++gh/f38yZ86Mp6cnAQEBLFy4kLi4uJS6JSFEOiLBqBBCJDMPDw9cXFxsXQ3u379PSEgIdevWpWTJkmzevJl//vmHKVOm8NJLL3Hv3j3KlStHmTJlmDFjxhPHX716lZUrV9KlSxcArl+/TuXKlZk5cyYffvghe/fuZfv27XTo0IFPP/2UM2fOpPQtCiHSAQlGhRAimT3+ml4pxfTp02nevDmurq4UKFCA+fPnP3LM+fPnadmyJdmyZSNbtmy8+eabHD9+3Lz/5MmTNGrUiFy5cuHq6krZsmVZvXr1E9cdMWIEnTt3xt3dnTZt2rBhwwayZ8/Ot99+yyuvvEKBAgWoXbs2kydPxsPDA4AuXbqwZMkSIiMjHznf/PnzcXR0pEWLFgAMGzaMsLAwdu3aRadOnShevDiFCxemU6dO7N+/n1y5ciXr11EIkTGk62A0Lk6m/BAiLQoICGD27NmAMQ9pQECAOXi7c+cOAQEB/PzzzwDcuHGDgIAAli1bBsC///5LQEAAq1atAuDixYsEBASwbt06AM6ePUtAQAC///47AKdOnUqRexo1ahSNGjXiwIEDtGjRgs6dO3P69GnzPQUGBuLs7MyWLVvYuXMnuXPn5o033uDOnTsAREZGUq9ePTZu3MiBAwdo2rQpTZo04ejRo49c55tvvqFo0aLs3buXMWPGkCtXLq5cucLmzZsTrVubNm2IjY01f00fmDlzJi1btsTV1ZW4uDgWLVpEmzZt8Pb2fuIczs7OODs7v+iXSQiRAaXrYNTOTgNw/z58/jlERNi4QkKIDKtdu3a0bduWQoUK8emnn+Lg4MC2bdsAWLRoEVprZs2aRcmSJSlatCjTpk0jMjLS3PpZqlQpevTowSuvvEKhQoUYPnw4ZcuWZenSpY9cp3r16gwaNIhChQpRuHBhmjdvTuvWralRowaenp40bNiQb775hitXrpiPcXd3p2nTpo+8qt+zZw8HDx6ka9eugBHkX7t2jWLFiln7SyVEqpMrFyj15CIvA5JHus5Nn+nKFQgLY9mu/AwbBh99BG+/Db16QUCA8Y0khEh9goKCzJ8dHR0fWXdxcXlk/aWXXnpkPUeOHI+s58qV65H1vHnzPrJeoECB5Kv4U5QsWdL82cHBAQ8PD3May3379hEWFvZETvg7d+5w8uRJAG7fvs3IkSNZvXo1Fy5cIDo6mqioqEfOC1C+fPlH1u3t7Zk1axafffYZmzZt4q+//uLrr79m9OjRbN26leLFiwPGq/oaNWoQGhpKuXLlmDlzJiVKlKBixYqAkQteiIzq0qWkbRdJk65bRp2uXYOCBak7qQGfVV2L0nEsXQo1akDx4jBxIty4YetaCiEyAkdHx0fWlVLm0edxcXGULl2a4ODgR5bQ0FC6d+8OwMCBA1myZAmffvopW7ZsITg4mAoVKnD//v1Hzuvq6prg9b28vGjXrh2TJk0iJCQEOzs7vv76a/P+gIAAChUqxLx587h79y4LFy40D1wCY1BWtmzZOHLkSLJ8PYQQ4oF0HYzeLlAA/vc/3I/vZfiO+tzNW4R1tb+hWK5rHDkCfftCrVq2rqUQIqMrW7YsJ06cIEeOHBQqVOiR5eWXXwZg+/bttG/fnqZNm1KyZEm8vb3NraZJlS1bNnLnzv3IgCWlFJ07d2bhwoUsXLiQu3fv0q5dO/N+Ozs7WrRowYIFCzh37twT54yKiiIqKuq56iOEyNjSdTAa5+AAo0bBmTOwcCEOeXNTZ8MADt/w4lTNrnR79W/at39Y/swZWLAA7t2zXZ2FEGnHzZs3n2jNDA8PT/J52rRpg6enJ40aNWLLli2EhYWxdetWBgwYYB5RX6RIEX799Vf279/PoUOHaNu2rUXB37Rp0+jZsycbNmzg5MmTHD58mMGDB3Po0CEaN278SNmOHTty9epVBg4cSOPGjcmePfsj+8eMGYOPjw8VK1Zk1qxZHD58mBMnTjBv3jzKlSvHxYsXk3zvQgiRYn1GlVIzgQbAZa11iQT2fwi0iVevYoCH1vo/pVQ4cAuIBWK01uUfP/6pnJygZUtjOXAANXky+efPZ9qdGWiHypCtNzRrxpQpmfjiC/jgA+jaFbp3h3z5XuCmhRDp2rZt2yhTpswj25o2bZrk87i4uLB161aGDBlC8+bNuXHjBnny5CEwMJBs2bIBxij5Ll268Nprr5EtWzbef/99i4LRChUq8Oeff9KzZ08iIiJwcXGhcOHCzJ07l7Zt2z5SNnfu3NSuXZu1a9eaBy7Fly1bNv766y+++uorvvzyS8LDw8maNSv+/v589NFH+Pj4JPnehRBCpVSndKXU60AkMDehYPSxsg2BD7TWNUzr4UB5rfW/Sbmmn5+fPnbsWMI7r1+H2bNh8mQ4fhxy5uSfyu/S/1h3Nh7NC4CdHbz5pjHgqXZtY12kjKCgIAICAmxdDfGckvL8jhw5IiO0U5Fbt249MZAqtZPvoYfkd6d15MqV8GAlT09IzhcC6fn5KaX2JdaYmGLhldZ6K/CfhcVbAQutWB1wd4f334ejR2H9eqhUiRKrPmd9qC//vt6EMTX/wMFes2oV1KsH771n1doIIYQQIpW6eBG0fnKRninJI9W19SmlXIC6wC/xNmtgg1Jqn1KqW7Je0M7OaPZcsQJOnkQNGkT2w1s
"text/plain": [
"<Figure size 792x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"def compute_decision_boundary(model):\n",
" w = -model.coef_[0, 0] / model.coef_[0, 1]\n",
" b = -model.intercept_[0] / model.coef_[0, 1]\n",
" return scaler.inverse_transform([[-10, -10 * w + b], [10, 10 * w + b]])\n",
2017-06-01 09:23:37 +02:00
"\n",
"lin_line = compute_decision_boundary(lin_clf)\n",
"svc_line = compute_decision_boundary(svc_clf)\n",
"sgd_line = compute_decision_boundary(sgd_clf)\n",
2017-06-01 09:23:37 +02:00
"\n",
"# Plot all three decision boundaries\n",
"plt.figure(figsize=(11, 4))\n",
"plt.plot(lin_line[:, 0], lin_line[:, 1], \"k:\", label=\"LinearSVC\")\n",
"plt.plot(svc_line[:, 0], svc_line[:, 1], \"b--\", linewidth=2, label=\"SVC\")\n",
"plt.plot(sgd_line[:, 0], sgd_line[:, 1], \"r-\", label=\"SGDClassifier\")\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\") # label=\"Iris versicolor\"\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\") # label=\"Iris setosa\"\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper center\")\n",
2017-06-01 09:23:37 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.grid()\n",
2017-06-01 09:23:37 +02:00
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Close enough!"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"# 10."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: Train an SVM classifier on the Wine dataset, which you can load using `sklearn.datasets.load_wine()`. This dataset contains the chemical analysis of 178 wine samples produced by 3 different cultivators: the goal is to train a classification model capable of predicting the cultivator based on the wine's chemical analysis. Since SVM classifiers are binary classifiers, you will need to use one-versus-all to classify all 3 classes. What accuracy can you reach?_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"First, let's fetch the dataset, look at its description, then split it into a training set and a test set:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import load_wine\n",
"\n",
"wine = load_wine(as_frame=True)"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 38,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
".. _wine_dataset:\n",
"\n",
"Wine recognition dataset\n",
"------------------------\n",
"\n",
"**Data Set Characteristics:**\n",
"\n",
" :Number of Instances: 178 (50 in each of three classes)\n",
" :Number of Attributes: 13 numeric, predictive attributes and the class\n",
" :Attribute Information:\n",
" \t\t- Alcohol\n",
" \t\t- Malic acid\n",
" \t\t- Ash\n",
"\t\t- Alcalinity of ash \n",
" \t\t- Magnesium\n",
"\t\t- Total phenols\n",
" \t\t- Flavanoids\n",
" \t\t- Nonflavanoid phenols\n",
" \t\t- Proanthocyanins\n",
"\t\t- Color intensity\n",
" \t\t- Hue\n",
" \t\t- OD280/OD315 of diluted wines\n",
" \t\t- Proline\n",
"\n",
" - class:\n",
" - class_0\n",
" - class_1\n",
" - class_2\n",
"\t\t\n",
" :Summary Statistics:\n",
" \n",
" ============================= ==== ===== ======= =====\n",
" Min Max Mean SD\n",
" ============================= ==== ===== ======= =====\n",
" Alcohol: 11.0 14.8 13.0 0.8\n",
"<<26 more lines>>\n",
"wine.\n",
"\n",
"Original Owners: \n",
"\n",
"Forina, M. et al, PARVUS - \n",
"An Extendible Package for Data Exploration, Classification and Correlation. \n",
"Institute of Pharmaceutical and Food Analysis and Technologies,\n",
"Via Brigata Salerno, 16147 Genoa, Italy.\n",
"\n",
"Citation:\n",
"\n",
"Lichman, M. (2013). UCI Machine Learning Repository\n",
"[https://archive.ics.uci.edu/ml]. Irvine, CA: University of California,\n",
"School of Information and Computer Science. \n",
"\n",
".. topic:: References\n",
"\n",
" (1) S. Aeberhard, D. Coomans and O. de Vel, \n",
" Comparison of Classifiers in High Dimensional Settings, \n",
" Tech. Rep. no. 92-02, (1992), Dept. of Computer Science and Dept. of \n",
" Mathematics and Statistics, James Cook University of North Queensland. \n",
" (Also submitted to Technometrics). \n",
"\n",
" The data was used with many others for comparing various \n",
" classifiers. The classes are separable, though only RDA \n",
" has achieved 100% correct classification. \n",
" (RDA : 100%, QDA 99.4%, LDA 98.9%, 1NN 96.1% (z-transformed data)) \n",
" (All results using the leave-one-out technique) \n",
"\n",
" (2) S. Aeberhard, D. Coomans and O. de Vel, \n",
" \"THE CLASSIFICATION PERFORMANCE OF RDA\" \n",
" Tech. Rep. no. 92-01, (1992), Dept. of Computer Science and Dept. of \n",
" Mathematics and Statistics, James Cook University of North Queensland. \n",
" (Also submitted to Journal of Chemometrics).\n",
"\n"
]
}
],
2017-06-01 09:23:37 +02:00
"source": [
"print(wine.DESCR)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 39,
2017-12-19 22:40:17 +01:00
"metadata": {},
"outputs": [],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" wine.data, wine.target, random_state=42)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 40,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>alcohol</th>\n",
" <th>malic_acid</th>\n",
" <th>ash</th>\n",
" <th>alcalinity_of_ash</th>\n",
" <th>magnesium</th>\n",
" <th>total_phenols</th>\n",
" <th>flavanoids</th>\n",
" <th>nonflavanoid_phenols</th>\n",
" <th>proanthocyanins</th>\n",
" <th>color_intensity</th>\n",
" <th>hue</th>\n",
" <th>od280/od315_of_diluted_wines</th>\n",
" <th>proline</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>13.16</td>\n",
" <td>2.36</td>\n",
" <td>2.67</td>\n",
" <td>18.6</td>\n",
" <td>101.0</td>\n",
" <td>2.80</td>\n",
" <td>3.24</td>\n",
" <td>0.30</td>\n",
" <td>2.81</td>\n",
" <td>5.68</td>\n",
" <td>1.03</td>\n",
" <td>3.17</td>\n",
" <td>1185.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>100</th>\n",
" <td>12.08</td>\n",
" <td>2.08</td>\n",
" <td>1.70</td>\n",
" <td>17.5</td>\n",
" <td>97.0</td>\n",
" <td>2.23</td>\n",
" <td>2.17</td>\n",
" <td>0.26</td>\n",
" <td>1.40</td>\n",
" <td>3.30</td>\n",
" <td>1.27</td>\n",
" <td>2.96</td>\n",
" <td>710.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>122</th>\n",
" <td>12.42</td>\n",
" <td>4.43</td>\n",
" <td>2.73</td>\n",
" <td>26.5</td>\n",
" <td>102.0</td>\n",
" <td>2.20</td>\n",
" <td>2.13</td>\n",
" <td>0.43</td>\n",
" <td>1.71</td>\n",
" <td>2.08</td>\n",
" <td>0.92</td>\n",
" <td>3.12</td>\n",
" <td>365.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>154</th>\n",
" <td>12.58</td>\n",
" <td>1.29</td>\n",
" <td>2.10</td>\n",
" <td>20.0</td>\n",
" <td>103.0</td>\n",
" <td>1.48</td>\n",
" <td>0.58</td>\n",
" <td>0.53</td>\n",
" <td>1.40</td>\n",
" <td>7.60</td>\n",
" <td>0.58</td>\n",
" <td>1.55</td>\n",
" <td>640.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>51</th>\n",
" <td>13.83</td>\n",
" <td>1.65</td>\n",
" <td>2.60</td>\n",
" <td>17.2</td>\n",
" <td>94.0</td>\n",
" <td>2.45</td>\n",
" <td>2.99</td>\n",
" <td>0.22</td>\n",
" <td>2.29</td>\n",
" <td>5.60</td>\n",
" <td>1.24</td>\n",
" <td>3.37</td>\n",
" <td>1265.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" alcohol malic_acid ash alcalinity_of_ash magnesium total_phenols \\\n",
"2 13.16 2.36 2.67 18.6 101.0 2.80 \n",
"100 12.08 2.08 1.70 17.5 97.0 2.23 \n",
"122 12.42 4.43 2.73 26.5 102.0 2.20 \n",
"154 12.58 1.29 2.10 20.0 103.0 1.48 \n",
"51 13.83 1.65 2.60 17.2 94.0 2.45 \n",
"\n",
" flavanoids nonflavanoid_phenols proanthocyanins color_intensity hue \\\n",
"2 3.24 0.30 2.81 5.68 1.03 \n",
"100 2.17 0.26 1.40 3.30 1.27 \n",
"122 2.13 0.43 1.71 2.08 0.92 \n",
"154 0.58 0.53 1.40 7.60 0.58 \n",
"51 2.99 0.22 2.29 5.60 1.24 \n",
"\n",
" od280/od315_of_diluted_wines proline \n",
"2 3.17 1185.0 \n",
"100 2.96 710.0 \n",
"122 3.12 365.0 \n",
"154 1.55 640.0 \n",
"51 3.37 1265.0 "
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"X_train.head()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 41,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"2 0\n",
"100 1\n",
"122 1\n",
"154 2\n",
"51 0\n",
"Name: target, dtype: int64"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"y_train.head()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's start simple, with a linear SVM classifier. It will automatically use the One-vs-All (also called One-vs-the-Rest, OvR) strategy, so there's nothing special we need to do to handle multiple classes. Easy, right?"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 42,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"LinearSVC(random_state=42)"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"lin_clf = LinearSVC(random_state=42)\n",
"lin_clf.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Oh no! It failed to converge. Can you guess why? Do you think we must just increase the number of training iterations? Let's see:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 43,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"LinearSVC(max_iter=1000000, random_state=42)"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"lin_clf = LinearSVC(max_iter=1_000_000, random_state=42)\n",
"lin_clf.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Even with one million iterations, it still did not converge. There must be another problem.\n",
"\n",
"Let's still evaluate this model with `cross_val_score`, it will serve as a baseline:"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 44,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n",
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n",
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n",
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n",
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"0.90997150997151"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.model_selection import cross_val_score\n",
"\n",
"cross_val_score(lin_clf, X_train, y_train).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Well 91% accuracy on this dataset is not great. So did you guess what the problem is?\n",
"\n",
"That's right, we forgot to scale the features! Always remember to scale the features when using SVMs:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 45,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('linearsvc', LinearSVC(random_state=42))])"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"lin_clf = make_pipeline(StandardScaler(),\n",
" LinearSVC(random_state=42))\n",
"lin_clf.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Now it converges without any problem. Let's measure its performance:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 46,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9774928774928775"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.model_selection import cross_val_score\n",
2017-06-01 09:23:37 +02:00
"\n",
"cross_val_score(lin_clf, X_train, y_train).mean()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Nice! We get 97.7% accuracy, that's much better."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's see if a kernelized SVM will do better. We will use a default `SVC` for now:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 47,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9698005698005698"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"svm_clf = make_pipeline(StandardScaler(), SVC(random_state=42))\n",
"cross_val_score(svm_clf, X_train, y_train).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's not better, but perhaps we need to do a bit of hyperparameter tuning:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 48,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('svc',\n",
" SVC(C=9.925589984899778, gamma=0.011986281799901176,\n",
" random_state=42))])"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.model_selection import RandomizedSearchCV\n",
"from scipy.stats import reciprocal, uniform\n",
"\n",
"param_distrib = {\n",
" \"svc__gamma\": reciprocal(0.001, 0.1),\n",
" \"svc__C\": uniform(1, 10)\n",
"}\n",
"rnd_search_cv = RandomizedSearchCV(svm_clf, param_distrib, n_iter=100, cv=5,\n",
" random_state=42)\n",
"rnd_search_cv.fit(X_train, y_train)\n",
"rnd_search_cv.best_estimator_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 49,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9925925925925926"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"rnd_search_cv.best_score_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Ah, this looks excellent! Let's select this model. Now we can test it on the test set:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 50,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9777777777777777"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"rnd_search_cv.score(X_test, y_test)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"This tuned kernelized SVM performs better than the `LinearSVC` model, but we get a lower score on the test set than we measured using cross-validation. This is quite common: since we did so much hyperparameter tuning, we ended up slightly overfitting the cross-validation test sets. It's tempting to tweak the hyperparameters a bit more until we get a better result on the test set, but we this would probably not help, as we would just start overfitting the test set. Anyway, this score is not bad at all, so let's stop here."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"## 11."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: Train and fine-tune an SVM regressor on the California housing dataset. You can use the original dataset rather than the tweaked version we used in Chapter 2. The original dataset can be fetched using `sklearn.datasets.fetch_california_housing()`. The targets represent hundreds of thousands of dollars. Since there are over 20,000 instances, SVMs can be slow, so for hyperparameter tuning you should use much less instances (e.g., 2,000), to test many more hyperparameter combinations. What is your best model's RMSE?_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's load the dataset:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.datasets import fetch_california_housing\n",
"\n",
"housing = fetch_california_housing()\n",
"X = housing.data\n",
"y = housing.target"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Split it into a training set and a test set:"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,\n",
" random_state=42)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Don't forget to scale the data:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's train a simple `LinearSVR` first:"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('linearsvr', LinearSVR(random_state=42))])"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.svm import LinearSVR\n",
2017-06-01 09:23:37 +02:00
"\n",
"lin_svr = make_pipeline(StandardScaler(), LinearSVR(random_state=42))\n",
"lin_svr.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"It did not converge, so let's increase `max_iter`:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 54,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('linearsvr', LinearSVR(max_iter=5000, random_state=42))])"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"lin_svr = make_pipeline(StandardScaler(),\n",
" LinearSVR(max_iter=5000, random_state=42))\n",
"lin_svr.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's see how it performs on the training set:"
]
},
{
"cell_type": "code",
"execution_count": 55,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9595484665813285"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.metrics import mean_squared_error\n",
"\n",
"y_pred = lin_svr.predict(X_train)\n",
2017-06-01 09:23:37 +02:00
"mse = mean_squared_error(y_train, y_pred)\n",
"mse"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's look at the RMSE:"
]
},
{
"cell_type": "code",
"execution_count": 56,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.979565447829459"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"np.sqrt(mse)"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"In this dataset, the targets represent hundreds of thousands of dollars. The RMSE gives a rough idea of the kind of error you should expect (with a higher weight for large errors): so with this model we can expect errors close to $98,000! Not great. Let's see if we can do better with an RBF Kernel. We will use randomized search with cross validation to find the appropriate hyperparameter values for `C` and `gamma`:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 57,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"RandomizedSearchCV(cv=3,\n",
" estimator=Pipeline(steps=[('standardscaler',\n",
" StandardScaler()),\n",
" ('svr', SVR())]),\n",
" n_iter=100,\n",
" param_distributions={'svr__C': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7ff030704ee0>,\n",
" 'svr__gamma': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7ff030704fd0>},\n",
" random_state=42)"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.svm import SVR\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"from scipy.stats import reciprocal, uniform\n",
"\n",
"svm_clf = make_pipeline(StandardScaler(), SVR())\n",
"\n",
"param_distrib = {\n",
" \"svr__gamma\": reciprocal(0.001, 0.1),\n",
" \"svr__C\": uniform(1, 10)\n",
"}\n",
"rnd_search_cv = RandomizedSearchCV(svm_clf, param_distrib,\n",
" n_iter=100, cv=3, random_state=42)\n",
"rnd_search_cv.fit(X_train[:2000], y_train[:2000])"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 58,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('svr', SVR(C=4.63629602379294, gamma=0.08781408196485974))])"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"rnd_search_cv.best_estimator_"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0.58835648, 0.57468589, 0.58085278, 0.57109886, 0.59853029])"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"-cross_val_score(rnd_search_cv.best_estimator_, X_train, y_train,\n",
" scoring=\"neg_root_mean_squared_error\")"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Looks much better than the linear model. Let's select this model and evaluate it on the test set:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 60,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.5854732265172222"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"y_pred = rnd_search_cv.best_estimator_.predict(X_test)\n",
"rmse = mean_squared_error(y_test, y_pred, squared=False)\n",
"rmse"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"So SVMs worked very well on the Wine dataset, but not so much on the California Housing dataset. In Chapter 2, we found that Random Forests worked better for that dataset."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"And that's all for today!"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
2016-09-27 23:31:21 +02:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
2016-09-27 23:31:21 +02:00
},
"nav_menu": {},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
2016-09-27 23:31:21 +02:00
}