ZigguratNormalizedGaussianSampler.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.rng.sampling.distribution;
import org.apache.commons.rng.UniformRandomProvider;
/**
* <a href="https://en.wikipedia.org/wiki/Ziggurat_algorithm">
* Marsaglia and Tsang "Ziggurat" method</a> for sampling from a Gaussian
* distribution with mean 0 and standard deviation 1.
*
* <p>The algorithm is explained in this
* <a href="http://www.jstatsoft.org/article/view/v005i08/ziggurat.pdf">paper</a>
* and this implementation has been adapted from the C code provided therein.</p>
*
* <p>Sampling uses:</p>
*
* <ul>
* <li>{@link UniformRandomProvider#nextLong()}
* <li>{@link UniformRandomProvider#nextDouble()}
* </ul>
*
* @since 1.1
*/
public class ZigguratNormalizedGaussianSampler
implements NormalizedGaussianSampler, SharedStateContinuousSampler {
/** Start of tail. */
private static final double R = 3.6541528853610088;
/** Inverse of R. */
private static final double ONE_OVER_R = 1 / R;
/** Index of last entry in the tables (which have a size that is a power of 2). */
private static final int LAST = 255;
/** Auxiliary table. */
private static final long[] K;
/** Auxiliary table. */
private static final double[] W;
/** Auxiliary table. */
private static final double[] F;
/** Underlying source of randomness. */
private final UniformRandomProvider rng;
static {
// Filling the tables.
// Rectangle area.
final double v = 0.00492867323399;
// Direction support uses the sign bit so the maximum magnitude from the long is 2^63
final double max = Math.pow(2, 63);
final double oneOverMax = 1d / max;
K = new long[LAST + 1];
W = new double[LAST + 1];
F = new double[LAST + 1];
double d = R;
double t = d;
double fd = pdf(d);
final double q = v / fd;
K[0] = (long) ((d / q) * max);
K[1] = 0;
W[0] = q * oneOverMax;
W[LAST] = d * oneOverMax;
F[0] = 1;
F[LAST] = fd;
for (int i = LAST - 1; i >= 1; i--) {
d = Math.sqrt(-2 * Math.log(v / d + fd));
fd = pdf(d);
K[i + 1] = (long) ((d / t) * max);
t = d;
F[i] = fd;
W[i] = d * oneOverMax;
}
}
/**
* Create an instance.
*
* @param rng Generator of uniformly distributed random numbers.
*/
public ZigguratNormalizedGaussianSampler(UniformRandomProvider rng) {
this.rng = rng;
}
/** {@inheritDoc} */
@Override
public double sample() {
final long j = rng.nextLong();
final int i = ((int) j) & LAST;
if (Math.abs(j) < K[i]) {
// This branch is called about 0.985086 times per sample.
return j * W[i];
}
return fix(j, i);
}
/** {@inheritDoc} */
@Override
public String toString() {
return "Ziggurat normalized Gaussian deviate [" + rng.toString() + "]";
}
/**
* Gets the value from the tail of the distribution.
*
* @param hz Start random integer.
* @param iz Index of cell corresponding to {@code hz}.
* @return the requested random value.
*/
private double fix(long hz,
int iz) {
if (iz == 0) {
// Base strip.
// This branch is called about 2.55224E-4 times per sample.
double y;
double x;
do {
// Avoid infinity by creating a non-zero double.
// Note: The extreme value y from -Math.log(2^-53) is (to 4 sf):
// y = 36.74
// The largest value x where 2y < x^2 is false is sqrt(2*36.74):
// x = 8.571
// The extreme tail is:
// out = +/- 12.01
// To generate this requires longs of 0 and then (1377 << 11).
y = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong()));
x = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong())) * ONE_OVER_R;
} while (y + y < x * x);
final double out = R + x;
return hz > 0 ? out : -out;
}
// Wedge of other strips.
// This branch is called about 0.0146584 times per sample.
final double x = hz * W[iz];
if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < pdf(x)) {
// This branch is called about 0.00797887 times per sample.
return x;
}
// Try again.
// This branch is called about 0.00667957 times per sample.
return sample();
}
/**
* Compute the Gaussian probability density function {@code f(x) = e^-0.5x^2}.
*
* @param x Argument.
* @return \( e^{-\frac{x^2}{2}} \)
*/
private static double pdf(double x) {
return Math.exp(-0.5 * x * x);
}
/**
* {@inheritDoc}
*
* @since 1.3
*/
@Override
public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) {
return new ZigguratNormalizedGaussianSampler(rng);
}
/**
* Create a new normalised Gaussian sampler.
*
* @param <S> Sampler type.
* @param rng Generator of uniformly distributed random numbers.
* @return the sampler
* @since 1.3
*/
@SuppressWarnings("unchecked")
public static <S extends NormalizedGaussianSampler & SharedStateContinuousSampler> S
of(UniformRandomProvider rng) {
return (S) new ZigguratNormalizedGaussianSampler(rng);
}
}