FastLoadedDiceRollerDiscreteSampler.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.rng.sampling.distribution;

import java.math.BigInteger;
import java.util.Arrays;
import org.apache.commons.rng.UniformRandomProvider;

/**
 * Distribution sampler that uses the Fast Loaded Dice Roller (FLDR). It can be used to
 * sample from {@code n} values each with an associated relative weight. If all unique items
 * are assigned the same weight it is more efficient to use the {@link DiscreteUniformSampler}.
 *
 * <p>Given a list {@code L} of {@code n} positive numbers,
 * where {@code L[i]} represents the relative weight of the {@code i}th side, FLDR returns
 * integer {@code i} with relative probability {@code L[i]}.
 *
 * <p>FLDR produces <em>exact</em> samples from the specified probability distribution.
 * <ul>
 *   <li>For integer weights, the probability of returning {@code i} is precisely equal to the
 *   rational number {@code L[i] / m}, where {@code m} is the sum of {@code L}.
 *   <li>For floating-points weights, each weight {@code L[i]} is converted to the
 *   corresponding rational number {@code p[i] / q[i]} where {@code p[i]} is a positive integer and
 *   {@code q[i]} is a power of 2. The rational weights are then normalized (exactly) to sum to unity.
 * </ul>
 *
 * <p>Note that if <em>exact</em> samples are not required then an alternative sampler that
 * ignores very small relative weights may have improved sampling performance.
 *
 * <p>This implementation is based on the algorithm in:
 *
 * <blockquote>
 *  Feras A. Saad, Cameron E. Freer, Martin C. Rinard, and Vikash K. Mansinghka.
 *  The Fast Loaded Dice Roller: A Near-Optimal Exact Sampler for Discrete Probability
 *  Distributions. In AISTATS 2020: Proceedings of the 23rd International Conference on
 *  Artificial Intelligence and Statistics, Proceedings of Machine Learning Research 108,
 *  Palermo, Sicily, Italy, 2020.
 * </blockquote>
 *
 * <p>Sampling uses {@link UniformRandomProvider#nextInt()} as the source of random bits.
 *
 * @see <a href="https://arxiv.org/abs/2003.03830">Saad et al (2020)
 * Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics,
 * PMLR 108:1036-1046.</a>
 * @since 1.5
 */
