001/*- 002 ******************************************************************************* 003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd. 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 * 009 * Contributors: 010 * Peter Chang - initial API and implementation and/or initial documentation 011 *******************************************************************************/ 012 013package org.eclipse.january.dataset; 014 015import java.util.Arrays; 016import java.util.List; 017 018/** 019 * Class to run over a pair of datasets in parallel with NumPy broadcasting to promote shapes 020 * which have lower rank and outputs to a third dataset 021 */ 022public class BroadcastPairIterator extends BroadcastIterator { 023 private int[] aShape; 024 private int[] bShape; 025 private int[] aStride; 026 private int[] bStride; 027 private int[] oStride; 028 029 final private int endrank; 030 031 private final int[] aDelta, bDelta; 032 private final int[] oDelta; // this being non-null means output is different from inputs 033 private final int aStep, bStep, oStep; 034 private int aMax, bMax; 035 private int aStart, bStart, oStart; 036 037 /** 038 * 039 * @param a 040 * @param b 041 * @param o (can be null for new dataset, a or b) 042 * @param createIfNull 043 */ 044 public BroadcastPairIterator(Dataset a, Dataset b, Dataset o, boolean createIfNull) { 045 super(a, b, o); 046 List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), b.getShapeRef(), o == null ? null : o.getShapeRef()); 047 048 maxShape = fullShapes.remove(0); 049 050 oStride = null; 051 if (o != null && !Arrays.equals(maxShape, o.getShapeRef())) { 052 throw new IllegalArgumentException("Output does not match broadcasted shape"); 053 } 054 aShape = fullShapes.remove(0); 055 bShape = fullShapes.remove(0); 056 057 int rank = maxShape.length; 058 endrank = rank - 1; 059 060 aDataset = a.reshape(aShape); 061 bDataset = b.reshape(bShape); 062 aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape); 063 bStride = BroadcastUtils.createBroadcastStrides(bDataset, maxShape); 064 if (outputA) { 065 oStride = aStride; 066 oDelta = null; 067 oStep = 0; 068 } else if (outputB) { 069 oStride = bStride; 070 oDelta = null; 071 oStep = 0; 072 } else if (o != null) { 073 oStride = BroadcastUtils.createBroadcastStrides(o, maxShape); 074 oDelta = new int[rank]; 075 oStep = o.getElementsPerItem(); 076 } else if (createIfNull) { 077 oDataset = BroadcastUtils.createDataset(a, b, maxShape); 078 oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape); 079 oDelta = new int[rank]; 080 oStep = oDataset.getElementsPerItem(); 081 } else { 082 oDelta = null; 083 oStep = 0; 084 } 085 086 pos = new int[rank]; 087 aDelta = new int[rank]; 088 aStep = aDataset.getElementsPerItem(); 089 bDelta = new int[rank]; 090 bStep = bDataset.getElementsPerItem(); 091 for (int j = endrank; j >= 0; j--) { 092 aDelta[j] = aStride[j] * aShape[j]; 093 bDelta[j] = bStride[j] * bShape[j]; 094 if (oDelta != null) { 095 oDelta[j] = oStride[j] * maxShape[j]; 096 } 097 } 098 aStart = aDataset.getOffset(); 099 bStart = bDataset.getOffset(); 100 aMax = endrank < 0 ? aStep + aStart: Integer.MIN_VALUE; 101 bMax = endrank < 0 ? bStep + bStart: Integer.MIN_VALUE; 102 oStart = oDelta == null ? 0 : oDataset.getOffset(); 103 reset(); 104 } 105 106 @Override 107 public boolean hasNext() { 108 int j = endrank; 109 int oldA = aIndex; 110 int oldB = bIndex; 111 for (; j >= 0; j--) { 112 pos[j]++; 113 aIndex += aStride[j]; 114 bIndex += bStride[j]; 115 if (oDelta != null) { 116 oIndex += oStride[j]; 117 } 118 if (pos[j] >= maxShape[j]) { 119 pos[j] = 0; 120 aIndex -= aDelta[j]; // reset these dimensions 121 bIndex -= bDelta[j]; 122 if (oDelta != null) { 123 oIndex -= oDelta[j]; 124 } 125 } else { 126 break; 127 } 128 } 129 if (j == -1) { 130 if (endrank >= 0) { 131 return false; 132 } 133 aIndex += aStep; 134 bIndex += bStep; 135 if (oDelta != null) { 136 oIndex += oStep; 137 } 138 } 139 if (outputA) { 140 oIndex = aIndex; 141 } else if (outputB) { 142 oIndex = bIndex; 143 } 144 145 if (aIndex == aMax || bIndex == bMax) { 146 return false; 147 } 148 149 if (read) { 150 if (oldA != aIndex) { 151 if (asDouble) { 152 aDouble = aDataset.getElementDoubleAbs(aIndex); 153 } else { 154 aLong = aDataset.getElementLongAbs(aIndex); 155 } 156 } 157 if (oldB != bIndex) { 158 if (asDouble) { 159 bDouble = bDataset.getElementDoubleAbs(bIndex); 160 } else { 161 bLong = bDataset.getElementLongAbs(bIndex); 162 } 163 } 164 } 165 166 return true; 167 } 168 169 /** 170 * @return shape of first broadcasted dataset 171 */ 172 public int[] getFirstShape() { 173 return aShape; 174 } 175 176 /** 177 * @return shape of second broadcasted dataset 178 */ 179 public int[] getSecondShape() { 180 return bShape; 181 } 182 183 @Override 184 public void reset() { 185 for (int i = 0; i <= endrank; i++) 186 pos[i] = 0; 187 188 if (endrank >= 0) { 189 pos[endrank] = -1; 190 aIndex = aStart - aStride[endrank]; 191 bIndex = bStart - bStride[endrank]; 192 oIndex = oStart - (oStride == null ? 0 : oStride[endrank]); 193 } else { 194 aIndex = aStart - aStep; 195 bIndex = bStart - bStep; 196 oIndex = oStart - oStep; 197 } 198 199 if (aIndex == 0 || bIndex == 0 || aStride[endrank] == 0 || bStride[endrank] == 0) { // for zero-ranked datasets or extended shape 200 if (read) { 201 storeCurrentValues(); 202 } 203 } 204 } 205}