handson-ml/06_decision_trees.ipynb

568 lines
16 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"**Chapter 6 Decision Trees**"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"_This notebook contains all the sample code and solutions to the exercices in chapter 6._"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"# To support both python 2 and python 3\n",
"from __future__ import division, print_function, unicode_literals\n",
"\n",
"# Common imports\n",
"import numpy as np\n",
"import numpy.random as rnd\n",
"import os\n",
"\n",
"# to make this notebook's output stable across runs\n",
"rnd.seed(42)\n",
"\n",
"# To plot pretty figures\n",
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['axes.labelsize'] = 14\n",
"plt.rcParams['xtick.labelsize'] = 12\n",
"plt.rcParams['ytick.labelsize'] = 12\n",
"\n",
"# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n",
"CHAPTER_ID = \"decision_trees\"\n",
"\n",
"def image_path(fig_id):\n",
" return os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id)\n",
"\n",
"def save_fig(fig_id, tight_layout=True):\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(image_path(fig_id) + \".png\", format='png', dpi=300)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Training and visualizing"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.datasets import load_iris\n",
"from sklearn.tree import DecisionTreeClassifier, export_graphviz\n",
"\n",
"iris = load_iris()\n",
"X = iris.data[:, 2:] # petal length and width\n",
"y = iris.target\n",
"\n",
"tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)\n",
"tree_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"export_graphviz(\n",
" tree_clf,\n",
" out_file=image_path(\"iris_tree.dot\"),\n",
" feature_names=iris.feature_names[2:],\n",
" class_names=iris.target_names,\n",
" rounded=True,\n",
" filled=True\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from matplotlib.colors import ListedColormap\n",
"\n",
"def plot_decision_boundary(clf, X, y, axes=[0, 7.5, 0, 3], iris=True, legend=False, plot_training=True):\n",
" x1s = np.linspace(axes[0], axes[1], 100)\n",
" x2s = np.linspace(axes[2], axes[3], 100)\n",
" x1, x2 = np.meshgrid(x1s, x2s)\n",
" X_new = np.c_[x1.ravel(), x2.ravel()]\n",
" y_pred = clf.predict(X_new).reshape(x1.shape)\n",
" custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n",
" plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap, linewidth=10)\n",
" if not iris:\n",
" custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])\n",
" plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8)\n",
" if plot_training:\n",
" plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\", label=\"Iris-Setosa\")\n",
" plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\", label=\"Iris-Versicolor\")\n",
" plt.plot(X[:, 0][y==2], X[:, 1][y==2], \"g^\", label=\"Iris-Virginica\")\n",
" plt.axis(axes)\n",
" if iris:\n",
" plt.xlabel(\"Petal length\", fontsize=14)\n",
" plt.ylabel(\"Petal width\", fontsize=14)\n",
" else:\n",
" plt.xlabel(r\"$x_1$\", fontsize=18)\n",
" plt.ylabel(r\"$x_2$\", fontsize=18, rotation=0)\n",
" if legend:\n",
" plt.legend(loc=\"lower right\", fontsize=14)\n",
"\n",
"plt.figure(figsize=(8, 4))\n",
"plot_decision_boundary(tree_clf, X, y)\n",
"plt.plot([2.45, 2.45], [0, 3], \"k-\", linewidth=2)\n",
"plt.plot([2.45, 7.5], [1.75, 1.75], \"k--\", linewidth=2)\n",
"plt.plot([4.95, 4.95], [0, 1.75], \"k:\", linewidth=2)\n",
"plt.plot([4.85, 4.85], [1.75, 3], \"k:\", linewidth=2)\n",
"plt.text(1.40, 1.0, \"Depth=0\", fontsize=15)\n",
"plt.text(3.2, 1.80, \"Depth=1\", fontsize=13)\n",
"plt.text(4.05, 0.5, \"(Depth=2)\", fontsize=11)\n",
"\n",
"save_fig(\"decision_tree_decision_boundaries_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Predicting classes and class probabilities"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"tree_clf.predict_proba([[5, 1.5]])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"tree_clf.predict([[5, 1.5]])"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Sensitivity to training set details"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"X[(X[:, 1]==X[:, 1][y==1].max()) & (y==1)] # widest Iris-Versicolor flower"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"not_widest_versicolor = (X[:, 1]!=1.8) | (y==2)\n",
"X_tweaked = X[not_widest_versicolor]\n",
"y_tweaked = y[not_widest_versicolor]\n",
"\n",
"tree_clf_tweaked = DecisionTreeClassifier(max_depth=2, random_state=40)\n",
"tree_clf_tweaked.fit(X_tweaked, y_tweaked)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"plt.figure(figsize=(8, 4))\n",
"plot_decision_boundary(tree_clf_tweaked, X_tweaked, y_tweaked, legend=False)\n",
"plt.plot([0, 7.5], [0.8, 0.8], \"k-\", linewidth=2)\n",
"plt.plot([0, 7.5], [1.75, 1.75], \"k--\", linewidth=2)\n",
"plt.text(1.0, 0.9, \"Depth=0\", fontsize=15)\n",
"plt.text(1.0, 1.80, \"Depth=1\", fontsize=13)\n",
"\n",
"save_fig(\"decision_tree_instability_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.datasets import make_moons\n",
"Xm, ym = make_moons(n_samples=100, noise=0.25, random_state=53)\n",
"\n",
"deep_tree_clf1 = DecisionTreeClassifier(random_state=42)\n",
"deep_tree_clf2 = DecisionTreeClassifier(min_samples_leaf=4, random_state=42)\n",
"deep_tree_clf1.fit(Xm, ym)\n",
"deep_tree_clf2.fit(Xm, ym)\n",
"\n",
"plt.figure(figsize=(11, 4))\n",
"plt.subplot(121)\n",
"plot_decision_boundary(deep_tree_clf1, Xm, ym, axes=[-1.5, 2.5, -1, 1.5], iris=False)\n",
"plt.title(\"No restrictions\", fontsize=16)\n",
"plt.subplot(122)\n",
"plot_decision_boundary(deep_tree_clf2, Xm, ym, axes=[-1.5, 2.5, -1, 1.5], iris=False)\n",
"plt.title(\"min_samples_leaf = {}\".format(deep_tree_clf2.min_samples_leaf), fontsize=14)\n",
"\n",
"save_fig(\"min_samples_leaf_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"angle = np.pi / 180 * 20\n",
"rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])\n",
"Xr = X.dot(rotation_matrix)\n",
"\n",
"tree_clf_r = DecisionTreeClassifier(random_state=42)\n",
"tree_clf_r.fit(Xr, y)\n",
"\n",
"plt.figure(figsize=(8, 3))\n",
"plot_decision_boundary(tree_clf_r, Xr, y, axes=[0.5, 7.5, -1.0, 1], iris=False)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"rnd.seed(6)\n",
"Xs = rnd.rand(100, 2) - 0.5\n",
"ys = (Xs[:, 0] > 0).astype(np.float32) * 2\n",
"\n",
"angle = np.pi / 4\n",
"rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])\n",
"Xsr = Xs.dot(rotation_matrix)\n",
"\n",
"tree_clf_s = DecisionTreeClassifier(random_state=42)\n",
"tree_clf_s.fit(Xs, ys)\n",
"tree_clf_sr = DecisionTreeClassifier(random_state=42)\n",
"tree_clf_sr.fit(Xsr, ys)\n",
"\n",
"plt.figure(figsize=(11, 4))\n",
"plt.subplot(121)\n",
"plot_decision_boundary(tree_clf_s, Xs, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)\n",
"plt.subplot(122)\n",
"plot_decision_boundary(tree_clf_sr, Xsr, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)\n",
"\n",
"save_fig(\"sensitivity_to_rotation_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Regression trees"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.tree import DecisionTreeRegressor\n",
"\n",
"# Quadratic training set + noise\n",
"rnd.seed(42)\n",
"m = 200\n",
"X = rnd.rand(m, 1)\n",
"y = 4 * (X - 0.5) ** 2\n",
"y = y + rnd.randn(m, 1) / 10\n",
"\n",
"tree_reg1 = DecisionTreeRegressor(random_state=42, max_depth=2)\n",
"tree_reg2 = DecisionTreeRegressor(random_state=42, max_depth=3)\n",
"tree_reg1.fit(X, y)\n",
"tree_reg2.fit(X, y)\n",
"\n",
"def plot_regression_predictions(tree_reg, X, y, axes=[0, 1, -0.2, 1], ylabel=\"$y$\"):\n",
" x1 = np.linspace(axes[0], axes[1], 500).reshape(-1, 1)\n",
" y_pred = tree_reg.predict(x1)\n",
" plt.axis(axes)\n",
" plt.xlabel(\"$x_1$\", fontsize=18)\n",
" if ylabel:\n",
" plt.ylabel(ylabel, fontsize=18, rotation=0)\n",
" plt.plot(X, y, \"b.\")\n",
" plt.plot(x1, y_pred, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n",
"\n",
"plt.figure(figsize=(11, 4))\n",
"plt.subplot(121)\n",
"plot_regression_predictions(tree_reg1, X, y)\n",
"for split, style in ((0.1973, \"k-\"), (0.0917, \"k--\"), (0.7718, \"k--\")):\n",
" plt.plot([split, split], [-0.2, 1], style, linewidth=2)\n",
"plt.text(0.21, 0.65, \"Depth=0\", fontsize=15)\n",
"plt.text(0.01, 0.2, \"Depth=1\", fontsize=13)\n",
"plt.text(0.65, 0.8, \"Depth=1\", fontsize=13)\n",
"plt.legend(loc=\"upper center\", fontsize=18)\n",
"plt.title(\"max_depth=2\", fontsize=14)\n",
"\n",
"plt.subplot(122)\n",
"plot_regression_predictions(tree_reg2, X, y, ylabel=None)\n",
"for split, style in ((0.1973, \"k-\"), (0.0917, \"k--\"), (0.7718, \"k--\")):\n",
" plt.plot([split, split], [-0.2, 1], style, linewidth=2)\n",
"for split in (0.0458, 0.1298, 0.2873, 0.9040):\n",
" plt.plot([split, split], [-0.2, 1], \"k:\", linewidth=1)\n",
"plt.text(0.3, 0.5, \"Depth=2\", fontsize=13)\n",
"plt.title(\"max_depth=3\", fontsize=14)\n",
"\n",
"save_fig(\"tree_regression_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"export_graphviz(\n",
" tree_reg1,\n",
" out_file=image_path(\"regression_tree.dot\"),\n",
" feature_names=[\"x1\"],\n",
" rounded=True,\n",
" filled=True\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"tree_reg1 = DecisionTreeRegressor(random_state=42)\n",
"tree_reg2 = DecisionTreeRegressor(random_state=42, min_samples_leaf=10)\n",
"tree_reg1.fit(X, y)\n",
"tree_reg2.fit(X, y)\n",
"\n",
"x1 = np.linspace(0, 1, 500).reshape(-1, 1)\n",
"y_pred1 = tree_reg1.predict(x1)\n",
"y_pred2 = tree_reg2.predict(x1)\n",
"\n",
"plt.figure(figsize=(11, 4))\n",
"\n",
"plt.subplot(121)\n",
"plt.plot(X, y, \"b.\")\n",
"plt.plot(x1, y_pred1, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n",
"plt.axis([0, 1, -0.2, 1.1])\n",
"plt.xlabel(\"$x_1$\", fontsize=18)\n",
"plt.ylabel(\"$y$\", fontsize=18, rotation=0)\n",
"plt.legend(loc=\"upper center\", fontsize=18)\n",
"plt.title(\"No restrictions\", fontsize=14)\n",
"\n",
"plt.subplot(122)\n",
"plt.plot(X, y, \"b.\")\n",
"plt.plot(x1, y_pred2, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n",
"plt.axis([0, 1, -0.2, 1.1])\n",
"plt.xlabel(\"$x_1$\", fontsize=18)\n",
"plt.title(\"min_samples_leaf={}\".format(tree_reg2.min_samples_leaf), fontsize=14)\n",
"\n",
"save_fig(\"tree_regression_regularization_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"**Coming soon**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2+"
},
"nav_menu": {
"height": "309px",
"width": "468px"
},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 0
}