handson-ml/05_support_vector_machines....

1773 lines
56 KiB
Plaintext
Raw Normal View History

2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"**Support Vector Machines**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_This notebook is an extra chapter on Support Vector Machines. It also includes exercises and their solutions at the end._"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/05_support_vector_machines.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/05_support_vector_machines.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
" </td>\n",
"</table>"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"This project requires Python 3.8 or above:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 1,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"import sys\n",
2016-09-27 23:31:21 +02:00
"\n",
"assert sys.version_info >= (3, 8)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It also requires Scikit-Learn ≥ 1.0.1:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import sklearn\n",
2016-09-27 23:31:21 +02:00
"\n",
"assert sklearn.__version__ >= \"1.0.1\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we did in previous chapters, let's define the default font sizes to make the figures prettier:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib as mpl\n",
2016-09-27 23:31:21 +02:00
"\n",
"mpl.rc('font', size=12)\n",
"mpl.rc('axes', labelsize=14, titlesize=14)\n",
"mpl.rc('legend', fontsize=14)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And let's create the `images/svm` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"IMAGES_PATH = Path() / \"images\" / \"svm\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
2016-09-27 23:31:21 +02:00
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Linear SVM Classification"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The book starts with a few figures, before the first code example, so the next three cells generate and save these figures. You can skip them if you want."
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 5,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 51\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
2016-09-27 23:31:21 +02:00
"from sklearn.svm import SVC\n",
"from sklearn import datasets\n",
"\n",
"iris = datasets.load_iris(as_frame=True)\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = iris.target\n",
2016-09-27 23:31:21 +02:00
"\n",
2017-02-17 14:47:18 +01:00
"setosa_or_versicolor = (y == 0) | (y == 1)\n",
"X = X[setosa_or_versicolor]\n",
"y = y[setosa_or_versicolor]\n",
2016-09-27 23:31:21 +02:00
"\n",
"# SVM Classifier model\n",
"svm_clf = SVC(kernel=\"linear\", C=float(\"inf\"))\n",
"svm_clf.fit(X, y)\n",
"\n",
2016-09-27 23:31:21 +02:00
"# Bad models\n",
"x0 = np.linspace(0, 5.5, 200)\n",
"pred_1 = 5 * x0 - 20\n",
2016-09-27 23:31:21 +02:00
"pred_2 = x0 - 1.8\n",
"pred_3 = 0.1 * x0 + 0.5\n",
"\n",
"def plot_svc_decision_boundary(svm_clf, xmin, xmax):\n",
" w = svm_clf.coef_[0]\n",
" b = svm_clf.intercept_[0]\n",
"\n",
" # At the decision boundary, w0*x0 + w1*x1 + b = 0\n",
" # => x1 = -w0/w1 * x0 - b/w1\n",
" x0 = np.linspace(xmin, xmax, 200)\n",
" decision_boundary = -w[0] / w[1] * x0 - b / w[1]\n",
2016-09-27 23:31:21 +02:00
"\n",
" margin = 1/w[1]\n",
" gutter_up = decision_boundary + margin\n",
" gutter_down = decision_boundary - margin\n",
" svs = svm_clf.support_vectors_\n",
"\n",
" plt.plot(x0, decision_boundary, \"k-\", linewidth=2, zorder=-2)\n",
" plt.plot(x0, gutter_up, \"k--\", linewidth=2, zorder=-2)\n",
" plt.plot(x0, gutter_down, \"k--\", linewidth=2, zorder=-2)\n",
" plt.scatter(svs[:, 0], svs[:, 1], s=180, facecolors='#AAA',\n",
" zorder=-1)\n",
2016-09-27 23:31:21 +02:00
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10,2.7), sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[0])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(x0, pred_1, \"g--\", linewidth=2)\n",
"plt.plot(x0, pred_2, \"m-\", linewidth=2)\n",
"plt.plot(x0, pred_3, \"r-\", linewidth=2)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\", label=\"Iris versicolor\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\", label=\"Iris setosa\")\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper left\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plot_svc_decision_boundary(svm_clf, 0, 5.5)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\")\n",
"plt.xlabel(\"Petal length\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"large_margin_classification_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 52\n",
"\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
2016-09-27 23:31:21 +02:00
"Xs = np.array([[1, 50], [5, 20], [3, 80], [5, 60]]).astype(np.float64)\n",
"ys = np.array([0, 0, 1, 1])\n",
"svm_clf = SVC(kernel=\"linear\", C=100).fit(Xs, ys)\n",
"\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(Xs)\n",
"svm_clf_scaled = SVC(kernel=\"linear\", C=100).fit(X_scaled, ys)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.figure(figsize=(9,2.7))\n",
2016-09-27 23:31:21 +02:00
"plt.subplot(121)\n",
"plt.plot(Xs[:, 0][ys==1], Xs[:, 1][ys==1], \"bo\")\n",
"plt.plot(Xs[:, 0][ys==0], Xs[:, 1][ys==0], \"ms\")\n",
"plot_svc_decision_boundary(svm_clf, 0, 6)\n",
"plt.xlabel(\"$x_0$\")\n",
"plt.ylabel(\"$x_1$    \", rotation=0)\n",
"plt.title(\"Unscaled\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([0, 6, 0, 90])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.subplot(122)\n",
"plt.plot(X_scaled[:, 0][ys==1], X_scaled[:, 1][ys==1], \"bo\")\n",
"plt.plot(X_scaled[:, 0][ys==0], X_scaled[:, 1][ys==0], \"ms\")\n",
"plot_svc_decision_boundary(svm_clf_scaled, -2, 2)\n",
"plt.xlabel(\"$x'_0$\")\n",
"plt.ylabel(\"$x'_1$ \", rotation=0)\n",
"plt.title(\"Scaled\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([-2, 2, -2, 2])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"sensitivity_to_feature_scales_plot\")\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## Soft Margin Classification"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 7,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 53\n",
"\n",
2016-09-27 23:31:21 +02:00
"X_outliers = np.array([[3.4, 1.3], [3.2, 0.8]])\n",
"y_outliers = np.array([0, 0])\n",
"Xo1 = np.concatenate([X, X_outliers[:1]], axis=0)\n",
"yo1 = np.concatenate([y, y_outliers[:1]], axis=0)\n",
"Xo2 = np.concatenate([X, X_outliers[1:]], axis=0)\n",
"yo2 = np.concatenate([y, y_outliers[1:]], axis=0)\n",
"\n",
"svm_clf2 = SVC(kernel=\"linear\", C=10**9)\n",
2016-09-27 23:31:21 +02:00
"svm_clf2.fit(Xo2, yo2)\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10,2.7), sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[0])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(Xo1[:, 0][yo1==1], Xo1[:, 1][yo1==1], \"bs\")\n",
"plt.plot(Xo1[:, 0][yo1==0], Xo1[:, 1][yo1==0], \"yo\")\n",
"plt.text(0.3, 1.0, \"Impossible!\", color=\"red\", fontsize=18)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
2016-09-27 23:31:21 +02:00
"plt.annotate(\"Outlier\",\n",
" xy=(X_outliers[0][0], X_outliers[0][1]),\n",
" xytext=(2.5, 1.7),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=14,\n",
2016-09-27 23:31:21 +02:00
" )\n",
"plt.axis([0, 5.5, 0, 2])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(Xo2[:, 0][yo2==1], Xo2[:, 1][yo2==1], \"bs\")\n",
"plt.plot(Xo2[:, 0][yo2==0], Xo2[:, 1][yo2==0], \"yo\")\n",
"plot_svc_decision_boundary(svm_clf2, 0, 5.5)\n",
"plt.xlabel(\"Petal length\")\n",
2016-09-27 23:31:21 +02:00
"plt.annotate(\"Outlier\",\n",
" xy=(X_outliers[1][0], X_outliers[1][1]),\n",
" xytext=(3.2, 0.08),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=14,\n",
2016-09-27 23:31:21 +02:00
" )\n",
"plt.axis([0, 5.5, 0, 2])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"sensitivity_to_outliers_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 8,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2017-06-01 09:23:37 +02:00
"import numpy as np\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.pipeline import make_pipeline\n",
2016-09-27 23:31:21 +02:00
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.svm import LinearSVC\n",
"\n",
"iris = load_iris(as_frame=True)\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = (iris.target == 2) # Iris virginica\n",
2017-06-01 09:23:37 +02:00
"\n",
"svm_clf = make_pipeline(StandardScaler(),\n",
" LinearSVC(C=1, random_state=42))\n",
2017-06-01 09:23:37 +02:00
"svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 9,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"X_new = [[5.5, 1.7], [5.0, 1.5]]\n",
"svm_clf.predict(X_new)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 10,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"svm_clf.decision_function(X_new)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
2017-06-01 09:23:37 +02:00
"source": [
"# not in the book this cell generates and saves Figure 54\n",
"\n",
2016-09-27 23:31:21 +02:00
"scaler = StandardScaler()\n",
"svm_clf1 = LinearSVC(C=1, max_iter=10_000, random_state=42)\n",
"svm_clf2 = LinearSVC(C=100, max_iter=10_000, random_state=42)\n",
"\n",
"scaled_svm_clf1 = make_pipeline(scaler, svm_clf1)\n",
"scaled_svm_clf2 = make_pipeline(scaler, svm_clf2)\n",
2016-09-27 23:31:21 +02:00
"\n",
"scaled_svm_clf1.fit(X, y)\n",
"scaled_svm_clf2.fit(X, y)\n",
"\n",
2016-09-27 23:31:21 +02:00
"# Convert to unscaled parameters\n",
"b1 = svm_clf1.decision_function([-scaler.mean_ / scaler.scale_])\n",
"b2 = svm_clf2.decision_function([-scaler.mean_ / scaler.scale_])\n",
"w1 = svm_clf1.coef_[0] / scaler.scale_\n",
"w2 = svm_clf2.coef_[0] / scaler.scale_\n",
"svm_clf1.intercept_ = np.array([b1])\n",
"svm_clf2.intercept_ = np.array([b2])\n",
"svm_clf1.coef_ = np.array([w1])\n",
"svm_clf2.coef_ = np.array([w2])\n",
"\n",
"# Find support vectors (LinearSVC does not do this automatically)\n",
"t = y * 2 - 1\n",
"support_vectors_idx1 = (t * (X.dot(w1) + b1) < 1).ravel()\n",
"support_vectors_idx2 = (t * (X.dot(w2) + b2) < 1).ravel()\n",
"svm_clf1.support_vectors_ = X[support_vectors_idx1]\n",
"svm_clf2.support_vectors_ = X[support_vectors_idx2]\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10,2.7), sharey=True)\n",
"\n",
"plt.sca(axes[0])\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\", label=\"Iris virginica\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\", label=\"Iris versicolor\")\n",
"plot_svc_decision_boundary(svm_clf1, 4, 5.9)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper left\")\n",
"plt.title(f\"$C = {svm_clf1.C}$\")\n",
"plt.axis([4, 5.9, 0.8, 2.8])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
"plot_svc_decision_boundary(svm_clf2, 4, 5.99)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.title(f\"$C = {svm_clf2.C}$\")\n",
"plt.axis([4, 5.9, 0.8, 2.8])\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"regularization_plot\")\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Nonlinear SVM Classification"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 12,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 55\n",
"\n",
2016-09-27 23:31:21 +02:00
"X1D = np.linspace(-4, 4, 9).reshape(-1, 1)\n",
"X2D = np.c_[X1D, X1D**2]\n",
"y = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n",
"\n",
"plt.figure(figsize=(10, 3))\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.subplot(121)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.plot(X1D[:, 0][y==0], np.zeros(4), \"bs\")\n",
"plt.plot(X1D[:, 0][y==1], np.zeros(5), \"g^\")\n",
"plt.gca().get_yaxis().set_ticks([])\n",
"plt.xlabel(r\"$x_1$\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([-4.5, 4.5, -0.2, 0.2])\n",
"\n",
"plt.subplot(122)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n",
"plt.plot(X2D[:, 0][y==0], X2D[:, 1][y==0], \"bs\")\n",
"plt.plot(X2D[:, 0][y==1], X2D[:, 1][y==1], \"g^\")\n",
"plt.xlabel(r\"$x_1$\")\n",
"plt.ylabel(r\"$x_2$  \", rotation=0)\n",
2016-09-27 23:31:21 +02:00
"plt.gca().get_yaxis().set_ticks([0, 4, 8, 12, 16])\n",
"plt.plot([-4.5, 4.5], [6.5, 6.5], \"r--\", linewidth=3)\n",
"plt.axis([-4.5, 4.5, -1, 17])\n",
"\n",
"plt.subplots_adjust(right=1)\n",
"\n",
"save_fig(\"higher_dimensions_plot\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 13,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2017-06-01 09:23:37 +02:00
"from sklearn.datasets import make_moons\n",
2016-09-27 23:31:21 +02:00
"from sklearn.preprocessing import PolynomialFeatures\n",
"\n",
"X, y = make_moons(n_samples=100, noise=0.15, random_state=42)\n",
2016-09-27 23:31:21 +02:00
"\n",
"polynomial_svm_clf = make_pipeline(\n",
" PolynomialFeatures(degree=3),\n",
" StandardScaler(),\n",
" LinearSVC(C=10, max_iter=10_000, random_state=42)\n",
")\n",
2017-06-01 09:23:37 +02:00
"polynomial_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 14,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 56\n",
"\n",
"def plot_dataset(X, y, axes):\n",
" plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
" plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n",
" plt.axis(axes)\n",
" plt.grid(True, which='both')\n",
" plt.xlabel(r\"$x_1$\")\n",
" plt.ylabel(r\"$x_2$\", rotation=0)\n",
"\n",
2016-09-27 23:31:21 +02:00
"def plot_predictions(clf, axes):\n",
" x0s = np.linspace(axes[0], axes[1], 100)\n",
" x1s = np.linspace(axes[2], axes[3], 100)\n",
" x0, x1 = np.meshgrid(x0s, x1s)\n",
" X = np.c_[x0.ravel(), x1.ravel()]\n",
" y_pred = clf.predict(X).reshape(x0.shape)\n",
" y_decision = clf.decision_function(X).reshape(x0.shape)\n",
" plt.contourf(x0, x1, y_pred, cmap=plt.cm.brg, alpha=0.2)\n",
" plt.contourf(x0, x1, y_decision, cmap=plt.cm.brg, alpha=0.1)\n",
"\n",
"plot_predictions(polynomial_svm_clf, [-1.5, 2.5, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
"\n",
"save_fig(\"moons_polynomial_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Polynomial Kernel"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 15,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"from sklearn.svm import SVC\n",
2017-06-01 09:23:37 +02:00
"\n",
"poly_kernel_svm_clf = make_pipeline(StandardScaler(),\n",
" SVC(kernel=\"poly\", degree=3, coef0=1, C=5))\n",
2017-06-01 09:23:37 +02:00
"poly_kernel_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 16,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 57\n",
"\n",
"poly100_kernel_svm_clf = make_pipeline(\n",
" StandardScaler(),\n",
" SVC(kernel=\"poly\", degree=10, coef0=100, C=5)\n",
")\n",
"poly100_kernel_svm_clf.fit(X, y)\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10.5, 4), sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[0])\n",
"plot_predictions(poly_kernel_svm_clf, [-1.5, 2.45, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.4, -1, 1.5])\n",
"plt.title(r\"$d=3, r=1, C=5$\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
"plot_predictions(poly100_kernel_svm_clf, [-1.5, 2.45, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.4, -1, 1.5])\n",
"plt.title(r\"$d=10, r=100, C=5$\")\n",
"plt.ylabel(\"\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"moons_kernelized_polynomial_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Similarity Features"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 17,
2016-09-27 23:31:21 +02:00
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 58\n",
"\n",
2016-09-27 23:31:21 +02:00
"def gaussian_rbf(x, landmark, gamma):\n",
" return np.exp(-gamma * np.linalg.norm(x - landmark, axis=1)**2)\n",
"\n",
"gamma = 0.3\n",
"\n",
"x1s = np.linspace(-4.5, 4.5, 200).reshape(-1, 1)\n",
"x2s = gaussian_rbf(x1s, -2, gamma)\n",
"x3s = gaussian_rbf(x1s, 1, gamma)\n",
"\n",
"XK = np.c_[gaussian_rbf(X1D, -2, gamma), gaussian_rbf(X1D, 1, gamma)]\n",
"yk = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n",
"\n",
"plt.figure(figsize=(10.5, 4))\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.subplot(121)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.scatter(x=[-2, 1], y=[0, 0], s=150, alpha=0.5, c=\"red\")\n",
"plt.plot(X1D[:, 0][yk==0], np.zeros(4), \"bs\")\n",
"plt.plot(X1D[:, 0][yk==1], np.zeros(5), \"g^\")\n",
"plt.plot(x1s, x2s, \"g--\")\n",
"plt.plot(x1s, x3s, \"b:\")\n",
"plt.gca().get_yaxis().set_ticks([0, 0.25, 0.5, 0.75, 1])\n",
"plt.xlabel(r\"$x_1$\")\n",
"plt.ylabel(r\"Similarity\")\n",
"plt.annotate(\n",
" r'$\\mathbf{x}$',\n",
" xy=(X1D[3, 0], 0),\n",
" xytext=(-0.5, 0.20),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=16,\n",
")\n",
"plt.text(-2, 0.9, \"$x_2$\", ha=\"center\", fontsize=15)\n",
"plt.text(1, 0.9, \"$x_3$\", ha=\"center\", fontsize=15)\n",
2016-09-27 23:31:21 +02:00
"plt.axis([-4.5, 4.5, -0.1, 1.1])\n",
"\n",
"plt.subplot(122)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n",
"plt.plot(XK[:, 0][yk==0], XK[:, 1][yk==0], \"bs\")\n",
"plt.plot(XK[:, 0][yk==1], XK[:, 1][yk==1], \"g^\")\n",
"plt.xlabel(r\"$x_2$\")\n",
"plt.ylabel(r\"$x_3$  \", rotation=0)\n",
"plt.annotate(\n",
" r'$\\phi\\left(\\mathbf{x}\\right)$',\n",
" xy=(XK[3, 0], XK[3, 1]),\n",
" xytext=(0.65, 0.50),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=16,\n",
")\n",
2016-09-27 23:31:21 +02:00
"plt.plot([-0.1, 1.1], [0.57, -0.1], \"r--\", linewidth=3)\n",
"plt.axis([-0.1, 1.1, -0.1, 1.1])\n",
" \n",
"plt.subplots_adjust(right=1)\n",
"\n",
"save_fig(\"kernel_method_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gaussian RBF Kernel"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 18,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"rbf_kernel_svm_clf = make_pipeline(StandardScaler(),\n",
" SVC(kernel=\"rbf\", gamma=5, C=0.001))\n",
2016-09-27 23:31:21 +02:00
"rbf_kernel_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 19,
2016-09-27 23:31:21 +02:00
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 59\n",
"\n",
2016-09-27 23:31:21 +02:00
"from sklearn.svm import SVC\n",
"\n",
"gamma1, gamma2 = 0.1, 5\n",
"C1, C2 = 0.001, 1000\n",
"hyperparams = (gamma1, C1), (gamma1, C2), (gamma2, C1), (gamma2, C2)\n",
"\n",
"svm_clfs = []\n",
"for gamma, C in hyperparams:\n",
" rbf_kernel_svm_clf = make_pipeline(\n",
" StandardScaler(),\n",
" SVC(kernel=\"rbf\", gamma=gamma, C=C)\n",
" )\n",
2016-09-27 23:31:21 +02:00
" rbf_kernel_svm_clf.fit(X, y)\n",
" svm_clfs.append(rbf_kernel_svm_clf)\n",
"\n",
"fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10.5, 7), sharex=True, sharey=True)\n",
2016-09-27 23:31:21 +02:00
"\n",
"for i, svm_clf in enumerate(svm_clfs):\n",
" plt.sca(axes[i // 2, i % 2])\n",
" plot_predictions(svm_clf, [-1.5, 2.45, -1, 1.5])\n",
" plot_dataset(X, y, [-1.5, 2.45, -1, 1.5])\n",
2016-09-27 23:31:21 +02:00
" gamma, C = hyperparams[i]\n",
" plt.title(fr\"$\\gamma = {gamma}, C = {C}$\")\n",
" if i in (0, 1):\n",
" plt.xlabel(\"\")\n",
" if i in (1, 3):\n",
" plt.ylabel(\"\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"moons_rbf_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# SVM Regression"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 20,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.svm import LinearSVR\n",
2016-09-27 23:31:21 +02:00
"\n",
"# not in the book these 3 lines generate a simple linear dataset\n",
"np.random.seed(42)\n",
"X = 2 * np.random.rand(50, 1)\n",
"y = 4 + 3 * X[:, 0] + np.random.randn(50)\n",
"\n",
"svm_reg = make_pipeline(StandardScaler(),\n",
" LinearSVR(epsilon=0.5, random_state=42))\n",
2017-06-01 09:23:37 +02:00
"svm_reg.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 21,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 510\n",
2016-09-27 23:31:21 +02:00
"\n",
"def find_support_vectors(svm_reg, X, y):\n",
" y_pred = svm_reg.predict(X)\n",
" epsilon = svm_reg[-1].epsilon\n",
" off_margin = np.abs(y - y_pred) >= epsilon\n",
2016-09-27 23:31:21 +02:00
" return np.argwhere(off_margin)\n",
"\n",
"def plot_svm_regression(svm_reg, X, y, axes):\n",
" x1s = np.linspace(axes[0], axes[1], 100).reshape(100, 1)\n",
" y_pred = svm_reg.predict(x1s)\n",
" epsilon = svm_reg[-1].epsilon\n",
" plt.plot(x1s, y_pred, \"k-\", linewidth=2, label=r\"$\\hat{y}$\", zorder=-2)\n",
" plt.plot(x1s, y_pred + epsilon, \"k--\", zorder=-2)\n",
" plt.plot(x1s, y_pred - epsilon, \"k--\", zorder=-2)\n",
" plt.scatter(X[svm_reg._support], y[svm_reg._support], s=180,\n",
" facecolors='#AAA', zorder=-1)\n",
2016-09-27 23:31:21 +02:00
" plt.plot(X, y, \"bo\")\n",
" plt.xlabel(r\"$x_1$\")\n",
" plt.legend(loc=\"upper left\")\n",
2016-09-27 23:31:21 +02:00
" plt.axis(axes)\n",
"\n",
"svm_reg2 = make_pipeline(StandardScaler(),\n",
" LinearSVR(epsilon=1.2, random_state=42))\n",
"svm_reg2.fit(X, y)\n",
"\n",
"svm_reg._support = find_support_vectors(svm_reg, X, y)\n",
"svm_reg2._support = find_support_vectors(svm_reg2, X, y)\n",
"\n",
"eps_x1 = 1\n",
"eps_y_pred = svm_reg2.predict([[eps_x1]])\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 4), sharey=True)\n",
"plt.sca(axes[0])\n",
"plot_svm_regression(svm_reg, X, y, [0, 2, 3, 11])\n",
"plt.title(fr\"$\\epsilon = {svm_reg[-1].epsilon}$\")\n",
"plt.ylabel(r\"$y$\", rotation=0)\n",
"plt.grid()\n",
"plt.sca(axes[1])\n",
"plot_svm_regression(svm_reg2, X, y, [0, 2, 3, 11])\n",
"plt.title(fr\"$\\epsilon = {svm_reg2[-1].epsilon}$\")\n",
2016-09-27 23:31:21 +02:00
"plt.annotate(\n",
" '', xy=(eps_x1, eps_y_pred), xycoords='data',\n",
" xytext=(eps_x1, eps_y_pred - svm_reg2[-1].epsilon),\n",
2016-09-27 23:31:21 +02:00
" textcoords='data', arrowprops={'arrowstyle': '<->', 'linewidth': 1.5}\n",
" )\n",
"plt.text(0.90, 5.4, r\"$\\epsilon$\", fontsize=16)\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"svm_regression_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 22,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.svm import SVR\n",
"\n",
"# not in the book these 3 lines generate a simple quadratic dataset\n",
"np.random.seed(42)\n",
"X = 2 * np.random.rand(50, 1) - 1\n",
"y = 0.2 + 0.1 * X[:, 0] + 0.5 * X[:, 0] ** 2 + np.random.randn(50) / 10\n",
"\n",
"svm_poly_reg = make_pipeline(StandardScaler(),\n",
" SVR(kernel=\"poly\", degree=2, C=0.01, epsilon=0.1))\n",
2017-06-01 09:23:37 +02:00
"svm_poly_reg.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 23,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 511\n",
"\n",
"svm_poly_reg2 = make_pipeline(StandardScaler(),\n",
" SVR(kernel=\"poly\", degree=2, C=100))\n",
"svm_poly_reg2.fit(X, y)\n",
"\n",
"svm_poly_reg._support = find_support_vectors(svm_poly_reg, X, y)\n",
"svm_poly_reg2._support = find_support_vectors(svm_poly_reg2, X, y)\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 4), sharey=True)\n",
"plt.sca(axes[0])\n",
"plot_svm_regression(svm_poly_reg, X, y, [-1, 1, 0, 1])\n",
"plt.title(f\"$degree={svm_poly_reg[-1].degree}, \"\n",
" f\"C={svm_poly_reg[-1].C}, \"\n",
" fr\"\\epsilon={svm_poly_reg[-1].epsilon}$\")\n",
"plt.ylabel(r\"$y$\", rotation=0)\n",
"plt.grid()\n",
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plot_svm_regression(svm_poly_reg2, X, y, [-1, 1, 0, 1])\n",
"plt.title(f\"$degree={svm_poly_reg2[-1].degree}, \"\n",
" f\"C={svm_poly_reg2[-1].C}, \"\n",
" fr\"\\epsilon={svm_poly_reg2[-1].epsilon}$\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"svm_with_polynomial_kernel_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Under the hood"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 24,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 512\n",
"\n",
"import matplotlib.patches as patches\n",
"\n",
2016-09-27 23:31:21 +02:00
"def plot_2D_decision_function(w, b, ylabel=True, x1_lim=[-3, 3]):\n",
" x1 = np.linspace(x1_lim[0], x1_lim[1], 200)\n",
" y = w * x1 + b\n",
" half_margin = 1 / w\n",
2016-09-27 23:31:21 +02:00
"\n",
" plt.plot(x1, y, \"b-\", linewidth=2, label=r\"$s = w_1 x_1$\")\n",
" plt.axhline(y=0, color='k', linewidth=1)\n",
" plt.axvline(x=0, color='k', linewidth=1)\n",
" rect = patches.Rectangle((-half_margin, -2), 2 * half_margin, 4,\n",
" edgecolor='none', facecolor='gray', alpha=0.2)\n",
" plt.gca().add_patch(rect)\n",
" plt.plot([-3, 3], [1, 1], \"k--\", linewidth=1)\n",
" plt.plot([-3, 3], [-1, -1], \"k--\", linewidth=1)\n",
" plt.plot(half_margin, 1, \"k.\")\n",
" plt.plot(-half_margin, -1, \"k.\")\n",
2016-09-27 23:31:21 +02:00
" plt.axis(x1_lim + [-2, 2])\n",
" plt.xlabel(r\"$x_1$\")\n",
2016-09-27 23:31:21 +02:00
" if ylabel:\n",
" plt.ylabel(\"$s$\", rotation=0, labelpad=5)\n",
" plt.legend()\n",
" plt.text(1.02, -1.6, \"Margin\", ha=\"left\", va=\"center\",\n",
" color=\"k\", fontsize=14)\n",
" plt.annotate(\n",
" '', xy=(-half_margin, -1.6), xytext=(half_margin, -1.6),\n",
" arrowprops={'ec': 'k', 'arrowstyle': '<->', 'linewidth': 1.5}\n",
" )\n",
" plt.title(fr\"$w_1 = {w}$\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 3.2), sharey=True)\n",
"plt.sca(axes[0])\n",
2016-09-27 23:31:21 +02:00
"plot_2D_decision_function(1, 0)\n",
"plt.grid()\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plot_2D_decision_function(0.5, 0, ylabel=False)\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"small_w_large_margin_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 25,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"# not in the book this cell generates and saves Figure 513\n",
"\n",
"s = np.linspace(-2.5, 2.5, 200)\n",
"hinge_pos = np.where(1 - s < 0, 0, 1 - s) # max(0, 1 - s)\n",
"hinge_neg = np.where(1 + s < 0, 0, 1 + s) # max(0, 1 + s)\n",
"\n",
"titles = (r\"Hinge loss = $max(0, 1 - s\\,t)$\", r\"Squared Hinge loss\")\n",
"\n",
"fix, axs = plt.subplots(1, 2, sharey=True, figsize=(8.2, 3))\n",
"\n",
"for ax, loss_pos, loss_neg, title in zip(\n",
" axs, (hinge_pos, hinge_pos ** 2), (hinge_neg, hinge_neg ** 2), titles):\n",
" ax.plot(s, loss_pos, \"g-\", linewidth=2, zorder=10, label=\"$t=1$\")\n",
" ax.plot(s, loss_neg, \"r--\", linewidth=2, zorder=10, label=\"$t=-1$\")\n",
" ax.grid(True, which='both')\n",
" ax.axhline(y=0, color='k')\n",
" ax.axvline(x=0, color='k')\n",
" ax.set_xlabel(r\"$s = \\mathbf{w}^\\intercal \\mathbf{x} + b$\")\n",
" ax.axis([-2.5, 2.5, -0.5, 2.5])\n",
" ax.legend(loc=\"center right\")\n",
" ax.set_title(title)\n",
" ax.set_yticks(np.arange(0, 2.5, 1))\n",
" ax.set_aspect(\"equal\")\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"hinge_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Extra Material"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## Linear SVM classifier implementation using Batch Gradient Descent"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 26,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = (iris.target == 2)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 27,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator\n",
"\n",
"class MyLinearSVC(BaseEstimator):\n",
" def __init__(self, C=1, eta0=1, eta_d=10000, n_epochs=1000,\n",
" random_state=None):\n",
2016-09-27 23:31:21 +02:00
" self.C = C\n",
" self.eta0 = eta0\n",
" self.n_epochs = n_epochs\n",
" self.random_state = random_state\n",
" self.eta_d = eta_d\n",
"\n",
" def eta(self, epoch):\n",
" return self.eta0 / (epoch + self.eta_d)\n",
" \n",
" def fit(self, X, y):\n",
" # Random initialization\n",
" if self.random_state:\n",
" np.random.seed(self.random_state)\n",
" w = np.random.randn(X.shape[1], 1) # n feature weights\n",
2016-09-27 23:31:21 +02:00
" b = 0\n",
"\n",
" m = len(X)\n",
" t = np.array(y, dtype=np.float64).reshape(-1, 1) * 2 - 1\n",
2016-09-27 23:31:21 +02:00
" X_t = X * t\n",
" self.Js=[]\n",
"\n",
" # Training\n",
" for epoch in range(self.n_epochs):\n",
" support_vectors_idx = (X_t.dot(w) + t * b < 1).ravel()\n",
" X_t_sv = X_t[support_vectors_idx]\n",
" t_sv = t[support_vectors_idx]\n",
"\n",
" J = 1/2 * (w * w).sum() + self.C * ((1 - X_t_sv.dot(w)).sum() - b * t_sv.sum())\n",
2016-09-27 23:31:21 +02:00
" self.Js.append(J)\n",
"\n",
" w_gradient_vector = w - self.C * X_t_sv.sum(axis=0).reshape(-1, 1)\n",
" b_derivative = -self.C * t_sv.sum()\n",
2016-09-27 23:31:21 +02:00
" \n",
" w = w - self.eta(epoch) * w_gradient_vector\n",
" b = b - self.eta(epoch) * b_derivative\n",
" \n",
"\n",
" self.intercept_ = np.array([b])\n",
" self.coef_ = np.array([w])\n",
2017-12-19 22:40:17 +01:00
" support_vectors_idx = (X_t.dot(w) + t * b < 1).ravel()\n",
2016-09-27 23:31:21 +02:00
" self.support_vectors_ = X[support_vectors_idx]\n",
" return self\n",
"\n",
" def decision_function(self, X):\n",
" return X.dot(self.coef_[0]) + self.intercept_[0]\n",
"\n",
" def predict(self, X):\n",
" return self.decision_function(X) >= 0"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"C = 2\n",
"svm_clf = MyLinearSVC(C=C, eta0 = 10, eta_d = 1000, n_epochs=60000,\n",
" random_state=2)\n",
2016-09-27 23:31:21 +02:00
"svm_clf.fit(X, y)\n",
"svm_clf.predict(np.array([[5, 2], [4, 1]]))"
]
},
{
"cell_type": "code",
"execution_count": 29,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"plt.plot(range(svm_clf.n_epochs), svm_clf.Js)\n",
"plt.axis([0, svm_clf.n_epochs, 0, 100])\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.grid()\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 30,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"print(svm_clf.intercept_, svm_clf.coef_)"
]
},
{
"cell_type": "code",
"execution_count": 31,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"svm_clf2 = SVC(kernel=\"linear\", C=C)\n",
"svm_clf2.fit(X, y.ravel())\n",
"print(svm_clf2.intercept_, svm_clf2.coef_)"
]
},
{
"cell_type": "code",
"execution_count": 32,
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"yr = y.ravel()\n",
"fig, axes = plt.subplots(ncols=2, figsize=(11, 3.2), sharey=True)\n",
"plt.sca(axes[0])\n",
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\", label=\"Iris virginica\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\", label=\"Not Iris virginica\")\n",
2016-09-27 23:31:21 +02:00
"plot_svc_decision_boundary(svm_clf, 4, 6)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.title(\"MyLinearSVC\")\n",
2016-09-27 23:31:21 +02:00
"plt.axis([4, 6, 0.8, 2.8])\n",
"plt.legend(loc=\"upper left\")\n",
"plt.grid()\n",
2016-09-27 23:31:21 +02:00
"\n",
"plt.sca(axes[1])\n",
2016-09-27 23:31:21 +02:00
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\")\n",
"plot_svc_decision_boundary(svm_clf2, 4, 6)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.title(\"SVC\")\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"plt.grid()\n",
"\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 33,
2016-09-27 23:31:21 +02:00
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from sklearn.linear_model import SGDClassifier\n",
"\n",
"sgd_clf = SGDClassifier(loss=\"hinge\", alpha=0.017, max_iter=1000, tol=1e-3,\n",
" random_state=42)\n",
"sgd_clf.fit(X, y)\n",
2016-09-27 23:31:21 +02:00
"\n",
"m = len(X)\n",
"t = np.array(y).reshape(-1, 1) * 2 - 1 # -1 if t==0, +1 if t==1\n",
2016-09-27 23:31:21 +02:00
"X_b = np.c_[np.ones((m, 1)), X] # Add bias input x0=1\n",
"X_b_t = X_b * t\n",
"sgd_theta = np.r_[sgd_clf.intercept_[0], sgd_clf.coef_[0]]\n",
"print(sgd_theta)\n",
"support_vectors_idx = (X_b_t.dot(sgd_theta) < 1).ravel()\n",
"sgd_clf.support_vectors_ = X[support_vectors_idx]\n",
"sgd_clf.C = C\n",
"\n",
"plt.figure(figsize=(5.5,3.2))\n",
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\")\n",
"plot_svc_decision_boundary(sgd_clf, 4, 6)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.title(\"SGDClassifier\")\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"\n",
"plt.show()"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"## 1. to 8."
2016-09-27 23:31:21 +02:00
]
},
{
2017-06-01 09:23:37 +02:00
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"1. The fundamental idea behind Support Vector Machines is to fit the widest possible \"street\" between the classes. In other words, the goal is to have the largest possible margin between the decision boundary that separates the two classes and the training instances. When performing soft margin classification, the SVM searches for a compromise between perfectly separating the two classes and having the widest possible street (i.e., a few instances may end up on the street). Another key idea is to use kernels when training on nonlinear datasets. SVMs can also be tweaked to perform linear and nonlinear regression, as well as novelty detection.\n",
"2. After training an SVM, a _support vector_ is any instance located on the \"street\" (see the previous answer), including its border. The decision boundary is entirely determined by the support vectors. Any instance that is _not_ a support vector (i.e., is off the street) has no influence whatsoever; you could remove them, add more instances, or move them around, and as long as they stay off the street they won't affect the decision boundary. Computing the predictions with a kernelized SVM only involves the support vectors, not the whole training set.\n",
"3. SVMs try to fit the largest possible \"street\" between the classes (see the first answer), so if the training set is not scaled, the SVM will tend to neglect small features (see Figure 52).\n",
"4. You can use the `decision_function()` method to get confidence scores. These scores represent the distance between the instance and the decision boundary. However, they cannot be directly converted into an estimation of the class probability. If you set `probability=True` when creating an `SVC`, then at the end of training it will use 5-fold cross-validation to generate out-of-sample scores for the training samples, and it will train a `LogisticRegression` model to map these scores to estimated probabilities. The `predict_proba()` and `predict_log_proba()` methods will then be available.\n",
"5. All three classes can be used for large-margin linear classification. The `SVC` class also supports the kernel trick, which makes it capable of handling nonlinear tasks. However, this comes at a cost: the `SVC` class does not scale well to datasets with many instances. It does scale well to a large number of features, though. The `LinearSVC` class implements an optimized algorithm for linear SVMs, while `SGDClassifier` uses Stochastic Gradient Descent. Depending on the dataset `LinearSVC` may be a bit faster than `SGDClassifier`, but not always, and `SGDClassifier` is more flexible, plus it supports incremental learning.\n",
"6. If an SVM classifier trained with an RBF kernel underfits the training set, there might be too much regularization. To decrease it, you need to increase `gamma` or `C` (or both).\n",
"7. A Regression SVM model tries to fit as many instances within a small margin around its predictions. If you add instances within this margin, the model will not be affected at all: it is said to be _ϵ-insensitive_.\n",
"8. The kernel trick is mathematical technique that makes it possible to train a nonlinear SVM model. The resulting model is equivalent to mapping the inputs to another space using a nonlinear transformation, then training a linear SVM on the resulting high-dimensional inputs. The kernel trick gives the same result without having to transform the inputs at all."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"# 9."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: Train a `LinearSVC` on a linearly separable dataset. Then train an `SVC` and a `SGDClassifier` on the same dataset. See if you can get them to produce roughly the same model._"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's use the Iris dataset: the Iris Setosa and Iris Versicolor classes are linearly separable."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn import datasets\n",
"\n",
"iris = datasets.load_iris(as_frame=True)\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = iris.target\n",
2017-06-01 09:23:37 +02:00
"\n",
"setosa_or_versicolor = (y == 0) | (y == 1)\n",
"X = X[setosa_or_versicolor]\n",
"y = y[setosa_or_versicolor]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's build and train 3 models:\n",
"* Remember that `LinearSVC` uses `loss=\"squared_hinge\"` by default, so if we want all 3 models to produce similar results, we need to set `loss=\"hinge\"`.\n",
"* Also, the `SVC` class uses an RBF kernel by default, so we need to set `kernel=\"linear\"` to get similar results as the other two models.\n",
"* Lastly, the `SGDClassifier` class does not have a `C` hyperparameter, but it has another regularization hyperparameter called `alpha`, so we can tweak it to get similar results as the other two models."
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 35,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.svm import SVC, LinearSVC\n",
"from sklearn.linear_model import SGDClassifier\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"C = 5\n",
"alpha = 0.05\n",
2017-06-01 09:23:37 +02:00
"\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(X)\n",
"\n",
"lin_clf = LinearSVC(loss=\"hinge\", C=C, random_state=42).fit(X_scaled, y)\n",
"svc_clf = SVC(kernel=\"linear\", C=C).fit(X_scaled, y)\n",
"sgd_clf = SGDClassifier(alpha=alpha, random_state=42).fit(X_scaled, y)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's plot the decision boundaries of these three models:"
]
},
{
"cell_type": "code",
"execution_count": 36,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"def compute_decision_boundary(model):\n",
" w = -model.coef_[0, 0] / model.coef_[0, 1]\n",
" b = -model.intercept_[0] / model.coef_[0, 1]\n",
" return scaler.inverse_transform([[-10, -10 * w + b], [10, 10 * w + b]])\n",
2017-06-01 09:23:37 +02:00
"\n",
"lin_line = compute_decision_boundary(lin_clf)\n",
"svc_line = compute_decision_boundary(svc_clf)\n",
"sgd_line = compute_decision_boundary(sgd_clf)\n",
2017-06-01 09:23:37 +02:00
"\n",
"# Plot all three decision boundaries\n",
"plt.figure(figsize=(11, 4))\n",
"plt.plot(lin_line[:, 0], lin_line[:, 1], \"k:\", label=\"LinearSVC\")\n",
"plt.plot(svc_line[:, 0], svc_line[:, 1], \"b--\", linewidth=2, label=\"SVC\")\n",
"plt.plot(sgd_line[:, 0], sgd_line[:, 1], \"r-\", label=\"SGDClassifier\")\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\") # label=\"Iris versicolor\"\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\") # label=\"Iris setosa\"\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"upper center\")\n",
2017-06-01 09:23:37 +02:00
"plt.axis([0, 5.5, 0, 2])\n",
"plt.grid()\n",
2017-06-01 09:23:37 +02:00
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Close enough!"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"# 10."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: Train an SVM classifier on the Wine dataset, which you can load using `sklearn.datasets.load_wine()`. This dataset contains the chemical analysis of 178 wine samples produced by 3 different cultivators: the goal is to train a classification model capable of predicting the cultivator based on the wine's chemical analysis. Since SVM classifiers are binary classifiers, you will need to use one-versus-all to classify all 3 classes. What accuracy can you reach?_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"First, let's fetch the dataset, look at its description, then split it into a training set and a test set:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import load_wine\n",
"\n",
"wine = load_wine(as_frame=True)"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 38,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"print(wine.DESCR)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 39,
2017-12-19 22:40:17 +01:00
"metadata": {},
"outputs": [],
2017-06-01 09:23:37 +02:00
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" wine.data, wine.target, random_state=42)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 40,
2017-12-19 22:40:17 +01:00
"metadata": {},
"outputs": [],
2017-06-01 09:23:37 +02:00
"source": [
"X_train.head()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 41,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"y_train.head()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's start simple, with a linear SVM classifier. It will automatically use the One-vs-All (also called One-vs-the-Rest, OvR) strategy, so there's nothing special we need to do to handle multiple classes. Easy, right?"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 42,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"lin_clf = LinearSVC(random_state=42)\n",
"lin_clf.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Oh no! It failed to converge. Can you guess why? Do you think we must just increase the number of training iterations? Let's see:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 43,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"lin_clf = LinearSVC(max_iter=1_000_000, random_state=42)\n",
"lin_clf.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Even with one million iterations, it still did not converge. There must be another problem.\n",
"\n",
"Let's still evaluate this model with `cross_val_score`, it will serve as a baseline:"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 44,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.model_selection import cross_val_score\n",
"\n",
"cross_val_score(lin_clf, X_train, y_train).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Well 91% accuracy on this dataset is not great. So did you guess what the problem is?\n",
"\n",
"That's right, we forgot to scale the features! Always remember to scale the features when using SVMs:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 45,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"lin_clf = make_pipeline(StandardScaler(),\n",
" LinearSVC(random_state=42))\n",
"lin_clf.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Now it converges without any problem. Let's measure its performance:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 46,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.model_selection import cross_val_score\n",
2017-06-01 09:23:37 +02:00
"\n",
"cross_val_score(lin_clf, X_train, y_train).mean()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Nice! We get 97.7% accuracy, that's much better."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's see if a kernelized SVM will do better. We will use a default `SVC` for now:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 47,
2017-12-19 22:40:17 +01:00
"metadata": {},
"outputs": [],
2017-06-01 09:23:37 +02:00
"source": [
"svm_clf = make_pipeline(StandardScaler(), SVC(random_state=42))\n",
"cross_val_score(svm_clf, X_train, y_train).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's not better, but perhaps we need to do a bit of hyperparameter tuning:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 48,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.model_selection import RandomizedSearchCV\n",
"from scipy.stats import reciprocal, uniform\n",
"\n",
"param_distrib = {\n",
" \"svc__gamma\": reciprocal(0.001, 0.1),\n",
" \"svc__C\": uniform(1, 10)\n",
"}\n",
"rnd_search_cv = RandomizedSearchCV(svm_clf, param_distrib, n_iter=100, cv=5,\n",
" random_state=42)\n",
"rnd_search_cv.fit(X_train, y_train)\n",
"rnd_search_cv.best_estimator_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 49,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"rnd_search_cv.best_score_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Ah, this looks excellent! Let's select this model. Now we can test it on the test set:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 50,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"rnd_search_cv.score(X_test, y_test)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"This tuned kernelized SVM performs better than the `LinearSVC` model, but we get a lower score on the test set than we measured using cross-validation. This is quite common: since we did so much hyperparameter tuning, we ended up slightly overfitting the cross-validation test sets. It's tempting to tweak the hyperparameters a bit more until we get a better result on the test set, but we this would probably not help, as we would just start overfitting the test set. Anyway, this score is not bad at all, so let's stop here."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"## 11."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: Train and fine-tune an SVM regressor on the California housing dataset. You can use the original dataset rather than the tweaked version we used in Chapter 2. The original dataset can be fetched using `sklearn.datasets.fetch_california_housing()`. The labels represent hundreds of thousands of dollars. Since there are over 20,000 instances, SVMs can be slow, so for hyperparameter tuning you should use much less instances (e.g., 2,000), to test many more hyperparameter combinations. What is your best model's RMSE?_"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's load the dataset:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.datasets import fetch_california_housing\n",
"\n",
"housing = fetch_california_housing()\n",
"X = housing.data\n",
"y = housing.target"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Split it into a training set and a test set:"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,\n",
" random_state=42)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Don't forget to scale the data:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's train a simple `LinearSVR` first:"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.svm import LinearSVR\n",
2017-06-01 09:23:37 +02:00
"\n",
"lin_svr = make_pipeline(StandardScaler(), LinearSVR(random_state=42))\n",
"lin_svr.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"It did not converge, so let's increase `max_iter`:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 54,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"lin_svr = make_pipeline(StandardScaler(),\n",
" LinearSVR(max_iter=5000, random_state=42))\n",
"lin_svr.fit(X_train, y_train)"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's see how it performs on the training set:"
]
},
{
"cell_type": "code",
"execution_count": 55,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error\n",
"\n",
"y_pred = lin_svr.predict(X_train)\n",
2017-06-01 09:23:37 +02:00
"mse = mean_squared_error(y_train, y_pred)\n",
"mse"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Let's look at the RMSE:"
]
},
{
"cell_type": "code",
"execution_count": 56,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"np.sqrt(mse)"
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"In this dataset, the targets represent hundreds of thousands of dollars. The RMSE gives a rough idea of the kind of error you should expect (with a higher weight for large errors): so with this model we can expect errors close to $98,000! Not great. Let's see if we can do better with an RBF Kernel. We will use randomized search with cross validation to find the appropriate hyperparameter values for `C` and `gamma`:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 57,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"from sklearn.svm import SVR\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"from scipy.stats import reciprocal, uniform\n",
"\n",
"svm_clf = make_pipeline(StandardScaler(), SVR())\n",
"\n",
"param_distrib = {\n",
" \"svr__gamma\": reciprocal(0.001, 0.1),\n",
" \"svr__C\": uniform(1, 10)\n",
"}\n",
"rnd_search_cv = RandomizedSearchCV(svm_clf, param_distrib,\n",
" n_iter=100, cv=3, random_state=42)\n",
"rnd_search_cv.fit(X_train[:2000], y_train[:2000])"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 58,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"rnd_search_cv.best_estimator_"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"-cross_val_score(rnd_search_cv.best_estimator_, X_train, y_train,\n",
" scoring=\"neg_root_mean_squared_error\")"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"Looks much better than the linear model. Let's select this model and evaluate it on the test set:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 60,
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"outputs": [],
"source": [
"y_pred = rnd_search_cv.best_estimator_.predict(X_test)\n",
"rmse = mean_squared_error(y_test, y_pred, squared=False)\n",
"rmse"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"So SVMs worked very well on the Wine dataset, but not so much on the California Housing dataset. In Chapter 2, we found that Random Forests worked better for that dataset."
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "markdown",
2017-12-19 22:40:17 +01:00
"metadata": {},
2017-06-01 09:23:37 +02:00
"source": [
"And that's all for today!"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
2016-09-27 23:31:21 +02:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-10-17 03:27:34 +02:00
"version": "3.8.12"
2016-09-27 23:31:21 +02:00
},
"nav_menu": {},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
2016-09-27 23:31:21 +02:00
}