/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.translate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NDArraySupplier;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

public final class PaddingStackBatchifier
implements Batchifier {
    private static final long serialVersionUID = 1L;
    private List<Integer> arraysToPad;
    private List<Integer> dimsToPad;
    private transient List<NDArraySupplier> paddingSuppliers;
    private List<Integer> paddingSizes;
    private boolean includeValidLengths;

    private PaddingStackBatchifier(Builder builder) {
        this.arraysToPad = builder.arraysToPad;
        this.dimsToPad = builder.dimsToPad;
        this.paddingSuppliers = builder.paddingSuppliers;
        this.paddingSizes = builder.paddingSizes;
        this.includeValidLengths = builder.includeValidLengths;
    }

    @Override
    public NDList batchify(NDList[] inputs) {
        NDList validLengths = new NDList(inputs.length);
        NDManager manager = ((NDArray)inputs[0].get(0)).getManager();
        for (int i = 0; i < this.arraysToPad.size(); ++i) {
            int arrayIndex = this.arraysToPad.get(i);
            int dimIndex = this.dimsToPad.get(i);
            NDArray padding = this.paddingSuppliers.get(i).get(manager);
            long paddingSize = this.paddingSizes.get(i).intValue();
            long maxSize = PaddingStackBatchifier.findMaxSize(inputs, arrayIndex, dimIndex);
            if (paddingSize != -1L && maxSize > paddingSize) {
                throw new IllegalArgumentException("The batchifier padding size is too small " + maxSize + " " + paddingSize);
            }
            maxSize = Math.max(maxSize, paddingSize);
            long[] arrayValidLengths = PaddingStackBatchifier.padArrays(inputs, arrayIndex, dimIndex, padding, maxSize);
            validLengths.add(manager.create(arrayValidLengths));
        }
        NDList result = Batchifier.STACK.batchify(inputs);
        if (this.includeValidLengths) {
            result.addAll(validLengths);
        }
        return result;
    }

    @Override
    public NDList[] unbatchify(NDList inputs) {
        if (!this.includeValidLengths) {
            return Batchifier.STACK.unbatchify(inputs);
        }
        NDList validLengths = new NDList((Collection<NDArray>)inputs.subList(inputs.size() - this.arraysToPad.size(), inputs.size()));
        inputs = new NDList((Collection<NDArray>)inputs.subList(0, inputs.size() - this.arraysToPad.size()));
        NDList[] split2 = Batchifier.STACK.unbatchify(inputs);
        for (int i = 0; i < split2.length; ++i) {
            NDList arrays = split2[i];
            for (int j = 0; j < this.arraysToPad.size(); ++j) {
                long validLength = ((NDArray)validLengths.get(j)).getLong(i);
                int arrayIndex = this.arraysToPad.get(j);
                NDArray dePadded = ((NDArray)arrays.get(arrayIndex)).get(NDIndex.sliceAxis(this.dimsToPad.get(j) - 1, 0L, validLength));
                arrays.set(arrayIndex, dePadded);
            }
        }
        return split2;
    }

    @Override
    public NDList[] split(NDList list, int numOfSlices, boolean evenSplit) {
        if (!this.includeValidLengths) {
            return Batchifier.STACK.split(list, numOfSlices, evenSplit);
        }
        NDList validLengths = new NDList((Collection<NDArray>)list.subList(list.size() - this.arraysToPad.size(), list.size()));
        list = new NDList((Collection<NDArray>)list.subList(0, list.size() - this.arraysToPad.size()));
        NDList[] split2 = Batchifier.STACK.split(list, numOfSlices, evenSplit);
        long sliceSize = ((NDArray)split2[0].get(0)).getShape().get(0);
        long totalSize = ((NDArray)list.get(0)).getShape().get(0);
        for (int i = 0; i < split2.length; ++i) {
            NDList arrays = split2[i];
            for (int j = 0; j < this.arraysToPad.size(); ++j) {
                long min2 = (long)i * sliceSize;
                long max = Math.min((long)(i + 1) * sliceSize, totalSize);
                NDArray splitValidLenghts = ((NDArray)validLengths.get(j)).get(NDIndex.sliceAxis(0, min2, max));
                arrays.add(splitValidLenghts);
            }
        }
        return split2;
    }

    public static long findMaxSize(NDList[] inputs, int arrayIndex, int dimIndex) {
        long maxSize = -1L;
        for (NDList input : inputs) {
            NDArray array = (NDArray)input.get(arrayIndex);
            maxSize = Math.max(maxSize, array.getShape().get(dimIndex));
        }
        return maxSize;
    }

    public static long[] padArrays(NDList[] inputs, int arrayIndex, int dimIndex, NDArray padding, long maxSize) {
        long[] arrayValidLengths = new long[inputs.length];
        for (int i = 0; i < inputs.length; ++i) {
            NDArray array = (NDArray)inputs[i].get(arrayIndex);
            String arrayName = array.getName();
            long validLength = array.getShape().get(dimIndex);
            if (validLength < maxSize) {
                NDArray paddingArray;
                int dimensionsRequired = array.getShape().dimension() - padding.getShape().dimension();
                if (dimensionsRequired == 0) {
                    paddingArray = padding.repeat(Shape.update(array.getShape(), dimIndex, maxSize - validLength));
                } else if (dimensionsRequired > 0) {
                    paddingArray = padding.broadcast(Shape.update(array.getShape(), dimIndex, maxSize - validLength));
                } else {
                    throw new IllegalArgumentException("The padding must be <=" + dimensionsRequired + " dimensions, but found " + padding.getShape().dimension());
                }
                array = array.concat(paddingArray.toType(array.getDataType(), false), dimIndex);
            }
            array.setName(arrayName);
            inputs[i].set(arrayIndex, array);
            arrayValidLengths[i] = validLength;
        }
        return arrayValidLengths;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        private List<Integer> arraysToPad = new ArrayList<Integer>();
        private List<Integer> dimsToPad = new ArrayList<Integer>();
        private List<NDArraySupplier> paddingSuppliers = new ArrayList<NDArraySupplier>();
        private List<Integer> paddingSizes = new ArrayList<Integer>();
        private boolean includeValidLengths;

        private Builder() {
        }

        public Builder optIncludeValidLengths(boolean includeValidLengths) {
            this.includeValidLengths = includeValidLengths;
            return this;
        }

        public Builder addPad(int array, int dim, NDArraySupplier supplier) {
            return this.addPad(array, dim, supplier, -1);
        }

        public Builder addPad(int array, int dim, NDArraySupplier supplier, int paddingSize) {
            this.arraysToPad.add(array);
            this.dimsToPad.add(dim);
            this.paddingSuppliers.add(supplier);
            this.paddingSizes.add(paddingSize);
            return this;
        }

        public PaddingStackBatchifier build() {
            return new PaddingStackBatchifier(this);
        }
    }
}

