Merge pull request #15 from vi3itor/ch5-errata

Chapter 5: Fix typos and remove unused args in plot.grid()
main
Aurélien Geron 2022-05-31 22:50:55 +12:00 committed by GitHub
commit 4a06d1ab77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 11 deletions

View File

@ -11,7 +11,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"_This notebook is an extra chapter on Support Vector Machines. It also includes exercises and their solutions at the end._" "_This notebook contains all the sample code and solutions to the exercises in chapter 5._"
] ]
}, },
{ {
@ -540,7 +540,7 @@
"plt.figure(figsize=(10, 3))\n", "plt.figure(figsize=(10, 3))\n",
"\n", "\n",
"plt.subplot(121)\n", "plt.subplot(121)\n",
"plt.grid(True, which='both')\n", "plt.grid(True)\n",
"plt.axhline(y=0, color='k')\n", "plt.axhline(y=0, color='k')\n",
"plt.plot(X1D[:, 0][y==0], np.zeros(4), \"bs\")\n", "plt.plot(X1D[:, 0][y==0], np.zeros(4), \"bs\")\n",
"plt.plot(X1D[:, 0][y==1], np.zeros(5), \"g^\")\n", "plt.plot(X1D[:, 0][y==1], np.zeros(5), \"g^\")\n",
@ -549,7 +549,7 @@
"plt.axis([-4.5, 4.5, -0.2, 0.2])\n", "plt.axis([-4.5, 4.5, -0.2, 0.2])\n",
"\n", "\n",
"plt.subplot(122)\n", "plt.subplot(122)\n",
"plt.grid(True, which='both')\n", "plt.grid(True)\n",
"plt.axhline(y=0, color='k')\n", "plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n", "plt.axvline(x=0, color='k')\n",
"plt.plot(X2D[:, 0][y==0], X2D[:, 1][y==0], \"bs\")\n", "plt.plot(X2D[:, 0][y==0], X2D[:, 1][y==0], \"bs\")\n",
@ -624,7 +624,7 @@
" plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n", " plt.plot(X[:, 0][y==0], X[:, 1][y==0], \"bs\")\n",
" plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n", " plt.plot(X[:, 0][y==1], X[:, 1][y==1], \"g^\")\n",
" plt.axis(axes)\n", " plt.axis(axes)\n",
" plt.grid(True, which='both')\n", " plt.grid(True)\n",
" plt.xlabel(\"$x_1$\")\n", " plt.xlabel(\"$x_1$\")\n",
" plt.ylabel(\"$x_2$\", rotation=0)\n", " plt.ylabel(\"$x_2$\", rotation=0)\n",
"\n", "\n",
@ -766,7 +766,7 @@
"plt.figure(figsize=(10.5, 4))\n", "plt.figure(figsize=(10.5, 4))\n",
"\n", "\n",
"plt.subplot(121)\n", "plt.subplot(121)\n",
"plt.grid(True, which='both')\n", "plt.grid(True)\n",
"plt.axhline(y=0, color='k')\n", "plt.axhline(y=0, color='k')\n",
"plt.scatter(x=[-2, 1], y=[0, 0], s=150, alpha=0.5, c=\"red\")\n", "plt.scatter(x=[-2, 1], y=[0, 0], s=150, alpha=0.5, c=\"red\")\n",
"plt.plot(X1D[:, 0][yk==0], np.zeros(4), \"bs\")\n", "plt.plot(X1D[:, 0][yk==0], np.zeros(4), \"bs\")\n",
@ -789,7 +789,7 @@
"plt.axis([-4.5, 4.5, -0.1, 1.1])\n", "plt.axis([-4.5, 4.5, -0.1, 1.1])\n",
"\n", "\n",
"plt.subplot(122)\n", "plt.subplot(122)\n",
"plt.grid(True, which='both')\n", "plt.grid(True)\n",
"plt.axhline(y=0, color='k')\n", "plt.axhline(y=0, color='k')\n",
"plt.axvline(x=0, color='k')\n", "plt.axvline(x=0, color='k')\n",
"plt.plot(XK[:, 0][yk==0], XK[:, 1][yk==0], \"bs\")\n", "plt.plot(XK[:, 0][yk==0], XK[:, 1][yk==0], \"bs\")\n",
@ -1185,7 +1185,7 @@
" axs, (hinge_pos, hinge_pos ** 2), (hinge_neg, hinge_neg ** 2), titles):\n", " axs, (hinge_pos, hinge_pos ** 2), (hinge_neg, hinge_neg ** 2), titles):\n",
" ax.plot(s, loss_pos, \"g-\", linewidth=2, zorder=10, label=\"$t=1$\")\n", " ax.plot(s, loss_pos, \"g-\", linewidth=2, zorder=10, label=\"$t=1$\")\n",
" ax.plot(s, loss_neg, \"r--\", linewidth=2, zorder=10, label=\"$t=-1$\")\n", " ax.plot(s, loss_neg, \"r--\", linewidth=2, zorder=10, label=\"$t=-1$\")\n",
" ax.grid(True, which='both')\n", " ax.grid(True)\n",
" ax.axhline(y=0, color='k')\n", " ax.axhline(y=0, color='k')\n",
" ax.axvline(x=0, color='k')\n", " ax.axvline(x=0, color='k')\n",
" ax.set_xlabel(r\"$s = \\mathbf{w}^\\intercal \\mathbf{x} + b$\")\n", " ax.set_xlabel(r\"$s = \\mathbf{w}^\\intercal \\mathbf{x} + b$\")\n",
@ -1250,10 +1250,9 @@
" w = np.random.randn(X.shape[1], 1) # n feature weights\n", " w = np.random.randn(X.shape[1], 1) # n feature weights\n",
" b = 0\n", " b = 0\n",
"\n", "\n",
" m = len(X)\n",
" t = np.array(y, dtype=np.float64).reshape(-1, 1) * 2 - 1\n", " t = np.array(y, dtype=np.float64).reshape(-1, 1) * 2 - 1\n",
" X_t = X * t\n", " X_t = X * t\n",
" self.Js=[]\n", " self.Js = []\n",
"\n", "\n",
" # Training\n", " # Training\n",
" for epoch in range(self.n_epochs):\n", " for epoch in range(self.n_epochs):\n",
@ -2249,7 +2248,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"This tuned kernelized SVM performs better than the `LinearSVC` model, but we get a lower score on the test set than we measured using cross-validation. This is quite common: since we did so much hyperparameter tuning, we ended up slightly overfitting the cross-validation test sets. It's tempting to tweak the hyperparameters a bit more until we get a better result on the test set, but we this would probably not help, as we would just start overfitting the test set. Anyway, this score is not bad at all, so let's stop here." "This tuned kernelized SVM performs better than the `LinearSVC` model, but we get a lower score on the test set than we measured using cross-validation. This is quite common: since we did so much hyperparameter tuning, we ended up slightly overfitting the cross-validation test sets. It's tempting to tweak the hyperparameters a bit more until we get a better result on the test set, but this would probably not help, as we would just start overfitting the test set. Anyway, this score is not bad at all, so let's stop here."
] ]
}, },
{ {
@ -2309,7 +2308,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Don't forget to scale the data:" "Don't forget to scale the data!"
] ]
}, },
{ {