From 787a34b5748a5d7a6eceafd7cc3288fb0ff3c27e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Mon, 8 Nov 2021 17:24:30 +1300 Subject: [PATCH] Big update for the 3rd edition of the book --- 05_decision_trees.ipynb | 1110 ++++++++++++++++++++++++++++----------- 1 file changed, 813 insertions(+), 297 deletions(-) diff --git a/05_decision_trees.ipynb b/05_decision_trees.ipynb index d7c59af..83e7659 100644 --- a/05_decision_trees.ipynb +++ b/05_decision_trees.ipynb @@ -30,7 +30,9 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "tags": [] + }, "source": [ "# Setup" ] @@ -39,7 +41,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures." + "This project requires Python 3.8 or above:" ] }, { @@ -48,30 +50,64 @@ "metadata": {}, "outputs": [], "source": [ - "# Python ≥3.8 is required\n", "import sys\n", - "assert sys.version_info >= (3, 8)\n", "\n", - "# Scikit-Learn ≥1.0 is required\n", + "assert sys.version_info >= (3, 8)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It also requires Scikit-Learn ≥ 1.0.1:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ "import sklearn\n", - "assert sklearn.__version__ >= \"1.0\"\n", "\n", - "# Common imports\n", - "import numpy as np\n", + "assert sklearn.__version__ >= \"1.0.1\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we did in previous chapters, let's define the default font sizes to make the figures prettier:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib as mpl\n", + "\n", + "mpl.rc('font', size=12)\n", + "mpl.rc('axes', labelsize=14, titlesize=14)\n", + "mpl.rc('legend', fontsize=14)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And let's create the `images/decision_trees` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ "from pathlib import Path\n", "\n", - "# to make this notebook's output stable across runs\n", - "np.random.seed(42)\n", - "\n", - "# To plot pretty figures\n", - "%matplotlib inline\n", - "import matplotlib as mpl\n", - "import matplotlib.pyplot as plt\n", - "mpl.rc('axes', labelsize=14)\n", - "mpl.rc('xtick', labelsize=12)\n", - "mpl.rc('ytick', labelsize=12)\n", - "\n", - "# Where to save the figures\n", "IMAGES_PATH = Path() / \"images\" / \"decision_trees\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n", "\n", @@ -91,110 +127,171 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import load_iris\n", "from sklearn.tree import DecisionTreeClassifier\n", "\n", - "iris = load_iris()\n", - "X = iris.data[:, 2:] # petal length and width\n", - "y = iris.target\n", + "iris = load_iris(as_frame=True)\n", + "X_iris = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n", + "y_iris = iris.target\n", "\n", "tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)\n", - "tree_clf.fit(X, y)" + "tree_clf.fit(X_iris, y_iris)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "**This code example generates Figure 6–1. Iris Decision Tree:**" + "**This code example generates Figure 5–1. Iris Decision Tree:**" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "from graphviz import Source\n", "from sklearn.tree import export_graphviz\n", "\n", "export_graphviz(\n", " tree_clf,\n", - " out_file=IMAGES_PATH / \"iris_tree.dot\",\n", - " feature_names=iris.feature_names[2:],\n", + " out_file=str(IMAGES_PATH / \"iris_tree.dot\"), # path differs in the book\n", + " feature_names=[\"petal length (cm)\", \"petal width (cm)\"],\n", " class_names=iris.target_names,\n", " rounded=True,\n", " filled=True\n", - " )\n", - "\n", - "Source.from_file(IMAGES_PATH / \"iris_tree.dot\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Making Predictions" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Code to generate Figure 6–2. Decision Tree decision boundaries**" + " )" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ + "from graphviz import Source\n", + "\n", + "Source.from_file(IMAGES_PATH / \"iris_tree.dot\") # path differs in the book" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Graphviz also provides the `dot` command line tool to convert `.dot` files to a variety of formats. The following command converts the dot file to a png image:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# not in the book\n", + "!dot -Tpng {IMAGES_PATH / \"iris_tree.dot\"} -o {IMAGES_PATH / \"iris_tree.png\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Making Predictions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Code to generate Figure 5–2. Decision Tree decision boundaries**" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# not in the book\n", "from matplotlib.colors import ListedColormap\n", - "\n", - "def plot_decision_boundary(clf, X, y, axes=[0, 7.5, 0, 3], iris=True, legend=False, plot_training=True):\n", - " x1s = np.linspace(axes[0], axes[1], 100)\n", - " x2s = np.linspace(axes[2], axes[3], 100)\n", - " x1, x2 = np.meshgrid(x1s, x2s)\n", - " X_new = np.c_[x1.ravel(), x2.ravel()]\n", - " y_pred = clf.predict(X_new).reshape(x1.shape)\n", - " custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n", - " plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap)\n", - " if not iris:\n", - " custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])\n", - " plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8)\n", - " if plot_training:\n", - " plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\", label=\"Iris setosa\")\n", - " plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\", label=\"Iris versicolor\")\n", - " plt.plot(X[:, 0][y==2], X[:, 1][y==2], \"g^\", label=\"Iris virginica\")\n", - " plt.axis(axes)\n", - " if iris:\n", - " plt.xlabel(\"Petal length\", fontsize=14)\n", - " plt.ylabel(\"Petal width\", fontsize=14)\n", - " else:\n", - " plt.xlabel(r\"$x_1$\", fontsize=18)\n", - " plt.ylabel(r\"$x_2$\", fontsize=18, rotation=0)\n", - " if legend:\n", - " plt.legend(loc=\"lower right\", fontsize=14)\n", - "\n", + "custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n", "plt.figure(figsize=(8, 4))\n", - "plot_decision_boundary(tree_clf, X, y)\n", - "plt.plot([2.45, 2.45], [0, 3], \"k-\", linewidth=2)\n", - "plt.plot([2.45, 7.5], [1.75, 1.75], \"k--\", linewidth=2)\n", - "plt.plot([4.95, 4.95], [0, 1.75], \"k:\", linewidth=2)\n", - "plt.plot([4.85, 4.85], [1.75, 3], \"k:\", linewidth=2)\n", - "plt.text(1.40, 1.0, \"Depth=0\", fontsize=15)\n", - "plt.text(3.2, 1.80, \"Depth=1\", fontsize=13)\n", - "plt.text(4.05, 0.5, \"(Depth=2)\", fontsize=11)\n", "\n", + "lengths, widths = np.meshgrid(np.linspace(0, 7.2, 100), np.linspace(0, 3, 100))\n", + "X_iris_all = np.c_[lengths.ravel(), widths.ravel()]\n", + "y_pred = tree_clf.predict(X_iris_all).reshape(lengths.shape)\n", + "plt.contourf(lengths, widths, y_pred, alpha=0.3, cmap=custom_cmap)\n", + "for idx, (name, style) in enumerate(zip(iris.target_names, (\"yo\", \"bs\", \"g^\"))):\n", + " plt.plot(X_iris[:, 0][y_iris == idx], X_iris[:, 1][y_iris == idx],\n", + " style, label=f\"Iris {name}\")\n", + "\n", + "# not in the book\n", + "tree_clf_deeper = DecisionTreeClassifier(max_depth=3, random_state=42)\n", + "tree_clf_deeper.fit(X_iris, y_iris)\n", + "th0, th1, th2a, th2b = tree_clf_deeper.tree_.threshold[[0, 2, 3, 6]]\n", + "plt.xlabel(\"Petal length (cm)\")\n", + "plt.ylabel(\"Petal width (cm)\")\n", + "plt.plot([th0, th0], [0, 3], \"k-\", linewidth=2)\n", + "plt.plot([th0, 7.2], [th1, th1], \"k--\", linewidth=2)\n", + "plt.plot([th2a, th2a], [0, th1], \"k:\", linewidth=2)\n", + "plt.plot([th2b, th2b], [th1, 3], \"k:\", linewidth=2)\n", + "plt.text(th0 - 0.05, 1.0, \"Depth=0\", horizontalalignment=\"right\", fontsize=15)\n", + "plt.text(3.2, th1 + 0.02, \"Depth=1\", verticalalignment=\"bottom\", fontsize=13)\n", + "plt.text(th2a + 0.05, 0.5, \"(Depth=2)\", fontsize=11)\n", + "plt.axis([0, 7.2, 0, 3])\n", + "plt.legend()\n", "save_fig(\"decision_tree_decision_boundaries_plot\")\n", + "\n", "plt.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can access the tree structure via the `tree_` attribute:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "tree_clf.tree_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For more information, check out this class's documentation:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# help(sklearn.tree._tree.Tree)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See the extra material section below for an example." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -204,16 +301,16 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ - "tree_clf.predict_proba([[5, 1.5]])" + "tree_clf.predict_proba([[5, 1.5]]).round(3)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -224,146 +321,89 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Regularization Hyperparameters" + "# Regularization Hyperparameters" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We've seen that small changes in the dataset (such as a rotation) may produce a very different Decision Tree.\n", - "Now let's show that training the same model on the same data may produce a very different model every time, since the CART training algorithm used by Scikit-Learn is stochastic. To show this, we will set `random_state` to a different value than earlier:" + "**Code to generate Figure 5–3. Regularization using min_samples_leaf:**" ] }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "tree_clf_tweaked = DecisionTreeClassifier(max_depth=2, random_state=40)\n", - "tree_clf_tweaked.fit(X, y)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Code to generate Figure 6–8. Sensitivity to training set details:**" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "plt.figure(figsize=(8, 4))\n", - "plot_decision_boundary(tree_clf_tweaked, X, y, legend=False)\n", - "plt.plot([0, 7.5], [0.8, 0.8], \"k-\", linewidth=2)\n", - "plt.plot([0, 7.5], [1.75, 1.75], \"k--\", linewidth=2)\n", - "plt.text(1.0, 0.9, \"Depth=0\", fontsize=15)\n", - "plt.text(1.0, 1.80, \"Depth=1\", fontsize=13)\n", - "\n", - "save_fig(\"decision_tree_instability_plot\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Code to generate Figure 6–3. Regularization using min_samples_leaf:**" - ] - }, - { - "cell_type": "code", - "execution_count": 9, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import make_moons\n", - "Xm, ym = make_moons(n_samples=100, noise=0.25, random_state=53)\n", "\n", - "deep_tree_clf1 = DecisionTreeClassifier(random_state=42)\n", - "deep_tree_clf2 = DecisionTreeClassifier(min_samples_leaf=4, random_state=42)\n", - "deep_tree_clf1.fit(Xm, ym)\n", - "deep_tree_clf2.fit(Xm, ym)\n", + "X_moons, y_moons = make_moons(n_samples=150, noise=0.2, random_state=42)\n", + "\n", + "tree_clf1 = DecisionTreeClassifier(random_state=42)\n", + "tree_clf2 = DecisionTreeClassifier(min_samples_leaf=5, random_state=42)\n", + "tree_clf1.fit(X_moons, y_moons)\n", + "tree_clf2.fit(X_moons, y_moons)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# not in the book\n", + "\n", + "def plot_decision_boundary(clf, X, y, axes, cmap):\n", + " x1, x2 = np.meshgrid(np.linspace(axes[0], axes[1], 100),\n", + " np.linspace(axes[2], axes[3], 100))\n", + " X_new = np.c_[x1.ravel(), x2.ravel()]\n", + " y_pred = clf.predict(X_new).reshape(x1.shape)\n", + " \n", + " plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=cmap)\n", + " plt.contour(x1, x2, y_pred, cmap=\"Greys\", alpha=0.8)\n", + " colors = {\"Wistia\": [\"blue\", \"black\"], \"Pastel1\": [\"red\", \"blue\"]}[cmap]\n", + " markers = (\"s\", \"^\")\n", + " for idx in (0, 1):\n", + " plt.plot(X[:, 0][y == idx], X[:, 1][y == idx],\n", + " color=colors[idx], marker=markers[idx], linestyle=\"none\")\n", + " plt.axis(axes)\n", + " plt.xlabel(r\"$x_1$\")\n", + " plt.ylabel(r\"$x_2$\", rotation=0)\n", "\n", "fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n", "plt.sca(axes[0])\n", - "plot_decision_boundary(deep_tree_clf1, Xm, ym, axes=[-1.5, 2.4, -1, 1.5], iris=False)\n", - "plt.title(\"No restrictions\", fontsize=16)\n", + "plot_decision_boundary(tree_clf1, X_moons, y_moons,\n", + " axes=[-1.5, 2.4, -1, 1.5], cmap=\"Wistia\")\n", + "plt.title(\"No restrictions\")\n", "plt.sca(axes[1])\n", - "plot_decision_boundary(deep_tree_clf2, Xm, ym, axes=[-1.5, 2.4, -1, 1.5], iris=False)\n", - "plt.title(\"min_samples_leaf = {}\".format(deep_tree_clf2.min_samples_leaf), fontsize=14)\n", + "plot_decision_boundary(tree_clf2, X_moons, y_moons,\n", + " axes=[-1.5, 2.4, -1, 1.5], cmap=\"Wistia\")\n", + "plt.title(f\"min_samples_leaf = {tree_clf2.min_samples_leaf}\")\n", "plt.ylabel(\"\")\n", - "\n", "save_fig(\"min_samples_leaf_plot\")\n", "plt.show()" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 16, "metadata": {}, + "outputs": [], "source": [ - "Rotating the dataset also leads to completely different decision boundaries:" + "X_moons_test, y_moons_test = make_moons(n_samples=1000, noise=0.2,\n", + " random_state=43)\n", + "tree_clf1.score(X_moons_test, y_moons_test)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ - "angle = np.pi / 180 * 20\n", - "rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])\n", - "Xr = X.dot(rotation_matrix)\n", - "\n", - "tree_clf_r = DecisionTreeClassifier(random_state=42)\n", - "tree_clf_r.fit(Xr, y)\n", - "\n", - "plt.figure(figsize=(8, 3))\n", - "plot_decision_boundary(tree_clf_r, Xr, y, axes=[0.5, 7.5, -1.0, 1], iris=False)\n", - "\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Code to generate Figure 6–7. Sensitivity to training set rotation**" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "np.random.seed(6)\n", - "Xs = np.random.rand(100, 2) - 0.5\n", - "ys = (Xs[:, 0] > 0).astype(np.float32) * 2\n", - "\n", - "angle = np.pi / 4\n", - "rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])\n", - "Xsr = Xs.dot(rotation_matrix)\n", - "\n", - "tree_clf_s = DecisionTreeClassifier(random_state=42)\n", - "tree_clf_s.fit(Xs, ys)\n", - "tree_clf_sr = DecisionTreeClassifier(random_state=42)\n", - "tree_clf_sr.fit(Xsr, ys)\n", - "\n", - "fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n", - "plt.sca(axes[0])\n", - "plot_decision_boundary(tree_clf_s, Xs, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)\n", - "plt.sca(axes[1])\n", - "plot_decision_boundary(tree_clf_sr, Xsr, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)\n", - "plt.ylabel(\"\")\n", - "\n", - "save_fig(\"sensitivity_to_rotation_plot\")\n", - "plt.show()" + "tree_clf2.score(X_moons_test, y_moons_test)" ] }, { @@ -377,21 +417,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let's prepare a simple linear dataset:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "# Quadratic training set + noise\n", - "np.random.seed(42)\n", - "m = 200\n", - "X = np.random.rand(m, 1)\n", - "y = 4 * (X - 0.5) ** 2\n", - "y = y + np.random.randn(m, 1) / 10" + "Let's prepare a simple quadratic training set:" ] }, { @@ -403,65 +429,118 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "from sklearn.tree import DecisionTreeRegressor\n", "\n", + "np.random.seed(42)\n", + "X_quad = np.random.rand(200, 1) - 0.5 # a single random input feature\n", + "y_quad = X_quad ** 2 + 0.025 * np.random.randn(200, 1)\n", + "\n", "tree_reg = DecisionTreeRegressor(max_depth=2, random_state=42)\n", - "tree_reg.fit(X, y)" + "tree_reg.fit(X_quad, y_quad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "**Code to generate Figure 6–5. Predictions of two Decision Tree regression models:**" + "**Code to generate Figure 5–4. A Decision Tree for regression:**" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ - "from sklearn.tree import DecisionTreeRegressor\n", + "# not in the book\n", + "export_graphviz(\n", + " tree_reg,\n", + " out_file=str(IMAGES_PATH / \"regression_tree.dot\"),\n", + " feature_names=[\"x1\"],\n", + " rounded=True,\n", + " filled=True\n", + ")\n", + "Source.from_file(IMAGES_PATH / \"regression_tree.dot\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "tree_reg2 = DecisionTreeRegressor(max_depth=3, random_state=42)\n", + "tree_reg2.fit(X_quad, y_quad)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Code to generate Figure 5–5. Predictions of two Decision Tree regression models:**" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "tree_reg.tree_.threshold" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "tree_reg2.tree_.threshold" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "# not in the book\n", "\n", - "tree_reg1 = DecisionTreeRegressor(random_state=42, max_depth=2)\n", - "tree_reg2 = DecisionTreeRegressor(random_state=42, max_depth=3)\n", - "tree_reg1.fit(X, y)\n", - "tree_reg2.fit(X, y)\n", - "\n", - "def plot_regression_predictions(tree_reg, X, y, axes=[0, 1, -0.2, 1], ylabel=\"$y$\"):\n", + "def plot_regression_predictions(tree_reg, X, y, axes=[-0.5, 0.5, -0.05, 0.25]):\n", " x1 = np.linspace(axes[0], axes[1], 500).reshape(-1, 1)\n", " y_pred = tree_reg.predict(x1)\n", " plt.axis(axes)\n", - " plt.xlabel(\"$x_1$\", fontsize=18)\n", - " if ylabel:\n", - " plt.ylabel(ylabel, fontsize=18, rotation=0)\n", + " plt.xlabel(\"$x_1$\")\n", " plt.plot(X, y, \"b.\")\n", " plt.plot(x1, y_pred, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n", "\n", "fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n", "plt.sca(axes[0])\n", - "plot_regression_predictions(tree_reg1, X, y)\n", - "for split, style in ((0.1973, \"k-\"), (0.0917, \"k--\"), (0.7718, \"k--\")):\n", - " plt.plot([split, split], [-0.2, 1], style, linewidth=2)\n", - "plt.text(0.21, 0.65, \"Depth=0\", fontsize=15)\n", - "plt.text(0.01, 0.2, \"Depth=1\", fontsize=13)\n", - "plt.text(0.65, 0.8, \"Depth=1\", fontsize=13)\n", - "plt.legend(loc=\"upper center\", fontsize=18)\n", - "plt.title(\"max_depth=2\", fontsize=14)\n", + "plot_regression_predictions(tree_reg, X_quad, y_quad)\n", + "\n", + "th0, th1a, th1b = tree_reg.tree_.threshold[[0, 1, 4]]\n", + "for split, style in ((th0, \"k-\"), (th1a, \"k--\"), (th1b, \"k--\")):\n", + " plt.plot([split, split], [-0.05, 0.25], style, linewidth=2)\n", + "plt.text(th0, 0.16, \"Depth=0\", fontsize=15)\n", + "plt.text(th1a + 0.01, -0.01, \"Depth=1\", horizontalalignment=\"center\", fontsize=13)\n", + "plt.text(th1b + 0.01, -0.01, \"Depth=1\", fontsize=13)\n", + "plt.ylabel(\"$y$\", rotation=0)\n", + "plt.legend(loc=\"upper center\", fontsize=16)\n", + "plt.title(\"max_depth=2\")\n", "\n", "plt.sca(axes[1])\n", - "plot_regression_predictions(tree_reg2, X, y, ylabel=None)\n", - "for split, style in ((0.1973, \"k-\"), (0.0917, \"k--\"), (0.7718, \"k--\")):\n", - " plt.plot([split, split], [-0.2, 1], style, linewidth=2)\n", - "for split in (0.0458, 0.1298, 0.2873, 0.9040):\n", - " plt.plot([split, split], [-0.2, 1], \"k:\", linewidth=1)\n", - "plt.text(0.3, 0.5, \"Depth=2\", fontsize=13)\n", - "plt.title(\"max_depth=3\", fontsize=14)\n", + "th2s = tree_reg2.tree_.threshold[[2, 5, 9, 12]]\n", + "plot_regression_predictions(tree_reg2, X_quad, y_quad)\n", + "for split, style in ((th0, \"k-\"), (th1a, \"k--\"), (th1b, \"k--\")):\n", + " plt.plot([split, split], [-0.05, 0.25], style, linewidth=2)\n", + "for split in th2s:\n", + " plt.plot([split, split], [-0.05, 0.25], \"k:\", linewidth=1)\n", + "plt.text(th2s[2] + 0.01, 0.15, \"Depth=2\", fontsize=13)\n", + "plt.title(\"max_depth=3\")\n", "\n", "save_fig(\"tree_regression_plot\")\n", "plt.show()" @@ -471,77 +550,497 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "**Code to generate Figure 6-4. A Decision Tree for regression:**" + "**Code to generate Figure 5–6. Regularizing a Decision Tree regressor:**" ] }, { "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "export_graphviz(\n", - " tree_reg1,\n", - " out_file=IMAGES_PATH / \"regression_tree.dot\",\n", - " feature_names=[\"x1\"],\n", - " rounded=True,\n", - " filled=True\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "Source.from_file(IMAGES_PATH / \"regression_tree.dot\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Code to generate Figure 6–6. Regularizing a Decision Tree regressor:**" - ] - }, - { - "cell_type": "code", - "execution_count": 17, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ + "# not in the book\n", + "\n", "tree_reg1 = DecisionTreeRegressor(random_state=42)\n", "tree_reg2 = DecisionTreeRegressor(random_state=42, min_samples_leaf=10)\n", - "tree_reg1.fit(X, y)\n", - "tree_reg2.fit(X, y)\n", + "tree_reg1.fit(X_quad, y_quad)\n", + "tree_reg2.fit(X_quad, y_quad)\n", "\n", - "x1 = np.linspace(0, 1, 500).reshape(-1, 1)\n", + "x1 = np.linspace(-0.5, 0.5, 500).reshape(-1, 1)\n", "y_pred1 = tree_reg1.predict(x1)\n", "y_pred2 = tree_reg2.predict(x1)\n", "\n", "fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n", "\n", "plt.sca(axes[0])\n", - "plt.plot(X, y, \"b.\")\n", + "plt.plot(X_quad, y_quad, \"b.\")\n", "plt.plot(x1, y_pred1, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n", - "plt.axis([0, 1, -0.2, 1.1])\n", - "plt.xlabel(\"$x_1$\", fontsize=18)\n", - "plt.ylabel(\"$y$\", fontsize=18, rotation=0)\n", - "plt.legend(loc=\"upper center\", fontsize=18)\n", - "plt.title(\"No restrictions\", fontsize=14)\n", + "plt.axis([-0.5, 0.5, -0.05, 0.25])\n", + "plt.xlabel(\"$x_1$\")\n", + "plt.ylabel(\"$y$\", rotation=0)\n", + "plt.legend(loc=\"upper center\")\n", + "plt.title(\"No restrictions\")\n", "\n", "plt.sca(axes[1])\n", - "plt.plot(X, y, \"b.\")\n", + "plt.plot(X_quad, y_quad, \"b.\")\n", "plt.plot(x1, y_pred2, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n", - "plt.axis([0, 1, -0.2, 1.1])\n", - "plt.xlabel(\"$x_1$\", fontsize=18)\n", - "plt.title(\"min_samples_leaf={}\".format(tree_reg2.min_samples_leaf), fontsize=14)\n", + "plt.axis([-0.5, 0.5, -0.05, 0.25])\n", + "plt.xlabel(\"$x_1$\")\n", + "plt.title(f\"min_samples_leaf={tree_reg2.min_samples_leaf}\")\n", "\n", "save_fig(\"tree_regression_regularization_plot\")\n", "plt.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sensitivity to axis orientation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Rotating the dataset also leads to completely different decision boundaries:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Code to generate Figure 5–7. Sensitivity to training set rotation**" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "# not in the book\n", + "\n", + "np.random.seed(6)\n", + "X_square = np.random.rand(100, 2) - 0.5\n", + "y_square = (X_square[:, 0] > 0).astype(np.int64)\n", + "\n", + "angle = np.pi / 4 # 45 degrees\n", + "rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)],\n", + " [np.sin(angle), np.cos(angle)]])\n", + "X_rotated_square = X_square.dot(rotation_matrix)\n", + "\n", + "tree_clf_square = DecisionTreeClassifier(random_state=42)\n", + "tree_clf_square.fit(X_square, y_square)\n", + "tree_clf_rotated_square = DecisionTreeClassifier(random_state=42)\n", + "tree_clf_rotated_square.fit(X_rotated_square, y_square)\n", + "\n", + "fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n", + "plt.sca(axes[0])\n", + "plot_decision_boundary(tree_clf_square, X_square, y_square,\n", + " axes=[-0.7, 0.7, -0.7, 0.7], cmap=\"Pastel1\")\n", + "plt.sca(axes[1])\n", + "plot_decision_boundary(tree_clf_rotated_square, X_rotated_square, y_square,\n", + " axes=[-0.7, 0.7, -0.7, 0.7], cmap=\"Pastel1\")\n", + "plt.ylabel(\"\")\n", + "\n", + "save_fig(\"sensitivity_to_rotation_plot\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.decomposition import PCA\n", + "from sklearn.pipeline import make_pipeline\n", + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "pca_pipeline = make_pipeline(StandardScaler(), PCA())\n", + "X_iris_rotated = pca_pipeline.fit_transform(X_iris)\n", + "tree_clf_pca = DecisionTreeClassifier(max_depth=2, random_state=42)\n", + "tree_clf_pca.fit(X_iris_rotated, y_iris)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "# not in the book\n", + "\n", + "plt.figure(figsize=(8, 4))\n", + "\n", + "axes = [-2.2, 2.4, -0.6, 0.7]\n", + "z0s, z1s = np.meshgrid(np.linspace(axes[0], axes[1], 100),\n", + " np.linspace(axes[2], axes[3], 100))\n", + "X_iris_pca_all = np.c_[z0s.ravel(), z1s.ravel()]\n", + "y_pred = tree_clf_pca.predict(X_iris_pca_all).reshape(z0s.shape)\n", + "\n", + "plt.contourf(z0s, z1s, y_pred, alpha=0.3, cmap=custom_cmap)\n", + "for idx, (name, style) in enumerate(zip(iris.target_names, (\"yo\", \"bs\", \"g^\"))):\n", + " plt.plot(X_iris_rotated[:, 0][y_iris == idx],\n", + " X_iris_rotated[:, 1][y_iris == idx],\n", + " style, label=f\"Iris {name}\")\n", + "\n", + "plt.xlabel(\"$z_1$\")\n", + "plt.ylabel(\"$z_2$\", rotation=0)\n", + "th1, th2 = tree_clf_pca.tree_.threshold[[0, 2]]\n", + "plt.plot([th1, th1], axes[2:], \"k-\", linewidth=2)\n", + "plt.plot([th2, th2], axes[2:], \"k--\", linewidth=2)\n", + "plt.text(th1 - 0.01, axes[2] + 0.05, \"Depth=0\",\n", + " horizontalalignment=\"right\", fontsize=15)\n", + "plt.text(th2 - 0.01, axes[2] + 0.05, \"Depth=1\",\n", + " horizontalalignment=\"right\", fontsize=13)\n", + "plt.axis(axes)\n", + "plt.legend(loc=(0.32, 0.67))\n", + "save_fig(\"pca_preprocessing_plot\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Decision Trees Have High Variance" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We've seen that small changes in the dataset (such as a rotation) may produce a very different Decision Tree.\n", + "Now let's show that training the same model on the same data may produce a very different model every time, since the CART training algorithm used by Scikit-Learn is stochastic. To show this, we will set `random_state` to a different value than earlier:" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "tree_clf_tweaked = DecisionTreeClassifier(max_depth=2, random_state=40)\n", + "tree_clf_tweaked.fit(X_iris, y_iris)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Code to generate Figure 5–8. Sensitivity to training set details:**" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "# not in the book\n", + "\n", + "plt.figure(figsize=(8, 4))\n", + "y_pred = tree_clf_tweaked.predict(X_iris_all).reshape(lengths.shape)\n", + "plt.contourf(lengths, widths, y_pred, alpha=0.3, cmap=custom_cmap)\n", + "\n", + "for idx, (name, style) in enumerate(zip(iris.target_names, (\"yo\", \"bs\", \"g^\"))):\n", + " plt.plot(X_iris[:, 0][y_iris == idx], X_iris[:, 1][y_iris == idx],\n", + " style, label=f\"Iris {name}\")\n", + "\n", + "th0, th1 = tree_clf_tweaked.tree_.threshold[[0, 2]]\n", + "plt.plot([0, 7.2], [th0, th0], \"k-\", linewidth=2)\n", + "plt.plot([0, 7.2], [th1, th1], \"k--\", linewidth=2)\n", + "plt.text(1.8, th0 + 0.05, \"Depth=0\", verticalalignment=\"bottom\", fontsize=15)\n", + "plt.text(2.3, th1 + 0.05, \"Depth=1\", verticalalignment=\"bottom\", fontsize=13)\n", + "plt.xlabel(\"Petal length (cm)\")\n", + "plt.ylabel(\"Petal width (cm)\")\n", + "plt.axis([0, 7.2, 0, 3])\n", + "plt.legend()\n", + "save_fig(\"decision_tree_instability_plot\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Extra Material – Accessing the tree structure" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A trained `DecisionTreeClassifier` has a `tree_` attribute that stores the tree's structure:" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "tree = tree_clf.tree_\n", + "tree" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can get the total number of nodes in the tree:" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "tree.node_count" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And other self-explanatory attributes are available:" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "tree.max_depth" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "tree.max_n_classes" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "tree.n_features" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "tree.n_outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "tree.n_leaves" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "All the information about the nodes is stored in NumPy arrays. For example, the impurity of each node:" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "tree.impurity" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The root node is at index 0. The left and right children nodes of node _i_ are `tree.children_left[i]` and `tree.children_right[i]`. For example, the children of the root node are:" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "tree.children_left[0], tree.children_right[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When the left and right nodes are equal, it means this is a leaf node (and the children node ids are arbitrary):" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "tree.children_left[3], tree.children_right[3]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So you can get the leaf node ids like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "is_leaf = (tree.children_left == tree.children_right)\n", + "np.arange(tree.node_count)[is_leaf]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Non-leaf nodes are called _split nodes_. The feature they split is available via the `feature` array. Values for leaf nodes should be ignored:" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "tree.feature" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And the corresponding thresholds are:" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "tree.threshold" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And the number of instances per class that reached each node is available too:" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "tree.value" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "tree.n_node_samples" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "np.all(tree.value.sum(axis=(1, 2)) == tree.n_node_samples)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here's how you can compute the depth of each node:" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_depth(tree_clf):\n", + " tree = tree_clf.tree_\n", + " depth = np.zeros(tree.node_count)\n", + " stack = [(0, 0)]\n", + " while stack:\n", + " node, node_depth = stack.pop()\n", + " depth[node] = node_depth\n", + " if tree.children_left[node] != tree.children_right[node]:\n", + " stack.append((tree.children_left[node], node_depth + 1))\n", + " stack.append((tree.children_right[node], node_depth + 1))\n", + " return depth\n", + "\n", + "depth = compute_depth(tree_clf)\n", + "depth" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here's how to get the thresholds of all split nodes at depth 1:" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "tree_clf.tree_.feature[(depth == 1) & (~is_leaf)]" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "tree_clf.tree_.threshold[(depth == 1) & (~is_leaf)]" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -593,13 +1092,13 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import make_moons\n", "\n", - "X, y = make_moons(n_samples=10000, noise=0.4, random_state=42)" + "X_moons, y_moons = make_moons(n_samples=10000, noise=0.4, random_state=42)" ] }, { @@ -611,13 +1110,15 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", - "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)" + "X_train, X_test, y_train, y_test = train_test_split(X_moons, y_moons,\n", + " test_size=0.2,\n", + " random_state=42)" ] }, { @@ -629,21 +1130,27 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import GridSearchCV\n", "\n", - "params = {'max_leaf_nodes': list(range(2, 100)), 'min_samples_split': [2, 3, 4]}\n", - "grid_search_cv = GridSearchCV(DecisionTreeClassifier(random_state=42), params, verbose=1, cv=3)\n", + "params = {\n", + " 'max_leaf_nodes': list(range(2, 100)),\n", + " 'max_depth': list(range(1, 7)),\n", + " 'min_samples_split': [2, 3, 4]\n", + "}\n", + "grid_search_cv = GridSearchCV(DecisionTreeClassifier(random_state=42),\n", + " params,\n", + " cv=3)\n", "\n", "grid_search_cv.fit(X_train, y_train)" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ @@ -666,7 +1173,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 53, "metadata": {}, "outputs": [], "source": [ @@ -699,7 +1206,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 54, "metadata": {}, "outputs": [], "source": [ @@ -710,7 +1217,9 @@ "\n", "mini_sets = []\n", "\n", - "rs = ShuffleSplit(n_splits=n_trees, test_size=len(X_train) - n_instances, random_state=42)\n", + "rs = ShuffleSplit(n_splits=n_trees, test_size=len(X_train) - n_instances,\n", + " random_state=42)\n", + "\n", "for mini_train_index, mini_test_index in rs.split(X_train):\n", " X_mini_train = X_train[mini_train_index]\n", " y_mini_train = y_train[mini_train_index]\n", @@ -726,7 +1235,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 55, "metadata": {}, "outputs": [], "source": [ @@ -754,7 +1263,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 56, "metadata": {}, "outputs": [], "source": [ @@ -766,7 +1275,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 57, "metadata": {}, "outputs": [], "source": [ @@ -784,17 +1293,24 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "accuracy_score(y_test, y_pred_majority_votes.reshape([-1]))" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" },