"""Offer Z-Wave JS event listening automation trigger."""
from __future__ import annotations

import functools

from pydantic import ValidationError
import voluptuous as vol
from zwave_js_server.client import Client
from zwave_js_server.model.controller import CONTROLLER_EVENT_MODEL_MAP
from zwave_js_server.model.driver import DRIVER_EVENT_MODEL_MAP
from zwave_js_server.model.node import NODE_EVENT_MODEL_MAP

from homeassistant.components.automation import (
    AutomationActionType,
    AutomationTriggerInfo,
)
from homeassistant.components.zwave_js.const import (
    ATTR_CONFIG_ENTRY_ID,
    ATTR_EVENT,
    ATTR_EVENT_DATA,
    ATTR_EVENT_SOURCE,
    ATTR_NODE_ID,
    ATTR_PARTIAL_DICT_MATCH,
    DATA_CLIENT,
    DOMAIN,
)
from homeassistant.components.zwave_js.helpers import (
    async_get_nodes_from_targets,
    get_device_id,
    get_home_and_node_id_from_device_entry,
)
from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.typing import ConfigType

from .helpers import async_bypass_dynamic_config_validation

# Platform type should be <DOMAIN>.<SUBMODULE_NAME>
PLATFORM_TYPE = f"{DOMAIN}.{__name__.rsplit('.', maxsplit=1)[-1]}"


def validate_non_node_event_source(obj: dict) -> dict:
    """Validate that a trigger for a non node event source has a config entry."""
    if obj[ATTR_EVENT_SOURCE] != "node" and ATTR_CONFIG_ENTRY_ID in obj:
        return obj
    raise vol.Invalid(f"Non node event triggers must contain {ATTR_CONFIG_ENTRY_ID}.")


def validate_event_name(obj: dict) -> dict:
    """Validate that a trigger has a valid event name."""
    event_source = obj[ATTR_EVENT_SOURCE]
    event_name = obj[ATTR_EVENT]
    # the keys to the event source's model map are the event names
    if event_source == "controller":
        vol.In(CONTROLLER_EVENT_MODEL_MAP)(event_name)
    elif event_source == "driver":
        vol.In(DRIVER_EVENT_MODEL_MAP)(event_name)
    else:
        vol.In(NODE_EVENT_MODEL_MAP)(event_name)
    return obj


def validate_event_data(obj: dict) -> dict:
    """Validate that a trigger has a valid event data."""
    # Return if there's no event data to validate
    if ATTR_EVENT_DATA not in obj:
        return obj

    event_source: str = obj[ATTR_EVENT_SOURCE]
    event_name: str = obj[ATTR_EVENT]
    event_data: dict = obj[ATTR_EVENT_DATA]
    try:
        if event_source == "controller":
            CONTROLLER_EVENT_MODEL_MAP[event_name](**event_data)
        elif event_source == "driver":
            DRIVER_EVENT_MODEL_MAP[event_name](**event_data)
        else:
            NODE_EVENT_MODEL_MAP[event_name](**event_data)
    except ValidationError as exc:
        # Filter out required field errors if keys can be missing, and if there are
        # still errors, raise an exception
        if errors := [
            error for error in exc.errors() if error["type"] != "value_error.missing"
        ]:
            raise vol.MultipleInvalid(errors) from exc
    return obj


TRIGGER_SCHEMA = vol.All(
    cv.TRIGGER_BASE_SCHEMA.extend(
        {
            vol.Required(CONF_PLATFORM): PLATFORM_TYPE,
            vol.Optional(ATTR_CONFIG_ENTRY_ID): str,
            vol.Optional(ATTR_DEVICE_ID): vol.All(cv.ensure_list, [cv.string]),
            vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
            vol.Required(ATTR_EVENT_SOURCE): vol.In(["controller", "driver", "node"]),
            vol.Required(ATTR_EVENT): cv.string,
            vol.Optional(ATTR_EVENT_DATA): dict,
            vol.Optional(ATTR_PARTIAL_DICT_MATCH, default=False): bool,
        },
    ),
    validate_event_name,
    validate_event_data,
    vol.Any(
        validate_non_node_event_source,
        cv.has_at_least_one_key(ATTR_DEVICE_ID, ATTR_ENTITY_ID),
    ),
)


