001/******************************************************************************* 002 * Copyright (c) 2016 Diamond Light Source Ltd. and others. 003 * All rights reserved. This program and the accompanying materials 004 * are made available under the terms of the Eclipse Public License v1.0 005 * which accompanies this distribution, and is available at 006 * http://www.eclipse.org/legal/epl-v10.html 007 * 008 * Contributors: 009 * Diamond Light Source Ltd - initial API and implementation 010 *******************************************************************************/ 011package org.eclipse.january.dataset; 012 013import java.lang.reflect.Array; 014import java.util.ArrayList; 015import java.util.Collection; 016import java.util.List; 017import java.util.SortedSet; 018import java.util.TreeSet; 019 020public class ShapeUtils { 021 022 private ShapeUtils() { 023 } 024 025 /** 026 * Calculate total number of items in given shape 027 * @param shape 028 * @return size 029 */ 030 public static long calcLongSize(final int[] shape) { 031 if (shape == null) { // special case of null-shaped 032 return 0; 033 } 034 035 final int rank = shape.length; 036 if (rank == 0) { // special case of zero-rank shape 037 return 1; 038 } 039 040 double dsize = 1.0; 041 for (int i = 0; i < rank; i++) { 042 // make sure the indexes isn't zero or negative 043 if (shape[i] == 0) { 044 return 0; 045 } else if (shape[i] < 0) { 046 throw new IllegalArgumentException(String.format( 047 "The %d-th is %d which is not allowed as it is negative", i, shape[i])); 048 } 049 050 dsize *= shape[i]; 051 } 052 053 // check to see if the size is larger than an integer, i.e. we can't allocate it 054 if (dsize > Long.MAX_VALUE) { 055 throw new IllegalArgumentException("Size of the dataset is too large to allocate"); 056 } 057 return (long) dsize; 058 } 059 060 /** 061 * Calculate total number of items in given shape 062 * @param shape 063 * @return size 064 */ 065 public static int calcSize(final int[] shape) { 066 long lsize = calcLongSize(shape); 067 068 // check to see if the size is larger than an integer, i.e. we can't allocate it 069 if (lsize > Integer.MAX_VALUE) { 070 throw new IllegalArgumentException("Size of the dataset is too large to allocate"); 071 } 072 return (int) lsize; 073 } 074 075 /** 076 * Check if shapes are broadcast compatible 077 * 078 * @param ashape 079 * @param bshape 080 * @return true if they are compatible 081 */ 082 public static boolean areShapesBroadcastCompatible(final int[] ashape, final int[] bshape) { 083 if (ashape == null || bshape == null) { 084 return ashape == bshape; 085 } 086 087 if (ashape.length < bshape.length) { 088 return areShapesBroadcastCompatible(bshape, ashape); 089 } 090 091 for (int a = ashape.length - bshape.length, b = 0; a < ashape.length && b < bshape.length; a++, b++) { 092 if (ashape[a] != bshape[b] && ashape[a] != 1 && bshape[b] != 1) { 093 return false; 094 } 095 } 096 097 return true; 098 } 099 100 /** 101 * Check if shapes are compatible, ignoring extra axes of length 1 102 * 103 * @param ashape 104 * @param bshape 105 * @return true if they are compatible 106 */ 107 public static boolean areShapesCompatible(final int[] ashape, final int[] bshape) { 108 if (ashape == null || bshape == null) { 109 return ashape == bshape; 110 } 111 112 List<Integer> alist = new ArrayList<Integer>(); 113 114 for (int a : ashape) { 115 if (a > 1) alist.add(a); 116 } 117 118 final int imax = alist.size(); 119 int i = 0; 120 for (int b : bshape) { 121 if (b == 1) 122 continue; 123 if (i >= imax || b != alist.get(i++)) 124 return false; 125 } 126 127 return i == imax; 128 } 129 130 /** 131 * Check if shapes are compatible but skip axis 132 * 133 * @param ashape 134 * @param bshape 135 * @param axis 136 * @return true if they are compatible 137 */ 138 public static boolean areShapesCompatible(final int[] ashape, final int[] bshape, final int axis) { 139 if (ashape == null || bshape == null) { 140 return ashape == bshape; 141 } 142 143 if (ashape.length != bshape.length) { 144 return false; 145 } 146 147 final int rank = ashape.length; 148 for (int i = 0; i < rank; i++) { 149 if (i != axis && ashape[i] != bshape[i]) { 150 return false; 151 } 152 } 153 return true; 154 } 155 156 /** 157 * Remove dimensions of 1 in given shape - from both ends only, if true 158 * 159 * @param oshape 160 * @param onlyFromEnds 161 * @return newly squeezed shape (or original if unsqueezed) 162 */ 163 public static int[] squeezeShape(final int[] oshape, boolean onlyFromEnds) { 164 int unitDims = 0; 165 int rank = oshape.length; 166 int start = 0; 167 168 if (onlyFromEnds) { 169 int i = rank - 1; 170 for (; i >= 0; i--) { 171 if (oshape[i] == 1) { 172 unitDims++; 173 } else { 174 break; 175 } 176 } 177 for (int j = 0; j <= i; j++) { 178 if (oshape[j] == 1) { 179 unitDims++; 180 } else { 181 start = j; 182 break; 183 } 184 } 185 } else { 186 for (int i = 0; i < rank; i++) { 187 if (oshape[i] == 1) { 188 unitDims++; 189 } 190 } 191 } 192 193 if (unitDims == 0) { 194 return oshape; 195 } 196 197 int[] newDims = new int[rank - unitDims]; 198 if (unitDims == rank) 199 return newDims; // zero-rank dataset 200 201 if (onlyFromEnds) { 202 rank = newDims.length; 203 for (int i = 0; i < rank; i++) { 204 newDims[i] = oshape[i+start]; 205 } 206 } else { 207 int j = 0; 208 for (int i = 0; i < rank; i++) { 209 if (oshape[i] > 1) { 210 newDims[j++] = oshape[i]; 211 if (j >= newDims.length) 212 break; 213 } 214 } 215 } 216 217 return newDims; 218 } 219 220 /** 221 * Remove dimension of 1 in given shape 222 * 223 * @param oshape 224 * @param axis 225 * @return newly squeezed shape 226 */ 227 public static int[] squeezeShape(final int[] oshape, int axis) { 228 if (oshape == null) { 229 return null; 230 } 231 232 final int rank = oshape.length; 233 if (rank == 0) { 234 return new int[0]; 235 } 236 if (axis < 0) { 237 axis += rank; 238 } 239 if (axis < 0 || axis >= rank) { 240 throw new IllegalArgumentException("Axis argument is outside allowed range"); 241 } 242 int[] nshape = new int[rank-1]; 243 for (int i = 0; i < axis; i++) { 244 nshape[i] = oshape[i]; 245 } 246 for (int i = axis+1; i < rank; i++) { 247 nshape[i-1] = oshape[i]; 248 } 249 return nshape; 250 } 251 252 /** 253 * Get shape from object (array or list supported) 254 * @param obj 255 * @return shape can be null if obj is null 256 */ 257 public static int[] getShapeFromObject(final Object obj) { 258 if (obj == null) { 259 return null; 260 } 261 262 ArrayList<Integer> lshape = new ArrayList<Integer>(); 263 getShapeFromObj(lshape, obj, 0); 264 265 final int rank = lshape.size(); 266 final int[] shape = new int[rank]; 267 for (int i = 0; i < rank; i++) { 268 shape[i] = lshape.get(i); 269 } 270 271 return shape; 272 } 273 274 /** 275 * Get shape from object 276 * @param ldims 277 * @param obj 278 * @param depth 279 * @return true if there is a possibility of differing lengths 280 */ 281 private static boolean getShapeFromObj(final ArrayList<Integer> ldims, Object obj, int depth) { 282 if (obj == null) 283 return true; 284 285 if (obj instanceof List<?>) { 286 List<?> jl = (List<?>) obj; 287 int l = jl.size(); 288 updateShape(ldims, depth, l); 289 for (int i = 0; i < l; i++) { 290 Object lo = jl.get(i); 291 if (!getShapeFromObj(ldims, lo, depth + 1)) { 292 break; 293 } 294 } 295 return true; 296 } 297 Class<? extends Object> ca = obj.getClass().getComponentType(); 298 if (ca != null) { 299 final int l = Array.getLength(obj); 300 updateShape(ldims, depth, l); 301 if (DTypeUtils.isClassSupportedAsElement(ca)) { 302 return true; 303 } 304 for (int i = 0; i < l; i++) { 305 Object lo = Array.get(obj, i); 306 if (!getShapeFromObj(ldims, lo, depth + 1)) { 307 break; 308 } 309 } 310 return true; 311 } else if (obj instanceof IDataset) { 312 int[] s = ((IDataset) obj).getShape(); 313 for (int i = 0; i < s.length; i++) { 314 updateShape(ldims, depth++, s[i]); 315 } 316 return true; 317 } else { 318 return false; // not an array of any type 319 } 320 } 321 322 private static void updateShape(final ArrayList<Integer> ldims, final int depth, final int l) { 323 if (depth >= ldims.size()) { 324 ldims.add(l); 325 } else if (l > ldims.get(depth)) { 326 ldims.set(depth, l); 327 } 328 } 329 330 /** 331 * Get n-D position from given index 332 * @param n index 333 * @param shape 334 * @return n-D position 335 */ 336 public static int[] getNDPositionFromShape(int n, int[] shape) { 337 if (shape == null) { 338 return null; 339 } 340 341 int rank = shape.length; 342 if (rank == 0) { 343 return new int[0]; 344 } 345 346 if (rank == 1) { 347 return new int[] { n }; 348 } 349 350 int[] output = new int[rank]; 351 for (rank--; rank > 0; rank--) { 352 output[rank] = n % shape[rank]; 353 n /= shape[rank]; 354 } 355 output[0] = n; 356 357 return output; 358 } 359 360 /** 361 * Get flattened view index of given position 362 * @param shape 363 * @param pos 364 * the integer array specifying the n-D position 365 * @return the index on the flattened dataset 366 */ 367 public static int getFlat1DIndex(final int[] shape, final int[] pos) { 368 final int imax = pos.length; 369 if (imax == 0) { 370 return 0; 371 } 372 373 return AbstractDataset.get1DIndexFromShape(shape, pos); 374 } 375 376 /** 377 * This function takes a dataset and checks its shape against another dataset. If they are both of the same size, 378 * then this returns with no error, if there is a problem, then an error is thrown. 379 * 380 * @param g 381 * The first dataset to be compared 382 * @param h 383 * The second dataset to be compared 384 * @throws IllegalArgumentException 385 * This will be thrown if there is a problem with the compatibility 386 */ 387 public static void checkCompatibility(final ILazyDataset g, final ILazyDataset h) throws IllegalArgumentException { 388 if (!areShapesCompatible(g.getShape(), h.getShape())) { 389 throw new IllegalArgumentException("Shapes do not match"); 390 } 391 } 392 393 /** 394 * Check that axis is in range [-rank,rank) 395 * 396 * @param rank 397 * @param axis 398 * @return sanitized axis in range [0, rank) 399 * @since 2.1 400 */ 401 public static int checkAxis(int rank, int axis) { 402 if (axis < 0) { 403 axis += rank; 404 } 405 406 if (axis < 0 || axis >= rank) { 407 throw new IllegalArgumentException("Axis " + axis + " given is out of range [0, " + rank + ")"); 408 } 409 return axis; 410 } 411 412 private static int[] convert(Collection<Integer> list) { 413 int[] array = new int[list.size()]; 414 int i = 0; 415 for (Integer l : list) { 416 array[i++] = l; 417 } 418 return array; 419 } 420 421 /** 422 * Check that all axes are in range [-rank,rank) 423 * @param rank 424 * @param axes 425 * @return sanitized axes in range [0, rank) and sorted in increasing order 426 * @since 2.2 427 */ 428 public static int[] checkAxes(int rank, int... axes) { 429 return convert(sanitizeAxes(rank, axes)); 430 } 431 432 /** 433 * Check that all axes are in range [-rank,rank) 434 * @param rank 435 * @param axes 436 * @return sanitized axes in range [0, rank) and sorted in increasing order 437 * @since 2.2 438 */ 439 private static SortedSet<Integer> sanitizeAxes(int rank, int... axes) { 440 SortedSet<Integer> nAxes = new TreeSet<>(); 441 for (int i = 0; i < axes.length; i++) { 442 nAxes.add(checkAxis(rank, axes[i])); 443 } 444 445 return nAxes; 446 } 447 448 /** 449 * @param rank 450 * @param axes 451 * @return remaining axes not given by input 452 * @since 2.2 453 */ 454 public static int[] getRemainingAxes(int rank, int... axes) { 455 SortedSet<Integer> nAxes = sanitizeAxes(rank, axes); 456 457 int[] remains = new int[rank - axes.length]; 458 int j = 0; 459 for (int i = 0; i < rank; i++) { 460 if (!nAxes.contains(i)) { 461 remains[j++] = i; 462 } 463 } 464 return remains; 465 } 466 467 /** 468 * Remove axes from shape 469 * @param shape 470 * @param axes 471 * @return reduced shape 472 * @since 2.2 473 */ 474 public static int[] reduceShape(int[] shape, int... axes) { 475 int[] remain = getRemainingAxes(shape.length, axes); 476 for (int i = 0; i < remain.length; i++) { 477 int a = remain[i]; 478 remain[i] = shape[a]; 479 } 480 return remain; 481 } 482 483 /** 484 * Set reduced axes to 1 485 * @param shape 486 * @param axes 487 * @return shape with same rank 488 * @since 2.2 489 */ 490 public static int[] getReducedShapeKeepRank(int[] shape, int... axes) { 491 int[] keep = shape.clone(); 492 axes = checkAxes(shape.length, axes); 493 for (int i : axes) { 494 keep[i] = 1; 495 } 496 return keep; 497 } 498}