handson-ml/05_support_vector_machines....

1374 lines
40 KiB
Plaintext
Raw Normal View History

2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"**Chapter 5 Support Vector Machines**\n",
"\n",
"_This notebook contains all the sample code and solutions to the exercices in chapter 5._"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"# To support both python 2 and python 3\n",
"from __future__ import division, print_function, unicode_literals\n",
"\n",
"# Common imports\n",
"import numpy as np\n",
"import numpy.random as rnd\n",
"import os\n",
"\n",
"# to make this notebook's output stable across runs\n",
"rnd.seed(42)\n",
"\n",
"# To plot pretty figures\n",
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['axes.labelsize'] = 14\n",
"plt.rcParams['xtick.labelsize'] = 12\n",
"plt.rcParams['ytick.labelsize'] = 12\n",
"\n",
"# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n",
"CHAPTER_ID = \"svm\"\n",
"\n",
"def save_fig(fig_id, tight_layout=True):\n",
" path = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id + \".png\")\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format='png', dpi=300)"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Large margin classification"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.svm import SVC\n",
"from sklearn import datasets\n",
"\n",
"iris = datasets.load_iris()\n",
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
"y = iris[\"target\"]\n",
"\n",
"setosa_or_versicolour = (y == 0) | (y == 1)\n",
"X = X[setosa_or_versicolour]\n",
"y = y[setosa_or_versicolour]\n",
"\n",
"# SVM Classifier model\n",
"svm_clf = SVC(kernel=\"linear\", C=float(\"inf\"))\n",
"svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"# Bad models\n",
"x0 = np.linspace(0, 5.5, 200)\n",
"pred_1 = 5*x0 - 20\n",
"pred_2 = x0 - 1.8\n",
"pred_3 = 0.1 * x0 + 0.5\n",
"\n",
"def plot_svc_decision_boundary(svm_clf, xmin, xmax):\n",
" w = svm_clf.coef_[0]\n",
" b = svm_clf.intercept_[0]\n",
"\n",
" # At the decision boundary, w0*x0 + w1*x1 + b = 0\n",
" # => x1 = -w0/w1 * x0 - b/w1\n",
" x0 = np.linspace(xmin, xmax, 200)\n",
" decision_boundary = -w[0]/w[1] * x0 - b/w[1]\n",
"\n",
" margin = 1/w[1]\n",
" gutter_up = decision_boundary + margin\n",
" gutter_down = decision_boundary - margin\n",
"\n",
" svs = svm_clf.support_vectors_\n",
" plt.scatter(svs[:, 0], svs[:, 1], s=180, facecolors='#FFAAAA')\n",
" plt.plot(x0, decision_boundary, \"k-\", linewidth=2)\n",
" plt.plot(x0, gutter_up, \"k--\", linewidth=2)\n",
" plt.plot(x0, gutter_down, \"k--\", linewidth=2)\n",
"\n",
"plt.figure(figsize=(12,2.7))\n",
"\n",
"plt.subplot(121)\n",
"plt.plot(x0, pred_1, \"g--\", linewidth=2)\n",
"plt.plot(x0, pred_2, \"m-\", linewidth=2)\n",
"plt.plot(x0, pred_3, \"r-\", linewidth=2)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\", label=\"Iris-Versicolour\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\", label=\"Iris-Setosa\")\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.ylabel(\"Petal width\", fontsize=14)\n",
"plt.legend(loc=\"upper left\", fontsize=14)\n",
"plt.axis([0, 5.5, 0, 2])\n",
"\n",
"plt.subplot(122)\n",
"plot_svc_decision_boundary(svm_clf, 0, 5.5)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"bs\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"yo\")\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.axis([0, 5.5, 0, 2])\n",
"\n",
"save_fig(\"large_margin_classification_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Sensitivity to feature scales"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"Xs = np.array([[1, 50], [5, 20], [3, 80], [5, 60]]).astype(np.float64)\n",
"ys = np.array([0, 0, 1, 1])\n",
"svm_clf = SVC(kernel=\"linear\", C=100)\n",
"svm_clf.fit(Xs, ys)\n",
"\n",
"plt.figure(figsize=(12,3.2))\n",
"plt.subplot(121)\n",
"plt.plot(Xs[:, 0][ys==1], Xs[:, 1][ys==1], \"bo\")\n",
"plt.plot(Xs[:, 0][ys==0], Xs[:, 1][ys==0], \"ms\")\n",
"plot_svc_decision_boundary(svm_clf, 0, 6)\n",
"plt.xlabel(\"$x_0$\", fontsize=20)\n",
"plt.ylabel(\"$x_1$ \", fontsize=20, rotation=0)\n",
"plt.title(\"Unscaled\", fontsize=16)\n",
"plt.axis([0, 6, 0, 90])\n",
"\n",
"from sklearn.preprocessing import StandardScaler\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(Xs)\n",
"svm_clf.fit(X_scaled, ys)\n",
"\n",
"plt.subplot(122)\n",
"plt.plot(X_scaled[:, 0][ys==1], X_scaled[:, 1][ys==1], \"bo\")\n",
"plt.plot(X_scaled[:, 0][ys==0], X_scaled[:, 1][ys==0], \"ms\")\n",
"plot_svc_decision_boundary(svm_clf, -2, 2)\n",
"plt.xlabel(\"$x_0$\", fontsize=20)\n",
"plt.title(\"Scaled\", fontsize=16)\n",
"plt.axis([-2, 2, -2, 2])\n",
"\n",
"save_fig(\"sensitivity_to_feature_scales_plot\")\n"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Sensitivity to outliers"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"X_outliers = np.array([[3.4, 1.3], [3.2, 0.8]])\n",
"y_outliers = np.array([0, 0])\n",
"Xo1 = np.concatenate([X, X_outliers[:1]], axis=0)\n",
"yo1 = np.concatenate([y, y_outliers[:1]], axis=0)\n",
"Xo2 = np.concatenate([X, X_outliers[1:]], axis=0)\n",
"yo2 = np.concatenate([y, y_outliers[1:]], axis=0)\n",
"\n",
"svm_clf2 = SVC(kernel=\"linear\", C=10**9)#float(\"inf\"))\n",
"svm_clf2.fit(Xo2, yo2)\n",
"\n",
"plt.figure(figsize=(12,2.7))\n",
"\n",
"plt.subplot(121)\n",
"plt.plot(Xo1[:, 0][yo1==1], Xo1[:, 1][yo1==1], \"bs\")\n",
"plt.plot(Xo1[:, 0][yo1==0], Xo1[:, 1][yo1==0], \"yo\")\n",
"plt.text(0.3, 1.0, \"Impossible!\", fontsize=24, color=\"red\")\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.ylabel(\"Petal width\", fontsize=14)\n",
"plt.annotate(\"Outlier\",\n",
" xy=(X_outliers[0][0], X_outliers[0][1]),\n",
" xytext=(2.5, 1.7),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=16,\n",
" )\n",
"plt.axis([0, 5.5, 0, 2])\n",
"\n",
"plt.subplot(122)\n",
"plt.plot(Xo2[:, 0][yo2==1], Xo2[:, 1][yo2==1], \"bs\")\n",
"plt.plot(Xo2[:, 0][yo2==0], Xo2[:, 1][yo2==0], \"yo\")\n",
"plot_svc_decision_boundary(svm_clf2, 0, 5.5)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.annotate(\"Outlier\",\n",
" xy=(X_outliers[1][0], X_outliers[1][1]),\n",
" xytext=(3.2, 0.08),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=16,\n",
" )\n",
"plt.axis([0, 5.5, 0, 2])\n",
"\n",
"save_fig(\"sensitivity_to_outliers_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Large margin *vs* margin violations"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn import datasets\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.svm import LinearSVC\n",
"\n",
"iris = datasets.load_iris()\n",
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
"y = (iris[\"target\"] == 2).astype(np.float64) # Iris-Virginica\n",
"\n",
"scaler = StandardScaler()\n",
"svm_clf1 = LinearSVC(C=100, loss=\"hinge\")\n",
"svm_clf2 = LinearSVC(C=1, loss=\"hinge\")\n",
"\n",
"scaled_svm_clf1 = Pipeline((\n",
" (\"scaler\", scaler),\n",
" (\"linear_svc\", svm_clf1),\n",
" ))\n",
"scaled_svm_clf2 = Pipeline((\n",
" (\"scaler\", scaler),\n",
" (\"linear_svc\", svm_clf2),\n",
" ))\n",
"\n",
"scaled_svm_clf1.fit(X, y)\n",
"scaled_svm_clf2.fit(X, y)\n",
"\n",
"scaled_svm_clf2.predict([[5.5, 1.7]])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"# Convert to unscaled parameters\n",
"b1 = svm_clf1.decision_function([-scaler.mean_ / scaler.scale_])\n",
"b2 = svm_clf2.decision_function([-scaler.mean_ / scaler.scale_])\n",
"w1 = svm_clf1.coef_[0] / scaler.scale_\n",
"w2 = svm_clf2.coef_[0] / scaler.scale_\n",
"svm_clf1.intercept_ = np.array([b1])\n",
"svm_clf2.intercept_ = np.array([b2])\n",
"svm_clf1.coef_ = np.array([w1])\n",
"svm_clf2.coef_ = np.array([w2])\n",
"\n",
"# Find support vectors (LinearSVC does not do this automatically)\n",
"t = y * 2 - 1\n",
"support_vectors_idx1 = (t * (X.dot(w1) + b1) < 1).ravel()\n",
"support_vectors_idx2 = (t * (X.dot(w2) + b2) < 1).ravel()\n",
"svm_clf1.support_vectors_ = X[support_vectors_idx1]\n",
"svm_clf2.support_vectors_ = X[support_vectors_idx2]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"plt.figure(figsize=(12,3.2))\n",
"plt.subplot(121)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\", label=\"Iris-Virginica\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\", label=\"Iris-Versicolour\")\n",
"plot_svc_decision_boundary(svm_clf1, 4, 6)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.ylabel(\"Petal width\", fontsize=14)\n",
"plt.legend(loc=\"upper left\", fontsize=14)\n",
"plt.title(\"$C = {}$\".format(svm_clf1.C), fontsize=16)\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"\n",
"plt.subplot(122)\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
"plot_svc_decision_boundary(svm_clf2, 4, 6)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.title(\"$C = {}$\".format(svm_clf2.C), fontsize=16)\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"\n",
"save_fig(\"regularization_plot\")"
]
},
{
"cell_type": "markdown",
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"source": [
"# Non-linear classification"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"X1D = np.linspace(-4, 4, 9).reshape(-1, 1)\n",
"X2D = np.c_[X1D, X1D**2]\n",
"y = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n",
"\n",
"plt.figure(figsize=(11, 4))\n",
"\n",
"plt.subplot(121)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.plot(X1D[:, 0][y==0], np.zeros(4), \"bs\")\n",
"plt.plot(X1D[:, 0][y==1], np.zeros(5), \"g^\")\n",
"plt.gca().get_yaxis().set_ticks([])\n",
"plt.xlabel(r\"$x_1$\", fontsize=20)\n",
"plt.axis([-4.5, 4.5, -0.2, 0.2])\n",
"\n",
"plt.subplot(122)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n",
"plt.plot(X2D[:, 0][y==0], X2D[:, 1][y==0], \"bs\")\n",
"plt.plot(X2D[:, 0][y==1], X2D[:, 1][y==1], \"g^\")\n",
"plt.xlabel(r\"$x_1$\", fontsize=20)\n",
"plt.ylabel(r\"$x_2$\", fontsize=20, rotation=0)\n",
"plt.gca().get_yaxis().set_ticks([0, 4, 8, 12, 16])\n",
"plt.plot([-4.5, 4.5], [6.5, 6.5], \"r--\", linewidth=3)\n",
"plt.axis([-4.5, 4.5, -1, 17])\n",
"\n",
"plt.subplots_adjust(right=1)\n",
"\n",
"save_fig(\"higher_dimensions_plot\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.datasets import make_moons\n",
"X, y = make_moons(n_samples=100, noise=0.15, random_state=42)\n",
"\n",
"def plot_dataset(X, y, axes):\n",
" plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
" plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n",
" plt.axis(axes)\n",
" plt.grid(True, which='both')\n",
" plt.xlabel(r\"$x_1$\", fontsize=20)\n",
" plt.ylabel(r\"$x_2$\", fontsize=20, rotation=0)\n",
"\n",
"plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"\n",
"polynomial_svm_clf = Pipeline((\n",
" (\"poly_features\", PolynomialFeatures(degree=3)),\n",
" (\"scaler\", StandardScaler()),\n",
" (\"svm_clf\", LinearSVC(C=10, loss=\"hinge\"))\n",
" ))\n",
"\n",
"polynomial_svm_clf.fit(X, y)\n",
"\n",
"def plot_predictions(clf, axes):\n",
" x0s = np.linspace(axes[0], axes[1], 100)\n",
" x1s = np.linspace(axes[2], axes[3], 100)\n",
" x0, x1 = np.meshgrid(x0s, x1s)\n",
" X = np.c_[x0.ravel(), x1.ravel()]\n",
" y_pred = clf.predict(X).reshape(x0.shape)\n",
" y_decision = clf.decision_function(X).reshape(x0.shape)\n",
" plt.contourf(x0, x1, y_pred, cmap=plt.cm.brg, alpha=0.2)\n",
" plt.contourf(x0, x1, y_decision, cmap=plt.cm.brg, alpha=0.1)\n",
"\n",
"plot_predictions(polynomial_svm_clf, [-1.5, 2.5, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
"\n",
"save_fig(\"moons_polynomial_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.svm import SVC\n",
"poly_kernel_svm_clf = Pipeline((\n",
" (\"scaler\", StandardScaler()),\n",
" (\"svm_clf\", SVC(kernel=\"poly\", degree=3, coef0=1, C=5))\n",
" ))\n",
"poly100_kernel_svm_clf = Pipeline((\n",
" (\"scaler\", StandardScaler()),\n",
" (\"svm_clf\", SVC(kernel=\"poly\", degree=10, coef0=100, C=5))\n",
" ))\n",
"\n",
"poly_kernel_svm_clf.fit(X, y)\n",
"poly100_kernel_svm_clf.fit(X, y)\n",
"\n",
"plt.figure(figsize=(11, 4))\n",
"\n",
"plt.subplot(121)\n",
"plot_predictions(poly_kernel_svm_clf, [-1.5, 2.5, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
"plt.title(r\"$d=3, r=1, C=5$\", fontsize=18)\n",
"\n",
"plt.subplot(122)\n",
"plot_predictions(poly100_kernel_svm_clf, [-1.5, 2.5, -1, 1.5])\n",
"plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
"plt.title(r\"$d=10, r=100, C=5$\", fontsize=18)\n",
"\n",
"save_fig(\"moons_kernelized_polynomial_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false,
2017-02-17 11:51:26 +01:00
"deletable": true,
"editable": true,
2016-09-27 23:31:21 +02:00
"scrolled": true
},
"outputs": [],
"source": [
"def gaussian_rbf(x, landmark, gamma):\n",
" return np.exp(-gamma * np.linalg.norm(x - landmark, axis=1)**2)\n",
"\n",
"gamma = 0.3\n",
"\n",
"x1s = np.linspace(-4.5, 4.5, 200).reshape(-1, 1)\n",
"x2s = gaussian_rbf(x1s, -2, gamma)\n",
"x3s = gaussian_rbf(x1s, 1, gamma)\n",
"\n",
"XK = np.c_[gaussian_rbf(X1D, -2, gamma), gaussian_rbf(X1D, 1, gamma)]\n",
"yk = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0])\n",
"\n",
"plt.figure(figsize=(11, 4))\n",
"\n",
"plt.subplot(121)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.scatter(x=[-2, 1], y=[0, 0], s=150, alpha=0.5, c=\"red\")\n",
"plt.plot(X1D[:, 0][yk==0], np.zeros(4), \"bs\")\n",
"plt.plot(X1D[:, 0][yk==1], np.zeros(5), \"g^\")\n",
"plt.plot(x1s, x2s, \"g--\")\n",
"plt.plot(x1s, x3s, \"b:\")\n",
"plt.gca().get_yaxis().set_ticks([0, 0.25, 0.5, 0.75, 1])\n",
"plt.xlabel(r\"$x_1$\", fontsize=20)\n",
"plt.ylabel(r\"Similarity\", fontsize=14)\n",
"plt.annotate(r'$\\mathbf{x}$',\n",
" xy=(X1D[3, 0], 0),\n",
" xytext=(-0.5, 0.20),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=18,\n",
" )\n",
"plt.text(-2, 0.9, \"$x_2$\", ha=\"center\", fontsize=20)\n",
"plt.text(1, 0.9, \"$x_3$\", ha=\"center\", fontsize=20)\n",
"plt.axis([-4.5, 4.5, -0.1, 1.1])\n",
"\n",
"plt.subplot(122)\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n",
"plt.plot(XK[:, 0][yk==0], XK[:, 1][yk==0], \"bs\")\n",
"plt.plot(XK[:, 0][yk==1], XK[:, 1][yk==1], \"g^\")\n",
"plt.xlabel(r\"$x_2$\", fontsize=20)\n",
"plt.ylabel(r\"$x_3$ \", fontsize=20, rotation=0)\n",
"plt.annotate(r'$\\phi\\left(\\mathbf{x}\\right)$',\n",
" xy=(XK[3, 0], XK[3, 1]),\n",
" xytext=(0.65, 0.50),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.1),\n",
" fontsize=18,\n",
" )\n",
"plt.plot([-0.1, 1.1], [0.57, -0.1], \"r--\", linewidth=3)\n",
"plt.axis([-0.1, 1.1, -0.1, 1.1])\n",
" \n",
"plt.subplots_adjust(right=1)\n",
"\n",
"save_fig(\"kernel_method_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"x1_example = X1D[3, 0]\n",
"for landmark in (-2, 1):\n",
" k = gaussian_rbf(np.array([[x1_example]]), np.array([[landmark]]), gamma)\n",
" print(\"Phi({}, {}) = {}\".format(x1_example, landmark, k))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"rbf_kernel_svm_clf = Pipeline((\n",
" (\"scaler\", StandardScaler()),\n",
" (\"svm_clf\", SVC(kernel=\"rbf\", gamma=5, C=0.001))\n",
" ))\n",
"rbf_kernel_svm_clf.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false,
2017-02-17 11:51:26 +01:00
"deletable": true,
"editable": true,
2016-09-27 23:31:21 +02:00
"scrolled": true
},
"outputs": [],
"source": [
"from sklearn.svm import SVC\n",
"\n",
"gamma1, gamma2 = 0.1, 5\n",
"C1, C2 = 0.001, 1000\n",
"hyperparams = (gamma1, C1), (gamma1, C2), (gamma2, C1), (gamma2, C2)\n",
"\n",
"svm_clfs = []\n",
"for gamma, C in hyperparams:\n",
" rbf_kernel_svm_clf = Pipeline((\n",
" (\"scaler\", StandardScaler()),\n",
" (\"svm_clf\", SVC(kernel=\"rbf\", gamma=gamma, C=C))\n",
" ))\n",
" rbf_kernel_svm_clf.fit(X, y)\n",
" svm_clfs.append(rbf_kernel_svm_clf)\n",
"\n",
"plt.figure(figsize=(11, 7))\n",
"\n",
"for i, svm_clf in enumerate(svm_clfs):\n",
" plt.subplot(221 + i)\n",
" plot_predictions(svm_clf, [-1.5, 2.5, -1, 1.5])\n",
" plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])\n",
" gamma, C = hyperparams[i]\n",
" plt.title(r\"$\\gamma = {}, C = {}$\".format(gamma, C), fontsize=16)\n",
"\n",
"save_fig(\"moons_rbf_svc_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Regression\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.svm import LinearSVR\n",
"\n",
"rnd.seed(42)\n",
"m = 50\n",
"X = 2 * rnd.rand(m, 1)\n",
"y = (4 + 3 * X + rnd.randn(m, 1)).ravel()\n",
"\n",
"svm_reg1 = LinearSVR(epsilon=1.5)\n",
"svm_reg2 = LinearSVR(epsilon=0.5)\n",
"svm_reg1.fit(X, y)\n",
"svm_reg2.fit(X, y)\n",
"\n",
"def find_support_vectors(svm_reg, X, y):\n",
" y_pred = svm_reg.predict(X)\n",
" off_margin = (np.abs(y - y_pred) >= svm_reg.epsilon)\n",
" return np.argwhere(off_margin)\n",
"\n",
"svm_reg1.support_ = find_support_vectors(svm_reg1, X, y)\n",
"svm_reg2.support_ = find_support_vectors(svm_reg2, X, y)\n",
"\n",
"eps_x1 = 1\n",
"eps_y_pred = svm_reg1.predict([[eps_x1]])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"def plot_svm_regression(svm_reg, X, y, axes):\n",
" x1s = np.linspace(axes[0], axes[1], 100).reshape(100, 1)\n",
" y_pred = svm_reg.predict(x1s)\n",
" plt.plot(x1s, y_pred, \"k-\", linewidth=2, label=r\"$\\hat{y}$\")\n",
" plt.plot(x1s, y_pred + svm_reg.epsilon, \"k--\")\n",
" plt.plot(x1s, y_pred - svm_reg.epsilon, \"k--\")\n",
" plt.scatter(X[svm_reg.support_], y[svm_reg.support_], s=180, facecolors='#FFAAAA')\n",
" plt.plot(X, y, \"bo\")\n",
" plt.xlabel(r\"$x_1$\", fontsize=18)\n",
" plt.legend(loc=\"upper left\", fontsize=18)\n",
" plt.axis(axes)\n",
"\n",
"plt.figure(figsize=(9, 4))\n",
"plt.subplot(121)\n",
"plot_svm_regression(svm_reg1, X, y, [0, 2, 3, 11])\n",
"plt.title(r\"$\\epsilon = {}$\".format(svm_reg1.epsilon), fontsize=18)\n",
"plt.ylabel(r\"$y$\", fontsize=18, rotation=0)\n",
"#plt.plot([eps_x1, eps_x1], [eps_y_pred, eps_y_pred - svm_reg1.epsilon], \"k-\", linewidth=2)\n",
"plt.annotate(\n",
" '', xy=(eps_x1, eps_y_pred), xycoords='data',\n",
" xytext=(eps_x1, eps_y_pred - svm_reg1.epsilon),\n",
" textcoords='data', arrowprops={'arrowstyle': '<->', 'linewidth': 1.5}\n",
" )\n",
"plt.text(0.91, 5.6, r\"$\\epsilon$\", fontsize=20)\n",
"plt.subplot(122)\n",
"plot_svm_regression(svm_reg2, X, y, [0, 2, 3, 11])\n",
"plt.title(r\"$\\epsilon = {}$\".format(svm_reg2.epsilon), fontsize=18)\n",
"save_fig(\"svm_regression_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.svm import SVR\n",
"\n",
"rnd.seed(42)\n",
"m = 100\n",
"X = 2 * rnd.rand(m, 1) - 1\n",
"y = (0.2 + 0.1 * X + 0.5 * X**2 + rnd.randn(m, 1)/10).ravel()\n",
"\n",
"svm_poly_reg1 = SVR(kernel=\"poly\", degree=2, C=100, epsilon=0.1)\n",
"svm_poly_reg2 = SVR(kernel=\"poly\", degree=2, C=0.01, epsilon=0.1)\n",
"svm_poly_reg1.fit(X, y)\n",
"svm_poly_reg2.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"plt.figure(figsize=(9, 4))\n",
"plt.subplot(121)\n",
"plot_svm_regression(svm_poly_reg1, X, y, [-1, 1, 0, 1])\n",
"plt.title(r\"$degree={}, C={}, \\epsilon = {}$\".format(svm_poly_reg1.degree, svm_poly_reg1.C, svm_poly_reg1.epsilon), fontsize=18)\n",
"plt.ylabel(r\"$y$\", fontsize=18, rotation=0)\n",
"plt.subplot(122)\n",
"plot_svm_regression(svm_poly_reg2, X, y, [-1, 1, 0, 1])\n",
"plt.title(r\"$degree={}, C={}, \\epsilon = {}$\".format(svm_poly_reg2.degree, svm_poly_reg2.C, svm_poly_reg2.epsilon), fontsize=18)\n",
"save_fig(\"svm_with_polynomial_kernel_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Under the hood"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"iris = datasets.load_iris()\n",
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
"y = (iris[\"target\"] == 2).astype(np.float64) # Iris-Virginica"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from mpl_toolkits.mplot3d import Axes3D\n",
"\n",
"def plot_3D_decision_function(ax, w, b, x1_lim=[4, 6], x2_lim=[0.8, 2.8]):\n",
" x1_in_bounds = (X[:, 0] > x1_lim[0]) & (X[:, 0] < x1_lim[1])\n",
" X_crop = X[x1_in_bounds]\n",
" y_crop = y[x1_in_bounds]\n",
" x1s = np.linspace(x1_lim[0], x1_lim[1], 20)\n",
" x2s = np.linspace(x2_lim[0], x2_lim[1], 20)\n",
" x1, x2 = np.meshgrid(x1s, x2s)\n",
" xs = np.c_[x1.ravel(), x2.ravel()]\n",
" df = (xs.dot(w) + b).reshape(x1.shape)\n",
" m = 1 / np.linalg.norm(w)\n",
" boundary_x2s = -x1s*(w[0]/w[1])-b/w[1]\n",
" margin_x2s_1 = -x1s*(w[0]/w[1])-(b-1)/w[1]\n",
" margin_x2s_2 = -x1s*(w[0]/w[1])-(b+1)/w[1]\n",
" ax.plot_surface(x1s, x2, 0, color=\"b\", alpha=0.2, cstride=100, rstride=100)\n",
" ax.plot(x1s, boundary_x2s, 0, \"k-\", linewidth=2, label=r\"$h=0$\")\n",
" ax.plot(x1s, margin_x2s_1, 0, \"k--\", linewidth=2, label=r\"$h=\\pm 1$\")\n",
" ax.plot(x1s, margin_x2s_2, 0, \"k--\", linewidth=2)\n",
" ax.plot(X_crop[:, 0][y_crop==1], X_crop[:, 1][y_crop==1], 0, \"g^\")\n",
" ax.plot_wireframe(x1, x2, df, alpha=0.3, color=\"k\")\n",
" ax.plot(X_crop[:, 0][y_crop==0], X_crop[:, 1][y_crop==0], 0, \"bs\")\n",
" ax.axis(x1_lim + x2_lim)\n",
" ax.text(4.5, 2.5, 3.8, \"Decision function $h$\", fontsize=15)\n",
" ax.set_xlabel(r\"Petal length\", fontsize=15)\n",
" ax.set_ylabel(r\"Petal width\", fontsize=15)\n",
" ax.set_zlabel(r\"$h = \\mathbf{w}^t \\cdot \\mathbf{x} + b$\", fontsize=18)\n",
" ax.legend(loc=\"upper left\", fontsize=16)\n",
"\n",
"fig = plt.figure(figsize=(11, 6))\n",
"ax1 = fig.add_subplot(111, projection='3d')\n",
"plot_3D_decision_function(ax1, w=svm_clf2.coef_[0], b=svm_clf2.intercept_[0])\n",
"\n",
"save_fig(\"iris_3D_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Small weight vector results in a large margin"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"def plot_2D_decision_function(w, b, ylabel=True, x1_lim=[-3, 3]):\n",
" x1 = np.linspace(x1_lim[0], x1_lim[1], 200)\n",
" y = w * x1 + b\n",
" m = 1 / w\n",
"\n",
" plt.plot(x1, y)\n",
" plt.plot(x1_lim, [1, 1], \"k:\")\n",
" plt.plot(x1_lim, [-1, -1], \"k:\")\n",
" plt.axhline(y=0, color='k')\n",
" plt.axvline(x=0, color='k')\n",
" plt.plot([m, m], [0, 1], \"k--\")\n",
" plt.plot([-m, -m], [0, -1], \"k--\")\n",
" plt.plot([-m, m], [0, 0], \"k-o\", linewidth=3)\n",
" plt.axis(x1_lim + [-2, 2])\n",
" plt.xlabel(r\"$x_1$\", fontsize=16)\n",
" if ylabel:\n",
" plt.ylabel(r\"$w_1 x_1$ \", rotation=0, fontsize=16)\n",
" plt.title(r\"$w_1 = {}$\".format(w), fontsize=16)\n",
"\n",
"plt.figure(figsize=(12, 3.2))\n",
"plt.subplot(121)\n",
"plot_2D_decision_function(1, 0)\n",
"plt.subplot(122)\n",
"plot_2D_decision_function(0.5, 0, ylabel=False)\n",
"save_fig(\"small_w_large_margin_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.svm import SVC\n",
"from sklearn import datasets\n",
"\n",
"iris = datasets.load_iris()\n",
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
"y = (iris[\"target\"] == 2).astype(np.float64) # Iris-Virginica\n",
"\n",
"svm_clf = SVC(kernel=\"linear\", C=1)\n",
"svm_clf.fit(X, y)\n",
"svm_clf.predict([[5.3, 1.3]])"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Hinge loss"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"t = np.linspace(-2, 4, 200)\n",
"h = np.where(1 - t < 0, 0, 1 - t) # max(0, 1-t)\n",
"\n",
"plt.figure(figsize=(5,2.8))\n",
"plt.plot(t, h, \"b-\", linewidth=2, label=\"$max(0, 1 - t)$\")\n",
"plt.grid(True, which='both')\n",
"plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n",
"plt.yticks(np.arange(-1, 2.5, 1))\n",
"plt.xlabel(\"$t$\", fontsize=16)\n",
"plt.axis([-2, 4, -1, 2.5])\n",
"plt.legend(loc=\"upper right\", fontsize=16)\n",
"save_fig(\"hinge_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Extra material"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"## Training time"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"X, y = make_moons(n_samples=1000, noise=0.4)\n",
"plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
"plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"import time\n",
"\n",
"tol = 0.1\n",
"tols = []\n",
"times = []\n",
"for i in range(10):\n",
" svm_clf = SVC(kernel=\"poly\", gamma=3, C=10, tol=tol, verbose=1)\n",
" t1 = time.time()\n",
" svm_clf.fit(X, y)\n",
" t2 = time.time()\n",
" times.append(t2-t1)\n",
" tols.append(tol)\n",
" print(i, tol, t2-t1)\n",
" tol /= 10\n",
"plt.semilogx(tols, times)"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"## Identical linear classifiers"
]
},
{
"cell_type": "code",
2017-02-17 11:51:26 +01:00
"execution_count": 28,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.svm import SVC, LinearSVC\n",
"from sklearn.linear_model import SGDClassifier\n",
"from sklearn.datasets import make_moons\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"X, y = make_moons(n_samples=100, noise=0.15, random_state=42)\n",
"\n",
"C = 5\n",
"alpha = 1 / (C * len(X))\n",
"\n",
"sgd_clf = SGDClassifier(loss=\"hinge\", learning_rate=\"constant\", eta0=0.001, alpha=alpha, n_iter=100000, random_state=42)\n",
"svm_clf = SVC(kernel=\"linear\", C=C)\n",
"lin_clf = LinearSVC(loss=\"hinge\", C=C)\n",
"\n",
"X_scaled = StandardScaler().fit_transform(X)\n",
"sgd_clf.fit(X_scaled, y)\n",
"svm_clf.fit(X_scaled, y)\n",
"lin_clf.fit(X_scaled, y)\n",
"\n",
2017-02-17 11:51:26 +01:00
"print(\"SGDClassifier(alpha={}): \".format(sgd_clf.alpha), sgd_clf.intercept_, sgd_clf.coef_)\n",
2016-09-27 23:31:21 +02:00
"print(\"SVC: \", svm_clf.intercept_, svm_clf.coef_)\n",
"print(\"LinearSVC: \", lin_clf.intercept_, lin_clf.coef_)"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"## Linear SVM classifier implementation using Batch Gradient Descent"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"# Training set\n",
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
"y = (iris[\"target\"] == 2).astype(np.float64).reshape(-1, 1) # Iris-Virginica"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator\n",
"\n",
"class MyLinearSVC(BaseEstimator):\n",
" def __init__(self, C=1, eta0=1, eta_d=10000, n_epochs=1000, random_state=None):\n",
" self.C = C\n",
" self.eta0 = eta0\n",
" self.n_epochs = n_epochs\n",
" self.random_state = random_state\n",
" self.eta_d = eta_d\n",
"\n",
" def eta(self, epoch):\n",
" return self.eta0 / (epoch + self.eta_d)\n",
" \n",
" def fit(self, X, y):\n",
" # Random initialization\n",
" if self.random_state:\n",
" rnd.seed(self.random_state)\n",
" w = rnd.randn(X.shape[1], 1) # n feature weights\n",
" b = 0\n",
"\n",
" m = len(X)\n",
" t = y * 2 - 1 # -1 if t==0, +1 if t==1\n",
" X_t = X * t\n",
" self.Js=[]\n",
"\n",
" # Training\n",
" for epoch in range(self.n_epochs):\n",
" support_vectors_idx = (X_t.dot(w) + t * b < 1).ravel()\n",
" X_t_sv = X_t[support_vectors_idx]\n",
" t_sv = t[support_vectors_idx]\n",
"\n",
" J = 1/2 * np.sum(w * w) + self.C * (np.sum(1 - X_t_sv.dot(w)) - b * np.sum(t_sv))\n",
" self.Js.append(J)\n",
"\n",
" w_gradient_vector = w - self.C * np.sum(X_t_sv, axis=0).reshape(-1, 1)\n",
" b_derivative = -C * np.sum(t_sv)\n",
" \n",
" w = w - self.eta(epoch) * w_gradient_vector\n",
" b = b - self.eta(epoch) * b_derivative\n",
" \n",
"\n",
" self.intercept_ = np.array([b])\n",
" self.coef_ = np.array([w])\n",
" support_vectors_idx = (X_t.dot(w) + b < 1).ravel()\n",
" self.support_vectors_ = X[support_vectors_idx]\n",
" return self\n",
"\n",
" def decision_function(self, X):\n",
" return X.dot(self.coef_[0]) + self.intercept_[0]\n",
"\n",
" def predict(self, X):\n",
" return (self.decision_function(X) >= 0).astype(np.float64)\n",
"\n",
"C=2\n",
"svm_clf = MyLinearSVC(C=C, eta0 = 10, eta_d = 1000, n_epochs=60000, random_state=2)\n",
"svm_clf.fit(X, y)\n",
"svm_clf.predict(np.array([[5, 2], [4, 1]]))"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"plt.plot(range(svm_clf.n_epochs), svm_clf.Js)\n",
"plt.axis([0, svm_clf.n_epochs, 0, 100])"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"print(svm_clf.intercept_, svm_clf.coef_)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"svm_clf2 = SVC(kernel=\"linear\", C=C)\n",
"svm_clf2.fit(X, y.ravel())\n",
"print(svm_clf2.intercept_, svm_clf2.coef_)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"yr = y.ravel()\n",
"plt.figure(figsize=(12,3.2))\n",
"plt.subplot(121)\n",
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\", label=\"Iris-Virginica\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\", label=\"Not Iris-Virginica\")\n",
"plot_svc_decision_boundary(svm_clf, 4, 6)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.ylabel(\"Petal width\", fontsize=14)\n",
"plt.title(\"MyLinearSVC\", fontsize=14)\n",
"plt.axis([4, 6, 0.8, 2.8])\n",
"\n",
"plt.subplot(122)\n",
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\")\n",
"plot_svc_decision_boundary(svm_clf2, 4, 6)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.title(\"SVC\", fontsize=14)\n",
"plt.axis([4, 6, 0.8, 2.8])\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"collapsed": false,
2017-02-17 11:51:26 +01:00
"deletable": true,
"editable": true,
2016-09-27 23:31:21 +02:00
"scrolled": true
},
"outputs": [],
"source": [
"from sklearn.linear_model import SGDClassifier\n",
"\n",
"sgd_clf = SGDClassifier(loss=\"hinge\", alpha = 0.017, n_iter = 50, random_state=42)\n",
"sgd_clf.fit(X, y.ravel())\n",
"\n",
"m = len(X)\n",
"t = y * 2 - 1 # -1 if t==0, +1 if t==1\n",
"X_b = np.c_[np.ones((m, 1)), X] # Add bias input x0=1\n",
"X_b_t = X_b * t\n",
"sgd_theta = np.r_[sgd_clf.intercept_[0], sgd_clf.coef_[0]]\n",
"print(sgd_theta)\n",
"support_vectors_idx = (X_b_t.dot(sgd_theta) < 1).ravel()\n",
"sgd_clf.support_vectors_ = X[support_vectors_idx]\n",
"sgd_clf.C = C\n",
"\n",
"plt.figure(figsize=(5.5,3.2))\n",
"plt.plot(X[:, 0][yr==1], X[:, 1][yr==1], \"g^\")\n",
"plt.plot(X[:, 0][yr==0], X[:, 1][yr==0], \"bs\")\n",
"plot_svc_decision_boundary(sgd_clf, 4, 6)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.ylabel(\"Petal width\", fontsize=14)\n",
"plt.title(\"SGDClassifier\", fontsize=14)\n",
"plt.axis([4, 6, 0.8, 2.8])\n"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"**Coming soon**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2017-02-17 11:51:26 +01:00
"version": "3.5.2+"
2016-09-27 23:31:21 +02:00
},
"nav_menu": {},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 0
}