/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.engine.rust;

import ai.djl.engine.rust.RsNDArray;
import ai.djl.engine.rust.RsNDArrayIndexer;
import ai.djl.engine.rust.RsNDManager;
import ai.djl.engine.rust.RustLibrary;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.NDUtils;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.recurrent.RNN;
import java.util.List;

public class RsNDArrayEx
implements NDArrayEx {
    private RsNDArray array;

    RsNDArrayEx(RsNDArray parent) {
        this.array = parent;
    }

    @Override
    public RsNDArray rdivi(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public RsNDArray rmodi(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public RsNDArray rpowi(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public RsNDArray relu() {
        return new RsNDArray(this.array.getManager(), RustLibrary.relu((Long)this.array.getHandle()));
    }

    @Override
    public RsNDArray sigmoid() {
        return new RsNDArray(this.array.getManager(), RustLibrary.sigmoid((Long)this.array.getHandle()));
    }

    @Override
    public RsNDArray tanh() {
        return this.array.tanh();
    }

    @Override
    public RsNDArray softPlus() {
        return new RsNDArray(this.array.getManager(), RustLibrary.softPlus((Long)this.array.getHandle()));
    }

    @Override
    public RsNDArray softSign() {
        return new RsNDArray(this.array.getManager(), RustLibrary.softSign((Long)this.array.getHandle()));
    }

    @Override
    public RsNDArray leakyRelu(float alpha) {
        return new RsNDArray(this.array.getManager(), RustLibrary.leakyRelu((Long)this.array.getHandle(), alpha));
    }

    @Override
    public RsNDArray elu(float alpha) {
        return new RsNDArray(this.array.getManager(), RustLibrary.elu((Long)this.array.getHandle(), alpha));
    }

    @Override
    public RsNDArray selu() {
        return new RsNDArray(this.array.getManager(), RustLibrary.selu((Long)this.array.getHandle()));
    }

    @Override
    public RsNDArray gelu() {
        return new RsNDArray(this.array.getManager(), RustLibrary.gelu((Long)this.array.getHandle()));
    }

    @Override
    public RsNDArray maxPool(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        return new RsNDArray(this.array.getManager(), RustLibrary.maxPool((Long)this.array.getHandle(), kernelShape.getShape(), stride.getShape(), padding.getShape(), ceilMode));
    }

    @Override
    public RsNDArray globalMaxPool() {
        Shape shape = this.getPoolShape(this.array);
        long newHandle = RustLibrary.adaptiveMaxPool((Long)this.array.getHandle(), shape.getShape());
        try (RsNDArray temp = new RsNDArray(this.array.getManager(), newHandle);){
            RsNDArray rsNDArray = (RsNDArray)temp.reshape(this.array.getShape().slice(0, 2));
            return rsNDArray;
        }
    }

    @Override
    public RsNDArray avgPool(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode, boolean countIncludePad) {
        if (kernelShape.size() != 2L) {
            throw new UnsupportedOperationException("Only avgPool2d is supported");
        }
        return new RsNDArray(this.array.getManager(), RustLibrary.avgPool2d((Long)this.array.getHandle(), kernelShape.getShape(), stride.getShape()));
    }

    @Override
    public RsNDArray globalAvgPool() {
        Shape shape = this.getPoolShape(this.array);
        long newHandle = RustLibrary.adaptiveAvgPool((Long)this.array.getHandle(), shape.getShape());
        try (RsNDArray temp = new RsNDArray(this.array.getManager(), newHandle);){
            RsNDArray rsNDArray = (RsNDArray)temp.reshape(this.array.getShape().slice(0, 2));
            return rsNDArray;
        }
    }

    @Override
    public RsNDArray lpPool(float normType, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        if (padding.size() != 0L) {
            throw new IllegalArgumentException("padding is not supported for Rust engine");
        }
        return new RsNDArray(this.array.getManager(), RustLibrary.lpPool((Long)this.array.getHandle(), normType, kernelShape.getShape(), stride.getShape(), ceilMode));
    }

    @Override
    public RsNDArray globalLpPool(float normType) {
        long[] kernelShape = this.array.getShape().slice(2).getShape();
        long[] stride = this.getPoolShape(this.array).getShape();
        long newHandle = RustLibrary.lpPool((Long)this.array.getHandle(), normType, kernelShape, stride, false);
        try (RsNDArray temp = new RsNDArray(this.array.getManager(), newHandle);){
            RsNDArray rsNDArray = (RsNDArray)temp.reshape(this.array.getShape().slice(0, 2));
            return rsNDArray;
        }
    }

    @Override
    public void adadeltaUpdate(NDList inputs, NDList weights, float weightDecay, float rescaleGrad, float clipGrad, float rho, float epsilon) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void adagradUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float epsilon) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void adamUpdate(NDList inputs, NDList weights, float learningRate, float learningRateBiasCorrection, float weightDecay, float rescaleGrad, float clipGrad, float beta1, float beta2, float epsilon, boolean lazyUpdate, boolean adamw) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void nagUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float momentum) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void rmspropUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float rho, float momentum, float epsilon, boolean centered) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void sgdUpdate(NDList inputs, NDList weights, float learningRate, float weightDecay, float rescaleGrad, float clipGrad, float momentum, boolean lazyUpdate) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDList convolution(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape dilation, int groups) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDList deconvolution(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape outPadding, Shape dilation, int groups) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDList linear(NDArray input, NDArray weight, NDArray bias) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDList embedding(NDArray input, NDArray weight, SparseFormat sparseFormat) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDList prelu(NDArray input, NDArray alpha) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDList dropout(NDArray input, float rate, boolean training) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDList layerNorm(NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDList batchNorm(NDArray input, NDArray runningMean, NDArray runningVar, NDArray gamma, NDArray beta, int axis, float momentum, float eps, boolean training) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDList rnn(NDArray input, NDArray state, NDList params, boolean hasBiases, int numLayers, RNN.Activation activation, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDList gru(NDArray input, NDArray state, NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDList lstm(NDArray input, NDList states, NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDArray interpolation(long[] size, int mode, boolean alignCorners) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public RsNDArray resize(int width, int height, int interpolation) {
        long[] shape = this.array.getShape().getShape();
        if (shape[0] == (long)height && shape[1] == (long)width) {
            return this.array.toType(DataType.FLOAT32, false);
        }
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDArray randomFlipLeftRight() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDArray randomFlipTopBottom() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDArray randomBrightness(float brightness) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDArray randomHue(float hue) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDArray randomColorJitter(float brightness, float contrast, float saturation, float hue) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDArrayIndexer getIndexer(NDManager manager) {
        return new RsNDArrayIndexer((RsNDManager)manager);
    }

    @Override
    public RsNDArray where(NDArray condition, NDArray other) {
        if (!condition.getShape().equals(this.array.getShape())) {
            throw new UnsupportedOperationException("condition and self shape mismatch, broadcast is not supported");
        }
        RsNDManager manager = this.array.getManager();
        try (NDScope ignore = new NDScope();){
            long conditionHandle = (Long)manager.from(condition).getHandle();
            long otherHandle = (Long)manager.from(other).getHandle();
            RsNDArray ret = new RsNDArray(manager, RustLibrary.where(conditionHandle, (Long)this.array.getHandle(), otherHandle));
            NDScope.unregister((NDArray)ret);
            RsNDArray rsNDArray = ret;
            return rsNDArray;
        }
    }

    @Override
    public RsNDArray stack(NDList arrays, int axis) {
        long[] srcArray = new long[arrays.size() + 1];
        srcArray[0] = (Long)this.array.getHandle();
        RsNDManager manager = this.array.getManager();
        try (NDScope ignore = new NDScope();){
            int i = 1;
            for (NDArray arr : arrays) {
                srcArray[i++] = (Long)manager.from(arr).getHandle();
            }
            RsNDArray ret = new RsNDArray(manager, RustLibrary.stack(srcArray, axis));
            NDScope.unregister((NDArray)ret);
            RsNDArray rsNDArray = ret;
            return rsNDArray;
        }
    }

    @Override
    public RsNDArray concat(NDList list, int axis) {
        NDUtils.checkConcatInput(list);
        long[] srcArray = new long[list.size() + 1];
        srcArray[0] = (Long)this.array.getHandle();
        RsNDManager manager = this.array.getManager();
        try (NDScope ignore = new NDScope();){
            int i = 1;
            for (NDArray arr : list) {
                srcArray[i++] = (Long)manager.from(arr).getHandle();
            }
            RsNDArray ret = new RsNDArray(manager, RustLibrary.concat(srcArray, axis));
            NDScope.unregister((NDArray)ret);
            RsNDArray rsNDArray = ret;
            return rsNDArray;
        }
    }

    @Override
    public NDList multiBoxTarget(NDList inputs, float iouThreshold, float ignoreLabel, float negativeMiningRatio, float negativeMiningThreshold, int minNegativeSamples) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDList multiBoxPrior(List<Float> sizes, List<Float> ratios, List<Float> steps, List<Float> offsets, boolean clip) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDList multiBoxDetection(NDList inputs, boolean clip, float threshold, int backgroundId, float nmsThreshold, boolean forceSuppress, int nmsTopK) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public RsNDArray getArray() {
        return this.array;
    }

    private Shape getPoolShape(NDArray array) {
        switch (array.getShape().dimension() - 2) {
            case 1: {
                return new Shape(1L);
            }
            case 2: {
                return new Shape(1L, 1L);
            }
            case 3: {
                return new Shape(1L, 1L, 1L);
            }
        }
        throw new IllegalArgumentException("the input dimension should be in [3, 5]");
    }
}

