import {
  Anchor,
  Box,
  Button,
  Card,
  Collapse,
  Grid,
  Group,
  NumberInput,
  Radio,
  SegmentedControl,
  Select,
  SimpleGrid,
  Text,
  rem
} from '@mantine/core'
import { isNotEmpty, useForm } from '@mantine/form'
import { FormattedMessage, useIntl } from 'react-intl'
import { LabelMultiSelect } from '@/components/labels/LabelMultiSelect/LabelMultiSelect'
import { LabelWithHint } from '@/components/ui-shared/LabelWithHint/LabelWithHint'
import { Label } from '@/types/label'
import { MLModelType } from '@/types/model'

const BATCH_SIZES = [8, 16, 32]
const DEFAULT_BATCH_SIZE = BATCH_SIZES[1]

const PRIMARY_MODEL_INFERENCE_RESOLUTIONS = [
  '128x128',
  '256x256',
  '320x320',
  '448x448',
  '640x640',
  '1024x1024'
]

const SECONDARY_MODEL_INFERENCE_RESOLUTIONS = [
  '128x128',
  '256x256',
  '320x320',
  '448x448'
]

const FROZEN_LAYERS_FINE_TUNING = 10

const DEFAULT_PARAMS = {
  full_retrain: {
    frozenLayers: 0,
    patience: 15,
    epoch: 250,
    batchSize: DEFAULT_BATCH_SIZE
  },
  fine_tune: {
    frozenLayers: FROZEN_LAYERS_FINE_TUNING,
    patience: 15,
    epoch: 250,
    batchSize: DEFAULT_BATCH_SIZE
  }
}

type TrainingType = 'full_retrain' | 'fine_tune'

export type FormValues = {
  trainingType: TrainingType
  frozenLayers: number
  patience: number
  epoch: number
  batchSize: number
  inferenceResolution: string
  customWidth: number
  customHeight: number
  labels: string[]
}

type TrainingParamsFormProps = {
  modelType: MLModelType
  initialValues?: Partial<Omit<FormValues, 'labels'>>
  labels: Label[]
  onSubmit: (values: FormValues) => void
  onCancel: () => void
}

