Replace FeatureUnion + DataFrameSelector with new ColumnTransformer

main
Aurélien Geron 2018-07-31 20:08:33 +01:00
parent 060751a976
commit e2d450708a
1 changed files with 131 additions and 79 deletions

View File

@ -70,7 +70,7 @@
"\n", "\n",
"# Ignore useless warnings (see SciPy issue #5998)\n", "# Ignore useless warnings (see SciPy issue #5998)\n",
"import warnings\n", "import warnings\n",
"warnings.filterwarnings(action=\"ignore\", module=\"scipy\", message=\"^internal gelsd\")" "warnings.filterwarnings(action=\"ignore\", module=\"scipy\", message=\"internal gelsd\")"
] ]
}, },
{ {
@ -980,12 +980,56 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"And a transformer to just select a subset of the Pandas DataFrame columns:" "**Warning**: earlier versions of the book applied different transformations to different columns using a solution based on a `DataFrameSelector` transformer and a `FeatureUnion` (see below). It is now preferable to use the `ColumnTransformer` class that will be introduced in Scikit-Learn 0.20. For now we import it from `future_encoders.py`, but when Scikit-Learn 0.20 is released, you can import it from `sklearn.compose` instead:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 71, "execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"from future_encoders import ColumnTransformer\n",
"\n",
"num_attribs = list(housing_num)\n",
"cat_attribs = [\"ocean_proximity\"]\n",
"\n",
"full_pipeline = ColumnTransformer([\n",
" (\"num\", num_pipeline, num_attribs),\n",
" (\"cat\", OneHotEncoder(), cat_attribs),\n",
" ])\n",
"\n",
"housing_prepared = full_pipeline.fit_transform(housing)"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"housing_prepared"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"housing_prepared.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For reference, here is the old solution based on a `DataFrameSelector` transformer (to just select a subset of the Pandas `DataFrame` columns), and a `FeatureUnion`:"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -993,7 +1037,7 @@
"\n", "\n",
"# Create a class to select numerical or categorical columns \n", "# Create a class to select numerical or categorical columns \n",
"# since Scikit-Learn doesn't handle DataFrames yet\n", "# since Scikit-Learn doesn't handle DataFrames yet\n",
"class DataFrameSelector(BaseEstimator, TransformerMixin):\n", "class OldDataFrameSelector(BaseEstimator, TransformerMixin):\n",
" def __init__(self, attribute_names):\n", " def __init__(self, attribute_names):\n",
" self.attribute_names = attribute_names\n", " self.attribute_names = attribute_names\n",
" def fit(self, X, y=None):\n", " def fit(self, X, y=None):\n",
@ -1011,57 +1055,64 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 72, "execution_count": 76,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"num_attribs = list(housing_num)\n", "num_attribs = list(housing_num)\n",
"cat_attribs = [\"ocean_proximity\"]\n", "cat_attribs = [\"ocean_proximity\"]\n",
"\n", "\n",
"num_pipeline = Pipeline([\n", "old_num_pipeline = Pipeline([\n",
" ('selector', DataFrameSelector(num_attribs)),\n", " ('selector', OldDataFrameSelector(num_attribs)),\n",
" ('imputer', Imputer(strategy=\"median\")),\n", " ('imputer', Imputer(strategy=\"median\")),\n",
" ('attribs_adder', CombinedAttributesAdder()),\n", " ('attribs_adder', CombinedAttributesAdder()),\n",
" ('std_scaler', StandardScaler()),\n", " ('std_scaler', StandardScaler()),\n",
" ])\n", " ])\n",
"\n", "\n",
"cat_pipeline = Pipeline([\n", "old_cat_pipeline = Pipeline([\n",
" ('selector', DataFrameSelector(cat_attribs)),\n", " ('selector', OldDataFrameSelector(cat_attribs)),\n",
" ('cat_encoder', OneHotEncoder(sparse=False)),\n", " ('cat_encoder', OneHotEncoder(sparse=False)),\n",
" ])" " ])"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 73, "execution_count": 77,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sklearn.pipeline import FeatureUnion\n", "from sklearn.pipeline import FeatureUnion\n",
"\n", "\n",
"full_pipeline = FeatureUnion(transformer_list=[\n", "old_full_pipeline = FeatureUnion(transformer_list=[\n",
" (\"num_pipeline\", num_pipeline),\n", " (\"num_pipeline\", old_num_pipeline),\n",
" (\"cat_pipeline\", cat_pipeline),\n", " (\"cat_pipeline\", old_cat_pipeline),\n",
" ])" " ])"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 74, "execution_count": 78,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"housing_prepared = full_pipeline.fit_transform(housing)\n", "old_housing_prepared = old_full_pipeline.fit_transform(housing)\n",
"housing_prepared" "old_housing_prepared"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The result is the same as with the `ColumnTransformer`:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 75, "execution_count": 79,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"housing_prepared.shape" "np.allclose(housing_prepared, old_housing_prepared)"
] ]
}, },
{ {
@ -1073,7 +1124,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 76, "execution_count": 80,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1085,11 +1136,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 77, "execution_count": 81,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# let's try the full pipeline on a few training instances\n", "# let's try the full preprocessing pipeline on a few training instances\n",
"some_data = housing.iloc[:5]\n", "some_data = housing.iloc[:5]\n",
"some_labels = housing_labels.iloc[:5]\n", "some_labels = housing_labels.iloc[:5]\n",
"some_data_prepared = full_pipeline.transform(some_data)\n", "some_data_prepared = full_pipeline.transform(some_data)\n",
@ -1106,7 +1157,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 78, "execution_count": 82,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1115,7 +1166,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 79, "execution_count": 83,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1124,7 +1175,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 80, "execution_count": 84,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1138,7 +1189,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 81, "execution_count": 85,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1150,7 +1201,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 82, "execution_count": 86,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1162,7 +1213,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 83, "execution_count": 87,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1181,7 +1232,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 84, "execution_count": 88,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1194,7 +1245,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 85, "execution_count": 89,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1208,7 +1259,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 86, "execution_count": 90,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1220,7 +1271,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 87, "execution_count": 91,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1232,7 +1283,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 88, "execution_count": 92,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1244,7 +1295,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 89, "execution_count": 93,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1258,7 +1309,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 90, "execution_count": 94,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1268,7 +1319,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 91, "execution_count": 95,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1284,7 +1335,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 92, "execution_count": 96,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1313,7 +1364,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 93, "execution_count": 97,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1322,7 +1373,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 94, "execution_count": 98,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1338,7 +1389,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 95, "execution_count": 99,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1349,7 +1400,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 96, "execution_count": 100,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1358,7 +1409,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 97, "execution_count": 101,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1378,7 +1429,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 98, "execution_count": 102,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1389,7 +1440,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 99, "execution_count": 103,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1399,12 +1450,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 100, "execution_count": 104,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"extra_attribs = [\"rooms_per_hhold\", \"pop_per_hhold\", \"bedrooms_per_room\"]\n", "extra_attribs = [\"rooms_per_hhold\", \"pop_per_hhold\", \"bedrooms_per_room\"]\n",
"cat_encoder = cat_pipeline.named_steps[\"cat_encoder\"]\n", "#cat_encoder = cat_pipeline.named_steps[\"cat_encoder\"] # old solution\n",
"cat_encoder = full_pipeline.named_transformers_[\"cat\"]\n",
"cat_one_hot_attribs = list(cat_encoder.categories_[0])\n", "cat_one_hot_attribs = list(cat_encoder.categories_[0])\n",
"attributes = num_attribs + extra_attribs + cat_one_hot_attribs\n", "attributes = num_attribs + extra_attribs + cat_one_hot_attribs\n",
"sorted(zip(feature_importances, attributes), reverse=True)" "sorted(zip(feature_importances, attributes), reverse=True)"
@ -1412,7 +1464,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 101, "execution_count": 105,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1430,7 +1482,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 102, "execution_count": 106,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1446,7 +1498,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 103, "execution_count": 107,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1455,7 +1507,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 104, "execution_count": 108,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1478,7 +1530,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 105, "execution_count": 109,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1496,7 +1548,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 106, "execution_count": 110,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1521,7 +1573,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 107, "execution_count": 111,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1543,7 +1595,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 108, "execution_count": 112,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1552,7 +1604,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 109, "execution_count": 113,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1571,7 +1623,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 110, "execution_count": 114,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1609,7 +1661,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 111, "execution_count": 115,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1635,7 +1687,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 112, "execution_count": 116,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1653,7 +1705,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 113, "execution_count": 117,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1683,7 +1735,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 114, "execution_count": 118,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1716,7 +1768,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 115, "execution_count": 119,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1734,7 +1786,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 116, "execution_count": 120,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1757,7 +1809,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 117, "execution_count": 121,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1782,7 +1834,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 118, "execution_count": 122,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1821,7 +1873,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 119, "execution_count": 123,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1857,7 +1909,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 120, "execution_count": 124,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1873,7 +1925,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 121, "execution_count": 125,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1883,7 +1935,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 122, "execution_count": 126,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1899,7 +1951,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 123, "execution_count": 127,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1915,7 +1967,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 124, "execution_count": 128,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1927,7 +1979,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 125, "execution_count": 129,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1943,7 +1995,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 126, "execution_count": 130,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1959,7 +2011,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 127, "execution_count": 131,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1989,7 +2041,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 128, "execution_count": 132,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2002,7 +2054,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 129, "execution_count": 133,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2018,7 +2070,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 130, "execution_count": 134,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2052,12 +2104,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 131, "execution_count": 136,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"param_grid = [\n", "param_grid = [\n",
" {'preparation__num_pipeline__imputer__strategy': ['mean', 'median', 'most_frequent'],\n", " {'preparation__num__imputer__strategy': ['mean', 'median', 'most_frequent'],\n",
" 'feature_selection__k': list(range(1, len(feature_importances) + 1))}\n", " 'feature_selection__k': list(range(1, len(feature_importances) + 1))}\n",
"]\n", "]\n",
"\n", "\n",
@ -2068,7 +2120,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 132, "execution_count": 137,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2106,7 +2158,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.5.2" "version": "3.6.5"
}, },
"nav_menu": { "nav_menu": {
"height": "279px", "height": "279px",