Python scikit-learn: exporting trained classifier

Pickling/unpickling has the disadvantage that it only works with matching python versions (major and possibly also minor versions) and sklearn, joblib library versions.

There are alternative descriptive output formats for machine learning models, such as developed by the Data Mining Group, such as the predictive models markup language (PMML) and the portable format for analytics (PFA). Of the two, PMML is much better supported.

So you have the option of saving a model from scikit-learn into PMML (for example using sklearn2pmml), and then deploy and run it in java, spark, or hive using jpmml (of course you have more choices).


First, install joblib.

You can use:

>>> import joblib
>>> joblib.dump(clf, 'my_model.pkl', compress=9)

And then later, on the prediction server:

>>> import joblib
>>> model_clone = joblib.load('my_model.pkl')

This is basically a Python pickle with an optimized handling for large numpy arrays. It has the same limitations as the regular pickle w.r.t. code change: if the class structure of the pickle object changes you might no longer be able to unpickle the object with new versions of nolearn or scikit-learn.

If you want long-term robust way of storing your model parameters you might need to write your own IO layer (e.g. using binary format serialization tools such as protocol buffers or avro or an inefficient yet portable text / json / xml representation such as PMML).


The section 3.4. Model persistence in scikit-learn documentation covers pretty much everything.

In addition to sklearn.externals.joblib ogrisel pointed to, it shows how to use the regular pickle package:

>>> from sklearn import svm
>>> from sklearn import datasets
>>> clf = svm.SVC()
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf.fit(X, y)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
  kernel='rbf', max_iter=-1, probability=False, random_state=None,
  shrinking=True, tol=0.001, verbose=False)

>>> import pickle
>>> s = pickle.dumps(clf)
>>> clf2 = pickle.loads(s)
>>> clf2.predict(X[0])
array([0])
>>> y[0]
0

and gives a few warnings such as models saved in one version of scikit-learn might not load in another version.