Add switch to pi_hole integration (#35605)

Co-authored-by: Ian <vividboarder@gmail.com>
This commit is contained in:
Xiaonan Shen 2020-07-18 14:19:01 +08:00 committed by GitHub
parent 1acdb28cdd
commit 394194d1e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 341 additions and 230 deletions

View File

@ -1,11 +1,11 @@
"""The pi_hole component.""" """The pi_hole component."""
import asyncio
import logging import logging
from hole import Hole from hole import Hole
from hole.exceptions import HoleError from hole.exceptions import HoleError
import voluptuous as vol import voluptuous as vol
from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
from homeassistant.config_entries import SOURCE_IMPORT from homeassistant.config_entries import SOURCE_IMPORT
from homeassistant.const import ( from homeassistant.const import (
CONF_API_KEY, CONF_API_KEY,
@ -14,9 +14,11 @@ from homeassistant.const import (
CONF_SSL, CONF_SSL,
CONF_VERIFY_SSL, CONF_VERIFY_SSL,
) )
from homeassistant.core import callback
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from .const import ( from .const import (
@ -29,11 +31,6 @@ from .const import (
DEFAULT_VERIFY_SSL, DEFAULT_VERIFY_SSL,
DOMAIN, DOMAIN,
MIN_TIME_BETWEEN_UPDATES, MIN_TIME_BETWEEN_UPDATES,
SERVICE_DISABLE,
SERVICE_DISABLE_ATTR_DURATION,
SERVICE_DISABLE_ATTR_NAME,
SERVICE_ENABLE,
SERVICE_ENABLE_ATTR_NAME,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -58,20 +55,7 @@ CONFIG_SCHEMA = vol.Schema(
async def async_setup(hass, config): async def async_setup(hass, config):
"""Set up the Pi_hole integration.""" """Set up the Pi-hole integration."""
service_disable_schema = vol.Schema(
vol.All(
{
vol.Required(SERVICE_DISABLE_ATTR_DURATION): vol.All(
cv.time_period_str, cv.positive_timedelta
),
vol.Optional(SERVICE_DISABLE_ATTR_NAME): str,
},
)
)
service_enable_schema = vol.Schema({vol.Optional(SERVICE_ENABLE_ATTR_NAME): str})
hass.data[DOMAIN] = {} hass.data[DOMAIN] = {}
@ -84,71 +68,6 @@ async def async_setup(hass, config):
) )
) )
def get_api_from_name(name):
"""Get Pi-hole API object from user configured name."""
hole_data = hass.data[DOMAIN].get(name)
if hole_data is None:
_LOGGER.error("Unknown Pi-hole name %s", name)
return None
api = hole_data[DATA_KEY_API]
if not api.api_token:
_LOGGER.error(
"Pi-hole %s must have an api_key provided in configuration to be enabled",
name,
)
return None
return api
async def disable_service_handler(call):
"""Handle the service call to disable a single Pi-hole or all configured Pi-holes."""
duration = call.data[SERVICE_DISABLE_ATTR_DURATION].total_seconds()
name = call.data.get(SERVICE_DISABLE_ATTR_NAME)
async def do_disable(name):
"""Disable the named Pi-hole."""
api = get_api_from_name(name)
if api is None:
return
_LOGGER.debug(
"Disabling Pi-hole '%s' (%s) for %d seconds", name, api.host, duration,
)
await api.disable(duration)
if name is not None:
await do_disable(name)
else:
for name in hass.data[DOMAIN]:
await do_disable(name)
async def enable_service_handler(call):
"""Handle the service call to enable a single Pi-hole or all configured Pi-holes."""
name = call.data.get(SERVICE_ENABLE_ATTR_NAME)
async def do_enable(name):
"""Enable the named Pi-hole."""
api = get_api_from_name(name)
if api is None:
return
_LOGGER.debug("Enabling Pi-hole '%s' (%s)", name, api.host)
await api.enable()
if name is not None:
await do_enable(name)
else:
for name in hass.data[DOMAIN]:
await do_enable(name)
hass.services.async_register(
DOMAIN, SERVICE_DISABLE, disable_service_handler, schema=service_disable_schema
)
hass.services.async_register(
DOMAIN, SERVICE_ENABLE, enable_service_handler, schema=service_enable_schema
)
return True return True
@ -187,19 +106,85 @@ async def async_setup_entry(hass, entry):
update_method=async_update_data, update_method=async_update_data,
update_interval=MIN_TIME_BETWEEN_UPDATES, update_interval=MIN_TIME_BETWEEN_UPDATES,
) )
hass.data[DOMAIN][name] = { hass.data[DOMAIN][entry.entry_id] = {
DATA_KEY_API: api, DATA_KEY_API: api,
DATA_KEY_COORDINATOR: coordinator, DATA_KEY_COORDINATOR: coordinator,
} }
hass.async_create_task( for platform in _async_platforms(entry):
hass.config_entries.async_forward_entry_setup(entry, SENSOR_DOMAIN) hass.async_create_task(
) hass.config_entries.async_forward_entry_setup(entry, platform)
)
return True return True
async def async_unload_entry(hass, entry): async def async_unload_entry(hass, entry):
"""Unload pi-hole entry.""" """Unload Pi-hole entry."""
hass.data[DOMAIN].pop(entry.data[CONF_NAME]) unload_ok = all(
return await hass.config_entries.async_forward_entry_unload(entry, SENSOR_DOMAIN) await asyncio.gather(
*[
hass.config_entries.async_forward_entry_unload(entry, platform)
for platform in _async_platforms(entry)
]
)
)
if unload_ok:
hass.data[DOMAIN].pop(entry.entry_id)
return unload_ok
@callback
def _async_platforms(entry):
"""Return platforms to be loaded / unloaded."""
platforms = ["sensor"]
if entry.data.get(CONF_API_KEY):
platforms.append("switch")
else:
platforms.append("binary_sensor")
return platforms
class PiHoleEntity(Entity):
"""Representation of a Pi-hole entity."""
def __init__(self, api, coordinator, name, server_unique_id):
"""Initialize a Pi-hole entity."""
self.api = api
self.coordinator = coordinator
self._name = name
self._server_unique_id = server_unique_id
async def async_added_to_hass(self):
"""When entity is added to hass."""
self.async_on_remove(
self.coordinator.async_add_listener(self.async_write_ha_state)
)
@property
def icon(self):
"""Icon to use in the frontend, if any."""
return "mdi:pi-hole"
@property
def device_info(self):
"""Return the device information of the entity."""
return {
"identifiers": {(DOMAIN, self._server_unique_id)},
"name": self._name,
"manufacturer": "Pi-hole",
}
@property
def available(self):
"""Could the device be accessed during the last update call."""
return self.coordinator.last_update_success
@property
def should_poll(self):
"""No need to poll. Coordinator notifies entity of updates."""
return False
async def async_update(self):
"""Get the latest data from the Pi-hole API."""
await self.coordinator.async_request_refresh()

