import React, { useEffect, useState } from 'react';
import { useAtom } from 'jotai';
import { AnnotationInstance, LabelType, Media, RegisteredModel } from '@clef/shared/types';
import { EvaluationSetItem } from '@/api/evaluation_set_api';
import { useGetModelMediaListInfiniteQuery } from '@/serverStore/modelAnalysis';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import { ModelImageListSegmentationDiffView } from './ModelImageListSegmentation';
import { FilterOptions, SortOrder } from '@/api/model_analysis_api';
import { Box, makeStyles, MenuItem, Select } from '@material-ui/core';
import { Button, Typography } from '@clef/client-library';
import ModelComparisonImageListObjectDetection from './ModelComparisonImageListObjectDetection';
import ModelComparisonImageListSegmentation from './ModelComparisonImageListSegmentation';
import ModelComparisonImageListClassification from './ModelComparisonImageListClassification';
import ModelImageDetailDialog from '../ModelImageDetail/ModelImageDetailDialog';
import LoadingProgress from '../LoadingProgress';
import ModelImageListClassification from './ModelImageListClassification';
import { modelListFilterOptionsAtom } from '../atoms';
import ModelImageVirtualListWrapper from './ModelImageVirtualListWrapper';
import { ObjectDetectionImageDiffView } from '../ModelImageDetail/ObjectDetectionImageDiffView';

const useStyles = makeStyles(theme => ({
  mediaListRoot: {
    width: '100%',
  },
  chip: {
    backgroundColor: theme.palette.greyModern[200],
    padding: theme.spacing(1.5, 3),
    borderRadius: 10,
    display: 'inline-flex',
    alignItems: 'center',
  },
  removeChip: {
    marginLeft: theme.spacing(2),
    color: theme.palette.greyModern[400],
    cursor: 'pointer',
  },
  sortByText: {
    paddingRight: theme.spacing(3),
  },
  imageListColumnTitle: { paddingBottom: theme.spacing(2) },
  totalImages: {
    display: 'flex',
    alignItems: 'center',
    gap: theme.spacing(1),
  },
}));

export type ModelImageListProps = {
  model?: RegisteredModel;
  evaluationSet?: EvaluationSetItem;
  threshold?: number;
  candidate?: RegisteredModel;
  candidateThreshold?: number;
};

