import React, { useState, useCallback, useEffect, useRef, useMemo } from 'react';
import { Box, TableCell, makeStyles, CircularProgress, Tooltip } from '@material-ui/core';
import { useAtom, useSetAtom } from 'jotai';
import { cloneDeep, isEqual, isEmpty } from 'lodash';
import { useSnackbar } from 'notistack';

import {
  TransformType,
  Pipeline,
  TransformParams,
  Media,
  MediaStatusType,
} from '@clef/shared/types';
import { Typography } from '@clef/client-library';

import { CostIcon } from '@/images/custom_training/icons';
import { TRANSFORMS_UI_SCHEMA } from '@/constants/model_train';
import { defaultSelectOptions } from '@/constants/data_browser';
import { PreviewDimensions } from '@/types/client';
import PreviewModal from '@/pages/DataBrowser/TrainModelButtonGroup/PreviewModal';
import MediaAPI from '@/api/media_api';
import ExperimentReportApi from '@/api/experiment_report_api';
import { useTransformsApi } from '@/hooks/api/useTransformsApi';
import { ClientFeatures, useFeatureGateEnabled } from '@/hooks/useFeatureGate';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import { useDatasetMediaCountQuery, useDatasetMediasQuery } from '@/serverStore/dataset';
import {
  modelsConfigListAtom,
  selectedDatasetVersionAtom,
  isInProgressAtom,
  configErrorsAtom,
  showLargeImagesWarningAtom,
} from '@/uiStates/customTraining/pageUIStates';
import {
  getTransformDimensions,
  extractNewTransform,
  getRandomMedia,
  updateSections,
  getMaxTransformPixels,
  getPreviewDimensions,
  getErrorsByTransformAndParamName,
  getErrorsByHyperParams,
} from '@/utils/job_train_utils';

import ConfigList from './ConfigList';

const useStyles = makeStyles(theme => ({
  configCell: {
    verticalAlign: 'top',
  },
  alerts: {
    position: 'absolute',
    top: theme.spacing(4),
    left: '50%',
    transform: 'translateX(-50%)',
    display: 'flex',
    justifyContent: 'center',
    alignSelf: 'center',
    borderRadius: 8,
    alignItems: 'center',
    fontWeight: 400,
    fontSize: '14px',
  },
  buttonProgress: {
    marginLeft: theme.spacing(3),
    color: theme.palette.grey[100],
  },
  costItem: {
    display: 'flex',
    alignItems: 'center',
    gap: theme.spacing(1),
    margin: theme.spacing(2, 0, 4, 0),
    whiteSpace: 'nowrap',
    color: '#4D5761',
  },
}));

interface ConfigCellsProps {
  rowIndex: number;
}

