"""Config flow for the Template integration."""
from __future__ import annotations

from collections.abc import Callable, Coroutine, Mapping
from typing import Any, cast

import voluptuous as vol

from homeassistant.components import websocket_api
from homeassistant.components.binary_sensor import BinarySensorDeviceClass
from homeassistant.components.sensor import (
    CONF_STATE_CLASS,
    DEVICE_CLASS_STATE_CLASSES,
    DEVICE_CLASS_UNITS,
    SensorDeviceClass,
    SensorStateClass,
)
from homeassistant.const import (
    CONF_DEVICE_CLASS,
    CONF_NAME,
    CONF_STATE,
    CONF_UNIT_OF_MEASUREMENT,
    Platform,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er, selector
from homeassistant.helpers.schema_config_entry_flow import (
    SchemaCommonFlowHandler,
    SchemaConfigFlowHandler,
    SchemaFlowFormStep,
    SchemaFlowMenuStep,
)

from .binary_sensor import async_create_preview_binary_sensor
from .const import DOMAIN
from .sensor import async_create_preview_sensor
from .template_entity import TemplateEntity

NONE_SENTINEL = "none"


def generate_schema(domain: str, flow_type: str) -> dict[vol.Marker, Any]:
    """Generate schema."""
    schema: dict[vol.Marker, Any] = {}

    if domain == Platform.BINARY_SENSOR and flow_type == "config":
        schema = {
            vol.Optional(CONF_DEVICE_CLASS): selector.SelectSelector(
                selector.SelectSelectorConfig(
                    options=[
                        NONE_SENTINEL,
                        *sorted(
                            [cls.value for cls in BinarySensorDeviceClass],
                            key=str.casefold,
                        ),
                    ],
                    mode=selector.SelectSelectorMode.DROPDOWN,
                    translation_key="binary_sensor_device_class",
                ),
            )
        }

    if domain == Platform.SENSOR:
        schema = {
            vol.Optional(
                CONF_UNIT_OF_MEASUREMENT, default=NONE_SENTINEL
            ): selector.SelectSelector(
                selector.SelectSelectorConfig(
                    options=[
                        NONE_SENTINEL,
                        *sorted(
                            {
                                str(unit)
                                for units in DEVICE_CLASS_UNITS.values()
                                for unit in units
                                if unit is not None
                            },
                            key=str.casefold,
                        ),
                    ],
                    mode=selector.SelectSelectorMode.DROPDOWN,
                    translation_key="sensor_unit_of_measurement",
                    custom_value=True,
                ),
            ),
            vol.Optional(
                CONF_DEVICE_CLASS, default=NONE_SENTINEL
            ): selector.SelectSelector(
                selector.SelectSelectorConfig(
                    options=[
                        NONE_SENTINEL,
                        *sorted(
                            [
                                cls.value
                                for cls in SensorDeviceClass
                                if cls != SensorDeviceClass.ENUM
                            ],
                            key=str.casefold,
                        ),
                    ],
                    mode=selector.SelectSelectorMode.DROPDOWN,
                    translation_key="sensor_device_class",
                ),
            ),
            vol.Optional(
                CONF_STATE_CLASS, default=NONE_SENTINEL
            ): selector.SelectSelector(
                selector.SelectSelectorConfig(
                    options=[
                        NONE_SENTINEL,
                        *sorted([cls.value for cls in SensorStateClass]),
                    ],
                    mode=selector.SelectSelectorMode.DROPDOWN,
                    translation_key="sensor_state_class",
                ),
            ),
        }

    return schema


def options_schema(domain: str) -> vol.Schema:
    """Generate options schema."""
    return vol.Schema(
        {vol.Required(CONF_STATE): selector.TemplateSelector()}
        | generate_schema(domain, "option"),
    )


def config_schema(domain: str) -> vol.Schema:
    """Generate config schema."""
    return vol.Schema(
        {
            vol.Required(CONF_NAME): selector.TextSelector(),
            vol.Required(CONF_STATE): selector.TemplateSelector(),
        }
        | generate_schema(domain, "config"),
    )


async def choose_options_step(options: dict[str, Any]) -> str:
    """Return next step_id for options flow according to template_type."""
    return cast(str, options["template_type"])


def _strip_sentinel(options: dict[str, Any]) -> None:
    """Convert sentinel to None."""
    for key in (CONF_DEVICE_CLASS, CONF_STATE_CLASS, CONF_UNIT_OF_MEASUREMENT):
        if key not in options:
            continue
        if options[key] == NONE_SENTINEL:
            options.pop(key)


def _validate_unit(options: dict[str, Any]) -> None:
    """Validate unit of measurement."""
    if (
        (device_class := options.get(CONF_DEVICE_CLASS))
        and (units := DEVICE_CLASS_UNITS.get(device_class)) is not None
        and (unit := options.get(CONF_UNIT_OF_MEASUREMENT)) not in units
    ):
        sorted_units = sorted(
            [f"'{str(unit)}'" if unit else "no unit of measurement" for unit in units],
            key=str.casefold,
        )
        if len(sorted_units) == 1:
            units_string = sorted_units[0]
        else:
            units_string = f"one of {', '.join(sorted_units)}"

        raise vol.Invalid(
            f"'{unit}' is not a valid unit for device class '{device_class}'; "
            f"expected {units_string}"
        )


def _validate_state_class(options: dict[str, Any]) -> None:
    """Validate state class."""
    if (
        (state_class := options.get(CONF_STATE_CLASS))
        and (device_class := options.get(CONF_DEVICE_CLASS))
        and (state_classes := DEVICE_CLASS_STATE_CLASSES.get(device_class)) is not None
        and state_class not in state_classes
    ):
        sorted_state_classes = sorted(
            [f"'{str(state_class)}'" for state_class in state_classes],
            key=str.casefold,
        )
        if len(sorted_state_classes) == 0:
            state_classes_string = "no state class"
        elif len(sorted_state_classes) == 1:
            state_classes_string = sorted_state_classes[0]
        else:
            state_classes_string = f"one of {', '.join(sorted_state_classes)}"

        raise vol.Invalid(
            f"'{state_class}' is not a valid state class for device class "
            f"'{device_class}'; expected {state_classes_string}"
        )


def validate_user_input(
    template_type: str,
) -> Callable[
    [SchemaCommonFlowHandler, dict[str, Any]],
    Coroutine[Any, Any, dict[str, Any]],
]:
    """Do post validation of user input.

    For binary sensors: Strip none-sentinels.
    For sensors: Strip none-sentinels and validate unit of measurement.
    For all domaines: Set template type.
    """

    async def _validate_user_input(
        _: SchemaCommonFlowHandler,
        user_input: dict[str, Any],
    ) -> dict[str, Any]:
        """Add template type to user input."""
        if template_type in (Platform.BINARY_SENSOR, Platform.SENSOR):
            _strip_sentinel(user_input)
        if template_type == Platform.SENSOR:
            _validate_unit(user_input)
            _validate_state_class(user_input)
        return {"template_type": template_type} | user_input

    return _validate_user_input


TEMPLATE_TYPES = [
    "binary_sensor",
    "sensor",
]

CONFIG_FLOW = {
    "user": SchemaFlowMenuStep(TEMPLATE_TYPES),
    Platform.BINARY_SENSOR: SchemaFlowFormStep(
        config_schema(Platform.BINARY_SENSOR),
        preview="template",
        validate_user_input=validate_user_input(Platform.BINARY_SENSOR),
    ),
    Platform.SENSOR: SchemaFlowFormStep(
        config_schema(Platform.SENSOR),
        preview="template",
        validate_user_input=validate_user_input(Platform.SENSOR),
    ),
}


OPTIONS_FLOW = {
    "init": SchemaFlowFormStep(next_step=choose_options_step),
    Platform.BINARY_SENSOR: SchemaFlowFormStep(
        options_schema(Platform.BINARY_SENSOR),
        preview="template",
        validate_user_input=validate_user_input(Platform.BINARY_SENSOR),
    ),
    Platform.SENSOR: SchemaFlowFormStep(
        options_schema(Platform.SENSOR),
        preview="template",
        validate_user_input=validate_user_input(Platform.SENSOR),
    ),
}

CREATE_PREVIEW_ENTITY: dict[
    str,
    Callable[[HomeAssistant, str, dict[str, Any]], TemplateEntity],
] = {
    "binary_sensor": async_create_preview_binary_sensor,
    "sensor": async_create_preview_sensor,
}


class TemplateConfigFlowHandler(SchemaConfigFlowHandler, domain=DOMAIN):
    """Handle config flow for template helper."""

    config_flow = CONFIG_FLOW
    options_flow = OPTIONS_FLOW

    @callback
    def async_config_entry_title(self, options: Mapping[str, Any]) -> str:
        """Return config entry title."""
        return cast(str, options["name"])

    @staticmethod
    async def async_setup_preview(hass: HomeAssistant) -> None:
        """Set up preview WS API."""
        websocket_api.async_register_command(hass, ws_start_preview)


@websocket_api.websocket_command(
    {
        vol.Required("type"): "template/start_preview",
        vol.Required("flow_id"): str,
        vol.Required("flow_type"): vol.Any("config_flow", "options_flow"),
        vol.Required("user_input"): dict,
    }
)
@callback
def ws_start_preview(
    hass: HomeAssistant,
    connection: websocket_api.ActiveConnection,
    msg: dict[str, Any],
) -> None:
    """Generate a preview."""

    def _validate(schema: vol.Schema, domain: str, user_input: dict[str, Any]) -> Any:
        errors = {}
        key: vol.Marker
        for key, validator in schema.schema.items():
            if key.schema not in user_input:
                continue
            try:
                validator(user_input[key.schema])
            except vol.Invalid as ex:
                errors[key.schema] = str(ex.msg)

        if domain == Platform.SENSOR:
            _strip_sentinel(user_input)
            try:
                _validate_unit(user_input)
            except vol.Invalid as ex:
                errors[CONF_UNIT_OF_MEASUREMENT] = str(ex.msg)
            try:
                _validate_state_class(user_input)
            except vol.Invalid as ex:
                errors[CONF_STATE_CLASS] = str(ex.msg)

        return errors

    entity_registry_entry: er.RegistryEntry | None = None
    if msg["flow_type"] == "config_flow":
        flow_status = hass.config_entries.flow.async_get(msg["flow_id"])
        template_type = flow_status["step_id"]
        form_step = cast(SchemaFlowFormStep, CONFIG_FLOW[template_type])
        schema = cast(vol.Schema, form_step.schema)
        name = msg["user_input"]["name"]
    else:
        flow_status = hass.config_entries.options.async_get(msg["flow_id"])
        config_entry = hass.config_entries.async_get_entry(flow_status["handler"])
        if not config_entry:
            raise HomeAssistantError
        template_type = config_entry.options["template_type"]
        name = config_entry.options["name"]
        schema = cast(vol.Schema, OPTIONS_FLOW[template_type].schema)
        entity_registry = er.async_get(hass)
        entries = er.async_entries_for_config_entry(
            entity_registry, flow_status["handler"]
        )
        if entries:
            entity_registry_entry = entries[0]

    errors = _validate(schema, template_type, msg["user_input"])

    @callback
    def async_preview_updated(
        state: str | None,
        attributes: Mapping[str, Any] | None,
        listeners: dict[str, bool | set[str]] | None,
        error: str | None,
    ) -> None:
        """Forward config entry state events to websocket."""
        if error is not None:
            connection.send_message(
                websocket_api.event_message(
                    msg["id"],
                    {"error": error},
                )
            )
            return
        connection.send_message(
            websocket_api.event_message(
                msg["id"],
                {"attributes": attributes, "listeners": listeners, "state": state},
            )
        )

    if errors:
        connection.send_message(
            {
                "id": msg["id"],
                "type": websocket_api.const.TYPE_RESULT,
                "success": False,
                "error": {"code": "invalid_user_input", "message": errors},
            }
        )
        return

    _strip_sentinel(msg["user_input"])
    preview_entity = CREATE_PREVIEW_ENTITY[template_type](hass, name, msg["user_input"])
    preview_entity.hass = hass
    preview_entity.registry_entry = entity_registry_entry

    connection.send_result(msg["id"])
    connection.subscriptions[msg["id"]] = preview_entity.async_start_preview(
        async_preview_updated
    )