Fix get_feature_names_out for FunctionTransformer

main
Aurélien Geron 2022-09-23 10:06:07 +12:00
parent 0573deb5d3
commit 576cec95d9
1 changed files with 74 additions and 80 deletions

View File

@ -2302,6 +2302,13 @@
"outlier_pred" "outlier_pred"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you wanted to drop outliers, you would run the following code:"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 61, "execution_count": 61,
@ -3467,20 +3474,11 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 104, "execution_count": 104,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Monkey-patching SimpleImputer.get_feature_names_out()\n",
"Monkey-patching FunctionTransformer.get_feature_names_out()\n"
]
}
],
"source": [ "source": [
"def monkey_patch_get_signature_names_out():\n", "def monkey_patch_get_signature_names_out():\n",
" \"\"\"Monkey patch some classes which did not handle get_feature_names_out()\n", " \"\"\"Monkey patch some classes which did not handle get_feature_names_out()\n",
" correctly in 1.0.0.\"\"\"\n", " correctly in Scikit-Learn 1.0.*.\"\"\"\n",
" from inspect import Signature, signature, Parameter\n", " from inspect import Signature, signature, Parameter\n",
" import pandas as pd\n", " import pandas as pd\n",
" from sklearn.impute import SimpleImputer\n", " from sklearn.impute import SimpleImputer\n",
@ -3508,12 +3506,10 @@
" Parameter(\"feature_names_out\", Parameter.KEYWORD_ONLY)])\n", " Parameter(\"feature_names_out\", Parameter.KEYWORD_ONLY)])\n",
"\n", "\n",
" def get_feature_names_out(self, names=None):\n", " def get_feature_names_out(self, names=None):\n",
" if self.feature_names_out is None:\n", " if callable(self.feature_names_out):\n",
" return self.feature_names_out(self, names)\n",
" assert self.feature_names_out == \"one-to-one\"\n",
" return default_get_feature_names_out(self, names)\n", " return default_get_feature_names_out(self, names)\n",
" elif callable(self.feature_names_out):\n",
" return self.feature_names_out(names)\n",
" else:\n",
" return self.feature_names_out\n",
"\n", "\n",
" FunctionTransformer.__init__ = __init__\n", " FunctionTransformer.__init__ = __init__\n",
" FunctionTransformer.get_feature_names_out = get_feature_names_out\n", " FunctionTransformer.get_feature_names_out = get_feature_names_out\n",
@ -3896,28 +3892,28 @@
"def column_ratio(X):\n", "def column_ratio(X):\n",
" return X[:, [0]] / X[:, [1]]\n", " return X[:, [0]] / X[:, [1]]\n",
"\n", "\n",
"def ratio_pipeline(name=None):\n", "def ratio_name(function_transformer, feature_names_in):\n",
" return [\"ratio\"] # feature names out\n",
"\n",
"def ratio_pipeline():\n",
" return make_pipeline(\n", " return make_pipeline(\n",
" SimpleImputer(strategy=\"median\"),\n", " SimpleImputer(strategy=\"median\"),\n",
" FunctionTransformer(column_ratio,\n", " FunctionTransformer(column_ratio, feature_names_out=ratio_name),\n",
" feature_names_out=[name]),\n",
" StandardScaler())\n", " StandardScaler())\n",
"\n", "\n",
"log_pipeline = make_pipeline(SimpleImputer(strategy=\"median\"),\n", "log_pipeline = make_pipeline(\n",
" FunctionTransformer(np.log),\n", " SimpleImputer(strategy=\"median\"),\n",
" FunctionTransformer(np.log, feature_names_out=\"one-to-one\"),\n",
" StandardScaler())\n", " StandardScaler())\n",
"cluster_simil = ClusterSimilarity(n_clusters=10, gamma=1., random_state=42)\n", "cluster_simil = ClusterSimilarity(n_clusters=10, gamma=1., random_state=42)\n",
"default_num_pipeline = make_pipeline(SimpleImputer(strategy=\"median\"),\n", "default_num_pipeline = make_pipeline(SimpleImputer(strategy=\"median\"),\n",
" StandardScaler())\n", " StandardScaler())\n",
"preprocessing = ColumnTransformer([\n", "preprocessing = ColumnTransformer([\n",
" (\"bedrooms_ratio\", ratio_pipeline(\"bedrooms_ratio\"),\n", " (\"bedrooms\", ratio_pipeline(), [\"total_bedrooms\", \"total_rooms\"]),\n",
" [\"total_bedrooms\", \"total_rooms\"]),\n", " (\"rooms_per_house\", ratio_pipeline(), [\"total_rooms\", \"households\"]),\n",
" (\"rooms_per_house\", ratio_pipeline(\"rooms_per_house\"),\n", " (\"people_per_house\", ratio_pipeline(), [\"population\", \"households\"]),\n",
" [\"total_rooms\", \"households\"]),\n", " (\"log\", log_pipeline, [\"total_bedrooms\", \"total_rooms\", \"population\",\n",
" (\"people_per_house\", ratio_pipeline(\"people_per_house\"),\n", " \"households\", \"median_income\"]),\n",
" [\"population\", \"households\"]),\n",
" (\"log\", log_pipeline, [\"total_bedrooms\", \"total_rooms\",\n",
" \"population\", \"households\", \"median_income\"]),\n",
" (\"geo\", cluster_simil, [\"latitude\", \"longitude\"]),\n", " (\"geo\", cluster_simil, [\"latitude\", \"longitude\"]),\n",
" (\"cat\", cat_pipeline, make_column_selector(dtype_include=object)),\n", " (\"cat\", cat_pipeline, make_column_selector(dtype_include=object)),\n",
" ],\n", " ],\n",
@ -3953,9 +3949,8 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"array(['bedrooms_ratio__bedrooms_ratio',\n", "array(['bedrooms__ratio', 'rooms_per_house__ratio',\n",
" 'rooms_per_house__rooms_per_house',\n", " 'people_per_house__ratio', 'log__total_bedrooms',\n",
" 'people_per_house__people_per_house', 'log__total_bedrooms',\n",
" 'log__total_rooms', 'log__population', 'log__households',\n", " 'log__total_rooms', 'log__population', 'log__households',\n",
" 'log__median_income', 'geo__Cluster 0 similarity',\n", " 'log__median_income', 'geo__Cluster 0 similarity',\n",
" 'geo__Cluster 1 similarity', 'geo__Cluster 2 similarity',\n", " 'geo__Cluster 1 similarity', 'geo__Cluster 2 similarity',\n",
@ -4004,12 +3999,12 @@
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('standardscaler',\n", " ('standardscaler',\n",
" StandardScaler())]),\n", " StandardScaler())]),\n",
" transformers=[('bedrooms_ratio',\n", " transformers=[('bedrooms',\n",
" Pipeline(steps=[('simpleimputer',\n", " Pipeline(steps=[('simpleimputer',\n",
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('functiontransformer',\n", " ('functiontransformer',\n",
" FunctionTransformer(feature_names_out=['bedrooms_ratio'],\n", " FunctionTransformer(feature_names_out=<function ratio_name at 0x1a5...\n",
" f...\n", " 'households',\n",
" 'median_income']),\n", " 'median_income']),\n",
" ('geo',\n", " ('geo',\n",
" ClusterSimilarity(random_state=42),\n", " ClusterSimilarity(random_state=42),\n",
@ -4019,7 +4014,7 @@
" SimpleImputer(strategy='most_frequent')),\n", " SimpleImputer(strategy='most_frequent')),\n",
" ('onehotencoder',\n", " ('onehotencoder',\n",
" OneHotEncoder(handle_unknown='ignore'))]),\n", " OneHotEncoder(handle_unknown='ignore'))]),\n",
" <sklearn.compose._column_transformer.make_column_selector object at 0x7f9b50613dc0>)])),\n", " <sklearn.compose._column_transformer.make_column_selector object at 0x1a57e3a00>)])),\n",
" ('linearregression', LinearRegression())])" " ('linearregression', LinearRegression())])"
] ]
}, },
@ -4146,12 +4141,12 @@
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('standardscaler',\n", " ('standardscaler',\n",
" StandardScaler())]),\n", " StandardScaler())]),\n",
" transformers=[('bedrooms_ratio',\n", " transformers=[('bedrooms',\n",
" Pipeline(steps=[('simpleimputer',\n", " Pipeline(steps=[('simpleimputer',\n",
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('functiontransformer',\n", " ('functiontransformer',\n",
" FunctionTransformer(feature_names_out=['bedrooms_ratio'],\n", " FunctionTransformer(feature_names_out=<function ratio_name at 0x1a5...\n",
" f...\n", " ('geo',\n",
" ClusterSimilarity(random_state=42),\n", " ClusterSimilarity(random_state=42),\n",
" ['latitude', 'longitude']),\n", " ['latitude', 'longitude']),\n",
" ('cat',\n", " ('cat',\n",
@ -4159,7 +4154,7 @@
" SimpleImputer(strategy='most_frequent')),\n", " SimpleImputer(strategy='most_frequent')),\n",
" ('onehotencoder',\n", " ('onehotencoder',\n",
" OneHotEncoder(handle_unknown='ignore'))]),\n", " OneHotEncoder(handle_unknown='ignore'))]),\n",
" <sklearn.compose._column_transformer.make_column_selector object at 0x7f9b50613dc0>)])),\n", " <sklearn.compose._column_transformer.make_column_selector object at 0x1a57e3a00>)])),\n",
" ('decisiontreeregressor',\n", " ('decisiontreeregressor',\n",
" DecisionTreeRegressor(random_state=42))])" " DecisionTreeRegressor(random_state=42))])"
] ]
@ -4399,12 +4394,12 @@
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('standardscaler',\n", " ('standardscaler',\n",
" StandardScaler())]),\n", " StandardScaler())]),\n",
" transformers=[('bedrooms_ratio',\n", " transformers=[('bedrooms',\n",
" Pipeline(steps=[('simpleimputer',\n", " Pipeline(steps=[('simpleimputer',\n",
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('functiontransformer',\n", " ('functiontransformer',\n",
" FunctionTransformer(feature_names_...\n", " FunctionTransformer(feature_names_out=<f...\n",
" <sklearn.compose._column_transformer.make_column_selector object at 0x7f9b50613dc0>)])),\n", " <sklearn.compose._column_transformer.make_column_selector object at 0x1a57e3a00>)])),\n",
" ('random_forest',\n", " ('random_forest',\n",
" RandomForestRegressor(random_state=42))]),\n", " RandomForestRegressor(random_state=42))]),\n",
" param_grid=[{'preprocessing__geo__n_clusters': [5, 8, 10],\n", " param_grid=[{'preprocessing__geo__n_clusters': [5, 8, 10],\n",
@ -4502,12 +4497,11 @@
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('standardscaler',\n", " ('standardscaler',\n",
" StandardScaler())]),\n", " StandardScaler())]),\n",
" transformers=[('bedrooms_ratio',\n", " transformers=[('bedrooms',\n",
" Pipeline(steps=[('simpleimputer',\n", " Pipeline(steps=[('simpleimputer',\n",
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('functiontransformer',\n", " ('functiontransformer',\n",
" FunctionTransformer(feature_names_out=['bedrooms_ratio'],\n", " FunctionTransformer(feature_names_out=<function ratio_name at 0x1a5b6fd...\n",
" func=...\n",
" ClusterSimilarity(n_clusters=15,\n", " ClusterSimilarity(n_clusters=15,\n",
" random_state=42),\n", " random_state=42),\n",
" ['latitude', 'longitude']),\n", " ['latitude', 'longitude']),\n",
@ -4516,7 +4510,7 @@
" SimpleImputer(strategy='most_frequent')),\n", " SimpleImputer(strategy='most_frequent')),\n",
" ('onehotencoder',\n", " ('onehotencoder',\n",
" OneHotEncoder(handle_unknown='ignore'))]),\n", " OneHotEncoder(handle_unknown='ignore'))]),\n",
" <sklearn.compose._column_transformer.make_column_selector object at 0x7f9b410ec490>)])),\n", " <sklearn.compose._column_transformer.make_column_selector object at 0x1a5cdffd0>)])),\n",
" ('random_forest',\n", " ('random_forest',\n",
" RandomForestRegressor(max_features=6, random_state=42))])" " RandomForestRegressor(max_features=6, random_state=42))])"
] ]
@ -4657,13 +4651,6 @@
"## Randomized Search" "## Randomized Search"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning:** the following cell may take a few minutes to run:"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 137, "execution_count": 137,
@ -4681,6 +4668,13 @@
"Try 30 (`n_iter` × `cv`) random combinations of hyperparameters:" "Try 30 (`n_iter` × `cv`) random combinations of hyperparameters:"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning:** the following cell may take a few minutes to run:"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 138, "execution_count": 138,
@ -4695,16 +4689,16 @@
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('standardscaler',\n", " ('standardscaler',\n",
" StandardScaler())]),\n", " StandardScaler())]),\n",
" transformers=[('bedrooms_ratio',\n", " transformers=[('bedrooms',\n",
" Pipeline(steps=[('simpleimputer',\n", " Pipeline(steps=[('simpleimputer',\n",
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('functiontransformer',\n", " ('functiontransformer',\n",
" FunctionTransformer(feature_...\n", " FunctionTransformer(feature_names_...\n",
" <sklearn.compose._column_transformer.make_column_selector object at 0x7f9b50613dc0>)])),\n", " <sklearn.compose._column_transformer.make_column_selector object at 0x1a57e3a00>)])),\n",
" ('random_forest',\n", " ('random_forest',\n",
" RandomForestRegressor(random_state=42))]),\n", " RandomForestRegressor(random_state=42))]),\n",
" param_distributions={'preprocessing__geo__n_clusters': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7f9b103bb760>,\n", " param_distributions={'preprocessing__geo__n_clusters': <scipy.stats._distn_infrastructure.rv_discrete_frozen object at 0x1a4bfcb20>,\n",
" 'random_forest__max_features': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7f9b410decd0>},\n", " 'random_forest__max_features': <scipy.stats._distn_infrastructure.rv_discrete_frozen object at 0x1a57c7bb0>},\n",
" random_state=42, scoring='neg_root_mean_squared_error')" " random_state=42, scoring='neg_root_mean_squared_error')"
] ]
}, },
@ -5046,9 +5040,9 @@
"text/plain": [ "text/plain": [
"[(0.18694559869103852, 'log__median_income'),\n", "[(0.18694559869103852, 'log__median_income'),\n",
" (0.0748194905715524, 'cat__ocean_proximity_INLAND'),\n", " (0.0748194905715524, 'cat__ocean_proximity_INLAND'),\n",
" (0.06926417748515576, 'bedrooms_ratio__bedrooms_ratio'),\n", " (0.06926417748515576, 'bedrooms__ratio'),\n",
" (0.05446998753775219, 'rooms_per_house__rooms_per_house'),\n", " (0.05446998753775219, 'rooms_per_house__ratio'),\n",
" (0.05262301809680712, 'people_per_house__people_per_house'),\n", " (0.05262301809680712, 'people_per_house__ratio'),\n",
" (0.03819415873915732, 'geo__Cluster 0 similarity'),\n", " (0.03819415873915732, 'geo__Cluster 0 similarity'),\n",
" (0.02879263999929514, 'geo__Cluster 28 similarity'),\n", " (0.02879263999929514, 'geo__Cluster 28 similarity'),\n",
" (0.023530192521380392, 'geo__Cluster 24 similarity'),\n", " (0.023530192521380392, 'geo__Cluster 24 similarity'),\n",
@ -5333,7 +5327,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Also works with pickle, but joblib is more efficient." "You could use pickle instead, but joblib is more efficient."
] ]
}, },
{ {
@ -5371,12 +5365,12 @@
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('standardscaler',\n", " ('standardscaler',\n",
" StandardScaler())]),\n", " StandardScaler())]),\n",
" transformers=[('bedrooms_ratio',\n", " transformers=[('bedrooms',\n",
" Pipeline(steps=[('simpleimputer',\n", " Pipeline(steps=[('simpleimputer',\n",
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('functiontransformer',\n", " ('functiontransformer',\n",
" FunctionTransformer(feature_names_...\n", " FunctionTransformer(feature_names_out=<f...\n",
" <sklearn.compose._column_transformer.make_column_selector object at 0x7f9b50613dc0>)])),\n", " <sklearn.compose._column_transformer.make_column_selector object at 0x1a57e3a00>)])),\n",
" ('svr', SVR())]),\n", " ('svr', SVR())]),\n",
" param_grid=[{'svr__C': [10.0, 30.0, 100.0, 300.0, 1000.0, 3000.0,\n", " param_grid=[{'svr__C': [10.0, 30.0, 100.0, 300.0, 1000.0, 3000.0,\n",
" 10000.0, 30000.0],\n", " 10000.0, 30000.0],\n",
@ -5508,16 +5502,16 @@
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('standardscaler',\n", " ('standardscaler',\n",
" StandardScaler())]),\n", " StandardScaler())]),\n",
" transformers=[('bedrooms_ratio',\n", " transformers=[('bedrooms',\n",
" Pipeline(steps=[('simpleimputer',\n", " Pipeline(steps=[('simpleimputer',\n",
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('functiontransformer',\n", " ('functiontransformer',\n",
" FunctionTransformer(feature_...\n", " FunctionTransformer(feature_names_...\n",
" <sklearn.compose._column_transformer.make_column_selector object at 0x7f9b50613dc0>)])),\n", " <sklearn.compose._column_transformer.make_column_selector object at 0x1a57e3a00>)])),\n",
" ('svr', SVR())]),\n", " ('svr', SVR())]),\n",
" n_iter=50,\n", " n_iter=50,\n",
" param_distributions={'svr__C': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7f9ae254b9d0>,\n", " param_distributions={'svr__C': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x1a5d4c3a0>,\n",
" 'svr__gamma': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7f9b734dbe50>,\n", " 'svr__gamma': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x1a5d9ca00>,\n",
" 'svr__kernel': ['linear', 'rbf']},\n", " 'svr__kernel': ['linear', 'rbf']},\n",
" random_state=42, scoring='neg_root_mean_squared_error')" " random_state=42, scoring='neg_root_mean_squared_error')"
] ]
@ -5970,20 +5964,19 @@
"text/plain": [ "text/plain": [
"RandomizedSearchCV(cv=3,\n", "RandomizedSearchCV(cv=3,\n",
" estimator=Pipeline(steps=[('preprocessing',\n", " estimator=Pipeline(steps=[('preprocessing',\n",
" ColumnTransformer(transformers=[('bedrooms_ratio',\n", " ColumnTransformer(transformers=[('bedrooms',\n",
" Pipeline(steps=[('simpleimputer',\n", " Pipeline(steps=[('simpleimputer',\n",
" SimpleImputer(strategy='median')),\n", " SimpleImputer(strategy='median')),\n",
" ('functiontransformer',\n", " ('functiontransformer',\n",
" FunctionTransformer(feature_names_out=['bedrooms_ratio'],\n", " FunctionTransformer(feature_names_out=<function ratio_name at 0x1a5b6fd90>,\n",
" func=<function column_ratio at 0x7f9b505e5670>)),\n", " func=<function column_ratio at 0x1a5695bd0>)),\n",
" ('standardscaler',\n", " ('standardscaler',\n",
" StandardScaler())]),\n", " StandardScaler()...\n",
" ['...\n",
" param_distributions={'preprocessing__geo__estimator__n_neighbors': range(1, 30),\n", " param_distributions={'preprocessing__geo__estimator__n_neighbors': range(1, 30),\n",
" 'preprocessing__geo__estimator__weights': ['distance',\n", " 'preprocessing__geo__estimator__weights': ['distance',\n",
" 'uniform'],\n", " 'uniform'],\n",
" 'svr__C': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7f9acb940bb0>,\n", " 'svr__C': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x1a63fda80>,\n",
" 'svr__gamma': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7f9acb940a30>},\n", " 'svr__gamma': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x1a63fe410>},\n",
" random_state=42, scoring='neg_root_mean_squared_error')" " random_state=42, scoring='neg_root_mean_squared_error')"
] ]
}, },
@ -6186,6 +6179,7 @@
"source": [ "source": [
"scaler = StandardScalerClone()\n", "scaler = StandardScalerClone()\n",
"X_back = scaler.inverse_transform(scaler.fit_transform(X))\n", "X_back = scaler.inverse_transform(scaler.fit_transform(X))\n",
"\n",
"assert np.allclose(X, X_back)" "assert np.allclose(X, X_back)"
] ]
}, },