Improve a few figures (e.g., add missing labels, share axes, etc.)

main
Aurélien Geron 2019-10-13 17:19:39 +08:00
parent 7e6489f8a4
commit 89653bfade
1 changed files with 51 additions and 44 deletions

View File

@ -133,9 +133,9 @@
" plt.plot(x0, gutter_up, \"k--\", linewidth=2)\n",
" plt.plot(x0, gutter_down, \"k--\", linewidth=2)\n",
"\n",
"plt.figure(figsize=(12,2.7))\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10,2.7), sharey=True)\n",
"\n",
"plt.subplot(121)\n",
"plt.sca(axes[0])\n",
"plt.plot(x0, pred_1, \"g--\", linewidth=2)\n",
"plt.plot(x0, pred_2, \"m-\", linewidth=2)\n",
"plt.plot(x0, pred_3, \"r-\", linewidth=2)\n",
@ -146,7 +146,7 @@
"plt.legend(loc=\"upper left\", fontsize=14)\n",
"plt.axis([0, 5.5, 0, 2])\n",
"\n",
"plt.subplot(122)\n",
"plt.sca(axes[1])\n",
"plot_svc_decision_boundary(svm_clf, 0, 5.5)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\")\n",
@ -175,13 +175,13 @@
"svm_clf = SVC(kernel=\"linear\", C=100)\n",
"svm_clf.fit(Xs, ys)\n",
"\n",
"plt.figure(figsize=(12,3.2))\n",
"plt.figure(figsize=(9,2.7))\n",
"plt.subplot(121)\n",
"plt.plot(Xs[:, 0][ys==1], Xs[:, 1][ys==1], \"bo\")\n",
"plt.plot(Xs[:, 0][ys==0], Xs[:, 1][ys==0], \"ms\")\n",
"plot_svc_decision_boundary(svm_clf, 0, 6)\n",
"plt.xlabel(\"$x_0$\", fontsize=20)\n",
"plt.ylabel(\"$x_1$ \", fontsize=20, rotation=0)\n",
"plt.ylabel(\"$x_1$    \", fontsize=20, rotation=0)\n",
"plt.title(\"Unscaled\", fontsize=16)\n",
"plt.axis([0, 6, 0, 90])\n",
"\n",
@ -195,6 +195,7 @@
"plt.plot(X_scaled[:, 0][ys==0], X_scaled[:, 1][ys==0], \"ms\")\n",
"plot_svc_decision_boundary(svm_clf, -2, 2)\n",
"plt.xlabel(\"$x_0$\", fontsize=20)\n",
"plt.ylabel(\"$x'_1$ \", fontsize=20, rotation=0)\n",
"plt.title(\"Scaled\", fontsize=16)\n",
"plt.axis([-2, 2, -2, 2])\n",
"\n",
@ -224,9 +225,9 @@
"svm_clf2 = SVC(kernel=\"linear\", C=10**9)\n",
"svm_clf2.fit(Xo2, yo2)\n",
"\n",
"plt.figure(figsize=(12,2.7))\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10,2.7), sharey=True)\n",
"\n",
"plt.subplot(121)\n",
"plt.sca(axes[0])\n",
"plt.plot(Xo1[:, 0][yo1==1], Xo1[:, 1][yo1==1], \"bs\")\n",
"plt.plot(Xo1[:, 0][yo1==0], Xo1[:, 1][yo1==0], \"yo\")\n",
"plt.text(0.3, 1.0, \"Impossible!\", fontsize=24, color=\"red\")\n",
@ -241,7 +242,7 @@
" )\n",
"plt.axis([0, 5.5, 0, 2])\n",
"\n",
"plt.subplot(122)\n",
"plt.sca(axes[1])\n",
"plt.plot(Xo2[:, 0][yo2==1], Xo2[:, 1][yo2==1], \"bs\")\n",
"plt.plot(Xo2[:, 0][yo2==0], Xo2[:, 1][yo2==0], \"yo\")\n",
"plot_svc_decision_boundary(svm_clf2, 0, 5.5)\n",
@ -366,24 +367,25 @@
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(12,3.2))\n",
"plt.subplot(121)\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10,2.7), sharey=True)\n",
"\n",
"plt.sca(axes[0])\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\", label=\"Iris virginica\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\", label=\"Iris versicolor\")\n",
"plot_svc_decision_boundary(svm_clf1, 4, 6)\n",
"plot_svc_decision_boundary(svm_clf1, 4, 5.9)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.ylabel(\"Petal width\", fontsize=14)\n",
"plt.legend(loc=\"upper left\", fontsize=14)\n",
"plt.title(\"$C = {}$\".format(svm_clf1.C), fontsize=16)\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"plt.axis([4, 5.9, 0.8, 2.8])\n",
"\n",
"plt.subplot(122)\n",
"plt.sca(axes[1])\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
"plot_svc_decision_boundary(svm_clf2, 4, 6)\n",
"plot_svc_decision_boundary(svm_clf2, 4, 5.99)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.title(\"$C = {}$\".format(svm_clf2.C), fontsize=16)\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"plt.axis([4, 5.9, 0.8, 2.8])\n",
"\n",
"save_fig(\"regularization_plot\")"
]
@ -407,7 +409,7 @@
"X2D = np.c_[X1D, X1D**2]\n",
"y = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n",
"\n",
"plt.figure(figsize=(11, 4))\n",
"plt.figure(figsize=(10, 3))\n",
"\n",
"plt.subplot(121)\n",
"plt.grid(True, which='both')\n",
@ -425,7 +427,7 @@
"plt.plot(X2D[:, 0][y==0], X2D[:, 1][y==0], \"bs\")\n",
"plt.plot(X2D[:, 0][y==1], X2D[:, 1][y==1], \"g^\")\n",
"plt.xlabel(r\"$x_1$\", fontsize=20)\n",
"plt.ylabel(r\"$x_2$\", fontsize=20, rotation=0)\n",
"plt.ylabel(r\"$x_2$  \", fontsize=20, rotation=0)\n",
"plt.gca().get_yaxis().set_ticks([0, 4, 8, 12, 16])\n",
"plt.plot([-4.5, 4.5], [6.5, 6.5], \"r--\", linewidth=3)\n",
"plt.axis([-4.5, 4.5, -1, 17])\n",
@ -533,17 +535,18 @@
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(11, 4))\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10.5, 4), sharey=True)\n",
"\n",
"plt.subplot(121)\n",
"plot_predictions(poly_kernel_svm_clf, [-1.5, 2.5, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
"plt.sca(axes[0])\n",
"plot_predictions(poly_kernel_svm_clf, [-1.5, 2.45, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.4, -1, 1.5])\n",
"plt.title(r\"$d=3, r=1, C=5$\", fontsize=18)\n",
"\n",
"plt.subplot(122)\n",
"plot_predictions(poly100_kernel_svm_clf, [-1.5, 2.5, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
"plt.sca(axes[1])\n",
"plot_predictions(poly100_kernel_svm_clf, [-1.5, 2.45, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.4, -1, 1.5])\n",
"plt.title(r\"$d=10, r=100, C=5$\", fontsize=18)\n",
"plt.ylabel(\"\")\n",
"\n",
"save_fig(\"moons_kernelized_polynomial_svc_plot\")\n",
"plt.show()"
@ -569,7 +572,7 @@
"XK = np.c_[gaussian_rbf(X1D, -2, gamma), gaussian_rbf(X1D, 1, gamma)]\n",
"yk = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n",
"\n",
"plt.figure(figsize=(11, 4))\n",
"plt.figure(figsize=(10.5, 4))\n",
"\n",
"plt.subplot(121)\n",
"plt.grid(True, which='both')\n",
@ -600,7 +603,7 @@
"plt.plot(XK[:, 0][yk==0], XK[:, 1][yk==0], \"bs\")\n",
"plt.plot(XK[:, 0][yk==1], XK[:, 1][yk==1], \"g^\")\n",
"plt.xlabel(r\"$x_2$\", fontsize=20)\n",
"plt.ylabel(r\"$x_3$ \", fontsize=20, rotation=0)\n",
"plt.ylabel(r\"$x_3$  \", fontsize=20, rotation=0)\n",
"plt.annotate(r'$\\phi\\left(\\mathbf{x}\\right)$',\n",
" xy=(XK[3, 0], XK[3, 1]),\n",
" xytext=(0.65, 0.50),\n",
@ -665,14 +668,18 @@
" rbf_kernel_svm_clf.fit(X, y)\n",
" svm_clfs.append(rbf_kernel_svm_clf)\n",
"\n",
"plt.figure(figsize=(11, 7))\n",
"fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10.5, 7), sharex=True, sharey=True)\n",
"\n",
"for i, svm_clf in enumerate(svm_clfs):\n",
" plt.subplot(221 + i)\n",
" plot_predictions(svm_clf, [-1.5, 2.5, -1, 1.5])\n",
" plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
" plt.sca(axes[i // 2, i % 2])\n",
" plot_predictions(svm_clf, [-1.5, 2.45, -1, 1.5])\n",
" plot_dataset(X, y, [-1.5, 2.45, -1, 1.5])\n",
" gamma, C = hyperparams[i]\n",
" plt.title(r\"$\\gamma = {}, C = {}$\".format(gamma, C), fontsize=16)\n",
" if i in (0, 1):\n",
" plt.xlabel(\"\")\n",
" if i in (1, 3):\n",
" plt.ylabel(\"\")\n",
"\n",
"save_fig(\"moons_rbf_svc_plot\")\n",
"plt.show()"
@ -750,8 +757,8 @@
" plt.legend(loc=\"upper left\", fontsize=18)\n",
" plt.axis(axes)\n",
"\n",
"plt.figure(figsize=(9, 4))\n",
"plt.subplot(121)\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 4), sharey=True)\n",
"plt.sca(axes[0])\n",
"plot_svm_regression(svm_reg1, X, y, [0, 2, 3, 11])\n",
"plt.title(r\"$\\epsilon = {}$\".format(svm_reg1.epsilon), fontsize=18)\n",
"plt.ylabel(r\"$y$\", fontsize=18, rotation=0)\n",
@ -762,7 +769,7 @@
" textcoords='data', arrowprops={'arrowstyle': '<->', 'linewidth': 1.5}\n",
" )\n",
"plt.text(0.91, 5.6, r\"$\\epsilon$\", fontsize=20)\n",
"plt.subplot(122)\n",
"plt.sca(axes[1])\n",
"plot_svm_regression(svm_reg2, X, y, [0, 2, 3, 11])\n",
"plt.title(r\"$\\epsilon = {}$\".format(svm_reg2.epsilon), fontsize=18)\n",
"save_fig(\"svm_regression_plot\")\n",
@ -820,12 +827,12 @@
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(9, 4))\n",
"plt.subplot(121)\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 4), sharey=True)\n",
"plt.sca(axes[0])\n",
"plot_svm_regression(svm_poly_reg1, X, y, [-1, 1, 0, 1])\n",
"plt.title(r\"$degree={}, C={}, \\epsilon = {}$\".format(svm_poly_reg1.degree, svm_poly_reg1.C, svm_poly_reg1.epsilon), fontsize=18)\n",
"plt.ylabel(r\"$y$\", fontsize=18, rotation=0)\n",
"plt.subplot(122)\n",
"plt.sca(axes[1])\n",
"plot_svm_regression(svm_poly_reg2, X, y, [-1, 1, 0, 1])\n",
"plt.title(r\"$degree={}, C={}, \\epsilon = {}$\".format(svm_poly_reg2.degree, svm_poly_reg2.C, svm_poly_reg2.epsilon), fontsize=18)\n",
"save_fig(\"svm_with_polynomial_kernel_plot\")\n",
@ -923,13 +930,13 @@
" plt.axis(x1_lim + [-2, 2])\n",
" plt.xlabel(r\"$x_1$\", fontsize=16)\n",
" if ylabel:\n",
" plt.ylabel(r\"$w_1 x_1$ \", rotation=0, fontsize=16)\n",
" plt.ylabel(r\"$w_1 x_1$  \", rotation=0, fontsize=16)\n",
" plt.title(r\"$w_1 = {}$\".format(w), fontsize=16)\n",
"\n",
"plt.figure(figsize=(12, 3.2))\n",
"plt.subplot(121)\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 3.2), sharey=True)\n",
"plt.sca(axes[0])\n",
"plot_2D_decision_function(1, 0)\n",
"plt.subplot(122)\n",
"plt.sca(axes[1])\n",
"plot_2D_decision_function(0.5, 0, ylabel=False)\n",
"save_fig(\"small_w_large_margin_plot\")\n",
"plt.show()"
@ -1154,8 +1161,8 @@
"outputs": [],
"source": [
"yr = y.ravel()\n",
"plt.figure(figsize=(12,3.2))\n",
"plt.subplot(121)\n",
"fig, axes = plt.subplots(ncols=2, figsize=(11, 3.2), sharey=True)\n",
"plt.sca(axes[0])\n",
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\", label=\"Iris virginica\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\", label=\"Not Iris virginica\")\n",
"plot_svc_decision_boundary(svm_clf, 4, 6)\n",
@ -1165,7 +1172,7 @@
"plt.axis([4, 6, 0.8, 2.8])\n",
"plt.legend(loc=\"upper left\")\n",
"\n",
"plt.subplot(122)\n",
"plt.sca(axes[1])\n",
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\")\n",
"plot_svc_decision_boundary(svm_clf2, 4, 6)\n",
@ -1816,7 +1823,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
"version": "3.7.4"
},
"nav_menu": {},
"toc": {