import { TabManager, Typography } from '@clef/client-library';
import { Box, makeStyles } from '@material-ui/core';
import React, { useState } from 'react';
import { Provider } from 'jotai';
import ModelAnalysisSummary from '../ModelAnalysisSummary';
import ModelConfusionMatrix from './ModelConfusionMatrix';
import {
  useGetBatchModelMetricsQueryByModelId,
  useGetModelEvaluationReportsQuery,
} from '@/serverStore/modelAnalysis';
import { ModelEvaluationReportStatus, RegisteredModel } from '@clef/shared/types';
import { EvaluationSetItem } from '@/api/evaluation_set_api';
import { getTrainDevTestEvaluationSetName, getEvaluationSetName } from '../utils';
import ModelImageList from '../ModelImageList/ModelImageList';
import { isEmpty } from 'lodash';
import LoadingProgress from '../LoadingProgress';
import { BundleInfo } from '@/api/model_api';
import { TabsType } from '@clef/client-library/src/components/TabManager';
import { ConfusionMatrixTable } from './ConfusionMatrixTable';
import OpenInNew from '@material-ui/icons/OpenInNew';
import { useAtom } from 'jotai';
import { modelListFilterOptionsAtom } from '../atoms';

const useStyles = makeStyles(theme => ({
  title: {
    width: '100%',
    display: 'flex',
    alignItems: 'center',
    justifyContent: 'center',
    marginBottom: theme.spacing(2),
  },
  helpLink: {
    color: theme.palette.primary.main,
    textDecoration: 'none',
    display: 'flex',
    alignItems: 'center',
    marginLeft: theme.spacing(1),
    fontSize: '0.9rem',
    cursor: 'pointer',
  },
}));

type PerformanceReportPanelProps = {
  model: RegisteredModel & { bundles?: BundleInfo[] };
  evaluationSet?: EvaluationSetItem;
  threshold: number;
  trainDevTestColumnEvaluationSets?: EvaluationSetItem[];
  otherEvaluationSets?: EvaluationSetItem[];
  onChangeEvaluationSet?: (e: EvaluationSetItem) => void;
};

const AfterAdjustThresholdTip = () => {
  return (
    <Box
      width="100%"
      height="100%"
      display="flex"
      flexDirection="column"
      justifyContent="center"
      alignItems="center"
    >
      <Typography variant="body_regular">
        {t('You need to save the threshold to evaluate the model')}
      </Typography>
    </Box>
  );
};

const ConfusionMatrixTabs: React.FC<{
  model?: RegisteredModel & { bundles?: BundleInfo[] };
  evaluationSet?: EvaluationSetItem;
  threshold: number;
  selectedEvaluationSet: EvaluationSetItem;
}> = ({ model, evaluationSet, threshold, selectedEvaluationSet }) => {
  const classes = useStyles();
  const [filterOptions, setFilterOptions] = useAtom(modelListFilterOptionsAtom);
  return (
    <>
      <Box marginTop={6}>
        <TabManager
          type={TabsType.Card}
          defaultTab={'byConfusionMatrix'}
          disableSearchParam
          tabs={[
            {
              key: 'byConfusionMatrix',
              title: t('Analyze by confusion matrix'),
              component: (
                <>
                  <Box display="flex" flexDirection="column" alignItems="center">
                    <Typography variant="h4" className={classes.title}>
                      {t('Confusion Matrix')}
                      <Box
                        component="a"
                        className={classes.helpLink}
                        onClick={() => {
                          window.open(
                            'https://support.landing.ai/docs/performance-report',
                            '_blank',
                          );
                        }}
                      >
                        {t('Learn more')}
                        <OpenInNew />
                      </Box>
                    </Typography>
                    <ConfusionMatrixTable
                      model={model}
                      evaluationSetId={evaluationSet?.id}
                      threshold={threshold}
                      onClick={(gtClassId, predClassId) =>
                        setFilterOptions({ gtClassId, predClassId })
                      }
                    />
                  </Box>
                  <Box display="flex">
                    <ModelConfusionMatrix
                      model={model}
                      evaluationSet={selectedEvaluationSet}
                      threshold={threshold}
                    />
                    {filterOptions && (
                      <ModelImageList
                        model={model}
                        evaluationSet={selectedEvaluationSet}
                        threshold={threshold}
                      />
                    )}
                  </Box>
                </>
              ),
            },
            {
              key: 'allModelImages',
              title: t('Analyze all images'),
              component: (
                <Box display="flex">
                  {
                    <ModelImageList
                      model={model}
                      evaluationSet={selectedEvaluationSet}
                      threshold={threshold}
                    />
                  }
                </Box>
              ),
            },
          ]}
        />
      </Box>
    </>
  );
};

