001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.statistics.distribution; 019 020import java.util.function.DoubleSupplier; 021import org.apache.commons.numbers.gamma.Erf; 022import org.apache.commons.numbers.gamma.ErfDifference; 023import org.apache.commons.numbers.gamma.Erfcx; 024import org.apache.commons.rng.UniformRandomProvider; 025import org.apache.commons.rng.sampling.distribution.ZigguratSampler; 026 027/** 028 * Implementation of the truncated normal distribution. 029 * 030 * <p>The probability density function of \( X \) is: 031 * 032 * <p>\[ f(x;\mu,\sigma,a,b) = \frac{1}{\sigma}\,\frac{\phi(\frac{x - \mu}{\sigma})}{\Phi(\frac{b - \mu}{\sigma}) - \Phi(\frac{a - \mu}{\sigma}) } \] 033 * 034 * <p>for \( \mu \) mean of the parent normal distribution, 035 * \( \sigma \) standard deviation of the parent normal distribution, 036 * \( -\infty \le a \lt b \le \infty \) the truncation interval, and 037 * \( x \in [a, b] \), where \( \phi \) is the probability 038 * density function of the standard normal distribution and \( \Phi \) 039 * is its cumulative distribution function. 040 * 041 * @see <a href="https://en.wikipedia.org/wiki/Truncated_normal_distribution"> 042 * Truncated normal distribution (Wikipedia)</a> 043 */ 044public final class TruncatedNormalDistribution extends AbstractContinuousDistribution { 045 046 /** The max allowed value for x where (x*x) will not overflow. 047 * This is a limit on computation of the moments of the truncated normal 048 * as some calculations assume x*x is finite. Value is sqrt(MAX_VALUE). */ 049 private static final double MAX_X = 0x1.fffffffffffffp511; 050 051 /** The min allowed probability range of the parent normal distribution. 052 * Set to 0.0. This may be too low for accurate usage. It is a signal that 053 * the truncation is invalid. */ 054 private static final double MIN_P = 0.0; 055 056 /** sqrt(2). */ 057 private static final double ROOT2 = Constants.ROOT_TWO; 058 /** Normalisation constant 2 / sqrt(2 pi) = sqrt(2 / pi). */ 059 private static final double ROOT_2_PI = Constants.ROOT_TWO_DIV_PI; 060 /** Normalisation constant sqrt(2 pi) / 2 = sqrt(pi / 2). */ 061 private static final double ROOT_PI_2 = Constants.ROOT_PI_DIV_TWO; 062 063 /** 064 * The threshold to switch to a rejection sampler. When the truncated 065 * distribution covers more than this fraction of the CDF then rejection 066 * sampling will be more efficient than inverse CDF sampling. Performance 067 * benchmarks indicate that a normalized Gaussian sampler is up to 10 times 068 * faster than inverse transform sampling using a fast random generator. See 069 * STATISTICS-55. 070 */ 071 private static final double REJECTION_THRESHOLD = 0.2; 072 073 /** Parent normal distribution. */ 074 private final NormalDistribution parentNormal; 075 /** Lower bound of this distribution. */ 076 private final double lower; 077 /** Upper bound of this distribution. */ 078 private final double upper; 079 080 /** Stored value of {@code parentNormal.probability(lower, upper)}. This is used to 081 * normalise the probability computations. */ 082 private final double cdfDelta; 083 /** log(cdfDelta). */ 084 private final double logCdfDelta; 085 /** Stored value of {@code parentNormal.cumulativeProbability(lower)}. Used to map 086 * a probability into the range of the parent normal distribution. */ 087 private final double cdfAlpha; 088 /** Stored value of {@code parentNormal.survivalProbability(upper)}. Used to map 089 * a probability into the range of the parent normal distribution. */ 090 private final double sfBeta; 091 092 /** 093 * @param parent Parent distribution. 094 * @param z Probability of the parent distribution for {@code [lower, upper]}. 095 * @param lower Lower bound (inclusive) of the distribution, can be {@link Double#NEGATIVE_INFINITY}. 096 * @param upper Upper bound (inclusive) of the distribution, can be {@link Double#POSITIVE_INFINITY}. 097 */ 098 private TruncatedNormalDistribution(NormalDistribution parent, double z, double lower, double upper) { 099 this.parentNormal = parent; 100 this.lower = lower; 101 this.upper = upper; 102 103 cdfDelta = z; 104 logCdfDelta = Math.log(cdfDelta); 105 // Used to map the inverse probability. 106 cdfAlpha = parentNormal.cumulativeProbability(lower); 107 sfBeta = parentNormal.survivalProbability(upper); 108 } 109 110 /** 111 * Creates a truncated normal distribution. 112 * 113 * <p>Note that the {@code mean} and {@code sd} is of the parent normal distribution, 114 * and not the true mean and standard deviation of the truncated normal distribution. 115 * The {@code lower} and {@code upper} bounds define the truncation of the parent 116 * normal distribution. 117 * 118 * @param mean Mean for the parent distribution. 119 * @param sd Standard deviation for the parent distribution. 120 * @param lower Lower bound (inclusive) of the distribution, can be {@link Double#NEGATIVE_INFINITY}. 121 * @param upper Upper bound (inclusive) of the distribution, can be {@link Double#POSITIVE_INFINITY}. 122 * @return the distribution 123 * @throws IllegalArgumentException if {@code sd <= 0}; if {@code lower >= upper}; or if 124 * the truncation covers no probability range in the parent distribution. 125 */ 126 public static TruncatedNormalDistribution of(double mean, double sd, double lower, double upper) { 127 if (sd <= 0) { 128 throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sd); 129 } 130 if (lower >= upper) { 131 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GTE_HIGH, lower, upper); 132 } 133 134 // Use an instance for the parent normal distribution to maximise accuracy 135 // in range computations using the error function 136 final NormalDistribution parent = NormalDistribution.of(mean, sd); 137 138 // If there is no computable range then raise an exception. 139 final double z = parent.probability(lower, upper); 140 if (z <= MIN_P) { 141 // Map the bounds to a standard normal distribution for the message 142 final double a = (lower - mean) / sd; 143 final double b = (upper - mean) / sd; 144 throw new DistributionException( 145 "Excess truncation of standard normal : CDF(%s, %s) = %s", a, b, z); 146 } 147 148 // Here we have a meaningful truncation. Note that excess truncation may not be optimal. 149 // For example truncation close to zero where the PDF is constant can be approximated 150 // using a uniform distribution. 151 152 return new TruncatedNormalDistribution(parent, z, lower, upper); 153 } 154 155 /** {@inheritDoc} */ 156 @Override 157 public double density(double x) { 158 if (x < lower || x > upper) { 159 return 0; 160 } 161 return parentNormal.density(x) / cdfDelta; 162 } 163 164 /** {@inheritDoc} */ 165 @Override 166 public double probability(double x0, double x1) { 167 if (x0 > x1) { 168 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, 169 x0, x1); 170 } 171 return parentNormal.probability(clipToRange(x0), clipToRange(x1)) / cdfDelta; 172 } 173 174 /** {@inheritDoc} */ 175 @Override 176 public double logDensity(double x) { 177 if (x < lower || x > upper) { 178 return Double.NEGATIVE_INFINITY; 179 } 180 return parentNormal.logDensity(x) - logCdfDelta; 181 } 182 183 /** {@inheritDoc} */ 184 @Override 185 public double cumulativeProbability(double x) { 186 if (x <= lower) { 187 return 0; 188 } else if (x >= upper) { 189 return 1; 190 } 191 return parentNormal.probability(lower, x) / cdfDelta; 192 } 193 194 /** {@inheritDoc} */ 195 @Override 196 public double survivalProbability(double x) { 197 if (x <= lower) { 198 return 1; 199 } else if (x >= upper) { 200 return 0; 201 } 202 return parentNormal.probability(x, upper) / cdfDelta; 203 } 204 205 /** {@inheritDoc} */ 206 @Override 207 public double inverseCumulativeProbability(double p) { 208 ArgumentUtils.checkProbability(p); 209 // Exact bound 210 if (p == 0) { 211 return lower; 212 } else if (p == 1) { 213 return upper; 214 } 215 // Linearly map p to the range [lower, upper] 216 final double x = parentNormal.inverseCumulativeProbability(cdfAlpha + p * cdfDelta); 217 return clipToRange(x); 218 } 219 220 /** {@inheritDoc} */ 221 @Override 222 public double inverseSurvivalProbability(double p) { 223 ArgumentUtils.checkProbability(p); 224 // Exact bound 225 if (p == 1) { 226 return lower; 227 } else if (p == 0) { 228 return upper; 229 } 230 // Linearly map p to the range [lower, upper] 231 final double x = parentNormal.inverseSurvivalProbability(sfBeta + p * cdfDelta); 232 return clipToRange(x); 233 } 234 235 /** {@inheritDoc} */ 236 @Override 237 public Sampler createSampler(UniformRandomProvider rng) { 238 // If the truncation covers a reasonable amount of the normal distribution 239 // then a rejection sampler can be used. 240 double threshold = REJECTION_THRESHOLD; 241 // If the truncation is entirely in the upper or lower half then adjust the 242 // threshold as twice the samples can be used 243 if (lower >= 0 || upper <= 0) { 244 threshold *= 0.5; 245 } 246 247 if (cdfDelta > threshold) { 248 // Create the rejection sampler 249 final ZigguratSampler.NormalizedGaussian sampler = ZigguratSampler.NormalizedGaussian.of(rng); 250 final DoubleSupplier gen; 251 // Use mirroring if possible 252 if (lower >= 0) { 253 // Return the upper-half of the Gaussian 254 gen = () -> Math.abs(sampler.sample()); 255 } else if (upper <= 0) { 256 // Return the lower-half of the Gaussian 257 gen = () -> -Math.abs(sampler.sample()); 258 } else { 259 // Return the full range of the Gaussian 260 gen = sampler::sample; 261 } 262 // Map the bounds to a standard normal distribution 263 final double u = parentNormal.getMean(); 264 final double s = parentNormal.getStandardDeviation(); 265 final double a = (lower - u) / s; 266 final double b = (upper - u) / s; 267 // Sample in [a, b] using rejection 268 return () -> { 269 double x = gen.getAsDouble(); 270 while (x < a || x > b) { 271 x = gen.getAsDouble(); 272 } 273 // Avoid floating-point error when mapping back 274 return clipToRange(u + x * s); 275 }; 276 } 277 278 // Default to an inverse CDF sampler 279 return super.createSampler(rng); 280 } 281 282 /** 283 * {@inheritDoc} 284 * 285 * <p>Represents the true mean of the truncated normal distribution rather 286 * than the parent normal distribution mean. 287 * 288 * <p>For \( \mu \) mean of the parent normal distribution, 289 * \( \sigma \) standard deviation of the parent normal distribution, and 290 * \( a \lt b \) the truncation interval of the parent normal distribution, the mean is: 291 * 292 * <p>\[ \mu + \frac{\phi(a)-\phi(b)}{\Phi(b) - \Phi(a)}\sigma \] 293 * 294 * <p>where \( \phi \) is the probability density function of the standard normal distribution 295 * and \( \Phi \) is its cumulative distribution function. 296 */ 297 @Override 298 public double getMean() { 299 final double u = parentNormal.getMean(); 300 final double s = parentNormal.getStandardDeviation(); 301 final double a = (lower - u) / s; 302 final double b = (upper - u) / s; 303 return u + moment1(a, b) * s; 304 } 305 306 /** 307 * {@inheritDoc} 308 * 309 * <p>Represents the true variance of the truncated normal distribution rather 310 * than the parent normal distribution variance. 311 * 312 * <p>For \( \mu \) mean of the parent normal distribution, 313 * \( \sigma \) standard deviation of the parent normal distribution, and 314 * \( a \lt b \) the truncation interval of the parent normal distribution, the variance is: 315 * 316 * <p>\[ \sigma^2 \left[1 + \frac{a\phi(a)-b\phi(b)}{\Phi(b) - \Phi(a)} - 317 * \left( \frac{\phi(a)-\phi(b)}{\Phi(b) - \Phi(a)} \right)^2 \right] \] 318 * 319 * <p>where \( \phi \) is the probability density function of the standard normal distribution 320 * and \( \Phi \) is its cumulative distribution function. 321 */ 322 @Override 323 public double getVariance() { 324 final double u = parentNormal.getMean(); 325 final double s = parentNormal.getStandardDeviation(); 326 final double a = (lower - u) / s; 327 final double b = (upper - u) / s; 328 return variance(a, b) * s * s; 329 } 330 331 /** 332 * {@inheritDoc} 333 * 334 * <p>The lower bound of the support is equal to the lower bound parameter 335 * of the distribution. 336 */ 337 @Override 338 public double getSupportLowerBound() { 339 return lower; 340 } 341 342 /** 343 * {@inheritDoc} 344 * 345 * <p>The upper bound of the support is equal to the upper bound parameter 346 * of the distribution. 347 */ 348 @Override 349 public double getSupportUpperBound() { 350 return upper; 351 } 352 353 /** 354 * Clip the value to the range [lower, upper]. 355 * This is used to handle floating-point error at the support bound. 356 * 357 * @param x Value x 358 * @return x clipped to the range 359 */ 360 private double clipToRange(double x) { 361 return clip(x, lower, upper); 362 } 363 364 /** 365 * Clip the value to the range [lower, upper]. 366 * 367 * @param x Value x 368 * @param lower Lower bound (inclusive) 369 * @param upper Upper bound (inclusive) 370 * @return x clipped to the range 371 */ 372 private static double clip(double x, double lower, double upper) { 373 if (x <= lower) { 374 return lower; 375 } 376 return x < upper ? x : upper; 377 } 378 379 // Calculation of variance and mean can suffer from cancellation. 380 // 381 // Use formulas from Jorge Fernandez-de-Cossio-Diaz adapted under the 382 // terms of the MIT "Expat" License (see NOTICE and LICENSE). 383 // 384 // These formulas use the complementary error function 385 // erfcx(z) = erfc(z) * exp(z^2) 386 // This avoids computation of exp terms for the Gaussian PDF and then 387 // dividing by the error functions erf or erfc: 388 // exp(-0.5*x*x) / erfc(x / sqrt(2)) == 1 / erfcx(x / sqrt(2)) 389 // At large z the erfcx function is computable but exp(-0.5*z*z) and 390 // erfc(z) are zero. Use of these formulas allows computation of the 391 // mean and variance for the usable range of the truncated distribution 392 // (cdf(a, b) != 0). The variance is not accurate when it approaches 393 // machine epsilon (2^-52) at extremely narrow truncations and the 394 // computation -> 0. 395 // 396 // See: https://github.com/cossio/TruncatedNormal.jl 397 398 /** 399 * Compute the first moment (mean) of the truncated standard normal distribution. 400 * 401 * <p>Assumes {@code a <= b}. 402 * 403 * @param a Lower bound 404 * @param b Upper bound 405 * @return the first moment 406 */ 407 static double moment1(double a, double b) { 408 // Assume a <= b 409 if (a == b) { 410 return a; 411 } 412 if (Math.abs(a) > Math.abs(b)) { 413 // Subtract from zero to avoid generating -0.0 414 return 0 - moment1(-b, -a); 415 } 416 417 // Here: 418 // |a| <= |b| 419 // a < b 420 // 0 < b 421 422 if (a <= -MAX_X) { 423 // No truncation 424 return 0; 425 } 426 if (b >= MAX_X) { 427 // One-sided truncation 428 return ROOT_2_PI / Erfcx.value(a / ROOT2); 429 } 430 431 // pdf = exp(-0.5*x*x) / sqrt(2*pi) 432 // cdf = erfc(-x/sqrt(2)) / 2 433 // Compute: 434 // -(pdf(b) - pdf(a)) / cdf(b, a) 435 // Note: 436 // exp(-0.5*b*b) - exp(-0.5*a*a) 437 // Use cancellation of powers: 438 // exp(-0.5*(b*b-a*a)) * exp(-0.5*a*a) - exp(-0.5*a*a) 439 // expm1(-0.5*(b*b-a*a)) * exp(-0.5*a*a) 440 441 // dx = -0.5*(b*b-a*a) 442 final double dx = 0.5 * (b + a) * (b - a); 443 final double m; 444 if (a <= 0) { 445 // Opposite signs 446 m = ROOT_2_PI * -Math.expm1(-dx) * Math.exp(-0.5 * a * a) / ErfDifference.value(a / ROOT2, b / ROOT2); 447 } else { 448 final double z = Math.exp(-dx) * Erfcx.value(b / ROOT2) - Erfcx.value(a / ROOT2); 449 if (z == 0) { 450 // Occurs when a and b have large magnitudes and are very close 451 return (a + b) * 0.5; 452 } 453 m = ROOT_2_PI * Math.expm1(-dx) / z; 454 } 455 456 // Clip to the range 457 return clip(m, a, b); 458 } 459 460 /** 461 * Compute the second moment of the truncated standard normal distribution. 462 * 463 * <p>Assumes {@code a <= b}. 464 * 465 * @param a Lower bound 466 * @param b Upper bound 467 * @return the first moment 468 */ 469 private static double moment2(double a, double b) { 470 // Assume a < b. 471 // a == b is handled in the variance method 472 if (Math.abs(a) > Math.abs(b)) { 473 return moment2(-b, -a); 474 } 475 476 // Here: 477 // |a| <= |b| 478 // a < b 479 // 0 < b 480 481 if (a <= -MAX_X) { 482 // No truncation 483 return 1; 484 } 485 if (b >= MAX_X) { 486 // One-sided truncation. 487 // For a -> inf : moment2 -> a*a 488 // This occurs when erfcx(z) is approximated by (1/sqrt(pi)) / z and terms 489 // cancel. z > 6.71e7, a > 9.49e7 490 return 1 + ROOT_2_PI * a / Erfcx.value(a / ROOT2); 491 } 492 493 // pdf = exp(-0.5*x*x) / sqrt(2*pi) 494 // cdf = erfc(-x/sqrt(2)) / 2 495 // Compute: 496 // 1 - (b*pdf(b) - a*pdf(a)) / cdf(b, a) 497 // = (cdf(b, a) - b*pdf(b) -a*pdf(a)) / cdf(b, a) 498 499 // Note: 500 // For z -> 0: 501 // sqrt(pi / 2) * erf(z / sqrt(2)) -> z 502 // z * Math.exp(-0.5 * z * z) -> z 503 // Both computations below have cancellation as b -> 0 and the 504 // second moment is not computable as the fraction P/Q 505 // since P < ulp(Q). This always occurs when b < MIN_X 506 // if MIN_X is set at the point where 507 // exp(-0.5 * z * z) / sqrt(2 pi) == 1 / sqrt(2 pi). 508 // This is JDK dependent due to variations in Math.exp. 509 // For b < MIN_X the second moment can be approximated using 510 // a uniform distribution: (b^3 - a^3) / (3b - 3a). 511 // In practice it also occurs when b > MIN_X since any a < MIN_X 512 // is effectively zero for part of the computation. A 513 // threshold to transition to a uniform distribution 514 // approximation is a compromise. Also note it will not 515 // correct computation when (b-a) is small and is far from 0. 516 // Thus the second moment is left to be inaccurate for 517 // small ranges (b-a) and the variance -> 0 when the true 518 // variance is close to or below machine epsilon. 519 520 double m; 521 522 if (a <= 0) { 523 // Opposite signs 524 final double ea = ROOT_PI_2 * Erf.value(a / ROOT2); 525 final double eb = ROOT_PI_2 * Erf.value(b / ROOT2); 526 final double fa = ea - a * Math.exp(-0.5 * a * a); 527 final double fb = eb - b * Math.exp(-0.5 * b * b); 528 // Assume fb >= fa && eb >= ea 529 // If fb <= fa this is a tiny range around 0 530 m = (fb - fa) / (eb - ea); 531 // Clip to the range 532 m = clip(m, 0, 1); 533 } else { 534 final double dx = 0.5 * (b + a) * (b - a); 535 final double ex = Math.exp(-dx); 536 final double ea = ROOT_PI_2 * Erfcx.value(a / ROOT2); 537 final double eb = ROOT_PI_2 * Erfcx.value(b / ROOT2); 538 final double fa = ea + a; 539 final double fb = eb + b; 540 m = (fa - fb * ex) / (ea - eb * ex); 541 // Clip to the range 542 m = clip(m, a * a, b * b); 543 } 544 return m; 545 } 546 547 /** 548 * Compute the variance of the truncated standard normal distribution. 549 * 550 * <p>Assumes {@code a <= b}. 551 * 552 * @param a Lower bound 553 * @param b Upper bound 554 * @return the first moment 555 */ 556 static double variance(double a, double b) { 557 if (a == b) { 558 return 0; 559 } 560 561 final double m1 = moment1(a, b); 562 double m2 = moment2(a, b); 563 // variance = m2 - m1*m1 564 // rearrange x^2 - y^2 as (x-y)(x+y) 565 m2 = Math.sqrt(m2); 566 final double variance = (m2 - m1) * (m2 + m1); 567 568 // Detect floating-point error. 569 if (variance >= 1) { 570 // Note: 571 // Extreme truncations in the tails can compute a variance above 1, 572 // for example if m2 is infinite: m2 - m1*m1 > 1 573 // Detect no truncation as the terms a and b lie far either side of zero; 574 // otherwise return 0 to indicate very small unknown variance. 575 return a < -1 && b > 1 ? 1 : 0; 576 } else if (variance <= 0) { 577 // Floating-point error can create negative variance so return 0. 578 return 0; 579 } 580 581 return variance; 582 } 583}