From ce99fa8c02755bad814ba21ec1361577dcd9f6bd Mon Sep 17 00:00:00 2001 From: Xiaonan Shen Date: Wed, 13 May 2020 06:25:06 -0700 Subject: [PATCH] Add config flow to pi_hole integration (#35442) * Add config flow to pi-hole * Add config flow tests * Change PlatformNotReady to ConfigEntryNotReady * Improve config flow * Add @shenxn as codeowner * Use entity_id as unique id * Remove .get with [] for required fields * Remove unique id from config flow * Replace some strings with references * Fix api_key string * Fix service api_key check * Remove unused DuplicatedNameException --- CODEOWNERS | 2 +- homeassistant/components/pi_hole/__init__.py | 180 ++++++++---------- .../components/pi_hole/config_flow.py | 146 ++++++++++++++ homeassistant/components/pi_hole/const.py | 1 - .../components/pi_hole/manifest.json | 3 +- homeassistant/components/pi_hole/sensor.py | 35 ++-- homeassistant/components/pi_hole/strings.json | 23 +++ homeassistant/generated/config_flows.py | 1 + tests/components/pi_hole/__init__.py | 70 +++++++ tests/components/pi_hole/test_config_flow.py | 125 ++++++++++++ tests/components/pi_hole/test_init.py | 59 ++---- 11 files changed, 488 insertions(+), 157 deletions(-) create mode 100644 homeassistant/components/pi_hole/config_flow.py create mode 100644 homeassistant/components/pi_hole/strings.json create mode 100644 tests/components/pi_hole/test_config_flow.py diff --git a/CODEOWNERS b/CODEOWNERS index e388193ffb4..07c99d48b86 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -300,7 +300,7 @@ homeassistant/components/pcal9535a/* @Shulyaka homeassistant/components/persistent_notification/* @home-assistant/core homeassistant/components/philips_js/* @elupus homeassistant/components/pi4ioe5v9xxxx/* @antonverburg -homeassistant/components/pi_hole/* @fabaff @johnluetke +homeassistant/components/pi_hole/* @fabaff @johnluetke @shenxn homeassistant/components/pilight/* @trekky12 homeassistant/components/plaato/* @JohNan homeassistant/components/plant/* @ChristianKuehnel diff --git a/homeassistant/components/pi_hole/__init__.py b/homeassistant/components/pi_hole/__init__.py index 989841d1317..a0d6c5da6d1 100644 --- a/homeassistant/components/pi_hole/__init__.py +++ b/homeassistant/components/pi_hole/__init__.py @@ -6,6 +6,7 @@ from hole.exceptions import HoleError import voluptuous as vol from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN +from homeassistant.config_entries import SOURCE_IMPORT from homeassistant.const import ( CONF_API_KEY, CONF_HOST, @@ -13,14 +14,13 @@ from homeassistant.const import ( CONF_SSL, CONF_VERIFY_SSL, ) +from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers import config_validation as cv from homeassistant.helpers.aiohttp_client import async_get_clientsession -from homeassistant.helpers.discovery import async_load_platform from homeassistant.util import Throttle from .const import ( CONF_LOCATION, - CONF_SLUG, DEFAULT_LOCATION, DEFAULT_NAME, DEFAULT_SSL, @@ -34,31 +34,6 @@ from .const import ( SERVICE_ENABLE_ATTR_NAME, ) - -def ensure_unique_names_and_slugs(config): - """Ensure that each configuration dict contains a unique `name` value.""" - names = {} - slugs = {} - for conf in config: - if conf[CONF_NAME] not in names and conf[CONF_SLUG] not in slugs: - names[conf[CONF_NAME]] = conf[CONF_HOST] - slugs[conf[CONF_SLUG]] = conf[CONF_HOST] - else: - raise vol.Invalid( - f"Duplicate name '{conf[CONF_NAME]}' (or slug '{conf[CONF_SLUG]}') " - f"for '{conf[CONF_HOST]}' (already in use by " - f"'{names.get(conf[CONF_NAME], slugs[conf[CONF_SLUG]])}'). " - "Each configured Pi-hole must have a unique name." - ) - return config - - -def coerce_slug(config): - """Coerce the name of the Pi-Hole into a slug.""" - config[CONF_SLUG] = cv.slugify(config[CONF_NAME]) - return config - - LOGGER = logging.getLogger(__name__) PI_HOLE_SCHEMA = vol.Schema( @@ -71,16 +46,11 @@ PI_HOLE_SCHEMA = vol.Schema( vol.Optional(CONF_LOCATION, default=DEFAULT_LOCATION): cv.string, vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean, }, - coerce_slug, ) ) CONFIG_SCHEMA = vol.Schema( - { - DOMAIN: vol.Schema( - vol.All(cv.ensure_list, [PI_HOLE_SCHEMA], ensure_unique_names_and_slugs) - ) - }, + {DOMAIN: vol.Schema(vol.All(cv.ensure_list, [PI_HOLE_SCHEMA]))}, extra=vol.ALLOW_EXTRA, ) @@ -88,81 +58,42 @@ CONFIG_SCHEMA = vol.Schema( async def async_setup(hass, config): """Set up the pi_hole integration.""" - def get_data(): - """Retrieve component data.""" - return hass.data[DOMAIN] - - def ensure_api_token(call_data): - """Ensure the Pi-Hole to be enabled/disabled has a api_token configured.""" - data = get_data() - if SERVICE_DISABLE_ATTR_NAME not in call_data: - for slug in data: - call_data[SERVICE_DISABLE_ATTR_NAME] = data[slug].name - ensure_api_token(call_data) - - call_data[SERVICE_DISABLE_ATTR_NAME] = None - else: - slug = cv.slugify(call_data[SERVICE_DISABLE_ATTR_NAME]) - - if (data[slug]).api.api_token is None: - raise vol.Invalid( - f"Pi-hole '{pi_hole.name}' must have an api_key " - "provided in configuration to be enabled." - ) - - return call_data - service_disable_schema = vol.Schema( vol.All( { vol.Required(SERVICE_DISABLE_ATTR_DURATION): vol.All( cv.time_period_str, cv.positive_timedelta ), - vol.Optional(SERVICE_DISABLE_ATTR_NAME): vol.In( - [conf[CONF_NAME] for conf in config[DOMAIN]], msg="Unknown Pi-Hole" - ), + vol.Optional(SERVICE_DISABLE_ATTR_NAME): str, }, - ensure_api_token, ) ) - service_enable_schema = vol.Schema( - { - vol.Optional(SERVICE_ENABLE_ATTR_NAME): vol.In( - [conf[CONF_NAME] for conf in config[DOMAIN]], msg="Unknown Pi-Hole" - ) - } - ) + service_enable_schema = vol.Schema({vol.Optional(SERVICE_ENABLE_ATTR_NAME): str}) hass.data[DOMAIN] = {} - for conf in config[DOMAIN]: - name = conf[CONF_NAME] - slug = conf[CONF_SLUG] - host = conf[CONF_HOST] - use_tls = conf[CONF_SSL] - verify_tls = conf[CONF_VERIFY_SSL] - location = conf[CONF_LOCATION] - api_key = conf.get(CONF_API_KEY) + # import + if DOMAIN in config: + for conf in config[DOMAIN]: + hass.async_create_task( + hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_IMPORT}, data=conf + ) + ) - LOGGER.debug("Setting up %s integration with host %s", DOMAIN, host) - - session = async_get_clientsession(hass, verify_tls) - pi_hole = PiHoleData( - Hole( - host, - hass.loop, - session, - location=location, - tls=use_tls, - api_token=api_key, - ), - name, - ) - - await pi_hole.async_update() - - hass.data[DOMAIN][slug] = pi_hole + def get_pi_hole_from_name(name): + pi_hole = hass.data[DOMAIN].get(name) + if pi_hole is None: + LOGGER.error("Unknown Pi-hole name %s", name) + return None + if not pi_hole.api.api_token: + LOGGER.error( + "Pi-hole %s must have an api_key provided in configuration to be enabled", + name, + ) + return None + return pi_hole async def disable_service_handler(call): """Handle the service call to disable a single Pi-Hole or all configured Pi-Holes.""" @@ -171,8 +102,9 @@ async def async_setup(hass, config): async def do_disable(name): """Disable the named Pi-Hole.""" - slug = cv.slugify(name) - pi_hole = hass.data[DOMAIN][slug] + pi_hole = get_pi_hole_from_name(name) + if pi_hole is None: + return LOGGER.debug( "Disabling Pi-hole '%s' (%s) for %d seconds", @@ -185,8 +117,8 @@ async def async_setup(hass, config): if name is not None: await do_disable(name) else: - for pi_hole in hass.data[DOMAIN].values(): - await do_disable(pi_hole.name) + for name in hass.data[DOMAIN]: + await do_disable(name) async def enable_service_handler(call): """Handle the service call to enable a single Pi-Hole or all configured Pi-Holes.""" @@ -195,8 +127,9 @@ async def async_setup(hass, config): async def do_enable(name): """Enable the named Pi-Hole.""" - slug = cv.slugify(name) - pi_hole = hass.data[DOMAIN][slug] + pi_hole = get_pi_hole_from_name(name) + if pi_hole is None: + return LOGGER.debug("Enabling Pi-hole '%s' (%s)", name, pi_hole.api.host) await pi_hole.api.enable() @@ -204,8 +137,8 @@ async def async_setup(hass, config): if name is not None: await do_enable(name) else: - for pi_hole in hass.data[DOMAIN].values(): - await do_enable(pi_hole.name) + for name in hass.data[DOMAIN]: + await do_enable(name) hass.services.async_register( DOMAIN, SERVICE_DISABLE, disable_service_handler, schema=service_disable_schema @@ -215,11 +148,52 @@ async def async_setup(hass, config): DOMAIN, SERVICE_ENABLE, enable_service_handler, schema=service_enable_schema ) - hass.async_create_task(async_load_platform(hass, SENSOR_DOMAIN, DOMAIN, {}, config)) + return True + + +async def async_setup_entry(hass, entry): + """Set up Pi-hole entry.""" + name = entry.data[CONF_NAME] + host = entry.data[CONF_HOST] + use_tls = entry.data[CONF_SSL] + verify_tls = entry.data[CONF_VERIFY_SSL] + location = entry.data[CONF_LOCATION] + api_key = entry.data.get(CONF_API_KEY) + + LOGGER.debug("Setting up %s integration with host %s", DOMAIN, host) + + try: + session = async_get_clientsession(hass, verify_tls) + pi_hole = PiHoleData( + Hole( + host, + hass.loop, + session, + location=location, + tls=use_tls, + api_token=api_key, + ), + name, + ) + await pi_hole.async_update() + hass.data[DOMAIN][name] = pi_hole + except HoleError as ex: + LOGGER.warning("Failed to connect: %s", ex) + raise ConfigEntryNotReady + + hass.async_create_task( + hass.config_entries.async_forward_entry_setup(entry, SENSOR_DOMAIN) + ) return True +async def async_unload_entry(hass, entry): + """Unload pi-hole entry.""" + hass.data[DOMAIN].pop(entry.data[CONF_NAME]) + return await hass.config_entries.async_forward_entry_unload(entry, SENSOR_DOMAIN) + + class PiHoleData: """Get the latest data and update the states.""" diff --git a/homeassistant/components/pi_hole/config_flow.py b/homeassistant/components/pi_hole/config_flow.py new file mode 100644 index 00000000000..2b0ebfb7c16 --- /dev/null +++ b/homeassistant/components/pi_hole/config_flow.py @@ -0,0 +1,146 @@ +"""Config flow to configure the Pi-hole integration.""" +import logging + +from hole import Hole +from hole.exceptions import HoleError +import voluptuous as vol + +from homeassistant import config_entries +from homeassistant.components.pi_hole.const import ( # pylint: disable=unused-import + CONF_LOCATION, + DEFAULT_LOCATION, + DEFAULT_NAME, + DEFAULT_SSL, + DEFAULT_VERIFY_SSL, + DOMAIN, +) +from homeassistant.const import ( + CONF_API_KEY, + CONF_HOST, + CONF_NAME, + CONF_PORT, + CONF_SSL, + CONF_VERIFY_SSL, +) +from homeassistant.helpers.aiohttp_client import async_get_clientsession + +_LOGGER = logging.getLogger(__name__) + + +class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): + """Handle a Pi-hole config flow.""" + + VERSION = 1 + CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_POLL + + async def async_step_user(self, user_input=None): + """Handle a flow initiated by the user.""" + return await self.async_step_init(user_input) + + async def async_step_import(self, user_input=None): + """Handle a flow initiated by import.""" + return await self.async_step_init(user_input, is_import=True) + + async def async_step_init(self, user_input, is_import=False): + """Handle init step of a flow.""" + errors = {} + + if user_input is not None: + host = ( + user_input[CONF_HOST] + if is_import + else f"{user_input[CONF_HOST]}:{user_input[CONF_PORT]}" + ) + name = user_input[CONF_NAME] + location = user_input[CONF_LOCATION] + tls = user_input[CONF_SSL] + verify_tls = user_input[CONF_VERIFY_SSL] + api_token = user_input.get(CONF_API_KEY) + endpoint = f"{host}/{location}" + + if await self._async_endpoint_existed(endpoint): + return self.async_abort(reason="already_configured") + if await self._async_name_existed(name): + if is_import: + _LOGGER.error("Failed to import: name %s already existed", name) + return self.async_abort(reason="duplicated_name") + + try: + await self._async_try_connect( + host, location, tls, verify_tls, api_token + ) + return self.async_create_entry( + title=name, + data={ + CONF_HOST: host, + CONF_NAME: name, + CONF_LOCATION: location, + CONF_SSL: tls, + CONF_VERIFY_SSL: verify_tls, + CONF_API_KEY: api_token, + }, + ) + except HoleError as ex: + _LOGGER.debug("Connection failed: %s", ex) + if is_import: + _LOGGER.error("Failed to import: %s", ex) + return self.async_abort(reason="cannot_connect") + errors["base"] = "cannot_connect" + + user_input = user_input or {} + return self.async_show_form( + step_id="user", + data_schema=vol.Schema( + { + vol.Required( + CONF_HOST, default=user_input.get(CONF_HOST) or "" + ): str, + vol.Required( + CONF_PORT, default=user_input.get(CONF_PORT) or 80 + ): vol.Coerce(int), + vol.Required( + CONF_NAME, default=user_input.get(CONF_NAME) or DEFAULT_NAME + ): str, + vol.Required( + CONF_LOCATION, + default=user_input.get(CONF_LOCATION) or DEFAULT_LOCATION, + ): str, + vol.Optional( + CONF_API_KEY, default=user_input.get(CONF_API_KEY) or "" + ): str, + vol.Required( + CONF_SSL, default=user_input.get(CONF_SSL) or DEFAULT_SSL + ): bool, + vol.Required( + CONF_VERIFY_SSL, + default=user_input.get(CONF_VERIFY_SSL) or DEFAULT_VERIFY_SSL, + ): bool, + } + ), + errors=errors, + ) + + async def _async_endpoint_existed(self, endpoint): + existing_endpoints = [ + f"{entry.data.get(CONF_HOST)}/{entry.data.get(CONF_LOCATION)}" + for entry in self._async_current_entries() + ] + return endpoint in existing_endpoints + + async def _async_name_existed(self, name): + existing_names = [ + entry.data.get(CONF_NAME) for entry in self._async_current_entries() + ] + return name in existing_names + + async def _async_try_connect(self, host, location, tls, verify_tls, api_token): + session = async_get_clientsession(self.hass, verify_tls) + pi_hole = Hole( + host, + self.hass.loop, + session, + location=location, + tls=tls, + api_token=api_token, + ) + await pi_hole.get_data() diff --git a/homeassistant/components/pi_hole/const.py b/homeassistant/components/pi_hole/const.py index 94f687d9bfa..eec71ca441d 100644 --- a/homeassistant/components/pi_hole/const.py +++ b/homeassistant/components/pi_hole/const.py @@ -6,7 +6,6 @@ from homeassistant.const import UNIT_PERCENTAGE DOMAIN = "pi_hole" CONF_LOCATION = "location" -CONF_SLUG = "slug" DEFAULT_LOCATION = "admin" DEFAULT_METHOD = "GET" diff --git a/homeassistant/components/pi_hole/manifest.json b/homeassistant/components/pi_hole/manifest.json index 1f4b46cc0d4..efe90bbf7e8 100644 --- a/homeassistant/components/pi_hole/manifest.json +++ b/homeassistant/components/pi_hole/manifest.json @@ -3,5 +3,6 @@ "name": "Pi-hole", "documentation": "https://www.home-assistant.io/integrations/pi_hole", "requirements": ["hole==0.5.1"], - "codeowners": ["@fabaff", "@johnluetke"] + "codeowners": ["@fabaff", "@johnluetke", "@shenxn"], + "config_flow": true } diff --git a/homeassistant/components/pi_hole/sensor.py b/homeassistant/components/pi_hole/sensor.py index c01a0167e53..bbc42cdd8a5 100644 --- a/homeassistant/components/pi_hole/sensor.py +++ b/homeassistant/components/pi_hole/sensor.py @@ -1,6 +1,7 @@ """Support for getting statistical data from a Pi-hole system.""" import logging +from homeassistant.const import CONF_NAME from homeassistant.helpers.entity import Entity from .const import ( @@ -13,29 +14,25 @@ from .const import ( LOGGER = logging.getLogger(__name__) -async def async_setup_platform(hass, config, async_add_entities, discovery_info=None): +async def async_setup_entry(hass, entry, async_add_entities): """Set up the pi-hole sensor.""" - if discovery_info is None: - return - - sensors = [] - for pi_hole in hass.data[PIHOLE_DOMAIN].values(): - for sensor in [ - PiHoleSensor(pi_hole, sensor_name) for sensor_name in SENSOR_LIST - ]: - sensors.append(sensor) - + pi_hole = hass.data[PIHOLE_DOMAIN][entry.data[CONF_NAME]] + sensors = [ + PiHoleSensor(pi_hole, sensor_name, entry.entry_id) + for sensor_name in SENSOR_LIST + ] async_add_entities(sensors, True) class PiHoleSensor(Entity): """Representation of a Pi-hole sensor.""" - def __init__(self, pi_hole, sensor_name): + def __init__(self, pi_hole, sensor_name, server_unique_id): """Initialize a Pi-hole sensor.""" self.pi_hole = pi_hole self._name = pi_hole.name self._condition = sensor_name + self._server_unique_id = server_unique_id variable_info = SENSOR_DICT[sensor_name] self._condition_name = variable_info[0] @@ -48,6 +45,20 @@ class PiHoleSensor(Entity): """Return the name of the sensor.""" return f"{self._name} {self._condition_name}" + @property + def unique_id(self): + """Return the unique id of the sensor.""" + return f"{self._server_unique_id}/{self._condition_name}" + + @property + def device_info(self): + """Return the device information of the sensor.""" + return { + "identifiers": {(PIHOLE_DOMAIN, self._server_unique_id)}, + "name": self._name, + "manufacturer": "Pi-hole", + } + @property def icon(self): """Icon to use in the frontend, if any.""" diff --git a/homeassistant/components/pi_hole/strings.json b/homeassistant/components/pi_hole/strings.json new file mode 100644 index 00000000000..b155550844a --- /dev/null +++ b/homeassistant/components/pi_hole/strings.json @@ -0,0 +1,23 @@ +{ + "config": { + "step": { + "user": { + "data": { + "host": "[%key:common::config_flow::data::host%]", + "port": "[%key:common::config_flow::data::port%]", + "name": "Name", + "api_key": "API Key (Optional)", + "ssl": "Use SSL", + "verify_ssl": "Verify SSL certificate" + } + } + }, + "error": { + "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]" + }, + "abort": { + "already_configured": "[%key:common::config_flow::abort::already_configured_service%]", + "duplicated_name": "Name already existed" + } + } +} diff --git a/homeassistant/generated/config_flows.py b/homeassistant/generated/config_flows.py index 7027195a218..1f634d4ed47 100644 --- a/homeassistant/generated/config_flows.py +++ b/homeassistant/generated/config_flows.py @@ -106,6 +106,7 @@ FLOWS = [ "openuv", "owntracks", "panasonic_viera", + "pi_hole", "plaato", "plex", "point", diff --git a/tests/components/pi_hole/__init__.py b/tests/components/pi_hole/__init__.py index 7eea15b79c8..b39bfdced2a 100644 --- a/tests/components/pi_hole/__init__.py +++ b/tests/components/pi_hole/__init__.py @@ -1 +1,71 @@ """Tests for the pi_hole component.""" +from hole.exceptions import HoleError + +from homeassistant.components.pi_hole.const import CONF_LOCATION +from homeassistant.const import ( + CONF_API_KEY, + CONF_HOST, + CONF_NAME, + CONF_PORT, + CONF_SSL, + CONF_VERIFY_SSL, +) + +from tests.async_mock import AsyncMock, MagicMock, patch + +ZERO_DATA = { + "ads_blocked_today": 0, + "ads_percentage_today": 0, + "clients_ever_seen": 0, + "dns_queries_today": 0, + "domains_being_blocked": 0, + "queries_cached": 0, + "queries_forwarded": 0, + "status": 0, + "unique_clients": 0, + "unique_domains": 0, +} + +HOST = "1.2.3.4" +PORT = 80 +LOCATION = "location" +NAME = "name" +API_KEY = "apikey" +SSL = False +VERIFY_SSL = True + +CONF_DATA = { + CONF_HOST: f"{HOST}:{PORT}", + CONF_LOCATION: LOCATION, + CONF_NAME: NAME, + CONF_API_KEY: API_KEY, + CONF_SSL: SSL, + CONF_VERIFY_SSL: VERIFY_SSL, +} + +CONF_CONFIG_FLOW = { + CONF_HOST: HOST, + CONF_PORT: PORT, + CONF_LOCATION: LOCATION, + CONF_NAME: NAME, + CONF_API_KEY: API_KEY, + CONF_SSL: SSL, + CONF_VERIFY_SSL: VERIFY_SSL, +} + + +def _create_mocked_hole(raise_exception=False): + mocked_hole = MagicMock() + type(mocked_hole).get_data = AsyncMock( + side_effect=HoleError("") if raise_exception else None + ) + type(mocked_hole).enable = AsyncMock() + type(mocked_hole).disable = AsyncMock() + mocked_hole.data = ZERO_DATA + return mocked_hole + + +def _patch_config_flow_hole(mocked_hole): + return patch( + "homeassistant.components.pi_hole.config_flow.Hole", return_value=mocked_hole + ) diff --git a/tests/components/pi_hole/test_config_flow.py b/tests/components/pi_hole/test_config_flow.py new file mode 100644 index 00000000000..32b5b1ca146 --- /dev/null +++ b/tests/components/pi_hole/test_config_flow.py @@ -0,0 +1,125 @@ +"""Test pi_hole config flow.""" +import copy +import logging + +from homeassistant.components.pi_hole.const import DOMAIN +from homeassistant.config_entries import SOURCE_IMPORT, SOURCE_USER +from homeassistant.data_entry_flow import ( + RESULT_TYPE_ABORT, + RESULT_TYPE_CREATE_ENTRY, + RESULT_TYPE_FORM, +) + +from . import ( + CONF_CONFIG_FLOW, + CONF_DATA, + CONF_HOST, + NAME, + _create_mocked_hole, + _patch_config_flow_hole, +) + +from tests.async_mock import patch + + +def _flow_next(hass, flow_id): + return next( + flow + for flow in hass.config_entries.flow.async_progress() + if flow["flow_id"] == flow_id + ) + + +def _patch_setup(): + return patch( + "homeassistant.components.pi_hole.async_setup_entry", return_value=True, + ) + + +async def test_flow_import(hass, caplog): + """Test import flow.""" + mocked_hole = _create_mocked_hole() + with _patch_config_flow_hole(mocked_hole), _patch_setup(): + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_IMPORT}, data=CONF_DATA + ) + assert result["type"] == RESULT_TYPE_CREATE_ENTRY + assert result["title"] == NAME + assert result["data"] == CONF_DATA + + # duplicated server + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_IMPORT}, data=CONF_DATA + ) + assert result["type"] == RESULT_TYPE_ABORT + assert result["reason"] == "already_configured" + + # duplicated name + conf_data = copy.deepcopy(CONF_DATA) + conf_data[CONF_HOST] = "4.3.2.1" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_IMPORT}, data=conf_data + ) + assert result["type"] == RESULT_TYPE_ABORT + assert result["reason"] == "duplicated_name" + assert len([x for x in caplog.records if x.levelno == logging.ERROR]) == 1 + + +async def test_flow_import_invalid(hass, caplog): + """Test import flow with invalid server.""" + mocked_hole = _create_mocked_hole(True) + with _patch_config_flow_hole(mocked_hole), _patch_setup(): + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_IMPORT}, data=CONF_DATA + ) + assert result["type"] == RESULT_TYPE_ABORT + assert result["reason"] == "cannot_connect" + assert len([x for x in caplog.records if x.levelno == logging.ERROR]) == 1 + + +async def test_flow_user(hass): + """Test user initialized flow.""" + mocked_hole = _create_mocked_hole() + with _patch_config_flow_hole(mocked_hole), _patch_setup(): + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER}, + ) + assert result["type"] == RESULT_TYPE_FORM + assert result["step_id"] == "user" + assert result["errors"] == {} + _flow_next(hass, result["flow_id"]) + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input=CONF_CONFIG_FLOW, + ) + assert result["type"] == RESULT_TYPE_CREATE_ENTRY + assert result["title"] == NAME + assert result["data"] == CONF_DATA + + # duplicated server + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER}, data=CONF_CONFIG_FLOW, + ) + assert result["type"] == RESULT_TYPE_ABORT + assert result["reason"] == "already_configured" + + # duplicated name + conf_data = copy.deepcopy(CONF_CONFIG_FLOW) + conf_data[CONF_HOST] = "4.3.2.1" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER}, data=conf_data + ) + assert result["type"] == RESULT_TYPE_ABORT + assert result["reason"] == "duplicated_name" + + +async def test_flow_user_invalid(hass): + """Test user initialized flow with invalid server.""" + mocked_hole = _create_mocked_hole(True) + with _patch_config_flow_hole(mocked_hole), _patch_setup(): + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER}, data=CONF_CONFIG_FLOW + ) + assert result["type"] == RESULT_TYPE_FORM + assert result["step_id"] == "user" + assert result["errors"] == {"base": "cannot_connect"} diff --git a/tests/components/pi_hole/test_init.py b/tests/components/pi_hole/test_init.py index 3ff16001d86..73a501c74ce 100644 --- a/tests/components/pi_hole/test_init.py +++ b/tests/components/pi_hole/test_init.py @@ -2,29 +2,20 @@ from homeassistant.components import pi_hole -from tests.async_mock import AsyncMock, patch +from . import _create_mocked_hole, _patch_config_flow_hole + +from tests.async_mock import patch from tests.common import async_setup_component -ZERO_DATA = { - "ads_blocked_today": 0, - "ads_percentage_today": 0, - "clients_ever_seen": 0, - "dns_queries_today": 0, - "domains_being_blocked": 0, - "queries_cached": 0, - "queries_forwarded": 0, - "status": 0, - "unique_clients": 0, - "unique_domains": 0, -} + +def _patch_init_hole(mocked_hole): + return patch("homeassistant.components.pi_hole.Hole", return_value=mocked_hole) async def test_setup_minimal_config(hass): """Tests component setup with minimal config.""" - with patch("homeassistant.components.pi_hole.Hole") as _hole: - _hole.return_value.get_data = AsyncMock(return_value=None) - _hole.return_value.data = ZERO_DATA - + mocked_hole = _create_mocked_hole() + with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): assert await async_setup_component( hass, pi_hole.DOMAIN, {pi_hole.DOMAIN: [{"host": "pi.hole"}]} ) @@ -78,10 +69,8 @@ async def test_setup_minimal_config(hass): async def test_setup_name_config(hass): """Tests component setup with a custom name.""" - with patch("homeassistant.components.pi_hole.Hole") as _hole: - _hole.return_value.get_data = AsyncMock(return_value=None) - _hole.return_value.data = ZERO_DATA - + mocked_hole = _create_mocked_hole() + with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): assert await async_setup_component( hass, pi_hole.DOMAIN, @@ -98,19 +87,15 @@ async def test_setup_name_config(hass): async def test_disable_service_call(hass): """Test disable service call with no Pi-hole named.""" - with patch("homeassistant.components.pi_hole.Hole") as _hole: - mock_disable = AsyncMock(return_value=None) - _hole.return_value.disable = mock_disable - _hole.return_value.get_data = AsyncMock(return_value=None) - _hole.return_value.data = ZERO_DATA - + mocked_hole = _create_mocked_hole() + with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): assert await async_setup_component( hass, pi_hole.DOMAIN, { pi_hole.DOMAIN: [ - {"host": "pi.hole", "api_key": "1"}, - {"host": "pi.hole", "name": "Custom", "api_key": "2"}, + {"host": "pi.hole1", "api_key": "1"}, + {"host": "pi.hole2", "name": "Custom", "api_key": "2"}, ] }, ) @@ -126,24 +111,20 @@ async def test_disable_service_call(hass): await hass.async_block_till_done() - assert mock_disable.call_count == 2 + assert mocked_hole.disable.call_count == 2 async def test_enable_service_call(hass): """Test enable service call with no Pi-hole named.""" - with patch("homeassistant.components.pi_hole.Hole") as _hole: - mock_enable = AsyncMock(return_value=None) - _hole.return_value.enable = mock_enable - _hole.return_value.get_data = AsyncMock(return_value=None) - _hole.return_value.data = ZERO_DATA - + mocked_hole = _create_mocked_hole() + with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): assert await async_setup_component( hass, pi_hole.DOMAIN, { pi_hole.DOMAIN: [ - {"host": "pi.hole", "api_key": "1"}, - {"host": "pi.hole", "name": "Custom", "api_key": "2"}, + {"host": "pi.hole1", "api_key": "1"}, + {"host": "pi.hole2", "name": "Custom", "api_key": "2"}, ] }, ) @@ -156,4 +137,4 @@ async def test_enable_service_call(hass): await hass.async_block_till_done() - assert mock_enable.call_count == 2 + assert mocked_hole.enable.call_count == 2