From 4a2d0ea1aed2ab5215ae7217b1942ecc405cbfb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Tue, 19 Oct 2021 13:24:41 +1300 Subject: [PATCH] Improve figure ridge_model_plot, and run the first cell again with lifesat.csv now partial --- 01_the_machine_learning_landscape.ipynb | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/01_the_machine_learning_landscape.ipynb b/01_the_machine_learning_landscape.ipynb index 68a6720..618aff4 100644 --- a/01_the_machine_learning_landscape.ipynb +++ b/01_the_machine_learning_landscape.ipynb @@ -645,13 +645,13 @@ "plt.xlabel(\"GDP per capita (USD)\")\n", "plt.ylabel('Life satisfaction')\n", "\n", - "plt.plot(list(country_stats[\"GDP per capita (USD)\"]),\n", - " list(country_stats[\"Life satisfaction\"]), \"bo\")\n", - "plt.plot(list(missing_data[\"GDP per capita (USD)\"]),\n", - " list(missing_data[\"Life satisfaction\"]), \"rs\")\n", + "country_stats.plot(ax=plt.gca(), kind='scatter',\n", + " x=gdppc, y='Life satisfaction')\n", + "missing_data.plot(ax=plt.gca(), kind='scatter',\n", + " x=gdppc, y='Life satisfaction', marker=\"s\", color=\"r\")\n", "\n", "X = np.linspace(0, 115_000, 1000)\n", - "plt.plot(X, t0full + t1full * X, \"r--\", label=\"Linear model on all data\")\n", + "plt.plot(X, t0full + t1full * X, \"k-\", label=\"Linear model on all data\")\n", "plt.plot(X, t0 + t1*X, \"b:\", label=\"Linear model on partial data\")\n", "\n", "ridge = linear_model.Ridge(alpha=10**9.5)\n", @@ -659,11 +659,10 @@ "ysample = country_stats[[\"Life satisfaction\"]]\n", "ridge.fit(Xsample, ysample)\n", "t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n", - "plt.plot(X, t0ridge + t1ridge * X, \"b\", label=\"Regularized linear model on partial data\")\n", + "plt.plot(X, t0ridge + t1ridge * X, \"b--\", label=\"Regularized linear model on partial data\")\n", "\n", "plt.legend(loc=\"lower right\")\n", "plt.axis([0, 115_000, 0, 10])\n", - "plt.xlabel(\"GDP per capita (USD)\")\n", "\n", "plt.axis([0, 115_000, min_life_sat, max_life_sat])\n", "plt.grid(True)\n",