import {
  AnnotationType,
  ClassificationConfusionMatrix,
  DatasetGroupOptions,
  Defect,
  DefectId,
  Dimensions,
  InstanceViewMode,
  MediaDetails,
  SelectMediaOption,
  BoxAnnotation,
} from '@clef/shared/types';
import { hexToRgb } from '@clef/shared/utils';
import React, { useMemo, useEffect, useState } from 'react';
import { useCurrentProjectModelInfoQuery } from '@/serverStore/projectModels';
import useImage from 'use-image';
import { runLengthDecode } from '@clef/shared/utils';
import { PureCanvasLabelingAnnotation } from '../../components/Labeling/labelingState';
import {
  useAnnotationInstancesCountQuery,
  useAnnotationInstancesQuery,
  useDatasetMediaCountQuery,
  useDatasetMediaDetailsQuery,
  useDatasetMediasQuery,
} from '@/serverStore/dataset';
import { useGetDatasetStatsQuery } from '@/serverStore/dataset';
import { useGetSelectedProjectQuery } from '@/serverStore/projects';
import {
  useCurrentDefectsWithArchivedAndIndexFromVersion,
  useDefectSelectorWithArchived,
} from '../../store/defectState/actions';
import { getDefectColor } from '../../utils';
import {
  appliedFilterMappingToFormattedFilterMapping,
  DataBrowserState,
  getColumnFilterMapWithModelId,
  useDataBrowserState,
} from './dataBrowserState';
import {
  imageDataToOffscreenCanvas,
  SegmentationAnnotation,
  useDebouncedEffect,
  useLoadImage,
  CanvasAnnotation,
  AnnotationSourceType,
} from '@clef/client-library';
import { useSubscriptionPlanSettings } from '../../hooks/api/useSubscriptionApi';
import { useModels } from '../../hooks/useModels';
import { includesSkipDialogById, SkipDialogKey } from '../../utils/train';
import { canvasToBitMap } from '@/components/Labeling/utils';
import { isEmpty } from 'lodash';
import LRUCache from 'lru-cache';
import SegPredictionParserWorker from './segPredictionParser.worker';
import { FilterOptions } from '@/api/model_analysis_api';
import { useAtom } from 'jotai';
import { thresholdForPredictAtom } from '@/uiStates/projectModels/pageUIStates';

export const imageToImageData = (
  image?: HTMLImageElement,
  { width, height }: Dimensions = { width: 0, height: 0 },
) => {
  if (image) {
    const offscreen = new OffscreenCanvas(width, height);
    const context = offscreen.getContext('2d', {
      desynchronized: true,
    })!;
    context.imageSmoothingEnabled = false;
    context.drawImage(image, 0, 0);
    return context.getImageData(0, 0, width, height).data;
  } else {
    return new Uint8ClampedArray();
  }
};

interface FilteredGtAndPredInfo {
  gt: {
    dataUrl: string;
    canvas: OffscreenCanvas;
  };
  prediction: {
    dataUrl: string;
    canvas: OffscreenCanvas;
  };
}

interface FilteredComparisonGtAndPredInfo {
  baseline: {
    dataUrl: string;
    canvas: OffscreenCanvas;
  };
  candidate: {
    dataUrl: string;
    canvas: OffscreenCanvas;
  };
}

// Worker and cache here are used for imporving prediction parsing performance
const worker = new SegPredictionParserWorker();
let globalMessageId = 0;
const predictionDataUrlCache = new LRUCache<string, string>({
  max: 200,
  maxAge: 1000 * 60 * 15, // 15 min
});
const offscreenCanvasCache = new LRUCache<string, OffscreenCanvas>({
  max: 200,
  maxAge: 1000 * 60 * 15, // 15 min
});
const filteredGtAndPredInfoCache = new LRUCache<string, FilteredGtAndPredInfo>({
  max: 200,
  maxAge: 1000 * 60 * 15, // 15 min
});
const filteredComparisonGtAndPredInfoCache = new LRUCache<string, FilteredComparisonGtAndPredInfo>({
  max: 200,
  maxAge: 1000 * 60 * 15, // 15 min
});

/**
 * parse prediction image to get segmentation prediction result based on input score threshold.
 * for each pixel in the prediction image:
 *   the first channel is category (defect index), and
 *   the third channel is Math.floor(score * 255)
 *
 * @param imgSrc image source url
 * @param scoreThreshold 0 ~ 1, only return segmentation result with score >= scoreThreshold
 * @param defects defects array
 * @param opacity 0 ~ 1, the alpha value of the segmentation image
 * @param areaMaskSrc source url for a mask where each pixel encodes the corresponding connected component area as base-24
 * @param defectIdToAreaThreshold mapping from defect ID to the minimum allowed component area
 * @returns offscreen canvas of the segmentation image
 */
