/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.cv.util;

import ai.djl.modality.cv.Image;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.RandomUtils;

public final class NDImageUtils {
    private NDImageUtils() {
    }

    public static NDArray resize(NDArray image, int size) {
        return NDImageUtils.resize(image, size, size, Image.Interpolation.BILINEAR);
    }

    public static NDArray resize(NDArray image, int width, int height) {
        return NDImageUtils.resize(image, width, height, Image.Interpolation.BILINEAR);
    }

    public static NDArray resize(NDArray image, int width, int height, Image.Interpolation interpolation) {
        return image.getNDArrayInternal().resize(width, height, interpolation.ordinal());
    }

    public static NDArray rotate90(NDArray image, int times) {
        int batchDim;
        Shape shape = image.getShape();
        int n = batchDim = shape.dimension() == 4 ? 1 : 0;
        if (NDImageUtils.isCHW(shape)) {
            return image.rotate90(times, new int[]{1 + batchDim, 2 + batchDim});
        }
        return image.rotate90(times, new int[]{batchDim, 1 + batchDim});
    }

    public static NDArray normalize(NDArray input, float mean, float std) {
        return NDImageUtils.normalize(input, new float[]{mean, mean, mean}, new float[]{std, std, std});
    }

    public static NDArray normalize(NDArray input, float[] mean, float[] std) {
        boolean chw = NDImageUtils.isCHW(input.getShape());
        boolean tf = "TensorFlow".equals(input.getManager().getEngine().getEngineName());
        if (chw && tf || !chw && !tf) {
            throw new IllegalArgumentException("normalize requires CHW format. TensorFlow requires HWC");
        }
        return input.getNDArrayInternal().normalize(mean, std);
    }

    public static NDArray toTensor(NDArray image) {
        return image.getNDArrayInternal().toTensor();
    }

    public static NDArray centerCrop(NDArray image) {
        int h2;
        Shape shape = image.getShape();
        int w = (int)shape.get(1);
        if (w == (h2 = (int)shape.get(0))) {
            return image;
        }
        if (w > h2) {
            return NDImageUtils.centerCrop(image, h2, h2);
        }
        return NDImageUtils.centerCrop(image, w, w);
    }

    public static NDArray centerCrop(NDArray image, int width, int height) {
        int y;
        int x;
        Shape shape = image.getShape();
        if (NDImageUtils.isCHW(image.getShape()) || shape.dimension() == 4) {
            throw new IllegalArgumentException("CenterCrop only support for HWC image format");
        }
        int w = (int)shape.get(1);
        int h2 = (int)shape.get(0);
        int dw = (w - width) / 2;
        int dh = (h2 - height) / 2;
        if (dw > 0) {
            x = dw;
            w = width;
        } else {
            x = 0;
        }
        if (dh > 0) {
            y = dh;
            h2 = height;
        } else {
            y = 0;
        }
        return NDImageUtils.crop(image, x, y, w, h2);
    }

    public static NDArray crop(NDArray image, int x, int y, int width, int height) {
        return image.getNDArrayInternal().crop(x, y, width, height);
    }

    public static NDArray randomFlipLeftRight(NDArray image) {
        return image.getNDArrayInternal().randomFlipLeftRight();
    }

    public static NDArray randomFlipTopBottom(NDArray image) {
        return image.getNDArrayInternal().randomFlipTopBottom();
    }

    public static NDArray randomResizedCrop(NDArray image, int width, int height, double minAreaScale, double maxAreaScale, double minAspectRatio, double maxAspectRatio) {
        Shape shape = image.getShape();
        if (NDImageUtils.isCHW(image.getShape()) || shape.dimension() == 4) {
            throw new IllegalArgumentException("randomResizedCrop only support for HWC image format");
        }
        int h2 = (int)shape.get(0);
        int w = (int)shape.get(1);
        int srcArea = h2 * w;
        double targetArea = minAreaScale * (double)srcArea + (maxAreaScale - minAreaScale) * (double)srcArea * (double)RandomUtils.nextFloat();
        double minRatio = targetArea / (double)h2 / (double)h2;
        double maxRatio = (double)w / (targetArea / (double)w);
        double[] dArray = new double[]{Math.max(minRatio, minAspectRatio), Math.min(maxRatio, maxAspectRatio)};
        double[] intersectRatio = dArray;
        if (intersectRatio[1] < intersectRatio[0]) {
            return NDImageUtils.centerCrop(image, width, height);
        }
        float finalRatio = RandomUtils.nextFloat((float)intersectRatio[0], (float)intersectRatio[1]);
        int newWidth = (int)Math.round(Math.sqrt(targetArea * (double)finalRatio));
        int newHeight = (int)((float)newWidth / finalRatio);
        int x = w == newWidth ? 0 : RandomUtils.nextInt(w - newWidth);
        int y = h2 == newHeight ? 0 : RandomUtils.nextInt(h2 - newHeight);
        try (NDArray cropped = NDImageUtils.crop(image, x, y, newWidth, newHeight);){
            NDArray nDArray = NDImageUtils.resize(cropped, width, height);
            return nDArray;
        }
    }

    public static NDArray randomBrightness(NDArray image, float brightness) {
        return image.getNDArrayInternal().randomBrightness(brightness);
    }

    public static NDArray randomHue(NDArray image, float hue) {
        return image.getNDArrayInternal().randomHue(hue);
    }

    public static NDArray randomColorJitter(NDArray image, float brightness, float contrast, float saturation, float hue) {
        return image.getNDArrayInternal().randomColorJitter(brightness, contrast, saturation, hue);
    }

    public static boolean isCHW(Shape shape) {
        if (shape.dimension() < 3) {
            throw new IllegalArgumentException("Not a valid image shape, require at least three dimensions");
        }
        if (shape.dimension() == 4) {
            shape = shape.slice(1);
        }
        if (shape.get(0) == 1L || shape.get(0) == 3L) {
            return true;
        }
        if (shape.get(2) == 1L || shape.get(2) == 3L) {
            return false;
        }
        throw new IllegalArgumentException("Image is neither CHW nor HWC");
    }
}

