diff --git a/03_classification.ipynb b/03_classification.ipynb index 0e7545d..0717be2 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -484,14 +484,20 @@ " plt.grid(True) # Not shown\n", " plt.axis([-50000, 50000, 0, 1]) # Not shown\n", "\n", - "plt.figure(figsize=(8, 4)) # Not shown\n", + "\n", + "\n", + "recall_90_precision = recalls[np.argmax(precisions >= 0.90)]\n", + "threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]\n", + "\n", + "\n", + "plt.figure(figsize=(8, 4)) # Not shown\n", "plot_precision_recall_vs_threshold(precisions, recalls, thresholds)\n", - "plt.plot([7813, 7813], [0., 0.9], \"r:\") # Not shown\n", - "plt.plot([-50000, 7813], [0.9, 0.9], \"r:\") # Not shown\n", - "plt.plot([-50000, 7813], [0.4368, 0.4368], \"r:\")# Not shown\n", - "plt.plot([7813], [0.9], \"ro\") # Not shown\n", - "plt.plot([7813], [0.4368], \"ro\") # Not shown\n", - "save_fig(\"precision_recall_vs_threshold_plot\") # Not shown\n", + "plt.plot([threshold_90_precision, threshold_90_precision], [0., 0.9], \"r:\") # Not shown\n", + "plt.plot([-50000, threshold_90_precision], [0.9, 0.9], \"r:\") # Not shown\n", + "plt.plot([-50000, threshold_90_precision], [recall_90_precision, recall_90_precision], \"r:\")# Not shown\n", + "plt.plot([threshold_90_precision], [0.9], \"ro\") # Not shown\n", + "plt.plot([threshold_90_precision], [recall_90_precision], \"ro\") # Not shown\n", + "save_fig(\"precision_recall_vs_threshold_plot\") # Not shown\n", "plt.show()" ] },