export const useSegmentationPredictionDataUrl = (
  imgSrc: string,
  scoreThreshold: number,
  defects?: Defect[],
  opacity: number = 0.3,
  areaMaskSrc?: string,
  defectIdToAreaThreshold?: Record<DefectId, number>,
  filterOptions?: FilterOptions,
): [string | undefined, OffscreenCanvas | undefined] => {
  const image = useLoadImage(imgSrc || '', 'use-credentials') || undefined;
  const { width, height } = image ?? { width: 0, height: 0 };
  const allDefects = useDefectSelectorWithArchived();
  const finalDefects = defects ?? allDefects;

  const [areaMaskImage] = useImage(areaMaskSrc || '', 'use-credentials');
  const { width: areaMaskWidth, height: areaMaskHeight } = areaMaskImage ?? { width: 0, height: 0 };

  const [dataurl, setDataUrl] = useState<string>();
  const [canvas, setCanvas] = useState<OffscreenCanvas>();

  useDebouncedEffect(
    () => {
      if (!image || !!filterOptions) {
        setDataUrl(undefined);
        setCanvas(undefined);
        return undefined;
      }

      // When imgSrc changes, `image` might still be the loaded image of the previous imgSrc.
      // If we don't skip when they are unmatched, we will wrongly cache the new key with
      // old prediction parsed from old imgSrc
      // if the imgSrc is a local path, the loaded image will add hostname to the path makes below unmatch
      if (imgSrc !== image.src && !image.src.includes(imgSrc)) {
        return;
      }

      // Directly return cached data url if hit cache
      const cacheKey = JSON.stringify({
        imgSrc,
        scoreThreshold,
        finalDefects,
        opacity,
        areaMaskSrc,
        defectIdToAreaThreshold,
      });
      if (offscreenCanvasCache.has(cacheKey)) {
        setCanvas(offscreenCanvasCache.get(cacheKey));
      }
      if (predictionDataUrlCache.has(cacheKey)) {
        setDataUrl(predictionDataUrlCache.get(cacheKey));
        return;
      }
      // prepare data for parsing predictions
      const predictionArray = imageToImageData(image, { width, height });
      const areaMaskArray = imageToImageData(areaMaskImage, {
        width: areaMaskWidth,
        height: areaMaskHeight,
      });
      const offscreenAboveThreshold = new OffscreenCanvas(width, height);
      const contextAboveThreshold = offscreenAboveThreshold.getContext('2d', {
        desynchronized: true,
      })!;
      contextAboveThreshold.imageSmoothingEnabled = false;
      const dataAboveThreshold = contextAboveThreshold.getImageData(0, 0, width, height);
      const defectIndexToColor = finalDefects.reduce(
        (obj: Record<number, { r: number; g: number; b: number }>, defect) => {
          const color = getDefectColor(defect);
          const rgb = hexToRgb(color);
          if (defect.indexId == undefined) {
            return obj;
          }
          obj[defect.indexId] = rgb;
          return obj;
        },
        {},
      );
      const defectIndexToDefectId: Record<number, DefectId> = {};
      for (const defect of finalDefects) {
        defectIndexToDefectId[defect.indexId!] = defect.id;
      }

      // Send to worker for CPU-heavy calculation
      const messageId = globalMessageId++;
      worker.postMessage(
        {
          messageId,
          type: 'generate-prediction-url',
          body: {
            areaMaskArray,
            dataAboveThreshold,
            defectIdToAreaThreshold,
            defectIndexToColor,
            defectIndexToDefectId,
            opacity,
            predictionArray,
            scoreThreshold,
          },
        },
        [areaMaskArray.buffer, dataAboveThreshold.data.buffer, predictionArray.buffer],
      );
      // Receive worker results
      const onMessage = (e: any) => {
        if (e.data.messageId === messageId) {
          const dataAboveThreshold = e.data.dataAboveThreshold as ImageData;
          if (dataAboveThreshold) {
            contextAboveThreshold.putImageData(dataAboveThreshold, 0, 0);
            setCanvas(offscreenAboveThreshold);
            offscreenCanvasCache.set(cacheKey, offscreenAboveThreshold);
            offscreenAboveThreshold.convertToBlob().then(blob => {
              const dataUrl = URL.createObjectURL(blob);
              setDataUrl(dataUrl);
              predictionDataUrlCache.set(cacheKey, dataUrl);
            });
          }
        }
      };
      worker.addEventListener('message', onMessage);

      return () => {
        worker.removeEventListener('message', onMessage);
      };
    },
    [
      finalDefects,
      opacity,
      filterOptions,
      height,
      image,
      scoreThreshold,
      width,
      defectIdToAreaThreshold,
      imgSrc,
      defects,
      areaMaskSrc,
      areaMaskImage,
      areaMaskWidth,
      areaMaskHeight,
    ],
    64, // debounced time
  );
  return [dataurl, canvas];
};

