Fix Z-Wave unique id after controller reset (#144813)

This commit is contained in:
Martin Hjelmare 2025-05-13 13:12:00 +02:00 committed by Franck Nijhof
parent d82feb807f
commit 6c3a4f17f0
No known key found for this signature in database
GPG Key ID: AB33ADACE7101952
6 changed files with 133 additions and 63 deletions

View File

@ -71,6 +71,7 @@ from homeassistant.components.websocket_api import (
ActiveConnection,
)
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import CONF_URL
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.aiohttp_client import async_get_clientsession
@ -88,13 +89,16 @@ from .const import (
DATA_CLIENT,
DOMAIN,
EVENT_DEVICE_ADDED_TO_REGISTRY,
LOGGER,
RESTORE_NVM_DRIVER_READY_TIMEOUT,
USER_AGENT,
)
from .helpers import (
CannotConnect,
async_enable_statistics,
async_get_node_from_device_id,
async_get_provisioning_entry_from_device_id,
async_get_version_info,
get_device_id,
)
@ -2857,6 +2861,25 @@ async def websocket_hard_reset_controller(
async with asyncio.timeout(HARD_RESET_CONTROLLER_DRIVER_READY_TIMEOUT):
await wait_driver_ready.wait()
# When resetting the controller, the controller home id is also changed.
# The controller state in the client is stale after resetting the controller,
# so get the new home id with a new client using the helper function.
# The client state will be refreshed by reloading the config entry,
# after the unique id of the config entry has been updated.
try:
version_info = await async_get_version_info(hass, entry.data[CONF_URL])
except CannotConnect:
# Just log this error, as there's nothing to do about it here.
# The stale unique id needs to be handled by a repair flow,
# after the config entry has been reloaded.
LOGGER.error(
"Failed to get server version, cannot update config entry"
"unique id with new home id, after controller reset"
)
else:
hass.config_entries.async_update_entry(
entry, unique_id=str(version_info.home_id)
)
await hass.config_entries.async_reload(entry.entry_id)

View File

@ -9,14 +9,13 @@ import logging
from pathlib import Path
from typing import Any
import aiohttp
from awesomeversion import AwesomeVersion
from serial.tools import list_ports
import voluptuous as vol
from zwave_js_server.client import Client
from zwave_js_server.exceptions import FailedCommand
from zwave_js_server.model.driver import Driver
from zwave_js_server.version import VersionInfo, get_server_version
from zwave_js_server.version import VersionInfo
from homeassistant.components import usb
from homeassistant.components.hassio import (
@ -36,7 +35,6 @@ from homeassistant.const import CONF_NAME, CONF_URL
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import AbortFlow
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.hassio import is_hassio
from homeassistant.helpers.service_info.hassio import HassioServiceInfo
from homeassistant.helpers.service_info.usb import UsbServiceInfo
@ -69,6 +67,7 @@ from .const import (
DOMAIN,
RESTORE_NVM_DRIVER_READY_TIMEOUT,
)
from .helpers import CannotConnect, async_get_version_info
_LOGGER = logging.getLogger(__name__)
@ -79,7 +78,6 @@ ADDON_SETUP_TIMEOUT = 5
ADDON_SETUP_TIMEOUT_ROUNDS = 40
CONF_EMULATE_HARDWARE = "emulate_hardware"
CONF_LOG_LEVEL = "log_level"
SERVER_VERSION_TIMEOUT = 10
ADDON_LOG_LEVELS = {
"error": "Error",
@ -130,22 +128,6 @@ async def validate_input(hass: HomeAssistant, user_input: dict) -> VersionInfo:
raise InvalidInput("cannot_connect") from err
async def async_get_version_info(hass: HomeAssistant, ws_address: str) -> VersionInfo:
"""Return Z-Wave JS version info."""
try:
async with asyncio.timeout(SERVER_VERSION_TIMEOUT):
version_info: VersionInfo = await get_server_version(
ws_address, async_get_clientsession(hass)
)
except (TimeoutError, aiohttp.ClientError) as err:
# We don't want to spam the log if the add-on isn't started
# or takes a long time to start.
_LOGGER.debug("Failed to connect to Z-Wave JS server: %s", err)
raise CannotConnect from err
return version_info
def get_usb_ports() -> dict[str, str]:
"""Return a dict of USB ports and their friendly names."""
ports = list_ports.comports()
@ -1357,10 +1339,6 @@ class ZWaveJSConfigFlow(ConfigFlow, domain=DOMAIN):
return client.driver
class CannotConnect(HomeAssistantError):
"""Indicate connection error."""
class InvalidInput(HomeAssistantError):
"""Error to indicate input data is invalid."""

View File

@ -2,11 +2,13 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from dataclasses import astuple, dataclass
import logging
from typing import Any, cast
import aiohttp
import voluptuous as vol
from zwave_js_server.client import Client as ZwaveClient
from zwave_js_server.const import (
@ -25,6 +27,7 @@ from zwave_js_server.model.value import (
ValueDataType,
get_value_id_str,
)
from zwave_js_server.version import VersionInfo, get_server_version
from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
@ -38,6 +41,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.group import expand_entity_ids
from homeassistant.helpers.typing import ConfigType, VolSchemaType
@ -54,6 +58,8 @@ from .const import (
LOGGER,
)
SERVER_VERSION_TIMEOUT = 10
@dataclass
class ZwaveValueID:
@ -568,3 +574,23 @@ def get_network_identifier_for_notification(
return f"`{config_entry.title}`, with the home ID `{home_id}`,"
return f"with the home ID `{home_id}`"
return ""
async def async_get_version_info(hass: HomeAssistant, ws_address: str) -> VersionInfo:
"""Return Z-Wave JS version info."""
try:
async with asyncio.timeout(SERVER_VERSION_TIMEOUT):
version_info: VersionInfo = await get_server_version(
ws_address, async_get_clientsession(hass)
)
except (TimeoutError, aiohttp.ClientError) as err:
# We don't want to spam the log if the add-on isn't started
# or takes a long time to start.
LOGGER.debug("Failed to connect to Z-Wave JS server: %s", err)
raise CannotConnect from err
return version_info
class CannotConnect(HomeAssistantError):
"""Indicate connection error."""

View File

@ -1,6 +1,7 @@
"""Provide common Z-Wave JS fixtures."""
import asyncio
from collections.abc import Generator
import copy
import io
from typing import Any, cast
@ -15,6 +16,7 @@ from zwave_js_server.version import VersionInfo
from homeassistant.components.zwave_js import PLATFORMS
from homeassistant.components.zwave_js.const import DOMAIN
from homeassistant.components.zwave_js.helpers import SERVER_VERSION_TIMEOUT
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.util.json import JsonArrayType
@ -587,6 +589,44 @@ def mock_client_fixture(
yield client
@pytest.fixture(name="server_version_side_effect")
def server_version_side_effect_fixture() -> Any | None:
"""Return the server version side effect."""
return None
@pytest.fixture(name="get_server_version", autouse=True)
def mock_get_server_version(
server_version_side_effect: Any | None, server_version_timeout: int
) -> Generator[AsyncMock]:
"""Mock server version."""
version_info = VersionInfo(
driver_version="mock-driver-version",
server_version="mock-server-version",
home_id=1234,
min_schema_version=0,
max_schema_version=1,
)
with (
patch(
"homeassistant.components.zwave_js.helpers.get_server_version",
side_effect=server_version_side_effect,
return_value=version_info,
) as mock_version,
patch(
"homeassistant.components.zwave_js.helpers.SERVER_VERSION_TIMEOUT",
new=server_version_timeout,
),
):
yield mock_version
@pytest.fixture(name="server_version_timeout")
def mock_server_version_timeout() -> int:
"""Patch the timeout for getting server version."""
return SERVER_VERSION_TIMEOUT
@pytest.fixture(name="multisensor_6")
def multisensor_6_fixture(client, multisensor_6_state) -> Node:
"""Mock a multisensor 6 node."""

View File

@ -7,6 +7,7 @@ import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, PropertyMock, call, patch
from aiohttp import ClientError
import pytest
from zwave_js_server.const import (
ExclusionStrategy,
@ -5080,14 +5081,17 @@ async def test_subscribe_node_statistics(
async def test_hard_reset_controller(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
device_registry: dr.DeviceRegistry,
client: MagicMock,
get_server_version: AsyncMock,
integration: MockConfigEntry,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test that the hard_reset_controller WS API call works."""
entry = integration
ws_client = await hass_ws_client(hass)
assert entry.unique_id == "3245146787"
async def async_send_command_driver_ready(
message: dict[str, Any],
@ -5122,6 +5126,40 @@ async def test_hard_reset_controller(
assert client.async_send_command.call_args_list[0] == call(
{"command": "driver.hard_reset"}, 25
)
assert entry.unique_id == "1234"
client.async_send_command.reset_mock()
# Test client connect error when getting the server version.
get_server_version.side_effect = ClientError("Boom!")
await ws_client.send_json_auto_id(
{
TYPE: "zwave_js/hard_reset_controller",
ENTRY_ID: entry.entry_id,
}
)
msg = await ws_client.receive_json()
device = device_registry.async_get_device(
identifiers={get_device_id(client.driver, client.driver.controller.nodes[1])}
)
assert device is not None
assert msg["result"] == device.id
assert msg["success"]
assert client.async_send_command.call_count == 3
# The first call is the relevant hard reset command.
# 25 is the require_schema parameter.
assert client.async_send_command.call_args_list[0] == call(
{"command": "driver.hard_reset"}, 25
)
assert (
"Failed to get server version, cannot update config entry"
"unique id with new home id, after controller reset"
) in caplog.text
client.async_send_command.reset_mock()
@ -5162,6 +5200,8 @@ async def test_hard_reset_controller(
{"command": "driver.hard_reset"}, 25
)
client.async_send_command.reset_mock()
# Test FailedZWaveCommand is caught
with patch(
"zwave_js_server.model.driver.Driver.async_hard_reset",

View File

@ -17,8 +17,9 @@ from zwave_js_server.exceptions import FailedCommand
from zwave_js_server.version import VersionInfo
from homeassistant import config_entries, data_entry_flow
from homeassistant.components.zwave_js.config_flow import SERVER_VERSION_TIMEOUT, TITLE
from homeassistant.components.zwave_js.config_flow import TITLE
from homeassistant.components.zwave_js.const import ADDON_SLUG, CONF_USB_PATH, DOMAIN
from homeassistant.components.zwave_js.helpers import SERVER_VERSION_TIMEOUT
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers.service_info.hassio import HassioServiceInfo
@ -89,44 +90,6 @@ def mock_supervisor_fixture() -> Generator[None]:
yield
@pytest.fixture(name="server_version_side_effect")
def server_version_side_effect_fixture() -> Any | None:
"""Return the server version side effect."""
return None
@pytest.fixture(name="get_server_version", autouse=True)
def mock_get_server_version(
server_version_side_effect: Any | None, server_version_timeout: int
) -> Generator[AsyncMock]:
"""Mock server version."""
version_info = VersionInfo(
driver_version="mock-driver-version",
server_version="mock-server-version",
home_id=1234,
min_schema_version=0,
max_schema_version=1,
)
with (
patch(
"homeassistant.components.zwave_js.config_flow.get_server_version",
side_effect=server_version_side_effect,
return_value=version_info,
) as mock_version,
patch(
"homeassistant.components.zwave_js.config_flow.SERVER_VERSION_TIMEOUT",
new=server_version_timeout,
),
):
yield mock_version
@pytest.fixture(name="server_version_timeout")
def mock_server_version_timeout() -> int:
"""Patch the timeout for getting server version."""
return SERVER_VERSION_TIMEOUT
@pytest.fixture(name="addon_setup_time", autouse=True)
def mock_addon_setup_time() -> Generator[None]:
"""Mock add-on setup sleep time."""