NelderMeadTransform.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;
import java.util.Comparator;
import java.util.function.UnaryOperator;
import java.util.function.DoublePredicate;
import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
import org.apache.commons.math4.legacy.optim.PointValuePair;
/**
* <a href="https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method">Nelder-Mead method</a>.
*/
public class NelderMeadTransform
implements Simplex.TransformFactory {
/** Default value for {@link #alpha}: {@value}. */
private static final double DEFAULT_ALPHA = 1;
/** Default value for {@link #gamma}: {@value}. */
private static final double DEFAULT_GAMMA = 2;
/** Default value for {@link #rho}: {@value}. */
private static final double DEFAULT_RHO = 0.5;
/** Default value for {@link #sigma}: {@value}. */
private static final double DEFAULT_SIGMA = 0.5;
/** Reflection coefficient. */
private final double alpha;
/** Expansion coefficient. */
private final double gamma;
/** Contraction coefficient. */
private final double rho;
/** Shrinkage coefficient. */
private final double sigma;
/**
* @param alpha Reflection coefficient.
* @param gamma Expansion coefficient.
* @param rho Contraction coefficient.
* @param sigma Shrinkage coefficient.
*/
public NelderMeadTransform(double alpha,
double gamma,
double rho,
double sigma) {
this.alpha = alpha;
this.gamma = gamma;
this.rho = rho;
this.sigma = sigma;
}
/**
* Transform with default values.
*/
public NelderMeadTransform() {
this(DEFAULT_ALPHA,
DEFAULT_GAMMA,
DEFAULT_RHO,
DEFAULT_SIGMA);
}
/** {@inheritDoc} */
@Override
public UnaryOperator<Simplex> create(final MultivariateFunction evaluationFunction,
final Comparator<PointValuePair> comparator,
final DoublePredicate sa) {
return original -> {
// The simplex has n + 1 points if dimension is n.
final int n = original.getDimension();
// Interesting values.
final PointValuePair best = original.get(0);
final PointValuePair secondWorst = original.get(n - 1);
final PointValuePair worst = original.get(n);
final double[] xWorst = worst.getPoint();
// Centroid of the best vertices, dismissing the worst point (at index n).
final double[] centroid = Simplex.centroid(original.asList().subList(0, n));
// Reflection.
final PointValuePair reflected = Simplex.newPoint(centroid,
-alpha,
xWorst,
evaluationFunction);
if (comparator.compare(reflected, secondWorst) < 0 &&
comparator.compare(best, reflected) <= 0) {
return original.replaceLast(reflected);
}
if (comparator.compare(reflected, best) < 0) {
// Expansion.
final PointValuePair expanded = Simplex.newPoint(centroid,
-gamma,
xWorst,
evaluationFunction);
if (comparator.compare(expanded, reflected) < 0 ||
(sa != null &&
sa.test(expanded.getValue() - reflected.getValue()))) {
return original.replaceLast(expanded);
} else {
return original.replaceLast(reflected);
}
}
if (comparator.compare(reflected, worst) < 0) {
// Outside contraction.
final PointValuePair contracted = Simplex.newPoint(centroid,
rho,
reflected.getPoint(),
evaluationFunction);
if (comparator.compare(contracted, reflected) < 0) {
return original.replaceLast(contracted); // Accept contracted point.
}
} else {
// Inside contraction.
final PointValuePair contracted = Simplex.newPoint(centroid,
rho,
xWorst,
evaluationFunction);
if (comparator.compare(contracted, worst) < 0) {
return original.replaceLast(contracted); // Accept contracted point.
}
}
// Shrink.
return original.shrink(sigma, evaluationFunction);
};
}
/** {@inheritDoc} */
@Override
public String toString() {
return "Nelder-Mead [a=" + alpha +
" g=" + gamma +
" r=" + rho +
" s=" + sigma + "]";
}
}