handson-ml/07_ensemble_learning_and_ra...

2089 lines
359 KiB
Plaintext
Raw Normal View History

2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"**Chapter 7 Ensemble Learning and Random Forests**"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"_This notebook contains all the sample code and solutions to the exercises in chapter 7._"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/07_ensemble_learning_and_random_forests.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/07_ensemble_learning_and_random_forests.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
" </td>\n",
"</table>"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 11:03:20 +01:00
"This project requires Python 3.7 or above:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"import sys\n",
2016-09-27 23:31:21 +02:00
"\n",
2022-02-19 11:03:20 +01:00
"assert sys.version_info >= (3, 7)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It also requires Scikit-Learn ≥ 1.0.1:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
2016-09-27 23:31:21 +02:00
"source": [
"from packaging import version\n",
"import sklearn\n",
"\n",
"assert version.parse(sklearn.__version__) >= version.parse(\"1.0.1\")"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we did in previous chapters, let's define the default font sizes to make the figures prettier:"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
2016-09-27 23:31:21 +02:00
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rc('font', size=14)\n",
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
2021-12-08 03:16:42 +01:00
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And let's create the `images/ensembles` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"from pathlib import Path\n",
2016-09-27 23:31:21 +02:00
"\n",
"IMAGES_PATH = Path() / \"images\" / \"ensembles\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Voting Classifiers"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAADsCAYAAABqkpwSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAACVGElEQVR4nOydZXgbx9aA3xGaGWOIHWZyGJ1S2qaMKTOnt4xfIYVb5ltMmZtykzblxMGGmcEBO2ZG8Xw/VpYsS4Y0VJj3efx4d2hndyXt2TMHhJQShUKhUCgUir8TuiM9AYVCoVAoFIr9RQkwCoVCoVAo/nYoAUahUCgUCsXfDiXAKBQKhUKh+NuhBBiFQqFQKBR/O5QAo1AoFAqF4m/HYRVghBDHCyG2CiF2CCHuDlCfLYSoFkKscf890KzuFiHERiHEBiHEp0KIoMM5d4VCoVAoFH8dDpsAI4TQA68AJwB9gPOEEH0CNF0gpRzk/nvY3TcF+A8wVErZD9ADUw7T1BUKhUKhUPzFOJwamOHADillrpTSBnwGnLof/Q1AsBDCAIQABYdgjgqFQqFQKP4GHE4BJgXIa7af7y5rySghxFohxI9CiL4AUsp9wDPAXqAQqJZS/nKoJ6xQKBQKheKvieEwHksEKGuZx2AV0FlKWSeEOBH4FuguhIhG09ZkAlXAF0KIC6WUH/kdRIirgasBRERUlj4pGQCDbRcA6ab0g3Euig7gcrnQ6ZSd+JFE3YMji7r+Rx51D44s27ZtK5NSxh+KsQ+nAJMPpDXbT6XFMpCUsqbZ9mwhxKtCiDhgIrBLSlkKIIT4GhgN+AkwUsrpwHQAY88+Mvb1TwCI33sRAOsvWX/wzkjRJjk5OWRnZx/pafyrUffgyKKu/5FH3YMjixBiz6Ea+3CKpcvRtCmZQggTmhHuzOYNhBBJQgjh3h7unl852tLRSCFEiLv+aGDzYZy7QqFQKBSKvxCHTQMjpXQIIaYCP6N5Eb0jpdwohLjWXf86cBZwnRDCATQCU6SWLnupEOJLtCUmB7Aat5ZFoVAoFArFv4/DuYSElHI2MLtF2evNtl8GXm6l74PAgwdyfIPusJ6uQqFQKBSKQ8S/yrLJ4XIc6SkoFAqFQqE4CPyrBBiFQqFQKBT/DP41AsyPm1890lNQKBQKhUJxkPjXCDAKhUKhUCj+OfxrBJhXups4f9X9aE5NCoVCoVAo/s78awSYd7uYibDGMbu0mi+KKo70dBQKhUKhUBwA/zq/4is27gbg7KSYIzsRhUKhUCgUf5p/jQZGoVAoFArFPwclwCgUCoVCofjb8a8SYJyB8mErFAqFQqH42/GvEmCePDPas628kRQKhUKh+PvyrxJgnHqvCsap5BeFQqFQKP62/KsEmOY4URKMQqFQKBR/V/61AoxLyS8KhUKhUPxtOawCjBDieCHEViHEDiHE3QHqs4UQ1UKINe6/B5rVRQkhvhRCbBFCbBZCjDqQubha2MDUzNnLvvsXHciQCoVCoVAoDhOHLZCdEEIPvAIcC+QDy4UQM6WUm1o0XSClPCnAEC8CP0kpzxJCmICQA5mPs8V+zS97AM24VwjlrqRQKBQKxV+Zw6mBGQ7skFLmSiltwGfAqR3pKISIAMYDbwNIKW1SyqoDmYyzFS+kusUFBzKsQqFQKBSKw8DhFGBSgLxm+/nuspaMEkKsFUL8KITo6y7rApQC7wohVgsh3hJChB7IZFrzQrIX1h/IsAGpqKjgs88+47333jvoYysUCoVC8W/kcOZCCrQu01KMWAV0llLWCSFOBL4FuqPNcwhwo5RyqRDiReBu4H6/gwhxNXA1gKFH71Yns3DRIqKE9/Dd0ANQvK+QdTkFrFq1ivj4eNLS0jp8gq2Rk5MTcPufTl1d3b/qfP+KqHtwZFHX/8ij7sE/l8MpwOQDzaWBVMBnvUZKWdNse7YQ4lUhRJy7b76Ucqm7+ks0AcYPKeV0YDqAsWefVn2NRo4eTZLZCICzzkbhT9rQ0c4wemdnkZOTQ01NDRdddNH+nWULSktLfb48w4YNIzT0gJRHfxtycnLIzs4+0tP4V6PuwZFFXf8jj7oH/1wO5xLScqC7ECLTbYQ7BZjZvIEQIkm4LWiFEMPd8yuXUhYBeUKInu6mRwMtjX/3i+Y2MA1rSz3bjpIG7Hb7gQztwyuvvOKzv2bNmoM2tkKhUCgU/1YOmwZGSukQQkwFfgb0wDtSyo1CiGvd9a8DZwHXCSEcQCMwRXpj/t8IfOwWfnKByw5kPs0FGGlz+dQtXry43f4Wi4UnnniCfv36cfzxxyOMQazeW8XY7nGeNh988IFfv9jY2AOYtUKhUCgUCji8S0hIKWcDs1uUvd5s+2Xg5Vb6rgGGHqy5+IgsTu9erWhk7tz2BZjKykoANmzYQGVlJSuDhvDTxiJ+uWU8PRLDkVKSm5vr1++zzz7DaDRy/fXXEx0d7VevUCgUCoWifVQkXkA2c0mqFY0+7VatWuXRyFitVvLz8wHNMKyJsrIythRp5jvHPT+f3NI6bDabzzg333yzZ9tut/Piiy/idLaMRqNQKBQKhaIj/GsFGJ8lpGYamJZWvzNnzuSXX34B4PPPP+ett97CarXyySefeNpYrVZ2lzd49o96dh7lFZWe/cTERCIjI/3m8MgjjxzoaSgUCoVC8a/k3yvANBNVTJ3CPNuyjSSPeXlaGJtffvkF2UogvCamv+FZGWPy5MmtRvfduXNnwKUmhUKhUCgUrfOvFWB8lpBayeyYmJjo2V61apVnWWjlypVtjn2JebnPfolLE5AyMzP92n744YcBjX0VCoVCoVC0zr9WgPFJJdBKWN76em9U3pkzZwZsE4jmypb3LMP435wdAJx33nmcdva5DBw40K9PQYFKYaBQKBQKRUf5RwswcZS2WtfcC6nJiFcXYqBaeG1Zmgsw7THEkIdA4m9FA/O2lVLdaOf2rzZy2oe5lMcN5Oqrr/adj8vl10+hUCgUCkVg/tECTFsEMuJtPDqSP4zbAEhPT2/XzsUlYYU9FYABhiI66aq5fXjgKLsDH/qFmWs1LcvLc3cQEuKbTLuhoSFQN4VCoVAoFAH41wow5ZXl3h23BqbCUuUpMpvN7Y6hE9BNX+bZP9a0nbJ18zz7O5yBg9Z1iw9rVYBxuSSLd5RhdyqNjEKhUCgUrfGvFWB2rNro2fa4Ueu9xismk6lD4+x2xbRalzFgRODyuFC/8ZuWqy5/fznnv7WU7v/3Izd+urpDc1AoFAqF4t/Gv1aA8bGBcWgaGKkL7OrcnK5du7LL6Y2gu9sZWIC54oorePTs4Wx55Hi/ui9X5nuWp0JDQ9Hr9dTV1fNqzg5ytnrtdmatLaDGcvDyMikUCoVC8U/hXyvAyOauQi4JOoFopoFpuYQ0bNgwABxOF/PsXQEocIbjMocHHD8tTUu8HWTUs/OxE9nyyPF8d8MYT33mPbO5+eabueaaazCYg/hmRS5P/bTVb5xluRV/7gQVCoVCofgH868TYG7fbAFA4hVWpNOlCS/NZZpmXkHnnXce48ePB2BA1nBA8J5lGL/Ye/H6RVl+x9Dr9b77OkGQUc/AtCif8l+21xIREYG1oZ4ke9EBnplCoVAoFP8e/nUCTKh7ucjRvNAhQa/z8UxqnqeoZ8+ehIeHM23aNOI7dfaUv3FRFqO7xtG9e3efY4wbN67V4z98al/P9p1fraPHfT969g1ox7xvcm8W3X0UAGV11o6fnEKhUCgU/xIOazbqvwJO99LRB8FR/KfeDi6JdVc1Qi9wuLxiTWtxWeptWpuJPeOZ1DcJgAsuuACXy4WU0k/70pKLR2XwwHdeA2Kbw+W5C4m6Wva5opgyPB2jezlLCTAKhUKhUPhzWDUwQojjhRBbhRA7hBB3B6jPFkJUCyHWuP8eaFGvF0KsFkJ8/2fn4HCfcaHJTOEjSyj871LshfUgwO7wGsw2D/tf89NPNCzX0gM0uAWYS0Zn+Iyr0+naFV6a2Pqor2HvMrtmL3OsaTtvnNub/z33NOUlxUQEGSitVQKMQqFQKBQtOWwaGCGEHngFOBbIB5YLIWZKKTe1aLpASnlSK8P
"text/plain": [
"<Figure size 576x252 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 73\n",
2016-09-27 23:31:21 +02:00
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
2016-09-27 23:31:21 +02:00
"\n",
"heads_proba = 0.51\n",
"np.random.seed(42)\n",
"coin_tosses = (np.random.rand(10000, 10) < heads_proba).astype(np.int32)\n",
"cumulative_heads = coin_tosses.cumsum(axis=0)\n",
"cumulative_heads_ratio = cumulative_heads / np.arange(1, 10001).reshape(-1, 1)\n",
"\n",
2021-12-08 03:16:42 +01:00
"plt.figure(figsize=(8, 3.5))\n",
"plt.plot(cumulative_heads_ratio)\n",
"plt.plot([0, 10000], [0.51, 0.51], \"k--\", linewidth=2, label=\"51%\")\n",
"plt.plot([0, 10000], [0.5, 0.5], \"k-\", label=\"50%\")\n",
"plt.xlabel(\"Number of coin tosses\")\n",
"plt.ylabel(\"Heads ratio\")\n",
"plt.legend(loc=\"lower right\")\n",
"plt.axis([0, 10000, 0.42, 0.58])\n",
"plt.grid()\n",
"save_fig(\"law_of_large_numbers_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's build a voting classifier:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"VotingClassifier(estimators=[('lr', LogisticRegression(random_state=42)),\n",
" ('rf', RandomForestClassifier(random_state=42)),\n",
" ('svc', SVC(random_state=42))])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.datasets import make_moons\n",
"from sklearn.ensemble import RandomForestClassifier, VotingClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.svm import SVC\n",
2016-09-27 23:31:21 +02:00
"\n",
"X, y = make_moons(n_samples=500, noise=0.30, random_state=42)\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n",
"\n",
"voting_clf = VotingClassifier(\n",
" estimators=[\n",
" ('lr', LogisticRegression(random_state=42)),\n",
" ('rf', RandomForestClassifier(random_state=42)),\n",
" ('svc', SVC(random_state=42))\n",
" ]\n",
")\n",
"voting_clf.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"lr = 0.864\n",
"rf = 0.896\n",
"svc = 0.896\n"
]
}
],
"source": [
"for name, clf in voting_clf.named_estimators_.items():\n",
" print(name, \"=\", clf.score(X_test, y_test))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([1])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_clf.predict(X_test[:1])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[array([1]), array([1]), array([0])]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[clf.predict(X_test[:1]) for clf in voting_clf.estimators_]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.912"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_clf.score(X_test, y_test)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"Now let's use soft voting:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.92"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"voting_clf.voting = \"soft\"\n",
"voting_clf.named_estimators[\"svc\"].probability = True\n",
"voting_clf.fit(X_train, y_train)\n",
"voting_clf.score(X_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Bagging and Pasting\n",
"## Bagging and Pasting in Scikit-Learn"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"BaggingClassifier(base_estimator=DecisionTreeClassifier(), max_samples=100,\n",
" n_estimators=500, n_jobs=-1, random_state=42)"
2022-02-19 10:24:54 +01:00
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"from sklearn.ensemble import BaggingClassifier\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"\n",
"bag_clf = BaggingClassifier(DecisionTreeClassifier(), n_estimators=500,\n",
" max_samples=100, n_jobs=-1, random_state=42)\n",
"bag_clf.fit(X_train, y_train)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAAEQCAYAAAC++cJdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAACF5klEQVR4nO2deZgb1ZW33yv13m67bbwb29h4Z7EBsxjM7iZsxg4JZJ9kkgmTdRKYZL4kTIaZyWSZYYYsk5BMyD4JMCEYg/ECtjGLE8B7g/d9t7vdtntf1JLu94ektpYqqSSVVCX1eZ+nH1tVpVu3SlW/OnXuuecorTWCIAiCIAiC0F/wON0BQRAEQRAEQcgnYgALgiAIgiAI/QoxgAVBEARBEIR+hRjAgiAIgiAIQr9CDGBBEARBEAShXyEGsCAIgiAIgtCvEANY6Bcopf5ZKbXV4rYXKKW0Ump2rvslCEJxIpqTH6yc52I5v0qp3yilXnS6H8WCGMCCY4RvZh3+61VKNSql1iilPq+UKrV5d/8J3Ghx2yPAKGCLzX3oI+7YDf9ytW9B6K+I5hSl5sScZzuNRINz1qSUelEpNc2O9jPgS8BHHdp30SEGsOA0qwgJ/wXAbcAS4F+AN5RS1XbtRGvdrrU+bXHbgNb6pNbab9f+DfgSoeOO/HUCX45b1odSqiyHfRGE/oRoThFpTjrnOUMi18soQtdLJfBcDvdnita6RWvd7MS+ixExgAWn6QkL/zGt9Rat9WPATcDlwD9ENlJKlSml/l0pdVQp1aGUWq+Uek90Q0qpaUqpF5RSLUqpdqXUm0qpS8LrYobJlFKXKKVWK6ValVJtSql6pdTN4XUJw2VKqRuUUm8rpbqVUg1Kqe9HPyCUUq8qpR5XSn0n7CVoVEr9p1LK8B4LC9nJyB+ggZaoz08rpX4abuMU8OfwfmYopZaG+9yolHpKKTUy7jz8tVJqe7ivu5VSD5r1QxD6IaI5edYcpdTfhpd3K6VOKaVeUkqVGPVTKfV/SqmfRn3+dvjcXB217KhS6iPx51kp9c/Ax4G71Dmv7U1RzY9XSq1USnWG+1tn1Ic4eqLO2ybg+8A0pVRlVH++p5TapZTqUkodVEr9h1KqIu64vh7+HduVUr9TSj2ilDoYtb4k/BufDf99P/x7vBq1TYx328o1oJQaEb5Gu5RSh8K/1dbwuerXyENRcB1a663ACuB9UYt/TWiY68PAJcBvgSVKqZkASqnRwFpCol5H6GH2E8BrspsngRPAVcBlwD8D3UYbKqXGAMuBzeFtPwV8CPhu3KYfAfzAtcAXCHlXPmDlmE34KKCA64G/UkqNAl4Htob7PQ8YALwQETyl1KeB7wD/BEwH/h74f8DnsuiHIBQ1ojl92K45YaP+J4S87FPDbaxI0odXgZujPt8ENEWWKaUmA2PC28Xzn8AfifXa/iVq/beBHwEzgfWEjP4BSfoSg1KqhtD5fVdr3RW1qgP4JKHj/xzwQeDhqO99EHgkvOxyYAfwUFzzXwE+AfwNcA0h++zDFrqV6hr4LTAeuAVYQOg3Hm+h3eJHay1/8ufIH/Ab4EWTdd8DOsP/vxAIAuPitlkMPB7+/7eBQ0CZSXv/DGyN+twKfNxk2wsIPdRmR7W9F/BEbfMJoAeoCn9+FXgzrp2VwC8snot24BNRn18F3onb5l+B1XHLBof7elX482HgY3HbfBnY7vTvLX/y5/SfaE7MtnnRHOBeoAWosdiv6eH2RwFV4WP+GvBSeP2ngT1JznPCbxx1fv82atmY8LK5Ka4Xf/hctYe3PwxcnOIYPgPsjfr8JvCzuG1eBg5GfT4BfC3qswJ2Aq+aHVuqa4DQC4cGrolaPxYIAP+cj3vOzX+GQxCC4AIUoRsXQm/MCtiulIrephx4Jfz/y4C1WmufxfYfA36hlPo4sBp4Vmu902Tb6YREJhi1bC1QBkwC3gkveyfue8eB4Rb7Y8TGuM9XADcopdoNtr1QKXWAkLj9T/QQIlBC6PwJgmCOaE5uNGcloReFA0qplwgZfou01m1GHdBa71BKNXDO87sPeBr4RxWaqHgTxt5fK0Sfr+Phf1Odr9eBB8L/H0LIw/uyUupqrfURAKXU+wkZ/ZMIeci9xI4ETAOeiGv3bWBK+PuDgJHAushKrbVWSq0ndH6tHlPkuCLHNI3Qi9yGqHaPKKWOI4gBLLiWGcD+8P89hB5MVwK9cdtFhqHSMvC01v+slPoDcAfwHuARpdRntNa/Mtg8+sGY0FTU/+P7pskuzKgj7rMHWEpoqCyeBkLeEgh5H/5isI0gCOaI5uRAc7TWbUqpy4EbCIWKfB34jlLqSq21mSH2GqGQh1PAGq31QaVUE6Hf40ZCIRaZ0He+wgYmpD5fnVrrvZEPSqmNhDzaDwDfVEpdQ8hA/xfgQaAZuIdQOEY0VrJsZJKJI9k1II6PJEgMsOA6lFIXA7cDfwov2kzoRh6ptd4b93csvM0mYK5KY+ay1nqP1vpHWuu7gF8Sir0yYjswJ25yyVzAR8g7kS82ARcBhwzOQ5vWugE4BlxosH5v8qYFof8immOKLZqjtfZrrV/RWn8duBSoBu5Ost9XCRnAN3HO2/saIaPTLP43gg/zOGw70IS8qhHj/zrgmNb6W1rr9VrrPSTG2O4kFEMdTd9nrXULcDJ6mQpZ51dm2dcdhOy8K6LaPR8YnWW7RYEYwILTlCulRiqlRiulZiqlHiIkbhsJv0FrrXcDfwB+o5R6v1JqolJqtlLqK0qpe8PtPE5o6OmPSqkrlVKTlFIfUkrNit+hUqpSKfUTpdRNKjT7+mpCD5ftJn18nJBgPK6Umq6UuotQvOCPtdadtp2J1PwEGAT8n1Lq6vB5mKeU+nl4cgaE4uH+QYVmYU9VSl2slPorpdTX89hPQXAzojnWyVpzlFJ3K6W+pJS6TCk1ntDErhpCxpkZrxIKJ7iKc8buq4QmcEW/hBhxELg43JehKvv8zpHrZaRSajrw34R+9yXh9buBMUqpj4TPz2cJTViM5ofAJ5RSn1RKTVZK/QNwNbEe3x8SOo/vVUpNBf6LUBx0xvmZtda7gJeAnymlrglfm78mlAKvUPM+24aEQAhOM49Q8H+A0NDRVkJDSf8TF1v314Rm0P4HcD5whlC81BoArfUxpdQNwKPhZRp4l3OxW9EECE3k+C2huKvTwIsYD/NF2r4j3PaWcD+fBL6R0RFniNb6uFLqOkIzwVcAFYQmZLxMaKIIWutfKKU6gK+Gt+sCtgE/zmdfBcHFiOZYxCbNaQYWEsoSUUXIg/03Wus3kux3h1LqJHBaa30qvHgNIc/uqym6/QQhz/EGQobqzYSM4kyJXC8AbYS8ufdprV8N93WJUupR4AeEcgS/TOhYH486nqeVUhMJvcRUAYuAnxHKyhDhPwldG78mdC39mlC+4RFZ9B1CkyefIHTeGsN9m4hJBpL+hNK6378ECIIgCIIg5A2l1HNAidZ6fpJtNgF/1lp/0cb9DiU0Ue5DWutn7Wq3EBEPsCAIgiAIQo5QSlUBnyXkRfcTyje9gKi80+HwkPcQinUuITSSMBPjEYV09n0LoZCTdwllh/g2oewayXIx9wsciQFWSv1KhSqWbDVZf5MKVdbZEv77p3z3URAEQRAEwQY0oewfrxOaYPkBQrmTn4vaJgj8FaEwm7cIFcO4Q2u9gewoBf6NkAG8hFCIyg1a6/iMH/0OR0IgwnFT7cDvtNYXG6y/CfiK1jrZLFFBEARBEARBSBtHPMBa69cJTSgQBEEQBEEQhLzi5jRoc5RS9Uqp5Uqpi5zujCAIgiAIglAcuHUS3CZgvNa6XSl1J6H665ONNlRKPUA4SLy6uvKKqdMuyFcfhSKjs7ObttZ2AoEAXq+XmoEDqKqqSNiu4WQTgUAgYbnX62XEyKH56GoRUkKvT9PS0kpzczOlXs0FY7LN/lP4bNy6t0lrPSxX7Yt+CnYiGuoWzulpS3MzJf1YT5NpqGNp0JRSFwAvGsUAG2x7EJittW5Ktt0Vs2foN9f9wZ4OCv2K+i07WLJ4Jb29/r5lpaUlzF9Yx8xZ0zPeVrCGhyEcPRx
"text/plain": [
"<Figure size 720x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 75\n",
2016-09-27 23:31:21 +02:00
"\n",
"def plot_decision_boundary(clf, X, y, alpha=1.0):\n",
" axes=[-1.5, 2.4, -1, 1.5]\n",
" x1, x2 = np.meshgrid(np.linspace(axes[0], axes[1], 100),\n",
" np.linspace(axes[2], axes[3], 100))\n",
2016-09-27 23:31:21 +02:00
" X_new = np.c_[x1.ravel(), x2.ravel()]\n",
" y_pred = clf.predict(X_new).reshape(x1.shape)\n",
" \n",
" plt.contourf(x1, x2, y_pred, alpha=0.3 * alpha, cmap='Wistia')\n",
" plt.contour(x1, x2, y_pred, cmap=\"Greys\", alpha=0.8 * alpha)\n",
" colors = [\"#78785c\", \"#c47b27\"]\n",
" markers = (\"o\", \"^\")\n",
" for idx in (0, 1):\n",
" plt.plot(X[:, 0][y == idx], X[:, 1][y == idx],\n",
" color=colors[idx], marker=markers[idx], linestyle=\"none\")\n",
2016-09-27 23:31:21 +02:00
" plt.axis(axes)\n",
" plt.xlabel(r\"$x_1$\")\n",
2021-11-21 05:55:56 +01:00
" plt.ylabel(r\"$x_2$\", rotation=0)\n",
"\n",
"tree_clf = DecisionTreeClassifier(random_state=42)\n",
"tree_clf.fit(X_train, y_train)\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n",
"plt.sca(axes[0])\n",
"plot_decision_boundary(tree_clf, X_train, y_train)\n",
"plt.title(\"Decision Tree\")\n",
"plt.sca(axes[1])\n",
"plot_decision_boundary(bag_clf, X_train, y_train)\n",
"plt.title(\"Decision Trees with Bagging\")\n",
"plt.ylabel(\"\")\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"decision_tree_without_and_with_bagging_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## Out-of-Bag evaluation"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 14,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.896"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"bag_clf = BaggingClassifier(DecisionTreeClassifier(), n_estimators=500,\n",
" oob_score=True, n_jobs=-1, random_state=42)\n",
"bag_clf.fit(X_train, y_train)\n",
"bag_clf.oob_score_"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 15,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0.32352941, 0.67647059],\n",
" [0.3375 , 0.6625 ],\n",
" [1. , 0. ]])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bag_clf.oob_decision_function_[:3] # probas for the first 3 instances"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 16,
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.92"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import accuracy_score\n",
"\n",
"y_pred = bag_clf.predict(X_test)\n",
"accuracy_score(y_test, y_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you randomly draw one instance from a dataset of size _m_, each instance in the dataset obviously has probability 1/_m_ of getting picked, and therefore it has a probability 1 1/_m_ of _not_ getting picked. If you draw _m_ instances with replacement, all draws are independent and therefore each instance has a probability (1 1/_m_)<sup>_m_</sup> of _not_ getting picked. Now let's use the fact that exp(_x_) is equal to the limit of (1 + _x_/_m_)<sup>_m_</sup> as _m_ approaches infinity. So if _m_ is large, the ratio of out-of-bag instances will be about exp(1) ≈ 0.37. So roughly 63% (1 0.37) will be sampled."
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 17,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.6323045752290363\n",
"0.6321205588285577\n"
]
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code shows how to compute the 63% proba\n",
"print(1 - (1 - 1 / 1000) ** 1000)\n",
"print(1 - np.exp(-1))"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Random Forests"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 18,
"metadata": {},
"outputs": [],
2016-09-27 23:31:21 +02:00
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"rnd_clf = RandomForestClassifier(n_estimators=500, max_leaf_nodes=16,\n",
" n_jobs=-1, random_state=42)\n",
"rnd_clf.fit(X_train, y_train)\n",
"y_pred_rf = rnd_clf.predict(X_test)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"A Random Forest is equivalent to a bag of decision trees:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 19,
"metadata": {},
"outputs": [],
2016-09-27 23:31:21 +02:00
"source": [
"bag_clf = BaggingClassifier(\n",
" DecisionTreeClassifier(max_features=\"sqrt\", max_leaf_nodes=16),\n",
" n_estimators=500, n_jobs=-1, random_state=42)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 20,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code verifies that the predictions are identical\n",
"bag_clf.fit(X_train, y_train)\n",
"y_pred_bag = bag_clf.predict(X_test)\n",
"np.all(y_pred_bag == y_pred_rf) # same predictions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Feature Importance"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 21,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.11 sepal length (cm)\n",
"0.02 sepal width (cm)\n",
"0.44 petal length (cm)\n",
"0.42 petal width (cm)\n"
]
}
],
2016-09-27 23:31:21 +02:00
"source": [
"from sklearn.datasets import load_iris\n",
"\n",
"iris = load_iris(as_frame=True)\n",
"rnd_clf = RandomForestClassifier(n_estimators=500, random_state=42)\n",
"rnd_clf.fit(iris.data, iris.target)\n",
"for score, name in zip(rnd_clf.feature_importances_, iris.data.columns):\n",
" print(round(score, 2), name)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 22,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEQCAYAAACnaJNPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAW0ElEQVR4nO3df7CcVX3H8c83GFACihAwBoUIAoGCREQoFQW1/FAZrYhiKP6g2kqnjlRU6jigDC0qFKcOFdEIGh0ttqL1B/ijKgSNATSgCAYVxQsJKSGRBmJ+EDWnf5zzOMty757Pk9ybvbt5v2Z2cnfPd8/z7OWynz3P85w9kVISAACuKf3eAQDAYCE4AACtEBwAgFYIDgBAKwQHAKAVggMA0ArBAQBFRKSIOKXf+zHZERwA+iIivhYR3xmj7cDyJn7cVt6tp0r62lbepi0i3hgRv+t33wQHgH65QtKLImLWKG1vknSPpO+27TQitt/cHUop3Z9SemRznz+RImJqv/ehQXAA6JdrJa2QdEbng+UN8nWSPplS2hQRB0XEtRGxJiIeiIirImJGR/38iLgmIv4pIpZJWhYR742IO7o3GBE/iIhLx9qhzkNVETGr3H9tRNwQEesj4scR8ayIODgiFkXE2ohYGBHP6Ojj/Ii4IyLeHBH3lud9OSKmd9RMiYjzImJpRDwSEbdHxCs62pttz42I6yJivaS3SPqUpGmlLUXE+aX+9Ij4Ucfv6AsRsWdHf8eW+hdHxM0RsS4iFkfEYU37WH2PKqXEjRs3bn25Sfqg8shiSsdjJ0v6o6SnKx86WiXpIkkHSnqW8qGkHzbPkTRf0hpJn5N0sKRDJD1N0h8kHdHR7wGSkqRDe+xPknRK+XlWuf8LSS+VNFvS9ZLuKP++UNKfSVos6WsdfZwv6XeSFkh6tqTnSfqZpK921Lxd0sOSTpO0v6QLymue07XtEUmnSHqGpL0knSVpraQZ5bZTqf+bso/7SDqi7N/3OrZ3bOnvh2W/Z0v6lqQ7JYWk7cfqe9TfU7//cLhx47bt3iTtV97Qju947FpJ3yg/XyDpu13PeXJ5zhHl/nxJKyXt0FV3jaSPddy/SNLiyv6MFhxv6Wg/qTx2csdjb5T0u47755cQ2KvjsaPL8/Yr9++T9N6ubS+Q9Nmubb+jq+ZR2+rxOmaX5z+t3G+C44SOmud11Vh9p5T0OPUwLYJvQASGxNqUYkuef+KJJ6RVq1a1es4tt9z6M0kbOh6al1Ka19xJKd0VEd9T/sT8PxExU9IJkk4tJc+R9IIxTtruq/wJWpLuSI89N/EJSZ+OiLdL2qh8+OufW72A7KcdP68o/97e9di0iNgxpbSuPHZfSunejpqbJW2SdGBErJA0U9IPurazUHnU0Gmxs4PlkNP7JM2RtKvyKELKo5RlY7yW5eXfPbpqqnoGBwA0Vq1apcWLu9/reot4woaU0uGVsiskfSIidlX+1PugpK+WtinKI5B3jvK8FR0/rx2l/VpJ6yS9StJDknaRdJW77x1+3/Fz6vFY23PGo30w735stNf1KBExTfmw03eUw/EBSdMlfV/5EFSn8dhvggOAKymfNhh3V0v6d0mnK488PpNSat7gbpX0Gkn3dDxmSSn9ISLmlz4fkvSllNLq8drpij0j4ukppaXl/hHKb9B3ppQejojlyoevrut4ztGSllT63Shpu67HZisHxXtSSr+RpIg4eTP2ebS+R8VVVQBMTXC0uRm9prRe0n8onxvYV9KVHc2XSXqSpP+MiCMjYp+I+MuImBcROxvdXyHpGOVzE1dWasfTeuXDZHMi4ihJH5N0bUrprtL+r5LeWa6a2j8iLpD0fEkfqvQ7IunxEXFcREyPiB0l3SvpEUlvLb+fl2nzDsmN1veoCA4ApokJjuIK5ZPei1JKd/5piyktVz6Ju0nSN5WvTrpM+Y2yOt8ipXS3pBuU31wXtNmhLTQi6fPKV4BdJ+luPfqy40uVw+Ni5au0XinpVSmln/TqNKW0SDmErlK+IOCclNJKSW+Q9FfKI5b3STq77Q6P1vdYtVHOpo+Kk+PA8NjSk+OHH35oWrz4W62eE/HUW4xzHBMqIpZI+lxK6cKttL3zla/MOnhrbK8fOMcBwDRh5zgmRETsIWmu8qWtH+/v3gwXggNAC4MTHMpXXa1SnofR7jpi9ERwADAl5XltgyFt4aG5Ldju+con+ocWwQHANFiHqjBxCA4AJoIDGcEBoAWCAwQHABsjDmQEBwATwYGM4ABgIjiQERwATAQHMoJjAlhfL2lwFhiubcv5MjJnfzdU2lt9bekEG5yZBoOI4ADBAcDGiAMZwQHARHAgIzgAmAgOZAQHAFNS/WwXtgUEBwATIw5kBAcAE8GBjOAAYBqsr1XHxCE4AJgYcSAjODZDbcLc440+nIl5uxg1MyvtzuS+nY2aAyrtZxh9XDNONbXf7+1GH48YNbXTwNvmZ2+CAwQHABsjDmQEBwATwYGM4ABgIjiQERwATAQHMoIDQAsEBwgOADZGHMgIDgAmggMZwQHARHAgIzi6OBPmahPQnAl1exs1uxk1z6y0P9/oYx+j5uuV9ouNPsbre1Xvr7Rvb/ThrFhY+++8cZy2MzgIDmQEB4AWCA4QHABsjDiQERwATAQHMoIDgIngQEZwAGhh2/xOYDwawQHAtEmsOQ6J4ABg41AVMoIDgIngQDZUwVGbvDfV6GNHo2ZGpf25Rh/OkeIXGTWn1l6UM9PQOPowe2nlN3Pyumoff/zv+nY+WS/RryrtS40+Roya2q9ljdHHSqNmcCYJEhzIhio4AEw0ggMEBwAbIw5kBAcAE8GBjOAAYCI4kBEcAEwEBzKCA0ALBAcIDgA2RhzIhio4plTanTkatcV7pPpCTYcbfTiLJ800at5fmQRwWm3Sg6RZrzc2dE7veRoPG3M0rjY24yzC9KRKu/OlGM70lhWVdme+iKO2HWmyfEMUwYFsqIIDwEQiOJARHABMBAcyggNAC5PjoBn6i+AAYGLEgYzgAGAiOJARHABMBAcyggNACwQHCA4ANkYcyAYmOGqLNEn1yWNOH7uMQ42zMM9BRs1io2ZWrX0no5MDjJqX9m5+4o31Lo5ZWK9ZbezKSKX9KUYfDxo1tYmGa40+fmvUOBMWnf2deKw5jmxgggPAJJC4HBcEB4A2NvV7BzAZEBwAPEnM/4MkggOAi+BAQXAA8HGoCiI4ALgYcaAgOAD4GHFABAcAFyMOFAMTHLXV/RzO6n5TjZra/zs3G30cPE411elYexqdGBPzNK3SvmbLu5CkXY2aUyvtNxh93G3U1CbdORM9dzNqlhs1kwbBAQ1QcADosyQOVUESwQGgDUYcEMEBwMU5DhQEBwAfh6ogggOAixEHCoIDgI8RB0RwAHAx4kAxMMHhfNCpzcFwrrt35nocWWmvrHkkSZrzeqPImJBw3z2VAme1oduNmor7b6vXzHia0ZEx8eGhyrb+wtjMXUbNryrtOxp9OHM0BmZpJIIDxcAEB4BJgENVEMEBwMWIAwXBAcCT5B3vxdAjOAD4GHFABAcAF99VhYLgAOBjxAERHABcnBxHQXAA8HGoChqg4HAWcqp9GHL6cCYArqi032/0ob2NmpPqJWsu691+07J6H85rnlOZXTnDWQHLedMxVnKaVfkPeZOxnUONXakt9rTE6MNY32pw3osZcaAYmOAAMAkQHBDBAcDFVVUoCA4APkYcEMEBwMWIAwXBAcDHiAMiOAC4uKoKBcEBwMehKojgAOBixIFimwqOnY2amUbNKyrtL5pudHKuUbO0XnJAZQKgs9Ld/u8yig7r3fzw3HoXG4zl8PZ4Tb3mget7t19d78L6W6it9Dhi9LHSqBmY92KCA8U2FRwAthCHqiCCA4CLEQcKggOAjxEHRHAAcDHiQEFwAPCw5jgKggOAjxEHRHAAcPFdVSgIDgA+RhzQkAVHbSW7PY0+djNqDqoVrDM6ec54bEiKV/Vu3/8hYzvvN2rO6938oNHFLGN1P11RL1lcad/O2Mx+Rk1t8uQ0o4+hwslxFEMVHAAmGIeqIIIDgIsRBwqCA4CH4EBBcAD
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 76\n",
"\n",
"from sklearn.datasets import fetch_openml\n",
"\n",
"X_mnist, y_mnist = fetch_openml('mnist_784', return_X_y=True, as_frame=False,\n",
" parser='auto')\n",
2016-09-27 23:31:21 +02:00
"\n",
"rnd_clf = RandomForestClassifier(n_estimators=100, random_state=42)\n",
"rnd_clf.fit(X_mnist, y_mnist)\n",
2016-09-27 23:31:21 +02:00
"\n",
"heatmap_image = rnd_clf.feature_importances_.reshape(28, 28)\n",
"plt.imshow(heatmap_image, cmap=\"hot\")\n",
"cbar = plt.colorbar(ticks=[rnd_clf.feature_importances_.min(),\n",
" rnd_clf.feature_importances_.max()])\n",
"cbar.ax.set_yticklabels(['Not important', 'Very important'], fontsize=14)\n",
"plt.axis(\"off\")\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"mnist_feature_importance_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Boosting\n",
"## AdaBoost"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 23,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAAEQCAYAAAC++cJdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADb50lEQVR4nOy9d3hb2Xnn/zkXlQTYu0hRvfc+GmmayozH43EdO3bc4jibYidZx3Zms7teb3Y2+9tk1i2b2JtN4mziONWJY3uaRzOaoi72KpISKfbeSYAk2j2/P0BQJNFJgASl+3kePhIu7j33ALh48b3veYuQUqKhoaGhoaGhoaHxoKCs9gQ0NDQ0NDQ0NDQ0VhJNAGtoaGhoaGhoaDxQaAJYQ0NDQ0NDQ0PjgUITwBoaGhoaGhoaGg8UmgDW0NDQ0NDQ0NB4oNAEsIaGhoaGhoaGxgOFJoA1lo0Q4q+FEC+v9jx8CCGkEOK51Z6HhoaGhg/NTmpoJBaaANa4HykAXlrtSURCIv0ICSH+sxDiqhDCLoTQCoRraNzfaHZyCQghMoQQfyuEGJ/9+1shRHqYY35/9jXM/+tboSlrBEETwBprAiGEXgghItlXStknpXTEe07BEEIoQgjdap1/GZiAHwPfWeV5aGhoLAHNTq4Ifw8cBp4G3jP7/7+N4LgmvDcdvr998ZqgRmRoAlgj5ggvzwshWoQQ00KIWiHEpxbt84dCiKbZ59uEEC8KIczznv99IUSdEOKXhBAtgAOwzN45/6oQ4keznsq7Acae8xYIITbOPv6IEOINIcSUEOKWEOL8omOemZ3PjBDikhDi47PHbYzg9f6SEMImhHivEKIOcAK7hBDHhBAXhBBDQogJIcQVIcTJece1zf73R7Pnapv33LNCiPLZ+bQKIf6HEMIY2SewNKSUX5dSfhOojOd5NDQ0NDvJGrSTQohdeEXvr0opr0kprwO/BrxPCLEjzOHu2ZsO399gvOapERmaANaIB38AfB74IrAb+J/A/xVCPDNvHzvwy8Au4AvAx4H/vGicTcAvAh8FDgAzs9u/Dvx0dts/AX8lhNgQZk7/A/jfs8eUAv8ohLACCCGK8Xo+X5l9/n8DL0b1isEMfA2vMdwNtAMpeD0DjwDHgSrgVSFE9uwxx2b//Xd4PQLHZufzFPB3wJ8Ce/C+T88B/1+wkwshimd/XEL9/VmUr0lDQyN+aHZy7dnJk4ANuDZv21W8n9PDYV77ZiFE96xQ/0chxOYw+2vEGyml9qf9LesP+Gvg5dn/W4Bp4JFF+3wHeDXEGL8ONM97/PuAC8hbtJ8E/ue8x3pgCvjUon2em/3/xtnHvzbv+cLZbadnH/9PoAEQ8/b5T7P7bIzg9f/S7L5HwuwngN5gc5237RLwXxZt+yBewyuCjK0Htob5y43w83zOaxpW/9rS/rS/++VPs5Nr307Ovt67AbbfBf5jiOOeBj4G7AfOAe8AfUDWal+XD/KfHg2N2LIb713+z8XCRCoD0OZ7MLv09iW8BscK6Gb/5tMlpewPcI4a33+klG4hxCCQG2ZeNfP+3zP7r++YnUCpnLVUs9wMM95i3Hg9F3MIIXKB/w48AeThfX1JQHGYsY4Ax4UQ/2HeNmX22Hy8Pw4LkFK6geYo56yhobE6aHZyljVoJwMlCIsg233nfW3BzkLcwCuaPwt8a5nz0VgimgDWiDW+sJpngY5Fz7kAhBAPAf8I/Dfgd4Ax4P3ANxbtbw9yDteix5Lw4Txzx0gppfDmifiOCWm8IsQhpfQs2vY3eA367+D9UXMAF4FwMWoK3vfmRwGeCxg3Nrs8eSvMuD+UUv56mH00NDTij2Yn77GW7GQfkCuEEL4bAeF9k3KAQDchAZFS2oQQ9cC2SI/RiD2aANaINbfwGrANUsq3guxzCuiWUv5334YIYtPiSQPwgUXbjsdg3NPAb0spXwEQQuThjWGbjwt/j04FsFNKGY2nogc4GGafiSjG09DQiB+anbzHWrKT1/F64k9yLw74JN6QlmvBDlrMbCLjTuDtSI/RiD2aANaIKVLKSSHEN4BvzN4ZX8JrMB4CVCnlnwO3gUIhxCfxGpSngE+s1pyBPwO+PDvvv8CbUPFrs88tx+NxG/iUEOImXgP5It7M5/m0AWeFEO/i9Y6MAi8ALwsh2oF/xrtsuBc4LqV8PtCJYrG0N+sdycQbD4gQ4uDsU81SSttyxtbQ0LiHZicXsGbspJSyQQjxc7zJiv8Or1f8/+KN7W7y7SeEaAT+VEr5p7OPv4G35nIH3pCS/zL7Wv9mqXPRWD5aFQiNePBf8CZnfBWoB94APgK0AkgpXwL+F96EjxrgPN6M5VVBStk+O7/3A9V4l+L+2+zTM8GOi4BfxvujVo53KfOvmBffN8tX8Ma+dTJbfkxK+TrwzOz2ktm/38N/qTTWvDA7h/81+7hy9u9onM+rofEgotlJL2vNTn4S7+u/ALw++/9PL9pnB5A973ER8A94awH/GK/3/6HZ91RjlRAL49k1NDQAhBD/Hq8gzJBSqqs9Hw0NDY1EQ7OTGmsZLQRCQwMQQnwRb93LQbzLkP8F+GvNqGtoaGh40eykxv3EqoRACCH+SggxMNsNJtDzjwtvj+2q2b9VW/bReGDYCvwb3kSP/4433u13AYQQr4Uomv6fVnHOGhoaGiuJZic17htWJQRCCPEo3mLVP5BS7g3w/OPAV6WU71vhqWlo+CGEKMRbWzIQI1LKkZWcj4aGhkaiodlJjbXGqoRASCkviQh6h2toJAJSyu7VnoOGhoZGIqPZSY21RiJXgTgphKieXVbZs9qT0dDQ0NDQ0NDQuD9I1CS4CrwFwm1CiPcCPyFIxxQhxK8CvwpgsZiPbN9RuGKT1Li/6O8bR1Xv5XIoikJeftqivRR6e4aDjlHgt7/GymDE7VaZsk/hmJkmLS0JkyGEeRMKKkYcDic2mw29IslISwZAKkk4XW7sNhuq20lGmhVFb8bhVrHZ7LidM5j1kJ6SjIjzqypv6hmSUubEa3zNfmrEkshsKPT2jAYdo2BdRlzmphFr9KgePXb7FLZJG0nJFvR6PWOjo3hcDjblp672BIHQNnTVyqDNhkC8HCgGOMC+bcBRKeVQqP0OH9kqL9/4ZmwmqPFA8Z0XX8Jmc/htt1qT+NJXPwuAKnUM983wl3/5N3g8bv99LUa+/MXzcZ+rhj9S2UBvv43S6yV0NDfw/qcPsKk4K/gBShIzynru3Gnn6uWrZCU5+dj7jwIKbssu2ruHuXn1GskGwekzT9HY0snN62W4bP0cLEziiYMbsSSF69S6fHSnv1YupVyROsya/dRYDsFtqIkvPf/skvfVSEwUsrGN5XLjWilXL19l38HjZGdn8pMf/4Tx3mZ++HtPrvYUgdA2NCE9wEKIfKB/thf5cbyhGsHdbhoayySQMfZun6ahdhyA6elpqsprvT2KFjXl1MTv6iAx4FYL6OiY4OaNSjpabpOXLkhLDZaLEx3ZBZuYtE/T13EX3VQv5w4UcXhrNt7mXRoaGj6C21D/7V96/lk/EayJX42VZlUEsBDiH4DHgWwhRBfwXwEDgJTyz4DngN8QQriBaeDjUuvY8cCzWgbztZ/8GwAej4pOcXLySAEPP7YPs8kQ93NrhECkMjWdRn1DJ9evlDBjG2DXllTOPLI7RgJYoNPpUVUP0uPEpEgKMpM18auxZkkk0amJ3bWNlAJVSlSpoq5RdbZaVSBC9jOf7Z/9pys0nQeaRDKIoQi0ZGazOfjOiy/Ffb4HD3vjIhVFYevmbIrXZ8b1fBqRIZVMRsdtdLY2I50DnHlkC8cOrEdRli9QFZ2BzTuO4lLhys0aulvaKUySWM3aTY/GPdaK/YTVtaEa9xdSTWdixEh1VT2VlfUgjFitqThdKqpUMejCj5EIJHIVCI04E8ogJhrRLK9FgpR6kKlzfxZrcsD9LBYj557YybkndnLmse2a+E0oBEgVgZMkg6S4MD0m4hddMobkdQxPurj41mU6GmvZk6vnI4/tJdUS/7hfjbXBWrKfEHsbuhir1RTVdo21iA7VVUjXXTevv3qFK5ev4fFICgq
"text/plain": [
"<Figure size 720x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 78\n",
"\n",
2016-09-27 23:31:21 +02:00
"m = len(X_train)\n",
"\n",
2022-06-10 10:29:59 +02:00
"fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n",
"for subplot, learning_rate in ((0, 1), (1, 0.5)):\n",
" sample_weights = np.ones(m) / m\n",
" plt.sca(axes[subplot])\n",
2016-09-27 23:31:21 +02:00
" for i in range(5):\n",
" svm_clf = SVC(C=0.2, gamma=0.6, random_state=42)\n",
" svm_clf.fit(X_train, y_train, sample_weight=sample_weights * m)\n",
2016-09-27 23:31:21 +02:00
" y_pred = svm_clf.predict(X_train)\n",
"\n",
" error_weights = sample_weights[y_pred != y_train].sum()\n",
" r = error_weights / sample_weights.sum() # equation 7-1\n",
" alpha = learning_rate * np.log((1 - r) / r) # equation 7-2\n",
" sample_weights[y_pred != y_train] *= np.exp(alpha) # equation 7-3\n",
" sample_weights /= sample_weights.sum() # normalization step\n",
"\n",
" plot_decision_boundary(svm_clf, X_train, y_train, alpha=0.4)\n",
2021-11-21 22:19:22 +01:00
" plt.title(f\"learning_rate = {learning_rate}\")\n",
" if subplot == 0:\n",
" plt.text(-0.75, -0.95, \"1\", fontsize=16)\n",
" plt.text(-1.05, -0.95, \"2\", fontsize=16)\n",
" plt.text(1.0, -0.95, \"3\", fontsize=16)\n",
" plt.text(-1.45, -0.5, \"4\", fontsize=16)\n",
" plt.text(1.36, -0.95, \"5\", fontsize=16)\n",
" else:\n",
" plt.ylabel(\"\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"boosting_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 24,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"AdaBoostClassifier(base_estimator=DecisionTreeClassifier(max_depth=1),\n",
" learning_rate=0.5, n_estimators=30, random_state=42)"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"from sklearn.ensemble import AdaBoostClassifier\n",
2016-09-27 23:31:21 +02:00
"\n",
"ada_clf = AdaBoostClassifier(\n",
" DecisionTreeClassifier(max_depth=1), n_estimators=30,\n",
" learning_rate=0.5, random_state=42)\n",
"ada_clf.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 25,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEPCAYAAABY9lNGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABH60lEQVR4nO2de5hU9Znnv29VX7m2yB25KQ0BNA3aEjUkgWBPRCWgE50kz+4kMz66mYyzmbhxNs+4TjazM5tM2CWXSTJZTZwkm834aERUBBQQNUxUoIHmYovdXBsa6G6gu7qphuqu+u0fdelTp879/M6lqt/P8/TT3VWnzvmdU+e87++9/kgIAYZhGIZREwl6AAzDMEw4YQXBMAzDaMIKgmEYhtGEFQTDMAyjCSsIhmEYRhNWEAzDMIwmgSgIInqGiDqI6JDO+8uIqIeI9md+/s7vMTIMwwx3ygI67i8B/BjArw22+b0Q4l5/hsMwDMOoCcSCEEK8DeBiEMdmGIZhrBHmGMTtRNRERJuJaGHQg2EYhhluBOViMmMvgJlCiD4iuhvABgC1WhsS0SMAHgGAkSOrb5n3kVl+jZEZ5sTjV9Ab60MymUQ0GsXoMaMwYkSV5rbnz3UhmUwWvB6NRjFp8nivhzrMKcPgANDb24vuS90YHBzADdMnIRoN8/zYPxoPtXYJISZovUdB9WIiolkANgohbrSw7QkA9UKILqPtbqlfIN7Z9f/kDJBhDGja34xXNmzFwMBg7rXy8jKsWtOAukXzXW/PyCOCcTh/NoI3d+zAhg0b0NN1Fs/94OsYO6o66KGFgkjtZxuFEPWa7/k9GCsQ0WQioszfS5Ae54VgR8UwQ2zfujNP2APAwMAgtm/dqbl93aL5WLWmAWNrRgMAxtaMZuXAhJ5AXExE9G8AlgEYT0SnAXwLQDkACCF+BuBzAP6CiAYB9AP4vOC2s4xFmvY3Y/vWnejp7sXYmtFY0bBUuiDu6e619TqQVhKsEJhiIhAFIYT4gsn7P0Y6DZZhbNG0vxkbXtiCVCo9n+jp7sWGF7YAgFThPLZmtKYyyFoIDFMKhDVIzQSMH7NwL9i88Y2ccsiSSgls3viG1PGvaFiqGVNY0bBU2jEYJmhYQTAFqAOqPd29eGXDVgByZ+Fe0N9/1dbrTsleh2JUogxjFVYQTAFGAVgWgENwTIEpdUKZxcQEi5MAbFio1qlD0HudYRh92IJgCijmAOzKe5bjpfWvIZlM5V6LRiNYec/yAEfFqCnWGNdwgy0IpoAVDUtRXp4/dyiWAGzdovlYff9n8uoNVt//GRY+ISIb48pOQrIxrqb9zQGPjFHDFgRTQLEHYDk2EG44xlU8sIJgNGEhy3hFMce4hhusIEIC+2T9R/Y15+/QGsUc4xpusIIIAcVcdxA0ToWy1jVf//xmrH9+syPhzt+hdbjIsHjgIHUIsNv4jUnjJtipdc2zOAma8ndoHW5cWDywBWERL90HpeqT9drl4ibYaXZt7QZNS/U79AqOcRUHbEFYwOu0PD3fazH7ZP1IZXQjlK1cWzvCvRS/Q4ZhBWEBr90HxVx3oIcfLhc3QlnrmjvZj9H+iv07ZBhWEBbw2n1Qij5ZP1wuboSy+pqrsSvcS/E7ZBiOQVjAj7S8UvPJ+nXNAOcFfcprLiNeUmrfIcOwgrAAp+XZx69rJksos3BnmEJYQVig2FtPBAFfM4YpflhBWIRnmPYZzteMq6qZUoAVBMNIhquqmVKBFQRjSqnOhr06L+5WypQKrCAYQ0p1NuzleXFVNVMqcB0EY0ip9hjy8ry4qpopFdiCYAwp1dmwl+ell+JbO3c21q19uuRcdUzpwgqCMcTv3v124gJuWn0TEYQQBe/ZPS+jMShfr507G0373i85Vx1T2rCCYAzxs0jQTlxAbz2HUyfPYNXqO02PoaUc7J6X2XiVY1639mkOXDNFB8cgGEP87DFkJy6gt57Dnl0HDDvG6n2OiGyfl53xlqqrjilt2IJgcui5S/wqeLMjRI0Eq9GsXO9zQgjb52hnvLzMJlOMsAVRhDTtb8a6tU/jW0+sw7q1T0tZY8GP9RvM0BOWRFQwDiPBaqQ87BzDDDvZStwOnClGWEEUGV4J8jCks+qt0SCEKDjH2rmzdfdjpDzsHMPJePWEPrcDZ4oRdjEVGV5V6YbBR54d/4u/21IQRFafY8uHx3X3YzQrt3MMq+O1mkk1nHtTMcUJK4giwwtBLjPt0y11i+Zj/fObNd9TnqPR+ZoJYavHsAILfaaUYQVRZMgOdspM+5SBkYtHeY52roNW8J2DxgxjDscgigzZwU6ZaZ8yMIp5KM/R6nXQi9nUzp3NQWOGMYEtiCJD9kI8MtM+s7jpkmrVdWT1OujFbFo+PI5VaxpKskstw8giEAVBRM8AuBdAhxDiRo33CcAPAdwNIA7gy0KIvf6OMrzI9HvruVqAdPWvXaHptkuqHdePletgFLMx+nyptjhnwksiHsOHb/0Gcz/1H1AxYkzQwwEQnIvplwDuMnh/JYDazM8jAP7FhzENS/TSPgFnKbRu02X1xpNIDDhK5XXSWTUMNSHM8KOtaRti50+grWlb0EPJEYiCEEK8DeCiwSarAfxapHkXQA0RTfFndMMLdX6+Gru1EG6zrLLjqa6uzHu9P37FkZB2ErMJQ00IM7xIxGPoaN0NQKCzdQ8S8VjQQwIQ3hjENABtiv9PZ147q96QiB5B2srAjBmTfRlcqZF1tXzriXWa79tJ/ZSRHVS3aD62b92J/v6rea87qfdwErPxuyakWNxZxTLOYqStaRuQySQUIoW2pm244fb7Ax5VeBUEabxWmIcJQAjxFICnAOCW+gWa2zDWkBGPkNX9VaaQthuz8TMFtlhW7CuWcRYjWetBpJIAAJFKorN1D6bX3Rl4LCKsCuI0gOmK/68D0B7QWIYNWsI9i1WBICvLSq9wL52/YA/1zLd27my0fHhcd3x+tjj3ev1qWbN+XmfbO5TWQ5awWBFhVRAvA3iUiJ4F8DEAPUKIAvcSIxe1cFdjVSDIyLLSUg5Gr+uhNfPds+tA7n0txSc7lVg9HuV+vXRnyZz1h6EVS6nS23EyZz1kEakkejtOBjSiIYJKc/03AMsAjCei0wC+BaAcAIQQPwOwCekU11ak01z/LIhxyva5FoMPV2Y8wg2y3Dx6hYBKtBSfFy00tAS2HjLcWTJn/U6+j2K438PAotVfD3oIugSiIIQQXzB5XwD4S5+Go4lsn2vYfbjqh7m6urIgSAz414rC61iG0+3cYEVZAfLcWW5n/cp7onpEFSIRQio1ZMEZjTPs9ztjDW61oYPsVMcwp05q5f0nEgOIRPL9/X62opDVHtuqQvND8VmxGGS2AXdSA5JFfU/0x6+AiHLpx2NrRqNu8QJs37pTc12SMN/vjHXCGoMIHNk+Vxn788pk13qYk8kUqkdUoaKiPDAXgQw3j1HgPYtfis/ITfPY4w9LP54bK0zvnqiorMA3/9tfmloIHLMoDVhB6CA71dHt/rw02fUe2v74FXzzia+62nfQaAWczbKYvMLP7CjAXbDdTMCbxTe4W25pwApCB9kPs9v9eZlmKLsfU9gIy5oNbgS2U+vR6bmbCXgzBeK3MmSMcdrnqeQVhJsHC5CX6uh2f16a7DLqHxhrOBHYQQR8zQS8mQLxMlWYsY+yz5Od2oqSVhBuHyzZM083+7NqsjtRiLLqH6wStvTHsI1HjV9FaurrULd4ga4rzoqFEBbLbbij7vNkp0K7pBVEKVV/Wnkg3ShEJ/UPTgSrW6XtVphrVVXv3XMwl77Z092LDS9ssTweP/Aj4Kv1vTTte183o4othOLBTZ+nklYQpZBJoc5FLyuLor//quYDKUMhWrFUmvY3Y/OrO9Afv5J7TS3o9QS53hg3b3zDF+ViVFWdJZUSlsaj3vfmjW/kakeqR1Rh5T3LpaWreh3wdXLvsIUQftz2eSppBVHsmRRqgdYfv4Ly8jLc/8BKzQdThkI0s1TUY1KizHPXE+S6GVP9V9G0v9lQ4LhVgFYL1bLjsUrT/mZseGFLXhFZf/wKXlr/GgDnlohSyaq
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code in case you're curious to see what the decision boundary\n",
"# looks like for the AdaBoost classifier\n",
2021-11-21 05:55:56 +01:00
"plot_decision_boundary(ada_clf, X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gradient Boosting"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-06-10 10:29:59 +02:00
"Let's create a simple quadratic dataset and fit a `DecisionTreeRegressor` to it:"
2019-06-08 15:59:55 +02:00
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 26,
2019-06-08 15:59:55 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeRegressor(max_depth=2, random_state=42)"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
2019-06-08 15:59:55 +02:00
"source": [
"import numpy as np\n",
"from sklearn.tree import DecisionTreeRegressor\n",
2016-09-27 23:31:21 +02:00
"\n",
"np.random.seed(42)\n",
"X = np.random.rand(100, 1) - 0.5\n",
"y = 3 * X[:, 0] ** 2 + 0.05 * np.random.randn(100) # y = 3x² + Gaussian noise\n",
2016-09-27 23:31:21 +02:00
"\n",
"tree_reg1 = DecisionTreeRegressor(max_depth=2, random_state=42)\n",
"tree_reg1.fit(X, y)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-06-10 10:29:59 +02:00
"Now let's train another decision tree regressor on the residual errors made by the previous predictor:"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 27,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeRegressor(max_depth=2, random_state=43)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y2 = y - tree_reg1.predict(X)\n",
"tree_reg2 = DecisionTreeRegressor(max_depth=2, random_state=43)\n",
"tree_reg2.fit(X, y2)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 28,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeRegressor(max_depth=2, random_state=44)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y3 = y2 - tree_reg2.predict(X)\n",
"tree_reg3 = DecisionTreeRegressor(max_depth=2, random_state=44)\n",
"tree_reg3.fit(X, y3)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 29,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0.49484029, 0.04021166, 0.75026781])"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_new = np.array([[-0.4], [0.], [0.5]])\n",
"sum(tree.predict(X_new) for tree in (tree_reg1, tree_reg2, tree_reg3))"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 30,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwgAAAMICAYAAABsBcIXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADV4klEQVR4nOzdedxUdd3/8dfnWlgERBZFBUlFNFFABVGyBBcUcculskUlLeNWCyXviFKBMLQ7MystMzX1V7dmqd0upLmRJJcLCJcLuOCCIoiC5MZ2Ld/fH2dmrpm5Zp8z58zyfj4e87iumTlz5nvOzHzO+ZzvZs45REREREREAOrCLoCIiIiIiJQPJQgiIiIiIhKjBEFERERERGKUIIiIiIiISIwSBBERERERiVGCICIiIiIiMUoQJCdmNsvMXsiyzDVmNr8E7+3M7FS/11ssM7vPzG4OuxzlJPk7YGY3m9l9Ra4z63dPRCQXOR7LKi7mmNlFZvZm3P2it8HMJpvZJ0UXTiqSEoQqEDkJc5Fbq5m9ZWa/M7M+Pr7NlcA4H9dX9RRcAZgKfCOXBc1s18h3eHTSU/ruiZSRpGNO/O3JsMsmMXnFzTQX4v4C7O5rqaRiNIRdAPHNw8DpeJ/pMOAmYDvgq36s3Dn3CVDrJ7slYWZdnHNbwy5HlJ/lcc596MM69N0TKT/RY068soljlcjn2Ft03HTObQI2+VEeqTyqQageW5xz7zrnVjnn/omX+R8Vv4CZfdPMlpnZZjN7xcwuNLO6uOe/E3l8s5m9b2YPmllD5LmE6kozqzezK81sQ+R2NVCf9H7zzeyapMcSmpyY2UQzWxBZxweR99w704aa2aVmttLMtpjZu2Z2a4Zl683sRjN7w8w2mdmrZvaDpO2+OdJcaKqZvRMpyx/NbJu4ZbaJLPeJma01sx9lKeN44I9Aj7ira7Miz70Z2Z83mdl/gD9HHv+cmf3LzDZGyvE7M9s2bp0WKftrkW153swyXp2P27aLI+X+JLJt3eOWmR95ryvN7H3gicjjw8zsfjP72MzeM7PbzGzHpH2b7TuQ/HmbmX0/8jlsMbNVZnZ55Ok3In+fieyv+ZHXJH/36szsEjN7O7KO583sxLjnozURp5jZQ5H9uczMJsQt02hmvzaz1ZF1vG1mV2TalyKSIHrMib99EH0y8hs8x8z+amafmtnryfEqUyzPFu/ifuenReLmJjNbYmYjzGxfM1sYed9/m9luyYU3s2+ZV9u+ycz+bmb9M22sZTl+plh+lpm9kOl94uLzdDNbBayKPD7QzG6Pi633m9nQpPX/ILLPPonst56p3j/psTMj+3FL5Hhwc+TxNyOL/DWyT9+MPN6pFty884QVZrY18vfbSc8X9blL+VCCUIXMbHdgItAS99i3gbnApcDewPeB6cC5kedHA9cCs4G9gCOBBzK8zfeBbwPfAcbinRh+vYDi9gCuBsYA44EPgXvNrEuabTsFuChS7qHAccDTGdZfB7wDfBlvu38M/Aj4ZtJyXwD2xdvurwAn4TWPiboSmACcAhwB7A8cmuF9FwIXABuBnSK3K+Oenwa8BIwGfmRmw4F/AvcAI4GTgf3waoKiLgPOBs7DqyW6HPi9mR2boRzgVTOPjJT7FLzE8WdJy3wDMLz9cIaZ7QQ8DryA99kciXcAuifuoFjId2AucEmk7PsAXwLejjw3JvJ3It7+OjnNOqYC/433/R0O3A3cZWb7JS33U+DXkW1/BrjdzKIH0e/hfcan4X2PvgK8nKXsIpKfS4H/w/sN/gW4ycw+AznF8lzj3Wy8eLY/8B/gf4Hf4MX6MUA3vDgQb1e8mHciXmwbSmKsTZDt+JlBLu8zDhiBF/eOMO/C1GPA5shzY4E1wMOR5zCzL+Ptn5nAAXixa1qmgpjZd4Df4124GgFMAl6MPH1g5O+38WLvgZ1W4K3jJOAavGP2vsCvgN+a2fFJixbzuUu5cM7pVuE34GagFa86cRPgIrcL45Z5Czg96XUXAMsi/5+Md3LeK817zAJeiLu/Gvhx3P064BVgftxj84FrUpT1vgzb0gNoAz4f95gDTo38Pw0vGDYWsb+uAB5OKtPbQEPcY3+ILoN3YrwF+Hrc8z3xDkY3Z3ifycAnKR5/E7g36bFbgRuTHtsvsu07RPbLJuALSctcDczL8t34D9Az7rFvRLanR9zn9FzS634CPJL0WJ9Iecbk8R2Ifd6RfbYZmJKmrLtG1j86y3fvHeDSpGXmA39KWs934p4fGHns85H7vwYeAawUv0nddKvmG4nHnPjbz+KWccDlcfcb8C6YfCNyP20szyXepfmdHxd57OS4xxLicCSetAGD4x77fOR1Q+OWiY85GY+fafZRLu9zM/A+0DVumbOAV+NjE97Fl/XAlyP3FwJ/SHq/h4E3k94/fhtWAVdkKG/sOJth3z0B3JTiu/BvPz533crrphqE6vE43gnlGLyrJ/OIXDUxs+2BXfCuvnwSveGdKA+JvP4hYCXwhpn9OVIV2SvVG5lZb7yrDE3Rx5xz7cBT+RbazIaY2f9GqpE/AtbinWgOTvOSv+JdEXrDvKZDXzKzrlneY4qZLTKv2dQnwIUp1r/MOdcad3813ok5ePuoC4nb+wnwfI6bmcqipPujgG8kfT5PxL3/MLztfiBpmf+i4zNM57lIeaOa8LYn/nWLU5Tn0KT3il7pH1Lgd2AY0BXvxLwg5jW52pmOfRP178j64z0X9//qyN/oZ3oz3u/lFTO71syOzdRcQEQ6iR5z4m8/T1om9huMxNf36fgNZorl+cS7+N/52sjf55Me62FxTUaBd5xzb8Xdfwpox6sdSJDj8TOdXN7nBefclrj7o4DdgI/j3utDvAs00ffbm7jYG5F8P34bdsC7SFJw7I1737xib56fu5QRdVKuHhudcysi/3/PzB7Da8oxi46mZFPwrjx04pz72MwOwGs2MwGYAcw1swOdc6tTvSYH7XjNVuI1Jt2/F++K8Hcif1uBZXgnsKnK+baZ7YXXXOZI4BfATDM7yDn3afLyZvYVvKtOF+Ft+0d4VdYnJS3aknTf0bHfkrfBD8llrQNuAH6ZYtl38KqEAY7Hu5oVL7nsfpXnfrz9liyaxOXLz/3ocngstl+cc87MIFJu59yzZrYrXrX+4cAtQLOZTYgkOiKSWfwxJ520cTVTLKcjvuQS7+LvuwyPFXoBIOvxs0ipYu9SvOaPyT5I8VguQou9cc9n/dxTHcMlPLpiVr1mA9PNbGfn3Fq8k8whzrkVybfoC5xzrc65R51zM/BOSHvgVdkmcN7INGuAg6OPmXf2NSZp0ffxrjLHGxn3mn54VyTmOuceds4tB3qRJXF1zm12zt3vnLsQr63kPsAhaRb/PPCUc+4a59yzke3NdtUn2Qq8gBe/vT3w2mBmspWkTrsZPAvsk+rzcd5IEsvwmgV9JsXzK7Ose3ikvFEHR8r2WrbyACtTvN/HeXwH4kW34Yg0z0dH70i7z5xzH+HVBnw+6anPR9afs8h2/NU591/AsXiJwh75rENECpchlhcT73Ix0Mx2ibs/Bu98aHmKMuZ0/Cz2feI8ixeH1qV4v2iCsJy42BuRfD/VNqSLveAd47Idr5bjT+zN5xguIVENQpVyzs03sxeBi/E6A80CfmPeqDnz8K7kHwAMdM5dbmbH4Z04P453leIwvJP1dIHsV8AMM3sFrzr3XLxkYE3cMo8CV5vZCXhtDr+DV1X7ZuT5DcA64Ntm9jZeFejP8WoRUjKzyXjf26fw2rx+BS+wvZrmJa8Ak83sGLwT/dPwOn5tSPceyZxzn5jZjcDPzBvlZzVeJ6xswfRNoJt5o+cswbvitjHNsj8DnjSz6/A6kn0MfBY43jn3nUgNz5XAlZET8cfx2vQfDLQ7567PUI4GvE5iP8FrnnMFXvvVTFdrrsXrsPYXM/sZXrK3O15n7+875z4mt+9ATGQbfgVcbmZbItvQDxjlnPsd8B5eu+OjzRtFY7NLPUzqz4GfmNmreE2jvoHXuXpUhu1JYGbTIuVcivf9+Rpe7dKqXNchUuO6WtyoZhFtzrn3c3lxplheZLzLxSb
"text/plain": [
"<Figure size 792x792 with 6 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 79\n",
"\n",
"def plot_predictions(regressors, X, y, axes, style,\n",
" label=None, data_style=\"b.\", data_label=None):\n",
" x1 = np.linspace(axes[0], axes[1], 500)\n",
" y_pred = sum(regressor.predict(x1.reshape(-1, 1))\n",
" for regressor in regressors)\n",
" plt.plot(X[:, 0], y, data_style, label=data_label)\n",
" plt.plot(x1, y_pred, style, linewidth=2, label=label)\n",
" if label or data_label:\n",
" plt.legend(loc=\"upper center\")\n",
" plt.axis(axes)\n",
"\n",
2021-12-08 03:16:42 +01:00
"plt.figure(figsize=(11, 11))\n",
"\n",
"plt.subplot(3, 2, 1)\n",
"plot_predictions([tree_reg1], X, y, axes=[-0.5, 0.5, -0.2, 0.8], style=\"g-\",\n",
" label=\"$h_1(x_1)$\", data_label=\"Training set\")\n",
"plt.ylabel(\"$y$ \", rotation=0)\n",
"plt.title(\"Residuals and tree predictions\")\n",
"\n",
"plt.subplot(3, 2, 2)\n",
"plot_predictions([tree_reg1], X, y, axes=[-0.5, 0.5, -0.2, 0.8], style=\"r-\",\n",
" label=\"$h(x_1) = h_1(x_1)$\", data_label=\"Training set\")\n",
"plt.title(\"Ensemble predictions\")\n",
"\n",
"plt.subplot(3, 2, 3)\n",
"plot_predictions([tree_reg2], X, y2, axes=[-0.5, 0.5, -0.4, 0.6], style=\"g-\",\n",
" label=\"$h_2(x_1)$\", data_style=\"k+\",\n",
" data_label=\"Residuals: $y - h_1(x_1)$\")\n",
"plt.ylabel(\"$y$ \", rotation=0)\n",
"\n",
"plt.subplot(3, 2, 4)\n",
"plot_predictions([tree_reg1, tree_reg2], X, y, axes=[-0.5, 0.5, -0.2, 0.8],\n",
" style=\"r-\", label=\"$h(x_1) = h_1(x_1) + h_2(x_1)$\")\n",
"\n",
"plt.subplot(3, 2, 5)\n",
"plot_predictions([tree_reg3], X, y3, axes=[-0.5, 0.5, -0.4, 0.6], style=\"g-\",\n",
" label=\"$h_3(x_1)$\", data_style=\"k+\",\n",
" data_label=\"Residuals: $y - h_1(x_1) - h_2(x_1)$\")\n",
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$ \", rotation=0)\n",
"\n",
"plt.subplot(3, 2, 6)\n",
"plot_predictions([tree_reg1, tree_reg2, tree_reg3], X, y,\n",
" axes=[-0.5, 0.5, -0.2, 0.8], style=\"r-\",\n",
" label=\"$h(x_1) = h_1(x_1) + h_2(x_1) + h_3(x_1)$\")\n",
"plt.xlabel(\"$x_1$\")\n",
"\n",
"save_fig(\"gradient_boosting_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's try a gradient boosting regressor:"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 31,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"GradientBoostingRegressor(learning_rate=1.0, max_depth=2, n_estimators=3,\n",
" random_state=42)"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.ensemble import GradientBoostingRegressor\n",
"\n",
"gbrt = GradientBoostingRegressor(max_depth=2, n_estimators=3,\n",
" learning_rate=1.0, random_state=42)\n",
"gbrt.fit(X, y)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 32,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"GradientBoostingRegressor(learning_rate=0.05, max_depth=2, n_estimators=500,\n",
" n_iter_no_change=10, random_state=42)"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gbrt_best = GradientBoostingRegressor(\n",
" max_depth=2, learning_rate=0.05, n_estimators=500,\n",
" n_iter_no_change=10, random_state=42)\n",
"gbrt_best.fit(X, y)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 33,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"92"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gbrt_best.n_estimators_"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 34,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAAEQCAYAAAC++cJdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABXdklEQVR4nO3dd5xU1f3/8ddnl7Z0ECkuRRQUsUfErJqINZYoGsMvRhNL9KvGaCTGRozRqAn2aKIGibEmUWNJRMVKgnVV7BEQRQUpShPpbXfP749zZ5mdnb4zd2Z23s/HYxjm3jtzz52d+cznnnuKOecQERERESkXFYUugIiIiIhImJQAi4iIiEhZUQIsIiIiImVFCbCIiIiIlBUlwCIiIiJSVpQAi4iIiEhZUQKcQ2Z2t5k9UehyRJiZM7PvF7ocUrzM7GQzW13ockjrpJgorZGZXW5mHxS6HNIySoBbt37A44UuRDrC+mEys0vM7BUzW2NmaQ+CbWZnmdlnZrbezN4ys2/ls5z5kOA9fhDYJoR9jwr23yvf+8qUmQ03s/+a2aLg7/upmf3ezNoVumySc4qJzffTw8zuM7MVwe0+M+uexvOSxsTg5MfF3F7L24HkgZltHZR7RMyq64H9Qth/UVdQmNnPzGymma0zs1lmdmLM+v8zs5fM7Csz+zqIs/sWqryxlACXGDNrY2aWzrbOuS+dcxvyXaZEzKzCzCoLtf8E2gOPAjel+wQz+wFwM/B7YHfgVeApMxuYjwKGyTm3zjm3uNDlyISZtc3xS24E7gEOAbYHxgKnAlfleD+SB4qJLfYP4BvAYcChwf/vS/aEDGLi8/iTjsjt8JyWvECcc6udc8sKXY505eNzZ2Y/Ba4BrgB2BC4DbjWzI6M2G4WvZDkQ2AuYBTxjZkNzWZasOed0y9ENuBt4IuqxARcCnwDrgP8BP4p5ztX4D8U6YA5wLdAhav3lwAfAycHr1AOdAQecDjwErAE+jfPaDvh+8P+tg8fHAs8Ba4EZwMExzzkiKM964EXguOB5W6dx/CcDq/FB7gOgDtgJ2BN4FlgKrAReBmqinjcn2EfkNidq3ZHAW0F5PgN+B7TLwd/q+/7jn9a2rwN/iVn2MTA+w33OAX4N3B68D/OBCzJ4fjdgIrAYWAW8AIyIWX9fsH598JkYm+w9jvzN4nzeTgqesxq4C2gHnAXMA5YBNwIVUc/7ETAtKNfi4HNZHfPZi77dHaxrjz8ZWRSU+TVg36jXHRVsfzjwBj5Z/S4wAHgM+Ar/Wf4QOC6H3+Ubgdp8xYpyuaGYeDJFHBOBHYLX3ydq2b7Bsu2TPC9lTIz927fgM5Ty75ri+dXAA8Dy4PYkMDRqfcJYQvO4NTX6Mxh7rMBFwJfAiuBzXBFsuzhYflFM2c4D3g+OawFwB9A9WDcqzv4vD9b1wJ+0L8d/T54Hdkzjc7czMCX4zK0C3gP2z/Lv8irwh5hlNwAvJ3mOBe/DOS39XOTiVvACtKZb7Bc+CEyz8GfVg4Hjgw/6EVHbXArsgw/GhwOfA1dGrb88eM6z+DPznYA2wZdhPj7xGAKMxycHg6KeGy/Yf4gPoEODL9AyoHOwzUBgA/7Hf3t8kvg5mQX7uuCLsQ+wHdAFOAD4MT7YDgNuCb64vYLnbRns4zSgL7BlsPw7wRf1FGBbYP/g/bw+ap8Tgi96stvAOGVNKwHGJ351wJiY5bcCL2T4+ZgTvN9nB3+zc4LjrknjuYb/kXwSGBk8/8rg/ekXbPMn4N1g/db4ADomxXt8Ms0T4NX4WvKdgr/BauApfCK8A3AMsAk4Nup5P8F/frcJ9v9f4MVgXSXwvWD/w4P9dwvW3Qx8gU8ydgD+Euwvckyjguf9D19Du01wLI/jk5Zd8d+tQ4FDo8ozPcVnYnqS93oIPhG6utAxpdRvKCaeTBHHRPz3dhVgMbFmNXBKgmNKKyYGf/uv8cnfR/jvdu8sPkMp/65Jntsx2PfdwC7Be30HMBfoGGyTMJbgT1Rc8L73BXpGfQZjE+CVwXs/DPgh0AA8HZR3O+DM4LX2iHre2OCzsDW+ScX7wH1R7/O5+M963+AW+Vw+hv/cfhuf1E7CV05Upfjc/Q/4W1DGIfhYHn3ilepz81TUtm8RUwmEv2q2EWib4O/RPvhMpH0Ck9f4VOgCtKYbUcEe6IQ/M/tWzDY3AZOTvMaZwOyox5fjk40+Mds5mp5tt8Gfvf4oZpvYYH9G1PrqYNm+wePxwEyaBsNfkVmwb/IFT7Cd4ZOeuGWNWvYicGnMsqODL6IFj3sHX+RktzZxypBuArxVULZvxyz/DTArw8/HHOD+mGUfA79O47kHBMddFbP8XeDC4P+TgLuSvEa89/hkmifA6wgS1GDZw8ASomqZgKnALUn2NSzYX//g8ajgca+obTrhg+WJUcsq8bV6V8U879iY138fuCzJ/gel+EwMivOcV/G1ag5f016R6PV1S/szfzeKiUUbE4Nj+TROeT4FxiUoa1oxEV9TfhQ+QTsSX9v4AdA+w89Qyr9rkuf+BB9jo/9+lfiTnP8XPE4YS6I+IyNill9O8wR4HlAZtexN4P2Y580Bzk9S3kPxJ1wVUZ+f1THbDI19//FX/1YApyX73OGT9JOS7D/V56Y6atvf46/c7Rl8fkfga3cdQQVGnNe/Dn8y0zWTz0C+bm2QfBkOdACejuls1Rb/JQAg6OQwFv/h6oz/csa21ZnvnFsUZx/vR/7jnKszsyX44JfM+1H/XxjcR54zDJjmgk9q4PUUrxerDp+UNTKz3vjayv2BPvjjq8LXriSzBzDSzC6KWlYRPLcv8IXz7VfDaMPqYh5bnGXpeD/m8UJS/83AvxcdgSUxzR074GuCAP4MPGxm38DXaDzunHshizJ+7pxbEfV4EfCRc25jzLLGcgf7vAzYDeiJf3/A/43nJ9jPtvjvwyuRBc65ejOrxX9/or0Z8/hmYIKZHYq/pPcv59xbUa8zN9kBJvADfC3JrvhAfRE+AZLcUEwMFFlMjBfH0olvSWOic+6BqHX/M7O38DWvR+CvMGUim78r+PdrMLAqJm52ZHPcTBpLMjDDOVcf9XgRvraTmGXRcfMAYBz+SkA3/OegHf5vuZD4dsDXLtdGFjjnVpjZ/2gaN5t97vBXMu4ws5Pwx/qIc+7DqNeZnfQIm7oyKOer+L/9IvwVlAvxzZKaMLNzgTOAg5xzKzPYT94oAc6fSAfDI/GXzKJtAjCzb+LbJv0W+AX+y3IUvodptDUJ9rEp5rEjdcfGxuc451wQFCLPyTapi7YhJgiA/1L0wR/jHPwZ7hT8Fz2ZCvx781CcdUsAzGwC/tJYMsOdc7F/g3QtxX+Z+8Ys743/wmcqm78ZwTaLgHijT6wEcM49ZWaD8J1ZDgSeNLOHnHOn5KCM8ZZVAphZJ+AZfDu0H+N/fHsBL5H8bxz5RYr3mYtd1uQ74Jz7q5k9g79EfhDwqpmNd85dHpRpOr4WOJG5zrkdY15zXvDfGUGHkTvM7DrnXF2S15H0KSZuViwx8Uugt5lZJMkPOhRuSeL4llVMdM4tNLP5+BrMTLUkbr6Lr42O9VVQrqSxpIVlTFjuIFY/iW8a8ht8rfQ3gPtJL27GE/1Zbfa5c85dbmZ/x/9GfAe4zMzOdM7dGZQp1YgTLznnDgteax3wEzM7A/9Z/gLfVnsV/jOyucA++b0KOMw590aKfYRGCXD+zMAHtUHOuf8k2GYfYIFz7srIguBLUSgzgdExy0bm4HX3BX7unHsSwMz64HsER9tE81qet4FhKc5Kf0PzH8dYic6kU3LObQxqLg6m6Y/OwcAj2b5uFt7GB5kG59yniTZyzi3Fd4S7z8yeAu4PAtwG4r/HuTAMn/D+yjn3GYCZfS9mm0jtcfT+ZwfL98VfciVIPGvwPdOTcs7NxzdVmBjUiJ2LvzQJ/scs2WgRsT9MsSrw8bESX5MiLaeYuFmxxMRafC17Db4mj+D/naIeN5FtTAyGQKzGJ0pheRvfHnepc+7rRBsliSXx4laujMAnur+IJKpm9t2YbTbG2fcMfHyqwTeJwcy64pua3JVqp865j/HNQv5oZn/
"text/plain": [
"<Figure size 720x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 710\n",
"\n",
2022-06-10 10:29:59 +02:00
"fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n",
"\n",
"plt.sca(axes[0])\n",
"plot_predictions([gbrt], X, y, axes=[-0.5, 0.5, -0.1, 0.8], style=\"r-\",\n",
" label=\"Ensemble predictions\")\n",
"plt.title(f\"learning_rate={gbrt.learning_rate}, \"\n",
" f\"n_estimators={gbrt.n_estimators_}\")\n",
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
"\n",
"plt.sca(axes[1])\n",
"plot_predictions([gbrt_best], X, y, axes=[-0.5, 0.5, -0.1, 0.8], style=\"r-\")\n",
"plt.title(f\"learning_rate={gbrt_best.learning_rate}, \"\n",
" f\"n_estimators={gbrt_best.n_estimators_}\")\n",
"plt.xlabel(\"$x_1$\")\n",
"\n",
"save_fig(\"gbrt_learning_rate_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"# extra code at least not in this chapter, it's presented in chapter 2\n",
"\n",
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"import tarfile\n",
"import urllib.request\n",
"\n",
"def load_housing_data():\n",
" tarball_path = Path(\"datasets/housing.tgz\")\n",
" if not tarball_path.is_file():\n",
" Path(\"datasets\").mkdir(parents=True, exist_ok=True)\n",
" url = \"https://github.com/ageron/data/raw/main/housing.tgz\"\n",
" urllib.request.urlretrieve(url, tarball_path)\n",
" with tarfile.open(tarball_path) as housing_tarball:\n",
" housing_tarball.extractall(path=\"datasets\")\n",
" return pd.read_csv(Path(\"datasets/housing/housing.csv\"))\n",
"\n",
"housing = load_housing_data()\n",
"\n",
"train_set, test_set = train_test_split(housing, test_size=0.2, random_state=42)\n",
"housing_labels = train_set[\"median_house_value\"]\n",
"housing = train_set.drop(\"median_house_value\", axis=1)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 36,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('columntransformer',\n",
" ColumnTransformer(remainder='passthrough',\n",
" transformers=[('ordinalencoder',\n",
" OrdinalEncoder(),\n",
" ['ocean_proximity'])])),\n",
" ('histgradientboostingregressor',\n",
" HistGradientBoostingRegressor(categorical_features=[0],\n",
" random_state=42))])"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.compose import make_column_transformer\n",
"from sklearn.ensemble import HistGradientBoostingRegressor\n",
"from sklearn.preprocessing import OrdinalEncoder \n",
"\n",
"hgb_reg = make_pipeline(\n",
" make_column_transformer((OrdinalEncoder(), [\"ocean_proximity\"]),\n",
" remainder=\"passthrough\"),\n",
" HistGradientBoostingRegressor(categorical_features=[0], random_state=42)\n",
")\n",
"hgb_reg.fit(housing, housing_labels)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 37,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"count 10.000000\n",
"mean 47613.307194\n",
"std 1295.422509\n",
"min 44963.213061\n",
"25% 47001.233485\n",
"50% 48000.963564\n",
"75% 48488.093243\n",
"max 49176.368465\n",
"dtype: float64"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code evaluate the RMSE stats for the hgb_reg model\n",
"\n",
"from sklearn.model_selection import cross_val_score\n",
"\n",
"hgb_rmses = -cross_val_score(hgb_reg, housing, housing_labels,\n",
" scoring=\"neg_root_mean_squared_error\", cv=10)\n",
"pd.Series(hgb_rmses).describe()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Stacking"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 38,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"StackingClassifier(cv=5,\n",
" estimators=[('lr', LogisticRegression(random_state=42)),\n",
" ('rf', RandomForestClassifier(random_state=42)),\n",
" ('svc', SVC(probability=True, random_state=42))],\n",
" final_estimator=RandomForestClassifier(random_state=43))"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.ensemble import StackingClassifier\n",
"\n",
"stacking_clf = StackingClassifier(\n",
" estimators=[\n",
" ('lr', LogisticRegression(random_state=42)),\n",
" ('rf', RandomForestClassifier(random_state=42)),\n",
" ('svc', SVC(probability=True, random_state=42))\n",
" ],\n",
" final_estimator=RandomForestClassifier(random_state=43),\n",
" cv=5 # number of cross-validation folds\n",
")\n",
"stacking_clf.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 39,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.928"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stacking_clf.score(X_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. to 7."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. If you have trained five different models and they all achieve 95% precision, you can try combining them into a voting ensemble, which will often give you even better results. It works better if the models are very different (e.g., an SVM classifier, a Decision Tree classifier, a Logistic Regression classifier, and so on). It is even better if they are trained on different training instances (that's the whole point of bagging and pasting ensembles), but if not this will still be effective as long as the models are very different.\n",
"2. A hard voting classifier just counts the votes of each classifier in the ensemble and picks the class that gets the most votes. A soft voting classifier computes the average estimated class probability for each class and picks the class with the highest probability. This gives high-confidence votes more weight and often performs better, but it works only if every classifier is able to estimate class probabilities (e.g., for the SVM classifiers in Scikit-Learn you must set `probability=True`).\n",
"3. It is quite possible to speed up training of a bagging ensemble by distributing it across multiple servers, since each predictor in the ensemble is independent of the others. The same goes for pasting ensembles and Random Forests, for the same reason. However, each predictor in a boosting ensemble is built based on the previous predictor, so training is necessarily sequential, and you will not gain anything by distributing training across multiple servers. Regarding stacking ensembles, all the predictors in a given layer are independent of each other, so they can be trained in parallel on multiple servers. However, the predictors in one layer can only be trained after the predictors in the previous layer have all been trained.\n",
"4. With out-of-bag evaluation, each predictor in a bagging ensemble is evaluated using instances that it was not trained on (they were held out). This makes it possible to have a fairly unbiased evaluation of the ensemble without the need for an additional validation set. Thus, you have more instances available for training, and your ensemble can perform slightly better.\n",
"5. When you are growing a tree in a Random Forest, only a random subset of the features is considered for splitting at each node. This is true as well for Extra-Trees, but they go one step further: rather than searching for the best possible thresholds, like regular Decision Trees do, they use random thresholds for each feature. This extra randomness acts like a form of regularization: if a Random Forest overfits the training data, Extra-Trees might perform better. Moreover, since Extra-Trees don't search for the best possible thresholds, they are much faster to train than Random Forests. However, they are neither faster nor slower than Random Forests when making predictions.\n",
"6. If your AdaBoost ensemble underfits the training data, you can try increasing the number of estimators or reducing the regularization hyperparameters of the base estimator. You may also try slightly increasing the learning rate.\n",
"7. If your Gradient Boosting ensemble overfits the training set, you should try decreasing the learning rate. You could also use early stopping to find the right number of predictors (you probably have too many)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Voting Classifier"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Exercise: _Load the MNIST data and split it into a training set, a validation set, and a test set (e.g., use 50,000 instances for training, 10,000 for validation, and 10,000 for testing)._"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The MNIST dataset was loaded earlier. The dataset is already split into a training set (the first 60,000 instances) and a test set (the last 10,000 instances), and the training set is already shuffled. So all we need to do is to take the first 50,000 instances for the new training set, the next 10,000 for the validation set, and the last 10,000 for the test set:"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"X_train, y_train = X_mnist[:50_000], y_mnist[:50_000]\n",
"X_valid, y_valid = X_mnist[50_000:60_000], y_mnist[50_000:60_000]\n",
"X_test, y_test = X_mnist[60_000:], y_mnist[60_000:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Exercise: _Then train various classifiers, such as a Random Forest classifier, an Extra-Trees classifier, and an SVM._"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.ensemble import ExtraTreesClassifier\n",
"from sklearn.svm import LinearSVC\n",
"from sklearn.neural_network import MLPClassifier"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: The `LinearSVC` has a `dual` hyperparameter whose default value will change from `True` to `\"auto\"` in Scikit-Learn 1.5. To ensure this notebook continues to produce the same outputs, I'm setting it explicitly to `True`. Please see the [documentation](https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html) for more details."
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"random_forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)\n",
"extra_trees_clf = ExtraTreesClassifier(n_estimators=100, random_state=42)\n",
"svm_clf = LinearSVC(max_iter=100, tol=20, dual=True, random_state=42)\n",
"mlp_clf = MLPClassifier(random_state=42)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 43,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training the RandomForestClassifier(random_state=42)\n",
"Training the ExtraTreesClassifier(random_state=42)\n",
"Training the LinearSVC(max_iter=100, random_state=42, tol=20)\n",
"Training the MLPClassifier(random_state=42)\n"
]
}
],
"source": [
"estimators = [random_forest_clf, extra_trees_clf, svm_clf, mlp_clf]\n",
"for estimator in estimators:\n",
" print(\"Training the\", estimator)\n",
" estimator.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 44,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
2022-06-10 10:45:37 +02:00
"[0.9736, 0.9743, 0.8662, 0.966]"
2022-02-19 10:24:54 +01:00
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[estimator.score(X_valid, y_valid) for estimator in estimators]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The linear SVM is far outperformed by the other classifiers. However, let's keep it for now since it may improve the voting classifier's performance."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Exercise: _Next, try to combine \\[the classifiers\\] into an ensemble that outperforms them all on the validation set, using a soft or hard voting classifier._"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.ensemble import VotingClassifier"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"named_estimators = [\n",
" (\"random_forest_clf\", random_forest_clf),\n",
" (\"extra_trees_clf\", extra_trees_clf),\n",
" (\"svm_clf\", svm_clf),\n",
" (\"mlp_clf\", mlp_clf),\n",
"]"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"voting_clf = VotingClassifier(named_estimators)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 48,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"VotingClassifier(estimators=[('random_forest_clf',\n",
" RandomForestClassifier(random_state=42)),\n",
" ('extra_trees_clf',\n",
" ExtraTreesClassifier(random_state=42)),\n",
" ('svm_clf',\n",
" LinearSVC(max_iter=100, random_state=42, tol=20)),\n",
" ('mlp_clf', MLPClassifier(random_state=42))])"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_clf.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 49,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
2022-06-10 10:45:37 +02:00
"0.9758"
2022-02-19 10:24:54 +01:00
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_clf.score(X_valid, y_valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `VotingClassifier` made a clone of each classifier, and it trained the clones using class indices as the labels, not the original class names. Therefore, to evaluate these clones we need to provide class indices as well. To convert the classes to class indices, we can use a `LabelEncoder`:"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
"encoder = LabelEncoder()\n",
"y_valid_encoded = encoder.fit_transform(y_valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, in the case of MNIST, it's simpler to just convert the class names to integers, since the digits match the class ids:"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"y_valid_encoded = y_valid.astype(np.int64)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's evaluate the classifier clones:"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 52,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
2022-06-10 10:45:37 +02:00
"[0.9736, 0.9743, 0.8662, 0.966]"
2022-02-19 10:24:54 +01:00
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[estimator.score(X_valid, y_valid_encoded)\n",
" for estimator in voting_clf.estimators_]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's remove the SVM to see if performance improves. It is possible to remove an estimator by setting it to `\"drop\"` using `set_params()` like this:"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 53,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"VotingClassifier(estimators=[('random_forest_clf',\n",
" RandomForestClassifier(random_state=42)),\n",
" ('extra_trees_clf',\n",
" ExtraTreesClassifier(random_state=42)),\n",
" ('svm_clf', 'drop'),\n",
" ('mlp_clf', MLPClassifier(random_state=42))])"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_clf.set_params(svm_clf=\"drop\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This updated the list of estimators:"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 54,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[('random_forest_clf', RandomForestClassifier(random_state=42)),\n",
" ('extra_trees_clf', ExtraTreesClassifier(random_state=42)),\n",
" ('svm_clf', 'drop'),\n",
" ('mlp_clf', MLPClassifier(random_state=42))]"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_clf.estimators"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, it did not update the list of _trained_ estimators:"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 55,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[RandomForestClassifier(random_state=42),\n",
" ExtraTreesClassifier(random_state=42),\n",
" LinearSVC(max_iter=100, random_state=42, tol=20),\n",
" MLPClassifier(random_state=42)]"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_clf.estimators_"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 56,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"{'random_forest_clf': RandomForestClassifier(random_state=42),\n",
" 'extra_trees_clf': ExtraTreesClassifier(random_state=42),\n",
" 'svm_clf': LinearSVC(max_iter=100, random_state=42, tol=20),\n",
" 'mlp_clf': MLPClassifier(random_state=42)}"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_clf.named_estimators_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So we can either fit the `VotingClassifier` again, or just remove the SVM from the list of trained estimators, both in `estimators_` and `named_estimators_`:"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"svm_clf_trained = voting_clf.named_estimators_.pop(\"svm_clf\")\n",
"voting_clf.estimators_.remove(svm_clf_trained)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's evaluate the `VotingClassifier` again:"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 58,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
2022-06-10 10:45:37 +02:00
"0.9769"
2022-02-19 10:24:54 +01:00
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_clf.score(X_valid, y_valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A bit better! The SVM was hurting performance. Now let's try using a soft voting classifier. We do not actually need to retrain the classifier, we can just set `voting` to `\"soft\"`:"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"voting_clf.voting = \"soft\""
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 60,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
2022-06-10 10:45:37 +02:00
"0.9724"
2022-02-19 10:24:54 +01:00
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_clf.score(X_valid, y_valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nope, hard voting wins in this case."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_Once you have found \\[an ensemble that performs better than the individual predictors\\], try it on the test set. How much better does it perform compared to the individual classifiers?_"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 61,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
2022-06-10 10:45:37 +02:00
"0.9727"
2022-02-19 10:24:54 +01:00
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"voting_clf.voting = \"hard\"\n",
"voting_clf.score(X_test, y_test)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 62,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
2022-06-10 10:45:37 +02:00
"[0.968, 0.9703, 0.965]"
2022-02-19 10:24:54 +01:00
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[estimator.score(X_test, y_test.astype(np.int64))\n",
" for estimator in voting_clf.estimators_]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The voting classifier reduced the error rate of the best model from about 3% to 2.7%, which means 10% less errors."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9. Stacking Ensemble"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Exercise: _Run the individual classifiers from the previous exercise to make predictions on the validation set, and create a new training set with the resulting predictions: each training instance is a vector containing the set of predictions from all your classifiers for an image, and the target is the image's class. Train a classifier on this new training set._"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"X_valid_predictions = np.empty((len(X_valid), len(estimators)), dtype=object)\n",
"\n",
"for index, estimator in enumerate(estimators):\n",
" X_valid_predictions[:, index] = estimator.predict(X_valid)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 64,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([['3', '3', '3', '3'],\n",
" ['8', '8', '8', '8'],\n",
" ['6', '6', '6', '6'],\n",
" ...,\n",
" ['5', '5', '5', '5'],\n",
" ['6', '6', '6', '6'],\n",
" ['8', '8', '8', '8']], dtype=object)"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_valid_predictions"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 65,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"RandomForestClassifier(n_estimators=200, oob_score=True, random_state=42)"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rnd_forest_blender = RandomForestClassifier(n_estimators=200, oob_score=True,\n",
" random_state=42)\n",
"rnd_forest_blender.fit(X_valid_predictions, y_valid)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 66,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
2022-06-10 10:45:37 +02:00
"0.9722"
2022-02-19 10:24:54 +01:00
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rnd_forest_blender.oob_score_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You could fine-tune this blender or try other types of blenders (e.g., an `MLPClassifier`), then select the best one using cross-validation, as always."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Exercise: _Congratulations, you have just trained a blender, and together with the classifiers they form a stacking ensemble! Now let's evaluate the ensemble on the test set. For each image in the test set, make predictions with all your classifiers, then feed the predictions to the blender to get the ensemble's predictions. How does it compare to the voting classifier you trained earlier?_"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"X_test_predictions = np.empty((len(X_test), len(estimators)), dtype=object)\n",
"\n",
"for index, estimator in enumerate(estimators):\n",
" X_test_predictions[:, index] = estimator.predict(X_test)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
"y_pred = rnd_forest_blender.predict(X_test_predictions)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 69,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
2022-06-10 10:45:37 +02:00
"0.9705"
2022-02-19 10:24:54 +01:00
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy_score(y_test, y_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-06-10 10:29:59 +02:00
"This stacking ensemble does not perform as well as the voting classifier we trained earlier."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Exercise: _Now try again using a `StackingClassifier` instead: do you get better performance? If so, why?_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since `StackingClassifier` uses K-Fold cross-validation, we don't need a separate validation set, so let's join the training set and the validation set into a bigger training set:"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"X_train_full, y_train_full = X_mnist[:60_000], y_mnist[:60_000]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's create and train the stacking classifier on the full training set:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning**: the following cell will take quite a while to run (15-30 minutes depending on your hardware), as it uses K-Fold validation with 5 folds by default. It will train the 4 classifiers 5 times each on 80% of the full training set to make the predictions, plus one last time each on the full training set, and lastly it will train the final model on the predictions. That's a total of 25 models to train!"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 71,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"StackingClassifier(estimators=[('random_forest_clf',\n",
" RandomForestClassifier(random_state=42)),\n",
" ('extra_trees_clf',\n",
" ExtraTreesClassifier(random_state=42)),\n",
" ('svm_clf',\n",
" LinearSVC(max_iter=100, random_state=42,\n",
" tol=20)),\n",
" ('mlp_clf', MLPClassifier(random_state=42))],\n",
" final_estimator=RandomForestClassifier(n_estimators=200,\n",
" oob_score=True,\n",
" random_state=42))"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stack_clf = StackingClassifier(named_estimators,\n",
" final_estimator=rnd_forest_blender)\n",
"stack_clf.fit(X_train_full, y_train_full)"
]
},
{
"cell_type": "code",
2021-11-21 05:55:56 +01:00
"execution_count": 72,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
2022-06-10 10:45:37 +02:00
"0.9784"
2022-02-19 10:24:54 +01:00
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stack_clf.score(X_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `StackingClassifier` significantly outperforms the custom stacking implementation we tried earlier! This is for mainly two reasons:\n",
"\n",
"* Since we could reclaim the validation set, the `StackingClassifier` was trained on a larger dataset.\n",
"* It used `predict_proba()` if available, or else `decision_function()` if available, or else `predict()`. This gave the blender much more nuanced inputs to work with."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And that's all for today, congratulations on finishing the chapter and the exercises!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
},
"nav_menu": {
"height": "252px",
"width": "333px"
},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}