const useSegDefectMap = (
  width: number,
  height: number,
  mediaDetails?: MediaDetails,
  allDefects?: Defect[],
  filterOptions?: FilterOptions,
) => {
  const annotationsMap = useMemo(() => {
    if (!allDefects || !filterOptions || !mediaDetails?.label?.annotations?.length) {
      return [];
    }
    const defectMap: number[] = new Array(width * height).fill(0);
    mediaDetails.label.annotations?.forEach(annotation => {
      if (annotation.annotationType !== AnnotationType.segmentation) {
        return;
      }
      const { rangeBox, defectId, segmentationBitmapEncoded } = annotation;
      if (!rangeBox) {
        throw new Error('no rangeBox for segmentation annotation');
      }
      if (!segmentationBitmapEncoded) {
        throw new Error('no bitmap for segmentation annotation');
      }
      const { xmin, ymin, xmax, ymax } = rangeBox;
      const bitMapDecoded = runLengthDecode(segmentationBitmapEncoded);
      const defectIndexId = allDefects.find(defect => defect.id === defectId)?.indexId;
      if (!defectIndexId) {
        throw new Error('cannot find defect index id');
      }
      const rangeWidth = xmax - xmin + 1;
      const rangeHeight = ymax - ymin + 1;
      for (let y = 0; y < rangeHeight; y++) {
        for (let x = 0; x < rangeWidth; x++) {
          const index = Math.floor(y * rangeWidth + x);
          if (bitMapDecoded[index] === '1') {
            defectMap[(ymin + y) * width + xmin + x] = defectIndexId;
          }
        }
      }
    });
    return defectMap;
  }, [width, height, allDefects, mediaDetails?.label?.annotations]);

  const defectIndexToColor = useMemo(
    () =>
      allDefects?.reduce((obj: Record<number, { r: number; g: number; b: number }>, defect) => {
        const color = getDefectColor(defect);
        const rgb = hexToRgb(color);
        if (defect.indexId == undefined) {
          return obj;
        }
        obj[defect.indexId] = rgb;
        return obj;
      }, {}),
    [allDefects],
  );

  const defectIndexToDefectId = useMemo(() => {
    const res: Record<number, DefectId> = {};
    allDefects?.forEach(defect => (res[defect.indexId!] = defect.id));
    return res;
  }, [allDefects]);

  return {
    annotationsMap,
    defectIndexToColor,
    defectIndexToDefectId,
  };
};

export const useSegmentationInfoWithFilterOptions = (
  mediaId: number,
  modelId?: string,
  scoreThreshold?: number,
  versionId?: number,
  filterOptions?: FilterOptions,
): FilteredGtAndPredInfo | undefined => {
  const { datasetId } = useGetSelectedProjectQuery().data ?? {};
  const allDefects = useCurrentDefectsWithArchivedAndIndexFromVersion(versionId);
  const { data: mediaDetails } = useDatasetMediaDetailsQuery({
    datasetId,
    mediaId,
    modelId,
    ...(versionId && { versionId }),
  });
  const [imgPath, segImgPath] = [
    mediaDetails?.url || '',
    mediaDetails?.predictionLabel?.segImgPath || '',
  ];
  const image = useLoadImage(imgPath, 'use-credentials') || undefined;
  const segImage = useLoadImage(segImgPath, 'use-credentials') || undefined;
  const { width, height } = image ?? { width: 0, height: 0 };

  const [filteredGtAndPredInfo, setfilteredGtAndPredInfo] = useState<FilteredGtAndPredInfo>();

  const { annotationsMap, defectIndexToColor, defectIndexToDefectId } = useSegDefectMap(
    width,
    height,
    mediaDetails,
    allDefects,
    filterOptions,
  );

  useDebouncedEffect(
    () => {
      if (!image || !segImage || !filterOptions || !allDefects) {
        setfilteredGtAndPredInfo(undefined);
        return;
      }
      // Directly return cached data url if hit cache
      const cacheKey = JSON.stringify({
        imgPath,
        segImgPath,
        allDefects,
        filterOptions,
        annotationsMap,
        scoreThreshold,
      });
      if (filteredGtAndPredInfoCache.has(cacheKey)) {
        setfilteredGtAndPredInfo(filteredGtAndPredInfoCache.get(cacheKey));
        return;
      }
      // prepare data for parsing predictions
      const imageArray = imageToImageData(image, { width, height });
      const segMaskArray = imageToImageData(segImage, {
        width,
        height,
      });

      // Send to worker for CPU-heavy calculation
      const messageId = globalMessageId++;
      worker.postMessage(
        {
          messageId,
          type: 'generate-model-analysis-image-diff-view',
          body: {
            width,
            height,
            imageArray,
            segMaskArray,
            annotationsMap,
            defectIndexToColor,
            defectIndexToDefectId,
            scoreThreshold,
            filterOptions,
          },
        },
        [imageArray.buffer, segMaskArray.buffer],
      );
      // Receive worker results
      const onMessage = (e: any) => {
        if (e.data.messageId === messageId) {
          const { info } = e.data;
          if (info) {
            const gtCanvas = imageDataToOffscreenCanvas(info.gt.dataAboveThreshold);
            const predCanvas = imageDataToOffscreenCanvas(info.prediction.dataAboveThreshold);
            const value = {
              gt: {
                dataUrl: info.gt.dataUrl,
                canvas: gtCanvas,
              },
              prediction: {
                dataUrl: info.prediction.dataUrl,
                canvas: predCanvas,
              },
            };
            setfilteredGtAndPredInfo(value);
            filteredGtAndPredInfoCache.set(cacheKey, value);
          }
        }
      };
      worker.addEventListener('message', onMessage);

      return () => {
        worker.removeEventListener('message', onMessage);
      };
    },
    [
      height,
      image,
      scoreThreshold,
      width,
      imgPath,
      segImgPath,
      segImage,
      filterOptions,
      allDefects,
      annotationsMap,
      defectIndexToColor,
      defectIndexToDefectId,
    ],
    64, // debounced time
  );
  return filteredGtAndPredInfo;
};

