diff --git a/homeassistant/components/zwave_js/__init__.py b/homeassistant/components/zwave_js/__init__.py index 57b0e2edc6f..b847b76ca17 100644 --- a/homeassistant/components/zwave_js/__init__.py +++ b/homeassistant/components/zwave_js/__init__.py @@ -215,6 +215,9 @@ async def start_client( LOGGER.info("Connection to Zwave JS Server initialized") assert client.driver + async_dispatcher_send( + hass, f"{DOMAIN}_{client.driver.controller.home_id}_connected_to_server" + ) await driver_events.setup(client.driver) diff --git a/homeassistant/components/zwave_js/triggers/event.py b/homeassistant/components/zwave_js/triggers/event.py index 12c9d267ca6..32bd3130e03 100644 --- a/homeassistant/components/zwave_js/triggers/event.py +++ b/homeassistant/components/zwave_js/triggers/event.py @@ -1,18 +1,20 @@ """Offer Z-Wave JS event listening automation trigger.""" from __future__ import annotations +from collections.abc import Callable import functools from pydantic import ValidationError import voluptuous as vol from zwave_js_server.client import Client from zwave_js_server.model.controller import CONTROLLER_EVENT_MODEL_MAP -from zwave_js_server.model.driver import DRIVER_EVENT_MODEL_MAP +from zwave_js_server.model.driver import DRIVER_EVENT_MODEL_MAP, Driver from zwave_js_server.model.node import NODE_EVENT_MODEL_MAP from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback from homeassistant.helpers import config_validation as cv, device_registry as dr +from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.typing import ConfigType @@ -150,7 +152,7 @@ async def async_attach_trigger( event_name = config[ATTR_EVENT] event_data_filter = config.get(ATTR_EVENT_DATA, {}) - unsubs = [] + unsubs: list[Callable] = [] job = HassJob(action) trigger_data = trigger_info["trigger_data"] @@ -199,26 +201,6 @@ async def async_attach_trigger( hass.async_run_hass_job(job, {"trigger": payload}) - if not nodes: - entry_id = config[ATTR_CONFIG_ENTRY_ID] - client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT] - assert client.driver - if event_source == "controller": - unsubs.append(client.driver.controller.on(event_name, async_on_event)) - else: - unsubs.append(client.driver.on(event_name, async_on_event)) - - for node in nodes: - driver = node.client.driver - assert driver is not None # The node comes from the driver. - device_identifier = get_device_id(driver, node) - device = dev_reg.async_get_device({device_identifier}) - assert device - # We need to store the device for the callback - unsubs.append( - node.on(event_name, functools.partial(async_on_event, device=device)) - ) - @callback def async_remove() -> None: """Remove state listeners async.""" @@ -226,4 +208,45 @@ async def async_attach_trigger( unsub() unsubs.clear() + @callback + def _create_zwave_listeners() -> None: + """Create Z-Wave JS listeners.""" + async_remove() + # Nodes list can come from different drivers and we will need to listen to + # server connections for all of them. + drivers: set[Driver] = set() + if not nodes: + entry_id = config[ATTR_CONFIG_ENTRY_ID] + client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT] + driver = client.driver + assert driver + drivers.add(driver) + if event_source == "controller": + unsubs.append(driver.controller.on(event_name, async_on_event)) + else: + unsubs.append(driver.on(event_name, async_on_event)) + + for node in nodes: + driver = node.client.driver + assert driver is not None # The node comes from the driver. + drivers.add(driver) + device_identifier = get_device_id(driver, node) + device = dev_reg.async_get_device({device_identifier}) + assert device + # We need to store the device for the callback + unsubs.append( + node.on(event_name, functools.partial(async_on_event, device=device)) + ) + + for driver in drivers: + unsubs.append( + async_dispatcher_connect( + hass, + f"{DOMAIN}_{driver.controller.home_id}_connected_to_server", + _create_zwave_listeners, + ) + ) + + _create_zwave_listeners() + return async_remove diff --git a/homeassistant/components/zwave_js/triggers/value_updated.py b/homeassistant/components/zwave_js/triggers/value_updated.py index 655d1f9070e..4e21774c98f 100644 --- a/homeassistant/components/zwave_js/triggers/value_updated.py +++ b/homeassistant/components/zwave_js/triggers/value_updated.py @@ -1,15 +1,18 @@ """Offer Z-Wave JS value updated listening automation trigger.""" from __future__ import annotations +from collections.abc import Callable import functools import voluptuous as vol from zwave_js_server.const import CommandClass +from zwave_js_server.model.driver import Driver from zwave_js_server.model.value import Value, get_value_id_str from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM, MATCH_ALL from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback from homeassistant.helpers import config_validation as cv, device_registry as dr +from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.typing import ConfigType @@ -99,7 +102,7 @@ async def async_attach_trigger( property_ = config[ATTR_PROPERTY] endpoint = config.get(ATTR_ENDPOINT) property_key = config.get(ATTR_PROPERTY_KEY) - unsubs = [] + unsubs: list[Callable] = [] job = HassJob(action) trigger_data = trigger_info["trigger_data"] @@ -153,29 +156,11 @@ async def async_attach_trigger( ATTR_PREVIOUS_VALUE_RAW: prev_value_raw, ATTR_CURRENT_VALUE: curr_value, ATTR_CURRENT_VALUE_RAW: curr_value_raw, - "description": f"Z-Wave value {value_id} updated on {device_name}", + "description": f"Z-Wave value {value.value_id} updated on {device_name}", } hass.async_run_hass_job(job, {"trigger": payload}) - for node in nodes: - driver = node.client.driver - assert driver is not None # The node comes from the driver. - device_identifier = get_device_id(driver, node) - device = dev_reg.async_get_device({device_identifier}) - assert device - value_id = get_value_id_str( - node, command_class, property_, endpoint, property_key - ) - value = node.values[value_id] - # We need to store the current value and device for the callback - unsubs.append( - node.on( - "value updated", - functools.partial(async_on_value_updated, value, device), - ) - ) - @callback def async_remove() -> None: """Remove state listeners async.""" @@ -183,4 +168,40 @@ async def async_attach_trigger( unsub() unsubs.clear() + def _create_zwave_listeners() -> None: + """Create Z-Wave JS listeners.""" + async_remove() + # Nodes list can come from different drivers and we will need to listen to + # server connections for all of them. + drivers: set[Driver] = set() + for node in nodes: + driver = node.client.driver + assert driver is not None # The node comes from the driver. + drivers.add(driver) + device_identifier = get_device_id(driver, node) + device = dev_reg.async_get_device({device_identifier}) + assert device + value_id = get_value_id_str( + node, command_class, property_, endpoint, property_key + ) + value = node.values[value_id] + # We need to store the current value and device for the callback + unsubs.append( + node.on( + "value updated", + functools.partial(async_on_value_updated, value, device), + ) + ) + + for driver in drivers: + unsubs.append( + async_dispatcher_connect( + hass, + f"{DOMAIN}_{driver.controller.home_id}_connected_to_server", + _create_zwave_listeners, + ) + ) + + _create_zwave_listeners() + return async_remove diff --git a/tests/components/zwave_js/test_trigger.py b/tests/components/zwave_js/test_trigger.py index 9df8aa75f43..0fb3b829d9a 100644 --- a/tests/components/zwave_js/test_trigger.py +++ b/tests/components/zwave_js/test_trigger.py @@ -1109,3 +1109,101 @@ def test_get_trigger_platform_failure() -> None: """Test _get_trigger_platform.""" with pytest.raises(ValueError): _get_trigger_platform({CONF_PLATFORM: "zwave_js.invalid"}) + + +async def test_server_reconnect_event( + hass: HomeAssistant, client, lock_schlage_be469, integration +) -> None: + """Test that when we reconnect to server, event triggers reattach.""" + trigger_type = f"{DOMAIN}.event" + node: Node = lock_schlage_be469 + dev_reg = async_get_dev_reg(hass) + device = dev_reg.async_get_device( + {get_device_id(client.driver, lock_schlage_be469)} + ) + assert device + + event_name = "interview stage completed" + + original_len = len(node._listeners.get(event_name, [])) + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "trigger": { + "platform": trigger_type, + "entity_id": SCHLAGE_BE469_LOCK_ENTITY, + "event_source": "node", + "event": event_name, + }, + "action": { + "event": "blah", + }, + }, + ] + }, + ) + + assert len(node._listeners.get(event_name, [])) == original_len + 1 + old_listener = node._listeners.get(event_name, [])[original_len] + + await hass.config_entries.async_reload(integration.entry_id) + await hass.async_block_till_done() + + # Make sure there is still a listener added for the trigger + assert len(node._listeners.get(event_name, [])) == original_len + 1 + + # Make sure the old listener was removed + assert old_listener not in node._listeners.get(event_name, []) + + +async def test_server_reconnect_value_updated( + hass: HomeAssistant, client, lock_schlage_be469, integration +) -> None: + """Test that when we reconnect to server, value_updated triggers reattach.""" + trigger_type = f"{DOMAIN}.value_updated" + node: Node = lock_schlage_be469 + dev_reg = async_get_dev_reg(hass) + device = dev_reg.async_get_device( + {get_device_id(client.driver, lock_schlage_be469)} + ) + assert device + + event_name = "value updated" + + original_len = len(node._listeners.get(event_name, [])) + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "trigger": { + "platform": trigger_type, + "entity_id": SCHLAGE_BE469_LOCK_ENTITY, + "command_class": CommandClass.DOOR_LOCK.value, + "property": "latchStatus", + }, + "action": { + "event": "no_value_filter", + }, + }, + ] + }, + ) + + assert len(node._listeners.get(event_name, [])) == original_len + 1 + old_listener = node._listeners.get(event_name, [])[original_len] + + await hass.config_entries.async_reload(integration.entry_id) + await hass.async_block_till_done() + + # Make sure there is still a listener added for the trigger + assert len(node._listeners.get(event_name, [])) == original_len + 1 + + # Make sure the old listener was removed + assert old_listener not in node._listeners.get(event_name, [])