001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.statistics.distribution;
018
019import org.apache.commons.numbers.gamma.Erf;
020import org.apache.commons.numbers.gamma.ErfDifference;
021import org.apache.commons.numbers.gamma.Erfc;
022import org.apache.commons.numbers.gamma.InverseErf;
023import org.apache.commons.numbers.gamma.InverseErfc;
024import org.apache.commons.rng.UniformRandomProvider;
025import org.apache.commons.rng.sampling.distribution.GaussianSampler;
026import org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler;
027import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
028
029/**
030 * Implementation of the folded normal distribution.
031 *
032 * <p>Given a normally distributed random variable \( X \) with mean \( \mu \) and variance
033 * \( \sigma^2 \), the random variable \( Y = |X| \) has a folded normal distribution. This is
034 * equivalent to not recording the sign from a normally distributed random variable.
035 *
036 * <p>The probability density function of \( X \) is:
037 *
038 * <p>\[ f(x; \mu, \sigma) = \frac 1 {\sigma\sqrt{2\pi}} e^{-{\frac 1 2}\left( \frac{x-\mu}{\sigma} \right)^2 } +
039 *                           \frac 1 {\sigma\sqrt{2\pi}} e^{-{\frac 1 2}\left( \frac{x+\mu}{\sigma} \right)^2 }\]
040 *
041 * <p>for \( \mu \) the location,
042 * \( \sigma &gt; 0 \) the scale, and
043 * \( x \in [0, \infty) \).
044 *
045 * <p>If the location \( \mu \) is 0 this reduces to the half-normal distribution.
046 *
047 * @see <a href="https://en.wikipedia.org/wiki/Folded_normal_distribution">Folded normal distribution (Wikipedia)</a>
048 * @see <a href="https://en.wikipedia.org/wiki/Half-normal_distribution">Half-normal distribution (Wikipedia)</a>
049 * @since 1.1
050 */
051public abstract class FoldedNormalDistribution extends AbstractContinuousDistribution {
052    /** The scale. */
053    final double sigma;
054    /**
055     * The scale multiplied by sqrt(2).
056     * This is used to avoid a double division when computing the value passed to the
057     * error function:
058     * <pre>
059     *  ((x - u) / sd) / sqrt(2) == (x - u) / (sd * sqrt(2)).
060     *  </pre>
061     * <p>Note: Implementations may first normalise x and then divide by sqrt(2) resulting
062     * in differences due to rounding error that show increasingly large relative
063     * differences as the error function computes close to 0 in the extreme tail.
064     */
065    final double sigmaSqrt2;
066    /**
067     * The scale multiplied by sqrt(2 pi). Computed to high precision.
068     */
069    final double sigmaSqrt2pi;
070
071    /**
072     * Regular implementation of the folded normal distribution.
073     */
074    private static class RegularFoldedNormalDistribution extends FoldedNormalDistribution {
075        /** The location. */
076        private final double mu;
077        /** Cached value for inverse probability function. */
078        private final double mean;
079        /** Cached value for inverse probability function. */
080        private final double variance;
081
082        /**
083         * @param mu Location parameter.
084         * @param sigma Scale parameter.
085         */
086        RegularFoldedNormalDistribution(double mu, double sigma) {
087            super(sigma);
088            this.mu = mu;
089
090            final double a = mu / sigmaSqrt2;
091            mean = sigma * Constants.ROOT_TWO_DIV_PI * Math.exp(-a * a) + mu * Erf.value(a);
092            this.variance = mu * mu + sigma * sigma - mean * mean;
093        }
094
095        @Override
096        public double getMu() {
097            return mu;
098        }
099
100        @Override
101        public double density(double x) {
102            if (x < 0) {
103                return 0;
104            }
105            final double vm = (x - mu) / sigma;
106            final double vp = (x + mu) / sigma;
107            return (ExtendedPrecision.expmhxx(vm) + ExtendedPrecision.expmhxx(vp)) / sigmaSqrt2pi;
108        }
109
110        @Override
111        public double probability(double x0,
112                                  double x1) {
113            if (x0 > x1) {
114                throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
115                                                x0, x1);
116            }
117            if (x0 <= 0) {
118                return cumulativeProbability(x1);
119            }
120            // Assumes x1 >= x0 && x0 > 0
121            final double v0m = (x0 - mu) / sigmaSqrt2;
122            final double v1m = (x1 - mu) / sigmaSqrt2;
123            final double v0p = (x0 + mu) / sigmaSqrt2;
124            final double v1p = (x1 + mu) / sigmaSqrt2;
125            return 0.5 * (ErfDifference.value(v0m, v1m) + ErfDifference.value(v0p, v1p));
126        }
127
128        @Override
129        public double cumulativeProbability(double x) {
130            if (x <= 0) {
131                return 0;
132            }
133            return 0.5 * (Erf.value((x - mu) / sigmaSqrt2) + Erf.value((x + mu) / sigmaSqrt2));
134        }
135
136        @Override
137        public double survivalProbability(double x) {
138            if (x <= 0) {
139                return 1;
140            }
141            return 0.5 * (Erfc.value((x - mu) / sigmaSqrt2) + Erfc.value((x + mu) / sigmaSqrt2));
142        }
143
144        @Override
145        public double getMean() {
146            return mean;
147        }
148
149        @Override
150        public double getVariance() {
151            return variance;
152        }
153
154        @Override
155        public Sampler createSampler(UniformRandomProvider rng) {
156            // Return the absolute of a Gaussian distribution sampler.
157            final SharedStateContinuousSampler s =
158                GaussianSampler.of(ZigguratSampler.NormalizedGaussian.of(rng), mu, sigma);
159            return () -> Math.abs(s.sample());
160        }
161    }
162
163    /**
164     * Specialisation for the half-normal distribution.
165     *
166     * <p>Elimination of the {@code mu} location parameter simplifies the probability
167     * functions and allows computation of the log density and inverse CDF/SF.
168     */
169    private static class HalfNormalDistribution extends FoldedNormalDistribution {
170        /** Variance constant (1 - 2/pi). Computed using Matlab's VPA to 30 digits. */
171        private static final double VAR = 0.36338022763241865692446494650994;
172        /** The value of {@code log(sigma) + 0.5 * log(2*PI)} stored for faster computation. */
173        private final double logSigmaPlusHalfLog2Pi;
174
175        /**
176         * @param sigma Scale parameter.
177         */
178        HalfNormalDistribution(double sigma) {
179            super(sigma);
180            logSigmaPlusHalfLog2Pi = Math.log(sigma) + Constants.HALF_LOG_TWO_PI;
181        }
182
183        @Override
184        public double getMu() {
185            return 0;
186        }
187
188        @Override
189        public double density(double x) {
190            if (x < 0) {
191                return 0;
192            }
193            return 2 * ExtendedPrecision.expmhxx(x / sigma) / sigmaSqrt2pi;
194        }
195
196        @Override
197        public double probability(double x0,
198                                  double x1) {
199            if (x0 > x1) {
200                throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
201                                                x0, x1);
202            }
203            if (x0 <= 0) {
204                return cumulativeProbability(x1);
205            }
206            // Assumes x1 >= x0 && x0 > 0
207            return ErfDifference.value(x0 / sigmaSqrt2, x1 / sigmaSqrt2);
208        }
209
210        @Override
211        public double logDensity(double x) {
212            if (x < 0) {
213                return Double.NEGATIVE_INFINITY;
214            }
215            final double z = x / sigma;
216            return Constants.LN_TWO - 0.5 * z * z - logSigmaPlusHalfLog2Pi;
217        }
218
219        @Override
220        public double cumulativeProbability(double x) {
221            if (x <= 0) {
222                return 0;
223            }
224            return Erf.value(x / sigmaSqrt2);
225        }
226
227        @Override
228        public double survivalProbability(double x) {
229            if (x <= 0) {
230                return 1;
231            }
232            return Erfc.value(x / sigmaSqrt2);
233        }
234
235        @Override
236        public double inverseCumulativeProbability(double p) {
237            ArgumentUtils.checkProbability(p);
238            // Addition of 0.0 ensures 0.0 is returned for p=-0.0
239            return 0.0 + sigmaSqrt2 * InverseErf.value(p);
240        }
241
242        /** {@inheritDoc} */
243        @Override
244        public double inverseSurvivalProbability(double p) {
245            ArgumentUtils.checkProbability(p);
246            return sigmaSqrt2 * InverseErfc.value(p);
247        }
248
249        @Override
250        public double getMean() {
251            return sigma * Constants.ROOT_TWO_DIV_PI;
252        }
253
254        @Override
255        public double getVariance() {
256            // sigma^2 - mean^2
257            // sigma^2 - (sigma^2 * 2/pi)
258            return sigma * sigma * VAR;
259        }
260
261        @Override
262        public Sampler createSampler(UniformRandomProvider rng) {
263            // Return the absolute of a Gaussian distribution sampler.
264            final SharedStateContinuousSampler s = ZigguratSampler.NormalizedGaussian.of(rng);
265            return () -> Math.abs(s.sample() * sigma);
266        }
267    }
268
269    /**
270     * @param sigma Scale parameter.
271     */
272    FoldedNormalDistribution(double sigma) {
273        this.sigma = sigma;
274        // Minimise rounding error by computing sqrt(2 * sigma * sigma) exactly.
275        // Compute using extended precision with care to avoid over/underflow.
276        sigmaSqrt2 = ExtendedPrecision.sqrt2xx(sigma);
277        // Compute sigma * sqrt(2 * pi)
278        sigmaSqrt2pi = ExtendedPrecision.xsqrt2pi(sigma);
279    }
280
281    /**
282     * Creates a folded normal distribution. If the location {@code mu} is zero this is
283     * the half-normal distribution.
284     *
285     * @param mu Location parameter.
286     * @param sigma Scale parameter.
287     * @return the distribution
288     * @throws IllegalArgumentException if {@code sigma <= 0}.
289     */
290    public static FoldedNormalDistribution of(double mu,
291                                              double sigma) {
292        if (sigma > 0) {
293            if (mu == 0) {
294                return new HalfNormalDistribution(sigma);
295            }
296            return new RegularFoldedNormalDistribution(mu, sigma);
297        }
298        // scale is zero, negative or nan
299        throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sigma);
300    }
301
302    /**
303     * Gets the location parameter \( \mu \) of this distribution.
304     *
305     * @return the mu parameter.
306     */
307    public abstract double getMu();
308
309    /**
310     * Gets the scale parameter \( \sigma \) of this distribution.
311     *
312     * @return the sigma parameter.
313     */
314    public double getSigma() {
315        return sigma;
316    }
317
318    /**
319     * {@inheritDoc}
320     *
321     *
322     * <p>For location parameter \( \mu \) and scale parameter \( \sigma \), the mean is:
323     *
324     * <p>\[ \sigma \sqrt{ \frac 2 \pi } \exp \left( \frac{-\mu^2}{2\sigma^2} \right) +
325     *       \mu \operatorname{erf} \left( \frac \mu {\sqrt{2\sigma^2}} \right) \]
326     *
327     * <p>where \( \operatorname{erf} \) is the error function.
328     */
329    @Override
330    public abstract double getMean();
331
332    /**
333     * {@inheritDoc}
334     *
335     * <p>For location parameter \( \mu \), scale parameter \( \sigma \) and a distribution
336     * mean \( \mu_Y \), the variance is:
337     *
338     * <p>\[ \mu^2 + \sigma^2 - \mu_{Y}^2 \]
339     */
340    @Override
341    public abstract double getVariance();
342
343    /**
344     * {@inheritDoc}
345     *
346     * <p>The lower bound of the support is always 0.
347     *
348     * @return 0.
349     */
350    @Override
351    public double getSupportLowerBound() {
352        return 0.0;
353    }
354
355    /**
356     * {@inheritDoc}
357     *
358     * <p>The upper bound of the support is always positive infinity.
359     *
360     * @return {@linkplain Double#POSITIVE_INFINITY positive infinity}.
361     */
362    @Override
363    public double getSupportUpperBound() {
364        return Double.POSITIVE_INFINITY;
365    }
366}