Clarify the 'not in the book' comments

main
Aurélien Geron 2021-11-21 17:36:22 +13:00
parent 633436e8ae
commit dc64daaf65
1 changed files with 116 additions and 146 deletions

View File

@ -152,7 +152,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book generates and saves Figure 41\n",
"\n", "\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"\n", "\n",
@ -207,11 +207,11 @@
"source": [ "source": [
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"\n", "\n",
"plt.figure(figsize=(6, 4)) # not in the book\n", "plt.figure(figsize=(6, 4)) # not in the book not needed, just formatting\n",
"plt.plot(X_new, y_predict, \"r-\", label=\"Predictions\")\n", "plt.plot(X_new, y_predict, \"r-\", label=\"Predictions\")\n",
"plt.plot(X, y, \"b.\")\n", "plt.plot(X, y, \"b.\")\n",
"\n", "\n",
"# not in the book\n", "# not in the book beautifies and saves Figure 42\n",
"plt.xlabel(\"$x_1$\")\n", "plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n", "plt.ylabel(\"$y$\", rotation=0)\n",
"plt.axis([0, 2, 0, 15])\n", "plt.axis([0, 2, 0, 15])\n",
@ -325,7 +325,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book generates and saves Figure 48\n",
"\n", "\n",
"import matplotlib as mpl\n", "import matplotlib as mpl\n",
"\n", "\n",
@ -347,16 +347,7 @@
" plt.axis([0, 2, 0, 15])\n", " plt.axis([0, 2, 0, 15])\n",
" plt.grid()\n", " plt.grid()\n",
" plt.title(r\"$\\eta = {}$\".format(eta))\n", " plt.title(r\"$\\eta = {}$\".format(eta))\n",
" return theta_path" " return theta_path\n",
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# not in the book\n",
"\n", "\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
"theta = np.random.randn(2,1) # random initialization\n", "theta = np.random.randn(2,1) # random initialization\n",
@ -384,20 +375,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 18,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "theta_path_sgd = [] # not in the book we need to store the path of theta in\n",
"\n", " # the parameter space to plot the next figure"
"# we need to store the path of theta in the parameter space to plot the next\n",
"# figure\n",
"theta_path_sgd = []"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 19,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -410,14 +398,13 @@
"np.random.seed(42)\n", "np.random.seed(42)\n",
"theta = np.random.randn(2, 1) # random initialization\n", "theta = np.random.randn(2, 1) # random initialization\n",
"\n", "\n",
"# not in the book, this is for the next figure\n", "n_shown = 20 # not in the book just needed to generate the figure below\n",
"n_shown = 20\n", "plt.figure(figsize=(6, 4)) # not in the book not needed, just formatting\n",
"plt.figure(figsize=(6, 4))\n",
"\n", "\n",
"for epoch in range(n_epochs):\n", "for epoch in range(n_epochs):\n",
" for iteration in range(m):\n", " for iteration in range(m):\n",
"\n", "\n",
" # not in the book\n", " # not in the book these 4 lines are used to generate the figure\n",
" if epoch == 0 and iteration < n_shown:\n", " if epoch == 0 and iteration < n_shown:\n",
" y_predict = X_new_b @ theta\n", " y_predict = X_new_b @ theta\n",
" color = mpl.colors.rgb2hex(plt.cm.OrRd(iteration / n_shown + 0.15))\n", " color = mpl.colors.rgb2hex(plt.cm.OrRd(iteration / n_shown + 0.15))\n",
@ -429,9 +416,9 @@
" gradients = 2 / 1 * xi.T @ (xi @ theta - yi)\n", " gradients = 2 / 1 * xi.T @ (xi @ theta - yi)\n",
" eta = learning_schedule(epoch * m + iteration)\n", " eta = learning_schedule(epoch * m + iteration)\n",
" theta = theta - eta * gradients\n", " theta = theta - eta * gradients\n",
" theta_path_sgd.append(theta) # not in the book\n", " theta_path_sgd.append(theta) # not in the book to generate the figure\n",
"\n", "\n",
"# not in the book\n", "# not in the book this section beautifies and saves Figure 410\n",
"plt.plot(X, y, \"b.\")\n", "plt.plot(X, y, \"b.\")\n",
"plt.xlabel(\"$x_1$\")\n", "plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$y$\", rotation=0)\n", "plt.ylabel(\"$y$\", rotation=0)\n",
@ -443,7 +430,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 20,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
@ -454,7 +441,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 21,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -467,7 +454,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 23, "execution_count": 22,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -490,10 +477,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 24, "execution_count": 23,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book this cell generates and saves Figure 411\n",
"\n",
"from math import ceil\n", "from math import ceil\n",
"\n", "\n",
"n_epochs = 50\n", "n_epochs = 50\n",
@ -520,35 +509,12 @@
" gradients = 2 / minibatch_size * xi.T @ (xi @ theta - yi)\n", " gradients = 2 / minibatch_size * xi.T @ (xi @ theta - yi)\n",
" eta = learning_schedule(iteration)\n", " eta = learning_schedule(iteration)\n",
" theta = theta - eta * gradients\n", " theta = theta - eta * gradients\n",
" theta_path_mgd.append(theta)" " theta_path_mgd.append(theta)\n",
] "\n",
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"theta"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"theta_path_bgd = np.array(theta_path_bgd)\n", "theta_path_bgd = np.array(theta_path_bgd)\n",
"theta_path_sgd = np.array(theta_path_sgd)\n", "theta_path_sgd = np.array(theta_path_sgd)\n",
"theta_path_mgd = np.array(theta_path_mgd)" "theta_path_mgd = np.array(theta_path_mgd)\n",
] "\n",
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(7, 4))\n", "plt.figure(figsize=(7, 4))\n",
"plt.plot(theta_path_sgd[:, 0], theta_path_sgd[:, 1], \"r-s\", linewidth=1,\n", "plt.plot(theta_path_sgd[:, 0], theta_path_sgd[:, 1], \"r-s\", linewidth=1,\n",
" label=\"Stochastic\")\n", " label=\"Stochastic\")\n",
@ -574,7 +540,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 28, "execution_count": 24,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -586,11 +552,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 29, "execution_count": 25,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this cell generates and saves Figure 412\n",
"plt.figure(figsize=(6, 4))\n", "plt.figure(figsize=(6, 4))\n",
"plt.plot(X, y, \"b.\")\n", "plt.plot(X, y, \"b.\")\n",
"plt.xlabel(\"$x_1$\")\n", "plt.xlabel(\"$x_1$\")\n",
@ -603,7 +569,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 26,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -616,7 +582,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": 27,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -625,7 +591,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 32, "execution_count": 28,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -636,11 +602,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 33, "execution_count": 29,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this cell generates and saves Figure 413\n",
"\n", "\n",
"X_new = np.linspace(-3, 3, 100).reshape(100, 1)\n", "X_new = np.linspace(-3, 3, 100).reshape(100, 1)\n",
"X_new_poly = poly_features.transform(X_new)\n", "X_new_poly = poly_features.transform(X_new)\n",
@ -660,11 +626,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 34, "execution_count": 30,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this cell generates and saves Figure 414\n",
"\n", "\n",
"from sklearn.preprocessing import StandardScaler\n", "from sklearn.preprocessing import StandardScaler\n",
"from sklearn.pipeline import make_pipeline\n", "from sklearn.pipeline import make_pipeline\n",
@ -700,7 +666,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 35, "execution_count": 31,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -712,11 +678,11 @@
"train_errors = -train_scores.mean(axis=1)\n", "train_errors = -train_scores.mean(axis=1)\n",
"valid_errors = -valid_scores.mean(axis=1)\n", "valid_errors = -valid_scores.mean(axis=1)\n",
"\n", "\n",
"plt.figure(figsize=(6, 4)) # not in the book\n", "plt.figure(figsize=(6, 4)) # not in the book not need, just formatting\n",
"plt.plot(train_sizes, train_errors, \"r-+\", linewidth=2, label=\"train\")\n", "plt.plot(train_sizes, train_errors, \"r-+\", linewidth=2, label=\"train\")\n",
"plt.plot(train_sizes, valid_errors, \"b-\", linewidth=3, label=\"valid\")\n", "plt.plot(train_sizes, valid_errors, \"b-\", linewidth=3, label=\"valid\")\n",
"\n", "\n",
"# not in the book: beautifies and saves the figure\n", "# not in the book beautifies and saves Figure 415\n",
"plt.xlabel(\"Training set size\")\n", "plt.xlabel(\"Training set size\")\n",
"plt.ylabel(\"RMSE\")\n", "plt.ylabel(\"RMSE\")\n",
"plt.grid()\n", "plt.grid()\n",
@ -729,7 +695,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": 32,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -741,11 +707,20 @@
"\n", "\n",
"train_sizes, train_scores, valid_scores = learning_curve(\n", "train_sizes, train_scores, valid_scores = learning_curve(\n",
" polynomial_regression, X, y, train_sizes=np.linspace(0.01, 1.0, 40), cv=5,\n", " polynomial_regression, X, y, train_sizes=np.linspace(0.01, 1.0, 40), cv=5,\n",
" scoring=\"neg_root_mean_squared_error\")\n", " scoring=\"neg_root_mean_squared_error\")"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"# not in the book generates and saves Figure 416\n",
"\n", "\n",
"# not in the book (same as earlier)\n",
"train_errors = -train_scores.mean(axis=1)\n", "train_errors = -train_scores.mean(axis=1)\n",
"valid_errors = -valid_scores.mean(axis=1)\n", "valid_errors = -valid_scores.mean(axis=1)\n",
"\n",
"plt.figure(figsize=(6, 4))\n", "plt.figure(figsize=(6, 4))\n",
"plt.plot(train_sizes, train_errors, \"r-+\", linewidth=2, label=\"train\")\n", "plt.plot(train_sizes, train_errors, \"r-+\", linewidth=2, label=\"train\")\n",
"plt.plot(train_sizes, valid_errors, \"b-\", linewidth=3, label=\"valid\")\n", "plt.plot(train_sizes, valid_errors, \"b-\", linewidth=3, label=\"valid\")\n",
@ -781,11 +756,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 37, "execution_count": 34,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book we've done this type of generation several times before\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
"m = 20\n", "m = 20\n",
"X = 3 * np.random.rand(m, 1)\n", "X = 3 * np.random.rand(m, 1)\n",
@ -795,11 +770,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 38, "execution_count": 35,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book a quick peek at the dataset we just generated\n",
"plt.figure(figsize=(6, 4))\n", "plt.figure(figsize=(6, 4))\n",
"plt.plot(X, y, \".\")\n", "plt.plot(X, y, \".\")\n",
"plt.xlabel(\"$x_1$\")\n", "plt.xlabel(\"$x_1$\")\n",
@ -811,7 +786,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": 36,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -824,11 +799,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40, "execution_count": 37,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this cell generates and saves Figure 417\n",
"\n", "\n",
"def plot_model(model_class, polynomial, alphas, **model_kargs):\n", "def plot_model(model_class, polynomial, alphas, **model_kargs):\n",
" plt.plot(X, y, \"b.\", linewidth=3)\n", " plt.plot(X, y, \"b.\", linewidth=3)\n",
@ -864,7 +839,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 41, "execution_count": 38,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -875,11 +850,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42, "execution_count": 39,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book show that we get roughly the same solution as earlier when\n",
"# we use Stochastic Average GD (solver=\"sag\")\n",
"ridge_reg = Ridge(alpha=1, solver=\"sag\", random_state=42)\n", "ridge_reg = Ridge(alpha=1, solver=\"sag\", random_state=42)\n",
"ridge_reg.fit(X, y)\n", "ridge_reg.fit(X, y)\n",
"ridge_reg.predict([[1.5]])" "ridge_reg.predict([[1.5]])"
@ -894,7 +870,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 43, "execution_count": 40,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -907,11 +883,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 44, "execution_count": 41,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this cell generates and saves Figure 418\n",
"plt.figure(figsize=(9, 3.5))\n", "plt.figure(figsize=(9, 3.5))\n",
"plt.subplot(121)\n", "plt.subplot(121)\n",
"plot_model(Lasso, polynomial=False, alphas=(0, 0.1, 1), random_state=42)\n", "plot_model(Lasso, polynomial=False, alphas=(0, 0.1, 1), random_state=42)\n",
@ -925,10 +901,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 45, "execution_count": 42,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book this BIG cell generates and saves Figure 419\n",
"\n",
"t1a, t1b, t2a, t2b = -1, 3, -1.5, 1.5\n", "t1a, t1b, t2a, t2b = -1, 3, -1.5, 1.5\n",
"\n", "\n",
"t1s = np.linspace(t1a, t1b, 500)\n", "t1s = np.linspace(t1a, t1b, 500)\n",
@ -946,15 +924,8 @@
"t_min_idx = np.unravel_index(J.argmin(), J.shape)\n", "t_min_idx = np.unravel_index(J.argmin(), J.shape)\n",
"t1_min, t2_min = t1[t_min_idx], t2[t_min_idx]\n", "t1_min, t2_min = t1[t_min_idx], t2[t_min_idx]\n",
"\n", "\n",
"t_init = np.array([[0.25], [-1]])" "t_init = np.array([[0.25], [-1]])\n",
] "\n",
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"def bgd_path(theta, X, y, l1, l2, core=1, eta=0.05, n_iterations=200):\n", "def bgd_path(theta, X, y, l1, l2, core=1, eta=0.05, n_iterations=200):\n",
" path = [theta]\n", " path = [theta]\n",
" for iteration in range(n_iterations):\n", " for iteration in range(n_iterations):\n",
@ -1023,7 +994,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 47, "execution_count": 43,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1050,12 +1021,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 48, "execution_count": 44,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this is the same code as earlier\n",
"\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
"m = 100\n", "m = 100\n",
"X = 6 * np.random.rand(m, 1) - 3\n", "X = 6 * np.random.rand(m, 1) - 3\n",
@ -1066,7 +1036,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 49, "execution_count": 45,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1081,7 +1051,7 @@
"sgd_reg = SGDRegressor(penalty=None, eta0=0.002, random_state=42)\n", "sgd_reg = SGDRegressor(penalty=None, eta0=0.002, random_state=42)\n",
"n_epochs = 500\n", "n_epochs = 500\n",
"best_valid_rmse = float('inf')\n", "best_valid_rmse = float('inf')\n",
"train_errors, val_errors = [], [] # not in the book, it's for the figure below\n", "train_errors, val_errors = [], [] # not in the book it's for the figure below\n",
"\n", "\n",
"for epoch in range(n_epochs):\n", "for epoch in range(n_epochs):\n",
" sgd_reg.partial_fit(X_train_prep, y_train)\n", " sgd_reg.partial_fit(X_train_prep, y_train)\n",
@ -1091,13 +1061,13 @@
" best_valid_rmse = val_error\n", " best_valid_rmse = val_error\n",
" best_model = deepcopy(sgd_reg)\n", " best_model = deepcopy(sgd_reg)\n",
"\n", "\n",
" # not in the book, we evaluate the train error and save it for the figure\n", " # not in the book we evaluate the train error and save it for the figure\n",
" y_train_predict = sgd_reg.predict(X_train_prep)\n", " y_train_predict = sgd_reg.predict(X_train_prep)\n",
" train_error = mean_squared_error(y_train, y_train_predict, squared=False)\n", " train_error = mean_squared_error(y_train, y_train_predict, squared=False)\n",
" val_errors.append(val_error)\n", " val_errors.append(val_error)\n",
" train_errors.append(train_error)\n", " train_errors.append(train_error)\n",
"\n", "\n",
"# not in the book, this code just creates the figure below\n", "# not in the book this section generates and saves Figure 420\n",
"best_epoch = np.argmin(val_errors)\n", "best_epoch = np.argmin(val_errors)\n",
"plt.figure(figsize=(6, 4))\n", "plt.figure(figsize=(6, 4))\n",
"plt.annotate('Best model',\n", "plt.annotate('Best model',\n",
@ -1134,11 +1104,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 50, "execution_count": 46,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book generates and saves Figure 421\n",
"\n", "\n",
"lim = 6\n", "lim = 6\n",
"t = np.linspace(-lim, lim, 100)\n", "t = np.linspace(-lim, lim, 100)\n",
@ -1168,7 +1138,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 51, "execution_count": 47,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1180,18 +1150,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 52, "execution_count": 48,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "print(iris.DESCR) # not in the book it's a bit too long"
"\n",
"print(iris.DESCR)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 53, "execution_count": 49,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1200,7 +1168,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 54, "execution_count": 50,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1209,7 +1177,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 55, "execution_count": 51,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1218,7 +1186,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 56, "execution_count": 52,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1235,7 +1203,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 57, "execution_count": 53,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1243,14 +1211,14 @@
"y_proba = log_reg.predict_proba(X_new)\n", "y_proba = log_reg.predict_proba(X_new)\n",
"decision_boundary = X_new[y_proba[:, 1] >= 0.5][0, 0]\n", "decision_boundary = X_new[y_proba[:, 1] >= 0.5][0, 0]\n",
"\n", "\n",
"plt.figure(figsize=(8, 3)) # not in the book\n", "plt.figure(figsize=(8, 3)) # not in the book not needed, just formatting\n",
"plt.plot(X_new, y_proba[:, 0], \"b--\", linewidth=2,\n", "plt.plot(X_new, y_proba[:, 0], \"b--\", linewidth=2,\n",
" label=\"Not Iris virginica proba\")\n", " label=\"Not Iris virginica proba\")\n",
"plt.plot(X_new, y_proba[:, 1], \"g-\", linewidth=2, label=\"Iris virginica proba\")\n", "plt.plot(X_new, y_proba[:, 1], \"g-\", linewidth=2, label=\"Iris virginica proba\")\n",
"plt.plot([decision_boundary, decision_boundary], [0, 1], \"k:\", linewidth=2,\n", "plt.plot([decision_boundary, decision_boundary], [0, 1], \"k:\", linewidth=2,\n",
" label=\"Decision boundary\")\n", " label=\"Decision boundary\")\n",
"\n", "\n",
"# not in the book: beautifies the figure\n", "# not in the book this section beautifies and saves Figure 421\n",
"plt.arrow(x=decision_boundary, y=0.08, dx=-0.3, dy=0,\n", "plt.arrow(x=decision_boundary, y=0.08, dx=-0.3, dy=0,\n",
" head_width=0.05, head_length=0.1, fc=\"b\", ec=\"b\")\n", " head_width=0.05, head_length=0.1, fc=\"b\", ec=\"b\")\n",
"plt.arrow(x=decision_boundary, y=0.92, dx=0.3, dy=0,\n", "plt.arrow(x=decision_boundary, y=0.92, dx=0.3, dy=0,\n",
@ -1269,7 +1237,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 58, "execution_count": 54,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1278,7 +1246,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 59, "execution_count": 55,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1287,11 +1255,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 60, "execution_count": 56,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book\n", "# not in the book this cell generates and saves Figure 422\n",
"\n", "\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n", "X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = iris.target_names[iris.target] == 'virginica'\n", "y = iris.target_names[iris.target] == 'virginica'\n",
@ -1337,7 +1305,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 61, "execution_count": 57,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1351,7 +1319,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 62, "execution_count": 58,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
@ -1362,7 +1330,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 63, "execution_count": 59,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
@ -1373,10 +1341,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 64, "execution_count": 60,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# not in the book this cell generates and saves Figure 423\n",
"\n",
"from matplotlib.colors import ListedColormap\n", "from matplotlib.colors import ListedColormap\n",
"\n", "\n",
"custom_cmap = ListedColormap([\"#fafab0\", \"#9898ff\", \"#a0faa0\"])\n", "custom_cmap = ListedColormap([\"#fafab0\", \"#9898ff\", \"#a0faa0\"])\n",
@ -1446,7 +1416,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 65, "execution_count": 61,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1463,7 +1433,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 66, "execution_count": 62,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1479,7 +1449,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 67, "execution_count": 63,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1511,7 +1481,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 68, "execution_count": 64,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1528,7 +1498,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 69, "execution_count": 65,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1537,7 +1507,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 70, "execution_count": 66,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1553,7 +1523,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 71, "execution_count": 67,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1571,7 +1541,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 72, "execution_count": 68,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1593,7 +1563,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 73, "execution_count": 69,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1612,7 +1582,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 74, "execution_count": 70,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1640,7 +1610,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 75, "execution_count": 71,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1673,7 +1643,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 76, "execution_count": 72,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1689,7 +1659,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 77, "execution_count": 73,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1710,7 +1680,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 78, "execution_count": 74,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1747,7 +1717,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 79, "execution_count": 75,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1775,7 +1745,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 80, "execution_count": 76,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1812,7 +1782,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 81, "execution_count": 77,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1840,7 +1810,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 82, "execution_count": 78,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1884,7 +1854,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 83, "execution_count": 79,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [