MapRanking.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.commons.math4.neuralnet;

import java.util.List;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;

import org.apache.commons.math4.neuralnet.internal.NeuralNetException;

/**
 * Utility for ranking the units (neurons) of a network.
 *
 * @since 4.0
 */
public class MapRanking {
    /** List corresponding to the map passed to the constructor. */
    private final List<Neuron> map = new ArrayList<>();
    /** Distance function for sorting. */
    private final DistanceMeasure distance;

    /**
     * @param neurons List to be ranked.
     * No defensive copy is performed.
     * The {@link #rank(double[],int) created list of units} will
     * be sorted in increasing order of the {@code distance}.
     * @param distance Distance function.
     */
    public MapRanking(Iterable<Neuron> neurons,
                      DistanceMeasure distance) {
        this.distance = distance;

        for (final Neuron n : neurons) {
            map.add(n); // No defensive copy.
        }
    }

    /**
     * Creates a list of the neurons whose features best correspond to the
     * given {@code features}.
     *
     * @param features Data.
     * @return the list of neurons sorted in decreasing order of distance to
     * the given data.
     * @throws IllegalArgumentException if the size of the input is not
     * compatible with the neurons features size.
     */
    public List<Neuron> rank(double[] features) {
        return rank(features, map.size());
    }

    /**
     * Creates a list of the neurons whose features best correspond to the
     * given {@code features}.
     *
     * @param features Data.
     * @param max Maximum size of the returned list.
     * @return the list of neurons sorted in decreasing order of distance to
     * the given data.
     * @throws IllegalArgumentException if the size of the input is not
     * compatible with the neurons features size or {@code max <= 0}.
     */
    public List<Neuron> rank(double[] features,
                             int max) {
        if (max <= 0) {
            throw new NeuralNetException(NeuralNetException.NOT_STRICTLY_POSITIVE, max);
        }
        final int m = max <= map.size() ?
            max :
            map.size();
        final List<PairNeuronDouble> list = new ArrayList<>(m);

        for (final Neuron n : map) {
            final double d = distance.applyAsDouble(n.getFeatures(), features);
            final PairNeuronDouble p = new PairNeuronDouble(n, d);

            if (list.size() < m) {
                list.add(p);
                if (list.size() > 1) {
                    // Sort if there is more than 1 element.
                    Collections.sort(list, PairNeuronDouble.COMPARATOR);
                }
            } else {
                final int last = list.size() - 1;
                if (PairNeuronDouble.COMPARATOR.compare(p, list.get(last)) < 0) {
                    list.set(last, p); // Replace worst entry.
                    if (last > 0) {
                        // Sort if there is more than 1 element.
                        Collections.sort(list, PairNeuronDouble.COMPARATOR);
                    }
                }
            }
        }

        final List<Neuron> result = new ArrayList<>(m);
        for (final PairNeuronDouble p : list) {
            result.add(p.getNeuron());
        }

        return result;
    }

    /**
     * Helper data structure holding a (Neuron, double) pair.
     */
    private static class PairNeuronDouble {
        /** Comparator. */
        static final Comparator<PairNeuronDouble> COMPARATOR
            = new Comparator<PairNeuronDouble>() {
                /** {@inheritDoc} */
                @Override
                public int compare(PairNeuronDouble o1,
                                   PairNeuronDouble o2) {
                    return Double.compare(o1.value, o2.value);
                }
            };
        /** Key. */
        private final Neuron neuron;
        /** Value. */
        private final double value;

        /**
         * @param neuron Neuron.
         * @param value Value.
         */
        PairNeuronDouble(Neuron neuron, double value) {
            this.neuron = neuron;
            this.value = value;
        }

        /** @return the neuron. */
        public Neuron getNeuron() {
            return neuron;
        }
    }
}