diff --git a/homeassistant/components/pi_hole/__init__.py b/homeassistant/components/pi_hole/__init__.py index fcd4451bb0b..ac42410604f 100644 --- a/homeassistant/components/pi_hole/__init__.py +++ b/homeassistant/components/pi_hole/__init__.py @@ -16,7 +16,8 @@ from homeassistant.const import ( CONF_VERIFY_SSL, Platform, ) -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.helpers import config_validation as cv from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.entity import DeviceInfo @@ -38,6 +39,13 @@ _LOGGER = logging.getLogger(__name__) CONFIG_SCHEMA = cv.removed(DOMAIN, raise_if_present=False) +PLATFORMS = [ + Platform.BINARY_SENSOR, + Platform.SENSOR, + Platform.SWITCH, + Platform.UPDATE, +] + async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up Pi-hole entry.""" @@ -48,11 +56,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: location = entry.data[CONF_LOCATION] api_key = entry.data.get(CONF_API_KEY) - # For backward compatibility - if CONF_STATISTICS_ONLY not in entry.data: - hass.config_entries.async_update_entry( - entry, data={**entry.data, CONF_STATISTICS_ONLY: not api_key} - ) + # remove obsolet CONF_STATISTICS_ONLY from entry.data + if CONF_STATISTICS_ONLY in entry.data: + entry_data = entry.data.copy() + entry_data.pop(CONF_STATISTICS_ONLY) + hass.config_entries.async_update_entry(entry, data=entry_data) + + # start reauth to force api key is present + if CONF_API_KEY not in entry.data: + raise ConfigEntryAuthFailed _LOGGER.debug("Setting up %s integration with host %s", DOMAIN, host) @@ -72,6 +84,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: await api.get_versions() except HoleError as err: raise UpdateFailed(f"Failed to communicate with API: {err}") from err + if not isinstance(api.data, dict): + raise ConfigEntryAuthFailed coordinator = DataUpdateCoordinator( hass, @@ -89,30 +103,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: await coordinator.async_config_entry_first_refresh() - await hass.config_entries.async_forward_entry_setups(entry, _async_platforms(entry)) + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) return True async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload Pi-hole entry.""" - unload_ok = await hass.config_entries.async_unload_platforms( - entry, _async_platforms(entry) - ) + unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) if unload_ok: hass.data[DOMAIN].pop(entry.entry_id) return unload_ok -@callback -def _async_platforms(entry: ConfigEntry) -> list[Platform]: - """Return platforms to be loaded / unloaded.""" - platforms = [Platform.BINARY_SENSOR, Platform.UPDATE, Platform.SENSOR] - if not entry.data[CONF_STATISTICS_ONLY]: - platforms.append(Platform.SWITCH) - return platforms - - class PiHoleEntity(CoordinatorEntity): """Representation of a Pi-hole entity.""" diff --git a/homeassistant/components/pi_hole/binary_sensor.py b/homeassistant/components/pi_hole/binary_sensor.py index e887f2ea12f..7d0d9034fad 100644 --- a/homeassistant/components/pi_hole/binary_sensor.py +++ b/homeassistant/components/pi_hole/binary_sensor.py @@ -15,8 +15,6 @@ from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from . import PiHoleEntity from .const import ( BINARY_SENSOR_TYPES, - BINARY_SENSOR_TYPES_STATISTICS_ONLY, - CONF_STATISTICS_ONLY, DATA_KEY_API, DATA_KEY_COORDINATOR, DOMAIN as PIHOLE_DOMAIN, @@ -42,18 +40,6 @@ async def async_setup_entry( for description in BINARY_SENSOR_TYPES ] - if entry.data[CONF_STATISTICS_ONLY]: - binary_sensors += [ - PiHoleBinarySensor( - hole_data[DATA_KEY_API], - hole_data[DATA_KEY_COORDINATOR], - name, - entry.entry_id, - description, - ) - for description in BINARY_SENSOR_TYPES_STATISTICS_ONLY - ] - async_add_entities(binary_sensors, True) diff --git a/homeassistant/components/pi_hole/config_flow.py b/homeassistant/components/pi_hole/config_flow.py index 40a19cf416e..48cf93cbe33 100644 --- a/homeassistant/components/pi_hole/config_flow.py +++ b/homeassistant/components/pi_hole/config_flow.py @@ -1,6 +1,7 @@ """Config flow to configure the Pi-hole integration.""" from __future__ import annotations +from collections.abc import Mapping import logging from typing import Any @@ -22,11 +23,9 @@ from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers.aiohttp_client import async_get_clientsession from .const import ( - CONF_STATISTICS_ONLY, DEFAULT_LOCATION, DEFAULT_NAME, DEFAULT_SSL, - DEFAULT_STATISTICS_ONLY, DEFAULT_VERIFY_SSL, DOMAIN, ) @@ -47,59 +46,29 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Handle a flow initiated by the user.""" - return await self.async_step_init(user_input) - - async def async_step_init( - self, user_input: dict[str, Any] | None, is_import: bool = False - ) -> FlowResult: - """Handle init step of a flow.""" errors = {} if user_input is not None: - host = ( - user_input[CONF_HOST] - if is_import - else f"{user_input[CONF_HOST]}:{user_input[CONF_PORT]}" - ) - name = user_input[CONF_NAME] - location = user_input[CONF_LOCATION] - tls = user_input[CONF_SSL] - verify_tls = user_input[CONF_VERIFY_SSL] - endpoint = f"{host}/{location}" + self._config = { + CONF_HOST: f"{user_input[CONF_HOST]}:{user_input[CONF_PORT]}", + CONF_NAME: user_input[CONF_NAME], + CONF_LOCATION: user_input[CONF_LOCATION], + CONF_SSL: user_input[CONF_SSL], + CONF_VERIFY_SSL: user_input[CONF_VERIFY_SSL], + CONF_API_KEY: user_input[CONF_API_KEY], + } - if await self._async_endpoint_existed(endpoint): - return self.async_abort(reason="already_configured") - - try: - await self._async_try_connect(host, location, tls, verify_tls) - except HoleError as ex: - _LOGGER.debug("Connection failed: %s", ex) - if is_import: - _LOGGER.error("Failed to import: %s", ex) - return self.async_abort(reason="cannot_connect") - errors["base"] = "cannot_connect" - else: - self._config = { - CONF_HOST: host, - CONF_NAME: name, - CONF_LOCATION: location, - CONF_SSL: tls, - CONF_VERIFY_SSL: verify_tls, + self._async_abort_entries_match( + { + CONF_HOST: f"{user_input[CONF_HOST]}:{user_input[CONF_PORT]}", + CONF_LOCATION: user_input[CONF_LOCATION], } - if is_import: - api_key = user_input.get(CONF_API_KEY) - return self.async_create_entry( - title=name, - data={ - **self._config, - CONF_STATISTICS_ONLY: api_key is None, - CONF_API_KEY: api_key, - }, - ) - self._config[CONF_STATISTICS_ONLY] = user_input[CONF_STATISTICS_ONLY] - if self._config[CONF_STATISTICS_ONLY]: - return self.async_create_entry(title=name, data=self._config) - return await self.async_step_api_key() + ) + + if not (errors := await self._async_try_connect()): + return self.async_create_entry( + title=user_input[CONF_NAME], data=self._config + ) user_input = user_input or {} return self.async_show_form( @@ -110,6 +79,7 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): vol.Required( CONF_PORT, default=user_input.get(CONF_PORT, 80) ): vol.Coerce(int), + vol.Required(CONF_API_KEY): str, vol.Required( CONF_NAME, default=user_input.get(CONF_NAME, DEFAULT_NAME) ): str, @@ -117,12 +87,6 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): CONF_LOCATION, default=user_input.get(CONF_LOCATION, DEFAULT_LOCATION), ): str, - vol.Required( - CONF_STATISTICS_ONLY, - default=user_input.get( - CONF_STATISTICS_ONLY, DEFAULT_STATISTICS_ONLY - ), - ): bool, vol.Required( CONF_SSL, default=user_input.get(CONF_SSL, DEFAULT_SSL), @@ -136,34 +100,54 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): errors=errors, ) - async def async_step_api_key( - self, user_input: dict[str, Any] | None = None + async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult: + """Perform reauth upon an API authentication error.""" + self._config = dict(entry_data) + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm( + self, + user_input: dict[str, Any] | None = None, ) -> FlowResult: - """Handle step to setup API key.""" + """Perform reauth confirm upon an API authentication error.""" + errors = {} if user_input is not None: - return self.async_create_entry( - title=self._config[CONF_NAME], - data={ - **self._config, - CONF_API_KEY: user_input.get(CONF_API_KEY, ""), - }, - ) + self._config = {**self._config, CONF_API_KEY: user_input[CONF_API_KEY]} + if not (errors := await self._async_try_connect()): + entry = self.hass.config_entries.async_get_entry( + self.context["entry_id"] + ) + assert entry + self.hass.config_entries.async_update_entry(entry, data=self._config) + self.hass.async_create_task( + self.hass.config_entries.async_reload(self.context["entry_id"]) + ) + return self.async_abort(reason="reauth_successful") return self.async_show_form( - step_id="api_key", - data_schema=vol.Schema({vol.Optional(CONF_API_KEY): str}), + step_id="reauth_confirm", + description_placeholders={ + CONF_HOST: self._config[CONF_HOST], + CONF_LOCATION: self._config[CONF_LOCATION], + }, + data_schema=vol.Schema({vol.Required(CONF_API_KEY): str}), + errors=errors, ) - async def _async_endpoint_existed(self, endpoint: str) -> bool: - existing_endpoints = [ - f"{entry.data.get(CONF_HOST)}/{entry.data.get(CONF_LOCATION)}" - for entry in self._async_current_entries() - ] - return endpoint in existing_endpoints - - async def _async_try_connect( - self, host: str, location: str, tls: bool, verify_tls: bool - ) -> None: - session = async_get_clientsession(self.hass, verify_tls) - pi_hole = Hole(host, session, location=location, tls=tls) - await pi_hole.get_data() + async def _async_try_connect(self) -> dict[str, str]: + session = async_get_clientsession(self.hass, self._config[CONF_VERIFY_SSL]) + pi_hole = Hole( + self._config[CONF_HOST], + session, + location=self._config[CONF_LOCATION], + tls=self._config[CONF_SSL], + api_token=self._config[CONF_API_KEY], + ) + try: + await pi_hole.get_data() + except HoleError as ex: + _LOGGER.debug("Connection failed: %s", ex) + return {"base": "cannot_connect"} + if not isinstance(pi_hole.data, dict): + return {CONF_API_KEY: "invalid_auth"} + return {} diff --git a/homeassistant/components/pi_hole/const.py b/homeassistant/components/pi_hole/const.py index c73660faedb..a9bc5824ad9 100644 --- a/homeassistant/components/pi_hole/const.py +++ b/homeassistant/components/pi_hole/const.py @@ -154,9 +154,6 @@ BINARY_SENSOR_TYPES: tuple[PiHoleBinarySensorEntityDescription, ...] = ( }, state_value=lambda api: bool(api.versions["FTL_update"]), ), -) - -BINARY_SENSOR_TYPES_STATISTICS_ONLY: tuple[PiHoleBinarySensorEntityDescription, ...] = ( PiHoleBinarySensorEntityDescription( key="status", name="Status", diff --git a/homeassistant/components/pi_hole/strings.json b/homeassistant/components/pi_hole/strings.json index fbf3c5a627b..120ab8cb80a 100644 --- a/homeassistant/components/pi_hole/strings.json +++ b/homeassistant/components/pi_hole/strings.json @@ -8,22 +8,25 @@ "name": "[%key:common::config_flow::data::name%]", "location": "[%key:common::config_flow::data::location%]", "api_key": "[%key:common::config_flow::data::api_key%]", - "statistics_only": "Statistics Only", "ssl": "[%key:common::config_flow::data::ssl%]", "verify_ssl": "[%key:common::config_flow::data::verify_ssl%]" } }, - "api_key": { + "reauth_confirm": { + "title": "PI-Hole [%key:common::config_flow::title::reauth%]", + "description": "Please enter a new api key for PI-Hole at {host}/{location}", "data": { "api_key": "[%key:common::config_flow::data::api_key%]" } } }, "error": { - "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]" + "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", + "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]" }, "abort": { - "already_configured": "[%key:common::config_flow::abort::already_configured_service%]" + "already_configured": "[%key:common::config_flow::abort::already_configured_service%]", + "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]" } } } diff --git a/homeassistant/components/pi_hole/translations/en.json b/homeassistant/components/pi_hole/translations/en.json index 4333838ae64..815182731c2 100644 --- a/homeassistant/components/pi_hole/translations/en.json +++ b/homeassistant/components/pi_hole/translations/en.json @@ -1,16 +1,20 @@ { "config": { "abort": { - "already_configured": "Service is already configured" + "already_configured": "Service is already configured", + "reauth_successful": "Re-authentication was successful" }, "error": { - "cannot_connect": "Failed to connect" + "cannot_connect": "Failed to connect", + "invalid_auth": "Invalid authentication" }, "step": { - "api_key": { + "reauth_confirm": { "data": { "api_key": "API Key" - } + }, + "description": "Please enter a new api key for PI-Hole at {host}/{location}", + "title": "PI-Hole Reauthenticate Integration" }, "user": { "data": { @@ -20,16 +24,9 @@ "name": "Name", "port": "Port", "ssl": "Uses an SSL certificate", - "statistics_only": "Statistics Only", "verify_ssl": "Verify SSL certificate" } } } - }, - "issues": { - "deprecated_yaml": { - "description": "Configuring PI-Hole using YAML is being removed.\n\nYour existing YAML configuration has been imported into the UI automatically.\n\nRemove the PI-Hole YAML configuration from your configuration.yaml file and restart Home Assistant to fix this issue.", - "title": "The PI-Hole YAML configuration is being removed" - } } } \ No newline at end of file diff --git a/tests/components/pi_hole/__init__.py b/tests/components/pi_hole/__init__.py index 4752f98f98d..677a742726f 100644 --- a/tests/components/pi_hole/__init__.py +++ b/tests/components/pi_hole/__init__.py @@ -4,11 +4,9 @@ from unittest.mock import AsyncMock, MagicMock, patch from hole.exceptions import HoleError from homeassistant.components.pi_hole.const import ( - CONF_STATISTICS_ONLY, DEFAULT_LOCATION, DEFAULT_NAME, DEFAULT_SSL, - DEFAULT_STATISTICS_ONLY, DEFAULT_VERIFY_SSL, ) from homeassistant.const import ( @@ -54,16 +52,16 @@ API_KEY = "apikey" SSL = False VERIFY_SSL = True -CONF_DATA_DEFAULTS = { +CONFIG_DATA_DEFAULTS = { CONF_HOST: f"{HOST}:{PORT}", CONF_LOCATION: DEFAULT_LOCATION, CONF_NAME: DEFAULT_NAME, - CONF_STATISTICS_ONLY: DEFAULT_STATISTICS_ONLY, CONF_SSL: DEFAULT_SSL, CONF_VERIFY_SSL: DEFAULT_VERIFY_SSL, + CONF_API_KEY: API_KEY, } -CONF_DATA = { +CONFIG_DATA = { CONF_HOST: f"{HOST}:{PORT}", CONF_LOCATION: LOCATION, CONF_NAME: NAME, @@ -72,25 +70,20 @@ CONF_DATA = { CONF_VERIFY_SSL: VERIFY_SSL, } -CONF_CONFIG_FLOW_USER = { +CONFIG_FLOW_USER = { CONF_HOST: HOST, CONF_PORT: PORT, + CONF_API_KEY: API_KEY, CONF_LOCATION: LOCATION, CONF_NAME: NAME, - CONF_STATISTICS_ONLY: False, CONF_SSL: SSL, CONF_VERIFY_SSL: VERIFY_SSL, } -CONF_CONFIG_FLOW_API_KEY = { - CONF_API_KEY: API_KEY, -} - -CONF_CONFIG_ENTRY = { +CONFIG_ENTRY = { CONF_HOST: f"{HOST}:{PORT}", CONF_LOCATION: LOCATION, CONF_NAME: NAME, - CONF_STATISTICS_ONLY: False, CONF_API_KEY: API_KEY, CONF_SSL: SSL, CONF_VERIFY_SSL: VERIFY_SSL, @@ -99,7 +92,7 @@ CONF_CONFIG_ENTRY = { SWITCH_ENTITY_ID = "switch.pi_hole" -def _create_mocked_hole(raise_exception=False, has_versions=True): +def _create_mocked_hole(raise_exception=False, has_versions=True, has_data=True): mocked_hole = MagicMock() type(mocked_hole).get_data = AsyncMock( side_effect=HoleError("") if raise_exception else None @@ -109,7 +102,10 @@ def _create_mocked_hole(raise_exception=False, has_versions=True): ) type(mocked_hole).enable = AsyncMock() type(mocked_hole).disable = AsyncMock() - mocked_hole.data = ZERO_DATA + if has_data: + mocked_hole.data = ZERO_DATA + else: + mocked_hole.data = [] if has_versions: mocked_hole.versions = SAMPLE_VERSIONS else: diff --git a/tests/components/pi_hole/test_config_flow.py b/tests/components/pi_hole/test_config_flow.py index c5181d980ab..9cc818df60f 100644 --- a/tests/components/pi_hole/test_config_flow.py +++ b/tests/components/pi_hole/test_config_flow.py @@ -1,31 +1,28 @@ """Test pi_hole config flow.""" -from homeassistant.components.pi_hole.const import CONF_STATISTICS_ONLY, DOMAIN +from homeassistant.components import pi_hole +from homeassistant.components.pi_hole.const import DOMAIN from homeassistant.config_entries import SOURCE_USER from homeassistant.const import CONF_API_KEY from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType from . import ( - CONF_CONFIG_ENTRY, - CONF_CONFIG_FLOW_API_KEY, - CONF_CONFIG_FLOW_USER, + CONFIG_DATA_DEFAULTS, + CONFIG_ENTRY, + CONFIG_FLOW_USER, NAME, + ZERO_DATA, _create_mocked_hole, _patch_config_flow_hole, + _patch_init_hole, ) - -def _flow_next(hass: HomeAssistant, flow_id: str): - return next( - flow - for flow in hass.config_entries.flow.async_progress() - if flow["flow_id"] == flow_id - ) +from tests.common import MockConfigEntry async def test_flow_user(hass: HomeAssistant): """Test user initialized flow.""" - mocked_hole = _create_mocked_hole() + mocked_hole = _create_mocked_hole(has_data=False) with _patch_config_flow_hole(mocked_hole): result = await hass.config_entries.flow.async_init( DOMAIN, @@ -34,69 +31,68 @@ async def test_flow_user(hass: HomeAssistant): assert result["type"] == FlowResultType.FORM assert result["step_id"] == "user" assert result["errors"] == {} - _flow_next(hass, result["flow_id"]) result = await hass.config_entries.flow.async_configure( result["flow_id"], - user_input=CONF_CONFIG_FLOW_USER, + user_input=CONFIG_FLOW_USER, ) assert result["type"] == FlowResultType.FORM - assert result["step_id"] == "api_key" - assert result["errors"] is None - _flow_next(hass, result["flow_id"]) + assert result["step_id"] == "user" + assert result["errors"] == {CONF_API_KEY: "invalid_auth"} + mocked_hole.data = ZERO_DATA result = await hass.config_entries.flow.async_configure( result["flow_id"], - user_input=CONF_CONFIG_FLOW_API_KEY, + user_input=CONFIG_FLOW_USER, ) assert result["type"] == FlowResultType.CREATE_ENTRY assert result["title"] == NAME - assert result["data"] == CONF_CONFIG_ENTRY + assert result["data"] == CONFIG_ENTRY # duplicated server result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER}, - data=CONF_CONFIG_FLOW_USER, + data=CONFIG_FLOW_USER, ) assert result["type"] == FlowResultType.ABORT assert result["reason"] == "already_configured" -async def test_flow_statistics_only(hass: HomeAssistant): - """Test user initialized flow with statistics only.""" - mocked_hole = _create_mocked_hole() - with _patch_config_flow_hole(mocked_hole): - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": SOURCE_USER}, - ) - assert result["type"] == FlowResultType.FORM - assert result["step_id"] == "user" - assert result["errors"] == {} - _flow_next(hass, result["flow_id"]) - - user_input = {**CONF_CONFIG_FLOW_USER} - user_input[CONF_STATISTICS_ONLY] = True - config_entry_data = {**CONF_CONFIG_ENTRY} - config_entry_data[CONF_STATISTICS_ONLY] = True - config_entry_data.pop(CONF_API_KEY) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], - user_input=user_input, - ) - assert result["type"] == FlowResultType.CREATE_ENTRY - assert result["title"] == NAME - assert result["data"] == config_entry_data - - async def test_flow_user_invalid(hass: HomeAssistant): """Test user initialized flow with invalid server.""" mocked_hole = _create_mocked_hole(True) with _patch_config_flow_hole(mocked_hole): result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_USER}, data=CONF_CONFIG_FLOW_USER + DOMAIN, context={"source": SOURCE_USER}, data=CONFIG_FLOW_USER ) assert result["type"] == FlowResultType.FORM assert result["step_id"] == "user" assert result["errors"] == {"base": "cannot_connect"} + + +async def test_flow_reauth(hass: HomeAssistant): + """Test reauth flow.""" + mocked_hole = _create_mocked_hole(has_data=False) + entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=CONFIG_DATA_DEFAULTS) + entry.add_to_hass(hass) + with _patch_init_hole(mocked_hole), _patch_config_flow_hole(mocked_hole): + assert not await hass.config_entries.async_setup(entry.entry_id) + + flows = hass.config_entries.flow.async_progress() + + assert len(flows) == 1 + assert flows[0]["step_id"] == "reauth_confirm" + assert flows[0]["context"]["entry_id"] == entry.entry_id + + mocked_hole.data = ZERO_DATA + + result = await hass.config_entries.flow.async_configure( + flows[0]["flow_id"], + user_input={CONF_API_KEY: "newkey"}, + ) + + await hass.async_block_till_done() + assert result["type"] == FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + assert entry.data[CONF_API_KEY] == "newkey" diff --git a/tests/components/pi_hole/test_init.py b/tests/components/pi_hole/test_init.py index 264a6662496..c739f286cb4 100644 --- a/tests/components/pi_hole/test_init.py +++ b/tests/components/pi_hole/test_init.py @@ -7,28 +7,16 @@ from hole.exceptions import HoleError from homeassistant.components import pi_hole, switch from homeassistant.components.pi_hole.const import ( CONF_STATISTICS_ONLY, - DEFAULT_LOCATION, - DEFAULT_NAME, - DEFAULT_SSL, - DEFAULT_VERIFY_SSL, SERVICE_DISABLE, SERVICE_DISABLE_ATTR_DURATION, ) -from homeassistant.const import ( - ATTR_ENTITY_ID, - CONF_API_KEY, - CONF_HOST, - CONF_LOCATION, - CONF_NAME, - CONF_SSL, - CONF_VERIFY_SSL, -) +from homeassistant.config_entries import ConfigEntryState +from homeassistant.const import ATTR_ENTITY_ID, CONF_API_KEY, CONF_HOST, CONF_NAME from homeassistant.core import HomeAssistant from . import ( - CONF_CONFIG_ENTRY, - CONF_DATA, - CONF_DATA_DEFAULTS, + CONFIG_DATA, + CONFIG_DATA_DEFAULTS, SWITCH_ENTITY_ID, _create_mocked_hole, _patch_init_hole, @@ -40,7 +28,9 @@ from tests.common import MockConfigEntry async def test_setup_with_defaults(hass: HomeAssistant): """Tests component setup with default config.""" mocked_hole = _create_mocked_hole() - entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=CONF_DATA_DEFAULTS) + entry = MockConfigEntry( + domain=pi_hole.DOMAIN, data={**CONFIG_DATA_DEFAULTS, CONF_STATISTICS_ONLY: True} + ) entry.add_to_hass(hass) with _patch_init_hole(mocked_hole): assert await hass.config_entries.async_setup(entry.entry_id) @@ -90,7 +80,7 @@ async def test_setup_name_config(hass: HomeAssistant): """Tests component setup with a custom name.""" mocked_hole = _create_mocked_hole() entry = MockConfigEntry( - domain=pi_hole.DOMAIN, data={**CONF_DATA_DEFAULTS, CONF_NAME: "Custom"} + domain=pi_hole.DOMAIN, data={**CONFIG_DATA_DEFAULTS, CONF_NAME: "Custom"} ) entry.add_to_hass(hass) with _patch_init_hole(mocked_hole): @@ -107,7 +97,7 @@ async def test_setup_name_config(hass: HomeAssistant): async def test_switch(hass: HomeAssistant, caplog): """Test Pi-hole switch.""" mocked_hole = _create_mocked_hole() - entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=CONF_DATA) + entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=CONFIG_DATA) entry.add_to_hass(hass) with _patch_init_hole(mocked_hole): @@ -156,12 +146,12 @@ async def test_disable_service_call(hass: HomeAssistant): mocked_hole = _create_mocked_hole() with _patch_init_hole(mocked_hole): - entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=CONF_DATA) + entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=CONFIG_DATA) entry.add_to_hass(hass) assert await hass.config_entries.async_setup(entry.entry_id) entry = MockConfigEntry( - domain=pi_hole.DOMAIN, data={**CONF_DATA_DEFAULTS, CONF_NAME: "Custom"} + domain=pi_hole.DOMAIN, data={**CONFIG_DATA_DEFAULTS, CONF_NAME: "Custom"} ) entry.add_to_hass(hass) assert await hass.config_entries.async_setup(entry.entry_id) @@ -177,21 +167,14 @@ async def test_disable_service_call(hass: HomeAssistant): await hass.async_block_till_done() - mocked_hole.disable.assert_called_once_with(1) + mocked_hole.disable.assert_called_with(1) async def test_unload(hass: HomeAssistant): """Test unload entities.""" entry = MockConfigEntry( domain=pi_hole.DOMAIN, - data={ - CONF_NAME: DEFAULT_NAME, - CONF_HOST: "pi.hole", - CONF_LOCATION: DEFAULT_LOCATION, - CONF_SSL: DEFAULT_SSL, - CONF_VERIFY_SSL: DEFAULT_VERIFY_SSL, - CONF_STATISTICS_ONLY: True, - }, + data={**CONFIG_DATA_DEFAULTS, CONF_HOST: "pi.hole"}, ) entry.add_to_hass(hass) mocked_hole = _create_mocked_hole() @@ -199,38 +182,32 @@ async def test_unload(hass: HomeAssistant): await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() assert entry.entry_id in hass.data[pi_hole.DOMAIN] - assert await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + assert entry.entry_id not in hass.data[pi_hole.DOMAIN] -async def test_migrate(hass: HomeAssistant): - """Test migrate from old config entry.""" - entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=CONF_DATA) - entry.add_to_hass(hass) - +async def test_remove_obsolete(hass: HomeAssistant): + """Test removing obsolete config entry parameters.""" mocked_hole = _create_mocked_hole() - with _patch_init_hole(mocked_hole): - await hass.config_entries.async_setup(entry.entry_id) - await hass.async_block_till_done() - - assert entry.data == CONF_CONFIG_ENTRY - - -async def test_migrate_statistics_only(hass: HomeAssistant): - """Test migrate from old config entry with statistics only.""" - conf_data = {**CONF_DATA} - conf_data[CONF_API_KEY] = "" - entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=conf_data) + entry = MockConfigEntry( + domain=pi_hole.DOMAIN, data={**CONFIG_DATA_DEFAULTS, CONF_STATISTICS_ONLY: True} + ) entry.add_to_hass(hass) - - mocked_hole = _create_mocked_hole() with _patch_init_hole(mocked_hole): - await hass.config_entries.async_setup(entry.entry_id) - await hass.async_block_till_done() + assert await hass.config_entries.async_setup(entry.entry_id) + assert CONF_STATISTICS_ONLY not in entry.data - config_entry_data = {**CONF_CONFIG_ENTRY} - config_entry_data[CONF_STATISTICS_ONLY] = True - config_entry_data[CONF_API_KEY] = "" - assert entry.data == config_entry_data + +async def test_missing_api_key(hass: HomeAssistant): + """Tests start reauth flow if api key is missing.""" + mocked_hole = _create_mocked_hole() + data = CONFIG_DATA_DEFAULTS.copy() + data.pop(CONF_API_KEY) + entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=data) + entry.add_to_hass(hass) + with _patch_init_hole(mocked_hole): + assert not await hass.config_entries.async_setup(entry.entry_id) + assert entry.state == ConfigEntryState.SETUP_ERROR diff --git a/tests/components/pi_hole/test_update.py b/tests/components/pi_hole/test_update.py index 9c37c68550c..62b7410544c 100644 --- a/tests/components/pi_hole/test_update.py +++ b/tests/components/pi_hole/test_update.py @@ -4,7 +4,7 @@ from homeassistant.components import pi_hole from homeassistant.const import STATE_ON, STATE_UNKNOWN from homeassistant.core import HomeAssistant -from . import CONF_DATA_DEFAULTS, _create_mocked_hole, _patch_init_hole +from . import CONFIG_DATA_DEFAULTS, _create_mocked_hole, _patch_init_hole from tests.common import MockConfigEntry @@ -12,7 +12,7 @@ from tests.common import MockConfigEntry async def test_update(hass: HomeAssistant): """Tests update entity.""" mocked_hole = _create_mocked_hole() - entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=CONF_DATA_DEFAULTS) + entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=CONFIG_DATA_DEFAULTS) entry.add_to_hass(hass) with _patch_init_hole(mocked_hole): assert await hass.config_entries.async_setup(entry.entry_id) @@ -53,7 +53,7 @@ async def test_update(hass: HomeAssistant): async def test_update_no_versions(hass: HomeAssistant): """Tests update entity when no version data available.""" mocked_hole = _create_mocked_hole(has_versions=False) - entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=CONF_DATA_DEFAULTS) + entry = MockConfigEntry(domain=pi_hole.DOMAIN, data=CONFIG_DATA_DEFAULTS) entry.add_to_hass(hass) with _patch_init_hole(mocked_hole): assert await hass.config_entries.async_setup(entry.entry_id)