001/*-
002 * Copyright 2016 Diamond Light Source Ltd.
003 *
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 */
009
010package org.eclipse.january.dataset;
011
012import java.util.ArrayList;
013import java.util.Arrays;
014import java.util.List;
015
016public final class BroadcastUtils {
017
018        /**
019         * Calculate shapes for broadcasting
020         * @param oldShape
021         * @param size
022         * @param newShape
023         * @return broadcasted shape and full new shape or null if it cannot be done
024         */
025        public static int[][] calculateBroadcastShapes(int[] oldShape, int size, int... newShape) {
026                if (newShape == null)
027                        return null;
028        
029                int brank = newShape.length;
030                if (brank == 0) {
031                        if (size == 1)
032                                return new int[][] {oldShape, newShape};
033                        return null;
034                }
035        
036                if (Arrays.equals(oldShape, newShape))
037                        return new int[][] {oldShape, newShape};
038        
039                int offset = brank - oldShape.length;
040                if (offset < 0) { // when new shape is incomplete
041                        newShape = padShape(newShape, -offset);
042                        offset = 0;
043                }
044        
045                int[] bshape;
046                if (offset > 0) { // new shape has extra dimensions
047                        bshape = padShape(oldShape, offset);
048                } else {
049                        bshape = oldShape;
050                }
051        
052                for (int i = 0; i < brank; i++) {
053                        if (newShape[i] != bshape[i] && bshape[i] != 1 && newShape[i] != 1) {
054                                return null;
055                        }
056                }
057        
058                return new int[][] {bshape, newShape};
059        }
060
061        /**
062         * Pad shape by prefixing with ones
063         * @param shape
064         * @param padding
065         * @return new shape or old shape if padding is zero
066         */
067        public static int[] padShape(final int[] shape, final int padding) {
068                if (padding < 0)
069                        throw new IllegalArgumentException("Padding must be zero or greater");
070        
071                if (padding == 0)
072                        return shape;
073        
074                final int[] nshape = new int[shape.length + padding];
075                Arrays.fill(nshape, 1);
076                System.arraycopy(shape, 0, nshape, padding, shape.length);
077                return nshape;
078        }
079
080        /**
081         * Take in shapes and broadcast them to same rank
082         * @param shapes
083         * @return list of broadcasted shapes plus the first entry is the maximum shape
084         */
085        public static List<int[]> broadcastShapes(int[]... shapes) {
086                int maxRank = -1;
087                for (int[] s : shapes) {
088                        if (s == null)
089                                continue;
090        
091                        int r = s.length;
092                        if (r > maxRank) {
093                                maxRank = r;
094                        }
095                }
096        
097                List<int[]> newShapes = new ArrayList<int[]>();
098                for (int[] s : shapes) {
099                        if (s == null)
100                                continue;
101                        newShapes.add(padShape(s, maxRank - s.length));
102                }
103        
104                int[] maxShape = new int[maxRank];
105                for (int i = 0; i < maxRank; i++) {
106                        int m = -1;
107                        for (int[] s : newShapes) {
108                                int l = s[i];
109                                if (l > m) {
110                                        if (m > 1) {
111                                                throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum");
112                                        }
113                                        m = l;
114                                }
115                        }
116                        maxShape[i] = m;
117                }
118
119                checkShapes(maxShape, newShapes);
120                newShapes.add(0, maxShape);
121                return newShapes;
122        }
123
124        /**
125         * Take in shapes and broadcast them to maximum shape
126         * @param maxShape
127         * @param shapes
128         * @return list of broadcasted shapes
129         */
130        public static List<int[]> broadcastShapesToMax(int[] maxShape, int[]... shapes) {
131                int maxRank = maxShape.length;
132                for (int[] s : shapes) {
133                        if (s == null)
134                                continue;
135        
136                        int r = s.length;
137                        if (r > maxRank) {
138                                throw new IllegalArgumentException("A shape exceeds given rank of maximum shape");
139                        }
140                }
141        
142                List<int[]> newShapes = new ArrayList<int[]>();
143                for (int[] s : shapes) {
144                        if (s == null)
145                                continue;
146                        newShapes.add(padShape(s, maxRank - s.length));
147                }
148
149                checkShapes(maxShape, newShapes);
150                return newShapes;
151        }
152
153        private static void checkShapes(int[] maxShape, List<int[]> newShapes) {
154                for (int i = 0; i < maxShape.length; i++) {
155                        int m = maxShape[i];
156                        for (int[] s : newShapes) {
157                                int l = s[i];
158                                if (l != 1 && l != m) {
159                                        throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum");
160                                }
161                        }
162                }
163        }
164
165        @SuppressWarnings("deprecation")
166        static Dataset createDataset(final Dataset a, final Dataset b, final int[] shape) {
167                final int rt;
168                final int ar = a.getRank();
169                final int br = b.getRank();
170                final int tt = DTypeUtils.getBestDType(a.getDType(), b.getDType());
171                if (ar == 0 ^ br == 0) { // ignore type of zero-rank dataset unless it's floating point 
172                        if (ar == 0) {
173                                rt = a.hasFloatingPointElements() ? tt : b.getDType();
174                        } else {
175                                rt = b.hasFloatingPointElements() ? tt : a.getDType();
176                        }
177                } else {
178                        rt = tt;
179                }
180                final int ia = a.getElementsPerItem();
181                final int ib = b.getElementsPerItem();
182        
183                return DatasetFactory.zeros(ia > ib ? ia : ib, shape, rt);
184        }
185
186        static void checkItemSize(Dataset a, Dataset b, Dataset o) {
187                final int isa = a.getElementsPerItem();
188                final int isb = b.getElementsPerItem();
189                if (isa != isb && isa != 1 && isb != 1) {
190                        // exempt single-value dataset case too
191                        if ((isa == 1 || b.getSize() != 1) && (isb == 1 || a.getSize() != 1) ) {
192                                throw new IllegalArgumentException("Can not broadcast where number of elements per item mismatch and one does not equal another");
193                        }
194                }
195                if (o != null && o.getDType() != Dataset.BOOL) {
196                        final int ism = Math.max(isa, isb);
197                        final int iso = o.getElementsPerItem();
198                        if (iso != ism && ism != 1) {
199                                throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'");
200                        }
201                }
202        }
203
204        /**
205         * Create a stride array from a dataset to a broadcast shape
206         * @param a dataset
207         * @param broadcastShape
208         * @return stride array
209         */
210        public static int[] createBroadcastStrides(Dataset a, final int[] broadcastShape) {
211                return createBroadcastStrides(a.getElementsPerItem(), a.getShapeRef(), a.getStrides(), broadcastShape);
212        }
213
214        /**
215         * Create a stride array from a dataset to a broadcast shape
216         * @param isize
217         * @param oShape original shape
218         * @param oStride original stride
219         * @param broadcastShape
220         * @return stride array
221         */
222        public static int[] createBroadcastStrides(final int isize, final int[] oShape, final int[] oStride, final int[] broadcastShape) {
223                int rank = oShape.length;
224                if (broadcastShape.length != rank) {
225                        throw new IllegalArgumentException("Dataset must have same rank as broadcast shape");
226                }
227        
228                int[] stride = new int[rank];
229                if (oStride == null) {
230                        int s = isize;
231                        for (int j = rank - 1; j >= 0; j--) {
232                                if (broadcastShape[j] == oShape[j]) {
233                                        stride[j] = s;
234                                        s *= oShape[j];
235                                } else {
236                                        stride[j] = 0;
237                                }
238                        }
239                } else {
240                        for (int j = 0; j < rank; j++) {
241                                if (broadcastShape[j] == oShape[j]) {
242                                        stride[j] = oStride[j];
243                                } else {
244                                        stride[j] = 0;
245                                }
246                        }
247                }
248        
249                return stride;
250        }
251
252        /**
253         * Converts and broadcast all objects as datasets of same shape
254         * @param objects
255         * @return all as broadcasted to same shape
256         */
257        public static Dataset[] convertAndBroadcast(Object... objects) {
258                final int n = objects.length;
259
260                Dataset[] datasets = new Dataset[n];
261                int[][] shapes = new int[n][];
262                for (int i = 0; i < n; i++) {
263                        Dataset d = DatasetFactory.createFromObject(objects[i]);
264                        datasets[i] = d;
265                        shapes[i] = d.getShapeRef();
266                }
267
268                List<int[]> nShapes = BroadcastUtils.broadcastShapes(shapes);
269                int[] mshape = nShapes.get(0);
270                for (int i = 0; i < n; i++) {
271                        datasets[i] = datasets[i].getBroadcastView(mshape);
272                }
273
274                return datasets;
275        }
276}