handson-ml/05_support_vector_machines....

2612 lines
622 KiB
Plaintext
Raw Normal View History

2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"**Support Vector Machines**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_This notebook is an extra chapter on Support Vector Machines. It also includes exercises and their solutions at the end._"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/05_support_vector_machines.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/05_support_vector_machines.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
" </td>\n",
"</table>"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"This project requires Python 3.8 or above:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 1,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"import sys\n",
2016-09-27 23:31:21 +02:00
"\n",
"assert sys.version_info >= (3, 8)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It also requires Scikit-Learn ≥ 1.0.1:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import sklearn\n",
2016-09-27 23:31:21 +02:00
"\n",
"assert sklearn.__version__ >= \"1.0.1\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we did in previous chapters, let's define the default font sizes to make the figures prettier:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.rc('font', size=14)\n",
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
2021-12-08 03:16:42 +01:00
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And let's create the `images/svm` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"IMAGES_PATH = Path() / \"images\" / \"svm\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
2016-09-27 23:31:21 +02:00
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Linear SVM Classification"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The book starts with a few figures, before the first code example, so the next three cells generate and save these figures. You can skip them if you want."
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 5,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAACnCAYAAAAfQeokAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABH50lEQVR4nO3deXhU1fnA8e/JQhIIiywJhh1JEAigglCKS0AsCLi0giJLRS1BWW3FBRBQgVIUEEVAUFCrlKpAfy5QSkWCgmyCbEkElIJAWESRsGY9vz9mEpMwazIz996Z9/M885CZuXPnvSS8vDlzznmV1hohhBBCCCFCRZjRAQghhBBCCBFIUgALIYQQQoiQIgWwEEIIIYQIKVIACyGEEEKIkCIFsBBCCCGECClSAAshhBBCiJASsAJYKdVAKbVOKZWplEpXSo12cIxSSr2qlPpOKbVbKXVDoOITQgghhBChISKA75UPPKG13qGUqgpsV0r9V2udUeKYO4BE+60jMN/+pxBCCCGEED4RsBFgrfVxrfUO+9fngEygXpnD7gb+rm02AzWUUlcHKkYhhBBCCBH8AjkCXEwp1Ri4HthS5ql6wJES94/aHzte5vWpQCpAdHR0u4YNG/otVjMoLCwkLCz4p2uHwnVa6hoLgWPAJWyZogEQ6cHL/HSNhy4eIrcgl4aVGxIdHu3z83sr0N/L/fv3n9Za1ynv64Mtb+bm5pKfn09YWBjR0Vf+PJjt31pBQQE5OTkAxMTEoJSq8Dn9eY1aay5dugRAVFQU4eHhfnkfT5jte+kPco3+4TJvaq0DegNige3AHxw8txK4qcT9tUA7V+dLSkrSwW7dunVGhxAQoXCdVrnGvHN5esfNO/Q61umN9TbqCwcuePxaf1zj5bzLOnpKtA5/PlxfzL3o8/OXR6C/l8DX2kd5OBjy5k8//aTj4uI0oBcsWHDF82b7t1ZYWKjvuOMODei+ffv65Jz+vsZZs2ZpQNevX19nZ2f79b1cMdv30h/kGv3DVd4MaCmulIoElgNLtNYrHBxyFNs4U5H6QFYgYhNC2OSfz2dPzz2c/fIslepV4rq066jcrLKhMUVFRHHm6TPsGLqDmMgYQ2MR5lCzZk1effVVAJ566imOHz/u5hXGUkoxf/58qlSpwocffsgnn3xidEhujRo1ivbt23P06FHGjx9vdDhC+FQgd4FQwCIgU2s9y8lhHwN/tO8G8RvgrNba3FlNiCBixuK3SHRENG3i2xgdhjCR++67j169enH27Flmz55tdDhuNWrUiClTpgAwbNgwsrOzDY7ItfDwcN58803Cw8NZsGABWVkyHiWCRyBHgDsDg4CuSqmd9ltPpdSjSqlH7cesAg4C3wFvAMMCGJ8QIc3Mxa8QjiilmDdvHjNmzGDq1KlGh+ORkSNHcuONN1pmVLVt27bMmzePbdu2kZCQYHQ4QvhMwBbBaa03AC5n/dvnawwPTERCiCJmL367vNOFiLAIFt+1mAbVG7h/gQgZDRs25IknnjA6DI+Fh4fzxhtv0K5dO+bOnUv//v3p1KmT0WG5lJqaanQIQvhccC85FEK4Zfbi91LeJb48/CWf/+9zalWuZXQ4wsSOHz/OW2+9ZXQYbrVt25Ynn3wSrTVDhgwhNzfX6JA8orVm2bJl7N+/3+hQhKgwQ7ZBC6Ts7GxOnTpFXl6e0aGUW/Xq1cnMzDQ6DL8z03VGRkYSFxdHtWrVjA7Fr8xe/ALsOrmLAl1A67jWVI40V2zCPM6fP0/btm05ffo0zZs3NzoctyZOnMiyZctIT0/nxRdf5NlnnzU6JLfmzJnD6NGjSUlJ4fPPP/fJVm5CGCWoC+Ds7GxOnjxJvXr1fLbvohHOnTtH1apVjQ7D78xyndq+/+WxY8cAgrYItkLxC/B11tcAtEtoZ3AkwsxiY2MZMmQIf/3rXxkyZAivvfaa0SG5FBMTw4IFC7jtttuYPHkyffv2NX3hPmDAAKZMmUJaWhpvvfUWDz/8sNEhCVFuQT0F4tSpU9SrV4/KlStbtvgVgaeUonLlytSrV49Tp04ZHY5fWKX4Bdh+fDsA7a72fQFcty4odeWtbl2fv5UIgAkTJpCYmEhGRgYnTpwwOhy3unbtykMPPURubi6pqakUFhYaHZJLtWrVKt5tY8yYMZw8edLYgIQhgiVvBnUBnJeXR0yM7BkqyicmJsbSU2ecsVLxC7A9y38FsLP/v+X/dWuKjo5m4cKFgG0+8LfffmtwRO7NmDGDuLg4vvjiCxYtWmR0OG498MAD9OjRgzNnzjB69GijwxEGCJa8GdQFMCAjv6LcgvFnx2rF78W8i6T/mE64Cqdt3bZGhyMsICUlhYcffhittSVGVWvWrMkrr7wCwJNPPmmZhh6VK1fm/fffZ+XKlUaHJES5BH0BLISwsVrxC1CoC3mlxys83flpWQAnPPbSSy8RERHBhg0b2Lx5s9HhuHX//fdzxx13cPbsWUuMqjZu3JjJkycDthFsIawoqBfBCSFsrFj8AsRWimVEhxFGhyEspmbNmjRp0oTNmzfToUMHo8Nxq2hUtVWrVsVtku+8806jw3Jp1KhRhIWFyR7BwrJkBNjCUlJSGDHCesVB48aNfTpqYNW/h0CxavFrRiUXf3TpkmLZxR+hoFq1apYofotYrU1yREQEjz/+OJUrSy4Rzpk5Z0oBbEKDBw+md+/ebo9bsWIF06ZNC0BEvrVt2zaGDZMu14Fg9eL3lc2vsCJzBZfzL/vl/PHx3j0eLIs/Qs3q1av5z3/+Y3QYblmtTXKR7OxspkyZEpSLhsWVvMmbZs6ZMgXCjbp1HX+j4uPBqF12cnNzqVSpEjVr1jQmABfy8vKIjIx0eUydOnUCFI1nCgsL0VoTHh5udCg+ZfXi91LeJZ5Y8wQazbmx5/zyHhbYKUtU0OrVq7njjjuoV68eGRkZpt7X24ptkgHuuusu1q9fT1hYGOPGjTM6HOFnwZI3ZQTYDTP89vLoo4/Su3dvpk+fTv369alfvz5w5Uf/K1asoE2bNsTExFCzZk1uvfVWp/s0PvDAA9x7772lHissLKRBgwa8/PLLgK0hxIsvvsg111xDTEwMrVu35r333is+/tChQyilWLp0KV27di3e2P3s2bMMGjSIuLg4oqOjadq0afHekXDlFIjs7Gwee+wxEhMTiY6OpkWLFrz//vulrqt169ZERUXRoEEDpk6ditba6d/XmTNnePDBB7nqqquIiYmhW7dupKenFz//9ttvExsby6pVq0hOTqZSpUqm6UDnK1YvfuHXDnAt67SUBXCi3G6//XY6duzIsWPHGDt2rNHhuGXFNslFXexeeOEFaZMsLEMKYItYv349u3fvZvXq1axdu/aK50+cOEG/fv148MEHyczM5IsvvmDQoEFOzzdw4EBWrlzJL7/8Uuo9jh8/zgMPPADYktqiRYuYO3cuGRkZjB07lqFDh16x7c3YsWMZNmwYGRkZ3HPPPTz77LPs2bOHTz/9lG+//ZbFixdTr149h3ForbnjjjtYv3498+bNIyMjg1mzZlGpUiUAtm/fTt++ffnDH/7Anj17+Nvf/sa0adNcdnkaPHgwW7Zs4aOPPmLr1q1UrlyZHj16cOnSpeJjLl++zJQpU1iwYAEZGRk0atTI6fmsJhiKX/i1A1z7hPYGRyKsLDw8nIULFxIREcH8+fP56quvjA7JrYkTJ9KsWbPiNslm161bNx588EFycnIYOnSoywEKIcxCpkBYRHR0NIsXLyYqKsrh81lZWeTl5dGnT5/iYi45Odnp+bp37061atVYvnw5jzzyCABLlizhtttuo27duly4cIFZs2axZs0abr75ZgCaNGnC1q1bmTt3Lr169So+18iRI+nTp0/x/cOHD3P99dcXL0Bp3Lix0zg+++wzNm3aRHp6OvXr16dq1ao0bdq0+PlZs2Zx66238vzzzwOQlJTEgQMHmD59OiNHjrzifAcOHODjjz9m/fr13HLLLQC8++67NGzYkCVLlvCnP/0JgIKCAubMmUO7dsHVXjdYil/wbwc4EVratGnDk08+ybRp00hNTWXHjh3Fv2SbkRX
"text/plain": [
"<Figure size 720x194.4 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 51\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
2016-09-27 23:31:21 +02:00
"from sklearn.svm import SVC\n",
"from sklearn import datasets\n",
"\n",
"iris = datasets.load_iris(as_frame=True)\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = iris.target\n",
2016-09-27 23:31:21 +02:00
"\n",
2017-02-17 14:47:18 +01:00
"setosa_or_versicolor = (y == 0) | (y == 1)\n",
"X = X[setosa_or_versicolor]\n",
"y = y[setosa_or_versicolor]\n",
2016-09-27 23:31:21 +02:00
"\n",
"# SVM Classifier model\n",
"svm_clf = SVC(kernel=\"linear\", C=float(\"inf\"))\n",
"svm_clf.fit(X, y)\n",
"\n",
2016-09-27 23:31:21 +02:00
"# Bad models\n",
"x0 = np.linspace(0, 5.5, 200)\n",
"pred_1 = 5 * x0 - 20\n",
2016-09-27 23:31:21 +02:00
"pred_2 = x0 - 1.8\n",
"pred_3 = 0.1 * x0 + 0.5\n",
"\n",
"def plot_svc_decision_boundary(svm_clf, xmin, xmax):\n",
" w = svm_clf.coef_[0]\n",
" b = svm_clf.intercept_[0]\n",
"\n",
" # At the decision boundary, w0*x0 + w1*x1 + b = 0\n",
" # => x1 = -w0/w1 * x0 - b/w1\n",
" x0 = np.linspace(xmin, xmax, 200)\n",
" decision_boundary = -w[0] / w[1] * x0 - b / w[1]\n",
2016-09-27 23:31:21 +02:00
"\n",
" margin = 1/w[1]\n",
" gutter_up = decision_boundary + margin\n",
" gutter_down = decision_boundary - margin\n",
" svs = svm_clf.support_vectors_\n",
"\n",
" plt.plot(x0, decision_boundary, \"k-\", linewidth=2, zorder=-2)\n",
" plt.plot(x0, gutter_up, \"k--\", linewidth=2, zorder=-2)\n",
" plt.plot(x0, gutter_down, \"k--\", linewidth=2, zorder=-2)\n",
" plt.scatter(svs[:, 0], svs[:, 1], s=180, facecolors='#AAA',\n",
" zorder=-1)\n",
2016-09-27 23:31:21 +02:00
"\n",
2021-12-08 03:16:42 +01:00
"fig, axes = plt.subplots(ncols=2, figsize=(10, 2.7), sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[0])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(x0, pred_1, \"g--\", linewidth=2)\n",
"plt.plot(x0, pred_2, \"m-\", linewidth=2)\n",
"plt.plot(x0, pred_3, \"r-\", linewidth=2)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\", label=\"Iris versicolor\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\", label=\"Iris setosa\")\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper left\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plot_svc_decision_boundary(svm_clf, 0, 5.5)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\")\n",
"plt.xlabel(\"Petal length\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"large_margin_classification_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnwAAACyCAYAAADYkBK7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABIxUlEQVR4nO3deXxTVfr48c/ThbVAoSyWVTbZQUDWLoIroMKgIiI6MOPoMKOjM46jqONvXMaNxRG/jgs6OvoVFZWvG4KKQGnZ951C2QuUsi9l63Z+fyS5JiWBtCS5Sfu8X6+82iQnuc9N2psn555zHjHGoJRSSimlyq8ouwNQSimllFLBpQmfUkoppVQ5pwmfUkoppVQ5pwmfUkoppVQ5pwmfUkoppVQ5pwmfUkoppVQ5pwmfqjBExIjI7Zf4HI+KyM4AhaSUUmUiIs+IyPoAPE+eiIwOQEgqzGnCpy6ZiKSJyBtebh8tInl2xKSUUnYSkXoi8qaI7BSRcyKSKyKzReR6u2NTFVOM3QEopZRS5dA0oBpwL7AVqA9cDSTYGZSquLSHT4WEiPxXRKaLyMMisldEjorIByJSza1Nqogsdp5iOC4iS0Sko9v9vUVkjoicct4/W0QaOu8bICIZzuc9IiI/iki7i8TUSEQ+cz7mqIh8LyKtS7R5TET2O2P6CIgL8EujlCpnRCQeSAHGGmNmG2N2GWOWGWMmGGM+c7apJCIvisguZw/gdhF5yHlftIj8R0R2iMgZEclyHosu+JktIr8RkY0iclZEtojIX9wfIyKtnGdkzorIZhG5OYgvgwoz2sOnQikFyAGuA5oAnwNbgJdEJAb4BvgPMBKIBboBRQAi0gWYC/wv8AhwDkjll7/h6sBrwFqgKvB34DsRaW+MyS8ZiDPRnAssxPGtOx94FPhZRNoZY06LyB3AP4E/OdsOAx4HjgTsFVFKlUd5zstgEZlvjDnrpc2HOI6JDwOrgGY4jovg6IzZC9wBHAR6ApOBwziOkecRkfuA53Acr1YAHYF3gQLgDWfi9xVwFOiDo/dxElD5EvdVRQpjjF70ckkXIA14w8vto4E85+//BbKBGLf73wV+dv5eBzDA1T62MQVYXIqYquNIFpPdbjPA7c7ffwtkAeJ2fzSOA+odzusLgXdLPO/PwE67X3O96EUv4X0BbsPx5fAssAiYAPRy3tfaeTwaUIrne9l1vHRefwZY73Z9N3BPicf8Gdjo/P0G5zGxqdv9yc44Rtv9eukl+Bc9patCaaMxptDt+j4c41owxhzBkRT+6Dy1+oiINHFr2xWY7euJRaSliHwiIttE5ASQi+NbclMfD+kONAdOOk/X5gHHgdpAS2ebdjgO1O5KXldKqfMYY6YBDYFbgJlAX2CxiDyJ43hWjOPMgVciMkZElovIQefx6S/4OJ6JSD0cvYPvuI5nzse8jOfxbK8xZrfbQ5c441AVgJ7SVYFwAqjl5fZ4HEmUS0GJ+w1u40iNMb8RkdeAAcBg4AUR+ZUx5kdALhLDdzhOgfze+bMQ2AhU8tE+ClgN3OnlPj1lq5S6ZMZxKneW8/KciLyHo2fungs9TkSG4xii8iiOMw0ngAeAoT4e4jqOjnG29/q0pQhdlUOa8KlA2AwMEhExxnGewKmb8z6/GWPWAGuAV0RkJjAK+BFYCVzj7TEikoDj2+sDxpi5ztu6ceG/75XACOCQMeaYjzabgN7A+2639fZ7Z5RSytNGHMelTBxJWn/gBy/tkoElxhhruSsRaemlHQDGmFwR2Qu0NMZ8dIFtNxKRJsaYbOdtPdHJmxWGvtEqEN4CWgD/IyJdRKSNiPwFR0I1wZ8nEJHmIvKyiPQVkWYi0h/ojOMgBTAe6Coik9228TsRaYpjEPIh4D7nLLSrgbdx9PL5MgXHad9vRORq5/ZTRWSi20zdScAoEblPRFqLyBNAr1K9MkqpCkdEEpwrCtwtIp2dx5dhwGPAbOcX28+B90TkNuf9KSLi6vnbAnQTkYHOY8/TOCaXXcgzwGPOmbltRKSjiPzaedwCx/jjTOAjEblSRPoA/+LCx0lVjmjCpy6ZMWY7jhmzrYGfgKU4TpUOM8bM8PNpTgNXAF/gONh9iCMpe8W5jdU4Zve2BRbjGHtyJ1BgjCkGhuNIENcD/waexjGT11fMp50xb3duM9O5zdo4EkiMMVNxHERfwDGLrhPwqp/7o5SquPJwHKceBuYBG4AXgU9wHKsAfu28/jqO489/+WVozDs4EsJPgGXA5cDEC23QGPMejslo9+A4S5IB3A/scN5fjOOUcBSO4+dHOFYh8HmcVOWLeJ6BU0oppZRS5Y328CmllFJKlXOa8CmlVIiJSBMRmSsim0Rkg4g8bHdMSqnyTU/pKqVUiIlIIpBojFkpIjVwVEb4lTFm40UeqpRSZaI9fEopFWLGmBxjzErn7ydxLAHUyN6olFLlmSZ8SillIxG5HEflhSU2h6KUKsciduHl+Ph406pVK7vD8MupU6eoXr263WH4LZLi1ViDI5JiBVixYsUhY0w9u+MoLRGJA6YBfzbGnChx3/04ltWgSpUq3RMTEykqKqK42LMSVlRUFFWqVLGuG2MQCXxRheLiYqKi7OsjsHv74RCD3dsPhxgq+vYBtmzZUrbjnd3FfMt6ueKKK0ykmDt3rt0hlEokxauxBkckxWqMMcByEwbHpdJcgFgcVWQeuVhb9+Ndfn6+Wbx4sRk/frwZPHiwGTdunHXf8uXLDWDq1KljbrnlFjN+/HizePFik5+ff6kvse1/E3ZvPxxisHv74RBDRd++MWU/3kVsD59SSkUqcXTB/QfYZIwp1WLesbGx9OrVi169evHoo4963HfgwAEaNWrE3r17+e677/juu+8AqFatGr1792bq1KnUrVs3ULuhlIogmvAppVToJeGoiLBORFY7b3vS+F+ZxquBAweSnZ3Nzp07ycjIsC6bN29m1apV1KlTx2p7//33U6tWLVJSUkhOTva4TylV/mjCp5RSIWaMmQ8EfqAdICI0b96c5s2b8+tf/xpw9Pxt3brVGnt0+vRpPvjgAwoLC5kwwVHuumPHjqSkpJCSksJ1111HvXoRNyRSKXUBmvAppVQ5V79+ferXr29dj4mJYfr06aSnp5ORkcHSpUtZv34969ev56233uLjjz9m5MiRAGzdupWioiK7QldKBYgmfEopVcFUqlSJG2+8kRtvvBGAs2fPsnz5cusUcGpqqtV2woQJvPPOO7z22mu88cYbVi9gly5diI6OtmsXlFKlpAmfUkpVcFWqVCE5OZnk5GSeeOIJj/vi4uKoX78+hYWFTJs2jWnTpgFQo0YN7rvvPiZOnGhHyEqpUgr5YjIi8hdn7cj1IvKpiFQRkToiMktEspw/a4c6LqWUUuebMGEC+/fvp2PHjrz33nuMGjWKFi1acPLkSWJifukz2LBhAykpKTz11FP88MMPnDhx4gLPqpQKtZD28IlII+AhoL0x5oyIfA7cCbQHZhtjXhaRscBY4PFQxqYiy5Qp8NRTsHv31TRtCi+8AM4hR0qpABMRKleuzL333su9994LwN69ez0WeJ43bx7z589n/vz5gGNB6C5dulingIcMGUJsbKwt8Sul7DmlGwNUFZECoBqwD3gC6Oe8/0MgDU34lA9TpsD998Pp0wDCrl2O66BJn1Kh0qiRZ+nfu+66i6ZNm1rjAJctW8aqVatYtWoVH3zwAUeOHLHazpw5k7Zt23L55ZcHpSqIUup8IU34jDF7RWQCsBs4A/xkjPlJRBoYY3KcbXJEpP4Fn0hVaE895Ur2fnH6tON2TfiUskd8fDw333wzN998M+BY+mXJkiVkZGRw9uxZ6/TvuXPnGDp0KOfOnaNRo0ZWD2BKSgodOnSwvWyVUuVVqE/p1gaGAM2BY8AXInJ3KR5v1ZasV68eaWlpQYgy8PLy8iImVgj/eHfvvhpvS5jt3m1IS5sX+oD8FO6vq7tIilWFp2rVqtG/f3/69+/vcfuRI0cYMGAA8+fPZ+/evXz22Wd89tlnANSuXZupU6dy/fXX2xGyUuVaqE/pXgfsMMYcBBCR/wP6Arkikuj
"text/plain": [
"<Figure size 648x194.4 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 52\n",
"\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
2016-09-27 23:31:21 +02:00
"Xs = np.array([[1, 50], [5, 20], [3, 80], [5, 60]]).astype(np.float64)\n",
"ys = np.array([0, 0, 1, 1])\n",
"svm_clf = SVC(kernel=\"linear\", C=100).fit(Xs, ys)\n",
"\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(Xs)\n",
"svm_clf_scaled = SVC(kernel=\"linear\", C=100).fit(X_scaled, ys)\n",
2016-09-27 23:31:21 +02:00
"\n",
2021-12-08 03:16:42 +01:00
"plt.figure(figsize=(9, 2.7))\n",
2016-09-27 23:31:21 +02:00
"plt.subplot(121)\n",
"plt.plot(Xs[:, 0][ys==1], Xs[:, 1][ys==1], \"bo\")\n",
"plt.plot(Xs[:, 0][ys==0], Xs[:, 1][ys==0], \"ms\")\n",
"plot_svc_decision_boundary(svm_clf, 0, 6)\n",
"plt.xlabel(\"$x_0$\")\n",
"plt.ylabel(\"$x_1$    \", rotation=0)\n",
"plt.title(\"Unscaled\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 6, 0, 90])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.subplot(122)\n",
"plt.plot(X_scaled[:, 0][ys==1], X_scaled[:, 1][ys==1], \"bo\")\n",
"plt.plot(X_scaled[:, 0][ys==0], X_scaled[:, 1][ys==0], \"ms\")\n",
"plot_svc_decision_boundary(svm_clf_scaled, -2, 2)\n",
"plt.xlabel(\"$x'_0$\")\n",
"plt.ylabel(\"$x'_1$ \", rotation=0)\n",
"plt.title(\"Scaled\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([-2, 2, -2, 2])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"sensitivity_to_feature_scales_plot\")\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## Soft Margin Classification"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 7,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAACyCAYAAABMWnkMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABkTklEQVR4nO3dd1hURxfA4d+AFBF7lKKIvbdYosaKvdfYJVGjxhI/jTFRY4pJLNHEErvR2I1GjTF2YwF7770hijS70dgQ5vtjYQMI7C7SOe/z3Efu3dnZmQWOw92ZM0prjRBCCCGEEOmFVXI3QAghhBBCiKQkA2AhhBBCCJGuyABYCCGEEEKkKzIAFkIIIYQQ6YoMgIUQQgghRLoiA2AhhBBCCJGuJNkAWCnlppTyUkpdUEqdU0oNiqGMUkpNVUpdVUqdVkpVSKr2CSGEEEKI9CFDEr7WK+BTrfVxpVRm4JhSapvW+nykMk2AIuFHFWBW+L9CCCGEEEIkiCS7A6y1DtRaHw//+jFwAcgTrVgrYLE2OAhkU0q5JFUbhRBCCCFE2pcsc4CVUvmBt4FD0R7KA/hFOr/F64NkIYQQQggh4i0pp0AAoJRyBP4ABmut/4n+cAxPeW2vZqVUH6APgL29fcV8+fIleDtTkrCwMKys0v56xfTQT+lj2pHU/bx8+fJdrXWu+D4/cty0s7Or6OTkRIYMGbC2tkapmEJvzLTWaK2Nfdda8+zZM6ysrOJVX2JJDz+H6aGPkD76KX1MHHHGzYhglhQHYANsBYbE8vgcoHOk80uAS1x1Fi1aVKd1Xl5eyd2EJJEe+il9TDuSup/AUZ1AsThv3rwaw80FbW1trZs1a6ZXrFihnz59anG71q9fr3PkyGGsz8rKSjdo0EAvWbJEP3ny5I37HV/p4ecwPfRR6/TRT+lj4ogrbiZlFggF/Apc0FpPiqXYOuD98GwQVYFHWuvApGqjEEKkBw4ODvz+++80b94cpRQbN26kU6dOODs78+GHH+Lt7U1YWJhZdTVv3pzAwED+/PNP2rRpg7W1Ndu2bcPT05OCBQsSEhKSyL0RQgjLJeW96OqAJ1BXKXUy/GiqlOqrlOobXmYT4ANcBeYC/ZOwfUIIkW506NCB9evXExAQwNSpU6lUqRL//PMP8+fPx8PDgwIFCjBy5EguXrxosi5bW1tat27NmjVrCAoKYtasWbz77rs0bNgQGxsbAF68eMEXX3zBmTNnErtrQghhUlJmgdirtVZa67Ja6/Lhxyat9Wyt9ezwMlprPUBrXUhrXUZrfTSp2ieEEOlRrly5GDhwIEeOHOHChQt88cUX5MuXj5s3bzJ27FhKlChB5cqVmTZtGnfu3DFZX44cOejbty/79u1j/vz5xusbN25k3LhxlC1blvLlyzNx4kQCA+UDPiFE8kjbM66FEEKYrXjx4owZM4br16/j5eVFz549yZw5M0ePHuV///sfrq6utGjRgpUrV/L8+XOT9UXc/QUoVqwYffv2JXv27Jw6dYqhQ4eSN29eGjVqxNKlS82eciGEEAlBBsBCCCGisLKyok6dOvz6668EBwezYsUKmjZtitaaDRs20LFjR5ycnOjduze7d+82a/BaqlQpZs2aRWBgIH/88QetW7fG2tqav//+m3HjxkXJHCGDYSFEYpMBsBBCpDPPnz/n+++/5/r16ybLZsyYkY4dO7Jx40b8/f2ZMmUKFStW5J9//mHevHnUrl2bggUL8uWXX3Lp0iWT9dnZ2dG2bVv+/PNPAgMDmTFjBl988YVxAHzt2jXc3d0ZNmwYZ8+efeO+CiFETGQALNI9X19flFIcPXo0xnMh0pqwsDC+/vprChYsSM2aNfnll1948OCByec5OTkxaNAgjh49yrlz5xg+fDhubm7cuHGDMWPGULx4capUqcL06dO5e/euyfpy5sxJ//796dq1q/Ha2rVruXXrFhMmTKBMmTJUqFCByZMnExQU9EZ9FkKIyGQALFIFf39/+vTpQ968ebG1tSVPnjz07t2bW7duWVRPnTp1+Pjjj+Ms4+bmRmBgIOXLl3+DFovUwNkZlHr9cHZO7pYlLjs7O7p27YqDgwN79+7lo48+wtnZmffee49//om+P1HMSpYsybhx4/D19WXnzp306NEDR0dHDh8+zMCBA3FxcaFly5asXr3arPnCEYYMGcKePXvo06cP2bJl48SJEwwZMoQ8efLQtm3biBzxQohkklbipgyARYp3/fp1KlWqxNmzZ1m0aBFXr15l6dKlnDt3jsqVK+Pr65ugr2dtbY2zszMZMsR/o8SXL18mYItEYgkOtux6WmFtbc3SpUsJCgpi0aJF1KtXj5CQEM6cOUPmzJmN5c6fP29ywGllZYWHhwfz588nODiY3377jSZNmqC1Zv369bRv3x5nZ2f69OnDnj17TNanlKJGjRrMmTOHwMBAVq9eTatWrbC2tsbW1tY4VeLly5fs3LlT5gsLkcTSStyUAbBI8QYMGICVlRXbt2+nXr165MuXDw8PD7Zv346VlRUDBgwAYr672717d5o3b278eteuXcyYMQOlFEqpGAfPMU2BOH/+PM2aNSNz5szkzp2bzp07R/lINuJ1xo8fT968ecmbN28ivBNCJKzMmTPz/vvvs337dm7evMnChQuNA8wbN25QqlQpihcvbvZ8YQcHBzp37symTZvw9/dn8uTJvP322zx69Ii5c+dSq1YtChUqxNdff82VK1dM1mdvb0+7du1Yu3YtAQEBjBs3zvjY5s2bqVevHu7u7gwfPpxz587F/40QQqQ7MgAWKdr9+/fZsmULAwYMwMHBIcpjDg4O9O/fn82bN5s1f/Hnn3+mWrVq9OjRg8DAQAIDA3FzczP5vMDAQGrVqkXp0qU5fPgw27dv58mTJ7Rs2TLK3addu3Zx+vRptmzZwo4dOyzvrBDJKG/evFSrVs14fvXqVZydnbl8+bJxvnCtWrWYO3cuDx8+NFmfk5MTgwcP5vjx45w9e5Zhw4aRN29erl+/zvfff0/RokWpWrUqM2fO5N69eybre+uttyhQoIDx/Pnz5+TPn59bt24xfvx4SpcuTYUKFZgyZQrBqe1WlBAiyckAWKRoV65cQWtNiRIlYny8ZMmSaK3NupuUNWtWbG1tcXBwwNnZGWdnZ6ytrU0+b9asWZQrV47x48dTokQJypYty+LFizly5EiUu8T29vbMnz+f0qVLU6ZMGfM7KUQKVK9ePfz8/Ni8eTNdunQhY8aMxrm57u7uPHv2zOy6SpUqxQ8//ICvry/bt2/ngw8+wNHRkUOHDjFgwABcXFxo3bo1f/zxBy9evDCrzo4dO3Lt2jV2795N7969yZo1KydOnOCTTz6hXr168e22ECKdkAGwSBUi5wiNLGI+YWyPJ4Rjx46xe/duHB0djUfEneNr164Zy5UuXRo7O7tEa4cQSS1Dhgw0btyYZcuWERwczMKFC6lbty4NGzYkY8aMALx69Yphw4Zx5MgRk/N7ra2tqVevHgsXLiQoKIhly5bRuHFjQkND+euvv3jvvfdwdnY27iRnzvzjiCwWQUFBrFy5khYtWvD+++8by/j6+tKzZ0+8vb1lvrAQwkgGwCJFK1KkCEqpWOf3XbhwAaUUhQoVwsrK6rX/MENCQt64DWFhYTRr1oyTJ09GOa5cuWKcXwyQKVOmN34tkbScnCy7Hnn1s4dHnVS7+jk+MmfOzAcffMCOHTtYsWKF8fq2bduYMGEC77zzDiVKlGDMmDHcuHHDZH2ZMmWiS5cubN68mVu3bjFx4kTKly/Pw4cPmTNnDjVq1KBw4cJ88803XL161WR99vb2tG/fnnXr1vH5558bry9dupQFCxbg4eFBgQIF+OKLL7hw4UL83gQhhEVxMyXHTBkAixQtR44cNGrUiJkzZ/L06dMojz19+pQZM2bQpEkTcuTIQa5cuQgMDIxS5tSpU1HObW1tCQ0NtagNFSpU4Ny5c7i7u1O4cOEoR+QV8yL1CQoCrV8/Yks5m1ZWP7+pyFOHChcuzODBg8mdOzeXLl3iyy+/JH/+/Mad5F69emWyPhcXF4YMGcKJEyc4ffo
"text/plain": [
"<Figure size 720x194.4 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 53\n",
"\n",
2016-09-27 23:31:21 +02:00
"X_outliers = np.array([[3.4, 1.3], [3.2, 0.8]])\n",
"y_outliers = np.array([0, 0])\n",
"Xo1 = np.concatenate([X, X_outliers[:1]], axis=0)\n",
"yo1 = np.concatenate([y, y_outliers[:1]], axis=0)\n",
"Xo2 = np.concatenate([X, X_outliers[1:]], axis=0)\n",
"yo2 = np.concatenate([y, y_outliers[1:]], axis=0)\n",
"\n",
"svm_clf2 = SVC(kernel=\"linear\", C=10**9)\n",
2016-09-27 23:31:21 +02:00
"svm_clf2.fit(Xo2, yo2)\n",
"\n",
2021-12-08 03:16:42 +01:00
"fig, axes = plt.subplots(ncols=2, figsize=(10, 2.7), sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[0])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(Xo1[:, 0][yo1==1], Xo1[:, 1][yo1==1], \"bs\")\n",
"plt.plot(Xo1[:, 0][yo1==0], Xo1[:, 1][yo1==0], \"yo\")\n",
"plt.text(0.3, 1.0, \"Impossible!\", color=\"red\", fontsize=18)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.annotate(\n",
" \"Outlier\",\n",
" xy=(X_outliers[0][0], X_outliers[0][1]),\n",
" xytext=(2.5, 1.7),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(Xo2[:, 0][yo2==1], Xo2[:, 1][yo2==1], \"bs\")\n",
"plt.plot(Xo2[:, 0][yo2==0], Xo2[:, 1][yo2==0], \"yo\")\n",
"plot_svc_decision_boundary(svm_clf2, 0, 5.5)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.annotate(\n",
" \"Outlier\",\n",
" xy=(X_outliers[1][0], X_outliers[1][1]),\n",
" xytext=(3.2, 0.08),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"sensitivity_to_outliers_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 8,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('linearsvc', LinearSVC(C=1, random_state=42))])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2017-06-01 09:23:37 +02:00
"import numpy as np\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.pipeline import make_pipeline\n",
2016-09-27 23:31:21 +02:00
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.svm import LinearSVC\n",
"\n",
"iris = load_iris(as_frame=True)\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = (iris.target == 2) # Iris virginica\n",
2017-06-01 09:23:37 +02:00
"\n",
"svm_clf = make_pipeline(StandardScaler(),\n",
" LinearSVC(C=1, random_state=42))\n",
2017-06-01 09:23:37 +02:00
"svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 9,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([ True, False])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"X_new = [[5.5, 1.7], [5.0, 1.5]]\n",
"svm_clf.predict(X_new)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 10,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.66163411, -0.22036063])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"svm_clf.decision_function(X_new)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAACyCAYAAABMWnkMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAB0mElEQVR4nO2dd3hUVdrAf2fSQyAklCQkQBJKDKTRFZQiKgooiK5YUHHd1bWsFV37urb9dBXF1XXVddVV7GAHRYSgIEWQhAAhlBRIARIIaaTP+f6YzGVuZjKZZCaZCTm/55knM+eec+97T+59551z3yKklCgUCoVCoVAoFN0Fg7sFUCgUCoVCoVAoOhNlACsUCoVCoVAouhXKAFYoFAqFQqFQdCuUAaxQKBQKhUKh6FYoA1ihUCgUCoVC0a1QBrBCoVAoFAqFoluhDGCFQqFQKBQKRbdCGcCKbokQIkUIsVQIkS+EqBFC5Agh3hdCJLlo/5OFEF8JIQqEEFIIsdAV+1UoFIqOxlP0oxDi1qZj1wghtgkhzmlPH4XCFsoAVnQ7mpTtVqAWmA8MB65v2nyniw4TBOxs2l+1i/apUCgUHYqn6EchxHxgCfAMMAr4BVgphBjUlj4KRUsIVQlO0Z0QQpwF/AzcJ6V80cb2UCnlcRcfsxK4XUr5jiv3q1AoFK7Ek/SjEGIzsENK+UeLtn3AZ1LKBx3to1C0hFoBVnQ3XgA221LuAM2VuxDiISFEZSsv9chNoVCcDniEfhRC+AJjgFXNNq0CJjraR6Gwh7e7BVAoOgshxDDgLOCqNgz7N/BJK30K2i2UQqFQeAAeph/7Al7AkWbtR4Dz2tBHoWgRZQAruhOjm/5udXRA04qHSx/5KRQKhQfiifqxuY+msNHmSB+FwgrlAqHoTgQ2/a10dIBygVAoFN0ET9KPJUAjEN6svT+nVnwd6aNQtIhaAVZ0J3Y2/Z0CfNx8oxAiUEp5slmzcoFQKBTdAY/Rj1LKOiHENuB84FOLTecDyxzto1DYQxnAim6DlPJXIcQK4J9CiABgA6ZHZaOBPwJ/A9Y3G9OuR3xCiCBgaNNHAzBICJECHJdSHmz3SSgUCkUH4IH6cTHwnhBiS5MsfwIGYDK6aUMfhcImKg2aolshhPDDlHtyATAEqAOygW+BZ6SUNS46zlRgrY1N70opF7riGAqFQuFKPE0/CiFuBe4HIjCtUN8tpfyp2b5a7aNQ2EIZwAqFQqFQKBSKboUKglMoFAqFQqFQdCuUAaxQKBQKhUKh6FYoA1ihUCgUCoVC0a1QBrBCoVAoFAqFoluhDGCFQqFQKBQKRbeiy+cB7t27txw6dGjrHd1IVVUVPXr0cLcYdlEyugYlo2tQMsK2bdtKpJT9OuwAFig96hqUjK7B02X0dPlAyWjGrh6VUnbp1/Dhw6Wns3btWneL0CpKRtegZHQNSkYpga1S6VENdU24BiWj83i6fFIqGc3Y06PKBUKhUCgUCoVC0a1QBrBCoVAoFAqFoluhDGCFQqFQKBQKRbdCGcAKhUKhUCgUim6FMoAVCoVCoVAoFN2KLp8GzR5Go5GSkhJOnDhBY2Oj2+QIDg4mMzPTbcd3BCWjHi8vL3r37k3fvn0xGNTvRIVCoVAoTidOawM4Pz8fIQTR0dH4+PgghHCLHBUVFfTs2dMtx3YUJeMppJTU19dz5MgR8vPzGTRoUIcfU6FQKBQKRedxWi9tVVVVERkZia+vr9uMX0XXQwiBr68vkZGRVFVVuVschUKhUCgULua0NoAB9fha0W7UtaNQKBQKxemJ+oZXKBQKhUKhUHQrlAGsUCgUCoVCoehWdJoBLIQYKIRYK4TIFELsEkLcaaPPVCFEmRAiren1WGfJ1xWZOnUqt99+e4ftf+HChcyePdvp/aSmpiKEoKSkxOEx77zzDkFBQU4fW6FoD0UVRUx5ZwqHKw+7WxSFQqHocnQFHdqZWSAagHullL8JIXoC24QQP0gpdzfr97OU0nmrqwuzcOFCSkpK+Oabb+z2W758OT4+Ph0mx5IlS5BSOr2fiRMnUlRURJ8+fRweM3/+fGbOnOn0sRWK9vDkT0+y/uB6nlz3JK/OetXd4igUCkWXoivo0E5bAZZSFkkpf2t6XwFkApGddXxn8aRfM3V1dQCEhoZ2aFqw4OBgevfu3aocreHr60t4eHibMnEEBATQv39/h/srFK6iqKKIt9PexiiNvJ32tkfc8wqFQtFV6Co6VLhiha/NBxUiGvgJSJBSllu0TwWWAflAIbBISrnLxvibgJsA+vXrN+aTTz6xeZzg4GCGDh3qEpnvXn03b+94m98n/57F0xe3aWxjYyNeXl4O9//Tn/7EsWPH+PTTT3WfJ06cyOuvv05dXR3Z2dnMnDmT+Ph4XnjhBQC++uor/v73v3PgwAH8/f0ZOXIk7777rk1D8oYbbqC+vp73339fk1EIwciRI7ntttu4/fbbreSYOXMmcXFxBAYG8sEHHzBo0CDWrVvHd999x0MPPcShQ4cYM2YMf/jDH/j9739PRkYGgwcP5ueff2bWrFnk5OTQp08fli5dyqJFi/jwww/5y1/+Ql5eHmPGjOHVV18lOjoaQOtTVFSkybxixQr+8Y9/sGvXLgICApgwYQL/+9//8Pf356OPPuK1115j3759+Pv7c/bZZ/N///d/DBgwoE3/q+bs37+fsrIyh/tXVlZ6vOuGktE+L+59kRWHV9AgG/AW3syKmMVdw+6y6tfRMk6bNm2blHJsR+3fUT3qKajr1jUoGZ3H0+UDpUPN2NWjUspOfQFBwDZgno1tvYCgpvczgX2t7W/48OGyJXbv3t3itrZQWF4o/Z/ylzyODHgqQBZVFLVpfHl5eZv6X3/99XLWrFm6z0FBQfLqq6+WGRkZcseOHVJKKadMmSJvu+02KaWURUVF0sfHRz7//PMyJydHZmRkyDfffFMePnzY5jG++eYb6efnJ0tLSzUZ16xZI728vGRRUZFNOaZMmSKDgoLkPffcIzMzM+Xu3btlXl6e9PX1lXfffbfcs2eP/PTTT+XAgQMlIHNycqSUUq5du1YCsri4WEop5dtvvy29vb3l9OnT5ebNm2V6erpMSUmRF1xwgXast99+W/bo0UP7vHLlSunl5SUffvhhuWvXLpmeni7/8Y9/yKqqKimllG+99Zb89ttv5YEDB+TmzZvl1KlT5TnnnNOmebdFW6+htWvXOn3MjkbJ2DKW97r51dI939EyAltlJ+lle3rUU1DXrWtQMjqPp8snpdKhZuzp0U6tBCeE8MG0wrtUSrm8+XZpsRospVwhhPiXEKKvlNLx6KkO4MmfnsQojQA0yka3+LT4+/vz3//+Fz8/P5vbCwsLqa+v5/LLL2fw4MEAJCQktLi/GTNm0KtXL5YtW8aNN94ImFZdp0+fTnh4eIvjYmJitBVngAcffJDY2FheeOEFhBDExcWxd+9eHn74Ybvn09DQwKuvvkpcXBwAixYt4oYbbsBoNNrMv/vkk08yZ84cnnrqKa0tKSlJe//73/9eex8bG8trr71GfHw8+fn5REVF2ZVFoTBjea+bcdc9r1AoFF2NrqRDHfIBFkL4CyH+IoRY1ZSdYYfly8F9COAtIFNKadOHQAgR3tQPIcT4JvmO2dtvY2Mjn3zyCb/++ivHjh1zSdCWJWZflrpGk79rXWOdW3xaEhISWjR+AZKTkznvvPNISEjgsssu47XXXqO4uLjF/t7e3syfP5+lS5cCUFtby7Jly1iwYIFdOcaMGaP7vGfPHsaNG6fz750wYUKr5+Pn56cZvwADBgygvr6eEydO2Oy/fft2pk6d2uL+fvvtN+bMmcPgwYPp2bMnY8eanngcPHiwVVk8BU/yM+/KpBWl0fv/erPjiEOqScfG/I3avW6mrrGOX/J/cZV4HonRaOTYMbuqVqHoEig96hraq0e7kg51dAX4X8ClwKfAL0B7rMxJwLVAhhAirantIWAQgJTy38DlwC1CiAagGrhStmLR1tfXM3/+fO1zr169iImJITY2lgcffFBrNxpNv0jaWt3LU37N9OjRw+52Ly8vVq1axaZNm1i
"text/plain": [
"<Figure size 720x194.4 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"# extra code this cell generates and saves Figure 54\n",
"\n",
2016-09-27 23:31:21 +02:00
"scaler = StandardScaler()\n",
"svm_clf1 = LinearSVC(C=1, max_iter=10_000, random_state=42)\n",
"svm_clf2 = LinearSVC(C=100, max_iter=10_000, random_state=42)\n",
"\n",
"scaled_svm_clf1 = make_pipeline(scaler, svm_clf1)\n",
"scaled_svm_clf2 = make_pipeline(scaler, svm_clf2)\n",
2016-09-27 23:31:21 +02:00
"\n",
"scaled_svm_clf1.fit(X, y)\n",
"scaled_svm_clf2.fit(X, y)\n",
"\n",
2016-09-27 23:31:21 +02:00
"# Convert to unscaled parameters\n",
"b1 = svm_clf1.decision_function([-scaler.mean_ / scaler.scale_])\n",
"b2 = svm_clf2.decision_function([-scaler.mean_ / scaler.scale_])\n",
"w1 = svm_clf1.coef_[0] / scaler.scale_\n",
"w2 = svm_clf2.coef_[0] / scaler.scale_\n",
"svm_clf1.intercept_ = np.array([b1])\n",
"svm_clf2.intercept_ = np.array([b2])\n",
"svm_clf1.coef_ = np.array([w1])\n",
"svm_clf2.coef_ = np.array([w2])\n",
"\n",
"# Find support vectors (LinearSVC does not do this automatically)\n",
"t = y * 2 - 1\n",
"support_vectors_idx1 = (t * (X.dot(w1) + b1) < 1).ravel()\n",
"support_vectors_idx2 = (t * (X.dot(w2) + b2) < 1).ravel()\n",
"svm_clf1.support_vectors_ = X[support_vectors_idx1]\n",
"svm_clf2.support_vectors_ = X[support_vectors_idx2]\n",
"\n",
2021-12-08 03:16:42 +01:00
"fig, axes = plt.subplots(ncols=2, figsize=(10, 2.7), sharey=True)\n",
"\n",
"plt.sca(axes[0])\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\", label=\"Iris virginica\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\", label=\"Iris versicolor\")\n",
"plot_svc_decision_boundary(svm_clf1, 4, 5.9)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper left\")\n",
"plt.title(f\"$C = {svm_clf1.C}$\")\n",
"plt.axis([4, 5.9, 0.8, 2.8])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
"plot_svc_decision_boundary(svm_clf2, 4, 5.99)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.title(f\"$C = {svm_clf2.C}$\")\n",
"plt.axis([4, 5.9, 0.8, 2.8])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"regularization_plot\")\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Nonlinear SVM Classification"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 12,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAocAAADUCAYAAADukYmSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUSElEQVR4nO3df5Dcd13H8eebpBFocaIULoZ0CJ0pDBUJ2hOrHYZrASdoW9SRoR1xmIEx4AgGxw5a6kjlZGCE4dfATCejbUfAomJBKFBSJUtGG2gbciWtJaWghSOXRqYuJq3kms3bP27zmWuSo3d7u9/PZff5mNmZ3e/u916f3fvuJ69897vfi8xEkiRJAnhS7QFIkiRp5bAcSpIkqbAcSpIkqbAcSpIkqbAcSpIkqVi9lAefffbZuXHjxp7DHnnkEc4888ye11+Omtnmm++233v+7t27f5CZz+jjkFac03luNb9u/r59++h0Opx//vlV8kf5tR+G/AXn18xc9OWCCy7I5dixY8ey1j9ds803322/d8BduYR56nS8nM5zq/l181/60pfmpk2bquWP8ms/DPkLza9+rCxJkqTCcihJkqTCcihJkqTCcihJkqTCcihJPYqI6yPiYETcc8Lyt0TEvoi4NyL+ahDZ69ZBxNzl4osnyvV16waRJmmlaOK9bzmUpN7dCGyevyAiLgZeBbwwM38WeN8ggh96aGnLJQ2HJt77lkNJ6lFm7gQePmHx7wPvycwj3cccbHxgkrQMSzoJtiTpCT0XeElEvAv4EXBVZt554oMiYguwBWBsbIxWq7XEmIkF71n6z1qew4cPN55p/px2u02n06mWP8qvfb38iQXv6ddYLIeS1F+rgZ8CLgR+EfiHiDi3e8LZIjO3AdsAxsfHc2Jiom8D6OfPWoxWq9V4pvlz1q5dS7vdrpY/yq/9Ssg/Ub/G4sfKktRf08DN3T9AcAdwDDi78pgkadEsh5LUX58BLgGIiOcCa4Af9DtkbGxpyyUNhybe+5ZDSepRRNwE7AKeFxHTEfEG4Hrg3O7pbT4JvO7Ej5T74cAByJy77NjRKtcPHOh3kqSVpIn3vsccSlKPMvPKBe56baMDkaQ+cs+hJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJPUoIq6PiIMRcc+8Ze+NiG9GxDci4tMRsbbiECVpySyHktS7G4HNJyy7DXhBZr4QuB+4uulBSdJyWA4lqUeZuRN4+IRl2zPzaPfmV4ENjQ9MkpZhde0BSNIQez3w96e6IyK2AFsAxsbGaLVaPYccPnx4Wesvl/n18tvtNp1Op1r+KL/2w5xvOZSkAYiIa4CjwCdOdX9mbgO2AYyPj+fExETPWa1Wi+Wsv1zm18tfu3Yt7Xa7Wv4ov/bDnG85lKQ+i4jXAZcCL8vMrD0eSVoKy6Ek9VFEbAb+BHhpZj5aezyStFR+IUWSehQRNwG7gOdFxHREvAH4CPA04LaImIqI66oOUpKWyD2HktSjzLzyFIv/pvGBSFIfuedQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQkiRJheVQ0siLiFdHxJGIePa8ZR+KiG9HxFjNsa1U69ZBxNzl4osnyvV162qPTBqsUdj2LYeSBJ8C9gJ/BhARVwFXApsz86GaA1upHlrgVVlouTQsRmHbX117AJJUW2ZmRLwd+HxEfBu4BrgkM78VEecAHwOeCTwG/EVm3lxxuJI0UJZDSQIyc3tE3An8JXBZZt7Zveso8NbMnIqIZwK7I+LWzHy02mAlaYD8WFmSgIi4BNgEBFA+IMrMmcyc6l4/CPwPcHaNMUpSEyyHkkZeRGwCbgbeAnwGePcCjxsHzgC+19jgJKlhlkNJI637DeUvAO/PzOuBdwCviIiJEx73dOBvgTdkZjY9zpVmbIHvcC+0XBoWo7DtWw4ljayI+GngVuCWzHwnQGbeA/wj8/YeRsRPAJ8G3p2Zt9cY60pz4ABkzl127GiV6wcO1B6ZNFijsO37hRRJIyszHwaef4rlrzl+PSICuBH4cmZ+rLnRSVIdA99zWPNkkSvlRJUzh2bYOrWVA4fr/LfC/NHLr73t187vs4uA1wC/ERFT3cvPPdFKEfFHEXFvRNwTETdFxJMHP1RJWr6Bl8OaJ4tcKSeqnNw5yd4f7mXyK5PNBps/svm1t/3a+f2Umf+WmU/KzBfNu+z9cetExLOAPwTGM/MFwCrgiibGK0nLtaSPlfft28fExMQSI1oL3rP0n7VUNbPnHFlzhDsuvINclVz3tevY86E9rJld00i2+aOc31rwnma2/dr5K8Jq4CkR8RjwVGB/5fFI0qJ4zOGAPbjxQZK5LzYmyYPPfpDzvnWe+eZriGXm9yPifcB3gf8Dtmfm9vmPiYgtwBaAsbExWq1Wz3mHDx9e1vrLZX69/Ha7TafTqZY/yq/9MOfHUs7IMD4+nnfdddfSAmLh+wZ9Moia2TB3rNm5Hz6XHx39UVn2lNVP4Ttbv8O6swZ/8JX5o5tfe9vvZ35E7M7M8eWNqFkR8VPAPzF3rGKbuW8/fyozP36qx/cyt87XarWq7pE1v17+xMQE7XabqampKvmj/NoPQ/5C86unshmgyZ2THMtjj1vWyU5jx56ZP9r5qurlwH9m5n9n5mPMnWD7VyqPSZIWZeDlsObJImufqHLX9C5mO7OPWzbbmeX26WZOk2b+6ObX3vZr568A3wUujIindk+F8zLgvspjkqRFGfgxh/NPCtn07tea2QB73rjHfPOr5Nfe9mvn15aZX4uITwFfB44Ce4BtdUclSYvjF1IkaQAy8x3M/Sk+STqteMyhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJNW2f//c3xxczGXLlpNWf+773rf49a+99uT8yy5b/PrbTj5d4wVbtix+/c997uT89esXv/7u3SetPnHxxYtff//+3l/7U/1dyN27F5+/fv3J63/uc4vPvuCCk1a/dP9+pu6+e3HrX3bZyfnXXuu2N8rb3gIsh5IkSSosh5IkSSr8CymSVNv69Sd/5LQE9191FetvuaX3/B/z8dJi7N62jYmJid5/wDKeO0Brx47e89evh8zewy+4YHn5l122rPxb1q/n4099KlNTU739gGuvPfXHvYvktneab3sLcM+hJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJEmSCsuhJA1ARKyKiD0RcUvtsQyzmUMzbJ3ayoHDB2oPRQ3zdz84lkNJGoytwH21BzHsJndOsveHe5n8ymTtoahh/u4Hx3IoSX0WERuAXwf+uvZYhtnMoRlumLqBJLlh6gb3II0Qf/eDtbr2ACRpCH0QeBvwtIUeEBFbgC0AY2NjtFqtnsMOHz68rPWXq1b+B+7/AEc7RwF4rPMYb7rpTbz1vLc2Po6ar3+73abT6VTL93c/nO89y6Ek9VFEXAoczMzdETGx0OMycxuwDWB8fDwnJhZ86BNqtVosZ/3lqpE/c2iG7f++naM5VxCO5lG2H9zOdVdex7qz1jU6lpqv/9q1a2m329Xy/d0P53vPj5Ulqb8uAi6PiP8CPglcEhEfrzuk4TO5c5JjeexxyzrZ8fizEeDvfvAsh5LUR5l5dWZuyMyNwBXAlzPztZWHNXR2Te9itjP7uGWznVlun7690ojUFH/3g+fHypKk086
"text/plain": [
"<Figure size 720x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 55\n",
"\n",
2016-09-27 23:31:21 +02:00
"X1D = np.linspace(-4, 4, 9).reshape(-1, 1)\n",
"X2D = np.c_[X1D, X1D**2]\n",
"y = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n",
"\n",
"plt.figure(figsize=(10, 3))\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.subplot(121)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.plot(X1D[:, 0][y==0], np.zeros(4), \"bs\")\n",
"plt.plot(X1D[:, 0][y==1], np.zeros(5), \"g^\")\n",
"plt.gca().get_yaxis().set_ticks([])\n",
"plt.xlabel(r\"$x_1$\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([-4.5, 4.5, -0.2, 0.2])\n",
"\n",
"plt.subplot(122)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n",
"plt.plot(X2D[:, 0][y==0], X2D[:, 1][y==0], \"bs\")\n",
"plt.plot(X2D[:, 0][y==1], X2D[:, 1][y==1], \"g^\")\n",
"plt.xlabel(r\"$x_1$\")\n",
"plt.ylabel(r\"$x_2$  \", rotation=0)\n",
2016-09-27 23:31:21 +02:00
"plt.gca().get_yaxis().set_ticks([0, 4, 8, 12, 16])\n",
"plt.plot([-4.5, 4.5], [6.5, 6.5], \"r--\", linewidth=3)\n",
"plt.axis([-4.5, 4.5, -1, 17])\n",
"\n",
"plt.subplots_adjust(right=1)\n",
"\n",
"save_fig(\"higher_dimensions_plot\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 13,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('polynomialfeatures', PolynomialFeatures(degree=3)),\n",
" ('standardscaler', StandardScaler()),\n",
" ('linearsvc',\n",
" LinearSVC(C=10, max_iter=10000, random_state=42))])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2017-06-01 09:23:37 +02:00
"from sklearn.datasets import make_moons\n",
2016-09-27 23:31:21 +02:00
"from sklearn.preprocessing import PolynomialFeatures\n",
"\n",
"X, y = make_moons(n_samples=100, noise=0.15, random_state=42)\n",
2016-09-27 23:31:21 +02:00
"\n",
"polynomial_svm_clf = make_pipeline(\n",
" PolynomialFeatures(degree=3),\n",
" StandardScaler(),\n",
" LinearSVC(C=10, max_iter=10_000, random_state=42)\n",
")\n",
2017-06-01 09:23:37 +02:00
"polynomial_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 14,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAzNElEQVR4nO3de3Dc5Xno8e+zutiSLFm2LMsX+YLx2saIgFMgiXEAY9pwSSDtQC6054S0Uw6d0tPOnOQ0bTPtOe3pJCeTdE46pKFMGyBpnaQk3BJDAlHsQuyEOAYjg61YxhdJ2LLXK1u+SPJqte/5Y3fFStpdrbS/++/5zHgs7a52H/+82uf3vu/ze14xxqCUUko5LeJ2AEoppcJJE5BSSilXaAJSSinlCk1ASimlXKEJSCmllCs0ASmllHKFKwlIRL4hIqdE5M0C998sIgMisjfz56+djlEppZS9Kl163ceBh4FvFnnMK8aYDzsTjlJKKae5MgIyxrwM9Lvx2koppbzBy2tAHxCRN0TkBRG50u1glFJKWcutKbipvAasMMZcEJE7gGeAaL4HisgDwAMAs2fP/o3WpcttDs2a1kUGgyCWPJcT3I53uq+cwhDx0fHVeO3jp1jB3nhNKoWQIhUxRCIVljxnV9fbp40xzTP5WXGrF5yIrAR+aIxpK+GxR4FrjTGniz0uunqt+deH91gTYB6X4hcAaGwq/z8ulniD5uqry34ep9gRr8Tz/3fObppT9nP3JjpprV5X9vM4ReO1j59iBXvjjW/bzqr6DjpuFVZEt1jynFctv32PMebamfysJ0dAIrIIOGmMMSJyPempwrhb8WQTD1iTfMIqX8KxItkopaYW37adWcd3sWdzH1X5J5Qc50oCEpFvAzcDC0SkF/gboArAGPMIcA/wRyKSBIaATxgXhmqaeMozMeFoslHKefGufup2PktF45ucvmc2VS1Ry0Y/5XIlARljPjnF/Q+TLtN2jZXTbWGhCUcpb4l39VO7eweRVYcYuHUFKz2SeLI8OQXnJh31lE4TjlLeN3/uBYZXt1DTbHeB1vRpAsqho56padJRykcOvsHw4HGO1JynFk1AnqSjnsLGEk5NEjl/WhOOUj6QXfe5FOmga2OSqpYoCxu9UXiQK/QJSEc9k+Ub5UQSFZp8lPKBWHsHNYd/wIkNPcjlTTS0XefJ5AMhTkA66hlPp9aUCobIcJyla+cQ37jeM9VuhYQyAemoJ02TjlLBFKmvAYbdDmNKoUtAmnzGJx5NOkoFQ3bdZ3DOAXbXD3nmYtNiQpOAwp54NOkoFVzZdZ/Dqw9Te80iGto2enbdJ1coElCYk48mHqWCLd7Vz/LhvQysHaTu1qs9v+6TK/AJKIzJR5OOUuERuXgKgFkevdi0mMAmoDBWuWniUSpczuztpGbvz7nUMsCRmpgnLzYtJpAJKGyjHk08SoVPdt3n4IZjNFzeTK2Hr/cpJHAJKEzJJ4iJ5+b7lhA/M/n/rmneKDu2HnchIqW8J5t8Yht6mLuxzVfrPrkCloDSOzYEPflkE09Qkk6ufMmn2O1KhZGfLjYtJmAJKNjJJ8iJRylVooE+wD8XmxYTuAQURJp4lFKQnnqr79lDZOlBXl8Zobb5OrdDKosmIA8L4hqPUmr6shvLjSRfoXdTkqr1Uc9tLjcTmoA8SBOPUmqi+XMvUHvZYk7etMZ31W6FaALyEE086Wq3QlVwSoVV9mJTGme7G4jFNAF5hK7zpGmptVLjZS82Pb70IAM1Ed9dbFqMJiCXaeJRSuWTXfepuniA3s19gVn3yaUJyC3J9BbXoMlHKTVZqruX2qqDdN9yloYbtwRm3SeXJiAXSPw01GjiUUoVFhmO09BSQ+qKZYFMPqAJyFG5RQaRyuBeMKuUKs+ZvZ3UxY4SW3IWmOd2OLbRBOSQSWs9CReDUYD2nVPeNHJukKrdL9C7uY/KpgZqfbbFwnRoAnKAFhp4k/adU16SLTqQW+s5E+B1n1yagGykiUcpNR3z514gPms+1RtuCHzyAYi4HUBQafJRSs1IpbgdgWN0BGQDTT5KqenI7u+zb3MfIkupdjsgh2gCspAmHqXUdIxdbFq7k9iGi1StX0/F6HwWNi5yOzRHaAKyiCYf/9G+c8pN2eSzJDrAvnUt1K1LNxnt7QxPiawmIAto8vEnLbVWbmtqqaSqsYrRptmhKDqYSIsQyqTJRyk1U2b4rNshuEpHQDOkiUfphazKChXz5gD9bofhCk1AM6DJxxp+/wDXC1lVOSIXTyH1dfQlut0OxTWuTMGJyDdE5JSIvFngfhGRfxSRQyLSISLvdTrGQjT5WEc/wFVYxbdtp2r39+lYsoOjK6EmwO12inFrBPQ48DDwzQL33w5EM3/eB3w987erNPkopcoRa+9gzslDVNTu5PQ9s2loC367nWJcSUDGmJdFZGWRh9wNfNMYY4BfiEijiCw2xpxwJsLJNPkopaywqLWHc5s/SGIRoU4+4N0quKVAT873vZnbXKHJRyllhchw3O0QPMWrRQj5miGZvA8UeQB4AKC5uZlY4g1rI0kmoSazf4+F14clzDC9iU7rntBm9sS7rOA95b6WE8d3XuNCzpydlef2S/QmOrnvUzcWvH/rEy+Pu03fD/bxSqwj5waJXAN7Gq5AhmcR6a2mt2/yh0pi2ITmYlSvJqBexn86tQJ5y6KMMY8CjwJEV68xzdVXWxaEnSOf3kQnrdXrLH9eu1gRb6Gqt4ma5o2W/VpOHN+Xv32qyL3r8iYfgDNnZ02KLYzvB6e4HWu248Gl2p2cvypC1WXLWBHdUvDxvZ0JWteFoxucVxPQc8BDIvId0sUHA06v/+i0m/WKJZ99L/QUvC+s/F6mrt5NPhWte7iwoYVl77/T7ZA8xZUEJCLfBm4GFohIL/A3QBWAMeYR4HngDuAQMAh82tH4NPn4QrEP6H973P0pl3JpmXowNLVUMrRgLtXr1rgdiue4VQX3ySnuN8AfOxROXkFMPk6fUZc65TZTdn9A6whEWSHs7XaK8WoVnGskfjqQyQecP6P2+5m6jkCUZRpnux2BJ2kCyhHk5KOcV2hbB93uITwiF4sVqiivFiE4Lrvuowqzc0rq5vuWBG5aK2j/HlW6eFc/qe7esV1Oq1ZGqXE7KA/SBIQWHZTKzikpndaaTDfM86dYeweze7Yz2tit7XamEPoEFIbkY3cxgFuC/gGtIyj/Wn5lA/vWtbBSy66LCn0CgmAnH5h6dGHXB3ahBGGVYh/QvRZcSO5Egnv35CB/VwituPOn1NA5Rpu08GAqoU5AWnRg7wWg2Q9Ov47AnPjgn+q4+PG4hV1kOJ65qlFNJbRVcFp04Bz9EFVhE6nXkoNShHIEFIZ1n1JddXt66qeUqR47p6SCsm6jwi2+bTt1saOcXHICmOd2OJ4XygQEmnwmKmWUYseUlF1TgIU6UeuaSnHa/WFmxhqOJl/hzKYkVeujRRuOqrTQJaAwTr3ZXQzgRYU6UYftOEyXdn+Yvtxu16NXVdJw401adl2iUCWgsE695Z65ZqfclDdMdXKgU5P+0NRSydDiFhI3rtHkMw2hSkAQvuTjBUG/XifXdKewsre5vWeNUm4ITQIK49SbV4Rp7UCnsMLJDJ/VhqMzEIoEFNapt3y8PBrRBXClwiUUCQg0+fjhw93K0cO8xksFq+BK5YdjZjUvn6B4Uay9g5rDP+DVDcdoqGmmluVuh+QrgU9AOvWWVuzDPYidqLc+8XLZayphnE4L2vvALtnKt+HIHi5siDN3Y5uWXc9A4BMQ6OhnKkH+QJ1KLHGSzx56iC+v/hoLqhe6HY6jwjjCs0qqu5cl0QFOz09Qf8UmFkQ3uB2SLwU6AenoR03lkXe+ymvnf8kj73yVz1/2946/fqFmpE4kgTCO8KwmC5s1+ZQh8L3gdPSjCoklTvJs7EkMhmdiT3I6Uf7uldPdBVWTgH+lhs65HYLvBXYEpKMf/3F6AfyRd75KCgNAipQloyCdugoX3XKhPIFNQKCjn1x+aMcz8cM73tWf8/X0nivZkiR+rD/vfU3R+WOjnxGT3jhoxCR4JvYkDy7907G
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"# extra code this cell generates and saves Figure 56\n",
"\n",
"def plot_dataset(X, y, axes):\n",
" plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
" plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n",
" plt.axis(axes)\n",
" plt.grid(True, which='both')\n",
" plt.xlabel(r\"$x_1$\")\n",
" plt.ylabel(r\"$x_2$\", rotation=0)\n",
"\n",
2016-09-27 23:31:21 +02:00
"def plot_predictions(clf, axes):\n",
" x0s = np.linspace(axes[0], axes[1], 100)\n",
" x1s = np.linspace(axes[2], axes[3], 100)\n",
" x0, x1 = np.meshgrid(x0s, x1s)\n",
" X = np.c_[x0.ravel(), x1.ravel()]\n",
" y_pred = clf.predict(X).reshape(x0.shape)\n",
" y_decision = clf.decision_function(X).reshape(x0.shape)\n",
" plt.contourf(x0, x1, y_pred, cmap=plt.cm.brg, alpha=0.2)\n",
" plt.contourf(x0, x1, y_decision, cmap=plt.cm.brg, alpha=0.1)\n",
"\n",
"plot_predictions(polynomial_svm_clf, [-1.5, 2.5, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
"\n",
"save_fig(\"moons_polynomial_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Polynomial Kernel"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 15,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('svc', SVC(C=5, coef0=1, kernel='poly'))])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"from sklearn.svm import SVC\n",
2017-06-01 09:23:37 +02:00
"\n",
"poly_kernel_svm_clf = make_pipeline(StandardScaler(),\n",
" SVC(kernel=\"poly\", degree=3, coef0=1, C=5))\n",
2017-06-01 09:23:37 +02:00
"poly_kernel_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 16,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAuQAAAEQCAYAAAD4T2H3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABgxUlEQVR4nO3de3zU53Xg/8+juwQSEkJI3DFmLAzYIBs7DjbBjpwmdu3QpEmTsO066TZuusl2m91kk6btdttum/yyTdNk7cY/bza245TEjRtfcOwkWAGKTRwL24C4WsIIBAJpmBEgIYnRaJ79YzQwGs2M5vK9z3m/XnqBRqOZo2F4dObMec6jtNYIIYQQQggh7FFkdwBCCCGEEEIUMknIhRBCCCGEsJEk5EIIIYQQQthIEnIhhBBCCCFsJAm5EEIIIYQQNpKEXAghhBBCCBtJQi6EEEIIIYSNJCEXQgghhBDCRpKQC8dQSn1GKdVpdxxCCOEVsq4K4Q6SkAsnWQvsNerGlFKfVUrtV0pdnPj4lVLqN426/Xwopd6jlHpeKXVaKaWVUp/M8XbWKqX+WSl1Sik1qpQ6rpT6gVLqRgNj/R8TMcZ/nDXq9oUQplqLsevqtGuXUuo/TqxFo0qpN5RSG4y6/3wZFb8RP6Os3yKeJOTCSdYCbxl4e6eALwE3AeuAXwLP5rPYKaVKDIptJnAA+M/ASI6xfBLYA1wGPgZcBzww8eX/nH+IkxwF5sV93GDw7QshzLEWY9fVtGuXUupjwLeAvwNagN3AS0qpxfncqVVrbybxG/EzyvotptBay4d8WP5BdPFpI7ogdgC3ApeAe02+3yDwhxledyGggY8TTeZHgd83IaYh4JNZfs+7gTDw+RRfn21gfP8DOGD3c0Y+5EM+0n9Yva4mW7uAXwP/J+GyTuCrWdyubWtvJvHn+zPK+i0fyT6kQi4sp5TyAe1EqxQ3AF8G/gWoIuGtVaXUV5RSQ9N8TPtWoVKqWCn1caLVkd0Zhrp24s8vAX8PrAKeMzq2HH0D+LXW+pvJvqi1DsZ/bkCsyybe4j2ulPqRUmqZkT+MECI/dqyrSWIoA24GfpHwpV8A67O4qbUTf1q69mYSv0E/Y1br98T95vMzy/rtAka9BSRENh4CXtRax96W61JK/RbwQa11b8J1HyH6SyWd06m+oJS6AfgVUEG0GvIhrXVHhnGuIVqZ+ajWuivJ1/OKLVcTv3jfDXwii2/LJ9ZfA58EjgBzgT8HdiulVmmtA1nEIIQwj2XrahpzgGKgL+HyPuDuLG7HrrU3k/jz+hlzXL8h959Z1m+XkIRcWEoptQj4DaJvpcYbI8nGo4lKwZRqQRaOEq221AK/DTyhlLpTa30gg+9dS/QXXLJfCEbElqubJv7ck+k35BOr1vql+M+VUq8B7xDtd/yHXG5TCGEcG9bV6eiEz1WSy9JZi71rbybx5/ozZr1+Q+4/s6zf7iEtK8JqLcA4sC/h8ptI8osj37cmtdYhrXWX1nqP1vpPJ+7j8xnGugbYmeqLNrasVE38OZTpNxgZq9Z6CDgI+HKIXQhhPEvX1TTOTcTRlHD5XKZWlNOxa+3NJP58f8as128w7meW9du5pEIurKaJvt1XDoQAJhaRdwH/mOT6Rr81WTRx32kppWYA1wJvprmaLS0rRHtEATYCTyV+USlVpbUeTrjYsFiVUhXACmB7JtcXQpjO7nU1GoTWIaXUG8D7gB/Hfel9wL9mcht2rr2ZxG/Az5jL+g0G/cyyfjuXJOTCarExT3+vlPo60YXhnya+tjfxyvm8NamU+hrwU6AHqAY2A3cCmcwij41GnBKTQbHNBJZPfFoELFZKrQWCWuuT6b5Xa92ulHoR+N9KqUrgVaK/kG8CPg38FfCKgbH+PbAVOEm0CvQXwAzgiVxuTwhhOCvX1enWrn8AnlRKvU50bfoMMJ9oQpkJu9feTOLP+WfMZf2e+L6cfmZZv13E7jEv8lF4H0Q3s5wABoAdRBegS0CRwffz+MT9XAb6gZeB98d9/ZNEF8KlSb73M8AREx+DOyfuO/Hj8QzjKwf+G7B/4rEbAN4A/hqoMDjWHwG9RCtvp4lWgVba/TySD/mQj6sfFq6rmaxd/xHonlh73wDek3AbTl9708ZvwM8o67d8TPlQE/9gQhQcpdRfAR8B1mitw3bHk8jp8QkhRC4KYW0rhJ9RGEs2dYpCdi/wOQcvlk6PTwghclEIa1sh/IzCQLZUyJVS3wPuA/q11quTfP1OoocAHJ+46Cda67+2LEAhhBBCCCEsYtemzseJHmLw/TTX2aW1vs+acIQQQgghhLCHLS0rWut/w54DVYQQQgghhHAUJ/eQv1sptU8p9ZJSapXdwQghhBBCCGEGp84hfxNYorUeUkrdCzxLilOllFIPAg8CVFRU3LxwwWJDAtBxJ+AqlCG3efW2IyjLXgvltkdAow3/uc0gcRrLmXFO/b/ozDinmi7O6KirCKgIuqgIpcz5mTo7j53TWjeYcuMYtw6bue5OvS+z1mFj92V55bnuFF6NU6V83pn9/8ibj2de9zU+DjqCLtagFKqoiNi/Q1eatdi2sYdKqaXAC8k2dSa5bjewTmt9Lt31fMub9XcfSne41/SGAyNX/l5XX5zXbaUSCL1JfdlNptx2OHBhymVV9RVZ387Z0AGayqb9p7GdxGksp8SpAv4rfy+rr5ny9Z7QURaVNVsZUk6mizPQOcCioy8wcs0Z+t63grm15pxmvXrJPW9ordeZcuMJfMub9f/Nch0eCVw9mLC23ppixbnQXuaUrTXktiKBgUmfV9VXpbhm9npDB5lf5vw3iSVOYyWLsyRwNuX1i+trTY4oue5QF0vLlk9/RZtZEWegM0hV+06Gq1/l3MaZVK2+ZcqavnZB6rXYkRVypVQT0Ke11kqpW4m21gTMvE8rEnGzJCbhuSTgQtgtPgmH5Im4l/jbOqg6tpU3lh+jct58jEvh3MOORNwoZibhonDFkm5VOUbJ4NQE3K7EW0wvcvI0c5b08Nbd17LU15r199uSkCulfkj0tKw5SqlTwF8CpQBa60eIDtP/I6VUGBgBPq5NKuW7NRGXJFx4QaEl4RCtile172C06A0utQSoWt/CkhwWbzeLJeJuTsIlARfZSlfhjldcXwuhc5J8u0xkLAh5pGK2JORa609M8/WHiI5FNI3bEnFJwIVXFGISHi9y8hTzfecJNtew6LZ/Z3c4lnJjIi5JuMhEVsm28JRAZ5Dy157nctF+Xm+G0uRbHqflyJYVM7kpEZckXHhFoSfhifTwoN0hWMptibgk4QIyT7JBEu1C5W/roKJnO8eX7aNs7SJqkvSNZ6qgEvJYMu6mRFyScOFGiQk4SBKeKFwAiZ6bEnFJwguDCifvzU5GkmyRTqAzyMy+LvD1UX73urxbDwsiIXdDIi5JuHAzScAzVzwSQM2vAkbtDsU0bknEZWOm+2VTxQagukQSbWGYWbOHGJlTS2VD/iO3PZ2QO709RVpShBtJ8p274As7KO/dTXtzX859hk6n0e5IxCvDgCThTmVau0go7fRkITIS6xs/U3uAc/NmGjIly7MJuZOr4jo8TngwmoxLEi6cKlniHSMJeHZik1WKK1/F/5FKala3mjZz3G5OPSQksSXlfKhIknGbyAZI4WaJfePJ5o3nwnMJuVOr4pOq4ZWSiAtnSEy6VeUYajD9oTwiO7FkvGTBHvrvXprTfFqRO+kNt14mCbck28KNAp1BlnKci+++TO37PmxoYcVTCXns2GWnJuKxJPxiyNlv5wr3S1fdjpeYcKtQsSThJqivGWJkzixD+gzF9CQJN19i0p14kI0k3MKrxsamnohuBE8l5OCcZFw2aYpcZJpIT0eSamcIj45T1b6DyyW9HJ83VJCncVpJEnFjZXVUuxxkIzzOjL7xeJ5KyJ3QuyiJuDcYlRhfub2EVpBUJJH2Dn9bB+r683Q3vUD14nqqVt/h2b5xu0kinp+sEm8hCkygM0jk5GkqerZzds1hIutvNKX10FMJuZ0kEbeX0Qm00YmxtIIUlkDnADP7uhhZO4uadc0suO1+u0PyHEn
"text/plain": [
"<Figure size 756x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"# extra code this cell generates and saves Figure 57\n",
"\n",
"poly100_kernel_svm_clf = make_pipeline(\n",
" StandardScaler(),\n",
" SVC(kernel=\"poly\", degree=10, coef0=100, C=5)\n",
")\n",
"poly100_kernel_svm_clf.fit(X, y)\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10.5, 4), sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[0])\n",
"plot_predictions(poly_kernel_svm_clf, [-1.5, 2.45, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.4, -1, 1.5])\n",
"plt.title(r\"$d=3, r=1, C=5$\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
"plot_predictions(poly100_kernel_svm_clf, [-1.5, 2.45, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.4, -1, 1.5])\n",
"plt.title(r\"$d=10, r=100, C=5$\")\n",
"plt.ylabel(\"\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"moons_kernelized_polynomial_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Similarity Features"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 17,
2016-09-27 23:31:21 +02:00
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAuQAAAEQCAYAAAD4T2H3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABqQ0lEQVR4nO3dd3gUVdvH8e+dhECAANJ77y0gRYpUG4qKDRVUBERERVEfe+dBeey90ARFBV4LKgKKqARQOtJFkN5CJ5BQUjbn/eNsYBPSs7uz2dyf69qL3dnZmd9k2cnJ2TP3EWMMSimllFJKKWeEOB1AKaWUUkqpwkwb5EoppZRSSjlIG+RKKaWUUko5SBvkSimllFJKOUgb5EoppZRSSjlIG+RKKaWUUko5KMzpAL5Uvnx5U7t27Xxt4+TJk5QoUcI7gfIpULIESg7QLJkJlCyBkgOCL8vKlSsPG2MqeClSQPDGOTs7mzZtwuVy0bRpU5/ux58C6f+2twTbMenxBDZ/HU+W521jTNDe2rRpY/Jr3rx5+d6GtwRKlkDJYYxmyUygZAmUHMYEXxZghQmA86w3b944Z2enW7duJioqyuf78adA+r/tLcF2THo8gc1fx5PVeVuHrCillFJKKeUgbZArpZRSSinlIG2QK6WUUkop5SBtkCullFJKKeUgbZArpZRSSinlIG2QK6WUUkop5SBtkCullFJKKeUgbZArpZRSSinlIG2QK6WUUkop5SBtkCullFJKKeUgbZArpZRSSinlIG2QK6WUUkop5SBtkCullFJKKeUgvzXIRWSiiBwUkfWZPC8i8p6IbBGRtSJyocdzvURkk/u5J/2VWSml1Pnycz5XSil1Pn/2kH8K9Mri+SuBBu7bUOBjABEJBT50P98U6CciTX2aVCmlVFY+JQ/n82ASExdDt0+7sT9+v9NRlFJBwG8NcmPMAuBoFqv0ASYbawlQRkSqAO2BLcaYbcaYRGCae13lkJPJJ9l1fBeuFJfTUVQ6hw/Dpk2QkuJ0EhXM8nE+DxqjFozij11/MGr+KKejKKWCQCCNIa8G7PZ4vMe9LLPlyg92xu7kxegX6TChAxsPbQRg4eGF1HqnFpH/i6TLpC68uehNDsQfcDhp4XT8OIwdCy7330Zjx0LjxpCYaD/a338PTz4J8fHOZVSFUlCft2PiYpi0ehIpJoVJqydpL7lSKt/CnA7gQTJYZrJYnvFGRIZivyKlUqVKREdH5ytUfHx8vrfhLf7MEnM6hok7JvLbwd8AaFqqKfMWz+NAqQPUDavLIw0eYdepXaw+uppH5z7KM789w5SLplA2vKxf8qUqrO9PqrlzKzF6dBNSUlbSpEkctWsX49lnS3HmTBzR0dF8/31t5s2rwOWXLyckBBIThfDwTD8+XlfY35/MBFIWH8nxedvb5+zsxMbG4nK58rWftze/TbIrGYAkVxLDpg7joQYPeSdgHgTj/6dgOyY9nsAWEMdjjPHbDagNrM/kubFAP4/Hm4AqQEdgjsfyp4CncrK/Nm3amPyaN29evrfhLf7Kcuz0MVPqf6VMxEsR5rFfHjM7ju3IMsc/h/4x7y559+zj9QfW+yNmhlmc5K8ss2YZM2OGvZ+QYMzKlVlnSUw8t27jxsa8+abvM2aUw2nBlgVYYfx4/k5/y8v5PLtteuOcnZ1u3bqZqKioPL9+34l9pthLxQwvcvYW8VKEiYmL8V7IXAqk/9veEmzHpMcT2Px1PFmdtwNpyMoMYID76vwOwHFjTAywHGggInVEJBy41b2u8rJEVyIAZYqVYeK1E9n8wGZeu+w1apWpleXrGpVvxIMXPQjAqphVRI2JYuD3AzmddNrnmQsblwueegrGjLGPw8PhwmzqVxQpYv9NTISOHSEqyrcZlSLz83mBN2rBKFJM2os0XMalY8mVUvnityErIjIV6A6UF5E9wAtAEQBjzBhgNnAVsAU4BQxyP5csIsOBOUAoMNEYs8FfuQuLbce2ceWXVzK652hubHojNza9MU/baVmpJc90eYZRC0bx79F/+eHWHyhfvLyX0xY+J09C0aIQFgYzZ0KFCrnfRsmSMHHiucdTpkCTJtC6tfdyqsIhr+fzYLB4z+KznRepEl2JLNqzyKFESqlg4LcGuTGmXzbPG+D+TJ6bjT3BKx/45/A/XDL5Ek4nnaZSyUr52lZoSCgje4ykRaUW3D79drpO6sq8O+fle7uFWUIC9OplG87vvQc1auR/m4mJ8Pzz0LIlTJ+e/+2pwiU/5/OCbtU9q5yOoJQKQoE0ZEU5YNuxbfT8rCfJKcnMHzifi2te7JXt3tT0Jn6+/Wd2xO5g6vqpXtlmYRUeDt26QZcu3t3mvHkwebL3tqmUUkqpvAmkKivKz46dPsbln19OgiuBBQMX0KxiM69uv3vt7qy7dx11L6jr1e0WFsnJEBsL5cvDSy95f/upPe1nzsCDD8Ljj0P9+t7fj1JKKaWypj3khViZYmW4+8K7md1/ttcb46nqla2HiLD2wFoemP3AeRdDqcw9/ji0b28b5b60d6+tV+50xSellFKqsNIe8kLq6OmjlI0oyxMXP+GX/c3bPo8Pln9A1ciqPNXlKb/ss6Dr18/2jpcp49v91KsH//4LpUv7dj9KKaWUypj2kBdCX679kvrv1WfDwayL1cyfPx8R4aeffjq7bPv27Vx//fU8+OCDudrngxc9SL/m/Xjm92f4dduvecpdWCQl2X/btYOnn858vczen4oVK/Lee+/lap+pjfGVK9NWYlFKKaWU72mDvJDZdmwb9866l6YVmtKofKMs1+3WrRs9evRg1ChbX/f48eNcffXVNG7cmLfffjtX+xURxl8zniYVmnDn93dy9PTRPB9DMDtzBi66CN5/P/t1M3t/2rdvz/33563AxZtv2vHqZ87k6eVKKaWUygNtkBciySnJ3Db9NkIkhC9v+JKwkOxHLI0cOZLFixfzyy+/cPPNN1OkSBGef/55QkNDc73/EuEl+OL6Lzh48iBvL85dg76wSEqCpk1zfnFlRu/PtGnT8vT+AHz4IaxYAcWK5enlSimllMoDHUNeiPx3/n9ZsmcJU2+cmu3sm6m6dOnCpZdeyvXXX0+ZMmVYunQpW7ZsAWD37t0MHDiQffv2ERISQu/evXn11VcRkUy317pKa+bdOY8O1Tt45ZiCTWQkfPFFztfP6P0pWbLk2ee7detGbGwsxhgaNmzIxIkTKVWqVKbbu+AC+68x8PvvcMkleT0SpZRSSuWU9pAXEsYYth3bxoCoAdza/NZcvbZ+/fqcOnWKkSNHUr169bPLw8LCePXVV9m4cSOrVq1i6dKlTM/BLDMX17yYsJAwDp86TExcUMymnW+xsXDHHbBnT+5fm9n7AzBjxgzWrFnD2rVrqVmzJq+//nqOtjlxIlx6KSxYkPs8SimllModbZAXEiLCFzd8wcRrc3fF3rhx45g4cSJRUVFMmDAhzXNVqlShbdu2AISHh9OyZUt2796do+0muZLoMKEDQ2cOzVWeYLViBfz4I8Tk8u+TrN4fgNLuqzVTUlI4efJklt9eeLrjDpg2DS72zjxRSimllMqCNsgLgR/++eFsRZXQkJyPLZ47dy7Dhw9n/PjxjBkzhqVLl6ap6OHpyJEjfP/991xxxRU52naR0CIMazuMmZtn8uOmH3OcKVhdeqntHW/XLuevyen7c9VVV1GpUiU2bdrE448/nqNth4fDLbdASIidoEgppZRSvqMN8iAXExfDgO8H8Njcx3L1ug0bNtC3b18ef/xxBgwYQIcOHbj00kt54YUXzls3ISGBm266iYceeogmTZrkeB8jLhpB0wpNGfHzCE4nnc5VvmDhcsGiRfa+x9DvbOXm/Zk9ezb79++nffv2fPTRR7nKt2QJNGgAGzfm6mVKKaWUygVtkAe5R355hITkBN7t9W6OX3Pw4EGuvvpqLrvssrMl9QCee+45li9fzuLFi88uc7lc3HbbbbRu3Zr//Oc/ucpWJLQIH1z5Adtjt/Pqn6/m6rXB4ssvoXPn3I3Vzu79mTVr1nmvCQ0N5c4772Ty5Mm5yle3rp04KLU2ulJ
"text/plain": [
"<Figure size 756x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 58\n",
"\n",
2016-09-27 23:31:21 +02:00
"def gaussian_rbf(x, landmark, gamma):\n",
" return np.exp(-gamma * np.linalg.norm(x - landmark, axis=1)**2)\n",
"\n",
"gamma = 0.3\n",
"\n",
"x1s = np.linspace(-4.5, 4.5, 200).reshape(-1, 1)\n",
"x2s = gaussian_rbf(x1s, -2, gamma)\n",
"x3s = gaussian_rbf(x1s, 1, gamma)\n",
"\n",
"XK = np.c_[gaussian_rbf(X1D, -2, gamma), gaussian_rbf(X1D, 1, gamma)]\n",
"yk = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n",
"\n",
"plt.figure(figsize=(10.5, 4))\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.subplot(121)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.scatter(x=[-2, 1], y=[0, 0], s=150, alpha=0.5, c=\"red\")\n",
"plt.plot(X1D[:, 0][yk==0], np.zeros(4), \"bs\")\n",
"plt.plot(X1D[:, 0][yk==1], np.zeros(5), \"g^\")\n",
"plt.plot(x1s, x2s, \"g--\")\n",
"plt.plot(x1s, x3s, \"b:\")\n",
"plt.gca().get_yaxis().set_ticks([0, 0.25, 0.5, 0.75, 1])\n",
"plt.xlabel(r\"$x_1$\")\n",
"plt.ylabel(r\"Similarity\")\n",
"plt.annotate(\n",
" r'$\\mathbf{x}$',\n",
" xy=(X1D[3, 0], 0),\n",
" xytext=(-0.5, 0.20),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=16,\n",
")\n",
"plt.text(-2, 0.9, \"$x_2$\", ha=\"center\", fontsize=15)\n",
"plt.text(1, 0.9, \"$x_3$\", ha=\"center\", fontsize=15)\n",
2016-09-27 23:31:21 +02:00
"plt.axis([-4.5, 4.5, -0.1, 1.1])\n",
"\n",
"plt.subplot(122)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n",
"plt.plot(XK[:, 0][yk==0], XK[:, 1][yk==0], \"bs\")\n",
"plt.plot(XK[:, 0][yk==1], XK[:, 1][yk==1], \"g^\")\n",
"plt.xlabel(r\"$x_2$\")\n",
"plt.ylabel(r\"$x_3$  \", rotation=0)\n",
"plt.annotate(\n",
" r'$\\phi\\left(\\mathbf{x}\\right)$',\n",
" xy=(XK[3, 0], XK[3, 1]),\n",
" xytext=(0.65, 0.50),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=16,\n",
")\n",
2016-09-27 23:31:21 +02:00
"plt.plot([-0.1, 1.1], [0.57, -0.1], \"r--\", linewidth=3)\n",
"plt.axis([-0.1, 1.1, -0.1, 1.1])\n",
" \n",
"plt.subplots_adjust(right=1)\n",
"\n",
"save_fig(\"kernel_method_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gaussian RBF Kernel"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 18,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('svc', SVC(C=0.001, gamma=5))])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"rbf_kernel_svm_clf = make_pipeline(StandardScaler(),\n",
" SVC(kernel=\"rbf\", gamma=5, C=0.001))\n",
2016-09-27 23:31:21 +02:00
"rbf_kernel_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 19,
2016-09-27 23:31:21 +02:00
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAuQAAAHoCAYAAAABjvqDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADZ00lEQVR4nOy9eXgc5ZXv/321d9uSJbVkGy+yHUsIbBEwO44hdgxJTDC+uUNugicZIJMwWRjunZlwQzKTOzNZhvndmUwyCQQuIWaNgIQJYBtMcBRMbMwivK9CNrYleZFa3bK2bqm39/dHq1rVraruqq69+nyepx9b3dVVr1pVp7593vN+D+OcgyAIgiAIgiAIayiyegAEQRAEQRAEUciQICcIgiAIgiAICyFBThAEQRAEQRAWQoKcIAiCIAiCICyEBDlBEARBEARBWAgJcoIgCIIgCIKwEBLkBEEQBEEQBGEhJMgJgiAIgiAIwkJIkBOugzH2DcbYCcbYGGNsF2Ps+hzb38AY28gYO80Y44yxO/M87mWMsV8zxnomjn2CMfYMY+yjef0i8sdR9fupeV+ubfT6rAiCKBzcHJOVjlWP+Kt0G8KZkCAnXAVj7PMA/hPAvwBYBmAngC2MsYYsb5sO4CCA/wkgnOdx7wTwPoBxAJ8HcCGAOyZe/p/57FPmOPn8forep3Dfmj8rgiAKB7fHZCgYq17xN9/4TzgEzjk96GHKA8ARAO0Apmc8/zqA/6fTMd4F8MuM5zoBPKDw/SMA7lR5zOsAxAD8jczrtTp+hnn9fkrep3bf+XxW9KAHPezzoJis++cpOVa94q/Wz5Ie9n5Qhpwwk88DaAGwUniCMbYWwNUAvifekDH2XcbYSI5HZjlFGYArkLyZiHkdwHIDfh+BHwN4l3P+E6kXOefBzOfM/P2UvM/Cz44gCOugmDxBPr+fEvSKvxSj3U+J1QMgCgfO+X7G2C4AFwHYPBFgfgzgB5zzvozNHwHwmxy7PJ3xcx2AYgC9Gc/3Argxv1FnhzHWhGQ25naVbzXz91PyPtM/O4IgrIVichr5/H5K0Cv+Uox2OSTICbPpANA88f97J/79eeZGExmMKVkMhfCMn5nEc3px+cS/76t5k0W/n5L3mfnZEQRhPRSTofn3U3SIjJ/zjb8Uo10KlawQZtMBoJkxNhPJKdG/45xHMjfKc/qwH0AcwOyM52dialZBL7wT/46oeZPJv5+S91nx2REEYT0Uk2FcyQr0i78Uo10OZcgJs+kA8HcAfgTgPc75JpntVE8fcs4jE9OvNwH4reilmwD8V37DzcnBiX8/DuD5zBcZY17OeUjifab9fkreZ9FnRxCE9VBMTmJIyYpe8ZditPshQU6YzQdIfqP/C0xOLU5Bw/ThfwB4mjH2HoC3AHwNwBwkgy0AgDF2D4B7OOcXTfw8HUDjxMtFABoYY5cBCHLOu7IdjHPezhh7FcDPGWOeiWPyid/tqwD+GcAOK38/pe9TuO+8PyuCIGwJxWQNv5/CseoSfxVuQzgVq21e6FFYDwClSNpRPWjgMb4B4CSS/rO7ANyQ8fo/JU/91M8rkQzYmY8nRNvcOfHcQonjlQP43wD2AxgFMDBx3O8DqLD691P6PoX7zvlZ0YMe9HDOg2Ky5t9NUUzUI/4q3YYeznywiT8wQZjCRDZhCMC1nPP3rB6PUhhj/wzgNgCXcs5jVo+HIAhCDygmE4Q9oEWdhNlcimT24GCuDW3GzUhOqVLgJwjCTVBMJggbYIkgZ4xtYIz1McYkAwBjbCVjbJAxtnfi8X/MHiNhGMsAfMClF9XYFs75VZzzN6weB0EQhM5QTCYIG2BJyQpj7AYkLYme4py3SLy+EsC3OOe3mDw0giAIgiAIgjAVSzLknPM/wVgDfoIgCIIgCIJwBHauIb+OMbaPMbaFMbbU6sEQBEEQBEEQhBHY1Yd8N4AFnPMRxtjNAF4C0CS1IWPsbgB3A0BFRcUVc+fOBwMzbaBq4UiA2fp7kDPGCBgzTj7RgVivc6hwP0v9S+E4uK2vbYHc49T3HFNC5/HOfs55vVH7z4zD8+Y2GHUozfDUucltf206N37Yz73NPfHDeqaO0fyYpgS1nyXnHIjHASTAiwEUFaGI6Xv9dXYel43FltkeMsYWAtgsVUMuse1JAFdyzvuzbdfYeCH/1YN79BmgQQxH3kVl2TVWDyMrThgjYNw4Q4EwanzFuuwrENkNX5lsrw3boPc4Y4FBeH0Vuu0PAM5FDmJ2Wc5wYTnZxskCfgBAma/KzCGhZc3CXZzzK804VlNjM3/swd1mHEoVoUA49f8aX7Ejrk27jzEWGAQAnPd8gOrwhann9b729cAN8cMuiMdoVUxTQnekA/PLmhVt6287AO/xTehqPA7PZXPgbbkKM6sl88CaaFmwRjYW2zJDzhibDaCXc84ZY1cjWVoTyPk+m307I5zLQCCumygvNISbNCGNHW9cbiZTiBP5IXdde30VGIoU2VKEE8ZiZzGuhkDnABbiBEIfK4HvpnWGCHElWCLIGWPPItndqo4x1gPgH5HsFgbO+SNImv1/nTEWAxAG8AVOHYwIk/D6PGk3cUI9dHOeCgv4HX/jchrCdUxCXB3ZxDdBAO4R4wI8PAhUey0dgyWCnHN+e47XHwTwoEnDIQiCMBTh5kWYAwlxdUgJcBLfhBwsFgXgDjEulKrsajwOj2cOvLBu/YstS1YIgnAmVK4yFbdlkuwOifHskPgm8kWIZayy2PHxLNA5AG/7NpR53sLAJ4rgu8G6UhUBEuQEIYHX58GAjos7Cwm6uU9CYtw8SIhLQwKc0IO0WBY5a/Fo9KFuVjFCs2di/IaLLBfjAAlygiB0grLj6ZAYNw8S45OQACf0xo2xLNHVg7HAEZxYNgJrK8cnIUFOEIRu0I0/iRtvYHaEhDgJcMJY3BbLhFKVaHQ7ji8Jg81aaovsOECCnCAIHTDCd9ypuGnBk50pZDGeKcLp2iP0xm1CHEgu4KzofgO9S4+gtHkRFl671uohpUGCnCAk0LM5kNuhUpVJWMAPeNx1E7MbhSjESYATZuJGMS7QsLQSQ82LMNdmYhwgQU4QhA6QQBA5EJQUjlA0m0IS47HAILgnjthwUozTNUYYjdie1Y1iHAB4aBgxn12qxtMhQU4QRN5QdjyJGx0I7EYhiPHM66mohDpgEubg5qy4QHE4AFRaPQp5SJATBJEXgngodMFQCDcyK3G7EM9WjjIUMXs0RKFRCFlxABjY0wGv/yT65gwCqLF6OJKQICcs4bPrPRgYKJryfE1NAi+2Wtu2nurHc0NiPAmJcWNxsxgXC3Err6NPrq9HYGDq5+urieP1Vuow62YKIX4Jrirh6HYEVsRRuqQJC5pWWz0sSUiQE5YgJcazPU/YDxLj7r+ZWYkbxbhdRLgYKTGe7XnC+RRS7Ep09WB6yQcIXF2CqhtW2sbiUAoS5AQhQhABhDyFbnFYSDczq3CbGLejECcKj0KMXcXhACpneRC/uMHWYhwgQU4QU3CLCDCCQl/EWYg3NLNxixgnEU7YhUKpE88kNhRG+ZmdeH9VL0phbzEOkCAniBRUO56dQq8bJzFuPG64BkmIE3ahUIW40AAourYZ/tvCqGpZbfvsOECCnCAAUKlKLkiMkxg3EjdkxUmIE3ahUIU4kFzEOb33GEoW9aLYeykWXrHe6iEphgQ5YQk1NQlZlxWzcYMYMJJCFuOFfGMzC6dff04X4r6auKzLCuEcKFZNUjerGKG6GWAVzroeSZATlmC1tWEmThUDRkNinG5uRuJkMe6Wa4OsDZ2LWIQDFKsEeHgQqLZnN85skCAnCho31KwahVsERz6QGDcep4rxQr4uCOshEZ6dopE+q4eQNyTIiYKF6sblKVTRQdO+5uBEMV6o1wRhLZkCHKDYJIe/7QC8xzfhnWVdqPLUWz2cKUSGAllfJ0FOFCROFARmUajCg7Li5uC0a8/pNeKEc2ABP5gnCjZMWXA1CN04o9Ht8C8LY8byFixoWo3uoxGrh5YilxgHSJATLuez6z2Si0erZ3iw5bnC9tSWohDFOGXFzcNJYpyEuL58cn2
"text/plain": [
"<Figure size 756x504 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 59\n",
"\n",
2016-09-27 23:31:21 +02:00
"from sklearn.svm import SVC\n",
"\n",
"gamma1, gamma2 = 0.1, 5\n",
"C1, C2 = 0.001, 1000\n",
"hyperparams = (gamma1, C1), (gamma1, C2), (gamma2, C1), (gamma2, C2)\n",
"\n",
"svm_clfs = []\n",
"for gamma, C in hyperparams:\n",
" rbf_kernel_svm_clf = make_pipeline(\n",
" StandardScaler(),\n",
" SVC(kernel=\"rbf\", gamma=gamma, C=C)\n",
" )\n",
2016-09-27 23:31:21 +02:00
" rbf_kernel_svm_clf.fit(X, y)\n",
" svm_clfs.append(rbf_kernel_svm_clf)\n",
"\n",
"fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10.5, 7), sharex=True, sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"for i, svm_clf in enumerate(svm_clfs):\n",
" plt.sca(axes[i // 2, i % 2])\n",
" plot_predictions(svm_clf, [-1.5, 2.45, -1, 1.5])\n",
" plot_dataset(X, y, [-1.5, 2.45, -1, 1.5])\n",
2016-09-27 23:31:21 +02:00
" gamma, C = hyperparams[i]\n",
" plt.title(fr\"$\\gamma = {gamma}, C = {C}$\")\n",
" if i in (0, 1):\n",
" plt.xlabel(\"\")\n",
" if i in (1, 3):\n",
" plt.ylabel(\"\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"moons_rbf_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# SVM Regression"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 20,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('linearsvr', LinearSVR(epsilon=0.5, random_state=42))])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.svm import LinearSVR\n",
2016-09-27 23:31:21 +02:00
"\n",
"# extra code these 3 lines generate a simple linear dataset\n",
"np.random.seed(42)\n",
"X = 2 * np.random.rand(50, 1)\n",
"y = 4 + 3 * X[:, 0] + np.random.randn(50)\n",
"\n",
"svm_reg = make_pipeline(StandardScaler(),\n",
" LinearSVR(epsilon=0.5, random_state=42))\n",
2017-06-01 09:23:37 +02:00
"svm_reg.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 21,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnsAAAEQCAYAAADI77KTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAACPyUlEQVR4nOydeXhM1/vAPzeThdiLIEQsCUKordS+fG21tVqlSmvfKbXvRRL7VlpK1dZaqkWrWr+qora2SqON2GIJkZBIECKyzMz5/TFmmj0zmSWTOJ/nmSeZO/ee887kzpv3vOddFCEEEolEIpFIJJL8iUNuCyCRSCQSiUQisR7S2JNIJBKJRCLJx0hjTyKRSCQSiSQfI409iUQikUgkknyMNPYkEolEIpFI8jHS2JNIJBKJRCLJx0hjTyKRSCQSiSQfI409SZ5HUZRRiqLcVBQlQVGUc4qitDDimrmKoog0j3u2kFcikUhSoihKS0VR9iuKEv5cFw0w4prpiqL8pSjKY0VR7iuK8oOiKL42EFeSB5HGniRPoyhKb+BjYAFQDzgNHFQUpaIRl18ByqV41LaWnBKJRJIFhYELwDjgmZHXtAbWAk2BtoAaOKwoykvWEFCSt1FkBw2JtXm+Sh0HVEenkC4CzYUQaguM/SfwrxBiaIpjIcC3QojpWVw3F+gphJArYYlEki3W1GNp5okDxgghtph4XWEgFnhDCPGDJWWS5H2kZ09iVRRF6YbO87YE8AFeBRamVZCKosxQFCUum0eLNNc4Aw2AQ2mmPYRutZsdVZ5vm9xUFGWXoihVcvo+JRJJ/sWaesyCFEH3P/2hlcaX5GEcc1sASb6nBhAG/CyEePD82MUMzvsM2J3NWOFpnpcCVEBkmuORQLtsxvoTGABcBtyAWcBpRVFqCSFisrlWIpG8WFhTj1mKj4HzwO9WGl+Sh5HGnsTafAH0BmIURXkKvCqEuJD2pOcK9EHa40aSNhZByeBY2vkOprpAUf4AbgD9gRU5lEMikeRPbKHHcoyiKCuA5ui2lTW2nl9i/8htXInVUBTFEdgJ/A28AtQFLmVybk62P6IBDVA2zXE30nv7skQIEQcEA96mXCeRSPI3NtBj5sq3EugDtBVC3LDk2JL8g/TsSaxJD6CWEKKjEeeavP0hhEhSFOUc0B74JsVL7YE9pgiqKEoBdFs1R025TiKR5HusqsfMQVGUj4F3gNZCiMuWGleS/5DGnsSauABuiqL0B34DXIEmwC4hxNOUJ5qx/bEC+FJRlDPAKWAE4I5O6QKgKMoYdNltNVIcWwb8ANxG5wmcDRQCtuZABolEkn+xuh57nknr9fypA1BRUZS6wAMhxO3n56TSY4qifAq8B7wBPFQURb/DEfd8p0IiMSCNPYk12YVuy8MPKAM8Ak4LIb6w1ARCiK8VRSmJLsGiHLpaVZ2FELdSnFYKXbmElFRAtzVTCrgP/IEuDucWEolE8h9W12NAQ1LvKsx7/tiKLpEM0uuxUc9//ppmrHnAXAvKJskHyDp7EolEIpFIJPkYmaAhkUgkEolEko+xmLGnKMomRVGiFEW5kOLY24qiBCuKolUUpaGl5pJIJBKJRCKRGIclPXtbgE5pjl0A3gSOW3AeiUQikUgkEomRWCxBQwhxXFGUSmmOXQJQFMVS00gkEolEIpFITEDG7EkkEolEIpHkY+yi9IqiKMOAYQAFChRoULFixVyWSIdWq8XBIfftYSlHeuxFFilHeiwlixCClNUCFEUxaZfg6tWr0UKI0mYLYiRSj0k5TMVeZJFypCff6TG9IJZ4AJWACxkcPwY0NGaMatWqCXvh6NGjuS2CEELKkRH2IouUIz32IgtwVlhQv5nykHosPVKO9NiLLFKO9NiLLJbSY/ZhQkskEolEIpFIrIIlS6/sBH4HqiuKckdRlMGKovRQFOUOutYyPyqK8rOl5pNIJBKJRCKRZI8ls3H7ZPLSPkvNIZFIJBKJRCIxDbmNK5FIJBKJRJKPkcaeRCKRSCQSST7GLkqvmMLjx4+JiooiOTnZ6nMVK1aMS5cuWX2ezHBycsLNzS3X5pdIJNbhRdJjhQoVokKFCrk2v0QiyWPG3uPHj4mMjKR8+fIULFjQ6p05njx5QpEiRaw6R2YIIXj27Bnh4eGoVKpckUEikVieF0mPabVawsPDiY6OzpX5JRKJjjy1jRsVFUX58uVxdXXN9y3YFEXB1dWV8uXLU6hQodwWRyKRWIgXSY85ODhQpkwZYmNjc1sUieSFJk8Ze8nJyRQsWDC3xbAptlj5SyQS2/Gi6TEnJyfUanVuiyGRvNDkKWMPeOEMnxft/UokLwIv0vf6RXqvEom9kqdi9iQSif2gVqsJCwsjIiKCpKQknJ2dcXd3z22xJBKJxGheFD0mjb1c4P79+3z66aeMHj2a0qVt1qddIrEIQgguXLhASEgIABqNxvBaZGQkDg4OBAUF4evrK706+RipxyR5mRdNj+W5bdz8wMiRIzl79iyjR4/ObVEkEpMQQnDq1ClCQkLQaDSpFCToFKYQgpCQEE6dOoWuj7ckPyL1mCSv8iLqMWns2ZgdO3bg4uLCgQMHcHJyYvfu3bktkkRiNBcuXCAqKiqdckyLRqMhKiqKCxcu2EgyiS2RekySl8kLeuz3339n6tSpFhtPbuPamHfffZd3330XgO3bt+eyNBKJ8ajVasNKWM+JEx7s3FmbmBhXSpaMp0+fINq0uQvoFGVISAg+Pj44OkpVk5+QekySV7F3PXbp0iXGjBnDkSNHKFmypMXGlRpYIsnnZBaA7OHhYZLyCgsLS/X8xAkP1q9vSFKSbozo6EKsX98QlSqQpk1DU11XuXJli7wXY7hw4QLLly+32XwSicT65Gc9JoQgNjaW4sWLU7hwYa5du8ayZcsYMWIEhQsXtsgc0tiTSPIp2QUgBwYG4u3tbXQAckRERKoxdu6sbVCQepKSHNm+vZZBSWo0GiIiImxi7N29e5c5c+awadOmXOsYIZFILEt+1mNCCH744Qf8/f0pWrQohw8fxsPDgxs3bli8c5aM2bMRFSpUYMWKFamOBQUFUaBAAS5evJhLUknyK8YEIOu3J4wNQE5KSkr1PCbGNcPzoqNTFwy2Rf9X0K28t23bxrhx47h+/bpN5nzRkHpMYkvyqx7TarV888031KtXj9dff53o6Gh69+5tkN8aLVKlsWcjmjRpwl9//ZXq2Pjx4xkyZAg1a9bMJakk+RVrBCA7Ozunel6yZHyG55Uq9SzVcycnp2zHzglqtZr169czffp0ABo1akRYWBgrVqywaKyL5D+kHpPYkvyqx9avX0+vXr1ISEhg69atXL16laFDh1q1xIs09mxEWiX53XffERgYyLx583JRKkl+JLMA5FGjOtO7d09GjerMiRMehtf0K+PsWlq5u7unWnH26ROEs3Pqa5yd1fTtG2x4rlKpLF6gVL/1UadOHUaMGMHp06cNq243NzeLziVJjdRjEluRn/RYUlISmzZt4scffwSgX79+fP311wQHB/P+++/bJPHDYsaeoiibFEWJUhTlQopjLymK8ouiKCHPf5aw1Hwp5rDao2jRopm+Ziqvvvoq169f58GDByQmJjJp0iTmzJkjPRASi5NZAHJ0dCGEUAwByCkVZUbXpcXDI/X5LVqEMXz4WUqVeoqiCEqVesrw4Wdp2fJOlteZw9WrV2nbti3du3dHrVazb98+jh07ZjXvoa2QekwiSU1+0GMJCQmsXbsWb29vBg8ezI4dOwAoUqQIvXr1ssp2bWZY0pzcAnwCbEtxbBrwqxBikaIo054/t1zhmDxEgwYNcHZ25uzZswQGBuLo6CiLkUqsgrEByDt31qZFC51iNCYA2dHREW9v71Sr7RYtwgxjpDgT0K2Gvb29LbJqFUKgKAoFCxbkxo0bfPLJJwwbNizPG3l5DanHJLYir+uxHTt2MHnyZCIiInj11Vf57LPP6NSpk0ljWBKLGXtCiOOKolRKc/h1oPXz37cCx7CwsWfNytZPnjyxWFafi4sL9erV44cffmDr1q3s2LFD/qOSWAVjA5DTHjcmANnX15fY2Nh
"text/plain": [
"<Figure size 648x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"# extra code this cell generates and saves Figure 510\n",
2016-09-27 23:31:21 +02:00
"\n",
"def find_support_vectors(svm_reg, X, y):\n",
" y_pred = svm_reg.predict(X)\n",
" epsilon = svm_reg[-1].epsilon\n",
" off_margin = np.abs(y - y_pred) >= epsilon\n",
2016-09-27 23:31:21 +02:00
" return np.argwhere(off_margin)\n",
"\n",
"def plot_svm_regression(svm_reg, X, y, axes):\n",
" x1s = np.linspace(axes[0], axes[1], 100).reshape(100, 1)\n",
" y_pred = svm_reg.predict(x1s)\n",
" epsilon = svm_reg[-1].epsilon\n",
" plt.plot(x1s, y_pred, \"k-\", linewidth=2, label=r\"$\\hat{y}$\", zorder=-2)\n",
" plt.plot(x1s, y_pred + epsilon, \"k--\", zorder=-2)\n",
" plt.plot(x1s, y_pred - epsilon, \"k--\", zorder=-2)\n",
" plt.scatter(X[svm_reg._support], y[svm_reg._support], s=180,\n",
" facecolors='#AAA', zorder=-1)\n",
2016-09-27 23:31:21 +02:00
" plt.plot(X, y, \"bo\")\n",
" plt.xlabel(r\"$x_1$\")\n",
" plt.legend(loc=\"upper left\")\n",
2016-09-27 23:31:21 +02:00
" plt.axis(axes)\n",
"\n",
"svm_reg2 = make_pipeline(StandardScaler(),\n",
" LinearSVR(epsilon=1.2, random_state=42))\n",
"svm_reg2.fit(X, y)\n",
"\n",
"svm_reg._support = find_support_vectors(svm_reg, X, y)\n",
"svm_reg2._support = find_support_vectors(svm_reg2, X, y)\n",
"\n",
"eps_x1 = 1\n",
"eps_y_pred = svm_reg2.predict([[eps_x1]])\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 4), sharey=True)\n",
"plt.sca(axes[0])\n",
"plot_svm_regression(svm_reg, X, y, [0, 2, 3, 11])\n",
"plt.title(fr\"$\\epsilon = {svm_reg[-1].epsilon}$\")\n",
"plt.ylabel(r\"$y$\", rotation=0)\n",
"plt.grid()\n",
"plt.sca(axes[1])\n",
"plot_svm_regression(svm_reg2, X, y, [0, 2, 3, 11])\n",
"plt.title(fr\"$\\epsilon = {svm_reg2[-1].epsilon}$\")\n",
2016-09-27 23:31:21 +02:00
"plt.annotate(\n",
" '', xy=(eps_x1, eps_y_pred), xycoords='data',\n",
" xytext=(eps_x1, eps_y_pred - svm_reg2[-1].epsilon),\n",
2016-09-27 23:31:21 +02:00
" textcoords='data', arrowprops={'arrowstyle': '<->', 'linewidth': 1.5}\n",
" )\n",
"plt.text(0.90, 5.4, r\"$\\epsilon$\", fontsize=16)\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"svm_regression_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 22,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('svr', SVR(C=0.01, degree=2, kernel='poly'))])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.svm import SVR\n",
"\n",
"# extra code these 3 lines generate a simple quadratic dataset\n",
"np.random.seed(42)\n",
"X = 2 * np.random.rand(50, 1) - 1\n",
"y = 0.2 + 0.1 * X[:, 0] + 0.5 * X[:, 0] ** 2 + np.random.randn(50) / 10\n",
"\n",
"svm_poly_reg = make_pipeline(StandardScaler(),\n",
" SVR(kernel=\"poly\", degree=2, C=0.01, epsilon=0.1))\n",
2017-06-01 09:23:37 +02:00
"svm_poly_reg.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 23,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnsAAAEQCAYAAADI77KTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAACUtklEQVR4nOydeXhM1xvHPycrsRMJIbHGHhS1xb4WpVpVRZWqpahS1aJUKfpri6qdqlYttbQUVUoRxL7FElViD0EkSGRf5vz+mGSaZZLMTGaSSZzP88zD3HvuOe+9c+837z3L+wopJQqFQqFQKBSK/IlNbhugUCgUCoVCobAcytlTKBQKhUKhyMcoZ0+hUCgUCoUiH6OcPYVCoVAoFIp8jHL2FAqFQqFQKPIxytlTKBQKhUKhyMcoZ0+hUCgUCoUiH6OcPYVCoVAoFIp8jHL2LIAQ4j0hREBu26FQKBTmQGmaQpG3Uc6eZagPnMtlG3IMIcQkIcQpIUS4EOKREOIPIUQdE+qpL4RYJ4S4K4SIEULcFEKsFULUNbO9I5PqjhFCnBFCtMzuMUKIVkKI7UKIe0IIKYQYZE6bzYWJ554nzk1hUeqjNM0qNc3Q59OQZ98UfchpjLVR6ZcW5exZhvqAnyUbEELYWbJ+I2kDLAGaA+2ABGCvEKKkoRUkPYCngVigD1ANGJi0e4y5DBVC9AHmA18CLwBHgV1CCI9sHlMY8E+yNdpc9poTU849Cas/N4XFqY/SNKvUNAx4Pg159rOhDzmGiTYq/QKQUqpPNj5oH+B9aG+ii0BjIBLomrS/HLAaCAWeApsB1zR11AB8UtTRHIgHWiftLw9I4E1gPxADDDakfkPat8A1KQwkAt0NLN8MrZh+mMH+kma07QSwIs22AOB/5joGiAAGmWDbILR/UKOAcOA4YJeb526uc1OfvPNRmqb3mlitpqWpV+/zacizbyZ9sGoNe571S/XsZQMhhCdwCu1bgxcwEdgEOAHnhBCVgLPAPaAF2rdFZ2BZijpqACeT6mkATAI2AnbAhaRi9ZP+nQDMAWoD27Kq35D205zPp0KIiCw+hnTrF0Hba/zEgLIAc4ETUsp5+nZKKR+bw1YhhAPQENiTpro9aP8YpcOUY0xBCNEd7RvrN0BNoClaAUvQUzZHzl3x/KE0LUOsUtMMwZBn3xz6YKiGmXqeSsOyhzV1m+dFFgE7pZTJXfLXhBA9gR5SyiAhxG5gpZTy0+QDhBAzgC0p6lgA/C2l/CTp+2UhxBtAKyllsrDUQ/vm21tKeS1FXb9kUf8yA9pPyTK0wp4Z97LYD9oH/hxwLKuCSX9cmgF9Dag3JabY6gzYAg/TbH8IdMigDlOOMYUaQCCwO8Ufgn8yKJtT5654/lCaph9r1TRDMOTZN4c+GKphpp6n0rBsoJw9ExFCuAOd0A5xpCQe7RuwR9L+lkKID1Lst0XbxZ1cR0f+e8tNJhY4n+J7fbQCnFIUM63fkPbTkvSApnvjNAYhxLdo37hbSCkTDTikQdK/p41pJ5u2yjTfhZ5t5jjGGFaindcTKoSIBJpKKf31GpLz5654DlCapp88omkGNZHmu75nPzv6YJCGmeE8lYaZgHL2TOcFtHM4zqfZ3gA4iFbMwtF2O6clLkXZBP4b2kimJto5M8nUQ/vGnZKs6jek/VQIIT4FPtW3LwVdpJS+GRw/D+0cnLZSyhtZ1JOMU9K/EQaWT27LFFtD0P5mZdKUcyH922J2jjGKpInp69EOT72Hdh7SzUzK59S5K54vlKalP97aNc0QDHn2s6UPxmhYNs5TaVg2UM6e6Ui0b5SOJAlN0jyDJsB3aN+GCwEPpJQZPfSJSXU4oZ0AjRCiIeANfJv0vRBQBe1DlJJM6xdC1Dag/bSYPIwghJiPVhTbSCn/NbA90M4NAmiNdl5P2nqdpJT63tqNtlVKGSeEOIO25+HXFLs6op3knQ5TjjGBV4HaUsrOBpbPkXNXPHcoTUvdntVrmiEY8uybQR+M0TCTzlNpWDbJ7RUiefUDlEU752Q5WuHqBtxGK5g1gBLAI+B3tG/MVdDelIsBm6Q6yqBdrbYkaf9LwOWkOionlWmGVkALp2k/0/oNad+M12Ix2jfudknnlPwpbODxfwLBaFdyeQJVgTeAv9EOnZjT1j5o/5ANQdvbMB/tG3iFFGXeB/418pjCaHse6qMdUpqa9H8PA2x6K6n+gUBFoBbwLlAot889u+emPnnnozQtlS15SdOyfD4NfPazLJOJDVarYUq/kq5DbhuQlz9oJ+DeRrtC6wAwHe3bbLLwNUIbVuAp8Azt8MikNHW8AdxKOm4H2tVpwSn2v0eaP74p9mVavyHtm+k6yAw+01KUGZS0raKe4x2BT9AO/UQmXc8zwBdAAQvYOzLpmscmtdMqzf5pgDTymDYZXINVWV0DtD3sc4A7SfU/BH630D1ryrlneW7qkz8+StN07eQZTTP0+czq2TekTF7UMKVf2o9IuhgKK0AIIYC/gOtSypG5bY85EUJMB14H6kk94USeB9Q1UDxvKE3LXzyP55xfUHP2chEhRAu0QwNngVLAh2i7l9/JRbMsRVfg/edcINQ1UORrlKble57Hc84XmM3ZE0L8CLyMtrs+XQ7BpDe8+Whvlii0UazTTtB93igDfI02IvwjtMMmDaWUQblplCWQUr6Y2zbkNuoaKJ4DlKblY57Hc84vmG0YVwjRCu1EydUZOHtdgdFonb0mwHwpZROzNK5QKBQKhUKh0IvZ0qVJKQ+ReaDEV9A6glJKeRwoLoQoa672FQqFQqFQKBTpycncuOXQplJJ5m7SNoVCoVAoFAqFhcjJBRpCzza9Y8hCiGHAMIACBQo09PDwsKRdWaLRaLCxyUm/WNmh7FB2ZJerV6+GSClL51b71qZjYD2/jbLDumxQdlivHWbTMTPHv6kI+GewbznQN8X3K0DZrOqsVq2azG18fHxy2wQppbIjLcqO1Cg7UgOcllYQ30paiY5JaT2/jbLDumyQUtmRFmuxw1w6lpNu63bgbaGlKRAmpbyfg+0rFAqFQqFQPHeYM/TKerSRqp2FEHeBzwF7ACnlMmAn2pW419CGXsmPcZcUCoVCoVAorAqzOXtSyr5Z7JfAKHO1p1AoFAqFQqHImtyffahQKBQKhUKhSEVMTIzZ6lLOnkKhUCgUCoUVERQURIUKFcxWX57OjavRaLh79y6RkZEWbadYsWJcvnzZom1khb29Pba2trlqg0KhsAzh4eEEBwcTHx9v0XZyW8vs7e1xcXHJtfYVirzCtGnTePLkidnqy9POXkhICEIIqlevbtF4OM+ePaNIkSIWqz8rpJRER0cTGRlJeHg4RYsWzTVbFAqFeQkPD+fhw4eUK1eOggULok0jbhlyU8uSdezevXvqxVWhyITLly+zcuVK3n//fRYsWGCWOvP0MO7Tp09xdXW1isCHlkQIgZOTE+XKlSM4ODi3zVEoFGYkODiYcuXK4eTkZFFHL7dJqWOFChXKbXMUCqtl0qRJFCpUiClTppitzjztJSUmJmJvb5/bZuQYDg4OFh/mUSgUOUt8fDwFCxbMbTNyDEv3XioUeZm4uDhsbGyYOHEipUubLwFQnh7GBZ4r0XiezlWheJ54np7t5+lcFQpjcXBwYMuWLcmZxsxGnu7ZUygUCoVCocgPHD58mCtXrgDmfylSzl4u8+jRI6ZNm8ajR49y2xSFQqEwGaVlCoXpxMbG8vbbb9OvXz+z9+qBcvZynREjRnD69GlGjVLJRRQKRd5FaZlCYTqLFy/m5s2b/O9//7PIVAfl7OUiv/zyC46OjuzYsQN7e3s2bdqU2yYpFAqF0SgtUyhM5/Hjx8yYMYPOnTvTqVMni7SR5xdo5GX69etHv379AFi3bl0uW6NQKBSmobRMoTCdGTNmEB4ezjfffGOxNlTPnkKhUCgUCkUu4ejoyIgRI6hbt67F2lA9ewqFQqFQKBS5xFdffWWRRRkpUT17uUD58uX59ttvU227ePEiBQoU4J9//sklqxQKhcI4lJYpFKZz/Phx9u3bB1g+/qTq2csFmjVrxqlTp1JtGzt2LEO
"text/plain": [
"<Figure size 648x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 511\n",
"\n",
"svm_poly_reg2 = make_pipeline(StandardScaler(),\n",
" SVR(kernel=\"poly\", degree=2, C=100))\n",
"svm_poly_reg2.fit(X, y)\n",
"\n",
"svm_poly_reg._support = find_support_vectors(svm_poly_reg, X, y)\n",
"svm_poly_reg2._support = find_support_vectors(svm_poly_reg2, X, y)\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 4), sharey=True)\n",
"plt.sca(axes[0])\n",
"plot_svm_regression(svm_poly_reg, X, y, [-1, 1, 0, 1])\n",
"plt.title(f\"$degree={svm_poly_reg[-1].degree}, \"\n",
" f\"C={svm_poly_reg[-1].C}, \"\n",
" fr\"\\epsilon={svm_poly_reg[-1].epsilon}$\")\n",
"plt.ylabel(r\"$y$\", rotation=0)\n",
"plt.grid()\n",
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plot_svm_regression(svm_poly_reg2, X, y, [-1, 1, 0, 1])\n",
"plt.title(f\"$degree={svm_poly_reg2[-1].degree}, \"\n",
" f\"C={svm_poly_reg2[-1].C}, \"\n",
" fr\"\\epsilon={svm_poly_reg2[-1].epsilon}$\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"svm_with_polynomial_kernel_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Under the hood"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 24,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAm4AAADWCAYAAABorg4iAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA3iUlEQVR4nO3deXwURfrH8U8l5IYEAqKSVVl+sIiwnijiBYgHhyK3CEEuwVuUVVfFC1kU3fUGRBRcNSEcAiLrFdGAqCAiKooIyrIoIAgmXLmP+v1RgRBASCCZnk6+79drXtKVme4n40zl6aerq4y1FhEREREJfiFeByAiIiIiZaPETURERMQnlLiJiIiI+IQSNxERERGfUOImIiIi4hNK3ERERER8QombiIiIiE8ocRMRERHxCSVuUi0ZYy4yxrxljNlojLHGmIFexyQi1Zsx5iZjzDpjTI4x5ktjzIWHef7Dxf3Xvo/NgYpXvKHETaqrmsB3wHAg2+NYRKSaM8ZcDTwLPAqcAXwGvGuMOfEwL10NHL/P46+VGad4T4mbBIwxZokx5q59ticXnyEeV7wda4zZZYxpU9mxWGvfsdbeZ619Ayiq7OOJSPAJpj4JGAH821r7krV2lbX2VuBX4MbDvK7AWrt5n8fWyg9VvKTETQJpO1ALwBhTH+gJpAN1in8+APjJWruwLDszxtxnjNl9mMchLzWISLW2nSDok4wx4cBZQOp+P0oFzjvMYRsVD/lYZ4yZZoxpVJZYxb9qeB2AVCsZuEuUADcDc4DTgfjitpuAsQDGmLeAC4EPrbU9/2B/E4EZhznmxqOIV0SqtmDpk+oBocCW/dq3AJccYl+fAwOBH4D6wP3AZ8aY5tba3w8Th/iUEjcJpO1ALWNMJHADcBnwHFDHGHMJ7ix3WvFznwZewp3xHpS1Nh13diwiciS2E1x9kt1v2xykbd/jvVvqycYsAf5bHONTRxGHBDFdKpVA2nN2mwh8Z639BtiJ6xxvAV6w1uYCWGvTgF2H2pkulYrIUQqWPmkbUAgct197fQ6swv0ha+1uYCXQpKyvEf9RxU0CaTtuPMntwN+L23bgLk1cCgwr5/50qVREjsZ2gqBPstbmGWO+LD7mzH1+dCkwq6wHL64cngyklfU14j9K3CSQMoA2uI7rneK2nbjOcYa19rfy7OxoLksYY2oCjYs3Q4ATjTGnA+nW2p+PZJ8i4jtB0yfhLm2+boxZCnyKu3TbAJcMAmCMuQW4xVp7cvH2v4B5wM+46twDQAzw6hHGID6gxE0Cac9liWestXvGbezY0xbgWFpS+qx0VPHjVdxgXxGp+oKmT7LWTjfG1MXdYHA8bp7JTtba9fs8rR7QdJ/tPwEpxe1bgSXAufu9RqoYU/JZFQkuxpi2uLPLP7qDS0QkYNQnSTBQ4iZByRgzHzgNV/ZPB3pZaxd7G5WIVFfqkyRYBCxxM8acALyGu2umCJhkrX02IAcXERERqQICmbgdDxxvrV1ujKkFfAl0tdZ+H5AARERERHwuYPO4WWt/tdYuL/73LmAVkBCo44uIiIj4nScT8BpjGgJn4JbrEBEREZEyCPh0IMXzZ80CbrfW7tzvZ8MonvAwMjLyrBNPPDHQ4R21oqIiQkL8tyBFUVERxhivwyiXH3/8kSZN/DdBuLXWF5+RggLDpk1R5OSEYgzUr59FXFyh12GVy5o1a7ZZa4+pjH1Xlf7Kb997a63vYgZ/91d+e7/90sfur6z9VUDvKjXGhAH/Ad631h5yHbWmTZva1atXByawCrRgwQLatm3rdRjl9sEHH9C8eXOvwyiXhIQENm7038IIa9asCfrPyGefQY8esHkznHgizJkDO3f677NtjPnSWtuyso/j1/7Kj9/7lStX+i5m8G9/5cf32w997MGUtb8KWEpqXMo+GVh1uKRNRLzz0kvQtq1L2tq2hWXL4MwzvY5KREQgsGPczgf6AxcbY74ufnQK4PFF5BDy8uCGG2DYMMjPh+HDITUVjqmUC40iInIkAjbGzVr7CeCvC+Ui1cTmzdCzJ3z6KUREwKRJcO21XkclIiL701qlItXc0qXQrRts2gR/+pMbz9ay0keFiYjIkfDfbRciUmFeeQUuvNAlbRde6MazKWkTEQlevqy4FRUVsW3bNrZv305hYXBNTxAXF8eqVau8DqPc6tatS3p6eoXv1xhDREQEUVFRvrulvCrLz4c77oDx4932zTfDU09BeLi3cYmIyKH5MnHbsGEDxhgaNmxIWFhYUCUEu3btolatWl6HUW47d+4kMjKyQvdpraWgoICtW7eya9cuYmNjK3T/cmR++w169YKPP3aJ2oQJMGSI11GJiEhZ+PJSaWZmJgkJCYSHhwdV0ialGWMICwvjuOOOo6CgwOtwBPjyS3cp9OOPoUEDWLhQSZuIiJ/4MnEDfDkrcnWl/1fB4bXX4Pzz4Zdf4Lzz3Hi2c8/1OioRESkP/UUVqeIKCtx4tgEDIDfXzdOWlgbHH+91ZCIiUl6+HOMmImWzbRtcfTV89BGEhcHzz8P113sdlYiIHCklbiJV1FdfufnZ1q+H446DN95wl0pFRMS/dKlUpApKSXFJ2vr1cM45bjybkjYREf9T4hZg1lqeeOIJmjZtSlRUFPXr16dHjx4Vtv9zzz2Xf/7zn3u3hwwZgjGGzZs3A27aj1q1arFw4cJD7mfWrFnUqlWL9evX720bMWIEJ598Mlu2bKmweKViFRTAXXdB376QnQ2DB7s7RxMSvI5MREQqghK3APvnP//JK6+8woQJE/jhhx946623uPTSSw943qOPPkrNmjUP+Vi0aNEBr6tduza7du0C4LfffuONN94gPj6ejIwMAF599VUaN25MmzZtDhln9+7dadGiBWPHjgXgqaeeYsaMGcybN49jjz32aN8GqQTp6dCpE/zrX1CjBowbBy+/DBU8PZ+IiFQwa8v+3Cozxs2r6dzK82YDvPfee3Tq1In27dsDcNJJJ3HuQeZkuOGGG+jdu/ch95VwkDJKnTp12L17NwDjx4+nW7dufP3113tXRZgwYQL33HMPAF26dGHRokW0b9+eKVOmlNqPMYZHHnmErl270qhRI8aOHcv7779PkyZNAJfYffrpp7Rr145p06aV702QCvftt9C1K/z3v3DMMW4820UXeR2ViIgcSnq6m6pp0qSyv6bKJG5+0aVLF/72t7/xzTff0KtXL3r06EG9evUOeF58fDzx8fHl3v+eiltOTg4TJ04kNTWV2267jYyMDObPn09GRgZ9+vQB4I477mDo0KG8+uqrB93XpZdeSsuWLXnooYeYPXs2LfdZxPK2225j8ODBJCUllTtGqVgzZ8LAgZCVBWed5RaJP+EEr6MSEZGDsRY++wxefNH13zk55Xt9lblUaq03j/K6/fbbWb16NR06dGDChAn83//930HXNj3SS6V7Km5JSUm0aNGC0047jdjYWDIyMhg3bhw33ngjERERALRr1+6Qy3OlpaWxYsUKrLUHXB5t27atL5f2qkoKC+G++6B3b5e09e8PixYpaRMRCUYZGfDcc/DXv8IFF8Drr7uk7bLLYNassu9HFTcPNG7cmDvvvJPhw4dTt25dVqxYQbNmzUo950gvle6puD3zzDM8/vjjgFv4/uuvv+aDDz5gUhnrsStWrODqq6/m6aef5t133+X+++/n7bffLuNvKJUtIwP69YN334XQUHjySbjtNu+GDIiIyIGshSVLXHVt+vSS6lr9+u7msaFDoVGj8u1TiVsAPf744xx77LGcc8451KhRg1dffZXw8HDatm17wHOP9FJpnTp1WLhwIQkJCXTq1AmA2NhYJk2aRO/evalfv/5h97F+/Xquuuoqhg8fzsCBAzn77LM566yzWLhw4WFvapDKt3KlG8/2009Qt64rtbdr53VUIiKyx44dkJTkErZvvy1pv+QSNwl6ly4QHn5k+1biFkC5ubk8/vjjrF+/nujoaM4991w+/PDDCr1Lc8+l0ttvvx1TXH6Ji4vb23Y46enpXHnllXTs2JGRI0cC0Lx5c3r06MEDDzzAxx9/XGGxSvnNmQPXXgu7d8P
"text/plain": [
"<Figure size 648x230.4 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 512\n",
"\n",
"import matplotlib.patches as patches\n",
"\n",
2016-09-27 23:31:21 +02:00
"def plot_2D_decision_function(w, b, ylabel=True, x1_lim=[-3, 3]):\n",
" x1 = np.linspace(x1_lim[0], x1_lim[1], 200)\n",
" y = w * x1 + b\n",
" half_margin = 1 / w\n",
2016-09-27 23:31:21 +02:00
"\n",
" plt.plot(x1, y, \"b-\", linewidth=2, label=r\"$s = w_1 x_1$\")\n",
" plt.axhline(y=0, color='k', linewidth=1)\n",
" plt.axvline(x=0, color='k', linewidth=1)\n",
" rect = patches.Rectangle((-half_margin, -2), 2 * half_margin, 4,\n",
" edgecolor='none', facecolor='gray', alpha=0.2)\n",
" plt.gca().add_patch(rect)\n",
" plt.plot([-3, 3], [1, 1], \"k--\", linewidth=1)\n",
" plt.plot([-3, 3], [-1, -1], \"k--\", linewidth=1)\n",
" plt.plot(half_margin, 1, \"k.\")\n",
" plt.plot(-half_margin, -1, \"k.\")\n",
2016-09-27 23:31:21 +02:00
" plt.axis(x1_lim + [-2, 2])\n",
" plt.xlabel(r\"$x_1$\")\n",
2016-09-27 23:31:21 +02:00
" if ylabel:\n",
" plt.ylabel(\"$s$\", rotation=0, labelpad=5)\n",
" plt.legend()\n",
" plt.text(1.02, -1.6, \"Margin\", ha=\"left\", va=\"center\", color=\"k\")\n",
"\n",
" plt.annotate(\n",
" '', xy=(-half_margin, -1.6), xytext=(half_margin, -1.6),\n",
" arrowprops={'ec': 'k', 'arrowstyle': '<->', 'linewidth': 1.5}\n",
" )\n",
" plt.title(fr\"$w_1 = {w}$\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 3.2), sharey=True)\n",
"plt.sca(axes[0])\n",
2016-09-27 23:31:21 +02:00
"plot_2D_decision_function(1, 0)\n",
"plt.grid()\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plot_2D_decision_function(0.5, 0, ylabel=False)\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"small_w_large_margin_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 25,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAj4AAADlCAYAAABTVP1pAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA62klEQVR4nO3dd5wU9f3H8dfnGr0jqDS7FPUgIBpURKMJGruIRixgFGsMiRpFjWASSYy9xi5q/MVTRMVeQcQAAlKkGlRAiiAq5eCAK9/fH99dWJbrt7uz5f18PPZxuzOzM5+52/3ee2a+M2POOUREREQyQVbQBYiIiIgkioKPiIiIZAwFHxEREckYCj4iIiKSMRR8REREJGMo+IiIiEjGUPARERGRjKHgIyIiIhlDwSeKmY02szeCrgOSq5ZUY2YtzGy1me0bdC2JZGZjzOyPQdchUldm1svMnJntFYN59QvNq3UMSqtrLYPNrDDoOjJZxgSfikJEOV+u3wPnJbQ4iYcbgbecc1+FB5jZFWb2jZltMbMZZnZUdWdmZn3NbJyZrQh9XgbHo+iaMLM7zeydqMG3AjebWbMgapLEM7PdzOxhM1tiZltDgf9DMzs+6NrizcwmmNmD5QwfYGaRtyX4L7AH8EPCipOklTHBp7qcc+udc+uCrkNqz8waAhcDT0YMOxu4DxgF9MA3hG+bWcdqzrYxMBcfjItiWnDtHQp8FjnAOfcF8DUK75nkZaA38FvgAOAk4G2gVZBFAZhZlpllB12Hc26bc+47p3s0CQo+u4jeMxTaonjYzEaZ2VozWxPa0s6KmKaRmT1rZoWhra3hZvaGmY2OmMbM7E9m9pWZFZnZF2ZWo39OZlbPzO4NLWOLmU0xsyMjxvcNDSs0s/VmNtXMDqpqXDyYWfvQnpGzzewjM9tsZrPNrHNoL9vE0LDPosOHmd1sZnNCtX4f+ps0iBh/VmjLtlPEsPtCv9u2wIlAGfBpxGz/CIx2zj3unFvgnPsdsAq4vDrr45x7yzl3o3NuTGjecRPxWVkU+qysMbOXI8bnmtk2oC/w59DveV7ELMYBv4lnjZIczKw5cBRwg3PuQ+fcUufcNOfcnc65FyKma2Nmr4U+T0vN7CIzm2tmIyOmcWY2IGr+S8zs2ojXfwx9NzeF9n4+EaohPH5w6Ht7opnNBbYBXcwsz8xuN7PlofdOM7NfRS2rv5ktDLVtn+BDXKx+Tzsd6oqo8xeh38MmMxtvZntHvW94qL0tDLXxI8xsSdQ0Q8xsfqjuL83sDxbx/6Ga9V1qZovNbFvo5yXljP8ytIzvzexdM8sJjTvY/B6+DWa2MdTOHlOrX1SGUPCpnkFACdAHuAoYBpwdMf4u4GjgdOBYIB/fGEX6G36L7EqgK/B34FEz+3UN6vhnaLkX4fdafAG8Y2Z7hL4ErwGTQss/DL+Ho7SycRUtyMxuDH3ZK3tUdqioe+jn5fjDL4cB9YDRofW4Efg50BofSiLlhN7XDf8P/Hj87zxsTGjdbw7Vem1ouv7OudX43/2M8NadmeUBPYH3opbzHv5vmmyuA4YAVwCdgVOA9yPGl+J/d+B/r3sAR0aM/wzoHRkWJW0Vhh6nmFn9SqYbDewHHAecBlwA7FWL5ZXhv4vdgHPxe5oeiJqmPv67eSm+rVsKPI1vI88FDgaeAV43s3wAM+sAvIr/nHcPzfOftaivJuoBw/Ht6c+B5sAj4ZFmdg4wArgJ+BmwgKi2KhRQRgG3AF2Aa4Dr8d/dajGz04EHgXuBg/Bt88NmdnJofC/gIXw7eiD+bxh5iPv/8BtxvfH/F0YCW6q7/IzknMuIB/6LX8KOhiL82Aw4YK+I6d6IeN8EYHLUvN4Hngg9b4zfqjknYnwj4Cf8Hobw6yLgqKj53Ivvh1JZzW9EzGMbcEHE+GzgK3yoahlaj6PLmU+F4ypZdkt8Q1nZo0El778JWAe0jRj2APA90Cpi2NNAQRW1PAY8EzXsl0AxcAOwETg0YtyrkdMDe4bWv2/UPG4BFtXis1QIDI7jZ/Uj4K4qpjkJ2ABYOeMOCa3vvvGqUY/keQBnAj/i/9lNBu4EDosYf0Do83BExLBO+AA9MmKYAwZEzXsJcG0ly+4PbAWyQq8Hh+bTM2KaffGBqWPUe18FHg49HwV8Gfl5xoen7W1zBcufEGoXo9v1IsBFTNcvNK/WUXUeGDHNoNC8wusyGXgkannvAUsiXi8Dzo+aZhgwv5KaBwOFEa8/BZ6KmmY0MCn0/AxgPdCkgvltAC4M+nOYSo9M2+MzEb81Efk4txrvmxP1eiXQJvR8XyCXiL4WzrlN+P4gYV3xW0HvRO4xwe/VqO5ZR+HlbD9845wrxX85uzrnfsR/Wd41szdDu6Q7hKarcFxFnHM/OucWV/GorK9Ld3xoWx0xrCMw1jn3Q9Swb8IvzKyDmd1v/lDgj6Hf0xBgeVR97wHT8KFvoHNuWsToBpS/xRN9fN/KGRYTZva30K71yh79Knj7OGCYmX0Q2sVd3pkoPYDZLtTyRQn/XbTHJwM4517Gh/uT8X17+gBTzOzG0CRd8MEjso1aim/HasTMjjWz90OHrDYCY4E8YPeIyUqAWRGvf4b/rs2Pav9+zY72rwswJerzPLmaZRWwa7t+XTXet9U5tyji9Up8G9s89LozUX3ogKnhJ2a2G9ABv+c+cr3+QfXbdfDr/mnUsEn4/xvgN7SXAt+Y2fNmdqGZNYmY9m7gCfNdCm4ys841WHZGyrTgszn6nzdR/1ArUBz12rHjd2cRwyoSnvZkdv5ydsPvuaiOypbjN9ecG4I/9DERf3jky/Bx9MrGlbuwuh/qygemRA3rwa6NWT4wM7TMVvgwsztwLf6QVS98iJkVVV/4kKIBkeEKYC3QIup1KTs3zuDDa/R7Y+VefINW2SO6UQXAOXcvfpf2O/hd5l+ZWZeoyboT+r2Vo2Xo5/e1LV5Si3Nui3PufefcX5xzffAd+0eGDvNaFW/fPptyps0NPzHfp+5N/CGfs/CHjy8Kjc6LeM/W0EZZWFZo3oeyc/vXJeL91a2xPOvLade/q8b7SqJeh9vWrHKGlSc83WXsvF4H4dv2mqisXd+ID48D8XuYhgMLzWzP0PiR+JD0Kj70zjGzi8qZn4RkWvCJh8X4YNQ7PMD8WUWRHYfn43cHdypnr8nSGixnGxF9OcyfLfHz0PwBcM7Nds7d7pzrh98NfGF1xpXjEXbdiop+TC/vjWbWCL/FMzNiWEv81lHksA74M0/Cw36N3zN2tnPuXefcPHxYa0xE8An1CxgL/A7/Zf97VAkz2bG1hHNuGzAD31co0vH4s7tizjm31jm3sIrH5krev9g5dyc++Bn+8FWkfHbdExl2ELAyam+bZJb5+L5y9fFBJQsfPAAwf0LBnlHv+R7fXyw8TdvI1/jPYh7wB+fcZOfcl+XMozwz8Z/h3ctp/1ZE1HuYmUUGoMOrMe94WkhEux6y/XXo+7UCf0h5lz3iNVjOAnbuo0fodWS7XuKc+8g5NxzfFjTCH+4Oj/+fc+5+59yv8aH34hosP+PkBF1AqnPOFZrZU8DtZrYW38nsZnZs5eCc22hmdwJ3hr7YE/H/zA8Hypxzj1VjOZvM7F/AP0LL+Qb4A9AW3xFub3xnwnH4L+M++C/IvyobV8nyfsT3G6iN8D/p2RHDeuDD3/yIYd3xx+PDjcQP+N/LaWb2BXACvhP0xvA0oa3Ot4C7nXNPmdln+C2cfs65CaH5vIv/e7SKOKx2N/BcaPpP8VtpexLRmbEyZtYY368J/N+2o5l1B350zi2rzjyquZzr8XuhPsNvkV6ID7wToibNATqHtvo2u50vwXAUO3d+lDQV2kv6EvAUPghvxAeUPwEfOuc2ABvMX+/pUTMbij8Ueje7XpbhI+BKM/svfg/pKHY+ZPw//Gd/mJmNxbdfw6qq0Tn3pZk9D4w2s2uAz/F7JfsBXzvnxuK/h9cA95rZw/gO0JfV7LcRc/cBT5vZNOAT/Mkrh+H7b4aNBB4ws3X4dikXv3emnXMueoOsIncAL5nZDHwfov74/kZnAJjZSfgNyYn4NvkYoAmwIHQ
"text/plain": [
"<Figure size 590.4x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 513\n",
"\n",
"s = np.linspace(-2.5, 2.5, 200)\n",
"hinge_pos = np.where(1 - s < 0, 0, 1 - s) # max(0, 1 - s)\n",
"hinge_neg = np.where(1 + s < 0, 0, 1 + s) # max(0, 1 + s)\n",
"\n",
"titles = (r\"Hinge loss = $max(0, 1 - s\\,t)$\", r\"Squared Hinge loss\")\n",
"\n",
"fix, axs = plt.subplots(1, 2, sharey=True, figsize=(8.2, 3))\n",
"\n",
"for ax, loss_pos, loss_neg, title in zip(\n",
" axs, (hinge_pos, hinge_pos ** 2), (hinge_neg, hinge_neg ** 2), titles):\n",
" ax.plot(s, loss_pos, \"g-\", linewidth=2, zorder=10, label=\"$t=1$\")\n",
" ax.plot(s, loss_neg, \"r--\", linewidth=2, zorder=10, label=\"$t=-1$\")\n",
" ax.grid(True, which='both')\n",
" ax.axhline(y=0, color='k')\n",
" ax.axvline(x=0, color='k')\n",
" ax.set_xlabel(r\"$s = \\mathbf{w}^\\intercal \\mathbf{x} + b$\")\n",
" ax.axis([-2.5, 2.5, -0.5, 2.5])\n",
" ax.legend(loc=\"center right\")\n",
" ax.set_title(title)\n",
" ax.set_yticks(np.arange(0, 2.5, 1))\n",
" ax.set_aspect(\"equal\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"hinge_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Extra Material"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## Linear SVM classifier implementation using Batch Gradient Descent"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 26,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = (iris.target == 2)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 27,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator\n",
"\n",
"class MyLinearSVC(BaseEstimator):\n",
" def __init__(self, C=1, eta0=1, eta_d=10000, n_epochs=1000,\n",
" random_state=None):\n",
2016-09-27 23:31:21 +02:00
" self.C = C\n",
" self.eta0 = eta0\n",
" self.n_epochs = n_epochs\n",
" self.random_state = random_state\n",
" self.eta_d = eta_d\n",
"\n",
" def eta(self, epoch):\n",
" return self.eta0 / (epoch + self.eta_d)\n",
" \n",
" def fit(self, X, y):\n",
" # Random initialization\n",
" if self.random_state:\n",
" np.random.seed(self.random_state)\n",
" w = np.random.randn(X.shape[1], 1) # n feature weights\n",
2016-09-27 23:31:21 +02:00
" b = 0\n",
"\n",
" m = len(X)\n",
" t = np.array(y, dtype=np.float64).reshape(-1, 1) * 2 - 1\n",
2016-09-27 23:31:21 +02:00
" X_t = X * t\n",
" self.Js=[]\n",
"\n",
" # Training\n",
" for epoch in range(self.n_epochs):\n",
" support_vectors_idx = (X_t.dot(w) + t * b < 1).ravel()\n",
" X_t_sv = X_t[support_vectors_idx]\n",
" t_sv = t[support_vectors_idx]\n",
"\n",
" J = 1/2 * (w * w).sum() + self.C * ((1 - X_t_sv.dot(w)).sum() - b * t_sv.sum())\n",
2016-09-27 23:31:21 +02:00
" self.Js.append(J)\n",
"\n",
" w_gradient_vector = w - self.C * X_t_sv.sum(axis=0).reshape(-1, 1)\n",
" b_derivative = -self.C * t_sv.sum()\n",
2016-09-27 23:31:21 +02:00
" \n",
" w = w - self.eta(epoch) * w_gradient_vector\n",
" b = b - self.eta(epoch) * b_derivative\n",
" \n",
"\n",
" self.intercept_ = np.array([b])\n",
" self.coef_ = np.array([w])\n",
2017-12-19 22:40:17 +01:00
" support_vectors_idx = (X_t.dot(w) + t * b < 1).ravel()\n",
2016-09-27 23:31:21 +02:00
" self.support_vectors_ = X[support_vectors_idx]\n",
" return self\n",
"\n",
" def decision_function(self, X):\n",
" return X.dot(self.coef_[0]) + self.intercept_[0]\n",
"\n",
" def predict(self, X):\n",
" return self.decision_function(X) >= 0"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[ True],\n",
" [False]])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"C = 2\n",
"svm_clf = MyLinearSVC(C=C, eta0 = 10, eta_d = 1000, n_epochs=60000,\n",
" random_state=2)\n",
2016-09-27 23:31:21 +02:00
"svm_clf.fit(X, y)\n",
"svm_clf.predict(np.array([[5, 2], [4, 1]]))"
]
},
{
"cell_type": "code",
"execution_count": 29,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZkAAAEOCAYAAABbxmo1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAcXklEQVR4nO3de5hcdZ3n8fenu9OdSyeEkASzCWu4xGhE5JJxuOl0DIwOMMIIjDjiZlzGeGEUbzsmOCoz6wVcdRxnXSUrSFgYJKArPIZBINqwohC5SiA0uQKBkAQIJJ2QTrr7u3+c00l1pxNSSZ061ac/r+fpp0796tSp77fTnU+fX506RxGBmZlZFuryLsDMzIrLIWNmZplxyJiZWWYcMmZmlhmHjJmZZcYhY2ZmmalayEi6WtJ6SUtKxsZIulPSsvT24JLH5kpaLqlN0nuqVaeZmVVONfdkrgHe22dsDrAoIqYAi9L7SJoGXAC8NX3O/5JUX71SzcysEqoWMhFxD/Byn+Gzgfnp8nzgnJLxn0ZER0SsApYD76hGnWZmVjkNOb/+oRGxFiAi1koan45PBO4rWW9NOrYbSbOB2QB1w0ad0HDQeCaPqmP1pm4AJo8qzttO3d3d1NUVp5++3N/AVuT+itwbwFNPPfViRIzLYtt5h8yeqJ+xfs9/ExHzgHkATROmxIRZ36Pt8jOZPGchAG2Xn5lZkdXW2tpKS0tL3mVkxv0NbEXur8i9AUh6Oqtt5x3N6yRNAEhv16fja4DDStabBDxf5drMzOwA5R0ytwKz0uVZwC0l4xdIapJ0ODAFWJxDfWZmdgCqNl0m6QagBRgraQ3wVeByYIGki4BngPMBIuJxSQuAJ4BO4OKI6KpWrWZmVhlVC5mI+OAeHpq5h/W/Dnw9u4rMzCxreU+XmZlZgTlkzMwsMw4ZMzPLjEPGzMwyU9iQ+cD0w15/JTMzy1RhQ2ZEUwPNTbV6QgMzs8GhsCGj/k5MY2ZmVVXYkDEzs/wVOmQi+j2nppmZVUlhQ8azZWZm+StsyJiZWf4KHTKeLDMzy1dhQ8ZHl5mZ5a+wIWNmZvkrdMj44DIzs3wVNmTk+TIzs9wVNmTMzCx/hQ6Z8PFlZma5KmzIeLLMzCx/hQ0ZMzPLX6FDxkeXmZnlq7gh4/kyM7PcFTdkzMwsd4UOGc+WmZnlq7AhI8+XmZnlrrAhY2Zm+St2yHi+zMwsV4UNGZ+6zMwsf4UNGTMzy1+hQ8bnLjMzy1dhQ8azZWZm+StsyJiZWf4KHTI+d5mZWb4KGzI+uszMLH+FDRkzM8tfTYSMpM9KelzSEkk3SBoqaYykOyUtS28PLne7ni0zM8tX7iEjaSLwaWB6RBwN1AMXAHOARRExBViU3t/37fr4MjOz3OUeMqkGYJikBmA48DxwNjA/fXw+cE4+pZmZ2f5qyLuAiHhO0reBZ4DXgDsi4g5Jh0bE2nSdtZLG9/d8SbOB2QCNbzgKgNbWVp5+ejvd3UFra2s12qiK9vb2QvXTl/sb2IrcX5F7y1ruIZO+13I2cDjwCnCTpAv39fkRMQ+YB9A0YUoAtLS08OD2NrRqOS0tLRWvOS+tra2F6qcv9zewFbm/IveWtVqYLjsNWBURGyJiB/Bz4GRgnaQJAOnt+hxrNDOz/VALIfMMcKKk4ZIEzASWArcCs9J1ZgG3lLthH11mZpav3KfLIuJ+STcDDwGdwMMk01/NwAJJF5EE0fnlbNfHlpmZ5S/3kAGIiK8CX+0z3EGyV2NmZgNULUyXZcbnLjMzy1dxQ8YnLzMzy11xQ8bMzHLnkDEzs8wUNmQ8WWZmlr/ChoyZmeWv8CETPsTMzCw3hQ0ZH1xmZpa/woaMmZnlr/Ah49kyM7P8FDZkfGVMM7P8FTZkzMwsf4UPGc+WmZnlp7Ah46PLzMzyV9iQMTOz/BU+ZPxhTDOz/BQ2ZDxbZmaWv8KGjJmZ5a/wIePJMjOz/BQ2ZHx0mZlZ/gobMmZmlr/Ch4wPLjMzy09hQ0aeLzMzy11hQ8bMzPJX+JAJH19mZpabwoeMmZnlxyFjZmaZKXzI+OgyM7P8FDZkfHCZmVn+ChsyZmaWP4eMmZllprAhI5/s38wsd4UNmR5+49/MLD+FDRm/8W9mlr+aCBlJoyXdLOlJSUslnSRpjKQ7JS1Lbw/Ou04zMytPTYQM8K/A7RHxZuDtwFJgDrAoIqYAi9L7ZfNpZczM8pN7yEgaBbwLuAogIrZHxCvA2cD8dLX5wDllbbdyJZqZ2X5S5PzOuKRjgXnAEyR7MQ8ClwDPRcTokvU2RsRuU2aSZgOzARrfcNQJE2Z9j2veO4LbVm1nQdsOrjxtOE0NxYic9vZ2mpub8y4jM+5vYCtyf0XuDWDGjBkPRsT0LLbdkMVGy9QAHA98KiLul/SvlDE1FhHzSEKKpglTAqClpYU2rYC2Jzn1ne9kRFMttHngWltbaWlpybuMzLi/ga3I/RW5t6zlPl0GrAHWRMT96f2bSUJnnaQJAOnt+nI26qPLzMzyl3vIRMQLwLOSpqZDM0mmzm4FZqVjs4BbcijPzMwOQK3MI30KuF5SI7AS+AhJAC6QdBHwDHD+/mzYx5aZmeWnJkImIh4B+nvTaeb+btOnlTEzy1/u02VZ+bdfLwNg87YdOVdiZjZ4FTZkNm3rBOCj1z6QcyVmZoNXYUOmx5LnNuVdgpnZoFX4kDEzs/wccMhIGlKJQiqtzu/7m5nlrqyQkfRpSeeW3L8KeE1SW8nnXGrCKUeNzbsEM7NBr9w9mU8DGwAkvQv4a+BvgEeA71S0sgP0xfe+Oe8SzMwGvXJDZiKwOl3+S+CmiFgAXAacWLmyDtzY5qady79b/mKOlZiZDV7lhswmYFy6fDrJdV4AdgBDK1VUJRw6alfI3P3UhhwrMTMbvMoNmTuA/52+F3MU8B/p+FuBVZUs7ECp5AyZV96zMsdKzMwGr3JD5mLgXmAscF5EvJyOHw/cUMnCzMxs4Cvr3GURsYnkZJZ9x79asYoO0Mih/bfU1R3U+7hmM7OqKvcQ5mmlhypLOl3SdZLmSqqvfHnlGdUobrn4lH4fW7mhvcrVmJlZudNlVwHHAUiaRHKNlzEk02hfq2xp5RszVBwxrv9LpJ7+L/dUuRozMys3ZN4CPJQunw/cHxFnAB8GPljJwirhJx/5k7xLMDMb1MoNmXpge7o8E7gtXV4BHFqpoiplxtTxeZdgZjaolRsyS4BPSHonScjcno5PBGr+E49Lnns17xLMzAaVckPmi8BHgVbghoh4LB1/H7C4gnVVzFHjd71Hc9a//TbHSszMBp+yQiYi7iH5xP/YiPivJQ9dCXyikoVVyl2f+7Ne9yMip0rMzAafsk/1HxFdJGdePlrSWyUNjYjVEbE+g/oq7orb2/Iuwcxs0Cj3czINkv4HsBF4FHgM2CjpW7V6XZm+fnT3irxLMDMbNMrdk/kWcCHwceBNwBSSabIPA9+sbGmVc9lfTsu7BDOzQanckPkb4KKImB8RK9Kva4C/Az5U8eoq5G9PObzX/Sdf2JRTJWZmg0u5IXMQyWdi+loBjD7gaqpk/aaOvEswMxsUyg2ZR0mujtnXJeljNetnnzh553J7R2eOlZiZDR7lhsw/ALMkPSVpvqRrJLWRvE/zhcqXVzknvPHgncufvP6hvaxpZmaVsj+fk3kTcBPQDIxKl99D/3s4ZmY2iO3P52Sej4gvRcS5EfH+iPhHYAtwbuXLq6xV3zxj5/IDq1/ey5pmZlYJZYfMQFZ6SebzfvT7HCsxMxscBlXIAKy+/Mydy/cur/lzepqZDWiDLmQAPjD9MAA+9OP7c67EzKzYGvZlJUm3vs4qoypQS9Vccd4x3PjAswAc+8938MhX/jzniszMimlf92Reep2vVcC1WRSYlcWXzgTgla07ePDpjTlXY2ZWTPu0JxMRH8m6kGobP2oobzq0mafWtXPuD3/Hym+cQV2dXv+JZma2zwblezI97vjsrmvNHHHpbXtZ08zM9kfNhIykekkPS/plen+MpDslLUtvD369beyP0s/OTJ6zMIuXMDMbtGomZEjOf7a05P4cYFFETAE
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"plt.plot(range(svm_clf.n_epochs), svm_clf.Js)\n",
"plt.axis([0, svm_clf.n_epochs, 0, 100])\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.grid()\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 30,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-15.56761653] [[[2.28120287]\n",
" [2.71621742]]]\n"
]
}
],
2016-09-27 23:31:21 +02:00
"source": [
"print(svm_clf.intercept_, svm_clf.coef_)"
]
},
{
"cell_type": "code",
"execution_count": 31,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-15.51721253] [[2.27128546 2.71287145]]\n"
]
}
],
2016-09-27 23:31:21 +02:00
"source": [
"svm_clf2 = SVC(kernel=\"linear\", C=C)\n",
"svm_clf2.fit(X, y.ravel())\n",
"print(svm_clf2.intercept_, svm_clf2.coef_)"
]
},
{
"cell_type": "code",
"execution_count": 32,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAq4AAADwCAYAAADB5OaoAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAACRF0lEQVR4nOydZ3hURRuG70kPLbTQWxqhhCYivYkIIqAUBZQmItKUIiLSERAEBFGKgBQpomgA6YJIL9IhlBSSICWhhN7S5/ux2f2SsEk2YTtzX9e5sjunzHPOnn0zO2fmeYWUEoVCoVAoFAqFwtpxsLQAhUKhUCgUCoXCEFTDVaFQKBQKhUJhE6iGq0KhUCgUCoXCJlANV4VCoVAoFAqFTaAargqFQqFQKBQKm0A1XBUKhUKhUCgUNoFquCpsGiHEeCHEWUvrUCgUCoVCYXpUw1VhdIQQy4QQUgjxk55101LWbTLwWE1Sti+cwSYzgMbPo9eYCCE8hRDzhBCXhBBxQogbQoidQojmKevP6LsuKetapZxr+VRl7YUQ/wgh7gkhHgshgoQQk4UQRcx1TgqF4sUhsxim4pfCGlANV4WpuAJ0EkLk1hYIIZyAbsBlY1UipXwkpbxtrOPlFCGEkxBCAIHAK8CHQHmgNbAVKJSy6WLSXZdU9AL2SSlDU445GfgdOJVynErAIKAc0M9U56JQKF5oMothKn4pLI+UUi1qMeoCLAM2ASeAD1KVvwVEAj+nrG8EJADF0u0/GTiT8roJIIHCGdQ1Hjirp+5BwDXgLrAUyJVqGwEMB8KBp0AQ0DXdcacCISnrLwHTALf09QI9U46TBJRK0fpaJtemIBCb+rqklHsC8UD3lPevpBxraAbHyW/pz1ktalGLfS1A/sximIpfarGGRfW4KkzJYjS/wrX0QtOIlABSyr1oGn3dtRsIIRxS3i9+jnobAgHAa0AnoB2ahqyWSWh6Ewag6QWYAiwQQryZapvHKXorAv2BzsCodPV4Ae8B7wDVgOvAI6CtEMJNnzAp5R1gPWmvC2h6op8Cf6S8fz9Fww8ZHOeevnKFQqF4Dh6RSQxT8UthDaiGq8KU/AK8LITwE0IUA1qi6RFNzU/AB6netwCKACufo94HQD8p5QUp5XY0j6uaAaQ84hoK9JZSbpNSRkopfwEWoWnIAiClnCilPCClvCSl3AJ8DXRJV48L0E1KeUJKeVZKmYimB7YrcE8IcUgIMUMIUVvPOTdIPRYMzT+CX6SUT1Le+wHhUsqE57gOCoVCYTAGxjAVvxQWRTVcFSZDSnkXWIcmqPUAdksp049v/RnwFkLUS3nfC1gvn2/c6vmUAKwlCk1jGDQ9rG7ANiHEI+2CZsyVj3YHIURHIcR+IcT1lPWzgDLp6rkqpbyRukBKGQiUANqgGRdWDzgshBiZarOdaIZM9EqpqzZQGc0/BJ2EHJy3QqFQPBcGxDAVvxQWRTVcFaZmCZpH/71SXqdBSnkL2AD0EkIUAtryfMMEQDNuNk01/P9e1/5tA1RPtVQGXgcQQtQBfgX+StmuBjAacE533Mf6KpdSxkopd0gpv5JS1kNzPuOFEC4p6yWaIRPdhRCOaIYtnJZSHk91mFDAR7uPQqFQmIvMYpiKXwpLoxquClOzE82g/cJoxkbpYxHwLvAxcAP424R6zgNxQFkp5cV0y38p29QHrqUMFzgqpQwDyj5nnU5oenq1LAWKohkf25m0vRWgGWaRGxio74BCiPzPoUehUCiyQ/oYpuKXwmI4WVqAwr6RUkohRFVASCnjMthsB3AbGAdMlVIm69kmQAhxL13ZmRzoeSiEmAHMSLGv2gvkAeoAyVLKhWh6C0oKId4HDqEZd5t+fOszpPQY/46mZ/kM8BB4GY2DwU4p5YNUOq4KIf4C5qHpyV2VTue/QohpwHQhRCk0FjVX0UwI+xC4CEzI7vkrFApFRhgaw1T8UlgS1XBVmBwp5cMs1kshxFI0FlNLM9hsl56yvDmUNAZNz+4wYD6ayVyn0FheIaXcKISYDnwHuAPbgbFognRmPAIOo3Ew8AVc0Vhy/YLGySA9PwFvoJnUcDf9SinlF0KIY2gmjX2I5vsaCfxpgBaFQqHILtmJYSp+KSyC0AxXUSgsixBiPuArpWxuaS0KhUKhUCisE9XjqrAoQggPoCaaCVzvWliOQqFQKBQKK0Y1XBWW5k80WVYWSyk3W1qMQqFQKBQK60UNFVAoFAqFQqFQ2ATKDkuhUCgUCoVCYROohqtCoVAoFAqFwiaw6zGu+fPnl76+vpaWwePHj8mdO7elZQBKS0YoLfpRWvRz/PjxGCmlp6V1mAJriZtgXZ+50qIfpUU/SsuzGC1uSinNsgCl0XhxXgDOAYP0bPM5Gj/NU8BZIAkomLLuEhCUsu6YIXWWL19eWgO7du2ytAQdSot+lBb9KC36MTQG2eJiLXFTSuv6zJUW/Sgt+lFansVYcdOcPa6JwGdSyhNCiLzAcSHEDinlee0GUsrpwHQAIUQbYIiU8k6qYzSVUsaYUbNCoVAoFAqFwkow2xhXKWW0lPJEyuuHaHpeS2aySxdgtTm0KRQKhUKhUCisH4vYYQkhyqHJER8gU+VvT7U+F5q8xr7aHlchRCRwF5DAAqnJKa/v2H2APgCFCxeu+csvv+Ds7GyS8zCUR48ekSdPHotq0KK06Edp0Y/Sop+mTZsel1K+bGkdxiJ13PT09Kw5Z84cPD09EUJYVJc1feZKi36UFv0oLc9itLhpjPEG2VmAPMBxoH0m23QCNqYrK5HytwhwGmiUVV2lSpWSbm5uctiwYfLWrVs5GpNhDKxlfImUSktGKC36UVr0gx2PcS1TpowEZKlSpeSiRYtkfHy8MS9dtrCmz1xp0Y/Soh+l5VmMFTfNaoclhHAGAoFVUsq1mWzamXTDBKSUUSl/bwLr0GRbyhRHR0diY2OZMWMG3t7eTJgwgQcPnungVSgUCkUKzs7OVK1alatXr/LRRx9RqVIlVq9eTXJysqWlKRQKhfkmZwnNM6fFwAUp5cxMtvMAGgNdU5XlBhyklA9TXr8OfJVVna6urhw7dozRo0ezbds2xo8fzw8//MDy5ctp2bIlV69e5fHjx899blnh4eHBhQsXTF6PISgt+rGklty5c1OqVCkcHJStssLyODo6cvLkSdasWcPYsWMJCwvjvffeY8qUKfz+++/4+/tbWqJCoXiBMaerQH2gGxAkhDiVUjYSKAMgpfwxpawdsF1KmbpFWRRYlzLeygn4RUq5zZBKa9asydatW9m7dy8jR47k8OHDlC9fnpiYGIQQ+Pv7m7zB8PDhQ/LmzWvSOgxFadGPpbQkJydz7do1YmJiKFKkiNnrVyj04eDgQOfOnenYsSM///wzEyZM4MaNG5QqVcrS0hQKxQuOOV0F9ksphZSyqpSyesqyRUr5Y6pGK1LKZVLKzun2jZBSVktZKkspJ2e3/kaNGrFv3z5OnTqFr68v9+7do0iRIkRGRnL79m3tOFqFwqw4ODhQtGhR7t+/b2kpCsUzODk58eGHHxIWFsb27dt1JuYPHz6kY8eOHD582MIKFQrFi8YL9WxSCEFAQAAASUlJPH36lLt37xIZGcn58+e5d++easAqzI6zszOJiYmWlqFQZIirqyvVqlXTvZ8zZw6BgYHUrVuXtm3bcubMGQuqUygULxIvVMM1PR4eHpQrVw4XFxeePn3KxYsXCQ4OVhO4FGbF0pZDCkV26du3LyNHjiRXrlxs3LiRatWq0aVLF0JDQy0tTaFQ2DkvdMNVCEHhwoUJCAigdOnSODk58fjxY0JDQwkPD7e0PIVCobBKChQowOTJk4mIiGDQoEG4uLjw66+/UqlSJaZMmWJpeQqFwo55oRuuWrTjDKtUqULJkiVxdHQkV65clpaVY5o0acLAgQNNdvyePXvSunXr5z7Ovn37EEIQE2N4Ft9ly5ZZhZGyQqGAokWL8t133xEWFkbv3r0BqFq1qoVVKRQKe8acrgJWj6OjI8WLF38mY8yNGzd4/PgxJUqUwM3NzWL6evbsSUxMDJs2bcp0u7Vr15o0W9js2bONMha4du3
"text/plain": [
"<Figure size 792x230.4 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"yr = y.ravel()\n",
"fig, axes = plt.subplots(ncols=2, figsize=(11, 3.2), sharey=True)\n",
"plt.sca(axes[0])\n",
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\", label=\"Iris virginica\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\", label=\"Not Iris virginica\")\n",
2016-09-27 23:31:21 +02:00
"plot_svc_decision_boundary(svm_clf, 4, 6)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.title(\"MyLinearSVC\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([4, 6, 0.8, 2.8])\n",
"plt.legend(loc=\"upper left\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\")\n",
"plot_svc_decision_boundary(svm_clf2, 4, 6)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.title(\"SVC\")\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"plt.grid()\n",
"\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 33,
2016-09-27 23:31:21 +02:00
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-12.52988101 1.94162342 1.84544824]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAADwCAYAAADhPsSkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABQQElEQVR4nO2dd3hU1faw35UQQJr0Jl0ggPQmAipwUUGFiwX1IgoK8qGCXMpFLIAUQS/qDxH0ioqIIiq9iYUuKCUgIYEkFI2CQQgdhBCSrO+PMxknYZJMMpOcmWS/z3Me5uyzyzonhzV71l5rL1FVDAaDwZC/CbJbAIPBYDDkPkbZGwwGQwHAKHuDwWAoABhlbzAYDAUAo+wNBoOhAGCUvcFgMBQAjLI3GLJARPqLyMU8HE9F5EGX8wYi8pOIJIhIrLs6BkNWGGVv8AtEpIKIvCsisSJyRUSOi8g6EbnDpU4dEflQRH5z1IkTkQ0i0k9ECrvUU5fjkoj8IiKfi0jHDMa+X0TWi8hZEflLRCJE5FURqZgX9+6GKsBKl/PJwCWgAdAmgzoGQ6YYZW/wFxYDbYEBQH3gXmANUA5ARFoDPwONgaFAE+BuYDbQj7+VYCpPYSnEho4+E4HNIvIf10oi8iqwENjjGLMRMAyoBTzt0zv0EFX9U1WvuBTVBbaoaqyqxmdQJ1uISCEREW9lNQQQqmoOc9h6AKUBBbpmcF2AfUAYEJRRHZfPCjzops4UIAmo6zhv66g7IiO5HP/2By66lN8ILAf+BP4CdgP3pmt7P7AXuAycBjYBlRzXqjvan8aasUcDj7iT3/HZ9XjF3T0CNwBfAGccx2qgnsv1V4BIx70cBpKBEnb/7c2Rd4eZ2Rv8gYuOo6eIFHVzvTnWjPsNVU1x14E6NFoWvIn1a7aX4/xRLGX9TgZ9ns2gnxJYvzruAJph/SpZIiINAESkMpbi/QTrl8VtwKcu7d8FigGdgZuAfwMZjVUFiHHIXgV4I30FESkGbAASgNuBW4BjwFrHtVRqA32A3g65EzIY05APKWS3AAaDqiaJSH/gA2CQiPwMbAUWqup2LLMOWEoPABG5HvjDpZspqjoli3FOicgJoI6jqB5wWFWvZlPecCDcpehVEekBPIhlX68KhACLVPU3R51Il/o1gcWOfgB+zWSsP0UkCeuXxZ8ZVHsE69fPE6lfeiLy/4ATWKaprxz1CgOPqepxz+7UkJ8wM3uDX6Cqi7GUZA+sWXN7YJuIvJhBkwtYM/7mQByWIvMEwTKBpH7ONiJSXET+KyL7ReSMw1OnNVDDUSUcWAtEishiEXlaRCq4dPE28LLDw2ayiLTKiRwutMKatV8QkYsOec4BZbBMTqkcNYq+4GKUvcFvUNUEVf1eVSeqanvgIyxbc6yjSgOXuimqekhVD2EtvmaJiJQHKgC/OIoOADe6evJ4yBtYppCxWGaT5sAOHF84qpoM3Ok49mItEB8UkWaO6x9hKeePsX61/Cgir2RTBleCsBaYm6c76gPvu9T7y4sxDAGOUfYGf2Y/lqkxGogCRotIsBf9jQRSsBZHAT4HigND3FUWkdIZ9NMRmKeqi1V1L3CUtDNo1OInVZ2A5SkUBzzscv2oqs5W1YeAccCgHN+VtUBcFziZ+gXocpz2ol9DPsLY7A22IyLlsNwf52DNhC9gmUVGA+tU9ZzDpr8W+MnhLhkFBAMdgGpY3iWulHYslBbGUsT9gMeB0Y5fA6jqdhH5LzBNRKphLbQexZp1DwAOARPciHwAuE9ElgNXgfGAc2FZRNoBXYFvgeNACywPnP2O629jmaoOAKWAbqnXcsh8YBSwXETGAb87xvsn8D9VPehF34Z8glH2Bn/gIrANy7+9LlAEa/H1c6wFT1R1h4i0BF7A8p6pjOXWuBd4CfgwXZ8fOP69guWZsg3opKqbXSup6vMiEgY8i6XgC2EtmC7H8ppxxwgsE9MPWG6O03FR9lj28g5Y8QClgSPAJFX9zHE9yHEP1bG+2NZh/erIEap6SURuA17D+tK8HuuXxAaHfAaD5ZtsMBgMhvyNsdkbDAZDAcAoe4PBYCgAGGVvMBgMBQCj7A0Gg6EAYJS9wWAwFADytetl+fLltVatWnaLkS3Onj3LkSNHKFSoEA0bNrRbHIPBYDO7du06qaoVsq6ZOXmm7EWkOjAPyz86BZitqm+nq/MfrJ0IU2VrCFRQ1dOODD0XsIJnklS1dVZj1qpVi7CwMN/dRB5x+fJlTpw4Qc2aNQH4/fffWbhwIc899xwhISE2S2cwGPISEfkt61pZk5dmnCRgpKo2BNoBz4pII9cKqjpNVZuranOs4JlN6cK9OzuuZ6noA5nrrrvOqegBRowYwahRo2jWrBnr16+3UTKDwRCo5JmyV9Vjqrrb8fkCVrj7DZk0+RewIC9k83cGDBhA3bp1iYqK4h//+AcPP/wwR48etVssg8EQQNiyQCsitbD2C9mewfViWPuFLHYpVuA7EdklIhluGiUig0QkTETC4uPjfSi1fXTv3p3IyEheffVVrrvuOr766itCQ0N57bXXSEz0aMNHg8FQwMnz7RJEpARWirZXVXVJBnUeBvqqag+XsqqqGudIAv09MDT9Pifpad26tQaizT4zfv/9d0aOHMmiRYsIDg4mPDycm266yW6xDAZDLiEiu3xhus5TbxwRCcGarc/PSNE7eIR0JhxVjXP8e0JElmLlD81U2edHatSowcKFC/n+++/TKHpV5fjx41SuXNlmCQ0Ggz+SZ2YcRyb7j4AoVX0rk3rXYyWEWO5SVlxESqZ+xkoKEem+h4LBHXfcwahRo5zny5cvp06dOkyePJmEBJNa1GAwpCUvbfYdgMeALiKyx3HcLSKDRWSwS737gO9U1TWrTiVgi4iEY2UEWq2q3+Sd6P7Pli1buHz5MmPHjqVx48asXr3abpEMBoMfka+3OM6PNvvM2LBhA0OGDGH/fisPRo8ePZg+fTp16tTJoqXBYPBXfGWzN9sl5CM6d+7Mnj17eOuttyhZsiQrV66kUaNGxjffYDAYZZ/fCAkJYfjw4cTExPDYY49RuXJl2rVrZ7dYBoPBZoyyz6dUqVKFefPmER4eTrFixQA4f/48jz32GAcOHLBZOoPBkNcYZZ/Puf76652fp06dymeffUaTJk148cUX+euvvzJpaTAY8hNG2Rcghg8fzpNPPkliYiJTp06lQYMGLFy4kPy8SG8wGCyMsi9AVKxYkY8++oiffvqJli1bcvToUR566CHuuOMOoqOj7RbPYDDkIkbZF0DatWvHjh07eO+99yhTpgzr1q3jt998souqwWDwU4yyL6AEBwczePBgDhw4wLvvvstdd93lvLZz505j2jEY8hn5WtkfOHCAqKgou8Xwa8qXL8/TTz/tPA8LC+Pmm2+mc+fOREYW6B0pDIZ8Rb5W9hcuXKBp06aMHj2aCxcu2C1OQHDs2DHKlSvHpk2baN68Of/+9785d+6c3WIZDAYvydfKvnz58iQnJzNt2jQaNGjAggULjHkiC3r06EFMTAzPPPMMqsrbb79NaGgo8+bNIyUlxW7xDAZDDsnXyr5mzZps376dtm3bEhcXR58+fejSpYsxT2RB2bJlmTVrFmFhYdxyyy0cP36cfv36MXnyZLtFMxgMOSQvtziuLiIbRCRKRPaJyDA3dTqJyDmXXTHHuVzrJiIxInJIRMZ4Om6bNm346aef+OCDDyhXrhwbN26kefPmjBgxwpgnsqBFixZs2bKFuXPnUrduXZ566im7RTIYDDlFVfPkAKoALR2fSwIHgEbp6nQCVrlpGwwcBuoAhYHw9G3dHa1atVJXTp06pc8884wGBQUpoJUqVdJ58+ZpSkqKGjInOTnZ+fnq1avatWtX/eijj9KUGwz+Stz5OL3t49v02IVjdouSLeLOxynluaA+0MH+nHDclbbAIVX9RVUTgS+Af2ZXBlfzRPv27Tl+/DiPP/44t956K+Hh4dntrkARFPT3q7JkyRLWrl3LgAEDaN++PQVpG2lDYDJp8yS2/L6FSZsm2S1Ktpi0eRKEUMIXffljwvFbRCRcRNaISGpy1Ru
"text/plain": [
"<Figure size 396x230.4 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"from sklearn.linear_model import SGDClassifier\n",
"\n",
"sgd_clf = SGDClassifier(loss=\"hinge\", alpha=0.017, max_iter=1000, tol=1e-3,\n",
" random_state=42)\n",
"sgd_clf.fit(X, y)\n",
2016-09-27 23:31:21 +02:00
"\n",
"m = len(X)\n",
"t = np.array(y).reshape(-1, 1) * 2 - 1 # -1 if y == 0, or +1 if y == 1\n",
2016-09-27 23:31:21 +02:00
"X_b = np.c_[np.ones((m, 1)), X] # Add bias input x0=1\n",
"X_b_t = X_b * t\n",
"sgd_theta = np.r_[sgd_clf.intercept_[0], sgd_clf.coef_[0]]\n",
"print(sgd_theta)\n",
"support_vectors_idx = (X_b_t.dot(sgd_theta) < 1).ravel()\n",
"sgd_clf.support_vectors_ = X[support_vectors_idx]\n",
"sgd_clf.C = C\n",
"\n",
2021-12-08 03:16:42 +01:00
"plt.figure(figsize=(5.5, 3.2))\n",
2016-09-27 23:31:21 +02:00
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\")\n",
"plot_svc_decision_boundary(sgd_clf, 4, 6)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.title(\"SGDClassifier\")\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## 1. to 8."
2016-09-27 23:31:21 +02:00
]
},
{
2017-06-01 09:23:37 +02:00
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"1. The fundamental idea behind Support Vector Machines is to fit the widest possible \"street\" between the classes. In other words, the goal is to have the largest possible margin between the decision boundary that separates the two classes and the training instances. When performing soft margin classification, the SVM searches for a compromise between perfectly separating the two classes and having the widest possible street (i.e., a few instances may end up on the street). Another key idea is to use kernels when training on nonlinear datasets. SVMs can also be tweaked to perform linear and nonlinear regression, as well as novelty detection.\n",
"2. After training an SVM, a _support vector_ is any instance located on the \"street\" (see the previous answer), including its border. The decision boundary is entirely determined by the support vectors. Any instance that is _not_ a support vector (i.e., is off the street) has no influence whatsoever; you could remove them, add more instances, or move them around, and as long as they stay off the street they won't affect the decision boundary. Computing the predictions with a kernelized SVM only involves the support vectors, not the whole training set.\n",
"3. SVMs try to fit the largest possible \"street\" between the classes (see the first answer), so if the training set is not scaled, the SVM will tend to neglect small features (see Figure 52).\n",
"4. You can use the `decision_function()` method to get confidence scores. These scores represent the distance between the instance and the decision boundary. However, they cannot be directly converted into an estimation of the class probability. If you set `probability=True` when creating an `SVC`, then at the end of training it will use 5-fold cross-validation to generate out-of-sample scores for the training samples, and it will train a `LogisticRegression` model to map these scores to estimated probabilities. The `predict_proba()` and `predict_log_proba()` methods will then be available.\n",
"5. All three classes can be used for large-margin linear classification. The `SVC` class also supports the kernel trick, which makes it capable of handling nonlinear tasks. However, this comes at a cost: the `SVC` class does not scale well to datasets with many instances. It does scale well to a large number of features, though. The `LinearSVC` class implements an optimized algorithm for linear SVMs, while `SGDClassifier` uses Stochastic Gradient Descent. Depending on the dataset `LinearSVC` may be a bit faster than `SGDClassifier`, but not always, and `SGDClassifier` is more flexible, plus it supports incremental learning.\n",
"6. If an SVM classifier trained with an RBF kernel underfits the training set, there might be too much regularization. To decrease it, you need to increase `gamma` or `C` (or both).\n",
"7. A Regression SVM model tries to fit as many instances within a small margin around its predictions. If you add instances within this margin, the model will not be affected at all: it is said to be _ϵ-insensitive_.\n",
"8. The kernel trick is mathematical technique that makes it possible to train a nonlinear SVM model. The resulting model is equivalent to mapping the inputs to another space using a nonlinear transformation, then training a linear SVM on the resulting high-dimensional inputs. The kernel trick gives the same result without having to transform the inputs at all."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"# 9."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: Train a `LinearSVC` on a linearly separable dataset. Then train an `SVC` and a `SGDClassifier` on the same dataset. See if you can get them to produce roughly the same model._"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's use the Iris dataset: the Iris Setosa and Iris Versicolor classes are linearly separable."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn import datasets\n",
"\n",
"iris = datasets.load_iris(as_frame=True)\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = iris.target\n",
2017-06-01 09:23:37 +02:00
"\n",
"setosa_or_versicolor = (y == 0) | (y == 1)\n",
"X = X[setosa_or_versicolor]\n",
"y = y[setosa_or_versicolor]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's build and train 3 models:\n",
"* Remember that `LinearSVC` uses `loss=\"squared_hinge\"` by default, so if we want all 3 models to produce similar results, we need to set `loss=\"hinge\"`.\n",
"* Also, the `SVC` class uses an RBF kernel by default, so we need to set `kernel=\"linear\"` to get similar results as the other two models.\n",
"* Lastly, the `SGDClassifier` class does not have a `C` hyperparameter, but it has another regularization hyperparameter called `alpha`, so we can tweak it to get similar results as the other two models."
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 35,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.svm import SVC, LinearSVC\n",
"from sklearn.linear_model import SGDClassifier\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"C = 5\n",
"alpha = 0.05\n",
2017-06-01 09:23:37 +02:00
"\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(X)\n",
"\n",
"lin_clf = LinearSVC(loss=\"hinge\", C=C, random_state=42).fit(X_scaled, y)\n",
"svc_clf = SVC(kernel=\"linear\", C=C).fit(X_scaled, y)\n",
"sgd_clf = SGDClassifier(alpha=alpha, random_state=42).fit(X_scaled, y)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's plot the decision boundaries of these three models:"
]
},
{
"cell_type": "code",
"execution_count": 36,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqMAAAEOCAYAAAC99R7FAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABuIElEQVR4nO3dd3xO1x/A8c/JEklolAgSETOE2rVbCTVLqVF7q01b1Py1RUsXbdXW2pSiatVsiVU1G1QQI7FiVa0gZJzfH/fxCBKekCdPxvf9et2X59577r3n5mZ8nXvO+SqtNUIIIYQQQtiCna0rIIQQQgghMi4JRoUQQgghhM1IMCqEEEIIIWxGglEhhBBCCGEzEowKIYQQQgibkWBUCCGEEELYTIoFo0qpvEqpzUqpI0qpw0qp9xIoo5RS3yulTiilDiqlysbbV1cpdcy0b0hK1VsIIYQQQlhPSraMxgADtNbFgEpAb6WU/2Nl6gGFTUs3YAqAUsoemGTa7w+0SuBYIYQQQgiRxqRYMKq1vqC13m/6fAs4Ang9VqwRMFcb/gLclVK5gQrACa31Ka31fWCRqawQQgghhEjDHGxxUaWUL1AG2PXYLi/gbLz1c6ZtCW2vmMi5u2G0quLs7FzOx8cneSotUlRcXBx2dtKlOa2S55d2ybNL2+T5pW3p+fmFhob+q7X2SGhfigejSik34Bfgfa31zcd3J3CIfsr2JzdqPR2YDuDn56ePHTv2ArUVthIUFERAQICtqyGekzy/tEueXdomzy9tS8/PTyl1OrF9KRqMKqUcMQLRBVrrZQkUOQfkjbfuDUQATolsF0IIIYQQaVhKjqZXwAzgiNb6m0SKrQTam0bVVwJuaK0vAHuAwkqp/EopJ6ClqawQQgghhEjDUrJltCrQDjiklAo2bRsG+ABoracCa4D6wAngDtDJtC9GKdUHWA/YAzO11odTsO5CCCGEEMIKUiwY1VpvJ+G+n/HLaKB3IvvWYASrQgghhBAinUifQ7aEEEIIIUSaYJOpnYQQIr6bN29y+fJloqOjbV2VDO+ll17iyJEjtq6GRRwdHcmZMydZs2a1dVWEEC9AglEhhE3dvHmTS5cu4eXlRebMmTHGOgpbuXXrFlmyZLF1NZ5Ja83du3c5f/48gASkQqRh8ppeCGFTly9fxsvLCxcXFwlEhcWUUri4uODl5cXly5dtXR0hxAuQYFQIYVPR0dFkzpzZ1tUQaVTmzJmle4cQaZwEo0IIm5MWUfG85HtHiLRPglEhhBBCCGEzEowKIYQQQgibkWBUCCGSma+vL2PHjrV1NYQQIk2QYFQIIZ5Dx44dadCgQYL79uzZQ69evVK4RonbsmULNWvWJEeOHLi4uFCwYEHatGnDzZs32bdvH0optm/fnuCx77zzDlWrVjWv37p1i48++gh/f38yZ86Mp6cnAQEBLFy4kLi4uJS6JSFEOiLBqBBCJDMPDw9cXFxsXQ3u379PSEgIdevWpWTJkmzevJl//vmHKVOm8NJLL3Hv3j3KlStHmTJlmDFjxhPHX716lZUrV9KlSxcArl+/TuXKlZk5cyYffvghe/fuZfv27XTo0IFPP/2UM2fOpPQtCiHSAQlGhRAimT3+ml4pxfTp02nevDmurq4UKFCA+fPnP3LM+fPnadmyJdmyZSNbtmy8+eabHD9+3Lz/5MmTNGrUiFy5cuHq6krZsmVZvXr1E9cdMWIEnTt3xt3dnTZt2rBhwwayZ8/Ot99+yyuvvEKBAgWoXbs2kydPxsPDA4AuXbqwZMkSIiMjHznf/PnzcXR0pEWLFgAMGzaMsLAwdu3aRadOnShevDiFCxemU6dO7N+/n1y5ciXr11EIkTGk62A0Lk6m/BAiLQoICGD27NmAMQ9pQECAOXi7c+cOAQEB/PzzzwDcuHGDgIAAli1bBsC///5LQEAAq1atAuDixYsEBASwbt06AM6ePUtAQAC///47AKdOnUqRexo1ahSNGjXiwIEDtGjRgs6dO3P69GnzPQUGBuLs7MyWLVvYuXMnuXPn5o033uDOnTsAREZGUq9ePTZu3MiBAwdo2rQpTZo04ejRo49c55tvvqFo0aLs3buXMWPGkCtXLq5cucLmzZsTrVubNm2IjY01f00fmDlzJi1btsTV1ZW4uDgWLVpEmzZt8Pb2fuIczs7OODs7v+iXSQiRAaXrYNTOTgNw/z58/jlERNi4QkKIDKtdu3a0bduWQoUK8emnn+Lg4MC2bdsAWLRoEVprZs2aRcmSJSlatCjTpk0jMjLS3PpZqlQpevTowSuvvEKhQoUYPnw4ZcuWZenSpY9cp3r16gwaNIhChQpRuHBhmjdvTuvWralRowaenp40bNiQb775hitXrpiPcXd3p2nTpo+8qt+zZw8HDx6ka9eugBHkX7t2jWLFiln7SyVEqpMrFyj15CIvA5JHus5Nn+nKFQgLY9mu/AwbBh99BG+/Db16QUCA8Y0khEh9goKCzJ8dHR0fWXdxcXlk/aWXXnpkPUeOHI+s58qV65H1vHnzPrJeoECB5Kv4U5QsWdL82cHBAQ8PD3May3379hEWFvZETvg7d+5w8uRJAG7fvs3IkSNZvXo1Fy5cIDo6mqioqEfOC1C+fPlH1u3t7Zk1axafffYZmzZt4q+//uLrr79m9OjRbN26leLFiwPGq/oaNWoQGhpKuXLlmDlzJiVKlKBixYqAkQteiIzq0qWkbRdJk65bRp2uXYOCBak7qQGfVV2L0nEsXQo1akDx4jBxIty4YetaCiEyAkdHx0fWlVLm0edxcXGULl2a4ODgR5bQ0FC6d+8OwMCBA1myZAmffvopW7ZsITg4mAoVKnD//v1Hzuvq6prg9b28vGjXrh2TJk0iJCQEOzs7vv76a/P+gIAAChUqxLx587h79y4LFy40D1wCY1BWtmzZOHLkSLJ8PYQQ4oF0HYzeLlAA/vc/3I/vZfiO+tzNW4R1tb+hWK5rHDkCfftCrVq2rqUQIqMrW7YsJ06cIEeOHBQqVOiR5eWXXwZg+/bttG/fnqZNm1KyZEm8vb3NraZJlS1bNnLnzv3IgCWlFJ07d2bhwoUsXLiQu3fv0q5dO/N+Ozs7WrRowYIFCzh37twT54yKiiIqKuq56iOEyNjSdTAa5+AAo0bBmTOwcCEOeXNTZ8MADt/w4lTNrnR79W/at39Y/swZWLAA7t2zXZ2FEGnHzZs3n2jNDA8PT/J52rRpg6enJ40aNWLLli2EhYWxdetWBgwYYB5RX6RIEX799Vf279/PoUOHaNu2rUXB37Rp0+jZsycbNmzg5MmTHD58mMGDB3Po0CEaN278SNmOHTty9epVBg4cSOPGjcmePfsj+8eMGYOPjw8VK1Zk1qxZHD58mBMnTjBv3jzKlSvHxYsXk3zvQgiRYn1GlVIzgQbAZa11iQT2fwi0iVevYoCH1vo/pVQ4cAuIBWK01uUfP/6pnJygZUtjOXAANXky+efPZ9qdGWiHypCtNzRrxpQpmfjiC/jgA+jaFbp3h3z5XuCmhRDp2rZt2yhTpswj25o2bZrk87i4uLB161aGDBlC8+bNuXHjBnny5CEwMJBs2bIBxij5Ll268Nprr5EtWzbef/99i4LRChUq8Oeff9KzZ08iIiJwcXGhcOHCzJ07l7Zt2z5SNnfu3NSuXZu1a9eaBy7Fly1bNv766y+++uorvvzyS8LDw8maNSv+/v589NFH+Pj4JPnehRBCpVSndKXU60AkMDehYPSxsg2BD7TWNUzr4UB5rfW/Sbmmn5+fPnbsWMI7r1+H2bNh8mQ4fhxy5uSfyu/S/1h3Nh7NC4CdHbz5pjHgqXZtY12kjKCgIAICAmxdDfGckvL8jhw5IiO0U5Fbt249MZAqtZPvoYfkd6d15MqV8GAlT09IzhcC6fn5KaX2JdaYmGLhldZ6K/CfhcVbAQutWB1wd4f334ejR2H9eqhUiRKrPmd9qC//vt6EMTX/wMFes2oV1KsH771n1doIIYQQIpW6eBG0fnKRninJI9W19SmlXIC6wC/xNmtgg1Jqn1KqW7Je0M7OaPZcsQJOnkQNGkT2w1s
"text/plain": [
"<Figure size 792x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"def compute_decision_boundary(model):\n",
" w = -model.coef_[0, 0] / model.coef_[0, 1]\n",
" b = -model.intercept_[0] / model.coef_[0, 1]\n",
" return scaler.inverse_transform([[-10, -10 * w + b], [10, 10 * w + b]])\n",
2017-06-01 09:23:37 +02:00
"\n",
"lin_line = compute_decision_boundary(lin_clf)\n",
"svc_line = compute_decision_boundary(svc_clf)\n",
"sgd_line = compute_decision_boundary(sgd_clf)\n",
2017-06-01 09:23:37 +02:00
"\n",
"# Plot all three decision boundaries\n",
"plt.figure(figsize=(11, 4))\n",
"plt.plot(lin_line[:, 0], lin_line[:, 1], \"k:\", label=\"LinearSVC\")\n",
"plt.plot(svc_line[:, 0], svc_line[:, 1], \"b--\", linewidth=2, label=\"SVC\")\n",
"plt.plot(sgd_line[:, 0], sgd_line[:, 1], \"r-\", label=\"SGDClassifier\")\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\") # label=\"Iris versicolor\"\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\") # label=\"Iris setosa\"\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper center\")\n",
2017-06-01 09:23:37 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.grid()\n",
2017-06-01 09:23:37 +02:00
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Close enough!"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"# 10."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: Train an SVM classifier on the Wine dataset, which you can load using `sklearn.datasets.load_wine()`. This dataset contains the chemical analysis of 178 wine samples produced by 3 different cultivators: the goal is to train a classification model capable of predicting the cultivator based on the wine's chemical analysis. Since SVM classifiers are binary classifiers, you will need to use one-versus-all to classify all 3 classes. What accuracy can you reach?_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"First, let's fetch the dataset, look at its description, then split it into a training set and a test set:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import load_wine\n",
"\n",
"wine = load_wine(as_frame=True)"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 38,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
".. _wine_dataset:\n",
"\n",
"Wine recognition dataset\n",
"------------------------\n",
"\n",
"**Data Set Characteristics:**\n",
"\n",
" :Number of Instances: 178 (50 in each of three classes)\n",
" :Number of Attributes: 13 numeric, predictive attributes and the class\n",
" :Attribute Information:\n",
" \t\t- Alcohol\n",
" \t\t- Malic acid\n",
" \t\t- Ash\n",
"\t\t- Alcalinity of ash \n",
" \t\t- Magnesium\n",
"\t\t- Total phenols\n",
" \t\t- Flavanoids\n",
" \t\t- Nonflavanoid phenols\n",
" \t\t- Proanthocyanins\n",
"\t\t- Color intensity\n",
" \t\t- Hue\n",
" \t\t- OD280/OD315 of diluted wines\n",
" \t\t- Proline\n",
"\n",
" - class:\n",
" - class_0\n",
" - class_1\n",
" - class_2\n",
"\t\t\n",
" :Summary Statistics:\n",
" \n",
" ============================= ==== ===== ======= =====\n",
" Min Max Mean SD\n",
" ============================= ==== ===== ======= =====\n",
" Alcohol: 11.0 14.8 13.0 0.8\n",
"<<26 more lines>>\n",
"wine.\n",
"\n",
"Original Owners: \n",
"\n",
"Forina, M. et al, PARVUS - \n",
"An Extendible Package for Data Exploration, Classification and Correlation. \n",
"Institute of Pharmaceutical and Food Analysis and Technologies,\n",
"Via Brigata Salerno, 16147 Genoa, Italy.\n",
"\n",
"Citation:\n",
"\n",
"Lichman, M. (2013). UCI Machine Learning Repository\n",
"[https://archive.ics.uci.edu/ml]. Irvine, CA: University of California,\n",
"School of Information and Computer Science. \n",
"\n",
".. topic:: References\n",
"\n",
" (1) S. Aeberhard, D. Coomans and O. de Vel, \n",
" Comparison of Classifiers in High Dimensional Settings, \n",
" Tech. Rep. no. 92-02, (1992), Dept. of Computer Science and Dept. of \n",
" Mathematics and Statistics, James Cook University of North Queensland. \n",
" (Also submitted to Technometrics). \n",
"\n",
" The data was used with many others for comparing various \n",
" classifiers. The classes are separable, though only RDA \n",
" has achieved 100% correct classification. \n",
" (RDA : 100%, QDA 99.4%, LDA 98.9%, 1NN 96.1% (z-transformed data)) \n",
" (All results using the leave-one-out technique) \n",
"\n",
" (2) S. Aeberhard, D. Coomans and O. de Vel, \n",
" \"THE CLASSIFICATION PERFORMANCE OF RDA\" \n",
" Tech. Rep. no. 92-01, (1992), Dept. of Computer Science and Dept. of \n",
" Mathematics and Statistics, James Cook University of North Queensland. \n",
" (Also submitted to Journal of Chemometrics).\n",
"\n"
]
}
],
2017-06-01 09:23:37 +02:00
"source": [
"print(wine.DESCR)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 39,
2017-12-19 22:40:17 +01:00
"metadata": {},
"outputs": [],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" wine.data, wine.target, random_state=42)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 40,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>alcohol</th>\n",
" <th>malic_acid</th>\n",
" <th>ash</th>\n",
" <th>alcalinity_of_ash</th>\n",
" <th>magnesium</th>\n",
" <th>total_phenols</th>\n",
" <th>flavanoids</th>\n",
" <th>nonflavanoid_phenols</th>\n",
" <th>proanthocyanins</th>\n",
" <th>color_intensity</th>\n",
" <th>hue</th>\n",
" <th>od280/od315_of_diluted_wines</th>\n",
" <th>proline</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>13.16</td>\n",
" <td>2.36</td>\n",
" <td>2.67</td>\n",
" <td>18.6</td>\n",
" <td>101.0</td>\n",
" <td>2.80</td>\n",
" <td>3.24</td>\n",
" <td>0.30</td>\n",
" <td>2.81</td>\n",
" <td>5.68</td>\n",
" <td>1.03</td>\n",
" <td>3.17</td>\n",
" <td>1185.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>100</th>\n",
" <td>12.08</td>\n",
" <td>2.08</td>\n",
" <td>1.70</td>\n",
" <td>17.5</td>\n",
" <td>97.0</td>\n",
" <td>2.23</td>\n",
" <td>2.17</td>\n",
" <td>0.26</td>\n",
" <td>1.40</td>\n",
" <td>3.30</td>\n",
" <td>1.27</td>\n",
" <td>2.96</td>\n",
" <td>710.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>122</th>\n",
" <td>12.42</td>\n",
" <td>4.43</td>\n",
" <td>2.73</td>\n",
" <td>26.5</td>\n",
" <td>102.0</td>\n",
" <td>2.20</td>\n",
" <td>2.13</td>\n",
" <td>0.43</td>\n",
" <td>1.71</td>\n",
" <td>2.08</td>\n",
" <td>0.92</td>\n",
" <td>3.12</td>\n",
" <td>365.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>154</th>\n",
" <td>12.58</td>\n",
" <td>1.29</td>\n",
" <td>2.10</td>\n",
" <td>20.0</td>\n",
" <td>103.0</td>\n",
" <td>1.48</td>\n",
" <td>0.58</td>\n",
" <td>0.53</td>\n",
" <td>1.40</td>\n",
" <td>7.60</td>\n",
" <td>0.58</td>\n",
" <td>1.55</td>\n",
" <td>640.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>51</th>\n",
" <td>13.83</td>\n",
" <td>1.65</td>\n",
" <td>2.60</td>\n",
" <td>17.2</td>\n",
" <td>94.0</td>\n",
" <td>2.45</td>\n",
" <td>2.99</td>\n",
" <td>0.22</td>\n",
" <td>2.29</td>\n",
" <td>5.60</td>\n",
" <td>1.24</td>\n",
" <td>3.37</td>\n",
" <td>1265.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" alcohol malic_acid ash alcalinity_of_ash magnesium total_phenols \\\n",
"2 13.16 2.36 2.67 18.6 101.0 2.80 \n",
"100 12.08 2.08 1.70 17.5 97.0 2.23 \n",
"122 12.42 4.43 2.73 26.5 102.0 2.20 \n",
"154 12.58 1.29 2.10 20.0 103.0 1.48 \n",
"51 13.83 1.65 2.60 17.2 94.0 2.45 \n",
"\n",
" flavanoids nonflavanoid_phenols proanthocyanins color_intensity hue \\\n",
"2 3.24 0.30 2.81 5.68 1.03 \n",
"100 2.17 0.26 1.40 3.30 1.27 \n",
"122 2.13 0.43 1.71 2.08 0.92 \n",
"154 0.58 0.53 1.40 7.60 0.58 \n",
"51 2.99 0.22 2.29 5.60 1.24 \n",
"\n",
" od280/od315_of_diluted_wines proline \n",
"2 3.17 1185.0 \n",
"100 2.96 710.0 \n",
"122 3.12 365.0 \n",
"154 1.55 640.0 \n",
"51 3.37 1265.0 "
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"X_train.head()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 41,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"2 0\n",
"100 1\n",
"122 1\n",
"154 2\n",
"51 0\n",
"Name: target, dtype: int64"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"y_train.head()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's start simple, with a linear SVM classifier. It will automatically use the One-vs-All (also called One-vs-the-Rest, OvR) strategy, so there's nothing special we need to do to handle multiple classes. Easy, right?"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 42,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"LinearSVC(random_state=42)"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"lin_clf = LinearSVC(random_state=42)\n",
"lin_clf.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Oh no! It failed to converge. Can you guess why? Do you think we must just increase the number of training iterations? Let's see:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 43,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"LinearSVC(max_iter=1000000, random_state=42)"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"lin_clf = LinearSVC(max_iter=1_000_000, random_state=42)\n",
"lin_clf.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Even with one million iterations, it still did not converge. There must be another problem.\n",
"\n",
"Let's still evaluate this model with `cross_val_score`, it will serve as a baseline:"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 44,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n",
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n",
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n",
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n",
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"0.90997150997151"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.model_selection import cross_val_score\n",
"\n",
"cross_val_score(lin_clf, X_train, y_train).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Well 91% accuracy on this dataset is not great. So did you guess what the problem is?\n",
"\n",
"That's right, we forgot to scale the features! Always remember to scale the features when using SVMs:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 45,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('linearsvc', LinearSVC(random_state=42))])"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"lin_clf = make_pipeline(StandardScaler(),\n",
" LinearSVC(random_state=42))\n",
"lin_clf.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Now it converges without any problem. Let's measure its performance:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 46,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9774928774928775"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.model_selection import cross_val_score\n",
2017-06-01 09:23:37 +02:00
"\n",
"cross_val_score(lin_clf, X_train, y_train).mean()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Nice! We get 97.7% accuracy, that's much better."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's see if a kernelized SVM will do better. We will use a default `SVC` for now:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 47,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9698005698005698"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"svm_clf = make_pipeline(StandardScaler(), SVC(random_state=42))\n",
"cross_val_score(svm_clf, X_train, y_train).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's not better, but perhaps we need to do a bit of hyperparameter tuning:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 48,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('svc',\n",
" SVC(C=9.925589984899778, gamma=0.011986281799901176,\n",
" random_state=42))])"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.model_selection import RandomizedSearchCV\n",
"from scipy.stats import reciprocal, uniform\n",
"\n",
"param_distrib = {\n",
" \"svc__gamma\": reciprocal(0.001, 0.1),\n",
" \"svc__C\": uniform(1, 10)\n",
"}\n",
"rnd_search_cv = RandomizedSearchCV(svm_clf, param_distrib, n_iter=100, cv=5,\n",
" random_state=42)\n",
"rnd_search_cv.fit(X_train, y_train)\n",
"rnd_search_cv.best_estimator_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 49,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9925925925925926"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"rnd_search_cv.best_score_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Ah, this looks excellent! Let's select this model. Now we can test it on the test set:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 50,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9777777777777777"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"rnd_search_cv.score(X_test, y_test)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"This tuned kernelized SVM performs better than the `LinearSVC` model, but we get a lower score on the test set than we measured using cross-validation. This is quite common: since we did so much hyperparameter tuning, we ended up slightly overfitting the cross-validation test sets. It's tempting to tweak the hyperparameters a bit more until we get a better result on the test set, but we this would probably not help, as we would just start overfitting the test set. Anyway, this score is not bad at all, so let's stop here."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"## 11."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: Train and fine-tune an SVM regressor on the California housing dataset. You can use the original dataset rather than the tweaked version we used in Chapter 2. The original dataset can be fetched using `sklearn.datasets.fetch_california_housing()`. The targets represent hundreds of thousands of dollars. Since there are over 20,000 instances, SVMs can be slow, so for hyperparameter tuning you should use much less instances (e.g., 2,000), to test many more hyperparameter combinations. What is your best model's RMSE?_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's load the dataset:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.datasets import fetch_california_housing\n",
"\n",
"housing = fetch_california_housing()\n",
"X = housing.data\n",
"y = housing.target"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Split it into a training set and a test set:"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,\n",
" random_state=42)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Don't forget to scale the data:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's train a simple `LinearSVR` first:"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ageron/miniconda3/envs/homl3/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('linearsvr', LinearSVR(random_state=42))])"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.svm import LinearSVR\n",
2017-06-01 09:23:37 +02:00
"\n",
"lin_svr = make_pipeline(StandardScaler(), LinearSVR(random_state=42))\n",
"lin_svr.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"It did not converge, so let's increase `max_iter`:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 54,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('linearsvr', LinearSVR(max_iter=5000, random_state=42))])"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"lin_svr = make_pipeline(StandardScaler(),\n",
" LinearSVR(max_iter=5000, random_state=42))\n",
"lin_svr.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's see how it performs on the training set:"
]
},
{
"cell_type": "code",
"execution_count": 55,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9595484665813285"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.metrics import mean_squared_error\n",
"\n",
"y_pred = lin_svr.predict(X_train)\n",
2017-06-01 09:23:37 +02:00
"mse = mean_squared_error(y_train, y_pred)\n",
"mse"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's look at the RMSE:"
]
},
{
"cell_type": "code",
"execution_count": 56,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.979565447829459"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"np.sqrt(mse)"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"In this dataset, the targets represent hundreds of thousands of dollars. The RMSE gives a rough idea of the kind of error you should expect (with a higher weight for large errors): so with this model we can expect errors close to $98,000! Not great. Let's see if we can do better with an RBF Kernel. We will use randomized search with cross validation to find the appropriate hyperparameter values for `C` and `gamma`:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 57,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"RandomizedSearchCV(cv=3,\n",
" estimator=Pipeline(steps=[('standardscaler',\n",
" StandardScaler()),\n",
" ('svr', SVR())]),\n",
" n_iter=100,\n",
" param_distributions={'svr__C': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7ff030704ee0>,\n",
" 'svr__gamma': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7ff030704fd0>},\n",
" random_state=42)"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.svm import SVR\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"from scipy.stats import reciprocal, uniform\n",
"\n",
"svm_clf = make_pipeline(StandardScaler(), SVR())\n",
"\n",
"param_distrib = {\n",
" \"svr__gamma\": reciprocal(0.001, 0.1),\n",
" \"svr__C\": uniform(1, 10)\n",
"}\n",
"rnd_search_cv = RandomizedSearchCV(svm_clf, param_distrib,\n",
" n_iter=100, cv=3, random_state=42)\n",
"rnd_search_cv.fit(X_train[:2000], y_train[:2000])"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 58,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('standardscaler', StandardScaler()),\n",
" ('svr', SVR(C=4.63629602379294, gamma=0.08781408196485974))])"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"rnd_search_cv.best_estimator_"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0.58835648, 0.57468589, 0.58085278, 0.57109886, 0.59853029])"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"-cross_val_score(rnd_search_cv.best_estimator_, X_train, y_train,\n",
" scoring=\"neg_root_mean_squared_error\")"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Looks much better than the linear model. Let's select this model and evaluate it on the test set:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 60,
2017-12-19 22:40:17 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.5854732265172222"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-01 09:23:37 +02:00
"source": [
"y_pred = rnd_search_cv.best_estimator_.predict(X_test)\n",
"rmse = mean_squared_error(y_test, y_pred, squared=False)\n",
"rmse"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"So SVMs worked very well on the Wine dataset, but not so much on the California Housing dataset. In Chapter 2, we found that Random Forests worked better for that dataset."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"And that's all for today!"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
2016-09-27 23:31:21 +02:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-10-17 03:27:34 +02:00
"version": "3.8.12"
2016-09-27 23:31:21 +02:00
},
"nav_menu": {},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
2016-09-27 23:31:21 +02:00
}