View File

@ -0,0 +1,44 @@
"""Support for getting status from a Pi-hole system."""
import logging
from homeassistant.components.binary_sensor import BinarySensorEntity
from homeassistant.const import CONF_NAME
from . import PiHoleEntity
from .const import DATA_KEY_API, DATA_KEY_COORDINATOR, DOMAIN as PIHOLE_DOMAIN
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(hass, entry, async_add_entities):
"""Set up the Pi-hole binary sensor."""
name = entry.data[CONF_NAME]
hole_data = hass.data[PIHOLE_DOMAIN][entry.entry_id]
binary_sensors = [
PiHoleBinarySensor(
hole_data[DATA_KEY_API],
hole_data[DATA_KEY_COORDINATOR],
name,
entry.entry_id,
)
]
async_add_entities(binary_sensors, True)
class PiHoleBinarySensor(PiHoleEntity, BinarySensorEntity):
"""Representation of a Pi-hole binary sensor."""
@property
def name(self):
"""Return the name of the sensor."""
return self._name
@property
def unique_id(self):
"""Return the unique id of the sensor."""
return f"{self._server_unique_id}/Status"
@property
def is_on(self):
"""Return if the service is on."""
return self.api.data.get("status") == "enabled"

View File

@ -60,10 +60,6 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
if await self._async_endpoint_existed(endpoint): if await self._async_endpoint_existed(endpoint):
return self.async_abort(reason="already_configured") return self.async_abort(reason="already_configured")
if await self._async_name_existed(name):
if is_import:
_LOGGER.error("Failed to import: name %s already existed", name)
return self.async_abort(reason="duplicated_name")
try: try:
await self._async_try_connect( await self._async_try_connect(
@ -127,12 +123,6 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
] ]
return endpoint in existing_endpoints return endpoint in existing_endpoints
async def _async_name_existed(self, name):
existing_names = [
entry.data.get(CONF_NAME) for entry in self._async_current_entries()
]
return name in existing_names
async def _async_try_connect(self, host, location, tls, verify_tls, api_token): async def _async_try_connect(self, host, location, tls, verify_tls, api_token):
session = async_get_clientsession(self.hass, verify_tls) session = async_get_clientsession(self.hass, verify_tls)
pi_hole = Hole( pi_hole = Hole(

View File

@ -15,9 +15,6 @@ DEFAULT_VERIFY_SSL = True
SERVICE_DISABLE = "disable" SERVICE_DISABLE = "disable"
SERVICE_DISABLE_ATTR_DURATION = "duration" SERVICE_DISABLE_ATTR_DURATION = "duration"
SERVICE_DISABLE_ATTR_NAME = "name"
SERVICE_ENABLE = "enable"
SERVICE_ENABLE_ATTR_NAME = SERVICE_DISABLE_ATTR_NAME
ATTR_BLOCKED_DOMAINS = "domains_blocked" ATTR_BLOCKED_DOMAINS = "domains_blocked"

View File

@ -2,8 +2,8 @@
import logging import logging
from homeassistant.const import CONF_NAME from homeassistant.const import CONF_NAME
from homeassistant.helpers.entity import Entity
from . import PiHoleEntity
from .const import ( from .const import (
ATTR_BLOCKED_DOMAINS, ATTR_BLOCKED_DOMAINS,
DATA_KEY_API, DATA_KEY_API,
@ -19,7 +19,7 @@ LOGGER = logging.getLogger(__name__)
async def async_setup_entry(hass, entry, async_add_entities): async def async_setup_entry(hass, entry, async_add_entities):
"""Set up the Pi-hole sensor.""" """Set up the Pi-hole sensor."""
name = entry.data[CONF_NAME] name = entry.data[CONF_NAME]
hole_data = hass.data[PIHOLE_DOMAIN][name] hole_data = hass.data[PIHOLE_DOMAIN][entry.entry_id]
sensors = [ sensors = [
PiHoleSensor( PiHoleSensor(
hole_data[DATA_KEY_API], hole_data[DATA_KEY_API],
@ -33,28 +33,20 @@ async def async_setup_entry(hass, entry, async_add_entities):
async_add_entities(sensors, True) async_add_entities(sensors, True)
class PiHoleSensor(Entity): class PiHoleSensor(PiHoleEntity):
"""Representation of a Pi-hole sensor.""" """Representation of a Pi-hole sensor."""
def __init__(self, api, coordinator, name, sensor_name, server_unique_id): def __init__(self, api, coordinator, name, sensor_name, server_unique_id):
"""Initialize a Pi-hole sensor.""" """Initialize a Pi-hole sensor."""
self.api = api super().__init__(api, coordinator, name, server_unique_id)
self.coordinator = coordinator
self._name = name
self._condition = sensor_name self._condition = sensor_name
self._server_unique_id = server_unique_id
variable_info = SENSOR_DICT[sensor_name] variable_info = SENSOR_DICT[sensor_name]
self._condition_name = variable_info[0] self._condition_name = variable_info[0]
self._unit_of_measurement = variable_info[1] self._unit_of_measurement = variable_info[1]
self._icon = variable_info[2] self._icon = variable_info[2]
async def async_added_to_hass(self):
"""When entity is added to hass."""
self.async_on_remove(
self.coordinator.async_add_listener(self.async_write_ha_state)
)
@property @property
def name(self): def name(self):
"""Return the name of the sensor.""" """Return the name of the sensor."""
@ -65,15 +57,6 @@ class PiHoleSensor(Entity):
"""Return the unique id of the sensor.""" """Return the unique id of the sensor."""
return f"{self._server_unique_id}/{self._condition_name}" return f"{self._server_unique_id}/{self._condition_name}"
@property
def device_info(self):
"""Return the device information of the sensor."""
return {
"identifiers": {(PIHOLE_DOMAIN, self._server_unique_id)},
"name": self._name,
"manufacturer": "Pi-hole",
}
@property @property
def icon(self): def icon(self):
"""Icon to use in the frontend, if any.""" """Icon to use in the frontend, if any."""
@ -96,17 +79,3 @@ class PiHoleSensor(Entity):
def device_state_attributes(self): def device_state_attributes(self):
"""Return the state attributes of the Pi-hole.""" """Return the state attributes of the Pi-hole."""
return {ATTR_BLOCKED_DOMAINS: self.api.data["domains_being_blocked"]} return {ATTR_BLOCKED_DOMAINS: self.api.data["domains_being_blocked"]}
@property
def available(self):
"""Could the device be accessed during the last update call."""
return self.coordinator.last_update_success
@property
def should_poll(self):
"""No need to poll. Coordinator notifies entity of updates."""
return False
async def async_update(self):
"""Get the latest data from the Pi-hole API."""
await self.coordinator.async_request_refresh()

View File

@ -1,15 +1,9 @@
disable: disable:
description: Disable configured Pi-hole(s) for an amount of time description: Disable configured Pi-hole(s) for an amount of time
fields: fields:
entity_id:
description: Target switch entity
example: switch.pi_hole
duration: duration:
description: Time that the Pi-hole should be disabled for description: Time that the Pi-hole should be disabled for
example: "00:00:15" example: "00:00:15"
name:
description: "[Optional] When multiple Pi-holes are configured, the name of the one to disable. If omitted, all configured Pi-holes will be disabled."
example: "Pi-Hole"
enable:
description: Enable configured Pi-hole(s)
fields:
name:
description: "[Optional] When multiple Pi-holes are configured, the name of the one to enable. If omitted, all configured Pi-holes will be enabled."
example: "Pi-Hole"

View File

@ -6,7 +6,8 @@
"host": "[%key:common::config_flow::data::host%]", "host": "[%key:common::config_flow::data::host%]",
"port": "[%key:common::config_flow::data::port%]", "port": "[%key:common::config_flow::data::port%]",
"name": "Name", "name": "Name",
"api_key": "API Key (Optional)", "location": "Location",
"api_key": "[%key:common::config_flow::data::api_key%]",
"ssl": "Use SSL", "ssl": "Use SSL",
"verify_ssl": "Verify SSL certificate" "verify_ssl": "Verify SSL certificate"
} }
@ -16,8 +17,7 @@
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]" "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]"
}, },
"abort": { "abort": {
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]", "already_configured": "[%key:common::config_flow::abort::already_configured_service%]"
"duplicated_name": "Name already existed"
} }
} }
} }

