/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.listener;

import ai.djl.Model;
import ai.djl.training.Trainer;
import ai.djl.training.listener.TrainingListenerAdapter;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SaveModelTrainingListener
extends TrainingListenerAdapter {
    private static final Logger logger = LoggerFactory.getLogger(SaveModelTrainingListener.class);
    private String outputDir;
    private String overrideModelName;
    private Consumer<Trainer> onSaveModel;
    private int checkpoint;
    private int epoch;

    public SaveModelTrainingListener(String outputDir) {
        this(outputDir, null, -1);
    }

    public SaveModelTrainingListener(String outputDir, String overrideModelName) {
        this(outputDir, overrideModelName, -1);
    }

    public SaveModelTrainingListener(String outputDir, String overrideModelName, int checkpoint) {
        this.outputDir = outputDir;
        this.checkpoint = checkpoint;
        if (outputDir == null) {
            throw new IllegalArgumentException("Can not save checkpoint without specifying an output directory");
        }
        this.overrideModelName = overrideModelName;
    }

    @Override
    public void onEpoch(Trainer trainer) {
        ++this.epoch;
        if (this.outputDir == null) {
            return;
        }
        if (this.checkpoint > 0 && this.epoch % this.checkpoint == 0) {
            this.saveModel(trainer);
        }
    }

    @Override
    public void onTrainingEnd(Trainer trainer) {
        if (this.checkpoint == -1 || this.epoch % this.checkpoint != 0) {
            this.saveModel(trainer);
        }
    }

    public String getOverrideModelName() {
        return this.overrideModelName;
    }

    public void setOverrideModelName(String overrideModelName) {
        this.overrideModelName = overrideModelName;
    }

    public int getCheckpoint() {
        return this.checkpoint;
    }

    public void setCheckpoint(int checkpoint) {
        this.checkpoint = checkpoint;
    }

    public void setSaveModelCallback(Consumer<Trainer> onSaveModel) {
        this.onSaveModel = onSaveModel;
    }

    protected void saveModel(Trainer trainer) {
        Model model = trainer.getModel();
        String modelName = model.getName();
        if (this.overrideModelName != null) {
            modelName = this.overrideModelName;
        }
        try {
            model.setProperty("Epoch", String.valueOf(this.epoch));
            if (this.onSaveModel != null) {
                this.onSaveModel.accept(trainer);
            }
            model.save(Paths.get(this.outputDir, new String[0]), modelName);
        }
        catch (IOException e) {
            logger.error("Failed to save checkpoint", e);
        }
    }
}

