handson-ml/04_training_linear_models.i...

2805 lines
777 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-02 13:14:44 +02:00
"**Chapter 4 Training Models**"
2016-09-27 16:39:16 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 16:39:16 +02:00
"source": [
2017-08-19 17:01:55 +02:00
"_This notebook contains all the sample code and solutions to the exercises in chapter 4._"
2016-09-27 16:39:16 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/04_training_linear_models.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/04_training_linear_models.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
" </td>\n",
"</table>"
]
},
2016-09-27 16:39:16 +02:00
{
"cell_type": "markdown",
2021-11-03 23:35:15 +01:00
"metadata": {
"tags": []
},
2016-09-27 16:39:16 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 16:39:16 +02:00
"source": [
2022-02-19 11:03:20 +01:00
"This project requires Python 3.7 or above:"
]
},
{
"cell_type": "code",
2017-02-17 11:51:26 +01:00
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
2022-02-19 11:03:20 +01:00
"assert sys.version_info >= (3, 7)"
2021-11-03 23:35:15 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It also requires Scikit-Learn ≥ 1.0.1:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from packaging import version\n",
"import sklearn\n",
2016-09-27 16:39:16 +02:00
"\n",
"assert version.parse(sklearn.__version__) >= version.parse(\"1.0.1\")"
2021-11-03 23:35:15 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we did in previous chapters, let's define the default font sizes to make the figures prettier:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
2021-11-27 00:54:49 +01:00
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rc('font', size=14)\n",
2021-11-27 00:54:49 +01:00
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)"
2021-11-03 23:35:15 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And let's create the `images/training_linear_models` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"IMAGES_PATH = Path() / \"images\" / \"training_linear_models\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-02 13:14:44 +02:00
"# Linear Regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The Normal Equation"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
2017-05-29 23:20:14 +02:00
"import numpy as np\n",
"\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42) # to make this code example reproducible\n",
"m = 100 # number of instances\n",
"X = 2 * np.random.rand(m, 1) # column vector\n",
"y = 4 + 3 * X + np.random.randn(m, 1) # column vector"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 6,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbuElEQVR4nO3df5BdZX3H8c93Q8IEljQaVnRUDFbXqQGJkqrR1mZl2uKvaotWKBU0tDvWqQWLtaKjzpRp6Khg7FTHou4IY0pnFW0dKgrFTbHdiCa4YFbEX5PGiJYxGnFRA2G//ePcNZebe++ec+85z/Oce9+vmZ3N3l/nuydnn899nvOc55q7CwCA0EZiFwAAGE4EEAAgCgIIABAFAQQAiIIAAgBEQQABAKI4LnYBrdauXetPecpTYpdR2AMPPKATTzwxdhmFUHMY1BxGHWuW6ln3nj17fuTuY/2+TnIBdMopp2j37t2xyyhs586d2rJlS+wyCqHmMKg5jDrWLNWzbjP73zJehyE4AEAUBBAAIAoCCAAQBQEEAIiCAAIAREEAAQCiIIAAAFEQQACAKAggAEAUpQWQmU2Z2X1mtrfNfW82Mzezk8vaHgCg3srsAX1M0jmtN5rZEyX9rqT9JW4LAFBzpQWQu98m6cdt7nqfpLdI8rK2BQCov0rPAZnZH0j6vrvfWeV2AAD1Y+7ldUzMbL2kG939dDM7QdKMpN9z95+a2T5Jm9z9R22eNylpUpLGxsbOmp6eLq2mUBYWFjQ6Ohq7jEKoOQxqDqOONUv1rHtiYmKPu2/q+4XcvbQvSesl7W38+wxJ90na1/g6ouw80GO7vcb4+LjX0czMTOwSCqPmMKg5jDrW7F7PuiXt9hIyo7LPA3L3r0l6zNLP3XpAAIDhU+Y07Osl7ZL0NDM7YGYXl/XaAIDBU1oPyN3PX+b+9WVtCwBQf6yEAACIggACAERBAAEAoiCAAABREEAAgCgIIABAFAQQACAKAggAEAUBBACIggACAERBAAEAoiCAAABREEAAgCgIIABAFAQQACAKAggAEAUBBACIggACAERBAAEAoiCAAABREEAAgChKCyAzmzKz+8xsb9Nt7zGzb5jZXWb2aTNbW9b2AAD1VmYP6GOSzmm57RZJp7v7MyR9U9LlJW4PAFBjpQWQu98m6cctt93s7kcaP35J0hPK2h4AoN5CngPaKummgNsDACTM3L28FzNbL+lGdz+95fa3S9ok6Y+8zQbNbFLSpCSNjY2dNT09XVpNoSwsLGh0dDR2GYVQcxjUHEYda5bqWffExMQed9/U9wu5e2lfktZL2tty20WSdkk6Ic9rjI+Pex3NzMzELqEwag6DmsOoY83u9axb0m4vITOO6zvBujCzcyT9raTfcfefV7ktAEC9lDkN+3plPZ2nmdkBM7tY0j9JOknSLWY2Z2YfKmt7AIB6K60H5O7nt7n5o2W9PgBgsLASAgAgCgIIABAFAQQAiIIAAgBEQQABAKIggAAAURBAAIAoCCAACGDXLunKK7PvyFS6FA8AIAuds8+WHnxQWrVKuvVWafPm2FXFRw8IACq2c2cWPg8/nH3fuTN2RWkggACgYlu2ZD2fFSuy71u2xK4oDQzBAUDFNm/Oht127szCh+G3DAEEAAFs3ly/4Nm1q9rQJIAAAMcIMXGCc0AAgGOEmDhBAAEAjhFi4gRDcACAY4SYOEEAAQDaqnriBENwAJCIYVuuhx4QACQgz6yzqqdFh1ZaD8jMpszsPjPb23Tbo83sFjP7VuP7o8raHoDBNGy9gCXLzTpbCqh3vCP7Xub+ibXPyxyC+5ikc1pue6ukW939qZJubfwMAG1V2cimbrlZZ1VNi465z0sLIHe/TdKPW25+uaRrG/++VtIrytoegMEzzIt2Ls06u+KK9sNvVU2LjrnPqz4HdIq7/0CS3P0HZvaYircHoMaWGtml8yDDtmhnt1lnVU2LjrnPzd3LezGz9ZJudPfTGz8fcve1Tff/xN2POQ9kZpOSJiVpbGzsrOnp6dJqCmVhYUGjo6OxyyiEmsOg5mLm59dobm6tNm48pA0b7s/9vDruZymNuovu84mJiT3uvqnvDbt7aV+S1kva2/TzPZIe1/j34yTds9xrjI+Pex3NzMzELqEwag6DmsMIXfPsrPu2bdn3ftRxX0va7SVkRtVDcJ+RdJGkf2h8//eKtwcAlavDJ5zWYcp2aQFkZtdL2iLpZDM7IOldyoJn2swulrRf0qvK2h4AxNLuxH1KjXwdAlIqMYDc/fwOd51d1jYAIAUpTZZo19NpDcjrrkuzN8RKCMCQqsMQTbOU6k3lE0479XSaA/K446SpqSyMUusNEUDAEKrLEM2SFOtN4RNOOw0FNgfk/v3Shz+c5nAhi5ECQ6huF3zWrd5Qul2cunmzdPnl0oUXZveNjGRf69bFqvZYBBAwhEJ82FiZ6lZvKMutnrD0mO3bs/B5+GHp0kvTWeKIIThgCKVyDiOvutUbUp6hwIMHJXdpcTGtYTgCCBhSKZzDKCJmvSlNgOhFSrP2mhFAANBFihMgikq1B0kAAUAXqV90mleKPV4mIQBAk9YPZ2MCRHXoAQFAQ6fhthSHrwYBAQQADd0u7AwdPHWf+JAHAQQADa2zxdaty4bjQodA6IkPscKOAAJQCyEayebhtnXrsos2Y8x+CznxIeYsPyYhAEjeUiP5jndk36u8kn9pCZuDB8tf/qd1gkMnISc+xFzmiB4QkKhhOAeQV4yp0GVfvFmkpxFy4kPMi1QJICBBg3DxY5liNJJlh0DREA018SHmLD8CCEjQoFz8WES3Ht9SI3nddWFrWi4EivRSU10OR4p3kSoBBCQo5caqCnl7fNdemz3m2mvj9wqL9lK5nuhYBBCQoGFrrPL0+Ir0CkOcP+ull5ricjgx5QogMzsg6Wp3v7rptjMkfUXSs9z96xXVBwytYWqs8vT48vYK8/ZM+g2pYeulViFvD2iXpN9suW27pI8QPsBg2bVL2rHjVB1/fLgAzNPjy9srzNMzKWOSx7D1UqtQJIDesPSDmb1C0jMl/XGeJ5vZmyT9mSSX9DVJr3P3XxaqFEDllhrmw4dP044dYc+z5Onx5XlMnp5JWZM8hqmXWoW8F6J+SdKvm9mjzex4Se+V9HfufnC5J5rZ4yX9laRN7n66pBWSzuu1YADVWWqYFxct+EWJZen0MdXNF4GywnUa8vaA9kh6UNImZT2fI5I+UHA7q83sIUknSLq3SJEAqtF6HmSpYT58eFGrVo3UtmFu7Zm0G3Jj+Cy+XAHk7ofN7KuSXibpIkl/4u4P5Xzu983svZL2S/qFpJvd/eZeCwZQjm4fPTA1tU9btz651IY55soO110n/fKXkvvRIbfLLyd4YjN3z/dAs/dJukTSLe7++7k3YPYoSTdIerWkQ5I+IemT7v7xpsdMSpqUpLGxsbOmp6fzvnwyFhYWNDo6GruMQqg5jH5rnp9fo7m5tdq48ZA2bLi/tLp27DhVU1OnaXHRNDKyqK1b9+mCC/ZLKn8/z8+v0WWXnamHHhrRypWLuuqqO0v9XaTONc/Pr9Gb3rRRDz1kkqSVKxf1vveVv/1exTymez22JiYm9rj7pr4LcPdcX8p6Pkckbcj7nMbzXiXpo00/Xyjpg50ePz4+7nU0MzMTu4TCqDmMfmqenXVfvdp9xYrs++xseXV1e+2y9/O2bdl2pOz7tm2lvry7d665edtm7q9/ffnb7kesY7qfY0vSbi+QA52+iqyGfYGkf3b3+YIZt1/Sc83sBDMzSWdLurvgawBDqcqVijudrK9CnpP+eVeK7mXbK1ZIZtLKldKFF5b7+nUVcxXsJV3PAZnZiKQxSa+VdIayYbRC3P12M/ukpDuU9aC+KumawpUCQ6jqix1TWfCy6sVXzR75HWlcSLvcJIQXSPqCpHsknevuP+llI+7+Lknv6uW5wDAbpIsdu4VdlYuv7twpHTmSTUA4cmQ4FnbNI4Vjq2sAuftO8aF16BOfa9OfYbjYscp34ym8009V7GOLxUhRKT7XBnlU+W48hXf6aI8AQqWG8XNtUlOXHmjsd+M
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code generates and saves Figure 41\n",
2021-11-03 23:35:15 +01:00
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.figure(figsize=(6, 4))\n",
"plt.plot(X, y, \"b.\")\n",
2021-11-03 23:35:15 +01:00
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
"plt.axis([0, 2, 0, 15])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
2016-09-27 16:39:16 +02:00
"save_fig(\"generated_data_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
2021-11-03 23:35:15 +01:00
"from sklearn.preprocessing import add_dummy_feature\n",
"\n",
"X_b = add_dummy_feature(X) # add x0 = 1 to each instance\n",
"theta_best = np.linalg.inv(X_b.T @ X_b) @ X_b.T @ y"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 8,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21509616],\n",
" [2.77011339]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"theta_best"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 9,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21509616],\n",
" [9.75532293]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_new = np.array([[0], [2]])\n",
2021-11-03 23:35:15 +01:00
"X_new_b = add_dummy_feature(X_new) # add x0 = 1 to each instance\n",
"y_predict = X_new_b @ theta_best\n",
"y_predict"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 10,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAArhUlEQVR4nO3de3hV1Z3/8fdKSCAQ8AIRFRR0FFS8IMXqSRQCqD9Ha2mr1lZbdWiL07G2trbUy/PTtnSgtV7bzkzHURQd7Dx4669D1WppzqgkKiBRQYqt1kHQipdSBAK5fX9/7JNwcsjlnGRfk8/rec5zyDn77P3NZmd/z1rru9d2ZoaIiEjYiqIOQEREBiYlIBERiYQSkIiIREIJSEREIqEEJCIikVACEhGRSAyKOoBc++67rx1xxBFRh1GwHTt2MGzYsKjDKIhiDodiDkcSY4Zkxr169er3zayir+uJXQIaPXo0q1atijqMgqXTaaqrq6MOoyCKORyKORxJjBmSGbdz7n/9WI+64EREJBJKQCIiEgklIBERiYQSkIiIREIJSEREIhG7KriebNu2jS1bttDU1BR1KB3ss88+rF+/PuowCpKUmEtKSjjggAMYMWJE1KGIiI8SlYC2bdvGu+++y5gxYygrK8M5F3VI7T766COGDx8edRgFSULMZkZDQwObN2+OOhQR8VmiuuC2bNnCmDFjGDp0aKySjwTHOcfQoUMZM2YMW7ZsiTocEfFRohJQU1MTZWVlUYchESgrK4tdt6uI9I1vCcg5t8g5t8U5t7aT977tnDPn3CgfttPXVUgC6f9dpP/xswV0L3BW7ovOuUOAM4CNPm5LREQSzrcEZGZPAx928tZtwDzA/NqWdO2hhx7q0Fq49957KS8v79M60+k0zjnef//9voYnItIu0DEg59wngc1m9lKQ20mCyy67DOcczjlKSko4/PDD+fa3v82OHTsC3e6FF17IG2+8kffy48eP5+abb+7wWmVlJe+88w4jR470OzwRGcACK8N2zg0FrgfOzGPZucBcgIqKCtLpdKfL7bPPPnz00Uc+RumflpaWbmNrampixowZ3HnnnTQ1NVFbW8uVV17J1q1bue222zos29zcTHFxca/GPRoaGgA6xFJWVtZpbJ3FbGbs3r17r9eHDRvG9u3bC47HT7t27WL79u1dHh9xpZjDkcSYIblx+8LMfHsA44G1mX8fB2wB3sw8mvHGgQ7sbh0TJkywrrz66qtdvhe1bdu2dfv+pZdeauecc06H17785S/bgQceaDfeeKNNmjTJ7rnnHjv88MOtqKjIPvroI9u6dat95StfsYqKCisvL7dp06bZypUrO6xj8eLFduihh1pZWZmdc8459vOf/9y8/1bPPffcY8OGDevwmWXLltnHP/5xGzJkiO2///72iU98whoaGmz69OmG11Xa/jAzq6mpMcDee++99nU8/PDDduyxx1ppaamNHTvWfvjDH1pra2v7++PGjbP58+fb3Llzbfjw4TZmzBi76aabOsTxi1/8wo488kgbPHiwjRo1ys4880xramrqch+++uqrVlNT0+1+jiPFHI4kxmyWzLiBVeZDzgisC87MXjGzA8xsvJmNBzYBU8zsL0FtM2myS4v//Oc/88ADD/Dggw/y0ksvMXjwYM455xw2b97MsmXLWLNmDdOmTWPmzJm88847ADz//PNcdtllzJ07l/r6es4991xuuOGGbrf5xBNPMHv2bM444wyefvppampqmD59Oq2trTzyyCOMHTuWG264gXfeead9O7lWr17NBRdcwGc+8xleeeUVfvSjH7Fw4UJ+/vOfd1jutttu47jjjuPFF1/ku9/9LvPmzaOurg6AVatWccUVV3DjjTeyYcMGfve733HWWXvVsIhIP+ZbF5xz7pdANTDKObcJuNHM7vZr/V266iqorw98Mx1Mngy3396nVbzwwgs88MADzJo1C4DGxkbuv/9+Ro8eDcDvf/976uvree+999qvfZo/fz7//d//zf3338+8efO44447mDVrFtdffz0AEyZMYOXKldx9d9e7ff78+Zx//vn88Ic/bJ8J4fjjjwdg6NChFBcXM3z4cA488MAu13Hrrbcyffp0vv/977dv949//CM//vGPufLKK9uXO/PMM/na174GwJVXXslPf/pTli9fTiqVYuPGjQwbNoxPfvKTDB8+nHHjxnHCCSf0dneKSAL5WQX3eTM7yMxKzGxsbvLJtIQGdBnVE088QXl5OUOGDCGVSjFt2jR+9rOfATB27Nj25ANeK2Pnzp1UVFRQXl7e/li7di2vv/46AOvXryeVSnXYRu7PudasWdOe9Hpr/fr1VFVVdXjt1FNPZfPmzWzbtq39tbbE1ubggw9un83gjDPOYNy4cRx22GFcfPHFLF68OLbjeyISjETNBdepPrZEwjRt2jTuvPNOSkpKOPjggykpKWl/L/ee8K2trYwePZpnnnlmr/W0TcrpdcWGz8y6LJDIfj3792t7r7W1FYDhw4fz4osv8vTTT/PUU0+xcOFCrrvuOlauXMnBBx8cXPAiEhuJmoon6YYOHcoRRxzBuHHj9jo555oyZQrvvvsuRUVFHHHEER0eBxxwAADHHHMMzz33XIfP5f6c68QTT2T58uVdvl9aWkpLS0u36zjmmGN49tlnO7z27LPPMnbs2IImNx00aBAzZ85k4cKFvPzyy+zYsYNly5bl/XkRSbbkt4D6qdNPP52qqipmz57NTTfdxFFHHcVf/vIXnnjiCU4//XROO+00vv71r1NZWcnChQs5//zzSafTPProo92u9/rrr+fcc8/liCOOYPbs2QwdOpQnn3ySyy+/nKFDhzJ+/HieeeYZvvCFLzB48GBGjdp79qSrr76ak046ie9973tcdNFFrFy5kltuuYUFCxbk/fstW7aM119/nWnTprH//vtTU1PDRx99xNFHH13wvhKRZFILKKacczz22GPMnDmTr3zlK0ycOJHPfvazbNiwob2L6pRTTuHuu+/m3/7t3zj++ON55JFH+N73vtftes8++2weffRRHn/8cU499VSmT59OTU0NRUXeofCDH/yAt956i7/7u7+joqKi03VMmTKFBx98kIcffphjjz2Wa665hmuuuaa94CAf++67L7/61a84/fTTOeqoo7j55pu56667OO200/Jeh4gkm4tqHKErEydOtA0bNnT63vr162P7DTkJ99bJlbSY169fz7vvvkt1dXXUoRQknU4r5hAkMWZIZtzOudVmNrWv61ELSEREIqEEJCIikVACEhGRSCgBiYhIJBKXgOJWNCHh0P+7SP+TqARUUlLSfrsBGVgaGhp6vHhXRJIlUQnogAMOYPPmzezcuVPfiAcIM2Pnzp1s3ry5fQYIEekfEjUTQtscaG+//Xb7bQziYteuXQwZMiTqMAqSlJhLSkoYPXp0+/+/iPQPiUpA4CWhOJ6I0uk0J554YtRhFCSJMYtI/5GoLjgREek/lIBERCQSSkAiIhIJJSAREYmEEpCIiERCCUhERCKhBCQiIpHwLQE55xY557Y459ZmvfYT59wfnHMvO+cedc7t69f2REQk2fxsAd0LnJXz2lPAsWZ2PPAacK2P2xMRkQTzLQGZ2dPAhzmvPWlmzZkfnwPG+rU9ERFJtjDHgOYAj4e4PRERiTHn56zSzrnxwDIzOzbn9euBqcBnrJMNOufmAnMBKioqPrZ06VLfYgrL9u3bKS8vjzqMgijmcCjmcCQxZkhm3DNmzFhtZlP7vCIz8+0BjAfW5rx2KVAHDM1nHRMmTLAkqqmpiTqEginmcCjmcCQxZrNkxg2sMh9yRqCzYTvnzgK+C0w3s51BbktERJLFzzLsX+K1dCY65zY5574E/BwYDjzlnKt3zv3Cr+2JiEiy+dYCMrPPd/Ly3X6tX0RE+hfNhCAiIpFQAhIRkUgoAYmISCSUgEREJBJKQCIiEgklIBERiYQSkIiIREIJSEQkBHV1sHCh9yyeQKfiERERL+nMmgWNjVBaCsuXQyoVdVTRUwtIRCRg6bSXfFpavOd0OuqI4kEJSEQkYNXVXsunuNh7rq6
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-11-10 05:58:42 +01:00
"import matplotlib.pyplot as plt\n",
"\n",
"plt.figure(figsize=(6, 4)) # extra code not needed, just formatting\n",
2021-11-03 23:35:15 +01:00
"plt.plot(X_new, y_predict, \"r-\", label=\"Predictions\")\n",
"plt.plot(X, y, \"b.\")\n",
2021-11-03 23:35:15 +01:00
"\n",
"# extra code beautifies and saves Figure 42\n",
2021-11-03 23:35:15 +01:00
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
"plt.axis([0, 2, 0, 15])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
"plt.legend(loc=\"upper left\")\n",
2019-05-06 07:14:50 +02:00
"save_fig(\"linear_model_predictions_plot\")\n",
2021-11-03 23:35:15 +01:00
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 11,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(array([4.21509616]), array([[2.77011339]]))"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.linear_model import LinearRegression\n",
"\n",
"lin_reg = LinearRegression()\n",
"lin_reg.fit(X, y)\n",
"lin_reg.intercept_, lin_reg.coef_"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 12,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21509616],\n",
" [9.75532293]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lin_reg.predict(X_new)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `LinearRegression` class is based on the `scipy.linalg.lstsq()` function (the name stands for \"least squares\"), which you could call directly:"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 13,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21509616],\n",
" [2.77011339]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"theta_best_svd, residuals, rank, s = np.linalg.lstsq(X_b, y, rcond=1e-6)\n",
"theta_best_svd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function computes $\\mathbf{X}^+\\mathbf{y}$, where $\\mathbf{X}^{+}$ is the _pseudoinverse_ of $\\mathbf{X}$ (specifically the Moore-Penrose inverse). You can use `np.linalg.pinv()` to compute the pseudoinverse directly:"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 14,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21509616],\n",
" [2.77011339]])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-11-03 23:35:15 +01:00
"np.linalg.pinv(X_b) @ y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-02 13:14:44 +02:00
"# Gradient Descent\n",
"## Batch Gradient Descent"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 15,
"metadata": {},
"outputs": [],
2017-05-29 23:20:14 +02:00
"source": [
"eta = 0.1 # learning rate\n",
2021-11-03 23:35:15 +01:00
"n_epochs = 1000\n",
"m = len(X_b) # number of instances\n",
"\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
"theta = np.random.randn(2, 1) # randomly initialized model parameters\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"for epoch in range(n_epochs):\n",
" gradients = 2 / m * X_b.T @ (X_b @ theta - y)\n",
2017-05-29 23:20:14 +02:00
" theta = theta - eta * gradients"
]
},
{
2021-11-03 23:35:15 +01:00
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"The trained model parameters:"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 16,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21509616],\n",
" [2.77011339]])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"theta"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 17,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAAEQCAYAAAC++cJdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADaKUlEQVR4nOydd3gUx/nHP3NFvfcGEk1CgOi9N4MbuLckbontxHEcJ3bcW5y4xDV27NjJL7Zjx3Iv2AbcACPA9C66ACFAqPd6urLz+2MFCHR36kJlPs+jR7A3uzO7unfnu7NvEVJKFAqFQqFQKBSK3oLhXA9AoVAoFAqFQqHoTJQAVigUCoVCoVD0KpQAVigUCoVCoVD0KpQAVigUCoVCoVD0KpQAVigUCoVCoVD0KpQAVigUCoVCoVD0KpQAVigUCoVCoVD0KpQAVigUCoVCoVD0KpQAVrhECPFbIcQRIYRFCLFVCDGtrfsIIR4UQmwWQlQIIQqFEIuFEMM67iwUip5JS+1TCDFdCPG1EOKEEEIKIW7qpKEqFD2eVtjjn+vtsOFPXmeNV6EEsMIFQohrgFeAp4FRwDrgWyFE3zbuMxN4HZgMzAbswHIhREgHnIZC0SNpjX0CfsBu4C6gtsMHqVD0ElppjwAHgOgGPykdOU7FmQhVCrn7IYTYAXwFhADXARrwqpTyr+3Yx0YgXUp5a4NtB4HPpJQPtuM+fkA5cKmUcnF7jV+hOFd0Vfs8a/8q4HdSynfaa0wKRVekq9qjEOLPwJVSSvUG9ByhVoC7GUIIM5AM/ALYBEwE3gCeEEJEndX2ISFEVRM/jV7TCCE8gDHAD2d99AP6yq2zcbV4n3r80b+HpW7aKBTdgq5qnwpFb6Qb2GP/epekI0KIj4QQ/Vt5qopWYDrXA1C0mCGAB3CPlPJLACHEm8BjQBDQ0IfoX8AnTRzvhJNtYYARyD9rez4w18VxWrMP6K+NdgDrmxinQtEd6Kr2qVD0RrqyPW4EbgL2AxHAI8A6IcRQKWVxE+NQtANKAHc/RqKvljZ0Fwiv/32GA72UsgQoaUNfZ/vHCCfbWr2PEOIlYCowVUrpaNUIFYquxUi6tn0qFL2JkXRRe5RSfntGYyE2AJnAjcBLbRiHopkoF4juxwhgy1mCcSRwVEpZ1rBha1/pAEWAA4g6a3sEjZ9yW7WPEOLv6P5Ys6WUmS6OqVB0N7qqfSoUvZFuY49SyipgDzCoufso2oZaAe5+jAQ2n7VtFLobwdm06pWOlNIqhNgKnAd82uCj84DPnR2kJfsIIV4BrgVmSin3NzE+haI7MZIuaJ8KRS9lJN3EHoUQXsBgYGVz91G0DSWAux/Dgf87a9soYNnZDdv4Sucl4D0hxCZgLfAbIAb9JoEQ4nfoUeSDm7tP/X7/BK4HLgVKGwQiVNU/ASsU3Zkua5/1GVcG1v/XAPQVQowESqSUx1o5DoWiK9OV7fEFdNeMY+irxY8CvsC7rRyDooUoAdyNEELEAaE0eHoVQgj01zzPtWdfUsqPhRCh6I750ej5Qy+UUh6tbxIGJLVwH4Df1v9ecVaXTwB/bs9zUCg6k65un8BYzlxdeqL+5130YByFosfQDewxDviw/rNCYAMw8az5UtGBqDzA3RwhRCJ6Mu2+Usrj53o8CoXiNMo+FYqug7JHRUNUEFz3ZxRQrIxZoeiSKPtUKLoOyh4Vp1ACuPvjyqFfoVCce5R9KhRdB2WPilMoFwiFQqFQKBQKRa9CrQArFAqFQqFQKHoVSgArFAqFQqFQKHoVXS4NWlBQkBw4cGDTDbs41dXV+Pr6nuthtAl1Dl2DrVu3Fkkpw5tu2f70BHvsCd8B6Bnn0RPOoavYY0V2NlX5+UQMHYrJy8tpe6lpFOzZg8FkIjw52e2xa7KzqcvPx3/wYExN/I1shQVYc3Pw7BuPKSi4yXHbc7PRqiox90mgxqE1+R2QdbXIghPg5YMhPKbJ4+uDqoPSXDB5QHA0CNG8/apLwG4B7yDw8GnePtKBVlOKQQBeQWBoppRy1IG9Vm9v9mvePh1Ir7dHKWWX+klMTJQ9gZUrV57rIbQZdQ5dA/RSnsoeW0lP+A5I2TPOoyecQ1ewx8q8PPlXb2/5+S9+4Xasq595Rj4OMrOJ616yY4f82GiUm267rcnzr96/T24aEC0zbrtRaprWZPuKrz+RWXNHyLKP3pZSNv0dcJQUysqHfi4r/3Kr1Kormzy+lFJqpXnS/s9fS/t/7pJadXnz9tEc0rHtU+n4+mGpZa5v1j5SSqlZyqVjx7vSsvENqVUXNH+/kgzpOLRIOnI3SE2zN3u/jqS326NygVAoFAqFohux9rnncNTVMf3RR122qSooYM3TT5O0cCH9Zs502U5qGltvvx2PkBCGP/OM2341m43Mu+/A6OdPwlPPI5pYZbVmZlDyxvN4jZ1MwFU3um0LIG1Wat96Cmmtw/uWRxA+Ta+SytoqtC+eA6lhuPx+hE9A0/tIidzzDWTvQCTNQfSb2OQ+ALKuErn/S3BY2WWJR/g0vfAopUQr2Ycs2QN+sYjIcQhhbFZ/io5FCWCFQqFQKLoJlXl5bH7jDYb/4heEJSa6bJf25z9jr63lvOfcFz3LfOstitevZ8QLL+AZEuK2be4b/6Bm104Snnoec5h78afV1lL45P0Y/QMIu+9JhMG93JBSYvnkdbRjB/G6/m6M0X3dtgeQdhva1y9BRSGGS+5GhEQ3uQ8AGSvhyAboPxkGzWzWLrKuEnngS3DUIZIWUqV5N72PlMiSvVB6APz7IiLGIoSSXV0F9ZdQKBQKhaKbsPbZZ3FYrW5Xfwv37WPr//0fY3/zG8KSzq7AexpLQQHp999P+IwZJFx/vdt+a/buJueVFwhZeBkhFy1scpwl//wb9uwsQu9/CmOwe2ENYFuzFPvG5XjMvxbz8ElNtpdSIr//N2TvR8z/DSJucJP7AMjM9ciMH6HPKMSQ85tcxQaQ1ipd/NotiMSFCN+I5o2vKB3KDkJAP0T4qGb1peg8ulwQnEKhUCgUisZIu53Nb7zBiBtuINRNcOqy++7Dw9eXGY895vZ4O++7D1tlJWNef92tONOsVjLv/h2moGDi//psk+OsWrGU6u+/IvDnt+I9ekKT7e2HdlP3xX8wDh2HxwU/a7I9gFz3KXL/OsTUazAkT27ePse3I/cshaghiOGXNms1VlqrdbcHWy0iaSHCL7LpfaREFu6AyqMQOAAROkyJ3y6IEsAKhUKhUHQDrCUlSIeD6Y884rJN5o8/krFkCXOffRbfcNduCgWrV5P17rskP/AAgUOGuO0359UXqdm7m0Fvvoe5idVcW/ZRSl55Cs9howi8/tfuTwjQSguxvP0MhrBovG/4U5OuEgDarjTkhi8RKbMQ45tejQaQefuQOxdBWH/E6KsQhqb9cKW1Wl/5tdXUi9+opveRGrJgG1RlQ3AiIjhZid8uihLACoVCoVB0A6zl5Yz45S8J6d/f6eeaw8EP99xDYHw8E37/e5fHcVitbL39dnwTEhjixpUCoDp9BzmvvUzoFdcQPO8Ct22l1UrR0w8gTGbCHnwGYXQvMaS1jtq3nkbarHjf+jDCu+mUXDJrF3L5WxCfgphzc/NcGIoykVs/hsAYxLifI4zmpvexVSMPfAXWKt3tobniN38LVOcgQpIRwa7dTxTnHiWAFQqFQqHoDkjpdvV353vvkbdjB1d8+CFmF7mBATJeeomKvXuZtngxJh/XuW+1ujoy7/4d5rBw4h9/qsnhlb75MtaD+wh/4mVMEe4FY8OgN+9bH8UY2afJ48ui42iLX4aQGAwL7mpSYAPIshPITangE4KYcAPC5Nn0PrYa5P6vwFqJSFyA8G86uE5qDmT+JqjJ110egrp3/vTegBLACoVCoVB0A7wiIghOSHD6mbW6mh8ffpjY8eMZds01Lo9RnZXFnr/8hdhLLyXm4ovd9nfi789Rm7GfxHc/whQU5LZtzbo0Khd9gP+l1+EzeWYTZwK21Uuwb1qBx/nXYUpp2k9YVpXq6c7Mnhguuw/h2XTRCllZgNzwLnj4ICbehGhGoQtpq6lf+a1EDLo
"text/plain": [
"<Figure size 720x288 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code generates and saves Figure 48\n",
2021-11-03 23:35:15 +01:00
"\n",
"import matplotlib as mpl\n",
"\n",
2021-11-03 23:35:15 +01:00
"def plot_gradient_descent(theta, eta):\n",
2016-09-27 16:39:16 +02:00
" m = len(X_b)\n",
" plt.plot(X, y, \"b.\")\n",
2021-11-03 23:35:15 +01:00
" n_epochs = 1000\n",
" n_shown = 20\n",
" theta_path = []\n",
" for epoch in range(n_epochs):\n",
" if epoch < n_shown:\n",
" y_predict = X_new_b @ theta\n",
" color = mpl.colors.rgb2hex(plt.cm.OrRd(epoch / n_shown + 0.15))\n",
" plt.plot(X_new, y_predict, linestyle=\"solid\", color=color)\n",
" gradients = 2 / m * X_b.T @ (X_b @ theta - y)\n",
" theta = theta - eta * gradients\n",
2021-11-03 23:35:15 +01:00
" theta_path.append(theta)\n",
" plt.xlabel(\"$x_1$\")\n",
" plt.axis([0, 2, 0, 15])\n",
2021-11-03 23:35:15 +01:00
" plt.grid()\n",
2021-11-21 22:18:02 +01:00
" plt.title(fr\"$\\eta = {eta}$\")\n",
2021-11-21 05:36:22 +01:00
" return theta_path\n",
2021-11-03 23:35:15 +01:00
"\n",
"np.random.seed(42)\n",
"theta = np.random.randn(2, 1) # random initialization\n",
2017-05-29 23:20:14 +02:00
"\n",
"plt.figure(figsize=(10, 4))\n",
2021-11-03 23:35:15 +01:00
"plt.subplot(131)\n",
"plot_gradient_descent(theta, eta=0.02)\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
"plt.subplot(132)\n",
"theta_path_bgd = plot_gradient_descent(theta, eta=0.1)\n",
"plt.gca().axes.yaxis.set_ticklabels([])\n",
"plt.subplot(133)\n",
"plt.gca().axes.yaxis.set_ticklabels([])\n",
"plot_gradient_descent(theta, eta=0.5)\n",
2017-05-29 23:20:14 +02:00
"save_fig(\"gradient_descent_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-02 13:14:44 +02:00
"## Stochastic Gradient Descent"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 18,
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
"theta_path_sgd = [] # extra code we need to store the path of theta in the\n",
" # parameter space to plot the next figure"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 19,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABh/ElEQVR4nO29d3gc53Xv/3lnZ/tisbtYNBLsnWInJRIsIqhiybZsxTVusWM5UeL4+vrGieOS303u9Y1j59rXJXGKlVix5chFLrJlW7IlSwRFikUSm1jETqIRvS+277y/P2YBLEBUcoFdkO/nefBwy5Qzw9n5zjnvOecVUkoUCoVCoZhutFwboFAoFIpbEyVACoVCocgJSoAUCoVCkROUACkUCoUiJygBUigUCkVOUAKkUCgUipyg59qA4fh8Prl48eJcmzFp+vr6cLvduTZjUiibpwdl8/QwE22GmWn34cOH26SUxTe6nbwToNLSUl599dVcmzFpqqurqaqqyrUZk0LZPD0om6eHmWgzzEy7hRA12diOCsEpFAqFIicoAVIoFApFTlACpFAoFIqcoARIoVAoFDlBCZBCoVAocoISIIVCoVDkBCVACoVCocgJSoAUCoVCkROUACkUCoUiJ2RNgIQQjwohWoQQJ0f47i+FEFIIEczW/hQKhUIxs8mmB/Qd4P7hHwoh5gD3ArVZ3JdCoVAoZjhZEyAp5YtAxwhffQ34K0Bma18KhUKhmPlM6RiQEOKtQIOU8vhU7kehUCgUMw8hZfYcEyHEfOBXUspVQggXsBt4g5SyWwhxBdgkpWwbYb2HgYcBiouLNz7xxBNZs2m6CIVCeDyeXJsxKZTN04OyeXqYiTbDzLR7165dh6WUm254Q1LKrP0B84GT6dergRbgSvoviTkOVDbWNpYuXSpnIrt37861CZNG2Tw9KJunh5los5Qz027gVZkFzZiy+YCklCeAkv73Y3lACoVCobj1yGYa9g+AA8AyIUS9EOIj2dq2QqFQKG4+suYBSSnfO87387O1L4VCoVDMfFQnBIVCoVDkBCVACoVCocgJSoAUCoVCkROUACkUCoUiJygBUigUCkVOUAKkUCgUipygBEihUCgUOUEJkEKhUChyghIghUKhUOQEJUAKhUKhyAlKgBQKhUKRE5QAKRQKhSInKAFSKBQKRU5QAqRQKBSKnKAESKFQKBQ5QQmQQqFQKHKCEiCFQqFQ5AQlQAqFQqHICUqAFAqFQpETlAApFAqFIicoAVIoFApFTsiaAAkhHhVCtAghTmZ89mUhxBkhxGtCiCeFEL5s7U+hUCgUM5tsekDfAe4f9tlzwCop5RrgHPDZLO5PoVAoFDOYrAmQlPJFoGPYZ89KKZPptweBimztT6FQKBQzm+kcA3oIeGYa96dQKBSKPEZIKbO3MSHmA7+SUq4a9vlfA5uAt8sRdiiEeBh4GKC4uHjjE088kTWbpotQKITH48m1GZNC2Tw9KJunh5loM8xMu3ft2nVYSrnphjckpczaHzAfODnssw8BBwDXRLaxdOlSORPZvXt3rk2YNMrm6UHZPD3MRJulnJl2A6/KLGiGfsMKNgZCiPuBTwM7pZThqdyXQqFQKGYW2UzD/gGmp7NMCFEvhPgI8E2gAHhOCHFMCPFv2dqfQqFQKGY2WfOApJTvHeHjb2dr+wqFQqG4uVCdEBQKhUKRE5QAKRQKhSInKAFSKBQKRU5QAqRQKBSKnKAESKFQKBQ5QQmQQqFQKHKCEiCFQqFQ5AQlQAqFQjENHDgAX/yi+a/CZEpb8SgUCoXCFJ2774Z4HGw2eP55qKzMtVW5R3lACoVCMcVUV5vik0qZ/1ZX59qi/EAJkEKhUEwxVVWm52OxmP9WVeXaovxAheAUCoViiqmsNMNu1dWm+Kjwm4kSIIVCoZgGKitnnvAcODC1oqkESKFQKBTXMB2JE2oMSKFQKBTXMB2JE0qAFAqFQnEN05E4oUJwCoVCobiG6UicUAKkUCgUihGZ6sQJFYJTKBSKPOFWa9ejPCCFQqHIAyaSdTbVadETQUZCWdtW1gRICPEo8ADQIqVclf4sAPwImA9cAd4tpezM1j4VCsXNRz7cZHPBSFlnmcc/lWnRY51zmYhC/Vlk7Ulk7SloqcnOTsmuB/Qd4JvAYxmffQZ4Xkr5JSHEZ9LvP53FfSoUipuIW7lpZ3/WWf+xD886G0+grpfh5/x3zyapnH9hUHAaL4CRAosOs5YSmVMJfP/Gd0wWBUhK+aIQYv6wjx8EqtKvvwtUowRIoVCMwlTdZGcC42WdjSdQ18vu3QbxuCCVEsRjKXb/w8/ZXPUzEAJKFxAqXsPZZ/ZT8/JRmq8+QWcklp0dM/VjQKVSykYAKWWjEKJkivenUChmMFN1k50pjJV1lq20aCkldDYNeDh3tkexaX9O3NCx6Sk2rezm0MFuag6foKXxx3RFE8j0ul6bzvyiAmjrvb6dD0NIKcdfaqIbMz2gX2WMAXVJKX0Z33dKKf0jrPcw8DBAcXHxxieeeCJrNk0XoVAIj8eTazMmhbJ5elA2T45Tp7wcO+Zj3boubrutZ8LrzcTzDNNjty3eh6+3AV9PA77eq9gTfQBEbR5qWg2e/rXk4PkFlEd+S1FyHwAC8DmslHhdlPnclJf6cDhsiEIv83/y4mEp5aYbtWuqPaBmIUR52vspB1pGWkhK+QjwCMCyZctk1Qx87Kmurmam2a1snh6UzZPjenc73TZnK1liKuyWkRDUnR4cx+lsND93eGjqs3N2zzlqT5ylpaWNSDIFwFoBAaeNkmIfZT43ZSU+bHYbesCPbtexyASariGFJWt2TrUAPQV8CPhS+t9fTPH+FAqFYsrJt2SJwUy1U8jak9BSw4Eri6m+tIplpUUUXHmJutMXaW3rIJ4yALBpgiKXjaWFXsr9HoqLCtHtOro/gG63YCE5IDjxSJyGulbq6lpp6uzLmt3ZTMP+AWbCQVAIUQ/8LabwPCGE+AhQC7wrW/tTKBSKXJHrZAmZSkLjhbTgnILG82CkSKTgSpPBr37n4HMnPk0SGxbifJBfstTSQYnLRqnPTZnfQ1GgAItNR/f704KTMgUHC329YequNFPf0E5zV5iucBQjvW+HJQ89ICnle0f56u5s7UOhUCjygelOlpDSgJbawZBawxlIxIiE4/z88EJ+8+pGSrp/RWHoeSSwl8+QxIZEx0DSXfQx2grfw5KiI2xYUIdu17AIA82iYaDR3tpN7ZVmGpo6ae2J0BdPAOY4kMduY+GK5cy5YyOLb19C0Jvksx/8u6wcl+qEoFDcosy0gs98sneqG3UOz1Sj7jREQ/R0hblc28eF47U0XKznZGgNj/EYKWzoPMxnCx5ge/Aki6zn2H8hQcIAiyb5Zee7SLXrfLMmwU+Cn2JW+17qaltobOmmrTdKwjD9Gx3wedwsWLuWedtuZ9GGBbiNTmipR4b6oPk1ZHP2jlMJkEJxC5JvYxjjkY/2ZrtRp+ztMMdv6k4ja04ie9tpb+3l8uVuLp6s5+qVRvrSNTgWAX67lT7XfRjhfk8HIqVvY+V6gzW2Tn668C850LqBy11F/ODiAxhYkCmDrz3jYTvHAXAKKAkEmLVuHQvu3Mzc5cXY+pqQLVeRfc1wvhkD6EnqtHRqNDeGaDp9MWvHrARIobgFyfUYxmSZafZOBBkJQf1pFtXuI/XoLzHaGmhq6OTypXYuvd5EY10zsXgSMBMG/HYr84MFlPrdBH1unB4ns0KX+cWRJAkDrFqSqiXn6I3Hqb/QQsvVoxR0PoI9vgmNN6RreSRzCpPcedfvsbBqMyWzHehd9ciWRmTkLJw8S8IwaO8TNHckaK7rounEGaJtbQA4ioooWrUKzp/PyjlQAqRQ3ILMtILPmWbvSMhEFBrOImvMTLVE/UUaatu4dL6F35xroflqG8l0hprLohG0Wwn6XJT63fi9LhxuJ5pVRy/0otsEmkVQmarhX8VHeeHyMsqiz3Lkt9W8bJhSYwW8uoU3LGkjWPIjvnrggxhS50fhL/Chkr+jrO0VaINwLElLd4qWthhNV1ppPXmGVMz0tAoXL6Zs40b
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2017-05-29 23:20:14 +02:00
"n_epochs = 50\n",
"t0, t1 = 5, 50 # learning schedule hyperparameters\n",
"\n",
"def learning_schedule(t):\n",
" return t0 / (t + t1)\n",
"\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
"theta = np.random.randn(2, 1) # random initialization\n",
"\n",
"n_shown = 20 # extra code just needed to generate the figure below\n",
"plt.figure(figsize=(6, 4)) # extra code not needed, just formatting\n",
"\n",
2017-05-29 23:20:14 +02:00
"for epoch in range(n_epochs):\n",
2021-11-03 23:35:15 +01:00
" for iteration in range(m):\n",
"\n",
" # extra code these 4 lines are used to generate the figure\n",
2021-11-03 23:35:15 +01:00
" if epoch == 0 and iteration < n_shown:\n",
" y_predict = X_new_b @ theta\n",
" color = mpl.colors.rgb2hex(plt.cm.OrRd(iteration / n_shown + 0.15))\n",
" plt.plot(X_new, y_predict, color=color)\n",
"\n",
2017-05-29 23:20:14 +02:00
" random_index = np.random.randint(m)\n",
2021-11-03 23:35:15 +01:00
" xi = X_b[random_index : random_index + 1]\n",
" yi = y[random_index : random_index + 1]\n",
" gradients = 2 * xi.T @ (xi @ theta - yi) # for SGD, do not divide by m\n",
2021-11-03 23:35:15 +01:00
" eta = learning_schedule(epoch * m + iteration)\n",
" theta = theta - eta * gradients\n",
" theta_path_sgd.append(theta) # extra code to generate the figure\n",
2017-05-29 23:20:14 +02:00
"\n",
"# extra code this section beautifies and saves Figure 410\n",
2021-11-03 23:35:15 +01:00
"plt.plot(X, y, \"b.\")\n",
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
"plt.axis([0, 2, 0, 15])\n",
"plt.grid()\n",
"save_fig(\"sgd_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 20,
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21076011],\n",
" [2.74856079]])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"theta"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 21,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"SGDRegressor(n_iter_no_change=100, penalty=None, random_state=42, tol=1e-05)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.linear_model import SGDRegressor\n",
"\n",
"sgd_reg = SGDRegressor(max_iter=1000, tol=1e-5, penalty=None, eta0=0.01,\n",
" n_iter_no_change=100, random_state=42)\n",
"sgd_reg.fit(X, y.ravel()) # y.ravel() because fit() expects 1D targets\n"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 22,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(array([4.21278812]), array([2.77270267]))"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sgd_reg.intercept_, sgd_reg.coef_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-02 13:14:44 +02:00
"## Mini-batch gradient descent"
]
},
2021-11-03 23:35:15 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code in this section is used to generate the next figure, it is not in the book."
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 23,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAegAAAEQCAYAAAB7ked4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAACMAUlEQVR4nO2dZ3hURReA35tCSKND6CDSewdFSgSUKk1UBAQsWAAFpaOwIAgivX+IdKRJB+kSpEvvvffQk0D6zvdjNtlNsptsQrYkzJtnn+zOnXvvmS333DlziiaEQKFQKBQKhXPh4mgBFAqFQqFQJEQpaIVCoVAonBCloBUKhUKhcEKUglYoFAqFwglRClqhUCgUCidEKWiFQqFQKJwQuyhoTdMyapr2n6ZpxzVNO61p2tBE+lbTNC1a07T37SGbQqFQKBTOiJudzhMOvC2ECNE0zR3YrWnaRiHEftNOmqa5Ar8Cm+0kl0KhUCgUToldZtBCEmJ46W54mMuQ0gNYAQTaQy6FQqFQKJwVu61Ba5rmqmnaMaTy3SqEOBBvez6gFTDDXjIpFAqFQuGs2MvEjRAiGqioaVoWYJWmaWWFEKdMukwA+gkhojVNs3gcTdO6Al0BPD09qxQoUMB2QluBXq/HxSX9+dqpcaUt0uu4IP2OTY0rbWGrcV24cOGhECKnuW2aI3Jxa5o2BHguhBhj0nYViNHMOYAXQFchxGpLx6latao4dOiQLUVNkoCAAOrVq+dQGWyBGlfaIr2OC9Lv2NS40ha2GpemaYeFEFXNbbPLDFrTtJxApBDiqaZpnkADpDNYLEKI10z6zwXWJ6acFQqFQqFIz9jLxJ0HmGfw0nYBlgkh1mua9hWAEEKtOysUCoVCYYJdFLQQ4gRQyUy7WcUshOhsa5kUCoVCoXBm0t9KvkKhUCgU6QCloBUKhUKhcELsFmblCIKCgggMDCQyMtJm58icOTNnz5612fEdRVoZl7u7O7ly5SJTpkyOFkWhUChSlXSroIOCgrh//z758uXD09OTxGKrX4bg4GB8fX1tcmxHkhbGJYQgNDSU27dvAyglrVAo0hXp1sQdGBhIvnz58PLysplyVjgWTdPw8vIiX758BAaq7LAKhSJ9kW4VdGRkJJ6eno4WQ2EHPD09bbqMoVAoFI4g3SpoQM2cXxHU56xQKNIj6VpBKxQK69AF6BwtgkKhiIdS0ApA5pnVNI2HDx/a/dz16tWje/fudj+vQhKlj2LozqGOFkOhUMRDKWgn5MGDB3zzzTcULlwYDw8P/Pz8qF+/Plu3bgWgcOHCjBkzJomjOB9z587Fx8cnQfvKlSsZOXKkAyRSbLm8hXLTywFw5O4RB0ujUDgRuXODpsU+6vn7y+e5c9tNhHQbZpWWadOmDS9evOCPP/6gaNGiBAYGsnPnTh49euRo0WxCtmzZHC3CK8flx5dp8mcTLjy6ENtWZWYVAIbUHYKuns5BkikUTsL9+8lrtwFqBp0Y8e6gYh82vIN6+vQpu3btYtSoUdSvX59ChQpRrVo1evfuzUcffUS9evW4fv06ffr0QdO0OA5SK1eupFy5cnh4eFCgQAFGjBiBaTnRiIgIBg4cSKFChfDw8KBIkSJMmjQpzvmPHz9OjRo18PPzo2rVqhw5YpxVPXr0iHbt2pE/f348PT0pU6YMc+bMibP/v//+S82aNfHx8SFz5szUqFGDU6dOERAQQJcuXXj+/Hms3DqdDkho4rZGTkXKCIkIYcC2AZSeVjqOcgYQQwRiiFDKWaEIDXW0BIBS0InjgDsoHx8ffHx8WLt2LWFhYQm2r1y5kvz58zN48GDu3r3L3bt3ATh8+DBt27aldevWnDx5klGjRjFy5EimTJkSu2+nTp2YP38+48aN4+zZs/zxxx9kyZIlzvEHDBjAqFGj2LVrF9mzZ6d9+/axSj4sLIzKlSuzfv16Tp8+zXfffceXX37J9u3bAYiKiqJFixa89dZbHD9+nAMHDvDdd9/h6urKm2++yYQJE/Dy8oqVu3fv3mbfA2vkVCQPIQQLTyyk+OTijNoziojoiNhtA98a6EDJFAoHExUFJ0/CH3/Al19C5cqQPbujpQKUidvpcHNzY+7cuXzxxRfMnDmTSpUqUatWLdq2bUuNGjXIli0brq6u+Pr6kttkJj9u3Djq1q3L0KHS2ad48eJcvHiRX3/9lR49enDx4kWWLFnCxo0badSoEQBFihRJcP6ff/4Zf39/goODGTx4MG+99Ra3b98mf/785MuXjz59+sT27dq1K//88w+LFy+mfv36BAUF8fTpU5o3b87rr78OQMmSJWP7Z86cGU3T4sgdH2vlVFjPoTuH+Hbjt+y7tQ+A8n7lufjoIqFRobzz+jsM8x+Gu6u7g6VUKOyAEHD1Khw8CP/9J/8fPQr58kH16lCtGtSrB3PnwpYtjpb2FZxBmzNZW3pYcRzfTJle7jhmaNOmDXfu3GHdunU0btyYvXv3UrNmTX755ReL+5w9e5ZatWrFaYtRrkFBQRw9ehQXFxf8/f0TPXf58uVjn+fNmxcgNktXdHQ0I0aMoHz58mTPnh0fHx9WrlzJjRs3ALmW3LlzZ959912aNm3KuHHjuHnzZrLGbq2ciqS5H3Kfz9Z8RvXfq7Pv1j78vP2Y02IO1fNWJzQqlEKZC/Fn6z9xdXFVZm1F+iQwEDZsgCFDoHFjyJkT6tSBpUshRw7ZfvMmnDsH06dL62j37lJJOwGvnoIWwvqHFccJDgp6ueNYIGPGjDRs2JDBgwezd+9ePvvsM3Q6HREREWb7CyEsJuzQNC3OWnRiuLsbZ1Ixx9Pr9QCMGTOGsWPH0qdPH7Zv386xY8do2bJlHJnmzJnDgQMHqFOnDmvXrqV48eJs3rzZqnPHjEPxckRERzBu3ziKTynO7GOzcXNxo/cbvbnQ4wJ6oWfW0Vl4uHqw8sOVZPdyDlOeQvHSBAdDQAD89hu0bQuFC0OJEjBxojRjf/UVnDgBt27BypUwYADUrw+ZM8OSJVCypJxdHz8ut/n5mT+PpXYboEzcaYTSpUsTFRVFWFgYGTJkIDo6OsH23bt3x2nbvXs3+fPnx9fXl8qVK6PX69mxY0es6Ti57N69m+bNm9OxY0dAKtMLFy4kWB+uUKECFSpUoF+/fjRu3Jh58+bx7rvvmpU7Pqkh56vMpkub6LmpJ+cfnQegSbEmjH93PMWzF+fQnUN8s+EbAKY3nU7lPJUdKapCYZ7cuc37+fj5wb178nlEhFS2MWbq//6Da9egQgVpqm7ZEn75BYoWTdyKeeIE9OgBz57Bn39C7drGbTHnMhAQEEA9O8+sX70ZdHJwwB3Uo0ePePvtt1m4cCEnTpzg6tWrLF++nNGjR1O/fn0yZcpE4cKF2bVrF7dv345NLPLDDz+wc+dOdDodFy5cYNGiRYwdO5a+ffsCUKxYMT744AM+//xzVqxYwdWrV9m1axcLFiywWrbixYuzfft2du/ezblz5+jevTtXr16N3X716lX69+/P3r17uX79Ojt27ODEiROULl0akPHbYWFhbN26lYcPH/LixYsE50gNOV9Fbofe5r3F79F4UWPOPzpPsWzFWN9uPRs+3kDx7MV5+OIhbZa1ITw6nC+rfEmXSl0cLbJCYZ7EnHN79IAaNSBrVvj0UzhyBGrWhIUL4elT2LsXJkyA9u2hWDHLyvnxY2nKbtAAPvoIDh+Oq5ydBSFEmn1UqVJFWOLMmTMWt6UmQUFBqXq8sLAwMWDAAFG1alWRJUsW4enpKYoWLSp69eolHj16JIQQYt++faJ8+fLCw8NDyI9QsmLFClG2bFnh7u4u8ufPL4YPHy70en2cY/fp00fkzZtXZMiQQRQpUkRMnjxZCCHEjh07BCAePHgQO66rV68KQBw8eFAIIcTjx49Fq1athI+Pj8iZM6fo06eP+Prrr0XdunWFEELcu3dPtGrVKvb4BQoUEH369BERERGxMnz11Vcie/bsAhBDhgwRQghRt25d0a1bN6vktIS1n/eOHTus6pdWCAoLEv229hPuQ90FOoTvL75i9O7RIjwqPLZPVHSUaDC/gUCHqPF7DRE
"text/plain": [
"<Figure size 504x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 411\n",
2021-11-21 05:36:22 +01:00
"\n",
2021-11-03 23:35:15 +01:00
"from math import ceil\n",
"\n",
2021-11-03 23:35:15 +01:00
"n_epochs = 50\n",
"minibatch_size = 20\n",
2021-11-03 23:35:15 +01:00
"n_batches_per_epoch = ceil(m / minibatch_size)\n",
"\n",
"np.random.seed(42)\n",
2021-11-03 23:35:15 +01:00
"theta = np.random.randn(2, 1) # random initialization\n",
"\n",
"t0, t1 = 200, 1000 # learning schedule hyperparameters\n",
"\n",
"def learning_schedule(t):\n",
" return t0 / (t + t1)\n",
"\n",
2021-11-03 23:35:15 +01:00
"theta_path_mgd = []\n",
"for epoch in range(n_epochs):\n",
" shuffled_indices = np.random.permutation(m)\n",
2016-09-27 16:39:16 +02:00
" X_b_shuffled = X_b[shuffled_indices]\n",
" y_shuffled = y[shuffled_indices]\n",
2021-11-03 23:35:15 +01:00
" for iteration in range(0, n_batches_per_epoch):\n",
" idx = iteration * minibatch_size\n",
" xi = X_b_shuffled[idx : idx + minibatch_size]\n",
" yi = y_shuffled[idx : idx + minibatch_size]\n",
" gradients = 2 / minibatch_size * xi.T @ (xi @ theta - yi)\n",
" eta = learning_schedule(iteration)\n",
" theta = theta - eta * gradients\n",
2021-11-21 05:36:22 +01:00
" theta_path_mgd.append(theta)\n",
"\n",
"theta_path_bgd = np.array(theta_path_bgd)\n",
"theta_path_sgd = np.array(theta_path_sgd)\n",
2021-11-21 05:36:22 +01:00
"theta_path_mgd = np.array(theta_path_mgd)\n",
"\n",
2021-11-03 23:35:15 +01:00
"plt.figure(figsize=(7, 4))\n",
"plt.plot(theta_path_sgd[:, 0], theta_path_sgd[:, 1], \"r-s\", linewidth=1,\n",
" label=\"Stochastic\")\n",
"plt.plot(theta_path_mgd[:, 0], theta_path_mgd[:, 1], \"g-+\", linewidth=2,\n",
" label=\"Mini-batch\")\n",
"plt.plot(theta_path_bgd[:, 0], theta_path_bgd[:, 1], \"b-o\", linewidth=3,\n",
" label=\"Batch\")\n",
"plt.legend(loc=\"upper left\")\n",
"plt.xlabel(r\"$\\theta_0$\")\n",
"plt.ylabel(r\"$\\theta_1$ \", rotation=0)\n",
"plt.axis([2.6, 4.6, 2.3, 3.4])\n",
"plt.grid()\n",
"save_fig(\"gradient_descent_paths_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-02 13:14:44 +02:00
"# Polynomial Regression"
]
},
2017-05-29 23:20:14 +02:00
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 24,
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
"m = 100\n",
2017-05-29 23:20:14 +02:00
"X = 6 * np.random.rand(m, 1) - 3\n",
2021-11-03 23:35:15 +01:00
"y = 0.5 * X ** 2 + X + 2 + np.random.randn(m, 1)"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 25,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZYklEQVR4nO3dfYxcV33G8efntZ2YOCuD2NI0L01SxQEcagMpsKJFawxt1PJWhRRoGlIMWAgBoQkCEpSGYuFQmaBEFX8QwMJRt1QG0zbiNcF41CJvAnHYkBjj8KLUGFIiApazIUxs769/zEwynszr7j3n3JfvR7LGsx7PPffuzHnuebnnmrsLAIDYlqQuAACgmgggAEASBBAAIAkCCACQBAEEAEiCAAIAJJFZAJnZNjN7yMzua/vZM8zsdjP7UfPx6VltDwBQbFm2gD4n6aKOn31Q0i53P0/SruZzAABkWV6IamZnS/qyu1/QfH5A0pS7P2hmp0mqufv5mW0QAFBYoceAnuXuD0pS8/H3Am8PAFAQS1MXQJLMbJOkTZJ08sknv/Css85KXKL45ufntWRJteaEVHGfJfa7Ssq8z489NqZDh54md8lMOuOM32rFiuOSpPvvv/9X7j4x8E3cPbM/ks6WdF/b8wOSTmv+/TRJBwa9x+rVq72Kdu/enboI0VVxn93Z7yop8z5v2eI+NuYuNR63bHny3yTd5UNkRuhovlXS5c2/Xy7pvwJvDwAQwdSUtHy5NDbWeJyaGv09MuuCM7PPS5qS9EwzOyTpOkkfk7TDzN4q6aCkS7LaHgAgnclJadcuqVZrhM/k5OjvkVkAufubevzThqy2AQDIj8nJhQVPSzlHxwAAuUcAAQCSIIAAAEkQQACAJAggAEASBBAAIAkCCACQBAEEAEiCAAIAJEEAAQCSIIAAAEkQQACAJAggAEBfMzPS9dc3HrOUizuiAgDyaWZG2rBBevzxxn1/du1a3ArY7WgBAQB6qtUa4XP8eOOxVsvuvQkgAEBPWdz5tBe64AAAPWVx59NeCCAAQF+LvfNpL3TBAQCSIIAAAEkQQACAJAggAEASBBAAIAkCCACQBAEEAEiCAAIAJEEAAQCSIIAAAEkQQACQWKj77eQda8EBQEIh77eTd7SAACChkPfbyTsCCAASCnm/nWGl6gKkCw4AEsrqfjszM8O/R/trpXRdgAQQACS22PvtjDKO1Pnayy9/ahdgrACiCw4ACm6UcaTO10rpugCjtIDM7B8kvU2SS7pX0lvc/Xcxtg0AZdcaR2q1avqFSOdr3/zmxp9bbolU2DbBA8jMTpf0HknPdffHzGyHpDdK+lzobQNAFYwyjtTttTMz0vbtjVDavj3eOFCsMaClklaY2VFJT5P0i0jbBYBSGDTJYJRxpM7XduvCK0UAufvPzezjkg5KekzSbe5+W+jtAkBZhL5YdZQuvCyZu4fdgNnTJe2U9AZJhyV9QdIX3f1f216zSdImSZqYmHjhjh07gpYpj+bm5rRy5crUxYiqivsssd9VktU+T0+fpW3bztH8vGnJknlt3PiALr30YAYlfNK+feOanV2ldesOa82aI4t6r/Xr1+919wsHvS5GAF0i6SJ3f2vz+ZslvcTd39nt9eeff74fOHAgaJnyqFaraSrFFWgJVXGfJfa7SrLa56It12NmQwVQjDGgg5JeYmZPU6MLboOkuyJsFwBKYaEXq45ycWqM9+kUYwzoTjP7oqS7JR2T9D1JN4feLgCUyagXq2bVagrZ+opyIaq7X+fuz3b3C9z9Mnevx9guAFRVVouchlwslZUQAKCEslrkNORiqawFBwAllNUip1m9TzcEEACU1GIXOc36fTrRBQcASIIAAgAkQQABAJIggAAASRBAAIAkCCAAQBIEEAAgCQIIAJAEAQQASIIAAgAkQQABAJIggAAASRBAAIAkCCAAQBIEEAAgCQIIAJAEAQQA6GpmRrr++sZjCNwRFQAyNjPTuIX1+Pi4pqZSl2ZhZmakDRukxx+Xli9v3JY767uiEkAAkKH2invp0rV6wQvC3M46tFqtsQ/Hjzcea7Xs94MuOADIUHvFffSoqVbLfhuhu8YkaWqq0fIZG2s8hmjJ0QICgAy1Ku5GC8gzr7hjdI1JjffctasRqFNTYbZBAAFAhtor7vHxezQ5+YJM3z9G11jL5GTY7kMCCAAy1qq4a7Ujmb93ewsrVNdYLAQQABRIjK6xWAggACiY0F1jsTALDgCQBAEEAEiCAAIAJEEAAQCSIIAAAEkQQACAJAggAMiBGOu75U2U64DMbJWkz0i6QJJL2ujuFTrMANBbrPXd8iZWC+gmSV9392dLWitpf6TtAiiJMrcQuq3vVgXBW0BmNi7pZZL+XpLc/XFJj4feLoDyyHMLoXXzucUsi1Om9d1GYe4edgNm6yTdLOkHarR+9kq6wt0fbXvNJkmbJGliYuKFO3bsCFqmPJqbm9PKlStTFyOqKu6zxH4vxPT0Wdq27RzNz5uWLJnXxo0P6NJLD2ZcwtHt2zeuq65aq6NHl2jZsnndcMM9WrPmyQVIR9nnffvGNTu7SuvWHT7hPYpo/fr1e939woEvdPegfyRdKOmYpBc3n98kaXOv169evdqraPfu3amLEF0V99md/V6IPXvcV6xwHxtrPO7Zk125FmPLlkaZpMbjli0n/ntVf9eS7vIh8iHGJIRDkg65+53N51+U9MEI2wVQEnldAbqqXWdZCR5A7v5/ZvYzMzvf3Q9I2qBGdxwADC2PK0DnNRiLItbtGN4tadrMlkv6qaS3RNouAASVx2AsiigB5O6zaowFAQAgiZUQAGBoZb4WKQXuiAoAQ8jztUhFRQsIAIZQ1dUKQhoqgMzskJld2fGz55nZ78zsuWGKBgD50ZpyPTYmLV0qHTxIV9xiDdsCmpH0Jx0/u1HSZ9ydKdUAkok1LtOacv32t0vu0qc/3eiS67XdmZnGCg6EVG/DjgHNSHpn64mZvU7S8yX9TYAyAcBQYo/LTE42ut6OHz+xK65zm61y1evnaHqa8aJehm0B3SHpj8zsGWZ2kqSPS/qIuz8crmgA0F+KcZn2rrheqx+0yjU/b4wX9TFsC2ivGitYX6hGy+eYpE+GKhQADCPFUjjDrH7QKle9Pq/ly5ewRE8PQwWQu9fN7HuSXi3pckl/6+5Hg5YMAAZItRTOoNUPWuXatu0Bbdx4Lt1vPYxyHdCMpCsk3e7uXw5UHgAYSV6XwpmclOr1g5qcPDd1UXJrlOuAZiXNS7pywOsAABholAC6VNKn3H1fqMIAAKqjbxecmS2RNKHG7bSfJ+kNEcoEAKiAQWNAL5P0LUkHJF3s7r8JXyQAQBX0DSB3r4n14gAAARAufbD0OgCEw+0YemDpdQAIixZQDyy9DgBhEUA9DLPeEwBg4eiC6yHVEh8AUBUEUB95XeIDCGVmhpMujkE8BBAASUy8kTgGsTEGBEBScSfeZHm5RFGPQVHRAgIgKc29dRYr6xZLEY9BkRFAACQVc+JNtxZLv3IPGt8p4jEoMgIIwBOKNvFmlBbLsK2loh2DIiOAABRWq8Vyyy2DX9uvtcTMtzQIIACFt317I1S2b+/dsunVWmLmWzrMggNQaMPOXGu1ljZvPjFkmPmWTiVaQDSvgfIaZRyo2/gOM9/SKX0A0bwGimuYk8fFzFxrvf+NN0oPP8xJamyFD6BBH9BRp2kCyIdRTh4XMnONk9P0CjUG1HnFc+sDdO21jcduV0KzqjVQTKHHZhj7Sa8wLaBuZyvDtG64sAwoptBjM4z9pBclgMxsTNJdkn7u7q9ayHt0C5thP0BFvrCMCRSoqtAnj5ycpherBXSFpP2Sxhf6Bt3CpuwfIPqoUXVFPnnEYMEDyMzOkPRXkj4q6cqFvk+vsCnzB5QJFEA4nOClF6MFdKOk90s6ddj/0Kvbqcxh0w191EA4nOClFzSAzOxVkh5y971mNtXndZskbZKkVavO0Pr1x3X06BItWzavG264R2vWHAlZzFyYm5tTrcs0nK1bxzU7u0rr1h1WvX6kVDN1eu1z2bHf+TA+Pq6lS9fK3bR0qWt8/B7
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 412\n",
2021-11-03 23:35:15 +01:00
"plt.figure(figsize=(6, 4))\n",
"plt.plot(X, y, \"b.\")\n",
2021-11-03 23:35:15 +01:00
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
"plt.axis([-3, 3, 0, 10])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
"save_fig(\"quadratic_data_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 26,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([-0.75275929])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.preprocessing import PolynomialFeatures\n",
2021-11-03 23:35:15 +01:00
"\n",
"poly_features = PolynomialFeatures(degree=2, include_bias=False)\n",
"X_poly = poly_features.fit_transform(X)\n",
"X[0]"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 27,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([-0.75275929, 0.56664654])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_poly[0]"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 28,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(array([1.78134581]), array([[0.93366893, 0.56456263]]))"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lin_reg = LinearRegression()\n",
"lin_reg.fit(X_poly, y)\n",
"lin_reg.intercept_, lin_reg.coef_"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 29,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA00ElEQVR4nO3dd3hUVfrA8e9JCCEBQtvQm4oUpYNiAGmCooi4WAEVREXFLqJYVhAXELGt/tS1wAIrdpBVVEQgcVECAhIQiFgQ6bCo1EBIyPv74ySQhJRJMvfeKe/neeaZzOTOvedOJved095jRASllFLKbRFeF0AppVR40gCklFLKExqAlFJKeUIDkFJKKU9oAFJKKeUJDUBKKaU84bcAZIyZZozZY4xZl+u56saYL40xP2XfV/PX8ZRSSgU3f9aApgN98z03BlgkImcCi7IfK6WUUhh/TkQ1xjQG5olIy+zHG4EeIrLTGFMHSBKRZn47oFJKqaDldB9QLRHZCZB9X9Ph4ymllAoS5bwuAIAxZgQwAqBChQodGjZs6HGJ3JeVlUVERHiNCQnHcwY973ASyud85Egk27bFIgLGQP36acTEHAfgxx9/3Csi8cXuRET8dgMaA+tyPd4I1Mn+uQ6wsbh9NG3aVMJRYmKi10VwXTies4iedzgJ5XOeOFEkMlIE7P3EiSd/B6wUH2KG06H5Y2Bo9s9Dgf84fDyllFIu6NEDypeHyEh736NHyffhtyY4Y8w7QA/gL8aYbcBY4CngfWPMTcAW4Cp/HU8ppZR3EhJg0SJISrLBJyGh5PvwWwASkUGF/OoCfx1DKaVU4EhIKF3gyRGavWNKKaUCngYgpZRS/nHsGPTp4/PmATEMuyQOHDjAnj17yMjI8LooflWlShVSU1O9LoarfD3nqKgoatasSVxcnAulUkqV2sSJsHChz5sHVQA6cOAAu3fvpl69esTExGCM8bpIfnPw4EEqV67sdTFc5cs5iwhHjhxh+/btABqElApUa9fChAkleklQNcHt2bOHevXqERsbG1LBRxXOGENsbCz16tVjz549XhdHKVWQzEy48UZ7f8cdPr8sqAJQRkYGMTExXhdDeSAmJibkml2VChlTpsB330GjRjBpks8vC6oABGjNJ0zp312pAJWaCuPG2Z/feANK0JUQdAFIKaVUgDh+HIYPt6PfbrqpRCPgQANQyPnwww/z1BamT59OpUqVyrTPpKQkjDHs3bu3rMVTSoWS556DZcugXj149tkSv1wDkEuGDRuGMQZjDFFRUZx++uk88MADHD582NHjXnPNNWzatMnn7Rs3bswzzzyT57nOnTuzc+dOatSo4e/iKaWCVWoq/O1v9uc33oAqVUq8i6Aahh3sevfuzb///W8yMjJYsmQJN998M4cPH+bVV1/Ns11mZiaRkZF+6feIiYkp88CN8uXLU7t27TKXRSkVIjIzYdgwSE+3TXAXX1yq3WgNyEXR0dHUrl2bBg0aMHjwYIYMGcLcuXMZN24cnTp1Yvr06ZxxxhlER0dz+PBh9u/fz4gRI6hZsyaVK1eme/furFy5Ms8+Z86cSaNGjYiNjeXSSy9l9+7deX5fUBPcp59+SqdOnYiJiaFGjRr079+fo0eP0qNHD3777TdGjx59orYGBTfBzZkzh1atWhEdHU2DBg2YMGFCzhIcgK1J/f3vf+fWW28lLi6O+vXrM2XKlDzlmDZtGk2bNqVChQrEx8dz0UUXkZmZ6Zf3WinloGefhW+/hfr1bTNcKWkA8lDuocW//fYbb7/9Nh988AFr1qwhOjqafv36sX37dubNm8fq1avp1q0bvXr1YufOnQAsX76cYcOGMWLECFJSUujfvz+PP/54kcecP38+AwYMoE+fPqxatYrExES6d+9OVlYWc+bMoX79+jz++OPs3LnzxHHyW7VqFVdddRUDBw7k+++/56mnnmLSpEn83//9X57tnn/+eVq1asV3333HQw89xIMPPkhycjIAK1euZNSoUYwdO5aNGzeycOFC+vbtW9a3VCnlgORkO7o6ORnYsAFyrjNvvlmqprcTfFk0yM1bUQvSbdiw4dQnwZtbCQ0dOlT69et34vHy5culRo0acvXVV8vYsWOlXLlysmvXrhO/X7RokVSsWFHS0tLy7KdNmzYyefJkEREZNGiQ9O7dO8/vb7rpJiFX+f71r39JxYoVTzzu3LmzXHPNNYWWs1GjRjJlypQ8zyUmJgog//vf/0REZPDgwdKzZ88824wdO1bq1auXZz/XXnttnm2aNGkiTz75pIiIzJ49W+Li4uTAgQOFliW/Av/+QSiUFykrSjiedyic89KlIjExdtG5ShUy5GCLc+w18OabC30NAbIgncpl/vz5VKpUiQoVKpCQkEC3bt146aWXAKhXrx61atU6se2qVatIS0sjPj6eSpUqnbitW7eOX375BYDU1FQS8uVCz/84v9WrV3PBBWVbISM1NZUuXbrkea5r165s376dAwcOnHiudevWebapW7fuiWwGffr0oUGDBpx22mkMGTKEGTNmcPDgwTKVSynlf0lJdpT18eNwb/pkKqWugIYNId9gpdII/kEIufodAl23bt14/fXXiYqKom7dukRFRZ34XWxsbJ5ts7KyqFWrFkuWLDllPzn50MSjcxeRQgdI5H4+9/nl/C4rKwuAypUrs2TJElavXs2XX37JpEmTeOSRR1ixYgV169Z1rvBKqRLJWfm0RXoKf8t6wj45bVrZmt6yaQ3IRbGxsTRp0oRGjRqdcnHOr3379uzevZuIiAiaNGmS51azZk0AzjrrLJYtW5bndfkf59euXTsWLVpU6O/Lly/P8ePHi9zHWWedxddff53nua+//pr69euXKKFquXLl6NWrF5MmTWLt2rUcPnyYefPm+fx6pZTzEhJg8efpfFZzKOXJgDvvhDK2ouQI/hpQiOrduzddunRhwIABPP300zRv3pxdu3Yxf/58evfuzfnnn8/dd99N586dmTRpEldeeSVJSUl89NFHRe730UcfpX///jRp0oTBgwcjIixYsIBbb72V2NhYGjduzJIlS7juuuuIjo7mL3/5yyn7GDVqFOeccw7jxo1j8ODBrFixgmeffZaJEyf6fH7z5s1j/fr1XHjhhVSvXp3ExEQOHjxIixYtSvxeKaWcdd6C8bBrLTRpAk895bf9ag0oQBlj+Oyzz+jVqxe33HILzZo14+qrr2bjxo0nmqjOO+88pk6dyquvvkrr1q2ZM2cO43JyMhXikksu4aOPPuLzzz+nXbt2dO/encTERCIi7Edh/PjxbN26lTPOOIP4+PgC99G+fXs++OADZs+eTcuWLRkzZgxjxozhzjvv9Pn8qlatyqeffkrv3r1p3rw5zzzzDG+++Sbnn3++z/tQSrlg+XIbdIyB6dOhYkW/7dp41Y9QmGbNmsnGjRsL/F1qamrIfkPW9YCKFyp//6SkJHr06OF1MVwXjucd9Od8+DC0awc//QQPPGCzXvvAGLNKRDoWt53WgJRSShXswQdt8GnZEp580u+71wCklFLqVF98Aa+8AlFR8NZbUKGC3w+hAUgppVRef/xhc7wBjB8Pbdo4chgNQEoppfK64w7YsQM6d4bRox07TNAFoEAbNKHcoX93pVzy9tvw7rt2tNvMmRAZ6dihgioARUVFceTIEa+LoTxw5MiRYifvKqXKaMsWGDnS/vz883DGGY4eLqgCUM2aNdm+fTtpaWn6jThMiAhpaWls3779RAYIpZQDjh+HG26A/fthwAC4+WbHDxlUmRBycqDt2LHjxDIGoeLo0aNUcGCUSSDz9ZyjoqKoVavWib+/UsoBzz0HX30FtWrZFU79sCBmcYIqAIENQqF4IUpKSqJdu3ZeF8NV4XjOShUkOdlmne7Rw+Zec11KCjz6qP152jQoJAuKvwVdAFJKqVCSnGxzex47ZrNOL1rkchBKS4MhQyAjw/b/XHKJa4cOqj4gpZQKNbnX2zl2zD521QMP2FVOmzf3OdWOv2gAUkopD+WstxMZae9dTR338cfw6qtklYtiWu+3SV4TW/xr/EgDkFJKeSghwTa7Pfl
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 413\n",
2021-11-03 23:35:15 +01:00
"\n",
"X_new = np.linspace(-3, 3, 100).reshape(100, 1)\n",
"X_new_poly = poly_features.transform(X_new)\n",
"y_new = lin_reg.predict(X_new_poly)\n",
2021-11-03 23:35:15 +01:00
"\n",
"plt.figure(figsize=(6, 4))\n",
"plt.plot(X, y, \"b.\")\n",
"plt.plot(X_new, y_new, \"r-\", linewidth=2, label=\"Predictions\")\n",
2021-11-03 23:35:15 +01:00
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
"plt.legend(loc=\"upper left\")\n",
"plt.axis([-3, 3, 0, 10])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
"save_fig(\"quadratic_predictions_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 30,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAB1/UlEQVR4nO2dd3xT1dvAv7d70LL3XmVvBAqIZSPDgeBCBRyIiKC4EFRaVJDhQF/wJ8pUEBUnsldBoOwhe+9VoHSvNDnvH6dJkzZp05K0TXu+fPpJ7r3nnHtuEu5zn/MsTQiBQqFQKBT5jVtBT0ChUCgUxRMlgBQKhUJRICgBpFAoFIoCQQkghUKhUBQISgApFAqFokBQAkihUCgUBYLDBJCmafM1TYvUNO2I2b4ymqat1zTtdPpraUedT6FQKBSujSM1oIVAn0z7xgMbhRD1gY3p2wqFQqFQoDkyEFXTtFrAP0KIpunbJ4EQIcR1TdMqA+FCiAYOO6FCoVAoXBZn24AqCiGuA6S/VnDy+RQKhULhIngU9AQANE0bAYwA8PHxaVOjRg0A7uruojPoqOCds9xK0CdwN/Uu1XyrWT1+PuE8VX2r4uXm5biJOxCDwYCbW/HyCSmO1wzquosTjrzmRH0iVxKvUM67HGW8ymQ5nqxPJjIlkhp+NUzt76Tcobpf9RzHjtHFkKRPopJPJavHz8SfobZ/bdw1d1IMKVxPuk5Ft7pcueKHEKBpUK1aIr6+egBOnTp1WwhRPscTCyEc9gfUAo6YbZ8EKqe/rwyczGmMoKAgYeTLiC/FmFVjhD2sPbNW9Fzc0+bxoK+DxIlbJ+waqyDYvHlzQU8h3ymO1yyEuu7ihCOvef3Z9YJQxKf/fmr1+M7LO0W779qZtsPPh4suC7rYNfb8/fPFsD+H2TxedlpZcSvhlhBCiMM3D4sms5uIKVOEcHcXAuTrlCkZ7YG9wg6Z4ezHkb+BoenvhwJ/Ofl8CoVCUSQxCAMAgsKRQDokBLy8wN1dvoaE5H4Mhy3BaZr2ExAClNM07QowCfgU+EXTtBeAS8BgR51PoVAoihMmAVRIKhgEB8PGjRAeLoVPcHDux3CYABJCPGXjUHdHnUOhUCiKK4VNAwIpdPIieIwUL4ugQqFQuCiFTQNyBIXCC06hUCgU2VMYNaDM6HSQG6c/lxNAsbGxREZGotPpLPaX05VjQsMJHD9+3Gq/2e1mk3IjheO3rB8vaEqWLGlz7kUVZ1yzv78/1apVK3buvoqijytoQJ98AmvW2N/epQRQbGwsN2/epGrVqvj6+qJpmulYTHIMNxNuElQ2yGpffaSeeqXr4ePpk1/TzRVxcXEEBAQU9DTyFUdfs8Fg4OrVq9y+fZsKFVTMs6JoYRQ8hVUDioqCmTMhIcH+Pi71mBgZGUnVqlXx8/OzED4KBYCbmxsVK1YkJiamoKeiUDicnDSgghZMZcrArl3wxRf293EpAaTT6fD19S3oaSgKMZ6enqSlpRX0NBQKh2OPDUijYB/MmzSB11+3v71LCSBAaT6KbFG/D0VRpbDagI4fh6VLIS/TcikbkEKhUBRXCqMXnF4Pzz8PO3dKG9Do0bnr73IaUHFm+fLl6glfoSimFEYN6PPPpfCpWhWefTb3/YuVACqoJ4etW7fy0EMPUbVqVTRNY+HChQUyD4VC4boYBVBh4fhx+OAD+f6776BkydyPUawEUEERHx9P06ZNmTVrlks4UaSmphb0FBQKRSYK2xLcsGGQkgLDh8ODD+ZtjOItgEJD8+U0ffv2ZcqUKQwaNChXAZKLFy+mZs2a+Pn50b9/f27evJmlzYoVK2jTpg0+Pj7Url2biRMnWgiQmzdv8tBDD+Hr60vNmjVZsGABTZs2JdTs2jVNY/bs2QwcOBB/f38mTJhg19ipqam8++67VKtWDX9/f+677z7Wrl2bh09IoVDkhFHwFIYluNu3YfduqFZNLsPlleItgMLCCnoGNtm1axfDhg1jxIgRHDx4kAEDBvDhhx9atFm7di1Dhgxh9OjRHD16lPnz57N8+XKTAAEYOnQoFy9eZNOmTfz111/8+OOPXLx4Mcv5wsLC6Nu3L4cPH+bVV1+1a+zhw4ezZcsWli5dyuHDhxk6dCgDBgzg0KFDzvtgFIpiSkFqQLqLbflyhi8REZCWBsZQu++/h1Kl8j6u63vBpRvlS6b/2aJpDv1zjZOfQmbNmkX37t2ZOHEiAEFBQezZs4d58+aZ2nzyySe8/fbbDB8+HIC6desybdo0nnnmGWbMmMGpU6dYu3YtERERdOjQAYCFCxdSq1atLOd74oknePHFF03bQ4cOzXbsc+fO8dNPP3HhwgWMFWxHjx7Nhg0b+Pbbb5kzZ45TPheForhSUE4IEREQ+93vfGrw5fNpMPdXX+rUgUk/Q+/e9za26wugIsrx48cZMGCAxb7g4GALAbRv3z52797NtGnTTPsMBgNJSUncuHGDEydO4ObmRtu2bU3Hq1evTpUqVbKcz7yNPWPv378fIQSNGze26JeSkkK3bt3ydtEKhcImOWlAzhJM4eFAmhd6oZGaCnt3+KNVhccfv/exXV8ApX/oOeWCOxJ5hLql6+LraeYEoGlO12Tyij0/JoPBwKRJkxg8OGudv/Lly+fqB+nv75+rsQ0GA5qmsWfPHjw9PS2Ou4KjhULhatijATkjTCMkBPBIxc3gDmg0bpnAhluOGdv1BVARpXHjxuzcudNiX+bt1q1bc+LECerVq2d1jEaNGmEwGNi3bx/t27cH4MqVK1y7di3H8+c0dqtWrRBCcOPGDbp27WrPJSkUinugoGxAwcEQ+NJAvP/8jVvX/NlzKAmyLqLkieItgCZNypfTxMfHc+bMGUBqFpcuXeLgwYOUKVPGZD/JzJgxY+jYsSNTp05l0KBBhIeH88cff1i0+fDDD+nfvz81a9bk8ccfx8PDgyNHjrB7926mT59OgwYN6N27NyNHjuSbb77Bx8eHt99+265krjmNHRQUxJAhQxg2bBifffYZrVu3JioqivDwcOrUqcPAgQMd8+EpFAqgYANRdeX3EhudRN26/rz0IkSscsy4xdsLLp/csPfu3UurVq1o1aoVSUlJTJo0iVatWmXxajOnQ4cOzJs3j2+++YbmzZvz+++/W7hOA/Tu3ZuVK1eyefNm2rVrR7t27fj0008thNrChQupVq0aISEhPPTQQwwZMoQKFSrg45N9WQp7xl6wYAHDhw/nnXfeoWHDhvTv35+tW7dSs2bNvH1QCoXCJgWlAe3aBUlJ8v2iReDn57ixi7cGlE+EhITk6all+PDhJi80I6MzJVvq1asXvXr1sjlGpUqVWLFihWn79u3bjBgxwmJpzdbcchrb09OT0NDQLIJRoVA4HlM9oHzUgBIS0lPsPAqjRkGnTnAk0nHjKwFUxNm0aRNxcXE0a9aMyMhIJk6cSLly5ejTp09BT02hUOSCgtCA3nkHTp8Gd3d4713Hj68EUBFHp9Px/vvvc+7cOfz8/Gjfvj1bt27N4vWmUCgKN/ltAzIYIDUVPD3BPwByWLXPE0oAFXF69+5N73uNFlMoFAWOQRhw19zzTQNyc5NJRt97D9otd9I5nDOsQqFQKByJQRhwd3PPFw0oJSXjfZ06zjuPEkAKhULhAhg1IFs4SjNauhSaN5fJRp2NEkAKhULhAhiEATfNLVtBo3FvmRAuXZLebqdOQX7kFFYCSKFQKFwAZy/BCQM895zMdP3ww2CWm9hpKCcEhUKhcAEEIkcN6F44ehT2boGKFaXzgRPSymVBaUAKhUJRwEREwNSp8tUWJi84J2hAly7Bvv3y/fz5UL68w09hFSWAXIjly5c7JdutQqEoOCIioHt3+OAD+WpLCJmW4BysAaWmwty5cglu1Cjo29ehw2eLEkD5wNSpU7nvvvsIDAykfPnyDBgwgCNHjhT0tBQKRSFgyGdzSU0V6PVSGISHW2/nLA3Iywv6D5Baz4wZDh06R5QAygfCw8MZNWoUO3bsYNOmTXh4eNCjRw+ioqIKempWSU1NLegpKBTFAoMwcL7UAry8ZLobL6/0+js22jpDAwIo51O
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 414\n",
2021-11-03 23:35:15 +01:00
"\n",
"from sklearn.preprocessing import StandardScaler\n",
2021-11-03 23:35:15 +01:00
"from sklearn.pipeline import make_pipeline\n",
"\n",
2021-11-03 23:35:15 +01:00
"plt.figure(figsize=(6, 4))\n",
"\n",
"for style, width, degree in ((\"r-+\", 2, 1), (\"b--\", 2, 2), (\"g-\", 1, 300)):\n",
" polybig_features = PolynomialFeatures(degree=degree, include_bias=False)\n",
" std_scaler = StandardScaler()\n",
" lin_reg = LinearRegression()\n",
2021-11-03 23:35:15 +01:00
" polynomial_regression = make_pipeline(polybig_features, std_scaler, lin_reg)\n",
" polynomial_regression.fit(X, y)\n",
" y_newbig = polynomial_regression.predict(X_new)\n",
2021-11-03 23:35:15 +01:00
" label = f\"{degree} degree{'s' if degree > 1 else ''}\"\n",
" plt.plot(X_new, y_newbig, style, label=label, linewidth=width)\n",
"\n",
"plt.plot(X, y, \"b.\", linewidth=3)\n",
"plt.legend(loc=\"upper left\")\n",
2021-11-03 23:35:15 +01:00
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
"plt.axis([-3, 3, 0, 10])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
"save_fig(\"high_degree_polynomials_plot\")\n",
"plt.show()"
]
},
2021-10-02 13:14:44 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Learning Curves"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 31,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAvKElEQVR4nO3deXhU5d3/8fc3IYGwb2FXcAEUFxQRt9pibS1YlVa9qrWUYv1VfVyqdW2rj4C19teWy7b+tFX7aNVq9WmtuCBu+BAfad2pVFkVjcoim2wBEpLM9/fHPYEhTJJJmJkzk3xe13WuOWfO9p3JzHxz7vs+923ujoiISLYVRB2AiIi0TUpAIiISCSUgERGJhBKQiIhEQglIREQioQQkIiKRyFoCMrN9zGyOmS0yswVmdkWSbcaa2SYzeyc+3ZSt+EREJLvaZfFcNcDV7j7PzLoAb5vZi+6+sN52r7j7aVmMS0REIpC1KyB3X+Xu8+LzW4BFwMBsnV9ERHJLJHVAZjYEOBJ4Pcnq48xsvpk9a2aHZDcyERHJlmwWwQFgZp2BvwNXuvvmeqvnAYPdvcLMTgWeAIYmOcaFwIUAHTp0OGrfffdNeq41a9qzcWMxAJ061TBw4PYm46uoaMfKlSXN2qc5YrEYBQX51fYj32JWvJmXbzEr3sxbunTpOncvbdZO7p61CSgCngeuSnH7cqB3Y9sMGzbMG/LQQ+4QpltvbXCz3bz88q59jj8+tX2aY86cOek/aIblW8yKN/PyLWbFm3nAW97MnJC1KyAzM+BeYJG739bANv2A1e7uZjaGUES4vqXnPPdceP99qKiAyy9PbZ/evXfNr13b0jOLiEhTslkEdwLwXeBdM3sn/txPgX0B3P0u4GzgP8ysBtgOnBvPrC1SWAhTpzZvn9KEC8h161p6ZhERaUrWEpC7zwWsiW3uAO7ITkTJ9egBZqEQbsMGqKmBdlmvKRMRaf3yq5YrC9q1C0mozvoWFwCKiEhjlICSUDGciEjmqXApid69YcmSMK8EJJK/Nm/ezJo1a+jWrRuLFi2KOpyU5Vq8RUVF9OnTh65du6b1uEpASaglnEj+27x5M6tXr2bgwIHU1NSk/cczk7Zs2UKXLl2iDgMIt+ps376dFStWAKT1fVQRXBIqghPJf2vWrGHgwIF07NiRcBeItISZ0bFjRwYOHMiaNWvSemwloCQSr4CUgETyU3V1NSUlJVGH0WqUlJRQXV2d1mMqASWReAWkIjiR/KUrn/TJxHupBJSEroBERDJPCSgJJSARaS3Gjh3LZZddFnUYSakVXBIqghORKJ166qmMHDmSO+7Y+45hHn/8cYqKitIQVfopASWhKyARyXXV1dUpJZaePXtmIZqWURFcEvXvA2p5d6gi0mo0t2fjFpo8eTJz587lzjvvxMwwM+6//37MjFmzZjFmzBiKi4t5/vnnWbZsGRMmTKBfv3506tSJUaNGMXPmzN2OV78IbsiQIdxyyy1cdNFFdO3alUGDBvHrX/86K6+tPiWgJDp3hvbtw3xlJWzbFm08IpIDpk3Lyml+97vfMWbMGM4//3xWrVrFqlWr2GeffQC4/vrrueWWW1i8eDHHHHMMFRUVjB8/nhdffJH58+dz1llnceaZZ7J48eJGz/Gb3/yGww47jHnz5nH99ddz3XXX8eqrr2bj5e1GCSgJMxXDibRKZi2f9mb/ZujWrRvFxcV07NiRfv360a9fPwoLCwGYOnUqp5xyCvvvvz+lpaWMHDmSiy++mMMOO4wDDzyQG264gVGjRvHYY481eo5TTjmFyy67jAMPPJDLL7+cAw88kJdeeqlFb+neUAJqgBoiiEiuGT169G7LW7du5brrrmPEiBH06NGDzp0789Zbb/HJJ580epzDDz98t+UBAwakvZeDVKgRQgN0BSTSCu1NhW7dQGER6tSp027L11xzDc899xzTp09n6NChdOzYkUmTJrFjx45Gj1O/8YKZEYvF0h5vU5SAGqAEJCJRKSoqora2tsnt5s6dy6RJkzjrrLMAqKysZNmyZQwbNizTIaaFiuAaoCI4EdnNlClZO9XgwYN54403KC8vZ926dQ1enQwbNowZM2Ywb9483n33XSZOnEhlZWXW4txbSkAN0BWQiOwmS82wAS6//HKKi4sZMWIEpaWlDdbp3HbbbfTp04cTTzyR8ePHc+yxx3LiiSdmLc69pSK4BigBiUhUhg4dukez6MmTJ++x3eDBg5k9e/Zuz11zzTW7LZeVle22XF5evsdx6m+TLboCaoCK4EREMksJqAG6AhIRySwloAYoAYmIZJYSUANUBCcikllKQA3o1WvX/OefQwpN8kVEpBmUgBpQVATdu4f5WAw2bIg0HBGRVkcJqBGqBxIRyRwloEYoAYmIZI4SUCPUEEFEJHOUgBqhKyARyUf1R0Gtv5zMoYceytQsdjcE6oqnUUpAItIaPP7443sMwZALlIAaoSI4EWkNevbsGXUISakIrhG6AhKRbLv77rs54IADqKmp2e358847jwkTJrBs2TImTJhAv3796NSpE6NGjWLmzJmNHrN+EdyaNWuYMGECJSUlDB48mPvuuy8jr6UpSkCNSLwCUgISkWz41re+xaZNm3br5Xrr1q08+eSTTJw4kYqKCsaPH8+LL77I/PnzOeusszjzzDNZvHhxyueYPHkyH3zwAbNnz+aJJ57gwQcfTNpLdqapCK4RiVdAKoITyX9m0Z071dG8e/TowSmnnMLDDz/MuHHjAJgxYwbt2rXj9NNPp0OHDowcOXLn9jfccANPP/00jz32GDfeeGOTx1+6dCnPPvssc+fO5YQTTgDggQceYP/992/+i9pLWbsCMrN9zGyOmS0yswVmdkWSbczMbjezD8zs32Y2KlvxJaMiOBGJwjnnnMMTTzzBtm3bAHj44Yc5++yz6dChA1u3buW6665jxIgR9OjRg86dO/PWW281OGhdfYsWLaKgoIAxY8bsfG7w4MEMGDAgI6+lMdm8AqoBrnb3eWbWBXjbzF5094UJ24wHhsanY4A/xB8joSI4EYnCuHHjaNeuHU8++SQnn3wys2fP5oUXXgDCgHPPPfcc06dPZ+jQoXTs2JFJkyaxY8eOlI7tqV6KZUHWEpC7rwJWxee3mNkiYCCQmIAmAA96eIdeM7PuZtY/vm/Wde0K7dpBTQ1UVEBlJXToEEUkIpIOOfTb26j27dtz9tln8/DDD7Nu3Tr69evHl770JQDmzp3LpEmTOOusswCorKxk2bJlDBs2LKVjH3zwwcRiMd58802OP/54AD755BNWrlyZmRfTiEgaIZjZEOBI4PV6qwYCnyYsL48/FwkzFcOJSDQmTpzI888/z1133cV5551HQUH4uR42bBgzZsxg3rx5vPvuu0ycOJHKysqUjzt8+HDGjRvHRRddxKuvvso777zD5MmTKSkpydRLaVDWGyGYWWfg78CV7r65/uoku+zxP4uZXQhcCFBaWprR8cw7dhwNdAbguefe4sADK/bqeBUVFZGNv95S+Raz4s28fIi5W7dubNmyBYDa2tqd8/mgtraWI488kgEDBrBw4ULuvffenfHffPPNXHbZZZx44ol0796dSy65hIqKCqqrq3d7vTt27Ghw+Y477uDyyy/ny1/+Mr169eLHP/4xn332GVVVVY2+T5WVlen9u7t71iagCHgeuKqB9XcD305YXgL0b+yYw4YN80w66ST3cOHu/sILe3+8OXPm7P1BsizfYla8mZcPMS9cuHDn/ObNmyOMpPlyNd7E97Q+4C1vZk7IZis4A+4FFrn7bQ1s9hQwKd4a7lhgk0dU/1NHDRFERDIjm0VwJwDfBd41s3fiz/0U2BfA3e8CZgGnAh8A24DzsxhfUroXSEQkM7LZCm4uyet4Erdx4NLsRJQaNUIQEckMdcXTBBXBiYhkhhJQE1QEJ5K/PF9u/MkDmXgvlYCaoCI4kfxUVFTE9u3bow6j1di+fXvaxxRSAmpCJEVwWR6VUKQ16tOnDytWrGDbtm26EtoL7s62bdtYsWIFffr0Seux1Rt2EyIpgps2TUlIZC917doVgJUrV7J
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-11-03 23:35:15 +01:00
"from sklearn.model_selection import learning_curve\n",
"\n",
"train_sizes, train_scores, valid_scores = learning_curve(\n",
" LinearRegression(), X, y, train_sizes=np.linspace(0.01, 1.0, 40), cv=5,\n",
" scoring=\"neg_root_mean_squared_error\")\n",
"train_errors = -train_scores.mean(axis=1)\n",
"valid_errors = -valid_scores.mean(axis=1)\n",
"\n",
2022-05-24 14:37:17 +02:00
"plt.figure(figsize=(6, 4)) # extra code not needed, just formatting\n",
2021-11-03 23:35:15 +01:00
"plt.plot(train_sizes, train_errors, \"r-+\", linewidth=2, label=\"train\")\n",
"plt.plot(train_sizes, valid_errors, \"b-\", linewidth=3, label=\"valid\")\n",
"\n",
"# extra code beautifies and saves Figure 415\n",
2021-11-03 23:35:15 +01:00
"plt.xlabel(\"Training set size\")\n",
"plt.ylabel(\"RMSE\")\n",
"plt.grid()\n",
"plt.legend(loc=\"upper right\")\n",
"plt.axis([0, 80, 0, 2.5])\n",
"save_fig(\"underfitting_learning_curves_plot\")\n",
"\n",
"plt.show()"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 32,
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
2021-11-03 23:35:15 +01:00
"from sklearn.pipeline import make_pipeline\n",
"\n",
"polynomial_regression = make_pipeline(\n",
" PolynomialFeatures(degree=10, include_bias=False),\n",
" LinearRegression())\n",
"\n",
"train_sizes, train_scores, valid_scores = learning_curve(\n",
" polynomial_regression, X, y, train_sizes=np.linspace(0.01, 1.0, 40), cv=5,\n",
2021-11-21 05:36:22 +01:00
" scoring=\"neg_root_mean_squared_error\")"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAv9UlEQVR4nO3deZxUxbn/8c8zMMAMwy44LMoIskhEBBFXFExUIBpUuG4xBOP9qYl4NWpcolEwJtckXo254najopFojIIormhAxbghAsoqCCI7uAADDDAz9fujupmenq176D69zPf9etWrz9ann26Gfrrq1Kky5xwiIiJBy0l1ACIi0jApAYmISEooAYmISEooAYmISEooAYmISEooAYmISEoEloDM7CAzm2lmi81soZldVc0xQ8xsq5nNC5Vbg4pPRESC1TjA1yoFrnXOzTWzFsDHZjbDObco6rh3nHNnBBiXiIikQGA1IOfceufc3NDydmAx0Dmo1xcRkfSSkmtAZlYE9Ac+qGb3cWY238xeMbPvBRuZiIgEJcgmOADMrAB4DrjaObctavdcoKtzrtjMRgDPAz2qOcelwKUAzZo1O+rggw9ObtAJVF5eTk5OYvP+8uUFlJcbAN27F9OoUWKHV0pGzMmkeJMv02JWvMm3bNmyLc659nE9yTkXWAFygdeAa2I8fhVwQG3H9OzZ02WSmTNnJvyc7do5B75s3pzw0ycl5mRSvMmXaTEr3uQD5rg4c0KQveAMeARY7Jy7u4ZjCkPHYWaD8E2EXwcVo4iIBCfIJrgTgJ8An5rZvNC2XwMHAzjnHgRGAz83s1JgF3B+KLOKiEiWCSwBOedmA1bHMfcB9wUTkYiIpFJmXeUSEZGsoQSUZdRgKSKZIvBu2JJ4VmvDpkjDtW3bNjZt2kSrVq1YvHhxqsOJWbrFm5ubS4cOHWjZsmVCz6sEJCJZadu2bWzcuJHOnTtTWlqa8C/PZNq+fTstWrRIdRiAv1Vn165drF27FiChn6Oa4EQkK23atInOnTuTn5+PqZmg3syM/Px8OnfuzKZNmxJ6biUgEclKe/fuJS8vL9VhZI28vDz27t2b0HMqAYlI1lLNJ3GS8VkqAWUZ9YITkUyhBJQF9CNPRGoyZMgQxo0bl+owqqVecCIiaWbEiBH069eP++7b/4FhpkyZQm5ubgKiSjwlIBGRDLR3796YEkvbtm0DiKZ+1AQnIhKL8eMDeZmxY8cye/ZsJk6ciJlhZkyaNAkz4+WXX2bQoEE0adKE1157jRUrVjBy5EgKCwtp3rw5AwYMYPr06ZXOF90EV1RUxB133MFll11Gy5Yt6dKlC3/6058CeW/RlIBERGIxYUIgL3PvvfcyaNAgLr74YtavX8/69es56KCDALjhhhu44447WLJkCccccwzFxcUMHz6cGTNmMH/+fEaNGsU555zDkiVLan2Ne+65h759+zJ37lxuuOEGrr/+et57770g3l4lSkAi0nCY1b/sz/Pj0KpVK5o0aUJ+fj6FhYUUFhbSqFEjAMaPH89pp51Gt27daN++Pf369ePyyy+nb9++HHroodx8880MGDCAZ599ttbXOO200xg3bhyHHnooV155JYceeihvvvlmvT7S/aEElAUi/77VDVskew0cOLDS+o4dO7j++uvp06cPbdq0oaCggDlz5rB69epaz3PEEUdUWu/UqVPCRzmIhTohiEjDsT+/0MxS/guvefPmldavu+46Xn31Ve666y569OhBfn4+Y8aMYc+ePbWeJ7rzgplRXl6e8HjrogQkIpJmcnNzKSsrq/O42bNnM2bMGEaNGgVASUkJK1asoGfPnskOMSHUBCciEovbbgvspbp27cqHH37IqlWr2LJlS421k549ezJ16lTmzp3Lp59+ykUXXURJSUlgce4vJSARkVgE1A0b4Morr6RJkyb06dOH9u3b13hN5+6776ZDhw4MHjyY4cOHc+yxxzJ48ODA4txfaoITEUkzPXr0qNIteuzYsVWO69q1K2+88Ualbdddd12l9VmzZlVaX7VqVZXzRB8TFNWAsoB6wYlIJlICEhGRlFACEhGRlFACEhGRlFACEhGRlFACEhGRlFACygLqBScimUgJSEREUkIJSEREUkIJSEQky0TPghq9Xp3DDz+c8QEONwQaikdEJOtNmTKlyhQM6UAJSEQky7Vt2zbVIVRLTXBZIM4Zf0UkjT300EN0796d0tLSStsvvPBCRo4cyYoVKxg5ciSFhYU0b96cAQMGMH369FrPGd0Et2nTJkaOHEleXh5du3bl0UcfTcp7qYsSUJZRN2yRzHbuueeydevWSqNc79ixg2nTpnHRRRdRXFzM8OHDmTFjBvPnz2fUqFGcc845LFmyJObXGDt2LMuXL+eNN97g+eef54knnqh2lOxkUxOciDQYqWwtiPXHYZs2bTjttNOYPHkyw4YNA2Dq1Kk0btyYM888k2bNmtGvX799x9988828+OKLPPvss9xyyy11nn/ZsmW88sorzJ49mxNOOAGAxx9/nG7dusX/pvZTYDUgMzvIzGaa2WIzW2hmV1VzjJnZX8xsuZktMLMBQcUnIpIuzjvvPJ5//nl27twJwOTJkxk9ejTNmjVjx44dXH/99fTp04c2bdpQUFDAnDlzapy0LtrixYvJyclh0KBB+7Z17dqVTp06JeW91CbIGlApcK1zbq6ZtQA+NrMZzrlFEccMB3qEyjHAA6FHEZEGY9iwYTRu3Jhp06bx/e9/nzfeeIPXX38d8BPOvfrqq9x111306NGD/Px8xowZw549e2I6t0ujdvrAEpBzbj2wPrS83cwWA52ByAQ0EnjC+U/ofTNrbWYdQ88VEdkvafTdW6umTZsyevRoJk+ezJYtWygsLOTkk08GYPbs2YwZM4ZRo0YBUFJSwooVK+jZs2dM5z7ssMMoLy/no48+4vjjjwdg9erVrFu3LjlvphYpuQZkZkVAf+CDqF2dga8i1teEtikB1UK94ESyz0UXXcQPfvADVq5cyYUXXkhOjr9i0rNnT6ZOncrIkSPJzc1lwoQJlJSUxHzeXr16MWzYMC677DIefvhh8vLyuOaaa8jLy0vWW6lR4AnIzAqA54CrnXPbondX85Qqv1nM7FLgUoD27dunbD7z+iguLk54vHv2HAc0BeDdd/9N+/axVcVjlYyYk0nxJl8mxNyqVSu2b98OQFlZ2b7lTFBWVkb//v3p1KkTixYt4pFHHtkX/+233864ceMYPHgwrVu35he/+AXFxcXs3bu30vvds2dPjev33XcfV155Jaeccgrt2rXjxhtvZMOGDezevbvWz6mkpCSx/+7OucAKkAu8BlxTw/6HgAsi1pcCHWs7Z8+ePV0mmTlzZsLP2amTc75xwbk1axJ++qTEnEyKN/kyIeZFixbtW962bVsKI4lfusYb+ZlGA+a4OHNCkL3gDHgEWOycu7uGw14AxoR6wx0LbHW6/iMikpWCbII7AfgJ8KmZzQtt+zVwMIBz7kHgZWAEsBzYCVwcYHwiIhKgIHvBzab6azyRxzjgimAiEhGRVNJQPFlAveBEJBMpAWWZTLnPQSQITv8hEiYZn6USkIhkpdzcXHbt2pXqMLLGrl27Ej6nkBKQiGSlDh06sHbtWnbu3Kma0H5wzrFz507Wrl1Lhw4dEnpujYYtIlmpZcuWAKxbt47t27fTrFmzFEcUu5KSkrSKNzc3lwMPPHDfZ5ooSkAikrVatmxJy5YtmTVrFv379091ODHLtHjrS01wWUC94EQkEykBiYhISigBZRldaxWRTKEEJCIiKaEEJCIiKaEEJCIiKaEEJCIiKaEElAXUDVtEMpESUJZRLzgRyRRKQCIikhJKQCIikhJKQCIikhJKQCIikhJKQFlAveBEJBMpAWUZ9YITkUyhBCQiIimhBCQiIimhBCQiIimhBCQiIimhBJQF1AtORDKREpCIiKSEElCWUTdsEckUSkAiIpISSkAiIpISSkAiIpISSkBZQL3gRCQTKQGJiEhKKAFlGfWCE5FMoQQkIiIpEVgCMrNHzWyTmX1Ww/4hZrbVzOaFyq1BxSYiIsF
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2021-11-21 05:36:22 +01:00
"source": [
"# extra code generates and saves Figure 416\n",
2021-11-03 23:35:15 +01:00
"\n",
"train_errors = -train_scores.mean(axis=1)\n",
"valid_errors = -valid_scores.mean(axis=1)\n",
2021-11-21 05:36:22 +01:00
"\n",
2021-11-03 23:35:15 +01:00
"plt.figure(figsize=(6, 4))\n",
"plt.plot(train_sizes, train_errors, \"r-+\", linewidth=2, label=\"train\")\n",
"plt.plot(train_sizes, valid_errors, \"b-\", linewidth=3, label=\"valid\")\n",
"plt.legend(loc=\"upper right\")\n",
"plt.xlabel(\"Training set size\")\n",
"plt.ylabel(\"RMSE\")\n",
"plt.grid()\n",
"plt.axis([0, 80, 0, 2.5])\n",
"save_fig(\"learning_curves_plot\")\n",
"plt.show()"
]
},
{
2021-11-03 23:35:15 +01:00
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-03 23:35:15 +01:00
"# Regularized Linear Models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-03 23:35:15 +01:00
"## Ridge Regression"
2021-10-02 13:14:44 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-03 23:35:15 +01:00
"Let's generate a very small and noisy linear dataset:"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"# extra code we've done this type of generation several times before\n",
"np.random.seed(42)\n",
"m = 20\n",
"X = 3 * np.random.rand(m, 1)\n",
"y = 1 + 0.5 * X + np.random.randn(m, 1) / 1.5\n",
"X_new = np.linspace(0, 3, 100).reshape(100, 1)"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 35,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAEPCAYAAACp/QjLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVRklEQVR4nO3df4xl5X3f8fdnf2Csri0wjArixzpbr6a1HSWwGAZZjQbkSmZFRKUSBdeyE0t0BXUiR0lkRa4EcqSqUhRZxYJ6u7VRjIQcbRU7QhRq0XantlUvpkMWwhqvu15lxQpkBKzBU2NgmG//mEs0nXNn59e5c3+9X9LV3h/PnPt89WjnM+c55zwnVYUkSUtt63cHJEmDx3CQJDUYDpKkBsNBktRgOEiSGgwHSVJDq+GQ5PwkP0jyVJLjSb7Ypc10kleTHOs87mqzD5KkzdvR8vbeAG6sqrkkO4HvJXm0qo4ua/fdqrq55e+WJLWk1XCoxSvq5jovd3YeXmUnSUOm9WMOSbYnOQa8CDxWVY93aXZ9Z+rp0SQfarsPkqTNSa+Wz0hyAfAt4Per6pkl778XWOhMPe0H7qmqvSts4wBwAOD888/fd+WVV/akr4NgYWGBbdtG8/yAUa4NrG/YjXp9P/7xj1+qqon1/lzPwgEgyd3A/62qPz9Hm78Drqmql861rcnJyTpx4kTLPRwcMzMzTE9P97sbPTHKtYH1DbtRry/JbFVds96fa/tspYnOHgNJ3g18DPjRsjaXJEnn+bWdPrzcZj8kSZvT9tlKlwJfT7KdxV/6h6vq4SR3AFTVQeBW4M4k88DrwG3l0rCSNFDaPlvpaeCqLu8fXPL8XuDeNr9XktSu0T0KI0naMMNBktRgOEiSGgwHSVKD4SBJajAcJEkNhoMkqcFwkCQ1GA6SpAbDQZLUYDhIkhoMB0lSg+EgSWowHCRJDYaDJKnBcJAkNRgOkqQGw0GS1GA4SJIaDAdJUoPhIElqMBwkSQ2thkOS85P8IMlTSY4n+WKXNkny5SQnkzyd5Oo2+yBJ2rwdLW/vDeDGqppLshP4XpJHq+rokjY3AXs7j+uAr3T+lSQNiFb3HGrRXOflzs6jljW7BXig0/YocEGSS9vshyRpc1o/5pBke5JjwIvAY1X1+LImlwHPLXl9pvOeJGlAtD2tRFW9Dfx6kguAbyX5cFU9s6RJuv1Yt20lOQAcAJiYmGBmZqbl3g6Oubm5ka1vlGsD6xt2o17fRrUeDu+oqp8lmQE+DiwNhzPAFUteXw48v8I2DgGHACYnJ2t6eronfR0EMzMzjGp9o1wbWN+wG/X6Nqrts5UmOnsMJHk38DHgR8uaPQR8unPW0hTwalW90GY/JEmb0/aew6XA15NsZzF4DlfVw0nuAKiqg8AjwH7gJPAL4DMt90GStEmthkNVPQ1c1eX9g0ueF/DZNr9XktQur5CWJDUYDpKkBsNBktRgOEiSGgwHSVKD4SBJajAcJEkNhoMkqcFwkCQ1GA6SpAbDQZLUYDhI0hCZPX2W+46cZPb02Z5+T8/u5yBJatfs6bN88qtHeXN+gfN2bOPB26fYt/vCnnyXew6SNCSOnnqZN+cXWCh4a36Bo6de7tl3GQ6SNCSm9lzEeTu2sT2wc8c2pvZc1LPvclpJkobEvt0X8uDtUxw99TJTey7q2ZQSGA6SNJBmT5/tGgL7dl/Y01B4h+EgSQNmKw88r8RjDpI0YLbywPNKDAdJGjBbeeB5JU4rSdKA2coDzysxHCRpAG3VgeeVtDqtlOSKJEeSPJvkeJLPdWkzneTVJMc6j7va7IMkafPa3nOYB/6oqp5M8h5gNsljVfXDZe2+W1U3t/zdI2Wl09gkaSu0Gg5V9QLwQuf5z5M8C1wGLA8HncMgnMYmabz17GylJO8HrgIe7/Lx9UmeSvJokg/1qg/DahBOY5M03lJV7W802QX8T+DfVtU3l332XmChquaS7Afuqaq9K2znAHAAYGJiYt/hw4db7+ugmJubY9euXQCcPPs2f/bEL5lfgB3b4PMfOZ8PXLi9zz3cuKW1jSLrG26jXt8NN9wwW1XXrPfnWg+HJDuBh4FvV9WX1tD+74Brquqlc7WbnJysEydOtNPJATQzM8P09PTfvx6lYw7Laxs11jfcRr2+JBsKh1aPOSQJ8DXg2ZWCIcklwE+rqpJcy+LUlvMmy/T7NDZJ463ts5U+CnwK+NskxzrvfQG4EqCqDgK3AncmmQdeB26rXsxtSZI2rO2zlb4HZJU29wL3tvm9kqR2ubaSJKnBcJAkNRgOkqQGw0GS1GA4SJIaDAdJUoPhIElqMBwkSQ2GgySpwXCQJDUYDpKkBsOhD2ZPn+W+IyeZPX22312RpK7aXpVVq/AWoFtvlO6NIW0Vw2GLdbsFqL+wescwljbGaaUtNrXnIs7bsY3tgZ07tjG156J+d2mkeT9uaWPcc9hi+3ZfyIO3TznNsUXeCeO35hcMY2kdDIc+8BagW8cw7g+P8ww/w0EjzzDeWh7nGQ0ec5DUKo/zjAbDQVKrPOliNDitJKlVHucZDYaDpNZ5nGf4tTqtlOSKJEeSPJvkeJLPdWmTJF9OcjLJ00mubrMPkqTNa3vPYR74o6p6Msl7gNkkj1XVD5e0uQnY23lcB3yl868kaUC0uudQVS9U1ZOd5z8HngUuW9bsFuCBWnQUuCDJpW32Q5K0OT07WynJ+4GrgMeXfXQZ8NyS12doBogkqY96ckA6yS7gr4A/qKrXln/c5Udqhe0cAA4ATExMMDMz02Y3B8rc3NzI1jfKtYH1DbtRr2+jWg+HJDtZDIYHq+qbXZqcAa5Y8vpy4Plu26qqQ8AhgMnJyZqenm63swNkZmaGUa1vlGsD6xt2o17fRrV9tlKArwHPVtWXVmj2EPDpzllLU8CrVfVCm/2QJG3OmsIhyZkkf7jsvV9N8sskH1zy9keBTwE3JjnWeexPckeSOzptHgFOASeB/wT8682XIUlq01qnlb4PfGTZe/8e+OrS01Sr6nt0P6bAkjYFfHYdfZQkbbG1Tiv9f+GQ5J+zeCbS3T3okySpz9YaDkeBf5TkfUneBfw58KdV5XKLkjSC1jqtNAu8CVzD4h7DPHBfrzolSeqvNYVDVb2R5G+A3wR+B/iXVfVWT3smST30zt3q3vWzt5nud2cG0Hquc/g+8Dngsap6uEf9kaSeW3q3uh2Bq64+6yqyy6znOodjwALwh6u0k6SBtvRudfMLeLe6LtYTDp8E/mNVHe9VZyRpKyy9W92ObXi3ui7OOa2UZBswAfwu8KvAb29BnySpp5bere5dPzvtlFIXqx1z+A3gfwAngH9RVWd73yVJ6r137lY3M3Om310ZSOcMh6qaoYfLekuSBpO/+CVJDYaDJKnBcJAkNRgOkqQGw0GS1GA4SJIaDAdJUoPhIElqMBwkSQ2GgySpwXCQJDUYDpKkhtbDIcn9SV5M8swKn08neTXJsc7jrrb7II2b2dNnue/ISWZPu3Cy2rGe24Su1V8A9wIPnKPNd6vq5h58tzR2lt7y8rwd23jw9invT6BNa33Poaq+A7zS9nYldbf0lpdvzS94y0u1ol/HHK5P8lSSR5N8qE99kEbC0lte7tyxbShueek02OBLVbW/0eT9wMNV9eEun70XWKiquST7gXuqau8K2zkAHACYmJjYd/jw4db7Oijm5ubYtWtXv7vRE6NcGwxGfSfPvs2PXnmbf/y+7Xzgwu2tbrvt+k6efZs/e+KXvLUAO7fB5z9yfut9Xo9BGL9euuGGG2ar6pr1/lwvjjmcU1W9tuT5I0n+Q5KLq+qlLm0PAYcAJicna3p6eus6usVmZmYYtvpmT5/l6KmXmdpz0TnnuIextvUYhPp6+e1t13f8yEnm6wQFvF3wxgW7mZ7+QGvbX69BGL9BtOXhkOQS4KdVVUmuZXFqy0nSIeNBUG3UO9Ngb80vDM002DhqPRySfIPFP2QuTnIGuBvYCVBVB4FbgTuTzAOvA7dVL+a21FPdDoIaDlqLfbsv5MHbp9a016n+aT0cquoTq3x+L4unumqI+defNmPf7gsNhQG35dN
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code a quick peek at the dataset we just generated\n",
2021-11-03 23:35:15 +01:00
"plt.figure(figsize=(6, 4))\n",
"plt.plot(X, y, \".\")\n",
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$ \", rotation=0)\n",
"plt.axis([0, 3, 0, 3.5])\n",
"plt.grid()\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 36,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[1.55325833]])"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-11-03 23:35:15 +01:00
"from sklearn.linear_model import Ridge\n",
"\n",
"ridge_reg = Ridge(alpha=0.1, solver=\"cholesky\")\n",
"ridge_reg.fit(X, y)\n",
"ridge_reg.predict([[1.5]])"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 37,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAngAAADsCAYAAADn/9tGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABgnElEQVR4nO3dd3hUVfrA8e9JDyQkhIQkhB669CbFEkBQUMTeRSwgYgcr6oKrwLprx4JYFv3prqCuvYIQAUFBeock9BZKIAmkz/n9cdILKXOn8n6e5z6ZuXPn3ndmkpN3TlVaa4QQQgghhPfwcXUAQgghhBDCWpLgCSGEEEJ4GUnwhBBCCCG8jCR4QgghhBBeRhI8IYQQQggvIwmeEEIIIYSXsTTBU0oFKaVWKKXWKaU2KaWereSYBKXUSaXU2sLtb1bGIIQQQghxtvOz+Hw5wGCtdaZSyh9YqpT6UWv9R7njlmitL7P42kIIIYQQAosTPG1mTc4svOtfuMlMykIIIYQQTmR5HzyllK9Sai2QCszXWv9ZyWH9C5txf1RKnWN1DEIIIYQQZzPlqKXKlFLhwJfA/VrrjaX2NwBshc24I4DXtNZtqzjHOGAcQFBQUK/mzZs7JFZHsdls+Ph41jgWT4wZPDNuidl5tm/fflRrHeWs60nZ5RqeGLfE7DyeEvfOUzvJs+XRsn5LdiXtqnPZ5bAED0ApNQU4pbV+8QzH7AJ6a62Pnulc7du319u2bbM4QsdKTEwkISHB1WHUiifGDJ4Zt8TsPEqpVVrr3q64tpRdzuOJcUvMzuMJcWfnZ1NvWj2UUpyefJog/6A6l11Wj6KNKqy5QykVDFwEbC13TIxSShXe7lsYwzEr4xBCCCGE8DRJx5PQaFo3bE2gX6Bd57J6FG0s8KFSyheTuM3TWn+nlBoPoLWeBVwD3KOUygeygBu0I6sRhRBCCCE8wLajpra/faP2dp/L6lG064EeleyfVer2G8AbVl5XCCGEEMLTbTtmErwOkR3sPpf79zYUQgghhDgLbD1qerVZUYMnCZ4QQgghhBvYdWIXAO0j3ayJ1pXS09NJTU0lLy/P1aEUCwsLY8uWLa4Oo1asjNnf35/GjRvToEEDS84nhDeSsss6nha3v78/vr6+rg5DuJHEMYnsTNtJk9Amdp/LKxK89PR0Dh8+TFxcHMHBwRQO0nW5jIwMQkNDXR1GrVgVs9aarKws9u/fDyBJnhCVkLLLWp4Ud1EZeerUKdLT06WMFAD4KB/iI+KtOZclZ3Gx1NRU4uLiqFevntsUkGc7pRT16tUjLi6O1NRUV4cjhFuSsuvsJWWkcDSvSPDy8vIIDg52dRiiEsHBwW7V9CSEO5GySwQEBEgZKQD41+//os+7ffhs02eWnM8rEjxAvv26KflchDgz+Rs5u8nnL4qsOLCCvw78RXZ+tiXn85oETwghhBDCU60/vB6ALtFdLDmfJHhCCCGEEC6UlZdF0vEkfJUvHSM7WnJOSfDOIm+99RatWrUiKCiIXr16sWTJEleHJIQQ1ZKyS3i7zUc2Y9M22ke2t3sN2iKS4J0l5s6dy4MPPsjkyZNZs2YNAwYMYPjw4ezZs8fVoQkhRJWk7BJng+Lm2cbWNM+CJHhuYdWqVQwZMoTg4GDatGnD4sWLmTdvHgMHDrTsGi+//DJjxoxh7NixdOzYkZkzZxIbG8vbb79t2TWEEGcXKbuEsMaG1A0AdI3uatk5JcFzsZUrV3L++eczaNAg1q9fT79+/ZgyZQrTpk3jueeeK3Ps9OnTCQkJOeNWWdNFbm4uq1atYtiwYWX2Dxs2jGXLljn09QkhvJOUXUJYZ3ib4UzqP4nBrQZbdk6vTfCUMltpI0eafd9+W7Jv9myzb9y4kn0HDph9TcqtFNKrl9m/alXJvqlTK16nNiZNmsTIkSN5+umnadu2LTfddBOLFy8mIiKCwYPLftDjx49n7dq1Z9x69+5d4RpHjx6loKCA6OjoMvujo6M5dOhQ3YMXQjiEJ5RfUnYJYZ2h8UN5cdiL9Gvaz7JzesVSZZ7q0KFDLFmyhEWLFhXvCwgIwGazVfgGDBAREUFERESdr1d+viWttczBJISoNSm7hHB/XluDp7XZSvv2W7Nv5MiSfePGmX2zZ5fsa9LE7DtwoOzzV60y+3v1Ktk3dWrF69RU0aLYffr0Kd63bds22rdvz3nnnVfh+Lo2c0RGRuLr61vhG29qamqFb8ZCCNdz9/JLyi4hrLPt6DY+WPMBm49stvS8UoPnQidOnEAphY+PybMzMjKYNm0aMTExlR4/fvx4rrvuujOeMy4ursK+gIAAevXqxfz587n22muL98+fP5+rr77ajlcghDgbSdklhHV+TPqRh39+mHE9x/HOyHdKHli+3K7zWprgKaWCgMVAYOG5P9daTyl3jAJeA0YAp4ExWuvVVsbhKbp3747WmhkzZnDzzTfz6KOPEhsbS1JSEjt27KBt27ZljrenmWPixInceuut9O3bl4EDBzJr1iwOHDjA+PHjrXgpQoiziJRdQlhnzaE1AHSL6VayMysLLrnErvNa3USbAwzWWncDugOXKKXK9xgcDrQt3MYBZ+1Y91atWjFt2jTefvttunXrRmhoKAsWLKBz584MGDDA0mtdf/31vPrqqzz//PN0796dpUuX8sMPP9CiRQtLryOE8H5Sdglhnb8O/AVA7yalBhp9/TWkp9t1Xktr8LTWGsgsvOtfuJXv4TEK+Kjw2D+UUuFKqVit9UErY/EUkydPZvLkyWX2/f777w651oQJE5gwYYJDzi2EOLtI2SWE/TJzM9lyZAt+Pn5l58D78EO7z235IAullK9Sai2QCszXWv9Z7pA4YG+p+/sK9wkhhBBCnDVWH1yNRtM1uitBfkFm54ED8Msv4O9v17ktH2ShtS4AuiulwoEvlVKdtdYbSx1S2dj2SsdxKaXGYZpxiYqKIjExsdJrhoWFkZGRYU/YDlFQUOCWcZ2JI2LOzs6u8rOzSmZmpsOvYTWJ2XtJ2eUanhh3QUGBU8pIK3lqOeCOcc/bOw+AJjQpjq3Zp58Sb7NxZOBAsGPdZYeNotVan1BKJQKXAKUTvH1As1L3mwLlBvQXn2M2MBugffv2OiEhodJrbdmyhdDQUPuDtlhGRoZbxnUmjog5KCiIHj16WHrO8hITE6nq98NdSczeS8ou1/DEuDMyMpxSRlrJU8sBd4z7t8TfqL+3Ppf3upyEXglm3qL77gMgatIkuxI8S5tolVJRhTV3KKWCgYuAreUO+wYYrYx+wMmztf+dEEIIIc5eUxKmcPKJk4zuNtrsWL0aNm2CyEgYPtyuc1tdgxcLfKiU8sUkj/O01t8ppcYDaK1nAT9gpkhJwkyTcrvFMQghhBBCeARfH198fXzNnaLBFTfdBAEBdp3X6lG064EK9cyFiV3RbQ3ca+V1hRBCCCE8yem80wT4BuDnU5iKZWXBJ5+Y27fdZvf5vXapMiGEEEIId/XGijcInRHKi8teNDs+/xyOHzfrCfbsaff5JcETQgghhHCylQdWkp2fTWS9SLPjncJlyu6+25LzS4InhBBCCOFkRStY9GnSBzZuhN9/h9BQuPFGS84vCZ4QQgghhBMdOXWEXSd2Uc+/Hh0iO5TU3t1yC4SEWHINSfDOEosXL+byyy8nLi4OpRRz5syp9Lh3332XVq1aERQURK9evVhixxw8Qghhr5qWXc7y1ltvnbGMnDp1KkqpMltMTIyLohXuauWBlQD0jO2Jb3YO/N//mQcsap4FSfDOGpmZmXTu3JnXXnuN4ODgSo+ZO3cujz/+OJMnT2bNmjUMGDCA4cOHs2fPHidHK4QQRk3KLmeZO3cuDz74YLVlZPv27Tl48GDxtmHDBhdFLNzV73vMus0Dmg6AuXPh5Ek491zo1s2ya0iC5wZWrVrFkCFDCA4Opk2bNixevJh58+YxcOBAy64xYsQIpk+fzjXXXIOPT+Uf+8svv8zNN9/M2LFj6dixIzNnziQ2Npa3337bsjiEEN7DXcougP379zNu3DgaNWpEeHg4V199NYcPH7YsDjBl5JgxY6otI/38/Ii
"text/plain": [
"<Figure size 648x252 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 417\n",
"\n",
2022-05-24 14:37:17 +02:00
"def plot_model(model_class, polynomial, alphas, **model_kwargs):\n",
2021-11-03 23:35:15 +01:00
" plt.plot(X, y, \"b.\", linewidth=3)\n",
" for alpha, style in zip(alphas, (\"b:\", \"g--\", \"r-\")):\n",
" if alpha > 0:\n",
2022-05-24 14:37:17 +02:00
" model = model_class(alpha, **model_kwargs)\n",
2021-11-03 23:35:15 +01:00
" else:\n",
" model = LinearRegression()\n",
" if polynomial:\n",
2021-11-03 23:35:15 +01:00
" model = make_pipeline(\n",
" PolynomialFeatures(degree=10, include_bias=False),\n",
" StandardScaler(),\n",
" model)\n",
" model.fit(X, y)\n",
" y_new_regul = model.predict(X_new)\n",
2021-11-03 23:35:15 +01:00
" plt.plot(X_new, y_new_regul, style, linewidth=2,\n",
2021-11-21 22:18:02 +01:00
" label=fr\"$\\alpha = {alpha}$\")\n",
2021-11-03 23:35:15 +01:00
" plt.legend(loc=\"upper left\")\n",
" plt.xlabel(\"$x_1$\")\n",
" plt.axis([0, 3, 0, 3.5])\n",
" plt.grid()\n",
"\n",
"plt.figure(figsize=(9, 3.5))\n",
"plt.subplot(121)\n",
"plot_model(Ridge, polynomial=False, alphas=(0, 10, 100), random_state=42)\n",
2021-11-03 23:35:15 +01:00
"plt.ylabel(\"$y$ \", rotation=0)\n",
"plt.subplot(122)\n",
"plot_model(Ridge, polynomial=True, alphas=(0, 10**-5, 1), random_state=42)\n",
2021-11-03 23:35:15 +01:00
"plt.gca().axes.yaxis.set_ticklabels([])\n",
"save_fig(\"ridge_regression_plot\")\n",
"plt.show()"
]
},
{
2021-11-03 23:35:15 +01:00
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 38,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([1.55302613])"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sgd_reg = SGDRegressor(penalty=\"l2\", alpha=0.1 / m, tol=None,\n",
" max_iter=1000, eta0=0.01, random_state=42)\n",
"sgd_reg.fit(X, y.ravel()) # y.ravel() because fit() expects 1D targets\n",
2021-11-03 23:35:15 +01:00
"sgd_reg.predict([[1.5]])"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 39,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[1.55321535]])"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code show that we get roughly the same solution as earlier when\n",
"# we use Stochastic Average GD (solver=\"sag\")\n",
"ridge_reg = Ridge(alpha=0.1, solver=\"sag\", random_state=42)\n",
2021-11-03 23:35:15 +01:00
"ridge_reg.fit(X, y)\n",
"ridge_reg.predict([[1.5]])"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0.97898394],\n",
" [0.3828496 ]])"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code shows the closed form solution of Ridge regression,\n",
"# compare with the next Ridge model's learned parameters below\n",
"alpha = 0.1\n",
"A = np.array([[0., 0.], [0., 1.]])\n",
"X_b = np.c_[np.ones(m), X]\n",
"np.linalg.inv(X_b.T @ X_b + alpha * A) @ X_b.T @ y"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(array([0.97944909]), array([[0.38251084]]))"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ridge_reg.intercept_, ridge_reg.coef_ # extra code"
]
},
2021-10-02 13:14:44 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Lasso Regression"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([1.53788174])"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.linear_model import Lasso\n",
"\n",
"lasso_reg = Lasso(alpha=0.1)\n",
"lasso_reg.fit(X, y)\n",
"lasso_reg.predict([[1.5]])"
]
},
2021-10-02 13:14:44 +02:00
{
2021-11-03 23:35:15 +01:00
"cell_type": "code",
"execution_count": 43,
2021-10-02 13:14:44 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAngAAADsCAYAAADn/9tGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABPcklEQVR4nO3deVzU1f748ddhR0AUAUXcFVCz3E1cCrUsLeu2r5Ytmlb3VprdUkvL1Nuv22o3zbpt9/Ytbb22pyVpaa6puQSKlrsoboArzPn9cWCGAVRgPrPyfj4e84D5zGc+n/cM8OY953zOOUprjRBCCCGECBxB3g5ACCGEEEJYSwo8IYQQQogAIwWeEEIIIUSAkQJPCCGEECLASIEnhBBCCBFgpMATQgghhAgwlhZ4SqkIpdQypdQapdR6pdSTleyToZQ6rJRaXXJ7wsoYhBBCCCFquxCLj3cC6K+1LlBKhQI/KaW+1lr/Um6/RVrryy0+txBCCCGEwOICT5tZkwtK7oaW3GQmZSGEEEIID7L8GjylVLBSajWQC8zTWi+tZLf0km7cr5VS51gdgxBCCCFEbabctVSZUqoe8CnwV631ujLb6wK2km7cwcBLWuuU0xxjBDACICIiomuzZs3cEqu72Gw2goL8axyLP8YM/hm3xOw52dnZ+7XWCZ46n+Qu7/DHuCVmz/HHuF3JXW4r8ACUUhOBQq31P8+wzx9AN631/jMdKy0tTWdlZVkcoXtlZmaSkZHh7TCqxR9jBv+MW2L2HKXUSq11N2+cW3KX5/hj3BKz5/hj3K7kLqtH0SaUtNyhlIoELgJ+L7dPI6WUKvm+R0kMeVbGIYQQQghRm1k9ijYJeEcpFYwp3OZorb9QSo0E0FrPBK4FRimlioBjwI3anc2IQgghhBC1jNWjaNcCnSvZPrPM968Ar1h5XiGEEEII4eBfVxsKIYQQQoizkgJPCCGEECLAWH0NntccOXKE3NxcTp065e1Q7GJjY9m4caO3w6gWV2IODQ0lMTGRunXrWhyVEIFLcpd13Bm35DfhbwKiwDty5Ah79+4lOTmZyMhISgbpel1+fj4xMTHeDqNaahqz1ppjx46xc+dOAEmCQlSB5C5ruStuyW/CHwVEF21ubi7JycnUqVPHZxJkbaOUok6dOiQnJ5Obm+vtcITwC5K7/IPkN+GPAqLAO3XqFJGRkd4OQwCRkZE+1dUkhC+T3OVfJL8JfxIQBR4gn359hPwchKge+ZvxH/KzEv4kYAo8IYQQQghhSIEnhBBCCBFgpMCrRV599VVatmxJREQEXbt2ZdGiRd4OSQghzkpylxDVJwVeLTF79mweeOABxo0bx6+//kqvXr0YNGgQ27Zt83ZoQghxWpK7hKgZKfB8wMqVKxkwYACRkZG0adOGhQsXMmfOHHr37m3ZOZ5//nmGDRvG8OHDadeuHdOnTycpKYkZM2ZYdg4hRO0iuUsI3yUFnpctX76cvn370q9fP9auXUvPnj2ZOHEiU6ZMYfLkyU77Tp06lejo6DPeKuu6OHnyJCtXrmTgwIFO2wcOHMjixYvd+vqEEIFJcpcQvi1gCzylzK2sIUPMts8/d2ybNctsGzHCsW3XLrOtcWPn53ftaravXOnYNmlSxfNUx5gxYxgyZAgTJkwgJSWFm2++mYULFxIXF0f//v2d9h05ciSrV68+461bt24VzrF//36Ki4tp2LCh0/aGDRuyZ8+emgcvhHALf8hfkruE8G0BsVSZv9qzZw+LFi1iwYIF9m1hYWHYbLYKn4AB4uLiiIuLq/H5ys/hpLWWeZ2EENUmuUsI3xewLXham1tZn39utg0Z4tg2YoTZNmuWY1vjxmbbrl3Oz1+50mzv2tWxbdKkiuepqtJFsbt3727flpWVRVpaGn369Kmwf027OeLj4wkODq7wiTc3N7fCJ2MhhPf5ev6S3CWE75MWPC86dOgQSimCgkydnZ+fz5QpU2jUqFGl+48cOZLrr7/+jMdMTk6usC0sLIyuXbsyb948rrvuOvv2efPmcc0117jwCoQQtZHkLiF8n6UFnlIqAlgIhJcc+yOt9cRy+yjgJWAwcBQYprVeZWUc/qJTp05orZk2bRq33HILY8eOJSkpic2bN7Np0yZSUlKc9nelm2P06NEMHTqUHj160Lt3b2bOnMmuXbsYOXKkFS9FCFGLSO4SwvdZ3UV7Auivte4IdAIuVUr1LLfPICCl5DYCqLVj3Vu2bMmUKVOYMWMGHTt2JCYmhvnz59OhQwd69epl6bluuOEGXnzxRZ5++mk6derETz/9xFdffUXz5s0tPY8QIvBJ7hLCvXILcxn97WiXjmFpC57WWgMFJXdDS27lr/C4Eni3ZN9flFL1lFJJWuvdVsbiL8aNG8e4ceOctv38889uOde9997Lvffe65ZjCyFqF8ldQrjP+7+9zwu/vODSMSwfZKGUClZKrQZygXla66XldkkGtpe5v6NkmxBCCCFErffu2nddPoblgyy01sVAJ6VUPeBTpVQHrfW6MrtUNra90nFcSqkRmG5cEhISyMzMrPScsbGx5OfnuxK2WxQXF/tkXGdiRczHjx8/7c/KXQoKCjx+TldJzIFLcpd3eCJuq/ObP/5N+WPM4D9xby3cyqrdq4gOiabA3ilafW4bRau1PqSUygQuBcoWeDuApmXuNwHKDei3H2MWMAsgLS1NZ2RkVHqujRs3EhMT43rQFsvPz/fJuM7EipgjIiLo3LmzRRFVTWZmJqf7/fBVEnPgktzlHZ6I2+r85o9/U/4YM/hP3CsWrwDglo638Bqv1fg4lnbRKqUSSlruUEpFAhcBv5fbbS5wmzJ6Aodr6/V3QgghhBBlPdzrYdaMXMPYXmNdOo7VLXhJwDtKqWBM8ThHa/2FUmokgNZ6JvAVZoqUzZhpUu6wOAYhhBBCCL91XsPzXD6G1aNo1wIV2q5LCrvS7zVwn5XnFUIIIYTwdzuP7CS5rjXjTgN2qTIhhBBCCH+x/+h+Wr3cir5v9eVk8UmXjycFnhBCCCGEl72z+h1OFp8kJiyGsOAwl48nBZ4QQgghhBdprZm1ahYA93S9x5JjSoEnhBBCCOFFmX9kkp2XTXJMMpelXmbJMaXAq0VeffVVWrZsSUREBF27dmXRokVn3H/hwoVcccUVJCcno5Ti7bff9kygQghRRnVzV1WeI/lN+JLXVpr57u7qfBchQdaMf5UCr5aYPXs2DzzwAOPGjePXX3+lV69eDBo0iG3btp32OQUFBXTo0IGXXnqJyMhID0YrhBBGTXJXVZ4j+U34itzCXD7Z+AlBKoi7u9xt2XGlwPMBK1euZMCAAURGRtKmTRsWLlzInDlz6N27t2XneP755xk2bBjDhw+nXbt2TJ8+naSkJGbMmHHa5wwePJipU6dy7bXXEhQkvypCCGe+mruq8hzJb8JXrN27lvCQcC5LuYymsU3P/oQqkt9qL1u+fDl9+/alX79+rF27lp49ezJx4kSmTJnC5MmTnfadOnUq0dHRZ7xV1nVx8uRJVq5cycCBA522Dxw4kMWLF7v19QkhApOv5i7Jd8LfXNTqInY8tIMXL33R0uO6bS1ab1NPqtM+9trlrzGi6wgAZq2cxT1fnH7Eip6o7d93ndWVVbtXnXW/6hgzZgxDhgxhwoQJANx8880MGTKECy64gP79+zvtO3LkSK6//vozHi85ueIEifv376e4uJiGDRs6bW/YsCHz58+vUdxC+IolSyAzEzIyID3d29FYwx/yl6/mLsl3wtedPAnffgtaw/r1pbkrltiIWEvPE7AFnj/Ys2cPixYtYsGCBfZtYWFh2Gy2Cp+AAeLi4oiLi6vx+ZRy/qehta6wTQh/smQJDBhgEmZYGHz/feAUeb7MH3KX5Dvhq9avhyuuANDQ7jMipg7ih+8iLM9dAVvgVfUT6YiuI+yfhs9m5YiVroRUwcaNGwHo3r27fVtWVhZpaWn06dOnwv5Tp05l6tSpZzzm119/Td++fZ22xcfHExwczJ49e5y25+bmVviUK4Q/ycw0xV1xsfmamRkYBZ6v5y9fzl2S74SvO3UK0tIg+9gS9A1
"text/plain": [
"<Figure size 648x252 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2021-10-02 13:14:44 +02:00
"source": [
"# extra code this cell generates and saves Figure 418\n",
2021-11-03 23:35:15 +01:00
"plt.figure(figsize=(9, 3.5))\n",
"plt.subplot(121)\n",
"plot_model(Lasso, polynomial=False, alphas=(0, 0.1, 1), random_state=42)\n",
2021-11-03 23:44:16 +01:00
"plt.ylabel(\"$y$ \", rotation=0)\n",
2021-11-03 23:35:15 +01:00
"plt.subplot(122)\n",
"plot_model(Lasso, polynomial=True, alphas=(0, 1e-2, 1), random_state=42)\n",
"plt.gca().axes.yaxis.set_ticklabels([])\n",
"save_fig(\"lasso_regression_plot\")\n",
"plt.show()"
2021-10-02 13:14:44 +02:00
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAscAAAIwCAYAAABqYcaPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADA8UlEQVR4nOydd5hkVZm431NVnXOanhwZksiQBGcwgA4ZUZEkKrIqqKir6+oa1lWXXXVX/bmKioqrIiwIqIRRiQMiGQkyRAeYxPR0mI4znaa7q+v8/qiqnurqqu46t26s+t7n6Wemq2746tbtW+/96jvfUVprBEEQBEEQBEGAkNcBCIIgCIIgCIJfEDkWBEEQBEEQhAQix4IgCIIgCIKQQORYEARBEARBEBKIHAuCIAiCIAhCApFjQRAEQRAEQUggciwIgiAIgiAICUSOBUEQBEEQBCGByLEQCJRSzUqpa5VSP/M6lnxQSl2tlPqj13EIgiAIgpAZkWMhKPwHsBN4n9eB2IlS6n6l1I+8jkMQBCEX5AZfKAZEjgXfo5SKAOcCdwDtHocjCIIgCEIBI3IsBIG1wMvA0cDDVjaQyND+VCn1A6VUf+LnO0qpUOJ5pZT6F6XUFqXUqFLqOaXU+zNs40ql1DeVUj1Kqd1Kqe8mt5FY5lSl1IOJ7fcppe5SSh2SJaargbcCn1BK6cTPcqXURUqpXqVUWdry1ymlNlh5/YIgCE6jlPqsUupZpdSwUmqXUup/lVL1Kc/XJcrjdiul9imltiqlPpPy/EeVUi8nnutOXD8jiedCSql/U0rtVEqNJa7R73T/VQrFgMixEATeDDwGnAXcnMd23kf8nF8LfBS4FPhM4rn/BD4MfAI4FPgW8DOl1BkZthEF1gGfTKx/fsrzVcD3gWOBE4A9wB+UUqUZ4vk08CjwK2BB4mcn8NtEnFMXfqVUHfBu4BdmL1kQBME1YsSvia8DLiR+HfxhyvP/CbweOBM4GPgQsAtAKXUM8GPg34GDgPXAnSnrfhr4PPCFxDZuAW5WSh3h1IsRiheltfY6BkGYFaXUzcArxC+2K7TW0UQG9c3AvVrrc3LYxv3AQuAgnTjplVJfAT5G/ELcA5ystX4wZZ3vAwdqrU9P2UaZ1nptyjL3ADu01h/Jst8qYC/wVq31Q4lscbPW+syUbT6vtf5k2no/Ag7QWp+a+P3jwNeAxVrr6FyvVxAEwQnSr2FzLHsqcBtQobWOJa7bvVrrf8iw7NnEEwWLtdaDGZ7fBfxMa315ymP3A21a6/enLy8I+SCZYyEILCeeNf5eihj+D3CR4XYe09PvBh8FFgHHAOXAnUqpoeQP8HFgVdo2nk37vR2Yl/xFKbVKKXV9ojxjL9BF/O9sqWGsPwdOUkotTvz+IeDXIsaCIPgVpdTblFL3KKXalFKDxL/pKwXmJxb5CXCeUmpToiTtrSmr3wPsALYlSsg+qJSqSWy3lnhyI72s7iHi3/QJgq2IHAtBoBWoBqbauGmt/wzMyC7kyTuAI1J+XgecnLbMRNrvmul/R38AWoiXbRwHHEm8DCNTWUVWtNabgKeBi5VShxEX+F+abEMQBMEtlFLLgD8BLxEfQH008Zt6SFz/tNZ3AMuA7wLNwJ+UUr9KPDcIHAWcB7wGfAn4u1JqYcpuMn3VLV9/C7YjciwEgSjwJa31SJ7bOU4ppVJ+fyPxzO/TwBiwTGv9atrPjlw3rpRqAg4Bvqm13qi1fgmoASKzrDYOhLM893PgYuAjwMNa6825xiIIguAyxxCX4H/SWj+qtX6ZeLZ3GlrrHq31tVrri4mP8/hgcvCx1jqqtb5Pa/0l4HDiYzjO1FrvJX6tflPa5t4EvOjYKxKKltk+tAXBc5RSZxIvSRhTSh0KHKm1vs7i5hYC31dKXUl8QMfngf/UWg8qpb4LfDchzw8Qz1S/EYhpra/Kcfv9xGuXL1FK7SResvEd4nKfje3AsUqp5cAQ0Ke1jiWe+w3wPeLlHR/L+VUKgiA4S22GgXCvEE+4fSYxTuSN7B/wDIBS6nLiyYgXiPvH2cBWrfVY4lq/ivj1tw84kXhy4aXE6t8BLldKvQI8Bbyf+LiTo+1+cYIgciz4FqVUmHjm9P3A5UAH8UyDVa4jnqV9nPhXcb8gXrsM8G/E64M/R7wubi/wDPDtXDeeGHByPnAF8DzwKvDPwO9nWe27wK+JZz8qgBXEhZmEtN9E/CvKm3KNQxAEwWHeDPwt7bHfE+8o8QXiXSkeIX49vTFlmTHgG8Svc/uIdyF6R+K5AeBdwFeBSmAL8JGUQdJXEJflbxMvtdsMvEdr/Yxtr0oQEki3CiGwKKVOAD5p0K1iRlcIv6OUuoP4aOxLvI5FEARBEIoByRwLgUQptRFYA1QppdqAc7XWj3oclm0opRqJ9/k8mfjrFARBEATBBTyRY6XUL4k3Ad+ttT4sw/MnEO+NuC3x0M2pvQ0FQWu93usYHOZpoBH4stb6ea+DEQRBEIRiwZOyCqXUW4gPPrpmFjn+XC5NxgVBEARBEATBLjxp5aa1To5GFQRBEARBEATf4Oc+x2sTs+jcoZR6ndfBCIIgCIIgCIWPXwfkPU18QoYhpdTpwK3A6vSFlFKXApcClJWVHd06f5GrQeZKJKyYjMbmXtBFOrs6UApa5y3wOpQZhCMhopP+66ISCSv/xuXDrjMR5c+4tPJvbDu3bu3RWrd4tf/p19Tyo1sXzpjDYeY6hvvQ2nydqfUUhEOKyVhu7531fWmmzxc0O52d7b6+nvrt8yeJX2MzjsvgXMmHcFgxOcdnULJU1uT8zQdNwnFy/Ju0tA9t7RDv2G79eupZK7fEpAd/zFRznGHZ7cAxWuuebMssW75KLzj+MvsCtImabaOcd9Eqbrpmi9ehTOPJF37BvNYKljZf6HUoM3jPJau5emOb12HM4OL1i30b149eavc6jBlctmYRV27a5XUYGfnYMQv5Xqf/3sttn/7cU1rrY7yOA2DZylV61Qe/kPX53r0jtEYqjbfbt2eE+ZEKs3UGhmmNlANw5gkt/PH+7pzW6+8fpqXMbF/9vUM0lxnN9s59j19JQ1MpKxouoLHCbF2nOenC5dxz/Xavw8iIX2MziatvdHzq//VN1Q5FFOfksxdx981zX1N7xvbH1OBwTACnndrKNbdtp6GhyrF9dEX30Vhvtv0//fhjlq+nviyrUErNT07zq5Q6lnicvd5GZU7NtlGvQwgk4bGYHDtDqnf5L/viZ0IT7mRVhOn07cl3BvjcsSLGVolOTAL4TowF52msKJ163wd6hxjoHfI4ImguK526wet3KZ6Wsgr6+4fp7x92ZPutkXL6BpzZdiY8kWOl1G+AR4GDlFJtSqkPK6U+ppRKTpF7DvC8UmoT8VlxLtABm61E5C5/5BiaIYJsRmWbL3MDgaB3r3XJtZI1dot8RMKlb7EFn+JXSQZ3BRkoCEH2pOZYa/3eOZ7/EfAjl8KxHZE6+6jZNsrgCneyP4VA9a4YQ4tE+nKlsi3EyGK5qbCClZIK6/sqN1o+nw9n05IKP0iQ4B+Sgtw3Oj51bjhdbjEbzWWl9IyNTwmy02UWLWUVdI+N0t8/7EiZRWuknK6BYeMSC1PkU9RmRIztR46pGZJBNkMyyO5gpdY4H6zUGlulJCRpY2E66ZlkL3G7zMLpDDI4/42SfCrYiEicc8ixNUME2QwR5NzJp6TClNSBeG4gWWPBbpKS7IdSi0IR5OQ1wUlBlk8EmxB5cx45xmaIIJshgpw7piUVfh+Il48syCA8IRf8IskiyLkhnwY2INLmHnKszRBBNkME2Tn8PBAPJGucL/09g1M/Qnb8MGgvWWbR3zvkuCQHVZD9OglIYBBZcx8ZpGeGDNIzQwbpZcfNkgpwdyCeFQoxa2xFbhsrIzRWpujEyCh9I9E514tGY9P219BcY7zvIJI+aM+rAXupg/WcHKjn5CC91kg5XdF99Nk8SE/kOA9EjL1DBNkMEWQzRJCz4+eSCrA2EK8Ys8azSfA00bVILtuIhKYv15clpkKV5ilJ9rCrhQhyZkSOLSJi7D0iyGaIIJshgmwffi+psEKQssb
"text/plain": [
"<Figure size 727.2x576 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this BIG cell generates and saves Figure 419\n",
2021-11-21 05:36:22 +01:00
"\n",
2021-11-03 23:35:15 +01:00
"t1a, t1b, t2a, t2b = -1, 3, -1.5, 1.5\n",
"\n",
"t1s = np.linspace(t1a, t1b, 500)\n",
"t2s = np.linspace(t2a, t2b, 500)\n",
"t1, t2 = np.meshgrid(t1s, t2s)\n",
"T = np.c_[t1.ravel(), t2.ravel()]\n",
"Xr = np.array([[1, 1], [1, -1], [1, 0.5]])\n",
"yr = 2 * Xr[:, :1] + 0.5 * Xr[:, 1:]\n",
"\n",
"J = (1 / len(Xr) * ((T @ Xr.T - yr.T) ** 2).sum(axis=1)).reshape(t1.shape)\n",
"\n",
"N1 = np.linalg.norm(T, ord=1, axis=1).reshape(t1.shape)\n",
"N2 = np.linalg.norm(T, ord=2, axis=1).reshape(t1.shape)\n",
"\n",
"t_min_idx = np.unravel_index(J.argmin(), J.shape)\n",
"t1_min, t2_min = t1[t_min_idx], t2[t_min_idx]\n",
"\n",
2021-11-21 05:36:22 +01:00
"t_init = np.array([[0.25], [-1]])\n",
"\n",
2021-11-03 23:35:15 +01:00
"def bgd_path(theta, X, y, l1, l2, core=1, eta=0.05, n_iterations=200):\n",
" path = [theta]\n",
" for iteration in range(n_iterations):\n",
" gradients = (core * 2 / len(X) * X.T @ (X @ theta - y)\n",
" + l1 * np.sign(theta) + l2 * theta)\n",
" theta = theta - eta * gradients\n",
" path.append(theta)\n",
" return np.array(path)\n",
"\n",
2021-11-03 23:35:15 +01:00
"fig, axes = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10.1, 8))\n",
"\n",
2021-11-03 23:35:15 +01:00
"for i, N, l1, l2, title in ((0, N1, 2.0, 0, \"Lasso\"), (1, N2, 0, 2.0, \"Ridge\")):\n",
" JR = J + l1 * N1 + l2 * 0.5 * N2 ** 2\n",
"\n",
2021-11-03 23:35:15 +01:00
" tr_min_idx = np.unravel_index(JR.argmin(), JR.shape)\n",
" t1r_min, t2r_min = t1[tr_min_idx], t2[tr_min_idx]\n",
"\n",
" levels = np.exp(np.linspace(0, 1, 20)) - 1\n",
" levelsJ = levels * (J.max() - J.min()) + J.min()\n",
" levelsJR = levels * (JR.max() - JR.min()) + JR.min()\n",
" levelsN = np.linspace(0, N.max(), 10)\n",
"\n",
" path_J = bgd_path(t_init, Xr, yr, l1=0, l2=0)\n",
" path_JR = bgd_path(t_init, Xr, yr, l1, l2)\n",
" path_N = bgd_path(theta=np.array([[2.0], [0.5]]), X=Xr, y=yr,\n",
" l1=np.sign(l1) / 3, l2=np.sign(l2), core=0)\n",
" ax = axes[i, 0]\n",
" ax.grid()\n",
" ax.axhline(y=0, color=\"k\")\n",
" ax.axvline(x=0, color=\"k\")\n",
" ax.contourf(t1, t2, N / 2.0, levels=levelsN)\n",
" ax.plot(path_N[:, 0], path_N[:, 1], \"y--\")\n",
" ax.plot(0, 0, \"ys\")\n",
" ax.plot(t1_min, t2_min, \"ys\")\n",
2021-11-21 22:18:02 +01:00
" ax.set_title(fr\"$\\ell_{i + 1}$ penalty\")\n",
2021-11-03 23:35:15 +01:00
" ax.axis([t1a, t1b, t2a, t2b])\n",
" if i == 1:\n",
" ax.set_xlabel(r\"$\\theta_1$\")\n",
" ax.set_ylabel(r\"$\\theta_2$\", rotation=0)\n",
"\n",
2021-11-03 23:35:15 +01:00
" ax = axes[i, 1]\n",
" ax.grid()\n",
" ax.axhline(y=0, color=\"k\")\n",
" ax.axvline(x=0, color=\"k\")\n",
" ax.contourf(t1, t2, JR, levels=levelsJR, alpha=0.9)\n",
" ax.plot(path_JR[:, 0], path_JR[:, 1], \"w-o\")\n",
" ax.plot(path_N[:, 0], path_N[:, 1], \"y--\")\n",
" ax.plot(0, 0, \"ys\")\n",
" ax.plot(t1_min, t2_min, \"ys\")\n",
" ax.plot(t1r_min, t2r_min, \"rs\")\n",
" ax.set_title(title)\n",
" ax.axis([t1a, t1b, t2a, t2b])\n",
" if i == 1:\n",
" ax.set_xlabel(r\"$\\theta_1$\")\n",
"\n",
2021-11-03 23:35:15 +01:00
"save_fig(\"lasso_vs_ridge_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-03 23:35:15 +01:00
"## Elastic Net"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([1.54333232])"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-11-03 23:35:15 +01:00
"from sklearn.linear_model import ElasticNet\n",
"\n",
2021-11-03 23:35:15 +01:00
"elastic_net = ElasticNet(alpha=0.1, l1_ratio=0.5)\n",
"elastic_net.fit(X, y)\n",
"elastic_net.predict([[1.5]])"
]
},
{
2021-11-03 23:35:15 +01:00
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"## Early Stopping"
]
},
2016-09-27 16:39:16 +02:00
{
2021-11-03 23:35:15 +01:00
"cell_type": "markdown",
"metadata": {},
2016-09-27 16:39:16 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"Let's go back to the quadratic dataset we used earlier:"
2016-09-27 16:39:16 +02:00
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABKjklEQVR4nO3dd3hTZfvA8e/dRQtl7ynK3kMEBH6AiqKiKO/rQFABGSqguHGLA8fr9nXhQPR1bxEZLiqiMmXIBgEVKBspld3evz+epElKKAHapE3uz3Wdq8k5J+c8edrmzrNFVTHGGGPCLS7SCTDGGBObLAAZY4yJCAtAxhhjIsICkDHGmIiwAGSMMSYiLAAZY4yJiLAFIBFJFpFZIrJARBaLyP1BzukqIjtFZL5nuzdc6TPGGBNeCWG81z7gdFXNFJFEYLqITFLVGbnO+1FVzwtjuowxxkRA2AKQuhGvmZ6niZ7NRsEaY0yMCmsbkIjEi8h8YDPwjarODHLaqZ5qukki0iSc6TPGGBM+EompeESkDPAZcJ2qLvLbXwrI9lTTnQs8q6r1grx+CDAEoDycXCspiX9q1845vmFDCpmZrnBXpcpeSpU6UIDvpnDKzs4mLs76mARjeROc5Utwli/BrVixYquqVjyea0QkAAGIyH3AP6r6RB7nrAXaqOrWw53TRkTnNGoES5bk7BsxAp57zj1+/HG45Zb8SnXRkZaWRteuXSOdjELJ8iY4y5fgLF+CE5G5qtrmeK4Rzl5wFT0lH0QkBegGLMt1ThUREc/jtp70bTvixQ8ElnCqVvU9Tk8/rmQbY4wpIOHsBVcVeFNE4nGB5UNVnSAi1wCo6svARcC1InIQ2AP01lCKaKtWwa5dULKku5EFIGOMKfTC2QtuIdAqyP6X/R4/Dzx/TDd4+WW49VbAApAxxhQF4SwBFYislBQYNiygoccCkDHGFH5FPgDtrlnT9TTwYwHIGGMKvyIfgIIpXx4SE13fhIwM2L0bihePdKqMCa+MjAw2b97MgQOhDUMoXbo0S5cuLeBUFT2xmi8lSpSgRo0aBdoFvcgHoPi9e12Rp1UrmDgRABGoUgX++sudk54OdepEMJHGhFlGRgabNm2ievXqpKSk4Olcmqddu3ZR0tORx/jEYr5kZ2ezfv16tm7dSqVKlQrsPtExumrjRtiyJWCXVcOZWLZ582aqV69O8eLFQwo+xviLi4ujcuXK7Ny5s2DvU6BXDwP1/nPt3h2w3wKQiWUHDhwgJSUl0skwRVhiYiIHDx4s0HsU+QCEt34yjwC0YUMY02NMIWElH3M8wvH3U+QDkB4mANWq5Xv8xx9hTJAxxpiQFPkAxGGq4PzmJmXt2rClxhgTYaNGjaJp06aHfR7M8OHD82W+t1DuZXyKfAAKKAH5zdrjH4CsBGRM4Xf++efTrVu3oMeWLl2KiPDNN98c9XVvueUWfvjhh+NNXoC1a9ciIsyZM6fA71UQunbtyvDhwyOdjKLfDRuAUaMgORmysyE+HrASkDFFzaBBg+jVqxdr166ltv8/MPD6669zwgkncMYZZxz1dVNTU0lNTc2nVBaee0WDIl8CAuC++2DkyJzgA1C5MhQr5h5v3+4GpBpjCq8ePXpQuXJl3njjjYD9Bw4c4H//+x9XXXUVqsrAgQM58cQTSUlJoV69evznP/8hOzv7sNfNXS2WlZXFLbfcQtmyZSlbtiw33HADWVlZAa+ZPHky//d//0fZsmWpVasW3bt3DxiMeuKJJwJwyimnICI51Xe575Wdnc2DDz5IzZo1KVasGM2aNeOLL77IOe4tSX3yySeceeaZFC9enMaNGx+xpDdt2jTat29PamoqpUuXpl27dixalLO0Gj///DNdunShePHiVK9enWuvvZYMz4dg//79+eGHH3jhhRcQEUSEtRH6lh4dASiIuDg44QTfc6uGM6ZwS0hIoF+/fowbNy4goHz55Zds3bqVAQMGkJ2dTfXq1fnwww9ZunQpo0eP5uGHHz4kaOXlySef5NVXX2XMmDH88ssvZGVl8c477wSc888//3DDDTcwa9YsJk6cSOnSpTn//PPZv38/ALNmzQJcoEpPT+fTTz8Neq9nn32Wxx9/nMcee4zffvuNXr168a9//Yv58+cHnHfXXXdx/fXXs2DBAk455RR69+5NZmZm0GsePHiQCy64gE6dOrFgwQJmzpzJiBEjiPd8Af/tt98466yz6NmzJwsWLODTTz9l/vz5XHXVVTlpOvXUUxkwYADp6emkp6dTs2bNkPMvX6lqkd7q16+v+v33qh98oPr33+rvrLNUXcOQ6vjxGlOmTp0a6SQUWrGQN0uWLDlkn/d/IRJbqFasWKGATpkyJWffueeeq2efffZhXzNy5Eg944wzcp7fd9992qRJk8M+r1q1qj700EM5z7OysrRevXrapUuXoNfPyMjQzMxMjYuL0x9//FFVVdesWaOAzp49O+Dc3PeqVq2a3n///QHndOnSRfv27RtwnZdffjnn+Lp16xTIuVdu27ZtU0DT0tKCHr/iiiv0qquuCtg3b948BXTTpk05aRg2bFjQ1/sL9nfkBczR4/z8jo4S0A03wKWXHtLYY+1AxhQt9erVo3PnzowdOxaADRs2MGXKFAYNGpRzzssvv0ybNm2oWLEiqampPP300/z5558hXX/nzp2kp6dz6qmn5uyLi4ujXbt2Aef9/vvv9OnThzp16lC9enUqV65MdnZ2yPcBNx3Shg0b6NixY8D+Tp06scRvBWeA5s2b5zyuVq0a4GazCKZcuXL079+f7t2706NHD5566in+8s47BsydO5e33347pz0qNTU1Jw2///57yOkPh+gIQN6ZRvPoim1VcMYUDYMGDeLzzz9n+/btjBs3jnLlytGzZ08APvjgA2644Qb69+/PlClTmD9/PkOHDs2pGssv559/Plu2bGHMmDF8//33zJs3j4SEhGO6T7ABnbn3JSYmHnIsr3atN954g5kzZ9K5c2fGjx9P/fr1mTJlSs7rBg0axPz583O2BQsWsHLlSlq2bHnU6S9IMROArARkYt2RKsoyMnYVWCXc0bjoootITk7m7bffZuzYsVx55ZU5H9DTp0+nXbt2DB8+nNatW1O3bt2j+lZfunRpqlatyowZM/zyRXPadAC2bdvG0qVLufPOO+nWrRsNGjRg165dAdPSJCUlARzSecFfqVKlqFatGtOnTw/YP336dBo3bhxymg+nRYsWjBw5krS0NLp27cqbb74JQOvWrVm8eDF169Y9ZPNOz5SUlJRn2sMlOrphe2eq3bgxYLd/AFqzJnzJMcYcu5SUFPr06cOoUaPYsWMHAwcOzDlWv359xo0bx6RJk6hbty7vv/8+P/zwA2XLlg35+iNGjOCRRx6hfv36NGvWjBdffJH09HSqeubvKlu2LBUqVODVV1+lZs2arFy5kvvuu4+EBN/HZaVKlUhJSWHKlCnUrl2b5ORkSpcufci9br31Vu69917q1avHySefzNtvv82PP/7I3Llzjzl/1qxZw5gxY+jZsyfVq1dn9erVLFy4kGuvvRaAkSNH0r59e6655hquvvpqSpYsybJly/jyyy8ZM2YMALVr12bWrFmsXbuW1NRUypUrV6DLLhxOdJSAOnVyPydNCtjtvwTDihVH/03MGBMZgwYNYseOHXTo0IFGjRrl7L/66qu55JJL6NOnD6eccgpr167l5ptvPqpr33zzzQwYMIBBgwbRrl07srOz6du3b87xuLg4PvjgAxYuXEjTpk25+eabefDBBynmHdeB67H33HPP8dprr1GtWjUuuOCCoPe6/vrrufXWW7ntttto2rQpn332GZ988slxVYUVL16cFStWcPHFF1O/fn369etH3759GTlyJODak6ZNm8batWvp0qULLVq04I477qBy5co517jllltISkqicePGVKxY8ajatvKTaBH/VG7QoIEunzAB6teH0qXdoB9PJFeFChXcLnDrA9WoEcHEhpG3WG4OFQt5s3Tp0oAP7lDE4ro3oYjlfMnr70hE5qpqm+O5fnSUgOrVc8ugJiTA33/n7BaBhg19p8X
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 16:39:16 +02:00
"source": [
"from copy import deepcopy\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# extra code creates the same quadratic dataset as earlier and splits it\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
"m = 100\n",
"X = 6 * np.random.rand(m, 1) - 3\n",
"y = 0.5 * X ** 2 + X + 2 + np.random.randn(m, 1)\n",
2021-11-10 05:58:42 +01:00
"X_train, y_train = X[: m // 2], y[: m // 2, 0]\n",
"X_valid, y_valid = X[m // 2 :], y[m // 2 :, 0]\n",
2016-09-27 16:39:16 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"preprocessing = make_pipeline(PolynomialFeatures(degree=90, include_bias=False),\n",
" StandardScaler())\n",
"X_train_prep = preprocessing.fit_transform(X_train)\n",
"X_valid_prep = preprocessing.transform(X_valid)\n",
"sgd_reg = SGDRegressor(penalty=None, eta0=0.002, random_state=42)\n",
"n_epochs = 500\n",
"best_valid_rmse = float('inf')\n",
"train_errors, val_errors = [], [] # extra code it's for the figure below\n",
2016-09-27 16:39:16 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"for epoch in range(n_epochs):\n",
" sgd_reg.partial_fit(X_train_prep, y_train)\n",
" y_valid_predict = sgd_reg.predict(X_valid_prep)\n",
" val_error = mean_squared_error(y_valid, y_valid_predict, squared=False)\n",
" if val_error < best_valid_rmse:\n",
" best_valid_rmse = val_error\n",
" best_model = deepcopy(sgd_reg)\n",
"\n",
" # extra code we evaluate the train error and save it for the figure\n",
2021-11-03 23:35:15 +01:00
" y_train_predict = sgd_reg.predict(X_train_prep)\n",
" train_error = mean_squared_error(y_train, y_train_predict, squared=False)\n",
" val_errors.append(val_error)\n",
" train_errors.append(train_error)\n",
"\n",
"# extra code this section generates and saves Figure 420\n",
2021-11-03 23:35:15 +01:00
"best_epoch = np.argmin(val_errors)\n",
"plt.figure(figsize=(6, 4))\n",
"plt.annotate('Best model',\n",
" xy=(best_epoch, best_valid_rmse),\n",
" xytext=(best_epoch, best_valid_rmse + 0.5),\n",
2021-11-27 01:38:47 +01:00
" ha=\"center\",\n",
2021-11-03 23:35:15 +01:00
" arrowprops=dict(facecolor='black', shrink=0.05))\n",
"plt.plot([0, n_epochs], [best_valid_rmse, best_valid_rmse], \"k:\", linewidth=2)\n",
"plt.plot(val_errors, \"b-\", linewidth=3, label=\"Validation set\")\n",
"plt.plot(best_epoch, best_valid_rmse, \"bo\")\n",
"plt.plot(train_errors, \"r--\", linewidth=2, label=\"Training set\")\n",
"plt.legend(loc=\"upper right\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"RMSE\")\n",
"plt.axis([0, n_epochs, 0, 3.5])\n",
"plt.grid()\n",
"save_fig(\"early_stopping_plot\")\n",
2016-09-27 16:39:16 +02:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-02 13:14:44 +02:00
"# Logistic Regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-03 23:35:15 +01:00
"## Estimating Probabilities"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAADICAYAAAD2r9syAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAwS0lEQVR4nO3de3xU9Z3/8dc3F+6BcFGQW8EbQkHEC962emy1AnVxu9iWVrxQLYuLVdrlp9JusbtUl1Z6oa2sFSt0F2/11uIqeKkdQBENgQCJMEAChhAghHACIdeZfH9/fJMhCYEMkjCZ5P18PL6Pmc+ZcyafnByGz3zP93yPsdYiIiIiEk8SYp2AiIiIyKlSASMiIiJxRwWMiIiIxB0VMCIiIhJ3VMCIiIhI3FEBIyIiInEnKdYJNCY1NdWef/75sU6jTTl69Chdu3aNdRptjvZr8woGg4TDYUaMGBHrVNocHavNT/u0ZaSnpxdaa89qar1WWcD07duXdevWxTqNNiUQCOB5XqzTaHO0X5uX53n4vq9//y1Ax2rz0z5tGcaYz6JZT6eQREREJO6ogBEREZG4owJGRERE4o4KGBEREYk7KmBEREQk7rTKq5BOprq6mry8PI4ePRrrVOJKjx492LJlS6zTOOOSk5M5++yz6d69e6xTERGRZhR3BUxhYSHGGIYNG0ZCgjqQonXkyBFSUlJincYZZa2lrKyMPXv2AKiIERFpQ5qsAIwxzxpjCowxmSd43RhjfmuM2WGM2WSMubTOa+OMMcGa1x5pjoR936dv374qXqRJxhi6dOnCgAEDKCgoiHU6IiLSjKKpApYA407y+njggpo2DfhvAGNMIvBkzesjgG8bY057es1wOExycvLpvo20I507d6aqqirWaYiISDNq8hSStXaVMWbISVa5Ffgfa60F1hpjUo0x5wBDgB3W2hwAY8yLNet+erpJG2NO9y2kHdHxIiISW9ZCdTWEw+7xZC1azXEeZgCwu06cV7PsRMubtHv3bpYsWQJAVVUVnuexdOlSwA3iDQaDFBUVARAKhQgGgxw6dCiyfjAYxPf9enFxcTEAlZWVBINBDh8+DEBFRQXBYJAjR44AUF5eTjAYpKSkBICysjKCwWBk0HBpaSnBYJDS0lLA3QsjGAxSVlYGQElJCcFgkPLycsCNPQkGg1RUVABw+PBhgsEglZWVABQXFxMMBiM9BL7v14sPHTpEMBgkFAoBUFRUFLlfDMDBgwcJBoNU1/zVCwsLCQaDkX154MCBenFBQQHbtm2LxPv372f79u2ReN++fezYsSMS7927l+zs7Eicn59PTk5OJN6zZw87d+6MxHl5eezatave3/Kzz47NCp2bm0tubm4k/uyzz9i9+9hhsmvXLvLy8iLxzp07I2NYAHJycsjPz4/E2dnZ7N27NxLv2LGDffv2ReLt27ezf//+SDx+/HgWLlwYiW+88UYWLVoUiT3PO+GxV1paiud5vPTSS4D7282cOZPXXnsNcPve8zzeeOMNwO1Lz/NYsWJFZF94nsd7770X+V08z2PlypWAuw+Q53msWbMGgMzMTDzPIy0tDYCMjAw8zyMjIwOAtLQ0PM8jM9Od3V2zZg2e50X+3itXrsTzvMjf67333sPzvMj+XrFiBZ7nRfbXG2+8ged5FBYWAvDaa6/heV7k385LL72E53mRY3/p0qV4nhc5VpcsWVJvWvVFixZx4403RuKFCxcyfvz4SLxgwQImTpwYiefPn09WVlYknjdvHpMnT47Ec+fOZcqUKZF4zpw5TJ06NRLPnj2badOmReJZs2YxY8aMSDxz5kxmzpwZiWfMmMGsWbMi8bRp05g9e3Yknjp1KnPmzInEU6ZMYe7cuZF48uTJzJs3LxJPmjSJ+fPnR+KJEyeyYMGCSNzcx57nead07M2cOVPHXo3Gjr1JkyZF4lgee6EQTJkynVmz5pGXB9nZcMstD3PffYv46CMIBMDzHueuu17hlVfg+efh8st/z223vcPvfw+//jWMGLGECRM+Ys4cmD0bhg59Fc/bxH33wfe+B/37v81VV21n8mSYNAn69PmA0aN3c/PN8OUvV5OamsGFFxZw1VVw6aVhunXbweDBPsOHw/nnh+ncOZ+zzipl4EDo27eaDh0O0b17JT17QkpKNYmJZXTsGKZjR0hKshgDCQmQlAQdO0LnztC1K6SkQI8e0LMn9O4NZ50FffsSteYYxNvY11t7kuWNv4kx03CnoEhOTmbr1q0EAgFCoRC+77NlyxYCgUBkIGZZWRlHjhwhHA4TCoUoKysjKSmpXpyYmEgoFIrECQkJVFVVEQqFKC0txRhDZWVlJAZX0NTG1tpIfPToUaqrqykvL4/E4XCYsrKySFz7c0KhECUlJVRVVVFaWhqJKysr68XJycn11k9KSqr3fomJifVebxgnJCRE8ikpKcEYE4nrFmThcJhwOMyRI0cice3rFRUVx8V1t28Y1+6vunFVVVW9uO77VVZWYq2NxLUfOHVjY0y9uLq6+oTrh0IhKisr68UVFRUnjMPhMBUVFZSXlxMIBCgqKmLbtm0EAgHgWIFYG/u+f8Jjr7y8HN/3ycrKIhAIUFJSQjgcJjMzk169elFcXIzv+2zevJmUlBSKiorwfZ9NmzbRqVMnCgoK8H2fjRs3kpSURH5+Pr7vs2HDBqy15Obm4vs+69evp7Kykp07d+L7Punp6Rw9epQdO3ZE7hNUm6fv+6SlpVFYWEhmZia+7/Pxxx+zd+9eMjIy8H2ftWvXkpuby8aNG/F9n48++ojs7Gw2bdqE7/usWbOGXr16sXnzZnzf58MPP6RHjx6R91u9ejXdunUjKysL3/dZtWoVnTp1YsuWLfi+z8qVK0lKSorkU7sva79Y1Mbbtm2jqKgoEm/fvp2DBw9G4uzsbKqqqgiHwwQCAXJycigoKIi8vnPnTvbv3x+Jd+3axYEDByJxbm4uxcXFkXj37t1UVFRE4trCuDbes2cPHTt2jMT5+fkcPXo0Eu/bty+SC7hiv0OHDpG4oKCAnJycSHzgwAGys7Mj8cGDB9m+fXskbu5jz/f9Uzr2wuGwjr2THHt1j6WTHXvWwvbte9m3r4KtWy3BYDppad0oLu7MT37yKWVlibz33igqKpL5zndyKS9PYMOGfyQU6khaWiHl5Yns2nU/1dWd+POfS6moSMD3n6C6ugNPPFFNOJwAPAXAL39JjZ8D8NRTtfGPAPif/6mN7yc9HV59tTa+my1bYPny2ngSu3ZBTb0K3MzevfDxx7XxP3DwIGzaBK5f4xJqakcgETifYxf+JgL9qfmOXrN+T46dpU8AOlPzHZ26pUBCgsUYS3V1iKQkQ2JiAhCmsrKCTp2SSUpKpLo6TE3/QZOMO/PTxEruFNL/WWtHNvLaH4CAtfaFmjgIeLhTSD+11t5cs3w2gLX2v5r6ecOGDbN1ew3q2rJlC8OHD28yZ6mvPV6FVFdLHTe6mVvzqr2ZY+03fWk+OlaPCYWgqAgOHjzWiorA9+HQoWPN96G4uH47cuTUTnOcKmNcD0WnTq517Oha3ecdO0KHDvUfk5Pd89qWnHysNYyTklyrfV53WVISJCYeH9cuq/vYsCUk1H+s+9wY16LbBybdWnt5U+s1Rw/MMuD+mjEuVwLF1tq9xpgDwAXGmKHAHmAy8J1m+HkSJ1atWsX8+fNJT08nPz+fxYsXc/fdd8c6LRFpY6x1hcW+fa7t33/ssaAADhw49njggCtMTkeXLu70R1JSKX37diElxcVdu0K3bsdaly5uWdeux5537uyed+ly7HmnTu55586uONCwveg0WcAYY17A9aj0McbkAY8CyQDW2qeAt4AJwA6gFJha81rIGHM/8Dauz+lZa23WcT9A2qySkhJGjhzJnXfeyZ133hnrdEQkDlnrekdyc2H37mOPe/a4lp/vHk9lblNjjo27qG29erlldVtqqmvdu7uxGj16uOdJNf9zBgKfqFcrhqK5CunbTbxugRkneO0tXIEj7dCECROYMGECgHpeROSESkthxw43YHXnzmMtJwc++8y93pTOneGcc6BfPzcQtPaxb18
"text/plain": [
"<Figure size 576x216 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code generates and saves Figure 421\n",
2021-11-03 23:35:15 +01:00
"\n",
"lim = 6\n",
"t = np.linspace(-lim, lim, 100)\n",
"sig = 1 / (1 + np.exp(-t))\n",
2021-11-03 23:35:15 +01:00
"\n",
"plt.figure(figsize=(8, 3))\n",
"plt.plot([-lim, lim], [0, 0], \"k-\")\n",
"plt.plot([-lim, lim], [0.5, 0.5], \"k:\")\n",
"plt.plot([-lim, lim], [1, 1], \"k:\")\n",
"plt.plot([0, 0], [-1.1, 1.1], \"k-\")\n",
2021-11-03 23:35:15 +01:00
"plt.plot(t, sig, \"b-\", linewidth=2, label=r\"$\\sigma(t) = \\dfrac{1}{1 + e^{-t}}$\")\n",
"plt.xlabel(\"t\")\n",
2021-11-03 23:35:15 +01:00
"plt.legend(loc=\"upper left\")\n",
"plt.axis([-lim, lim, -0.1, 1.1])\n",
"plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1])\n",
"plt.grid()\n",
"save_fig(\"logistic_function_plot\")\n",
"plt.show()"
]
},
{
2021-11-03 23:35:15 +01:00
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-03 23:35:15 +01:00
"## Decision Boundaries"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"['data',\n",
" 'target',\n",
" 'frame',\n",
" 'target_names',\n",
" 'DESCR',\n",
" 'feature_names',\n",
" 'filename',\n",
" 'data_module']"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-11-03 23:35:15 +01:00
"from sklearn.datasets import load_iris\n",
"\n",
"iris = load_iris(as_frame=True)\n",
"list(iris)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
".. _iris_dataset:\n",
"\n",
"Iris plants dataset\n",
"--------------------\n",
"\n",
"**Data Set Characteristics:**\n",
"\n",
" :Number of Instances: 150 (50 in each of three classes)\n",
" :Number of Attributes: 4 numeric, predictive attributes and the class\n",
" :Attribute Information:\n",
" - sepal length in cm\n",
" - sepal width in cm\n",
" - petal length in cm\n",
" - petal width in cm\n",
" - class:\n",
" - Iris-Setosa\n",
" - Iris-Versicolour\n",
" - Iris-Virginica\n",
" \n",
" :Summary Statistics:\n",
"\n",
" ============== ==== ==== ======= ===== ====================\n",
" Min Max Mean SD Class Correlation\n",
" ============== ==== ==== ======= ===== ====================\n",
" sepal length: 4.3 7.9 5.84 0.83 0.7826\n",
" sepal width: 2.0 4.4 3.05 0.43 -0.4194\n",
" petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)\n",
" petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)\n",
" ============== ==== ==== ======= ===== ====================\n",
"\n",
" :Missing Attribute Values: None\n",
" :Class Distribution: 33.3% for each of 3 classes.\n",
" :Creator: R.A. Fisher\n",
" :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n",
" :Date: July, 1988\n",
"\n",
"The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\n",
"from Fisher's paper. Note that it's the same as in R, but not as in the UCI\n",
"Machine Learning Repository, which has two wrong data points.\n",
"\n",
"This is perhaps the best known database to be found in the\n",
"pattern recognition literature. Fisher's paper is a classic in the field and\n",
"is referenced frequently to this day. (See Duda & Hart, for example.) The\n",
"data set contains 3 classes of 50 instances each, where each class refers to a\n",
"type of iris plant. One class is linearly separable from the other 2; the\n",
"latter are NOT linearly separable from each other.\n",
"\n",
".. topic:: References\n",
"\n",
" - Fisher, R.A. \"The use of multiple measurements in taxonomic problems\"\n",
" Annual Eugenics, 7, Part II, 179-188 (1936); also in \"Contributions to\n",
" Mathematical Statistics\" (John Wiley, NY, 1950).\n",
" - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n",
" (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.\n",
" - Dasarathy, B.V. (1980) \"Nosing Around the Neighborhood: A New System\n",
" Structure and Classification Rule for Recognition in Partially Exposed\n",
" Environments\". IEEE Transactions on Pattern Analysis and Machine\n",
" Intelligence, Vol. PAMI-2, No. 1, 67-71.\n",
" - Gates, G.W. (1972) \"The Reduced Nearest Neighbor Rule\". IEEE Transactions\n",
" on Information Theory, May 1972, 431-433.\n",
" - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al\"s AUTOCLASS II\n",
" conceptual clustering system finds 3 classes in the data.\n",
" - Many, many more ...\n"
]
}
],
"source": [
"print(iris.DESCR) # extra code it's a bit too long"
]
},
{
2021-11-03 23:35:15 +01:00
"cell_type": "code",
"execution_count": 50,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sepal length (cm)</th>\n",
" <th>sepal width (cm)</th>\n",
" <th>petal length (cm)</th>\n",
" <th>petal width (cm)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.1</td>\n",
" <td>3.5</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4.9</td>\n",
" <td>3.0</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.7</td>\n",
" <td>3.2</td>\n",
" <td>1.3</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)\n",
"0 5.1 3.5 1.4 0.2\n",
"1 4.9 3.0 1.4 0.2\n",
"2 4.7 3.2 1.3 0.2"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-11-10 05:58:42 +01:00
"iris.data.head(3)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0 0\n",
"1 0\n",
"2 0\n",
"Name: target, dtype: int64"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-11-10 05:58:42 +01:00
"iris.target.head(3) # note that the instances are not shuffled"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array(['setosa', 'versicolor', 'virginica'], dtype='<U10')"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-11-03 23:35:15 +01:00
"iris.target_names"
2017-05-29 23:20:14 +02:00
]
},
{
2021-11-03 23:35:15 +01:00
"cell_type": "code",
"execution_count": 53,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"LogisticRegression(random_state=42)"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"X = iris.data[[\"petal width (cm)\"]].values\n",
"y = iris.target_names[iris.target] == 'virginica'\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n",
"\n",
"log_reg = LogisticRegression(random_state=42)\n",
"log_reg.fit(X_train, y_train)"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAADICAYAAAD2r9syAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABkEklEQVR4nO3dd3gUxRvA8e+kV0ILLaFDgBBASggBlCDwE6UpRarSi4CKiooVEAVRKSooHaUIiiBSBEUgYJBAQpAaOqEl9JpA2mV+f+ylXOpduORS5vM8+9zd7uzue5tL7s3M7IyQUqIoiqIoilKYWFk6AEVRFEVRFFOpBEZRFEVRlEJHJTCKoiiKohQ6KoFRFEVRFKXQUQmMoiiKoiiFjkpgFEVRFEUpdPItgRFCLBFCXBdCHM1iuxBCfCOEOCOEOCyEaJJfsSmKoiiKUrjY5OO5fgDmAMuy2P4sUFu/+AHf6x+zVbJkSVmrVi0zhVi8xcTE4OzsbOkwigR1LSEhKYFzd85Ro1QNbK1sc3WMhwkPOXnzJHXL1sXR1tEiMTzuMSLuRnDr4S3KOpelqlvVXMVw+9Ftzt85T43SNSjlUCrf9we4EXODi/cuUrVkVco6lTV5f3P8LMzxPu7H3ef07dN4lfbC1d41d/vfOo1Xmdztrxg6cODATSmle272zbcERkq5WwhRLZsi3YBlUhtZL1gIUVIIUVFKGZXdccuXL09oaKg5Qy22AgMDCQgIsHQYRYK6ljB682iOHjhK66atmdtpbq6O4fOdD0k3khDugtDRpv+emyOGxz2GmCwAuMlNbky8kasY7D+1Bx1ctr7M2Q/P5vv+ANaTrQG4xCUiJkaYvL85fhbmeB+lp5eGWLjhcIOT757M9/0VQ0KIC7ndNz9rYHLiAVxK8/qyfl22CcylS064uoK1NVhZaUvy8wED4MsvtXJHj8LzzxtuT/u4fDl4e2tlv/wSNm7MeDwrK/Dygq+/Tj3/iy8alrO1BRsbbenXD1q31sodOACbNqVuT/84ZIh2DIDAQHjwIPOyFSpAlSpaucREuH8fHBzA3l47v6IUBFEPolj631KSZBJL/1vKR20+ooJLBZOO8V/Ufxy7cQyAYzeOcfjaYRqWb5ivMTzuMYb+NtTg9ciNI5nfZb5JMaw+spp4XTwA8bp4fj3+Kz29e+bb/gDf7/+eJJIASCKJhWELGd5kuNH7m+NnYY73se3sNu7E3gHgTuwddpzfwdPVn863/YsaXZKOOF0ccYlxxOniiE2MTXme/JigSyAhKYF4XXymzx9HQUpgRCbrMp3nQAgxAhgBYGPTgEePMj/gyZORBAaeAuDECVfOnm2a5cmDgkK5fj0agF27vPjnn0qZljt37gGBgQcASEqCNWsCsjymk9NJEhO1/GvDhkrMmuWVZdmaNQMR+iswYkRTTp/OvGqyU6dIxo8/pX9/Lowa1Sxlm7V1Era2Eju7JOzskvj888PUrBkDwOrVlQkOLoOtrbbN1jYJR0cdDg5JeHg8omfPy0RHRxMYGMi2beWxs9Ph6JiEg4NOv2jP3dwSsLdPyvJ9KJrka1lczTo1i0RdIgAJugRGrRrFuNrjTDrG4JDBBq+7LevGUt+lZo+hZ0/tS/DXX3/N9TGysuTwEoPXC8IW0Ne1r9H7A7y06yWD133X9KVsG+ObcB53f4Axu8YYvB65cSS179c2en9zfB7M8T66B3U3eN1tZTc2tt6Yb/vnB53U8TDxIQ91D4nVxRKbFEusLpa4pDge6R4RlxSnrdevS1sm+TE+KZ74pHgt2UiKJ0EmpD5Ps04ndRZ9rwUpgbkMVE7z2hOIzKyglHIBsADAy6uuDA3VkomkJNDpUp87OlaiZEktEWnRAtq3T92e/rFu3WYkd1nw8IB33sl4PJ0OSpRwpXXrAH0c8PPPqdt0Oq1WJCFBe2zbtg7e3nUAcHYGFxfD7cmPOh20bRuQ8v46dYLTpzOWTUiAVq0qERBQKeWYJUtCbCzExYFOZ4VOB7GxWlVM8+a+NGigHfPHH+HQocwvvJ8fzJlTi8DAQFq1CqBt26x/SAsWwHD9P14rVsDkyeDmlvkyaRIpSdn+/VoNU5kyULasdi1EZilrEVGcm5CiHkTx156/SJTaF1aiTOSv638xr+88o//r/i/qPyJ2RRisi3gYQel6pY2qhTElhps3b+bJ+0hf+5Js1YNVRtfCrD6ymkQSDdYlksjNcjeNqn143P1Bq32R6f6XlEhOlzhtVC2MOT4P5ngf285uI1oXbbAuWhdNUtUko2pRHnd/Yz1KeMStR7e49fAWtx/d5tajW9x5dIf7cfe5F3fP4PF+3H3uxaZ5HnePhwkPzRZLTgQCext77K3tsbexx8HGIeW5vbU9dtZ22FnbYWttqz1a2Ro+t7JlEYtyf/78nMxR3wdmk5TSJ5NtnYCxwHNonXe/kVI2z+mYderUkSdPqnZIKbUEJy5OW2JjoVw5sLPTtp86BZGRqcnOo0cQEwMPH2pJRb9+2peun18Aw4dr25K3Jz+PidGaz/T/sPLll1qilxknJ618snr14MSJ1Nd2dloiU6YMDB4Mb7yhrb98Gdau1bZVqAAVK2pLyZKFK+EpzgnM6M2jWXxwcUp1P4CdtR3DGg8zuu+Dz3c+Kc1HadV3r8/R0ZneyGj2GB73GMl9XzIjJxr3d9f+U3uD86eNI+7DuDzfH7S+L8nNR2lZYYVuYs7/gZvjZ2GO91F6eumU5p+0SjmU4va7t/Nkfykl0fHRXI2+mmG5+fAmtx6lJinJCcujxCyaFIwkEJSwL4GrvSvOts442TrhbKd/tE33mH69/rWjjaOWjOgTkbTP066zsbJBPOYfZiHEASlls5xLZpRvNTBCiFVAAFBWCHEZmAjYAkgp5wF/oCUvZ4CHwODMj6RkRggtKbCzA9dMWp+8vLQlJ46OWs2KMUaOhK5d4d69jIsu3d+1hg21Y9+6BTdvaolRZKS23E7zu3/0KIwbl/Fc9vZaIhMYCFX1N3Js3Kgdy9NT6xdUpYp2DsWy9l7em+HLJl4Xz7+X/zX6GGfvZN5BM6v1eRGDOY7xuDL70s5uvbn3BzJNXrJbn545rqM53sfd2LsmrTe23J3YO6w4vIKL9y5y+f5lgyTlWsw1k2tE7KztKONYhjJOZVIeSzmUws3ejRL2JXBz0B6Tl+T1yYuznTNWongM8ZavNTB5QdXAmE9+1ho8fJiazJQpk9ox+fBhWLgQbtyAq1chKkpbHjzQtt+9qzVPAXToAH//bXjccuW0BKdrV/jwQ21dfDycPAk1a2o1Q/mhONfAmFteX8umTbW+cQcOHMizcxQE6jNpmnhdPBF3Izhz+wxnbp/hwt0LXLx/kUv3LnHmxhlux9/O0LSWGUcbRyq6VqS8c3kquFSggksFyjuXx93ZPUOiUtqxNM62zo9dq1GYFIoaGEVJy8lJWypXNlzfsCF8+23G8tHRWiJTokTquk6dtGamy5fhwgW4dAmuX9eWRo1Sy506pR0XtNqa2rW12qjatbUlIMDwuErxEhYWZukQFAtJTErk7O2zhN8M58ztM5y9fZYzd7SE5eK9iyTJrGuZrIQVHq4eVHGrQhW3KlQuURnPEp5Ucq1kkKy42LkUq4QkP6kERikUXFy0ZCOt9E1NOp2W5Fy4kFpLA1ptT506cO6cluxcvgw7d6ZuP3oU6tfXni9erCVCPj7QoIFWa2OjfkuKNDWOVNGXnKgcu3GM4zeOpzyeuHkiy2YoK2FFtZLVqFmqJrVK16J6yepaouJWmSvHr9Djfz2wsVJ/HCxJXX2lyLC21mpYPD0N1zdvrnUgTkyEixe1GplTp7Q7vU6d0pKUZMuWwe7dqa/t7bXxgRo0gGefhT598ue9KPknuQmpMIuOj6bPr32oXrI6T1Z9khaeLahconKx/M//UcIjDl87zMGrBwmLCiMsKoyj148Sp8u8s2/lEpXxdvfGq4xXSrJSq3QtqpWshr2Nfab7BJ4LVMlLAaB+AkqxYWMDNWpoS8eOmZd5/XXw99dqZY4c0RKegwe1xckpNYE5f167TbxZM2154gnVgVixnOs
"text/plain": [
"<Figure size 576x216 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"X_new = np.linspace(0, 3, 1000).reshape(-1, 1) # reshape to get a column vector\n",
2017-05-29 23:20:14 +02:00
"y_proba = log_reg.predict_proba(X_new)\n",
2021-11-03 23:35:15 +01:00
"decision_boundary = X_new[y_proba[:, 1] >= 0.5][0, 0]\n",
"\n",
"plt.figure(figsize=(8, 3)) # extra code not needed, just formatting\n",
2021-11-03 23:35:15 +01:00
"plt.plot(X_new, y_proba[:, 0], \"b--\", linewidth=2,\n",
" label=\"Not Iris virginica proba\")\n",
"plt.plot(X_new, y_proba[:, 1], \"g-\", linewidth=2, label=\"Iris virginica proba\")\n",
"plt.plot([decision_boundary, decision_boundary], [0, 1], \"k:\", linewidth=2,\n",
" label=\"Decision boundary\")\n",
"\n",
2022-05-24 14:37:17 +02:00
"# extra code this section beautifies and saves Figure 423\n",
2021-11-03 23:35:15 +01:00
"plt.arrow(x=decision_boundary, y=0.08, dx=-0.3, dy=0,\n",
" head_width=0.05, head_length=0.1, fc=\"b\", ec=\"b\")\n",
"plt.arrow(x=decision_boundary, y=0.92, dx=0.3, dy=0,\n",
" head_width=0.05, head_length=0.1, fc=\"g\", ec=\"g\")\n",
"plt.plot(X_train[y_train == 0], y_train[y_train == 0], \"bs\")\n",
"plt.plot(X_train[y_train == 1], y_train[y_train == 1], \"g^\")\n",
"plt.xlabel(\"Petal width (cm)\")\n",
"plt.ylabel(\"Probability\")\n",
"plt.legend(loc=\"center left\")\n",
2017-05-29 23:20:14 +02:00
"plt.axis([0, 3, -0.02, 1.02])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
2017-05-29 23:20:14 +02:00
"save_fig(\"logistic_regression_plot\")\n",
2021-11-03 23:35:15 +01:00
"\n",
2017-05-29 23:20:14 +02:00
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"1.6516516516516517"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
"decision_boundary"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([ True, False])"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
"log_reg.predict([[1.7], [1.5]])"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAAEQCAYAAAC++cJdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAACq5UlEQVR4nOydd1gURx+A3zl6taFgwW7svQD2Envvms8kmsQSY9fEGBNjjSaxd2NsUWPv3cReAHuJvYsKdun1br4/DhAEFA64A533efbxbnZ25rd7cP6YnX1HSClRKBQKhUKhUCg+FDSmDkChUCgUCoVCoTAmKgFWKBQKhUKhUHxQqARYoVAoFAqFQvFBoRJghUKhUCgUCsUHhUqAFQqFQqFQKBQfFCoBVigUCoVCoVB8UJibOoD0xMnJSRYsWNCgY4ODg7Gzs0vbgDIoL3zDef4oHKe8VmRzsTK4HWNes8B7Lwh9HoxjwexYZzfB5yQl4f9dBwSWxQogrBNet/DwcKys3n09U3rddMFBRN65hiZLNizyFQRhur9joyJDeeF3GSvrLDg6FUYYKRZj/aw9Dr1HYORznG0K4mCRPd37S28+pO+1tEJds5SjrplhqOuWck6fPv1MSpnToIOllO/tVrlyZWkoBw4cMPjYzEZYSJT8pdNZ2YRdckavizIyQmtQO8a8ZqEvguWGujPkTPpLr9E7pU6nM1rfsTGcvCBvuVSXN7NUksH/HE3WMVeuXJETJkyIF29Kr5tOp5MvFv4mrxdD3u/oLiOfPU7R8WnNhSPz5KyBZvLvSeVl4Asfo/RprJ+1gPDn8luvOrLJLuTKG2NN8nOWlnxI32tphbpmKUddM8NQ1y3lAKekgTmimgKhwMrGjOGrytP5h8Ls+uMBP7c4TbB/pKnDeivW2WxpvacvJT6vxonRu/jns+Vow40bs3WVsrh6r8PcNTcPm/bE/8+1b63/9OlTmjdvzsiRI+nWrRthYWEG9SuEINtX35J71gbCr57Hp6M7ETevGNRWWlC2Zh9a9t6B/7PbrJ3qxlOfsyaLJa1xsMzO+Kp7aJDnM5bfHMWUi58ToQs3dVgKhUKhSCUqAVYAoNEIuk/4iEGLynB+/wuG1vDm8b1QU4f1Vswszfl4yf9wH9ecaytOsbnhXEKfBxs1Bov8ech3bDW2DTx40vNHnn3/O1KnS7TulClTuH37NgB///03DRs25NmzZwb3bd+4HfmWH0SGheDT2YMQz/0Gt5VaCpRsTIfBxxAaMzbMqMWdi9tMFktaY6mxYmjZpXxadCz7Hi1n8oVPTR2SQqFQKFKJSoAV8Wj8RT7G76nCswdhDHbz5NqJV6YO6a0IIaj6Y2Mar/qcxyfusd5jKq9uPDFqDGaO9uTZvgDH3l14+etC/LoMRheacHR3/Pjx9O7dO/b90aNH8fDwwMfHx+C+rctXw3WdN+Yu+Xj4ZWP81y82uK3U4pSnLJ2GeJPNpSQ7/mzD+UMzTRZLWiOE4JOiP/FduZW0LzjM1OEoFAqFIpWoBFiRgAr1czDV0x1LGzOG1z3BsY1+pg7pnXzUpTJt9vUj7EUI69yn8vDILaP2L8zNyTVvDE6ThxO0fjcP631K1OP4o7vm5ubMmzePyZMnI4QA4ObNm/Tr14/Dhw8b3LdF3gLkW30MW7d6PPnhS55N+SHJUej0xi5Lbtr1P0ihsq04vGEgh9b3R6eNMkks6UG9PJ9QPGs1U4ehUCgUilSiEmBFouQvac90b3cKlXdgQodzrJ98B/1884xLnhqF6eg1BGsnezZ/PJtrK08atX8hBNmGfknuDbMIv3ANH/dOhF++maDO0KFD2bBhAzY2NgAEBATw8ccfs3z5coP7NnPIQp4/duDYuRcvF0zEb3BXdGGmmcJiYWVH0y/WU6HeEC4cns2OP9sQERZoklgUCoVCoUgMlQArkiRrLism7a9GzQ4uLPr2GrP6XCIq0jQji8kla9GcdPQcTG6PQuztthzvMbuMnrjbt21EvkMrkKFhPKjemZB9xxPUadu2LYcOHcLZ2RmAyMhIPvvsM0aPHm1wvMLCglxj5+P03W8E7VrLw8/qE/XcuNNBYtBozKjVdgp1O83j3pXdbJhRi6CXD0wSi0KhUCgUb6ISYMVbsbIx4/vV5ek8IhMZIrLb0XpvX0p8Fm2I+HyF8Q0RVcu9NkQ0+Qr/ResS1KlatSre3t4UKlQotmzMmDGsXr3a4H5jDBEus9YTfuVcxjBE9Nr+XhoiFAqFQpF5UQmw4p1oNILuv7w2RAyrmUkMEUv/h9vYZlxbfpLNjUxgiCiQV2+IqO/Ok69G8uyHKQnm5hYoUICZM2fSqFEjANq3b0/nzp1T3bdD4/bkW3EIGRpsekNEqSZ0GHQUITR6Q8R/200Wi0KhUCgUoBJgRQpo/EU+xu2uwlOfaEPESX9Th/RWhBBU+6kJjf/+HD+vu3pDxM2nRo0hniFi4oJEDRH29vZs376dSZMm8ddff6HRpM2vZQJDxIYladKuITjlLUenoSfI5lyCHQtbc/7QLJPFolAoFAqFSoAVKaJigxxMOR5tiKjjzbFNj00d0jv5qGtl2u7vH2uIeHTUyIYICwu9IeL3aENE/c+IevI8Xh0LCwuGDx+Ora1tvHKtVsv58+cN7tsiX0HyrTqKTbW6PBnxBc+mjjStIWLAIQqWacnhDQM4tH4AOp3WJLFkFDL6g6UKhULxvqISYEWKKVBKb4goWM6BCe3PZi5DRA47NjUwkSFi2JfkXj+T8PNX8XHvSMSVm+88bvDgwVSrVo0VK1YY3LeZY1byLtyJY6eevJz/C35DPjGpIaLZlxuiDRGz2LGwDRHhQSaJJSMQo8PL6L8/CoVC8b6hEmCFQWTNZcWvB14bImZ/fRmdNmP/J/6mIeLEuN3GN0S0a6w3RISE4eORuCEihgULFjBr1iwiIiL49NNPGTNmTOoMEeMWkOPbXwnauYaHnzcg6oVxp4PEEGuI6DiXe5d3smF6LYJePTRJLKbEN+Q2W+7NYurFL1h+cxQPg2+YOiSFQqH4YFAJsMJgYgwRnb4vxM4FPqwfEUlwQMZe9CCuIcJ71E7+7b7SdIaIfC48bPIVtjuPJlqvSZMmlC5dOvb96NGj+eyzzwgPDzeoXyEE2Xt+pzdEXD7Lg47uRNy6alBbaUHZWl/TsvcO/J/dYu0UN54+OGeyWIzNhReHGOzlzv5Hy3kQfJX7QVcY5FmN+0GmM3YoFArFh4RKgBWpQqMR9JhYnEF/luHeGR3Danjx5H7mMURc/esEWxrPI+yFiQwR9dzI9vtfPBsxOVFDxLFjx2jYsGFs2YoVK2jYsCHPnz9/s8lkozdEHEQXEqQ3RHgdMLit1PLaECHYML0mdy7tMFksxsL7yXZ+OtWEKk5NGFxmMVPdj/NjxfVUzdmclTdHo9Vl7D8iFQqF4n1AJcCKNKHxl/no+KsFT33CGOTmlWkMEY1Wfoav5x3WeUwzviEiiwN5dvxBcItavJz0B35dExoismTJwo4dO+jVq1ds2ZEjR/Dw8ODmzXfPIU4K6/JuekNErjw8/KIRARuXGtxWanHKW46OQ73Jmqs4O/5o9V4bIm76n2HCuQ58nLc7fUrOpKBDmdh9eWyL8jz8EZKMPZVIoVAo3gdUAqxIMwpWNtMbIqw1mcYQUfyTKrTd14+w58EmM0S8GtJNb4hYuytJQ8T8+fP5/fffY8tu3LiBu7s7R48mPn0iOVjkK0i+1cewqVaXx9/34Nm0H01miLDPkof2Aw/HGiIObxj43hkigqMCWHhtKJWdGtOt6GjsLbLG7gvThnD2+T/ksS2KucbCdEEqFArFB4JKgBVpSoFS9kzzem2I2DAlExgiahbRGyKy27KpwWyurzpt3ACiDREu62claYgQQjBs2DDWr1+PtbU1AM+fP6dBgwZs2bLF4K5jDREdv+LlvAl6Q0R42LsPTAdiDRF1B3P+0Mz3zhARFPEC35Bb1HTuQDYr59jyMG0Iu30WEqYNplKORiaMUKFQKD4cVAKsSHOyOesNETXaO/PnsGvM7nsZbZRpRhaTS9aiOengOQQX94Ls+WS
"text/plain": [
"<Figure size 720x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2022-05-24 14:37:17 +02:00
"# extra code this cell generates and saves Figure 424\n",
"\n",
2021-11-03 23:35:15 +01:00
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = iris.target_names[iris.target] == 'virginica'\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n",
"\n",
2021-11-03 23:35:15 +01:00
"log_reg = LogisticRegression(C=2, random_state=42)\n",
"log_reg.fit(X_train, y_train)\n",
"\n",
2021-11-03 23:35:15 +01:00
"# for the contour plot\n",
"x0, x1 = np.meshgrid(np.linspace(2.9, 7, 500).reshape(-1, 1),\n",
" np.linspace(0.8, 2.7, 200).reshape(-1, 1))\n",
"X_new = np.c_[x0.ravel(), x1.ravel()] # one instance per point on the figure\n",
"y_proba = log_reg.predict_proba(X_new)\n",
"zz = y_proba[:, 1].reshape(x0.shape)\n",
"\n",
2021-11-03 23:35:15 +01:00
"# for the decision boundary\n",
"left_right = np.array([2.9, 7])\n",
2021-11-03 23:35:15 +01:00
"boundary = -((log_reg.coef_[0, 0] * left_right + log_reg.intercept_[0])\n",
" / log_reg.coef_[0, 1])\n",
"\n",
2021-11-03 23:35:15 +01:00
"plt.figure(figsize=(10, 4))\n",
"plt.plot(X_train[y_train == 0, 0], X_train[y_train == 0, 1], \"bs\")\n",
"plt.plot(X_train[y_train == 1, 0], X_train[y_train == 1, 1], \"g^\")\n",
"contour = plt.contour(x0, x1, zz, cmap=plt.cm.brg)\n",
"plt.clabel(contour, inline=1)\n",
"plt.plot(left_right, boundary, \"k--\", linewidth=3)\n",
2021-11-27 01:38:47 +01:00
"plt.text(3.5, 1.27, \"Not Iris virginica\", color=\"b\", ha=\"center\")\n",
"plt.text(6.5, 2.3, \"Iris virginica\", color=\"g\", ha=\"center\")\n",
2021-11-03 23:35:15 +01:00
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.axis([2.9, 7, 0.8, 2.7])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
"save_fig(\"logistic_regression_contour_plot\")\n",
"plt.show()"
]
},
2021-11-03 23:35:15 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Softmax Regression"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"LogisticRegression(C=30, random_state=42)"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-11-03 23:35:15 +01:00
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = iris[\"target\"]\n",
2021-11-03 23:35:15 +01:00
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n",
"\n",
2021-11-03 23:35:15 +01:00
"softmax_reg = LogisticRegression(C=30, random_state=42)\n",
"softmax_reg.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 59,
2021-11-03 23:35:15 +01:00
"metadata": {
"tags": []
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([2])"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
2021-11-03 23:35:15 +01:00
"source": [
"softmax_reg.predict([[5, 2]])"
]
},
{
"cell_type": "code",
"execution_count": 60,
2021-11-03 23:35:15 +01:00
"metadata": {
"tags": []
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0. , 0.04, 0.96]])"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
2021-11-03 23:35:15 +01:00
"source": [
"softmax_reg.predict_proba([[5, 2]]).round(2)"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAAEQCAYAAAC++cJdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAD+5UlEQVR4nOydd3hUddbHP3dKepn0XkmjE3oICUgTQSmCIE0g7Lq7rmtd9V17X9fuWlelV0GqCiKKmITee0ISkpDeJ8mkTrnvHwlISCCTEJr8Ps8zmkx+5dx7h5kz557zPZIsywgEAoFAIBAIBLcLihttgEAgEAgEAoFAcD0RDrBAIBAIBAKB4LZCOMACgUAgEAgEgtsK4QALBAKBQCAQCG4rhAMsEAgEAoFAILitEA6wQCAQCAQCgeC24ro5wJIkWUmStE+SpKOSJJ2UJOmVFsYMlSSpXJKkI42PF6+XfQKBQCAQCASC2wPVddyrDhgmy7JOkiQ1kChJ0hZZlvdcMi5BluW7r6NdAoFAIBAIBILbiOvmAMsNHTd0jb+qGx+iC4dAIBAIBAKB4LpyXXOAJUlSSpJ0BCgEtsmyvLeFYVGNaRJbJEnqej3tEwgEAoFAIBD88ZFuRCtkSZI0wHrgH7Isn7joeQfA1JgmMQb4SJbl0BbmPwg8CGBlZdXH39/n+hh+DTGZQCFKEm8K2nMtDAYlqAzXxqDbGMkkISvEjaKbBXE9bh7Etbh5ENfi5uLsmbPFsiy7tTbuhjjAAJIkvQRUybL87hXGZAB9ZVkuvtyYPn1C5D173rsGFl5fEhNh8OAbbYUA2nctFiwYj1Xc4mtj0G2MXaIdusG61gcKrgvietw8iGtx8yCuxc3FHIs5B2VZ7tvauOupAuHWGPlFkiRrYASQdMkYT0mSpMaf+zfaV3K9bBQI2ktc3Eamo7nRZggEAoFAIDCD66kC4QUsliRJSYNju1qW5e8lSforgCzLXwCTgb9JkmQAaoD75RsVohYI2ogsy8QeKWXZ2yYMdhupyy7FqK1GNpqQ1EpUGhtUrvZY+Dpj1ckdqzBPbLr6onKyvdGmCwQCgUBwW3E9VSCOAZEtPP/FRT9/AnxyvWwSCDoCk9HIwVUJbHvrWwqSswGw99AgB7qgcrVDUiqQDUYM2mpqU/KpzylD1hsvzLcIcMG2dxD2AzphFx2KbWQgCovr+d1UIBAIBILbC/EpKxBcBVWllSyZ9R5J247g3T2Q+7/4O13H9MXB0wloOTdYNpqoO1dCbXIe1cezqDqSSdWBdMrWHwBAYW2BXVQIDkM74ziiK7a9A5FEhaRAIBAIBB2GcIAFgnZSU17Ff4c9R1FqLlM+/RtR80aiuMRRjYvbCGhYgfbCc5JSgVWQG1ZBbmhG97jwvL6gnMqdZ6jceYaKHUlkv7iW7BfXonKxw3F4VxxH90BzV0/ULnbX6QgFAoFAIPhjIhxggaCdrH74CwrP5PDX718kfFjPK46tXTC7VZUItYcjzvf2w/nefkCDQ1z+y0nKt51Au+0EJav3gkLCbmAITndH4jSuN9Zhnh12PAKBQCAQ3C4IB1ggaAcpvx3n0DcJ3PXitFadX2iIBK9o4x5qD0dcpw/CdfogZJOJqkMZaL8/Qtnmo2Q9u5qsZ1dj3dkbp/F9cJ7YF5te/jSKqAgEAoFAILgCwgEWCNrBL++tx8HTieH/nGj2nOlo2q0XLCkU2PUNxq5vML4v30vduRLKNh2idONBct/+nty3vsMyxAPnCX1wntSvIW9YOMMCgUAgELSIcIAFgjZSWagl6acjjHjqXtRWFm2a21JOcHuw9HfB8+GReD48En1RBWXfHabk233kffAjee9uxrKTOy6T++Ny3wCsu/sKZ1ggEAgEgosQpeUCQRtJ2nYE2WSi58SB7V5jOhpqF8zuEHvUbg64xw2h8+an6J39X4L+NxfLQDdy3/mB431f4Hiv58h5cxO1qQUdsp9AIBAIBLc6IgIsELSR9D1JWDnY4NMr+KrW6aho8MWoXexwnzsE97lD0BdVULruACWr9pD98jqyX16Hbb9gXKdF4Ty5Hxaemg7bVyAQCASCWwkRARYI2kju8Uy8uwc0kzxrLx0ZDb4YtZsDHn8ZRpdfn6VX2nv4vzUVud5A5hPLORz4OElj36V4+S6MutoO31sgEAgEgpsZEQEWCNpIydl8Oo/u3aFrXoto8MVY+rng9cRdeD1xF9WncihZtYeSVbtJm/slCmsLnMb3xnVmNI7DuiCplNfEBoFAIBAIbhaEAywQtAGT0UhlgRaNj+uNNqXd2HTxwebVSfi+PBHd7lSKV+6mZM0+SlbtQe3piMvUgbjOjMa2p/+NNlUgEAgEgmuCcIAFgjZQXVaFLMvYudpfk/WnX8Mo8KVICgX20WHYR4cR8O40tFuOUbxiNwWf/Uz+R1ux7uaL28xoXKZHiXxhgUAgEPyhEA6wQNAG6iprALC0t75me0xHw8L/3Yljv/9gcSYfZUkliqo6kCRMtpYYNTYYvJ3Q+7ugD3AF5dXnIiusLHCe2BfniX3Rl+goWb2H4uW7OPd/33Du2dU4juyG28xonMb1RmHdNuk3gUAgEAhuNoQDLBC0AX1tPQBqy453AmVZRrf9GEWf/UDvn6cg19S3PketpC7ci9ruftRGBlDTvxO1kQHIV+Gkql3s8PzbCDz/NoKapFyKl++ieMUuUmd9gdLRGpf7BuD6wGDsBnQS+sICgUAguCURDrBA0AZMBiMAig4uFNPnl5E5+310vx5H5emEy9wR2MV0wzLCl9W/TMPiz+tBllFU1aEs0aHK02KRUYTFmQIsT+dgG5+EZuVuoMEprukTRPXgMKqGdqZ6cBiyjWW77LKO8Mbvtcn4vnIvFb+epnjZTopX7KLw6x1YhXriOnMQrjOjsfRz6cjTIRAIBALBNUU4wAJBW2iMeMqy3GFL1p7JIe3OFzGUVuLzwZ9wmTcKxUUd5mZ33Qm4swItRhtLjG4O1Ed4U33JOsp8Ldb707HZnYJN4hlcPtyK67ubMVmoqBkUim5kN3Sje1DXzffCcZiLpFDgOLwrjsO7EvjfWZSs3U/x0p1kv7SO7JfX4ziiK64zBuE0oQ/KdjrbAoFAIBBcL4QDLBC0AUVjvq1s6hgH2KirIX3iG8j1BkJ/ewubKzTXaK1AzuipQXdPJLp7IgGQquqw2XkG2+0nsfvpBB7PrcHjuTXo/ZypHNsL3ZheVN3RGdlS3SablfbWuM+JxX1OLLVnCylekkjR8p2kzfkSpYM1LlMH4DpLpEgIBAKB4OZFOMACQRtQWTT8kzHU6TtkvYI3VlOXkkunba9d0fk9T+2C2VjFLTZrbdnWkqpR3aka1Z3Ct0CVW4bd1mPYfX8EzdKdOH+xHaOdFbrRPagc3xvdXT0xObStuM8q2B3fl+/F58UJVMQnU7QonuJluyj8agdWYZ64zY7BdcYgLLyd2rSuQCAQCATXEuEACwRtQG3dcHu/vqbuqtcyFJVT9OkPOM28A/sh3c2ed3HXOHOdYQCDtxPauUPQzh2CVFuP7a+nsf/uMPabDuH47T5MFiqqRnSlYlJ/Ku/phUlja/bakkKB49DOOA7tjOG/NZSu20/RogSynltD1gvformzO26zY9CM7YWijRFngUAgEAg6GuEACwRtwMK20QGuuvr2wWXfJCDX1uP+5MR2r9FeZ1i2skB3V090d/Uk7+MHsN6bhsP6Azis24/95qPIaiW6kd2ouG8AlfdEtikyrHK4KEUiJZ+iJYkULduJ9v5PUTnb4jItCrfZMdj2CmjTsQoEAoFA0FEIB1ggaAOWdlbA73rAV0P5pr1YdfXHumvHdFw77wy3xREGQKmgZlAoNYNCKXj7fqz3n8Vh7X4cvt2H/eajmCxV6Eb3pHzqAHRje7VJYs0q1LNBReLleyn/5SRFixIo/GoHBZ/+jE0PP9zihuAydSBqF7u22SwQCAQCwVUgHGCBoA0olEos7ayovUoH2FSvp2p3Eq5/vauDLPud9kaFAZAkavp3oqZ/Jwr+PQXrfWdxWL0Xh2/34bDxIEY7KyrH96b8/oFUDe8KZsr
"text/plain": [
"<Figure size 720x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2022-05-24 14:37:17 +02:00
"# extra code this cell generates and saves Figure 425\n",
2021-11-21 05:36:22 +01:00
"\n",
2021-11-03 23:35:15 +01:00
"from matplotlib.colors import ListedColormap\n",
"\n",
"custom_cmap = ListedColormap([\"#fafab0\", \"#9898ff\", \"#a0faa0\"])\n",
"\n",
2021-11-03 23:35:15 +01:00
"x0, x1 = np.meshgrid(np.linspace(0, 8, 500).reshape(-1, 1),\n",
" np.linspace(0, 3.5, 200).reshape(-1, 1))\n",
"X_new = np.c_[x0.ravel(), x1.ravel()]\n",
"\n",
"y_proba = softmax_reg.predict_proba(X_new)\n",
"y_predict = softmax_reg.predict(X_new)\n",
"\n",
"zz1 = y_proba[:, 1].reshape(x0.shape)\n",
"zz = y_predict.reshape(x0.shape)\n",
"\n",
"plt.figure(figsize=(10, 4))\n",
2021-11-03 23:35:15 +01:00
"plt.plot(X[y == 2, 0], X[y == 2, 1], \"g^\", label=\"Iris virginica\")\n",
"plt.plot(X[y == 1, 0], X[y == 1, 1], \"bs\", label=\"Iris versicolor\")\n",
"plt.plot(X[y == 0, 0], X[y == 0, 1], \"yo\", label=\"Iris setosa\")\n",
"\n",
"plt.contourf(x0, x1, zz, cmap=custom_cmap)\n",
2021-11-03 23:35:15 +01:00
"contour = plt.contour(x0, x1, zz1, cmap=\"hot\")\n",
"plt.clabel(contour, inline=1)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"center left\")\n",
"plt.axis([0.5, 7, 0, 3.5])\n",
"plt.grid()\n",
"save_fig(\"softmax_regression_contour_plot\")\n",
"plt.show()"
]
},
2016-09-27 16:39:16 +02:00
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 16:39:16 +02:00
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"## 1. to 11."
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"1. If you have a training set with millions of features you can use Stochastic Gradient Descent or Mini-batch Gradient Descent, and perhaps Batch Gradient Descent if the training set fits in memory. But you cannot use the Normal Equation or the SVD approach because the computational complexity grows quickly (more than quadratically) with the number of features.\n",
"2. If the features in your training set have very different scales, the cost function will have the shape of an elongated bowl, so the Gradient Descent algorithms will take a long time to converge. To solve this you should scale the data before training the model. Note that the Normal Equation or SVD approach will work just fine without scaling. Moreover, regularized models may converge to a suboptimal solution if the features are not scaled: since regularization penalizes large weights, features with smaller values will tend to be ignored compared to features with larger values.\n",
"3. Gradient Descent cannot get stuck in a local minimum when training a Logistic Regression model because the cost function is convex. _Convex_ means that if you draw a straight line between any two points on the curve, the line never crosses the curve.\n",
"4. If the optimization problem is convex (such as Linear Regression or Logistic Regression), and assuming the learning rate is not too high, then all Gradient Descent algorithms will approach the global optimum and end up producing fairly similar models. However, unless you gradually reduce the learning rate, Stochastic GD and Mini-batch GD will never truly converge; instead, they will keep jumping back and forth around the global optimum. This means that even if you let them run for a very long time, these Gradient Descent algorithms will produce slightly different models.\n",
"5. If the validation error consistently goes up after every epoch, then one possibility is that the learning rate is too high and the algorithm is diverging. If the training error also goes up, then this is clearly the problem and you should reduce the learning rate. However, if the training error is not going up, then your model is overfitting the training set and you should stop training.\n",
"6. Due to their random nature, neither Stochastic Gradient Descent nor Mini-batch Gradient Descent is guaranteed to make progress at every single training iteration. So if you immediately stop training when the validation error goes up, you may stop much too early, before the optimum is reached. A better option is to save the model at regular intervals; then, when it has not improved for a long time (meaning it will probably never beat the record), you can revert to the best saved model.\n",
"7. Stochastic Gradient Descent has the fastest training iteration since it considers only one training instance at a time, so it is generally the first to reach the vicinity of the global optimum (or Mini-batch GD with a very small mini-batch size). However, only Batch Gradient Descent will actually converge, given enough training time. As mentioned, Stochastic GD and Mini-batch GD will bounce around the optimum, unless you gradually reduce the learning rate.\n",
"8. If the validation error is much higher than the training error, this is likely because your model is overfitting the training set. One way to try to fix this is to reduce the polynomial degree: a model with fewer degrees of freedom is less likely to overfit. Another thing you can try is to regularize the model—for example, by adding an ℓ₂ penalty (Ridge) or an ℓ₁ penalty (Lasso) to the cost function. This will also reduce the degrees of freedom of the model. Lastly, you can try to increase the size of the training set.\n",
"9. If both the training error and the validation error are almost equal and fairly high, the model is likely underfitting the training set, which means it has a high bias. You should try reducing the regularization hyperparameter _α_.\n",
"10. Let's see:\n",
" * A model with some regularization typically performs better than a model without any regularization, so you should generally prefer Ridge Regression over plain Linear Regression.\n",
" * Lasso Regression uses an ℓ₁ penalty, which tends to push the weights down to exactly zero. This leads to sparse models, where all weights are zero except for the most important weights. This is a way to perform feature selection automatically, which is good if you suspect that only a few features actually matter. When you are not sure, you should prefer Ridge Regression.\n",
" * Elastic Net is generally preferred over Lasso since Lasso may behave erratically in some cases (when several features are strongly correlated or when there are more features than training instances). However, it does add an extra hyperparameter to tune. If you want Lasso without the erratic behavior, you can just use Elastic Net with an `l1_ratio` close to 1.\n",
"11. If you want to classify pictures as outdoor/indoor and daytime/nighttime, since these are not exclusive classes (i.e., all four combinations are possible) you should train two Logistic Regression classifiers."
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"## 12. Batch Gradient Descent with early stopping for Softmax Regression\n",
"Exercise: _Implement Batch Gradient Descent with early stopping for Softmax Regression without using Scikit-Learn, only NumPy. Use it on a classification task such as the iris dataset._"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Let's start by loading the data. We will just reuse the Iris dataset we loaded earlier."
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
2021-11-03 23:35:15 +01:00
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = iris[\"target\"].values"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"We need to add the bias term for every instance ($x_0 = 1$). The easiest option to do this would be to use Scikit-Learn's `add_dummy_feature()` function, but the point of this exercise is to get a better understanding of the algorithms by implementing them manually. So here is one possible implementation:"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
2021-11-03 23:35:15 +01:00
"X_with_bias = np.c_[np.ones(len(X)), X]"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2022-05-24 14:37:17 +02:00
"The easiest option to split the dataset into a training set, a validation set and a test set would be to use Scikit-Learn's `train_test_split()` function, but again, we want to do it manually:"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
"test_ratio = 0.2\n",
"validation_ratio = 0.2\n",
"total_size = len(X_with_bias)\n",
"\n",
"test_size = int(total_size * test_ratio)\n",
"validation_size = int(total_size * validation_ratio)\n",
"train_size = total_size - test_size - validation_size\n",
"\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
2017-05-29 23:20:14 +02:00
"rnd_indices = np.random.permutation(total_size)\n",
"\n",
"X_train = X_with_bias[rnd_indices[:train_size]]\n",
"y_train = y[rnd_indices[:train_size]]\n",
"X_valid = X_with_bias[rnd_indices[train_size:-test_size]]\n",
"y_valid = y[rnd_indices[train_size:-test_size]]\n",
"X_test = X_with_bias[rnd_indices[-test_size:]]\n",
"y_test = y[rnd_indices[-test_size:]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2022-05-24 14:37:17 +02:00
"The targets are currently class indices (0, 1 or 2), but we need target class probabilities to train the Softmax Regression model. Each instance will have target class probabilities equal to 0.0 for all classes except for the target class which will have a probability of 1.0 (in other words, the vector of class probabilities for any given instance is a one-hot vector). Let's write a small function to convert the vector of class indices into a matrix containing a one-hot vector for each instance. To understand this code, you need to know that `np.diag(np.ones(n))` creates an n×n matrix full of 0s except for 1s on the main diagonal. Moreover, if `a` is a NumPy array, then `a[[1, 3, 2]]` returns an array with 3 rows equal to `a[1]`, `a[3]` and `a[2]` (this is [advanced NumPy indexing](https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing))."
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
"def to_one_hot(y):\n",
2021-11-03 23:35:15 +01:00
" return np.diag(np.ones(y.max() + 1))[y]"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Let's test this function on the first 10 instances:"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([1, 0, 2, 1, 1, 0, 1, 2, 1, 1])"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
"y_train[:10]"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.]])"
]
},
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
"to_one_hot(y_train[:10])"
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Looks good, so let's create the target class probabilities matrix for the training set and the test set:"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
"Y_train_one_hot = to_one_hot(y_train)\n",
"Y_valid_one_hot = to_one_hot(y_valid)\n",
"Y_test_one_hot = to_one_hot(y_test)"
]
},
2021-11-03 23:35:15 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's scale the inputs. We compute the mean and standard deviation of each feature on the training set (except for the bias feature), then we center and scale each feature in the training set, the validation set, and the test set:"
]
},
{
"cell_type": "code",
"execution_count": 69,
2021-11-03 23:35:15 +01:00
"metadata": {},
"outputs": [],
"source": [
"mean = X_train[:, 1:].mean(axis=0)\n",
"std = X_train[:, 1:].std(axis=0)\n",
"X_train[:, 1:] = (X_train[:, 1:] - mean) / std\n",
"X_valid[:, 1:] = (X_valid[:, 1:] - mean) / std\n",
"X_test[:, 1:] = (X_test[:, 1:] - mean) / std"
]
},
2017-05-29 23:20:14 +02:00
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Now let's implement the Softmax function. Recall that it is defined by the following equation:\n",
"\n",
"$\\sigma\\left(\\mathbf{s}(\\mathbf{x})\\right)_k = \\dfrac{\\exp\\left(s_k(\\mathbf{x})\\right)}{\\sum\\limits_{j=1}^{K}{\\exp\\left(s_j(\\mathbf{x})\\right)}}$"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
"def softmax(logits):\n",
" exps = np.exp(logits)\n",
2021-11-03 23:35:15 +01:00
" exp_sums = exps.sum(axis=1, keepdims=True)\n",
2017-05-29 23:20:14 +02:00
" return exps / exp_sums"
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"We are almost ready to start training. Let's define the number of inputs and outputs:"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
2016-09-27 16:39:16 +02:00
"outputs": [],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"n_inputs = X_train.shape[1] # == 3 (2 features plus the bias term)\n",
"n_outputs = len(np.unique(y_train)) # == 3 (there are 3 iris classes)"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Now here comes the hardest part: training! Theoretically, it's simple: it's just a matter of translating the math equations into Python code. But in practice, it can be quite tricky: in particular, it's easy to mix up the order of the terms, or the indices. You can even end up with code that looks like it's working but is actually not computing exactly the right thing. When unsure, you should write down the shape of each term in the equation and make sure the corresponding terms in your code match closely. It can also help to evaluate each term independently and print them out. The good news it that you won't have to do this everyday, since all this is well implemented by Scikit-Learn, but it will help you understand what's going on under the hood.\n",
"\n",
"So the equations we will need are the cost function:\n",
"\n",
"$J(\\mathbf{\\Theta}) =\n",
"- \\dfrac{1}{m}\\sum\\limits_{i=1}^{m}\\sum\\limits_{k=1}^{K}{y_k^{(i)}\\log\\left(\\hat{p}_k^{(i)}\\right)}$\n",
"\n",
"And the equation for the gradients:\n",
"\n",
"$\\nabla_{\\mathbf{\\theta}^{(k)}} \\, J(\\mathbf{\\Theta}) = \\dfrac{1}{m} \\sum\\limits_{i=1}^{m}{ \\left ( \\hat{p}^{(i)}_k - y_k^{(i)} \\right ) \\mathbf{x}^{(i)}}$\n",
"\n",
"Note that $\\log\\left(\\hat{p}_k^{(i)}\\right)$ may not be computable if $\\hat{p}_k^{(i)} = 0$. So we will add a tiny value $\\epsilon$ to $\\log\\left(\\hat{p}_k^{(i)}\\right)$ to avoid getting `nan` values."
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 3.7085808486476917\n",
"1000 0.14519367480830644\n",
"2000 0.1301309575504088\n",
"3000 0.12009639326384539\n",
"4000 0.11372961364786884\n",
"5000 0.11002459532472425\n"
]
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"eta = 0.5\n",
"n_epochs = 5001\n",
2017-05-29 23:20:14 +02:00
"m = len(X_train)\n",
2021-11-03 23:35:15 +01:00
"epsilon = 1e-5\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
2017-05-29 23:20:14 +02:00
"Theta = np.random.randn(n_inputs, n_outputs)\n",
"\n",
2021-11-03 23:35:15 +01:00
"for epoch in range(n_epochs):\n",
" logits = X_train @ Theta\n",
2017-05-29 23:20:14 +02:00
" Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
" if epoch % 1000 == 0:\n",
" Y_proba_valid = softmax(X_valid @ Theta)\n",
" xentropy_losses = -(Y_valid_one_hot * np.log(Y_proba_valid + epsilon))\n",
" print(epoch, xentropy_losses.sum(axis=1).mean())\n",
2021-03-02 05:26:41 +01:00
" error = Y_proba - Y_train_one_hot\n",
2021-11-03 23:35:15 +01:00
" gradients = 1 / m * X_train.T @ error\n",
2017-05-29 23:20:14 +02:00
" Theta = Theta - eta * gradients"
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"And that's it! The Softmax model is trained. Let's look at the model parameters:"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.41931626, 6.11112089, -5.52429876],\n",
" [-6.53054533, -0.74608616, 8.33137102],\n",
" [-5.28115784, 0.25152675, 6.90680425]])"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
"Theta"
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Let's make predictions for the validation set and check the accuracy score:"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9333333333333333"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"logits = X_valid @ Theta\n",
2017-05-29 23:20:14 +02:00
"Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
"y_predict = Y_proba.argmax(axis=1)\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"accuracy_score = (y_predict == y_valid).mean()\n",
2017-05-29 23:20:14 +02:00
"accuracy_score"
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"Well, this model looks pretty ok. For the sake of the exercise, let's add a bit of $\\ell_2$ regularization. The following training code is similar to the one above, but the loss now has an additional $\\ell_2$ penalty, and the gradients have the proper additional term (note that we don't regularize the first element of `Theta` since this corresponds to the bias term). Also, let's try increasing the learning rate `eta`."
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 3.7372\n",
"1000 0.3259\n",
"2000 0.3259\n",
"3000 0.3259\n",
"4000 0.3259\n",
"5000 0.3259\n"
]
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"eta = 0.5\n",
"n_epochs = 5001\n",
2017-05-29 23:20:14 +02:00
"m = len(X_train)\n",
2021-11-03 23:35:15 +01:00
"epsilon = 1e-5\n",
"alpha = 0.01 # regularization hyperparameter\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
2017-05-29 23:20:14 +02:00
"Theta = np.random.randn(n_inputs, n_outputs)\n",
"\n",
2021-11-03 23:35:15 +01:00
"for epoch in range(n_epochs):\n",
" logits = X_train @ Theta\n",
2017-05-29 23:20:14 +02:00
" Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
" if epoch % 1000 == 0:\n",
" Y_proba_valid = softmax(X_valid @ Theta)\n",
" xentropy_losses = -(Y_valid_one_hot * np.log(Y_proba_valid + epsilon))\n",
" l2_loss = 1 / 2 * (Theta[1:] ** 2).sum()\n",
" total_loss = xentropy_losses.sum(axis=1).mean() + alpha * l2_loss\n",
" print(epoch, total_loss.round(4))\n",
2021-03-02 05:26:41 +01:00
" error = Y_proba - Y_train_one_hot\n",
2021-11-03 23:35:15 +01:00
" gradients = 1 / m * X_train.T @ error\n",
" gradients += np.r_[np.zeros([1, n_outputs]), alpha * Theta[1:]]\n",
2017-05-29 23:20:14 +02:00
" Theta = Theta - eta * gradients"
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Because of the additional $\\ell_2$ penalty, the loss seems greater than earlier, but perhaps this model will perform better? Let's find out:"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9333333333333333"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"logits = X_valid @ Theta\n",
2017-05-29 23:20:14 +02:00
"Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
"y_predict = Y_proba.argmax(axis=1)\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"accuracy_score = (y_predict == y_valid).mean()\n",
2017-05-29 23:20:14 +02:00
"accuracy_score"
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"In this case, the $\\ell_2$ penalty did not change the test accuracy. Perhaps try fine-tuning `alpha`?"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Now let's add early stopping. For this we just need to measure the loss on the validation set at every iteration and stop when the error starts growing."
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 3.7372\n",
"281 0.3256\n",
"282 0.3256 early stopping!\n"
]
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"eta = 0.5\n",
"n_epochs = 50_001\n",
2017-05-29 23:20:14 +02:00
"m = len(X_train)\n",
2021-11-03 23:35:15 +01:00
"epsilon = 1e-5\n",
"C = 100 # regularization hyperparameter\n",
2017-05-29 23:20:14 +02:00
"best_loss = np.infty\n",
"\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
2017-05-29 23:20:14 +02:00
"Theta = np.random.randn(n_inputs, n_outputs)\n",
"\n",
2021-11-03 23:35:15 +01:00
"for epoch in range(n_epochs):\n",
" logits = X_train @ Theta\n",
2017-05-29 23:20:14 +02:00
" Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
" Y_proba_valid = softmax(X_valid @ Theta)\n",
" xentropy_losses = -(Y_valid_one_hot * np.log(Y_proba_valid + epsilon))\n",
" l2_loss = 1 / 2 * (Theta[1:] ** 2).sum()\n",
" total_loss = xentropy_losses.sum(axis=1).mean() + 1 / C * l2_loss\n",
" if epoch % 1000 == 0:\n",
" print(epoch, total_loss.round(4))\n",
" if total_loss < best_loss:\n",
" best_loss = total_loss\n",
2017-05-29 23:20:14 +02:00
" else:\n",
2021-11-03 23:35:15 +01:00
" print(epoch - 1, best_loss.round(4))\n",
" print(epoch, total_loss.round(4), \"early stopping!\")\n",
" break\n",
" error = Y_proba - Y_train_one_hot\n",
" gradients = 1 / m * X_train.T @ error\n",
" gradients += np.r_[np.zeros([1, n_outputs]), 1 / C * Theta[1:]]\n",
" Theta = Theta - eta * gradients"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9333333333333333"
]
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"logits = X_valid @ Theta\n",
2017-05-29 23:20:14 +02:00
"Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
"y_predict = Y_proba.argmax(axis=1)\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"accuracy_score = (y_predict == y_valid).mean()\n",
2017-05-29 23:20:14 +02:00
"accuracy_score"
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2022-05-24 14:37:17 +02:00
"Oh well, still no change in validation accuracy, but at least early stopping shortened training a bit."
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"Now let's plot the model's predictions on the whole dataset (remember to scale all features fed to the model):"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmgAAAEOCAYAAAA9quuTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADa1ElEQVR4nOyddZxUVf/H3+dOb8x2J7vUkqJiIYLYqBiAndjd7c8Wux8Vu+sxUYzHRMICA4UFpNmG7Z3d6Xt+f8zsxAa7sA33/XrNa2bOPffeM4dl9zPfFFJKNDQ0NDQ0NDQ0+g9KXy9AQ0NDQ0NDQ0MjHE2gaWhoaGhoaGj0MzSBpqGhoaGhoaHRz9AEmoaGhoaGhoZGP0MTaBoaGhoaGhoa/QxNoGloaGhoaGho9DN6TaAJIcxCiN+EEMuEECuEEHe2MWeyEKJOCPGX/3Fbb61PQ0NDQ0NDQ6O/oO/FezmBKVJKmxDCACwSQnwppfylxbyFUsqjenFdGhoaGhoaGhr9il4TaNJXEdfmf2vwP7QquRoaGhoaGhoaLehNCxpCCB3wOzAYeFpK+Wsb0/YVQiwDSoFrpZQr2rjO+cD5AGazeY/07PQeXPXOjVAFUtF08o6i7d+Oo+1d19D2r2to+7fjaHvXNdb/u75SSpnU0TzRF62ehBCxwMfAZVLK5SHjVkD1u0GnAk9IKYds61r5w/Llbcu1ULUdJWpRFLb9bR1P1GgTbf92HG3vuoa2f11D278dR9u7rnGW8azfpZR7djSvT7I4pZS1wHzg8Bbj9VJKm//1F4BBCJHY6wvU0NDQ0NDQ0OhDejOLM8lvOUMIYQEOBla1mJMqhBD+13v511fVW2vU0NDQ0NDQ0OgP9GYMWhrwmj8OTQH+K6WcJ4S4EEBKOQeYAVwkhPAAduAk2Rc+WA0NDQ0NDQ2NPqQ3szj/Bsa1MT4n5PV/gP/01po0NDQ0NDQ0NPojWicBDQ0NDQ0NDY1+Rq+W2egLohqiSK9KR+/e6T/qDqHEKqhr1L5eRr/DY/BQmlCKLVrLVNLQ0NDQ6H12atUS1RBFbmUuaRlpGMwG/PkHGiHobDq8Ud6+Xka/QkqJ2+HGWGJkIxs1kaahoaGh0evs1C7O9Kp00jLSMFqMmjjT6DRCCIwWI2kZaaRXaUWQNTQ0NDR6n51aoOndegxmQ18vQ2OAYjAbNNe4hoaGhkafsFMLNECznGnsMNrPjoaGhoZGX7HTCzQNDQ0NDQ0NjYGGJtA0NDQ0NDQ0NPoZmkAbwBwz5RhuuOyGHrv+pWdfyilHn9Ll6yyev5gkXRJVlZ3v2vXOq++QY83p8r01NDQ0NDQGIloEdD/k0rMvpbqymrc/e3ub81794FUMhp5Lgpj9+Gy6o9PW+P3Gs7xkOfEJ8Z0+59gTj+XgqQd3+d4aGhoaGhoDEc2C1knKG8qZ9to0KmwVfb0UXC4XAHHxcURFR/XYfawxVmJiYzpcR0cYjUZSUlO2K+jeYrGQlJzU6fkaGhoaGho7EwNeoHkdTuyrSlFdnh69zyMLH+HXol95ZMEjPXqftmh2NT754JOMyR7D2OyxQGsX57yP5jFpt0lkRWYxJHEI0w6cxpaKLW1e8/xTzuesGWeFjamqyticscx5fE7YfZs5ZsoxXHfxddx+3e0MTxnOkROPBODrz79mn4J9yIzI5OjJR/Pxux+TpEti88bNQGsXZ7P7csF3C5g4ZiI50Tkce9CxbNqwKXCvtlycX3/+NYftexhZkVkMTRrKqdNOxeFwAPD+m+9zyN6HkBuTS0FqAbNOmEVZSdl277WGhoaGhkZ/YMC7OB3lW/l7zM2gCEy5iZiHpmEenIJ5SArD9x+CdHkQxq59zPKGct5d9i6qVHln2Ttcc8A1pESldNMn6Bw/LfgJa4yV9754r023Y0V5Beefcj63zr6Vo44/ikZbI0t/Wdru9WacOoNZM2dRV1tHVJTPCvfTjz9RUVbBcScd1+5577/1Pmecdwaf/fgZUkqKNxdz9oyzmXXxLM48/0wK/ynktmtv6/DzuJwunnjgCZ548QlMZhOXnn0p1150Le9/9X6b87/76jvOOO4MLr/hcp586Uk8Hg/zv5mPqvraVLlcLq6//XqGDB9CdWU1d910F+efej6fzf+sw7VoaGhoaGj0Nwa8QDOnJpF3x1Qca8pxrKnAsbaChoWrURuduL88lSa1GIRAMRtQTHqE2YBiMqCYDQizAaHr2Ij4yMJHUKVPCKhS5ZEFj/Dg1Ad7+qOFYTabeeKlJzCZTG0eLy8tx+12c/T0o8nKyQKgYFRBu9ebctgUoq3RfPrpp5xysc9K9sHbHzBxykRSUtsXnzmDcrjr4bsC7++++W5y8nxjQggGDxvMujXrmH3r7G1+Ho/HwwNPPcDgYYMBuOTqS7j8nMtRVRVFaf1v8si9j3D09KO5+e6bA2Mjx4wMvD511qmB17l5uTz09EPsN3I/SotLSc/UugFoaGhoaAwsBrxA05lNJJ02IWxMSom7rBZDtRVTRgKq043qcKPa3ah1TRBigBJ6JSDWAsLN5BNzCBGwnrm8vngrl9fVJ1a04aOGtyvOAEaNHcWkgyYxccxEJh8ymUkHTeLoGUeTmJTY5ny9Xs8xJxzDB//9gFMuPgWn08m8j+Yx+/FtC6sxu48Je7921VrG7TkuLL5sj7326PDzmEymgDgDSE1Pxe12U1dbR1x8XKv5y/9czslnnNzu9Zb9sYyH73qY5cuWU1NdE7AyFm8u1gSahoaGhsaAY8DHoLWFEAJjehyKyYA+MRpjRjzm/BQsIzOI3C0Hy8gMzPnJGDPi0MVGAAJvbROukhoc67ZgLyyh8a9N2AtLeHDePahqeDPxZitabxIREbHN4zqdjvf/9z7vf/U+I8eM5K1X3mLvYXuzfNnyds+ZeepMFi9eTFlJGd98/g1ul5upx03d9joiw9chpYQdKLiv0+vC3jcLvGaX5fbQ2NjIiUeciCXCwtOvPc03v37De1+8B4Db5d7+xWloaGhoaPQxA96Ctt0IgWIygMmArkWCovSqyGZrm8ONdLj5fetfuNTwP/Iur4tf1/2Ec3Ol33Vq9Fve+nY7hRCM33c84/cdz7X/dy37j96fT/77CaPGjmpz/h5778GgQYP46N2PWPLzEo445ohAPFpnGVIwhK8+/Sps7I8lf+zwZ2iPUeNGseD7BZx+3umtjq1dtZaqyipuufcWcgb5EgtWf7S629egoaGhoaHRW+x6Am0bCJ2CiDChRARdiT9evggA6fQEXaV+8eattePx2EIuIFBM+qDL1OwTborZAErP9nVc+stSfvzuR6YcOoWklCT++fMfSopKGFYwbJvnzThhBm++9CZFG4t49cNXt/u+Z11wFnMem8Pt193O6eeezqoVq3j9+deB7u1ledVNV3HaMacxaPAgpp88HSkl87+Zzxnnn0FGdgYmk4mXnn6JWRfP4t+V/3L/7fd32701NDQ0NDR6m53SxdkTCJMendWCIdmKKTsB89BUIsZkETE2G8uwNEw5iRiSrQiTAdXuxl1eh3PjVuyrSmn8axNNy4txrK3AVVyNp7IBr82B9Hg7vnEnscZY+W3xb5w67VT2HrY3t113G9fceg0zT5u5zfNOOOEE1q5eizXGyuRDJm/3fbNysnjl/Vf46rOvmDxuMnOemMO1/3ctACZz+zFz28shUw/htQ9f47uvvmPKHlM45sBjWPTDIhRFITEpkadeeYov537J/qP25+G7Hw5LZNDQ0NDQ0BhoiO6oFN+X5A/Ll7ctb7usw4g1Ixg0fFAvr8iPlKhOD9LRbHVzoTo8SIcLqQb3XOhCkhRCrG695S7V2XR4o7pPKAI89+RzPHD7A6ytWttmRuZAYsOqDRQOKWz3eNSiKGz729o9rtE+2t51DW3/uoa2fzuOtndd4yzjWb9LKffsaJ7m4uwp/KU9MBvQtTgk3Z4wV6nqcOOtt+OpsrU6v/khQl7Tja7D7uClZ15i3J7jSEhK4PdffufRex7lpDNPGvDiTENDQ0NDo6/QBFofIAx6dAY9umhL2Lj0qi0sbm7UJieemsaweYrJgDDrgzF
"text/plain": [
"<Figure size 720x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-05-29 23:20:14 +02:00
"source": [
"custom_cmap = mpl.colors.ListedColormap(['#fafab0', '#9898ff', '#a0faa0'])\n",
2021-11-03 23:35:15 +01:00
"\n",
"x0, x1 = np.meshgrid(np.linspace(0, 8, 500).reshape(-1, 1),\n",
" np.linspace(0, 3.5, 200).reshape(-1, 1))\n",
2017-05-29 23:20:14 +02:00
"X_new = np.c_[x0.ravel(), x1.ravel()]\n",
2021-11-03 23:35:15 +01:00
"X_new = (X_new - mean) / std\n",
"X_new_with_bias = np.c_[np.ones(len(X_new)), X_new]\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"logits = X_new_with_bias @ Theta\n",
2017-05-29 23:20:14 +02:00
"Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
"y_predict = Y_proba.argmax(axis=1)\n",
2017-05-29 23:20:14 +02:00
"\n",
"zz1 = Y_proba[:, 1].reshape(x0.shape)\n",
"zz = y_predict.reshape(x0.shape)\n",
"\n",
"plt.figure(figsize=(10, 4))\n",
2021-11-03 23:35:15 +01:00
"plt.plot(X[y == 2, 0], X[y == 2, 1], \"g^\", label=\"Iris virginica\")\n",
"plt.plot(X[y == 1, 0], X[y == 1, 1], \"bs\", label=\"Iris versicolor\")\n",
"plt.plot(X[y == 0, 0], X[y == 0, 1], \"yo\", label=\"Iris setosa\")\n",
2017-05-29 23:20:14 +02:00
"\n",
"plt.contourf(x0, x1, zz, cmap=custom_cmap)\n",
2021-11-03 23:35:15 +01:00
"contour = plt.contour(x0, x1, zz1, cmap=\"hot\")\n",
"plt.clabel(contour, inline=1)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper left\")\n",
2017-05-29 23:20:14 +02:00
"plt.axis([0, 7, 0, 3.5])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
2017-05-29 23:20:14 +02:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"And now let's measure the final model's accuracy on the test set:"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9666666666666667"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"logits = X_test @ Theta\n",
2017-05-29 23:20:14 +02:00
"Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
"y_predict = Y_proba.argmax(axis=1)\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"accuracy_score = (y_predict == y_test).mean()\n",
2017-05-29 23:20:14 +02:00
"accuracy_score"
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"Well we get even better performance on the test set. This variability is likely due to the very small size of the dataset: depending on how you sample the training set, validation set and the test set, you can get quite different results. Try changing the random seed and running the code again a few times, you will see that the results will vary."
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
2016-09-27 16:39:16 +02:00
"source": []
}
],
"metadata": {
"kernelspec": {
2021-11-03 23:35:15 +01:00
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
2016-09-27 16:39:16 +02:00
"nav_menu": {},
"toc": {
2016-09-27 16:39:16 +02:00
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
2016-09-27 16:39:16 +02:00
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
}