diff --git a/01_the_machine_learning_landscape.ipynb b/01_the_machine_learning_landscape.ipynb index 618aff4..f2101b9 100644 --- a/01_the_machine_learning_landscape.ipynb +++ b/01_the_machine_learning_landscape.ipynb @@ -27,7 +27,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Code example 1-1" + "# Setup" ] }, { @@ -78,6 +78,7 @@ "%matplotlib inline\n", "import matplotlib as mpl\n", "\n", + "mpl.rc('font', size=12)\n", "mpl.rc('axes', labelsize=14)\n", "mpl.rc('xtick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)" @@ -104,13 +105,19 @@ " urllib.request.urlretrieve(url, datapath / filename)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Code example 1-1" + ] + }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "# Code example\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", @@ -124,10 +131,9 @@ "y = lifesat[[\"Life satisfaction\"]].values\n", "\n", "# Visualize the data\n", - "lifesat.plot(kind='scatter',\n", - " x=\"GDP per capita (USD)\", y='Life satisfaction')\n", + "lifesat.plot(kind='scatter', grid=True,\n", + " x=\"GDP per capita (USD)\", y=\"Life satisfaction\")\n", "plt.axis([23_500, 62_500, 4, 9])\n", - "plt.grid(True)\n", "plt.show()\n", "\n", "# Select a linear model\n", @@ -149,15 +155,17 @@ "lines:\n", "\n", "```python\n", - "import sklearn.linear_model\n", - "model = sklearn.linear_model.LinearRegression()\n", + "from sklearn.linear_model import LinearRegression\n", + "\n", + "model = LinearRegression()\n", "```\n", "\n", "with these two:\n", "\n", "```python\n", - "import sklearn.neighbors\n", - "model = sklearn.neighbors.KNeighborsRegressor(n_neighbors=3)\n", + "from sklearn.neighbors import KNeighborsRegressor\n", + "\n", + "model = KNeighborsRegressor(n_neighbors=3)\n", "```" ] }, @@ -168,9 +176,9 @@ "outputs": [], "source": [ "# Select a 3-Nearest Neighbors regression model\n", - "import sklearn.neighbors\n", + "from sklearn.neighbors import KNeighborsRegressor\n", "\n", - "model = sklearn.neighbors.KNeighborsRegressor(n_neighbors=3)\n", + "model = KNeighborsRegressor(n_neighbors=3)\n", "\n", "# Train the model\n", "model.fit(X,y)\n", @@ -274,7 +282,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This function just merges the OECD's life satisfaction data and the World Bank's GDP per capita data:" + "Preprocess the GDP per capita data to keep only the year 2020:" ] }, { @@ -283,19 +291,23 @@ "metadata": {}, "outputs": [], "source": [ - "def prepare_country_stats(oecd_bli, gdp_per_capita):\n", - " gdp_year = 2020\n", - " gdppc = \"GDP per capita (USD)\"\n", - " oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n", - " oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n", - " gdp_per_capita = gdp_per_capita[gdp_per_capita[\"Year\"] == gdp_year]\n", - " gdp_per_capita = gdp_per_capita.drop([\"Code\", \"Year\"], axis=1)\n", - " gdp_per_capita.columns = [\"Country\", gdppc]\n", - " gdp_per_capita.set_index(\"Country\", inplace=True)\n", - " full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n", - " left_index=True, right_index=True)\n", - " full_country_stats.sort_values(by=gdppc, inplace=True)\n", - " return full_country_stats[[gdppc, 'Life satisfaction']]" + "gdp_year = 2020\n", + "gdppc_col = \"GDP per capita (USD)\"\n", + "lifesat_col = \"Life satisfaction\"\n", + "\n", + "gdp_per_capita = gdp_per_capita[gdp_per_capita[\"Year\"] == gdp_year]\n", + "gdp_per_capita = gdp_per_capita.drop([\"Code\", \"Year\"], axis=1)\n", + "gdp_per_capita.columns = [\"Country\", gdppc_col]\n", + "gdp_per_capita.set_index(\"Country\", inplace=True)\n", + "\n", + "gdp_per_capita.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Preprocess the OECD BLI data to keep only the `Life satisfaction` column:" ] }, { @@ -304,8 +316,31 @@ "metadata": {}, "outputs": [], "source": [ - "full_country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)\n", - "full_country_stats.to_csv(datapath / \"lifesat_full.csv\")" + "oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n", + "oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n", + "\n", + "oecd_bli.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's merge the life satisfaction data and the GDP per capita data, keeping only the GDP per capita and Life satisfaction columns:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n", + " left_index=True, right_index=True)\n", + "full_country_stats.sort_values(by=gdppc_col, inplace=True)\n", + "full_country_stats = full_country_stats[[gdppc_col, lifesat_col]]\n", + "\n", + "full_country_stats.head()" ] }, { @@ -315,56 +350,18 @@ "To illustrate the risk of overfitting, I use only part of the data in most figures (all countries with a GDP per capita between `min_gdp` and `max_gdp`). Later in the chapter I reveal the missing countries, and show that they don't follow the same linear trend at all." ] }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "gdppc = \"GDP per capita (USD)\"\n", - "min_gdp = 23_500\n", - "max_gdp = 62_500\n", - "country_stats = full_country_stats[(full_country_stats[gdppc] >= min_gdp) &\n", - " (full_country_stats[gdppc] <= max_gdp)]\n", - "country_stats.to_csv(datapath / \"lifesat.csv\")\n", - "country_stats.head()" - ] - }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ - "country_stats.plot(kind='scatter', figsize=(5,3),\n", - " x=\"GDP per capita (USD)\", y='Life satisfaction')\n", + "min_gdp = 23_500\n", + "max_gdp = 62_500\n", "\n", - "min_life_sat = 4\n", - "max_life_sat = 9\n", - "\n", - "plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n", - "position_text = {\n", - " \"Hungary\": (28_000, 4.2),\n", - " \"France\": (40_000, 5),\n", - " \"New Zealand\": (30_000, 8),\n", - " \"Australia\": (50_000, 5.5),\n", - " \"United States\": (59_000, 5.5),\n", - " \"Denmark\": (46_000, 8.5)\n", - "}\n", - "\n", - "for country, pos_text in position_text.items():\n", - " pos_data_x, pos_data_y = country_stats[[\"GDP per capita (USD)\",\n", - " \"Life satisfaction\"]].loc[country]\n", - " country = \"U.S.\" if country == \"United States\" else country\n", - " plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n", - " arrowprops=dict(facecolor='black', width=0.5, shrink=0.2,\n", - " headwidth=5))\n", - " plt.plot(pos_data_x, pos_data_y, \"ro\")\n", - "\n", - "plt.grid(True)\n", - "\n", - "save_fig('money_happy_scatterplot')\n", - "plt.show()" + "country_stats = full_country_stats[(full_country_stats[gdppc_col] >= min_gdp) &\n", + " (full_country_stats[gdppc_col] <= max_gdp)]\n", + "country_stats.head()" ] }, { @@ -373,8 +370,8 @@ "metadata": {}, "outputs": [], "source": [ - "highlighted_countries = country_stats.loc[list(position_text.keys())]\n", - "highlighted_countries[[gdppc, \"Life satisfaction\"]].sort_values(by=gdppc)" + "country_stats.to_csv(datapath / \"lifesat.csv\")\n", + "full_country_stats.to_csv(datapath / \"lifesat_full.csv\")" ] }, { @@ -383,37 +380,34 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", + "country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", + " x=gdppc_col, y=lifesat_col)\n", + "\n", + "min_life_sat = 4\n", + "max_life_sat = 9\n", + "\n", + "position_text = {\n", + " \"Hungary\": (28_000, 4.2),\n", + " \"France\": (40_000, 5),\n", + " \"New Zealand\": (28_000, 8.2),\n", + " \"Australia\": (50_000, 5.5),\n", + " \"United States\": (59_000, 5.5),\n", + " \"Denmark\": (46_000, 8.5)\n", + "}\n", + "\n", + "for country, pos_text in position_text.items():\n", + " pos_data_x = country_stats[gdppc_col].loc[country]\n", + " pos_data_y = country_stats[lifesat_col].loc[country]\n", + " country = \"U.S.\" if country == \"United States\" else country\n", + " plt.annotate(country, xy=(pos_data_x, pos_data_y),\n", + " xytext=pos_text,\n", + " arrowprops=dict(facecolor='black', width=0.5,\n", + " shrink=0.15, headwidth=5))\n", + " plt.plot(pos_data_x, pos_data_y, \"ro\")\n", "\n", - "country_stats.plot(kind='scatter', figsize=(5,3),\n", - " x=\"GDP per capita (USD)\", y='Life satisfaction')\n", "plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n", "\n", - "X = np.linspace(min_gdp, max_gdp, 1000)\n", - "\n", - "w1, w2 = 4.2, 0\n", - "plt.plot(X, w1 + w2 * 1e-5 * X, \"r\")\n", - "plt.text(40_000, 4.9, fr\"$\\theta_0 = {w1}$\",\n", - " fontsize=14, color=\"r\")\n", - "plt.text(40_000, 4.4, fr\"$\\theta_1 = {w2}$\",\n", - " fontsize=14, color=\"r\")\n", - "\n", - "w1, w2 = 10, -9\n", - "plt.plot(X, w1 + w2 * 1e-5 * X, \"g\")\n", - "plt.text(26_000, 8.5, fr\"$\\theta_0 = {w1}$\",\n", - " fontsize=14, color=\"g\")\n", - "plt.text(26_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\",\n", - " fontsize=14, color=\"g\")\n", - "\n", - "w1, w2 = 3, 8\n", - "plt.plot(X, w1 + w2 * 1e-5 * X, \"b\")\n", - "plt.text(48_000, 8.5, fr\"$\\theta_0 = {w1}$\",\n", - " fontsize=14, color=\"b\")\n", - "plt.text(48_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\",\n", - " fontsize=14, color=\"b\")\n", - "plt.grid(True)\n", - "\n", - "save_fig('tweaking_model_params_plot')\n", + "save_fig('money_happy_scatterplot')\n", "plt.show()" ] }, @@ -423,16 +417,8 @@ "metadata": {}, "outputs": [], "source": [ - "from sklearn import linear_model\n", - "\n", - "X_sample = country_stats[[\"GDP per capita (USD)\"]].values\n", - "y_sample = country_stats[[\"Life satisfaction\"]].values\n", - "\n", - "lin1 = linear_model.LinearRegression()\n", - "lin1.fit(X_sample, y_sample)\n", - "\n", - "t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]\n", - "t0, t1" + "highlighted_countries = country_stats.loc[list(position_text.keys())]\n", + "highlighted_countries[[gdppc_col, lifesat_col]].sort_values(by=gdppc_col)" ] }, { @@ -441,23 +427,29 @@ "metadata": {}, "outputs": [], "source": [ - "country_stats.plot(kind='scatter', figsize=(5,3),\n", - " x=\"GDP per capita (USD)\", y='Life satisfaction')\n", - "plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n", + "country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", + " x=gdppc_col, y=lifesat_col)\n", "\n", "X = np.linspace(min_gdp, max_gdp, 1000)\n", - "plt.plot(X, t0 + t1 * X, \"b\")\n", "\n", - "plt.text(max_gdp - 20_000, min_life_sat + 1.5,\n", - " fr\"$\\theta_0 = {t0:.2f}$\",\n", - " fontsize=14, color=\"b\")\n", - "plt.text(max_gdp - 20_000, min_life_sat + 1,\n", - " fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\",\n", - " fontsize=14, color=\"b\")\n", + "w1, w2 = 4.2, 0\n", + "plt.plot(X, w1 + w2 * 1e-5 * X, \"r\")\n", + "plt.text(40_000, 4.9, fr\"$\\theta_0 = {w1}$\", color=\"r\")\n", + "plt.text(40_000, 4.4, fr\"$\\theta_1 = {w2}$\", color=\"r\")\n", + "\n", + "w1, w2 = 10, -9\n", + "plt.plot(X, w1 + w2 * 1e-5 * X, \"g\")\n", + "plt.text(26_000, 8.5, fr\"$\\theta_0 = {w1}$\", color=\"g\")\n", + "plt.text(26_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\", color=\"g\")\n", + "\n", + "w1, w2 = 3, 8\n", + "plt.plot(X, w1 + w2 * 1e-5 * X, \"b\")\n", + "plt.text(48_000, 8.5, fr\"$\\theta_0 = {w1}$\", color=\"b\")\n", + "plt.text(48_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\", color=\"b\")\n", + "\n", "plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n", - "plt.grid(True)\n", "\n", - "save_fig('best_fit_model_plot')\n", + "save_fig('tweaking_model_params_plot')\n", "plt.show()" ] }, @@ -467,11 +459,16 @@ "metadata": {}, "outputs": [], "source": [ - "gdp_year = 2020\n", - "gdp_per_capita_clean = gdp_per_capita[gdp_per_capita[\"Year\"] == gdp_year]\n", - "gdp_per_capita_clean = gdp_per_capita_clean.drop([\"Code\", \"Year\"], axis=1)\n", - "gdp_per_capita_clean.columns = [\"Country\", \"GDP per capita (USD)\"]\n", - "gdp_per_capita_clean.set_index(\"Country\", inplace=True)" + "from sklearn import linear_model\n", + "\n", + "X_sample = country_stats[[gdppc_col]].values\n", + "y_sample = country_stats[[lifesat_col]].values\n", + "\n", + "lin1 = linear_model.LinearRegression()\n", + "lin1.fit(X_sample, y_sample)\n", + "\n", + "t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]\n", + "print(f\"θ0={t0:.2f}, θ1={t1:.2e}\")" ] }, { @@ -480,10 +477,21 @@ "metadata": {}, "outputs": [], "source": [ - "cyprus_gdp_per_capita = gdp_per_capita_clean.loc[\"Cyprus\"][\"GDP per capita (USD)\"]\n", - "print(cyprus_gdp_per_capita)\n", - "cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0, 0]\n", - "cyprus_predicted_life_satisfaction" + "country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", + " x=gdppc_col, y=lifesat_col)\n", + "\n", + "X = np.linspace(min_gdp, max_gdp, 1000)\n", + "plt.plot(X, t0 + t1 * X, \"b\")\n", + "\n", + "plt.text(max_gdp - 20_000, min_life_sat + 1.5,\n", + " fr\"$\\theta_0 = {t0:.2f}$\", color=\"b\")\n", + "plt.text(max_gdp - 20_000, min_life_sat + 1,\n", + " fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\", color=\"b\")\n", + "\n", + "plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n", + "\n", + "save_fig('best_fit_model_plot')\n", + "plt.show()" ] }, { @@ -492,31 +500,8 @@ "metadata": {}, "outputs": [], "source": [ - "country_stats.plot(kind='scatter', figsize=(5,3),\n", - " x=\"GDP per capita (USD)\", y='Life satisfaction')\n", - "\n", - "X = np.linspace(min_gdp, max_gdp, 1000)\n", - "plt.plot(X, t0 + t1 * X, \"b\")\n", - "\n", - "plt.text(min_gdp + 15_000, max_life_sat - 1.5,\n", - " fr\"$\\theta_0 = {t0:.2f}$\",\n", - " fontsize=14, color=\"b\")\n", - "plt.text(min_gdp + 15_000, max_life_sat - 1,\n", - " fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\",\n", - " fontsize=14, color=\"b\")\n", - "\n", - "plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita],\n", - " [min_life_sat, cyprus_predicted_life_satisfaction], \"r--\")\n", - "plt.text(cyprus_gdp_per_capita + 1000, 5.0,\n", - " fr\"Prediction = {cyprus_predicted_life_satisfaction:.2f}\",\n", - " fontsize=14, color=\"r\")\n", - "plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, \"ro\")\n", - "\n", - "plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n", - "plt.grid(True)\n", - "\n", - "save_fig('cyprus_prediction_plot')\n", - "plt.show()" + "cyprus_gdp_per_capita = gdp_per_capita[gdppc_col].loc[\"Cyprus\"]\n", + "cyprus_gdp_per_capita" ] }, { @@ -525,9 +510,8 @@ "metadata": {}, "outputs": [], "source": [ - "missing_data = full_country_stats[(full_country_stats[gdppc] < min_gdp) |\n", - " (full_country_stats[gdppc] > max_gdp)]\n", - "missing_data" + "cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0, 0]\n", + "cyprus_predicted_life_satisfaction" ] }, { @@ -536,17 +520,27 @@ "metadata": {}, "outputs": [], "source": [ - "position_text2 = {\n", - " \"South Africa\": (20_000, 4.2),\n", - " \"Colombia\": (6_000, 8.2),\n", - " \"Brazil\": (18_000, 7.8),\n", - " \"Mexico\": (24_000, 7.4),\n", - " \"Chile\": (30_000, 7.0),\n", - " \"Norway\": (60_000, 6.2),\n", - " \"Switzerland\": (65_000, 5.7),\n", - " \"Ireland\": (80_000, 5.5),\n", - " \"Luxembourg\": (100_000, 5.0),\n", - "}" + "country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", + " x=gdppc_col, y=lifesat_col)\n", + "\n", + "X = np.linspace(min_gdp, max_gdp, 1000)\n", + "plt.plot(X, t0 + t1 * X, \"b\")\n", + "\n", + "plt.text(min_gdp + 15_000, max_life_sat - 1.5,\n", + " fr\"$\\theta_0 = {t0:.2f}$\", color=\"b\")\n", + "plt.text(min_gdp + 15_000, max_life_sat - 1,\n", + " fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\", color=\"b\")\n", + "\n", + "plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita],\n", + " [min_life_sat, cyprus_predicted_life_satisfaction], \"r--\")\n", + "plt.text(cyprus_gdp_per_capita + 1000, 5.0,\n", + " fr\"Prediction = {cyprus_predicted_life_satisfaction:.2f}\", color=\"r\")\n", + "plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, \"ro\")\n", + "\n", + "plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n", + "\n", + "save_fig('cyprus_prediction_plot')\n", + "plt.show()" ] }, { @@ -555,33 +549,9 @@ "metadata": {}, "outputs": [], "source": [ - "full_country_stats.plot(kind='scatter', figsize=(8,3),\n", - " x=\"GDP per capita (USD)\", y='Life satisfaction')\n", - "\n", - "for country, pos_text in position_text2.items():\n", - " pos_data_x, pos_data_y = missing_data.loc[country]\n", - " plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n", - " arrowprops=dict(facecolor='black', width=0.5, shrink=0.1,\n", - " headwidth=5))\n", - " plt.plot(pos_data_x, pos_data_y, \"rs\")\n", - "\n", - "X = np.linspace(0, 115_000, 1000)\n", - "plt.plot(X, t0 + t1 * X, \"b:\")\n", - "\n", - "lin_reg_full = linear_model.LinearRegression()\n", - "Xfull = np.c_[full_country_stats[\"GDP per capita (USD)\"]]\n", - "yfull = np.c_[full_country_stats[\"Life satisfaction\"]]\n", - "lin_reg_full.fit(Xfull, yfull)\n", - "\n", - "t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n", - "X = np.linspace(0, 115_000, 1000)\n", - "plt.plot(X, t0full + t1full * X, \"k\")\n", - "\n", - "plt.axis([0, 115_000, min_life_sat, max_life_sat])\n", - "plt.grid(True)\n", - "\n", - "save_fig('representative_training_data_scatterplot')\n", - "plt.show()" + "missing_data = full_country_stats[(full_country_stats[gdppc_col] < min_gdp) |\n", + " (full_country_stats[gdppc_col] > max_gdp)]\n", + "missing_data" ] }, { @@ -589,13 +559,66 @@ "execution_count": 25, "metadata": {}, "outputs": [], + "source": [ + "position_text_missing_countries = {\n", + " \"South Africa\": (20_000, 4.2),\n", + " \"Colombia\": (6_000, 8.2),\n", + " \"Brazil\": (18_000, 7.8),\n", + " \"Mexico\": (24_000, 7.4),\n", + " \"Chile\": (30_000, 7.0),\n", + " \"Norway\": (51_000, 6.2),\n", + " \"Switzerland\": (62_000, 5.7),\n", + " \"Ireland\": (81_000, 5.2),\n", + " \"Luxembourg\": (92_000, 4.7),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "full_country_stats.plot(kind='scatter', figsize=(8,3),\n", + " x=gdppc_col, y=lifesat_col, grid=True)\n", + "\n", + "for country, pos_text in position_text_missing_countries.items():\n", + " pos_data_x, pos_data_y = missing_data.loc[country]\n", + " plt.annotate(country, xy=(pos_data_x, pos_data_y),\n", + " xytext=pos_text,\n", + " arrowprops=dict(facecolor='black', width=0.5,\n", + " shrink=0.1, headwidth=5))\n", + " plt.plot(pos_data_x, pos_data_y, \"rs\")\n", + "\n", + "X = np.linspace(0, 115_000, 1000)\n", + "plt.plot(X, t0 + t1 * X, \"b:\")\n", + "\n", + "lin_reg_full = linear_model.LinearRegression()\n", + "Xfull = np.c_[full_country_stats[gdppc_col]]\n", + "yfull = np.c_[full_country_stats[lifesat_col]]\n", + "lin_reg_full.fit(Xfull, yfull)\n", + "\n", + "t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n", + "X = np.linspace(0, 115_000, 1000)\n", + "plt.plot(X, t0full + t1full * X, \"k\")\n", + "\n", + "plt.axis([0, 115_000, min_life_sat, max_life_sat])\n", + "\n", + "save_fig('representative_training_data_scatterplot')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], "source": [ "from sklearn import preprocessing\n", "from sklearn import pipeline\n", "\n", "full_country_stats.plot(kind='scatter', figsize=(8,3),\n", - " x=\"GDP per capita (USD)\", y='Life satisfaction')\n", - "plt.axis([0, 115_000, min_life_sat, max_life_sat])\n", + " x=gdppc_col, y=lifesat_col, grid=True)\n", "\n", "poly = preprocessing.PolynomialFeatures(degree=10, include_bias=False)\n", "scaler = preprocessing.StandardScaler()\n", @@ -608,64 +631,57 @@ "pipeline_reg.fit(Xfull, yfull)\n", "curve = pipeline_reg.predict(X[:, np.newaxis])\n", "plt.plot(X, curve)\n", - "plt.grid(True)\n", + "\n", + "plt.axis([0, 115_000, min_life_sat, max_life_sat])\n", "\n", "save_fig('overfitting_model_plot')\n", "plt.show()" ] }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [], - "source": [ - "w_countries = [c for c in full_country_stats.index if \"W\" in c.upper()]\n", - "full_country_stats.loc[w_countries][\"Life satisfaction\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [], - "source": [ - "all_w_countries = [c for c in gdp_per_capita_clean.index if \"W\" in c.upper()]\n", - "gdp_per_capita_clean.loc[all_w_countries].sort_values(by=gdppc)" - ] - }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ - "plt.figure(figsize=(8,3))\n", - "\n", - "plt.xlabel(\"GDP per capita (USD)\")\n", - "plt.ylabel('Life satisfaction')\n", - "\n", - "country_stats.plot(ax=plt.gca(), kind='scatter',\n", - " x=gdppc, y='Life satisfaction')\n", - "missing_data.plot(ax=plt.gca(), kind='scatter',\n", - " x=gdppc, y='Life satisfaction', marker=\"s\", color=\"r\")\n", + "w_countries = [c for c in full_country_stats.index if \"W\" in c.upper()]\n", + "full_country_stats.loc[w_countries][lifesat_col]" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "all_w_countries = [c for c in gdp_per_capita.index if \"W\" in c.upper()]\n", + "gdp_per_capita.loc[all_w_countries].sort_values(by=gdppc_col)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "country_stats.plot(kind='scatter', x=gdppc_col, y=lifesat_col, figsize=(8,3))\n", + "missing_data.plot(kind='scatter', x=gdppc_col, y=lifesat_col,\n", + " marker=\"s\", color=\"r\", grid=True, ax=plt.gca())\n", "\n", "X = np.linspace(0, 115_000, 1000)\n", - "plt.plot(X, t0full + t1full * X, \"k-\", label=\"Linear model on all data\")\n", "plt.plot(X, t0 + t1*X, \"b:\", label=\"Linear model on partial data\")\n", + "plt.plot(X, t0full + t1full * X, \"k-\", label=\"Linear model on all data\")\n", "\n", "ridge = linear_model.Ridge(alpha=10**9.5)\n", - "Xsample = country_stats[[\"GDP per capita (USD)\"]]\n", - "ysample = country_stats[[\"Life satisfaction\"]]\n", - "ridge.fit(Xsample, ysample)\n", + "X_sample = country_stats[[gdppc_col]]\n", + "y_sample = country_stats[[lifesat_col]]\n", + "ridge.fit(X_sample, y_sample)\n", "t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n", - "plt.plot(X, t0ridge + t1ridge * X, \"b--\", label=\"Regularized linear model on partial data\")\n", - "\n", + "plt.plot(X, t0ridge + t1ridge * X, \"b--\",\n", + " label=\"Regularized linear model on partial data\")\n", "plt.legend(loc=\"lower right\")\n", - "plt.axis([0, 115_000, 0, 10])\n", "\n", "plt.axis([0, 115_000, min_life_sat, max_life_sat])\n", - "plt.grid(True)\n", "\n", "save_fig('ridge_model_plot')\n", "plt.show()"