/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;

public final class ScaledDotProductAttentionBlock
extends AbstractBlock {
    private static final byte VERSION = 1;
    private int embeddingSize;
    private int headCount;
    private Linear keyProjection;
    private Linear queryProjection;
    private Linear valueProjection;
    private Linear resultProjection;
    private Dropout attentionProbsDropout;

    private ScaledDotProductAttentionBlock(Builder builder) {
        super((byte)1);
        this.embeddingSize = builder.embeddingSize;
        this.headCount = builder.headCount;
        this.keyProjection = this.addChildBlock("keyProjection", this.buildProjection());
        this.queryProjection = this.addChildBlock("queryProjection", this.buildProjection());
        this.valueProjection = this.addChildBlock("valueProjection", this.buildProjection());
        this.resultProjection = this.addChildBlock("resultProjection", this.buildProjection());
        this.attentionProbsDropout = this.addChildBlock("probabilityDropout", Dropout.builder().optRate(builder.attentionProbsDropoutProb).build());
    }

    private Linear buildProjection() {
        return Linear.builder().setUnits(this.embeddingSize).optBias(true).build();
    }

    public Linear getKeyProjection() {
        return this.keyProjection;
    }

    public Linear getQueryProjection() {
        return this.queryProjection;
    }

    public Linear getValueProjection() {
        return this.valueProjection;
    }

    public Linear getResultProjection() {
        return this.resultProjection;
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        if (inputShapes.length == 1 || inputShapes.length == 2) {
            return new Shape[]{inputShapes[0]};
        }
        if (inputShapes.length == 3 || inputShapes.length == 4) {
            return new Shape[]{inputShapes[1]};
        }
        throw new IllegalArgumentException("Invalid number of input shapes: " + inputShapes.length + ", must be 1-4.");
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
        Shape projectionShape = new Shape(-1L, this.embeddingSize);
        for (Block projection : this.children.values()) {
            projection.initialize(manager, DataType.FLOAT32, projectionShape);
        }
    }

    private NDArray createAttentionHeadsFromEmbeddings(NDArray projection, long B, long S, long N2, long H) {
        NDArray sequenceAndHeads = projection.reshape(B, S, N2, H);
        return sequenceAndHeads.transpose(0, 2, 1, 3);
    }

    @Override
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDList flattenedValueInput;
        NDList flattenedQueryInput;
        NDList flattenedKeyInput;
        long T;
        long F;
        long E = this.embeddingSize;
        long B = inputs.head().getShape().get(0);
        long N2 = this.headCount;
        long H = E / N2;
        if (inputs.size() < 3) {
            T = F = inputs.head().getShape().get(1);
            flattenedQueryInput = flattenedKeyInput = new NDList(inputs.head());
            flattenedValueInput = flattenedKeyInput;
        } else {
            F = ((NDArray)inputs.get(0)).getShape().get(1);
            T = ((NDArray)inputs.get(1)).getShape().get(1);
            flattenedKeyInput = new NDList((NDArray)inputs.get(0));
            flattenedQueryInput = new NDList((NDArray)inputs.get(1));
            flattenedValueInput = new NDList((NDArray)inputs.get(2));
        }
        NDArray attentionMask = inputs.size() == 2 || inputs.size() == 4 ? (NDArray)inputs.get(inputs.size() - 1) : null;
        NDList keys = this.keyProjection.forward(parameterStore, flattenedKeyInput, training, params);
        NDList queries = this.queryProjection.forward(parameterStore, flattenedQueryInput, training, params);
        NDList values = this.valueProjection.forward(parameterStore, flattenedValueInput, training, params);
        NDArray keyHeads = this.createAttentionHeadsFromEmbeddings(keys.head(), B, F, N2, H);
        NDArray queryHeads = this.createAttentionHeadsFromEmbeddings(queries.head(), B, T, N2, H);
        NDArray valueHeads = this.createAttentionHeadsFromEmbeddings(values.head(), B, F, N2, H);
        NDArray attentionScores = queryHeads.matMul(keyHeads.transpose(0, 1, 3, 2));
        NDArray normalizedAttentionScores = attentionScores.mul(attentionScores.getManager().create(1.0f / (float)Math.sqrt(H)));
        if (attentionMask != null) {
            NDArray maskOffset;
            if (attentionMask.getShape().dimension() != 4) {
                NDArray expandedMask = attentionMask.reshape(B, 1L, T, F);
                maskOffset = expandedMask.toType(DataType.FLOAT32, false).mul(expandedMask.getManager().create(-1.0f)).add(expandedMask.getManager().create(1.0f)).mul(expandedMask.getManager().create(-100000.0f));
            } else {
                maskOffset = attentionMask;
            }
            normalizedAttentionScores = normalizedAttentionScores.add(maskOffset);
        }
        NDArray attentionProbs = normalizedAttentionScores.softmax(3);
        NDArray attentionProbsAfterDropout = this.attentionProbsDropout.forward(parameterStore, new NDList(attentionProbs), training).singletonOrThrow();
        NDArray attentionResult = attentionProbsAfterDropout.matMul(valueHeads);
        NDArray resultEmbeddings = attentionResult.transpose(0, 2, 1, 3).reshape(B, T, E);
        NDList projectedEmbeddings = this.resultProjection.forward(parameterStore, new NDList(resultEmbeddings), training);
        return new NDList(projectedEmbeddings);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        private int embeddingSize;
        private int headCount;
        private float attentionProbsDropoutProb = 0.1f;

        private Builder() {
        }

        public Builder setEmbeddingSize(int embeddingSize) {
            this.embeddingSize = embeddingSize;
            return this;
        }

        public Builder setHeadCount(int headCount) {
            this.headCount = headCount;
            return this;
        }

        public Builder optAttentionProbsDropoutProb(float attentionProbsDropoutProb) {
            this.attentionProbsDropoutProb = attentionProbsDropoutProb;
            return this;
        }

        public ScaledDotProductAttentionBlock build() {
            if (this.embeddingSize < 1) {
                throw new IllegalStateException("Embedding size not initialized.");
            }
            if (this.headCount < 1) {
                throw new IllegalStateException("Head count not initialized.");
            }
            if (this.embeddingSize % this.headCount != 0) {
                throw new IllegalStateException("Embedding Size (" + this.embeddingSize + ") is not divisible by head count (" + this.headCount + ")");
            }
            return new ScaledDotProductAttentionBlock(this);
        }
    }
}

