Open ChiragTutlani opened 5 years ago
Hi @Chirag161198 ,
Thanks for your question. Scikit-Learn's API uses duck typing: if you want to write your own custom estimators (including transformers and predictors), you only need to implement the right methods, you don't have to inherit from any particular class.
For example, all estimators must implement a fit()
method, and get_params()
and set_params()
methods. All transformers must also implement transform()
and fit_transform()
methods. All predictors must implement a predict()
method. And so on.
The most basic implementation of the fit_transform()
method is just this:
def fit_transform(self, X, y=None):
return self.fit(X, y).transform(X, y)
You don't have to inherit from the TransformerMixin
class, but that's what you get if you do: if you implement the fit()
method and the predict()
method, it gives you the fit_transform()
method for free, just like the above.
Similarly, the BaseEstimator
class just gives you the get_params()
and set_params()
methods for free. By default, get_params()
does some introspection to get the parameters of the constructor __init__()
, and it assumes that the class has corresponding instance variables. For example:
from sklearn.base import BaseEstimator
class MyEstimator(BaseEstimator):
def __init__(self, a, b=2):
self.a = a
self.b = b
You get the get_params()
and set_params()
methods for free:
>>> m = MyEstimator(1, 2)
>>> m.get_params()
{'a': 1, 'b': 2}
>>> m.set_params(a=5, b=10)
MyEstimator(a=5, b=10)
>>> m.a
5
>>> m.b
10
Now if you want to build a transformer that will do, say, standardization (just like the StandardScaler
), you can implement it like this:
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
class MyStandardScaler(BaseEstimator, TransformerMixin):
def fit(self, X, y=None):
self.mean_ = X.mean(axis=0, keepdims=True)
self.scale_ = X.std(axis=0, keepdims=True)
return self # <= always remember to return self
def transform(self, X, y=None):
return (X - self.mean_) / self.scale_
Note that by convention all learned parameters are stored as instance variables with a name ending with an underscore.
You can then use it like a normal StandardScaler
:
>>> scaler = MyStandardScaler()
>>> X_train = np.random.rand(1000, 10)
>>> X_train_scaled = scaler.fit_transform(X_train)
>>> np.round(X_train_scaled.mean(axis=0), 5)
array([-0., 0., -0., 0., -0., -0., 0., -0., 0., 0.])
>>> X_train_scaled.std(axis=0)
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
Hope this helps! Aurélien
Hi @ageron, thanks for sharing that get_params looks at __init__()
. I have a few questions that i found hard to find answers to that i wish you could help with.
You had a comment above "always remember to return self" under fit
. I have seen some articles not returning anything in fit
.(https://datamadness.github.io/Custom-Transformer)
Is the purpose of return self
to make a sklearn Pipeline object which contains some downstream transformers work? (not sure how pipeline transformers pass fit/transform information down when doing pipe.fit/transform).
Another issue bugging me is the function signature/input parameters to fit and transform, there are so many variations. Most code i see do def fit(self, X, y=None)
and def transform(self, X)
and it's the first time i see your signature for transform. transform(self, X, y=None)
. I also see *args or kwargs in the mix too. Eg. `def fit(self, X, y=None, fit_params)`.
When would i include y input parameter in fit or transform? I can't find any examples working on y.
Is there any flexibility in/logical way to split tasks across __init__
, fit
, transform
when defining a custom transformer?
Finally, is there any documentation of the answers to the questions above or i have to read the source to understand them?
Thanks a lot! Han
Hi @gitgithan ,
Great questions, and thanks for your kind words! 👍
The Scikit-Learn API specifies that the fit()
method should return self
, so it's good practice to do so. One advantage is that it allows you to chain methods like this: model.fit(X).transform(X)
. If you don't return self
, you won't necessarily run into issues, it depends on how the class is used, but you're asking for trouble because one day you or some users of your class will expect the fit()
method to return self
and they might write things like model.fit(X).transform(X)
. Moreover, some components of Scikit-Learn such as pipelines may expect the fit()
method to return self
(you can give it a try, I'm not sure whether it'll break or not, but the fact that it could should be reason enough to always return self
).
Regarding the variations of signatures, I believe the transform(X, y=None)
signature I used was necessary a long time ago to include the transformer in a pipeline, but it might just be an error. In any case, it's not needed anymore, so you can safely remove the y=None
in the signature of the transform()
method. However, the y
argument is still needed in the fit()
method, whether you actually need y
or not, at least if you plan on using the estimator in a pipeline, as their implementation expects to be able to pass the labels to all estimators in the pipeline (this seems a bit unfortunate, perhaps this can be fixed one day).
In some cases, you may want to pass extra arguments to the fit()
method. That's typically for things like the number of epochs to train, the batch size, etc. Very often these are treated as hyperparameters (i.e., passed to the constructor and set as instance variables, not passed as arguments to the fit()
method), and you can definitely handle them like that, but in many cases it's more convenient to be able to just pass some parameters directly to the fit()
method. Of course this means that if some software component (such as a GridSearchCV
) uses the estimator and calls its fit()
method, it needs to provide a way for its user to set those extra arguments to the fit()
method. The GridSearchCV
class does that by having a **fit_params
argument in its own fit()
method. It just forwards these arguments through to the underlying estimator's fit()
method. The use of keyword arguments is necessary in this case to allow GridSearchCV
to forward any arguments the estimator may need.
Regarding your question of when to include y
in the fit()
or transform()
methods:
fit()
method should always have X
and y
. If you don't really need y
, just give it a default value of None
.transform()
method, I've never run into a case where y
would be needed, but it's not impossible. For example, suppose you want to do image augmentation of digits, with slight rotations + horizontal flip. However, some digits can be flipped horizontally (like 0 or 8) while others cannot (like 3 or 5). The image augmentation transformer would need its transform()
method to be passed both the images X
and the labels y
in order to know whether it can or cannot perform the horizontal flip.
In general, you want the constructor to be passed all the hyperparameters and just set them as instance variables. You want the fit()
method to be passed X
and y
and optionally some extra arguments that are more convenient to pass directly rather than as hyperparameters. The fit()
method should learn whatever it needs to learn from the data (e.g., a StandardScaler
learns the mean and standard deviation of each input feature) and sets them as instance variables whose name ends with an underscore (e.g., mean_
). Lastly, the transform()
method should usually be passed only input features X
, but it may occasionally need more, such as the labels y
(but very rarely).Regarding documentation, I have used three sources:
I hope this helps.
@ageron Thanks for your detailed explanation. I have a much better understanding now of the flexibility and difficulty of designing an API.
I can't understand how we built Custom Transformer class and what's the exact use of BaseEstimator and TransformerMixin. Can anyone provide detail explanation on that or any link for understanding it better. It would be a great help