{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "**Chapter 3 – Classification**\n", "\n", "_This notebook contains all the sample code and solutions to the exercices in chapter 3._" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# To support both python 2 and python 3\n", "from __future__ import division, print_function, unicode_literals\n", "\n", "# Common imports\n", "import numpy as np\n", "import os\n", "\n", "# to make this notebook's output stable across runs\n", "np.random.seed(42)\n", "\n", "# To plot pretty figures\n", "%matplotlib inline\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "plt.rcParams['axes.labelsize'] = 14\n", "plt.rcParams['xtick.labelsize'] = 12\n", "plt.rcParams['ytick.labelsize'] = 12\n", "\n", "# Where to save the figures\n", "PROJECT_ROOT_DIR = \".\"\n", "CHAPTER_ID = \"classification\"\n", "\n", "def save_fig(fig_id, tight_layout=True):\n", " path = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id + \".png\")\n", " print(\"Saving figure\", fig_id)\n", " if tight_layout:\n", " plt.tight_layout()\n", " plt.savefig(path, format='png', dpi=300)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# MNIST" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import fetch_mldata\n", "mnist = fetch_mldata('MNIST original')\n", "mnist" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "X, y = mnist[\"data\"], mnist[\"target\"]\n", "X.shape" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "y.shape" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "28*28" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "\n", "some_digit = X[36000]\n", "some_digit_image = some_digit.reshape(28, 28)\n", "plt.imshow(some_digit_image, cmap = matplotlib.cm.binary,\n", " interpolation=\"nearest\")\n", "plt.axis(\"off\")\n", "\n", "save_fig(\"some_digit_plot\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def plot_digit(data):\n", " image = data.reshape(28, 28)\n", " plt.imshow(image, cmap = matplotlib.cm.binary,\n", " interpolation=\"nearest\")\n", " plt.axis(\"off\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# EXTRA\n", "def plot_digits(instances, images_per_row=10, **options):\n", " size = 28\n", " images_per_row = min(len(instances), images_per_row)\n", " images = [instance.reshape(size,size) for instance in instances]\n", " n_rows = (len(instances) - 1) // images_per_row + 1\n", " row_images = []\n", " n_empty = n_rows * images_per_row - len(instances)\n", " images.append(np.zeros((size, size * n_empty)))\n", " for row in range(n_rows):\n", " rimages = images[row * images_per_row : (row + 1) * images_per_row]\n", " row_images.append(np.concatenate(rimages, axis=1))\n", " image = np.concatenate(row_images, axis=0)\n", " plt.imshow(image, cmap = matplotlib.cm.binary, **options)\n", " plt.axis(\"off\")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(9,9))\n", "example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]\n", "plot_digits(example_images, images_per_row=10)\n", "save_fig(\"more_digits_plot\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "y[36000]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np\n", "\n", "shuffle_index = np.random.permutation(60000)\n", "X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Binary classifier" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "y_train_5 = (y_train == 5)\n", "y_test_5 = (y_test == 5)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import SGDClassifier\n", "\n", "sgd_clf = SGDClassifier(random_state=42)\n", "sgd_clf.fit(X_train, y_train_5)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "sgd_clf.predict([some_digit])" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import cross_val_score\n", "cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import StratifiedKFold\n", "from sklearn.base import clone\n", "\n", "skfolds = StratifiedKFold(n_splits=3, random_state=42)\n", "\n", "for train_index, test_index in skfolds.split(X_train, y_train_5):\n", " clone_clf = clone(sgd_clf)\n", " X_train_folds = X_train[train_index]\n", " y_train_folds = (y_train_5[train_index])\n", " X_test_fold = X_train[test_index]\n", " y_test_fold = (y_train_5[test_index])\n", "\n", " clone_clf.fit(X_train_folds, y_train_folds)\n", " y_pred = clone_clf.predict(X_test_fold)\n", " n_correct = sum(y_pred == y_test_fold)\n", " print(n_correct / len(y_pred))" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.base import BaseEstimator\n", "class Never5Classifier(BaseEstimator):\n", " def fit(self, X, y=None):\n", " pass\n", " def predict(self, X):\n", " return np.zeros((len(X), 1), dtype=bool)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "never_5_clf = Never5Classifier()\n", "cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.model_selection import cross_val_predict\n", "\n", "y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import confusion_matrix\n", "\n", "confusion_matrix(y_train_5, y_train_pred)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": true }, "outputs": [], "source": [ "y_train_perfect_predictions = y_train_5" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "confusion_matrix(y_train_5, y_train_perfect_predictions)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import precision_score, recall_score\n", "\n", "precision_score(y_train_5, y_train_pred)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "4344 / (4344 + 1307)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "recall_score(y_train_5, y_train_pred)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "4344 / (4344 + 1077)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import f1_score\n", "f1_score(y_train_5, y_train_pred)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "4344 / (4344 + (1077 + 1307)/2)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "y_scores = sgd_clf.decision_function([some_digit])\n", "y_scores" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "collapsed": true }, "outputs": [], "source": [ "threshold = 0\n", "y_some_digit_pred = (y_scores > threshold)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "y_some_digit_pred" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "threshold = 200000\n", "y_some_digit_pred = (y_scores > threshold)\n", "y_some_digit_pred" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "collapsed": true }, "outputs": [], "source": [ "y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,\n", " method=\"decision_function\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note: there is an [issue](https://github.com/scikit-learn/scikit-learn/issues/9589) introduced in Scikit-Learn 0.19.0 where the result of `cross_val_predict()` is incorrect in the binary classification case when using `method=\"decision_function\"`, as in the code above. The resulting array has an extra first dimension full of 0s. We need to add this small hack for now to work around this issue:" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "y_scores.shape" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# hack to work around issue #9589 introduced in Scikit-Learn 0.19.0\n", "if y_scores.ndim == 2:\n", " y_scores = y_scores[:, 1]" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.metrics import precision_recall_curve\n", "\n", "precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):\n", " plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n", " plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n", " plt.xlabel(\"Threshold\", fontsize=16)\n", " plt.legend(loc=\"upper left\", fontsize=16)\n", " plt.ylim([0, 1])\n", "\n", "plt.figure(figsize=(8, 4))\n", "plot_precision_recall_vs_threshold(precisions, recalls, thresholds)\n", "plt.xlim([-700000, 700000])\n", "save_fig(\"precision_recall_vs_threshold_plot\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "(y_train_pred == (y_scores > 0)).all()" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "collapsed": true }, "outputs": [], "source": [ "y_train_pred_90 = (y_scores > 70000)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "precision_score(y_train_5, y_train_pred_90)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "recall_score(y_train_5, y_train_pred_90)" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "def plot_precision_vs_recall(precisions, recalls):\n", " plt.plot(recalls, precisions, \"b-\", linewidth=2)\n", " plt.xlabel(\"Recall\", fontsize=16)\n", " plt.ylabel(\"Precision\", fontsize=16)\n", " plt.axis([0, 1, 0, 1])\n", "\n", "plt.figure(figsize=(8, 6))\n", "plot_precision_vs_recall(precisions, recalls)\n", "save_fig(\"precision_vs_recall_plot\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# ROC curves" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.metrics import roc_curve\n", "\n", "fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "def plot_roc_curve(fpr, tpr, label=None):\n", " plt.plot(fpr, tpr, linewidth=2, label=label)\n", " plt.plot([0, 1], [0, 1], 'k--')\n", " plt.axis([0, 1, 0, 1])\n", " plt.xlabel('False Positive Rate', fontsize=16)\n", " plt.ylabel('True Positive Rate', fontsize=16)\n", "\n", "plt.figure(figsize=(8, 6))\n", "plot_roc_curve(fpr, tpr)\n", "save_fig(\"roc_curve_plot\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import roc_auc_score\n", "\n", "roc_auc_score(y_train_5, y_scores)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "forest_clf = RandomForestClassifier(random_state=42)\n", "y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,\n", " method=\"predict_proba\")" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "collapsed": true }, "outputs": [], "source": [ "y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n", "fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(8, 6))\n", "plt.plot(fpr, tpr, \"b:\", linewidth=2, label=\"SGD\")\n", "plot_roc_curve(fpr_forest, tpr_forest, \"Random Forest\")\n", "plt.legend(loc=\"lower right\", fontsize=16)\n", "save_fig(\"roc_curve_comparison_plot\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "roc_auc_score(y_train_5, y_scores_forest)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "y_train_pred_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3)\n", "precision_score(y_train_5, y_train_pred_forest)" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "recall_score(y_train_5, y_train_pred_forest)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Multiclass classification" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "sgd_clf.fit(X_train, y_train)\n", "sgd_clf.predict([some_digit])" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "some_digit_scores = sgd_clf.decision_function([some_digit])\n", "some_digit_scores" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "np.argmax(some_digit_scores)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "sgd_clf.classes_" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "sgd_clf.classes_[5]" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "from sklearn.multiclass import OneVsOneClassifier\n", "ovo_clf = OneVsOneClassifier(SGDClassifier(random_state=42))\n", "ovo_clf.fit(X_train, y_train)\n", "ovo_clf.predict([some_digit])" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "len(ovo_clf.estimators_)" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "forest_clf.fit(X_train, y_train)\n", "forest_clf.predict([some_digit])" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "forest_clf.predict_proba([some_digit])" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [], "source": [ "cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring=\"accuracy\")" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler\n", "scaler = StandardScaler()\n", "X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))\n", "cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring=\"accuracy\")" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [], "source": [ "y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)\n", "conf_mx = confusion_matrix(y_train, y_train_pred)\n", "conf_mx" ] }, { "cell_type": "code", "execution_count": 65, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def plot_confusion_matrix(matrix):\n", " \"\"\"If you prefer color and a colorbar\"\"\"\n", " fig = plt.figure(figsize=(8,8))\n", " ax = fig.add_subplot(111)\n", " cax = ax.matshow(matrix)\n", " fig.colorbar(cax)" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [], "source": [ "plt.matshow(conf_mx, cmap=plt.cm.gray)\n", "save_fig(\"confusion_matrix_plot\", tight_layout=False)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 67, "metadata": { "collapsed": true }, "outputs": [], "source": [ "row_sums = conf_mx.sum(axis=1, keepdims=True)\n", "norm_conf_mx = conf_mx / row_sums" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "np.fill_diagonal(norm_conf_mx, 0)\n", "plt.matshow(norm_conf_mx, cmap=plt.cm.gray)\n", "save_fig(\"confusion_matrix_errors_plot\", tight_layout=False)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [], "source": [ "cl_a, cl_b = 3, 5\n", "X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]\n", "X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]\n", "X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]\n", "X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]\n", "\n", "plt.figure(figsize=(8,8))\n", "plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)\n", "plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)\n", "plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)\n", "plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)\n", "save_fig(\"error_analysis_digits_plot\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Multilabel classification" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "\n", "y_train_large = (y_train >= 7)\n", "y_train_odd = (y_train % 2 == 1)\n", "y_multilabel = np.c_[y_train_large, y_train_odd]\n", "\n", "knn_clf = KNeighborsClassifier()\n", "knn_clf.fit(X_train, y_multilabel)" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "knn_clf.predict([some_digit])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Warning**: the following cell may take a very long time (possibly hours depending on your hardware)." ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [], "source": [ "y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3)\n", "f1_score(y_multilabel, y_train_knn_pred, average=\"macro\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Multioutput classification" ] }, { "cell_type": "code", "execution_count": 73, "metadata": { "collapsed": true }, "outputs": [], "source": [ "noise = np.random.randint(0, 100, (len(X_train), 784))\n", "X_train_mod = X_train + noise\n", "noise = np.random.randint(0, 100, (len(X_test), 784))\n", "X_test_mod = X_test + noise\n", "y_train_mod = X_train\n", "y_test_mod = X_test" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [], "source": [ "some_index = 5500\n", "plt.subplot(121); plot_digit(X_test_mod[some_index])\n", "plt.subplot(122); plot_digit(y_test_mod[some_index])\n", "save_fig(\"noisy_digit_example_plot\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [], "source": [ "knn_clf.fit(X_train_mod, y_train_mod)\n", "clean_digit = knn_clf.predict([X_test_mod[some_index]])\n", "plot_digit(clean_digit)\n", "save_fig(\"cleaned_digit_example_plot\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Extra material" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dummy (ie. random) classifier" ] }, { "cell_type": "code", "execution_count": 76, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.dummy import DummyClassifier\n", "dmy_clf = DummyClassifier()\n", "y_probas_dmy = cross_val_predict(dmy_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n", "y_scores_dmy = y_probas_dmy[:, 1]" ] }, { "cell_type": "code", "execution_count": 77, "metadata": { "scrolled": true }, "outputs": [], "source": [ "fprr, tprr, thresholdsr = roc_curve(y_train_5, y_scores_dmy)\n", "plot_roc_curve(fprr, tprr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## KNN classifier" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "knn_clf = KNeighborsClassifier(n_jobs=-1, weights='distance', n_neighbors=4)\n", "knn_clf.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 79, "metadata": { "collapsed": true }, "outputs": [], "source": [ "y_knn_pred = knn_clf.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score\n", "accuracy_score(y_test, y_knn_pred)" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [], "source": [ "from scipy.ndimage.interpolation import shift\n", "def shift_digit(digit_array, dx, dy, new=0):\n", " return shift(digit_array.reshape(28, 28), [dy, dx], cval=new).reshape(784)\n", "\n", "plot_digit(shift_digit(some_digit, 5, 1, new=100))" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [], "source": [ "X_train_expanded = [X_train]\n", "y_train_expanded = [y_train]\n", "for dx, dy in ((1, 0), (-1, 0), (0, 1), (0, -1)):\n", " shifted_images = np.apply_along_axis(shift_digit, axis=1, arr=X_train, dx=dx, dy=dy)\n", " X_train_expanded.append(shifted_images)\n", " y_train_expanded.append(y_train)\n", "\n", "X_train_expanded = np.concatenate(X_train_expanded)\n", "y_train_expanded = np.concatenate(y_train_expanded)\n", "X_train_expanded.shape, y_train_expanded.shape" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [], "source": [ "knn_clf.fit(X_train_expanded, y_train_expanded)" ] }, { "cell_type": "code", "execution_count": 84, "metadata": { "collapsed": true }, "outputs": [], "source": [ "y_knn_expanded_pred = knn_clf.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [], "source": [ "accuracy_score(y_test, y_knn_expanded_pred)" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [], "source": [ "ambiguous_digit = X_test[2589]\n", "knn_clf.predict_proba([ambiguous_digit])" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [], "source": [ "plot_digit(ambiguous_digit)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "# Exercise solutions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Coming soon**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" }, "nav_menu": {}, "toc": { "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 6, "toc_cell": false, "toc_section_display": "block", "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 1 }