SSJ API Documentation
Stochastic Simulation in Java
Loading...
Searching...
No Matches
MultiNormalDist.java
1/*
2 * Class: MultiNormalDist
3 * Description: multinormal distribution
4 * Environment: Java
5 * Software: SSJ
6 * Copyright (C) 2001 Pierre L'Ecuyer and Universite de Montreal
7 * Organization: DIRO, Universite de Montreal
8 * @author
9 * @since
10 *
11 *
12 * Licensed under the Apache License, Version 2.0 (the "License");
13 * you may not use this file except in compliance with the License.
14 * You may obtain a copy of the License at
15 *
16 * http://www.apache.org/licenses/LICENSE-2.0
17 *
18 * Unless required by applicable law or agreed to in writing, software
19 * distributed under the License is distributed on an "AS IS" BASIS,
20 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21 * See the License for the specific language governing permissions and
22 * limitations under the License.
23 *
24 */
25package umontreal.ssj.probdistmulti;
26
27import cern.colt.matrix.DoubleMatrix2D;
28import cern.colt.matrix.impl.DenseDoubleMatrix2D;
29import cern.colt.matrix.linalg.Algebra;
30
46public class MultiNormalDist extends ContinuousDistributionMulti {
47 protected int dim;
48 protected double[] mu;
49 protected DoubleMatrix2D sigma;
50 protected DoubleMatrix2D invSigma;
51
52 protected static Algebra algebra = new Algebra();
53
54 public MultiNormalDist(double[] mu, double[][] sigma) {
55 setParams(mu, sigma);
56 }
57
58 public double density(double[] x) {
59 double sum = 0.0;
60
61 if (invSigma == null)
62 invSigma = algebra.inverse(sigma);
63
64 double[] temp = new double[mu.length];
65 for (int i = 0; i < mu.length; i++) {
66 sum = 0.0;
67 for (int j = 0; j < mu.length; j++)
68 sum += ((x[j] - mu[j]) * invSigma.getQuick(j, i));
69 temp[i] = sum;
70 }
71
72 sum = 0.0;
73 for (int i = 0; i < mu.length; i++)
74 sum += temp[i] * (x[i] - mu[i]);
75
76 return (Math.exp(-0.5 * sum) / Math.sqrt(Math.pow(2 * Math.PI, mu.length) * algebra.det(sigma)));
77 }
78
79 public double[] getMean() {
80 return mu;
81 }
82
83 public double[][] getCovariance() {
84 return sigma.toArray();
85 }
86
87 public double[][] getCorrelation() {
88 return getCorrelation_(mu, sigma.toArray());
89 }
90
99 public static double density(double[] mu, double[][] sigma, double[] x) {
100 double sum = 0.0;
101 DoubleMatrix2D sig;
102 DoubleMatrix2D inv;
103
104 if (sigma.length != sigma[0].length)
105 throw new IllegalArgumentException("sigma must be a square matrix");
106 if (mu.length != sigma.length)
107 throw new IllegalArgumentException("mu and sigma must have the same dimension");
108
109 sig = new DenseDoubleMatrix2D(sigma);
110 inv = algebra.inverse(sig);
111
112 double[] temp = new double[mu.length];
113 for (int i = 0; i < mu.length; i++) {
114 sum = 0.0;
115 for (int j = 0; j < mu.length; j++)
116 sum += ((x[j] - mu[j]) * inv.getQuick(j, i));
117 temp[i] = sum;
118 }
119
120 sum = 0.0;
121 for (int i = 0; i < mu.length; i++)
122 sum += temp[i] * (x[i] - mu[i]);
123
124 return (Math.exp(-0.5 * sum) / Math.sqrt(Math.pow(2 * Math.PI, mu.length) * algebra.det(sig)));
125 }
126
130 public int getDimension() {
131 return dim;
132 }
133
140 public static double[] getMean(double[] mu, double[][] sigma) {
141 if (sigma.length != sigma[0].length)
142 throw new IllegalArgumentException("sigma must be a square matrix");
143 if (mu.length != sigma.length)
144 throw new IllegalArgumentException("mu and sigma must have the same dimension");
145
146 return mu;
147 }
148
153 public static double[][] getCovariance(double[] mu, double[][] sigma) {
154 if (sigma.length != sigma[0].length)
155 throw new IllegalArgumentException("sigma must be a square matrix");
156 if (mu.length != sigma.length)
157 throw new IllegalArgumentException("mu and sigma must have the same dimension");
158
159 return sigma;
160 }
161
162 private static double[][] getCorrelation_(double[] mu, double[][] sigma) {
163 double corr[][] = new double[mu.length][mu.length];
164
165 for (int i = 0; i < mu.length; i++) {
166 for (int j = 0; j < mu.length; j++)
167 corr[i][j] = -sigma[i][j] / Math.sqrt(sigma[i][i] * sigma[j][j]);
168 corr[i][i] = 1.0;
169 }
170 return corr;
171 }
172
177 public static double[][] getCorrelation(double[] mu, double[][] sigma) {
178 if (sigma.length != sigma[0].length)
179 throw new IllegalArgumentException("sigma must be a square matrix");
180 if (mu.length != sigma.length)
181 throw new IllegalArgumentException("mu and sigma must have the same dimension");
182
183 return getCorrelation_(mu, sigma);
184 }
185
198 public static double[] getMLEMu(double[][] x, int n, int d) {
199 if (n <= 0)
200 throw new IllegalArgumentException("n <= 0");
201 if (d <= 0)
202 throw new IllegalArgumentException("d <= 0");
203
204 double[] parameters = new double[d];
205 for (int i = 0; i < parameters.length; i++)
206 parameters[i] = 0.0;
207
208 for (int i = 0; i < n; i++)
209 for (int j = 0; j < d; j++)
210 parameters[j] += x[i][j];
211
212 for (int i = 0; i < parameters.length; i++)
213 parameters[i] = parameters[i] / (double) n;
214
215 return parameters;
216 }
217
229 public static double[][] getMLESigma(double[][] x, int n, int d) {
230 double sum = 0.0;
231
232 if (n <= 0)
233 throw new IllegalArgumentException("n <= 0");
234 if (d <= 0)
235 throw new IllegalArgumentException("d <= 0");
236
237 double[] mean = getMLEMu(x, n, d);
238 double[][] parameters = new double[d][d];
239 for (int i = 0; i < parameters.length; i++)
240 for (int j = 0; j < parameters.length; j++)
241 parameters[i][j] = 0.0;
242
243 for (int i = 0; i < parameters.length; i++) {
244 for (int j = 0; j < parameters.length; j++) {
245 sum = 0.0;
246 for (int t = 0; t < n; t++)
247 sum += (x[t][i] - mean[i]) * (x[t][j] - mean[j]);
248 parameters[i][j] = sum / (double) n;
249 }
250 }
251
252 return parameters;
253 }
254
258 public double[] getMu() {
259 return mu;
260 }
261
267 public double getMu(int i) {
268 return mu[i];
269 }
270
274 public double[][] getSigma() {
275 return sigma.toArray();
276 }
277
283 public void setParams(double[] mu, double[][] sigma) {
284 if (sigma.length != sigma[0].length)
285 throw new IllegalArgumentException("sigma must be a square matrix");
286 if (mu.length != sigma.length)
287 throw new IllegalArgumentException("mu and sigma must have the same dimension");
288
289 this.mu = new double[mu.length];
290 this.dimension = mu.length;
291 System.arraycopy(mu, 0, this.mu, 0, mu.length);
292 this.sigma = new DenseDoubleMatrix2D(sigma);
293
294 invSigma = null;
295 }
296
297}
Classes implementing continuous multi-dimensional distributions should inherit from this class.
double[][] getCorrelation()
Returns the correlation matrix of the distribution, defined as.
void setParams(double[] mu, double[][] sigma)
Sets the parameters and.
static double[][] getCorrelation(double[] mu, double[][] sigma)
Computes the correlation matrix of the multinormal distribution with parameters and ).
static double[][] getCovariance(double[] mu, double[][] sigma)
Computes the covariance matrix of the multinormal distribution with parameters and .
double[][] getSigma()
Returns the parameter of this object.
static double[][] getMLESigma(double[][] x, int n, int d)
Estimates the parameters of the multinormal distribution using the maximum likelihood method.
double density(double[] x)
Returns , the probability density of evaluated at the point , where .
static double density(double[] mu, double[][] sigma, double[] x)
Computes the density ( fMultinormal ) of the multinormal distribution with parameters.
static double[] getMean(double[] mu, double[][] sigma)
Returns the mean of the multinormal distribution with parameters and.
double[] getMean()
Returns the mean vector of the distribution, defined as .
static double[] getMLEMu(double[][] x, int n, int d)
Estimates the parameters of the multinormal distribution using the maximum likelihood method.
double getMu(int i)
Returns the -th component of the parameter.
int getDimension()
Returns the dimension of the distribution.
double[] getMu()
Returns the parameter of this object.
double[][] getCovariance()
Returns the variance-covariance matrix of the distribution, defined as .