001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.statistics.distribution;
019
020import org.apache.commons.numbers.gamma.ErfDifference;
021import org.apache.commons.numbers.gamma.Erfc;
022import org.apache.commons.numbers.gamma.InverseErfc;
023import org.apache.commons.rng.UniformRandomProvider;
024import org.apache.commons.rng.sampling.distribution.GaussianSampler;
025import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
026
027/**
028 * Implementation of the normal (Gaussian) distribution.
029 *
030 * <p>The probability density function of \( X \) is:
031 *
032 * <p>\[ f(x; \mu, \sigma) = \frac 1 {\sigma\sqrt{2\pi}} e^{-{\frac 1 2}\left( \frac{x-\mu}{\sigma} \right)^2 } \]
033 *
034 * <p>for \( \mu \) the mean,
035 * \( \sigma &gt; 0 \) the standard deviation, and
036 * \( x \in (-\infty, \infty) \).
037 *
038 * @see <a href="https://en.wikipedia.org/wiki/Normal_distribution">Normal distribution (Wikipedia)</a>
039 * @see <a href="https://mathworld.wolfram.com/NormalDistribution.html">Normal distribution (MathWorld)</a>
040 */
041public final class NormalDistribution extends AbstractContinuousDistribution {
042    /** Mean of this distribution. */
043    private final double mean;
044    /** Standard deviation of this distribution. */
045    private final double standardDeviation;
046    /** The value of {@code log(sd) + 0.5*log(2*pi)} stored for faster computation. */
047    private final double logStandardDeviationPlusHalfLog2Pi;
048    /**
049     * Standard deviation multiplied by sqrt(2).
050     * This is used to avoid a double division when computing the value passed to the
051     * error function:
052     * <pre>
053     *  ((x - u) / sd) / sqrt(2) == (x - u) / (sd * sqrt(2)).
054     *  </pre>
055     * <p>Note: Implementations may first normalise x and then divide by sqrt(2) resulting
056     * in differences due to rounding error that show increasingly large relative
057     * differences as the error function computes close to 0 in the extreme tail.
058     */
059    private final double sdSqrt2;
060    /**
061     * Standard deviation multiplied by sqrt(2 pi). Computed to high precision.
062     */
063    private final double sdSqrt2pi;
064
065    /**
066     * @param mean Mean for this distribution.
067     * @param sd Standard deviation for this distribution.
068     */
069    private NormalDistribution(double mean,
070                               double sd) {
071        this.mean = mean;
072        standardDeviation = sd;
073        logStandardDeviationPlusHalfLog2Pi = Math.log(sd) + Constants.HALF_LOG_TWO_PI;
074        // Minimise rounding error by computing sqrt(2 * sd * sd) exactly.
075        // Compute using extended precision with care to avoid over/underflow.
076        sdSqrt2 = ExtendedPrecision.sqrt2xx(sd);
077        // Compute sd * sqrt(2 * pi)
078        sdSqrt2pi = ExtendedPrecision.xsqrt2pi(sd);
079    }
080
081    /**
082     * Creates a normal distribution.
083     *
084     * @param mean Mean for this distribution.
085     * @param sd Standard deviation for this distribution.
086     * @return the distribution
087     * @throws IllegalArgumentException if {@code sd <= 0}.
088     */
089    public static NormalDistribution of(double mean,
090                                        double sd) {
091        if (sd > 0) {
092            return new NormalDistribution(mean, sd);
093        }
094        // zero, negative or nan
095        throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sd);
096    }
097
098    /**
099     * Gets the standard deviation parameter of this distribution.
100     *
101     * @return the standard deviation.
102     */
103    public double getStandardDeviation() {
104        return standardDeviation;
105    }
106
107    /** {@inheritDoc} */
108    @Override
109    public double density(double x) {
110        final double z = (x - mean) / standardDeviation;
111        return ExtendedPrecision.expmhxx(z) / sdSqrt2pi;
112    }
113
114    /** {@inheritDoc} */
115    @Override
116    public double probability(double x0,
117                              double x1) {
118        if (x0 > x1) {
119            throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
120                                            x0, x1);
121        }
122        final double v0 = (x0 - mean) / sdSqrt2;
123        final double v1 = (x1 - mean) / sdSqrt2;
124        return 0.5 * ErfDifference.value(v0, v1);
125    }
126
127    /** {@inheritDoc} */
128    @Override
129    public double logDensity(double x) {
130        final double z = (x - mean) / standardDeviation;
131        return -0.5 * z * z - logStandardDeviationPlusHalfLog2Pi;
132    }
133
134    /** {@inheritDoc} */
135    @Override
136    public double cumulativeProbability(double x)  {
137        final double dev = x - mean;
138        return 0.5 * Erfc.value(-dev / sdSqrt2);
139    }
140
141    /** {@inheritDoc} */
142    @Override
143    public double survivalProbability(double x) {
144        final double dev = x - mean;
145        return 0.5 * Erfc.value(dev / sdSqrt2);
146    }
147
148    /** {@inheritDoc} */
149    @Override
150    public double inverseCumulativeProbability(double p) {
151        ArgumentUtils.checkProbability(p);
152        return mean - sdSqrt2 * InverseErfc.value(2 * p);
153    }
154
155    /** {@inheritDoc} */
156    @Override
157    public double inverseSurvivalProbability(double p) {
158        ArgumentUtils.checkProbability(p);
159        return mean + sdSqrt2 * InverseErfc.value(2 * p);
160    }
161
162    /** {@inheritDoc} */
163    @Override
164    public double getMean() {
165        return mean;
166    }
167
168    /**
169     * {@inheritDoc}
170     *
171     * <p>For standard deviation parameter \( \sigma \), the variance is \( \sigma^2 \).
172     */
173    @Override
174    public double getVariance() {
175        final double s = getStandardDeviation();
176        return s * s;
177    }
178
179    /**
180     * {@inheritDoc}
181     *
182     * <p>The lower bound of the support is always negative infinity.
183     *
184     * @return {@linkplain Double#NEGATIVE_INFINITY negative infinity}.
185     */
186    @Override
187    public double getSupportLowerBound() {
188        return Double.NEGATIVE_INFINITY;
189    }
190
191    /**
192     * {@inheritDoc}
193     *
194     * <p>The upper bound of the support is always positive infinity.
195     *
196     * @return {@linkplain Double#POSITIVE_INFINITY positive infinity}.
197     */
198    @Override
199    public double getSupportUpperBound() {
200        return Double.POSITIVE_INFINITY;
201    }
202
203    /** {@inheritDoc} */
204    @Override
205    public ContinuousDistribution.Sampler createSampler(final UniformRandomProvider rng) {
206        // Gaussian distribution sampler.
207        return GaussianSampler.of(ZigguratSampler.NormalizedGaussian.of(rng),
208                                  mean, standardDeviation)::sample;
209    }
210}