diff --git a/03_classification.ipynb b/03_classification.ipynb index ba65a7e..832dbde 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -131,7 +131,7 @@ "source": [ "from sklearn.datasets import fetch_openml\n", "\n", - "mnist = fetch_openml('mnist_784', version=1, as_frame=False)" + "mnist = fetch_openml('mnist_784', as_frame=False)" ] }, { @@ -1254,9 +1254,8 @@ "outputs": [], "source": [ "# not in the book\n", - "some_index = 0\n", - "plt.subplot(121); plot_digit(X_test_mod[some_index])\n", - "plt.subplot(122); plot_digit(y_test_mod[some_index])\n", + "plt.subplot(121); plot_digit(X_test_mod[0])\n", + "plt.subplot(122); plot_digit(y_test_mod[0])\n", "save_fig(\"noisy_digit_example_plot\")\n", "plt.show()" ] @@ -1267,8 +1266,9 @@ "metadata": {}, "outputs": [], "source": [ + "knn_clf = KNeighborsClassifier()\n", "knn_clf.fit(X_train_mod, y_train_mod)\n", - "clean_digit = knn_clf.predict([X_test_mod[some_index]])\n", + "clean_digit = knn_clf.predict([X_test_mod[0]])\n", "plot_digit(clean_digit)\n", "save_fig(\"cleaned_digit_example_plot\") # not in the book\n", "plt.show()"