From 4488c80cf0e2c1b821acb1e95d7a1774388ecb49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Mon, 11 Oct 2021 17:46:10 +1300 Subject: [PATCH] Improve the solution to the Titanic exercise --- 03_classification.ipynb | 354 +++++++++++++++++++--------------------- 1 file changed, 165 insertions(+), 189 deletions(-) diff --git a/03_classification.ipynb b/03_classification.ipynb index 877ab61..93fcf5f 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -1461,14 +1461,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "First, login to [Kaggle](https://www.kaggle.com/) and go to the [Titanic challenge](https://www.kaggle.com/c/titanic) to download `train.csv` and `test.csv`. Save them to the `datasets/titanic` directory." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, let's load the data:" + "Let's fetch the data and load it:" ] }, { @@ -1478,8 +1471,21 @@ "outputs": [], "source": [ "import os\n", + "import urllib.request\n", "\n", - "TITANIC_PATH = os.path.join(\"datasets\", \"titanic\")" + "TITANIC_PATH = os.path.join(\"datasets\", \"titanic\")\n", + "DOWNLOAD_URL = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/datasets/titanic/\"\n", + "\n", + "def fetch_titanic_data(url=DOWNLOAD_URL, path=TITANIC_PATH):\n", + " if not os.path.isdir(path):\n", + " os.makedirs(path)\n", + " for filename in (\"train.csv\", \"test.csv\"):\n", + " filepath = os.path.join(path, filename)\n", + " if not os.path.isfile(filepath):\n", + " print(\"Downloading\", filename)\n", + " urllib.request.urlretrieve(url + filename, filepath)\n", + "\n", + "fetch_titanic_data() " ] }, { @@ -1533,6 +1539,7 @@ "metadata": {}, "source": [ "The attributes have the following meaning:\n", + "* **PassengerId**: a unique identifier for each passenger\n", "* **Survived**: that's the target, 0 means the passenger did not survive, while 1 means he/she survived.\n", "* **Pclass**: passenger class.\n", "* **Name**, **Sex**, **Age**: self-explanatory\n", @@ -1548,7 +1555,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let's get more info to see how much data is missing:" + "Let's explicitly set the `PassengerId` column as the index column:" ] }, { @@ -1557,14 +1564,40 @@ "metadata": {}, "outputs": [], "source": [ - "train_data.info()" + "train_data = train_data.set_index(\"PassengerId\")\n", + "test_data = test_data.set_index(\"PassengerId\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Okay, the **Age**, **Cabin** and **Embarked** attributes are sometimes null (less than 891 non-null), especially the **Cabin** (77% are null). We will ignore the **Cabin** for now and focus on the rest. The **Age** attribute has about 19% null values, so we will need to decide what to do with them. Replacing null values with the median age seems reasonable." + "Let's get more info to see how much data is missing:" + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "metadata": {}, + "outputs": [], + "source": [ + "train_data.info()" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "metadata": {}, + "outputs": [], + "source": [ + "train_data[train_data[\"Sex\"]==\"female\"][\"Age\"].median()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Okay, the **Age**, **Cabin** and **Embarked** attributes are sometimes null (less than 891 non-null), especially the **Cabin** (77% are null). We will ignore the **Cabin** for now and focus on the rest. The **Age** attribute has about 19% null values, so we will need to decide what to do with them. Replacing null values with the median age seems reasonable. We could be a bit smarter by predicting the age based on the other columns (for example, the median age is 37 in 1st class, 29 in 2nd class and 24 in 3rd class), but we'll keep things simple and just use the overall median age." ] }, { @@ -1583,7 +1616,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": 107, "metadata": {}, "outputs": [], "source": [ @@ -1594,7 +1627,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "* Yikes, only 38% **Survived**. :( That's close enough to 40%, so accuracy will be a reasonable metric to evaluate our model.\n", + "* Yikes, only 38% **Survived**! 😭 That's close enough to 40%, so accuracy will be a reasonable metric to evaluate our model.\n", "* The mean **Fare** was £32.20, which does not seem so expensive (but it was probably a lot of money back then).\n", "* The mean **Age** was less than 30 years old." ] @@ -1608,7 +1641,7 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 108, "metadata": {}, "outputs": [], "source": [ @@ -1624,7 +1657,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 109, "metadata": {}, "outputs": [], "source": [ @@ -1633,7 +1666,7 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 110, "metadata": {}, "outputs": [], "source": [ @@ -1642,7 +1675,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 111, "metadata": {}, "outputs": [], "source": [ @@ -1660,53 +1693,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "**Note**: the code below uses a mix of `Pipeline`, `FeatureUnion` and a custom `DataFrameSelector` to preprocess some columns differently. Since Scikit-Learn 0.20, it is preferable to use a `ColumnTransformer`, like in the previous chapter." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let's build our preprocessing pipelines. We will reuse the `DataframeSelector` we built in the previous chapter to select specific attributes from the `DataFrame`:" - ] - }, - { - "cell_type": "code", - "execution_count": 110, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.base import BaseEstimator, TransformerMixin\n", - "\n", - "class DataFrameSelector(BaseEstimator, TransformerMixin):\n", - " def __init__(self, attribute_names):\n", - " self.attribute_names = attribute_names\n", - " def fit(self, X, y=None):\n", - " return self\n", - " def transform(self, X):\n", - " return X[self.attribute_names]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's build the pipeline for the numerical attributes:" - ] - }, - { - "cell_type": "code", - "execution_count": 111, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.pipeline import Pipeline\n", - "from sklearn.impute import SimpleImputer\n", - "\n", - "num_pipeline = Pipeline([\n", - " (\"select_numeric\", DataFrameSelector([\"Age\", \"SibSp\", \"Parch\", \"Fare\"])),\n", - " (\"imputer\", SimpleImputer(strategy=\"median\")),\n", - " ])" + "Now let's build our preprocessing pipelines, starting with the pipeline for numerical attributes:" ] }, { @@ -1715,39 +1702,14 @@ "metadata": {}, "outputs": [], "source": [ - "num_pipeline.fit_transform(train_data)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We will also need an imputer for the string categorical columns (the regular `SimpleImputer` does not work on those):" - ] - }, - { - "cell_type": "code", - "execution_count": 113, - "metadata": {}, - "outputs": [], - "source": [ - "# Inspired from stackoverflow.com/questions/25239958\n", - "class MostFrequentImputer(BaseEstimator, TransformerMixin):\n", - " def fit(self, X, y=None):\n", - " self.most_frequent_ = pd.Series([X[c].value_counts().index[0] for c in X],\n", - " index=X.columns)\n", - " return self\n", - " def transform(self, X, y=None):\n", - " return X.fillna(self.most_frequent_)" - ] - }, - { - "cell_type": "code", - "execution_count": 114, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.preprocessing import OneHotEncoder" + "from sklearn.pipeline import Pipeline\n", + "from sklearn.impute import SimpleImputer\n", + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "num_pipeline = Pipeline([\n", + " (\"imputer\", SimpleImputer(strategy=\"median\")),\n", + " (\"scaler\", StandardScaler())\n", + " ])" ] }, { @@ -1759,24 +1721,24 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 113, "metadata": {}, "outputs": [], "source": [ - "cat_pipeline = Pipeline([\n", - " (\"select_cat\", DataFrameSelector([\"Pclass\", \"Sex\", \"Embarked\"])),\n", - " (\"imputer\", MostFrequentImputer()),\n", - " (\"cat_encoder\", OneHotEncoder(sparse=False)),\n", - " ])" + "from sklearn.preprocessing import OrdinalEncoder, OneHotEncoder" ] }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 114, "metadata": {}, "outputs": [], "source": [ - "cat_pipeline.fit_transform(train_data)" + "cat_pipeline = Pipeline([\n", + " (\"ordinal_encoder\", OrdinalEncoder()), \n", + " (\"imputer\", SimpleImputer(strategy=\"most_frequent\")),\n", + " (\"cat_encoder\", OneHotEncoder(sparse=False)),\n", + " ])" ] }, { @@ -1788,14 +1750,18 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 115, "metadata": {}, "outputs": [], "source": [ - "from sklearn.pipeline import FeatureUnion\n", - "preprocess_pipeline = FeatureUnion(transformer_list=[\n", - " (\"num_pipeline\", num_pipeline),\n", - " (\"cat_pipeline\", cat_pipeline),\n", + "from sklearn.compose import ColumnTransformer\n", + "\n", + "num_attribs = [\"Age\", \"SibSp\", \"Parch\", \"Fare\"]\n", + "cat_attribs = [\"Pclass\", \"Sex\", \"Embarked\"]\n", + "\n", + "preprocess_pipeline = ColumnTransformer([\n", + " (\"num\", num_pipeline, num_attribs),\n", + " (\"cat\", cat_pipeline, cat_attribs),\n", " ])" ] }, @@ -1808,7 +1774,7 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 116, "metadata": {}, "outputs": [], "source": [ @@ -1825,7 +1791,7 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 117, "metadata": {}, "outputs": [], "source": [ @@ -1836,19 +1802,19 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We are now ready to train a classifier. Let's start with an `SVC`:" + "We are now ready to train a classifier. Let's start with a `RandomForestClassifier`:" ] }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 118, "metadata": {}, "outputs": [], "source": [ - "from sklearn.svm import SVC\n", + "from sklearn.ensemble import RandomForestClassifier\n", "\n", - "svm_clf = SVC(gamma=\"auto\")\n", - "svm_clf.fit(X_train, y_train)" + "forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)\n", + "forest_clf.fit(X_train, y_train)" ] }, { @@ -1860,12 +1826,12 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 119, "metadata": {}, "outputs": [], "source": [ "X_test = preprocess_pipeline.transform(test_data)\n", - "y_pred = svm_clf.predict(X_test)" + "y_pred = forest_clf.predict(X_test)" ] }, { @@ -1877,39 +1843,12 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": 120, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import cross_val_score\n", "\n", - "svm_scores = cross_val_score(svm_clf, X_train, y_train, cv=10)\n", - "svm_scores.mean()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Okay, over 73% accuracy, clearly better than random chance, but it's not a great score. Looking at the [leaderboard](https://www.kaggle.com/c/titanic/leaderboard) for the Titanic competition on Kaggle, you can see that you need to reach above 80% accuracy to be within the top 10% Kagglers. Some reached 100%, but since you can easily find the [list of victims](https://www.encyclopedia-titanica.org/titanic-victims/) of the Titanic, it seems likely that there was little Machine Learning involved in their performance! ;-) So let's try to build a model that reaches 80% accuracy." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's try a `RandomForestClassifier`:" - ] - }, - { - "cell_type": "code", - "execution_count": 123, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.ensemble import RandomForestClassifier\n", - "\n", - "forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)\n", "forest_scores = cross_val_score(forest_clf, X_train, y_train, cv=10)\n", "forest_scores.mean()" ] @@ -1918,22 +1857,51 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "That's much better!" + "Okay, not too bad! Looking at the [leaderboard](https://www.kaggle.com/c/titanic/leaderboard) for the Titanic competition on Kaggle, you can see that our score is in the top 2%, woohoo! Some Kagglers reached 100% accuracy, but since you can easily find the [list of victims](https://www.encyclopedia-titanica.org/titanic-victims/) of the Titanic, it seems likely that there was little Machine Learning involved in their performance! 😆" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Instead of just looking at the mean accuracy across the 10 cross-validation folds, let's plot all 10 scores for each model, along with a box plot highlighting the lower and upper quartiles, and \"whiskers\" showing the extent of the scores (thanks to Nevin Yilmaz for suggesting this visualization). Note that the `boxplot()` function detects outliers (called \"fliers\") and does not include them within the whiskers. Specifically, if the lower quartile is $Q_1$ and the upper quartile is $Q_3$, then the interquartile range $IQR = Q_3 - Q_1$ (this is the box's height), and any score lower than $Q_1 - 1.5 \\times IQR$ is a flier, and so is any score greater than $Q3 + 1.5 \\times IQR$." + "Let's try an `SVC`:" ] }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 121, "metadata": {}, "outputs": [], "source": [ + "from sklearn.svm import SVC\n", + "\n", + "svm_clf = SVC(gamma=\"auto\")\n", + "svm_scores = cross_val_score(svm_clf, X_train, y_train, cv=10)\n", + "svm_scores.mean()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Great! This model looks better." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "But instead of just looking at the mean accuracy across the 10 cross-validation folds, let's plot all 10 scores for each model, along with a box plot highlighting the lower and upper quartiles, and \"whiskers\" showing the extent of the scores (thanks to Nevin Yilmaz for suggesting this visualization). Note that the `boxplot()` function detects outliers (called \"fliers\") and does not include them within the whiskers. Specifically, if the lower quartile is $Q_1$ and the upper quartile is $Q_3$, then the interquartile range $IQR = Q_3 - Q_1$ (this is the box's height), and any score lower than $Q_1 - 1.5 \\times IQR$ is a flier, and so is any score greater than $Q3 + 1.5 \\times IQR$." + ] + }, + { + "cell_type": "code", + "execution_count": 122, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", "plt.figure(figsize=(8, 4))\n", "plt.plot([1]*10, svm_scores, \".\")\n", "plt.plot([2]*10, forest_scores, \".\")\n", @@ -1942,6 +1910,13 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The random forest classifier got a very high score on one of the 10 folds, but overall it had a lower mean score, as well as a bigger spread, so it looks like the SVM classifier is more likely to generalize well." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1949,14 +1924,15 @@ "To improve this result further, you could:\n", "* Compare many more models and tune hyperparameters using cross validation and grid search,\n", "* Do more feature engineering, for example:\n", - " * replace **SibSp** and **Parch** with their sum,\n", - " * try to identify parts of names that correlate well with the **Survived** attribute (e.g. if the name contains \"Countess\", then survival seems more likely),\n", - "* try to convert numerical attributes to categorical attributes: for example, different age groups had very different survival rates (see below), so it may help to create an age bucket category and use it instead of the age. Similarly, it may be useful to have a special category for people traveling alone since only 30% of them survived (see below)." + " * Try to convert numerical attributes to categorical attributes: for example, different age groups had very different survival rates (see below), so it may help to create an age bucket category and use it instead of the age. Similarly, it may be useful to have a special category for people traveling alone since only 30% of them survived (see below).\n", + " * Replace **SibSp** and **Parch** with their sum.\n", + " * Try to identify parts of names that correlate well with the **Survived** attribute.\n", + " * Use the **Cabin** column, for example take its first letter and treat it as a categorical attribute." ] }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 123, "metadata": {}, "outputs": [], "source": [ @@ -1966,7 +1942,7 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 124, "metadata": {}, "outputs": [], "source": [ @@ -1990,7 +1966,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 125, "metadata": {}, "outputs": [], "source": [ @@ -2017,7 +1993,7 @@ }, { "cell_type": "code", - "execution_count": 128, + "execution_count": 126, "metadata": {}, "outputs": [], "source": [ @@ -2033,7 +2009,7 @@ }, { "cell_type": "code", - "execution_count": 129, + "execution_count": 127, "metadata": {}, "outputs": [], "source": [ @@ -2045,7 +2021,7 @@ }, { "cell_type": "code", - "execution_count": 130, + "execution_count": 128, "metadata": {}, "outputs": [], "source": [ @@ -2054,7 +2030,7 @@ }, { "cell_type": "code", - "execution_count": 131, + "execution_count": 129, "metadata": {}, "outputs": [], "source": [ @@ -2070,7 +2046,7 @@ }, { "cell_type": "code", - "execution_count": 132, + "execution_count": 130, "metadata": {}, "outputs": [], "source": [ @@ -2085,7 +2061,7 @@ }, { "cell_type": "code", - "execution_count": 133, + "execution_count": 131, "metadata": {}, "outputs": [], "source": [ @@ -2102,7 +2078,7 @@ }, { "cell_type": "code", - "execution_count": 134, + "execution_count": 132, "metadata": {}, "outputs": [], "source": [ @@ -2111,7 +2087,7 @@ }, { "cell_type": "code", - "execution_count": 135, + "execution_count": 133, "metadata": {}, "outputs": [], "source": [ @@ -2127,7 +2103,7 @@ }, { "cell_type": "code", - "execution_count": 136, + "execution_count": 134, "metadata": {}, "outputs": [], "source": [ @@ -2146,7 +2122,7 @@ }, { "cell_type": "code", - "execution_count": 137, + "execution_count": 135, "metadata": {}, "outputs": [], "source": [ @@ -2162,7 +2138,7 @@ }, { "cell_type": "code", - "execution_count": 138, + "execution_count": 136, "metadata": {}, "outputs": [], "source": [ @@ -2171,7 +2147,7 @@ }, { "cell_type": "code", - "execution_count": 139, + "execution_count": 137, "metadata": {}, "outputs": [], "source": [ @@ -2194,7 +2170,7 @@ }, { "cell_type": "code", - "execution_count": 140, + "execution_count": 138, "metadata": {}, "outputs": [], "source": [ @@ -2211,7 +2187,7 @@ }, { "cell_type": "code", - "execution_count": 141, + "execution_count": 139, "metadata": {}, "outputs": [], "source": [ @@ -2227,7 +2203,7 @@ }, { "cell_type": "code", - "execution_count": 142, + "execution_count": 140, "metadata": {}, "outputs": [], "source": [ @@ -2249,7 +2225,7 @@ }, { "cell_type": "code", - "execution_count": 143, + "execution_count": 141, "metadata": {}, "outputs": [], "source": [ @@ -2273,7 +2249,7 @@ }, { "cell_type": "code", - "execution_count": 144, + "execution_count": 142, "metadata": {}, "outputs": [], "source": [ @@ -2292,7 +2268,7 @@ }, { "cell_type": "code", - "execution_count": 145, + "execution_count": 143, "metadata": {}, "outputs": [], "source": [ @@ -2308,7 +2284,7 @@ }, { "cell_type": "code", - "execution_count": 146, + "execution_count": 144, "metadata": {}, "outputs": [], "source": [ @@ -2332,7 +2308,7 @@ }, { "cell_type": "code", - "execution_count": 147, + "execution_count": 145, "metadata": {}, "outputs": [], "source": [ @@ -2350,7 +2326,7 @@ }, { "cell_type": "code", - "execution_count": 148, + "execution_count": 146, "metadata": {}, "outputs": [], "source": [ @@ -2376,7 +2352,7 @@ }, { "cell_type": "code", - "execution_count": 149, + "execution_count": 147, "metadata": {}, "outputs": [], "source": [ @@ -2394,7 +2370,7 @@ }, { "cell_type": "code", - "execution_count": 150, + "execution_count": 148, "metadata": {}, "outputs": [], "source": [ @@ -2417,7 +2393,7 @@ }, { "cell_type": "code", - "execution_count": 151, + "execution_count": 149, "metadata": {}, "outputs": [], "source": [ @@ -2469,7 +2445,7 @@ }, { "cell_type": "code", - "execution_count": 152, + "execution_count": 150, "metadata": {}, "outputs": [], "source": [ @@ -2494,7 +2470,7 @@ }, { "cell_type": "code", - "execution_count": 153, + "execution_count": 151, "metadata": {}, "outputs": [], "source": [ @@ -2525,7 +2501,7 @@ }, { "cell_type": "code", - "execution_count": 154, + "execution_count": 152, "metadata": {}, "outputs": [], "source": [ @@ -2536,7 +2512,7 @@ }, { "cell_type": "code", - "execution_count": 155, + "execution_count": 153, "metadata": {}, "outputs": [], "source": [ @@ -2552,7 +2528,7 @@ }, { "cell_type": "code", - "execution_count": 156, + "execution_count": 154, "metadata": {}, "outputs": [], "source": [ @@ -2568,7 +2544,7 @@ }, { "cell_type": "code", - "execution_count": 157, + "execution_count": 155, "metadata": {}, "outputs": [], "source": [ @@ -2591,7 +2567,7 @@ }, { "cell_type": "code", - "execution_count": 158, + "execution_count": 156, "metadata": {}, "outputs": [], "source": [ @@ -2614,7 +2590,7 @@ }, { "cell_type": "code", - "execution_count": 159, + "execution_count": 157, "metadata": {}, "outputs": [], "source": [