001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.numbers.combinatorics;
019
020/**
021 * Computation of <a href="https://en.wikipedia.org/wiki/Stirling_number">Stirling numbers</a>.
022 *
023 * @since 1.2
024 */
025public final class Stirling {
026    /** Stirling S1 error message. */
027    private static final String S1_ERROR_FORMAT = "s(n=%d, k=%d)";
028    /** Stirling S2 error message. */
029    private static final String S2_ERROR_FORMAT = "S(n=%d, k=%d)";
030    /** Overflow threshold for n when computing s(n, 1). */
031    private static final int S1_OVERFLOW_K_EQUALS_1 = 21;
032    /** Overflow threshold for n when computing s(n, n-2). */
033    private static final int S1_OVERFLOW_K_EQUALS_NM2 = 92682;
034    /** Overflow threshold for n when computing s(n, n-3). */
035    private static final int S1_OVERFLOW_K_EQUALS_NM3 = 2761;
036    /** Overflow threshold for n when computing S(n, n-2). */
037    private static final int S2_OVERFLOW_K_EQUALS_NM2 = 92683;
038    /** Overflow threshold for n when computing S(n, n-3). */
039    private static final int S2_OVERFLOW_K_EQUALS_NM3 = 2762;
040
041    /**
042     * Precomputed Stirling numbers of the first kind.
043     * Provides a thread-safe lazy initialization of the cache.
044     */
045    private static final class StirlingS1Cache {
046        /** Maximum n to compute (exclusive).
047         * As s(21,3) = 13803759753640704000 is larger than Long.MAX_VALUE
048         * we must stop computation at row 21. */
049        static final int MAX_N = 21;
050        /** Stirling numbers of the first kind. */
051        static final long[][] S1;
052
053        static {
054            S1 = new long[MAX_N][];
055            // Initialise first two rows to allow s(2, 1) to use s(1, 1)
056            S1[0] = new long[] {1};
057            S1[1] = new long[] {0, 1};
058            for (int n = 2; n < S1.length; n++) {
059                S1[n] = new long[n + 1];
060                S1[n][0] = 0;
061                S1[n][n] = 1;
062                for (int k = 1; k < n; k++) {
063                    S1[n][k] = S1[n - 1][k - 1] - (n - 1) * S1[n - 1][k];
064                }
065            }
066        }
067    }
068
069    /**
070     * Precomputed Stirling numbers of the second kind.
071     * Provides a thread-safe lazy initialization of the cache.
072     */
073    private static final class StirlingS2Cache {
074        /** Maximum n to compute (exclusive).
075         * As S(26,9) = 11201516780955125625 is larger than Long.MAX_VALUE
076         * we must stop computation at row 26. */
077        static final int MAX_N = 26;
078        /** Stirling numbers of the second kind. */
079        static final long[][] S2;
080
081        static {
082            S2 = new long[MAX_N][];
083            S2[0] = new long[] {1};
084            for (int n = 1; n < S2.length; n++) {
085                S2[n] = new long[n + 1];
086                S2[n][0] = 0;
087                S2[n][1] = 1;
088                S2[n][n] = 1;
089                for (int k = 2; k < n; k++) {
090                    S2[n][k] = k * S2[n - 1][k] + S2[n - 1][k - 1];
091                }
092            }
093        }
094    }
095
096    /** Private constructor. */
097    private Stirling() {
098        // intentionally empty.
099    }
100
101    /**
102     * Returns the <em>signed</em> <a
103     * href="https://mathworld.wolfram.com/StirlingNumberoftheFirstKind.html">
104     * Stirling number of the first kind</a>, "{@code s(n,k)}". The number of permutations of
105     * {@code n} elements which contain exactly {@code k} permutation cycles is the
106     * nonnegative number: {@code |s(n,k)| = (-1)^(n-k) s(n,k)}
107     *
108     * @param n Size of the set
109     * @param k Number of permutation cycles ({@code 0 <= k <= n})
110     * @return {@code s(n,k)}
111     * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
112     * @throws ArithmeticException if some overflow happens, typically for n exceeding 20
113     * (s(n,n-1) is handled specifically and does not overflow)
114     */
115    public static long stirlingS1(int n, int k) {
116        checkArguments(n, k);
117
118        if (n < StirlingS1Cache.MAX_N) {
119            // The number is in the small cache
120            return StirlingS1Cache.S1[n][k];
121        }
122
123        // Simple cases
124        // https://en.wikipedia.org/wiki/Stirling_numbers_of_the_first_kind#Simple_identities
125        if (k == 0) {
126            return 0;
127        } else if (k == n) {
128            return 1;
129        } else if (k == 1) {
130            checkN(n, k, S1_OVERFLOW_K_EQUALS_1, S1_ERROR_FORMAT);
131            // Note: Only occurs for n=21 so avoid computing the sign with pow(-1, n-1) * (n-1)!
132            return Factorial.value(n - 1);
133        } else if (k == n - 1) {
134            return -BinomialCoefficient.value(n, 2);
135        } else if (k == n - 2) {
136            checkN(n, k, S1_OVERFLOW_K_EQUALS_NM2, S1_ERROR_FORMAT);
137            // (3n-1) * binom(n, 3) / 4
138            return productOver4(3L * n - 1, BinomialCoefficient.value(n, 3));
139        } else if (k == n - 3) {
140            checkN(n, k, S1_OVERFLOW_K_EQUALS_NM3, S1_ERROR_FORMAT);
141            return -BinomialCoefficient.value(n, 2) * BinomialCoefficient.value(n, 4);
142        }
143
144        // Compute using:
145        // s(n + 1, k) = s(n, k - 1)     - n       * s(n, k)
146        // s(n, k)     = s(n - 1, k - 1) - (n - 1) * s(n - 1, k)
147
148        // n >= 21 (MAX_N)
149        // 2 <= k <= n-4
150
151        // Start at the largest easily computed value: n < MAX_N or k < 2
152        final int reduction = Math.min(n - StirlingS1Cache.MAX_N, k - 2) + 1;
153        int n0 = n - reduction;
154        int k0 = k - reduction;
155
156        long sum = stirlingS1(n0, k0);
157        while (n0 < n) {
158            k0++;
159            sum = Math.subtractExact(
160                sum,
161                Math.multiplyExact(n0, stirlingS1(n0, k0))
162            );
163            n0++;
164        }
165
166        return sum;
167    }
168
169    /**
170     * Returns the <a
171     * href="https://mathworld.wolfram.com/StirlingNumberoftheSecondKind.html">
172     * Stirling number of the second kind</a>, "{@code S(n,k)}", the number of
173     * ways of partitioning an {@code n}-element set into {@code k} non-empty
174     * subsets.
175     *
176     * @param n Size of the set
177     * @param k Number of non-empty subsets ({@code 0 <= k <= n})
178     * @return {@code S(n,k)}
179     * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
180     * @throws ArithmeticException if some overflow happens, typically for n exceeding 25 and
181     * k between 20 and n-2 (S(n,n-1) is handled specifically and does not overflow)
182     */
183    public static long stirlingS2(int n, int k) {
184        checkArguments(n, k);
185
186        if (n < StirlingS2Cache.MAX_N) {
187            // The number is in the small cache
188            return StirlingS2Cache.S2[n][k];
189        }
190
191        // Simple cases
192        if (k == 0) {
193            return 0;
194        } else if (k == 1 || k == n) {
195            return 1;
196        } else if (k == 2) {
197            checkN(n, k, 64, S2_ERROR_FORMAT);
198            return (1L << (n - 1)) - 1L;
199        } else if (k == n - 1) {
200            return BinomialCoefficient.value(n, 2);
201        } else if (k == n - 2) {
202            checkN(n, k, S2_OVERFLOW_K_EQUALS_NM2, S2_ERROR_FORMAT);
203            // (3n-5) * binom(n, 3) / 4
204            return productOver4(3L * n - 5, BinomialCoefficient.value(n, 3));
205        } else if (k == n - 3) {
206            checkN(n, k, S2_OVERFLOW_K_EQUALS_NM3, S2_ERROR_FORMAT);
207            return BinomialCoefficient.value(n - 2, 2) * BinomialCoefficient.value(n, 4);
208        }
209
210        // Compute using:
211        // S(n, k) = k * S(n - 1, k) + S(n - 1, k - 1)
212
213        // n >= 26 (MAX_N)
214        // 3 <= k <= n-3
215
216        // Start at the largest easily computed value: n < MAX_N or k < 3
217        final int reduction = Math.min(n - StirlingS2Cache.MAX_N, k - 3) + 1;
218        int n0 = n - reduction;
219        int k0 = k - reduction;
220
221        long sum = stirlingS2(n0, k0);
222        while (n0 < n) {
223            k0++;
224            sum = Math.addExact(
225                Math.multiplyExact(k0, stirlingS2(n0, k0)),
226                sum
227            );
228            n0++;
229        }
230
231        return sum;
232    }
233
234    /**
235     * Check {@code 0 <= k <= n}.
236     *
237     * @param n N
238     * @param k K
239     * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
240     */
241    private static void checkArguments(int n, int k) {
242        // Combine all checks with a single branch:
243        // 0 <= n; 0 <= k <= n
244        // Note: If n >= 0 && k >= 0 && n - k < 0 then k > n.
245        // Bitwise or will detect a negative sign bit in any of the numbers
246        if ((n | k | (n - k)) < 0) {
247            // Raise the correct exception
248            if (n < 0) {
249                throw new CombinatoricsException(CombinatoricsException.NEGATIVE, n);
250            }
251            throw new CombinatoricsException(CombinatoricsException.OUT_OF_RANGE, k, 0, n);
252        }
253    }
254
255    /**
256     * Check {@code n <= threshold}, or else throw an {@link ArithmeticException}.
257     *
258     * @param n N
259     * @param k K
260     * @param threshold Threshold for {@code n}
261     * @param msgFormat Error message format
262     * @throws ArithmeticException if overflow is expected to happen
263     */
264    private static void checkN(int n, int k, int threshold, String msgFormat) {
265        if (n > threshold) {
266            throw new ArithmeticException(String.format(msgFormat, n, k));
267        }
268    }
269
270    /**
271     * Return {@code a*b/4} without intermediate overflow.
272     * It is assumed that:
273     * <ul>
274     * <li>The coefficients a and b are positive
275     * <li>The product (a*b) is an exact multiple of 4
276     * <li>The result (a*b/4) is an exact integer that does not overflow a {@code long}
277     * </ul>
278     *
279     * <p>A conditional branch is performed on the odd/even property of {@code b}.
280     * The branch is predictable if {@code b} is typically the same parity.
281     *
282     * @param a Coefficient a
283     * @param b Coefficient b
284     * @return {@code a*b/4}
285     */
286    private static long productOver4(long a, long b) {
287        // Compute (a*b/4) without intermediate overflow.
288        // The product (a*b) must be an exact multiple of 4.
289        // If b is even: ((b/2) * a) / 2
290        // If b is odd then a must be even to make a*b even: ((a/2) * b) / 2
291        return (b & 1) == 0 ?
292            ((b >>> 1) * a) >>> 1 :
293            ((a >>> 1) * b) >>> 1;
294    }
295}