import React, { useMemo, useState } from 'react';
import cx from 'classnames';
import { Box, makeStyles, TextField } from '@material-ui/core';
import { Typography } from '@clef/client-library';
import Search from '@material-ui/icons/Search';
import ConfusionMatrixEmptyStateSVG from '@/images/empty-state/confusion-matrix.svg';
import { useGetConfusionMatrixQuery } from '@/serverStore/modelAnalysis';
import useGetDefectNameById from '@/hooks/defect/useGetDefectNameById';
import { PredictionMatrixData } from '@/pages/DataBrowser/ModelPerformance/ConfusionMatrix';
import { EvaluationSetItem } from '@/api/evaluation_set_api';
import { AggregatedConfusionMatrix, LabelType, RegisteredModel } from '@clef/shared/types';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import LoadingProgress from '../LoadingProgress';
import { ConfusionMatrixDiffTable } from './ConfusionMatrixDiffTable';
import { useAtom } from 'jotai';
import { modelListFilterOptionsAtom } from '../atoms';

const useStyles = makeStyles(theme => ({
  HeaderWithClearButton: {
    display: 'flex',
    justifyContent: 'space-between',
    alignItems: 'center',
  },
  clearFilterButton: {
    color: theme.palette.blue[600],
    fontSize: 12,
    fontWeight: 700,
    lineHeight: '16px',
    cursor: 'pointer',
  },
  searchIcon: {
    color: theme.palette.grey[500],
    marginRight: theme.spacing(1),
  },
  confusionMatrixSearchField: {
    borderRadius: 6,
    height: 36,
    background: theme.palette.greyModern[25],
  },
  correctColor: {
    color: theme.palette.green[500],
  },
  incorrectColor: {
    color: theme.palette.error.main,
  },
  matrixRow: {
    height: 36,
  },
  matrixLeftCell: {
    paddingLeft: theme.spacing(2),
  },
  matrixRightCell: {
    paddingRight: theme.spacing(2),
  },
  matrixCell: {
    verticalAlign: 'middle',
  },
  infoTitle: {
    display: 'flex',
    alignItems: 'center',
    gap: theme.spacing(1),
  },
  emptyState: {
    width: '100%',
    display: 'flex',
    gap: theme.spacing(4),
    alignItems: 'center',
    flexDirection: 'column',
    borderRadius: '10px',
    padding: theme.spacing(14, 0),
    background: theme.palette.grey[50],
  },
  emptyStateIcon: {
    width: theme.spacing(14),
    height: theme.spacing(14),
  },
}));

export type ModelConfusionMatrixProps = {
  model?: RegisteredModel;
  evaluationSet?: EvaluationSetItem;
  threshold?: number;
};

