Clarify the 'not in the book' comments, and rename decision_tree_instability_plot to decision_tree_high_variance_plot
parent
4eb68a8b7a
commit
37abd9c4d5
|
@ -202,13 +202,6 @@
|
||||||
"# Making Predictions"
|
"# Making Predictions"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"**Code to generate Figure 5–2. Decision Tree decision boundaries**"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 9,
|
||||||
|
@ -218,7 +211,7 @@
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"import matplotlib.pyplot as plt\n",
|
"import matplotlib.pyplot as plt\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# not in the book\n",
|
"# not in the book – just formatting details\n",
|
||||||
"from matplotlib.colors import ListedColormap\n",
|
"from matplotlib.colors import ListedColormap\n",
|
||||||
"custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n",
|
"custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n",
|
||||||
"plt.figure(figsize=(8, 4))\n",
|
"plt.figure(figsize=(8, 4))\n",
|
||||||
|
@ -231,7 +224,7 @@
|
||||||
" plt.plot(X_iris[:, 0][y_iris == idx], X_iris[:, 1][y_iris == idx],\n",
|
" plt.plot(X_iris[:, 0][y_iris == idx], X_iris[:, 1][y_iris == idx],\n",
|
||||||
" style, label=f\"Iris {name}\")\n",
|
" style, label=f\"Iris {name}\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# not in the book\n",
|
"# not in the book – this section beautifies and saves Figure 5–2\n",
|
||||||
"tree_clf_deeper = DecisionTreeClassifier(max_depth=3, random_state=42)\n",
|
"tree_clf_deeper = DecisionTreeClassifier(max_depth=3, random_state=42)\n",
|
||||||
"tree_clf_deeper.fit(X_iris, y_iris)\n",
|
"tree_clf_deeper.fit(X_iris, y_iris)\n",
|
||||||
"th0, th1, th2a, th2b = tree_clf_deeper.tree_.threshold[[0, 2, 3, 6]]\n",
|
"th0, th1, th2a, th2b = tree_clf_deeper.tree_.threshold[[0, 2, 3, 6]]\n",
|
||||||
|
@ -324,13 +317,6 @@
|
||||||
"# Regularization Hyperparameters"
|
"# Regularization Hyperparameters"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"**Code to generate Figure 5–3. Regularization using min_samples_leaf:**"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 14,
|
||||||
|
@ -353,7 +339,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this cell generates and saves Figure 5–3\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def plot_decision_boundary(clf, X, y, axes, cmap):\n",
|
"def plot_decision_boundary(clf, X, y, axes, cmap):\n",
|
||||||
" x1, x2 = np.meshgrid(np.linspace(axes[0], axes[1], 100),\n",
|
" x1, x2 = np.meshgrid(np.linspace(axes[0], axes[1], 100),\n",
|
||||||
|
@ -443,20 +429,13 @@
|
||||||
"tree_reg.fit(X_quad, y_quad)"
|
"tree_reg.fit(X_quad, y_quad)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"**Code to generate Figure 5–4. A Decision Tree for regression:**"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 19,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – we've already seen how to use export_graphviz()\n",
|
||||||
"export_graphviz(\n",
|
"export_graphviz(\n",
|
||||||
" tree_reg,\n",
|
" tree_reg,\n",
|
||||||
" out_file=str(IMAGES_PATH / \"regression_tree.dot\"),\n",
|
" out_file=str(IMAGES_PATH / \"regression_tree.dot\"),\n",
|
||||||
|
@ -477,13 +456,6 @@
|
||||||
"tree_reg2.fit(X_quad, y_quad)"
|
"tree_reg2.fit(X_quad, y_quad)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"**Code to generate Figure 5–5. Predictions of two Decision Tree regression models:**"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 21,
|
"execution_count": 21,
|
||||||
|
@ -508,7 +480,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this cell generates and saves Figure 5–5\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def plot_regression_predictions(tree_reg, X, y, axes=[-0.5, 0.5, -0.05, 0.25]):\n",
|
"def plot_regression_predictions(tree_reg, X, y, axes=[-0.5, 0.5, -0.05, 0.25]):\n",
|
||||||
" x1 = np.linspace(axes[0], axes[1], 500).reshape(-1, 1)\n",
|
" x1 = np.linspace(axes[0], axes[1], 500).reshape(-1, 1)\n",
|
||||||
|
@ -546,20 +518,13 @@
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"**Code to generate Figure 5–6. Regularizing a Decision Tree regressor:**"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 24,
|
"execution_count": 24,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this cell generates and saves Figure 5–6\n",
|
||||||
"\n",
|
"\n",
|
||||||
"tree_reg1 = DecisionTreeRegressor(random_state=42)\n",
|
"tree_reg1 = DecisionTreeRegressor(random_state=42)\n",
|
||||||
"tree_reg2 = DecisionTreeRegressor(random_state=42, min_samples_leaf=10)\n",
|
"tree_reg2 = DecisionTreeRegressor(random_state=42, min_samples_leaf=10)\n",
|
||||||
|
@ -606,20 +571,13 @@
|
||||||
"Rotating the dataset also leads to completely different decision boundaries:"
|
"Rotating the dataset also leads to completely different decision boundaries:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"**Code to generate Figure 5–7. Sensitivity to training set rotation**"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 25,
|
"execution_count": 25,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this cell generates and saves Figure 5–7\n",
|
||||||
"\n",
|
"\n",
|
||||||
"np.random.seed(6)\n",
|
"np.random.seed(6)\n",
|
||||||
"X_square = np.random.rand(100, 2) - 0.5\n",
|
"X_square = np.random.rand(100, 2) - 0.5\n",
|
||||||
|
@ -670,7 +628,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this cell generates and saves Figure 5–8\n",
|
||||||
"\n",
|
"\n",
|
||||||
"plt.figure(figsize=(8, 4))\n",
|
"plt.figure(figsize=(8, 4))\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -727,20 +685,13 @@
|
||||||
"tree_clf_tweaked.fit(X_iris, y_iris)"
|
"tree_clf_tweaked.fit(X_iris, y_iris)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"**Code to generate Figure 5–8. Sensitivity to training set details:**"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 29,
|
"execution_count": 29,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this cell generates and saves Figure 5–9\n",
|
||||||
"\n",
|
"\n",
|
||||||
"plt.figure(figsize=(8, 4))\n",
|
"plt.figure(figsize=(8, 4))\n",
|
||||||
"y_pred = tree_clf_tweaked.predict(X_iris_all).reshape(lengths.shape)\n",
|
"y_pred = tree_clf_tweaked.predict(X_iris_all).reshape(lengths.shape)\n",
|
||||||
|
@ -759,7 +710,7 @@
|
||||||
"plt.ylabel(\"Petal width (cm)\")\n",
|
"plt.ylabel(\"Petal width (cm)\")\n",
|
||||||
"plt.axis([0, 7.2, 0, 3])\n",
|
"plt.axis([0, 7.2, 0, 3])\n",
|
||||||
"plt.legend()\n",
|
"plt.legend()\n",
|
||||||
"save_fig(\"decision_tree_instability_plot\")\n",
|
"save_fig(\"decision_tree_high_variance_plot\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in New Issue