BracketFinder.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.statistics.inference;

import java.util.function.DoubleUnaryOperator;

/**
 * Provide an interval that brackets a local minimum of a function.
 * This code is based on a Python implementation (from <em>SciPy</em>,
 * module {@code optimize.py} v0.5).
 *
 * <p>This class has been extracted from {@code o.a.c.math4.optim.univariate}
 * and modified to: remove support for bracketing a maximum; support bounds
 * on the bracket; correct the sign of the denominator when the magnitude is small;
 * and return true/false if there is a minimum strictly inside the bounds.
 *
 * @since 1.1
 */
class BracketFinder {
    /** Tolerance to avoid division by zero. */
    private static final double EPS_MIN = 1e-21;
    /** Golden section. */
    private static final double GOLD = 1.6180339887498948482;
    /** Factor for expanding the interval. */
    private final double growLimit;
    /**  Number of allowed function evaluations. */
    private final int maxEvaluations;
    /** Number of function evaluations performed in the last search. */
    private int evaluations;
    /** Lower bound of the bracket. */
    private double lo;
    /** Higher bound of the bracket. */
    private double hi;
    /** Point inside the bracket. */
    private double mid;
    /** Function value at {@link #lo}. */
    private double fLo;
    /** Function value at {@link #hi}. */
    private double fHi;
    /** Function value at {@link #mid}. */
    private double fMid;

    /**
     * Constructor with default values {@code 100, 100000} (see the
     * {@link #BracketFinder(double,int) other constructor}).
     */
    BracketFinder() {
        this(100, 100000);
    }

    /**
     * Create a bracketing interval finder.
     *
     * @param growLimit Expanding factor.
     * @param maxEvaluations Maximum number of evaluations allowed for finding
     * a bracketing interval.
     * @throws IllegalArgumentException if the {@code growLimit} or {@code maxEvalutations}
     * are not strictly positive.
     */
    BracketFinder(double growLimit, int maxEvaluations) {
        Arguments.checkStrictlyPositive(growLimit);
        Arguments.checkStrictlyPositive(maxEvaluations);
        this.growLimit = growLimit;
        this.maxEvaluations = maxEvaluations;
    }

