SSJ API Documentation
Stochastic Simulation in Java
Loading...
Searching...
No Matches
DirichletDist.java
1/*
2 * Class: DirichletDist
3 * Description: Dirichlet distribution
4 * Environment: Java
5 * Software: SSJ
6 * Copyright (C) 2001 Pierre L'Ecuyer and Universite de Montreal
7 * Organization: DIRO, Universite de Montreal
8 * @author
9 * @since
10 *
11 *
12 * Licensed under the Apache License, Version 2.0 (the "License");
13 * you may not use this file except in compliance with the License.
14 * You may obtain a copy of the License at
15 *
16 * http://www.apache.org/licenses/LICENSE-2.0
17 *
18 * Unless required by applicable law or agreed to in writing, software
19 * distributed under the License is distributed on an "AS IS" BASIS,
20 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21 * See the License for the specific language governing permissions and
22 * limitations under the License.
23 *
24 */
25package umontreal.ssj.probdistmulti;
26
27import umontreal.ssj.util.Num;
28import optimization.*;
29
46public class DirichletDist extends ContinuousDistributionMulti {
47 private static final double LOGMIN = -709.1; // Log(MIN_DOUBLE/2)
48 protected double[] alpha;
49
50 private static class Optim implements Uncmin_methods {
51 double[] logP;
52 int n;
53 int k;
54
55 public Optim(double[] logP, int n) {
56 this.n = n;
57 this.k = logP.length;
58 this.logP = new double[k];
59 System.arraycopy(logP, 0, this.logP, 0, k);
60 }
61
62 public double f_to_minimize(double[] alpha) {
63 double sumAlpha = 0.0;
64 double sumLnGammaAlpha = 0.0;
65 double sumAlphaLnP = 0.0;
66
67 for (int i = 1; i < alpha.length; i++) {
68 if (alpha[i] <= 0.0)
69 return 1.0e200;
70
71 sumAlpha += alpha[i];
72 sumLnGammaAlpha += Num.lnGamma(alpha[i]);
73 sumAlphaLnP += ((alpha[i] - 1.0) * logP[i - 1]);
74 }
75
76 return (-n * (Num.lnGamma(sumAlpha) - sumLnGammaAlpha + sumAlphaLnP));
77 }
78
79 public void gradient(double[] alpha, double[] g) {
80 }
81
82 public void hessian(double[] alpha, double[][] h) {
83 }
84 }
85
86 public DirichletDist(double[] alpha) {
87 setParams(alpha);
88 }
89
90 public double density(double[] x) {
91 return density_(alpha, x);
92 }
93
94 public double[] getMean() {
95 return getMean_(alpha);
96 }
97
98 public double[][] getCovariance() {
99 return getCovariance_(alpha);
100 }
101
102 public double[][] getCorrelation() {
103 return getCorrelation_(alpha);
104 }
105
106 private static void verifParam(double[] alpha) {
107
108 for (int i = 0; i < alpha.length; i++) {
109 if (alpha[i] <= 0)
110 throw new IllegalArgumentException("alpha[" + i + "] <= 0");
111 }
112 }
113
114 private static double density_(double[] alpha, double[] x) {
115 double alpha0 = 0.0;
116 double sumLnGamma = 0.0;
117 double sumAlphaLnXi = 0.0;
118
119 if (alpha.length != x.length)
120 throw new IllegalArgumentException("alpha and x must have the same dimension");
121
122 for (int i = 0; i < alpha.length; i++) {
123 alpha0 += alpha[i];
124 sumLnGamma += Num.lnGamma(alpha[i]);
125 if (x[i] <= 0.0 || x[i] >= 1.0)
126 return 0.0;
127 sumAlphaLnXi += (alpha[i] - 1.0) * Math.log(x[i]);
128 }
129
130 return Math.exp(Num.lnGamma(alpha0) - sumLnGamma + sumAlphaLnXi);
131 }
132
138 public static double density(double[] alpha, double[] x) {
139 verifParam(alpha);
140 return density_(alpha, x);
141 }
142
143 private static double[][] getCovariance_(double[] alpha) {
144 double[][] cov = new double[alpha.length][alpha.length];
145 double alpha0 = 0.0;
146
147 for (int i = 0; i < alpha.length; i++)
148 alpha0 += alpha[i];
149
150 for (int i = 0; i < alpha.length; i++) {
151 for (int j = 0; j < alpha.length; j++)
152 cov[i][j] = -(alpha[i] * alpha[j]) / (alpha0 * alpha0 * (alpha0 + 1.0));
153
154 cov[i][i] = (alpha[i] / alpha0) * (1.0 - alpha[i] / alpha0) / (alpha0 + 1.0);
155 }
156
157 return cov;
158 }
159
164 public static double[][] getCovariance(double[] alpha) {
165 verifParam(alpha);
166
167 return getCovariance_(alpha);
168 }
169
170 private static double[][] getCorrelation_(double[] alpha) {
171 double corr[][] = new double[alpha.length][alpha.length];
172 double alpha0 = 0.0;
173
174 for (int i = 0; i < alpha.length; i++)
175 alpha0 += alpha[i];
176
177 for (int i = 0; i < alpha.length; i++) {
178 for (int j = 0; j < alpha.length; j++)
179 corr[i][j] = -Math.sqrt((alpha[i] * alpha[j]) / ((alpha0 - alpha[i]) * (alpha0 - alpha[j])));
180 corr[i][i] = 1.0;
181 }
182 return corr;
183 }
184
189 public static double[][] getCorrelation(double[] alpha) {
190 verifParam(alpha);
191 return getCorrelation_(alpha);
192 }
193
214 public static double[] getMLE(double[][] x, int n, int d) {
215 if (n <= 0)
216 throw new IllegalArgumentException("n <= 0");
217 if (d <= 0)
218 throw new IllegalArgumentException("d <= 0");
219
220 double[] logP = new double[d];
221 double mean[] = new double[d];
222 double var[] = new double[d];
223 int i;
224 int j;
225 for (i = 0; i < d; i++) {
226 logP[i] = 0.0;
227 mean[i] = 0.0;
228 }
229
230 for (i = 0; i < n; i++) {
231 for (j = 0; j < d; j++) {
232 if (x[i][j] > 0.)
233 logP[j] += Math.log(x[i][j]);
234 else
235 logP[j] += LOGMIN;
236 mean[j] += x[i][j];
237 }
238 }
239
240 for (i = 0; i < d; i++) {
241 logP[i] /= (double) n;
242 mean[i] /= (double) n;
243 }
244
245 double sum = 0.0;
246 for (j = 0; j < d; j++) {
247 sum = 0.0;
248 for (i = 0; i < n; i++)
249 sum += (x[i][j] - mean[j]) * (x[i][j] - mean[j]);
250 var[j] = sum / (double) n;
251 }
252
253 double alpha0 = (mean[0] * (1.0 - mean[0])) / var[0] - 1.0;
254 Optim system = new Optim(logP, n);
255
256 double[] parameters = new double[d];
257 double[] xpls = new double[d + 1];
258 double[] alpha = new double[d + 1];
259 double[] fpls = new double[d + 1];
260 double[] gpls = new double[d + 1];
261 int[] itrcmd = new int[2];
262 double[][] a = new double[d + 1][d + 1];
263 double[] udiag = new double[d + 1];
264
265 for (i = 1; i <= d; i++)
266 alpha[i] = mean[i - 1] * alpha0;
267
268 Uncmin_f77.optif0_f77(d, alpha, system, xpls, fpls, gpls, itrcmd, a, udiag);
269
270 for (i = 0; i < d; i++)
271 parameters[i] = xpls[i + 1];
272
273 return parameters;
274 }
275
276 private static double[] getMean_(double[] alpha) {
277 double alpha0 = 0.0;
278 double[] mean = new double[alpha.length];
279
280 for (int i = 0; i < alpha.length; i++)
281 alpha0 += alpha[i];
282
283 for (int i = 0; i < alpha.length; i++)
284 mean[i] = alpha[i] / alpha0;
285
286 return mean;
287 }
288
295 public static double[] getMean(double[] alpha) {
296 verifParam(alpha);
297 return getMean_(alpha);
298 }
299
303 public double[] getAlpha() {
304 return alpha;
305 }
306
310 public double getAlpha(int i) {
311 return alpha[i];
312 }
313
317 public void setParams(double[] alpha) {
318 this.dimension = alpha.length;
319 this.alpha = new double[dimension];
320 for (int i = 0; i < dimension; i++) {
321 if (alpha[i] <= 0)
322 throw new IllegalArgumentException("alpha[" + i + "] <= 0");
323 this.alpha[i] = alpha[i];
324 }
325 }
326
327}
Classes implementing continuous multi-dimensional distributions should inherit from this class.
double[][] getCorrelation()
Returns the correlation matrix of the distribution, defined as.
static double[] getMean(double[] alpha)
Computes the mean of the Dirichlet distribution with parameters ( , …, ), where.
double density(double[] x)
Returns , the probability density of evaluated at the point , where .
static double[][] getCovariance(double[] alpha)
Computes the covariance matrix of the Dirichlet distribution with parameters ( , …,...
void setParams(double[] alpha)
Sets the parameters ( , …, ) of this object.
static double density(double[] alpha, double[] x)
Computes the density ( fDirichlet ) of the Dirichlet distribution with parameters ( ,...
static double[] getMLE(double[][] x, int n, int d)
Estimates the parameters [ ] of the Dirichlet distribution using the maximum likelihood method.
double[] getAlpha()
Returns the parameters ( , …, ) of this object.
double[][] getCovariance()
Returns the variance-covariance matrix of the distribution, defined as .
double getAlpha(int i)
Returns the th component of the alpha vector.
double[] getMean()
Returns the mean vector of the distribution, defined as .
static double[][] getCorrelation(double[] alpha)
Computes the correlation matrix of the Dirichlet distribution with parameters ( , …,...
This class provides various constants and methods to compute numerical quantities such as factorials,...
Definition Num.java:35
static double lnGamma(double x)
Returns the natural logarithm of the gamma function evaluated at x.
Definition Num.java:417