View File

@ -0,0 +1,100 @@
"""Support for turning on and off Pi-hole system."""
import logging
from hole.exceptions import HoleError
import voluptuous as vol
from homeassistant.components.switch import SwitchEntity
from homeassistant.const import CONF_NAME
from homeassistant.helpers import config_validation as cv, entity_platform
from . import PiHoleEntity
from .const import (
DATA_KEY_API,
DATA_KEY_COORDINATOR,
DOMAIN as PIHOLE_DOMAIN,
SERVICE_DISABLE,
SERVICE_DISABLE_ATTR_DURATION,
)
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(hass, entry, async_add_entities):
"""Set up the Pi-hole switch."""
name = entry.data[CONF_NAME]
hole_data = hass.data[PIHOLE_DOMAIN][entry.entry_id]
switches = [
PiHoleSwitch(
hole_data[DATA_KEY_API],
hole_data[DATA_KEY_COORDINATOR],
name,
entry.entry_id,
)
]
async_add_entities(switches, True)
# register service
platform = entity_platform.current_platform.get()
platform.async_register_entity_service(
SERVICE_DISABLE,
{
vol.Required(SERVICE_DISABLE_ATTR_DURATION): vol.All(
cv.time_period_str, cv.positive_timedelta
),
},
"async_disable",
)
class PiHoleSwitch(PiHoleEntity, SwitchEntity):
"""Representation of a Pi-hole switch."""
@property
def name(self):
"""Return the name of the switch."""
return self._name
@property
def unique_id(self):
"""Return the unique id of the switch."""
return f"{self._server_unique_id}/Switch"
@property
def icon(self):
"""Icon to use in the frontend, if any."""
return "mdi:pi-hole"
@property
def is_on(self):
"""Return if the service is on."""
return self.api.data.get("status") == "enabled"
async def async_turn_on(self, **kwargs):
"""Turn on the service."""
try:
await self.api.enable()
await self.async_update()
except HoleError as err:
_LOGGER.error("Unable to enable Pi-hole: %s", err)
async def async_turn_off(self, **kwargs):
"""Turn off the service."""
await self.async_disable()
async def async_disable(self, duration=None):
"""Disable the service for a given duration."""
duration_seconds = True # Disable infinitely by default
if duration is not None:
duration_seconds = duration.total_seconds()
_LOGGER.debug(
"Disabling Pi-hole '%s' (%s) for %d seconds",
self.name,
self.api.host,
duration_seconds,
)
try:
await self.api.disable(duration_seconds)
await self.async_update()
except HoleError as err:
_LOGGER.error("Unable to disable Pi-hole: %s", err)

