1266 lines
28 KiB
Plaintext
1266 lines
28 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Chapter 3 – Classification**\n",
|
||
"\n",
|
||
"_This notebook contains all the sample code and solutions to the exercices in chapter 3._"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Setup"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# To support both python 2 and python 3\n",
|
||
"from __future__ import division, print_function, unicode_literals\n",
|
||
"\n",
|
||
"# Common imports\n",
|
||
"import numpy as np\n",
|
||
"import numpy.random as rnd\n",
|
||
"import os\n",
|
||
"\n",
|
||
"# to make this notebook's output stable across runs\n",
|
||
"rnd.seed(42)\n",
|
||
"\n",
|
||
"# To plot pretty figures\n",
|
||
"%matplotlib inline\n",
|
||
"import matplotlib\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"plt.rcParams['axes.labelsize'] = 14\n",
|
||
"plt.rcParams['xtick.labelsize'] = 12\n",
|
||
"plt.rcParams['ytick.labelsize'] = 12\n",
|
||
"\n",
|
||
"# Where to save the figures\n",
|
||
"PROJECT_ROOT_DIR = \".\"\n",
|
||
"CHAPTER_ID = \"classification\"\n",
|
||
"\n",
|
||
"def save_fig(fig_id, tight_layout=True):\n",
|
||
" path = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id + \".png\")\n",
|
||
" print(\"Saving figure\", fig_id)\n",
|
||
" if tight_layout:\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.savefig(path, format='png', dpi=300)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# MNIST"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.datasets import fetch_mldata\n",
|
||
"mnist = fetch_mldata('MNIST original')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"mnist"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"X, y = mnist[\"data\"], mnist[\"target\"]\n",
|
||
"X.shape"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"y.shape"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"28*28"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def plot_digit(data):\n",
|
||
" image = data.reshape(28, 28)\n",
|
||
" plt.imshow(image, cmap = matplotlib.cm.binary,\n",
|
||
" interpolation=\"nearest\")\n",
|
||
" plt.axis(\"off\")\n",
|
||
"\n",
|
||
"some_digit_index = 36000\n",
|
||
"some_digit = X[some_digit_index]\n",
|
||
"plot_digit(some_digit)\n",
|
||
"save_fig(\"some_digit_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# EXTRA\n",
|
||
"def plot_digits(instances, images_per_row=10, **options):\n",
|
||
" size = 28\n",
|
||
" images_per_row = min(len(instances), images_per_row)\n",
|
||
" images = [instance.reshape(size,size) for instance in instances]\n",
|
||
" n_rows = (len(instances) - 1) // images_per_row + 1\n",
|
||
" row_images = []\n",
|
||
" n_empty = n_rows * images_per_row - len(instances)\n",
|
||
" images.append(np.zeros((size, size * n_empty)))\n",
|
||
" for row in range(n_rows):\n",
|
||
" rimages = images[row * images_per_row : (row + 1) * images_per_row]\n",
|
||
" row_images.append(np.concatenate(rimages, axis=1))\n",
|
||
" image = np.concatenate(row_images, axis=0)\n",
|
||
" plt.imshow(image, cmap = matplotlib.cm.binary, **options)\n",
|
||
" plt.axis(\"off\")\n",
|
||
"\n",
|
||
"plt.figure(figsize=(9,9))\n",
|
||
"example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]\n",
|
||
"plot_digits(example_images, images_per_row=10)\n",
|
||
"save_fig(\"more_digits_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"y[some_digit_index]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"shuffle_index = rnd.permutation(60000)\n",
|
||
"X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Binary classifier"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"y_train_5 = (y_train == 5)\n",
|
||
"y_test_5 = (y_test == 5)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.linear_model import SGDClassifier\n",
|
||
"\n",
|
||
"sgd_clf = SGDClassifier(random_state=42)\n",
|
||
"sgd_clf.fit(X_train, y_train_5)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"sgd_clf.predict([some_digit])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.cross_validation import cross_val_score\n",
|
||
"cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.cross_validation import StratifiedKFold\n",
|
||
"from sklearn.base import clone\n",
|
||
"\n",
|
||
"skfolds = StratifiedKFold(y_train_5, n_folds=3, random_state=42)\n",
|
||
"\n",
|
||
"for train_index, test_index in skfolds:\n",
|
||
" clone_clf = clone(sgd_clf)\n",
|
||
" X_train_folds = X_train[train_index]\n",
|
||
" y_train_folds = (y_train_5[train_index])\n",
|
||
" X_test_fold = X_train[test_index]\n",
|
||
" y_test_fold = (y_train_5[test_index])\n",
|
||
" \n",
|
||
" clone_clf.fit(X_train_folds, y_train_folds)\n",
|
||
" y_pred = clone_clf.predict(X_test_fold)\n",
|
||
" n_correct = sum(y_pred == y_test_fold)\n",
|
||
" print(n_correct / len(y_pred))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.base import BaseEstimator\n",
|
||
"class Never5Classifier(BaseEstimator):\n",
|
||
" def fit(self, X, y=None):\n",
|
||
" pass\n",
|
||
" def predict(self, X):\n",
|
||
" return np.zeros((len(X), 1), dtype=bool)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"never_5_clf = Never5Classifier()\n",
|
||
"cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.cross_validation import cross_val_predict\n",
|
||
"\n",
|
||
"y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.metrics import confusion_matrix\n",
|
||
"\n",
|
||
"confusion_matrix(y_train_5, y_train_pred)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.metrics import precision_score, recall_score\n",
|
||
"\n",
|
||
"precision_score(y_train_5, y_train_pred)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"4344 / (4344 + 1307)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"recall_score(y_train_5, y_train_pred)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"4344 / (4344 + 1077)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.metrics import f1_score\n",
|
||
"f1_score(y_train_5, y_train_pred)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"4344 / (4344 + (1077 + 1307)/2)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"y_scores = sgd_clf.decision_function([some_digit])\n",
|
||
"y_scores"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"threshold = 0\n",
|
||
"y_some_digit_pred = (y_scores > threshold)\n",
|
||
"y_some_digit_pred"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 29,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"threshold = 200000\n",
|
||
"y_some_digit_pred = (y_scores > threshold)\n",
|
||
"y_some_digit_pred"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Implemented in https://github.com/scikit-learn/scikit-learn/pull/6671\n",
|
||
"# Pushed to master but not yet in pip module.\n",
|
||
"from sklearn.cross_validation import StratifiedKFold\n",
|
||
"from sklearn.base import clone\n",
|
||
"\n",
|
||
"def cross_val_predict_future(clf, X, y, cv, method=None):\n",
|
||
" clf_clone = clone(clf) # keep original intact\n",
|
||
" if method is None:\n",
|
||
" return cross_val_predict(clf, X, y, cv=cv)\n",
|
||
" else:\n",
|
||
" method_f = getattr(clf_clone, method)\n",
|
||
" scores = []\n",
|
||
" skfolds = StratifiedKFold(y, n_folds=cv)\n",
|
||
" for train_indices, test_indices in skfolds:\n",
|
||
" clf_clone.fit(X[train_indices], y[train_indices])\n",
|
||
" scores.append((method_f(X[test_indices]), test_indices))\n",
|
||
" res_shape = list(scores[0][0].shape)\n",
|
||
" res_shape[0] = len(X)\n",
|
||
" res = np.empty(tuple(res_shape))\n",
|
||
" for sc, test_indices in scores:\n",
|
||
" res[test_indices] = sc\n",
|
||
" return res"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"y_scores = cross_val_predict_future(sgd_clf, X_train, y_train_5, cv=3, method=\"decision_function\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 32,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.metrics import precision_recall_curve\n",
|
||
"\n",
|
||
"precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 33,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):\n",
|
||
" plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n",
|
||
" plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n",
|
||
" plt.xlabel(\"Threshold\", fontsize=16)\n",
|
||
" plt.legend(loc=\"center left\", fontsize=16)\n",
|
||
" plt.ylim([0, 1])\n",
|
||
"\n",
|
||
"plt.figure(figsize=(8, 4))\n",
|
||
"plot_precision_recall_vs_threshold(precisions, recalls, thresholds)\n",
|
||
"plt.xlim([-700000, 700000])\n",
|
||
"save_fig(\"precision_recall_vs_threshold_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 34,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"(y_train_pred == (y_scores > 0)).all()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 35,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"y_train_pred_90 = (y_scores > 70000)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 36,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"precision_score(y_train_5, y_train_pred_90)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 37,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"recall_score(y_train_5, y_train_pred_90)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 38,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def plot_precision_vs_recall(precisions, recalls):\n",
|
||
" plt.plot(recalls, precisions, \"b-\", linewidth=2)\n",
|
||
" plt.xlabel(\"Recall\", fontsize=16)\n",
|
||
" plt.ylabel(\"Precision\", fontsize=16)\n",
|
||
" plt.axis([0, 1, 0, 1])\n",
|
||
"\n",
|
||
"plt.figure(figsize=(8, 6))\n",
|
||
"plot_precision_vs_recall(precisions, recalls)\n",
|
||
"save_fig(\"precision_vs_recall_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# ROC curves"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 39,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.metrics import roc_curve\n",
|
||
"\n",
|
||
"fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 40,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def plot_roc_curve(fpr, tpr, **options):\n",
|
||
" plt.plot(fpr, tpr, linewidth=2, **options)\n",
|
||
" plt.plot([0, 1], [0, 1], 'k--')\n",
|
||
" plt.axis([0, 1, 0, 1])\n",
|
||
" plt.xlabel('False Positive Rate', fontsize=16)\n",
|
||
" plt.ylabel('True Positive Rate', fontsize=16)\n",
|
||
"\n",
|
||
"plt.figure(figsize=(8, 6))\n",
|
||
"plot_roc_curve(fpr, tpr)\n",
|
||
"save_fig(\"roc_curve_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 41,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.metrics import roc_auc_score\n",
|
||
"\n",
|
||
"roc_auc_score(y_train_5, y_scores)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 42,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||
"forest_clf = RandomForestClassifier(random_state=42)\n",
|
||
"y_probas_forest = cross_val_predict_future(forest_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
|
||
"y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n",
|
||
"fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)\n",
|
||
"\n",
|
||
"plt.figure(figsize=(8, 6))\n",
|
||
"plt.plot(fpr, tpr, \"b:\", linewidth=2, label=\"SGD\")\n",
|
||
"plot_roc_curve(fpr_forest, tpr_forest, label=\"Random Forest\")\n",
|
||
"plt.legend(loc=\"lower right\", fontsize=16)\n",
|
||
"save_fig(\"roc_curve_comparison_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 43,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"roc_auc_score(y_train_5, y_scores_forest)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 44,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"y_train_pred_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3)\n",
|
||
"precision_score(y_train_5, y_train_pred_forest)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 45,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"recall_score(y_train_5, y_train_pred_forest)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Multiclass classification"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 46,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"sgd_clf.fit(X_train, y_train)\n",
|
||
"sgd_clf.predict([some_digit])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 47,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"some_digit_scores = sgd_clf.decision_function([some_digit])\n",
|
||
"some_digit_scores"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 48,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"np.argmax(some_digit_scores)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 49,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"sgd_clf.classes_"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 50,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.multiclass import OneVsOneClassifier\n",
|
||
"ovo_clf = OneVsOneClassifier(SGDClassifier(random_state=42))\n",
|
||
"ovo_clf.fit(X_train, y_train)\n",
|
||
"ovo_clf.predict([some_digit])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 51,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"len(ovo_clf.estimators_)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 52,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"forest_clf.fit(X_train, y_train)\n",
|
||
"forest_clf.predict([some_digit])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 53,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"forest_clf.predict_proba([some_digit])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 54,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring=\"accuracy\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 55,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"scaler = StandardScaler()\n",
|
||
"X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))\n",
|
||
"cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring=\"accuracy\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 56,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)\n",
|
||
"conf_mx = confusion_matrix(y_train, y_train_pred)\n",
|
||
"conf_mx"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 57,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def plot_confusion_matrix(matrix):\n",
|
||
" \"\"\"If you prefer color and a colorbar\"\"\"\n",
|
||
" fig = plt.figure(figsize=(8,8))\n",
|
||
" ax = fig.add_subplot(111)\n",
|
||
" cax = ax.matshow(conf_mx)\n",
|
||
" fig.colorbar(cax)\n",
|
||
"\n",
|
||
"plt.matshow(conf_mx, cmap=plt.cm.gray)\n",
|
||
"save_fig(\"confusion_matrix_plot\", tight_layout=False)\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 58,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"row_sums = conf_mx.sum(axis=1, keepdims=True)\n",
|
||
"norm_conf_mx = conf_mx / row_sums\n",
|
||
"np.fill_diagonal(norm_conf_mx, 0)\n",
|
||
"plt.matshow(norm_conf_mx, cmap=plt.cm.gray)\n",
|
||
"save_fig(\"confusion_matrix_errors_plot\", tight_layout=False)\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 59,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"cl_a, cl_b = 3, 5\n",
|
||
"X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]\n",
|
||
"X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]\n",
|
||
"X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]\n",
|
||
"X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]\n",
|
||
"\n",
|
||
"plt.figure(figsize=(8,8))\n",
|
||
"plt.subplot(221)\n",
|
||
"plot_digits(X_aa[:25], images_per_row=5)\n",
|
||
"plt.subplot(222)\n",
|
||
"plot_digits(X_ab[:25], images_per_row=5)\n",
|
||
"plt.subplot(223)\n",
|
||
"plot_digits(X_ba[:25], images_per_row=5)\n",
|
||
"plt.subplot(224)\n",
|
||
"plot_digits(X_bb[:25], images_per_row=5)\n",
|
||
"save_fig(\"error_analysis_digits_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Multilabel classification"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 60,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.neighbors import KNeighborsClassifier\n",
|
||
"\n",
|
||
"y_train_large = (y_train >= 7)\n",
|
||
"y_train_odd = (y_train % 2 == 1)\n",
|
||
"y_multilabel = np.c_[y_train_large, y_train_odd]\n",
|
||
"\n",
|
||
"knn_clf = KNeighborsClassifier()\n",
|
||
"knn_clf.fit(X_train, y_multilabel)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 61,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"knn_clf.predict([some_digit])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 62,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_train, cv=3)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 63,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"f1_score(y_train, y_train_knn_pred, average=\"macro\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Multioutput classification"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 64,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"noise = rnd.randint(0, 100, (len(X_train), 784))\n",
|
||
"X_train_mod = X_train + noise\n",
|
||
"noise = rnd.randint(0, 100, (len(X_test), 784))\n",
|
||
"X_test_mod = X_test + noise\n",
|
||
"y_train_mod = X_train\n",
|
||
"y_test_mod = X_test"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 65,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"some_index = 5500\n",
|
||
"plt.subplot(121); plot_digit(X_test_mod[some_index])\n",
|
||
"plt.subplot(122); plot_digit(y_test_mod[some_index])\n",
|
||
"save_fig(\"noisy_digit_example_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 66,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"knn_clf.fit(X_train_mod, y_train_mod)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 67,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"clean_digit = knn_clf.predict([X_test_mod[some_index]])\n",
|
||
"plot_digit(clean_digit)\n",
|
||
"save_fig(\"cleaned_digit_example_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Extra material"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Dummy (ie. random) classifier"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 68,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.dummy import DummyClassifier\n",
|
||
"dmy_clf = DummyClassifier()\n",
|
||
"y_probas_dmy = cross_val_predict_future(dmy_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
|
||
"y_scores_dmy = y_probas_dmy[:, 1]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 69,
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"scrolled": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"fprr, tprr, thresholdsr = roc_curve(y_train_5, y_scores_dmy)\n",
|
||
"plot_roc_curve(fprr, tprr)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## KNN classifier"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 70,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.neighbors import KNeighborsClassifier\n",
|
||
"knn_clf = KNeighborsClassifier(n_jobs=-1, weights='distance', n_neighbors=4)\n",
|
||
"knn_clf.fit(X_train, y_train)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 71,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"y_knn_pred = knn_clf.predict(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 72,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.metrics import accuracy_score\n",
|
||
"accuracy_score(y_test, y_knn_pred)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 73,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from scipy.ndimage.interpolation import shift\n",
|
||
"def shift_digit(digit_array, dx, dy, new=0):\n",
|
||
" return shift(digit_array.reshape(28, 28), [dy, dx], cval=new).reshape(784)\n",
|
||
"\n",
|
||
"plot_digit(shift_digit(some_digit, 5, 1, new=100))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 74,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"X_train_expanded = [X_train]\n",
|
||
"y_train_expanded = [y_train]\n",
|
||
"for dx, dy in ((1, 0), (-1, 0), (0, 1), (0, -1)):\n",
|
||
" shifted_images = np.apply_along_axis(shift_digit, axis=1, arr=X_train, dx=dx, dy=dy)\n",
|
||
" X_train_expanded.append(shifted_images)\n",
|
||
" y_train_expanded.append(y_train)\n",
|
||
"\n",
|
||
"X_train_expanded = np.concatenate(X_train_expanded)\n",
|
||
"y_train_expanded = np.concatenate(y_train_expanded)\n",
|
||
"X_train_expanded.shape, y_train_expanded.shape"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 75,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"knn_clf.fit(X_train_expanded, y_train_expanded)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 76,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"y_knn_expanded_pred = knn_clf.predict(X_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 77,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"accuracy_score(y_test, y_knn_expanded_pred)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 78,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"ambiguous_digit = X_test[2589]\n",
|
||
"knn_clf.predict_proba([ambiguous_digit])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 79,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"plot_digit(ambiguous_digit)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"source": [
|
||
"# Exercise solutions"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Coming soon**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.5.1"
|
||
},
|
||
"nav_menu": {},
|
||
"toc": {
|
||
"navigate_menu": true,
|
||
"number_sections": true,
|
||
"sideBar": true,
|
||
"threshold": 6,
|
||
"toc_cell": false,
|
||
"toc_section_display": "block",
|
||
"toc_window_display": false
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0
|
||
}
|