Improve a few figures (e.g., add missing labels, share axes, etc.)
parent
23b6366c39
commit
7e6489f8a4
|
@ -249,13 +249,14 @@
|
|||
"deep_tree_clf1.fit(Xm, ym)\n",
|
||||
"deep_tree_clf2.fit(Xm, ym)\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(11, 4))\n",
|
||||
"plt.subplot(121)\n",
|
||||
"plot_decision_boundary(deep_tree_clf1, Xm, ym, axes=[-1.5, 2.5, -1, 1.5], iris=False)\n",
|
||||
"fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n",
|
||||
"plt.sca(axes[0])\n",
|
||||
"plot_decision_boundary(deep_tree_clf1, Xm, ym, axes=[-1.5, 2.4, -1, 1.5], iris=False)\n",
|
||||
"plt.title(\"No restrictions\", fontsize=16)\n",
|
||||
"plt.subplot(122)\n",
|
||||
"plot_decision_boundary(deep_tree_clf2, Xm, ym, axes=[-1.5, 2.5, -1, 1.5], iris=False)\n",
|
||||
"plt.sca(axes[1])\n",
|
||||
"plot_decision_boundary(deep_tree_clf2, Xm, ym, axes=[-1.5, 2.4, -1, 1.5], iris=False)\n",
|
||||
"plt.title(\"min_samples_leaf = {}\".format(deep_tree_clf2.min_samples_leaf), fontsize=14)\n",
|
||||
"plt.ylabel(\"\")\n",
|
||||
"\n",
|
||||
"save_fig(\"min_samples_leaf_plot\")\n",
|
||||
"plt.show()"
|
||||
|
@ -299,11 +300,12 @@
|
|||
"tree_clf_sr = DecisionTreeClassifier(random_state=42)\n",
|
||||
"tree_clf_sr.fit(Xsr, ys)\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(11, 4))\n",
|
||||
"plt.subplot(121)\n",
|
||||
"fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n",
|
||||
"plt.sca(axes[0])\n",
|
||||
"plot_decision_boundary(tree_clf_s, Xs, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)\n",
|
||||
"plt.subplot(122)\n",
|
||||
"plt.sca(axes[1])\n",
|
||||
"plot_decision_boundary(tree_clf_sr, Xsr, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)\n",
|
||||
"plt.ylabel(\"\")\n",
|
||||
"\n",
|
||||
"save_fig(\"sensitivity_to_rotation_plot\")\n",
|
||||
"plt.show()"
|
||||
|
@ -365,8 +367,8 @@
|
|||
" plt.plot(X, y, \"b.\")\n",
|
||||
" plt.plot(x1, y_pred, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(11, 4))\n",
|
||||
"plt.subplot(121)\n",
|
||||
"fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n",
|
||||
"plt.sca(axes[0])\n",
|
||||
"plot_regression_predictions(tree_reg1, X, y)\n",
|
||||
"for split, style in ((0.1973, \"k-\"), (0.0917, \"k--\"), (0.7718, \"k--\")):\n",
|
||||
" plt.plot([split, split], [-0.2, 1], style, linewidth=2)\n",
|
||||
|
@ -376,7 +378,7 @@
|
|||
"plt.legend(loc=\"upper center\", fontsize=18)\n",
|
||||
"plt.title(\"max_depth=2\", fontsize=14)\n",
|
||||
"\n",
|
||||
"plt.subplot(122)\n",
|
||||
"plt.sca(axes[1])\n",
|
||||
"plot_regression_predictions(tree_reg2, X, y, ylabel=None)\n",
|
||||
"for split, style in ((0.1973, \"k-\"), (0.0917, \"k--\"), (0.7718, \"k--\")):\n",
|
||||
" plt.plot([split, split], [-0.2, 1], style, linewidth=2)\n",
|
||||
|
@ -428,9 +430,9 @@
|
|||
"y_pred1 = tree_reg1.predict(x1)\n",
|
||||
"y_pred2 = tree_reg2.predict(x1)\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(11, 4))\n",
|
||||
"fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n",
|
||||
"\n",
|
||||
"plt.subplot(121)\n",
|
||||
"plt.sca(axes[0])\n",
|
||||
"plt.plot(X, y, \"b.\")\n",
|
||||
"plt.plot(x1, y_pred1, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n",
|
||||
"plt.axis([0, 1, -0.2, 1.1])\n",
|
||||
|
@ -439,7 +441,7 @@
|
|||
"plt.legend(loc=\"upper center\", fontsize=18)\n",
|
||||
"plt.title(\"No restrictions\", fontsize=14)\n",
|
||||
"\n",
|
||||
"plt.subplot(122)\n",
|
||||
"plt.sca(axes[1])\n",
|
||||
"plt.plot(X, y, \"b.\")\n",
|
||||
"plt.plot(x1, y_pred2, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n",
|
||||
"plt.axis([0, 1, -0.2, 1.1])\n",
|
||||
|
@ -720,7 +722,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.8"
|
||||
"version": "3.7.4"
|
||||
},
|
||||
"nav_menu": {
|
||||
"height": "309px",
|
||||
|
|
|
@ -260,7 +260,7 @@
|
|||
"source": [
|
||||
"from matplotlib.colors import ListedColormap\n",
|
||||
"\n",
|
||||
"def plot_decision_boundary(clf, X, y, axes=[-1.5, 2.5, -1, 1.5], alpha=0.5, contour=True):\n",
|
||||
"def plot_decision_boundary(clf, X, y, axes=[-1.5, 2.45, -1, 1.5], alpha=0.5, contour=True):\n",
|
||||
" x1s = np.linspace(axes[0], axes[1], 100)\n",
|
||||
" x2s = np.linspace(axes[2], axes[3], 100)\n",
|
||||
" x1, x2 = np.meshgrid(x1s, x2s)\n",
|
||||
|
@ -284,13 +284,14 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(11,4))\n",
|
||||
"plt.subplot(121)\n",
|
||||
"fix, axes = plt.subplots(ncols=2, figsize=(10,4), sharey=True)\n",
|
||||
"plt.sca(axes[0])\n",
|
||||
"plot_decision_boundary(tree_clf, X, y)\n",
|
||||
"plt.title(\"Decision Tree\", fontsize=14)\n",
|
||||
"plt.subplot(122)\n",
|
||||
"plt.sca(axes[1])\n",
|
||||
"plot_decision_boundary(bag_clf, X, y)\n",
|
||||
"plt.title(\"Decision Trees with Bagging\", fontsize=14)\n",
|
||||
"plt.ylabel(\"\")\n",
|
||||
"save_fig(\"decision_tree_without_and_with_bagging_plot\")\n",
|
||||
"plt.show()"
|
||||
]
|
||||
|
@ -381,7 +382,7 @@
|
|||
" tree_clf = DecisionTreeClassifier(max_leaf_nodes=16, random_state=42 + i)\n",
|
||||
" indices_with_replacement = np.random.randint(0, len(X_train), len(X_train))\n",
|
||||
" tree_clf.fit(X[indices_with_replacement], y[indices_with_replacement])\n",
|
||||
" plot_decision_boundary(tree_clf, X, y, axes=[-1.5, 2.5, -1, 1.5], alpha=0.02, contour=False)\n",
|
||||
" plot_decision_boundary(tree_clf, X, y, axes=[-1.5, 2.45, -1, 1.5], alpha=0.02, contour=False)\n",
|
||||
"\n",
|
||||
"plt.show()"
|
||||
]
|
||||
|
@ -521,10 +522,10 @@
|
|||
"source": [
|
||||
"m = len(X_train)\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(11, 4))\n",
|
||||
"for subplot, learning_rate in ((121, 1), (122, 0.5)):\n",
|
||||
"fix, axes = plt.subplots(ncols=2, figsize=(10,4), sharey=True)\n",
|
||||
"for subplot, learning_rate in ((0, 1), (1, 0.5)):\n",
|
||||
" sample_weights = np.ones(m)\n",
|
||||
" plt.subplot(subplot)\n",
|
||||
" plt.sca(axes[subplot])\n",
|
||||
" for i in range(5):\n",
|
||||
" svm_clf = SVC(kernel=\"rbf\", C=0.05, gamma=\"scale\", random_state=42)\n",
|
||||
" svm_clf.fit(X_train, y_train, sample_weight=sample_weights)\n",
|
||||
|
@ -532,12 +533,14 @@
|
|||
" sample_weights[y_pred != y_train] *= (1 + learning_rate)\n",
|
||||
" plot_decision_boundary(svm_clf, X, y, alpha=0.2)\n",
|
||||
" plt.title(\"learning_rate = {}\".format(learning_rate), fontsize=16)\n",
|
||||
" if subplot == 121:\n",
|
||||
" if subplot == 0:\n",
|
||||
" plt.text(-0.7, -0.65, \"1\", fontsize=14)\n",
|
||||
" plt.text(-0.6, -0.10, \"2\", fontsize=14)\n",
|
||||
" plt.text(-0.5, 0.10, \"3\", fontsize=14)\n",
|
||||
" plt.text(-0.4, 0.55, \"4\", fontsize=14)\n",
|
||||
" plt.text(-0.3, 0.90, \"5\", fontsize=14)\n",
|
||||
" else:\n",
|
||||
" plt.ylabel(\"\")\n",
|
||||
"\n",
|
||||
"save_fig(\"boosting_plot\")\n",
|
||||
"plt.show()"
|
||||
|
@ -715,15 +718,18 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(11,4))\n",
|
||||
"fix, axes = plt.subplots(ncols=2, figsize=(10,4), sharey=True)\n",
|
||||
"\n",
|
||||
"plt.subplot(121)\n",
|
||||
"plt.sca(axes[0])\n",
|
||||
"plot_predictions([gbrt], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"Ensemble predictions\")\n",
|
||||
"plt.title(\"learning_rate={}, n_estimators={}\".format(gbrt.learning_rate, gbrt.n_estimators), fontsize=14)\n",
|
||||
"plt.xlabel(\"$x_1$\", fontsize=16)\n",
|
||||
"plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n",
|
||||
"\n",
|
||||
"plt.subplot(122)\n",
|
||||
"plt.sca(axes[1])\n",
|
||||
"plot_predictions([gbrt_slow], X, y, axes=[-0.5, 0.5, -0.1, 0.8])\n",
|
||||
"plt.title(\"learning_rate={}, n_estimators={}\".format(gbrt_slow.learning_rate, gbrt_slow.n_estimators), fontsize=14)\n",
|
||||
"plt.xlabel(\"$x_1$\", fontsize=16)\n",
|
||||
"\n",
|
||||
"save_fig(\"gbrt_learning_rate_plot\")\n",
|
||||
"plt.show()"
|
||||
|
@ -774,7 +780,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(11, 4))\n",
|
||||
"plt.figure(figsize=(10, 4))\n",
|
||||
"\n",
|
||||
"plt.subplot(121)\n",
|
||||
"plt.plot(errors, \"b.-\")\n",
|
||||
|
@ -784,11 +790,14 @@
|
|||
"plt.text(bst_n_estimators, min_error*1.2, \"Minimum\", ha=\"center\", fontsize=14)\n",
|
||||
"plt.axis([0, 120, 0, 0.01])\n",
|
||||
"plt.xlabel(\"Number of trees\")\n",
|
||||
"plt.ylabel(\"Error\", fontsize=16)\n",
|
||||
"plt.title(\"Validation error\", fontsize=14)\n",
|
||||
"\n",
|
||||
"plt.subplot(122)\n",
|
||||
"plot_predictions([gbrt_best], X, y, axes=[-0.5, 0.5, -0.1, 0.8])\n",
|
||||
"plt.title(\"Best model (%d trees)\" % bst_n_estimators, fontsize=14)\n",
|
||||
"plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n",
|
||||
"plt.xlabel(\"$x_1$\", fontsize=16)\n",
|
||||
"\n",
|
||||
"save_fig(\"early_stopping_gbrt_plot\")\n",
|
||||
"plt.show()"
|
||||
|
@ -1200,7 +1209,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"That's a significant improvement, and it's much better than each of the individual classifiers."
|
||||
"Nope, hard voting wins in this case."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1216,6 +1225,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"voting_clf.voting = \"hard\"\n",
|
||||
"voting_clf.score(X_test, y_test)"
|
||||
]
|
||||
},
|
||||
|
@ -1232,7 +1242,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The voting classifier reduced the error rate from about 4.0% for our best model (the `MLPClassifier`) to just 3.1%. That's about 22.5% less errors, not bad!"
|
||||
"The voting classifier only very slightly reduced the error rate of the best model in this case."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1346,8 +1356,15 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This stacking ensemble does not perform as well as the soft voting classifier we trained earlier, it's just as good as the best individual classifier."
|
||||
"This stacking ensemble does not perform as well as the voting classifier we trained earlier, it's not quite as good as the best individual classifier."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -1366,7 +1383,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.8"
|
||||
"version": "3.7.4"
|
||||
},
|
||||
"nav_menu": {
|
||||
"height": "252px",
|
||||
|
|
Loading…
Reference in New Issue