From 06bc9c7b229c264c522d97046f9381d70965d90f Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 11 Jan 2023 15:28:31 -0500 Subject: [PATCH] Automatically fetch the encryption key from the ESPHome dashboard (#85709) * Automatically fetch the encryption key from the ESPHome dashboard * Also use encryption key during reauth * Typo * Clean up tests --- homeassistant/components/esphome/__init__.py | 8 + .../components/esphome/config_flow.py | 66 +++++- homeassistant/components/esphome/dashboard.py | 49 ++++- .../components/esphome/manifest.json | 2 +- requirements_all.txt | 3 + requirements_test_all.txt | 3 + tests/components/esphome/conftest.py | 10 +- tests/components/esphome/test_config_flow.py | 202 ++++++++++++++---- 8 files changed, 290 insertions(+), 53 deletions(-) diff --git a/homeassistant/components/esphome/__init__.py b/homeassistant/components/esphome/__init__.py index 979057a194c..47ef51087d2 100644 --- a/homeassistant/components/esphome/__init__.py +++ b/homeassistant/components/esphome/__init__.py @@ -62,6 +62,7 @@ from .domain_data import DOMAIN, DomainData # Import config flow so that it's added to the registry from .entry_data import RuntimeEntryData +CONF_DEVICE_NAME = "device_name" CONF_NOISE_PSK = "noise_psk" _LOGGER = logging.getLogger(__name__) _R = TypeVar("_R") @@ -268,6 +269,13 @@ async def async_setup_entry( # noqa: C901 entry, unique_id=format_mac(device_info.mac_address) ) + # Make sure we have the correct device name stored + # so we can map the device to ESPHome Dashboard config + if entry.data.get(CONF_DEVICE_NAME) != device_info.name: + hass.config_entries.async_update_entry( + entry, data={**entry.data, CONF_DEVICE_NAME: device_info.name} + ) + entry_data.device_info = device_info assert cli.api_version is not None entry_data.api_version = cli.api_version diff --git a/homeassistant/components/esphome/config_flow.py b/homeassistant/components/esphome/config_flow.py index 1c8f795c1a7..de9d3ebc624 100644 --- a/homeassistant/components/esphome/config_flow.py +++ b/homeassistant/components/esphome/config_flow.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections import OrderedDict from collections.abc import Mapping +import logging from typing import Any from aioesphomeapi import ( @@ -14,21 +15,24 @@ from aioesphomeapi import ( RequiresEncryptionAPIError, ResolveAPIError, ) +import aiohttp import voluptuous as vol from homeassistant.components import dhcp, zeroconf -from homeassistant.components.hassio.discovery import HassioServiceInfo +from homeassistant.components.hassio import HassioServiceInfo from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PASSWORD, CONF_PORT from homeassistant.core import callback from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers.device_registry import format_mac -from . import CONF_NOISE_PSK, DOMAIN -from .dashboard import async_set_dashboard_info +from . import CONF_DEVICE_NAME, CONF_NOISE_PSK, DOMAIN +from .dashboard import async_get_dashboard, async_set_dashboard_info ERROR_REQUIRES_ENCRYPTION_KEY = "requires_encryption_key" +ERROR_INVALID_ENCRYPTION_KEY = "invalid_psk" ESPHOME_URL = "https://esphome.io/" +_LOGGER = logging.getLogger(__name__) class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN): @@ -44,6 +48,8 @@ class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN): self._noise_psk: str | None = None self._device_info: DeviceInfo | None = None self._reauth_entry: ConfigEntry | None = None + # The ESPHome name as per its config + self._device_name: str | None = None async def _async_step_user_base( self, user_input: dict[str, Any] | None = None, error: str | None = None @@ -83,6 +89,13 @@ class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN): self._port = entry.data[CONF_PORT] self._password = entry.data[CONF_PASSWORD] self._name = entry.title + self._device_name = entry.data.get(CONF_DEVICE_NAME) + + if await self._retrieve_encryption_key_from_dashboard(): + error = await self.fetch_device_info() + if error is None: + return await self._async_authenticate_or_add() + return await self.async_step_reauth_confirm() async def async_step_reauth_confirm( @@ -116,6 +129,17 @@ class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN): async def _async_try_fetch_device_info(self) -> FlowResult: error = await self.fetch_device_info() + + if ( + error == ERROR_REQUIRES_ENCRYPTION_KEY + and await self._retrieve_encryption_key_from_dashboard() + ): + error = await self.fetch_device_info() + # If the fetched key is invalid, unset it again. + if error == ERROR_INVALID_ENCRYPTION_KEY: + self._noise_psk = None + error = ERROR_REQUIRES_ENCRYPTION_KEY + if error == ERROR_REQUIRES_ENCRYPTION_KEY: return await self.async_step_encryption_key() if error is not None: @@ -156,6 +180,7 @@ class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN): # Hostname is format: livingroom.local. self._name = discovery_info.hostname[: -len(".local.")] + self._device_name = self._name self._host = discovery_info.host self._port = discovery_info.port @@ -193,6 +218,7 @@ class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN): # The API uses protobuf, so empty string denotes absence CONF_PASSWORD: self._password or "", CONF_NOISE_PSK: self._noise_psk or "", + CONF_DEVICE_NAME: self._device_name, } if self._reauth_entry: entry = self._reauth_entry @@ -272,7 +298,7 @@ class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN): except RequiresEncryptionAPIError: return ERROR_REQUIRES_ENCRYPTION_KEY except InvalidEncryptionKeyAPIError: - return "invalid_psk" + return ERROR_INVALID_ENCRYPTION_KEY except ResolveAPIError: return "resolve_error" except APIConnectionError: @@ -280,7 +306,7 @@ class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN): finally: await cli.disconnect(force=True) - self._name = self._device_info.name + self._name = self._device_name = self._device_info.name await self.async_set_unique_id( self._device_info.mac_address, raise_on_progress=False ) @@ -314,3 +340,33 @@ class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN): await cli.disconnect(force=True) return None + + async def _retrieve_encryption_key_from_dashboard(self) -> bool: + """Try to retrieve the encryption key from the dashboard. + + Return boolean if a key was retrieved. + """ + if self._device_name is None: + return False + + if (dashboard := async_get_dashboard(self.hass)) is None: + return False + + await dashboard.async_request_refresh() + + if not dashboard.last_update_success: + return False + + device = dashboard.data.get(self._device_name) + + if device is None: + return False + + try: + noise_psk = await dashboard.api.get_encryption_key(device["configuration"]) + except aiohttp.ClientError as err: + _LOGGER.error("Error talking to the dashboard: %s", err) + return False + + self._noise_psk = noise_psk + return True diff --git a/homeassistant/components/esphome/dashboard.py b/homeassistant/components/esphome/dashboard.py index 4b95fa0d6fd..9e8911f7efe 100644 --- a/homeassistant/components/esphome/dashboard.py +++ b/homeassistant/components/esphome/dashboard.py @@ -1,9 +1,20 @@ """Files to interact with a the ESPHome dashboard.""" from __future__ import annotations -from dataclasses import dataclass +import asyncio +from datetime import timedelta +import logging +from typing import TYPE_CHECKING + +import aiohttp +from esphome_dashboard_api import ConfiguredDevice, ESPHomeDashboardAPI from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.aiohttp_client import async_get_clientsession +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator + +if TYPE_CHECKING: + pass KEY_DASHBOARD = "esphome_dashboard" @@ -15,14 +26,40 @@ def async_get_dashboard(hass: HomeAssistant) -> ESPHomeDashboard | None: def async_set_dashboard_info( - hass: HomeAssistant, addon_slug: str, _host: str, _port: int + hass: HomeAssistant, addon_slug: str, host: str, port: int ) -> None: """Set the dashboard info.""" - hass.data[KEY_DASHBOARD] = ESPHomeDashboard(addon_slug) + hass.data[KEY_DASHBOARD] = ESPHomeDashboard( + hass, + addon_slug, + f"http://{host}:{port}", + async_get_clientsession(hass), + ) -@dataclass -class ESPHomeDashboard: +class ESPHomeDashboard(DataUpdateCoordinator[dict[str, ConfiguredDevice]]): """Class to interact with the ESPHome dashboard.""" - addon_slug: str + _first_fetch_lock: asyncio.Lock | None = None + + def __init__( + self, + hass: HomeAssistant, + addon_slug: str, + url: str, + session: aiohttp.ClientSession, + ) -> None: + """Initialize.""" + super().__init__( + hass, + logging.getLogger(__name__), + name="ESPHome Dashboard", + update_interval=timedelta(minutes=5), + ) + self.addon_slug = addon_slug + self.api = ESPHomeDashboardAPI(url, session) + + async def _async_update_data(self) -> dict: + """Fetch device data.""" + devices = await self.api.get_devices() + return {dev["name"]: dev for dev in devices["configured"]} diff --git a/homeassistant/components/esphome/manifest.json b/homeassistant/components/esphome/manifest.json index 95b23befccc..014b6d6d6e0 100644 --- a/homeassistant/components/esphome/manifest.json +++ b/homeassistant/components/esphome/manifest.json @@ -3,7 +3,7 @@ "name": "ESPHome", "config_flow": true, "documentation": "https://www.home-assistant.io/integrations/esphome", - "requirements": ["aioesphomeapi==13.0.4"], + "requirements": ["aioesphomeapi==13.0.4", "esphome-dashboard-api==1.1"], "zeroconf": ["_esphomelib._tcp.local."], "dhcp": [{ "registered_devices": true }], "codeowners": ["@OttoWinter", "@jesserockz"], diff --git a/requirements_all.txt b/requirements_all.txt index 781effaa37d..14d9ac3dbe7 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -672,6 +672,9 @@ epson-projector==0.5.0 # homeassistant.components.epsonworkforce epsonprinter==0.0.9 +# homeassistant.components.esphome +esphome-dashboard-api==1.1 + # homeassistant.components.netgear_lte eternalegypt==0.0.12 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 5f8d36435df..f8764b21b0f 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -522,6 +522,9 @@ ephem==4.1.2 # homeassistant.components.epson epson-projector==0.5.0 +# homeassistant.components.esphome +esphome-dashboard-api==1.1 + # homeassistant.components.faa_delays faadelays==0.0.7 diff --git a/tests/components/esphome/conftest.py b/tests/components/esphome/conftest.py index 3382e978a19..c3f7fdd281c 100644 --- a/tests/components/esphome/conftest.py +++ b/tests/components/esphome/conftest.py @@ -3,7 +3,7 @@ from __future__ import annotations from unittest.mock import AsyncMock, Mock, patch -from aioesphomeapi import APIClient +from aioesphomeapi import APIClient, DeviceInfo import pytest from zeroconf import Zeroconf @@ -78,6 +78,14 @@ def mock_client(): return mock_client mock_client.side_effect = mock_constructor + mock_client.device_info = AsyncMock( + return_value=DeviceInfo( + uses_password=False, + name="test", + bluetooth_proxy_version=0, + mac_address="11:22:33:44:55:aa", + ) + ) mock_client.connect = AsyncMock() mock_client.disconnect = AsyncMock() diff --git a/tests/components/esphome/test_config_flow.py b/tests/components/esphome/test_config_flow.py index 8ac1b23eff0..9c49fe0f3f2 100644 --- a/tests/components/esphome/test_config_flow.py +++ b/tests/components/esphome/test_config_flow.py @@ -14,6 +14,7 @@ import pytest from homeassistant import config_entries, data_entry_flow from homeassistant.components import dhcp, zeroconf from homeassistant.components.esphome import ( + CONF_DEVICE_NAME, CONF_NOISE_PSK, DOMAIN, DomainData, @@ -47,12 +48,6 @@ async def test_user_connection_works(hass, mock_client, mock_zeroconf): assert result["type"] == FlowResultType.FORM assert result["step_id"] == "user" - mock_client.device_info = AsyncMock( - return_value=DeviceInfo( - uses_password=False, name="test", mac_address="mock-mac" - ) - ) - result = await hass.config_entries.flow.async_init( "esphome", context={"source": config_entries.SOURCE_USER}, @@ -65,9 +60,10 @@ async def test_user_connection_works(hass, mock_client, mock_zeroconf): CONF_PORT: 80, CONF_PASSWORD: "", CONF_NOISE_PSK: "", + CONF_DEVICE_NAME: "test", } assert result["title"] == "test" - assert result["result"].unique_id == "mock-mac" + assert result["result"].unique_id == "11:22:33:44:55:aa" assert len(mock_client.connect.mock_calls) == 1 assert len(mock_client.device_info.mock_calls) == 1 @@ -83,7 +79,7 @@ async def test_user_connection_updates_host(hass, mock_client, mock_zeroconf): entry = MockConfigEntry( domain=DOMAIN, data={CONF_HOST: "test.local", CONF_PORT: 6053, CONF_PASSWORD: ""}, - unique_id="mock-mac", + unique_id="11:22:33:44:55:aa", ) entry.add_to_hass(hass) result = await hass.config_entries.flow.async_init( @@ -95,12 +91,6 @@ async def test_user_connection_updates_host(hass, mock_client, mock_zeroconf): assert result["type"] == FlowResultType.FORM assert result["step_id"] == "user" - mock_client.device_info = AsyncMock( - return_value=DeviceInfo( - uses_password=False, name="test", mac_address="mock-mac" - ) - ) - result = await hass.config_entries.flow.async_init( "esphome", context={"source": config_entries.SOURCE_USER}, @@ -155,9 +145,7 @@ async def test_user_connection_error(hass, mock_client, mock_zeroconf): async def test_user_with_password(hass, mock_client, mock_zeroconf): """Test user step with password.""" - mock_client.device_info = AsyncMock( - return_value=DeviceInfo(uses_password=True, name="test") - ) + mock_client.device_info.return_value = DeviceInfo(uses_password=True, name="test") result = await hass.config_entries.flow.async_init( "esphome", @@ -178,15 +166,14 @@ async def test_user_with_password(hass, mock_client, mock_zeroconf): CONF_PORT: 6053, CONF_PASSWORD: "password1", CONF_NOISE_PSK: "", + CONF_DEVICE_NAME: "test", } assert mock_client.password == "password1" async def test_user_invalid_password(hass, mock_client, mock_zeroconf): """Test user step with invalid password.""" - mock_client.device_info = AsyncMock( - return_value=DeviceInfo(uses_password=True, name="test") - ) + mock_client.device_info.return_value = DeviceInfo(uses_password=True, name="test") result = await hass.config_entries.flow.async_init( "esphome", @@ -210,9 +197,7 @@ async def test_user_invalid_password(hass, mock_client, mock_zeroconf): async def test_login_connection_error(hass, mock_client, mock_zeroconf): """Test user step with connection error on login attempt.""" - mock_client.device_info = AsyncMock( - return_value=DeviceInfo(uses_password=True, name="test") - ) + mock_client.device_info.return_value = DeviceInfo(uses_password=True, name="test") result = await hass.config_entries.flow.async_init( "esphome", @@ -236,16 +221,10 @@ async def test_login_connection_error(hass, mock_client, mock_zeroconf): async def test_discovery_initiation(hass, mock_client, mock_zeroconf): """Test discovery importing works.""" - mock_client.device_info = AsyncMock( - return_value=DeviceInfo( - uses_password=False, name="test8266", mac_address="11:22:33:44:55:aa" - ) - ) - service_info = zeroconf.ZeroconfServiceInfo( host="192.168.43.183", addresses=["192.168.43.183"], - hostname="test8266.local.", + hostname="test.local.", name="mock_name", port=6053, properties={ @@ -262,7 +241,7 @@ async def test_discovery_initiation(hass, mock_client, mock_zeroconf): ) assert result["type"] == FlowResultType.CREATE_ENTRY - assert result["title"] == "test8266" + assert result["title"] == "test" assert result["data"][CONF_HOST] == "192.168.43.183" assert result["data"][CONF_PORT] == 6053 @@ -320,17 +299,13 @@ async def test_discovery_duplicate_data(hass, mock_client): service_info = zeroconf.ZeroconfServiceInfo( host="192.168.43.183", addresses=["192.168.43.183"], - hostname="test8266.local.", + hostname="test.local.", name="mock_name", port=6053, - properties={"address": "test8266.local", "mac": "1122334455aa"}, + properties={"address": "test.local", "mac": "1122334455aa"}, type="mock_type", ) - mock_client.device_info = AsyncMock( - return_value=DeviceInfo(uses_password=False, name="test8266") - ) - result = await hass.config_entries.flow.async_init( "esphome", data=service_info, context={"source": config_entries.SOURCE_ZEROCONF} ) @@ -419,6 +394,7 @@ async def test_encryption_key_valid_psk(hass, mock_client, mock_zeroconf): CONF_PORT: 6053, CONF_PASSWORD: "", CONF_NOISE_PSK: VALID_NOISE_PSK, + CONF_DEVICE_NAME: "test", } assert mock_client.noise_psk == VALID_NOISE_PSK @@ -485,9 +461,7 @@ async def test_reauth_confirm_valid(hass, mock_client, mock_zeroconf): }, ) - mock_client.device_info = AsyncMock( - return_value=DeviceInfo(uses_password=False, name="test") - ) + mock_client.device_info.return_value = DeviceInfo(uses_password=False, name="test") result = await hass.config_entries.flow.async_configure( result["flow_id"], user_input={CONF_NOISE_PSK: VALID_NOISE_PSK} ) @@ -497,6 +471,53 @@ async def test_reauth_confirm_valid(hass, mock_client, mock_zeroconf): assert entry.data[CONF_NOISE_PSK] == VALID_NOISE_PSK +async def test_reauth_fixed_via_dashboard(hass, mock_client, mock_zeroconf): + """Test reauth fixed automatically via dashboard.""" + dashboard.async_set_dashboard_info(hass, "mock-slug", "mock-host", 6052) + + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "127.0.0.1", + CONF_PORT: 6053, + CONF_PASSWORD: "", + CONF_DEVICE_NAME: "test", + }, + ) + entry.add_to_hass(hass) + + mock_client.device_info.return_value = DeviceInfo(uses_password=False, name="test") + + with patch( + "homeassistant.components.esphome.dashboard.ESPHomeDashboardAPI.get_devices", + return_value={ + "configured": [ + { + "name": "test", + "configuration": "test.yaml", + } + ] + }, + ), patch( + "homeassistant.components.esphome.dashboard.ESPHomeDashboardAPI.get_encryption_key", + return_value=VALID_NOISE_PSK, + ) as mock_get_encryption_key: + result = await hass.config_entries.flow.async_init( + "esphome", + context={ + "source": config_entries.SOURCE_REAUTH, + "entry_id": entry.entry_id, + "unique_id": entry.unique_id, + }, + ) + + assert result["type"] == FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + assert entry.data[CONF_NOISE_PSK] == VALID_NOISE_PSK + + assert len(mock_get_encryption_key.mock_calls) == 1 + + async def test_reauth_confirm_invalid(hass, mock_client, mock_zeroconf): """Test reauth initiation with invalid PSK.""" entry = MockConfigEntry( @@ -649,3 +670,104 @@ async def test_discovery_hassio(hass): dash = dashboard.async_get_dashboard(hass) assert dash is not None assert dash.addon_slug == "mock-slug" + + +async def test_zeroconf_encryption_key_via_dashboard(hass, mock_client, mock_zeroconf): + """Test encryption key retrieved from dashboard.""" + service_info = zeroconf.ZeroconfServiceInfo( + host="192.168.43.183", + addresses=["192.168.43.183"], + hostname="test8266.local.", + name="mock_name", + port=6053, + properties={ + "mac": "1122334455aa", + }, + type="mock_type", + ) + flow = await hass.config_entries.flow.async_init( + "esphome", context={"source": config_entries.SOURCE_ZEROCONF}, data=service_info + ) + + assert flow["type"] == FlowResultType.FORM + assert flow["step_id"] == "discovery_confirm" + + dashboard.async_set_dashboard_info(hass, "mock-slug", "mock-host", 6052) + + mock_client.device_info.side_effect = [ + RequiresEncryptionAPIError, + DeviceInfo( + uses_password=False, + name="test8266", + mac_address="11:22:33:44:55:aa", + ), + ] + + with patch( + "homeassistant.components.esphome.dashboard.ESPHomeDashboardAPI.get_devices", + return_value={ + "configured": [ + { + "name": "test8266", + "configuration": "test8266.yaml", + } + ] + }, + ), patch( + "homeassistant.components.esphome.dashboard.ESPHomeDashboardAPI.get_encryption_key", + return_value=VALID_NOISE_PSK, + ) as mock_get_encryption_key: + result = await hass.config_entries.flow.async_configure( + flow["flow_id"], user_input={} + ) + + assert len(mock_get_encryption_key.mock_calls) == 1 + + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["title"] == "test8266" + assert result["data"][CONF_HOST] == "192.168.43.183" + assert result["data"][CONF_PORT] == 6053 + assert result["data"][CONF_NOISE_PSK] == VALID_NOISE_PSK + + assert result["result"] + assert result["result"].unique_id == "11:22:33:44:55:aa" + + assert mock_client.noise_psk == VALID_NOISE_PSK + + +async def test_zeroconf_no_encryption_key_via_dashboard( + hass, mock_client, mock_zeroconf +): + """Test encryption key not retrieved from dashboard.""" + service_info = zeroconf.ZeroconfServiceInfo( + host="192.168.43.183", + addresses=["192.168.43.183"], + hostname="test8266.local.", + name="mock_name", + port=6053, + properties={ + "mac": "1122334455aa", + }, + type="mock_type", + ) + flow = await hass.config_entries.flow.async_init( + "esphome", context={"source": config_entries.SOURCE_ZEROCONF}, data=service_info + ) + + assert flow["type"] == FlowResultType.FORM + assert flow["step_id"] == "discovery_confirm" + + dashboard.async_set_dashboard_info(hass, "mock-slug", "mock-host", 6052) + + mock_client.device_info.side_effect = RequiresEncryptionAPIError + + with patch( + "homeassistant.components.esphome.dashboard.ESPHomeDashboardAPI.get_devices", + return_value={"configured": []}, + ): + result = await hass.config_entries.flow.async_configure( + flow["flow_id"], user_input={} + ) + + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "encryption_key"