/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.engine.rust;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.engine.rust.RsNDArray;
import ai.djl.engine.rust.RustLibrary;
import ai.djl.ndarray.BaseNDManager;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;

public class RsNDManager
extends BaseNDManager {
    private static final RsNDManager SYSTEM_MANAGER = new SystemManager();

    private RsNDManager(NDManager parent, Device device) {
        super(parent, device);
    }

    static RsNDManager getSystemManager() {
        return SYSTEM_MANAGER;
    }

    @Override
    public ByteBuffer allocateDirect(int capacity) {
        return ByteBuffer.allocateDirect(capacity).order(ByteOrder.nativeOrder());
    }

    @Override
    public RsNDArray from(NDArray array) {
        if (array == null || array instanceof RsNDArray) {
            return (RsNDArray)array;
        }
        RsNDArray result = this.create(array.toByteBuffer(), array.getShape(), array.getDataType());
        result.setName(array.getName());
        return result;
    }

    @Override
    public RsNDArray create(Shape shape, DataType dataType) {
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        int dType = this.toRustDataType(dataType);
        long handle = RustLibrary.zeros(shape.getShape(), dType, deviceType, deviceId);
        return new RsNDArray(this, handle, dataType);
    }

    @Override
    public RsNDArray create(Buffer data, Shape shape, DataType dataType) {
        ByteBuffer buf;
        int size = Math.toIntExact(shape.size());
        BaseNDManager.validateBuffer(data, dataType, size);
        if (data.isDirect() && data instanceof ByteBuffer) {
            buf = (ByteBuffer)data;
        } else {
            buf = this.allocateDirect(size * dataType.getNumOfBytes());
            RsNDManager.copyBuffer(data, buf);
        }
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        int dType = this.toRustDataType(dataType);
        long handle = RustLibrary.tensorOf(buf, shape.getShape(), dType, deviceType, deviceId);
        return new RsNDArray(this, handle, dataType, buf);
    }

    @Override
    public NDArray create(String[] data, Charset charset, Shape shape) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDArray createCoo(Buffer data, long[][] indices, Shape shape) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NDArray zeros(Shape shape, DataType dataType) {
        return this.create(shape, dataType);
    }

    @Override
    public NDArray ones(Shape shape, DataType dataType) {
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        int dType = this.toRustDataType(dataType);
        long handle = RustLibrary.ones(shape.getShape(), dType, deviceType, deviceId);
        return new RsNDArray(this, handle, dataType);
    }

    @Override
    public NDArray full(Shape shape, float value, DataType dataType) {
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        int dType = this.toRustDataType(dataType);
        long handle = RustLibrary.full(value, shape.getShape(), dType, deviceType, deviceId);
        return new RsNDArray(this, handle, dataType);
    }

    @Override
    public NDArray arange(int start, int stop, int step, DataType dataType) {
        return this.arange(start, stop, step, dataType, this.device);
    }

    @Override
    public NDArray arange(float start, float stop, float step, DataType dataType) {
        if (Math.signum(stop - start) != Math.signum(step)) {
            return this.create(new Shape(0L), dataType, this.device);
        }
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        int dType = this.toRustDataType(dataType);
        long handle = RustLibrary.arange(start, stop, step, dType, deviceType, deviceId);
        return new RsNDArray(this, handle, dataType);
    }

    @Override
    public NDArray eye(int rows, int cols, int k, DataType dataType) {
        if (k != 0) {
            throw new UnsupportedOperationException("index of the diagonal is not supported in Rust");
        }
        if (rows != cols) {
            throw new UnsupportedOperationException("rows must equals to columns in Rust");
        }
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        int dType = this.toRustDataType(dataType);
        long handle = RustLibrary.eye(rows, cols, dType, deviceType, deviceId);
        return new RsNDArray(this, handle, dataType);
    }

    @Override
    public NDArray linspace(float start, float stop, int num, boolean endpoint) {
        if (!endpoint) {
            throw new UnsupportedOperationException("endpoint only support true");
        }
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        int dType = DataType.FLOAT32.ordinal();
        long handle = RustLibrary.linspace(start, stop, num, dType, deviceType, deviceId);
        return new RsNDArray(this, handle, DataType.FLOAT32);
    }

    @Override
    public NDArray randomInteger(long low, long high, Shape shape, DataType dataType) {
        long[] sh = shape.getShape();
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        int dType = DataType.FLOAT32.ordinal();
        long handle = RustLibrary.randint(low, high, sh, dType, deviceType, deviceId);
        return new RsNDArray(this, handle, DataType.FLOAT32);
    }

    @Override
    public NDArray randomPermutation(long n) {
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        long handle = RustLibrary.randomPermutation(n, deviceType, deviceId);
        return new RsNDArray(this, handle);
    }

    @Override
    public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) {
        long[] sh = shape.getShape();
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        int dType = this.toRustDataType(dataType);
        long handle = RustLibrary.uniform(low, high, sh, dType, deviceType, deviceId);
        return new RsNDArray(this, handle, dataType);
    }

    @Override
    public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataType) {
        long[] sh = shape.getShape();
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        int dType = this.toRustDataType(dataType);
        long handle = RustLibrary.randomNormal(loc, scale, sh, dType, deviceType, deviceId);
        return new RsNDArray(this, handle, dataType);
    }

    @Override
    public NDArray hanningWindow(long numPoints) {
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        long handle = RustLibrary.hannWindow(numPoints, deviceType, deviceId);
        return new RsNDArray(this, handle);
    }

    @Override
    public RsNDManager newSubManager(Device device) {
        RsNDManager manager = new RsNDManager(this, device);
        this.attachUncappedInternal(manager.uid, manager);
        return manager;
    }

    @Override
    public final Engine getEngine() {
        return Engine.getEngine("Rust");
    }

    int toRustDataType(DataType dataType) {
        switch (dataType) {
            case BOOLEAN: 
            case INT8: {
                return DataType.UINT8.ordinal();
            }
            case INT32: {
                return DataType.UINT32.ordinal();
            }
            case FLOAT16: 
            case BFLOAT16: 
            case FLOAT32: 
            case FLOAT64: 
            case UINT8: 
            case UINT32: 
            case INT64: {
                return dataType.ordinal();
            }
        }
        throw new UnsupportedOperationException("Unsupported data type: " + (Object)((Object)dataType));
    }

    private static final class SystemManager
    extends RsNDManager
    implements NDManager.SystemNDManager {
        SystemManager() {
            super(null, null);
        }
    }
}

