Replace SGD with SVC in OvA vs OvO section

main
Aurélien Geron 2019-05-28 15:21:49 +08:00
parent efa16f02ec
commit f680e49ea2
1 changed files with 17 additions and 14 deletions

View File

@ -124,7 +124,7 @@
"\n",
"some_digit = X[0]\n",
"some_digit_image = some_digit.reshape(28, 28)\n",
"plt.imshow(some_digit_image, cmap = mpl.cm.binary, interpolation=\"nearest\")\n",
"plt.imshow(some_digit_image, cmap=mpl.cm.binary)\n",
"plt.axis(\"off\")\n",
"\n",
"save_fig(\"some_digit_plot\")\n",
@ -713,8 +713,11 @@
"metadata": {},
"outputs": [],
"source": [
"sgd_clf.fit(X_train, y_train) # y_train, not y_train_5\n",
"sgd_clf.predict([some_digit])"
"from sklearn.svm import SVC\n",
"\n",
"svm_clf = SVC(gamma=\"auto\", random_state=42)\n",
"svm_clf.fit(X_train[:1000], y_train[:1000]) # y_train, not y_train_5\n",
"svm_clf.predict([some_digit])"
]
},
{
@ -723,7 +726,7 @@
"metadata": {},
"outputs": [],
"source": [
"some_digit_scores = sgd_clf.decision_function([some_digit])\n",
"some_digit_scores = svm_clf.decision_function([some_digit])\n",
"some_digit_scores"
]
},
@ -742,7 +745,7 @@
"metadata": {},
"outputs": [],
"source": [
"sgd_clf.classes_"
"svm_clf.classes_"
]
},
{
@ -751,7 +754,7 @@
"metadata": {},
"outputs": [],
"source": [
"sgd_clf.classes_[5]"
"svm_clf.classes_[5]"
]
},
{
@ -760,10 +763,10 @@
"metadata": {},
"outputs": [],
"source": [
"from sklearn.multiclass import OneVsOneClassifier\n",
"ovo_clf = OneVsOneClassifier(SGDClassifier(max_iter=1000, tol=1e-3, random_state=42))\n",
"ovo_clf.fit(X_train, y_train)\n",
"ovo_clf.predict([some_digit])"
"from sklearn.multiclass import OneVsRestClassifier\n",
"ovr_clf = OneVsRestClassifier(SVC(gamma=\"auto\", random_state=42))\n",
"ovr_clf.fit(X_train[:1000], y_train[:1000])\n",
"ovr_clf.predict([some_digit])"
]
},
{
@ -772,7 +775,7 @@
"metadata": {},
"outputs": [],
"source": [
"len(ovo_clf.estimators_)"
"len(ovr_clf.estimators_)"
]
},
{
@ -781,8 +784,8 @@
"metadata": {},
"outputs": [],
"source": [
"forest_clf.fit(X_train, y_train)\n",
"forest_clf.predict([some_digit])"
"sgd_clf.fit(X_train, y_train)\n",
"sgd_clf.predict([some_digit])"
]
},
{
@ -791,7 +794,7 @@
"metadata": {},
"outputs": [],
"source": [
"forest_clf.predict_proba([some_digit])"
"sgd_clf.decision_function([some_digit])"
]
},
{