handson-ml/07_ensemble_learning_and_ra...

894 lines
24 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"**Chapter 7 Ensemble Learning and Random Forests**"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"_This notebook contains all the sample code and solutions to the exercices in chapter 7._"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"# To support both python 2 and python 3\n",
"from __future__ import division, print_function, unicode_literals\n",
"\n",
"# Common imports\n",
"import numpy as np\n",
"import numpy.random as rnd\n",
"import os\n",
"\n",
"# to make this notebook's output stable across runs\n",
"rnd.seed(42)\n",
"\n",
"# To plot pretty figures\n",
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['axes.labelsize'] = 14\n",
"plt.rcParams['xtick.labelsize'] = 12\n",
"plt.rcParams['ytick.labelsize'] = 12\n",
"\n",
"# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n",
"CHAPTER_ID = \"ensembles\"\n",
"\n",
"def image_path(fig_id):\n",
" return os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id)\n",
"\n",
"def save_fig(fig_id, tight_layout=True):\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(image_path(fig_id) + \".png\", format='png', dpi=300)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Voting classifiers"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"heads_proba = 0.51\n",
"coin_tosses = (rnd.rand(10000, 10) < heads_proba).astype(np.int32)\n",
"cumulative_heads_ratio = np.cumsum(coin_tosses, axis=0) / np.arange(1, 10001).reshape(-1, 1)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"plt.figure(figsize=(8,3.5))\n",
"plt.plot(cumulative_heads_ratio)\n",
"plt.plot([0, 10000], [0.51, 0.51], \"k--\", linewidth=2, label=\"51%\")\n",
"plt.plot([0, 10000], [0.5, 0.5], \"k-\", label=\"50%\")\n",
"plt.xlabel(\"Number of coin tosses\")\n",
"plt.ylabel(\"Heads ratio\")\n",
"plt.legend(loc=\"lower right\")\n",
"plt.axis([0, 10000, 0.42, 0.58])\n",
"save_fig(\"law_of_large_numbers_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.datasets import make_moons\n",
"\n",
"X, y = make_moons(n_samples=500, noise=0.30, random_state=42)\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n",
"\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.ensemble import VotingClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.svm import SVC\n",
"\n",
"log_clf = LogisticRegression(random_state=42)\n",
"rnd_clf = RandomForestClassifier(random_state=42)\n",
"svm_clf = SVC(probability=True, random_state=42)\n",
"\n",
"voting_clf = VotingClassifier(\n",
" estimators=[('lr', log_clf), ('rf', rnd_clf), ('svc', svm_clf)],\n",
" voting='soft'\n",
" )\n",
"voting_clf.fit(X_train, y_train)\n",
"\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"for clf in (log_clf, rnd_clf, svm_clf, voting_clf):\n",
" clf.fit(X_train, y_train)\n",
" y_pred = clf.predict(X_test)\n",
" print(clf.__class__.__name__, accuracy_score(y_test, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Bagging ensembles"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.datasets import make_moons\n",
"from sklearn.ensemble import BaggingClassifier\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"\n",
"bag_clf = BaggingClassifier(\n",
" DecisionTreeClassifier(random_state=42), n_estimators=500,\n",
" max_samples=100, bootstrap=True, n_jobs=-1, random_state=42\n",
" )\n",
"bag_clf.fit(X_train, y_train)\n",
"y_pred = bag_clf.predict(X_test)\n",
"print(accuracy_score(y_test, y_pred))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"tree_clf = DecisionTreeClassifier(random_state=42)\n",
"tree_clf.fit(X_train, y_train)\n",
"y_pred_tree = tree_clf.predict(X_test)\n",
"print(accuracy_score(y_test, y_pred_tree))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from matplotlib.colors import ListedColormap\n",
"\n",
"def plot_decision_boundary(clf, X, y, axes=[-1.5, 2.5, -1, 1.5], alpha=0.5, contour=True):\n",
" x1s = np.linspace(axes[0], axes[1], 100)\n",
" x2s = np.linspace(axes[2], axes[3], 100)\n",
" x1, x2 = np.meshgrid(x1s, x2s)\n",
" X_new = np.c_[x1.ravel(), x2.ravel()]\n",
" y_pred = clf.predict(X_new).reshape(x1.shape)\n",
" custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n",
" plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap, linewidth=10)\n",
" if contour:\n",
" custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])\n",
" plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8)\n",
" plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\", alpha=alpha)\n",
" plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\", alpha=alpha)\n",
" plt.axis(axes)\n",
" plt.xlabel(r\"$x_1$\", fontsize=18)\n",
" plt.ylabel(r\"$x_2$\", fontsize=18, rotation=0)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"plt.figure(figsize=(11,4))\n",
"plt.subplot(121)\n",
"plot_decision_boundary(tree_clf, X, y)\n",
"plt.title(\"Decision Tree\", fontsize=14)\n",
"plt.subplot(122)\n",
"plot_decision_boundary(bag_clf, X, y)\n",
"plt.title(\"Decision Trees with Bagging\", fontsize=14)\n",
"save_fig(\"decision_tree_without_and_with_bagging_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Random Forests"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"bag_clf = BaggingClassifier(\n",
" DecisionTreeClassifier(splitter=\"random\", max_leaf_nodes=16, random_state=42),\n",
" n_estimators=500, max_samples=1.0, bootstrap=True,\n",
" n_jobs=-1, random_state=42\n",
" )\n",
"bag_clf.fit(X_train, y_train)\n",
"y_pred = bag_clf.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"rnd_clf = RandomForestClassifier(n_estimators=500, max_leaf_nodes=16, n_jobs=-1, random_state=42)\n",
"rnd_clf.fit(X_train, y_train)\n",
"\n",
"y_pred_rf = rnd_clf.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"np.sum(y_pred == y_pred_rf) / len(y_pred) # almost identical predictions"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.datasets import load_iris\n",
"iris = load_iris()\n",
"rnd_clf = RandomForestClassifier(n_estimators=500, n_jobs=-1, random_state=42)\n",
"rnd_clf.fit(iris[\"data\"], iris[\"target\"])\n",
"for name, importance in zip(iris[\"feature_names\"], rnd_clf.feature_importances_):\n",
" print(name, \"=\", importance)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"rnd_clf.feature_importances_"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"plt.figure(figsize=(6, 4))\n",
"\n",
"for i in range(15):\n",
" tree_clf = DecisionTreeClassifier(max_leaf_nodes=16, random_state=42+i)\n",
" indices_with_replacement = rnd.randint(0, len(X_train), len(X_train))\n",
" tree_clf.fit(X[indices_with_replacement], y[indices_with_replacement])\n",
" plot_decision_boundary(tree_clf, X, y, axes=[-1.5, 2.5, -1, 1.5], alpha=0.02, contour=False)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"## Out-of-Bag evaluation"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"bag_clf = BaggingClassifier(\n",
" DecisionTreeClassifier(random_state=42), n_estimators=500,\n",
" bootstrap=True, n_jobs=-1, oob_score=True, random_state=40\n",
")\n",
"bag_clf.fit(X_train, y_train)\n",
"bag_clf.oob_score_"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"bag_clf.oob_decision_function_[:10]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.metrics import accuracy_score\n",
"y_pred = bag_clf.predict(X_test)\n",
"accuracy_score(y_test, y_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"## Feature importance"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_mldata\n",
"mnist = fetch_mldata('MNIST original')\n",
"rnd_clf = RandomForestClassifier(random_state=42)\n",
"rnd_clf.fit(mnist[\"data\"], mnist[\"target\"])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"def plot_digit(data):\n",
" image = data.reshape(28, 28)\n",
" plt.imshow(image, cmap = matplotlib.cm.hot,\n",
" interpolation=\"nearest\")\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"plot_digit(rnd_clf.feature_importances_)\n",
"\n",
"cbar = plt.colorbar(ticks=[rnd_clf.feature_importances_.min(), rnd_clf.feature_importances_.max()])\n",
"cbar.ax.set_yticklabels(['Not important', 'Very important'])\n",
"\n",
"save_fig(\"mnist_feature_importance_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# AdaBoost"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.ensemble import AdaBoostClassifier\n",
"\n",
"ada_clf = AdaBoostClassifier(\n",
" DecisionTreeClassifier(max_depth=2), n_estimators=200,\n",
" algorithm=\"SAMME.R\", learning_rate=0.5, random_state=42\n",
" )\n",
"ada_clf.fit(X_train, y_train)\n",
"plot_decision_boundary(ada_clf, X, y)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"m = len(X_train)\n",
"\n",
"plt.figure(figsize=(11, 4))\n",
"for subplot, learning_rate in ((121, 1), (122, 0.5)):\n",
" sample_weights = np.ones(m)\n",
" for i in range(5):\n",
" plt.subplot(subplot)\n",
" svm_clf = SVC(kernel=\"rbf\", C=0.05)\n",
" svm_clf.fit(X_train, y_train, sample_weight=sample_weights)\n",
" y_pred = svm_clf.predict(X_train)\n",
" sample_weights[y_pred != y_train] *= (1 + learning_rate)\n",
" plot_decision_boundary(svm_clf, X, y, alpha=0.2)\n",
" plt.title(\"learning_rate = {}\".format(learning_rate - 1), fontsize=16)\n",
"\n",
"plt.subplot(121)\n",
"plt.text(-0.7, -0.65, \"1\", fontsize=14)\n",
"plt.text(-0.6, -0.10, \"2\", fontsize=14)\n",
"plt.text(-0.5, 0.10, \"3\", fontsize=14)\n",
"plt.text(-0.4, 0.55, \"4\", fontsize=14)\n",
"plt.text(-0.3, 0.90, \"5\", fontsize=14)\n",
"save_fig(\"boosting_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"list(m for m in dir(ada_clf) if not m.startswith(\"_\") and m.endswith(\"_\"))"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Gradient Boosting"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.tree import DecisionTreeRegressor\n",
"\n",
"rnd.seed(42)\n",
"X = rnd.rand(100, 1) - 0.5\n",
"y = 3*X[:, 0]**2 + 0.05 * rnd.randn(100)\n",
"\n",
"tree_reg1 = DecisionTreeRegressor(max_depth=2, random_state=42)\n",
"tree_reg1.fit(X, y)\n",
"\n",
"y2 = y - tree_reg1.predict(X)\n",
"tree_reg2 = DecisionTreeRegressor(max_depth=2, random_state=42)\n",
"tree_reg2.fit(X, y2)\n",
"\n",
"y3 = y2 - tree_reg2.predict(X)\n",
"tree_reg3 = DecisionTreeRegressor(max_depth=2, random_state=42)\n",
"tree_reg3.fit(X, y3)\n",
"\n",
"X_new = np.array([[0.8]])\n",
"y_pred = sum(tree.predict(X_new) for tree in (tree_reg1, tree_reg2, tree_reg3))\n",
"print(y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"def plot_predictions(regressors, X, y, axes, label=None, style=\"r-\", data_style=\"b.\", data_label=None):\n",
" x1 = np.linspace(axes[0], axes[1], 500)\n",
" y_pred = sum(regressor.predict(x1.reshape(-1, 1)) for regressor in regressors)\n",
" plt.plot(X[:, 0], y, data_style, label=data_label)\n",
" plt.plot(x1, y_pred, style, linewidth=2, label=label)\n",
" if label or data_label:\n",
" plt.legend(loc=\"upper center\", fontsize=16)\n",
" plt.axis(axes)\n",
"\n",
"plt.figure(figsize=(11,11))\n",
"\n",
"plt.subplot(321)\n",
"plot_predictions([tree_reg1], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"$h_1(x_1)$\", style=\"g-\", data_label=\"Training set\")\n",
"plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n",
"plt.title(\"Residuals and tree predictions\", fontsize=16)\n",
"\n",
"plt.subplot(322)\n",
"plot_predictions([tree_reg1], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"$h(x_1) = h_1(x_1)$\", data_label=\"Training set\")\n",
"plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n",
"plt.title(\"Ensemble predictions\", fontsize=16)\n",
"\n",
"plt.subplot(323)\n",
"plot_predictions([tree_reg2], X, y2, axes=[-0.5, 0.5, -0.5, 0.5], label=\"$h_2(x_1)$\", style=\"g-\", data_style=\"k+\", data_label=\"Residuals\")\n",
"plt.ylabel(\"$y - h_1(x_1)$\", fontsize=16)\n",
"\n",
"plt.subplot(324)\n",
"plot_predictions([tree_reg1, tree_reg2], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"$h(x_1) = h_1(x_1) + h_2(x_1)$\")\n",
"plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n",
"\n",
"plt.subplot(325)\n",
"plot_predictions([tree_reg3], X, y3, axes=[-0.5, 0.5, -0.5, 0.5], label=\"$h_3(x_1)$\", style=\"g-\", data_style=\"k+\")\n",
"plt.ylabel(\"$y - h_1(x_1) - h_2(x_1)$\", fontsize=16)\n",
"plt.xlabel(\"$x_1$\", fontsize=16)\n",
"\n",
"plt.subplot(326)\n",
"plot_predictions([tree_reg1, tree_reg2, tree_reg3], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"$h(x_1) = h_1(x_1) + h_2(x_1) + h_3(x_1)$\")\n",
"plt.xlabel(\"$x_1$\", fontsize=16)\n",
"plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n",
"\n",
"save_fig(\"gradient_boosting_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.ensemble import GradientBoostingRegressor\n",
"\n",
"gbrt = GradientBoostingRegressor(max_depth=2, n_estimators=3, learning_rate=0.1, random_state=42)\n",
"gbrt.fit(X, y)\n",
"\n",
"gbrt_slow = GradientBoostingRegressor(max_depth=2, n_estimators=200, learning_rate=0.1, random_state=42)\n",
"gbrt_slow.fit(X, y)\n",
"\n",
"plt.figure(figsize=(11,4))\n",
"\n",
"plt.subplot(121)\n",
"plot_predictions([gbrt], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"Ensemble predictions\")\n",
"plt.title(\"learning_rate={}, n_estimators={}\".format(gbrt.learning_rate, gbrt.n_estimators), fontsize=14)\n",
"\n",
"plt.subplot(122)\n",
"plot_predictions([gbrt_slow], X, y, axes=[-0.5, 0.5, -0.1, 0.8])\n",
"plt.title(\"learning_rate={}, n_estimators={}\".format(gbrt_slow.learning_rate, gbrt_slow.n_estimators), fontsize=14)\n",
"\n",
"save_fig(\"gbrt_learning_rate_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"## Gradient Boosting with Early stopping"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import mean_squared_error\n",
"\n",
"X_train, X_val, y_train, y_val = train_test_split(X, y)\n",
"\n",
"gbrt = GradientBoostingRegressor(max_depth=2, n_estimators=120, learning_rate=0.1, random_state=42)\n",
"gbrt.fit(X_train, y_train)\n",
"\n",
"errors = [mean_squared_error(y_val, y_pred) for y_pred in gbrt.staged_predict(X_val)]"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"best_n_estimators = np.argmin(errors)\n",
"min_error = errors[best_n_estimators]\n",
"\n",
"gbrt_best = GradientBoostingRegressor(max_depth=2, n_estimators=best_n_estimators, learning_rate=0.1, random_state=42)\n",
"gbrt_best.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"plt.figure(figsize=(11, 4))\n",
"\n",
"plt.subplot(121)\n",
"plt.plot(errors, \"b.-\")\n",
"plt.plot([best_n_estimators, best_n_estimators], [0, min_error], \"k--\")\n",
"plt.plot([0, 120], [min_error, min_error], \"k--\")\n",
"plt.plot(best_n_estimators, min_error, \"ko\")\n",
"plt.text(best_n_estimators, min_error*1.2, \"Minimum\", ha=\"center\", fontsize=14)\n",
"plt.axis([0, 120, 0, 0.01])\n",
"plt.xlabel(\"Number of trees\")\n",
"plt.title(\"Validation error\", fontsize=14)\n",
"\n",
"plt.subplot(122)\n",
"plot_predictions([gbrt_best], X, y, axes=[-0.5, 0.5, -0.1, 0.8])\n",
"plt.title(\"Best model (55 trees)\", fontsize=14)\n",
"\n",
"save_fig(\"early_stopping_gbrt_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"gbrt = GradientBoostingRegressor(max_depth=2, n_estimators=1, learning_rate=0.1, random_state=42, warm_start=True)\n",
"\n",
"min_val_error = float(\"inf\")\n",
"error_going_up = 0\n",
"for n_estimators in range(1, 120):\n",
" gbrt.n_estimators = n_estimators\n",
" gbrt.fit(X_train, y_train)\n",
" y_pred = gbrt.predict(X_val)\n",
" val_error = mean_squared_error(y_val, y_pred)\n",
" if val_error < min_val_error:\n",
" min_val_error = val_error\n",
" error_going_up = 0\n",
" else:\n",
" error_going_up += 1\n",
" if error_going_up == 5:\n",
" break # early stopping"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"print(gbrt.n_estimators)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"**Coming soon**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2+"
},
"nav_menu": {
"height": "252px",
"width": "333px"
},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 0
}