001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.statistics.ranking;
018
019import java.util.Arrays;
020import java.util.Objects;
021import java.util.SplittableRandom;
022import java.util.function.DoubleUnaryOperator;
023import java.util.function.IntUnaryOperator;
024
025/**
026 * Ranking based on the natural ordering on floating-point values.
027 *
028 * <p>{@link Double#NaN NaNs} are treated according to the configured
029 * {@link NaNStrategy} and ties are handled using the selected
030 * {@link TiesStrategy}. Configuration settings are supplied in optional
031 * constructor arguments. Defaults are {@link NaNStrategy#FAILED} and
032 * {@link TiesStrategy#AVERAGE}, respectively.
033 *
034 * <p>When using {@link TiesStrategy#RANDOM}, a generator of random values in {@code [0, x)}
035 * can be supplied as a {@link IntUnaryOperator} argument; otherwise a default is created
036 * on-demand. The source of randomness can be supplied using a method reference.
037 * The following example creates a ranking with NaN values with the highest
038 * ranking and ties resolved randomly:
039 *
040 * <pre>
041 * NaturalRanking ranking = new NaturalRanking(NaNStrategy.MAXIMAL,
042 *                                             new SplittableRandom()::nextInt);
043 * </pre>
044 *
045 * <p>Note: Using {@link TiesStrategy#RANDOM} is not thread-safe due to the mutable
046 * generator of randomness. Instances not using random resolution of ties are
047 * thread-safe.
048 *
049 * <p>Examples:
050 *
051 * <table border="">
052 * <caption>Examples</caption>
053 * <tr><th colspan="3">
054 * Input data: [20, 17, 30, 42.3, 17, 50, Double.NaN, Double.NEGATIVE_INFINITY, 17]
055 * </th></tr>
056 * <tr><th>NaNStrategy</th><th>TiesStrategy</th>
057 * <th>{@code rank(data)}</th>
058 * <tr>
059 * <td>MAXIMAL</td>
060 * <td>default (ties averaged)</td>
061 * <td>[5, 3, 6, 7, 3, 8, 9, 1, 3]</td></tr>
062 * <tr>
063 * <td>MAXIMAL</td>
064 * <td>MINIMUM</td>
065 * <td>[5, 2, 6, 7, 2, 8, 9, 1, 2]</td></tr>
066 * <tr>
067 * <td>MINIMAL</td>
068 * <td>default (ties averaged]</td>
069 * <td>[6, 4, 7, 8, 4, 9, 1.5, 1.5, 4]</td></tr>
070 * <tr>
071 * <td>REMOVED</td>
072 * <td>SEQUENTIAL</td>
073 * <td>[5, 2, 6, 7, 3, 8, 1, 4]</td></tr>
074 * <tr>
075 * <td>MINIMAL</td>
076 * <td>MAXIMUM</td>
077 * <td>[6, 5, 7, 8, 5, 9, 2, 2, 5]</td></tr>
078 * <tr>
079 * <td>MINIMAL</td>
080 * <td>MAXIMUM</td>
081 * <td>[6, 5, 7, 8, 5, 9, 2, 2, 5]</td></tr>
082 * </table>
083 *
084 * @since 1.1
085 */
086public class NaturalRanking implements RankingAlgorithm {
087    /** Message for a null user-supplied {@link NaNStrategy}. */
088    private static final String NULL_NAN_STRATEGY = "nanStrategy";
089    /** Message for a null user-supplied {@link TiesStrategy}. */
090    private static final String NULL_TIES_STRATEGY = "tiesStrategy";
091    /** Message for a null user-supplied source of randomness. */
092    private static final String NULL_RANDOM_SOURCE = "randomIntFunction";
093    /** Default NaN strategy. */
094    private static final NaNStrategy DEFAULT_NAN_STRATEGY = NaNStrategy.FAILED;
095    /** Default ties strategy. */
096    private static final TiesStrategy DEFAULT_TIES_STRATEGY = TiesStrategy.AVERAGE;
097    /** Map values to positive infinity. */
098    private static final DoubleUnaryOperator ACTION_POS_INF = x -> Double.POSITIVE_INFINITY;
099    /** Map values to negative infinity. */
100    private static final DoubleUnaryOperator ACTION_NEG_INF = x -> Double.NEGATIVE_INFINITY;
101    /** Raise an exception for values. */
102    private static final DoubleUnaryOperator ACTION_ERROR = operand -> {
103        throw new IllegalArgumentException("Invalid data: " + operand);
104    };
105
106    /** NaN strategy. */
107    private final NaNStrategy nanStrategy;
108    /** Ties strategy. */
109    private final TiesStrategy tiesStrategy;
110    /** Source of randomness when ties strategy is RANDOM.
111     * Function maps positive x to {@code [0, x)}.
112     * Can be null to default to a JDK implementation. */
113    private IntUnaryOperator randomIntFunction;
114
115    /**
116     * Creates an instance with {@link NaNStrategy#FAILED} and
117     * {@link TiesStrategy#AVERAGE}.
118     */
119    public NaturalRanking() {
120        this(DEFAULT_NAN_STRATEGY, DEFAULT_TIES_STRATEGY, null);
121    }
122
123    /**
124     * Creates an instance with {@link NaNStrategy#FAILED} and the
125     * specified @{@code tiesStrategy}.
126     *
127     * <p>If the ties strategy is {@link TiesStrategy#RANDOM RANDOM} a default
128     * source of randomness is used to resolve ties.
129     *
130     * @param tiesStrategy TiesStrategy to use.
131     * @throws NullPointerException if the strategy is {@code null}
132     */
133    public NaturalRanking(TiesStrategy tiesStrategy) {
134        this(DEFAULT_NAN_STRATEGY,
135            Objects.requireNonNull(tiesStrategy, NULL_TIES_STRATEGY), null);
136    }
137
138    /**
139     * Creates an instance with the specified @{@code nanStrategy} and
140     * {@link TiesStrategy#AVERAGE}.
141     *
142     * @param nanStrategy NaNStrategy to use.
143     * @throws NullPointerException if the strategy is {@code null}
144     */
145    public NaturalRanking(NaNStrategy nanStrategy) {
146        this(Objects.requireNonNull(nanStrategy, NULL_NAN_STRATEGY),
147            DEFAULT_TIES_STRATEGY, null);
148    }
149
150    /**
151     * Creates an instance with the specified @{@code nanStrategy} and the
152     * specified @{@code tiesStrategy}.
153     *
154     * <p>If the ties strategy is {@link TiesStrategy#RANDOM RANDOM} a default
155     * source of randomness is used to resolve ties.
156     *
157     * @param nanStrategy NaNStrategy to use.
158     * @param tiesStrategy TiesStrategy to use.
159     * @throws NullPointerException if any strategy is {@code null}
160     */
161    public NaturalRanking(NaNStrategy nanStrategy,
162                          TiesStrategy tiesStrategy) {
163        this(Objects.requireNonNull(nanStrategy, NULL_NAN_STRATEGY),
164            Objects.requireNonNull(tiesStrategy, NULL_TIES_STRATEGY), null);
165    }
166
167    /**
168     * Creates an instance with {@link NaNStrategy#FAILED},
169     * {@link TiesStrategy#RANDOM} and the given the source of random index data.
170     *
171     * @param randomIntFunction Source of random index data.
172     * Function maps positive {@code x} randomly to {@code [0, x)}
173     * @throws NullPointerException if the source of randomness is {@code null}
174     */
175    public NaturalRanking(IntUnaryOperator randomIntFunction) {
176        this(DEFAULT_NAN_STRATEGY, TiesStrategy.RANDOM,
177            Objects.requireNonNull(randomIntFunction, NULL_RANDOM_SOURCE));
178    }
179
180    /**
181     * Creates an instance with the specified @{@code nanStrategy},
182     * {@link TiesStrategy#RANDOM} and the given the source of random index data.
183     *
184     * @param nanStrategy NaNStrategy to use.
185     * @param randomIntFunction Source of random index data.
186     * Function maps positive {@code x} randomly to {@code [0, x)}
187     * @throws NullPointerException if the strategy or source of randomness are {@code null}
188     */
189    public NaturalRanking(NaNStrategy nanStrategy,
190                          IntUnaryOperator randomIntFunction) {
191        this(Objects.requireNonNull(nanStrategy, NULL_NAN_STRATEGY), TiesStrategy.RANDOM,
192            Objects.requireNonNull(randomIntFunction, NULL_RANDOM_SOURCE));
193    }
194
195    /**
196     * @param nanStrategy NaNStrategy to use.
197     * @param tiesStrategy TiesStrategy to use.
198     * @param randomIntFunction Source of random index data.
199     */
200    private NaturalRanking(NaNStrategy nanStrategy,
201                           TiesStrategy tiesStrategy,
202                           IntUnaryOperator randomIntFunction) {
203        // User-supplied arguments are checked for non-null in the respective constructor
204        this.nanStrategy = nanStrategy;
205        this.tiesStrategy = tiesStrategy;
206        this.randomIntFunction = randomIntFunction;
207    }
208
209    /**
210     * Return the {@link NaNStrategy}.
211     *
212     * @return the strategy for handling NaN
213     */
214    public NaNStrategy getNanStrategy() {
215        return nanStrategy;
216    }
217
218    /**
219     * Return the {@link TiesStrategy}.
220     *
221     * @return the strategy for handling ties
222     */
223    public TiesStrategy getTiesStrategy() {
224        return tiesStrategy;
225    }
226
227    /**
228     * Rank {@code data} using the natural ordering on floating-point values, with
229     * NaN values handled according to {@code nanStrategy} and ties resolved using
230     * {@code tiesStrategy}.
231     *
232     * @throws IllegalArgumentException if the selected {@link NaNStrategy} is
233     * {@code FAILED} and a {@link Double#NaN} is encountered in the input data.
234     */
235    @Override
236    public double[] apply(double[] data) {
237        // Convert data for sorting.
238        // NaNs are counted for the FIXED strategy.
239        final int[] nanCount = {0};
240        final DataPosition[] ranks = createRankData(data, nanCount);
241
242        // Sorting will move NaNs to the end and we do not have to resolve ties in them.
243        final int nonNanSize = ranks.length - nanCount[0];
244
245        // Edge case for empty data
246        if (nonNanSize == 0) {
247            // Either NaN are left in-place or removed
248            return nanStrategy == NaNStrategy.FIXED ? data : new double[0];
249        }
250
251        Arrays.sort(ranks);
252
253        // Walk the sorted array, filling output array using sorted positions,
254        // resolving ties as we go.
255        int pos = 1;
256        final double[] out = new double[ranks.length];
257
258        DataPosition current = ranks[0];
259        out[current.getPosition()] = pos;
260
261        // Store all previous elements of a tie.
262        // Note this lags behind the length of the tie sequence by 1.
263        // In the event there are no ties this is not used.
264        final IntList tiesTrace = new IntList(ranks.length);
265
266        for (int i = 1; i < nonNanSize; i++) {
267            final DataPosition previous = current;
268            current = ranks[i];
269            if (current.compareTo(previous) > 0) {
270                // Check for a previous tie sequence
271                if (tiesTrace.size() != 0) {
272                    resolveTie(out, tiesTrace, previous.getPosition());
273                }
274                pos = i + 1;
275            } else {
276                // Tie sequence. Add the matching previous element.
277                tiesTrace.add(previous.getPosition());
278            }
279            out[current.getPosition()] = pos;
280        }
281        // Handle tie sequence at end
282        if (tiesTrace.size() != 0) {
283            resolveTie(out, tiesTrace, current.getPosition());
284        }
285        // For the FIXED strategy consume the remaining NaN elements
286        if (nanStrategy == NaNStrategy.FIXED) {
287            for (int i = nonNanSize; i < ranks.length; i++) {
288                out[ranks[i].getPosition()] = Double.NaN;
289            }
290        }
291        return out;
292    }
293
294    /**
295     * Creates the rank data. If using {@link NaNStrategy#REMOVED} then NaNs are
296     * filtered. Otherwise NaNs may be mapped to an infinite value, counted to allow
297     * subsequent processing, or cause an exception to be thrown.
298     *
299     * @param data Source data.
300     * @param nanCount Output counter for NaN values.
301     * @return the rank data
302     * @throws IllegalArgumentException if the data contains NaN values when using
303     * {@link NaNStrategy#FAILED}.
304     */
305    private DataPosition[] createRankData(double[] data, final int[] nanCount) {
306        return nanStrategy == NaNStrategy.REMOVED ?
307                createNonNaNRankData(data) :
308                createMappedRankData(data, createNaNAction(nanCount));
309    }
310
311    /**
312     * Creates the NaN action.
313     *
314     * @param nanCount Output counter for NaN values.
315     * @return the operator applied to NaN values
316     */
317    private DoubleUnaryOperator createNaNAction(int[] nanCount) {
318        switch (nanStrategy) {
319        case MAXIMAL: // Replace NaNs with +INFs
320            return ACTION_POS_INF;
321        case MINIMAL: // Replace NaNs with -INFs
322            return ACTION_NEG_INF;
323        case REMOVED: // NaNs are removed
324        case FIXED:   // NaNs are unchanged
325            // Count the NaNs in the data that must be handled
326            return x -> {
327                nanCount[0]++;
328                return x;
329            };
330        case FAILED:
331            return ACTION_ERROR;
332        default:
333            // this should not happen unless NaNStrategy enum is changed
334            throw new IllegalStateException();
335        }
336    }
337
338    /**
339     * Creates the rank data with NaNs removed.
340     *
341     * @param data Source data.
342     * @return the rank data
343     */
344    private static DataPosition[] createNonNaNRankData(double[] data) {
345        final DataPosition[] ranks = new DataPosition[data.length];
346        int size = 0;
347        for (final double v : data) {
348            if (!Double.isNaN(v)) {
349                ranks[size] = new DataPosition(v, size);
350                size++;
351            }
352        }
353        return size == data.length ? ranks : Arrays.copyOf(ranks, size);
354    }
355
356    /**
357     * Creates the rank data.
358     *
359     * @param data Source data.
360     * @param nanAction Mapping operator applied to NaN values.
361     * @return the rank data
362     */
363    private static DataPosition[] createMappedRankData(double[] data, DoubleUnaryOperator nanAction) {
364        final DataPosition[] ranks = new DataPosition[data.length];
365        for (int i = 0; i < data.length; i++) {
366            double v = data[i];
367            if (Double.isNaN(v)) {
368                v = nanAction.applyAsDouble(v);
369            }
370            ranks[i] = new DataPosition(v, i);
371        }
372        return ranks;
373    }
374
375    /**
376     * Resolve a sequence of ties, using the configured {@link TiesStrategy}. The
377     * input {@code ranks} array is expected to take the same value for all indices
378     * in {@code tiesTrace}. The common value is recoded according to the
379     * tiesStrategy. For example, if ranks = [5,8,2,6,2,7,1,2], tiesTrace = [2,4,7]
380     * and tiesStrategy is MINIMUM, ranks will be unchanged. The same array and
381     * trace with tiesStrategy AVERAGE will come out [5,8,3,6,3,7,1,3].
382     *
383     * <p>Note: For convenience the final index of the trace is passed as an argument;
384     * it is assumed the list is already non-empty. At the end of the method the
385     * list of indices is cleared.
386     *
387     * @param ranks Array of ranks.
388     * @param tiesTrace List of indices where {@code ranks} is constant, that is,
389     * for any i and j in {@code tiesTrace}: {@code ranks[i] == ranks[j]}.
390     * @param finalIndex The final index to add to the sequence of ties.
391     */
392    private void resolveTie(double[] ranks, IntList tiesTrace, int finalIndex) {
393        tiesTrace.add(finalIndex);
394
395        // Constant value of ranks over tiesTrace.
396        // Note: c is a rank counter starting from 1 so limited to an int.
397        final double c = ranks[tiesTrace.get(0)];
398
399        // length of sequence of tied ranks
400        final int length = tiesTrace.size();
401
402        switch (tiesStrategy) {
403        case  AVERAGE:   // Replace ranks with average: (lower + upper) / 2
404            fill(ranks, tiesTrace, (2 * c + length - 1) * 0.5);
405            break;
406        case MAXIMUM:    // Replace ranks with maximum values
407            fill(ranks, tiesTrace, c + length - 1);
408            break;
409        case MINIMUM:    // Replace ties with minimum
410            // Note that the tie sequence already has all values set to c so
411            // no requirement to fill again.
412            break;
413        case SEQUENTIAL: // Fill sequentially from c to c + length - 1
414        case RANDOM:     // Fill with randomized sequential values in [c, c + length - 1]
415            // This cast is safe as c is a counter.
416            int r = (int) c;
417            if (tiesStrategy == TiesStrategy.RANDOM) {
418                tiesTrace.shuffle(getRandomIntFunction());
419            }
420            final int size = tiesTrace.size();
421            for (int i = 0; i < size; i++) {
422                ranks[tiesTrace.get(i)] = r++;
423            }
424            break;
425        default: // this should not happen unless TiesStrategy enum is changed
426            throw new IllegalStateException();
427        }
428
429        tiesTrace.clear();
430    }
431
432    /**
433     * Sets {@code data[i] = value} for each i in {@code tiesTrace}.
434     *
435     * @param data Array to modify.
436     * @param tiesTrace List of index values to set.
437     * @param value Value to set.
438     */
439    private static void fill(double[] data, IntList tiesTrace, double value) {
440        final int size = tiesTrace.size();
441        for (int i = 0; i < size; i++) {
442            data[tiesTrace.get(i)] = value;
443        }
444    }
445
446    /**
447     * Gets the function to map positive {@code x} randomly to {@code [0, x)}.
448     * Defaults to a system provided generator if the constructor source of randomness is null.
449     *
450     * @return the RNG
451     */
452    private IntUnaryOperator getRandomIntFunction() {
453        IntUnaryOperator r = randomIntFunction;
454        if (r == null) {
455            // Default to a SplittableRandom
456            randomIntFunction = r = new SplittableRandom()::nextInt;
457        }
458        return r;
459    }
460
461    /**
462     * An expandable list of int values. This allows tracking array positions
463     * without using boxed values in a {@code List<Integer>}.
464     */
465    private static class IntList {
466        /** The maximum size of array to allocate. */
467        private final int max;
468
469        /** The size of the list. */
470        private int size;
471        /** The list data. Initialised with space to store a tie of 2 values. */
472        private int[] data = new int[2];
473
474        /**
475         * @param max Maximum size of array to allocate. Can use the length of the parent array
476         * for which this is used to track indices.
477         */
478        IntList(int max) {
479            this.max = max;
480        }
481
482        /**
483         * Adds the value to the list.
484         *
485         * @param value the value
486         */
487        void add(int value) {
488            if (size == data.length) {
489                // Overflow safe doubling of the current size.
490                data = Arrays.copyOf(data, (int) Math.min(max, size * 2L));
491            }
492            data[size++] = value;
493        }
494
495        /**
496         * Gets the element at the specified {@code index}.
497         *
498         * @param index Element index
499         * @return the element
500         */
501        int get(int index) {
502            return data[index];
503        }
504
505        /**
506         * Gets the number of elements in the list.
507         *
508         * @return the size
509         */
510        int size() {
511            return size;
512        }
513
514        /**
515         * Clear the list.
516         */
517        void clear() {
518            size = 0;
519        }
520
521        /**
522         * Shuffle the list.
523         *
524         * @param randomIntFunction Function maps positive {@code x} randomly to {@code [0, x)}.
525         */
526        void shuffle(IntUnaryOperator randomIntFunction) {
527            // Fisher-Yates shuffle
528            final int[] array = data;
529            for (int i = size; i > 1; i--) {
530                swap(array, i - 1, randomIntFunction.applyAsInt(i));
531            }
532        }
533
534        /**
535         * Swaps the two specified elements in the specified array.
536         *
537         * @param array Data array
538         * @param i     First index
539         * @param j     Second index
540         */
541        private static void swap(int[] array, int i, int j) {
542            final int tmp = array[i];
543            array[i] = array[j];
544            array[j] = tmp;
545        }
546    }
547
548    /**
549     * Represents the position of a {@code double} value in a data array. The
550     * Comparable interface is implemented so Arrays.sort can be used to sort an
551     * array of data positions by value. Note that the implicitly defined natural
552     * ordering is NOT consistent with equals.
553     */
554    private static class DataPosition implements Comparable<DataPosition>  {
555        /** Data value. */
556        private final double value;
557        /** Data position. */
558        private final int position;
559
560        /**
561         * Create an instance with the given value and position.
562         *
563         * @param value Data value.
564         * @param position Data position.
565         */
566        DataPosition(double value, int position) {
567            this.value = value;
568            this.position = position;
569        }
570
571        /**
572         * Compare this value to another.
573         * Only the <strong>values</strong> are compared.
574         *
575         * @param other the other pair to compare this to
576         * @return result of {@code Double.compare(value, other.value)}
577         */
578        @Override
579        public int compareTo(DataPosition other) {
580            return Double.compare(value, other.value);
581        }
582
583        // N.B. equals() and hashCode() are not implemented; see MATH-610 for discussion.
584
585        /**
586         * Returns the data position.
587         *
588         * @return position
589         */
590        int getPosition() {
591            return position;
592        }
593    }
594}