export const useComparisonSegmentationInfoWithFilters = (
  mediaId: number,
  baselineModelId?: string,
  baselineThreshold?: number,
  candidateModelId?: string,
  candidateThreshold?: number,
  versionId?: number,
  filterOptions?: FilterOptions,
): FilteredComparisonGtAndPredInfo | undefined => {
  const { datasetId } = useGetSelectedProjectQuery().data ?? {};
  const allDefects = useCurrentDefectsWithArchivedAndIndexFromVersion(versionId);
  const { data: baselineMediaDetails } = useDatasetMediaDetailsQuery({
    datasetId,
    mediaId,
    modelId: baselineModelId,
    ...(versionId && { versionId }),
  });
  const { data: candidateMediaDetails } = useDatasetMediaDetailsQuery({
    datasetId,
    mediaId,
    modelId: candidateModelId,
    ...(versionId && { versionId }),
  });
  const [imgPath, baselineSegImgPath, candidateSegImgPath] = [
    baselineMediaDetails?.url || '',
    baselineMediaDetails?.predictionLabel?.segImgPath || '',
    candidateMediaDetails?.predictionLabel?.segImgPath || '',
  ];
  const image = useLoadImage(imgPath, 'use-credentials') || undefined;
  const baselineSegImg = useLoadImage(baselineSegImgPath, 'use-credentials') || undefined;
  const candidateSegImg = useLoadImage(candidateSegImgPath, 'use-credentials') || undefined;
  const { width, height } = image ?? { width: 0, height: 0 };

  const [filteredComparisonInfo, setFilteredComparisonInfo] =
    useState<FilteredComparisonGtAndPredInfo>();

  // ground truth label are the same so just need to use baselineMediaDetails to get annotation amp
  const { annotationsMap, defectIndexToColor, defectIndexToDefectId } = useSegDefectMap(
    width,
    height,
    baselineMediaDetails,
    allDefects,
    filterOptions,
  );

  useDebouncedEffect(
    () => {
      if (!image || !baselineSegImg || !candidateSegImg || !filterOptions || !allDefects) {
        setFilteredComparisonInfo(undefined);
        return;
      }
      // Directly return cached data url if hit cache
      const cacheKey = JSON.stringify({
        imgPath,
        baselineSegImgPath,
        candidateSegImgPath,
        filterOptions,
        annotationsMap,
        baselineThreshold,
        candidateThreshold,
      });
      if (filteredComparisonGtAndPredInfoCache.has(cacheKey)) {
        setFilteredComparisonInfo(filteredComparisonGtAndPredInfoCache.get(cacheKey));
        return;
      }
      // prepare data for parsing predictions
      const imageArray = imageToImageData(image, { width, height });
      const baselineSegMaskArray = imageToImageData(baselineSegImg, {
        width,
        height,
      });
      const candidateSegMaskArray = imageToImageData(candidateSegImg, {
        width,
        height,
      });

      // Send to worker for CPU-heavy calculation
      const messageId = globalMessageId++;
      worker.postMessage(
        {
          messageId,
          type: 'generate-model-comparison-image-diff-view',
          body: {
            width,
            height,
            imageArray,
            baselineSegMaskArray,
            candidateSegMaskArray,
            annotationsMap,
            defectIndexToColor,
            defectIndexToDefectId,
            baselineThreshold,
            candidateThreshold,
            filterOptions,
          },
        },
        [imageArray.buffer, baselineSegMaskArray.buffer, candidateSegMaskArray.buffer],
      );
      // Receive worker results
      const onMessage = (e: any) => {
        if (e.data.messageId === messageId) {
          const { info } = e.data;
          if (info) {
            const baselineCanvas = imageDataToOffscreenCanvas(info.baseline.dataAboveThreshold);
            const candidateCanvas = imageDataToOffscreenCanvas(info.candidate.dataAboveThreshold);
            const value = {
              baseline: {
                dataUrl: info.baseline.dataUrl,
                canvas: baselineCanvas,
              },
              candidate: {
                dataUrl: info.candidate.dataUrl,
                canvas: candidateCanvas,
              },
            };
            setFilteredComparisonInfo(value);
            filteredGtAndPredInfoCache.set(cacheKey, info);
          }
        }
      };
      worker.addEventListener('message', onMessage);

      return () => {
        worker.removeEventListener('message', onMessage);
      };
    },
    [
      height,
      image,
      width,
      imgPath,
      baselineSegImg,
      baselineSegImgPath,
      candidateSegImg,
      candidateSegImgPath,
      baselineThreshold,
      candidateThreshold,
      filterOptions,
      annotationsMap,
      defectIndexToColor,
      defectIndexToDefectId,
    ],
    64, // debounced time
  );
  return filteredComparisonInfo;
};

