Replace SGD with SVC in OvA vs OvO section
parent
efa16f02ec
commit
f680e49ea2
|
@ -124,7 +124,7 @@
|
|||
"\n",
|
||||
"some_digit = X[0]\n",
|
||||
"some_digit_image = some_digit.reshape(28, 28)\n",
|
||||
"plt.imshow(some_digit_image, cmap = mpl.cm.binary, interpolation=\"nearest\")\n",
|
||||
"plt.imshow(some_digit_image, cmap=mpl.cm.binary)\n",
|
||||
"plt.axis(\"off\")\n",
|
||||
"\n",
|
||||
"save_fig(\"some_digit_plot\")\n",
|
||||
|
@ -713,8 +713,11 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sgd_clf.fit(X_train, y_train) # y_train, not y_train_5\n",
|
||||
"sgd_clf.predict([some_digit])"
|
||||
"from sklearn.svm import SVC\n",
|
||||
"\n",
|
||||
"svm_clf = SVC(gamma=\"auto\", random_state=42)\n",
|
||||
"svm_clf.fit(X_train[:1000], y_train[:1000]) # y_train, not y_train_5\n",
|
||||
"svm_clf.predict([some_digit])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -723,7 +726,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"some_digit_scores = sgd_clf.decision_function([some_digit])\n",
|
||||
"some_digit_scores = svm_clf.decision_function([some_digit])\n",
|
||||
"some_digit_scores"
|
||||
]
|
||||
},
|
||||
|
@ -742,7 +745,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sgd_clf.classes_"
|
||||
"svm_clf.classes_"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -751,7 +754,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sgd_clf.classes_[5]"
|
||||
"svm_clf.classes_[5]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -760,10 +763,10 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.multiclass import OneVsOneClassifier\n",
|
||||
"ovo_clf = OneVsOneClassifier(SGDClassifier(max_iter=1000, tol=1e-3, random_state=42))\n",
|
||||
"ovo_clf.fit(X_train, y_train)\n",
|
||||
"ovo_clf.predict([some_digit])"
|
||||
"from sklearn.multiclass import OneVsRestClassifier\n",
|
||||
"ovr_clf = OneVsRestClassifier(SVC(gamma=\"auto\", random_state=42))\n",
|
||||
"ovr_clf.fit(X_train[:1000], y_train[:1000])\n",
|
||||
"ovr_clf.predict([some_digit])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -772,7 +775,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"len(ovo_clf.estimators_)"
|
||||
"len(ovr_clf.estimators_)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -781,8 +784,8 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"forest_clf.fit(X_train, y_train)\n",
|
||||
"forest_clf.predict([some_digit])"
|
||||
"sgd_clf.fit(X_train, y_train)\n",
|
||||
"sgd_clf.predict([some_digit])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -791,7 +794,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"forest_clf.predict_proba([some_digit])"
|
||||
"sgd_clf.decision_function([some_digit])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue