From 89653bfaded025074eca17b817640d7dc8e6a79d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Sun, 13 Oct 2019 17:19:39 +0800 Subject: [PATCH] Improve a few figures (e.g., add missing labels, share axes, etc.) --- 05_support_vector_machines.ipynb | 95 +++++++++++++++++--------------- 1 file changed, 51 insertions(+), 44 deletions(-) diff --git a/05_support_vector_machines.ipynb b/05_support_vector_machines.ipynb index d25c556..a99b181 100644 --- a/05_support_vector_machines.ipynb +++ b/05_support_vector_machines.ipynb @@ -133,9 +133,9 @@ " plt.plot(x0, gutter_up, \"k--\", linewidth=2)\n", " plt.plot(x0, gutter_down, \"k--\", linewidth=2)\n", "\n", - "plt.figure(figsize=(12,2.7))\n", + "fig, axes = plt.subplots(ncols=2, figsize=(10,2.7), sharey=True)\n", "\n", - "plt.subplot(121)\n", + "plt.sca(axes[0])\n", "plt.plot(x0, pred_1, \"g--\", linewidth=2)\n", "plt.plot(x0, pred_2, \"m-\", linewidth=2)\n", "plt.plot(x0, pred_3, \"r-\", linewidth=2)\n", @@ -146,7 +146,7 @@ "plt.legend(loc=\"upper left\", fontsize=14)\n", "plt.axis([0, 5.5, 0, 2])\n", "\n", - "plt.subplot(122)\n", + "plt.sca(axes[1])\n", "plot_svc_decision_boundary(svm_clf, 0, 5.5)\n", "plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\")\n", "plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\")\n", @@ -175,13 +175,13 @@ "svm_clf = SVC(kernel=\"linear\", C=100)\n", "svm_clf.fit(Xs, ys)\n", "\n", - "plt.figure(figsize=(12,3.2))\n", + "plt.figure(figsize=(9,2.7))\n", "plt.subplot(121)\n", "plt.plot(Xs[:, 0][ys==1], Xs[:, 1][ys==1], \"bo\")\n", "plt.plot(Xs[:, 0][ys==0], Xs[:, 1][ys==0], \"ms\")\n", "plot_svc_decision_boundary(svm_clf, 0, 6)\n", "plt.xlabel(\"$x_0$\", fontsize=20)\n", - "plt.ylabel(\"$x_1$ \", fontsize=20, rotation=0)\n", + "plt.ylabel(\"$x_1$    \", fontsize=20, rotation=0)\n", "plt.title(\"Unscaled\", fontsize=16)\n", "plt.axis([0, 6, 0, 90])\n", "\n", @@ -195,6 +195,7 @@ "plt.plot(X_scaled[:, 0][ys==0], X_scaled[:, 1][ys==0], \"ms\")\n", "plot_svc_decision_boundary(svm_clf, -2, 2)\n", "plt.xlabel(\"$x_0$\", fontsize=20)\n", + "plt.ylabel(\"$x'_1$ \", fontsize=20, rotation=0)\n", "plt.title(\"Scaled\", fontsize=16)\n", "plt.axis([-2, 2, -2, 2])\n", "\n", @@ -224,9 +225,9 @@ "svm_clf2 = SVC(kernel=\"linear\", C=10**9)\n", "svm_clf2.fit(Xo2, yo2)\n", "\n", - "plt.figure(figsize=(12,2.7))\n", + "fig, axes = plt.subplots(ncols=2, figsize=(10,2.7), sharey=True)\n", "\n", - "plt.subplot(121)\n", + "plt.sca(axes[0])\n", "plt.plot(Xo1[:, 0][yo1==1], Xo1[:, 1][yo1==1], \"bs\")\n", "plt.plot(Xo1[:, 0][yo1==0], Xo1[:, 1][yo1==0], \"yo\")\n", "plt.text(0.3, 1.0, \"Impossible!\", fontsize=24, color=\"red\")\n", @@ -241,7 +242,7 @@ " )\n", "plt.axis([0, 5.5, 0, 2])\n", "\n", - "plt.subplot(122)\n", + "plt.sca(axes[1])\n", "plt.plot(Xo2[:, 0][yo2==1], Xo2[:, 1][yo2==1], \"bs\")\n", "plt.plot(Xo2[:, 0][yo2==0], Xo2[:, 1][yo2==0], \"yo\")\n", "plot_svc_decision_boundary(svm_clf2, 0, 5.5)\n", @@ -366,24 +367,25 @@ "metadata": {}, "outputs": [], "source": [ - "plt.figure(figsize=(12,3.2))\n", - "plt.subplot(121)\n", + "fig, axes = plt.subplots(ncols=2, figsize=(10,2.7), sharey=True)\n", + "\n", + "plt.sca(axes[0])\n", "plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\", label=\"Iris virginica\")\n", "plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\", label=\"Iris versicolor\")\n", - "plot_svc_decision_boundary(svm_clf1, 4, 6)\n", + "plot_svc_decision_boundary(svm_clf1, 4, 5.9)\n", "plt.xlabel(\"Petal length\", fontsize=14)\n", "plt.ylabel(\"Petal width\", fontsize=14)\n", "plt.legend(loc=\"upper left\", fontsize=14)\n", "plt.title(\"$C = {}$\".format(svm_clf1.C), fontsize=16)\n", - "plt.axis([4, 6, 0.8, 2.8])\n", + "plt.axis([4, 5.9, 0.8, 2.8])\n", "\n", - "plt.subplot(122)\n", + "plt.sca(axes[1])\n", "plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n", "plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n", - "plot_svc_decision_boundary(svm_clf2, 4, 6)\n", + "plot_svc_decision_boundary(svm_clf2, 4, 5.99)\n", "plt.xlabel(\"Petal length\", fontsize=14)\n", "plt.title(\"$C = {}$\".format(svm_clf2.C), fontsize=16)\n", - "plt.axis([4, 6, 0.8, 2.8])\n", + "plt.axis([4, 5.9, 0.8, 2.8])\n", "\n", "save_fig(\"regularization_plot\")" ] @@ -407,7 +409,7 @@ "X2D = np.c_[X1D, X1D**2]\n", "y = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n", "\n", - "plt.figure(figsize=(11, 4))\n", + "plt.figure(figsize=(10, 3))\n", "\n", "plt.subplot(121)\n", "plt.grid(True, which='both')\n", @@ -425,7 +427,7 @@ "plt.plot(X2D[:, 0][y==0], X2D[:, 1][y==0], \"bs\")\n", "plt.plot(X2D[:, 0][y==1], X2D[:, 1][y==1], \"g^\")\n", "plt.xlabel(r\"$x_1$\", fontsize=20)\n", - "plt.ylabel(r\"$x_2$\", fontsize=20, rotation=0)\n", + "plt.ylabel(r\"$x_2$  \", fontsize=20, rotation=0)\n", "plt.gca().get_yaxis().set_ticks([0, 4, 8, 12, 16])\n", "plt.plot([-4.5, 4.5], [6.5, 6.5], \"r--\", linewidth=3)\n", "plt.axis([-4.5, 4.5, -1, 17])\n", @@ -533,17 +535,18 @@ "metadata": {}, "outputs": [], "source": [ - "plt.figure(figsize=(11, 4))\n", + "fig, axes = plt.subplots(ncols=2, figsize=(10.5, 4), sharey=True)\n", "\n", - "plt.subplot(121)\n", - "plot_predictions(poly_kernel_svm_clf, [-1.5, 2.5, -1, 1.5])\n", - "plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n", + "plt.sca(axes[0])\n", + "plot_predictions(poly_kernel_svm_clf, [-1.5, 2.45, -1, 1.5])\n", + "plot_dataset(X, y, [-1.5, 2.4, -1, 1.5])\n", "plt.title(r\"$d=3, r=1, C=5$\", fontsize=18)\n", "\n", - "plt.subplot(122)\n", - "plot_predictions(poly100_kernel_svm_clf, [-1.5, 2.5, -1, 1.5])\n", - "plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n", + "plt.sca(axes[1])\n", + "plot_predictions(poly100_kernel_svm_clf, [-1.5, 2.45, -1, 1.5])\n", + "plot_dataset(X, y, [-1.5, 2.4, -1, 1.5])\n", "plt.title(r\"$d=10, r=100, C=5$\", fontsize=18)\n", + "plt.ylabel(\"\")\n", "\n", "save_fig(\"moons_kernelized_polynomial_svc_plot\")\n", "plt.show()" @@ -569,7 +572,7 @@ "XK = np.c_[gaussian_rbf(X1D, -2, gamma), gaussian_rbf(X1D, 1, gamma)]\n", "yk = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n", "\n", - "plt.figure(figsize=(11, 4))\n", + "plt.figure(figsize=(10.5, 4))\n", "\n", "plt.subplot(121)\n", "plt.grid(True, which='both')\n", @@ -600,7 +603,7 @@ "plt.plot(XK[:, 0][yk==0], XK[:, 1][yk==0], \"bs\")\n", "plt.plot(XK[:, 0][yk==1], XK[:, 1][yk==1], \"g^\")\n", "plt.xlabel(r\"$x_2$\", fontsize=20)\n", - "plt.ylabel(r\"$x_3$ \", fontsize=20, rotation=0)\n", + "plt.ylabel(r\"$x_3$  \", fontsize=20, rotation=0)\n", "plt.annotate(r'$\\phi\\left(\\mathbf{x}\\right)$',\n", " xy=(XK[3, 0], XK[3, 1]),\n", " xytext=(0.65, 0.50),\n", @@ -665,14 +668,18 @@ " rbf_kernel_svm_clf.fit(X, y)\n", " svm_clfs.append(rbf_kernel_svm_clf)\n", "\n", - "plt.figure(figsize=(11, 7))\n", + "fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10.5, 7), sharex=True, sharey=True)\n", "\n", "for i, svm_clf in enumerate(svm_clfs):\n", - " plt.subplot(221 + i)\n", - " plot_predictions(svm_clf, [-1.5, 2.5, -1, 1.5])\n", - " plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n", + " plt.sca(axes[i // 2, i % 2])\n", + " plot_predictions(svm_clf, [-1.5, 2.45, -1, 1.5])\n", + " plot_dataset(X, y, [-1.5, 2.45, -1, 1.5])\n", " gamma, C = hyperparams[i]\n", " plt.title(r\"$\\gamma = {}, C = {}$\".format(gamma, C), fontsize=16)\n", + " if i in (0, 1):\n", + " plt.xlabel(\"\")\n", + " if i in (1, 3):\n", + " plt.ylabel(\"\")\n", "\n", "save_fig(\"moons_rbf_svc_plot\")\n", "plt.show()" @@ -750,8 +757,8 @@ " plt.legend(loc=\"upper left\", fontsize=18)\n", " plt.axis(axes)\n", "\n", - "plt.figure(figsize=(9, 4))\n", - "plt.subplot(121)\n", + "fig, axes = plt.subplots(ncols=2, figsize=(9, 4), sharey=True)\n", + "plt.sca(axes[0])\n", "plot_svm_regression(svm_reg1, X, y, [0, 2, 3, 11])\n", "plt.title(r\"$\\epsilon = {}$\".format(svm_reg1.epsilon), fontsize=18)\n", "plt.ylabel(r\"$y$\", fontsize=18, rotation=0)\n", @@ -762,7 +769,7 @@ " textcoords='data', arrowprops={'arrowstyle': '<->', 'linewidth': 1.5}\n", " )\n", "plt.text(0.91, 5.6, r\"$\\epsilon$\", fontsize=20)\n", - "plt.subplot(122)\n", + "plt.sca(axes[1])\n", "plot_svm_regression(svm_reg2, X, y, [0, 2, 3, 11])\n", "plt.title(r\"$\\epsilon = {}$\".format(svm_reg2.epsilon), fontsize=18)\n", "save_fig(\"svm_regression_plot\")\n", @@ -820,12 +827,12 @@ "metadata": {}, "outputs": [], "source": [ - "plt.figure(figsize=(9, 4))\n", - "plt.subplot(121)\n", + "fig, axes = plt.subplots(ncols=2, figsize=(9, 4), sharey=True)\n", + "plt.sca(axes[0])\n", "plot_svm_regression(svm_poly_reg1, X, y, [-1, 1, 0, 1])\n", "plt.title(r\"$degree={}, C={}, \\epsilon = {}$\".format(svm_poly_reg1.degree, svm_poly_reg1.C, svm_poly_reg1.epsilon), fontsize=18)\n", "plt.ylabel(r\"$y$\", fontsize=18, rotation=0)\n", - "plt.subplot(122)\n", + "plt.sca(axes[1])\n", "plot_svm_regression(svm_poly_reg2, X, y, [-1, 1, 0, 1])\n", "plt.title(r\"$degree={}, C={}, \\epsilon = {}$\".format(svm_poly_reg2.degree, svm_poly_reg2.C, svm_poly_reg2.epsilon), fontsize=18)\n", "save_fig(\"svm_with_polynomial_kernel_plot\")\n", @@ -923,13 +930,13 @@ " plt.axis(x1_lim + [-2, 2])\n", " plt.xlabel(r\"$x_1$\", fontsize=16)\n", " if ylabel:\n", - " plt.ylabel(r\"$w_1 x_1$ \", rotation=0, fontsize=16)\n", + " plt.ylabel(r\"$w_1 x_1$  \", rotation=0, fontsize=16)\n", " plt.title(r\"$w_1 = {}$\".format(w), fontsize=16)\n", "\n", - "plt.figure(figsize=(12, 3.2))\n", - "plt.subplot(121)\n", + "fig, axes = plt.subplots(ncols=2, figsize=(9, 3.2), sharey=True)\n", + "plt.sca(axes[0])\n", "plot_2D_decision_function(1, 0)\n", - "plt.subplot(122)\n", + "plt.sca(axes[1])\n", "plot_2D_decision_function(0.5, 0, ylabel=False)\n", "save_fig(\"small_w_large_margin_plot\")\n", "plt.show()" @@ -1154,8 +1161,8 @@ "outputs": [], "source": [ "yr = y.ravel()\n", - "plt.figure(figsize=(12,3.2))\n", - "plt.subplot(121)\n", + "fig, axes = plt.subplots(ncols=2, figsize=(11, 3.2), sharey=True)\n", + "plt.sca(axes[0])\n", "plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\", label=\"Iris virginica\")\n", "plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\", label=\"Not Iris virginica\")\n", "plot_svc_decision_boundary(svm_clf, 4, 6)\n", @@ -1165,7 +1172,7 @@ "plt.axis([4, 6, 0.8, 2.8])\n", "plt.legend(loc=\"upper left\")\n", "\n", - "plt.subplot(122)\n", + "plt.sca(axes[1])\n", "plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\")\n", "plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\")\n", "plot_svc_decision_boundary(svm_clf2, 4, 6)\n", @@ -1816,7 +1823,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.8" + "version": "3.7.4" }, "nav_menu": {}, "toc": {