handson-ml/03_classification.ipynb

1219 lines
28 KiB
Plaintext
Raw Normal View History

2016-05-22 17:40:18 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
2016-09-27 23:31:21 +02:00
"**Chapter 3 Classification**\n",
"\n",
2017-08-19 17:01:55 +02:00
"_This notebook contains all the sample code and solutions to the exercises in chapter 3._"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
2016-09-27 23:31:21 +02:00
"# To support both python 2 and python 3\n",
2016-05-22 17:40:18 +02:00
"from __future__ import division, print_function, unicode_literals\n",
"\n",
2016-09-27 23:31:21 +02:00
"# Common imports\n",
2016-05-22 17:40:18 +02:00
"import numpy as np\n",
2016-05-22 18:07:41 +02:00
"import os\n",
"\n",
2016-09-27 23:31:21 +02:00
"# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n",
2016-09-27 23:31:21 +02:00
"\n",
"# To plot pretty figures\n",
2016-05-22 17:40:18 +02:00
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['axes.labelsize'] = 14\n",
"plt.rcParams['xtick.labelsize'] = 12\n",
"plt.rcParams['ytick.labelsize'] = 12\n",
"\n",
2016-09-27 23:31:21 +02:00
"# Where to save the figures\n",
2016-05-22 17:40:18 +02:00
"PROJECT_ROOT_DIR = \".\"\n",
"CHAPTER_ID = \"classification\"\n",
"\n",
"def save_fig(fig_id, tight_layout=True):\n",
" path = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id + \".png\")\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format='png', dpi=300)"
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
"# MNIST"
]
},
{
"cell_type": "code",
"execution_count": 2,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.datasets import fetch_mldata\n",
"mnist = fetch_mldata('MNIST original')\n",
"mnist"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 3,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"X, y = mnist[\"data\"], mnist[\"target\"]\n",
"X.shape"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 4,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"y.shape"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 5,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"28*28"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 6,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"\n",
"some_digit = X[36000]\n",
"some_digit_image = some_digit.reshape(28, 28)\n",
"plt.imshow(some_digit_image, cmap = matplotlib.cm.binary,\n",
" interpolation=\"nearest\")\n",
"plt.axis(\"off\")\n",
"\n",
"save_fig(\"some_digit_plot\")\n",
"plt.show()"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
2017-07-07 21:56:30 +02:00
"collapsed": true
2016-05-22 17:40:18 +02:00
},
"outputs": [],
"source": [
"def plot_digit(data):\n",
" image = data.reshape(28, 28)\n",
" plt.imshow(image, cmap = matplotlib.cm.binary,\n",
" interpolation=\"nearest\")\n",
" plt.axis(\"off\")"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"# EXTRA\n",
"def plot_digits(instances, images_per_row=10, **options):\n",
" size = 28\n",
" images_per_row = min(len(instances), images_per_row)\n",
" images = [instance.reshape(size,size) for instance in instances]\n",
" n_rows = (len(instances) - 1) // images_per_row + 1\n",
" row_images = []\n",
" n_empty = n_rows * images_per_row - len(instances)\n",
" images.append(np.zeros((size, size * n_empty)))\n",
" for row in range(n_rows):\n",
" rimages = images[row * images_per_row : (row + 1) * images_per_row]\n",
" row_images.append(np.concatenate(rimages, axis=1))\n",
" image = np.concatenate(row_images, axis=0)\n",
" plt.imshow(image, cmap = matplotlib.cm.binary, **options)\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
2017-07-07 21:56:30 +02:00
"metadata": {},
"outputs": [],
"source": [
2016-05-22 17:40:18 +02:00
"plt.figure(figsize=(9,9))\n",
"example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]\n",
"plot_digits(example_images, images_per_row=10)\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"more_digits_plot\")\n",
2016-05-22 17:40:18 +02:00
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 10,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"y[36000]"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "code",
"execution_count": 11,
2016-05-22 17:40:18 +02:00
"metadata": {
2017-07-07 21:56:30 +02:00
"collapsed": true
2016-05-22 17:40:18 +02:00
},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]"
]
},
{
"cell_type": "code",
"execution_count": 12,
2016-05-22 17:40:18 +02:00
"metadata": {
2017-07-07 21:56:30 +02:00
"collapsed": true
2016-05-22 17:40:18 +02:00
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"shuffle_index = np.random.permutation(60000)\n",
2016-05-22 17:40:18 +02:00
"X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]"
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
"# Binary classifier"
]
},
{
"cell_type": "code",
"execution_count": 13,
2016-05-22 17:40:18 +02:00
"metadata": {
2017-07-07 21:56:30 +02:00
"collapsed": true
2016-05-22 17:40:18 +02:00
},
"outputs": [],
"source": [
"y_train_5 = (y_train == 5)\n",
"y_test_5 = (y_test == 5)"
]
},
{
"cell_type": "code",
"execution_count": 14,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.linear_model import SGDClassifier\n",
"\n",
"sgd_clf = SGDClassifier(random_state=42)\n",
"sgd_clf.fit(X_train, y_train_5)"
]
},
{
"cell_type": "code",
"execution_count": 15,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"sgd_clf.predict([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 16,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
2016-11-05 18:13:54 +01:00
"from sklearn.model_selection import cross_val_score\n",
2016-05-22 17:40:18 +02:00
"cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
2016-11-05 18:13:54 +01:00
"from sklearn.model_selection import StratifiedKFold\n",
2016-05-22 17:40:18 +02:00
"from sklearn.base import clone\n",
"\n",
2016-11-05 18:13:54 +01:00
"skfolds = StratifiedKFold(n_splits=3, random_state=42)\n",
2016-05-22 17:40:18 +02:00
"\n",
2016-11-05 18:13:54 +01:00
"for train_index, test_index in skfolds.split(X_train, y_train_5):\n",
2016-05-22 17:40:18 +02:00
" clone_clf = clone(sgd_clf)\n",
" X_train_folds = X_train[train_index]\n",
" y_train_folds = (y_train_5[train_index])\n",
" X_test_fold = X_train[test_index]\n",
" y_test_fold = (y_train_5[test_index])\n",
2016-11-05 18:13:54 +01:00
"\n",
2016-05-22 17:40:18 +02:00
" clone_clf.fit(X_train_folds, y_train_folds)\n",
" y_pred = clone_clf.predict(X_test_fold)\n",
" n_correct = sum(y_pred == y_test_fold)\n",
" print(n_correct / len(y_pred))"
]
},
{
"cell_type": "code",
"execution_count": 18,
2016-05-22 17:40:18 +02:00
"metadata": {
2017-07-07 21:56:30 +02:00
"collapsed": true
2016-05-22 17:40:18 +02:00
},
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator\n",
"class Never5Classifier(BaseEstimator):\n",
" def fit(self, X, y=None):\n",
" pass\n",
" def predict(self, X):\n",
" return np.zeros((len(X), 1), dtype=bool)"
]
},
{
"cell_type": "code",
"execution_count": 19,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"never_5_clf = Never5Classifier()\n",
"cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": true
},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
2016-11-05 18:13:54 +01:00
"from sklearn.model_selection import cross_val_predict\n",
2016-05-22 17:40:18 +02:00
"\n",
"y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)"
]
},
{
"cell_type": "code",
"execution_count": 21,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"\n",
"confusion_matrix(y_train_5, y_train_pred)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
2017-07-07 21:56:30 +02:00
"collapsed": true
},
"outputs": [],
"source": [
"y_train_perfect_predictions = y_train_5"
]
},
{
"cell_type": "code",
"execution_count": 23,
2017-07-07 21:56:30 +02:00
"metadata": {},
"outputs": [],
"source": [
"confusion_matrix(y_train_5, y_train_perfect_predictions)"
]
},
{
"cell_type": "code",
"execution_count": 24,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.metrics import precision_score, recall_score\n",
"\n",
"precision_score(y_train_5, y_train_pred)"
]
},
{
"cell_type": "code",
"execution_count": 25,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"4344 / (4344 + 1307)"
]
},
{
"cell_type": "code",
"execution_count": 26,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"recall_score(y_train_5, y_train_pred)"
]
},
{
"cell_type": "code",
"execution_count": 27,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"4344 / (4344 + 1077)"
]
},
{
"cell_type": "code",
"execution_count": 28,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.metrics import f1_score\n",
"f1_score(y_train_5, y_train_pred)"
]
},
{
"cell_type": "code",
"execution_count": 29,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"4344 / (4344 + (1077 + 1307)/2)"
]
},
{
"cell_type": "code",
"execution_count": 30,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"y_scores = sgd_clf.decision_function([some_digit])\n",
"y_scores"
]
},
{
"cell_type": "code",
"execution_count": 31,
2016-05-22 17:40:18 +02:00
"metadata": {
2017-07-07 21:56:30 +02:00
"collapsed": true
2016-05-22 17:40:18 +02:00
},
"outputs": [],
"source": [
"threshold = 0\n",
"y_some_digit_pred = (y_scores > threshold)"
]
},
{
"cell_type": "code",
"execution_count": 32,
2017-07-07 21:56:30 +02:00
"metadata": {},
"outputs": [],
"source": [
2016-05-22 17:40:18 +02:00
"y_some_digit_pred"
]
},
{
"cell_type": "code",
"execution_count": 33,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"threshold = 200000\n",
"y_some_digit_pred = (y_scores > threshold)\n",
"y_some_digit_pred"
]
},
{
"cell_type": "code",
"execution_count": 34,
2016-05-22 17:40:18 +02:00
"metadata": {
2017-07-07 21:56:30 +02:00
"collapsed": true
2016-05-22 17:40:18 +02:00
},
"outputs": [],
"source": [
"y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,\n",
" method=\"decision_function\")"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: there is an [issue](https://github.com/scikit-learn/scikit-learn/issues/9589) introduced in Scikit-Learn 0.19.0 where the result of `cross_val_predict()` is incorrect in the binary classification case when using `method=\"decision_function\"`, as in the code above. The resulting array has an extra first dimension full of 0s. We need to add this small hack for now to work around this issue:"
]
},
2016-05-22 17:40:18 +02:00
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"y_scores.shape"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# hack to work around issue #9589 introduced in Scikit-Learn 0.19.0\n",
"if y_scores.ndim == 2:\n",
" y_scores = y_scores[:, 1]"
]
},
{
"cell_type": "code",
"execution_count": 37,
2016-05-22 17:40:18 +02:00
"metadata": {
2017-07-07 21:56:30 +02:00
"collapsed": true
2016-05-22 17:40:18 +02:00
},
"outputs": [],
"source": [
"from sklearn.metrics import precision_recall_curve\n",
"\n",
"precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)"
]
},
{
"cell_type": "code",
"execution_count": 38,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):\n",
" plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n",
" plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n",
" plt.xlabel(\"Threshold\", fontsize=16)\n",
" plt.legend(loc=\"upper left\", fontsize=16)\n",
2016-05-22 17:40:18 +02:00
" plt.ylim([0, 1])\n",
"\n",
"plt.figure(figsize=(8, 4))\n",
"plot_precision_recall_vs_threshold(precisions, recalls, thresholds)\n",
"plt.xlim([-700000, 700000])\n",
"save_fig(\"precision_recall_vs_threshold_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 39,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"(y_train_pred == (y_scores > 0)).all()"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"collapsed": true
},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"y_train_pred_90 = (y_scores > 70000)"
]
},
{
"cell_type": "code",
"execution_count": 41,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"precision_score(y_train_5, y_train_pred_90)"
]
},
{
"cell_type": "code",
"execution_count": 42,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"recall_score(y_train_5, y_train_pred_90)"
]
},
{
"cell_type": "code",
"execution_count": 43,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"def plot_precision_vs_recall(precisions, recalls):\n",
" plt.plot(recalls, precisions, \"b-\", linewidth=2)\n",
" plt.xlabel(\"Recall\", fontsize=16)\n",
" plt.ylabel(\"Precision\", fontsize=16)\n",
" plt.axis([0, 1, 0, 1])\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"plot_precision_vs_recall(precisions, recalls)\n",
"save_fig(\"precision_vs_recall_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
"# ROC curves"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"collapsed": true
},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.metrics import roc_curve\n",
"\n",
"fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)"
]
},
{
"cell_type": "code",
"execution_count": 45,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"def plot_roc_curve(fpr, tpr, label=None):\n",
" plt.plot(fpr, tpr, linewidth=2, label=label)\n",
2016-05-22 17:40:18 +02:00
" plt.plot([0, 1], [0, 1], 'k--')\n",
" plt.axis([0, 1, 0, 1])\n",
" plt.xlabel('False Positive Rate', fontsize=16)\n",
" plt.ylabel('True Positive Rate', fontsize=16)\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"plot_roc_curve(fpr, tpr)\n",
"save_fig(\"roc_curve_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 46,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.metrics import roc_auc_score\n",
"\n",
"roc_auc_score(y_train_5, y_scores)"
]
},
{
"cell_type": "code",
"execution_count": 47,
2016-05-22 17:40:18 +02:00
"metadata": {
2017-07-07 21:56:30 +02:00
"collapsed": true
2016-05-22 17:40:18 +02:00
},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"forest_clf = RandomForestClassifier(random_state=42)\n",
"y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,\n",
" method=\"predict_proba\")"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
2017-07-07 21:56:30 +02:00
"collapsed": true
},
"outputs": [],
"source": [
2016-05-22 17:40:18 +02:00
"y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n",
"fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)"
]
},
{
"cell_type": "code",
"execution_count": 49,
2017-07-07 21:56:30 +02:00
"metadata": {},
"outputs": [],
"source": [
2016-05-22 17:40:18 +02:00
"plt.figure(figsize=(8, 6))\n",
"plt.plot(fpr, tpr, \"b:\", linewidth=2, label=\"SGD\")\n",
"plot_roc_curve(fpr_forest, tpr_forest, \"Random Forest\")\n",
2016-05-22 17:40:18 +02:00
"plt.legend(loc=\"lower right\", fontsize=16)\n",
"save_fig(\"roc_curve_comparison_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 50,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"roc_auc_score(y_train_5, y_scores_forest)"
]
},
{
"cell_type": "code",
"execution_count": 51,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"y_train_pred_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3)\n",
"precision_score(y_train_5, y_train_pred_forest)"
]
},
{
"cell_type": "code",
"execution_count": 52,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"recall_score(y_train_5, y_train_pred_forest)"
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
"# Multiclass classification"
]
},
{
"cell_type": "code",
"execution_count": 53,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"sgd_clf.fit(X_train, y_train)\n",
"sgd_clf.predict([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 54,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"some_digit_scores = sgd_clf.decision_function([some_digit])\n",
"some_digit_scores"
]
},
{
"cell_type": "code",
"execution_count": 55,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"np.argmax(some_digit_scores)"
]
},
{
"cell_type": "code",
"execution_count": 56,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"sgd_clf.classes_"
]
},
{
"cell_type": "code",
"execution_count": 57,
2017-07-07 21:56:30 +02:00
"metadata": {},
"outputs": [],
"source": [
"sgd_clf.classes_[5]"
]
},
{
"cell_type": "code",
"execution_count": 58,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.multiclass import OneVsOneClassifier\n",
"ovo_clf = OneVsOneClassifier(SGDClassifier(random_state=42))\n",
"ovo_clf.fit(X_train, y_train)\n",
"ovo_clf.predict([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 59,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"len(ovo_clf.estimators_)"
]
},
{
"cell_type": "code",
"execution_count": 60,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"forest_clf.fit(X_train, y_train)\n",
"forest_clf.predict([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 61,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"forest_clf.predict_proba([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 62,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring=\"accuracy\")"
]
},
{
"cell_type": "code",
"execution_count": 63,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"scaler = StandardScaler()\n",
"X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))\n",
"cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring=\"accuracy\")"
]
},
{
"cell_type": "code",
"execution_count": 64,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)\n",
"conf_mx = confusion_matrix(y_train, y_train_pred)\n",
"conf_mx"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {
"collapsed": true
},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"def plot_confusion_matrix(matrix):\n",
" \"\"\"If you prefer color and a colorbar\"\"\"\n",
" fig = plt.figure(figsize=(8,8))\n",
" ax = fig.add_subplot(111)\n",
" cax = ax.matshow(matrix)\n",
" fig.colorbar(cax)"
]
},
{
"cell_type": "code",
"execution_count": 66,
2017-07-07 21:56:30 +02:00
"metadata": {},
"outputs": [],
"source": [
2016-05-22 17:40:18 +02:00
"plt.matshow(conf_mx, cmap=plt.cm.gray)\n",
"save_fig(\"confusion_matrix_plot\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"collapsed": true
},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"row_sums = conf_mx.sum(axis=1, keepdims=True)\n",
"norm_conf_mx = conf_mx / row_sums"
]
},
{
"cell_type": "code",
"execution_count": 68,
2017-07-07 21:56:30 +02:00
"metadata": {},
"outputs": [],
"source": [
2016-05-22 17:40:18 +02:00
"np.fill_diagonal(norm_conf_mx, 0)\n",
"plt.matshow(norm_conf_mx, cmap=plt.cm.gray)\n",
"save_fig(\"confusion_matrix_errors_plot\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 69,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"cl_a, cl_b = 3, 5\n",
"X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]\n",
"X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]\n",
"X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]\n",
"X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]\n",
"\n",
"plt.figure(figsize=(8,8))\n",
"plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)\n",
"plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)\n",
"plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)\n",
"plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)\n",
2016-05-22 17:40:18 +02:00
"save_fig(\"error_analysis_digits_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
"# Multilabel classification"
]
},
{
"cell_type": "code",
"execution_count": 70,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"\n",
"y_train_large = (y_train >= 7)\n",
"y_train_odd = (y_train % 2 == 1)\n",
"y_multilabel = np.c_[y_train_large, y_train_odd]\n",
"\n",
"knn_clf = KNeighborsClassifier()\n",
"knn_clf.fit(X_train, y_multilabel)"
]
},
{
"cell_type": "code",
"execution_count": 71,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"knn_clf.predict([some_digit])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning**: the following cell may take a very long time (possibly hours depending on your hardware)."
]
},
2016-05-22 17:40:18 +02:00
{
"cell_type": "code",
"execution_count": 72,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
2017-07-07 21:56:30 +02:00
"y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3)\n",
"f1_score(y_multilabel, y_train_knn_pred, average=\"macro\")"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
"# Multioutput classification"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"collapsed": true
},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"noise = np.random.randint(0, 100, (len(X_train), 784))\n",
2016-05-22 17:40:18 +02:00
"X_train_mod = X_train + noise\n",
"noise = np.random.randint(0, 100, (len(X_test), 784))\n",
2016-05-22 17:40:18 +02:00
"X_test_mod = X_test + noise\n",
"y_train_mod = X_train\n",
"y_test_mod = X_test"
]
},
{
"cell_type": "code",
"execution_count": 74,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"some_index = 5500\n",
"plt.subplot(121); plot_digit(X_test_mod[some_index])\n",
"plt.subplot(122); plot_digit(y_test_mod[some_index])\n",
2016-09-27 23:31:21 +02:00
"save_fig(\"noisy_digit_example_plot\")\n",
2016-05-22 17:40:18 +02:00
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 75,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"knn_clf.fit(X_train_mod, y_train_mod)\n",
2016-05-22 17:40:18 +02:00
"clean_digit = knn_clf.predict([X_test_mod[some_index]])\n",
"plot_digit(clean_digit)\n",
"save_fig(\"cleaned_digit_example_plot\")"
2016-05-22 17:40:18 +02:00
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
"# Extra material"
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
"## Dummy (ie. random) classifier"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {
"collapsed": true
},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.dummy import DummyClassifier\n",
"dmy_clf = DummyClassifier()\n",
2016-11-05 18:13:54 +01:00
"y_probas_dmy = cross_val_predict(dmy_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
2016-05-22 17:40:18 +02:00
"y_scores_dmy = y_probas_dmy[:, 1]"
]
},
{
"cell_type": "code",
"execution_count": 77,
2016-05-22 17:40:18 +02:00
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"fprr, tprr, thresholdsr = roc_curve(y_train_5, y_scores_dmy)\n",
"plot_roc_curve(fprr, tprr)"
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"source": [
"## KNN classifier"
]
},
{
"cell_type": "code",
"execution_count": 78,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"knn_clf = KNeighborsClassifier(n_jobs=-1, weights='distance', n_neighbors=4)\n",
"knn_clf.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {
"collapsed": true
},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"y_knn_pred = knn_clf.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 80,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from sklearn.metrics import accuracy_score\n",
"accuracy_score(y_test, y_knn_pred)"
]
},
{
"cell_type": "code",
"execution_count": 81,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"from scipy.ndimage.interpolation import shift\n",
"def shift_digit(digit_array, dx, dy, new=0):\n",
" return shift(digit_array.reshape(28, 28), [dy, dx], cval=new).reshape(784)\n",
"\n",
"plot_digit(shift_digit(some_digit, 5, 1, new=100))"
]
},
{
"cell_type": "code",
"execution_count": 82,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"X_train_expanded = [X_train]\n",
"y_train_expanded = [y_train]\n",
"for dx, dy in ((1, 0), (-1, 0), (0, 1), (0, -1)):\n",
" shifted_images = np.apply_along_axis(shift_digit, axis=1, arr=X_train, dx=dx, dy=dy)\n",
" X_train_expanded.append(shifted_images)\n",
" y_train_expanded.append(y_train)\n",
"\n",
"X_train_expanded = np.concatenate(X_train_expanded)\n",
"y_train_expanded = np.concatenate(y_train_expanded)\n",
"X_train_expanded.shape, y_train_expanded.shape"
]
},
{
"cell_type": "code",
"execution_count": 83,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"knn_clf.fit(X_train_expanded, y_train_expanded)"
]
},
{
"cell_type": "code",
"execution_count": 84,
2016-05-22 17:40:18 +02:00
"metadata": {
2017-07-07 21:56:30 +02:00
"collapsed": true
2016-05-22 17:40:18 +02:00
},
"outputs": [],
"source": [
"y_knn_expanded_pred = knn_clf.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 85,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"accuracy_score(y_test, y_knn_expanded_pred)"
]
},
{
"cell_type": "code",
"execution_count": 86,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"ambiguous_digit = X_test[2589]\n",
"knn_clf.predict_proba([ambiguous_digit])"
]
},
{
"cell_type": "code",
"execution_count": 87,
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-05-22 17:40:18 +02:00
"outputs": [],
"source": [
"plot_digit(ambiguous_digit)"
]
2016-09-27 23:31:21 +02:00
},
{
"cell_type": "markdown",
"metadata": {
2017-07-07 21:56:30 +02:00
"collapsed": true
2016-09-27 23:31:21 +02:00
},
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
2017-07-07 21:56:30 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"**Coming soon**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
2017-07-07 21:56:30 +02:00
"collapsed": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": []
2016-05-22 17:40:18 +02:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
2016-05-22 17:40:18 +02:00
},
2016-09-27 23:31:21 +02:00
"nav_menu": {},
2016-05-22 17:40:18 +02:00
"toc": {
2016-09-27 23:31:21 +02:00
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
2016-05-22 17:40:18 +02:00
"toc_cell": false,
2016-09-27 23:31:21 +02:00
"toc_section_display": "block",
2016-05-22 17:40:18 +02:00
"toc_window_display": false
}
},
"nbformat": 4,
2017-07-07 21:56:30 +02:00
"nbformat_minor": 1
2016-05-22 17:40:18 +02:00
}