export const TrainingParamsForm = ({
  modelType,
  labels,
  initialValues,
  onSubmit,
  onCancel
}: TrainingParamsFormProps) => {
  const intl = useIntl()
  const form = useForm<FormValues>({
    initialValues: {
      ...DEFAULT_PARAMS.full_retrain,
      trainingType: 'full_retrain',
      customWidth: 0,
      customHeight: 0,
      inferenceResolution:
        modelType === MLModelType.Primary
          ? PRIMARY_MODEL_INFERENCE_RESOLUTIONS[4]
          : SECONDARY_MODEL_INFERENCE_RESOLUTIONS[1],
      ...initialValues,
      labels: labels.map((label) => label.id)
    },

    validate: {
      frozenLayers: isNotEmpty(
        intl.formatMessage({ id: 'training.params.validation.enterValue' })
      ),
      patience: isNotEmpty(
        intl.formatMessage({ id: 'training.params.validation.enterValue' })
      ),
      epoch: isNotEmpty(
        intl.formatMessage({ id: 'training.params.validation.enterValue' })
      ),
      customWidth: (value: number | '', values) => {
        if (values.inferenceResolution === '$custom') {
          return value !== ''
            ? null
            : intl.formatMessage({
                id: 'training.params.validation.enterValue'
              })
        }

        return null
      },
      customHeight: (value: number | '', values) => {
        if (values.inferenceResolution === '$custom') {
          return value !== ''
            ? null
            : intl.formatMessage({
                id: 'training.params.validation.enterValue'
              })
        }

        return null
      }
    }
  })

  const labelOptions = labels.map((label) => ({
    labelId: label.id,
    name: label.name,
    color: label.color
  }))

  const inferenceResolutions =
    modelType === MLModelType.Primary
      ? PRIMARY_MODEL_INFERENCE_RESOLUTIONS
      : SECONDARY_MODEL_INFERENCE_RESOLUTIONS

  const inferenceResolutionOptions = inferenceResolutions
    .map((resolution) => ({
      value: resolution,
      label: resolution
    }))
    .concat({
      value: '$custom',
      label: intl.formatMessage({
        id: 'training.params.inferenceResolution.custom'
      })
    })

  const handleTrainingTypeChange = (value: TrainingType) => {
    form.setFieldValue('trainingType', value)

    if (value === 'fine_tune') {
      form.setValues({
        frozenLayers: FROZEN_LAYERS_FINE_TUNING
      })
    }

    if (value === 'full_retrain') {
      form.setValues({
        frozenLayers: 0
      })
    }
  }

  const handleUseDefaultParams = () => {
    form.setValues(DEFAULT_PARAMS[form.values.trainingType])
  }

  const handleSubmit = (values: FormValues) => {
    const resolution =
      values.inferenceResolution === '$custom'
        ? `${values.customWidth}x${values.customHeight}`
        : values.inferenceResolution

    onSubmit({
      ...values,
      inferenceResolution: resolution
    })
  }

  return (
    <form onSubmit={form.onSubmit((values) => handleSubmit(values))}>
      <Group mb="lg" gap="xs">
        <Text fw="bold" size="sm">
          <FormattedMessage id="training.params.choose" />
        </Text>

        <Anchor
          component="button"
          type="button"
          size="xs"
          onClick={handleUseDefaultParams}
        >
          (<FormattedMessage id="training.params.useRecommended" />)
        </Anchor>
      </Group>

      <Card bg="blue.0" p="lg" mb="xl">
        <Radio.Group
          value={form.values.trainingType}
          styles={{
            label: {
              marginBottom: rem(16)
            }
          }}
          label={
            <LabelWithHint
              label={intl.formatMessage({
                id: 'training.params.trainingType'
              })}
              hint={intl.formatMessage({
                id: 'training.params.trainingTypeHint'
              })}
            />
          }
          onChange={(value) => handleTrainingTypeChange(value as TrainingType)}
        >
          <Group gap="xl">
            <Radio
              label={
                <FormattedMessage id="training.params.trainingType.fullRetraining" />
              }
              value="full_retrain"
            />

            <Radio
              label={
                <FormattedMessage id="training.params.trainingType.fineTuning" />
              }
              value="fine_tune"
            />
          </Group>
        </Radio.Group>
      </Card>

      <Box>
        <Box mb="xl">
          <LabelMultiSelect
            label={
              <LabelWithHint
                label={intl.formatMessage({
                  id: 'training.params.classes'
                })}
                hint={intl.formatMessage({
                  id: 'training.params.classesHint'
                })}
              />
            }
            values={form.values.labels}
            options={labelOptions}
            disabled
          />
        </Box>

        <Grid mb="xl">
          <Grid.Col span={{ base: 12, lg: 4 }}>
            <NumberInput
              label={
                <LabelWithHint
                  label={intl.formatMessage({
                    id: 'training.params.frozenLayers'
                  })}
                  hint={intl.formatMessage({
                    id: 'training.params.frozenLayersHint'
                  })}
                />
              }
              min={0}
              disabled={form.values.trainingType === 'full_retrain'}
              {...form.getInputProps('frozenLayers')}
            />
          </Grid.Col>

          <Grid.Col span={{ base: 12, lg: 4 }}>
            <NumberInput
              label={
                <LabelWithHint
                  label={intl.formatMessage({
                    id: 'training.params.patience'
                  })}
                  hint={intl.formatMessage({
                    id: 'training.params.patienceHint'
                  })}
                />
              }
              min={0}
              {...form.getInputProps('patience')}
            />
          </Grid.Col>

          <Grid.Col span={{ base: 12, lg: 4 }}>
            <NumberInput
              label={
                <LabelWithHint
                  label={intl.formatMessage({
                    id: 'training.params.epoch'
                  })}
                  hint={intl.formatMessage({
                    id: 'training.params.epochHint'
                  })}
                />
              }
              min={0}
              {...form.getInputProps('epoch')}
            />
          </Grid.Col>
        </Grid>

        <Box mb="xl">
          <LabelWithHint
            label={
              <Text size="sm" mb={2} fw={500}>
                {intl.formatMessage({ id: 'training.params.batchSize' })}
              </Text>
            }
            hint={intl.formatMessage({
              id: 'training.params.batchSizeHint'
            })}
          />

          <SegmentedControl
            size="xs"
            styles={{
              control: {
                minWidth: rem(60)
              }
            }}
            color="blue"
            value={form.values.batchSize.toString()}
            data={BATCH_SIZES.map((size) => ({
              value: size.toString(),
              label: size.toString()
            }))}
            onChange={(value) =>
              form.setFieldValue('batchSize', parseInt(value))
            }
          />
        </Box>

        <Select
          value={form.values.inferenceResolution}
          label={
            <LabelWithHint
              label={intl.formatMessage({
                id: 'training.params.inferenceResolution'
              })}
              hint={intl.formatMessage({
                id: 'training.params.inferenceResolutionHint'
              })}
            />
          }
          data={inferenceResolutionOptions}
          onChange={(value) =>
            form.setFieldValue('inferenceResolution', value as string)
          }
        />

        <Collapse in={form.values.inferenceResolution === '$custom'}>
          <SimpleGrid cols={2} mt="xl">
            <NumberInput
              label={
                <FormattedMessage id="training.params.inferenceResolution.customWidth" />
              }
              min={0}
              {...form.getInputProps('customWidth')}
            />

            <NumberInput
              label={
                <FormattedMessage id="training.params.inferenceResolution.customHeight" />
              }
              min={0}
              {...form.getInputProps('customHeight')}
            />
          </SimpleGrid>
        </Collapse>
      </Box>

      <Group mt={80} justify="center">
        <Button miw={200} variant="light" onClick={onCancel}>
          <FormattedMessage id="cancel" />
        </Button>

        <Button type="submit" miw={200}>
          <FormattedMessage id="models.trainModel" />
        </Button>
      </Group>
    </form>
  )
}
