import React, { useEffect, useCallback, useMemo, useState } from 'react';
import { Box, CircularProgress, Tooltip } from '@material-ui/core';
import { useAtom, useSetAtom } from 'jotai';
import { cloneDeep } from 'lodash';
import { useSnackbar } from 'notistack';

import { Button } from '@clef/client-library';
import {
  TrainMode,
  SelectMediaOption,
  MediaStatusType,
  LabelType,
  DatasetLimits,
} from '@clef/shared/types';

import ProjectApi from '@/api/project_api';
import { MIN_LABELED_MEDIA_FOR_FAST_N_EASY_TRAIN } from '@/constants/model_train';
import { defaultSelectOptions } from '@/constants/data_browser';
import { useTypedSelector } from '@/hooks/useTypedSelector';
import { ClientFeatures, useFeatureGateEnabled } from '@/hooks/useFeatureGate';
import { getModelDefaultArchNameByLabelType } from '@/utils/job_train_utils';
import { pendoEntity } from '@/utils/pendo';
import { supportedAutoSplit } from '@/components/AutoSplitDialog/utils';
import { assignmentToDistributionToAssignSplitMapping } from '@clef/shared/utils/auto_split_core_algorithm';
import { useTrainModel } from '@/pages/DataBrowser/TrainModelButtonGroup/state';
import CustomTrainAlertDialog from '@/pages/DataBrowser/TrainModelButtonGroup/CustomTrainAlertDialog';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import { useProjectModelInfoQuery } from '@/serverStore/projectModels';
import { useGetDatasetFilterOptionsQuery, useAutoSplitMutation } from '@/serverStore/dataset';
import {
  currentStepAtom,
  CustomTrainingStepMode,
  selectedDatasetVersionAtom,
  modelsConfigListAtom,
  defaultModelConfig,
  useResetAndBackToBuild,
  isInProgressAtom,
  enableAdjustSplitAtom,
  defectDistributionWithAssignmentAtom,
  autoFocusNameModelAtom,
  configErrorsAtom,
  DatasetTrainingIssueType,
  datasetTrainingIssuesAtom,
  duplicateNameIdsAtom,
  showLargeImagesWarningAtom,
} from '@/uiStates/customTraining/pageUIStates';
import { generateRandomId } from '@/utils';
import { useGetModelArchSchemas } from '@/serverStore/train';

import StepTitle from '../StepTitle';
import ModelsConfiguraionTable from './ModelsConfiguraionTable';

import useStyles from '../styles';
import { useGetModelListQuery } from '@/serverStore/modelAnalysis';

export const transferDatasetLimitsFormat = (datasetLimits: DatasetLimits) => {
  return {
    maxLabeledMedia: datasetLimits.maxLabeledMedia,
    minLabeledMedia: datasetLimits.minLabeledMedia,
    largeImage: {
      maxArea: datasetLimits.largeImageMaxArea,
      minArea: datasetLimits.minImageArea,
      thresholdArea: datasetLimits.largeImageThresholdArea,
    },
  };
};

