LinearRegression is based on SVD, not the Normal Equation (fixes #184), also fixes #179 (mini-batch gradient descent), and updates matplotlib code to latest version.

main
Aurélien Geron 2018-03-15 18:38:58 +01:00
parent fb29c3b386
commit d9fbf7dd4c
1 changed files with 136 additions and 117 deletions

View File

@ -61,7 +61,11 @@
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format='png', dpi=300)\n"
" plt.savefig(path, format='png', dpi=300)\n",
"\n",
"# Ignore useless warnings (see SciPy issue #5998)\n",
"import warnings\n",
"warnings.filterwarnings(action=\"ignore\", module=\"scipy\", message=\"^internal gelsd\")"
]
},
{
@ -188,7 +192,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Linear regression using batch gradient descent"
"The `LinearRegression` class is based on the `scipy.linalg.lstsq()` function (the name stands for \"least squares\"), which you could call directly:"
]
},
{
@ -196,6 +200,46 @@
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"theta_best_svd, residuals, rank, s = np.linalg.lstsq(X_b, y, rcond=1e-6)\n",
"theta_best_svd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function computes $\\mathbf{X}^+\\mathbf{y}$, where $\\mathbf{X}^{+}$ is the _pseudoinverse_ of $\\mathbf{X}$ (specifically the Moore-Penrose inverse). You can use `np.linalg.pinv()` to compute the pseudoinverse directly:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"np.linalg.pinv(X_b).dot(y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: the first releases of the book implied that the `LinearRegression` class was based on the Normal Equation. This was an error, my apologies: as explained above, it is based on the pseudoinverse, which ultimately relies on the SVD matrix decomposition of $\\mathbf{X}$ (see chapter 8 for details about the SVD decomposition). Its time complexity is $O(n^2)$ and it works even when $m < n$ or when some features are linear combinations of other features (in these cases, $\\mathbf{X}^T \\mathbf{X}$ is not invertible so the Normal Equation fails), see [issue #184](https://github.com/ageron/handson-ml/issues/184) for more details. However, this does not change the rest of the description of the `LinearRegression` class, in particular, it is based on an analytical solution, it does not scale well with the number of features, it scales linearly with the number of instances, all the data must fit in memory, it does not require feature scaling and the order of the instances in the training set does not matter."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Linear regression using batch gradient descent"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"eta = 0.1\n",
"n_iterations = 1000\n",
@ -209,7 +253,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@ -218,7 +262,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@ -227,7 +271,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
@ -253,7 +297,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@ -279,7 +323,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@ -290,7 +334,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
@ -326,7 +370,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@ -335,7 +379,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
@ -346,7 +390,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
@ -362,7 +406,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
@ -374,7 +418,7 @@
"np.random.seed(42)\n",
"theta = np.random.randn(2,1) # random initialization\n",
"\n",
"t0, t1 = 10, 1000\n",
"t0, t1 = 200, 1000\n",
"def learning_schedule(t):\n",
" return t0 / (t + t1)\n",
"\n",
@ -387,7 +431,7 @@
" t += 1\n",
" xi = X_b_shuffled[i:i+minibatch_size]\n",
" yi = y_shuffled[i:i+minibatch_size]\n",
" gradients = 2 * xi.T.dot(xi.dot(theta) - yi)\n",
" gradients = 2/minibatch_size * xi.T.dot(xi.dot(theta) - yi)\n",
" eta = learning_schedule(t)\n",
" theta = theta - eta * gradients\n",
" theta_path_mgd.append(theta)"
@ -395,7 +439,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
@ -404,7 +448,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
@ -415,7 +459,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
@ -440,10 +484,8 @@
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": true
},
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
@ -454,10 +496,8 @@
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": true
},
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"m = 100\n",
@ -467,7 +507,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
@ -481,7 +521,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
@ -493,7 +533,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
@ -502,7 +542,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
@ -513,7 +553,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
@ -532,7 +572,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
@ -563,7 +603,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
@ -589,7 +629,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
@ -602,7 +642,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
@ -628,7 +668,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
@ -671,7 +711,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
@ -683,7 +723,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
@ -694,7 +734,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
@ -705,7 +745,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
@ -724,7 +764,7 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
@ -736,7 +776,7 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
@ -748,7 +788,7 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 45,
"metadata": {
"scrolled": true
},
@ -809,7 +849,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
@ -832,7 +872,7 @@
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
@ -841,7 +881,7 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
@ -852,7 +892,7 @@
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
@ -879,7 +919,7 @@
},
{
"cell_type": "code",
"execution_count": 48,
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
@ -918,6 +958,9 @@
" plt.plot(t1_min, t2_min, \"rs\")\n",
" plt.title(r\"$\\ell_{}$ penalty\".format(i + 1), fontsize=16)\n",
" plt.axis([t1a, t1b, t2a, t2b])\n",
" if i == 1:\n",
" plt.xlabel(r\"$\\theta_1$\", fontsize=20)\n",
" plt.ylabel(r\"$\\theta_2$\", fontsize=20, rotation=0)\n",
"\n",
" plt.subplot(222 + i * 2)\n",
" plt.grid(True)\n",
@ -928,14 +971,8 @@
" plt.plot(t1r_min, t2r_min, \"rs\")\n",
" plt.title(title, fontsize=16)\n",
" plt.axis([t1a, t1b, t2a, t2b])\n",
"\n",
"for subplot in (221, 223):\n",
" plt.subplot(subplot)\n",
" plt.ylabel(r\"$\\theta_2$\", fontsize=20, rotation=0)\n",
"\n",
"for subplot in (223, 224):\n",
" plt.subplot(subplot)\n",
" plt.xlabel(r\"$\\theta_1$\", fontsize=20)\n",
" if i == 1:\n",
" plt.xlabel(r\"$\\theta_1$\", fontsize=20)\n",
"\n",
"save_fig(\"lasso_vs_ridge_plot\")\n",
"plt.show()"
@ -950,7 +987,7 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
@ -971,7 +1008,7 @@
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
@ -982,7 +1019,7 @@
},
{
"cell_type": "code",
"execution_count": 51,
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
@ -991,10 +1028,8 @@
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"collapsed": true
},
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"X = iris[\"data\"][:, 3:] # petal width\n",
@ -1003,7 +1038,7 @@
},
{
"cell_type": "code",
"execution_count": 53,
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
@ -1014,7 +1049,7 @@
},
{
"cell_type": "code",
"execution_count": 54,
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
@ -1034,7 +1069,7 @@
},
{
"cell_type": "code",
"execution_count": 55,
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
@ -1061,7 +1096,7 @@
},
{
"cell_type": "code",
"execution_count": 56,
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
@ -1070,7 +1105,7 @@
},
{
"cell_type": "code",
"execution_count": 57,
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
@ -1079,7 +1114,7 @@
},
{
"cell_type": "code",
"execution_count": 58,
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
@ -1123,7 +1158,7 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
@ -1136,7 +1171,7 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
@ -1161,7 +1196,7 @@
"from matplotlib.colors import ListedColormap\n",
"custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n",
"\n",
"plt.contourf(x0, x1, zz, cmap=custom_cmap, linewidth=5)\n",
"plt.contourf(x0, x1, zz, cmap=custom_cmap)\n",
"contour = plt.contour(x0, x1, zz1, cmap=plt.cm.brg)\n",
"plt.clabel(contour, inline=1, fontsize=12)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
@ -1174,7 +1209,7 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
@ -1183,7 +1218,7 @@
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
@ -1228,10 +1263,8 @@
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"collapsed": true
},
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
@ -1247,10 +1280,8 @@
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {
"collapsed": true
},
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"X_with_bias = np.c_[np.ones([len(X), 1]), X]"
@ -1265,10 +1296,8 @@
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {
"collapsed": true
},
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(2042)"
@ -1283,7 +1312,7 @@
},
{
"cell_type": "code",
"execution_count": 66,
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
@ -1314,10 +1343,8 @@
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"collapsed": true
},
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
"def to_one_hot(y):\n",
@ -1337,7 +1364,7 @@
},
{
"cell_type": "code",
"execution_count": 68,
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
@ -1346,7 +1373,7 @@
},
{
"cell_type": "code",
"execution_count": 69,
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
@ -1362,10 +1389,8 @@
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {
"collapsed": true
},
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"Y_train_one_hot = to_one_hot(y_train)\n",
@ -1384,10 +1409,8 @@
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"collapsed": true
},
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"def softmax(logits):\n",
@ -1405,10 +1428,8 @@
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {
"collapsed": true
},
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"n_inputs = X_train.shape[1] # == 3 (2 features plus the bias term)\n",
@ -1435,7 +1456,7 @@
},
{
"cell_type": "code",
"execution_count": 73,
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
@ -1466,7 +1487,7 @@
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
@ -1482,7 +1503,7 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
@ -1503,7 +1524,7 @@
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
@ -1537,7 +1558,7 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
@ -1565,7 +1586,7 @@
},
{
"cell_type": "code",
"execution_count": 78,
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
@ -1605,7 +1626,7 @@
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
@ -1633,7 +1654,7 @@
},
{
"cell_type": "code",
"execution_count": 80,
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
@ -1659,7 +1680,7 @@
"from matplotlib.colors import ListedColormap\n",
"custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n",
"\n",
"plt.contourf(x0, x1, zz, cmap=custom_cmap, linewidth=5)\n",
"plt.contourf(x0, x1, zz, cmap=custom_cmap)\n",
"contour = plt.contour(x0, x1, zz1, cmap=plt.cm.brg)\n",
"plt.clabel(contour, inline=1, fontsize=12)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
@ -1678,7 +1699,7 @@
},
{
"cell_type": "code",
"execution_count": 81,
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
@ -1700,9 +1721,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": []
}