Clean code and set default font size

main
Aurélien Geron 2021-10-19 23:15:36 +13:00
parent 4a2d0ea1ae
commit f57e80968d
1 changed files with 253 additions and 237 deletions

View File

@ -27,7 +27,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Code example 1-1"
"# Setup"
]
},
{
@ -78,6 +78,7 @@
"%matplotlib inline\n",
"import matplotlib as mpl\n",
"\n",
"mpl.rc('font', size=12)\n",
"mpl.rc('axes', labelsize=14)\n",
"mpl.rc('xtick', labelsize=12)\n",
"mpl.rc('ytick', labelsize=12)"
@ -104,13 +105,19 @@
" urllib.request.urlretrieve(url, datapath / filename)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Code example 1-1"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Code example\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
@ -124,10 +131,9 @@
"y = lifesat[[\"Life satisfaction\"]].values\n",
"\n",
"# Visualize the data\n",
"lifesat.plot(kind='scatter',\n",
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
"lifesat.plot(kind='scatter', grid=True,\n",
" x=\"GDP per capita (USD)\", y=\"Life satisfaction\")\n",
"plt.axis([23_500, 62_500, 4, 9])\n",
"plt.grid(True)\n",
"plt.show()\n",
"\n",
"# Select a linear model\n",
@ -149,15 +155,17 @@
"lines:\n",
"\n",
"```python\n",
"import sklearn.linear_model\n",
"model = sklearn.linear_model.LinearRegression()\n",
"from sklearn.linear_model import LinearRegression\n",
"\n",
"model = LinearRegression()\n",
"```\n",
"\n",
"with these two:\n",
"\n",
"```python\n",
"import sklearn.neighbors\n",
"model = sklearn.neighbors.KNeighborsRegressor(n_neighbors=3)\n",
"from sklearn.neighbors import KNeighborsRegressor\n",
"\n",
"model = KNeighborsRegressor(n_neighbors=3)\n",
"```"
]
},
@ -168,9 +176,9 @@
"outputs": [],
"source": [
"# Select a 3-Nearest Neighbors regression model\n",
"import sklearn.neighbors\n",
"from sklearn.neighbors import KNeighborsRegressor\n",
"\n",
"model = sklearn.neighbors.KNeighborsRegressor(n_neighbors=3)\n",
"model = KNeighborsRegressor(n_neighbors=3)\n",
"\n",
"# Train the model\n",
"model.fit(X,y)\n",
@ -274,7 +282,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"This function just merges the OECD's life satisfaction data and the World Bank's GDP per capita data:"
"Preprocess the GDP per capita data to keep only the year 2020:"
]
},
{
@ -283,19 +291,23 @@
"metadata": {},
"outputs": [],
"source": [
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
"gdp_year = 2020\n",
" gdppc = \"GDP per capita (USD)\"\n",
" oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
" oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
"gdppc_col = \"GDP per capita (USD)\"\n",
"lifesat_col = \"Life satisfaction\"\n",
"\n",
"gdp_per_capita = gdp_per_capita[gdp_per_capita[\"Year\"] == gdp_year]\n",
"gdp_per_capita = gdp_per_capita.drop([\"Code\", \"Year\"], axis=1)\n",
" gdp_per_capita.columns = [\"Country\", gdppc]\n",
"gdp_per_capita.columns = [\"Country\", gdppc_col]\n",
"gdp_per_capita.set_index(\"Country\", inplace=True)\n",
" full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
" left_index=True, right_index=True)\n",
" full_country_stats.sort_values(by=gdppc, inplace=True)\n",
" return full_country_stats[[gdppc, 'Life satisfaction']]"
"\n",
"gdp_per_capita.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Preprocess the OECD BLI data to keep only the `Life satisfaction` column:"
]
},
{
@ -304,8 +316,31 @@
"metadata": {},
"outputs": [],
"source": [
"full_country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)\n",
"full_country_stats.to_csv(datapath / \"lifesat_full.csv\")"
"oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
"oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
"\n",
"oecd_bli.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's merge the life satisfaction data and the GDP per capita data, keeping only the GDP per capita and Life satisfaction columns:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
" left_index=True, right_index=True)\n",
"full_country_stats.sort_values(by=gdppc_col, inplace=True)\n",
"full_country_stats = full_country_stats[[gdppc_col, lifesat_col]]\n",
"\n",
"full_country_stats.head()"
]
},
{
@ -315,56 +350,18 @@
"To illustrate the risk of overfitting, I use only part of the data in most figures (all countries with a GDP per capita between `min_gdp` and `max_gdp`). Later in the chapter I reveal the missing countries, and show that they don't follow the same linear trend at all."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"gdppc = \"GDP per capita (USD)\"\n",
"min_gdp = 23_500\n",
"max_gdp = 62_500\n",
"country_stats = full_country_stats[(full_country_stats[gdppc] >= min_gdp) &\n",
" (full_country_stats[gdppc] <= max_gdp)]\n",
"country_stats.to_csv(datapath / \"lifesat.csv\")\n",
"country_stats.head()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"country_stats.plot(kind='scatter', figsize=(5,3),\n",
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
"min_gdp = 23_500\n",
"max_gdp = 62_500\n",
"\n",
"min_life_sat = 4\n",
"max_life_sat = 9\n",
"\n",
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
"position_text = {\n",
" \"Hungary\": (28_000, 4.2),\n",
" \"France\": (40_000, 5),\n",
" \"New Zealand\": (30_000, 8),\n",
" \"Australia\": (50_000, 5.5),\n",
" \"United States\": (59_000, 5.5),\n",
" \"Denmark\": (46_000, 8.5)\n",
"}\n",
"\n",
"for country, pos_text in position_text.items():\n",
" pos_data_x, pos_data_y = country_stats[[\"GDP per capita (USD)\",\n",
" \"Life satisfaction\"]].loc[country]\n",
" country = \"U.S.\" if country == \"United States\" else country\n",
" plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n",
" arrowprops=dict(facecolor='black', width=0.5, shrink=0.2,\n",
" headwidth=5))\n",
" plt.plot(pos_data_x, pos_data_y, \"ro\")\n",
"\n",
"plt.grid(True)\n",
"\n",
"save_fig('money_happy_scatterplot')\n",
"plt.show()"
"country_stats = full_country_stats[(full_country_stats[gdppc_col] >= min_gdp) &\n",
" (full_country_stats[gdppc_col] <= max_gdp)]\n",
"country_stats.head()"
]
},
{
@ -373,8 +370,8 @@
"metadata": {},
"outputs": [],
"source": [
"highlighted_countries = country_stats.loc[list(position_text.keys())]\n",
"highlighted_countries[[gdppc, \"Life satisfaction\"]].sort_values(by=gdppc)"
"country_stats.to_csv(datapath / \"lifesat.csv\")\n",
"full_country_stats.to_csv(datapath / \"lifesat_full.csv\")"
]
},
{
@ -383,37 +380,34 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n",
" x=gdppc_col, y=lifesat_col)\n",
"\n",
"min_life_sat = 4\n",
"max_life_sat = 9\n",
"\n",
"position_text = {\n",
" \"Hungary\": (28_000, 4.2),\n",
" \"France\": (40_000, 5),\n",
" \"New Zealand\": (28_000, 8.2),\n",
" \"Australia\": (50_000, 5.5),\n",
" \"United States\": (59_000, 5.5),\n",
" \"Denmark\": (46_000, 8.5)\n",
"}\n",
"\n",
"for country, pos_text in position_text.items():\n",
" pos_data_x = country_stats[gdppc_col].loc[country]\n",
" pos_data_y = country_stats[lifesat_col].loc[country]\n",
" country = \"U.S.\" if country == \"United States\" else country\n",
" plt.annotate(country, xy=(pos_data_x, pos_data_y),\n",
" xytext=pos_text,\n",
" arrowprops=dict(facecolor='black', width=0.5,\n",
" shrink=0.15, headwidth=5))\n",
" plt.plot(pos_data_x, pos_data_y, \"ro\")\n",
"\n",
"country_stats.plot(kind='scatter', figsize=(5,3),\n",
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
"\n",
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
"\n",
"w1, w2 = 4.2, 0\n",
"plt.plot(X, w1 + w2 * 1e-5 * X, \"r\")\n",
"plt.text(40_000, 4.9, fr\"$\\theta_0 = {w1}$\",\n",
" fontsize=14, color=\"r\")\n",
"plt.text(40_000, 4.4, fr\"$\\theta_1 = {w2}$\",\n",
" fontsize=14, color=\"r\")\n",
"\n",
"w1, w2 = 10, -9\n",
"plt.plot(X, w1 + w2 * 1e-5 * X, \"g\")\n",
"plt.text(26_000, 8.5, fr\"$\\theta_0 = {w1}$\",\n",
" fontsize=14, color=\"g\")\n",
"plt.text(26_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\",\n",
" fontsize=14, color=\"g\")\n",
"\n",
"w1, w2 = 3, 8\n",
"plt.plot(X, w1 + w2 * 1e-5 * X, \"b\")\n",
"plt.text(48_000, 8.5, fr\"$\\theta_0 = {w1}$\",\n",
" fontsize=14, color=\"b\")\n",
"plt.text(48_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\",\n",
" fontsize=14, color=\"b\")\n",
"plt.grid(True)\n",
"\n",
"save_fig('tweaking_model_params_plot')\n",
"save_fig('money_happy_scatterplot')\n",
"plt.show()"
]
},
@ -423,16 +417,8 @@
"metadata": {},
"outputs": [],
"source": [
"from sklearn import linear_model\n",
"\n",
"X_sample = country_stats[[\"GDP per capita (USD)\"]].values\n",
"y_sample = country_stats[[\"Life satisfaction\"]].values\n",
"\n",
"lin1 = linear_model.LinearRegression()\n",
"lin1.fit(X_sample, y_sample)\n",
"\n",
"t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]\n",
"t0, t1"
"highlighted_countries = country_stats.loc[list(position_text.keys())]\n",
"highlighted_countries[[gdppc_col, lifesat_col]].sort_values(by=gdppc_col)"
]
},
{
@ -441,23 +427,29 @@
"metadata": {},
"outputs": [],
"source": [
"country_stats.plot(kind='scatter', figsize=(5,3),\n",
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n",
" x=gdppc_col, y=lifesat_col)\n",
"\n",
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
"plt.plot(X, t0 + t1 * X, \"b\")\n",
"\n",
"plt.text(max_gdp - 20_000, min_life_sat + 1.5,\n",
" fr\"$\\theta_0 = {t0:.2f}$\",\n",
" fontsize=14, color=\"b\")\n",
"plt.text(max_gdp - 20_000, min_life_sat + 1,\n",
" fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\",\n",
" fontsize=14, color=\"b\")\n",
"w1, w2 = 4.2, 0\n",
"plt.plot(X, w1 + w2 * 1e-5 * X, \"r\")\n",
"plt.text(40_000, 4.9, fr\"$\\theta_0 = {w1}$\", color=\"r\")\n",
"plt.text(40_000, 4.4, fr\"$\\theta_1 = {w2}$\", color=\"r\")\n",
"\n",
"w1, w2 = 10, -9\n",
"plt.plot(X, w1 + w2 * 1e-5 * X, \"g\")\n",
"plt.text(26_000, 8.5, fr\"$\\theta_0 = {w1}$\", color=\"g\")\n",
"plt.text(26_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\", color=\"g\")\n",
"\n",
"w1, w2 = 3, 8\n",
"plt.plot(X, w1 + w2 * 1e-5 * X, \"b\")\n",
"plt.text(48_000, 8.5, fr\"$\\theta_0 = {w1}$\", color=\"b\")\n",
"plt.text(48_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\", color=\"b\")\n",
"\n",
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
"plt.grid(True)\n",
"\n",
"save_fig('best_fit_model_plot')\n",
"save_fig('tweaking_model_params_plot')\n",
"plt.show()"
]
},
@ -467,11 +459,16 @@
"metadata": {},
"outputs": [],
"source": [
"gdp_year = 2020\n",
"gdp_per_capita_clean = gdp_per_capita[gdp_per_capita[\"Year\"] == gdp_year]\n",
"gdp_per_capita_clean = gdp_per_capita_clean.drop([\"Code\", \"Year\"], axis=1)\n",
"gdp_per_capita_clean.columns = [\"Country\", \"GDP per capita (USD)\"]\n",
"gdp_per_capita_clean.set_index(\"Country\", inplace=True)"
"from sklearn import linear_model\n",
"\n",
"X_sample = country_stats[[gdppc_col]].values\n",
"y_sample = country_stats[[lifesat_col]].values\n",
"\n",
"lin1 = linear_model.LinearRegression()\n",
"lin1.fit(X_sample, y_sample)\n",
"\n",
"t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]\n",
"print(f\"θ0={t0:.2f}, θ1={t1:.2e}\")"
]
},
{
@ -480,10 +477,21 @@
"metadata": {},
"outputs": [],
"source": [
"cyprus_gdp_per_capita = gdp_per_capita_clean.loc[\"Cyprus\"][\"GDP per capita (USD)\"]\n",
"print(cyprus_gdp_per_capita)\n",
"cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0, 0]\n",
"cyprus_predicted_life_satisfaction"
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n",
" x=gdppc_col, y=lifesat_col)\n",
"\n",
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
"plt.plot(X, t0 + t1 * X, \"b\")\n",
"\n",
"plt.text(max_gdp - 20_000, min_life_sat + 1.5,\n",
" fr\"$\\theta_0 = {t0:.2f}$\", color=\"b\")\n",
"plt.text(max_gdp - 20_000, min_life_sat + 1,\n",
" fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\", color=\"b\")\n",
"\n",
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
"\n",
"save_fig('best_fit_model_plot')\n",
"plt.show()"
]
},
{
@ -492,31 +500,8 @@
"metadata": {},
"outputs": [],
"source": [
"country_stats.plot(kind='scatter', figsize=(5,3),\n",
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
"\n",
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
"plt.plot(X, t0 + t1 * X, \"b\")\n",
"\n",
"plt.text(min_gdp + 15_000, max_life_sat - 1.5,\n",
" fr\"$\\theta_0 = {t0:.2f}$\",\n",
" fontsize=14, color=\"b\")\n",
"plt.text(min_gdp + 15_000, max_life_sat - 1,\n",
" fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\",\n",
" fontsize=14, color=\"b\")\n",
"\n",
"plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita],\n",
" [min_life_sat, cyprus_predicted_life_satisfaction], \"r--\")\n",
"plt.text(cyprus_gdp_per_capita + 1000, 5.0,\n",
" fr\"Prediction = {cyprus_predicted_life_satisfaction:.2f}\",\n",
" fontsize=14, color=\"r\")\n",
"plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, \"ro\")\n",
"\n",
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
"plt.grid(True)\n",
"\n",
"save_fig('cyprus_prediction_plot')\n",
"plt.show()"
"cyprus_gdp_per_capita = gdp_per_capita[gdppc_col].loc[\"Cyprus\"]\n",
"cyprus_gdp_per_capita"
]
},
{
@ -525,9 +510,8 @@
"metadata": {},
"outputs": [],
"source": [
"missing_data = full_country_stats[(full_country_stats[gdppc] < min_gdp) |\n",
" (full_country_stats[gdppc] > max_gdp)]\n",
"missing_data"
"cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0, 0]\n",
"cyprus_predicted_life_satisfaction"
]
},
{
@ -536,17 +520,27 @@
"metadata": {},
"outputs": [],
"source": [
"position_text2 = {\n",
" \"South Africa\": (20_000, 4.2),\n",
" \"Colombia\": (6_000, 8.2),\n",
" \"Brazil\": (18_000, 7.8),\n",
" \"Mexico\": (24_000, 7.4),\n",
" \"Chile\": (30_000, 7.0),\n",
" \"Norway\": (60_000, 6.2),\n",
" \"Switzerland\": (65_000, 5.7),\n",
" \"Ireland\": (80_000, 5.5),\n",
" \"Luxembourg\": (100_000, 5.0),\n",
"}"
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n",
" x=gdppc_col, y=lifesat_col)\n",
"\n",
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
"plt.plot(X, t0 + t1 * X, \"b\")\n",
"\n",
"plt.text(min_gdp + 15_000, max_life_sat - 1.5,\n",
" fr\"$\\theta_0 = {t0:.2f}$\", color=\"b\")\n",
"plt.text(min_gdp + 15_000, max_life_sat - 1,\n",
" fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\", color=\"b\")\n",
"\n",
"plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita],\n",
" [min_life_sat, cyprus_predicted_life_satisfaction], \"r--\")\n",
"plt.text(cyprus_gdp_per_capita + 1000, 5.0,\n",
" fr\"Prediction = {cyprus_predicted_life_satisfaction:.2f}\", color=\"r\")\n",
"plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, \"ro\")\n",
"\n",
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
"\n",
"save_fig('cyprus_prediction_plot')\n",
"plt.show()"
]
},
{
@ -555,33 +549,9 @@
"metadata": {},
"outputs": [],
"source": [
"full_country_stats.plot(kind='scatter', figsize=(8,3),\n",
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
"\n",
"for country, pos_text in position_text2.items():\n",
" pos_data_x, pos_data_y = missing_data.loc[country]\n",
" plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n",
" arrowprops=dict(facecolor='black', width=0.5, shrink=0.1,\n",
" headwidth=5))\n",
" plt.plot(pos_data_x, pos_data_y, \"rs\")\n",
"\n",
"X = np.linspace(0, 115_000, 1000)\n",
"plt.plot(X, t0 + t1 * X, \"b:\")\n",
"\n",
"lin_reg_full = linear_model.LinearRegression()\n",
"Xfull = np.c_[full_country_stats[\"GDP per capita (USD)\"]]\n",
"yfull = np.c_[full_country_stats[\"Life satisfaction\"]]\n",
"lin_reg_full.fit(Xfull, yfull)\n",
"\n",
"t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n",
"X = np.linspace(0, 115_000, 1000)\n",
"plt.plot(X, t0full + t1full * X, \"k\")\n",
"\n",
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
"plt.grid(True)\n",
"\n",
"save_fig('representative_training_data_scatterplot')\n",
"plt.show()"
"missing_data = full_country_stats[(full_country_stats[gdppc_col] < min_gdp) |\n",
" (full_country_stats[gdppc_col] > max_gdp)]\n",
"missing_data"
]
},
{
@ -589,13 +559,66 @@
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"position_text_missing_countries = {\n",
" \"South Africa\": (20_000, 4.2),\n",
" \"Colombia\": (6_000, 8.2),\n",
" \"Brazil\": (18_000, 7.8),\n",
" \"Mexico\": (24_000, 7.4),\n",
" \"Chile\": (30_000, 7.0),\n",
" \"Norway\": (51_000, 6.2),\n",
" \"Switzerland\": (62_000, 5.7),\n",
" \"Ireland\": (81_000, 5.2),\n",
" \"Luxembourg\": (92_000, 4.7),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"full_country_stats.plot(kind='scatter', figsize=(8,3),\n",
" x=gdppc_col, y=lifesat_col, grid=True)\n",
"\n",
"for country, pos_text in position_text_missing_countries.items():\n",
" pos_data_x, pos_data_y = missing_data.loc[country]\n",
" plt.annotate(country, xy=(pos_data_x, pos_data_y),\n",
" xytext=pos_text,\n",
" arrowprops=dict(facecolor='black', width=0.5,\n",
" shrink=0.1, headwidth=5))\n",
" plt.plot(pos_data_x, pos_data_y, \"rs\")\n",
"\n",
"X = np.linspace(0, 115_000, 1000)\n",
"plt.plot(X, t0 + t1 * X, \"b:\")\n",
"\n",
"lin_reg_full = linear_model.LinearRegression()\n",
"Xfull = np.c_[full_country_stats[gdppc_col]]\n",
"yfull = np.c_[full_country_stats[lifesat_col]]\n",
"lin_reg_full.fit(Xfull, yfull)\n",
"\n",
"t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n",
"X = np.linspace(0, 115_000, 1000)\n",
"plt.plot(X, t0full + t1full * X, \"k\")\n",
"\n",
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
"\n",
"save_fig('representative_training_data_scatterplot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import preprocessing\n",
"from sklearn import pipeline\n",
"\n",
"full_country_stats.plot(kind='scatter', figsize=(8,3),\n",
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
" x=gdppc_col, y=lifesat_col, grid=True)\n",
"\n",
"poly = preprocessing.PolynomialFeatures(degree=10, include_bias=False)\n",
"scaler = preprocessing.StandardScaler()\n",
@ -608,64 +631,57 @@
"pipeline_reg.fit(Xfull, yfull)\n",
"curve = pipeline_reg.predict(X[:, np.newaxis])\n",
"plt.plot(X, curve)\n",
"plt.grid(True)\n",
"\n",
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
"\n",
"save_fig('overfitting_model_plot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"w_countries = [c for c in full_country_stats.index if \"W\" in c.upper()]\n",
"full_country_stats.loc[w_countries][\"Life satisfaction\"]"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"all_w_countries = [c for c in gdp_per_capita_clean.index if \"W\" in c.upper()]\n",
"gdp_per_capita_clean.loc[all_w_countries].sort_values(by=gdppc)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(8,3))\n",
"\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"plt.ylabel('Life satisfaction')\n",
"\n",
"country_stats.plot(ax=plt.gca(), kind='scatter',\n",
" x=gdppc, y='Life satisfaction')\n",
"missing_data.plot(ax=plt.gca(), kind='scatter',\n",
" x=gdppc, y='Life satisfaction', marker=\"s\", color=\"r\")\n",
"w_countries = [c for c in full_country_stats.index if \"W\" in c.upper()]\n",
"full_country_stats.loc[w_countries][lifesat_col]"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"all_w_countries = [c for c in gdp_per_capita.index if \"W\" in c.upper()]\n",
"gdp_per_capita.loc[all_w_countries].sort_values(by=gdppc_col)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"country_stats.plot(kind='scatter', x=gdppc_col, y=lifesat_col, figsize=(8,3))\n",
"missing_data.plot(kind='scatter', x=gdppc_col, y=lifesat_col,\n",
" marker=\"s\", color=\"r\", grid=True, ax=plt.gca())\n",
"\n",
"X = np.linspace(0, 115_000, 1000)\n",
"plt.plot(X, t0full + t1full * X, \"k-\", label=\"Linear model on all data\")\n",
"plt.plot(X, t0 + t1*X, \"b:\", label=\"Linear model on partial data\")\n",
"plt.plot(X, t0full + t1full * X, \"k-\", label=\"Linear model on all data\")\n",
"\n",
"ridge = linear_model.Ridge(alpha=10**9.5)\n",
"Xsample = country_stats[[\"GDP per capita (USD)\"]]\n",
"ysample = country_stats[[\"Life satisfaction\"]]\n",
"ridge.fit(Xsample, ysample)\n",
"X_sample = country_stats[[gdppc_col]]\n",
"y_sample = country_stats[[lifesat_col]]\n",
"ridge.fit(X_sample, y_sample)\n",
"t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n",
"plt.plot(X, t0ridge + t1ridge * X, \"b--\", label=\"Regularized linear model on partial data\")\n",
"\n",
"plt.plot(X, t0ridge + t1ridge * X, \"b--\",\n",
" label=\"Regularized linear model on partial data\")\n",
"plt.legend(loc=\"lower right\")\n",
"plt.axis([0, 115_000, 0, 10])\n",
"\n",
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
"plt.grid(True)\n",
"\n",
"save_fig('ridge_model_plot')\n",
"plt.show()"