SSJ API Documentation
Stochastic Simulation in Java
Loading...
Searching...
No Matches
LeastSquares.java
1/*
2 * Class: LeastSquares
3 * Description: General linear regression with the least squares method
4 * Environment: Java
5 * Software: SSJ
6 * Copyright (C) 2013 Pierre L'Ecuyer and Universite de Montreal
7 * Organization: DIRO, Universite de Montreal
8 * @author Richard Simard
9 * @since April 2013
10 *
11 *
12 * Licensed under the Apache License, Version 2.0 (the "License");
13 * you may not use this file except in compliance with the License.
14 * You may obtain a copy of the License at
15 *
16 * http://www.apache.org/licenses/LICENSE-2.0
17 *
18 * Unless required by applicable law or agreed to in writing, software
19 * distributed under the License is distributed on an "AS IS" BASIS,
20 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21 * See the License for the specific language governing permissions and
22 * limitations under the License.
23 *
24 */
25package umontreal.ssj.functionfit;
26
27import java.io.Serializable;
28import cern.colt.matrix.DoubleMatrix1D;
29import cern.colt.matrix.DoubleMatrix2D;
30import cern.colt.matrix.impl.DenseDoubleMatrix2D;
31import cern.colt.matrix.linalg.QRDecomposition;
32import cern.colt.matrix.linalg.SingularValueDecomposition;
33import cern.colt.matrix.linalg.Algebra;
34
56public class LeastSquares {
57
58 private static double[] solution(DoubleMatrix2D X, DoubleMatrix2D Y, int k) {
59 // Solve X * Beta = Y for Beta
60 // Only the first column of Y is used
61 // k is number of beta coefficients
62
63 QRDecomposition qr = new QRDecomposition(X);
64
65 if (qr.hasFullRank()) {
66 DoubleMatrix2D B = qr.solve(Y);
67 return B.viewColumn(0).toArray();
68
69 } else {
70 DoubleMatrix1D Y0 = Y.viewColumn(0); // first column of Y
71 SingularValueDecomposition svd = new SingularValueDecomposition(X);
72 DoubleMatrix2D S = svd.getS();
73 DoubleMatrix2D V = svd.getV();
74 DoubleMatrix2D U = svd.getU();
75 Algebra alg = new Algebra();
76 DoubleMatrix2D Ut = alg.transpose(U);
77 DoubleMatrix1D g = alg.mult(Ut, Y0); // Ut*Y0
78
79 for (int j = 0; j < k; j++) {
80 // solve S*p = g for p; S is a diagonal matrix
81 double x = S.getQuick(j, j);
82 if (x > 0.) {
83 x = g.getQuick(j) / x; // p[j] = g[j]/S[j]
84 g.setQuick(j, x); // overwrite g by p
85 } else
86 g.setQuick(j, 0.);
87 }
88 DoubleMatrix1D beta = alg.mult(V, g); // V*p
89 return beta.toArray();
90 }
91 }
92
104 public static double[] calcCoefficients(double[] X, double[] Y) {
105 if (X.length != Y.length)
106 throw new IllegalArgumentException("Lengths of X and Y are not equal");
107 final int n = X.length;
108 double[][] Xa = new double[n][1];
109 for (int i = 0; i < n; i++)
110 Xa[i][0] = X[i];
111
112 return calcCoefficients0(Xa, Y);
113 }
114
128 public static double[] calcCoefficients(double[] X, double[] Y, int deg) {
129 final int n = X.length;
130 if (n != Y.length)
131 throw new IllegalArgumentException("Lengths of X and Y are not equal");
132 if (n < deg + 1)
133 throw new IllegalArgumentException("Not enough points");
134
135 final double[] xSums = new double[2 * deg + 1];
136 final double[] xySums = new double[deg + 1];
137 xSums[0] = n;
138 for (int i = 0; i < n; i++) {
139 double xv = X[i];
140 xySums[0] += Y[i];
141 for (int j = 1; j <= 2 * deg; j++) {
142 xSums[j] += xv;
143 if (j <= deg)
144 xySums[j] += xv * Y[i];
145 xv *= X[i];
146 }
147 }
148 final DoubleMatrix2D A = new DenseDoubleMatrix2D(deg + 1, deg + 1);
149 final DoubleMatrix2D B = new DenseDoubleMatrix2D(deg + 1, 1);
150 for (int i = 0; i <= deg; i++) {
151 for (int j = 0; j <= deg; j++) {
152 final int d = i + j;
153 A.setQuick(i, j, xSums[d]);
154 }
155 B.setQuick(i, 0, xySums[i]);
156 }
157
158 return solution(A, B, deg + 1);
159 }
160
178 public static double[] calcCoefficients0(double[][] X, double[] Y) {
179 if (X.length != Y.length)
180 throw new IllegalArgumentException("Lengths of X and Y are not equal");
181 if (Y.length <= X[0].length + 1)
182 throw new IllegalArgumentException("Not enough points");
183
184 final int n = Y.length;
185 final int k = X[0].length;
186
187 DoubleMatrix2D Xa = new DenseDoubleMatrix2D(n, k + 1);
188 DoubleMatrix2D Ya = new DenseDoubleMatrix2D(n, 1);
189
190 for (int i = 0; i < n; i++) {
191 Xa.setQuick(i, 0, 1.);
192 for (int j = 1; j <= k; j++) {
193 Xa.setQuick(i, j, X[i][j - 1]);
194 }
195 Ya.setQuick(i, 0, Y[i]);
196 }
197
198 return solution(Xa, Ya, k + 1);
199 }
200
220 public static double[] calcCoefficients(double[][] X, double[] Y) {
221 if (X.length != Y.length)
222 throw new IllegalArgumentException("Lengths of X and Y are not equal");
223 if (Y.length <= X[0].length + 1)
224 throw new IllegalArgumentException("Not enough points");
225
226 final int n = Y.length;
227 final int k = X[0].length;
228
229 DoubleMatrix2D Xa = new DenseDoubleMatrix2D(n, k);
230 DoubleMatrix2D Ya = new DenseDoubleMatrix2D(n, 1);
231
232 for (int i = 0; i < n; i++) {
233 for (int j = 0; j < k; j++) {
234 Xa.setQuick(i, j, X[i][j]);
235 }
236 Ya.setQuick(i, 0, Y[i]);
237 }
238
239 return solution(Xa, Ya, k);
240 }
241
242}
This class implements different linear regression models, using the least squares method to estimate ...
static double[] calcCoefficients(double[][] X, double[] Y)
Computes the regression coefficients using the least squares method.
static double[] calcCoefficients(double[] X, double[] Y, int deg)
Computes the regression coefficients using the least squares method.
static double[] calcCoefficients0(double[][] X, double[] Y)
Computes the regression coefficients using the least squares method.
static double[] calcCoefficients(double[] X, double[] Y)
Computes the regression coefficients using the least squares method.