771 lines
22 KiB
Plaintext
771 lines
22 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Chapter 1 – The Machine Learning landscape**\n",
|
||
"\n",
|
||
"_This notebook contains the code examples in chapter 1, as well as all the code used to generate `lifesat.csv` from the original data sources, and some of this chapter's figures._\n",
|
||
"\n",
|
||
"You're welcome to go through it if you want, but it's just a teaser: the real action starts in the next chapter."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"<table align=\"left\">\n",
|
||
" <td>\n",
|
||
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/01_the_machine_learning_landscape.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
|
||
" </td>\n",
|
||
" <td>\n",
|
||
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/01_the_machine_learning_landscape.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
|
||
" </td>\n",
|
||
"</table>"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Setup"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Python 3.8 is required:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"slideshow": {
|
||
"slide_type": "-"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import sys\n",
|
||
"assert sys.version_info >= (3, 8)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Make this notebook's output stable across runs:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"np.random.seed(42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Scikit-Learn ≥1.0 is required:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import sklearn\n",
|
||
"\n",
|
||
"assert sklearn.__version__ >= \"1.0\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"To plot pretty figures directly within Jupyter:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import matplotlib as mpl\n",
|
||
"\n",
|
||
"mpl.rc('font', size=12)\n",
|
||
"mpl.rc('axes', labelsize=14)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Download `lifesat.csv` from github, unless it's already available locally:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from pathlib import Path\n",
|
||
"import urllib.request\n",
|
||
"\n",
|
||
"datapath = Path() / \"datasets\" / \"lifesat\"\n",
|
||
"datapath.mkdir(parents=True, exist_ok=True)\n",
|
||
"\n",
|
||
"root = \"https://raw.githubusercontent.com/ageron/handson-ml3/main/\"\n",
|
||
"filename = \"lifesat.csv\"\n",
|
||
"if not (datapath / filename).is_file():\n",
|
||
" print(\"Downloading\", filename)\n",
|
||
" url = root + \"datasets/lifesat/\" + filename\n",
|
||
" urllib.request.urlretrieve(url, datapath / filename)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Code example 1-1"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from pathlib import Path\n",
|
||
"\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"from sklearn.linear_model import LinearRegression\n",
|
||
"\n",
|
||
"# Load and prepare the data\n",
|
||
"lifesat = pd.read_csv(Path() / \"datasets\" / \"lifesat\" / \"lifesat.csv\")\n",
|
||
"X = lifesat[[\"GDP per capita (USD)\"]].values\n",
|
||
"y = lifesat[[\"Life satisfaction\"]].values\n",
|
||
"\n",
|
||
"# Visualize the data\n",
|
||
"lifesat.plot(kind='scatter', grid=True,\n",
|
||
" x=\"GDP per capita (USD)\", y=\"Life satisfaction\")\n",
|
||
"plt.axis([23_500, 62_500, 4, 9])\n",
|
||
"plt.show()\n",
|
||
"\n",
|
||
"# Select a linear model\n",
|
||
"model = LinearRegression()\n",
|
||
"\n",
|
||
"# Train the model\n",
|
||
"model.fit(X, y)\n",
|
||
"\n",
|
||
"# Make a prediction for Cyprus\n",
|
||
"X_new = [[37_655.2]] # Cyprus' GDP per capita in 2020\n",
|
||
"print(model.predict(X_new)) # outputs [[6.30165767]]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Replacing the Linear Regression model with k-Nearest Neighbors (in this example, k = 3) regression in the previous code is as simple as replacing these two\n",
|
||
"lines:\n",
|
||
"\n",
|
||
"```python\n",
|
||
"from sklearn.linear_model import LinearRegression\n",
|
||
"\n",
|
||
"model = LinearRegression()\n",
|
||
"```\n",
|
||
"\n",
|
||
"with these two:\n",
|
||
"\n",
|
||
"```python\n",
|
||
"from sklearn.neighbors import KNeighborsRegressor\n",
|
||
"\n",
|
||
"model = KNeighborsRegressor(n_neighbors=3)\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Select a 3-Nearest Neighbors regression model\n",
|
||
"from sklearn.neighbors import KNeighborsRegressor\n",
|
||
"\n",
|
||
"model = KNeighborsRegressor(n_neighbors=3)\n",
|
||
"\n",
|
||
"# Train the model\n",
|
||
"model.fit(X,y)\n",
|
||
"\n",
|
||
"# Make a prediction for Cyprus\n",
|
||
"print(model.predict(X_new)) # outputs [[6.33333333]]\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Note: you can safely ignore the rest of this notebook, it just generates many of the figures in chapter 1."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Create a function to save the figures:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Where to save the figures\n",
|
||
"IMAGES_PATH = Path() / \"images\" / \"fundamentals\"\n",
|
||
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
|
||
"\n",
|
||
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
|
||
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
|
||
" if tight_layout:\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.savefig(path, format=fig_extension, dpi=resolution)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Load and prepare Life satisfaction data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"To create `lifesat.csv`, I downloaded the Better Life Index (BLI) data from [OECD's website](http://stats.oecd.org/index.aspx?DataSetCode=BLI) (to get the Life Satisfaction for each country), and World Bank GDP per capita data from [OurWorldInData.org](https://ourworldindata.org/grapher/gdp-per-capita-worldbank). The BLI data is in `datasets/lifesat/oecd_bli.csv` (data from 2020), and the GDP per capita data is in `datasets/lifesat/gdp_per_capita.csv` (data up to 2020).\n",
|
||
"\n",
|
||
"If you want to grab the latest versions, please feel free to do so. However, there may be some changes (e.g., in the column names, or different countries missing data), so be prepared to have to tweak the code."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for filename in (\"oecd_bli.csv\", \"gdp_per_capita.csv\"):\n",
|
||
" if not (datapath / filename).is_file():\n",
|
||
" print(\"Downloading\", filename)\n",
|
||
" url = root + \"datasets/lifesat/\" + filename\n",
|
||
" urllib.request.urlretrieve(url, datapath / filename)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"oecd_bli = pd.read_csv(datapath / \"oecd_bli.csv\")\n",
|
||
"gdp_per_capita = pd.read_csv(datapath / \"gdp_per_capita.csv\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Preprocess the GDP per capita data to keep only the year 2020:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"gdp_year = 2020\n",
|
||
"gdppc_col = \"GDP per capita (USD)\"\n",
|
||
"lifesat_col = \"Life satisfaction\"\n",
|
||
"\n",
|
||
"gdp_per_capita = gdp_per_capita[gdp_per_capita[\"Year\"] == gdp_year]\n",
|
||
"gdp_per_capita = gdp_per_capita.drop([\"Code\", \"Year\"], axis=1)\n",
|
||
"gdp_per_capita.columns = [\"Country\", gdppc_col]\n",
|
||
"gdp_per_capita.set_index(\"Country\", inplace=True)\n",
|
||
"\n",
|
||
"gdp_per_capita.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Preprocess the OECD BLI data to keep only the `Life satisfaction` column:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
|
||
"oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
|
||
"\n",
|
||
"oecd_bli.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Now let's merge the life satisfaction data and the GDP per capita data, keeping only the GDP per capita and Life satisfaction columns:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
|
||
" left_index=True, right_index=True)\n",
|
||
"full_country_stats.sort_values(by=gdppc_col, inplace=True)\n",
|
||
"full_country_stats = full_country_stats[[gdppc_col, lifesat_col]]\n",
|
||
"\n",
|
||
"full_country_stats.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"To illustrate the risk of overfitting, I use only part of the data in most figures (all countries with a GDP per capita between `min_gdp` and `max_gdp`). Later in the chapter I reveal the missing countries, and show that they don't follow the same linear trend at all."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"min_gdp = 23_500\n",
|
||
"max_gdp = 62_500\n",
|
||
"\n",
|
||
"country_stats = full_country_stats[(full_country_stats[gdppc_col] >= min_gdp) &\n",
|
||
" (full_country_stats[gdppc_col] <= max_gdp)]\n",
|
||
"country_stats.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"country_stats.to_csv(datapath / \"lifesat.csv\")\n",
|
||
"full_country_stats.to_csv(datapath / \"lifesat_full.csv\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n",
|
||
" x=gdppc_col, y=lifesat_col)\n",
|
||
"\n",
|
||
"min_life_sat = 4\n",
|
||
"max_life_sat = 9\n",
|
||
"\n",
|
||
"position_text = {\n",
|
||
" \"Hungary\": (28_000, 4.2),\n",
|
||
" \"France\": (40_000, 5),\n",
|
||
" \"New Zealand\": (28_000, 8.2),\n",
|
||
" \"Australia\": (50_000, 5.5),\n",
|
||
" \"United States\": (59_000, 5.5),\n",
|
||
" \"Denmark\": (46_000, 8.5)\n",
|
||
"}\n",
|
||
"\n",
|
||
"for country, pos_text in position_text.items():\n",
|
||
" pos_data_x = country_stats[gdppc_col].loc[country]\n",
|
||
" pos_data_y = country_stats[lifesat_col].loc[country]\n",
|
||
" country = \"U.S.\" if country == \"United States\" else country\n",
|
||
" plt.annotate(country, xy=(pos_data_x, pos_data_y),\n",
|
||
" xytext=pos_text,\n",
|
||
" arrowprops=dict(facecolor='black', width=0.5,\n",
|
||
" shrink=0.15, headwidth=5))\n",
|
||
" plt.plot(pos_data_x, pos_data_y, \"ro\")\n",
|
||
"\n",
|
||
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"save_fig('money_happy_scatterplot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"highlighted_countries = country_stats.loc[list(position_text.keys())]\n",
|
||
"highlighted_countries[[gdppc_col, lifesat_col]].sort_values(by=gdppc_col)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n",
|
||
" x=gdppc_col, y=lifesat_col)\n",
|
||
"\n",
|
||
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
|
||
"\n",
|
||
"w1, w2 = 4.2, 0\n",
|
||
"plt.plot(X, w1 + w2 * 1e-5 * X, \"r\")\n",
|
||
"plt.text(40_000, 4.9, fr\"$\\theta_0 = {w1}$\", color=\"r\")\n",
|
||
"plt.text(40_000, 4.4, fr\"$\\theta_1 = {w2}$\", color=\"r\")\n",
|
||
"\n",
|
||
"w1, w2 = 10, -9\n",
|
||
"plt.plot(X, w1 + w2 * 1e-5 * X, \"g\")\n",
|
||
"plt.text(26_000, 8.5, fr\"$\\theta_0 = {w1}$\", color=\"g\")\n",
|
||
"plt.text(26_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\", color=\"g\")\n",
|
||
"\n",
|
||
"w1, w2 = 3, 8\n",
|
||
"plt.plot(X, w1 + w2 * 1e-5 * X, \"b\")\n",
|
||
"plt.text(48_000, 8.5, fr\"$\\theta_0 = {w1}$\", color=\"b\")\n",
|
||
"plt.text(48_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\", color=\"b\")\n",
|
||
"\n",
|
||
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"save_fig('tweaking_model_params_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn import linear_model\n",
|
||
"\n",
|
||
"X_sample = country_stats[[gdppc_col]].values\n",
|
||
"y_sample = country_stats[[lifesat_col]].values\n",
|
||
"\n",
|
||
"lin1 = linear_model.LinearRegression()\n",
|
||
"lin1.fit(X_sample, y_sample)\n",
|
||
"\n",
|
||
"t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]\n",
|
||
"print(f\"θ0={t0:.2f}, θ1={t1:.2e}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n",
|
||
" x=gdppc_col, y=lifesat_col)\n",
|
||
"\n",
|
||
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
|
||
"plt.plot(X, t0 + t1 * X, \"b\")\n",
|
||
"\n",
|
||
"plt.text(max_gdp - 20_000, min_life_sat + 1.5,\n",
|
||
" fr\"$\\theta_0 = {t0:.2f}$\", color=\"b\")\n",
|
||
"plt.text(max_gdp - 20_000, min_life_sat + 1,\n",
|
||
" fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\", color=\"b\")\n",
|
||
"\n",
|
||
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"save_fig('best_fit_model_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"cyprus_gdp_per_capita = gdp_per_capita[gdppc_col].loc[\"Cyprus\"]\n",
|
||
"cyprus_gdp_per_capita"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0, 0]\n",
|
||
"cyprus_predicted_life_satisfaction"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n",
|
||
" x=gdppc_col, y=lifesat_col)\n",
|
||
"\n",
|
||
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
|
||
"plt.plot(X, t0 + t1 * X, \"b\")\n",
|
||
"\n",
|
||
"plt.text(min_gdp + 15_000, max_life_sat - 1.5,\n",
|
||
" fr\"$\\theta_0 = {t0:.2f}$\", color=\"b\")\n",
|
||
"plt.text(min_gdp + 15_000, max_life_sat - 1,\n",
|
||
" fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\", color=\"b\")\n",
|
||
"\n",
|
||
"plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita],\n",
|
||
" [min_life_sat, cyprus_predicted_life_satisfaction], \"r--\")\n",
|
||
"plt.text(cyprus_gdp_per_capita + 1000, 5.0,\n",
|
||
" fr\"Prediction = {cyprus_predicted_life_satisfaction:.2f}\", color=\"r\")\n",
|
||
"plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, \"ro\")\n",
|
||
"\n",
|
||
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"missing_data = full_country_stats[(full_country_stats[gdppc_col] < min_gdp) |\n",
|
||
" (full_country_stats[gdppc_col] > max_gdp)]\n",
|
||
"missing_data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"position_text_missing_countries = {\n",
|
||
" \"South Africa\": (20_000, 4.2),\n",
|
||
" \"Colombia\": (6_000, 8.2),\n",
|
||
" \"Brazil\": (18_000, 7.8),\n",
|
||
" \"Mexico\": (24_000, 7.4),\n",
|
||
" \"Chile\": (30_000, 7.0),\n",
|
||
" \"Norway\": (51_000, 6.2),\n",
|
||
" \"Switzerland\": (62_000, 5.7),\n",
|
||
" \"Ireland\": (81_000, 5.2),\n",
|
||
" \"Luxembourg\": (92_000, 4.7),\n",
|
||
"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"full_country_stats.plot(kind='scatter', figsize=(8,3),\n",
|
||
" x=gdppc_col, y=lifesat_col, grid=True)\n",
|
||
"\n",
|
||
"for country, pos_text in position_text_missing_countries.items():\n",
|
||
" pos_data_x, pos_data_y = missing_data.loc[country]\n",
|
||
" plt.annotate(country, xy=(pos_data_x, pos_data_y),\n",
|
||
" xytext=pos_text,\n",
|
||
" arrowprops=dict(facecolor='black', width=0.5,\n",
|
||
" shrink=0.1, headwidth=5))\n",
|
||
" plt.plot(pos_data_x, pos_data_y, \"rs\")\n",
|
||
"\n",
|
||
"X = np.linspace(0, 115_000, 1000)\n",
|
||
"plt.plot(X, t0 + t1 * X, \"b:\")\n",
|
||
"\n",
|
||
"lin_reg_full = linear_model.LinearRegression()\n",
|
||
"Xfull = np.c_[full_country_stats[gdppc_col]]\n",
|
||
"yfull = np.c_[full_country_stats[lifesat_col]]\n",
|
||
"lin_reg_full.fit(Xfull, yfull)\n",
|
||
"\n",
|
||
"t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n",
|
||
"X = np.linspace(0, 115_000, 1000)\n",
|
||
"plt.plot(X, t0full + t1full * X, \"k\")\n",
|
||
"\n",
|
||
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"save_fig('representative_training_data_scatterplot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn import preprocessing\n",
|
||
"from sklearn import pipeline\n",
|
||
"\n",
|
||
"full_country_stats.plot(kind='scatter', figsize=(8,3),\n",
|
||
" x=gdppc_col, y=lifesat_col, grid=True)\n",
|
||
"\n",
|
||
"poly = preprocessing.PolynomialFeatures(degree=10, include_bias=False)\n",
|
||
"scaler = preprocessing.StandardScaler()\n",
|
||
"lin_reg2 = linear_model.LinearRegression()\n",
|
||
"\n",
|
||
"pipeline_reg = pipeline.Pipeline([\n",
|
||
" ('poly', poly),\n",
|
||
" ('scal', scaler),\n",
|
||
" ('lin', lin_reg2)])\n",
|
||
"pipeline_reg.fit(Xfull, yfull)\n",
|
||
"curve = pipeline_reg.predict(X[:, np.newaxis])\n",
|
||
"plt.plot(X, curve)\n",
|
||
"\n",
|
||
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"save_fig('overfitting_model_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"w_countries = [c for c in full_country_stats.index if \"W\" in c.upper()]\n",
|
||
"full_country_stats.loc[w_countries][lifesat_col]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 29,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"all_w_countries = [c for c in gdp_per_capita.index if \"W\" in c.upper()]\n",
|
||
"gdp_per_capita.loc[all_w_countries].sort_values(by=gdppc_col)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"country_stats.plot(kind='scatter', x=gdppc_col, y=lifesat_col, figsize=(8,3))\n",
|
||
"missing_data.plot(kind='scatter', x=gdppc_col, y=lifesat_col,\n",
|
||
" marker=\"s\", color=\"r\", grid=True, ax=plt.gca())\n",
|
||
"\n",
|
||
"X = np.linspace(0, 115_000, 1000)\n",
|
||
"plt.plot(X, t0 + t1*X, \"b:\", label=\"Linear model on partial data\")\n",
|
||
"plt.plot(X, t0full + t1full * X, \"k-\", label=\"Linear model on all data\")\n",
|
||
"\n",
|
||
"ridge = linear_model.Ridge(alpha=10**9.5)\n",
|
||
"X_sample = country_stats[[gdppc_col]]\n",
|
||
"y_sample = country_stats[[lifesat_col]]\n",
|
||
"ridge.fit(X_sample, y_sample)\n",
|
||
"t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n",
|
||
"plt.plot(X, t0ridge + t1ridge * X, \"b--\",\n",
|
||
" label=\"Regularized linear model on partial data\")\n",
|
||
"plt.legend(loc=\"lower right\")\n",
|
||
"\n",
|
||
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"save_fig('ridge_model_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "homl3",
|
||
"language": "python",
|
||
"name": "homl3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.8.12"
|
||
},
|
||
"metadata": {
|
||
"interpreter": {
|
||
"hash": "22b0ec00cd9e253c751e6d2619fc0bb2d18ed12980de3246690d5be49479dd65"
|
||
}
|
||
},
|
||
"nav_menu": {},
|
||
"toc": {
|
||
"navigate_menu": true,
|
||
"number_sections": true,
|
||
"sideBar": true,
|
||
"threshold": 6,
|
||
"toc_cell": false,
|
||
"toc_section_display": "block",
|
||
"toc_window_display": true
|
||
},
|
||
"toc_position": {
|
||
"height": "616px",
|
||
"left": "0px",
|
||
"right": "20px",
|
||
"top": "106px",
|
||
"width": "213px"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|