LinearRegression is based on SVD, not the Normal Equation (fixes #184), also fixes #179 (mini-batch gradient descent), and updates matplotlib code to latest version.
parent
fb29c3b386
commit
d9fbf7dd4c
|
@ -61,7 +61,11 @@
|
|||
" print(\"Saving figure\", fig_id)\n",
|
||||
" if tight_layout:\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.savefig(path, format='png', dpi=300)\n"
|
||||
" plt.savefig(path, format='png', dpi=300)\n",
|
||||
"\n",
|
||||
"# Ignore useless warnings (see SciPy issue #5998)\n",
|
||||
"import warnings\n",
|
||||
"warnings.filterwarnings(action=\"ignore\", module=\"scipy\", message=\"^internal gelsd\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -188,7 +192,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Linear regression using batch gradient descent"
|
||||
"The `LinearRegression` class is based on the `scipy.linalg.lstsq()` function (the name stands for \"least squares\"), which you could call directly:"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -196,6 +200,46 @@
|
|||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"theta_best_svd, residuals, rank, s = np.linalg.lstsq(X_b, y, rcond=1e-6)\n",
|
||||
"theta_best_svd"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This function computes $\\mathbf{X}^+\\mathbf{y}$, where $\\mathbf{X}^{+}$ is the _pseudoinverse_ of $\\mathbf{X}$ (specifically the Moore-Penrose inverse). You can use `np.linalg.pinv()` to compute the pseudoinverse directly:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"np.linalg.pinv(X_b).dot(y)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Note**: the first releases of the book implied that the `LinearRegression` class was based on the Normal Equation. This was an error, my apologies: as explained above, it is based on the pseudoinverse, which ultimately relies on the SVD matrix decomposition of $\\mathbf{X}$ (see chapter 8 for details about the SVD decomposition). Its time complexity is $O(n^2)$ and it works even when $m < n$ or when some features are linear combinations of other features (in these cases, $\\mathbf{X}^T \\mathbf{X}$ is not invertible so the Normal Equation fails), see [issue #184](https://github.com/ageron/handson-ml/issues/184) for more details. However, this does not change the rest of the description of the `LinearRegression` class, in particular, it is based on an analytical solution, it does not scale well with the number of features, it scales linearly with the number of instances, all the data must fit in memory, it does not require feature scaling and the order of the instances in the training set does not matter."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Linear regression using batch gradient descent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"eta = 0.1\n",
|
||||
"n_iterations = 1000\n",
|
||||
|
@ -209,7 +253,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -218,7 +262,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -227,7 +271,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -253,7 +297,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -279,7 +323,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -290,7 +334,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -326,7 +370,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -335,7 +379,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -346,7 +390,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -362,7 +406,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -374,7 +418,7 @@
|
|||
"np.random.seed(42)\n",
|
||||
"theta = np.random.randn(2,1) # random initialization\n",
|
||||
"\n",
|
||||
"t0, t1 = 10, 1000\n",
|
||||
"t0, t1 = 200, 1000\n",
|
||||
"def learning_schedule(t):\n",
|
||||
" return t0 / (t + t1)\n",
|
||||
"\n",
|
||||
|
@ -387,7 +431,7 @@
|
|||
" t += 1\n",
|
||||
" xi = X_b_shuffled[i:i+minibatch_size]\n",
|
||||
" yi = y_shuffled[i:i+minibatch_size]\n",
|
||||
" gradients = 2 * xi.T.dot(xi.dot(theta) - yi)\n",
|
||||
" gradients = 2/minibatch_size * xi.T.dot(xi.dot(theta) - yi)\n",
|
||||
" eta = learning_schedule(t)\n",
|
||||
" theta = theta - eta * gradients\n",
|
||||
" theta_path_mgd.append(theta)"
|
||||
|
@ -395,7 +439,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -404,7 +448,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -415,7 +459,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -440,10 +484,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
|
@ -454,10 +496,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"m = 100\n",
|
||||
|
@ -467,7 +507,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -481,7 +521,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -493,7 +533,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -502,7 +542,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -513,7 +553,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -532,7 +572,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -563,7 +603,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -589,7 +629,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -602,7 +642,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -628,7 +668,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -671,7 +711,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -683,7 +723,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -694,7 +734,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": 41,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -705,7 +745,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -724,7 +764,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"execution_count": 43,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -736,7 +776,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -748,7 +788,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"execution_count": 45,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
|
@ -809,7 +849,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -832,7 +872,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"execution_count": 47,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -841,7 +881,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"execution_count": 48,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -852,7 +892,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"execution_count": 49,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -879,7 +919,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"execution_count": 50,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -918,6 +958,9 @@
|
|||
" plt.plot(t1_min, t2_min, \"rs\")\n",
|
||||
" plt.title(r\"$\\ell_{}$ penalty\".format(i + 1), fontsize=16)\n",
|
||||
" plt.axis([t1a, t1b, t2a, t2b])\n",
|
||||
" if i == 1:\n",
|
||||
" plt.xlabel(r\"$\\theta_1$\", fontsize=20)\n",
|
||||
" plt.ylabel(r\"$\\theta_2$\", fontsize=20, rotation=0)\n",
|
||||
"\n",
|
||||
" plt.subplot(222 + i * 2)\n",
|
||||
" plt.grid(True)\n",
|
||||
|
@ -928,13 +971,7 @@
|
|||
" plt.plot(t1r_min, t2r_min, \"rs\")\n",
|
||||
" plt.title(title, fontsize=16)\n",
|
||||
" plt.axis([t1a, t1b, t2a, t2b])\n",
|
||||
"\n",
|
||||
"for subplot in (221, 223):\n",
|
||||
" plt.subplot(subplot)\n",
|
||||
" plt.ylabel(r\"$\\theta_2$\", fontsize=20, rotation=0)\n",
|
||||
"\n",
|
||||
"for subplot in (223, 224):\n",
|
||||
" plt.subplot(subplot)\n",
|
||||
" if i == 1:\n",
|
||||
" plt.xlabel(r\"$\\theta_1$\", fontsize=20)\n",
|
||||
"\n",
|
||||
"save_fig(\"lasso_vs_ridge_plot\")\n",
|
||||
|
@ -950,7 +987,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"execution_count": 51,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -971,7 +1008,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"execution_count": 52,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -982,7 +1019,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 51,
|
||||
"execution_count": 53,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -991,10 +1028,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 52,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 54,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X = iris[\"data\"][:, 3:] # petal width\n",
|
||||
|
@ -1003,7 +1038,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 53,
|
||||
"execution_count": 55,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1014,7 +1049,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 54,
|
||||
"execution_count": 56,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1034,7 +1069,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 55,
|
||||
"execution_count": 57,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1061,7 +1096,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 56,
|
||||
"execution_count": 58,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1070,7 +1105,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 57,
|
||||
"execution_count": 59,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1079,7 +1114,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 58,
|
||||
"execution_count": 60,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1123,7 +1158,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 59,
|
||||
"execution_count": 61,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1136,7 +1171,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 60,
|
||||
"execution_count": 62,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1161,7 +1196,7 @@
|
|||
"from matplotlib.colors import ListedColormap\n",
|
||||
"custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n",
|
||||
"\n",
|
||||
"plt.contourf(x0, x1, zz, cmap=custom_cmap, linewidth=5)\n",
|
||||
"plt.contourf(x0, x1, zz, cmap=custom_cmap)\n",
|
||||
"contour = plt.contour(x0, x1, zz1, cmap=plt.cm.brg)\n",
|
||||
"plt.clabel(contour, inline=1, fontsize=12)\n",
|
||||
"plt.xlabel(\"Petal length\", fontsize=14)\n",
|
||||
|
@ -1174,7 +1209,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"execution_count": 63,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1183,7 +1218,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 62,
|
||||
"execution_count": 64,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1228,10 +1263,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 63,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 65,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
|
||||
|
@ -1247,10 +1280,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 64,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 66,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_with_bias = np.c_[np.ones([len(X), 1]), X]"
|
||||
|
@ -1265,10 +1296,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 65,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 67,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"np.random.seed(2042)"
|
||||
|
@ -1283,7 +1312,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 66,
|
||||
"execution_count": 68,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1314,10 +1343,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 67,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 69,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def to_one_hot(y):\n",
|
||||
|
@ -1337,7 +1364,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 68,
|
||||
"execution_count": 70,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1346,7 +1373,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 69,
|
||||
"execution_count": 71,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1362,10 +1389,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 70,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 72,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Y_train_one_hot = to_one_hot(y_train)\n",
|
||||
|
@ -1384,10 +1409,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 71,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 73,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def softmax(logits):\n",
|
||||
|
@ -1405,10 +1428,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 72,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 74,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"n_inputs = X_train.shape[1] # == 3 (2 features plus the bias term)\n",
|
||||
|
@ -1435,7 +1456,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 73,
|
||||
"execution_count": 75,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1466,7 +1487,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 74,
|
||||
"execution_count": 76,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1482,7 +1503,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 75,
|
||||
"execution_count": 77,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1503,7 +1524,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 76,
|
||||
"execution_count": 78,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1537,7 +1558,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 77,
|
||||
"execution_count": 79,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1565,7 +1586,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 78,
|
||||
"execution_count": 80,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1605,7 +1626,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 79,
|
||||
"execution_count": 81,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1633,7 +1654,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 80,
|
||||
"execution_count": 82,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1659,7 +1680,7 @@
|
|||
"from matplotlib.colors import ListedColormap\n",
|
||||
"custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n",
|
||||
"\n",
|
||||
"plt.contourf(x0, x1, zz, cmap=custom_cmap, linewidth=5)\n",
|
||||
"plt.contourf(x0, x1, zz, cmap=custom_cmap)\n",
|
||||
"contour = plt.contour(x0, x1, zz1, cmap=plt.cm.brg)\n",
|
||||
"plt.clabel(contour, inline=1, fontsize=12)\n",
|
||||
"plt.xlabel(\"Petal length\", fontsize=14)\n",
|
||||
|
@ -1678,7 +1699,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 81,
|
||||
"execution_count": 83,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1700,9 +1721,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue