import {
  AnnotationInstance,
  BundleStatus,
  BundleStatusItem,
  ConfusionMatrixCountByMediaId,
  LabelType,
  ModelEvaluationReport,
  ModelEvaluationReportStatus,
  ModelMetricParam,
  ProjectId,
  RegisteredModelId,
} from '@clef/shared/types';
import {
  useQuery,
  useInfiniteQuery,
  InfiniteData,
  useQueries,
  UseQueryOptions,
} from '@tanstack/react-query';
import { useGetSelectedProjectQuery } from '../projects';
import model_analysis_api, {
  BatchModelMetricResponse,
  ConfusionMatrixResponse,
  FilterOptions,
  GetMediaListResponse,
  GetModelPerformanceSummaryResponse,
  ModelMetricsResponse,
  SortOrder,
} from '@/api/model_analysis_api';
import { queryClient } from '..';
import { isNumber } from 'lodash';
import { ApiError } from '@/utils/error';
import model_api, { RegisteredModelWithBundles } from '@/api/model_api';
import clp_api from '@/api/clp_api';
import { useAtom } from 'jotai';
import { batchModelMetricsParamAtom } from '@/pages/model_iteration/componentsV2/atoms';
import { EvaluationSetItem } from '@/api/evaluation_set_api';
import { useMemo } from 'react';

export const modelAnalysisQueryKeys = {
  all: ['modelAnalysis'] as const,
  modelList: (projectId: ProjectId) =>
    [projectId, ...modelAnalysisQueryKeys.all, 'modelList'] as const,
  metrics: (projectId: ProjectId, modelId: string, evaluationSetId?: number, threshold?: number) =>
    [
      projectId,
      ...modelAnalysisQueryKeys.all,
      'metrics',
      modelId,
      { evaluationSetId, threshold },
    ] as const,
  batchMetrics: (
    projectId: ProjectId,
    modelList: ModelMetricParam[],
    modelMetricsMap?: Record<string, ModelMetricsResponse & ModelMetricParam>,
  ) =>
    [
      projectId,
      ...modelAnalysisQueryKeys.all,
      'batchMetrics',
      ...modelList,
      ...(modelMetricsMap ? [modelMetricsMap] : []),
    ] as const,
  mediaList: (
    projectId: ProjectId,
    modelId: string,
    evaluationSetId?: number,
    threshold?: number,
    candidateModelId?: string,
    candidateThreshold?: number,
    filterOptions?: FilterOptions,
    sortOrder?: SortOrder,
  ) => [
    projectId,
    ...modelAnalysisQueryKeys.all,
    'mediaList',
    modelId,
    { evaluationSetId, threshold },
    candidateModelId,
    candidateThreshold,
    filterOptions,
    sortOrder,
  ],
  paginatedMediaList: (
    projectId: ProjectId,
    modelId: string,
    evaluationSetId: number,
    labelType: LabelType,
    threshold?: number,
    candidateModelId?: string,
    candidateThreshold?: number,
    filterOptions?: FilterOptions,
    sortOrder?: SortOrder,
    limit?: number,
    version?: number,
  ) => [
    projectId,
    ...modelAnalysisQueryKeys.all,
    'paginatedMediaList',
    modelId,
    { evaluationSetId, threshold },
    labelType,
    ...(version ? [version] : []),
    ...(candidateModelId ? [candidateModelId] : []),
    ...(candidateThreshold ? [candidateThreshold] : []),
    ...(filterOptions ? [filterOptions] : []),
    ...(sortOrder ? [sortOrder] : []),
    ...(limit ? [limit] : []),
  ],
  confusionMatrix: (
    projectId: ProjectId,
    modelId: string,
    evaluationSetId?: number,
    threshold?: number,
  ) =>
    [
      projectId,
      ...modelAnalysisQueryKeys.all,
      'confusionMatrix',
      modelId,
      { evaluationSetId, threshold },
    ] as const,
  confusionMatrixCounts: (
    projectId: ProjectId,
    modelId: string,
    evaluationSetId?: number,
    threshold?: number,
  ) =>
    [
      projectId,
      ...modelAnalysisQueryKeys.all,
      'confusionMatrixCounts',
      modelId,
      { evaluationSetId, threshold },
    ] as const,
  modelComparisonReports: (projectId: ProjectId) => [
    projectId,
    ...modelAnalysisQueryKeys.all,
    'modelComparisonReports',
  ],
  annotationInstancePairs: (
    projectId: ProjectId,
    modelId: string,
    mediaIds: number[],
    threshold?: number,
    version?: number,
  ) => [
    projectId,
    ...modelAnalysisQueryKeys.all,
    'annotationInstancePairs',
    { modelId },
    { mediaIds },
    ...(threshold ? [threshold] : []),
    ...(version ? [version] : []),
  ],
  bundleStatus: (projectId: ProjectId, modelId: RegisteredModelId, threshold: number) => [
    projectId,
    ...modelAnalysisQueryKeys.all,
    'bundleStatus',
    modelId,
    threshold,
  ],
  modelEvaluationReports: (projectId: ProjectId, modelId: string) => [
    projectId,
    ...modelAnalysisQueryKeys.all,
    'modelEvaluationReports',
    modelId,
  ],
  modelPerformanceSummary: (projectId: ProjectId, modelId: string) => [
    projectId,
    ...modelAnalysisQueryKeys.all,
    'modelPerformanceSummary',
    modelId,
  ],
};

