How can I call scikit-learn classifiers from Java?

You can either use a porter, I have tested the sklearn-porter (https://github.com/nok/sklearn-porter), and it works well for Java.

My code is the following:

import pandas as pd
from sklearn import tree
from sklearn_porter import Porter

train_dataset = pd.read_csv('./result2.csv').as_matrix()

X_train = train_dataset[:90, :8]
Y_train = train_dataset[:90, 8:]

X_test = train_dataset[90:, :8]
Y_test = train_dataset[90:, 8:]

print X_train.shape
print Y_train.shape


clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)

porter = Porter(clf, language='java')
output = porter.export(embed_data=True)
print(output)

In my case, I'm using a DecisionTreeClassifier, and the output of

print(output)

is the following code as text in the console:

class DecisionTreeClassifier {

  private static int findMax(int[] nums) {
    int index = 0;
    for (int i = 0; i < nums.length; i++) {
        index = nums[i] > nums[index] ? i : index;
    }
    return index;
  }


  public static int predict(double[] features) {
    int[] classes = new int[2];

    if (features[5] <= 51.5) {
        if (features[6] <= 21.0) {

            // HUGE amount of ifs..........

        }
    }

    return findMax(classes);
  }

  public static void main(String[] args) {
    if (args.length == 8) {

        // Features:
        double[] features = new double[args.length];
        for (int i = 0, l = args.length; i < l; i++) {
            features[i] = Double.parseDouble(args[i]);
        }

        // Prediction:
        int prediction = DecisionTreeClassifier.predict(features);
        System.out.println(prediction);

    }
  }
}

There is JPMML project for this purpose.

First, you can serialize scikit-learn model to PMML (which is XML internally) using sklearn2pmml library directly from python or dump it in python first and convert using jpmml-sklearn in java or from a command line provided by this library. Next, you can load pmml file, deserialize and execute loaded model using jpmml-evaluator in your Java code.

This way works with not all scikit-learn models, but with many of them.


You cannot use jython as scikit-learn heavily relies on numpy and scipy that have many compiled C and Fortran extensions hence cannot work in jython.

The easiest ways to use scikit-learn in a java environment would be to:

  • expose the classifier as a HTTP / Json service, for instance using a microframework such as flask or bottle or cornice and call it from java using an HTTP client library

  • write a commandline wrapper application in python that reads data on stdin and output predictions on stdout using some format such as CSV or JSON (or some lower level binary representation) and call the python program from java for instance using Apache Commons Exec.

  • make the python program output the raw numerical parameters learnt at fit time (typically as an array of floating point values) and reimplement the predict function in java (this is typically easy for predictive linear models where the prediction is often just a thresholded dot product).

The last approach will be a lot more work if you need to re-implement feature extraction in Java as well.

Finally you can use a Java library such as Weka or Mahout that implement the algorithms you need instead of trying to use scikit-learn from Java.