const ModelsConfiguration: React.FC = () => {
  const styles = useStyles();
  const { enqueueSnackbar } = useSnackbar();

  const [currentStep] = useAtom(currentStepAtom);
  const [selectedDatasetVersion] = useAtom(selectedDatasetVersionAtom);
  const [modelsConfigList, setModelsConfigList] = useAtom(modelsConfigListAtom);
  const [isInProgress] = useAtom(isInProgressAtom);
  const setAutoFocusNameModel = useSetAtom(autoFocusNameModelAtom);
  const [enableAdjustSplit] = useAtom(enableAdjustSplitAtom);
  const [defectDistributionWithAssignment] = useAtom(defectDistributionWithAssignmentAtom);
  const [configErrors] = useAtom(configErrorsAtom);
  const [duplicateNameIds] = useAtom(duplicateNameIdsAtom);
  const [datasetTrainingIssues] = useAtom(datasetTrainingIssuesAtom);
  const [showLargeImagesWarning] = useAtom(showLargeImagesWarningAtom);
  const resetAndBackToBuild = useResetAndBackToBuild();

  const currentUser = useTypedSelector(state => state.login.user);
  const { id: projectId, labelType } = useGetSelectedProjectQuery().data ?? {};
  const { data: modelInfo } = useProjectModelInfoQuery(projectId);
  const autoSplit = useAutoSplitMutation();
  const shouldUseAdvancedPricing = useFeatureGateEnabled(ClientFeatures.AdvancedUsageBasedPricing);
  const isCreditReferenceEnabled = !useFeatureGateEnabled(ClientFeatures.DisableCreditReference);
  const { data: modelList = [] } = useGetModelListQuery();
  const modelNames = modelList.map(model => model.modelName);

  const { data: modelArchSchemas } = useGetModelArchSchemas();

  const trainModel = useTrainModel();

  const addNewModelConfig = useCallback(() => {
    if (modelInfo && modelArchSchemas?.length && labelType) {
      const tempId = generateRandomId();
      setAutoFocusNameModel(tempId);
      const modelNamesWithPattern = modelNames.filter(modelName => {
        return modelName?.startsWith(`${defaultModelConfig.name}_`);
      });
      const modelNameSuffixes = modelNamesWithPattern.map(modelName => {
        const suffix = modelName?.replace(`${defaultModelConfig.name}_`, '') ?? '';
        const isNumber = /^\d+$/.test(suffix);
        if (isNumber) {
          return parseInt(suffix);
        } else {
          return 0;
        }
      });
      const maxModelNameSuffix = modelNameSuffixes.length ? Math.max(...modelNameSuffixes) : 0;

      const newModel = {
        ...cloneDeep(defaultModelConfig),
        tempId: tempId,
        name: `${defaultModelConfig.name}_${maxModelNameSuffix + 1 + modelsConfigList.length}`,
      };

      const currentArchName =
        getModelDefaultArchNameByLabelType(labelType!) ?? modelArchSchemas?.[0].name;
      const currentArchSchema =
        modelArchSchemas?.find(schema => schema.name === currentArchName) ?? modelArchSchemas[0];

      const hyperParamsModelConfig = {
        'learningParams.epochs':
          currentArchSchema?.schema.properties.learningParams?.default.epochs,
        archName: currentArchName,
        modelSize: currentArchSchema?.modelSize,
        'nmsParams.iou_threshold':
          currentArchSchema?.schema.properties.nmsParams?.default.iou_threshold,
      };
      if (labelType !== LabelType.BoundingBox) {
        delete hyperParamsModelConfig['nmsParams.iou_threshold'];
      }

      if (modelInfo?.customTrainingConfig) {
        newModel.config.transforms = modelInfo?.customTrainingConfig.transforms ?? [];
        newModel.config.augmentations = modelInfo?.customTrainingConfig.augmentations ?? [];
        newModel.config.trainingParams = {
          ...modelInfo.customTrainingConfig?.trainingParams!,
        };
        if (modelInfo.customTrainingConfig?.trainingParams?.hyperParams?.model) {
          newModel.config.trainingParams.hyperParams = {
            ...modelInfo.customTrainingConfig?.trainingParams?.hyperParams,
            model: hyperParamsModelConfig,
          };
        }
      } else {
        newModel.config.trainingParams = {
          projectId,
          hyperParams: {
            dataset: {
              training_split_name: 'train',
              validation_split_name: 'dev',
            },
            model: hyperParamsModelConfig,
          },
        };
      }
      // `model size` equals `archName`
      newModel.config.availableModelSizes =
        modelArchSchemas?.map(modelArchSchema => modelArchSchema.name) ?? [];
      newModel.config.defaultParameters = hyperParamsModelConfig;
      newModel.config.limits = transferDatasetLimitsFormat(currentArchSchema.datasetLimits);
      newModel.config.currentSchema = currentArchSchema;

      setModelsConfigList(prev => [...prev, newModel]);
    } else {
      enqueueSnackbar('Model configuration initialization failed', { variant: 'error' });
    }
  }, [
    modelArchSchemas,
    enqueueSnackbar,
    labelType,
    modelInfo,
    modelNames,
    modelsConfigList.length,
    projectId,
    setAutoFocusNameModel,
    setModelsConfigList,
  ]);

  useEffect(() => {
    if (
      modelInfo &&
      modelArchSchemas?.length &&
      modelsConfigList.length === 0 &&
      (modelArchSchemas?.length ?? 0) > 0 &&
      currentStep === CustomTrainingStepMode.ModelsConfiguration
    ) {
      addNewModelConfig();
    }
  }, [
    modelArchSchemas,
    addNewModelConfig,
    modelArchSchemas?.length,
    currentStep,
    modelInfo,
    modelsConfigList.length,
  ]);

  const { data: allFilters } = useGetDatasetFilterOptionsQuery(
    selectedDatasetVersion?.version ?? undefined,
  );
  const splitFilterOption = allFilters
    ? allFilters.find(value => value.filterName === 'Split' && value.filterType === 'column') ||
      allFilters.find(value => value.filterName === 'split')
    : undefined;
  const splitFilter = useMemo(
    () => (splitFilterOption?.value ? Object.keys(splitFilterOption.value) : []),
    [splitFilterOption?.value],
  );
  const supportedAutoSplitId: number[] = useMemo(
    () =>
      splitFilterOption?.filterType === 'column'
        ? supportedAutoSplit.map(s => splitFilterOption.value![s] as number)
        : [],
    [splitFilterOption?.filterType, splitFilterOption?.value],
  );

  const unassignedLabeledMediaSelectMediaOption: SelectMediaOption | undefined = useMemo(
    () =>
      splitFilterOption
        ? {
            ...defaultSelectOptions,
            fieldFilterMap:
              splitFilterOption.filterType === 'field' && splitFilter?.length
                ? {
                    [splitFilterOption.fieldId!]: { NOT_CONTAIN_ANY: splitFilter },
                  }
                : {},
            columnFilterMap: {
              datasetContent: {
                mediaStatus: { CONTAINS_ANY: [MediaStatusType.Approved] },
                ...(splitFilterOption.filterType === 'column' &&
                  supportedAutoSplitId.length && {
                    splitSet: { NOT_CONTAIN_ANY: supportedAutoSplitId },
                  }),
              },
            },
          }
        : undefined,
    [splitFilter, splitFilterOption, supportedAutoSplitId],
  );

  const [trainingTriggered, setTrainingTriggered] = useState(false);
  const [openTrainAlertDialog, setOpenTrainAlertDialog] = useState(false);

  const onLaunchTraining = useCallback(async () => {
    if (!projectId) return;
    setTrainingTriggered(true);
    try {
      if (enableAdjustSplit && defectDistributionWithAssignment.length > 0) {
        await autoSplit.mutateAsync({
          selectOptions: unassignedLabeledMediaSelectMediaOption!,
          splitByDefectDistribution: assignmentToDistributionToAssignSplitMapping(
            defectDistributionWithAssignment,
          ),
        });
      }
      await ProjectApi.upsertCustomTrainingConfig(projectId, modelsConfigList[0].config);
      const res = await Promise.allSettled(
        modelsConfigList.map(modelConfig =>
          trainModel(
            modelConfig.config,
            TrainMode.Advanced,
            false,
            selectedDatasetVersion?.version,
            modelConfig.name,
            modelConfig.description,
          ),
        ),
      );
      const allSuccess = res.every(result => result.status === 'fulfilled');
      if (allSuccess) {
        enqueueSnackbar(t('Start Training.'), { variant: 'success' });
        resetAndBackToBuild();
      } else {
        res.forEach(result => {
          if (result.status === 'fulfilled') {
            const index = modelsConfigList.findIndex(e => e.name === result.value.modelName);
            index > -1 && modelsConfigList.splice(index, 1);
          }
        });
        enqueueSnackbar(t('Start Training part of the jobs.'), { variant: 'success' });
      }
      pendoEntity?.track('parallel_training_model_count', {
        userName: currentUser ? `${currentUser.name} ${currentUser.lastName}` : undefined,
        orgName: currentUser ? currentUser.company : undefined,
        orgId: currentUser ? currentUser.orgId : undefined,
        modelCount: modelsConfigList.length,
        modelNames: modelsConfigList.map(model => model.name).join(', ') ?? undefined,
      });
    } catch (e) {
      enqueueSnackbar(`${e.message ?? 'Error starting training.'}`, { variant: 'error' });
    } finally {
      setTrainingTriggered(false);
    }
  }, [
    projectId,
    enableAdjustSplit,
    defectDistributionWithAssignment,
    modelsConfigList,
    currentUser,
    enqueueSnackbar,
    resetAndBackToBuild,
    autoSplit,
    unassignedLabeledMediaSelectMediaOption,
    trainModel,
    selectedDatasetVersion?.version,
  ]);

  const handleLaunchTraining = useCallback(async () => {
    if (showLargeImagesWarning) {
      setOpenTrainAlertDialog(true);
      return;
    }
    onLaunchTraining();
  }, [showLargeImagesWarning, onLaunchTraining]);

  const totalTrainCost = useMemo(() => {
    return modelsConfigList
      .reduce((total, model) => {
        return total + (model.cost?.trainCost ?? 0);
      }, 0)
      .toFixed(2);
  }, [modelsConfigList]);

  if (currentStep !== CustomTrainingStepMode.ModelsConfiguration) {
    return null;
  }

  return (
    <>
      <StepTitle step={2} title="Configure your model" />

      <Box className={styles.modelsConfiguration}>
        <ModelsConfiguraionTable />

        <Box display="flex" justifyContent={'space-between'} ml={-3} mt={5}>
          {modelsConfigList.length < 10 ? (
            <Button
              id="add-another-model"
              variant="text"
              color="primary"
              onClick={addNewModelConfig}
            >
              {t('Add Another Model')}
            </Button>
          ) : (
            <Box />
          )}

          <Box className={styles.pageControlButtons}>
            <Button
              id="cancel-customize-training-at-step2"
              variant="text"
              color="primary"
              onClick={resetAndBackToBuild}
            >
              {t('Cancel')}
            </Button>

            <Tooltip
              arrow
              placement="top"
              title={
                (datasetTrainingIssues.includes(DatasetTrainingIssueType.NotEnoughLabeledMedias) &&
                  t(`Need at least {{minCount}} labeled assigned images to train a model.`, {
                    minCount:
                      modelArchSchemas?.[0].datasetLimits.minLabeledMedia ??
                      MIN_LABELED_MEDIA_FOR_FAST_N_EASY_TRAIN,
                  })) ||
                (datasetTrainingIssues.includes(
                  DatasetTrainingIssueType.NotEnoughMediasInTrainSplit,
                ) &&
                  t(`At least 2 images are required in the train split.`)) ||
                (datasetTrainingIssues.includes(
                  DatasetTrainingIssueType.InvalidClassificationTrain,
                ) &&
                  t('Classification projects should contain at least two classes created.')) ||
                (configErrors && t('Please check the form to correct parameters')) ||
                (isCreditReferenceEnabled &&
                  'Custom training cost is calculated based on model settings (image resize, epochs) and # of images. If the default settings are used, each image in the project equals 1 credit.') ||
                ''
              }
            >
              <div>
                <Button
                  id="customize-training-go-next-at-step1"
                  variant="contained"
                  color="primary"
                  onClick={handleLaunchTraining}
                  disabled={
                    duplicateNameIds.length > 0 ||
                    configErrors ||
                    datasetTrainingIssues.length > 0 ||
                    isInProgress ||
                    !modelsConfigList.length ||
                    trainingTriggered
                  }
                >
                  {(isInProgress || trainingTriggered) && (
                    <CircularProgress color="inherit" size={16} style={{ marginRight: 10 }} />
                  )}
                  {isCreditReferenceEnabled && shouldUseAdvancedPricing && !isInProgress
                    ? t(`Train (${totalTrainCost} Credits)`)
                    : t('Train')}
                </Button>
              </div>
            </Tooltip>
          </Box>
        </Box>
      </Box>

      <CustomTrainAlertDialog
        open={openTrainAlertDialog}
        onCancel={() => setOpenTrainAlertDialog(false)}
        onOk={async () => {
          await onLaunchTraining();
          setOpenTrainAlertDialog(false);
        }}
        isExecuting={trainingTriggered}
      />
    </>
  );
};

export default ModelsConfiguration;
