From 934e1a160387606731ff20a6e40c4efd71293dbf Mon Sep 17 00:00:00 2001 From: Raman Gupta <7243222+raman325@users.noreply.github.com> Date: Fri, 16 Jun 2023 03:35:29 -0400 Subject: [PATCH] Fix zwave_js trigger event reattach logic (#94702) --- .../components/zwave_js/triggers/event.py | 7 +- .../zwave_js/triggers/value_updated.py | 4 +- tests/components/zwave_js/test_trigger.py | 118 +++++++++++++----- 3 files changed, 96 insertions(+), 33 deletions(-) diff --git a/homeassistant/components/zwave_js/triggers/event.py b/homeassistant/components/zwave_js/triggers/event.py index 32bd3130e03..33cb59d8505 100644 --- a/homeassistant/components/zwave_js/triggers/event.py +++ b/homeassistant/components/zwave_js/triggers/event.py @@ -142,8 +142,9 @@ async def async_attach_trigger( ) -> CALLBACK_TYPE: """Listen for state changes based on configuration.""" dev_reg = dr.async_get(hass) - nodes = async_get_nodes_from_targets(hass, config, dev_reg=dev_reg) - if config[ATTR_EVENT_SOURCE] == "node" and not nodes: + if config[ATTR_EVENT_SOURCE] == "node" and not async_get_nodes_from_targets( + hass, config, dev_reg=dev_reg + ): raise ValueError( f"No nodes found for given {ATTR_DEVICE_ID}s or {ATTR_ENTITY_ID}s." ) @@ -215,7 +216,7 @@ async def async_attach_trigger( # Nodes list can come from different drivers and we will need to listen to # server connections for all of them. drivers: set[Driver] = set() - if not nodes: + if not (nodes := async_get_nodes_from_targets(hass, config, dev_reg=dev_reg)): entry_id = config[ATTR_CONFIG_ENTRY_ID] client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT] driver = client.driver diff --git a/homeassistant/components/zwave_js/triggers/value_updated.py b/homeassistant/components/zwave_js/triggers/value_updated.py index 4e21774c98f..52ecc0a7742 100644 --- a/homeassistant/components/zwave_js/triggers/value_updated.py +++ b/homeassistant/components/zwave_js/triggers/value_updated.py @@ -91,7 +91,7 @@ async def async_attach_trigger( ) -> CALLBACK_TYPE: """Listen for state changes based on configuration.""" dev_reg = dr.async_get(hass) - if not (nodes := async_get_nodes_from_targets(hass, config, dev_reg=dev_reg)): + if not async_get_nodes_from_targets(hass, config, dev_reg=dev_reg): raise ValueError( f"No nodes found for given {ATTR_DEVICE_ID}s or {ATTR_ENTITY_ID}s." ) @@ -174,7 +174,7 @@ async def async_attach_trigger( # Nodes list can come from different drivers and we will need to listen to # server connections for all of them. drivers: set[Driver] = set() - for node in nodes: + for node in async_get_nodes_from_targets(hass, config, dev_reg=dev_reg): driver = node.client.driver assert driver is not None # The node comes from the driver. drivers.add(driver) diff --git a/tests/components/zwave_js/test_trigger.py b/tests/components/zwave_js/test_trigger.py index 0fb3b829d9a..eae9d6f5416 100644 --- a/tests/components/zwave_js/test_trigger.py +++ b/tests/components/zwave_js/test_trigger.py @@ -1112,20 +1112,21 @@ def test_get_trigger_platform_failure() -> None: async def test_server_reconnect_event( - hass: HomeAssistant, client, lock_schlage_be469, integration + hass: HomeAssistant, + client, + lock_schlage_be469, + lock_schlage_be469_state, + integration, ) -> None: """Test that when we reconnect to server, event triggers reattach.""" trigger_type = f"{DOMAIN}.event" - node: Node = lock_schlage_be469 - dev_reg = async_get_dev_reg(hass) - device = dev_reg.async_get_device( - {get_device_id(client.driver, lock_schlage_be469)} - ) - assert device + old_node: Node = lock_schlage_be469 event_name = "interview stage completed" - original_len = len(node._listeners.get(event_name, [])) + old_node = client.driver.controller.nodes[20] + + original_len = len(old_node._listeners.get(event_name, [])) assert await async_setup_component( hass, @@ -1147,34 +1148,65 @@ async def test_server_reconnect_event( }, ) - assert len(node._listeners.get(event_name, [])) == original_len + 1 - old_listener = node._listeners.get(event_name, [])[original_len] + assert len(old_node._listeners.get(event_name, [])) == original_len + 1 + old_listener = old_node._listeners.get(event_name, [])[original_len] + # Remove node so that we can create a new node instance and make sure the listener + # attaches + node_removed_event = Event( + type="node removed", + data={ + "source": "controller", + "event": "node removed", + "replaced": False, + "node": lock_schlage_be469_state, + }, + ) + client.driver.controller.receive_event(node_removed_event) + assert 20 not in client.driver.controller.nodes + await hass.async_block_till_done() + + # Add node like new server connection would + node_added_event = Event( + type="node added", + data={ + "source": "controller", + "event": "node added", + "node": lock_schlage_be469_state, + "result": {}, + }, + ) + client.driver.controller.receive_event(node_added_event) + await hass.async_block_till_done() + + # Reload integration to trigger the dispatch signal await hass.config_entries.async_reload(integration.entry_id) await hass.async_block_till_done() - # Make sure there is still a listener added for the trigger - assert len(node._listeners.get(event_name, [])) == original_len + 1 + # Make sure there is a listener added for the trigger to the new node + new_node = client.driver.controller.nodes[20] + assert len(new_node._listeners.get(event_name, [])) == original_len + 1 - # Make sure the old listener was removed - assert old_listener not in node._listeners.get(event_name, []) + # Make sure the old listener is no longer referenced + assert old_listener not in new_node._listeners.get(event_name, []) async def test_server_reconnect_value_updated( - hass: HomeAssistant, client, lock_schlage_be469, integration + hass: HomeAssistant, + client, + lock_schlage_be469, + lock_schlage_be469_state, + integration, ) -> None: """Test that when we reconnect to server, value_updated triggers reattach.""" trigger_type = f"{DOMAIN}.value_updated" - node: Node = lock_schlage_be469 - dev_reg = async_get_dev_reg(hass) - device = dev_reg.async_get_device( - {get_device_id(client.driver, lock_schlage_be469)} - ) - assert device + old_node: Node = lock_schlage_be469 event_name = "value updated" - original_len = len(node._listeners.get(event_name, [])) + old_node = client.driver.controller.nodes[20] + + original_len = len(old_node._listeners.get(event_name, [])) assert await async_setup_component( hass, @@ -1196,14 +1228,44 @@ async def test_server_reconnect_value_updated( }, ) - assert len(node._listeners.get(event_name, [])) == original_len + 1 - old_listener = node._listeners.get(event_name, [])[original_len] + assert len(old_node._listeners.get(event_name, [])) == original_len + 1 + old_listener = old_node._listeners.get(event_name, [])[original_len] + # Remove node so that we can create a new node instance and make sure the listener + # attaches + node_removed_event = Event( + type="node removed", + data={ + "source": "controller", + "event": "node removed", + "replaced": False, + "node": lock_schlage_be469_state, + }, + ) + client.driver.controller.receive_event(node_removed_event) + assert 20 not in client.driver.controller.nodes + await hass.async_block_till_done() + + # Add node like new server connection would + node_added_event = Event( + type="node added", + data={ + "source": "controller", + "event": "node added", + "node": lock_schlage_be469_state, + "result": {}, + }, + ) + client.driver.controller.receive_event(node_added_event) + await hass.async_block_till_done() + + # Reload integration to trigger the dispatch signal await hass.config_entries.async_reload(integration.entry_id) await hass.async_block_till_done() - # Make sure there is still a listener added for the trigger - assert len(node._listeners.get(event_name, [])) == original_len + 1 + # Make sure there is a listener added for the trigger to the new node + new_node = client.driver.controller.nodes[20] + assert len(new_node._listeners.get(event_name, [])) == original_len + 1 - # Make sure the old listener was removed - assert old_listener not in node._listeners.get(event_name, []) + # Make sure the old listener is no longer referenced + assert old_listener not in new_node._listeners.get(event_name, [])