2016-05-22 16:01:18 +02:00
{
"cells": [
{
"cell_type": "markdown",
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"source": [
2016-09-27 23:31:21 +02:00
"**Chapter 1 – The Machine Learning landscape**\n",
2016-09-27 16:39:16 +02:00
"\n",
2021-10-15 10:46:27 +02:00
"_This contains the code example in this chapter 1, as well as all the code used to generate `lifesat.csv` and some of this chapter's figures._"
2016-09-27 16:39:16 +02:00
]
},
2019-11-05 15:26:52 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
2021-05-25 21:40:58 +02:00
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml2/blob/master/01_the_machine_learning_landscape.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
2019-11-05 15:26:52 +01:00
" </td>\n",
2021-05-20 01:09:54 +02:00
" <td>\n",
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml2/blob/master/01_the_machine_learning_landscape.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
2019-11-05 15:26:52 +01:00
" </td>\n",
"</table>"
]
},
2016-09-27 16:39:16 +02:00
{
"cell_type": "markdown",
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-09-27 16:39:16 +02:00
"source": [
2021-10-19 12:15:36 +02:00
"# Setup"
2016-09-27 16:39:16 +02:00
]
},
2016-05-22 16:01:18 +02:00
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [],
"source": [
2021-10-17 03:27:34 +02:00
"# Python ≥3.8 is required\n",
2019-01-16 16:42:00 +01:00
"import sys\n",
2021-10-17 03:27:34 +02:00
"assert sys.version_info >= (3, 8)"
2018-01-15 17:25:17 +01:00
]
},
2019-01-21 11:42:31 +01:00
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
2021-10-15 10:46:27 +02:00
"import numpy as np\n",
"\n",
"# Make this notebook's output stable across runs\n",
"np.random.seed(42)"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2019-01-21 11:42:31 +01:00
"execution_count": 3,
2018-01-15 17:25:17 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-10-15 10:46:27 +02:00
"# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"1.0\""
2018-01-15 17:25:17 +01:00
]
},
{
"cell_type": "code",
2019-01-21 11:42:31 +01:00
"execution_count": 4,
2018-01-15 17:25:17 +01:00
"metadata": {},
"outputs": [],
2019-01-16 16:42:00 +01:00
"source": [
"# To plot pretty figures directly within Jupyter\n",
"%matplotlib inline\n",
"import matplotlib as mpl\n",
2021-10-15 10:46:27 +02:00
"\n",
2021-10-19 12:15:36 +02:00
"mpl.rc('font', size=12)\n",
2019-01-16 16:42:00 +01:00
"mpl.rc('axes', labelsize=14)\n",
"mpl.rc('xtick', labelsize=12)\n",
"mpl.rc('ytick', labelsize=12)"
]
},
{
"cell_type": "code",
2021-10-15 10:46:27 +02:00
"execution_count": 5,
2019-01-16 16:42:00 +01:00
"metadata": {},
"outputs": [],
2019-11-05 15:26:52 +01:00
"source": [
"# Download the data\n",
2021-10-15 10:46:27 +02:00
"from pathlib import Path\n",
2021-02-14 03:02:09 +01:00
"import urllib.request\n",
2021-10-15 10:46:27 +02:00
"\n",
"datapath = Path() / \"datasets\" / \"lifesat\"\n",
"datapath.mkdir(parents=True, exist_ok=True)\n",
"\n",
"root = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
"filename = \"lifesat.csv\"\n",
"if not (datapath / filename).is_file():\n",
2019-11-05 15:26:52 +01:00
" print(\"Downloading\", filename)\n",
2021-10-15 10:46:27 +02:00
" url = root + \"datasets/lifesat/\" + filename\n",
" urllib.request.urlretrieve(url, datapath / filename)"
2019-11-05 15:26:52 +01:00
]
},
2021-10-19 12:15:36 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Code example 1-1"
]
},
2019-11-05 15:26:52 +01:00
{
"cell_type": "code",
2021-10-15 10:46:27 +02:00
"execution_count": 6,
2019-11-05 15:26:52 +01:00
"metadata": {},
"outputs": [],
2016-05-22 16:01:18 +02:00
"source": [
2021-10-15 10:46:27 +02:00
"from pathlib import Path\n",
"\n",
2018-01-15 17:25:17 +01:00
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
2016-05-22 16:01:18 +02:00
"import pandas as pd\n",
2021-10-15 10:46:27 +02:00
"from sklearn.linear_model import LinearRegression\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-10-19 00:55:02 +02:00
"# Load and prepare the data\n",
2021-10-15 10:46:27 +02:00
"lifesat = pd.read_csv(Path() / \"datasets\" / \"lifesat\" / \"lifesat.csv\")\n",
"X = lifesat[[\"GDP per capita (USD)\"]].values\n",
"y = lifesat[[\"Life satisfaction\"]].values\n",
2018-01-15 17:25:17 +01:00
"\n",
"# Visualize the data\n",
2021-10-19 12:15:36 +02:00
"lifesat.plot(kind='scatter', grid=True,\n",
" x=\"GDP per capita (USD)\", y=\"Life satisfaction\")\n",
2021-10-15 10:46:27 +02:00
"plt.axis([23_500, 62_500, 4, 9])\n",
2018-01-15 17:25:17 +01:00
"plt.show()\n",
"\n",
"# Select a linear model\n",
2021-10-15 10:46:27 +02:00
"model = LinearRegression()\n",
2018-01-15 17:25:17 +01:00
"\n",
"# Train the model\n",
"model.fit(X, y)\n",
"\n",
"# Make a prediction for Cyprus\n",
2021-10-15 10:46:27 +02:00
"X_new = [[37_655.2]] # Cyprus' GDP per capita in 2020\n",
"print(model.predict(X_new)) # outputs [[6.30165767]]"
2018-01-15 17:25:17 +01:00
]
},
2021-05-04 18:28:54 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Replacing the Linear Regression model with k-Nearest Neighbors (in this example, k = 3) regression in the previous code is as simple as replacing these two\n",
"lines:\n",
"\n",
2021-05-27 04:58:49 +02:00
"```python\n",
2021-10-19 12:15:36 +02:00
"from sklearn.linear_model import LinearRegression\n",
"\n",
"model = LinearRegression()\n",
2021-05-27 04:58:49 +02:00
"```\n",
2021-05-04 18:28:54 +02:00
"\n",
"with these two:\n",
"\n",
2021-05-27 04:58:49 +02:00
"```python\n",
2021-10-19 12:15:36 +02:00
"from sklearn.neighbors import KNeighborsRegressor\n",
"\n",
"model = KNeighborsRegressor(n_neighbors=3)\n",
2021-05-27 04:58:49 +02:00
"```"
2021-05-04 18:28:54 +02:00
]
},
{
"cell_type": "code",
2021-10-15 10:46:27 +02:00
"execution_count": 7,
2021-05-04 18:28:54 +02:00
"metadata": {},
"outputs": [],
"source": [
"# Select a 3-Nearest Neighbors regression model\n",
2021-10-19 12:15:36 +02:00
"from sklearn.neighbors import KNeighborsRegressor\n",
2021-10-15 10:46:27 +02:00
"\n",
2021-10-19 12:15:36 +02:00
"model = KNeighborsRegressor(n_neighbors=3)\n",
2021-05-04 18:28:54 +02:00
"\n",
"# Train the model\n",
2021-10-15 10:46:27 +02:00
"model.fit(X,y)\n",
2021-05-04 18:28:54 +02:00
"\n",
"# Make a prediction for Cyprus\n",
2021-10-15 10:46:27 +02:00
"print(model.predict(X_new)) # outputs [[6.33333333]]\n"
2021-05-04 18:28:54 +02:00
]
},
2018-01-15 17:25:17 +01:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-15 10:46:27 +02:00
"# Note: you can safely ignore the rest of this notebook, it just generates many of the figures in chapter 1."
2018-01-15 17:25:17 +01:00
]
},
2019-01-16 16:42:00 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-15 10:46:27 +02:00
"Create a function to save the figures:"
2019-01-16 16:42:00 +01:00
]
},
{
"cell_type": "code",
2021-10-15 10:46:27 +02:00
"execution_count": 8,
2019-01-16 16:42:00 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Where to save the figures\n",
2021-10-15 10:46:27 +02:00
"IMAGES_PATH = Path() / \"images\" / \"fundamentals\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
2019-01-16 16:42:00 +01:00
"\n",
2019-01-21 11:42:31 +01:00
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
2021-10-15 10:46:27 +02:00
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
2019-01-16 16:42:00 +01:00
" if tight_layout:\n",
" plt.tight_layout()\n",
2019-01-21 11:42:31 +01:00
" plt.savefig(path, format=fig_extension, dpi=resolution)"
2019-01-16 16:42:00 +01:00
]
},
2018-01-15 17:25:17 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load and prepare Life satisfaction data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-10-15 10:46:27 +02:00
"To create `lifesat.csv`, I downloaded the Better Life Index (BLI) data from [OECD's website](http://stats.oecd.org/index.aspx?DataSetCode=BLI) (to get the Life Satisfaction for each country), and World Bank GDP per capita data from [OurWorldInData.org](https://ourworldindata.org/grapher/gdp-per-capita-worldbank). The BLI data is in `datasets/lifesat/oecd_bli.csv` (data from 2020), and the GDP per capita data is in `datasets/lifesat/gdp_per_capita.csv` (data up to 2020).\n",
"\n",
"If you want to grab the latest versions, please feel free to do so. However, there may be some changes (e.g., in the column names, or different countries missing data), so be prepared to have to tweak the code."
2018-01-15 17:25:17 +01:00
]
},
{
"cell_type": "code",
2021-10-15 10:46:27 +02:00
"execution_count": 9,
2018-01-15 17:25:17 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-10-15 10:46:27 +02:00
"for filename in (\"oecd_bli.csv\", \"gdp_per_capita.csv\"):\n",
" if not (datapath / filename).is_file():\n",
" print(\"Downloading\", filename)\n",
" url = root + \"datasets/lifesat/\" + filename\n",
" urllib.request.urlretrieve(url, datapath / filename)"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-10-15 10:46:27 +02:00
"execution_count": 10,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-15 10:46:27 +02:00
"oecd_bli = pd.read_csv(datapath / \"oecd_bli.csv\")\n",
"gdp_per_capita = pd.read_csv(datapath / \"gdp_per_capita.csv\")"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "markdown",
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"source": [
2021-10-19 12:15:36 +02:00
"Preprocess the GDP per capita data to keep only the year 2020:"
2018-01-15 17:25:17 +01:00
]
},
2016-05-22 16:01:18 +02:00
{
"cell_type": "code",
2021-10-15 10:46:27 +02:00
"execution_count": 11,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-19 12:15:36 +02:00
"gdp_year = 2020\n",
"gdppc_col = \"GDP per capita (USD)\"\n",
"lifesat_col = \"Life satisfaction\"\n",
"\n",
"gdp_per_capita = gdp_per_capita[gdp_per_capita[\"Year\"] == gdp_year]\n",
"gdp_per_capita = gdp_per_capita.drop([\"Code\", \"Year\"], axis=1)\n",
"gdp_per_capita.columns = [\"Country\", gdppc_col]\n",
"gdp_per_capita.set_index(\"Country\", inplace=True)\n",
"\n",
"gdp_per_capita.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Preprocess the OECD BLI data to keep only the `Life satisfaction` column:"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-10-15 10:46:27 +02:00
"execution_count": 12,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-19 12:15:36 +02:00
"oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
"oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
"\n",
"oecd_bli.head()"
2016-05-22 16:01:18 +02:00
]
},
{
2021-10-15 10:46:27 +02:00
"cell_type": "markdown",
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"source": [
2021-10-19 12:15:36 +02:00
"Now let's merge the life satisfaction data and the GDP per capita data, keeping only the GDP per capita and Life satisfaction columns:"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-10-15 10:46:27 +02:00
"execution_count": 13,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-19 12:15:36 +02:00
"full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
" left_index=True, right_index=True)\n",
"full_country_stats.sort_values(by=gdppc_col, inplace=True)\n",
"full_country_stats = full_country_stats[[gdppc_col, lifesat_col]]\n",
"\n",
"full_country_stats.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To illustrate the risk of overfitting, I use only part of the data in most figures (all countries with a GDP per capita between `min_gdp` and `max_gdp`). Later in the chapter I reveal the missing countries, and show that they don't follow the same linear trend at all."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
2021-10-15 10:46:27 +02:00
"min_gdp = 23_500\n",
"max_gdp = 62_500\n",
2021-10-19 12:15:36 +02:00
"\n",
"country_stats = full_country_stats[(full_country_stats[gdppc_col] >= min_gdp) &\n",
" (full_country_stats[gdppc_col] <= max_gdp)]\n",
2021-10-15 10:46:27 +02:00
"country_stats.head()"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"country_stats.to_csv(datapath / \"lifesat.csv\")\n",
"full_country_stats.to_csv(datapath / \"lifesat_full.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-19 12:15:36 +02:00
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n",
" x=gdppc_col, y=lifesat_col)\n",
2021-10-15 10:46:27 +02:00
"\n",
"min_life_sat = 4\n",
"max_life_sat = 9\n",
"\n",
2016-05-22 16:01:18 +02:00
"position_text = {\n",
2021-10-15 10:46:27 +02:00
" \"Hungary\": (28_000, 4.2),\n",
" \"France\": (40_000, 5),\n",
2021-10-19 12:15:36 +02:00
" \"New Zealand\": (28_000, 8.2),\n",
2021-10-15 10:46:27 +02:00
" \"Australia\": (50_000, 5.5),\n",
" \"United States\": (59_000, 5.5),\n",
" \"Denmark\": (46_000, 8.5)\n",
2016-05-22 16:01:18 +02:00
"}\n",
2021-10-15 10:46:27 +02:00
"\n",
2016-05-22 16:01:18 +02:00
"for country, pos_text in position_text.items():\n",
2021-10-19 12:15:36 +02:00
" pos_data_x = country_stats[gdppc_col].loc[country]\n",
" pos_data_y = country_stats[lifesat_col].loc[country]\n",
2016-05-22 16:01:18 +02:00
" country = \"U.S.\" if country == \"United States\" else country\n",
2021-10-19 12:15:36 +02:00
" plt.annotate(country, xy=(pos_data_x, pos_data_y),\n",
" xytext=pos_text,\n",
" arrowprops=dict(facecolor='black', width=0.5,\n",
" shrink=0.15, headwidth=5))\n",
2016-05-22 16:01:18 +02:00
" plt.plot(pos_data_x, pos_data_y, \"ro\")\n",
2021-10-15 10:46:27 +02:00
"\n",
2021-10-19 12:15:36 +02:00
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
2021-10-15 10:46:27 +02:00
"\n",
2016-05-22 16:01:18 +02:00
"save_fig('money_happy_scatterplot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 17,
2018-01-15 17:25:17 +01:00
"metadata": {},
2017-06-01 09:57:58 +02:00
"outputs": [],
"source": [
2021-10-15 10:46:27 +02:00
"highlighted_countries = country_stats.loc[list(position_text.keys())]\n",
2021-10-19 12:15:36 +02:00
"highlighted_countries[[gdppc_col, lifesat_col]].sort_values(by=gdppc_col)"
2017-06-01 09:57:58 +02:00
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 18,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-19 12:15:36 +02:00
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n",
" x=gdppc_col, y=lifesat_col)\n",
2021-10-15 10:46:27 +02:00
"\n",
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
"\n",
"w1, w2 = 4.2, 0\n",
"plt.plot(X, w1 + w2 * 1e-5 * X, \"r\")\n",
2021-10-19 12:15:36 +02:00
"plt.text(40_000, 4.9, fr\"$\\theta_0 = {w1}$\", color=\"r\")\n",
"plt.text(40_000, 4.4, fr\"$\\theta_1 = {w2}$\", color=\"r\")\n",
2021-10-15 10:46:27 +02:00
"\n",
"w1, w2 = 10, -9\n",
"plt.plot(X, w1 + w2 * 1e-5 * X, \"g\")\n",
2021-10-19 12:15:36 +02:00
"plt.text(26_000, 8.5, fr\"$\\theta_0 = {w1}$\", color=\"g\")\n",
"plt.text(26_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\", color=\"g\")\n",
2021-10-15 10:46:27 +02:00
"\n",
"w1, w2 = 3, 8\n",
"plt.plot(X, w1 + w2 * 1e-5 * X, \"b\")\n",
2021-10-19 12:15:36 +02:00
"plt.text(48_000, 8.5, fr\"$\\theta_0 = {w1}$\", color=\"b\")\n",
"plt.text(48_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\", color=\"b\")\n",
"\n",
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
2021-10-15 10:46:27 +02:00
"\n",
2016-05-22 16:01:18 +02:00
"save_fig('tweaking_model_params_plot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 19,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
"from sklearn import linear_model\n",
2021-10-15 10:46:27 +02:00
"\n",
2021-10-19 12:15:36 +02:00
"X_sample = country_stats[[gdppc_col]].values\n",
"y_sample = country_stats[[lifesat_col]].values\n",
2021-10-15 10:46:27 +02:00
"\n",
2016-05-22 16:01:18 +02:00
"lin1 = linear_model.LinearRegression()\n",
2021-10-15 10:46:27 +02:00
"lin1.fit(X_sample, y_sample)\n",
"\n",
2016-05-22 16:01:18 +02:00
"t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]\n",
2021-10-19 12:15:36 +02:00
"print(f\"θ0={t0:.2f}, θ1={t1:.2e}\")"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 20,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-19 12:15:36 +02:00
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n",
" x=gdppc_col, y=lifesat_col)\n",
2021-10-15 10:46:27 +02:00
"\n",
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
"plt.plot(X, t0 + t1 * X, \"b\")\n",
"\n",
"plt.text(max_gdp - 20_000, min_life_sat + 1.5,\n",
2021-10-19 12:15:36 +02:00
" fr\"$\\theta_0 = {t0:.2f}$\", color=\"b\")\n",
2021-10-15 10:46:27 +02:00
"plt.text(max_gdp - 20_000, min_life_sat + 1,\n",
2021-10-19 12:15:36 +02:00
" fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\", color=\"b\")\n",
"\n",
2021-10-15 10:46:27 +02:00
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
"\n",
2016-05-22 16:01:18 +02:00
"save_fig('best_fit_model_plot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 21,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-19 12:15:36 +02:00
"cyprus_gdp_per_capita = gdp_per_capita[gdppc_col].loc[\"Cyprus\"]\n",
"cyprus_gdp_per_capita"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 22,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-15 10:46:27 +02:00
"cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0, 0]\n",
"cyprus_predicted_life_satisfaction"
2017-06-01 09:57:58 +02:00
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 23,
2018-01-15 17:25:17 +01:00
"metadata": {},
2017-06-01 09:57:58 +02:00
"outputs": [],
"source": [
2021-10-19 12:15:36 +02:00
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n",
" x=gdppc_col, y=lifesat_col)\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-10-15 10:46:27 +02:00
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
"plt.plot(X, t0 + t1 * X, \"b\")\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-10-15 10:46:27 +02:00
"plt.text(min_gdp + 15_000, max_life_sat - 1.5,\n",
2021-10-19 12:15:36 +02:00
" fr\"$\\theta_0 = {t0:.2f}$\", color=\"b\")\n",
2021-10-15 10:46:27 +02:00
"plt.text(min_gdp + 15_000, max_life_sat - 1,\n",
2021-10-19 12:15:36 +02:00
" fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\", color=\"b\")\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-10-15 10:46:27 +02:00
"plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita],\n",
" [min_life_sat, cyprus_predicted_life_satisfaction], \"r--\")\n",
"plt.text(cyprus_gdp_per_capita + 1000, 5.0,\n",
2021-10-19 12:15:36 +02:00
" fr\"Prediction = {cyprus_predicted_life_satisfaction:.2f}\", color=\"r\")\n",
2021-10-15 10:46:27 +02:00
"plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, \"ro\")\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-10-15 10:46:27 +02:00
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-10-15 10:46:27 +02:00
"save_fig('cyprus_prediction_plot')\n",
"plt.show()"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 24,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-19 12:15:36 +02:00
"missing_data = full_country_stats[(full_country_stats[gdppc_col] < min_gdp) |\n",
" (full_country_stats[gdppc_col] > max_gdp)]\n",
2016-05-22 16:01:18 +02:00
"missing_data"
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 25,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-19 12:15:36 +02:00
"position_text_missing_countries = {\n",
2021-10-15 10:46:27 +02:00
" \"South Africa\": (20_000, 4.2),\n",
" \"Colombia\": (6_000, 8.2),\n",
" \"Brazil\": (18_000, 7.8),\n",
" \"Mexico\": (24_000, 7.4),\n",
" \"Chile\": (30_000, 7.0),\n",
2021-10-19 12:15:36 +02:00
" \"Norway\": (51_000, 6.2),\n",
" \"Switzerland\": (62_000, 5.7),\n",
" \"Ireland\": (81_000, 5.2),\n",
" \"Luxembourg\": (92_000, 4.7),\n",
2016-05-22 16:01:18 +02:00
"}"
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 26,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-15 10:46:27 +02:00
"full_country_stats.plot(kind='scatter', figsize=(8,3),\n",
2021-10-19 12:15:36 +02:00
" x=gdppc_col, y=lifesat_col, grid=True)\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-10-19 12:15:36 +02:00
"for country, pos_text in position_text_missing_countries.items():\n",
2016-05-22 16:01:18 +02:00
" pos_data_x, pos_data_y = missing_data.loc[country]\n",
2021-10-19 12:15:36 +02:00
" plt.annotate(country, xy=(pos_data_x, pos_data_y),\n",
" xytext=pos_text,\n",
" arrowprops=dict(facecolor='black', width=0.5,\n",
" shrink=0.1, headwidth=5))\n",
2016-05-22 16:01:18 +02:00
" plt.plot(pos_data_x, pos_data_y, \"rs\")\n",
"\n",
2021-10-15 10:46:27 +02:00
"X = np.linspace(0, 115_000, 1000)\n",
"plt.plot(X, t0 + t1 * X, \"b:\")\n",
2016-05-22 16:01:18 +02:00
"\n",
"lin_reg_full = linear_model.LinearRegression()\n",
2021-10-19 12:15:36 +02:00
"Xfull = np.c_[full_country_stats[gdppc_col]]\n",
"yfull = np.c_[full_country_stats[lifesat_col]]\n",
2016-05-22 16:01:18 +02:00
"lin_reg_full.fit(Xfull, yfull)\n",
"\n",
"t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n",
2021-10-15 10:46:27 +02:00
"X = np.linspace(0, 115_000, 1000)\n",
2016-05-22 16:01:18 +02:00
"plt.plot(X, t0full + t1full * X, \"k\")\n",
2021-10-15 10:46:27 +02:00
"\n",
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
2016-05-22 16:01:18 +02:00
"\n",
"save_fig('representative_training_data_scatterplot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 27,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
"from sklearn import preprocessing\n",
"from sklearn import pipeline\n",
"\n",
2021-10-15 10:46:27 +02:00
"full_country_stats.plot(kind='scatter', figsize=(8,3),\n",
2021-10-19 12:15:36 +02:00
" x=gdppc_col, y=lifesat_col, grid=True)\n",
2021-10-15 10:46:27 +02:00
"\n",
"poly = preprocessing.PolynomialFeatures(degree=10, include_bias=False)\n",
2016-05-22 16:01:18 +02:00
"scaler = preprocessing.StandardScaler()\n",
"lin_reg2 = linear_model.LinearRegression()\n",
"\n",
2021-10-15 10:46:27 +02:00
"pipeline_reg = pipeline.Pipeline([\n",
" ('poly', poly),\n",
" ('scal', scaler),\n",
" ('lin', lin_reg2)])\n",
2016-05-22 16:01:18 +02:00
"pipeline_reg.fit(Xfull, yfull)\n",
"curve = pipeline_reg.predict(X[:, np.newaxis])\n",
"plt.plot(X, curve)\n",
2021-10-19 12:15:36 +02:00
"\n",
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
2021-10-15 10:46:27 +02:00
"\n",
2016-05-22 16:01:18 +02:00
"save_fig('overfitting_model_plot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 28,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-15 10:46:27 +02:00
"w_countries = [c for c in full_country_stats.index if \"W\" in c.upper()]\n",
2021-10-19 12:15:36 +02:00
"full_country_stats.loc[w_countries][lifesat_col]"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 29,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-19 12:15:36 +02:00
"all_w_countries = [c for c in gdp_per_capita.index if \"W\" in c.upper()]\n",
"gdp_per_capita.loc[all_w_countries].sort_values(by=gdppc_col)"
2016-05-22 16:01:18 +02:00
]
},
{
"cell_type": "code",
2021-10-19 12:15:36 +02:00
"execution_count": 30,
2018-01-15 17:25:17 +01:00
"metadata": {},
2016-05-22 16:01:18 +02:00
"outputs": [],
"source": [
2021-10-19 12:15:36 +02:00
"country_stats.plot(kind='scatter', x=gdppc_col, y=lifesat_col, figsize=(8,3))\n",
"missing_data.plot(kind='scatter', x=gdppc_col, y=lifesat_col,\n",
" marker=\"s\", color=\"r\", grid=True, ax=plt.gca())\n",
2016-05-22 16:01:18 +02:00
"\n",
2021-10-15 10:46:27 +02:00
"X = np.linspace(0, 115_000, 1000)\n",
2016-05-22 16:01:18 +02:00
"plt.plot(X, t0 + t1*X, \"b:\", label=\"Linear model on partial data\")\n",
2021-10-19 12:15:36 +02:00
"plt.plot(X, t0full + t1full * X, \"k-\", label=\"Linear model on all data\")\n",
2016-05-22 16:01:18 +02:00
"\n",
"ridge = linear_model.Ridge(alpha=10**9.5)\n",
2021-10-19 12:15:36 +02:00
"X_sample = country_stats[[gdppc_col]]\n",
"y_sample = country_stats[[lifesat_col]]\n",
"ridge.fit(X_sample, y_sample)\n",
2016-05-22 16:01:18 +02:00
"t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n",
2021-10-19 12:15:36 +02:00
"plt.plot(X, t0ridge + t1ridge * X, \"b--\",\n",
" label=\"Regularized linear model on partial data\")\n",
2016-05-22 16:01:18 +02:00
"plt.legend(loc=\"lower right\")\n",
2021-10-15 10:46:27 +02:00
"\n",
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
"\n",
2016-05-22 16:01:18 +02:00
"save_fig('ridge_model_plot')\n",
"plt.show()"
]
},
2019-07-10 17:08:12 +02:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2016-05-22 16:01:18 +02:00
}
],
"metadata": {
"kernelspec": {
2021-10-15 10:46:27 +02:00
"display_name": "homl3",
2021-05-27 04:58:49 +02:00
"language": "python",
2021-10-15 10:46:27 +02:00
"name": "homl3"
2016-05-22 16:01:18 +02:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-10-17 03:27:34 +02:00
"version": "3.8.12"
2021-05-04 18:28:54 +02:00
},
"metadata": {
"interpreter": {
"hash": "22b0ec00cd9e253c751e6d2619fc0bb2d18ed12980de3246690d5be49479dd65"
}
2016-05-22 16:01:18 +02:00
},
2016-09-27 16:39:16 +02:00
"nav_menu": {},
2016-05-22 16:01:18 +02:00
"toc": {
2016-09-27 16:39:16 +02:00
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
2016-05-22 16:01:18 +02:00
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": true
},
"toc_position": {
2016-09-27 16:39:16 +02:00
"height": "616px",
"left": "0px",
2016-05-22 16:01:18 +02:00
"right": "20px",
2016-09-27 16:39:16 +02:00
"top": "106px",
2016-05-22 16:01:18 +02:00
"width": "213px"
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
2016-05-22 16:01:18 +02:00
}