001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math4.neuralnet; 019 020import java.util.NoSuchElementException; 021import java.util.List; 022import java.util.ArrayList; 023import java.util.Set; 024import java.util.HashSet; 025import java.util.Collection; 026import java.util.Iterator; 027import java.util.Collections; 028import java.util.Map; 029import java.util.concurrent.ConcurrentHashMap; 030import java.util.concurrent.atomic.AtomicLong; 031import java.util.stream.Collectors; 032 033import org.apache.commons.math4.neuralnet.internal.NeuralNetException; 034 035/** 036 * Neural network, composed of {@link Neuron} instances and the links 037 * between them. 038 * 039 * Although updating a neuron's state is thread-safe, modifying the 040 * network's topology (adding or removing links) is not. 041 * 042 * @since 3.3 043 */ 044public class Network 045 implements Iterable<Neuron> { 046 /** Neurons. */ 047 private final ConcurrentHashMap<Long, Neuron> neuronMap 048 = new ConcurrentHashMap<>(); 049 /** Next available neuron identifier. */ 050 private final AtomicLong nextId; 051 /** Neuron's features set size. */ 052 private final int featureSize; 053 /** Links. */ 054 private final ConcurrentHashMap<Long, Set<Long>> linkMap 055 = new ConcurrentHashMap<>(); 056 057 /** 058 * @param firstId Identifier of the first neuron that will be added 059 * to this network. 060 * @param featureSize Size of the neuron's features. 061 */ 062 public Network(long firstId, 063 int featureSize) { 064 this.nextId = new AtomicLong(firstId); 065 this.featureSize = featureSize; 066 } 067 068 /** 069 * Builds a network from a list of neurons and their neighbours. 070 * 071 * @param featureSize Number of features. 072 * @param idList List of neuron identifiers. 073 * @param featureList List of neuron features. 074 * @param neighbourIdList Links associated to each of the neurons in 075 * {@code idList}. 076 * @throws IllegalArgumentException if an inconsistency is detected. 077 * @return a new instance. 078 */ 079 public static Network from(int featureSize, 080 long[] idList, 081 double[][] featureList, 082 long[][] neighbourIdList) { 083 final int numNeurons = idList.length; 084 if (idList.length != featureList.length) { 085 throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH, 086 idList.length, featureList.length); 087 } 088 if (idList.length != neighbourIdList.length) { 089 throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH, 090 idList.length, neighbourIdList.length); 091 } 092 093 final Network net = new Network(Long.MIN_VALUE, featureSize); 094 095 for (int i = 0; i < numNeurons; i++) { 096 final long id = idList[i]; 097 net.createNeuron(id, featureList[i]); 098 } 099 100 for (int i = 0; i < numNeurons; i++) { 101 final Neuron a = net.getNeuron(idList[i]); 102 for (final long id : neighbourIdList[i]) { 103 final Neuron b = net.neuronMap.get(id); 104 if (b == null) { 105 throw new NeuralNetException(NeuralNetException.ID_NOT_FOUND, id); 106 } 107 net.addLink(a, b); 108 } 109 } 110 111 return net; 112 } 113 114 /** 115 * Performs a deep copy of this instance. 116 * Upon return, the copied and original instances will be independent: 117 * Updating one will not affect the other. 118 * 119 * @return a new instance with the same state as this instance. 120 * @since 3.6 121 */ 122 public synchronized Network copy() { 123 final Network copy = new Network(nextId.get(), 124 featureSize); 125 126 127 for (final Map.Entry<Long, Neuron> e : neuronMap.entrySet()) { 128 copy.neuronMap.put(e.getKey(), e.getValue().copy()); 129 } 130 131 for (final Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) { 132 copy.linkMap.put(e.getKey(), new HashSet<>(e.getValue())); 133 } 134 135 return copy; 136 } 137 138 /** 139 * {@inheritDoc} 140 */ 141 @Override 142 public Iterator<Neuron> iterator() { 143 return neuronMap.values().iterator(); 144 } 145 146 /** 147 * @return a shallow copy of the network's neurons. 148 */ 149 public Collection<Neuron> getNeurons() { 150 return Collections.unmodifiableCollection(neuronMap.values()); 151 } 152 153 /** 154 * Creates a neuron and assigns it a unique identifier. 155 * 156 * @param features Initial values for the neuron's features. 157 * @return the neuron's identifier. 158 * @throws IllegalArgumentException if the length of {@code features} 159 * is different from the expected size (as set by the 160 * {@link #Network(long,int) constructor}). 161 */ 162 public long createNeuron(double[] features) { 163 return createNeuron(createNextId(), features); 164 } 165 166 /** 167 * @param id Identifier. 168 * @param features Features. 169 * @return {@¢ode id}. 170 * @throws IllegalArgumentException if the identifier is already used 171 * by a neuron that belongs to this network or the features size does 172 * not match the expected value. 173 */ 174 private long createNeuron(long id, 175 double[] features) { 176 if (neuronMap.get(id) != null) { 177 throw new NeuralNetException(NeuralNetException.ID_IN_USE, id); 178 } 179 180 if (features.length != featureSize) { 181 throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH, 182 features.length, featureSize); 183 } 184 185 neuronMap.put(id, new Neuron(id, features.clone())); 186 linkMap.put(id, new HashSet<>()); 187 188 if (id > nextId.get()) { 189 nextId.set(id); 190 } 191 192 return id; 193 } 194 195 /** 196 * Deletes a neuron. 197 * Links from all neighbours to the removed neuron will also be 198 * {@link #deleteLink(Neuron,Neuron) deleted}. 199 * 200 * @param neuron Neuron to be removed from this network. 201 * @throws NoSuchElementException if {@code n} does not belong to 202 * this network. 203 */ 204 public void deleteNeuron(Neuron neuron) { 205 // Delete links to from neighbours. 206 getNeighbours(neuron).forEach(neighbour -> deleteLink(neighbour, neuron)); 207 208 // Remove neuron. 209 neuronMap.remove(neuron.getIdentifier()); 210 } 211 212 /** 213 * Gets the size of the neurons' features set. 214 * 215 * @return the size of the features set. 216 */ 217 public int getFeaturesSize() { 218 return featureSize; 219 } 220 221 /** 222 * Adds a link from neuron {@code a} to neuron {@code b}. 223 * Note: the link is not bi-directional; if a bi-directional link is 224 * required, an additional call must be made with {@code a} and 225 * {@code b} exchanged in the argument list. 226 * 227 * @param a Neuron. 228 * @param b Neuron. 229 * @throws NoSuchElementException if the neurons do not exist in the 230 * network. 231 */ 232 public void addLink(Neuron a, 233 Neuron b) { 234 // Check that the neurons belong to this network. 235 final long aId = a.getIdentifier(); 236 if (a != getNeuron(aId)) { 237 throw new NoSuchElementException(Long.toString(aId)); 238 } 239 final long bId = b.getIdentifier(); 240 if (b != getNeuron(bId)) { 241 throw new NoSuchElementException(Long.toString(bId)); 242 } 243 244 // Add link from "a" to "b". 245 addLinkToLinkSet(linkMap.get(aId), bId); 246 } 247 248 /** 249 * Adds a link to neuron {@code id} in given {@code linkSet}. 250 * Note: no check verifies that the identifier indeed belongs 251 * to this network. 252 * 253 * @param linkSet Neuron identifier. 254 * @param id Neuron identifier. 255 */ 256 private void addLinkToLinkSet(Set<Long> linkSet, 257 long id) { 258 linkSet.add(id); 259 } 260 261 /** 262 * Deletes the link between neurons {@code a} and {@code b}. 263 * 264 * @param a Neuron. 265 * @param b Neuron. 266 * @throws NoSuchElementException if the neurons do not exist in the 267 * network. 268 */ 269 public void deleteLink(Neuron a, 270 Neuron b) { 271 // Check that the neurons belong to this network. 272 final long aId = a.getIdentifier(); 273 if (a != getNeuron(aId)) { 274 throw new NoSuchElementException(Long.toString(aId)); 275 } 276 final long bId = b.getIdentifier(); 277 if (b != getNeuron(bId)) { 278 throw new NoSuchElementException(Long.toString(bId)); 279 } 280 281 // Delete link from "a" to "b". 282 deleteLinkFromLinkSet(linkMap.get(aId), bId); 283 } 284 285 /** 286 * Deletes a link to neuron {@code id} in given {@code linkSet}. 287 * Note: no check verifies that the identifier indeed belongs 288 * to this network. 289 * 290 * @param linkSet Neuron identifier. 291 * @param id Neuron identifier. 292 */ 293 private void deleteLinkFromLinkSet(Set<Long> linkSet, 294 long id) { 295 linkSet.remove(id); 296 } 297 298 /** 299 * Retrieves the neuron with the given (unique) {@code id}. 300 * 301 * @param id Identifier. 302 * @return the neuron associated with the given {@code id}. 303 * @throws NoSuchElementException if the neuron does not exist in the 304 * network. 305 */ 306 public Neuron getNeuron(long id) { 307 final Neuron n = neuronMap.get(id); 308 if (n == null) { 309 throw new NoSuchElementException(Long.toString(id)); 310 } 311 return n; 312 } 313 314 /** 315 * Retrieves the neurons in the neighbourhood of any neuron in the 316 * {@code neurons} list. 317 * @param neurons Neurons for which to retrieve the neighbours. 318 * @return the list of neighbours. 319 * @see #getNeighbours(Iterable,Iterable) 320 */ 321 public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) { 322 return getNeighbours(neurons, null); 323 } 324 325 /** 326 * Retrieves the neurons in the neighbourhood of any neuron in the 327 * {@code neurons} list. 328 * The {@code exclude} list allows to retrieve the "concentric" 329 * neighbourhoods by removing the neurons that belong to the inner 330 * "circles". 331 * 332 * @param neurons Neurons for which to retrieve the neighbours. 333 * @param exclude Neurons to exclude from the returned list. 334 * Can be {@code null}. 335 * @return the list of neighbours. 336 */ 337 public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons, 338 Iterable<Neuron> exclude) { 339 final Set<Long> idList = new HashSet<>(); 340 neurons.forEach(n -> idList.addAll(linkMap.get(n.getIdentifier()))); 341 342 if (exclude != null) { 343 exclude.forEach(n -> idList.remove(n.getIdentifier())); 344 } 345 346 return idList.stream().map(this::getNeuron).collect(Collectors.toList()); 347 } 348 349 /** 350 * Retrieves the neighbours of the given neuron. 351 * 352 * @param neuron Neuron for which to retrieve the neighbours. 353 * @return the list of neighbours. 354 * @see #getNeighbours(Neuron,Iterable) 355 */ 356 public Collection<Neuron> getNeighbours(Neuron neuron) { 357 return getNeighbours(neuron, null); 358 } 359 360 /** 361 * Retrieves the neighbours of the given neuron. 362 * 363 * @param neuron Neuron for which to retrieve the neighbours. 364 * @param exclude Neurons to exclude from the returned list. 365 * Can be {@code null}. 366 * @return the list of neighbours. 367 */ 368 public Collection<Neuron> getNeighbours(Neuron neuron, 369 Iterable<Neuron> exclude) { 370 final Set<Long> idList = linkMap.get(neuron.getIdentifier()); 371 if (exclude != null) { 372 for (final Neuron n : exclude) { 373 idList.remove(n.getIdentifier()); 374 } 375 } 376 377 final List<Neuron> neuronList = new ArrayList<>(); 378 for (final Long id : idList) { 379 neuronList.add(getNeuron(id)); 380 } 381 382 return neuronList; 383 } 384 385 /** 386 * Creates a neuron identifier. 387 * 388 * @return a value that will serve as a unique identifier. 389 */ 390 private Long createNextId() { 391 return nextId.getAndIncrement(); 392 } 393}