001/*- 002 * Copyright 2016 Diamond Light Source Ltd. 003 * 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 */ 009 010package org.eclipse.january.dataset; 011 012import java.util.List; 013 014/** 015 * Class to run over a pair of datasets in parallel with NumPy broadcasting of second dataset 016 */ 017public class BroadcastSingleIterator extends BroadcastSelfIterator { 018 private int[] bShape; 019 private int[] aStride; 020 private int[] bStride; 021 022 final private int endrank; 023 024 private final int[] aDelta, bDelta; 025 private final int aStep, bStep; 026 private int aMax, bMax; 027 private int aStart, bStart; 028 029 /** 030 * 031 * @param a 032 * @param b 033 */ 034 public BroadcastSingleIterator(Dataset a, Dataset b) { 035 super(a, b); 036 037 int[] aShape = a.getShapeRef(); 038 maxShape = aShape; 039 List<int[]> fullShapes = BroadcastUtils.broadcastShapesToMax(maxShape, b.getShapeRef()); 040 bShape = fullShapes.remove(0); 041 042 int rank = maxShape.length; 043 endrank = rank - 1; 044 045 bDataset = b.reshape(bShape); 046 int[] aOffset = new int[1]; 047 aStride = AbstractDataset.createStrides(aDataset, aOffset ); 048 bStride = BroadcastUtils.createBroadcastStrides(bDataset, maxShape); 049 050 pos = new int[rank]; 051 aDelta = new int[rank]; 052 aStep = aDataset.getElementsPerItem(); 053 bDelta = new int[rank]; 054 bStep = bDataset.getElementsPerItem(); 055 for (int j = endrank; j >= 0; j--) { 056 aDelta[j] = aStride[j] * aShape[j]; 057 bDelta[j] = bStride[j] * bShape[j]; 058 } 059 if (endrank < 0) { 060 aMax = aStep; 061 bMax = bStep; 062 } else { 063 aMax = Integer.MIN_VALUE; // use max delta 064 bMax = Integer.MIN_VALUE; 065 for (int j = endrank; j >= 0; j--) { 066 if (aDelta[j] > aMax) { 067 aMax = aDelta[j]; 068 } 069 if (bDelta[j] > bMax) { 070 bMax = bDelta[j]; 071 } 072 } 073 } 074 aStart = aOffset[0]; 075 aMax += aStart; 076 bStart = bDataset.getOffset(); 077 bMax += bStart; 078 reset(); 079 } 080 081 @Override 082 public boolean hasNext() { 083 int j = endrank; 084 int oldB = bIndex; 085 for (; j >= 0; j--) { 086 pos[j]++; 087 aIndex += aStride[j]; 088 bIndex += bStride[j]; 089 if (pos[j] >= maxShape[j]) { 090 pos[j] = 0; 091 aIndex -= aDelta[j]; // reset these dimensions 092 bIndex -= bDelta[j]; 093 } else { 094 break; 095 } 096 } 097 if (j == -1) { 098 if (endrank >= 0) { 099 aIndex = aMax; 100 bIndex = bMax; 101 return false; 102 } 103 aIndex += aStep; 104 bIndex += bStep; 105 } 106 107 if (aIndex == aMax || bIndex == bMax) 108 return false; 109 110 if (read) { 111 if (oldB != bIndex) { 112 if (asDouble) { 113 bDouble = bDataset.getElementDoubleAbs(bIndex); 114 } else { 115 bLong = bDataset.getElementLongAbs(bIndex); 116 } 117 } 118 } 119 120 return true; 121 } 122 123 /** 124 * @return shape of first broadcasted dataset 125 */ 126 public int[] getFirstShape() { 127 return maxShape; 128 } 129 130 /** 131 * @return shape of second broadcasted dataset 132 */ 133 public int[] getSecondShape() { 134 return bShape; 135 } 136 137 @Override 138 public void reset() { 139 for (int i = 0; i <= endrank; i++) 140 pos[i] = 0; 141 142 if (endrank >= 0) { 143 pos[endrank] = -1; 144 aIndex = aStart - aStride[endrank]; 145 bIndex = bStart - bStride[endrank]; 146 } else { 147 aIndex = aStart - aStep; 148 bIndex = bStart - bStep; 149 } 150 151 if (aIndex == 0 || bIndex == 0) { // for zero-ranked datasets 152 if (read) { 153 storeCurrentValues(); 154 } 155 if (aMax == aIndex) 156 aMax++; 157 if (bMax == bIndex) 158 bMax++; 159 } 160 } 161}