001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.statistics.distribution;
019
020import java.util.function.DoublePredicate;
021
022/**
023 * Implementation of the hypergeometric distribution.
024 *
025 * <p>The probability mass function of \( X \) is:
026 *
027 * <p>\[ f(k; N, K, n) = \frac{\binom{K}{k} \binom{N - K}{n-k}}{\binom{N}{n}} \]
028 *
029 * <p>for \( N \in \{0, 1, 2, \dots\} \) the population size,
030 * \( K \in \{0, 1, \dots, N\} \) the number of success states,
031 * \( n \in \{0, 1, \dots, N\} \) the number of samples,
032 * \( k \in \{\max(0, n+K-N), \dots, \min(n, K)\} \) the number of successes, and
033 *
034 * <p>\[ \binom{a}{b} = \frac{a!}{b! \, (a-b)!} \]
035 *
036 * <p>is the binomial coefficient.
037 *
038 * @see <a href="https://en.wikipedia.org/wiki/Hypergeometric_distribution">Hypergeometric distribution (Wikipedia)</a>
039 * @see <a href="https://mathworld.wolfram.com/HypergeometricDistribution.html">Hypergeometric distribution (MathWorld)</a>
040 */
041public final class HypergeometricDistribution extends AbstractDiscreteDistribution {
042    /** 1/2. */
043    private static final double HALF = 0.5;
044    /** The number of successes in the population. */
045    private final int numberOfSuccesses;
046    /** The population size. */
047    private final int populationSize;
048    /** The sample size. */
049    private final int sampleSize;
050    /** The lower bound of the support (inclusive). */
051    private final int lowerBound;
052    /** The upper bound of the support (inclusive). */
053    private final int upperBound;
054    /** Binomial probability of success (sampleSize / populationSize). */
055    private final double bp;
056    /** Binomial probability of failure ((populationSize - sampleSize) / populationSize). */
057    private final double bq;
058    /** Cached midpoint of the CDF/SF. The array holds [x, cdf(x)] for the midpoint x.
059     * Used for the cumulative probability functions. */
060    private double[] midpoint;
061
062    /**
063     * @param populationSize Population size.
064     * @param numberOfSuccesses Number of successes in the population.
065     * @param sampleSize Sample size.
066     */
067    private HypergeometricDistribution(int populationSize,
068                                       int numberOfSuccesses,
069                                       int sampleSize) {
070        this.numberOfSuccesses = numberOfSuccesses;
071        this.populationSize = populationSize;
072        this.sampleSize = sampleSize;
073        lowerBound = getLowerDomain(populationSize, numberOfSuccesses, sampleSize);
074        upperBound = getUpperDomain(numberOfSuccesses, sampleSize);
075        bp = (double) sampleSize / populationSize;
076        bq = (double) (populationSize - sampleSize) / populationSize;
077    }
078
079    /**
080     * Creates a hypergeometric distribution.
081     *
082     * @param populationSize Population size.
083     * @param numberOfSuccesses Number of successes in the population.
084     * @param sampleSize Sample size.
085     * @return the distribution
086     * @throws IllegalArgumentException if {@code numberOfSuccesses < 0}, or
087     * {@code populationSize <= 0} or {@code numberOfSuccesses > populationSize}, or
088     * {@code sampleSize > populationSize}.
089     */
090    public static HypergeometricDistribution of(int populationSize,
091                                                int numberOfSuccesses,
092                                                int sampleSize) {
093        if (populationSize <= 0) {
094            throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE,
095                                            populationSize);
096        }
097        if (numberOfSuccesses < 0) {
098            throw new DistributionException(DistributionException.NEGATIVE,
099                                            numberOfSuccesses);
100        }
101        if (sampleSize < 0) {
102            throw new DistributionException(DistributionException.NEGATIVE,
103                                            sampleSize);
104        }
105
106        if (numberOfSuccesses > populationSize) {
107            throw new DistributionException(DistributionException.TOO_LARGE,
108                                            numberOfSuccesses, populationSize);
109        }
110        if (sampleSize > populationSize) {
111            throw new DistributionException(DistributionException.TOO_LARGE,
112                                            sampleSize, populationSize);
113        }
114        return new HypergeometricDistribution(populationSize, numberOfSuccesses, sampleSize);
115    }
116
117    /**
118     * Return the lowest domain value for the given hypergeometric distribution
119     * parameters.
120     *
121     * @param nn Population size.
122     * @param k Number of successes in the population.
123     * @param n Sample size.
124     * @return the lowest domain value of the hypergeometric distribution.
125     */
126    private static int getLowerDomain(int nn, int k, int n) {
127        // Avoid overflow given N > n:
128        // n + K - N == K - (N - n)
129        return Math.max(0, k - (nn - n));
130    }
131
132    /**
133     * Return the highest domain value for the given hypergeometric distribution
134     * parameters.
135     *
136     * @param k Number of successes in the population.
137     * @param n Sample size.
138     * @return the highest domain value of the hypergeometric distribution.
139     */
140    private static int getUpperDomain(int k, int n) {
141        return Math.min(n, k);
142    }
143
144    /**
145     * Gets the population size parameter of this distribution.
146     *
147     * @return the population size.
148     */
149    public int getPopulationSize() {
150        return populationSize;
151    }
152
153    /**
154     * Gets the number of successes parameter of this distribution.
155     *
156     * @return the number of successes.
157     */
158    public int getNumberOfSuccesses() {
159        return numberOfSuccesses;
160    }
161
162    /**
163     * Gets the sample size parameter of this distribution.
164     *
165     * @return the sample size.
166     */
167    public int getSampleSize() {
168        return sampleSize;
169    }
170
171    /** {@inheritDoc} */
172    @Override
173    public double probability(int x) {
174        return Math.exp(logProbability(x));
175    }
176
177    /** {@inheritDoc} */
178    @Override
179    public double probability(int x0, int x1) {
180        if (x0 > x1) {
181            throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
182        }
183        if (x0 == x1 || x1 < lowerBound) {
184            return 0;
185        }
186        // If the range is outside the bounds use the appropriate cumulative probability
187        if (x0 < lowerBound) {
188            return cumulativeProbability(x1);
189        }
190        if (x1 >= upperBound) {
191            // 1 - cdf(x0)
192            return survivalProbability(x0);
193        }
194        // Here: lower <= x0 < x1 < upper:
195        // sum(pdf(x)) for x in (x0, x1]
196        final int lo = x0 + 1;
197        // Sum small values first by starting at the point the greatest distance from the mode.
198        final int mode = (int) Math.floor((sampleSize + 1.0) * (numberOfSuccesses + 1.0) / (populationSize + 2.0));
199        return Math.abs(mode - lo) > Math.abs(mode - x1) ?
200            innerCumulativeProbability(lo, x1) :
201            innerCumulativeProbability(x1, lo);
202    }
203
204    /** {@inheritDoc} */
205    @Override
206    public double logProbability(int x) {
207        if (x < lowerBound || x > upperBound) {
208            return Double.NEGATIVE_INFINITY;
209        }
210        return computeLogProbability(x);
211    }
212
213    /**
214     * Compute the log probability.
215     *
216     * @param x Value.
217     * @return log(P(X = x))
218     */
219    private double computeLogProbability(int x) {
220        final double p1 =
221                SaddlePointExpansionUtils.logBinomialProbability(x, numberOfSuccesses, bp, bq);
222        final double p2 =
223                SaddlePointExpansionUtils.logBinomialProbability(sampleSize - x,
224                        populationSize - numberOfSuccesses, bp, bq);
225        final double p3 =
226                SaddlePointExpansionUtils.logBinomialProbability(sampleSize, populationSize, bp, bq);
227        return p1 + p2 - p3;
228    }
229
230    /** {@inheritDoc} */
231    @Override
232    public double cumulativeProbability(int x) {
233        if (x < lowerBound) {
234            return 0.0;
235        } else if (x >= upperBound) {
236            return 1.0;
237        }
238        final double[] mid = getMidPoint();
239        final int m = (int) mid[0];
240        if (x < m) {
241            return innerCumulativeProbability(lowerBound, x);
242        } else if (x > m) {
243            return 1 - innerCumulativeProbability(upperBound, x + 1);
244        }
245        // cdf(x)
246        return mid[1];
247    }
248
249    /** {@inheritDoc} */
250    @Override
251    public double survivalProbability(int x) {
252        if (x < lowerBound) {
253            return 1.0;
254        } else if (x >= upperBound) {
255            return 0.0;
256        }
257        final double[] mid = getMidPoint();
258        final int m = (int) mid[0];
259        if (x < m) {
260            return 1 - innerCumulativeProbability(lowerBound, x);
261        } else if (x > m) {
262            return innerCumulativeProbability(upperBound, x + 1);
263        }
264        // 1 - cdf(x)
265        return 1 - mid[1];
266    }
267
268    /**
269     * For this distribution, {@code X}, this method returns
270     * {@code P(x0 <= X <= x1)}.
271     * This probability is computed by summing the point probabilities for the
272     * values {@code x0, x0 + dx, x0 + 2 * dx, ..., x1}; the direction {@code dx} is determined
273     * using a comparison of the input bounds.
274     * This should be called by using {@code x0} as the domain limit and {@code x1}
275     * as the internal value. This will result in an initial sum of increasing larger magnitudes.
276     *
277     * @param x0 Inclusive domain bound.
278     * @param x1 Inclusive internal bound.
279     * @return {@code P(x0 <= X <= x1)}.
280     */
281    private double innerCumulativeProbability(int x0, int x1) {
282        // Assume the range is within the domain.
283        // Reuse the computation for probability(x) but avoid checking the domain for each call.
284        int x = x0;
285        double ret = Math.exp(computeLogProbability(x));
286        if (x0 < x1) {
287            while (x != x1) {
288                x++;
289                ret += Math.exp(computeLogProbability(x));
290            }
291        } else {
292            while (x != x1) {
293                x--;
294                ret += Math.exp(computeLogProbability(x));
295            }
296        }
297        return ret;
298    }
299
300    @Override
301    public int inverseCumulativeProbability(double p) {
302        ArgumentUtils.checkProbability(p);
303        return computeInverseProbability(p, 1 - p, false);
304    }
305
306    @Override
307    public int inverseSurvivalProbability(double p) {
308        ArgumentUtils.checkProbability(p);
309        return computeInverseProbability(1 - p, p, true);
310    }
311
312    /**
313     * Implementation for the inverse cumulative or survival probability.
314     *
315     * @param p Cumulative probability.
316     * @param q Survival probability.
317     * @param complement Set to true to compute the inverse survival probability.
318     * @return the value
319     */
320    private int computeInverseProbability(double p, double q, boolean complement) {
321        if (p == 0) {
322            return lowerBound;
323        }
324        if (q == 0) {
325            return upperBound;
326        }
327
328        // Sum the PDF(x) until the appropriate p-value is obtained
329        // CDF: require smallest x where P(X<=x) >= p
330        // SF:  require smallest x where P(X>x) <= q
331        // The choice of summation uses the mid-point.
332        // The test on the CDF or SF is based on the appropriate input p-value.
333
334        final double[] mid = getMidPoint();
335        final int m = (int) mid[0];
336        final double mp = mid[1];
337
338        final int midPointComparison = complement ?
339            Double.compare(1 - mp, q) :
340            Double.compare(p, mp);
341
342        if (midPointComparison < 0) {
343            return inverseLower(p, q, complement);
344        } else if (midPointComparison > 0) {
345            // Avoid floating-point summation error when the mid-point computed using the
346            // lower sum is different to the midpoint computed using the upper sum.
347            // Here we know the result must be above the midpoint so we can clip the result.
348            return Math.max(m + 1, inverseUpper(p, q, complement));
349        }
350        // Exact mid-point
351        return m;
352    }
353
354    /**
355     * Compute the inverse cumulative or survival probability using the lower sum.
356     *
357     * @param p Cumulative probability.
358     * @param q Survival probability.
359     * @param complement Set to true to compute the inverse survival probability.
360     * @return the value
361     */
362    private int inverseLower(double p, double q, boolean complement) {
363        // Sum from the lower bound (computing the cdf)
364        int x = lowerBound;
365        final DoublePredicate test = complement ?
366            i -> 1 - i > q :
367            i -> i < p;
368        double cdf = Math.exp(computeLogProbability(x));
369        while (test.test(cdf)) {
370            x++;
371            cdf += Math.exp(computeLogProbability(x));
372        }
373        return x;
374    }
375
376    /**
377     * Compute the inverse cumulative or survival probability using the upper sum.
378     *
379     * @param p Cumulative probability.
380     * @param q Survival probability.
381     * @param complement Set to true to compute the inverse survival probability.
382     * @return the value
383     */
384    private int inverseUpper(double p, double q, boolean complement) {
385        // Sum from the upper bound (computing the sf)
386        int x = upperBound;
387        final DoublePredicate test = complement ?
388            i -> i < q :
389            i -> 1 - i > p;
390        double sf = 0;
391        while (test.test(sf)) {
392            sf += Math.exp(computeLogProbability(x));
393            x--;
394        }
395        // Here either sf(x) >= q, or cdf(x) <= p
396        // Ensure sf(x) <= q, or cdf(x) >= p
397        if (complement && sf > q ||
398            !complement && 1 - sf < p) {
399            x++;
400        }
401        return x;
402    }
403
404    /**
405     * {@inheritDoc}
406     *
407     * <p>For population size \( N \), number of successes \( K \), and sample
408     * size \( n \), the mean is:
409     *
410     * <p>\[ n \frac{K}{N} \]
411     */
412    @Override
413    public double getMean() {
414        return getSampleSize() * (getNumberOfSuccesses() / (double) getPopulationSize());
415    }
416
417    /**
418     * {@inheritDoc}
419     *
420     * <p>For population size \( N \), number of successes \( K \), and sample
421     * size \( n \), the variance is:
422     *
423     * <p>\[ n \frac{K}{N} \frac{N-K}{N} \frac{N-n}{N-1} \]
424     */
425    @Override
426    public double getVariance() {
427        final double N = getPopulationSize();
428        final double K = getNumberOfSuccesses();
429        final double n = getSampleSize();
430        return (n * K * (N - K) * (N - n)) / (N * N * (N - 1));
431    }
432
433    /**
434     * {@inheritDoc}
435     *
436     * <p>For population size \( N \), number of successes \( K \), and sample
437     * size \( n \), the lower bound of the support is \( \max \{ 0, n + K - N \} \).
438     *
439     * @return lower bound of the support
440     */
441    @Override
442    public int getSupportLowerBound() {
443        return lowerBound;
444    }
445
446    /**
447     * {@inheritDoc}
448     *
449     * <p>For number of successes \( K \), and sample
450     * size \( n \), the upper bound of the support is \( \min \{ n, K \} \).
451     *
452     * @return upper bound of the support
453     */
454    @Override
455    public int getSupportUpperBound() {
456        return upperBound;
457    }
458
459    /**
460     * Return the mid-point {@code x} of the distribution, and the cdf(x).
461     *
462     * <p>This is not the true median. It is the value where the CDF(x) is closest to 0.5;
463     * as such the CDF may be below 0.5 if the next value of x is further from 0.5.
464     *
465     * @return the mid-point ([x, cdf(x)])
466     */
467    private double[] getMidPoint() {
468        double[] v = midpoint;
469        if (v == null) {
470            // Find the closest sum(PDF) to 0.5
471            int x = lowerBound;
472            double p0 = 0;
473            double p1 = Math.exp(computeLogProbability(x));
474            // No check of the upper bound required here as the CDF should sum to 1 and 0.5
475            // is exceeded before a bounds error.
476            while (p1 < HALF) {
477                x++;
478                p0 = p1;
479                p1 += Math.exp(computeLogProbability(x));
480            }
481            // p1 >= 0.5 > p0
482            // Pick closet
483            if (p1 - HALF >= HALF - p0) {
484                x--;
485                p1 = p0;
486            }
487            midpoint = v = new double[] {x, p1};
488        }
489        return v;
490    }
491}