import { Box, LoadingOverlay } from '@mantine/core'
import { useIntl } from 'react-intl'
import { getApiError } from '@/api/helpers/apiError'
import { useInvalidateModelDetails } from '@/queries/modelQueries'
import { useCreateTraining, useRetryTraining } from '@/queries/trainingQueries'
import { showToast } from '@/theme/notifications'
import { MLModel } from '@/types/model'
import { DatasetSplitType, Training, TrainingType } from '@/types/training'
import { useModelStepNavigation } from '../models/ModelDetails/useModelStepNavigation'
import {
  FormValues,
  TrainingParamsForm
} from './TraningParamsForm/TrainingParamsForm'

type CreateTrainingHandlerProps = {
  applicationId: string
  model: MLModel
  training?: Training
  onCancel: () => void
  onTrainingCreated?: () => void
}

export const CreateTrainingHandler = ({
  applicationId,
  model,
  training,
  onCancel,
  onTrainingCreated
}: CreateTrainingHandlerProps) => {
  const intl = useIntl()

  const { goToModelTrainingScreen } = useModelStepNavigation({
    appId: applicationId,
    modelId: model.id
  })

  const { mutateAsync: createTraining, isPending: isCreatePending } =
    useCreateTraining()
  const { mutateAsync: retryTraining, isPending: isRetryPending } =
    useRetryTraining()

  const isPending = isCreatePending || isRetryPending

  const { invalidateModelDetails } = useInvalidateModelDetails(model.id)

  const handleTrainModel = async (values: FormValues) => {
    const payload = {
      application_id: applicationId,
      ml_model_id: model.id,
      training_set_ids: [],
      validation_set_ids: [],
      test_set_ids: [],
      training_type:
        values.trainingType === 'fine_tune'
          ? TrainingType.FineTuning
          : TrainingType.Full,
      dataset_split_type: DatasetSplitType.Automatic,
      frozen_layers: values.frozenLayers,
      patience: values.patience,
      epoch: values.epoch,
      batch_size: values.batchSize,
      resolution: values.inferenceResolution
    }

    try {
      if (training) {
        await retryTraining({
          trainingId: training.id,
          data: payload
        })
      } else {
        await createTraining(payload)
      }

      onTrainingCreated?.()

      invalidateModelDetails()

      if (!training) {
        goToModelTrainingScreen()
      }

      showToast(intl.formatMessage({ id: 'training.started' }), 'success')
    } catch (err) {
      const { errorMessage } = getApiError(err)

      const message =
        errorMessage || intl.formatMessage({ id: 'training.error' })

      showToast(message, 'error')
    }
  }

  return (
    <Box pos="relative">
      <LoadingOverlay visible={isPending} />

      <TrainingParamsForm
        initialValues={
          training
            ? {
                trainingType:
                  training.training_type === TrainingType.FineTuning
                    ? 'fine_tune'
                    : 'full_retrain',
                frozenLayers: training.frozen_layers,
                patience: training.patience,
                epoch: training.epoch,
                batchSize: training.batch_size,
                inferenceResolution: training.resolution
              }
            : undefined
        }
        labels={model.dataset_version?.labels || []}
        onSubmit={(values) => void handleTrainModel(values)}
        onCancel={onCancel}
      />
    </Box>
  )
}
