Clarify the 'not in the book' comments
parent
c46123155d
commit
633436e8ae
|
@ -140,7 +140,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – it's a bit too long\n",
|
||||||
"print(mnist.DESCR)"
|
"print(mnist.DESCR)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -150,8 +150,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"mnist.keys() # not in the book – we only use data and target in this notebook"
|
||||||
"mnist.keys()"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -234,7 +233,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this code generates Figure 3–2\n",
|
||||||
"plt.figure(figsize=(9, 9))\n",
|
"plt.figure(figsize=(9, 9))\n",
|
||||||
"for idx, image_data in enumerate(X[:100]):\n",
|
"for idx, image_data in enumerate(X[:100]):\n",
|
||||||
" plt.subplot(10, 10, idx + 1)\n",
|
" plt.subplot(10, 10, idx + 1)\n",
|
||||||
|
@ -388,7 +387,8 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from sklearn.metrics import confusion_matrix\n",
|
"from sklearn.metrics import confusion_matrix\n",
|
||||||
"\n",
|
"\n",
|
||||||
"confusion_matrix(y_train_5, y_train_pred)"
|
"cm = confusion_matrix(y_train_5, y_train_pred)\n",
|
||||||
|
"cm"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -425,11 +425,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this code also computes the precision: TP / (FP + TP)\n",
|
||||||
"\n",
|
|
||||||
"cm = confusion_matrix(y_train_5, y_train_pred)\n",
|
|
||||||
"\n",
|
|
||||||
"# Precision = TP / (FP + TP)\n",
|
|
||||||
"cm[1, 1] / (cm[0, 1] + cm[1, 1])"
|
"cm[1, 1] / (cm[0, 1] + cm[1, 1])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -448,9 +444,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this code also computes the recall: TP / (FN + TP)\n",
|
||||||
"\n",
|
|
||||||
"# Recall = TP / (FN + TP)\n",
|
|
||||||
"cm[1, 1] / (cm[1, 0] + cm[1, 1])"
|
"cm[1, 1] / (cm[1, 0] + cm[1, 1])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -471,6 +465,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# not in the book – this code also computes the f1 score\n",
|
||||||
"cm[1, 1] / (cm[1, 1] + (cm[1, 0] + cm[0, 1]) / 2)"
|
"cm[1, 1] / (cm[1, 1] + (cm[1, 0] + cm[0, 1]) / 2)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -516,9 +511,9 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this code just shows that y_scores > 0 produces the same\n",
|
||||||
"# Using threshold 0, we get exactly the same predictions as with predict()\n",
|
"# result as calling predict()\n",
|
||||||
"(y_train_pred == (y_scores > 0)).all()"
|
"y_scores > 0"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -559,13 +554,12 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"plt.figure(figsize=(8, 4)) # not in the book\n",
|
"plt.figure(figsize=(8, 4)) # not in the book – it's not needed, just formatting\n",
|
||||||
"plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n",
|
"plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n",
|
||||||
"plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n",
|
"plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n",
|
||||||
"plt.vlines(threshold, 0, 1.0, \"k\", \"dotted\", label=\"threshold\")\n",
|
"plt.vlines(threshold, 0, 1.0, \"k\", \"dotted\", label=\"threshold\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# not in the book\n",
|
"# not in the book – this section just beautifies and saves Figure 3–5\n",
|
||||||
"# beautify the figure\n",
|
|
||||||
"idx = (thresholds >= threshold).argmax() # first index ≥ threshold\n",
|
"idx = (thresholds >= threshold).argmax() # first index ≥ threshold\n",
|
||||||
"plt.plot(thresholds[idx], precisions[idx], \"bo\")\n",
|
"plt.plot(thresholds[idx], precisions[idx], \"bo\")\n",
|
||||||
"plt.plot(thresholds[idx], recalls[idx], \"go\")\n",
|
"plt.plot(thresholds[idx], recalls[idx], \"go\")\n",
|
||||||
|
@ -584,13 +578,13 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import matplotlib.patches as patches # not in the book\n",
|
"import matplotlib.patches as patches # not in the book – for the curved arrow\n",
|
||||||
"\n",
|
"\n",
|
||||||
"plt.figure(figsize=(6, 5)) # not in the book\n",
|
"plt.figure(figsize=(6, 5)) # not in the book – not needed, just formatting\n",
|
||||||
"\n",
|
"\n",
|
||||||
"plt.plot(recalls, precisions, linewidth=2, label=\"Precision/Recall curve\")\n",
|
"plt.plot(recalls, precisions, linewidth=2, label=\"Precision/Recall curve\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# not in the book (just beautifies the figure)\n",
|
"# not in the book – just beautifies and saves Figure 3–6\n",
|
||||||
"plt.plot([recalls[idx], recalls[idx]], [0., precisions[idx]], \"k:\")\n",
|
"plt.plot([recalls[idx], recalls[idx]], [0., precisions[idx]], \"k:\")\n",
|
||||||
"plt.plot([0.0, recalls[idx]], [precisions[idx], precisions[idx]], \"k:\")\n",
|
"plt.plot([0.0, recalls[idx]], [precisions[idx], precisions[idx]], \"k:\")\n",
|
||||||
"plt.plot([recalls[idx]], [precisions[idx]], \"ko\",\n",
|
"plt.plot([recalls[idx]], [precisions[idx]], \"ko\",\n",
|
||||||
|
@ -677,12 +671,12 @@
|
||||||
"idx_for_threshold_at_90 = (thresholds <= threshold_for_90_precision).argmax()\n",
|
"idx_for_threshold_at_90 = (thresholds <= threshold_for_90_precision).argmax()\n",
|
||||||
"tpr_90, fpr_90 = tpr[idx_for_threshold_at_90], fpr[idx_for_threshold_at_90]\n",
|
"tpr_90, fpr_90 = tpr[idx_for_threshold_at_90], fpr[idx_for_threshold_at_90]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"plt.figure(figsize=(6, 5)) # not in the book\n",
|
"plt.figure(figsize=(6, 5)) # not in the book – not needed, just formatting\n",
|
||||||
"plt.plot(fpr, tpr, linewidth=2, label=\"ROC curve\")\n",
|
"plt.plot(fpr, tpr, linewidth=2, label=\"ROC curve\")\n",
|
||||||
"plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier's ROC curve\")\n",
|
"plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier's ROC curve\")\n",
|
||||||
"plt.plot([fpr_90], [tpr_90], \"ko\", label=\"Threshold for 90% precision\")\n",
|
"plt.plot([fpr_90], [tpr_90], \"ko\", label=\"Threshold for 90% precision\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# not in the book (just beautifies the figure)\n",
|
"# not in the book – just beautifies and saves Figure 3–7\n",
|
||||||
"plt.gca().add_patch(patches.FancyArrowPatch(\n",
|
"plt.gca().add_patch(patches.FancyArrowPatch(\n",
|
||||||
" (0.20, 0.89), (0.07, 0.70),\n",
|
" (0.20, 0.89), (0.07, 0.70),\n",
|
||||||
" connectionstyle=\"arc3,rad=.4\",\n",
|
" connectionstyle=\"arc3,rad=.4\",\n",
|
||||||
|
@ -782,13 +776,13 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"plt.figure(figsize=(6, 5)) # not in the book\n",
|
"plt.figure(figsize=(6, 5)) # not in the book – not needed, just formatting\n",
|
||||||
"\n",
|
"\n",
|
||||||
"plt.plot(fpr_forest, tpr_forest, \"b-\", linewidth=2, label=\"Random Forest\")\n",
|
"plt.plot(fpr_forest, tpr_forest, \"b-\", linewidth=2, label=\"Random Forest\")\n",
|
||||||
"plt.plot(fpr, tpr, \"--\", linewidth=2, label=\"SGD\")\n",
|
"plt.plot(fpr, tpr, \"--\", linewidth=2, label=\"SGD\")\n",
|
||||||
"plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier\")\n",
|
"plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# not in the book\n",
|
"# not in the book – just beautifies and saves Figure 3–8\n",
|
||||||
"plt.xlabel('False Positive Rate (Fall-Out)')\n",
|
"plt.xlabel('False Positive Rate (Fall-Out)')\n",
|
||||||
"plt.ylabel('True Positive Rate (Recall)')\n",
|
"plt.ylabel('True Positive Rate (Recall)')\n",
|
||||||
"plt.grid()\n",
|
"plt.grid()\n",
|
||||||
|
@ -920,7 +914,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this code shows how to get all 45 OvO scores if needed\n",
|
||||||
"svm_clf.decision_function_shape = \"ovo\"\n",
|
"svm_clf.decision_function_shape = \"ovo\"\n",
|
||||||
"some_digit_scores_ovo = svm_clf.decision_function([some_digit])\n",
|
"some_digit_scores_ovo = svm_clf.decision_function([some_digit])\n",
|
||||||
"some_digit_scores_ovo.round(2)"
|
"some_digit_scores_ovo.round(2)"
|
||||||
|
@ -1069,7 +1063,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this code generates Figure 3–9\n",
|
||||||
"fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 8))\n",
|
"fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 8))\n",
|
||||||
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0, 0])\n",
|
"ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0, 0])\n",
|
||||||
"axs[0, 0].set_title(\"Confusion matrix\")\n",
|
"axs[0, 0].set_title(\"Confusion matrix\")\n",
|
||||||
|
@ -1084,7 +1078,7 @@
|
||||||
" sample_weight=sample_weight,\n",
|
" sample_weight=sample_weight,\n",
|
||||||
" normalize=\"pred\", values_format=\".0%\")\n",
|
" normalize=\"pred\", values_format=\".0%\")\n",
|
||||||
"axs[1, 1].set_title(\"Errors normalized by column\")\n",
|
"axs[1, 1].set_title(\"Errors normalized by column\")\n",
|
||||||
"save_fig(\"confusion_matrix_plot\") # not in the book\n",
|
"save_fig(\"confusion_matrix_plot\")\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -1107,7 +1101,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this code generates Figure 3–10\n",
|
||||||
"size = 5\n",
|
"size = 5\n",
|
||||||
"pad = 0.2\n",
|
"pad = 0.2\n",
|
||||||
"plt.figure(figsize=(size, size))\n",
|
"plt.figure(figsize=(size, size))\n",
|
||||||
|
@ -1200,7 +1194,9 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this code shows that we get a negligible performance\n",
|
||||||
|
"# improvement when we set average=\"weighted\" because the\n",
|
||||||
|
"# classes are already pretty well balanced.\n",
|
||||||
"f1_score(y_multilabel, y_train_knn_pred, average=\"weighted\")"
|
"f1_score(y_multilabel, y_train_knn_pred, average=\"weighted\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -1253,7 +1249,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# not in the book\n",
|
"# not in the book – this code generates Figure 3–11\n",
|
||||||
"plt.subplot(121); plot_digit(X_test_mod[0])\n",
|
"plt.subplot(121); plot_digit(X_test_mod[0])\n",
|
||||||
"plt.subplot(122); plot_digit(y_test_mod[0])\n",
|
"plt.subplot(122); plot_digit(y_test_mod[0])\n",
|
||||||
"save_fig(\"noisy_digit_example_plot\")\n",
|
"save_fig(\"noisy_digit_example_plot\")\n",
|
||||||
|
@ -1270,7 +1266,7 @@
|
||||||
"knn_clf.fit(X_train_mod, y_train_mod)\n",
|
"knn_clf.fit(X_train_mod, y_train_mod)\n",
|
||||||
"clean_digit = knn_clf.predict([X_test_mod[0]])\n",
|
"clean_digit = knn_clf.predict([X_test_mod[0]])\n",
|
||||||
"plot_digit(clean_digit)\n",
|
"plot_digit(clean_digit)\n",
|
||||||
"save_fig(\"cleaned_digit_example_plot\") # not in the book\n",
|
"save_fig(\"cleaned_digit_example_plot\") # not in the book – saves Figure 3–12\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
Loading…
Reference in New Issue