Weighted Linear Regression in Java

I was also searching for this, but I couldn't find anything. The reason might be that you can simplify the problem to the standard regression as follows:

The weighted linear regression without residual can be represented as diag(sqrt(weights))y = diag(sqrt(weights))Xb where diag(sqrt(weights))T basically means multiplying each row of the T matrix by a different square rooted weight. Therefore, the translation between weighted and unweighted regressions without residual is trivial.

To translate a regression with residual y=Xb+u into a regression without residual y=Xb, you add an additional column to X - a new column with only ones.

Now that you know how to simplify the problem, you can use any library to solve the standard linear regression.

Here's an example, using Apache Commons Math:

void linearRegression(double[] xUnweighted, double[] yUnweighted, double[] weights) {
    double[] y = new double[yUnweighted.length];
    double[][] x = new double[xUnweighted.length][2];

    for (int i = 0; i < y.length; i++) {
        y[i] = Math.sqrt(weights[i]) * yUnweighted[i];
        x[i][0] = Math.sqrt(weights[i]) * xUnweighted[i];
        x[i][1] = Math.sqrt(weights[i]);
    }

    OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
    regression.setNoIntercept(true);
    regression.newSampleData(y, x);

    double[] regressionParameters = regression.estimateRegressionParameters();
    double slope = regressionParameters[0];
    double intercept = regressionParameters[1];

    System.out.println("y = " + slope + "*x + " + intercept);
}

This can be explained intuitively by the fact that in linear regression with u=0, if you take any point (x,y) and convert it to (xC,yC), the error for the new point will also get multiplied by C. In other words, linear regression already applies higher weight to points with higher x. We are minimizing the squared error, that's why we extract the roots of the weights.


Not a library, but the code is posted: http://www.codeproject.com/KB/recipes/LinReg.aspx (and includes the mathematical explanation for the code, which is a huge plus). Also, it seems that there is another implementation of the same algorithm here: http://sin-memories.blogspot.com/2009/04/weighted-linear-regression-in-java-and.html

Finally, there is a lib from a University in New Zealand that seems to have it implemented: http://www.cs.waikato.ac.nz/~ml/weka/ (pretty decent javadocs). The specific method is described here: http://weka.sourceforge.net/doc/weka/classifiers/functions/LinearRegression.html


I personally used org.apache.commons.math.stat.regression.SimpleRegression Class of the Apache Math library.

I also found a more lightweight class from Princeton university but didn't test it:

http://introcs.cs.princeton.edu/java/97data/LinearRegression.java.html