From c1a4dc2f229179243acd00c2231e3070e92db2af Mon Sep 17 00:00:00 2001 From: Maciej Bieniek Date: Mon, 11 Jul 2022 16:00:13 +0200 Subject: [PATCH] Add NextDNS switch platform (#74512) * Add switch platform * Use lambda to get state * Use async with timeout * Add tests * Use correct type * Use Generic for coordinator * Use TCoordinatorData * Cleanup generic * Simplify coordinator data update methods * Use new entity naming style * Remove unnecessary code * Only the first word should be capitalised * Suggested change * improve typing in tests * Improve typing intests * Update tests/components/nextdns/__init__.py * black Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> --- homeassistant/components/nextdns/__init__.py | 62 ++-- homeassistant/components/nextdns/const.py | 2 + homeassistant/components/nextdns/switch.py | 247 +++++++++++++++ tests/components/nextdns/__init__.py | 34 ++- tests/components/nextdns/test_switch.py | 306 +++++++++++++++++++ 5 files changed, 619 insertions(+), 32 deletions(-) create mode 100644 homeassistant/components/nextdns/switch.py create mode 100644 tests/components/nextdns/test_switch.py diff --git a/homeassistant/components/nextdns/__init__.py b/homeassistant/components/nextdns/__init__.py index 7e7f5ff2dd8..2f68abee847 100644 --- a/homeassistant/components/nextdns/__init__.py +++ b/homeassistant/components/nextdns/__init__.py @@ -17,6 +17,7 @@ from nextdns import ( ApiError, InvalidApiKeyError, NextDns, + Settings, ) from nextdns.model import NextDnsData @@ -34,10 +35,12 @@ from .const import ( ATTR_ENCRYPTION, ATTR_IP_VERSIONS, ATTR_PROTOCOLS, + ATTR_SETTINGS, ATTR_STATUS, CONF_PROFILE_ID, DOMAIN, UPDATE_INTERVAL_ANALYTICS, + UPDATE_INTERVAL_SETTINGS, ) TCoordinatorData = TypeVar("TCoordinatorData", bound=NextDnsData) @@ -68,6 +71,14 @@ class NextDnsUpdateCoordinator(DataUpdateCoordinator[TCoordinatorData]): super().__init__(hass, _LOGGER, name=DOMAIN, update_interval=update_interval) async def _async_update_data(self) -> TCoordinatorData: + """Update data via internal method.""" + try: + async with timeout(10): + return await self._async_update_data_internal() + except (ApiError, ClientConnectorError, InvalidApiKeyError) as err: + raise UpdateFailed(err) from err + + async def _async_update_data_internal(self) -> TCoordinatorData: """Update data via library.""" raise NotImplementedError("Update method not implemented") @@ -75,71 +86,60 @@ class NextDnsUpdateCoordinator(DataUpdateCoordinator[TCoordinatorData]): class NextDnsStatusUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsStatus]): """Class to manage fetching NextDNS analytics status data from API.""" - async def _async_update_data(self) -> AnalyticsStatus: + async def _async_update_data_internal(self) -> AnalyticsStatus: """Update data via library.""" - try: - async with timeout(10): - return await self.nextdns.get_analytics_status(self.profile_id) - except (ApiError, ClientConnectorError, InvalidApiKeyError) as err: - raise UpdateFailed(err) from err + return await self.nextdns.get_analytics_status(self.profile_id) class NextDnsDnssecUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsDnssec]): """Class to manage fetching NextDNS analytics Dnssec data from API.""" - async def _async_update_data(self) -> AnalyticsDnssec: + async def _async_update_data_internal(self) -> AnalyticsDnssec: """Update data via library.""" - try: - async with timeout(10): - return await self.nextdns.get_analytics_dnssec(self.profile_id) - except (ApiError, ClientConnectorError, InvalidApiKeyError) as err: - raise UpdateFailed(err) from err + return await self.nextdns.get_analytics_dnssec(self.profile_id) class NextDnsEncryptionUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsEncryption]): """Class to manage fetching NextDNS analytics encryption data from API.""" - async def _async_update_data(self) -> AnalyticsEncryption: + async def _async_update_data_internal(self) -> AnalyticsEncryption: """Update data via library.""" - try: - async with timeout(10): - return await self.nextdns.get_analytics_encryption(self.profile_id) - except (ApiError, ClientConnectorError, InvalidApiKeyError) as err: - raise UpdateFailed(err) from err + return await self.nextdns.get_analytics_encryption(self.profile_id) class NextDnsIpVersionsUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsIpVersions]): """Class to manage fetching NextDNS analytics IP versions data from API.""" - async def _async_update_data(self) -> AnalyticsIpVersions: + async def _async_update_data_internal(self) -> AnalyticsIpVersions: """Update data via library.""" - try: - async with timeout(10): - return await self.nextdns.get_analytics_ip_versions(self.profile_id) - except (ApiError, ClientConnectorError, InvalidApiKeyError) as err: - raise UpdateFailed(err) from err + return await self.nextdns.get_analytics_ip_versions(self.profile_id) class NextDnsProtocolsUpdateCoordinator(NextDnsUpdateCoordinator[AnalyticsProtocols]): """Class to manage fetching NextDNS analytics protocols data from API.""" - async def _async_update_data(self) -> AnalyticsProtocols: + async def _async_update_data_internal(self) -> AnalyticsProtocols: """Update data via library.""" - try: - async with timeout(10): - return await self.nextdns.get_analytics_protocols(self.profile_id) - except (ApiError, ClientConnectorError, InvalidApiKeyError) as err: - raise UpdateFailed(err) from err + return await self.nextdns.get_analytics_protocols(self.profile_id) + + +class NextDnsSettingsUpdateCoordinator(NextDnsUpdateCoordinator[Settings]): + """Class to manage fetching NextDNS connection data from API.""" + + async def _async_update_data_internal(self) -> Settings: + """Update data via library.""" + return await self.nextdns.get_settings(self.profile_id) _LOGGER = logging.getLogger(__name__) -PLATFORMS = [Platform.BUTTON, Platform.SENSOR] +PLATFORMS = [Platform.BUTTON, Platform.SENSOR, Platform.SWITCH] COORDINATORS = [ (ATTR_DNSSEC, NextDnsDnssecUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS), (ATTR_ENCRYPTION, NextDnsEncryptionUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS), (ATTR_IP_VERSIONS, NextDnsIpVersionsUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS), (ATTR_PROTOCOLS, NextDnsProtocolsUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS), + (ATTR_SETTINGS, NextDnsSettingsUpdateCoordinator, UPDATE_INTERVAL_SETTINGS), (ATTR_STATUS, NextDnsStatusUpdateCoordinator, UPDATE_INTERVAL_ANALYTICS), ] diff --git a/homeassistant/components/nextdns/const.py b/homeassistant/components/nextdns/const.py index 04bab44354b..d455dd79635 100644 --- a/homeassistant/components/nextdns/const.py +++ b/homeassistant/components/nextdns/const.py @@ -5,11 +5,13 @@ ATTR_DNSSEC = "dnssec" ATTR_ENCRYPTION = "encryption" ATTR_IP_VERSIONS = "ip_versions" ATTR_PROTOCOLS = "protocols" +ATTR_SETTINGS = "settings" ATTR_STATUS = "status" CONF_PROFILE_ID = "profile_id" CONF_PROFILE_NAME = "profile_name" UPDATE_INTERVAL_ANALYTICS = timedelta(minutes=10) +UPDATE_INTERVAL_SETTINGS = timedelta(minutes=1) DOMAIN = "nextdns" diff --git a/homeassistant/components/nextdns/switch.py b/homeassistant/components/nextdns/switch.py new file mode 100644 index 00000000000..4bd3c14c20f --- /dev/null +++ b/homeassistant/components/nextdns/switch.py @@ -0,0 +1,247 @@ +"""Support for the NextDNS service.""" +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Generic + +from nextdns import Settings + +from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.entity import EntityCategory +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.update_coordinator import CoordinatorEntity + +from . import NextDnsSettingsUpdateCoordinator, TCoordinatorData +from .const import ATTR_SETTINGS, DOMAIN + +PARALLEL_UPDATES = 1 + + +@dataclass +class NextDnsSwitchRequiredKeysMixin(Generic[TCoordinatorData]): + """Class for NextDNS entity required keys.""" + + state: Callable[[TCoordinatorData], bool] + + +@dataclass +class NextDnsSwitchEntityDescription( + SwitchEntityDescription, NextDnsSwitchRequiredKeysMixin[TCoordinatorData] +): + """NextDNS switch entity description.""" + + +SWITCHES = ( + NextDnsSwitchEntityDescription[Settings]( + key="block_page", + name="Block page", + entity_category=EntityCategory.CONFIG, + icon="mdi:web-cancel", + state=lambda data: data.block_page, + ), + NextDnsSwitchEntityDescription[Settings]( + key="cache_boost", + name="Cache boost", + entity_category=EntityCategory.CONFIG, + icon="mdi:memory", + state=lambda data: data.cache_boost, + ), + NextDnsSwitchEntityDescription[Settings]( + key="cname_flattening", + name="CNAME flattening", + entity_category=EntityCategory.CONFIG, + icon="mdi:tournament", + state=lambda data: data.cname_flattening, + ), + NextDnsSwitchEntityDescription[Settings]( + key="anonymized_ecs", + name="Anonymized EDNS client subnet", + entity_category=EntityCategory.CONFIG, + icon="mdi:incognito", + state=lambda data: data.anonymized_ecs, + ), + NextDnsSwitchEntityDescription[Settings]( + key="logs", + name="Logs", + entity_category=EntityCategory.CONFIG, + icon="mdi:file-document-outline", + state=lambda data: data.logs, + ), + NextDnsSwitchEntityDescription[Settings]( + key="web3", + name="Web3", + entity_category=EntityCategory.CONFIG, + icon="mdi:web", + state=lambda data: data.web3, + ), + NextDnsSwitchEntityDescription[Settings]( + key="allow_affiliate", + name="Allow affiliate & tracking links", + entity_category=EntityCategory.CONFIG, + state=lambda data: data.allow_affiliate, + ), + NextDnsSwitchEntityDescription[Settings]( + key="block_disguised_trackers", + name="Block disguised third-party trackers", + entity_category=EntityCategory.CONFIG, + state=lambda data: data.block_disguised_trackers, + ), + NextDnsSwitchEntityDescription[Settings]( + key="ai_threat_detection", + name="AI-Driven threat detection", + entity_category=EntityCategory.CONFIG, + state=lambda data: data.ai_threat_detection, + ), + NextDnsSwitchEntityDescription[Settings]( + key="block_csam", + name="Block child sexual abuse material", + entity_category=EntityCategory.CONFIG, + state=lambda data: data.block_csam, + ), + NextDnsSwitchEntityDescription[Settings]( + key="block_ddns", + name="Block dynamic DNS hostnames", + entity_category=EntityCategory.CONFIG, + state=lambda data: data.block_ddns, + ), + NextDnsSwitchEntityDescription[Settings]( + key="block_nrd", + name="Block newly registered domains", + entity_category=EntityCategory.CONFIG, + state=lambda data: data.block_nrd, + ), + NextDnsSwitchEntityDescription[Settings]( + key="block_parked_domains", + name="Block parked domains", + entity_category=EntityCategory.CONFIG, + state=lambda data: data.block_parked_domains, + ), + NextDnsSwitchEntityDescription[Settings]( + key="cryptojacking_protection", + name="Cryptojacking protection", + entity_category=EntityCategory.CONFIG, + state=lambda data: data.cryptojacking_protection, + ), + NextDnsSwitchEntityDescription[Settings]( + key="dga_protection", + name="Domain generation algorithms protection", + entity_category=EntityCategory.CONFIG, + state=lambda data: data.dga_protection, + ), + NextDnsSwitchEntityDescription[Settings]( + key="dns_rebinding_protection", + name="DNS rebinding protection", + entity_category=EntityCategory.CONFIG, + icon="mdi:dns", + state=lambda data: data.dns_rebinding_protection, + ), + NextDnsSwitchEntityDescription[Settings]( + key="google_safe_browsing", + name="Google safe browsing", + entity_category=EntityCategory.CONFIG, + icon="mdi:google", + state=lambda data: data.google_safe_browsing, + ), + NextDnsSwitchEntityDescription[Settings]( + key="idn_homograph_attacks_protection", + name="IDN homograph attacks protection", + entity_category=EntityCategory.CONFIG, + state=lambda data: data.idn_homograph_attacks_protection, + ), + NextDnsSwitchEntityDescription[Settings]( + key="threat_intelligence_feeds", + name="Threat intelligence feeds", + entity_category=EntityCategory.CONFIG, + state=lambda data: data.threat_intelligence_feeds, + ), + NextDnsSwitchEntityDescription[Settings]( + key="typosquatting_protection", + name="Typosquatting protection", + entity_category=EntityCategory.CONFIG, + icon="mdi:keyboard-outline", + state=lambda data: data.typosquatting_protection, + ), + NextDnsSwitchEntityDescription[Settings]( + key="block_bypass_methods", + name="Block bypass methods", + entity_category=EntityCategory.CONFIG, + state=lambda data: data.block_bypass_methods, + ), + NextDnsSwitchEntityDescription[Settings]( + key="safesearch", + name="Force SafeSearch", + entity_category=EntityCategory.CONFIG, + icon="mdi:search-web", + state=lambda data: data.safesearch, + ), + NextDnsSwitchEntityDescription[Settings]( + key="youtube_restricted_mode", + name="Force YouTube restricted mode", + entity_category=EntityCategory.CONFIG, + icon="mdi:youtube", + state=lambda data: data.youtube_restricted_mode, + ), +) + + +async def async_setup_entry( + hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback +) -> None: + """Add NextDNS entities from a config_entry.""" + coordinator: NextDnsSettingsUpdateCoordinator = hass.data[DOMAIN][entry.entry_id][ + ATTR_SETTINGS + ] + + switches: list[NextDnsSwitch] = [] + for description in SWITCHES: + switches.append(NextDnsSwitch(coordinator, description)) + + async_add_entities(switches) + + +class NextDnsSwitch(CoordinatorEntity[NextDnsSettingsUpdateCoordinator], SwitchEntity): + """Define an NextDNS switch.""" + + _attr_has_entity_name = True + entity_description: NextDnsSwitchEntityDescription + + def __init__( + self, + coordinator: NextDnsSettingsUpdateCoordinator, + description: NextDnsSwitchEntityDescription, + ) -> None: + """Initialize.""" + super().__init__(coordinator) + self._attr_device_info = coordinator.device_info + self._attr_unique_id = f"{coordinator.profile_id}_{description.key}" + self._attr_is_on = description.state(coordinator.data) + self.entity_description = description + + @callback + def _handle_coordinator_update(self) -> None: + """Handle updated data from the coordinator.""" + self._attr_is_on = self.entity_description.state(self.coordinator.data) + self.async_write_ha_state() + + async def async_turn_on(self, **kwargs: Any) -> None: + """Turn on switch.""" + result = await self.coordinator.nextdns.set_setting( + self.coordinator.profile_id, self.entity_description.key, True + ) + + if result: + self._attr_is_on = True + self.async_write_ha_state() + + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn off switch.""" + result = await self.coordinator.nextdns.set_setting( + self.coordinator.profile_id, self.entity_description.key, False + ) + + if result: + self._attr_is_on = False + self.async_write_ha_state() diff --git a/tests/components/nextdns/__init__.py b/tests/components/nextdns/__init__.py index 24063d957d4..82c55f56bbb 100644 --- a/tests/components/nextdns/__init__.py +++ b/tests/components/nextdns/__init__.py @@ -7,10 +7,12 @@ from nextdns import ( AnalyticsIpVersions, AnalyticsProtocols, AnalyticsStatus, + Settings, ) from homeassistant.components.nextdns.const import CONF_PROFILE_ID, DOMAIN from homeassistant.const import CONF_API_KEY +from homeassistant.core import HomeAssistant from tests.common import MockConfigEntry @@ -24,9 +26,36 @@ IP_VERSIONS = AnalyticsIpVersions(ipv4_queries=90, ipv6_queries=10) PROTOCOLS = AnalyticsProtocols( doh_queries=20, doq_queries=10, dot_queries=30, udp_queries=40 ) +SETTINGS = Settings( + ai_threat_detection=True, + allow_affiliate=True, + anonymized_ecs=True, + block_bypass_methods=True, + block_csam=True, + block_ddns=True, + block_disguised_trackers=True, + block_nrd=True, + block_page=False, + block_parked_domains=True, + cache_boost=True, + cname_flattening=True, + cryptojacking_protection=True, + dga_protection=True, + dns_rebinding_protection=True, + google_safe_browsing=False, + idn_homograph_attacks_protection=True, + logs=True, + safesearch=False, + threat_intelligence_feeds=True, + typosquatting_protection=True, + web3=True, + youtube_restricted_mode=False, +) -async def init_integration(hass, add_to_hass=True) -> MockConfigEntry: +async def init_integration( + hass: HomeAssistant, add_to_hass: bool = True +) -> MockConfigEntry: """Set up the NextDNS integration in Home Assistant.""" entry = MockConfigEntry( domain=DOMAIN, @@ -55,6 +84,9 @@ async def init_integration(hass, add_to_hass=True) -> MockConfigEntry: ), patch( "homeassistant.components.nextdns.NextDns.get_analytics_protocols", return_value=PROTOCOLS, + ), patch( + "homeassistant.components.nextdns.NextDns.get_settings", + return_value=SETTINGS, ): entry.add_to_hass(hass) await hass.config_entries.async_setup(entry.entry_id) diff --git a/tests/components/nextdns/test_switch.py b/tests/components/nextdns/test_switch.py new file mode 100644 index 00000000000..3e07a2633d1 --- /dev/null +++ b/tests/components/nextdns/test_switch.py @@ -0,0 +1,306 @@ +"""Test switch of NextDNS integration.""" +from datetime import timedelta +from unittest.mock import patch + +from nextdns import ApiError + +from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN +from homeassistant.const import ( + ATTR_ENTITY_ID, + SERVICE_TURN_OFF, + SERVICE_TURN_ON, + STATE_OFF, + STATE_ON, + STATE_UNAVAILABLE, +) +from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er +from homeassistant.util.dt import utcnow + +from . import SETTINGS, init_integration + +from tests.common import async_fire_time_changed + + +async def test_switch(hass: HomeAssistant) -> None: + """Test states of the switches.""" + registry = er.async_get(hass) + + await init_integration(hass) + + state = hass.states.get("switch.fake_profile_ai_driven_threat_detection") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_ai_driven_threat_detection") + assert entry + assert entry.unique_id == "xyz12_ai_threat_detection" + + state = hass.states.get("switch.fake_profile_allow_affiliate_tracking_links") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_allow_affiliate_tracking_links") + assert entry + assert entry.unique_id == "xyz12_allow_affiliate" + + state = hass.states.get("switch.fake_profile_anonymized_edns_client_subnet") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_anonymized_edns_client_subnet") + assert entry + assert entry.unique_id == "xyz12_anonymized_ecs" + + state = hass.states.get("switch.fake_profile_block_bypass_methods") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_block_bypass_methods") + assert entry + assert entry.unique_id == "xyz12_block_bypass_methods" + + state = hass.states.get("switch.fake_profile_block_child_sexual_abuse_material") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_block_child_sexual_abuse_material") + assert entry + assert entry.unique_id == "xyz12_block_csam" + + state = hass.states.get("switch.fake_profile_block_disguised_third_party_trackers") + assert state + assert state.state == STATE_ON + + entry = registry.async_get( + "switch.fake_profile_block_disguised_third_party_trackers" + ) + assert entry + assert entry.unique_id == "xyz12_block_disguised_trackers" + + state = hass.states.get("switch.fake_profile_block_dynamic_dns_hostnames") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_block_dynamic_dns_hostnames") + assert entry + assert entry.unique_id == "xyz12_block_ddns" + + state = hass.states.get("switch.fake_profile_block_newly_registered_domains") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_block_newly_registered_domains") + assert entry + assert entry.unique_id == "xyz12_block_nrd" + + state = hass.states.get("switch.fake_profile_block_page") + assert state + assert state.state == STATE_OFF + + entry = registry.async_get("switch.fake_profile_block_page") + assert entry + assert entry.unique_id == "xyz12_block_page" + + state = hass.states.get("switch.fake_profile_block_parked_domains") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_block_parked_domains") + assert entry + assert entry.unique_id == "xyz12_block_parked_domains" + + state = hass.states.get("switch.fake_profile_cname_flattening") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_cname_flattening") + assert entry + assert entry.unique_id == "xyz12_cname_flattening" + + state = hass.states.get("switch.fake_profile_cache_boost") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_cache_boost") + assert entry + assert entry.unique_id == "xyz12_cache_boost" + + state = hass.states.get("switch.fake_profile_cryptojacking_protection") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_cryptojacking_protection") + assert entry + assert entry.unique_id == "xyz12_cryptojacking_protection" + + state = hass.states.get("switch.fake_profile_dns_rebinding_protection") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_dns_rebinding_protection") + assert entry + assert entry.unique_id == "xyz12_dns_rebinding_protection" + + state = hass.states.get( + "switch.fake_profile_domain_generation_algorithms_protection" + ) + assert state + assert state.state == STATE_ON + + entry = registry.async_get( + "switch.fake_profile_domain_generation_algorithms_protection" + ) + assert entry + assert entry.unique_id == "xyz12_dga_protection" + + state = hass.states.get("switch.fake_profile_force_safesearch") + assert state + assert state.state == STATE_OFF + + entry = registry.async_get("switch.fake_profile_force_safesearch") + assert entry + assert entry.unique_id == "xyz12_safesearch" + + state = hass.states.get("switch.fake_profile_force_youtube_restricted_mode") + assert state + assert state.state == STATE_OFF + + entry = registry.async_get("switch.fake_profile_force_youtube_restricted_mode") + assert entry + assert entry.unique_id == "xyz12_youtube_restricted_mode" + + state = hass.states.get("switch.fake_profile_google_safe_browsing") + assert state + assert state.state == STATE_OFF + + entry = registry.async_get("switch.fake_profile_google_safe_browsing") + assert entry + assert entry.unique_id == "xyz12_google_safe_browsing" + + state = hass.states.get("switch.fake_profile_idn_homograph_attacks_protection") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_idn_homograph_attacks_protection") + assert entry + assert entry.unique_id == "xyz12_idn_homograph_attacks_protection" + + state = hass.states.get("switch.fake_profile_logs") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_logs") + assert entry + assert entry.unique_id == "xyz12_logs" + + state = hass.states.get("switch.fake_profile_threat_intelligence_feeds") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_threat_intelligence_feeds") + assert entry + assert entry.unique_id == "xyz12_threat_intelligence_feeds" + + state = hass.states.get("switch.fake_profile_typosquatting_protection") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_typosquatting_protection") + assert entry + assert entry.unique_id == "xyz12_typosquatting_protection" + + state = hass.states.get("switch.fake_profile_web3") + assert state + assert state.state == STATE_ON + + entry = registry.async_get("switch.fake_profile_web3") + assert entry + assert entry.unique_id == "xyz12_web3" + + +async def test_switch_on(hass: HomeAssistant) -> None: + """Test the switch can be turned on.""" + await init_integration(hass) + + state = hass.states.get("switch.fake_profile_block_page") + assert state + assert state.state == STATE_OFF + + with patch( + "homeassistant.components.nextdns.NextDns.set_setting", return_value=True + ) as mock_switch_on: + assert await hass.services.async_call( + SWITCH_DOMAIN, + SERVICE_TURN_ON, + {ATTR_ENTITY_ID: "switch.fake_profile_block_page"}, + blocking=True, + ) + await hass.async_block_till_done() + + state = hass.states.get("switch.fake_profile_block_page") + assert state + assert state.state == STATE_ON + + mock_switch_on.assert_called_once() + + +async def test_switch_off(hass: HomeAssistant) -> None: + """Test the switch can be turned on.""" + await init_integration(hass) + + state = hass.states.get("switch.fake_profile_web3") + assert state + assert state.state == STATE_ON + + with patch( + "homeassistant.components.nextdns.NextDns.set_setting", return_value=True + ) as mock_switch_on: + assert await hass.services.async_call( + SWITCH_DOMAIN, + SERVICE_TURN_OFF, + {ATTR_ENTITY_ID: "switch.fake_profile_web3"}, + blocking=True, + ) + await hass.async_block_till_done() + + state = hass.states.get("switch.fake_profile_web3") + assert state + assert state.state == STATE_OFF + + mock_switch_on.assert_called_once() + + +async def test_availability(hass: HomeAssistant) -> None: + """Ensure that we mark the entities unavailable correctly when service causes an error.""" + await init_integration(hass) + + state = hass.states.get("switch.fake_profile_web3") + assert state + assert state.state != STATE_UNAVAILABLE + assert state.state == STATE_ON + + future = utcnow() + timedelta(minutes=10) + with patch( + "homeassistant.components.nextdns.NextDns.get_settings", + side_effect=ApiError("API Error"), + ): + async_fire_time_changed(hass, future) + await hass.async_block_till_done() + + state = hass.states.get("switch.fake_profile_web3") + assert state + assert state.state == STATE_UNAVAILABLE + + future = utcnow() + timedelta(minutes=20) + with patch( + "homeassistant.components.nextdns.NextDns.get_settings", + return_value=SETTINGS, + ): + async_fire_time_changed(hass, future) + await hass.async_block_till_done() + + state = hass.states.get("switch.fake_profile_web3") + assert state + assert state.state != STATE_UNAVAILABLE + assert state.state == STATE_ON