UniformRandomProviderSupport.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.rng;
import java.util.Objects;
import java.util.Spliterator;
import java.util.function.Consumer;
import java.util.function.DoubleConsumer;
import java.util.function.IntConsumer;
import java.util.function.LongConsumer;
import java.util.function.ToDoubleFunction;
import java.util.function.ToIntFunction;
import java.util.function.ToLongFunction;
/**
* Support for {@link UniformRandomProvider} default methods.
*
* @since 1.5
*/
final class UniformRandomProviderSupport {
/** Message for an invalid stream size. */
private static final String INVALID_STREAM_SIZE = "Invalid stream size: ";
/** Message for an invalid upper bound (must be positive, finite and above zero). */
private static final String INVALID_UPPER_BOUND = "Upper bound must be above zero: ";
/** Message format for an invalid range for lower inclusive and upper exclusive. */
private static final String INVALID_RANGE = "Invalid range: [%s, %s)";
/** 2^32. */
private static final long POW_32 = 1L << 32;
/** Message when the consumer action is null. */
private static final String NULL_ACTION = "action must not be null";
/** No instances. */
private UniformRandomProviderSupport() {}
/**
* Validate the stream size.
*
* @param size Stream size.
* @throws IllegalArgumentException if {@code size} is negative.
*/
static void validateStreamSize(long size) {
if (size < 0) {
throw new IllegalArgumentException(INVALID_STREAM_SIZE + size);
}
}
/**
* Validate the upper bound.
*
* @param bound Upper bound (exclusive) on the random number to be returned.
* @throws IllegalArgumentException if {@code bound} is equal to or less than zero.
*/
static void validateUpperBound(int bound) {
if (bound <= 0) {
throw new IllegalArgumentException(INVALID_UPPER_BOUND + bound);
}
}
/**
* Validate the upper bound.
*
* @param bound Upper bound (exclusive) on the random number to be returned.
* @throws IllegalArgumentException if {@code bound} is equal to or less than zero.
*/
static void validateUpperBound(long bound) {
if (bound <= 0) {
throw new IllegalArgumentException(INVALID_UPPER_BOUND + bound);
}
}
/**
* Validate the upper bound.
*
* @param bound Upper bound (exclusive) on the random number to be returned.
* @throws IllegalArgumentException if {@code bound} is equal to or less than zero, or
* is not finite
*/
static void validateUpperBound(float bound) {
// Negation of logic will detect NaN
if (!(bound > 0 && bound <= Float.MAX_VALUE)) {
throw new IllegalArgumentException(INVALID_UPPER_BOUND + bound);
}
}
/**
* Validate the upper bound.
*
* @param bound Upper bound (exclusive) on the random number to be returned.
* @throws IllegalArgumentException if {@code bound} is equal to or less than zero, or
* is not finite
*/
static void validateUpperBound(double bound) {
// Negation of logic will detect NaN
if (!(bound > 0 && bound <= Double.MAX_VALUE)) {
throw new IllegalArgumentException(INVALID_UPPER_BOUND + bound);
}
}
/**
* Validate the range between the specified {@code origin} (inclusive) and the
* specified {@code bound} (exclusive).
*
* @param origin Lower bound on the random number to be returned.
* @param bound Upper bound (exclusive) on the random number to be returned.
* @throws IllegalArgumentException if {@code origin} is greater than or equal to
* {@code bound}.
*/
static void validateRange(int origin, int bound) {
if (origin >= bound) {
throw new IllegalArgumentException(String.format(INVALID_RANGE, origin, bound));
}
}
/**
* Validate the range between the specified {@code origin} (inclusive) and the
* specified {@code bound} (exclusive).
*
* @param origin Lower bound on the random number to be returned.
* @param bound Upper bound (exclusive) on the random number to be returned.
* @throws IllegalArgumentException if {@code origin} is greater than or equal to
* {@code bound}.
*/
static void validateRange(long origin, long bound) {
if (origin >= bound) {
throw new IllegalArgumentException(String.format(INVALID_RANGE, origin, bound));
}
}
/**
* Validate the range between the specified {@code origin} (inclusive) and the
* specified {@code bound} (exclusive).
*
* @param origin Lower bound on the random number to be returned.
* @param bound Upper bound (exclusive) on the random number to be returned.
* @throws IllegalArgumentException if {@code origin} is not finite, or {@code bound}
* is not finite, or {@code origin} is greater than or equal to {@code bound}.
*/
static void validateRange(double origin, double bound) {
if (origin >= bound || !Double.isFinite(origin) || !Double.isFinite(bound)) {
throw new IllegalArgumentException(String.format(INVALID_RANGE, origin, bound));
}
}
/**
* Checks if the sub-range from fromIndex (inclusive) to fromIndex + size (exclusive) is
* within the bounds of range from 0 (inclusive) to length (exclusive).
*
* <p>This function provides the functionality of
* {@code java.utils.Objects.checkFromIndexSize} introduced in JDK 9. The
* <a href="https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/util/Objects.html#checkFromIndexSize(int,int,int)">Objects</a>
* javadoc has been reproduced for reference.
*
* <p>The sub-range is defined to be out of bounds if any of the following inequalities
* is true:
* <ul>
* <li>{@code fromIndex < 0}
* <li>{@code size < 0}
* <li>{@code fromIndex + size > length}, taking into account integer overflow
* <li>{@code length < 0}, which is implied from the former inequalities
* </ul>
*
* <p>Note: This is not an exact implementation of the functionality of
* {@code Objects.checkFromIndexSize}. The following changes have been made:
* <ul>
* <li>The method signature has been changed to avoid the return of {@code fromIndex};
* this value is not used within this package.
* <li>No checks are made for {@code length < 0} as this is assumed to be derived from
* an array length.
* </ul>
*
* @param fromIndex the lower-bound (inclusive) of the sub-interval
* @param size the size of the sub-range
* @param length the upper-bound (exclusive) of the range
* @throws IndexOutOfBoundsException if the sub-range is out of bounds
*/
static void validateFromIndexSize(int fromIndex, int size, int length) {
// check for any negatives (assume 'length' is positive array length),
// or overflow safe length check given the values are all positive
// remaining = length - fromIndex
if ((fromIndex | size) < 0 || size > length - fromIndex) {
throw new IndexOutOfBoundsException(
// Note: %<d is 'relative indexing' to re-use the last argument
String.format("Range [%d, %<d + %d) out of bounds for length %d",
fromIndex, size, length));
}
}
/**
* Generates random bytes and places them into a user-supplied array.
*
* <p>The array is filled with bytes extracted from random {@code long} values. This
* implies that the number of random bytes generated may be larger than the length of
* the byte array.
*
* @param source Source of randomness.
* @param bytes Array in which to put the generated bytes. Cannot be null.
* @param start Index at which to start inserting the generated bytes.
* @param len Number of bytes to insert.
*/
static void nextBytes(UniformRandomProvider source,
byte[] bytes, int start, int len) {
// Index of first insertion plus multiple of 8 part of length
// (i.e. length with 3 least significant bits unset).
final int indexLoopLimit = start + (len & 0x7ffffff8);
// Start filling in the byte array, 8 bytes at a time.
int index = start;
while (index < indexLoopLimit) {
final long random = source.nextLong();
bytes[index++] = (byte) random;
bytes[index++] = (byte) (random >>> 8);
bytes[index++] = (byte) (random >>> 16);
bytes[index++] = (byte) (random >>> 24);
bytes[index++] = (byte) (random >>> 32);
bytes[index++] = (byte) (random >>> 40);
bytes[index++] = (byte) (random >>> 48);
bytes[index++] = (byte) (random >>> 56);
}
// Index of last insertion + 1
final int indexLimit = start + len;
// Fill in the remaining bytes.
if (index < indexLimit) {
long random = source.nextLong();
for (;;) {
bytes[index++] = (byte) random;
if (index == indexLimit) {
break;
}
random >>>= 8;
}
}
}
/**
* Generates an {@code int} value between 0 (inclusive) and the specified value
* (exclusive).
*
* @param source Source of randomness.
* @param n Bound on the random number to be returned. Must be strictly positive.
* @return a random {@code int} value between 0 (inclusive) and {@code n} (exclusive).
*/
static int nextInt(UniformRandomProvider source,
int n) {
// Lemire (2019): Fast Random Integer Generation in an Interval
// https://arxiv.org/abs/1805.10941
long m = (source.nextInt() & 0xffffffffL) * n;
long l = m & 0xffffffffL;
if (l < n) {
// 2^32 % n
final long t = POW_32 % n;
while (l < t) {
m = (source.nextInt() & 0xffffffffL) * n;
l = m & 0xffffffffL;
}
}
return (int) (m >>> 32);
}
/**
* Generates an {@code int} value between the specified {@code origin} (inclusive) and
* the specified {@code bound} (exclusive).
*
* @param source Source of randomness.
* @param origin Lower bound on the random number to be returned.
* @param bound Upper bound (exclusive) on the random number to be returned. Must be
* above {@code origin}.
* @return a random {@code int} value between {@code origin} (inclusive) and
* {@code bound} (exclusive).
*/
static int nextInt(UniformRandomProvider source,
int origin, int bound) {
final int n = bound - origin;
if (n > 0) {
return nextInt(source, n) + origin;
}
// Range too large to fit in a positive integer.
// Use simple rejection.
int v = source.nextInt();
while (v < origin || v >= bound) {
v = source.nextInt();
}
return v;
}
/**
* Generates an {@code long} value between 0 (inclusive) and the specified value
* (exclusive).
*
* @param source Source of randomness.
* @param n Bound on the random number to be returned. Must be strictly positive.
* @return a random {@code long} value between 0 (inclusive) and {@code n}
* (exclusive).
*/
static long nextLong(UniformRandomProvider source,
long n) {
long bits;
long val;
do {
bits = source.nextLong() >>> 1;
val = bits % n;
} while (bits - val + (n - 1) < 0);
return val;
}
/**
* Generates a {@code long} value between the specified {@code origin} (inclusive) and
* the specified {@code bound} (exclusive).
*
* @param source Source of randomness.
* @param origin Lower bound on the random number to be returned.
* @param bound Upper bound (exclusive) on the random number to be returned. Must be
* above {@code origin}.
* @return a random {@code long} value between {@code origin} (inclusive) and
* {@code bound} (exclusive).
*/
static long nextLong(UniformRandomProvider source,
long origin, long bound) {
final long n = bound - origin;
if (n > 0) {
return nextLong(source, n) + origin;
}
// Range too large to fit in a positive integer.
// Use simple rejection.
long v = source.nextLong();
while (v < origin || v >= bound) {
v = source.nextLong();
}
return v;
}
/**
* Generates a {@code float} value between 0 (inclusive) and the specified value
* (exclusive).
*
* @param source Source of randomness.
* @param bound Bound on the random number to be returned. Must be strictly positive.
* @return a random {@code float} value between 0 (inclusive) and {@code bound}
* (exclusive).
*/
static float nextFloat(UniformRandomProvider source,
float bound) {
float v = source.nextFloat() * bound;
if (v >= bound) {
// Correct rounding
v = Math.nextDown(bound);
}
return v;
}
/**
* Generates a {@code float} value between the specified {@code origin} (inclusive)
* and the specified {@code bound} (exclusive).
*
* @param source Source of randomness.
* @param origin Lower bound on the random number to be returned. Must be finite.
* @param bound Upper bound (exclusive) on the random number to be returned. Must be
* above {@code origin} and finite.
* @return a random {@code float} value between {@code origin} (inclusive) and
* {@code bound} (exclusive).
*/
static float nextFloat(UniformRandomProvider source,
float origin, float bound) {
float v = source.nextFloat();
// This expression allows (bound - origin) to be infinite
// origin + (bound - origin) * v
// == origin - origin * v + bound * v
v = (1f - v) * origin + v * bound;
if (v >= bound) {
// Correct rounding
v = Math.nextDown(bound);
}
return v;
}
/**
* Generates a {@code double} value between 0 (inclusive) and the specified value
* (exclusive).
*
* @param source Source of randomness.
* @param bound Bound on the random number to be returned. Must be strictly positive.
* @return a random {@code double} value between 0 (inclusive) and {@code bound}
* (exclusive).
*/
static double nextDouble(UniformRandomProvider source,
double bound) {
double v = source.nextDouble() * bound;
if (v >= bound) {
// Correct rounding
v = Math.nextDown(bound);
}
return v;
}
/**
* Generates a {@code double} value between the specified {@code origin} (inclusive)
* and the specified {@code bound} (exclusive).
*
* @param source Source of randomness.
* @param origin Lower bound on the random number to be returned. Must be finite.
* @param bound Upper bound (exclusive) on the random number to be returned. Must be
* above {@code origin} and finite.
* @return a random {@code double} value between {@code origin} (inclusive) and
* {@code bound} (exclusive).
*/
static double nextDouble(UniformRandomProvider source,
double origin, double bound) {
double v = source.nextDouble();
// This expression allows (bound - origin) to be infinite
// origin + (bound - origin) * v
// == origin - origin * v + bound * v
v = (1f - v) * origin + v * bound;
if (v >= bound) {
// Correct rounding
v = Math.nextDown(bound);
}
return v;
}
// Spliterator support
/**
* Base class for spliterators for streams of values. Contains the range current position and
* end position. Splitting is expected to divide the range in half and create instances
* that span the two ranges.
*/
private static class ProviderSpliterator {
/** The current position in the range. */
protected long position;
/** The upper limit of the range. */
protected final long end;
/**
* @param start Start position of the stream (inclusive).
* @param end Upper limit of the stream (exclusive).
*/
ProviderSpliterator(long start, long end) {
position = start;
this.end = end;
}
// Methods required by all Spliterators
// See Spliterator.estimateSize()
public long estimateSize() {
return end - position;
}
// See Spliterator.characteristics()
public int characteristics() {
return Spliterator.SIZED | Spliterator.SUBSIZED | Spliterator.NONNULL | Spliterator.IMMUTABLE;
}
}
/**
* Spliterator for streams of SplittableUniformRandomProvider.
*/
static class ProviderSplitsSpliterator extends ProviderSpliterator
implements Spliterator<SplittableUniformRandomProvider> {
/** Source of randomness used to initialise the new instances. */
private final SplittableUniformRandomProvider source;
/** Generator to split to create new instances. */
private final SplittableUniformRandomProvider rng;
/**
* @param start Start position of the stream (inclusive).
* @param end Upper limit of the stream (exclusive).
* @param source Source of randomness used to initialise the new instances.
* @param rng Generator to split to create new instances.
*/
ProviderSplitsSpliterator(long start, long end,
SplittableUniformRandomProvider source,
SplittableUniformRandomProvider rng) {
super(start, end);
this.source = source;
this.rng = rng;
}
@Override
public Spliterator<SplittableUniformRandomProvider> trySplit() {
final long start = position;
final long middle = (start + end) >>> 1;
if (middle <= start) {
return null;
}
position = middle;
return new ProviderSplitsSpliterator(start, middle, source.split(), rng);
}
@Override
public boolean tryAdvance(Consumer<? super SplittableUniformRandomProvider> action) {
Objects.requireNonNull(action, NULL_ACTION);
final long pos = position;
if (pos < end) {
// Advance before exceptions from the action are relayed to the caller
position = pos + 1;
action.accept(rng.split(source));
return true;
}
return false;
}
@Override
public void forEachRemaining(Consumer<? super SplittableUniformRandomProvider> action) {
Objects.requireNonNull(action, NULL_ACTION);
long pos = position;
final long last = end;
if (pos < last) {
// Ensure forEachRemaining is called only once
position = last;
final SplittableUniformRandomProvider s = source;
final SplittableUniformRandomProvider r = rng;
do {
action.accept(r.split(s));
} while (++pos < last);
}
}
}
/**
* Spliterator for streams of int values that may be recursively split.
*/
static class ProviderIntsSpliterator extends ProviderSpliterator
implements Spliterator.OfInt {
/** Source of randomness. */
private final SplittableUniformRandomProvider source;
/** Value generator function. */
private final ToIntFunction<SplittableUniformRandomProvider> gen;
/**
* @param start Start position of the stream (inclusive).
* @param end Upper limit of the stream (exclusive).
* @param source Source of randomness.
* @param gen Value generator function.
*/
ProviderIntsSpliterator(long start, long end,
SplittableUniformRandomProvider source,
ToIntFunction<SplittableUniformRandomProvider> gen) {
super(start, end);
this.source = source;
this.gen = gen;
}
@Override
public Spliterator.OfInt trySplit() {
final long start = position;
final long middle = (start + end) >>> 1;
if (middle <= start) {
return null;
}
position = middle;
return new ProviderIntsSpliterator(start, middle, source.split(), gen);
}
@Override
public boolean tryAdvance(IntConsumer action) {
Objects.requireNonNull(action, NULL_ACTION);
final long pos = position;
if (pos < end) {
// Advance before exceptions from the action are relayed to the caller
position = pos + 1;
action.accept(gen.applyAsInt(source));
return true;
}
return false;
}
@Override
public void forEachRemaining(IntConsumer action) {
Objects.requireNonNull(action, NULL_ACTION);
long pos = position;
final long last = end;
if (pos < last) {
// Ensure forEachRemaining is called only once
position = last;
final SplittableUniformRandomProvider s = source;
final ToIntFunction<SplittableUniformRandomProvider> g = gen;
do {
action.accept(g.applyAsInt(s));
} while (++pos < last);
}
}
}
/**
* Spliterator for streams of long values that may be recursively split.
*/
static class ProviderLongsSpliterator extends ProviderSpliterator
implements Spliterator.OfLong {
/** Source of randomness. */
private final SplittableUniformRandomProvider source;
/** Value generator function. */
private final ToLongFunction<SplittableUniformRandomProvider> gen;
/**
* @param start Start position of the stream (inclusive).
* @param end Upper limit of the stream (exclusive).
* @param source Source of randomness.
* @param gen Value generator function.
*/
ProviderLongsSpliterator(long start, long end,
SplittableUniformRandomProvider source,
ToLongFunction<SplittableUniformRandomProvider> gen) {
super(start, end);
this.source = source;
this.gen = gen;
}
@Override
public Spliterator.OfLong trySplit() {
final long start = position;
final long middle = (start + end) >>> 1;
if (middle <= start) {
return null;
}
position = middle;
return new ProviderLongsSpliterator(start, middle, source.split(), gen);
}
@Override
public boolean tryAdvance(LongConsumer action) {
Objects.requireNonNull(action, NULL_ACTION);
final long pos = position;
if (pos < end) {
// Advance before exceptions from the action are relayed to the caller
position = pos + 1;
action.accept(gen.applyAsLong(source));
return true;
}
return false;
}
@Override
public void forEachRemaining(LongConsumer action) {
Objects.requireNonNull(action, NULL_ACTION);
long pos = position;
final long last = end;
if (pos < last) {
// Ensure forEachRemaining is called only once
position = last;
final SplittableUniformRandomProvider s = source;
final ToLongFunction<SplittableUniformRandomProvider> g = gen;
do {
action.accept(g.applyAsLong(s));
} while (++pos < last);
}
}
}
/**
* Spliterator for streams of double values that may be recursively split.
*/
static class ProviderDoublesSpliterator extends ProviderSpliterator
implements Spliterator.OfDouble {
/** Source of randomness. */
private final SplittableUniformRandomProvider source;
/** Value generator function. */
private final ToDoubleFunction<SplittableUniformRandomProvider> gen;
/**
* @param start Start position of the stream (inclusive).
* @param end Upper limit of the stream (exclusive).
* @param source Source of randomness.
* @param gen Value generator function.
*/
ProviderDoublesSpliterator(long start, long end,
SplittableUniformRandomProvider source,
ToDoubleFunction<SplittableUniformRandomProvider> gen) {
super(start, end);
this.source = source;
this.gen = gen;
}
@Override
public Spliterator.OfDouble trySplit() {
final long start = position;
final long middle = (start + end) >>> 1;
if (middle <= start) {
return null;
}
position = middle;
return new ProviderDoublesSpliterator(start, middle, source.split(), gen);
}
@Override
public boolean tryAdvance(DoubleConsumer action) {
Objects.requireNonNull(action, NULL_ACTION);
final long pos = position;
if (pos < end) {
// Advance before exceptions from the action are relayed to the caller
position = pos + 1;
action.accept(gen.applyAsDouble(source));
return true;
}
return false;
}
@Override
public void forEachRemaining(DoubleConsumer action) {
Objects.requireNonNull(action, NULL_ACTION);
long pos = position;
final long last = end;
if (pos < last) {
// Ensure forEachRemaining is called only once
position = last;
final SplittableUniformRandomProvider s = source;
final ToDoubleFunction<SplittableUniformRandomProvider> g = gen;
do {
action.accept(g.applyAsDouble(s));
} while (++pos < last);
}
}
}
}