export const useSegmentationMaskCanvas = (
  imgSrc: string,
  defects?: Defect[],
  opacity: number = 0.3,
) => {
  const [image] = useImage(imgSrc || '', 'use-credentials');
  const { width, height } = image ?? { width: 0, height: 0 };
  const allDefects = useDefectSelectorWithArchived();
  const finalDefects = defects ?? allDefects;
  const predictionArray = useMemo(
    () => imageToImageData(image, { width, height }),
    [image, width, height],
  );

  return useMemo(() => {
    if (imgSrc !== image?.src || !image) {
      return undefined;
    }
    const offscreenAboveThreshold = new OffscreenCanvas(width, height);
    const contextAboveThreshold = offscreenAboveThreshold.getContext('2d', {
      desynchronized: true,
    })!;
    contextAboveThreshold.imageSmoothingEnabled = false;
    const dataAboveThreshold = contextAboveThreshold.getImageData(0, 0, width, height);
    const defectIndexToColor = finalDefects.reduce(
      (obj: Record<number, { r: number; g: number; b: number }>, defect) => {
        const color = getDefectColor(defect);
        const rgb = hexToRgb(color);
        if (defect.indexId == undefined) {
          return obj;
        }
        obj[defect.indexId] = rgb;
        return obj;
      },
      {},
    );
    const defectIndexToDefectId: Record<number, DefectId> = {};
    for (const defect of finalDefects) {
      defectIndexToDefectId[defect.indexId!] = defect.id;
    }

    let hasPredictionPixel = false;
    for (let i = 0; i < predictionArray.length; i += 4) {
      const defectIndex = predictionArray[i];
      if (defectIndex > 0) {
        const { r, g, b } = defectIndexToColor[defectIndex] ?? { r: 0, g: 0, b: 0 };
        dataAboveThreshold.data[i] = r;
        dataAboveThreshold.data[i + 1] = g;
        dataAboveThreshold.data[i + 2] = b;
        dataAboveThreshold.data[i + 3] = 255 * opacity;
        hasPredictionPixel = true;
      } else {
        dataAboveThreshold.data[i + 3] = 0;
      }
    }

    if (hasPredictionPixel) {
      contextAboveThreshold.putImageData(dataAboveThreshold, 0, 0);
      return offscreenAboveThreshold;
    }

    return undefined;
  }, [imgSrc, image, width, height, finalDefects, predictionArray, opacity]);
};

export const useOffscreenCanvasToDataUrl = (canvas?: OffscreenCanvas) => {
  const [dataUrl, setDataUrl] = React.useState<string | null>(null);
  useEffect(() => {
    if (canvas) {
      canvas
        .convertToBlob()
        .then(blob => setDataUrl(URL.createObjectURL(blob)))
        .catch(() => setDataUrl(null));
    } else {
      setDataUrl(null);
    }
  }, [canvas]);
  return dataUrl;
};

export const useSegmentationPredictionOffscreenCanvas = (
  imgSrc: string,
  allDefects: Defect[],
  scoreThreshold?: number,
  opacity: number = 0.6,
) => {
  const [image] = useImage(imgSrc || '', 'use-credentials');

  const { width, height } = image ?? { width: 0, height: 0 };
  const predictionArray = useMemo(
    () => imageToImageData(image, { width, height }),
    [image, width, height],
  );

  return useMemo(() => {
    if (!image || image.src !== imgSrc || scoreThreshold === undefined) {
      return;
    }
    const offscreenMap: Record<
      number,
      {
        canvas: OffscreenCanvas;
        context: OffscreenCanvasRenderingContext2D;
        imageData: ImageData;
        hasData?: boolean;
      }
    > = {};
    const defectColors = {} as Record<number, { r: number; g: number; b: number }>;
    allDefects.forEach(defect => {
      const color = getDefectColor(defect);
      defectColors[defect.indexId! - 1] = hexToRgb(color);
    });
    const scoreThresholdInt = Math.floor(scoreThreshold * 255);
    for (let i = 0; i < predictionArray.length; i += 4) {
      const score = predictionArray[i + 2];
      const defectIndex = predictionArray[i] - 1;
      if (defectIndex >= 0) {
        // create new canvas for new identified defect
        if (!(defectIndex in offscreenMap)) {
          const canvas = new OffscreenCanvas(width, height);
          const context = canvas.getContext('2d', {
            desynchronized: true,
          })!;
          context.imageSmoothingEnabled = false;
          const imageData = context.getImageData(0, 0, width, height);
          offscreenMap[defectIndex] = { canvas, context, imageData };
        }
        // set image data, set the pixel with defect color when score >= scoreThreshold
        const { imageData } = offscreenMap[defectIndex];
        if (score >= scoreThresholdInt) {
          const { r, g, b } = defectColors[defectIndex] ?? { r: 0, g: 0, b: 0 };
          imageData.data[i] = r;
          imageData.data[i + 1] = g;
          imageData.data[i + 2] = b;
          imageData.data[i + 3] = 255 * opacity;
          offscreenMap[defectIndex].hasData = true;
        } else {
          imageData.data[i + 3] = 0;
        }
      }
    }
    return offscreenMap;
  }, [allDefects, height, image, imgSrc, opacity, predictionArray, scoreThreshold, width]);
};