export const isBundleInEndState = (status: BundleStatus | undefined) => {
  if (status && ['not_found', 'created', 'deployed'].includes(status)) {
    return true;
  }
  return false;
};

export const isBundleInDeployableState = (status: BundleStatus | undefined) => {
  if (status && ['created', 'deployed'].includes(status)) {
    return true;
  }
  return false;
};

export const useGetBundleStatusQuery = (modelId: RegisteredModelId, threshold: number) => {
  const { id: projectId = 0 } = useGetSelectedProjectQuery().data ?? {};
  return useQuery<BundleStatusItem, ApiError>({
    queryKey: modelAnalysisQueryKeys.bundleStatus(projectId, modelId, threshold),
    queryFn: async () => {
      const response = clp_api.getBundleStatus({
        projectId,
        modelId,
        threshold,
      });
      return response;
    },
    // polling
    refetchInterval: (data, query) => {
      // stop polling on any error
      if (query.state.error) {
        return false;
      }
      // stop polling if the bundle creation ended
      const { status } = data ?? {};
      if (isBundleInEndState(status)) {
        return false;
      }
      // refetch after 2000ms
      return 2000;
    },
    enabled: !!projectId && !!modelId && isNumber(threshold),
  });
};

export const useGetModelListQuery = () => {
  const { id: projectId = 0 } = useGetSelectedProjectQuery().data ?? {};
  return useQuery<RegisteredModelWithBundles[], ApiError>({
    queryKey: modelAnalysisQueryKeys.modelList(projectId),
    queryFn: async () => {
      const response = await model_api.getModelList({
        projectId,
      });
      return response.modelList;
    },
    enabled: !!projectId,
  });
};

export const useGetModelMetricsQuery = (
  modelId?: string,
  evaluationSetId?: number,
  threshold?: number,
) => {
  const { id: projectId = 0 } = useGetSelectedProjectQuery().data ?? {};
  return useQuery<ModelMetricsResponse | undefined, ApiError>({
    queryKey: modelAnalysisQueryKeys.metrics(projectId, modelId!, evaluationSetId!, threshold!),
    queryFn: async ctx => {
      if (!modelId || !evaluationSetId || !isNumber(threshold)) {
        return undefined;
      }
      const lastResponse = queryClient.getQueryData<ModelEvaluationReport>(ctx.queryKey);
      const response = await model_analysis_api.getModelMetrics({
        projectId,
        modelId: modelId!,
        evaluationSetId: evaluationSetId!,
        threshold: threshold!,
        lastUpdatedAt: lastResponse?.updatedAt,
      });
      return response;
    },
    enabled: !!projectId && !!modelId && !!evaluationSetId && isNumber(threshold),
    refetchInterval: (data, query) => {
      if (query.state.error) {
        return false;
      }

      const { evaluationStatus } = data ?? {};
      if (evaluationStatus === ModelEvaluationReportStatus.STARTED) {
        return 3000;
      }
      return false;
    },
  });
};

export const useGetBatchModelMetricsQueries = () => {
  const { id: projectId = 0 } = useGetSelectedProjectQuery().data ?? {};
  const [batchModelMetricsParam] = useAtom(batchModelMetricsParamAtom);
  return useQueries<UseQueryOptions<BatchModelMetricResponse, ApiError>[]>({
    queries: batchModelMetricsParam.map(bundleParams => ({
      queryKey: modelAnalysisQueryKeys.batchMetrics(projectId, bundleParams),
      queryFn: async () => {
        const response = (await model_analysis_api.getBatchModelMetrics(projectId, bundleParams))
          .data;
        return response;
      },
      enabled: !!projectId && bundleParams.length > 0,
      refetchInterval: (data, query) => {
        if (query.state.error) {
          return false;
        }
        if (
          data?.some(metrics => metrics.evaluationStatus === ModelEvaluationReportStatus.STARTED)
        ) {
          return 3000;
        }
        return false;
      },
    })),
  });
};

