2016-05-22 16:01:18 +02:00
{
"cells": [
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"source": [
2021-10-02 13:14:44 +02:00
"**Chapter 4 – Training Models**"
2016-09-27 16:39:16 +02:00
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-09-27 16:39:16 +02:00
"source": [
2017-08-19 17:01:55 +02:00
"_This notebook contains all the sample code and solutions to the exercises in chapter 4._"
2016-09-27 16:39:16 +02:00
]
},
2019-11-05 15:26:52 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
2021-11-23 03:42:16 +01:00
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/04_training_linear_models.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
2019-11-05 15:26:52 +01:00
" </td>\n",
2021-05-25 00:39:03 +02:00
" <td>\n",
2021-11-23 03:42:16 +01:00
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/04_training_linear_models.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
2021-05-25 00:39:03 +02:00
" </td>\n",
2019-11-05 15:26:52 +01:00
"</table>"
]
},
2016-09-27 16:39:16 +02:00
{
"cell_type": "markdown",
2021-11-03 23:35:15 +01:00
"metadata": {
"tags": []
},
2016-09-27 16:39:16 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-09-27 16:39:16 +02:00
"source": [
2022-02-19 11:03:20 +01:00
"This project requires Python 3.7 or above:"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2017-02-17 11:51:26 +01:00
"execution_count": 1,
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2019-01-16 16:42:00 +01:00
"import sys\n",
2016-05-22 16:01:18 +02:00
"\n",
2022-02-19 11:03:20 +01:00
"assert sys.version_info >= (3, 7)"
2021-11-03 23:35:15 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It also requires Scikit-Learn ≥ 1.0.1:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
2019-01-21 11:42:31 +01:00
"import sklearn\n",
2016-09-27 16:39:16 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"assert sklearn.__version__ >= \"1.0.1\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we did in previous chapters, let's define the default font sizes to make the figures prettier:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
2021-11-27 00:54:49 +01:00
"import matplotlib.pyplot as plt\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-27 11:03:26 +01:00
"plt.rc('font', size=14)\n",
2021-11-27 00:54:49 +01:00
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
2022-02-19 06:17:36 +01:00
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)"
2021-11-03 23:35:15 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And let's create the `images/training_linear_models` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
2021-10-15 10:46:27 +02:00
"IMAGES_PATH = Path() / \"images\" / \"training_linear_models\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
2016-05-22 16:01:18 +02:00
"\n",
2019-01-21 11:42:31 +01:00
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
2021-10-15 10:46:27 +02:00
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
2016-05-22 16:01:18 +02:00
" if tight_layout:\n",
" plt.tight_layout()\n",
2021-02-14 03:02:09 +01:00
" plt.savefig(path, format=fig_extension, dpi=resolution)"
2016-05-22 16:01:18 +02:00
]
},
2019-01-18 16:08:37 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-02 13:14:44 +02:00
"# Linear Regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The Normal Equation"
2019-01-18 16:08:37 +01:00
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 5,
2019-01-18 16:08:37 +01:00
"metadata": {},
"outputs": [],
2016-05-22 16:01:18 +02:00
"source": [
2017-05-29 23:20:14 +02:00
"import numpy as np\n",
"\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42) # to make this code example reproducible\n",
"m = 100 # number of instances\n",
"X = 2 * np.random.rand(m, 1) # column vector\n",
"y = 4 + 3 * X + np.random.randn(m, 1) # column vector"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 6,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbuElEQVR4nO3df5BdZX3H8c93Q8IEljQaVnRUDFbXqQGJkqrR1mZl2uKvaotWKBU0tDvWqQWLtaKjzpRp6Khg7FTHou4IY0pnFW0dKgrFTbHdiCa4YFbEX5PGiJYxGnFRA2G//ePcNZebe++ec+85z/Oce9+vmZ3N3l/nuydnn899nvOc55q7CwCA0EZiFwAAGE4EEAAgCgIIABAFAQQAiIIAAgBEQQABAKI4LnYBrdauXetPecpTYpdR2AMPPKATTzwxdhmFUHMY1BxGHWuW6ln3nj17fuTuY/2+TnIBdMopp2j37t2xyyhs586d2rJlS+wyCqHmMKg5jDrWLNWzbjP73zJehyE4AEAUBBAAIAoCCAAQBQEEAIiCAAIAREEAAQCiIIAAAFEQQACAKAggAEAUpQWQmU2Z2X1mtrfNfW82Mzezk8vaHgCg3srsAX1M0jmtN5rZEyX9rqT9JW4LAFBzpQWQu98m6cdt7nqfpLdI8rK2BQCov0rPAZnZH0j6vrvfWeV2AAD1Y+7ldUzMbL2kG939dDM7QdKMpN9z95+a2T5Jm9z9R22eNylpUpLGxsbOmp6eLq2mUBYWFjQ6Ohq7jEKoOQxqDqOONUv1rHtiYmKPu2/q+4XcvbQvSesl7W38+wxJ90na1/g6ouw80GO7vcb4+LjX0czMTOwSCqPmMKg5jDrW7F7PuiXt9hIyo7LPA3L3r0l6zNLP3XpAAIDhU+Y07Osl7ZL0NDM7YGYXl/XaAIDBU1oPyN3PX+b+9WVtCwBQf6yEAACIggACAERBAAEAoiCAAABREEAAgCgIIABAFAQQACAKAggAEAUBBACIggACAERBAAEAoiCAAABREEAAgCgIIABAFAQQACAKAggAEAUBBACIggACAERBAAEAoiCAAABREEAAgChKCyAzmzKz+8xsb9Nt7zGzb5jZXWb2aTNbW9b2AAD1VmYP6GOSzmm57RZJp7v7MyR9U9LlJW4PAFBjpQWQu98m6cctt93s7kcaP35J0hPK2h4AoN5CngPaKummgNsDACTM3L28FzNbL+lGdz+95fa3S9ok6Y+8zQbNbFLSpCSNjY2dNT09XVpNoSwsLGh0dDR2GYVQcxjUHEYda5bqWffExMQed9/U9wu5e2lfktZL2tty20WSdkk6Ic9rjI+Pex3NzMzELqEwag6DmsOoY83u9axb0m4vITOO6zvBujCzcyT9raTfcfefV7ktAEC9lDkN+3plPZ2nmdkBM7tY0j9JOknSLWY2Z2YfKmt7AIB6K60H5O7nt7n5o2W9PgBgsLASAgAgCgIIABAFAQQAiIIAAgBEQQABAKIggAAAURBAAIAoCCAACGDXLunKK7PvyFS6FA8AIAuds8+WHnxQWrVKuvVWafPm2FXFRw8IACq2c2cWPg8/nH3fuTN2RWkggACgYlu2ZD2fFSuy71u2xK4oDQzBAUDFNm/Oht127szCh+G3DAEEAAFs3ly/4Nm1q9rQJIAAAMcIMXGCc0AAgGOEmDhBAAEAjhFi4gRDcACAY4SYOEEAAQDaqnriBENwAJCIYVuuhx4QACQgz6yzqqdFh1ZaD8jMpszsPjPb23Tbo83sFjP7VuP7o8raHoDBNGy9gCXLzTpbCqh3vCP7Xub+ibXPyxyC+5ikc1pue6ukW939qZJubfwMAG1V2cimbrlZZ1VNi465z0sLIHe/TdKPW25+uaRrG/++VtIrytoegMEzzIt2Ls06u+KK9sNvVU2LjrnPqz4HdIq7/0CS3P0HZvaYircHoMaWGtml8yDDtmhnt1lnVU2LjrnPzd3LezGz9ZJudPfTGz8fcve1Tff/xN2POQ9kZpOSJiVpbGzsrOnp6dJqCmVhYUGjo6OxyyiEmsOg5mLm59dobm6tNm48pA0b7s/9vDruZymNuovu84mJiT3uvqnvDbt7aV+S1kva2/TzPZIe1/j34yTds9xrjI+Pex3NzMzELqEwag6DmsMIXfPsrPu2bdn3ftRxX0va7SVkRtVDcJ+RdJGkf2h8//eKtwcAlavDJ5zWYcp2aQFkZtdL2iLpZDM7IOldyoJn2swulrRf0qvK2h4AxNLuxH1KjXwdAlIqMYDc/fwOd51d1jYAIAUpTZZo19NpDcjrrkuzN8RKCMCQqsMQTbOU6k3lE0479XSaA/K446SpqSyMUusNEUDAEKrLEM2SFOtN4RNOOw0FNgfk/v3Shz+c5nAhi5ECQ6huF3zWrd5Qul2cunmzdPnl0oUXZveNjGRf69bFqvZYBBAwhEJ82FiZ6lZvKMutnrD0mO3bs/B5+GHp0kvTWeKIIThgCKVyDiOvutUbUp6hwIMHJXdpcTGtYTgCCBhSKZzDKCJmvSlNgOhFSrP2mhFAANBFihMgikq1B0kAAUAXqV90mleKPV4mIQBAk9YPZ2MCRHXoAQFAQ6fhthSHrwYBAQQADd0u7AwdPHWf+JAHAQQADa2zxdaty4bjQodA6IkPscKOAAJQCyEayebhtnXrsos2Y8x+CznxIeYsPyYhAEjeUiP5jndk36u8kn9pCZuDB8tf/qd1gkMnISc+xFzmiB4QkKhhOAeQV4yp0GVfvFmkpxFy4kPMi1QJICBBg3DxY5liNJJlh0DREA018SHmLD8CCEjQoFz8WES3Ht9SI3nddWFrWi4EivRSU10OR4p3kSoBBCQo5caqCnl7fNdemz3m2mvj9wqL9lK5nuhYBBCQoGFrrPL0+Ir0CkOcP+ull5ricjgx5QogMzsg6Wp3v7rptjMkfUXSs9z96xXVBwytYWqs8vT48vYK8/ZM+g2pYeulViFvD2iXpN9suW27pI8QPsBg2bVL2rHjVB1/fLgAzNPjy9srzNMzKWOSx7D1UqtQJIDesPSDmb1C0jMl/XGeJ5vZmyT9mSSX9DVJr3P3XxaqFEDllhrmw4dP044dYc+z5Onx5XlMnp5JWZM8hqmXWoW8F6J+SdKvm9mjzex4Se+V9HfufnC5J5rZ4yX9laRN7n66pBWSzuu1YADVWWqYFxct+EWJZen0MdXNF4GywnUa8vaA9kh6UNImZT2fI5I+UHA7q83sIUknSLq3SJEAqtF6HmSpYT58eFGrVo3UtmFu7Zm0G3Jj+Cy+XAHk7ofN7KuSXibpIkl/4u4P5Xzu983svZL2S/qFpJvd/eZeCwZQjm4fPTA1tU9btz651IY55soO110n/fKXkvvRIbfLLyd4YjN3z/dAs/dJukTSLe7++7k3YPYoSTdIerWkQ5I+IemT7v7xpsdMSpqUpLGxsbOmp6fzvnwyFhYWNDo6GruMQqg5jH5rnp9fo7m5tdq48ZA2bLi/tLp27DhVU1OnaXHRNDKyqK1b9+mCC/ZLKn8/z8+v0WWXnamHHhrRypWLuuqqO0v9XaTONc/Pr9Gb3rRRDz1kkqSVKxf1vveVv/1exTymez22JiYm9rj7pr4LcPdcX8p6Pkckbcj7nMbzXiXpo00/Xyjpg50ePz4+7nU0MzMTu4TCqDmMfmqenXVfvdp9xYrs++xseXV1e+2y9/O2bdl2pOz7tm2lvry7d665edtm7q9/ffnb7kesY7qfY0vSbi+QA52+iqyGfYGkf3b3+YIZt1/Sc83sBDMzSWdLurvgawBDqcqVijudrK9CnpP+eVeK7mXbK1ZIZtLKldKFF5b7+nUVcxXsJV3PAZnZiKQxSa+VdIayYbRC3P12M/ukpDuU9aC+KumawpUCQ6jqix1TWfCy6sVXzR75HWlcSLvcJIQXSPqCpHsknevuP+llI+7+Lknv6uW5wDAbpIsdu4VdlYuv7twpHTmSTUA4cmQ4FnbNI4Vjq2sAuftO8aF16BOfa9OfYbjYscp34ym8009V7GOLxUhRKT7XBnlU+W48hXf6aI8AQqWG8XNtUlOXHmjsd+M
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – generates and saves Figure 4– 1\n",
2021-11-03 23:35:15 +01:00
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.figure(figsize=(6, 4))\n",
2016-05-22 16:01:18 +02:00
"plt.plot(X, y, \"b.\")\n",
2021-11-03 23:35:15 +01:00
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
2016-05-22 16:01:18 +02:00
"plt.axis([0, 2, 0, 15])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
2016-09-27 16:39:16 +02:00
"save_fig(\"generated_data_plot\")\n",
2016-05-22 16:01:18 +02:00
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 7,
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-11-03 23:35:15 +01:00
"from sklearn.preprocessing import add_dummy_feature\n",
"\n",
"X_b = add_dummy_feature(X) # add x0 = 1 to each instance\n",
"theta_best = np.linalg.inv(X_b.T @ X_b) @ X_b.T @ y"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 8,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21509616],\n",
" [2.77011339]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
"theta_best"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 9,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21509616],\n",
" [9.75532293]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
"X_new = np.array([[0], [2]])\n",
2021-11-03 23:35:15 +01:00
"X_new_b = add_dummy_feature(X_new) # add x0 = 1 to each instance\n",
"y_predict = X_new_b @ theta_best\n",
2016-05-22 16:01:18 +02:00
"y_predict"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 10,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAArhUlEQVR4nO3de3hV1Z3/8fdKSCAQ8AIRFRR0FFS8IMXqSRQCqD9Ha2mr1lZbdWiL07G2trbUy/PTtnSgtV7bzkzHURQd7Dx4669D1WppzqgkKiBRQYqt1kHQipdSBAK5fX9/7JNwcsjlnGRfk8/rec5zyDn77P3NZmd/z1rru9d2ZoaIiEjYiqIOQEREBiYlIBERiYQSkIiIREIJSEREIqEEJCIikVACEhGRSAyKOoBc++67rx1xxBFRh1GwHTt2MGzYsKjDKIhiDodiDkcSY4Zkxr169er3zayir+uJXQIaPXo0q1atijqMgqXTaaqrq6MOoyCKORyKORxJjBmSGbdz7n/9WI+64EREJBJKQCIiEgklIBERiYQSkIiIREIJSEREIhG7KriebNu2jS1bttDU1BR1KB3ss88+rF+/PuowCpKUmEtKSjjggAMYMWJE1KGIiI8SlYC2bdvGu+++y5gxYygrK8M5F3VI7T766COGDx8edRgFSULMZkZDQwObN2+OOhQR8VmiuuC2bNnCmDFjGDp0aKySjwTHOcfQoUMZM2YMW7ZsiTocEfFRohJQU1MTZWVlUYchESgrK4tdt6uI9I1vCcg5t8g5t8U5t7aT977tnDPn3CgfttPXVUgC6f9dpP/xswV0L3BW7ovOuUOAM4CNPm5LREQSzrcEZGZPAx928tZtwDzA/NqWdO2hhx7q0Fq49957KS8v79M60+k0zjnef//9voYnItIu0DEg59wngc1m9lKQ20mCyy67DOcczjlKSko4/PDD+fa3v82OHTsC3e6FF17IG2+8kffy48eP5+abb+7wWmVlJe+88w4jR470OzwRGcACK8N2zg0FrgfOzGPZucBcgIqKCtLpdKfL7bPPPnz00Uc+RumflpaWbmNrampixowZ3HnnnTQ1NVFbW8uVV17J1q1bue222zos29zcTHFxca/GPRoaGgA6xFJWVtZpbJ3FbGbs3r17r9eHDRvG9u3bC47HT7t27WL79u1dHh9xpZjDkcSYIblx+8LMfHsA44G1mX8fB2wB3sw8mvHGgQ7sbh0TJkywrrz66qtdvhe1bdu2dfv+pZdeauecc06H17785S/bgQceaDfeeKNNmjTJ7rnnHjv88MOtqKjIPvroI9u6dat95StfsYqKCisvL7dp06bZypUrO6xj8eLFduihh1pZWZmdc8459vOf/9y8/1bPPffcY8OGDevwmWXLltnHP/5xGzJkiO2///72iU98whoaGmz69OmG11Xa/jAzq6mpMcDee++99nU8/PDDduyxx1ppaamNHTvWfvjDH1pra2v7++PGjbP58+fb3Llzbfjw4TZmzBi76aabOsTxi1/8wo488kgbPHiwjRo1ys4880xramrqch+++uqrVlNT0+1+jiPFHI4kxmyWzLiBVeZDzgisC87MXjGzA8xsvJmNBzYBU8zsL0FtM2myS4v//Oc/88ADD/Dggw/y0ksvMXjwYM455xw2b97MsmXLWLNmDdOmTWPmzJm88847ADz//PNcdtllzJ07l/r6es4991xuuOGGbrf5xBNPMHv2bM444wyefvppampqmD59Oq2trTzyyCOMHTuWG264gXfeead9O7lWr17NBRdcwGc+8xleeeUVfvSjH7Fw4UJ+/vOfd1jutttu47jjjuPFF1/ku9/9LvPmzaOurg6AVatWccUVV3DjjTeyYcMGfve733HWWXvVsIhIP+ZbF5xz7pdANTDKObcJuNHM7vZr/V266iqorw98Mx1Mngy3396nVbzwwgs88MADzJo1C4DGxkbuv/9+Ro8eDcDvf/976uvree+999qvfZo/fz7//d//zf3338+8efO44447mDVrFtdffz0AEyZMYOXKldx9d9e7ff78+Zx//vn88Ic/bJ8J4fjjjwdg6NChFBcXM3z4cA488MAu13Hrrbcyffp0vv/977dv949//CM//vGPufLKK9uXO/PMM/na174GwJVXXslPf/pTli9fTiqVYuPGjQwbNoxPfvKTDB8+nHHjxnHCCSf0dneKSAL5WQX3eTM7yMxKzGxsbvLJtIQGdBnVE088QXl5OUOGDCGVSjFt2jR+9rOfATB27Nj25ANeK2Pnzp1UVFRQXl7e/li7di2vv/46AOvXryeVSnXYRu7PudasWdOe9Hpr/fr1VFVVdXjt1FNPZfPmzWzbtq39tbbE1ubggw9un83gjDPOYNy4cRx22GFcfPHFLF68OLbjeyISjETNBdepPrZEwjRt2jTuvPNOSkpKOPjggykpKWl/L/ee8K2trYwePZpnnnlmr/W0TcrpdcWGz8y6LJDIfj3792t7r7W1FYDhw4fz4osv8vTTT/PUU0+xcOFCrrvuOlauXMnBBx8cXPAiEhuJmoon6YYOHcoRRxzBuHHj9jo555oyZQrvvvsuRUVFHHHEER0eBxxwAADHHHMMzz33XIfP5f6c68QTT2T58uVdvl9aWkpLS0u36zjmmGN49tlnO7z27LPPMnbs2IImNx00aBAzZ85k4cKFvPzyy+zYsYNly5bl/XkRSbbkt4D6qdNPP52qqipmz57NTTfdxFFHHcVf/vIXnnjiCU4//XROO+00vv71r1NZWcnChQs5//zzSafTPProo92u9/rrr+fcc8/liCOOYPbs2QwdOpQnn3ySyy+/nKFDhzJ+/HieeeYZvvCFLzB48GBGjdp79qSrr76ak046ie9973tcdNFFrFy5kltuuYUFCxbk/fstW7aM119/nWnTprH//vtTU1PDRx99xNFHH13wvhKRZFILKKacczz22GPMnDmTr3zlK0ycOJHPfvazbNiwob2L6pRTTuHuu+/m3/7t3zj++ON55JFH+N73vtftes8++2weffRRHn/8cU499VSmT59OTU0NRUXeofCDH/yAt956i7/7u7+joqKi03VMmTKFBx98kIcffphjjz2Wa665hmuuuaa94CAf++67L7/61a84/fTTOeqoo7j55pu56667OO200/Jeh4gkm4tqHKErEydOtA0bNnT63vr162P7DTkJ99bJlbSY169fz7vvvkt1dXXUoRQknU4r5hAkMWZIZtzOudVmNrWv61ELSEREIqEEJCIikVACEhGRSCgBiYhIJBKXgOJWNCHh0P+7SP+TqARUUlLSfrsBGVgaGhp6vHhXRJIlUQnogAMOYPPmzezcuVPfiAcIM2Pnzp1s3ry5fQYIEekfEjUTQtscaG+//Xb7bQziYteuXQwZMiTqMAqSlJhLSkoYPXp0+/+/iPQPiUpA4CWhOJ6I0uk0J554YtRhFCSJMYtI/5GoLjgREek/lIBERCQSSkAiIhIJJSAREYmEEpCIiERCCUhERCKhBCQiIpHwLQE55xY557Y459ZmvfYT59wfnHMvO+cedc7t69f2REQk2fxsAd0LnJXz2lPAsWZ2PPAacK2P2xMRkQTzLQGZ2dPAhzmvPWlmzZkfnwPG+rU9ERFJtjDHgOYAj4e4PRERiTHn56zSzrnxwDIzOzbn9euBqcBnrJMNOufmAnMBKioqPrZ06VLfYgrL9u3bKS8vjzqMgijmcCjmcCQxZkhm3DNmzFhtZlP7vCIz8+0BjAfW5rx2KVAHDM1nHRMmTLAkqqmpiTqEginmcCjmcCQxZrNkxg2sMh9yRqCzYTvnzgK+C0w3s51BbktERJLFzzLsX+K1dCY65zY5574E/BwYDjzlnKt3zv3Cr+2JiEiy+dYCMrPPd/Ly3X6tX0RE+hfNhCAiIpFQAhIRkUgoAYmISCSUgEREJBJKQCIiEgklIBERiYQSkIiIREIJSEQkBHV1sHCh9yyeQKfiERERL+nMmgWNjVBaCsuXQyoVdVTRUwtIRCRg6bSXfFpavOd0OuqI4kEJSEQkYNXVXsunuNh7rq6
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2021-11-10 05:58:42 +01:00
"import matplotlib.pyplot as plt\n",
"\n",
2022-02-19 06:17:36 +01:00
"plt.figure(figsize=(6, 4)) # extra code – not needed, just formatting\n",
2021-11-03 23:35:15 +01:00
"plt.plot(X_new, y_predict, \"r-\", label=\"Predictions\")\n",
2016-05-22 16:01:18 +02:00
"plt.plot(X, y, \"b.\")\n",
2021-11-03 23:35:15 +01:00
"\n",
2022-02-19 06:17:36 +01:00
"# extra code – beautifies and saves Figure 4– 2\n",
2021-11-03 23:35:15 +01:00
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
2016-05-22 16:01:18 +02:00
"plt.axis([0, 2, 0, 15])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
"plt.legend(loc=\"upper left\")\n",
2019-05-06 07:14:50 +02:00
"save_fig(\"linear_model_predictions_plot\")\n",
2021-11-03 23:35:15 +01:00
"\n",
2016-05-22 16:01:18 +02:00
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 11,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(array([4.21509616]), array([[2.77011339]]))"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
"from sklearn.linear_model import LinearRegression\n",
2019-01-18 16:08:37 +01:00
"\n",
2016-05-22 16:01:18 +02:00
"lin_reg = LinearRegression()\n",
"lin_reg.fit(X, y)\n",
"lin_reg.intercept_, lin_reg.coef_"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 12,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21509616],\n",
" [9.75532293]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
"lin_reg.predict(X_new)"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"source": [
2018-03-15 18:38:58 +01:00
"The `LinearRegression` class is based on the `scipy.linalg.lstsq()` function (the name stands for \"least squares\"), which you could call directly:"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 13,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21509616],\n",
" [2.77011339]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
2018-03-15 18:38:58 +01:00
"source": [
"theta_best_svd, residuals, rank, s = np.linalg.lstsq(X_b, y, rcond=1e-6)\n",
"theta_best_svd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function computes $\\mathbf{X}^+\\mathbf{y}$, where $\\mathbf{X}^{+}$ is the _pseudoinverse_ of $\\mathbf{X}$ (specifically the Moore-Penrose inverse). You can use `np.linalg.pinv()` to compute the pseudoinverse directly:"
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 14,
2018-03-15 18:38:58 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21509616],\n",
" [2.77011339]])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
2018-03-15 18:38:58 +01:00
"source": [
2021-11-03 23:35:15 +01:00
"np.linalg.pinv(X_b) @ y"
2018-03-15 18:38:58 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-02 13:14:44 +02:00
"# Gradient Descent\n",
"## Batch Gradient Descent"
2018-03-15 18:38:58 +01:00
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 15,
2018-03-15 18:38:58 +01:00
"metadata": {},
"outputs": [],
2017-05-29 23:20:14 +02:00
"source": [
2019-01-18 16:08:37 +01:00
"eta = 0.1 # learning rate\n",
2021-11-03 23:35:15 +01:00
"n_epochs = 1000\n",
"m = len(X_b) # number of instances\n",
2019-01-18 16:08:37 +01:00
"\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
"theta = np.random.randn(2, 1) # randomly initialized model parameters\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"for epoch in range(n_epochs):\n",
" gradients = 2 / m * X_b.T @ (X_b @ theta - y)\n",
2017-05-29 23:20:14 +02:00
" theta = theta - eta * gradients"
]
},
{
2021-11-03 23:35:15 +01:00
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"The trained model parameters:"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 16,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21509616],\n",
" [2.77011339]])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"theta"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2021-11-03 23:35:15 +01:00
"execution_count": 17,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAAEQCAYAAAC++cJdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADaKUlEQVR4nOydd3gUx/nHP3NFvfcGEk1CgOi9N4MbuLckbontxHEcJ3bcW5y4xDV27NjJL7Zjx3Iv2AbcACPA9C66ACFAqPd6urLz+2MFCHR36kJlPs+jR7A3uzO7unfnu7NvEVJKFAqFQqFQKBSK3oLhXA9AoVAoFAqFQqHoTJQAVigUCoVCoVD0KpQAVigUCoVCoVD0KpQAVigUCoVCoVD0KpQAVigUCoVCoVD0KpQAVigUCoVCoVD0KpQAVigUCoVCoVD0KpQAVigUCoVCoVD0KpQAVrhECPFbIcQRIYRFCLFVCDGtrfsIIR4UQmwWQlQIIQqFEIuFEMM67iwUip5JS+1TCDFdCPG1EOKEEEIKIW7qpKEqFD2eVtjjn+vtsOFPXmeNV6EEsMIFQohrgFeAp4FRwDrgWyFE3zbuMxN4HZgMzAbswHIhREgHnIZC0SNpjX0CfsBu4C6gtsMHqVD0ElppjwAHgOgGPykdOU7FmQhVCrn7IYTYAXwFhADXARrwqpTyr+3Yx0YgXUp5a4NtB4HPpJQPtuM+fkA5cKmUcnF7jV+hOFd0Vfs8a/8q4HdSynfaa0wKRVekq9qjEOLPwJVSSvUG9ByhVoC7GUIIM5AM/ALYBEwE3gCeEEJEndX2ISFEVRM/jV7TCCE8gDHAD2d99AP6yq2zcbV4n3r80b+HpW7aKBTdgq5qnwpFb6Qb2GP/epekI0KIj4QQ/Vt5qopWYDrXA1C0mCGAB3CPlPJLACHEm8BjQBDQ0IfoX8AnTRzvhJNtYYARyD9rez4w18VxWrMP6K+NdgDrmxinQtEd6Kr2qVD0RrqyPW4EbgL2AxHAI8A6IcRQKWVxE+NQtANKAHc/RqKvljZ0Fwiv/32GA72UsgQoaUNfZ/vHCCfbWr2PEOIlYCowVUrpaNUIFYquxUi6tn0qFL2JkXRRe5RSfntGYyE2AJnAjcBLbRiHopkoF4juxwhgy1mCcSRwVEpZ1rBha1/pAEWAA4g6a3sEjZ9yW7WPEOLv6P5Ys6WUmS6OqVB0N7qqfSoUvZFuY49SyipgDzCoufso2oZaAe5+jAQ2n7VtFLobwdm06pWOlNIqhNgKnAd82uCj84DPnR2kJfsIIV4BrgVmSin3NzE+haI7MZIuaJ8KRS9lJN3EHoUQXsBgYGVz91G0DSWAux/Dgf87a9soYNnZDdv4Sucl4D0hxCZgLfAbIAb9JoEQ4nfoUeSDm7tP/X7/BK4HLgVKGwQiVNU/ASsU3Zkua5/1GVcG1v/XAPQVQowESqSUx1o5DoWiK9OV7fEFdNeMY+irxY8CvsC7rRyDooUoAdyNEELEAaE0eHoVQgj01zzPtWdfUsqPhRCh6I750ej5Qy+UUh6tbxIGJLVwH4Df1v9ecVaXTwB/bs9zUCg6k65un8BYzlxdeqL+5130YByFosfQDewxDviw/rNCYAMw8az5UtGBqDzA3RwhRCJ6Mu2+Usrj53o8CoXiNMo+FYqug7JHRUNUEFz3ZxRQrIxZoeiSKPtUKLoOyh4Vp1ACuPvjyqFfoVCce5R9KhRdB2WPilMoFwiFQqFQKBQKRa9CrQArFAqFQqFQKHoVSgArFAqFQqFQKHoVXS4NWlBQkBw4cGDTDbs41dXV+Pr6nuthtAl1Dl2DrVu3Fkkpw5tu2f70BHvsCd8B6Bnn0RPOoavYY0V2NlX5+UQMHYrJy8tpe6lpFOzZg8FkIjw52e2xa7KzqcvPx3/wYExN/I1shQVYc3Pw7BuPKSi4yXHbc7PRqiox90mgxqE1+R2QdbXIghPg5YMhPKbJ4+uDqoPSXDB5QHA0CNG8/apLwG4B7yDw8GnePtKBVlOKQQBeQWBoppRy1IG9Vm9v9mvePh1Ir7dHKWWX+klMTJQ9gZUrV57rIbQZdQ5dA/RSnsoeW0lP+A5I2TPOoyecQ1ewx8q8PPlXb2/5+S9+4Xasq595Rj4OMrOJ616yY4f82GiUm267rcnzr96/T24aEC0zbrtRaprWZPuKrz+RWXNHyLKP3pZSNv0dcJQUysqHfi4r/3Kr1Kormzy+lFJqpXnS/s9fS/t/7pJadXnz9tEc0rHtU+n4+mGpZa5v1j5SSqlZyqVjx7vSsvENqVUXNH+/kgzpOLRIOnI3SE2zN3u/jqS326NygVAoFAqFohux9rnncNTVMf3RR122qSooYM3TT5O0cCH9Zs502U5qGltvvx2PkBCGP/OM2341m43Mu+/A6OdPwlPPI5pYZbVmZlDyxvN4jZ1MwFU3um0LIG1Wat96Cmmtw/uWRxA+Ta+SytoqtC+eA6lhuPx+hE9A0/tIidzzDWTvQCTNQfSb2OQ+ALKuErn/S3BY2WWJR/g0vfAopUQr2Ycs2QN+sYjIcQhhbFZ/io5FCWCFQqFQKLoJlXl5bH7jDYb/4heEJSa6bJf25z9jr63lvOfcFz3LfOstitevZ8QLL+AZEuK2be4b/6Bm104Snnoec5h78afV1lL45P0Y/QMIu+9JhMG93JBSYvnkdbRjB/G6/m6M0X3dtgeQdhva1y9BRSGGS+5GhEQ3uQ8AGSvhyAboPxkGzWzWLrKuEnngS3DUIZIWUqV5N72PlMiSvVB6APz7IiLGIoSSXV0F9ZdQKBQKhaKbsPbZZ3FYrW5Xfwv37WPr//0fY3/zG8KSzq7AexpLQQHp999P+IwZJFx/vdt+a/buJueVFwhZeBkhFy1scpwl//wb9uwsQu9/CmOwe2ENYFuzFPvG5XjMvxbz8ElNtpdSIr//N2TvR8z/DSJucJP7AMjM9ciMH6HPKMSQ85tcxQaQ1ipd/NotiMSFCN+I5o2vKB3KDkJAP0T4qGb1peg8ulwQnEKhUCgUisZIu53Nb7zBiBtuINRNcOqy++7Dw9eXGY895vZ4O++7D1tlJWNef92tONOsVjLv/h2moGDi//psk+OsWrGU6u+/IvDnt+I9ekKT7e2HdlP3xX8wDh2HxwU/a7I9gFz3KXL/OsTUazAkT27ePse3I/cshaghiOGXNms1VlqrdbcHWy0iaSHCL7LpfaREFu6AyqMQOAAROkyJ3y6IEsAKhUKhUHQDrCUlSIeD6Y884rJN5o8/krFkCXOffRbfcNduCgWrV5P17rskP/AAgUOGuO0359UXqdm7m0Fvvoe5idVcW/ZRSl55Cs9howi8/tfuTwjQSguxvP0MhrBovG/4U5OuEgDarjTkhi8RKbMQ45tejQaQefuQOxdBWH/E6KsQhqb9cKW1Wl/5tdXUi9+opveRGrJgG1RlQ3AiIjhZid8uihLACoVCoVB0A6zl5Yz45S8J6d/f6eeaw8EP99xDYHw8E37/e5fHcVitbL39dnwTEhjixpUCoDp9BzmvvUzoFdcQPO8Ct22l1UrR0w8gTGbCHnwGYXQvMaS1jtq3nkbarHjf+jDCu+mUXDJrF3L5WxCfgphzc/NcGIoykVs/hsAYxLifI4zmpvexVSMPfAXWKt3tobniN38LVOcgQpIRwa7dTxTnHiWAFQqFQqHoDkjpdvV353vvkbdjB1d8+CFmF7mBATJeeomKvXuZtngxJh/XuW+1ujoy7/4d5rBw4h9/qsnhlb75MtaD+wh/4mVMEe4FY8OgN+9bH8UY2afJ48ui42iLX4aQGAwL7mpSYAPIshPITangE4KYcAPC5Nn0PrYa5P6vwFqJSFyA8G86uE5qDmT+JqjJ110egrp3/vTegBLACoVCoVB0A7wiIghOSHD6mbW6mh8ffpjY8eMZds01Lo9RnZXFnr/8hdhLLyXm4ovd9nfi789Rm7GfxHc/whQU5LZtzbo0Khd9gP+l1+EzeWYTZwK21Uuwb1qBx/nXYUpp2k9YVpXq6c7Mnhguuw/h2XTRCllZgNzwLnj4ICbehGhGoQtpq6lf+a1EDLo
"text/plain": [
"<Figure size 720x288 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – generates and saves Figure 4– 8\n",
2021-11-03 23:35:15 +01:00
"\n",
"import matplotlib as mpl\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"def plot_gradient_descent(theta, eta):\n",
2016-09-27 16:39:16 +02:00
" m = len(X_b)\n",
2016-05-22 16:01:18 +02:00
" plt.plot(X, y, \"b.\")\n",
2021-11-03 23:35:15 +01:00
" n_epochs = 1000\n",
" n_shown = 20\n",
" theta_path = []\n",
" for epoch in range(n_epochs):\n",
" if epoch < n_shown:\n",
" y_predict = X_new_b @ theta\n",
" color = mpl.colors.rgb2hex(plt.cm.OrRd(epoch / n_shown + 0.15))\n",
" plt.plot(X_new, y_predict, linestyle=\"solid\", color=color)\n",
" gradients = 2 / m * X_b.T @ (X_b @ theta - y)\n",
2016-05-22 16:01:18 +02:00
" theta = theta - eta * gradients\n",
2021-11-03 23:35:15 +01:00
" theta_path.append(theta)\n",
" plt.xlabel(\"$x_1$\")\n",
2016-05-22 16:01:18 +02:00
" plt.axis([0, 2, 0, 15])\n",
2021-11-03 23:35:15 +01:00
" plt.grid()\n",
2021-11-21 22:18:02 +01:00
" plt.title(fr\"$\\eta = {eta}$\")\n",
2021-11-21 05:36:22 +01:00
" return theta_path\n",
2021-11-03 23:35:15 +01:00
"\n",
2017-06-06 15:16:46 +02:00
"np.random.seed(42)\n",
2022-02-19 06:17:36 +01:00
"theta = np.random.randn(2, 1) # random initialization\n",
2017-05-29 23:20:14 +02:00
"\n",
2022-02-19 06:17:36 +01:00
"plt.figure(figsize=(10, 4))\n",
2021-11-03 23:35:15 +01:00
"plt.subplot(131)\n",
"plot_gradient_descent(theta, eta=0.02)\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
"plt.subplot(132)\n",
"theta_path_bgd = plot_gradient_descent(theta, eta=0.1)\n",
"plt.gca().axes.yaxis.set_ticklabels([])\n",
"plt.subplot(133)\n",
"plt.gca().axes.yaxis.set_ticklabels([])\n",
"plot_gradient_descent(theta, eta=0.5)\n",
2017-05-29 23:20:14 +02:00
"save_fig(\"gradient_descent_plot\")\n",
"plt.show()"
]
},
2016-05-22 16:01:18 +02:00
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"source": [
2021-10-02 13:14:44 +02:00
"## Stochastic Gradient Descent"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 18,
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"theta_path_sgd = [] # extra code – we need to store the path of theta in the\n",
" # parameter space to plot the next figure"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 19,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABh/ElEQVR4nO29d3gc53Xv/3lnZ/tisbtYNBLsnWInJRIsIqhiybZsxTVusWM5UeL4+vrGieOS303u9Y1j59rXJXGKlVix5chFLrJlW7IlSwRFikUSm1jETqIRvS+277y/P2YBLEBUcoFdkO/nefBwy5Qzw9n5zjnvOecVUkoUCoVCoZhutFwboFAoFIpbEyVACoVCocgJSoAUCoVCkROUACkUCoUiJygBUigUCkVOUAKkUCgUipyg59qA4fh8Prl48eJcmzFp+vr6cLvduTZjUiibpwdl8/QwE22GmWn34cOH26SUxTe6nbwToNLSUl599dVcmzFpqqurqaqqyrUZk0LZPD0om6eHmWgzzEy7hRA12diOCsEpFAqFIicoAVIoFApFTlACpFAoFIqcoARIoVAoFDlBCZBCoVAocoISIIVCoVDkBCVACoVCocgJSoAUCoVCkROUACkUCoUiJ2RNgIQQjwohWoQQJ0f47i+FEFIIEczW/hQKhUIxs8mmB/Qd4P7hHwoh5gD3ArVZ3JdCoVAoZjhZEyAp5YtAxwhffQ34K0Bma18KhUKhmPlM6RiQEOKtQIOU8vhU7kehUCgUMw8hZfYcEyHEfOBXUspVQggXsBt4g5SyWwhxBdgkpWwbYb2HgYcBiouLNz7xxBNZs2m6CIVCeDyeXJsxKZTN04OyeXqYiTbDzLR7165dh6WUm254Q1LKrP0B84GT6dergRbgSvoviTkOVDbWNpYuXSpnIrt37861CZNG2Tw9KJunh5los5Qz027gVZkFzZiy+YCklCeAkv73Y3lACoVCobj1yGYa9g+AA8AyIUS9EOIj2dq2QqFQKG4+suYBSSnfO87387O1L4VCoVDMfFQnBIVCoVDkBCVACoVCocgJSoAUCoVCkROUACkUCoUiJygBUigUCkVOUAKkUCgUipygBEihUCgUOUEJkEKhUChyghIghUKhUOQEJUAKhUKhyAlKgBQKhUKRE5QAKRQKhSInKAFSKBQKRU5QAqRQKBSKnKAESKFQKBQ5QQmQQqFQKHKCEiCFQqFQ5AQlQAqFQqHICUqAFAqFQpETlAApFAqFIicoAVIoFApFTsiaAAkhHhVCtAghTmZ89mUhxBkhxGtCiCeFEL5s7U+hUCgUM5tsekDfAe4f9tlzwCop5RrgHPDZLO5PoVAoFDOYrAmQlPJFoGPYZ89KKZPptweBimztT6FQKBQzm+kcA3oIeGYa96dQKBSKPEZIKbO3MSHmA7+SUq4a9vlfA5uAt8sRdiiEeBh4GKC4uHjjE088kTWbpotQKITH48m1GZNC2Tw9KJunh5loM8xMu3ft2nVYSrnphjckpczaHzAfODnssw8BBwDXRLaxdOlSORPZvXt3rk2YNMrm6UHZPD3MRJulnJl2A6/KLGiGfsMKNgZCiPuBTwM7pZThqdyXQqFQKGYW2UzD/gGmp7NMCFEvhPgI8E2gAHhOCHFMCPFv2dqfQqFQKGY2WfOApJTvHeHjb2dr+wqFQqG4uVCdEBQKhUKRE5QAKRQKhSInKAFSKBQKRU5QAqRQKBSKnKAESKFQKBQ5QQmQQqFQKHKCEiCFQqFQ5AQlQAqFQjENHDgAX/yi+a/CZEpb8SgUCoXCFJ2774Z4HGw2eP55qKzMtVW5R3lACoVCMcVUV5vik0qZ/1ZX59qi/EAJkEKhUEwxVVWm52OxmP9WVeXaovxAheAUCoViiqmsNMNu1dWm+Kjwm4kSIIVCoZgGKitnnvAcODC1oqkESKFQKBTXMB2JE2oMSKFQKBTXMB2JE0qAFAqFQnEN05E4oUJwCoVCobiG6UicUAKkUCgUihGZ6sQJFYJTKBSKPOFWa9ejPCCFQqHIAyaSdTbVadETQUZCWdtW1gRICPEo8ADQIqVclf4sAPwImA9cAd4tpezM1j4VCsXNRz7cZHPBSFlnmcc/lWnRY51zmYhC/Vlk7Ulk7SloqcnOTsmuB/Qd4JvAYxmffQZ4Xkr5JSHEZ9LvP53FfSoUipuIW7lpZ3/WWf+xD886G0+grpfh5/x3zyapnH9hUHAaL4CRAosOs5YSmVMJfP/Gd0wWBUhK+aIQYv6wjx8EqtKvvwtUowRIoVCMwlTdZGcC42WdjSdQ18vu3QbxuCCVEsRjKXb/w8/ZXPUzEAJKFxAqXsPZZ/ZT8/JRmq8+QWcklp0dM/VjQKVSykYAKWWjEKJkivenUChmMFN1k50pjJV1lq20aCkldDYNeDh3tkexaX9O3NCx6Sk2rezm0MFuag6foKXxx3RFE8j0ul6bzvyiAmjrvb6dD0NIKcdfaqIbMz2gX2WMAXVJKX0Z33dKKf0jrPcw8DBAcXHxxieeeCJrNk0XoVAIj8eTazMmhbJ5elA2T45Tp7wcO+Zj3boubrutZ8LrzcTzDNNjty3eh6+3AV9PA77eq9gTfQBEbR5qWg2e/rXk4PkFlEd+S1FyHwAC8DmslHhdlPnclJf6cDhsiEIv83/y4mEp5aYbtWuqPaBmIUR52vspB1pGWkhK+QjwCMCyZctk1Qx87Kmurmam2a1snh6UzZPjenc73TZnK1liKuyWkRDUnR4cx+lsND93eGjqs3N2zzlqT5ylpaWNSDIFwFoBAaeNkmIfZT43ZSU+bHYbesCPbtexyASariGFJWt2TrUAPQV8CPhS+t9fTPH+FAqFYsrJt2SJwUy1U8jak9BSw4Eri6m+tIplpUUUXHmJutMXaW3rIJ4yALBpgiKXjaWFXsr9HoqLCtHtOro/gG63YCE5IDjxSJyGulbq6lpp6uzLmt3ZTMP+AWbCQVAIUQ/8LabwPCGE+AhQC7wrW/tTKBSKXJHrZAmZSkLjhbTgnILG82CkSKTgSpPBr37n4HMnPk0SGxbifJBfstTSQYnLRqnPTZnfQ1GgAItNR/f704KTMgUHC329YequNFPf0E5zV5iucBQjvW+HJQ89ICnle0f56u5s7UOhUCjygelOlpDSgJbawZBawxlIxIiE4/z88EJ+8+pGSrp/RWHoeSSwl8+QxIZEx0DSXfQx2grfw5KiI2xYUIdu17AIA82iYaDR3tpN7ZVmGpo6ae2J0BdPAOY4kMduY+GK5cy5YyOLb19C0Jvksx/8u6wcl+qEoFDcosy0gs98sneqG3UOz1Sj7jREQ/R0hblc28eF47U0XKznZGgNj/EYKWzoPMxnCx5ge/Aki6zn2H8hQcIAiyb5Zee7SLXrfLMmwU+Cn2JW+17qaltobOmmrTdKwjD9Gx3wedwsWLuWedtuZ9GGBbiNTmipR4b6oPk1ZHP2jlMJkEJxC5JvYxjjkY/2ZrtRp+ztMMdv6k4ja04ie9tpb+3l8uVuLp6s5+qVRvrSNTgWAX67lT7XfRjhfk8HIqVvY+V6gzW2Tn668C850LqBy11F/ODiAxhYkCmDrz3jYTvHAXAKKAkEmLVuHQvu3Mzc5cXY+pqQLVeRfc1wvhkD6EnqtHRqNDeGaDp9MWvHrARIobgFyfUYxmSZafZOBBkJQf1pFtXuI/XoLzHaGmhq6OTypXYuvd5EY10zsXgSMBMG/HYr84MFlPrdBH1unB4ns0KX+cWRJAkDrFqSqiXn6I3Hqb/QQsvVoxR0PoI9vgmNN6RreSRzCpPcedfvsbBqMyWzHehd9ciWRmTkLJw8S8IwaO8TNHckaK7rounEGaJtbQA4ioooWrUKzp/PyjlQAqRQ3ILMtILPmWbvSMhEFBrOImvMTLVE/UUaatu4dL6F35xroflqG8l0hprLohG0Wwn6XJT63fi9LhxuJ5pVRy/0otsEmkVQmarhX8VHeeHyMsqiz3Lkt9W8bJhSYwW8uoU3LGkjWPIjvnrggxhS50fhL/Chkr+jrO0VaINwLElLd4qWthhNV1ppPXmGVMz0tAoXL6Zs40b
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2017-05-29 23:20:14 +02:00
"n_epochs = 50\n",
2016-05-22 16:01:18 +02:00
"t0, t1 = 5, 50 # learning schedule hyperparameters\n",
"\n",
"def learning_schedule(t):\n",
" return t0 / (t + t1)\n",
"\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
"theta = np.random.randn(2, 1) # random initialization\n",
"\n",
2022-02-19 06:17:36 +01:00
"n_shown = 20 # extra code – just needed to generate the figure below\n",
"plt.figure(figsize=(6, 4)) # extra code – not needed, just formatting\n",
2016-05-22 16:01:18 +02:00
"\n",
2017-05-29 23:20:14 +02:00
"for epoch in range(n_epochs):\n",
2021-11-03 23:35:15 +01:00
" for iteration in range(m):\n",
"\n",
2022-02-19 06:17:36 +01:00
" # extra code – these 4 lines are used to generate the figure\n",
2021-11-03 23:35:15 +01:00
" if epoch == 0 and iteration < n_shown:\n",
" y_predict = X_new_b @ theta\n",
" color = mpl.colors.rgb2hex(plt.cm.OrRd(iteration / n_shown + 0.15))\n",
" plt.plot(X_new, y_predict, color=color)\n",
"\n",
2017-05-29 23:20:14 +02:00
" random_index = np.random.randint(m)\n",
2021-11-03 23:35:15 +01:00
" xi = X_b[random_index : random_index + 1]\n",
" yi = y[random_index : random_index + 1]\n",
2022-02-19 06:17:36 +01:00
" gradients = 2 * xi.T @ (xi @ theta - yi) # for SGD, do not divide by m\n",
2021-11-03 23:35:15 +01:00
" eta = learning_schedule(epoch * m + iteration)\n",
2016-05-22 16:01:18 +02:00
" theta = theta - eta * gradients\n",
2022-02-19 06:17:36 +01:00
" theta_path_sgd.append(theta) # extra code – to generate the figure\n",
2017-05-29 23:20:14 +02:00
"\n",
2022-02-19 06:17:36 +01:00
"# extra code – this section beautifies and saves Figure 4– 10\n",
2021-11-03 23:35:15 +01:00
"plt.plot(X, y, \"b.\")\n",
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
"plt.axis([0, 2, 0, 15])\n",
"plt.grid()\n",
"save_fig(\"sgd_plot\")\n",
"plt.show()"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 20,
2019-01-18 16:08:37 +01:00
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21076011],\n",
" [2.74856079]])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
"theta"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 21,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"SGDRegressor(n_iter_no_change=100, penalty=None, random_state=42, tol=1e-05)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
"from sklearn.linear_model import SGDRegressor\n",
2019-01-18 16:08:37 +01:00
"\n",
2022-02-19 06:17:36 +01:00
"sgd_reg = SGDRegressor(max_iter=1000, tol=1e-5, penalty=None, eta0=0.01,\n",
" n_iter_no_change=100, random_state=42)\n",
"sgd_reg.fit(X, y.ravel()) # y.ravel() because fit() expects 1D targets\n"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 22,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(array([4.21278812]), array([2.77270267]))"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
"sgd_reg.intercept_, sgd_reg.coef_"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"source": [
2021-10-02 13:14:44 +02:00
"## Mini-batch gradient descent"
2016-05-22 16:01:18 +02:00
]
},
2021-11-03 23:35:15 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code in this section is used to generate the next figure, it is not in the book."
]
},
2016-05-22 16:01:18 +02:00
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 23,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAegAAAEQCAYAAAB7ked4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAACMAUlEQVR4nO2dZ3hURReA35tCSKND6CDSewdFSgSUKk1UBAQsWAAFpaOwIAgivX+IdKRJB+kSpEvvvffQk0D6zvdjNtlNsptsQrYkzJtnn+zOnXvvmS333DlziiaEQKFQKBQKhXPh4mgBFAqFQqFQJEQpaIVCoVAonBCloBUKhUKhcEKUglYoFAqFwglRClqhUCgUCidEKWiFQqFQKJwQuyhoTdMyapr2n6ZpxzVNO61p2tBE+lbTNC1a07T37SGbQqFQKBTOiJudzhMOvC2ECNE0zR3YrWnaRiHEftNOmqa5Ar8Cm+0kl0KhUCgUToldZtBCEmJ46W54mMuQ0gNYAQTaQy6FQqFQKJwVu61Ba5rmqmnaMaTy3SqEOBBvez6gFTDDXjIpFAqFQuGs2MvEjRAiGqioaVoWYJWmaWWFEKdMukwA+gkhojVNs3gcTdO6Al0BPD09qxQoUMB2QluBXq/HxSX9+dqpcaUt0uu4IP2OTY0rbWGrcV24cOGhECKnuW2aI3Jxa5o2BHguhBhj0nYViNHMOYAXQFchxGpLx6latao4dOiQLUVNkoCAAOrVq+dQGWyBGlfaIr2OC9Lv2NS40ha2GpemaYeFEFXNbbPLDFrTtJxApBDiqaZpnkADpDNYLEKI10z6zwXWJ6acFQqFQqFIz9jLxJ0HmGfw0nYBlgkh1mua9hWAEEKtOysUCoVCYYJdFLQQ4gRQyUy7WcUshOhsa5kUCoVCoXBm0t9KvkKhUCgU6QCloBUKhUKhcELsFmblCIKCgggMDCQyMtJm58icOTNnz5612fEdRVoZl7u7O7ly5SJTpkyOFkWhUChSlXSroIOCgrh//z758uXD09OTxGKrX4bg4GB8fX1tcmxHkhbGJYQgNDSU27dvAyglrVAo0hXp1sQdGBhIvnz58PLysplyVjgWTdPw8vIiX758BAaq7LAKhSJ9kW4VdGRkJJ6eno4WQ2EHPD09bbqMoVAoFI4g3SpoQM2cXxHU56xQKNIj6VpBKxQK69AF6BwtgkKhiIdS0ApA5pnVNI2HDx/a/dz16tWje/fudj+vQhKlj2LozqGOFkOhUMRDKWgn5MGDB3zzzTcULlwYDw8P/Pz8qF+/Plu3bgWgcOHCjBkzJomjOB9z587Fx8cnQfvKlSsZOXKkAyRSbLm8hXLTywFw5O4RB0ujUDgRuXODpsU+6vn7y+e5c9tNhHQbZpWWadOmDS9evOCPP/6gaNGiBAYGsnPnTh49euRo0WxCtmzZHC3CK8flx5dp8mcTLjy6ENtWZWYVAIbUHYKuns5BkikUTsL9+8lrtwFqBp0Y8e6gYh82vIN6+vQpu3btYtSoUdSvX59ChQpRrVo1evfuzUcffUS9evW4fv06ffr0QdO0OA5SK1eupFy5cnh4eFCgQAFGjBiBaTnRiIgIBg4cSKFChfDw8KBIkSJMmjQpzvmPHz9OjRo18PPzo2rVqhw5YpxVPXr0iHbt2pE/f348PT0pU6YMc+bMibP/v//+S82aNfHx8SFz5szUqFGDU6dOERAQQJcuXXj+/Hms3DqdDkho4rZGTkXKCIkIYcC2AZSeVjqOcgYQQwRiiFDKWaEIDXW0BIBS0InjgDsoHx8ffHx8WLt2LWFhYQm2r1y5kvz58zN48GDu3r3L3bt3ATh8+DBt27aldevWnDx5klGjRjFy5EimTJkSu2+nTp2YP38+48aN4+zZs/zxxx9kyZIlzvEHDBjAqFGj2LVrF9mzZ6d9+/axSj4sLIzKlSuzfv16Tp8+zXfffceXX37J9u3bAYiKiqJFixa89dZbHD9+nAMHDvDdd9/h6urKm2++yYQJE/Dy8oqVu3fv3mbfA2vkVCQPIQQLTyyk+OTijNoziojoiNhtA98a6EDJFAoHExUFJ0/CH3/Al19C5cqQPbujpQKUidvpcHNzY+7cuXzxxRfMnDmTSpUqUatWLdq2bUuNGjXIli0brq6u+Pr6kttkJj9u3Djq1q3L0KHS2ad48eJcvHiRX3/9lR49enDx4kWWLFnCxo0badSoEQBFihRJcP6ff/4Zf39/goODGTx4MG+99Ra3b98mf/785MuXjz59+sT27dq1K//88w+LFy+mfv36BAUF8fTpU5o3b87rr78OQMmSJWP7Z86cGU3T4sgdH2vlVFjPoTuH+Hbjt+y7tQ+A8n7lufjoIqFRobzz+jsM8x+Gu6u7g6VUKOyAEHD1Khw8CP/9J/8fPQr58kH16lCtGtSrB3PnwpYtjpb2FZxBmzNZW3pYcRzfTJle7jhmaNOmDXfu3GHdunU0btyYvXv3UrNmTX755ReL+5w9e5ZatWrFaYtRrkFBQRw9ehQXFxf8/f0TPXf58uVjn+fNmxcgNktXdHQ0I0aMoHz58mTPnh0fHx9WrlzJjRs3ALmW3LlzZ959912aNm3KuHHjuHnzZrLGbq2ciqS5H3Kfz9Z8RvXfq7Pv1j78vP2Y02IO1fNWJzQqlEKZC/Fn6z9xdXFVZm1F+iQwEDZsgCFDoHFjyJkT6tSBpUshRw7ZfvMmnDsH06dL62j37lJJOwGvnoIWwvqHFccJDgp6ueNYIGPGjDRs2JDBgwezd+9ePvvsM3Q6HREREWb7CyEsJuzQNC3OWnRiuLsbZ1Ixx9Pr9QCMGTOGsWPH0qdPH7Zv386xY8do2bJlHJnmzJnDgQMHqFOnDmvXrqV48eJs3rzZqnPHjEPxckRERzBu3ziKTynO7GOzcXNxo/cbvbnQ4wJ6oWfW0Vl4uHqw8sOVZPdyDlOeQvHSBAdDQAD89hu0bQuFC0OJEjBxojRjf/UVnDgBt27BypUwYADUrw+ZM8OSJVCypJxdHz8ut/n5mT+PpXYboEzcaYTSpUsTFRVFWFgYGTJkIDo6OsH23bt3x2nbvXs3+fPnx9fXl8qVK6PX69mxY0es6Ti57N69m+bNm9OxY0dAKtMLFy4kWB+uUKECFSpUoF+/fjRu3Jh58+bx7rvvmpU7Pqkh56vMpkub6LmpJ+cfnQegSbEmjH93PMWzF+fQnUN8s+EbAKY3nU7lPJUdKapCYZ7cuc37+fj5wb178nlEhFS2MWbq//6Da9egQgVpqm7ZEn75BYoWTdyKeeIE9OgBz57Bn39C7drGbTHnMhAQEEA9O8+sX70ZdHJwwB3Uo0ePePvtt1m4cCEnTpzg6tWrLF++nNGjR1O/fn0yZcpE4cKF2bVrF7dv345NLPLDDz+wc+dOdDodFy5cYNGiRYwdO5a+ffsCUKxYMT744AM+//xzVqxYwdWrV9m1axcLFiywWrbixYuzfft2du/ezblz5+jevTtXr16N3X716lX69+/P3r17uX79Ojt27ODEiROULl0akPHbYWFhbN26lYcPH/LixYsE50gNOV9Fbofe5r3F79F4UWPOPzpPsWzFWN9uPRs+3kDx7MV5+OIhbZa1ITw6nC+rfEmXSl0cLbJCYZ7EnHN79IAaNSBrVvj0UzhyBGrWhIUL4elT2LsXJkyA9u2hWDHLyvnxY2nKbtAAPvoIDh+Oq5ydBSFEmn1UqVJFWOLMmTMWt6UmQUFBqXq8sLAwMWDAAFG1alWRJUsW4enpKYoWLSp69eolHj16JIQQYt++faJ8+fLCw8NDyI9QsmLFClG2bFnh7u4u8ufPL4YPHy70en2cY/fp00fkzZtXZMiQQRQpUkRMnjxZCCHEjh07BCAePHgQO66rV68KQBw8eFAIIcTjx49Fq1athI+Pj8iZM6fo06eP+Prrr0XdunWFEELcu3dPtGrVKvb4BQoUEH369BERERGxMnz11Vcie/bsAhBDhgwRQghRt25d0a1bN6vktIS1n/eOHTus6pdWCAoLEv229hPuQ90FOoTvL75i9O7RIjwqPLZPVHSUaDC/gUCHqPF7DRE
"text/plain": [
"<Figure size 504x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 4– 11\n",
2021-11-21 05:36:22 +01:00
"\n",
2021-11-03 23:35:15 +01:00
"from math import ceil\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"n_epochs = 50\n",
2016-05-22 16:01:18 +02:00
"minibatch_size = 20\n",
2021-11-03 23:35:15 +01:00
"n_batches_per_epoch = ceil(m / minibatch_size)\n",
2016-05-22 16:01:18 +02:00
"\n",
2017-06-06 15:16:46 +02:00
"np.random.seed(42)\n",
2021-11-03 23:35:15 +01:00
"theta = np.random.randn(2, 1) # random initialization\n",
"\n",
"t0, t1 = 200, 1000 # learning schedule hyperparameters\n",
2016-05-22 16:01:18 +02:00
"\n",
"def learning_schedule(t):\n",
" return t0 / (t + t1)\n",
"\n",
2021-11-03 23:35:15 +01:00
"theta_path_mgd = []\n",
"for epoch in range(n_epochs):\n",
2017-06-06 15:16:46 +02:00
" shuffled_indices = np.random.permutation(m)\n",
2016-09-27 16:39:16 +02:00
" X_b_shuffled = X_b[shuffled_indices]\n",
2016-05-22 16:01:18 +02:00
" y_shuffled = y[shuffled_indices]\n",
2021-11-03 23:35:15 +01:00
" for iteration in range(0, n_batches_per_epoch):\n",
" idx = iteration * minibatch_size\n",
" xi = X_b_shuffled[idx : idx + minibatch_size]\n",
" yi = y_shuffled[idx : idx + minibatch_size]\n",
" gradients = 2 / minibatch_size * xi.T @ (xi @ theta - yi)\n",
" eta = learning_schedule(iteration)\n",
2016-05-22 16:01:18 +02:00
" theta = theta - eta * gradients\n",
2021-11-21 05:36:22 +01:00
" theta_path_mgd.append(theta)\n",
"\n",
2016-05-22 16:01:18 +02:00
"theta_path_bgd = np.array(theta_path_bgd)\n",
"theta_path_sgd = np.array(theta_path_sgd)\n",
2021-11-21 05:36:22 +01:00
"theta_path_mgd = np.array(theta_path_mgd)\n",
"\n",
2021-11-03 23:35:15 +01:00
"plt.figure(figsize=(7, 4))\n",
"plt.plot(theta_path_sgd[:, 0], theta_path_sgd[:, 1], \"r-s\", linewidth=1,\n",
" label=\"Stochastic\")\n",
"plt.plot(theta_path_mgd[:, 0], theta_path_mgd[:, 1], \"g-+\", linewidth=2,\n",
" label=\"Mini-batch\")\n",
"plt.plot(theta_path_bgd[:, 0], theta_path_bgd[:, 1], \"b-o\", linewidth=3,\n",
" label=\"Batch\")\n",
"plt.legend(loc=\"upper left\")\n",
"plt.xlabel(r\"$\\theta_0$\")\n",
"plt.ylabel(r\"$\\theta_1$ \", rotation=0)\n",
"plt.axis([2.6, 4.6, 2.3, 3.4])\n",
"plt.grid()\n",
2016-05-22 16:01:18 +02:00
"save_fig(\"gradient_descent_paths_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"source": [
2021-10-02 13:14:44 +02:00
"# Polynomial Regression"
2016-05-22 16:01:18 +02:00
]
},
2017-05-29 23:20:14 +02:00
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 24,
2018-03-15 18:38:58 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
2016-05-22 16:01:18 +02:00
"m = 100\n",
2017-05-29 23:20:14 +02:00
"X = 6 * np.random.rand(m, 1) - 3\n",
2021-11-03 23:35:15 +01:00
"y = 0.5 * X ** 2 + X + 2 + np.random.randn(m, 1)"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 25,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZYklEQVR4nO3dfYxcV33G8efntZ2YOCuD2NI0L01SxQEcagMpsKJFawxt1PJWhRRoGlIMWAgBoQkCEpSGYuFQmaBEFX8QwMJRt1QG0zbiNcF41CJvAnHYkBjj8KLUGFIiApazIUxs769/zEwynszr7j3n3JfvR7LGsx7PPffuzHnuebnnmrsLAIDYlqQuAACgmgggAEASBBAAIAkCCACQBAEEAEiCAAIAJJFZAJnZNjN7yMzua/vZM8zsdjP7UfPx6VltDwBQbFm2gD4n6aKOn31Q0i53P0/SruZzAABkWV6IamZnS/qyu1/QfH5A0pS7P2hmp0mqufv5mW0QAFBYoceAnuXuD0pS8/H3Am8PAFAQS1MXQJLMbJOkTZJ08sknv/Css85KXKL45ufntWRJteaEVHGfJfa7Ssq8z489NqZDh54md8lMOuOM32rFiuOSpPvvv/9X7j4x8E3cPbM/ks6WdF/b8wOSTmv+/TRJBwa9x+rVq72Kdu/enboI0VVxn93Z7yop8z5v2eI+NuYuNR63bHny3yTd5UNkRuhovlXS5c2/Xy7pvwJvDwAQwdSUtHy5NDbWeJyaGv09MuuCM7PPS5qS9EwzOyTpOkkfk7TDzN4q6aCkS7LaHgAgnclJadcuqVZrhM/k5OjvkVkAufubevzThqy2AQDIj8nJhQVPSzlHxwAAuUcAAQCSIIAAAEkQQACAJAggAEASBBAAIAkCCACQBAEEAEiCAAIAJEEAAQCSIIAAAEkQQACAJAggAEBfMzPS9dc3HrOUizuiAgDyaWZG2rBBevzxxn1/du1a3ArY7WgBAQB6qtUa4XP8eOOxVsvuvQkgAEBPWdz5tBe64AAAPWVx59NeCCAAQF+LvfNpL3TBAQCSIIAAAEkQQACAJAggAEASBBAAIAkCCACQBAEEAEiCAAIAJEEAAQCSIIAAAEkQQACQWKj77eQda8EBQEIh77eTd7SAACChkPfbyTsCCAASCnm/nWGl6gKkCw4AEsrqfjszM8O/R/trpXRdgAQQACS22PvtjDKO1Pnayy9/ahdgrACiCw4ACm6UcaTO10rpugCjtIDM7B8kvU2SS7pX0lvc/Xcxtg0AZdcaR2q1avqFSOdr3/zmxp9bbolU2DbBA8jMTpf0HknPdffHzGyHpDdK+lzobQNAFYwyjtTttTMz0vbtjVDavj3eOFCsMaClklaY2VFJT5P0i0jbBYBSGDTJYJRxpM7XduvCK0UAufvPzezjkg5KekzSbe5+W+jtAkBZhL5YdZQuvCyZu4fdgNnTJe2U9AZJhyV9QdIX3f1f216zSdImSZqYmHjhjh07gpYpj+bm5rRy5crUxYiqivsssd9VktU+T0+fpW3bztH8vGnJknlt3PiALr30YAYlfNK+feOanV2ldesOa82aI4t6r/Xr1+919wsHvS5GAF0i6SJ3f2vz+ZslvcTd39nt9eeff74fOHAgaJnyqFaraSrFFWgJVXGfJfa7SrLa56It12NmQwVQjDGgg5JeYmZPU6MLboOkuyJsFwBKYaEXq45ycWqM9+kUYwzoTjP7oqS7JR2T9D1JN4feLgCUyagXq2bVagrZ+opyIaq7X+fuz3b3C9z9Mnevx9guAFRVVouchlwslZUQAKCEslrkNORiqawFBwAllNUip1m9TzcEEACU1GIXOc36fTrRBQcASIIAAgAkQQABAJIggAAASRBAAIAkCCAAQBIEEAAgCQIIAJAEAQQASIIAAgAkQQABAJIggAAASRBAAIAkCCAAQBIEEAAgCQIIAJAEAQQA6GpmRrr++sZjCNwRFQAyNjPTuIX1+Pi4pqZSl2ZhZmakDRukxx+Xli9v3JY767uiEkAAkKH2invp0rV6wQvC3M46tFqtsQ/Hjzcea7Xs94MuOADIUHvFffSoqVbLfhuhu8YkaWqq0fIZG2s8hmjJ0QICgAy1Ku5GC8gzr7hjdI1JjffctasRqFNTYbZBAAFAhtor7vHxezQ5+YJM3z9G11jL5GTY7kMCCAAy1qq4a7Ujmb93ewsrVNdYLAQQABRIjK6xWAggACiY0F1jsTALDgCQBAEEAEiCAAIAJEEAAQCSIIAAAEkQQACAJAggAMiBGOu75U2U64DMbJWkz0i6QJJL2ujuFTrMANBbrPXd8iZWC+gmSV9392dLWitpf6TtAiiJMrcQuq3vVgXBW0BmNi7pZZL+XpLc/XFJj4feLoDyyHMLoXXzucUsi1Om9d1GYe4edgNm6yTdLOkHarR+9kq6wt0fbXvNJkmbJGliYuKFO3bsCFqmPJqbm9PKlStTFyOqKu6zxH4vxPT0Wdq27RzNz5uWLJnXxo0P6NJLD2ZcwtHt2zeuq65aq6NHl2jZsnndcMM9WrPmyQVIR9nnffvGNTu7SuvWHT7hPYpo/fr1e939woEvdPegfyRdKOmYpBc3n98kaXOv169evdqraPfu3amLEF0V99md/V6IPXvcV6xwHxtrPO7Zk125FmPLlkaZpMbjli0n/ntVf9eS7vIh8iHGJIRDkg65+53N51+U9MEI2wVQEnldAbqqXWdZCR5A7v5/ZvYzMzvf3Q9I2qBGdxwADC2PK0DnNRiLItbtGN4tadrMlkv6qaS3RNouAASVx2AsiigB5O6zaowFAQAgiZUQAGBoZb4WKQXuiAoAQ8jztUhFRQsIAIZQ1dUKQhoqgMzskJld2fGz55nZ78zsuWGKBgD50ZpyPTYmLV0qHTxIV9xiDdsCmpH0Jx0/u1HSZ9ydKdUAkok1LtOacv32t0vu0qc/3eiS67XdmZnGCg6EVG/DjgHNSHpn64mZvU7S8yX9TYAyAcBQYo/LTE42ut6OHz+xK65zm61y1evnaHqa8aJehm0B3SHpj8zsGWZ2kqSPS/qIuz8crmgA0F+KcZn2rrheqx+0yjU/b4wX9TFsC2ivGitYX6hGy+eYpE+GKhQADCPFUjjDrH7QKle9Pq/ly5ewRE8PQwWQu9fN7HuSXi3pckl/6+5Hg5YMAAZItRTOoNUPWuXatu0Bbdx4Lt1vPYxyHdCMpCsk3e7uXw5UHgAYSV6XwpmclOr1g5qcPDd1UXJrlOuAZiXNS7pywOsAABholAC6VNKn3H1fqMIAAKqjbxecmS2RNKHG7bSfJ+kNEcoEAKiAQWNAL5P0LUkHJF3s7r8JXyQAQBX0DSB3r4n14gAAARAufbD0OgCEw+0YemDpdQAIixZQDyy9DgBhEUA9DLPeEwBg4eiC6yHVEh8AUBUEUB95XeIDCGVmhpMujkE8BBAASUy8kTgGsTEGBEBScSfeZHm5RFGPQVHRAgIgKc29dRYr6xZLEY9BkRFAACQVc+JNtxZLv3IPGt8p4jEoMgIIwBOKNvFmlBbLsK2loh2DIiOAABRWq8Vyyy2DX9uvtcTMtzQIIACFt317I1S2b+/dsunVWmLmWzrMggNQaMPOXGu1ljZvPjFkmPmWTiVaQDSvgfIaZRyo2/gOM9/SKX0A0bwGimuYk8fFzFxrvf+NN0oPP8xJamyFD6BBH9BRp2kCyIdRTh4XMnONk9P0CjUG1HnFc+sDdO21jcduV0KzqjVQTKHHZhj7Sa8wLaBuZyvDtG64sAwoptBjM4z9pBclgMxsTNJdkn7u7q9ayHt0C5thP0BFvrCMCRSoqtAnj5ycpherBXSFpP2Sxhf6Bt3CpuwfIPqoUXVFPnnEYMEDyMzOkPRXkj4q6cqFvk+vsCnzB5QJFEA4nOClF6MFdKOk90s6ddj/0Kvbqcxh0w191EA4nOClFzSAzOxVkh5y971mNtXndZskbZKkVavO0Pr1x3X06BItWzavG264R2vWHAlZzFyYm5tTrcs0nK1bxzU7u0rr1h1WvX6kVDN1eu1z2bHf+TA+Pq6lS9fK3bR0qWt8/B7
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 4– 12\n",
2021-11-03 23:35:15 +01:00
"plt.figure(figsize=(6, 4))\n",
2016-05-22 16:01:18 +02:00
"plt.plot(X, y, \"b.\")\n",
2021-11-03 23:35:15 +01:00
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
2016-05-22 16:01:18 +02:00
"plt.axis([-3, 3, 0, 10])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
2016-05-22 16:01:18 +02:00
"save_fig(\"quadratic_data_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 26,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([-0.75275929])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
"from sklearn.preprocessing import PolynomialFeatures\n",
2021-11-03 23:35:15 +01:00
"\n",
2016-05-22 16:01:18 +02:00
"poly_features = PolynomialFeatures(degree=2, include_bias=False)\n",
"X_poly = poly_features.fit_transform(X)\n",
"X[0]"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 27,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([-0.75275929, 0.56664654])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
"X_poly[0]"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 28,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(array([1.78134581]), array([[0.93366893, 0.56456263]]))"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
"lin_reg = LinearRegression()\n",
"lin_reg.fit(X_poly, y)\n",
"lin_reg.intercept_, lin_reg.coef_"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 29,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA00ElEQVR4nO3dd3hUVfrA8e9JCCEBQtvQm4oUpYNiAGmCooi4WAEVREXFLqJYVhAXELGt/tS1wAIrdpBVVEQgcVECAhIQiFgQ6bCo1EBIyPv74ySQhJRJMvfeKe/neeaZzOTOvedOJved095jRASllFLKbRFeF0AppVR40gCklFLKExqAlFJKeUIDkFJKKU9oAFJKKeUJDUBKKaU84bcAZIyZZozZY4xZl+u56saYL40xP2XfV/PX8ZRSSgU3f9aApgN98z03BlgkImcCi7IfK6WUUhh/TkQ1xjQG5olIy+zHG4EeIrLTGFMHSBKRZn47oFJKqaDldB9QLRHZCZB9X9Ph4ymllAoS5bwuAIAxZgQwAqBChQodGjZs6HGJ3JeVlUVERHiNCQnHcwY973ASyud85Egk27bFIgLGQP36acTEHAfgxx9/3Csi8cXuRET8dgMaA+tyPd4I1Mn+uQ6wsbh9NG3aVMJRYmKi10VwXTies4iedzgJ5XOeOFEkMlIE7P3EiSd/B6wUH2KG06H5Y2Bo9s9Dgf84fDyllFIu6NEDypeHyEh736NHyffhtyY4Y8w7QA/gL8aYbcBY4CngfWPMTcAW4Cp/HU8ppZR3EhJg0SJISrLBJyGh5PvwWwASkUGF/OoCfx1DKaVU4EhIKF3gyRGavWNKKaUCngYgpZRS/nHsGPTp4/PmATEMuyQOHDjAnj17yMjI8LooflWlShVSU1O9LoarfD3nqKgoatasSVxcnAulUkqV2sSJsHChz5sHVQA6cOAAu3fvpl69esTExGCM8bpIfnPw4EEqV67sdTFc5cs5iwhHjhxh+/btABqElApUa9fChAkleklQNcHt2bOHevXqERsbG1LBRxXOGENsbCz16tVjz549XhdHKVWQzEy48UZ7f8cdPr8sqAJQRkYGMTExXhdDeSAmJibkml2VChlTpsB330GjRjBpks8vC6oABGjNJ0zp312pAJWaCuPG2Z/feANK0JUQdAFIKaVUgDh+HIYPt6PfbrqpRCPgQANQyPnwww/z1BamT59OpUqVyrTPpKQkjDHs3bu3rMVTSoWS556DZcugXj149tkSv1wDkEuGDRuGMQZjDFFRUZx++uk88MADHD582NHjXnPNNWzatMnn7Rs3bswzzzyT57nOnTuzc+dOatSo4e/iKaWCVWoq/O1v9uc33oAqVUq8i6Aahh3sevfuzb///W8yMjJYsmQJN998M4cPH+bVV1/Ns11mZiaRkZF+6feIiYkp88CN8uXLU7t27TKXRSkVIjIzYdgwSE+3TXAXX1yq3WgNyEXR0dHUrl2bBg0aMHjwYIYMGcLcuXMZN24cnTp1Yvr06ZxxxhlER0dz+PBh9u/fz4gRI6hZsyaVK1eme/furFy5Ms8+Z86cSaNGjYiNjeXSSy9l9+7deX5fUBPcp59+SqdOnYiJiaFGjRr079+fo0eP0qNHD3777TdGjx59orYGBTfBzZkzh1atWhEdHU2DBg2YMGFCzhIcgK1J/f3vf+fWW28lLi6O+vXrM2XKlDzlmDZtGk2bNqVChQrEx8dz0UUXkZmZ6Zf3WinloGefhW+/hfr1bTNcKWkA8lDuocW//fYbb7/9Nh988AFr1qwhOjqafv36sX37dubNm8fq1avp1q0bvXr1YufOnQAsX76cYcOGMWLECFJSUujfvz+PP/54kcecP38+AwYMoE+fPqxatYrExES6d+9OVlYWc+bMoX79+jz++OPs3LnzxHHyW7VqFVdddRUDBw7k+++/56mnnmLSpEn83//9X57tnn/+eVq1asV3333HQw89xIMPPkhycjIAK1euZNSoUYwdO5aNGzeycOFC+vbtW9a3VCnlgORkO7o6ORnYsAFyrjNvvlmqprcTfFk0yM1bUQvSbdiw4dQnwZtbCQ0dOlT69et34vHy5culRo0acvXVV8vYsWOlXLlysmvXrhO/X7RokVSsWFHS0tLy7KdNmzYyefJkEREZNGiQ9O7dO8/vb7rpJiFX+f71r39JxYoVTzzu3LmzXHPNNYWWs1GjRjJlypQ8zyUmJgog//vf/0REZPDgwdKzZ88824wdO1bq1auXZz/XXnttnm2aNGkiTz75pIiIzJ49W+Li4uTAgQOFliW/Av/+QSiUFykrSjiedyic89KlIjExdtG5ShUy5GCLc+w18OabC30NAbIgncpl/vz5VKpUiQoVKpCQkEC3bt146aWXAKhXrx61atU6se2qVatIS0sjPj6eSpUqnbitW7eOX375BYDU1FQS8uVCz/84v9WrV3PBBWVbISM1NZUuXbrkea5r165s376dAwcOnHiudevWebapW7fuiWwGffr0oUGDBpx22mkMGTKEGTNmcPDgwTKVSynlf0lJdpT18eNwb/pkKqWugIYNId9gpdII/kEIufodAl23bt14/fXXiYqKom7dukRFRZ34XWxsbJ5ts7KyqFWrFkuWLDllPzn50MSjcxeRQgdI5H4+9/nl/C4rKwuAypUrs2TJElavXs2XX37JpEmTeOSRR1ixYgV169Z1rvBKqRLJWfm0RXoKf8t6wj45bVrZmt6yaQ3IRbGxsTRp0oRGjRqdcnHOr3379uzevZuIiAiaNGmS51azZk0AzjrrLJYtW5bndfkf59euXTsWLVpU6O/Lly/P8ePHi9zHWWedxddff53nua+//pr69euXKKFquXLl6NWrF5MmTWLt2rUcPnyYefPm+fx6pZTzEhJg8efpfFZzKOXJgDvvhDK2ouQI/hpQiOrduzddunRhwIABPP300zRv3pxdu3Yxf/58evfuzfnnn8/dd99N586dmTRpEldeeSVJSUl89NFHRe730UcfpX///jRp0oTBgwcjIixYsIBbb72V2NhYGjduzJIlS7juuuuIjo7mL3/5yyn7GDVqFOeccw7jxo1j8ODBrFixgmeffZaJEyf6fH7z5s1j/fr1XHjhhVSvXp3ExEQOHjxIixYtSvxeKaWcdd6C8bBrLTRpAk895bf9ag0oQBlj+Oyzz+jVqxe33HILzZo14+qrr2bjxo0nmqjOO+88pk6dyquvvkrr1q2ZM2cO43JyMhXikksu4aOPPuLzzz+nXbt2dO/encTERCIi7Edh/PjxbN26lTPOOIP4+PgC99G+fXs++OADZs+eTcuWLRkzZgxjxozhzjvv9Pn8qlatyqeffkrv3r1p3rw5zzzzDG+++Sbnn3++z/tQSrlg+XIbdIyB6dOhYkW/7dp41Y9QmGbNmsnGjRsL/F1qamrIfkPW9YCKFyp//6SkJHr06OF1MVwXjucd9Od8+DC0awc//QQPPGCzXvvAGLNKRDoWt53WgJRSShXswQdt8GnZEp580u+71wCklFLqVF98Aa+8AlFR8NZbUKGC3w+hAUgppVRef/xhc7wBjB8Pbdo4chgNQEoppfK64w7YsQM6d4bRox07TNAFoEAbNKHcoX93pVzy9tvw7rt2tNvMmRAZ6dihgioARUVFceTIEa+LoTxw5MiRYifvKqXKaMsWGDnS/vz883DGGY4eLqgCUM2aNdm+fTtpaWn6jThMiAhpaWls3779RAYIpZQDjh+HG26A/fthwAC4+WbHDxlUmRBycqDt2LHjxDIGoeLo0aNUcGCUSSDz9ZyjoqKoVavWib+/UsoBzz0HX30FtWrZFU79sCBmcYIqAIENQqF4IUpKSqJdu3ZeF8NV4XjOShUkOdlmne7Rw+Zec11KCjz6qP152jQoJAuKvwVdAFJKqVCSnGxzex47ZrNOL1rkchBKS4MhQyAjw/b/XHKJa4cOqj4gpZQKNbnX2zl2zD521QMP2FVOmzf3OdWOv2gAUkopD+WstxMZae9dTR338cfw6qtklYtiWu+3SV4TW/xr/EgDkFJKeSghwTa7Pfl
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 4– 13\n",
2021-11-03 23:35:15 +01:00
"\n",
"X_new = np.linspace(-3, 3, 100).reshape(100, 1)\n",
2016-05-22 16:01:18 +02:00
"X_new_poly = poly_features.transform(X_new)\n",
"y_new = lin_reg.predict(X_new_poly)\n",
2021-11-03 23:35:15 +01:00
"\n",
"plt.figure(figsize=(6, 4))\n",
2016-05-22 16:01:18 +02:00
"plt.plot(X, y, \"b.\")\n",
"plt.plot(X_new, y_new, \"r-\", linewidth=2, label=\"Predictions\")\n",
2021-11-03 23:35:15 +01:00
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
"plt.legend(loc=\"upper left\")\n",
2016-05-22 16:01:18 +02:00
"plt.axis([-3, 3, 0, 10])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
2016-05-22 16:01:18 +02:00
"save_fig(\"quadratic_predictions_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 30,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAB1/UlEQVR4nO2dd3xT1dvAv7d70LL3XmVvBAqIZSPDgeBCBRyIiKC4EFRaVJDhQF/wJ8pUEBUnsldBoOwhe+9VoHSvNDnvH6dJkzZp05K0TXu+fPpJ7r3nnHtuEu5zn/MsTQiBQqFQKBT5jVtBT0ChUCgUxRMlgBQKhUJRICgBpFAoFIoCQQkghUKhUBQISgApFAqFokBQAkihUCgUBYLDBJCmafM1TYvUNO2I2b4ymqat1zTtdPpraUedT6FQKBSujSM1oIVAn0z7xgMbhRD1gY3p2wqFQqFQoDkyEFXTtFrAP0KIpunbJ4EQIcR1TdMqA+FCiAYOO6FCoVAoXBZn24AqCiGuA6S/VnDy+RQKhULhIngU9AQANE0bAYwA8PHxaVOjRg0A7uruojPoqOCds9xK0CdwN/Uu1XyrWT1+PuE8VX2r4uXm5biJOxCDwYCbW/HyCSmO1wzquosTjrzmRH0iVxKvUM67HGW8ymQ5nqxPJjIlkhp+NUzt76Tcobpf9RzHjtHFkKRPopJPJavHz8SfobZ/bdw1d1IMKVxPuk5Ft7pcueKHEKBpUK1aIr6+egBOnTp1WwhRPscTCyEc9gfUAo6YbZ8EKqe/rwyczGmMoKAgYeTLiC/FmFVjhD2sPbNW9Fzc0+bxoK+DxIlbJ+waqyDYvHlzQU8h3ymO1yyEuu7ihCOvef3Z9YJQxKf/fmr1+M7LO0W779qZtsPPh4suC7rYNfb8/fPFsD+H2TxedlpZcSvhlhBCiMM3D4sms5uIKVOEcHcXAuTrlCkZ7YG9wg6Z4ezHkb+BoenvhwJ/Ofl8CoVCUSQxCAMAgsKRQDokBLy8wN1dvoaE5H4Mhy3BaZr2ExAClNM07QowCfgU+EXTtBeAS8BgR51PoVAoihMmAVRIKhgEB8PGjRAeLoVPcHDux3CYABJCPGXjUHdHnUOhUCiKK4VNAwIpdPIieIwUL4ugQqFQuCiFTQNyBIXCC06hUCgU2VMYNaDM6HSQG6c/lxNAsbGxREZGotPpLPaX05VjQsMJHD9+3Gq/2e1mk3IjheO3rB8vaEqWLGlz7kUVZ1yzv78/1apVK3buvoqijytoQJ98AmvW2N/epQRQbGwsN2/epGrVqvj6+qJpmulYTHIMNxNuElQ2yGpffaSeeqXr4ePpk1/TzRVxcXEEBAQU9DTyFUdfs8Fg4OrVq9y+fZsKFVTMs6JoYRQ8hVUDioqCmTMhIcH+Pi71mBgZGUnVqlXx8/OzED4KBYCbmxsVK1YkJiamoKeiUDicnDSgghZMZcrArl3wxRf293EpAaTT6fD19S3oaSgKMZ6enqSlpRX0NBQKh2OPDUijYB/MmzSB11+3v71LCSBAaT6KbFG/D0VRpbDagI4fh6VLIS/TcikbkEKhUBRXCqMXnF4Pzz8PO3dKG9Do0bnr73IaUHFm+fLl6glfoSimFEYN6PPPpfCpWhWefTb3/YuVACqoJ4etW7fy0EMPUbVqVTRNY+HChQUyD4VC4boYBVBh4fhx+OAD+f6776BkydyPUawEUEERHx9P06ZNmTVrlks4UaSmphb0FBQKRSYK2xLcsGGQkgLDh8ODD+ZtjOItgEJD8+U0ffv2ZcqUKQwaNChXAZKLFy+mZs2a+Pn50b9/f27evJmlzYoVK2jTpg0+Pj7Url2biRMnWgiQmzdv8tBDD+Hr60vNmjVZsGABTZs2JdTs2jVNY/bs2QwcOBB/f38mTJhg19ipqam8++67VKtWDX9/f+677z7Wrl2bh09IoVDkhFHwFIYluNu3YfduqFZNLsPlleItgMLCCnoGNtm1axfDhg1jxIgRHDx4kAEDBvDhhx9atFm7di1Dhgxh9OjRHD16lPnz57N8+XKTAAEYOnQoFy9eZNOmTfz111/8+OOPXLx4Mcv5wsLC6Nu3L4cPH+bVV1+1a+zhw4ezZcsWli5dyuHDhxk6dCgDBgzg0KFDzvtgFIpiSkFqQLqLbflyhi8REZCWBsZQu++/h1Kl8j6u63vBpRvlS6b/2aJpDv1zjZOfQmbNmkX37t2ZOHEiAEFBQezZs4d58+aZ2nzyySe8/fbbDB8+HIC6desybdo0nnnmGWbMmMGpU6dYu3YtERERdOjQAYCFCxdSq1atLOd74oknePHFF03bQ4cOzXbsc+fO8dNPP3HhwgWMFWxHjx7Nhg0b+Pbbb5kzZ45TPheForhSUE4IEREQ+93vfGrw5fNpMPdXX+rUgUk/Q+/e9za26wugIsrx48cZMGCAxb7g4GALAbRv3z52797NtGnTTPsMBgNJSUncuHGDEydO4ObmRtu2bU3Hq1evTpUqVbKcz7yNPWPv378fIQSNGze26JeSkkK3bt3ydtEKhcImOWlAzhJM4eFAmhd6oZGaCnt3+KNVhccfv/exXV8ApX/oOeWCOxJ5hLql6+LraeYEoGlO12Tyij0/JoPBwKRJkxg8OGudv/Lly+fqB+nv75+rsQ0GA5qmsWfPHjw9PS2Ou4KjhULhatijATkjTCMkBPBIxc3gDmg0bpnAhluOGdv1BVARpXHjxuzcudNiX+bt1q1bc+LECerVq2d1jEaNGmEwGNi3bx/t27cH4MqVK1y7di3H8+c0dqtWrRBCcOPGDbp27WrPJSkUinugoGxAwcEQ+NJAvP/8jVvX/NlzKAmyLqLkieItgCZNypfTxMfHc+bMGUBqFpcuXeLgwYOUKVPGZD/JzJgxY+jYsSNTp05l0KBBhIeH88cff1i0+fDDD+nfvz81a9bk8ccfx8PDgyNHjrB7926mT59OgwYN6N27NyNHjuSbb77Bx8eHt99+265krjmNHRQUxJAhQxg2bBifffYZrVu3JioqivDwcOrUqcPAgQMd8+EpFAqgYANRdeX3EhudRN26/rz0IkSscsy4xdsLLp/csPfu3UurVq1o1aoVSUlJTJo0iVatWmXxajOnQ4cOzJs3j2+++YbmzZvz+++/W7hOA/Tu3ZuVK1eyefNm2rVrR7t27fj0008thNrChQupVq0aISEhPPTQQwwZMoQKFSrg45N9WQp7xl6wYAHDhw/nnXfeoWHDhvTv35+tW7dSs2bNvH1QCoXCJgWlAe3aBUlJ8v2iReDn57ixi7cGlE+EhITk6all+PDhJi80I6MzJVvq1asXvXr1sjlGpUqVWLFihWn79u3bjBgxwmJpzdbcchrb09OT0NDQLIJRoVA4HlM9oHzUgBIS0lPsPAqjRkGnTnAk0nHjKwFUxNm0aRNxcXE0a9aMyMhIJk6cSLly5ejTp09BT02hUOSCgtCA3nkHTp8Gd3d4713Hj68EUBFHp9Px/vvvc+7cOfz8/Gjfvj1bt27N4vWmUCgKN/ltAzIYIDUVPD3BPwByWLXPE0oAFXF69+5N73uNFlMoFAWOQRhw19zzTQNyc5NJRt97D9otd9I5nDOsQqFQKByJQRhwd3PPFw0oJSXjfZ06zjuPEkAKhULhAhg1IFs4SjNauhSaN5fJRp2NEkAKhULhAhiEATfNLVtBo3FvmRAuXZLebqdOQX7kFFYCSKFQKFwAZy/BCQM895zMdP3ww2CWm9hpKCcEhUKhcAEEIkcN6F44ehT2boGKFaXzgRPSymVBaUAKhUJRwEREwNSp8tUWJi84J2hAly7Bvv3y/fz5UL68w09hFSWAXIjly5c7JdutQqEoOCIioHt3+OAD+WpLCJmW4BysAaWmwty5cglu1Cjo29ehw2eLEkD5wNSpU7nvvvsIDAykfPnyDBgwgCNHjhT0tBQKRSFgyGdzSU0V6PVSGISHW2/nLA3Iywv6D5Baz4wZDh06R5QAygfCw8MZNWoUO3bsYNOmTXh4eNCjRw+ioqIKempWSU1NLegpKBTFAoMwcL7UAry8ZLobL6/0+js22jpDAwIo51O
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 4– 14\n",
2021-11-03 23:35:15 +01:00
"\n",
2016-05-22 16:01:18 +02:00
"from sklearn.preprocessing import StandardScaler\n",
2021-11-03 23:35:15 +01:00
"from sklearn.pipeline import make_pipeline\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"plt.figure(figsize=(6, 4))\n",
"\n",
"for style, width, degree in ((\"r-+\", 2, 1), (\"b--\", 2, 2), (\"g-\", 1, 300)):\n",
2016-05-22 16:01:18 +02:00
" polybig_features = PolynomialFeatures(degree=degree, include_bias=False)\n",
" std_scaler = StandardScaler()\n",
" lin_reg = LinearRegression()\n",
2021-11-03 23:35:15 +01:00
" polynomial_regression = make_pipeline(polybig_features, std_scaler, lin_reg)\n",
2016-05-22 16:01:18 +02:00
" polynomial_regression.fit(X, y)\n",
" y_newbig = polynomial_regression.predict(X_new)\n",
2021-11-03 23:35:15 +01:00
" label = f\"{degree} degree{'s' if degree > 1 else ''}\"\n",
" plt.plot(X_new, y_newbig, style, label=label, linewidth=width)\n",
2016-05-22 16:01:18 +02:00
"\n",
"plt.plot(X, y, \"b.\", linewidth=3)\n",
"plt.legend(loc=\"upper left\")\n",
2021-11-03 23:35:15 +01:00
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
2016-05-22 16:01:18 +02:00
"plt.axis([-3, 3, 0, 10])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
2016-05-22 16:01:18 +02:00
"save_fig(\"high_degree_polynomials_plot\")\n",
"plt.show()"
]
},
2021-10-02 13:14:44 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Learning Curves"
]
},
2016-05-22 16:01:18 +02:00
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 31,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAvKElEQVR4nO3deXhU5d3/8fc3IYGwb2FXcAEUFxQRt9pibS1YlVa9qrWUYv1VfVyqdW2rj4C19teWy7b+tFX7aNVq9WmtuCBu+BAfad2pVFkVjcoim2wBEpLM9/fHPYEhTJJJmJkzk3xe13WuOWfO9p3JzHxz7vs+923ujoiISLYVRB2AiIi0TUpAIiISCSUgERGJhBKQiIhEQglIREQioQQkIiKRyFoCMrN9zGyOmS0yswVmdkWSbcaa2SYzeyc+3ZSt+EREJLvaZfFcNcDV7j7PzLoAb5vZi+6+sN52r7j7aVmMS0REIpC1KyB3X+Xu8+LzW4BFwMBsnV9ERHJLJHVAZjYEOBJ4Pcnq48xsvpk9a2aHZDcyERHJlmwWwQFgZp2BvwNXuvvmeqvnAYPdvcLMTgWeAIYmOcaFwIUAHTp0OGrfffdNeq41a9qzcWMxAJ061TBw4PYm46uoaMfKlSXN2qc5YrEYBQX51fYj32JWvJmXbzEr3sxbunTpOncvbdZO7p61CSgCngeuSnH7cqB3Y9sMGzbMG/LQQ+4QpltvbXCz3bz88q59jj8+tX2aY86cOek/aIblW8yKN/PyLWbFm3nAW97MnJC1KyAzM+BeYJG739bANv2A1e7uZjaGUES4vqXnPPdceP99qKiAyy9PbZ/evXfNr13b0jOLiEhTslkEdwLwXeBdM3sn/txPgX0B3P0u4GzgP8ysBtgOnBvPrC1SWAhTpzZvn9KEC8h161p6ZhERaUrWEpC7zwWsiW3uAO7ITkTJ9egBZqEQbsMGqKmBdlmvKRMRaf3yq5YrC9q1C0mozvoWFwCKiEhjlICSUDGciEjmqXApid69YcmSMK8EJJK/Nm/ezJo1a+jWrRuLFi2KOpyU5Vq8RUVF9OnTh65du6b1uEpASaglnEj+27x5M6tXr2bgwIHU1NSk/cczk7Zs2UKXLl2iDgMIt+ps376dFStWAKT1fVQRXBIqghPJf2vWrGHgwIF07NiRcBeItISZ0bFjRwYOHMiaNWvSemwloCQSr4CUgETyU3V1NSUlJVGH0WqUlJRQXV2d1mMqASWReAWkIjiR/KUrn/TJxHupBJSEroBERDJPCSgJJSARaS3Gjh3LZZddFnUYSakVXBIqghORKJ166qmMHDmSO+7Y+45hHn/8cYqKitIQVfopASWhKyARyXXV1dUpJZaePXtmIZqWURFcEvXvA2p5d6gi0mo0t2fjFpo8eTJz587lzjvvxMwwM+6//37MjFmzZjFmzBiKi4t5/vnnWbZsGRMmTKBfv3506tSJUaNGMXPmzN2OV78IbsiQIdxyyy1cdNFFdO3alUGDBvHrX/86K6+tPiWgJDp3hvbtw3xlJWzbFm08IpIDpk3Lyml+97vfMWbMGM4//3xWrVrFqlWr2GeffQC4/vrrueWWW1i8eDHHHHMMFRUVjB8/nhdffJH58+dz1llnceaZZ7J48eJGz/Gb3/yGww47jHnz5nH99ddz3XXX8eqrr2bj5e1GCSgJMxXDibRKZi2f9mb/ZujWrRvFxcV07NiRfv360a9fPwoLCwGYOnUqp5xyCvvvvz+lpaWMHDmSiy++mMMOO4wDDzyQG264gVGjRvHYY481eo5TTjmFyy67jAMPPJDLL7+cAw88kJdeeqlFb+neUAJqgBoiiEiuGT169G7LW7du5brrrmPEiBH06NGDzp0789Zbb/HJJ580epzDDz98t+UBAwakvZeDVKgRQgN0BSTSCu1NhW7dQGER6tSp027L11xzDc899xzTp09n6NChdOzYkUmTJrFjx45Gj1O/8YKZEYvF0h5vU5SAGqAEJCJRKSoqora2tsnt5s6dy6RJkzjrrLMAqKysZNmyZQwbNizTIaaFiuAaoCI4EdnNlClZO9XgwYN54403KC8vZ926dQ1enQwbNowZM2Ywb9483n33XSZOnEhlZWXW4txbSkAN0BWQiOwmS82wAS6//HKKi4sZMWIEpaWlDdbp3HbbbfTp04cTTzyR8ePHc+yxx3LiiSdmLc69pSK4BigBiUhUhg4dukez6MmTJ++x3eDBg5k9e/Zuz11zzTW7LZeVle22XF5evsdx6m+TLboCaoCK4EREMksJqAG6AhIRySwloAYoAYmIZJYSUANUBCcikllKQA3o1WvX/OefQwpN8kVEpBmUgBpQVATdu4f5WAw2bIg0HBGRVkcJqBGqBxIRyRwloEYoAYmIZI4SUCPUEEFEJHOUgBqhKyARyUf1R0Gtv5zMoYceytQsdjcE6oqnUUpAItIaPP7443sMwZALlIAaoSI4EWkNevbsGXUISakIrhG6AhKRbLv77rs54IADqKmp2e358847jwkTJrBs2TImTJhAv3796NSpE6NGjWLmzJmNHrN+EdyaNWuYMGECJSUlDB48mPvuuy8jr6UpSkCNSLwCUgISkWz41re+xaZNm3br5Xrr1q08+eSTTJw4kYqKCsaPH8+LL77I/PnzOeusszjzzDNZvHhxyueYPHkyH3zwAbNnz+aJJ57gwQcfTNpLdqapCK4RiVdAKoITyX9m0Z071dG8e/TowSmnnMLDDz/MuHHjAJgxYwbt2rXj9NNPp0OHDowcOXLn9jfccANPP/00jz32GDfeeGOTx1+6dCnPPvssc+fO5YQTTgDggQceYP/992/+i9pLWbsCMrN9zGyOmS0yswVmdkWSbczMbjezD8zs32Y2KlvxJaMiOBGJwjnnnMMTTzzBtm3bAHj44Yc5++yz6dChA1u3buW6665jxIgR9OjRg86dO/PWW281OGhdfYsWLaKgoIAxY8bsfG7w4MEMGDAgI6+lMdm8AqoBrnb3eWbWBXjbzF5094UJ24wHhsanY4A/xB8joSI4EYnCuHHjaNeuHU8++SQnn3wys2fP5oUXXgDCgHPPPfcc06dPZ+jQoXTs2JFJkyaxY8eOlI7tqV6KZUHWEpC7rwJWxee3mNkiYCCQmIAmAA96eIdeM7PuZtY/vm/Wde0K7dpBTQ1UVEBlJXToEEUkIpIOOfTb26j27dtz9tln8/DDD7Nu3Tr69evHl770JQDmzp3LpEmTOOusswCorKxk2bJlDBs2LKVjH3zwwcRiMd58802OP/54AD755BNWrlyZmRfTiEgaIZjZEOBI4PV6qwYCnyYsL48/FwkzFcOJSDQmTpzI888/z1133cV5551HQUH4uR42bBgzZsxg3rx5vPvuu0ycOJHKysqUjzt8+HDGjRvHRRddxKuvvso777zD5MmTKSkpydRLaVDWGyGYWWfg78CV7r65/uoku+zxP4uZXQhcCFBaWprR8cw7dhwNdAbguefe4sADK/bqeBUVFZGNv95S+Raz4s28fIi5W7dubNmyBYDa2tqd8/mgtraWI488kgEDBrBw4ULuvffenfHffPPNXHbZZZx44ol0796dSy65hIqKCqqrq3d7vTt27Ghw+Y477uDyyy/ny1/+Mr169eLHP/4xn332GVVVVY2+T5WVlen9u7t71iagCHgeuKqB9XcD305YXgL0b+yYw4YN80w66ST3cOHu/sILe3+8OXPm7P1BsizfYla8mZcPMS9cuHDn/ObNmyOMpPlyNd7E97Q+4C1vZk7IZis4A+4FFrn7bQ1s9hQwKd4a7lhgk0dU/1NHDRFERDIjm0VwJwDfBd41s3fiz/0U2BfA3e8CZgGnAh8A24DzsxhfUroXSEQkM7LZCm4uyet4Erdx4NLsRJQaNUIQEckMdcXTBBXBiYhkhhJQE1QEJ5K/PF9u/MkDmXgvlYCaoCI4kfxUVFTE9u3bow6j1di+fXvaxxRSAmpCJEVwWR6VUKQ16tOnDytWrGDbtm26EtoL7s62bdtYsWIFffr0Seux1Rt2EyIpgps2TUlIZC917doVgJUrV7J
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"from sklearn.model_selection import learning_curve\n",
"\n",
"train_sizes, train_scores, valid_scores = learning_curve(\n",
" LinearRegression(), X, y, train_sizes=np.linspace(0.01, 1.0, 40), cv=5,\n",
" scoring=\"neg_root_mean_squared_error\")\n",
"train_errors = -train_scores.mean(axis=1)\n",
"valid_errors = -valid_scores.mean(axis=1)\n",
2016-05-22 16:01:18 +02:00
"\n",
2022-05-24 14:37:17 +02:00
"plt.figure(figsize=(6, 4)) # extra code – not needed, just formatting\n",
2021-11-03 23:35:15 +01:00
"plt.plot(train_sizes, train_errors, \"r-+\", linewidth=2, label=\"train\")\n",
"plt.plot(train_sizes, valid_errors, \"b-\", linewidth=3, label=\"valid\")\n",
2016-05-22 16:01:18 +02:00
"\n",
2022-02-19 06:17:36 +01:00
"# extra code – beautifies and saves Figure 4– 15\n",
2021-11-03 23:35:15 +01:00
"plt.xlabel(\"Training set size\")\n",
"plt.ylabel(\"RMSE\")\n",
"plt.grid()\n",
"plt.legend(loc=\"upper right\")\n",
"plt.axis([0, 80, 0, 2.5])\n",
"save_fig(\"underfitting_learning_curves_plot\")\n",
"\n",
"plt.show()"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 32,
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
2021-11-03 23:35:15 +01:00
"from sklearn.pipeline import make_pipeline\n",
"\n",
"polynomial_regression = make_pipeline(\n",
" PolynomialFeatures(degree=10, include_bias=False),\n",
" LinearRegression())\n",
"\n",
"train_sizes, train_scores, valid_scores = learning_curve(\n",
" polynomial_regression, X, y, train_sizes=np.linspace(0.01, 1.0, 40), cv=5,\n",
2021-11-21 05:36:22 +01:00
" scoring=\"neg_root_mean_squared_error\")"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAv9UlEQVR4nO3deZxUxbn/8c8zMMAMwy44LMoIskhEBBFXFExUIBpUuG4xBOP9qYl4NWpcolEwJtckXo254najopFojIIormhAxbghAsoqCCI7uAADDDAz9fujupmenq176D69zPf9etWrz9ann26Gfrrq1Kky5xwiIiJBy0l1ACIi0jApAYmISEooAYmISEooAYmISEooAYmISEooAYmISEoEloDM7CAzm2lmi81soZldVc0xQ8xsq5nNC5Vbg4pPRESC1TjA1yoFrnXOzTWzFsDHZjbDObco6rh3nHNnBBiXiIikQGA1IOfceufc3NDydmAx0Dmo1xcRkfSSkmtAZlYE9Ac+qGb3cWY238xeMbPvBRuZiIgEJcgmOADMrAB4DrjaObctavdcoKtzrtjMRgDPAz2qOcelwKUAzZo1O+rggw9ObtAJVF5eTk5OYvP+8uUFlJcbAN27F9OoUWKHV0pGzMmkeJMv02JWvMm3bNmyLc659nE9yTkXWAFygdeAa2I8fhVwQG3H9OzZ02WSmTNnJvyc7do5B75s3pzw0ycl5mRSvMmXaTEr3uQD5rg4c0KQveAMeARY7Jy7u4ZjCkPHYWaD8E2EXwcVo4iIBCfIJrgTgJ8An5rZvNC2XwMHAzjnHgRGAz83s1JgF3B+KLOKiEiWCSwBOedmA1bHMfcB9wUTkYiIpFJmXeUSEZGsoQSUZdRgKSKZIvBu2JJ4VmvDpkjDtW3bNjZt2kSrVq1YvHhxqsOJWbrFm5ubS4cOHWjZsmVCz6sEJCJZadu2bWzcuJHOnTtTWlqa8C/PZNq+fTstWrRIdRiAv1Vn165drF27FiChn6Oa4EQkK23atInOnTuTn5+PqZmg3syM/Px8OnfuzKZNmxJ6biUgEclKe/fuJS8vL9VhZI28vDz27t2b0HMqAYlI1lLNJ3GS8VkqAWUZ9YITkUyhBJQF9CNPRGoyZMgQxo0bl+owqqVecCIiaWbEiBH069eP++7b/4FhpkyZQm5ubgKiSjwlIBGRDLR3796YEkvbtm0DiKZ+1AQnIhKL8eMDeZmxY8cye/ZsJk6ciJlhZkyaNAkz4+WXX2bQoEE0adKE1157jRUrVjBy5EgKCwtp3rw5AwYMYPr06ZXOF90EV1RUxB133MFll11Gy5Yt6dKlC3/6058CeW/RlIBERGIxYUIgL3PvvfcyaNAgLr74YtavX8/69es56KCDALjhhhu44447WLJkCccccwzFxcUMHz6cGTNmMH/+fEaNGsU555zDkiVLan2Ne+65h759+zJ37lxuuOEGrr/+et57770g3l4lSkAi0nCY1b/sz/Pj0KpVK5o0aUJ+fj6FhYUUFhbSqFEjAMaPH89pp51Gt27daN++Pf369ePyyy+nb9++HHroodx8880MGDCAZ599ttbXOO200xg3bhyHHnooV155JYceeihvvvlmvT7S/aEElAUi/77VDVskew0cOLDS+o4dO7j++uvp06cPbdq0oaCggDlz5rB69epaz3PEEUdUWu/UqVPCRzmIhTohiEjDsT+/0MxS/guvefPmldavu+46Xn31Ve666y569OhBfn4+Y8aMYc+ePbWeJ7rzgplRXl6e8HjrogQkIpJmcnNzKSsrq/O42bNnM2bMGEaNGgVASUkJK1asoGfPnskOMSHUBCciEovbbgvspbp27cqHH37IqlWr2LJlS421k549ezJ16lTmzp3Lp59+ykUXXURJSUlgce4vJSARkVgE1A0b4Morr6RJkyb06dOH9u3b13hN5+6776ZDhw4MHjyY4cOHc+yxxzJ48ODA4txfaoITEUkzPXr0qNIteuzYsVWO69q1K2+88Ualbdddd12l9VmzZlVaX7VqVZXzRB8TFNWAsoB6wYlIJlICEhGRlFACEhGRlFACEhGRlFACEhGRlFACEhGRlFACygLqBScimUgJSEREUkIJSEREUkIJSEQky0TPghq9Xp3DDz+c8QEONwQaikdEJOtNmTKlyhQM6UAJSEQky7Vt2zbVIVRLTXBZIM4Zf0UkjT300EN0796d0tLSStsvvPBCRo4cyYoVKxg5ciSFhYU0b96cAQMGMH369FrPGd0Et2nTJkaOHEleXh5du3bl0UcfTcp7qYsSUJZRN2yRzHbuueeydevWSqNc79ixg2nTpnHRRRdRXFzM8OHDmTFjBvPnz2fUqFGcc845LFmyJObXGDt2LMuXL+eNN97g+eef54knnqh2lOxkUxOciDQYqWwtiPXHYZs2bTjttNOYPHkyw4YNA2Dq1Kk0btyYM888k2bNmtGvX799x9988828+OKLPPvss9xyyy11nn/ZsmW88sorzJ49mxNOOAGAxx9/nG7dusX/pvZTYDUgMzvIzGaa2WIzW2hmV1VzjJnZX8xsuZktMLMBQcUnIpIuzjvvPJ5//nl27twJwOTJkxk9ejTNmjVjx44dXH/99fTp04c2bdpQUFDAnDlzapy0LtrixYvJyclh0KBB+7Z17dqVTp06JeW91CbIGlApcK1zbq6ZtQA+NrMZzrlFEccMB3qEyjHAA6FHEZEGY9iwYTRu3Jhp06bx/e9/nzfeeIPXX38d8BPOvfrqq9x111306NGD/Px8xowZw549e2I6t0ujdvrAEpBzbj2wPrS83cwWA52ByAQ0EnjC+U/ofTNrbWYdQ88VEdkvafTdW6umTZsyevRoJk+ezJYtWygsLOTkk08GYPbs2YwZM4ZRo0YBUFJSwooVK+jZs2dM5z7ssMMoLy/no48+4vjjjwdg9erVrFu3LjlvphYpuQZkZkVAf+CDqF2dga8i1teEtikB1UK94ESyz0UXXcQPfvADVq5cyYUXXkhOjr9i0rNnT6ZOncrIkSPJzc1lwoQJlJSUxHzeXr16MWzYMC677DIefvhh8vLyuOaaa8jLy0vWW6lR4AnIzAqA54CrnXPbondX85Qqv1nM7FLgUoD27dunbD7z+iguLk54vHv2HAc0BeDdd/9N+/axVcVjlYyYk0nxJl8mxNyqVSu2b98OQFlZ2b7lTFBWVkb//v3p1KkTixYt4pFHHtkX/+233864ceMYPHgwrVu35he/+AXFxcXs3bu30vvds2dPjev33XcfV155Jaeccgrt2rXjxhtvZMOGDezevbvWz6mkpCSx/+7OucAKkAu8BlxTw/6HgAsi1pcCHWs7Z8+ePV0mmTlzZsLP2amTc75xwbk1axJ++qTEnEyKN/kyIeZFixbtW962bVsKI4lfusYb+ZlGA+a4OHNCkL3gDHgEWOycu7uGw14AxoR6wx0LbHW6/iMikpWCbII7AfgJ8KmZzQtt+zVwMIBz7kHgZWAEsBzYCVwcYHwiIhKgIHvBzab6azyRxzjgimAiEhGRVNJQPFlAveBEJBMpAWWZTLnPQSQITv8hEiYZn6USkIhkpdzcXHbt2pXqMLLGrl27Ej6nkBKQiGSlDh06sHbtWnbu3Kma0H5wzrFz507Wrl1Lhw4dEnpujYYtIlmpZcuWAKxbt47t27fTrFmzFEcUu5KSkrSKNzc3lwMPPHDfZ5ooSkAikrVatmxJy5YtmTVrFv379091ODHLtHjrS01wWUC94EQkEykBiYhISigBZRldaxWRTKEEJCIiKaEEJCIiKaEEJCIiKaEEJCIiKaEElAXUDVtEMpESUJZRLzgRyRRKQCIikhJKQCIikhJKQCIikhJKQCIikhJKQFlAveBEJBMpAWUZ9YITkUyhBCQiIimhBCQiIimhBCQiIimhBCQiIimhBJQF1AtORDKREpCIiKSEElCWUTdsEckUSkAiIpISSkAiIpISSkAiIpISSkBZQL3gRCQTKQGJiEhKKAFlGfWCE5FMoQQkIiIpEVgCMrNHzWyTmX1Ww/4hZrbVzOaFyq1BxSYiIsF
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2021-11-21 05:36:22 +01:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – generates and saves Figure 4– 16\n",
2021-11-03 23:35:15 +01:00
"\n",
"train_errors = -train_scores.mean(axis=1)\n",
"valid_errors = -valid_scores.mean(axis=1)\n",
2021-11-21 05:36:22 +01:00
"\n",
2021-11-03 23:35:15 +01:00
"plt.figure(figsize=(6, 4))\n",
"plt.plot(train_sizes, train_errors, \"r-+\", linewidth=2, label=\"train\")\n",
"plt.plot(train_sizes, valid_errors, \"b-\", linewidth=3, label=\"valid\")\n",
"plt.legend(loc=\"upper right\")\n",
"plt.xlabel(\"Training set size\")\n",
"plt.ylabel(\"RMSE\")\n",
"plt.grid()\n",
"plt.axis([0, 80, 0, 2.5])\n",
"save_fig(\"learning_curves_plot\")\n",
"plt.show()"
2016-05-22 16:01:18 +02:00
]
},
{
2021-11-03 23:35:15 +01:00
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"# Regularized Linear Models"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"## Ridge Regression"
2021-10-02 13:14:44 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-03 23:35:15 +01:00
"Let's generate a very small and noisy linear dataset:"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 34,
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – we've done this type of generation several times before\n",
2017-06-06 15:16:46 +02:00
"np.random.seed(42)\n",
2016-05-22 16:01:18 +02:00
"m = 20\n",
2017-06-06 15:16:46 +02:00
"X = 3 * np.random.rand(m, 1)\n",
"y = 1 + 0.5 * X + np.random.randn(m, 1) / 1.5\n",
2019-01-18 16:08:37 +01:00
"X_new = np.linspace(0, 3, 100).reshape(100, 1)"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 35,
2019-01-18 16:08:37 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAEPCAYAAACp/QjLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVRklEQVR4nO3df4xl5X3f8fdnf2Csri0wjArixzpbr6a1HSWwGAZZjQbkSmZFRKUSBdeyE0t0BXUiR0lkRa4EcqSqUhRZxYJ6u7VRjIQcbRU7QhRq0XantlUvpkMWwhqvu15lxQpkBKzBU2NgmG//mEs0nXNn59e5c3+9X9LV3h/PnPt89WjnM+c55zwnVYUkSUtt63cHJEmDx3CQJDUYDpKkBsNBktRgOEiSGgwHSVJDq+GQ5PwkP0jyVJLjSb7Ypc10kleTHOs87mqzD5KkzdvR8vbeAG6sqrkkO4HvJXm0qo4ua/fdqrq55e+WJLWk1XCoxSvq5jovd3YeXmUnSUOm9WMOSbYnOQa8CDxWVY93aXZ9Z+rp0SQfarsPkqTNSa+Wz0hyAfAt4Per6pkl778XWOhMPe0H7qmqvSts4wBwAOD888/fd+WVV/akr4NgYWGBbdtG8/yAUa4NrG/YjXp9P/7xj1+qqon1/lzPwgEgyd3A/62qPz9Hm78Drqmql861rcnJyTpx4kTLPRwcMzMzTE9P97sbPTHKtYH1DbtRry/JbFVds96fa/tspYnOHgNJ3g18DPjRsjaXJEnn+bWdPrzcZj8kSZvT9tlKlwJfT7KdxV/6h6vq4SR3AFTVQeBW4M4k88DrwG3l0rCSNFDaPlvpaeCqLu8fXPL8XuDeNr9XktSu0T0KI0naMMNBktRgOEiSGgwHSVKD4SBJajAcJEkNhoMkqcFwkCQ1GA6SpAbDQZLUYDhIkhoMB0lSg+EgSWowHCRJDYaDJKnBcJAkNRgOkqQGw0GS1GA4SJIaDAdJUoPhIElqMBwkSQ2thkOS85P8IMlTSY4n+WKXNkny5SQnkzyd5Oo2+yBJ2rwdLW/vDeDGqppLshP4XpJHq+rokjY3AXs7j+uAr3T+lSQNiFb3HGrRXOflzs6jljW7BXig0/YocEGSS9vshyRpc1o/5pBke5JjwIvAY1X1+LImlwHPLXl9pvOeJGlAtD2tRFW9Dfx6kguAbyX5cFU9s6RJuv1Yt20lOQAcAJiYmGBmZqbl3g6Oubm5ka1vlGsD6xt2o17fRrUeDu+oqp8lmQE+DiwNhzPAFUteXw48v8I2DgGHACYnJ2t6eronfR0EMzMzjGp9o1wbWN+wG/X6Nqrts5UmOnsMJHk38DHgR8uaPQR8unPW0hTwalW90GY/JEmb0/aew6XA15NsZzF4DlfVw0nuAKiqg8AjwH7gJPAL4DMt90GStEmthkNVPQ1c1eX9g0ueF/DZNr9XktQur5CWJDUYDpKkBsNBktRgOEiSGgwHSVKD4SBJajAcJEkNhoMkqcFwkCQ1GA6SpAbDQZLUYDhI0hCZPX2W+46cZPb02Z5+T8/u5yBJatfs6bN88qtHeXN+gfN2bOPB26fYt/vCnnyXew6SNCSOnnqZN+cXWCh4a36Bo6de7tl3GQ6SNCSm9lzEeTu2sT2wc8c2pvZc1LPvclpJkobEvt0X8uDtUxw99TJTey7q2ZQSGA6SNJBmT5/tGgL7dl/Y01B4h+EgSQNmKw88r8RjDpI0YLbywPNKDAdJGjBbeeB5JU4rSdKA2coDzysxHCRpAG3VgeeVtDqtlOSKJEeSPJvkeJLPdWkzneTVJMc6j7va7IMkafPa3nOYB/6oqp5M8h5gNsljVfXDZe2+W1U3t/zdI2Wl09gkaSu0Gg5V9QLwQuf5z5M8C1wGLA8HncMgnMYmabz17GylJO8HrgIe7/Lx9UmeSvJokg/1qg/DahBOY5M03lJV7W802QX8T+DfVtU3l332XmChquaS7Afuqaq9K2znAHAAYGJiYt/hw4db7+ugmJubY9euXQCcPPs2f/bEL5lfgB3b4PMfOZ8PXLi9zz3cuKW1jSLrG26jXt8NN9wwW1XXrPfnWg+HJDuBh4FvV9WX1tD+74Brquqlc7WbnJysEydOtNPJATQzM8P09PTfvx6lYw7Laxs11jfcRr2+JBsKh1aPOSQJ8DXg2ZWCIcklwE+rqpJcy+LUlvMmy/T7NDZJ463ts5U+CnwK+NskxzrvfQG4EqCqDgK3AncmmQdeB26rXsxtSZI2rO2zlb4HZJU29wL3tvm9kqR2ubaSJKnBcJAkNRgOkqQGw0GS1GA4SJIaDAdJUoPhIElqMBwkSQ2GgySpwXCQJDUYDpKkBsOhD2ZPn+W+IyeZPX22312RpK7aXpVVq/AWoFtvlO6NIW0Vw2GLdbsFqL+wescwljbGaaUtNrXnIs7bsY3tgZ07tjG156J+d2mkeT9uaWPcc9hi+3ZfyIO3TznNsUXeCeO35hcMY2kdDIc+8BagW8cw7g+P8ww/w0EjzzDeWh7nGQ0ec5DUKo/zjAbDQVKrPOliNDitJKlVHucZDYaDpNZ5nGf4tTqtlOSKJEeSPJvkeJLPdWmTJF9OcjLJ00mubrMPkqTNa3vPYR74o6p6Msl7gNkkj1XVD5e0uQnY23lcB3yl868kaUC0uudQVS9U1ZOd5z8HngUuW9bsFuCBWnQUuCDJpW32Q5K0OT07WynJ+4GrgMeXfXQZ8NyS12doBogkqY96ckA6yS7gr4A/qKrXln/c5Udqhe0cAA4ATExMMDMz02Y3B8rc3NzI1jfKtYH1DbtRr2+jWg+HJDtZDIYHq+qbXZqcAa5Y8vpy4Plu26qqQ8AhgMnJyZqenm63swNkZmaGUa1vlGsD6xt2o17fRrV9tlKArwHPVtWXVmj2EPDpzllLU8CrVfVCm/2QJG3OmsIhyZkkf7jsvV9N8sskH1zy9keBTwE3JjnWeexPckeSOzptHgFOASeB/wT8682XIUlq01qnlb4PfGTZe/8e+OrS01Sr6nt0P6bAkjYFfHYdfZQkbbG1Tiv9f+GQ5J+zeCbS3T3okySpz9YaDkeBf5TkfUneBfw58KdV5XKLkjSC1jqtNAu8CVzD4h7DPHBfrzolSeqvNYVDVb2R5G+A3wR+B/iXVfVWT3smST30zt3q3vWzt5nud2cG0Hquc/g+8Dngsap6uEf9kaSeW3q3uh2Bq64+6yqyy6znOodjwALwh6u0k6SBtvRudfMLeLe6LtYTDp8E/mNVHe9VZyRpKyy9W92ObXi3ui7OOa2UZBswAfwu8KvAb29BnySpp5bere5dPzvtlFIXqx1z+A3gfwAngH9RVWd73yVJ6r137lY3M3Om310ZSOcMh6qaoYfLekuSBpO/+CVJDYaDJKnBcJAkNRgOkqQGw0GS1GA4SJIaDAdJUoPhIElqMBwkSQ2GgySpwXCQJDUYDpKkhtbDIcn9SV5M8swKn08neTXJsc7jrrb7II2b2dNnue/ISWZPu3Cy2rGe24Su1V8A9wIPnKPNd6vq5h58tzR2lt7y8rwd23jw9invT6BNa33Poaq+A7zS9nYldbf0lpdvzS94y0u1ol/HHK5P8lSSR5N8qE99kEbC0lte7tyxbShueek02OBLVbW/0eT9wMNV9eEun70XWKiquST7gXuqau8K2zkAHACYmJjYd/jw4db7Oijm5ubYtWtXv7vRE6NcGwxGfSfPvs2PXnmbf/y+7Xzgwu2tbrvt+k6efZs/e+KXvLUAO7fB5z9yfut9Xo9BGL9euuGGG2ar6pr1/lwvjjmcU1W9tuT5I0n+Q5KLq+qlLm0PAYcAJicna3p6eus6usVmZmYYtvpmT5/l6KmXmdpz0TnnuIextvUYhPp6+e1t13f8yEnm6wQFvF3wxgW7mZ7+QGvbX69BGL9BtOXhkOQS4KdVVUmuZXFqy0nSIeNBUG3UO9Ngb80vDM002DhqPRySfIPFP2QuTnIGuBvYCVBVB4FbgTuTzAOvA7dVL+a21FPdDoIaDlqLfbsv5MHbp9a016n+aT0cquoTq3x+L4unumqI+defNmPf7gsNhQG35dN
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-01-18 16:08:37 +01:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – a quick peek at the dataset we just generated\n",
2021-11-03 23:35:15 +01:00
"plt.figure(figsize=(6, 4))\n",
"plt.plot(X, y, \".\")\n",
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$ \", rotation=0)\n",
"plt.axis([0, 3, 0, 3.5])\n",
"plt.grid()\n",
"plt.show()"
2019-01-18 16:08:37 +01:00
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 36,
2019-01-18 16:08:37 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[1.55325833]])"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
2019-01-18 16:08:37 +01:00
"source": [
2021-11-03 23:35:15 +01:00
"from sklearn.linear_model import Ridge\n",
"\n",
2022-02-19 06:17:36 +01:00
"ridge_reg = Ridge(alpha=0.1, solver=\"cholesky\")\n",
2019-01-18 16:08:37 +01:00
"ridge_reg.fit(X, y)\n",
"ridge_reg.predict([[1.5]])"
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 37,
2019-01-18 16:08:37 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAngAAADsCAYAAADn/9tGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABgnElEQVR4nO3dd3hUVfrA8e9JDyQkhIQkhB669CbFEkBQUMTeRSwgYgcr6oKrwLprx4JYFv3prqCuvYIQAUFBeock9BZKIAmkz/n9cdILKXOn8n6e5z6ZuXPn3ndmkpN3TlVaa4QQQgghhPfwcXUAQgghhBDCWpLgCSGEEEJ4GUnwhBBCCCG8jCR4QgghhBBeRhI8IYQQQggvIwmeEEIIIYSXsTTBU0oFKaVWKKXWKaU2KaWereSYBKXUSaXU2sLtb1bGIIQQQghxtvOz+Hw5wGCtdaZSyh9YqpT6UWv9R7njlmitL7P42kIIIYQQAosTPG1mTc4svOtfuMlMykIIIYQQTmR5HzyllK9Sai2QCszXWv9ZyWH9C5txf1RKnWN1DEIIIYQQZzPlqKXKlFLhwJfA/VrrjaX2NwBshc24I4DXtNZtqzjHOGAcQFBQUK/mzZs7JFZHsdls+Ph41jgWT4wZPDNuidl5tm/fflRrHeWs60nZ5RqeGLfE7DyeEvfOUzvJs+XRsn5LdiXtqnPZ5bAED0ApNQU4pbV+8QzH7AJ6a62Pnulc7du319u2bbM4QsdKTEwkISHB1WHUiifGDJ4Zt8TsPEqpVVrr3q64tpRdzuOJcUvMzuMJcWfnZ1NvWj2UUpyefJog/6A6l11Wj6KNKqy5QykVDFwEbC13TIxSShXe7lsYwzEr4xBCCCGE8DRJx5PQaFo3bE2gX6Bd57J6FG0s8KFSyheTuM3TWn+nlBoPoLWeBVwD3KOUygeygBu0I6sRhRBCCCE8wLajpra/faP2dp/L6lG064EeleyfVer2G8AbVl5XCCGEEMLTbTtmErwOkR3sPpf79zYUQgghhDgLbD1qerVZUYMnCZ4QQgghhBvYdWIXAO0j3ayJ1pXS09NJTU0lLy/P1aEUCwsLY8uWLa4Oo1asjNnf35/GjRvToEEDS84nhDeSsss6nha3v78/vr6+rg5DuJHEMYnsTNtJk9Amdp/LKxK89PR0Dh8+TFxcHMHBwRQO0nW5jIwMQkNDXR1GrVgVs9aarKws9u/fDyBJnhCVkLLLWp4Ud1EZeerUKdLT06WMFAD4KB/iI+KtOZclZ3Gx1NRU4uLiqFevntsUkGc7pRT16tUjLi6O1NRUV4cjhFuSsuvsJWWkcDSvSPDy8vIIDg52dRiiEsHBwW7V9CSEO5GySwQEBEgZKQD41+//os+7ffhs02eWnM8rEjxAvv26KflchDgz+Rs5u8nnL4qsOLCCvw78RXZ+tiXn85oETwghhBDCU60/vB6ALtFdLDmfJHhCCCGEEC6UlZdF0vEkfJUvHSM7WnJOSfDOIm+99RatWrUiKCiIXr16sWTJEleHJIQQ1ZKyS3i7zUc2Y9M22ke2t3sN2iKS4J0l5s6dy4MPPsjkyZNZs2YNAwYMYPjw4ezZs8fVoQkhRJWk7BJng+Lm2cbWNM+CJHhuYdWqVQwZMoTg4GDatGnD4sWLmTdvHgMHDrTsGi+//DJjxoxh7NixdOzYkZkzZxIbG8vbb79t2TWEEGcXKbuEsMaG1A0AdI3uatk5JcFzsZUrV3L++eczaNAg1q9fT79+/ZgyZQrTpk3jueeeK3Ps9OnTCQkJOeNWWdNFbm4uq1atYtiwYWX2Dxs2jGXLljn09QkhvJOUXUJYZ3ib4UzqP4nBrQZbdk6vTfCUMltpI0eafd9+W7Jv9myzb9y4kn0HDph9TcqtFNKrl9m/alXJvqlTK16nNiZNmsTIkSN5+umnadu2LTfddBOLFy8mIiKCwYPLftDjx49n7dq1Z9x69+5d4RpHjx6loKCA6OjoMvujo6M5dOhQ3YMXQjiEJ5RfUnYJYZ2h8UN5cdiL9Gvaz7JzesVSZZ7q0KFDLFmyhEWLFhXvCwgIwGazVfgGDBAREUFERESdr1d+viWttczBJISoNSm7hHB/XluDp7XZSvv2W7Nv5MiSfePGmX2zZ5fsa9LE7DtwoOzzV60y+3v1Ktk3dWrF69RU0aLYffr0Kd63bds22rdvz3nnnVfh+Lo2c0RGRuLr61vhG29qamqFb8ZCCNdz9/JLyi4hrLPt6DY+WPMBm49stvS8UoPnQidOnEAphY+PybMzMjKYNm0aMTExlR4/fvx4rrvuujOeMy4ursK+gIAAevXqxfz587n22muL98+fP5+rr77ajlcghDgbSdklhHV+TPqRh39+mHE9x/HOyHdKHli+3K7zWprgKaWCgMVAYOG5P9daTyl3jAJeA0YAp4ExWuvVVsbhKbp3747WmhkzZnDzzTfz6KOPEhsbS1JSEjt27KBt27ZljrenmWPixInceuut9O3bl4EDBzJr1iwOHDjA+PHjrXgpQoiziJRdQlhnzaE1AHSL6VayMysLLrnErvNa3USbAwzWWncDugOXKKXK9xgcDrQt3MYBZ+1Y91atWjFt2jTefvttunXrRmhoKAsWLKBz584MGDDA0mtdf/31vPrqqzz//PN0796dpUuX8sMPP9CiRQtLryOE8H5Sdglhnb8O/AVA7yalBhp9/TWkp9t1Xktr8LTWGsgsvOtfuJXv4TEK+Kjw2D+UUuFKqVit9UErY/EUkydPZvLkyWX2/f777w651oQJE5gwYYJDzi2EOLtI2SWE/TJzM9lyZAt+Pn5l58D78EO7z235IAullK9Sai2QCszXWv9Z7pA4YG+p+/sK9wkhhBBCnDVWH1yNRtM1uitBfkFm54ED8Msv4O9v17ktH2ShtS4AuiulwoEvlVKdtdYbSx1S2dj2SsdxKaXGYZpxiYqKIjExsdJrhoWFkZGRYU/YDlFQUOCWcZ2JI2LOzs6u8rOzSmZmpsOvYTWJ2XtJ2eUanhh3QUGBU8pIK3lqOeCOcc/bOw+AJjQpjq3Zp58Sb7NxZOBAsGPdZYeNotVan1BKJQKXAKUTvH1As1L3mwLlBvQXn2M2MBugffv2OiEhodJrbdmyhdDQUPuDtlhGRoZbxnUmjog5KCiIHj16WHrO8hITE6nq98NdSczeS8ou1/DEuDMyMpxSRlrJU8sBd4z7t8TfqL+3Ppf3upyEXglm3qL77gMgatIkuxI8S5tolVJRhTV3KKWCgYuAreUO+wYYrYx+wMmztf+dEEIIIc5eUxKmcPKJk4zuNtrsWL0aNm2CyEgYPtyuc1tdgxcLfKiU8sUkj/O01t8ppcYDaK1nAT9gpkhJwkyTcrvFMQghhBBCeARfH198fXzNnaLBFTfdBAEBdp3X6lG064EK9cyFiV3RbQ3ca+V1hRBCCCE8yem80wT4BuDnU5iKZWXBJ5+Y27fdZvf5vXapMiGEEEIId/XGijcInRHKi8teNDs+/xyOHzfrCfbsaff5JcETQgghhHCylQdWkp2fTWS9SLPjncJlyu6+25LzS4InhBBCCOFkRStY9GnSBzZuhN9/h9BQuPFGS84vCZ4QQgghhBMdOXWEXSd2Uc+/Hh0iO5TU3t1yC4SEWHINSfDOEosXL+byyy8nLi4OpRRz5syp9Lh3332XVq1aERQURK9evVhixxw8Qghhr5qWXc7y1ltvnbGMnDp1KkqpMltMTIyLohXuauWBlQD0jO2Jb3YO/N//mQcsap4FSfDOGpmZmXTu3JnXXnuN4ODgSo+ZO3cujz/+OJMnT2bNmjUMGDCA4cOHs2fPHidHK4QQRk3KLmeZO3cuDz74YLVlZPv27Tl48GDxtmHDBhdFLNzV73vMus0Dmg6AuXPh5Ek491zo1s2ya0iC5wZWrVrFkCFDCA4Opk2bNixevJh58+YxcOBAy64xYsQIpk+fzjXXXIOPT+Uf+8svv8zNN9/M2LFj6dixIzNnziQ2Npa3337bsjiEEN7DXcougP379zNu3DgaNWpEeHg4V199NYcPH7YsDjBl5JgxY6otI/38/Ii
"text/plain": [
"<Figure size 648x252 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-01-18 16:08:37 +01:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 4– 17\n",
2016-05-22 16:01:18 +02:00
"\n",
2022-05-24 14:37:17 +02:00
"def plot_model(model_class, polynomial, alphas, **model_kwargs):\n",
2021-11-03 23:35:15 +01:00
" plt.plot(X, y, \"b.\", linewidth=3)\n",
" for alpha, style in zip(alphas, (\"b:\", \"g--\", \"r-\")):\n",
" if alpha > 0:\n",
2022-05-24 14:37:17 +02:00
" model = model_class(alpha, **model_kwargs)\n",
2021-11-03 23:35:15 +01:00
" else:\n",
" model = LinearRegression()\n",
2016-05-22 16:01:18 +02:00
" if polynomial:\n",
2021-11-03 23:35:15 +01:00
" model = make_pipeline(\n",
" PolynomialFeatures(degree=10, include_bias=False),\n",
" StandardScaler(),\n",
" model)\n",
2016-05-22 16:01:18 +02:00
" model.fit(X, y)\n",
" y_new_regul = model.predict(X_new)\n",
2021-11-03 23:35:15 +01:00
" plt.plot(X_new, y_new_regul, style, linewidth=2,\n",
2021-11-21 22:18:02 +01:00
" label=fr\"$\\alpha = {alpha}$\")\n",
2021-11-03 23:35:15 +01:00
" plt.legend(loc=\"upper left\")\n",
" plt.xlabel(\"$x_1$\")\n",
" plt.axis([0, 3, 0, 3.5])\n",
" plt.grid()\n",
"\n",
"plt.figure(figsize=(9, 3.5))\n",
2016-05-22 16:01:18 +02:00
"plt.subplot(121)\n",
2017-06-06 15:16:46 +02:00
"plot_model(Ridge, polynomial=False, alphas=(0, 10, 100), random_state=42)\n",
2021-11-03 23:35:15 +01:00
"plt.ylabel(\"$y$ \", rotation=0)\n",
2016-05-22 16:01:18 +02:00
"plt.subplot(122)\n",
2017-06-06 15:16:46 +02:00
"plot_model(Ridge, polynomial=True, alphas=(0, 10**-5, 1), random_state=42)\n",
2021-11-03 23:35:15 +01:00
"plt.gca().axes.yaxis.set_ticklabels([])\n",
2016-05-22 16:01:18 +02:00
"save_fig(\"ridge_regression_plot\")\n",
"plt.show()"
]
},
{
2021-11-03 23:35:15 +01:00
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 38,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([1.55302613])"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"sgd_reg = SGDRegressor(penalty=\"l2\", alpha=0.1 / m, tol=None,\n",
" max_iter=1000, eta0=0.01, random_state=42)\n",
"sgd_reg.fit(X, y.ravel()) # y.ravel() because fit() expects 1D targets\n",
2021-11-03 23:35:15 +01:00
"sgd_reg.predict([[1.5]])"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:36:22 +01:00
"execution_count": 39,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[1.55321535]])"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – show that we get roughly the same solution as earlier when\n",
"# we use Stochastic Average GD (solver=\"sag\")\n",
"ridge_reg = Ridge(alpha=0.1, solver=\"sag\", random_state=42)\n",
2021-11-03 23:35:15 +01:00
"ridge_reg.fit(X, y)\n",
"ridge_reg.predict([[1.5]])"
2016-05-22 16:01:18 +02:00
]
},
2022-02-19 06:17:36 +01:00
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0.97898394],\n",
" [0.3828496 ]])"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
2022-02-19 06:17:36 +01:00
"source": [
"# extra code – shows the closed form solution of Ridge regression,\n",
"# compare with the next Ridge model's learned parameters below\n",
"alpha = 0.1\n",
"A = np.array([[0., 0.], [0., 1.]])\n",
"X_b = np.c_[np.ones(m), X]\n",
"np.linalg.inv(X_b.T @ X_b + alpha * A) @ X_b.T @ y"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(array([0.97944909]), array([[0.38251084]]))"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
2022-02-19 06:17:36 +01:00
"source": [
"ridge_reg.intercept_, ridge_reg.coef_ # extra code"
]
},
2021-10-02 13:14:44 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Lasso Regression"
]
},
2016-05-22 16:01:18 +02:00
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 42,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([1.53788174])"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
"from sklearn.linear_model import Lasso\n",
"\n",
"lasso_reg = Lasso(alpha=0.1)\n",
"lasso_reg.fit(X, y)\n",
"lasso_reg.predict([[1.5]])"
]
},
2021-10-02 13:14:44 +02:00
{
2021-11-03 23:35:15 +01:00
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 43,
2021-10-02 13:14:44 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAngAAADsCAYAAADn/9tGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABPcklEQVR4nO3deVzU1f748ddhR0AUAUXcFVCz3E1cCrUsLeu2r5Ytmlb3VprdUkvL1Nuv22o3zbpt9/Ytbb22pyVpaa6puQSKlrsoboArzPn9cWCGAVRgPrPyfj4e84D5zGc+n/cM8OY953zOOUprjRBCCCGECBxB3g5ACCGEEEJYSwo8IYQQQogAIwWeEEIIIUSAkQJPCCGEECLASIEnhBBCCBFgpMATQgghhAgwlhZ4SqkIpdQypdQapdR6pdSTleyToZQ6rJRaXXJ7wsoYhBBCCCFquxCLj3cC6K+1LlBKhQI/KaW+1lr/Um6/RVrryy0+txBCCCGEwOICT5tZkwtK7oaW3GQmZSGEEEIID7L8GjylVLBSajWQC8zTWi+tZLf0km7cr5VS51gdgxBCCCFEbabctVSZUqoe8CnwV631ujLb6wK2km7cwcBLWuuU0xxjBDACICIiomuzZs3cEqu72Gw2goL8axyLP8YM/hm3xOw52dnZ+7XWCZ46n+Qu7/DHuCVmz/HHuF3JXW4r8ACUUhOBQq31P8+wzx9AN631/jMdKy0tTWdlZVkcoXtlZmaSkZHh7TCqxR9jBv+MW2L2HKXUSq11N2+cW3KX5/hj3BKz5/hj3K7kLqtH0SaUtNyhlIoELgJ+L7dPI6WUKvm+R0kMeVbGIYQQQghRm1k9ijYJeEcpFYwp3OZorb9QSo0E0FrPBK4FRimlioBjwI3anc2IQgghhBC1jNWjaNcCnSvZPrPM968Ar1h5XiGEEEII4eBfVxsKIYQQQoizkgJPCCGEECLAWH0NntccOXKE3NxcTp065e1Q7GJjY9m4caO3w6gWV2IODQ0lMTGRunXrWhyVEIFLcpd13Bm35DfhbwKiwDty5Ah79+4lOTmZyMhISgbpel1+fj4xMTHeDqNaahqz1ppjx46xc+dOAEmCQlSB5C5ruStuyW/CHwVEF21ubi7JycnUqVPHZxJkbaOUok6dOiQnJ5Obm+vtcITwC5K7/IPkN+GPAqLAO3XqFJGRkd4OQwCRkZE+1dUkhC+T3OVfJL8JfxIQBR4gn359hPwchKge+ZvxH/KzEv4kYAo8IYQQQghhSIEnhBBCCBFgpMCrRV599VVatmxJREQEXbt2ZdGiRd4OSQghzkpylxDVJwVeLTF79mweeOABxo0bx6+//kqvXr0YNGgQ27Zt83ZoQghxWpK7hKgZKfB8wMqVKxkwYACRkZG0adOGhQsXMmfOHHr37m3ZOZ5//nmGDRvG8OHDadeuHdOnTycpKYkZM2ZYdg4hRO0iuUsI3yUFnpctX76cvn370q9fP9auXUvPnj2ZOHEiU6ZMYfLkyU77Tp06lejo6DPeKuu6OHnyJCtXrmTgwIFO2wcOHMjixYvd+vqEEIFJcpcQvi1gCzylzK2sIUPMts8/d2ybNctsGzHCsW3XLrOtcWPn53ftaravXOnYNmlSxfNUx5gxYxgyZAgTJkwgJSWFm2++mYULFxIXF0f//v2d9h05ciSrV68+461bt24VzrF//36Ki4tp2LCh0/aGDRuyZ8+emgcvhHALf8hfkruE8G0BsVSZv9qzZw+LFi1iwYIF9m1hYWHYbLYKn4AB4uLiiIuLq/H5ys/hpLWWeZ2EENUmuUsI3xewLXham1tZn39utg0Z4tg2YoTZNmuWY1vjxmbbrl3Oz1+50mzv2tWxbdKkiuepqtJFsbt3727flpWVRVpaGn369Kmwf027OeLj4wkODq7wiTc3N7fCJ2MhhPf5ev6S3CWE75MWPC86dOgQSimCgkydnZ+fz5QpU2jUqFGl+48cOZLrr7/+jMdMTk6usC0sLIyuXbsyb948rrvuOvv2efPmcc0117jwCoQQtZHkLiF8n6UFnlIqAlgIhJcc+yOt9cRy+yjgJWAwcBQYprVeZWUc/qJTp05orZk2bRq33HILY8eOJSkpic2bN7Np0yZSUlKc9nelm2P06NEMHTqUHj160Lt3b2bOnMmuXbsYOXKkFS9FCFGLSO4SwvdZ3UV7Auivte4IdAIuVUr1LLfPICCl5DYCqLVj3Vu2bMmUKVOYMWMGHTt2JCYmhvnz59OhQwd69epl6bluuOEGXnzxRZ5++mk6derETz/9xFdffUXz5s0tPY8QIvBJ7hLCvXILcxn97WiXjmFpC57WWgMFJXdDS27lr/C4Eni3ZN9flFL1lFJJWuvdVsbiL8aNG8e4ceOctv38889uOde9997Lvffe65ZjCyFqF8ldQrjP+7+9zwu/vODSMSwfZKGUClZKrQZygXla66XldkkGtpe5v6NkmxBCCCFErffu2nddPoblgyy01sVAJ6VUPeBTpVQHrfW6MrtUNra90nFcSqkRmG5cEhISyMzMrPScsbGx5OfnuxK2WxQXF/tkXGdiRczHjx8/7c/KXQoKCjx+TldJzIFLcpd3eCJuq/ObP/5N+WPM4D9xby3cyqrdq4gOiabA3ilafW4bRau1PqSUygQuBcoWeDuApmXuNwHKDei3H2MWMAsgLS1NZ2RkVHqujRs3EhMT43rQFsvPz/fJuM7EipgjIiLo3LmzRRFVTWZmJqf7/fBVEnPgktzlHZ6I2+r85o9/U/4YM/hP3CsWrwDglo638Bqv1fg4lnbRKqUSSlruUEpFAhcBv5fbbS5wmzJ6Aodr6/V3QgghhBBlPdzrYdaMXMPYXmNdOo7VLXhJwDtKqWBM8ThHa/2FUmokgNZ6JvAVZoqUzZhpUu6wOAYhhBBCCL91XsPzXD6G1aNo1wIV2q5LCrvS7zVwn5XnFUIIIYTwdzuP7CS5rjXjTgN2qTIhhBBCCH+x/+h+Wr3cir5v9eVk8UmXjycFnhBCCCGEl72z+h1OFp8kJiyGsOAwl48nBZ4QQgghhBdprZm1ahYA93S9x5JjSoEnhBBCCOFFmX9kkp2XTXJMMpelXmbJMaXAq0VeffVVWrZsSUREBF27dmXRokVn3H/hwoVcccUVJCcno5Ti7bff9kygQghRRnVzV1WeI/lN+JLXVpr57u7qfBchQdaMf5UCr5aYPXs2DzzwAOPGjePXX3+lV69eDBo0iG3btp32OQUFBXTo0IGXXnqJyMhID0YrhBBGTXJXVZ4j+U34itzCXD7Z+AlBKoi7u9xt2XGlwPMBK1euZMCAAURGRtKmTRsWLlzInDlz6N27t2XneP755xk2bBjDhw+nXbt2TJ8+naSkJGbMmHHa5wwePJipU6dy7bXXEhQkvypCCGe+mruq8hzJb8JXrN27lvCQcC5LuYymsU3P/oQqkt9qL1u+fDl9+/alX79+rF27lp49ezJx4kSmTJnC5MmTnfadOnUq0dHRZ7xV1nVx8uRJVq5cycCBA522Dxw4kMWLF7v19QkhApOv5i7Jd8LfXNTqInY8tIMXL33R0uO6bS1ab1NPqtM+9trlrzGi6wgAZq2cxT1fnH7Eip6o7d93ndWVVbtXnXW/6hgzZgxDhgxhwoQJANx8880MGTKECy64gP79+zvtO3LkSK6//vozHi85ueIEifv376e4uJiGDRs6bW/YsCHz58+vUdxC+IolSyAzEzIyID3d29FYwx/yl6/mLsl3wtedPAnffgtaw/r1pbkrltiIWEvPE7AFnj/Ys2cPixYtYsGCBfZtYWFh2Gy2Cp+AAeLi4oiLi6vx+ZRy/qehta6wTQh/smQJDBhgEmZYGHz/feAUeb7MH3KX5Dvhq9avhyuuANDQ7jMipg7ih+8iLM9dAVvgVfUT6YiuI+yfhs9m5YiVroRUwcaNGwHo3r27fVtWVhZpaWn06dOnwv5Tp05l6tSpZzzm119/Td++fZ22xcfHExwczJ49e5y25+bmVviUK4Q/ycw0xV1xsfmamRkYBZ6v5y9fzl2S74SvO3UK0tIg+9gS9A1
"text/plain": [
"<Figure size 648x252 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2021-10-02 13:14:44 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this cell generates and saves Figure 4– 18\n",
2021-11-03 23:35:15 +01:00
"plt.figure(figsize=(9, 3.5))\n",
"plt.subplot(121)\n",
"plot_model(Lasso, polynomial=False, alphas=(0, 0.1, 1), random_state=42)\n",
2021-11-03 23:44:16 +01:00
"plt.ylabel(\"$y$ \", rotation=0)\n",
2021-11-03 23:35:15 +01:00
"plt.subplot(122)\n",
"plot_model(Lasso, polynomial=True, alphas=(0, 1e-2, 1), random_state=42)\n",
"plt.gca().axes.yaxis.set_ticklabels([])\n",
"save_fig(\"lasso_regression_plot\")\n",
"plt.show()"
2021-10-02 13:14:44 +02:00
]
},
2016-05-22 16:01:18 +02:00
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 44,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAscAAAIwCAYAAABqYcaPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADA8UlEQVR4nOydd5hkVZm431NVnXOanhwZksiQBGcwgA4ZUZEkKrIqqKir6+oa1lWXXXVX/bmKioqrIiwIqIRRiQMiGQkyRAeYxPR0mI4znaa7q+v8/qiqnurqqu46t26s+t7n6Wemq2746tbtW+/96jvfUVprBEEQBEEQBEGAkNcBCIIgCIIgCIJfEDkWBEEQBEEQhAQix4IgCIIgCIKQQORYEARBEARBEBKIHAuCIAiCIAhCApFjQRAEQRAEQUggciwIgiAIgiAICUSOBUEQBEEQBCGByLEQCJRSzUqpa5VSP/M6lnxQSl2tlPqj13EIgiAIgpAZkWMhKPwHsBN4n9eB2IlS6n6l1I+8jkMQBCEX5AZfKAZEjgXfo5SKAOcCdwDtHocjCIIgCEIBI3IsBIG1wMvA0cDDVjaQyND+VCn1A6VUf+LnO0qpUOJ5pZT6F6XUFqXUqFLqOaXU+zNs40ql1DeVUj1Kqd1Kqe8mt5FY5lSl1IOJ7fcppe5SSh2SJaargbcCn1BK6cTPcqXURUqpXqVUWdry1ymlNlh5/YIgCE6jlPqsUupZpdSwUmqXUup/lVL1Kc/XJcrjdiul9imltiqlPpPy/EeVUi8nnutOXD8jiedCSql/U0rtVEqNJa7R73T/VQrFgMixEATeDDwGnAXcnMd23kf8nF8LfBS4FPhM4rn/BD4MfAI4FPgW8DOl1BkZthEF1gGfTKx/fsrzVcD3gWOBE4A9wB+UUqUZ4vk08CjwK2BB4mcn8NtEnFMXfqVUHfBu4BdmL1kQBME1YsSvia8DLiR+HfxhyvP/CbweOBM4GPgQsAtAKXUM8GPg34GDgPXAnSnrfhr4PPCFxDZuAW5WSh3h1IsRiheltfY6BkGYFaXUzcArxC+2K7TW0UQG9c3AvVrrc3LYxv3AQuAgnTjplVJfAT5G/ELcA5ystX4wZZ3vAwdqrU9P2UaZ1nptyjL3ADu01h/Jst8qYC/wVq31Q4lscbPW+syUbT6vtf5k2no/Ag7QWp+a+P3jwNeAxVrr6FyvVxAEwQnSr2FzLHsqcBtQobWOJa7bvVrrf8iw7NnEEwWLtdaDGZ7fBfxMa315ymP3A21a6/enLy8I+SCZYyEILCeeNf5eihj+D3CR4XYe09PvBh8FFgHHAOXAnUqpoeQP8HFgVdo2nk37vR2Yl/xFKbVKKXV9ojxjL9BF/O9sqWGsPwdOUkotTvz+IeDXIsaCIPgVpdTblFL3KKXalFKDxL/pKwXmJxb5CXCeUmpToiTtrSmr3wPsALYlSsg+qJSqSWy3lnhyI72s7iHi3/QJgq2IHAtBoBWoBqbauGmt/wzMyC7kyTuAI1J+XgecnLbMRNrvmul/R38AWoiXbRwHHEm8DCNTWUVWtNabgKeBi5VShxEX+F+abEMQBMEtlFLLgD8BLxEfQH008Zt6SFz/tNZ3AMuA7wLNwJ+UUr9KPDcIHAWcB7wGfAn4u1JqYcpuMn3VLV9/C7YjciwEgSjwJa31SJ7bOU4ppVJ+fyPxzO/TwBiwTGv9atrPjlw3rpRqAg4Bvqm13qi1fgmoASKzrDYOhLM893PgYuAjwMNa6825xiIIguAyxxCX4H/SWj+qtX6ZeLZ3GlrrHq31tVrri4mP8/hgcvCx1jqqtb5Pa/0l4HDiYzjO1FrvJX6tflPa5t4EvOjYKxKKltk+tAXBc5RSZxIvSRhTSh0KHKm1vs7i5hYC31dKXUl8QMfngf/UWg8qpb4LfDchzw8Qz1S/EYhpra/Kcfv9xGuXL1FK7SResvEd4nKfje3AsUqp5cAQ0Ke1jiWe+w3wPeLlHR/L+VUKgiA4S22GgXCvEE+4fSYxTuSN7B/wDIBS6nLiyYgXiPvH2cBWrfVY4lq/ivj1tw84kXhy4aXE6t8BLldKvQI8Bbyf+LiTo+1+cYIgciz4FqVUmHjm9P3A5UAH8UyDVa4jnqV9nPhXcb8gXrsM8G/E64M/R7wubi/wDPDtXDeeGHByPnAF8DzwKvDPwO9nWe27wK+JZz8qgBXEhZmEtN9E/CvKm3KNQxAEwWHeDPwt7bHfE+8o8QXiXSkeIX49vTFlmTHgG8Svc/uIdyF6R+K5AeBdwFeBSmAL8JGUQdJXEJflbxMvtdsMvEdr/Yxtr0oQEki3CiGwKKVOAD5p0K1iRlcIv6OUuoP4aOxLvI5FEARBEIoByRwLgUQptRFYA1QppdqAc7XWj3oclm0opRqJ9/k8mfjrFARBEATBBTyRY6XUL4k3Ad+ttT4sw/MnEO+NuC3x0M2pvQ0FQWu93usYHOZpoBH4stb6ea+DEQRBEIRiwZOyCqXUW4gPPrpmFjn+XC5NxgVBEARBEATBLjxp5aa1To5GFQRBEARBEATf4Oc+x2sTs+jcoZR6ndfBCIIgCIIgCIWPXwfkPU18QoYhpdTpwK3A6vSFlFKXApcClJWVHd06f5GrQeZKJKyYjMbmXtBFOrs6UApa5y3wOpQZhCMhopP+66ISCSv/xuXDrjMR5c+4tPJvbDu3bu3RWrd4tf/p19Tyo1sXzpjDYeY6hvvQ2nydqfUUhEOKyVhu7531fWmmzxc0O52d7b6+nvrt8yeJX2MzjsvgXMmHcFgxOcdnULJU1uT8zQdNwnFy/Ju0tA9t7RDv2G79eupZK7fEpAd/zFRznGHZ7cAxWuuebMssW75KLzj+MvsCtImabaOcd9Eqbrpmi9ehTOPJF37BvNYKljZf6HUoM3jPJau5emOb12HM4OL1i30b149eavc6jBlctmYRV27a5XUYGfnYMQv5Xqf/3sttn/7cU1rrY7yOA2DZylV61Qe/kPX53r0jtEYqjbfbt2eE+ZEKs3UGhmmNlANw5gkt/PH+7pzW6+8fpqXMbF/9vUM0lxnN9s59j19JQ1MpKxouoLHCbF2nOenC5dxz/Xavw8iIX2MziatvdHzq//VN1Q5FFOfksxdx981zX1N7xvbH1OBwTACnndrKNbdtp6GhyrF9dEX30Vhvtv0//fhjlq+nviyrUErNT07zq5Q6lnicvd5GZU7NtlGvQwgk4bGYHDtDqnf5L/viZ0IT7mRVhOn07cl3BvjcsSLGVolOTAL4TowF52msKJ163wd6hxjoHfI4ImguK526wet3KZ6Wsgr6+4fp7x92ZPutkXL6BpzZdiY8kWOl1G+AR4GDlFJtSqkPK6U+ppRKTpF7DvC8UmoT8VlxLtABm61E5C5/5BiaIYJsRmWbL3MDgaB3r3XJtZI1dot8RMKlb7EFn+JXSQZ3BRkoCEH2pOZYa/3eOZ7/EfAjl8KxHZE6+6jZNsrgCneyP4VA9a4YQ4tE+nKlsi3EyGK5qbCClZIK6/sqN1o+nw9n05IKP0iQ4B+Sgtw3Oj51bjhdbjEbzWWl9IyNTwmy02UWLWUVdI+N0t8/7EiZRWuknK6BYeMSC1PkU9RmRIztR46pGZJBNkMyyO5gpdY4H6zUGlulJCRpY2E66ZlkL3G7zMLpDDI4/42SfCrYiEicc8ixNUME2QwR5NzJp6TClNSBeG4gWWPBbpKS7IdSi0IR5OQ1wUlBlk8EmxB5cx45xmaIIJshgpw7piUVfh+Il48syCA8IRf8IskiyLkhnwY2INLmHnKszRBBNkME2Tn8PBAPJGucL/09g1M/Qnb8MGgvWWbR3zvkuCQHVZD9OglIYBBZcx8ZpGeGDNIzQwbpZcfNkgpwdyCeFQoxa2xFbhsrIzRWpujEyCh9I9E514tGY9P219BcY7zvIJI+aM+rAXupg/WcHKjn5CC91kg5XdF99Nk8SE/kOA9EjL1DBNkMEWQzRJCz4+eSCrA2EK8Ys8azSfA00bVILtuIhKYv15clpkKV5ilJ9rCrhQhyZkSOLSJi7D0iyGaIIJshgmwffi+psEKQssb
"text/plain": [
"<Figure size 727.2x576 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – this BIG cell generates and saves Figure 4– 19\n",
2021-11-21 05:36:22 +01:00
"\n",
2021-11-03 23:35:15 +01:00
"t1a, t1b, t2a, t2b = -1, 3, -1.5, 1.5\n",
"\n",
"t1s = np.linspace(t1a, t1b, 500)\n",
"t2s = np.linspace(t2a, t2b, 500)\n",
"t1, t2 = np.meshgrid(t1s, t2s)\n",
"T = np.c_[t1.ravel(), t2.ravel()]\n",
"Xr = np.array([[1, 1], [1, -1], [1, 0.5]])\n",
"yr = 2 * Xr[:, :1] + 0.5 * Xr[:, 1:]\n",
"\n",
"J = (1 / len(Xr) * ((T @ Xr.T - yr.T) ** 2).sum(axis=1)).reshape(t1.shape)\n",
"\n",
"N1 = np.linalg.norm(T, ord=1, axis=1).reshape(t1.shape)\n",
"N2 = np.linalg.norm(T, ord=2, axis=1).reshape(t1.shape)\n",
"\n",
"t_min_idx = np.unravel_index(J.argmin(), J.shape)\n",
"t1_min, t2_min = t1[t_min_idx], t2[t_min_idx]\n",
"\n",
2021-11-21 05:36:22 +01:00
"t_init = np.array([[0.25], [-1]])\n",
"\n",
2021-11-03 23:35:15 +01:00
"def bgd_path(theta, X, y, l1, l2, core=1, eta=0.05, n_iterations=200):\n",
" path = [theta]\n",
" for iteration in range(n_iterations):\n",
" gradients = (core * 2 / len(X) * X.T @ (X @ theta - y)\n",
" + l1 * np.sign(theta) + l2 * theta)\n",
" theta = theta - eta * gradients\n",
" path.append(theta)\n",
" return np.array(path)\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"fig, axes = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10.1, 8))\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"for i, N, l1, l2, title in ((0, N1, 2.0, 0, \"Lasso\"), (1, N2, 0, 2.0, \"Ridge\")):\n",
" JR = J + l1 * N1 + l2 * 0.5 * N2 ** 2\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-03 23:35:15 +01:00
" tr_min_idx = np.unravel_index(JR.argmin(), JR.shape)\n",
" t1r_min, t2r_min = t1[tr_min_idx], t2[tr_min_idx]\n",
"\n",
" levels = np.exp(np.linspace(0, 1, 20)) - 1\n",
" levelsJ = levels * (J.max() - J.min()) + J.min()\n",
" levelsJR = levels * (JR.max() - JR.min()) + JR.min()\n",
" levelsN = np.linspace(0, N.max(), 10)\n",
"\n",
" path_J = bgd_path(t_init, Xr, yr, l1=0, l2=0)\n",
" path_JR = bgd_path(t_init, Xr, yr, l1, l2)\n",
" path_N = bgd_path(theta=np.array([[2.0], [0.5]]), X=Xr, y=yr,\n",
" l1=np.sign(l1) / 3, l2=np.sign(l2), core=0)\n",
" ax = axes[i, 0]\n",
" ax.grid()\n",
" ax.axhline(y=0, color=\"k\")\n",
" ax.axvline(x=0, color=\"k\")\n",
" ax.contourf(t1, t2, N / 2.0, levels=levelsN)\n",
" ax.plot(path_N[:, 0], path_N[:, 1], \"y--\")\n",
" ax.plot(0, 0, \"ys\")\n",
" ax.plot(t1_min, t2_min, \"ys\")\n",
2021-11-21 22:18:02 +01:00
" ax.set_title(fr\"$\\ell_{i + 1}$ penalty\")\n",
2021-11-03 23:35:15 +01:00
" ax.axis([t1a, t1b, t2a, t2b])\n",
" if i == 1:\n",
" ax.set_xlabel(r\"$\\theta_1$\")\n",
" ax.set_ylabel(r\"$\\theta_2$\", rotation=0)\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-03 23:35:15 +01:00
" ax = axes[i, 1]\n",
" ax.grid()\n",
" ax.axhline(y=0, color=\"k\")\n",
" ax.axvline(x=0, color=\"k\")\n",
" ax.contourf(t1, t2, JR, levels=levelsJR, alpha=0.9)\n",
" ax.plot(path_JR[:, 0], path_JR[:, 1], \"w-o\")\n",
" ax.plot(path_N[:, 0], path_N[:, 1], \"y--\")\n",
" ax.plot(0, 0, \"ys\")\n",
" ax.plot(t1_min, t2_min, \"ys\")\n",
" ax.plot(t1r_min, t2r_min, \"rs\")\n",
" ax.set_title(title)\n",
" ax.axis([t1a, t1b, t2a, t2b])\n",
" if i == 1:\n",
" ax.set_xlabel(r\"$\\theta_1$\")\n",
2019-01-18 16:08:37 +01:00
"\n",
2021-11-03 23:35:15 +01:00
"save_fig(\"lasso_vs_ridge_plot\")\n",
"plt.show()"
2019-01-18 16:08:37 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-03 23:35:15 +01:00
"## Elastic Net"
2019-01-18 16:08:37 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 45,
2019-01-18 16:08:37 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([1.54333232])"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
2019-01-18 16:08:37 +01:00
"source": [
2021-11-03 23:35:15 +01:00
"from sklearn.linear_model import ElasticNet\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"elastic_net = ElasticNet(alpha=0.1, l1_ratio=0.5)\n",
"elastic_net.fit(X, y)\n",
"elastic_net.predict([[1.5]])"
2016-05-22 16:01:18 +02:00
]
},
{
2021-11-03 23:35:15 +01:00
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"## Early Stopping"
2016-05-22 16:01:18 +02:00
]
},
2016-09-27 16:39:16 +02:00
{
2021-11-03 23:35:15 +01:00
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-09-27 16:39:16 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"Let's go back to the quadratic dataset we used earlier:"
2016-09-27 16:39:16 +02:00
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 46,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABKjklEQVR4nO3dd3hTZfvA8e/dRQtl7ynK3kMEBH6AiqKiKO/rQFABGSqguHGLA8fr9nXhQPR1bxEZLiqiMmXIBgEVKBspld3evz+epElKKAHapE3uz3Wdq8k5J+c8edrmzrNFVTHGGGPCLS7SCTDGGBObLAAZY4yJCAtAxhhjIsICkDHGmIiwAGSMMSYiLAAZY4yJiLAFIBFJFpFZIrJARBaLyP1BzukqIjtFZL5nuzdc6TPGGBNeCWG81z7gdFXNFJFEYLqITFLVGbnO+1FVzwtjuowxxkRA2AKQuhGvmZ6niZ7NRsEaY0yMCmsbkIjEi8h8YDPwjarODHLaqZ5qukki0iSc6TPGGBM+EompeESkDPAZcJ2qLvLbXwrI9lTTnQs8q6r1grx+CDAEoDycXCspiX9q1845vmFDCpmZrnBXpcpeSpU6UIDvpnDKzs4mLs76mARjeROc5Utwli/BrVixYquqVjyea0QkAAGIyH3AP6r6RB7nrAXaqOrWw53TRkTnNGoES5bk7BsxAp57zj1+/HG45Zb8SnXRkZaWRteuXSOdjELJ8iY4y5fgLF+CE5G5qtrmeK4Rzl5wFT0lH0QkBegGLMt1ThUREc/jtp70bTvixQ8ElnCqVvU9Tk8/rmQbY4wpIOHsBVcVeFNE4nGB5UNVnSAi1wCo6svARcC1InIQ2AP01lCKaKtWwa5dULKku5EFIGOMKfTC2QtuIdAqyP6X/R4/Dzx/TDd4+WW49VbAApAxxhQF4SwBFYislBQYNiygoccCkDHGFH5FPgDtrlnT9TTwYwHIGGMKvyIfgIIpXx4SE13fhIwM2L0bihePdKqMCa+MjAw2b97MgQOhDUMoXbo0S5cuLeBUFT2xmi8lSpSgRo0aBdoFvcgHoPi9e12Rp1UrmDgRABGoUgX++sudk54OdepEMJHGhFlGRgabNm2ievXqpKSk4Olcmqddu3ZR0tORx/jEYr5kZ2ezfv16tm7dSqVKlQrsPtExumrjRtiyJWCXVcOZWLZ582aqV69O8eLFQwo+xviLi4ujcuXK7Ny5s2DvU6BXDwP1/nPt3h2w3wKQiWUHDhwgJSUl0skwRVhiYiIHDx4s0HsU+QCEt34yjwC0YUMY02NMIWElH3M8wvH3U+QDkB4mANWq5Xv8xx9hTJAxxpiQFPkAxGGq4PzmJmXt2rClxhgTYaNGjaJp06aHfR7M8OHD82W+t1DuZXyKfAAKKAH5zdrjH4CsBGRM4Xf++efTrVu3oMeWLl2KiPDNN98c9XVvueUWfvjhh+NNXoC1a9ciIsyZM6fA71UQunbtyvDhwyOdjKLfDRuAUaMgORmysyE+HrASkDFFzaBBg+jVqxdr166ltv8/MPD6669zwgkncMYZZxz1dVNTU0lNTc2nVBaee0WDIl8CAuC++2DkyJzgA1C5MhQr5h5v3+4GpBpjCq8ePXpQuXJl3njjjYD9Bw4c4H//+x9XXXUVqsrAgQM58cQTSUlJoV69evznP/8hOzv7sNfNXS2WlZXFLbfcQtmyZSlbtiw33HADWVlZAa+ZPHky//d//0fZsmWpVasW3bt3DxiMeuKJJwJwyimnICI51Xe575Wdnc2DDz5IzZo1KVasGM2aNeOLL77IOe4tSX3yySeceeaZFC9enMaNGx+xpDdt2jTat29PamoqpUuXpl27dixalLO0Gj///DNdunShePHiVK9enWuvvZYMz4dg//79+eGHH3jhhRcQEUSEtRH6lh4dASiIuDg44QTfc6uGM6ZwS0hIoF+/fowbNy4goHz55Zds3bqVAQMGkJ2dTfXq1fnwww9ZunQpo0eP5uGHHz4kaOXlySef5NVXX2XMmDH88ssvZGVl8c477wSc888//3DDDTcwa9YsJk6cSOnSpTn//PPZv38/ALNmzQJcoEpPT+fTTz8Neq9nn32Wxx9/nMcee4zffvuNXr168a9//Yv58+cHnHfXXXdx/fXXs2DBAk455RR69+5NZmZm0GsePHiQCy64gE6dOrFgwQJmzpzJiBEjiPd8Af/tt98466yz6NmzJwsWLODTTz9l/vz5XHXVVTlpOvXUUxkwYADp6emkp6dTs2bNkPMvX6lqkd7q16+v+v33qh98oPr33+rvrLNUXcOQ6vjxGlOmTp0a6SQUWrGQN0uWLDlkn/d/IRJbqFasWKGATpkyJWffueeeq2efffZhXzNy5Eg944wzcp7fd9992qRJk8M+r1q1qj700EM5z7OysrRevXrapUuXoNfPyMjQzMxMjYuL0x9//FFVVdesWaOAzp49O+Dc3PeqVq2a3n///QHndOnSRfv27RtwnZdffjnn+Lp16xTIuVdu27ZtU0DT0tKCHr/iiiv0qquuCtg3b948BXTTpk05aRg2bFjQ1/sL9nfkBczR4/z8jo4S0A03wKWXHtLYY+1AxhQt9erVo3PnzowdOxaADRs2MGXKFAYNGpRzzssvv0ybNm2oWLEiqampPP300/z5558hXX/nzp2kp6dz6qmn5uyLi4ujXbt2Aef9/vvv9OnThzp16lC9enUqV65MdnZ2yPcBNx3Shg0b6NixY8D+Tp06scRvBWeA5s2b5zyuVq0a4GazCKZcuXL079+f7t2706NHD5566in+8s47BsydO5e33347pz0qNTU1Jw2///57yOkPh+gIQN6ZRvPoim1VcMYUDYMGDeLzzz9n+/btjBs3jnLlytGzZ08APvjgA2644Qb69+/PlClTmD9/PkOHDs2pGssv559/Plu2bGHMmDF8//33zJs3j4SEhGO6T7ABnbn3JSYmHnIsr3atN954g5kzZ9K5c2fGjx9P/fr1mTJlSs7rBg0axPz583O2BQsWsHLlSlq2bHnU6S9IMROArARkYt2RKsoyMnYVWCXc0bjoootITk7m7bffZuzYsVx55ZU5H9DTp0+nXbt2DB8+nNatW1O3bt2j+lZfunRpqlatyowZM/zyRXPadAC2bdvG0qVLufPOO+nWrRsNGjRg165dAdPSJCUlARzSecFfqVKlqFatGtOnTw/YP336dBo3bhxymg+nRYsWjBw5krS0NLp27cqbb74JQOvWrVm8eDF169Y9ZPNOz5SUlJRn2sMlOrphe2eq3bgxYLd/AFqzJnzJMcYcu5SUFPr06cOoUaPYsWMHAwcOzDlWv359xo0bx6RJk6hbty7vv/8+P/zwA2XLlg35+iNGjOCRRx6hfv36NGvWjBdffJH09HSqeubvKlu2LBUqVODVV1+lZs2arFy5kvvuu4+EBN/HZaVKlUhJSWHKlCnUrl2b5ORkSpcufci9br31Vu69917q1avHySefzNtvv82PP/7I3Llzjzl/1qxZw5gxY+jZsyfVq1dn9erVLFy4kGuvvRaAkSNH0r59e6655hquvvpqSpYsybJly/jyyy8ZM2YMALVr12bWrFmsXbuW1NRUypUrV6DLLhxOdJSAOnVyPydNCtjtvwTDihVH/03MGBMZgwYNYseOHXTo0IFGjRrl7L/66qu55JJL6NOnD6eccgpr167l5ptvPqpr33zzzQwYMIBBgwbRrl07srOz6du3b87xuLg4PvjgAxYuXEjTpk25+eabefDBBynmHdeB67H33HPP8dprr1GtWjUuuOCCoPe6/vrrufXWW7ntttto2rQpn332GZ988slxVYUVL16cFStWcPHFF1O/fn369etH3759GTlyJODak6ZNm8batWvp0qULLVq04I477qBy5co517jllltISkqicePGVKxY8ajatvKTaBH/VG7QoIEunzAB6teH0qXdoB9PJFeFChXcLnDrA9WoEcHEhpG3WG4OFQt5s3Tp0oAP7lDE4ro3oYjlfMnr70hE5qpqm+O5fnSUgOrVc8ugJiTA33/n7BaBhg19p8X
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 16:39:16 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"from copy import deepcopy\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# extra code – creates the same quadratic dataset as earlier and splits it\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
"m = 100\n",
"X = 6 * np.random.rand(m, 1) - 3\n",
"y = 0.5 * X ** 2 + X + 2 + np.random.randn(m, 1)\n",
2021-11-10 05:58:42 +01:00
"X_train, y_train = X[: m // 2], y[: m // 2, 0]\n",
2022-02-19 06:17:36 +01:00
"X_valid, y_valid = X[m // 2 :], y[m // 2 :, 0]\n",
2016-09-27 16:39:16 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"preprocessing = make_pipeline(PolynomialFeatures(degree=90, include_bias=False),\n",
" StandardScaler())\n",
"X_train_prep = preprocessing.fit_transform(X_train)\n",
"X_valid_prep = preprocessing.transform(X_valid)\n",
"sgd_reg = SGDRegressor(penalty=None, eta0=0.002, random_state=42)\n",
"n_epochs = 500\n",
"best_valid_rmse = float('inf')\n",
2022-02-19 06:17:36 +01:00
"train_errors, val_errors = [], [] # extra code – it's for the figure below\n",
2016-09-27 16:39:16 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"for epoch in range(n_epochs):\n",
" sgd_reg.partial_fit(X_train_prep, y_train)\n",
" y_valid_predict = sgd_reg.predict(X_valid_prep)\n",
" val_error = mean_squared_error(y_valid, y_valid_predict, squared=False)\n",
" if val_error < best_valid_rmse:\n",
" best_valid_rmse = val_error\n",
" best_model = deepcopy(sgd_reg)\n",
"\n",
2022-02-19 06:17:36 +01:00
" # extra code – we evaluate the train error and save it for the figure\n",
2021-11-03 23:35:15 +01:00
" y_train_predict = sgd_reg.predict(X_train_prep)\n",
" train_error = mean_squared_error(y_train, y_train_predict, squared=False)\n",
" val_errors.append(val_error)\n",
" train_errors.append(train_error)\n",
"\n",
2022-02-19 06:17:36 +01:00
"# extra code – this section generates and saves Figure 4– 20\n",
2021-11-03 23:35:15 +01:00
"best_epoch = np.argmin(val_errors)\n",
"plt.figure(figsize=(6, 4))\n",
"plt.annotate('Best model',\n",
" xy=(best_epoch, best_valid_rmse),\n",
" xytext=(best_epoch, best_valid_rmse + 0.5),\n",
2021-11-27 01:38:47 +01:00
" ha=\"center\",\n",
2021-11-03 23:35:15 +01:00
" arrowprops=dict(facecolor='black', shrink=0.05))\n",
"plt.plot([0, n_epochs], [best_valid_rmse, best_valid_rmse], \"k:\", linewidth=2)\n",
"plt.plot(val_errors, \"b-\", linewidth=3, label=\"Validation set\")\n",
"plt.plot(best_epoch, best_valid_rmse, \"bo\")\n",
"plt.plot(train_errors, \"r--\", linewidth=2, label=\"Training set\")\n",
"plt.legend(loc=\"upper right\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"RMSE\")\n",
"plt.axis([0, n_epochs, 0, 3.5])\n",
"plt.grid()\n",
"save_fig(\"early_stopping_plot\")\n",
2016-09-27 16:39:16 +02:00
"plt.show()"
]
},
2016-05-22 16:01:18 +02:00
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"source": [
2021-10-02 13:14:44 +02:00
"# Logistic Regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-11-03 23:35:15 +01:00
"## Estimating Probabilities"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 47,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAADICAYAAAD2r9syAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAwS0lEQVR4nO3de3xU9Z3/8dc3F+6BcFGQW8EbQkHEC962emy1AnVxu9iWVrxQLYuLVdrlp9JusbtUl1Z6oa2sFSt0F2/11uIqeKkdQBENgQCJMEAChhAghHACIdeZfH9/fJMhCYEMkjCZ5P18PL6Pmc+ZcyafnByGz3zP93yPsdYiIiIiEk8SYp2AiIiIyKlSASMiIiJxRwWMiIiIxB0VMCIiIhJ3VMCIiIhI3FEBIyIiInEnKdYJNCY1NdWef/75sU6jTTl69Chdu3aNdRptjvZr8woGg4TDYUaMGBHrVNocHavNT/u0ZaSnpxdaa89qar1WWcD07duXdevWxTqNNiUQCOB5XqzTaHO0X5uX53n4vq9//y1Ax2rz0z5tGcaYz6JZT6eQREREJO6ogBEREZG4owJGRERE4o4KGBEREYk7KmBEREQk7rTKq5BOprq6mry8PI4ePRrrVOJKjx492LJlS6zTOOOSk5M5++yz6d69e6xTERGRZhR3BUxhYSHGGIYNG0ZCgjqQonXkyBFSUlJincYZZa2lrKyMPXv2AKiIERFpQ5qsAIwxzxpjCowxmSd43RhjfmuM2WGM2WSMubTOa+OMMcGa1x5pjoR936dv374qXqRJxhi6dOnCgAEDKCgoiHU6IiLSjKKpApYA407y+njggpo2DfhvAGNMIvBkzesjgG8bY057es1wOExycvLpvo20I507d6aqqirWaYiISDNq8hSStXaVMWbISVa5Ffgfa60F1hpjUo0x5wBDgB3W2hwAY8yLNet+erpJG2NO9y2kHdHxIiISW9ZCdTWEw+7xZC1azXEeZgCwu06cV7PsRMubtHv3bpYsWQJAVVUVnuexdOlSwA3iDQaDFBUVARAKhQgGgxw6dCiyfjAYxPf9enFxcTEAlZWVBINBDh8+DEBFRQXBYJAjR44AUF5eTjAYpKSkBICysjKCwWBk0HBpaSnBYJDS0lLA3QsjGAxSVlYGQElJCcFgkPLycsCNPQkGg1RUVABw+PBhgsEglZWVABQXFxMMBiM9BL7v14sPHTpEMBgkFAoBUFRUFLlfDMDBgwcJBoNU1/zVCwsLCQaDkX154MCBenFBQQHbtm2LxPv372f79u2ReN++fezYsSMS7927l+zs7Eicn59PTk5OJN6zZw87d+6MxHl5eezatave3/Kzz47NCp2bm0tubm4k/uyzz9i9+9hhsmvXLvLy8iLxzp07I2NYAHJycsjPz4/E2dnZ7N27NxLv2LGDffv2ReLt27ezf//+SDx+/HgWLlwYiW+88UYWLVoUiT3PO+GxV1paiud5vPTSS4D7282cOZPXXnsNcPve8zzeeOMNwO1Lz/NYsWJFZF94nsd7770X+V08z2PlypWAuw+Q53msWbMGgMzMTDzPIy0tDYCMjAw8zyMjIwOAtLQ0PM8jM9Od3V2zZg2e50X+3itXrsTzvMjf67333sPzvMj+XrFiBZ7nRfbXG2+8ged5FBYWAvDaa6/heV7k385LL72E53mRY3/p0qV4nhc5VpcsWVJvWvVFixZx4403RuKFCxcyfvz4SLxgwQImTpwYiefPn09WVlYknjdvHpMnT47Ec+fOZcqUKZF4zpw5TJ06NRLPnj2badOmReJZs2YxY8aMSDxz5kxmzpwZiWfMmMGsWbMi8bRp05g9e3Yknjp1KnPmzInEU6ZMYe7cuZF48uTJzJs3LxJPmjSJ+fPnR+KJEyeyYMGCSNzcx57nead07M2cOVPHXo3Gjr1JkyZF4lgee6EQTJkynVmz5pGXB9nZcMstD3PffYv46CMIBMDzHueuu17hlVfg+efh8st/z223vcPvfw+//jWMGLGECRM+Ys4cmD0bhg59Fc/bxH33wfe+B/37v81VV21n8mSYNAn69PmA0aN3c/PN8OUvV5OamsGFFxZw1VVw6aVhunXbweDBPsOHw/nnh+ncOZ+zzipl4EDo27eaDh0O0b17JT17QkpKNYmJZXTsGKZjR0hKshgDCQmQlAQdO0LnztC1K6SkQI8e0LMn9O4NZ50FffsSteYYxNvY11t7kuWNv4kx03CnoEhOTmbr1q0EAgFCoRC+77NlyxYCgUBkIGZZWRlHjhwhHA4TCoUoKysjKSmpXpyYmEgoFIrECQkJVFVVEQqFKC0txRhDZWVlJAZX0NTG1tpIfPToUaqrqykvL4/E4XCYsrKySFz7c0KhECUlJVRVVVFaWhqJKysr68XJycn11k9KSqr3fomJifVebxgnJCRE8ikpKcEYE4nrFmThcJhwOMyRI0cice3rFRUVx8V1t28Y1+6vunFVVVW9uO77VVZWYq2NxLUfOHVjY0y9uLq6+oTrh0IhKisr68UVFRUnjMPhMBUVFZSXlxMIBCgqKmLbtm0EAgHgWIFYG/u+f8Jjr7y8HN/3ycrKIhAIUFJSQjgcJjMzk169elFcXIzv+2zevJmUlBSKiorwfZ9NmzbRqVMnCgoK8H2fjRs3kpSURH5+Pr7vs2HDBqy15Obm4vs+69evp7Kykp07d+L7Punp6Rw9epQdO3ZE7hNUm6fv+6SlpVFYWEhmZia+7/Pxxx+zd+9eMjIy8H2ftWvXkpuby8aNG/F9n48++ojs7Gw2bdqE7/usWbOGXr16sXnzZnzf58MPP6RHjx6R91u9ejXdunUjKysL3/dZtWoVnTp1YsuWLfi+z8qVK0lKSorkU7sva79Y1Mbbtm2jqKgoEm/fvp2DBw9G4uzsbKqqqgiHwwQCAXJycigoKIi8vnPnTvbv3x+Jd+3axYEDByJxbm4uxcXFkXj37t1UVFRE4trCuDbes2cPHTt2jMT5+fkcPXo0Eu/bty+SC7hiv0OHDpG4oKCAnJycSHzgwAGys7Mj8cGDB9m+fXskbu5jz/f9Uzr2wuGwjr2THHt1j6WTHXvWwvbte9m3r4KtWy3BYDppad0oLu7MT37yKWVlibz33igqKpL5zndyKS9PYMOGfyQU6khaWiHl5Yns2nU/1dWd+POfS6moSMD3n6C6ugNPPFFNOJwAPAXAL39JjZ8D8NRTtfGPAPif/6mN7yc9HV59tTa+my1bYPny2ngSu3ZBTb0K3MzevfDxx7XxP3DwIGzaBK5f4xJqakcgETifYxf+JgL9qfmOXrN+T46dpU8AOlPzHZ26pUBCgsUYS3V1iKQkQ2JiAhCmsrKCTp2SSUpKpLo6TE3/QZOMO/PTxEruFNL/WWtHNvLaH4CAtfaFmjgIeLhTSD+11t5cs3w2gLX2v5r6ecOGDbN1ew3q2rJlC8OHD28yZ6mvPV6FVFdLHTe6mVvzqr2ZY+03fWk+OlaPCYWgqAgOHjzWiorA9+HQoWPN96G4uH47cuTUTnOcKmNcD0WnTq517Oha3ecdO0KHDvUfk5Pd89qWnHysNYyTklyrfV53WVISJCYeH9cuq/vYsCUk1H+s+9wY16LbBybdWnt5U+s1Rw/MMuD+mjEuVwLF1tq9xpgDwAXGmKHAHmAy8J1m+HkSJ1atWsX8+fNJT08nPz+fxYsXc/fdd8c6LRFpY6x1hcW+fa7t33/ssaAADhw49njggCtMTkeXLu70R1JSKX37diElxcVdu0K3bsdaly5uWdeux5537uyed+ly7HmnTu55586uONCwveg0WcAYY17A9aj0McbkAY8CyQDW2qeAt4AJwA6gFJha81rIGHM/8Dauz+lZa23WcT9A2qySkhJGjhzJnXfeyZ133hnrdEQkDlnrekdyc2H37mOPe/a4lp/vHk9lblNjjo27qG29erlldVtqqmvdu7uxGj16uOdJNf9zBgKfqFcrhqK5CunbTbxugRkneO0tXIEj7dCECROYMGECgHpeROSESkthxw43YHXnzmMtJwc++8y93pTOneGcc6BfPzcQtPaxb18
"text/plain": [
"<Figure size 576x216 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"# extra code – generates and saves Figure 4– 21\n",
2021-11-03 23:35:15 +01:00
"\n",
"lim = 6\n",
"t = np.linspace(-lim, lim, 100)\n",
2016-05-22 16:01:18 +02:00
"sig = 1 / (1 + np.exp(-t))\n",
2021-11-03 23:35:15 +01:00
"\n",
"plt.figure(figsize=(8, 3))\n",
"plt.plot([-lim, lim], [0, 0], \"k-\")\n",
"plt.plot([-lim, lim], [0.5, 0.5], \"k:\")\n",
"plt.plot([-lim, lim], [1, 1], \"k:\")\n",
2016-05-22 16:01:18 +02:00
"plt.plot([0, 0], [-1.1, 1.1], \"k-\")\n",
2021-11-03 23:35:15 +01:00
"plt.plot(t, sig, \"b-\", linewidth=2, label=r\"$\\sigma(t) = \\dfrac{1}{1 + e^{-t}}$\")\n",
2016-05-22 16:01:18 +02:00
"plt.xlabel(\"t\")\n",
2021-11-03 23:35:15 +01:00
"plt.legend(loc=\"upper left\")\n",
"plt.axis([-lim, lim, -0.1, 1.1])\n",
"plt.gca().set_yticks([0, 0.25, 0.5, 0.75, 1])\n",
"plt.grid()\n",
2016-05-22 16:01:18 +02:00
"save_fig(\"logistic_function_plot\")\n",
"plt.show()"
]
},
{
2021-11-03 23:35:15 +01:00
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"## Decision Boundaries"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 48,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"['data',\n",
" 'target',\n",
" 'frame',\n",
" 'target_names',\n",
" 'DESCR',\n",
" 'feature_names',\n",
" 'filename',\n",
" 'data_module']"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"from sklearn.datasets import load_iris\n",
"\n",
"iris = load_iris(as_frame=True)\n",
"list(iris)"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 49,
2018-03-15 18:38:58 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
".. _iris_dataset:\n",
"\n",
"Iris plants dataset\n",
"--------------------\n",
"\n",
"**Data Set Characteristics:**\n",
"\n",
" :Number of Instances: 150 (50 in each of three classes)\n",
" :Number of Attributes: 4 numeric, predictive attributes and the class\n",
" :Attribute Information:\n",
" - sepal length in cm\n",
" - sepal width in cm\n",
" - petal length in cm\n",
" - petal width in cm\n",
" - class:\n",
" - Iris-Setosa\n",
" - Iris-Versicolour\n",
" - Iris-Virginica\n",
" \n",
" :Summary Statistics:\n",
"\n",
" ============== ==== ==== ======= ===== ====================\n",
" Min Max Mean SD Class Correlation\n",
" ============== ==== ==== ======= ===== ====================\n",
" sepal length: 4.3 7.9 5.84 0.83 0.7826\n",
" sepal width: 2.0 4.4 3.05 0.43 -0.4194\n",
" petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)\n",
" petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)\n",
" ============== ==== ==== ======= ===== ====================\n",
"\n",
" :Missing Attribute Values: None\n",
" :Class Distribution: 33.3% for each of 3 classes.\n",
" :Creator: R.A. Fisher\n",
" :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n",
" :Date: July, 1988\n",
"\n",
"The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\n",
"from Fisher's paper. Note that it's the same as in R, but not as in the UCI\n",
"Machine Learning Repository, which has two wrong data points.\n",
"\n",
"This is perhaps the best known database to be found in the\n",
"pattern recognition literature. Fisher's paper is a classic in the field and\n",
"is referenced frequently to this day. (See Duda & Hart, for example.) The\n",
"data set contains 3 classes of 50 instances each, where each class refers to a\n",
"type of iris plant. One class is linearly separable from the other 2; the\n",
"latter are NOT linearly separable from each other.\n",
"\n",
".. topic:: References\n",
"\n",
" - Fisher, R.A. \"The use of multiple measurements in taxonomic problems\"\n",
" Annual Eugenics, 7, Part II, 179-188 (1936); also in \"Contributions to\n",
" Mathematical Statistics\" (John Wiley, NY, 1950).\n",
" - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n",
" (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.\n",
" - Dasarathy, B.V. (1980) \"Nosing Around the Neighborhood: A New System\n",
" Structure and Classification Rule for Recognition in Partially Exposed\n",
" Environments\". IEEE Transactions on Pattern Analysis and Machine\n",
" Intelligence, Vol. PAMI-2, No. 1, 67-71.\n",
" - Gates, G.W. (1972) \"The Reduced Nearest Neighbor Rule\". IEEE Transactions\n",
" on Information Theory, May 1972, 431-433.\n",
" - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al\"s AUTOCLASS II\n",
" conceptual clustering system finds 3 classes in the data.\n",
" - Many, many more ...\n"
]
}
],
2016-05-22 16:01:18 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"print(iris.DESCR) # extra code – it's a bit too long"
2016-05-22 16:01:18 +02:00
]
},
2019-01-18 16:08:37 +01:00
{
2021-11-03 23:35:15 +01:00
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 50,
2019-01-18 16:08:37 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sepal length (cm)</th>\n",
" <th>sepal width (cm)</th>\n",
" <th>petal length (cm)</th>\n",
" <th>petal width (cm)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.1</td>\n",
" <td>3.5</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4.9</td>\n",
" <td>3.0</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.7</td>\n",
" <td>3.2</td>\n",
" <td>1.3</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)\n",
"0 5.1 3.5 1.4 0.2\n",
"1 4.9 3.0 1.4 0.2\n",
"2 4.7 3.2 1.3 0.2"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
2019-01-18 16:08:37 +01:00
"source": [
2021-11-10 05:58:42 +01:00
"iris.data.head(3)"
2019-01-18 16:08:37 +01:00
]
},
2016-05-22 16:01:18 +02:00
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 51,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0 0\n",
"1 0\n",
"2 0\n",
"Name: target, dtype: int64"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2021-11-10 05:58:42 +01:00
"iris.target.head(3) # note that the instances are not shuffled"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 52,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array(['setosa', 'versicolor', 'virginica'], dtype='<U10')"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"iris.target_names"
2017-05-29 23:20:14 +02:00
]
},
{
2021-11-03 23:35:15 +01:00
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 53,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"LogisticRegression(random_state=42)"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"X = iris.data[[\"petal width (cm)\"]].values\n",
"y = iris.target_names[iris.target] == 'virginica'\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n",
"\n",
"log_reg = LogisticRegression(random_state=42)\n",
"log_reg.fit(X_train, y_train)"
2017-05-29 23:20:14 +02:00
]
},
2016-05-22 16:01:18 +02:00
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 54,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAADICAYAAAD2r9syAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABkEklEQVR4nO3dd3gUxRvA8e+kV0ILLaFDgBBASggBlCDwE6UpRarSi4CKiooVEAVRKSooHaUIiiBSBEUgYJBAQpAaOqEl9JpA2mV+f+ylXOpduORS5vM8+9zd7uzue5tL7s3M7IyQUqIoiqIoilKYWFk6AEVRFEVRFFOpBEZRFEVRlEJHJTCKoiiKohQ6KoFRFEVRFKXQUQmMoiiKoiiFjkpgFEVRFEUpdPItgRFCLBFCXBdCHM1iuxBCfCOEOCOEOCyEaJJfsSmKoiiKUrjY5OO5fgDmAMuy2P4sUFu/+AHf6x+zVbJkSVmrVi0zhVi8xcTE4OzsbOkwigR1LSEhKYFzd85Ro1QNbK1sc3WMhwkPOXnzJHXL1sXR1tEiMTzuMSLuRnDr4S3KOpelqlvVXMVw+9Ftzt85T43SNSjlUCrf9we4EXODi/cuUrVkVco6lTV5f3P8LMzxPu7H3ef07dN4lfbC1d41d/vfOo1Xmdztrxg6cODATSmle272zbcERkq5WwhRLZsi3YBlUhtZL1gIUVIIUVFKGZXdccuXL09oaKg5Qy22AgMDCQgIsHQYRYK6ljB682iOHjhK66atmdtpbq6O4fOdD0k3khDugtDRpv+emyOGxz2GmCwAuMlNbky8kasY7D+1Bx1ctr7M2Q/P5vv+ANaTrQG4xCUiJkaYvL85fhbmeB+lp5eGWLjhcIOT757M9/0VQ0KIC7ndNz9rYHLiAVxK8/qyfl22CcylS064uoK1NVhZaUvy8wED4MsvtXJHj8LzzxtuT/u4fDl4e2tlv/wSNm7MeDwrK/Dygq+/Tj3/iy8alrO1BRsbbenXD1q31sodOACbNqVuT/84ZIh2DIDAQHjwIPOyFSpAlSpaucREuH8fHBzA3l47v6IUBFEPolj631KSZBJL/1vKR20+ooJLBZOO8V/Ufxy7cQyAYzeOcfjaYRqWb5ivMTzuMYb+NtTg9ciNI5nfZb5JMaw+spp4XTwA8bp4fj3+Kz29e+bb/gDf7/+eJJIASCKJhWELGd5kuNH7m+NnYY73se3sNu7E3gHgTuwddpzfwdPVn863/YsaXZKOOF0ccYlxxOniiE2MTXme/JigSyAhKYF4XXymzx9HQUpgRCbrMp3nQAgxAhgBYGPTgEePMj/gyZORBAaeAuDECVfOnm2a5cmDgkK5fj0agF27vPjnn0qZljt37gGBgQcASEqCNWsCsjymk9NJEhO1/GvDhkrMmuWVZdmaNQMR+iswYkRTTp/OvGqyU6dIxo8/pX9/Lowa1Sxlm7V1Era2Eju7JOzskvj888PUrBkDwOrVlQkOLoOtrbbN1jYJR0cdDg5JeHg8omfPy0RHRxMYGMi2beWxs9Ph6JiEg4NOv2jP3dwSsLdPyvJ9KJrka1lczTo1i0RdIgAJugRGrRrFuNrjTDrG4JDBBq+7LevGUt+lZo+hZ0/tS/DXX3/N9TGysuTwEoPXC8IW0Ne1r9H7A7y06yWD133X9KVsG+ObcB53f4Axu8YYvB65cSS179c2en9zfB7M8T66B3U3eN1tZTc2tt6Yb/vnB53U8TDxIQ91D4nVxRKbFEusLpa4pDge6R4RlxSnrdevS1sm+TE+KZ74pHgt2UiKJ0EmpD5Ps04ndRZ9rwUpgbkMVE7z2hOIzKyglHIBsADAy6uuDA3VkomkJNDpUp87OlaiZEktEWnRAtq3T92e/rFu3WYkd1nw8IB33sl4PJ0OSpRwpXXrAH0c8PPPqdt0Oq1WJCFBe2zbtg7e3nUAcHYGFxfD7cmPOh20bRuQ8v46dYLTpzOWTUiAVq0qERBQKeWYJUtCbCzExYFOZ4VOB7GxWlVM8+a+NGigHfPHH+HQocwvvJ8fzJlTi8DAQFq1CqBt26x/SAsWwHD9P14rVsDkyeDmlvkyaRIpSdn+/VoNU5kyULasdi1EZilrEVGcm5CiHkTx156/SJTaF1aiTOSv638xr+88o//r/i/qPyJ2RRisi3gYQel6pY2qhTElhps3b+bJ+0hf+5Js1YNVRtfCrD6ymkQSDdYlksjNcjeNqn143P1Bq32R6f6XlEhOlzhtVC2MOT4P5ngf285uI1oXbbAuWhdNUtUko2pRHnd/Yz1KeMStR7e49fAWtx/d5tajW9x5dIf7cfe5F3fP4PF+3H3uxaZ5HnePhwkPzRZLTgQCext77K3tsbexx8HGIeW5vbU9dtZ22FnbYWttqz1a2Ro+t7JlEYtyf/78nMxR3wdmk5TSJ5NtnYCxwHNonXe/kVI2z+mYderUkSdPqnZIKbUEJy5OW2JjoVw5sLPTtp86BZGRqcnOo0cQEwMPH2pJRb9+2peun18Aw4dr25K3Jz+PidGaz/T/sPLll1qilxknJ618snr14MSJ1Nd2dloiU6YMDB4Mb7yhrb98Gdau1bZVqAAVK2pLyZKFK+EpzgnM6M2jWXxwcUp1P4CdtR3DGg8zuu+Dz3c+Kc1HadV3r8/R0ZneyGj2GB73GMl9XzIjJxr3d9f+U3uD86eNI+7DuDzfH7S+L8nNR2lZYYVuYs7/gZvjZ2GO91F6eumU5p+0SjmU4va7t/Nkfykl0fHRXI2+mmG5+fAmtx6lJinJCcujxCyaFIwkEJSwL4GrvSvOts442TrhbKd/tE33mH69/rWjjaOWjOgTkbTP066zsbJBPOYfZiHEASlls5xLZpRvNTBCiFVAAFBWCHEZmAjYAkgp5wF/oCUvZ4CHwODMj6RkRggtKbCzA9dMWp+8vLQlJ46OWs2KMUaOhK5d4d69jIsu3d+1hg21Y9+6BTdvaolRZKS23E7zu3/0KIwbl/Fc9vZaIhMYCFX1N3Js3Kgdy9NT6xdUpYp2DsWy9l7em+HLJl4Xz7+X/zX6GGfvZN5BM6v1eRGDOY7xuDL70s5uvbn3BzJNXrJbn545rqM53sfd2LsmrTe23J3YO6w4vIKL9y5y+f5lgyTlWsw1k2tE7KztKONYhjJOZVIeSzmUws3ejRL2JXBz0B6Tl+T1yYuznTNWongM8ZavNTB5QdXAmE9+1ho8fJiazJQpk9ox+fBhWLgQbtyAq1chKkpbHjzQtt+9qzVPAXToAH//bXjccuW0BKdrV/jwQ21dfDycPAk1a2o1Q/mhONfAmFteX8umTbW+cQcOHMizcxQE6jNpmnhdPBF3Izhz+wxnbp/hwt0LXLx/kUv3LnHmxhlux9/O0LSWGUcbRyq6VqS8c3kquFSggksFyjuXx93ZPUOiUtqxNM62zo9dq1GYFIoaGEVJy8lJWypXNlzfsCF8+23G8tHRWiJTokTquk6dtGamy5fhwgW4dAmuX9eWRo1Sy506pR0XtNqa2rW12qjatbUlIMDwuErxEhYWZukQFAtJTErk7O2zhN8M58ztM5y9fZYzd7SE5eK9iyTJrGuZrIQVHq4eVHGrQhW3KlQuURnPEp5Ucq1kkKy42LkUq4QkP6kERikUXFy0ZCOt9E1NOp2W5Fy4kFpLA1ptT506cO6cluxcvgw7d6ZuP3oU6tfXni9erCVCPj7QoIFWa2OjfkuKNDWOVNGXnKgcu3GM4zeOpzyeuHkiy2YoK2FFtZLVqFmqJrVK16J6yepaouJWmSvHr9Djfz2wsVJ/HCxJXX2lyLC21mpYPD0N1zdvrnUgTkyEixe1GplTp7Q7vU6d0pKUZMuWwe7dqa/t7bXxgRo0gGefhT598ue9KPknuQmpMIuOj6bPr32oXrI6T1Z9khaeLahconKx/M//UcIjDl87zMGrBwmLCiMsKoyj148Sp8u8s2/lEpXxdvfGq4xXSrJSq3QtqpWshr2Nfab7BJ4LVMlLAaB+AkqxYWMDNWpoS8eOmZd5/XXw99dqZY4c0RKegwe1xckpNYE5f167TbxZM2154gnVgVixnOs
"text/plain": [
"<Figure size 576x216 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"X_new = np.linspace(0, 3, 1000).reshape(-1, 1) # reshape to get a column vector\n",
2017-05-29 23:20:14 +02:00
"y_proba = log_reg.predict_proba(X_new)\n",
2021-11-03 23:35:15 +01:00
"decision_boundary = X_new[y_proba[:, 1] >= 0.5][0, 0]\n",
"\n",
2022-02-19 06:17:36 +01:00
"plt.figure(figsize=(8, 3)) # extra code – not needed, just formatting\n",
2021-11-03 23:35:15 +01:00
"plt.plot(X_new, y_proba[:, 0], \"b--\", linewidth=2,\n",
" label=\"Not Iris virginica proba\")\n",
"plt.plot(X_new, y_proba[:, 1], \"g-\", linewidth=2, label=\"Iris virginica proba\")\n",
"plt.plot([decision_boundary, decision_boundary], [0, 1], \"k:\", linewidth=2,\n",
" label=\"Decision boundary\")\n",
"\n",
2022-05-24 14:37:17 +02:00
"# extra code – this section beautifies and saves Figure 4– 23\n",
2021-11-03 23:35:15 +01:00
"plt.arrow(x=decision_boundary, y=0.08, dx=-0.3, dy=0,\n",
" head_width=0.05, head_length=0.1, fc=\"b\", ec=\"b\")\n",
"plt.arrow(x=decision_boundary, y=0.92, dx=0.3, dy=0,\n",
" head_width=0.05, head_length=0.1, fc=\"g\", ec=\"g\")\n",
"plt.plot(X_train[y_train == 0], y_train[y_train == 0], \"bs\")\n",
"plt.plot(X_train[y_train == 1], y_train[y_train == 1], \"g^\")\n",
"plt.xlabel(\"Petal width (cm)\")\n",
"plt.ylabel(\"Probability\")\n",
"plt.legend(loc=\"center left\")\n",
2017-05-29 23:20:14 +02:00
"plt.axis([0, 3, -0.02, 1.02])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
2017-05-29 23:20:14 +02:00
"save_fig(\"logistic_regression_plot\")\n",
2021-11-03 23:35:15 +01:00
"\n",
2017-05-29 23:20:14 +02:00
"plt.show()"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 55,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"1.6516516516516517"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
"decision_boundary"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 56,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([ True, False])"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
"log_reg.predict([[1.7], [1.5]])"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 57,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAAEQCAYAAAC++cJdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAACq5UlEQVR4nOydd1gURx+A3zl6taFgwW7svQD2Envvms8kmsQSY9fEGBNjjSaxd2NsUWPv3cReAHuJvYsKdun1br4/DhAEFA64A533efbxbnZ25rd7cP6YnX1HSClRKBQKhUKhUCg+FDSmDkChUCgUCoVCoTAmKgFWKBQKhUKhUHxQqARYoVAoFAqFQvFBoRJghUKhUCgUCsUHhUqAFQqFQqFQKBQfFCoBVigUCoVCoVB8UJibOoD0xMnJSRYsWNCgY4ODg7Gzs0vbgDIoL3zDef4oHKe8VmRzsTK4HWNes8B7Lwh9HoxjwexYZzfB5yQl4f9dBwSWxQogrBNet/DwcKys3n09U3rddMFBRN65hiZLNizyFQRhur9joyJDeeF3GSvrLDg6FUYYKRZj/aw9Dr1HYORznG0K4mCRPd37S28+pO+1tEJds5SjrplhqOuWck6fPv1MSpnToIOllO/tVrlyZWkoBw4cMPjYzEZYSJT8pdNZ2YRdckavizIyQmtQO8a8ZqEvguWGujPkTPpLr9E7pU6nM1rfsTGcvCBvuVSXN7NUksH/HE3WMVeuXJETJkyIF29Kr5tOp5MvFv4mrxdD3u/oLiOfPU7R8WnNhSPz5KyBZvLvSeVl4Asfo/RprJ+1gPDn8luvOrLJLuTKG2NN8nOWlnxI32tphbpmKUddM8NQ1y3lAKekgTmimgKhwMrGjOGrytP5h8Ls+uMBP7c4TbB/pKnDeivW2WxpvacvJT6vxonRu/jns+Vow40bs3WVsrh6r8PcNTcPm/bE/8+1b63/9OlTmjdvzsiRI+nWrRthYWEG9SuEINtX35J71gbCr57Hp6M7ETevGNRWWlC2Zh9a9t6B/7PbrJ3qxlOfsyaLJa1xsMzO+Kp7aJDnM5bfHMWUi58ToQs3dVgKhUKhSCUqAVYAoNEIuk/4iEGLynB+/wuG1vDm8b1QU4f1Vswszfl4yf9wH9ecaytOsbnhXEKfBxs1Bov8ech3bDW2DTx40vNHnn3/O1KnS7TulClTuH37NgB///03DRs25NmzZwb3bd+4HfmWH0SGheDT2YMQz/0Gt5VaCpRsTIfBxxAaMzbMqMWdi9tMFktaY6mxYmjZpXxadCz7Hi1n8oVPTR2SQqFQKFKJSoAV8Wj8RT7G76nCswdhDHbz5NqJV6YO6a0IIaj6Y2Mar/qcxyfusd5jKq9uPDFqDGaO9uTZvgDH3l14+etC/LoMRheacHR3/Pjx9O7dO/b90aNH8fDwwMfHx+C+rctXw3WdN+Yu+Xj4ZWP81y82uK3U4pSnLJ2GeJPNpSQ7/mzD+UMzTRZLWiOE4JOiP/FduZW0LzjM1OEoFAqFIpWoBFiRgAr1czDV0x1LGzOG1z3BsY1+pg7pnXzUpTJt9vUj7EUI69yn8vDILaP2L8zNyTVvDE6ThxO0fjcP631K1OP4o7vm5ubMmzePyZMnI4QA4ObNm/Tr14/Dhw8b3LdF3gLkW30MW7d6PPnhS55N+SHJUej0xi5Lbtr1P0ihsq04vGEgh9b3R6eNMkks6UG9PJ9QPGs1U4ehUCgUilSiEmBFouQvac90b3cKlXdgQodzrJ98B/1884xLnhqF6eg1BGsnezZ/PJtrK08atX8hBNmGfknuDbMIv3ANH/dOhF++maDO0KFD2bBhAzY2NgAEBATw8ccfs3z5coP7NnPIQp4/duDYuRcvF0zEb3BXdGGmmcJiYWVH0y/WU6HeEC4cns2OP9sQERZoklgUCoVCoUgMlQArkiRrLism7a9GzQ4uLPr2GrP6XCIq0jQji8kla9GcdPQcTG6PQuztthzvMbuMnrjbt21EvkMrkKFhPKjemZB9xxPUadu2LYcOHcLZ2RmAyMhIPvvsM0aPHm1wvMLCglxj5+P03W8E7VrLw8/qE/XcuNNBYtBozKjVdgp1O83j3pXdbJhRi6CXD0wSi0KhUCgUb6ISYMVbsbIx4/vV5ek8IhMZIrLb0XpvX0p8Fm2I+HyF8Q0RVcu9NkQ0+Qr/ResS1KlatSre3t4UKlQotmzMmDGsXr3a4H5jDBEus9YTfuVcxjBE9Nr+XhoiFAqFQpF5UQmw4p1oNILuv7w2RAyrmUkMEUv/h9vYZlxbfpLNjUxgiCiQV2+IqO/Ok69G8uyHKQnm5hYoUICZM2fSqFEjANq3b0/nzp1T3bdD4/bkW3EIGRpsekNEqSZ0GHQUITR6Q8R/200Wi0KhUCgUoBJgRQpo/EU+xu2uwlOfaEPESX9Th/RWhBBU+6kJjf/+HD+vu3pDxM2nRo0hniFi4oJEDRH29vZs376dSZMm8ddff6HRpM2vZQJDxIYladKuITjlLUenoSfI5lyCHQtbc/7QLJPFolAoFAqFSoAVKaJigxxMOR5tiKjjzbFNj00d0jv5qGtl2u7vH2uIeHTUyIYICwu9IeL3aENE/c+IevI8Xh0LCwuGDx+Ora1tvHKtVsv58+cN7tsiX0HyrTqKTbW6PBnxBc+mjjStIWLAIQqWacnhDQM4tH4AOp3WJLFkFDL6g6UKhULxvqISYEWKKVBKb4goWM6BCe3PZi5DRA47NjUwkSFi2JfkXj+T8PNX8XHvSMSVm+88bvDgwVSrVo0VK1YY3LeZY1byLtyJY6eevJz/C35DPjGpIaLZlxuiDRGz2LGwDRHhQSaJJSMQo8PL6L8/CoVC8b6hEmCFQWTNZcWvB14bImZ/fRmdNmP/J/6mIeLEuN3GN0S0a6w3RISE4eORuCEihgULFjBr1iwiIiL49NNPGTNmTOoMEeMWkOPbXwnauYaHnzcg6oVxp4PEEGuI6DiXe5d3smF6LYJePTRJLKbEN+Q2W+7NYurFL1h+cxQPg2+YOiSFQqH4YFAJsMJgYgwRnb4vxM4FPqwfEUlwQMZe9CCuIcJ71E7+7b7SdIaIfC48bPIVtjuPJlqvSZMmlC5dOvb96NGj+eyzzwgPDzeoXyEE2Xt+pzdEXD7Lg47uRNy6alBbaUHZWl/TsvcO/J/dYu0UN54+OGeyWIzNhReHGOzlzv5Hy3kQfJX7QVcY5FmN+0GmM3YoFArFh4RKgBWpQqMR9JhYnEF/luHeGR3Danjx5H7mMURc/esEWxrPI+yFiQwR9dzI9vtfPBsxOVFDxLFjx2jYsGFs2YoVK2jYsCHPnz9/s8lkozdEHEQXEqQ3RHgdMLit1PLaECHYML0mdy7tMFksxsL7yXZ+OtWEKk5NGFxmMVPdj/NjxfVUzdmclTdHo9Vl7D8iFQqF4n1AJcCKNKHxl/no+KsFT33CGOTmlWkMEY1Wfoav5x3WeUwzviEiiwN5dvxBcItavJz0B35dExoismTJwo4dO+jVq1ds2ZEjR/Dw8ODmzXfPIU4K6/JuekNErjw8/KIRARuXGtxWanHKW46OQ73Jmqs4O/5o9V4bIm76n2HCuQ58nLc7fUrOpKBDmdh9eWyL8jz8EZKMPZVIoVAo3gdUAqxIMwpWNtMbIqw1mcYQUfyTKrTd14+w58EmM0S8GtJNb4hYuytJQ8T8+fP5/fffY8tu3LiBu7s7R48mPn0iOVjkK0i+1cewqVaXx9/34Nm0H01miLDPkof2Aw/HGiIObxj43hkigqMCWHhtKJWdGtOt6GjsLbLG7gvThnD2+T/ksS2KucbCdEEqFArFB4JKgBVpSoFS9kzzem2I2DAlExgiahbRGyKy27KpwWyurzpt3ACiDREu62claYgQQjBs2DDWr1+PtbU1AM+fP6dBgwZs2bLF4K5jDREdv+LlvAl6Q0R42LsPTAdiDRF1B3P+0Mz3zhARFPEC35Bb1HTuQDYr59jyMG0Iu30WEqYNplKORiaMUKFQKD4cVAKsSHOyOesNETXaO/PnsGvM7nsZbZRpRhaTS9aiOengOQQX94Ls+WS
"text/plain": [
"<Figure size 720x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2022-05-24 14:37:17 +02:00
"# extra code – this cell generates and saves Figure 4– 24\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = iris.target_names[iris.target] == 'virginica'\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"log_reg = LogisticRegression(C=2, random_state=42)\n",
"log_reg.fit(X_train, y_train)\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"# for the contour plot\n",
"x0, x1 = np.meshgrid(np.linspace(2.9, 7, 500).reshape(-1, 1),\n",
" np.linspace(0.8, 2.7, 200).reshape(-1, 1))\n",
"X_new = np.c_[x0.ravel(), x1.ravel()] # one instance per point on the figure\n",
2016-05-22 16:01:18 +02:00
"y_proba = log_reg.predict_proba(X_new)\n",
"zz = y_proba[:, 1].reshape(x0.shape)\n",
"\n",
2021-11-03 23:35:15 +01:00
"# for the decision boundary\n",
2016-05-22 16:01:18 +02:00
"left_right = np.array([2.9, 7])\n",
2021-11-03 23:35:15 +01:00
"boundary = -((log_reg.coef_[0, 0] * left_right + log_reg.intercept_[0])\n",
" / log_reg.coef_[0, 1])\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"plt.figure(figsize=(10, 4))\n",
"plt.plot(X_train[y_train == 0, 0], X_train[y_train == 0, 1], \"bs\")\n",
"plt.plot(X_train[y_train == 1, 0], X_train[y_train == 1, 1], \"g^\")\n",
"contour = plt.contour(x0, x1, zz, cmap=plt.cm.brg)\n",
"plt.clabel(contour, inline=1)\n",
2016-05-22 16:01:18 +02:00
"plt.plot(left_right, boundary, \"k--\", linewidth=3)\n",
2021-11-27 01:38:47 +01:00
"plt.text(3.5, 1.27, \"Not Iris virginica\", color=\"b\", ha=\"center\")\n",
"plt.text(6.5, 2.3, \"Iris virginica\", color=\"g\", ha=\"center\")\n",
2021-11-03 23:35:15 +01:00
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
2016-05-22 16:01:18 +02:00
"plt.axis([2.9, 7, 0.8, 2.7])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
2016-05-22 16:01:18 +02:00
"save_fig(\"logistic_regression_contour_plot\")\n",
"plt.show()"
]
},
2021-11-03 23:35:15 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Softmax Regression"
]
},
2016-05-22 16:01:18 +02:00
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 58,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"LogisticRegression(C=30, random_state=42)"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
2016-05-22 16:01:18 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
2016-05-22 16:01:18 +02:00
"y = iris[\"target\"]\n",
2021-11-03 23:35:15 +01:00
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"softmax_reg = LogisticRegression(C=30, random_state=42)\n",
"softmax_reg.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 59,
2021-11-03 23:35:15 +01:00
"metadata": {
"tags": []
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([2])"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
2021-11-03 23:35:15 +01:00
"source": [
"softmax_reg.predict([[5, 2]])"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 60,
2021-11-03 23:35:15 +01:00
"metadata": {
"tags": []
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0. , 0.04, 0.96]])"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
2021-11-03 23:35:15 +01:00
"source": [
"softmax_reg.predict_proba([[5, 2]]).round(2)"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 61,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAAEQCAYAAAC++cJdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAD+5UlEQVR4nOydd3hUddbHP3dKepn0XkmjE3oICUgTQSmCIE0g7Lq7rmtd9V17X9fuWlelV0GqCiKKmITee0ISkpDeJ8mkTrnvHwlISCCTEJr8Ps8zmkx+5dx7h5kz557zPZIsywgEAoFAIBAIBLcLihttgEAgEAgEAoFAcD0RDrBAIBAIBAKB4LZCOMACgUAgEAgEgtsK4QALBAKBQCAQCG4rhAMsEAgEAoFAILitEA6wQCAQCAQCgeC24ro5wJIkWUmStE+SpKOSJJ2UJOmVFsYMlSSpXJKkI42PF6+XfQKBQCAQCASC2wPVddyrDhgmy7JOkiQ1kChJ0hZZlvdcMi5BluW7r6NdAoFAIBAIBILbiOvmAMsNHTd0jb+qGx+iC4dAIBAIBAKB4LpyXXOAJUlSSpJ0BCgEtsmyvLeFYVGNaRJbJEnqej3tEwgEAoFAIBD88ZFuRCtkSZI0wHrgH7Isn7joeQfA1JgmMQb4SJbl0BbmPwg8CGBlZdXH39/n+hh+DTGZQCFKEm8K2nMtDAYlqAzXxqDbGMkkISvEjaKbBXE9bh7Etbh5ENfi5uLsmbPFsiy7tTbuhjjAAJIkvQRUybL87hXGZAB9ZVkuvtyYPn1C5D173rsGFl5fEhNh8OAbbYUA2nctFiwYj1Xc4mtj0G2MXaIdusG61gcKrgvietw8iGtx8yCuxc3FHIs5B2VZ7tvauOupAuHWGPlFkiRrYASQdMkYT0mSpMaf+zfaV3K9bBQI2ktc3Eamo7nRZggEAoFAIDCD66kC4QUsliRJSYNju1qW5e8lSforgCzLXwCTgb9JkmQAaoD75RsVohYI2ogsy8QeKWXZ2yYMdhupyy7FqK1GNpqQ1EpUGhtUrvZY+Dpj1ckdqzBPbLr6onKyvdGmCwQCgUBwW3E9VSCOAZEtPP/FRT9/AnxyvWwSCDoCk9HIwVUJbHvrWwqSswGw99AgB7qgcrVDUiqQDUYM2mpqU/KpzylD1hsvzLcIcMG2dxD2AzphFx2KbWQgCovr+d1UIBAIBILbC/EpKxBcBVWllSyZ9R5J247g3T2Q+7/4O13H9MXB0wloOTdYNpqoO1dCbXIe1cezqDqSSdWBdMrWHwBAYW2BXVQIDkM74ziiK7a9A5FEhaRAIBAIBB2GcIAFgnZSU17Ff4c9R1FqLlM+/RtR80aiuMRRjYvbCGhYgfbCc5JSgVWQG1ZBbmhG97jwvL6gnMqdZ6jceYaKHUlkv7iW7BfXonKxw3F4VxxH90BzV0/ULnbX6QgFAoFAIPhjIhxggaCdrH74CwrP5PDX718kfFjPK46tXTC7VZUItYcjzvf2w/nefkCDQ1z+y0nKt51Au+0EJav3gkLCbmAITndH4jSuN9Zhnh12PAKBQCAQ3C4IB1ggaAcpvx3n0DcJ3PXitFadX2iIBK9o4x5qD0dcpw/CdfogZJOJqkMZaL8/Qtnmo2Q9u5qsZ1dj3dkbp/F9cJ7YF5te/jSKqAgEAoFAILgCwgEWCNrBL++tx8HTieH/nGj2nOlo2q0XLCkU2PUNxq5vML4v30vduRLKNh2idONBct/+nty3vsMyxAPnCX1wntSvIW9YOMMCgUAgELSIcIAFgjZSWagl6acjjHjqXtRWFm2a21JOcHuw9HfB8+GReD48En1RBWXfHabk233kffAjee9uxrKTOy6T++Ny3wCsu/sKZ1ggEAgEgosQpeUCQRtJ2nYE2WSi58SB7V5jOhpqF8zuEHvUbg64xw2h8+an6J39X4L+NxfLQDdy3/mB431f4Hiv58h5cxO1qQUdsp9AIBAIBLc6IgIsELSR9D1JWDnY4NMr+KrW6aho8MWoXexwnzsE97lD0BdVULruACWr9pD98jqyX16Hbb9gXKdF4Ty5Hxaemg7bVyAQCASCWwkRARYI2kju8Uy8uwc0kzxrLx0ZDb4YtZsDHn8ZRpdfn6VX2nv4vzUVud5A5hPLORz4OElj36V4+S6MutoO31sgEAgEgpsZEQEWCNpIydl8Oo/u3aFrXoto8MVY+rng9cRdeD1xF9WncihZtYeSVbtJm/slCmsLnMb3xnVmNI7DuiCplNfEBoFAIBAIbhaEAywQtAGT0UhlgRaNj+uNNqXd2HTxwebVSfi+PBHd7lSKV+6mZM0+SlbtQe3piMvUgbjOjMa2p/+NNlUgEAgEgmuCcIAFgjZQXVaFLMvYudpfk/WnX8Mo8KVICgX20WHYR4cR8O40tFuOUbxiNwWf/Uz+R1ux7uaL28xoXKZHiXxhgUAgEPyhEA6wQNAG6iprALC0t75me0xHw8L/3Yljv/9gcSYfZUkliqo6kCRMtpYYNTYYvJ3Q+7ugD3AF5dXnIiusLHCe2BfniX3Rl+goWb2H4uW7OPd/33Du2dU4juyG28xonMb1RmHdNuk3gUAgEAhuNoQDLBC0AX1tPQBqy453AmVZRrf9GEWf/UDvn6cg19S3PketpC7ci9ruftRGBlDTvxO1kQHIV+Gkql3s8PzbCDz/NoKapFyKl++ieMUuUmd9gdLRGpf7BuD6wGDsBnQS+sICgUAguCURDrBA0AZMBiMAig4uFNPnl5E5+310vx5H5emEy9wR2MV0wzLCl9W/TMPiz+tBllFU1aEs0aHK02KRUYTFmQIsT+dgG5+EZuVuoMEprukTRPXgMKqGdqZ6cBiyjWW77LKO8Mbvtcn4vnIvFb+epnjZTopX7KLw6x1YhXriOnMQrjOjsfRz6cjTIRAIBALBNUU4wAJBW2iMeMqy3GFL1p7JIe3OFzGUVuLzwZ9wmTcKxUUd5mZ33Qm4swItRhtLjG4O1Ed4U33JOsp8Ldb707HZnYJN4hlcPtyK67ubMVmoqBkUim5kN3Sje1DXzffCcZiLpFDgOLwrjsO7EvjfWZSs3U/x0p1kv7SO7JfX4ziiK64zBuE0oQ/KdjrbAoFAIBBcL4QDLBC0AUVjvq1s6hgH2KirIX3iG8j1BkJ/ewubKzTXaK1AzuipQXdPJLp7IgGQquqw2XkG2+0nsfvpBB7PrcHjuTXo/ZypHNsL3ZheVN3RGdlS3SablfbWuM+JxX1OLLVnCylekkjR8p2kzfkSpYM1LlMH4DpLpEgIBAKB4OZFOMACQRtQWTT8kzHU6TtkvYI3VlOXkkunba9d0fk9T+2C2VjFLTZrbdnWkqpR3aka1Z3Ct0CVW4bd1mPYfX8EzdKdOH+xHaOdFbrRPagc3xvdXT0xObStuM8q2B3fl+/F58UJVMQnU7QonuJluyj8agdWYZ64zY7BdcYgLLyd2rSuQCAQCATXEuEACwRtQG3dcHu/vqbuqtcyFJVT9OkPOM28A/sh3c2ed3HXOHOdYQCDtxPauUPQzh2CVFuP7a+nsf/uMPabDuH47T5MFiqqRnSlYlJ/Ku/phUlja/bakkKB49DOOA7tjOG/NZSu20/RogSynltD1gvformzO26zY9CM7YWijRFngUAgEAg6GuEACwRtwMK20QGuuvr2wWXfJCDX1uP+5MR2r9FeZ1i2skB3V090d/Uk7+MHsN6bhsP6Azis24/95qPIaiW6kd2ouG8AlfdEtikyrHK4KEUiJZ+iJYkULduJ9v5PUTnb4jItCrfZMdj2CmjTsQoEAoFA0FEIB1ggaAOWdlbA73rAV0P5pr1YdfXHumvHdFw77wy3xREGQKmgZlAoNYNCKXj7fqz3n8Vh7X4cvt2H/eajmCxV6Eb3pHzqAHRje7VJYs0q1LNBReLleyn/5SRFixIo/GoHBZ/+jE0PP9zihuAydSBqF7u22SwQCAQCwVUgHGCBoA0olEos7ayovUoH2FSvp2p3Eq5/vauDLPud9kaFAZAkavp3oqZ/Jwr+PQXrfWdxWL0Xh2/34bDxIEY7KyrH96b8/oFUDe8KZsr
"text/plain": [
"<Figure size 720x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2022-05-24 14:37:17 +02:00
"# extra code – this cell generates and saves Figure 4– 25\n",
2021-11-21 05:36:22 +01:00
"\n",
2021-11-03 23:35:15 +01:00
"from matplotlib.colors import ListedColormap\n",
"\n",
"custom_cmap = ListedColormap([\"#fafab0\", \"#9898ff\", \"#a0faa0\"])\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"x0, x1 = np.meshgrid(np.linspace(0, 8, 500).reshape(-1, 1),\n",
" np.linspace(0, 3.5, 200).reshape(-1, 1))\n",
"X_new = np.c_[x0.ravel(), x1.ravel()]\n",
2016-05-22 16:01:18 +02:00
"\n",
"y_proba = softmax_reg.predict_proba(X_new)\n",
"y_predict = softmax_reg.predict(X_new)\n",
"\n",
"zz1 = y_proba[:, 1].reshape(x0.shape)\n",
"zz = y_predict.reshape(x0.shape)\n",
"\n",
"plt.figure(figsize=(10, 4))\n",
2021-11-03 23:35:15 +01:00
"plt.plot(X[y == 2, 0], X[y == 2, 1], \"g^\", label=\"Iris virginica\")\n",
"plt.plot(X[y == 1, 0], X[y == 1, 1], \"bs\", label=\"Iris versicolor\")\n",
"plt.plot(X[y == 0, 0], X[y == 0, 1], \"yo\", label=\"Iris setosa\")\n",
2016-05-22 16:01:18 +02:00
"\n",
2018-03-15 18:38:58 +01:00
"plt.contourf(x0, x1, zz, cmap=custom_cmap)\n",
2021-11-03 23:35:15 +01:00
"contour = plt.contour(x0, x1, zz1, cmap=\"hot\")\n",
"plt.clabel(contour, inline=1)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"center left\")\n",
"plt.axis([0.5, 7, 0, 3.5])\n",
"plt.grid()\n",
2016-05-22 16:01:18 +02:00
"save_fig(\"softmax_regression_contour_plot\")\n",
"plt.show()"
]
},
2016-09-27 16:39:16 +02:00
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2016-09-27 16:39:16 +02:00
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"## 1. to 11."
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-25 09:45:32 +01:00
"1. If you have a training set with millions of features you can use Stochastic Gradient Descent or Mini-batch Gradient Descent, and perhaps Batch Gradient Descent if the training set fits in memory. But you cannot use the Normal Equation or the SVD approach because the computational complexity grows quickly (more than quadratically) with the number of features.\n",
"2. If the features in your training set have very different scales, the cost function will have the shape of an elongated bowl, so the Gradient Descent algorithms will take a long time to converge. To solve this you should scale the data before training the model. Note that the Normal Equation or SVD approach will work just fine without scaling. Moreover, regularized models may converge to a suboptimal solution if the features are not scaled: since regularization penalizes large weights, features with smaller values will tend to be ignored compared to features with larger values.\n",
"3. Gradient Descent cannot get stuck in a local minimum when training a Logistic Regression model because the cost function is convex. _Convex_ means that if you draw a straight line between any two points on the curve, the line never crosses the curve.\n",
"4. If the optimization problem is convex (such as Linear Regression or Logistic Regression), and assuming the learning rate is not too high, then all Gradient Descent algorithms will approach the global optimum and end up producing fairly similar models. However, unless you gradually reduce the learning rate, Stochastic GD and Mini-batch GD will never truly converge; instead, they will keep jumping back and forth around the global optimum. This means that even if you let them run for a very long time, these Gradient Descent algorithms will produce slightly different models.\n",
"5. If the validation error consistently goes up after every epoch, then one possibility is that the learning rate is too high and the algorithm is diverging. If the training error also goes up, then this is clearly the problem and you should reduce the learning rate. However, if the training error is not going up, then your model is overfitting the training set and you should stop training.\n",
"6. Due to their random nature, neither Stochastic Gradient Descent nor Mini-batch Gradient Descent is guaranteed to make progress at every single training iteration. So if you immediately stop training when the validation error goes up, you may stop much too early, before the optimum is reached. A better option is to save the model at regular intervals; then, when it has not improved for a long time (meaning it will probably never beat the record), you can revert to the best saved model.\n",
"7. Stochastic Gradient Descent has the fastest training iteration since it considers only one training instance at a time, so it is generally the first to reach the vicinity of the global optimum (or Mini-batch GD with a very small mini-batch size). However, only Batch Gradient Descent will actually converge, given enough training time. As mentioned, Stochastic GD and Mini-batch GD will bounce around the optimum, unless you gradually reduce the learning rate.\n",
"8. If the validation error is much higher than the training error, this is likely because your model is overfitting the training set. One way to try to fix this is to reduce the polynomial degree: a model with fewer degrees of freedom is less likely to overfit. Another thing you can try is to regularize the model—for example, by adding an ℓ₂ penalty (Ridge) or an ℓ₁ penalty (Lasso) to the cost function. This will also reduce the degrees of freedom of the model. Lastly, you can try to increase the size of the training set.\n",
"9. If both the training error and the validation error are almost equal and fairly high, the model is likely underfitting the training set, which means it has a high bias. You should try reducing the regularization hyperparameter _α _.\n",
"10. Let's see:\n",
" * A model with some regularization typically performs better than a model without any regularization, so you should generally prefer Ridge Regression over plain Linear Regression.\n",
" * Lasso Regression uses an ℓ₁ penalty, which tends to push the weights down to exactly zero. This leads to sparse models, where all weights are zero except for the most important weights. This is a way to perform feature selection automatically, which is good if you suspect that only a few features actually matter. When you are not sure, you should prefer Ridge Regression.\n",
" * Elastic Net is generally preferred over Lasso since Lasso may behave erratically in some cases (when several features are strongly correlated or when there are more features than training instances). However, it does add an extra hyperparameter to tune. If you want Lasso without the erratic behavior, you can just use Elastic Net with an `l1_ratio` close to 1.\n",
"11. If you want to classify pictures as outdoor/indoor and daytime/nighttime, since these are not exclusive classes (i.e., all four combinations are possible) you should train two Logistic Regression classifiers."
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"## 12. Batch Gradient Descent with early stopping for Softmax Regression\n",
2022-02-19 06:17:36 +01:00
"Exercise: _Implement Batch Gradient Descent with early stopping for Softmax Regression without using Scikit-Learn, only NumPy. Use it on a classification task such as the iris dataset._"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Let's start by loading the data. We will just reuse the Iris dataset we loaded earlier."
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 62,
2018-03-15 18:38:58 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
2021-11-03 23:35:15 +01:00
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = iris[\"target\"].values"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"We need to add the bias term for every instance ($x_0 = 1$). The easiest option to do this would be to use Scikit-Learn's `add_dummy_feature()` function, but the point of this exercise is to get a better understanding of the algorithms by implementing them manually. So here is one possible implementation:"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 63,
2018-03-15 18:38:58 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
2021-11-03 23:35:15 +01:00
"X_with_bias = np.c_[np.ones(len(X)), X]"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2022-05-24 14:37:17 +02:00
"The easiest option to split the dataset into a training set, a validation set and a test set would be to use Scikit-Learn's `train_test_split()` function, but again, we want to do it manually:"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 64,
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
"test_ratio = 0.2\n",
"validation_ratio = 0.2\n",
"total_size = len(X_with_bias)\n",
"\n",
"test_size = int(total_size * test_ratio)\n",
"validation_size = int(total_size * validation_ratio)\n",
"train_size = total_size - test_size - validation_size\n",
"\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
2017-05-29 23:20:14 +02:00
"rnd_indices = np.random.permutation(total_size)\n",
"\n",
"X_train = X_with_bias[rnd_indices[:train_size]]\n",
"y_train = y[rnd_indices[:train_size]]\n",
"X_valid = X_with_bias[rnd_indices[train_size:-test_size]]\n",
"y_valid = y[rnd_indices[train_size:-test_size]]\n",
"X_test = X_with_bias[rnd_indices[-test_size:]]\n",
"y_test = y[rnd_indices[-test_size:]]"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2022-05-24 14:37:17 +02:00
"The targets are currently class indices (0, 1 or 2), but we need target class probabilities to train the Softmax Regression model. Each instance will have target class probabilities equal to 0.0 for all classes except for the target class which will have a probability of 1.0 (in other words, the vector of class probabilities for any given instance is a one-hot vector). Let's write a small function to convert the vector of class indices into a matrix containing a one-hot vector for each instance. To understand this code, you need to know that `np.diag(np.ones(n))` creates an n× n matrix full of 0s except for 1s on the main diagonal. Moreover, if `a` is a NumPy array, then `a[[1, 3, 2]]` returns an array with 3 rows equal to `a[1]`, `a[3]` and `a[2]` (this is [advanced NumPy indexing](https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing))."
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 65,
2018-03-15 18:38:58 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
"def to_one_hot(y):\n",
2021-11-03 23:35:15 +01:00
" return np.diag(np.ones(y.max() + 1))[y]"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Let's test this function on the first 10 instances:"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 66,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([1, 0, 2, 1, 1, 0, 1, 2, 1, 1])"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
"y_train[:10]"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 67,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.]])"
]
},
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
"to_one_hot(y_train[:10])"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Looks good, so let's create the target class probabilities matrix for the training set and the test set:"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 68,
2018-03-15 18:38:58 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
"Y_train_one_hot = to_one_hot(y_train)\n",
"Y_valid_one_hot = to_one_hot(y_valid)\n",
"Y_test_one_hot = to_one_hot(y_test)"
]
},
2021-11-03 23:35:15 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's scale the inputs. We compute the mean and standard deviation of each feature on the training set (except for the bias feature), then we center and scale each feature in the training set, the validation set, and the test set:"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 69,
2021-11-03 23:35:15 +01:00
"metadata": {},
"outputs": [],
"source": [
"mean = X_train[:, 1:].mean(axis=0)\n",
"std = X_train[:, 1:].std(axis=0)\n",
"X_train[:, 1:] = (X_train[:, 1:] - mean) / std\n",
"X_valid[:, 1:] = (X_valid[:, 1:] - mean) / std\n",
"X_test[:, 1:] = (X_test[:, 1:] - mean) / std"
]
},
2017-05-29 23:20:14 +02:00
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Now let's implement the Softmax function. Recall that it is defined by the following equation:\n",
"\n",
"$\\sigma\\left(\\mathbf{s}(\\mathbf{x})\\right)_k = \\dfrac{\\exp\\left(s_k(\\mathbf{x})\\right)}{\\sum\\limits_{j=1}^{K}{\\exp\\left(s_j(\\mathbf{x})\\right)}}$"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 70,
2018-03-15 18:38:58 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
"source": [
"def softmax(logits):\n",
" exps = np.exp(logits)\n",
2021-11-03 23:35:15 +01:00
" exp_sums = exps.sum(axis=1, keepdims=True)\n",
2017-05-29 23:20:14 +02:00
" return exps / exp_sums"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"We are almost ready to start training. Let's define the number of inputs and outputs:"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 71,
2018-03-15 18:38:58 +01:00
"metadata": {},
2016-09-27 16:39:16 +02:00
"outputs": [],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"n_inputs = X_train.shape[1] # == 3 (2 features plus the bias term)\n",
"n_outputs = len(np.unique(y_train)) # == 3 (there are 3 iris classes)"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Now here comes the hardest part: training! Theoretically, it's simple: it's just a matter of translating the math equations into Python code. But in practice, it can be quite tricky: in particular, it's easy to mix up the order of the terms, or the indices. You can even end up with code that looks like it's working but is actually not computing exactly the right thing. When unsure, you should write down the shape of each term in the equation and make sure the corresponding terms in your code match closely. It can also help to evaluate each term independently and print them out. The good news it that you won't have to do this everyday, since all this is well implemented by Scikit-Learn, but it will help you understand what's going on under the hood.\n",
"\n",
"So the equations we will need are the cost function:\n",
"\n",
"$J(\\mathbf{\\Theta}) =\n",
"- \\dfrac{1}{m}\\sum\\limits_{i=1}^{m}\\sum\\limits_{k=1}^{K}{y_k^{(i)}\\log\\left(\\hat{p}_k^{(i)}\\right)}$\n",
"\n",
"And the equation for the gradients:\n",
"\n",
"$\\nabla_{\\mathbf{\\theta}^{(k)}} \\, J(\\mathbf{\\Theta}) = \\dfrac{1}{m} \\sum\\limits_{i=1}^{m}{ \\left ( \\hat{p}^{(i)}_k - y_k^{(i)} \\right ) \\mathbf{x}^{(i)}}$\n",
"\n",
"Note that $\\log\\left(\\hat{p}_k^{(i)}\\right)$ may not be computable if $\\hat{p}_k^{(i)} = 0$. So we will add a tiny value $\\epsilon$ to $\\log\\left(\\hat{p}_k^{(i)}\\right)$ to avoid getting `nan` values."
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 72,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 3.7085808486476917\n",
"1000 0.14519367480830644\n",
"2000 0.1301309575504088\n",
"3000 0.12009639326384539\n",
"4000 0.11372961364786884\n",
"5000 0.11002459532472425\n"
]
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"eta = 0.5\n",
"n_epochs = 5001\n",
2017-05-29 23:20:14 +02:00
"m = len(X_train)\n",
2021-11-03 23:35:15 +01:00
"epsilon = 1e-5\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
2017-05-29 23:20:14 +02:00
"Theta = np.random.randn(n_inputs, n_outputs)\n",
"\n",
2021-11-03 23:35:15 +01:00
"for epoch in range(n_epochs):\n",
" logits = X_train @ Theta\n",
2017-05-29 23:20:14 +02:00
" Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
" if epoch % 1000 == 0:\n",
" Y_proba_valid = softmax(X_valid @ Theta)\n",
" xentropy_losses = -(Y_valid_one_hot * np.log(Y_proba_valid + epsilon))\n",
" print(epoch, xentropy_losses.sum(axis=1).mean())\n",
2021-03-02 05:26:41 +01:00
" error = Y_proba - Y_train_one_hot\n",
2021-11-03 23:35:15 +01:00
" gradients = 1 / m * X_train.T @ error\n",
2017-05-29 23:20:14 +02:00
" Theta = Theta - eta * gradients"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"And that's it! The Softmax model is trained. Let's look at the model parameters:"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 73,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.41931626, 6.11112089, -5.52429876],\n",
" [-6.53054533, -0.74608616, 8.33137102],\n",
" [-5.28115784, 0.25152675, 6.90680425]])"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
"Theta"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Let's make predictions for the validation set and check the accuracy score:"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 74,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9333333333333333"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"logits = X_valid @ Theta\n",
2017-05-29 23:20:14 +02:00
"Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
"y_predict = Y_proba.argmax(axis=1)\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"accuracy_score = (y_predict == y_valid).mean()\n",
2017-05-29 23:20:14 +02:00
"accuracy_score"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"Well, this model looks pretty ok. For the sake of the exercise, let's add a bit of $\\ell_2$ regularization. The following training code is similar to the one above, but the loss now has an additional $\\ell_2$ penalty, and the gradients have the proper additional term (note that we don't regularize the first element of `Theta` since this corresponds to the bias term). Also, let's try increasing the learning rate `eta`."
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 75,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 3.7372\n",
"1000 0.3259\n",
"2000 0.3259\n",
"3000 0.3259\n",
"4000 0.3259\n",
"5000 0.3259\n"
]
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"eta = 0.5\n",
"n_epochs = 5001\n",
2017-05-29 23:20:14 +02:00
"m = len(X_train)\n",
2021-11-03 23:35:15 +01:00
"epsilon = 1e-5\n",
"alpha = 0.01 # regularization hyperparameter\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
2017-05-29 23:20:14 +02:00
"Theta = np.random.randn(n_inputs, n_outputs)\n",
"\n",
2021-11-03 23:35:15 +01:00
"for epoch in range(n_epochs):\n",
" logits = X_train @ Theta\n",
2017-05-29 23:20:14 +02:00
" Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
" if epoch % 1000 == 0:\n",
" Y_proba_valid = softmax(X_valid @ Theta)\n",
" xentropy_losses = -(Y_valid_one_hot * np.log(Y_proba_valid + epsilon))\n",
" l2_loss = 1 / 2 * (Theta[1:] ** 2).sum()\n",
" total_loss = xentropy_losses.sum(axis=1).mean() + alpha * l2_loss\n",
" print(epoch, total_loss.round(4))\n",
2021-03-02 05:26:41 +01:00
" error = Y_proba - Y_train_one_hot\n",
2021-11-03 23:35:15 +01:00
" gradients = 1 / m * X_train.T @ error\n",
" gradients += np.r_[np.zeros([1, n_outputs]), alpha * Theta[1:]]\n",
2017-05-29 23:20:14 +02:00
" Theta = Theta - eta * gradients"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Because of the additional $\\ell_2$ penalty, the loss seems greater than earlier, but perhaps this model will perform better? Let's find out:"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 76,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9333333333333333"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"logits = X_valid @ Theta\n",
2017-05-29 23:20:14 +02:00
"Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
"y_predict = Y_proba.argmax(axis=1)\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"accuracy_score = (y_predict == y_valid).mean()\n",
2017-05-29 23:20:14 +02:00
"accuracy_score"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"In this case, the $\\ell_2$ penalty did not change the test accuracy. Perhaps try fine-tuning `alpha`?"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"Now let's add early stopping. For this we just need to measure the loss on the validation set at every iteration and stop when the error starts growing."
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 77,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 3.7372\n",
"281 0.3256\n",
"282 0.3256 early stopping!\n"
]
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"eta = 0.5\n",
"n_epochs = 50_001\n",
2017-05-29 23:20:14 +02:00
"m = len(X_train)\n",
2021-11-03 23:35:15 +01:00
"epsilon = 1e-5\n",
"C = 100 # regularization hyperparameter\n",
2017-05-29 23:20:14 +02:00
"best_loss = np.infty\n",
"\n",
2021-11-03 23:35:15 +01:00
"np.random.seed(42)\n",
2017-05-29 23:20:14 +02:00
"Theta = np.random.randn(n_inputs, n_outputs)\n",
"\n",
2021-11-03 23:35:15 +01:00
"for epoch in range(n_epochs):\n",
" logits = X_train @ Theta\n",
2017-05-29 23:20:14 +02:00
" Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
" Y_proba_valid = softmax(X_valid @ Theta)\n",
" xentropy_losses = -(Y_valid_one_hot * np.log(Y_proba_valid + epsilon))\n",
" l2_loss = 1 / 2 * (Theta[1:] ** 2).sum()\n",
" total_loss = xentropy_losses.sum(axis=1).mean() + 1 / C * l2_loss\n",
" if epoch % 1000 == 0:\n",
" print(epoch, total_loss.round(4))\n",
" if total_loss < best_loss:\n",
" best_loss = total_loss\n",
2017-05-29 23:20:14 +02:00
" else:\n",
2021-11-03 23:35:15 +01:00
" print(epoch - 1, best_loss.round(4))\n",
" print(epoch, total_loss.round(4), \"early stopping!\")\n",
" break\n",
" error = Y_proba - Y_train_one_hot\n",
" gradients = 1 / m * X_train.T @ error\n",
" gradients += np.r_[np.zeros([1, n_outputs]), 1 / C * Theta[1:]]\n",
" Theta = Theta - eta * gradients"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 78,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9333333333333333"
]
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"logits = X_valid @ Theta\n",
2017-05-29 23:20:14 +02:00
"Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
"y_predict = Y_proba.argmax(axis=1)\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"accuracy_score = (y_predict == y_valid).mean()\n",
2017-05-29 23:20:14 +02:00
"accuracy_score"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2022-05-24 14:37:17 +02:00
"Oh well, still no change in validation accuracy, but at least early stopping shortened training a bit."
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"Now let's plot the model's predictions on the whole dataset (remember to scale all features fed to the model):"
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 79,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmgAAAEOCAYAAAA9quuTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADa1ElEQVR4nOyddZxUVf/H3+dOb8x2J7vUkqJiIYLYqBiAndjd7c8Wux8Vu+sxUYzHRMICA4UFpNmG7Z3d6Xt+f8zsxAa7sA33/XrNa2bOPffeM4dl9zPfFFJKNDQ0NDQ0NDQ0+g9KXy9AQ0NDQ0NDQ0MjHE2gaWhoaGhoaGj0MzSBpqGhoaGhoaHRz9AEmoaGhoaGhoZGP0MTaBoaGhoaGhoa/QxNoGloaGhoaGho9DN6TaAJIcxCiN+EEMuEECuEEHe2MWeyEKJOCPGX/3Fbb61PQ0NDQ0NDQ6O/oO/FezmBKVJKmxDCACwSQnwppfylxbyFUsqjenFdGhoaGhoaGhr9il4TaNJXEdfmf2vwP7QquRoaGhoaGhoaLehNCxpCCB3wOzAYeFpK+Wsb0/YVQiwDSoFrpZQr2rjO+cD5AGazeY/07PQeXPXOjVAFUtF08o6i7d+Oo+1d19D2r2to+7fjaHvXNdb/u75SSpnU0TzRF62ehBCxwMfAZVLK5SHjVkD1u0GnAk9IKYds61r5w/Llbcu1ULUdJWpRFLb9bR1P1GgTbf92HG3vuoa2f11D278dR9u7rnGW8azfpZR7djSvT7I4pZS1wHzg8Bbj9VJKm//1F4BBCJHY6wvU0NDQ0NDQ0OhDejOLM8lvOUMIYQEOBla1mJMqhBD+13v511fVW2vU0NDQ0NDQ0OgP9GYMWhrwmj8OTQH+K6WcJ4S4EEBKOQeYAVwkhPAAduAk2Rc+WA0NDQ0NDQ2NPqQ3szj/Bsa1MT4n5PV/gP/01po0NDQ0NDQ0NPojWicBDQ0NDQ0NDY1+Rq+W2egLohqiSK9KR+/e6T/qDqHEKqhr1L5eRr/DY/BQmlCKLVrLVNLQ0NDQ6H12atUS1RBFbmUuaRlpGMwG/PkHGiHobDq8Ud6+Xka/QkqJ2+HGWGJkIxs1kaahoaGh0evs1C7O9Kp00jLSMFqMmjjT6DRCCIwWI2kZaaRXaUWQNTQ0NDR6n51aoOndegxmQ18vQ2OAYjAbNNe4hoaGhkafsFMLNECznGnsMNrPjoaGhoZGX7HTCzQNDQ0NDQ0NjYGGJtA0NDQ0NDQ0NPoZmkAbwBwz5RhuuOyGHrv+pWdfyilHn9Ll6yyev5gkXRJVlZ3v2vXOq++QY83p8r01NDQ0NDQGIloEdD/k0rMvpbqymrc/e3ub81794FUMhp5Lgpj9+Gy6o9PW+P3Gs7xkOfEJ8Z0+59gTj+XgqQd3+d4aGhoaGhoDEc2C1knKG8qZ9to0KmwVfb0UXC4XAHHxcURFR/XYfawxVmJiYzpcR0cYjUZSUlO2K+jeYrGQlJzU6fkaGhoaGho7EwNeoHkdTuyrSlFdnh69zyMLH+HXol95ZMEjPXqftmh2NT754JOMyR7D2OyxQGsX57yP5jFpt0lkRWYxJHEI0w6cxpaKLW1e8/xTzuesGWeFjamqyticscx5fE7YfZs5ZsoxXHfxddx+3e0MTxnOkROPBODrz79mn4J9yIzI5OjJR/Pxux+TpEti88bNQGsXZ7P7csF3C5g4ZiI50Tkce9CxbNqwKXCvtlycX3/+NYftexhZkVkMTRrKqdNOxeFwAPD+m+9zyN6HkBuTS0FqAbNOmEVZSdl277WGhoaGhkZ/YMC7OB3lW/l7zM2gCEy5iZiHpmEenIJ5SArD9x+CdHkQxq59zPKGct5d9i6qVHln2Ttcc8A1pESldNMn6Bw/LfgJa4yV9754r023Y0V5Beefcj63zr6Vo44/ikZbI0t/Wdru9WacOoNZM2dRV1tHVJTPCvfTjz9RUVbBcScd1+5577/1Pmecdwaf/fgZUkqKNxdz9oyzmXXxLM48/0wK/ynktmtv6/DzuJwunnjgCZ548QlMZhOXnn0p1150Le9/9X6b87/76jvOOO4MLr/hcp586Uk8Hg/zv5mPqvraVLlcLq6//XqGDB9CdWU1d910F+efej6fzf+sw7VoaGhoaGj0Nwa8QDOnJpF3x1Qca8pxrKnAsbaChoWrURuduL88lSa1GIRAMRtQTHqE2YBiMqCYDQizAaHr2Ij4yMJHUKVPCKhS5ZEFj/Dg1Ad7+qOFYTabeeKlJzCZTG0eLy8tx+12c/T0o8nKyQKgYFRBu9ebctgUoq3RfPrpp5xysc9K9sHbHzBxykRSUtsXnzmDcrjr4bsC7++++W5y8nxjQggGDxvMujXrmH3r7G1+Ho/HwwNPPcDgYYMBuOTqS7j8nMtRVRVFaf1v8si9j3D09KO5+e6bA2Mjx4wMvD511qmB17l5uTz09EPsN3I/SotLSc/UugFoaGhoaAwsBrxA05lNJJ02IWxMSom7rBZDtRVTRgKq043qcKPa3ah1TRBigBJ6JSDWAsLN5BNzCBGwnrm8vngrl9fVJ1a04aOGtyvOAEaNHcWkgyYxccxEJh8ymUkHTeLoGUeTmJTY5ny9Xs8xJxzDB//9gFMuPgWn08m8j+Yx+/FtC6sxu48Je7921VrG7TkuLL5sj7326PDzmEymgDgDSE1Pxe12U1dbR1x8XKv5y/9czslnnNzu9Zb9sYyH73qY5cuWU1NdE7AyFm8u1gSahoaGhsaAY8DHoLWFEAJjehyKyYA+MRpjRjzm/BQsIzOI3C0Hy8gMzPnJGDPi0MVGAAJvbROukhoc67ZgLyyh8a9N2AtLeHDePahqeDPxZitabxIREbHN4zqdjvf/9z7vf/U+I8eM5K1X3mLvYXuzfNnyds+ZeepMFi9eTFlJGd98/g1ul5upx03d9joiw9chpYQdKLiv0+vC3jcLvGaX5fbQ2NjIiUeciCXCwtOvPc03v37De1+8B4Db5d7+xWloaGhoaPQxA96Ctt0IgWIygMmArkWCovSqyGZrm8ONdLj5fetfuNTwP/Iur4tf1/2Ec3Ol33Vq9Fve+nY7hRCM33c84/cdz7X/dy37j96fT/77CaPGjmpz/h5778GgQYP46N2PWPLzEo445ohAPFpnGVIwhK8+/Sps7I8lf+zwZ2iPUeNGseD7BZx+3umtjq1dtZaqyipuufcWcgb5EgtWf7S629egoaGhoaHRW+x6Am0bCJ2CiDChRARdiT9evggA6fQEXaV+8eattePx2EIuIFBM+qDL1OwTborZAErP9nVc+stSfvzuR6YcOoWklCT++fMfSopKGFYwbJvnzThhBm++9CZFG4t49cNXt/u+Z11wFnMem8Pt193O6eeezqoVq3j9+deB7u1ledVNV3HaMacxaPAgpp88HSkl87+Zzxnnn0FGdgYmk4mXnn6JWRfP4t+V/3L/7fd32701NDQ0NDR6m53SxdkTCJMendWCIdmKKTsB89BUIsZkETE2G8uwNEw5iRiSrQiTAdXuxl1eh3PjVuyrSmn8axNNy4txrK3AVVyNp7IBr82B9Hg7vnEnscZY+W3xb5w67VT2HrY3t113G9fceg0zT5u5zfNOOOEE1q5eizXGyuRDJm/3fbNysnjl/Vf46rOvmDxuMnOemMO1/3ctACZz+zFz28shUw/htQ9f47uvvmPKHlM45sBjWPTDIhRFITEpkadeeYov537J/qP25+G7Hw5LZNDQ0NDQ0BhoiO6oFN+X5A/Ll7ctb7usw4g1Ixg0fFAvr8iPlKhOD9LRbHVzoTo8SIcLqQb3XOhCkhRCrG695S7V2XR4o7pPKAI89+RzPHD7A6ytWttmRuZAYsOqDRQOKWz3eNSiKGz729o9rtE+2t51DW3/uoa2fzuOtndd4yzjWb9LKffsaJ7m4uwp/KU9MBvQtTgk3Z4wV6nqcOOtt+OpsrU6v/khQl7Tja7D7uClZ15i3J7jSEhK4PdffufRex7lpDNPGvDiTENDQ0NDo6/QBFofIAx6dAY9umhL2Lj0qi0sbm7UJieemsaweYrJgDDrgzF
"text/plain": [
"<Figure size 720x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2022-02-19 06:17:36 +01:00
"custom_cmap = mpl.colors.ListedColormap(['#fafab0', '#9898ff', '#a0faa0'])\n",
2021-11-03 23:35:15 +01:00
"\n",
"x0, x1 = np.meshgrid(np.linspace(0, 8, 500).reshape(-1, 1),\n",
" np.linspace(0, 3.5, 200).reshape(-1, 1))\n",
2017-05-29 23:20:14 +02:00
"X_new = np.c_[x0.ravel(), x1.ravel()]\n",
2021-11-03 23:35:15 +01:00
"X_new = (X_new - mean) / std\n",
"X_new_with_bias = np.c_[np.ones(len(X_new)), X_new]\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"logits = X_new_with_bias @ Theta\n",
2017-05-29 23:20:14 +02:00
"Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
"y_predict = Y_proba.argmax(axis=1)\n",
2017-05-29 23:20:14 +02:00
"\n",
"zz1 = Y_proba[:, 1].reshape(x0.shape)\n",
"zz = y_predict.reshape(x0.shape)\n",
"\n",
"plt.figure(figsize=(10, 4))\n",
2021-11-03 23:35:15 +01:00
"plt.plot(X[y == 2, 0], X[y == 2, 1], \"g^\", label=\"Iris virginica\")\n",
"plt.plot(X[y == 1, 0], X[y == 1, 1], \"bs\", label=\"Iris versicolor\")\n",
"plt.plot(X[y == 0, 0], X[y == 0, 1], \"yo\", label=\"Iris setosa\")\n",
2017-05-29 23:20:14 +02:00
"\n",
2018-03-15 18:38:58 +01:00
"plt.contourf(x0, x1, zz, cmap=custom_cmap)\n",
2021-11-03 23:35:15 +01:00
"contour = plt.contour(x0, x1, zz1, cmap=\"hot\")\n",
"plt.clabel(contour, inline=1)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper left\")\n",
2017-05-29 23:20:14 +02:00
"plt.axis([0, 7, 0, 3.5])\n",
2021-11-03 23:35:15 +01:00
"plt.grid()\n",
2017-05-29 23:20:14 +02:00
"plt.show()"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
"And now let's measure the final model's accuracy on the test set:"
]
},
{
"cell_type": "code",
2022-02-19 06:17:36 +01:00
"execution_count": 80,
2018-02-21 23:04:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.9666666666666667"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"logits = X_test @ Theta\n",
2017-05-29 23:20:14 +02:00
"Y_proba = softmax(logits)\n",
2021-11-03 23:35:15 +01:00
"y_predict = Y_proba.argmax(axis=1)\n",
2017-05-29 23:20:14 +02:00
"\n",
2021-11-03 23:35:15 +01:00
"accuracy_score = (y_predict == y_test).mean()\n",
2017-05-29 23:20:14 +02:00
"accuracy_score"
]
},
{
"cell_type": "markdown",
2018-02-21 23:04:09 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"source": [
2021-11-03 23:35:15 +01:00
"Well we get even better performance on the test set. This variability is likely due to the very small size of the dataset: depending on how you sample the training set, validation set and the test set, you can get quite different results. Try changing the random seed and running the code again a few times, you will see that the results will vary."
2017-05-29 23:20:14 +02:00
]
},
{
"cell_type": "code",
"execution_count": null,
2018-03-15 18:38:58 +01:00
"metadata": {},
2017-05-29 23:20:14 +02:00
"outputs": [],
2016-09-27 16:39:16 +02:00
"source": []
2016-05-22 16:01:18 +02:00
}
],
"metadata": {
"kernelspec": {
2021-11-03 23:35:15 +01:00
"display_name": "Python 3",
2016-05-22 16:01:18 +02:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-10-17 03:27:34 +02:00
"version": "3.8.12"
2016-05-22 16:01:18 +02:00
},
2016-09-27 16:39:16 +02:00
"nav_menu": {},
2016-05-22 16:01:18 +02:00
"toc": {
2016-09-27 16:39:16 +02:00
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
2016-05-22 16:01:18 +02:00
"toc_cell": false,
2016-09-27 16:39:16 +02:00
"toc_section_display": "block",
2016-05-22 16:01:18 +02:00
"toc_window_display": false
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
2016-05-22 16:01:18 +02:00
}