diff --git a/homeassistant/components/zwave_js/__init__.py b/homeassistant/components/zwave_js/__init__.py index d12c7df6a79..0c716a39c75 100644 --- a/homeassistant/components/zwave_js/__init__.py +++ b/homeassistant/components/zwave_js/__init__.py @@ -9,6 +9,7 @@ from typing import Any from async_timeout import timeout from zwave_js_server.client import Client as ZwaveClient from zwave_js_server.exceptions import BaseZwaveJSServerError, InvalidServerVersion +from zwave_js_server.model.driver import Driver from zwave_js_server.model.node import Node as ZwaveNode from zwave_js_server.model.notification import ( EntryControlNotification, @@ -25,12 +26,6 @@ from homeassistant.const import ( ATTR_DEVICE_ID, ATTR_DOMAIN, ATTR_ENTITY_ID, - ATTR_IDENTIFIERS, - ATTR_MANUFACTURER, - ATTR_MODEL, - ATTR_NAME, - ATTR_SUGGESTED_AREA, - ATTR_SW_VERSION, CONF_URL, EVENT_HOMEASSISTANT_STOP, ) @@ -39,7 +34,7 @@ from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers import device_registry, entity_registry from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.dispatcher import async_dispatcher_send -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.typing import UNDEFINED, ConfigType from .addon import AddonError, AddonManager, AddonState, get_addon_manager from .api import async_register_api @@ -154,39 +149,105 @@ def register_node_in_dev_reg( else: ids = {device_id} - params = { - ATTR_IDENTIFIERS: ids, - ATTR_SW_VERSION: node.firmware_version, - ATTR_NAME: node.name - or node.device_config.description - or f"Node {node.node_id}", - ATTR_MODEL: node.device_config.label, - ATTR_MANUFACTURER: node.device_config.manufacturer, - } - if node.location: - params[ATTR_SUGGESTED_AREA] = node.location - device = dev_reg.async_get_or_create(config_entry_id=entry.entry_id, **params) + device = dev_reg.async_get_or_create( + config_entry_id=entry.entry_id, + identifiers=ids, + sw_version=node.firmware_version, + name=node.name or node.device_config.description or f"Node {node.node_id}", + model=node.device_config.label, + manufacturer=node.device_config.manufacturer, + suggested_area=node.location if node.location else UNDEFINED, + ) async_dispatcher_send(hass, EVENT_DEVICE_ADDED_TO_REGISTRY, device) return device -async def async_setup_entry( # noqa: C901 - hass: HomeAssistant, entry: ConfigEntry -) -> bool: +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up Z-Wave JS from a config entry.""" if use_addon := entry.data.get(CONF_USE_ADDON): await async_ensure_addon_running(hass, entry) client = ZwaveClient(entry.data[CONF_URL], async_get_clientsession(hass)) + entry_hass_data: dict = hass.data[DOMAIN].setdefault(entry.entry_id, {}) + + # connect and throw error if connection failed + try: + async with timeout(CONNECT_TIMEOUT): + await client.connect() + except InvalidServerVersion as err: + if not entry_hass_data.get(DATA_INVALID_SERVER_VERSION_LOGGED): + LOGGER.error("Invalid server version: %s", err) + entry_hass_data[DATA_INVALID_SERVER_VERSION_LOGGED] = True + if use_addon: + async_ensure_addon_updated(hass) + raise ConfigEntryNotReady from err + except (asyncio.TimeoutError, BaseZwaveJSServerError) as err: + if not entry_hass_data.get(DATA_CONNECT_FAILED_LOGGED): + LOGGER.error("Failed to connect: %s", err) + entry_hass_data[DATA_CONNECT_FAILED_LOGGED] = True + raise ConfigEntryNotReady from err + else: + LOGGER.info("Connected to Zwave JS Server") + entry_hass_data[DATA_CONNECT_FAILED_LOGGED] = False + entry_hass_data[DATA_INVALID_SERVER_VERSION_LOGGED] = False + + dev_reg = device_registry.async_get(hass) + ent_reg = entity_registry.async_get(hass) + services = ZWaveServices(hass, ent_reg, dev_reg) + services.async_register() + + # Set up websocket API + async_register_api(hass) + + platform_task = hass.async_create_task(start_platforms(hass, entry, client)) + entry_hass_data[DATA_START_PLATFORM_TASK] = platform_task + + return True + + +async def start_platforms( + hass: HomeAssistant, entry: ConfigEntry, client: ZwaveClient +) -> None: + """Start platforms and perform discovery.""" + entry_hass_data: dict = hass.data[DOMAIN].setdefault(entry.entry_id, {}) + entry_hass_data[DATA_CLIENT] = client + entry_hass_data[DATA_PLATFORM_SETUP] = {} + driver_ready = asyncio.Event() + + async def handle_ha_shutdown(event: Event) -> None: + """Handle HA shutdown.""" + await disconnect_client(hass, entry) + + listen_task = asyncio.create_task(client_listen(hass, entry, client, driver_ready)) + entry_hass_data[DATA_CLIENT_LISTEN_TASK] = listen_task + entry.async_on_unload( + hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_ha_shutdown) + ) + + try: + await driver_ready.wait() + except asyncio.CancelledError: + LOGGER.debug("Cancelling start platforms") + return + + LOGGER.info("Connection to Zwave JS Server initialized") + + if client.driver is None: + raise RuntimeError("Driver not ready.") + + await setup_driver(hass, entry, client, client.driver) + + +async def setup_driver( # noqa: C901 + hass: HomeAssistant, entry: ConfigEntry, client: ZwaveClient, driver: Driver +) -> None: + """Set up devices using the ready driver.""" dev_reg = device_registry.async_get(hass) ent_reg = entity_registry.async_get(hass) entry_hass_data: dict = hass.data[DOMAIN].setdefault(entry.entry_id, {}) - - entry_hass_data[DATA_CLIENT] = client - platform_setup_tasks = entry_hass_data[DATA_PLATFORM_SETUP] = {} - + platform_setup_tasks = entry_hass_data[DATA_PLATFORM_SETUP] registered_unique_ids: dict[str, dict[str, set[str]]] = defaultdict(dict) discovered_value_ids: dict[str, set[str]] = defaultdict(set) @@ -384,7 +445,7 @@ async def async_setup_entry( # noqa: C901 { ATTR_DOMAIN: DOMAIN, ATTR_NODE_ID: notification.node.node_id, - ATTR_HOME_ID: client.driver.controller.home_id, + ATTR_HOME_ID: driver.controller.home_id, ATTR_ENDPOINT: notification.endpoint, ATTR_DEVICE_ID: device.id, ATTR_COMMAND_CLASS: notification.command_class, @@ -414,7 +475,7 @@ async def async_setup_entry( # noqa: C901 event_data = { ATTR_DOMAIN: DOMAIN, ATTR_NODE_ID: notification.node.node_id, - ATTR_HOME_ID: client.driver.controller.home_id, + ATTR_HOME_ID: driver.controller.home_id, ATTR_DEVICE_ID: device.id, ATTR_COMMAND_CLASS: notification.command_class, } @@ -487,7 +548,7 @@ async def async_setup_entry( # noqa: C901 ZWAVE_JS_VALUE_UPDATED_EVENT, { ATTR_NODE_ID: value.node.node_id, - ATTR_HOME_ID: client.driver.controller.home_id, + ATTR_HOME_ID: driver.controller.home_id, ATTR_DEVICE_ID: device.id, ATTR_ENTITY_ID: entity_id, ATTR_COMMAND_CLASS: value.command_class, @@ -502,105 +563,42 @@ async def async_setup_entry( # noqa: C901 }, ) - # connect and throw error if connection failed - try: - async with timeout(CONNECT_TIMEOUT): - await client.connect() - except InvalidServerVersion as err: - if not entry_hass_data.get(DATA_INVALID_SERVER_VERSION_LOGGED): - LOGGER.error("Invalid server version: %s", err) - entry_hass_data[DATA_INVALID_SERVER_VERSION_LOGGED] = True - if use_addon: - async_ensure_addon_updated(hass) - raise ConfigEntryNotReady from err - except (asyncio.TimeoutError, BaseZwaveJSServerError) as err: - if not entry_hass_data.get(DATA_CONNECT_FAILED_LOGGED): - LOGGER.error("Failed to connect: %s", err) - entry_hass_data[DATA_CONNECT_FAILED_LOGGED] = True - raise ConfigEntryNotReady from err - else: - LOGGER.info("Connected to Zwave JS Server") - entry_hass_data[DATA_CONNECT_FAILED_LOGGED] = False - entry_hass_data[DATA_INVALID_SERVER_VERSION_LOGGED] = False + # If opt in preference hasn't been specified yet, we do nothing, otherwise + # we apply the preference + if opted_in := entry.data.get(CONF_DATA_COLLECTION_OPTED_IN): + await async_enable_statistics(client) + elif opted_in is False: + await driver.async_disable_statistics() - services = ZWaveServices(hass, ent_reg, dev_reg) - services.async_register() + # Check for nodes that no longer exist and remove them + stored_devices = device_registry.async_entries_for_config_entry( + dev_reg, entry.entry_id + ) + known_devices = [ + dev_reg.async_get_device({get_device_id(client, node)}) + for node in driver.controller.nodes.values() + ] - # Set up websocket API - async_register_api(hass) + # Devices that are in the device registry that are not known by the controller can be removed + for device in stored_devices: + if device not in known_devices: + dev_reg.async_remove_device(device.id) - async def start_platforms() -> None: - """Start platforms and perform discovery.""" - driver_ready = asyncio.Event() + # run discovery on all ready nodes + await asyncio.gather( + *(async_on_node_added(node) for node in driver.controller.nodes.values()) + ) - async def handle_ha_shutdown(event: Event) -> None: - """Handle HA shutdown.""" - await disconnect_client(hass, entry) - - listen_task = asyncio.create_task( - client_listen(hass, entry, client, driver_ready) + # listen for new nodes being added to the mesh + entry.async_on_unload( + driver.controller.on( + "node added", + lambda event: hass.async_create_task(async_on_node_added(event["node"])), ) - entry_hass_data[DATA_CLIENT_LISTEN_TASK] = listen_task - entry.async_on_unload( - hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_ha_shutdown) - ) - - try: - await driver_ready.wait() - except asyncio.CancelledError: - LOGGER.debug("Cancelling start platforms") - return - - LOGGER.info("Connection to Zwave JS Server initialized") - - # If opt in preference hasn't been specified yet, we do nothing, otherwise - # we apply the preference - if opted_in := entry.data.get(CONF_DATA_COLLECTION_OPTED_IN): - await async_enable_statistics(client) - elif opted_in is False: - await client.driver.async_disable_statistics() - - # Check for nodes that no longer exist and remove them - stored_devices = device_registry.async_entries_for_config_entry( - dev_reg, entry.entry_id - ) - known_devices = [ - dev_reg.async_get_device({get_device_id(client, node)}) - for node in client.driver.controller.nodes.values() - ] - - # Devices that are in the device registry that are not known by the controller can be removed - for device in stored_devices: - if device not in known_devices: - dev_reg.async_remove_device(device.id) - - # run discovery on all ready nodes - await asyncio.gather( - *( - async_on_node_added(node) - for node in client.driver.controller.nodes.values() - ) - ) - - # listen for new nodes being added to the mesh - entry.async_on_unload( - client.driver.controller.on( - "node added", - lambda event: hass.async_create_task( - async_on_node_added(event["node"]) - ), - ) - ) - # listen for nodes being removed from the mesh - # NOTE: This will not remove nodes that were removed when HA was not running - entry.async_on_unload( - client.driver.controller.on("node removed", async_on_node_removed) - ) - - platform_task = hass.async_create_task(start_platforms()) - entry_hass_data[DATA_START_PLATFORM_TASK] = platform_task - - return True + ) + # listen for nodes being removed from the mesh + # NOTE: This will not remove nodes that were removed when HA was not running + entry.async_on_unload(driver.controller.on("node removed", async_on_node_removed)) async def client_listen(