const ConfigCells: React.FC<ConfigCellsProps> = ({ rowIndex }) => {
  const styles = useStyles();

  const { labelType } = useGetSelectedProjectQuery().data ?? {};

  const shouldUseAdvancedPricing = useFeatureGateEnabled(ClientFeatures.AdvancedUsageBasedPricing);

  const [modelsConfigList, setModelsConfigList] = useAtom(modelsConfigListAtom);
  const currentTrainingConfig = modelsConfigList[rowIndex];
  const { trainingParams, transforms, augmentations, limits, currentSchema } =
    currentTrainingConfig.config;

  const [selectedDatasetVersion] = useAtom(selectedDatasetVersionAtom);
  const setisInProgress = useSetAtom(isInProgressAtom);
  const [configErrors, setConfigErrors] = useAtom(configErrorsAtom);
  const [showLargeImagesWarning] = useAtom(showLargeImagesWarningAtom);

  const { enqueueSnackbar } = useSnackbar();

  // initialize transformUiSchema
  useEffect(() => {
    if (currentTrainingConfig.transformUiSchema.length === 0) {
      const transformsUiSchemaOrigin = cloneDeep(TRANSFORMS_UI_SCHEMA);
      delete transformsUiSchemaOrigin[1].AutoResize;
      setModelsConfigList(prev => {
        const newModelsConfigList = [...prev];
        newModelsConfigList[rowIndex].transformUiSchema = transformsUiSchemaOrigin;
        return newModelsConfigList;
      });
    }
  }, [currentTrainingConfig, rowIndex, setModelsConfigList]);

  const [transformParams = []] = useTransformsApi({});
  useEffect(() => {
    if (transformParams.length === 0) {
      setisInProgress(true);
    } else {
      setisInProgress(false);
    }
  }, [setisInProgress, transformParams]);

  const { data: datasetMedias } = useDatasetMediasQuery({
    sortOptions: {
      offset: 0,
      limit: 50,
    },
    version: selectedDatasetVersion?.version,
  });
  const randomMediaIds = useMemo(() => {
    return datasetMedias ? getRandomMedia(datasetMedias.slice()).map(m => m.id) : [];
  }, [datasetMedias]);
  const [medias, setMedias] = useState<Media[]>([]);
  useEffect(() => {
    async function fetchMedias() {
      if (randomMediaIds.length > 0) {
        try {
          const response = await MediaAPI.getMediasByIds(randomMediaIds);
          setMedias(response.data);
        } catch (error) {
          setMedias([]);
        }
      }
    }

    fetchMedias();
  }, [randomMediaIds]);

  const { data: mediaCount } = useDatasetMediaCountQuery({
    version: selectedDatasetVersion?.version,
    selectOptions: defaultSelectOptions,
  });

  const { data: labeledMediaCount } = useDatasetMediaCountQuery({
    version: selectedDatasetVersion?.version,
    selectOptions: {
      fieldFilterMap: {},
      columnFilterMap: {
        datasetContent: { mediaStatus: { CONTAINS_ANY: [MediaStatusType.Approved] } },
      },
      selectedMedia: [],
      unselectedMedia: [],
      isUnselectMode: true,
    },
  });

  const maxTransformPixels = getMaxTransformPixels(labelType, limits?.largeImage);

  const [pipelineSections, setPipelineSections] = useState<Pipeline['sections']>({
    train: [...transforms, ...augmentations],
    valid: [],
  });

  const [resizeId, setResizeId] = useState('');
  const [rescaleId, setRescaleId] = useState('');
  const [cropId, setCropId] = useState('');
  const [showSizeAlert, setShowSizeAlert] = React.useState<boolean>(false);

  useEffect(() => {
    if (transformParams.length) {
      setResizeId(transformParams.find(t => t.name === 'Resize')?.id || '');
      setRescaleId(transformParams.find(t => t.name === 'RescaleWithPadding')?.id || '');
      setCropId(transformParams.find(t => t.name === 'Crop')?.id || '');
    }
  }, [transformParams]);

  useEffect(() => {
    const verifyResizeValues = () => {
      let transformMegapixels;
      pipelineSections.train.map(transform => {
        const { width, height } = getTransformDimensions(transform, transformParams);
        if (width && height) {
          transformMegapixels = width * height;
        }
      });

      if (transformMegapixels && transformMegapixels > maxTransformPixels) {
        setShowSizeAlert(true);
      } else if (
        (transformMegapixels && transformMegapixels <= maxTransformPixels) ||
        !pipelineSections.train.length
      ) {
        showSizeAlert && setShowSizeAlert(false);
      }
    };

    verifyResizeValues();
  }, [maxTransformPixels, pipelineSections, showSizeAlert, transformParams]);

  const [previewDimensions, setPreviewDimensions] = useState<PreviewDimensions | null>(null);

  const preprocessingParams = transformParams
    .filter(t => TRANSFORMS_UI_SCHEMA[1].hasOwnProperty(t.name))
    .map(param => {
      const matchedSchema = TRANSFORMS_UI_SCHEMA[1][param.name];
      return {
        ...param,
        label: matchedSchema ? matchedSchema.label : undefined,
      };
    });
  const augmentationParams = transformParams
    .filter(t => TRANSFORMS_UI_SCHEMA[2].hasOwnProperty(t.name))
    .map(param => {
      const matchedSchema = TRANSFORMS_UI_SCHEMA[2][param.name];
      return {
        ...param,
        label: matchedSchema ? matchedSchema.label : undefined,
      };
    });

  const preprocessingTransformParams = pipelineSections.train.filter(s =>
    preprocessingParams.some(p => p.id === s.id),
  );
  const augmentationTransformParams = pipelineSections.train.filter(s =>
    augmentationParams.some(p => p.id === s.id),
  );

  // initially set default transforms & augmentations
  useEffect(() => {
    if (!transformParams.length) {
      return;
    }
    const defaultTransformNames = currentSchema?.preprocessing.default ?? [];
    const defaultTransforms = defaultTransformNames
      .map(name => {
        const transform = transformParams.find(param => param.name === name);
        const rule = currentSchema?.preprocessing.params?.[transform?.name ?? '']?.properties;
        return (transform ? extractNewTransform(transform, rule, true) : null) as TransformParams;
      })
      .filter(Boolean);
    const defaultAugmentationNames = currentSchema?.augmentations.default ?? [];
    const defaultAugmentations = defaultAugmentationNames
      .map(name => {
        const transform = transformParams.find(param => param.name === name);
        const rule = currentSchema?.augmentations.params?.[transform?.name ?? '']?.properties;
        return (transform ? extractNewTransform(transform, rule, true) : null) as TransformParams;
      })
      .filter(Boolean);

    if (
      !transforms.length &&
      defaultTransforms.length &&
      !augmentations.length &&
      defaultAugmentations.length
    ) {
      const transforms = [...defaultTransforms];
      const augmentations = [...defaultAugmentations];
      setModelsConfigList(prev => {
        const newModelsConfigList = [...prev];
        newModelsConfigList[rowIndex].config.transforms = transforms;
        newModelsConfigList[rowIndex].config.augmentations = augmentations;

        const prevUiSchema = newModelsConfigList[rowIndex].transformUiSchema;
        if (prevUiSchema.length > 0) {
          const _transformsUiSchema = cloneDeep(prevUiSchema);
          for (const transform of transforms) {
            _transformsUiSchema[1][transform.name!].isAdded = true;
          }
          for (const augmentation of augmentations) {
            _transformsUiSchema[2][augmentation.name!].isAdded = true;
          }
          newModelsConfigList[rowIndex].transformUiSchema = _transformsUiSchema;
        }

        return newModelsConfigList;
      });
      setPipelineSections({
        train: [...transforms, ...augmentations],
        valid: [],
      });
    }
  }, [
    augmentations.length,
    rowIndex,
    currentSchema,
    setModelsConfigList,
    transformParams,
    transforms.length,
  ]);

  const prevPipelineSectionsRef = useRef<typeof pipelineSections | null>(null);
  const applyPipelineSections = useCallback(
    async (sections: Pipeline['sections'], sectionType: TransformType) => {
      const { upgradedTransforms, upgradedAugmentations } = updateSections(
        sections,
        pipelineSections,
        sectionType,
        transformParams,
      );

      // pull resize or rescale to the very front
      const resizeOrRescaleIndex = upgradedTransforms.findIndex(
        t => t.id === resizeId || t.id === rescaleId,
      );
      if (resizeOrRescaleIndex > 0) {
        const [resizeOrRescale] = upgradedTransforms.splice(resizeOrRescaleIndex, 1);
        upgradedTransforms.unshift(resizeOrRescale);
      }

      const upgradedSections: Pipeline['sections'] = {
        train: [...upgradedTransforms, ...upgradedAugmentations],
        valid: [],
      };
      if (isEqual(upgradedSections, prevPipelineSectionsRef.current)) {
        return;
      }
      prevPipelineSectionsRef.current = upgradedSections;

      setPipelineSections(upgradedSections);
      setModelsConfigList(prev => {
        const newModelsConfigList = [...prev];
        newModelsConfigList[rowIndex].config.transforms = upgradedTransforms;
        newModelsConfigList[rowIndex].config.augmentations = upgradedAugmentations;
        return newModelsConfigList;
      });

      if (upgradedSections && upgradedSections.train?.length > 0) {
        setPreviewDimensions(getPreviewDimensions(transformParams, upgradedSections));
      } else {
        setPreviewDimensions(null);
      }
    },
    [pipelineSections, transformParams, setModelsConfigList, resizeId, rescaleId, rowIndex],
  );

  const { width, height } = useMemo(() => {
    const resizeOrRescale = pipelineSections.train.find(
      t => t.id === resizeId || t.id === rescaleId,
    );
    const { width, height } =
      (resizeOrRescale?.params.reduce(
        (res, { name, value }) => ({ ...res, [name]: value }),
        {},
      ) as Record<string, number>) ?? {};
    return { width, height };
  }, [pipelineSections.train, rescaleId, resizeId]);

  const [resizedWidth, resizedHeight] = useMemo(() => {
    if (!shouldUseAdvancedPricing || transformParams.length === 0) {
      return [0, 0];
    }
    const cropId = transformParams.find(t => t.name === 'Crop')?.id || '';
    const cropValue = pipelineSections.train.find(t => t.id === cropId);
    if (cropValue) {
      return cropValue.params.reduce(
        ([width, height]: [number, number], { name, value }) => {
          if (name === 'x_min') {
            return [width - value, height];
          } else if (name == 'x_max') {
            return [width + value, height];
          } else if (name === 'y_min') {
            return [width, height - value];
          } else if (name === 'y_max') {
            return [width, height + value];
          }
          return [width, height];
        },
        [0, 0],
      );
    }
    return [width, height];
  }, [shouldUseAdvancedPricing, transformParams, pipelineSections.train, width, height]);

  const [isCostCalculating, setIsCostCalculating] = useState<boolean>(false);

  useEffect(() => {
    const calculate = async () => {
      if (
        shouldUseAdvancedPricing &&
        resizedWidth &&
        resizedHeight &&
        trainingParams?.hyperParams?.model['learningParams.epochs'] &&
        trainingParams?.hyperParams?.model.modelSize &&
        labelType &&
        mediaCount &&
        labeledMediaCount
      ) {
        setIsCostCalculating(true);
        setisInProgress(true);
        const [trainCostResponse, predictionCostResponse] = await Promise.allSettled([
          ExperimentReportApi.getCustomTrainingCreditCost({
            numEpochs: Number(trainingParams?.hyperParams?.model['learningParams.epochs']),
            numLabeled: labeledMediaCount,
            numUnlabeled: mediaCount - labeledMediaCount,
            width: resizedWidth,
            height: resizedHeight,
            labelType,
          }),
          ExperimentReportApi.getCustomTrainingInferenceCost({
            width: resizedWidth,
            height: resizedHeight,
            modelSize: trainingParams?.hyperParams?.model.modelSize as string,
            labelType,
          }),
        ]);

        if (trainCostResponse.status === 'rejected') {
          enqueueSnackbar(t('Failed to get train cost.'), {
            variant: 'error',
            autoHideDuration: 12000,
          });
        }
        if (predictionCostResponse.status === 'rejected') {
          enqueueSnackbar(t('Failed to get inference cost.'), {
            variant: 'error',
            autoHideDuration: 12000,
          });
        }

        setModelsConfigList(prev => {
          const newModelsConfigList = [...prev];
          newModelsConfigList[rowIndex].cost = {
            trainCost:
              trainCostResponse.status === 'fulfilled'
                ? trainCostResponse.value.data.trainCredits
                : undefined,
            predictionCost:
              predictionCostResponse.status === 'fulfilled'
                ? predictionCostResponse.value.data.inferenceCredits
                : undefined,
          };
          return newModelsConfigList;
        });

        setIsCostCalculating(false);
        setisInProgress(false);
      }
    };
    calculate();
  }, [
    labelType,
    labeledMediaCount,
    mediaCount,
    resizedHeight,
    resizedWidth,
    rowIndex,
    setisInProgress,
    setModelsConfigList,
    shouldUseAdvancedPricing,
    trainingParams?.hyperParams?.model,
    enqueueSnackbar,
  ]);

  const [previewModalOpen, setPreviewModalOpen] = useState<boolean>(false);
  const [activeStep, setActiveStep] = useState<string>();

  const handleModalOpen = useCallback((step: string) => {
    setPreviewModalOpen(true);
    setActiveStep(step);
  }, []);

  const errorsByHyperParams = useMemo(() => {
    return getErrorsByHyperParams(
      currentSchema?.schema,
      Number(trainingParams?.hyperParams?.model['learningParams.epochs']),
    );
  }, [currentSchema, trainingParams?.hyperParams?.model]);

  const hasError = !isEmpty({
    ...getErrorsByTransformAndParamName(
      pipelineSections.train,
      width,
      height,
      transformParams,
      limits,
      currentSchema,
    ),
    ...errorsByHyperParams,
  });
  useEffect(() => {
    setConfigErrors(hasError);
  }, [configErrors, hasError, setConfigErrors]);

  const hasResizeOrRescale = useMemo(
    () => pipelineSections.train.some(t => t.id === resizeId || t.id === rescaleId),
    [pipelineSections.train, rescaleId, resizeId],
  );

  const isCreditReferenceEnabled = !useFeatureGateEnabled(ClientFeatures.DisableCreditReference);

  // ensure that the transformUiSchema is not empty when rendering ConfigList
  if (currentTrainingConfig.transformUiSchema.length === 0) {
    return null;
  }

  return (
    <>
      <TableCell className={styles.configCell}>
        {transforms.length > 0 ? (
          <ConfigList
            rowIndex={rowIndex}
            transformType={TransformType.TRANSFORM}
            medias={medias}
            transformParams={preprocessingParams}
            appliedPipeline={pipelineSections}
            resizeId={resizeId}
            rescaleId={rescaleId}
            cropId={cropId}
            applyPipelineSections={applyPipelineSections}
            onOpenPreview={handleModalOpen}
            showSizeAlert={showSizeAlert}
            hasResizeOrRescale={transformParams.length ? hasResizeOrRescale : undefined}
          />
        ) : (
          <CircularProgress color="primary" size={16} />
        )}
      </TableCell>
      <TableCell className={styles.configCell}>
        {transforms.length > 0 ? (
          <ConfigList
            rowIndex={rowIndex}
            transformType={TransformType.AUGMENTATION}
            medias={medias}
            transformParams={augmentationParams}
            appliedPipeline={pipelineSections}
            resizeId={resizeId}
            rescaleId={rescaleId}
            cropId={cropId}
            applyPipelineSections={applyPipelineSections}
            onOpenPreview={handleModalOpen}
          />
        ) : (
          <CircularProgress color="primary" size={16} />
        )}
      </TableCell>
      {isCreditReferenceEnabled && (
        <TableCell className={styles.configCell}>
          {shouldUseAdvancedPricing && (
            <>
              {isCostCalculating ? (
                <CircularProgress color="primary" size={16} />
              ) : (
                <>
                  <Tooltip
                    placement="top"
                    title={t(
                      'Train cost is calculated based on model settings (image resize, epochs) and # of images. If the default settings are used, each image in the project equals 1 credit.',
                    )}
                    arrow
                  >
                    <Box className={styles.costItem}>
                      <Typography>{t('Train')}</Typography>
                      <CostIcon />
                      <Typography variant="body_medium" style={{ color: '#121926' }}>
                        {currentTrainingConfig.cost?.trainCost ?? '--'}
                      </Typography>
                      <Typography>{t('total')}</Typography>
                    </Box>
                  </Tooltip>
                  {!showLargeImagesWarning && (
                    <Tooltip
                      placement="top"
                      title={t(
                        'Prediction cost is calculated based on training settings (image resize, model size) and project type. ',
                      )}
                      arrow
                    >
                      <Box className={styles.costItem}>
                        <Typography>{t('Prediction')}</Typography>
                        <CostIcon />
                        <Typography variant="body_medium" style={{ color: '#121926' }}>
                          {currentTrainingConfig.cost?.predictionCost ?? '--'}
                        </Typography>
                        <Typography style={{ whiteSpace: 'nowrap' }}>
                          {`/ ` + t('image')}
                        </Typography>
                      </Box>
                    </Tooltip>
                  )}
                </>
              )}
            </>
          )}
        </TableCell>
      )}

      {datasetMedias && previewModalOpen && (
        <PreviewModal
          activeStep={activeStep!}
          open
          onClose={() => setPreviewModalOpen(false)}
          loadingAllMedia={false}
          applyPipelineSections={applyPipelineSections}
          pipelineSections={
            hasError
              ? undefined
              : {
                  train:
                    activeStep === 'Augmentation'
                      ? augmentationTransformParams
                      : preprocessingTransformParams,
                  valid: [],
                }
          }
          previewDimensions={activeStep === 'Augmentation' ? undefined : previewDimensions}
          paginatedMedias={medias ?? []}
        />
      )}
    </>
  );
};

export default ConfigCells;
