KohonenUpdateAction.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.neuralnet.sofm;
import java.util.Collection;
import java.util.HashSet;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.DoubleUnaryOperator;
import org.apache.commons.math4.neuralnet.DistanceMeasure;
import org.apache.commons.math4.neuralnet.MapRanking;
import org.apache.commons.math4.neuralnet.Network;
import org.apache.commons.math4.neuralnet.Neuron;
import org.apache.commons.math4.neuralnet.UpdateAction;
/**
* Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen">
* Kohonen's Self-Organizing Map</a>.
* <br>
* The {@link #update(Network,double[]) update} method modifies the
* features {@code w} of the "winning" neuron and its neighbours
* according to the following rule:
* <code>
* w<sub>new</sub> = w<sub>old</sub> + α e<sup>(-d / σ)</sup> * (sample - w<sub>old</sub>)
* </code>
* where
* <ul>
* <li>α is the current <em>learning rate</em>, </li>
* <li>σ is the current <em>neighbourhood size</em>, and</li>
* <li>{@code d} is the number of links to traverse in order to reach
* the neuron from the winning neuron.</li>
* </ul>
* <br>
* This class is thread-safe as long as the arguments passed to the
* {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction,
* NeighbourhoodSizeFunction) constructor} are instances of thread-safe
* classes.
* <br>
* Each call to the {@link #update(Network,double[]) update} method
* will increment the internal counter used to compute the current
* values for
* <ul>
* <li>the <em>learning rate</em>, and</li>
* <li>the <em>neighbourhood size</em>.</li>
* </ul>
* Consequently, the function instances that compute those values (passed
* to the constructor of this class) must take into account whether this
* class's instance will be shared by multiple threads, as this will impact
* the training process.
*
* @since 3.3
*/
public class KohonenUpdateAction implements UpdateAction {
/** Distance function. */
private final DistanceMeasure distance;
/** Learning factor update function. */
private final LearningFactorFunction learningFactor;
/** Neighbourhood size update function. */
private final NeighbourhoodSizeFunction neighbourhoodSize;
/** Number of calls to {@link #update(Network,double[])}. */
private final AtomicLong numberOfCalls = new AtomicLong(0);
/**
* @param distance Distance function.
* @param learningFactor Learning factor update function.
* @param neighbourhoodSize Neighbourhood size update function.
*/
public KohonenUpdateAction(DistanceMeasure distance,
LearningFactorFunction learningFactor,
NeighbourhoodSizeFunction neighbourhoodSize) {
this.distance = distance;
this.learningFactor = learningFactor;
this.neighbourhoodSize = neighbourhoodSize;
}
/**
* {@inheritDoc}
*/
@Override
public void update(Network net,
double[] features) {
final long numCalls = numberOfCalls.incrementAndGet() - 1;
final double currentLearning = learningFactor.value(numCalls);
final Neuron best = findAndUpdateBestNeuron(net,
features,
currentLearning);
final int currentNeighbourhood = neighbourhoodSize.value(numCalls);
// The farther away the neighbour is from the winning neuron, the
// smaller the learning rate will become.
final Gaussian neighbourhoodDecay
= new Gaussian(currentLearning, currentNeighbourhood);
if (currentNeighbourhood > 0) {
// Initial set of neurons only contains the winning neuron.
Collection<Neuron> neighbours = new HashSet<>();
neighbours.add(best);
// Winning neuron must be excluded from the neighbours.
final HashSet<Neuron> exclude = new HashSet<>();
exclude.add(best);
int radius = 1;
do {
// Retrieve immediate neighbours of the current set of neurons.
neighbours = net.getNeighbours(neighbours, exclude);
// Update all the neighbours.
for (final Neuron n : neighbours) {
updateNeighbouringNeuron(n, features, neighbourhoodDecay.applyAsDouble(radius));
}
// Add the neighbours to the exclude list so that they will
// not be updated more than once per training step.
exclude.addAll(neighbours);
++radius;
} while (radius <= currentNeighbourhood);
}
}
/**
* Retrieves the number of calls to the {@link #update(Network,double[]) update}
* method.
*
* @return the current number of calls.
*/
public long getNumberOfCalls() {
return numberOfCalls.get();
}
/**
* Tries to update a neuron.
*
* @param n Neuron to be updated.
* @param features Training data.
* @param learningRate Learning factor.
* @return {@code true} if the update succeeded, {@code true} if a
* concurrent update has been detected.
*/
private boolean attemptNeuronUpdate(Neuron n,
double[] features,
double learningRate) {
final double[] expect = n.getFeatures();
final double[] update = computeFeatures(expect,
features,
learningRate);
return n.compareAndSetFeatures(expect, update);
}
/**
* Atomically updates the given neuron.
*
* @param n Neuron to be updated.
* @param features Training data.
* @param learningRate Learning factor.
*/
private void updateNeighbouringNeuron(Neuron n,
double[] features,
double learningRate) {
while (true) {
if (attemptNeuronUpdate(n, features, learningRate)) {
break;
}
}
}
/**
* Searches for the neuron whose features are closest to the given
* sample, and atomically updates its features.
*
* @param net Network.
* @param features Sample data.
* @param learningRate Current learning factor.
* @return the winning neuron.
*/
private Neuron findAndUpdateBestNeuron(Network net,
double[] features,
double learningRate) {
final MapRanking rank = new MapRanking(net, distance);
while (true) {
final Neuron best = rank.rank(features, 1).get(0);
if (attemptNeuronUpdate(best, features, learningRate)) {
return best;
}
// If another thread modified the state of the winning neuron,
// it may not be the best match anymore for the given training
// sample: Hence, the winner search is performed again.
}
}
/**
* Computes the new value of the features set.
*
* @param current Current values of the features.
* @param sample Training data.
* @param learningRate Learning factor.
* @return the new values for the features.
*/
private double[] computeFeatures(double[] current,
double[] sample,
double learningRate) {
final int len = current.length;
final double[] r = new double[len];
for (int i = 0; i < len; i++) {
final double c = current[i];
final double s = sample[i];
r[i] = c + learningRate * (s - c);
}
return r;
}
/**
* Gaussian function with zero mean.
*/
private static class Gaussian implements DoubleUnaryOperator {
/** Inverse of twice the square of the standard deviation. */
private final double i2s2;
/** Normalization factor. */
private final double norm;
/**
* @param norm Normalization factor.
* @param sigma Standard deviation.
*/
Gaussian(double norm,
double sigma) {
this.norm = norm;
i2s2 = 1d / (2 * sigma * sigma);
}
@Override
public double applyAsDouble(double x) {
return norm * Math.exp(-x * x * i2s2);
}
}
}