Refactor zwave_js config entry setup (#107635)

* Refactor zwave_js config entry setup

* Fix blocking update test

* Address timeout comment

* Remove platform tasks

* Replace deprecated async_add_job

* Use ConfigEntry.async_on_state_change

* Use modern config entry methods

* Clarify exception message

* Test listen error after config entry setup

* Test listen failure during setup after forward entry

* Test not reloading when hass is stopping

* Test client disconnect is called on entry unload

* Fix and test client not connected during driver setup

* Fix and test driver ready timeout

* Stringify listen task exception when logging

* Use identity compare

* Guard for closed connection

* Consolidate listen task checking and tests
This commit is contained in:
Martin Hjelmare 2025-03-20 10:16:48 +01:00 committed by GitHub
parent 32f9c07254
commit 2674b02bfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 362 additions and 129 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from contextlib import suppress import contextlib
import logging import logging
from typing import Any from typing import Any
@ -12,7 +12,11 @@ from awesomeversion import AwesomeVersion
import voluptuous as vol import voluptuous as vol
from zwave_js_server.client import Client as ZwaveClient from zwave_js_server.client import Client as ZwaveClient
from zwave_js_server.const import CommandClass, RemoveNodeReason from zwave_js_server.const import CommandClass, RemoveNodeReason
from zwave_js_server.exceptions import BaseZwaveJSServerError, InvalidServerVersion from zwave_js_server.exceptions import (
BaseZwaveJSServerError,
InvalidServerVersion,
NotConnected,
)
from zwave_js_server.model.driver import Driver from zwave_js_server.model.driver import Driver
from zwave_js_server.model.node import Node as ZwaveNode from zwave_js_server.model.node import Node as ZwaveNode
from zwave_js_server.model.notification import ( from zwave_js_server.model.notification import (
@ -25,7 +29,7 @@ from zwave_js_server.model.value import Value, ValueNotification
from homeassistant.components.hassio import AddonError, AddonManager, AddonState from homeassistant.components.hassio import AddonError, AddonManager, AddonState
from homeassistant.components.persistent_notification import async_create from homeassistant.components.persistent_notification import async_create
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import ( from homeassistant.const import (
ATTR_DEVICE_ID, ATTR_DEVICE_ID,
ATTR_DOMAIN, ATTR_DOMAIN,
@ -36,7 +40,7 @@ from homeassistant.const import (
Platform, Platform,
) )
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
from homeassistant.helpers import ( from homeassistant.helpers import (
config_validation as cv, config_validation as cv,
device_registry as dr, device_registry as dr,
@ -130,9 +134,8 @@ from .migrate import async_migrate_discovered_value
from .services import ZWaveServices from .services import ZWaveServices
CONNECT_TIMEOUT = 10 CONNECT_TIMEOUT = 10
DATA_CLIENT_LISTEN_TASK = "client_listen_task"
DATA_DRIVER_EVENTS = "driver_events" DATA_DRIVER_EVENTS = "driver_events"
DATA_START_CLIENT_TASK = "start_client_task" DRIVER_READY_TIMEOUT = 60
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(
{ {
@ -145,6 +148,24 @@ CONFIG_SCHEMA = vol.Schema(
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
) )
PLATFORMS = [
Platform.BINARY_SENSOR,
Platform.BUTTON,
Platform.CLIMATE,
Platform.COVER,
Platform.EVENT,
Platform.FAN,
Platform.HUMIDIFIER,
Platform.LIGHT,
Platform.LOCK,
Platform.NUMBER,
Platform.SELECT,
Platform.SENSOR,
Platform.SIREN,
Platform.SWITCH,
Platform.UPDATE,
]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Z-Wave JS component.""" """Set up the Z-Wave JS component."""
@ -196,53 +217,99 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
raise ConfigEntryNotReady(f"Failed to connect: {err}") from err raise ConfigEntryNotReady(f"Failed to connect: {err}") from err
async_delete_issue(hass, DOMAIN, "invalid_server_version") async_delete_issue(hass, DOMAIN, "invalid_server_version")
LOGGER.info("Connected to Zwave JS Server") LOGGER.debug("Connected to Zwave JS Server")
# Set up websocket API # Set up websocket API
async_register_api(hass) async_register_api(hass)
entry.runtime_data = {}
# Create a task to allow the config entry to be unloaded before the driver is ready. driver_ready = asyncio.Event()
# Unloading the config entry is needed if the client listen task errors. listen_task = entry.async_create_background_task(
start_client_task = hass.async_create_task(start_client(hass, entry, client)) hass,
entry.runtime_data[DATA_START_CLIENT_TASK] = start_client_task client_listen(hass, entry, client, driver_ready),
f"{DOMAIN}_{entry.title}_client_listen",
)
return True entry.async_on_unload(client.disconnect)
async def start_client(
hass: HomeAssistant, entry: ConfigEntry, client: ZwaveClient
) -> None:
"""Start listening with the client."""
entry.runtime_data[DATA_CLIENT] = client
driver_events = entry.runtime_data[DATA_DRIVER_EVENTS] = DriverEvents(hass, entry)
async def handle_ha_shutdown(event: Event) -> None: async def handle_ha_shutdown(event: Event) -> None:
"""Handle HA shutdown.""" """Handle HA shutdown."""
await disconnect_client(hass, entry) await client.disconnect()
listen_task = asyncio.create_task(
client_listen(hass, entry, client, driver_events.ready)
)
entry.runtime_data[DATA_CLIENT_LISTEN_TASK] = listen_task
entry.async_on_unload( entry.async_on_unload(
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_ha_shutdown) hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_ha_shutdown)
) )
try: driver_ready_task = entry.async_create_task(
await driver_events.ready.wait() hass,
except asyncio.CancelledError: driver_ready.wait(),
LOGGER.debug("Cancelling start client") f"{DOMAIN}_{entry.title}_driver_ready",
return )
done, pending = await asyncio.wait(
LOGGER.info("Connection to Zwave JS Server initialized") (driver_ready_task, listen_task),
return_when=asyncio.FIRST_COMPLETED,
assert client.driver timeout=DRIVER_READY_TIMEOUT,
async_dispatcher_send(
hass, f"{DOMAIN}_{client.driver.controller.home_id}_connected_to_server"
) )
await driver_events.setup(client.driver) if driver_ready_task in pending or listen_task in done:
error_message = "Driver ready timed out"
listen_error: BaseException | None = None
if listen_task.done():
listen_error, error_message = _get_listen_task_error(listen_task)
else:
listen_task.cancel()
driver_ready_task.cancel()
raise ConfigEntryNotReady(error_message) from listen_error
LOGGER.debug("Connection to Zwave JS Server initialized")
entry_runtime_data = entry.runtime_data = {
DATA_CLIENT: client,
}
entry_runtime_data[DATA_DRIVER_EVENTS] = driver_events = DriverEvents(hass, entry)
driver = client.driver
# When the driver is ready we know it's set on the client.
assert driver is not None
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
with contextlib.suppress(NotConnected):
# If the client isn't connected the listen task may have an exception
# and we'll handle the clean up below.
await driver_events.setup(driver)
# If the listen task is already failed, we need to raise ConfigEntryNotReady
if listen_task.done():
listen_error, error_message = _get_listen_task_error(listen_task)
await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
raise ConfigEntryNotReady(error_message) from listen_error
# Re-attach trigger listeners.
# Schedule this call to make sure the config entry is loaded first.
@callback
def on_config_entry_loaded() -> None:
"""Signal that server connection and driver are ready."""
if entry.state is ConfigEntryState.LOADED:
async_dispatcher_send(
hass,
f"{DOMAIN}_{driver.controller.home_id}_connected_to_server",
)
entry.async_on_unload(entry.async_on_state_change(on_config_entry_loaded))
return True
def _get_listen_task_error(
listen_task: asyncio.Task,
) -> tuple[BaseException | None, str]:
"""Check the listen task for errors."""
if listen_error := listen_task.exception():
error_message = f"Client listen failed: {listen_error}"
else:
error_message = "Client connection was closed"
return listen_error, error_message
class DriverEvents: class DriverEvents:
@ -255,8 +322,6 @@ class DriverEvents:
self.config_entry = entry self.config_entry = entry
self.dev_reg = dr.async_get(hass) self.dev_reg = dr.async_get(hass)
self.hass = hass self.hass = hass
self.platform_setup_tasks: dict[str, asyncio.Task] = {}
self.ready = asyncio.Event()
# Make sure to not pass self to ControllerEvents until all attributes are set. # Make sure to not pass self to ControllerEvents until all attributes are set.
self.controller_events = ControllerEvents(hass, self) self.controller_events = ControllerEvents(hass, self)
@ -339,16 +404,6 @@ class DriverEvents:
controller.on("identify", self.controller_events.async_on_identify) controller.on("identify", self.controller_events.async_on_identify)
) )
async def async_setup_platform(self, platform: Platform) -> None:
"""Set up platform if needed."""
if platform not in self.platform_setup_tasks:
self.platform_setup_tasks[platform] = self.hass.async_create_task(
self.hass.config_entries.async_forward_entry_setups(
self.config_entry, [platform]
)
)
await self.platform_setup_tasks[platform]
class ControllerEvents: class ControllerEvents:
"""Represent controller events. """Represent controller events.
@ -380,9 +435,6 @@ class ControllerEvents:
async def async_on_node_added(self, node: ZwaveNode) -> None: async def async_on_node_added(self, node: ZwaveNode) -> None:
"""Handle node added event.""" """Handle node added event."""
# Every node including the controller will have at least one sensor
await self.driver_events.async_setup_platform(Platform.SENSOR)
# Remove stale entities that may exist from a previous interview when an # Remove stale entities that may exist from a previous interview when an
# interview is started. # interview is started.
base_unique_id = get_valueless_base_unique_id(self.driver_events.driver, node) base_unique_id = get_valueless_base_unique_id(self.driver_events.driver, node)
@ -411,7 +463,6 @@ class ControllerEvents:
) )
# Create a ping button for each device # Create a ping button for each device
await self.driver_events.async_setup_platform(Platform.BUTTON)
async_dispatcher_send( async_dispatcher_send(
self.hass, self.hass,
f"{DOMAIN}_{self.config_entry.entry_id}_add_ping_button_entity", f"{DOMAIN}_{self.config_entry.entry_id}_add_ping_button_entity",
@ -668,9 +719,6 @@ class NodeEvents:
cc.id == CommandClass.FIRMWARE_UPDATE_MD.value cc.id == CommandClass.FIRMWARE_UPDATE_MD.value
for cc in node.command_classes for cc in node.command_classes
): ):
await self.controller_events.driver_events.async_setup_platform(
Platform.UPDATE
)
async_dispatcher_send( async_dispatcher_send(
self.hass, self.hass,
f"{DOMAIN}_{self.config_entry.entry_id}_add_firmware_update_entity", f"{DOMAIN}_{self.config_entry.entry_id}_add_firmware_update_entity",
@ -701,21 +749,19 @@ class NodeEvents:
value_updates_disc_info: dict[str, ZwaveDiscoveryInfo], value_updates_disc_info: dict[str, ZwaveDiscoveryInfo],
) -> None: ) -> None:
"""Handle discovery info and all dependent tasks.""" """Handle discovery info and all dependent tasks."""
platform = disc_info.platform
# This migration logic was added in 2021.3 to handle a breaking change to # This migration logic was added in 2021.3 to handle a breaking change to
# the value_id format. Some time in the future, this call (as well as the # the value_id format. Some time in the future, this call (as well as the
# helper functions) can be removed. # helper functions) can be removed.
async_migrate_discovered_value( async_migrate_discovered_value(
self.hass, self.hass,
self.ent_reg, self.ent_reg,
self.controller_events.registered_unique_ids[device.id][disc_info.platform], self.controller_events.registered_unique_ids[device.id][platform],
device, device,
self.controller_events.driver_events.driver, self.controller_events.driver_events.driver,
disc_info, disc_info,
) )
platform = disc_info.platform
await self.controller_events.driver_events.async_setup_platform(platform)
LOGGER.debug("Discovered entity: %s", disc_info) LOGGER.debug("Discovered entity: %s", disc_info)
async_dispatcher_send( async_dispatcher_send(
self.hass, self.hass,
@ -930,63 +976,37 @@ async def client_listen(
driver_ready: asyncio.Event, driver_ready: asyncio.Event,
) -> None: ) -> None:
"""Listen with the client.""" """Listen with the client."""
should_reload = True
try: try:
await client.listen(driver_ready) await client.listen(driver_ready)
except asyncio.CancelledError:
should_reload = False
except BaseZwaveJSServerError as err: except BaseZwaveJSServerError as err:
LOGGER.error("Failed to listen: %s", err) if entry.state is not ConfigEntryState.LOADED:
except Exception as err: # noqa: BLE001 raise
LOGGER.error("Client listen failed: %s", err)
except Exception as err:
# We need to guard against unknown exceptions to not crash this task. # We need to guard against unknown exceptions to not crash this task.
LOGGER.exception("Unexpected exception: %s", err) LOGGER.exception("Unexpected exception: %s", err)
if entry.state is not ConfigEntryState.LOADED:
raise
# The entry needs to be reloaded since a new driver state # The entry needs to be reloaded since a new driver state
# will be acquired on reconnect. # will be acquired on reconnect.
# All model instances will be replaced when the new state is acquired. # All model instances will be replaced when the new state is acquired.
if should_reload: if not hass.is_stopping:
LOGGER.info("Disconnected from server. Reloading integration") if entry.state is not ConfigEntryState.LOADED:
hass.async_create_task(hass.config_entries.async_reload(entry.entry_id)) raise HomeAssistantError("Listen task ended unexpectedly")
LOGGER.debug("Disconnected from server. Reloading integration")
hass.config_entries.async_schedule_reload(entry.entry_id)
async def disconnect_client(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Disconnect client."""
client: ZwaveClient = entry.runtime_data[DATA_CLIENT]
listen_task: asyncio.Task = entry.runtime_data[DATA_CLIENT_LISTEN_TASK]
start_client_task: asyncio.Task = entry.runtime_data[DATA_START_CLIENT_TASK]
driver_events: DriverEvents = entry.runtime_data[DATA_DRIVER_EVENTS]
listen_task.cancel()
start_client_task.cancel()
platform_setup_tasks = driver_events.platform_setup_tasks.values()
for task in platform_setup_tasks:
task.cancel()
tasks = (listen_task, start_client_task, *platform_setup_tasks)
await asyncio.gather(*tasks, return_exceptions=True)
for task in tasks:
with suppress(asyncio.CancelledError):
await task
if client.connected:
await client.disconnect()
LOGGER.info("Disconnected from Zwave JS Server")
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
client: ZwaveClient = entry.runtime_data[DATA_CLIENT] unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
driver_events: DriverEvents = entry.runtime_data[DATA_DRIVER_EVENTS]
platforms = [
platform
for platform, task in driver_events.platform_setup_tasks.items()
if not task.cancel()
]
unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms)
if client.connected and client.driver: entry_runtime_data = entry.runtime_data
await async_disable_server_logging_if_needed(hass, entry, client.driver) client: ZwaveClient = entry_runtime_data[DATA_CLIENT]
if DATA_CLIENT_LISTEN_TASK in entry.runtime_data:
await disconnect_client(hass, entry) if client.connected and (driver := client.driver):
await async_disable_server_logging_if_needed(hass, entry, driver)
if entry.data.get(CONF_USE_ADDON) and entry.disabled_by: if entry.data.get(CONF_USE_ADDON) and entry.disabled_by:
addon_manager: AddonManager = get_addon_manager(hass) addon_manager: AddonManager = get_addon_manager(hass)

View File

@ -42,7 +42,6 @@ from homeassistant.helpers.service_info.usb import UsbServiceInfo
from homeassistant.helpers.service_info.zeroconf import ZeroconfServiceInfo from homeassistant.helpers.service_info.zeroconf import ZeroconfServiceInfo
from homeassistant.helpers.typing import VolDictType from homeassistant.helpers.typing import VolDictType
from . import disconnect_client
from .addon import get_addon_manager from .addon import get_addon_manager
from .const import ( from .const import (
ADDON_SLUG, ADDON_SLUG,
@ -861,7 +860,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, OptionsFlow):
and self.config_entry.state == ConfigEntryState.LOADED and self.config_entry.state == ConfigEntryState.LOADED
): ):
# Disconnect integration before restarting add-on. # Disconnect integration before restarting add-on.
await disconnect_client(self.hass, self.config_entry) await self.hass.config_entries.async_unload(self.config_entry.entry_id)
return await self.async_step_start_addon() return await self.async_step_start_addon()

View File

@ -511,18 +511,25 @@ def aeotec_smart_switch_7_state_fixture() -> NodeDataType:
@pytest.fixture(name="listen_block") @pytest.fixture(name="listen_block")
def mock_listen_block_fixture(): def mock_listen_block_fixture() -> asyncio.Event:
"""Mock a listen block.""" """Mock a listen block."""
return asyncio.Event() return asyncio.Event()
@pytest.fixture(name="listen_result")
def listen_result_fixture() -> asyncio.Future[None]:
"""Mock a listen result."""
return asyncio.Future()
@pytest.fixture(name="client") @pytest.fixture(name="client")
def mock_client_fixture( def mock_client_fixture(
controller_state, controller_state: dict[str, Any],
controller_node_state, controller_node_state: dict[str, Any],
version_state, version_state: dict[str, Any],
log_config_state, log_config_state: dict[str, Any],
listen_block, listen_block: asyncio.Event,
listen_result: asyncio.Future[None],
): ):
"""Mock a client.""" """Mock a client."""
with patch( with patch(
@ -537,6 +544,7 @@ def mock_client_fixture(
async def listen(driver_ready: asyncio.Event) -> None: async def listen(driver_ready: asyncio.Event) -> None:
driver_ready.set() driver_ready.set()
await listen_block.wait() await listen_block.wait()
await listen_result
async def disconnect(): async def disconnect():
client.connected = False client.connected = False
@ -817,7 +825,10 @@ def nortek_thermostat_removed_event_fixture(client) -> Node:
@pytest.fixture(name="integration") @pytest.fixture(name="integration")
async def integration_fixture(hass: HomeAssistant, client) -> MockConfigEntry: async def integration_fixture(
hass: HomeAssistant,
client: MagicMock,
) -> MockConfigEntry:
"""Set up the zwave_js integration.""" """Set up the zwave_js integration."""
entry = MockConfigEntry(domain="zwave_js", data={"url": "ws://test.org"}) entry = MockConfigEntry(domain="zwave_js", data={"url": "ws://test.org"})
entry.add_to_hass(hass) entry.add_to_hass(hass)

View File

@ -3,14 +3,19 @@
import asyncio import asyncio
from copy import deepcopy from copy import deepcopy
import logging import logging
from unittest.mock import AsyncMock, call, patch from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch
from aiohasupervisor import SupervisorError from aiohasupervisor import SupervisorError
from aiohasupervisor.models import AddonsOptions from aiohasupervisor.models import AddonsOptions
import pytest import pytest
from zwave_js_server.client import Client from zwave_js_server.client import Client
from zwave_js_server.event import Event from zwave_js_server.event import Event
from zwave_js_server.exceptions import BaseZwaveJSServerError, InvalidServerVersion from zwave_js_server.exceptions import (
BaseZwaveJSServerError,
InvalidServerVersion,
NotConnected,
)
from zwave_js_server.model.node import Node from zwave_js_server.model.node import Node
from zwave_js_server.model.version import VersionInfo from zwave_js_server.model.version import VersionInfo
@ -21,7 +26,7 @@ from homeassistant.components.zwave_js import DOMAIN
from homeassistant.components.zwave_js.helpers import get_device_id from homeassistant.components.zwave_js.helpers import get_device_id
from homeassistant.config_entries import ConfigEntryDisabler, ConfigEntryState from homeassistant.config_entries import ConfigEntryDisabler, ConfigEntryState
from homeassistant.const import STATE_UNAVAILABLE from homeassistant.const import STATE_UNAVAILABLE
from homeassistant.core import HomeAssistant from homeassistant.core import CoreState, HomeAssistant
from homeassistant.helpers import ( from homeassistant.helpers import (
area_registry as ar, area_registry as ar,
device_registry as dr, device_registry as dr,
@ -32,7 +37,11 @@ from homeassistant.setup import async_setup_component
from .common import AIR_TEMPERATURE_SENSOR, EATON_RF9640_ENTITY from .common import AIR_TEMPERATURE_SENSOR, EATON_RF9640_ENTITY
from tests.common import MockConfigEntry, async_get_persistent_notifications from tests.common import (
MockConfigEntry,
async_fire_time_changed,
async_get_persistent_notifications,
)
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
@ -127,24 +136,215 @@ async def test_noop_statistics(hass: HomeAssistant, client) -> None:
assert not mock_cmd2.called assert not mock_cmd2.called
@pytest.mark.parametrize("error", [BaseZwaveJSServerError("Boom"), Exception("Boom")]) async def test_driver_ready_timeout_during_setup(
async def test_listen_failure(hass: HomeAssistant, client, error) -> None: hass: HomeAssistant,
"""Test we handle errors during client listen.""" client: MagicMock,
listen_block: asyncio.Event,
) -> None:
"""Test we handle driver ready timeout during setup."""
async def listen(driver_ready): async def listen(driver_ready: asyncio.Event) -> None:
"""Mock the client listen method.""" """Mock listen."""
# Set the connect side effect to stop an endless loop on reload. await listen_block.wait()
client.connect.side_effect = BaseZwaveJSServerError("Boom")
raise error
client.listen.side_effect = listen client.listen.side_effect = listen
entry = MockConfigEntry(
domain="zwave_js",
data={"url": "ws://test.org", "data_collection_opted_in": True},
)
entry.add_to_hass(hass)
assert client.disconnect.call_count == 0
with patch("homeassistant.components.zwave_js.DRIVER_READY_TIMEOUT", new=0):
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.SETUP_RETRY
assert client.disconnect.call_count == 1
@pytest.mark.parametrize("core_state", [CoreState.running, CoreState.stopping])
@pytest.mark.parametrize(
("listen_future_result_method", "listen_future_result"),
[
("set_exception", BaseZwaveJSServerError("Boom")),
("set_exception", Exception("Boom")),
("set_result", None),
],
)
async def test_listen_done_during_setup_before_forward_entry(
hass: HomeAssistant,
client: MagicMock,
listen_block: asyncio.Event,
listen_result: asyncio.Future[None],
core_state: CoreState,
listen_future_result_method: str,
listen_future_result: Exception | None,
) -> None:
"""Test listen task finishing during setup before forward entry."""
assert hass.state is CoreState.running
async def listen(driver_ready: asyncio.Event) -> None:
await listen_block.wait()
await listen_result
async_fire_time_changed(hass, fire_all=True)
client.listen.side_effect = listen
hass.set_state(core_state)
listen_block.set()
getattr(listen_result, listen_future_result_method)(listen_future_result)
entry = MockConfigEntry(domain="zwave_js", data={"url": "ws://test.org"}) entry = MockConfigEntry(domain="zwave_js", data={"url": "ws://test.org"})
entry.add_to_hass(hass) entry.add_to_hass(hass)
assert client.disconnect.call_count == 0
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert entry.state is ConfigEntryState.SETUP_RETRY assert entry.state is ConfigEntryState.SETUP_RETRY
assert client.disconnect.call_count == 1
async def test_not_connected_during_setup_after_forward_entry(
hass: HomeAssistant,
client: MagicMock,
listen_block: asyncio.Event,
listen_result: asyncio.Future[None],
) -> None:
"""Test we handle not connected client during setup after forward entry."""
async def send_command_side_effect(*args: Any, **kwargs: Any) -> None:
"""Mock send command."""
listen_block.set()
listen_result.set_result(None)
# Yield to allow the listen task to run
await asyncio.sleep(0)
raise NotConnected("Boom")
async def listen(driver_ready: asyncio.Event) -> None:
"""Mock listen."""
driver_ready.set()
client.async_send_command.side_effect = send_command_side_effect
await listen_block.wait()
await listen_result
client.listen.side_effect = listen
entry = MockConfigEntry(
domain="zwave_js",
data={"url": "ws://test.org", "data_collection_opted_in": True},
)
entry.add_to_hass(hass)
assert client.disconnect.call_count == 0
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.SETUP_RETRY
assert client.disconnect.call_count == 1
@pytest.mark.parametrize("core_state", [CoreState.running, CoreState.stopping])
@pytest.mark.parametrize(
("listen_future_result_method", "listen_future_result"),
[
("set_exception", BaseZwaveJSServerError("Boom")),
("set_exception", Exception("Boom")),
("set_result", None),
],
)
async def test_listen_done_during_setup_after_forward_entry(
hass: HomeAssistant,
client: MagicMock,
listen_block: asyncio.Event,
listen_result: asyncio.Future[None],
core_state: CoreState,
listen_future_result_method: str,
listen_future_result: Exception | None,
) -> None:
"""Test listen task finishing during setup after forward entry."""
assert hass.state is CoreState.running
async def send_command_side_effect(*args: Any, **kwargs: Any) -> None:
"""Mock send command."""
listen_block.set()
getattr(listen_result, listen_future_result_method)(listen_future_result)
# Yield to allow the listen task to run
await asyncio.sleep(0)
async def listen(driver_ready: asyncio.Event) -> None:
"""Mock listen."""
driver_ready.set()
client.async_send_command.side_effect = send_command_side_effect
await listen_block.wait()
await listen_result
client.listen.side_effect = listen
hass.set_state(core_state)
entry = MockConfigEntry(
domain="zwave_js",
data={"url": "ws://test.org", "data_collection_opted_in": True},
)
entry.add_to_hass(hass)
assert client.disconnect.call_count == 0
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.SETUP_RETRY
assert client.disconnect.call_count == 1
@pytest.mark.parametrize(
("core_state", "final_config_entry_state", "disconnect_call_count"),
[
(
CoreState.running,
ConfigEntryState.SETUP_RETRY,
2,
), # the reload will cause a disconnect call too
(
CoreState.stopping,
ConfigEntryState.LOADED,
0,
), # the home assistant stop event will handle the disconnect
],
)
@pytest.mark.parametrize(
("listen_future_result_method", "listen_future_result"),
[
("set_exception", BaseZwaveJSServerError("Boom")),
("set_exception", Exception("Boom")),
("set_result", None),
],
)
async def test_listen_done_after_setup(
hass: HomeAssistant,
client: MagicMock,
integration: MockConfigEntry,
listen_block: asyncio.Event,
listen_result: asyncio.Future[None],
core_state: CoreState,
listen_future_result_method: str,
listen_future_result: Exception | None,
final_config_entry_state: ConfigEntryState,
disconnect_call_count: int,
) -> None:
"""Test listen task finishing after setup."""
config_entry = integration
assert config_entry.state is ConfigEntryState.LOADED
assert hass.state is CoreState.running
assert client.disconnect.call_count == 0
hass.set_state(core_state)
listen_block.set()
getattr(listen_result, listen_future_result_method)(listen_future_result)
await hass.async_block_till_done()
assert config_entry.state is final_config_entry_state
assert client.disconnect.call_count == disconnect_call_count
async def test_new_entity_on_value_added( async def test_new_entity_on_value_added(

View File

@ -658,8 +658,10 @@ async def test_update_entity_delay(
assert len(client.async_send_command.call_args_list) == 2 assert len(client.async_send_command.call_args_list) == 2
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5)) update_interval = timedelta(minutes=5)
await hass.async_block_till_done(wait_background_tasks=True) freezer.tick(update_interval)
async_fire_time_changed(hass)
await hass.async_block_till_done()
nodes: set[int] = set() nodes: set[int] = set()
@ -668,8 +670,9 @@ async def test_update_entity_delay(
assert args["command"] == "controller.get_available_firmware_updates" assert args["command"] == "controller.get_available_firmware_updates"
nodes.add(args["nodeId"]) nodes.add(args["nodeId"])
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=10)) freezer.tick(update_interval)
await hass.async_block_till_done(wait_background_tasks=True) async_fire_time_changed(hass)
await hass.async_block_till_done()
assert len(client.async_send_command.call_args_list) == 4 assert len(client.async_send_command.call_args_list) == 4
args = client.async_send_command.call_args_list[3][0][0] args = client.async_send_command.call_args_list[3][0][0]