UnitBallSampler.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.rng.sampling.shape;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.SharedStateObjectSampler;
import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
import org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler;
import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
/**
* Generate coordinates <a href="http://mathworld.wolfram.com/BallPointPicking.html">
* uniformly distributed within the unit n-ball</a>.
*
* <p>Sampling uses:</p>
*
* <ul>
* <li>{@link UniformRandomProvider#nextLong()}
* <li>{@link UniformRandomProvider#nextDouble()} (only for dimensions above 2)
* </ul>
*
* @since 1.4
*/
public abstract class UnitBallSampler implements SharedStateObjectSampler<double[]> {
/** The dimension for 1D sampling. */
private static final int ONE_D = 1;
/** The dimension for 2D sampling. */
private static final int TWO_D = 2;
/** The dimension for 3D sampling. */
private static final int THREE_D = 3;
/**
* The multiplier to convert the least significant 53-bits of a {@code long} to a {@code double}.
* Taken from o.a.c.rng.core.utils.NumberFactory.
*
* <p>This is equivalent to {@code 1.0 / (1L << 53)}.
*/
private static final double DOUBLE_MULTIPLIER = 0x1.0p-53d;
/**
* Sample uniformly from a 1D unit line.
*/
private static final class UnitBallSampler1D extends UnitBallSampler {
/** The source of randomness. */
private final UniformRandomProvider rng;
/**
* @param rng Source of randomness.
*/
UnitBallSampler1D(UniformRandomProvider rng) {
this.rng = rng;
}
@Override
public double[] sample() {
return new double[] {makeSignedDouble(rng.nextLong())};
}
@Override
public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
return new UnitBallSampler1D(rng);
}
}
/**
* Sample uniformly from a 2D unit disk.
*/
private static final class UnitBallSampler2D extends UnitBallSampler {
/** The source of randomness. */
private final UniformRandomProvider rng;
/**
* @param rng Source of randomness.
*/
UnitBallSampler2D(UniformRandomProvider rng) {
this.rng = rng;
}
@Override
public double[] sample() {
// Generate via rejection method of a circle inside a square of edge length 2.
// This should compute approximately 2^2 / pi = 1.27 square positions per sample.
double x;
double y;
do {
x = makeSignedDouble(rng.nextLong());
y = makeSignedDouble(rng.nextLong());
} while (x * x + y * y > 1.0);
return new double[] {x, y};
}
@Override
public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
return new UnitBallSampler2D(rng);
}
}
/**
* Sample uniformly from a 3D unit ball. This is an non-array based specialisation of
* {@link UnitBallSamplerND} for performance.
*/
private static final class UnitBallSampler3D extends UnitBallSampler {
/** The standard normal distribution. */
private final NormalizedGaussianSampler normal;
/** The exponential distribution (mean=1). */
private final ContinuousSampler exp;
/**
* @param rng Source of randomness.
*/
UnitBallSampler3D(UniformRandomProvider rng) {
normal = ZigguratSampler.NormalizedGaussian.of(rng);
// Require an Exponential(mean=2).
// Here we use mean = 1 and scale the output later.
exp = ZigguratSampler.Exponential.of(rng);
}
@Override
public double[] sample() {
final double x = normal.sample();
final double y = normal.sample();
final double z = normal.sample();
// Include the exponential sample. It has mean 1 so multiply by 2.
final double sum = exp.sample() * 2 + x * x + y * y + z * z;
// Note: Handle the possibility of a zero sum and invalid inverse
if (sum == 0) {
return sample();
}
final double f = 1.0 / Math.sqrt(sum);
return new double[] {x * f, y * f, z * f};
}
@Override
public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
return new UnitBallSampler3D(rng);
}
}
/**
* Sample using ball point picking.
* @see <a href="https://mathworld.wolfram.com/BallPointPicking.html">Ball point picking</a>
*/
private static final class UnitBallSamplerND extends UnitBallSampler {
/** The dimension. */
private final int dimension;
/** The standard normal distribution. */
private final NormalizedGaussianSampler normal;
/** The exponential distribution (mean=1). */
private final ContinuousSampler exp;
/**
* @param rng Source of randomness.
* @param dimension Space dimension.
*/
UnitBallSamplerND(UniformRandomProvider rng, int dimension) {
this.dimension = dimension;
normal = ZigguratSampler.NormalizedGaussian.of(rng);
// Require an Exponential(mean=2).
// Here we use mean = 1 and scale the output later.
exp = ZigguratSampler.Exponential.of(rng);
}
@Override
public double[] sample() {
final double[] sample = new double[dimension];
// Include the exponential sample. It has mean 1 so multiply by 2.
double sum = exp.sample() * 2;
for (int i = 0; i < dimension; i++) {
final double x = normal.sample();
sum += x * x;
sample[i] = x;
}
// Note: Handle the possibility of a zero sum and invalid inverse
if (sum == 0) {
return sample();
}
final double f = 1.0 / Math.sqrt(sum);
for (int i = 0; i < dimension; i++) {
sample[i] *= f;
}
return sample;
}
@Override
public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
return new UnitBallSamplerND(rng, dimension);
}
}
/**
* Create an instance.
*/
public UnitBallSampler() {}
/**
* @return a random Cartesian coordinate within the unit n-ball.
*/
@Override
public abstract double[] sample();
/** {@inheritDoc} */
// Redeclare the signature to return a UnitBallSampler not a SharedStateObjectSampler<double[]>
@Override
public abstract UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng);
/**
* Create a unit n-ball sampler for the given dimension.
* Sampled points are uniformly distributed within the unit n-ball.
*
* <p>Sampling is supported in dimensions of 1 or above.
*
* @param rng Source of randomness.
* @param dimension Space dimension.
* @return the sampler
* @throws IllegalArgumentException If {@code dimension <= 0}
*/
public static UnitBallSampler of(UniformRandomProvider rng,
int dimension) {
if (dimension <= 0) {
throw new IllegalArgumentException("Dimension must be strictly positive");
} else if (dimension == ONE_D) {
return new UnitBallSampler1D(rng);
} else if (dimension == TWO_D) {
return new UnitBallSampler2D(rng);
} else if (dimension == THREE_D) {
return new UnitBallSampler3D(rng);
}
return new UnitBallSamplerND(rng, dimension);
}
/**
* Creates a signed double in the range {@code [-1, 1)}. The magnitude is sampled evenly
* from the 2<sup>54</sup> dyadic rationals in the range.
*
* <p>Note: This method will not return samples for both -0.0 and 0.0.
*
* @param bits the bits
* @return the double
*/
private static double makeSignedDouble(long bits) {
// As per o.a.c.rng.core.utils.NumberFactory.makeDouble(long) but using a signed
// shift of 10 in place of an unsigned shift of 11.
// Use the upper 54 bits on the assumption they are more random.
// The sign bit is maintained by the signed shift.
// The next 53 bits generates a magnitude in the range [0, 2^53) or [-2^53, 0).
return (bits >> 10) * DOUBLE_MULTIPLIER;
}
}