From 7e6489f8a4354face51ceb474fdb282f11d261f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Sun, 13 Oct 2019 16:58:36 +0800 Subject: [PATCH] Improve a few figures (e.g., add missing labels, share axes, etc.) --- 06_decision_trees.ipynb | 32 ++++++------ 07_ensemble_learning_and_random_forests.ipynb | 51 ++++++++++++------- 2 files changed, 51 insertions(+), 32 deletions(-) diff --git a/06_decision_trees.ipynb b/06_decision_trees.ipynb index 15fef59..ca5962f 100644 --- a/06_decision_trees.ipynb +++ b/06_decision_trees.ipynb @@ -249,13 +249,14 @@ "deep_tree_clf1.fit(Xm, ym)\n", "deep_tree_clf2.fit(Xm, ym)\n", "\n", - "plt.figure(figsize=(11, 4))\n", - "plt.subplot(121)\n", - "plot_decision_boundary(deep_tree_clf1, Xm, ym, axes=[-1.5, 2.5, -1, 1.5], iris=False)\n", + "fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n", + "plt.sca(axes[0])\n", + "plot_decision_boundary(deep_tree_clf1, Xm, ym, axes=[-1.5, 2.4, -1, 1.5], iris=False)\n", "plt.title(\"No restrictions\", fontsize=16)\n", - "plt.subplot(122)\n", - "plot_decision_boundary(deep_tree_clf2, Xm, ym, axes=[-1.5, 2.5, -1, 1.5], iris=False)\n", + "plt.sca(axes[1])\n", + "plot_decision_boundary(deep_tree_clf2, Xm, ym, axes=[-1.5, 2.4, -1, 1.5], iris=False)\n", "plt.title(\"min_samples_leaf = {}\".format(deep_tree_clf2.min_samples_leaf), fontsize=14)\n", + "plt.ylabel(\"\")\n", "\n", "save_fig(\"min_samples_leaf_plot\")\n", "plt.show()" @@ -299,11 +300,12 @@ "tree_clf_sr = DecisionTreeClassifier(random_state=42)\n", "tree_clf_sr.fit(Xsr, ys)\n", "\n", - "plt.figure(figsize=(11, 4))\n", - "plt.subplot(121)\n", + "fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n", + "plt.sca(axes[0])\n", "plot_decision_boundary(tree_clf_s, Xs, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)\n", - "plt.subplot(122)\n", + "plt.sca(axes[1])\n", "plot_decision_boundary(tree_clf_sr, Xsr, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)\n", + "plt.ylabel(\"\")\n", "\n", "save_fig(\"sensitivity_to_rotation_plot\")\n", "plt.show()" @@ -365,8 +367,8 @@ " plt.plot(X, y, \"b.\")\n", " plt.plot(x1, y_pred, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n", "\n", - "plt.figure(figsize=(11, 4))\n", - "plt.subplot(121)\n", + "fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n", + "plt.sca(axes[0])\n", "plot_regression_predictions(tree_reg1, X, y)\n", "for split, style in ((0.1973, \"k-\"), (0.0917, \"k--\"), (0.7718, \"k--\")):\n", " plt.plot([split, split], [-0.2, 1], style, linewidth=2)\n", @@ -376,7 +378,7 @@ "plt.legend(loc=\"upper center\", fontsize=18)\n", "plt.title(\"max_depth=2\", fontsize=14)\n", "\n", - "plt.subplot(122)\n", + "plt.sca(axes[1])\n", "plot_regression_predictions(tree_reg2, X, y, ylabel=None)\n", "for split, style in ((0.1973, \"k-\"), (0.0917, \"k--\"), (0.7718, \"k--\")):\n", " plt.plot([split, split], [-0.2, 1], style, linewidth=2)\n", @@ -428,9 +430,9 @@ "y_pred1 = tree_reg1.predict(x1)\n", "y_pred2 = tree_reg2.predict(x1)\n", "\n", - "plt.figure(figsize=(11, 4))\n", + "fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n", "\n", - "plt.subplot(121)\n", + "plt.sca(axes[0])\n", "plt.plot(X, y, \"b.\")\n", "plt.plot(x1, y_pred1, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n", "plt.axis([0, 1, -0.2, 1.1])\n", @@ -439,7 +441,7 @@ "plt.legend(loc=\"upper center\", fontsize=18)\n", "plt.title(\"No restrictions\", fontsize=14)\n", "\n", - "plt.subplot(122)\n", + "plt.sca(axes[1])\n", "plt.plot(X, y, \"b.\")\n", "plt.plot(x1, y_pred2, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n", "plt.axis([0, 1, -0.2, 1.1])\n", @@ -720,7 +722,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.8" + "version": "3.7.4" }, "nav_menu": { "height": "309px", diff --git a/07_ensemble_learning_and_random_forests.ipynb b/07_ensemble_learning_and_random_forests.ipynb index 432e422..19e2584 100644 --- a/07_ensemble_learning_and_random_forests.ipynb +++ b/07_ensemble_learning_and_random_forests.ipynb @@ -260,7 +260,7 @@ "source": [ "from matplotlib.colors import ListedColormap\n", "\n", - "def plot_decision_boundary(clf, X, y, axes=[-1.5, 2.5, -1, 1.5], alpha=0.5, contour=True):\n", + "def plot_decision_boundary(clf, X, y, axes=[-1.5, 2.45, -1, 1.5], alpha=0.5, contour=True):\n", " x1s = np.linspace(axes[0], axes[1], 100)\n", " x2s = np.linspace(axes[2], axes[3], 100)\n", " x1, x2 = np.meshgrid(x1s, x2s)\n", @@ -284,13 +284,14 @@ "metadata": {}, "outputs": [], "source": [ - "plt.figure(figsize=(11,4))\n", - "plt.subplot(121)\n", + "fix, axes = plt.subplots(ncols=2, figsize=(10,4), sharey=True)\n", + "plt.sca(axes[0])\n", "plot_decision_boundary(tree_clf, X, y)\n", "plt.title(\"Decision Tree\", fontsize=14)\n", - "plt.subplot(122)\n", + "plt.sca(axes[1])\n", "plot_decision_boundary(bag_clf, X, y)\n", "plt.title(\"Decision Trees with Bagging\", fontsize=14)\n", + "plt.ylabel(\"\")\n", "save_fig(\"decision_tree_without_and_with_bagging_plot\")\n", "plt.show()" ] @@ -381,7 +382,7 @@ " tree_clf = DecisionTreeClassifier(max_leaf_nodes=16, random_state=42 + i)\n", " indices_with_replacement = np.random.randint(0, len(X_train), len(X_train))\n", " tree_clf.fit(X[indices_with_replacement], y[indices_with_replacement])\n", - " plot_decision_boundary(tree_clf, X, y, axes=[-1.5, 2.5, -1, 1.5], alpha=0.02, contour=False)\n", + " plot_decision_boundary(tree_clf, X, y, axes=[-1.5, 2.45, -1, 1.5], alpha=0.02, contour=False)\n", "\n", "plt.show()" ] @@ -521,10 +522,10 @@ "source": [ "m = len(X_train)\n", "\n", - "plt.figure(figsize=(11, 4))\n", - "for subplot, learning_rate in ((121, 1), (122, 0.5)):\n", + "fix, axes = plt.subplots(ncols=2, figsize=(10,4), sharey=True)\n", + "for subplot, learning_rate in ((0, 1), (1, 0.5)):\n", " sample_weights = np.ones(m)\n", - " plt.subplot(subplot)\n", + " plt.sca(axes[subplot])\n", " for i in range(5):\n", " svm_clf = SVC(kernel=\"rbf\", C=0.05, gamma=\"scale\", random_state=42)\n", " svm_clf.fit(X_train, y_train, sample_weight=sample_weights)\n", @@ -532,12 +533,14 @@ " sample_weights[y_pred != y_train] *= (1 + learning_rate)\n", " plot_decision_boundary(svm_clf, X, y, alpha=0.2)\n", " plt.title(\"learning_rate = {}\".format(learning_rate), fontsize=16)\n", - " if subplot == 121:\n", + " if subplot == 0:\n", " plt.text(-0.7, -0.65, \"1\", fontsize=14)\n", " plt.text(-0.6, -0.10, \"2\", fontsize=14)\n", " plt.text(-0.5, 0.10, \"3\", fontsize=14)\n", " plt.text(-0.4, 0.55, \"4\", fontsize=14)\n", " plt.text(-0.3, 0.90, \"5\", fontsize=14)\n", + " else:\n", + " plt.ylabel(\"\")\n", "\n", "save_fig(\"boosting_plot\")\n", "plt.show()" @@ -715,15 +718,18 @@ "metadata": {}, "outputs": [], "source": [ - "plt.figure(figsize=(11,4))\n", + "fix, axes = plt.subplots(ncols=2, figsize=(10,4), sharey=True)\n", "\n", - "plt.subplot(121)\n", + "plt.sca(axes[0])\n", "plot_predictions([gbrt], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"Ensemble predictions\")\n", "plt.title(\"learning_rate={}, n_estimators={}\".format(gbrt.learning_rate, gbrt.n_estimators), fontsize=14)\n", + "plt.xlabel(\"$x_1$\", fontsize=16)\n", + "plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n", "\n", - "plt.subplot(122)\n", + "plt.sca(axes[1])\n", "plot_predictions([gbrt_slow], X, y, axes=[-0.5, 0.5, -0.1, 0.8])\n", "plt.title(\"learning_rate={}, n_estimators={}\".format(gbrt_slow.learning_rate, gbrt_slow.n_estimators), fontsize=14)\n", + "plt.xlabel(\"$x_1$\", fontsize=16)\n", "\n", "save_fig(\"gbrt_learning_rate_plot\")\n", "plt.show()" @@ -774,7 +780,7 @@ "metadata": {}, "outputs": [], "source": [ - "plt.figure(figsize=(11, 4))\n", + "plt.figure(figsize=(10, 4))\n", "\n", "plt.subplot(121)\n", "plt.plot(errors, \"b.-\")\n", @@ -784,11 +790,14 @@ "plt.text(bst_n_estimators, min_error*1.2, \"Minimum\", ha=\"center\", fontsize=14)\n", "plt.axis([0, 120, 0, 0.01])\n", "plt.xlabel(\"Number of trees\")\n", + "plt.ylabel(\"Error\", fontsize=16)\n", "plt.title(\"Validation error\", fontsize=14)\n", "\n", "plt.subplot(122)\n", "plot_predictions([gbrt_best], X, y, axes=[-0.5, 0.5, -0.1, 0.8])\n", "plt.title(\"Best model (%d trees)\" % bst_n_estimators, fontsize=14)\n", + "plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n", + "plt.xlabel(\"$x_1$\", fontsize=16)\n", "\n", "save_fig(\"early_stopping_gbrt_plot\")\n", "plt.show()" @@ -1200,7 +1209,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "That's a significant improvement, and it's much better than each of the individual classifiers." + "Nope, hard voting wins in this case." ] }, { @@ -1216,6 +1225,7 @@ "metadata": {}, "outputs": [], "source": [ + "voting_clf.voting = \"hard\"\n", "voting_clf.score(X_test, y_test)" ] }, @@ -1232,7 +1242,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The voting classifier reduced the error rate from about 4.0% for our best model (the `MLPClassifier`) to just 3.1%. That's about 22.5% less errors, not bad!" + "The voting classifier only very slightly reduced the error rate of the best model in this case." ] }, { @@ -1346,8 +1356,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This stacking ensemble does not perform as well as the soft voting classifier we trained earlier, it's just as good as the best individual classifier." + "This stacking ensemble does not perform as well as the voting classifier we trained earlier, it's not quite as good as the best individual classifier." ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -1366,7 +1383,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.8" + "version": "3.7.4" }, "nav_menu": { "height": "252px",