2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"**Chapter 7 – Ensemble Learning and Random Forests**"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"_This notebook contains all the sample code and solutions to the exercices in chapter 7._"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"# To support both python 2 and python 3\n",
"from __future__ import division, print_function, unicode_literals\n",
"\n",
"# Common imports\n",
"import numpy as np\n",
"import numpy.random as rnd\n",
"import os\n",
"\n",
"# to make this notebook's output stable across runs\n",
"rnd.seed(42)\n",
"\n",
"# To plot pretty figures\n",
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['axes.labelsize'] = 14\n",
"plt.rcParams['xtick.labelsize'] = 12\n",
"plt.rcParams['ytick.labelsize'] = 12\n",
"\n",
"# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n",
"CHAPTER_ID = \"ensembles\"\n",
"\n",
"def image_path(fig_id):\n",
" return os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id)\n",
"\n",
"def save_fig(fig_id, tight_layout=True):\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(image_path(fig_id) + \".png\", format='png', dpi=300)"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Voting classifiers"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"heads_proba = 0.51\n",
"coin_tosses = (rnd.rand(10000, 10) < heads_proba).astype(np.int32)\n",
"cumulative_heads_ratio = np.cumsum(coin_tosses, axis=0) / np.arange(1, 10001).reshape(-1, 1)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"plt.figure(figsize=(8,3.5))\n",
"plt.plot(cumulative_heads_ratio)\n",
"plt.plot([0, 10000], [0.51, 0.51], \"k--\", linewidth=2, label=\"51%\")\n",
"plt.plot([0, 10000], [0.5, 0.5], \"k-\", label=\"50%\")\n",
"plt.xlabel(\"Number of coin tosses\")\n",
"plt.ylabel(\"Heads ratio\")\n",
"plt.legend(loc=\"lower right\")\n",
"plt.axis([0, 10000, 0.42, 0.58])\n",
"save_fig(\"law_of_large_numbers_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
2017-06-02 10:57:06 +02:00
"collapsed": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
2016-11-05 14:25:56 +01:00
"from sklearn.model_selection import train_test_split\n",
2016-09-27 23:31:21 +02:00
"from sklearn.datasets import make_moons\n",
"\n",
"X, y = make_moons(n_samples=500, noise=0.30, random_state=42)\n",
2017-06-02 10:57:06 +02:00
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.ensemble import VotingClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.svm import SVC\n",
"\n",
"log_clf = LogisticRegression(random_state=42)\n",
"rnd_clf = RandomForestClassifier(random_state=42)\n",
2017-06-02 10:57:06 +02:00
"svm_clf = SVC(random_state=42)\n",
2016-09-27 23:31:21 +02:00
"\n",
"voting_clf = VotingClassifier(\n",
2017-06-02 10:57:06 +02:00
" estimators=[('lr', log_clf), ('rf', rnd_clf), ('svc', svm_clf)],\n",
" voting='hard')\n",
"voting_clf.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.metrics import accuracy_score\n",
2016-09-27 23:31:21 +02:00
"\n",
2017-06-02 10:57:06 +02:00
"for clf in (log_clf, rnd_clf, svm_clf, voting_clf):\n",
" clf.fit(X_train, y_train)\n",
" y_pred = clf.predict(X_test)\n",
" print(clf.__class__.__name__, accuracy_score(y_test, y_pred))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"log_clf = LogisticRegression(random_state=42)\n",
"rnd_clf = RandomForestClassifier(random_state=42)\n",
"svm_clf = SVC(probability=True, random_state=42)\n",
"\n",
"voting_clf = VotingClassifier(\n",
" estimators=[('lr', log_clf), ('rf', rnd_clf), ('svc', svm_clf)],\n",
" voting='soft')\n",
"voting_clf.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"from sklearn.metrics import accuracy_score\n",
"\n",
"for clf in (log_clf, rnd_clf, svm_clf, voting_clf):\n",
" clf.fit(X_train, y_train)\n",
" y_pred = clf.predict(X_test)\n",
" print(clf.__class__.__name__, accuracy_score(y_test, y_pred))"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Bagging ensembles"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 9,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-06-02 10:57:06 +02:00
"collapsed": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.ensemble import BaggingClassifier\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"\n",
"bag_clf = BaggingClassifier(\n",
2017-06-02 10:57:06 +02:00
" DecisionTreeClassifier(random_state=42), n_estimators=500,\n",
" max_samples=100, bootstrap=True, n_jobs=-1, random_state=42)\n",
2016-09-27 23:31:21 +02:00
"bag_clf.fit(X_train, y_train)\n",
2017-06-02 10:57:06 +02:00
"y_pred = bag_clf.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.metrics import accuracy_score\n",
2016-09-27 23:31:21 +02:00
"print(accuracy_score(y_test, y_pred))"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 11,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"tree_clf = DecisionTreeClassifier(random_state=42)\n",
"tree_clf.fit(X_train, y_train)\n",
"y_pred_tree = tree_clf.predict(X_test)\n",
"print(accuracy_score(y_test, y_pred_tree))"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 12,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from matplotlib.colors import ListedColormap\n",
"\n",
"def plot_decision_boundary(clf, X, y, axes=[-1.5, 2.5, -1, 1.5], alpha=0.5, contour=True):\n",
" x1s = np.linspace(axes[0], axes[1], 100)\n",
" x2s = np.linspace(axes[2], axes[3], 100)\n",
" x1, x2 = np.meshgrid(x1s, x2s)\n",
" X_new = np.c_[x1.ravel(), x2.ravel()]\n",
" y_pred = clf.predict(X_new).reshape(x1.shape)\n",
" custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n",
" plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap, linewidth=10)\n",
" if contour:\n",
" custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])\n",
" plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8)\n",
" plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\", alpha=alpha)\n",
" plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\", alpha=alpha)\n",
" plt.axis(axes)\n",
" plt.xlabel(r\"$x_1$\", fontsize=18)\n",
" plt.ylabel(r\"$x_2$\", fontsize=18, rotation=0)"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 13,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"plt.figure(figsize=(11,4))\n",
"plt.subplot(121)\n",
"plot_decision_boundary(tree_clf, X, y)\n",
"plt.title(\"Decision Tree\", fontsize=14)\n",
"plt.subplot(122)\n",
"plot_decision_boundary(bag_clf, X, y)\n",
"plt.title(\"Decision Trees with Bagging\", fontsize=14)\n",
"save_fig(\"decision_tree_without_and_with_bagging_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Random Forests"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 14,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"bag_clf = BaggingClassifier(\n",
2017-06-02 10:57:06 +02:00
" DecisionTreeClassifier(splitter=\"random\", max_leaf_nodes=16, random_state=42),\n",
" n_estimators=500, max_samples=1.0, bootstrap=True, n_jobs=-1, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"bag_clf.fit(X_train, y_train)\n",
"y_pred = bag_clf.predict(X_test)"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 16,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"rnd_clf = RandomForestClassifier(n_estimators=500, max_leaf_nodes=16, n_jobs=-1, random_state=42)\n",
"rnd_clf.fit(X_train, y_train)\n",
"\n",
"y_pred_rf = rnd_clf.predict(X_test)"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 17,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"np.sum(y_pred == y_pred_rf) / len(y_pred) # almost identical predictions"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 18,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.datasets import load_iris\n",
"iris = load_iris()\n",
"rnd_clf = RandomForestClassifier(n_estimators=500, n_jobs=-1, random_state=42)\n",
"rnd_clf.fit(iris[\"data\"], iris[\"target\"])\n",
2017-06-02 10:57:06 +02:00
"for name, score in zip(iris[\"feature_names\"], rnd_clf.feature_importances_):\n",
" print(name, score)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 19,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"rnd_clf.feature_importances_"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 20,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"plt.figure(figsize=(6, 4))\n",
"\n",
"for i in range(15):\n",
2017-06-02 10:57:06 +02:00
" tree_clf = DecisionTreeClassifier(max_leaf_nodes=16, random_state=42 + i)\n",
2016-09-27 23:31:21 +02:00
" indices_with_replacement = rnd.randint(0, len(X_train), len(X_train))\n",
" tree_clf.fit(X[indices_with_replacement], y[indices_with_replacement])\n",
" plot_decision_boundary(tree_clf, X, y, axes=[-1.5, 2.5, -1, 1.5], alpha=0.02, contour=False)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"## Out-of-Bag evaluation"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 21,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"bag_clf = BaggingClassifier(\n",
" DecisionTreeClassifier(random_state=42), n_estimators=500,\n",
2017-06-02 10:57:06 +02:00
" bootstrap=True, n_jobs=-1, oob_score=True, random_state=40)\n",
2016-09-27 23:31:21 +02:00
"bag_clf.fit(X_train, y_train)\n",
"bag_clf.oob_score_"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 22,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
2017-06-02 10:57:06 +02:00
"bag_clf.oob_decision_function_"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 23,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.metrics import accuracy_score\n",
"y_pred = bag_clf.predict(X_test)\n",
"accuracy_score(y_test, y_pred)"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"## Feature importance"
]
},
2017-04-07 21:33:53 +02:00
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 24,
2017-04-07 21:33:53 +02:00
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_mldata\n",
2017-06-02 10:57:06 +02:00
"mnist = fetch_mldata('MNIST original')"
2017-04-07 21:33:53 +02:00
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 25,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"rnd_clf = RandomForestClassifier(random_state=42)\n",
"rnd_clf.fit(mnist[\"data\"], mnist[\"target\"])"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 26,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"def plot_digit(data):\n",
" image = data.reshape(28, 28)\n",
" plt.imshow(image, cmap = matplotlib.cm.hot,\n",
" interpolation=\"nearest\")\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 27,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"plot_digit(rnd_clf.feature_importances_)\n",
"\n",
"cbar = plt.colorbar(ticks=[rnd_clf.feature_importances_.min(), rnd_clf.feature_importances_.max()])\n",
"cbar.ax.set_yticklabels(['Not important', 'Very important'])\n",
"\n",
"save_fig(\"mnist_feature_importance_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# AdaBoost"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 28,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-06-02 10:57:06 +02:00
"collapsed": false
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.ensemble import AdaBoostClassifier\n",
"\n",
"ada_clf = AdaBoostClassifier(\n",
2017-06-02 10:57:06 +02:00
" DecisionTreeClassifier(max_depth=1), n_estimators=200,\n",
" algorithm=\"SAMME.R\", learning_rate=0.5, random_state=42)\n",
"ada_clf.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"plot_decision_boundary(ada_clf, X, y)"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 30,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"m = len(X_train)\n",
"\n",
"plt.figure(figsize=(11, 4))\n",
"for subplot, learning_rate in ((121, 1), (122, 0.5)):\n",
" sample_weights = np.ones(m)\n",
" for i in range(5):\n",
" plt.subplot(subplot)\n",
" svm_clf = SVC(kernel=\"rbf\", C=0.05)\n",
" svm_clf.fit(X_train, y_train, sample_weight=sample_weights)\n",
" y_pred = svm_clf.predict(X_train)\n",
" sample_weights[y_pred != y_train] *= (1 + learning_rate)\n",
" plot_decision_boundary(svm_clf, X, y, alpha=0.2)\n",
" plt.title(\"learning_rate = {}\".format(learning_rate - 1), fontsize=16)\n",
"\n",
"plt.subplot(121)\n",
"plt.text(-0.7, -0.65, \"1\", fontsize=14)\n",
"plt.text(-0.6, -0.10, \"2\", fontsize=14)\n",
"plt.text(-0.5, 0.10, \"3\", fontsize=14)\n",
"plt.text(-0.4, 0.55, \"4\", fontsize=14)\n",
"plt.text(-0.3, 0.90, \"5\", fontsize=14)\n",
"save_fig(\"boosting_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 31,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"list(m for m in dir(ada_clf) if not m.startswith(\"_\") and m.endswith(\"_\"))"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Gradient Boosting"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 32,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"rnd.seed(42)\n",
"X = rnd.rand(100, 1) - 0.5\n",
"y = 3*X[:, 0]**2 + 0.05 * rnd.randn(100)"
]
},
{
"cell_type": "code",
"execution_count": 33,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.tree import DecisionTreeRegressor\n",
"\n",
"tree_reg1 = DecisionTreeRegressor(max_depth=2, random_state=42)\n",
2017-06-02 10:57:06 +02:00
"tree_reg1.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"y2 = y - tree_reg1.predict(X)\n",
"tree_reg2 = DecisionTreeRegressor(max_depth=2, random_state=42)\n",
2017-06-02 10:57:06 +02:00
"tree_reg2.fit(X, y2)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"y3 = y2 - tree_reg2.predict(X)\n",
"tree_reg3 = DecisionTreeRegressor(max_depth=2, random_state=42)\n",
2017-06-02 10:57:06 +02:00
"tree_reg3.fit(X, y3)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 36,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"X_new = np.array([[0.8]])"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y_pred = sum(tree.predict(X_new) for tree in (tree_reg1, tree_reg2, tree_reg3))"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y_pred"
]
},
{
"cell_type": "code",
"execution_count": 39,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"def plot_predictions(regressors, X, y, axes, label=None, style=\"r-\", data_style=\"b.\", data_label=None):\n",
" x1 = np.linspace(axes[0], axes[1], 500)\n",
" y_pred = sum(regressor.predict(x1.reshape(-1, 1)) for regressor in regressors)\n",
" plt.plot(X[:, 0], y, data_style, label=data_label)\n",
" plt.plot(x1, y_pred, style, linewidth=2, label=label)\n",
" if label or data_label:\n",
" plt.legend(loc=\"upper center\", fontsize=16)\n",
" plt.axis(axes)\n",
"\n",
"plt.figure(figsize=(11,11))\n",
"\n",
"plt.subplot(321)\n",
"plot_predictions([tree_reg1], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"$h_1(x_1)$\", style=\"g-\", data_label=\"Training set\")\n",
"plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n",
"plt.title(\"Residuals and tree predictions\", fontsize=16)\n",
"\n",
"plt.subplot(322)\n",
"plot_predictions([tree_reg1], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"$h(x_1) = h_1(x_1)$\", data_label=\"Training set\")\n",
"plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n",
"plt.title(\"Ensemble predictions\", fontsize=16)\n",
"\n",
"plt.subplot(323)\n",
"plot_predictions([tree_reg2], X, y2, axes=[-0.5, 0.5, -0.5, 0.5], label=\"$h_2(x_1)$\", style=\"g-\", data_style=\"k+\", data_label=\"Residuals\")\n",
"plt.ylabel(\"$y - h_1(x_1)$\", fontsize=16)\n",
"\n",
"plt.subplot(324)\n",
"plot_predictions([tree_reg1, tree_reg2], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"$h(x_1) = h_1(x_1) + h_2(x_1)$\")\n",
"plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n",
"\n",
"plt.subplot(325)\n",
"plot_predictions([tree_reg3], X, y3, axes=[-0.5, 0.5, -0.5, 0.5], label=\"$h_3(x_1)$\", style=\"g-\", data_style=\"k+\")\n",
"plt.ylabel(\"$y - h_1(x_1) - h_2(x_1)$\", fontsize=16)\n",
"plt.xlabel(\"$x_1$\", fontsize=16)\n",
"\n",
"plt.subplot(326)\n",
"plot_predictions([tree_reg1, tree_reg2, tree_reg3], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"$h(x_1) = h_1(x_1) + h_2(x_1) + h_3(x_1)$\")\n",
"plt.xlabel(\"$x_1$\", fontsize=16)\n",
"plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n",
"\n",
"save_fig(\"gradient_boosting_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 40,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.ensemble import GradientBoostingRegressor\n",
"\n",
2017-06-02 10:57:06 +02:00
"gbrt = GradientBoostingRegressor(max_depth=2, n_estimators=3, learning_rate=1.0, random_state=42)\n",
"gbrt.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"gbrt_slow = GradientBoostingRegressor(max_depth=2, n_estimators=200, learning_rate=0.1, random_state=42)\n",
2017-06-02 10:57:06 +02:00
"gbrt_slow.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"plt.figure(figsize=(11,4))\n",
"\n",
"plt.subplot(121)\n",
"plot_predictions([gbrt], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"Ensemble predictions\")\n",
"plt.title(\"learning_rate={}, n_estimators={}\".format(gbrt.learning_rate, gbrt.n_estimators), fontsize=14)\n",
"\n",
"plt.subplot(122)\n",
"plot_predictions([gbrt_slow], X, y, axes=[-0.5, 0.5, -0.1, 0.8])\n",
"plt.title(\"learning_rate={}, n_estimators={}\".format(gbrt_slow.learning_rate, gbrt_slow.n_estimators), fontsize=14)\n",
"\n",
"save_fig(\"gbrt_learning_rate_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"## Gradient Boosting with Early stopping"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 43,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
2017-06-02 10:57:06 +02:00
"import numpy as np\n",
2016-11-05 14:25:56 +01:00
"from sklearn.model_selection import train_test_split\n",
2016-09-27 23:31:21 +02:00
"from sklearn.metrics import mean_squared_error\n",
"\n",
2017-06-02 10:57:06 +02:00
"X_train, X_val, y_train, y_val = train_test_split(X, y, random_state=49)\n",
2016-09-27 23:31:21 +02:00
"\n",
2017-06-02 10:57:06 +02:00
"gbrt = GradientBoostingRegressor(max_depth=2, n_estimators=120, random_state=42)\n",
2016-09-27 23:31:21 +02:00
"gbrt.fit(X_train, y_train)\n",
"\n",
2017-06-02 10:57:06 +02:00
"errors = [mean_squared_error(y_val, y_pred)\n",
" for y_pred in gbrt.staged_predict(X_val)]\n",
"bst_n_estimators = np.argmin(errors)\n",
"\n",
"gbrt_best = GradientBoostingRegressor(max_depth=2,n_estimators=bst_n_estimators, random_state=42)\n",
"gbrt_best.fit(X_train, y_train)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 44,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-06-02 10:57:06 +02:00
"collapsed": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
2017-06-02 10:57:06 +02:00
"min_error = np.min(errors)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 45,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"plt.figure(figsize=(11, 4))\n",
"\n",
"plt.subplot(121)\n",
"plt.plot(errors, \"b.-\")\n",
2017-06-02 10:57:06 +02:00
"plt.plot([bst_n_estimators, bst_n_estimators], [0, min_error], \"k--\")\n",
2016-09-27 23:31:21 +02:00
"plt.plot([0, 120], [min_error, min_error], \"k--\")\n",
2017-06-02 10:57:06 +02:00
"plt.plot(bst_n_estimators, min_error, \"ko\")\n",
"plt.text(bst_n_estimators, min_error*1.2, \"Minimum\", ha=\"center\", fontsize=14)\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 120, 0, 0.01])\n",
"plt.xlabel(\"Number of trees\")\n",
"plt.title(\"Validation error\", fontsize=14)\n",
"\n",
"plt.subplot(122)\n",
"plot_predictions([gbrt_best], X, y, axes=[-0.5, 0.5, -0.1, 0.8])\n",
2017-06-02 10:57:06 +02:00
"plt.title(\"Best model (%d trees)\" % bst_n_estimators, fontsize=14)\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"early_stopping_gbrt_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 46,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
2017-06-02 10:57:06 +02:00
"gbrt = GradientBoostingRegressor(max_depth=2, warm_start=True, random_state=42)\n",
2016-09-27 23:31:21 +02:00
"\n",
"min_val_error = float(\"inf\")\n",
"error_going_up = 0\n",
"for n_estimators in range(1, 120):\n",
" gbrt.n_estimators = n_estimators\n",
" gbrt.fit(X_train, y_train)\n",
" y_pred = gbrt.predict(X_val)\n",
" val_error = mean_squared_error(y_val, y_pred)\n",
" if val_error < min_val_error:\n",
" min_val_error = val_error\n",
" error_going_up = 0\n",
" else:\n",
" error_going_up += 1\n",
" if error_going_up == 5:\n",
" break # early stopping"
]
},
{
"cell_type": "code",
2017-06-02 10:57:06 +02:00
"execution_count": 47,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"print(gbrt.n_estimators)"
]
},
{
"cell_type": "markdown",
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"**Coming soon**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2017-06-02 10:57:06 +02:00
"version": "3.5.3"
2016-09-27 23:31:21 +02:00
},
"nav_menu": {
"height": "252px",
"width": "333px"
},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 0
}