001/*- 002 ******************************************************************************* 003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd. 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 * 009 * Contributors: 010 * Peter Chang - initial API and implementation and/or initial documentation 011 *******************************************************************************/ 012 013package org.eclipse.january.dataset; 014 015import java.util.Arrays; 016import java.util.List; 017 018import org.apache.commons.math3.complex.Complex; 019import org.apache.commons.math3.linear.Array2DRowRealMatrix; 020import org.apache.commons.math3.linear.ArrayRealVector; 021import org.apache.commons.math3.linear.CholeskyDecomposition; 022import org.apache.commons.math3.linear.ConjugateGradient; 023import org.apache.commons.math3.linear.EigenDecomposition; 024import org.apache.commons.math3.linear.LUDecomposition; 025import org.apache.commons.math3.linear.MatrixUtils; 026import org.apache.commons.math3.linear.QRDecomposition; 027import org.apache.commons.math3.linear.RealLinearOperator; 028import org.apache.commons.math3.linear.RealMatrix; 029import org.apache.commons.math3.linear.RealVector; 030import org.apache.commons.math3.linear.SingularValueDecomposition; 031 032 033public class LinearAlgebra { 034 035 private static final int CROSSOVERPOINT = 16; // point at which using slice iterators for inner loop is faster 036 037 /** 038 * Calculate the tensor dot product over given axes. This is the sum of products of elements selected 039 * from the given axes in each dataset 040 * @param a 041 * @param b 042 * @param axisa axis dimension in a to sum over (can be -ve) 043 * @param axisb axis dimension in b to sum over (can be -ve) 044 * @return tensor dot product 045 */ 046 public static Dataset tensorDotProduct(final Dataset a, final Dataset b, final int axisa, final int axisb) { 047 // this is slower for summing lengths < ~15 048 final int[] ashape = a.getShapeRef(); 049 final int[] bshape = b.getShapeRef(); 050 final int arank = ashape.length; 051 final int brank = bshape.length; 052 int aaxis = ShapeUtils.checkAxis(arank, axisa); 053 054 if (ashape[aaxis] < CROSSOVERPOINT) { // faster to use position iteration 055 return tensorDotProduct(a, b, new int[] {axisa}, new int[] {axisb}); 056 } 057 int baxis = ShapeUtils.checkAxis(brank, axisb); 058 059 final boolean[] achoice = new boolean[arank]; 060 final boolean[] bchoice = new boolean[brank]; 061 Arrays.fill(achoice, true); 062 Arrays.fill(bchoice, true); 063 achoice[aaxis] = false; // flag which axes not to iterate over 064 bchoice[baxis] = false; 065 066 final boolean[] notachoice = new boolean[arank]; 067 final boolean[] notbchoice = new boolean[brank]; 068 notachoice[aaxis] = true; // flag which axes to iterate over 069 notbchoice[baxis] = true; 070 071 int drank = arank + brank - 2; 072 int[] dshape = new int[drank]; 073 int d = 0; 074 for (int i = 0; i < arank; i++) { 075 if (achoice[i]) 076 dshape[d++] = ashape[i]; 077 } 078 for (int i = 0; i < brank; i++) { 079 if (bchoice[i]) 080 dshape[d++] = bshape[i]; 081 } 082 int dtype = DTypeUtils.getBestDType(a.getDType(), b.getDType()); 083 @SuppressWarnings("deprecation") 084 Dataset data = DatasetFactory.zeros(dshape, dtype); 085 086 SliceIterator ita = a.getSliceIteratorFromAxes(null, achoice); 087 int l = 0; 088 final int[] apos = ita.getPos(); 089 while (ita.hasNext()) { 090 SliceIterator itb = b.getSliceIteratorFromAxes(null, bchoice); 091 final int[] bpos = itb.getPos(); 092 while (itb.hasNext()) { 093 SliceIterator itaa = a.getSliceIteratorFromAxes(apos, notachoice); 094 SliceIterator itba = b.getSliceIteratorFromAxes(bpos, notbchoice); 095 double sum = 0.0; 096 double com = 0.0; 097 while (itaa.hasNext() && itba.hasNext()) { 098 final double y = a.getElementDoubleAbs(itaa.index) * b.getElementDoubleAbs(itba.index) - com; 099 final double t = sum + y; 100 com = (t - sum) - y; 101 sum = t; 102 } 103 data.setObjectAbs(l++, sum); 104 } 105 } 106 107 return data; 108 } 109 110 /** 111 * Calculate the tensor dot product over given axes. This is the sum of products of elements selected 112 * from the given axes in each dataset 113 * @param a 114 * @param b 115 * @param axisa axis dimensions in a to sum over (can be -ve) 116 * @param axisb axis dimensions in b to sum over (can be -ve) 117 * @return tensor dot product 118 */ 119 public static Dataset tensorDotProduct(final Dataset a, final Dataset b, final int[] axisa, final int[] axisb) { 120 if (axisa.length != axisb.length) { 121 throw new IllegalArgumentException("Numbers of summing axes must be same"); 122 } 123 final int[] ashape = a.getShapeRef(); 124 final int[] bshape = b.getShapeRef(); 125 final int arank = ashape.length; 126 final int brank = bshape.length; 127 final int[] aaxes = new int[axisa.length]; 128 final int[] baxes = new int[axisa.length]; 129 for (int i = 0; i < axisa.length; i++) { 130 aaxes[i] = ShapeUtils.checkAxis(arank, axisa[i]); 131 int n = ShapeUtils.checkAxis(brank, axisb[i]); 132 baxes[i] = n; 133 134 if (ashape[aaxes[i]] != bshape[n]) { 135 throw new IllegalArgumentException("Summing axes do not have matching lengths"); 136 } 137 } 138 139 final boolean[] achoice = new boolean[arank]; 140 final boolean[] bchoice = new boolean[brank]; 141 Arrays.fill(achoice, true); 142 Arrays.fill(bchoice, true); 143 for (int i = 0; i < aaxes.length; i++) { // flag which axes to iterate over 144 achoice[aaxes[i]] = false; 145 bchoice[baxes[i]] = false; 146 } 147 148 int drank = arank + brank - 2*aaxes.length; 149 int[] dshape = new int[drank]; 150 int d = 0; 151 for (int i = 0; i < arank; i++) { 152 if (achoice[i]) 153 dshape[d++] = ashape[i]; 154 } 155 for (int i = 0; i < brank; i++) { 156 if (bchoice[i]) 157 dshape[d++] = bshape[i]; 158 } 159 int dtype = DTypeUtils.getBestDType(a.getDType(), b.getDType()); 160 @SuppressWarnings("deprecation") 161 Dataset data = DatasetFactory.zeros(dshape, dtype); 162 163 SliceIterator ita = a.getSliceIteratorFromAxes(null, achoice); 164 int l = 0; 165 final int[] apos = ita.getPos(); 166 while (ita.hasNext()) { 167 SliceIterator itb = b.getSliceIteratorFromAxes(null, bchoice); 168 final int[] bpos = itb.getPos(); 169 while (itb.hasNext()) { 170 double sum = 0.0; 171 double com = 0.0; 172 apos[aaxes[aaxes.length - 1]] = -1; 173 bpos[baxes[aaxes.length - 1]] = -1; 174 while (true) { // step through summing axes 175 int e = aaxes.length - 1; 176 for (; e >= 0; e--) { 177 int ai = aaxes[e]; 178 int bi = baxes[e]; 179 180 apos[ai]++; 181 bpos[bi]++; 182 if (apos[ai] == ashape[ai]) { 183 apos[ai] = 0; 184 bpos[bi] = 0; 185 } else 186 break; 187 } 188 if (e == -1) break; 189 final double y = a.getDouble(apos) * b.getDouble(bpos) - com; 190 final double t = sum + y; 191 com = (t - sum) - y; 192 sum = t; 193 } 194 data.setObjectAbs(l++, sum); 195 } 196 } 197 198 return data; 199 } 200 201 /** 202 * Calculate the dot product of two datasets. When <b>b</b> is a 1D dataset, the sum product over 203 * the last axis of <b>a</b> and <b>b</b> is returned. Where <b>a</b> is also a 1D dataset, a zero-rank dataset 204 * is returned. If <b>b</b> is 2D or higher, its second-to-last axis is used 205 * @param a 206 * @param b 207 * @return dot product 208 */ 209 public static Dataset dotProduct(Dataset a, Dataset b) { 210 if (b.getRank() < 2) 211 return tensorDotProduct(a, b, -1, 0); 212 return tensorDotProduct(a, b, -1, -2); 213 } 214 215 /** 216 * Calculate the outer product of two datasets 217 * @param a 218 * @param b 219 * @return outer product 220 */ 221 public static Dataset outerProduct(Dataset a, Dataset b) { 222 int[] as = a.getShapeRef(); 223 int[] bs = b.getShapeRef(); 224 int rank = as.length + bs.length; 225 int[] shape = new int[rank]; 226 for (int i = 0; i < as.length; i++) { 227 shape[i] = as[i]; 228 } 229 for (int i = 0; i < bs.length; i++) { 230 shape[as.length + i] = bs[i]; 231 } 232 int isa = a.getElementsPerItem(); 233 int isb = b.getElementsPerItem(); 234 if (isa != 1 || isb != 1) { 235 throw new UnsupportedOperationException("Compound datasets not supported"); 236 } 237 @SuppressWarnings("deprecation") 238 Dataset o = DatasetFactory.zeros(shape, DTypeUtils.getBestDType(a.getDType(), b.getDType())); 239 240 IndexIterator ita = a.getIterator(); 241 IndexIterator itb = b.getIterator(); 242 int j = 0; 243 while (ita.hasNext()) { 244 double va = a.getElementDoubleAbs(ita.index); 245 while (itb.hasNext()) { 246 o.setObjectAbs(j++, va * b.getElementDoubleAbs(itb.index)); 247 } 248 itb.reset(); 249 } 250 return o; 251 } 252 253 /** 254 * Calculate the cross product of two datasets. Datasets must be broadcastable and 255 * possess last dimensions of length 2 or 3 256 * @param a 257 * @param b 258 * @return cross product 259 */ 260 public static Dataset crossProduct(Dataset a, Dataset b) { 261 return crossProduct(a, b, -1, -1, -1); 262 } 263 264 /** 265 * Calculate the cross product of two datasets. Datasets must be broadcastable and 266 * possess dimensions of length 2 or 3. The axis parameters can be negative to indicate 267 * dimensions from the end of their shapes 268 * @param a 269 * @param b 270 * @param axisA dimension to be used a vector (must have length of 2 or 3) 271 * @param axisB dimension to be used a vector (must have length of 2 or 3) 272 * @param axisC dimension to assign as cross-product 273 * @return cross product 274 */ 275 public static Dataset crossProduct(Dataset a, Dataset b, int axisA, int axisB, int axisC) { 276 final int rankA = a.getRank(); 277 final int rankB = b.getRank(); 278 if (rankA == 0 || rankB == 0) { 279 throw new IllegalArgumentException("Datasets must have one or more dimensions"); 280 } 281 axisA = a.checkAxis(axisA); 282 axisB = b.checkAxis(axisB); 283 284 final int[] shapeA = a.getShape(); 285 final int[] shapeB = b.getShape(); 286 int la = shapeA[axisA]; 287 int lb = shapeB[axisB]; 288 if (Math.min(la, lb) < 2 || Math.max(la, lb) > 3) { 289 throw new IllegalArgumentException("Chosen dimension of A & B must be 2 or 3"); 290 } 291 292 if (Math.max(la, lb) == 2) { 293 return crossProduct2D(a, b, axisA, axisB); 294 } 295 296 return crossProduct3D(a, b, axisA, axisB, axisC); 297 } 298 299 private static int[] removeAxisFromShape(int[] shape, int axis) { 300 int[] s = new int[shape.length - 1]; 301 int i = 0; 302 int j = 0; 303 while (i < axis) { 304 s[j++] = shape[i++]; 305 } 306 i++; 307 while (i < shape.length) { 308 s[j++] = shape[i++]; 309 } 310 return s; 311 } 312 313 // assume axes is in increasing order 314 private static int[] removeAxesFromShape(int[] shape, int... axes) { 315 int n = axes.length; 316 int[] s = new int[shape.length - n]; 317 int i = 0; 318 int j = 0; 319 for (int k = 0; k < n; k++) { 320 int a = axes[k]; 321 while (i < a) { 322 s[j++] = shape[i++]; 323 } 324 i++; 325 } 326 while (i < shape.length) { 327 s[j++] = shape[i++]; 328 } 329 return s; 330 } 331 332 private static int[] addAxisToShape(int[] shape, int axis, int length) { 333 int[] s = new int[shape.length + 1]; 334 int i = 0; 335 int j = 0; 336 while (i < axis) { 337 s[j++] = shape[i++]; 338 } 339 s[j++] = length; 340 while (i < shape.length) { 341 s[j++] = shape[i++]; 342 } 343 return s; 344 } 345 346 private static Dataset crossProduct2D(Dataset a, Dataset b, int axisA, int axisB) { 347 // need to broadcast and omit given axes 348 int[] shapeA = removeAxisFromShape(a.getShapeRef(), axisA); 349 int[] shapeB = removeAxisFromShape(b.getShapeRef(), axisB); 350 351 List<int[]> fullShapes = BroadcastUtils.broadcastShapes(shapeA, shapeB); 352 353 int[] maxShape = fullShapes.get(0); 354 @SuppressWarnings("deprecation") 355 Dataset c = DatasetFactory.zeros(maxShape, DTypeUtils.getBestDType(a.getDType(), b.getDType())); 356 357 PositionIterator ita = a.getPositionIterator(axisA); 358 PositionIterator itb = b.getPositionIterator(axisB); 359 IndexIterator itc = c.getIterator(); 360 361 final int[] pa = ita.getPos(); 362 final int[] pb = itb.getPos(); 363 while (itc.hasNext()) { 364 if (!ita.hasNext()) // TODO use broadcasting... 365 ita.reset(); 366 if (!itb.hasNext()) 367 itb.reset(); 368 pa[axisA] = 0; 369 pb[axisB] = 1; 370 double cv = a.getDouble(pa) * b.getDouble(pb); 371 pa[axisA] = 1; 372 pb[axisB] = 0; 373 cv -= a.getDouble(pa) * b.getDouble(pb); 374 375 c.setObjectAbs(itc.index, cv); 376 } 377 return c; 378 } 379 380 private static Dataset crossProduct3D(Dataset a, Dataset b, int axisA, int axisB, int axisC) { 381 int[] shapeA = removeAxisFromShape(a.getShapeRef(), axisA); 382 int[] shapeB = removeAxisFromShape(b.getShapeRef(), axisB); 383 384 List<int[]> fullShapes = BroadcastUtils.broadcastShapes(shapeA, shapeB); 385 386 int[] maxShape = fullShapes.get(0); 387 int rankC = maxShape.length + 1; 388 axisC = ShapeUtils.checkAxis(rankC, axisC); 389 maxShape = addAxisToShape(maxShape, axisC, 3); 390 @SuppressWarnings("deprecation") 391 Dataset c = DatasetFactory.zeros(maxShape, DTypeUtils.getBestDType(a.getDType(), b.getDType())); 392 393 PositionIterator ita = a.getPositionIterator(axisA); 394 PositionIterator itb = b.getPositionIterator(axisB); 395 PositionIterator itc = c.getPositionIterator(axisC); 396 397 final int[] pa = ita.getPos(); 398 final int[] pb = itb.getPos(); 399 final int[] pc = itc.getPos(); 400 final int la = a.getShapeRef()[axisA]; 401 final int lb = b.getShapeRef()[axisB]; 402 403 if (la == 2) { 404 while (itc.hasNext()) { 405 if (!ita.hasNext()) // TODO use broadcasting... 406 ita.reset(); 407 if (!itb.hasNext()) 408 itb.reset(); 409 double cv; 410 pa[axisA] = 1; 411 pb[axisB] = 2; 412 cv = a.getDouble(pa) * b.getDouble(pb); 413 pc[axisC] = 0; 414 c.set(cv, pc); 415 416 pa[axisA] = 0; 417 pb[axisB] = 2; 418 cv = -a.getDouble(pa) * b.getDouble(pb); 419 pc[axisC] = 1; 420 c.set(cv, pc); 421 422 pa[axisA] = 0; 423 pb[axisB] = 1; 424 cv = a.getDouble(pa) * b.getDouble(pb); 425 pa[axisA] = 1; 426 pb[axisB] = 0; 427 cv -= a.getDouble(pa) * b.getDouble(pb); 428 pc[axisC] = 2; 429 c.set(cv, pc); 430 } 431 } else if (lb == 2) { 432 while (itc.hasNext()) { 433 if (!ita.hasNext()) // TODO use broadcasting... 434 ita.reset(); 435 if (!itb.hasNext()) 436 itb.reset(); 437 double cv; 438 pa[axisA] = 2; 439 pb[axisB] = 1; 440 cv = -a.getDouble(pa) * b.getDouble(pb); 441 pc[axisC] = 0; 442 c.set(cv, pc); 443 444 pa[axisA] = 2; 445 pb[axisB] = 0; 446 cv = a.getDouble(pa) * b.getDouble(pb); 447 pc[axisC] = 1; 448 c.set(cv, pc); 449 450 pa[axisA] = 0; 451 pb[axisB] = 1; 452 cv = a.getDouble(pa) * b.getDouble(pb); 453 pa[axisA] = 1; 454 pb[axisB] = 0; 455 cv -= a.getDouble(pa) * b.getDouble(pb); 456 pc[axisC] = 2; 457 c.set(cv, pc); 458 } 459 460 } else { 461 while (itc.hasNext()) { 462 if (!ita.hasNext()) // TODO use broadcasting... 463 ita.reset(); 464 if (!itb.hasNext()) 465 itb.reset(); 466 double cv; 467 pa[axisA] = 1; 468 pb[axisB] = 2; 469 cv = a.getDouble(pa) * b.getDouble(pb); 470 pa[axisA] = 2; 471 pb[axisB] = 1; 472 cv -= a.getDouble(pa) * b.getDouble(pb); 473 pc[axisC] = 0; 474 c.set(cv, pc); 475 476 pa[axisA] = 2; 477 pb[axisB] = 0; 478 cv = a.getDouble(pa) * b.getDouble(pb); 479 pa[axisA] = 0; 480 pb[axisB] = 2; 481 cv -= a.getDouble(pa) * b.getDouble(pb); 482 pc[axisC] = 1; 483 c.set(cv, pc); 484 485 pa[axisA] = 0; 486 pb[axisB] = 1; 487 cv = a.getDouble(pa) * b.getDouble(pb); 488 pa[axisA] = 1; 489 pb[axisB] = 0; 490 cv -= a.getDouble(pa) * b.getDouble(pb); 491 pc[axisC] = 2; 492 c.set(cv, pc); 493 } 494 } 495 return c; 496 } 497 498 /** 499 * Raise dataset to given power by matrix multiplication 500 * @param a 501 * @param n power 502 * @return a ** n 503 */ 504 public static Dataset power(Dataset a, int n) { 505 if (n < 0) { 506 LUDecomposition lud = new LUDecomposition(createRealMatrix(a)); 507 return createDataset(lud.getSolver().getInverse().power(-n)); 508 } 509 Dataset p = createDataset(createRealMatrix(a).power(n)); 510 if (!a.hasFloatingPointElements()) 511 return p.cast(a.getDType()); 512 return p; 513 } 514 515 /** 516 * Create the Kronecker product as defined by 517 * kron[k0,...,kN] = a[i0,...,iN] * b[j0,...,jN] 518 * where kn = sn * in + jn for n = 0...N and s is shape of b 519 * @param a 520 * @param b 521 * @return Kronecker product of a and b 522 */ 523 public static Dataset kroneckerProduct(Dataset a, Dataset b) { 524 if (a.getElementsPerItem() != 1 || b.getElementsPerItem() != 1) { 525 throw new UnsupportedOperationException("Compound datasets (including complex ones) are not currently supported"); 526 } 527 int ar = a.getRank(); 528 int br = b.getRank(); 529 int[] aShape; 530 int[] bShape; 531 aShape = a.getShapeRef(); 532 bShape = b.getShapeRef(); 533 int r = ar; 534 // pre-pad if ranks are not same 535 if (ar < br) { 536 r = br; 537 int[] shape = new int[br]; 538 int j = 0; 539 for (int i = ar; i < br; i++) { 540 shape[j++] = 1; 541 } 542 int i = 0; 543 while (j < br) { 544 shape[j++] = aShape[i++]; 545 } 546 a = a.reshape(shape); 547 aShape = shape; 548 } else if (ar > br) { 549 int[] shape = new int[ar]; 550 int j = 0; 551 for (int i = br; i < ar; i++) { 552 shape[j++] = 1; 553 } 554 int i = 0; 555 while (j < ar) { 556 shape[j++] = bShape[i++]; 557 } 558 b = b.reshape(shape); 559 bShape = shape; 560 } 561 562 int[] nShape = new int[r]; 563 for (int i = 0; i < r; i++) { 564 nShape[i] = aShape[i] * bShape[i]; 565 } 566 @SuppressWarnings("deprecation") 567 Dataset kron = DatasetFactory.zeros(nShape, DTypeUtils.getBestDType(a.getDType(), b.getDType())); 568 IndexIterator ita = a.getIterator(true); 569 IndexIterator itb = b.getIterator(true); 570 int[] pa = ita.getPos(); 571 int[] pb = itb.getPos(); 572 int[] off = new int[1]; 573 int[] stride = AbstractDataset.createStrides(1, nShape, null, 0, off); 574 if (kron.getDType() == Dataset.INT64) { 575 while (ita.hasNext()) { 576 long av = a.getElementLongAbs(ita.index); 577 578 int ka = 0; 579 for (int i = 0; i < r; i++) { 580 ka += stride[i] * bShape[i] * pa[i]; 581 } 582 itb.reset(); 583 while (itb.hasNext()) { 584 long bv = b.getElementLongAbs(itb.index); 585 int kb = ka; 586 for (int i = 0; i < r; i++) { 587 kb += stride[i] * pb[i]; 588 } 589 kron.setObjectAbs(kb, av * bv); 590 } 591 } 592 } else { 593 while (ita.hasNext()) { 594 double av = a.getElementDoubleAbs(ita.index); 595 596 int ka = 0; 597 for (int i = 0; i < r; i++) { 598 ka += stride[i] * bShape[i] * pa[i]; 599 } 600 itb.reset(); 601 while (itb.hasNext()) { 602 double bv = b.getElementLongAbs(itb.index); 603 int kb = ka; 604 for (int i = 0; i < r; i++) { 605 kb += stride[i] * pb[i]; 606 } 607 kron.setObjectAbs(kb, av * bv); 608 } 609 } 610 } 611 612 return kron; 613 } 614 615 /** 616 * Calculate trace of dataset - sum of values over 1st axis and 2nd axis 617 * @param a 618 * @return trace of dataset 619 */ 620 public static Dataset trace(Dataset a) { 621 return trace(a, 0, 0, 1); 622 } 623 624 /** 625 * Calculate trace of dataset - sum of values over axis1 and axis2 where axis2 is offset 626 * @param a 627 * @param offset 628 * @param axis1 629 * @param axis2 630 * @return trace of dataset 631 */ 632 public static Dataset trace(Dataset a, int offset, int axis1, int axis2) { 633 int[] shape = a.getShapeRef(); 634 int[] axes = new int[] { a.checkAxis(axis1), a.checkAxis(axis2) }; 635 Arrays.sort(axes); 636 int is = a.getElementsPerItem(); 637 @SuppressWarnings("deprecation") 638 Dataset trace = DatasetFactory.zeros(is, removeAxesFromShape(shape, axes), a.getDType()); 639 640 int am = axes[0]; 641 int mmax = shape[am]; 642 int an = axes[1]; 643 int nmax = shape[an]; 644 PositionIterator it = new PositionIterator(shape, axes); 645 int[] pos = it.getPos(); 646 int i = 0; 647 int mmin; 648 int nmin; 649 if (offset >= 0) { 650 mmin = 0; 651 nmin = offset; 652 } else { 653 mmin = -offset; 654 nmin = 0; 655 } 656 if (is == 1) { 657 if (a.getDType() == Dataset.INT64) { 658 while (it.hasNext()) { 659 int m = mmin; 660 int n = nmin; 661 long s = 0; 662 while (m < mmax && n < nmax) { 663 pos[am] = m++; 664 pos[an] = n++; 665 s += a.getLong(pos); 666 } 667 trace.setObjectAbs(i++, s); 668 } 669 } else { 670 while (it.hasNext()) { 671 int m = mmin; 672 int n = nmin; 673 double s = 0; 674 while (m < mmax && n < nmax) { 675 pos[am] = m++; 676 pos[an] = n++; 677 s += a.getDouble(pos); 678 } 679 trace.setObjectAbs(i++, s); 680 } 681 } 682 } else { 683 AbstractCompoundDataset ca = (AbstractCompoundDataset) a; 684 if (ca instanceof CompoundLongDataset) { 685 long[] t = new long[is]; 686 long[] s = new long[is]; 687 while (it.hasNext()) { 688 int m = mmin; 689 int n = nmin; 690 Arrays.fill(s, 0); 691 while (m < mmax && n < nmax) { 692 pos[am] = m++; 693 pos[an] = n++; 694 ((CompoundLongDataset)ca).getAbs(ca.get1DIndex(pos), t); 695 for (int k = 0; k < is; k++) { 696 s[k] += t[k]; 697 } 698 } 699 trace.setObjectAbs(i++, s); 700 } 701 } else { 702 double[] t = new double[is]; 703 double[] s = new double[is]; 704 while (it.hasNext()) { 705 int m = mmin; 706 int n = nmin; 707 Arrays.fill(s, 0); 708 while (m < mmax && n < nmax) { 709 pos[am] = m++; 710 pos[an] = n++; 711 ca.getDoubleArray(t, pos); 712 for (int k = 0; k < is; k++) { 713 s[k] += t[k]; 714 } 715 } 716 trace.setObjectAbs(i++, s); 717 } 718 } 719 } 720 721 return trace; 722 } 723 724 /** 725 * Order value for norm 726 */ 727 public enum NormOrder { 728 /** 729 * 2-norm for vectors and Frobenius for matrices 730 */ 731 DEFAULT, 732 /** 733 * Frobenius (not allowed for vectors) 734 */ 735 FROBENIUS, 736 /** 737 * Zero-order (not allowed for matrices) 738 */ 739 ZERO, 740 /** 741 * Positive infinity 742 */ 743 POS_INFINITY, 744 /** 745 * Negative infinity 746 */ 747 NEG_INFINITY; 748 } 749 750 /** 751 * @param a 752 * @return norm of dataset 753 */ 754 public static double norm(Dataset a) { 755 return norm(a, NormOrder.DEFAULT); 756 } 757 758 /** 759 * @param a 760 * @param order 761 * @return norm of dataset 762 */ 763 public static double norm(Dataset a, NormOrder order) { 764 int r = a.getRank(); 765 if (r == 1) { 766 return vectorNorm(a, order); 767 } else if (r == 2) { 768 return matrixNorm(a, order); 769 } 770 throw new IllegalArgumentException("Rank of dataset must be one or two"); 771 } 772 773 private static double vectorNorm(Dataset a, NormOrder order) { 774 double n; 775 IndexIterator it; 776 switch (order) { 777 case FROBENIUS: 778 throw new IllegalArgumentException("Not allowed for vectors"); 779 case NEG_INFINITY: 780 case POS_INFINITY: 781 it = a.getIterator(); 782 if (order == NormOrder.POS_INFINITY) { 783 n = Double.NEGATIVE_INFINITY; 784 if (a.isComplex()) { 785 while (it.hasNext()) { 786 double v = ((Complex) a.getObjectAbs(it.index)).abs(); 787 n = Math.max(n, v); 788 } 789 } else { 790 while (it.hasNext()) { 791 double v = Math.abs(a.getElementDoubleAbs(it.index)); 792 n = Math.max(n, v); 793 } 794 } 795 } else { 796 n = Double.POSITIVE_INFINITY; 797 if (a.isComplex()) { 798 while (it.hasNext()) { 799 double v = ((Complex) a.getObjectAbs(it.index)).abs(); 800 n = Math.min(n, v); 801 } 802 } else { 803 while (it.hasNext()) { 804 double v = Math.abs(a.getElementDoubleAbs(it.index)); 805 n = Math.min(n, v); 806 } 807 } 808 } 809 break; 810 case ZERO: 811 it = a.getIterator(); 812 n = 0; 813 if (a.isComplex()) { 814 while (it.hasNext()) { 815 if (!((Complex) a.getObjectAbs(it.index)).equals(Complex.ZERO)) 816 n++; 817 } 818 } else { 819 while (it.hasNext()) { 820 if (a.getElementBooleanAbs(it.index)) 821 n++; 822 } 823 } 824 825 break; 826 default: 827 n = vectorNorm(a, 2); 828 break; 829 } 830 return n; 831 } 832 833 private static double matrixNorm(Dataset a, NormOrder order) { 834 double n; 835 IndexIterator it; 836 switch (order) { 837 case NEG_INFINITY: 838 case POS_INFINITY: 839 n = maxMinMatrixNorm(a, 1, order == NormOrder.POS_INFINITY); 840 break; 841 case ZERO: 842 throw new IllegalArgumentException("Not allowed for matrices"); 843 default: 844 case FROBENIUS: 845 it = a.getIterator(); 846 n = 0; 847 if (a.isComplex()) { 848 while (it.hasNext()) { 849 double v = ((Complex) a.getObjectAbs(it.index)).abs(); 850 n += v*v; 851 } 852 } else { 853 while (it.hasNext()) { 854 double v = a.getElementDoubleAbs(it.index); 855 n += v*v; 856 } 857 } 858 n = Math.sqrt(n); 859 break; 860 } 861 return n; 862 } 863 864 /** 865 * @param a 866 * @param p 867 * @return p-norm of dataset 868 */ 869 public static double norm(Dataset a, final double p) { 870 if (p == 0) { 871 return norm(a, NormOrder.ZERO); 872 } 873 int r = a.getRank(); 874 if (r == 1) { 875 return vectorNorm(a, p); 876 } else if (r == 2) { 877 return matrixNorm(a, p); 878 } 879 throw new IllegalArgumentException("Rank of dataset must be one or two"); 880 } 881 882 private static double vectorNorm(Dataset a, final double p) { 883 IndexIterator it = a.getIterator(); 884 double n = 0; 885 if (a.isComplex()) { 886 while (it.hasNext()) { 887 double v = ((Complex) a.getObjectAbs(it.index)).abs(); 888 if (p == 2) { 889 v *= v; 890 } else if (p != 1) { 891 v = Math.pow(v, p); 892 } 893 n += v; 894 } 895 } else { 896 while (it.hasNext()) { 897 double v = a.getElementDoubleAbs(it.index); 898 if (p == 1) { 899 v = Math.abs(v); 900 } else if (p == 2) { 901 v *= v; 902 } else { 903 v = Math.pow(Math.abs(v), p); 904 } 905 n += v; 906 } 907 } 908 return Math.pow(n, 1./p); 909 } 910 911 private static double matrixNorm(Dataset a, final double p) { 912 double n; 913 if (Math.abs(p) == 1) { 914 n = maxMinMatrixNorm(a, 0, p > 0); 915 } else if (Math.abs(p) == 2) { 916 double[] s = calcSingularValues(a); 917 n = p > 0 ? s[0] : s[s.length - 1]; 918 } else { 919 throw new IllegalArgumentException("Order not allowed"); 920 } 921 922 return n; 923 } 924 925 private static double maxMinMatrixNorm(Dataset a, int d, boolean max) { 926 double n; 927 IndexIterator it; 928 int[] pos; 929 int l; 930 it = a.getPositionIterator(d); 931 pos = it.getPos(); 932 l = a.getShapeRef()[d]; 933 if (max) { 934 n = Double.NEGATIVE_INFINITY; 935 if (a.isComplex()) { 936 while (it.hasNext()) { 937 double v = ((Complex) a.getObject(pos)).abs(); 938 for (int i = 1; i < l; i++) { 939 pos[d] = i; 940 v += ((Complex) a.getObject(pos)).abs(); 941 } 942 pos[d] = 0; 943 n = Math.max(n, v); 944 } 945 } else { 946 while (it.hasNext()) { 947 double v = Math.abs(a.getDouble(pos)); 948 for (int i = 1; i < l; i++) { 949 pos[d] = i; 950 v += Math.abs(a.getDouble(pos)); 951 } 952 pos[d] = 0; 953 n = Math.max(n, v); 954 } 955 } 956 } else { 957 n = Double.POSITIVE_INFINITY; 958 if (a.isComplex()) { 959 while (it.hasNext()) { 960 double v = ((Complex) a.getObject(pos)).abs(); 961 for (int i = 1; i < l; i++) { 962 pos[d] = i; 963 v += ((Complex) a.getObject(pos)).abs(); 964 } 965 pos[d] = 0; 966 n = Math.min(n, v); 967 } 968 } else { 969 while (it.hasNext()) { 970 double v = Math.abs(a.getDouble(pos)); 971 for (int i = 1; i < l; i++) { 972 pos[d] = i; 973 v += Math.abs(a.getDouble(pos)); 974 } 975 pos[d] = 0; 976 n = Math.min(n, v); 977 } 978 } 979 } 980 return n; 981 } 982 983 /** 984 * @param a 985 * @return array of singular values 986 */ 987 public static double[] calcSingularValues(Dataset a) { 988 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 989 return svd.getSingularValues(); 990 } 991 992 993 /** 994 * Calculate singular value decomposition A = U S V^T 995 * @param a 996 * @return array of U - orthogonal matrix, s - singular values vector, V - orthogonal matrix 997 */ 998 public static Dataset[] calcSingularValueDecomposition(Dataset a) { 999 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1000 return new Dataset[] {createDataset(svd.getU()), DatasetFactory.createFromObject(svd.getSingularValues()), 1001 createDataset(svd.getV())}; 1002 } 1003 1004 /** 1005 * Calculate (Moore-Penrose) pseudo-inverse 1006 * @param a 1007 * @return pseudo-inverse 1008 */ 1009 public static Dataset calcPseudoInverse(Dataset a) { 1010 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1011 return createDataset(svd.getSolver().getInverse()); 1012 } 1013 1014 /** 1015 * Calculate matrix rank by singular value decomposition method 1016 * @param a 1017 * @return effective numerical rank of matrix 1018 */ 1019 public static int calcMatrixRank(Dataset a) { 1020 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1021 return svd.getRank(); 1022 } 1023 1024 /** 1025 * Calculate condition number of matrix by singular value decomposition method 1026 * @param a 1027 * @return condition number 1028 */ 1029 public static double calcConditionNumber(Dataset a) { 1030 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1031 return svd.getConditionNumber(); 1032 } 1033 1034 /** 1035 * @param a 1036 * @return determinant of dataset 1037 */ 1038 public static double calcDeterminant(Dataset a) { 1039 EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a)); 1040 return evd.getDeterminant(); 1041 } 1042 1043 /** 1044 * @param a 1045 * @return dataset of eigenvalues (can be double or complex double) 1046 */ 1047 public static Dataset calcEigenvalues(Dataset a) { 1048 EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a)); 1049 double[] rev = evd.getRealEigenvalues(); 1050 1051 if (evd.hasComplexEigenvalues()) { 1052 double[] iev = evd.getImagEigenvalues(); 1053 return DatasetFactory.createComplexDataset(ComplexDoubleDataset.class, rev, iev); 1054 } 1055 return DatasetFactory.createFromObject(rev); 1056 } 1057 1058 /** 1059 * Calculate eigen-decomposition A = V D V^T 1060 * @param a 1061 * @return array of D eigenvalues (can be double or complex double) and V eigenvectors 1062 */ 1063 public static Dataset[] calcEigenDecomposition(Dataset a) { 1064 EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a)); 1065 Dataset[] results = new Dataset[2]; 1066 1067 double[] rev = evd.getRealEigenvalues(); 1068 if (evd.hasComplexEigenvalues()) { 1069 double[] iev = evd.getImagEigenvalues(); 1070 results[0] = DatasetFactory.createComplexDataset(ComplexDoubleDataset.class, rev, iev); 1071 } else { 1072 results[0] = DatasetFactory.createFromObject(rev); 1073 } 1074 results[1] = createDataset(evd.getV()); 1075 return results; 1076 } 1077 1078 /** 1079 * Calculate QR decomposition A = Q R 1080 * @param a 1081 * @return array of Q and R 1082 */ 1083 public static Dataset[] calcQRDecomposition(Dataset a) { 1084 QRDecomposition qrd = new QRDecomposition(createRealMatrix(a)); 1085 return new Dataset[] {createDataset(qrd.getQT()).getTransposedView(), createDataset(qrd.getR())}; 1086 } 1087 1088 /** 1089 * Calculate LU decomposition A = P^-1 L U 1090 * @param a 1091 * @return array of L, U and P 1092 */ 1093 public static Dataset[] calcLUDecomposition(Dataset a) { 1094 LUDecomposition lud = new LUDecomposition(createRealMatrix(a)); 1095 return new Dataset[] {createDataset(lud.getL()), createDataset(lud.getU()), 1096 createDataset(lud.getP())}; 1097 } 1098 1099 /** 1100 * Calculate inverse of square dataset 1101 * @param a 1102 * @return inverse 1103 */ 1104 public static Dataset calcInverse(Dataset a) { 1105 LUDecomposition lud = new LUDecomposition(createRealMatrix(a)); 1106 return createDataset(lud.getSolver().getInverse()); 1107 } 1108 1109 /** 1110 * Solve linear matrix equation A x = v 1111 * @param a 1112 * @param v 1113 * @return x 1114 */ 1115 public static Dataset solve(Dataset a, Dataset v) { 1116 LUDecomposition lud = new LUDecomposition(createRealMatrix(a)); 1117 if (v.getRank() == 1) { 1118 RealVector x = createRealVector(v); 1119 return createDataset(lud.getSolver().solve(x)); 1120 } 1121 RealMatrix x = createRealMatrix(v); 1122 return createDataset(lud.getSolver().solve(x)); 1123 } 1124 1125 1126 /** 1127 * Solve least squares matrix equation A x = v by SVD 1128 * @param a 1129 * @param v 1130 * @return x 1131 */ 1132 public static Dataset solveSVD(Dataset a, Dataset v) { 1133 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1134 if (v.getRank() == 1) { 1135 RealVector x = createRealVector(v); 1136 return createDataset(svd.getSolver().solve(x)); 1137 } 1138 RealMatrix x = createRealMatrix(v); 1139 return createDataset(svd.getSolver().solve(x)); 1140 } 1141 1142 /** 1143 * Calculate Cholesky decomposition A = L L^T 1144 * @param a 1145 * @return L 1146 */ 1147 public static Dataset calcCholeskyDecomposition(Dataset a) { 1148 CholeskyDecomposition cd = new CholeskyDecomposition(createRealMatrix(a)); 1149 return createDataset(cd.getL()); 1150 } 1151 1152 /** 1153 * Calculation A x = v by conjugate gradient method with the stopping criterion being 1154 * that the estimated residual r = v - A x satisfies ||r|| < ||v|| with maximum of 100 iterations 1155 * @param a 1156 * @param v 1157 * @return solution of A^-1 v by conjugate gradient method 1158 */ 1159 public static Dataset calcConjugateGradient(Dataset a, Dataset v) { 1160 return calcConjugateGradient(a, v, 100, 1); 1161 } 1162 1163 /** 1164 * Calculation A x = v by conjugate gradient method with the stopping criterion being 1165 * that the estimated residual r = v - A x satisfies ||r|| < delta ||v|| 1166 * @param a 1167 * @param v 1168 * @param maxIterations 1169 * @param delta parameter used by stopping criterion 1170 * @return solution of A^-1 v by conjugate gradient method 1171 */ 1172 public static Dataset calcConjugateGradient(Dataset a, Dataset v, int maxIterations, double delta) { 1173 ConjugateGradient cg = new ConjugateGradient(maxIterations, delta, false); 1174 return createDataset(cg.solve((RealLinearOperator) createRealMatrix(a), createRealVector(v))); 1175 } 1176 1177 private static RealMatrix createRealMatrix(Dataset a) { 1178 if (a.getRank() != 2) { 1179 throw new IllegalArgumentException("Dataset must be rank 2"); 1180 } 1181 int[] shape = a.getShapeRef(); 1182 IndexIterator it = a.getIterator(true); 1183 int[] pos = it.getPos(); 1184 RealMatrix m = MatrixUtils.createRealMatrix(shape[0], shape[1]); 1185 while (it.hasNext()) { 1186 m.setEntry(pos[0], pos[1], a.getElementDoubleAbs(it.index)); 1187 } 1188 return m; 1189 } 1190 1191 private static RealVector createRealVector(Dataset a) { 1192 if (a.getRank() != 1) { 1193 throw new IllegalArgumentException("Dataset must be rank 1"); 1194 } 1195 int size = a.getSize(); 1196 IndexIterator it = a.getIterator(true); 1197 int[] pos = it.getPos(); 1198 RealVector m = new ArrayRealVector(size); 1199 while (it.hasNext()) { 1200 m.setEntry(pos[0], a.getElementDoubleAbs(it.index)); 1201 } 1202 return m; 1203 } 1204 1205 private static Dataset createDataset(RealVector v) { 1206 DoubleDataset r = DatasetFactory.zeros(DoubleDataset.class, v.getDimension()); 1207 int size = r.getSize(); 1208 if (v instanceof ArrayRealVector) { 1209 double[] data = ((ArrayRealVector) v).getDataRef(); 1210 for (int i = 0; i < size; i++) { 1211 r.setAbs(i, data[i]); 1212 } 1213 } else { 1214 for (int i = 0; i < size; i++) { 1215 r.setAbs(i, v.getEntry(i)); 1216 } 1217 } 1218 return r; 1219 } 1220 1221 private static Dataset createDataset(RealMatrix m) { 1222 DoubleDataset r = DatasetFactory.zeros(DoubleDataset.class, m.getRowDimension(), m.getColumnDimension()); 1223 if (m instanceof Array2DRowRealMatrix) { 1224 double[][] data = ((Array2DRowRealMatrix) m).getDataRef(); 1225 IndexIterator it = r.getIterator(true); 1226 int[] pos = it.getPos(); 1227 while (it.hasNext()) { 1228 r.setAbs(it.index, data[pos[0]][pos[1]]); 1229 } 1230 } else { 1231 IndexIterator it = r.getIterator(true); 1232 int[] pos = it.getPos(); 1233 while (it.hasNext()) { 1234 r.setAbs(it.index, m.getEntry(pos[0], pos[1])); 1235 } 1236 } 1237 return r; 1238 } 1239}