Improve the solution to the Titanic exercise
parent
59c3a7f84d
commit
4488c80cf0
|
@ -1461,14 +1461,7 @@
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"First, login to [Kaggle](https://www.kaggle.com/) and go to the [Titanic challenge](https://www.kaggle.com/c/titanic) to download `train.csv` and `test.csv`. Save them to the `datasets/titanic` directory."
|
"Let's fetch the data and load it:"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Next, let's load the data:"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1478,8 +1471,21 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
|
"import urllib.request\n",
|
||||||
"\n",
|
"\n",
|
||||||
"TITANIC_PATH = os.path.join(\"datasets\", \"titanic\")"
|
"TITANIC_PATH = os.path.join(\"datasets\", \"titanic\")\n",
|
||||||
|
"DOWNLOAD_URL = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/datasets/titanic/\"\n",
|
||||||
|
"\n",
|
||||||
|
"def fetch_titanic_data(url=DOWNLOAD_URL, path=TITANIC_PATH):\n",
|
||||||
|
" if not os.path.isdir(path):\n",
|
||||||
|
" os.makedirs(path)\n",
|
||||||
|
" for filename in (\"train.csv\", \"test.csv\"):\n",
|
||||||
|
" filepath = os.path.join(path, filename)\n",
|
||||||
|
" if not os.path.isfile(filepath):\n",
|
||||||
|
" print(\"Downloading\", filename)\n",
|
||||||
|
" urllib.request.urlretrieve(url + filename, filepath)\n",
|
||||||
|
"\n",
|
||||||
|
"fetch_titanic_data() "
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1533,6 +1539,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"The attributes have the following meaning:\n",
|
"The attributes have the following meaning:\n",
|
||||||
|
"* **PassengerId**: a unique identifier for each passenger\n",
|
||||||
"* **Survived**: that's the target, 0 means the passenger did not survive, while 1 means he/she survived.\n",
|
"* **Survived**: that's the target, 0 means the passenger did not survive, while 1 means he/she survived.\n",
|
||||||
"* **Pclass**: passenger class.\n",
|
"* **Pclass**: passenger class.\n",
|
||||||
"* **Name**, **Sex**, **Age**: self-explanatory\n",
|
"* **Name**, **Sex**, **Age**: self-explanatory\n",
|
||||||
|
@ -1548,7 +1555,7 @@
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"Let's get more info to see how much data is missing:"
|
"Let's explicitly set the `PassengerId` column as the index column:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1557,14 +1564,40 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"train_data.info()"
|
"train_data = train_data.set_index(\"PassengerId\")\n",
|
||||||
|
"test_data = test_data.set_index(\"PassengerId\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"Okay, the **Age**, **Cabin** and **Embarked** attributes are sometimes null (less than 891 non-null), especially the **Cabin** (77% are null). We will ignore the **Cabin** for now and focus on the rest. The **Age** attribute has about 19% null values, so we will need to decide what to do with them. Replacing null values with the median age seems reasonable."
|
"Let's get more info to see how much data is missing:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 105,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"train_data.info()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 106,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"train_data[train_data[\"Sex\"]==\"female\"][\"Age\"].median()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Okay, the **Age**, **Cabin** and **Embarked** attributes are sometimes null (less than 891 non-null), especially the **Cabin** (77% are null). We will ignore the **Cabin** for now and focus on the rest. The **Age** attribute has about 19% null values, so we will need to decide what to do with them. Replacing null values with the median age seems reasonable. We could be a bit smarter by predicting the age based on the other columns (for example, the median age is 37 in 1st class, 29 in 2nd class and 24 in 3rd class), but we'll keep things simple and just use the overall median age."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1583,7 +1616,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 105,
|
"execution_count": 107,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1594,7 +1627,7 @@
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"* Yikes, only 38% **Survived**. :( That's close enough to 40%, so accuracy will be a reasonable metric to evaluate our model.\n",
|
"* Yikes, only 38% **Survived**! 😭 That's close enough to 40%, so accuracy will be a reasonable metric to evaluate our model.\n",
|
||||||
"* The mean **Fare** was £32.20, which does not seem so expensive (but it was probably a lot of money back then).\n",
|
"* The mean **Fare** was £32.20, which does not seem so expensive (but it was probably a lot of money back then).\n",
|
||||||
"* The mean **Age** was less than 30 years old."
|
"* The mean **Age** was less than 30 years old."
|
||||||
]
|
]
|
||||||
|
@ -1608,7 +1641,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 106,
|
"execution_count": 108,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1624,7 +1657,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 107,
|
"execution_count": 109,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1633,7 +1666,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 108,
|
"execution_count": 110,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1642,7 +1675,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 109,
|
"execution_count": 111,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1660,53 +1693,7 @@
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"**Note**: the code below uses a mix of `Pipeline`, `FeatureUnion` and a custom `DataFrameSelector` to preprocess some columns differently. Since Scikit-Learn 0.20, it is preferable to use a `ColumnTransformer`, like in the previous chapter."
|
"Now let's build our preprocessing pipelines, starting with the pipeline for numerical attributes:"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Now let's build our preprocessing pipelines. We will reuse the `DataframeSelector` we built in the previous chapter to select specific attributes from the `DataFrame`:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 110,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from sklearn.base import BaseEstimator, TransformerMixin\n",
|
|
||||||
"\n",
|
|
||||||
"class DataFrameSelector(BaseEstimator, TransformerMixin):\n",
|
|
||||||
" def __init__(self, attribute_names):\n",
|
|
||||||
" self.attribute_names = attribute_names\n",
|
|
||||||
" def fit(self, X, y=None):\n",
|
|
||||||
" return self\n",
|
|
||||||
" def transform(self, X):\n",
|
|
||||||
" return X[self.attribute_names]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Let's build the pipeline for the numerical attributes:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 111,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from sklearn.pipeline import Pipeline\n",
|
|
||||||
"from sklearn.impute import SimpleImputer\n",
|
|
||||||
"\n",
|
|
||||||
"num_pipeline = Pipeline([\n",
|
|
||||||
" (\"select_numeric\", DataFrameSelector([\"Age\", \"SibSp\", \"Parch\", \"Fare\"])),\n",
|
|
||||||
" (\"imputer\", SimpleImputer(strategy=\"median\")),\n",
|
|
||||||
" ])"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1715,39 +1702,14 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"num_pipeline.fit_transform(train_data)"
|
"from sklearn.pipeline import Pipeline\n",
|
||||||
]
|
"from sklearn.impute import SimpleImputer\n",
|
||||||
},
|
"from sklearn.preprocessing import StandardScaler\n",
|
||||||
{
|
"\n",
|
||||||
"cell_type": "markdown",
|
"num_pipeline = Pipeline([\n",
|
||||||
"metadata": {},
|
" (\"imputer\", SimpleImputer(strategy=\"median\")),\n",
|
||||||
"source": [
|
" (\"scaler\", StandardScaler())\n",
|
||||||
"We will also need an imputer for the string categorical columns (the regular `SimpleImputer` does not work on those):"
|
" ])"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 113,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Inspired from stackoverflow.com/questions/25239958\n",
|
|
||||||
"class MostFrequentImputer(BaseEstimator, TransformerMixin):\n",
|
|
||||||
" def fit(self, X, y=None):\n",
|
|
||||||
" self.most_frequent_ = pd.Series([X[c].value_counts().index[0] for c in X],\n",
|
|
||||||
" index=X.columns)\n",
|
|
||||||
" return self\n",
|
|
||||||
" def transform(self, X, y=None):\n",
|
|
||||||
" return X.fillna(self.most_frequent_)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 114,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from sklearn.preprocessing import OneHotEncoder"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1759,24 +1721,24 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 115,
|
"execution_count": 113,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"cat_pipeline = Pipeline([\n",
|
"from sklearn.preprocessing import OrdinalEncoder, OneHotEncoder"
|
||||||
" (\"select_cat\", DataFrameSelector([\"Pclass\", \"Sex\", \"Embarked\"])),\n",
|
|
||||||
" (\"imputer\", MostFrequentImputer()),\n",
|
|
||||||
" (\"cat_encoder\", OneHotEncoder(sparse=False)),\n",
|
|
||||||
" ])"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 116,
|
"execution_count": 114,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"cat_pipeline.fit_transform(train_data)"
|
"cat_pipeline = Pipeline([\n",
|
||||||
|
" (\"ordinal_encoder\", OrdinalEncoder()), \n",
|
||||||
|
" (\"imputer\", SimpleImputer(strategy=\"most_frequent\")),\n",
|
||||||
|
" (\"cat_encoder\", OneHotEncoder(sparse=False)),\n",
|
||||||
|
" ])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1788,14 +1750,18 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 117,
|
"execution_count": 115,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from sklearn.pipeline import FeatureUnion\n",
|
"from sklearn.compose import ColumnTransformer\n",
|
||||||
"preprocess_pipeline = FeatureUnion(transformer_list=[\n",
|
"\n",
|
||||||
" (\"num_pipeline\", num_pipeline),\n",
|
"num_attribs = [\"Age\", \"SibSp\", \"Parch\", \"Fare\"]\n",
|
||||||
" (\"cat_pipeline\", cat_pipeline),\n",
|
"cat_attribs = [\"Pclass\", \"Sex\", \"Embarked\"]\n",
|
||||||
|
"\n",
|
||||||
|
"preprocess_pipeline = ColumnTransformer([\n",
|
||||||
|
" (\"num\", num_pipeline, num_attribs),\n",
|
||||||
|
" (\"cat\", cat_pipeline, cat_attribs),\n",
|
||||||
" ])"
|
" ])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -1808,7 +1774,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 118,
|
"execution_count": 116,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1825,7 +1791,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 119,
|
"execution_count": 117,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1836,19 +1802,19 @@
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"We are now ready to train a classifier. Let's start with an `SVC`:"
|
"We are now ready to train a classifier. Let's start with a `RandomForestClassifier`:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 120,
|
"execution_count": 118,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from sklearn.svm import SVC\n",
|
"from sklearn.ensemble import RandomForestClassifier\n",
|
||||||
"\n",
|
"\n",
|
||||||
"svm_clf = SVC(gamma=\"auto\")\n",
|
"forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)\n",
|
||||||
"svm_clf.fit(X_train, y_train)"
|
"forest_clf.fit(X_train, y_train)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1860,12 +1826,12 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 121,
|
"execution_count": 119,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"X_test = preprocess_pipeline.transform(test_data)\n",
|
"X_test = preprocess_pipeline.transform(test_data)\n",
|
||||||
"y_pred = svm_clf.predict(X_test)"
|
"y_pred = forest_clf.predict(X_test)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1877,39 +1843,12 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 122,
|
"execution_count": 120,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from sklearn.model_selection import cross_val_score\n",
|
"from sklearn.model_selection import cross_val_score\n",
|
||||||
"\n",
|
"\n",
|
||||||
"svm_scores = cross_val_score(svm_clf, X_train, y_train, cv=10)\n",
|
|
||||||
"svm_scores.mean()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Okay, over 73% accuracy, clearly better than random chance, but it's not a great score. Looking at the [leaderboard](https://www.kaggle.com/c/titanic/leaderboard) for the Titanic competition on Kaggle, you can see that you need to reach above 80% accuracy to be within the top 10% Kagglers. Some reached 100%, but since you can easily find the [list of victims](https://www.encyclopedia-titanica.org/titanic-victims/) of the Titanic, it seems likely that there was little Machine Learning involved in their performance! ;-) So let's try to build a model that reaches 80% accuracy."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Let's try a `RandomForestClassifier`:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 123,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from sklearn.ensemble import RandomForestClassifier\n",
|
|
||||||
"\n",
|
|
||||||
"forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)\n",
|
|
||||||
"forest_scores = cross_val_score(forest_clf, X_train, y_train, cv=10)\n",
|
"forest_scores = cross_val_score(forest_clf, X_train, y_train, cv=10)\n",
|
||||||
"forest_scores.mean()"
|
"forest_scores.mean()"
|
||||||
]
|
]
|
||||||
|
@ -1918,22 +1857,51 @@
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"That's much better!"
|
"Okay, not too bad! Looking at the [leaderboard](https://www.kaggle.com/c/titanic/leaderboard) for the Titanic competition on Kaggle, you can see that our score is in the top 2%, woohoo! Some Kagglers reached 100% accuracy, but since you can easily find the [list of victims](https://www.encyclopedia-titanica.org/titanic-victims/) of the Titanic, it seems likely that there was little Machine Learning involved in their performance! 😆"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"Instead of just looking at the mean accuracy across the 10 cross-validation folds, let's plot all 10 scores for each model, along with a box plot highlighting the lower and upper quartiles, and \"whiskers\" showing the extent of the scores (thanks to Nevin Yilmaz for suggesting this visualization). Note that the `boxplot()` function detects outliers (called \"fliers\") and does not include them within the whiskers. Specifically, if the lower quartile is $Q_1$ and the upper quartile is $Q_3$, then the interquartile range $IQR = Q_3 - Q_1$ (this is the box's height), and any score lower than $Q_1 - 1.5 \\times IQR$ is a flier, and so is any score greater than $Q3 + 1.5 \\times IQR$."
|
"Let's try an `SVC`:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 124,
|
"execution_count": 121,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"from sklearn.svm import SVC\n",
|
||||||
|
"\n",
|
||||||
|
"svm_clf = SVC(gamma=\"auto\")\n",
|
||||||
|
"svm_scores = cross_val_score(svm_clf, X_train, y_train, cv=10)\n",
|
||||||
|
"svm_scores.mean()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Great! This model looks better."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"But instead of just looking at the mean accuracy across the 10 cross-validation folds, let's plot all 10 scores for each model, along with a box plot highlighting the lower and upper quartiles, and \"whiskers\" showing the extent of the scores (thanks to Nevin Yilmaz for suggesting this visualization). Note that the `boxplot()` function detects outliers (called \"fliers\") and does not include them within the whiskers. Specifically, if the lower quartile is $Q_1$ and the upper quartile is $Q_3$, then the interquartile range $IQR = Q_3 - Q_1$ (this is the box's height), and any score lower than $Q_1 - 1.5 \\times IQR$ is a flier, and so is any score greater than $Q3 + 1.5 \\times IQR$."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 122,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"\n",
|
||||||
"plt.figure(figsize=(8, 4))\n",
|
"plt.figure(figsize=(8, 4))\n",
|
||||||
"plt.plot([1]*10, svm_scores, \".\")\n",
|
"plt.plot([1]*10, svm_scores, \".\")\n",
|
||||||
"plt.plot([2]*10, forest_scores, \".\")\n",
|
"plt.plot([2]*10, forest_scores, \".\")\n",
|
||||||
|
@ -1942,6 +1910,13 @@
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"The random forest classifier got a very high score on one of the 10 folds, but overall it had a lower mean score, as well as a bigger spread, so it looks like the SVM classifier is more likely to generalize well."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@ -1949,14 +1924,15 @@
|
||||||
"To improve this result further, you could:\n",
|
"To improve this result further, you could:\n",
|
||||||
"* Compare many more models and tune hyperparameters using cross validation and grid search,\n",
|
"* Compare many more models and tune hyperparameters using cross validation and grid search,\n",
|
||||||
"* Do more feature engineering, for example:\n",
|
"* Do more feature engineering, for example:\n",
|
||||||
" * replace **SibSp** and **Parch** with their sum,\n",
|
" * Try to convert numerical attributes to categorical attributes: for example, different age groups had very different survival rates (see below), so it may help to create an age bucket category and use it instead of the age. Similarly, it may be useful to have a special category for people traveling alone since only 30% of them survived (see below).\n",
|
||||||
" * try to identify parts of names that correlate well with the **Survived** attribute (e.g. if the name contains \"Countess\", then survival seems more likely),\n",
|
" * Replace **SibSp** and **Parch** with their sum.\n",
|
||||||
"* try to convert numerical attributes to categorical attributes: for example, different age groups had very different survival rates (see below), so it may help to create an age bucket category and use it instead of the age. Similarly, it may be useful to have a special category for people traveling alone since only 30% of them survived (see below)."
|
" * Try to identify parts of names that correlate well with the **Survived** attribute.\n",
|
||||||
|
" * Use the **Cabin** column, for example take its first letter and treat it as a categorical attribute."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 125,
|
"execution_count": 123,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1966,7 +1942,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 126,
|
"execution_count": 124,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1990,7 +1966,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 127,
|
"execution_count": 125,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2017,7 +1993,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 128,
|
"execution_count": 126,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2033,7 +2009,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 129,
|
"execution_count": 127,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2045,7 +2021,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 130,
|
"execution_count": 128,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2054,7 +2030,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 131,
|
"execution_count": 129,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2070,7 +2046,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 132,
|
"execution_count": 130,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2085,7 +2061,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 133,
|
"execution_count": 131,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2102,7 +2078,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 134,
|
"execution_count": 132,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2111,7 +2087,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 135,
|
"execution_count": 133,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2127,7 +2103,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 136,
|
"execution_count": 134,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2146,7 +2122,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 137,
|
"execution_count": 135,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2162,7 +2138,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 138,
|
"execution_count": 136,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2171,7 +2147,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 139,
|
"execution_count": 137,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2194,7 +2170,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 140,
|
"execution_count": 138,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2211,7 +2187,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 141,
|
"execution_count": 139,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2227,7 +2203,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 142,
|
"execution_count": 140,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2249,7 +2225,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 143,
|
"execution_count": 141,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2273,7 +2249,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 144,
|
"execution_count": 142,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2292,7 +2268,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 145,
|
"execution_count": 143,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2308,7 +2284,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 146,
|
"execution_count": 144,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2332,7 +2308,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 147,
|
"execution_count": 145,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2350,7 +2326,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 148,
|
"execution_count": 146,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2376,7 +2352,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 149,
|
"execution_count": 147,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2394,7 +2370,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 150,
|
"execution_count": 148,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2417,7 +2393,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 151,
|
"execution_count": 149,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2469,7 +2445,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 152,
|
"execution_count": 150,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2494,7 +2470,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 153,
|
"execution_count": 151,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2525,7 +2501,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 154,
|
"execution_count": 152,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2536,7 +2512,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 155,
|
"execution_count": 153,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2552,7 +2528,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 156,
|
"execution_count": 154,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2568,7 +2544,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 157,
|
"execution_count": 155,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2591,7 +2567,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 158,
|
"execution_count": 156,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2614,7 +2590,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 159,
|
"execution_count": 157,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
|
Loading…
Reference in New Issue