diff --git a/homeassistant/components/pi_hole/__init__.py b/homeassistant/components/pi_hole/__init__.py index eba9053183b..9b51cc09b35 100644 --- a/homeassistant/components/pi_hole/__init__.py +++ b/homeassistant/components/pi_hole/__init__.py @@ -1,11 +1,11 @@ """The pi_hole component.""" +import asyncio import logging from hole import Hole from hole.exceptions import HoleError import voluptuous as vol -from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN from homeassistant.config_entries import SOURCE_IMPORT from homeassistant.const import ( CONF_API_KEY, @@ -14,9 +14,11 @@ from homeassistant.const import ( CONF_SSL, CONF_VERIFY_SSL, ) +from homeassistant.core import callback from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers import config_validation as cv from homeassistant.helpers.aiohttp_client import async_get_clientsession +from homeassistant.helpers.entity import Entity from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from .const import ( @@ -29,11 +31,6 @@ from .const import ( DEFAULT_VERIFY_SSL, DOMAIN, MIN_TIME_BETWEEN_UPDATES, - SERVICE_DISABLE, - SERVICE_DISABLE_ATTR_DURATION, - SERVICE_DISABLE_ATTR_NAME, - SERVICE_ENABLE, - SERVICE_ENABLE_ATTR_NAME, ) _LOGGER = logging.getLogger(__name__) @@ -58,20 +55,7 @@ CONFIG_SCHEMA = vol.Schema( async def async_setup(hass, config): - """Set up the Pi_hole integration.""" - - service_disable_schema = vol.Schema( - vol.All( - { - vol.Required(SERVICE_DISABLE_ATTR_DURATION): vol.All( - cv.time_period_str, cv.positive_timedelta - ), - vol.Optional(SERVICE_DISABLE_ATTR_NAME): str, - }, - ) - ) - - service_enable_schema = vol.Schema({vol.Optional(SERVICE_ENABLE_ATTR_NAME): str}) + """Set up the Pi-hole integration.""" hass.data[DOMAIN] = {} @@ -84,71 +68,6 @@ async def async_setup(hass, config): ) ) - def get_api_from_name(name): - """Get Pi-hole API object from user configured name.""" - hole_data = hass.data[DOMAIN].get(name) - if hole_data is None: - _LOGGER.error("Unknown Pi-hole name %s", name) - return None - api = hole_data[DATA_KEY_API] - if not api.api_token: - _LOGGER.error( - "Pi-hole %s must have an api_key provided in configuration to be enabled", - name, - ) - return None - return api - - async def disable_service_handler(call): - """Handle the service call to disable a single Pi-hole or all configured Pi-holes.""" - duration = call.data[SERVICE_DISABLE_ATTR_DURATION].total_seconds() - name = call.data.get(SERVICE_DISABLE_ATTR_NAME) - - async def do_disable(name): - """Disable the named Pi-hole.""" - api = get_api_from_name(name) - if api is None: - return - - _LOGGER.debug( - "Disabling Pi-hole '%s' (%s) for %d seconds", name, api.host, duration, - ) - await api.disable(duration) - - if name is not None: - await do_disable(name) - else: - for name in hass.data[DOMAIN]: - await do_disable(name) - - async def enable_service_handler(call): - """Handle the service call to enable a single Pi-hole or all configured Pi-holes.""" - - name = call.data.get(SERVICE_ENABLE_ATTR_NAME) - - async def do_enable(name): - """Enable the named Pi-hole.""" - api = get_api_from_name(name) - if api is None: - return - - _LOGGER.debug("Enabling Pi-hole '%s' (%s)", name, api.host) - await api.enable() - - if name is not None: - await do_enable(name) - else: - for name in hass.data[DOMAIN]: - await do_enable(name) - - hass.services.async_register( - DOMAIN, SERVICE_DISABLE, disable_service_handler, schema=service_disable_schema - ) - - hass.services.async_register( - DOMAIN, SERVICE_ENABLE, enable_service_handler, schema=service_enable_schema - ) - return True @@ -187,19 +106,85 @@ async def async_setup_entry(hass, entry): update_method=async_update_data, update_interval=MIN_TIME_BETWEEN_UPDATES, ) - hass.data[DOMAIN][name] = { + hass.data[DOMAIN][entry.entry_id] = { DATA_KEY_API: api, DATA_KEY_COORDINATOR: coordinator, } - hass.async_create_task( - hass.config_entries.async_forward_entry_setup(entry, SENSOR_DOMAIN) - ) + for platform in _async_platforms(entry): + hass.async_create_task( + hass.config_entries.async_forward_entry_setup(entry, platform) + ) return True async def async_unload_entry(hass, entry): - """Unload pi-hole entry.""" - hass.data[DOMAIN].pop(entry.data[CONF_NAME]) - return await hass.config_entries.async_forward_entry_unload(entry, SENSOR_DOMAIN) + """Unload Pi-hole entry.""" + unload_ok = all( + await asyncio.gather( + *[ + hass.config_entries.async_forward_entry_unload(entry, platform) + for platform in _async_platforms(entry) + ] + ) + ) + if unload_ok: + hass.data[DOMAIN].pop(entry.entry_id) + return unload_ok + + +@callback +def _async_platforms(entry): + """Return platforms to be loaded / unloaded.""" + platforms = ["sensor"] + if entry.data.get(CONF_API_KEY): + platforms.append("switch") + else: + platforms.append("binary_sensor") + return platforms + + +class PiHoleEntity(Entity): + """Representation of a Pi-hole entity.""" + + def __init__(self, api, coordinator, name, server_unique_id): + """Initialize a Pi-hole entity.""" + self.api = api + self.coordinator = coordinator + self._name = name + self._server_unique_id = server_unique_id + + async def async_added_to_hass(self): + """When entity is added to hass.""" + self.async_on_remove( + self.coordinator.async_add_listener(self.async_write_ha_state) + ) + + @property + def icon(self): + """Icon to use in the frontend, if any.""" + return "mdi:pi-hole" + + @property + def device_info(self): + """Return the device information of the entity.""" + return { + "identifiers": {(DOMAIN, self._server_unique_id)}, + "name": self._name, + "manufacturer": "Pi-hole", + } + + @property + def available(self): + """Could the device be accessed during the last update call.""" + return self.coordinator.last_update_success + + @property + def should_poll(self): + """No need to poll. Coordinator notifies entity of updates.""" + return False + + async def async_update(self): + """Get the latest data from the Pi-hole API.""" + await self.coordinator.async_request_refresh() diff --git a/homeassistant/components/pi_hole/binary_sensor.py b/homeassistant/components/pi_hole/binary_sensor.py new file mode 100644 index 00000000000..d572bb390e5 --- /dev/null +++ b/homeassistant/components/pi_hole/binary_sensor.py @@ -0,0 +1,44 @@ +"""Support for getting status from a Pi-hole system.""" +import logging + +from homeassistant.components.binary_sensor import BinarySensorEntity +from homeassistant.const import CONF_NAME + +from . import PiHoleEntity +from .const import DATA_KEY_API, DATA_KEY_COORDINATOR, DOMAIN as PIHOLE_DOMAIN + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry(hass, entry, async_add_entities): + """Set up the Pi-hole binary sensor.""" + name = entry.data[CONF_NAME] + hole_data = hass.data[PIHOLE_DOMAIN][entry.entry_id] + binary_sensors = [ + PiHoleBinarySensor( + hole_data[DATA_KEY_API], + hole_data[DATA_KEY_COORDINATOR], + name, + entry.entry_id, + ) + ] + async_add_entities(binary_sensors, True) + + +class PiHoleBinarySensor(PiHoleEntity, BinarySensorEntity): + """Representation of a Pi-hole binary sensor.""" + + @property + def name(self): + """Return the name of the sensor.""" + return self._name + + @property + def unique_id(self): + """Return the unique id of the sensor.""" + return f"{self._server_unique_id}/Status" + + @property + def is_on(self): + """Return if the service is on.""" + return self.api.data.get("status") == "enabled" diff --git a/homeassistant/components/pi_hole/config_flow.py b/homeassistant/components/pi_hole/config_flow.py index 2b0ebfb7c16..c7061b05caa 100644 --- a/homeassistant/components/pi_hole/config_flow.py +++ b/homeassistant/components/pi_hole/config_flow.py @@ -60,10 +60,6 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): if await self._async_endpoint_existed(endpoint): return self.async_abort(reason="already_configured") - if await self._async_name_existed(name): - if is_import: - _LOGGER.error("Failed to import: name %s already existed", name) - return self.async_abort(reason="duplicated_name") try: await self._async_try_connect( @@ -127,12 +123,6 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): ] return endpoint in existing_endpoints - async def _async_name_existed(self, name): - existing_names = [ - entry.data.get(CONF_NAME) for entry in self._async_current_entries() - ] - return name in existing_names - async def _async_try_connect(self, host, location, tls, verify_tls, api_token): session = async_get_clientsession(self.hass, verify_tls) pi_hole = Hole( diff --git a/homeassistant/components/pi_hole/const.py b/homeassistant/components/pi_hole/const.py index a5807de5575..cb8087fdbf0 100644 --- a/homeassistant/components/pi_hole/const.py +++ b/homeassistant/components/pi_hole/const.py @@ -15,9 +15,6 @@ DEFAULT_VERIFY_SSL = True SERVICE_DISABLE = "disable" SERVICE_DISABLE_ATTR_DURATION = "duration" -SERVICE_DISABLE_ATTR_NAME = "name" -SERVICE_ENABLE = "enable" -SERVICE_ENABLE_ATTR_NAME = SERVICE_DISABLE_ATTR_NAME ATTR_BLOCKED_DOMAINS = "domains_blocked" diff --git a/homeassistant/components/pi_hole/sensor.py b/homeassistant/components/pi_hole/sensor.py index d0009f1ebba..179e61a21cc 100644 --- a/homeassistant/components/pi_hole/sensor.py +++ b/homeassistant/components/pi_hole/sensor.py @@ -2,8 +2,8 @@ import logging from homeassistant.const import CONF_NAME -from homeassistant.helpers.entity import Entity +from . import PiHoleEntity from .const import ( ATTR_BLOCKED_DOMAINS, DATA_KEY_API, @@ -19,7 +19,7 @@ LOGGER = logging.getLogger(__name__) async def async_setup_entry(hass, entry, async_add_entities): """Set up the Pi-hole sensor.""" name = entry.data[CONF_NAME] - hole_data = hass.data[PIHOLE_DOMAIN][name] + hole_data = hass.data[PIHOLE_DOMAIN][entry.entry_id] sensors = [ PiHoleSensor( hole_data[DATA_KEY_API], @@ -33,28 +33,20 @@ async def async_setup_entry(hass, entry, async_add_entities): async_add_entities(sensors, True) -class PiHoleSensor(Entity): +class PiHoleSensor(PiHoleEntity): """Representation of a Pi-hole sensor.""" def __init__(self, api, coordinator, name, sensor_name, server_unique_id): """Initialize a Pi-hole sensor.""" - self.api = api - self.coordinator = coordinator - self._name = name + super().__init__(api, coordinator, name, server_unique_id) + self._condition = sensor_name - self._server_unique_id = server_unique_id variable_info = SENSOR_DICT[sensor_name] self._condition_name = variable_info[0] self._unit_of_measurement = variable_info[1] self._icon = variable_info[2] - async def async_added_to_hass(self): - """When entity is added to hass.""" - self.async_on_remove( - self.coordinator.async_add_listener(self.async_write_ha_state) - ) - @property def name(self): """Return the name of the sensor.""" @@ -65,15 +57,6 @@ class PiHoleSensor(Entity): """Return the unique id of the sensor.""" return f"{self._server_unique_id}/{self._condition_name}" - @property - def device_info(self): - """Return the device information of the sensor.""" - return { - "identifiers": {(PIHOLE_DOMAIN, self._server_unique_id)}, - "name": self._name, - "manufacturer": "Pi-hole", - } - @property def icon(self): """Icon to use in the frontend, if any.""" @@ -96,17 +79,3 @@ class PiHoleSensor(Entity): def device_state_attributes(self): """Return the state attributes of the Pi-hole.""" return {ATTR_BLOCKED_DOMAINS: self.api.data["domains_being_blocked"]} - - @property - def available(self): - """Could the device be accessed during the last update call.""" - return self.coordinator.last_update_success - - @property - def should_poll(self): - """No need to poll. Coordinator notifies entity of updates.""" - return False - - async def async_update(self): - """Get the latest data from the Pi-hole API.""" - await self.coordinator.async_request_refresh() diff --git a/homeassistant/components/pi_hole/services.yaml b/homeassistant/components/pi_hole/services.yaml index 9bb31b1723f..fb9a5c17a13 100644 --- a/homeassistant/components/pi_hole/services.yaml +++ b/homeassistant/components/pi_hole/services.yaml @@ -1,15 +1,9 @@ disable: description: Disable configured Pi-hole(s) for an amount of time fields: + entity_id: + description: Target switch entity + example: switch.pi_hole duration: description: Time that the Pi-hole should be disabled for example: "00:00:15" - name: - description: "[Optional] When multiple Pi-holes are configured, the name of the one to disable. If omitted, all configured Pi-holes will be disabled." - example: "Pi-Hole" -enable: - description: Enable configured Pi-hole(s) - fields: - name: - description: "[Optional] When multiple Pi-holes are configured, the name of the one to enable. If omitted, all configured Pi-holes will be enabled." - example: "Pi-Hole" diff --git a/homeassistant/components/pi_hole/strings.json b/homeassistant/components/pi_hole/strings.json index b155550844a..42faf5d5a46 100644 --- a/homeassistant/components/pi_hole/strings.json +++ b/homeassistant/components/pi_hole/strings.json @@ -6,7 +6,8 @@ "host": "[%key:common::config_flow::data::host%]", "port": "[%key:common::config_flow::data::port%]", "name": "Name", - "api_key": "API Key (Optional)", + "location": "Location", + "api_key": "[%key:common::config_flow::data::api_key%]", "ssl": "Use SSL", "verify_ssl": "Verify SSL certificate" } @@ -16,8 +17,7 @@ "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]" }, "abort": { - "already_configured": "[%key:common::config_flow::abort::already_configured_service%]", - "duplicated_name": "Name already existed" + "already_configured": "[%key:common::config_flow::abort::already_configured_service%]" } } } diff --git a/homeassistant/components/pi_hole/switch.py b/homeassistant/components/pi_hole/switch.py new file mode 100644 index 00000000000..015bab8fe60 --- /dev/null +++ b/homeassistant/components/pi_hole/switch.py @@ -0,0 +1,100 @@ +"""Support for turning on and off Pi-hole system.""" +import logging + +from hole.exceptions import HoleError +import voluptuous as vol + +from homeassistant.components.switch import SwitchEntity +from homeassistant.const import CONF_NAME +from homeassistant.helpers import config_validation as cv, entity_platform + +from . import PiHoleEntity +from .const import ( + DATA_KEY_API, + DATA_KEY_COORDINATOR, + DOMAIN as PIHOLE_DOMAIN, + SERVICE_DISABLE, + SERVICE_DISABLE_ATTR_DURATION, +) + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry(hass, entry, async_add_entities): + """Set up the Pi-hole switch.""" + name = entry.data[CONF_NAME] + hole_data = hass.data[PIHOLE_DOMAIN][entry.entry_id] + switches = [ + PiHoleSwitch( + hole_data[DATA_KEY_API], + hole_data[DATA_KEY_COORDINATOR], + name, + entry.entry_id, + ) + ] + async_add_entities(switches, True) + + # register service + platform = entity_platform.current_platform.get() + platform.async_register_entity_service( + SERVICE_DISABLE, + { + vol.Required(SERVICE_DISABLE_ATTR_DURATION): vol.All( + cv.time_period_str, cv.positive_timedelta + ), + }, + "async_disable", + ) + + +class PiHoleSwitch(PiHoleEntity, SwitchEntity): + """Representation of a Pi-hole switch.""" + + @property + def name(self): + """Return the name of the switch.""" + return self._name + + @property + def unique_id(self): + """Return the unique id of the switch.""" + return f"{self._server_unique_id}/Switch" + + @property + def icon(self): + """Icon to use in the frontend, if any.""" + return "mdi:pi-hole" + + @property + def is_on(self): + """Return if the service is on.""" + return self.api.data.get("status") == "enabled" + + async def async_turn_on(self, **kwargs): + """Turn on the service.""" + try: + await self.api.enable() + await self.async_update() + except HoleError as err: + _LOGGER.error("Unable to enable Pi-hole: %s", err) + + async def async_turn_off(self, **kwargs): + """Turn off the service.""" + await self.async_disable() + + async def async_disable(self, duration=None): + """Disable the service for a given duration.""" + duration_seconds = True # Disable infinitely by default + if duration is not None: + duration_seconds = duration.total_seconds() + _LOGGER.debug( + "Disabling Pi-hole '%s' (%s) for %d seconds", + self.name, + self.api.host, + duration_seconds, + ) + try: + await self.api.disable(duration_seconds) + await self.async_update() + except HoleError as err: + _LOGGER.error("Unable to disable Pi-hole: %s", err) diff --git a/tests/components/pi_hole/__init__.py b/tests/components/pi_hole/__init__.py index b39bfdced2a..f487f413363 100644 --- a/tests/components/pi_hole/__init__.py +++ b/tests/components/pi_hole/__init__.py @@ -21,7 +21,7 @@ ZERO_DATA = { "domains_being_blocked": 0, "queries_cached": 0, "queries_forwarded": 0, - "status": 0, + "status": "disabled", "unique_clients": 0, "unique_domains": 0, } @@ -29,7 +29,7 @@ ZERO_DATA = { HOST = "1.2.3.4" PORT = 80 LOCATION = "location" -NAME = "name" +NAME = "Pi hole" API_KEY = "apikey" SSL = False VERIFY_SSL = True @@ -53,6 +53,8 @@ CONF_CONFIG_FLOW = { CONF_VERIFY_SSL: VERIFY_SSL, } +SWITCH_ENTITY_ID = "switch.pi_hole" + def _create_mocked_hole(raise_exception=False): mocked_hole = MagicMock() @@ -65,6 +67,10 @@ def _create_mocked_hole(raise_exception=False): return mocked_hole +def _patch_init_hole(mocked_hole): + return patch("homeassistant.components.pi_hole.Hole", return_value=mocked_hole) + + def _patch_config_flow_hole(mocked_hole): return patch( "homeassistant.components.pi_hole.config_flow.Hole", return_value=mocked_hole diff --git a/tests/components/pi_hole/test_config_flow.py b/tests/components/pi_hole/test_config_flow.py index 32b5b1ca146..07a9e08313a 100644 --- a/tests/components/pi_hole/test_config_flow.py +++ b/tests/components/pi_hole/test_config_flow.py @@ -1,5 +1,4 @@ """Test pi_hole config flow.""" -import copy import logging from homeassistant.components.pi_hole.const import DOMAIN @@ -13,7 +12,6 @@ from homeassistant.data_entry_flow import ( from . import ( CONF_CONFIG_FLOW, CONF_DATA, - CONF_HOST, NAME, _create_mocked_hole, _patch_config_flow_hole, @@ -54,16 +52,6 @@ async def test_flow_import(hass, caplog): assert result["type"] == RESULT_TYPE_ABORT assert result["reason"] == "already_configured" - # duplicated name - conf_data = copy.deepcopy(CONF_DATA) - conf_data[CONF_HOST] = "4.3.2.1" - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_IMPORT}, data=conf_data - ) - assert result["type"] == RESULT_TYPE_ABORT - assert result["reason"] == "duplicated_name" - assert len([x for x in caplog.records if x.levelno == logging.ERROR]) == 1 - async def test_flow_import_invalid(hass, caplog): """Test import flow with invalid server.""" @@ -103,15 +91,6 @@ async def test_flow_user(hass): assert result["type"] == RESULT_TYPE_ABORT assert result["reason"] == "already_configured" - # duplicated name - conf_data = copy.deepcopy(CONF_CONFIG_FLOW) - conf_data[CONF_HOST] = "4.3.2.1" - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_USER}, data=conf_data - ) - assert result["type"] == RESULT_TYPE_ABORT - assert result["reason"] == "duplicated_name" - async def test_flow_user_invalid(hass): """Test user initialized flow with invalid server.""" diff --git a/tests/components/pi_hole/test_init.py b/tests/components/pi_hole/test_init.py index 088b56d75b9..1f3e2451895 100644 --- a/tests/components/pi_hole/test_init.py +++ b/tests/components/pi_hole/test_init.py @@ -1,18 +1,36 @@ """Test pi_hole component.""" +import logging -from homeassistant.components import pi_hole -from homeassistant.components.pi_hole.const import MIN_TIME_BETWEEN_UPDATES +from hole.exceptions import HoleError + +from homeassistant.components import pi_hole, switch +from homeassistant.components.pi_hole.const import ( + CONF_LOCATION, + DEFAULT_LOCATION, + DEFAULT_NAME, + DEFAULT_SSL, + DEFAULT_VERIFY_SSL, + SERVICE_DISABLE, + SERVICE_DISABLE_ATTR_DURATION, +) +from homeassistant.const import ( + ATTR_ENTITY_ID, + CONF_HOST, + CONF_NAME, + CONF_SSL, + CONF_VERIFY_SSL, +) from homeassistant.setup import async_setup_component -from homeassistant.util import dt as dt_util -from . import _create_mocked_hole, _patch_config_flow_hole +from . import ( + SWITCH_ENTITY_ID, + _create_mocked_hole, + _patch_config_flow_hole, + _patch_init_hole, +) -from tests.async_mock import patch -from tests.common import async_fire_time_changed - - -def _patch_init_hole(mocked_hole): - return patch("homeassistant.components.pi_hole.Hole", return_value=mocked_hole) +from tests.async_mock import AsyncMock +from tests.common import MockConfigEntry async def test_setup_minimal_config(hass): @@ -69,6 +87,9 @@ async def test_setup_minimal_config(hass): assert hass.states.get("sensor.pi_hole_domains_blocked").state == "0" assert hass.states.get("sensor.pi_hole_seen_clients").state == "0" + assert hass.states.get("binary_sensor.pi_hole").name == "Pi-Hole" + assert hass.states.get("binary_sensor.pi_hole").state == "off" + async def test_setup_name_config(hass): """Tests component setup with a custom name.""" @@ -88,6 +109,54 @@ async def test_setup_name_config(hass): ) +async def test_switch(hass, caplog): + """Test Pi-hole switch.""" + mocked_hole = _create_mocked_hole() + with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): + assert await async_setup_component( + hass, + pi_hole.DOMAIN, + {pi_hole.DOMAIN: [{"host": "pi.hole1", "api_key": "1"}]}, + ) + + await hass.async_block_till_done() + + await hass.services.async_call( + switch.DOMAIN, + switch.SERVICE_TURN_ON, + {"entity_id": SWITCH_ENTITY_ID}, + blocking=True, + ) + mocked_hole.enable.assert_called_once() + + await hass.services.async_call( + switch.DOMAIN, + switch.SERVICE_TURN_OFF, + {"entity_id": SWITCH_ENTITY_ID}, + blocking=True, + ) + mocked_hole.disable.assert_called_once_with(True) + + # Failed calls + type(mocked_hole).enable = AsyncMock(side_effect=HoleError("Error1")) + await hass.services.async_call( + switch.DOMAIN, + switch.SERVICE_TURN_ON, + {"entity_id": SWITCH_ENTITY_ID}, + blocking=True, + ) + type(mocked_hole).disable = AsyncMock(side_effect=HoleError("Error2")) + await hass.services.async_call( + switch.DOMAIN, + switch.SERVICE_TURN_OFF, + {"entity_id": SWITCH_ENTITY_ID}, + blocking=True, + ) + errors = [x for x in caplog.records if x.levelno == logging.ERROR] + assert errors[-2].message == "Unable to enable Pi-hole: Error1" + assert errors[-1].message == "Unable to disable Pi-hole: Error2" + + async def test_disable_service_call(hass): """Test disable service call with no Pi-hole named.""" mocked_hole = _create_mocked_hole() @@ -98,7 +167,7 @@ async def test_disable_service_call(hass): { pi_hole.DOMAIN: [ {"host": "pi.hole1", "api_key": "1"}, - {"host": "pi.hole2", "name": "Custom", "api_key": "2"}, + {"host": "pi.hole2", "name": "Custom"}, ] }, ) @@ -107,57 +176,35 @@ async def test_disable_service_call(hass): await hass.services.async_call( pi_hole.DOMAIN, - pi_hole.SERVICE_DISABLE, - {pi_hole.SERVICE_DISABLE_ATTR_DURATION: "00:00:01"}, + SERVICE_DISABLE, + {ATTR_ENTITY_ID: "all", SERVICE_DISABLE_ATTR_DURATION: "00:00:01"}, blocking=True, ) await hass.async_block_till_done() - assert mocked_hole.disable.call_count == 2 + mocked_hole.disable.assert_called_once_with(1) -async def test_enable_service_call(hass): - """Test enable service call with no Pi-hole named.""" +async def test_unload(hass): + """Test unload entities.""" + entry = MockConfigEntry( + domain=pi_hole.DOMAIN, + data={ + CONF_NAME: DEFAULT_NAME, + CONF_HOST: "pi.hole", + CONF_LOCATION: DEFAULT_LOCATION, + CONF_SSL: DEFAULT_SSL, + CONF_VERIFY_SSL: DEFAULT_VERIFY_SSL, + }, + ) + entry.add_to_hass(hass) mocked_hole = _create_mocked_hole() with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): - assert await async_setup_component( - hass, - pi_hole.DOMAIN, - { - pi_hole.DOMAIN: [ - {"host": "pi.hole1", "api_key": "1"}, - {"host": "pi.hole2", "name": "Custom", "api_key": "2"}, - ] - }, - ) - + await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() + assert entry.entry_id in hass.data[pi_hole.DOMAIN] - await hass.services.async_call( - pi_hole.DOMAIN, pi_hole.SERVICE_ENABLE, {}, blocking=True - ) - - await hass.async_block_till_done() - - assert mocked_hole.enable.call_count == 2 - - -async def test_update_coordinator(hass): - """Test update coordinator.""" - mocked_hole = _create_mocked_hole() - sensor_entity_id = "sensor.pi_hole_ads_blocked_today" - with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): - assert await async_setup_component( - hass, pi_hole.DOMAIN, {pi_hole.DOMAIN: [{"host": "pi.hole"}]} - ) - await hass.async_block_till_done() - assert mocked_hole.get_data.call_count == 3 - assert hass.states.get(sensor_entity_id).state == "0" - - mocked_hole.data["ads_blocked_today"] = 1 - utcnow = dt_util.utcnow() - async_fire_time_changed(hass, utcnow + MIN_TIME_BETWEEN_UPDATES) - await hass.async_block_till_done() - assert mocked_hole.get_data.call_count == 4 - assert hass.states.get(sensor_entity_id).state == "1" + assert await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + assert entry.entry_id not in hass.data[pi_hole.DOMAIN]