001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.statistics.distribution; 019 020import java.util.function.DoublePredicate; 021 022/** 023 * Implementation of the hypergeometric distribution. 024 * 025 * <p>The probability mass function of \( X \) is: 026 * 027 * <p>\[ f(k; N, K, n) = \frac{\binom{K}{k} \binom{N - K}{n-k}}{\binom{N}{n}} \] 028 * 029 * <p>for \( N \in \{0, 1, 2, \dots\} \) the population size, 030 * \( K \in \{0, 1, \dots, N\} \) the number of success states, 031 * \( n \in \{0, 1, \dots, N\} \) the number of samples, 032 * \( k \in \{\max(0, n+K-N), \dots, \min(n, K)\} \) the number of successes, and 033 * 034 * <p>\[ \binom{a}{b} = \frac{a!}{b! \, (a-b)!} \] 035 * 036 * <p>is the binomial coefficient. 037 * 038 * @see <a href="https://en.wikipedia.org/wiki/Hypergeometric_distribution">Hypergeometric distribution (Wikipedia)</a> 039 * @see <a href="https://mathworld.wolfram.com/HypergeometricDistribution.html">Hypergeometric distribution (MathWorld)</a> 040 */ 041public final class HypergeometricDistribution extends AbstractDiscreteDistribution { 042 /** 1/2. */ 043 private static final double HALF = 0.5; 044 /** The number of successes in the population. */ 045 private final int numberOfSuccesses; 046 /** The population size. */ 047 private final int populationSize; 048 /** The sample size. */ 049 private final int sampleSize; 050 /** The lower bound of the support (inclusive). */ 051 private final int lowerBound; 052 /** The upper bound of the support (inclusive). */ 053 private final int upperBound; 054 /** Binomial probability of success (sampleSize / populationSize). */ 055 private final double bp; 056 /** Binomial probability of failure ((populationSize - sampleSize) / populationSize). */ 057 private final double bq; 058 /** Cached midpoint of the CDF/SF. The array holds [x, cdf(x)] for the midpoint x. 059 * Used for the cumulative probability functions. */ 060 private double[] midpoint; 061 062 /** 063 * @param populationSize Population size. 064 * @param numberOfSuccesses Number of successes in the population. 065 * @param sampleSize Sample size. 066 */ 067 private HypergeometricDistribution(int populationSize, 068 int numberOfSuccesses, 069 int sampleSize) { 070 this.numberOfSuccesses = numberOfSuccesses; 071 this.populationSize = populationSize; 072 this.sampleSize = sampleSize; 073 lowerBound = getLowerDomain(populationSize, numberOfSuccesses, sampleSize); 074 upperBound = getUpperDomain(numberOfSuccesses, sampleSize); 075 bp = (double) sampleSize / populationSize; 076 bq = (double) (populationSize - sampleSize) / populationSize; 077 } 078 079 /** 080 * Creates a hypergeometric distribution. 081 * 082 * @param populationSize Population size. 083 * @param numberOfSuccesses Number of successes in the population. 084 * @param sampleSize Sample size. 085 * @return the distribution 086 * @throws IllegalArgumentException if {@code numberOfSuccesses < 0}, or 087 * {@code populationSize <= 0} or {@code numberOfSuccesses > populationSize}, or 088 * {@code sampleSize > populationSize}. 089 */ 090 public static HypergeometricDistribution of(int populationSize, 091 int numberOfSuccesses, 092 int sampleSize) { 093 if (populationSize <= 0) { 094 throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, 095 populationSize); 096 } 097 if (numberOfSuccesses < 0) { 098 throw new DistributionException(DistributionException.NEGATIVE, 099 numberOfSuccesses); 100 } 101 if (sampleSize < 0) { 102 throw new DistributionException(DistributionException.NEGATIVE, 103 sampleSize); 104 } 105 106 if (numberOfSuccesses > populationSize) { 107 throw new DistributionException(DistributionException.TOO_LARGE, 108 numberOfSuccesses, populationSize); 109 } 110 if (sampleSize > populationSize) { 111 throw new DistributionException(DistributionException.TOO_LARGE, 112 sampleSize, populationSize); 113 } 114 return new HypergeometricDistribution(populationSize, numberOfSuccesses, sampleSize); 115 } 116 117 /** 118 * Return the lowest domain value for the given hypergeometric distribution 119 * parameters. 120 * 121 * @param nn Population size. 122 * @param k Number of successes in the population. 123 * @param n Sample size. 124 * @return the lowest domain value of the hypergeometric distribution. 125 */ 126 private static int getLowerDomain(int nn, int k, int n) { 127 // Avoid overflow given N > n: 128 // n + K - N == K - (N - n) 129 return Math.max(0, k - (nn - n)); 130 } 131 132 /** 133 * Return the highest domain value for the given hypergeometric distribution 134 * parameters. 135 * 136 * @param k Number of successes in the population. 137 * @param n Sample size. 138 * @return the highest domain value of the hypergeometric distribution. 139 */ 140 private static int getUpperDomain(int k, int n) { 141 return Math.min(n, k); 142 } 143 144 /** 145 * Gets the population size parameter of this distribution. 146 * 147 * @return the population size. 148 */ 149 public int getPopulationSize() { 150 return populationSize; 151 } 152 153 /** 154 * Gets the number of successes parameter of this distribution. 155 * 156 * @return the number of successes. 157 */ 158 public int getNumberOfSuccesses() { 159 return numberOfSuccesses; 160 } 161 162 /** 163 * Gets the sample size parameter of this distribution. 164 * 165 * @return the sample size. 166 */ 167 public int getSampleSize() { 168 return sampleSize; 169 } 170 171 /** {@inheritDoc} */ 172 @Override 173 public double probability(int x) { 174 return Math.exp(logProbability(x)); 175 } 176 177 /** {@inheritDoc} */ 178 @Override 179 public double probability(int x0, int x1) { 180 if (x0 > x1) { 181 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1); 182 } 183 if (x0 == x1 || x1 < lowerBound) { 184 return 0; 185 } 186 // If the range is outside the bounds use the appropriate cumulative probability 187 if (x0 < lowerBound) { 188 return cumulativeProbability(x1); 189 } 190 if (x1 >= upperBound) { 191 // 1 - cdf(x0) 192 return survivalProbability(x0); 193 } 194 // Here: lower <= x0 < x1 < upper: 195 // sum(pdf(x)) for x in (x0, x1] 196 final int lo = x0 + 1; 197 // Sum small values first by starting at the point the greatest distance from the mode. 198 final int mode = (int) Math.floor((sampleSize + 1.0) * (numberOfSuccesses + 1.0) / (populationSize + 2.0)); 199 return Math.abs(mode - lo) > Math.abs(mode - x1) ? 200 innerCumulativeProbability(lo, x1) : 201 innerCumulativeProbability(x1, lo); 202 } 203 204 /** {@inheritDoc} */ 205 @Override 206 public double logProbability(int x) { 207 if (x < lowerBound || x > upperBound) { 208 return Double.NEGATIVE_INFINITY; 209 } 210 return computeLogProbability(x); 211 } 212 213 /** 214 * Compute the log probability. 215 * 216 * @param x Value. 217 * @return log(P(X = x)) 218 */ 219 private double computeLogProbability(int x) { 220 final double p1 = 221 SaddlePointExpansionUtils.logBinomialProbability(x, numberOfSuccesses, bp, bq); 222 final double p2 = 223 SaddlePointExpansionUtils.logBinomialProbability(sampleSize - x, 224 populationSize - numberOfSuccesses, bp, bq); 225 final double p3 = 226 SaddlePointExpansionUtils.logBinomialProbability(sampleSize, populationSize, bp, bq); 227 return p1 + p2 - p3; 228 } 229 230 /** {@inheritDoc} */ 231 @Override 232 public double cumulativeProbability(int x) { 233 if (x < lowerBound) { 234 return 0.0; 235 } else if (x >= upperBound) { 236 return 1.0; 237 } 238 final double[] mid = getMidPoint(); 239 final int m = (int) mid[0]; 240 if (x < m) { 241 return innerCumulativeProbability(lowerBound, x); 242 } else if (x > m) { 243 return 1 - innerCumulativeProbability(upperBound, x + 1); 244 } 245 // cdf(x) 246 return mid[1]; 247 } 248 249 /** {@inheritDoc} */ 250 @Override 251 public double survivalProbability(int x) { 252 if (x < lowerBound) { 253 return 1.0; 254 } else if (x >= upperBound) { 255 return 0.0; 256 } 257 final double[] mid = getMidPoint(); 258 final int m = (int) mid[0]; 259 if (x < m) { 260 return 1 - innerCumulativeProbability(lowerBound, x); 261 } else if (x > m) { 262 return innerCumulativeProbability(upperBound, x + 1); 263 } 264 // 1 - cdf(x) 265 return 1 - mid[1]; 266 } 267 268 /** 269 * For this distribution, {@code X}, this method returns 270 * {@code P(x0 <= X <= x1)}. 271 * This probability is computed by summing the point probabilities for the 272 * values {@code x0, x0 + dx, x0 + 2 * dx, ..., x1}; the direction {@code dx} is determined 273 * using a comparison of the input bounds. 274 * This should be called by using {@code x0} as the domain limit and {@code x1} 275 * as the internal value. This will result in an initial sum of increasing larger magnitudes. 276 * 277 * @param x0 Inclusive domain bound. 278 * @param x1 Inclusive internal bound. 279 * @return {@code P(x0 <= X <= x1)}. 280 */ 281 private double innerCumulativeProbability(int x0, int x1) { 282 // Assume the range is within the domain. 283 // Reuse the computation for probability(x) but avoid checking the domain for each call. 284 int x = x0; 285 double ret = Math.exp(computeLogProbability(x)); 286 if (x0 < x1) { 287 while (x != x1) { 288 x++; 289 ret += Math.exp(computeLogProbability(x)); 290 } 291 } else { 292 while (x != x1) { 293 x--; 294 ret += Math.exp(computeLogProbability(x)); 295 } 296 } 297 return ret; 298 } 299 300 @Override 301 public int inverseCumulativeProbability(double p) { 302 ArgumentUtils.checkProbability(p); 303 return computeInverseProbability(p, 1 - p, false); 304 } 305 306 @Override 307 public int inverseSurvivalProbability(double p) { 308 ArgumentUtils.checkProbability(p); 309 return computeInverseProbability(1 - p, p, true); 310 } 311 312 /** 313 * Implementation for the inverse cumulative or survival probability. 314 * 315 * @param p Cumulative probability. 316 * @param q Survival probability. 317 * @param complement Set to true to compute the inverse survival probability. 318 * @return the value 319 */ 320 private int computeInverseProbability(double p, double q, boolean complement) { 321 if (p == 0) { 322 return lowerBound; 323 } 324 if (q == 0) { 325 return upperBound; 326 } 327 328 // Sum the PDF(x) until the appropriate p-value is obtained 329 // CDF: require smallest x where P(X<=x) >= p 330 // SF: require smallest x where P(X>x) <= q 331 // The choice of summation uses the mid-point. 332 // The test on the CDF or SF is based on the appropriate input p-value. 333 334 final double[] mid = getMidPoint(); 335 final int m = (int) mid[0]; 336 final double mp = mid[1]; 337 338 final int midPointComparison = complement ? 339 Double.compare(1 - mp, q) : 340 Double.compare(p, mp); 341 342 if (midPointComparison < 0) { 343 return inverseLower(p, q, complement); 344 } else if (midPointComparison > 0) { 345 // Avoid floating-point summation error when the mid-point computed using the 346 // lower sum is different to the midpoint computed using the upper sum. 347 // Here we know the result must be above the midpoint so we can clip the result. 348 return Math.max(m + 1, inverseUpper(p, q, complement)); 349 } 350 // Exact mid-point 351 return m; 352 } 353 354 /** 355 * Compute the inverse cumulative or survival probability using the lower sum. 356 * 357 * @param p Cumulative probability. 358 * @param q Survival probability. 359 * @param complement Set to true to compute the inverse survival probability. 360 * @return the value 361 */ 362 private int inverseLower(double p, double q, boolean complement) { 363 // Sum from the lower bound (computing the cdf) 364 int x = lowerBound; 365 final DoublePredicate test = complement ? 366 i -> 1 - i > q : 367 i -> i < p; 368 double cdf = Math.exp(computeLogProbability(x)); 369 while (test.test(cdf)) { 370 x++; 371 cdf += Math.exp(computeLogProbability(x)); 372 } 373 return x; 374 } 375 376 /** 377 * Compute the inverse cumulative or survival probability using the upper sum. 378 * 379 * @param p Cumulative probability. 380 * @param q Survival probability. 381 * @param complement Set to true to compute the inverse survival probability. 382 * @return the value 383 */ 384 private int inverseUpper(double p, double q, boolean complement) { 385 // Sum from the upper bound (computing the sf) 386 int x = upperBound; 387 final DoublePredicate test = complement ? 388 i -> i < q : 389 i -> 1 - i > p; 390 double sf = 0; 391 while (test.test(sf)) { 392 sf += Math.exp(computeLogProbability(x)); 393 x--; 394 } 395 // Here either sf(x) >= q, or cdf(x) <= p 396 // Ensure sf(x) <= q, or cdf(x) >= p 397 if (complement && sf > q || 398 !complement && 1 - sf < p) { 399 x++; 400 } 401 return x; 402 } 403 404 /** 405 * {@inheritDoc} 406 * 407 * <p>For population size \( N \), number of successes \( K \), and sample 408 * size \( n \), the mean is: 409 * 410 * <p>\[ n \frac{K}{N} \] 411 */ 412 @Override 413 public double getMean() { 414 return getSampleSize() * (getNumberOfSuccesses() / (double) getPopulationSize()); 415 } 416 417 /** 418 * {@inheritDoc} 419 * 420 * <p>For population size \( N \), number of successes \( K \), and sample 421 * size \( n \), the variance is: 422 * 423 * <p>\[ n \frac{K}{N} \frac{N-K}{N} \frac{N-n}{N-1} \] 424 */ 425 @Override 426 public double getVariance() { 427 final double N = getPopulationSize(); 428 final double K = getNumberOfSuccesses(); 429 final double n = getSampleSize(); 430 return (n * K * (N - K) * (N - n)) / (N * N * (N - 1)); 431 } 432 433 /** 434 * {@inheritDoc} 435 * 436 * <p>For population size \( N \), number of successes \( K \), and sample 437 * size \( n \), the lower bound of the support is \( \max \{ 0, n + K - N \} \). 438 * 439 * @return lower bound of the support 440 */ 441 @Override 442 public int getSupportLowerBound() { 443 return lowerBound; 444 } 445 446 /** 447 * {@inheritDoc} 448 * 449 * <p>For number of successes \( K \), and sample 450 * size \( n \), the upper bound of the support is \( \min \{ n, K \} \). 451 * 452 * @return upper bound of the support 453 */ 454 @Override 455 public int getSupportUpperBound() { 456 return upperBound; 457 } 458 459 /** 460 * Return the mid-point {@code x} of the distribution, and the cdf(x). 461 * 462 * <p>This is not the true median. It is the value where the CDF(x) is closest to 0.5; 463 * as such the CDF may be below 0.5 if the next value of x is further from 0.5. 464 * 465 * @return the mid-point ([x, cdf(x)]) 466 */ 467 private double[] getMidPoint() { 468 double[] v = midpoint; 469 if (v == null) { 470 // Find the closest sum(PDF) to 0.5 471 int x = lowerBound; 472 double p0 = 0; 473 double p1 = Math.exp(computeLogProbability(x)); 474 // No check of the upper bound required here as the CDF should sum to 1 and 0.5 475 // is exceeded before a bounds error. 476 while (p1 < HALF) { 477 x++; 478 p0 = p1; 479 p1 += Math.exp(computeLogProbability(x)); 480 } 481 // p1 >= 0.5 > p0 482 // Pick closet 483 if (p1 - HALF >= HALF - p0) { 484 x--; 485 p1 = p0; 486 } 487 midpoint = v = new double[] {x, p1}; 488 } 489 return v; 490 } 491}