SimpleCurveFitter.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.math4.legacy.fitting;

import java.util.Collections;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.ArrayList;

import org.apache.commons.math4.legacy.exception.ZeroException;
import org.apache.commons.math4.legacy.exception.OutOfRangeException;
import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
import org.apache.commons.math4.legacy.linear.DiagonalMatrix;

/**
 * Fits points to a user-defined {@link ParametricUnivariateFunction function}.
 *
 * @since 3.4
 */
public class SimpleCurveFitter extends AbstractCurveFitter {
    /** Function to fit. */
    private final ParametricUnivariateFunction function;
    /** Initial guess for the parameters. */
    private final double[] initialGuess;
    /** Parameter guesser. */
    private final ParameterGuesser guesser;
    /** Maximum number of iterations of the optimization algorithm. */
    private final int maxIter;

    /**
     * Constructor used by the factory methods.
     *
     * @param function Function to fit.
     * @param initialGuess Initial guess. Cannot be {@code null}. Its length must
     * be consistent with the number of parameters of the {@code function} to fit.
     * @param guesser Method for providing an initial guess (if {@code initialGuess}
     * is {@code null}).
     * @param maxIter Maximum number of iterations of the optimization algorithm.
     */
    protected SimpleCurveFitter(ParametricUnivariateFunction function,
                                double[] initialGuess,
                                ParameterGuesser guesser,
                                int maxIter) {
        this.function = function;
        this.initialGuess = initialGuess;
        this.guesser = guesser;
        this.maxIter = maxIter;
    }

    /**
     * Creates a curve fitter.
     * The maximum number of iterations of the optimization algorithm is set
     * to {@link Integer#MAX_VALUE}.
     *
     * @param f Function to fit.
     * @param start Initial guess for the parameters.  Cannot be {@code null}.
     * Its length must be consistent with the number of parameters of the
     * function to fit.
     * @return a curve fitter.
     *
     * @see #withStartPoint(double[])
     * @see #withMaxIterations(int)
     */
    public static SimpleCurveFitter create(ParametricUnivariateFunction f,
                                           double[] start) {
        return new SimpleCurveFitter(f, start, null, Integer.MAX_VALUE);
    }

    /**
     * Creates a curve fitter.
     * The maximum number of iterations of the optimization algorithm is set
     * to {@link Integer#MAX_VALUE}.
     *
     * @param f Function to fit.
     * @param guesser Method for providing an initial guess.
     * @return a curve fitter.
     *
     * @see #withStartPoint(double[])
     * @see #withMaxIterations(int)
     */
    public static SimpleCurveFitter create(ParametricUnivariateFunction f,
                                           ParameterGuesser guesser) {
        return new SimpleCurveFitter(f, null, guesser, Integer.MAX_VALUE);
    }

    /**
     * Configure the start point (initial guess).
     * @param newStart new start point (initial guess)
     * @return a new instance.
     */
    public SimpleCurveFitter withStartPoint(double[] newStart) {
        return new SimpleCurveFitter(function,
                                     newStart.clone(),
                                     null,
                                     maxIter);
    }

    /**
     * Configure the maximum number of iterations.
     * @param newMaxIter maximum number of iterations
     * @return a new instance.
     */
    public SimpleCurveFitter withMaxIterations(int newMaxIter) {
        return new SimpleCurveFitter(function,
                                     initialGuess,
                                     guesser,
                                     newMaxIter);
    }

    /** {@inheritDoc} */
    @Override
    protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
        // Prepare least-squares problem.
        final int len = observations.size();
        final double[] target  = new double[len];
        final double[] weights = new double[len];

        int count = 0;
        for (WeightedObservedPoint obs : observations) {
            target[count]  = obs.getY();
            weights[count] = obs.getWeight();
            ++count;
        }

        final AbstractCurveFitter.TheoreticalValuesFunction model
            = new AbstractCurveFitter.TheoreticalValuesFunction(function,
                                                                observations);

        final double[] startPoint = initialGuess != null ?
            initialGuess :
            // Compute estimation.
            guesser.guess(observations);

