// Copyright HS Analysis GmbH, 2023
// Author: Valentin Haas

// HSA imports
import { isInt } from "../utils/Utils";
import Structure from "./Structure";

// #region Option sets for the AI training

// Shared data structures for the AI training
// Keep in sync with:
// Source\HSA-KIT\Models\AITrainingSettings.cs
// Source\HSA-KIT\modules\hsa\core\ai\models.py
// Source\HSA-KIT\ClientApp\src\common\components\AITrainingSettings.jsx

/**
 * All principle dataset creation approaches.
 */
export const DatasetApproach = Object.freeze({
  ImageFileBased: 0,
  ImageObjectBased: 1,
  ImageSlidingWindow: 2,
  AudioFileBased: 10,
  AudioSlidingWindow: 11,
});

/**
 * Readable names for the principle dataset creation approaches.
 */
export const DatasetApproachNames = Object.freeze({
  [DatasetApproach.ImageFileBased]: "File Based",
  [DatasetApproach.ImageObjectBased]: "Object Based",
  [DatasetApproach.ImageSlidingWindow]: "Sliding Window",
  [DatasetApproach.AudioFileBased]: "File Based",
  [DatasetApproach.AudioSlidingWindow]: "Sliding Window",
});

/**
 * Abbreviations for the principle dataset creation approaches.
 */
export const DatasetApproachShort = Object.freeze({
  [DatasetApproach.ImageFileBased]: "i-fb",
  [DatasetApproach.ImageObjectBased]: "i-ob",
  [DatasetApproach.ImageSlidingWindow]: "i-sw",
  [DatasetApproach.AudioFileBased]: "a-fb",
  [DatasetApproach.AudioSlidingWindow]: "a-sw",
});

/**
 * All training model types.
 */
export const DatasetType = Object.freeze({
  ImageClassification: 0,
  ImageObjectDetection: 1,
  ImageSegmentation: 2,
  ImageInstanceSegmentation: 3,
  AudioClassification: 10,
  AudioSequenceDetection: 11,
  AudioSequenceSegmentation: 12,
  AudioSequenceInstanceSegmentation: 13,
});

/**
 * Readable names for the training model types.
 */
export const DatasetTypeNames = Object.freeze({
  [DatasetType.ImageClassification]: "Classification",
  [DatasetType.ImageObjectDetection]: "Object Detection",
  [DatasetType.ImageSegmentation]: "Segmentation",
  [DatasetType.ImageInstanceSegmentation]: "Instance Segmentation",
  [DatasetType.AudioClassification]: "Classification",
  [DatasetType.AudioSequenceDetection]: "Object Detection",
  [DatasetType.AudioSequenceSegmentation]: "Segmentation",
  [DatasetType.AudioSequenceInstanceSegmentation]: "Instance Segmentation",
});

/**
 * Abbreviations for the training model types.
 */
export const DatasetTypeShort = Object.freeze({
  [DatasetType.ImageClassification]: "i-cl",
  [DatasetType.ImageObjectDetection]: "i-od",
  [DatasetType.ImageSegmentation]: "i-seg",
  [DatasetType.ImageInstanceSegmentation]: "i-iseg",
  [DatasetType.AudioClassification]: "a-cl",
  [DatasetType.AudioSequenceDetection]: "a-sd",
  [DatasetType.AudioSequenceSegmentation]: "a-seg",
  [DatasetType.AudioSequenceInstanceSegmentation]: "a-iseg",
});

/**
 * A collection of all image dataset types.
 */
export const ImageDatasetTypes = [
  DatasetType.ImageClassification,
  DatasetType.ImageObjectDetection,
  DatasetType.ImageSegmentation,
  DatasetType.ImageInstanceSegmentation,
];

/**
 * A collection of all audio dataset types.
 */
export const AudioDatasetTypes = [
  DatasetType.AudioClassification,
  DatasetType.AudioSequenceDetection,
  DatasetType.AudioSequenceSegmentation,
  DatasetType.AudioSequenceInstanceSegmentation,
];

/**
 * A collection of all image dataset approaches.
 */
export const ImageDatasetApproaches = [
  DatasetApproach.ImageFileBased,
  DatasetApproach.ImageObjectBased,
  DatasetApproach.ImageSlidingWindow,
];

/**
 * A collection of all audio dataset approaches.
 */
export const AudioDatasetApproaches = [
  DatasetApproach.AudioFileBased,
  DatasetApproach.AudioSlidingWindow,
];

/**
 * All principle training data types.
 */
export const TrainingDataTypes = Object.freeze({
  Image: 0,
  Audio: 1,
});

