diff --git a/homeassistant/components/zha/__init__.py b/homeassistant/components/zha/__init__.py index a3c68ae3030..7303367d485 100644 --- a/homeassistant/components/zha/__init__.py +++ b/homeassistant/components/zha/__init__.py @@ -1,7 +1,4 @@ -"""Support for Zigbee Home Automation devices. - -isort:skip_file -""" +"""Support for Zigbee Home Automation devices.""" import logging @@ -11,7 +8,6 @@ from homeassistant import config_entries, const as ha_const import homeassistant.helpers.config_validation as cv from homeassistant.helpers.device_registry import CONNECTION_ZIGBEE -from . import config_flow # noqa: F401 pylint: disable=unused-import from . import api from .core import ZHAGateway from .core.const import ( @@ -147,5 +143,4 @@ async def async_unload_entry(hass, config_entry): for component in COMPONENTS: await hass.config_entries.async_forward_entry_unload(config_entry, component) - del hass.data[DATA_ZHA] return True diff --git a/homeassistant/components/zha/api.py b/homeassistant/components/zha/api.py index e796c48c3f3..462afd777b9 100644 --- a/homeassistant/components/zha/api.py +++ b/homeassistant/components/zha/api.py @@ -50,7 +50,11 @@ from .core.const import ( WARNING_DEVICE_STROBE_HIGH, WARNING_DEVICE_STROBE_YES, ) -from .core.helpers import async_is_bindable_target, get_matched_clusters +from .core.helpers import ( + async_get_device_info, + async_is_bindable_target, + get_matched_clusters, +) _LOGGER = logging.getLogger(__name__) @@ -423,31 +427,6 @@ async def websocket_remove_group_members(hass, connection, msg): connection.send_result(msg[ID], ret_group) -@callback -def async_get_device_info(hass, device, ha_device_registry=None): - """Get ZHA device.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ret_device = {} - ret_device.update(device.device_info) - ret_device["entities"] = [ - { - "entity_id": entity_ref.reference_id, - ATTR_NAME: entity_ref.device_info[ATTR_NAME], - } - for entity_ref in zha_gateway.device_registry[device.ieee] - ] - - if ha_device_registry is not None: - reg_device = ha_device_registry.async_get_device( - {(DOMAIN, str(device.ieee))}, set() - ) - if reg_device is not None: - ret_device["user_given_name"] = reg_device.name_by_user - ret_device["device_reg_id"] = reg_device.id - ret_device["area_id"] = reg_device.area_id - return ret_device - - async def get_groups(hass,): """Get ZHA Groups.""" zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] diff --git a/homeassistant/components/zha/config_flow.py b/homeassistant/components/zha/config_flow.py index 474cb15b41a..5ee0d0ee9bb 100644 --- a/homeassistant/components/zha/config_flow.py +++ b/homeassistant/components/zha/config_flow.py @@ -1,4 +1,5 @@ """Config flow for ZHA.""" +import asyncio from collections import OrderedDict import os @@ -9,11 +10,14 @@ from homeassistant import config_entries from .core.const import ( CONF_RADIO_TYPE, CONF_USB_PATH, + CONTROLLER, + DEFAULT_BAUDRATE, DEFAULT_DATABASE_NAME, DOMAIN, + ZHA_GW_RADIO, RadioType, ) -from .core.helpers import check_zigpy_connection +from .core.registries import RADIO_TYPES @config_entries.HANDLERS.register(DOMAIN) @@ -57,3 +61,20 @@ class ZhaFlowHandler(config_entries.ConfigFlow): return self.async_create_entry( title=import_info[CONF_USB_PATH], data=import_info ) + + +async def check_zigpy_connection(usb_path, radio_type, database_path): + """Test zigpy radio connection.""" + try: + radio = RADIO_TYPES[radio_type][ZHA_GW_RADIO]() + controller_application = RADIO_TYPES[radio_type][CONTROLLER] + except KeyError: + return False + try: + await radio.connect(usb_path, DEFAULT_BAUDRATE) + controller = controller_application(radio, database_path) + await asyncio.wait_for(controller.startup(auto_form=True), timeout=30) + await controller.shutdown() + except Exception: # pylint: disable=broad-except + return False + return True diff --git a/homeassistant/components/zha/core/gateway.py b/homeassistant/components/zha/core/gateway.py index ef81705ce47..72931c665ee 100644 --- a/homeassistant/components/zha/core/gateway.py +++ b/homeassistant/components/zha/core/gateway.py @@ -20,7 +20,6 @@ from homeassistant.helpers.device_registry import ( ) from homeassistant.helpers.dispatcher import async_dispatcher_send -from ..api import async_get_device_info from .const import ( ATTR_IEEE, ATTR_MANUFACTURER, @@ -65,6 +64,7 @@ from .const import ( ) from .device import DeviceStatus, ZHADevice from .discovery import async_dispatch_discovery_info, async_process_endpoint +from .helpers import async_get_device_info from .patches import apply_application_controller_patch from .registries import RADIO_TYPES from .store import async_get_registry diff --git a/homeassistant/components/zha/core/helpers.py b/homeassistant/components/zha/core/helpers.py index d3f06090dae..981a03fe7b5 100644 --- a/homeassistant/components/zha/core/helpers.py +++ b/homeassistant/components/zha/core/helpers.py @@ -4,29 +4,20 @@ Helpers for Zigbee Home Automation. For more details about this component, please refer to the documentation at https://home-assistant.io/integrations/zha/ """ -import asyncio import collections import logging -import bellows.ezsp -import bellows.zigbee.application import zigpy.types -import zigpy_deconz.api -import zigpy_deconz.zigbee.application -import zigpy_xbee.api -import zigpy_xbee.zigbee.application -import zigpy_zigate.api -import zigpy_zigate.zigbee.application from homeassistant.core import callback from .const import ( + ATTR_NAME, CLUSTER_TYPE_IN, CLUSTER_TYPE_OUT, DATA_ZHA, DATA_ZHA_GATEWAY, - DEFAULT_BAUDRATE, - RadioType, + DOMAIN, ) from .registries import BINDABLE_CLUSTERS @@ -56,30 +47,6 @@ async def safe_read( return {} -async def check_zigpy_connection(usb_path, radio_type, database_path): - """Test zigpy radio connection.""" - if radio_type == RadioType.ezsp.name: - radio = bellows.ezsp.EZSP() - ControllerApplication = bellows.zigbee.application.ControllerApplication - elif radio_type == RadioType.xbee.name: - radio = zigpy_xbee.api.XBee() - ControllerApplication = zigpy_xbee.zigbee.application.ControllerApplication - elif radio_type == RadioType.deconz.name: - radio = zigpy_deconz.api.Deconz() - ControllerApplication = zigpy_deconz.zigbee.application.ControllerApplication - elif radio_type == RadioType.zigate.name: - radio = zigpy_zigate.api.ZiGate() - ControllerApplication = zigpy_zigate.zigbee.application.ControllerApplication - try: - await radio.connect(usb_path, DEFAULT_BAUDRATE) - controller = ControllerApplication(radio, database_path) - await asyncio.wait_for(controller.startup(auto_form=True), timeout=30) - await controller.shutdown() - except Exception: # pylint: disable=broad-except - return False - return True - - def get_attr_id_by_name(cluster, attr_name): """Get the attribute id for a cluster attribute by its name.""" return next( @@ -164,3 +131,28 @@ class LogMixin: def error(self, msg, *args): """Error level log.""" return self.log(logging.ERROR, msg, *args) + + +@callback +def async_get_device_info(hass, device, ha_device_registry=None): + """Get ZHA device.""" + zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + ret_device = {} + ret_device.update(device.device_info) + ret_device["entities"] = [ + { + "entity_id": entity_ref.reference_id, + ATTR_NAME: entity_ref.device_info[ATTR_NAME], + } + for entity_ref in zha_gateway.device_registry[device.ieee] + ] + + if ha_device_registry is not None: + reg_device = ha_device_registry.async_get_device( + {(DOMAIN, str(device.ieee))}, set() + ) + if reg_device is not None: + ret_device["user_given_name"] = reg_device.name_by_user + ret_device["device_reg_id"] = reg_device.id + ret_device["area_id"] = reg_device.area_id + return ret_device diff --git a/tests/components/zha/test_config_flow.py b/tests/components/zha/test_config_flow.py index 5e6bf51afd6..fdff064a1c5 100644 --- a/tests/components/zha/test_config_flow.py +++ b/tests/components/zha/test_config_flow.py @@ -1,8 +1,11 @@ """Tests for ZHA config flow.""" -from asynctest import patch +from unittest import mock + +import asynctest from homeassistant.components.zha import config_flow -from homeassistant.components.zha.core.const import DOMAIN +from homeassistant.components.zha.core.const import CONTROLLER, DOMAIN, ZHA_GW_RADIO +import homeassistant.components.zha.core.registries from tests.common import MockConfigEntry @@ -12,7 +15,7 @@ async def test_user_flow(hass): flow = config_flow.ZhaFlowHandler() flow.hass = hass - with patch( + with asynctest.patch( "homeassistant.components.zha.config_flow" ".check_zigpy_connection", return_value=False, ): @@ -22,7 +25,7 @@ async def test_user_flow(hass): assert result["errors"] == {"base": "cannot_connect"} - with patch( + with asynctest.patch( "homeassistant.components.zha.config_flow" ".check_zigpy_connection", return_value=True, ): @@ -71,3 +74,53 @@ async def test_import_flow_existing_config_entry(hass): ) assert result["type"] == "abort" + + +async def test_check_zigpy_connection(): + """Test config flow validator.""" + + mock_radio = asynctest.MagicMock() + mock_radio.connect = asynctest.CoroutineMock() + radio_cls = asynctest.MagicMock(return_value=mock_radio) + + bad_radio = asynctest.MagicMock() + bad_radio.connect = asynctest.CoroutineMock(side_effect=Exception) + bad_radio_cls = asynctest.MagicMock(return_value=bad_radio) + + mock_ctrl = asynctest.MagicMock() + mock_ctrl.startup = asynctest.CoroutineMock() + mock_ctrl.shutdown = asynctest.CoroutineMock() + ctrl_cls = asynctest.MagicMock(return_value=mock_ctrl) + new_radios = { + mock.sentinel.radio: {ZHA_GW_RADIO: radio_cls, CONTROLLER: ctrl_cls}, + mock.sentinel.bad_radio: {ZHA_GW_RADIO: bad_radio_cls, CONTROLLER: ctrl_cls}, + } + + with mock.patch.dict( + homeassistant.components.zha.core.registries.RADIO_TYPES, new_radios, clear=True + ): + assert not await config_flow.check_zigpy_connection( + mock.sentinel.usb_path, mock.sentinel.unk_radio, mock.sentinel.zigbee_db + ) + assert mock_radio.connect.call_count == 0 + assert bad_radio.connect.call_count == 0 + assert mock_ctrl.startup.call_count == 0 + assert mock_ctrl.shutdown.call_count == 0 + + # unsuccessful radio connect + assert not await config_flow.check_zigpy_connection( + mock.sentinel.usb_path, mock.sentinel.bad_radio, mock.sentinel.zigbee_db + ) + assert mock_radio.connect.call_count == 0 + assert bad_radio.connect.call_count == 1 + assert mock_ctrl.startup.call_count == 0 + assert mock_ctrl.shutdown.call_count == 0 + + # successful radio connect + assert await config_flow.check_zigpy_connection( + mock.sentinel.usb_path, mock.sentinel.radio, mock.sentinel.zigbee_db + ) + assert mock_radio.connect.call_count == 1 + assert bad_radio.connect.call_count == 1 + assert mock_ctrl.startup.call_count == 1 + assert mock_ctrl.shutdown.call_count == 1