001/*- 002 * Copyright 2016 Diamond Light Source Ltd. 003 * 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 */ 009 010package org.eclipse.january.dataset; 011 012import java.util.List; 013 014/** 015 * Class to run over a pair of datasets in parallel with NumPy broadcasting of second dataset 016 */ 017public class BroadcastSingleIterator extends BroadcastSelfIterator { 018 private int[] bShape; 019 private int[] aStride; 020 private int[] bStride; 021 022 final private int endrank; 023 024 private final int[] aDelta, bDelta; 025 private final int aStep, bStep; 026 private int aMax, bMax; 027 private int aStart, bStart; 028 029 /** 030 * 031 * @param a 032 * @param b 033 */ 034 public BroadcastSingleIterator(Dataset a, Dataset b) { 035 super(a, b); 036 037 int[] aShape = a.getShapeRef(); 038 maxShape = aShape; 039 List<int[]> fullShapes = BroadcastUtils.broadcastShapesToMax(maxShape, b.getShapeRef()); 040 bShape = fullShapes.remove(0); 041 042 int rank = maxShape.length; 043 endrank = rank - 1; 044 045 bDataset = b.reshape(bShape); 046 int[] aOffset = new int[1]; 047 aStride = AbstractDataset.createStrides(aDataset, aOffset); 048 bStride = BroadcastUtils.createBroadcastStrides(bDataset, maxShape); 049 050 pos = new int[rank]; 051 aDelta = new int[rank]; 052 aStep = aDataset.getElementsPerItem(); 053 bDelta = new int[rank]; 054 bStep = bDataset.getElementsPerItem(); 055 for (int j = endrank; j >= 0; j--) { 056 aDelta[j] = aStride[j] * aShape[j]; 057 bDelta[j] = bStride[j] * bShape[j]; 058 } 059 aStart = aOffset[0]; 060 bStart = bDataset.getOffset(); 061 aMax = endrank < 0 ? aStep + aStart: Integer.MIN_VALUE; 062 bMax = endrank < 0 ? bStep + bStart: Integer.MIN_VALUE; 063 reset(); 064 } 065 066 @Override 067 public boolean hasNext() { 068 int j = endrank; 069 int oldB = bIndex; 070 for (; j >= 0; j--) { 071 pos[j]++; 072 aIndex += aStride[j]; 073 bIndex += bStride[j]; 074 if (pos[j] >= maxShape[j]) { 075 pos[j] = 0; 076 aIndex -= aDelta[j]; // reset these dimensions 077 bIndex -= bDelta[j]; 078 } else { 079 break; 080 } 081 } 082 if (j == -1) { 083 if (endrank >= 0) { 084 return false; 085 } 086 aIndex += aStep; 087 bIndex += bStep; 088 } 089 090 if (aIndex == aMax || bIndex == bMax) { 091 return false; 092 } 093 094 if (read) { 095 if (oldB != bIndex) { 096 if (asDouble) { 097 bDouble = bDataset.getElementDoubleAbs(bIndex); 098 } else { 099 bLong = bDataset.getElementLongAbs(bIndex); 100 } 101 } 102 } 103 104 return true; 105 } 106 107 /** 108 * @return shape of first broadcasted dataset 109 */ 110 public int[] getFirstShape() { 111 return maxShape; 112 } 113 114 /** 115 * @return shape of second broadcasted dataset 116 */ 117 public int[] getSecondShape() { 118 return bShape; 119 } 120 121 @Override 122 public void reset() { 123 for (int i = 0; i <= endrank; i++) 124 pos[i] = 0; 125 126 if (endrank >= 0) { 127 pos[endrank] = -1; 128 aIndex = aStart - aStride[endrank]; 129 bIndex = bStart - bStride[endrank]; 130 } else { 131 aIndex = aStart - aStep; 132 bIndex = bStart - bStep; 133 } 134 135 if (aIndex == 0 || bIndex == 0) { // for zero-ranked datasets 136 if (read) { 137 storeCurrentValues(); 138 } 139 } 140 } 141}