/**
 * Mapping of training data types to project types.
 */
const _trainingDataTypebyModuleName = Object.freeze({
  AudioAnnotator: TrainingDataTypes.Audio,
});

/**
 * Returns the training data type for a given module.
 * @param {string} projectType The name of the module currently loaded to get the training data type for.
 * @returns {TrainingDataTypes} The training data type for the given module. Defaults to trainingDataTypes.Image.
 */
export const trainingDataTypebyModuleName = (projectType = "") => {
  return _trainingDataTypebyModuleName[projectType] || TrainingDataTypes.Image;
};

/**
 * All available metrics for training.
 */
export const Metrics = Object.freeze({
  Accuracy: 0,
});
export const MetricsNames = Object.freeze({
  [Metrics.Accuracy]: "Accuracy",
});

/**
 * All available loss functions for training.
 */
export const LossFunction = Object.freeze({
  CrossEntropy: 0,
  Dice: 1,
  CrossEntropyDice: 2,
});
export const LossFunctionNames = Object.freeze({
  [LossFunction.CrossEntropy]: "Cross Entropy",
  [LossFunction.Dice]: "Dice",
  [LossFunction.CrossEntropyDice]: "Cross Entropy + Dice",
});

/**
 * All available optimizers for training.
 */
export const Optimizer = Object.freeze({
  Adam: 0,
  AdamW: 1,
  SGD: 2,
});
export const OptimizerNames = Object.freeze({
  [Optimizer.Adam]: "Adam",
  [Optimizer.AdamW]: "AdamW",
  [Optimizer.SGD]: "SGD",
});

// #endregion

// region Default values
const _DefaultDatasetTypeMappings = Object.freeze({
  AudioAnnotator: DatasetType.AudioSequenceDetection,
});

/**
 * Get a default model type for a given project type.
 * @param {string} projectType The type of the project to get the default model type for.
 * @returns {DatasetType} The default model type for the given project type.
 */
export const DefaultDatasetType = (projectType = "") => {
  return (
    _DefaultDatasetTypeMappings[projectType] ||
    DatasetType.ImageInstanceSegmentation
  );
};

const _DefaultDatasetApproachMappings = Object.freeze({
  AudioAnnotator: DatasetApproach.AudioFileBased,
});

/**
 * Get a default dataset approach for a given project type.
 * @param {string} projectType The type of the project to get the default dataset approach for.
 * @returns {DatasetApproach} The default dataset approach for the given project type.
 */
export const DefaultDatasetApproach = (projectType = "") => {
  return (
    _DefaultDatasetApproachMappings[projectType] ||
    DatasetApproach.ImageSlidingWindow
  );
};

// #endregion

// #region Data structures for each individual settings page
/**
 * All parameters of AI Training that are not model specific.
 * @param {number} epochs Number of epochs to train.
 * @param {number} earlyStopping Number of epochs to wait for early stopping.
 * @param {number} batchSize Batch size for training.
 * @param {number[]} metrics Metrics to use for training.
 * @param {string[]} lossFunctions Loss functions to use for training.
 * @param {string} optimizer Optimizer to use for training.
 * @param {number} learningRate Learning rate to use for training.
 */
export class TrainingParameters {
  constructor({
    epochs = 1,
    earlyStopping = 500,
    batchSize = 2,
    metrics = [Metrics.Accuracy],
    lossFunctions = [LossFunction.CrossEntropy, LossFunction.Dice],
    optimizer = Optimizer.AdamW,
    learningRate = 1e-4,
  } = {}) {
    // Input validation
    if (!isInt(epochs) || epochs <= 0)
      throw TypeError(
        `epochs must be of type integer > 0, received ${typeof epochs}: ${epochs}`
      );
    if (!isInt(earlyStopping) || earlyStopping < 0)
      throw TypeError(
        `earlyStopping must be of type integer >= 0, received ${typeof earlyStopping}: ${earlyStopping}`
      );
    if (!isInt(batchSize) || batchSize < 1)
      throw TypeError(
        `batchSize must be of type integer >= 1, received ${typeof batchSize}: ${batchSize}`
      );
    if (!Array.isArray(metrics) || metrics.length === 0)
      throw TypeError(
        `metrics must be of type array with at least one element, received ${typeof metrics}: ${metrics}`
      );
    for (const metric of metrics) {
      if (!isInt(metric) || !Object.values(Metrics).includes(metric))
        throw TypeError(
          `metrics must be of type array with elements of type Metrics with one of the following values: ${Object.values(
            Metrics
          )}, received ${typeof metric}: ${metric}`
        );
    }

    if (!Array.isArray(lossFunctions) || lossFunctions.length === 0)
      throw TypeError(
        `lossFunctions must be of type array with at least one element, received ${typeof lossFunctions}: ${lossFunctions}`
      );
    for (const lossFunction of lossFunctions) {
      if (
        !isInt(lossFunction) ||
        !Object.values(LossFunction).includes(lossFunction)
      )
        throw TypeError(
          `lossFunctions must be of type array with elements of type LossFunction with one of the following values: ${Object.values(
            LossFunction
          )}, received ${typeof lossFunction}: ${lossFunction}`
        );
    }

    if (!isInt(optimizer) || !Object.values(Optimizer).includes(optimizer))
      throw TypeError(
        `optimizer must be of type Optimizer with one of the following values: ${Object.values(
          Optimizer
        )}, received ${typeof optimizer}: ${optimizer}`
      );
    if (typeof learningRate !== "number" || learningRate <= 0)
      throw TypeError(
        `learningRate must be of type number > 0, received ${typeof learningRate}: ${learningRate}`
      );

    this.epochs = epochs;
    this.earlyStopping = earlyStopping;
    this.batchSize = batchSize;
    this.metrics = metrics;
    this.lossFunctions = lossFunctions;
    this.optimizer = optimizer;
    this.learningRate = learningRate;
  }
}

/**
 * All meta data of an AI Training.
 * @param {string} name Name of the AI Model.
 * @param {string} version Version of the AI Model.
 * @param {string} description Description of the AI Model.
 * @param {boolean} isNewModel Whether the model is new or not.
 */
export class ModelMetaData {
  constructor({
    name = "",
    version = "",
    description = "",
    isNewModel = true,
  } = {}) {
    if (typeof name !== "string")
      throw TypeError(
        `name must be of type string, received ${typeof name}: ${name}`
      );
    if (typeof version !== "string")
      throw TypeError(
        `modelVersion must be of type string, received ${typeof version}: ${version}`
      );
    if (typeof description !== "string")
      throw TypeError(
        `modelDescription must be of type string, received ${typeof description}: ${description}`
      );
    if (typeof isNewModel !== "boolean")
      throw TypeError(
        `isNewModel must be of type boolean, received ${typeof isNewModel}: ${isNewModel}`
      );
    this.name = name;
    this.version = version;
    this.description = description;
    this.isNewModel = isNewModel;
  }
}

/**
 * All dataset specific parameters of an AI Training.
 * @param {DatasetApproach} datasetApproach Approach to create the dataset. Must be one of the values of DatasetApproach. Default: ImageFileBased.
 * @param {DatasetType} datasetType Type of the dataset used. Must be one of the values of DatasetType. Default: ImageClassification.
 * @param {boolean} datasetOnly Whether only the dataset should be created or not, without training. Default: false.
 * @param {boolean} useExistingDataset Whether an existing dataset should be used or not for training. Default: false.
 */
export class DatasetParameters {
  constructor({
    datasetApproach = DatasetApproach.ImageSlidingWindow,
    datasetType = DatasetType.ImageClassification,
    datasetOnly = false,
    useExistingDataset = false,
  } = {}) {
    if (
      !isInt(datasetApproach) ||
      !Object.values(DatasetApproach).includes(datasetApproach)
    )
      throw TypeError(
        `datasetApproach must be of type DatasetApproach with one of the following values: ${Object.values(
          DatasetApproach
        )}, received ${typeof datasetApproach}: ${datasetApproach}`
      );
    if (
      !isInt(datasetType) ||
      !Object.values(DatasetType).includes(datasetType)
    )
      throw TypeError(
        `datasetType must be of type DatasetType with one of the following values: ${Object.values(
          DatasetType
        )}, received ${typeof datasetType}: ${datasetType}`
      );
    if (typeof datasetOnly !== "boolean")
      throw TypeError(
        `datasetOnly must be of type boolean, received ${typeof datasetOnly}: ${datasetOnly}`
      );
    if (typeof useExistingDataset !== "boolean")
      throw TypeError(
        `useExistingDataset must be of type boolean, received ${typeof useExistingDataset}: ${useExistingDataset}`
      );
    this.datasetApproach = datasetApproach;
    this.datasetType = datasetType;
    this.datasetOnly = datasetOnly;
    this.useExistingDataset = useExistingDataset;
  }
}

/**
 * General class for all dataset specific parameters of an AI Training.
 */
export class DatasetMeta {
  constructor() {}
}

