UnitBallSampler.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.commons.rng.sampling.shape;

import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.SharedStateObjectSampler;
import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
import org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler;
import org.apache.commons.rng.sampling.distribution.ZigguratSampler;

/**
 * Generate coordinates <a href="http://mathworld.wolfram.com/BallPointPicking.html">
 * uniformly distributed within the unit n-ball</a>.
 *
 * <p>Sampling uses:</p>
 *
 * <ul>
 *   <li>{@link UniformRandomProvider#nextLong()}
 *   <li>{@link UniformRandomProvider#nextDouble()} (only for dimensions above 2)
 * </ul>
 *
 * @since 1.4
 */
public abstract class UnitBallSampler implements SharedStateObjectSampler<double[]> {
    /** The dimension for 1D sampling. */
    private static final int ONE_D = 1;
    /** The dimension for 2D sampling. */
    private static final int TWO_D = 2;
    /** The dimension for 3D sampling. */
    private static final int THREE_D = 3;
    /**
     * The multiplier to convert the least significant 53-bits of a {@code long} to a {@code double}.
     * Taken from o.a.c.rng.core.utils.NumberFactory.
     *
     * <p>This is equivalent to {@code 1.0 / (1L << 53)}.
     */
    private static final double DOUBLE_MULTIPLIER = 0x1.0p-53d;

    /**
     * Sample uniformly from a 1D unit line.
     */
    private static final class UnitBallSampler1D extends UnitBallSampler {
        /** The source of randomness. */
        private final UniformRandomProvider rng;

        /**
         * @param rng Source of randomness.
         */
        UnitBallSampler1D(UniformRandomProvider rng) {
            this.rng = rng;
        }

        @Override
        public double[] sample() {
            return new double[] {makeSignedDouble(rng.nextLong())};
        }

        @Override
        public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
            return new UnitBallSampler1D(rng);
        }
    }

    /**
     * Sample uniformly from a 2D unit disk.
     */
    private static final class UnitBallSampler2D extends UnitBallSampler {
        /** The source of randomness. */
        private final UniformRandomProvider rng;

        /**
         * @param rng Source of randomness.
         */
        UnitBallSampler2D(UniformRandomProvider rng) {
            this.rng = rng;
        }

        @Override
        public double[] sample() {
            // Generate via rejection method of a circle inside a square of edge length 2.
            // This should compute approximately 2^2 / pi = 1.27 square positions per sample.
            double x;
            double y;
            do {
                x = makeSignedDouble(rng.nextLong());
                y = makeSignedDouble(rng.nextLong());
            } while (x * x + y * y > 1.0);
            return new double[] {x, y};
        }

        @Override
        public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
            return new UnitBallSampler2D(rng);
        }
    }

    /**
     * Sample uniformly from a 3D unit ball. This is an non-array based specialisation of
     * {@link UnitBallSamplerND} for performance.
     */
    private static final class UnitBallSampler3D extends UnitBallSampler {
        /** The standard normal distribution. */
        private final NormalizedGaussianSampler normal;
        /** The exponential distribution (mean=1). */
        private final ContinuousSampler exp;

        /**
         * @param rng Source of randomness.
         */
        UnitBallSampler3D(UniformRandomProvider rng) {
            normal = ZigguratSampler.NormalizedGaussian.of(rng);
            // Require an Exponential(mean=2).
            // Here we use mean = 1 and scale the output later.
            exp = ZigguratSampler.Exponential.of(rng);
        }

        @Override
        public double[] sample() {
            final double x = normal.sample();
            final double y = normal.sample();
            final double z = normal.sample();
            // Include the exponential sample. It has mean 1 so multiply by 2.
            final double sum = exp.sample() * 2 + x * x + y * y + z * z;
            // Note: Handle the possibility of a zero sum and invalid inverse
            if (sum == 0) {
                return sample();
            }
            final double f = 1.0 / Math.sqrt(sum);
            return new double[] {x * f, y * f, z * f};
        }

        @Override
        public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
            return new UnitBallSampler3D(rng);
        }
    }

    /**
     * Sample using ball point picking.
     * @see <a href="https://mathworld.wolfram.com/BallPointPicking.html">Ball point picking</a>
     */
    private static final class UnitBallSamplerND extends UnitBallSampler {
        /** The dimension. */
        private final int dimension;
        /** The standard normal distribution. */
        private final NormalizedGaussianSampler normal;
        /** The exponential distribution (mean=1). */
        private final ContinuousSampler exp;

        /**
         * @param rng Source of randomness.
         * @param dimension Space dimension.
         */
        UnitBallSamplerND(UniformRandomProvider rng, int dimension) {
            this.dimension  = dimension;
            normal = ZigguratSampler.NormalizedGaussian.of(rng);
            // Require an Exponential(mean=2).
            // Here we use mean = 1 and scale the output later.
            exp = ZigguratSampler.Exponential.of(rng);
        }

        @Override
        public double[] sample() {
            final double[] sample = new double[dimension];
            // Include the exponential sample. It has mean 1 so multiply by 2.
            double sum = exp.sample() * 2;
            for (int i = 0; i < dimension; i++) {
                final double x = normal.sample();
                sum += x * x;
                sample[i] = x;
            }
            // Note: Handle the possibility of a zero sum and invalid inverse
            if (sum == 0) {
                return sample();
            }
            final double f = 1.0 / Math.sqrt(sum);
            for (int i = 0; i < dimension; i++) {
                sample[i] *= f;
            }
            return sample;
        }

        @Override
        public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
            return new UnitBallSamplerND(rng, dimension);
        }
    }

    /**
     * Create an instance.
     */
    public UnitBallSampler() {}

    /**
     * @return a random Cartesian coordinate within the unit n-ball.
     */
    @Override
    public abstract double[] sample();

    /** {@inheritDoc} */
    // Redeclare the signature to return a UnitBallSampler not a SharedStateObjectSampler<double[]>
    @Override
    public abstract UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng);

    /**
     * Create a unit n-ball sampler for the given dimension.
     * Sampled points are uniformly distributed within the unit n-ball.
     *
     * <p>Sampling is supported in dimensions of 1 or above.
     *
     * @param rng Source of randomness.
     * @param dimension Space dimension.
     * @return the sampler
     * @throws IllegalArgumentException If {@code dimension <= 0}
     */
    public static UnitBallSampler of(UniformRandomProvider rng,
                                     int dimension) {
        if (dimension <= 0) {
            throw new IllegalArgumentException("Dimension must be strictly positive");
        } else if (dimension == ONE_D) {
            return new UnitBallSampler1D(rng);
        } else if (dimension == TWO_D) {
            return new UnitBallSampler2D(rng);
        } else if (dimension == THREE_D) {
            return new UnitBallSampler3D(rng);
        }
        return new UnitBallSamplerND(rng, dimension);
    }

    /**
     * Creates a signed double in the range {@code [-1, 1)}. The magnitude is sampled evenly
     * from the 2<sup>54</sup> dyadic rationals in the range.
     *
     * <p>Note: This method will not return samples for both -0.0 and 0.0.
     *
     * @param bits the bits
     * @return the double
     */
    private static double makeSignedDouble(long bits) {
        // As per o.a.c.rng.core.utils.NumberFactory.makeDouble(long) but using a signed
        // shift of 10 in place of an unsigned shift of 11.
        // Use the upper 54 bits on the assumption they are more random.
        // The sign bit is maintained by the signed shift.
        // The next 53 bits generates a magnitude in the range [0, 2^53) or [-2^53, 0).
        return (bits >> 10) * DOUBLE_MULTIPLIER;
    }
}