View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import org.apache.commons.rng.UniformRandomProvider;
20  
21  /**
22   * Sampler for a discrete distribution using an optimised look-up table.
23   *
24   * <ul>
25   *  <li>
26   *   The method requires 30-bit integer probabilities that sum to 2<sup>30</sup> as described
27   *   in George Marsaglia, Wai Wan Tsang, Jingbo Wang (2004) Fast Generation of Discrete
28   *   Random Variables. Journal of Statistical Software. Vol. 11, Issue. 3, pp. 1-11.
29   *  </li>
30   * </ul>
31   *
32   * <p>Sampling uses 1 call to {@link UniformRandomProvider#nextInt()}.</p>
33   *
34   * <p>Memory requirements depend on the maximum number of possible sample values, {@code n},
35   * and the values for the probabilities. Storage is optimised for {@code n}. The worst case
36   * scenario is a uniform distribution of the maximum sample size. This is capped at 0.06MB for
37   * {@code n <= } 2<sup>8</sup>, 17.0MB for {@code n <= } 2<sup>16</sup>, and 4.3GB for
38   * {@code n <=} 2<sup>30</sup>. Realistic requirements will be in the kB range.</p>
39   *
40   * <p>The sampler supports the following distributions:</p>
41   *
42   * <ul>
43   *  <li>Enumerated distribution (probabilities must be provided for each sample)
44   *  <li>Poisson distribution up to {@code mean = 1024}
45   *  <li>Binomial distribution up to {@code trials = 65535}
46   * </ul>
47   *
48   * @see <a href="http://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol.
49   * 11, Issue 3</a>
50   * @since 1.3
51   */
52  public final class MarsagliaTsangWangDiscreteSampler {
53      /** The value 2<sup>8</sup> as an {@code int}. */
54      private static final int INT_8 = 1 << 8;
55      /** The value 2<sup>16</sup> as an {@code int}. */
56      private static final int INT_16 = 1 << 16;
57      /** The value 2<sup>30</sup> as an {@code int}. */
58      private static final int INT_30 = 1 << 30;
59      /** The value 2<sup>31</sup> as a {@code double}. */
60      private static final double DOUBLE_31 = 1L << 31;
61  
62      // =========================================================================
63      // Implementation note:
64      //
65      // This sampler uses prepared look-up tables that are searched using a single
66      // random int variate. The look-up tables contain the sample value. The tables
67      // are constructed using probabilities that sum to 2^30. The original paper
68      // by Marsaglia, et al (2004) describes the use of 5, 3, or 2 look-up tables
69      // indexed using digits of base 2^6, 2^10 or 2^15. Currently only base 64 (2^6)
70      // is supported using 5 look-up tables.
71      //
72      // The implementations use 8, 16 or 32 bit storage tables to support different
73      // distribution sizes with optimal storage. Separate class implementations of
74      // the same algorithm allow array storage to be accessed directly from 1D tables.
75      // This provides a performance gain over using: abstracted storage accessed via
76      // an interface; or a single 2D table.
77      //
78      // To allow the optimal implementation to be chosen the sampler is created
79      // using factory methods. The sampler supports any probability distribution
80      // when provided via an array of probabilities and the Poisson and Binomial
81      // distributions for a restricted set of parameters. The restrictions are
82      // imposed by the requirement to compute the entire probability distribution
83      // from the controlling parameter(s) using a recursive method. Factory
84      // constructors return a SharedStateDiscreteSampler instance. Each distribution
85      // type is contained in an inner class.
86      // =========================================================================
87  
88      /**
89       * The base class for Marsaglia-Tsang-Wang samplers.
90       */
91      private abstract static class AbstractMarsagliaTsangWangDiscreteSampler
92              implements SharedStateDiscreteSampler {
93          /** Underlying source of randomness. */
94          protected final UniformRandomProvider rng;
95  
96          /** The name of the distribution. */
97          private final String distributionName;
98  
99          /**
100          * @param rng Generator of uniformly distributed random numbers.
101          * @param distributionName Distribution name.
102          */
103         AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
104                                                   String distributionName) {
105             this.rng = rng;
106             this.distributionName = distributionName;
107         }
108 
109         /**
110          * @param rng Generator of uniformly distributed random numbers.
111          * @param source Source to copy.
112          */
113         AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
114                                                   AbstractMarsagliaTsangWangDiscreteSampler source) {
115             this.rng = rng;
116             this.distributionName = source.distributionName;
117         }
118 
119         /** {@inheritDoc} */
120         @Override
121         public String toString() {
122             return "Marsaglia Tsang Wang " + distributionName + " deviate [" + rng.toString() + "]";
123         }
124     }
125 
126     /**
127      * An implementation for the sample algorithm based on the decomposition of the
128      * index in the range {@code [0,2^30)} into 5 base-64 digits with 8-bit backing storage.
129      */
130     private static final class MarsagliaTsangWangBase64Int8DiscreteSampler
131         extends AbstractMarsagliaTsangWangDiscreteSampler {
132         /** The mask to convert a {@code byte} to an unsigned 8-bit integer. */
133         private static final int MASK = 0xff;
134 
135         /** Limit for look-up table 1. */
136         private final int t1;
137         /** Limit for look-up table 2. */
138         private final int t2;
139         /** Limit for look-up table 3. */
140         private final int t3;
141         /** Limit for look-up table 4. */
142         private final int t4;
143 
144         /** Look-up table table1. */
145         private final byte[] table1;
146         /** Look-up table table2. */
147         private final byte[] table2;
148         /** Look-up table table3. */
149         private final byte[] table3;
150         /** Look-up table table4. */
151         private final byte[] table4;
152         /** Look-up table table5. */
153         private final byte[] table5;
154 
155         /**
156          * @param rng Generator of uniformly distributed random numbers.
157          * @param distributionName Distribution name.
158          * @param prob The probabilities.
159          * @param offset The offset (must be positive).
160          */
161         MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
162                                                     String distributionName,
163                                                     int[] prob,
164                                                     int offset) {
165             super(rng, distributionName);
166 
167             // Get table sizes for each base-64 digit
168             int n1 = 0;
169             int n2 = 0;
170             int n3 = 0;
171             int n4 = 0;
172             int n5 = 0;
173             for (final int m : prob) {
174                 n1 += getBase64Digit(m, 1);
175                 n2 += getBase64Digit(m, 2);
176                 n3 += getBase64Digit(m, 3);
177                 n4 += getBase64Digit(m, 4);
178                 n5 += getBase64Digit(m, 5);
179             }
180 
181             table1 = new byte[n1];
182             table2 = new byte[n2];
183             table3 = new byte[n3];
184             table4 = new byte[n4];
185             table5 = new byte[n5];
186 
187             // Compute offsets
188             t1 = n1 << 24;
189             t2 = t1 + (n2 << 18);
190             t3 = t2 + (n3 << 12);
191             t4 = t3 + (n4 << 6);
192             n1 = n2 = n3 = n4 = n5 = 0;
193 
194             // Fill tables
195             for (int i = 0; i < prob.length; i++) {
196                 final int m = prob[i];
197                 // Primitive type conversion will extract lower 8 bits
198                 final byte k = (byte) (i + offset);
199                 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
200                 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
201                 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
202                 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
203                 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
204             }
205         }
206 
207         /**
208          * @param rng Generator of uniformly distributed random numbers.
209          * @param source Source to copy.
210          */
211         private MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
212                 MarsagliaTsangWangBase64Int8DiscreteSampler source) {
213             super(rng, source);
214             t1 = source.t1;
215             t2 = source.t2;
216             t3 = source.t3;
217             t4 = source.t4;
218             table1 = source.table1;
219             table2 = source.table2;
220             table3 = source.table3;
221             table4 = source.table4;
222             table5 = source.table5;
223         }
224 
225         /**
226          * Fill the table with the value.
227          *
228          * @param table Table.
229          * @param from Lower bound index (inclusive)
230          * @param to Upper bound index (exclusive)
231          * @param value Value.
232          * @return the upper bound index
233          */
234         private static int fill(byte[] table, int from, int to, byte value) {
235             for (int i = from; i < to; i++) {
236                 table[i] = value;
237             }
238             return to;
239         }
240 
241         @Override
242         public int sample() {
243             final int j = rng.nextInt() >>> 2;
244             if (j < t1) {
245                 return table1[j >>> 24] & MASK;
246             }
247             if (j < t2) {
248                 return table2[(j - t1) >>> 18] & MASK;
249             }
250             if (j < t3) {
251                 return table3[(j - t2) >>> 12] & MASK;
252             }
253             if (j < t4) {
254                 return table4[(j - t3) >>> 6] & MASK;
255             }
256             // Note the tables are filled on the assumption that the sum of the probabilities.
257             // is >=2^30. If this is not true then the final table table5 will be smaller by the
258             // difference. So the tables *must* be constructed correctly.
259             return table5[j - t4] & MASK;
260         }
261 
262         @Override
263         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
264             return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, this);
265         }
266     }
267 
268     /**
269      * An implementation for the sample algorithm based on the decomposition of the
270      * index in the range {@code [0,2^30)} into 5 base-64 digits with 16-bit backing storage.
271      */
272     private static final class MarsagliaTsangWangBase64Int16DiscreteSampler
273         extends AbstractMarsagliaTsangWangDiscreteSampler {
274         /** The mask to convert a {@code byte} to an unsigned 16-bit integer. */
275         private static final int MASK = 0xffff;
276 
277         /** Limit for look-up table 1. */
278         private final int t1;
279         /** Limit for look-up table 2. */
280         private final int t2;
281         /** Limit for look-up table 3. */
282         private final int t3;
283         /** Limit for look-up table 4. */
284         private final int t4;
285 
286         /** Look-up table table1. */
287         private final short[] table1;
288         /** Look-up table table2. */
289         private final short[] table2;
290         /** Look-up table table3. */
291         private final short[] table3;
292         /** Look-up table table4. */
293         private final short[] table4;
294         /** Look-up table table5. */
295         private final short[] table5;
296 
297         /**
298          * @param rng Generator of uniformly distributed random numbers.
299          * @param distributionName Distribution name.
300          * @param prob The probabilities.
301          * @param offset The offset (must be positive).
302          */
303         MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
304                                                      String distributionName,
305                                                      int[] prob,
306                                                      int offset) {
307             super(rng, distributionName);
308 
309             // Get table sizes for each base-64 digit
310             int n1 = 0;
311             int n2 = 0;
312             int n3 = 0;
313             int n4 = 0;
314             int n5 = 0;
315             for (final int m : prob) {
316                 n1 += getBase64Digit(m, 1);
317                 n2 += getBase64Digit(m, 2);
318                 n3 += getBase64Digit(m, 3);
319                 n4 += getBase64Digit(m, 4);
320                 n5 += getBase64Digit(m, 5);
321             }
322 
323             table1 = new short[n1];
324             table2 = new short[n2];
325             table3 = new short[n3];
326             table4 = new short[n4];
327             table5 = new short[n5];
328 
329             // Compute offsets
330             t1 = n1 << 24;
331             t2 = t1 + (n2 << 18);
332             t3 = t2 + (n3 << 12);
333             t4 = t3 + (n4 << 6);
334             n1 = n2 = n3 = n4 = n5 = 0;
335 
336             // Fill tables
337             for (int i = 0; i < prob.length; i++) {
338                 final int m = prob[i];
339                 // Primitive type conversion will extract lower 16 bits
340                 final short k = (short) (i + offset);
341                 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
342                 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
343                 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
344                 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
345                 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
346             }
347         }
348 
349         /**
350          * @param rng Generator of uniformly distributed random numbers.
351          * @param source Source to copy.
352          */
353         private MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
354                 MarsagliaTsangWangBase64Int16DiscreteSampler source) {
355             super(rng, source);
356             t1 = source.t1;
357             t2 = source.t2;
358             t3 = source.t3;
359             t4 = source.t4;
360             table1 = source.table1;
361             table2 = source.table2;
362             table3 = source.table3;
363             table4 = source.table4;
364             table5 = source.table5;
365         }
366 
367         /**
368          * Fill the table with the value.
369          *
370          * @param table Table.
371          * @param from Lower bound index (inclusive)
372          * @param to Upper bound index (exclusive)
373          * @param value Value.
374          * @return the upper bound index
375          */
376         private static int fill(short[] table, int from, int to, short value) {
377             for (int i = from; i < to; i++) {
378                 table[i] = value;
379             }
380             return to;
381         }
382 
383         @Override
384         public int sample() {
385             final int j = rng.nextInt() >>> 2;
386             if (j < t1) {
387                 return table1[j >>> 24] & MASK;
388             }
389             if (j < t2) {
390                 return table2[(j - t1) >>> 18] & MASK;
391             }
392             if (j < t3) {
393                 return table3[(j - t2) >>> 12] & MASK;
394             }
395             if (j < t4) {
396                 return table4[(j - t3) >>> 6] & MASK;
397             }
398             // Note the tables are filled on the assumption that the sum of the probabilities.
399             // is >=2^30. If this is not true then the final table table5 will be smaller by the
400             // difference. So the tables *must* be constructed correctly.
401             return table5[j - t4] & MASK;
402         }
403 
404         @Override
405         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
406             return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, this);
407         }
408     }
409 
410     /**
411      * An implementation for the sample algorithm based on the decomposition of the
412      * index in the range {@code [0,2^30)} into 5 base-64 digits with 32-bit backing storage.
413      */
414     private static final class MarsagliaTsangWangBase64Int32DiscreteSampler
415         extends AbstractMarsagliaTsangWangDiscreteSampler {
416         /** Limit for look-up table 1. */
417         private final int t1;
418         /** Limit for look-up table 2. */
419         private final int t2;
420         /** Limit for look-up table 3. */
421         private final int t3;
422         /** Limit for look-up table 4. */
423         private final int t4;
424 
425         /** Look-up table table1. */
426         private final int[] table1;
427         /** Look-up table table2. */
428         private final int[] table2;
429         /** Look-up table table3. */
430         private final int[] table3;
431         /** Look-up table table4. */
432         private final int[] table4;
433         /** Look-up table table5. */
434         private final int[] table5;
435 
436         /**
437          * @param rng Generator of uniformly distributed random numbers.
438          * @param distributionName Distribution name.
439          * @param prob The probabilities.
440          * @param offset The offset (must be positive).
441          */
442         MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
443                                                      String distributionName,
444                                                      int[] prob,
445                                                      int offset) {
446             super(rng, distributionName);
447 
448             // Get table sizes for each base-64 digit
449             int n1 = 0;
450             int n2 = 0;
451             int n3 = 0;
452             int n4 = 0;
453             int n5 = 0;
454             for (final int m : prob) {
455                 n1 += getBase64Digit(m, 1);
456                 n2 += getBase64Digit(m, 2);
457                 n3 += getBase64Digit(m, 3);
458                 n4 += getBase64Digit(m, 4);
459                 n5 += getBase64Digit(m, 5);
460             }
461 
462             table1 = new int[n1];
463             table2 = new int[n2];
464             table3 = new int[n3];
465             table4 = new int[n4];
466             table5 = new int[n5];
467 
468             // Compute offsets
469             t1 = n1 << 24;
470             t2 = t1 + (n2 << 18);
471             t3 = t2 + (n3 << 12);
472             t4 = t3 + (n4 << 6);
473             n1 = n2 = n3 = n4 = n5 = 0;
474 
475             // Fill tables
476             for (int i = 0; i < prob.length; i++) {
477                 final int m = prob[i];
478                 final int k = i + offset;
479                 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
480                 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
481                 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
482                 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
483                 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
484             }
485         }
486 
487         /**
488          * @param rng Generator of uniformly distributed random numbers.
489          * @param source Source to copy.
490          */
491         private MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
492                 MarsagliaTsangWangBase64Int32DiscreteSampler source) {
493             super(rng, source);
494             t1 = source.t1;
495             t2 = source.t2;
496             t3 = source.t3;
497             t4 = source.t4;
498             table1 = source.table1;
499             table2 = source.table2;
500             table3 = source.table3;
501             table4 = source.table4;
502             table5 = source.table5;
503         }
504 
505         /**
506          * Fill the table with the value.
507          *
508          * @param table Table.
509          * @param from Lower bound index (inclusive)
510          * @param to Upper bound index (exclusive)
511          * @param value Value.
512          * @return the upper bound index
513          */
514         private static int fill(int[] table, int from, int to, int value) {
515             for (int i = from; i < to; i++) {
516                 table[i] = value;
517             }
518             return to;
519         }
520 
521         @Override
522         public int sample() {
523             final int j = rng.nextInt() >>> 2;
524             if (j < t1) {
525                 return table1[j >>> 24];
526             }
527             if (j < t2) {
528                 return table2[(j - t1) >>> 18];
529             }
530             if (j < t3) {
531                 return table3[(j - t2) >>> 12];
532             }
533             if (j < t4) {
534                 return table4[(j - t3) >>> 6];
535             }
536             // Note the tables are filled on the assumption that the sum of the probabilities.
537             // is >=2^30. If this is not true then the final table table5 will be smaller by the
538             // difference. So the tables *must* be constructed correctly.
539             return table5[j - t4];
540         }
541 
542         @Override
543         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
544             return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, this);
545         }
546     }
547 
548 
549 
550     /** Class contains only static methods. */
551     private MarsagliaTsangWangDiscreteSampler() {}
552 
553     /**
554      * Gets the k<sup>th</sup> base 64 digit of {@code m}.
555      *
556      * @param m the value m.
557      * @param k the digit.
558      * @return the base 64 digit
559      */
560     private static int getBase64Digit(int m, int k) {
561         return (m >>> (30 - 6 * k)) & 63;
562     }
563 
564     /**
565      * Convert the probability to an integer in the range [0,2^30]. This is the numerator of
566      * a fraction with assumed denominator 2<sup>30</sup>.
567      *
568      * @param p Probability.
569      * @return the fraction numerator
570      */
571     private static int toUnsignedInt30(double p) {
572         return (int) (p * INT_30 + 0.5);
573     }
574 
575     /**
576      * Create a new instance for probabilities {@code p(i)} where the sample value {@code x} is
577      * {@code i + offset}.
578      *
579      * <p>The sum of the probabilities must be {@code >=} 2<sup>30</sup>. Only the
580      * values for cumulative probability up to 2<sup>30</sup> will be sampled.</p>
581      *
582      * @param rng Generator of uniformly distributed random numbers.
583      * @param distributionName Distribution name.
584      * @param prob The probabilities.
585      * @param offset The offset (must be positive).
586      * @return Sampler.
587      */
588     private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
589                                                             String distributionName,
590                                                             int[] prob,
591                                                             int offset) {
592         // Note: No argument checks for private method.
593 
594         // Choose implementation based on the maximum index
595         final int maxIndex = prob.length + offset - 1;
596         if (maxIndex < INT_8) {
597             return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, distributionName, prob, offset);
598         }
599         if (maxIndex < INT_16) {
600             return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, distributionName, prob, offset);
601         }
602         return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, distributionName, prob, offset);
603     }
604 
605     // =========================================================================
606     // The following public classes provide factory methods to construct a sampler for:
607     // - Enumerated probability distribution (from provided double[] probabilities)
608     // - Poisson distribution for mean <= 1024
609     // - Binomial distribution for trials <= 65535
610     // =========================================================================
611 
612     /**
613      * Create a sampler for an enumerated distribution of {@code n} values each with an
614      * associated probability.
615      * The samples corresponding to each probability are assumed to be a natural sequence
616      * starting at zero.
617      */
618     public static final class Enumerated {
619         /** The name of the enumerated probability distribution. */
620         private static final String ENUMERATED_NAME = "Enumerated";
621 
622         /** Class contains only static methods. */
623         private Enumerated() {}
624 
625         /**
626          * Creates a sampler for an enumerated distribution of {@code n} values each with an
627          * associated probability.
628          *
629          * <p>The probabilities will be normalised using their sum. The only requirement
630          * is the sum is positive.</p>
631          *
632          * <p>The sum of the probabilities is normalised to 2<sup>30</sup>. Note that
633          * probabilities are adjusted to the nearest 2<sup>-30</sup> due to round-off during
634          * the normalisation conversion. Consequently any probability less than 2<sup>-31</sup>
635          * will not be observed in samples.</p>
636          *
637          * @param rng Generator of uniformly distributed random numbers.
638          * @param probabilities The list of probabilities.
639          * @return Sampler.
640          * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
641          * probability is negative, infinite or {@code NaN}, or the sum of all
642          * probabilities is not strictly positive.
643          */
644         public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
645                                                     double[] probabilities) {
646             return createSampler(rng, ENUMERATED_NAME, normaliseProbabilities(probabilities), 0);
647         }
648 
649         /**
650          * Normalise the probabilities to integers that sum to 2<sup>30</sup>.
651          *
652          * @param probabilities The list of probabilities.
653          * @return the normalised probabilities.
654          * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
655          * probability is negative, infinite or {@code NaN}, or the sum of all
656          * probabilities is not strictly positive.
657          */
658         private static int[] normaliseProbabilities(double[] probabilities) {
659             final double sumProb = InternalUtils.validateProbabilities(probabilities);
660 
661             // Compute the normalisation: 2^30 / sum
662             final double normalisation = INT_30 / sumProb;
663             final int[] prob = new int[probabilities.length];
664             int sum = 0;
665             int max = 0;
666             int mode = 0;
667             for (int i = 0; i < prob.length; i++) {
668                 // Add 0.5 for rounding
669                 final int p = (int) (probabilities[i] * normalisation + 0.5);
670                 sum += p;
671                 // Find the mode (maximum probability)
672                 if (max < p) {
673                     max = p;
674                     mode = i;
675                 }
676                 prob[i] = p;
677             }
678 
679             // The sum must be >= 2^30.
680             // Here just compensate the difference onto the highest probability.
681             prob[mode] += INT_30 - sum;
682 
683             return prob;
684         }
685     }
686 
687     /**
688      * Create a sampler for the Poisson distribution.
689      */
690     public static final class Poisson {
691         /** The name of the Poisson distribution. */
692         private static final String POISSON_NAME = "Poisson";
693 
694         /**
695          * Upper bound on the mean for the Poisson distribution.
696          *
697          * <p>The original source code provided in Marsaglia, et al (2004) has no explicit
698          * limit but the code fails at mean {@code >= 1941} as the transform to compute p(x=mode)
699          * produces infinity. Use a conservative limit of 1024.</p>
700          */
701 
702         private static final double MAX_MEAN = 1024;
703         /**
704          * The threshold for the mean of the Poisson distribution to switch the method used
705          * to compute the probabilities. This is taken from the example software provided by
706          * Marsaglia, et al (2004).
707          */
708         private static final double MEAN_THRESHOLD = 21.4;
709 
710         /** Class contains only static methods. */
711         private Poisson() {}
712 
713         /**
714          * Creates a sampler for the Poisson distribution.
715          *
716          * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p>
717          *
718          * <p>Storage requirements depend on the tabulated probability values. Example storage
719          * requirements are listed below.</p>
720          *
721          * <pre>
722          * mean      table size     kB
723          * 0.25      882            0.88
724          * 0.5       1135           1.14
725          * 1         1200           1.20
726          * 2         1451           1.45
727          * 4         1955           1.96
728          * 8         2961           2.96
729          * 16        4410           4.41
730          * 32        6115           6.11
731          * 64        8499           8.50
732          * 128       11528          11.53
733          * 256       15935          31.87
734          * 512       20912          41.82
735          * 1024      30614          61.23
736          * </pre>
737          *
738          * <p>Note: Storage changes to 2 bytes per index between {@code mean=128} and {@code mean=256}.</p>
739          *
740          * @param rng Generator of uniformly distributed random numbers.
741          * @param mean Mean.
742          * @return Sampler.
743          * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}.
744          */
745         public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
746                                                     double mean) {
747             validatePoissonDistributionParameters(mean);
748 
749             // Create the distribution either from X=0 or from X=mode when the mean is high.
750             return mean < MEAN_THRESHOLD ?
751                 createPoissonDistributionFromX0(rng, mean) :
752                 createPoissonDistributionFromXMode(rng, mean);
753         }
754 
755         /**
756          * Validate the Poisson distribution parameters.
757          *
758          * @param mean Mean.
759          * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}.
760          */
761         private static void validatePoissonDistributionParameters(double mean) {
762             InternalUtils.requireStrictlyPositive(mean, "mean");
763             if (mean > MAX_MEAN) {
764                 throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
765             }
766         }
767 
768         /**
769          * Creates the Poisson distribution by computing probabilities recursively from {@code X=0}.
770          *
771          * @param rng Generator of uniformly distributed random numbers.
772          * @param mean Mean.
773          * @return Sampler.
774          */
775         private static SharedStateDiscreteSampler createPoissonDistributionFromX0(
776                 UniformRandomProvider rng, double mean) {
777             final double p0 = Math.exp(-mean);
778 
779             // Recursive update of Poisson probability until the value is too small
780             // p(x + 1) = p(x) * mean / (x + 1)
781             double p = p0;
782             int i = 1;
783             while (p * DOUBLE_31 >= 1) {
784                 p *= mean / i++;
785             }
786 
787             // Probabilities are 30-bit integers, assumed denominator 2^30
788             final int size = i - 1;
789             final int[] prob = new int[size];
790 
791             p = p0;
792             prob[0] = toUnsignedInt30(p);
793             // The sum must exceed 2^30. In edges cases this is false due to round-off.
794             int sum = prob[0];
795             for (i = 1; i < prob.length; i++) {
796                 p *= mean / i;
797                 prob[i] = toUnsignedInt30(p);
798                 sum += prob[i];
799             }
800 
801             // If the sum is < 2^30 add the remaining sum to the mode (floor(mean)).
802             prob[(int) mean] += Math.max(0, INT_30 - sum);
803 
804             // Note: offset = 0
805             return createSampler(rng, POISSON_NAME, prob, 0);
806         }
807 
808         /**
809          * Creates the Poisson distribution by computing probabilities recursively upward and downward
810          * from {@code X=mode}, the location of the largest p-value.
811          *
812          * @param rng Generator of uniformly distributed random numbers.
813          * @param mean Mean.
814          * @return Sampler.
815          */
816         private static SharedStateDiscreteSampler createPoissonDistributionFromXMode(
817                 UniformRandomProvider rng, double mean) {
818             // If mean >= 21.4, generate from largest p-value up, then largest down.
819             // The largest p-value will be at the mode (floor(mean)).
820 
821             // Find p(x=mode)
822             final int mode = (int) mean;
823             // This transform is stable until mean >= 1941 where p will result in Infinity
824             // before the divisor i is large enough to start reducing the product (i.e. i > c).
825             final double c = mean * Math.exp(-mean / mode);
826             double p = 1.0;
827             for (int i = 1; i <= mode; i++) {
828                 p *= c / i;
829             }
830             final double pMode = p;
831 
832             // Find the upper limit using recursive computation of the p-value.
833             // Note this will exit when i overflows to negative so no check on the range
834             int i = mode + 1;
835             while (p * DOUBLE_31 >= 1) {
836                 p *= mean / i++;
837             }
838             final int last = i - 2;
839 
840             // Find the lower limit using recursive computation of the p-value.
841             p = pMode;
842             int j = -1;
843             for (i = mode - 1; i >= 0; i--) {
844                 p *= (i + 1) / mean;
845                 if (p * DOUBLE_31 < 1) {
846                     j = i;
847                     break;
848                 }
849             }
850 
851             // Probabilities are 30-bit integers, assumed denominator 2^30.
852             // This is the minimum sample value: prob[x - offset] = p(x)
853             final int offset = j + 1;
854             final int size = last - offset + 1;
855             final int[] prob = new int[size];
856 
857             p = pMode;
858             prob[mode - offset] = toUnsignedInt30(p);
859             // The sum must exceed 2^30. In edges cases this is false due to round-off.
860             int sum = prob[mode - offset];
861             // From mode to upper limit
862             for (i = mode + 1; i <= last; i++) {
863                 p *= mean / i;
864                 prob[i - offset] = toUnsignedInt30(p);
865                 sum += prob[i - offset];
866             }
867             // From mode to lower limit
868             p = pMode;
869             for (i = mode - 1; i >= offset; i--) {
870                 p *= (i + 1) / mean;
871                 prob[i - offset] = toUnsignedInt30(p);
872                 sum += prob[i - offset];
873             }
874 
875             // If the sum is < 2^30 add the remaining sum to the mode.
876             // If above 2^30 then the effect is truncation of the long tail of the distribution.
877             prob[mode - offset] += Math.max(0, INT_30 - sum);
878 
879             return createSampler(rng, POISSON_NAME, prob, offset);
880         }
881     }
882 
883     /**
884      * Create a sampler for the Binomial distribution.
885      */
886     public static final class Binomial {
887         /** The name of the Binomial distribution. */
888         private static final String BINOMIAL_NAME = "Binomial";
889 
890         /**
891          * Return a fixed result for the Binomial distribution. This is a special class to handle
892          * an edge case of probability of success equal to 0 or 1.
893          */
894         private static final class MarsagliaTsangWangFixedResultBinomialSampler
895             extends AbstractMarsagliaTsangWangDiscreteSampler {
896             /** The result. */
897             private final int result;
898 
899             /**
900              * @param result Result.
901              */
902             MarsagliaTsangWangFixedResultBinomialSampler(int result) {
903                 super(null, BINOMIAL_NAME);
904                 this.result = result;
905             }
906 
907             @Override
908             public int sample() {
909                 return result;
910             }
911 
912             @Override
913             public String toString() {
914                 return BINOMIAL_NAME + " deviate";
915             }
916 
917             @Override
918             public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
919                 // No shared state
920                 return this;
921             }
922         }
923 
924         /**
925          * Return an inversion result for the Binomial distribution. This assumes the
926          * following:
927          *
928          * <pre>
929          * Binomial(n, p) = 1 - Binomial(n, 1 - p)
930          * </pre>
931          */
932         private static final class MarsagliaTsangWangInversionBinomialSampler
933             extends AbstractMarsagliaTsangWangDiscreteSampler {
934             /** The number of trials. */
935             private final int trials;
936             /** The Binomial distribution sampler. */
937             private final SharedStateDiscreteSampler sampler;
938 
939             /**
940              * @param trials Number of trials.
941              * @param sampler Binomial distribution sampler.
942              */
943             MarsagliaTsangWangInversionBinomialSampler(int trials,
944                                                        SharedStateDiscreteSampler sampler) {
945                 super(null, BINOMIAL_NAME);
946                 this.trials = trials;
947                 this.sampler = sampler;
948             }
949 
950             @Override
951             public int sample() {
952                 return trials - sampler.sample();
953             }
954 
955             @Override
956             public String toString() {
957                 return sampler.toString();
958             }
959 
960             @Override
961             public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
962                 return new MarsagliaTsangWangInversionBinomialSampler(this.trials,
963                     this.sampler.withUniformRandomProvider(rng));
964             }
965         }
966 
967         /** Class contains only static methods. */
968         private Binomial() {}
969 
970         /**
971          * Creates a sampler for the Binomial distribution.
972          *
973          * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p>
974          *
975          * <p>Storage requirements depend on the tabulated probability values. Example storage
976          * requirements are listed below (in kB).</p>
977          *
978          * <pre>
979          *          p
980          * trials   0.5    0.1   0.01  0.001
981          *    4    0.06   0.63   0.44   0.44
982          *   16    0.69   1.14   0.76   0.44
983          *   64    4.73   2.40   1.14   0.51
984          *  256    8.63   5.17   1.89   0.82
985          * 1024   31.12   9.45   3.34   0.89
986          * </pre>
987          *
988          * <p>The method requires that the Binomial distribution probability at {@code x=0} can be computed.
989          * This will fail when {@code (1 - p)^trials == 0} which requires {@code trials} to be large
990          * and/or {@code p} to be small. In this case an exception is raised.</p>
991          *
992          * @param rng Generator of uniformly distributed random numbers.
993          * @param trials Number of trials.
994          * @param probabilityOfSuccess Probability of success (p).
995          * @return Sampler.
996          * @throws IllegalArgumentException if {@code trials < 0} or {@code trials >= 2^16},
997          * {@code p} is not in the range {@code [0-1]}, or the probability distribution cannot
998          * be computed.
999          */
1000         public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
1001                                                     int trials,
1002                                                     double probabilityOfSuccess) {
1003             validateBinomialDistributionParameters(trials, probabilityOfSuccess);
1004 
1005             // Handle edge cases
1006             if (probabilityOfSuccess == 0) {
1007                 return new MarsagliaTsangWangFixedResultBinomialSampler(0);
1008             }
1009             if (probabilityOfSuccess == 1) {
1010                 return new MarsagliaTsangWangFixedResultBinomialSampler(trials);
1011             }
1012 
1013             // Check the supported size.
1014             if (trials >= INT_16) {
1015                 throw new IllegalArgumentException("Unsupported number of trials: " + trials);
1016             }
1017 
1018             return createBinomialDistributionSampler(rng, trials, probabilityOfSuccess);
1019         }
1020 
1021         /**
1022          * Validate the Binomial distribution parameters.
1023          *
1024          * @param trials Number of trials.
1025          * @param probabilityOfSuccess Probability of success (p).
1026          * @throws IllegalArgumentException if {@code trials < 0} or
1027          * {@code p} is not in the range {@code [0-1]}
1028          */
1029         private static void validateBinomialDistributionParameters(int trials, double probabilityOfSuccess) {
1030             if (trials < 0) {
1031                 throw new IllegalArgumentException("Trials is not positive: " + trials);
1032             }
1033             InternalUtils.requireRangeClosed(0, 1, probabilityOfSuccess, "probability of success");
1034         }
1035 
1036         /**
1037          * Creates the Binomial distribution sampler.
1038          *
1039          * <p>This assumes the parameters for the distribution are valid. The method
1040          * will only fail if the initial probability for {@code X=0} is zero.</p>
1041          *
1042          * @param rng Generator of uniformly distributed random numbers.
1043          * @param trials Number of trials.
1044          * @param probabilityOfSuccess Probability of success (p).
1045          * @return Sampler.
1046          * @throws IllegalArgumentException if the probability distribution cannot be
1047          * computed.
1048          */
1049         private static SharedStateDiscreteSampler createBinomialDistributionSampler(
1050                 UniformRandomProvider rng, int trials, double probabilityOfSuccess) {
1051 
1052             // The maximum supported value for Math.exp is approximately -744.
1053             // This occurs when trials is large and p is close to 1.
1054             // Handle this by using an inversion: generate j=Binomial(n,1-p), return n-j
1055             final boolean useInversion = probabilityOfSuccess > 0.5;
1056             final double p = useInversion ? 1 - probabilityOfSuccess : probabilityOfSuccess;
1057 
1058             // Check if the distribution can be computed
1059             final double p0 = Math.exp(trials * Math.log(1 - p));
1060             if (p0 < Double.MIN_VALUE) {
1061                 throw new IllegalArgumentException("Unable to compute distribution");
1062             }
1063 
1064             // First find size of probability array
1065             double t = p0;
1066             final double h = p / (1 - p);
1067             // Find first probability above the threshold of 2^-31
1068             int begin = 0;
1069             if (t * DOUBLE_31 < 1) {
1070                 // Somewhere after p(0)
1071                 // Note:
1072                 // If this loop is entered p(0) is < 2^-31.
1073                 // This has been tested at the extreme for p(0)=Double.MIN_VALUE and either
1074                 // p=0.5 or trials=2^16-1 and does not fail to find the beginning.
1075                 for (int i = 1; i <= trials; i++) {
1076                     t *= (trials + 1 - i) * h / i;
1077                     if (t * DOUBLE_31 >= 1) {
1078                         begin = i;
1079                         break;
1080                     }
1081                 }
1082             }
1083             // Find last probability
1084             int end = trials;
1085             for (int i = begin + 1; i <= trials; i++) {
1086                 t *= (trials + 1 - i) * h / i;
1087                 if (t * DOUBLE_31 < 1) {
1088                     end = i - 1;
1089                     break;
1090                 }
1091             }
1092 
1093             return createBinomialDistributionSamplerFromRange(rng, trials, p, useInversion,
1094                     p0, begin, end);
1095         }
1096 
1097         /**
1098          * Creates the Binomial distribution sampler using only the probability values for {@code X}
1099          * between the begin and the end (inclusive).
1100          *
1101          * @param rng Generator of uniformly distributed random numbers.
1102          * @param trials Number of trials.
1103          * @param p Probability of success (p).
1104          * @param useInversion Set to {@code true} if the probability was inverted.
1105          * @param p0 Probability at {@code X=0}
1106          * @param begin Begin value {@code X} for the distribution.
1107          * @param end End value {@code X} for the distribution.
1108          * @return Sampler.
1109          */
1110         private static SharedStateDiscreteSampler createBinomialDistributionSamplerFromRange(
1111                 UniformRandomProvider rng, int trials, double p,
1112                 boolean useInversion, double p0, int begin, int end) {
1113 
1114             // Assign probability values as 30-bit integers
1115             final int size = end - begin + 1;
1116             final int[] prob = new int[size];
1117             double t = p0;
1118             final double h = p / (1 - p);
1119             for (int i = 1; i <= begin; i++) {
1120                 t *= (trials + 1 - i) * h / i;
1121             }
1122             int sum = toUnsignedInt30(t);
1123             prob[0] = sum;
1124             for (int i = begin + 1; i <= end; i++) {
1125                 t *= (trials + 1 - i) * h / i;
1126                 prob[i - begin] = toUnsignedInt30(t);
1127                 sum += prob[i - begin];
1128             }
1129 
1130             // If the sum is < 2^30 add the remaining sum to the mode (floor((n+1)p))).
1131             // If above 2^30 then the effect is truncation of the long tail of the distribution.
1132             final int mode = (int) ((trials + 1) * p) - begin;
1133             prob[mode] += Math.max(0, INT_30 - sum);
1134 
1135             final SharedStateDiscreteSampler sampler = createSampler(rng, BINOMIAL_NAME, prob, begin);
1136 
1137             // Check if an inversion was made
1138             return useInversion ?
1139                    new MarsagliaTsangWangInversionBinomialSampler(trials, sampler) :
1140                    sampler;
1141         }
1142     }
1143 }