/**
 * All model specific parameters of an image model AI Training.
 * @param {DatasetType} datasetType Type of the dataset used. Must be one of the values of DatasetType. Default: ImageClassification.
 * @param {number} inputChannels Number of input channels. Can be 1 for grayscale images, 3 for RGB images, or more for multispectral images.
 * @param {Array} fluorescenceChannels List of fluorescence channels. Default: [].
 * @param {number} spatialDims Number of spatial dimensions. Can be 2 for 2D images or 3 for 3D images.
 * @param {number} imageWidth Width of an dataset image.
 * @param {number} imageHeight Height of an dataset image.
 * @param {number} numberOfClasses Number of classes to predict.
 */
export class ImageDatasetMeta extends DatasetMeta {
  constructor({
    datasetType = DatasetType.ImageClassification,
    inputChannels = 3,
    fluorescenceChannels = [],
    spatialDims = 2,
    imageWidth = 512,
    imageHeight = 512,
    numberOfClasses = 2,
    pyramidLevel = -1,
  } = {}) {
    super();

    // Input validation
    if (
      !isInt(datasetType) ||
      !Object.values(DatasetType).includes(datasetType)
    )
      throw TypeError(
        `datasetType must be of type DatasetType with one of the following values: ${Object.values(
          DatasetType
        )}, received ${typeof datasetType}: ${datasetType}`
      );
    if (!isInt(inputChannels) || inputChannels < 1)
      throw TypeError(
        `inputChannels must be of type integer >= 1, received ${typeof inputChannels}: ${inputChannels}`
      );
    if (!Array.isArray(fluorescenceChannels))
      throw TypeError(
        `fluorescenceChannels must be of type array, received ${typeof fluorescenceChannels}: ${fluorescenceChannels}`
      );
    if (!isInt(spatialDims) || spatialDims < 2 || spatialDims > 3)
      throw TypeError(
        `spatialDims must be of type integer >= 2 and <= 3, received ${typeof spatialDims}: ${spatialDims}`
      );
    if (!isInt(imageWidth) || imageWidth < 1)
      throw TypeError(
        `imageWidth must be of type integer >= 1, received ${typeof imageWidth}: ${imageWidth}`
      );
    if (!isInt(imageHeight) || imageHeight < 1)
      throw TypeError(
        `imageHeight must be of type integer >= 1, received ${typeof imageHeight}: ${imageHeight}`
      );
    if (!isInt(numberOfClasses) || numberOfClasses < 2)
      throw TypeError(
        `numberOfClasses must be of type integer >= 2, received ${typeof numberOfClasses}: ${numberOfClasses}`
      );
    if (!isInt(pyramidLevel) || pyramidLevel < -1)
      throw TypeError(
        `pyramidLevel must be of type integer >= -1, received ${typeof pyramidLevel}: ${pyramidLevel}`
      );

    // Set parameters
    this.datasetType = datasetType;
    this.inputChannels = inputChannels;
    this.fluorescenceChannels = fluorescenceChannels;
    this.spatialDims = spatialDims;
    this.imageWidth = imageWidth;
    this.imageHeight = imageHeight;
    this.numberOfClasses = numberOfClasses;
    this.pyramidLevel = pyramidLevel;
  }
}

/**
 * All model specific parameters of an audio model AI Training.
 * @param {DatasetType} datasetType Type of the dataset used. Must be one of the values of DatasetType. Default: AudioClassification.
 * @param {number} sequenceLengthSeconds Length of an audio sequence in seconds. 0 for full length. Default: 0.
 * @param {number} sequenceOverlapSeconds Overlap of audio sequences in seconds. Negative values for gaps between sequences. Default: 0.
 * @param {number} numberOfClasses Number of classes to predict. Default: 1.
 */
export class AudioDatasetMeta extends DatasetMeta {
  constructor({
    datasetType = DatasetType.AudioClassification,
    sequenceLengthSeconds = 0,
    sequenceOverlapSeconds = 0,
    numberOfClasses = 1,
  } = {}) {
    super();

    // Input validation
    if (
      !isInt(datasetType) ||
      !Object.values(DatasetType).includes(datasetType)
    )
      throw TypeError(
        `datasetType must be of type DatasetType with one of the following values: ${Object.values(
          DatasetType
        )}, received ${typeof datasetType}: ${datasetType}`
      );
    if (typeof sequenceLengthSeconds !== "number" || sequenceLengthSeconds < 0)
      throw TypeError(
        `sequenceLengthSeconds must be of type number >= 0, received ${typeof sequenceLengthSeconds}: ${sequenceLengthSeconds}`
      );
    if (typeof sequenceOverlapSeconds !== "number")
      throw TypeError(
        `sequenceOverlapSeconds must be of type number, received ${typeof sequenceOverlapSeconds}: ${sequenceOverlapSeconds}`
      );

    if (!isInt(numberOfClasses) || numberOfClasses < 1)
      throw TypeError(
        `numberOfClasses must be of type integer >= 1, received ${typeof numberOfClasses}: ${numberOfClasses}`
      );

    // Set parameters
    this.datasetType = datasetType;
    this.sequenceLengthSeconds = sequenceLengthSeconds;
    this.sequenceOverlapSeconds = sequenceOverlapSeconds;
    this.numberOfClasses = numberOfClasses;
  }
}

/**
 * Mapping of model parameters to dataset types.
 */
const _ModelParameterMappping = Object.freeze({
  [DatasetType.ImageClassification]: new ImageDatasetMeta(
    DatasetType.ImageClassification
  ),
  [DatasetType.ImageObjectDetection]: new ImageDatasetMeta(
    DatasetType.ImageObjectDetection
  ),
  [DatasetType.ImageSegmentation]: new ImageDatasetMeta(
    DatasetType.ImageSegmentation
  ),
  [DatasetType.ImageInstanceSegmentation]: new ImageDatasetMeta(
    DatasetType.ImageInstanceSegmentation
  ),
  [DatasetType.AudioClassification]: new AudioDatasetMeta(
    DatasetType.AudioClassification
  ),
  [DatasetType.AudioSequenceDetection]: new AudioDatasetMeta(
    DatasetType.AudioSequenceDetection
  ),
  [DatasetType.AudioSequenceSegmentation]: new AudioDatasetMeta(
    DatasetType.AudioSequenceSegmentation
  ),
  [DatasetType.AudioSequenceInstanceSegmentation]: new AudioDatasetMeta(
    DatasetType.AudioSequenceInstanceSegmentation
  ),
});

/**
 * Returns the model parameters for a given dataset type.
 * @param {DatasetType} datasetType The type of the dataset to get the model parameters for.
 * @returns {DatasetMeta}
 */
export const getModelParameters = (datasetType) => {
  return _ModelParameterMappping[datasetType] || new DatasetMeta();
};

// #endregion

/**
 * All settings of an AI Training.
 * @param {Uuid} projectId ID of the project to train on.
 * @param {ModelMetaData} metaData Meta data of the AI Model.
 * @param {DatasetParameters} datasetParameters Dataset specific parameters of the AI Training.
 * @param {TrainingParameters} trainingParameters Training specific parameters of the AI Training.
 * @param {DatasetMeta} modelParameters Model specific parameters of the AI Training.
 * @param {Structure[]} structures Structures to train on.
 * @param {String} trainingStatus status of the AI Training.
 */
export default class AITrainingSettings {
  constructor(
    projectId,
    projectType,
    metaData = new ModelMetaData(),
    datasetParameters = new DatasetParameters(),
    trainingParameters = new TrainingParameters(),
    modelParameters = new DatasetMeta(),
    structures = []
  ) {
    // Input validation
    // if (!isUuid(projectId))
    //   throw TypeError(
    //     `projectId must be of type uuid, received ${typeof projectId}: ${projectId}`
    //   );
    if (typeof projectType !== "string")
      throw TypeError(
        `projectType must be of type string, received ${typeof projectType}: ${projectType}`
      );
    if (!(metaData instanceof ModelMetaData))
      throw TypeError(
        `metaData must be of type ModelMetaData, received ${typeof metaData}: ${metaData}`
      );
    if (!(datasetParameters instanceof DatasetParameters))
      throw TypeError(
        `datasetParameters must be of type DatasetParameters, received ${typeof datasetParameters}: ${datasetParameters}`
      );
    if (!(trainingParameters instanceof TrainingParameters))
      throw TypeError(
        `trainingParameters must be of type TrainingParameters, received ${typeof trainingParameters}: ${trainingParameters}`
      );
    if (!(modelParameters instanceof DatasetMeta))
      throw TypeError(
        `modelParameters must be of type DatasetMeta, received ${typeof modelParameters}: ${modelParameters}`
      );
    if (!Array.isArray(structures))
      throw TypeError(
        `structures must be of type array, received ${typeof structures}: ${structures}`
      );

    for (const structure of structures) {
      if (!(structure instanceof Structure))
        throw TypeError(
          `structures must be of type array with elements of type Structure, received ${typeof structure}: ${structure}`
        );
    }

    // Set parameters
    this.projectId = projectId;
    this.projectType = projectType;
    this.metaData = metaData;
    this.datasetParameters = datasetParameters;
    this.trainingParameters = trainingParameters;
    this.modelParameters = modelParameters;
    this.structures = structures;
  }
}
