Update plot options

main
Aurélien Geron 2022-02-19 18:18:08 +13:00
parent b63019fd28
commit 4ba9496a87
1 changed files with 15 additions and 15 deletions

View File

@ -104,11 +104,11 @@
"source": [ "source": [
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"\n", "\n",
"plt.rc('font', size=14)\n", "plt.rc('font', size=12)\n",
"plt.rc('axes', labelsize=14, titlesize=14)\n", "plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n", "plt.rc('legend', fontsize=12)\n",
"plt.rc('xtick',labelsize=10)\n", "plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick',labelsize=10)" "plt.rc('ytick', labelsize=10)"
] ]
}, },
{ {
@ -214,7 +214,7 @@
"model = KNeighborsRegressor(n_neighbors=3)\n", "model = KNeighborsRegressor(n_neighbors=3)\n",
"\n", "\n",
"# Train the model\n", "# Train the model\n",
"model.fit(X,y)\n", "model.fit(X, y)\n",
"\n", "\n",
"# Make a prediction for Cyprus\n", "# Make a prediction for Cyprus\n",
"print(model.predict(X_new)) # outputs [[6.33333333]]\n" "print(model.predict(X_new)) # outputs [[6.33333333]]\n"
@ -399,7 +399,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", "country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,\n",
" x=gdppc_col, y=lifesat_col)\n", " x=gdppc_col, y=lifesat_col)\n",
"\n", "\n",
"min_life_sat = 4\n", "min_life_sat = 4\n",
@ -422,7 +422,7 @@
" plt.annotate(country, xy=(pos_data_x, pos_data_y),\n", " plt.annotate(country, xy=(pos_data_x, pos_data_y),\n",
" xytext=pos_text, fontsize=12,\n", " xytext=pos_text, fontsize=12,\n",
" arrowprops=dict(facecolor='black', width=0.5,\n", " arrowprops=dict(facecolor='black', width=0.5,\n",
" shrink=0.15, headwidth=5))\n", " shrink=0.08, headwidth=5))\n",
" plt.plot(pos_data_x, pos_data_y, \"ro\")\n", " plt.plot(pos_data_x, pos_data_y, \"ro\")\n",
"\n", "\n",
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n", "plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
@ -447,7 +447,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", "country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,\n",
" x=gdppc_col, y=lifesat_col)\n", " x=gdppc_col, y=lifesat_col)\n",
"\n", "\n",
"X = np.linspace(min_gdp, max_gdp, 1000)\n", "X = np.linspace(min_gdp, max_gdp, 1000)\n",
@ -497,7 +497,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", "country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,\n",
" x=gdppc_col, y=lifesat_col)\n", " x=gdppc_col, y=lifesat_col)\n",
"\n", "\n",
"X = np.linspace(min_gdp, max_gdp, 1000)\n", "X = np.linspace(min_gdp, max_gdp, 1000)\n",
@ -540,7 +540,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", "country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,\n",
" x=gdppc_col, y=lifesat_col)\n", " x=gdppc_col, y=lifesat_col)\n",
"\n", "\n",
"X = np.linspace(min_gdp, max_gdp, 1000)\n", "X = np.linspace(min_gdp, max_gdp, 1000)\n",
@ -598,7 +598,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"full_country_stats.plot(kind='scatter', figsize=(8,3),\n", "full_country_stats.plot(kind='scatter', figsize=(8, 3),\n",
" x=gdppc_col, y=lifesat_col, grid=True)\n", " x=gdppc_col, y=lifesat_col, grid=True)\n",
"\n", "\n",
"for country, pos_text in position_text_missing_countries.items():\n", "for country, pos_text in position_text_missing_countries.items():\n",
@ -606,7 +606,7 @@
" plt.annotate(country, xy=(pos_data_x, pos_data_y),\n", " plt.annotate(country, xy=(pos_data_x, pos_data_y),\n",
" xytext=pos_text, fontsize=12,\n", " xytext=pos_text, fontsize=12,\n",
" arrowprops=dict(facecolor='black', width=0.5,\n", " arrowprops=dict(facecolor='black', width=0.5,\n",
" shrink=0.1, headwidth=5))\n", " shrink=0.08, headwidth=5))\n",
" plt.plot(pos_data_x, pos_data_y, \"rs\")\n", " plt.plot(pos_data_x, pos_data_y, \"rs\")\n",
"\n", "\n",
"X = np.linspace(0, 115_000, 1000)\n", "X = np.linspace(0, 115_000, 1000)\n",
@ -636,7 +636,7 @@
"from sklearn import preprocessing\n", "from sklearn import preprocessing\n",
"from sklearn import pipeline\n", "from sklearn import pipeline\n",
"\n", "\n",
"full_country_stats.plot(kind='scatter', figsize=(8,3),\n", "full_country_stats.plot(kind='scatter', figsize=(8, 3),\n",
" x=gdppc_col, y=lifesat_col, grid=True)\n", " x=gdppc_col, y=lifesat_col, grid=True)\n",
"\n", "\n",
"poly = preprocessing.PolynomialFeatures(degree=10, include_bias=False)\n", "poly = preprocessing.PolynomialFeatures(degree=10, include_bias=False)\n",
@ -683,7 +683,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"country_stats.plot(kind='scatter', x=gdppc_col, y=lifesat_col, figsize=(8,3))\n", "country_stats.plot(kind='scatter', x=gdppc_col, y=lifesat_col, figsize=(8, 3))\n",
"missing_data.plot(kind='scatter', x=gdppc_col, y=lifesat_col,\n", "missing_data.plot(kind='scatter', x=gdppc_col, y=lifesat_col,\n",
" marker=\"s\", color=\"r\", grid=True, ax=plt.gca())\n", " marker=\"s\", color=\"r\", grid=True, ax=plt.gca())\n",
"\n", "\n",
@ -698,7 +698,7 @@
"t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n", "t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n",
"plt.plot(X, t0ridge + t1ridge * X, \"b--\",\n", "plt.plot(X, t0ridge + t1ridge * X, \"b--\",\n",
" label=\"Regularized linear model on partial data\")\n", " label=\"Regularized linear model on partial data\")\n",
"plt.legend(loc=\"lower right\", fontsize=13)\n", "plt.legend(loc=\"lower right\")\n",
"\n", "\n",
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n", "plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
"\n", "\n",