handson-ml/classification.ipynb

1266 lines
28 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Chapter 3 Classification**\n",
"\n",
"_This notebook contains all the sample code and solutions to the exercices in chapter 3._"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# To support both python 2 and python 3\n",
"from __future__ import division, print_function, unicode_literals\n",
"\n",
"# Common imports\n",
"import numpy as np\n",
"import numpy.random as rnd\n",
"import os\n",
"\n",
"# to make this notebook's output stable across runs\n",
"rnd.seed(42)\n",
"\n",
"# To plot pretty figures\n",
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['axes.labelsize'] = 14\n",
"plt.rcParams['xtick.labelsize'] = 12\n",
"plt.rcParams['ytick.labelsize'] = 12\n",
"\n",
"# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n",
"CHAPTER_ID = \"classification\"\n",
"\n",
"def save_fig(fig_id, tight_layout=True):\n",
" path = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id + \".png\")\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format='png', dpi=300)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MNIST"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_mldata\n",
"mnist = fetch_mldata('MNIST original')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"mnist"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"X, y = mnist[\"data\"], mnist[\"target\"]\n",
"X.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"28*28"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def plot_digit(data):\n",
" image = data.reshape(28, 28)\n",
" plt.imshow(image, cmap = matplotlib.cm.binary,\n",
" interpolation=\"nearest\")\n",
" plt.axis(\"off\")\n",
"\n",
"some_digit_index = 36000\n",
"some_digit = X[some_digit_index]\n",
"plot_digit(some_digit)\n",
"save_fig(\"some_digit_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# EXTRA\n",
"def plot_digits(instances, images_per_row=10, **options):\n",
" size = 28\n",
" images_per_row = min(len(instances), images_per_row)\n",
" images = [instance.reshape(size,size) for instance in instances]\n",
" n_rows = (len(instances) - 1) // images_per_row + 1\n",
" row_images = []\n",
" n_empty = n_rows * images_per_row - len(instances)\n",
" images.append(np.zeros((size, size * n_empty)))\n",
" for row in range(n_rows):\n",
" rimages = images[row * images_per_row : (row + 1) * images_per_row]\n",
" row_images.append(np.concatenate(rimages, axis=1))\n",
" image = np.concatenate(row_images, axis=0)\n",
" plt.imshow(image, cmap = matplotlib.cm.binary, **options)\n",
" plt.axis(\"off\")\n",
"\n",
"plt.figure(figsize=(9,9))\n",
"example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]\n",
"plot_digits(example_images, images_per_row=10)\n",
"save_fig(\"more_digits_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y[some_digit_index]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"shuffle_index = rnd.permutation(60000)\n",
"X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Binary classifier"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"y_train_5 = (y_train == 5)\n",
"y_test_5 = (y_test == 5)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.linear_model import SGDClassifier\n",
"\n",
"sgd_clf = SGDClassifier(random_state=42)\n",
"sgd_clf.fit(X_train, y_train_5)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"sgd_clf.predict([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.cross_validation import cross_val_score\n",
"cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.cross_validation import StratifiedKFold\n",
"from sklearn.base import clone\n",
"\n",
"skfolds = StratifiedKFold(y_train_5, n_folds=3, random_state=42)\n",
"\n",
"for train_index, test_index in skfolds:\n",
" clone_clf = clone(sgd_clf)\n",
" X_train_folds = X_train[train_index]\n",
" y_train_folds = (y_train_5[train_index])\n",
" X_test_fold = X_train[test_index]\n",
" y_test_fold = (y_train_5[test_index])\n",
" \n",
" clone_clf.fit(X_train_folds, y_train_folds)\n",
" y_pred = clone_clf.predict(X_test_fold)\n",
" n_correct = sum(y_pred == y_test_fold)\n",
" print(n_correct / len(y_pred))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator\n",
"class Never5Classifier(BaseEstimator):\n",
" def fit(self, X, y=None):\n",
" pass\n",
" def predict(self, X):\n",
" return np.zeros((len(X), 1), dtype=bool)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"never_5_clf = Never5Classifier()\n",
"cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.cross_validation import cross_val_predict\n",
"\n",
"y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"\n",
"confusion_matrix(y_train_5, y_train_pred)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.metrics import precision_score, recall_score\n",
"\n",
"precision_score(y_train_5, y_train_pred)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"4344 / (4344 + 1307)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"recall_score(y_train_5, y_train_pred)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"4344 / (4344 + 1077)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.metrics import f1_score\n",
"f1_score(y_train_5, y_train_pred)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"4344 / (4344 + (1077 + 1307)/2)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y_scores = sgd_clf.decision_function([some_digit])\n",
"y_scores"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"threshold = 0\n",
"y_some_digit_pred = (y_scores > threshold)\n",
"y_some_digit_pred"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"threshold = 200000\n",
"y_some_digit_pred = (y_scores > threshold)\n",
"y_some_digit_pred"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Implemented in https://github.com/scikit-learn/scikit-learn/pull/6671\n",
"# Pushed to master but not yet in pip module.\n",
"from sklearn.cross_validation import StratifiedKFold\n",
"from sklearn.base import clone\n",
"\n",
"def cross_val_predict_future(clf, X, y, cv, method=None):\n",
" clf_clone = clone(clf) # keep original intact\n",
" if method is None:\n",
" return cross_val_predict(clf, X, y, cv=cv)\n",
" else:\n",
" method_f = getattr(clf_clone, method)\n",
" scores = []\n",
" skfolds = StratifiedKFold(y, n_folds=cv)\n",
" for train_indices, test_indices in skfolds:\n",
" clf_clone.fit(X[train_indices], y[train_indices])\n",
" scores.append((method_f(X[test_indices]), test_indices))\n",
" res_shape = list(scores[0][0].shape)\n",
" res_shape[0] = len(X)\n",
" res = np.empty(tuple(res_shape))\n",
" for sc, test_indices in scores:\n",
" res[test_indices] = sc\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"y_scores = cross_val_predict_future(sgd_clf, X_train, y_train_5, cv=3, method=\"decision_function\")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.metrics import precision_recall_curve\n",
"\n",
"precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):\n",
" plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n",
" plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n",
" plt.xlabel(\"Threshold\", fontsize=16)\n",
" plt.legend(loc=\"center left\", fontsize=16)\n",
" plt.ylim([0, 1])\n",
"\n",
"plt.figure(figsize=(8, 4))\n",
"plot_precision_recall_vs_threshold(precisions, recalls, thresholds)\n",
"plt.xlim([-700000, 700000])\n",
"save_fig(\"precision_recall_vs_threshold_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"(y_train_pred == (y_scores > 0)).all()"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y_train_pred_90 = (y_scores > 70000)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"precision_score(y_train_5, y_train_pred_90)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"recall_score(y_train_5, y_train_pred_90)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def plot_precision_vs_recall(precisions, recalls):\n",
" plt.plot(recalls, precisions, \"b-\", linewidth=2)\n",
" plt.xlabel(\"Recall\", fontsize=16)\n",
" plt.ylabel(\"Precision\", fontsize=16)\n",
" plt.axis([0, 1, 0, 1])\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"plot_precision_vs_recall(precisions, recalls)\n",
"save_fig(\"precision_vs_recall_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ROC curves"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.metrics import roc_curve\n",
"\n",
"fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def plot_roc_curve(fpr, tpr, **options):\n",
" plt.plot(fpr, tpr, linewidth=2, **options)\n",
" plt.plot([0, 1], [0, 1], 'k--')\n",
" plt.axis([0, 1, 0, 1])\n",
" plt.xlabel('False Positive Rate', fontsize=16)\n",
" plt.ylabel('True Positive Rate', fontsize=16)\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"plot_roc_curve(fpr, tpr)\n",
"save_fig(\"roc_curve_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.metrics import roc_auc_score\n",
"\n",
"roc_auc_score(y_train_5, y_scores)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"forest_clf = RandomForestClassifier(random_state=42)\n",
"y_probas_forest = cross_val_predict_future(forest_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
"y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n",
"fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"plt.plot(fpr, tpr, \"b:\", linewidth=2, label=\"SGD\")\n",
"plot_roc_curve(fpr_forest, tpr_forest, label=\"Random Forest\")\n",
"plt.legend(loc=\"lower right\", fontsize=16)\n",
"save_fig(\"roc_curve_comparison_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"roc_auc_score(y_train_5, y_scores_forest)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y_train_pred_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3)\n",
"precision_score(y_train_5, y_train_pred_forest)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"recall_score(y_train_5, y_train_pred_forest)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multiclass classification"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"sgd_clf.fit(X_train, y_train)\n",
"sgd_clf.predict([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"some_digit_scores = sgd_clf.decision_function([some_digit])\n",
"some_digit_scores"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"np.argmax(some_digit_scores)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"sgd_clf.classes_"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.multiclass import OneVsOneClassifier\n",
"ovo_clf = OneVsOneClassifier(SGDClassifier(random_state=42))\n",
"ovo_clf.fit(X_train, y_train)\n",
"ovo_clf.predict([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"len(ovo_clf.estimators_)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"forest_clf.fit(X_train, y_train)\n",
"forest_clf.predict([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"forest_clf.predict_proba([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring=\"accuracy\")"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"scaler = StandardScaler()\n",
"X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))\n",
"cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring=\"accuracy\")"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)\n",
"conf_mx = confusion_matrix(y_train, y_train_pred)\n",
"conf_mx"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def plot_confusion_matrix(matrix):\n",
" \"\"\"If you prefer color and a colorbar\"\"\"\n",
" fig = plt.figure(figsize=(8,8))\n",
" ax = fig.add_subplot(111)\n",
" cax = ax.matshow(conf_mx)\n",
" fig.colorbar(cax)\n",
"\n",
"plt.matshow(conf_mx, cmap=plt.cm.gray)\n",
"save_fig(\"confusion_matrix_plot\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"row_sums = conf_mx.sum(axis=1, keepdims=True)\n",
"norm_conf_mx = conf_mx / row_sums\n",
"np.fill_diagonal(norm_conf_mx, 0)\n",
"plt.matshow(norm_conf_mx, cmap=plt.cm.gray)\n",
"save_fig(\"confusion_matrix_errors_plot\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"cl_a, cl_b = 3, 5\n",
"X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]\n",
"X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]\n",
"X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]\n",
"X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]\n",
"\n",
"plt.figure(figsize=(8,8))\n",
"plt.subplot(221)\n",
"plot_digits(X_aa[:25], images_per_row=5)\n",
"plt.subplot(222)\n",
"plot_digits(X_ab[:25], images_per_row=5)\n",
"plt.subplot(223)\n",
"plot_digits(X_ba[:25], images_per_row=5)\n",
"plt.subplot(224)\n",
"plot_digits(X_bb[:25], images_per_row=5)\n",
"save_fig(\"error_analysis_digits_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multilabel classification"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"\n",
"y_train_large = (y_train >= 7)\n",
"y_train_odd = (y_train % 2 == 1)\n",
"y_multilabel = np.c_[y_train_large, y_train_odd]\n",
"\n",
"knn_clf = KNeighborsClassifier()\n",
"knn_clf.fit(X_train, y_multilabel)"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"knn_clf.predict([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_train, cv=3)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"f1_score(y_train, y_train_knn_pred, average=\"macro\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multioutput classification"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"noise = rnd.randint(0, 100, (len(X_train), 784))\n",
"X_train_mod = X_train + noise\n",
"noise = rnd.randint(0, 100, (len(X_test), 784))\n",
"X_test_mod = X_test + noise\n",
"y_train_mod = X_train\n",
"y_test_mod = X_test"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"some_index = 5500\n",
"plt.subplot(121); plot_digit(X_test_mod[some_index])\n",
"plt.subplot(122); plot_digit(y_test_mod[some_index])\n",
"save_fig(\"noisy_digit_example_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"knn_clf.fit(X_train_mod, y_train_mod)"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"clean_digit = knn_clf.predict([X_test_mod[some_index]])\n",
"plot_digit(clean_digit)\n",
"save_fig(\"cleaned_digit_example_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Extra material"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dummy (ie. random) classifier"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.dummy import DummyClassifier\n",
"dmy_clf = DummyClassifier()\n",
"y_probas_dmy = cross_val_predict_future(dmy_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
"y_scores_dmy = y_probas_dmy[:, 1]"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"fprr, tprr, thresholdsr = roc_curve(y_train_5, y_scores_dmy)\n",
"plot_roc_curve(fprr, tprr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## KNN classifier"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"knn_clf = KNeighborsClassifier(n_jobs=-1, weights='distance', n_neighbors=4)\n",
"knn_clf.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y_knn_pred = knn_clf.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.metrics import accuracy_score\n",
"accuracy_score(y_test, y_knn_pred)"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from scipy.ndimage.interpolation import shift\n",
"def shift_digit(digit_array, dx, dy, new=0):\n",
" return shift(digit_array.reshape(28, 28), [dy, dx], cval=new).reshape(784)\n",
"\n",
"plot_digit(shift_digit(some_digit, 5, 1, new=100))"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"X_train_expanded = [X_train]\n",
"y_train_expanded = [y_train]\n",
"for dx, dy in ((1, 0), (-1, 0), (0, 1), (0, -1)):\n",
" shifted_images = np.apply_along_axis(shift_digit, axis=1, arr=X_train, dx=dx, dy=dy)\n",
" X_train_expanded.append(shifted_images)\n",
" y_train_expanded.append(y_train)\n",
"\n",
"X_train_expanded = np.concatenate(X_train_expanded)\n",
"y_train_expanded = np.concatenate(y_train_expanded)\n",
"X_train_expanded.shape, y_train_expanded.shape"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"knn_clf.fit(X_train_expanded, y_train_expanded)"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"y_knn_expanded_pred = knn_clf.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"accuracy_score(y_test, y_knn_expanded_pred)"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"ambiguous_digit = X_test[2589]\n",
"knn_clf.predict_proba([ambiguous_digit])"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plot_digit(ambiguous_digit)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Coming soon**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
},
"nav_menu": {},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 0
}