export const useGetBatchModelMetricsQueryByModelId = (modelId: string, threshold: number) => {
  const queries = useGetBatchModelMetricsQueries();
  const [batchModelMetricsParam] = useAtom(batchModelMetricsParamAtom);
  return useMemo(() => {
    const index = batchModelMetricsParam.findIndex(params =>
      params.some(param => param.modelId === modelId && param.threshold === threshold),
    );
    return index >= 0 ? queries[index] : undefined;
  }, [queries, modelId, threshold, batchModelMetricsParam]);
};

export const useGetModelMediaListInfiniteQuery = (
  modelId?: string,
  threshold?: number,
  evaluationSet?: EvaluationSetItem | null,
  filterOptions?: FilterOptions,
  candidateModelId?: string,
  candidateThreshold?: number,
  sortOrder: SortOrder = SortOrder.ASC,
  limit = 50,
) => {
  const { id: projectId = 0, labelType } = useGetSelectedProjectQuery().data ?? {};
  return useInfiniteQuery<
    | (GetMediaListResponse & {
        allInstances: AnnotationInstance[];
        candidateAllInstances: AnnotationInstance[];
      })
    | undefined,
    ApiError
  >({
    queryKey: modelAnalysisQueryKeys.paginatedMediaList(
      projectId,
      modelId!,
      evaluationSet?.id!,
      labelType!,
      threshold!,
      candidateModelId,
      candidateThreshold,
      filterOptions,
      sortOrder,
      limit,
      evaluationSet?.datasetVersion.version,
    ),
    queryFn: async ({ pageParam = 0, queryKey }) => {
      const lastResponse = queryClient.getQueryData<InfiniteData<GetMediaListResponse>>(queryKey);
      const response = await model_analysis_api.getMediaList({
        projectId,
        modelId: modelId!,
        candidateModelId,
        candidateThreshold,
        evaluationSetId: evaluationSet?.id!,
        threshold: threshold!,
        ...(filterOptions && { filterOptions }),
        sortOrder,
        lastUpdatedAt: lastResponse?.pages[0].updatedAt,
        offset:
          labelType === LabelType.BoundingBox || labelType === LabelType.Segmentation
            ? pageParam
            : undefined,
        limit:
          labelType === LabelType.BoundingBox || labelType === LabelType.Segmentation
            ? limit
            : undefined,
      });
      const allInstancesResponse =
        labelType === LabelType.BoundingBox
          ? (
              await model_analysis_api.getAnnotationInstancePairsByIds({
                projectId,
                modelId: modelId!,
                mediaIds: response.mediaList.map(m => m.id),
                threshold,
                version: evaluationSet?.datasetVersion.version,
              })
            ).data
          : [];
      const candidateAllInstances =
        labelType === LabelType.BoundingBox && candidateModelId && candidateThreshold
          ? (
              await model_analysis_api.getAnnotationInstancePairsByIds({
                projectId,
                modelId: candidateModelId,
                mediaIds: response.mediaList.map(m => m.id),
                threshold: candidateThreshold,
                version: evaluationSet?.datasetVersion.version,
              })
            ).data
          : [];
      return {
        ...response,
        allInstances: allInstancesResponse,
        candidateAllInstances,
      };
    },
    getNextPageParam: (lastPage, pages) => {
      if (
        (labelType !== LabelType.BoundingBox && labelType !== LabelType.Segmentation) ||
        !lastPage ||
        lastPage.mediaList.length < limit
      )
        return undefined;
      return pages.flatMap(page => page?.mediaList ?? []).length;
    },
    refetchInterval: (data, query) => {
      if (query.state.error) {
        return false;
      }

      const { evaluationStatus } = data?.pages[0] ?? {};
      if (evaluationStatus === ModelEvaluationReportStatus.STARTED) {
        return 3000;
      }
      return false;
    },
    enabled: !!projectId && !!modelId && !!evaluationSet?.id && isNumber(threshold),
  });
};