    /**
     * Search downhill from the initial points to obtain new points that bracket a local
     * minimum of the function. Note that the initial points do not have to bracket a minimum.
     * An exception is raised if a minimum cannot be found within the configured number
     * of function evaluations.
     *
     * <p>The bracket is limited to the provided bounds if they create a positive interval
     * {@code min < max}. It is possible that the middle of the bracket is at the bounds as
     * the final bracket is {@code f(mid) <= min(f(lo), f(hi))} and {@code lo <= mid <= hi}.
     *
     * <p>No exception is raised if the initial points are not within the bounds; the points
     * are updated to be within the bounds.
     *
     * <p>No exception is raised if the initial points are equal; the bracket will be returned
     * as a single point {@code lo == mid == hi}.
     *
     * @param func Function whose optimum should be bracketed.
     * @param a Initial point.
     * @param b Initial point.
     * @param min Minimum bound of the bracket (inclusive).
     * @param max Maximum bound of the bracket (inclusive).
     * @return true if the mid-point is strictly within the final bracket {@code [lo, hi]};
     * false if there is no local minima.
     * @throws IllegalStateException if the maximum number of evaluations is exceeded.
     */
    boolean search(DoubleUnaryOperator func,
                   double a, double b,
                   double min, double max) {
        evaluations = 0;

        // Limit the range of x
        final DoubleUnaryOperator range;
        if (min < max) {
            // Limit: min <= x <= max
            range = x -> {
                if (x > min) {
                    return x < max ? x : max;
                }
                return min;
            };
        } else {
            range = DoubleUnaryOperator.identity();
        }

        double xA = range.applyAsDouble(a);
        double xB = range.applyAsDouble(b);
        double fA = value(func, xA);
        double fB = value(func, xB);
        // Ensure fB <= fA
        if (fA < fB) {
            double tmp = xA;
            xA = xB;
            xB = tmp;
            tmp = fA;
            fA = fB;
            fB = tmp;
        }

        double xC = range.applyAsDouble(xB + GOLD * (xB - xA));
        double fC = value(func, xC);

        // Note: When a [min, max] interval is provided and there is no minima then this
        // loop will terminate when B == C and both are at the min/max bound.
        while (fC < fB) {
            final double tmp1 = (xB - xA) * (fB - fC);
            final double tmp2 = (xB - xC) * (fB - fA);

            final double val = tmp2 - tmp1;
            // limit magnitude of val to a small value
            final double denom = 2 * Math.copySign(Math.max(Math.abs(val), EPS_MIN), val);

            double w = range.applyAsDouble(xB - ((xB - xC) * tmp2 - (xB - xA) * tmp1) / denom);
            final double wLim = range.applyAsDouble(xB + growLimit * (xC - xB));

            double fW;
            if ((w - xC) * (xB - w) > 0) {
                // xB < w < xC
                fW = value(func, w);
                if (fW < fC) {
                    // minimum in [xB, xC]
                    xA = xB;
                    xB = w;
                    fA = fB;
                    fB = fW;
                    break;
                } else if (fW > fB) {
                    // minimum in [xA, w]
                    xC = w;
                    fC = fW;
                    break;
                }
                // continue downhill
                w = range.applyAsDouble(xC + GOLD * (xC - xB));
                fW = value(func, w);
            } else if ((w - wLim) * (xC - w) > 0) {
                // xC < w < limit
                fW = value(func, w);
                if (fW < fC) {
                    // continue downhill
                    xB = xC;
                    xC = w;
                    w = range.applyAsDouble(xC + GOLD * (xC - xB));
                    fB = fC;
                    fC = fW;
                    fW = value(func, w);
                }
            } else if ((w - wLim) * (wLim - xC) >= 0) {
                // xC <= limit <= w
                w = wLim;
                fW = value(func, w);
            } else {
                // possibly w == xC; reject w and take a default step
                w = range.applyAsDouble(xC + GOLD * (xC - xB));
                fW = value(func, w);
            }

            xA = xB;
            fA = fB;
            xB = xC;
            fB = fC;
            xC = w;
            fC = fW;
        }

        mid = xB;
        fMid = fB;

        // Store the bracket: lo <= mid <= hi
        if (xC < xA) {
            lo = xC;
            fLo = fC;
            hi = xA;
            fHi = fA;
        } else {
            lo = xA;
            fLo = fA;
            hi = xC;
            fHi = fC;
        }

        return lo < mid && mid < hi;
    }

    /**
     * @return the number of evaluations.
     */
    int getEvaluations() {
        return evaluations;
    }

    /**
     * @return the lower bound of the bracket.
     * @see #getFLo()
     */
    double getLo() {
        return lo;
    }

    /**
     * Get function value at {@link #getLo()}.
     * @return function value at {@link #getLo()}
     */
    double getFLo() {
        return fLo;
    }

    /**
     * @return the higher bound of the bracket.
     * @see #getFHi()
     */
    double getHi() {
        return hi;
    }

    /**
     * Get function value at {@link #getHi()}.
     * @return function value at {@link #getHi()}
     */
    double getFHi() {
        return fHi;
    }

    /**
     * @return a point in the middle of the bracket.
     * @see #getFMid()
     */
    double getMid() {
        return mid;
    }

    /**
     * Get function value at {@link #getMid()}.
     * @return function value at {@link #getMid()}
     */
    double getFMid() {
        return fMid;
    }

    /**
     * Get the value of the function.
     *
     * @param func Function.
     * @param x Point.
     * @return the value
     * @throws IllegalStateException if the maximal number of evaluations is exceeded.
     */
    private double value(DoubleUnaryOperator func, double x) {
        if (evaluations >= maxEvaluations) {
            throw new IllegalStateException("Too many evaluations: " + evaluations);
        }
        evaluations++;
        return func.applyAsDouble(x);
    }
}