handson-ml/05_support_vector_machines....

1828 lines
53 KiB
Plaintext
Raw Normal View History

2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"**Support Vector Machines**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_This notebook is an extra chapter on Support Vector Machines. It also includes exercises and their solutions at the end._"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/05_support_vector_machines.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/05_support_vector_machines.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
" </td>\n",
"</table>"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"This project requires Python 3.8 or above:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 1,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"import sys\n",
2016-09-27 23:31:21 +02:00
"\n",
"assert sys.version_info >= (3, 8)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It also requires Scikit-Learn ≥ 1.0.1:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import sklearn\n",
2016-09-27 23:31:21 +02:00
"\n",
"assert sklearn.__version__ >= \"1.0.1\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we did in previous chapters, let's define the default font sizes to make the figures prettier:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib as mpl\n",
2016-09-27 23:31:21 +02:00
"\n",
"mpl.rc('font', size=12)\n",
"mpl.rc('axes', labelsize=14, titlesize=14)\n",
"mpl.rc('legend', fontsize=14)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And let's create the `images/svm` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"IMAGES_PATH = Path() / \"images\" / \"svm\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
2016-09-27 23:31:21 +02:00
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Linear SVM Classification"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The book starts with a few figures, before the first code example, so the next three cells generate and save these figures. You can skip them if you want."
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 5,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 51\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
2016-09-27 23:31:21 +02:00
"from sklearn.svm import SVC\n",
"from sklearn import datasets\n",
"\n",
"iris = datasets.load_iris(as_frame=True)\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = iris.target\n",
2016-09-27 23:31:21 +02:00
"\n",
2017-02-17 14:47:18 +01:00
"setosa_or_versicolor = (y == 0) | (y == 1)\n",
"X = X[setosa_or_versicolor]\n",
"y = y[setosa_or_versicolor]\n",
2016-09-27 23:31:21 +02:00
"\n",
"# SVM Classifier model\n",
"svm_clf = SVC(kernel=\"linear\", C=float(\"inf\"))\n",
"svm_clf.fit(X, y)\n",
"\n",
2016-09-27 23:31:21 +02:00
"# Bad models\n",
"x0 = np.linspace(0, 5.5, 200)\n",
"pred_1 = 5 * x0 - 20\n",
2016-09-27 23:31:21 +02:00
"pred_2 = x0 - 1.8\n",
"pred_3 = 0.1 * x0 + 0.5\n",
"\n",
"def plot_svc_decision_boundary(svm_clf, xmin, xmax):\n",
" w = svm_clf.coef_[0]\n",
" b = svm_clf.intercept_[0]\n",
"\n",
" # At the decision boundary, w0*x0 + w1*x1 + b = 0\n",
" # => x1 = -w0/w1 * x0 - b/w1\n",
" x0 = np.linspace(xmin, xmax, 200)\n",
" decision_boundary = -w[0] / w[1] * x0 - b / w[1]\n",
2016-09-27 23:31:21 +02:00
"\n",
" margin = 1/w[1]\n",
" gutter_up = decision_boundary + margin\n",
" gutter_down = decision_boundary - margin\n",
" svs = svm_clf.support_vectors_\n",
"\n",
" plt.plot(x0, decision_boundary, \"k-\", linewidth=2, zorder=-2)\n",
" plt.plot(x0, gutter_up, \"k--\", linewidth=2, zorder=-2)\n",
" plt.plot(x0, gutter_down, \"k--\", linewidth=2, zorder=-2)\n",
" plt.scatter(svs[:, 0], svs[:, 1], s=180, facecolors='#AAA',\n",
" zorder=-1)\n",
2016-09-27 23:31:21 +02:00
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10,2.7), sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[0])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(x0, pred_1, \"g--\", linewidth=2)\n",
"plt.plot(x0, pred_2, \"m-\", linewidth=2)\n",
"plt.plot(x0, pred_3, \"r-\", linewidth=2)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\", label=\"Iris versicolor\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\", label=\"Iris setosa\")\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper left\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plot_svc_decision_boundary(svm_clf, 0, 5.5)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\")\n",
"plt.xlabel(\"Petal length\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"large_margin_classification_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 52\n",
"\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
2016-09-27 23:31:21 +02:00
"Xs = np.array([[1, 50], [5, 20], [3, 80], [5, 60]]).astype(np.float64)\n",
"ys = np.array([0, 0, 1, 1])\n",
"svm_clf = SVC(kernel=\"linear\", C=100).fit(Xs, ys)\n",
"\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(Xs)\n",
"svm_clf_scaled = SVC(kernel=\"linear\", C=100).fit(X_scaled, ys)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.figure(figsize=(9,2.7))\n",
2016-09-27 23:31:21 +02:00
"plt.subplot(121)\n",
"plt.plot(Xs[:, 0][ys==1], Xs[:, 1][ys==1], \"bo\")\n",
"plt.plot(Xs[:, 0][ys==0], Xs[:, 1][ys==0], \"ms\")\n",
"plot_svc_decision_boundary(svm_clf, 0, 6)\n",
"plt.xlabel(\"$x_0$\")\n",
"plt.ylabel(\"$x_1$    \", rotation=0)\n",
"plt.title(\"Unscaled\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 6, 0, 90])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.subplot(122)\n",
"plt.plot(X_scaled[:, 0][ys==1], X_scaled[:, 1][ys==1], \"bo\")\n",
"plt.plot(X_scaled[:, 0][ys==0], X_scaled[:, 1][ys==0], \"ms\")\n",
"plot_svc_decision_boundary(svm_clf_scaled, -2, 2)\n",
"plt.xlabel(\"$x'_0$\")\n",
"plt.ylabel(\"$x'_1$ \", rotation=0)\n",
"plt.title(\"Scaled\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([-2, 2, -2, 2])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"sensitivity_to_feature_scales_plot\")\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## Soft Margin Classification"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 7,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 53\n",
"\n",
2016-09-27 23:31:21 +02:00
"X_outliers = np.array([[3.4, 1.3], [3.2, 0.8]])\n",
"y_outliers = np.array([0, 0])\n",
"Xo1 = np.concatenate([X, X_outliers[:1]], axis=0)\n",
"yo1 = np.concatenate([y, y_outliers[:1]], axis=0)\n",
"Xo2 = np.concatenate([X, X_outliers[1:]], axis=0)\n",
"yo2 = np.concatenate([y, y_outliers[1:]], axis=0)\n",
"\n",
"svm_clf2 = SVC(kernel=\"linear\", C=10**9)\n",
2016-09-27 23:31:21 +02:00
"svm_clf2.fit(Xo2, yo2)\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10,2.7), sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[0])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(Xo1[:, 0][yo1==1], Xo1[:, 1][yo1==1], \"bs\")\n",
"plt.plot(Xo1[:, 0][yo1==0], Xo1[:, 1][yo1==0], \"yo\")\n",
"plt.text(0.3, 1.0, \"Impossible!\", color=\"red\", fontsize=18)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
2016-09-27 23:31:21 +02:00
"plt.annotate(\"Outlier\",\n",
" xy=(X_outliers[0][0], X_outliers[0][1]),\n",
" xytext=(2.5, 1.7),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=14,\n",
2016-09-27 23:31:21 +02:00
" )\n",
"plt.axis([0, 5.5, 0, 2])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(Xo2[:, 0][yo2==1], Xo2[:, 1][yo2==1], \"bs\")\n",
"plt.plot(Xo2[:, 0][yo2==0], Xo2[:, 1][yo2==0], \"yo\")\n",
"plot_svc_decision_boundary(svm_clf2, 0, 5.5)\n",
"plt.xlabel(\"Petal length\")\n",
2016-09-27 23:31:21 +02:00
"plt.annotate(\"Outlier\",\n",
" xy=(X_outliers[1][0], X_outliers[1][1]),\n",
" xytext=(3.2, 0.08),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=14,\n",
2016-09-27 23:31:21 +02:00
" )\n",
"plt.axis([0, 5.5, 0, 2])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"sensitivity_to_outliers_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"**This is the first code example in chapter 5:**"
2017-06-01 09:23:37 +02:00
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 8,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2017-06-01 09:23:37 +02:00
"import numpy as np\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.pipeline import make_pipeline\n",
2016-09-27 23:31:21 +02:00
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.svm import LinearSVC\n",
"\n",
"iris = load_iris(as_frame=True)\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = (iris.target == 2) # Iris virginica\n",
2017-06-01 09:23:37 +02:00
"\n",
"svm_clf = make_pipeline(StandardScaler(),\n",
" LinearSVC(C=1, random_state=42))\n",
2017-06-01 09:23:37 +02:00
"svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 9,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"X_new = [[5.5, 1.7], [5.0, 1.5]]\n",
"svm_clf.predict(X_new)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 10,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"svm_clf.decision_function(X_new)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
2017-06-01 09:23:37 +02:00
"source": [
"# not in the book this cell generates and saves Figure 54\n",
"\n",
2016-09-27 23:31:21 +02:00
"scaler = StandardScaler()\n",
"svm_clf1 = LinearSVC(C=1, max_iter=10_000, random_state=42)\n",
"svm_clf2 = LinearSVC(C=100, max_iter=10_000, random_state=42)\n",
"\n",
"scaled_svm_clf1 = make_pipeline(scaler, svm_clf1)\n",
"scaled_svm_clf2 = make_pipeline(scaler, svm_clf2)\n",
2016-09-27 23:31:21 +02:00
"\n",
"scaled_svm_clf1.fit(X, y)\n",
"scaled_svm_clf2.fit(X, y)\n",
"\n",
2016-09-27 23:31:21 +02:00
"# Convert to unscaled parameters\n",
"b1 = svm_clf1.decision_function([-scaler.mean_ / scaler.scale_])\n",
"b2 = svm_clf2.decision_function([-scaler.mean_ / scaler.scale_])\n",
"w1 = svm_clf1.coef_[0] / scaler.scale_\n",
"w2 = svm_clf2.coef_[0] / scaler.scale_\n",
"svm_clf1.intercept_ = np.array([b1])\n",
"svm_clf2.intercept_ = np.array([b2])\n",
"svm_clf1.coef_ = np.array([w1])\n",
"svm_clf2.coef_ = np.array([w2])\n",
"\n",
"# Find support vectors (LinearSVC does not do this automatically)\n",
"t = y * 2 - 1\n",
"support_vectors_idx1 = (t * (X.dot(w1) + b1) < 1).ravel()\n",
"support_vectors_idx2 = (t * (X.dot(w2) + b2) < 1).ravel()\n",
"svm_clf1.support_vectors_ = X[support_vectors_idx1]\n",
"svm_clf2.support_vectors_ = X[support_vectors_idx2]\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10,2.7), sharey=True)\n",
"\n",
"plt.sca(axes[0])\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\", label=\"Iris virginica\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\", label=\"Iris versicolor\")\n",
"plot_svc_decision_boundary(svm_clf1, 4, 5.9)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper left\")\n",
"plt.title(f\"$C = {svm_clf1.C}$\")\n",
"plt.axis([4, 5.9, 0.8, 2.8])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
"plot_svc_decision_boundary(svm_clf2, 4, 5.99)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.title(f\"$C = {svm_clf2.C}$\")\n",
"plt.axis([4, 5.9, 0.8, 2.8])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"regularization_plot\")\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Nonlinear SVM Classification"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 12,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 55\n",
"\n",
2016-09-27 23:31:21 +02:00
"X1D = np.linspace(-4, 4, 9).reshape(-1, 1)\n",
"X2D = np.c_[X1D, X1D**2]\n",
"y = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n",
"\n",
"plt.figure(figsize=(10, 3))\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.subplot(121)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.plot(X1D[:, 0][y==0], np.zeros(4), \"bs\")\n",
"plt.plot(X1D[:, 0][y==1], np.zeros(5), \"g^\")\n",
"plt.gca().get_yaxis().set_ticks([])\n",
"plt.xlabel(r\"$x_1$\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([-4.5, 4.5, -0.2, 0.2])\n",
"\n",
"plt.subplot(122)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n",
"plt.plot(X2D[:, 0][y==0], X2D[:, 1][y==0], \"bs\")\n",
"plt.plot(X2D[:, 0][y==1], X2D[:, 1][y==1], \"g^\")\n",
"plt.xlabel(r\"$x_1$\")\n",
"plt.ylabel(r\"$x_2$  \", rotation=0)\n",
2016-09-27 23:31:21 +02:00
"plt.gca().get_yaxis().set_ticks([0, 4, 8, 12, 16])\n",
"plt.plot([-4.5, 4.5], [6.5, 6.5], \"r--\", linewidth=3)\n",
"plt.axis([-4.5, 4.5, -1, 17])\n",
"\n",
"plt.subplots_adjust(right=1)\n",
"\n",
"save_fig(\"higher_dimensions_plot\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Here is second code example in the chapter:**"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 13,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2017-06-01 09:23:37 +02:00
"from sklearn.datasets import make_moons\n",
2016-09-27 23:31:21 +02:00
"from sklearn.preprocessing import PolynomialFeatures\n",
"\n",
"X, y = make_moons(n_samples=100, noise=0.15, random_state=42)\n",
2016-09-27 23:31:21 +02:00
"\n",
"polynomial_svm_clf = make_pipeline(\n",
" PolynomialFeatures(degree=3),\n",
" StandardScaler(),\n",
" LinearSVC(C=10, max_iter=10_000, random_state=42)\n",
")\n",
2017-06-01 09:23:37 +02:00
"polynomial_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 14,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 56\n",
"\n",
"def plot_dataset(X, y, axes):\n",
" plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
" plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n",
" plt.axis(axes)\n",
" plt.grid(True, which='both')\n",
" plt.xlabel(r\"$x_1$\")\n",
" plt.ylabel(r\"$x_2$\", rotation=0)\n",
"\n",
2016-09-27 23:31:21 +02:00
"def plot_predictions(clf, axes):\n",
" x0s = np.linspace(axes[0], axes[1], 100)\n",
" x1s = np.linspace(axes[2], axes[3], 100)\n",
" x0, x1 = np.meshgrid(x0s, x1s)\n",
" X = np.c_[x0.ravel(), x1.ravel()]\n",
" y_pred = clf.predict(X).reshape(x0.shape)\n",
" y_decision = clf.decision_function(X).reshape(x0.shape)\n",
" plt.contourf(x0, x1, y_pred, cmap=plt.cm.brg, alpha=0.2)\n",
" plt.contourf(x0, x1, y_decision, cmap=plt.cm.brg, alpha=0.1)\n",
"\n",
"plot_predictions(polynomial_svm_clf, [-1.5, 2.5, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
"\n",
"save_fig(\"moons_polynomial_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Polynomial Kernel"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Next code example:**"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 15,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"from sklearn.svm import SVC\n",
2017-06-01 09:23:37 +02:00
"\n",
"poly_kernel_svm_clf = make_pipeline(\n",
" StandardScaler(),\n",
" SVC(kernel=\"poly\", degree=3, coef0=1, C=5)\n",
")\n",
2017-06-01 09:23:37 +02:00
"poly_kernel_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 16,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 57\n",
"\n",
"poly100_kernel_svm_clf = make_pipeline(\n",
" StandardScaler(),\n",
" SVC(kernel=\"poly\", degree=10, coef0=100, C=5)\n",
")\n",
"poly100_kernel_svm_clf.fit(X, y)\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10.5, 4), sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[0])\n",
"plot_predictions(poly_kernel_svm_clf, [-1.5, 2.45, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.4, -1, 1.5])\n",
"plt.title(r\"$d=3, r=1, C=5$\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
"plot_predictions(poly100_kernel_svm_clf, [-1.5, 2.45, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.4, -1, 1.5])\n",
"plt.title(r\"$d=10, r=100, C=5$\")\n",
"plt.ylabel(\"\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"moons_kernelized_polynomial_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Similarity Features"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 17,
2016-09-27 23:31:21 +02:00
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 58\n",
"\n",
2016-09-27 23:31:21 +02:00
"def gaussian_rbf(x, landmark, gamma):\n",
" return np.exp(-gamma * np.linalg.norm(x - landmark, axis=1)**2)\n",
"\n",
"gamma = 0.3\n",
"\n",
"x1s = np.linspace(-4.5, 4.5, 200).reshape(-1, 1)\n",
"x2s = gaussian_rbf(x1s, -2, gamma)\n",
"x3s = gaussian_rbf(x1s, 1, gamma)\n",
"\n",
"XK = np.c_[gaussian_rbf(X1D, -2, gamma), gaussian_rbf(X1D, 1, gamma)]\n",
"yk = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n",
"\n",
"plt.figure(figsize=(10.5, 4))\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.subplot(121)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.scatter(x=[-2, 1], y=[0, 0], s=150, alpha=0.5, c=\"red\")\n",
"plt.plot(X1D[:, 0][yk==0], np.zeros(4), \"bs\")\n",
"plt.plot(X1D[:, 0][yk==1], np.zeros(5), \"g^\")\n",
"plt.plot(x1s, x2s, \"g--\")\n",
"plt.plot(x1s, x3s, \"b:\")\n",
"plt.gca().get_yaxis().set_ticks([0, 0.25, 0.5, 0.75, 1])\n",
"plt.xlabel(r\"$x_1$\")\n",
"plt.ylabel(r\"Similarity\")\n",
"plt.annotate(\n",
" r'$\\mathbf{x}$',\n",
" xy=(X1D[3, 0], 0),\n",
" xytext=(-0.5, 0.20),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=16,\n",
")\n",
"plt.text(-2, 0.9, \"$x_2$\", ha=\"center\", fontsize=15)\n",
"plt.text(1, 0.9, \"$x_3$\", ha=\"center\", fontsize=15)\n",
2016-09-27 23:31:21 +02:00
"plt.axis([-4.5, 4.5, -0.1, 1.1])\n",
"\n",
"plt.subplot(122)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n",
"plt.plot(XK[:, 0][yk==0], XK[:, 1][yk==0], \"bs\")\n",
"plt.plot(XK[:, 0][yk==1], XK[:, 1][yk==1], \"g^\")\n",
"plt.xlabel(r\"$x_2$\")\n",
"plt.ylabel(r\"$x_3$  \", rotation=0)\n",
"plt.annotate(\n",
" r'$\\phi\\left(\\mathbf{x}\\right)$',\n",
" xy=(XK[3, 0], XK[3, 1]),\n",
" xytext=(0.65, 0.50),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=16,\n",
")\n",
2016-09-27 23:31:21 +02:00
"plt.plot([-0.1, 1.1], [0.57, -0.1], \"r--\", linewidth=3)\n",
"plt.axis([-0.1, 1.1, -0.1, 1.1])\n",
" \n",
"plt.subplots_adjust(right=1)\n",
"\n",
"save_fig(\"kernel_method_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gaussian RBF Kernel"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Next code example:**"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 18,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"rbf_kernel_svm_clf = make_pipeline(\n",
" StandardScaler(),\n",
" SVC(kernel=\"rbf\", gamma=5, C=0.001)\n",
")\n",
2016-09-27 23:31:21 +02:00
"rbf_kernel_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 19,
2016-09-27 23:31:21 +02:00
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 59\n",
"\n",
2016-09-27 23:31:21 +02:00
"from sklearn.svm import SVC\n",
"\n",
"gamma1, gamma2 = 0.1, 5\n",
"C1, C2 = 0.001, 1000\n",
"hyperparams = (gamma1, C1), (gamma1, C2), (gamma2, C1), (gamma2, C2)\n",
"\n",
"svm_clfs = []\n",
"for gamma, C in hyperparams:\n",
" rbf_kernel_svm_clf = make_pipeline(\n",
" StandardScaler(),\n",
" SVC(kernel=\"rbf\", gamma=gamma, C=C)\n",
" )\n",
2016-09-27 23:31:21 +02:00
" rbf_kernel_svm_clf.fit(X, y)\n",
" svm_clfs.append(rbf_kernel_svm_clf)\n",
"\n",
"fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10.5, 7), sharex=True, sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"for i, svm_clf in enumerate(svm_clfs):\n",
" plt.sca(axes[i // 2, i % 2])\n",
" plot_predictions(svm_clf, [-1.5, 2.45, -1, 1.5])\n",
" plot_dataset(X, y, [-1.5, 2.45, -1, 1.5])\n",
2016-09-27 23:31:21 +02:00
" gamma, C = hyperparams[i]\n",
" plt.title(fr\"$\\gamma = {gamma}, C = {C}$\")\n",
" if i in (0, 1):\n",
" plt.xlabel(\"\")\n",
" if i in (1, 3):\n",
" plt.ylabel(\"\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"moons_rbf_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# SVM Regression"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 20,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this code generates a simple linear dataset\n",
"np.random.seed(42)\n",
2016-09-27 23:31:21 +02:00
"m = 50\n",
"X = 2 * np.random.rand(m, 1)\n",
"y = (4 + 3 * X + np.random.randn(m, 1)).ravel()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Next code example:**"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 21,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.svm import LinearSVR\n",
2016-09-27 23:31:21 +02:00
"\n",
"svm_reg = make_pipeline(StandardScaler(),\n",
" LinearSVR(epsilon=0.5, random_state=42))\n",
2017-06-01 09:23:37 +02:00
"svm_reg.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 22,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 510\n",
2016-09-27 23:31:21 +02:00
"\n",
"def find_support_vectors(svm_reg, X, y):\n",
" y_pred = svm_reg.predict(X)\n",
" epsilon = svm_reg[-1].epsilon\n",
" off_margin = np.abs(y - y_pred) >= epsilon\n",
2016-09-27 23:31:21 +02:00
" return np.argwhere(off_margin)\n",
"\n",
"def plot_svm_regression(svm_reg, X, y, axes):\n",
" x1s = np.linspace(axes[0], axes[1], 100).reshape(100, 1)\n",
" y_pred = svm_reg.predict(x1s)\n",
" epsilon = svm_reg[-1].epsilon\n",
" plt.plot(x1s, y_pred, \"k-\", linewidth=2, label=r\"$\\hat{y}$\", zorder=-2)\n",
" plt.plot(x1s, y_pred + epsilon, \"k--\", zorder=-2)\n",
" plt.plot(x1s, y_pred - epsilon, \"k--\", zorder=-2)\n",
" plt.scatter(X[svm_reg._support], y[svm_reg._support], s=180,\n",
" facecolors='#AAA', zorder=-1)\n",
2016-09-27 23:31:21 +02:00
" plt.plot(X, y, \"bo\")\n",
" plt.xlabel(r\"$x_1$\")\n",
" plt.legend(loc=\"upper left\")\n",
2016-09-27 23:31:21 +02:00
" plt.axis(axes)\n",
"\n",
"svm_reg2 = make_pipeline(StandardScaler(),\n",
" LinearSVR(epsilon=1.2, random_state=42))\n",
"svm_reg2.fit(X, y)\n",
"\n",
"svm_reg._support = find_support_vectors(svm_reg, X, y)\n",
"svm_reg2._support = find_support_vectors(svm_reg2, X, y)\n",
"\n",
"eps_x1 = 1\n",
"eps_y_pred = svm_reg2.predict([[eps_x1]])\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 4), sharey=True)\n",
"plt.sca(axes[0])\n",
"plot_svm_regression(svm_reg, X, y, [0, 2, 3, 11])\n",
"plt.title(fr\"$\\epsilon = {svm_reg[-1].epsilon}$\")\n",
"plt.ylabel(r\"$y$\", rotation=0)\n",
"plt.grid()\n",
"plt.sca(axes[1])\n",
"plot_svm_regression(svm_reg2, X, y, [0, 2, 3, 11])\n",
"plt.title(fr\"$\\epsilon = {svm_reg2[-1].epsilon}$\")\n",
2016-09-27 23:31:21 +02:00
"plt.annotate(\n",
" '', xy=(eps_x1, eps_y_pred), xycoords='data',\n",
" xytext=(eps_x1, eps_y_pred - svm_reg2[-1].epsilon),\n",
2016-09-27 23:31:21 +02:00
" textcoords='data', arrowprops={'arrowstyle': '<->', 'linewidth': 1.5}\n",
" )\n",
"plt.text(0.90, 5.4, r\"$\\epsilon$\", fontsize=16)\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"svm_regression_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 23,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"# not in the book this code generates a simple quadratic dataset\n",
"np.random.seed(42)\n",
"m = 50\n",
"X = 2 * np.random.rand(m, 1) - 1\n",
"y = (0.2 + 0.1 * X + 0.5 * X ** 2 + np.random.randn(m, 1) / 10).ravel()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Next code example:**"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 24,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.svm import SVR\n",
"\n",
"svm_poly_reg = make_pipeline(StandardScaler(),\n",
" SVR(kernel=\"poly\", degree=2, C=0.01, epsilon=0.1))\n",
2017-06-01 09:23:37 +02:00
"svm_poly_reg.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 25,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 511\n",
"\n",
"svm_poly_reg2 = make_pipeline(StandardScaler(),\n",
" SVR(kernel=\"poly\", degree=2, C=100))\n",
"svm_poly_reg2.fit(X, y)\n",
"\n",
"svm_poly_reg._support = find_support_vectors(svm_poly_reg, X, y)\n",
"svm_poly_reg2._support = find_support_vectors(svm_poly_reg2, X, y)\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 4), sharey=True)\n",
"plt.sca(axes[0])\n",
"plot_svm_regression(svm_poly_reg, X, y, [-1, 1, 0, 1])\n",
"plt.title(f\"$degree={svm_poly_reg[-1].degree}, \"\n",
" f\"C={svm_poly_reg[-1].C}, \"\n",
" fr\"\\epsilon={svm_poly_reg[-1].epsilon}$\")\n",
"plt.ylabel(r\"$y$\", rotation=0)\n",
"plt.grid()\n",
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plot_svm_regression(svm_poly_reg2, X, y, [-1, 1, 0, 1])\n",
"plt.title(f\"$degree={svm_poly_reg2[-1].degree}, \"\n",
" f\"C={svm_poly_reg2[-1].C}, \"\n",
" fr\"\\epsilon={svm_poly_reg2[-1].epsilon}$\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"svm_with_polynomial_kernel_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Under the hood"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 26,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 512\n",
"\n",
"import matplotlib.patches as patches\n",
"\n",
2016-09-27 23:31:21 +02:00
"def plot_2D_decision_function(w, b, ylabel=True, x1_lim=[-3, 3]):\n",
" x1 = np.linspace(x1_lim[0], x1_lim[1], 200)\n",
" y = w * x1 + b\n",
" half_margin = 1 / w\n",
2016-09-27 23:31:21 +02:00
"\n",
" plt.plot(x1, y, \"b-\", linewidth=2, label=r\"$s = w_1 x_1$\")\n",
" plt.axhline(y=0, color='k', linewidth=1)\n",
" plt.axvline(x=0, color='k', linewidth=1)\n",
" rect = patches.Rectangle((-half_margin, -2), 2 * half_margin, 4,\n",
" edgecolor='none', facecolor='gray', alpha=0.2)\n",
" plt.gca().add_patch(rect)\n",
" plt.plot([-3, 3], [1, 1], \"k--\", linewidth=1)\n",
" plt.plot([-3, 3], [-1, -1], \"k--\", linewidth=1)\n",
" plt.plot(half_margin, 1, \"k.\")\n",
" plt.plot(-half_margin, -1, \"k.\")\n",
2016-09-27 23:31:21 +02:00
" plt.axis(x1_lim + [-2, 2])\n",
" plt.xlabel(r\"$x_1$\")\n",
2016-09-27 23:31:21 +02:00
" if ylabel:\n",
" plt.ylabel(\"$s$\", rotation=0, labelpad=5)\n",
" plt.legend()\n",
" plt.text(1.02, -1.6, \"Margin\", ha=\"left\", va=\"center\",\n",
" color=\"k\", fontsize=14)\n",
" plt.annotate(\n",
" '', xy=(-half_margin, -1.6), xytext=(half_margin, -1.6),\n",
" arrowprops={'ec': 'k', 'arrowstyle': '<->', 'linewidth': 1.5}\n",
" )\n",
" plt.title(fr\"$w_1 = {w}$\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 3.2), sharey=True)\n",
"plt.sca(axes[0])\n",
2016-09-27 23:31:21 +02:00
"plot_2D_decision_function(1, 0)\n",
"plt.grid()\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plot_2D_decision_function(0.5, 0, ylabel=False)\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"small_w_large_margin_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 27,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 513\n",
"\n",
"s = np.linspace(-2.5, 2.5, 200)\n",
"hinge_pos = np.where(1 - s < 0, 0, 1 - s) # max(0, 1 - s)\n",
"hinge_neg = np.where(1 + s < 0, 0, 1 + s) # max(0, 1 + s)\n",
"\n",
"titles = (r\"Hinge loss = $max(0, 1 - s\\,t)$\", r\"Squared Hinge loss\")\n",
"\n",
"fix, axs = plt.subplots(1, 2, sharey=True, figsize=(8.2, 3))\n",
"\n",
"for ax, loss_pos, loss_neg, title in zip(\n",
" axs, (hinge_pos, hinge_pos ** 2), (hinge_neg, hinge_neg ** 2), titles):\n",
" ax.plot(s, loss_pos, \"g-\", linewidth=2, zorder=10, label=\"$t=1$\")\n",
" ax.plot(s, loss_neg, \"r--\", linewidth=2, zorder=10, label=\"$t=-1$\")\n",
" ax.grid(True, which='both')\n",
" ax.axhline(y=0, color='k')\n",
" ax.axvline(x=0, color='k')\n",
" ax.set_xlabel(r\"$s = \\mathbf{w}^\\intercal \\mathbf{x} + b$\")\n",
" ax.axis([-2.5, 2.5, -0.5, 2.5])\n",
" ax.legend(loc=\"center right\")\n",
" ax.set_title(title)\n",
" ax.set_yticks(np.arange(0, 2.5, 1))\n",
" ax.set_aspect(\"equal\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"hinge_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Extra Material"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## Linear SVM classifier implementation using Batch Gradient Descent"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 28,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = (iris.target == 2)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 29,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator\n",
"\n",
"class MyLinearSVC(BaseEstimator):\n",
" def __init__(self, C=1, eta0=1, eta_d=10000, n_epochs=1000,\n",
" random_state=None):\n",
2016-09-27 23:31:21 +02:00
" self.C = C\n",
" self.eta0 = eta0\n",
" self.n_epochs = n_epochs\n",
" self.random_state = random_state\n",
" self.eta_d = eta_d\n",
"\n",
" def eta(self, epoch):\n",
" return self.eta0 / (epoch + self.eta_d)\n",
" \n",
" def fit(self, X, y):\n",
" # Random initialization\n",
" if self.random_state:\n",
" np.random.seed(self.random_state)\n",
" w = np.random.randn(X.shape[1], 1) # n feature weights\n",
2016-09-27 23:31:21 +02:00
" b = 0\n",
"\n",
" m = len(X)\n",
" t = np.array(y, dtype=np.float64).reshape(-1, 1) * 2 - 1\n",
2016-09-27 23:31:21 +02:00
" X_t = X * t\n",
" self.Js=[]\n",
"\n",
" # Training\n",
" for epoch in range(self.n_epochs):\n",
" support_vectors_idx = (X_t.dot(w) + t * b < 1).ravel()\n",
" X_t_sv = X_t[support_vectors_idx]\n",
" t_sv = t[support_vectors_idx]\n",
"\n",
" J = 1/2 * (w * w).sum() + self.C * ((1 - X_t_sv.dot(w)).sum() - b * t_sv.sum())\n",
2016-09-27 23:31:21 +02:00
" self.Js.append(J)\n",
"\n",
" w_gradient_vector = w - self.C * X_t_sv.sum(axis=0).reshape(-1, 1)\n",
" b_derivative = -self.C * t_sv.sum()\n",
2016-09-27 23:31:21 +02:00
" \n",
" w = w - self.eta(epoch) * w_gradient_vector\n",
" b = b - self.eta(epoch) * b_derivative\n",
" \n",
"\n",
" self.intercept_ = np.array([b])\n",
" self.coef_ = np.array([w])\n",
2017-12-19 22:40:17 +01:00
" support_vectors_idx = (X_t.dot(w) + t * b < 1).ravel()\n",
2016-09-27 23:31:21 +02:00
" self.support_vectors_ = X[support_vectors_idx]\n",
" return self\n",
"\n",
" def decision_function(self, X):\n",
" return X.dot(self.coef_[0]) + self.intercept_[0]\n",
"\n",
" def predict(self, X):\n",
" return self.decision_function(X) >= 0"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"C = 2\n",
"svm_clf = MyLinearSVC(C=C, eta0 = 10, eta_d = 1000, n_epochs=60000,\n",
" random_state=2)\n",
2016-09-27 23:31:21 +02:00
"svm_clf.fit(X, y)\n",
"svm_clf.predict(np.array([[5, 2], [4, 1]]))"
]
},
{
"cell_type": "code",
"execution_count": 31,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"plt.plot(range(svm_clf.n_epochs), svm_clf.Js)\n",
"plt.axis([0, svm_clf.n_epochs, 0, 100])\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.grid()\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 32,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"print(svm_clf.intercept_, svm_clf.coef_)"
]
},
{
"cell_type": "code",
"execution_count": 33,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"svm_clf2 = SVC(kernel=\"linear\", C=C)\n",
"svm_clf2.fit(X, y.ravel())\n",
"print(svm_clf2.intercept_, svm_clf2.coef_)"
]
},
{
"cell_type": "code",
"execution_count": 34,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"yr = y.ravel()\n",
"fig, axes = plt.subplots(ncols=2, figsize=(11, 3.2), sharey=True)\n",
"plt.sca(axes[0])\n",
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\", label=\"Iris virginica\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\", label=\"Not Iris virginica\")\n",
2016-09-27 23:31:21 +02:00
"plot_svc_decision_boundary(svm_clf, 4, 6)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.title(\"MyLinearSVC\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([4, 6, 0.8, 2.8])\n",
"plt.legend(loc=\"upper left\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\")\n",
"plot_svc_decision_boundary(svm_clf2, 4, 6)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.title(\"SVC\")\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"plt.grid()\n",
"\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 35,
2016-09-27 23:31:21 +02:00
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from sklearn.linear_model import SGDClassifier\n",
"\n",
"sgd_clf = SGDClassifier(loss=\"hinge\", alpha=0.017, max_iter=1000, tol=1e-3,\n",
" random_state=42)\n",
"sgd_clf.fit(X, y)\n",
2016-09-27 23:31:21 +02:00
"\n",
"m = len(X)\n",
"t = np.array(y).reshape(-1, 1) * 2 - 1 # -1 if t==0, +1 if t==1\n",
2016-09-27 23:31:21 +02:00
"X_b = np.c_[np.ones((m, 1)), X] # Add bias input x0=1\n",
"X_b_t = X_b * t\n",
"sgd_theta = np.r_[sgd_clf.intercept_[0], sgd_clf.coef_[0]]\n",
"print(sgd_theta)\n",
"support_vectors_idx = (X_b_t.dot(sgd_theta) < 1).ravel()\n",
"sgd_clf.support_vectors_ = X[support_vectors_idx]\n",
"sgd_clf.C = C\n",
"\n",
"plt.figure(figsize=(5.5,3.2))\n",
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\")\n",
"plot_svc_decision_boundary(sgd_clf, 4, 6)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.title(\"SGDClassifier\")\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## 1. to 8."
2016-09-27 23:31:21 +02:00
]
},
{
2017-06-01 09:23:37 +02:00
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"See appendix A."
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"# 9."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: Train a `LinearSVC` on a linearly separable dataset. Then train an `SVC` and a `SGDClassifier` on the same dataset. See if you can get them to produce roughly the same model._"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's use the Iris dataset: the Iris Setosa and Iris Versicolor classes are linearly separable."
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn import datasets\n",
"\n",
"iris = datasets.load_iris(as_frame=True)\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = iris.target\n",
2017-06-01 09:23:37 +02:00
"\n",
"setosa_or_versicolor = (y == 0) | (y == 1)\n",
"X = X[setosa_or_versicolor]\n",
"y = y[setosa_or_versicolor]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's build and train 3 models:\n",
"* Remember that `LinearSVC` uses `loss=\"squared_hinge\"` by default, so if we want all 3 models to produce similar results, we need to set `loss=\"hinge\"`.\n",
"* Also, the `SVC` class uses an RBF kernel by default, so we need to set `kernel=\"linear\"` to get similar results as the other two models.\n",
"* Lastly, the `SGDClassifier` class does not have a `C` hyperparameter, but it has another regularization hyperparameter called `alpha`, so we can tweak it to get similar results as the other two models."
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 37,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.svm import SVC, LinearSVC\n",
"from sklearn.linear_model import SGDClassifier\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"C = 5\n",
"alpha = 0.05\n",
2017-06-01 09:23:37 +02:00
"\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(X)\n",
"\n",
"lin_clf = LinearSVC(loss=\"hinge\", C=C, random_state=42).fit(X_scaled, y)\n",
"svc_clf = SVC(kernel=\"linear\", C=C).fit(X_scaled, y)\n",
"sgd_clf = SGDClassifier(alpha=alpha, random_state=42).fit(X_scaled, y)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's plot the decision boundaries of these three models:"
]
},
{
"cell_type": "code",
"execution_count": 38,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"def compute_decision_boundary(model):\n",
" w = -model.coef_[0, 0] / model.coef_[0, 1]\n",
" b = -model.intercept_[0] / model.coef_[0, 1]\n",
" return scaler.inverse_transform([[-10, -10 * w + b], [10, 10 * w + b]])\n",
2017-06-01 09:23:37 +02:00
"\n",
"lin_line = compute_decision_boundary(lin_clf)\n",
"svc_line = compute_decision_boundary(svc_clf)\n",
"sgd_line = compute_decision_boundary(sgd_clf)\n",
2017-06-01 09:23:37 +02:00
"\n",
"# Plot all three decision boundaries\n",
"plt.figure(figsize=(11, 4))\n",
"plt.plot(lin_line[:, 0], lin_line[:, 1], \"k:\", label=\"LinearSVC\")\n",
"plt.plot(svc_line[:, 0], svc_line[:, 1], \"b--\", linewidth=2, label=\"SVC\")\n",
"plt.plot(sgd_line[:, 0], sgd_line[:, 1], \"r-\", label=\"SGDClassifier\")\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\") # label=\"Iris versicolor\"\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\") # label=\"Iris setosa\"\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper center\")\n",
2017-06-01 09:23:37 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.grid()\n",
2017-06-01 09:23:37 +02:00
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Close enough!"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"# 10."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: Train an SVM classifier on the Wine dataset, which you can load using `sklearn.datasets.load_wine()`. This dataset contains the chemical analysis of 178 wine samples produced by 3 different cultivators: the goal is to train a classification model capable of predicting the cultivator based on the wine's chemical analysis. Since SVM classifiers are binary classifiers, you will need to use one-versus-all to classify all 3 classes. What accuracy can you reach?_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"First, let's fetch the dataset, look at its description, then split it into a training set and a test set:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import load_wine\n",
"\n",
"wine = load_wine(as_frame=True)"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 40,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"print(wine.DESCR)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 41,
2017-12-19 22:40:17 +01:00
"metadata": {},
"outputs": [],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" wine.data, wine.target, random_state=42)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 42,
2017-12-19 22:40:17 +01:00
"metadata": {},
"outputs": [],
2017-06-01 09:23:37 +02:00
"source": [
"X_train.head()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 43,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"y_train.head()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's start simple, with a linear SVM classifier. It will automatically use the One-vs-All (also called One-vs-the-Rest, OvR) strategy, so there's nothing special we need to do to handle multiple classes. Easy, right?"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 44,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"lin_clf = LinearSVC(random_state=42)\n",
"lin_clf.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Oh no! It failed to converge. Can you guess why? Do you think we must just increase the number of training iterations? Let's see:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 45,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"lin_clf = LinearSVC(max_iter=1_000_000, random_state=42)\n",
"lin_clf.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Even with one million iterations, it still did not converge. There must be another problem.\n",
"\n",
"Let's still evaluate this model with `cross_val_score`, it will serve as a baseline:"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 46,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.model_selection import cross_val_score\n",
"\n",
"cross_val_score(lin_clf, X_train, y_train).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Well 91% accuracy on this dataset is not great. So did you guess what the problem is?\n",
"\n",
"That's right, we forgot to scale the features! Always remember to scale the features when using SVMs:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 47,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"lin_clf = make_pipeline(StandardScaler(),\n",
" LinearSVC(random_state=42))\n",
"lin_clf.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Now it converges without any problem. Let's measure its performance:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 48,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.model_selection import cross_val_score\n",
2017-06-01 09:23:37 +02:00
"\n",
"cross_val_score(lin_clf, X_train, y_train).mean()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Nice! We get 97.7% accuracy, that's much better."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's see if a kernelized SVM will do better. We will use a default `SVC` for now:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 49,
2017-12-19 22:40:17 +01:00
"metadata": {},
"outputs": [],
2017-06-01 09:23:37 +02:00
"source": [
"svm_clf = make_pipeline(StandardScaler(), SVC(random_state=42))\n",
"cross_val_score(svm_clf, X_train, y_train).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's not better, but perhaps we need to do a bit of hyperparameter tuning:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 50,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.model_selection import RandomizedSearchCV\n",
"from scipy.stats import reciprocal, uniform\n",
"\n",
"param_distrib = {\n",
" \"svc__gamma\": reciprocal(0.001, 0.1),\n",
" \"svc__C\": uniform(1, 10)\n",
"}\n",
"rnd_search_cv = RandomizedSearchCV(svm_clf, param_distrib, n_iter=100, cv=5,\n",
" random_state=42)\n",
"rnd_search_cv.fit(X_train, y_train)\n",
"rnd_search_cv.best_estimator_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 51,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"rnd_search_cv.best_score_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Ah, this looks excellent! Let's select this model. Now we can test it on the test set:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 52,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"rnd_search_cv.score(X_test, y_test)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"This tuned kernelized SVM performs better than the `LinearSVC` model, but we get a lower score on the test set than we measured using cross-validation. This is quite common: since we did so much hyperparameter tuning, we ended up slightly overfitting the cross-validation test sets. It's tempting to tweak the hyperparameters a bit more until we get a better result on the test set, but we this would probably not help, as we would just start overfitting the test set. Anyway, this score is not bad at all, so let's stop here."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"## 11."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: Train and fine-tune an SVM regressor on the California housing dataset. You can use the original dataset rather than the tweaked version we used in Chapter 2. The original dataset can be fetched using `sklearn.datasets.fetch_california_housing()`. The labels represent hundreds of thousands of dollars. Since there are over 20,000 instances, SVMs can be slow, so for hyperparameter tuning you should use much less instances (e.g., 2,000), to test many more hyperparameter combinations. What is your best model's RMSE?_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's load the dataset:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.datasets import fetch_california_housing\n",
"\n",
"housing = fetch_california_housing()\n",
"X = housing.data\n",
"y = housing.target"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Split it into a training set and a test set:"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,\n",
" random_state=42)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Don't forget to scale the data:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's train a simple `LinearSVR` first:"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.svm import LinearSVR\n",
2017-06-01 09:23:37 +02:00
"\n",
"lin_svr = make_pipeline(StandardScaler(), LinearSVR(random_state=42))\n",
"lin_svr.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"It did not converge, so let's increase `max_iter`:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 56,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"lin_svr = make_pipeline(StandardScaler(),\n",
" LinearSVR(max_iter=5000, random_state=42))\n",
"lin_svr.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's see how it performs on the training set:"
]
},
{
"cell_type": "code",
"execution_count": 57,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error\n",
"\n",
"y_pred = lin_svr.predict(X_train)\n",
2017-06-01 09:23:37 +02:00
"mse = mean_squared_error(y_train, y_pred)\n",
"mse"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's look at the RMSE:"
]
},
{
"cell_type": "code",
"execution_count": 58,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"np.sqrt(mse)"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"In this dataset, the targets represent hundreds of thousands of dollars. The RMSE gives a rough idea of the kind of error you should expect (with a higher weight for large errors): so with this model we can expect errors close to $98,000! Not great. Let's see if we can do better with an RBF Kernel. We will use randomized search with cross validation to find the appropriate hyperparameter values for `C` and `gamma`:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 59,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.svm import SVR\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"from scipy.stats import reciprocal, uniform\n",
"\n",
"svm_clf = make_pipeline(StandardScaler(), SVR())\n",
"\n",
"param_distrib = {\n",
" \"svr__gamma\": reciprocal(0.001, 0.1),\n",
" \"svr__C\": uniform(1, 10)\n",
"}\n",
"rnd_search_cv = RandomizedSearchCV(svm_clf, param_distrib,\n",
" n_iter=100, cv=3, random_state=42)\n",
"rnd_search_cv.fit(X_train[:2000], y_train[:2000])"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 60,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"rnd_search_cv.best_estimator_"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"-cross_val_score(rnd_search_cv.best_estimator_, X_train, y_train,\n",
" scoring=\"neg_root_mean_squared_error\")"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Looks much better than the linear model. Let's select this model and evaluate it on the test set:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 62,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"y_pred = rnd_search_cv.best_estimator_.predict(X_test)\n",
"rmse = mean_squared_error(y_test, y_pred, squared=False)\n",
"rmse"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"So SVMs worked very well on the Wine dataset, but not so much on the California Housing dataset. In Chapter 2, we found that Random Forests worked better for that dataset."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"And that's all for today!"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
2016-09-27 23:31:21 +02:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-10-17 03:27:34 +02:00
"version": "3.8.12"
2016-09-27 23:31:21 +02:00
},
"nav_menu": {},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
2016-09-27 23:31:21 +02:00
}