1746 lines
235 KiB
Plaintext
1746 lines
235 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Chapter 1 – The Machine Learning landscape**\n",
|
||
"\n",
|
||
"_This notebook contains the code examples in chapter 1. You'll also find the exercise solutions at the end of the notebook. The rest of this notebook is used to generate `lifesat.csv` from the original data sources, and some of this chapter's figures._\n",
|
||
"\n",
|
||
"You're welcome to go through the code in this notebook if you want, but the real action starts in the next chapter."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"<table align=\"left\">\n",
|
||
" <td>\n",
|
||
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/01_the_machine_learning_landscape.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
|
||
" </td>\n",
|
||
" <td>\n",
|
||
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/01_the_machine_learning_landscape.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
|
||
" </td>\n",
|
||
"</table>"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Setup"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"This project requires Python 3.7 or above:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"slideshow": {
|
||
"slide_type": "-"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import sys\n",
|
||
"\n",
|
||
"assert sys.version_info >= (3, 7)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Scikit-Learn ≥1.0.1 is required:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from packaging import version\n",
|
||
"import sklearn\n",
|
||
"\n",
|
||
"assert version.parse(sklearn.__version__) >= version.parse(\"1.0.1\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Let's define the default font sizes, to plot pretty figures:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import matplotlib.pyplot as plt\n",
|
||
"\n",
|
||
"plt.rc('font', size=12)\n",
|
||
"plt.rc('axes', labelsize=14, titlesize=14)\n",
|
||
"plt.rc('legend', fontsize=12)\n",
|
||
"plt.rc('xtick', labelsize=10)\n",
|
||
"plt.rc('ytick', labelsize=10)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Make this notebook's output stable across runs:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"np.random.seed(42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Code example 1-1"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[[6.30165767]]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"from sklearn.linear_model import LinearRegression\n",
|
||
"\n",
|
||
"# Download and prepare the data\n",
|
||
"data_root = \"https://github.com/ageron/data/raw/main/\"\n",
|
||
"lifesat = pd.read_csv(data_root + \"lifesat/lifesat.csv\")\n",
|
||
"X = lifesat[[\"GDP per capita (USD)\"]].values\n",
|
||
"y = lifesat[[\"Life satisfaction\"]].values\n",
|
||
"\n",
|
||
"# Visualize the data\n",
|
||
"lifesat.plot(kind='scatter', grid=True,\n",
|
||
" x=\"GDP per capita (USD)\", y=\"Life satisfaction\")\n",
|
||
"plt.axis([23_500, 62_500, 4, 9])\n",
|
||
"plt.show()\n",
|
||
"\n",
|
||
"# Select a linear model\n",
|
||
"model = LinearRegression()\n",
|
||
"\n",
|
||
"# Train the model\n",
|
||
"model.fit(X, y)\n",
|
||
"\n",
|
||
"# Make a prediction for Cyprus\n",
|
||
"X_new = [[37_655.2]] # Cyprus' GDP per capita in 2020\n",
|
||
"print(model.predict(X_new)) # outputs [[6.30165767]]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Replacing the Linear Regression model with k-Nearest Neighbors (in this example, k = 3) regression in the previous code is as simple as replacing these two\n",
|
||
"lines:\n",
|
||
"\n",
|
||
"```python\n",
|
||
"from sklearn.linear_model import LinearRegression\n",
|
||
"\n",
|
||
"model = LinearRegression()\n",
|
||
"```\n",
|
||
"\n",
|
||
"with these two:\n",
|
||
"\n",
|
||
"```python\n",
|
||
"from sklearn.neighbors import KNeighborsRegressor\n",
|
||
"\n",
|
||
"model = KNeighborsRegressor(n_neighbors=3)\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[[6.33333333]]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Select a 3-Nearest Neighbors regression model\n",
|
||
"from sklearn.neighbors import KNeighborsRegressor\n",
|
||
"\n",
|
||
"model = KNeighborsRegressor(n_neighbors=3)\n",
|
||
"\n",
|
||
"# Train the model\n",
|
||
"model.fit(X, y)\n",
|
||
"\n",
|
||
"# Make a prediction for Cyprus\n",
|
||
"print(model.predict(X_new)) # outputs [[6.33333333]]\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Generating the data and figures — please skip"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"This is the code I used to generate the `lifesat.csv` dataset. You can safely skip this."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Create a function to save the figures:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from pathlib import Path\n",
|
||
"\n",
|
||
"# Where to save the figures\n",
|
||
"IMAGES_PATH = Path() / \"images\" / \"fundamentals\"\n",
|
||
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
|
||
"\n",
|
||
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
|
||
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
|
||
" if tight_layout:\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.savefig(path, format=fig_extension, dpi=resolution)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Load and prepare Life satisfaction data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"To create `lifesat.csv`, I downloaded the Better Life Index (BLI) data from [OECD's website](http://stats.oecd.org/index.aspx?DataSetCode=BLI) (to get the Life Satisfaction for each country), and World Bank GDP per capita data from [OurWorldInData.org](https://ourworldindata.org/grapher/gdp-per-capita-worldbank). The BLI data is in `datasets/lifesat/oecd_bli.csv` (data from 2020), and the GDP per capita data is in `datasets/lifesat/gdp_per_capita.csv` (data up to 2020).\n",
|
||
"\n",
|
||
"If you want to grab the latest versions, please feel free to do so. However, there may be some changes (e.g., in the column names, or different countries missing data), so be prepared to have to tweak the code."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Downloading oecd_bli.csv\n",
|
||
"Downloading gdp_per_capita.csv\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import urllib.request\n",
|
||
"\n",
|
||
"datapath = Path() / \"datasets\" / \"lifesat\"\n",
|
||
"datapath.mkdir(parents=True, exist_ok=True)\n",
|
||
"\n",
|
||
"data_root = \"https://github.com/ageron/data/raw/main/\"\n",
|
||
"for filename in (\"oecd_bli.csv\", \"gdp_per_capita.csv\"):\n",
|
||
" if not (datapath / filename).is_file():\n",
|
||
" print(\"Downloading\", filename)\n",
|
||
" url = data_root + \"lifesat/\" + filename\n",
|
||
" urllib.request.urlretrieve(url, datapath / filename)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"oecd_bli = pd.read_csv(datapath / \"oecd_bli.csv\")\n",
|
||
"gdp_per_capita = pd.read_csv(datapath / \"gdp_per_capita.csv\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Preprocess the GDP per capita data to keep only the year 2020:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>GDP per capita (USD)</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Country</th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>Afghanistan</th>\n",
|
||
" <td>1978.961579</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Africa Eastern and Southern</th>\n",
|
||
" <td>3387.594670</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Africa Western and Central</th>\n",
|
||
" <td>4003.158913</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Albania</th>\n",
|
||
" <td>13295.410885</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Algeria</th>\n",
|
||
" <td>10681.679297</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" GDP per capita (USD)\n",
|
||
"Country \n",
|
||
"Afghanistan 1978.961579\n",
|
||
"Africa Eastern and Southern 3387.594670\n",
|
||
"Africa Western and Central 4003.158913\n",
|
||
"Albania 13295.410885\n",
|
||
"Algeria 10681.679297"
|
||
]
|
||
},
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"gdp_year = 2020\n",
|
||
"gdppc_col = \"GDP per capita (USD)\"\n",
|
||
"lifesat_col = \"Life satisfaction\"\n",
|
||
"\n",
|
||
"gdp_per_capita = gdp_per_capita[gdp_per_capita[\"Year\"] == gdp_year]\n",
|
||
"gdp_per_capita = gdp_per_capita.drop([\"Code\", \"Year\"], axis=1)\n",
|
||
"gdp_per_capita.columns = [\"Country\", gdppc_col]\n",
|
||
"gdp_per_capita.set_index(\"Country\", inplace=True)\n",
|
||
"\n",
|
||
"gdp_per_capita.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Preprocess the OECD BLI data to keep only the `Life satisfaction` column:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th>Indicator</th>\n",
|
||
" <th>Air pollution</th>\n",
|
||
" <th>Dwellings without basic facilities</th>\n",
|
||
" <th>Educational attainment</th>\n",
|
||
" <th>Employees working very long hours</th>\n",
|
||
" <th>Employment rate</th>\n",
|
||
" <th>Feeling safe walking alone at night</th>\n",
|
||
" <th>Homicide rate</th>\n",
|
||
" <th>Household net adjusted disposable income</th>\n",
|
||
" <th>Household net wealth</th>\n",
|
||
" <th>Housing expenditure</th>\n",
|
||
" <th>...</th>\n",
|
||
" <th>Personal earnings</th>\n",
|
||
" <th>Quality of support network</th>\n",
|
||
" <th>Rooms per person</th>\n",
|
||
" <th>Self-reported health</th>\n",
|
||
" <th>Stakeholder engagement for developing regulations</th>\n",
|
||
" <th>Student skills</th>\n",
|
||
" <th>Time devoted to leisure and personal care</th>\n",
|
||
" <th>Voter turnout</th>\n",
|
||
" <th>Water quality</th>\n",
|
||
" <th>Years in education</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Country</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>Australia</th>\n",
|
||
" <td>5.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>81.0</td>\n",
|
||
" <td>13.04</td>\n",
|
||
" <td>73.0</td>\n",
|
||
" <td>63.5</td>\n",
|
||
" <td>1.1</td>\n",
|
||
" <td>32759.0</td>\n",
|
||
" <td>427064.0</td>\n",
|
||
" <td>20.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>49126.0</td>\n",
|
||
" <td>95.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>85.0</td>\n",
|
||
" <td>2.7</td>\n",
|
||
" <td>502.0</td>\n",
|
||
" <td>14.35</td>\n",
|
||
" <td>91.0</td>\n",
|
||
" <td>93.0</td>\n",
|
||
" <td>21.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Austria</th>\n",
|
||
" <td>16.0</td>\n",
|
||
" <td>0.9</td>\n",
|
||
" <td>85.0</td>\n",
|
||
" <td>6.66</td>\n",
|
||
" <td>72.0</td>\n",
|
||
" <td>80.6</td>\n",
|
||
" <td>0.5</td>\n",
|
||
" <td>33541.0</td>\n",
|
||
" <td>308325.0</td>\n",
|
||
" <td>21.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>50349.0</td>\n",
|
||
" <td>92.0</td>\n",
|
||
" <td>1.6</td>\n",
|
||
" <td>70.0</td>\n",
|
||
" <td>1.3</td>\n",
|
||
" <td>492.0</td>\n",
|
||
" <td>14.55</td>\n",
|
||
" <td>80.0</td>\n",
|
||
" <td>92.0</td>\n",
|
||
" <td>17.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Belgium</th>\n",
|
||
" <td>15.0</td>\n",
|
||
" <td>1.9</td>\n",
|
||
" <td>77.0</td>\n",
|
||
" <td>4.75</td>\n",
|
||
" <td>63.0</td>\n",
|
||
" <td>70.1</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>30364.0</td>\n",
|
||
" <td>386006.0</td>\n",
|
||
" <td>21.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>49675.0</td>\n",
|
||
" <td>91.0</td>\n",
|
||
" <td>2.2</td>\n",
|
||
" <td>74.0</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>503.0</td>\n",
|
||
" <td>15.70</td>\n",
|
||
" <td>89.0</td>\n",
|
||
" <td>84.0</td>\n",
|
||
" <td>19.3</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Brazil</th>\n",
|
||
" <td>10.0</td>\n",
|
||
" <td>6.7</td>\n",
|
||
" <td>49.0</td>\n",
|
||
" <td>7.13</td>\n",
|
||
" <td>61.0</td>\n",
|
||
" <td>35.6</td>\n",
|
||
" <td>26.7</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>90.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>2.2</td>\n",
|
||
" <td>395.0</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>79.0</td>\n",
|
||
" <td>73.0</td>\n",
|
||
" <td>16.2</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Canada</th>\n",
|
||
" <td>7.0</td>\n",
|
||
" <td>0.2</td>\n",
|
||
" <td>91.0</td>\n",
|
||
" <td>3.69</td>\n",
|
||
" <td>73.0</td>\n",
|
||
" <td>82.2</td>\n",
|
||
" <td>1.3</td>\n",
|
||
" <td>30854.0</td>\n",
|
||
" <td>423849.0</td>\n",
|
||
" <td>22.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>47622.0</td>\n",
|
||
" <td>93.0</td>\n",
|
||
" <td>2.6</td>\n",
|
||
" <td>88.0</td>\n",
|
||
" <td>2.9</td>\n",
|
||
" <td>523.0</td>\n",
|
||
" <td>14.56</td>\n",
|
||
" <td>68.0</td>\n",
|
||
" <td>91.0</td>\n",
|
||
" <td>17.3</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>5 rows × 24 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
"Indicator Air pollution Dwellings without basic facilities \\\n",
|
||
"Country \n",
|
||
"Australia 5.0 NaN \n",
|
||
"Austria 16.0 0.9 \n",
|
||
"Belgium 15.0 1.9 \n",
|
||
"Brazil 10.0 6.7 \n",
|
||
"Canada 7.0 0.2 \n",
|
||
"\n",
|
||
"Indicator Educational attainment Employees working very long hours \\\n",
|
||
"Country \n",
|
||
"Australia 81.0 13.04 \n",
|
||
"Austria 85.0 6.66 \n",
|
||
"Belgium 77.0 4.75 \n",
|
||
"Brazil 49.0 7.13 \n",
|
||
"Canada 91.0 3.69 \n",
|
||
"\n",
|
||
"Indicator Employment rate Feeling safe walking alone at night \\\n",
|
||
"Country \n",
|
||
"Australia 73.0 63.5 \n",
|
||
"Austria 72.0 80.6 \n",
|
||
"Belgium 63.0 70.1 \n",
|
||
"Brazil 61.0 35.6 \n",
|
||
"Canada 73.0 82.2 \n",
|
||
"\n",
|
||
"Indicator Homicide rate Household net adjusted disposable income \\\n",
|
||
"Country \n",
|
||
"Australia 1.1 32759.0 \n",
|
||
"Austria 0.5 33541.0 \n",
|
||
"Belgium 1.0 30364.0 \n",
|
||
"Brazil 26.7 NaN \n",
|
||
"Canada 1.3 30854.0 \n",
|
||
"\n",
|
||
"Indicator Household net wealth Housing expenditure ... Personal earnings \\\n",
|
||
"Country ... \n",
|
||
"Australia 427064.0 20.0 ... 49126.0 \n",
|
||
"Austria 308325.0 21.0 ... 50349.0 \n",
|
||
"Belgium 386006.0 21.0 ... 49675.0 \n",
|
||
"Brazil NaN NaN ... NaN \n",
|
||
"Canada 423849.0 22.0 ... 47622.0 \n",
|
||
"\n",
|
||
"Indicator Quality of support network Rooms per person Self-reported health \\\n",
|
||
"Country \n",
|
||
"Australia 95.0 NaN 85.0 \n",
|
||
"Austria 92.0 1.6 70.0 \n",
|
||
"Belgium 91.0 2.2 74.0 \n",
|
||
"Brazil 90.0 NaN NaN \n",
|
||
"Canada 93.0 2.6 88.0 \n",
|
||
"\n",
|
||
"Indicator Stakeholder engagement for developing regulations Student skills \\\n",
|
||
"Country \n",
|
||
"Australia 2.7 502.0 \n",
|
||
"Austria 1.3 492.0 \n",
|
||
"Belgium 2.0 503.0 \n",
|
||
"Brazil 2.2 395.0 \n",
|
||
"Canada 2.9 523.0 \n",
|
||
"\n",
|
||
"Indicator Time devoted to leisure and personal care Voter turnout \\\n",
|
||
"Country \n",
|
||
"Australia 14.35 91.0 \n",
|
||
"Austria 14.55 80.0 \n",
|
||
"Belgium 15.70 89.0 \n",
|
||
"Brazil NaN 79.0 \n",
|
||
"Canada 14.56 68.0 \n",
|
||
"\n",
|
||
"Indicator Water quality Years in education \n",
|
||
"Country \n",
|
||
"Australia 93.0 21.0 \n",
|
||
"Austria 92.0 17.0 \n",
|
||
"Belgium 84.0 19.3 \n",
|
||
"Brazil 73.0 16.2 \n",
|
||
"Canada 91.0 17.3 \n",
|
||
"\n",
|
||
"[5 rows x 24 columns]"
|
||
]
|
||
},
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
|
||
"oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
|
||
"\n",
|
||
"oecd_bli.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Now let's merge the life satisfaction data and the GDP per capita data, keeping only the GDP per capita and Life satisfaction columns:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>GDP per capita (USD)</th>\n",
|
||
" <th>Life satisfaction</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Country</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>South Africa</th>\n",
|
||
" <td>11466.189672</td>\n",
|
||
" <td>4.7</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Colombia</th>\n",
|
||
" <td>13441.492952</td>\n",
|
||
" <td>6.3</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Brazil</th>\n",
|
||
" <td>14063.982505</td>\n",
|
||
" <td>6.4</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Mexico</th>\n",
|
||
" <td>17887.750736</td>\n",
|
||
" <td>6.5</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Chile</th>\n",
|
||
" <td>23324.524751</td>\n",
|
||
" <td>6.5</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" GDP per capita (USD) Life satisfaction\n",
|
||
"Country \n",
|
||
"South Africa 11466.189672 4.7\n",
|
||
"Colombia 13441.492952 6.3\n",
|
||
"Brazil 14063.982505 6.4\n",
|
||
"Mexico 17887.750736 6.5\n",
|
||
"Chile 23324.524751 6.5"
|
||
]
|
||
},
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
|
||
" left_index=True, right_index=True)\n",
|
||
"full_country_stats.sort_values(by=gdppc_col, inplace=True)\n",
|
||
"full_country_stats = full_country_stats[[gdppc_col, lifesat_col]]\n",
|
||
"\n",
|
||
"full_country_stats.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"To illustrate the risk of overfitting, I use only part of the data in most figures (all countries with a GDP per capita between `min_gdp` and `max_gdp`). Later in the chapter I reveal the missing countries, and show that they don't follow the same linear trend at all."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>GDP per capita (USD)</th>\n",
|
||
" <th>Life satisfaction</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Country</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>Russia</th>\n",
|
||
" <td>26456.387938</td>\n",
|
||
" <td>5.8</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Greece</th>\n",
|
||
" <td>27287.083401</td>\n",
|
||
" <td>5.4</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Turkey</th>\n",
|
||
" <td>28384.987785</td>\n",
|
||
" <td>5.5</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Latvia</th>\n",
|
||
" <td>29932.493910</td>\n",
|
||
" <td>5.9</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Hungary</th>\n",
|
||
" <td>31007.768407</td>\n",
|
||
" <td>5.6</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" GDP per capita (USD) Life satisfaction\n",
|
||
"Country \n",
|
||
"Russia 26456.387938 5.8\n",
|
||
"Greece 27287.083401 5.4\n",
|
||
"Turkey 28384.987785 5.5\n",
|
||
"Latvia 29932.493910 5.9\n",
|
||
"Hungary 31007.768407 5.6"
|
||
]
|
||
},
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"min_gdp = 23_500\n",
|
||
"max_gdp = 62_500\n",
|
||
"\n",
|
||
"country_stats = full_country_stats[(full_country_stats[gdppc_col] >= min_gdp) &\n",
|
||
" (full_country_stats[gdppc_col] <= max_gdp)]\n",
|
||
"country_stats.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"country_stats.to_csv(datapath / \"lifesat.csv\")\n",
|
||
"full_country_stats.to_csv(datapath / \"lifesat_full.csv\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 360x216 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,\n",
|
||
" x=gdppc_col, y=lifesat_col)\n",
|
||
"\n",
|
||
"min_life_sat = 4\n",
|
||
"max_life_sat = 9\n",
|
||
"\n",
|
||
"position_text = {\n",
|
||
" \"Turkey\": (29_500, 4.2),\n",
|
||
" \"Hungary\": (28_000, 6.9),\n",
|
||
" \"France\": (40_000, 5),\n",
|
||
" \"New Zealand\": (28_000, 8.2),\n",
|
||
" \"Australia\": (50_000, 5.5),\n",
|
||
" \"United States\": (59_000, 5.3),\n",
|
||
" \"Denmark\": (46_000, 8.5)\n",
|
||
"}\n",
|
||
"\n",
|
||
"for country, pos_text in position_text.items():\n",
|
||
" pos_data_x = country_stats[gdppc_col].loc[country]\n",
|
||
" pos_data_y = country_stats[lifesat_col].loc[country]\n",
|
||
" country = \"U.S.\" if country == \"United States\" else country\n",
|
||
" plt.annotate(country, xy=(pos_data_x, pos_data_y),\n",
|
||
" xytext=pos_text, fontsize=12,\n",
|
||
" arrowprops=dict(facecolor='black', width=0.5,\n",
|
||
" shrink=0.08, headwidth=5))\n",
|
||
" plt.plot(pos_data_x, pos_data_y, \"ro\")\n",
|
||
"\n",
|
||
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"save_fig('money_happy_scatterplot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>GDP per capita (USD)</th>\n",
|
||
" <th>Life satisfaction</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Country</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>Turkey</th>\n",
|
||
" <td>28384.987785</td>\n",
|
||
" <td>5.5</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Hungary</th>\n",
|
||
" <td>31007.768407</td>\n",
|
||
" <td>5.6</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>France</th>\n",
|
||
" <td>42025.617373</td>\n",
|
||
" <td>6.5</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>New Zealand</th>\n",
|
||
" <td>42404.393738</td>\n",
|
||
" <td>7.3</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Australia</th>\n",
|
||
" <td>48697.837028</td>\n",
|
||
" <td>7.3</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Denmark</th>\n",
|
||
" <td>55938.212809</td>\n",
|
||
" <td>7.6</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>United States</th>\n",
|
||
" <td>60235.728492</td>\n",
|
||
" <td>6.9</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" GDP per capita (USD) Life satisfaction\n",
|
||
"Country \n",
|
||
"Turkey 28384.987785 5.5\n",
|
||
"Hungary 31007.768407 5.6\n",
|
||
"France 42025.617373 6.5\n",
|
||
"New Zealand 42404.393738 7.3\n",
|
||
"Australia 48697.837028 7.3\n",
|
||
"Denmark 55938.212809 7.6\n",
|
||
"United States 60235.728492 6.9"
|
||
]
|
||
},
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"highlighted_countries = country_stats.loc[list(position_text.keys())]\n",
|
||
"highlighted_countries[[gdppc_col, lifesat_col]].sort_values(by=gdppc_col)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 360x216 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,\n",
|
||
" x=gdppc_col, y=lifesat_col)\n",
|
||
"\n",
|
||
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
|
||
"\n",
|
||
"w1, w2 = 4.2, 0\n",
|
||
"plt.plot(X, w1 + w2 * 1e-5 * X, \"r\")\n",
|
||
"plt.text(40_000, 4.9, fr\"$\\theta_0 = {w1}$\", color=\"r\")\n",
|
||
"plt.text(40_000, 4.4, fr\"$\\theta_1 = {w2}$\", color=\"r\")\n",
|
||
"\n",
|
||
"w1, w2 = 10, -9\n",
|
||
"plt.plot(X, w1 + w2 * 1e-5 * X, \"g\")\n",
|
||
"plt.text(26_000, 8.5, fr\"$\\theta_0 = {w1}$\", color=\"g\")\n",
|
||
"plt.text(26_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\", color=\"g\")\n",
|
||
"\n",
|
||
"w1, w2 = 3, 8\n",
|
||
"plt.plot(X, w1 + w2 * 1e-5 * X, \"b\")\n",
|
||
"plt.text(48_000, 8.5, fr\"$\\theta_0 = {w1}$\", color=\"b\")\n",
|
||
"plt.text(48_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\", color=\"b\")\n",
|
||
"\n",
|
||
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"save_fig('tweaking_model_params_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"θ0=3.75, θ1=6.78e-05\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn import linear_model\n",
|
||
"\n",
|
||
"X_sample = country_stats[[gdppc_col]].values\n",
|
||
"y_sample = country_stats[[lifesat_col]].values\n",
|
||
"\n",
|
||
"lin1 = linear_model.LinearRegression()\n",
|
||
"lin1.fit(X_sample, y_sample)\n",
|
||
"\n",
|
||
"t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]\n",
|
||
"print(f\"θ0={t0:.2f}, θ1={t1:.2e}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 360x216 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,\n",
|
||
" x=gdppc_col, y=lifesat_col)\n",
|
||
"\n",
|
||
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
|
||
"plt.plot(X, t0 + t1 * X, \"b\")\n",
|
||
"\n",
|
||
"plt.text(max_gdp - 20_000, min_life_sat + 1.9,\n",
|
||
" fr\"$\\theta_0 = {t0:.2f}$\", color=\"b\")\n",
|
||
"plt.text(max_gdp - 20_000, min_life_sat + 1.3,\n",
|
||
" fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\", color=\"b\")\n",
|
||
"\n",
|
||
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"save_fig('best_fit_model_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"37655.1803457421"
|
||
]
|
||
},
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"cyprus_gdp_per_capita = gdp_per_capita[gdppc_col].loc[\"Cyprus\"]\n",
|
||
"cyprus_gdp_per_capita"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"6.301656332738056"
|
||
]
|
||
},
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0, 0]\n",
|
||
"cyprus_predicted_life_satisfaction"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 360x216 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,\n",
|
||
" x=gdppc_col, y=lifesat_col)\n",
|
||
"\n",
|
||
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
|
||
"plt.plot(X, t0 + t1 * X, \"b\")\n",
|
||
"\n",
|
||
"plt.text(min_gdp + 22_000, max_life_sat - 1.1,\n",
|
||
" fr\"$\\theta_0 = {t0:.2f}$\", color=\"b\")\n",
|
||
"plt.text(min_gdp + 22_000, max_life_sat - 0.6,\n",
|
||
" fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\", color=\"b\")\n",
|
||
"\n",
|
||
"plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita],\n",
|
||
" [min_life_sat, cyprus_predicted_life_satisfaction], \"r--\")\n",
|
||
"plt.text(cyprus_gdp_per_capita + 1000, 5.0,\n",
|
||
" fr\"Prediction = {cyprus_predicted_life_satisfaction:.2f}\", color=\"r\")\n",
|
||
"plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, \"ro\")\n",
|
||
"\n",
|
||
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>GDP per capita (USD)</th>\n",
|
||
" <th>Life satisfaction</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Country</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>South Africa</th>\n",
|
||
" <td>11466.189672</td>\n",
|
||
" <td>4.7</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Colombia</th>\n",
|
||
" <td>13441.492952</td>\n",
|
||
" <td>6.3</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Brazil</th>\n",
|
||
" <td>14063.982505</td>\n",
|
||
" <td>6.4</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Mexico</th>\n",
|
||
" <td>17887.750736</td>\n",
|
||
" <td>6.5</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Chile</th>\n",
|
||
" <td>23324.524751</td>\n",
|
||
" <td>6.5</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Norway</th>\n",
|
||
" <td>63585.903514</td>\n",
|
||
" <td>7.6</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Switzerland</th>\n",
|
||
" <td>68393.306004</td>\n",
|
||
" <td>7.5</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Ireland</th>\n",
|
||
" <td>89688.956958</td>\n",
|
||
" <td>7.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Luxembourg</th>\n",
|
||
" <td>110261.157353</td>\n",
|
||
" <td>6.9</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" GDP per capita (USD) Life satisfaction\n",
|
||
"Country \n",
|
||
"South Africa 11466.189672 4.7\n",
|
||
"Colombia 13441.492952 6.3\n",
|
||
"Brazil 14063.982505 6.4\n",
|
||
"Mexico 17887.750736 6.5\n",
|
||
"Chile 23324.524751 6.5\n",
|
||
"Norway 63585.903514 7.6\n",
|
||
"Switzerland 68393.306004 7.5\n",
|
||
"Ireland 89688.956958 7.0\n",
|
||
"Luxembourg 110261.157353 6.9"
|
||
]
|
||
},
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"missing_data = full_country_stats[(full_country_stats[gdppc_col] < min_gdp) |\n",
|
||
" (full_country_stats[gdppc_col] > max_gdp)]\n",
|
||
"missing_data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"position_text_missing_countries = {\n",
|
||
" \"South Africa\": (20_000, 4.2),\n",
|
||
" \"Colombia\": (6_000, 8.2),\n",
|
||
" \"Brazil\": (18_000, 7.8),\n",
|
||
" \"Mexico\": (24_000, 7.4),\n",
|
||
" \"Chile\": (30_000, 7.0),\n",
|
||
" \"Norway\": (51_000, 6.2),\n",
|
||
" \"Switzerland\": (62_000, 5.7),\n",
|
||
" \"Ireland\": (81_000, 5.2),\n",
|
||
" \"Luxembourg\": (92_000, 4.7),\n",
|
||
"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 576x216 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"full_country_stats.plot(kind='scatter', figsize=(8, 3),\n",
|
||
" x=gdppc_col, y=lifesat_col, grid=True)\n",
|
||
"\n",
|
||
"for country, pos_text in position_text_missing_countries.items():\n",
|
||
" pos_data_x, pos_data_y = missing_data.loc[country]\n",
|
||
" plt.annotate(country, xy=(pos_data_x, pos_data_y),\n",
|
||
" xytext=pos_text, fontsize=12,\n",
|
||
" arrowprops=dict(facecolor='black', width=0.5,\n",
|
||
" shrink=0.08, headwidth=5))\n",
|
||
" plt.plot(pos_data_x, pos_data_y, \"rs\")\n",
|
||
"\n",
|
||
"X = np.linspace(0, 115_000, 1000)\n",
|
||
"plt.plot(X, t0 + t1 * X, \"b:\")\n",
|
||
"\n",
|
||
"lin_reg_full = linear_model.LinearRegression()\n",
|
||
"Xfull = np.c_[full_country_stats[gdppc_col]]\n",
|
||
"yfull = np.c_[full_country_stats[lifesat_col]]\n",
|
||
"lin_reg_full.fit(Xfull, yfull)\n",
|
||
"\n",
|
||
"t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n",
|
||
"X = np.linspace(0, 115_000, 1000)\n",
|
||
"plt.plot(X, t0full + t1full * X, \"k\")\n",
|
||
"\n",
|
||
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"save_fig('representative_training_data_scatterplot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 576x216 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn import preprocessing\n",
|
||
"from sklearn import pipeline\n",
|
||
"\n",
|
||
"full_country_stats.plot(kind='scatter', figsize=(8, 3),\n",
|
||
" x=gdppc_col, y=lifesat_col, grid=True)\n",
|
||
"\n",
|
||
"poly = preprocessing.PolynomialFeatures(degree=10, include_bias=False)\n",
|
||
"scaler = preprocessing.StandardScaler()\n",
|
||
"lin_reg2 = linear_model.LinearRegression()\n",
|
||
"\n",
|
||
"pipeline_reg = pipeline.Pipeline([\n",
|
||
" ('poly', poly),\n",
|
||
" ('scal', scaler),\n",
|
||
" ('lin', lin_reg2)])\n",
|
||
"pipeline_reg.fit(Xfull, yfull)\n",
|
||
"curve = pipeline_reg.predict(X[:, np.newaxis])\n",
|
||
"plt.plot(X, curve)\n",
|
||
"\n",
|
||
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"save_fig('overfitting_model_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"Country\n",
|
||
"New Zealand 7.3\n",
|
||
"Sweden 7.3\n",
|
||
"Norway 7.6\n",
|
||
"Switzerland 7.5\n",
|
||
"Name: Life satisfaction, dtype: float64"
|
||
]
|
||
},
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"w_countries = [c for c in full_country_stats.index if \"W\" in c.upper()]\n",
|
||
"full_country_stats.loc[w_countries][lifesat_col]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>GDP per capita (USD)</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Country</th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>Malawi</th>\n",
|
||
" <td>1486.778248</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Rwanda</th>\n",
|
||
" <td>2098.710362</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Zimbabwe</th>\n",
|
||
" <td>2744.690758</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Africa Western and Central</th>\n",
|
||
" <td>4003.158913</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Papua New Guinea</th>\n",
|
||
" <td>4101.218882</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Lower middle income</th>\n",
|
||
" <td>6722.809932</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Eswatini</th>\n",
|
||
" <td>8392.717564</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Low & middle income</th>\n",
|
||
" <td>10293.855325</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Arab World</th>\n",
|
||
" <td>13753.707307</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Botswana</th>\n",
|
||
" <td>16040.008473</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>World</th>\n",
|
||
" <td>16194.040310</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>New Zealand</th>\n",
|
||
" <td>42404.393738</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Sweden</th>\n",
|
||
" <td>50683.323510</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Norway</th>\n",
|
||
" <td>63585.903514</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>Switzerland</th>\n",
|
||
" <td>68393.306004</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" GDP per capita (USD)\n",
|
||
"Country \n",
|
||
"Malawi 1486.778248\n",
|
||
"Rwanda 2098.710362\n",
|
||
"Zimbabwe 2744.690758\n",
|
||
"Africa Western and Central 4003.158913\n",
|
||
"Papua New Guinea 4101.218882\n",
|
||
"Lower middle income 6722.809932\n",
|
||
"Eswatini 8392.717564\n",
|
||
"Low & middle income 10293.855325\n",
|
||
"Arab World 13753.707307\n",
|
||
"Botswana 16040.008473\n",
|
||
"World 16194.040310\n",
|
||
"New Zealand 42404.393738\n",
|
||
"Sweden 50683.323510\n",
|
||
"Norway 63585.903514\n",
|
||
"Switzerland 68393.306004"
|
||
]
|
||
},
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"all_w_countries = [c for c in gdp_per_capita.index if \"W\" in c.upper()]\n",
|
||
"gdp_per_capita.loc[all_w_countries].sort_values(by=gdppc_col)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 29,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 576x216 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"country_stats.plot(kind='scatter', x=gdppc_col, y=lifesat_col, figsize=(8, 3))\n",
|
||
"missing_data.plot(kind='scatter', x=gdppc_col, y=lifesat_col,\n",
|
||
" marker=\"s\", color=\"r\", grid=True, ax=plt.gca())\n",
|
||
"\n",
|
||
"X = np.linspace(0, 115_000, 1000)\n",
|
||
"plt.plot(X, t0 + t1*X, \"b:\", label=\"Linear model on partial data\")\n",
|
||
"plt.plot(X, t0full + t1full * X, \"k-\", label=\"Linear model on all data\")\n",
|
||
"\n",
|
||
"ridge = linear_model.Ridge(alpha=10**9.5)\n",
|
||
"X_sample = country_stats[[gdppc_col]]\n",
|
||
"y_sample = country_stats[[lifesat_col]]\n",
|
||
"ridge.fit(X_sample, y_sample)\n",
|
||
"t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n",
|
||
"plt.plot(X, t0ridge + t1ridge * X, \"b--\",\n",
|
||
" label=\"Regularized linear model on partial data\")\n",
|
||
"plt.legend(loc=\"lower right\")\n",
|
||
"\n",
|
||
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"save_fig('ridge_model_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Exercise Solutions"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"1. Machine Learning is about building systems that can learn from data. Learning means getting better at some task, given some performance measure.\n",
|
||
"2. Machine Learning is great for complex problems for which we have no algorithmic solution, to replace long lists of hand-tuned rules, to build systems that adapt to fluctuating environments, and finally to help humans learn (e.g., data mining).\n",
|
||
"3. A labeled training set is a training set that contains the desired solution (a.k.a. a label) for each instance.\n",
|
||
"4. The two most common supervised tasks are regression and classification.\n",
|
||
"5. Common unsupervised tasks include clustering, visualization, dimensionality reduction, and association rule learning.\n",
|
||
"6. Reinforcement Learning is likely to perform best if we want a robot to learn to walk in various unknown terrains, since this is typically the type of problem that Reinforcement Learning tackles. It might be possible to express the problem as a supervised or semi-supervised learning problem, but it would be less natural.\n",
|
||
"7. If you don't know how to define the groups, then you can use a clustering algorithm (unsupervised learning) to segment your customers into clusters of similar customers. However, if you know what groups you would like to have, then you can feed many examples of each group to a classification algorithm (supervised learning), and it will classify all your customers into these groups.\n",
|
||
"8. Spam detection is a typical supervised learning problem: the algorithm is fed many emails along with their labels (spam or not spam).\n",
|
||
"9. An online learning system can learn incrementally, as opposed to a batch learning system. This makes it capable of adapting rapidly to both changing data and autonomous systems, and of training on very large quantities of data.\n",
|
||
"10. Out-of-core algorithms can handle vast quantities of data that cannot fit in a computer's main memory. An out-of-core learning algorithm chops the data into mini-batches and uses online learning techniques to learn from these mini-batches.\n",
|
||
"11. An instance-based learning system learns the training data by heart; then, when given a new instance, it uses a similarity measure to find the most similar learned instances and uses them to make predictions.\n",
|
||
"12. A model has one or more model parameters that determine what it will predict given a new instance (e.g., the slope of a linear model). A learning algorithm tries to find optimal values for these parameters such that the model generalizes well to new instances. A hyperparameter is a parameter of the learning algorithm itself, not of the model (e.g., the amount of regularization to apply).\n",
|
||
"13. Model-based learning algorithms search for an optimal value for the model parameters such that the model will generalize well to new instances. We usually train such systems by minimizing a cost function that measures how bad the system is at making predictions on the training data, plus a penalty for model complexity if the model is regularized. To make predictions, we feed the new instance's features into the model's prediction function, using the parameter values found by the learning algorithm.\n",
|
||
"14. Some of the main challenges in Machine Learning are the lack of data, poor data quality, nonrepresentative data, uninformative features, excessively simple models that underfit the training data, and excessively complex models that overfit the data.\n",
|
||
"15. If a model performs great on the training data but generalizes poorly to new instances, the model is likely overfitting the training data (or we got extremely lucky on the training data). Possible solutions to overfitting are getting more data, simplifying the model (selecting a simpler algorithm, reducing the number of parameters or features used, or regularizing the model), or reducing the noise in the training data.\n",
|
||
"16. A test set is used to estimate the generalization error that a model will make on new instances, before the model is launched in production.\n",
|
||
"17. A validation set is used to compare models. It makes it possible to select the best model and tune the hyperparameters.\n",
|
||
"18. The train-dev set is used when there is a risk of mismatch between the training data and the data used in the validation and test datasets (which should always be as close as possible to the data used once the model is in production). The train-dev set is a part of the training set that's held out (the model is not trained on it). The model is trained on the rest of the training set, and evaluated on both the train-dev set and the validation set. If the model performs well on the training set but not on the train-dev set, then the model is likely overfitting the training set. If it performs well on both the training set and the train-dev set, but not on the validation set, then there is probably a significant data mismatch between the training data and the validation + test data, and you should try to improve the training data to make it look more like the validation + test data.\n",
|
||
"19. If you tune hyperparameters using the test set, you risk overfitting the test set, and the generalization error you measure will be optimistic (you may launch a model that performs worse than you expect)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.6"
|
||
},
|
||
"metadata": {
|
||
"interpreter": {
|
||
"hash": "22b0ec00cd9e253c751e6d2619fc0bb2d18ed12980de3246690d5be49479dd65"
|
||
}
|
||
},
|
||
"nav_menu": {},
|
||
"toc": {
|
||
"navigate_menu": true,
|
||
"number_sections": true,
|
||
"sideBar": true,
|
||
"threshold": 6,
|
||
"toc_cell": false,
|
||
"toc_section_display": "block",
|
||
"toc_window_display": true
|
||
},
|
||
"toc_position": {
|
||
"height": "616px",
|
||
"left": "0px",
|
||
"right": "20px",
|
||
"top": "106px",
|
||
"width": "213px"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|