From 633436e8ae4a36960cecff8051fe08d89dbe2afc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Sun, 21 Nov 2021 17:06:37 +1300 Subject: [PATCH] Clarify the 'not in the book' comments --- 03_classification.ipynb | 62 +++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/03_classification.ipynb b/03_classification.ipynb index 832dbde..f2dfa63 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -140,7 +140,7 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", + "# not in the book – it's a bit too long\n", "print(mnist.DESCR)" ] }, @@ -150,8 +150,7 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", - "mnist.keys()" + "mnist.keys() # not in the book – we only use data and target in this notebook" ] }, { @@ -234,7 +233,7 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", + "# not in the book – this code generates Figure 3–2\n", "plt.figure(figsize=(9, 9))\n", "for idx, image_data in enumerate(X[:100]):\n", " plt.subplot(10, 10, idx + 1)\n", @@ -388,7 +387,8 @@ "source": [ "from sklearn.metrics import confusion_matrix\n", "\n", - "confusion_matrix(y_train_5, y_train_pred)" + "cm = confusion_matrix(y_train_5, y_train_pred)\n", + "cm" ] }, { @@ -425,11 +425,7 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", - "\n", - "cm = confusion_matrix(y_train_5, y_train_pred)\n", - "\n", - "# Precision = TP / (FP + TP)\n", + "# not in the book – this code also computes the precision: TP / (FP + TP)\n", "cm[1, 1] / (cm[0, 1] + cm[1, 1])" ] }, @@ -448,9 +444,7 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", - "\n", - "# Recall = TP / (FN + TP)\n", + "# not in the book – this code also computes the recall: TP / (FN + TP)\n", "cm[1, 1] / (cm[1, 0] + cm[1, 1])" ] }, @@ -471,6 +465,7 @@ "metadata": {}, "outputs": [], "source": [ + "# not in the book – this code also computes the f1 score\n", "cm[1, 1] / (cm[1, 1] + (cm[1, 0] + cm[0, 1]) / 2)" ] }, @@ -516,9 +511,9 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", - "# Using threshold 0, we get exactly the same predictions as with predict()\n", - "(y_train_pred == (y_scores > 0)).all()" + "# not in the book – this code just shows that y_scores > 0 produces the same\n", + "# result as calling predict()\n", + "y_scores > 0" ] }, { @@ -559,13 +554,12 @@ "metadata": {}, "outputs": [], "source": [ - "plt.figure(figsize=(8, 4)) # not in the book\n", + "plt.figure(figsize=(8, 4)) # not in the book – it's not needed, just formatting\n", "plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n", "plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n", "plt.vlines(threshold, 0, 1.0, \"k\", \"dotted\", label=\"threshold\")\n", "\n", - "# not in the book\n", - "# beautify the figure\n", + "# not in the book – this section just beautifies and saves Figure 3–5\n", "idx = (thresholds >= threshold).argmax() # first index ≥ threshold\n", "plt.plot(thresholds[idx], precisions[idx], \"bo\")\n", "plt.plot(thresholds[idx], recalls[idx], \"go\")\n", @@ -584,13 +578,13 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib.patches as patches # not in the book\n", + "import matplotlib.patches as patches # not in the book – for the curved arrow\n", "\n", - "plt.figure(figsize=(6, 5)) # not in the book\n", + "plt.figure(figsize=(6, 5)) # not in the book – not needed, just formatting\n", "\n", "plt.plot(recalls, precisions, linewidth=2, label=\"Precision/Recall curve\")\n", "\n", - "# not in the book (just beautifies the figure)\n", + "# not in the book – just beautifies and saves Figure 3–6\n", "plt.plot([recalls[idx], recalls[idx]], [0., precisions[idx]], \"k:\")\n", "plt.plot([0.0, recalls[idx]], [precisions[idx], precisions[idx]], \"k:\")\n", "plt.plot([recalls[idx]], [precisions[idx]], \"ko\",\n", @@ -677,12 +671,12 @@ "idx_for_threshold_at_90 = (thresholds <= threshold_for_90_precision).argmax()\n", "tpr_90, fpr_90 = tpr[idx_for_threshold_at_90], fpr[idx_for_threshold_at_90]\n", "\n", - "plt.figure(figsize=(6, 5)) # not in the book\n", + "plt.figure(figsize=(6, 5)) # not in the book – not needed, just formatting\n", "plt.plot(fpr, tpr, linewidth=2, label=\"ROC curve\")\n", "plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier's ROC curve\")\n", "plt.plot([fpr_90], [tpr_90], \"ko\", label=\"Threshold for 90% precision\")\n", "\n", - "# not in the book (just beautifies the figure)\n", + "# not in the book – just beautifies and saves Figure 3–7\n", "plt.gca().add_patch(patches.FancyArrowPatch(\n", " (0.20, 0.89), (0.07, 0.70),\n", " connectionstyle=\"arc3,rad=.4\",\n", @@ -782,13 +776,13 @@ "metadata": {}, "outputs": [], "source": [ - "plt.figure(figsize=(6, 5)) # not in the book\n", + "plt.figure(figsize=(6, 5)) # not in the book – not needed, just formatting\n", "\n", "plt.plot(fpr_forest, tpr_forest, \"b-\", linewidth=2, label=\"Random Forest\")\n", "plt.plot(fpr, tpr, \"--\", linewidth=2, label=\"SGD\")\n", "plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier\")\n", "\n", - "# not in the book\n", + "# not in the book – just beautifies and saves Figure 3–8\n", "plt.xlabel('False Positive Rate (Fall-Out)')\n", "plt.ylabel('True Positive Rate (Recall)')\n", "plt.grid()\n", @@ -920,7 +914,7 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", + "# not in the book – this code shows how to get all 45 OvO scores if needed\n", "svm_clf.decision_function_shape = \"ovo\"\n", "some_digit_scores_ovo = svm_clf.decision_function([some_digit])\n", "some_digit_scores_ovo.round(2)" @@ -1069,7 +1063,7 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", + "# not in the book – this code generates Figure 3–9\n", "fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 8))\n", "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0, 0])\n", "axs[0, 0].set_title(\"Confusion matrix\")\n", @@ -1084,7 +1078,7 @@ " sample_weight=sample_weight,\n", " normalize=\"pred\", values_format=\".0%\")\n", "axs[1, 1].set_title(\"Errors normalized by column\")\n", - "save_fig(\"confusion_matrix_plot\") # not in the book\n", + "save_fig(\"confusion_matrix_plot\")\n", "plt.show()" ] }, @@ -1107,7 +1101,7 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", + "# not in the book – this code generates Figure 3–10\n", "size = 5\n", "pad = 0.2\n", "plt.figure(figsize=(size, size))\n", @@ -1200,7 +1194,9 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", + "# not in the book – this code shows that we get a negligible performance\n", + "# improvement when we set average=\"weighted\" because the\n", + "# classes are already pretty well balanced.\n", "f1_score(y_multilabel, y_train_knn_pred, average=\"weighted\")" ] }, @@ -1253,7 +1249,7 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", + "# not in the book – this code generates Figure 3–11\n", "plt.subplot(121); plot_digit(X_test_mod[0])\n", "plt.subplot(122); plot_digit(y_test_mod[0])\n", "save_fig(\"noisy_digit_example_plot\")\n", @@ -1270,7 +1266,7 @@ "knn_clf.fit(X_train_mod, y_train_mod)\n", "clean_digit = knn_clf.predict([X_test_mod[0]])\n", "plot_digit(clean_digit)\n", - "save_fig(\"cleaned_digit_example_plot\") # not in the book\n", + "save_fig(\"cleaned_digit_example_plot\") # not in the book – saves Figure 3–12\n", "plt.show()" ] },