Improve figure ridge_model_plot, and run the first cell again with lifesat.csv now partial

main
Aurélien Geron 2021-10-19 13:24:41 +13:00
parent 84f173b600
commit 4a2d0ea1ae
1 changed files with 6 additions and 7 deletions

View File

@ -645,13 +645,13 @@
"plt.xlabel(\"GDP per capita (USD)\")\n", "plt.xlabel(\"GDP per capita (USD)\")\n",
"plt.ylabel('Life satisfaction')\n", "plt.ylabel('Life satisfaction')\n",
"\n", "\n",
"plt.plot(list(country_stats[\"GDP per capita (USD)\"]),\n", "country_stats.plot(ax=plt.gca(), kind='scatter',\n",
" list(country_stats[\"Life satisfaction\"]), \"bo\")\n", " x=gdppc, y='Life satisfaction')\n",
"plt.plot(list(missing_data[\"GDP per capita (USD)\"]),\n", "missing_data.plot(ax=plt.gca(), kind='scatter',\n",
" list(missing_data[\"Life satisfaction\"]), \"rs\")\n", " x=gdppc, y='Life satisfaction', marker=\"s\", color=\"r\")\n",
"\n", "\n",
"X = np.linspace(0, 115_000, 1000)\n", "X = np.linspace(0, 115_000, 1000)\n",
"plt.plot(X, t0full + t1full * X, \"r--\", label=\"Linear model on all data\")\n", "plt.plot(X, t0full + t1full * X, \"k-\", label=\"Linear model on all data\")\n",
"plt.plot(X, t0 + t1*X, \"b:\", label=\"Linear model on partial data\")\n", "plt.plot(X, t0 + t1*X, \"b:\", label=\"Linear model on partial data\")\n",
"\n", "\n",
"ridge = linear_model.Ridge(alpha=10**9.5)\n", "ridge = linear_model.Ridge(alpha=10**9.5)\n",
@ -659,11 +659,10 @@
"ysample = country_stats[[\"Life satisfaction\"]]\n", "ysample = country_stats[[\"Life satisfaction\"]]\n",
"ridge.fit(Xsample, ysample)\n", "ridge.fit(Xsample, ysample)\n",
"t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n", "t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n",
"plt.plot(X, t0ridge + t1ridge * X, \"b\", label=\"Regularized linear model on partial data\")\n", "plt.plot(X, t0ridge + t1ridge * X, \"b--\", label=\"Regularized linear model on partial data\")\n",
"\n", "\n",
"plt.legend(loc=\"lower right\")\n", "plt.legend(loc=\"lower right\")\n",
"plt.axis([0, 115_000, 0, 10])\n", "plt.axis([0, 115_000, 0, 10])\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"\n", "\n",
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n", "plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
"plt.grid(True)\n", "plt.grid(True)\n",