001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.statistics.inference; 018 019import java.util.Arrays; 020import java.util.Objects; 021import java.util.function.Consumer; 022import java.util.function.DoublePredicate; 023import java.util.function.DoubleUnaryOperator; 024import java.util.function.IntToDoubleFunction; 025import org.apache.commons.numbers.combinatorics.LogBinomialCoefficient; 026import org.apache.commons.statistics.inference.BrentOptimizer.PointValuePair; 027 028/** 029 * Implements an unconditioned exact test for a contingency table. 030 * 031 * <p>Performs an exact test for the statistical significance of the association (contingency) 032 * between two kinds of categorical classification. A 2x2 contingency table is: 033 * 034 * <p>\[ \left[ {\begin{array}{cc} 035 * a & b \\ 036 * c & d \\ 037 * \end{array} } \right] \] 038 * 039 * <p>This test applies to the case of a 2x2 contingency table with one margin fixed. Note that 040 * if both margins are fixed (the row sums and column sums are not random) 041 * then Fisher's exact test can be applied. 042 * 043 * <p>This implementation fixes the column sums \( m = a + c \) and \( n = b + d \). 044 * All possible tables can be created using \( 0 \le a \le m \) and \( 0 \le b \le n \). 045 * The random values \( a \) and \( b \) follow a binomial distribution with probabilities 046 * \( p_0 \) and \( p_1 \) such that \( a \sim B(m, p_0) \) and \( b \sim B(n, p_1) \). 047 * The p-value of the 2x2 table is the product of two binomials: 048 * 049 * <p>\[ \begin{aligned} 050 * p &= Pr(a; m, p_0) \times Pr(b; n, p_1) \\ 051 * &= \binom{m}{a} p_0^a (1-p_0)^{m-a} \times \binom{n}{b} p_1^b (1-p_1)^{n-b} \end{aligned} \] 052 * 053 * <p>For the binomial model, the null hypothesis is the two nuisance parameters are equal 054 * \( p_0 = p_1 = \pi\), with \( \pi \) the probability for equal proportions, and the probability 055 * of any single table is: 056 * 057 * <p>\[ p = \binom{m}{a} \binom{n}{b} \pi^{a+b} (1-\pi)^{m+n-a-b} \] 058 * 059 * <p>The p-value of the observed table is calculated by maximising the sum of the as or more 060 * extreme tables over the domain of the nuisance parameter \( 0 \lt \pi \lt 1 \): 061 * 062 * <p>\[ p(a, b) = \sum_{i,j} \binom{m}{i} \binom{n}{j} \pi^{i+j} (1-\pi)^{m+n-i-j} \] 063 * 064 * <p>where table \( (i,j) \) is as or more extreme than the observed table \( (a, b) \). The test 065 * can be configured to select more extreme tables using various {@linkplain Method methods}. 066 * 067 * <p>Note that the sum of the joint binomial distribution is a univariate function for 068 * the nuisance parameter \( \pi \). This function may have many local maxima and the 069 * search enumerates the range with a configured {@linkplain #withInitialPoints(int) 070 * number of points}. The best candidates are optionally used as the start point for an 071 * {@linkplain #withOptimize(boolean) optimized} search for a local maxima. 072 * 073 * <p>References: 074 * <ol> 075 * <li> 076 * Barnard, G.A. (1947). 077 * <a href="https://doi.org/10.1093/biomet/34.1-2.123">Significance tests for 2x2 tables.</a> 078 * Biometrika, 34, Issue 1-2, 123–138. 079 * <li> 080 * Boschloo, R.D. (1970). 081 * <a href="https://doi.org/10.1111/j.1467-9574.1970.tb00104.x">Raised conditional level of 082 * significance for the 2 × 2-table when testing the equality of two probabilities.</a> 083 * Statistica neerlandica, 24(1), 1–9. 084 * <li> 085 * Suisaa, A and Shuster, J.J. (1985). 086 * <a href="https://doi.org/10.2307/2981892">Exact Unconditional Sample Sizes 087 * for the 2 × 2 Binomial Trial.</a> 088 * Journal of the Royal Statistical Society. Series A (General), 148(4), 317-327. 089 * </ol> 090 * 091 * @see FisherExactTest 092 * @see <a href="https://en.wikipedia.org/wiki/Boschloo%27s_test">Boschloo's test (Wikipedia)</a> 093 * @see <a href="https://en.wikipedia.org/wiki/Barnard%27s_test">Barnard's test (Wikipedia)</a> 094 * @since 1.1 095 */ 096public final class UnconditionedExactTest { 097 /** 098 * Default instance. 099 * 100 * <p>SciPy's boschloo_exact and barnard_exact tests use 32 points in the interval [0, 101 * 1) The R Exact package uses 100 in the interval [1e-5, 1-1e-5]. Barnards 1947 paper 102 * describes the nuisance parameter in the open interval {@code 0 < pi < 1}. Here we 103 * respect the open-interval for the initial candidates and ignore 0 and 1. The 104 * initial bounds used are the same as the R Exact package. We closely match the inner 105 * 31 points from SciPy by using 33 points by default. 106 */ 107 private static final UnconditionedExactTest DEFAULT = new UnconditionedExactTest( 108 AlternativeHypothesis.TWO_SIDED, Method.BOSCHLOO, 33, true); 109 /** Lower bound for the enumerated interval. The upper bound is {@code 1 - lower}. */ 110 private static final double LOWER_BOUND = 1e-5; 111 /** Relative epsilon for the Brent solver. This is limited for a univariate function 112 * to approximately sqrt(eps) with eps = 2^-52. */ 113 private static final double SOLVER_RELATIVE_EPS = 1.4901161193847656E-8; 114 /** Fraction of the increment (interval between enumerated points) to initialise the bracket 115 * for the minima. Note the minima should lie between x +/- increment. The bracket should 116 * search within this range. Set to 1/8 and so the initial point of the bracket is 117 * approximately 1.61 * 1/8 = 0.2 of the increment away from initial points a or b. */ 118 private static final double INC_FRACTION = 0.125; 119 /** Maximum number of candidate to optimize. This is a safety limit to avoid excess 120 * optimization. Only candidates within a relative tolerance of the best candidate are 121 * stored. If the number of candidates exceeds this value then many candidates have a 122 * very similar p-value and the top candidates will be optimized. Using a value of 3 123 * allows at least one other candidate to be optimized when there is two-fold 124 * symmetry in the energy function. */ 125 private static final int MAX_CANDIDATES = 3; 126 /** Relative distance of candidate minima from the lowest candidate. Used to exclude 127 * poor candidates from optimization. */ 128 private static final double MINIMA_EPS = 0.02; 129 /** The maximum number of tables. This is limited by the maximum number of indices that 130 * can be maintained in memory. Potentially up to this number of tables must be tracked 131 * during computation of the p-value for as or more extreme tables. The limit is set 132 * using the same limit for maximum capacity as java.util.ArrayList. In practice any 133 * table anywhere near this limit can be computed using an alternative such as a chi-squared 134 * or g test. */ 135 private static final int MAX_TABLES = Integer.MAX_VALUE - 8; 136 /** Error message text for zero column sums. */ 137 private static final String COLUMN_SUM = "Column sum"; 138 139 /** Alternative hypothesis. */ 140 private final AlternativeHypothesis alternative; 141 /** Method to identify more extreme tables. */ 142 private final Method method; 143 /** Number of initial points. */ 144 private final int points; 145 /** Option to optimize the best initial point(s). */ 146 private final boolean optimize; 147 148 /** 149 * Define the method to determine the more extreme tables. 150 * 151 * @since 1.1 152 */ 153 public enum Method { 154 /** 155 * Uses the test statistic from a Z-test using a pooled variance. 156 * 157 * <p>\[ T(X) = \frac{\hat{p}_0 - \hat{p}_1}{\sqrt{\hat{p}(1 - \hat{p}) (\frac{1}{m} + \frac{1}{n})}} \] 158 * 159 * <p>where \( \hat{p}_0 = a / m \), \( \hat{p}_1 = b / n \), and 160 * \( \hat{p} = (a+b) / (m+n) \) are the estimators of \( p_0 \), \( p_1 \) and the 161 * pooled probability \( p \) assuming \( p_0 = p_1 \). 162 * 163 * <p>The more extreme tables are identified using the {@link AlternativeHypothesis}: 164 * <ul> 165 * <li>greater: \( T(X) \ge T(X_0) \) 166 * <li>less: \( T(X) \le T(X_0) \) 167 * <li>two-sided: \( | T(X) | \ge | T(X_0) | \) 168 * </ul> 169 * 170 * <p>The use of the Z statistic was suggested by Suissa and Shuster (1985). 171 * This method is uniformly more powerful than Fisher's test for balanced designs 172 * (\( m = n \)). 173 */ 174 Z_POOLED, 175 176 /** 177 * Uses the test statistic from a Z-test using an unpooled variance. 178 * 179 * <p>\[ T(X) = \frac{\hat{p}_0 - \hat{p}_1} 180 * {\sqrt{ \frac{\hat{p}_0(1 - \hat{p}_0)}{m} + \frac{\hat{p}_1(1 - \hat{p}_1)}{n}} } \] 181 * 182 * <p>where \( \hat{p}_0 = a / m \) and \( \hat{p}_1 = b / n \). 183 * 184 * <p>The more extreme tables are identified using the {@link AlternativeHypothesis} as 185 * per the {@link #Z_POOLED} method. 186 */ 187 Z_UNPOOLED, 188 189 /** 190 * Uses the p-value from Fisher's exact test. This is also known as Boschloo's test. 191 * 192 * <p>The p-value for Fisher's test is computed using using the 193 * {@link AlternativeHypothesis}. The more extreme tables are identified using 194 * \( p(X) \le p(X_0) \). 195 * 196 * <p>This method is always uniformly more powerful than Fisher's test. 197 * 198 * @see FisherExactTest 199 */ 200 BOSCHLOO; 201 } 202 203 /** 204 * Result for the unconditioned exact test. 205 * 206 * <p>This class is immutable. 207 * 208 * @since 1.1 209 */ 210 public static final class Result extends BaseSignificanceResult { 211 /** Nuisance parameter. */ 212 private final double pi; 213 214 /** 215 * Create an instance where all tables are more extreme, i.e. the p-value 216 * is 1.0. 217 * 218 * @param statistic Test statistic. 219 */ 220 Result(double statistic) { 221 super(statistic, 1); 222 this.pi = 0.5; 223 } 224 225 /** 226 * @param statistic Test statistic. 227 * @param pi Nuisance parameter. 228 * @param p Result p-value. 229 */ 230 Result(double statistic, double pi, double p) { 231 super(statistic, p); 232 this.pi = pi; 233 } 234 235 /** 236 * {@inheritDoc} 237 * 238 * <p>The value of the statistic is dependent on the {@linkplain Method method} 239 * used to determine the more extreme tables. 240 */ 241 @Override 242 public double getStatistic() { 243 // Note: This method is here for documentation 244 return super.getStatistic(); 245 } 246 247 /** 248 * Gets the nuisance parameter that maximised the probability sum of the as or more 249 * extreme tables. 250 * 251 * @return the nuisance parameter. 252 */ 253 public double getNuisanceParameter() { 254 return pi; 255 } 256 } 257 258 /** 259 * An expandable list of (x,y) values. This allows tracking 2D positions stored as 260 * a single index. 261 */ 262 private static class XYList { 263 /** The maximum size of array to allocate. */ 264 private final int max; 265 /** Width, or maximum x value (exclusive). */ 266 private final int width; 267 268 /** The size of the list. */ 269 private int size; 270 /** The list data. */ 271 private int[] data = new int[10]; 272 273 /** 274 * Create an instance. It is assumed that (maxx+1)*(maxy+1) does not exceed the 275 * capacity of an array. 276 * 277 * @param maxx Maximum x-value (inclusive). 278 * @param maxy Maximum y-value (inclusive). 279 */ 280 XYList(int maxx, int maxy) { 281 this.width = maxx + 1; 282 this.max = width * (maxy + 1); 283 } 284 285 /** 286 * Gets the width. 287 * (x, y) values are stored using y * width + x. 288 * 289 * @return the width 290 */ 291 int getWidth() { 292 return width; 293 } 294 295 /** 296 * Gets the maximum X value (inclusive). 297 * 298 * @return the max X 299 */ 300 int getMaxX() { 301 return width - 1; 302 } 303 304 /** 305 * Gets the maximum Y value (inclusive). 306 * 307 * @return the max Y 308 */ 309 int getMaxY() { 310 return max / width - 1; 311 } 312 313 /** 314 * Adds the value to the list. 315 * 316 * @param x X value. 317 * @param y Y value. 318 */ 319 void add(int x, int y) { 320 if (size == data.length) { 321 // Overflow safe doubling of the current size. 322 data = Arrays.copyOf(data, (int) Math.min(max, size * 2L)); 323 } 324 data[size++] = width * y + x; 325 } 326 327 /** 328 * Gets the 2D index at the specified {@code index}. 329 * The index is y * width + x: 330 * <pre> 331 * x = index % width 332 * y = index / width 333 * </pre> 334 * 335 * @param index Element index. 336 * @return the 2D index 337 */ 338 int get(int index) { 339 return data[index]; 340 } 341 342 /** 343 * Gets the number of elements in the list. 344 * 345 * @return the size 346 */ 347 int size() { 348 return size; 349 } 350 351 /** 352 * Checks if the list size is zero. 353 * 354 * @return true if empty 355 */ 356 boolean isEmpty() { 357 return size == 0; 358 } 359 360 /** 361 * Checks if the list is the maximum capacity. 362 * 363 * @return true if full 364 */ 365 boolean isFull() { 366 return size == max; 367 } 368 } 369 370 /** 371 * A container of (key,value) pairs to store candidate minima. Encapsulates the 372 * logic of storing multiple initial search points for optimization. 373 * 374 * <p>Stores all pairs within a relative tolerance of the lowest minima up to a set 375 * capacity. When at capacity the worst candidate is replaced by addition of a 376 * better candidate. 377 * 378 * <p>Special handling is provided to store only a single NaN value if no non-NaN 379 * values have been observed. This prevents storing a large number of NaN 380 * candidates. 381 */ 382 static class Candidates { 383 /** The maximum size of array to allocate. */ 384 private final int max; 385 /** Relative distance from lowest candidate. */ 386 private final double eps; 387 /** Candidate (key,value) pairs. */ 388 private double[][] data; 389 /** Current size of the list. */ 390 private int size; 391 /** Current minimum. */ 392 private double min = Double.POSITIVE_INFINITY; 393 /** Current threshold for inclusion. */ 394 private double threshold = Double.POSITIVE_INFINITY; 395 396 /** 397 * Create an instance. 398 * 399 * @param max Maximum number of allowed candidates (limited to at least 1). 400 * @param eps Relative distance of candidate minima from the lowest candidate 401 * (assumed to be positive and finite). 402 */ 403 Candidates(int max, double eps) { 404 this.max = Math.max(1, max); 405 this.eps = eps; 406 // Create the initial storage 407 data = new double[Math.min(this.max, 4)][]; 408 } 409 410 /** 411 * Adds the (key, value) pair. 412 * 413 * @param k Key. 414 * @param v Value. 415 */ 416 void add(double k, double v) { 417 // Store only a single NaN 418 if (Double.isNaN(v)) { 419 if (size == 0) { 420 // No requirement to check capacity 421 data[size++] = new double[] {k, v}; 422 } 423 return; 424 } 425 // Here values are non-NaN. 426 // If higher then do not store. 427 if (v > threshold) { 428 return; 429 } 430 // Check if lower than the current minima. 431 if (v < min) { 432 min = v; 433 // Get new threshold 434 threshold = v + Math.abs(v) * eps; 435 // Remove existing entries above the threshold 436 int s = 0; 437 for (int i = 0; i < size; i++) { 438 // This will filter NaN values 439 if (data[i][1] <= threshold) { 440 data[s++] = data[i]; 441 } 442 } 443 size = s; 444 // Caution: This does not clear stale data 445 // by setting all values in [newSize, oldSize) = null 446 } 447 addPair(k, v); 448 } 449 450 /** 451 * Add the (key, value) pair to the data. 452 * It is assumed the data satisfy the conditions for addition. 453 * 454 * @param k Key. 455 * @param v Value. 456 */ 457 private void addPair(double k, double v) { 458 if (size == data.length) { 459 if (size == max) { 460 // At capacity. 461 replaceWorst(k, v); 462 return; 463 } 464 // Expand 465 data = Arrays.copyOfRange(data, 0, (int) Math.min(max, size * 2L)); 466 } 467 data[size++] = new double[] {k, v}; 468 } 469 470 /** 471 * Replace the worst candidate. 472 * 473 * @param k Key. 474 * @param v Value. 475 */ 476 private void replaceWorst(double k, double v) { 477 // Note: This only occurs when NaN values have been removed by addition 478 // of non-NaN values. 479 double[] worst = data[0]; 480 for (int i = 1; i < size; i++) { 481 if (worst[1] < data[i][1]) { 482 worst = data[i]; 483 } 484 } 485 worst[0] = k; 486 worst[1] = v; 487 } 488 489 /** 490 * Return the minimum (key,value) pair. 491 * 492 * @return the minimum (or null) 493 */ 494 double[] getMinimum() { 495 // This will handle size=0 as data[0] will be null 496 double[] best = data[0]; 497 for (int i = 1; i < size; i++) { 498 if (best[1] > data[i][1]) { 499 best = data[i]; 500 } 501 } 502 return best; 503 } 504 505 /** 506 * Perform the given action for each (key, value) pair. 507 * 508 * @param action Action. 509 */ 510 void forEach(Consumer<double[]> action) { 511 for (int i = 0; i < size; i++) { 512 action.accept(data[i]); 513 } 514 } 515 } 516 517 /** 518 * Compute the statistic for Boschloo's test. 519 */ 520 private interface BoschlooStatistic { 521 /** 522 * Compute Fisher's p-value for the 2x2 contingency table with the observed 523 * value {@code x} in position [0][0]. Note that the table margins are fixed 524 * and are defined by the population size, number of successes and sample 525 * size of the specified hypergeometric distribution. 526 * 527 * @param dist Hypergeometric distribution. 528 * @param x Value. 529 * @return Fisher's p-value 530 */ 531 double value(Hypergeom dist, int x); 532 } 533 534 /** 535 * @param alternative Alternative hypothesis. 536 * @param method Method to identify more extreme tables. 537 * @param points Number of initial points. 538 * @param optimize Option to optimize the best initial point(s). 539 */ 540 private UnconditionedExactTest(AlternativeHypothesis alternative, 541 Method method, 542 int points, 543 boolean optimize) { 544 this.alternative = alternative; 545 this.method = method; 546 this.points = points; 547 this.optimize = optimize; 548 } 549 550 /** 551 * Return an instance using the default options. 552 * 553 * <ul> 554 * <li>{@link AlternativeHypothesis#TWO_SIDED} 555 * <li>{@link Method#BOSCHLOO} 556 * <li>{@linkplain #withInitialPoints(int) points = 33} 557 * <li>{@linkplain #withOptimize(boolean) optimize = true} 558 * </ul> 559 * 560 * @return default instance 561 */ 562 public static UnconditionedExactTest withDefaults() { 563 return DEFAULT; 564 } 565 566 /** 567 * Return an instance with the configured alternative hypothesis. 568 * 569 * @param v Value. 570 * @return an instance 571 */ 572 public UnconditionedExactTest with(AlternativeHypothesis v) { 573 return new UnconditionedExactTest(Objects.requireNonNull(v), method, points, optimize); 574 } 575 576 /** 577 * Return an instance with the configured method. 578 * 579 * @param v Value. 580 * @return an instance 581 */ 582 public UnconditionedExactTest with(Method v) { 583 return new UnconditionedExactTest(alternative, Objects.requireNonNull(v), points, optimize); 584 } 585 586 /** 587 * Return an instance with the configured number of initial points. 588 * 589 * <p>The search for the nuisance parameter will use \( v \) points in the open interval 590 * \( (0, 1) \). The interval is evaluated by including start and end points approximately 591 * equal to 0 and 1. Additional internal points are enumerated using increments of 592 * approximately \( \frac{1}{v-1} \). The minimum number of points is 2. Increasing the 593 * number of points increases the precision of the search at the cost of performance. 594 * 595 * <p>To approximately double the number of points so that all existing points are included 596 * and additional points half-way between them are sampled requires using {@code 2p - 1} 597 * where {@code p} is the existing number of points. 598 * 599 * @param v Value. 600 * @return an instance 601 * @throws IllegalArgumentException if the value is {@code < 2}. 602 */ 603 public UnconditionedExactTest withInitialPoints(int v) { 604 if (v <= 1) { 605 throw new InferenceException(InferenceException.X_LT_Y, v, 2); 606 } 607 return new UnconditionedExactTest(alternative, method, v, optimize); 608 } 609 610 /** 611 * Return an instance with the configured optimization of initial search points. 612 * 613 * <p>If enabled then the initial point(s) with the highest probability is/are used as the start 614 * for an optimization to find a local maxima. 615 * 616 * @param v Value. 617 * @return an instance 618 * @see #withInitialPoints(int) 619 */ 620 public UnconditionedExactTest withOptimize(boolean v) { 621 return new UnconditionedExactTest(alternative, method, points, v); 622 } 623 624 /** 625 * Compute the statistic for the unconditioned exact test. The statistic returned 626 * depends on the configured {@linkplain Method method}. 627 * 628 * @param table 2-by-2 contingency table. 629 * @return test statistic 630 * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any 631 * table entry is negative; any column sum is zero; the table sum is zero or not an 632 * integer; or the number of possible tables exceeds the maximum array capacity. 633 * @see #with(Method) 634 * @see #test(int[][]) 635 */ 636 public double statistic(int[][] table) { 637 checkTable(table); 638 final int a = table[0][0]; 639 final int b = table[0][1]; 640 final int c = table[1][0]; 641 final int d = table[1][1]; 642 final int m = a + c; 643 final int n = b + d; 644 switch (method) { 645 case Z_POOLED: 646 return statisticZ(a, b, m, n, true); 647 case Z_UNPOOLED: 648 return statisticZ(a, b, m, n, false); 649 case BOSCHLOO: 650 return statisticBoschloo(a, b, m, n); 651 default: 652 throw new IllegalStateException(String.valueOf(method)); 653 } 654 } 655 656 /** 657 * Performs an unconditioned exact test on the 2-by-2 contingency table. The statistic and 658 * p-value returned depends on the configured {@linkplain Method method} and 659 * {@linkplain AlternativeHypothesis alternative hypothesis}. 660 * 661 * <p>The search for the nuisance parameter that maximises the p-value can be configured to: 662 * start with a number of {@linkplain #withInitialPoints(int) initial points}; and 663 * {@linkplain #withOptimize(boolean) optimize} the best points. 664 * 665 * @param table 2-by-2 contingency table. 666 * @return test result 667 * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any 668 * table entry is negative; any column sum is zero; the table sum is zero or not an 669 * integer; or the number of possible tables exceeds the maximum array capacity. 670 * @see #with(Method) 671 * @see #with(AlternativeHypothesis) 672 * @see #statistic(int[][]) 673 */ 674 public Result test(int[][] table) { 675 checkTable(table); 676 final int a = table[0][0]; 677 final int b = table[0][1]; 678 final int c = table[1][0]; 679 final int d = table[1][1]; 680 final int m = a + c; 681 final int n = b + d; 682 683 // Used to track more extreme tables 684 final XYList tableList = new XYList(m, n); 685 686 final double statistic = findExtremeTables(a, b, tableList); 687 if (tableList.isEmpty() || tableList.isFull()) { 688 // All possible tables are more extreme, e.g. a two-sided test where the 689 // z-statistic is zero. 690 return new Result(statistic); 691 } 692 final double[] opt = computePValue(tableList); 693 694 return new Result(statistic, opt[0], opt[1]); 695 } 696 697 /** 698 * Find all tables that are as or more extreme than the observed table. 699 * 700 * <p>If the list of tables is full then all tables are more extreme. 701 * Some configurations can detect this without performing a search 702 * and in this case the list of tables is returned as empty. 703 * 704 * @param a Observed value for a. 705 * @param b Observed value for b. 706 * @param tableList List to track more extreme tables. 707 * @return the test statistic 708 */ 709 private double findExtremeTables(int a, int b, XYList tableList) { 710 final int m = tableList.getMaxX(); 711 final int n = tableList.getMaxY(); 712 switch (method) { 713 case Z_POOLED: 714 return findExtremeTablesZ(a, b, m, n, true, tableList); 715 case Z_UNPOOLED: 716 return findExtremeTablesZ(a, b, m, n, false, tableList); 717 case BOSCHLOO: 718 return findExtremeTablesBoschloo(a, b, m, n, tableList); 719 default: 720 throw new IllegalStateException(String.valueOf(method)); 721 } 722 } 723 724 /** 725 * Compute the statistic from a Z-test. 726 * 727 * @param a Observed value for a. 728 * @param b Observed value for b. 729 * @param m Column sum m. 730 * @param n Column sum n. 731 * @param pooled true to use a pooled variance. 732 * @return z 733 */ 734 private static double statisticZ(int a, int b, int m, int n, boolean pooled) { 735 final double p0 = (double) a / m; 736 final double p1 = (double) b / n; 737 // Avoid NaN generation 0 / 0 when the variance is 0 738 if (p0 != p1) { 739 final double variance; 740 if (pooled) { 741 // Integer sums will not overflow 742 final double p = (double) (a + b) / (m + n); 743 variance = p * (1 - p) * (1.0 / m + 1.0 / n); 744 } else { 745 variance = p0 * (1 - p0) / m + p1 * (1 - p1) / n; 746 } 747 return (p0 - p1) / Math.sqrt(variance); 748 } 749 return 0; 750 } 751 752 /** 753 * Find all tables that are as or more extreme than the observed table using the Z statistic. 754 * 755 * @param a Observed value for a. 756 * @param b Observed value for b. 757 * @param m Column sum m. 758 * @param n Column sum n. 759 * @param pooled true to use a pooled variance. 760 * @param tableList List to track more extreme tables. 761 * @return observed z 762 */ 763 private double findExtremeTablesZ(int a, int b, int m, int n, boolean pooled, XYList tableList) { 764 final double statistic = statisticZ(a, b, m, n, pooled); 765 // Identify more extreme tables using the alternate hypothesis 766 final DoublePredicate test; 767 if (alternative == AlternativeHypothesis.GREATER_THAN) { 768 test = z -> z >= statistic; 769 } else if (alternative == AlternativeHypothesis.LESS_THAN) { 770 test = z -> z <= statistic; 771 } else { 772 // two-sided 773 if (statistic == 0) { 774 // Early exit: all tables are as extreme 775 return 0; 776 } 777 final double za = Math.abs(statistic); 778 test = z -> Math.abs(z) >= za; 779 } 780 // Precompute factors 781 final double mn = (double) m + n; 782 final double norm = 1.0 / m + 1.0 / n; 783 double z; 784 // Process all possible tables 785 for (int i = 0; i <= m; i++) { 786 final double p0 = (double) i / m; 787 final double vp0 = p0 * (1 - p0) / m; 788 for (int j = 0; j <= n; j++) { 789 final double p1 = (double) j / n; 790 // Avoid NaN generation 0 / 0 when the variance is 0 791 if (p0 == p1) { 792 z = 0; 793 } else { 794 final double variance; 795 if (pooled) { 796 // Integer sums will not overflow 797 final double p = (i + j) / mn; 798 variance = p * (1 - p) * norm; 799 } else { 800 variance = vp0 + p1 * (1 - p1) / n; 801 } 802 z = (p0 - p1) / Math.sqrt(variance); 803 } 804 if (test.test(z)) { 805 tableList.add(i, j); 806 } 807 } 808 } 809 return statistic; 810 } 811 812 /** 813 * Compute the statistic using Fisher's p-value (also known as Boschloo's test). 814 * 815 * @param a Observed value for a. 816 * @param b Observed value for b. 817 * @param m Column sum m. 818 * @param n Column sum n. 819 * @return p-value 820 */ 821 private double statisticBoschloo(int a, int b, int m, int n) { 822 final int nn = m + n; 823 final int k = a + b; 824 // Re-use the cached Hypergeometric implementation to allow the value 825 // to be identical for the statistic and test methods. 826 final Hypergeom dist = new Hypergeom(nn, k, m); 827 if (alternative == AlternativeHypothesis.GREATER_THAN) { 828 return dist.sf(a - 1); 829 } else if (alternative == AlternativeHypothesis.LESS_THAN) { 830 return dist.cdf(a); 831 } 832 // two-sided: Find all i where Pr(X = i) <= Pr(X = a) and sum them. 833 return statisticBoschlooTwoSided(dist, a); 834 } 835 836 /** 837 * Compute the two-sided statistic using Fisher's p-value (also known as Boschloo's test). 838 * 839 * @param distribution Hypergeometric distribution. 840 * @param k Observed value. 841 * @return p-value 842 */ 843 private static double statisticBoschlooTwoSided(Hypergeom distribution, int k) { 844 // two-sided: Find all i where Pr(X = i) <= Pr(X = k) and sum them. 845 // Logic is the same as FisherExactTest but using the probability (PMF), which 846 // is cached, rather than the logProbability. 847 final double pk = distribution.pmf(k); 848 849 final int m1 = distribution.getLowerMode(); 850 final int m2 = distribution.getUpperMode(); 851 if (k < m1) { 852 // Lower half = cdf(k) 853 // Find upper half. As k < lower mode i should never 854 // reach the lower mode based on the probability alone. 855 // Bracket with the upper mode. 856 final int i = Searches.searchDescending(m2, distribution.getSupportUpperBound(), pk, 857 distribution::pmf); 858 return distribution.cdf(k) + 859 distribution.sf(i - 1); 860 } else if (k > m2) { 861 // Upper half = sf(k - 1) 862 // Find lower half. As k > upper mode i should never 863 // reach the upper mode based on the probability alone. 864 // Bracket with the lower mode. 865 final int i = Searches.searchAscending(distribution.getSupportLowerBound(), m1, pk, 866 distribution::pmf); 867 return distribution.cdf(i) + 868 distribution.sf(k - 1); 869 } 870 // k == mode 871 // Edge case where the sum of probabilities will be either 872 // 1 or 1 - Pr(X = mode) where mode != k 873 final double pm = distribution.pmf(k == m1 ? m2 : m1); 874 return pm > pk ? 1 - pm : 1; 875 } 876 877 /** 878 * Find all tables that are as or more extreme than the observed table using the 879 * Fisher's p-value as the statistic (also known as Boschloo's test). 880 * 881 * @param a Observed value for a. 882 * @param b Observed value for b. 883 * @param m Column sum m. 884 * @param n Column sum n. 885 * @param tableList List to track more extreme tables. 886 * @return observed p-value 887 */ 888 private double findExtremeTablesBoschloo(int a, int b, int m, int n, XYList tableList) { 889 final double statistic = statisticBoschloo(a, b, m, n); 890 891 // Function to compute the statistic 892 final BoschlooStatistic func; 893 if (alternative == AlternativeHypothesis.GREATER_THAN) { 894 func = (dist, x) -> dist.sf(x - 1); 895 } else if (alternative == AlternativeHypothesis.LESS_THAN) { 896 func = Hypergeom::cdf; 897 } else { 898 func = UnconditionedExactTest::statisticBoschlooTwoSided; 899 } 900 901 // All tables are: 0 <= i <= m by 0 <= j <= n 902 // Diagonal (upper-left to lower-right) strips of the possible 903 // tables use the same hypergeometric distribution 904 // (i.e. i+j == number of successes). To enumerate all requires 905 // using the full range of all distributions: 0 <= i+j <= m+n. 906 // Note the column sum m is fixed. 907 final int mn = m + n; 908 for (int k = 0; k <= mn; k++) { 909 final Hypergeom dist = new Hypergeom(mn, k, m); 910 final int lo = dist.getSupportLowerBound(); 911 final int hi = dist.getSupportUpperBound(); 912 for (int i = lo; i <= hi; i++) { 913 if (func.value(dist, i) <= statistic) { 914 // j = k - i 915 tableList.add(i, k - i); 916 } 917 } 918 } 919 return statistic; 920 } 921 922 /** 923 * Compute the nuisance parameter and p-value for the binomial model given the list 924 * of possible tables. 925 * 926 * <p>The current method enumerates an initial set of points and stores local 927 * extrema as candidates. Any candidate within 2% of the best is optionally 928 * optimized; this is limited to the top 3 candidates. These settings 929 * could be exposed as configurable options. Currently only the choice to optimize 930 * or not is exposed. 931 * 932 * @param tableList List of tables. 933 * @return [nuisance parameter, p-value] 934 */ 935 private double[] computePValue(XYList tableList) { 936 final DoubleUnaryOperator func = createBinomialModel(tableList); 937 938 // Enumerate the range [LOWER, 1-LOWER] and save the best points for optimization 939 final Candidates minima = new Candidates(MAX_CANDIDATES, MINIMA_EPS); 940 final int n = points - 1; 941 final double inc = (1.0 - 2 * LOWER_BOUND) / n; 942 // Moving window of 3 values to identify minima. 943 // px holds the position of the previous evaluated point. 944 double v2 = 0; 945 double v3 = func.applyAsDouble(LOWER_BOUND); 946 double px = LOWER_BOUND; 947 for (int i = 1; i < n; i++) { 948 final double x = LOWER_BOUND + i * inc; 949 final double v1 = v2; 950 v2 = v3; 951 v3 = func.applyAsDouble(x); 952 addCandidate(minima, v1, v2, v3, px); 953 px = x; 954 } 955 // Add the upper bound 956 final double x = 1 - LOWER_BOUND; 957 final double vn = func.applyAsDouble(x); 958 addCandidate(minima, v2, v3, vn, px); 959 addCandidate(minima, v3, vn, 0, x); 960 961 final double[] min = minima.getMinimum(); 962 963 // Optionally optimize the best point(s) (if not already optimal) 964 if (optimize && min[1] > -1) { 965 final BrentOptimizer opt = new BrentOptimizer(SOLVER_RELATIVE_EPS, Double.MIN_VALUE); 966 final BracketFinder bf = new BracketFinder(); 967 minima.forEach(candidate -> { 968 double a = candidate[0]; 969 final double fa; 970 // Attempt to bracket the minima. Use an initial second point placed relative to 971 // the size of the interval: [x - increment, x + increment]. 972 // if a < 0.5 then add a small delta ; otherwise subtract the delta. 973 final double b = a - Math.copySign(inc * INC_FRACTION, a - 0.5); 974 if (bf.search(func, a, b, 0, 1)) { 975 // The bracket a < b < c must have f(b) < min(f(a), f(b)) 976 final PointValuePair p = opt.optimize(func, bf.getLo(), bf.getHi(), bf.getMid(), bf.getFMid()); 977 a = p.getPoint(); 978 fa = p.getValue(); 979 } else { 980 // Mid-point is at one of the bounds (i.e. is 0 or 1) 981 a = bf.getMid(); 982 fa = bf.getFMid(); 983 } 984 if (fa < min[1]) { 985 min[0] = a; 986 min[1] = fa; 987 } 988 }); 989 } 990 // Reverse the sign of the p-value to create a maximum. 991 // Note that due to the summation the p-value can be above 1 so we clip the final result. 992 // Note: Apply max then reverse sign. This will pass through spurious NaN values if 993 // the p-value computation produced all NaNs. 994 min[1] = -Math.max(-1, min[1]); 995 return min; 996 } 997 998 /** 999 * Creates the binomial model p-value function for the nuisance parameter. 1000 * Note: This function computes the negative p-value so is suitable for 1001 * optimization by a search for a minimum. 1002 * 1003 * @param tableList List of tables. 1004 * @return the function 1005 */ 1006 private static DoubleUnaryOperator createBinomialModel(XYList tableList) { 1007 final int m = tableList.getMaxX(); 1008 final int n = tableList.getMaxY(); 1009 final int mn = m + n; 1010 // Compute the probability using logs 1011 final double[] c = new double[tableList.size()]; 1012 final int[] ij = new int[tableList.size()]; 1013 final int width = tableList.getWidth(); 1014 1015 // Compute the log binomial dynamically for a small number of values 1016 final IntToDoubleFunction binomM; 1017 final IntToDoubleFunction binomN; 1018 if (tableList.size() < mn) { 1019 binomM = k -> LogBinomialCoefficient.value(m, k); 1020 binomN = k -> LogBinomialCoefficient.value(n, k); 1021 } else { 1022 // Pre-compute all values 1023 binomM = createLogBinomialCoefficients(m); 1024 binomN = m == n ? binomM : createLogBinomialCoefficients(n); 1025 } 1026 1027 // Handle special cases i+j == 0 and i+j == m+n. 1028 // These will occur only once, if at all. Mark if they occur. 1029 int flag = 0; 1030 int j = 0; 1031 for (int i = 0; i < c.length; i++) { 1032 final int index = tableList.get(i); 1033 final int x = index % width; 1034 final int y = index / width; 1035 final int xy = x + y; 1036 if (xy == 0) { 1037 flag |= 1; 1038 } else if (xy == mn) { 1039 flag |= 2; 1040 } else { 1041 ij[j] = xy; 1042 c[j] = binomM.applyAsDouble(x) + binomN.applyAsDouble(y); 1043 j++; 1044 } 1045 } 1046 1047 final int size = j; 1048 final boolean ij0 = (flag & 1) != 0; 1049 final boolean ijmn = (flag & 2) != 0; 1050 return pi -> { 1051 final double logp = Math.log(pi); 1052 final double log1mp = Math.log1p(-pi); 1053 double sum = 0; 1054 for (int i = 0; i < size; i++) { 1055 // binom(m, i) * binom(n, j) * pi^(i+j) * (1-pi)^(m+n-i-j) 1056 sum += Math.exp(ij[i] * logp + (mn - ij[i]) * log1mp + c[i]); 1057 } 1058 // Add the simplified terms where the binomial is 1.0 and one power is x^0 == 1.0. 1059 // This avoids 0 * log(x) generating NaN when x is 0 in the case where pi was 0 or 1. 1060 // Reuse exp (not pow) to support pi approaching 0 or 1. 1061 if (ij0) { 1062 // pow(1-pi, mn) 1063 sum += Math.exp(mn * log1mp); 1064 } 1065 if (ijmn) { 1066 // pow(pi, mn) 1067 sum += Math.exp(mn * logp); 1068 } 1069 // The optimizer minimises the function so this returns -p. 1070 return -sum; 1071 }; 1072 } 1073 1074 /** 1075 * Create the natural logarithm of the binomial coefficient for all {@code k = [0, n]}. 1076 * 1077 * @param n Limit N. 1078 * @return ln binom(n, k) 1079 */ 1080 private static IntToDoubleFunction createLogBinomialCoefficients(int n) { 1081 final double[] binom = new double[n + 1]; 1082 // Exploit symmetry. 1083 // ignore: binom(n, 0) == binom(n, n) == 1 1084 int j = n - 1; 1085 for (int i = 1; i <= j; i++, j--) { 1086 binom[i] = binom[j] = LogBinomialCoefficient.value(n, i); 1087 } 1088 return k -> binom[k]; 1089 } 1090 1091 /** 1092 * Add point 2 to the list of minima if neither neighbour value is lower. 1093 * <pre> 1094 * !(v1 < v2 || v3 < v2) 1095 * </pre> 1096 * 1097 * @param minima Candidate minima. 1098 * @param v1 First point function value. 1099 * @param v2 Second point function value. 1100 * @param v3 Third point function value. 1101 * @param x2 Second point. 1102 */ 1103 private void addCandidate(Candidates minima, double v1, double v2, double v3, double x2) { 1104 final double min = v1 < v3 ? v1 : v3; 1105 if (min < v2) { 1106 // Lower neighbour(s) 1107 return; 1108 } 1109 // Add the candidate. This could be NaN but the candidate list handles this by storing 1110 // NaN only when no non-NaN values have been observed. 1111 minima.add(x2, v2); 1112 } 1113 1114 /** 1115 * Check the input is a 2-by-2 contingency table. 1116 * 1117 * @param table Contingency table. 1118 * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any 1119 * table entry is negative; any column sum is zero; the table sum is zero or not an 1120 * integer; or the number of possible tables exceeds the maximum array capacity. 1121 */ 1122 private static void checkTable(int[][] table) { 1123 Arguments.checkTable(table); 1124 // Must all be positive 1125 final int a = table[0][0]; 1126 final int c = table[1][0]; 1127 // checkTable has validated the total sum is < 2^31 1128 final int m = a + c; 1129 if (m == 0) { 1130 throw new InferenceException(InferenceException.ZERO_AT, COLUMN_SUM, 0); 1131 } 1132 final int b = table[0][1]; 1133 final int d = table[1][1]; 1134 final int n = b + d; 1135 if (n == 0) { 1136 throw new InferenceException(InferenceException.ZERO_AT, COLUMN_SUM, 1); 1137 } 1138 // Total possible tables must be a size we can track in an array (to compute the p-value) 1139 final long size = (m + 1L) * (n + 1L); 1140 if (size > MAX_TABLES) { 1141 throw new InferenceException(InferenceException.X_GT_Y, size, MAX_TABLES); 1142 } 1143 } 1144}