Move StandardScalerClone inverse_transform and get_feature_names_out to exercise

main
Aurélien Geron 2021-11-15 17:45:26 +13:00
parent 93676a4f23
commit c658c2b07c
1 changed files with 267 additions and 122 deletions

View File

@ -1473,27 +1473,7 @@
" assert self.n_features_in_ == X.shape[1]\n", " assert self.n_features_in_ == X.shape[1]\n",
" if self.with_mean:\n", " if self.with_mean:\n",
" X = X - self.mean_\n", " X = X - self.mean_\n",
" return X / self.scale_\n", " return X / self.scale_"
" \n",
" # not in the book (left as an exercise):\n",
" def inverse_transform(self, X):\n",
" check_is_fitted(self)\n",
" X = check_array(X)\n",
" assert self.n_features_in_ == X.shape[1]\n",
" X = X * self.scale_\n",
" return X + self.mean_ if self.with_mean else X\n",
" \n",
" # not in the book (left as an exercise):\n",
" def get_feature_names_out(self, names=None):\n",
" return names or getattr(self, \"feature_names_in_\",\n",
" [f\"x{i}\" for i in range(self.n_features_in_)]) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's test our custom transformer:"
] ]
}, },
{ {
@ -1501,30 +1481,6 @@
"execution_count": 100, "execution_count": 100,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [
"# Not in the book\n",
"from sklearn.utils.estimator_checks import check_estimator\n",
" \n",
"check_estimator(StandardScaler())\n",
"X = np.random.rand(1000, 3)\n",
"ss = StandardScaler()\n",
"ssc = StandardScalerClone()\n",
"X_scaled1 = ss.fit_transform(X)\n",
"X_scaled2 = ssc.fit_transform(X)\n",
"X_back1 = ss.inverse_transform(X_scaled1)\n",
"X_back2 = ssc.inverse_transform(X_scaled2)\n",
"assert np.allclose(X_scaled1, X_scaled2)\n",
"assert np.allclose(X_back1, X_back2)\n",
"assert ssc.n_features_in_ == 3\n",
"assert not hasattr(ssc, \"features_names_in_\")\n",
"assert ssc.get_feature_names_out() == [\"x0\", \"x1\", \"x2\"]"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [],
"source": [ "source": [
"from sklearn.cluster import KMeans\n", "from sklearn.cluster import KMeans\n",
"\n", "\n",
@ -1548,7 +1504,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 102, "execution_count": 101,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1559,7 +1515,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 103, "execution_count": 102,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1568,7 +1524,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 104, "execution_count": 103,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1610,7 +1566,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 105, "execution_count": 104,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1624,7 +1580,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 106, "execution_count": 105,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1635,7 +1591,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 107, "execution_count": 106,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1648,7 +1604,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 108, "execution_count": 107,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1658,7 +1614,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 109, "execution_count": 108,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1707,7 +1663,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 110, "execution_count": 109,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1719,7 +1675,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 111, "execution_count": 110,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1728,7 +1684,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 112, "execution_count": 111,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1737,7 +1693,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 113, "execution_count": 112,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1746,7 +1702,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 114, "execution_count": 113,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1755,7 +1711,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 115, "execution_count": 114,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1764,7 +1720,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 116, "execution_count": 115,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1786,7 +1742,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 117, "execution_count": 116,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1800,7 +1756,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 118, "execution_count": 117,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1809,7 +1765,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 119, "execution_count": 118,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1822,7 +1778,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 120, "execution_count": 119,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1859,7 +1815,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 121, "execution_count": 120,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1869,7 +1825,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 122, "execution_count": 121,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1892,7 +1848,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 123, "execution_count": 122,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1911,7 +1867,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 124, "execution_count": 123,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1928,7 +1884,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 125, "execution_count": 124,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1937,7 +1893,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 126, "execution_count": 125,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1948,7 +1904,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 127, "execution_count": 126,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1961,7 +1917,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 128, "execution_count": 127,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1973,7 +1929,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 129, "execution_count": 128,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1992,7 +1948,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 130, "execution_count": 129,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2004,7 +1960,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 131, "execution_count": 130,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2013,7 +1969,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 132, "execution_count": 131,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2032,7 +1988,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 133, "execution_count": 132,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2046,7 +2002,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 134, "execution_count": 133,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2062,7 +2018,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 135, "execution_count": 134,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2103,7 +2059,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 136, "execution_count": 135,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2133,7 +2089,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 137, "execution_count": 136,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2150,7 +2106,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 138, "execution_count": 137,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2159,7 +2115,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 139, "execution_count": 138,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2175,7 +2131,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 140, "execution_count": 139,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2209,7 +2165,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 141, "execution_count": 140,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2226,7 +2182,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 142, "execution_count": 141,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2245,7 +2201,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 143, "execution_count": 142,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2282,7 +2238,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 144, "execution_count": 143,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
@ -2343,7 +2299,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 145, "execution_count": 144,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
@ -2406,7 +2362,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 146, "execution_count": 145,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2417,7 +2373,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 147, "execution_count": 146,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2435,7 +2391,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 148, "execution_count": 147,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2457,7 +2413,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 149, "execution_count": 148,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2479,7 +2435,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 150, "execution_count": 149,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2500,7 +2456,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 151, "execution_count": 150,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2526,7 +2482,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 152, "execution_count": 151,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2544,7 +2500,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 153, "execution_count": 152,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2569,7 +2525,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 154, "execution_count": 153,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2601,12 +2557,12 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"_Try a Support Vector Machine regressor (`sklearn.svm.SVR`) with various hyperparameters, such as `kernel=\"linear\"` (with various values for the `C` hyperparameter) or `kernel=\"rbf\"` (with various values for the `C` and `gamma` hyperparameters). Note that SVMs don't scale well to large datasets, so you should probably train your model on just the first 5,000 instances of the training set and use only 3-fold cross-validation, or else it will take hours. Don't worry about what the hyperparameters mean for now (see the SVM notebook if you're interested). How does the best `SVR` predictor perform?_" "Exercise: _Try a Support Vector Machine regressor (`sklearn.svm.SVR`) with various hyperparameters, such as `kernel=\"linear\"` (with various values for the `C` hyperparameter) or `kernel=\"rbf\"` (with various values for the `C` and `gamma` hyperparameters). Note that SVMs don't scale well to large datasets, so you should probably train your model on just the first 5,000 instances of the training set and use only 3-fold cross-validation, or else it will take hours. Don't worry about what the hyperparameters mean for now (see the SVM notebook if you're interested). How does the best `SVR` predictor perform?_"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 155, "execution_count": 154,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2636,7 +2592,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 156, "execution_count": 155,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2653,7 +2609,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 157, "execution_count": 156,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2678,7 +2634,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"_Try replacing the `GridSearchCV` with a `RandomizedSearchCV`._" "Exercise: _Try replacing the `GridSearchCV` with a `RandomizedSearchCV`._"
] ]
}, },
{ {
@ -2690,7 +2646,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 158, "execution_count": 157,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2724,7 +2680,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 159, "execution_count": 158,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2741,7 +2697,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 160, "execution_count": 159,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2764,7 +2720,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 161, "execution_count": 160,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2792,7 +2748,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"_Try adding a `SelectFromModel` transformer in the preparation pipeline to select only the most important attributes._" "Exercise: _Try adding a `SelectFromModel` transformer in the preparation pipeline to select only the most important attributes._"
] ]
}, },
{ {
@ -2804,7 +2760,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 162, "execution_count": 161,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2822,7 +2778,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 163, "execution_count": 162,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2852,7 +2808,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"_Try creating a custom transformer that trains a k-Nearest Neighbors regressor (`sklearn.neighbors.KNeighborsRegressor`) in its `fit()` method, and outputs the model's predictions in its `transform()` method. Then add this feature to the preprocessing pipeline, using latitude and longitude as the inputs to this transformer. This will add a feature in the model that corresponds to the housing median price of the nearest districts._" "Exercise: _Try creating a custom transformer that trains a k-Nearest Neighbors regressor (`sklearn.neighbors.KNeighborsRegressor`) in its `fit()` method, and outputs the model's predictions in its `transform()` method. Then add this feature to the preprocessing pipeline, using latitude and longitude as the inputs to this transformer. This will add a feature in the model that corresponds to the housing median price of the nearest districts._"
] ]
}, },
{ {
@ -2864,7 +2820,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 164, "execution_count": 163,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2909,7 +2865,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 165, "execution_count": 164,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2925,7 +2881,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 166, "execution_count": 165,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2944,7 +2900,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 167, "execution_count": 166,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2960,7 +2916,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 168, "execution_count": 167,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2976,7 +2932,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 169, "execution_count": 168,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2990,7 +2946,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 170, "execution_count": 169,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -3020,12 +2976,12 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Question: Automatically explore some preparation options using `RandomSearchCV`." "Exercise: _Automatically explore some preparation options using `RandomSearchCV`._"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 171, "execution_count": 170,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -3047,7 +3003,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 172, "execution_count": 171,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -3066,7 +3022,196 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"That's all for today! 😀" "## 6."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Exercise: _Try to implement the `StandardScalerClone` class again from scratch, then add support for the `inverse_transform()` method: executing `scaler.inverse_transform(scaler.fit_transform(X))` should return an array very close to `X`. Then add support for feature names: set `feature_names_in_` in the `fit()` method if the input is a DataFrame. This attribute should be a NumPy array of column names. Lastly, implement the `get_feature_names_out()` method: it should have one optional `input_features=None` argument. If passed, the method should check that its length matches `n_features_in_`, and it should match `feature_names_in_` if it is defined, then `input_features` should be returned. If `input_features` is `None`, then the method should return `feature_names_in_` if it is defined or `np.array([\"x0\", \"x1\", ...])` with length `n_features_in_` otherwise._"
]
},
{
"cell_type": "code",
"execution_count": 172,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.utils.validation import check_array, check_is_fitted\n",
"\n",
"class StandardScalerClone(BaseEstimator, TransformerMixin):\n",
" def __init__(self, with_mean=True): # no *args or **kwargs!\n",
" self.with_mean = with_mean\n",
"\n",
" def fit(self, X, y=None): # y is required even though we don't use it\n",
" X = check_array(X) # checks that X is an array with finite float values\n",
" self.mean_ = X.mean(axis=0)\n",
" self.scale_ = X.std(axis=0)\n",
" self.n_features_in_ = X.shape[1] # every estimator stores this in fit()\n",
" if hasattr(X, \"columns\"):\n",
" self.feature_names_in_ = np.array(X.columns, np.object)\n",
" return self # always return self!\n",
"\n",
" def transform(self, X):\n",
" check_is_fitted(self) # looks for learned attributes (with trailing _)\n",
" X = check_array(X)\n",
" if self.n_features_in_ != X.shape[1]:\n",
" raise ValueError(\"Unexpected number of features\")\n",
" if self.with_mean:\n",
" X = X - self.mean_\n",
" return X / self.scale_\n",
" \n",
" def inverse_transform(self, X):\n",
" check_is_fitted(self)\n",
" X = check_array(X)\n",
" if self.n_features_in_ != X.shape[1]:\n",
" raise ValueError(\"Unexpected number of features\")\n",
" X = X * self.scale_\n",
" return X + self.mean_ if self.with_mean else X\n",
" \n",
" def get_feature_names_out(self, input_features=None):\n",
" if input_features is None:\n",
" return getattr(self, \"feature_names_in_\",\n",
" [f\"x{i}\" for i in range(self.n_features_in_)])\n",
" else:\n",
" if len(input_features) != self.n_features_in_:\n",
" raise ValueError(\"Invalid number of features\")\n",
" if hasattr(self, \"feature_names_in_\") and not np.all(\n",
" self.feature_names_in_ == input_features\n",
" ):\n",
" raise ValueError(\"input_features ≠ feature_names_in_\")\n",
" return input_features"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's test our custom transformer:"
]
},
{
"cell_type": "code",
"execution_count": 173,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.utils.estimator_checks import check_estimator\n",
" \n",
"check_estimator(StandardScalerClone())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"No errors, that's a great start, we respect the Scikit-Learn API."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's ensure we the transformation works as expected:"
]
},
{
"cell_type": "code",
"execution_count": 174,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"X = np.random.rand(1000, 3)\n",
"\n",
"scaler = StandardScalerClone()\n",
"X_scaled = scaler.fit_transform(X)\n",
"\n",
"assert np.allclose(X_scaled, (X - X.mean(axis=0)) / X.std(axis=0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"How about setting `with_mean=False`?"
]
},
{
"cell_type": "code",
"execution_count": 175,
"metadata": {},
"outputs": [],
"source": [
"scaler = StandardScalerClone(with_mean=False)\n",
"X_scaled_uncentered = scaler.fit_transform(X)\n",
"\n",
"assert np.allclose(X_scaled_uncentered, X / X.std(axis=0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And does the inverse work?"
]
},
{
"cell_type": "code",
"execution_count": 176,
"metadata": {},
"outputs": [],
"source": [
"scaler = StandardScalerClone()\n",
"X_back = scaler.inverse_transform(scaler.fit_transform(X))\n",
"assert np.allclose(X, X_back)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"How about the feature names out?"
]
},
{
"cell_type": "code",
"execution_count": 177,
"metadata": {},
"outputs": [],
"source": [
"assert np.all(scaler.get_feature_names_out() == [\"x0\", \"x1\", \"x2\"])\n",
"assert np.all(scaler.get_feature_names_out([\"a\", \"b\", \"c\"]) == [\"a\", \"b\", \"c\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And if we fit a DataFrame, are the feature in and out ok?"
]
},
{
"cell_type": "code",
"execution_count": 178,
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame({\"a\": np.random.rand(100), \"b\": np.random.rand(100)})\n",
"scaler = StandardScalerClone()\n",
"X_scaled = scaler.fit_transform(df)\n",
"\n",
"assert np.all(ss.feature_names_in_ == [\"a\", \"b\"])\n",
"assert np.all(ss.get_feature_names_out() == [\"a\", \"b\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"All good! That's all for today! 😀"
] ]
}, },
{ {