1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.neuralnet;
19
20 import java.util.NoSuchElementException;
21 import java.util.List;
22 import java.util.ArrayList;
23 import java.util.Set;
24 import java.util.HashSet;
25 import java.util.Collection;
26 import java.util.Iterator;
27 import java.util.Collections;
28 import java.util.Map;
29 import java.util.concurrent.ConcurrentHashMap;
30 import java.util.concurrent.atomic.AtomicLong;
31 import java.util.stream.Collectors;
32
33 import org.apache.commons.math4.neuralnet.internal.NeuralNetException;
34
35
36
37
38
39
40
41
42
43
44 public class Network
45 implements Iterable<Neuron> {
46
47 private final ConcurrentHashMap<Long, Neuron> neuronMap
48 = new ConcurrentHashMap<>();
49
50 private final AtomicLong nextId;
51
52 private final int featureSize;
53
54 private final ConcurrentHashMap<Long, Set<Long>> linkMap
55 = new ConcurrentHashMap<>();
56
57
58
59
60
61
62 public Network(long firstId,
63 int featureSize) {
64 this.nextId = new AtomicLong(firstId);
65 this.featureSize = featureSize;
66 }
67
68
69
70
71
72
73
74
75
76
77
78
79 public static Network from(int featureSize,
80 long[] idList,
81 double[][] featureList,
82 long[][] neighbourIdList) {
83 final int numNeurons = idList.length;
84 if (idList.length != featureList.length) {
85 throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH,
86 idList.length, featureList.length);
87 }
88 if (idList.length != neighbourIdList.length) {
89 throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH,
90 idList.length, neighbourIdList.length);
91 }
92
93 final Network net = new Network(Long.MIN_VALUE, featureSize);
94
95 for (int i = 0; i < numNeurons; i++) {
96 final long id = idList[i];
97 net.createNeuron(id, featureList[i]);
98 }
99
100 for (int i = 0; i < numNeurons; i++) {
101 final Neuron a = net.getNeuron(idList[i]);
102 for (final long id : neighbourIdList[i]) {
103 final Neuron b = net.neuronMap.get(id);
104 if (b == null) {
105 throw new NeuralNetException(NeuralNetException.ID_NOT_FOUND, id);
106 }
107 net.addLink(a, b);
108 }
109 }
110
111 return net;
112 }
113
114
115
116
117
118
119
120
121
122 public synchronized Network copy() {
123 final Network copy = new Network(nextId.get(),
124 featureSize);
125
126
127 for (final Map.Entry<Long, Neuron> e : neuronMap.entrySet()) {
128 copy.neuronMap.put(e.getKey(), e.getValue().copy());
129 }
130
131 for (final Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) {
132 copy.linkMap.put(e.getKey(), new HashSet<>(e.getValue()));
133 }
134
135 return copy;
136 }
137
138
139
140
141 @Override
142 public Iterator<Neuron> iterator() {
143 return neuronMap.values().iterator();
144 }
145
146
147
148
149 public Collection<Neuron> getNeurons() {
150 return Collections.unmodifiableCollection(neuronMap.values());
151 }
152
153
154
155
156
157
158
159
160
161
162 public long createNeuron(double[] features) {
163 return createNeuron(createNextId(), features);
164 }
165
166
167
168
169
170
171
172
173
174 private long createNeuron(long id,
175 double[] features) {
176 if (neuronMap.get(id) != null) {
177 throw new NeuralNetException(NeuralNetException.ID_IN_USE, id);
178 }
179
180 if (features.length != featureSize) {
181 throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH,
182 features.length, featureSize);
183 }
184
185 neuronMap.put(id, new Neuron(id, features.clone()));
186 linkMap.put(id, new HashSet<>());
187
188 if (id > nextId.get()) {
189 nextId.set(id);
190 }
191
192 return id;
193 }
194
195
196
197
198
199
200
201
202
203
204 public void deleteNeuron(Neuron neuron) {
205
206 getNeighbours(neuron).forEach(neighbour -> deleteLink(neighbour, neuron));
207
208
209 neuronMap.remove(neuron.getIdentifier());
210 }
211
212
213
214
215
216
217 public int getFeaturesSize() {
218 return featureSize;
219 }
220
221
222
223
224
225
226
227
228
229
230
231
232 public void addLink(Neuron a,
233 Neuron b) {
234
235 final long aId = a.getIdentifier();
236 if (a != getNeuron(aId)) {
237 throw new NoSuchElementException(Long.toString(aId));
238 }
239 final long bId = b.getIdentifier();
240 if (b != getNeuron(bId)) {
241 throw new NoSuchElementException(Long.toString(bId));
242 }
243
244
245 addLinkToLinkSet(linkMap.get(aId), bId);
246 }
247
248
249
250
251
252
253
254
255
256 private void addLinkToLinkSet(Set<Long> linkSet,
257 long id) {
258 linkSet.add(id);
259 }
260
261
262
263
264
265
266
267
268
269 public void deleteLink(Neuron a,
270 Neuron b) {
271
272 final long aId = a.getIdentifier();
273 if (a != getNeuron(aId)) {
274 throw new NoSuchElementException(Long.toString(aId));
275 }
276 final long bId = b.getIdentifier();
277 if (b != getNeuron(bId)) {
278 throw new NoSuchElementException(Long.toString(bId));
279 }
280
281
282 deleteLinkFromLinkSet(linkMap.get(aId), bId);
283 }
284
285
286
287
288
289
290
291
292
293 private void deleteLinkFromLinkSet(Set<Long> linkSet,
294 long id) {
295 linkSet.remove(id);
296 }
297
298
299
300
301
302
303
304
305
306 public Neuron getNeuron(long id) {
307 final Neuron n = neuronMap.get(id);
308 if (n == null) {
309 throw new NoSuchElementException(Long.toString(id));
310 }
311 return n;
312 }
313
314
315
316
317
318
319
320
321 public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) {
322 return getNeighbours(neurons, null);
323 }
324
325
326
327
328
329
330
331
332
333
334
335
336
337 public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons,
338 Iterable<Neuron> exclude) {
339 final Set<Long> idList = new HashSet<>();
340 neurons.forEach(n -> idList.addAll(linkMap.get(n.getIdentifier())));
341
342 if (exclude != null) {
343 exclude.forEach(n -> idList.remove(n.getIdentifier()));
344 }
345
346 return idList.stream().map(this::getNeuron).collect(Collectors.toList());
347 }
348
349
350
351
352
353
354
355
356 public Collection<Neuron> getNeighbours(Neuron neuron) {
357 return getNeighbours(neuron, null);
358 }
359
360
361
362
363
364
365
366
367
368 public Collection<Neuron> getNeighbours(Neuron neuron,
369 Iterable<Neuron> exclude) {
370 final Set<Long> idList = linkMap.get(neuron.getIdentifier());
371 if (exclude != null) {
372 for (final Neuron n : exclude) {
373 idList.remove(n.getIdentifier());
374 }
375 }
376
377 final List<Neuron> neuronList = new ArrayList<>();
378 for (final Long id : idList) {
379 neuronList.add(getNeuron(id));
380 }
381
382 return neuronList;
383 }
384
385
386
387
388
389
390 private Long createNextId() {
391 return nextId.getAndIncrement();
392 }
393 }