Clarify figure 4-19

main
Aurélien Geron 2019-05-09 16:24:17 +08:00
parent 107b90d9b3
commit d4d15a9c59
1 changed files with 32 additions and 32 deletions

View File

@ -935,12 +935,11 @@
"source": [
"t1a, t1b, t2a, t2b = -1, 3, -1.5, 1.5\n",
"\n",
"# ignoring bias term\n",
"t1s = np.linspace(t1a, t1b, 500)\n",
"t2s = np.linspace(t2a, t2b, 500)\n",
"t1, t2 = np.meshgrid(t1s, t2s)\n",
"T = np.c_[t1.ravel(), t2.ravel()]\n",
"Xr = np.array([[-1, 1], [-0.3, -1], [1, 0.1]])\n",
"Xr = np.array([[1, 1], [1, -1], [1, 0.5]])\n",
"yr = 2 * Xr[:, :1] + 0.5 * Xr[:, 1:]\n",
"\n",
"J = (1/len(Xr) * np.sum((T.dot(Xr.T) - yr.T)**2, axis=1)).reshape(t1.shape)\n",
@ -960,18 +959,17 @@
"metadata": {},
"outputs": [],
"source": [
"def bgd_path(theta, X, y, l1, l2, core = 1, eta = 0.1, n_iterations = 50):\n",
"def bgd_path(theta, X, y, l1, l2, core = 1, eta = 0.05, n_iterations = 200):\n",
" path = [theta]\n",
" for iteration in range(n_iterations):\n",
" gradients = core * 2/len(X) * X.T.dot(X.dot(theta) - y) + l1 * np.sign(theta) + 2 * l2 * theta\n",
"\n",
" gradients = core * 2/len(X) * X.T.dot(X.dot(theta) - y) + l1 * np.sign(theta) + l2 * theta\n",
" theta = theta - eta * gradients\n",
" path.append(theta)\n",
" return np.array(path)\n",
"\n",
"plt.figure(figsize=(12, 8))\n",
"for i, N, l1, l2, title in ((0, N1, 0.5, 0, \"Lasso\"), (1, N2, 0, 0.1, \"Ridge\")):\n",
" JR = J + l1 * N1 + l2 * N2**2\n",
"fig, axes = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10.1, 8))\n",
"for i, N, l1, l2, title in ((0, N1, 2., 0, \"Lasso\"), (1, N2, 0, 2., \"Ridge\")):\n",
" JR = J + l1 * N1 + l2 * 0.5 * N2**2\n",
" \n",
" tr_min_idx = np.unravel_index(np.argmin(JR), JR.shape)\n",
" t1r_min, t2r_min = t1[tr_min_idx], t2[tr_min_idx]\n",
@ -982,34 +980,36 @@
" \n",
" path_J = bgd_path(t_init, Xr, yr, l1=0, l2=0)\n",
" path_JR = bgd_path(t_init, Xr, yr, l1, l2)\n",
" path_N = bgd_path(t_init, Xr, yr, np.sign(l1)/3, np.sign(l2), core=0)\n",
" path_N = bgd_path(np.array([[2.0], [0.5]]), Xr, yr, np.sign(l1)/3, np.sign(l2), core=0)\n",
"\n",
" plt.subplot(221 + i * 2)\n",
" plt.grid(True)\n",
" plt.axhline(y=0, color='k')\n",
" plt.axvline(x=0, color='k')\n",
" plt.contourf(t1, t2, J, levels=levelsJ, alpha=0.9)\n",
" plt.contour(t1, t2, N, levels=levelsN)\n",
" plt.plot(path_J[:, 0], path_J[:, 1], \"w-o\")\n",
" plt.plot(path_N[:, 0], path_N[:, 1], \"y-^\")\n",
" plt.plot(t1_min, t2_min, \"rs\")\n",
" plt.title(r\"$\\ell_{}$ penalty\".format(i + 1), fontsize=16)\n",
" plt.axis([t1a, t1b, t2a, t2b])\n",
" ax = axes[i, 0]\n",
" ax.grid(True)\n",
" ax.axhline(y=0, color='k')\n",
" ax.axvline(x=0, color='k')\n",
" ax.contourf(t1, t2, N / 2., levels=levelsN)\n",
" ax.plot(path_N[:, 0], path_N[:, 1], \"y--\")\n",
" ax.plot(0, 0, \"ys\")\n",
" ax.plot(t1_min, t2_min, \"ys\")\n",
" ax.set_title(r\"$\\ell_{}$ penalty\".format(i + 1), fontsize=16)\n",
" ax.axis([t1a, t1b, t2a, t2b])\n",
" if i == 1:\n",
" plt.xlabel(r\"$\\theta_1$\", fontsize=20)\n",
" plt.ylabel(r\"$\\theta_2$\", fontsize=20, rotation=0)\n",
" ax.set_xlabel(r\"$\\theta_1$\", fontsize=16)\n",
" ax.set_ylabel(r\"$\\theta_2$\", fontsize=16, rotation=0)\n",
"\n",
" plt.subplot(222 + i * 2)\n",
" plt.grid(True)\n",
" plt.axhline(y=0, color='k')\n",
" plt.axvline(x=0, color='k')\n",
" plt.contourf(t1, t2, JR, levels=levelsJR, alpha=0.9)\n",
" plt.plot(path_JR[:, 0], path_JR[:, 1], \"w-o\")\n",
" plt.plot(t1r_min, t2r_min, \"rs\")\n",
" plt.title(title, fontsize=16)\n",
" plt.axis([t1a, t1b, t2a, t2b])\n",
" ax = axes[i, 1]\n",
" ax.grid(True)\n",
" ax.axhline(y=0, color='k')\n",
" ax.axvline(x=0, color='k')\n",
" ax.contourf(t1, t2, JR, levels=levelsJR, alpha=0.9)\n",
" ax.plot(path_JR[:, 0], path_JR[:, 1], \"w-o\")\n",
" ax.plot(path_N[:, 0], path_N[:, 1], \"y--\")\n",
" ax.plot(0, 0, \"ys\")\n",
" ax.plot(t1_min, t2_min, \"ys\")\n",
" ax.plot(t1r_min, t2r_min, \"rs\")\n",
" ax.set_title(title, fontsize=16)\n",
" ax.axis([t1a, t1b, t2a, t2b])\n",
" if i == 1:\n",
" plt.xlabel(r\"$\\theta_1$\", fontsize=20)\n",
" ax.set_xlabel(r\"$\\theta_1$\", fontsize=16)\n",
"\n",
"save_fig(\"lasso_vs_ridge_plot\")\n",
"plt.show()"