Replace SGD with SVC in OvA vs OvO section

main
Aurélien Geron 2019-05-28 15:21:49 +08:00
parent efa16f02ec
commit f680e49ea2
1 changed files with 17 additions and 14 deletions

View File

@ -124,7 +124,7 @@
"\n", "\n",
"some_digit = X[0]\n", "some_digit = X[0]\n",
"some_digit_image = some_digit.reshape(28, 28)\n", "some_digit_image = some_digit.reshape(28, 28)\n",
"plt.imshow(some_digit_image, cmap = mpl.cm.binary, interpolation=\"nearest\")\n", "plt.imshow(some_digit_image, cmap=mpl.cm.binary)\n",
"plt.axis(\"off\")\n", "plt.axis(\"off\")\n",
"\n", "\n",
"save_fig(\"some_digit_plot\")\n", "save_fig(\"some_digit_plot\")\n",
@ -713,8 +713,11 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"sgd_clf.fit(X_train, y_train) # y_train, not y_train_5\n", "from sklearn.svm import SVC\n",
"sgd_clf.predict([some_digit])" "\n",
"svm_clf = SVC(gamma=\"auto\", random_state=42)\n",
"svm_clf.fit(X_train[:1000], y_train[:1000]) # y_train, not y_train_5\n",
"svm_clf.predict([some_digit])"
] ]
}, },
{ {
@ -723,7 +726,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"some_digit_scores = sgd_clf.decision_function([some_digit])\n", "some_digit_scores = svm_clf.decision_function([some_digit])\n",
"some_digit_scores" "some_digit_scores"
] ]
}, },
@ -742,7 +745,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"sgd_clf.classes_" "svm_clf.classes_"
] ]
}, },
{ {
@ -751,7 +754,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"sgd_clf.classes_[5]" "svm_clf.classes_[5]"
] ]
}, },
{ {
@ -760,10 +763,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sklearn.multiclass import OneVsOneClassifier\n", "from sklearn.multiclass import OneVsRestClassifier\n",
"ovo_clf = OneVsOneClassifier(SGDClassifier(max_iter=1000, tol=1e-3, random_state=42))\n", "ovr_clf = OneVsRestClassifier(SVC(gamma=\"auto\", random_state=42))\n",
"ovo_clf.fit(X_train, y_train)\n", "ovr_clf.fit(X_train[:1000], y_train[:1000])\n",
"ovo_clf.predict([some_digit])" "ovr_clf.predict([some_digit])"
] ]
}, },
{ {
@ -772,7 +775,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"len(ovo_clf.estimators_)" "len(ovr_clf.estimators_)"
] ]
}, },
{ {
@ -781,8 +784,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"forest_clf.fit(X_train, y_train)\n", "sgd_clf.fit(X_train, y_train)\n",
"forest_clf.predict([some_digit])" "sgd_clf.predict([some_digit])"
] ]
}, },
{ {
@ -791,7 +794,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"forest_clf.predict_proba([some_digit])" "sgd_clf.decision_function([some_digit])"
] ]
}, },
{ {