From d4d15a9c5901554cfcc864626f54a6df7aaa53db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Thu, 9 May 2019 16:24:17 +0800 Subject: [PATCH] Clarify figure 4-19 --- 04_training_linear_models.ipynb | 64 ++++++++++++++++----------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/04_training_linear_models.ipynb b/04_training_linear_models.ipynb index 8297501..5d84aa7 100644 --- a/04_training_linear_models.ipynb +++ b/04_training_linear_models.ipynb @@ -935,12 +935,11 @@ "source": [ "t1a, t1b, t2a, t2b = -1, 3, -1.5, 1.5\n", "\n", - "# ignoring bias term\n", "t1s = np.linspace(t1a, t1b, 500)\n", "t2s = np.linspace(t2a, t2b, 500)\n", "t1, t2 = np.meshgrid(t1s, t2s)\n", "T = np.c_[t1.ravel(), t2.ravel()]\n", - "Xr = np.array([[-1, 1], [-0.3, -1], [1, 0.1]])\n", + "Xr = np.array([[1, 1], [1, -1], [1, 0.5]])\n", "yr = 2 * Xr[:, :1] + 0.5 * Xr[:, 1:]\n", "\n", "J = (1/len(Xr) * np.sum((T.dot(Xr.T) - yr.T)**2, axis=1)).reshape(t1.shape)\n", @@ -960,18 +959,17 @@ "metadata": {}, "outputs": [], "source": [ - "def bgd_path(theta, X, y, l1, l2, core = 1, eta = 0.1, n_iterations = 50):\n", + "def bgd_path(theta, X, y, l1, l2, core = 1, eta = 0.05, n_iterations = 200):\n", " path = [theta]\n", " for iteration in range(n_iterations):\n", - " gradients = core * 2/len(X) * X.T.dot(X.dot(theta) - y) + l1 * np.sign(theta) + 2 * l2 * theta\n", - "\n", + " gradients = core * 2/len(X) * X.T.dot(X.dot(theta) - y) + l1 * np.sign(theta) + l2 * theta\n", " theta = theta - eta * gradients\n", " path.append(theta)\n", " return np.array(path)\n", "\n", - "plt.figure(figsize=(12, 8))\n", - "for i, N, l1, l2, title in ((0, N1, 0.5, 0, \"Lasso\"), (1, N2, 0, 0.1, \"Ridge\")):\n", - " JR = J + l1 * N1 + l2 * N2**2\n", + "fig, axes = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10.1, 8))\n", + "for i, N, l1, l2, title in ((0, N1, 2., 0, \"Lasso\"), (1, N2, 0, 2., \"Ridge\")):\n", + " JR = J + l1 * N1 + l2 * 0.5 * N2**2\n", " \n", " tr_min_idx = np.unravel_index(np.argmin(JR), JR.shape)\n", " t1r_min, t2r_min = t1[tr_min_idx], t2[tr_min_idx]\n", @@ -982,34 +980,36 @@ " \n", " path_J = bgd_path(t_init, Xr, yr, l1=0, l2=0)\n", " path_JR = bgd_path(t_init, Xr, yr, l1, l2)\n", - " path_N = bgd_path(t_init, Xr, yr, np.sign(l1)/3, np.sign(l2), core=0)\n", + " path_N = bgd_path(np.array([[2.0], [0.5]]), Xr, yr, np.sign(l1)/3, np.sign(l2), core=0)\n", "\n", - " plt.subplot(221 + i * 2)\n", - " plt.grid(True)\n", - " plt.axhline(y=0, color='k')\n", - " plt.axvline(x=0, color='k')\n", - " plt.contourf(t1, t2, J, levels=levelsJ, alpha=0.9)\n", - " plt.contour(t1, t2, N, levels=levelsN)\n", - " plt.plot(path_J[:, 0], path_J[:, 1], \"w-o\")\n", - " plt.plot(path_N[:, 0], path_N[:, 1], \"y-^\")\n", - " plt.plot(t1_min, t2_min, \"rs\")\n", - " plt.title(r\"$\\ell_{}$ penalty\".format(i + 1), fontsize=16)\n", - " plt.axis([t1a, t1b, t2a, t2b])\n", + " ax = axes[i, 0]\n", + " ax.grid(True)\n", + " ax.axhline(y=0, color='k')\n", + " ax.axvline(x=0, color='k')\n", + " ax.contourf(t1, t2, N / 2., levels=levelsN)\n", + " ax.plot(path_N[:, 0], path_N[:, 1], \"y--\")\n", + " ax.plot(0, 0, \"ys\")\n", + " ax.plot(t1_min, t2_min, \"ys\")\n", + " ax.set_title(r\"$\\ell_{}$ penalty\".format(i + 1), fontsize=16)\n", + " ax.axis([t1a, t1b, t2a, t2b])\n", " if i == 1:\n", - " plt.xlabel(r\"$\\theta_1$\", fontsize=20)\n", - " plt.ylabel(r\"$\\theta_2$\", fontsize=20, rotation=0)\n", + " ax.set_xlabel(r\"$\\theta_1$\", fontsize=16)\n", + " ax.set_ylabel(r\"$\\theta_2$\", fontsize=16, rotation=0)\n", "\n", - " plt.subplot(222 + i * 2)\n", - " plt.grid(True)\n", - " plt.axhline(y=0, color='k')\n", - " plt.axvline(x=0, color='k')\n", - " plt.contourf(t1, t2, JR, levels=levelsJR, alpha=0.9)\n", - " plt.plot(path_JR[:, 0], path_JR[:, 1], \"w-o\")\n", - " plt.plot(t1r_min, t2r_min, \"rs\")\n", - " plt.title(title, fontsize=16)\n", - " plt.axis([t1a, t1b, t2a, t2b])\n", + " ax = axes[i, 1]\n", + " ax.grid(True)\n", + " ax.axhline(y=0, color='k')\n", + " ax.axvline(x=0, color='k')\n", + " ax.contourf(t1, t2, JR, levels=levelsJR, alpha=0.9)\n", + " ax.plot(path_JR[:, 0], path_JR[:, 1], \"w-o\")\n", + " ax.plot(path_N[:, 0], path_N[:, 1], \"y--\")\n", + " ax.plot(0, 0, \"ys\")\n", + " ax.plot(t1_min, t2_min, \"ys\")\n", + " ax.plot(t1r_min, t2r_min, \"rs\")\n", + " ax.set_title(title, fontsize=16)\n", + " ax.axis([t1a, t1b, t2a, t2b])\n", " if i == 1:\n", - " plt.xlabel(r\"$\\theta_1$\", fontsize=20)\n", + " ax.set_xlabel(r\"$\\theta_1$\", fontsize=16)\n", "\n", "save_fig(\"lasso_vs_ridge_plot\")\n", "plt.show()"