726 lines
22 KiB
Plaintext
726 lines
22 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Chapter 1 – The Machine Learning landscape**\n",
|
||
"\n",
|
||
"_This contains the code example in this chapter 1, as well as all the code used to generate `lifesat.csv` and some of this chapter's figures._"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"<table align=\"left\">\n",
|
||
" <td>\n",
|
||
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml2/blob/master/01_the_machine_learning_landscape.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
|
||
" </td>\n",
|
||
" <td>\n",
|
||
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml2/blob/master/01_the_machine_learning_landscape.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
|
||
" </td>\n",
|
||
"</table>"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Code example 1-1"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"slideshow": {
|
||
"slide_type": "-"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Python ≥3.8 is required\n",
|
||
"import sys\n",
|
||
"assert sys.version_info >= (3, 8)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"# Make this notebook's output stable across runs\n",
|
||
"np.random.seed(42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Scikit-Learn ≥1.0 is required\n",
|
||
"import sklearn\n",
|
||
"assert sklearn.__version__ >= \"1.0\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# To plot pretty figures directly within Jupyter\n",
|
||
"%matplotlib inline\n",
|
||
"import matplotlib as mpl\n",
|
||
"\n",
|
||
"mpl.rc('axes', labelsize=14)\n",
|
||
"mpl.rc('xtick', labelsize=12)\n",
|
||
"mpl.rc('ytick', labelsize=12)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Download the data\n",
|
||
"from pathlib import Path\n",
|
||
"import urllib.request\n",
|
||
"\n",
|
||
"datapath = Path() / \"datasets\" / \"lifesat\"\n",
|
||
"datapath.mkdir(parents=True, exist_ok=True)\n",
|
||
"\n",
|
||
"root = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
|
||
"filename = \"lifesat.csv\"\n",
|
||
"if not (datapath / filename).is_file():\n",
|
||
" print(\"Downloading\", filename)\n",
|
||
" url = root + \"datasets/lifesat/\" + filename\n",
|
||
" urllib.request.urlretrieve(url, datapath / filename)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Code example\n",
|
||
"from pathlib import Path\n",
|
||
"\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"from sklearn.linear_model import LinearRegression\n",
|
||
"\n",
|
||
"# Load the data\n",
|
||
"lifesat = pd.read_csv(Path() / \"datasets\" / \"lifesat\" / \"lifesat.csv\")\n",
|
||
"X = lifesat[[\"GDP per capita (USD)\"]].values\n",
|
||
"y = lifesat[[\"Life satisfaction\"]].values\n",
|
||
"\n",
|
||
"# Visualize the data\n",
|
||
"lifesat.plot(kind='scatter',\n",
|
||
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
|
||
"plt.axis([23_500, 62_500, 4, 9])\n",
|
||
"plt.grid(True)\n",
|
||
"plt.show()\n",
|
||
"\n",
|
||
"# Select a linear model\n",
|
||
"model = LinearRegression()\n",
|
||
"\n",
|
||
"# Train the model\n",
|
||
"model.fit(X, y)\n",
|
||
"\n",
|
||
"# Make a prediction for Cyprus\n",
|
||
"X_new = [[37_655.2]] # Cyprus' GDP per capita in 2020\n",
|
||
"print(model.predict(X_new)) # outputs [[6.30165767]]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Replacing the Linear Regression model with k-Nearest Neighbors (in this example, k = 3) regression in the previous code is as simple as replacing these two\n",
|
||
"lines:\n",
|
||
"\n",
|
||
"```python\n",
|
||
"import sklearn.linear_model\n",
|
||
"model = sklearn.linear_model.LinearRegression()\n",
|
||
"```\n",
|
||
"\n",
|
||
"with these two:\n",
|
||
"\n",
|
||
"```python\n",
|
||
"import sklearn.neighbors\n",
|
||
"model = sklearn.neighbors.KNeighborsRegressor(n_neighbors=3)\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Select a 3-Nearest Neighbors regression model\n",
|
||
"import sklearn.neighbors\n",
|
||
"\n",
|
||
"model = sklearn.neighbors.KNeighborsRegressor(n_neighbors=3)\n",
|
||
"\n",
|
||
"# Train the model\n",
|
||
"model.fit(X,y)\n",
|
||
"\n",
|
||
"# Make a prediction for Cyprus\n",
|
||
"print(model.predict(X_new)) # outputs [[6.33333333]]\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Note: you can safely ignore the rest of this notebook, it just generates many of the figures in chapter 1."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Create a function to save the figures:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Where to save the figures\n",
|
||
"IMAGES_PATH = Path() / \"images\" / \"fundamentals\"\n",
|
||
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
|
||
"\n",
|
||
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
|
||
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
|
||
" if tight_layout:\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.savefig(path, format=fig_extension, dpi=resolution)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Load and prepare Life satisfaction data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"To create `lifesat.csv`, I downloaded the Better Life Index (BLI) data from [OECD's website](http://stats.oecd.org/index.aspx?DataSetCode=BLI) (to get the Life Satisfaction for each country), and World Bank GDP per capita data from [OurWorldInData.org](https://ourworldindata.org/grapher/gdp-per-capita-worldbank). The BLI data is in `datasets/lifesat/oecd_bli.csv` (data from 2020), and the GDP per capita data is in `datasets/lifesat/gdp_per_capita.csv` (data up to 2020).\n",
|
||
"\n",
|
||
"If you want to grab the latest versions, please feel free to do so. However, there may be some changes (e.g., in the column names, or different countries missing data), so be prepared to have to tweak the code."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for filename in (\"oecd_bli.csv\", \"gdp_per_capita.csv\"):\n",
|
||
" if not (datapath / filename).is_file():\n",
|
||
" print(\"Downloading\", filename)\n",
|
||
" url = root + \"datasets/lifesat/\" + filename\n",
|
||
" urllib.request.urlretrieve(url, datapath / filename)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"oecd_bli = pd.read_csv(datapath / \"oecd_bli.csv\")\n",
|
||
"gdp_per_capita = pd.read_csv(datapath / \"gdp_per_capita.csv\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"This function just merges the OECD's life satisfaction data and the World Bank's GDP per capita data:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
|
||
" gdp_year = 2020\n",
|
||
" gdppc = \"GDP per capita (USD)\"\n",
|
||
" oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
|
||
" oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
|
||
" gdp_per_capita = gdp_per_capita[gdp_per_capita[\"Year\"] == gdp_year]\n",
|
||
" gdp_per_capita = gdp_per_capita.drop([\"Code\", \"Year\"], axis=1)\n",
|
||
" gdp_per_capita.columns = [\"Country\", gdppc]\n",
|
||
" gdp_per_capita.set_index(\"Country\", inplace=True)\n",
|
||
" full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
|
||
" left_index=True, right_index=True)\n",
|
||
" full_country_stats.sort_values(by=gdppc, inplace=True)\n",
|
||
" return full_country_stats[[gdppc, 'Life satisfaction']]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"full_country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)\n",
|
||
"full_country_stats.to_csv(datapath / \"lifesat.csv\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"To illustrate the risk of overfitting, I use only part of the data in most figures (all countries with a GDP per capita between `min_gdp` and `max_gdp`). Later in the chapter I reveal the missing countries, and show that they don't follow the same linear trend at all."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"gdppc = \"GDP per capita (USD)\"\n",
|
||
"min_gdp = 23_500\n",
|
||
"max_gdp = 62_500\n",
|
||
"country_stats = full_country_stats[(full_country_stats[gdppc] >= min_gdp) &\n",
|
||
" (full_country_stats[gdppc] <= max_gdp)]\n",
|
||
"country_stats.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"country_stats.plot(kind='scatter', figsize=(5,3),\n",
|
||
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
|
||
"\n",
|
||
"min_life_sat = 4\n",
|
||
"max_life_sat = 9\n",
|
||
"\n",
|
||
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
|
||
"position_text = {\n",
|
||
" \"Hungary\": (28_000, 4.2),\n",
|
||
" \"France\": (40_000, 5),\n",
|
||
" \"New Zealand\": (30_000, 8),\n",
|
||
" \"Australia\": (50_000, 5.5),\n",
|
||
" \"United States\": (59_000, 5.5),\n",
|
||
" \"Denmark\": (46_000, 8.5)\n",
|
||
"}\n",
|
||
"\n",
|
||
"for country, pos_text in position_text.items():\n",
|
||
" pos_data_x, pos_data_y = country_stats[[\"GDP per capita (USD)\",\n",
|
||
" \"Life satisfaction\"]].loc[country]\n",
|
||
" country = \"U.S.\" if country == \"United States\" else country\n",
|
||
" plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n",
|
||
" arrowprops=dict(facecolor='black', width=0.5, shrink=0.2,\n",
|
||
" headwidth=5))\n",
|
||
" plt.plot(pos_data_x, pos_data_y, \"ro\")\n",
|
||
"\n",
|
||
"plt.grid(True)\n",
|
||
"\n",
|
||
"save_fig('money_happy_scatterplot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"highlighted_countries = country_stats.loc[list(position_text.keys())]\n",
|
||
"highlighted_countries[[\"Life satisfaction\"]].sort_values(by=\"Life satisfaction\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"country_stats.plot(kind='scatter', figsize=(5,3),\n",
|
||
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
|
||
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
|
||
"\n",
|
||
"w1, w2 = 4.2, 0\n",
|
||
"plt.plot(X, w1 + w2 * 1e-5 * X, \"r\")\n",
|
||
"plt.text(40_000, 4.9, fr\"$\\theta_0 = {w1}$\",\n",
|
||
" fontsize=14, color=\"r\")\n",
|
||
"plt.text(40_000, 4.4, fr\"$\\theta_1 = {w2}$\",\n",
|
||
" fontsize=14, color=\"r\")\n",
|
||
"\n",
|
||
"w1, w2 = 10, -9\n",
|
||
"plt.plot(X, w1 + w2 * 1e-5 * X, \"g\")\n",
|
||
"plt.text(26_000, 8.5, fr\"$\\theta_0 = {w1}$\",\n",
|
||
" fontsize=14, color=\"g\")\n",
|
||
"plt.text(26_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\",\n",
|
||
" fontsize=14, color=\"g\")\n",
|
||
"\n",
|
||
"w1, w2 = 3, 8\n",
|
||
"plt.plot(X, w1 + w2 * 1e-5 * X, \"b\")\n",
|
||
"plt.text(48_000, 8.5, fr\"$\\theta_0 = {w1}$\",\n",
|
||
" fontsize=14, color=\"b\")\n",
|
||
"plt.text(48_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\",\n",
|
||
" fontsize=14, color=\"b\")\n",
|
||
"plt.grid(True)\n",
|
||
"\n",
|
||
"save_fig('tweaking_model_params_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn import linear_model\n",
|
||
"\n",
|
||
"X_sample = country_stats[[\"GDP per capita (USD)\"]].values\n",
|
||
"y_sample = country_stats[[\"Life satisfaction\"]].values\n",
|
||
"\n",
|
||
"lin1 = linear_model.LinearRegression()\n",
|
||
"lin1.fit(X_sample, y_sample)\n",
|
||
"\n",
|
||
"t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]\n",
|
||
"t0, t1"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"country_stats.plot(kind='scatter', figsize=(5,3),\n",
|
||
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
|
||
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
|
||
"plt.plot(X, t0 + t1 * X, \"b\")\n",
|
||
"\n",
|
||
"plt.text(max_gdp - 20_000, min_life_sat + 1.5,\n",
|
||
" fr\"$\\theta_0 = {t0:.2f}$\",\n",
|
||
" fontsize=14, color=\"b\")\n",
|
||
"plt.text(max_gdp - 20_000, min_life_sat + 1,\n",
|
||
" fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\",\n",
|
||
" fontsize=14, color=\"b\")\n",
|
||
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
|
||
"plt.grid(True)\n",
|
||
"\n",
|
||
"save_fig('best_fit_model_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"gdp_year = 2020\n",
|
||
"gdp_per_capita_clean = gdp_per_capita[gdp_per_capita[\"Year\"] == gdp_year]\n",
|
||
"gdp_per_capita_clean = gdp_per_capita_clean.drop([\"Code\", \"Year\"], axis=1)\n",
|
||
"gdp_per_capita_clean.columns = [\"Country\", \"GDP per capita (USD)\"]\n",
|
||
"gdp_per_capita_clean.set_index(\"Country\", inplace=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"cyprus_gdp_per_capita = gdp_per_capita_clean.loc[\"Cyprus\"][\"GDP per capita (USD)\"]\n",
|
||
"print(cyprus_gdp_per_capita)\n",
|
||
"cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0, 0]\n",
|
||
"cyprus_predicted_life_satisfaction"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"country_stats.plot(kind='scatter', figsize=(5,3),\n",
|
||
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
|
||
"\n",
|
||
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
|
||
"plt.plot(X, t0 + t1 * X, \"b\")\n",
|
||
"\n",
|
||
"plt.text(min_gdp + 15_000, max_life_sat - 1.5,\n",
|
||
" fr\"$\\theta_0 = {t0:.2f}$\",\n",
|
||
" fontsize=14, color=\"b\")\n",
|
||
"plt.text(min_gdp + 15_000, max_life_sat - 1,\n",
|
||
" fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\",\n",
|
||
" fontsize=14, color=\"b\")\n",
|
||
"\n",
|
||
"plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita],\n",
|
||
" [min_life_sat, cyprus_predicted_life_satisfaction], \"r--\")\n",
|
||
"plt.text(cyprus_gdp_per_capita + 1000, 5.0,\n",
|
||
" fr\"Prediction = {cyprus_predicted_life_satisfaction:.2f}\",\n",
|
||
" fontsize=14, color=\"r\")\n",
|
||
"plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, \"ro\")\n",
|
||
"\n",
|
||
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
|
||
"plt.grid(True)\n",
|
||
"\n",
|
||
"save_fig('cyprus_prediction_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"missing_data = full_country_stats[(full_country_stats[gdppc] < min_gdp) |\n",
|
||
" (full_country_stats[gdppc] > max_gdp)]\n",
|
||
"missing_data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"position_text2 = {\n",
|
||
" \"South Africa\": (20_000, 4.2),\n",
|
||
" \"Colombia\": (6_000, 8.2),\n",
|
||
" \"Brazil\": (18_000, 7.8),\n",
|
||
" \"Mexico\": (24_000, 7.4),\n",
|
||
" \"Chile\": (30_000, 7.0),\n",
|
||
" \"Norway\": (60_000, 6.2),\n",
|
||
" \"Switzerland\": (65_000, 5.7),\n",
|
||
" \"Ireland\": (80_000, 5.5),\n",
|
||
" \"Luxembourg\": (100_000, 5.0),\n",
|
||
"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"full_country_stats.plot(kind='scatter', figsize=(8,3),\n",
|
||
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
|
||
"\n",
|
||
"for country, pos_text in position_text2.items():\n",
|
||
" pos_data_x, pos_data_y = missing_data.loc[country]\n",
|
||
" plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n",
|
||
" arrowprops=dict(facecolor='black', width=0.5, shrink=0.1,\n",
|
||
" headwidth=5))\n",
|
||
" plt.plot(pos_data_x, pos_data_y, \"rs\")\n",
|
||
"\n",
|
||
"X = np.linspace(0, 115_000, 1000)\n",
|
||
"plt.plot(X, t0 + t1 * X, \"b:\")\n",
|
||
"\n",
|
||
"lin_reg_full = linear_model.LinearRegression()\n",
|
||
"Xfull = np.c_[full_country_stats[\"GDP per capita (USD)\"]]\n",
|
||
"yfull = np.c_[full_country_stats[\"Life satisfaction\"]]\n",
|
||
"lin_reg_full.fit(Xfull, yfull)\n",
|
||
"\n",
|
||
"t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n",
|
||
"X = np.linspace(0, 115_000, 1000)\n",
|
||
"plt.plot(X, t0full + t1full * X, \"k\")\n",
|
||
"\n",
|
||
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
|
||
"plt.grid(True)\n",
|
||
"\n",
|
||
"save_fig('representative_training_data_scatterplot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn import preprocessing\n",
|
||
"from sklearn import pipeline\n",
|
||
"\n",
|
||
"full_country_stats.plot(kind='scatter', figsize=(8,3),\n",
|
||
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
|
||
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
|
||
"\n",
|
||
"poly = preprocessing.PolynomialFeatures(degree=10, include_bias=False)\n",
|
||
"scaler = preprocessing.StandardScaler()\n",
|
||
"lin_reg2 = linear_model.LinearRegression()\n",
|
||
"\n",
|
||
"pipeline_reg = pipeline.Pipeline([\n",
|
||
" ('poly', poly),\n",
|
||
" ('scal', scaler),\n",
|
||
" ('lin', lin_reg2)])\n",
|
||
"pipeline_reg.fit(Xfull, yfull)\n",
|
||
"curve = pipeline_reg.predict(X[:, np.newaxis])\n",
|
||
"plt.plot(X, curve)\n",
|
||
"plt.grid(True)\n",
|
||
"\n",
|
||
"save_fig('overfitting_model_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"w_countries = [c for c in full_country_stats.index if \"W\" in c.upper()]\n",
|
||
"full_country_stats.loc[w_countries][\"Life satisfaction\"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"all_w_countries = [c for c in gdp_per_capita_clean.index if \"W\" in c.upper()]\n",
|
||
"gdp_per_capita_clean.loc[all_w_countries].sort_values(by=gdppc)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"plt.figure(figsize=(8,3))\n",
|
||
"\n",
|
||
"plt.xlabel(\"GDP per capita (USD)\")\n",
|
||
"plt.ylabel('Life satisfaction')\n",
|
||
"\n",
|
||
"plt.plot(list(country_stats[\"GDP per capita (USD)\"]),\n",
|
||
" list(country_stats[\"Life satisfaction\"]), \"bo\")\n",
|
||
"plt.plot(list(missing_data[\"GDP per capita (USD)\"]),\n",
|
||
" list(missing_data[\"Life satisfaction\"]), \"rs\")\n",
|
||
"\n",
|
||
"X = np.linspace(0, 115_000, 1000)\n",
|
||
"plt.plot(X, t0full + t1full * X, \"r--\", label=\"Linear model on all data\")\n",
|
||
"plt.plot(X, t0 + t1*X, \"b:\", label=\"Linear model on partial data\")\n",
|
||
"\n",
|
||
"ridge = linear_model.Ridge(alpha=10**9.5)\n",
|
||
"Xsample = country_stats[[\"GDP per capita (USD)\"]]\n",
|
||
"ysample = country_stats[[\"Life satisfaction\"]]\n",
|
||
"ridge.fit(Xsample, ysample)\n",
|
||
"t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n",
|
||
"plt.plot(X, t0ridge + t1ridge * X, \"b\", label=\"Regularized linear model on partial data\")\n",
|
||
"\n",
|
||
"plt.legend(loc=\"lower right\")\n",
|
||
"plt.axis([0, 115_000, 0, 10])\n",
|
||
"plt.xlabel(\"GDP per capita (USD)\")\n",
|
||
"\n",
|
||
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
|
||
"plt.grid(True)\n",
|
||
"\n",
|
||
"save_fig('ridge_model_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "homl3",
|
||
"language": "python",
|
||
"name": "homl3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.8.12"
|
||
},
|
||
"metadata": {
|
||
"interpreter": {
|
||
"hash": "22b0ec00cd9e253c751e6d2619fc0bb2d18ed12980de3246690d5be49479dd65"
|
||
}
|
||
},
|
||
"nav_menu": {},
|
||
"toc": {
|
||
"navigate_menu": true,
|
||
"number_sections": true,
|
||
"sideBar": true,
|
||
"threshold": 6,
|
||
"toc_cell": false,
|
||
"toc_section_display": "block",
|
||
"toc_window_display": true
|
||
},
|
||
"toc_position": {
|
||
"height": "616px",
|
||
"left": "0px",
|
||
"right": "20px",
|
||
"top": "106px",
|
||
"width": "213px"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|