/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.ObjectDetectionTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;

public class YoloV5Translator
extends ObjectDetectionTranslator {
    private YoloOutputType yoloOutputLayerType;
    private float nmsThreshold;

    protected YoloV5Translator(Builder builder) {
        super(builder);
        this.yoloOutputLayerType = builder.outputType;
        this.nmsThreshold = builder.nmsThreshold;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> arguments) {
        Builder builder = new Builder();
        builder.configPreProcess(arguments);
        builder.configPostProcess(arguments);
        return builder;
    }

    protected DetectedObjects nms(int imageWidth, int imageHeight, List<Rectangle> boxes, List<Integer> classIds, List<Float> scores) {
        ArrayList<String> retClasses = new ArrayList<String>();
        ArrayList<Double> retProbs = new ArrayList<Double>();
        ArrayList<BoundingBox> retBB = new ArrayList<BoundingBox>();
        for (int classId = 0; classId < this.classes.size(); ++classId) {
            ArrayList<Rectangle> r = new ArrayList<Rectangle>();
            ArrayList<Double> s2 = new ArrayList<Double>();
            ArrayList<Integer> map = new ArrayList<Integer>();
            for (int j = 0; j < classIds.size(); ++j) {
                if (classIds.get(j) != classId) continue;
                r.add(boxes.get(j));
                s2.add(scores.get(j).doubleValue());
                map.add(j);
            }
            if (r.isEmpty()) continue;
            List<Integer> nms = Rectangle.nms(r, s2, this.nmsThreshold);
            for (int index : nms) {
                int pos = (Integer)map.get(index);
                int id = classIds.get(pos);
                retClasses.add((String)this.classes.get(id));
                retProbs.add(scores.get(pos).doubleValue());
                Rectangle rect = boxes.get(pos);
                if (this.removePadding) {
                    int padW = (this.width - imageWidth) / 2;
                    int padH = (this.height - imageHeight) / 2;
                    rect = new Rectangle((rect.getX() - (double)padW) / (double)imageWidth, (rect.getY() - (double)padH) / (double)imageHeight, rect.getWidth() / (double)imageWidth, rect.getHeight() / (double)imageHeight);
                } else if (this.applyRatio) {
                    rect = new Rectangle(rect.getX() / (double)this.width, rect.getY() / (double)this.height, rect.getWidth() / (double)this.width, rect.getHeight() / (double)this.height);
                }
                retBB.add(rect);
            }
        }
        return new DetectedObjects(retClasses, retProbs, retBB);
    }

    protected DetectedObjects processFromBoxOutput(int imageWidth, int imageHeight, NDList list) {
        float[] flattened = ((NDArray)list.get(0)).toFloatArray();
        int sizeClasses = this.classes.size();
        int stride = 5 + sizeClasses;
        int size = flattened.length / stride;
        ArrayList<Rectangle> boxes = new ArrayList<Rectangle>();
        ArrayList<Float> scores = new ArrayList<Float>();
        ArrayList<Integer> classIds = new ArrayList<Integer>();
        for (int i = 0; i < size; ++i) {
            int indexBase = i * stride;
            float maxClass = 0.0f;
            int maxIndex = 0;
            for (int c = 0; c < sizeClasses; ++c) {
                if (!(flattened[indexBase + c + 5] > maxClass)) continue;
                maxClass = flattened[indexBase + c + 5];
                maxIndex = c;
            }
            float score = maxClass * flattened[indexBase + 4];
            if (!(score > this.threshold)) continue;
            float xPos = flattened[indexBase];
            float yPos = flattened[indexBase + 1];
            float w = flattened[indexBase + 2];
            float h2 = flattened[indexBase + 3];
            Rectangle rect = new Rectangle(Math.max(0.0f, xPos - w / 2.0f), Math.max(0.0f, yPos - h2 / 2.0f), w, h2);
            boxes.add(rect);
            scores.add(Float.valueOf(score));
            classIds.add(maxIndex);
        }
        return this.nms(imageWidth, imageHeight, boxes, classIds, scores);
    }

    private DetectedObjects processFromDetectOutput() {
        throw new UnsupportedOperationException("detect layer output is not supported yet, check correct YoloV5 export format");
    }

    @Override
    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        int imageWidth = (Integer)ctx.getAttachment("width");
        int imageHeight = (Integer)ctx.getAttachment("height");
        switch (this.yoloOutputLayerType) {
            case DETECT: {
                return this.processFromDetectOutput();
            }
            case AUTO: {
                if (((NDArray)list.get(0)).getShape().dimension() > 2) {
                    return this.processFromDetectOutput();
                }
                return this.processFromBoxOutput(imageWidth, imageHeight, list);
            }
        }
        return this.processFromBoxOutput(imageWidth, imageHeight, list);
    }

    public static class Builder
    extends ObjectDetectionTranslator.ObjectDetectionBuilder<Builder> {
        YoloOutputType outputType = YoloOutputType.AUTO;
        float nmsThreshold = 0.4f;

        public Builder optOutputType(YoloOutputType outputType) {
            this.outputType = outputType;
            return this;
        }

        public Builder optNmsThreshold(float nmsThreshold) {
            this.nmsThreshold = nmsThreshold;
            return this;
        }

        @Override
        protected Builder self() {
            return this;
        }

        @Override
        protected void configPostProcess(Map<String, ?> arguments) {
            super.configPostProcess(arguments);
            String type = ArgumentsUtil.stringValue(arguments, "outputType", "AUTO");
            this.outputType = YoloOutputType.valueOf(type.toUpperCase(Locale.ENGLISH));
            this.nmsThreshold = ArgumentsUtil.floatValue(arguments, "nmsThreshold", 0.4f);
        }

        public YoloV5Translator build() {
            if (this.pipeline == null) {
                this.addTransform(array -> array.transpose(2, 0, 1).toType(DataType.FLOAT32, false).div(255));
            }
            this.validate();
            return new YoloV5Translator(this);
        }
    }

    public static enum YoloOutputType {
        BOX,
        DETECT,
        AUTO;

    }
}

