import { BaseAPI } from './base_api';
import {
  AnnotationInstance,
  ApiResponse,
  Media,
  ModelEvaluationReport,
  ModelEvaluationReportStatus,
  ModelMetrics,
  ProjectId,
  MediaId,
  ConfusionMatrixCountByMediaId,
  SplitConfusionMatrices,
  RegisteredModelId,
  ModelEvaluationReportPerformance,
} from '@clef/shared/types';

export type ModelMetricsResponse = {
  evaluationStatus?: ModelEvaluationReportStatus;
  metrics?: ModelMetrics;
};

export type GetMediaListResponse = {
  total: number;
  mediaList: Array<Media & { count?: number }>;
  evaluationStatus?: ModelEvaluationReportStatus;
  updatedAt?: string | null;
};

export type ConfusionMatrixResponse = {
  splitConfusionMatrices?: SplitConfusionMatrices | null;
  status?: 'started' | 'in_progress' | 'completed';
};

export type GetModelEvaluationReportsResponse = ModelEvaluationReport[];

export type GetModelPerformanceSummaryResponse = {
  performance: {
    train?: ModelEvaluationReportPerformance;
    dev?: ModelEvaluationReportPerformance;
    test?: ModelEvaluationReportPerformance;
  };
  confusionMatrices: {
    all?: SplitConfusionMatrices;
    train?: SplitConfusionMatrices;
    dev?: SplitConfusionMatrices;
    test?: SplitConfusionMatrices;
  };
};

export type FilterOptions = {
  gtClassId: number;
  predClassId: number;
};

export enum SortOrder {
  ASC = 'ASC',
  DESC = 'DESC',
}

export type BatchModelMetricResponse = Array<
  ModelMetricsResponse & { evaluationSetId: number; evalMediaIdsInTrainSet?: MediaId[] }
>;

class ModelAnalysisAPI extends BaseAPI {
  async getBatchModelMetrics(
    projectId: ProjectId,
    modelId: string,
    threshold: number,
    evaluationSetIds: number[],
  ): Promise<ApiResponse<BatchModelMetricResponse>> {
    return this.postJSON('batch_model_metrics', {
      projectId,
      modelId,
      threshold,
      evaluationSetIds,
    });
  }

  async getMediaList(params: {
    projectId: ProjectId;
    modelId: string;
    evaluationSetId: number;
    threshold: number;
    candidateModelId?: string;
    candidateThreshold?: number;
    filterOptions?: FilterOptions;
    sortOrder?: SortOrder;
    lastUpdatedAt?: string | null;
    offset?: number;
    limit?: number;
  }): Promise<GetMediaListResponse> {
    const { filterOptions, ...other } = params;
    return this.get(
      'media_list',
      {
        ...other,
        filterOptions: JSON.stringify(filterOptions),
      },
      true,
    );
  }

  async getAnnotationInstancePairsByIds(params: {
    projectId: ProjectId;
    modelId: string;
    mediaIds: number[];
    threshold?: number;
    version?: number;
  }): Promise<ApiResponse<AnnotationInstance[]>> {
    return this.get('annotation_instance_pairs', {
      ...params,
      mediaIds: params.mediaIds.join(),
    });
  }

  async getConfusionMatrix(params: {
    projectId: ProjectId;
    modelId: string;
    evaluationSetId: number;
    threshold: number;
    lastUpdatedAt?: string | null;
  }): Promise<ConfusionMatrixResponse> {
    return this.get('model_confusion_matrix', params, true);
  }

  async getConfusionMatrixCounts(params: {
    projectId: ProjectId;
    modelId: string;
    evaluationSetId: number;
    threshold: number;
  }): Promise<ConfusionMatrixCountByMediaId> {
    return this.get('model_confusion_matrix_counts', params, true);
  }

  async runEvaluation(params: {
    projectId: ProjectId;
    modelId: string;
    evaluationSetId: number;
    threshold: number;
  }) {
    return this.postJSON('run_evaluation', params);
  }

  async createBundle(params: {
    projectId: ProjectId;
    modelId: string;
    evaluationSetId: number;
    threshold: number;
  }): Promise<ApiResponse<void>> {
    return this.postJSON(`create_bundle`, params);
  }

  async getModelEvaluationReports(
    projectId: ProjectId,
    modelId: string,
  ): Promise<GetModelEvaluationReportsResponse> {
    return this.get('model_evaluation_reports', { projectId, modelId }, true);
  }
  async getModelEvaluationReportCsv(projectId: ProjectId, modelEvaluationReportId: number) {
    return this.get('model_evaluation_report_csv', { projectId, modelEvaluationReportId }, true);
  }

  async getModelPerformanceSummary(
    projectId: ProjectId,
    modelId: RegisteredModelId,
  ): Promise<GetModelPerformanceSummaryResponse> {
    return this.get('model_performance_summary', { projectId, modelId }, true);
  }
}

export default new ModelAnalysisAPI('model_analysis');
