{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "**Chapter 1 – The Machine Learning landscape**\n", "\n", "_This notebook contains the code examples in chapter 1. You'll also find the exercise solutions at the end of the notebook. The rest of this notebook is used to generate `lifesat.csv` from the original data sources, and some of this chapter's figures._\n", "\n", "You're welcome to go through the code in this notebook if you want, but the real action starts in the next chapter." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", " \n", " \n", "
\n", " \"Open\n", " \n", " \n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Python 3.8 is required:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "import sys\n", "assert sys.version_info >= (3, 8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make this notebook's output stable across runs:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "np.random.seed(42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Scikit-Learn ≥1.0 is required:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import sklearn\n", "\n", "assert sklearn.__version__ >= \"1.0\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To plot pretty figures directly within Jupyter:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import matplotlib as mpl\n", "\n", "mpl.rc('font', size=12)\n", "mpl.rc('axes', labelsize=14)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Download `lifesat.csv` from github, unless it's already available locally:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "import urllib.request\n", "\n", "datapath = Path() / \"datasets\" / \"lifesat\"\n", "datapath.mkdir(parents=True, exist_ok=True)\n", "\n", "root = \"https://raw.githubusercontent.com/ageron/handson-ml3/main/\"\n", "filename = \"lifesat.csv\"\n", "if not (datapath / filename).is_file():\n", " print(\"Downloading\", filename)\n", " url = root + \"datasets/lifesat/\" + filename\n", " urllib.request.urlretrieve(url, datapath / filename)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Code example 1-1" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.linear_model import LinearRegression\n", "\n", "# Load and prepare the data\n", "lifesat = pd.read_csv(Path() / \"datasets\" / \"lifesat\" / \"lifesat.csv\")\n", "X = lifesat[[\"GDP per capita (USD)\"]].values\n", "y = lifesat[[\"Life satisfaction\"]].values\n", "\n", "# Visualize the data\n", "lifesat.plot(kind='scatter', grid=True,\n", " x=\"GDP per capita (USD)\", y=\"Life satisfaction\")\n", "plt.axis([23_500, 62_500, 4, 9])\n", "plt.show()\n", "\n", "# Select a linear model\n", "model = LinearRegression()\n", "\n", "# Train the model\n", "model.fit(X, y)\n", "\n", "# Make a prediction for Cyprus\n", "X_new = [[37_655.2]] # Cyprus' GDP per capita in 2020\n", "print(model.predict(X_new)) # outputs [[6.30165767]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Replacing the Linear Regression model with k-Nearest Neighbors (in this example, k = 3) regression in the previous code is as simple as replacing these two\n", "lines:\n", "\n", "```python\n", "from sklearn.linear_model import LinearRegression\n", "\n", "model = LinearRegression()\n", "```\n", "\n", "with these two:\n", "\n", "```python\n", "from sklearn.neighbors import KNeighborsRegressor\n", "\n", "model = KNeighborsRegressor(n_neighbors=3)\n", "```" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Select a 3-Nearest Neighbors regression model\n", "from sklearn.neighbors import KNeighborsRegressor\n", "\n", "model = KNeighborsRegressor(n_neighbors=3)\n", "\n", "# Train the model\n", "model.fit(X,y)\n", "\n", "# Make a prediction for Cyprus\n", "print(model.predict(X_new)) # outputs [[6.33333333]]\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generating the data and figures — please skip" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the code I used to generate the `lifesat.csv` dataset. You can safely skip this." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create a function to save the figures:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Where to save the figures\n", "IMAGES_PATH = Path() / \"images\" / \"fundamentals\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n", "\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n", " if tight_layout:\n", " plt.tight_layout()\n", " plt.savefig(path, format=fig_extension, dpi=resolution)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load and prepare Life satisfaction data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To create `lifesat.csv`, I downloaded the Better Life Index (BLI) data from [OECD's website](http://stats.oecd.org/index.aspx?DataSetCode=BLI) (to get the Life Satisfaction for each country), and World Bank GDP per capita data from [OurWorldInData.org](https://ourworldindata.org/grapher/gdp-per-capita-worldbank). The BLI data is in `datasets/lifesat/oecd_bli.csv` (data from 2020), and the GDP per capita data is in `datasets/lifesat/gdp_per_capita.csv` (data up to 2020).\n", "\n", "If you want to grab the latest versions, please feel free to do so. However, there may be some changes (e.g., in the column names, or different countries missing data), so be prepared to have to tweak the code." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "for filename in (\"oecd_bli.csv\", \"gdp_per_capita.csv\"):\n", " if not (datapath / filename).is_file():\n", " print(\"Downloading\", filename)\n", " url = root + \"datasets/lifesat/\" + filename\n", " urllib.request.urlretrieve(url, datapath / filename)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "oecd_bli = pd.read_csv(datapath / \"oecd_bli.csv\")\n", "gdp_per_capita = pd.read_csv(datapath / \"gdp_per_capita.csv\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Preprocess the GDP per capita data to keep only the year 2020:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "gdp_year = 2020\n", "gdppc_col = \"GDP per capita (USD)\"\n", "lifesat_col = \"Life satisfaction\"\n", "\n", "gdp_per_capita = gdp_per_capita[gdp_per_capita[\"Year\"] == gdp_year]\n", "gdp_per_capita = gdp_per_capita.drop([\"Code\", \"Year\"], axis=1)\n", "gdp_per_capita.columns = [\"Country\", gdppc_col]\n", "gdp_per_capita.set_index(\"Country\", inplace=True)\n", "\n", "gdp_per_capita.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Preprocess the OECD BLI data to keep only the `Life satisfaction` column:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n", "oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n", "\n", "oecd_bli.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's merge the life satisfaction data and the GDP per capita data, keeping only the GDP per capita and Life satisfaction columns:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n", " left_index=True, right_index=True)\n", "full_country_stats.sort_values(by=gdppc_col, inplace=True)\n", "full_country_stats = full_country_stats[[gdppc_col, lifesat_col]]\n", "\n", "full_country_stats.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To illustrate the risk of overfitting, I use only part of the data in most figures (all countries with a GDP per capita between `min_gdp` and `max_gdp`). Later in the chapter I reveal the missing countries, and show that they don't follow the same linear trend at all." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "min_gdp = 23_500\n", "max_gdp = 62_500\n", "\n", "country_stats = full_country_stats[(full_country_stats[gdppc_col] >= min_gdp) &\n", " (full_country_stats[gdppc_col] <= max_gdp)]\n", "country_stats.head()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "country_stats.to_csv(datapath / \"lifesat.csv\")\n", "full_country_stats.to_csv(datapath / \"lifesat_full.csv\")" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", " x=gdppc_col, y=lifesat_col)\n", "\n", "min_life_sat = 4\n", "max_life_sat = 9\n", "\n", "position_text = {\n", " \"Hungary\": (28_000, 4.2),\n", " \"France\": (40_000, 5),\n", " \"New Zealand\": (28_000, 8.2),\n", " \"Australia\": (50_000, 5.5),\n", " \"United States\": (59_000, 5.5),\n", " \"Denmark\": (46_000, 8.5)\n", "}\n", "\n", "for country, pos_text in position_text.items():\n", " pos_data_x = country_stats[gdppc_col].loc[country]\n", " pos_data_y = country_stats[lifesat_col].loc[country]\n", " country = \"U.S.\" if country == \"United States\" else country\n", " plt.annotate(country, xy=(pos_data_x, pos_data_y),\n", " xytext=pos_text,\n", " arrowprops=dict(facecolor='black', width=0.5,\n", " shrink=0.15, headwidth=5))\n", " plt.plot(pos_data_x, pos_data_y, \"ro\")\n", "\n", "plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n", "\n", "save_fig('money_happy_scatterplot')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "highlighted_countries = country_stats.loc[list(position_text.keys())]\n", "highlighted_countries[[gdppc_col, lifesat_col]].sort_values(by=gdppc_col)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", " x=gdppc_col, y=lifesat_col)\n", "\n", "X = np.linspace(min_gdp, max_gdp, 1000)\n", "\n", "w1, w2 = 4.2, 0\n", "plt.plot(X, w1 + w2 * 1e-5 * X, \"r\")\n", "plt.text(40_000, 4.9, fr\"$\\theta_0 = {w1}$\", color=\"r\")\n", "plt.text(40_000, 4.4, fr\"$\\theta_1 = {w2}$\", color=\"r\")\n", "\n", "w1, w2 = 10, -9\n", "plt.plot(X, w1 + w2 * 1e-5 * X, \"g\")\n", "plt.text(26_000, 8.5, fr\"$\\theta_0 = {w1}$\", color=\"g\")\n", "plt.text(26_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\", color=\"g\")\n", "\n", "w1, w2 = 3, 8\n", "plt.plot(X, w1 + w2 * 1e-5 * X, \"b\")\n", "plt.text(48_000, 8.5, fr\"$\\theta_0 = {w1}$\", color=\"b\")\n", "plt.text(48_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\", color=\"b\")\n", "\n", "plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n", "\n", "save_fig('tweaking_model_params_plot')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "from sklearn import linear_model\n", "\n", "X_sample = country_stats[[gdppc_col]].values\n", "y_sample = country_stats[[lifesat_col]].values\n", "\n", "lin1 = linear_model.LinearRegression()\n", "lin1.fit(X_sample, y_sample)\n", "\n", "t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]\n", "print(f\"θ0={t0:.2f}, θ1={t1:.2e}\")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", " x=gdppc_col, y=lifesat_col)\n", "\n", "X = np.linspace(min_gdp, max_gdp, 1000)\n", "plt.plot(X, t0 + t1 * X, \"b\")\n", "\n", "plt.text(max_gdp - 20_000, min_life_sat + 1.5,\n", " fr\"$\\theta_0 = {t0:.2f}$\", color=\"b\")\n", "plt.text(max_gdp - 20_000, min_life_sat + 1,\n", " fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\", color=\"b\")\n", "\n", "plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n", "\n", "save_fig('best_fit_model_plot')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "cyprus_gdp_per_capita = gdp_per_capita[gdppc_col].loc[\"Cyprus\"]\n", "cyprus_gdp_per_capita" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0, 0]\n", "cyprus_predicted_life_satisfaction" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", " x=gdppc_col, y=lifesat_col)\n", "\n", "X = np.linspace(min_gdp, max_gdp, 1000)\n", "plt.plot(X, t0 + t1 * X, \"b\")\n", "\n", "plt.text(min_gdp + 15_000, max_life_sat - 1.5,\n", " fr\"$\\theta_0 = {t0:.2f}$\", color=\"b\")\n", "plt.text(min_gdp + 15_000, max_life_sat - 1,\n", " fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\", color=\"b\")\n", "\n", "plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita],\n", " [min_life_sat, cyprus_predicted_life_satisfaction], \"r--\")\n", "plt.text(cyprus_gdp_per_capita + 1000, 5.0,\n", " fr\"Prediction = {cyprus_predicted_life_satisfaction:.2f}\", color=\"r\")\n", "plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, \"ro\")\n", "\n", "plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "missing_data = full_country_stats[(full_country_stats[gdppc_col] < min_gdp) |\n", " (full_country_stats[gdppc_col] > max_gdp)]\n", "missing_data" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "position_text_missing_countries = {\n", " \"South Africa\": (20_000, 4.2),\n", " \"Colombia\": (6_000, 8.2),\n", " \"Brazil\": (18_000, 7.8),\n", " \"Mexico\": (24_000, 7.4),\n", " \"Chile\": (30_000, 7.0),\n", " \"Norway\": (51_000, 6.2),\n", " \"Switzerland\": (62_000, 5.7),\n", " \"Ireland\": (81_000, 5.2),\n", " \"Luxembourg\": (92_000, 4.7),\n", "}" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "full_country_stats.plot(kind='scatter', figsize=(8,3),\n", " x=gdppc_col, y=lifesat_col, grid=True)\n", "\n", "for country, pos_text in position_text_missing_countries.items():\n", " pos_data_x, pos_data_y = missing_data.loc[country]\n", " plt.annotate(country, xy=(pos_data_x, pos_data_y),\n", " xytext=pos_text,\n", " arrowprops=dict(facecolor='black', width=0.5,\n", " shrink=0.1, headwidth=5))\n", " plt.plot(pos_data_x, pos_data_y, \"rs\")\n", "\n", "X = np.linspace(0, 115_000, 1000)\n", "plt.plot(X, t0 + t1 * X, \"b:\")\n", "\n", "lin_reg_full = linear_model.LinearRegression()\n", "Xfull = np.c_[full_country_stats[gdppc_col]]\n", "yfull = np.c_[full_country_stats[lifesat_col]]\n", "lin_reg_full.fit(Xfull, yfull)\n", "\n", "t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n", "X = np.linspace(0, 115_000, 1000)\n", "plt.plot(X, t0full + t1full * X, \"k\")\n", "\n", "plt.axis([0, 115_000, min_life_sat, max_life_sat])\n", "\n", "save_fig('representative_training_data_scatterplot')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "from sklearn import preprocessing\n", "from sklearn import pipeline\n", "\n", "full_country_stats.plot(kind='scatter', figsize=(8,3),\n", " x=gdppc_col, y=lifesat_col, grid=True)\n", "\n", "poly = preprocessing.PolynomialFeatures(degree=10, include_bias=False)\n", "scaler = preprocessing.StandardScaler()\n", "lin_reg2 = linear_model.LinearRegression()\n", "\n", "pipeline_reg = pipeline.Pipeline([\n", " ('poly', poly),\n", " ('scal', scaler),\n", " ('lin', lin_reg2)])\n", "pipeline_reg.fit(Xfull, yfull)\n", "curve = pipeline_reg.predict(X[:, np.newaxis])\n", "plt.plot(X, curve)\n", "\n", "plt.axis([0, 115_000, min_life_sat, max_life_sat])\n", "\n", "save_fig('overfitting_model_plot')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "w_countries = [c for c in full_country_stats.index if \"W\" in c.upper()]\n", "full_country_stats.loc[w_countries][lifesat_col]" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "all_w_countries = [c for c in gdp_per_capita.index if \"W\" in c.upper()]\n", "gdp_per_capita.loc[all_w_countries].sort_values(by=gdppc_col)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "country_stats.plot(kind='scatter', x=gdppc_col, y=lifesat_col, figsize=(8,3))\n", "missing_data.plot(kind='scatter', x=gdppc_col, y=lifesat_col,\n", " marker=\"s\", color=\"r\", grid=True, ax=plt.gca())\n", "\n", "X = np.linspace(0, 115_000, 1000)\n", "plt.plot(X, t0 + t1*X, \"b:\", label=\"Linear model on partial data\")\n", "plt.plot(X, t0full + t1full * X, \"k-\", label=\"Linear model on all data\")\n", "\n", "ridge = linear_model.Ridge(alpha=10**9.5)\n", "X_sample = country_stats[[gdppc_col]]\n", "y_sample = country_stats[[lifesat_col]]\n", "ridge.fit(X_sample, y_sample)\n", "t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n", "plt.plot(X, t0ridge + t1ridge * X, \"b--\",\n", " label=\"Regularized linear model on partial data\")\n", "plt.legend(loc=\"lower right\")\n", "\n", "plt.axis([0, 115_000, min_life_sat, max_life_sat])\n", "\n", "save_fig('ridge_model_plot')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Exercise Solutions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1. Machine Learning is about building systems that can learn from data. Learning means getting better at some task, given some performance measure.\n", "2. Machine Learning is great for complex problems for which we have no algorithmic solution, to replace long lists of hand-tuned rules, to build systems that adapt to fluctuating environments, and finally to help humans learn (e.g., data mining).\n", "3. A labeled training set is a training set that contains the desired solution (a.k.a. a label) for each instance.\n", "4. The two most common supervised tasks are regression and classification.\n", "5. Common unsupervised tasks include clustering, visualization, dimensionality reduction, and association rule learning.\n", "6. Reinforcement Learning is likely to perform best if we want a robot to learn to walk in various unknown terrains, since this is typically the type of problem that Reinforcement Learning tackles. It might be possible to express the problem as a supervised or semi-supervised learning problem, but it would be less natural.\n", "7. If you don't know how to define the groups, then you can use a clustering algorithm (unsupervised learning) to segment your customers into clusters of similar customers. However, if you know what groups you would like to have, then you can feed many examples of each group to a classification algorithm (supervised learning), and it will classify all your customers into these groups.\n", "8. Spam detection is a typical supervised learning problem: the algorithm is fed many emails along with their labels (spam or not spam).\n", "9. An online learning system can learn incrementally, as opposed to a batch learning system. This makes it capable of adapting rapidly to both changing data and autonomous systems, and of training on very large quantities of data.\n", "10. Out-of-core algorithms can handle vast quantities of data that cannot fit in a computer's main memory. An out-of-core learning algorithm chops the data into mini-batches and uses online learning techniques to learn from these mini-batches.\n", "11. An instance-based learning system learns the training data by heart; then, when given a new instance, it uses a similarity measure to find the most similar learned instances and uses them to make predictions.\n", "12. A model has one or more model parameters that determine what it will predict given a new instance (e.g., the slope of a linear model). A learning algorithm tries to find optimal values for these parameters such that the model generalizes well to new instances. A hyperparameter is a parameter of the learning algorithm itself, not of the model (e.g., the amount of regularization to apply).\n", "13. Model-based learning algorithms search for an optimal value for the model parameters such that the model will generalize well to new instances. We usually train such systems by minimizing a cost function that measures how bad the system is at making predictions on the training data, plus a penalty for model complexity if the model is regularized. To make predictions, we feed the new instance's features into the model's prediction function, using the parameter values found by the learning algorithm.\n", "14. Some of the main challenges in Machine Learning are the lack of data, poor data quality, nonrepresentative data, uninformative features, excessively simple models that underfit the training data, and excessively complex models that overfit the data.\n", "15. If a model performs great on the training data but generalizes poorly to new instances, the model is likely overfitting the training data (or we got extremely lucky on the training data). Possible solutions to overfitting are getting more data, simplifying the model (selecting a simpler algorithm, reducing the number of parameters or features used, or regularizing the model), or reducing the noise in the training data.\n", "16. A test set is used to estimate the generalization error that a model will make on new instances, before the model is launched in production.\n", "17. A validation set is used to compare models. It makes it possible to select the best model and tune the hyperparameters.\n", "18. The train-dev set is used when there is a risk of mismatch between the training data and the data used in the validation and test datasets (which should always be as close as possible to the data used once the model is in production). The train-dev set is a part of the training set that's held out (the model is not trained on it). The model is trained on the rest of the training set, and evaluated on both the train-dev set and the validation set. If the model performs well on the training set but not on the train-dev set, then the model is likely overfitting the training set. If it performs well on both the training set and the train-dev set, but not on the validation set, then there is probably a significant data mismatch between the training data and the validation + test data, and you should try to improve the training data to make it look more like the validation + test data.\n", "19. If you tune hyperparameters using the test set, you risk overfitting the test set, and the generalization error you measure will be optimistic (you may launch a model that performs worse than you expect)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "homl3", "language": "python", "name": "homl3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" }, "metadata": { "interpreter": { "hash": "22b0ec00cd9e253c751e6d2619fc0bb2d18ed12980de3246690d5be49479dd65" } }, "nav_menu": {}, "toc": { "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 6, "toc_cell": false, "toc_section_display": "block", "toc_window_display": true }, "toc_position": { "height": "616px", "left": "0px", "right": "20px", "top": "106px", "width": "213px" } }, "nbformat": 4, "nbformat_minor": 4 }