Add NextDNS switch platform (#74512)

* Add switch platform

* Use lambda to get state

* Use async with timeout

* Add tests

* Use correct type

* Use Generic for coordinator

* Use TCoordinatorData

* Cleanup generic

* Simplify coordinator data update methods

* Use new entity naming style

* Remove unnecessary code

* Only the first word should be capitalised

* Suggested change

* improve typing in tests

* Improve typing intests

* Update tests/components/nextdns/__init__.py

* black

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
Maciej Bieniek 2022-07-11 16:00:13 +02:00 committed by GitHub
parent 8820ce0bdd
commit c1a4dc2f22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 619 additions and 32 deletions

View File

@ -17,6 +17,7 @@ from nextdns import (
ApiError, ApiError,
InvalidApiKeyError, InvalidApiKeyError,
NextDns, NextDns,
Settings,
) )
from nextdns.model import NextDnsData from nextdns.model import NextDnsData
@ -34,10 +35,12 @@ from .const import (
ATTR_ENCRYPTION, ATTR_ENCRYPTION,
ATTR_IP_VERSIONS, ATTR_IP_VERSIONS,
ATTR_PROTOCOLS, ATTR_PROTOCOLS,
ATTR_SETTINGS,
ATTR_STATUS, ATTR_STATUS,
CONF_PROFILE_ID, CONF_PROFILE_ID,
DOMAIN, DOMAIN,
UPDATE_INTERVAL_ANALYTICS, UPDATE_INTERVAL_ANALYTICS,
UPDATE_INTERVAL_SETTINGS,
) )
TCoordinatorData = TypeVar("TCoordinatorData", bound=NextDnsData) TCoordinatorData = TypeVar("TCoordinatorData", bound=NextDnsData)
@ -68,6 +71,14 @@ class NextDnsUpdateCoordinator(DataUpdateCoordinator[TCoordinatorData]):
super().__init__(hass, _LOGGER, name=DOMAIN, update_interval=update_interval) super().__init__(hass, _LOGGER, name=DOMAIN, update_interval=update_interval)
async def _async_update_data(self) -> TCoordinatorData: async def _async_update_data(self) -> TCoordinatorData:
"""Update data via internal method."""
try:
async with timeout(10):
return await self._async_update_data_internal()
except (ApiError, ClientConnectorError, InvalidApiKeyError) as err:
raise UpdateFailed(err) from err
async def _async_update_data_internal(self) -> TCoordinatorData:
"""Update data via library.""" """Update data via library."""
raise NotImplementedError("Update method not implemented") raise NotImplementedError("Update method not implemented")
@ -75,71 +86,60 @@ class NextDnsUpdateCoordinator(DataUpdateCoordinator[TCoordinatorData]):
class NextDnsStatusUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsStatus]): class NextDnsStatusUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsStatus]):
"""Class to manage fetching NextDNS analytics status data from API.""" """Class to manage fetching NextDNS analytics status data from API."""
async def _async_update_data(self) -> AnalyticsStatus: async def _async_update_data_internal(self) -> AnalyticsStatus:
"""Update data via library.""" """Update data via library."""
try:
async with timeout(10):
return await self.nextdns.get_analytics_status(self.profile_id) return await self.nextdns.get_analytics_status(self.profile_id)
except (ApiError, ClientConnectorError, InvalidApiKeyError) as err:
raise UpdateFailed(err) from err
class NextDnsDnssecUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsDnssec]): class NextDnsDnssecUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsDnssec]):
"""Class to manage fetching NextDNS analytics Dnssec data from API.""" """Class to manage fetching NextDNS analytics Dnssec data from API."""
async def _async_update_data(self) -> AnalyticsDnssec: async def _async_update_data_internal(self) -> AnalyticsDnssec:
"""Update data via library.""" """Update data via library."""
try:
async with timeout(10):
return await self.nextdns.get_analytics_dnssec(self.profile_id) return await self.nextdns.get_analytics_dnssec(self.profile_id)
except (ApiError, ClientConnectorError, InvalidApiKeyError) as err:
raise UpdateFailed(err) from err
class NextDnsEncryptionUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsEncryption]): class NextDnsEncryptionUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsEncryption]):
"""Class to manage fetching NextDNS analytics encryption data from API.""" """Class to manage fetching NextDNS analytics encryption data from API."""
async def _async_update_data(self) -> AnalyticsEncryption: async def _async_update_data_internal(self) -> AnalyticsEncryption:
"""Update data via library.""" """Update data via library."""
try:
async with timeout(10):
return await self.nextdns.get_analytics_encryption(self.profile_id) return await self.nextdns.get_analytics_encryption(self.profile_id)
except (ApiError, ClientConnectorError, InvalidApiKeyError) as err:
raise UpdateFailed(err) from err
class NextDnsIpVersionsUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsIpVersions]): class NextDnsIpVersionsUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsIpVersions]):
"""Class to manage fetching NextDNS analytics IP versions data from API.""" """Class to manage fetching NextDNS analytics IP versions data from API."""
async def _async_update_data(self) -> AnalyticsIpVersions: async def _async_update_data_internal(self) -> AnalyticsIpVersions:
"""Update data via library.""" """Update data via library."""
try:
async with timeout(10):
return await self.nextdns.get_analytics_ip_versions(self.profile_id) return await self.nextdns.get_analytics_ip_versions(self.profile_id)
except (ApiError, ClientConnectorError, InvalidApiKeyError) as err:
raise UpdateFailed(err) from err
class NextDnsProtocolsUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsProtocols]): class NextDnsProtocolsUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsProtocols]):
"""Class to manage fetching NextDNS analytics protocols data from API.""" """Class to manage fetching NextDNS analytics protocols data from API."""
async def _async_update_data(self) -> AnalyticsProtocols: async def _async_update_data_internal(self) -> AnalyticsProtocols:
"""Update data via library.""" """Update data via library."""
try:
async with timeout(10):
return await self.nextdns.get_analytics_protocols(self.profile_id) return await self.nextdns.get_analytics_protocols(self.profile_id)
except (ApiError, ClientConnectorError, InvalidApiKeyError) as err:
raise UpdateFailed(err) from err
class NextDnsSettingsUpdateCoordinator(NextDnsUpdateCoordinator[Settings]):
"""Class to manage fetching NextDNS connection data from API."""
async def _async_update_data_internal(self) -> Settings:
"""Update data via library."""
return await self.nextdns.get_settings(self.profile_id)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
PLATFORMS = [Platform.BUTTON, Platform.SENSOR] PLATFORMS = [Platform.BUTTON, Platform.SENSOR, Platform.SWITCH]
COORDINATORS = [ COORDINATORS = [
(ATTR_DNSSEC, NextDnsDnssecUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS), (ATTR_DNSSEC, NextDnsDnssecUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS),
(ATTR_ENCRYPTION, NextDnsEncryptionUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS), (ATTR_ENCRYPTION, NextDnsEncryptionUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS),
(ATTR_IP_VERSIONS, NextDnsIpVersionsUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS), (ATTR_IP_VERSIONS, NextDnsIpVersionsUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS),
(ATTR_PROTOCOLS, NextDnsProtocolsUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS), (ATTR_PROTOCOLS, NextDnsProtocolsUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS),
(ATTR_SETTINGS, NextDnsSettingsUpdateCoordinator, UPDATE_INTERVAL_SETTINGS),
(ATTR_STATUS, NextDnsStatusUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS), (ATTR_STATUS, NextDnsStatusUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS),
] ]

