diff --git a/homeassistant/components/homekit_controller/device_trigger.py b/homeassistant/components/homekit_controller/device_trigger.py index ecffb902928..bc1434f4bd9 100644 --- a/homeassistant/components/homekit_controller/device_trigger.py +++ b/homeassistant/components/homekit_controller/device_trigger.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any from aiohomekit.model.characteristics import CharacteristicsTypes from aiohomekit.model.characteristics.const import InputEventValues -from aiohomekit.model.services import ServicesTypes +from aiohomekit.model.services import Service, ServicesTypes from aiohomekit.utils import clamp_enum_to_char import voluptuous as vol @@ -57,28 +57,41 @@ HK_TO_HA_INPUT_EVENT_VALUES = { class TriggerSource: """Represents a stateless source of event data from HomeKit.""" - def __init__( + def __init__(self, hass: HomeAssistant) -> None: + """Initialize a set of triggers for a device.""" + self._hass = hass + self._triggers: dict[tuple[str, str], dict[str, Any]] = {} + self._callbacks: dict[tuple[str, str], list[Callable[[Any], None]]] = {} + self._iid_trigger_keys: dict[int, set[tuple[str, str]]] = {} + + async def async_setup( self, connection: HKDevice, aid: int, triggers: list[dict[str, Any]] ) -> None: - """Initialize a set of triggers for a device.""" - self._hass = connection.hass - self._connection = connection - self._aid = aid - self._triggers: dict[tuple[str, str], dict[str, Any]] = {} - for trigger in triggers: - self._triggers[(trigger["type"], trigger["subtype"])] = trigger - self._callbacks: dict[int, list[Callable[[Any], None]]] = {} + """Set up a set of triggers for a device. - def fire(self, iid, value): + This function must be re-entrant since + it is called when the device is first added and + when the config entry is reloaded. + """ + for trigger_data in triggers: + trigger_key = (trigger_data[CONF_TYPE], trigger_data[CONF_SUBTYPE]) + self._triggers[trigger_key] = trigger_data + iid = trigger_data["characteristic"] + self._iid_trigger_keys.setdefault(iid, set()).add(trigger_key) + await connection.add_watchable_characteristics([(aid, iid)]) + + def fire(self, iid: int, value: dict[str, Any]) -> None: """Process events that have been received from a HomeKit accessory.""" - for event_handler in self._callbacks.get(iid, []): - event_handler(value) + for trigger_key in self._iid_trigger_keys.get(iid, set()): + for event_handler in self._callbacks.get(trigger_key, []): + event_handler(value) def async_get_triggers(self) -> Generator[tuple[str, str], None, None]: - """List device triggers for homekit devices.""" + """List device triggers for HomeKit devices.""" yield from self._triggers - async def async_attach_trigger( + @callback + def async_attach_trigger( self, config: ConfigType, action: TriggerActionType, @@ -86,28 +99,25 @@ class TriggerSource: ) -> CALLBACK_TYPE: """Attach a trigger.""" trigger_data = trigger_info["trigger_data"] + trigger_key = (config[CONF_TYPE], config[CONF_SUBTYPE]) job = HassJob(action) @callback - def event_handler(char): + def event_handler(char: dict[str, Any]) -> None: if config[CONF_SUBTYPE] != HK_TO_HA_INPUT_EVENT_VALUES[char["value"]]: return self._hass.async_run_hass_job(job, {"trigger": {**trigger_data, **config}}) - trigger = self._triggers[config[CONF_TYPE], config[CONF_SUBTYPE]] - iid = trigger["characteristic"] - - await self._connection.add_watchable_characteristics([(self._aid, iid)]) - self._callbacks.setdefault(iid, []).append(event_handler) + self._callbacks.setdefault(trigger_key, []).append(event_handler) def async_remove_handler(): - if iid in self._callbacks: - self._callbacks[iid].remove(event_handler) + if trigger_key in self._callbacks: + self._callbacks[trigger_key].remove(event_handler) return async_remove_handler -def enumerate_stateless_switch(service): +def enumerate_stateless_switch(service: Service) -> list[dict[str, Any]]: """Enumerate a stateless switch, like a single button.""" # A stateless switch that has a SERVICE_LABEL_INDEX is part of a group @@ -135,7 +145,7 @@ def enumerate_stateless_switch(service): ] -def enumerate_stateless_switch_group(service): +def enumerate_stateless_switch_group(service: Service) -> list[dict[str, Any]]: """Enumerate a group of stateless switches, like a remote control.""" switches = list( service.accessory.services.filter( @@ -165,7 +175,7 @@ def enumerate_stateless_switch_group(service): return results -def enumerate_doorbell(service): +def enumerate_doorbell(service: Service) -> list[dict[str, Any]]: """Enumerate doorbell buttons.""" input_event = service[CharacteristicsTypes.INPUT_EVENT] @@ -217,21 +227,32 @@ async def async_setup_triggers_for_entry( if device_id in hass.data[TRIGGERS]: return False - # Just because we recognise the service type doesn't mean we can actually + # Just because we recognize the service type doesn't mean we can actually # extract any triggers - so only proceed if we can triggers = TRIGGER_FINDERS[service_type](service) if len(triggers) == 0: return False - trigger = TriggerSource(conn, aid, triggers) - hass.data[TRIGGERS][device_id] = trigger + trigger = async_get_or_create_trigger_source(conn.hass, device_id) + hass.async_create_task(trigger.async_setup(conn, aid, triggers)) return True conn.add_listener(async_add_service) -def async_fire_triggers(conn: HKDevice, events: dict[tuple[int, int], Any]): +@callback +def async_get_or_create_trigger_source( + hass: HomeAssistant, device_id: str +) -> TriggerSource: + """Get or create a trigger source for a device id.""" + if not (source := hass.data[TRIGGERS].get(device_id)): + source = TriggerSource(hass) + hass.data[TRIGGERS][device_id] = source + return source + + +def async_fire_triggers(conn: HKDevice, events: dict[tuple[int, int], dict[str, Any]]): """Process events generated by a HomeKit accessory into automation triggers.""" trigger_sources: dict[str, TriggerSource] = conn.hass.data[TRIGGERS] for (aid, iid), ev in events.items(): @@ -271,5 +292,6 @@ async def async_attach_trigger( ) -> CALLBACK_TYPE: """Attach a trigger.""" device_id = config[CONF_DEVICE_ID] - device = hass.data[TRIGGERS][device_id] - return await device.async_attach_trigger(config, action, trigger_info) + return async_get_or_create_trigger_source(hass, device_id).async_attach_trigger( + config, action, trigger_info + ) diff --git a/tests/components/homekit_controller/test_device_trigger.py b/tests/components/homekit_controller/test_device_trigger.py index a09525d9dec..6f17f5db786 100644 --- a/tests/components/homekit_controller/test_device_trigger.py +++ b/tests/components/homekit_controller/test_device_trigger.py @@ -6,6 +6,7 @@ import pytest import homeassistant.components.automation as automation from homeassistant.components.device_automation import DeviceAutomationType from homeassistant.components.homekit_controller.const import DOMAIN +from homeassistant.config_entries import ConfigEntryState from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.setup import async_setup_component @@ -338,3 +339,129 @@ async def test_handle_events(hass, utcnow, calls): await hass.async_block_till_done() assert len(calls) == 2 + + +async def test_handle_events_late_setup(hass, utcnow, calls): + """Test that events are handled when setup happens after startup.""" + helper = await setup_test_component(hass, create_remote) + + entity_registry = er.async_get(hass) + entry = entity_registry.async_get("sensor.testdevice_battery") + + device_registry = dr.async_get(hass) + device = device_registry.async_get(entry.device_id) + + await hass.config_entries.async_unload(helper.config_entry.entry_id) + await hass.async_block_till_done() + assert helper.config_entry.state == ConfigEntryState.NOT_LOADED + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "alias": "single_press", + "trigger": { + "platform": "device", + "domain": DOMAIN, + "device_id": device.id, + "type": "button1", + "subtype": "single_press", + }, + "action": { + "service": "test.automation", + "data_template": { + "some": ( + "{{ trigger.platform}} - " + "{{ trigger.type }} - {{ trigger.subtype }} - " + "{{ trigger.id }}" + ) + }, + }, + }, + { + "alias": "long_press", + "trigger": { + "platform": "device", + "domain": DOMAIN, + "device_id": device.id, + "type": "button2", + "subtype": "long_press", + }, + "action": { + "service": "test.automation", + "data_template": { + "some": ( + "{{ trigger.platform}} - " + "{{ trigger.type }} - {{ trigger.subtype }} - " + "{{ trigger.id }}" + ) + }, + }, + }, + ] + }, + ) + await hass.async_block_till_done() + + await hass.config_entries.async_setup(helper.config_entry.entry_id) + await hass.async_block_till_done() + assert helper.config_entry.state == ConfigEntryState.LOADED + + # Make sure first automation (only) fires for single press + helper.pairing.testing.update_named_service( + "Button 1", {CharacteristicsTypes.INPUT_EVENT: 0} + ) + + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0].data["some"] == "device - button1 - single_press - 0" + + # Make sure automation doesn't trigger for long press + helper.pairing.testing.update_named_service( + "Button 1", {CharacteristicsTypes.INPUT_EVENT: 1} + ) + + await hass.async_block_till_done() + assert len(calls) == 1 + + # Make sure automation doesn't trigger for double press + helper.pairing.testing.update_named_service( + "Button 1", {CharacteristicsTypes.INPUT_EVENT: 2} + ) + + await hass.async_block_till_done() + assert len(calls) == 1 + + # Make sure second automation fires for long press + helper.pairing.testing.update_named_service( + "Button 2", {CharacteristicsTypes.INPUT_EVENT: 2} + ) + + await hass.async_block_till_done() + assert len(calls) == 2 + assert calls[1].data["some"] == "device - button2 - long_press - 0" + + # Turn the automations off + await hass.services.async_call( + "automation", + "turn_off", + {"entity_id": "automation.long_press"}, + blocking=True, + ) + + await hass.services.async_call( + "automation", + "turn_off", + {"entity_id": "automation.single_press"}, + blocking=True, + ) + + # Make sure event no longer fires + helper.pairing.testing.update_named_service( + "Button 2", {CharacteristicsTypes.INPUT_EVENT: 2} + ) + + await hass.async_block_till_done() + assert len(calls) == 2