import {
  Box,
  LinearProgress,
  makeStyles,
  Step,
  StepContent,
  StepLabel,
  Stepper,
} from '@material-ui/core';
import { useSnackbar } from 'notistack';
import React, { useCallback, useEffect, useMemo, useState } from 'react';

import { Button, Typography } from '@clef/client-library';
import { LabelType, ModelStatus } from '@clef/shared/types';

import { useDatasetExportedWithVersionsQuery } from '@/serverStore/dataset';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import maglev_api from '../../../../api/maglev_api';
import { useDialog } from '../../../../components/Layout/components/useDialog';
import { isModelTrainingHasLearningCurve } from '../../../../store/projectModelInfoState/utils';

import { TrainingStepType } from '@/constants/model_train';
import { ClientFeatures, useFeatureGateEnabled } from '@/hooks/useFeatureGate';
import { useModelStatusQuery } from '@/serverStore/projectModels';
import LearningCurveChart from './LearningCurveChart';
import ProvisionGpuTimer from './ProvisionGpuTimer';

const useStyles = makeStyles(theme => ({
  modelName: {
    fontWeight: 500,
  },
  stepperRoot: {
    boxShadow: 'none!important',
    padding: 0,
  },
  stepLabel: {
    fontWeight: 'bold',
    fontSize: 14,
  },
  stepContent: {
    color: theme.palette.grey[500],
    fontSize: 14,

    '& > span': {
      '& > span:first-child': {
        whiteSpace: 'nowrap !important',
      },
    },
  },
  snapshotName: {
    fontWeight: 500,
    textOverflow: 'ellipsis',
    wordBreak: 'break-word',
  },
  snapshotProgressRoot: {
    width: 256,
    backgroundColor: theme.palette.greyModern[200],
    borderRadius: 2,
  },
  snapshotProgress: {
    backgroundColor: theme.palette.green[400],
    borderRadius: 2,
  },
  manuallyEndedChip: {
    display: 'inline-block',
    padding: '2px 8px',
    borderRadius: 5,
    color: theme.palette.indigoBlue[600],
    fontSize: 12,
    fontWeight: 700,
    backgroundColor: theme.palette.indigoBlue[100],
  },
  infoBlock: {
    padding: theme.spacing(5, 10),
    borderRadius: 6,
    backgroundColor: theme.palette.greyModern[50],
    marginRight: theme.spacing(4),
    '&:last-child': {
      marginRight: 0,
    },
  },
  chartContainer: {
    backgroundColor: theme.palette.common.white,
  },
}));

type ModelTrainingInProgressProps = {
  modelId: string | undefined;
  showLearningGraphBlock?: boolean;
};

