2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-09-15 16:41:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"**Chapter 7 – Ensemble Learning and Random Forests**"
]
},
{
"cell_type": "markdown",
2017-09-15 16:41:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2017-08-19 17:01:55 +02:00
"_This notebook contains all the sample code and solutions to the exercises in chapter 7._"
2016-09-27 23:31:21 +02:00
]
},
2019-11-05 15:26:52 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/ageron/handson-ml2/blob/master/07_ensemble_learning_and_random_forests.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
2021-05-25 04:47:05 +02:00
" <td>\n",
2021-05-25 05:54:57 +02:00
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/pizzaz93/handson-ml2/blob/add-kaggle-badge/07_ensemble_learning_and_random_forests.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
2021-05-25 04:47:05 +02:00
" </td>\n",
2019-11-05 15:26:52 +01:00
"</table>"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
2017-09-15 16:41:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2017-09-15 16:41:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2019-01-21 11:13:10 +01:00
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20."
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 1,
2018-05-08 12:43:49 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2019-01-16 16:42:00 +01:00
"# Python ≥3.5 is required\n",
"import sys\n",
"assert sys.version_info >= (3, 5)\n",
2016-09-27 23:31:21 +02:00
"\n",
2019-01-21 11:13:10 +01:00
"# Scikit-Learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n",
"\n",
2016-09-27 23:31:21 +02:00
"# Common imports\n",
"import numpy as np\n",
"import os\n",
"\n",
"# to make this notebook's output stable across runs\n",
2017-06-06 16:32:08 +02:00
"np.random.seed(42)\n",
2016-09-27 23:31:21 +02:00
"\n",
"# To plot pretty figures\n",
"%matplotlib inline\n",
2019-01-16 16:42:00 +01:00
"import matplotlib as mpl\n",
2016-09-27 23:31:21 +02:00
"import matplotlib.pyplot as plt\n",
2019-01-16 16:42:00 +01:00
"mpl.rc('axes', labelsize=14)\n",
"mpl.rc('xtick', labelsize=12)\n",
"mpl.rc('ytick', labelsize=12)\n",
2016-09-27 23:31:21 +02:00
"\n",
"# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n",
"CHAPTER_ID = \"ensembles\"\n",
2019-01-21 11:13:10 +01:00
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
2019-01-21 11:13:10 +01:00
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n",
2016-09-27 23:31:21 +02:00
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
2019-01-21 11:13:10 +01:00
" plt.savefig(path, format=fig_extension, dpi=resolution)"
2019-01-18 16:08:37 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Voting classifiers"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 2,
2019-01-18 16:08:37 +01:00
"metadata": {},
"outputs": [],
2016-09-27 23:31:21 +02:00
"source": [
"heads_proba = 0.51\n",
2017-06-06 16:32:08 +02:00
"coin_tosses = (np.random.rand(10000, 10) < heads_proba).astype(np.int32)\n",
2016-09-27 23:31:21 +02:00
"cumulative_heads_ratio = np.cumsum(coin_tosses, axis=0) / np.arange(1, 10001).reshape(-1, 1)"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 3,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saving figure law_of_large_numbers_plot\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjEAAAD5CAYAAADWS2QEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAACXE0lEQVR4nOyddXgUx9/AP3Oai7sACe7uFLeWAnWh7qXubpSW/uoutJS6lxoVSg13d9cIJCHul9N5/9jLXS53CYECbV/m8zx5sjszOzt7svu9rwopJQqFQqFQKBT/NXT/9AIUCoVCoVAojgYlxCgUCoVCofhPooQYhUKhUCgU/0mUEKNQKBQKheI/iRJiFAqFQqFQ/CdRQoxCoVAoFIr/JEqIUSgUCoVC8Z/khAoxQohYIcRMIUSlECJDCHFpPeOuFkK4hBAVtf6G1+pvIYSYLYQoFkLkCiHeEkIYTtR1KBQKhUKh+Oc50ZqYqYAdSAIuA94RQnSuZ+xyKWV4rb8FtfreBvKAFKAHMAy45bitWqFQKBQKxb+OEybECCHCgPOBSVLKCinlEuBn4IqjmK4l8I2UslpKmQv8DtQnDCkUCoVCofh/yIk0wbQDXFLKXbXaNqJpUYLRUwhRABQBnwHPSimdnr7XgYuFEAuAGGAsMCnYJEKIG4AbAERkdG99cgoABvt+ADrHKdlHoVAoFIpjydq1awuklAnH+zwnUogJB0rrtJUCEUHGLgK6ABloGpYZgBN41tO/EJgIlAF64BPgx2AnlVJOB6YDGNt3knHTvgQgIVNTAK25as3RXY1CoVAoFIqgCCEyTsR5TqRPTAUQWactEiivO1BKuU9KuV9K6ZZSbgamABcACCF0wB/AD0AYEI+mjXn+OK5doVAoFArFv4wTKcTsAgxCiLa12roDWxtxrASEZzsWSAXeklLapJSFwEfAuGO5WIVCoVAoFP9uTpgQI6WsRNOeTBFChAkhBgFno/m7+CGEGCuESPJsd0Dzd/nJM08BsB+4WQhhEEJEA1eh+dcoFAqFQqE4STjRIda3ABa08OivgJullFuFEGmeXDBpnnGjgE1CiEpgNprw80ytec4DTgfygT1o/jJ3n6BrUCgUCoVC8S/ghCaIk1IWAecEac9Ec/yt2b8PuK+BeTYAw4/5AhUKhUKhUPxnOGnLDrSqbvZPL0GhUCgUCsXf4KQVYqbuf4RER+w/vQyFQqFQKBRHyUkrxABEuEL/6SUoFAqFQqE4Sk5aIWb4yHAu2HQ/JYeq/umlKBQKhUKhOApOWiGmwiiQwK+78vjf3ux/ejkKhUKhUCiOkJNWiKnhXmcJb2Xm/dPLUCgUCoVCcYSc9EJMDQV25+EHKRQKhUKh+Ndwcgsx0re5qVz5xigUCoVC8V/ipBZiqozCu51nd/yDK1EoFAqFQnGknNRCzKzuvhBrgWhgpEKhUCgUin8bJ7UQs7ql+Z9egkKhUCgUiqPkpBZiatPSYvqnl6BQKBQKheIIUEKMhxC9eikUCoVCofgvoZ7cHlzSf99ZYqNiRc4/sxiFQqFQKBSH5YQKMUKIWCHETCFEpRAiQwhxaT3jrhZCuIQQFbX+htcZc7EQYrtnrr1CiCF/Z21u6S/FFHy0hZIf9+CqsP+daRUKhUKhUBwnDCf4fFMBO5AE9AB+FUJslFJuDTJ2uZRycLBJhBCnAs8DFwGrgJS/uzB3nX1ngRUAe2Y5lk5xf3d6hUKhUCgUx5gTpokRQoQB5wOTpJQVUsolwM/AFUcx3ZPAFCnlCimlW0p5UEp58O+sz1VHE1NjXyr9ff/fmbZesrKyKC8vPy5zKxQKhUJxMnAizUntAJeUcletto1A53rG9xRCFAghdgkhJgkhDABCCD3QB0gQQuwRQhwQQrwlhLAEm0QIcYMQYo0QYk1Di3PL4O3SoeloSkpKsFqtDU3RKJxOJ59++ikffPAB06ZN+9vzKRQKhUJxsnIihZhwoLROWykQEWTsIqALkIimvbkEuN/TlwQYgQuAIWhmqZ7AY8FOKqWcLqXsI6Xs09Di3ASXYqRTa3/ttdd4++23G5qiUezYsYN9+/YBUFlZ+bfnUygUCoXiZOVECjEVQGSdtkggwKYipdwnpdzvMRVtBqagCS0ANeqQN6WUOVLKAuAVYNzfWVzt6CRnic277S73OfYeC/PPzz///LfnUCgUCoVCcWKFmF2AQQjRtlZbdyCYU29dJGh1AaSUxcABqEd1cpTUjk5yFVX79TVWeCkpKUHW9a2phcvlwm63B7QpFAqFQqE4ck6YECOlrAR+AKYIIcKEEIOAs4HP6o4VQowVQiR5tjsAk4Cfag35CLhdCJEohIgB7gJm/Z31+YsS/oLIkiVLDnt8dnY2r732Gs888wylpaVszynj61WZ3v6ysjKeeuqpgOOUc69CoVAoFEfHiQ6xvgX4EMgDCoGbpZRbhRBpwDagk5QyExgFfCyECAcOAZ8Dz9Sa5ykgHk27Uw18Azz9dxZWWxNTW5lShY2VK1ce9viCggIAHA4HP//8M5O2xgJwSus4mseFefvr8tprrwHwyCOPYDKp0gcKhUKhUDSWE5rsTkpZJKU8R0oZJqVMk1J+6WnPlFKGewQYpJT3SSmTPONaSSkfl1I6as3jkFLeIqWMllImSynvkFJW13fexuAXnVRr2y6cfuP+/PNPvv/+ewAqKirYs2cPAHl5ed4xTqfvmGEvLiCvzH9pzZs3Z9CgQX5tq1at8jtOoVAoFApFw6iyAx5ctSWXBvxali1bxubNmwH4+OOP+fzzz3G73X4mp8rKKr9j+j0zl/37fflmBg0aRJ8+/sFSc+bM4csvv/w7l6BQKBQKxUmFEmI81NbECL3wbssG/IdrTERz586t054fMHbx4sXe7TZt2hATExMwZt++fWzfvp2SkpLGLluhUCgUipMWJcR4qJ2xV9aKt66dPyYuzld+YNu2bd7tpUuX+s1V7Pbl3dPh5uqQ1X79Ol39L/uMGTO8fjIKhUKhUCjqRwkxHmrXTqotxLhq9dTO2PvNN9/UO1eMzjfOWCvuqUKa+Li6L2szigC4/fbbue666/7OshUKhUKhOGk5qYSYcCrq7fOrYu0RYnQRRj8hpqqqqu5h9ZKqKwbAJHxCTLjQcsQ8OnMLBRU2nvorkyFTN3HLPQ8xfPjwRs+tUCgUCoXiJBNiGqJ2xl7p0gQX46kp/GpeB0BKyuELZZe6Q9jv0nxdRpn20D7BzKiUwGR2O3LL6fO/Ofy4IRuAbTllWCxBSz8pFAqFQqGohxOdJ+ZfS1lFGaDldqnx8i1x+TQ3ZrP5sHNE6aqJwhdOPaR6Jc5yb2Q4qx3N6j02NDTUb99ut6u8MQqFQqFQNIASYjxkLdkObVsAvqKPOqPe2x8SEnLEczodDr/9lFYd2Lq7bg1MKLU6aBbur4mxWq2YTCZaPPQrAHqdYO1jo4kOVYKNQqFQKBSgzEk+XL6XQro1c5LO6Gtz1BFIajDV0dBsdSYFHXfupdfw+uUDeO/KwGLaq/YXBWhdSssrKKr01VlyuSWXvHf4zMEKhUKhUJwsnHRCTCfrtqDtbuHLDeN1kDH4Xp6EhISgx6U1b8mM6h7e/Q3OpgFjwsPD6d6uOWFmA6d2SmL2HUOYdftgRndMBOCLlZlExCYxePBgzjrrLAAe/34dvZ76y2+e7Tllh70+hUKhUChOFk46ISbeVeS3365Mc7yVtYSYmhBrWTvpXa3opbCwMC677DIAIqJjsWLkZ1snvqruwfRrBgScs1kzf1+YTk0i6dI0ivev6utt6/m/OYwePdo7Nv1Q8VFdn0KhUCgUJwsnnRCDwT9a6JQCrV5R7TwxNZoYl/C11hZi7rvvPtq0acPFF19M2269ASiSYdgwMqRNfMApIyIiGrW0BTvzKLZp28NN+7ztKx4eRZvEcEJN+nqOVCgUCoXi5OOkd+ytkeL2hPscd2tCrLMOZHnbXC6f8CM8WpsOHTqwMavE257+3HgArr32WvLz84mMjKSwsJCePXvWe/759w1nxEsLALj6o9UI3FxVx4c4OSqE83o15YXfd2K1u7AoYUahUCgUipNQiBF
"text/plain": [
"<Figure size 576x252 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"plt.figure(figsize=(8,3.5))\n",
"plt.plot(cumulative_heads_ratio)\n",
"plt.plot([0, 10000], [0.51, 0.51], \"k--\", linewidth=2, label=\"51%\")\n",
"plt.plot([0, 10000], [0.5, 0.5], \"k-\", label=\"50%\")\n",
"plt.xlabel(\"Number of coin tosses\")\n",
"plt.ylabel(\"Heads ratio\")\n",
"plt.legend(loc=\"lower right\")\n",
"plt.axis([0, 10000, 0.42, 0.58])\n",
"save_fig(\"law_of_large_numbers_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 4,
2018-05-08 12:43:49 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2016-11-05 14:25:56 +01:00
"from sklearn.model_selection import train_test_split\n",
2016-09-27 23:31:21 +02:00
"from sklearn.datasets import make_moons\n",
"\n",
"X, y = make_moons(n_samples=500, noise=0.30, random_state=42)\n",
2017-06-02 10:57:06 +02:00
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)"
]
},
2018-12-21 03:18:31 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-01-18 16:08:37 +01:00
"**Note**: to be future-proof, we set `solver=\"lbfgs\"`, `n_estimators=100`, and `gamma=\"scale\"` since these will be the default values in upcoming Scikit-Learn versions."
2018-12-21 03:18:31 +01:00
]
},
2017-06-02 10:57:06 +02:00
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 5,
2017-09-15 16:41:15 +02:00
"metadata": {},
2017-06-02 10:57:06 +02:00
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.ensemble import VotingClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.svm import SVC\n",
"\n",
2019-01-18 16:08:37 +01:00
"log_clf = LogisticRegression(solver=\"lbfgs\", random_state=42)\n",
"rnd_clf = RandomForestClassifier(n_estimators=100, random_state=42)\n",
"svm_clf = SVC(gamma=\"scale\", random_state=42)\n",
2016-09-27 23:31:21 +02:00
"\n",
"voting_clf = VotingClassifier(\n",
2017-06-02 10:57:06 +02:00
" estimators=[('lr', log_clf), ('rf', rnd_clf), ('svc', svm_clf)],\n",
2018-05-08 12:43:49 +02:00
" voting='hard')"
2017-06-02 10:57:06 +02:00
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 6,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"VotingClassifier(estimators=[('lr', LogisticRegression(random_state=42)),\n",
" ('rf', RandomForestClassifier(random_state=42)),\n",
" ('svc', SVC(random_state=42))])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
2018-05-08 12:43:49 +02:00
"source": [
"voting_clf.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 7,
2018-05-08 12:43:49 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LogisticRegression 0.864\n",
"RandomForestClassifier 0.896\n",
"SVC 0.896\n",
"VotingClassifier 0.912\n"
]
}
],
2017-06-02 10:57:06 +02:00
"source": [
"from sklearn.metrics import accuracy_score\n",
2016-09-27 23:31:21 +02:00
"\n",
2017-06-02 10:57:06 +02:00
"for clf in (log_clf, rnd_clf, svm_clf, voting_clf):\n",
" clf.fit(X_train, y_train)\n",
" y_pred = clf.predict(X_test)\n",
" print(clf.__class__.__name__, accuracy_score(y_test, y_pred))"
]
},
2021-02-14 03:02:09 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: the results in this notebook may differ slightly from the book, as Scikit-Learn algorithms sometimes get tweaked."
]
},
2019-01-18 16:08:37 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Soft voting:"
]
},
2017-06-02 10:57:06 +02:00
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 8,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"VotingClassifier(estimators=[('lr', LogisticRegression(random_state=42)),\n",
" ('rf', RandomForestClassifier(random_state=42)),\n",
" ('svc', SVC(probability=True, random_state=42))],\n",
" voting='soft')"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-02 10:57:06 +02:00
"source": [
2019-01-18 16:08:37 +01:00
"log_clf = LogisticRegression(solver=\"lbfgs\", random_state=42)\n",
"rnd_clf = RandomForestClassifier(n_estimators=100, random_state=42)\n",
"svm_clf = SVC(gamma=\"scale\", probability=True, random_state=42)\n",
2017-06-02 10:57:06 +02:00
"\n",
"voting_clf = VotingClassifier(\n",
" estimators=[('lr', log_clf), ('rf', rnd_clf), ('svc', svm_clf)],\n",
" voting='soft')\n",
"voting_clf.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 9,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LogisticRegression 0.864\n",
"RandomForestClassifier 0.896\n",
"SVC 0.896\n",
"VotingClassifier 0.92\n"
]
}
],
2017-06-02 10:57:06 +02:00
"source": [
2016-09-27 23:31:21 +02:00
"from sklearn.metrics import accuracy_score\n",
"\n",
"for clf in (log_clf, rnd_clf, svm_clf, voting_clf):\n",
" clf.fit(X_train, y_train)\n",
" y_pred = clf.predict(X_test)\n",
" print(clf.__class__.__name__, accuracy_score(y_test, y_pred))"
]
},
{
"cell_type": "markdown",
2017-09-15 16:41:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Bagging ensembles"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 10,
2018-05-08 12:43:49 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"from sklearn.ensemble import BaggingClassifier\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"\n",
"bag_clf = BaggingClassifier(\n",
2021-03-04 03:17:19 +01:00
" DecisionTreeClassifier(), n_estimators=500,\n",
2019-01-18 16:08:37 +01:00
" max_samples=100, bootstrap=True, random_state=42)\n",
2016-09-27 23:31:21 +02:00
"bag_clf.fit(X_train, y_train)\n",
2017-06-02 10:57:06 +02:00
"y_pred = bag_clf.predict(X_test)"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 11,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.904\n"
]
}
],
2017-06-02 10:57:06 +02:00
"source": [
"from sklearn.metrics import accuracy_score\n",
2016-09-27 23:31:21 +02:00
"print(accuracy_score(y_test, y_pred))"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 12,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.856\n"
]
}
],
2016-09-27 23:31:21 +02:00
"source": [
"tree_clf = DecisionTreeClassifier(random_state=42)\n",
"tree_clf.fit(X_train, y_train)\n",
"y_pred_tree = tree_clf.predict(X_test)\n",
"print(accuracy_score(y_test, y_pred_tree))"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 13,
2018-05-08 12:43:49 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"from matplotlib.colors import ListedColormap\n",
"\n",
2019-10-13 10:58:36 +02:00
"def plot_decision_boundary(clf, X, y, axes=[-1.5, 2.45, -1, 1.5], alpha=0.5, contour=True):\n",
2016-09-27 23:31:21 +02:00
" x1s = np.linspace(axes[0], axes[1], 100)\n",
" x2s = np.linspace(axes[2], axes[3], 100)\n",
" x1, x2 = np.meshgrid(x1s, x2s)\n",
" X_new = np.c_[x1.ravel(), x2.ravel()]\n",
" y_pred = clf.predict(X_new).reshape(x1.shape)\n",
" custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n",
2018-05-08 12:43:49 +02:00
" plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap)\n",
2016-09-27 23:31:21 +02:00
" if contour:\n",
" custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])\n",
" plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8)\n",
" plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\", alpha=alpha)\n",
" plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\", alpha=alpha)\n",
" plt.axis(axes)\n",
" plt.xlabel(r\"$x_1$\", fontsize=18)\n",
" plt.ylabel(r\"$x_2$\", fontsize=18, rotation=0)"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 14,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saving figure decision_tree_without_and_with_bagging_plot\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsgAAAEYCAYAAABBfQDEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAADBn0lEQVR4nOy9eZgcV3nv/zlV1dV7zyppNBqNJGu1LNkYL3hhMRAMOBC2sIUkxiHhBm5CMEsCN7k/QwKBgC9LEiBhCSYBwr5jsAM2GBsbvMnWYu2akWZGI83a+1ZV5/dHdfd093TP9Dar6vM880hTU3XqVNWp73nrnPe8r5BS4uDg4ODg4ODg4OBgoyx1BRwcHBwcHBwcHByWE46B7ODg4ODg4ODg4FCEYyA7ODg4ODg4ODg4FOEYyA4ODg4ODg4ODg5FOAayg4ODg4ODg4ODQxGOgezg4ODg4ODg4OBQhGMgOziUIYR4nxDiQI37bhZCSCHElQtdLwcHh5WNoy2LQy33ebXcXyHEHUKIHy11PVYjjoHssCLIiYDM/WSFEOeFEPcKIf63EMLV4tPdDjynxn3PAOuBfS2uQ4Gya6/4s1DndnBY7Tjasiq1peQ+t9KIrHDPxoUQPxJC7GpF+Q3wV8AfLtG5VzWOgeywkvgZdoexGbgR+CHwfuBXQgh/q04ipYxJKSdq3NeUUo5KKY1Wnb8Cf4V93fmfBPD2sm0FhBD6AtbFwWE14mjLKtKWeu5zg+Tby3rs9uIFvruA56uKlDIspZxeinOvdhwD2WElkc51GMNSyn1Syo8BNwBPB/46v5MQQhdC/JMQYkgIERdCPCyEeGFxQUKIXUKIHwghwkKImBDiQSHE3tzfSqbnhBB7hRA/F0JEhBBRIcQTQojn5v42a5pOCPFsIcRvhBApIcQ5IcTHizsWIcQvhBCfFkL8Y2704bwQ4nYhRMX3MSeAo/kfQALhot+/JoT4TK6MMeCB3Hl2CyF+nKvzeSHEfwshesruwy1CiEO5uh4VQtxarR4ODqsYR1sWWVuEEP8rtz0lhBgTQtwlhNAq1VMI8XUhxGeKfv9g7t48o2jbkBDiDeX3WQjxPuBm4HfFzKjvDUXFbxJC/I8QIpGr7wsq1aGMdNF9ewz4OLBLCOEtqs+HhRBHhBBJIcSAEOIjQghP2XW9N/ccY0KI/xRC3CaEGCj6u5Z7xlO5n4/nnscvivYpGR2vpQ0IIdbl2mhSCDGYe1YHcvfKIYfTETqsaKSUB4CfAq8q2vxF7Om1PwD2Al8CfiiEuAxACNEL3I/dGbwAuxP8FKBWOc1XgbPA1cDlwPuAVKUdhRAbgJ8Aj+f2fRPweuBDZbu+ATCA64C/wB61eW0t11yFPwQE8Czgj4UQ64H7gAO5ev8OEAB+kBdKIcSfAf8I/H/AxcA7gb8B3tpEPRwcVgWOthRoubbkjP5PYY/S78yV8dM56vAL4LlFv98AjOe3CSG2Axty+5VzO/ANSkd9f1309w8C/wxcBjyM/VEQmKMuJQghgtj3d7+UMln0pzjwJ9jX/1bgdcDfFh33OuC23LanA08B7ygr/l3AG4E/Ba7Bttn+oIZqzdcGvgRsAp4HvAz7GW+qodwLCyml8+P8LPsf4A7gR1X+9mEgkfv/VsAC+sv2+R7w6dz/PwgMAnqV8t4HHCj6PQLcXGXfzdid4ZVFZR8HlKJ93gikAV/u918AD5aV8z/A52u8FzHgjUW//wJ4smyfvwd+XratI1fXq3O/nwb+qGyftwOHlvp5Oz/Oz2L9ONpSsu+iaAvwSiAMBGus18W58tcDvtw1vwe4K/f3PwOOzXGfZz3jovv7v4q2bchte+Y87cXI3atYbv/TwJ55ruHPgeNFvz8I/FvZPncDA0W/nwXeU/S7AA4Dv6h2bfO1AewPEglcU/T3jYAJvG8x3rmV8lNxOsPBYYUhsF94sL/EBXBICFG8jxu4J/f/y4H7pZSZGsv/GPB5IcTNwM+Bb0spD1fZ92JscbKKtt0P6MA24MnctifLjhsB1tZYn0o8Wvb7FcCzhRCxCvtuFUKcwhbFfy+eugQ07Pvn4ODgaAssjLb8D/aHxCkhxF3YhuF3pJTRShWQUj4lhDjHzMjxCeBrwN8JeyHlDVQePa6F4vs1kvt3vvt1H/Dm3P87sUeI7xZCPENKeQZACPH72B8F27BH2FVKZxJ2AZ8rK/c3wI7c8W1AD/Db/B+llFII8TD2/a31mvLXlb+mXdgfeo8UlXtGCDGCQwmOgeywGtgNnMz9X8Hu0K4CsmX75ae/6jIApZTvE0J8BXgx8ELgNiHEn0sp/6PC7sUd6qyiiv5fXjdJcy5P8bLfFeDH2FN05ZzDHoUBe1Tj1xX2cXBwcLQFFkBbpJRRIcTTgWdju6K8F/hHIcRVUspqhtovsV0qxoB7pZQDQohx7OfxHGwXjkYo3K+cAQrz36+ElPJ4/hchxKPYI+JvBv6vEOIabAP+/cCtwDTwe9juHsXUEiWkkUgic7UBZwCkRhwfZIcVjRBiD/Ai4Fu5TY9jC0CPlPJ42c9wbp/HgGeKOlZkSymPSSn/WUr5u8AXsH3CKnEIuLZsUcwzgQz2qMdi8RhwCTBY4T5EpZTngGFga4W/H5+7aAeH1Y+jLVVpibZIKQ0p5T1SyvcClwJ+4CVznPcX2AbyDcyMFv8S2yit5n+cJ0N1P/BWILFHZfMfB9cDw1LKf5BSPiylPMZsH9/D2D7cxRR+l1KGgdHibcK23q9qsq5PYdt+VxSV2wf0NlnuqsMxkB1WEm4hRI8QolcIcZkQ4h3YovgouS9zKeVR4CvAHUKI3xdCXCSEuFII8S4hxCtz5Xwae8rrG0KIq4QQ24QQrxdCPK38hEIIrxDiU0KIG4S9qvwZ2J3SoSp1/DS20HxaCHGxEOJ3sf0Y/1VKmWjZnZifTwFtwNeFEM/I3YffEUJ8NreoBGw/vb8W9urynUKIPUKIPxZCvHcR6+ngsBxwtKV2mtYWIcRLhBB/JYS4XAixCXvhWRDbeKvGL7DdFa5mxhj+BfYCs+KPlEoMAHtydekWzce3zreXHiHExcC/YD/3H+b+fhTYIIR4Q+7+vAV7QWUxnwTeKIT4EyHEdiHEXwPPoHTE+JPY9/EVQoidwP/D9sNuOD61lPIIcBfwb0KIa3Jt84vYIf5WatzrBcFxsXBYSfwO9qIFE3vK6gD2FNa/l/n83YK9MvgjQB8wie3HdS+AlHJYCPFs4KO5bRLYz4xPWTEm9gKUL2H7g00AP6Ly9GK+7Bfnyt6Xq+dXgf/T0BU3iJRyRAhxPfYK958CHuyFJHdjL3BBSvl5IUQceHduvyRwEPjXxayrg8MywNGWGmmRtkwDL8eOcuHDHgH/Uynlr+Y471NCiFFgQko5ltt8L/bI8C/mqfbnsEeeH8E2ZJ+LbTQ3Sr69AESxR4NfLaX8Ra6uPxRCfBT4BHaM5Luxr/XTRdfzNSHERdgfOT7gO8C/YUeVyHM7dtv4InZb+iJ2vOV1TdQd7MWdn8O+b+dzdbuIKhFULlSElM4Hg4ODg4ODg4PDUiKE+C6gSSlfOsc+jwEPSCn/soXn7cZeyPd6KeW3W1XuSscZQXZwcHBwcHBwWESEED7gLdij8AZ2vO2XURR3O+d+8kJsX2sNeybiMirPSNRz7udhu7Tsx45u8UHs6CBzxaK+4FhyH2QhxF8IIR4RQqSFEHfMsd8bhRCmsDPO5H9uWLSKOjg4ODg4ODi0BokdveQ+7AWgr8WOHf3don0s4I+x3Xgewk4W8mIp5SM0hwv4ALaB/ENsF5hnSynLI5Zc0Cy5i0VucYOF/ZXklVK+scp+b8T2UXrm4tXOwcHBwcHBwcHhQmPJXSyklN+BQurJviWujoODg4ODg4ODwwXOkhvIdXJ5LjD4JPBfwIeklEalHYUQbybnp+P3+67YtWvb4tXS4YImlTqDHUWoOB67RMosHs98CZAcqmGaJtPTUyQSElOx0DyCrvZOVHUhw5uuTPY/un9cSrmm1v0dvXRYKhy9XHiSyQSRSIRUWkO6DHxBN+2B9qWu1rKhml6uJAP5PmAPdnrKS4CvYzu
"text/plain": [
"<Figure size 720x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2019-10-13 10:58:36 +02:00
"fix, axes = plt.subplots(ncols=2, figsize=(10,4), sharey=True)\n",
"plt.sca(axes[0])\n",
2016-09-27 23:31:21 +02:00
"plot_decision_boundary(tree_clf, X, y)\n",
"plt.title(\"Decision Tree\", fontsize=14)\n",
2019-10-13 10:58:36 +02:00
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plot_decision_boundary(bag_clf, X, y)\n",
"plt.title(\"Decision Trees with Bagging\", fontsize=14)\n",
2019-10-13 10:58:36 +02:00
"plt.ylabel(\"\")\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"decision_tree_without_and_with_bagging_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-09-15 16:41:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Random Forests"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 15,
2017-09-15 16:41:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"bag_clf = BaggingClassifier(\n",
2021-03-04 11:06:21 +01:00
" DecisionTreeClassifier(max_features=\"sqrt\", max_leaf_nodes=16),\n",
" n_estimators=500, random_state=42)"
2017-06-02 10:57:06 +02:00
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 16,
2017-09-15 16:41:15 +02:00
"metadata": {},
2017-06-02 10:57:06 +02:00
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"bag_clf.fit(X_train, y_train)\n",
"y_pred = bag_clf.predict(X_test)"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 17,
2018-05-08 12:43:49 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
2019-01-18 16:08:37 +01:00
"rnd_clf = RandomForestClassifier(n_estimators=500, max_leaf_nodes=16, random_state=42)\n",
2016-09-27 23:31:21 +02:00
"rnd_clf.fit(X_train, y_train)\n",
"\n",
"y_pred_rf = rnd_clf.predict(X_test)"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 18,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2021-03-04 03:17:19 +01:00
"np.sum(y_pred == y_pred_rf) / len(y_pred) # very similar predictions"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 19,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sepal length (cm) 0.11249225099876375\n",
"sepal width (cm) 0.02311928828251033\n",
"petal length (cm) 0.4410304643639577\n",
"petal width (cm) 0.4233579963547682\n"
]
}
],
2016-09-27 23:31:21 +02:00
"source": [
"from sklearn.datasets import load_iris\n",
"iris = load_iris()\n",
2019-01-18 16:08:37 +01:00
"rnd_clf = RandomForestClassifier(n_estimators=500, random_state=42)\n",
2016-09-27 23:31:21 +02:00
"rnd_clf.fit(iris[\"data\"], iris[\"target\"])\n",
2017-06-02 10:57:06 +02:00
"for name, score in zip(iris[\"feature_names\"], rnd_clf.feature_importances_):\n",
" print(name, score)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 20,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"array([0.11249225, 0.02311929, 0.44103046, 0.423358 ])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"rnd_clf.feature_importances_"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 21,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY8AAAEWCAYAAACe8xtsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAACIs0lEQVR4nOz9eYwkWX6YCX7PDr/dLK6MKyMir8o6s66u7uquZrPJJsXu5kgAOaCwO8BIK+4swV1qNVgI0GAkQBpJnIEICBjNYHa0BHpJLqFrRtKK4nCHUjWpEskm2VXdXV3MOjIrqypPjwiP+zALv93M3v7x3CP8jHCP8Lgy7QMKlWFuxzOzZ+/33u8UUkpCQkJCQkL6QTvtBoSEhISEnD9C4RESEhIS0jeh8AgJCQkJ6ZtQeISEhISE9E0oPEJCQkJC+iYUHiEhISEhfRMKj5CQkJCQvjl14SGE+GtCiHeFEGUhxG/us9/PCyF8IUSu4b8fP7GGhoSEhITsYpx2A4As8N8B3wDiB+z7tpTyK8ffpJCQkJCQ/Th14SGl/C0AIcTngZlTbk5ISEhISA+cuvDok1eFEOvAJvBPgV+RUnqddhRC/CLwiwDJZOy1Z565eHKtDAk5gHLZQdfNtu2+XyUatU+hRY8H1aqBFAHC8NEQSAkIEEj1mxYg9ICgYlIpRyBRQA80KhUTYiWEHpz2LZwpHr73cF1KeaHTb+dJeHwHuAE8Al4A/iXgAb/SaWcp5beAbwG89tpT8p13/vsTambIk4zjZHCcRXx/G10fwrYvYttzbftlMm/j+yUsy9rd5rouuh5jbu6Nk2zyY0U2O0I1WiI66hD1NYJAB83H0CUry8NUo2WMYZfiwxnu35/BePXPSBVTZDKT6M/fQbfyp30LZ4qfj/z8o26/nbrBvFeklPellA+klIGU8kPgl4G/eNrtCgmpowTHxzWhMI7vl3Ccj3GcTNu+tn0RXa/gui5QFxwVbDtcIYecD87TyqOV2oI05HGj19n7aZ+z/RqL+H5kdzVhWRau6+I4i23Xqv/tOIu47mqtTdcG3qaQkOPi1IWHEMKotUMHdCFEDPBabRlCiJ8G3pNSrgghngX+DvCvT7zBIcfK3uw9gmWN1wbfjwEOPbAexzk74fvbWNZ40zYlQFY77m/bc6GwCDm3nLrwAP428Hcb/v5LwN8XQvwGcBt4XkqZAX4S+E0hRApYAf4Z8A9OurFPMmdt9t7PObe2doAcjlNCzU9SwOHP2QldH8J13Q52jKGBXSMk5Kxw6jYPKeXfk1KKlv/+npQyI6VM1QQHUsq/IaWckFImpZRXpZT/jZSyetrtf1LoR59/FNTs3WraZlkWvr996HNubT0CVpHSw7Js1KJ2tbZ9cIR2jJAniVMXHiHng04rAt+P4DiLA71OffbeyFFn77mcSy5XxrKSAFhWklyuTC7nHnBkfyg11HPoeqxmx4hh28+FqqmQx5KzoLYKOQf0q88/LLZ9Ecf5eFf9szd7v3boc6ZSFuDhODlsO4Xj5EilooB10KF9E9oxQp4UQuER0hMnpc8/Di+k4eFLbG0ZaFoe13XQtDgwzPBwqE46K5yEPS1ksITCI6QnjmNF0P1ag529K5uDi++PtrQ9FB5ngZPyhgsZLKHNI6QnzrM+/zy3/UngpOxpIYMlXHmE9Mx51uef57Y/7pyUPS1ksITCIySkgVD3fvKE8THnk1B4nHHCwax/DvvMGnXvAPPz7zI//0dY1gvMzb3c93MP311vnKQ9LWRwhMLjDBMaEvvnKM+srnuHEq77gEQihhBxXPcRjhPp6RyDaMeTRpjn63wSCo8jcpyzy+NI1fG40+sz6/Te6rr3hYUMQWBi26na3t6uAbd34RG+u34IbVLnj1B4HIHjnl0+jobE41bl9PLMur23XM4HXKQsYtuqIJPr5hEi1vdzfxzfXUhII6Gr7hE4bhfD40jVcZqcRH6sXp5Zt/cGEl2vkM9XcZxcTXBUsKyJvp/74/buQkJaCYXHETiOJH6NPG6J9k7Gn18jm73JnTt/yMLChywsPGh7Zt3eWyplYNvPYVlXKBS2yOdLWNYVINb3c3/c3l1ISCuh2uoIHLeL4Xk2JO5nU2hkkKoctYLZJpWaBfLkcmvAJsPDP9L0zPZ7b7Y9x4svzuE4LzW0P9b3cz/P7y4kpBdC4XEETsLF8DwaEg+yKRyXsK2vbGZmrjSdH4Km/Xp5b4N47ufx3YWE9EooPI7A4zq7PKpRu5unEVR2VTnHIWx7Xdk8ru8tJOQkCYXHEXncZpeD8CDrNohDCdt+7tgG7X7UiI/bewsJOWlC4RHSxCDiEw6yKRzXoH2WI5XDaPOQx41QeIQ00Y9Ru9uAeFqD+FHVUcc1wIfR5iGPI6HwCGmiV9VPLwPiadgUDruyGcQA3034hNHmIY8jofAIaaLXVcNBA+J5sykcdYDfT/iE0eYhjyOh8AhpotdVw+M2IB71fvYTPvut5kJbSMh5JRQeIW30smo4TIDk0VKl93bcYa+Ry/lks98lmTQRIl4TJLGeY1D2Ez4jIy90XM1BIrSFhJxbwvQkIYei3/Qbh81r1c9xR7kGuEAJMAiCCtnsR+RySz2nE6kLU9ddZWHhQ+bnf8CdO2+Ty3l0K4MLQVh+NeTcEq48Qg5Fv0bxw9oU+jmusR7HwkIGKYu1JIdVXnxx/2ukUlOkUsO47gpKiMSBdM8rANu+SCbzJ7juIsnkMFLqwA6wg+NkOq7mNjdvPVaqv5Ani1B4hByafozih7Up9HNcPSGl6z6o1eOwkTKH697GcV7q2ta9a1hN1+pnEFfntoANwEPT4kxP3wBiXQXkXubdEq67gpQl8nkPy7rU83VDQk6LUHiE7MugDLqHTSLZz3G6PsT8/B8AOcCgUDCAJMq20H2F02jvyOeLgAB8IM3ISKbn+02ldKanv9y2vZsQ6mW1EhJyVgmFx2PEoD13BhncdtjAwcbjoEQ2+x6wAlzAcdZbaotrwAKQJpGwKRQ2gR1Sqde6pslvtHfk85XauQvALKnUSF/326+APMxqJSTkrBAazB8TjqPQ0iDrb3QzGh80QNaPy+XyZLPfBdaASySTU7juHTKZP2m4xwC4CGgUCnnUqmOCXG696wBet3eoQXsbkEAaiDIzc6Wv+z1MDY9USufZZ7/M7OwXmJm5gWWND7QmTEjIcRGuPB4TjiOKeWvrEVDGcYq77qtqBXI4g+5hAwfVcYu47jiJxBx7tcXBdTd379H3t5me/lyDzSOF4+QoFLa6DuCN9g7HuYhl1cvPOrjuKq67Si63UmvH/iu5w0TWH3dNmJCQ4yIUHo8Jgw7ac5wMudwSAFNTk6iyrA9w3TzDwydfDU8JskUKhQiFgkEiMY6qM+7tztJ1fai28rpSG/gdCoUqlvV8B88speLLZu+Qzd5jevoaQsRw3TxSSgqFEkI8IJcrk0rtreSAWsqRbnm9+hOQZzmZY0jIfoTC4zFh0DNYpc6ZBVZx3Ty2nWJpaRnIcPny6wNpc+9tqQuyEhAFfAqFLIVCDogxMjIE7A3Evh9jZuYGrusyMlKpxVQ0n69uy5mefpZs9iOy2Y9IpWbJ5+dre5k1wRHFsiaaVnLqHAfbgnqxQYW1RULOK6HweEwY9AzW97eZmbmC6yZx3RVc1yGVGgZiR0g82J8xv37M/Pw7KCO2Vft/AigDGSzrS7sqqV4H4mYVnwXcIJu9Sy63jGU9Cwhc9zap1HhNcKgVXX0l5zgcqCLsx9ngvOUBCwmBUHg8Ngx6Bru3khnfHTyVQIr1fa7DeG01HpNMqoC9fH4DMIE8yp32AnNzX2k6Ry8DcauKr9GWc+XKTwOQyYzWVGDtK7leVIRhJt3Hlx+8abPxwGzbPnqlyhe+6Zy58x4Xpy48hBB/Dfh54EXgf5FS/vw++/514L9Ghf/+G+CXpJTlE2jmuWCQM9jWlczCwgNyuXlSqand33tPVd7/QNp4jOvGCYIKyeQoQhjMzLy4K8iOK+ak00pOqc4sXPezXTtJs2DdO/5xSxwZssfGA5Opa9W27Uv32gf+s3De4+LUhQeQBf474BsoodA
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"plt.figure(figsize=(6, 4))\n",
"\n",
"for i in range(15):\n",
2017-06-02 10:57:06 +02:00
" tree_clf = DecisionTreeClassifier(max_leaf_nodes=16, random_state=42 + i)\n",
2017-06-06 16:32:08 +02:00
" indices_with_replacement = np.random.randint(0, len(X_train), len(X_train))\n",
2016-09-27 23:31:21 +02:00
" tree_clf.fit(X[indices_with_replacement], y[indices_with_replacement])\n",
2019-10-13 10:58:36 +02:00
" plot_decision_boundary(tree_clf, X, y, axes=[-1.5, 2.45, -1, 1.5], alpha=0.02, contour=False)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-09-15 16:41:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## Out-of-Bag evaluation"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 22,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"0.8986666666666666"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"bag_clf = BaggingClassifier(\n",
2021-03-04 03:17:19 +01:00
" DecisionTreeClassifier(), n_estimators=500,\n",
2019-01-18 16:08:37 +01:00
" bootstrap=True, oob_score=True, random_state=40)\n",
2016-09-27 23:31:21 +02:00
"bag_clf.fit(X_train, y_train)\n",
"bag_clf.oob_score_"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 23,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0.32275132, 0.67724868],\n",
" [0.34117647, 0.65882353],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [0.09497207, 0.90502793],\n",
" [0.31147541, 0.68852459],\n",
" [0.01754386, 0.98245614],\n",
" [0.97109827, 0.02890173],\n",
" [0.97765363, 0.02234637],\n",
" [0.74404762, 0.25595238],\n",
" [0. , 1. ],\n",
" [0.7173913 , 0.2826087 ],\n",
" [0.85026738, 0.14973262],\n",
" [0.97222222, 0.02777778],\n",
" [0.0625 , 0.9375 ],\n",
" [0. , 1. ],\n",
" [0.97837838, 0.02162162],\n",
" [0.94642857, 0.05357143],\n",
" [1. , 0. ],\n",
" [0.01704545, 0.98295455],\n",
" [0.39473684, 0.60526316],\n",
" [0.88700565, 0.11299435],\n",
" [1. , 0. ],\n",
" [0.97790055, 0.02209945],\n",
" [0. , 1. ],\n",
" [0.99428571, 0.00571429],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0.62569832, 0.37430168],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [0.13402062, 0.86597938],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0.38251366, 0.61748634],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0.27093596, 0.72906404],\n",
" [0.34146341, 0.65853659],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0.00531915, 0.99468085],\n",
" [0.98843931, 0.01156069],\n",
" [0.91428571, 0.08571429],\n",
" [0.97282609, 0.02717391],\n",
" [0.98019802, 0.01980198],\n",
" [0. , 1. ],\n",
" [0.07361963, 0.92638037],\n",
" [0.98019802, 0.01980198],\n",
" [0.0052356 , 0.9947644 ],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [0.97790055, 0.02209945],\n",
" [0.8 , 0.2 ],\n",
" [0.42424242, 0.57575758],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0.66477273, 0.33522727],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [0.86781609, 0.13218391],\n",
" [1. , 0. ],\n",
" [0.56725146, 0.43274854],\n",
" [0.1576087 , 0.8423913 ],\n",
" [0.66492147, 0.33507853],\n",
" [0.91709845, 0.08290155],\n",
" [0. , 1. ],\n",
" [0.16759777, 0.83240223],\n",
" [0.87434555, 0.12565445],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0.995 , 0.005 ],\n",
" [0. , 1. ],\n",
" [0.07878788, 0.92121212],\n",
" [0.05418719, 0.94581281],\n",
" [0.29015544, 0.70984456],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0.83040936, 0.16959064],\n",
" [0.01092896, 0.98907104],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [0.21465969, 0.78534031],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [0.94660194, 0.05339806],\n",
" [0.77094972, 0.22905028],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0.16574586, 0.83425414],\n",
" [0.65306122, 0.34693878],\n",
" [0. , 1. ],\n",
" [0.02564103, 0.97435897],\n",
" [0.50555556, 0.49444444],\n",
" [1. , 0. ],\n",
" [0.03208556, 0.96791444],\n",
" [0.99435028, 0.00564972],\n",
" [0.23699422, 0.76300578],\n",
" [0.49509804, 0.50490196],\n",
" [0.9947644 , 0.0052356 ],\n",
" [0.00555556, 0.99444444],\n",
" [0.98963731, 0.01036269],\n",
" [0.26153846, 0.73846154],\n",
" [0.92972973, 0.07027027],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [0.80113636, 0.19886364],\n",
" [1. , 0. ],\n",
" [0.0106383 , 0.9893617 ],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [0.98181818, 0.01818182],\n",
" [1. , 0. ],\n",
" [0.01036269, 0.98963731],\n",
" [0.97752809, 0.02247191],\n",
" [0.99453552, 0.00546448],\n",
" [0.01960784, 0.98039216],\n",
" [0.17857143, 0.82142857],\n",
" [0.98387097, 0.01612903],\n",
" [0.29533679, 0.70466321],\n",
" [0.98295455, 0.01704545],\n",
" [0. , 1. ],\n",
" [0.00561798, 0.99438202],\n",
" [0.75690608, 0.24309392],\n",
" [0.38624339, 0.61375661],\n",
" [0.40625 , 0.59375 ],\n",
" [0.87368421, 0.12631579],\n",
" [0.92462312, 0.07537688],\n",
" [0.05181347, 0.94818653],\n",
" [0.82802548, 0.17197452],\n",
" [0.01546392, 0.98453608],\n",
" [0. , 1. ],\n",
" [0.02298851, 0.97701149],\n",
" [0.9726776 , 0.0273224 ],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [0.01041667, 0.98958333],\n",
" [0. , 1. ],\n",
" [0.03804348, 0.96195652],\n",
" [0.02040816, 0.97959184],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [0.94915254, 0.05084746],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [0.99462366, 0.00537634],\n",
" [0. , 1. ],\n",
" [0.39378238, 0.60621762],\n",
" [0.33152174, 0.66847826],\n",
" [0.00609756, 0.99390244],\n",
" [0. , 1. ],\n",
" [0.3172043 , 0.6827957 ],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0.00588235, 0.99411765],\n",
" [0. , 1. ],\n",
" [0.98924731, 0.01075269],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0.62893082, 0.37106918],\n",
" [0.92344498, 0.07655502],\n",
" [0. , 1. ],\n",
" [0.99526066, 0.00473934],\n",
" [1. , 0. ],\n",
" [0.98888889, 0.01111111],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0.06989247, 0.93010753],\n",
" [1. , 0. ],\n",
" [0.03608247, 0.96391753],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0.02185792, 0.97814208],\n",
" [1. , 0. ],\n",
" [0.95808383, 0.04191617],\n",
" [0.78362573, 0.21637427],\n",
" [0.56650246, 0.43349754],\n",
" [0. , 1. ],\n",
" [0.18023256, 0.81976744],\n",
" [1. , 0. ],\n",
" [0.93121693, 0.06878307],\n",
" [0.97175141, 0.02824859],\n",
" [1. , 0. ],\n",
" [0.00531915, 0.99468085],\n",
" [0. , 1. ],\n",
" [0.43010753, 0.56989247],\n",
" [0.85858586, 0.14141414],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0.00558659, 0.99441341],\n",
" [0. , 1. ],\n",
" [0.96923077, 0.03076923],\n",
" [0. , 1. ],\n",
" [0.21649485, 0.78350515],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [0.98477157, 0.01522843],\n",
" [0.8 , 0.2 ],\n",
" [0.99441341, 0.00558659],\n",
" [0. , 1. ],\n",
" [0.09497207, 0.90502793],\n",
" [0.99492386, 0.00507614],\n",
" [0.01714286, 0.98285714],\n",
" [0. , 1. ],\n",
" [0.02747253, 0.97252747],\n",
" [1. , 0. ],\n",
" [0.77005348, 0.22994652],\n",
" [0. , 1. ],\n",
" [0.90229885, 0.09770115],\n",
" [0.98387097, 0.01612903],\n",
" [0.22222222, 0.77777778],\n",
" [0.20348837, 0.79651163],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [0.20338983, 0.79661017],\n",
" [0.98181818, 0.01818182],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0.98969072, 0.01030928],\n",
" [0. , 1. ],\n",
" [0.48663102, 0.51336898],\n",
" [1. , 0. ],\n",
" [0.00529101, 0.99470899],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [0.08379888, 0.91620112],\n",
" [0.12352941, 0.87647059],\n",
" [0.99415205, 0.00584795],\n",
" [0.03517588, 0.96482412],\n",
" [1. , 0. ],\n",
" [0.39790576, 0.60209424],\n",
" [0.05434783, 0.94565217],\n",
" [0.53191489, 0.46808511],\n",
" [0.51898734, 0.48101266],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [0.60869565, 0.39130435],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0.24157303, 0.75842697],\n",
" [0.81578947, 0.18421053],\n",
" [0.08717949, 0.91282051],\n",
" [0.99453552, 0.00546448],\n",
" [0.82142857, 0.17857143],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [0.11904762, 0.88095238],\n",
" [0.04188482, 0.95811518],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0.89150943, 0.10849057],\n",
" [0.19230769, 0.80769231],\n",
" [0.95238095, 0.04761905],\n",
" [0.00515464, 0.99484536],\n",
" [0.59375 , 0.40625 ],\n",
" [0.07692308, 0.92307692],\n",
" [0.99484536, 0.00515464],\n",
" [0.83684211, 0.16315789],\n",
" [0. , 1. ],\n",
" [0.99484536, 0.00515464],\n",
" [0.95360825, 0.04639175],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0.26395939, 0.73604061],\n",
" [0.98461538, 0.01538462],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0.00574713, 0.99425287],\n",
" [0.85142857, 0.14857143],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0.75301205, 0.24698795],\n",
" [0.8969697 , 0.1030303 ],\n",
" [1. , 0. ],\n",
" [0.75555556, 0.24444444],\n",
" [0.48863636, 0.51136364],\n",
" [0. , 1. ],\n",
" [0.92473118, 0.07526882],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0.87709497, 0.12290503],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [0.74752475, 0.25247525],\n",
" [0.09146341, 0.90853659],\n",
" [0.42268041, 0.57731959],\n",
" [0.22395833, 0.77604167],\n",
" [0. , 1. ],\n",
" [0.87046632, 0.12953368],\n",
" [0.78212291, 0.21787709],\n",
" [0.00507614, 0.99492386],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0.02884615, 0.97115385],\n",
" [0.96 , 0.04 ],\n",
" [0.93478261, 0.06521739],\n",
" [1. , 0. ],\n",
" [0.50731707, 0.49268293],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0.01604278, 0.98395722],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0.96987952, 0.03012048],\n",
" [0. , 1. ],\n",
" [0.05172414, 0.94827586],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0.99494949, 0.00505051],\n",
" [0.01675978, 0.98324022],\n",
" [1. , 0. ],\n",
" [0.14583333, 0.85416667],\n",
" [0. , 1. ],\n",
" [0.00546448, 0.99453552],\n",
" [0. , 1. ],\n",
" [0.41836735, 0.58163265],\n",
" [0.13095238, 0.86904762],\n",
" [0.22110553, 0.77889447],\n",
" [1. , 0. ],\n",
" [0.97647059, 0.02352941],\n",
" [0.21195652, 0.78804348],\n",
" [0.98882682, 0.01117318],\n",
" [0. , 1. ],\n",
" [0. , 1. ],\n",
" [1. , 0. ],\n",
" [0.96428571, 0.03571429],\n",
" [0.34554974, 0.65445026],\n",
" [0.98235294, 0.01764706],\n",
" [1. , 0. ],\n",
" [0. , 1. ],\n",
" [0.99465241, 0.00534759],\n",
" [0. , 1. ],\n",
" [0.06043956, 0.93956044],\n",
" [0.98214286, 0.01785714],\n",
" [1. , 0. ],\n",
" [0.03108808, 0.96891192],\n",
" [0.58854167, 0.41145833]])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2017-06-02 10:57:06 +02:00
"bag_clf.oob_decision_function_"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 24,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"0.912"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"from sklearn.metrics import accuracy_score\n",
"y_pred = bag_clf.predict(X_test)\n",
"accuracy_score(y_test, y_pred)"
]
},
{
"cell_type": "markdown",
2017-09-15 16:41:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## Feature importance"
]
},
2021-03-01 21:29:06 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning:** since Scikit-Learn 0.24, `fetch_openml()` returns a Pandas `DataFrame` by default. To avoid this and keep the same code as in the book, we use `as_frame=False`."
]
},
2017-04-07 21:33:53 +02:00
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 25,
2018-05-08 12:43:49 +02:00
"metadata": {},
2017-04-07 21:33:53 +02:00
"outputs": [],
"source": [
2019-01-18 16:08:37 +01:00
"from sklearn.datasets import fetch_openml\n",
"\n",
2021-03-01 21:29:06 +01:00
"mnist = fetch_openml('mnist_784', version=1, as_frame=False)\n",
2019-01-18 16:08:37 +01:00
"mnist.target = mnist.target.astype(np.uint8)"
2017-04-07 21:33:53 +02:00
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 26,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"RandomForestClassifier(random_state=42)"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2019-01-18 16:08:37 +01:00
"rnd_clf = RandomForestClassifier(n_estimators=100, random_state=42)\n",
2016-09-27 23:31:21 +02:00
"rnd_clf.fit(mnist[\"data\"], mnist[\"target\"])"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 27,
2018-05-08 12:43:49 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"def plot_digit(data):\n",
" image = data.reshape(28, 28)\n",
2019-01-16 16:42:00 +01:00
" plt.imshow(image, cmap = mpl.cm.hot,\n",
2016-09-27 23:31:21 +02:00
" interpolation=\"nearest\")\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 28,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saving figure mnist_feature_importance_plot\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEYCAYAAACtEtpmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWgklEQVR4nO3dfbBdVX3G8ecJxFISQiQiNLxFBLQZXzIUB7UyorYSWixWrYoKSkVFRxBHcLTFii/gqFMGGFsZCqgRivG1ikixDEWHF63B8YU6WiAmBFJCQkyAhIiQ1T/OvvZwSbJ/l6yV3PPL9zNzh3vPfc7a+9xcznPW3mfd7VKKAADYkinbewcAAJMfZQEA6EVZAAB6URYAgF6UBQCgF2UBAOhFWQDAFtje3/aDtnfa3vuyPVEWACY929fY/ugmbj/W9j22d2617VLKnaWU6aWUR1ttI8r2HNul1uO1faTtuyJZygLAKPi8pONte9ztx0u6vJTySHSglsXS0vbeb8oCwCj4N0l7SDpi7AbbT5Z0jKQFtqfY/oDtO2zfZ/vLtvfocmOvxt9q+05J19m+yvYpwxuw/TPbrxy/4fGv5m1fb/vjtm/qDk9daXuW7ctt32/7R7bnDN2/2D7V9mLbq2x/2vaU7ntTbJ9pe6nte20vsL375vZb0ve7Ydd0236B7afbvq573Ku6/Zg5tP0ltk/vHt9a2wtt72J7mqSrJc3uxnrQ9uzN/QN4S3/uY5rN3wIBdjDrShn/6n3C5s+fX1atWjWh+9xyyy3/LWnD0E0XlVIuGvvC9r9o8Jx1Uvf1OyS9s5Qyz/Zpkl4v6TWSVkq6QNKMUspx3RP3ryV9UdI7JW2U9ApJ7yulHN6N9VwNnoz/qJTy8PB+Dd1/ainlEdvXS9pX0lGSVkm6WdLOkt4l6XpJl0p6tJRyYnf/0t3+aknTJV0r6VOllItt/62kv5P0ckn3SlogaV0p5fjN7Pdew/vSjX+QpKdpUCQzJH1N0o9LKad131/Sjf3K7ud7o6TzSykX2j5S0mWllH239G+j7gECQFWrVq3SokWLJnQf2xtKKYdtIfIFSVfZPqWU8pCkE7rbJOkdkt5dSrmrG+ssSXfaPn7o/meVUtZ13/+mpAttH1xKuU2Dw1kLxxfFFnyulHJHN9bVkuaWUq7tvv6KpI+Ny3+ylLJa0mrb50k6TtLFkt4o6dxSyuLuvh+UdKvtEzez34/bkVLK7ZJu775caftcSR8eF7uglLK8G+NKSfOCj/P3KAsADRRJ4dMIsRFLucH2SknH2v4vSc+T9Kru2wdI+obtjUN3eVSDV+Jjlg2N9VvbX5b0Jtsf0eDJ+zUT2J0VQ58/tImvp4/LLxv6fKmkscM9s7uvh7+38+b2e1NsP1WDmdQRknbT4PTCb8bF7hn6fP3Q9sM4ZwGgkUcm+BGyQIMZxfGSvltKGXuSXibp6FLKzKGPXUopdw/dd/xh9S9o8Mr+ZZLWl1JunuADnIj9hj7fX9Ly7vPlGhTd8Pce0WPLp2zm8zGf6G5/TillhqQ3SYoeSgyfaqAsADQwNrNoUhZ/Jult+v9DUJJ0oaSzbR8gSbb3tH3sFvdwUA4bJf2jBucFWjrD9pNt7yfpPZIWdrdfIem9tp9me7qkczQ4HLa5H8jKbp8PHLptN0kPanDSex9JZ0xgv1ZImjV2Un1LKAsADbQpi1LKEkk3SZom6VtD3zq/+/q7th+Q9ANJhweGXCDp2ZIuC+3AE/dNSbdI+omkqyRd0t1+qQZF9X0NTlxvkHTKJu4vSSqlrJd0tqQbba+x/XxJH5F0qKS13dhfj+5UKeWXGhTW4m483g0FIKbGu6EOO2xeWbTo2gndx97zlp4T3NXZPkHS20spL2q4jSLp4O5E9MjiBDeABuqf4K7N9q4avN31n7f3vowCDkMBaKDZOYsqbB+lwfH/FZL+dZtufEQxswDQyOSdWZRSrtHgvMe22NZWH9abDCgLAA0UDZY5IAvKAkADk/+cBSaGsgDQAGWRDWUBoBHKIhPKAkADzCyyoSwANEBZZENZAGiAssiGsgDQAGWRDWWxDexUaZyplbYVWbYfGWdDf0S/C2S2Jd75vy1RFplQFgAaYGaRDWUBoAHKIhvKAkADlEU2lAWABiiLbCgLAI1QFplQFgAaYGaRDWUBoAHKIhvKAkADXM8iG8piK0UWr+0SyEQWys0MZGYHMpF93i2QeUYgc2Ig8+1KmcjP+eeBzG8DGSm2KHHHfbpkZpENZQGgEcoiE8oCQAPMLLKhLAA0QFlkQ1kAaICyyIayANAAZZENZQGgEcoiE8oCQAPMLLKhLAA0QFlkQ1lsRvTqdpGFYJEFbgcEMrMCmYMCmSMCmQMDme8EMp8KZCKL2yLuCWSeFMhEr+4X+bd/uOL2RgtlkQ1lAaAByiIbygJAI5RFJpQFgAaYWWRDWQBoYKPqnY3CZEBZAGhkx/2buxlRFgAa4DBUNpQFgAYoi2woCwANUBbZpCuLyGK6qYHMrsHt7R3IPC+QiRzdfWkg87rIg4usAAycm3zmssBP6VXreyOPfqN/mEv7I7o9kFkWyCwJZKTY6dsHApmVgczoLdyjLLJJVxYAJgPKIhvKAkAjlEUmlAWABphZZENZAGiAssiGsgDQAGWRDWUBoAHKIhvKAkAjlEUmlAWABphZZJOuLKYEMpEFd5GroEmxq+AdFshErkw3O5A5J7B66w2B1WtzTghs7P39C+7uDyy4+2pgU5Er3O0eyEQW0kXWLErSikAmsgiw1rYm15/toyyySVcWACYDyiIbygJAA5RFNpQFgEYm14ExbB3KAkADzCyyoSwANEBZZENZAGiAssiGsgDQCGWRCWUBoAFmFtmMVFlEroIXWbwVGWdmIBPNRa5yNjeQWRTIzIlkpgdCzwhk/qI/MuPm/syLb+jPrOmPhK5wt1cgszqQkWKLANcFMvcFMpHFhNH93jYoi2xGqiwAjArKIhvKAkAbhXUWmVAWANrYuL13ADVRFgDqK2IBdzKUBYD6KIt0KAsAbXAYKhXKAkB9zCzSoSwAtMHMIhXKAkB9zCzSGamyiFwyNSJyydSpwbEi/z/8MJB5VqVMZKWv9glkAquqNS2QeaDOMHsEMq8LZL4XyCwOZKTYiunI6v1ZgczyQGbSoSxSGamyADAiijgMlQxlAaANZhapUBYA6uOcRTqUBYA2OAyVCmUBoD5mFulQFgDaYGaRCmUBoD5mFulQFgDqoyzSGamyiMxqI4vpIgulIgv3JOnwQCZw9VHNOyEQCqwou3tpYJzItT5/HsgE3PPT/sze+wYGCqxcWxvY1gsDm7otkJGk2wOZXQOZyIK70GLLyYbDUKmMVFkAGBHMLNKhLAC0QVmkQlkAqI8/95EOZQGgDWYWqVAWAOpjZpEOZQGgDWYWqVAWAOrj3VDpUBYA2uAwVCojVRaRK+VFXsxExokuylsRyNwTGeiAQOaY/sgD/9Sf+cFd/ZnI458XWAG5d2SVZORJJXCpvDmBf9gfBLb13P6IpNgV9X4RyAQuJjh6z7vMLNIZqbIAMEIoi1QoCwD18W6odCgLAG0ws0iFsgBQHzOLdCgLAG0ws0iFsgBQH++GSoeyANAGh6FSoSwA1MfMIp0dsix2C2RmB8c6NpB56VMCoTMDmWX9kWcEFuVFrgR3yBmB0KH9kfuP689sCFwq7qmv7c/c+5/9ma/2R0K/H1LsKolLApmVgczIPe9SFunskGUBYBvgMFQqlAWA+phZpENZAGiDmUUqlAWA+phZpENZAGiDskiFsgBQH3/uIx3KAkAbzCxSoSwA1MfMIp10ZRG5wts+gcys4PbmRkLrA5k/qbMxv7o/c8jawLbOCWQ+1B9ZHRhmTuAqeLq4P7IoMMxOgczBgYwUW9w4LThWSswsUklXFgAmAd4NlQ5lAaANDkOlQlkAqI+ZRTqUBYD6KIt0KAsAbXAYKhXKAkB9zCzSoSwAtMHMIhXKAkB9zCzSSVcWkd/PhwOZ6JXyfhnI7P2sQOj9gczNgcyUQOYNgczOpT+z3r2RyHo7fak/csfL+zM/DmxqSSBzdyAjSW8
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"plot_digit(rnd_clf.feature_importances_)\n",
"\n",
"cbar = plt.colorbar(ticks=[rnd_clf.feature_importances_.min(), rnd_clf.feature_importances_.max()])\n",
"cbar.ax.set_yticklabels(['Not important', 'Very important'])\n",
"\n",
"save_fig(\"mnist_feature_importance_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-09-15 16:41:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# AdaBoost"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 29,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"AdaBoostClassifier(base_estimator=DecisionTreeClassifier(max_depth=1),\n",
" learning_rate=0.5, n_estimators=200, random_state=42)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"from sklearn.ensemble import AdaBoostClassifier\n",
"\n",
"ada_clf = AdaBoostClassifier(\n",
2017-06-02 10:57:06 +02:00
" DecisionTreeClassifier(max_depth=1), n_estimators=200,\n",
" algorithm=\"SAMME.R\", learning_rate=0.5, random_state=42)\n",
"ada_clf.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 30,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY8AAAEWCAYAAACe8xtsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAABtPklEQVR4nO39eZwcV33oDX9PVfW+zCaPRqPRSJZsSZZsjLEx2AZsSLAxbwgQSB4C5AIJj28c4F4gTkLeSx5jQgKX+AGycMn1G8AJkBsCGEKC2QIYvIHxbkmWZGsbjWZGo9l6mV6r6rx/9HRPd093T+/L6Hw/H4GnupbT1VXnd367kFKiUCgUCkUtaJ0egEKhUCh6DyU8FAqFQlEzSngoFAqFomaU8FAoFApFzSjhoVAoFIqaUcJDoVAoFDWjhIdCoVAoaqbjwkMI8V4hxKNCiKQQ4u4K+71TCGEJIaJ5/25o20AVCoVCkcPo9ACAKeBjwE2AZ519H5ZSvqz1Q1IoFApFJTouPKSU9wAIIa4Cxjo8HIVCoVBUQceFR41cIYSYAxaALwEfl1KapXYUQtwC3ALg83mv3Lv3ovaNUqFYh0TiNEI4AJG3VSJlGrd7W6eGpVAU8NhjT89JKS8o9VkvCY+fAZcCp4D9wFcBE/h4qZ2llHcBdwFcddXl8pFHvt+mYSrOZ6LRQ8zP30siMYnbPcbQ0Gvx+/et2e/UqTtJp0MYRl9um2mGcDj62L79tnYOWaEoi65vOVXus447zKtFSnlcSnlCSmlLKZ8BPgq8udPjUiiyRKOHmJz8HOl0CKdzlHQ6xOTk54hGD63Zd2jotZjmEqYZQkob0wxhmksMDb22AyNXKGqnlzSPYiSFOr9ig1Dt6r3T5yxmfv5eDKM/p01k/39+/t411/L79zE2dmvBmEZGfrvpY1IoWkXHhYcQwlgZhw7oQgg3YBb7MoQQNwOPSynPCiH2An8GfK3tA1a0lOzq3TD6C1bvY2O31j2xtuKcpUgkJnE6Rwu26XqARGKy5P5+/z4lLBQ9S8eFB/Bh4Pa8v98O3CGE+AJwCNgnpZwAfgW4WwjhB84CXwb+st2DPZ/pttV7Lee0bYvl5YOYZhjDCOJwjDR0zlK43WNr/BiWFcHtVkGEio1Hx30eUsqPSClF0b+PSCknpJT+FcGBlPI2KeVmKaVPSrlTSvn/SCnTnR7/+UIt9vxGSCQm0fVAwbZKq/dqiEQOsLz8LJaVQNcDWFaC5eVniUQONDrcApQfQ3E+0XHhoegN8jUCITQMow/D6Gd+/t6mXsftHsOyIgXbGl29W1YYITQ0zQ0INM2NEBqWFW5wtIVk/RgORx+p1BQOR1/TTWMKRbfQDWYrRQ9Qqz2/XoaGXsvk5Ody57esCKa5xMjIb9d9Tl0PkE4vYdsJhHAhZRKQazScZqD8GIrzBSU8FFXRLnt+K6KQAoHL0DQfqdQMlhVG14O43Tvw+XY2ceSKRmiHP03RXJTwUFRFKzSCcjR79T409Fri8c/h9+8vGLvyRXQH7YqGUzQX5fNQVEUv2/N7eeznA+3ypymai9I8FFXTy/b8Xh77Rqdd/jRFc1HCQ6HIQ9ne24/Kj+lNlPDoctRkVjv13rN82zs4mJ//MWfPfp3+/usZHX1Hzfdd/XbV0U5/mqJ5KJ9HF9OuxLyNRCP3LGt7t6wkkchjAOj6AMvLz9R839VvVz3KJ9WbKM2jQVq5umxFqY6NTrX3rNTvlrW9R6MH0TT3SlKhxLIiOQdutfdd/Xa1oXxSvYcSHg3Q6hDDjehIbLUpp5p7Vu530zQPlhXBssJoWiaB0LaTGEaw5vu+EX87hSIfZbZqgFaHGLaiVEcnaYcpp5p7Vu53A4FpLiGEEykT2Hbmn8dzcc33faP9dgpFMUp4NEArivjls9EK7bUjnt/j2UMo9CDnzn2bpaUHicWOr7ln5X43KZMrWuOlpNOLCAGBwJVomrPm+77RfjuFohhltmqAVocY9nLDoEo+hXyaKWyj0UMsLPwAj2cvqdQ06fQ8phlifPwDBfes0u/m9+9j9+5PFozf4dhc833v5d9OoagGJTwaoB0hhr3oSFzPp9AqYZuv2Xi9mbpVphkiHj8CvC63XzW/WzPuey/+dgpFtSjh0QAbdXXZqFO7XKSRbacwzSWgNcK2Ws1mo/5uCkU7UcKjQTba6rIZEWTlJnHLmmrppF2LGXGj/W4KRbtRwkNRQDPyE9bzKbRq0u7mTGWVba7YaCjhoSigFqd2uQmxU5N4o+aoVk3wquS4YiOihIeigGpNP+tNiJ3yKdSr2TRjgi8nfFS2uWIjooSHooBqtYb1JsRe8yk0OsFXEj4q21yxEVHCQ1FAtVrDRpsQG/0+lYRPJW1O+UIUvYoSHoo1VKM11JMg2Uip9GqPq/caQjhZXLwPKVPoehCP5yJ03VV1Dkol4bN16y0ltblg8GrlC1H0LKo8iaIuai2/UW9dq1qOa+QaqdQMphkGHNh2nHD4YeLxk1WXE8nWskqlzhEKPcT8/PdYWroPIVxlS47H40dU+1VFz6I0D0Vd1OoUr9enUMtx+f04otGDWFYYIZxMTd3N7t2frHgNt3s7TucW4vHnMM0whhHA5RqtWgMYGnotJ09+gljsOLoeQAgHphkhmZwiGj1UUps7c+auDWX6U5xfKOGhqJtanOL1+hRqOS6zzUEk8thKP44AUiZYXPxZbgKvdA3D0HA6LwBASptUaqqq7waZe+F0jpBKzSFlCk0L4vNdhq67ygrIrOnPtlM5oaVpTny+y6q+rkLRKZTwUFSkWQ7deotI1nKc2z3G7Ox3sKwlpLTQNBea5sfpHKyo4eT7O0AHwLbjOJ2bKgqdYqRMMTBwA0JoedvssgKyGm1FoehWlM9jAxGNHuLUqTs5cuT9nDp1Z8N9MprZf6PeEuX5xyWTZzl37j84e/ZrzM5+i6NH/6RgLB7PHpLJ01hWGiGcWFZixccwWnYCz/d3WJZJIjFBInEK207jcIzU9H1r7eGR1VYMIwik0TQPweA1eDw7lN9D0fUo4bFBaEWjpWb236i3T3X2ONtOMj//n6TTszidI2hagKWlBzl58hO57xiPH8Hl2oamOVZMR+6VezFVdgLP+jv6+69DymWE0NA0D4YRxOvdWdP3rUdAZrWVoaHX0N9/LS7XsPJ7KHoCZbbaILQiizkSeWZlRR7Jha86nZvqntjqTRz0+/fhcg3jdm8BWOktDkII0um53HdMJCYJBq/K+TyEcCFlglRqoewEnu/vMIw+XK4xpATbjpBMzhKPP0cqNQOwrsmunsz6VveEUShahRIeG4RmJ+1Fo4dIJk8jpUDXg9h2gkjkUbzevfh8O5sx5JqIRJ4hkTgNCDTNhWEMYhheTDOc+47ZiTgQuIp4/PlctNXAwCvWTOBZX040egBNO4rPdxmGEcSyEit76EQijwICp3NkTQ5GOV9QrQKym4s5KhSVUGarDUKze2bPz9+Lx7MXkEiZRNNcgCAWO9z2VqpZQZZBQ0qTdHqKdHqxIJEvazbSdRd9fS+lr+8avN5djI6+c835siY+v/8KTDPC0tKD6PpgTtOSUgICkHi9uwtMdtWaCKvxQdVrzlMoOo3SPDYIzV7BJhKTeDw7MIxAXu5DEMPoq7vwYK1RW9lj5ua+h21LNM2FbccAJ1JCKnUWj+fFOWFWrdko38RnGH0Eg9ewvHyAZPIk/f3XAZKlpftxOkfwenfnwnezmlw1JsJaCi32Wh0whQKU8NgwNLuSbdYE5HRekJs8TTOEw9G3zpFrqadibf4xUko0zb3iBPdiWUsAOJ1D7NjxoYJzVDMRF5v4XK5hnM4bSKWm2L37fwJw6tSdZX0R1ZgIVSXdjcvtt/uZmNDXbB8ft7jjjmjXnbdVdFx4CCHeC7wTuAz4P1LKd1bY9wPAnwAe4BvArVLKZBuG2RM0cwVbrMnE4yeJxw/jcm3j1Kk7a8r3qGciLdYObDuBYfSh6276+n49J8halXNSSpPLhPCOFPhJsoK1+PiNVjhyI2KaJg899BNmZ6drOu5nP7uZ4eFwie1Bvv7179Y9nladt14Mo7J
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-02 10:57:06 +02:00
"source": [
2016-09-27 23:31:21 +02:00
"plot_decision_boundary(ada_clf, X, y)"
]
},
{
"cell_type": "code",
2019-04-15 18:06:57 +02:00
"execution_count": 31,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saving figure boosting_plot\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsgAAAEYCAYAAABBfQDEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9d3hk13WnC7/7xErIuQM6kt0km5lNMUokRYlBVKZkW4ly0lgezXh8Z+axLevescceh/mu09jy2L62RVFpZJGUSFEMEjObsUmKzc4ZQKMRC0ABqHCqTtjfH6cKjVAFVCE2us/7PGATVSfsqlO18Dtr//ZaQkpJQEBAQEBAQEBAQICPstIDCAgICAgICAgICDibCARyQEBAQEBAQEBAwCQCgRwQEBAQEBAQEBAwiUAgBwQEBAQEBAQEBEwiEMgBAQEBAQEBAQEBkwgEckBAQEBAQEBAQMAkAoEcsCgIIf5ACHFW1wxcDWMsByHEfxJCfGKlx1EKIcSHhRDfFUIcEUJ4QojnV3pMAQHLyWqINathjOVwtsdDACHEx4QQPxdCWEKITiHE14QQahn7fVEIIYv8vLMMwz7v0VZ6AAEBy8g/A0+u9CAWgf8E7AIeXuFxlOJjwBXAa0BoRUcSEBBQiiAeLgNCiDuAh4B/Af4v4ErgT4Aq4HfKPMyngO5Jv6cWc4wBxQkEcsCqRQhhSimz5W4vpexmapA5K6j0dawCfl1K6QEIIXat9GACAs4Hgnh41vJnwC4p5Zfyvz8nhIgBXxNC/JWUsq+MY7wjpTy2dEMMKEZgsQhYMoQQmhDi94QQh4QQWSFEjxDiL4QQoWnb/aEQ4m0hxKgQIi6EeFYIcd20bW7JTy19Qgjx/wkhBoH+/HPPCyF2CSFuzx8nLYTYJ4T42LRjzJhSzB/zj4UQ/1EIcVIIMS6EeEEIccm07dT8dr354z8rhNie3/8PKnhPClNm7xVC/EAIkQBezz+3UwjxoBCiWwiREUIcFkL8iRAiPGn/DmAD8NlJ0233T3r+ciHEo0KIkfwxXhZC3Fzu+BaDgjgOCAg4QxAPi74n53Q8FEKsx59N+/a0p74F6MBdyzWWgMoJMsgBS8m3gQ8Dfw68AlwE/BGwEfjkpO3WAn+Fn82IAp8DXhRCXCOlfHfaMf8WeAL4PFOn77cAfwP8KRAH/jPwoBBiexl33p8DDgO/BRjA/w94JL+vk9/mD4Gv5p97GrgKeHTut6Ak3wG+B9zLme9hO/AOcD8wDlwC/D/AZuAX89t8HHgc2AP8Qf6xQQAhxFXAS8DPgV8H0sBvAE8LIW6QUr5VajBCCAHM6YkDpJTSLWO7gICAqQTxsDTnajws3Fjsm7bTSSFEGri4jHMA7BJCNAEDwCPAV6WUw2XuGzBfpJTBT/Cz4B/84CQn/X4zIIEvTNvus/nHryhxHBU/QB4G/mbS47fk9/thkX2eB2zggkmPNQMufiApOsb8YxI4CuiTHrs3//gN+d/rgCTw99P2/b/y2/1BBe/TF/P7/NUc24n8+/A5wAMaJj3XAXy7yD7PAAcBY9r7eRD40RznK7y/c/08X+HnYlel+wQ/wc9q/wniYdnv0zkdD4HP5LfbXuS5buBf5tj/DvybkbuBW4Gv4d8s7AVCK/05P9d/ggxywFJxJ5ADHhJCTP6c/TT/73vxswMIIW4Hfh+4DKiftO3JIsf9YYnzHZVSHi38IqUcEEIM4Gch5uJnUkp70u978/+242d6LsXP5Pxg2n4PAn9RxvGLMeN1CCGq8d+He4H1+FNwBS4AhkodLD/t+D78xR/etPf8afw/xLPxFrCzjHGPl7FNQEDAVIJ4ODvnajwU+X+LVQsRRR6bgpTyKeCpSQ89J4TYC/wI/2bhn8sYY8A8CQRywFLRjD89lyzxfANMTIM9jh8EfhXoxc90/DPFKyD0ljhesemmbIljzLVvYYFIYd+2/L8D07brL+PYpSj2Or4B3I4/jfgO/krla4GvM/frqMfPjvzf+Z8ZCCEUWdofnMyfcy5WfVmogIAVIIiHs3OuxsPCe1lf5Llail+nuXgU/73YSSCQl5RAIAcsFUOAhT+1WIye/L+fBBzgE5OzFkKIOiBRZL+VEGiF4N0M7J/0eMsCjjl9cUwI+Cj+9OTfTHr80jKPl8Cfevw68EDRE86+eO59wHNlnOcF/OnHgICA8gni4eycq/Gw8P5cArxaeFAIsRGIAAfKOEcpgmTFEhMI5ICl4kn8Go81UspnZtkugp8hmfiyCyFuw5/OKzaluBLsxb9j/xRTg+anFvEcJn7Gw572+BeLbJsFwpMfkFKmhBAvAZcDb88R/IsRWCwCApaOIB5WxjkRD6WUXUKIPfiWjsnZ3s/hv7YnKhwX+HXmo+SrfQQsHYFADlgSpJTPCyG+h79y+i+BN/Dv6DfiLzj4HSnlEfw/HP8JuF8I8Q3gQvwpsdMrMe5iSClHhBB/DXxVCDHOmVXbv5rfZMFlzaSUo0KI14D/LIToxV95/iv4K9qncwC4WQhxD9AHxKWUHfiLZF4EnhJC/At+pqcxP1ZVSvm7s5x/HHhzoa8DQAixgTN/XBrwPYD35n/fLaXsXIzzBASsFoJ4WPE5zpl4iF/t4zEhxD/iV+q4En+x3d/ISTWQhRD/D76dZEshRgohfoZ/E7IPyAA3Av8Fv2rHdxdpfAElCOogBywln8NfKX0vfmmaB4Gv4K+S7oeJRQj/Ef+L/xh+EPwCcLYVRf9v+CWT7sP3gN3FmWzG6CKd45fwMxdfxy9t1Idfamk6v4e/qv3fgN3kyxtJKd/GF6ZDwP/CXwD0N/iLal5cpDGWw634C3h+AGzHL2VU+P3WZRxHQMDZRBAPK+OciIdSysfxr/l1+N7y38ZfPDhdoCv4WfPJi/f245fw+w7wk/z//wPwPnluNVM5KxFSBjaWgID5IIT4FH5Qfq+U8qWVHk9AQEDAShHEw4BzjUAgBwSUgRDiPcCH8H1fFnA1fgbgMH590OCLFBAQcF4QxMOA84EVt1gIIb4ihHhT+K03759luy8KIVwhRHLSzy3LNtCA850kfq3SB/B9gr+Fny25u/DHQPitZGf7mbPuZUBAQMAqIIiHAec8K55BFkJ8At/UfwcQllJ+scR2XwR+TUp50/KNLiCgPPJle+ZaZX6rlPL5pR9NQEBAwMoRxMOAc4EVr2IhpXwYQAhxDbBuhYcTEDBfepi7LNDh5RhIQEBAwAoTxMOAVc+KZ5ALCCH+GFg3Rwb56/ilToaBbwF/KqV0Smz/JeBLANFo6Opt28rpsBkQsBAkyWQa1+3DdRUKi5GlFEjFQ1M8tMg6QrqxssM8T/E8j/FEEisjcISHFoL6mhiKWtpp5touqdEklqXiqS5GWKGutgqA1GiKbNrFdhWE6VBTG0XaLunRDNmsitQdhAY1jdWoqrpcLxOA428fj0spm8rdPoiXAStFOt2BEDqTHRdSSqS0iUQ2rtzAAipmYGCYbBZsTyB0B00KbFtD6jZGWKO6vnqlh1iUUvFyxTPIFfAisAPoxO9K8338jkN/WmxjKeU/Af8EcPXV2+Rrr/39Mg0z4FwlHt9FX9/j5HJxDKOR1ta7aWz0HT/9/cM89NCz7HlX5YZbfohuZLGcfO16xSWkZ9H0KFfd9JfUVcdW8FWcvySGRnnoXx/j0MEmkk0DhBpy/OYvf4SWprqS+3Qf6ealh15g37523C3HidVr/Off+iUUIfjJ3/6QjiMGfZ5HqDXOjto6ug5rDEoP2TiEosKaC9q4+WM3YYbNZXyl8HHj4xXVmg7iZcBiM1u8nMy+fV/F8zLoes3EY7Y9iqKE2bHjT5ZzyAEL5I/+6J/p7a2nX9hozXEaMlH6B+pRNpwiUh3i07/96ZUeYlFKxctVI5CllCcm/bpXCPHfgf9KCYEcELCYxOO76Or6JpoWwzBa8LwkXV3fxPNc9u4N8cwzJzjRH8ZtHWDf4Hpu2HoQYUuyrk5Yc2iuN9lw4RepCcTxipAcS/HsYy9z6lSEpJmEUAZF0QiZi5PNb0ZhpMdlIFmF2NiFEdK58WPXsWH7hkU5fkDAaqJUvARmiOT
"text/plain": [
"<Figure size 720x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"m = len(X_train)\n",
"\n",
2019-10-13 10:58:36 +02:00
"fix, axes = plt.subplots(ncols=2, figsize=(10,4), sharey=True)\n",
"for subplot, learning_rate in ((0, 1), (1, 0.5)):\n",
2021-02-14 03:02:09 +01:00
" sample_weights = np.ones(m) / m\n",
2019-10-13 10:58:36 +02:00
" plt.sca(axes[subplot])\n",
2016-09-27 23:31:21 +02:00
" for i in range(5):\n",
2021-02-14 03:02:09 +01:00
" svm_clf = SVC(kernel=\"rbf\", C=0.2, gamma=0.6, random_state=42)\n",
" svm_clf.fit(X_train, y_train, sample_weight=sample_weights * m)\n",
2016-09-27 23:31:21 +02:00
" y_pred = svm_clf.predict(X_train)\n",
2021-02-14 03:02:09 +01:00
"\n",
" r = sample_weights[y_pred != y_train].sum() / sample_weights.sum() # equation 7-1\n",
" alpha = learning_rate * np.log((1 - r) / r) # equation 7-2\n",
" sample_weights[y_pred != y_train] *= np.exp(alpha) # equation 7-3\n",
" sample_weights /= sample_weights.sum() # normalization step\n",
"\n",
2016-09-27 23:31:21 +02:00
" plot_decision_boundary(svm_clf, X, y, alpha=0.2)\n",
2017-09-15 16:41:15 +02:00
" plt.title(\"learning_rate = {}\".format(learning_rate), fontsize=16)\n",
2019-10-13 10:58:36 +02:00
" if subplot == 0:\n",
2021-02-14 03:02:09 +01:00
" plt.text(-0.75, -0.95, \"1\", fontsize=14)\n",
" plt.text(-1.05, -0.95, \"2\", fontsize=14)\n",
" plt.text(1.0, -0.95, \"3\", fontsize=14)\n",
" plt.text(-1.45, -0.5, \"4\", fontsize=14)\n",
" plt.text(1.36, -0.95, \"5\", fontsize=14)\n",
2019-10-13 10:58:36 +02:00
" else:\n",
" plt.ylabel(\"\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"boosting_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-09-15 16:41:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Gradient Boosting"
]
},
{
"cell_type": "code",
2021-02-14 03:02:09 +01:00
"execution_count": 32,
2018-05-08 12:43:49 +02:00
"metadata": {},
2017-06-02 10:57:06 +02:00
"outputs": [],
"source": [
2017-06-06 16:32:08 +02:00
"np.random.seed(42)\n",
"X = np.random.rand(100, 1) - 0.5\n",
"y = 3*X[:, 0]**2 + 0.05 * np.random.randn(100)"
2017-06-02 10:57:06 +02:00
]
},
{
"cell_type": "code",
2021-02-14 03:02:09 +01:00
"execution_count": 33,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeRegressor(max_depth=2, random_state=42)"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"from sklearn.tree import DecisionTreeRegressor\n",
"\n",
"tree_reg1 = DecisionTreeRegressor(max_depth=2, random_state=42)\n",
2017-06-02 10:57:06 +02:00
"tree_reg1.fit(X, y)"
]
},
{
"cell_type": "code",
2021-02-14 03:02:09 +01:00
"execution_count": 34,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeRegressor(max_depth=2, random_state=42)"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-02 10:57:06 +02:00
"source": [
2016-09-27 23:31:21 +02:00
"y2 = y - tree_reg1.predict(X)\n",
"tree_reg2 = DecisionTreeRegressor(max_depth=2, random_state=42)\n",
2017-06-02 10:57:06 +02:00
"tree_reg2.fit(X, y2)"
]
},
{
"cell_type": "code",
2021-02-14 03:02:09 +01:00
"execution_count": 35,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeRegressor(max_depth=2, random_state=42)"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-02 10:57:06 +02:00
"source": [
2016-09-27 23:31:21 +02:00
"y3 = y2 - tree_reg2.predict(X)\n",
"tree_reg3 = DecisionTreeRegressor(max_depth=2, random_state=42)\n",
2017-06-02 10:57:06 +02:00
"tree_reg3.fit(X, y3)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2021-02-14 03:02:09 +01:00
"execution_count": 36,
2018-05-08 12:43:49 +02:00
"metadata": {},
2017-06-02 10:57:06 +02:00
"outputs": [],
"source": [
"X_new = np.array([[0.8]])"
]
},
{
"cell_type": "code",
2021-02-14 03:02:09 +01:00
"execution_count": 37,
2017-09-15 16:41:15 +02:00
"metadata": {},
2017-06-02 10:57:06 +02:00
"outputs": [],
"source": [
"y_pred = sum(tree.predict(X_new) for tree in (tree_reg1, tree_reg2, tree_reg3))"
]
},
{
"cell_type": "code",
2021-02-14 03:02:09 +01:00
"execution_count": 38,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"array([0.75026781])"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-02 10:57:06 +02:00
"source": [
"y_pred"
]
},
{
"cell_type": "code",
2021-02-14 03:02:09 +01:00
"execution_count": 39,
2017-09-15 16:41:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"def plot_predictions(regressors, X, y, axes, label=None, style=\"r-\", data_style=\"b.\", data_label=None):\n",
" x1 = np.linspace(axes[0], axes[1], 500)\n",
" y_pred = sum(regressor.predict(x1.reshape(-1, 1)) for regressor in regressors)\n",
" plt.plot(X[:, 0], y, data_style, label=data_label)\n",
" plt.plot(x1, y_pred, style, linewidth=2, label=label)\n",
" if label or data_label:\n",
" plt.legend(loc=\"upper center\", fontsize=16)\n",
2019-06-08 15:59:55 +02:00
" plt.axis(axes)"
]
},
{
"cell_type": "code",
2021-02-14 03:02:09 +01:00
"execution_count": 40,
2019-06-08 15:59:55 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saving figure gradient_boosting_plot\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwcAAAMQCAYAAAByixlsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAADlPUlEQVR4nOzdeZxbdb3/8ddnlk5LFyhtKQiUAkLZhAqDdFRssWxeUeFyES8ooFYQuAj83ADBtqAtKCKKgKDsm6IssihXLlBkmWJboCqIBaSUvS2F0pZuM/P5/XGSaSaTPSc5Ocn7+XjkMZOTk5NvTpLPOZ/z3czdERERERERaYq6ACIiIiIiUhuUHIiIiIiICKDkQEREREREEpQciIiIiIgIoORAREREREQSlByIiIiIiAig5CCWzOw4M/OU2zoze9HMZpjZwAq9ppvZtALWm2VmsypRhsT2xybKclylXiNMKZ/V2DzrnWZm/1mlYsVG+vfOzKaZWVHjL5vZ+MTzNs23fREJZDjOpN7ejbp8pUp5Xx/Ms16sjjX5ZDoWmdlCM7u2hO18pZDtS3y1RF0AKcsRwKvAUOAw4MzE/6dU4LU6Eq8llXEa8Chwe8TlqHW/Bu4r8jnjganAjcCytMf0vRbJLXmcSdUVRUEkdIcB7xX5nOMIzh2vTlt+L0E8faP8YknUlBzE29Pu/kLi//vNbAfgq2Z2qrv3hPlC7j47zO1J6cyszd3XRl2OXMysGTB3D/Ukwt1fJcSTeX2vRfJKPc5IBCoV8939qRC3tQRYEtb2JFpqVlRfngQGASOTC8xsIzO7wMxeSjQ/esnMvmdmTSnrDDGzS8xskZmtNbO3zOz/zGynlHX6Nb8wsy+Y2XOJ5zxjZoelFyhbVWOm5iFm9j9m1mlmy8zsXTObbWafzvemzWxvM7vfzN42s/fN7N9mdlme5ww0s5+a2T/MbKWZvWlmd6e+57TyTzCzm8zsPTN73cx+nt6Ey8y2M7N7E2VYYmY/A9oKKP9CYBvg6JRq+2tT95OZ7WZm/2tmK4FbE4/l/WwT6400s8vN7LXEZ/WcmR1fQLkmJV77cDO71szeSbz/m8xsRNq6bmY/NLMzzOwlYB3wocRjE83sATNbYWarEu9jt7TnN5vZD8zsjcT+m2Vmu2YoU6bvTYuZfdfMnjWzNYl9f5+Z7WRBk4BrEqs+n7J/x6aUe1ra9g5OfA9Xm9lyM7vTzMalrTPLzB41s/3N7MlEmf9hZoemrbejmd1hZosTZVtkZr8zM12YkbpQaIxM/E7Ps6AJ7BozW5r4DX08bXtfM7P5KetcZWlNAhOv9wMz+6aZvZyIK/ea2WaJ262J3+4rZvbdLEX/QOK3vdKCY8elZjaogPebN55led61ZvaqmX3UzOYk3t9CMzslbb3k/vxEIla8CzyRsg/PtA3H3dfN7CdW4rHIMjQrMrNtzewGC46Jay04nv4s8dgsYCLwsZRYOiut3GNTttWa+JwWWnCMWpi435qyTrL51glmdq4Fx4B3LTgeb5VWtqPM7KnEZ7bczP5uZifk2/dSPB2g6stYYDnwNgSBBPhfYBfgPODvwATgHGBT4JuJ5/0U+CxwFvA8MAL4GLBJthcys/2BmwmqEr8JjAJ+BrQC/yqj/L8GFhJ8Nz8D3GNm/+Huf8pSjiGJ9/hXgurOFYntfDTPa7URNMH6AUE16KbAScBsM9vJ3d9MW/8G4BbgPwmqTqcB7xA0V8HMBgD3EyRnJwOLgRMS6+dzGPBHYH5iu9D/CswfgKuAC4CeQj9bMxsGPJYo1zTgJeAg4HILrkZdUkD5Lgb+D/hvYAdgBvABYL+09Y4D/g18C1gFvG5BcvcHgu/JFxPrfRd4xMx2d/dXEsumEXz/LgL+DLQDdxVQNoDfAIemlHMg8Algi8Tr/gA4m77NIzJWfZvZwYnnPAgcCQwBzgUeNbPx7v5ayurbE3znZwJLCfb57xPfn+SV1nuAd4ETE+tsCfwHujAj8dGcIZntyVA7nTNGEvzuTwe+BzwNDCP4nfee+JvZ+QS/o58D3yb4vfwA2M3MPuru3Smv9yXgHwRxezTB7/96grj+J+BKgt/8+Wb2d3f/Y1p5byS40HIZ8BHg+8BggjiWURHxLJthwG8J4vgLwBeAn5vZCne/Nm3dmwj253+x4VztRoLj4gXA48DOBPF/LHB4oowlH4vMbFuCY+n7BJ/b88DWwIGJVU5KlKE5sU3I3SzpOuDzBMeMRwm+F2cD2wFHpa17ZuI9fQXYDPhJYh9MTJTt44nXTn43moCdyHGeImVwd91idiMIXg6MIwgawwl+UF3A/6Ss96XEep9Ie/73CK7sbpa4/w/gojyv6cC0lPuPAc8CTSnL9kmsNytDWcembW9a8PXL+npNiff2Z+APKcvHJrZ3XOJ+e+L+7mXu02ZgI4Lk4vQM5Z+etv49wIKU+19LrDch7T08k+n9Z3j9hcCNGZZPSzz/1LTlhX625wBrgB3S1vsVwclqS44yTUq8xn1py49OLJ+c9v14HRiUtu4LwANpy4YlXvvixP3hwErgl2nrfTfD967P9wb4ZGKdbxTwe/lgAd/ruQQHxJaUZdsC60n5jQCzEst2SFm2GdANnJW4PzKx/c+W893UTbcobim/m0y3ezKsly9G3gPcnuP1xiZ+P99PW/6xxPYPTVnmwIK03+lFieVnpyxrITg5viZDedPjzfcSr79jSnl6jzWJZXnjWY73d21ie19IW34/8DJBM8zU8v00bb19E8uPSVuejMfjE/cLPhYRHHeuTbl/PUEs/kCO9zELeDTH92Vs4v5upMXXxPKzSTlmp+znh9PW+1Zi+QdS7i+L+nfRKDddvYq35whOUJYRXFW+wt1/kfL4wQRB5/FEdWRL4grQnwmu8E9IrDcHOM7MzjKzdgvai2eVeHxv4PeecvXI3Z8gCDYlMbO9zOweM3uLINFZDxxAkARl8zzBldkrzOyLZrZ1Ea/3eTN7IlFt20VwtXtIlte7N+3+34ExKfc7gFc8pQ17Yt/cWmh58rgj7X6hn+3BBFXSL6Wt978ENUS7FPDa6e/hd0APwXtOdZ+7r07esaAPzPbATWmv/T7QSXB1H4LmR4MzvM5vCijbgQQHkF8VsG5OZjYY2BP4raf0lXD3lwiS4YlpT3ne3Z9PWW8xwYlI8nvxNkFNyvkWNJXYodwyikTgMIJ4n3o7LcN6+WLkHOA/LGh++PHEFe5UBxCcxKbHiycIrk5/Im39+71vn6bnEn//N7kg8fgLBFe/02WKN00EtQj9FBHPcukGbsvwumMIaklSZYr564DbMsR8Ul6/nGPRgQSJ3+sFrJtPsjw3pi1P3k+Pp5m+P7DhOzQHGG5mN5rZIWa2SQhllCyUHMRbMmj/B0FzipPM7JiUxzcjaMu+Pu3218TjyXbjpwBXENQ+zAEWW9Aef6MsrzuS4AT0rQyPZVqWV+Kk/gGCKuZTCJoF7U0wMk3W4VndfTlB85bXCaqHF1nQ9vvwPK/3GYLq3X8SVG/uk3i9JVleL32Um7X0bcO5BSHujwzSm8EU+tluRhCk09f7Xdp6ufR5D+6+jqC5QPrBLFMZIUhc01//kJTX3iLT62S4n8kIgqtJq/Oumd9wwMjc5OhNUpo/JKR/JyD4XgyERPVGcMIzl6Dp0YJE+90TQyirSLX8w93npt0ydVDOFyNnEDRV+SzwCPC2mV1jZsk+csl48QL948Uw+seqd9Lur8uxPFNMzxZv0uNaUqHxLJd33H19ga+bKZ4OILiyn/raixOPp8bTUo9FIwhvwIdkvEx/H2+mPZ6U6fsDG+LpwwTNxLYmSJyWWNA3cvdwiiup1Ocg3v6RDNJm9iDwN+DHZnabu68iuHL5EkGbv0wWArj7SoL2fmea2TYEbRzPJwiqmTpzLSUISqMzPDaa4Ip20prE3/SrROmB9GBgY+DzHoxIQ+J9ZUtQern708Dhiaso7Yn3cquZ7eHu/8jytC8AL7j7cSmv1Ur/gFWoN4B+HWjJvI9K4Wn3C/psE+stBk7Nsl4h/UP6vIfEFb/hwGtp62UqIwSfx/9l2G7yYJ48eIw
"text/plain": [
"<Figure size 792x792 with 6 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-06-08 15:59:55 +02:00
"source": [
2016-09-27 23:31:21 +02:00
"plt.figure(figsize=(11,11))\n",
"\n",
"plt.subplot(321)\n",
"plot_predictions([tree_reg1], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"$h_1(x_1)$\", style=\"g-\", data_label=\"Training set\")\n",
"plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n",
"plt.title(\"Residuals and tree predictions\", fontsize=16)\n",
"\n",
"plt.subplot(322)\n",
"plot_predictions([tree_reg1], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"$h(x_1) = h_1(x_1)$\", data_label=\"Training set\")\n",
"plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n",
"plt.title(\"Ensemble predictions\", fontsize=16)\n",
"\n",
"plt.subplot(323)\n",
"plot_predictions([tree_reg2], X, y2, axes=[-0.5, 0.5, -0.5, 0.5], label=\"$h_2(x_1)$\", style=\"g-\", data_style=\"k+\", data_label=\"Residuals\")\n",
"plt.ylabel(\"$y - h_1(x_1)$\", fontsize=16)\n",
"\n",
"plt.subplot(324)\n",
"plot_predictions([tree_reg1, tree_reg2], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"$h(x_1) = h_1(x_1) + h_2(x_1)$\")\n",
"plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n",
"\n",
"plt.subplot(325)\n",
"plot_predictions([tree_reg3], X, y3, axes=[-0.5, 0.5, -0.5, 0.5], label=\"$h_3(x_1)$\", style=\"g-\", data_style=\"k+\")\n",
"plt.ylabel(\"$y - h_1(x_1) - h_2(x_1)$\", fontsize=16)\n",
"plt.xlabel(\"$x_1$\", fontsize=16)\n",
"\n",
"plt.subplot(326)\n",
"plot_predictions([tree_reg1, tree_reg2, tree_reg3], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label=\"$h(x_1) = h_1(x_1) + h_2(x_1) + h_3(x_1)$\")\n",
"plt.xlabel(\"$x_1$\", fontsize=16)\n",
"plt.ylabel(\"$y$\", fontsize=16, rotation=0)\n",
"\n",
"save_fig(\"gradient_boosting_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2021-02-14 03:02:09 +01:00
"execution_count": 41,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"GradientBoostingRegressor(learning_rate=1.0, max_depth=2, n_estimators=3,\n",
" random_state=42)"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"from sklearn.ensemble import GradientBoostingRegressor\n",
"\n",
2017-06-02 10:57:06 +02:00
"gbrt = GradientBoostingRegressor(max_depth=2, n_estimators=3, learning_rate=1.0, random_state=42)\n",
"gbrt.fit(X, y)"
]
},
{
"cell_type": "code",
2021-02-14 03:02:09 +01:00
"execution_count": 42,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"data": {
"text/plain": [
"GradientBoostingRegressor(max_depth=2, n_estimators=200, random_state=42)"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-02 10:57:06 +02:00
"source": [
2016-09-27 23:31:21 +02:00
"gbrt_slow = GradientBoostingRegressor(max_depth=2, n_estimators=200, learning_rate=0.1, random_state=42)\n",
2017-06-02 10:57:06 +02:00
"gbrt_slow.fit(X, y)"
]
},
{
"cell_type": "code",
2021-02-14 03:02:09 +01:00
"execution_count": 43,
2017-09-15 16:41:15 +02:00
"metadata": {},
2021-05-25 05:54:57 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saving figure gbrt_learning_rate_plot\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAr4AAAEYCAYAAAC6KvUbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAABdUklEQVR4nO3dd5hU1f3H8ffZXXbZpUlHuhRFUEEFdLGAIjb0Z0ONLZrYS0RNIvYSVDTRRJNgj73FmthjiajoKqJRBAsqIr0jddlld8/vj3Pv7uzszOzM7vT5vJ5nnpm5c++dc6d85zvnnmKstYiIiIiIZLu8VBdARERERCQZlPiKiIiISE5Q4isiIiIiOUGJr4iIiIjkBCW+IiIiIpITlPiKiIiISE5Q4ttMxpiHjDEvp7ocPmOMNcZMSHU5JH0ZY04zxmxMdTkkOykmSqZRTMwtSnyzz7bAS6kuRDSS9YNkjLnSGPOBMWaTMSbqgauNMecZY340xmwxxnxqjNknkeVMhDCv8T+Bfkl47jHe83dK9HPFyhgz2BjzjjFmuff+zjPG3GSMKUx12STuFBMbPk97Y8yjxph13uVRY8w2jWyzrzHmRWPMYq+cpyW6nImgmBiaV7Z/G2OWGmM2G2NmGWN+HWK90d7voR83zwmxzjHGmK+MMRXe9VHJOYroKPHNAMaYAmOMiWZda+0ya21FossUjjEmzxiTn6rnD6MIeB64PdoNjDHHA3cANwG7Ah8CrxljeieigMlkrS231q5IdTliYYxpEeddVgIPAwcCOwAXAacDN8T5eSQBFBOb7QlgN+AQ4GDv9qONbNMamA1MBMoTWrokU0wEYBTwJTAB2Am4C7jXGHNiwHNuB7yK+z3cFZgC/M0Yc0zAOqW4PxKPA8O862eMMXvEubxNZ63VpRkX4CHg5YD7BrgU+AEXHL4ETg7a5mbgW+/x+cAfgZYBj1+HCzCnefupxgUdC5wFPANsAuaF2LcFJni3+3r3jwHeBDYDXwHjgrYZ75VnC/Ae8Atvu75RHP9pwEbgUK/MVbgvzQjgDWAVsB6YDpQGbDffew7/Mj/gscOBT73y/AjcCBTG4b2a4D7yUa37MXBf0LLvgCkxPud84CrgHu91WAT8Pobt2wH3AiuADcC7wPCgxx/1Ht/ifSYuivQa++9ZiM/bqd42G4EHgULgPGAhsBr4M5AXsN3JwCdeuVZ4n8seQZ+9wMtD3mNFuD8hy70yfwTsHbDfMd76hwIzcEnqYUAv4N/AGtxn+RvgF3H8Lv8ZKEtUrMiVC4qJp5HGMRHY0dv/XgHL9vaW7RDlPjYCpzXx+eejmJgpMfFp4LmA+7cA3wWtcz8BcROX9L4ZtM5bwJPxKlezjyvVBcj0Cw2D/I24gHkwsB1wIi4gjw9Y52pgL++LcCiwAJgc8Ph13jZv4P6J7wQUeB/8Rd6XawDu31Yl0Cdg21BB/htc4ByIq+VaDbT21ukNVHhf4B1wyeECYgvyVbh/gHsB2wNtgP2BU3BBdhDwd2At0MnbrrP3HGcA3YDO3vKDcMHwV0B/YD/v9bw14DnvxgWiSJfeIcoaVeKLC25VwLFBy6cC78b4+Zjvvd4XeO/Zb7zjLo1iW4P7cXwFGOltP9l7fbb11vkb8Ln3eF9cgDy2kdf4NBoG+Y24WvGdvPdgI/AaLtjvCBwFbAWOCdju17jPbz/v+d8B3vMeyweO9p5/sPf87bzH7gCW4pKLHYH7vOfzj2mMt92XuBrZft6xvIRLVobivlsHAwcHlGdOI5+JORFe6wG4BOjmVMeUTL+gmHgaaRwTcd/bDYAJijUbgV9F+R43N/FVTEzzmOht/zpwf8D994CpQesc670OLbz7Cwj6IwP8Hvgp1bGptjypLkCmXwgI8kArXI3FPkHr3A68GmEf5wDfB9y/zvsgdQ1azxJQ44gL/JsJqOEgdJA/O+DxHt6yvb37U4CvqR8EryC2IG+B3RtZz3hf7JBlDVj2HnB10LIjvS+p8e53wQW8SJeCEGWINvHt7pVt36Dl1wDfxvj5mE/QP11czfFVUWy7v3fcxUHLPwcu9W6/CDwYYR+hXuPTaBjky/GCsLfsWWAlAbVKwDTg7xGea5D3fD29+2O8+50C1mmFS0x+GbAsH1eLd0PQdscE7X8WcG2E5+/TyGeiT4htPsTVsFhcLVJeuP3rEvVn/iEUE9M2JnrHMi9EeeYBl0f5Hjc38VVMrFsnrWJiwLaH4b5zIwOWzQWuCVpvX69sfpJe71i8Zb8EKpryeUnEpQCJp8FAS+D1oE5ULXBfdgC8hvUX4T54rXEf8uA2YIustctDPMcs/4a1tsoYsxIX9CKZFXB7iXftbzMI+MR6n07Px43sL1gVLvDUMsZ0wf0T3w/oiju+YlxtSiS7AyONMZMCluV523YDllrXFisZ7bFs0H0TYlk0ZgXdX0Lj7xm416IEWBnUnLElruYHXDusZ40xu+H++b9krX23CWVcYK1dF3B/OTDXWlsZtKy23N5zXotrx9UB9/qAe48XhXme/rjvwwf+AmtttTGmDPf9CTQz6P4dwN3GmIOBt4EXrLWfBuznp0gHGMbxuNq4ocCfgEm4xEfiQzHRk2YxMVQca2p8awrFxDrpFhMxxuyFawd+obV2RtDDoX4Xg5fH67czIZT4xpffWfBwXHV/oK0Axpg9gaeA64GLgZ+B/wNuDVp/U5jn2Bp039J4J8Xabay11gsY/jbx+EBWWGurg5Y9jAvuF+N+4CpwX8zGes3n4V6bZ0I8thLAGHM37tRmJIOttcHvQbRW4doQdgta3gUX6GLVlPcMb53lQKjRJNYDWGtfM8b0wXVSGQu8Yox5xlr7qziUMdSyfABjTCvgP7i2W6fgfnQ7Ae8T+T0OFSQJs6zed8Ba+w9jzH9wpxIPAD40xkyx1l7nlWkOroYjnJ+stUOC9rnQu/mV1wHpfmPMn6y1VRH2I9FTTKyTLjFxGdDFGGP85N7rKNiZpsW3plBMrJNWMdEYszeuA9s11tq7gtZfRujfxSpc85VI6yTrs9UoJb7x9RUumPWx1v43zDp7AYuttZP9Bd6XNFW+Bo4IWjYyDvvdG/dv8RUAY0xX3LBCgbbSsFbnM2CQtfb7CPu+hoY/isGWNPJ4WNbaSmPMp8A46v/YjAOea+p+m+Az3A9ljbV2XriVrLWrcJ05HjXGvAY8aYw5x7qe7KFe43gYhAvqV1hrfwQwxhwdtI5fMxL4/N97y/fGnVrFSzhLcTUMEVlrF+GaJNzr1YBNxJ2WBBf8I/V0Dv7RCpaHi4n5uEAuzaeYWCddYmIZrla9FNfUB+92q4D76UoxMUg8Y6IxZl9c++nrrLW3h1i/DNfMJtA4YKa1dmvAOuNwZ9AC10mbz5YS3ziy1m4wxtwK3Or9g34PF2D2xH1R78W1kelhjDkJ9wE5CDghVWXGdYq4xCv3fcAQ4GzvsebUeswFTjbGfIwLqH+k7kvvmw+MNca8i6shWQv8AXjZGPMTrkep3yN6pLX2UoBYT+t5Q5B1wLXvwxgzzHvoe2vtRm/ZN7i2Wn/3HvszLmjOwJ2COgfX9vfuaJ83Dt7ynvvfxphLcR1yuuE6MLxlrX3fGPMH3I/BHNz3+Whc+z1/+Kb5NHyN42EBLqG5wBgzFdchY3LQOj/hPkPjjTEvAeXW2o3GmLuAm40xq3A91C/G/ZjdGekJjTF34DqXzAXa4l6Hr/zHYzmtZ4w5Bde290vc53I4ronDszaFQ19lG8XEetIiJlprvzbGvA7cY4w5E1fjeA+uXfa3/nrBMdEY0xrXFAXcn8TeXixd04yza7FSTAwQ55g4Bpf03gk8bozxa22rrbUrvdt3e8d3O+4zsxeufXTg9/UO4D1jzOXAC7hOgPvhEvv0kIqGxdl0IfTQPb+hrqZjJa6d0biAdaZ4y/1eo+cS0OkKbyiVEM8VqlH+fOB3odahriPH8Ej7wTVin4tLBN7H9R62BHUkCXP8pxHQKSBg+VBcu7hyXCP9U3DDw1wXsM7huE4NW6k/dM+BXjk2405fzQQuaOZ7ZENcxgS9JtcFbXcedackP6VhZ7frAt+