2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"**Chapter 5 – Support Vector Machines**\n",
"\n",
2017-08-19 17:01:55 +02:00
"_This notebook contains all the sample code and solutions to the exercises in chapter 5._"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"# To support both python 2 and python 3\n",
"from __future__ import division, print_function, unicode_literals\n",
"\n",
"# Common imports\n",
"import numpy as np\n",
"import os\n",
"\n",
"# to make this notebook's output stable across runs\n",
2017-06-06 23:13:43 +02:00
"np.random.seed(42)\n",
2016-09-27 23:31:21 +02:00
"\n",
"# To plot pretty figures\n",
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['axes.labelsize'] = 14\n",
"plt.rcParams['xtick.labelsize'] = 12\n",
"plt.rcParams['ytick.labelsize'] = 12\n",
"\n",
"# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n",
"CHAPTER_ID = \"svm\"\n",
"\n",
"def save_fig(fig_id, tight_layout=True):\n",
" path = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id + \".png\")\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format='png', dpi=300)"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Large margin classification"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"The next few code cells generate the first figures in chapter 5. The first actual code sample comes after:"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.svm import SVC\n",
"from sklearn import datasets\n",
"\n",
"iris = datasets.load_iris()\n",
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
"y = iris[\"target\"]\n",
"\n",
2017-02-17 14:47:18 +01:00
"setosa_or_versicolor = (y == 0) | (y == 1)\n",
"X = X[setosa_or_versicolor]\n",
"y = y[setosa_or_versicolor]\n",
2016-09-27 23:31:21 +02:00
"\n",
"# SVM Classifier model\n",
"svm_clf = SVC(kernel=\"linear\", C=float(\"inf\"))\n",
"svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"# Bad models\n",
"x0 = np.linspace(0, 5.5, 200)\n",
"pred_1 = 5*x0 - 20\n",
"pred_2 = x0 - 1.8\n",
"pred_3 = 0.1 * x0 + 0.5\n",
"\n",
"def plot_svc_decision_boundary(svm_clf, xmin, xmax):\n",
" w = svm_clf.coef_[0]\n",
" b = svm_clf.intercept_[0]\n",
"\n",
" # At the decision boundary, w0*x0 + w1*x1 + b = 0\n",
" # => x1 = -w0/w1 * x0 - b/w1\n",
" x0 = np.linspace(xmin, xmax, 200)\n",
" decision_boundary = -w[0]/w[1] * x0 - b/w[1]\n",
"\n",
" margin = 1/w[1]\n",
" gutter_up = decision_boundary + margin\n",
" gutter_down = decision_boundary - margin\n",
"\n",
" svs = svm_clf.support_vectors_\n",
" plt.scatter(svs[:, 0], svs[:, 1], s=180, facecolors='#FFAAAA')\n",
" plt.plot(x0, decision_boundary, \"k-\", linewidth=2)\n",
" plt.plot(x0, gutter_up, \"k--\", linewidth=2)\n",
" plt.plot(x0, gutter_down, \"k--\", linewidth=2)\n",
"\n",
"plt.figure(figsize=(12,2.7))\n",
"\n",
"plt.subplot(121)\n",
"plt.plot(x0, pred_1, \"g--\", linewidth=2)\n",
"plt.plot(x0, pred_2, \"m-\", linewidth=2)\n",
"plt.plot(x0, pred_3, \"r-\", linewidth=2)\n",
2017-02-17 14:47:18 +01:00
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\", label=\"Iris-Versicolor\")\n",
2016-09-27 23:31:21 +02:00
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\", label=\"Iris-Setosa\")\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.ylabel(\"Petal width\", fontsize=14)\n",
"plt.legend(loc=\"upper left\", fontsize=14)\n",
"plt.axis([0, 5.5, 0, 2])\n",
"\n",
"plt.subplot(122)\n",
"plot_svc_decision_boundary(svm_clf, 0, 5.5)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\")\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.axis([0, 5.5, 0, 2])\n",
"\n",
"save_fig(\"large_margin_classification_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Sensitivity to feature scales"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"Xs = np.array([[1, 50], [5, 20], [3, 80], [5, 60]]).astype(np.float64)\n",
"ys = np.array([0, 0, 1, 1])\n",
"svm_clf = SVC(kernel=\"linear\", C=100)\n",
"svm_clf.fit(Xs, ys)\n",
"\n",
"plt.figure(figsize=(12,3.2))\n",
"plt.subplot(121)\n",
"plt.plot(Xs[:, 0][ys==1], Xs[:, 1][ys==1], \"bo\")\n",
"plt.plot(Xs[:, 0][ys==0], Xs[:, 1][ys==0], \"ms\")\n",
"plot_svc_decision_boundary(svm_clf, 0, 6)\n",
"plt.xlabel(\"$x_0$\", fontsize=20)\n",
"plt.ylabel(\"$x_1$ \", fontsize=20, rotation=0)\n",
"plt.title(\"Unscaled\", fontsize=16)\n",
"plt.axis([0, 6, 0, 90])\n",
"\n",
"from sklearn.preprocessing import StandardScaler\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(Xs)\n",
"svm_clf.fit(X_scaled, ys)\n",
"\n",
"plt.subplot(122)\n",
"plt.plot(X_scaled[:, 0][ys==1], X_scaled[:, 1][ys==1], \"bo\")\n",
"plt.plot(X_scaled[:, 0][ys==0], X_scaled[:, 1][ys==0], \"ms\")\n",
"plot_svc_decision_boundary(svm_clf, -2, 2)\n",
"plt.xlabel(\"$x_0$\", fontsize=20)\n",
"plt.title(\"Scaled\", fontsize=16)\n",
"plt.axis([-2, 2, -2, 2])\n",
"\n",
"save_fig(\"sensitivity_to_feature_scales_plot\")\n"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Sensitivity to outliers"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"X_outliers = np.array([[3.4, 1.3], [3.2, 0.8]])\n",
"y_outliers = np.array([0, 0])\n",
"Xo1 = np.concatenate([X, X_outliers[:1]], axis=0)\n",
"yo1 = np.concatenate([y, y_outliers[:1]], axis=0)\n",
"Xo2 = np.concatenate([X, X_outliers[1:]], axis=0)\n",
"yo2 = np.concatenate([y, y_outliers[1:]], axis=0)\n",
"\n",
2017-06-06 23:13:43 +02:00
"svm_clf2 = SVC(kernel=\"linear\", C=10**9)\n",
2016-09-27 23:31:21 +02:00
"svm_clf2.fit(Xo2, yo2)\n",
"\n",
"plt.figure(figsize=(12,2.7))\n",
"\n",
"plt.subplot(121)\n",
"plt.plot(Xo1[:, 0][yo1==1], Xo1[:, 1][yo1==1], \"bs\")\n",
"plt.plot(Xo1[:, 0][yo1==0], Xo1[:, 1][yo1==0], \"yo\")\n",
"plt.text(0.3, 1.0, \"Impossible!\", fontsize=24, color=\"red\")\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.ylabel(\"Petal width\", fontsize=14)\n",
"plt.annotate(\"Outlier\",\n",
" xy=(X_outliers[0][0], X_outliers[0][1]),\n",
" xytext=(2.5, 1.7),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=16,\n",
" )\n",
"plt.axis([0, 5.5, 0, 2])\n",
"\n",
"plt.subplot(122)\n",
"plt.plot(Xo2[:, 0][yo2==1], Xo2[:, 1][yo2==1], \"bs\")\n",
"plt.plot(Xo2[:, 0][yo2==0], Xo2[:, 1][yo2==0], \"yo\")\n",
"plot_svc_decision_boundary(svm_clf2, 0, 5.5)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.annotate(\"Outlier\",\n",
" xy=(X_outliers[1][0], X_outliers[1][1]),\n",
" xytext=(3.2, 0.08),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=16,\n",
" )\n",
"plt.axis([0, 5.5, 0, 2])\n",
"\n",
"save_fig(\"sensitivity_to_outliers_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Large margin *vs* margin violations"
]
},
2017-06-01 09:23:37 +02:00
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"This is the first code example in chapter 5:"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
2017-06-01 09:23:37 +02:00
"import numpy as np\n",
2016-09-27 23:31:21 +02:00
"from sklearn import datasets\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.svm import LinearSVC\n",
"\n",
"iris = datasets.load_iris()\n",
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
"y = (iris[\"target\"] == 2).astype(np.float64) # Iris-Virginica\n",
"\n",
2017-06-01 09:23:37 +02:00
"svm_clf = Pipeline((\n",
" (\"scaler\", StandardScaler()),\n",
2017-06-06 23:13:43 +02:00
" (\"linear_svc\", LinearSVC(C=1, loss=\"hinge\", random_state=42)),\n",
2017-06-01 09:23:37 +02:00
" ))\n",
"\n",
"svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"svm_clf.predict([[5.5, 1.7]])"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Now let's generate the graph comparing different regularization settings:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"scaler = StandardScaler()\n",
2017-06-06 23:13:43 +02:00
"svm_clf1 = LinearSVC(C=1, loss=\"hinge\", random_state=42)\n",
"svm_clf2 = LinearSVC(C=100, loss=\"hinge\", random_state=42)\n",
2016-09-27 23:31:21 +02:00
"\n",
"scaled_svm_clf1 = Pipeline((\n",
" (\"scaler\", scaler),\n",
" (\"linear_svc\", svm_clf1),\n",
" ))\n",
"scaled_svm_clf2 = Pipeline((\n",
" (\"scaler\", scaler),\n",
" (\"linear_svc\", svm_clf2),\n",
" ))\n",
"\n",
"scaled_svm_clf1.fit(X, y)\n",
2017-06-01 09:23:37 +02:00
"scaled_svm_clf2.fit(X, y)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 9,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"# Convert to unscaled parameters\n",
"b1 = svm_clf1.decision_function([-scaler.mean_ / scaler.scale_])\n",
"b2 = svm_clf2.decision_function([-scaler.mean_ / scaler.scale_])\n",
"w1 = svm_clf1.coef_[0] / scaler.scale_\n",
"w2 = svm_clf2.coef_[0] / scaler.scale_\n",
"svm_clf1.intercept_ = np.array([b1])\n",
"svm_clf2.intercept_ = np.array([b2])\n",
"svm_clf1.coef_ = np.array([w1])\n",
"svm_clf2.coef_ = np.array([w2])\n",
"\n",
"# Find support vectors (LinearSVC does not do this automatically)\n",
"t = y * 2 - 1\n",
"support_vectors_idx1 = (t * (X.dot(w1) + b1) < 1).ravel()\n",
"support_vectors_idx2 = (t * (X.dot(w2) + b2) < 1).ravel()\n",
"svm_clf1.support_vectors_ = X[support_vectors_idx1]\n",
"svm_clf2.support_vectors_ = X[support_vectors_idx2]"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 10,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"plt.figure(figsize=(12,3.2))\n",
"plt.subplot(121)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\", label=\"Iris-Virginica\")\n",
2017-02-17 14:47:18 +01:00
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\", label=\"Iris-Versicolor\")\n",
2016-09-27 23:31:21 +02:00
"plot_svc_decision_boundary(svm_clf1, 4, 6)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.ylabel(\"Petal width\", fontsize=14)\n",
"plt.legend(loc=\"upper left\", fontsize=14)\n",
"plt.title(\"$C = {}$\".format(svm_clf1.C), fontsize=16)\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"\n",
"plt.subplot(122)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
"plot_svc_decision_boundary(svm_clf2, 4, 6)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.title(\"$C = {}$\".format(svm_clf2.C), fontsize=16)\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"\n",
"save_fig(\"regularization_plot\")"
]
},
{
"cell_type": "markdown",
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"source": [
"# Non-linear classification"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 11,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"X1D = np.linspace(-4, 4, 9).reshape(-1, 1)\n",
"X2D = np.c_[X1D, X1D**2]\n",
"y = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n",
"\n",
"plt.figure(figsize=(11, 4))\n",
"\n",
"plt.subplot(121)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.plot(X1D[:, 0][y==0], np.zeros(4), \"bs\")\n",
"plt.plot(X1D[:, 0][y==1], np.zeros(5), \"g^\")\n",
"plt.gca().get_yaxis().set_ticks([])\n",
"plt.xlabel(r\"$x_1$\", fontsize=20)\n",
"plt.axis([-4.5, 4.5, -0.2, 0.2])\n",
"\n",
"plt.subplot(122)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n",
"plt.plot(X2D[:, 0][y==0], X2D[:, 1][y==0], \"bs\")\n",
"plt.plot(X2D[:, 0][y==1], X2D[:, 1][y==1], \"g^\")\n",
"plt.xlabel(r\"$x_1$\", fontsize=20)\n",
"plt.ylabel(r\"$x_2$\", fontsize=20, rotation=0)\n",
"plt.gca().get_yaxis().set_ticks([0, 4, 8, 12, 16])\n",
"plt.plot([-4.5, 4.5], [6.5, 6.5], \"r--\", linewidth=3)\n",
"plt.axis([-4.5, 4.5, -1, 17])\n",
"\n",
"plt.subplots_adjust(right=1)\n",
"\n",
"save_fig(\"higher_dimensions_plot\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 12,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.datasets import make_moons\n",
"X, y = make_moons(n_samples=100, noise=0.15, random_state=42)\n",
"\n",
"def plot_dataset(X, y, axes):\n",
" plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
" plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n",
" plt.axis(axes)\n",
" plt.grid(True, which='both')\n",
" plt.xlabel(r\"$x_1$\", fontsize=20)\n",
" plt.ylabel(r\"$x_2$\", fontsize=20, rotation=0)\n",
"\n",
"plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
"plt.show()"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 13,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
2017-06-01 09:23:37 +02:00
"from sklearn.datasets import make_moons\n",
2016-09-27 23:31:21 +02:00
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"\n",
"polynomial_svm_clf = Pipeline((\n",
" (\"poly_features\", PolynomialFeatures(degree=3)),\n",
" (\"scaler\", StandardScaler()),\n",
2017-06-06 23:13:43 +02:00
" (\"svm_clf\", LinearSVC(C=10, loss=\"hinge\", random_state=42))\n",
2016-09-27 23:31:21 +02:00
" ))\n",
"\n",
2017-06-01 09:23:37 +02:00
"polynomial_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"def plot_predictions(clf, axes):\n",
" x0s = np.linspace(axes[0], axes[1], 100)\n",
" x1s = np.linspace(axes[2], axes[3], 100)\n",
" x0, x1 = np.meshgrid(x0s, x1s)\n",
" X = np.c_[x0.ravel(), x1.ravel()]\n",
" y_pred = clf.predict(X).reshape(x0.shape)\n",
" y_decision = clf.decision_function(X).reshape(x0.shape)\n",
" plt.contourf(x0, x1, y_pred, cmap=plt.cm.brg, alpha=0.2)\n",
" plt.contourf(x0, x1, y_decision, cmap=plt.cm.brg, alpha=0.1)\n",
"\n",
"plot_predictions(polynomial_svm_clf, [-1.5, 2.5, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
"\n",
"save_fig(\"moons_polynomial_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 15,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.svm import SVC\n",
2017-06-01 09:23:37 +02:00
"\n",
2016-09-27 23:31:21 +02:00
"poly_kernel_svm_clf = Pipeline((\n",
" (\"scaler\", StandardScaler()),\n",
" (\"svm_clf\", SVC(kernel=\"poly\", degree=3, coef0=1, C=5))\n",
" ))\n",
2017-06-01 09:23:37 +02:00
"poly_kernel_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"poly100_kernel_svm_clf = Pipeline((\n",
" (\"scaler\", StandardScaler()),\n",
" (\"svm_clf\", SVC(kernel=\"poly\", degree=10, coef0=100, C=5))\n",
" ))\n",
2017-06-01 09:23:37 +02:00
"poly100_kernel_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"plt.figure(figsize=(11, 4))\n",
"\n",
"plt.subplot(121)\n",
"plot_predictions(poly_kernel_svm_clf, [-1.5, 2.5, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
"plt.title(r\"$d=3, r=1, C=5$\", fontsize=18)\n",
"\n",
"plt.subplot(122)\n",
"plot_predictions(poly100_kernel_svm_clf, [-1.5, 2.5, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
"plt.title(r\"$d=10, r=100, C=5$\", fontsize=18)\n",
"\n",
"save_fig(\"moons_kernelized_polynomial_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 18,
2016-09-27 23:31:21 +02:00
"metadata": {
"collapsed": false,
2017-02-17 11:51:26 +01:00
"deletable": true,
"editable": true,
2016-09-27 23:31:21 +02:00
"scrolled": true
},
"outputs": [],
"source": [
"def gaussian_rbf(x, landmark, gamma):\n",
" return np.exp(-gamma * np.linalg.norm(x - landmark, axis=1)**2)\n",
"\n",
"gamma = 0.3\n",
"\n",
"x1s = np.linspace(-4.5, 4.5, 200).reshape(-1, 1)\n",
"x2s = gaussian_rbf(x1s, -2, gamma)\n",
"x3s = gaussian_rbf(x1s, 1, gamma)\n",
"\n",
"XK = np.c_[gaussian_rbf(X1D, -2, gamma), gaussian_rbf(X1D, 1, gamma)]\n",
"yk = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n",
"\n",
"plt.figure(figsize=(11, 4))\n",
"\n",
"plt.subplot(121)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.scatter(x=[-2, 1], y=[0, 0], s=150, alpha=0.5, c=\"red\")\n",
"plt.plot(X1D[:, 0][yk==0], np.zeros(4), \"bs\")\n",
"plt.plot(X1D[:, 0][yk==1], np.zeros(5), \"g^\")\n",
"plt.plot(x1s, x2s, \"g--\")\n",
"plt.plot(x1s, x3s, \"b:\")\n",
"plt.gca().get_yaxis().set_ticks([0, 0.25, 0.5, 0.75, 1])\n",
"plt.xlabel(r\"$x_1$\", fontsize=20)\n",
"plt.ylabel(r\"Similarity\", fontsize=14)\n",
"plt.annotate(r'$\\mathbf{x}$',\n",
" xy=(X1D[3, 0], 0),\n",
" xytext=(-0.5, 0.20),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=18,\n",
" )\n",
"plt.text(-2, 0.9, \"$x_2$\", ha=\"center\", fontsize=20)\n",
"plt.text(1, 0.9, \"$x_3$\", ha=\"center\", fontsize=20)\n",
"plt.axis([-4.5, 4.5, -0.1, 1.1])\n",
"\n",
"plt.subplot(122)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n",
"plt.plot(XK[:, 0][yk==0], XK[:, 1][yk==0], \"bs\")\n",
"plt.plot(XK[:, 0][yk==1], XK[:, 1][yk==1], \"g^\")\n",
"plt.xlabel(r\"$x_2$\", fontsize=20)\n",
"plt.ylabel(r\"$x_3$ \", fontsize=20, rotation=0)\n",
"plt.annotate(r'$\\phi\\left(\\mathbf{x}\\right)$',\n",
" xy=(XK[3, 0], XK[3, 1]),\n",
" xytext=(0.65, 0.50),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=18,\n",
" )\n",
"plt.plot([-0.1, 1.1], [0.57, -0.1], \"r--\", linewidth=3)\n",
"plt.axis([-0.1, 1.1, -0.1, 1.1])\n",
" \n",
"plt.subplots_adjust(right=1)\n",
"\n",
"save_fig(\"kernel_method_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 19,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"x1_example = X1D[3, 0]\n",
"for landmark in (-2, 1):\n",
" k = gaussian_rbf(np.array([[x1_example]]), np.array([[landmark]]), gamma)\n",
" print(\"Phi({}, {}) = {}\".format(x1_example, landmark, k))"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 20,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"rbf_kernel_svm_clf = Pipeline((\n",
" (\"scaler\", StandardScaler()),\n",
" (\"svm_clf\", SVC(kernel=\"rbf\", gamma=5, C=0.001))\n",
" ))\n",
"rbf_kernel_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 21,
2016-09-27 23:31:21 +02:00
"metadata": {
"collapsed": false,
2017-02-17 11:51:26 +01:00
"deletable": true,
"editable": true,
2016-09-27 23:31:21 +02:00
"scrolled": true
},
"outputs": [],
"source": [
"from sklearn.svm import SVC\n",
"\n",
"gamma1, gamma2 = 0.1, 5\n",
"C1, C2 = 0.001, 1000\n",
"hyperparams = (gamma1, C1), (gamma1, C2), (gamma2, C1), (gamma2, C2)\n",
"\n",
"svm_clfs = []\n",
"for gamma, C in hyperparams:\n",
" rbf_kernel_svm_clf = Pipeline((\n",
" (\"scaler\", StandardScaler()),\n",
" (\"svm_clf\", SVC(kernel=\"rbf\", gamma=gamma, C=C))\n",
" ))\n",
" rbf_kernel_svm_clf.fit(X, y)\n",
" svm_clfs.append(rbf_kernel_svm_clf)\n",
"\n",
"plt.figure(figsize=(11, 7))\n",
"\n",
"for i, svm_clf in enumerate(svm_clfs):\n",
" plt.subplot(221 + i)\n",
" plot_predictions(svm_clf, [-1.5, 2.5, -1, 1.5])\n",
" plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
" gamma, C = hyperparams[i]\n",
" plt.title(r\"$\\gamma = {}, C = {}$\".format(gamma, C), fontsize=16)\n",
"\n",
"save_fig(\"moons_rbf_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Regression\n"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 22,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
2017-06-06 23:13:43 +02:00
"np.random.seed(42)\n",
2016-09-27 23:31:21 +02:00
"m = 50\n",
2017-06-06 23:13:43 +02:00
"X = 2 * np.random.rand(m, 1)\n",
"y = (4 + 3 * X + np.random.randn(m, 1)).ravel()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"from sklearn.svm import LinearSVR\n",
2016-09-27 23:31:21 +02:00
"\n",
2017-06-06 23:13:43 +02:00
"svm_reg = LinearSVR(epsilon=1.5, random_state=42)\n",
2017-06-01 09:23:37 +02:00
"svm_reg.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
2017-06-06 23:13:43 +02:00
"svm_reg1 = LinearSVR(epsilon=1.5, random_state=42)\n",
"svm_reg2 = LinearSVR(epsilon=0.5, random_state=42)\n",
2016-09-27 23:31:21 +02:00
"svm_reg1.fit(X, y)\n",
"svm_reg2.fit(X, y)\n",
"\n",
"def find_support_vectors(svm_reg, X, y):\n",
" y_pred = svm_reg.predict(X)\n",
" off_margin = (np.abs(y - y_pred) >= svm_reg.epsilon)\n",
" return np.argwhere(off_margin)\n",
"\n",
"svm_reg1.support_ = find_support_vectors(svm_reg1, X, y)\n",
"svm_reg2.support_ = find_support_vectors(svm_reg2, X, y)\n",
"\n",
"eps_x1 = 1\n",
"eps_y_pred = svm_reg1.predict([[eps_x1]])"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 25,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"def plot_svm_regression(svm_reg, X, y, axes):\n",
" x1s = np.linspace(axes[0], axes[1], 100).reshape(100, 1)\n",
" y_pred = svm_reg.predict(x1s)\n",
" plt.plot(x1s, y_pred, \"k-\", linewidth=2, label=r\"$\\hat{y}$\")\n",
" plt.plot(x1s, y_pred + svm_reg.epsilon, \"k--\")\n",
" plt.plot(x1s, y_pred - svm_reg.epsilon, \"k--\")\n",
" plt.scatter(X[svm_reg.support_], y[svm_reg.support_], s=180, facecolors='#FFAAAA')\n",
" plt.plot(X, y, \"bo\")\n",
" plt.xlabel(r\"$x_1$\", fontsize=18)\n",
" plt.legend(loc=\"upper left\", fontsize=18)\n",
" plt.axis(axes)\n",
"\n",
"plt.figure(figsize=(9, 4))\n",
"plt.subplot(121)\n",
"plot_svm_regression(svm_reg1, X, y, [0, 2, 3, 11])\n",
"plt.title(r\"$\\epsilon = {}$\".format(svm_reg1.epsilon), fontsize=18)\n",
"plt.ylabel(r\"$y$\", fontsize=18, rotation=0)\n",
"#plt.plot([eps_x1, eps_x1], [eps_y_pred, eps_y_pred - svm_reg1.epsilon], \"k-\", linewidth=2)\n",
"plt.annotate(\n",
" '', xy=(eps_x1, eps_y_pred), xycoords='data',\n",
" xytext=(eps_x1, eps_y_pred - svm_reg1.epsilon),\n",
" textcoords='data', arrowprops={'arrowstyle': '<->', 'linewidth': 1.5}\n",
" )\n",
"plt.text(0.91, 5.6, r\"$\\epsilon$\", fontsize=20)\n",
"plt.subplot(122)\n",
"plot_svm_regression(svm_reg2, X, y, [0, 2, 3, 11])\n",
"plt.title(r\"$\\epsilon = {}$\".format(svm_reg2.epsilon), fontsize=18)\n",
"save_fig(\"svm_regression_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 26,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": true,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
2017-06-06 23:13:43 +02:00
"np.random.seed(42)\n",
2017-06-01 09:23:37 +02:00
"m = 100\n",
2017-06-06 23:13:43 +02:00
"X = 2 * np.random.rand(m, 1) - 1\n",
"y = (0.2 + 0.1 * X + 0.5 * X**2 + np.random.randn(m, 1)/10).ravel()"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"from sklearn.svm import SVR\n",
"\n",
"svm_poly_reg = SVR(kernel=\"poly\", degree=2, C=100, epsilon=0.1)\n",
"svm_poly_reg.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 28,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.svm import SVR\n",
"\n",
"svm_poly_reg1 = SVR(kernel=\"poly\", degree=2, C=100, epsilon=0.1)\n",
"svm_poly_reg2 = SVR(kernel=\"poly\", degree=2, C=0.01, epsilon=0.1)\n",
"svm_poly_reg1.fit(X, y)\n",
"svm_poly_reg2.fit(X, y)"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 29,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"plt.figure(figsize=(9, 4))\n",
"plt.subplot(121)\n",
"plot_svm_regression(svm_poly_reg1, X, y, [-1, 1, 0, 1])\n",
"plt.title(r\"$degree={}, C={}, \\epsilon = {}$\".format(svm_poly_reg1.degree, svm_poly_reg1.C, svm_poly_reg1.epsilon), fontsize=18)\n",
"plt.ylabel(r\"$y$\", fontsize=18, rotation=0)\n",
"plt.subplot(122)\n",
"plot_svm_regression(svm_poly_reg2, X, y, [-1, 1, 0, 1])\n",
"plt.title(r\"$degree={}, C={}, \\epsilon = {}$\".format(svm_poly_reg2.degree, svm_poly_reg2.C, svm_poly_reg2.epsilon), fontsize=18)\n",
"save_fig(\"svm_with_polynomial_kernel_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Under the hood"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 30,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"iris = datasets.load_iris()\n",
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
"y = (iris[\"target\"] == 2).astype(np.float64) # Iris-Virginica"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 31,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from mpl_toolkits.mplot3d import Axes3D\n",
"\n",
"def plot_3D_decision_function(ax, w, b, x1_lim=[4, 6], x2_lim=[0.8, 2.8]):\n",
" x1_in_bounds = (X[:, 0] > x1_lim[0]) & (X[:, 0] < x1_lim[1])\n",
" X_crop = X[x1_in_bounds]\n",
" y_crop = y[x1_in_bounds]\n",
" x1s = np.linspace(x1_lim[0], x1_lim[1], 20)\n",
" x2s = np.linspace(x2_lim[0], x2_lim[1], 20)\n",
" x1, x2 = np.meshgrid(x1s, x2s)\n",
" xs = np.c_[x1.ravel(), x2.ravel()]\n",
" df = (xs.dot(w) + b).reshape(x1.shape)\n",
" m = 1 / np.linalg.norm(w)\n",
" boundary_x2s = -x1s*(w[0]/w[1])-b/w[1]\n",
" margin_x2s_1 = -x1s*(w[0]/w[1])-(b-1)/w[1]\n",
" margin_x2s_2 = -x1s*(w[0]/w[1])-(b+1)/w[1]\n",
" ax.plot_surface(x1s, x2, 0, color=\"b\", alpha=0.2, cstride=100, rstride=100)\n",
" ax.plot(x1s, boundary_x2s, 0, \"k-\", linewidth=2, label=r\"$h=0$\")\n",
" ax.plot(x1s, margin_x2s_1, 0, \"k--\", linewidth=2, label=r\"$h=\\pm 1$\")\n",
" ax.plot(x1s, margin_x2s_2, 0, \"k--\", linewidth=2)\n",
" ax.plot(X_crop[:, 0][y_crop==1], X_crop[:, 1][y_crop==1], 0, \"g^\")\n",
" ax.plot_wireframe(x1, x2, df, alpha=0.3, color=\"k\")\n",
" ax.plot(X_crop[:, 0][y_crop==0], X_crop[:, 1][y_crop==0], 0, \"bs\")\n",
" ax.axis(x1_lim + x2_lim)\n",
" ax.text(4.5, 2.5, 3.8, \"Decision function $h$\", fontsize=15)\n",
" ax.set_xlabel(r\"Petal length\", fontsize=15)\n",
" ax.set_ylabel(r\"Petal width\", fontsize=15)\n",
" ax.set_zlabel(r\"$h = \\mathbf{w}^t \\cdot \\mathbf{x} + b$\", fontsize=18)\n",
" ax.legend(loc=\"upper left\", fontsize=16)\n",
"\n",
"fig = plt.figure(figsize=(11, 6))\n",
"ax1 = fig.add_subplot(111, projection='3d')\n",
"plot_3D_decision_function(ax1, w=svm_clf2.coef_[0], b=svm_clf2.intercept_[0])\n",
"\n",
"save_fig(\"iris_3D_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Small weight vector results in a large margin"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 32,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"def plot_2D_decision_function(w, b, ylabel=True, x1_lim=[-3, 3]):\n",
" x1 = np.linspace(x1_lim[0], x1_lim[1], 200)\n",
" y = w * x1 + b\n",
" m = 1 / w\n",
"\n",
" plt.plot(x1, y)\n",
" plt.plot(x1_lim, [1, 1], \"k:\")\n",
" plt.plot(x1_lim, [-1, -1], \"k:\")\n",
" plt.axhline(y=0, color='k')\n",
" plt.axvline(x=0, color='k')\n",
" plt.plot([m, m], [0, 1], \"k--\")\n",
" plt.plot([-m, -m], [0, -1], \"k--\")\n",
" plt.plot([-m, m], [0, 0], \"k-o\", linewidth=3)\n",
" plt.axis(x1_lim + [-2, 2])\n",
" plt.xlabel(r\"$x_1$\", fontsize=16)\n",
" if ylabel:\n",
" plt.ylabel(r\"$w_1 x_1$ \", rotation=0, fontsize=16)\n",
" plt.title(r\"$w_1 = {}$\".format(w), fontsize=16)\n",
"\n",
"plt.figure(figsize=(12, 3.2))\n",
"plt.subplot(121)\n",
"plot_2D_decision_function(1, 0)\n",
"plt.subplot(122)\n",
"plot_2D_decision_function(0.5, 0, ylabel=False)\n",
"save_fig(\"small_w_large_margin_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 33,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.svm import SVC\n",
"from sklearn import datasets\n",
"\n",
"iris = datasets.load_iris()\n",
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
"y = (iris[\"target\"] == 2).astype(np.float64) # Iris-Virginica\n",
"\n",
"svm_clf = SVC(kernel=\"linear\", C=1)\n",
"svm_clf.fit(X, y)\n",
"svm_clf.predict([[5.3, 1.3]])"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Hinge loss"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 34,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"t = np.linspace(-2, 4, 200)\n",
"h = np.where(1 - t < 0, 0, 1 - t) # max(0, 1-t)\n",
"\n",
"plt.figure(figsize=(5,2.8))\n",
"plt.plot(t, h, \"b-\", linewidth=2, label=\"$max(0, 1 - t)$\")\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n",
"plt.yticks(np.arange(-1, 2.5, 1))\n",
"plt.xlabel(\"$t$\", fontsize=16)\n",
"plt.axis([-2, 4, -1, 2.5])\n",
"plt.legend(loc=\"upper right\", fontsize=16)\n",
"save_fig(\"hinge_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Extra material"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"## Training time"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 35,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
2017-06-06 23:13:43 +02:00
"X, y = make_moons(n_samples=1000, noise=0.4, random_state=42)\n",
2016-09-27 23:31:21 +02:00
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 36,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"import time\n",
"\n",
"tol = 0.1\n",
"tols = []\n",
"times = []\n",
"for i in range(10):\n",
" svm_clf = SVC(kernel=\"poly\", gamma=3, C=10, tol=tol, verbose=1)\n",
" t1 = time.time()\n",
" svm_clf.fit(X, y)\n",
" t2 = time.time()\n",
" times.append(t2-t1)\n",
" tols.append(tol)\n",
" print(i, tol, t2-t1)\n",
" tol /= 10\n",
"plt.semilogx(tols, times)"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"## Linear SVM classifier implementation using Batch Gradient Descent"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 37,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"# Training set\n",
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
"y = (iris[\"target\"] == 2).astype(np.float64).reshape(-1, 1) # Iris-Virginica"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 38,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator\n",
"\n",
"class MyLinearSVC(BaseEstimator):\n",
" def __init__(self, C=1, eta0=1, eta_d=10000, n_epochs=1000, random_state=None):\n",
" self.C = C\n",
" self.eta0 = eta0\n",
" self.n_epochs = n_epochs\n",
" self.random_state = random_state\n",
" self.eta_d = eta_d\n",
"\n",
" def eta(self, epoch):\n",
" return self.eta0 / (epoch + self.eta_d)\n",
" \n",
" def fit(self, X, y):\n",
" # Random initialization\n",
" if self.random_state:\n",
2017-06-06 23:13:43 +02:00
" np.random.seed(self.random_state)\n",
" w = np.random.randn(X.shape[1], 1) # n feature weights\n",
2016-09-27 23:31:21 +02:00
" b = 0\n",
"\n",
" m = len(X)\n",
" t = y * 2 - 1 # -1 if t==0, +1 if t==1\n",
" X_t = X * t\n",
" self.Js=[]\n",
"\n",
" # Training\n",
" for epoch in range(self.n_epochs):\n",
" support_vectors_idx = (X_t.dot(w) + t * b < 1).ravel()\n",
" X_t_sv = X_t[support_vectors_idx]\n",
" t_sv = t[support_vectors_idx]\n",
"\n",
" J = 1/2 * np.sum(w * w) + self.C * (np.sum(1 - X_t_sv.dot(w)) - b * np.sum(t_sv))\n",
" self.Js.append(J)\n",
"\n",
" w_gradient_vector = w - self.C * np.sum(X_t_sv, axis=0).reshape(-1, 1)\n",
" b_derivative = -C * np.sum(t_sv)\n",
" \n",
" w = w - self.eta(epoch) * w_gradient_vector\n",
" b = b - self.eta(epoch) * b_derivative\n",
" \n",
"\n",
" self.intercept_ = np.array([b])\n",
" self.coef_ = np.array([w])\n",
" support_vectors_idx = (X_t.dot(w) + b < 1).ravel()\n",
" self.support_vectors_ = X[support_vectors_idx]\n",
" return self\n",
"\n",
" def decision_function(self, X):\n",
" return X.dot(self.coef_[0]) + self.intercept_[0]\n",
"\n",
" def predict(self, X):\n",
" return (self.decision_function(X) >= 0).astype(np.float64)\n",
"\n",
"C=2\n",
"svm_clf = MyLinearSVC(C=C, eta0 = 10, eta_d = 1000, n_epochs=60000, random_state=2)\n",
"svm_clf.fit(X, y)\n",
"svm_clf.predict(np.array([[5, 2], [4, 1]]))"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 39,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"plt.plot(range(svm_clf.n_epochs), svm_clf.Js)\n",
"plt.axis([0, svm_clf.n_epochs, 0, 100])"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 40,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"print(svm_clf.intercept_, svm_clf.coef_)"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 41,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"svm_clf2 = SVC(kernel=\"linear\", C=C)\n",
"svm_clf2.fit(X, y.ravel())\n",
"print(svm_clf2.intercept_, svm_clf2.coef_)"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 42,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"yr = y.ravel()\n",
"plt.figure(figsize=(12,3.2))\n",
"plt.subplot(121)\n",
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\", label=\"Iris-Virginica\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\", label=\"Not Iris-Virginica\")\n",
"plot_svc_decision_boundary(svm_clf, 4, 6)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.ylabel(\"Petal width\", fontsize=14)\n",
"plt.title(\"MyLinearSVC\", fontsize=14)\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"\n",
"plt.subplot(122)\n",
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\")\n",
"plot_svc_decision_boundary(svm_clf2, 4, 6)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.title(\"SVC\", fontsize=14)\n",
"plt.axis([4, 6, 0.8, 2.8])\n"
]
},
{
"cell_type": "code",
2017-06-01 09:23:37 +02:00
"execution_count": 43,
2016-09-27 23:31:21 +02:00
"metadata": {
"collapsed": false,
2017-02-17 11:51:26 +01:00
"deletable": true,
"editable": true,
2016-09-27 23:31:21 +02:00
"scrolled": true
},
"outputs": [],
"source": [
"from sklearn.linear_model import SGDClassifier\n",
"\n",
"sgd_clf = SGDClassifier(loss=\"hinge\", alpha = 0.017, n_iter = 50, random_state=42)\n",
"sgd_clf.fit(X, y.ravel())\n",
"\n",
"m = len(X)\n",
"t = y * 2 - 1 # -1 if t==0, +1 if t==1\n",
"X_b = np.c_[np.ones((m, 1)), X] # Add bias input x0=1\n",
"X_b_t = X_b * t\n",
"sgd_theta = np.r_[sgd_clf.intercept_[0], sgd_clf.coef_[0]]\n",
"print(sgd_theta)\n",
"support_vectors_idx = (X_b_t.dot(sgd_theta) < 1).ravel()\n",
"sgd_clf.support_vectors_ = X[support_vectors_idx]\n",
"sgd_clf.C = C\n",
"\n",
"plt.figure(figsize=(5.5,3.2))\n",
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\")\n",
"plot_svc_decision_boundary(sgd_clf, 4, 6)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.ylabel(\"Petal width\", fontsize=14)\n",
"plt.title(\"SGDClassifier\", fontsize=14)\n",
"plt.axis([4, 6, 0.8, 2.8])\n"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
2017-06-01 09:23:37 +02:00
"## 1. to 7."
2016-09-27 23:31:21 +02:00
]
},
{
2017-06-01 09:23:37 +02:00
"cell_type": "markdown",
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
2017-06-01 09:23:37 +02:00
"source": [
"See appendix A."
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"# 8."
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: train a `LinearSVC` on a linearly separable dataset. Then train an `SVC` and a `SGDClassifier` on the same dataset. See if you can get them to produce roughly the same model._"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Let's use the Iris dataset: the Iris Setosa and Iris Versicolor classes are linearly separable."
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": true,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"from sklearn import datasets\n",
"\n",
"iris = datasets.load_iris()\n",
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
"y = iris[\"target\"]\n",
"\n",
"setosa_or_versicolor = (y == 0) | (y == 1)\n",
"X = X[setosa_or_versicolor]\n",
"y = y[setosa_or_versicolor]"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from sklearn.svm import SVC, LinearSVC\n",
"from sklearn.linear_model import SGDClassifier\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"C = 5\n",
"alpha = 1 / (C * len(X))\n",
"\n",
2017-06-06 23:13:43 +02:00
"lin_clf = LinearSVC(loss=\"hinge\", C=C, random_state=42)\n",
2017-06-01 09:23:37 +02:00
"svm_clf = SVC(kernel=\"linear\", C=C)\n",
"sgd_clf = SGDClassifier(loss=\"hinge\", learning_rate=\"constant\", eta0=0.001, alpha=alpha,\n",
" n_iter=100000, random_state=42)\n",
"\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(X)\n",
"\n",
"lin_clf.fit(X_scaled, y)\n",
"svm_clf.fit(X_scaled, y)\n",
"sgd_clf.fit(X_scaled, y)\n",
"\n",
"print(\"LinearSVC: \", lin_clf.intercept_, lin_clf.coef_)\n",
"print(\"SVC: \", svm_clf.intercept_, svm_clf.coef_)\n",
"print(\"SGDClassifier(alpha={:.5f}):\".format(sgd_clf.alpha), sgd_clf.intercept_, sgd_clf.coef_)"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Let's plot the decision boundaries of these three models:"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"# Compute the slope and bias of each decision boundary\n",
"w1 = -lin_clf.coef_[0, 0]/lin_clf.coef_[0, 1]\n",
"b1 = -lin_clf.intercept_[0]/lin_clf.coef_[0, 1]\n",
"w2 = -svm_clf.coef_[0, 0]/svm_clf.coef_[0, 1]\n",
"b2 = -svm_clf.intercept_[0]/svm_clf.coef_[0, 1]\n",
"w3 = -sgd_clf.coef_[0, 0]/sgd_clf.coef_[0, 1]\n",
"b3 = -sgd_clf.intercept_[0]/sgd_clf.coef_[0, 1]\n",
"\n",
"# Transform the decision boundary lines back to the original scale\n",
"line1 = scaler.inverse_transform([[-10, -10 * w1 + b1], [10, 10 * w1 + b1]])\n",
"line2 = scaler.inverse_transform([[-10, -10 * w2 + b2], [10, 10 * w2 + b2]])\n",
"line3 = scaler.inverse_transform([[-10, -10 * w3 + b3], [10, 10 * w3 + b3]])\n",
"\n",
"# Plot all three decision boundaries\n",
"plt.figure(figsize=(11, 4))\n",
"plt.plot(line1[:, 0], line1[:, 1], \"k:\", label=\"LinearSVC\")\n",
"plt.plot(line2[:, 0], line2[:, 1], \"b--\", linewidth=2, label=\"SVC\")\n",
"plt.plot(line3[:, 0], line3[:, 1], \"r-\", label=\"SGDClassifier\")\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\") # label=\"Iris-Versicolor\"\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\") # label=\"Iris-Setosa\"\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.ylabel(\"Petal width\", fontsize=14)\n",
"plt.legend(loc=\"upper center\", fontsize=14)\n",
"plt.axis([0, 5.5, 0, 2])\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Close enough!"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"# 9."
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: train an SVM classifier on the MNIST dataset. Since SVM classifiers are binary classifiers, you will need to use one-versus-all to classify all 10 digits. You may want to tune the hyperparameters using small validation sets to speed up the process. What accuracy can you reach?_"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"First, let's load the dataset and split it into a training set and a test set. We could use `train_test_split()` but people usually just take the first 60,000 instances for the training set, and the last 10,000 instances for the test set (this makes it possible to compare your model's performance with others): "
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_mldata\n",
"\n",
"mnist = fetch_mldata(\"MNIST original\")\n",
"X = mnist[\"data\"]\n",
"y = mnist[\"target\"]\n",
"\n",
"X_train = X[:60000]\n",
"y_train = y[:60000]\n",
"X_test = X[60000:]\n",
"y_test = y[60000:]"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Many training algorithms are sensitive to the order of the training instances, so it's generally good practice to shuffle them first:"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": true,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"rnd_idx = np.random.permutation(60000)\n",
"X_train = X_train[rnd_idx]\n",
"y_train = y_train[rnd_idx]"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Let's start simple, with a linear SVM classifier. It will automatically use the One-vs-All (also called One-vs-the-Rest, OvR) strategy, so there's nothing special we need to do. Easy!"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
2017-06-06 23:13:43 +02:00
"lin_clf = LinearSVC(random_state=42)\n",
2017-06-01 09:23:37 +02:00
"lin_clf.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Let's make predictions on the training set and measure the accuracy (we don't want to measure it on the test set yet, since we have not selected and trained the final model yet):"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"from sklearn.metrics import accuracy_score\n",
"\n",
"y_pred = lin_clf.predict(X_train)\n",
"accuracy_score(y_train, y_pred)"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
2017-06-06 23:13:43 +02:00
"Wow, 82% accuracy on MNIST is a really bad performance. This linear model is certainly too simple for MNIST, but perhaps we just needed to scale the data first:"
2017-06-01 09:23:37 +02:00
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"scaler = StandardScaler()\n",
"X_train_scaled = scaler.fit_transform(X_train.astype(np.float32))\n",
"X_test_scaled = scaler.transform(X_test.astype(np.float32))"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
2017-06-06 23:13:43 +02:00
"lin_clf = LinearSVC(random_state=42)\n",
2017-06-01 09:23:37 +02:00
"lin_clf.fit(X_train_scaled, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"y_pred = lin_clf.predict(X_train_scaled)\n",
"accuracy_score(y_train, y_pred)"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
2017-06-06 23:13:43 +02:00
"That's much better (we cut the error rate in two), but still not great at all for MNIST. If we want to use an SVM, we will have to use a kernel. Let's try an `SVC` with an RBF kernel (the default).\n",
2017-06-01 09:23:37 +02:00
"\n",
"**Warning**: if you are using Scikit-Learn ≤ 0.19, the `SVC` class will use the One-vs-One (OvO) strategy by default, so you must explicitly set `decision_function_shape=\"ovr\"` if you want to use the OvR strategy instead (OvR is the default since 0.19)."
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"svm_clf = SVC(decision_function_shape=\"ovr\")\n",
"svm_clf.fit(X_train_scaled[:10000], y_train[:10000])"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"y_pred = svm_clf.predict(X_train_scaled)\n",
"accuracy_score(y_train, y_pred)"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"That's promising, we get better performance even though we trained the model on 6 times less data. Let's tune the hyperparameters by doing a randomized search with cross validation. We will do this on a small dataset just to speed up the process:"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"from sklearn.model_selection import RandomizedSearchCV\n",
"from scipy.stats import reciprocal, uniform\n",
"\n",
"param_distributions = {\"gamma\": reciprocal(0.001, 0.1), \"C\": uniform(1, 10)}\n",
"rnd_search_cv = RandomizedSearchCV(svm_clf, param_distributions, n_iter=10, verbose=2)\n",
"rnd_search_cv.fit(X_train_scaled[:1000], y_train[:1000])"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"rnd_search_cv.best_estimator_"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"rnd_search_cv.best_score_"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"This looks pretty low but remember we only trained the model on 1,000 instances. Let's retrain the best estimator on the whole training set (run this at night, it will take hours):"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"rnd_search_cv.best_estimator_.fit(X_train_scaled, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"y_pred = rnd_search_cv.best_estimator_.predict(X_train_scaled)\n",
"accuracy_score(y_train, y_pred)"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Ah, this looks good! Let's select this model. Now we can test it on the test set:"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"y_pred = rnd_search_cv.best_estimator_.predict(X_test_scaled)\n",
"accuracy_score(y_test, y_pred)"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Not too bad, but apparently the model is overfitting slightly. It's tempting to tweak the hyperparameters a bit more (e.g. decreasing `C` and/or `gamma`), but we would run the risk of overfitting the test set. Other people have found that the hyperparameters `C=5` and `gamma=0.005` yield even better performance (over 98% accuracy). By running the randomized search for longer and on a larger part of the training set, you may be able to find this as well."
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"## 10."
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"_Exercise: train an SVM regressor on the California housing dataset._"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Let's load the dataset using Scikit-Learn's `fetch_california_housing()` function:"
]
},
{
"cell_type": "code",
2017-06-06 23:13:43 +02:00
"execution_count": 62,
2017-06-01 09:23:37 +02:00
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": true,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_california_housing\n",
"\n",
"housing = fetch_california_housing()\n",
"X = housing[\"data\"]\n",
"y = housing[\"target\"]"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Split it into a training set and a test set:"
]
},
{
"cell_type": "code",
2017-06-06 23:13:43 +02:00
"execution_count": 63,
2017-06-01 09:23:37 +02:00
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": true,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Don't forget to scale the data:"
]
},
{
"cell_type": "code",
2017-06-06 23:13:43 +02:00
"execution_count": 64,
2017-06-01 09:23:37 +02:00
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": true,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"scaler = StandardScaler()\n",
"X_train_scaled = scaler.fit_transform(X_train)\n",
"X_test_scaled = scaler.transform(X_test)"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Let's train a simple `LinearSVR` first:"
]
},
{
"cell_type": "code",
2017-06-06 23:13:43 +02:00
"execution_count": 65,
2017-06-01 09:23:37 +02:00
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"from sklearn.svm import LinearSVR\n",
"\n",
2017-06-06 23:13:43 +02:00
"lin_svr = LinearSVR(random_state=42)\n",
2017-06-01 09:23:37 +02:00
"lin_svr.fit(X_train_scaled, y_train)"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Let's see how it performs on the training set:"
]
},
{
"cell_type": "code",
2017-06-06 23:13:43 +02:00
"execution_count": 66,
2017-06-01 09:23:37 +02:00
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error\n",
"\n",
"y_pred = lin_svr.predict(X_train_scaled)\n",
"mse = mean_squared_error(y_train, y_pred)\n",
"mse"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Let's look at the RMSE:"
]
},
{
"cell_type": "code",
2017-06-06 23:13:43 +02:00
"execution_count": 67,
2017-06-01 09:23:37 +02:00
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"np.sqrt(mse)"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"In this training set, the targets are tens of thousands of dollars. The RMSE gives a rough idea of the kind of error you should expect (with a higher weight for large errors): so with this model we can expect errors somewhere around $10,000. Not great. Let's see if we can do better with an RBF Kernel. We will use randomized search with cross validation to find the appropriate hyperparameter values for `C` and `gamma`:"
]
},
{
"cell_type": "code",
2017-06-06 23:13:43 +02:00
"execution_count": 68,
2017-06-01 09:23:37 +02:00
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"from sklearn.svm import SVR\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"from scipy.stats import reciprocal, uniform\n",
"\n",
"param_distributions = {\"gamma\": reciprocal(0.001, 0.1), \"C\": uniform(1, 10)}\n",
2017-06-06 23:13:43 +02:00
"rnd_search_cv = RandomizedSearchCV(SVR(), param_distributions, n_iter=10, verbose=2, random_state=42)\n",
2017-06-01 09:23:37 +02:00
"rnd_search_cv.fit(X_train_scaled, y_train)"
]
},
{
"cell_type": "code",
2017-06-06 23:13:43 +02:00
"execution_count": 69,
2017-06-01 09:23:37 +02:00
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"rnd_search_cv.best_estimator_"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Now let's measure the RMSE on the training set:"
]
},
{
"cell_type": "code",
2017-06-06 23:13:43 +02:00
"execution_count": 70,
2017-06-01 09:23:37 +02:00
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"y_pred = rnd_search_cv.best_estimator_.predict(X_train_scaled)\n",
"mse = mean_squared_error(y_train, y_pred)\n",
"np.sqrt(mse)"
]
},
{
"cell_type": "markdown",
2017-06-06 23:13:43 +02:00
"metadata": {
"deletable": true,
"editable": true
},
2017-06-01 09:23:37 +02:00
"source": [
"Looks much better than the linear model. Let's select this model and evaluate it on the test set:"
]
},
{
"cell_type": "code",
2017-06-06 23:13:43 +02:00
"execution_count": 71,
2017-06-01 09:23:37 +02:00
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": false,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
"outputs": [],
"source": [
"y_pred = rnd_search_cv.best_estimator_.predict(X_test_scaled)\n",
"mse = mean_squared_error(y_test, y_pred)\n",
"np.sqrt(mse)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
2017-06-06 23:13:43 +02:00
"collapsed": true,
"deletable": true,
"editable": true
2017-06-01 09:23:37 +02:00
},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2017-06-01 09:23:37 +02:00
"version": "3.5.3"
2016-09-27 23:31:21 +02:00
},
"nav_menu": {},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 0
}