mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Add switch to pi_hole integration (#35605)
Co-authored-by: Ian <vividboarder@gmail.com>
This commit is contained in:
parent
1acdb28cdd
commit
394194d1e6
@ -1,11 +1,11 @@
|
||||
"""The pi_hole component."""
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from hole import Hole
|
||||
from hole.exceptions import HoleError
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
|
||||
from homeassistant.config_entries import SOURCE_IMPORT
|
||||
from homeassistant.const import (
|
||||
CONF_API_KEY,
|
||||
@ -14,9 +14,11 @@ from homeassistant.const import (
|
||||
CONF_SSL,
|
||||
CONF_VERIFY_SSL,
|
||||
)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
|
||||
|
||||
from .const import (
|
||||
@ -29,11 +31,6 @@ from .const import (
|
||||
DEFAULT_VERIFY_SSL,
|
||||
DOMAIN,
|
||||
MIN_TIME_BETWEEN_UPDATES,
|
||||
SERVICE_DISABLE,
|
||||
SERVICE_DISABLE_ATTR_DURATION,
|
||||
SERVICE_DISABLE_ATTR_NAME,
|
||||
SERVICE_ENABLE,
|
||||
SERVICE_ENABLE_ATTR_NAME,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -58,20 +55,7 @@ CONFIG_SCHEMA = vol.Schema(
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
"""Set up the Pi_hole integration."""
|
||||
|
||||
service_disable_schema = vol.Schema(
|
||||
vol.All(
|
||||
{
|
||||
vol.Required(SERVICE_DISABLE_ATTR_DURATION): vol.All(
|
||||
cv.time_period_str, cv.positive_timedelta
|
||||
),
|
||||
vol.Optional(SERVICE_DISABLE_ATTR_NAME): str,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
service_enable_schema = vol.Schema({vol.Optional(SERVICE_ENABLE_ATTR_NAME): str})
|
||||
"""Set up the Pi-hole integration."""
|
||||
|
||||
hass.data[DOMAIN] = {}
|
||||
|
||||
@ -84,71 +68,6 @@ async def async_setup(hass, config):
|
||||
)
|
||||
)
|
||||
|
||||
def get_api_from_name(name):
|
||||
"""Get Pi-hole API object from user configured name."""
|
||||
hole_data = hass.data[DOMAIN].get(name)
|
||||
if hole_data is None:
|
||||
_LOGGER.error("Unknown Pi-hole name %s", name)
|
||||
return None
|
||||
api = hole_data[DATA_KEY_API]
|
||||
if not api.api_token:
|
||||
_LOGGER.error(
|
||||
"Pi-hole %s must have an api_key provided in configuration to be enabled",
|
||||
name,
|
||||
)
|
||||
return None
|
||||
return api
|
||||
|
||||
async def disable_service_handler(call):
|
||||
"""Handle the service call to disable a single Pi-hole or all configured Pi-holes."""
|
||||
duration = call.data[SERVICE_DISABLE_ATTR_DURATION].total_seconds()
|
||||
name = call.data.get(SERVICE_DISABLE_ATTR_NAME)
|
||||
|
||||
async def do_disable(name):
|
||||
"""Disable the named Pi-hole."""
|
||||
api = get_api_from_name(name)
|
||||
if api is None:
|
||||
return
|
||||
|
||||
_LOGGER.debug(
|
||||
"Disabling Pi-hole '%s' (%s) for %d seconds", name, api.host, duration,
|
||||
)
|
||||
await api.disable(duration)
|
||||
|
||||
if name is not None:
|
||||
await do_disable(name)
|
||||
else:
|
||||
for name in hass.data[DOMAIN]:
|
||||
await do_disable(name)
|
||||
|
||||
async def enable_service_handler(call):
|
||||
"""Handle the service call to enable a single Pi-hole or all configured Pi-holes."""
|
||||
|
||||
name = call.data.get(SERVICE_ENABLE_ATTR_NAME)
|
||||
|
||||
async def do_enable(name):
|
||||
"""Enable the named Pi-hole."""
|
||||
api = get_api_from_name(name)
|
||||
if api is None:
|
||||
return
|
||||
|
||||
_LOGGER.debug("Enabling Pi-hole '%s' (%s)", name, api.host)
|
||||
await api.enable()
|
||||
|
||||
if name is not None:
|
||||
await do_enable(name)
|
||||
else:
|
||||
for name in hass.data[DOMAIN]:
|
||||
await do_enable(name)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_DISABLE, disable_service_handler, schema=service_disable_schema
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_ENABLE, enable_service_handler, schema=service_enable_schema
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@ -187,19 +106,85 @@ async def async_setup_entry(hass, entry):
|
||||
update_method=async_update_data,
|
||||
update_interval=MIN_TIME_BETWEEN_UPDATES,
|
||||
)
|
||||
hass.data[DOMAIN][name] = {
|
||||
hass.data[DOMAIN][entry.entry_id] = {
|
||||
DATA_KEY_API: api,
|
||||
DATA_KEY_COORDINATOR: coordinator,
|
||||
}
|
||||
|
||||
hass.async_create_task(
|
||||
hass.config_entries.async_forward_entry_setup(entry, SENSOR_DOMAIN)
|
||||
)
|
||||
for platform in _async_platforms(entry):
|
||||
hass.async_create_task(
|
||||
hass.config_entries.async_forward_entry_setup(entry, platform)
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass, entry):
|
||||
"""Unload pi-hole entry."""
|
||||
hass.data[DOMAIN].pop(entry.data[CONF_NAME])
|
||||
return await hass.config_entries.async_forward_entry_unload(entry, SENSOR_DOMAIN)
|
||||
"""Unload Pi-hole entry."""
|
||||
unload_ok = all(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
hass.config_entries.async_forward_entry_unload(entry, platform)
|
||||
for platform in _async_platforms(entry)
|
||||
]
|
||||
)
|
||||
)
|
||||
if unload_ok:
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
return unload_ok
|
||||
|
||||
|
||||
@callback
|
||||
def _async_platforms(entry):
|
||||
"""Return platforms to be loaded / unloaded."""
|
||||
platforms = ["sensor"]
|
||||
if entry.data.get(CONF_API_KEY):
|
||||
platforms.append("switch")
|
||||
else:
|
||||
platforms.append("binary_sensor")
|
||||
return platforms
|
||||
|
||||
|
||||
class PiHoleEntity(Entity):
|
||||
"""Representation of a Pi-hole entity."""
|
||||
|
||||
def __init__(self, api, coordinator, name, server_unique_id):
|
||||
"""Initialize a Pi-hole entity."""
|
||||
self.api = api
|
||||
self.coordinator = coordinator
|
||||
self._name = name
|
||||
self._server_unique_id = server_unique_id
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
"""When entity is added to hass."""
|
||||
self.async_on_remove(
|
||||
self.coordinator.async_add_listener(self.async_write_ha_state)
|
||||
)
|
||||
|
||||
@property
|
||||
def icon(self):
|
||||
"""Icon to use in the frontend, if any."""
|
||||
return "mdi:pi-hole"
|
||||
|
||||
@property
|
||||
def device_info(self):
|
||||
"""Return the device information of the entity."""
|
||||
return {
|
||||
"identifiers": {(DOMAIN, self._server_unique_id)},
|
||||
"name": self._name,
|
||||
"manufacturer": "Pi-hole",
|
||||
}
|
||||
|
||||
@property
|
||||
def available(self):
|
||||
"""Could the device be accessed during the last update call."""
|
||||
return self.coordinator.last_update_success
|
||||
|
||||
@property
|
||||
def should_poll(self):
|
||||
"""No need to poll. Coordinator notifies entity of updates."""
|
||||
return False
|
||||
|
||||
async def async_update(self):
|
||||
"""Get the latest data from the Pi-hole API."""
|
||||
await self.coordinator.async_request_refresh()
|
||||
|
44
homeassistant/components/pi_hole/binary_sensor.py
Normal file
44
homeassistant/components/pi_hole/binary_sensor.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Support for getting status from a Pi-hole system."""
|
||||
import logging
|
||||
|
||||
from homeassistant.components.binary_sensor import BinarySensorEntity
|
||||
from homeassistant.const import CONF_NAME
|
||||
|
||||
from . import PiHoleEntity
|
||||
from .const import DATA_KEY_API, DATA_KEY_COORDINATOR, DOMAIN as PIHOLE_DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(hass, entry, async_add_entities):
|
||||
"""Set up the Pi-hole binary sensor."""
|
||||
name = entry.data[CONF_NAME]
|
||||
hole_data = hass.data[PIHOLE_DOMAIN][entry.entry_id]
|
||||
binary_sensors = [
|
||||
PiHoleBinarySensor(
|
||||
hole_data[DATA_KEY_API],
|
||||
hole_data[DATA_KEY_COORDINATOR],
|
||||
name,
|
||||
entry.entry_id,
|
||||
)
|
||||
]
|
||||
async_add_entities(binary_sensors, True)
|
||||
|
||||
|
||||
class PiHoleBinarySensor(PiHoleEntity, BinarySensorEntity):
|
||||
"""Representation of a Pi-hole binary sensor."""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the sensor."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def unique_id(self):
|
||||
"""Return the unique id of the sensor."""
|
||||
return f"{self._server_unique_id}/Status"
|
||||
|
||||
@property
|
||||
def is_on(self):
|
||||
"""Return if the service is on."""
|
||||
return self.api.data.get("status") == "enabled"
|
@ -60,10 +60,6 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
|
||||
if await self._async_endpoint_existed(endpoint):
|
||||
return self.async_abort(reason="already_configured")
|
||||
if await self._async_name_existed(name):
|
||||
if is_import:
|
||||
_LOGGER.error("Failed to import: name %s already existed", name)
|
||||
return self.async_abort(reason="duplicated_name")
|
||||
|
||||
try:
|
||||
await self._async_try_connect(
|
||||
@ -127,12 +123,6 @@ class PiHoleFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
]
|
||||
return endpoint in existing_endpoints
|
||||
|
||||
async def _async_name_existed(self, name):
|
||||
existing_names = [
|
||||
entry.data.get(CONF_NAME) for entry in self._async_current_entries()
|
||||
]
|
||||
return name in existing_names
|
||||
|
||||
async def _async_try_connect(self, host, location, tls, verify_tls, api_token):
|
||||
session = async_get_clientsession(self.hass, verify_tls)
|
||||
pi_hole = Hole(
|
||||
|
@ -15,9 +15,6 @@ DEFAULT_VERIFY_SSL = True
|
||||
|
||||
SERVICE_DISABLE = "disable"
|
||||
SERVICE_DISABLE_ATTR_DURATION = "duration"
|
||||
SERVICE_DISABLE_ATTR_NAME = "name"
|
||||
SERVICE_ENABLE = "enable"
|
||||
SERVICE_ENABLE_ATTR_NAME = SERVICE_DISABLE_ATTR_NAME
|
||||
|
||||
ATTR_BLOCKED_DOMAINS = "domains_blocked"
|
||||
|
||||
|
@ -2,8 +2,8 @@
|
||||
import logging
|
||||
|
||||
from homeassistant.const import CONF_NAME
|
||||
from homeassistant.helpers.entity import Entity
|
||||
|
||||
from . import PiHoleEntity
|
||||
from .const import (
|
||||
ATTR_BLOCKED_DOMAINS,
|
||||
DATA_KEY_API,
|
||||
@ -19,7 +19,7 @@ LOGGER = logging.getLogger(__name__)
|
||||
async def async_setup_entry(hass, entry, async_add_entities):
|
||||
"""Set up the Pi-hole sensor."""
|
||||
name = entry.data[CONF_NAME]
|
||||
hole_data = hass.data[PIHOLE_DOMAIN][name]
|
||||
hole_data = hass.data[PIHOLE_DOMAIN][entry.entry_id]
|
||||
sensors = [
|
||||
PiHoleSensor(
|
||||
hole_data[DATA_KEY_API],
|
||||
@ -33,28 +33,20 @@ async def async_setup_entry(hass, entry, async_add_entities):
|
||||
async_add_entities(sensors, True)
|
||||
|
||||
|
||||
class PiHoleSensor(Entity):
|
||||
class PiHoleSensor(PiHoleEntity):
|
||||
"""Representation of a Pi-hole sensor."""
|
||||
|
||||
def __init__(self, api, coordinator, name, sensor_name, server_unique_id):
|
||||
"""Initialize a Pi-hole sensor."""
|
||||
self.api = api
|
||||
self.coordinator = coordinator
|
||||
self._name = name
|
||||
super().__init__(api, coordinator, name, server_unique_id)
|
||||
|
||||
self._condition = sensor_name
|
||||
self._server_unique_id = server_unique_id
|
||||
|
||||
variable_info = SENSOR_DICT[sensor_name]
|
||||
self._condition_name = variable_info[0]
|
||||
self._unit_of_measurement = variable_info[1]
|
||||
self._icon = variable_info[2]
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
"""When entity is added to hass."""
|
||||
self.async_on_remove(
|
||||
self.coordinator.async_add_listener(self.async_write_ha_state)
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the sensor."""
|
||||
@ -65,15 +57,6 @@ class PiHoleSensor(Entity):
|
||||
"""Return the unique id of the sensor."""
|
||||
return f"{self._server_unique_id}/{self._condition_name}"
|
||||
|
||||
@property
|
||||
def device_info(self):
|
||||
"""Return the device information of the sensor."""
|
||||
return {
|
||||
"identifiers": {(PIHOLE_DOMAIN, self._server_unique_id)},
|
||||
"name": self._name,
|
||||
"manufacturer": "Pi-hole",
|
||||
}
|
||||
|
||||
@property
|
||||
def icon(self):
|
||||
"""Icon to use in the frontend, if any."""
|
||||
@ -96,17 +79,3 @@ class PiHoleSensor(Entity):
|
||||
def device_state_attributes(self):
|
||||
"""Return the state attributes of the Pi-hole."""
|
||||
return {ATTR_BLOCKED_DOMAINS: self.api.data["domains_being_blocked"]}
|
||||
|
||||
@property
|
||||
def available(self):
|
||||
"""Could the device be accessed during the last update call."""
|
||||
return self.coordinator.last_update_success
|
||||
|
||||
@property
|
||||
def should_poll(self):
|
||||
"""No need to poll. Coordinator notifies entity of updates."""
|
||||
return False
|
||||
|
||||
async def async_update(self):
|
||||
"""Get the latest data from the Pi-hole API."""
|
||||
await self.coordinator.async_request_refresh()
|
||||
|
@ -1,15 +1,9 @@
|
||||
disable:
|
||||
description: Disable configured Pi-hole(s) for an amount of time
|
||||
fields:
|
||||
entity_id:
|
||||
description: Target switch entity
|
||||
example: switch.pi_hole
|
||||
duration:
|
||||
description: Time that the Pi-hole should be disabled for
|
||||
example: "00:00:15"
|
||||
name:
|
||||
description: "[Optional] When multiple Pi-holes are configured, the name of the one to disable. If omitted, all configured Pi-holes will be disabled."
|
||||
example: "Pi-Hole"
|
||||
enable:
|
||||
description: Enable configured Pi-hole(s)
|
||||
fields:
|
||||
name:
|
||||
description: "[Optional] When multiple Pi-holes are configured, the name of the one to enable. If omitted, all configured Pi-holes will be enabled."
|
||||
example: "Pi-Hole"
|
||||
|
@ -6,7 +6,8 @@
|
||||
"host": "[%key:common::config_flow::data::host%]",
|
||||
"port": "[%key:common::config_flow::data::port%]",
|
||||
"name": "Name",
|
||||
"api_key": "API Key (Optional)",
|
||||
"location": "Location",
|
||||
"api_key": "[%key:common::config_flow::data::api_key%]",
|
||||
"ssl": "Use SSL",
|
||||
"verify_ssl": "Verify SSL certificate"
|
||||
}
|
||||
@ -16,8 +17,7 @@
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]"
|
||||
},
|
||||
"abort": {
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]",
|
||||
"duplicated_name": "Name already existed"
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
100
homeassistant/components/pi_hole/switch.py
Normal file
100
homeassistant/components/pi_hole/switch.py
Normal file
@ -0,0 +1,100 @@
|
||||
"""Support for turning on and off Pi-hole system."""
|
||||
import logging
|
||||
|
||||
from hole.exceptions import HoleError
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.switch import SwitchEntity
|
||||
from homeassistant.const import CONF_NAME
|
||||
from homeassistant.helpers import config_validation as cv, entity_platform
|
||||
|
||||
from . import PiHoleEntity
|
||||
from .const import (
|
||||
DATA_KEY_API,
|
||||
DATA_KEY_COORDINATOR,
|
||||
DOMAIN as PIHOLE_DOMAIN,
|
||||
SERVICE_DISABLE,
|
||||
SERVICE_DISABLE_ATTR_DURATION,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(hass, entry, async_add_entities):
|
||||
"""Set up the Pi-hole switch."""
|
||||
name = entry.data[CONF_NAME]
|
||||
hole_data = hass.data[PIHOLE_DOMAIN][entry.entry_id]
|
||||
switches = [
|
||||
PiHoleSwitch(
|
||||
hole_data[DATA_KEY_API],
|
||||
hole_data[DATA_KEY_COORDINATOR],
|
||||
name,
|
||||
entry.entry_id,
|
||||
)
|
||||
]
|
||||
async_add_entities(switches, True)
|
||||
|
||||
# register service
|
||||
platform = entity_platform.current_platform.get()
|
||||
platform.async_register_entity_service(
|
||||
SERVICE_DISABLE,
|
||||
{
|
||||
vol.Required(SERVICE_DISABLE_ATTR_DURATION): vol.All(
|
||||
cv.time_period_str, cv.positive_timedelta
|
||||
),
|
||||
},
|
||||
"async_disable",
|
||||
)
|
||||
|
||||
|
||||
class PiHoleSwitch(PiHoleEntity, SwitchEntity):
|
||||
"""Representation of a Pi-hole switch."""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the switch."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def unique_id(self):
|
||||
"""Return the unique id of the switch."""
|
||||
return f"{self._server_unique_id}/Switch"
|
||||
|
||||
@property
|
||||
def icon(self):
|
||||
"""Icon to use in the frontend, if any."""
|
||||
return "mdi:pi-hole"
|
||||
|
||||
@property
|
||||
def is_on(self):
|
||||
"""Return if the service is on."""
|
||||
return self.api.data.get("status") == "enabled"
|
||||
|
||||
async def async_turn_on(self, **kwargs):
|
||||
"""Turn on the service."""
|
||||
try:
|
||||
await self.api.enable()
|
||||
await self.async_update()
|
||||
except HoleError as err:
|
||||
_LOGGER.error("Unable to enable Pi-hole: %s", err)
|
||||
|
||||
async def async_turn_off(self, **kwargs):
|
||||
"""Turn off the service."""
|
||||
await self.async_disable()
|
||||
|
||||
async def async_disable(self, duration=None):
|
||||
"""Disable the service for a given duration."""
|
||||
duration_seconds = True # Disable infinitely by default
|
||||
if duration is not None:
|
||||
duration_seconds = duration.total_seconds()
|
||||
_LOGGER.debug(
|
||||
"Disabling Pi-hole '%s' (%s) for %d seconds",
|
||||
self.name,
|
||||
self.api.host,
|
||||
duration_seconds,
|
||||
)
|
||||
try:
|
||||
await self.api.disable(duration_seconds)
|
||||
await self.async_update()
|
||||
except HoleError as err:
|
||||
_LOGGER.error("Unable to disable Pi-hole: %s", err)
|
@ -21,7 +21,7 @@ ZERO_DATA = {
|
||||
"domains_being_blocked": 0,
|
||||
"queries_cached": 0,
|
||||
"queries_forwarded": 0,
|
||||
"status": 0,
|
||||
"status": "disabled",
|
||||
"unique_clients": 0,
|
||||
"unique_domains": 0,
|
||||
}
|
||||
@ -29,7 +29,7 @@ ZERO_DATA = {
|
||||
HOST = "1.2.3.4"
|
||||
PORT = 80
|
||||
LOCATION = "location"
|
||||
NAME = "name"
|
||||
NAME = "Pi hole"
|
||||
API_KEY = "apikey"
|
||||
SSL = False
|
||||
VERIFY_SSL = True
|
||||
@ -53,6 +53,8 @@ CONF_CONFIG_FLOW = {
|
||||
CONF_VERIFY_SSL: VERIFY_SSL,
|
||||
}
|
||||
|
||||
SWITCH_ENTITY_ID = "switch.pi_hole"
|
||||
|
||||
|
||||
def _create_mocked_hole(raise_exception=False):
|
||||
mocked_hole = MagicMock()
|
||||
@ -65,6 +67,10 @@ def _create_mocked_hole(raise_exception=False):
|
||||
return mocked_hole
|
||||
|
||||
|
||||
def _patch_init_hole(mocked_hole):
|
||||
return patch("homeassistant.components.pi_hole.Hole", return_value=mocked_hole)
|
||||
|
||||
|
||||
def _patch_config_flow_hole(mocked_hole):
|
||||
return patch(
|
||||
"homeassistant.components.pi_hole.config_flow.Hole", return_value=mocked_hole
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Test pi_hole config flow."""
|
||||
import copy
|
||||
import logging
|
||||
|
||||
from homeassistant.components.pi_hole.const import DOMAIN
|
||||
@ -13,7 +12,6 @@ from homeassistant.data_entry_flow import (
|
||||
from . import (
|
||||
CONF_CONFIG_FLOW,
|
||||
CONF_DATA,
|
||||
CONF_HOST,
|
||||
NAME,
|
||||
_create_mocked_hole,
|
||||
_patch_config_flow_hole,
|
||||
@ -54,16 +52,6 @@ async def test_flow_import(hass, caplog):
|
||||
assert result["type"] == RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "already_configured"
|
||||
|
||||
# duplicated name
|
||||
conf_data = copy.deepcopy(CONF_DATA)
|
||||
conf_data[CONF_HOST] = "4.3.2.1"
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_IMPORT}, data=conf_data
|
||||
)
|
||||
assert result["type"] == RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "duplicated_name"
|
||||
assert len([x for x in caplog.records if x.levelno == logging.ERROR]) == 1
|
||||
|
||||
|
||||
async def test_flow_import_invalid(hass, caplog):
|
||||
"""Test import flow with invalid server."""
|
||||
@ -103,15 +91,6 @@ async def test_flow_user(hass):
|
||||
assert result["type"] == RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "already_configured"
|
||||
|
||||
# duplicated name
|
||||
conf_data = copy.deepcopy(CONF_CONFIG_FLOW)
|
||||
conf_data[CONF_HOST] = "4.3.2.1"
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}, data=conf_data
|
||||
)
|
||||
assert result["type"] == RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "duplicated_name"
|
||||
|
||||
|
||||
async def test_flow_user_invalid(hass):
|
||||
"""Test user initialized flow with invalid server."""
|
||||
|
@ -1,18 +1,36 @@
|
||||
"""Test pi_hole component."""
|
||||
import logging
|
||||
|
||||
from homeassistant.components import pi_hole
|
||||
from homeassistant.components.pi_hole.const import MIN_TIME_BETWEEN_UPDATES
|
||||
from hole.exceptions import HoleError
|
||||
|
||||
from homeassistant.components import pi_hole, switch
|
||||
from homeassistant.components.pi_hole.const import (
|
||||
CONF_LOCATION,
|
||||
DEFAULT_LOCATION,
|
||||
DEFAULT_NAME,
|
||||
DEFAULT_SSL,
|
||||
DEFAULT_VERIFY_SSL,
|
||||
SERVICE_DISABLE,
|
||||
SERVICE_DISABLE_ATTR_DURATION,
|
||||
)
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
CONF_HOST,
|
||||
CONF_NAME,
|
||||
CONF_SSL,
|
||||
CONF_VERIFY_SSL,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import _create_mocked_hole, _patch_config_flow_hole
|
||||
from . import (
|
||||
SWITCH_ENTITY_ID,
|
||||
_create_mocked_hole,
|
||||
_patch_config_flow_hole,
|
||||
_patch_init_hole,
|
||||
)
|
||||
|
||||
from tests.async_mock import patch
|
||||
from tests.common import async_fire_time_changed
|
||||
|
||||
|
||||
def _patch_init_hole(mocked_hole):
|
||||
return patch("homeassistant.components.pi_hole.Hole", return_value=mocked_hole)
|
||||
from tests.async_mock import AsyncMock
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_setup_minimal_config(hass):
|
||||
@ -69,6 +87,9 @@ async def test_setup_minimal_config(hass):
|
||||
assert hass.states.get("sensor.pi_hole_domains_blocked").state == "0"
|
||||
assert hass.states.get("sensor.pi_hole_seen_clients").state == "0"
|
||||
|
||||
assert hass.states.get("binary_sensor.pi_hole").name == "Pi-Hole"
|
||||
assert hass.states.get("binary_sensor.pi_hole").state == "off"
|
||||
|
||||
|
||||
async def test_setup_name_config(hass):
|
||||
"""Tests component setup with a custom name."""
|
||||
@ -88,6 +109,54 @@ async def test_setup_name_config(hass):
|
||||
)
|
||||
|
||||
|
||||
async def test_switch(hass, caplog):
|
||||
"""Test Pi-hole switch."""
|
||||
mocked_hole = _create_mocked_hole()
|
||||
with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole):
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
pi_hole.DOMAIN,
|
||||
{pi_hole.DOMAIN: [{"host": "pi.hole1", "api_key": "1"}]},
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await hass.services.async_call(
|
||||
switch.DOMAIN,
|
||||
switch.SERVICE_TURN_ON,
|
||||
{"entity_id": SWITCH_ENTITY_ID},
|
||||
blocking=True,
|
||||
)
|
||||
mocked_hole.enable.assert_called_once()
|
||||
|
||||
await hass.services.async_call(
|
||||
switch.DOMAIN,
|
||||
switch.SERVICE_TURN_OFF,
|
||||
{"entity_id": SWITCH_ENTITY_ID},
|
||||
blocking=True,
|
||||
)
|
||||
mocked_hole.disable.assert_called_once_with(True)
|
||||
|
||||
# Failed calls
|
||||
type(mocked_hole).enable = AsyncMock(side_effect=HoleError("Error1"))
|
||||
await hass.services.async_call(
|
||||
switch.DOMAIN,
|
||||
switch.SERVICE_TURN_ON,
|
||||
{"entity_id": SWITCH_ENTITY_ID},
|
||||
blocking=True,
|
||||
)
|
||||
type(mocked_hole).disable = AsyncMock(side_effect=HoleError("Error2"))
|
||||
await hass.services.async_call(
|
||||
switch.DOMAIN,
|
||||
switch.SERVICE_TURN_OFF,
|
||||
{"entity_id": SWITCH_ENTITY_ID},
|
||||
blocking=True,
|
||||
)
|
||||
errors = [x for x in caplog.records if x.levelno == logging.ERROR]
|
||||
assert errors[-2].message == "Unable to enable Pi-hole: Error1"
|
||||
assert errors[-1].message == "Unable to disable Pi-hole: Error2"
|
||||
|
||||
|
||||
async def test_disable_service_call(hass):
|
||||
"""Test disable service call with no Pi-hole named."""
|
||||
mocked_hole = _create_mocked_hole()
|
||||
@ -98,7 +167,7 @@ async def test_disable_service_call(hass):
|
||||
{
|
||||
pi_hole.DOMAIN: [
|
||||
{"host": "pi.hole1", "api_key": "1"},
|
||||
{"host": "pi.hole2", "name": "Custom", "api_key": "2"},
|
||||
{"host": "pi.hole2", "name": "Custom"},
|
||||
]
|
||||
},
|
||||
)
|
||||
@ -107,57 +176,35 @@ async def test_disable_service_call(hass):
|
||||
|
||||
await hass.services.async_call(
|
||||
pi_hole.DOMAIN,
|
||||
pi_hole.SERVICE_DISABLE,
|
||||
{pi_hole.SERVICE_DISABLE_ATTR_DURATION: "00:00:01"},
|
||||
SERVICE_DISABLE,
|
||||
{ATTR_ENTITY_ID: "all", SERVICE_DISABLE_ATTR_DURATION: "00:00:01"},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mocked_hole.disable.call_count == 2
|
||||
mocked_hole.disable.assert_called_once_with(1)
|
||||
|
||||
|
||||
async def test_enable_service_call(hass):
|
||||
"""Test enable service call with no Pi-hole named."""
|
||||
async def test_unload(hass):
|
||||
"""Test unload entities."""
|
||||
entry = MockConfigEntry(
|
||||
domain=pi_hole.DOMAIN,
|
||||
data={
|
||||
CONF_NAME: DEFAULT_NAME,
|
||||
CONF_HOST: "pi.hole",
|
||||
CONF_LOCATION: DEFAULT_LOCATION,
|
||||
CONF_SSL: DEFAULT_SSL,
|
||||
CONF_VERIFY_SSL: DEFAULT_VERIFY_SSL,
|
||||
},
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
mocked_hole = _create_mocked_hole()
|
||||
with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole):
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
pi_hole.DOMAIN,
|
||||
{
|
||||
pi_hole.DOMAIN: [
|
||||
{"host": "pi.hole1", "api_key": "1"},
|
||||
{"host": "pi.hole2", "name": "Custom", "api_key": "2"},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
assert entry.entry_id in hass.data[pi_hole.DOMAIN]
|
||||
|
||||
await hass.services.async_call(
|
||||
pi_hole.DOMAIN, pi_hole.SERVICE_ENABLE, {}, blocking=True
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mocked_hole.enable.call_count == 2
|
||||
|
||||
|
||||
async def test_update_coordinator(hass):
|
||||
"""Test update coordinator."""
|
||||
mocked_hole = _create_mocked_hole()
|
||||
sensor_entity_id = "sensor.pi_hole_ads_blocked_today"
|
||||
with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole):
|
||||
assert await async_setup_component(
|
||||
hass, pi_hole.DOMAIN, {pi_hole.DOMAIN: [{"host": "pi.hole"}]}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert mocked_hole.get_data.call_count == 3
|
||||
assert hass.states.get(sensor_entity_id).state == "0"
|
||||
|
||||
mocked_hole.data["ads_blocked_today"] = 1
|
||||
utcnow = dt_util.utcnow()
|
||||
async_fire_time_changed(hass, utcnow + MIN_TIME_BETWEEN_UPDATES)
|
||||
await hass.async_block_till_done()
|
||||
assert mocked_hole.get_data.call_count == 4
|
||||
assert hass.states.get(sensor_entity_id).state == "1"
|
||||
assert await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
assert entry.entry_id not in hass.data[pi_hole.DOMAIN]
|
||||
|
Loading…
x
Reference in New Issue
Block a user