From 37abd9c4d54dab8ecfb1418c803784c1380d1f8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Mon, 22 Nov 2021 10:19:04 +1300 Subject: [PATCH] Clarify the 'not in the book' comments, and rename decision_tree_instability_plot to decision_tree_high_variance_plot --- 05_decision_trees.ipynb | 69 ++++++----------------------------------- 1 file changed, 10 insertions(+), 59 deletions(-) diff --git a/05_decision_trees.ipynb b/05_decision_trees.ipynb index febe245..5ba2354 100644 --- a/05_decision_trees.ipynb +++ b/05_decision_trees.ipynb @@ -202,13 +202,6 @@ "# Making Predictions" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Code to generate Figure 5–2. Decision Tree decision boundaries**" - ] - }, { "cell_type": "code", "execution_count": 9, @@ -218,7 +211,7 @@ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", - "# not in the book\n", + "# not in the book – just formatting details\n", "from matplotlib.colors import ListedColormap\n", "custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n", "plt.figure(figsize=(8, 4))\n", @@ -231,7 +224,7 @@ " plt.plot(X_iris[:, 0][y_iris == idx], X_iris[:, 1][y_iris == idx],\n", " style, label=f\"Iris {name}\")\n", "\n", - "# not in the book\n", + "# not in the book – this section beautifies and saves Figure 5–2\n", "tree_clf_deeper = DecisionTreeClassifier(max_depth=3, random_state=42)\n", "tree_clf_deeper.fit(X_iris, y_iris)\n", "th0, th1, th2a, th2b = tree_clf_deeper.tree_.threshold[[0, 2, 3, 6]]\n", @@ -324,13 +317,6 @@ "# Regularization Hyperparameters" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Code to generate Figure 5–3. Regularization using min_samples_leaf:**" - ] - }, { "cell_type": "code", "execution_count": 14, @@ -353,7 +339,7 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", + "# not in the book – this cell generates and saves Figure 5–3\n", "\n", "def plot_decision_boundary(clf, X, y, axes, cmap):\n", " x1, x2 = np.meshgrid(np.linspace(axes[0], axes[1], 100),\n", @@ -443,20 +429,13 @@ "tree_reg.fit(X_quad, y_quad)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Code to generate Figure 5–4. A Decision Tree for regression:**" - ] - }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", + "# not in the book – we've already seen how to use export_graphviz()\n", "export_graphviz(\n", " tree_reg,\n", " out_file=str(IMAGES_PATH / \"regression_tree.dot\"),\n", @@ -477,13 +456,6 @@ "tree_reg2.fit(X_quad, y_quad)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Code to generate Figure 5–5. Predictions of two Decision Tree regression models:**" - ] - }, { "cell_type": "code", "execution_count": 21, @@ -508,7 +480,7 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", + "# not in the book – this cell generates and saves Figure 5–5\n", "\n", "def plot_regression_predictions(tree_reg, X, y, axes=[-0.5, 0.5, -0.05, 0.25]):\n", " x1 = np.linspace(axes[0], axes[1], 500).reshape(-1, 1)\n", @@ -546,20 +518,13 @@ "plt.show()" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Code to generate Figure 5–6. Regularizing a Decision Tree regressor:**" - ] - }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", + "# not in the book – this cell generates and saves Figure 5–6\n", "\n", "tree_reg1 = DecisionTreeRegressor(random_state=42)\n", "tree_reg2 = DecisionTreeRegressor(random_state=42, min_samples_leaf=10)\n", @@ -606,20 +571,13 @@ "Rotating the dataset also leads to completely different decision boundaries:" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Code to generate Figure 5–7. Sensitivity to training set rotation**" - ] - }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", + "# not in the book – this cell generates and saves Figure 5–7\n", "\n", "np.random.seed(6)\n", "X_square = np.random.rand(100, 2) - 0.5\n", @@ -670,7 +628,7 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", + "# not in the book – this cell generates and saves Figure 5–8\n", "\n", "plt.figure(figsize=(8, 4))\n", "\n", @@ -727,20 +685,13 @@ "tree_clf_tweaked.fit(X_iris, y_iris)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Code to generate Figure 5–8. Sensitivity to training set details:**" - ] - }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", + "# not in the book – this cell generates and saves Figure 5–9\n", "\n", "plt.figure(figsize=(8, 4))\n", "y_pred = tree_clf_tweaked.predict(X_iris_all).reshape(lengths.shape)\n", @@ -759,7 +710,7 @@ "plt.ylabel(\"Petal width (cm)\")\n", "plt.axis([0, 7.2, 0, 3])\n", "plt.legend()\n", - "save_fig(\"decision_tree_instability_plot\")\n", + "save_fig(\"decision_tree_high_variance_plot\")\n", "\n", "plt.show()" ]