/*
 * Decompiled with CFR 0.152.
 */
package ai.djl;

import ai.djl.engine.Engine;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class Device {
    private static final Map<String, Device> CACHE = new ConcurrentHashMap<String, Device>();
    private static final Device CPU = new Device("cpu", -1);
    private static final Device GPU = Device.of("gpu", 0);
    private static final Pattern DEVICE_NAME = Pattern.compile("([a-z]+)([0-9]*)");
    protected String deviceType;
    protected int deviceId;

    private Device(String deviceType, int deviceId) {
        this.deviceType = deviceType;
        this.deviceId = deviceId;
    }

    public static Device of(String deviceType, int deviceId) {
        if ("cpu".equals(deviceType)) {
            return CPU;
        }
        String key = deviceType + '-' + deviceId;
        return CACHE.computeIfAbsent(key, k -> new Device(deviceType, deviceId));
    }

    public static Device fromName(String deviceName) {
        return Device.fromName(deviceName, Engine.getInstance());
    }

    public static Device fromName(String deviceName, Engine engine) {
        if (deviceName == null || deviceName.isEmpty()) {
            return engine.defaultDevice();
        }
        if (deviceName.contains("+")) {
            String[] split2 = deviceName.split("\\+");
            List<Device> subDevices = Arrays.stream(split2).map(n -> Device.fromName(n, engine)).collect(Collectors.toList());
            return new MultiDevice(subDevices);
        }
        Matcher matcher = DEVICE_NAME.matcher(deviceName);
        if (matcher.matches()) {
            String deviceType = matcher.group(1);
            int deviceId = -1;
            if (!matcher.group(2).isEmpty()) {
                deviceId = Integer.parseInt(matcher.group(2));
            }
            return Device.of(deviceType, deviceId);
        }
        try {
            int deviceId = Integer.parseInt(deviceName);
            if (deviceId < 0) {
                return Device.cpu();
            }
            return Device.gpu(deviceId);
        }
        catch (NumberFormatException numberFormatException) {
            throw new IllegalArgumentException("Failed to parse device name: " + deviceName);
        }
    }

    public String getDeviceType() {
        return this.deviceType;
    }

    public int getDeviceId() {
        return this.deviceId;
    }

    public boolean isGpu() {
        return "gpu".equals(this.deviceType);
    }

    public List<Device> getDevices() {
        return Collections.singletonList(this);
    }

    public String toString() {
        if ("cpu".equals(this.deviceType)) {
            return this.deviceType + "()";
        }
        return this.deviceType + '(' + this.deviceId + ')';
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        Device device = (Device)o;
        if ("cpu".equals(this.deviceType)) {
            return Objects.equals(this.deviceType, device.deviceType);
        }
        return this.deviceId == device.deviceId && Objects.equals(this.deviceType, device.deviceType);
    }

    public int hashCode() {
        return Objects.hash(this.deviceType, this.deviceId);
    }

    public static Device cpu() {
        return CPU;
    }

    public static Device gpu() {
        return GPU;
    }

    public static Device gpu(int deviceId) {
        return Device.of("gpu", deviceId);
    }

    public static interface Type {
        public static final String CPU = "cpu";
        public static final String GPU = "gpu";
    }

    public static class MultiDevice
    extends Device {
        List<Device> devices;

        public MultiDevice(String deviceType, int startInclusive, int endExclusive) {
            this(IntStream.range(startInclusive, endExclusive).mapToObj(i -> Device.of(deviceType, i)).collect(Collectors.toList()));
        }

        public MultiDevice(Device ... devices) {
            this(Arrays.asList(devices));
        }

        public MultiDevice(List<Device> devices) {
            super(null, -1);
            devices.sort(Comparator.comparing(Device::getDeviceType, String.CASE_INSENSITIVE_ORDER).thenComparingInt(Device::getDeviceId));
            this.deviceType = String.join((CharSequence)"+", () -> devices.stream().map(d -> d.getDeviceType() + d.getDeviceId()).iterator());
            this.devices = devices;
        }

        @Override
        public List<Device> getDevices() {
            return this.devices;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            MultiDevice that = (MultiDevice)o;
            return Objects.equals(this.devices, that.devices);
        }

        @Override
        public int hashCode() {
            return Objects.hash(super.hashCode(), this.devices);
        }

        @Override
        public String toString() {
            return this.deviceType + "()";
        }
    }
}