/**
 * parse prediction image to get segmentation prediction annotations based on input score threshold.
 * for each pixel in the prediction image:
 *   the first channel is category (defect index), and
 *   the third channel is Math.floor(score * 255)
 *
 * @param imgSrc image source url
 * @param scoreThreshold 0 ~ 1, only return segmentation result with score >= scoreThreshold
 * @param defectMap defect map, key is defect index, value is defect ID
 * @param opacity 0 ~ 1, the alpha value of the annotations
 * @returns pure canvas labeling annotations
 */
export const useSegmentationPredictionLabelingAnnotations = (
  imgSrc: string,
  scoreThreshold?: number,
  opacity: number = 0.6,
) => {
  const currentDefects = useCurrentDefectsWithArchivedAndIndexFromVersion();
  const offscreenMap = useSegmentationPredictionOffscreenCanvas(
    imgSrc,
    currentDefects,
    scoreThreshold,
    opacity,
  );

  return useMemo(() => {
    if (!offscreenMap) {
      return;
    }
    const defectIndexToDefect = {} as Record<number, Defect>;
    currentDefects.forEach(defect => {
      defectIndexToDefect[defect.indexId! - 1] = defect;
    });
    // convert to labeling annotation
    return Object.entries(offscreenMap)
      .filter(([_, { hasData }]) => hasData)
      .map(([defectIndex, { canvas, context, imageData }], index) => {
        context.putImageData(imageData, 0, 0);
        const currentDefect = defectIndexToDefect[Number(defectIndex)];
        const defectId = currentDefect?.id ?? 0;
        return {
          id: (Date.now() + index).toString(32), // does not matter
          defectId,
          data: canvas,
          color: getDefectColor(currentDefect),
        } as PureCanvasLabelingAnnotation;
      });
  }, [offscreenMap, currentDefects]);
};

export const useSegmentationMaskLabelingAnnotations = (
  imgSrc: string,
  defectMap?: Defect[],
  opacity: number = 0.6,
) => {
  const allDefects = useDefectSelectorWithArchived();
  const finalDefectMap = defectMap ?? allDefects;

  const offscreenCanvas = useSegmentationMaskCanvas(imgSrc, finalDefectMap, opacity);
  return useMemo(() => {
    return offscreenCanvas ? canvasToBitMap(offscreenCanvas, finalDefectMap) : undefined;
  }, [finalDefectMap, offscreenCanvas]);
};

/**
 * parse prediction image to get segmentation prediction annotations based on input score threshold.
 * for each pixel in the prediction image:
 *   the first channel is category (defect index), and
 *   the third channel is Math.floor(score * 255)
 *
 * @param imgSrc image source url
 * @param allDefects defect map, key is defect index, value is defect ID
 * @param scoreThreshold 0 ~ 1, only return segmentation result with score >= scoreThreshold
 * @param opacity 0 ~ 1, the alpha value of the annotations
 * @returns pure canvas labeling annotations
 */
export const useSegmentationPredictionAnnotations = (
  imgSrc: string,
  allDefects: Defect[],
  scoreThreshold?: number,
  opacity: number = 0.6,
) => {
  const offscreenMap = useSegmentationPredictionOffscreenCanvas(
    imgSrc,
    allDefects,
    scoreThreshold,
    opacity,
  );

  return useMemo(() => {
    if (!offscreenMap) {
      return;
    }

    const defectIndexToDefect = {} as Record<number, Defect>;
    allDefects.forEach(defect => {
      defectIndexToDefect[defect.indexId! - 1] = defect;
    });
    // convert to labeling annotation
    const res = Object.entries(offscreenMap)
      .filter(([_, { hasData }]) => hasData)
      .map(([defectIndex, { canvas, context, imageData }], index) => {
        context.putImageData(imageData, 0, 0);
        const defect = defectIndexToDefect[Number(defectIndex)];
        const defectId = defect?.id ?? 0;
        const ann = canvasToBitMap(canvas, [defect])[0];
        return ann
          ? ({
              key: (Date.now() + index).toString(32), // does not matter
              defectId,
              color: getDefectColor(defect),
              description: defect?.name || undefined,
              compressedBitMap: ann.data.bitMap,
              xMin: ann.data.rangeBox.xmin,
              yMin: ann.data.rangeBox.ymin,
              xMax: ann.data.rangeBox.xmax,
              yMax: ann.data.rangeBox.ymax,
              data: canvas,
            } as SegmentationAnnotation & PureCanvasLabelingAnnotation)
          : undefined;
      })
      .filter(Boolean);
    return isEmpty(res)
      ? undefined
      : (res as (SegmentationAnnotation & PureCanvasLabelingAnnotation)[]);
  }, [offscreenMap, allDefects]);
};