View File

@ -5,11 +5,13 @@ ATTR_DNSSEC = "dnssec"
ATTR_ENCRYPTION = "encryption" ATTR_ENCRYPTION = "encryption"
ATTR_IP_VERSIONS = "ip_versions" ATTR_IP_VERSIONS = "ip_versions"
ATTR_PROTOCOLS = "protocols" ATTR_PROTOCOLS = "protocols"
ATTR_SETTINGS = "settings"
ATTR_STATUS = "status" ATTR_STATUS = "status"
CONF_PROFILE_ID = "profile_id" CONF_PROFILE_ID = "profile_id"
CONF_PROFILE_NAME = "profile_name" CONF_PROFILE_NAME = "profile_name"
UPDATE_INTERVAL_ANALYTICS = timedelta(minutes=10) UPDATE_INTERVAL_ANALYTICS = timedelta(minutes=10)
UPDATE_INTERVAL_SETTINGS = timedelta(minutes=1)
DOMAIN = "nextdns" DOMAIN = "nextdns"

View File

@ -0,0 +1,247 @@
"""Support for the NextDNS service."""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Generic
from nextdns import Settings
from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity import EntityCategory
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import NextDnsSettingsUpdateCoordinator, TCoordinatorData
from .const import ATTR_SETTINGS, DOMAIN
PARALLEL_UPDATES = 1
@dataclass
class NextDnsSwitchRequiredKeysMixin(Generic[TCoordinatorData]):
"""Class for NextDNS entity required keys."""
state: Callable[[TCoordinatorData], bool]
@dataclass
class NextDnsSwitchEntityDescription(
SwitchEntityDescription, NextDnsSwitchRequiredKeysMixin[TCoordinatorData]
):
"""NextDNS switch entity description."""
SWITCHES = (
NextDnsSwitchEntityDescription[Settings](
key="block_page",
name="Block page",
entity_category=EntityCategory.CONFIG,
icon="mdi:web-cancel",
state=lambda data: data.block_page,
),
NextDnsSwitchEntityDescription[Settings](
key="cache_boost",
name="Cache boost",
entity_category=EntityCategory.CONFIG,
icon="mdi:memory",
state=lambda data: data.cache_boost,
),
NextDnsSwitchEntityDescription[Settings](
key="cname_flattening",
name="CNAME flattening",
entity_category=EntityCategory.CONFIG,
icon="mdi:tournament",
state=lambda data: data.cname_flattening,
),
NextDnsSwitchEntityDescription[Settings](
key="anonymized_ecs",
name="Anonymized EDNS client subnet",
entity_category=EntityCategory.CONFIG,
icon="mdi:incognito",
state=lambda data: data.anonymized_ecs,
),
NextDnsSwitchEntityDescription[Settings](
key="logs",
name="Logs",
entity_category=EntityCategory.CONFIG,
icon="mdi:file-document-outline",
state=lambda data: data.logs,
),
NextDnsSwitchEntityDescription[Settings](
key="web3",
name="Web3",
entity_category=EntityCategory.CONFIG,
icon="mdi:web",
state=lambda data: data.web3,
),
NextDnsSwitchEntityDescription[Settings](
key="allow_affiliate",
name="Allow affiliate & tracking links",
entity_category=EntityCategory.CONFIG,
state=lambda data: data.allow_affiliate,
),
NextDnsSwitchEntityDescription[Settings](
key="block_disguised_trackers",
name="Block disguised third-party trackers",
entity_category=EntityCategory.CONFIG,
state=lambda data: data.block_disguised_trackers,
),
NextDnsSwitchEntityDescription[Settings](
key="ai_threat_detection",
name="AI-Driven threat detection",
entity_category=EntityCategory.CONFIG,
state=lambda data: data.ai_threat_detection,
),
NextDnsSwitchEntityDescription[Settings](
key="block_csam",
name="Block child sexual abuse material",
entity_category=EntityCategory.CONFIG,
state=lambda data: data.block_csam,
),
NextDnsSwitchEntityDescription[Settings](
key="block_ddns",
name="Block dynamic DNS hostnames",
entity_category=EntityCategory.CONFIG,
state=lambda data: data.block_ddns,
),
NextDnsSwitchEntityDescription[Settings](
key="block_nrd",
name="Block newly registered domains",
entity_category=EntityCategory.CONFIG,
state=lambda data: data.block_nrd,
),
NextDnsSwitchEntityDescription[Settings](
key="block_parked_domains",
name="Block parked domains",
entity_category=EntityCategory.CONFIG,
state=lambda data: data.block_parked_domains,
),
NextDnsSwitchEntityDescription[Settings](
key="cryptojacking_protection",
name="Cryptojacking protection",
entity_category=EntityCategory.CONFIG,
state=lambda data: data.cryptojacking_protection,
),
NextDnsSwitchEntityDescription[Settings](
key="dga_protection",
name="Domain generation algorithms protection",
entity_category=EntityCategory.CONFIG,
state=lambda data: data.dga_protection,
),
NextDnsSwitchEntityDescription[Settings](
key="dns_rebinding_protection",
name="DNS rebinding protection",
entity_category=EntityCategory.CONFIG,
icon="mdi:dns",
state=lambda data: data.dns_rebinding_protection,
),
NextDnsSwitchEntityDescription[Settings](
key="google_safe_browsing",
name="Google safe browsing",
entity_category=EntityCategory.CONFIG,
icon="mdi:google",
state=lambda data: data.google_safe_browsing,
),
NextDnsSwitchEntityDescription[Settings](
key="idn_homograph_attacks_protection",
name="IDN homograph attacks protection",
entity_category=EntityCategory.CONFIG,
state=lambda data: data.idn_homograph_attacks_protection,
),
NextDnsSwitchEntityDescription[Settings](
key="threat_intelligence_feeds",
name="Threat intelligence feeds",
entity_category=EntityCategory.CONFIG,
state=lambda data: data.threat_intelligence_feeds,
),
NextDnsSwitchEntityDescription[Settings](
key="typosquatting_protection",
name="Typosquatting protection",
entity_category=EntityCategory.CONFIG,
icon="mdi:keyboard-outline",
state=lambda data: data.typosquatting_protection,
),
NextDnsSwitchEntityDescription[Settings](
key="block_bypass_methods",
name="Block bypass methods",
entity_category=EntityCategory.CONFIG,
state=lambda data: data.block_bypass_methods,
),
NextDnsSwitchEntityDescription[Settings](
key="safesearch",
name="Force SafeSearch",
entity_category=EntityCategory.CONFIG,
icon="mdi:search-web",
state=lambda data: data.safesearch,
),
NextDnsSwitchEntityDescription[Settings](
key="youtube_restricted_mode",
name="Force YouTube restricted mode",
entity_category=EntityCategory.CONFIG,
icon="mdi:youtube",
state=lambda data: data.youtube_restricted_mode,
),
)
async def async_setup_entry(
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback
) -> None:
"""Add NextDNS entities from a config_entry."""
coordinator: NextDnsSettingsUpdateCoordinator = hass.data[DOMAIN][entry.entry_id][
ATTR_SETTINGS
]
switches: list[NextDnsSwitch] = []
for description in SWITCHES:
switches.append(NextDnsSwitch(coordinator, description))
async_add_entities(switches)
class NextDnsSwitch(CoordinatorEntity[NextDnsSettingsUpdateCoordinator], SwitchEntity):
"""Define an NextDNS switch."""
_attr_has_entity_name = True
entity_description: NextDnsSwitchEntityDescription
def __init__(
self,
coordinator: NextDnsSettingsUpdateCoordinator,
description: NextDnsSwitchEntityDescription,
) -> None:
"""Initialize."""
super().__init__(coordinator)
self._attr_device_info = coordinator.device_info
self._attr_unique_id = f"{coordinator.profile_id}_{description.key}"
self._attr_is_on = description.state(coordinator.data)
self.entity_description = description
@callback
def _handle_coordinator_update(self) -> None:
"""Handle updated data from the coordinator."""
self._attr_is_on = self.entity_description.state(self.coordinator.data)
self.async_write_ha_state()
async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on switch."""
result = await self.coordinator.nextdns.set_setting(
self.coordinator.profile_id, self.entity_description.key, True
)
if result:
self._attr_is_on = True
self.async_write_ha_state()
async def async_turn_off(self, **kwargs: Any) -> None:
"""Turn off switch."""
result = await self.coordinator.nextdns.set_setting(
self.coordinator.profile_id, self.entity_description.key, False
)
if result:
self._attr_is_on = False
self.async_write_ha_state()

View File

@ -7,10 +7,12 @@ from nextdns import (
AnalyticsIpVersions, AnalyticsIpVersions,
AnalyticsProtocols, AnalyticsProtocols,
AnalyticsStatus, AnalyticsStatus,
Settings,
) )
from homeassistant.components.nextdns.const import CONF_PROFILE_ID, DOMAIN from homeassistant.components.nextdns.const import CONF_PROFILE_ID, DOMAIN
from homeassistant.const import CONF_API_KEY from homeassistant.const import CONF_API_KEY
from homeassistant.core import HomeAssistant
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -24,9 +26,36 @@ IP_VERSIONS = AnalyticsIpVersions(ipv4_queries=90, ipv6_queries=10)
PROTOCOLS = AnalyticsProtocols( PROTOCOLS = AnalyticsProtocols(
doh_queries=20, doq_queries=10, dot_queries=30, udp_queries=40 doh_queries=20, doq_queries=10, dot_queries=30, udp_queries=40
) )
SETTINGS = Settings(
ai_threat_detection=True,
allow_affiliate=True,
anonymized_ecs=True,
block_bypass_methods=True,
block_csam=True,
block_ddns=True,
block_disguised_trackers=True,
block_nrd=True,
block_page=False,
block_parked_domains=True,
cache_boost=True,
cname_flattening=True,
cryptojacking_protection=True,
dga_protection=True,
dns_rebinding_protection=True,
google_safe_browsing=False,
idn_homograph_attacks_protection=True,
logs=True,
safesearch=False,
threat_intelligence_feeds=True,
typosquatting_protection=True,
web3=True,
youtube_restricted_mode=False,
)
async def init_integration(hass, add_to_hass=True) -> MockConfigEntry: async def init_integration(
hass: HomeAssistant, add_to_hass: bool = True
) -> MockConfigEntry:
"""Set up the NextDNS integration in Home Assistant.""" """Set up the NextDNS integration in Home Assistant."""
entry = MockConfigEntry( entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
@ -55,6 +84,9 @@ async def init_integration(hass, add_to_hass=True) -> MockConfigEntry:
), patch( ), patch(
"homeassistant.components.nextdns.NextDns.get_analytics_protocols", "homeassistant.components.nextdns.NextDns.get_analytics_protocols",
return_value=PROTOCOLS, return_value=PROTOCOLS,
), patch(
"homeassistant.components.nextdns.NextDns.get_settings",
return_value=SETTINGS,
): ):
entry.add_to_hass(hass) entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)

View File

@ -0,0 +1,306 @@
"""Test switch of NextDNS integration."""
from datetime import timedelta
from unittest.mock import patch
from nextdns import ApiError
from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN
from homeassistant.const import (
ATTR_ENTITY_ID,
SERVICE_TURN_OFF,
SERVICE_TURN_ON,
STATE_OFF,
STATE_ON,
STATE_UNAVAILABLE,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.util.dt import utcnow
from . import SETTINGS, init_integration
from tests.common import async_fire_time_changed
async def test_switch(hass: HomeAssistant) -> None:
"""Test states of the switches."""
registry = er.async_get(hass)
await init_integration(hass)
state = hass.states.get("switch.fake_profile_ai_driven_threat_detection")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_ai_driven_threat_detection")
assert entry
assert entry.unique_id == "xyz12_ai_threat_detection"
state = hass.states.get("switch.fake_profile_allow_affiliate_tracking_links")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_allow_affiliate_tracking_links")
assert entry
assert entry.unique_id == "xyz12_allow_affiliate"
state = hass.states.get("switch.fake_profile_anonymized_edns_client_subnet")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_anonymized_edns_client_subnet")
assert entry
assert entry.unique_id == "xyz12_anonymized_ecs"
state = hass.states.get("switch.fake_profile_block_bypass_methods")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_block_bypass_methods")
assert entry
assert entry.unique_id == "xyz12_block_bypass_methods"
state = hass.states.get("switch.fake_profile_block_child_sexual_abuse_material")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_block_child_sexual_abuse_material")
assert entry
assert entry.unique_id == "xyz12_block_csam"
state = hass.states.get("switch.fake_profile_block_disguised_third_party_trackers")
assert state
assert state.state == STATE_ON
entry = registry.async_get(
"switch.fake_profile_block_disguised_third_party_trackers"
)
assert entry
assert entry.unique_id == "xyz12_block_disguised_trackers"
state = hass.states.get("switch.fake_profile_block_dynamic_dns_hostnames")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_block_dynamic_dns_hostnames")
assert entry
assert entry.unique_id == "xyz12_block_ddns"
state = hass.states.get("switch.fake_profile_block_newly_registered_domains")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_block_newly_registered_domains")
assert entry
assert entry.unique_id == "xyz12_block_nrd"
state = hass.states.get("switch.fake_profile_block_page")
assert state
assert state.state == STATE_OFF
entry = registry.async_get("switch.fake_profile_block_page")
assert entry
assert entry.unique_id == "xyz12_block_page"
state = hass.states.get("switch.fake_profile_block_parked_domains")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_block_parked_domains")
assert entry
assert entry.unique_id == "xyz12_block_parked_domains"
state = hass.states.get("switch.fake_profile_cname_flattening")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_cname_flattening")
assert entry
assert entry.unique_id == "xyz12_cname_flattening"
state = hass.states.get("switch.fake_profile_cache_boost")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_cache_boost")
assert entry
assert entry.unique_id == "xyz12_cache_boost"
state = hass.states.get("switch.fake_profile_cryptojacking_protection")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_cryptojacking_protection")
assert entry
assert entry.unique_id == "xyz12_cryptojacking_protection"
state = hass.states.get("switch.fake_profile_dns_rebinding_protection")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_dns_rebinding_protection")
assert entry
assert entry.unique_id == "xyz12_dns_rebinding_protection"
state = hass.states.get(
"switch.fake_profile_domain_generation_algorithms_protection"
)
assert state
assert state.state == STATE_ON
entry = registry.async_get(
"switch.fake_profile_domain_generation_algorithms_protection"
)
assert entry
assert entry.unique_id == "xyz12_dga_protection"
state = hass.states.get("switch.fake_profile_force_safesearch")
assert state
assert state.state == STATE_OFF
entry = registry.async_get("switch.fake_profile_force_safesearch")
assert entry
assert entry.unique_id == "xyz12_safesearch"
state = hass.states.get("switch.fake_profile_force_youtube_restricted_mode")
assert state
assert state.state == STATE_OFF
entry = registry.async_get("switch.fake_profile_force_youtube_restricted_mode")
assert entry
assert entry.unique_id == "xyz12_youtube_restricted_mode"
state = hass.states.get("switch.fake_profile_google_safe_browsing")
assert state
assert state.state == STATE_OFF
entry = registry.async_get("switch.fake_profile_google_safe_browsing")
assert entry
assert entry.unique_id == "xyz12_google_safe_browsing"
state = hass.states.get("switch.fake_profile_idn_homograph_attacks_protection")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_idn_homograph_attacks_protection")
assert entry
assert entry.unique_id == "xyz12_idn_homograph_attacks_protection"
state = hass.states.get("switch.fake_profile_logs")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_logs")
assert entry
assert entry.unique_id == "xyz12_logs"
state = hass.states.get("switch.fake_profile_threat_intelligence_feeds")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_threat_intelligence_feeds")
assert entry
assert entry.unique_id == "xyz12_threat_intelligence_feeds"
state = hass.states.get("switch.fake_profile_typosquatting_protection")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_typosquatting_protection")
assert entry
assert entry.unique_id == "xyz12_typosquatting_protection"
state = hass.states.get("switch.fake_profile_web3")
assert state
assert state.state == STATE_ON
entry = registry.async_get("switch.fake_profile_web3")
assert entry
assert entry.unique_id == "xyz12_web3"
async def test_switch_on(hass: HomeAssistant) -> None:
"""Test the switch can be turned on."""
await init_integration(hass)
state = hass.states.get("switch.fake_profile_block_page")
assert state
assert state.state == STATE_OFF
with patch(
"homeassistant.components.nextdns.NextDns.set_setting", return_value=True
) as mock_switch_on:
assert await hass.services.async_call(
SWITCH_DOMAIN,
SERVICE_TURN_ON,
{ATTR_ENTITY_ID: "switch.fake_profile_block_page"},
blocking=True,
)
await hass.async_block_till_done()
state = hass.states.get("switch.fake_profile_block_page")
assert state
assert state.state == STATE_ON
mock_switch_on.assert_called_once()
async def test_switch_off(hass: HomeAssistant) -> None:
"""Test the switch can be turned on."""
await init_integration(hass)
state = hass.states.get("switch.fake_profile_web3")
assert state
assert state.state == STATE_ON
with patch(
"homeassistant.components.nextdns.NextDns.set_setting", return_value=True
) as mock_switch_on:
assert await hass.services.async_call(
SWITCH_DOMAIN,
SERVICE_TURN_OFF,
{ATTR_ENTITY_ID: "switch.fake_profile_web3"},
blocking=True,
)
await hass.async_block_till_done()
state = hass.states.get("switch.fake_profile_web3")
assert state
assert state.state == STATE_OFF
mock_switch_on.assert_called_once()
async def test_availability(hass: HomeAssistant) -> None:
"""Ensure that we mark the entities unavailable correctly when service causes an error."""
await init_integration(hass)
state = hass.states.get("switch.fake_profile_web3")
assert state
assert state.state != STATE_UNAVAILABLE
assert state.state == STATE_ON
future = utcnow() + timedelta(minutes=10)
with patch(
"homeassistant.components.nextdns.NextDns.get_settings",
side_effect=ApiError("API Error"),
):
async_fire_time_changed(hass, future)
await hass.async_block_till_done()
state = hass.states.get("switch.fake_profile_web3")
assert state
assert state.state == STATE_UNAVAILABLE
future = utcnow() + timedelta(minutes=20)
with patch(
"homeassistant.components.nextdns.NextDns.get_settings",
return_value=SETTINGS,
):
async_fire_time_changed(hass, future)
await hass.async_block_till_done()
state = hass.states.get("switch.fake_profile_web3")
assert state
assert state.state != STATE_UNAVAILABLE
assert state.state == STATE_ON