1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.linear;
19
20 import org.apache.commons.math4.legacy.core.Field;
21 import org.apache.commons.math4.legacy.core.FieldElement;
22 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23 import org.apache.commons.math4.legacy.core.MathArrays;
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53 public class FieldLUDecomposition<T extends FieldElement<T>> {
54
55
56 private final Field<T> field;
57
58
59 private T[][] lu;
60
61
62 private int[] pivot;
63
64
65 private boolean even;
66
67
68 private boolean singular;
69
70
71 private FieldMatrix<T> cachedL;
72
73
74 private FieldMatrix<T> cachedU;
75
76
77 private FieldMatrix<T> cachedP;
78
79
80
81
82
83
84 public FieldLUDecomposition(FieldMatrix<T> matrix) {
85 if (!matrix.isSquare()) {
86 throw new NonSquareMatrixException(matrix.getRowDimension(),
87 matrix.getColumnDimension());
88 }
89
90 final int m = matrix.getColumnDimension();
91 field = matrix.getField();
92 lu = matrix.getData();
93 pivot = new int[m];
94 cachedL = null;
95 cachedU = null;
96 cachedP = null;
97
98
99 for (int row = 0; row < m; row++) {
100 pivot[row] = row;
101 }
102 even = true;
103 singular = false;
104
105
106 for (int col = 0; col < m; col++) {
107
108 T sum = field.getZero();
109
110
111 for (int row = 0; row < col; row++) {
112 final T[] luRow = lu[row];
113 sum = luRow[col];
114 for (int i = 0; i < row; i++) {
115 sum = sum.subtract(luRow[i].multiply(lu[i][col]));
116 }
117 luRow[col] = sum;
118 }
119
120
121 int nonZero = col;
122 for (int row = col; row < m; row++) {
123 final T[] luRow = lu[row];
124 sum = luRow[col];
125 for (int i = 0; i < col; i++) {
126 sum = sum.subtract(luRow[i].multiply(lu[i][col]));
127 }
128 luRow[col] = sum;
129
130 if (lu[nonZero][col].equals(field.getZero())) {
131
132 ++nonZero;
133 }
134 }
135
136
137 if (nonZero >= m) {
138 singular = true;
139 return;
140 }
141
142
143 if (nonZero != col) {
144 T tmp = field.getZero();
145 for (int i = 0; i < m; i++) {
146 tmp = lu[nonZero][i];
147 lu[nonZero][i] = lu[col][i];
148 lu[col][i] = tmp;
149 }
150 int temp = pivot[nonZero];
151 pivot[nonZero] = pivot[col];
152 pivot[col] = temp;
153 even = !even;
154 }
155
156
157 final T luDiag = lu[col][col];
158 for (int row = col + 1; row < m; row++) {
159 final T[] luRow = lu[row];
160 luRow[col] = luRow[col].divide(luDiag);
161 }
162 }
163 }
164
165
166
167
168
169
170 public FieldMatrix<T> getL() {
171 if (cachedL == null && !singular) {
172 final int m = pivot.length;
173 cachedL = new Array2DRowFieldMatrix<>(field, m, m);
174 for (int i = 0; i < m; ++i) {
175 final T[] luI = lu[i];
176 for (int j = 0; j < i; ++j) {
177 cachedL.setEntry(i, j, luI[j]);
178 }
179 cachedL.setEntry(i, i, field.getOne());
180 }
181 }
182 return cachedL;
183 }
184
185
186
187
188
189
190 public FieldMatrix<T> getU() {
191 if (cachedU == null && !singular) {
192 final int m = pivot.length;
193 cachedU = new Array2DRowFieldMatrix<>(field, m, m);
194 for (int i = 0; i < m; ++i) {
195 final T[] luI = lu[i];
196 for (int j = i; j < m; ++j) {
197 cachedU.setEntry(i, j, luI[j]);
198 }
199 }
200 }
201 return cachedU;
202 }
203
204
205
206
207
208
209
210
211
212
213 public FieldMatrix<T> getP() {
214 if (cachedP == null && !singular) {
215 final int m = pivot.length;
216 cachedP = new Array2DRowFieldMatrix<>(field, m, m);
217 for (int i = 0; i < m; ++i) {
218 cachedP.setEntry(i, pivot[i], field.getOne());
219 }
220 }
221 return cachedP;
222 }
223
224
225
226
227
228
229 public int[] getPivot() {
230 return pivot.clone();
231 }
232
233
234
235
236
237 public T getDeterminant() {
238 if (singular) {
239 return field.getZero();
240 } else {
241 final int m = pivot.length;
242 T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
243 for (int i = 0; i < m; i++) {
244 determinant = determinant.multiply(lu[i][i]);
245 }
246 return determinant;
247 }
248 }
249
250
251
252
253
254 public FieldDecompositionSolver<T> getSolver() {
255 return new Solver<>(field, lu, pivot, singular);
256 }
257
258
259
260
261 private static final class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {
262
263
264 private final Field<T> field;
265
266
267 private final T[][] lu;
268
269
270 private final int[] pivot;
271
272
273 private final boolean singular;
274
275
276
277
278
279
280
281
282 private Solver(final Field<T> field, final T[][] lu,
283 final int[] pivot, final boolean singular) {
284 this.field = field;
285 this.lu = lu;
286 this.pivot = pivot;
287 this.singular = singular;
288 }
289
290
291 @Override
292 public boolean isNonSingular() {
293 return !singular;
294 }
295
296
297 @Override
298 public FieldVector<T> solve(FieldVector<T> b) {
299 if (b instanceof ArrayFieldVector) {
300 return solve((ArrayFieldVector<T>) b);
301 }
302
303 final int m = pivot.length;
304 if (b.getDimension() != m) {
305 throw new DimensionMismatchException(b.getDimension(), m);
306 }
307 if (singular) {
308 throw new SingularMatrixException();
309 }
310
311
312 final T[] bp = MathArrays.buildArray(field, m);
313 for (int row = 0; row < m; row++) {
314 bp[row] = b.getEntry(pivot[row]);
315 }
316
317
318 for (int col = 0; col < m; col++) {
319 final T bpCol = bp[col];
320 for (int i = col + 1; i < m; i++) {
321 bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
322 }
323 }
324
325
326 for (int col = m - 1; col >= 0; col--) {
327 bp[col] = bp[col].divide(lu[col][col]);
328 final T bpCol = bp[col];
329 for (int i = 0; i < col; i++) {
330 bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
331 }
332 }
333
334 return new ArrayFieldVector<>(field, bp, false);
335 }
336
337
338
339
340
341
342
343
344 public ArrayFieldVector<T> solve(ArrayFieldVector<T> b) {
345 final int m = pivot.length;
346 final int length = b.getDimension();
347 if (length != m) {
348 throw new DimensionMismatchException(length, m);
349 }
350 if (singular) {
351 throw new SingularMatrixException();
352 }
353
354
355 final T[] bp = MathArrays.buildArray(field, m);
356 for (int row = 0; row < m; row++) {
357 bp[row] = b.getEntry(pivot[row]);
358 }
359
360
361 for (int col = 0; col < m; col++) {
362 final T bpCol = bp[col];
363 for (int i = col + 1; i < m; i++) {
364 bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
365 }
366 }
367
368
369 for (int col = m - 1; col >= 0; col--) {
370 bp[col] = bp[col].divide(lu[col][col]);
371 final T bpCol = bp[col];
372 for (int i = 0; i < col; i++) {
373 bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
374 }
375 }
376
377 return new ArrayFieldVector<>(bp, false);
378 }
379
380
381 @Override
382 public FieldMatrix<T> solve(FieldMatrix<T> b) {
383 final int m = pivot.length;
384 if (b.getRowDimension() != m) {
385 throw new DimensionMismatchException(b.getRowDimension(), m);
386 }
387 if (singular) {
388 throw new SingularMatrixException();
389 }
390
391 final int nColB = b.getColumnDimension();
392
393
394 final T[][] bp = MathArrays.buildArray(field, m, nColB);
395 for (int row = 0; row < m; row++) {
396 final T[] bpRow = bp[row];
397 final int pRow = pivot[row];
398 for (int col = 0; col < nColB; col++) {
399 bpRow[col] = b.getEntry(pRow, col);
400 }
401 }
402
403
404 for (int col = 0; col < m; col++) {
405 final T[] bpCol = bp[col];
406 for (int i = col + 1; i < m; i++) {
407 final T[] bpI = bp[i];
408 final T luICol = lu[i][col];
409 for (int j = 0; j < nColB; j++) {
410 bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
411 }
412 }
413 }
414
415
416 for (int col = m - 1; col >= 0; col--) {
417 final T[] bpCol = bp[col];
418 final T luDiag = lu[col][col];
419 for (int j = 0; j < nColB; j++) {
420 bpCol[j] = bpCol[j].divide(luDiag);
421 }
422 for (int i = 0; i < col; i++) {
423 final T[] bpI = bp[i];
424 final T luICol = lu[i][col];
425 for (int j = 0; j < nColB; j++) {
426 bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
427 }
428 }
429 }
430
431 return new Array2DRowFieldMatrix<>(field, bp, false);
432 }
433
434
435 @Override
436 public FieldMatrix<T> getInverse() {
437 final int m = pivot.length;
438 final T one = field.getOne();
439 FieldMatrix<T> identity = new Array2DRowFieldMatrix<>(field, m, m);
440 for (int i = 0; i < m; ++i) {
441 identity.setEntry(i, i, one);
442 }
443 return solve(identity);
444 }
445 }
446 }