export const useTotalMediaCount = (props?: DataBrowserState) => {
  const { state } = useDataBrowserState();
  const { appliedFilters } = props ?? state;
  const { id: currentModelId } = useCurrentProjectModelInfoQuery();
  const [columnFilterMap, metadataFilterMap] = useMemo(() => {
    const [col, metadata] = appliedFilterMappingToFormattedFilterMapping(appliedFilters);
    const colWithModelId = getColumnFilterMapWithModelId(col, currentModelId);
    return [colWithModelId, metadata];
  }, [appliedFilters, currentModelId]);
  return useDatasetMediaCountQuery({
    selectOptions: {
      fieldFilterMap: metadataFilterMap,
      columnFilterMap: columnFilterMap,
      selectedMedia: [],
      unselectedMedia: [],
      isUnselectMode: true,
    },
  });
};

export const useTotalAnnotationInstanceCount = (
  instanceViewMode: InstanceViewMode,
  props?: DataBrowserState,
) => {
  const { state } = useDataBrowserState();
  const { appliedFilters } = props ?? state;
  const { id: currentModelId } = useCurrentProjectModelInfoQuery();
  const [thresholdForPredict] = useAtom(thresholdForPredictAtom);
  const [columnFilterMap, metadataFilterMap] = useMemo(() => {
    const [col, metadata] = appliedFilterMappingToFormattedFilterMapping(appliedFilters);
    const colWithModelId = getColumnFilterMapWithModelId(col, currentModelId, true);
    return [colWithModelId, metadata];
  }, [appliedFilters, currentModelId]);
  return useAnnotationInstancesCountQuery(
    // currentModelId will be '' (empty string) if the project does not have a trained model
    // undefined means the API has not returned yet
    {
      columnFilterMap: columnFilterMap,
      metadataFilterMap: metadataFilterMap,
      instanceViewMode,
      threshold: thresholdForPredict,
    },
    currentModelId !== undefined,
  );
};

export const useTotalMediaOrInstanceCount = (props?: DataBrowserState) => {
  const { state } = useDataBrowserState();
  const { viewMode, showPredictions } = props ?? state;
  const { data: totalMediaCount } = useTotalMediaCount(props);

  const instanceViewMode: InstanceViewMode = showPredictions
    ? InstanceViewMode.Prediction
    : InstanceViewMode.GroundTruth;

  const { data: totalAnnotationInstanceCounts } = useTotalAnnotationInstanceCount(
    instanceViewMode,
    props,
  );

  return useMemo(() => {
    if (viewMode === 'image') {
      return totalMediaCount;
    }

    if (!totalAnnotationInstanceCounts) {
      return undefined;
    }

    const { count } = totalAnnotationInstanceCounts;
    return count;
  }, [viewMode, totalAnnotationInstanceCounts, totalMediaCount]);
};

// when called outside context provider, need to pass the state
export const useDatasetMedias = (props?: DataBrowserState) => {
  const { state } = useDataBrowserState();
  const { appliedFilters, sortField, paginationLimit, pageIndex } = props ?? state;
  const { id: currentModelId } = useCurrentProjectModelInfoQuery();
  const [columnFilterMap, metadataFilterMap] = useMemo(() => {
    const [col, metadata] = appliedFilterMappingToFormattedFilterMapping(appliedFilters);
    const colWithModelId = getColumnFilterMapWithModelId(col, currentModelId);
    return [colWithModelId, metadata];
  }, [appliedFilters, currentModelId]);
  const newMetadataFilterMap = Object.fromEntries(
    Object.entries(metadataFilterMap).filter(([, value]) => {
      return Object.values(value).some(subValue => {
        return Array.isArray(subValue) ? subValue.length > 0 : !!subValue;
      });
    }),
  );

  return useDatasetMediasQuery({
    sortOptions: {
      ...sortField,
      offset: pageIndex * paginationLimit,
      limit: paginationLimit,
    },
    columnFilterMap,
    metadataFilterMap: newMetadataFilterMap,
    includeMediaStatus: true, // task creation needs here
  });
};

