Update to latest library versions

main
Aurélien Geron 2020-11-21 12:22:42 +13:00
parent 1e81324573
commit f225f59780
3 changed files with 224 additions and 176 deletions

View File

@ -931,7 +931,7 @@
"rooms_ix, bedrooms_ix, population_ix, households_ix = 3, 4, 5, 6\n",
"\n",
"class CombinedAttributesAdder(BaseEstimator, TransformerMixin):\n",
" def __init__(self, add_bedrooms_per_room = True): # no *args or **kargs\n",
" def __init__(self, add_bedrooms_per_room=True): # no *args or **kargs\n",
" self.add_bedrooms_per_room = add_bedrooms_per_room\n",
" def fit(self, X, y=None):\n",
" return self # nothing else to do\n",
@ -949,11 +949,36 @@
"housing_extra_attribs = attr_adder.transform(housing.values)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that I hard coded the indices (3, 4, 5, 6) for concision and clarity in the book, but it would be much cleaner to get them dynamically, like this:"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"col_names = \"total_rooms\", \"total_bedrooms\", \"population\", \"households\"\n",
"rooms_ix, bedrooms_ix, population_ix, households_ix = [\n",
" housing.columns.get_loc(c) for c in col_names] # get the column indices"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Also, `housing_extra_attribs` is a NumPy array, we've lost the column names (unfortunately, that's a problem with Scikit-Learn). To recover a `DataFrame`, you could run this:"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"housing_extra_attribs = pd.DataFrame(\n",
" housing_extra_attribs,\n",
@ -971,7 +996,7 @@
},
{
"cell_type": "code",
"execution_count": 72,
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
@ -989,7 +1014,7 @@
},
{
"cell_type": "code",
"execution_count": 73,
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
@ -998,7 +1023,7 @@
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
@ -1017,7 +1042,7 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
@ -1026,7 +1051,7 @@
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
@ -1042,7 +1067,7 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
@ -1067,7 +1092,7 @@
},
{
"cell_type": "code",
"execution_count": 78,
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
@ -1089,7 +1114,7 @@
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
@ -1103,7 +1128,7 @@
},
{
"cell_type": "code",
"execution_count": 80,
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
@ -1120,7 +1145,7 @@
},
{
"cell_type": "code",
"execution_count": 81,
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
@ -1136,7 +1161,7 @@
},
{
"cell_type": "code",
"execution_count": 82,
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
@ -1148,7 +1173,7 @@
},
{
"cell_type": "code",
"execution_count": 83,
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
@ -1169,7 +1194,7 @@
},
{
"cell_type": "code",
"execution_count": 84,
"execution_count": 85,
"metadata": {},
"outputs": [],
"source": [
@ -1178,7 +1203,7 @@
},
{
"cell_type": "code",
"execution_count": 85,
"execution_count": 86,
"metadata": {},
"outputs": [],
"source": [
@ -1187,7 +1212,7 @@
},
{
"cell_type": "code",
"execution_count": 86,
"execution_count": 87,
"metadata": {},
"outputs": [],
"source": [
@ -1201,7 +1226,7 @@
},
{
"cell_type": "code",
"execution_count": 87,
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [
@ -1213,7 +1238,7 @@
},
{
"cell_type": "code",
"execution_count": 88,
"execution_count": 89,
"metadata": {},
"outputs": [],
"source": [
@ -1225,7 +1250,7 @@
},
{
"cell_type": "code",
"execution_count": 89,
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
@ -1244,7 +1269,7 @@
},
{
"cell_type": "code",
"execution_count": 90,
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
@ -1257,7 +1282,7 @@
},
{
"cell_type": "code",
"execution_count": 91,
"execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
@ -1271,7 +1296,7 @@
},
{
"cell_type": "code",
"execution_count": 92,
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
@ -1290,7 +1315,7 @@
},
{
"cell_type": "code",
"execution_count": 93,
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
@ -1302,7 +1327,7 @@
},
{
"cell_type": "code",
"execution_count": 94,
"execution_count": 95,
"metadata": {},
"outputs": [],
"source": [
@ -1314,7 +1339,7 @@
},
{
"cell_type": "code",
"execution_count": 95,
"execution_count": 96,
"metadata": {},
"outputs": [],
"source": [
@ -1328,7 +1353,7 @@
},
{
"cell_type": "code",
"execution_count": 96,
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
@ -1338,7 +1363,7 @@
},
{
"cell_type": "code",
"execution_count": 97,
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
@ -1354,7 +1379,7 @@
},
{
"cell_type": "code",
"execution_count": 98,
"execution_count": 99,
"metadata": {},
"outputs": [],
"source": [
@ -1384,7 +1409,7 @@
},
{
"cell_type": "code",
"execution_count": 99,
"execution_count": 100,
"metadata": {},
"outputs": [],
"source": [
@ -1393,7 +1418,7 @@
},
{
"cell_type": "code",
"execution_count": 100,
"execution_count": 101,
"metadata": {},
"outputs": [],
"source": [
@ -1409,7 +1434,7 @@
},
{
"cell_type": "code",
"execution_count": 101,
"execution_count": 102,
"metadata": {},
"outputs": [],
"source": [
@ -1420,7 +1445,7 @@
},
{
"cell_type": "code",
"execution_count": 102,
"execution_count": 103,
"metadata": {},
"outputs": [],
"source": [
@ -1429,7 +1454,7 @@
},
{
"cell_type": "code",
"execution_count": 103,
"execution_count": 104,
"metadata": {},
"outputs": [],
"source": [
@ -1449,7 +1474,7 @@
},
{
"cell_type": "code",
"execution_count": 104,
"execution_count": 105,
"metadata": {},
"outputs": [],
"source": [
@ -1460,7 +1485,7 @@
},
{
"cell_type": "code",
"execution_count": 105,
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
@ -1470,7 +1495,7 @@
},
{
"cell_type": "code",
"execution_count": 106,
"execution_count": 107,
"metadata": {},
"outputs": [],
"source": [
@ -1484,7 +1509,7 @@
},
{
"cell_type": "code",
"execution_count": 107,
"execution_count": 108,
"metadata": {},
"outputs": [],
"source": [
@ -1502,7 +1527,7 @@
},
{
"cell_type": "code",
"execution_count": 108,
"execution_count": 109,
"metadata": {},
"outputs": [],
"source": [
@ -1518,7 +1543,7 @@
},
{
"cell_type": "code",
"execution_count": 109,
"execution_count": 110,
"metadata": {},
"outputs": [],
"source": [
@ -1540,7 +1565,7 @@
},
{
"cell_type": "code",
"execution_count": 110,
"execution_count": 111,
"metadata": {},
"outputs": [],
"source": [
@ -1560,7 +1585,7 @@
},
{
"cell_type": "code",
"execution_count": 111,
"execution_count": 112,
"metadata": {},
"outputs": [],
"source": [
@ -1585,7 +1610,7 @@
},
{
"cell_type": "code",
"execution_count": 112,
"execution_count": 113,
"metadata": {},
"outputs": [],
"source": [
@ -1607,7 +1632,7 @@
},
{
"cell_type": "code",
"execution_count": 113,
"execution_count": 114,
"metadata": {},
"outputs": [],
"source": [
@ -1616,7 +1641,7 @@
},
{
"cell_type": "code",
"execution_count": 114,
"execution_count": 115,
"metadata": {},
"outputs": [],
"source": [
@ -1635,7 +1660,7 @@
},
{
"cell_type": "code",
"execution_count": 115,
"execution_count": 116,
"metadata": {},
"outputs": [],
"source": [
@ -1671,7 +1696,7 @@
},
{
"cell_type": "code",
"execution_count": 116,
"execution_count": 117,
"metadata": {},
"outputs": [],
"source": [
@ -1697,7 +1722,7 @@
},
{
"cell_type": "code",
"execution_count": 117,
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
@ -1715,7 +1740,7 @@
},
{
"cell_type": "code",
"execution_count": 118,
"execution_count": 119,
"metadata": {},
"outputs": [],
"source": [
@ -1745,7 +1770,7 @@
},
{
"cell_type": "code",
"execution_count": 119,
"execution_count": 120,
"metadata": {},
"outputs": [],
"source": [
@ -1778,7 +1803,7 @@
},
{
"cell_type": "code",
"execution_count": 120,
"execution_count": 121,
"metadata": {},
"outputs": [],
"source": [
@ -1796,7 +1821,7 @@
},
{
"cell_type": "code",
"execution_count": 121,
"execution_count": 122,
"metadata": {},
"outputs": [],
"source": [
@ -1819,7 +1844,7 @@
},
{
"cell_type": "code",
"execution_count": 122,
"execution_count": 123,
"metadata": {},
"outputs": [],
"source": [
@ -1844,7 +1869,7 @@
},
{
"cell_type": "code",
"execution_count": 123,
"execution_count": 124,
"metadata": {},
"outputs": [],
"source": [
@ -1883,7 +1908,7 @@
},
{
"cell_type": "code",
"execution_count": 124,
"execution_count": 125,
"metadata": {},
"outputs": [],
"source": [
@ -1919,7 +1944,7 @@
},
{
"cell_type": "code",
"execution_count": 125,
"execution_count": 126,
"metadata": {},
"outputs": [],
"source": [
@ -1935,7 +1960,7 @@
},
{
"cell_type": "code",
"execution_count": 126,
"execution_count": 127,
"metadata": {},
"outputs": [],
"source": [
@ -1945,7 +1970,7 @@
},
{
"cell_type": "code",
"execution_count": 127,
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
@ -1961,7 +1986,7 @@
},
{
"cell_type": "code",
"execution_count": 128,
"execution_count": 129,
"metadata": {},
"outputs": [],
"source": [
@ -1977,7 +2002,7 @@
},
{
"cell_type": "code",
"execution_count": 129,
"execution_count": 130,
"metadata": {},
"outputs": [],
"source": [
@ -1989,7 +2014,7 @@
},
{
"cell_type": "code",
"execution_count": 130,
"execution_count": 131,
"metadata": {},
"outputs": [],
"source": [
@ -2005,7 +2030,7 @@
},
{
"cell_type": "code",
"execution_count": 131,
"execution_count": 132,
"metadata": {},
"outputs": [],
"source": [
@ -2021,7 +2046,7 @@
},
{
"cell_type": "code",
"execution_count": 132,
"execution_count": 133,
"metadata": {},
"outputs": [],
"source": [
@ -2051,7 +2076,7 @@
},
{
"cell_type": "code",
"execution_count": 133,
"execution_count": 134,
"metadata": {},
"outputs": [],
"source": [
@ -2064,7 +2089,7 @@
},
{
"cell_type": "code",
"execution_count": 134,
"execution_count": 135,
"metadata": {},
"outputs": [],
"source": [
@ -2080,7 +2105,7 @@
},
{
"cell_type": "code",
"execution_count": 135,
"execution_count": 136,
"metadata": {},
"outputs": [],
"source": [
@ -2114,7 +2139,7 @@
},
{
"cell_type": "code",
"execution_count": 136,
"execution_count": 137,
"metadata": {},
"outputs": [],
"source": [
@ -2130,7 +2155,7 @@
},
{
"cell_type": "code",
"execution_count": 137,
"execution_count": 138,
"metadata": {},
"outputs": [],
"source": [
@ -2168,7 +2193,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.7.8"
},
"nav_menu": {
"height": "279px",

View File

@ -291,7 +291,7 @@
"from sklearn.model_selection import StratifiedKFold\n",
"from sklearn.base import clone\n",
"\n",
"skfolds = StratifiedKFold(n_splits=3, random_state=42)\n",
"skfolds = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)\n",
"\n",
"for train_index, test_index in skfolds.split(X_train, y_train_5):\n",
" clone_clf = clone(sgd_clf)\n",
@ -306,6 +306,13 @@
" print(n_correct / len(y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: `shuffle=True` was omitted by mistake in previous releases of the book."
]
},
{
"cell_type": "code",
"execution_count": 19,
@ -330,6 +337,17 @@
"cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning**: this output (and many others in this notebook and other notebooks) may differ slightly from those in the book. Don't worry, that's okay! There are several reasons for this:\n",
"* first, Scikit-Learn and other libraries evolve, and algorithms get tweaked a bit, which may change the exact result you get. If you use the latest Scikit-Learn version (and in general, you really should), you probably won't be using the exact same version I used when I wrote the book or this notebook, hence the difference. I try to keep this notebook reasonably up to date, but I can't change the numbers on the pages in your copy of the book.\n",
"* second, many training algorithms are stochastic, meaning they rely on randomness. In principle, it's possible to get consistent outputs from a random number generator by setting the seed from which it generates the pseudo-random numbers (which is why you will see `random_state=42` or `np.random.seed(42)` pretty often). However, sometimes this does not suffice due to the other factors listed here.\n",
"* third, if the training algorithm runs across multiple threads (as do some algorithms implemented in C) or across multiple processes (e.g., when using the `n_jobs` argument), then the precise order in which operations will run is not always guaranteed, and thus the exact result may vary slightly.\n",
"* lastly, other things may prevent perfect reproducibility, such as Python maps and sets whose order is not guaranteed to be stable across sessions, or the order of files in a directory which is also not guaranteed."
]
},
{
"cell_type": "code",
"execution_count": 21,
@ -375,11 +393,12 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"4096 / (4096 + 1522)"
"cm = confusion_matrix(y_train_5, y_train_pred)\n",
"cm[1, 1] / (cm[0, 1] + cm[1, 1])"
]
},
{
@ -393,11 +412,11 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"4096 / (4096 + 1325)"
"cm[1, 1] / (cm[1, 0] + cm[1, 1])"
]
},
{
@ -417,7 +436,7 @@
"metadata": {},
"outputs": [],
"source": [
"4096 / (4096 + (1522 + 1325) / 2)"
"cm[1, 1] / (cm[1, 1] + (cm[1, 0] + cm[0, 1]) / 2)"
]
},
{
@ -462,7 +481,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
@ -472,7 +491,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
@ -483,7 +502,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
@ -514,7 +533,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
@ -523,7 +542,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
@ -536,47 +555,20 @@
"\n",
"plt.figure(figsize=(8, 6))\n",
"plot_precision_vs_recall(precisions, recalls)\n",
"plt.plot([0.4368, 0.4368], [0., 0.9], \"r:\")\n",
"plt.plot([0.0, 0.4368], [0.9, 0.9], \"r:\")\n",
"plt.plot([0.4368], [0.9], \"ro\")\n",
"plt.plot([recall_90_precision, recall_90_precision], [0., 0.9], \"r:\")\n",
"plt.plot([0.0, recall_90_precision], [0.9, 0.9], \"r:\")\n",
"plt.plot([recall_90_precision], [0.9], \"ro\")\n",
"save_fig(\"precision_vs_recall_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"threshold_90_precision"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"y_train_pred_90 = (y_scores >= threshold_90_precision)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"precision_score(y_train_5, y_train_pred_90)"
"threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]"
]
},
{
@ -584,6 +576,33 @@
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"threshold_90_precision"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"y_train_pred_90 = (y_scores >= threshold_90_precision)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"precision_score(y_train_5, y_train_pred_90)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"recall_score(y_train_5, y_train_pred_90)"
]
@ -597,7 +616,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
@ -608,7 +627,7 @@
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
@ -620,18 +639,19 @@
" plt.ylabel('True Positive Rate (Recall)', fontsize=16) # Not shown\n",
" plt.grid(True) # Not shown\n",
"\n",
"plt.figure(figsize=(8, 6)) # Not shown\n",
"plt.figure(figsize=(8, 6)) # Not shown\n",
"plot_roc_curve(fpr, tpr)\n",
"plt.plot([4.837e-3, 4.837e-3], [0., 0.4368], \"r:\") # Not shown\n",
"plt.plot([0.0, 4.837e-3], [0.4368, 0.4368], \"r:\") # Not shown\n",
"plt.plot([4.837e-3], [0.4368], \"ro\") # Not shown\n",
"save_fig(\"roc_curve_plot\") # Not shown\n",
"fpr_90 = fpr[np.argmax(tpr >= recall_90_precision)] # Not shown\n",
"plt.plot([fpr_90, fpr_90], [0., recall_90_precision], \"r:\") # Not shown\n",
"plt.plot([0.0, fpr_90], [recall_90_precision, recall_90_precision], \"r:\") # Not shown\n",
"plt.plot([fpr_90], [recall_90_precision], \"ro\") # Not shown\n",
"save_fig(\"roc_curve_plot\") # Not shown\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
@ -649,7 +669,7 @@
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
@ -661,7 +681,7 @@
},
{
"cell_type": "code",
"execution_count": 48,
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
@ -671,18 +691,20 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"recall_for_forest = tpr_forest[np.argmax(fpr_forest >= fpr_90)]\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"plt.plot(fpr, tpr, \"b:\", linewidth=2, label=\"SGD\")\n",
"plot_roc_curve(fpr_forest, tpr_forest, \"Random Forest\")\n",
"plt.plot([4.837e-3, 4.837e-3], [0., 0.4368], \"r:\")\n",
"plt.plot([0.0, 4.837e-3], [0.4368, 0.4368], \"r:\")\n",
"plt.plot([4.837e-3], [0.4368], \"ro\")\n",
"plt.plot([4.837e-3, 4.837e-3], [0., 0.9487], \"r:\")\n",
"plt.plot([4.837e-3], [0.9487], \"ro\")\n",
"plt.plot([fpr_90, fpr_90], [0., recall_90_precision], \"r:\")\n",
"plt.plot([0.0, fpr_90], [recall_90_precision, recall_90_precision], \"r:\")\n",
"plt.plot([fpr_90], [recall_90_precision], \"ro\")\n",
"plt.plot([fpr_90, fpr_90], [0., recall_for_forest], \"r:\")\n",
"plt.plot([fpr_90], [recall_for_forest], \"ro\")\n",
"plt.grid(True)\n",
"plt.legend(loc=\"lower right\", fontsize=16)\n",
"save_fig(\"roc_curve_comparison_plot\")\n",
@ -691,7 +713,7 @@
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
@ -700,7 +722,7 @@
},
{
"cell_type": "code",
"execution_count": 51,
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
@ -710,7 +732,7 @@
},
{
"cell_type": "code",
"execution_count": 52,
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
@ -1031,7 +1053,7 @@
"outputs": [],
"source": [
"from sklearn.dummy import DummyClassifier\n",
"dmy_clf = DummyClassifier()\n",
"dmy_clf = DummyClassifier(strategy=\"prior\")\n",
"y_probas_dmy = cross_val_predict(dmy_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
"y_scores_dmy = y_probas_dmy[:, 1]"
]
@ -2127,14 +2149,14 @@
},
{
"cell_type": "code",
"execution_count": 142,
"execution_count": 185,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"X = np.array(ham_emails + spam_emails)\n",
"X = np.array(ham_emails + spam_emails, dtype=object)\n",
"y = np.array([0] * len(ham_emails) + [1] * len(spam_emails))\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
@ -2488,14 +2510,14 @@
},
{
"cell_type": "code",
"execution_count": 158,
"execution_count": 183,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.model_selection import cross_val_score\n",
"\n",
"log_clf = LogisticRegression(solver=\"lbfgs\", random_state=42)\n",
"log_clf = LogisticRegression(solver=\"lbfgs\", max_iter=1000, random_state=42)\n",
"score = cross_val_score(log_clf, X_train_transformed, y_train, cv=3, verbose=3)\n",
"score.mean()"
]
@ -2504,14 +2526,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Over 98.7%, not bad for a first try! :) However, remember that we are using the \"easy\" dataset. You can try with the harder datasets, the results won't be so amazing. You would have to try multiple models, select the best ones and fine-tune them using cross-validation, and so on.\n",
"Over 98.5%, not bad for a first try! :) However, remember that we are using the \"easy\" dataset. You can try with the harder datasets, the results won't be so amazing. You would have to try multiple models, select the best ones and fine-tune them using cross-validation, and so on.\n",
"\n",
"But you get the picture, so let's stop now, and just print out the precision/recall we get on the test set:"
]
},
{
"cell_type": "code",
"execution_count": 159,
"execution_count": 184,
"metadata": {},
"outputs": [],
"source": [
@ -2519,7 +2541,7 @@
"\n",
"X_test_transformed = preprocess_pipeline.transform(X_test)\n",
"\n",
"log_clf = LogisticRegression(solver=\"lbfgs\", random_state=42)\n",
"log_clf = LogisticRegression(solver=\"lbfgs\", max_iter=1000, random_state=42)\n",
"log_clf.fit(X_train_transformed, y_train)\n",
"\n",
"y_pred = log_clf.predict(X_test_transformed)\n",
@ -2552,7 +2574,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.7.8"
},
"nav_menu": {},
"toc": {

View File

@ -2,71 +2,64 @@
# on Windows or when using a GPU. Please see the installation
# instructions in INSTALL.md
##### Core scientific packages
jupyter==1.0.0
matplotlib==3.1.3
numpy==1.18.1
pandas==1.0.3
scipy==1.4.1
matplotlib==3.3.2
numpy==1.18.5
pandas==1.1.3
scipy==1.5.3
##### Machine Learning packages
scikit-learn==0.22
scikit-learn==0.23.2
# Optional: the XGBoost library is only used in chapter 7
xgboost==1.0.2
xgboost==1.2.1
# Optional: the transformers library is only using in chapter 16
transformers==2.8.0
transformers==3.3.1
##### TensorFlow-related packages
# If you have a TF-compatible GPU and you want to enable GPU support, then
# replace tensorflow with tensorflow-gpu, and replace tensorflow-serving-api
# with tensorflow-serving-api-gpu.
# replace tensorflow-serving-api with tensorflow-serving-api-gpu.
# Your GPU must have CUDA Compute Capability 3.5 or higher support, and
# you must install CUDA, cuDNN and more: see tensorflow.org for the detailed
# installation instructions.
tensorflow==2.1.0
tensorflow==2.3.1
# Optional: the TF Serving API library is just needed for chapter 19.
tensorflow-serving-api==2.1.0
#tensorflow-serving-api-gpu==2.1.0
tensorflow-serving-api==2.3.0 # or tensorflow-serving-api-gpu if gpu
tensorboard==2.1.1
tensorboard-plugin-profile==2.2.0
tensorflow-datasets==2.1.0
tensorflow-hub==0.7.0
tensorflow-probability==0.9.0
tensorboard==2.3.0
tensorboard-plugin-profile==2.3.0
tensorflow-datasets==4.0.1
tensorflow-hub==0.9.0
tensorflow-probability==0.11.1
# Optional: only used in chapter 13.
# NOT AVAILABLE ON WINDOWS
tfx==0.21.2
tfx==0.24.1
# Optional: only used in chapter 16.
# NOT AVAILABLE ON WINDOWS
tensorflow-addons==0.8.3
tensorflow-addons==0.11.2
##### Reinforcement Learning library (chapter 18)
# There are a few dependencies you need to install first, check out:
# https://github.com/openai/gym#installing-everything
gym[atari]==0.17.1
gym[atari]==0.17.3
# On Windows, install atari_py using:
# pip install --no-index -f https://github.com/Kojoley/atari-py/releases atari_py
tf-agents==0.3.0
tf-agents==0.6.0
##### Image manipulation
imageio==2.6.1
Pillow==7.0.0
scikit-image==0.16.2
graphviz==0.13.2
pydot==1.4.1
opencv-python==4.2.0.32
pyglet==1.5.0
Pillow==8.0.0
graphviz==0.14.2
opencv-python==4.4.0.44
pyglet==1.4.11
#pyvirtualdisplay # needed in chapter 16, if on a headless server
# (i.e., without screen, e.g., Colab or VM)
@ -78,10 +71,10 @@ pyglet==1.5.0
joblib==0.14.1
# Easy http requests
requests==2.23.0
requests==2.24.0
# Nice utility to diff Jupyter Notebooks.
nbdime==2.0.0
nbdime==2.1.0
# May be useful with Pandas for complex "where" clauses (e.g., Pandas
# tutorial).
@ -89,13 +82,21 @@ numexpr==2.7.1
# Optional: these libraries can be useful in the classification chapter,
# exercise 4.
nltk==3.4.5
urlextract==0.14.0
nltk==3.5
urlextract==1.1.0
# Optional: these libraries are only used in chapter 16
spacy==2.2.4
ftfy==5.7
ftfy==5.8
# Optional: tqdm displays nice progress bars, ipywidgets for tqdm's notebook support
tqdm==4.43.0
tqdm==4.50.2
ipywidgets==7.5.1
# Specific lib versions to avoid conflicts
attrs==19.3.0
cloudpickle==1.3.0
dill==0.3.1.1
gast==0.3.3
httplib2==0.17.4