25package umontreal.ssj.functionfit;
27import java.io.Serializable;
28import cern.colt.matrix.DoubleMatrix1D;
29import cern.colt.matrix.DoubleMatrix2D;
30import cern.colt.matrix.impl.DenseDoubleMatrix2D;
31import cern.colt.matrix.linalg.QRDecomposition;
32import cern.colt.matrix.linalg.SingularValueDecomposition;
33import cern.colt.matrix.linalg.Algebra;
58 private static double[] solution(DoubleMatrix2D X, DoubleMatrix2D Y,
int k) {
63 QRDecomposition qr =
new QRDecomposition(X);
65 if (qr.hasFullRank()) {
66 DoubleMatrix2D B = qr.solve(Y);
67 return B.viewColumn(0).toArray();
70 DoubleMatrix1D Y0 = Y.viewColumn(0);
71 SingularValueDecomposition svd =
new SingularValueDecomposition(X);
72 DoubleMatrix2D S = svd.getS();
73 DoubleMatrix2D V = svd.getV();
74 DoubleMatrix2D U = svd.getU();
75 Algebra alg =
new Algebra();
76 DoubleMatrix2D Ut = alg.transpose(U);
77 DoubleMatrix1D g = alg.mult(Ut, Y0);
79 for (
int j = 0; j < k; j++) {
81 double x = S.getQuick(j, j);
83 x = g.getQuick(j) / x;
88 DoubleMatrix1D beta = alg.mult(V, g);
89 return beta.toArray();
105 if (X.length != Y.length)
106 throw new IllegalArgumentException(
"Lengths of X and Y are not equal");
107 final int n = X.length;
108 double[][] Xa =
new double[n][1];
109 for (
int i = 0; i < n; i++)
129 final int n = X.length;
131 throw new IllegalArgumentException(
"Lengths of X and Y are not equal");
133 throw new IllegalArgumentException(
"Not enough points");
135 final double[] xSums =
new double[2 * deg + 1];
136 final double[] xySums =
new double[deg + 1];
138 for (
int i = 0; i < n; i++) {
141 for (
int j = 1; j <= 2 * deg; j++) {
144 xySums[j] += xv * Y[i];
148 final DoubleMatrix2D A =
new DenseDoubleMatrix2D(deg + 1, deg + 1);
149 final DoubleMatrix2D B =
new DenseDoubleMatrix2D(deg + 1, 1);
150 for (
int i = 0; i <= deg; i++) {
151 for (
int j = 0; j <= deg; j++) {
153 A.setQuick(i, j, xSums[d]);
155 B.setQuick(i, 0, xySums[i]);
158 return solution(A, B, deg + 1);
179 if (X.length != Y.length)
180 throw new IllegalArgumentException(
"Lengths of X and Y are not equal");
181 if (Y.length <= X[0].length + 1)
182 throw new IllegalArgumentException(
"Not enough points");
184 final int n = Y.length;
185 final int k = X[0].length;
187 DoubleMatrix2D Xa =
new DenseDoubleMatrix2D(n, k + 1);
188 DoubleMatrix2D Ya =
new DenseDoubleMatrix2D(n, 1);
190 for (
int i = 0; i < n; i++) {
191 Xa.setQuick(i, 0, 1.);
192 for (
int j = 1; j <= k; j++) {
193 Xa.setQuick(i, j, X[i][j - 1]);
195 Ya.setQuick(i, 0, Y[i]);
198 return solution(Xa, Ya, k + 1);
221 if (X.length != Y.length)
222 throw new IllegalArgumentException(
"Lengths of X and Y are not equal");
223 if (Y.length <= X[0].length + 1)
224 throw new IllegalArgumentException(
"Not enough points");
226 final int n = Y.length;
227 final int k = X[0].length;
229 DoubleMatrix2D Xa =
new DenseDoubleMatrix2D(n, k);
230 DoubleMatrix2D Ya =
new DenseDoubleMatrix2D(n, 1);
232 for (
int i = 0; i < n; i++) {
233 for (
int j = 0; j < k; j++) {
234 Xa.setQuick(i, j, X[i][j]);
236 Ya.setQuick(i, 0, Y[i]);
239 return solution(Xa, Ya, k);
This class implements different linear regression models, using the least squares method to estimate ...
static double[] calcCoefficients(double[][] X, double[] Y)
Computes the regression coefficients using the least squares method.
static double[] calcCoefficients(double[] X, double[] Y, int deg)
Computes the regression coefficients using the least squares method.
static double[] calcCoefficients0(double[][] X, double[] Y)
Computes the regression coefficients using the least squares method.
static double[] calcCoefficients(double[] X, double[] Y)
Computes the regression coefficients using the least squares method.