From 1d5ecdd4eae00f7e4e2d657fe3e2c3d74be8920c Mon Sep 17 00:00:00 2001 From: Michael <35783820+mib1185@users.noreply.github.com> Date: Tue, 17 Jan 2023 03:34:42 +0100 Subject: [PATCH] Make API key mandatory for PI-Hole (#85885) add reauth, so make api-key mandatory --- homeassistant/components/pi_hole/__init__.py | 42 +++-- .../components/pi_hole/binary_sensor.py | 14 -- .../components/pi_hole/config_flow.py | 177 ++++++++++-------- homeassistant/components/pi_hole/const.py | 3 - homeassistant/components/pi_hole/strings.json | 17 +- .../components/pi_hole/translations/en.json | 19 +- tests/components/pi_hole/__init__.py | 37 +++- tests/components/pi_hole/test_config_flow.py | 111 ++++++----- tests/components/pi_hole/test_init.py | 83 +++----- 9 files changed, 252 insertions(+), 251 deletions(-) diff --git a/homeassistant/components/pi_hole/__init__.py b/homeassistant/components/pi_hole/__init__.py index 714547ba961..ba7949c0c30 100644 --- a/homeassistant/components/pi_hole/__init__.py +++ b/homeassistant/components/pi_hole/__init__.py @@ -17,7 +17,8 @@ from homeassistant.const import ( CONF_VERIFY_SSL, Platform, ) -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.helpers import config_validation as cv from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.entity import DeviceInfo @@ -64,6 +65,13 @@ CONFIG_SCHEMA = vol.Schema( extra=vol.ALLOW_EXTRA, ) +PLATFORMS = [ + Platform.BINARY_SENSOR, + Platform.SENSOR, + Platform.SWITCH, + Platform.UPDATE, +] + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the Pi-hole integration.""" @@ -103,11 +111,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: location = entry.data[CONF_LOCATION] api_key = entry.data.get(CONF_API_KEY) - # For backward compatibility - if CONF_STATISTICS_ONLY not in entry.data: - hass.config_entries.async_update_entry( - entry, data={**entry.data, CONF_STATISTICS_ONLY: not api_key} - ) + # remove obsolet CONF_STATISTICS_ONLY from entry.data + if CONF_STATISTICS_ONLY in entry.data: + entry_data = entry.data.copy() + entry_data.pop(CONF_STATISTICS_ONLY) + hass.config_entries.async_update_entry(entry, data=entry_data) + + # start reauth to force api key is present + if CONF_API_KEY not in entry.data: + raise ConfigEntryAuthFailed _LOGGER.debug("Setting up %s integration with host %s", DOMAIN, host) @@ -125,8 +137,11 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: try: await api.get_data() await api.get_versions() + _LOGGER.debug("async_update_data() api.data: %s", api.data) except HoleError as err: raise UpdateFailed(f"Failed to communicate with API: {err}") from err + if not isinstance(api.data, dict): + raise ConfigEntryAuthFailed coordinator = DataUpdateCoordinator( hass, @@ -142,30 +157,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: await coordinator.async_config_entry_first_refresh() - await hass.config_entries.async_forward_entry_setups(entry, _async_platforms(entry)) + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) return True async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload Pi-hole entry.""" - unload_ok = await hass.config_entries.async_unload_platforms( - entry, _async_platforms(entry) - ) + unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) if unload_ok: hass.data[DOMAIN].pop(entry.entry_id) return unload_ok -@callback -def _async_platforms(entry: ConfigEntry) -> list[Platform]: - """Return platforms to be loaded / unloaded.""" - platforms = [Platform.BINARY_SENSOR, Platform.UPDATE, Platform.SENSOR] - if not entry.data[CONF_STATISTICS_ONLY]: - platforms.append(Platform.SWITCH) - return platforms - - class PiHoleEntity(CoordinatorEntity): """Representation of a Pi-hole entity.""" diff --git a/homeassistant/components/pi_hole/binary_sensor.py b/homeassistant/components/pi_hole/binary_sensor.py index e887f2ea12f..7d0d9034fad 100644 --- a/homeassistant/components/pi_hole/binary_sensor.py +++ b/homeassistant/components/pi_hole/binary_sensor.py @@ -15,8 +15,6 @@ from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from . import PiHoleEntity from .const import ( BINARY_SENSOR_TYPES, - BINARY_SENSOR_TYPES_STATISTICS_ONLY, - CONF_STATISTICS_ONLY, DATA_KEY_API, DATA_KEY_COORDINATOR, DOMAIN as PIHOLE_DOMAIN, @@ -42,18 +40,6 @@ async def async_setup_entry( for description in BINARY_SENSOR_TYPES ] - if entry.data[CONF_STATISTICS_ONLY]: - binary_sensors += [ - PiHoleBinarySensor( - hole_data[DATA_KEY_API], - hole_data[DATA_KEY_COORDINATOR], - name, - entry.entry_id, - description, - ) - for description in BINARY_SENSOR_TYPES_STATISTICS_ONLY - ] - async_add_entities(binary_sensors, True) diff --git a/homeassistant/components/pi_hole/config_flow.py b/homeassistant/components/pi_hole/config_flow.py index 40f4555e7d2..637f906b9ee 100644 --- a/homeassistant/components/pi_hole/config_flow.py +++ b/homeassistant/components/pi_hole/config_flow.py @@ -1,6 +1,7 @@ """Config flow to configure the Pi-hole integration.""" from __future__ import annotations +from collections.abc import Mapping import logging from typing import Any @@ -26,7 +27,6 @@ from .const import ( DEFAULT_LOCATION, DEFAULT_NAME, DEFAULT_SSL, - DEFAULT_STATISTICS_ONLY, DEFAULT_VERIFY_SSL, DOMAIN, ) @@ -47,65 +47,29 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Handle a flow initiated by the user.""" - return await self.async_step_init(user_input) - - async def async_step_import( - self, user_input: dict[str, Any] | None = None - ) -> FlowResult: - """Handle a flow initiated by import.""" - return await self.async_step_init(user_input, is_import=True) - - async def async_step_init( - self, user_input: dict[str, Any] | None, is_import: bool = False - ) -> FlowResult: - """Handle init step of a flow.""" errors = {} if user_input is not None: - host = ( - user_input[CONF_HOST] - if is_import - else f"{user_input[CONF_HOST]}:{user_input[CONF_PORT]}" - ) - name = user_input[CONF_NAME] - location = user_input[CONF_LOCATION] - tls = user_input[CONF_SSL] - verify_tls = user_input[CONF_VERIFY_SSL] - endpoint = f"{host}/{location}" + self._config = { + CONF_HOST: f"{user_input[CONF_HOST]}:{user_input[CONF_PORT]}", + CONF_NAME: user_input[CONF_NAME], + CONF_LOCATION: user_input[CONF_LOCATION], + CONF_SSL: user_input[CONF_SSL], + CONF_VERIFY_SSL: user_input[CONF_VERIFY_SSL], + CONF_API_KEY: user_input[CONF_API_KEY], + } - if await self._async_endpoint_existed(endpoint): - return self.async_abort(reason="already_configured") - - try: - await self._async_try_connect(host, location, tls, verify_tls) - except HoleError as ex: - _LOGGER.debug("Connection failed: %s", ex) - if is_import: - _LOGGER.error("Failed to import: %s", ex) - return self.async_abort(reason="cannot_connect") - errors["base"] = "cannot_connect" - else: - self._config = { - CONF_HOST: host, - CONF_NAME: name, - CONF_LOCATION: location, - CONF_SSL: tls, - CONF_VERIFY_SSL: verify_tls, + self._async_abort_entries_match( + { + CONF_HOST: f"{user_input[CONF_HOST]}:{user_input[CONF_PORT]}", + CONF_LOCATION: user_input[CONF_LOCATION], } - if is_import: - api_key = user_input.get(CONF_API_KEY) - return self.async_create_entry( - title=name, - data={ - **self._config, - CONF_STATISTICS_ONLY: api_key is None, - CONF_API_KEY: api_key, - }, - ) - self._config[CONF_STATISTICS_ONLY] = user_input[CONF_STATISTICS_ONLY] - if self._config[CONF_STATISTICS_ONLY]: - return self.async_create_entry(title=name, data=self._config) - return await self.async_step_api_key() + ) + + if not (errors := await self._async_try_connect()): + return self.async_create_entry( + title=user_input[CONF_NAME], data=self._config + ) user_input = user_input or {} return self.async_show_form( @@ -116,6 +80,7 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): vol.Required( CONF_PORT, default=user_input.get(CONF_PORT, 80) ): vol.Coerce(int), + vol.Required(CONF_API_KEY): str, vol.Required( CONF_NAME, default=user_input.get(CONF_NAME, DEFAULT_NAME) ): str, @@ -123,12 +88,6 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): CONF_LOCATION, default=user_input.get(CONF_LOCATION, DEFAULT_LOCATION), ): str, - vol.Required( - CONF_STATISTICS_ONLY, - default=user_input.get( - CONF_STATISTICS_ONLY, DEFAULT_STATISTICS_ONLY - ), - ): bool, vol.Required( CONF_SSL, default=user_input.get(CONF_SSL, DEFAULT_SSL), @@ -142,24 +101,94 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): errors=errors, ) - async def async_step_api_key( - self, user_input: dict[str, Any] | None = None + async def async_step_import(self, user_input: dict[str, Any]) -> FlowResult: + """Handle a flow initiated by import.""" + + host = user_input[CONF_HOST] + name = user_input[CONF_NAME] + location = user_input[CONF_LOCATION] + tls = user_input[CONF_SSL] + verify_tls = user_input[CONF_VERIFY_SSL] + endpoint = f"{host}/{location}" + + if await self._async_endpoint_existed(endpoint): + return self.async_abort(reason="already_configured") + + try: + await self._async_try_connect_legacy(host, location, tls, verify_tls) + except HoleError as ex: + _LOGGER.debug("Connection failed: %s", ex) + _LOGGER.error("Failed to import: %s", ex) + return self.async_abort(reason="cannot_connect") + self._config = { + CONF_HOST: host, + CONF_NAME: name, + CONF_LOCATION: location, + CONF_SSL: tls, + CONF_VERIFY_SSL: verify_tls, + } + api_key = user_input.get(CONF_API_KEY) + return self.async_create_entry( + title=name, + data={ + **self._config, + CONF_STATISTICS_ONLY: api_key is None, + CONF_API_KEY: api_key, + }, + ) + + async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult: + """Perform reauth upon an API authentication error.""" + self._config = dict(entry_data) + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm( + self, + user_input: dict[str, Any] | None = None, ) -> FlowResult: - """Handle step to setup API key.""" + """Perform reauth confirm upon an API authentication error.""" + errors = {} if user_input is not None: - return self.async_create_entry( - title=self._config[CONF_NAME], - data={ - **self._config, - CONF_API_KEY: user_input.get(CONF_API_KEY, ""), - }, - ) + self._config = {**self._config, CONF_API_KEY: user_input[CONF_API_KEY]} + if not (errors := await self._async_try_connect()): + entry = self.hass.config_entries.async_get_entry( + self.context["entry_id"] + ) + assert entry + self.hass.config_entries.async_update_entry(entry, data=self._config) + self.hass.async_create_task( + self.hass.config_entries.async_reload(self.context["entry_id"]) + ) + return self.async_abort(reason="reauth_successful") return self.async_show_form( - step_id="api_key", - data_schema=vol.Schema({vol.Optional(CONF_API_KEY): str}), + step_id="reauth_confirm", + description_placeholders={ + CONF_HOST: self._config[CONF_HOST], + CONF_LOCATION: self._config[CONF_LOCATION], + }, + data_schema=vol.Schema({vol.Required(CONF_API_KEY): str}), + errors=errors, ) + async def _async_try_connect(self) -> dict[str, str]: + session = async_get_clientsession(self.hass, self._config[CONF_VERIFY_SSL]) + pi_hole = Hole( + self._config[CONF_HOST], + session, + location=self._config[CONF_LOCATION], + tls=self._config[CONF_SSL], + api_token=self._config[CONF_API_KEY], + ) + try: + await pi_hole.get_data() + except HoleError as ex: + _LOGGER.debug("Connection failed: %s", ex) + return {"base": "cannot_connect"} + if not isinstance(pi_hole.data, dict): + return {CONF_API_KEY: "invalid_auth"} + return {} + async def _async_endpoint_existed(self, endpoint: str) -> bool: existing_endpoints = [ f"{entry.data.get(CONF_HOST)}/{entry.data.get(CONF_LOCATION)}" @@ -167,7 +196,7 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): ] return endpoint in existing_endpoints - async def _async_try_connect( + async def _async_try_connect_legacy( self, host: str, location: str, tls: bool, verify_tls: bool ) -> None: session = async_get_clientsession(self.hass, verify_tls) diff --git a/homeassistant/components/pi_hole/const.py b/homeassistant/components/pi_hole/const.py index c73660faedb..a9bc5824ad9 100644 --- a/homeassistant/components/pi_hole/const.py +++ b/homeassistant/components/pi_hole/const.py @@ -154,9 +154,6 @@ BINARY_SENSOR_TYPES: tuple[PiHoleBinarySensorEntityDescription, ...] = ( }, state_value=lambda api: bool(api.versions["FTL_update"]), ), -) - -BINARY_SENSOR_TYPES_STATISTICS_ONLY: tuple[PiHoleBinarySensorEntityDescription, ...] = ( PiHoleBinarySensorEntityDescription( key="status", name="Status", diff --git a/homeassistant/components/pi_hole/strings.json b/homeassistant/components/pi_hole/strings.json index e911779d5d7..120ab8cb80a 100644 --- a/homeassistant/components/pi_hole/strings.json +++ b/homeassistant/components/pi_hole/strings.json @@ -8,28 +8,25 @@ "name": "[%key:common::config_flow::data::name%]", "location": "[%key:common::config_flow::data::location%]", "api_key": "[%key:common::config_flow::data::api_key%]", - "statistics_only": "Statistics Only", "ssl": "[%key:common::config_flow::data::ssl%]", "verify_ssl": "[%key:common::config_flow::data::verify_ssl%]" } }, - "api_key": { + "reauth_confirm": { + "title": "PI-Hole [%key:common::config_flow::title::reauth%]", + "description": "Please enter a new api key for PI-Hole at {host}/{location}", "data": { "api_key": "[%key:common::config_flow::data::api_key%]" } } }, "error": { - "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]" + "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", + "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]" }, "abort": { - "already_configured": "[%key:common::config_flow::abort::already_configured_service%]" - } - }, - "issues": { - "deprecated_yaml": { - "title": "The PI-Hole YAML configuration is being removed", - "description": "Configuring PI-Hole using YAML is being removed.\n\nYour existing YAML configuration has been imported into the UI automatically.\n\nRemove the PI-Hole YAML configuration from your configuration.yaml file and restart Home Assistant to fix this issue." + "already_configured": "[%key:common::config_flow::abort::already_configured_service%]", + "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]" } } } diff --git a/homeassistant/components/pi_hole/translations/en.json b/homeassistant/components/pi_hole/translations/en.json index 4333838ae64..815182731c2 100644 --- a/homeassistant/components/pi_hole/translations/en.json +++ b/homeassistant/components/pi_hole/translations/en.json @@ -1,16 +1,20 @@ { "config": { "abort": { - "already_configured": "Service is already configured" + "already_configured": "Service is already configured", + "reauth_successful": "Re-authentication was successful" }, "error": { - "cannot_connect": "Failed to connect" + "cannot_connect": "Failed to connect", + "invalid_auth": "Invalid authentication" }, "step": { - "api_key": { + "reauth_confirm": { "data": { "api_key": "API Key" - } + }, + "description": "Please enter a new api key for PI-Hole at {host}/{location}", + "title": "PI-Hole Reauthenticate Integration" }, "user": { "data": { @@ -20,16 +24,9 @@ "name": "Name", "port": "Port", "ssl": "Uses an SSL certificate", - "statistics_only": "Statistics Only", "verify_ssl": "Verify SSL certificate" } } } - }, - "issues": { - "deprecated_yaml": { - "description": "Configuring PI-Hole using YAML is being removed.\n\nYour existing YAML configuration has been imported into the UI automatically.\n\nRemove the PI-Hole YAML configuration from your configuration.yaml file and restart Home Assistant to fix this issue.", - "title": "The PI-Hole YAML configuration is being removed" - } } } \ No newline at end of file diff --git a/tests/components/pi_hole/__init__.py b/tests/components/pi_hole/__init__.py index 57ea89fc7e0..49e15391f8c 100644 --- a/tests/components/pi_hole/__init__.py +++ b/tests/components/pi_hole/__init__.py @@ -3,7 +3,13 @@ from unittest.mock import AsyncMock, MagicMock, patch from hole.exceptions import HoleError -from homeassistant.components.pi_hole.const import CONF_STATISTICS_ONLY +from homeassistant.components.pi_hole.const import ( + CONF_STATISTICS_ONLY, + DEFAULT_LOCATION, + DEFAULT_NAME, + DEFAULT_SSL, + DEFAULT_VERIFY_SSL, +) from homeassistant.const import ( CONF_API_KEY, CONF_HOST, @@ -47,7 +53,16 @@ API_KEY = "apikey" SSL = False VERIFY_SSL = True -CONF_DATA = { +CONFIG_DATA_DEFAULTS = { + CONF_HOST: f"{HOST}:{PORT}", + CONF_LOCATION: DEFAULT_LOCATION, + CONF_NAME: DEFAULT_NAME, + CONF_SSL: DEFAULT_SSL, + CONF_VERIFY_SSL: DEFAULT_VERIFY_SSL, + CONF_API_KEY: API_KEY, +} + +CONFIG_DATA = { CONF_HOST: f"{HOST}:{PORT}", CONF_LOCATION: LOCATION, CONF_NAME: NAME, @@ -56,34 +71,35 @@ CONF_DATA = { CONF_VERIFY_SSL: VERIFY_SSL, } -CONF_CONFIG_FLOW_USER = { +CONFIG_FLOW_USER = { CONF_HOST: HOST, CONF_PORT: PORT, + CONF_API_KEY: API_KEY, CONF_LOCATION: LOCATION, CONF_NAME: NAME, - CONF_STATISTICS_ONLY: False, CONF_SSL: SSL, CONF_VERIFY_SSL: VERIFY_SSL, } -CONF_CONFIG_FLOW_API_KEY = { +CONFIG_FLOW_API_KEY = { CONF_API_KEY: API_KEY, } -CONF_CONFIG_ENTRY = { +CONFIG_ENTRY = { CONF_HOST: f"{HOST}:{PORT}", CONF_LOCATION: LOCATION, CONF_NAME: NAME, - CONF_STATISTICS_ONLY: False, CONF_API_KEY: API_KEY, CONF_SSL: SSL, CONF_VERIFY_SSL: VERIFY_SSL, } +CONFIG_ENTRY_IMPORTED = {**CONFIG_ENTRY, CONF_STATISTICS_ONLY: False} + SWITCH_ENTITY_ID = "switch.pi_hole" -def _create_mocked_hole(raise_exception=False, has_versions=True): +def _create_mocked_hole(raise_exception=False, has_versions=True, has_data=True): mocked_hole = MagicMock() type(mocked_hole).get_data = AsyncMock( side_effect=HoleError("") if raise_exception else None @@ -93,7 +109,10 @@ def _create_mocked_hole(raise_exception=False, has_versions=True): ) type(mocked_hole).enable = AsyncMock() type(mocked_hole).disable = AsyncMock() - mocked_hole.data = ZERO_DATA + if has_data: + mocked_hole.data = ZERO_DATA + else: + mocked_hole.data = [] if has_versions: mocked_hole.versions = SAMPLE_VERSIONS else: diff --git a/tests/components/pi_hole/test_config_flow.py b/tests/components/pi_hole/test_config_flow.py index bc86922c89f..65f21418bad 100644 --- a/tests/components/pi_hole/test_config_flow.py +++ b/tests/components/pi_hole/test_config_flow.py @@ -2,28 +2,26 @@ import logging from unittest.mock import patch -from homeassistant.components.pi_hole.const import CONF_STATISTICS_ONLY, DOMAIN +from homeassistant.components.pi_hole.const import DOMAIN from homeassistant.config_entries import SOURCE_IMPORT, SOURCE_USER from homeassistant.const import CONF_API_KEY +from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType from . import ( - CONF_CONFIG_ENTRY, - CONF_CONFIG_FLOW_API_KEY, - CONF_CONFIG_FLOW_USER, - CONF_DATA, + CONFIG_DATA, + CONFIG_DATA_DEFAULTS, + CONFIG_ENTRY, + CONFIG_ENTRY_IMPORTED, + CONFIG_FLOW_USER, NAME, + ZERO_DATA, _create_mocked_hole, _patch_config_flow_hole, + _patch_init_hole, ) - -def _flow_next(hass, flow_id): - return next( - flow - for flow in hass.config_entries.flow.async_progress() - if flow["flow_id"] == flow_id - ) +from tests.common import MockConfigEntry def _patch_setup(): @@ -33,41 +31,41 @@ def _patch_setup(): ) -async def test_flow_import(hass, caplog): +async def test_flow_import(hass: HomeAssistant, caplog): """Test import flow.""" mocked_hole = _create_mocked_hole() with _patch_config_flow_hole(mocked_hole), _patch_setup(): result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_IMPORT}, data=CONF_DATA + DOMAIN, context={"source": SOURCE_IMPORT}, data=CONFIG_DATA ) assert result["type"] == FlowResultType.CREATE_ENTRY assert result["title"] == NAME - assert result["data"] == CONF_CONFIG_ENTRY + assert result["data"] == CONFIG_ENTRY_IMPORTED # duplicated server result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_IMPORT}, data=CONF_DATA + DOMAIN, context={"source": SOURCE_IMPORT}, data=CONFIG_DATA ) assert result["type"] == FlowResultType.ABORT assert result["reason"] == "already_configured" -async def test_flow_import_invalid(hass, caplog): +async def test_flow_import_invalid(hass: HomeAssistant, caplog): """Test import flow with invalid server.""" mocked_hole = _create_mocked_hole(True) with _patch_config_flow_hole(mocked_hole), _patch_setup(): result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_IMPORT}, data=CONF_DATA + DOMAIN, context={"source": SOURCE_IMPORT}, data=CONFIG_DATA ) assert result["type"] == FlowResultType.ABORT assert result["reason"] == "cannot_connect" assert len([x for x in caplog.records if x.levelno == logging.ERROR]) == 1 -async def test_flow_user(hass): +async def test_flow_user(hass: HomeAssistant): """Test user initialized flow.""" - mocked_hole = _create_mocked_hole() - with _patch_config_flow_hole(mocked_hole), _patch_setup(): + mocked_hole = _create_mocked_hole(has_data=False) + with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER}, @@ -75,69 +73,68 @@ async def test_flow_user(hass): assert result["type"] == FlowResultType.FORM assert result["step_id"] == "user" assert result["errors"] == {} - _flow_next(hass, result["flow_id"]) result = await hass.config_entries.flow.async_configure( result["flow_id"], - user_input=CONF_CONFIG_FLOW_USER, + user_input=CONFIG_FLOW_USER, ) assert result["type"] == FlowResultType.FORM - assert result["step_id"] == "api_key" - assert result["errors"] is None - _flow_next(hass, result["flow_id"]) + assert result["step_id"] == "user" + assert result["errors"] == {CONF_API_KEY: "invalid_auth"} + mocked_hole.data = ZERO_DATA result = await hass.config_entries.flow.async_configure( result["flow_id"], - user_input=CONF_CONFIG_FLOW_API_KEY, + user_input=CONFIG_FLOW_USER, ) assert result["type"] == FlowResultType.CREATE_ENTRY assert result["title"] == NAME - assert result["data"] == CONF_CONFIG_ENTRY + assert result["data"] == CONFIG_ENTRY # duplicated server result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER}, - data=CONF_CONFIG_FLOW_USER, + data=CONFIG_FLOW_USER, ) assert result["type"] == FlowResultType.ABORT assert result["reason"] == "already_configured" -async def test_flow_statistics_only(hass): - """Test user initialized flow with statistics only.""" - mocked_hole = _create_mocked_hole() - with _patch_config_flow_hole(mocked_hole), _patch_setup(): - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": SOURCE_USER}, - ) - assert result["type"] == FlowResultType.FORM - assert result["step_id"] == "user" - assert result["errors"] == {} - _flow_next(hass, result["flow_id"]) - - user_input = {**CONF_CONFIG_FLOW_USER} - user_input[CONF_STATISTICS_ONLY] = True - config_entry_data = {**CONF_CONFIG_ENTRY} - config_entry_data[CONF_STATISTICS_ONLY] = True - config_entry_data.pop(CONF_API_KEY) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], - user_input=user_input, - ) - assert result["type"] == FlowResultType.CREATE_ENTRY - assert result["title"] == NAME - assert result["data"] == config_entry_data - - async def test_flow_user_invalid(hass): """Test user initialized flow with invalid server.""" mocked_hole = _create_mocked_hole(True) with _patch_config_flow_hole(mocked_hole), _patch_setup(): result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_USER}, data=CONF_CONFIG_FLOW_USER + DOMAIN, context={"source": SOURCE_USER}, data=CONFIG_FLOW_USER ) assert result["type"] == FlowResultType.FORM assert result["step_id"] == "user" assert result["errors"] == {"base": "cannot_connect"} + + +async def test_flow_reauth(hass: HomeAssistant): + """Test reauth flow.""" + mocked_hole = _create_mocked_hole(has_data=False) + entry = MockConfigEntry(domain=DOMAIN, data=CONFIG_DATA_DEFAULTS) + entry.add_to_hass(hass) + with _patch_init_hole(mocked_hole), _patch_config_flow_hole(mocked_hole): + assert not await hass.config_entries.async_setup(entry.entry_id) + + flows = hass.config_entries.flow.async_progress() + + assert len(flows) == 1 + assert flows[0]["step_id"] == "reauth_confirm" + assert flows[0]["context"]["entry_id"] == entry.entry_id + + mocked_hole.data = ZERO_DATA + + result = await hass.config_entries.flow.async_configure( + flows[0]["flow_id"], + user_input={CONF_API_KEY: "newkey"}, + ) + + await hass.async_block_till_done() + assert result["type"] == FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + assert entry.data[CONF_API_KEY] == "newkey" diff --git a/tests/components/pi_hole/test_init.py b/tests/components/pi_hole/test_init.py index dce3773acdc..75d9dd27aee 100644 --- a/tests/components/pi_hole/test_init.py +++ b/tests/components/pi_hole/test_init.py @@ -7,27 +7,16 @@ from hole.exceptions import HoleError from homeassistant.components import pi_hole, switch from homeassistant.components.pi_hole.const import ( CONF_STATISTICS_ONLY, - DEFAULT_LOCATION, - DEFAULT_NAME, - DEFAULT_SSL, - DEFAULT_VERIFY_SSL, SERVICE_DISABLE, SERVICE_DISABLE_ATTR_DURATION, ) -from homeassistant.const import ( - ATTR_ENTITY_ID, - CONF_API_KEY, - CONF_HOST, - CONF_LOCATION, - CONF_NAME, - CONF_SSL, - CONF_VERIFY_SSL, -) +from homeassistant.config_entries import ConfigEntryState +from homeassistant.const import ATTR_ENTITY_ID, CONF_API_KEY, CONF_HOST +from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component from . import ( - CONF_CONFIG_ENTRY, - CONF_DATA, + CONFIG_DATA_DEFAULTS, SWITCH_ENTITY_ID, _create_mocked_hole, _patch_config_flow_hole, @@ -37,7 +26,7 @@ from . import ( from tests.common import MockConfigEntry -async def test_setup_minimal_config(hass): +async def test_setup_minimal_config(hass: HomeAssistant): """Tests component setup with minimal config.""" mocked_hole = _create_mocked_hole() with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): @@ -88,7 +77,7 @@ async def test_setup_minimal_config(hass): assert state.state == "off" -async def test_setup_name_config(hass): +async def test_setup_name_config(hass: HomeAssistant): """Tests component setup with a custom name.""" mocked_hole = _create_mocked_hole() with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): @@ -106,7 +95,7 @@ async def test_setup_name_config(hass): ) -async def test_switch(hass, caplog): +async def test_switch(hass: HomeAssistant, caplog): """Test Pi-hole switch.""" mocked_hole = _create_mocked_hole() with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): @@ -154,7 +143,7 @@ async def test_switch(hass, caplog): assert errors[-1].message == "Unable to disable Pi-hole: Error2" -async def test_disable_service_call(hass): +async def test_disable_service_call(hass: HomeAssistant): """Test disable service call with no Pi-hole named.""" mocked_hole = _create_mocked_hole() with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): @@ -180,21 +169,14 @@ async def test_disable_service_call(hass): await hass.async_block_till_done() - mocked_hole.disable.assert_called_once_with(1) + mocked_hole.disable.assert_called_with(1) -async def test_unload(hass): +async def test_unload(hass: HomeAssistant): """Test unload entities.""" entry = MockConfigEntry( domain=pi_hole.DOMAIN, - data={ - CONF_NAME: DEFAULT_NAME, - CONF_HOST: "pi.hole", - CONF_LOCATION: DEFAULT_LOCATION, - CONF_SSL: DEFAULT_SSL, - CONF_VERIFY_SSL: DEFAULT_VERIFY_SSL, - CONF_STATISTICS_ONLY: True, - }, + data={**CONFIG_DATA_DEFAULTS, CONF_HOST: "pi.hole"}, ) entry.add_to_hass(hass) mocked_hole = _create_mocked_hole() @@ -208,32 +190,25 @@ async def test_unload(hass): assert entry.entry_id not in hass.data[pi_hole.DOMAIN] -async def test_migrate(hass): - """Test migrate from old config entry.""" - entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=CONF_DATA) - entry.add_to_hass(hass) - +async def test_remove_obsolete(hass: HomeAssistant): + """Test removing obsolete config entry parameters.""" mocked_hole = _create_mocked_hole() - with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): - await hass.config_entries.async_setup(entry.entry_id) - await hass.async_block_till_done() - - assert entry.data == CONF_CONFIG_ENTRY - - -async def test_migrate_statistics_only(hass): - """Test migrate from old config entry with statistics only.""" - conf_data = {**CONF_DATA} - conf_data[CONF_API_KEY] = "" - entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=conf_data) + entry = MockConfigEntry( + domain=pi_hole.DOMAIN, data={**CONFIG_DATA_DEFAULTS, CONF_STATISTICS_ONLY: True} + ) entry.add_to_hass(hass) + with _patch_init_hole(mocked_hole): + assert await hass.config_entries.async_setup(entry.entry_id) + assert CONF_STATISTICS_ONLY not in entry.data + +async def test_missing_api_key(hass: HomeAssistant): + """Tests start reauth flow if api key is missing.""" mocked_hole = _create_mocked_hole() - with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): - await hass.config_entries.async_setup(entry.entry_id) - await hass.async_block_till_done() - - config_entry_data = {**CONF_CONFIG_ENTRY} - config_entry_data[CONF_STATISTICS_ONLY] = True - config_entry_data[CONF_API_KEY] = "" - assert entry.data == config_entry_data + data = CONFIG_DATA_DEFAULTS.copy() + data.pop(CONF_API_KEY) + entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=data) + entry.add_to_hass(hass) + with _patch_init_hole(mocked_hole): + assert not await hass.config_entries.async_setup(entry.entry_id) + assert entry.state == ConfigEntryState.SETUP_ERROR