001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.rng.simple;
018
019import java.util.EnumMap;
020import java.util.Map;
021import org.apache.commons.rng.UniformRandomProvider;
022
023/**
024 * This class provides a thread-local {@link UniformRandomProvider}.
025 *
026 * <p>The {@link UniformRandomProvider} is created once-per-thread using the default
027 * construction method {@link RandomSource#create()}.
028 *
029 * <p>Example:</p>
030 * <pre><code>
031 * import org.apache.commons.rng.simple.RandomSource;
032 * import org.apache.commons.rng.simple.ThreadLocalRandomSource;
033 * import org.apache.commons.rng.sampling.distribution.PoissonSampler;
034 *
035 * // Access a thread-safe random number generator
036 * UniformRandomProvider rng = ThreadLocalRandomSource.current(RandomSource.SPLIT_MIX_64);
037 *
038 * // One-time Poisson sample
039 * double mean = 12.3;
040 * int counts = PoissonSampler.of(rng, mean).sample();
041 * </code></pre>
042 *
043 * <p>Note if the {@link RandomSource} requires additional arguments then it is not
044 * supported. The same can be achieved using:</p>
045 *
046 * <pre><code>
047 * import org.apache.commons.rng.simple.RandomSource;
048 * import org.apache.commons.rng.sampling.distribution.PoissonSampler;
049 *
050 * // Provide a thread-safe random number generator with data arguments
051 * private static ThreadLocal&lt;UniformRandomProvider&gt; rng =
052 *     new ThreadLocal&lt;UniformRandomProvider&gt;() {
053 *         &#64;Override
054 *         protected UniformRandomProvider initialValue() {
055 *             return RandomSource.TWO_CMRES_SELECT.create(null, 3, 4);
056 *         }
057 *     };
058 *
059 * // One-time Poisson sample using a thread-safe random number generator
060 * double mean = 12.3;
061 * int counts = PoissonSampler.of(rng.get(), mean).sample();
062 * </code></pre>
063 *
064 * @since 1.3
065 */
066public final class ThreadLocalRandomSource {
067    /**
068     * A map containing the {@link ThreadLocal} instance for each {@link RandomSource}.
069     *
070     * <p>This should only be modified to create new instances in a synchronized block.
071     */
072    private static final Map<RandomSource, ThreadLocal<UniformRandomProvider>> SOURCES =
073        new EnumMap<>(RandomSource.class);
074
075    /** No public construction. */
076    private ThreadLocalRandomSource() {}
077
078    /**
079     * Extend the {@link ThreadLocal} to allow creation of the desired {@link RandomSource}.
080     */
081    private static class ThreadLocalRng extends ThreadLocal<UniformRandomProvider> {
082        /** The source. */
083        private final RandomSource source;
084
085        /**
086         * Create a new instance.
087         *
088         * @param source the source
089         */
090        ThreadLocalRng(RandomSource source) {
091            this.source = source;
092        }
093
094        @Override
095        protected UniformRandomProvider initialValue() {
096            // Create with the default seed generation method
097            return source.create();
098        }
099    }
100
101    /**
102     * Returns the current thread's copy of the given {@code source}. If there is no
103     * value for the current thread, it is first initialized to the value returned
104     * by {@link RandomSource#create()}.
105     *
106     * <p>Note if the {@code source} requires additional arguments then it is not
107     * supported.
108     *
109     * @param source the source
110     * @return the current thread's value of the {@code source}.
111     * @throws IllegalArgumentException if the source is null or the source requires arguments
112     */
113    public static UniformRandomProvider current(RandomSource source) {
114        ThreadLocal<UniformRandomProvider> rng = SOURCES.get(source);
115        // Implement double-checked locking:
116        // https://en.wikipedia.org/wiki/Double-checked_locking#Usage_in_Java
117        if (rng == null) {
118            // Do the checks on the source here since it is an edge case
119            // and the EnumMap handles null (returning null).
120            if (source == null) {
121                throw new IllegalArgumentException("Random source is null");
122            }
123
124            synchronized (SOURCES) {
125                rng = SOURCES.computeIfAbsent(source, ThreadLocalRng::new);
126            }
127        }
128        return rng.get();
129    }
130}