const ModelConfusionMatrix: React.FC<ModelConfusionMatrixProps> = props => {
  const styles = useStyles();
  const { model, evaluationSet, threshold } = props;
  const { data: confusionMatrixData, isLoading: isConfusionMatrixDataLoading } =
    useGetConfusionMatrixQuery(model?.id, evaluationSet?.id, threshold);
  const [searchText, setSearchText] = useState('');
  const [filterOptions] = useAtom(modelListFilterOptionsAtom);

  const getDefectNameById = useGetDefectNameById();
  const {
    correctConfusionMatrix,
    falsePositiveConfusionMatrix,
    falseNegativeConfusionMatrix,
    misClassificationConfusionMatrix,
  } = useMemo(() => {
    const addNameAndFilterSearchKeyToMatrix = (
      confusionMatrix: AggregatedConfusionMatrix[],
      lowerSearchText: string,
    ) => {
      const confusionMatrixWithNames = confusionMatrix
        .filter(m => m.count > 0)
        .map(
          item =>
            ({
              gtDefectId: item.gtClassId,
              predDefectId: item.predClassId,
              count: item.count,
              gtCaption: item.gtClassId ? getDefectNameById(item.gtClassId) : 'No label',
              predictionCaption: item.predClassId
                ? getDefectNameById(item.predClassId)
                : 'No label',
            } as PredictionMatrixData),
        );
      return lowerSearchText
        ? confusionMatrixWithNames.filter(
            c =>
              c.gtCaption.toLowerCase().includes(lowerSearchText) ||
              c.predictionCaption.toLowerCase().includes(lowerSearchText),
          )
        : confusionMatrixWithNames;
    };

    const { correct, misClassification, falseNegative, falsePositive } =
      confusionMatrixData?.splitConfusionMatrices ?? {};
    const lowerSearchText = searchText.toLowerCase();
    const correctConfusionMatrix = addNameAndFilterSearchKeyToMatrix(
      correct?.data ?? [],
      lowerSearchText,
    );
    const misClassificationConfusionMatrix = addNameAndFilterSearchKeyToMatrix(
      misClassification?.data ?? [],
      lowerSearchText,
    );
    const falseNegativeConfusionMatrix = addNameAndFilterSearchKeyToMatrix(
      falseNegative?.data ?? [],
      lowerSearchText,
    );
    const falsePositiveConfusionMatrix = addNameAndFilterSearchKeyToMatrix(
      falsePositive?.data ?? [],
      lowerSearchText,
    );
    return {
      correctConfusionMatrix,
      misClassificationConfusionMatrix,
      falseNegativeConfusionMatrix,
      falsePositiveConfusionMatrix,
    };
  }, [confusionMatrixData, searchText, getDefectNameById]);

  const { data: project } = useGetSelectedProjectQuery();
  const { labelType } = project ?? {};
  const countTitle = labelType === LabelType.Segmentation ? t('Pixels') : t('Count');

  return isConfusionMatrixDataLoading ? (
    <LoadingProgress size={24} />
  ) : filterOptions ? (
    <Box id="confusion-matrix-section" flexShrink={0} flexGrow={0}>
      <Box className={styles.HeaderWithClearButton} marginBottom={4}>
        <Typography variant="body_bold">{t('Analyze')}</Typography>
      </Box>
      <Box marginBottom={5}>
        <TextField
          variant="outlined"
          placeholder={t('Search by class')}
          InputProps={{
            className: styles.confusionMatrixSearchField,
            startAdornment: <Search className={styles.searchIcon} />,
          }}
          value={searchText}
          onChange={e => setSearchText(e.target.value ?? '')}
        />
      </Box>
      <Box>
        <Box display="table" width="100%">
          <Box display="table-row" className={styles.matrixRow}>
            <Box
              display="table-cell"
              className={cx(styles.matrixCell, styles.matrixLeftCell)}
              width={150}
            >
              <Typography variant="body2">
                <strong>{t('Ground Truth')}</strong>
              </Typography>
            </Box>
            <Box display="table-cell" className={styles.matrixCell} width={190}>
              <Typography variant="body2">
                <strong>{t('Prediction')}</strong>
              </Typography>
            </Box>
            <Box display="table-cell" className={cx(styles.matrixCell, styles.matrixRightCell)}>
              <Typography variant="body2">
                <strong>{countTitle}</strong>
              </Typography>
            </Box>
          </Box>
        </Box>
        {[
          falsePositiveConfusionMatrix,
          falseNegativeConfusionMatrix,
          misClassificationConfusionMatrix,
          correctConfusionMatrix,
        ].map(
          (matrix, index) =>
            matrix.length > 0 && (
              <ConfusionMatrixDiffTable
                key={['false-positive', 'false-negative', 'mis-classified', 'correct'][index]}
                title={
                  [t('False Positive'), t('False Negative'), t('Misclassified'), t('Correct')][
                    index
                  ]
                }
                titleTooltip={
                  [
                    t(
                      'The model predicted that an object of interest was present, but the model was incorrect.',
                    ),
                    t(
                      'The model predicted that an object of interest was not present, but the model was incorrect.',
                    ),
                    t(
                      'The model correctly predicted that an object of interest was present, but it predicted the wrong class.',
                    ),
                    t('The model’s prediction was correct.'),
                  ][index]
                }
                isCorrectMapping={index === 3}
                confusionMatrixSum={matrix.reduce((accum, val) => accum + val.count, 0)}
                confusionMatrices={matrix}
              />
            ),
        )}
      </Box>
    </Box>
  ) : (
    <Box className={styles.emptyState}>
      <img className={styles.emptyStateIcon} src={ConfusionMatrixEmptyStateSVG} />
      <Typography variant="body2">
        {t(
          'Click on a cell above to view the specific ground truth and prediction pair of instances.',
        )}
      </Typography>
    </Box>
  );
};

export default ModelConfusionMatrix;
