View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.numbers.combinatorics;
19  
20  /**
21   * Computation of <a href="https://en.wikipedia.org/wiki/Stirling_number">Stirling numbers</a>.
22   *
23   * @since 1.2
24   */
25  public final class Stirling {
26      /** Stirling S1 error message. */
27      private static final String S1_ERROR_FORMAT = "s(n=%d, k=%d)";
28      /** Stirling S2 error message. */
29      private static final String S2_ERROR_FORMAT = "S(n=%d, k=%d)";
30      /** Overflow threshold for n when computing s(n, 1). */
31      private static final int S1_OVERFLOW_K_EQUALS_1 = 21;
32      /** Overflow threshold for n when computing s(n, n-2). */
33      private static final int S1_OVERFLOW_K_EQUALS_NM2 = 92682;
34      /** Overflow threshold for n when computing s(n, n-3). */
35      private static final int S1_OVERFLOW_K_EQUALS_NM3 = 2761;
36      /** Overflow threshold for n when computing S(n, n-2). */
37      private static final int S2_OVERFLOW_K_EQUALS_NM2 = 92683;
38      /** Overflow threshold for n when computing S(n, n-3). */
39      private static final int S2_OVERFLOW_K_EQUALS_NM3 = 2762;
40  
41      /**
42       * Precomputed Stirling numbers of the first kind.
43       * Provides a thread-safe lazy initialization of the cache.
44       */
45      private static final class StirlingS1Cache {
46          /** Maximum n to compute (exclusive).
47           * As s(21,3) = 13803759753640704000 is larger than Long.MAX_VALUE
48           * we must stop computation at row 21. */
49          static final int MAX_N = 21;
50          /** Stirling numbers of the first kind. */
51          static final long[][] S1;
52  
53          static {
54              S1 = new long[MAX_N][];
55              // Initialise first two rows to allow s(2, 1) to use s(1, 1)
56              S1[0] = new long[] {1};
57              S1[1] = new long[] {0, 1};
58              for (int n = 2; n < S1.length; n++) {
59                  S1[n] = new long[n + 1];
60                  S1[n][0] = 0;
61                  S1[n][n] = 1;
62                  for (int k = 1; k < n; k++) {
63                      S1[n][k] = S1[n - 1][k - 1] - (n - 1) * S1[n - 1][k];
64                  }
65              }
66          }
67      }
68  
69      /**
70       * Precomputed Stirling numbers of the second kind.
71       * Provides a thread-safe lazy initialization of the cache.
72       */
73      private static final class StirlingS2Cache {
74          /** Maximum n to compute (exclusive).
75           * As S(26,9) = 11201516780955125625 is larger than Long.MAX_VALUE
76           * we must stop computation at row 26. */
77          static final int MAX_N = 26;
78          /** Stirling numbers of the second kind. */
79          static final long[][] S2;
80  
81          static {
82              S2 = new long[MAX_N][];
83              S2[0] = new long[] {1};
84              for (int n = 1; n < S2.length; n++) {
85                  S2[n] = new long[n + 1];
86                  S2[n][0] = 0;
87                  S2[n][1] = 1;
88                  S2[n][n] = 1;
89                  for (int k = 2; k < n; k++) {
90                      S2[n][k] = k * S2[n - 1][k] + S2[n - 1][k - 1];
91                  }
92              }
93          }
94      }
95  
96      /** Private constructor. */
97      private Stirling() {
98          // intentionally empty.
99      }
100 
101     /**
102      * Returns the <em>signed</em> <a
103      * href="https://mathworld.wolfram.com/StirlingNumberoftheFirstKind.html">
104      * Stirling number of the first kind</a>, "{@code s(n,k)}". The number of permutations of
105      * {@code n} elements which contain exactly {@code k} permutation cycles is the
106      * nonnegative number: {@code |s(n,k)| = (-1)^(n-k) s(n,k)}
107      *
108      * @param n Size of the set
109      * @param k Number of permutation cycles ({@code 0 <= k <= n})
110      * @return {@code s(n,k)}
111      * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
112      * @throws ArithmeticException if some overflow happens, typically for n exceeding 20
113      * (s(n,n-1) is handled specifically and does not overflow)
114      */
115     public static long stirlingS1(int n, int k) {
116         checkArguments(n, k);
117 
118         if (n < StirlingS1Cache.MAX_N) {
119             // The number is in the small cache
120             return StirlingS1Cache.S1[n][k];
121         }
122 
123         // Simple cases
124         // https://en.wikipedia.org/wiki/Stirling_numbers_of_the_first_kind#Simple_identities
125         if (k == 0) {
126             return 0;
127         } else if (k == n) {
128             return 1;
129         } else if (k == 1) {
130             checkN(n, k, S1_OVERFLOW_K_EQUALS_1, S1_ERROR_FORMAT);
131             // Note: Only occurs for n=21 so avoid computing the sign with pow(-1, n-1) * (n-1)!
132             return Factorial.value(n - 1);
133         } else if (k == n - 1) {
134             return -BinomialCoefficient.value(n, 2);
135         } else if (k == n - 2) {
136             checkN(n, k, S1_OVERFLOW_K_EQUALS_NM2, S1_ERROR_FORMAT);
137             // (3n-1) * binom(n, 3) / 4
138             return productOver4(3L * n - 1, BinomialCoefficient.value(n, 3));
139         } else if (k == n - 3) {
140             checkN(n, k, S1_OVERFLOW_K_EQUALS_NM3, S1_ERROR_FORMAT);
141             return -BinomialCoefficient.value(n, 2) * BinomialCoefficient.value(n, 4);
142         }
143 
144         // Compute using:
145         // s(n + 1, k) = s(n, k - 1)     - n       * s(n, k)
146         // s(n, k)     = s(n - 1, k - 1) - (n - 1) * s(n - 1, k)
147 
148         // n >= 21 (MAX_N)
149         // 2 <= k <= n-4
150 
151         // Start at the largest easily computed value: n < MAX_N or k < 2
152         final int reduction = Math.min(n - StirlingS1Cache.MAX_N, k - 2) + 1;
153         int n0 = n - reduction;
154         int k0 = k - reduction;
155 
156         long sum = stirlingS1(n0, k0);
157         while (n0 < n) {
158             k0++;
159             sum = Math.subtractExact(
160                 sum,
161                 Math.multiplyExact(n0, stirlingS1(n0, k0))
162             );
163             n0++;
164         }
165 
166         return sum;
167     }
168 
169     /**
170      * Returns the <a
171      * href="https://mathworld.wolfram.com/StirlingNumberoftheSecondKind.html">
172      * Stirling number of the second kind</a>, "{@code S(n,k)}", the number of
173      * ways of partitioning an {@code n}-element set into {@code k} non-empty
174      * subsets.
175      *
176      * @param n Size of the set
177      * @param k Number of non-empty subsets ({@code 0 <= k <= n})
178      * @return {@code S(n,k)}
179      * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
180      * @throws ArithmeticException if some overflow happens, typically for n exceeding 25 and
181      * k between 20 and n-2 (S(n,n-1) is handled specifically and does not overflow)
182      */
183     public static long stirlingS2(int n, int k) {
184         checkArguments(n, k);
185 
186         if (n < StirlingS2Cache.MAX_N) {
187             // The number is in the small cache
188             return StirlingS2Cache.S2[n][k];
189         }
190 
191         // Simple cases
192         if (k == 0) {
193             return 0;
194         } else if (k == 1 || k == n) {
195             return 1;
196         } else if (k == 2) {
197             checkN(n, k, 64, S2_ERROR_FORMAT);
198             return (1L << (n - 1)) - 1L;
199         } else if (k == n - 1) {
200             return BinomialCoefficient.value(n, 2);
201         } else if (k == n - 2) {
202             checkN(n, k, S2_OVERFLOW_K_EQUALS_NM2, S2_ERROR_FORMAT);
203             // (3n-5) * binom(n, 3) / 4
204             return productOver4(3L * n - 5, BinomialCoefficient.value(n, 3));
205         } else if (k == n - 3) {
206             checkN(n, k, S2_OVERFLOW_K_EQUALS_NM3, S2_ERROR_FORMAT);
207             return BinomialCoefficient.value(n - 2, 2) * BinomialCoefficient.value(n, 4);
208         }
209 
210         // Compute using:
211         // S(n, k) = k * S(n - 1, k) + S(n - 1, k - 1)
212 
213         // n >= 26 (MAX_N)
214         // 3 <= k <= n-3
215 
216         // Start at the largest easily computed value: n < MAX_N or k < 3
217         final int reduction = Math.min(n - StirlingS2Cache.MAX_N, k - 3) + 1;
218         int n0 = n - reduction;
219         int k0 = k - reduction;
220 
221         long sum = stirlingS2(n0, k0);
222         while (n0 < n) {
223             k0++;
224             sum = Math.addExact(
225                 Math.multiplyExact(k0, stirlingS2(n0, k0)),
226                 sum
227             );
228             n0++;
229         }
230 
231         return sum;
232     }
233 
234     /**
235      * Check {@code 0 <= k <= n}.
236      *
237      * @param n N
238      * @param k K
239      * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
240      */
241     private static void checkArguments(int n, int k) {
242         // Combine all checks with a single branch:
243         // 0 <= n; 0 <= k <= n
244         // Note: If n >= 0 && k >= 0 && n - k < 0 then k > n.
245         // Bitwise or will detect a negative sign bit in any of the numbers
246         if ((n | k | (n - k)) < 0) {
247             // Raise the correct exception
248             if (n < 0) {
249                 throw new CombinatoricsException(CombinatoricsException.NEGATIVE, n);
250             }
251             throw new CombinatoricsException(CombinatoricsException.OUT_OF_RANGE, k, 0, n);
252         }
253     }
254 
255     /**
256      * Check {@code n <= threshold}, or else throw an {@link ArithmeticException}.
257      *
258      * @param n N
259      * @param k K
260      * @param threshold Threshold for {@code n}
261      * @param msgFormat Error message format
262      * @throws ArithmeticException if overflow is expected to happen
263      */
264     private static void checkN(int n, int k, int threshold, String msgFormat) {
265         if (n > threshold) {
266             throw new ArithmeticException(String.format(msgFormat, n, k));
267         }
268     }
269 
270     /**
271      * Return {@code a*b/4} without intermediate overflow.
272      * It is assumed that:
273      * <ul>
274      * <li>The coefficients a and b are positive
275      * <li>The product (a*b) is an exact multiple of 4
276      * <li>The result (a*b/4) is an exact integer that does not overflow a {@code long}
277      * </ul>
278      *
279      * <p>A conditional branch is performed on the odd/even property of {@code b}.
280      * The branch is predictable if {@code b} is typically the same parity.
281      *
282      * @param a Coefficient a
283      * @param b Coefficient b
284      * @return {@code a*b/4}
285      */
286     private static long productOver4(long a, long b) {
287         // Compute (a*b/4) without intermediate overflow.
288         // The product (a*b) must be an exact multiple of 4.
289         // If b is even: ((b/2) * a) / 2
290         // If b is odd then a must be even to make a*b even: ((a/2) * b) / 2
291         return (b & 1) == 0 ?
292             ((b >>> 1) * a) >>> 1 :
293             ((a >>> 1) * b) >>> 1;
294     }
295 }