Add (USD) after GDP per capita

main
Aurélien Geron 2019-07-10 17:08:12 +02:00
parent 48dfa2ce67
commit f8f2b9e4bb
1 changed files with 14 additions and 0 deletions

View File

@ -353,6 +353,7 @@
" plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n", " plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n",
" arrowprops=dict(facecolor='black', width=0.5, shrink=0.1, headwidth=5))\n", " arrowprops=dict(facecolor='black', width=0.5, shrink=0.1, headwidth=5))\n",
" plt.plot(pos_data_x, pos_data_y, \"ro\")\n", " plt.plot(pos_data_x, pos_data_y, \"ro\")\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"save_fig('money_happy_scatterplot')\n", "save_fig('money_happy_scatterplot')\n",
"plt.show()" "plt.show()"
] ]
@ -384,6 +385,7 @@
"import numpy as np\n", "import numpy as np\n",
"\n", "\n",
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n", "sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"plt.axis([0, 60000, 0, 10])\n", "plt.axis([0, 60000, 0, 10])\n",
"X=np.linspace(0, 60000, 1000)\n", "X=np.linspace(0, 60000, 1000)\n",
"plt.plot(X, 2*X/100000, \"r\")\n", "plt.plot(X, 2*X/100000, \"r\")\n",
@ -421,6 +423,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n", "sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"plt.axis([0, 60000, 0, 10])\n", "plt.axis([0, 60000, 0, 10])\n",
"X=np.linspace(0, 60000, 1000)\n", "X=np.linspace(0, 60000, 1000)\n",
"plt.plot(X, t0 + t1*X, \"b\")\n", "plt.plot(X, t0 + t1*X, \"b\")\n",
@ -449,6 +452,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3), s=1)\n", "sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3), s=1)\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"X=np.linspace(0, 60000, 1000)\n", "X=np.linspace(0, 60000, 1000)\n",
"plt.plot(X, t0 + t1*X, \"b\")\n", "plt.plot(X, t0 + t1*X, \"b\")\n",
"plt.axis([0, 60000, 0, 10])\n", "plt.axis([0, 60000, 0, 10])\n",
@ -598,6 +602,7 @@
"t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n", "t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n",
"X = np.linspace(0, 110000, 1000)\n", "X = np.linspace(0, 110000, 1000)\n",
"plt.plot(X, t0full + t1full * X, \"k\")\n", "plt.plot(X, t0full + t1full * X, \"k\")\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"\n", "\n",
"save_fig('representative_training_data_scatterplot')\n", "save_fig('representative_training_data_scatterplot')\n",
"plt.show()" "plt.show()"
@ -623,6 +628,7 @@
"pipeline_reg.fit(Xfull, yfull)\n", "pipeline_reg.fit(Xfull, yfull)\n",
"curve = pipeline_reg.predict(X[:, np.newaxis])\n", "curve = pipeline_reg.predict(X[:, np.newaxis])\n",
"plt.plot(X, curve)\n", "plt.plot(X, curve)\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"save_fig('overfitting_model_plot')\n", "save_fig('overfitting_model_plot')\n",
"plt.show()" "plt.show()"
] ]
@ -672,6 +678,7 @@
"\n", "\n",
"plt.legend(loc=\"lower right\")\n", "plt.legend(loc=\"lower right\")\n",
"plt.axis([0, 110000, 0, 10])\n", "plt.axis([0, 110000, 0, 10])\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"save_fig('ridge_model_plot')\n", "save_fig('ridge_model_plot')\n",
"plt.show()" "plt.show()"
] ]
@ -726,6 +733,13 @@
"X_new = np.array([[22587.0]]) # Cyprus' GDP per capita\n", "X_new = np.array([[22587.0]]) # Cyprus' GDP per capita\n",
"print(model.predict(X_new)) # outputs [[ 5.76666667]]" "print(model.predict(X_new)) # outputs [[ 5.76666667]]"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {