Replace FeatureUnion + DataFrameSelector with new ColumnTransformer

main
Aurélien Geron 2018-07-31 20:08:33 +01:00
parent 060751a976
commit e2d450708a
1 changed files with 131 additions and 79 deletions

View File

@ -70,7 +70,7 @@
"\n",
"# Ignore useless warnings (see SciPy issue #5998)\n",
"import warnings\n",
"warnings.filterwarnings(action=\"ignore\", module=\"scipy\", message=\"^internal gelsd\")"
"warnings.filterwarnings(action=\"ignore\", module=\"scipy\", message=\"internal gelsd\")"
]
},
{
@ -980,12 +980,56 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"And a transformer to just select a subset of the Pandas DataFrame columns:"
"**Warning**: earlier versions of the book applied different transformations to different columns using a solution based on a `DataFrameSelector` transformer and a `FeatureUnion` (see below). It is now preferable to use the `ColumnTransformer` class that will be introduced in Scikit-Learn 0.20. For now we import it from `future_encoders.py`, but when Scikit-Learn 0.20 is released, you can import it from `sklearn.compose` instead:"
]
},
{
"cell_type": "code",
"execution_count": 71,
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"from future_encoders import ColumnTransformer\n",
"\n",
"num_attribs = list(housing_num)\n",
"cat_attribs = [\"ocean_proximity\"]\n",
"\n",
"full_pipeline = ColumnTransformer([\n",
" (\"num\", num_pipeline, num_attribs),\n",
" (\"cat\", OneHotEncoder(), cat_attribs),\n",
" ])\n",
"\n",
"housing_prepared = full_pipeline.fit_transform(housing)"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"housing_prepared"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"housing_prepared.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For reference, here is the old solution based on a `DataFrameSelector` transformer (to just select a subset of the Pandas `DataFrame` columns), and a `FeatureUnion`:"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
@ -993,7 +1037,7 @@
"\n",
"# Create a class to select numerical or categorical columns \n",
"# since Scikit-Learn doesn't handle DataFrames yet\n",
"class DataFrameSelector(BaseEstimator, TransformerMixin):\n",
"class OldDataFrameSelector(BaseEstimator, TransformerMixin):\n",
" def __init__(self, attribute_names):\n",
" self.attribute_names = attribute_names\n",
" def fit(self, X, y=None):\n",
@ -1011,57 +1055,64 @@
},
{
"cell_type": "code",
"execution_count": 72,
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"num_attribs = list(housing_num)\n",
"cat_attribs = [\"ocean_proximity\"]\n",
"\n",
"num_pipeline = Pipeline([\n",
" ('selector', DataFrameSelector(num_attribs)),\n",
"old_num_pipeline = Pipeline([\n",
" ('selector', OldDataFrameSelector(num_attribs)),\n",
" ('imputer', Imputer(strategy=\"median\")),\n",
" ('attribs_adder', CombinedAttributesAdder()),\n",
" ('std_scaler', StandardScaler()),\n",
" ])\n",
"\n",
"cat_pipeline = Pipeline([\n",
" ('selector', DataFrameSelector(cat_attribs)),\n",
"old_cat_pipeline = Pipeline([\n",
" ('selector', OldDataFrameSelector(cat_attribs)),\n",
" ('cat_encoder', OneHotEncoder(sparse=False)),\n",
" ])"
]
},
{
"cell_type": "code",
"execution_count": 73,
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import FeatureUnion\n",
"\n",
"full_pipeline = FeatureUnion(transformer_list=[\n",
" (\"num_pipeline\", num_pipeline),\n",
" (\"cat_pipeline\", cat_pipeline),\n",
"old_full_pipeline = FeatureUnion(transformer_list=[\n",
" (\"num_pipeline\", old_num_pipeline),\n",
" (\"cat_pipeline\", old_cat_pipeline),\n",
" ])"
]
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"housing_prepared = full_pipeline.fit_transform(housing)\n",
"housing_prepared"
"old_housing_prepared = old_full_pipeline.fit_transform(housing)\n",
"old_housing_prepared"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The result is the same as with the `ColumnTransformer`:"
]
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
"housing_prepared.shape"
"np.allclose(housing_prepared, old_housing_prepared)"
]
},
{
@ -1073,7 +1124,7 @@
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
@ -1085,11 +1136,11 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
"# let's try the full pipeline on a few training instances\n",
"# let's try the full preprocessing pipeline on a few training instances\n",
"some_data = housing.iloc[:5]\n",
"some_labels = housing_labels.iloc[:5]\n",
"some_data_prepared = full_pipeline.transform(some_data)\n",
@ -1106,7 +1157,7 @@
},
{
"cell_type": "code",
"execution_count": 78,
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
@ -1115,7 +1166,7 @@
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
@ -1124,7 +1175,7 @@
},
{
"cell_type": "code",
"execution_count": 80,
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
@ -1138,7 +1189,7 @@
},
{
"cell_type": "code",
"execution_count": 81,
"execution_count": 85,
"metadata": {},
"outputs": [],
"source": [
@ -1150,7 +1201,7 @@
},
{
"cell_type": "code",
"execution_count": 82,
"execution_count": 86,
"metadata": {},
"outputs": [],
"source": [
@ -1162,7 +1213,7 @@
},
{
"cell_type": "code",
"execution_count": 83,
"execution_count": 87,
"metadata": {},
"outputs": [],
"source": [
@ -1181,7 +1232,7 @@
},
{
"cell_type": "code",
"execution_count": 84,
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [
@ -1194,7 +1245,7 @@
},
{
"cell_type": "code",
"execution_count": 85,
"execution_count": 89,
"metadata": {},
"outputs": [],
"source": [
@ -1208,7 +1259,7 @@
},
{
"cell_type": "code",
"execution_count": 86,
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
@ -1220,7 +1271,7 @@
},
{
"cell_type": "code",
"execution_count": 87,
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
@ -1232,7 +1283,7 @@
},
{
"cell_type": "code",
"execution_count": 88,
"execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
@ -1244,7 +1295,7 @@
},
{
"cell_type": "code",
"execution_count": 89,
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
@ -1258,7 +1309,7 @@
},
{
"cell_type": "code",
"execution_count": 90,
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
@ -1268,7 +1319,7 @@
},
{
"cell_type": "code",
"execution_count": 91,
"execution_count": 95,
"metadata": {},
"outputs": [],
"source": [
@ -1284,7 +1335,7 @@
},
{
"cell_type": "code",
"execution_count": 92,
"execution_count": 96,
"metadata": {},
"outputs": [],
"source": [
@ -1313,7 +1364,7 @@
},
{
"cell_type": "code",
"execution_count": 93,
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
@ -1322,7 +1373,7 @@
},
{
"cell_type": "code",
"execution_count": 94,
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
@ -1338,7 +1389,7 @@
},
{
"cell_type": "code",
"execution_count": 95,
"execution_count": 99,
"metadata": {},
"outputs": [],
"source": [
@ -1349,7 +1400,7 @@
},
{
"cell_type": "code",
"execution_count": 96,
"execution_count": 100,
"metadata": {},
"outputs": [],
"source": [
@ -1358,7 +1409,7 @@
},
{
"cell_type": "code",
"execution_count": 97,
"execution_count": 101,
"metadata": {},
"outputs": [],
"source": [
@ -1378,7 +1429,7 @@
},
{
"cell_type": "code",
"execution_count": 98,
"execution_count": 102,
"metadata": {},
"outputs": [],
"source": [
@ -1389,7 +1440,7 @@
},
{
"cell_type": "code",
"execution_count": 99,
"execution_count": 103,
"metadata": {},
"outputs": [],
"source": [
@ -1399,12 +1450,13 @@
},
{
"cell_type": "code",
"execution_count": 100,
"execution_count": 104,
"metadata": {},
"outputs": [],
"source": [
"extra_attribs = [\"rooms_per_hhold\", \"pop_per_hhold\", \"bedrooms_per_room\"]\n",
"cat_encoder = cat_pipeline.named_steps[\"cat_encoder\"]\n",
"#cat_encoder = cat_pipeline.named_steps[\"cat_encoder\"] # old solution\n",
"cat_encoder = full_pipeline.named_transformers_[\"cat\"]\n",
"cat_one_hot_attribs = list(cat_encoder.categories_[0])\n",
"attributes = num_attribs + extra_attribs + cat_one_hot_attribs\n",
"sorted(zip(feature_importances, attributes), reverse=True)"
@ -1412,7 +1464,7 @@
},
{
"cell_type": "code",
"execution_count": 101,
"execution_count": 105,
"metadata": {},
"outputs": [],
"source": [
@ -1430,7 +1482,7 @@
},
{
"cell_type": "code",
"execution_count": 102,
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
@ -1446,7 +1498,7 @@
},
{
"cell_type": "code",
"execution_count": 103,
"execution_count": 107,
"metadata": {},
"outputs": [],
"source": [
@ -1455,7 +1507,7 @@
},
{
"cell_type": "code",
"execution_count": 104,
"execution_count": 108,
"metadata": {},
"outputs": [],
"source": [
@ -1478,7 +1530,7 @@
},
{
"cell_type": "code",
"execution_count": 105,
"execution_count": 109,
"metadata": {},
"outputs": [],
"source": [
@ -1496,7 +1548,7 @@
},
{
"cell_type": "code",
"execution_count": 106,
"execution_count": 110,
"metadata": {},
"outputs": [],
"source": [
@ -1521,7 +1573,7 @@
},
{
"cell_type": "code",
"execution_count": 107,
"execution_count": 111,
"metadata": {},
"outputs": [],
"source": [
@ -1543,7 +1595,7 @@
},
{
"cell_type": "code",
"execution_count": 108,
"execution_count": 112,
"metadata": {},
"outputs": [],
"source": [
@ -1552,7 +1604,7 @@
},
{
"cell_type": "code",
"execution_count": 109,
"execution_count": 113,
"metadata": {},
"outputs": [],
"source": [
@ -1571,7 +1623,7 @@
},
{
"cell_type": "code",
"execution_count": 110,
"execution_count": 114,
"metadata": {},
"outputs": [],
"source": [
@ -1609,7 +1661,7 @@
},
{
"cell_type": "code",
"execution_count": 111,
"execution_count": 115,
"metadata": {},
"outputs": [],
"source": [
@ -1635,7 +1687,7 @@
},
{
"cell_type": "code",
"execution_count": 112,
"execution_count": 116,
"metadata": {},
"outputs": [],
"source": [
@ -1653,7 +1705,7 @@
},
{
"cell_type": "code",
"execution_count": 113,
"execution_count": 117,
"metadata": {},
"outputs": [],
"source": [
@ -1683,7 +1735,7 @@
},
{
"cell_type": "code",
"execution_count": 114,
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
@ -1716,7 +1768,7 @@
},
{
"cell_type": "code",
"execution_count": 115,
"execution_count": 119,
"metadata": {},
"outputs": [],
"source": [
@ -1734,7 +1786,7 @@
},
{
"cell_type": "code",
"execution_count": 116,
"execution_count": 120,
"metadata": {},
"outputs": [],
"source": [
@ -1757,7 +1809,7 @@
},
{
"cell_type": "code",
"execution_count": 117,
"execution_count": 121,
"metadata": {},
"outputs": [],
"source": [
@ -1782,7 +1834,7 @@
},
{
"cell_type": "code",
"execution_count": 118,
"execution_count": 122,
"metadata": {},
"outputs": [],
"source": [
@ -1821,7 +1873,7 @@
},
{
"cell_type": "code",
"execution_count": 119,
"execution_count": 123,
"metadata": {},
"outputs": [],
"source": [
@ -1857,7 +1909,7 @@
},
{
"cell_type": "code",
"execution_count": 120,
"execution_count": 124,
"metadata": {},
"outputs": [],
"source": [
@ -1873,7 +1925,7 @@
},
{
"cell_type": "code",
"execution_count": 121,
"execution_count": 125,
"metadata": {},
"outputs": [],
"source": [
@ -1883,7 +1935,7 @@
},
{
"cell_type": "code",
"execution_count": 122,
"execution_count": 126,
"metadata": {},
"outputs": [],
"source": [
@ -1899,7 +1951,7 @@
},
{
"cell_type": "code",
"execution_count": 123,
"execution_count": 127,
"metadata": {},
"outputs": [],
"source": [
@ -1915,7 +1967,7 @@
},
{
"cell_type": "code",
"execution_count": 124,
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
@ -1927,7 +1979,7 @@
},
{
"cell_type": "code",
"execution_count": 125,
"execution_count": 129,
"metadata": {},
"outputs": [],
"source": [
@ -1943,7 +1995,7 @@
},
{
"cell_type": "code",
"execution_count": 126,
"execution_count": 130,
"metadata": {},
"outputs": [],
"source": [
@ -1959,7 +2011,7 @@
},
{
"cell_type": "code",
"execution_count": 127,
"execution_count": 131,
"metadata": {},
"outputs": [],
"source": [
@ -1989,7 +2041,7 @@
},
{
"cell_type": "code",
"execution_count": 128,
"execution_count": 132,
"metadata": {},
"outputs": [],
"source": [
@ -2002,7 +2054,7 @@
},
{
"cell_type": "code",
"execution_count": 129,
"execution_count": 133,
"metadata": {},
"outputs": [],
"source": [
@ -2018,7 +2070,7 @@
},
{
"cell_type": "code",
"execution_count": 130,
"execution_count": 134,
"metadata": {},
"outputs": [],
"source": [
@ -2052,12 +2104,12 @@
},
{
"cell_type": "code",
"execution_count": 131,
"execution_count": 136,
"metadata": {},
"outputs": [],
"source": [
"param_grid = [\n",
" {'preparation__num_pipeline__imputer__strategy': ['mean', 'median', 'most_frequent'],\n",
" {'preparation__num__imputer__strategy': ['mean', 'median', 'most_frequent'],\n",
" 'feature_selection__k': list(range(1, len(feature_importances) + 1))}\n",
"]\n",
"\n",
@ -2068,7 +2120,7 @@
},
{
"cell_type": "code",
"execution_count": 132,
"execution_count": 137,
"metadata": {},
"outputs": [],
"source": [
@ -2106,7 +2158,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"version": "3.6.5"
},
"nav_menu": {
"height": "279px",