"""Support for MQTT message handling."""

from __future__ import annotations

import asyncio
from collections.abc import Callable
from datetime import datetime
import logging
from typing import TYPE_CHECKING, Any, cast

import voluptuous as vol

from homeassistant import config as conf_util
from homeassistant.components import websocket_api
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_DISCOVERY, CONF_PAYLOAD, SERVICE_RELOAD
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import (
    ConfigValidationError,
    ServiceValidationError,
    Unauthorized,
)
from homeassistant.helpers import (
    config_validation as cv,
    entity_registry as er,
    event as ev,
    issue_registry as ir,
    template,
)
from homeassistant.helpers.device_registry import DeviceEntry
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import async_get_platforms
from homeassistant.helpers.reload import async_integration_yaml_config
from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import async_get_integration, async_get_loaded_integration
from homeassistant.setup import SetupPhases, async_pause_setup
from homeassistant.util.async_ import create_eager_task

# Loading the config flow file will register the flow
from . import debug_info, discovery
from .client import (  # noqa: F401
    MQTT,
    async_publish,
    async_subscribe,
    async_subscribe_internal,
    publish,
    subscribe,
)
from .config import MQTT_BASE_SCHEMA, MQTT_RO_SCHEMA, MQTT_RW_SCHEMA  # noqa: F401
from .config_integration import CONFIG_SCHEMA_BASE
from .const import (  # noqa: F401
    ATTR_PAYLOAD,
    ATTR_QOS,
    ATTR_RETAIN,
    ATTR_TOPIC,
    CONF_BIRTH_MESSAGE,
    CONF_BROKER,
    CONF_CERTIFICATE,
    CONF_CLIENT_CERT,
    CONF_CLIENT_KEY,
    CONF_COMMAND_TOPIC,
    CONF_DISCOVERY_PREFIX,
    CONF_KEEPALIVE,
    CONF_QOS,
    CONF_STATE_TOPIC,
    CONF_TLS_INSECURE,
    CONF_TOPIC,
    CONF_TRANSPORT,
    CONF_WILL_MESSAGE,
    CONF_WS_HEADERS,
    CONF_WS_PATH,
    DEFAULT_DISCOVERY,
    DEFAULT_ENCODING,
    DEFAULT_PREFIX,
    DEFAULT_QOS,
    DEFAULT_RETAIN,
    DOMAIN,
    MQTT_CONNECTION_STATE,
    RELOADABLE_PLATFORMS,
    TEMPLATE_ERRORS,
)
from .models import (  # noqa: F401
    DATA_MQTT,
    DATA_MQTT_AVAILABLE,
    MqttCommandTemplate,
    MqttData,
    MqttValueTemplate,
    PayloadSentinel,
    PublishPayloadType,
    ReceiveMessage,
    ReceivePayloadType,
)
from .subscription import (  # noqa: F401
    EntitySubscription,
    async_prepare_subscribe_topics,
    async_subscribe_topics,
    async_unsubscribe_topics,
)
from .util import (  # noqa: F401
    async_create_certificate_temp_files,
    async_forward_entry_setup_and_setup_discovery,
    async_wait_for_mqtt_client,
    mqtt_config_entry_enabled,
    platforms_from_config,
    valid_publish_topic,
    valid_qos_schema,
    valid_subscribe_topic,
)

_LOGGER = logging.getLogger(__name__)

SERVICE_PUBLISH = "publish"
SERVICE_DUMP = "dump"

ATTR_TOPIC_TEMPLATE = "topic_template"
ATTR_PAYLOAD_TEMPLATE = "payload_template"

MAX_RECONNECT_WAIT = 300  # seconds

CONNECTION_SUCCESS = "connection_success"
CONNECTION_FAILED = "connection_failed"
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"

# We accept 2 schemes for configuring manual MQTT items
#
# Preferred style:
#
# mqtt:
#   - {domain}:
#       name: ""
#       ...
#   - {domain}:
#       name: ""
#       ...
# ```
#
# Legacy supported style:
#
# mqtt:
#   {domain}:
#     - name: ""
#       ...
#     - name: ""
#       ...
CONFIG_SCHEMA = vol.Schema(
    {
        DOMAIN: vol.All(
            cv.ensure_list,
            cv.remove_falsy,
            [CONFIG_SCHEMA_BASE],
        )
    },
    extra=vol.ALLOW_EXTRA,
)


