From f680e49ea2146d056bbfc83b7b0704880ce0ee57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Tue, 28 May 2019 15:21:49 +0800 Subject: [PATCH] Replace SGD with SVC in OvA vs OvO section --- 03_classification.ipynb | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/03_classification.ipynb b/03_classification.ipynb index 852d0c1..0e7545d 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -124,7 +124,7 @@ "\n", "some_digit = X[0]\n", "some_digit_image = some_digit.reshape(28, 28)\n", - "plt.imshow(some_digit_image, cmap = mpl.cm.binary, interpolation=\"nearest\")\n", + "plt.imshow(some_digit_image, cmap=mpl.cm.binary)\n", "plt.axis(\"off\")\n", "\n", "save_fig(\"some_digit_plot\")\n", @@ -713,8 +713,11 @@ "metadata": {}, "outputs": [], "source": [ - "sgd_clf.fit(X_train, y_train) # y_train, not y_train_5\n", - "sgd_clf.predict([some_digit])" + "from sklearn.svm import SVC\n", + "\n", + "svm_clf = SVC(gamma=\"auto\", random_state=42)\n", + "svm_clf.fit(X_train[:1000], y_train[:1000]) # y_train, not y_train_5\n", + "svm_clf.predict([some_digit])" ] }, { @@ -723,7 +726,7 @@ "metadata": {}, "outputs": [], "source": [ - "some_digit_scores = sgd_clf.decision_function([some_digit])\n", + "some_digit_scores = svm_clf.decision_function([some_digit])\n", "some_digit_scores" ] }, @@ -742,7 +745,7 @@ "metadata": {}, "outputs": [], "source": [ - "sgd_clf.classes_" + "svm_clf.classes_" ] }, { @@ -751,7 +754,7 @@ "metadata": {}, "outputs": [], "source": [ - "sgd_clf.classes_[5]" + "svm_clf.classes_[5]" ] }, { @@ -760,10 +763,10 @@ "metadata": {}, "outputs": [], "source": [ - "from sklearn.multiclass import OneVsOneClassifier\n", - "ovo_clf = OneVsOneClassifier(SGDClassifier(max_iter=1000, tol=1e-3, random_state=42))\n", - "ovo_clf.fit(X_train, y_train)\n", - "ovo_clf.predict([some_digit])" + "from sklearn.multiclass import OneVsRestClassifier\n", + "ovr_clf = OneVsRestClassifier(SVC(gamma=\"auto\", random_state=42))\n", + "ovr_clf.fit(X_train[:1000], y_train[:1000])\n", + "ovr_clf.predict([some_digit])" ] }, { @@ -772,7 +775,7 @@ "metadata": {}, "outputs": [], "source": [ - "len(ovo_clf.estimators_)" + "len(ovr_clf.estimators_)" ] }, { @@ -781,8 +784,8 @@ "metadata": {}, "outputs": [], "source": [ - "forest_clf.fit(X_train, y_train)\n", - "forest_clf.predict([some_digit])" + "sgd_clf.fit(X_train, y_train)\n", + "sgd_clf.predict([some_digit])" ] }, { @@ -791,7 +794,7 @@ "metadata": {}, "outputs": [], "source": [ - "forest_clf.predict_proba([some_digit])" + "sgd_clf.decision_function([some_digit])" ] }, {