Replace SGD with SVC in OvA vs OvO section
parent
efa16f02ec
commit
f680e49ea2
|
@ -124,7 +124,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"some_digit = X[0]\n",
|
"some_digit = X[0]\n",
|
||||||
"some_digit_image = some_digit.reshape(28, 28)\n",
|
"some_digit_image = some_digit.reshape(28, 28)\n",
|
||||||
"plt.imshow(some_digit_image, cmap = mpl.cm.binary, interpolation=\"nearest\")\n",
|
"plt.imshow(some_digit_image, cmap=mpl.cm.binary)\n",
|
||||||
"plt.axis(\"off\")\n",
|
"plt.axis(\"off\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"save_fig(\"some_digit_plot\")\n",
|
"save_fig(\"some_digit_plot\")\n",
|
||||||
|
@ -713,8 +713,11 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sgd_clf.fit(X_train, y_train) # y_train, not y_train_5\n",
|
"from sklearn.svm import SVC\n",
|
||||||
"sgd_clf.predict([some_digit])"
|
"\n",
|
||||||
|
"svm_clf = SVC(gamma=\"auto\", random_state=42)\n",
|
||||||
|
"svm_clf.fit(X_train[:1000], y_train[:1000]) # y_train, not y_train_5\n",
|
||||||
|
"svm_clf.predict([some_digit])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -723,7 +726,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"some_digit_scores = sgd_clf.decision_function([some_digit])\n",
|
"some_digit_scores = svm_clf.decision_function([some_digit])\n",
|
||||||
"some_digit_scores"
|
"some_digit_scores"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -742,7 +745,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sgd_clf.classes_"
|
"svm_clf.classes_"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -751,7 +754,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sgd_clf.classes_[5]"
|
"svm_clf.classes_[5]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -760,10 +763,10 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from sklearn.multiclass import OneVsOneClassifier\n",
|
"from sklearn.multiclass import OneVsRestClassifier\n",
|
||||||
"ovo_clf = OneVsOneClassifier(SGDClassifier(max_iter=1000, tol=1e-3, random_state=42))\n",
|
"ovr_clf = OneVsRestClassifier(SVC(gamma=\"auto\", random_state=42))\n",
|
||||||
"ovo_clf.fit(X_train, y_train)\n",
|
"ovr_clf.fit(X_train[:1000], y_train[:1000])\n",
|
||||||
"ovo_clf.predict([some_digit])"
|
"ovr_clf.predict([some_digit])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -772,7 +775,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"len(ovo_clf.estimators_)"
|
"len(ovr_clf.estimators_)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -781,8 +784,8 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"forest_clf.fit(X_train, y_train)\n",
|
"sgd_clf.fit(X_train, y_train)\n",
|
||||||
"forest_clf.predict([some_digit])"
|
"sgd_clf.predict([some_digit])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -791,7 +794,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"forest_clf.predict_proba([some_digit])"
|
"sgd_clf.decision_function([some_digit])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue