001/*- 002 * Copyright 2017 Diamond Light Source Ltd. 003 * 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 */ 009 010package org.eclipse.january.dataset; 011 012import java.util.Arrays; 013import java.util.List; 014 015/** 016 * Class to run over a single dataset with NumPy broadcasting to promote shapes 017 * which have lower rank and outputs to a second dataset 018 * @since 2.1 019 */ 020public class BooleanNullIterator extends BooleanIteratorBase { 021 022 /** 023 * @param a 024 * @param o (can be null for new dataset, or a) 025 */ 026 public BooleanNullIterator(Dataset a, Dataset o) { 027 this(a, o, false); 028 } 029 030 /** 031 * @param a 032 * @param o (can be null for new dataset, or a) 033 * @param createIfNull if true create the output dataset if that is null 034 * (by default, can create float or complex datasets) 035 */ 036 public BooleanNullIterator(Dataset a, Dataset o, boolean createIfNull) { 037 this(a, o, createIfNull, false, true); 038 } 039 040 /** 041 * @param a 042 * @param o (can be null for new dataset, or a) 043 * @param createIfNull if true create the output dataset if that is null 044 * @param allowInteger if true, can create integer datasets 045 * @param allowComplex if true, can create complex datasets 046 */ 047 @SuppressWarnings("deprecation") 048 public BooleanNullIterator(Dataset a, Dataset o, boolean createIfNull, boolean allowInteger, boolean allowComplex) { 049 super(true, a, null, o); 050 List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), o == null ? null : o.getShapeRef()); 051 052 BroadcastUtils.checkItemSize(a, o); 053 054 maxShape = fullShapes.remove(0); 055 056 oStride = null; 057 if (o != null && !Arrays.equals(maxShape, o.getShapeRef())) { 058 throw new IllegalArgumentException("Output does not match broadcasted shape"); 059 } 060 061 aShape = fullShapes.remove(0); 062 063 int rank = maxShape.length; 064 endrank = rank - 1; 065 066 aDataset = a.reshape(aShape); 067 aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape); 068 if (outputA) { 069 oStride = aStride; 070 oDelta = null; 071 oStep = 0; 072 } else if (o != null) { 073 oStride = BroadcastUtils.createBroadcastStrides(o, maxShape); 074 oDelta = new int[rank]; 075 oStep = o.getElementsPerItem(); 076 } else if (createIfNull) { 077 int is = aDataset.getElementsPerItem(); 078 int dt = aDataset.getDType(); 079 if (aDataset.isComplex() && !allowComplex) { 080 is = 1; 081 dt = DTypeUtils.getBestFloatDType(dt); 082 } else if (!aDataset.hasFloatingPointElements() && !allowInteger) { 083 dt = DTypeUtils.getBestFloatDType(dt); 084 } 085 oDataset = DatasetFactory.zeros(is, maxShape, dt); 086 oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape); 087 oDelta = new int[rank]; 088 oStep = is; 089 } else { 090 oDelta = null; 091 oStep = 0; 092 } 093 094 pos = new int[rank]; 095 aDelta = new int[rank]; 096 for (int j = endrank; j >= 0; j--) { 097 aDelta[j] = aStride[j] * aShape[j]; 098 if (oDelta != null) { 099 oDelta[j] = oStride[j] * maxShape[j]; 100 } 101 } 102 103 aStart = aDataset.getOffset(); 104 aMax = endrank < 0 ? aStep + aStart : Integer.MIN_VALUE; 105 oStart = oDelta == null ? 0 : oDataset.getOffset(); 106 reset(); 107 } 108 109 @Override 110 public boolean hasNext() { 111 int j = endrank; 112 for (; j >= 0; j--) { 113 pos[j]++; 114 index += aStride[j]; 115 if (oDelta != null) { 116 oIndex += oStride[j]; 117 } 118 if (pos[j] >= maxShape[j]) { 119 pos[j] = 0; 120 index -= aDelta[j]; // reset these dimensions 121 if (oDelta != null) { 122 oIndex -= oDelta[j]; 123 } 124 } else { 125 break; 126 } 127 } 128 if (j == -1) { 129 if (endrank >= 0) { 130 return false; 131 } 132 index += aStep; 133 if (oDelta != null) { 134 oIndex += oStep; 135 } 136 } 137 if (outputA) { 138 oIndex = index; 139 } 140 141 if (index == aMax) { 142 return false; 143 } 144 145 return true; 146 } 147 148 @Override 149 public void reset() { 150 for (int i = 0; i <= endrank; i++) { 151 pos[i] = 0; 152 } 153 154 if (endrank >= 0) { 155 pos[endrank] = -1; 156 index = aStart - aStride[endrank]; 157 oIndex = oStart - (oStride == null ? 0 : oStride[endrank]); 158 } else { 159 index = aStart - aStep; 160 oIndex = oStart - oStep; 161 } 162 } 163}