const ModelImageList: React.FC<ModelImageListProps> = props => {
  const styles = useStyles();
  const { model, threshold, candidate, candidateThreshold, evaluationSet } = props;
  const [filterOptions, setFilterOptions] = useAtom(modelListFilterOptionsAtom);
  const [sortOrder, setSortOrder] = useState<SortOrder>(
    filterOptions ? SortOrder.DESC : SortOrder.ASC,
  );
  const { data: project } = useGetSelectedProjectQuery();
  const { labelType } = project ?? {};

  useEffect(() => {
    if (filterOptions) {
      setSortOrder(SortOrder.DESC);
    } else {
      setSortOrder(SortOrder.ASC);
    }
  }, [filterOptions]);

  const [containerWidth, setContainerWidth] = useState(0);

  const { data: mediaListPages, isLoading } = useGetModelMediaListInfiniteQuery(
    model?.id,
    threshold,
    evaluationSet,
    filterOptions,
    candidate?.id,
    candidateThreshold,
    candidate ? undefined : sortOrder, // sort in frontend for model comparison
  );

  const allMedias =
    (mediaListPages?.pages.flatMap(page => page?.mediaList).filter(media => !!media) as Media[]) ||
    undefined;

  const totalImagesComponent = mediaListPages?.pages[0]?.total && (
    <Box className={styles.totalImages}>
      <Typography variant="body_medium">
        {t('{{total}} images', { total: mediaListPages.pages[0]?.total })}
      </Typography>
      {filterOptions && (
        <Button
          id="model-comparison-report-clear-filter-button"
          variant="text"
          size="medium"
          color="primary"
          onClick={() => setFilterOptions(undefined)}
        >
          {candidate ? t('Close') : t('Clear filter')}
        </Button>
      )}
    </Box>
  );

  const getSelectText = (
    labelType: LabelType | undefined | null,
    sortOrder: SortOrder,
    filterOptions: FilterOptions | undefined,
  ) => {
    const selectTextSuffix = sortOrder === SortOrder.DESC ? t('high to low') : t('low to high');
    if (!filterOptions) {
      return t(`Improvement ({{selectTextSuffix}})`, { selectTextSuffix });
    }
    if (labelType === LabelType.BoundingBox || labelType === LabelType.Classification) {
      return t(`Error count ({{selectTextSuffix}})`, { selectTextSuffix });
    }
    return t(`Error area ({{selectTextSuffix}})`, { selectTextSuffix });
  };

  const handleSortOrderChange = (event: React.ChangeEvent<{ value: unknown }>) => {
    setSortOrder(event.target.value as SortOrder);
  };

  const sortOptions = (
    <Box display="flex" alignItems="center" marginLeft="auto">
      <Typography className={styles.sortByText}>{t('Sort by')}</Typography>
      <Select id="sort-order-select" value={sortOrder} onChange={handleSortOrderChange}>
        <MenuItem value={SortOrder.ASC}>
          {getSelectText(labelType, SortOrder.ASC, filterOptions)}
        </MenuItem>
        <MenuItem value={SortOrder.DESC}>
          {getSelectText(labelType, SortOrder.DESC, filterOptions)}
        </MenuItem>
      </Select>
    </Box>
  );

  const [selectedImageId, setSelectedImageId] = useState<number>();
  let imageListComponent: JSX.Element = <>ModelImageList</>;
  if (isLoading || !allMedias) {
    imageListComponent = <LoadingProgress size={24} />;
  } else if (labelType === LabelType.Segmentation) {
    imageListComponent = (
      <Box flex={1} marginLeft={7}>
        <Box
          display="flex"
          flexDirection="row"
          alignItems="flex-start"
          justifyContent="space-between"
          marginBottom={5}
        >
          {totalImagesComponent}
          {sortOptions}
        </Box>
        <div
          id="model-image-list-segmentation"
          className={styles.mediaListRoot}
          ref={ref => {
            if (ref?.clientWidth && ref.clientWidth !== containerWidth) {
              setContainerWidth(ref.clientWidth);
            }
          }}
        >
          {containerWidth &&
            (candidate ? (
              <ModelComparisonImageListSegmentation
                containerWidth={containerWidth}
                modelId={model?.id}
                sortOrder={sortOrder}
                evaluationSet={evaluationSet}
                threshold={threshold}
                candidateModelId={candidate.id}
                candidateThreshold={candidateThreshold}
              />
            ) : (
              <ModelImageVirtualListWrapper
                titles={[t('Ground truth mask'), t('Prediction mask')]}
                evaluationSet={evaluationSet}
                modelId={model?.id}
                sortOrder={sortOrder}
                threshold={threshold}
                containerWidth={containerWidth}
                rowRender={(media: Media, rowWidth: number) => {
                  return (
                    <ModelImageListSegmentationDiffView
                      key={media.id}
                      media={media}
                      rowWidth={rowWidth}
                      modelId={model?.id}
                      version={evaluationSet?.datasetVersion.version}
                      threshold={threshold}
                      onImageClick={() => setSelectedImageId(media.id)}
                    />
                  );
                }}
              />
            ))}
        </div>
      </Box>
    );
  } else if (labelType === LabelType.BoundingBox) {
    imageListComponent = (
      <Box flex={1} marginLeft={7}>
        <Box
          display="flex"
          flexDirection="row"
          alignItems="flex-start"
          justifyContent="space-between"
          marginBottom={5}
        >
          {totalImagesComponent}
          {sortOptions}
        </Box>
        <div
          id="model-image-list-object-detection"
          className={styles.mediaListRoot}
          ref={ref => {
            if (ref?.clientWidth && ref.clientWidth !== containerWidth) {
              setContainerWidth(ref.clientWidth);
            }
          }}
        >
          {containerWidth &&
            (candidate ? (
              <ModelComparisonImageListObjectDetection
                modelId={model?.id}
                evaluationSet={evaluationSet}
                threshold={threshold}
                sortOrder={sortOrder}
                containerWidth={containerWidth}
                candidateModelId={candidate.id}
                candidateThreshold={candidateThreshold}
              />
            ) : (
              <ModelImageVirtualListWrapper
                containerWidth={containerWidth}
                titles={[t('Ground truth'), t('Prediction')]}
                evaluationSet={evaluationSet}
                modelId={model?.id}
                threshold={threshold}
                sortOrder={sortOrder}
                onboardingTipsKey="odHightlightGtAndPredictionPairTips"
                rowRender={(
                  media: Media,
                  rowWidth: number,
                  allInstances?: AnnotationInstance[],
                ) => (
                  <ObjectDetectionImageDiffView
                    key={media.id}
                    media={media}
                    modelId={model?.id}
                    version={evaluationSet?.datasetVersion.version}
                    threshold={threshold}
                    allInstances={allInstances}
                    rowWidth={rowWidth}
                    onImageClick={imageId => setSelectedImageId(imageId)}
                  />
                )}
              />
            ))}
        </div>
      </Box>
    );
  } else if (labelType === LabelType.Classification) {
    imageListComponent = (
      <Box flex={1} marginLeft={7}>
        <Box
          display="flex"
          flexDirection="row"
          alignItems="flex-start"
          justifyContent="space-between"
          marginBottom={5}
        >
          {totalImagesComponent}
          {sortOptions}
        </Box>
        <div
          id="model-image-list-classification"
          className={styles.mediaListRoot}
          ref={ref => {
            if (ref?.clientWidth && ref.clientWidth !== containerWidth) {
              setContainerWidth(ref.clientWidth);
            }
          }}
        >
          {candidate ? (
            <ModelComparisonImageListClassification
              modelId={model?.id}
              mediaList={allMedias}
              version={evaluationSet?.datasetVersion.version}
              containerWidth={containerWidth}
              candidateModelId={candidate.id}
            />
          ) : (
            <ModelImageListClassification
              modelId={model?.id}
              mediaList={allMedias}
              version={evaluationSet?.datasetVersion.version}
              onImageClick={imageId => setSelectedImageId(imageId)}
            />
          )}
        </div>
      </Box>
    );
  }
  return (
    <>
      {imageListComponent}
      <ModelImageDetailDialog
        modelId={model?.id}
        threshold={threshold}
        candidateModelId={candidate?.id}
        candidateThreshold={candidateThreshold}
        evaluationSet={evaluationSet}
        mediaId={selectedImageId}
        mediaList={allMedias}
        onClose={() => setSelectedImageId(undefined)}
      />
    </>
  );
};

export default ModelImageList;
