what is the difference between transformer and estimator in sklearn?

Transformer is a type of Estimator that implements transform method.

Let me support that statement with examples I have come across in sklearn implementation.

  1. Class sklearn.preprocessing.FunctionTransformer :

This inherits from two other classes TransformerMixin, BaseEstimator

  1. Class sklearn.preprocessing.PowerTransformer :

This also inherits from TransformerMixin, BaseEstimator

From what I understand, Estimators just take data, do some processing, and store data based on logic implemented in its fit method.

Note: Estimator's aren't used to predict values directly. They don't even have predict method in them.

Before I give more explanation to the above statement, let me tell you about Mixin Classes.

Mixin Class: These are classes that implement a Mix-in design pattern. Wikipedia has very good explanation about it. You can read it here . To summarise, these are classes you write which have methods that can be used in many different classes. So, you write them in one class and just inherit in many different classes(A form of composition. Read These Links - Link1 Link2)

In Sklearn there are many mixin classes. To name a few ClassifierMixin, RegressorMixin, TransformerMixin.

Here, TransformerMixin is the class that's inherited by every Transformer used in sklearn. TransformerMixin class has only one method which is reusable in every transformer and that is fit_transform.

All transformers inherit two classes, BaseEstimator(Which has fit method) and TransformerMixin(Which has fit_transform method). And, Each transformer has transform method based on its functionality

I guess that gives an answer to your question. Now, let me answer the statement I made regarding the Estimator for prediction.

Every Model Class has its own predict class that does prediction.

Consider LinearRegression, KNeighborsClassifier, or any other Model class. They all have a predict function declared in them. This is used for prediction. Not the Estimator.


The sklearn usage is perhaps a little unintuitive, but "estimator" doesn't mean anything very specific: basically everything is an estimator.

From the sklearn glossary:

estimator:

An object which manages the estimation and decoding of a model...

Estimators must provide a fit method, and should provide set_params and get_params, although these are usually provided by inheritance from base.BaseEstimator.

transformer:

An estimator supporting transform and/or fit_transform...

As in @VivekKumar's answer, I think there's a tendency to use the word estimator for what sklearn instead calls a "predictor":

An estimator supporting predict and/or fit_predict. This encompasses classifier, regressor, outlier detector and clusterer...


The basic difference is that a:

  • Transformer transforms the input data (X) in some ways.
  • Estimator predicts a new value (or values) (y) by using the input data (X).

Both the Transformer and Estimator should have a fit() method which can be used to train them (they learn some characteristics of the data). The signature is:

fit(X, y)

fit() does not return any value, just stores the learnt data inside the object.

Here X represents the samples (feature vectors) and y is the target vector (which may have single or multiple values per corresponding sample in X). Note that y can be optional in some transformers where its not needed, but its mandatory for most estimators (supervised estimators). Look at StandardScaler for example. It needs the initial data X for finding the mean and std of the data (it learns the characteristics of X, y is not needed).

Each Transformer should have a transform(X, y) function which like fit() takes the input X and returns a new transformed version of X (which generally should have same number samples but may or may not have same features).

On the other hand, Estimator should have a predict(X) method which should output the predicted value of y from the given X.

There will be some classes in scikit-learn which implement both transform() and predict(), like KMeans, in that case carefully reading the documentation should solve your doubts.

Tags:

Scikit Learn