export const useGetConfusionMatrixQuery = (
  modelId?: string,
  evaluationSetId?: number,
  threshold?: number,
) => {
  const { id: projectId = 0 } = useGetSelectedProjectQuery().data ?? {};
  return useQuery<ConfusionMatrixResponse | undefined, ApiError>({
    queryKey: modelAnalysisQueryKeys.confusionMatrix(
      projectId,
      modelId!,
      evaluationSetId!,
      threshold!,
    ),
    queryFn: async ctx => {
      if (!modelId || !evaluationSetId || !isNumber(threshold)) {
        return undefined;
      }
      const lastResponse = queryClient.getQueryData<ConfusionMatrixResponse>(ctx.queryKey);
      const response = await model_analysis_api.getConfusionMatrix({
        projectId,
        modelId: modelId!,
        evaluationSetId: evaluationSetId!,
        threshold: threshold!,
        lastUpdatedAt: lastResponse?.updatedAt,
      });
      return response;
    },
    enabled: !!projectId && !!modelId && !!evaluationSetId && isNumber(threshold),
    refetchInterval: (data, query) => {
      if (query.state.error) {
        return false;
      }

      const { evaluationStatus } = data ?? {};
      if (evaluationStatus === ModelEvaluationReportStatus.STARTED) {
        return 3000;
      }
      return false;
    },
  });
};

export const useGetConfusionMatrixCountsQuery = (
  modelId?: string,
  evaluationSetId?: number,
  threshold?: number,
) => {
  // TODO: @Louie change to useInfiniteQuery or do pagination
  const { id: projectId = 0 } = useGetSelectedProjectQuery().data ?? {};
  return useQuery<ConfusionMatrixCountByMediaId | undefined, ApiError>({
    queryKey: modelAnalysisQueryKeys.confusionMatrixCounts(
      projectId,
      modelId!,
      evaluationSetId!,
      threshold!,
    ),
    queryFn: async () => {
      if (!modelId || !evaluationSetId || !isNumber(threshold)) {
        return undefined;
      }
      const response = await model_analysis_api.getConfusionMatrixCounts({
        projectId,
        modelId: modelId!,
        evaluationSetId: evaluationSetId!,
        threshold: threshold!,
      });
      return response;
    },
    enabled: !!projectId && !!modelId && !!evaluationSetId && isNumber(threshold),
  });
};

export const useMediasAnnotationInstancePairsQuery = (
  params: {
    modelId?: string;
    mediaIds: number[];
    threshold?: number;
    version?: number;
  },
  isEnabled = true,
) => {
  const { id: projectId = 0 } = useGetSelectedProjectQuery().data ?? {};
  const modelId = params.modelId ?? '';
  return useQuery({
    queryKey: modelAnalysisQueryKeys.annotationInstancePairs(
      projectId,
      modelId,
      params.mediaIds,
      params.threshold,
      params.version,
    ),
    queryFn: async () => {
      const res = await model_analysis_api.getAnnotationInstancePairsByIds({
        projectId,
        modelId,
        mediaIds: params.mediaIds,
        threshold: params.threshold,
        version: params.version,
      });
      return res.data;
    },
    enabled: isEnabled && !!params.modelId && params.mediaIds.length > 0,
  });
};

export const useGetModelEvaluationReportsQuery = (modelId?: string) => {
  const { id: projectId = 0 } = useGetSelectedProjectQuery().data ?? {};
  return useQuery<ModelEvaluationReport[], ApiError>({
    queryKey: modelAnalysisQueryKeys.modelEvaluationReports(projectId, modelId!),
    queryFn: async () => {
      const response = model_analysis_api.getModelEvaluationReports(projectId, modelId!);
      return response;
    },
    enabled: !!projectId && !!modelId,
    refetchInterval: (data, query) => {
      if (query.state.error) {
        return false;
      }
      const hasIncompleteReport = data?.some(r => r.status === ModelEvaluationReportStatus.STARTED);
      if (hasIncompleteReport) {
        return 3000;
      }
      return false;
    },
  });
};

export const useGetModelPerformanceSummaryQuery = (modelId?: string) => {
  const { id: projectId = 0 } = useGetSelectedProjectQuery().data ?? {};
  return useQuery<GetModelPerformanceSummaryResponse | undefined, ApiError>({
    queryKey: modelAnalysisQueryKeys.modelPerformanceSummary(projectId, modelId!),
    queryFn: async () => {
      if (!modelId) {
        return undefined;
      }
      const response = model_analysis_api.getModelPerformanceSummary(projectId, modelId);
      return response;
    },
    enabled: !!projectId,
  });
};
