From 4ba9496a87953f34f14d3b45c9a1fe1e474b32f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Sat, 19 Feb 2022 18:18:08 +1300 Subject: [PATCH] Update plot options --- 01_the_machine_learning_landscape.ipynb | 30 ++++++++++++------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/01_the_machine_learning_landscape.ipynb b/01_the_machine_learning_landscape.ipynb index e5644b5..90d6586 100644 --- a/01_the_machine_learning_landscape.ipynb +++ b/01_the_machine_learning_landscape.ipynb @@ -104,11 +104,11 @@ "source": [ "import matplotlib.pyplot as plt\n", "\n", - "plt.rc('font', size=14)\n", + "plt.rc('font', size=12)\n", "plt.rc('axes', labelsize=14, titlesize=14)\n", - "plt.rc('legend', fontsize=14)\n", - "plt.rc('xtick',labelsize=10)\n", - "plt.rc('ytick',labelsize=10)" + "plt.rc('legend', fontsize=12)\n", + "plt.rc('xtick', labelsize=10)\n", + "plt.rc('ytick', labelsize=10)" ] }, { @@ -214,7 +214,7 @@ "model = KNeighborsRegressor(n_neighbors=3)\n", "\n", "# Train the model\n", - "model.fit(X,y)\n", + "model.fit(X, y)\n", "\n", "# Make a prediction for Cyprus\n", "print(model.predict(X_new)) # outputs [[6.33333333]]\n" @@ -399,7 +399,7 @@ "metadata": {}, "outputs": [], "source": [ - "country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", + "country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,\n", " x=gdppc_col, y=lifesat_col)\n", "\n", "min_life_sat = 4\n", @@ -422,7 +422,7 @@ " plt.annotate(country, xy=(pos_data_x, pos_data_y),\n", " xytext=pos_text, fontsize=12,\n", " arrowprops=dict(facecolor='black', width=0.5,\n", - " shrink=0.15, headwidth=5))\n", + " shrink=0.08, headwidth=5))\n", " plt.plot(pos_data_x, pos_data_y, \"ro\")\n", "\n", "plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n", @@ -447,7 +447,7 @@ "metadata": {}, "outputs": [], "source": [ - "country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", + "country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,\n", " x=gdppc_col, y=lifesat_col)\n", "\n", "X = np.linspace(min_gdp, max_gdp, 1000)\n", @@ -497,7 +497,7 @@ "metadata": {}, "outputs": [], "source": [ - "country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", + "country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,\n", " x=gdppc_col, y=lifesat_col)\n", "\n", "X = np.linspace(min_gdp, max_gdp, 1000)\n", @@ -540,7 +540,7 @@ "metadata": {}, "outputs": [], "source": [ - "country_stats.plot(kind='scatter', figsize=(5,3), grid=True,\n", + "country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,\n", " x=gdppc_col, y=lifesat_col)\n", "\n", "X = np.linspace(min_gdp, max_gdp, 1000)\n", @@ -598,7 +598,7 @@ "metadata": {}, "outputs": [], "source": [ - "full_country_stats.plot(kind='scatter', figsize=(8,3),\n", + "full_country_stats.plot(kind='scatter', figsize=(8, 3),\n", " x=gdppc_col, y=lifesat_col, grid=True)\n", "\n", "for country, pos_text in position_text_missing_countries.items():\n", @@ -606,7 +606,7 @@ " plt.annotate(country, xy=(pos_data_x, pos_data_y),\n", " xytext=pos_text, fontsize=12,\n", " arrowprops=dict(facecolor='black', width=0.5,\n", - " shrink=0.1, headwidth=5))\n", + " shrink=0.08, headwidth=5))\n", " plt.plot(pos_data_x, pos_data_y, \"rs\")\n", "\n", "X = np.linspace(0, 115_000, 1000)\n", @@ -636,7 +636,7 @@ "from sklearn import preprocessing\n", "from sklearn import pipeline\n", "\n", - "full_country_stats.plot(kind='scatter', figsize=(8,3),\n", + "full_country_stats.plot(kind='scatter', figsize=(8, 3),\n", " x=gdppc_col, y=lifesat_col, grid=True)\n", "\n", "poly = preprocessing.PolynomialFeatures(degree=10, include_bias=False)\n", @@ -683,7 +683,7 @@ "metadata": {}, "outputs": [], "source": [ - "country_stats.plot(kind='scatter', x=gdppc_col, y=lifesat_col, figsize=(8,3))\n", + "country_stats.plot(kind='scatter', x=gdppc_col, y=lifesat_col, figsize=(8, 3))\n", "missing_data.plot(kind='scatter', x=gdppc_col, y=lifesat_col,\n", " marker=\"s\", color=\"r\", grid=True, ax=plt.gca())\n", "\n", @@ -698,7 +698,7 @@ "t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n", "plt.plot(X, t0ridge + t1ridge * X, \"b--\",\n", " label=\"Regularized linear model on partial data\")\n", - "plt.legend(loc=\"lower right\", fontsize=13)\n", + "plt.legend(loc=\"lower right\")\n", "\n", "plt.axis([0, 115_000, min_life_sat, max_life_sat])\n", "\n",