Clarify the 'not in the book' comments

main
Aurélien Geron 2021-11-21 17:06:37 +13:00
parent c46123155d
commit 633436e8ae
1 changed files with 29 additions and 33 deletions

View File

@ -140,7 +140,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book it's a bit too long\n",
"print(mnist.DESCR)" "print(mnist.DESCR)"
] ]
}, },
@ -150,8 +150,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "mnist.keys() # not in the book we only use data and target in this notebook"
"mnist.keys()"
] ]
}, },
{ {
@ -234,7 +233,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this code generates Figure 32\n",
"plt.figure(figsize=(9, 9))\n", "plt.figure(figsize=(9, 9))\n",
"for idx, image_data in enumerate(X[:100]):\n", "for idx, image_data in enumerate(X[:100]):\n",
" plt.subplot(10, 10, idx + 1)\n", " plt.subplot(10, 10, idx + 1)\n",
@ -388,7 +387,8 @@
"source": [ "source": [
"from sklearn.metrics import confusion_matrix\n", "from sklearn.metrics import confusion_matrix\n",
"\n", "\n",
"confusion_matrix(y_train_5, y_train_pred)" "cm = confusion_matrix(y_train_5, y_train_pred)\n",
"cm"
] ]
}, },
{ {
@ -425,11 +425,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this code also computes the precision: TP / (FP + TP)\n",
"\n",
"cm = confusion_matrix(y_train_5, y_train_pred)\n",
"\n",
"# Precision = TP / (FP + TP)\n",
"cm[1, 1] / (cm[0, 1] + cm[1, 1])" "cm[1, 1] / (cm[0, 1] + cm[1, 1])"
] ]
}, },
@ -448,9 +444,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this code also computes the recall: TP / (FN + TP)\n",
"\n",
"# Recall = TP / (FN + TP)\n",
"cm[1, 1] / (cm[1, 0] + cm[1, 1])" "cm[1, 1] / (cm[1, 0] + cm[1, 1])"
] ]
}, },
@ -471,6 +465,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book this code also computes the f1 score\n",
"cm[1, 1] / (cm[1, 1] + (cm[1, 0] + cm[0, 1]) / 2)" "cm[1, 1] / (cm[1, 1] + (cm[1, 0] + cm[0, 1]) / 2)"
] ]
}, },
@ -516,9 +511,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this code just shows that y_scores > 0 produces the same\n",
"# Using threshold 0, we get exactly the same predictions as with predict()\n", "# result as calling predict()\n",
"(y_train_pred == (y_scores > 0)).all()" "y_scores > 0"
] ]
}, },
{ {
@ -559,13 +554,12 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"plt.figure(figsize=(8, 4)) # not in the book\n", "plt.figure(figsize=(8, 4)) # not in the book it's not needed, just formatting\n",
"plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n", "plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n",
"plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n", "plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n",
"plt.vlines(threshold, 0, 1.0, \"k\", \"dotted\", label=\"threshold\")\n", "plt.vlines(threshold, 0, 1.0, \"k\", \"dotted\", label=\"threshold\")\n",
"\n", "\n",
"# not in the book\n", "# not in the book this section just beautifies and saves Figure 35\n",
"# beautify the figure\n",
"idx = (thresholds >= threshold).argmax() # first index ≥ threshold\n", "idx = (thresholds >= threshold).argmax() # first index ≥ threshold\n",
"plt.plot(thresholds[idx], precisions[idx], \"bo\")\n", "plt.plot(thresholds[idx], precisions[idx], \"bo\")\n",
"plt.plot(thresholds[idx], recalls[idx], \"go\")\n", "plt.plot(thresholds[idx], recalls[idx], \"go\")\n",
@ -584,13 +578,13 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import matplotlib.patches as patches # not in the book\n", "import matplotlib.patches as patches # not in the book for the curved arrow\n",
"\n", "\n",
"plt.figure(figsize=(6, 5)) # not in the book\n", "plt.figure(figsize=(6, 5)) # not in the book not needed, just formatting\n",
"\n", "\n",
"plt.plot(recalls, precisions, linewidth=2, label=\"Precision/Recall curve\")\n", "plt.plot(recalls, precisions, linewidth=2, label=\"Precision/Recall curve\")\n",
"\n", "\n",
"# not in the book (just beautifies the figure)\n", "# not in the book just beautifies and saves Figure 36\n",
"plt.plot([recalls[idx], recalls[idx]], [0., precisions[idx]], \"k:\")\n", "plt.plot([recalls[idx], recalls[idx]], [0., precisions[idx]], \"k:\")\n",
"plt.plot([0.0, recalls[idx]], [precisions[idx], precisions[idx]], \"k:\")\n", "plt.plot([0.0, recalls[idx]], [precisions[idx], precisions[idx]], \"k:\")\n",
"plt.plot([recalls[idx]], [precisions[idx]], \"ko\",\n", "plt.plot([recalls[idx]], [precisions[idx]], \"ko\",\n",
@ -677,12 +671,12 @@
"idx_for_threshold_at_90 = (thresholds <= threshold_for_90_precision).argmax()\n", "idx_for_threshold_at_90 = (thresholds <= threshold_for_90_precision).argmax()\n",
"tpr_90, fpr_90 = tpr[idx_for_threshold_at_90], fpr[idx_for_threshold_at_90]\n", "tpr_90, fpr_90 = tpr[idx_for_threshold_at_90], fpr[idx_for_threshold_at_90]\n",
"\n", "\n",
"plt.figure(figsize=(6, 5)) # not in the book\n", "plt.figure(figsize=(6, 5)) # not in the book not needed, just formatting\n",
"plt.plot(fpr, tpr, linewidth=2, label=\"ROC curve\")\n", "plt.plot(fpr, tpr, linewidth=2, label=\"ROC curve\")\n",
"plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier's ROC curve\")\n", "plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier's ROC curve\")\n",
"plt.plot([fpr_90], [tpr_90], \"ko\", label=\"Threshold for 90% precision\")\n", "plt.plot([fpr_90], [tpr_90], \"ko\", label=\"Threshold for 90% precision\")\n",
"\n", "\n",
"# not in the book (just beautifies the figure)\n", "# not in the book just beautifies and saves Figure 37\n",
"plt.gca().add_patch(patches.FancyArrowPatch(\n", "plt.gca().add_patch(patches.FancyArrowPatch(\n",
" (0.20, 0.89), (0.07, 0.70),\n", " (0.20, 0.89), (0.07, 0.70),\n",
" connectionstyle=\"arc3,rad=.4\",\n", " connectionstyle=\"arc3,rad=.4\",\n",
@ -782,13 +776,13 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"plt.figure(figsize=(6, 5)) # not in the book\n", "plt.figure(figsize=(6, 5)) # not in the book not needed, just formatting\n",
"\n", "\n",
"plt.plot(fpr_forest, tpr_forest, \"b-\", linewidth=2, label=\"Random Forest\")\n", "plt.plot(fpr_forest, tpr_forest, \"b-\", linewidth=2, label=\"Random Forest\")\n",
"plt.plot(fpr, tpr, \"--\", linewidth=2, label=\"SGD\")\n", "plt.plot(fpr, tpr, \"--\", linewidth=2, label=\"SGD\")\n",
"plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier\")\n", "plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier\")\n",
"\n", "\n",
"# not in the book\n", "# not in the book just beautifies and saves Figure 38\n",
"plt.xlabel('False Positive Rate (Fall-Out)')\n", "plt.xlabel('False Positive Rate (Fall-Out)')\n",
"plt.ylabel('True Positive Rate (Recall)')\n", "plt.ylabel('True Positive Rate (Recall)')\n",
"plt.grid()\n", "plt.grid()\n",
@ -920,7 +914,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this code shows how to get all 45 OvO scores if needed\n",
"svm_clf.decision_function_shape = \"ovo\"\n", "svm_clf.decision_function_shape = \"ovo\"\n",
"some_digit_scores_ovo = svm_clf.decision_function([some_digit])\n", "some_digit_scores_ovo = svm_clf.decision_function([some_digit])\n",
"some_digit_scores_ovo.round(2)" "some_digit_scores_ovo.round(2)"
@ -1069,7 +1063,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this code generates Figure 39\n",
"fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 8))\n", "fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 8))\n",
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0, 0])\n", "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0, 0])\n",
"axs[0, 0].set_title(\"Confusion matrix\")\n", "axs[0, 0].set_title(\"Confusion matrix\")\n",
@ -1084,7 +1078,7 @@
" sample_weight=sample_weight,\n", " sample_weight=sample_weight,\n",
" normalize=\"pred\", values_format=\".0%\")\n", " normalize=\"pred\", values_format=\".0%\")\n",
"axs[1, 1].set_title(\"Errors normalized by column\")\n", "axs[1, 1].set_title(\"Errors normalized by column\")\n",
"save_fig(\"confusion_matrix_plot\") # not in the book\n", "save_fig(\"confusion_matrix_plot\")\n",
"plt.show()" "plt.show()"
] ]
}, },
@ -1107,7 +1101,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this code generates Figure 310\n",
"size = 5\n", "size = 5\n",
"pad = 0.2\n", "pad = 0.2\n",
"plt.figure(figsize=(size, size))\n", "plt.figure(figsize=(size, size))\n",
@ -1200,7 +1194,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this code shows that we get a negligible performance\n",
"# improvement when we set average=\"weighted\" because the\n",
"# classes are already pretty well balanced.\n",
"f1_score(y_multilabel, y_train_knn_pred, average=\"weighted\")" "f1_score(y_multilabel, y_train_knn_pred, average=\"weighted\")"
] ]
}, },
@ -1253,7 +1249,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this code generates Figure 311\n",
"plt.subplot(121); plot_digit(X_test_mod[0])\n", "plt.subplot(121); plot_digit(X_test_mod[0])\n",
"plt.subplot(122); plot_digit(y_test_mod[0])\n", "plt.subplot(122); plot_digit(y_test_mod[0])\n",
"save_fig(\"noisy_digit_example_plot\")\n", "save_fig(\"noisy_digit_example_plot\")\n",
@ -1270,7 +1266,7 @@
"knn_clf.fit(X_train_mod, y_train_mod)\n", "knn_clf.fit(X_train_mod, y_train_mod)\n",
"clean_digit = knn_clf.predict([X_test_mod[0]])\n", "clean_digit = knn_clf.predict([X_test_mod[0]])\n",
"plot_digit(clean_digit)\n", "plot_digit(clean_digit)\n",
"save_fig(\"cleaned_digit_example_plot\") # not in the book\n", "save_fig(\"cleaned_digit_example_plot\") # not in the book saves Figure 312\n",
"plt.show()" "plt.show()"
] ]
}, },