const PerformanceReportPanel = (props: PerformanceReportPanelProps) => {
  const {
    model,
    evaluationSet: initialEvaluateSet,
    threshold: initialThreshold,
    trainDevTestColumnEvaluationSets,
    otherEvaluationSets,
    onChangeEvaluationSet,
  } = props;

  const { data: evaluationReports, isLoading: isEvaluationReportsLoading } =
    useGetModelEvaluationReportsQuery(model?.id);
  const completedReports = (evaluationReports ?? []).filter(
    r => r.status === ModelEvaluationReportStatus.COMPLETED && r.threshold === initialThreshold,
  );
  const isCurrentModel = (evaluationSet: EvaluationSetItem) =>
    evaluationSet.datasetVersionId === model?.datasetVersionId;
  const evaluationOptions = [
    // train dev test for this model, should displayed with simplified names
    ...(trainDevTestColumnEvaluationSets ?? [])
      .filter(isCurrentModel)
      .sort((s1, s2) => {
        const order = { train: 0, dev: 1, test: 2 } as Record<string, number>;
        return order[s1.split?.splitSetName ?? ''] - order[s2.split?.splitSetName ?? ''];
      })
      .map(evaluationSet => ({
        evaluationSet,
        name: getTrainDevTestEvaluationSetName(evaluationSet),
      })),
    // train dev test for other models, display full names
    ...(trainDevTestColumnEvaluationSets ?? [])
      .filter(e => !isCurrentModel(e) && !e.hidden)
      .map(evaluationSet => ({
        evaluationSet,
        name: getEvaluationSetName(evaluationSet),
      })),
    // other evaluation sets, display full names
    ...(otherEvaluationSets ?? [])
      .filter(e => !e.hidden)
      .map(evaluationSet => ({
        evaluationSet,
        name: getEvaluationSetName(evaluationSet),
      })),
  ].filter(o => {
    // show only those eval sets that have completed evaluation
    return completedReports.some(r => r.evaluationSetId === o.evaluationSet.id);
  });
  const evaluationSet = initialEvaluateSet ?? evaluationOptions[0]?.evaluationSet;

  const selectedOption =
    evaluationOptions.find(option => option.evaluationSet.id === evaluationSet?.id) ??
    (evaluationOptions.length ? evaluationOptions[0] : null);
  const selectedEvaluationSet = selectedOption?.evaluationSet ?? null;

  const [threshold, setThreshold] = useState(initialThreshold);
  const { data: batchModelMetrics, isLoading: isModelMetricsLoading } =
    useGetBatchModelMetricsQueryByModelId(model.id, threshold) ?? {};
  const modelMetrics = batchModelMetrics?.find(
    pred => selectedEvaluationSet?.id && pred.evaluationSetId === selectedEvaluationSet.id,
  );

  return isEvaluationReportsLoading ? (
    <LoadingProgress />
  ) : evaluationOptions.length === 0 ? (
    <Typography>
      {t(
        'There is no evaluation set for this model, please run evaluation first on models table page.',
      )}
    </Typography>
  ) : (
    <>
      {/* summary */}
      <ModelAnalysisSummary
        performance={modelMetrics?.metrics?.all}
        model={model}
        threshold={threshold}
        setThreshold={setThreshold}
        initialThreshold={initialThreshold}
        evaluationSet={evaluationSet}
        evaluationOptions={evaluationOptions}
        onChangeEvaluationSet={onChangeEvaluationSet}
      />
      {isModelMetricsLoading ? (
        <LoadingProgress />
      ) : isEmpty(modelMetrics) ? (
        <AfterAdjustThresholdTip />
      ) : (
        selectedEvaluationSet && (
          <Provider>
            <ConfusionMatrixTabs
              model={model}
              evaluationSet={evaluationSet}
              threshold={threshold}
              selectedEvaluationSet={selectedEvaluationSet}
            />
          </Provider>
        )
      )}
    </>
  );
};

export default PerformanceReportPanel;
