507 lines
14 KiB
Plaintext
507 lines
14 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Chapter 6 – Decision Trees**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"_This notebook contains all the sample code and solutions to the exercices in chapter 6._"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Setup"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# To support both python 2 and python 3\n",
|
||
"from __future__ import division, print_function, unicode_literals\n",
|
||
"\n",
|
||
"# Common imports\n",
|
||
"import numpy as np\n",
|
||
"import numpy.random as rnd\n",
|
||
"import os\n",
|
||
"\n",
|
||
"# to make this notebook's output stable across runs\n",
|
||
"rnd.seed(42)\n",
|
||
"\n",
|
||
"# To plot pretty figures\n",
|
||
"%matplotlib inline\n",
|
||
"import matplotlib\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"plt.rcParams['axes.labelsize'] = 14\n",
|
||
"plt.rcParams['xtick.labelsize'] = 12\n",
|
||
"plt.rcParams['ytick.labelsize'] = 12\n",
|
||
"\n",
|
||
"# Where to save the figures\n",
|
||
"PROJECT_ROOT_DIR = \".\"\n",
|
||
"CHAPTER_ID = \"decision_trees\"\n",
|
||
"\n",
|
||
"def image_path(fig_id):\n",
|
||
" return os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id)\n",
|
||
"\n",
|
||
"def save_fig(fig_id, tight_layout=True):\n",
|
||
" print(\"Saving figure\", fig_id)\n",
|
||
" if tight_layout:\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.savefig(image_path(fig_id) + \".png\", format='png', dpi=300)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Training and visualizing"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.datasets import load_iris\n",
|
||
"from sklearn.tree import DecisionTreeClassifier, export_graphviz\n",
|
||
"\n",
|
||
"iris = load_iris()\n",
|
||
"X = iris.data[:, 2:] # petal length and width\n",
|
||
"y = iris.target\n",
|
||
"\n",
|
||
"tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)\n",
|
||
"tree_clf.fit(X, y)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"export_graphviz(\n",
|
||
" tree_clf,\n",
|
||
" out_file=image_path(\"iris_tree.dot\"),\n",
|
||
" feature_names=iris.feature_names[2:],\n",
|
||
" class_names=iris.target_names,\n",
|
||
" rounded=True,\n",
|
||
" filled=True\n",
|
||
" )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from matplotlib.colors import ListedColormap\n",
|
||
"\n",
|
||
"def plot_decision_boundary(clf, X, y, axes=[0, 7.5, 0, 3], iris=True, legend=False, plot_training=True):\n",
|
||
" x1s = np.linspace(axes[0], axes[1], 100)\n",
|
||
" x2s = np.linspace(axes[2], axes[3], 100)\n",
|
||
" x1, x2 = np.meshgrid(x1s, x2s)\n",
|
||
" X_new = np.c_[x1.ravel(), x2.ravel()]\n",
|
||
" y_pred = clf.predict(X_new).reshape(x1.shape)\n",
|
||
" custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n",
|
||
" plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap, linewidth=10)\n",
|
||
" if not iris:\n",
|
||
" custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])\n",
|
||
" plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8)\n",
|
||
" if plot_training:\n",
|
||
" plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\", label=\"Iris-Setosa\")\n",
|
||
" plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\", label=\"Iris-Versicolour\")\n",
|
||
" plt.plot(X[:, 0][y==2], X[:, 1][y==2], \"g^\", label=\"Iris-Virginica\")\n",
|
||
" plt.axis(axes)\n",
|
||
" if iris:\n",
|
||
" plt.xlabel(\"Petal length\", fontsize=14)\n",
|
||
" plt.ylabel(\"Petal width\", fontsize=14)\n",
|
||
" else:\n",
|
||
" plt.xlabel(r\"$x_1$\", fontsize=18)\n",
|
||
" plt.ylabel(r\"$x_2$\", fontsize=18, rotation=0)\n",
|
||
" if legend:\n",
|
||
" plt.legend(loc=\"lower right\", fontsize=14)\n",
|
||
"\n",
|
||
"plt.figure(figsize=(8, 4))\n",
|
||
"plot_decision_boundary(tree_clf, X, y)\n",
|
||
"plt.plot([2.45, 2.45], [0, 3], \"k-\", linewidth=2)\n",
|
||
"plt.plot([2.45, 7.5], [1.75, 1.75], \"k--\", linewidth=2)\n",
|
||
"plt.plot([4.95, 4.95], [0, 1.75], \"k:\", linewidth=2)\n",
|
||
"plt.plot([4.85, 4.85], [1.75, 3], \"k:\", linewidth=2)\n",
|
||
"plt.text(1.40, 1.0, \"Depth=0\", fontsize=15)\n",
|
||
"plt.text(3.2, 1.80, \"Depth=1\", fontsize=13)\n",
|
||
"plt.text(4.05, 0.5, \"(Depth=2)\", fontsize=11)\n",
|
||
"\n",
|
||
"save_fig(\"decision_tree_decision_boundaries_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Predicting classes and class probabilities"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"tree_clf.predict_proba([[5, 1.5]])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"tree_clf.predict([[5, 1.5]])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Sensitivity to training set details"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"X[(X[:, 1]==X[:, 1][y==1].max()) & (y==1)] # widest Iris-Versicolour flower"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"not_widest_versicolour = (X[:, 1]!=1.8) | (y==2)\n",
|
||
"X_tweaked = X[not_widest_versicolour]\n",
|
||
"y_tweaked = y[not_widest_versicolour]\n",
|
||
"\n",
|
||
"tree_clf_tweaked = DecisionTreeClassifier(max_depth=2, random_state=40)\n",
|
||
"tree_clf_tweaked.fit(X_tweaked, y_tweaked)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"plt.figure(figsize=(8, 4))\n",
|
||
"plot_decision_boundary(tree_clf_tweaked, X_tweaked, y_tweaked, legend=False)\n",
|
||
"plt.plot([0, 7.5], [0.8, 0.8], \"k-\", linewidth=2)\n",
|
||
"plt.plot([0, 7.5], [1.75, 1.75], \"k--\", linewidth=2)\n",
|
||
"plt.text(1.0, 0.9, \"Depth=0\", fontsize=15)\n",
|
||
"plt.text(1.0, 1.80, \"Depth=1\", fontsize=13)\n",
|
||
"\n",
|
||
"save_fig(\"decision_tree_instability_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.datasets import make_moons\n",
|
||
"Xm, ym = make_moons(n_samples=100, noise=0.25, random_state=53)\n",
|
||
"\n",
|
||
"deep_tree_clf1 = DecisionTreeClassifier(random_state=42)\n",
|
||
"deep_tree_clf2 = DecisionTreeClassifier(min_samples_leaf=4, random_state=42)\n",
|
||
"deep_tree_clf1.fit(Xm, ym)\n",
|
||
"deep_tree_clf2.fit(Xm, ym)\n",
|
||
"\n",
|
||
"plt.figure(figsize=(11, 4))\n",
|
||
"plt.subplot(121)\n",
|
||
"plot_decision_boundary(deep_tree_clf1, Xm, ym, axes=[-1.5, 2.5, -1, 1.5], iris=False)\n",
|
||
"plt.title(\"No restrictions\", fontsize=16)\n",
|
||
"plt.subplot(122)\n",
|
||
"plot_decision_boundary(deep_tree_clf2, Xm, ym, axes=[-1.5, 2.5, -1, 1.5], iris=False)\n",
|
||
"plt.title(\"min_samples_leaf = {}\".format(deep_tree_clf2.min_samples_leaf), fontsize=14)\n",
|
||
"\n",
|
||
"save_fig(\"min_samples_leaf_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"angle = np.pi / 180 * 20\n",
|
||
"rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])\n",
|
||
"Xr = X.dot(rotation_matrix)\n",
|
||
"\n",
|
||
"tree_clf_r = DecisionTreeClassifier(random_state=42)\n",
|
||
"tree_clf_r.fit(Xr, y)\n",
|
||
"\n",
|
||
"plt.figure(figsize=(8, 3))\n",
|
||
"plot_decision_boundary(tree_clf_r, Xr, y, axes=[0.5, 7.5, -1.0, 1], iris=False)\n",
|
||
"\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"rnd.seed(6)\n",
|
||
"Xs = rnd.rand(100, 2) - 0.5\n",
|
||
"ys = (Xs[:, 0] > 0).astype(np.float32) * 2\n",
|
||
"\n",
|
||
"angle = np.pi / 4\n",
|
||
"rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])\n",
|
||
"Xsr = Xs.dot(rotation_matrix)\n",
|
||
"\n",
|
||
"tree_clf_s = DecisionTreeClassifier(random_state=42)\n",
|
||
"tree_clf_s.fit(Xs, ys)\n",
|
||
"tree_clf_sr = DecisionTreeClassifier(random_state=42)\n",
|
||
"tree_clf_sr.fit(Xsr, ys)\n",
|
||
"\n",
|
||
"plt.figure(figsize=(11, 4))\n",
|
||
"plt.subplot(121)\n",
|
||
"plot_decision_boundary(tree_clf_s, Xs, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)\n",
|
||
"plt.subplot(122)\n",
|
||
"plot_decision_boundary(tree_clf_sr, Xsr, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)\n",
|
||
"\n",
|
||
"save_fig(\"sensitivity_to_rotation_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Regression trees"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 145,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.tree import DecisionTreeRegressor\n",
|
||
"\n",
|
||
"# Quadratic training set + noise\n",
|
||
"rnd.seed(42)\n",
|
||
"m = 200\n",
|
||
"X = rnd.rand(m, 1)\n",
|
||
"y = 4 * (X - 0.5) ** 2\n",
|
||
"y = y + rnd.randn(m, 1) / 10\n",
|
||
"\n",
|
||
"tree_reg1 = DecisionTreeRegressor(random_state=42, max_depth=2)\n",
|
||
"tree_reg2 = DecisionTreeRegressor(random_state=42, max_depth=3)\n",
|
||
"tree_reg1.fit(X, y)\n",
|
||
"tree_reg2.fit(X, y)\n",
|
||
"\n",
|
||
"def plot_regression_predictions(tree_reg, X, y, axes=[0, 1, -0.2, 1], ylabel=\"$y$\"):\n",
|
||
" x1 = np.linspace(axes[0], axes[1], 500).reshape(-1, 1)\n",
|
||
" y_pred = tree_reg.predict(x1)\n",
|
||
" plt.axis(axes)\n",
|
||
" plt.xlabel(\"$x_1$\", fontsize=18)\n",
|
||
" if ylabel:\n",
|
||
" plt.ylabel(ylabel, fontsize=18, rotation=0)\n",
|
||
" plt.plot(X, y, \"b.\")\n",
|
||
" plt.plot(x1, y_pred, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n",
|
||
"\n",
|
||
"plt.figure(figsize=(11, 4))\n",
|
||
"plt.subplot(121)\n",
|
||
"plot_regression_predictions(tree_reg1, X, y)\n",
|
||
"for split, style in ((0.1973, \"k-\"), (0.0917, \"k--\"), (0.7718, \"k--\")):\n",
|
||
" plt.plot([split, split], [-0.2, 1], style, linewidth=2)\n",
|
||
"plt.text(0.21, 0.65, \"Depth=0\", fontsize=15)\n",
|
||
"plt.text(0.01, 0.2, \"Depth=1\", fontsize=13)\n",
|
||
"plt.text(0.65, 0.8, \"Depth=1\", fontsize=13)\n",
|
||
"plt.legend(loc=\"upper center\", fontsize=18)\n",
|
||
"plt.title(\"max_depth=2\", fontsize=14)\n",
|
||
"\n",
|
||
"plt.subplot(122)\n",
|
||
"plot_regression_predictions(tree_reg2, X, y, ylabel=None)\n",
|
||
"for split, style in ((0.1973, \"k-\"), (0.0917, \"k--\"), (0.7718, \"k--\")):\n",
|
||
" plt.plot([split, split], [-0.2, 1], style, linewidth=2)\n",
|
||
"for split in (0.0458, 0.1298, 0.2873, 0.9040):\n",
|
||
" plt.plot([split, split], [-0.2, 1], \"k:\", linewidth=1)\n",
|
||
"plt.text(0.3, 0.5, \"Depth=2\", fontsize=13)\n",
|
||
"plt.title(\"max_depth=3\", fontsize=14)\n",
|
||
"\n",
|
||
"save_fig(\"tree_regression_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 131,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"export_graphviz(\n",
|
||
" tree_reg1,\n",
|
||
" out_file=image_path(\"regression_tree.dot\"),\n",
|
||
" feature_names=[\"x1\"],\n",
|
||
" rounded=True,\n",
|
||
" filled=True\n",
|
||
" )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 144,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"tree_reg1 = DecisionTreeRegressor(random_state=42)\n",
|
||
"tree_reg2 = DecisionTreeRegressor(random_state=42, min_samples_leaf=10)\n",
|
||
"tree_reg1.fit(X, y)\n",
|
||
"tree_reg2.fit(X, y)\n",
|
||
"\n",
|
||
"x1 = np.linspace(0, 1, 500).reshape(-1, 1)\n",
|
||
"y_pred1 = tree_reg1.predict(x1)\n",
|
||
"y_pred2 = tree_reg2.predict(x1)\n",
|
||
"\n",
|
||
"plt.figure(figsize=(11, 4))\n",
|
||
"\n",
|
||
"plt.subplot(121)\n",
|
||
"plt.plot(X, y, \"b.\")\n",
|
||
"plt.plot(x1, y_pred1, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n",
|
||
"plt.axis([0, 1, -0.2, 1.1])\n",
|
||
"plt.xlabel(\"$x_1$\", fontsize=18)\n",
|
||
"plt.ylabel(\"$y$\", fontsize=18, rotation=0)\n",
|
||
"plt.legend(loc=\"upper center\", fontsize=18)\n",
|
||
"plt.title(\"No restrictions\", fontsize=14)\n",
|
||
"\n",
|
||
"plt.subplot(122)\n",
|
||
"plt.plot(X, y, \"b.\")\n",
|
||
"plt.plot(x1, y_pred2, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n",
|
||
"plt.axis([0, 1, -0.2, 1.1])\n",
|
||
"plt.xlabel(\"$x_1$\", fontsize=18)\n",
|
||
"plt.title(\"min_samples_leaf={}\".format(tree_reg2.min_samples_leaf), fontsize=14)\n",
|
||
"\n",
|
||
"save_fig(\"tree_regression_regularization_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"source": [
|
||
"# Exercise solutions"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Coming soon**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.5.1"
|
||
},
|
||
"nav_menu": {
|
||
"height": "309px",
|
||
"width": "468px"
|
||
},
|
||
"toc": {
|
||
"navigate_menu": true,
|
||
"number_sections": true,
|
||
"sideBar": true,
|
||
"threshold": 6,
|
||
"toc_cell": false,
|
||
"toc_section_display": "block",
|
||
"toc_window_display": false
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0
|
||
}
|