public abstract class FastLoadedDiceRollerDiscreteSampler
    implements SharedStateDiscreteSampler {
    /**
     * The maximum size of an array.
     *
     * <p>This value is taken from the limit in Open JDK 8 {@code java.util.ArrayList}.
     * It allows VMs to reserve some header words in an array.
     */
    private static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8;
    /** The maximum biased exponent for a finite double.
     * This is offset by 1023 from {@code Math.getExponent(Double.MAX_VALUE)}. */
    private static final int MAX_BIASED_EXPONENT = 2046;
    /** Size of the mantissa of a double. Equal to 52 bits. */
    private static final int MANTISSA_SIZE = 52;
    /** Mask to extract the 52-bit mantissa from a long representation of a double. */
    private static final long MANTISSA_MASK = 0x000f_ffff_ffff_ffffL;
    /** BigInteger representation of {@link Long#MAX_VALUE}. */
    private static final BigInteger MAX_LONG = BigInteger.valueOf(Long.MAX_VALUE);
    /** The maximum offset that will avoid loss of bits for a left shift of a 53-bit value.
     * The value will remain positive for any shift {@code <=} this value. */
    private static final int MAX_OFFSET = 10;
    /** Initial value for no leaf node label. */
    private static final int NO_LABEL = Integer.MAX_VALUE;
    /** Name of the sampler. */
    private static final String SAMPLER_NAME = "Fast Loaded Dice Roller";

    /**
     * Class to handle the edge case of observations in only one category.
     */
    private static final class FixedValueDiscreteSampler extends FastLoadedDiceRollerDiscreteSampler {
        /** The sample value. */
        private final int sampleValue;

        /**
         * @param sampleValue Sample value.
         */
        FixedValueDiscreteSampler(int sampleValue) {
            this.sampleValue = sampleValue;
        }

        @Override
        public int sample() {
            return sampleValue;
        }

        @Override
        public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
            return this;
        }

        @Override
        public String toString() {
            return SAMPLER_NAME;
        }
    }

    /**
     * Class to implement the FLDR sample algorithm.
     */
    private static final class FLDRSampler extends FastLoadedDiceRollerDiscreteSampler {
        /** Empty boolean source. This is the location of the sign-bit after 31 right shifts on
         * the boolean source. */
        private static final int EMPTY_BOOL_SOURCE = 1;

        /** Underlying source of randomness. */
        private final UniformRandomProvider rng;
        /** Number of categories. */
        private final int n;
        /** Number of levels in the discrete distribution generating (DDG) tree.
         * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations. */
        private final int k;
        /** Number of leaf nodes at each level. */
        private final int[] h;
        /** Stores the leaf node labels in increasing order. Named {@code H} in the FLDR paper. */
        private final int[] lH;

        /**
         * Provides a bit source for booleans.
         *
         * <p>A cached value from a call to {@link UniformRandomProvider#nextInt()}.
         *
         * <p>Only stores 31-bits when full as 1 bit has already been consumed.
         * The sign bit is a flag that shifts down so the source eventually equals 1
         * when all bits are consumed and will trigger a refill.
         */
        private int booleanSource = EMPTY_BOOL_SOURCE;

        /**
         * Creates a sampler.
         *
         * <p>The input parameters are not validated and must be correctly computed tables.
         *
         * @param rng Generator of uniformly distributed random numbers.
         * @param n Number of categories
         * @param k Number of levels in the discrete distribution generating (DDG) tree.
         * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations.
         * @param h Number of leaf nodes at each level.
         * @param lH Stores the leaf node labels in increasing order.
         */
        FLDRSampler(UniformRandomProvider rng,
                    int n,
                    int k,
                    int[] h,
                    int[] lH) {
            this.rng = rng;
            this.n = n;
            this.k = k;
            // Deliberate direct storage of input arrays
            this.h = h;
            this.lH = lH;
        }

        /**
         * Creates a copy with a new source of randomness.
         *
         * @param rng Generator of uniformly distributed random numbers.
         * @param source Source to copy.
         */
        private FLDRSampler(UniformRandomProvider rng,
                            FLDRSampler source) {
            this.rng = rng;
            this.n = source.n;
            this.k = source.k;
            this.h = source.h;
            this.lH = source.lH;
        }

        /** {@inheritDoc} */
        @Override
        public int sample() {
            // ALGORITHM 5: SAMPLE
            int c = 0;
            int d = 0;
            for (;;) {
                // b = flip()
                // d = 2 * d + (1 - b)
                d = (d << 1) + flip();
                if (d < h[c]) {
                    // z = H[d][c]
                    final int z = lH[d * k + c];
                    // assert z != NO_LABEL
                    if (z < n) {
                        return z;
                    }
                    d = 0;
                    c = 0;
                } else {
                    d = d - h[c];
                    c++;
                }
            }
        }

        /**
         * Provides a source of boolean bits.
         *
         * <p>Note: This replicates the boolean cache functionality of
         * {@code o.a.c.rng.core.source32.IntProvider}. The method has been simplified to return
         * an {@code int} value rather than a {@code boolean}.
         *
         * @return the bit (0 or 1)
         */
        private int flip() {
            int bits = booleanSource;
            if (bits == 1) {
                // Refill
                bits = rng.nextInt();
                // Store a refill flag in the sign bit and the unused 31 bits, return lowest bit
                booleanSource = Integer.MIN_VALUE | (bits >>> 1);
                return bits & 0x1;
            }
            // Shift down eventually triggering refill, return current lowest bit
            booleanSource = bits >>> 1;
            return bits & 0x1;
        }

        /** {@inheritDoc} */
        @Override
        public String toString() {
            return SAMPLER_NAME + " [" + rng.toString() + "]";
        }

        /** {@inheritDoc} */
        @Override
        public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
            return new FLDRSampler(rng, this);
        }
    }

    /** Package-private constructor. */
    FastLoadedDiceRollerDiscreteSampler() {
        // Intentionally empty
    }

    /** {@inheritDoc} */
    // Redeclare the signature to return a FastLoadedDiceRollerSampler not a SharedStateLongSampler
    @Override
    public abstract FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng);

    /**
     * Creates a sampler.
     *
     * <p>Note: The discrete distribution generating (DDG) tree requires {@code (n + 1) * k} entries
     * where {@code n} is the number of categories, {@code k == ceil(log2(m))} and {@code m}
     * is the sum of the observed frequencies. An exception is raised if this cannot be allocated
     * as a single array.
     *
     * <p>For reference the sum is limited to {@link Long#MAX_VALUE} and the value {@code k} to 63.
     * The number of categories is limited to approximately {@code ((2^31 - 1) / k) = 34,087,042}
     * when the sum of frequencies is large enough to create k=63.
     *
     * @param rng Generator of uniformly distributed random numbers.
     * @param frequencies Observed frequencies of the discrete distribution.
     * @return the sampler
     * @throws IllegalArgumentException if {@code frequencies} is null or empty, a
     * frequency is negative, the sum of all frequencies is either zero or
     * above {@link Long#MAX_VALUE}, or the size of the discrete distribution generating tree
     * is too large.
     */
    public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
                                                         long[] frequencies) {
        final long m = sum(frequencies);

        // Obtain indices of non-zero frequencies
        final int[] indices = indicesOfNonZero(frequencies);

        // Edge case for 1 non-zero weight. This also handles edge case for 1 observation
        // (as log2(m) == 0 will break the computation of the DDG tree).
        if (indices.length == 1) {
            return new FixedValueDiscreteSampler(indexOfNonZero(frequencies));
        }

        return createSampler(rng, frequencies, indices, m);
    }

    /**
     * Creates a sampler.
     *
     * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is a power of 2.
     * The numerators {@code p} are scaled to use a common denominator before summing.
     *
     * <p>All weights are used to create the sampler. Weights with a small magnitude relative
     * to the largest weight can be excluded using the constructor method with the
     * relative magnitude parameter {@code alpha} (see {@link #of(UniformRandomProvider, double[], int)}).
     *
     * @param rng Generator of uniformly distributed random numbers.
     * @param weights Weights of the discrete distribution.
     * @return the sampler
     * @throws IllegalArgumentException if {@code weights} is null or empty, a
     * weight is negative, infinite or {@code NaN}, the sum of all weights is zero, or the size
     * of the discrete distribution generating tree is too large.
     * @see #of(UniformRandomProvider, double[], int)
     */
    public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
                                                         double[] weights) {
        return of(rng, weights, 0);
    }

    /**
     * Creates a sampler.
     *
     * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is
     * a power of 2. The numerators {@code p} are scaled to use a common
     * denominator before summing.
     *
     * <p>Note: The discrete distribution generating (DDG) tree requires
     * {@code (n + 1) * k} entries where {@code n} is the number of categories,
     * {@code k == ceil(log2(m))} and {@code m} is the sum of the weight numerators
     * {@code q}. An exception is raised if this cannot be allocated as a single
     * array.
     *
     * <p>For reference the value {@code k} is equal to or greater than the ratio of
     * the largest to the smallest weight expressed as a power of 2. For
     * {@code Double.MAX_VALUE / Double.MIN_VALUE} this is ~2098. The value
     * {@code k} increases with the sum of the weight numerators. A number of
     * weights in excess of 1,000,000 with values equal to {@link Double#MAX_VALUE}
     * would be required to raise an exception when the minimum weight is
     * {@link Double#MIN_VALUE}.
     *
     * <p>Weights with a small magnitude relative to the largest weight can be
     * excluded using the relative magnitude parameter {@code alpha}. This will set
     * any weight to zero if the magnitude is approximately 2<sup>alpha</sup>
     * <em>smaller</em> than the largest weight. This comparison is made using only
     * the exponent of the input weights. The {@code alpha} parameter is ignored if
     * not above zero. Note that a small {@code alpha} parameter will exclude more
     * weights than a large {@code alpha} parameter.
     *
     * <p>The alpha parameter can be used to exclude categories that
     * have a very low probability of occurrence and will improve the construction
     * performance of the sampler. The effect on sampling performance depends on
     * the relative weights of the excluded categories; typically a high {@code alpha}
     * is used to exclude categories that would be visited with a very low probability
     * and the sampling performance is unchanged.
     *
     * <p><b>Implementation Note</b>
     *
     * <p>This method creates a sampler with <em>exact</em> samples from the
     * specified probability distribution. It is recommended to use this method:
     * <ul>
     *  <li>if the weights are computed, for example from a probability mass function; or
     *  <li>if the weights sum to an infinite value.
     * </ul>
     *
     * <p>If the weights are computed from empirical observations then it is
     * recommended to use the factory method
     * {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This
     * requires the total number of observations to be representable as a long
     * integer.
     *
     * <p>Note that if all weights are scaled by a power of 2 to be integers, and
     * each integer can be represented as a positive 64-bit long value, then the
     * sampler created using this method will match the output from a sampler
     * created with the scaled weights converted to long values for the factory
     * method {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This
     * assumes the sum of the integer values does not overflow.
     *
     * <p>It should be noted that the conversion of weights to rational numbers has
     * a performance overhead during construction (sampling performance is not
     * affected). This may be avoided by first converting them to integer values
     * that can be summed without overflow. For example by scaling values by
     * {@code 2^62 / sum} and converting to long by casting or rounding.
     *
     * <p>This approach may increase the efficiency of construction. The resulting
     * sampler may no longer produce <em>exact</em> samples from the distribution.
     * In particular any weights with a converted frequency of zero cannot be
     * sampled.
     *
     * @param rng Generator of uniformly distributed random numbers.
     * @param weights Weights of the discrete distribution.
     * @param alpha Alpha parameter.
     * @return the sampler
     * @throws IllegalArgumentException if {@code weights} is null or empty, a
     * weight is negative, infinite or {@code NaN}, the sum of all weights is zero,
     * or the size of the discrete distribution generating tree is too large.
     * @see #of(UniformRandomProvider, long[])
     */
    public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
                                                         double[] weights,
                                                         int alpha) {
        final int n = checkWeightsNonZeroLength(weights);

        // Convert floating-point double to a relative weight
        // using a shifted integer representation
        final long[] frequencies = new long[n];
        final int[] offsets = new int[n];
        convertToIntegers(weights, frequencies, offsets, alpha);

        // Obtain indices of non-zero weights
        final int[] indices = indicesOfNonZero(frequencies);

        // Edge case for 1 non-zero weight.
        if (indices.length == 1) {
            return new FixedValueDiscreteSampler(indexOfNonZero(frequencies));
        }

        final BigInteger m = sum(frequencies, offsets, indices);

        // Use long arithmetic if possible. This occurs when the weights are similar in magnitude.
        if (m.compareTo(MAX_LONG) <= 0) {
            // Apply the offset
            for (int i = 0; i < n; i++) {
                frequencies[i] <<= offsets[i];
            }
            return createSampler(rng, frequencies, indices, m.longValue());
        }

        return createSampler(rng, frequencies, offsets, indices, m);
    }

    /**
     * Sum the frequencies.
     *
     * @param frequencies Frequencies.
     * @return the sum
     * @throws IllegalArgumentException if {@code frequencies} is null or empty, a
     * frequency is negative, or the sum of all frequencies is either zero or above
     * {@link Long#MAX_VALUE}
     */
    private static long sum(long[] frequencies) {
        // Validate
        if (frequencies == null || frequencies.length == 0) {
            throw new IllegalArgumentException("frequencies must contain at least 1 value");
        }

        // Sum the values.
        // Combine all the sign bits in the observations and the intermediate sum in a flag.
        long m = 0;
        long signFlag = 0;
        for (final long o : frequencies) {
            m += o;
            signFlag |= o | m;
        }

        // Check for a sign-bit.
        if (signFlag < 0) {
            // One or more observations were negative, or the sum overflowed.
            for (final long o : frequencies) {
                if (o < 0) {
                    throw new IllegalArgumentException("frequencies must contain positive values: " + o);
                }
            }
            throw new IllegalArgumentException("Overflow when summing frequencies");
        }
        if (m == 0) {
            throw new IllegalArgumentException("Sum of frequencies is zero");
        }
        return m;
    }

    /**
     * Convert the floating-point weights to relative weights represented as
     * integers {@code value * 2^exponent}. The relative weight as an integer is:
     *
     * <pre>
     * BigInteger.valueOf(value).shiftLeft(exponent)
     * </pre>
     *
     * <p>Note that the weights are created using a common power-of-2 scaling
     * operation so the minimum exponent is zero.
     *
     * <p>A positive {@code alpha} parameter is used to set any weight to zero if
     * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the
     * largest weight. This comparison is made using only the exponent of the input
     * weights.
     *
     * @param weights Weights of the discrete distribution.
     * @param values Output floating-point mantissas converted to 53-bit integers.
     * @param exponents Output power of 2 exponent.
     * @param alpha Alpha parameter.
     * @throws IllegalArgumentException if a weight is negative, infinite or
     * {@code NaN}, or the sum of all weights is zero.
     */
    private static void convertToIntegers(double[] weights, long[] values, int[] exponents, int alpha) {
        int maxExponent = Integer.MIN_VALUE;
        for (int i = 0; i < weights.length; i++) {
            final double weight = weights[i];
            // Ignore zero.
            // When creating the integer value later using bit shifts the result will remain zero.
            if (weight == 0) {
                continue;
            }
            final long bits = Double.doubleToRawLongBits(weight);

            // For the IEEE 754 format see Double.longBitsToDouble(long).

            // Extract the exponent (with the sign bit)
            int exp = (int) (bits >>> MANTISSA_SIZE);
            // Detect negative, infinite or NaN.
            // Note: Negative values sign bit will cause the exponent to be too high.
            if (exp > MAX_BIASED_EXPONENT) {
                throw new IllegalArgumentException("Invalid weight: " + weight);
            }
            long mantissa;
            if (exp == 0) {
                // Sub-normal number:
                mantissa = (bits & MANTISSA_MASK) << 1;
                // Here we convert to a normalised number by counting the leading zeros
                // to obtain the number of shifts of the most significant bit in
                // the mantissa that is required to get a 1 at position 53 (i.e. as
                // if it were a normal number with assumed leading bit).
                final int shift = Long.numberOfLeadingZeros(mantissa << 11);
                mantissa <<= shift;
                exp -= shift;
            } else {
                // Normal number. Add the implicit leading 1-bit.
                mantissa = (bits & MANTISSA_MASK) | (1L << MANTISSA_SIZE);
            }

            // Here the floating-point number is equal to:
            // mantissa * 2^(exp-1075)

            values[i] = mantissa;
            exponents[i] = exp;
            maxExponent = Math.max(maxExponent, exp);
        }

        // No exponent indicates that all weights are zero
        if (maxExponent == Integer.MIN_VALUE) {
            throw new IllegalArgumentException("Sum of weights is zero");
        }

        filterWeights(values, exponents, alpha, maxExponent);
        scaleWeights(values, exponents);
    }

    /**
     * Filters small weights using the {@code alpha} parameter.
     * A positive {@code alpha} parameter is used to set any weight to zero if
     * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the
     * largest weight. This comparison is made using only the exponent of the input
     * weights.
     *
     * @param values 53-bit values.
     * @param exponents Power of 2 exponent.
     * @param alpha Alpha parameter.
     * @param maxExponent Maximum exponent.
     */
    private static void filterWeights(long[] values, int[] exponents, int alpha, int maxExponent) {
        if (alpha > 0) {
            // Filter weights. This must be done before the values are shifted so
            // the exponent represents the approximate magnitude of the value.
            for (int i = 0; i < exponents.length; i++) {
                if (maxExponent - exponents[i] > alpha) {
                    values[i] = 0;
                }
            }
        }
    }

    /**
     * Scale the weights represented as integers {@code value * 2^exponent} to use a
     * minimum exponent of zero. The values are scaled to remove any common trailing zeros
     * in their representation. This ultimately reduces the size of the discrete distribution
     * generating (DGG) tree.
     *
     * @param values 53-bit values.
     * @param exponents Power of 2 exponent.
     */
    private static void scaleWeights(long[] values, int[] exponents) {
        // Find the minimum exponent and common trailing zeros.
        int minExponent = Integer.MAX_VALUE;
        for (int i = 0; i < exponents.length; i++) {
            if (values[i] != 0) {
                minExponent = Math.min(minExponent, exponents[i]);
            }
        }
        // Trailing zeros occur when the original weights have a representation with
        // less than 52 binary digits, e.g. {1.5, 0.5, 0.25}.
        int trailingZeros = Long.SIZE;
        for (int i = 0; i < values.length && trailingZeros != 0; i++) {
            trailingZeros = Math.min(trailingZeros, Long.numberOfTrailingZeros(values[i]));
        }
        // Scale by a power of 2 so the minimum exponent is zero.
        for (int i = 0; i < exponents.length; i++) {
            exponents[i] -= minExponent;
        }
        // Remove common trailing zeros.
        if (trailingZeros != 0) {
            for (int i = 0; i < values.length; i++) {
                values[i] >>>= trailingZeros;
            }
        }
    }

    /**
     * Sum the integers at the specified indices.
     * Integers are represented as {@code value * 2^exponent}.
     *
     * @param values 53-bit values.
     * @param exponents Power of 2 exponent.
     * @param indices Indices to sum.
     * @return the sum
     */
    private static BigInteger sum(long[] values, int[] exponents, int[] indices) {
        BigInteger m = BigInteger.ZERO;
        for (final int i : indices) {
            m = m.add(toBigInteger(values[i], exponents[i]));
        }
        return m;
    }

    /**
     * Convert the value and left shift offset to a BigInteger.
     * It is assumed the value is at most 53-bits. This allows optimising the left
     * shift if it is below 11 bits.
     *
     * @param value 53-bit value.
     * @param offset Left shift offset (must be positive).
     * @return the BigInteger
     */
    private static BigInteger toBigInteger(long value, int offset) {
        // Ignore zeros. The sum method uses indices of non-zero values.
        if (offset <= MAX_OFFSET) {
            // Assume (value << offset) <= Long.MAX_VALUE
            return BigInteger.valueOf(value << offset);
        }
        return BigInteger.valueOf(value).shiftLeft(offset);
    }

    /**
     * Creates the sampler.
     *
     * <p>It is assumed the frequencies are all positive and the sum does not
     * overflow.
     *
     * @param rng Generator of uniformly distributed random numbers.
     * @param frequencies Observed frequencies of the discrete distribution.
     * @param indices Indices of non-zero frequencies.
     * @param m Sum of the frequencies.
     * @return the sampler
     */
    private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng,
                                                                     long[] frequencies,
                                                                     int[] indices,
                                                                     long m) {
        // ALGORITHM 5: PREPROCESS
        // a == frequencies
        // m = sum(a)
        // h = leaf node count
        // H = leaf node label (lH)

        final int n = frequencies.length;

        // k = ceil(log2(m))
        final int k = 64 - Long.numberOfLeadingZeros(m - 1);
        // r = a(n+1) = 2^k - m
        final long r = (1L << k) - m;

        // Note:
        // A sparse matrix can often be used for H, as most of its entries are empty.
        // This implementation uses a 1D array for efficiency at the cost of memory.
        // This is limited to approximately ((2^31 - 1) / k), e.g. 34087042 when the sum of
        // observations is large enough to create k=63.
        // This could be handled using a 2D array. In practice a number of categories this
        // large is not expected and is currently not supported.
        final int[] h = new int[k];
        final int[] lH = new int[checkArraySize((n + 1L) * k)];
        Arrays.fill(lH, NO_LABEL);

        int d;
        for (int j = 0; j < k; j++) {
            final int shift = (k - 1) - j;
            final long bitMask = 1L << shift;

            d = 0;
            for (final int i : indices) {
                // bool w ← (a[i] >> (k − 1) − j)) & 1
                // h[j] = h[j] + w
                // if w then:
                if ((frequencies[i] & bitMask) != 0) {
                    h[j]++;
                    // H[d][j] = i
                    lH[d * k + j] = i;
                    d++;
                }
            }
            // process a(n+1) without extending the input frequencies array by 1
            if ((r & bitMask) != 0) {
                h[j]++;
                lH[d * k + j] = n;
            }
        }

        return new FLDRSampler(rng, n, k, h, lH);
    }

    /**
     * Creates the sampler. Frequencies are represented as a 53-bit value with a
     * left-shift offset.
     * <pre>
     * BigInteger.valueOf(value).shiftLeft(offset)
     * </pre>
     *
     * <p>It is assumed the frequencies are all positive.
     *
     * @param rng Generator of uniformly distributed random numbers.
     * @param frequencies Observed frequencies of the discrete distribution.
     * @param offsets Left shift offsets (must be positive).
     * @param indices Indices of non-zero frequencies.
     * @param m Sum of the frequencies.
     * @return the sampler
     */
    private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng,
                                                                     long[] frequencies,
                                                                     int[] offsets,
                                                                     int[] indices,
                                                                     BigInteger m) {
        // Repeat the logic from createSampler(...) using extended arithmetic to test the bits

        // ALGORITHM 5: PREPROCESS
        // a == frequencies
        // m = sum(a)
        // h = leaf node count
        // H = leaf node label (lH)

        final int n = frequencies.length;

        // k = ceil(log2(m))
        final int k = m.subtract(BigInteger.ONE).bitLength();
        // r = a(n+1) = 2^k - m
        final BigInteger r = BigInteger.ONE.shiftLeft(k).subtract(m);

        final int[] h = new int[k];
        final int[] lH = new int[checkArraySize((n + 1L) * k)];
        Arrays.fill(lH, NO_LABEL);

        int d;
        for (int j = 0; j < k; j++) {
            final int shift = (k - 1) - j;

            d = 0;
            for (final int i : indices) {
                // bool w ← (a[i] >> (k − 1) − j)) & 1
                // h[j] = h[j] + w
                // if w then:
                if (testBit(frequencies[i], offsets[i], shift)) {
                    h[j]++;
                    // H[d][j] = i
                    lH[d * k + j] = i;
                    d++;
                }
            }
            // process a(n+1) without extending the input frequencies array by 1
            if (r.testBit(shift)) {
                h[j]++;
                lH[d * k + j] = n;
            }
        }

        return new FLDRSampler(rng, n, k, h, lH);
    }

    /**
     * Test the logical bit of the shifted integer representation.
     * The value is assumed to have at most 53-bits of information. The offset
     * is assumed to be positive. This is functionally equivalent to:
     * <pre>
     * BigInteger.valueOf(value).shiftLeft(offset).testBit(n)
     * </pre>
     *
     * @param value 53-bit value.
     * @param offset Left shift offset.
     * @param n Index of bit to test.
     * @return true if the bit is 1
     */
    private static boolean testBit(long value, int offset, int n) {
        if (n < offset) {
            // All logical trailing bits are zero
            return false;
        }
        // Test if outside the 53-bit value (note that the implicit 1 bit
        // has been added to the 52-bit mantissas for 'normal' floating-point numbers).
        final int bit = n - offset;
        return bit <= MANTISSA_SIZE && (value & (1L << bit)) != 0;
    }

    /**
     * Check the weights have a non-zero length.
     *
     * @param weights Weights.
     * @return the length
     */
    private static int checkWeightsNonZeroLength(double[] weights) {
        if (weights == null || weights.length == 0) {
            throw new IllegalArgumentException("weights must contain at least 1 value");
        }
        return weights.length;
    }

    /**
     * Create the indices of non-zero values.
     *
     * @param values Values.
     * @return the indices
     */
    private static int[] indicesOfNonZero(long[] values) {
        int n = 0;
        final int[] indices = new int[values.length];
        for (int i = 0; i < values.length; i++) {
            if (values[i] != 0) {
                indices[n++] = i;
            }
        }
        return Arrays.copyOf(indices, n);
    }

    /**
     * Find the index of the first non-zero frequency.
     *
     * @param frequencies Frequencies.
     * @return the index
     * @throws IllegalStateException if all frequencies are zero.
     */
    static int indexOfNonZero(long[] frequencies) {
        for (int i = 0; i < frequencies.length; i++) {
            if (frequencies[i] != 0) {
                return i;
            }
        }
        throw new IllegalStateException("All frequencies are zero");
    }

    /**
     * Check the size is valid for a 1D array.
     *
     * @param size Size
     * @return the size as an {@code int}
     * @throws IllegalArgumentException if the size is too large for a 1D array.
     */
    static int checkArraySize(long size) {
        if (size > MAX_ARRAY_SIZE) {
            throw new IllegalArgumentException("Unable to allocate array of size: " + size);
        }
        return (int) size;
    }
}