        // Create an optimizer for fitting the curve to the observed points.
        return new LeastSquaresBuilder().
                maxEvaluations(Integer.MAX_VALUE).
                maxIterations(maxIter).
                start(startPoint).
                target(target).
                weight(new DiagonalMatrix(weights)).
                model(model.getModelFunction(), model.getModelFunctionJacobian()).
                build();
    }

    /**
     * Guesses the parameters.
     */
    public abstract static class ParameterGuesser {
        /** Comparator. */
        private static final Comparator<WeightedObservedPoint> CMP = new Comparator<WeightedObservedPoint>() {
                /** {@inheritDoc} */
                @Override
                public int compare(WeightedObservedPoint p1,
                                   WeightedObservedPoint p2) {
                    if (p1 == null && p2 == null) {
                        return 0;
                    }
                    if (p1 == null) {
                        return -1;
                    }
                    if (p2 == null) {
                        return 1;
                    }
                    int comp = Double.compare(p1.getX(), p2.getX());
                    if (comp != 0) {
                        return comp;
                    }
                    comp = Double.compare(p1.getY(), p2.getY());
                    if (comp != 0) {
                        return comp;
                    }
                    return Double.compare(p1.getWeight(), p2.getWeight());
                }
            };

        /**
         * Computes an estimation of the parameters.
         *
         * @param obs Observations.
         * @return the guessed parameters.
         */
        public abstract double[] guess(Collection<WeightedObservedPoint> obs);

        /**
         * Sort the observations.
         *
         * @param unsorted Input observations.
         * @return the input observations, sorted.
         */
        protected List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
            final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
            Collections.sort(observations, CMP);
            return observations;
        }

        /**
         * Finds index of point in specified points with the largest Y.
         *
         * @param points Points to search.
         * @return the index in specified points array.
         */
        protected int findMaxY(WeightedObservedPoint[] points) {
            int maxYIdx = 0;
            for (int i = 1; i < points.length; i++) {
                if (points[i].getY() > points[maxYIdx].getY()) {
                    maxYIdx = i;
                }
            }
            return maxYIdx;
        }

        /**
         * Interpolates using the specified points to determine X at the
         * specified Y.
         *
         * @param points Points to use for interpolation.
         * @param startIdx Index within points from which to start the search for
         * interpolation bounds points.
         * @param idxStep Index step for searching interpolation bounds points.
         * @param y Y value for which X should be determined.
         * @return the value of X for the specified Y.
         * @throws ZeroException if {@code idxStep} is 0.
         * @throws OutOfRangeException if specified {@code y} is not within the
         * range of the specified {@code points}.
         */
        protected double interpolateXAtY(WeightedObservedPoint[] points,
                                         int startIdx,
                                         int idxStep,
                                         double y) {
            if (idxStep == 0) {
                throw new ZeroException();
            }
            final WeightedObservedPoint[] twoPoints
                = getInterpolationPointsForY(points, startIdx, idxStep, y);
            final WeightedObservedPoint p1 = twoPoints[0];
            final WeightedObservedPoint p2 = twoPoints[1];
            if (p1.getY() == y) {
                return p1.getX();
            }
            if (p2.getY() == y) {
                return p2.getX();
            }
            return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
                                (p2.getY() - p1.getY()));
        }

        /**
         * Gets the two bounding interpolation points from the specified points
         * suitable for determining X at the specified Y.
         *
         * @param points Points to use for interpolation.
         * @param startIdx Index within points from which to start search for
         * interpolation bounds points.
         * @param idxStep Index step for search for interpolation bounds points.
         * @param y Y value for which X should be determined.
         * @return the array containing two points suitable for determining X at
         * the specified Y.
         * @throws ZeroException if {@code idxStep} is 0.
         * @throws OutOfRangeException if specified {@code y} is not within the
         * range of the specified {@code points}.
         */
        private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
                                                                   int startIdx,
                                                                   int idxStep,
                                                                   double y) {
            if (idxStep == 0) {
                throw new ZeroException();
            }
            for (int i = startIdx;
                 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
                 i += idxStep) {
                final WeightedObservedPoint p1 = points[i];
                final WeightedObservedPoint p2 = points[i + idxStep];
                if (isBetween(y, p1.getY(), p2.getY())) {
                    if (idxStep < 0) {
                        return new WeightedObservedPoint[] { p2, p1 };
                    } else {
                        return new WeightedObservedPoint[] { p1, p2 };
                    }
                }
            }

            // Boundaries are replaced by dummy values because the raised
            // exception is caught and the message never displayed.
            // TODO: Exceptions should not be used for flow control.
            throw new OutOfRangeException(y,
                                          Double.NEGATIVE_INFINITY,
                                          Double.POSITIVE_INFINITY);
        }

        /**
         * Determines whether a value is between two other values.
         *
         * @param value Value to test whether it is between {@code boundary1}
         * and {@code boundary2}.
         * @param boundary1 One end of the range.
         * @param boundary2 Other end of the range.
         * @return {@code true} if {@code value} is between {@code boundary1} and
         * {@code boundary2} (inclusive), {@code false} otherwise.
         */
        private boolean isBetween(double value,
                                  double boundary1,
                                  double boundary2) {
            return (value >= boundary1 && value <= boundary2) ||
                (value >= boundary2 && value <= boundary1);
        }
    }
}