From c226f328d369a011e27d9f2e54ef22cc7bfad3e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Thu, 1 Jun 2017 09:52:10 +0200 Subject: [PATCH] Make notebook code match book examples more closely in chapter 3 --- 03_classification.ipynb | 382 +++++++++++++++++++++++----------------- 1 file changed, 220 insertions(+), 162 deletions(-) diff --git a/03_classification.ipynb b/03_classification.ipynb index 579c829..087766b 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -93,47 +93,14 @@ }, "outputs": [], "source": [ - "from six.moves import urllib\n", "from sklearn.datasets import fetch_mldata\n", - "try:\n", - " mnist = fetch_mldata('MNIST original')\n", - "except urllib.error.HTTPError as ex:\n", - " print(\"Could not download MNIST data from mldata.org, trying alternative...\")\n", - "\n", - " # Alternative method to load MNIST, if mldata.org is down\n", - " from scipy.io import loadmat\n", - " mnist_alternative_url = \"https://github.com/amplab/datascience-sp14/raw/master/lab7/mldata/mnist-original.mat\"\n", - " mnist_path = \"./mnist-original.mat\"\n", - " response = urllib.request.urlopen(mnist_alternative_url)\n", - " with open(mnist_path, \"wb\") as f:\n", - " content = response.read()\n", - " f.write(content)\n", - " mnist_raw = loadmat(mnist_path)\n", - " mnist = {\n", - " \"data\": mnist_raw[\"data\"].T,\n", - " \"target\": mnist_raw[\"label\"][0],\n", - " \"COL_NAMES\": [\"label\", \"data\"],\n", - " \"DESCR\": \"mldata.org dataset: mnist-original\",\n", - " }\n", - " print(\"Success!\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, - "outputs": [], - "source": [ + "mnist = fetch_mldata('MNIST original')\n", "mnist" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "collapsed": false, "deletable": true, @@ -147,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": { "collapsed": false, "deletable": true, @@ -160,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": { "collapsed": false, "deletable": true, @@ -173,25 +140,43 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], + "source": [ + "%matplotlib inline\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "\n", + "some_digit = X[36000]\n", + "some_digit_image = some_digit.reshape(28, 28)\n", + "plt.imshow(some_digit_image, cmap = matplotlib.cm.binary,\n", + " interpolation=\"nearest\")\n", + "plt.axis(\"off\")\n", + "\n", + "save_fig(\"some_digit_plot\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], "source": [ "def plot_digit(data):\n", " image = data.reshape(28, 28)\n", " plt.imshow(image, cmap = matplotlib.cm.binary,\n", " interpolation=\"nearest\")\n", - " plt.axis(\"off\")\n", - "\n", - "some_digit_index = 36000\n", - "some_digit = X[some_digit_index]\n", - "plot_digit(some_digit)\n", - "save_fig(\"some_digit_plot\")\n", - "plt.show()" + " plt.axis(\"off\")" ] }, { @@ -218,13 +203,7 @@ " row_images.append(np.concatenate(rimages, axis=1))\n", " image = np.concatenate(row_images, axis=0)\n", " plt.imshow(image, cmap = matplotlib.cm.binary, **options)\n", - " plt.axis(\"off\")\n", - "\n", - "plt.figure(figsize=(9,9))\n", - "example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]\n", - "plot_digits(example_images, images_per_row=10)\n", - "save_fig(\"more_digits_plot\")\n", - "plt.show()" + " plt.axis(\"off\")" ] }, { @@ -237,20 +216,24 @@ }, "outputs": [], "source": [ - "y[some_digit_index]" + "plt.figure(figsize=(9,9))\n", + "example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]\n", + "plot_digits(example_images, images_per_row=10)\n", + "save_fig(\"more_digits_plot\")\n", + "plt.show()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { - "collapsed": true, + "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ - "X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]" + "y[36000]" ] }, { @@ -263,7 +246,22 @@ }, "outputs": [], "source": [ - "shuffle_index = rnd.permutation(60000)\n", + "X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "shuffle_index = np.random.permutation(60000)\n", "X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]" ] }, @@ -279,7 +277,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": { "collapsed": true, "deletable": true, @@ -293,7 +291,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": { "collapsed": false, "deletable": true, @@ -309,7 +307,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": { "collapsed": false, "deletable": true, @@ -322,7 +320,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "metadata": { "collapsed": false, "deletable": true, @@ -336,7 +334,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "metadata": { "collapsed": false, "deletable": true, @@ -364,7 +362,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "metadata": { "collapsed": true, "deletable": true, @@ -382,7 +380,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "metadata": { "collapsed": false, "deletable": true, @@ -396,7 +394,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "metadata": { "collapsed": false, "deletable": true, @@ -411,7 +409,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 21, "metadata": { "collapsed": false, "deletable": true, @@ -426,7 +424,33 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 22, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "y_train_perfect_predictions = y_train_5" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "confusion_matrix(y_train_5, y_train_perfect_predictions)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, "metadata": { "collapsed": false, "deletable": true, @@ -441,7 +465,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 25, "metadata": { "collapsed": false, "deletable": true, @@ -454,7 +478,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 26, "metadata": { "collapsed": false, "deletable": true, @@ -467,7 +491,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 27, "metadata": { "collapsed": false, "deletable": true, @@ -480,7 +504,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 28, "metadata": { "collapsed": false, "deletable": true, @@ -494,7 +518,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 29, "metadata": { "collapsed": false, "deletable": true, @@ -507,7 +531,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 30, "metadata": { "collapsed": false, "deletable": true, @@ -521,7 +545,21 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 31, + "metadata": { + "collapsed": true, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "threshold = 0\n", + "y_some_digit_pred = (y_scores > threshold)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, "metadata": { "collapsed": false, "deletable": true, @@ -529,14 +567,12 @@ }, "outputs": [], "source": [ - "threshold = 0\n", - "y_some_digit_pred = (y_scores > threshold)\n", "y_some_digit_pred" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 33, "metadata": { "collapsed": false, "deletable": true, @@ -551,7 +587,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 34, "metadata": { "collapsed": true, "deletable": true, @@ -559,12 +595,13 @@ }, "outputs": [], "source": [ - "y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method=\"decision_function\")" + "y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,\n", + " method=\"decision_function\")" ] }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 35, "metadata": { "collapsed": true, "deletable": true, @@ -579,7 +616,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 36, "metadata": { "collapsed": false, "deletable": true, @@ -591,7 +628,7 @@ " plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n", " plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n", " plt.xlabel(\"Threshold\", fontsize=16)\n", - " plt.legend(loc=\"center left\", fontsize=16)\n", + " plt.legend(loc=\"upper left\", fontsize=16)\n", " plt.ylim([0, 1])\n", "\n", "plt.figure(figsize=(8, 4))\n", @@ -603,7 +640,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 37, "metadata": { "collapsed": false, "deletable": true, @@ -616,7 +653,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 38, "metadata": { "collapsed": false, "deletable": true, @@ -629,7 +666,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 39, "metadata": { "collapsed": false, "deletable": true, @@ -642,7 +679,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 40, "metadata": { "collapsed": false, "deletable": true, @@ -655,7 +692,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 41, "metadata": { "collapsed": false, "deletable": true, @@ -687,7 +724,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 42, "metadata": { "collapsed": false, "deletable": true, @@ -702,7 +739,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 43, "metadata": { "collapsed": false, "deletable": true, @@ -710,8 +747,8 @@ }, "outputs": [], "source": [ - "def plot_roc_curve(fpr, tpr, **options):\n", - " plt.plot(fpr, tpr, linewidth=2, **options)\n", + "def plot_roc_curve(fpr, tpr, label=None):\n", + " plt.plot(fpr, tpr, linewidth=2, label=label)\n", " plt.plot([0, 1], [0, 1], 'k--')\n", " plt.axis([0, 1, 0, 1])\n", " plt.xlabel('False Positive Rate', fontsize=16)\n", @@ -725,7 +762,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 44, "metadata": { "collapsed": false, "deletable": true, @@ -740,7 +777,33 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 45, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from sklearn.ensemble import RandomForestClassifier\n", + "forest_clf = RandomForestClassifier(random_state=42)\n", + "y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,\n", + " method=\"predict_proba\")" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n", + "fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, "metadata": { "collapsed": false, "deletable": true, @@ -748,15 +811,9 @@ }, "outputs": [], "source": [ - "from sklearn.ensemble import RandomForestClassifier\n", - "forest_clf = RandomForestClassifier(random_state=42)\n", - "y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n", - "y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n", - "fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)\n", - "\n", "plt.figure(figsize=(8, 6))\n", "plt.plot(fpr, tpr, \"b:\", linewidth=2, label=\"SGD\")\n", - "plot_roc_curve(fpr_forest, tpr_forest, label=\"Random Forest\")\n", + "plot_roc_curve(fpr_forest, tpr_forest, \"Random Forest\")\n", "plt.legend(loc=\"lower right\", fontsize=16)\n", "save_fig(\"roc_curve_comparison_plot\")\n", "plt.show()" @@ -764,7 +821,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 53, "metadata": { "collapsed": false, "deletable": true, @@ -777,7 +834,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 54, "metadata": { "collapsed": false, "deletable": true, @@ -791,7 +848,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 55, "metadata": { "collapsed": false, "deletable": true, @@ -814,7 +871,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 56, "metadata": { "collapsed": false, "deletable": true, @@ -828,7 +885,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 57, "metadata": { "collapsed": false, "deletable": true, @@ -842,7 +899,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 58, "metadata": { "collapsed": false, "deletable": true, @@ -855,7 +912,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 59, "metadata": { "collapsed": false, "deletable": true, @@ -868,7 +925,18 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 60, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "sgd_clf.classes_[5]" + ] + }, + { + "cell_type": "code", + "execution_count": 61, "metadata": { "collapsed": false, "deletable": true, @@ -884,7 +952,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 62, "metadata": { "collapsed": false, "deletable": true, @@ -897,7 +965,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 63, "metadata": { "collapsed": false, "deletable": true, @@ -911,7 +979,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 64, "metadata": { "collapsed": false, "deletable": true, @@ -924,7 +992,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 65, "metadata": { "collapsed": false, "deletable": true, @@ -937,7 +1005,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 66, "metadata": { "collapsed": false, "deletable": true, @@ -953,7 +1021,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 67, "metadata": { "collapsed": false, "deletable": true, @@ -968,7 +1036,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 68, "metadata": { "collapsed": false, "deletable": true, @@ -981,8 +1049,17 @@ " fig = plt.figure(figsize=(8,8))\n", " ax = fig.add_subplot(111)\n", " cax = ax.matshow(conf_mx)\n", - " fig.colorbar(cax)\n", - "\n", + " fig.colorbar(cax)" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ "plt.matshow(conf_mx, cmap=plt.cm.gray)\n", "save_fig(\"confusion_matrix_plot\", tight_layout=False)\n", "plt.show()" @@ -990,7 +1067,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 70, "metadata": { "collapsed": false, "deletable": true, @@ -999,7 +1076,17 @@ "outputs": [], "source": [ "row_sums = conf_mx.sum(axis=1, keepdims=True)\n", - "norm_conf_mx = conf_mx / row_sums\n", + "norm_conf_mx = conf_mx / row_sums" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ "np.fill_diagonal(norm_conf_mx, 0)\n", "plt.matshow(norm_conf_mx, cmap=plt.cm.gray)\n", "save_fig(\"confusion_matrix_errors_plot\", tight_layout=False)\n", @@ -1008,7 +1095,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 72, "metadata": { "collapsed": false, "deletable": true, @@ -1023,14 +1110,10 @@ "X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]\n", "\n", "plt.figure(figsize=(8,8))\n", - "plt.subplot(221)\n", - "plot_digits(X_aa[:25], images_per_row=5)\n", - "plt.subplot(222)\n", - "plot_digits(X_ab[:25], images_per_row=5)\n", - "plt.subplot(223)\n", - "plot_digits(X_ba[:25], images_per_row=5)\n", - "plt.subplot(224)\n", - "plot_digits(X_bb[:25], images_per_row=5)\n", + "plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)\n", + "plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)\n", + "plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)\n", + "plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)\n", "save_fig(\"error_analysis_digits_plot\")\n", "plt.show()" ] @@ -1047,7 +1130,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 73, "metadata": { "collapsed": false, "deletable": true, @@ -1067,7 +1150,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 74, "metadata": { "collapsed": false, "deletable": true, @@ -1080,20 +1163,7 @@ }, { "cell_type": "code", - "execution_count": 61, - "metadata": { - "collapsed": true, - "deletable": true, - "editable": true - }, - "outputs": [], - "source": [ - "y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_train, cv=3)" - ] - }, - { - "cell_type": "code", - "execution_count": 62, + "execution_count": 77, "metadata": { "collapsed": false, "deletable": true, @@ -1101,6 +1171,7 @@ }, "outputs": [], "source": [ + "y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_train, cv=3)\n", "f1_score(y_train, y_train_knn_pred, average=\"macro\")" ] }, @@ -1116,7 +1187,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 78, "metadata": { "collapsed": false, "deletable": true, @@ -1134,7 +1205,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 79, "metadata": { "collapsed": false, "deletable": true, @@ -1151,20 +1222,7 @@ }, { "cell_type": "code", - "execution_count": 65, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, - "outputs": [], - "source": [ - "knn_clf.fit(X_train_mod, y_train_mod)" - ] - }, - { - "cell_type": "code", - "execution_count": 66, + "execution_count": 82, "metadata": { "collapsed": false, "deletable": true, @@ -1172,10 +1230,10 @@ }, "outputs": [], "source": [ + "knn_clf.fit(X_train_mod, y_train_mod)\n", "clean_digit = knn_clf.predict([X_test_mod[some_index]])\n", "plot_digit(clean_digit)\n", - "save_fig(\"cleaned_digit_example_plot\")\n", - "plt.show()" + "save_fig(\"cleaned_digit_example_plot\")" ] }, { @@ -1435,7 +1493,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.2+" + "version": "3.5.3" }, "nav_menu": {}, "toc": {