# Service call validation schema
MQTT_PUBLISH_SCHEMA = vol.All(
    vol.Schema(
        {
            vol.Exclusive(ATTR_TOPIC, CONF_TOPIC): valid_publish_topic,
            vol.Exclusive(ATTR_TOPIC_TEMPLATE, CONF_TOPIC): cv.string,
            vol.Exclusive(ATTR_PAYLOAD, CONF_PAYLOAD): cv.string,
            vol.Exclusive(ATTR_PAYLOAD_TEMPLATE, CONF_PAYLOAD): cv.string,
            vol.Optional(ATTR_QOS, default=DEFAULT_QOS): valid_qos_schema,
            vol.Optional(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean,
        },
        required=True,
    ),
    cv.has_at_least_one_key(ATTR_TOPIC, ATTR_TOPIC_TEMPLATE),
)


async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -> None:
    """Handle signals of config entry being updated.

    Causes for this is config entry options changing.
    """
    await hass.config_entries.async_reload(entry.entry_id)


@callback
def _async_remove_mqtt_issues(hass: HomeAssistant, mqtt_data: MqttData) -> None:
    """Unregister open config issues."""
    issue_registry = ir.async_get(hass)
    open_issues = [
        issue_id
        for (domain, issue_id), issue_entry in issue_registry.issues.items()
        if domain == DOMAIN and issue_entry.translation_key == "invalid_platform_config"
    ]
    for issue in open_issues:
        ir.async_delete_issue(hass, DOMAIN, issue)


async def async_check_config_schema(
    hass: HomeAssistant, config_yaml: ConfigType
) -> None:
    """Validate manually configured MQTT items."""
    mqtt_data = hass.data[DATA_MQTT]
    mqtt_config: list[dict[str, list[ConfigType]]] = config_yaml.get(DOMAIN, {})
    for mqtt_config_item in mqtt_config:
        for domain, config_items in mqtt_config_item.items():
            schema = mqtt_data.reload_schema[domain]
            for config in config_items:
                try:
                    schema(config)
                except vol.Invalid as exc:
                    integration = await async_get_integration(hass, DOMAIN)
                    message = conf_util.format_schema_error(
                        hass, exc, domain, config, integration.documentation
                    )
                    raise ServiceValidationError(
                        message,
                        translation_domain=DOMAIN,
                        translation_key="invalid_platform_config",
                        translation_placeholders={
                            "domain": domain,
                        },
                    ) from exc


async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
    """Load a config entry."""
    conf: dict[str, Any]
    mqtt_data: MqttData

    async def _setup_client(
        client_available: asyncio.Future[bool],
    ) -> tuple[MqttData, dict[str, Any]]:
        """Set up the MQTT client."""
        # Fetch configuration
        conf = dict(entry.data)
        hass_config = await conf_util.async_hass_config_yaml(hass)
        mqtt_yaml = CONFIG_SCHEMA(hass_config).get(DOMAIN, [])
        await async_create_certificate_temp_files(hass, conf)
        client = MQTT(hass, entry, conf)
        if DOMAIN in hass.data:
            mqtt_data = hass.data[DATA_MQTT]
            mqtt_data.config = mqtt_yaml
            mqtt_data.client = client
        else:
            # Initial setup
            websocket_api.async_register_command(hass, websocket_subscribe)
            websocket_api.async_register_command(hass, websocket_mqtt_info)
            hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client)
        await client.async_start(mqtt_data)

        # Restore saved subscriptions
        if mqtt_data.subscriptions_to_restore:
            mqtt_data.client.async_restore_tracked_subscriptions(
                mqtt_data.subscriptions_to_restore
            )
            mqtt_data.subscriptions_to_restore = []
        mqtt_data.reload_dispatchers.append(
            entry.add_update_listener(_async_config_entry_updated)
        )

        return (mqtt_data, conf)

    client_available: asyncio.Future[bool]
    if DATA_MQTT_AVAILABLE not in hass.data:
        client_available = hass.data[DATA_MQTT_AVAILABLE] = hass.loop.create_future()
    else:
        client_available = hass.data[DATA_MQTT_AVAILABLE]

    mqtt_data, conf = await _setup_client(client_available)
    platforms_used = platforms_from_config(mqtt_data.config)
    platforms_used.update(
        entry.domain
        for entry in er.async_entries_for_config_entry(
            er.async_get(hass), entry.entry_id
        )
    )
    integration = async_get_loaded_integration(hass, DOMAIN)
    # Preload platforms we know we are going to use so
    # discovery can setup each platform synchronously
    # and avoid creating a flood of tasks at startup
    # while waiting for the the imports to complete
    if not integration.platforms_are_loaded(platforms_used):
        with async_pause_setup(hass, SetupPhases.WAIT_IMPORT_PLATFORMS):
            await integration.async_get_platforms(platforms_used)

    # Wait to connect until the platforms are loaded so
    # we can be sure discovery does not have to wait for
    # each platform to load when we get the flood of retained
    # messages on connect
    await mqtt_data.client.async_connect(client_available)

    async def async_publish_service(call: ServiceCall) -> None:
        """Handle MQTT publish service calls."""
        msg_topic: str | None = call.data.get(ATTR_TOPIC)
        msg_topic_template: str | None = call.data.get(ATTR_TOPIC_TEMPLATE)
        payload: PublishPayloadType = call.data.get(ATTR_PAYLOAD)
        payload_template: str | None = call.data.get(ATTR_PAYLOAD_TEMPLATE)
        qos: int = call.data[ATTR_QOS]
        retain: bool = call.data[ATTR_RETAIN]
        if msg_topic_template is not None:
            rendered_topic: Any = MqttCommandTemplate(
                template.Template(msg_topic_template),
                hass=hass,
            ).async_render()
            try:
                msg_topic = valid_publish_topic(rendered_topic)
            except vol.Invalid as err:
                err_str = str(err)
                raise ServiceValidationError(
                    translation_domain=DOMAIN,
                    translation_key="invalid_publish_topic",
                    translation_placeholders={
                        "error": err_str,
                        "topic": str(rendered_topic),
                        "topic_template": str(msg_topic_template),
                    },
                ) from err

        if payload_template is not None:
            payload = MqttCommandTemplate(
                template.Template(payload_template), hass=hass
            ).async_render()

        if TYPE_CHECKING:
            assert msg_topic is not None
        await mqtt_data.client.async_publish(msg_topic, payload, qos, retain)

    hass.services.async_register(
        DOMAIN, SERVICE_PUBLISH, async_publish_service, schema=MQTT_PUBLISH_SCHEMA
    )

    async def async_dump_service(call: ServiceCall) -> None:
        """Handle MQTT dump service calls."""
        messages: list[tuple[str, str]] = []

        @callback
        def collect_msg(msg: ReceiveMessage) -> None:
            messages.append((msg.topic, str(msg.payload).replace("\n", "")))

        unsub = async_subscribe_internal(hass, call.data["topic"], collect_msg)

        def write_dump() -> None:
            with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp:
                for msg in messages:
                    fp.write(",".join(msg) + "\n")

        async def finish_dump(_: datetime) -> None:
            """Write dump to file."""
            unsub()
            await hass.async_add_executor_job(write_dump)

        ev.async_call_later(hass, call.data["duration"], finish_dump)

    hass.services.async_register(
        DOMAIN,
        SERVICE_DUMP,
        async_dump_service,
        schema=vol.Schema(
            {
                vol.Required("topic"): valid_subscribe_topic,
                vol.Optional("duration", default=5): int,
            }
        ),
    )

    # setup platforms and discovery
    async def _reload_config(call: ServiceCall) -> None:
        """Reload the platforms."""
        # Fetch updated manually configured items and validate
        try:
            config_yaml = await async_integration_yaml_config(
                hass, DOMAIN, raise_on_failure=True
            )
        except ConfigValidationError as ex:
            raise ServiceValidationError(
                translation_domain=ex.translation_domain,
                translation_key=ex.translation_key,
                translation_placeholders=ex.translation_placeholders,
            ) from ex

        new_config: list[ConfigType] = config_yaml.get(DOMAIN, [])
        platforms_used = platforms_from_config(new_config)
        new_platforms = platforms_used - mqtt_data.platforms_loaded
        await async_forward_entry_setup_and_setup_discovery(
            hass, entry, new_platforms, late=True
        )
        # Check the schema before continuing reload
        await async_check_config_schema(hass, config_yaml)

        # Remove repair issues
        _async_remove_mqtt_issues(hass, mqtt_data)

        mqtt_data.config = new_config

        # Reload the modern yaml platforms
        mqtt_platforms = async_get_platforms(hass, DOMAIN)
        tasks = [
            create_eager_task(entity.async_remove())
            for mqtt_platform in mqtt_platforms
            for entity in list(mqtt_platform.entities.values())
            if getattr(entity, "_discovery_data", None) is None
            and mqtt_platform.config_entry
            and mqtt_platform.domain in RELOADABLE_PLATFORMS
        ]
        await asyncio.gather(*tasks)

        for component in mqtt_data.reload_handlers.values():
            component()

        # Fire event
        hass.bus.async_fire(f"event_{DOMAIN}_reloaded", context=call.context)

    await async_forward_entry_setup_and_setup_discovery(hass, entry, platforms_used)
    # Setup reload service after all platforms have loaded
    if not hass.services.has_service(DOMAIN, SERVICE_RELOAD):
        async_register_admin_service(hass, DOMAIN, SERVICE_RELOAD, _reload_config)
    # Setup discovery
    if conf.get(CONF_DISCOVERY, DEFAULT_DISCOVERY):
        await discovery.async_start(
            hass, conf.get(CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX), entry
        )

    return True