View File

@ -21,7 +21,7 @@ ZERO_DATA = {
"domains_being_blocked": 0, "domains_being_blocked": 0,
"queries_cached": 0, "queries_cached": 0,
"queries_forwarded": 0, "queries_forwarded": 0,
"status": 0, "status": "disabled",
"unique_clients": 0, "unique_clients": 0,
"unique_domains": 0, "unique_domains": 0,
} }
@ -29,7 +29,7 @@ ZERO_DATA = {
HOST = "1.2.3.4" HOST = "1.2.3.4"
PORT = 80 PORT = 80
LOCATION = "location" LOCATION = "location"
NAME = "name" NAME = "Pi hole"
API_KEY = "apikey" API_KEY = "apikey"
SSL = False SSL = False
VERIFY_SSL = True VERIFY_SSL = True
@ -53,6 +53,8 @@ CONF_CONFIG_FLOW = {
CONF_VERIFY_SSL: VERIFY_SSL, CONF_VERIFY_SSL: VERIFY_SSL,
} }
SWITCH_ENTITY_ID = "switch.pi_hole"
def _create_mocked_hole(raise_exception=False): def _create_mocked_hole(raise_exception=False):
mocked_hole = MagicMock() mocked_hole = MagicMock()
@ -65,6 +67,10 @@ def _create_mocked_hole(raise_exception=False):
return mocked_hole return mocked_hole
def _patch_init_hole(mocked_hole):
return patch("homeassistant.components.pi_hole.Hole", return_value=mocked_hole)
def _patch_config_flow_hole(mocked_hole): def _patch_config_flow_hole(mocked_hole):
return patch( return patch(
"homeassistant.components.pi_hole.config_flow.Hole", return_value=mocked_hole "homeassistant.components.pi_hole.config_flow.Hole", return_value=mocked_hole

View File

@ -1,5 +1,4 @@
"""Test pi_hole config flow.""" """Test pi_hole config flow."""
import copy
import logging import logging
from homeassistant.components.pi_hole.const import DOMAIN from homeassistant.components.pi_hole.const import DOMAIN
@ -13,7 +12,6 @@ from homeassistant.data_entry_flow import (
from . import ( from . import (
CONF_CONFIG_FLOW, CONF_CONFIG_FLOW,
CONF_DATA, CONF_DATA,
CONF_HOST,
NAME, NAME,
_create_mocked_hole, _create_mocked_hole,
_patch_config_flow_hole, _patch_config_flow_hole,
@ -54,16 +52,6 @@ async def test_flow_import(hass, caplog):
assert result["type"] == RESULT_TYPE_ABORT assert result["type"] == RESULT_TYPE_ABORT
assert result["reason"] == "already_configured" assert result["reason"] == "already_configured"
# duplicated name
conf_data = copy.deepcopy(CONF_DATA)
conf_data[CONF_HOST] = "4.3.2.1"
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_IMPORT}, data=conf_data
)
assert result["type"] == RESULT_TYPE_ABORT
assert result["reason"] == "duplicated_name"
assert len([x for x in caplog.records if x.levelno == logging.ERROR]) == 1
async def test_flow_import_invalid(hass, caplog): async def test_flow_import_invalid(hass, caplog):
"""Test import flow with invalid server.""" """Test import flow with invalid server."""
@ -103,15 +91,6 @@ async def test_flow_user(hass):
assert result["type"] == RESULT_TYPE_ABORT assert result["type"] == RESULT_TYPE_ABORT
assert result["reason"] == "already_configured" assert result["reason"] == "already_configured"
# duplicated name
conf_data = copy.deepcopy(CONF_CONFIG_FLOW)
conf_data[CONF_HOST] = "4.3.2.1"
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}, data=conf_data
)
assert result["type"] == RESULT_TYPE_ABORT
assert result["reason"] == "duplicated_name"
async def test_flow_user_invalid(hass): async def test_flow_user_invalid(hass):
"""Test user initialized flow with invalid server.""" """Test user initialized flow with invalid server."""

View File

@ -1,18 +1,36 @@
"""Test pi_hole component.""" """Test pi_hole component."""
import logging
from homeassistant.components import pi_hole from hole.exceptions import HoleError
from homeassistant.components.pi_hole.const import MIN_TIME_BETWEEN_UPDATES
from homeassistant.components import pi_hole, switch
from homeassistant.components.pi_hole.const import (
CONF_LOCATION,
DEFAULT_LOCATION,
DEFAULT_NAME,
DEFAULT_SSL,
DEFAULT_VERIFY_SSL,
SERVICE_DISABLE,
SERVICE_DISABLE_ATTR_DURATION,
)
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_HOST,
CONF_NAME,
CONF_SSL,
CONF_VERIFY_SSL,
)
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
from . import _create_mocked_hole, _patch_config_flow_hole from . import (
SWITCH_ENTITY_ID,
_create_mocked_hole,
_patch_config_flow_hole,
_patch_init_hole,
)
from tests.async_mock import patch from tests.async_mock import AsyncMock
from tests.common import async_fire_time_changed from tests.common import MockConfigEntry
def _patch_init_hole(mocked_hole):
return patch("homeassistant.components.pi_hole.Hole", return_value=mocked_hole)
async def test_setup_minimal_config(hass): async def test_setup_minimal_config(hass):
@ -69,6 +87,9 @@ async def test_setup_minimal_config(hass):
assert hass.states.get("sensor.pi_hole_domains_blocked").state == "0" assert hass.states.get("sensor.pi_hole_domains_blocked").state == "0"
assert hass.states.get("sensor.pi_hole_seen_clients").state == "0" assert hass.states.get("sensor.pi_hole_seen_clients").state == "0"
assert hass.states.get("binary_sensor.pi_hole").name == "Pi-Hole"
assert hass.states.get("binary_sensor.pi_hole").state == "off"
async def test_setup_name_config(hass): async def test_setup_name_config(hass):
"""Tests component setup with a custom name.""" """Tests component setup with a custom name."""
@ -88,6 +109,54 @@ async def test_setup_name_config(hass):
) )
async def test_switch(hass, caplog):
"""Test Pi-hole switch."""
mocked_hole = _create_mocked_hole()
with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole):
assert await async_setup_component(
hass,
pi_hole.DOMAIN,
{pi_hole.DOMAIN: [{"host": "pi.hole1", "api_key": "1"}]},
)
await hass.async_block_till_done()
await hass.services.async_call(
switch.DOMAIN,
switch.SERVICE_TURN_ON,
{"entity_id": SWITCH_ENTITY_ID},
blocking=True,
)
mocked_hole.enable.assert_called_once()
await hass.services.async_call(
switch.DOMAIN,
switch.SERVICE_TURN_OFF,
{"entity_id": SWITCH_ENTITY_ID},
blocking=True,
)
mocked_hole.disable.assert_called_once_with(True)
# Failed calls
type(mocked_hole).enable = AsyncMock(side_effect=HoleError("Error1"))
await hass.services.async_call(
switch.DOMAIN,
switch.SERVICE_TURN_ON,
{"entity_id": SWITCH_ENTITY_ID},
blocking=True,
)
type(mocked_hole).disable = AsyncMock(side_effect=HoleError("Error2"))
await hass.services.async_call(
switch.DOMAIN,
switch.SERVICE_TURN_OFF,
{"entity_id": SWITCH_ENTITY_ID},
blocking=True,
)
errors = [x for x in caplog.records if x.levelno == logging.ERROR]
assert errors[-2].message == "Unable to enable Pi-hole: Error1"
assert errors[-1].message == "Unable to disable Pi-hole: Error2"
async def test_disable_service_call(hass): async def test_disable_service_call(hass):
"""Test disable service call with no Pi-hole named.""" """Test disable service call with no Pi-hole named."""
mocked_hole = _create_mocked_hole() mocked_hole = _create_mocked_hole()
@ -98,7 +167,7 @@ async def test_disable_service_call(hass):
{ {
pi_hole.DOMAIN: [ pi_hole.DOMAIN: [
{"host": "pi.hole1", "api_key": "1"}, {"host": "pi.hole1", "api_key": "1"},
{"host": "pi.hole2", "name": "Custom", "api_key": "2"}, {"host": "pi.hole2", "name": "Custom"},
] ]
}, },
) )
@ -107,57 +176,35 @@ async def test_disable_service_call(hass):
await hass.services.async_call( await hass.services.async_call(
pi_hole.DOMAIN, pi_hole.DOMAIN,
pi_hole.SERVICE_DISABLE, SERVICE_DISABLE,
{pi_hole.SERVICE_DISABLE_ATTR_DURATION: "00:00:01"}, {ATTR_ENTITY_ID: "all", SERVICE_DISABLE_ATTR_DURATION: "00:00:01"},
blocking=True, blocking=True,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert mocked_hole.disable.call_count == 2 mocked_hole.disable.assert_called_once_with(1)
async def test_enable_service_call(hass): async def test_unload(hass):
"""Test enable service call with no Pi-hole named.""" """Test unload entities."""
entry = MockConfigEntry(
domain=pi_hole.DOMAIN,
data={
CONF_NAME: DEFAULT_NAME,
CONF_HOST: "pi.hole",
CONF_LOCATION: DEFAULT_LOCATION,
CONF_SSL: DEFAULT_SSL,
CONF_VERIFY_SSL: DEFAULT_VERIFY_SSL,
},
)
entry.add_to_hass(hass)
mocked_hole = _create_mocked_hole() mocked_hole = _create_mocked_hole()
with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole):
assert await async_setup_component( await hass.config_entries.async_setup(entry.entry_id)
hass,
pi_hole.DOMAIN,
{
pi_hole.DOMAIN: [
{"host": "pi.hole1", "api_key": "1"},
{"host": "pi.hole2", "name": "Custom", "api_key": "2"},
]
},
)
await hass.async_block_till_done() await hass.async_block_till_done()
assert entry.entry_id in hass.data[pi_hole.DOMAIN]
await hass.services.async_call( assert await hass.config_entries.async_unload(entry.entry_id)
pi_hole.DOMAIN, pi_hole.SERVICE_ENABLE, {}, blocking=True await hass.async_block_till_done()
) assert entry.entry_id not in hass.data[pi_hole.DOMAIN]
await hass.async_block_till_done()
assert mocked_hole.enable.call_count == 2
async def test_update_coordinator(hass):
"""Test update coordinator."""
mocked_hole = _create_mocked_hole()
sensor_entity_id = "sensor.pi_hole_ads_blocked_today"
with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole):
assert await async_setup_component(
hass, pi_hole.DOMAIN, {pi_hole.DOMAIN: [{"host": "pi.hole"}]}
)
await hass.async_block_till_done()
assert mocked_hole.get_data.call_count == 3
assert hass.states.get(sensor_entity_id).state == "0"
mocked_hole.data["ads_blocked_today"] = 1
utcnow = dt_util.utcnow()
async_fire_time_changed(hass, utcnow + MIN_TIME_BETWEEN_UPDATES)
await hass.async_block_till_done()
assert mocked_hole.get_data.call_count == 4
assert hass.states.get(sensor_entity_id).state == "1"