Make notebook code match book examples more closely in chapter 3

main
Aurélien Geron 2017-06-01 09:52:10 +02:00
parent 291ae3e39d
commit c226f328d3
1 changed files with 220 additions and 162 deletions

View File

@ -93,47 +93,14 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from six.moves import urllib\n",
"from sklearn.datasets import fetch_mldata\n", "from sklearn.datasets import fetch_mldata\n",
"try:\n", "mnist = fetch_mldata('MNIST original')\n",
" mnist = fetch_mldata('MNIST original')\n",
"except urllib.error.HTTPError as ex:\n",
" print(\"Could not download MNIST data from mldata.org, trying alternative...\")\n",
"\n",
" # Alternative method to load MNIST, if mldata.org is down\n",
" from scipy.io import loadmat\n",
" mnist_alternative_url = \"https://github.com/amplab/datascience-sp14/raw/master/lab7/mldata/mnist-original.mat\"\n",
" mnist_path = \"./mnist-original.mat\"\n",
" response = urllib.request.urlopen(mnist_alternative_url)\n",
" with open(mnist_path, \"wb\") as f:\n",
" content = response.read()\n",
" f.write(content)\n",
" mnist_raw = loadmat(mnist_path)\n",
" mnist = {\n",
" \"data\": mnist_raw[\"data\"].T,\n",
" \"target\": mnist_raw[\"label\"][0],\n",
" \"COL_NAMES\": [\"label\", \"data\"],\n",
" \"DESCR\": \"mldata.org dataset: mnist-original\",\n",
" }\n",
" print(\"Success!\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"mnist" "mnist"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 3,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -147,7 +114,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 4,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -160,7 +127,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 5,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -173,25 +140,43 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 6,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
"editable": true "editable": true
}, },
"outputs": [], "outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"\n",
"some_digit = X[36000]\n",
"some_digit_image = some_digit.reshape(28, 28)\n",
"plt.imshow(some_digit_image, cmap = matplotlib.cm.binary,\n",
" interpolation=\"nearest\")\n",
"plt.axis(\"off\")\n",
"\n",
"save_fig(\"some_digit_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [ "source": [
"def plot_digit(data):\n", "def plot_digit(data):\n",
" image = data.reshape(28, 28)\n", " image = data.reshape(28, 28)\n",
" plt.imshow(image, cmap = matplotlib.cm.binary,\n", " plt.imshow(image, cmap = matplotlib.cm.binary,\n",
" interpolation=\"nearest\")\n", " interpolation=\"nearest\")\n",
" plt.axis(\"off\")\n", " plt.axis(\"off\")"
"\n",
"some_digit_index = 36000\n",
"some_digit = X[some_digit_index]\n",
"plot_digit(some_digit)\n",
"save_fig(\"some_digit_plot\")\n",
"plt.show()"
] ]
}, },
{ {
@ -218,13 +203,7 @@
" row_images.append(np.concatenate(rimages, axis=1))\n", " row_images.append(np.concatenate(rimages, axis=1))\n",
" image = np.concatenate(row_images, axis=0)\n", " image = np.concatenate(row_images, axis=0)\n",
" plt.imshow(image, cmap = matplotlib.cm.binary, **options)\n", " plt.imshow(image, cmap = matplotlib.cm.binary, **options)\n",
" plt.axis(\"off\")\n", " plt.axis(\"off\")"
"\n",
"plt.figure(figsize=(9,9))\n",
"example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]\n",
"plot_digits(example_images, images_per_row=10)\n",
"save_fig(\"more_digits_plot\")\n",
"plt.show()"
] ]
}, },
{ {
@ -237,20 +216,24 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"y[some_digit_index]" "plt.figure(figsize=(9,9))\n",
"example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]\n",
"plot_digits(example_images, images_per_row=10)\n",
"save_fig(\"more_digits_plot\")\n",
"plt.show()"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 10,
"metadata": { "metadata": {
"collapsed": true, "collapsed": false,
"deletable": true, "deletable": true,
"editable": true "editable": true
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]" "y[36000]"
] ]
}, },
{ {
@ -263,7 +246,22 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"shuffle_index = rnd.permutation(60000)\n", "X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"shuffle_index = np.random.permutation(60000)\n",
"X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]" "X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]"
] ]
}, },
@ -279,7 +277,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 13,
"metadata": { "metadata": {
"collapsed": true, "collapsed": true,
"deletable": true, "deletable": true,
@ -293,7 +291,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 14,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -309,7 +307,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 15,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -322,7 +320,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 16,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -336,7 +334,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 17,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -364,7 +362,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 18,
"metadata": { "metadata": {
"collapsed": true, "collapsed": true,
"deletable": true, "deletable": true,
@ -382,7 +380,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 19,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -396,7 +394,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 20,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -411,7 +409,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 21,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -426,7 +424,33 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 22,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"y_train_perfect_predictions = y_train_5"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"confusion_matrix(y_train_5, y_train_perfect_predictions)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -441,7 +465,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 25,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -454,7 +478,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 23, "execution_count": 26,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -467,7 +491,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 24, "execution_count": 27,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -480,7 +504,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25, "execution_count": 28,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -494,7 +518,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 26, "execution_count": 29,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -507,7 +531,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 27, "execution_count": 30,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -521,7 +545,21 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 28, "execution_count": 31,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"threshold = 0\n",
"y_some_digit_pred = (y_scores > threshold)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -529,14 +567,12 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"threshold = 0\n",
"y_some_digit_pred = (y_scores > threshold)\n",
"y_some_digit_pred" "y_some_digit_pred"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 29, "execution_count": 33,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -551,7 +587,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 34,
"metadata": { "metadata": {
"collapsed": true, "collapsed": true,
"deletable": true, "deletable": true,
@ -559,12 +595,13 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method=\"decision_function\")" "y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,\n",
" method=\"decision_function\")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": 35,
"metadata": { "metadata": {
"collapsed": true, "collapsed": true,
"deletable": true, "deletable": true,
@ -579,7 +616,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 32, "execution_count": 36,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -591,7 +628,7 @@
" plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n", " plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n",
" plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n", " plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n",
" plt.xlabel(\"Threshold\", fontsize=16)\n", " plt.xlabel(\"Threshold\", fontsize=16)\n",
" plt.legend(loc=\"center left\", fontsize=16)\n", " plt.legend(loc=\"upper left\", fontsize=16)\n",
" plt.ylim([0, 1])\n", " plt.ylim([0, 1])\n",
"\n", "\n",
"plt.figure(figsize=(8, 4))\n", "plt.figure(figsize=(8, 4))\n",
@ -603,7 +640,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 33, "execution_count": 37,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -616,7 +653,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 34, "execution_count": 38,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -629,7 +666,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 35, "execution_count": 39,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -642,7 +679,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": 40,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -655,7 +692,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 37, "execution_count": 41,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -687,7 +724,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 38, "execution_count": 42,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -702,7 +739,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": 43,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -710,8 +747,8 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"def plot_roc_curve(fpr, tpr, **options):\n", "def plot_roc_curve(fpr, tpr, label=None):\n",
" plt.plot(fpr, tpr, linewidth=2, **options)\n", " plt.plot(fpr, tpr, linewidth=2, label=label)\n",
" plt.plot([0, 1], [0, 1], 'k--')\n", " plt.plot([0, 1], [0, 1], 'k--')\n",
" plt.axis([0, 1, 0, 1])\n", " plt.axis([0, 1, 0, 1])\n",
" plt.xlabel('False Positive Rate', fontsize=16)\n", " plt.xlabel('False Positive Rate', fontsize=16)\n",
@ -725,7 +762,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40, "execution_count": 44,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -740,7 +777,33 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 41, "execution_count": 45,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"forest_clf = RandomForestClassifier(random_state=42)\n",
"y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,\n",
" method=\"predict_proba\")"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n",
"fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -748,15 +811,9 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"forest_clf = RandomForestClassifier(random_state=42)\n",
"y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
"y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n",
"fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)\n",
"\n",
"plt.figure(figsize=(8, 6))\n", "plt.figure(figsize=(8, 6))\n",
"plt.plot(fpr, tpr, \"b:\", linewidth=2, label=\"SGD\")\n", "plt.plot(fpr, tpr, \"b:\", linewidth=2, label=\"SGD\")\n",
"plot_roc_curve(fpr_forest, tpr_forest, label=\"Random Forest\")\n", "plot_roc_curve(fpr_forest, tpr_forest, \"Random Forest\")\n",
"plt.legend(loc=\"lower right\", fontsize=16)\n", "plt.legend(loc=\"lower right\", fontsize=16)\n",
"save_fig(\"roc_curve_comparison_plot\")\n", "save_fig(\"roc_curve_comparison_plot\")\n",
"plt.show()" "plt.show()"
@ -764,7 +821,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42, "execution_count": 53,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -777,7 +834,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 43, "execution_count": 54,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -791,7 +848,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 44, "execution_count": 55,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -814,7 +871,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 45, "execution_count": 56,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -828,7 +885,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 46, "execution_count": 57,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -842,7 +899,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 47, "execution_count": 58,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -855,7 +912,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 48, "execution_count": 59,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -868,7 +925,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 49, "execution_count": 60,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"sgd_clf.classes_[5]"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -884,7 +952,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 50, "execution_count": 62,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -897,7 +965,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 51, "execution_count": 63,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -911,7 +979,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 52, "execution_count": 64,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -924,7 +992,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 53, "execution_count": 65,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -937,7 +1005,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 54, "execution_count": 66,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -953,7 +1021,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 55, "execution_count": 67,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -968,7 +1036,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 56, "execution_count": 68,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -981,8 +1049,17 @@
" fig = plt.figure(figsize=(8,8))\n", " fig = plt.figure(figsize=(8,8))\n",
" ax = fig.add_subplot(111)\n", " ax = fig.add_subplot(111)\n",
" cax = ax.matshow(conf_mx)\n", " cax = ax.matshow(conf_mx)\n",
" fig.colorbar(cax)\n", " fig.colorbar(cax)"
"\n", ]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plt.matshow(conf_mx, cmap=plt.cm.gray)\n", "plt.matshow(conf_mx, cmap=plt.cm.gray)\n",
"save_fig(\"confusion_matrix_plot\", tight_layout=False)\n", "save_fig(\"confusion_matrix_plot\", tight_layout=False)\n",
"plt.show()" "plt.show()"
@ -990,7 +1067,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 57, "execution_count": 70,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -999,7 +1076,17 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"row_sums = conf_mx.sum(axis=1, keepdims=True)\n", "row_sums = conf_mx.sum(axis=1, keepdims=True)\n",
"norm_conf_mx = conf_mx / row_sums\n", "norm_conf_mx = conf_mx / row_sums"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"np.fill_diagonal(norm_conf_mx, 0)\n", "np.fill_diagonal(norm_conf_mx, 0)\n",
"plt.matshow(norm_conf_mx, cmap=plt.cm.gray)\n", "plt.matshow(norm_conf_mx, cmap=plt.cm.gray)\n",
"save_fig(\"confusion_matrix_errors_plot\", tight_layout=False)\n", "save_fig(\"confusion_matrix_errors_plot\", tight_layout=False)\n",
@ -1008,7 +1095,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 58, "execution_count": 72,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -1023,14 +1110,10 @@
"X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]\n", "X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]\n",
"\n", "\n",
"plt.figure(figsize=(8,8))\n", "plt.figure(figsize=(8,8))\n",
"plt.subplot(221)\n", "plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)\n",
"plot_digits(X_aa[:25], images_per_row=5)\n", "plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)\n",
"plt.subplot(222)\n", "plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)\n",
"plot_digits(X_ab[:25], images_per_row=5)\n", "plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)\n",
"plt.subplot(223)\n",
"plot_digits(X_ba[:25], images_per_row=5)\n",
"plt.subplot(224)\n",
"plot_digits(X_bb[:25], images_per_row=5)\n",
"save_fig(\"error_analysis_digits_plot\")\n", "save_fig(\"error_analysis_digits_plot\")\n",
"plt.show()" "plt.show()"
] ]
@ -1047,7 +1130,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 59, "execution_count": 73,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -1067,7 +1150,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 60, "execution_count": 74,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -1080,20 +1163,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 61, "execution_count": 77,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_train, cv=3)"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -1101,6 +1171,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_train, cv=3)\n",
"f1_score(y_train, y_train_knn_pred, average=\"macro\")" "f1_score(y_train, y_train_knn_pred, average=\"macro\")"
] ]
}, },
@ -1116,7 +1187,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 63, "execution_count": 78,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -1134,7 +1205,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 64, "execution_count": 79,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -1151,20 +1222,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 65, "execution_count": 82,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"knn_clf.fit(X_train_mod, y_train_mod)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"deletable": true, "deletable": true,
@ -1172,10 +1230,10 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"knn_clf.fit(X_train_mod, y_train_mod)\n",
"clean_digit = knn_clf.predict([X_test_mod[some_index]])\n", "clean_digit = knn_clf.predict([X_test_mod[some_index]])\n",
"plot_digit(clean_digit)\n", "plot_digit(clean_digit)\n",
"save_fig(\"cleaned_digit_example_plot\")\n", "save_fig(\"cleaned_digit_example_plot\")"
"plt.show()"
] ]
}, },
{ {
@ -1435,7 +1493,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.5.2+" "version": "3.5.3"
}, },
"nav_menu": {}, "nav_menu": {},
"toc": { "toc": {