From 3c7005d4dcbfd30fc76e2872820f2e45b9f18fc6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 18 Jan 2022 14:40:55 -1000 Subject: [PATCH] Use unifi direct connect w/ssl verify for unifiprotect when possible (#64395) --- .../components/unifiprotect/config_flow.py | 51 ++++-- .../components/unifiprotect/discovery.py | 8 +- .../components/unifiprotect/strings.json | 1 - .../unifiprotect/translations/en.json | 3 +- tests/components/unifiprotect/__init__.py | 8 + .../unifiprotect/test_config_flow.py | 152 ++++++++++++++++-- 6 files changed, 190 insertions(+), 33 deletions(-) diff --git a/homeassistant/components/unifiprotect/config_flow.py b/homeassistant/components/unifiprotect/config_flow.py index 7f06539c2fa..720ba9aa37f 100644 --- a/homeassistant/components/unifiprotect/config_flow.py +++ b/homeassistant/components/unifiprotect/config_flow.py @@ -40,6 +40,11 @@ from .utils import _async_short_mac, _async_unifi_mac_from_hass _LOGGER = logging.getLogger(__name__) +def _host_is_direct_connect(host: str) -> bool: + """Check if a host is a unifi direct connect domain.""" + return host.endswith(".ui.direct") + + class ProtectFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Handle a UniFi Protect config flow.""" @@ -74,10 +79,33 @@ class ProtectFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): ) -> FlowResult: """Handle discovery.""" self._discovered_device = discovery_info - mac = _async_unifi_mac_from_hass(discovery_info["mac"]) + mac = _async_unifi_mac_from_hass(discovery_info["hw_addr"]) await self.async_set_unique_id(mac) + for entry in self._async_current_entries(include_ignore=False): + if entry.unique_id != mac: + continue + new_host = None + if ( + _host_is_direct_connect(entry.data[CONF_HOST]) + and discovery_info["direct_connect_domain"] + and entry.data[CONF_HOST] != discovery_info["direct_connect_domain"] + ): + new_host = discovery_info["direct_connect_domain"] + elif ( + not _host_is_direct_connect(entry.data[CONF_HOST]) + and entry.data[CONF_HOST] != discovery_info["source_ip"] + ): + new_host = discovery_info["source_ip"] + if new_host: + self.hass.config_entries.async_update_entry( + entry, data={**entry.data, CONF_HOST: new_host} + ) + self.hass.async_create_task( + self.hass.config_entries.async_reload(entry.entry_id) + ) + return self.async_abort(reason="already_configured") self._abort_if_unique_id_configured( - updates={CONF_HOST: discovery_info["ip_address"]} + updates={CONF_HOST: discovery_info["source_ip"]} ) return await self.async_step_discovery_confirm() @@ -88,17 +116,24 @@ class ProtectFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): errors: dict[str, str] = {} discovery_info = self._discovered_device if user_input is not None: - user_input[CONF_HOST] = discovery_info["ip_address"] user_input[CONF_PORT] = DEFAULT_PORT - nvr_data, errors = await self._async_get_nvr_data(user_input) + nvr_data = None + if discovery_info["direct_connect_domain"]: + user_input[CONF_HOST] = discovery_info["direct_connect_domain"] + user_input[CONF_VERIFY_SSL] = True + nvr_data, errors = await self._async_get_nvr_data(user_input) + if not nvr_data or errors: + user_input[CONF_HOST] = discovery_info["source_ip"] + user_input[CONF_VERIFY_SSL] = False + nvr_data, errors = await self._async_get_nvr_data(user_input) if nvr_data and not errors: return self._async_create_entry(nvr_data.name, user_input) placeholders = { "name": discovery_info["hostname"] or discovery_info["platform"] - or f"NVR {_async_short_mac(discovery_info['mac'])}", - "ip_address": discovery_info["ip_address"], + or f"NVR {_async_short_mac(discovery_info['hw_addr'])}", + "ip_address": discovery_info["source_ip"], } self.context["title_placeholders"] = placeholders user_input = user_input or {} @@ -107,10 +142,6 @@ class ProtectFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): description_placeholders=placeholders, data_schema=vol.Schema( { - vol.Required( - CONF_VERIFY_SSL, - default=user_input.get(CONF_VERIFY_SSL, DEFAULT_VERIFY_SSL), - ): bool, vol.Required( CONF_USERNAME, default=user_input.get(CONF_USERNAME) ): str, diff --git a/homeassistant/components/unifiprotect/discovery.py b/homeassistant/components/unifiprotect/discovery.py index e2867e97a56..0efa6bc1fb1 100644 --- a/homeassistant/components/unifiprotect/discovery.py +++ b/homeassistant/components/unifiprotect/discovery.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +from dataclasses import asdict from datetime import timedelta import logging from typing import Any @@ -56,11 +57,6 @@ def async_trigger_discovery( hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_DISCOVERY}, - data={ - "ip_address": device.source_ip, - "mac": device.hw_addr, - "hostname": device.hostname, # can be None - "platform": device.platform, # can be None - }, + data=asdict(device), ) ) diff --git a/homeassistant/components/unifiprotect/strings.json b/homeassistant/components/unifiprotect/strings.json index 57836dc45f4..99bffefbd35 100644 --- a/homeassistant/components/unifiprotect/strings.json +++ b/homeassistant/components/unifiprotect/strings.json @@ -24,7 +24,6 @@ "discovery_confirm": { "description": "Do you want to setup {name} ({ip_address})?", "data": { - "verify_ssl": "[%key:common::config_flow::data::verify_ssl%]", "username": "[%key:common::config_flow::data::username%]", "password": "[%key:common::config_flow::data::password%]" } diff --git a/homeassistant/components/unifiprotect/translations/en.json b/homeassistant/components/unifiprotect/translations/en.json index effa7881480..ab1868be7a1 100644 --- a/homeassistant/components/unifiprotect/translations/en.json +++ b/homeassistant/components/unifiprotect/translations/en.json @@ -15,8 +15,7 @@ "discovery_confirm": { "data": { "password": "Password", - "username": "Username", - "verify_ssl": "Verify SSL certificate" + "username": "Username" }, "description": "Do you want to setup {name} ({ip_address})?" }, diff --git a/tests/components/unifiprotect/__init__.py b/tests/components/unifiprotect/__init__.py index 1cdc6dfc9f7..5fd1b7cc909 100644 --- a/tests/components/unifiprotect/__init__.py +++ b/tests/components/unifiprotect/__init__.py @@ -16,6 +16,14 @@ UNIFI_DISCOVERY = UnifiDevice( platform=DEVICE_HOSTNAME, hostname=DEVICE_HOSTNAME, services={UnifiService.Protect: True}, + direct_connect_domain="x.ui.direct", +) + + +UNIFI_DISCOVERY_PARTIAL = UnifiDevice( + source_ip=DEVICE_IP_ADDRESS, + hw_addr=DEVICE_MAC_ADDRESS, + services={UnifiService.Protect: True}, ) diff --git a/tests/components/unifiprotect/test_config_flow.py b/tests/components/unifiprotect/test_config_flow.py index b162c7a9540..557eb3d5e79 100644 --- a/tests/components/unifiprotect/test_config_flow.py +++ b/tests/components/unifiprotect/test_config_flow.py @@ -1,6 +1,7 @@ """Test the UniFi Protect config flow.""" from __future__ import annotations +from dataclasses import asdict from unittest.mock import patch import pytest @@ -15,6 +16,7 @@ from homeassistant.components.unifiprotect.const import ( CONF_OVERRIDE_CHOST, DOMAIN, ) +from homeassistant.const import CONF_HOST from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import ( RESULT_TYPE_ABORT, @@ -23,7 +25,14 @@ from homeassistant.data_entry_flow import ( ) from homeassistant.helpers import device_registry as dr -from . import DEVICE_HOSTNAME, DEVICE_IP_ADDRESS, DEVICE_MAC_ADDRESS, _patch_discovery +from . import ( + DEVICE_HOSTNAME, + DEVICE_IP_ADDRESS, + DEVICE_MAC_ADDRESS, + UNIFI_DISCOVERY, + UNIFI_DISCOVERY_PARTIAL, + _patch_discovery, +) from .conftest import MAC_ADDR from tests.common import MockConfigEntry @@ -45,18 +54,9 @@ SSDP_DISCOVERY = ( }, ), ) -UNIFI_DISCOVERY_DICT = { - "ip_address": DEVICE_IP_ADDRESS, - "mac": DEVICE_MAC_ADDRESS, - "hostname": DEVICE_HOSTNAME, - "platform": DEVICE_HOSTNAME, -} -UNIFI_DISCOVERY_DICT_PARTIAL = { - "ip_address": DEVICE_IP_ADDRESS, - "mac": DEVICE_MAC_ADDRESS, - "hostname": None, - "platform": None, -} + +UNIFI_DISCOVERY_DICT = asdict(UNIFI_DISCOVERY) +UNIFI_DISCOVERY_DICT_PARTIAL = asdict(UNIFI_DISCOVERY_PARTIAL) async def test_form(hass: HomeAssistant, mock_nvr: NVR) -> None: @@ -292,7 +292,7 @@ async def test_discovered_by_ssdp_or_dhcp( assert result["reason"] == "discovery_started" -async def test_discovered_by_unifi_discovery( +async def test_discovered_by_unifi_discovery_direct_connect( hass: HomeAssistant, mock_nvr: NVR ) -> None: """Test a discovery from unifi-discovery.""" @@ -331,6 +331,130 @@ async def test_discovered_by_unifi_discovery( ) await hass.async_block_till_done() + assert result2["type"] == RESULT_TYPE_CREATE_ENTRY + assert result2["title"] == "UnifiProtect" + assert result2["data"] == { + "host": "x.ui.direct", + "username": "test-username", + "password": "test-password", + "id": "UnifiProtect", + "port": 443, + "verify_ssl": True, + } + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_discovered_by_unifi_discovery_direct_connect_updated( + hass: HomeAssistant, mock_nvr: NVR +) -> None: + """Test a discovery from unifi-discovery updates the direct connect host.""" + mock_config = MockConfigEntry( + domain=DOMAIN, + data={ + "host": "y.ui.direct", + "username": "test-username", + "password": "test-password", + "id": "UnifiProtect", + "port": 443, + "verify_ssl": True, + }, + version=2, + unique_id=DEVICE_MAC_ADDRESS.replace(":", "").upper(), + ) + mock_config.add_to_hass(hass) + + with _patch_discovery(), patch( + "homeassistant.components.unifiprotect.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_DISCOVERY}, + data=UNIFI_DISCOVERY_DICT, + ) + await hass.async_block_till_done() + + assert result["type"] == RESULT_TYPE_ABORT + assert result["reason"] == "already_configured" + assert len(mock_setup_entry.mock_calls) == 1 + assert mock_config.data[CONF_HOST] == "x.ui.direct" + + +async def test_discovered_by_unifi_discovery_direct_connect_updated_but_not_using_direct_connect( + hass: HomeAssistant, mock_nvr: NVR +) -> None: + """Test a discovery from unifi-discovery updates the host but not direct connect if its not in use.""" + mock_config = MockConfigEntry( + domain=DOMAIN, + data={ + "host": "1.2.2.2", + "username": "test-username", + "password": "test-password", + "id": "UnifiProtect", + "port": 443, + "verify_ssl": False, + }, + version=2, + unique_id=DEVICE_MAC_ADDRESS.replace(":", "").upper(), + ) + mock_config.add_to_hass(hass) + + with _patch_discovery(), patch( + "homeassistant.components.unifiprotect.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_DISCOVERY}, + data=UNIFI_DISCOVERY_DICT, + ) + await hass.async_block_till_done() + + assert result["type"] == RESULT_TYPE_ABORT + assert result["reason"] == "already_configured" + assert len(mock_setup_entry.mock_calls) == 1 + assert mock_config.data[CONF_HOST] == "127.0.0.1" + + +async def test_discovered_by_unifi_discovery( + hass: HomeAssistant, mock_nvr: NVR +) -> None: + """Test a discovery from unifi-discovery.""" + + with _patch_discovery(): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_DISCOVERY}, + data=UNIFI_DISCOVERY_DICT, + ) + await hass.async_block_till_done() + + assert result["type"] == RESULT_TYPE_FORM + assert result["step_id"] == "discovery_confirm" + flows = hass.config_entries.flow.async_progress_by_handler(DOMAIN) + assert flows[0]["context"]["title_placeholders"] == { + "ip_address": DEVICE_IP_ADDRESS, + "name": DEVICE_HOSTNAME, + } + + assert not result["errors"] + + with patch( + "homeassistant.components.unifiprotect.config_flow.ProtectApiClient.get_nvr", + side_effect=[NotAuthorized, mock_nvr], + ), patch( + "homeassistant.components.unifiprotect.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + "username": "test-username", + "password": "test-password", + }, + ) + await hass.async_block_till_done() + assert result2["type"] == RESULT_TYPE_CREATE_ENTRY assert result2["title"] == "UnifiProtect" assert result2["data"] == {