/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.cv.output;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.util.JsonSerializable;
import ai.djl.util.JsonUtils;
import ai.djl.util.RandomUtils;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.util.List;
import java.util.stream.Collectors;

public class CategoryMask
implements JsonSerializable {
    private static final long serialVersionUID = 1L;
    private static final int COLOR_BLACK = -16777216;
    private List<String> classes;
    private int[][] mask;

    public CategoryMask(List<String> classes, int[][] mask) {
        this.classes = classes;
        this.mask = mask;
    }

    public List<String> getClasses() {
        return this.classes;
    }

    public int[][] getMask() {
        return this.mask;
    }

    @Override
    public JsonElement serialize() {
        JsonObject ret = new JsonObject();
        ret.add("classes", JsonUtils.GSON.toJsonTree(this.classes));
        ret.add("mask", JsonUtils.GSON.toJsonTree(this.mask));
        return ret;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(4096);
        String list = this.classes.stream().map(s2 -> '\"' + s2 + '\"').collect(Collectors.joining(", "));
        sb.append("{\n\t\"classes\": [").append(list).append("],\n\t\"mask\": ");
        sb.append(JsonUtils.GSON_COMPACT.toJson(this.mask));
        sb.append("\n}");
        return sb.toString();
    }

    public Image getMaskImage(Image image) {
        return image.getMask(this.mask);
    }

    public Image getMaskImage(Image image, int classId) {
        int width = this.mask[0].length;
        int height = this.mask.length;
        int[][] selected = new int[height][width];
        for (int h2 = 0; h2 < height; ++h2) {
            for (int w = 0; w < width; ++w) {
                selected[h2][w] = this.mask[h2][w] == classId ? 1 : 0;
            }
        }
        return image.getMask(selected);
    }

    public Image getBackgroundImage(Image image) {
        return this.getMaskImage(image, 0);
    }

    public void drawMask(Image image, int opacity) {
        this.drawMask(image, opacity, -16777216);
    }

    public void drawMask(Image image, int opacity, int background) {
        int[] colors = this.generateColors(background, opacity);
        Image maskImage = this.getColorOverlay(colors);
        image.drawImage(maskImage, true);
    }

    public void drawMask(Image image, int classId, int color, int opacity) {
        int[] colors = new int[this.classes.size()];
        colors[classId] = color & 0xFFFFFF | opacity << 24;
        Image colorOverlay = this.getColorOverlay(colors);
        image.drawImage(colorOverlay, true);
    }

    private Image getColorOverlay(int[] colors) {
        int height = this.mask.length;
        int width = this.mask[0].length;
        int[] pixels = new int[width * height];
        for (int h2 = 0; h2 < height; ++h2) {
            for (int w = 0; w < width; ++w) {
                int index = this.mask[h2][w];
                pixels[h2 * width + w] = colors[index];
            }
        }
        return ImageFactory.getInstance().fromPixels(pixels, width, height);
    }

    private int[] generateColors(int background, int opacity) {
        int[] colors = new int[this.classes.size()];
        colors[0] = background;
        for (int i = 1; i < this.classes.size(); ++i) {
            int red = RandomUtils.nextInt(256);
            int green = RandomUtils.nextInt(256);
            int blue = RandomUtils.nextInt(256);
            colors[i] = opacity << 24 | red << 16 | green << 8 | blue;
        }
        return colors;
    }
}