const ModelTrainingInProgress: React.FC<ModelTrainingInProgressProps> = props => {
  const styles = useStyles();
  const { showLearningGraphBlock = false, modelId: selectedModelId } = props;
  const { id: projectId, labelType } = useGetSelectedProjectQuery().data ?? {};
  const { data: modelStatus } = useModelStatusQuery(projectId, selectedModelId);

  const { data: datasetExported } = useDatasetExportedWithVersionsQuery({
    withCount: true,
    includeNotCompleted: true,
    includeFastEasy: true,
  });

  const curSnapshot = useMemo(() => {
    if (!datasetExported?.datasetVersions) return undefined;
    return datasetExported.datasetVersions.find(
      datasetVersion => datasetVersion.id === modelStatus?.datasetVersionId,
    );
  }, [datasetExported?.datasetVersions, modelStatus?.datasetVersionId]);

  const [isTrainingStoppedEarly, setIsTrainingStoppedEarly] = useState<boolean>(false);

  const ManuallyEndedChip = useMemo(() => {
    if (!isTrainingStoppedEarly) return null;
    return <Box className={styles.manuallyEndedChip}>Manually Ended</Box>;
  }, [isTrainingStoppedEarly, styles.manuallyEndedChip]);

  // below are for the progress bar for snapshot step
  const [showSnapshotProgress, setShowSnapshotProgress] = useState<boolean>(false);
  const [snapshotProgress, setSnapshotProgress] = React.useState(0);

  useEffect(() => {
    if (modelStatus?.status === ModelStatus.Created && snapshotProgress === 0) {
      setShowSnapshotProgress(true);
    }
    if (snapshotProgress === 100) {
      setShowSnapshotProgress(false);
    }
  }, [modelStatus, snapshotProgress]);

  useEffect(() => {
    let timer: ReturnType<typeof setInterval>;
    if (showSnapshotProgress) {
      timer = setInterval(() => {
        setSnapshotProgress(oldProgress => {
          if (oldProgress === 100) {
            clearInterval(timer);
          }
          const diff = Math.random() * 10;
          return Math.min(oldProgress + diff, 100);
        });
      }, 100);
    }
    return () => {
      clearInterval(timer);
    };
  }, [showSnapshotProgress]);

  const activeStep = useMemo(() => {
    if (!modelStatus) {
      return undefined;
    }
    const { status, metricsReady } = modelStatus;
    if (status === ModelStatus.Created) {
      if (snapshotProgress < 100) {
        return 0;
      } else {
        return 1;
      }
    } else if (status === ModelStatus.Starting) {
      return 1;
    } else if (status === ModelStatus.Training || status === ModelStatus.Evaluating) {
      return 2;
    } else if (status === ModelStatus.Publishing || !metricsReady) {
      return 3;
    }
    return undefined;
  }, [modelStatus, snapshotProgress]);

  // below are for the dots animation of Provisioning GPU step
  const [provisionStepContent, setProvisionStepContent] = useState<string>(
    t('Warming up a virtual computer in the cloud'),
  );

  useEffect(() => {
    let timer: ReturnType<typeof setInterval> | null = null;
    if (activeStep === 1) {
      let dotCount = 0;
      timer = setInterval(() => {
        if (dotCount === 3) {
          setProvisionStepContent(prev => prev.slice(0, -3));
          dotCount = 0;
        } else {
          setProvisionStepContent(prev => prev + '.');
          dotCount++;
        }
      }, 1000);
    } else if (timer) {
      clearInterval(timer);
    }
    return () => {
      if (timer) {
        clearInterval(timer);
      }
    };
  }, [activeStep]);

  // below are for the dots animation of Calculation step
  const [calculationContent, setCalculationContent] = useState<string>(t('Loading metrics'));

  useEffect(() => {
    let timer: ReturnType<typeof setInterval> | null = null;
    if (activeStep === 3) {
      let dotCount = 0;
      timer = setInterval(() => {
        if (dotCount === 3) {
          setCalculationContent(prev => prev.slice(0, -3));
          dotCount = 0;
        } else {
          setCalculationContent(prev => prev + '.');
          dotCount++;
        }
      }, 1000);
    } else if (timer) {
      clearInterval(timer);
    }
    return () => {
      if (timer) {
        clearInterval(timer);
      }
    };
  }, [activeStep]);

  const steps = useMemo(() => {
    return [
      {
        type: TrainingStepType.Snapshot,
        label: t('Preparing the snapshot of your data'),
        content: curSnapshot
          ? t(`snapshot {{name}} with {{count}} images.`, {
              name: <span className={styles.snapshotName}>{curSnapshot.name}</span>,
              count: <span style={{ fontWeight: 500 }}>{curSnapshot.count}</span>,
            })
          : null,
      },
      {
        type: TrainingStepType.Provision,
        label: t('Provisioning GPU'),
        content: provisionStepContent,
      },
      {
        type: TrainingStepType.Train,
        label: t('Training & learning'),
        content: undefined,
      },
      {
        type: TrainingStepType.Calculation,
        label: t('Calculating model performance'),
        content: calculationContent,
      },
    ];
  }, [curSnapshot, styles.snapshotName, provisionStepContent, calculationContent]);

  const { showConfirmationDialog } = useDialog();
  const { enqueueSnackbar } = useSnackbar();
  const enableGPUTimer = useFeatureGateEnabled(ClientFeatures.ShowGpuTimer);
  const stopTraining = useCallback(async () => {
    if (!projectId) return;
    try {
      await maglev_api.postStopTraining({
        projectId,
        jobId: selectedModelId!,
      });
      setIsTrainingStoppedEarly(true);
      enqueueSnackbar(t('stopping training job'), {
        variant: 'success',
      });
    } catch (error) {
      if (error.status >= 500) {
        enqueueSnackbar(t('Error while attempting to stop training job'), { variant: 'error' });
      } else if (error.status >= 400 && error.status < 500) {
        enqueueSnackbar(t(error.message), { variant: 'info' });
      }
    }
  }, [enqueueSnackbar, projectId, selectedModelId]);

  if (activeStep === undefined) {
    return null;
  }

  return (
    <>
      <Stepper
        activeStep={activeStep}
        orientation="vertical"
        elevation={0}
        classes={{ root: styles.stepperRoot }}
      >
        {steps.map(step => (
          <Step
            key={step.label}
            expanded={
              step.type === TrainingStepType.Snapshot ||
              (step.type === TrainingStepType.Train && activeStep > 2)
            }
          >
            <StepLabel>
              <Box
                display={'flex'}
                flexDirection="row"
                justifyContent={'flex-start'}
                alignItems={'center'}
              >
                <Typography className={styles.stepLabel}>{step.label}</Typography>
                {enableGPUTimer &&
                  modelStatus?.createdAt &&
                  step.type === TrainingStepType.Provision &&
                  activeStep === 1 && (
                    <ProvisionGpuTimer
                      startTime={modelStatus?.createdAt}
                      modelId={selectedModelId}
                    />
                  )}
              </Box>
            </StepLabel>
            {/* {(step.type !== TrainingStepType.Train || !isTrainingStoppedEarly) && ( */}
            <StepContent>
              {step.type === TrainingStepType.Snapshot && showSnapshotProgress && (
                <LinearProgress
                  variant="determinate"
                  classes={{
                    root: styles.snapshotProgressRoot,
                    bar1Determinate: styles.snapshotProgress,
                  }}
                  value={snapshotProgress}
                />
              )}
              {!showSnapshotProgress && (
                <Typography className={styles.stepContent}>{step.content}</Typography>
              )}
              {step.type === TrainingStepType.Train &&
                activeStep >= 2 &&
                labelType !== LabelType.SegmentationInstantLearning &&
                isModelTrainingHasLearningCurve(modelStatus?.status) &&
                (showLearningGraphBlock ? (
                  <Box display="flex" flexDirection="column" className={styles.infoBlock}>
                    <Typography variant={'body_bold'} display="inline">
                      {t('Loss Chart')}
                    </Typography>
                    <Box className={styles.chartContainer}>
                      <LearningCurveChart
                        modelId={selectedModelId}
                        topRightTag={ManuallyEndedChip}
                        isTraining={activeStep === 2}
                        aspectRatio={2}
                        hideMarginBottom={true}
                        hideLegends={true}
                        hideValidationCurve={true}
                      />
                    </Box>
                  </Box>
                ) : (
                  <LearningCurveChart
                    hideLegends={true}
                    hideValidationCurve={true}
                    modelId={selectedModelId}
                    topRightTag={ManuallyEndedChip}
                    isTraining={activeStep === 2}
                  />
                ))}
              {step.type === TrainingStepType.Train && activeStep === 2 && false && (
                <Box marginY={3}>
                  <Button
                    id="stop-training"
                    variant="outlined"
                    size="small"
                    tooltipOnButton
                    tooltip={
                      'This will end the training process and provide you with the model trained to the current epoch.'
                    }
                    onClick={() => {
                      showConfirmationDialog({
                        title: t('Are you sure you want to end training'),
                        content: t(
                          'This will end the training process and provide you with the model trained to the current epoch. This can result in a lower performing model.',
                        ),
                        confirmText: t('End Training'),
                        color: 'primary',
                        onConfirm: async () => {
                          stopTraining();
                        },
                      });
                    }}
                  >
                    {t('End Training Now')}
                  </Button>
                </Box>
              )}
            </StepContent>
            {/* )} */}
          </Step>
        ))}
      </Stepper>
      <div style={{ flex: 1 }}></div>
    </>
  );
};

export default ModelTrainingInProgress;
