001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.statistics.distribution;
019
020import org.apache.commons.rng.UniformRandomProvider;
021import org.apache.commons.rng.sampling.distribution.DiscreteUniformSampler;
022
023/**
024 * Implementation of the uniform discrete distribution.
025 *
026 * <p>The probability mass function of \( X \) is:
027 *
028 * <p>\[ f(k; a, b) = \frac{1}{b-a+1} \]
029 *
030 * <p>for integer \( a, b \) and \( a \le b \) and
031 * \( k \in [a, b] \).
032 *
033 * @see <a href="https://en.wikipedia.org/wiki/Uniform_distribution_(discrete)">
034 * Uniform distribution (discrete) (Wikipedia)</a>
035 * @see <a href="https://mathworld.wolfram.com/DiscreteUniformDistribution.html">
036 * Discrete uniform distribution (MathWorld)</a>
037 */
038public final class UniformDiscreteDistribution extends AbstractDiscreteDistribution {
039    /** Lower bound (inclusive) of this distribution. */
040    private final int lower;
041    /** Upper bound (inclusive) of this distribution. */
042    private final int upper;
043    /** "upper" - "lower" + 1 (as a double to avoid overflow). */
044    private final double upperMinusLowerPlus1;
045    /** Cache of the probability. */
046    private final double pmf;
047    /** Cache of the log probability. */
048    private final double logPmf;
049    /** Value of survival probability for x=0. Used in the inverse survival function. */
050    private final double sf0;
051
052    /**
053     * @param lower Lower bound (inclusive) of this distribution.
054     * @param upper Upper bound (inclusive) of this distribution.
055     */
056    private UniformDiscreteDistribution(int lower,
057                                        int upper) {
058        this.lower = lower;
059        this.upper = upper;
060        upperMinusLowerPlus1 = (double) upper - lower + 1;
061        pmf = 1.0 / upperMinusLowerPlus1;
062        logPmf = -Math.log(upperMinusLowerPlus1);
063        sf0 = (upperMinusLowerPlus1 - 1) / upperMinusLowerPlus1;
064    }
065
066    /**
067     * Creates a new uniform discrete distribution.
068     *
069     * @param lower Lower bound (inclusive) of this distribution.
070     * @param upper Upper bound (inclusive) of this distribution.
071     * @return the distribution
072     * @throws IllegalArgumentException if {@code lower > upper}.
073     */
074    public static UniformDiscreteDistribution of(int lower,
075                                                 int upper) {
076        if (lower > upper) {
077            throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
078                                            lower, upper);
079        }
080        return new UniformDiscreteDistribution(lower, upper);
081    }
082
083    /** {@inheritDoc} */
084    @Override
085    public double probability(int x) {
086        if (x < lower || x > upper) {
087            return 0;
088        }
089        return pmf;
090    }
091
092    /** {@inheritDoc} */
093    @Override
094    public double probability(int x0,
095                              int x1) {
096        if (x0 > x1) {
097            throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
098        }
099        if (x0 >= upper || x1 < lower) {
100            // (x0, x1] does not overlap [lower, upper]
101            return 0;
102        }
103
104        // x0 < upper
105        // x1 >= lower
106
107        // Find the range between x0 (exclusive) and x1 (inclusive) within [lower, upper].
108        // In the case of x0 < lower set l so that u - l == (u - lower) + 1
109        // long arithmetic prevents overflow
110        final long l = Math.max(lower - 1L, x0);
111        final long u = Math.min(upper, x1);
112
113        return (u - l) / upperMinusLowerPlus1;
114    }
115
116    /** {@inheritDoc} */
117    @Override
118    public double logProbability(int x) {
119        if (x < lower || x > upper) {
120            return Double.NEGATIVE_INFINITY;
121        }
122        return logPmf;
123    }
124
125    /** {@inheritDoc} */
126    @Override
127    public double cumulativeProbability(int x) {
128        if (x <= lower) {
129            // Note: CDF(x=0) = PDF(x=0)
130            return x == lower ? pmf : 0;
131        }
132        if (x >= upper) {
133            return 1;
134        }
135        return ((double) x - lower + 1) / upperMinusLowerPlus1;
136    }
137
138    /** {@inheritDoc} */
139    @Override
140    public double survivalProbability(int x) {
141        if (x <= lower) {
142            // Note: SF(x=0) = 1 - PDF(x=0)
143            // Use a pre-computed value to avoid cancellation when probabilityOfSuccess -> 0
144            return x == lower ? sf0 : 1;
145        }
146        if (x >= upper) {
147            return 0;
148        }
149        return ((double) upper - x) / upperMinusLowerPlus1;
150    }
151
152    /** {@inheritDoc} */
153    @Override
154    public int inverseCumulativeProbability(double p) {
155        ArgumentUtils.checkProbability(p);
156        if (p > sf0) {
157            return upper;
158        }
159        if (p <= pmf) {
160            return lower;
161        }
162        // p in ( pmf         , sf0             ]
163        // p in ( 1 / {u-l+1} , {u-l} / {u-l+1} ]
164        // x in ( l           , u-1             ]
165        int x = (int) (lower + Math.ceil(p * upperMinusLowerPlus1) - 1);
166
167        // Correct rounding errors.
168        // This ensures x == icdf(cdf(x))
169        // Note: Directly computing the CDF(x-1) avoids integer overflow if x=min_value
170
171        if (((double) x - lower) / upperMinusLowerPlus1 >= p) {
172            // No check for x > lower: cdf(x=lower) = 0 and thus is below p
173            // cdf(x-1) >= p
174            x--;
175        } else if (((double) x - lower + 1) / upperMinusLowerPlus1 < p) {
176            // No check for x < upper: cdf(x=upper) = 1 and thus is above p
177            // cdf(x) < p
178            x++;
179        }
180
181        return x;
182    }
183
184    /** {@inheritDoc} */
185    @Override
186    public int inverseSurvivalProbability(final double p) {
187        ArgumentUtils.checkProbability(p);
188        if (p < pmf) {
189            return upper;
190        }
191        if (p >= sf0) {
192            return lower;
193        }
194        // p in [ pmf         , sf0             )
195        // p in [ 1 / {u-l+1} , {u-l} / {u-l+1} )
196        // x in [ u-1         , l               )
197        int x = (int) (upper - Math.floor(p * upperMinusLowerPlus1));
198
199        // Correct rounding errors.
200        // This ensures x == isf(sf(x))
201        // Note: Directly computing the SF(x-1) avoids integer overflow if x=min_value
202
203        if (((double) upper - x + 1) / upperMinusLowerPlus1 <= p) {
204            // No check for x > lower: sf(x=lower) = 1 and thus is above p
205            // sf(x-1) <= p
206            x--;
207        } else if (((double) upper - x) / upperMinusLowerPlus1 > p) {
208            // No check for x < upper: sf(x=upper) = 0 and thus is below p
209            // sf(x) > p
210            x++;
211        }
212
213        return x;
214    }
215
216    /**
217     * {@inheritDoc}
218     *
219     * <p>For lower bound \( a \) and upper bound \( b \), the mean is \( \frac{1}{2} (a + b) \).
220     */
221    @Override
222    public double getMean() {
223        // Avoid overflow
224        return 0.5 * ((double) upper + lower);
225    }
226
227    /**
228     * {@inheritDoc}
229     *
230     * <p>For lower bound \( a \) and upper bound \( b \), the variance is:
231     *
232     * <p>\[ \frac{1}{12} (n^2 - 1) \]
233     *
234     * <p>where \( n = b - a + 1 \).
235     */
236    @Override
237    public double getVariance() {
238        return (upperMinusLowerPlus1 * upperMinusLowerPlus1 - 1) / 12;
239    }
240
241    /**
242     * {@inheritDoc}
243     *
244     * <p>The lower bound of the support is equal to the lower bound parameter
245     * of the distribution.
246     */
247    @Override
248    public int getSupportLowerBound() {
249        return lower;
250    }
251
252    /**
253     * {@inheritDoc}
254     *
255     * <p>The upper bound of the support is equal to the upper bound parameter
256     * of the distribution.
257     */
258    @Override
259    public int getSupportUpperBound() {
260        return upper;
261    }
262
263    /** {@inheritDoc} */
264    @Override
265    public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) {
266        // Discrete uniform distribution sampler.
267        return DiscreteUniformSampler.of(rng, lower, upper)::sample;
268    }
269}