import {
  Box,
  Button,
  Card,
  Collapse,
  Grid,
  Group,
  NumberInput,
  Radio,
  SegmentedControl,
  Select,
  SimpleGrid,
  Text,
  rem
} from '@mantine/core'
import { isNotEmpty, useForm } from '@mantine/form'
import { FormattedMessage, useIntl } from 'react-intl'
import { LabelMultiSelect } from '@/components/labels/LabelMultiSelect/LabelMultiSelect'
import { LabelWithHint } from '@/components/ui-shared/LabelWithHint/LabelWithHint'
import { Label } from '@/types/label'

const BATCH_SIZES = [8, 16, 32, 64, 128]

const INFERENCE_RESOLUTIONS = [
  '128x128',
  '256x256',
  '320x320',
  '448x448',
  '640x640',
  '1024x1024'
]

const FROZEN_LAYERS_FINE_TUNING = 10

const DEFAULT_PARAMS = {
  full_retrain: {
    frozenLayers: 0,
    patience: 50,
    epoch: 250,
    batchSize: BATCH_SIZES[4],
    inferenceResolution: INFERENCE_RESOLUTIONS[4]
  },
  fine_tune: {
    frozenLayers: FROZEN_LAYERS_FINE_TUNING,
    patience: 50,
    epoch: 150,
    batchSize: BATCH_SIZES[4],
    inferenceResolution: INFERENCE_RESOLUTIONS[4]
  }
}

type TrainingType = 'full_retrain' | 'fine_tune'

export type FormValues = {
  trainingType: TrainingType
  frozenLayers: number
  patience: number
  epoch: number
  batchSize: number
  inferenceResolution: string
  customWidth: number
  customHeight: number
  labels: string[]
}

type TrainingParamsFormProps = {
  labels: Label[]
  onSubmit: (values: FormValues) => void
}