@websocket_api.websocket_command(
    {vol.Required("type"): "mqtt/device/debug_info", vol.Required("device_id"): str}
)
@callback
def websocket_mqtt_info(
    hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
    """Get MQTT debug info for device."""
    device_id = msg["device_id"]
    mqtt_info = debug_info.info_for_device(hass, device_id)

    connection.send_result(msg["id"], mqtt_info)


@websocket_api.websocket_command(
    {
        vol.Required("type"): "mqtt/subscribe",
        vol.Required("topic"): valid_subscribe_topic,
        vol.Optional("qos"): valid_qos_schema,
    }
)
@websocket_api.async_response
async def websocket_subscribe(
    hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
    """Subscribe to a MQTT topic."""
    if not connection.user.is_admin:
        raise Unauthorized

    @callback
    def forward_messages(mqttmsg: ReceiveMessage) -> None:
        """Forward events to websocket."""
        try:
            payload = cast(bytes, mqttmsg.payload).decode(
                DEFAULT_ENCODING
            )  # not str because encoding is set to None
        except (AttributeError, UnicodeDecodeError):
            # Convert non UTF-8 payload to a string presentation
            payload = str(mqttmsg.payload)

        connection.send_message(
            websocket_api.event_message(
                msg["id"],
                {
                    "topic": mqttmsg.topic,
                    "payload": payload,
                    "qos": mqttmsg.qos,
                    "retain": mqttmsg.retain,
                },
            )
        )

    # Perform UTF-8 decoding directly in callback routine
    qos: int = msg.get("qos", DEFAULT_QOS)
    connection.subscriptions[msg["id"]] = async_subscribe_internal(
        hass, msg["topic"], forward_messages, encoding=None, qos=qos
    )

    connection.send_message(websocket_api.result_message(msg["id"]))


type ConnectionStatusCallback = Callable[[bool], None]


@callback
def async_subscribe_connection_status(
    hass: HomeAssistant, connection_status_callback: ConnectionStatusCallback
) -> Callable[[], None]:
    """Subscribe to MQTT connection changes."""
    return async_dispatcher_connect(
        hass, MQTT_CONNECTION_STATE, connection_status_callback
    )


def is_connected(hass: HomeAssistant) -> bool:
    """Return if MQTT client is connected."""
    mqtt_data = hass.data[DATA_MQTT]
    return mqtt_data.client.connected


async def async_remove_config_entry_device(
    hass: HomeAssistant, config_entry: ConfigEntry, device_entry: DeviceEntry
) -> bool:
    """Remove MQTT config entry from a device."""
    # pylint: disable-next=import-outside-toplevel
    from . import device_automation

    await device_automation.async_removed_from_device(hass, device_entry.id)
    return True


async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
    """Unload MQTT dump and publish service when the config entry is unloaded."""
    mqtt_data = hass.data[DATA_MQTT]
    mqtt_client = mqtt_data.client

    # Unload publish and dump services.
    hass.services.async_remove(DOMAIN, SERVICE_PUBLISH)
    hass.services.async_remove(DOMAIN, SERVICE_DUMP)

    # Stop the discovery
    await discovery.async_stop(hass)
    # Unload the platforms
    await hass.config_entries.async_unload_platforms(entry, mqtt_data.platforms_loaded)
    mqtt_data.platforms_loaded = set()
    await asyncio.sleep(0)
    # Unsubscribe reload dispatchers
    while reload_dispatchers := mqtt_data.reload_dispatchers:
        reload_dispatchers.pop()()
    # Cleanup listeners
    mqtt_client.cleanup()

    # Cleanup entity registry hooks
    registry_hooks = mqtt_data.discovery_registry_hooks
    while registry_hooks:
        registry_hooks.popitem()[1]()
    # Wait for all ACKs and stop the loop
    await mqtt_client.async_disconnect()

    # Cleanup MQTT client availability
    hass.data.pop(DATA_MQTT_AVAILABLE, None)
    # Store remaining subscriptions to be able to restore or reload them
    # when the entry is set up again
    if subscriptions := mqtt_client.subscriptions:
        mqtt_data.subscriptions_to_restore = subscriptions

    # Remove repair issues
    _async_remove_mqtt_issues(hass, mqtt_data)

    return True