1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.rng.examples.sampling;
19
20 import java.io.PrintWriter;
21 import java.util.EnumSet;
22 import java.util.concurrent.Callable;
23 import java.io.IOException;
24 import org.apache.commons.rng.UniformRandomProvider;
25 import org.apache.commons.rng.simple.RandomSource;
26
27 import picocli.CommandLine.Command;
28 import picocli.CommandLine.Mixin;
29 import picocli.CommandLine.Option;
30
31 import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
32 import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
33 import org.apache.commons.rng.sampling.distribution.MarsagliaNormalizedGaussianSampler;
34 import org.apache.commons.rng.sampling.distribution.StableSampler;
35 import org.apache.commons.rng.sampling.distribution.TSampler;
36 import org.apache.commons.rng.sampling.distribution.BoxMullerNormalizedGaussianSampler;
37 import org.apache.commons.rng.sampling.distribution.ChengBetaSampler;
38 import org.apache.commons.rng.sampling.distribution.AhrensDieterExponentialSampler;
39 import org.apache.commons.rng.sampling.distribution.AhrensDieterMarsagliaTsangGammaSampler;
40 import org.apache.commons.rng.sampling.distribution.InverseTransformParetoSampler;
41 import org.apache.commons.rng.sampling.distribution.LevySampler;
42 import org.apache.commons.rng.sampling.distribution.LogNormalSampler;
43 import org.apache.commons.rng.sampling.distribution.ContinuousUniformSampler;
44 import org.apache.commons.rng.sampling.distribution.GaussianSampler;
45 import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
46
47
48
49
50 @Command(name = "density",
51 description = {"Approximate the probability density of samplers."})
52 class ProbabilityDensityApproximationCommand implements Callable<Void> {
53
54 @Mixin
55 private StandardOptions reusableOptions;
56
57
58 @Option(names = {"-b", "--bins"},
59 description = "The number of bins in the histogram (default: ${DEFAULT-VALUE}).")
60 private int numBins = 25_000;
61
62
63 @Option(names = {"-n", "--samples"},
64 description = "The number of samples in the histogram (default: ${DEFAULT-VALUE}).")
65 private long numSamples = 1_000_000_000;
66
67
68 @Option(names = {"-s", "--samplers"},
69 split = ",",
70 description = {"The samplers (comma-delimited for multiple options).",
71 "Valid values: ${COMPLETION-CANDIDATES}."})
72 private EnumSet<Sampler> samplers = EnumSet.noneOf(Sampler.class);
73
74
75 @Option(names = {"-r", "--rng"},
76 description = {"The source of randomness (default: ${DEFAULT-VALUE})."})
77 private RandomSource randomSource = RandomSource.XOR_SHIFT_1024_S_PHI;
78
79
80 @Option(names = {"-a", "--all"},
81 description = "Output all samplers")
82 private boolean allSamplers;
83
84
85
86
87 enum Sampler {
88
89 ZigguratGaussianSampler,
90
91 MarsagliaGaussianSampler,
92
93 BoxMullerGaussianSampler,
94
95 ModifiedZigguratGaussianSampler,
96
97 ChengBetaSamplerCase1,
98
99 ChengBetaSamplerCase2,
100
101 AhrensDieterExponentialSampler,
102
103 ModifiedZigguratExponentialSampler,
104
105 AhrensDieterMarsagliaTsangGammaSamplerCase1,
106
107 AhrensDieterMarsagliaTsangGammaSamplerCase2,
108
109 InverseTransformParetoSampler,
110
111 ContinuousUniformSampler,
112
113 LogNormalZigguratGaussianSampler,
114
115 LogNormalMarsagliaGaussianSampler,
116
117 LogNormalBoxMullerGaussianSampler,
118
119 LogNormalModifiedZigguratGaussianSampler,
120
121 LevySampler,
122
123 StableSampler,
124
125 TSampler,
126 }
127
128
129
130
131
132
133
134
135
136
137
138
139 private void createDensity(ContinuousSampler sampler,
140 double min,
141 double max,
142 String outputFile)
143 throws IOException {
144 final double binSize = (max - min) / numBins;
145 final long[] histogram = new long[numBins];
146
147 long belowMin = 0;
148 long aboveMax = 0;
149 for (long n = 0; n < numSamples; n++) {
150 final double r = sampler.sample();
151
152 if (r < min) {
153 ++belowMin;
154 continue;
155 }
156
157 if (r >= max) {
158 ++aboveMax;
159 continue;
160 }
161
162 final int binIndex = (int) ((r - min) / binSize);
163 ++histogram[binIndex];
164 }
165
166 final double binHalfSize = 0.5 * binSize;
167 final double norm = 1 / (binSize * numSamples);
168
169 try (PrintWriter out = new PrintWriter("pdf." + outputFile + ".txt", "UTF-8")) {
170
171 out.println("# Sampler: " + sampler);
172 out.println("# Number of bins: " + numBins);
173 out.println("# Min: " + min + " (fraction of samples below: " + (belowMin / (double) numSamples) + ")");
174 out.println("# Max: " + max + " (fraction of samples above: " + (aboveMax / (double) numSamples) + ")");
175 out.println("# Bin width: " + binSize);
176 out.println("# Histogram normalization factor: " + norm);
177 out.println("#");
178 out.println("# " + (min - binHalfSize) + " " + (belowMin * norm));
179 for (int i = 0; i < numBins; i++) {
180 out.println((min + (i + 1) * binSize - binHalfSize) + " " + (histogram[i] * norm));
181 }
182 out.println("# " + (max + binHalfSize) + " " + (aboveMax * norm));
183
184 }
185 }
186
187
188
189
190
191
192 @Override
193 public Void call() throws IOException {
194 if (allSamplers) {
195 samplers = EnumSet.allOf(Sampler.class);
196 } else if (samplers.isEmpty()) {
197
198 System.err.println("ERROR: No samplers specified");
199
200 System.exit(1);
201 }
202
203 final UniformRandomProvider rng = randomSource.create();
204
205 final double gaussMean = 1;
206 final double gaussSigma = 2;
207 final double gaussMin = -9;
208 final double gaussMax = 11;
209 if (samplers.contains(Sampler.ZigguratGaussianSampler)) {
210 createDensity(GaussianSampler.of(ZigguratNormalizedGaussianSampler.of(rng),
211 gaussMean, gaussSigma),
212 gaussMin, gaussMax, "gauss.ziggurat");
213 }
214 if (samplers.contains(Sampler.MarsagliaGaussianSampler)) {
215 createDensity(GaussianSampler.of(MarsagliaNormalizedGaussianSampler.of(rng),
216 gaussMean, gaussSigma),
217 gaussMin, gaussMax, "gauss.marsaglia");
218 }
219 if (samplers.contains(Sampler.BoxMullerGaussianSampler)) {
220 createDensity(GaussianSampler.of(BoxMullerNormalizedGaussianSampler.of(rng),
221 gaussMean, gaussSigma),
222 gaussMin, gaussMax, "gauss.boxmuller");
223 }
224 if (samplers.contains(Sampler.ModifiedZigguratGaussianSampler)) {
225 createDensity(GaussianSampler.of(ZigguratSampler.NormalizedGaussian.of(rng),
226 gaussMean, gaussSigma),
227 gaussMin, gaussMax, "gauss.modified.ziggurat");
228 }
229
230 final double betaMin = 0;
231 final double betaMax = 1;
232 if (samplers.contains(Sampler.ChengBetaSamplerCase1)) {
233 final double alphaBeta = 4.3;
234 final double betaBeta = 2.1;
235 createDensity(ChengBetaSampler.of(rng, alphaBeta, betaBeta),
236 betaMin, betaMax, "beta.case1");
237 }
238 if (samplers.contains(Sampler.ChengBetaSamplerCase2)) {
239 final double alphaBetaAlt = 0.5678;
240 final double betaBetaAlt = 0.1234;
241 createDensity(ChengBetaSampler.of(rng, alphaBetaAlt, betaBetaAlt),
242 betaMin, betaMax, "beta.case2");
243 }
244
245 final double meanExp = 3.45;
246 final double expMin = 0;
247 final double expMax = 60;
248 if (samplers.contains(Sampler.AhrensDieterExponentialSampler)) {
249 createDensity(AhrensDieterExponentialSampler.of(rng, meanExp),
250 expMin, expMax, "exp");
251 }
252 if (samplers.contains(Sampler.ModifiedZigguratExponentialSampler)) {
253 createDensity(ZigguratSampler.Exponential.of(rng, meanExp),
254 expMin, expMax, "exp.modified.ziggurat");
255 }
256
257 final double gammaMin = 0;
258 final double gammaMax1 = 40;
259 final double thetaGamma = 3.456;
260 if (samplers.contains(Sampler.AhrensDieterMarsagliaTsangGammaSamplerCase1)) {
261 final double alphaGammaSmallerThanOne = 0.1234;
262 createDensity(AhrensDieterMarsagliaTsangGammaSampler.of(rng, alphaGammaSmallerThanOne, thetaGamma),
263 gammaMin, gammaMax1, "gamma.case1");
264 }
265 if (samplers.contains(Sampler.AhrensDieterMarsagliaTsangGammaSamplerCase2)) {
266 final double alphaGammaLargerThanOne = 2.345;
267 final double gammaMax2 = 70;
268 createDensity(AhrensDieterMarsagliaTsangGammaSampler.of(rng, alphaGammaLargerThanOne, thetaGamma),
269 gammaMin, gammaMax2, "gamma.case2");
270 }
271
272 final double scalePareto = 23.45;
273 final double shapePareto = 0.789;
274 final double paretoMin = 23;
275 final double paretoMax = 400;
276 if (samplers.contains(Sampler.InverseTransformParetoSampler)) {
277 createDensity(InverseTransformParetoSampler.of(rng, scalePareto, shapePareto),
278 paretoMin, paretoMax, "pareto");
279 }
280
281 final double loUniform = -9.876;
282 final double hiUniform = 5.432;
283 if (samplers.contains(Sampler.ContinuousUniformSampler)) {
284 createDensity(ContinuousUniformSampler.of(rng, loUniform, hiUniform),
285 loUniform, hiUniform, "uniform");
286 }
287
288 final double scaleLogNormal = 2.345;
289 final double shapeLogNormal = 0.1234;
290 final double logNormalMin = 5;
291 final double logNormalMax = 25;
292 if (samplers.contains(Sampler.LogNormalZigguratGaussianSampler)) {
293 createDensity(LogNormalSampler.of(ZigguratNormalizedGaussianSampler.of(rng),
294 scaleLogNormal, shapeLogNormal),
295 logNormalMin, logNormalMax, "lognormal.ziggurat");
296 }
297 if (samplers.contains(Sampler.LogNormalMarsagliaGaussianSampler)) {
298 createDensity(LogNormalSampler.of(MarsagliaNormalizedGaussianSampler.of(rng),
299 scaleLogNormal, shapeLogNormal),
300 logNormalMin, logNormalMax, "lognormal.marsaglia");
301 }
302 if (samplers.contains(Sampler.LogNormalBoxMullerGaussianSampler)) {
303 createDensity(LogNormalSampler.of(BoxMullerNormalizedGaussianSampler.of(rng),
304 scaleLogNormal, shapeLogNormal),
305 logNormalMin, logNormalMax, "lognormal.boxmuller");
306 }
307 if (samplers.contains(Sampler.LogNormalModifiedZigguratGaussianSampler)) {
308 createDensity(LogNormalSampler.of(ZigguratSampler.NormalizedGaussian.of(rng),
309 scaleLogNormal, shapeLogNormal),
310 logNormalMin, logNormalMax, "lognormal.modified.ziggurat");
311 }
312
313 if (samplers.contains(Sampler.LevySampler)) {
314 final double levyLocation = 1.23;
315 final double levyscale = 0.75;
316 final double levyMin = levyLocation;
317
318 final double levyMax = 6.2815;
319 createDensity(LevySampler.of(rng, levyLocation, levyscale),
320 levyMin, levyMax, "levy");
321 }
322
323 if (samplers.contains(Sampler.StableSampler)) {
324 final double stableAlpha = 1.23;
325 final double stableBeta = 0.75;
326
327 final double stableMin = -1.7862;
328 final double stableMax = 4.0364;
329 createDensity(StableSampler.of(rng, stableAlpha, stableBeta),
330 stableMin, stableMax, "stable");
331 }
332
333 if (samplers.contains(Sampler.TSampler)) {
334 final double tDegreesOfFreedom = 1.23;
335
336 final double tMin = -9.9264;
337 final double tMax = 9.9264;
338 createDensity(TSampler.of(rng, tDegreesOfFreedom),
339 tMin, tMax, "t");
340 }
341
342 return null;
343 }
344 }