diff --git a/homeassistant/components/bthome/device_trigger.py b/homeassistant/components/bthome/device_trigger.py index c50ffc05900..4eca110e581 100644 --- a/homeassistant/components/bthome/device_trigger.py +++ b/homeassistant/components/bthome/device_trigger.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any import voluptuous as vol @@ -34,7 +34,7 @@ from .const import ( EVENT_TYPE, ) -TRIGGERS_BY_EVENT_CLASS = { +EVENT_TYPES_BY_EVENT_CLASS = { EVENT_CLASS_BUTTON: { "press", "double_press", @@ -51,6 +51,38 @@ TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend( ) +def get_event_classes_by_device_id(hass: HomeAssistant, device_id: str) -> list[str]: + """Get the supported event classes for a device. + + Events for BTHome BLE devices are dynamically discovered + and stored in the device config entry when they are first seen. + """ + device_registry = dr.async_get(hass) + device = device_registry.async_get(device_id) + if TYPE_CHECKING: + assert device is not None + + config_entries = [ + hass.config_entries.async_get_entry(entry_id) + for entry_id in device.config_entries + ] + bthome_config_entry = next( + entry for entry in config_entries if entry and entry.domain == DOMAIN + ) + return bthome_config_entry.data.get(CONF_DISCOVERED_EVENT_CLASSES, []) + + +def get_event_types_by_event_class(event_class: str) -> set[str]: + """Get the supported event types for an event class. + + If the device has multiple buttons they will have + event classes like button_1 button_2, button_3, etc + but if there is only one button then it will be + button without a number postfix. + """ + return EVENT_TYPES_BY_EVENT_CLASS.get(event_class.split("_")[0], set()) + + async def async_validate_trigger_config( hass: HomeAssistant, config: ConfigType ) -> ConfigType: @@ -58,31 +90,17 @@ async def async_validate_trigger_config( config = TRIGGER_SCHEMA(config) event_class = config[CONF_TYPE] event_type = config[CONF_SUBTYPE] - - device_registry = dr.async_get(hass) - device = device_registry.async_get(config[CONF_DEVICE_ID]) - assert device is not None - config_entries = [ - hass.config_entries.async_get_entry(entry_id) - for entry_id in device.config_entries - ] - bthome_config_entry = next( - iter(entry for entry in config_entries if entry and entry.domain == DOMAIN) - ) - event_classes: list[str] = bthome_config_entry.data.get( - CONF_DISCOVERED_EVENT_CLASSES, [] - ) + device_id = config[CONF_DEVICE_ID] + event_classes = get_event_classes_by_device_id(hass, device_id) if event_class not in event_classes: raise InvalidDeviceAutomationConfig( - f"BTHome trigger {event_class} is not valid for device " - f"{device} ({config[CONF_DEVICE_ID]})" + f"BTHome trigger {event_class} is not valid for device_id '{device_id}'" ) - if event_type not in TRIGGERS_BY_EVENT_CLASS.get(event_class.split("_")[0], ()): + if event_type not in get_event_types_by_event_class(event_class): raise InvalidDeviceAutomationConfig( - f"BTHome trigger {event_type} is not valid for device " - f"{device} ({config[CONF_DEVICE_ID]})" + f"BTHome trigger {event_type} is not valid for device_id '{device_id}'" ) return config @@ -92,21 +110,7 @@ async def async_get_triggers( hass: HomeAssistant, device_id: str ) -> list[dict[str, Any]]: """Return a list of triggers for BTHome BLE devices.""" - device_registry = dr.async_get(hass) - device = device_registry.async_get(device_id) - assert device is not None - config_entries = [ - hass.config_entries.async_get_entry(entry_id) - for entry_id in device.config_entries - ] - bthome_config_entry = next( - iter(entry for entry in config_entries if entry and entry.domain == DOMAIN), - None, - ) - assert bthome_config_entry is not None - event_classes: list[str] = bthome_config_entry.data.get( - CONF_DISCOVERED_EVENT_CLASSES, [] - ) + event_classes = get_event_classes_by_device_id(hass, device_id) return [ { # Required fields of TRIGGER_BASE_SCHEMA @@ -118,14 +122,7 @@ async def async_get_triggers( CONF_SUBTYPE: event_type, } for event_class in event_classes - for event_type in TRIGGERS_BY_EVENT_CLASS.get( - event_class.split("_")[0], - # If the device has multiple buttons they will have - # event classes like button_1 button_2, button_3, etc - # but if there is only one button then it will be - # button without a number postfix. - (), - ) + for event_type in get_event_types_by_event_class(event_class) ]