001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.rng.sampling;
019
020import java.util.List;
021import java.util.Map;
022import java.util.ArrayList;
023import org.apache.commons.rng.UniformRandomProvider;
024import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
025import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler;
026
027/**
028 * Sampling from a collection of items with user-defined
029 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
030 * probabilities</a>.
031 * Note that if all unique items are assigned the same probability,
032 * it is much more efficient to use {@link CollectionSampler}.
033 *
034 * <p>Sampling uses {@link UniformRandomProvider#nextDouble()}.</p>
035 *
036 * @param <T> Type of items in the collection.
037 *
038 * @since 1.1
039 */
040public class DiscreteProbabilityCollectionSampler<T> implements SharedStateObjectSampler<T> {
041    /** The error message for an empty collection. */
042    private static final String EMPTY_COLLECTION = "Empty collection";
043    /** Collection to be sampled from. */
044    private final List<T> items;
045    /** Sampler for the probabilities. */
046    private final SharedStateDiscreteSampler sampler;
047
048    /**
049     * Creates a sampler.
050     *
051     * @param rng Generator of uniformly distributed random numbers.
052     * @param collection Collection to be sampled, with the probabilities
053     * associated to each of its items.
054     * A (shallow) copy of the items will be stored in the created instance.
055     * The probabilities must be non-negative, but zero values are allowed
056     * and their sum does not have to equal one (input will be normalized
057     * to make the probabilities sum to one).
058     * @throws IllegalArgumentException if {@code collection} is empty, a
059     * probability is negative, infinite or {@code NaN}, or the sum of all
060     * probabilities is not strictly positive.
061     */
062    public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
063                                                Map<T, Double> collection) {
064        this(toList(collection),
065             createSampler(rng, toProbabilities(collection)));
066    }
067
068    /**
069     * Creates a sampler.
070     *
071     * @param rng Generator of uniformly distributed random numbers.
072     * @param collection Collection to be sampled.
073     * A (shallow) copy of the items will be stored in the created instance.
074     * @param probabilities Probability associated to each item of the
075     * {@code collection}.
076     * The probabilities must be non-negative, but zero values are allowed
077     * and their sum does not have to equal one (input will be normalized
078     * to make the probabilities sum to one).
079     * @throws IllegalArgumentException if {@code collection} is empty or
080     * a probability is negative, infinite or {@code NaN}, or if the number
081     * of items in the {@code collection} is not equal to the number of
082     * provided {@code probabilities}.
083     */
084    public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
085                                                List<T> collection,
086                                                double[] probabilities) {
087        this(copyList(collection),
088             createSampler(rng, collection, probabilities));
089    }
090
091    /**
092     * @param items Collection to be sampled.
093     * @param sampler Sampler for the probabilities.
094     */
095    private DiscreteProbabilityCollectionSampler(List<T> items,
096                                                 SharedStateDiscreteSampler sampler) {
097        this.items = items;
098        this.sampler = sampler;
099    }
100
101    /**
102     * Picks one of the items from the collection passed to the constructor.
103     *
104     * @return a random sample.
105     */
106    @Override
107    public T sample() {
108        return items.get(sampler.sample());
109    }
110
111    /**
112     * {@inheritDoc}
113     *
114     * @since 1.3
115     */
116    @Override
117    public DiscreteProbabilityCollectionSampler<T> withUniformRandomProvider(UniformRandomProvider rng) {
118        return new DiscreteProbabilityCollectionSampler<>(items, sampler.withUniformRandomProvider(rng));
119    }
120
121    /**
122     * Creates the sampler of the enumerated probability distribution.
123     *
124     * @param rng Generator of uniformly distributed random numbers.
125     * @param probabilities Probability associated to each item.
126     * @return the sampler
127     */
128    private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
129                                                            double[] probabilities) {
130        return GuideTableDiscreteSampler.of(rng, probabilities);
131    }
132
133    /**
134     * Creates the sampler of the enumerated probability distribution.
135     *
136     * @param <T> Type of items in the collection.
137     * @param rng Generator of uniformly distributed random numbers.
138     * @param collection Collection to be sampled.
139     * @param probabilities Probability associated to each item.
140     * @return the sampler
141     * @throws IllegalArgumentException if the number
142     * of items in the {@code collection} is not equal to the number of
143     * provided {@code probabilities}.
144     */
145    private static <T> SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
146                                                                List<T> collection,
147                                                                double[] probabilities) {
148        if (probabilities.length != collection.size()) {
149            throw new IllegalArgumentException("Size mismatch: " +
150                                               probabilities.length + " != " +
151                                               collection.size());
152        }
153        return GuideTableDiscreteSampler.of(rng, probabilities);
154    }
155
156    // Validation methods exist to raise an exception before invocation of the
157    // private constructor; this mitigates Finalizer attacks
158    // (see SpotBugs CT_CONSTRUCTOR_THROW).
159
160    /**
161     * Extract the items.
162     *
163     * @param <T> Type of items in the collection.
164     * @param collection Collection.
165     * @return the items
166     * @throws IllegalArgumentException if {@code collection} is empty.
167     */
168    private static <T> List<T> toList(Map<T, Double> collection) {
169        if (collection.isEmpty()) {
170            throw new IllegalArgumentException(EMPTY_COLLECTION);
171        }
172        return new ArrayList<>(collection.keySet());
173    }
174
175    /**
176     * Extract the probabilities.
177     *
178     * @param <T> Type of items in the collection.
179     * @param collection Collection.
180     * @return the probabilities
181     */
182    private static <T> double[] toProbabilities(Map<T, Double> collection) {
183        final int size = collection.size();
184        final double[] probabilities = new double[size];
185        int count = 0;
186        for (final Double e : collection.values()) {
187            final double probability = e;
188            if (probability < 0 ||
189                Double.isInfinite(probability) ||
190                Double.isNaN(probability)) {
191                throw new IllegalArgumentException("Invalid probability: " +
192                                                   probability);
193            }
194            probabilities[count++] = probability;
195        }
196        return probabilities;
197    }
198
199    /**
200     * Create a (shallow) copy of the collection.
201     *
202     * @param <T> Type of items in the collection.
203     * @param collection Collection.
204     * @return the copy
205     * @throws IllegalArgumentException if {@code collection} is empty.
206     */
207    private static <T> List<T> copyList(List<T> collection) {
208        if (collection.isEmpty()) {
209            throw new IllegalArgumentException(EMPTY_COLLECTION);
210        }
211        return new ArrayList<>(collection);
212    }
213}