001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.rng.sampling.distribution;
018
019import org.apache.commons.rng.UniformRandomProvider;
020import org.apache.commons.rng.sampling.SharedStateObjectSampler;
021
022/**
023 * Sampling from a <a href="https://en.wikipedia.org/wiki/Dirichlet_distribution">Dirichlet
024 * distribution</a>.
025 *
026 * <p>Sampling uses:</p>
027 *
028 * <ul>
029 *   <li>{@link UniformRandomProvider#nextLong()}
030 *   <li>{@link UniformRandomProvider#nextDouble()}
031 * </ul>
032 *
033 * @since 1.4
034 */
035public abstract class DirichletSampler implements SharedStateObjectSampler<double[]> {
036    /** The minimum number of categories. */
037    private static final int MIN_CATGEORIES = 2;
038
039    /** RNG (used for the toString() method). */
040    private final UniformRandomProvider rng;
041
042    /**
043     * Sample from a Dirichlet distribution with different concentration parameters
044     * for each category.
045     */
046    private static final class GeneralDirichletSampler extends DirichletSampler {
047        /** Samplers for each category. */
048        private final SharedStateContinuousSampler[] samplers;
049
050        /**
051         * @param rng Generator of uniformly distributed random numbers.
052         * @param samplers Samplers for each category.
053         */
054        GeneralDirichletSampler(UniformRandomProvider rng,
055                                SharedStateContinuousSampler[] samplers) {
056            super(rng);
057            // Array is stored directly as it is generated within the DirichletSampler class
058            this.samplers = samplers;
059        }
060
061        @Override
062        protected int getK() {
063            return samplers.length;
064        }
065
066        @Override
067        protected double nextGamma(int i) {
068            return samplers[i].sample();
069        }
070
071        @Override
072        public GeneralDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
073            final SharedStateContinuousSampler[] newSamplers = new SharedStateContinuousSampler[samplers.length];
074            for (int i = 0; i < newSamplers.length; i++) {
075                newSamplers[i] = samplers[i].withUniformRandomProvider(rng);
076            }
077            return new GeneralDirichletSampler(rng, newSamplers);
078        }
079    }
080
081    /**
082     * Sample from a symmetric Dirichlet distribution with the same concentration parameter
083     * for each category.
084     */
085    private static final class SymmetricDirichletSampler extends DirichletSampler {
086        /** Number of categories. */
087        private final int k;
088        /** Sampler for the categories. */
089        private final SharedStateContinuousSampler sampler;
090
091        /**
092         * @param rng Generator of uniformly distributed random numbers.
093         * @param k Number of categories.
094         * @param sampler Sampler for the categories.
095         */
096        SymmetricDirichletSampler(UniformRandomProvider rng,
097                                  int k,
098                                  SharedStateContinuousSampler sampler) {
099            super(rng);
100            this.k = k;
101            this.sampler = sampler;
102        }
103
104        @Override
105        protected int getK() {
106            return k;
107        }
108
109        @Override
110        protected double nextGamma(int i) {
111            return sampler.sample();
112        }
113
114        @Override
115        public SymmetricDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
116            return new SymmetricDirichletSampler(rng, k, sampler.withUniformRandomProvider(rng));
117        }
118    }
119
120    /**
121     * @param rng Generator of uniformly distributed random numbers.
122     */
123    DirichletSampler(UniformRandomProvider rng) {
124        this.rng = rng;
125    }
126
127    /** {@inheritDoc} */
128    @Override
129    public String toString() {
130        return "Dirichlet deviate [" + rng.toString() + "]";
131    }
132
133    /** {@inheritDoc} */
134    @Override
135    public double[] sample() {
136        // Create Gamma(alpha_i, 1) deviates for all alpha
137        final double[] y = new double[getK()];
138        double norm = 0;
139        for (int i = 0; i < y.length; i++) {
140            final double yi = nextGamma(i);
141            norm += yi;
142            y[i] = yi;
143        }
144        // Normalize by dividing by the sum of the samples
145        norm = 1.0 / norm;
146        // Detect an invalid normalization, e.g. cases of all zero samples
147        if (!isNonZeroPositiveFinite(norm)) {
148            // Sample again using recursion.
149            // A stack overflow due to a broken RNG will eventually occur
150            // rather than the alternative which is an infinite loop.
151            return sample();
152        }
153        // Normalise
154        for (int i = 0; i < y.length; i++) {
155            y[i] *= norm;
156        }
157        return y;
158    }
159
160    /**
161     * Gets the number of categories.
162     *
163     * @return k
164     */
165    protected abstract int getK();
166
167    /**
168     * Create a gamma sample for the given category.
169     *
170     * @param category Category.
171     * @return the sample
172     */
173    protected abstract double nextGamma(int category);
174
175    /** {@inheritDoc} */
176    // Redeclare the signature to return a DirichletSampler not a SharedStateObjectSampler<double[]>
177    @Override
178    public abstract DirichletSampler withUniformRandomProvider(UniformRandomProvider rng);
179
180    /**
181     * Creates a new Dirichlet distribution sampler.
182     *
183     * @param rng Generator of uniformly distributed random numbers.
184     * @param alpha Concentration parameters.
185     * @return the sampler
186     * @throws IllegalArgumentException if the number of concentration parameters
187     * is less than 2; or if any concentration parameter is not strictly positive.
188     */
189    public static DirichletSampler of(UniformRandomProvider rng,
190                                      double... alpha) {
191        validateNumberOfCategories(alpha.length);
192        final SharedStateContinuousSampler[] samplers = new SharedStateContinuousSampler[alpha.length];
193        for (int i = 0; i < samplers.length; i++) {
194            samplers[i] = createSampler(rng, alpha[i]);
195        }
196        return new GeneralDirichletSampler(rng, samplers);
197    }
198
199    /**
200     * Creates a new symmetric Dirichlet distribution sampler using the same concentration
201     * parameter for each category.
202     *
203     * @param rng Generator of uniformly distributed random numbers.
204     * @param k Number of categories.
205     * @param alpha Concentration parameter.
206     * @return the sampler
207     * @throws IllegalArgumentException if the number of categories is
208     * less than 2; or if the concentration parameter is not strictly positive.
209     */
210    public static DirichletSampler symmetric(UniformRandomProvider rng,
211                                             int k,
212                                             double alpha) {
213        validateNumberOfCategories(k);
214        final SharedStateContinuousSampler sampler = createSampler(rng, alpha);
215        return new SymmetricDirichletSampler(rng, k, sampler);
216    }
217
218    /**
219     * Validate the number of categories.
220     *
221     * @param k Number of categories.
222     * @throws IllegalArgumentException if the number of categories is
223     * less than 2.
224     */
225    private static void validateNumberOfCategories(int k) {
226        if (k < MIN_CATGEORIES) {
227            throw new IllegalArgumentException("Invalid number of categories: " + k);
228        }
229    }
230
231    /**
232     * Creates a gamma sampler for a category with the given concentration parameter.
233     *
234     * @param rng Generator of uniformly distributed random numbers.
235     * @param alpha Concentration parameter.
236     * @return the sampler
237     * @throws IllegalArgumentException if the concentration parameter is not strictly positive.
238     */
239    private static SharedStateContinuousSampler createSampler(UniformRandomProvider rng,
240                                                              double alpha) {
241        InternalUtils.requireStrictlyPositiveFinite(alpha, "alpha concentration");
242        // Create a Gamma(shape=alpha, scale=1) sampler.
243        if (alpha == 1) {
244            // Special case
245            // Gamma(shape=1, scale=1) == Exponential(mean=1)
246            return ZigguratSampler.Exponential.of(rng);
247        }
248        return AhrensDieterMarsagliaTsangGammaSampler.of(rng, alpha, 1);
249    }
250
251    /**
252     * Return true if the value is non-zero, positive and finite.
253     *
254     * @param x Value.
255     * @return true if non-zero positive finite
256     */
257    private static boolean isNonZeroPositiveFinite(double x) {
258        return x > 0 && x < Double.POSITIVE_INFINITY;
259    }
260}