diff --git a/03_classification.ipynb b/03_classification.ipynb index ca360c5..d67dbf9 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -766,8 +766,8 @@ "outputs": [], "source": [ "y_scores_forest = y_probas_forest[:, 1]\n", - "fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,\n", - " y_scores_forest)" + "precisions_forest, recalls_forest, thresholds_forest = precision_recall_curve(\n", + " y_train_5, y_scores_forest)" ] }, { @@ -778,30 +778,21 @@ "source": [ "plt.figure(figsize=(6, 5)) # not in the book – not needed, just formatting\n", "\n", - "plt.plot(fpr_forest, tpr_forest, \"b-\", linewidth=2, label=\"Random Forest\")\n", - "plt.plot(fpr, tpr, \"--\", linewidth=2, label=\"SGD\")\n", - "plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier\")\n", + "plt.plot(recalls_forest, precisions_forest, \"b-\", linewidth=2,\n", + " label=\"Random Forest\")\n", + "plt.plot(recalls, precisions, \"--\", linewidth=2, label=\"SGD\")\n", "\n", - "# not in the book – just beautifies and saves Figure 3–8\n", - "plt.xlabel('False Positive Rate (Fall-Out)')\n", - "plt.ylabel('True Positive Rate (Recall)')\n", - "plt.grid()\n", + "# not in the book – just beautifies and saves Figure 3–8\n", + "plt.xlabel(\"Recall\")\n", + "plt.ylabel(\"Precision\")\n", "plt.axis([0, 1, 0, 1])\n", - "plt.legend(loc=\"lower right\")\n", - "save_fig(\"roc_curve_comparison_plot\")\n", + "plt.grid()\n", + "plt.legend(loc=\"lower left\")\n", + "save_fig(\"pr_curve_comparison_plot\")\n", "\n", "plt.show()" ] }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [], - "source": [ - "roc_auc_score(y_train_5, y_scores_forest)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -809,14 +800,23 @@ "We could use `cross_val_predict(forest_clf, X_train, y_train_5, cv=3)` to compute `y_train_pred_forest`, but since we already have the estimated probabilities, we can just use the default threshold of 50% probability to get the same predictions much faster:" ] }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [], + "source": [ + "y_train_pred_forest = y_probas_forest[:, 1] >= 0.5 # positive proba ≥ 50%\n", + "f1_score(y_train_5, y_train_pred_forest)" + ] + }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ - "y_train_pred_forest = y_probas_forest[:, 1] > 0.5\n", - "precision_score(y_train_5, y_train_pred_forest)" + "roc_auc_score(y_train_5, y_scores_forest)" ] }, { @@ -824,6 +824,15 @@ "execution_count": 57, "metadata": {}, "outputs": [], + "source": [ + "precision_score(y_train_5, y_train_pred_forest)" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [], "source": [ "recall_score(y_train_5, y_train_pred_forest)" ] @@ -844,7 +853,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 59, "metadata": {}, "outputs": [], "source": [ @@ -856,7 +865,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 60, "metadata": {}, "outputs": [], "source": [ @@ -865,7 +874,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 61, "metadata": {}, "outputs": [], "source": [ @@ -875,7 +884,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 62, "metadata": {}, "outputs": [], "source": [ @@ -885,7 +894,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 63, "metadata": {}, "outputs": [], "source": [ @@ -894,7 +903,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 64, "metadata": {}, "outputs": [], "source": [ @@ -910,7 +919,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 65, "metadata": {}, "outputs": [], "source": [ @@ -922,7 +931,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 66, "metadata": {}, "outputs": [], "source": [ @@ -934,7 +943,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 67, "metadata": {}, "outputs": [], "source": [ @@ -943,7 +952,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 68, "metadata": {}, "outputs": [], "source": [ @@ -952,7 +961,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 69, "metadata": {}, "outputs": [], "source": [ @@ -963,7 +972,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 70, "metadata": {}, "outputs": [], "source": [ @@ -979,7 +988,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 71, "metadata": {}, "outputs": [], "source": [ @@ -988,7 +997,7 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 72, "metadata": {}, "outputs": [], "source": [ @@ -1015,7 +1024,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 73, "metadata": {}, "outputs": [], "source": [ @@ -1028,7 +1037,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 74, "metadata": {}, "outputs": [], "source": [ @@ -1039,7 +1048,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 75, "metadata": {}, "outputs": [], "source": [ @@ -1054,32 +1063,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let's put all plots in a single figure for the book:" - ] - }, - { - "cell_type": "code", - "execution_count": 75, - "metadata": {}, - "outputs": [], - "source": [ - "# not in the book – this code generates Figure 3–9\n", - "fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 8))\n", - "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0, 0])\n", - "axs[0, 0].set_title(\"Confusion matrix\")\n", - "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0, 1],\n", - " normalize=\"true\", values_format=\".0%\")\n", - "axs[0, 1].set_title(\"CM normalized by row\")\n", - "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[1, 0],\n", - " sample_weight=sample_weight,\n", - " normalize=\"true\", values_format=\".0%\")\n", - "axs[1, 0].set_title(\"Errors normalized by row\")\n", - "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[1, 1],\n", - " sample_weight=sample_weight,\n", - " normalize=\"pred\", values_format=\".0%\")\n", - "axs[1, 1].set_title(\"Errors normalized by column\")\n", - "save_fig(\"confusion_matrix_plot\")\n", - "plt.show()" + "Let's put all plots in a couple of figures for the book:" ] }, { @@ -1087,6 +1071,43 @@ "execution_count": 76, "metadata": {}, "outputs": [], + "source": [ + "# not in the book – this code generates Figure 3–9\n", + "fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))\n", + "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0])\n", + "axs[0].set_title(\"Confusion matrix\")\n", + "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[1],\n", + " normalize=\"true\", values_format=\".0%\")\n", + "axs[1].set_title(\"CM normalized by row\")\n", + "save_fig(\"confusion_matrix_plot_1\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [], + "source": [ + "# not in the book – this code generates Figure 3–10\n", + "fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))\n", + "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0],\n", + " sample_weight=sample_weight,\n", + " normalize=\"true\", values_format=\".0%\")\n", + "axs[0].set_title(\"Errors normalized by row\")\n", + "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[1],\n", + " sample_weight=sample_weight,\n", + " normalize=\"pred\", values_format=\".0%\")\n", + "axs[1].set_title(\"Errors normalized by column\")\n", + "save_fig(\"confusion_matrix_plot_2\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [], "source": [ "cl_a, cl_b = '3', '5'\n", "X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]\n", @@ -1097,11 +1118,11 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 79, "metadata": {}, "outputs": [], "source": [ - "# not in the book – this code generates Figure 3–10\n", + "# not in the book – this code generates Figure 3–11\n", "size = 5\n", "pad = 0.2\n", "plt.figure(figsize=(size, size))\n", @@ -1145,7 +1166,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 80, "metadata": { "tags": [] }, @@ -1164,7 +1185,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 81, "metadata": {}, "outputs": [], "source": [ @@ -1180,7 +1201,7 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 82, "metadata": {}, "outputs": [], "source": [ @@ -1190,7 +1211,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 83, "metadata": {}, "outputs": [], "source": [ @@ -1202,7 +1223,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 84, "metadata": {}, "outputs": [], "source": [ @@ -1214,7 +1235,7 @@ }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 85, "metadata": {}, "outputs": [], "source": [ @@ -1230,7 +1251,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 86, "metadata": {}, "outputs": [], "source": [ @@ -1245,11 +1266,11 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 87, "metadata": {}, "outputs": [], "source": [ - "# not in the book – this code generates Figure 3–11\n", + "# not in the book – this code generates Figure 3–12\n", "plt.subplot(121); plot_digit(X_test_mod[0])\n", "plt.subplot(122); plot_digit(y_test_mod[0])\n", "save_fig(\"noisy_digit_example_plot\")\n", @@ -1258,7 +1279,7 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 88, "metadata": {}, "outputs": [], "source": [ @@ -1266,7 +1287,7 @@ "knn_clf.fit(X_train_mod, y_train_mod)\n", "clean_digit = knn_clf.predict([X_test_mod[0]])\n", "plot_digit(clean_digit)\n", - "save_fig(\"cleaned_digit_example_plot\") # not in the book – saves Figure 3–12\n", + "save_fig(\"cleaned_digit_example_plot\") # not in the book – saves Figure 3–13\n", "plt.show()" ] }, @@ -1279,7 +1300,7 @@ }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 89, "metadata": {}, "outputs": [], "source": [ @@ -1330,7 +1351,7 @@ }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 90, "metadata": {}, "outputs": [], "source": [ @@ -1356,7 +1377,7 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 91, "metadata": {}, "outputs": [], "source": [ @@ -1371,7 +1392,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 92, "metadata": {}, "outputs": [], "source": [ @@ -1380,7 +1401,7 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 93, "metadata": {}, "outputs": [], "source": [ @@ -1396,7 +1417,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 94, "metadata": {}, "outputs": [], "source": [ @@ -1435,7 +1456,7 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 95, "metadata": {}, "outputs": [], "source": [ @@ -1444,7 +1465,7 @@ }, { "cell_type": "code", - "execution_count": 94, + "execution_count": 96, "metadata": {}, "outputs": [], "source": [ @@ -1463,7 +1484,7 @@ }, { "cell_type": "code", - "execution_count": 95, + "execution_count": 97, "metadata": {}, "outputs": [], "source": [ @@ -1496,7 +1517,7 @@ }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 98, "metadata": {}, "outputs": [], "source": [ @@ -1521,7 +1542,7 @@ }, { "cell_type": "code", - "execution_count": 97, + "execution_count": 99, "metadata": {}, "outputs": [], "source": [ @@ -1539,7 +1560,7 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": 100, "metadata": {}, "outputs": [], "source": [ @@ -1548,7 +1569,7 @@ }, { "cell_type": "code", - "execution_count": 99, + "execution_count": 101, "metadata": {}, "outputs": [], "source": [ @@ -1564,7 +1585,7 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": 102, "metadata": {}, "outputs": [], "source": [ @@ -1580,7 +1601,7 @@ }, { "cell_type": "code", - "execution_count": 101, + "execution_count": 103, "metadata": { "tags": [] }, @@ -1620,7 +1641,7 @@ }, { "cell_type": "code", - "execution_count": 102, + "execution_count": 104, "metadata": {}, "outputs": [], "source": [ @@ -1644,7 +1665,7 @@ }, { "cell_type": "code", - "execution_count": 103, + "execution_count": 105, "metadata": {}, "outputs": [], "source": [ @@ -1667,7 +1688,7 @@ }, { "cell_type": "code", - "execution_count": 104, + "execution_count": 106, "metadata": {}, "outputs": [], "source": [ @@ -1707,7 +1728,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": 107, "metadata": {}, "outputs": [], "source": [ @@ -1724,7 +1745,7 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 108, "metadata": {}, "outputs": [], "source": [ @@ -1733,7 +1754,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 109, "metadata": {}, "outputs": [], "source": [ @@ -1763,7 +1784,7 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 110, "metadata": {}, "outputs": [], "source": [ @@ -1788,7 +1809,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 111, "metadata": {}, "outputs": [], "source": [ @@ -1804,7 +1825,7 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 112, "metadata": {}, "outputs": [], "source": [ @@ -1813,7 +1834,7 @@ }, { "cell_type": "code", - "execution_count": 111, + "execution_count": 113, "metadata": {}, "outputs": [], "source": [ @@ -1822,7 +1843,7 @@ }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 114, "metadata": {}, "outputs": [], "source": [ @@ -1845,7 +1866,7 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 115, "metadata": {}, "outputs": [], "source": [ @@ -1867,7 +1888,7 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 116, "metadata": {}, "outputs": [], "source": [ @@ -1876,7 +1897,7 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 117, "metadata": {}, "outputs": [], "source": [ @@ -1896,7 +1917,7 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 118, "metadata": {}, "outputs": [], "source": [ @@ -1920,7 +1941,7 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 119, "metadata": {}, "outputs": [], "source": [ @@ -1937,7 +1958,7 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 120, "metadata": {}, "outputs": [], "source": [ @@ -1953,7 +1974,7 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 121, "metadata": {}, "outputs": [], "source": [ @@ -1970,7 +1991,7 @@ }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 122, "metadata": {}, "outputs": [], "source": [ @@ -1987,7 +2008,7 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 123, "metadata": {}, "outputs": [], "source": [ @@ -2011,7 +2032,7 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": 124, "metadata": {}, "outputs": [], "source": [ @@ -2038,7 +2059,7 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 125, "metadata": {}, "outputs": [], "source": [ @@ -2072,7 +2093,7 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 126, "metadata": {}, "outputs": [], "source": [ @@ -2082,7 +2103,7 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 127, "metadata": {}, "outputs": [], "source": [ @@ -2116,7 +2137,7 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 128, "metadata": {}, "outputs": [], "source": [ @@ -2143,7 +2164,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 129, "metadata": {}, "outputs": [], "source": [ @@ -2159,7 +2180,7 @@ }, { "cell_type": "code", - "execution_count": 128, + "execution_count": 130, "metadata": {}, "outputs": [], "source": [ @@ -2169,7 +2190,7 @@ }, { "cell_type": "code", - "execution_count": 129, + "execution_count": 131, "metadata": {}, "outputs": [], "source": [ @@ -2178,7 +2199,7 @@ }, { "cell_type": "code", - "execution_count": 130, + "execution_count": 132, "metadata": {}, "outputs": [], "source": [ @@ -2194,7 +2215,7 @@ }, { "cell_type": "code", - "execution_count": 131, + "execution_count": 133, "metadata": {}, "outputs": [], "source": [ @@ -2208,7 +2229,7 @@ }, { "cell_type": "code", - "execution_count": 132, + "execution_count": 134, "metadata": {}, "outputs": [], "source": [ @@ -2225,7 +2246,7 @@ }, { "cell_type": "code", - "execution_count": 133, + "execution_count": 135, "metadata": {}, "outputs": [], "source": [ @@ -2234,7 +2255,7 @@ }, { "cell_type": "code", - "execution_count": 134, + "execution_count": 136, "metadata": {}, "outputs": [], "source": [ @@ -2250,7 +2271,7 @@ }, { "cell_type": "code", - "execution_count": 135, + "execution_count": 137, "metadata": {}, "outputs": [], "source": [ @@ -2268,7 +2289,7 @@ }, { "cell_type": "code", - "execution_count": 136, + "execution_count": 138, "metadata": {}, "outputs": [], "source": [ @@ -2284,7 +2305,7 @@ }, { "cell_type": "code", - "execution_count": 137, + "execution_count": 139, "metadata": {}, "outputs": [], "source": [ @@ -2293,7 +2314,7 @@ }, { "cell_type": "code", - "execution_count": 138, + "execution_count": 140, "metadata": {}, "outputs": [], "source": [ @@ -2316,7 +2337,7 @@ }, { "cell_type": "code", - "execution_count": 139, + "execution_count": 141, "metadata": {}, "outputs": [], "source": [ @@ -2333,7 +2354,7 @@ }, { "cell_type": "code", - "execution_count": 140, + "execution_count": 142, "metadata": {}, "outputs": [], "source": [ @@ -2349,7 +2370,7 @@ }, { "cell_type": "code", - "execution_count": 141, + "execution_count": 143, "metadata": {}, "outputs": [], "source": [ @@ -2372,7 +2393,7 @@ }, { "cell_type": "code", - "execution_count": 142, + "execution_count": 144, "metadata": {}, "outputs": [], "source": [ @@ -2396,7 +2417,7 @@ }, { "cell_type": "code", - "execution_count": 143, + "execution_count": 145, "metadata": {}, "outputs": [], "source": [ @@ -2415,7 +2436,7 @@ }, { "cell_type": "code", - "execution_count": 144, + "execution_count": 146, "metadata": {}, "outputs": [], "source": [ @@ -2431,7 +2452,7 @@ }, { "cell_type": "code", - "execution_count": 145, + "execution_count": 147, "metadata": {}, "outputs": [], "source": [ @@ -2455,7 +2476,7 @@ }, { "cell_type": "code", - "execution_count": 146, + "execution_count": 148, "metadata": {}, "outputs": [], "source": [ @@ -2471,7 +2492,7 @@ }, { "cell_type": "code", - "execution_count": 147, + "execution_count": 149, "metadata": {}, "outputs": [], "source": [ @@ -2492,7 +2513,7 @@ }, { "cell_type": "code", - "execution_count": 148, + "execution_count": 150, "metadata": {}, "outputs": [], "source": [ @@ -2514,7 +2535,7 @@ }, { "cell_type": "code", - "execution_count": 149, + "execution_count": 151, "metadata": {}, "outputs": [], "source": [ @@ -2535,7 +2556,7 @@ }, { "cell_type": "code", - "execution_count": 150, + "execution_count": 152, "metadata": {}, "outputs": [], "source": [ @@ -2588,7 +2609,7 @@ }, { "cell_type": "code", - "execution_count": 151, + "execution_count": 153, "metadata": {}, "outputs": [], "source": [ @@ -2613,7 +2634,7 @@ }, { "cell_type": "code", - "execution_count": 152, + "execution_count": 154, "metadata": {}, "outputs": [], "source": [ @@ -2646,7 +2667,7 @@ }, { "cell_type": "code", - "execution_count": 153, + "execution_count": 155, "metadata": {}, "outputs": [], "source": [ @@ -2657,7 +2678,7 @@ }, { "cell_type": "code", - "execution_count": 154, + "execution_count": 156, "metadata": {}, "outputs": [], "source": [ @@ -2673,7 +2694,7 @@ }, { "cell_type": "code", - "execution_count": 155, + "execution_count": 157, "metadata": {}, "outputs": [], "source": [ @@ -2689,7 +2710,7 @@ }, { "cell_type": "code", - "execution_count": 156, + "execution_count": 158, "metadata": {}, "outputs": [], "source": [ @@ -2705,7 +2726,7 @@ }, { "cell_type": "code", - "execution_count": 157, + "execution_count": 159, "metadata": {}, "outputs": [], "source": [ @@ -2728,7 +2749,7 @@ }, { "cell_type": "code", - "execution_count": 158, + "execution_count": 160, "metadata": {}, "outputs": [], "source": [