/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.BaseImageTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class YoloPoseTranslator
extends BaseImageTranslator<Joints[]> {
    private static final int MAX_DETECTION = 300;
    private float threshold;
    private float nmsThreshold;

    public YoloPoseTranslator(Builder builder) {
        super(builder);
        this.threshold = builder.threshold;
        this.nmsThreshold = builder.nmsThreshold;
    }

    @Override
    public Joints[] processOutput(TranslatorContext ctx, NDList list) {
        NDArray pred = list.singletonOrThrow();
        NDArray candidates = pred.get(4L).gt(Float.valueOf(this.threshold));
        pred = pred.transpose();
        NDArray sub = pred.get("..., :4", new Object[0]);
        sub = this.xywh2xyxy(sub);
        pred = sub.concat(pred.get("..., 4:", new Object[0]), -1);
        pred = pred.get(candidates);
        NDList split2 = pred.split(new long[]{4L, 5L}, 1);
        NDArray box = (NDArray)split2.get(0);
        int numBox = Math.toIntExact(box.getShape().get(0));
        float[] buf = box.toFloatArray();
        float[] confidences = ((NDArray)split2.get(1)).toFloatArray();
        float[] mask = ((NDArray)split2.get(2)).toFloatArray();
        ArrayList<Rectangle> boxes = new ArrayList<Rectangle>(numBox);
        ArrayList<Double> scores = new ArrayList<Double>(numBox);
        for (int i = 0; i < numBox; ++i) {
            float xPos = buf[i * 4];
            float yPos = buf[i * 4 + 1];
            float w = buf[i * 4 + 2] - xPos;
            float h2 = buf[i * 4 + 3] - yPos;
            Rectangle rect = new Rectangle(xPos, yPos, w, h2);
            boxes.add(rect);
            scores.add(Double.valueOf(confidences[i]));
        }
        List<Integer> nms = Rectangle.nms(boxes, scores, this.nmsThreshold);
        if (nms.size() > 300) {
            nms = nms.subList(0, 300);
        }
        Joints[] ret = new Joints[nms.size()];
        for (int i = 0; i < ret.length; ++i) {
            ArrayList<Joints.Joint> joints = new ArrayList<Joints.Joint>();
            ret[i] = new Joints(joints);
            int index = nms.get(i);
            int pos = index * 51;
            for (int j = 0; j < 17; ++j) {
                joints.add(new Joints.Joint(mask[pos + j * 3] / (float)this.width, mask[pos + j * 3 + 1] / (float)this.height, mask[pos + j * 3 + 2]));
            }
        }
        return ret;
    }

    private NDArray xywh2xyxy(NDArray array) {
        NDArray xy = array.get("..., :2", new Object[0]);
        NDArray wh = array.get("..., 2:", new Object[0]).div(2);
        return xy.sub(wh).concat(xy.add(wh), -1);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> arguments) {
        Builder builder = new Builder();
        builder.configPreProcess(arguments);
        builder.configPostProcess(arguments);
        return builder;
    }

    public static class Builder
    extends BaseImageTranslator.BaseBuilder<Builder> {
        float threshold = 0.25f;
        float nmsThreshold = 0.7f;

        Builder() {
        }

        public Builder optThreshold(float threshold) {
            this.threshold = threshold;
            return this.self();
        }

        public Builder optNmsThreshold(float nmsThreshold) {
            this.nmsThreshold = nmsThreshold;
            return this;
        }

        @Override
        protected Builder self() {
            return this;
        }

        @Override
        protected void configPostProcess(Map<String, ?> arguments) {
            this.optThreshold(ArgumentsUtil.floatValue(arguments, "threshold", this.threshold));
            this.optNmsThreshold(ArgumentsUtil.floatValue(arguments, "nmsThreshold", this.nmsThreshold));
        }

        public YoloPoseTranslator build() {
            this.validate();
            return new YoloPoseTranslator(this);
        }
    }
}

