View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.neuralnet;
19  
20  import java.util.List;
21  import java.util.ArrayList;
22  import java.util.Collections;
23  import java.util.Comparator;
24  
25  import org.apache.commons.math4.neuralnet.internal.NeuralNetException;
26  
27  /**
28   * Utility for ranking the units (neurons) of a network.
29   *
30   * @since 4.0
31   */
32  public class MapRanking {
33      /** List corresponding to the map passed to the constructor. */
34      private final List<Neuron> map = new ArrayList<>();
35      /** Distance function for sorting. */
36      private final DistanceMeasure distance;
37  
38      /**
39       * @param neurons List to be ranked.
40       * No defensive copy is performed.
41       * The {@link #rank(double[],int) created list of units} will
42       * be sorted in increasing order of the {@code distance}.
43       * @param distance Distance function.
44       */
45      public MapRanking(Iterable<Neuron> neurons,
46                        DistanceMeasure distance) {
47          this.distance = distance;
48  
49          for (final Neuron n : neurons) {
50              map.add(n); // No defensive copy.
51          }
52      }
53  
54      /**
55       * Creates a list of the neurons whose features best correspond to the
56       * given {@code features}.
57       *
58       * @param features Data.
59       * @return the list of neurons sorted in decreasing order of distance to
60       * the given data.
61       * @throws IllegalArgumentException if the size of the input is not
62       * compatible with the neurons features size.
63       */
64      public List<Neuron> rank(double[] features) {
65          return rank(features, map.size());
66      }
67  
68      /**
69       * Creates a list of the neurons whose features best correspond to the
70       * given {@code features}.
71       *
72       * @param features Data.
73       * @param max Maximum size of the returned list.
74       * @return the list of neurons sorted in decreasing order of distance to
75       * the given data.
76       * @throws IllegalArgumentException if the size of the input is not
77       * compatible with the neurons features size or {@code max <= 0}.
78       */
79      public List<Neuron> rank(double[] features,
80                               int max) {
81          if (max <= 0) {
82              throw new NeuralNetException(NeuralNetException.NOT_STRICTLY_POSITIVE, max);
83          }
84          final int m = max <= map.size() ?
85              max :
86              map.size();
87          final List<PairNeuronDouble> list = new ArrayList<>(m);
88  
89          for (final Neuron n : map) {
90              final double d = distance.applyAsDouble(n.getFeatures(), features);
91              final PairNeuronDouble p = new PairNeuronDouble(n, d);
92  
93              if (list.size() < m) {
94                  list.add(p);
95                  if (list.size() > 1) {
96                      // Sort if there is more than 1 element.
97                      Collections.sort(list, PairNeuronDouble.COMPARATOR);
98                  }
99              } else {
100                 final int last = list.size() - 1;
101                 if (PairNeuronDouble.COMPARATOR.compare(p, list.get(last)) < 0) {
102                     list.set(last, p); // Replace worst entry.
103                     if (last > 0) {
104                         // Sort if there is more than 1 element.
105                         Collections.sort(list, PairNeuronDouble.COMPARATOR);
106                     }
107                 }
108             }
109         }
110 
111         final List<Neuron> result = new ArrayList<>(m);
112         for (final PairNeuronDouble p : list) {
113             result.add(p.getNeuron());
114         }
115 
116         return result;
117     }
118 
119     /**
120      * Helper data structure holding a (Neuron, double) pair.
121      */
122     private static class PairNeuronDouble {
123         /** Comparator. */
124         static final Comparator<PairNeuronDouble> COMPARATOR
125             = new Comparator<PairNeuronDouble>() {
126                 /** {@inheritDoc} */
127                 @Override
128                 public int compare(PairNeuronDouble o1,
129                                    PairNeuronDouble o2) {
130                     return Double.compare(o1.value, o2.value);
131                 }
132             };
133         /** Key. */
134         private final Neuron neuron;
135         /** Value. */
136         private final double value;
137 
138         /**
139          * @param neuron Neuron.
140          * @param value Value.
141          */
142         PairNeuronDouble(Neuron neuron, double value) {
143             this.neuron = neuron;
144             this.value = value;
145         }
146 
147         /** @return the neuron. */
148         public Neuron getNeuron() {
149             return neuron;
150         }
151     }
152 }