export const TrainingParamsForm = ({
  labels,
  onSubmit
}: TrainingParamsFormProps) => {
  const intl = useIntl()
  const form = useForm<FormValues>({
    initialValues: {
      ...DEFAULT_PARAMS.full_retrain,
      trainingType: 'full_retrain',
      customWidth: 0,
      customHeight: 0,
      labels: []
    },

    validate: {
      labels: (value) =>
        value.length === 0
          ? intl.formatMessage({
              id: 'labels.validation.selectAtLeastOneLabel'
            })
          : null,
      frozenLayers: isNotEmpty(
        intl.formatMessage({ id: 'trainingParams.validation.enterValue' })
      ),
      patience: isNotEmpty(
        intl.formatMessage({ id: 'trainingParams.validation.enterValue' })
      ),
      epoch: isNotEmpty(
        intl.formatMessage({ id: 'trainingParams.validation.enterValue' })
      ),
      customWidth: (value: number | '', values) => {
        if (values.inferenceResolution === '$custom') {
          return value !== ''
            ? null
            : intl.formatMessage({
                id: 'trainingParams.validation.enterValue'
              })
        }

        return null
      },
      customHeight: (value: number | '', values) => {
        if (values.inferenceResolution === '$custom') {
          return value !== ''
            ? null
            : intl.formatMessage({
                id: 'trainingParams.validation.enterValue'
              })
        }

        return null
      }
    }
  })

  const labelOptions = labels.map((label) => ({
    labelId: label.id,
    name: label.name,
    color: label.color
  }))

  const inferenceResolutionOptions = INFERENCE_RESOLUTIONS.map(
    (resolution) => ({
      value: resolution,
      label: resolution
    })
  ).concat({
    value: '$custom',
    label: intl.formatMessage({
      id: 'trainingParams.inferenceResolution.custom'
    })
  })

  const handleTrainingTypeChange = (value: TrainingType) => {
    form.setFieldValue('trainingType', value)

    if (value === 'fine_tune') {
      form.setValues({
        labels: labels.map((label) => label.id),
        frozenLayers: FROZEN_LAYERS_FINE_TUNING
      })
    }

    if (value === 'full_retrain') {
      form.setValues({
        labels: [],
        frozenLayers: 0
      })
    }
  }

  const handleUseDefaultParams = () => {
    form.setValues(DEFAULT_PARAMS[form.values.trainingType])
  }

  const handleSubmit = (values: FormValues) => {
    const resolution =
      values.inferenceResolution === '$custom'
        ? `${values.customWidth}x${values.customHeight}`
        : values.inferenceResolution

    onSubmit({
      ...values,
      inferenceResolution: resolution
    })
  }

  return (
    <form onSubmit={form.onSubmit((values) => handleSubmit(values))}>
      <Card bg="blue.0" p="lg" mb="xl">
        <Radio.Group
          value={form.values.trainingType}
          styles={{
            label: {
              marginBottom: rem(16)
            }
          }}
          label={
            <LabelWithHint
              label={intl.formatMessage({
                id: 'trainingParams.trainingType'
              })}
              hint={intl.formatMessage({
                id: 'trainingParams.trainingTypeHint'
              })}
            />
          }
          onChange={(value) => handleTrainingTypeChange(value as TrainingType)}
        >
          <Group gap="xl">
            <Radio
              label={
                <FormattedMessage id="trainingParams.trainingType.fullRetraining" />
              }
              value="full_retrain"
            />

            <Radio
              label={
                <FormattedMessage id="trainingParams.trainingType.fineTuning" />
              }
              value="fine_tune"
            />
          </Group>
        </Radio.Group>
      </Card>

      <Box>
        <Box mb="xl">
          <LabelMultiSelect
            label={
              <LabelWithHint
                label={intl.formatMessage({
                  id: 'trainingParams.classes'
                })}
                hint={intl.formatMessage({
                  id: 'trainingParams.classesHint'
                })}
              />
            }
            values={form.values.labels}
            options={labelOptions}
            disabled={form.values.trainingType === 'fine_tune'}
            error={form.errors.labels}
            onChange={(value) => form.setFieldValue('labels', value)}
          />
        </Box>

        <Grid mb="xl">
          <Grid.Col span={{ base: 12, lg: 4 }}>
            <NumberInput
              label={
                <LabelWithHint
                  label={intl.formatMessage({
                    id: 'trainingParams.frozenLayers'
                  })}
                  hint={intl.formatMessage({
                    id: 'trainingParams.frozenLayersHint'
                  })}
                />
              }
              min={0}
              disabled={form.values.trainingType === 'full_retrain'}
              {...form.getInputProps('frozenLayers')}
            />
          </Grid.Col>

          <Grid.Col span={{ base: 12, lg: 4 }}>
            <NumberInput
              label={
                <LabelWithHint
                  label={intl.formatMessage({
                    id: 'trainingParams.patience'
                  })}
                  hint={intl.formatMessage({
                    id: 'trainingParams.patienceHint'
                  })}
                />
              }
              min={0}
              {...form.getInputProps('patience')}
            />
          </Grid.Col>

          <Grid.Col span={{ base: 12, lg: 4 }}>
            <NumberInput
              label={
                <LabelWithHint
                  label={intl.formatMessage({
                    id: 'trainingParams.epoch'
                  })}
                  hint={intl.formatMessage({
                    id: 'trainingParams.epochHint'
                  })}
                />
              }
              min={0}
              {...form.getInputProps('epoch')}
            />
          </Grid.Col>
        </Grid>

        <Box mb="xl">
          <LabelWithHint
            label={
              <Text size="sm" mb={2} fw={500}>
                {intl.formatMessage({ id: 'trainingParams.batchSize' })}
              </Text>
            }
            hint={intl.formatMessage({
              id: 'trainingParams.batchSizeHint'
            })}
          />

          <SegmentedControl
            size="xs"
            styles={{
              control: {
                minWidth: rem(60)
              }
            }}
            color="blue"
            value={form.values.batchSize.toString()}
            data={BATCH_SIZES.map((size) => ({
              value: size.toString(),
              label: size.toString()
            }))}
            onChange={(value) =>
              form.setFieldValue('batchSize', parseInt(value))
            }
          />
        </Box>

        <Select
          value={form.values.inferenceResolution}
          label={
            <LabelWithHint
              label={intl.formatMessage({
                id: 'trainingParams.inferenceResolution'
              })}
              hint={intl.formatMessage({
                id: 'trainingParams.inferenceResolutionHint'
              })}
            />
          }
          data={inferenceResolutionOptions}
          onChange={(value) =>
            form.setFieldValue('inferenceResolution', value as string)
          }
        />

        <Collapse in={form.values.inferenceResolution === '$custom'}>
          <SimpleGrid cols={2} mt="xl">
            <NumberInput
              label={
                <FormattedMessage id="trainingParams.inferenceResolution.customWidth" />
              }
              min={0}
              {...form.getInputProps('customWidth')}
            />

            <NumberInput
              label={
                <FormattedMessage id="trainingParams.inferenceResolution.customHeight" />
              }
              min={0}
              {...form.getInputProps('customHeight')}
            />
          </SimpleGrid>
        </Collapse>
      </Box>

      <Group mt="xl" justify="center">
        <Button
          miw={240}
          variant="light"
          radius="xl"
          onClick={handleUseDefaultParams}
        >
          <FormattedMessage id="trainingParams.useRecommended" />
        </Button>

        <Button type="submit" miw={240} radius="xl">
          <FormattedMessage id="models.trainModel" />
        </Button>
      </Group>
    </form>
  )
}
