001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.rng.sampling.distribution;
018
019import org.apache.commons.rng.UniformRandomProvider;
020import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;
021
022/**
023 * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>.
024 *
025 * <ul>
026 *  <li>
027 *   For large means, we use the rejection algorithm described in
028 *   <blockquote>
029 *    Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i><br>
030 *    <strong>Computing</strong> vol. 26 pp. 197-207.
031 *   </blockquote>
032 *  </li>
033 * </ul>
034 *
035 * <p>This sampler is suitable for {@code mean >= 40}.</p>
036 *
037 * <p>Sampling uses:</p>
038 *
039 * <ul>
040 *   <li>{@link UniformRandomProvider#nextLong()}
041 *   <li>{@link UniformRandomProvider#nextDouble()}
042 * </ul>
043 *
044 * @since 1.1
045 */
046public class LargeMeanPoissonSampler
047    implements SharedStateDiscreteSampler {
048    /** Upper bound to avoid truncation. */
049    private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE;
050    /** Class to compute {@code log(n!)}. This has no cached values. */
051    private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
052    /** Used when there is no requirement for a small mean Poisson sampler. */
053    private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER =
054        new SharedStateDiscreteSampler() {
055            @Override
056            public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
057                // No requirement for RNG
058                return this;
059            }
060
061            @Override
062            public int sample() {
063                // No Poisson sample
064                return 0;
065            }
066        };
067
068    static {
069        // Create without a cache.
070        NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
071    }
072
073    /** Underlying source of randomness. */
074    private final UniformRandomProvider rng;
075    /** Exponential. */
076    private final SharedStateContinuousSampler exponential;
077    /** Gaussian. */
078    private final SharedStateContinuousSampler gaussian;
079    /** Local class to compute {@code log(n!)}. This may have cached values. */
080    private final InternalUtils.FactorialLog factorialLog;
081
082    // Working values
083
084    /** Algorithm constant: {@code Math.floor(mean)}. */
085    private final double lambda;
086    /** Algorithm constant: {@code Math.log(lambda)}. */
087    private final double logLambda;
088    /** Algorithm constant: {@code factorialLog((int) lambda)}. */
089    private final double logLambdaFactorial;
090    /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */
091    private final double delta;
092    /** Algorithm constant: {@code delta / 2}. */
093    private final double halfDelta;
094    /** Algorithm constant: {@code Math.sqrt(lambda + halfDelta)}. */
095    private final double sqrtLambdaPlusHalfDelta;
096    /** Algorithm constant: {@code 2 * lambda + delta}. */
097    private final double twolpd;
098    /**
099     * Algorithm constant: {@code a1 / aSum}.
100     * <ul>
101     *  <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li>
102     *  <li>{@code aSum = a1 + a2 + 1}</li>
103     * </ul>
104     */
105    private final double p1;
106    /**
107     * Algorithm constant: {@code a2 / aSum}.
108     * <ul>
109     *  <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li>
110     *  <li>{@code aSum = a1 + a2 + 1}</li>
111     * </ul>
112     */
113    private final double p2;
114    /** Algorithm constant: {@code 1 / (8 * lambda)}. */
115    private final double c1;
116
117    /** The internal Poisson sampler for the lambda fraction. */
118    private final SharedStateDiscreteSampler smallMeanPoissonSampler;
119
120
121    /**
122     * Create an instance.
123     *
124     * @param rng Generator of uniformly distributed random numbers.
125     * @param mean Mean.
126     * @throws IllegalArgumentException if {@code mean < 1} or
127     * {@code mean > 0.5 *} {@link Integer#MAX_VALUE}.
128     */
129    public LargeMeanPoissonSampler(UniformRandomProvider rng,
130                                   double mean) {
131        // Validation before java.lang.Object constructor exits prevents partially initialized object
132        this(InternalUtils.requireRangeClosed(1, MAX_MEAN, mean, "mean"), rng);
133    }
134
135    /**
136     * Instantiates a sampler using a precomputed state.
137     *
138     * @param rng              Generator of uniformly distributed random numbers.
139     * @param state            The state for {@code lambda = (int)Math.floor(mean)}.
140     * @param lambdaFractional The lambda fractional value
141     *                         ({@code mean - (int)Math.floor(mean))}.
142     * @throws IllegalArgumentException
143     *                         if {@code lambdaFractional < 0 || lambdaFractional >= 1}.
144     */
145    LargeMeanPoissonSampler(UniformRandomProvider rng,
146                            LargeMeanPoissonSamplerState state,
147                            double lambdaFractional) {
148        // Validation before java.lang.Object constructor exits prevents partially initialized object
149        this(state, InternalUtils.requireRange(0, 1, lambdaFractional, "lambdaFractional"), rng);
150    }
151
152    /**
153     * @param mean Mean.
154     * @param rng Generator of uniformly distributed random numbers.
155     */
156    private LargeMeanPoissonSampler(double mean,
157                                    UniformRandomProvider rng) {
158        this.rng = rng;
159
160        gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
161        exponential = ZigguratSampler.Exponential.of(rng);
162        // Plain constructor uses the uncached function.
163        factorialLog = NO_CACHE_FACTORIAL_LOG;
164
165        // Cache values used in the algorithm
166        lambda = Math.floor(mean);
167        logLambda = Math.log(lambda);
168        logLambdaFactorial = getFactorialLog((int) lambda);
169        delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
170        halfDelta = delta / 2;
171        sqrtLambdaPlusHalfDelta = Math.sqrt(lambda + halfDelta);
172        twolpd = 2 * lambda + delta;
173        c1 = 1 / (8 * lambda);
174        final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
175        final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd);
176        final double aSum = a1 + a2 + 1;
177        p1 = a1 / aSum;
178        p2 = a2 / aSum;
179
180        // The algorithm requires a Poisson sample from the remaining lambda fraction.
181        final double lambdaFractional = mean - lambda;
182        smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
183            NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
184            KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
185    }
186
187    /**
188     * Instantiates a sampler using a precomputed state.
189     *
190     * @param state            The state for {@code lambda = (int)Math.floor(mean)}.
191     * @param lambdaFractional The lambda fractional value
192     *                         ({@code mean - (int)Math.floor(mean))}.
193     * @param rng              Generator of uniformly distributed random numbers.
194     */
195    private LargeMeanPoissonSampler(LargeMeanPoissonSamplerState state,
196                                    double lambdaFractional,
197                                    UniformRandomProvider rng) {
198        this.rng = rng;
199
200        gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
201        exponential = ZigguratSampler.Exponential.of(rng);
202        // Plain constructor uses the uncached function.
203        factorialLog = NO_CACHE_FACTORIAL_LOG;
204
205        // Use the state to initialize the algorithm
206        lambda = state.getLambdaRaw();
207        logLambda = state.getLogLambda();
208        logLambdaFactorial = state.getLogLambdaFactorial();
209        delta = state.getDelta();
210        halfDelta = state.getHalfDelta();
211        sqrtLambdaPlusHalfDelta = state.getSqrtLambdaPlusHalfDelta();
212        twolpd = state.getTwolpd();
213        p1 = state.getP1();
214        p2 = state.getP2();
215        c1 = state.getC1();
216
217        // The algorithm requires a Poisson sample from the remaining lambda fraction.
218        smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
219            NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
220            KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
221    }
222
223    /**
224     * @param rng Generator of uniformly distributed random numbers.
225     * @param source Source to copy.
226     */
227    private LargeMeanPoissonSampler(UniformRandomProvider rng,
228                                    LargeMeanPoissonSampler source) {
229        this.rng = rng;
230
231        gaussian = source.gaussian.withUniformRandomProvider(rng);
232        exponential = source.exponential.withUniformRandomProvider(rng);
233        // Reuse the cache
234        factorialLog = source.factorialLog;
235
236        lambda = source.lambda;
237        logLambda = source.logLambda;
238        logLambdaFactorial = source.logLambdaFactorial;
239        delta = source.delta;
240        halfDelta = source.halfDelta;
241        sqrtLambdaPlusHalfDelta = source.sqrtLambdaPlusHalfDelta;
242        twolpd = source.twolpd;
243        p1 = source.p1;
244        p2 = source.p2;
245        c1 = source.c1;
246
247        // Share the state of the small sampler
248        smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng);
249    }
250
251    /** {@inheritDoc} */
252    @Override
253    public int sample() {
254        // This will never be null. It may be a no-op delegate that returns zero.
255        final int y2 = smallMeanPoissonSampler.sample();
256
257        double x;
258        double y;
259        double v;
260        int a;
261        double t;
262        double qr;
263        double qa;
264        while (true) {
265            // Step 1:
266            final double u = rng.nextDouble();
267            if (u <= p1) {
268                // Step 2:
269                final double n = gaussian.sample();
270                x = n * sqrtLambdaPlusHalfDelta - 0.5d;
271                if (x > delta || x < -lambda) {
272                    continue;
273                }
274                y = x < 0 ? Math.floor(x) : Math.ceil(x);
275                final double e = exponential.sample();
276                v = -e - 0.5 * n * n + c1;
277            } else {
278                // Step 3:
279                if (u > p1 + p2) {
280                    y = lambda;
281                    break;
282                }
283                x = delta + (twolpd / delta) * exponential.sample();
284                y = Math.ceil(x);
285                v = -exponential.sample() - delta * (x + 1) / twolpd;
286            }
287            // The Squeeze Principle
288            // Step 4.1:
289            a = x < 0 ? 1 : 0;
290            t = y * (y + 1) / (2 * lambda);
291            // Step 4.2
292            if (v < -t && a == 0) {
293                y = lambda + y;
294                break;
295            }
296            // Step 4.3:
297            qr = t * ((2 * y + 1) / (6 * lambda) - 1);
298            qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
299            // Step 4.4:
300            if (v < qa) {
301                y = lambda + y;
302                break;
303            }
304            // Step 4.5:
305            if (v > qr) {
306                continue;
307            }
308            // Step 4.6:
309            if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) {
310                y = lambda + y;
311                break;
312            }
313        }
314
315        return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
316    }
317
318    /**
319     * Compute the natural logarithm of the factorial of {@code n}.
320     *
321     * @param n Argument.
322     * @return {@code log(n!)}
323     * @throws IllegalArgumentException if {@code n < 0}.
324     */
325    private double getFactorialLog(int n) {
326        return factorialLog.value(n);
327    }
328
329    /** {@inheritDoc} */
330    @Override
331    public String toString() {
332        return "Large Mean Poisson deviate [" + rng.toString() + "]";
333    }
334
335    /**
336     * {@inheritDoc}
337     *
338     * @since 1.3
339     */
340    @Override
341    public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
342        return new LargeMeanPoissonSampler(rng, this);
343    }
344
345    /**
346     * Creates a new Poisson distribution sampler.
347     *
348     * @param rng Generator of uniformly distributed random numbers.
349     * @param mean Mean.
350     * @return the sampler
351     * @throws IllegalArgumentException if {@code mean < 1} or {@code mean > 0.5 *}
352     * {@link Integer#MAX_VALUE}.
353     * @since 1.3
354     */
355    public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
356                                                double mean) {
357        return new LargeMeanPoissonSampler(rng, mean);
358    }
359
360    /**
361     * Gets the initialisation state of the sampler.
362     *
363     * <p>The state is computed using an integer {@code lambda} value of
364     * {@code lambda = (int)Math.floor(mean)}.
365     *
366     * <p>The state will be suitable for reconstructing a new sampler with a mean
367     * in the range {@code lambda <= mean < lambda+1} using
368     * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}.
369     *
370     * @return the state
371     */
372    LargeMeanPoissonSamplerState getState() {
373        return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial,
374                delta, halfDelta, sqrtLambdaPlusHalfDelta, twolpd, p1, p2, c1);
375    }
376
377    /**
378     * Encapsulate the state of the sampler. The state is valid for construction of
379     * a sampler in the range {@code lambda <= mean < lambda+1}.
380     *
381     * <p>This class is immutable.
382     *
383     * @see #getLambda()
384     */
385    static final class LargeMeanPoissonSamplerState {
386        /** Algorithm constant {@code lambda}. */
387        private final double lambda;
388        /** Algorithm constant {@code logLambda}. */
389        private final double logLambda;
390        /** Algorithm constant {@code logLambdaFactorial}. */
391        private final double logLambdaFactorial;
392        /** Algorithm constant {@code delta}. */
393        private final double delta;
394        /** Algorithm constant {@code halfDelta}. */
395        private final double halfDelta;
396        /** Algorithm constant {@code sqrtLambdaPlusHalfDelta}. */
397        private final double sqrtLambdaPlusHalfDelta;
398        /** Algorithm constant {@code twolpd}. */
399        private final double twolpd;
400        /** Algorithm constant {@code p1}. */
401        private final double p1;
402        /** Algorithm constant {@code p2}. */
403        private final double p2;
404        /** Algorithm constant {@code c1}. */
405        private final double c1;
406
407        /**
408         * Creates the state.
409         *
410         * <p>The state is valid for construction of a sampler in the range
411         * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer.
412         *
413         * @param lambda the lambda
414         * @param logLambda the log lambda
415         * @param logLambdaFactorial the log lambda factorial
416         * @param delta the delta
417         * @param halfDelta the half delta
418         * @param sqrtLambdaPlusHalfDelta the sqrt(lambda+half delta)
419         * @param twolpd the two lambda plus delta
420         * @param p1 the p1 constant
421         * @param p2 the p2 constant
422         * @param c1 the c1 constant
423         */
424        LargeMeanPoissonSamplerState(double lambda, double logLambda,
425                double logLambdaFactorial, double delta, double halfDelta,
426                double sqrtLambdaPlusHalfDelta, double twolpd,
427                double p1, double p2, double c1) {
428            this.lambda = lambda;
429            this.logLambda = logLambda;
430            this.logLambdaFactorial = logLambdaFactorial;
431            this.delta = delta;
432            this.halfDelta = halfDelta;
433            this.sqrtLambdaPlusHalfDelta = sqrtLambdaPlusHalfDelta;
434            this.twolpd = twolpd;
435            this.p1 = p1;
436            this.p2 = p2;
437            this.c1 = c1;
438        }
439
440        /**
441         * Get the lambda value for the state.
442         *
443         * <p>Equal to {@code floor(mean)} for a Poisson sampler.
444         * @return the lambda value
445         */
446        int getLambda() {
447            return (int) getLambdaRaw();
448        }
449
450        /**
451         * @return algorithm constant {@code lambda}
452         */
453        double getLambdaRaw() {
454            return lambda;
455        }
456
457        /**
458         * @return algorithm constant {@code logLambda}
459         */
460        double getLogLambda() {
461            return logLambda;
462        }
463
464        /**
465         * @return algorithm constant {@code logLambdaFactorial}
466         */
467        double getLogLambdaFactorial() {
468            return logLambdaFactorial;
469        }
470
471        /**
472         * @return algorithm constant {@code delta}
473         */
474        double getDelta() {
475            return delta;
476        }
477
478        /**
479         * @return algorithm constant {@code halfDelta}
480         */
481        double getHalfDelta() {
482            return halfDelta;
483        }
484
485        /**
486         * @return algorithm constant {@code sqrtLambdaPlusHalfDelta}
487         */
488        double getSqrtLambdaPlusHalfDelta() {
489            return sqrtLambdaPlusHalfDelta;
490        }
491
492        /**
493         * @return algorithm constant {@code twolpd}
494         */
495        double getTwolpd() {
496            return twolpd;
497        }
498
499        /**
500         * @return algorithm constant {@code p1}
501         */
502        double getP1() {
503            return p1;
504        }
505
506        /**
507         * @return algorithm constant {@code p2}
508         */
509        double getP2() {
510            return p2;
511        }
512
513        /**
514         * @return algorithm constant {@code c1}
515         */
516        double getC1() {
517            return c1;
518        }
519    }
520}