001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.rng.sampling.distribution; 019 020import org.apache.commons.rng.UniformRandomProvider; 021 022/** 023 * <a href="https://en.wikipedia.org/wiki/Ziggurat_algorithm"> 024 * Marsaglia and Tsang "Ziggurat" method</a> for sampling from a Gaussian 025 * distribution with mean 0 and standard deviation 1. 026 * 027 * <p>The algorithm is explained in this 028 * <a href="http://www.jstatsoft.org/article/view/v005i08/ziggurat.pdf">paper</a> 029 * and this implementation has been adapted from the C code provided therein.</p> 030 * 031 * <p>Sampling uses:</p> 032 * 033 * <ul> 034 * <li>{@link UniformRandomProvider#nextLong()} 035 * <li>{@link UniformRandomProvider#nextDouble()} 036 * </ul> 037 * 038 * @since 1.1 039 */ 040public class ZigguratNormalizedGaussianSampler 041 implements NormalizedGaussianSampler, SharedStateContinuousSampler { 042 /** Start of tail. */ 043 private static final double R = 3.6541528853610088; 044 /** Inverse of R. */ 045 private static final double ONE_OVER_R = 1 / R; 046 /** Index of last entry in the tables (which have a size that is a power of 2). */ 047 private static final int LAST = 255; 048 /** Auxiliary table. */ 049 private static final long[] K; 050 /** Auxiliary table. */ 051 private static final double[] W; 052 /** Auxiliary table. */ 053 private static final double[] F; 054 055 /** Underlying source of randomness. */ 056 private final UniformRandomProvider rng; 057 058 static { 059 // Filling the tables. 060 // Rectangle area. 061 final double v = 0.00492867323399; 062 // Direction support uses the sign bit so the maximum magnitude from the long is 2^63 063 final double max = Math.pow(2, 63); 064 final double oneOverMax = 1d / max; 065 066 K = new long[LAST + 1]; 067 W = new double[LAST + 1]; 068 F = new double[LAST + 1]; 069 070 double d = R; 071 double t = d; 072 double fd = pdf(d); 073 final double q = v / fd; 074 075 K[0] = (long) ((d / q) * max); 076 K[1] = 0; 077 078 W[0] = q * oneOverMax; 079 W[LAST] = d * oneOverMax; 080 081 F[0] = 1; 082 F[LAST] = fd; 083 084 for (int i = LAST - 1; i >= 1; i--) { 085 d = Math.sqrt(-2 * Math.log(v / d + fd)); 086 fd = pdf(d); 087 088 K[i + 1] = (long) ((d / t) * max); 089 t = d; 090 091 F[i] = fd; 092 093 W[i] = d * oneOverMax; 094 } 095 } 096 097 /** 098 * Create an instance. 099 * 100 * @param rng Generator of uniformly distributed random numbers. 101 */ 102 public ZigguratNormalizedGaussianSampler(UniformRandomProvider rng) { 103 this.rng = rng; 104 } 105 106 /** {@inheritDoc} */ 107 @Override 108 public double sample() { 109 final long j = rng.nextLong(); 110 final int i = ((int) j) & LAST; 111 if (Math.abs(j) < K[i]) { 112 // This branch is called about 0.985086 times per sample. 113 return j * W[i]; 114 } 115 return fix(j, i); 116 } 117 118 /** {@inheritDoc} */ 119 @Override 120 public String toString() { 121 return "Ziggurat normalized Gaussian deviate [" + rng.toString() + "]"; 122 } 123 124 /** 125 * Gets the value from the tail of the distribution. 126 * 127 * @param hz Start random integer. 128 * @param iz Index of cell corresponding to {@code hz}. 129 * @return the requested random value. 130 */ 131 private double fix(long hz, 132 int iz) { 133 if (iz == 0) { 134 // Base strip. 135 // This branch is called about 2.55224E-4 times per sample. 136 double y; 137 double x; 138 do { 139 // Avoid infinity by creating a non-zero double. 140 // Note: The extreme value y from -Math.log(2^-53) is (to 4 sf): 141 // y = 36.74 142 // The largest value x where 2y < x^2 is false is sqrt(2*36.74): 143 // x = 8.571 144 // The extreme tail is: 145 // out = +/- 12.01 146 // To generate this requires longs of 0 and then (1377 << 11). 147 y = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong())); 148 x = -Math.log(InternalUtils.makeNonZeroDouble(rng.nextLong())) * ONE_OVER_R; 149 } while (y + y < x * x); 150 151 final double out = R + x; 152 return hz > 0 ? out : -out; 153 } 154 // Wedge of other strips. 155 // This branch is called about 0.0146584 times per sample. 156 final double x = hz * W[iz]; 157 if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < pdf(x)) { 158 // This branch is called about 0.00797887 times per sample. 159 return x; 160 } 161 // Try again. 162 // This branch is called about 0.00667957 times per sample. 163 return sample(); 164 } 165 166 /** 167 * Compute the Gaussian probability density function {@code f(x) = e^-0.5x^2}. 168 * 169 * @param x Argument. 170 * @return \( e^{-\frac{x^2}{2}} \) 171 */ 172 private static double pdf(double x) { 173 return Math.exp(-0.5 * x * x); 174 } 175 176 /** 177 * {@inheritDoc} 178 * 179 * @since 1.3 180 */ 181 @Override 182 public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) { 183 return new ZigguratNormalizedGaussianSampler(rng); 184 } 185 186 /** 187 * Create a new normalised Gaussian sampler. 188 * 189 * @param <S> Sampler type. 190 * @param rng Generator of uniformly distributed random numbers. 191 * @return the sampler 192 * @since 1.3 193 */ 194 @SuppressWarnings("unchecked") 195 public static <S extends NormalizedGaussianSampler & SharedStateContinuousSampler> S 196 of(UniformRandomProvider rng) { 197 return (S) new ZigguratNormalizedGaussianSampler(rng); 198 } 199}