async def async_validate_trigger_config(
    hass: HomeAssistant, config: ConfigType
) -> ConfigType:
    """Validate config."""
    config = TRIGGER_SCHEMA(config)

    if ATTR_CONFIG_ENTRY_ID in config:
        entry_id = config[ATTR_CONFIG_ENTRY_ID]
        if hass.config_entries.async_get_entry(entry_id) is None:
            raise vol.Invalid(f"Config entry '{entry_id}' not found")

    if async_bypass_dynamic_config_validation(hass, config):
        return config

    if config[ATTR_EVENT_SOURCE] == "node" and not async_get_nodes_from_targets(
        hass, config
    ):
        raise vol.Invalid(
            f"No nodes found for given {ATTR_DEVICE_ID}s or {ATTR_ENTITY_ID}s."
        )

    return config


async def async_attach_trigger(
    hass: HomeAssistant,
    config: ConfigType,
    action: AutomationActionType,
    automation_info: AutomationTriggerInfo,
    *,
    platform_type: str = PLATFORM_TYPE,
) -> CALLBACK_TYPE:
    """Listen for state changes based on configuration."""
    dev_reg = dr.async_get(hass)
    nodes = async_get_nodes_from_targets(hass, config, dev_reg=dev_reg)
    if config[ATTR_EVENT_SOURCE] == "node" and not nodes:
        raise ValueError(
            f"No nodes found for given {ATTR_DEVICE_ID}s or {ATTR_ENTITY_ID}s."
        )

    event_source = config[ATTR_EVENT_SOURCE]
    event_name = config[ATTR_EVENT]
    event_data_filter = config.get(ATTR_EVENT_DATA, {})

    unsubs = []
    job = HassJob(action)

    trigger_data = automation_info["trigger_data"]

    @callback
    def async_on_event(event_data: dict, device: dr.DeviceEntry | None = None) -> None:
        """Handle event."""
        for key, val in event_data_filter.items():
            if key not in event_data:
                return
            if (
                config[ATTR_PARTIAL_DICT_MATCH]
                and isinstance(event_data[key], dict)
                and isinstance(event_data_filter[key], dict)
            ):
                for key2, val2 in event_data_filter[key].items():
                    if key2 not in event_data[key] or event_data[key][key2] != val2:
                        return
                continue
            if event_data[key] != val:
                return

        payload = {
            **trigger_data,
            CONF_PLATFORM: platform_type,
            ATTR_EVENT_SOURCE: event_source,
            ATTR_EVENT: event_name,
            ATTR_EVENT_DATA: event_data,
        }

        primary_desc = f"Z-Wave JS '{event_source}' event '{event_name}' was emitted"

        if device:
            device_name = device.name_by_user or device.name
            payload[ATTR_DEVICE_ID] = device.id
            home_and_node_id = get_home_and_node_id_from_device_entry(device)
            assert home_and_node_id
            payload[ATTR_NODE_ID] = home_and_node_id[1]
            payload["description"] = f"{primary_desc} on {device_name}"
        else:
            payload["description"] = primary_desc

        payload[
            "description"
        ] = f"{payload['description']} with event data: {event_data}"

        hass.async_run_hass_job(job, {"trigger": payload})

    if not nodes:
        entry_id = config[ATTR_CONFIG_ENTRY_ID]
        client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT]
        assert client.driver
        if event_source == "controller":
            unsubs.append(client.driver.controller.on(event_name, async_on_event))
        else:
            unsubs.append(client.driver.on(event_name, async_on_event))

    for node in nodes:
        driver = node.client.driver
        assert driver is not None  # The node comes from the driver.
        device_identifier = get_device_id(driver, node)
        device = dev_reg.async_get_device({device_identifier})
        assert device
        # We need to store the device for the callback
        unsubs.append(
            node.on(event_name, functools.partial(async_on_event, device=device))
        )

    @callback
    def async_remove() -> None:
        """Remove state listeners async."""
        for unsub in unsubs:
            unsub()
        unsubs.clear()

    return async_remove