001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math4.neuralnet.sofm; 019 020import java.util.Collection; 021import java.util.HashSet; 022import java.util.concurrent.atomic.AtomicLong; 023import java.util.function.DoubleUnaryOperator; 024 025import org.apache.commons.math4.neuralnet.DistanceMeasure; 026import org.apache.commons.math4.neuralnet.MapRanking; 027import org.apache.commons.math4.neuralnet.Network; 028import org.apache.commons.math4.neuralnet.Neuron; 029import org.apache.commons.math4.neuralnet.UpdateAction; 030 031/** 032 * Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen"> 033 * Kohonen's Self-Organizing Map</a>. 034 * <br> 035 * The {@link #update(Network,double[]) update} method modifies the 036 * features {@code w} of the "winning" neuron and its neighbours 037 * according to the following rule: 038 * <code> 039 * w<sub>new</sub> = w<sub>old</sub> + α e<sup>(-d / σ)</sup> * (sample - w<sub>old</sub>) 040 * </code> 041 * where 042 * <ul> 043 * <li>α is the current <em>learning rate</em>, </li> 044 * <li>σ is the current <em>neighbourhood size</em>, and</li> 045 * <li>{@code d} is the number of links to traverse in order to reach 046 * the neuron from the winning neuron.</li> 047 * </ul> 048 * <br> 049 * This class is thread-safe as long as the arguments passed to the 050 * {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction, 051 * NeighbourhoodSizeFunction) constructor} are instances of thread-safe 052 * classes. 053 * <br> 054 * Each call to the {@link #update(Network,double[]) update} method 055 * will increment the internal counter used to compute the current 056 * values for 057 * <ul> 058 * <li>the <em>learning rate</em>, and</li> 059 * <li>the <em>neighbourhood size</em>.</li> 060 * </ul> 061 * Consequently, the function instances that compute those values (passed 062 * to the constructor of this class) must take into account whether this 063 * class's instance will be shared by multiple threads, as this will impact 064 * the training process. 065 * 066 * @since 3.3 067 */ 068public class KohonenUpdateAction implements UpdateAction { 069 /** Distance function. */ 070 private final DistanceMeasure distance; 071 /** Learning factor update function. */ 072 private final LearningFactorFunction learningFactor; 073 /** Neighbourhood size update function. */ 074 private final NeighbourhoodSizeFunction neighbourhoodSize; 075 /** Number of calls to {@link #update(Network,double[])}. */ 076 private final AtomicLong numberOfCalls = new AtomicLong(0); 077 078 /** 079 * @param distance Distance function. 080 * @param learningFactor Learning factor update function. 081 * @param neighbourhoodSize Neighbourhood size update function. 082 */ 083 public KohonenUpdateAction(DistanceMeasure distance, 084 LearningFactorFunction learningFactor, 085 NeighbourhoodSizeFunction neighbourhoodSize) { 086 this.distance = distance; 087 this.learningFactor = learningFactor; 088 this.neighbourhoodSize = neighbourhoodSize; 089 } 090 091 /** 092 * {@inheritDoc} 093 */ 094 @Override 095 public void update(Network net, 096 double[] features) { 097 final long numCalls = numberOfCalls.incrementAndGet() - 1; 098 final double currentLearning = learningFactor.value(numCalls); 099 final Neuron best = findAndUpdateBestNeuron(net, 100 features, 101 currentLearning); 102 103 final int currentNeighbourhood = neighbourhoodSize.value(numCalls); 104 // The farther away the neighbour is from the winning neuron, the 105 // smaller the learning rate will become. 106 final Gaussian neighbourhoodDecay 107 = new Gaussian(currentLearning, currentNeighbourhood); 108 109 if (currentNeighbourhood > 0) { 110 // Initial set of neurons only contains the winning neuron. 111 Collection<Neuron> neighbours = new HashSet<>(); 112 neighbours.add(best); 113 // Winning neuron must be excluded from the neighbours. 114 final HashSet<Neuron> exclude = new HashSet<>(); 115 exclude.add(best); 116 117 int radius = 1; 118 do { 119 // Retrieve immediate neighbours of the current set of neurons. 120 neighbours = net.getNeighbours(neighbours, exclude); 121 122 // Update all the neighbours. 123 for (final Neuron n : neighbours) { 124 updateNeighbouringNeuron(n, features, neighbourhoodDecay.applyAsDouble(radius)); 125 } 126 127 // Add the neighbours to the exclude list so that they will 128 // not be updated more than once per training step. 129 exclude.addAll(neighbours); 130 ++radius; 131 } while (radius <= currentNeighbourhood); 132 } 133 } 134 135 /** 136 * Retrieves the number of calls to the {@link #update(Network,double[]) update} 137 * method. 138 * 139 * @return the current number of calls. 140 */ 141 public long getNumberOfCalls() { 142 return numberOfCalls.get(); 143 } 144 145 /** 146 * Tries to update a neuron. 147 * 148 * @param n Neuron to be updated. 149 * @param features Training data. 150 * @param learningRate Learning factor. 151 * @return {@code true} if the update succeeded, {@code true} if a 152 * concurrent update has been detected. 153 */ 154 private boolean attemptNeuronUpdate(Neuron n, 155 double[] features, 156 double learningRate) { 157 final double[] expect = n.getFeatures(); 158 final double[] update = computeFeatures(expect, 159 features, 160 learningRate); 161 162 return n.compareAndSetFeatures(expect, update); 163 } 164 165 /** 166 * Atomically updates the given neuron. 167 * 168 * @param n Neuron to be updated. 169 * @param features Training data. 170 * @param learningRate Learning factor. 171 */ 172 private void updateNeighbouringNeuron(Neuron n, 173 double[] features, 174 double learningRate) { 175 while (true) { 176 if (attemptNeuronUpdate(n, features, learningRate)) { 177 break; 178 } 179 } 180 } 181 182 /** 183 * Searches for the neuron whose features are closest to the given 184 * sample, and atomically updates its features. 185 * 186 * @param net Network. 187 * @param features Sample data. 188 * @param learningRate Current learning factor. 189 * @return the winning neuron. 190 */ 191 private Neuron findAndUpdateBestNeuron(Network net, 192 double[] features, 193 double learningRate) { 194 final MapRanking rank = new MapRanking(net, distance); 195 196 while (true) { 197 final Neuron best = rank.rank(features, 1).get(0); 198 199 if (attemptNeuronUpdate(best, features, learningRate)) { 200 return best; 201 } 202 203 // If another thread modified the state of the winning neuron, 204 // it may not be the best match anymore for the given training 205 // sample: Hence, the winner search is performed again. 206 } 207 } 208 209 /** 210 * Computes the new value of the features set. 211 * 212 * @param current Current values of the features. 213 * @param sample Training data. 214 * @param learningRate Learning factor. 215 * @return the new values for the features. 216 */ 217 private double[] computeFeatures(double[] current, 218 double[] sample, 219 double learningRate) { 220 final int len = current.length; 221 final double[] r = new double[len]; 222 for (int i = 0; i < len; i++) { 223 final double c = current[i]; 224 final double s = sample[i]; 225 r[i] = c + learningRate * (s - c); 226 } 227 return r; 228 } 229 230 /** 231 * Gaussian function with zero mean. 232 */ 233 private static class Gaussian implements DoubleUnaryOperator { 234 /** Inverse of twice the square of the standard deviation. */ 235 private final double i2s2; 236 /** Normalization factor. */ 237 private final double norm; 238 239 /** 240 * @param norm Normalization factor. 241 * @param sigma Standard deviation. 242 */ 243 Gaussian(double norm, 244 double sigma) { 245 this.norm = norm; 246 i2s2 = 1d / (2 * sigma * sigma); 247 } 248 249 @Override 250 public double applyAsDouble(double x) { 251 return norm * Math.exp(-x * x * i2s2); 252 } 253 } 254}