import React, { useMemo } from 'react';
import { Box, makeStyles } from '@material-ui/core';
import { Media } from '@clef/shared/types';

import { SuccessIcon, ErrorIcon } from '@/images/data-browser/icons';
import {
  useGetSelectedProjectQuery,
  useGetProjectVersionedDefectsQuery,
} from '@/serverStore/projects';
import MediaContainer from '@/pages/DataBrowser/MediaContainer';
import { calcOptimalRatio } from '@/pages/DataBrowser/MediaGrid/MediaGrid';
import { getClassifiedClass } from '@/utils';
import { useDatasetMediaDetailsQuery } from '@/serverStore/dataset';
import { Typography } from '@clef/client-library';

const useStyles = makeStyles(() => ({
  diffViewDetailItem: {
    display: 'flex',
    alignItems: 'center',
    gap: 10,
  },
}));

type ClassificationDiffViewProps = {
  media: Media;
  modelId?: string;
  threshold?: number;
  candidateModelId?: string;
  candidateThreshold?: number;
  version?: number;
};

const ClassificationDiffView: React.FC<ClassificationDiffViewProps> = ({
  media,
  modelId,
  threshold,
  candidateModelId,
  candidateThreshold,
  version,
}) => {
  const styles = useStyles();
  const { datasetId } = useGetSelectedProjectQuery().data ?? {};

  const { data: baselineMediaDetails } = useDatasetMediaDetailsQuery({
    datasetId: datasetId,
    mediaId: media.id,
    modelId: modelId,
    versionId: version,
  });

  const { data: candidateMediaDetails } = useDatasetMediaDetailsQuery({
    datasetId: datasetId,
    mediaId: media.id,
    modelId: candidateModelId,
    versionId: version,
  });

  const { data: versionedDefects } = useGetProjectVersionedDefectsQuery(version);

  const gtClassifiedClass = useMemo(() => {
    return getClassifiedClass(
      baselineMediaDetails?.label?.annotations || [],
      versionedDefects ?? [],
    );
  }, [versionedDefects, baselineMediaDetails?.label?.annotations]);

  const baselineClassifiedClass = useMemo(() => {
    return getClassifiedClass(
      baselineMediaDetails?.predictionLabel?.annotations || [],
      versionedDefects ?? [],
    );
  }, [versionedDefects, baselineMediaDetails?.predictionLabel?.annotations]);

  const candidateClassifiedClass = useMemo(() => {
    return getClassifiedClass(
      candidateMediaDetails?.predictionLabel?.annotations || [],
      versionedDefects ?? [],
    );
  }, [versionedDefects, candidateMediaDetails?.predictionLabel?.annotations]);

  return (
    <Box display="flex" alignItems="center">
      <Box flex={1}>
        <Typography variant="body_bold">{t('Ground truth')}</Typography>
        <Typography>{gtClassifiedClass?.name}</Typography>
      </Box>
      <Box flex={1}>
        <Typography variant="body_bold">{t('Baseline model')}</Typography>
        <Box className={styles.diffViewDetailItem}>
          {baselineClassifiedClass?.name === gtClassifiedClass?.name ? (
            <SuccessIcon />
          ) : (
            <ErrorIcon />
          )}
          <Typography variant="body_medium">{baselineClassifiedClass?.name}</Typography>
          <Typography>{threshold}</Typography>
        </Box>
      </Box>
      <Box flex={1}>
        <Typography variant="body_bold">{t('Candidate model')}</Typography>
        <Box className={styles.diffViewDetailItem}>
          {candidateClassifiedClass?.name === gtClassifiedClass?.name ? (
            <SuccessIcon />
          ) : (
            <ErrorIcon />
          )}
          <Typography variant="body_medium">{candidateClassifiedClass?.name}</Typography>
          <Typography>{candidateThreshold}</Typography>
        </Box>
      </Box>
    </Box>
  );
};

export type ModelImageListClassificationProps = {
  showHeatmap: boolean;
  mediaList?: Media[];
  modelId?: string;
  threshold?: number;
  candidateModelId?: string;
  candidateThreshold?: number;
  containerWidth: number;
  version?: number;
};

const ModelComparisonImageListClassification: React.FC<
  ModelImageListClassificationProps
> = props => {
  const {
    showHeatmap,
    mediaList,
    modelId,
    threshold,
    candidateModelId,
    candidateThreshold,
    version,
    containerWidth,
  } = props;

  if (!mediaList) {
    return null;
  }
  const columns = 3;
  const columnWidth = 100 / columns;
  const imageRatio = calcOptimalRatio(mediaList);
  return (
    <>
      {mediaList.map(media => {
        return (
          <Box display="flex" key={media.id}>
            <Box
              width={columnWidth + '%'}
              height={(containerWidth * imageRatio) / columns}
              maxHeight={300}
            >
              <MediaContainer
                showHeatmap={showHeatmap}
                media={media}
                showClassChip={false}
                modelId={modelId}
                versionId={version}
                threshold={threshold}
              />
            </Box>
            <Box flex={1} paddingLeft={7}>
              <ClassificationDiffView
                media={media}
                modelId={modelId}
                threshold={threshold}
                candidateModelId={candidateModelId}
                candidateThreshold={candidateThreshold}
                version={version}
              />
            </Box>
          </Box>
        );
      })}
    </>
  );
};

export default ModelComparisonImageListClassification;
