diff --git a/03_classification.ipynb b/03_classification.ipynb index 499d157..846af67 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -82,7 +82,7 @@ "outputs": [], "source": [ "# Not in the book\n", - "mnist.keys()" + "print(mnist.DESCR)" ] }, { @@ -91,7 +91,8 @@ "metadata": {}, "outputs": [], "source": [ - "print(mnist.DESCR)" + "# Not in the book\n", + "mnist.keys()" ] }, { @@ -100,7 +101,8 @@ "metadata": {}, "outputs": [], "source": [ - "mnist.data" + "X, y = mnist.data, mnist.target\n", + "X" ] }, { @@ -109,7 +111,7 @@ "metadata": {}, "outputs": [], "source": [ - "mnist.target" + "X.shape" ] }, { @@ -118,10 +120,7 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "\n", - "X, y = mnist.data, mnist.target\n", - "X.shape" + "mnist.target" ] }, { @@ -180,17 +179,15 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "\n", "def plot_digit(image_data):\n", " image = image_data.reshape(28, 28)\n", - " plt.imshow(image, cmap=mpl.cm.binary, interpolation=\"nearest\")\n", + " plt.imshow(image, cmap=\"binary\")\n", " plt.axis(\"off\")\n", "\n", "some_digit = X[0]\n", "plot_digit(some_digit)\n", - "\n", "save_fig(\"some_digit_plot\")\n", "plt.show()" ] @@ -210,6 +207,7 @@ "metadata": {}, "outputs": [], "source": [ + "# not in the book\n", "plt.figure(figsize=(9, 9))\n", "for idx, image_data in enumerate(X[:100]):\n", " plt.subplot(10, 10, idx + 1)\n", @@ -241,7 +239,7 @@ "metadata": {}, "outputs": [], "source": [ - "y_train_5 = (y_train == '5')\n", + "y_train_5 = (y_train == '5') # True for all 5s, False for all other digits\n", "y_test_5 = (y_test == '5')" ] }, @@ -266,17 +264,6 @@ "sgd_clf.predict([some_digit])" ] }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.model_selection import cross_val_score\n", - "\n", - "cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -291,6 +278,17 @@ "## Measuring Accuracy Using Cross-Validation" ] }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import cross_val_score\n", + "\n", + "cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")" + ] + }, { "cell_type": "code", "execution_count": 19, @@ -298,20 +296,21 @@ "outputs": [], "source": [ "from sklearn.model_selection import StratifiedKFold\n", + "from sklearn.base import clone\n", "\n", - "skfolds = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)\n", - "\n", + "skfolds = StratifiedKFold(n_splits=3) # add shuffle=True is the dataset is not\n", + " # already shuffled\n", "for train_index, test_index in skfolds.split(X_train, y_train_5):\n", + " clone_clf = clone(sgd_clf)\n", " X_train_folds = X_train[train_index]\n", " y_train_folds = y_train_5[train_index]\n", " X_test_fold = X_train[test_index]\n", " y_test_fold = y_train_5[test_index]\n", "\n", - " sgd_clf_cv = SGDClassifier(random_state=42)\n", - " sgd_clf_cv.fit(X_train_folds, y_train_folds)\n", - " y_pred = sgd_clf_cv.predict(X_test_fold)\n", + " clone_clf.fit(X_train_folds, y_train_folds)\n", + " y_pred = clone_clf.predict(X_test_fold)\n", " n_correct = sum(y_pred == y_test_fold)\n", - " print(n_correct / len(y_pred))" + " print(n_correct / len(y_pred)) # prints 0.95035, 0.96035 and 0.9604" ] }, { @@ -324,7 +323,7 @@ "\n", "dummy_clf = DummyClassifier()\n", "dummy_clf.fit(X_train, y_train_5)\n", - "np.any(dummy_clf.predict(X_train))" + "print(any(dummy_clf.predict(X_train))) # prints False: no 5s detected" ] }, { @@ -399,7 +398,10 @@ "metadata": {}, "outputs": [], "source": [ + "# Not in the book\n", "cm = confusion_matrix(y_train_5, y_train_pred)\n", + "\n", + "# Precision = TP / (FP + TP)\n", "cm[1, 1] / (cm[0, 1] + cm[1, 1])" ] }, @@ -418,6 +420,8 @@ "metadata": {}, "outputs": [], "source": [ + "# Not in the book\n", + "# Recall = TP / (FN + TP)\n", "cm[1, 1] / (cm[1, 0] + cm[1, 1])" ] }, @@ -483,14 +487,25 @@ "metadata": {}, "outputs": [], "source": [ - "threshold = 8000\n", + "# Not in the book\n", + "# Using threshold 0, we get exactly the same predictions as with predict()\n", + "(y_train_pred == (y_scores > 0)).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "threshold = 3000\n", "y_some_digit_pred = (y_scores > threshold)\n", "y_some_digit_pred" ] }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -500,7 +515,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ @@ -511,21 +526,20 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ - "recall_90_precision = recalls[np.argmax(precisions >= 0.90)]\n", - "threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]\n", - "\n", "plt.figure(figsize=(8, 4)) # not in the book\n", "plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n", "plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n", - "plt.vlines(threshold_90_precision, 0, 1.0, \"k\", \"dotted\", label=\"threshold\")\n", + "plt.vlines(threshold, 0, 1.0, \"k\", \"dotted\", label=\"threshold\")\n", "\n", - "# not in the book (just beautifies the figure)\n", - "plt.plot(threshold_90_precision, recall_90_precision, \"go\")\n", - "plt.plot(threshold_90_precision, 0.90, \"bo\")\n", + "# not in the book\n", + "# beautify the figure\n", + "idx = (thresholds >= threshold).argmax() # first index ≥ threshold\n", + "plt.plot(thresholds[idx], precisions[idx], \"bo\")\n", + "plt.plot(thresholds[idx], recalls[idx], \"go\")\n", "plt.axis([-50000, 50000, 0, 1])\n", "plt.grid(True)\n", "plt.xlabel(\"Threshold\")\n", @@ -535,33 +549,34 @@ "plt.show()" ] }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [], - "source": [ - "(y_train_pred == (y_scores > 0)).all()" - ] - }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ - "plt.figure(figsize=(8, 6)) # not in the book\n", + "import matplotlib.patches as patches # not in the book\n", "\n", - "plt.plot(recalls, precisions, linewidth=2)\n", - "plt.plot([recall_90_precision, recall_90_precision], [0., 0.9], \"k:\")\n", - "plt.plot([0.0, recall_90_precision], [0.9, 0.9], \"k:\")\n", - "plt.plot([recall_90_precision], [0.9], \"ko\")\n", + "plt.figure(figsize=(6, 5)) # not in the book\n", + "\n", + "plt.plot(recalls, precisions, linewidth=2, label=\"Precision/Recall curve\")\n", "\n", "# not in the book (just beautifies the figure)\n", + "plt.plot([recalls[idx], recalls[idx]], [0., precisions[idx]], \"k:\")\n", + "plt.plot([0.0, recalls[idx]], [precisions[idx], precisions[idx]], \"k:\")\n", + "plt.plot([recalls[idx]], [precisions[idx]], \"ko\",\n", + " label=\"Point at threshold 3,000\")\n", + "plt.gca().add_patch(patches.FancyArrowPatch(\n", + " (0.79, 0.60), (0.61, 0.78),\n", + " connectionstyle=\"arc3,rad=.2\",\n", + " arrowstyle=\"Simple, tail_width=1.5, head_width=8, head_length=10\",\n", + " color=\"#444444\"))\n", + "plt.text(0.56, 0.62, \"Higher\\nthreshold\", fontsize=14, color=\"#333333\")\n", "plt.xlabel(\"Recall\")\n", "plt.ylabel(\"Precision\")\n", "plt.axis([0, 1, 0, 1])\n", "plt.grid(True)\n", + "plt.legend(loc=\"lower left\")\n", "save_fig(\"precision_vs_recall_plot\")\n", "\n", "plt.show()" @@ -573,7 +588,9 @@ "metadata": {}, "outputs": [], "source": [ - "threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]" + "idx_for_90_precision = (precisions >= 0.90).argmax()\n", + "threshold_for_90_precision = thresholds[idx_for_90_precision]\n", + "threshold_for_90_precision" ] }, { @@ -582,7 +599,7 @@ "metadata": {}, "outputs": [], "source": [ - "threshold_90_precision" + "y_train_pred_90 = (y_scores >= threshold_for_90_precision)" ] }, { @@ -591,7 +608,7 @@ "metadata": {}, "outputs": [], "source": [ - "y_train_pred_90 = (y_scores >= threshold_90_precision)" + "precision_score(y_train_5, y_train_pred_90)" ] }, { @@ -600,16 +617,8 @@ "metadata": {}, "outputs": [], "source": [ - "precision_score(y_train_5, y_train_pred_90)" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [], - "source": [ - "recall_score(y_train_5, y_train_pred_90)" + "recall_at_90_precision = recall_score(y_train_5, y_train_pred_90)\n", + "recall_at_90_precision" ] }, { @@ -621,7 +630,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 44, "metadata": {}, "outputs": [], "source": [ @@ -632,29 +641,30 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ - "def plot_roc_curve(fpr, tpr, label=None):\n", - " plt.plot(fpr, tpr, linewidth=2, label=label) # ROC curve\n", - " plt.plot([0, 1], [0, 1], 'k--') # dashed diagonal\n", + "idx_for_threshold_at_90 = (thresholds <= threshold_for_90_precision).argmax()\n", + "tpr_90, fpr_90 = tpr[idx_for_threshold_at_90], fpr[idx_for_threshold_at_90]\n", "\n", - " # not in the book (just beautifies the figure)\n", - " plt.axis([0, 1, 0, 1])\n", - " plt.xlabel('False Positive Rate (Fall-Out)')\n", - " plt.ylabel('True Positive Rate (Recall)')\n", - " plt.grid(True)\n", - "\n", - "\n", - "plt.figure(figsize=(8, 6)) # Not in the book\n", - "plot_roc_curve(fpr, tpr)\n", + "plt.figure(figsize=(6, 5)) # Not in the book\n", + "plt.plot(fpr, tpr, linewidth=2, label=\"ROC curve\")\n", + "plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier's ROC curve\")\n", + "plt.plot([fpr_90], [tpr_90], \"ko\", label=\"Threshold for 90% precision\")\n", "\n", "# not in the book (just beautifies the figure)\n", - "fpr_90 = fpr[np.argmax(tpr >= recall_90_precision)]\n", - "plt.plot([fpr_90, fpr_90], [0., recall_90_precision], \"r:\")\n", - "plt.plot([0.0, fpr_90], [recall_90_precision, recall_90_precision], \"r:\")\n", - "plt.plot([fpr_90], [recall_90_precision], \"ro\")\n", + "plt.gca().add_patch(patches.FancyArrowPatch(\n", + " (0.20, 0.89), (0.07, 0.70),\n", + " connectionstyle=\"arc3,rad=.4\",\n", + " arrowstyle=\"Simple, tail_width=1.5, head_width=8, head_length=10\",\n", + " color=\"#444444\"))\n", + "plt.text(0.12, 0.71, \"Higher\\nthreshold\", fontsize=14, color=\"#333333\")\n", + "plt.xlabel('False Positive Rate (Fall-Out)')\n", + "plt.ylabel('True Positive Rate (Recall)')\n", + "plt.grid(True)\n", + "plt.axis([0, 1, 0, 1])\n", + "plt.legend(loc=\"lower right\")\n", "save_fig(\"roc_curve_plot\")\n", "\n", "plt.show()" @@ -662,7 +672,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ @@ -671,15 +681,30 @@ "roc_auc_score(y_train_5, y_scores)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Warning:** the following cell may take a few minutes to run." + ] + }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "\n", - "forest_clf = RandomForestClassifier(random_state=42)\n", + "forest_clf = RandomForestClassifier(random_state=42)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ "y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,\n", " method=\"predict_proba\")" ] @@ -690,7 +715,14 @@ "metadata": {}, "outputs": [], "source": [ - "y_probas_forest.shape" + "y_probas_forest[:2]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "These are _estimated probabilities_. Among the images that the model classified as positive with a probability between 50% and 60%, there are actually about 94% positive images:" ] }, { @@ -699,8 +731,9 @@ "metadata": {}, "outputs": [], "source": [ - "y_scores_forest = y_probas_forest[:, 1] # 2nd column = proba of positive class\n", - "fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)" + "# Not in the code\n", + "idx_50_to_60 = (y_probas_forest[:, 1] > 0.50) & (y_probas_forest[:, 1] < 0.60)\n", + "print(f\"{(y_train_5[idx_50_to_60]).sum() / idx_50_to_60.sum():.1%}\")" ] }, { @@ -709,21 +742,9 @@ "metadata": {}, "outputs": [], "source": [ - "# not in the book\n", - "recall_for_forest = tpr_forest[np.argmax(fpr_forest >= fpr_90)]\n", - "\n", - "plt.figure(figsize=(8, 6))\n", - "plt.plot(fpr, tpr, \"b:\", linewidth=2, label=\"SGD\")\n", - "plot_roc_curve(fpr_forest, tpr_forest, \"Random Forest\")\n", - "plt.plot([fpr_90, fpr_90], [0., recall_90_precision], \"r:\")\n", - "plt.plot([0.0, fpr_90], [recall_90_precision, recall_90_precision], \"r:\")\n", - "plt.plot([fpr_90], [recall_90_precision], \"ro\")\n", - "plt.plot([fpr_90, fpr_90], [0., recall_for_forest], \"r:\")\n", - "plt.plot([fpr_90], [recall_for_forest], \"ro\")\n", - "plt.grid(True)\n", - "plt.legend(loc=\"lower right\")\n", - "save_fig(\"roc_curve_comparison_plot\")\n", - "plt.show()" + "y_scores_forest = y_probas_forest[:, 1]\n", + "fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,\n", + " y_scores_forest)" ] }, { @@ -732,7 +753,21 @@ "metadata": {}, "outputs": [], "source": [ - "roc_auc_score(y_train_5, y_scores_forest)" + "plt.figure(figsize=(6, 5)) # not in the book\n", + "\n", + "plt.plot(fpr_forest, tpr_forest, \"b-\", linewidth=2, label=\"Random Forest\")\n", + "plt.plot(fpr, tpr, \"--\", linewidth=2, label=\"SGD\")\n", + "plt.plot([0, 1], [0, 1], 'k:', label=\"Random classifier\")\n", + "\n", + "# not in the book\n", + "plt.xlabel('False Positive Rate (Fall-Out)')\n", + "plt.ylabel('True Positive Rate (Recall)')\n", + "plt.grid(True)\n", + "plt.axis([0, 1, 0, 1])\n", + "plt.legend(loc=\"lower right\")\n", + "save_fig(\"roc_curve_comparison_plot\")\n", + "\n", + "plt.show()" ] }, { @@ -741,8 +776,14 @@ "metadata": {}, "outputs": [], "source": [ - "y_train_pred_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3)\n", - "precision_score(y_train_5, y_train_pred_forest)" + "roc_auc_score(y_train_5, y_scores_forest)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We could use `cross_val_predict(forest_clf, X_train, y_train_5, cv=3)` to compute `y_train_pred_forest`, but since we already have the estimated probabilities, we can just use the default threshold of 50% probability to get the same predictions much faster:" ] }, { @@ -750,6 +791,16 @@ "execution_count": 54, "metadata": {}, "outputs": [], + "source": [ + "y_train_pred_forest = y_probas_forest[:, 1] > 0.5\n", + "precision_score(y_train_5, y_train_pred_forest)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [], "source": [ "recall_score(y_train_5, y_train_pred_forest)" ] @@ -770,25 +821,14 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "from sklearn.svm import SVC\n", "\n", - "svm_clf = SVC(gamma=\"auto\", random_state=42)\n", - "svm_clf.fit(X_train[:2000], y_train[:2000]) # y_train, not y_train_5\n", - "svm_clf.predict([some_digit])" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "metadata": {}, - "outputs": [], - "source": [ - "some_digit_scores = svm_clf.decision_function([some_digit])\n", - "np.round(some_digit_scores, 2)" + "svm_clf = SVC(random_state=42)\n", + "svm_clf.fit(X_train[:2000], y_train[:2000]) # y_train, not y_train_5" ] }, { @@ -797,10 +837,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Not in the book\n", - "svm_clf.decision_function_shape = \"ovo\"\n", - "some_digit_scores_ovo = svm_clf.decision_function([some_digit])\n", - "np.round(some_digit_scores_ovo, 2)" + "svm_clf.predict([some_digit])" ] }, { @@ -809,8 +846,8 @@ "metadata": {}, "outputs": [], "source": [ - "class_id = np.argmax(some_digit_scores)\n", - "class_id" + "some_digit_scores = svm_clf.decision_function([some_digit])\n", + "some_digit_scores.round(2)" ] }, { @@ -819,7 +856,8 @@ "metadata": {}, "outputs": [], "source": [ - "svm_clf.classes_" + "class_id = some_digit_scores.argmax()\n", + "class_id" ] }, { @@ -828,7 +866,7 @@ "metadata": {}, "outputs": [], "source": [ - "svm_clf.classes_[class_id]" + "svm_clf.classes_" ] }, { @@ -837,11 +875,14 @@ "metadata": {}, "outputs": [], "source": [ - "from sklearn.multiclass import OneVsRestClassifier\n", - "\n", - "ovr_clf = OneVsRestClassifier(SVC(gamma=\"auto\", random_state=42))\n", - "ovr_clf.fit(X_train[:2000], y_train[:2000])\n", - "ovr_clf.predict([some_digit])" + "svm_clf.classes_[class_id]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you want `decision_function()` to return all 45 scores, you can set the `decision_function_shape` hyperparameter to `\"ovo\"`. The default value is `\"ovr\"`, but don't let this confuse you: `SVC` always uses OvO for training. This hyperparameter only affects whether or not the 45 scores get aggregated or not:" ] }, { @@ -850,7 +891,10 @@ "metadata": {}, "outputs": [], "source": [ - "len(ovr_clf.estimators_)" + "# Not in the book\n", + "svm_clf.decision_function_shape = \"ovo\"\n", + "some_digit_scores_ovo = svm_clf.decision_function([some_digit])\n", + "some_digit_scores_ovo.round(2)" ] }, { @@ -859,9 +903,11 @@ "metadata": {}, "outputs": [], "source": [ - "sgd_clf = SGDClassifier(random_state=42)\n", - "sgd_clf.fit(X_train[:2000], y_train[:2000])\n", - "sgd_clf.predict([some_digit])" + "from sklearn.multiclass import OneVsRestClassifier\n", + "\n", + "ovr_clf = OneVsRestClassifier(SVC(random_state=42))\n", + "ovr_clf.fit(X_train[:2000], y_train[:2000])\n", + "ovr_clf.predict([some_digit])" ] }, { @@ -870,7 +916,7 @@ "metadata": {}, "outputs": [], "source": [ - "np.round(sgd_clf.decision_function([some_digit]))" + "len(ovr_clf.estimators_)" ] }, { @@ -879,8 +925,9 @@ "metadata": {}, "outputs": [], "source": [ - "cross_val_score(sgd_clf, X_train[:2000], y_train[:2000],\n", - " cv=3, scoring=\"accuracy\")" + "sgd_clf = SGDClassifier(random_state=42)\n", + "sgd_clf.fit(X_train, y_train)\n", + "sgd_clf.predict([some_digit])" ] }, { @@ -888,13 +935,37 @@ "execution_count": 66, "metadata": {}, "outputs": [], + "source": [ + "sgd_clf.decision_function([some_digit]).round()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Warning:** the following two cells make take a few minutes each to run:" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [], + "source": [ + "cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring=\"accuracy\")" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler\n", "\n", "scaler = StandardScaler()\n", - "X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))\n", - "cross_val_score(sgd_clf, X_train_scaled[:2000], y_train[:2000],\n", - " cv=3, scoring=\"accuracy\")" + "X_train_scaled = scaler.fit_transform(X_train.astype(\"float64\"))\n", + "cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring=\"accuracy\")" ] }, { @@ -913,38 +984,14 @@ }, { "cell_type": "code", - "execution_count": 67, - "metadata": {}, - "outputs": [], - "source": [ - "y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)" - ] - }, - { - "cell_type": "code", - "execution_count": 68, + "execution_count": 69, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import ConfusionMatrixDisplay\n", "\n", + "y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)\n", "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred)\n", - "save_fig(\"confusion_matrix_plot\")\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 69, - "metadata": {}, - "outputs": [], - "source": [ - "error_idx = y_train_pred != y_train\n", - "y_train_pred_error = y_train_pred[error_idx]\n", - "y_train_error = y_train[error_idx]\n", - "ConfusionMatrixDisplay.from_predictions(y_train_error, y_train_pred_error,\n", - " normalize=\"pred\", values_format=\".0%\")\n", - "save_fig(\"confusion_matrix_errors_plot\", tight_layout=False)\n", "plt.show()" ] }, @@ -953,6 +1000,62 @@ "execution_count": 70, "metadata": {}, "outputs": [], + "source": [ + "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred,\n", + " normalize=\"true\", values_format=\".0%\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [], + "source": [ + "sample_weight = (y_train_pred != y_train)\n", + "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred,\n", + " sample_weight=sample_weight,\n", + " normalize=\"true\", values_format=\".0%\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's put all plots in a single figure for the book:" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [], + "source": [ + "# Not in the book\n", + "fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 8))\n", + "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0, 0])\n", + "axs[0, 0].set_title(\"Confusion matrix\")\n", + "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0, 1],\n", + " normalize=\"true\", values_format=\".0%\")\n", + "axs[0, 1].set_title(\"CM normalized by row\")\n", + "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[1, 0],\n", + " sample_weight=sample_weight,\n", + " normalize=\"true\", values_format=\".0%\")\n", + "axs[1, 0].set_title(\"Errors normalized by row\")\n", + "ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[1, 1],\n", + " sample_weight=sample_weight,\n", + " normalize=\"pred\", values_format=\".0%\")\n", + "axs[1, 1].set_title(\"Errors normalized by column\")\n", + "save_fig(\"confusion_matrix_plot\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [], "source": [ "cl_a, cl_b = '3', '5'\n", "X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]\n", @@ -963,7 +1066,7 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 74, "metadata": {}, "outputs": [], "source": [ @@ -1011,14 +1114,15 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 75, "metadata": {}, "outputs": [], "source": [ + "import numpy as np\n", "from sklearn.neighbors import KNeighborsClassifier\n", "\n", "y_train_large = (y_train >= '7')\n", - "y_train_odd = (y_train.astype(np.uint8) % 2 == 1)\n", + "y_train_odd = (y_train.astype('int8') % 2 == 1)\n", "y_multilabel = np.c_[y_train_large, y_train_odd]\n", "\n", "knn_clf = KNeighborsClassifier()\n", @@ -1027,7 +1131,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 76, "metadata": {}, "outputs": [], "source": [ @@ -1043,14 +1147,45 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 77, + "metadata": {}, + "outputs": [], + "source": [ + "y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3)" + ] + }, + { + "cell_type": "code", + "execution_count": 78, "metadata": {}, "outputs": [], "source": [ - "y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3)\n", "f1_score(y_multilabel, y_train_knn_pred, average=\"macro\")" ] }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [], + "source": [ + "# not in the book\n", + "f1_score(y_multilabel, y_train_knn_pred, average=\"weighted\")" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.multioutput import ClassifierChain\n", + "\n", + "chain_clf = ClassifierChain(SVC(), cv=3, random_state=42)\n", + "chain_clf.fit(X_train[:2000], y_multilabel[:2000])\n", + "chain_clf.predict([some_digit])" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1060,11 +1195,11 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 81, "metadata": {}, "outputs": [], "source": [ - "np.random.seed(42)\n", + "np.random.seed(42) # as always, this is to make this example reproducible\n", "noise = np.random.randint(0, 100, (len(X_train), 784))\n", "X_train_mod = X_train + noise\n", "noise = np.random.randint(0, 100, (len(X_test), 784))\n", @@ -1075,10 +1210,11 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 82, "metadata": {}, "outputs": [], "source": [ + "# not in the book\n", "some_index = 0\n", "plt.subplot(121); plot_digit(X_test_mod[some_index])\n", "plt.subplot(122); plot_digit(y_test_mod[some_index])\n", @@ -1088,7 +1224,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 83, "metadata": {}, "outputs": [], "source": [ @@ -1099,6 +1235,36 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Extra Material — Calibrating Estimated Probabilities" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", "metadata": {}, @@ -1129,7 +1295,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 85, "metadata": {}, "outputs": [], "source": [ @@ -1155,7 +1321,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 86, "metadata": {}, "outputs": [], "source": [ @@ -1170,7 +1336,7 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 87, "metadata": {}, "outputs": [], "source": [ @@ -1179,7 +1345,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 88, "metadata": {}, "outputs": [], "source": [ @@ -1195,7 +1361,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 89, "metadata": {}, "outputs": [], "source": [ @@ -1234,7 +1400,7 @@ }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 90, "metadata": {}, "outputs": [], "source": [ @@ -1243,7 +1409,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 91, "metadata": {}, "outputs": [], "source": [ @@ -1262,7 +1428,7 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 92, "metadata": {}, "outputs": [], "source": [ @@ -1295,7 +1461,7 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 93, "metadata": {}, "outputs": [], "source": [ @@ -1320,7 +1486,7 @@ }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 94, "metadata": {}, "outputs": [], "source": [ @@ -1338,7 +1504,7 @@ }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 95, "metadata": {}, "outputs": [], "source": [ @@ -1347,7 +1513,7 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 96, "metadata": {}, "outputs": [], "source": [ @@ -1363,7 +1529,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 97, "metadata": {}, "outputs": [], "source": [ @@ -1379,7 +1545,7 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 98, "metadata": { "tags": [] }, @@ -1407,14 +1573,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Exercise: _Tackle the Titanic dataset. A great place to start is on [Kaggle](https://www.kaggle.com/c/titanic)._" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The goal is to predict whether or not a passenger survived based on attributes such as their age, sex, passenger class, where they embarked and so on." + "Exercise: _Tackle the Titanic dataset. A great place to start is on [Kaggle](https://www.kaggle.com/c/titanic). Alternatively, you can download the data from https://homl.info/titanic.tgz and unzip this tarball like you did for the housing data in Chapter 2. This will give you two CSV files: _train.csv_ and _test.csv_ which you can load using `pandas.read_csv()`. The goal is to train a classifier that can predict the `Survived` column based on the other columns._" ] }, { @@ -1426,7 +1585,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 99, "metadata": {}, "outputs": [], "source": [ @@ -1450,7 +1609,7 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 100, "metadata": {}, "outputs": [], "source": [ @@ -1473,7 +1632,7 @@ }, { "cell_type": "code", - "execution_count": 94, + "execution_count": 101, "metadata": {}, "outputs": [], "source": [ @@ -1497,6 +1656,13 @@ "* **Embarked**: where the passenger embarked the Titanic" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The goal is to predict whether or not a passenger survived based on attributes such as their age, sex, passenger class, where they embarked and so on." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1506,7 +1672,7 @@ }, { "cell_type": "code", - "execution_count": 95, + "execution_count": 102, "metadata": {}, "outputs": [], "source": [ @@ -1523,7 +1689,7 @@ }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 103, "metadata": {}, "outputs": [], "source": [ @@ -1532,7 +1698,7 @@ }, { "cell_type": "code", - "execution_count": 97, + "execution_count": 104, "metadata": {}, "outputs": [], "source": [ @@ -1562,7 +1728,7 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": 105, "metadata": {}, "outputs": [], "source": [ @@ -1587,7 +1753,7 @@ }, { "cell_type": "code", - "execution_count": 99, + "execution_count": 106, "metadata": {}, "outputs": [], "source": [ @@ -1603,7 +1769,7 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": 107, "metadata": {}, "outputs": [], "source": [ @@ -1612,7 +1778,7 @@ }, { "cell_type": "code", - "execution_count": 101, + "execution_count": 108, "metadata": {}, "outputs": [], "source": [ @@ -1621,7 +1787,7 @@ }, { "cell_type": "code", - "execution_count": 102, + "execution_count": 109, "metadata": {}, "outputs": [], "source": [ @@ -1644,7 +1810,7 @@ }, { "cell_type": "code", - "execution_count": 103, + "execution_count": 110, "metadata": {}, "outputs": [], "source": [ @@ -1666,7 +1832,7 @@ }, { "cell_type": "code", - "execution_count": 104, + "execution_count": 111, "metadata": {}, "outputs": [], "source": [ @@ -1675,7 +1841,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": 112, "metadata": {}, "outputs": [], "source": [ @@ -1695,7 +1861,7 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 113, "metadata": {}, "outputs": [], "source": [ @@ -1719,7 +1885,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 114, "metadata": {}, "outputs": [], "source": [ @@ -1736,7 +1902,7 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 115, "metadata": {}, "outputs": [], "source": [ @@ -1752,7 +1918,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 116, "metadata": {}, "outputs": [], "source": [ @@ -1769,7 +1935,7 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 117, "metadata": {}, "outputs": [], "source": [ @@ -1786,7 +1952,7 @@ }, { "cell_type": "code", - "execution_count": 111, + "execution_count": 118, "metadata": {}, "outputs": [], "source": [ @@ -1810,7 +1976,7 @@ }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 119, "metadata": {}, "outputs": [], "source": [ @@ -1837,7 +2003,7 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 120, "metadata": {}, "outputs": [], "source": [ @@ -1871,7 +2037,7 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 121, "metadata": {}, "outputs": [], "source": [ @@ -1881,7 +2047,7 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 122, "metadata": {}, "outputs": [], "source": [ @@ -1915,7 +2081,7 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 123, "metadata": {}, "outputs": [], "source": [ @@ -1942,7 +2108,7 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 124, "metadata": {}, "outputs": [], "source": [ @@ -1958,7 +2124,7 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 125, "metadata": {}, "outputs": [], "source": [ @@ -1968,7 +2134,7 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 126, "metadata": {}, "outputs": [], "source": [ @@ -1977,7 +2143,7 @@ }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 127, "metadata": {}, "outputs": [], "source": [ @@ -1993,7 +2159,7 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 128, "metadata": {}, "outputs": [], "source": [ @@ -2007,7 +2173,7 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": 129, "metadata": {}, "outputs": [], "source": [ @@ -2024,7 +2190,7 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 130, "metadata": {}, "outputs": [], "source": [ @@ -2033,7 +2199,7 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 131, "metadata": {}, "outputs": [], "source": [ @@ -2049,7 +2215,7 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 132, "metadata": {}, "outputs": [], "source": [ @@ -2068,7 +2234,7 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 133, "metadata": {}, "outputs": [], "source": [ @@ -2084,7 +2250,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 134, "metadata": {}, "outputs": [], "source": [ @@ -2093,7 +2259,7 @@ }, { "cell_type": "code", - "execution_count": 128, + "execution_count": 135, "metadata": {}, "outputs": [], "source": [ @@ -2116,7 +2282,7 @@ }, { "cell_type": "code", - "execution_count": 129, + "execution_count": 136, "metadata": {}, "outputs": [], "source": [ @@ -2133,7 +2299,7 @@ }, { "cell_type": "code", - "execution_count": 130, + "execution_count": 137, "metadata": {}, "outputs": [], "source": [ @@ -2149,7 +2315,7 @@ }, { "cell_type": "code", - "execution_count": 131, + "execution_count": 138, "metadata": {}, "outputs": [], "source": [ @@ -2172,7 +2338,7 @@ }, { "cell_type": "code", - "execution_count": 132, + "execution_count": 139, "metadata": {}, "outputs": [], "source": [ @@ -2196,7 +2362,7 @@ }, { "cell_type": "code", - "execution_count": 133, + "execution_count": 140, "metadata": {}, "outputs": [], "source": [ @@ -2215,7 +2381,7 @@ }, { "cell_type": "code", - "execution_count": 134, + "execution_count": 141, "metadata": {}, "outputs": [], "source": [ @@ -2231,7 +2397,7 @@ }, { "cell_type": "code", - "execution_count": 135, + "execution_count": 142, "metadata": {}, "outputs": [], "source": [ @@ -2255,7 +2421,7 @@ }, { "cell_type": "code", - "execution_count": 136, + "execution_count": 143, "metadata": {}, "outputs": [], "source": [ @@ -2271,7 +2437,7 @@ }, { "cell_type": "code", - "execution_count": 137, + "execution_count": 144, "metadata": {}, "outputs": [], "source": [ @@ -2292,7 +2458,7 @@ }, { "cell_type": "code", - "execution_count": 138, + "execution_count": 145, "metadata": {}, "outputs": [], "source": [ @@ -2314,7 +2480,7 @@ }, { "cell_type": "code", - "execution_count": 139, + "execution_count": 146, "metadata": {}, "outputs": [], "source": [ @@ -2335,7 +2501,7 @@ }, { "cell_type": "code", - "execution_count": 140, + "execution_count": 147, "metadata": {}, "outputs": [], "source": [ @@ -2388,7 +2554,7 @@ }, { "cell_type": "code", - "execution_count": 141, + "execution_count": 148, "metadata": {}, "outputs": [], "source": [ @@ -2413,7 +2579,7 @@ }, { "cell_type": "code", - "execution_count": 142, + "execution_count": 149, "metadata": {}, "outputs": [], "source": [ @@ -2446,7 +2612,7 @@ }, { "cell_type": "code", - "execution_count": 143, + "execution_count": 150, "metadata": {}, "outputs": [], "source": [ @@ -2457,7 +2623,7 @@ }, { "cell_type": "code", - "execution_count": 144, + "execution_count": 151, "metadata": {}, "outputs": [], "source": [ @@ -2473,7 +2639,7 @@ }, { "cell_type": "code", - "execution_count": 145, + "execution_count": 152, "metadata": {}, "outputs": [], "source": [ @@ -2489,7 +2655,7 @@ }, { "cell_type": "code", - "execution_count": 146, + "execution_count": 153, "metadata": {}, "outputs": [], "source": [ @@ -2505,7 +2671,7 @@ }, { "cell_type": "code", - "execution_count": 147, + "execution_count": 154, "metadata": {}, "outputs": [], "source": [ @@ -2528,7 +2694,7 @@ }, { "cell_type": "code", - "execution_count": 148, + "execution_count": 155, "metadata": {}, "outputs": [], "source": [