export const useAnnotationInstances = (props?: DataBrowserState) => {
  const { state } = useDataBrowserState();
  const [thresholdForPredict] = useAtom(thresholdForPredictAtom);

  const { appliedFilters, sortField, paginationLimit, pageIndex, showPredictions, viewMode } =
    props ?? state;
  const { id: currentModelId } = useCurrentProjectModelInfoQuery();
  const [columnFilterMap, metadataFilterMap] = useMemo(() => {
    const [col, metadata] = appliedFilterMappingToFormattedFilterMapping(appliedFilters);
    const colWithModelId = getColumnFilterMapWithModelId(col, currentModelId, true);
    return [colWithModelId, metadata];
  }, [appliedFilters, currentModelId]);
  const instanceViewMode = useMemo(() => {
    if (showPredictions) {
      return InstanceViewMode.Prediction;
    }
    return InstanceViewMode.GroundTruth;
  }, [showPredictions]);
  return useAnnotationInstancesQuery(
    // the GroundTruth view mode does not need the modelId
    // Prediction and Mix mode must have modelId
    {
      columnFilterMap,
      metadataFilterMap,
      mode: instanceViewMode,
      sortOptions: {
        ...sortField,
        offset: pageIndex * paginationLimit,
        limit: paginationLimit,
      },
      threshold: thresholdForPredict,
    },
    Boolean(
      viewMode === 'instance' &&
        (instanceViewMode === InstanceViewMode.GroundTruth || currentModelId),
    ),
  );
};

/**
 * Reformat { [gtDefectId]: { [predDefectId]: count } } to
 * [{ gtDefectId, predDefectId, count }, ...]
 */
export const classificationConfusionMatrixToParings = (
  confusionMatrix: ClassificationConfusionMatrix,
) =>
  Object.entries(confusionMatrix as ClassificationConfusionMatrix)
    .flatMap(([gtDefectId, prediction]) =>
      Object.entries(prediction).map(([predDefectId, count]) => ({
        gtDefectId: Number(gtDefectId),
        predDefectId: Number(predDefectId),
        count,
      })),
    )
    .filter(({ count }) => count > 0);

export const useIsShowUploadLimitDialog = () => {
  const { orgId } = useGetSelectedProjectQuery().data ?? {};
  const [orgPlanSetting] = useSubscriptionPlanSettings('anystring');
  const { models } = useModels();

  return !!(
    models &&
    models.length &&
    orgId &&
    !includesSkipDialogById(SkipDialogKey.Upload, orgId) &&
    orgPlanSetting?.predictUponUploading === false
  );
};

export const useIsLabeled = () => {
  const { state } = useDataBrowserState();
  const { appliedFilters } = state;
  const { id: currentModelId } = useCurrentProjectModelInfoQuery();
  const [columnFilterMap, metadataFilterMap] = useMemo(() => {
    const [col, metadata] = appliedFilterMappingToFormattedFilterMapping(appliedFilters);
    const colWithModelId = getColumnFilterMapWithModelId(col, currentModelId);
    return [colWithModelId, metadata];
  }, [appliedFilters, currentModelId]);

  const selectedMediaOption = useMemo(
    () =>
      ({
        fieldFilterMap: metadataFilterMap,
        columnFilterMap,
        selectedMedia: [],
        unselectedMedia: [],
        isUnselectMode: true,
      } as SelectMediaOption),
    [columnFilterMap, metadataFilterMap],
  );
  const { data: distrbutionStats, isLoading: distributionStatsLoading } = useGetDatasetStatsQuery({
    selectOptions: selectedMediaOption,
    groupOptions: [DatasetGroupOptions.DEFECT_DISTRIBUTION],
  });

  const isLabeled = useMemo(() => {
    if (distributionStatsLoading) return false;
    if (!distrbutionStats) return false;
    return distrbutionStats
      .map(item => item.count)
      .reduce((p, n) => {
        return p + n;
      }, 0);
  }, [distrbutionStats, distributionStatsLoading]);

  return isLabeled;
};

export const detectDuplicateAnnotations = (annotations: CanvasAnnotation[]) => {
  const gtAnnotations = annotations.filter(
    annotation => annotation.group === AnnotationSourceType.GroundTruth,
  );
  const annotationsWithoutGT = annotations.filter(
    annotation => annotation.group !== AnnotationSourceType.GroundTruth,
  );
  const seen = new Map<string, CanvasAnnotation>();
  const filteredAnnotations: CanvasAnnotation[] = [];
  let hasDuplicate = false;

  for (const annotation of gtAnnotations) {
    const { x, y, width, height, color } = annotation.data as BoxAnnotation;
    const key = `${x},${y},${width},${height},${color}`;

    if (seen.has(key)) {
      hasDuplicate = true;
    } else {
      seen.set(key, annotation);
      filteredAnnotations.push(annotation);
    }
  }

  return {
    duplicated: hasDuplicate,
    filteredAnnotations: [...filteredAnnotations, ...annotationsWithoutGT],
  };
};
