diff --git a/homeassistant/components/wemo/config_flow.py b/homeassistant/components/wemo/config_flow.py index 73f4303cfd6..e4626b4deaf 100644 --- a/homeassistant/components/wemo/config_flow.py +++ b/homeassistant/components/wemo/config_flow.py @@ -1,11 +1,20 @@ """Config flow for Wemo.""" -import pywemo +from __future__ import annotations -from homeassistant.core import HomeAssistant -from homeassistant.helpers import config_entry_flow +from dataclasses import fields +from typing import Any, get_type_hints + +import pywemo +import voluptuous as vol + +from homeassistant.config_entries import ConfigEntry, OptionsFlow +from homeassistant.core import HomeAssistant, callback +from homeassistant.data_entry_flow import FlowResult +from homeassistant.helpers.config_entry_flow import DiscoveryFlowHandler from .const import DOMAIN +from .wemo_device import Options, OptionsValidationError async def _async_has_devices(hass: HomeAssistant) -> bool: @@ -13,4 +22,58 @@ async def _async_has_devices(hass: HomeAssistant) -> bool: return bool(await hass.async_add_executor_job(pywemo.discover_devices)) -config_entry_flow.register_discovery_flow(DOMAIN, "Wemo", _async_has_devices) +class WemoFlow(DiscoveryFlowHandler, domain=DOMAIN): + """Discovery flow with options for Wemo.""" + + def __init__(self) -> None: + """Init discovery flow.""" + super().__init__(DOMAIN, "Wemo", _async_has_devices) + + @staticmethod + @callback + def async_get_options_flow(config_entry: ConfigEntry) -> OptionsFlow: + """Get the options flow for this handler.""" + return WemoOptionsFlow(config_entry) + + +class WemoOptionsFlow(OptionsFlow): + """Options flow for the WeMo component.""" + + def __init__(self, config_entry: ConfigEntry) -> None: + """Initialize options flow.""" + self.config_entry = config_entry + + async def async_step_init( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Manage options for the WeMo component.""" + errors: dict[str, str] | None = None + if user_input is not None: + try: + Options(**user_input) + except OptionsValidationError as err: + errors = {err.field_key: err.error_key} + else: + return self.async_create_entry(title="", data=user_input) + + return self.async_show_form( + step_id="init", + data_schema=_schema_for_options(Options(**self.config_entry.options)), + errors=errors, + ) + + +def _schema_for_options(options: Options) -> vol.Schema: + """Return the Voluptuous schema for the Options instance. + + All values are optional. The default value is set to the current value and + the type hint is set to the value of the field type annotation. + """ + return vol.Schema( + { + vol.Optional( + field.name, default=getattr(options, field.name) + ): get_type_hints(options)[field.name] + for field in fields(options) + } + ) diff --git a/homeassistant/components/wemo/strings.json b/homeassistant/components/wemo/strings.json index 3419b2cb3d1..dfe5d94bb8a 100644 --- a/homeassistant/components/wemo/strings.json +++ b/homeassistant/components/wemo/strings.json @@ -10,6 +10,22 @@ "no_devices_found": "[%key:common::config_flow::abort::no_devices_found%]" } }, + "options": { + "step": { + "init": { + "data": { + "enable_subscription": "Subscribe to device local push updates", + "enable_long_press": "Register for device long-press events", + "polling_interval_seconds": "Seconds to wait between polling the device" + } + } + }, + "error": { + "long_press_requires_subscription": "Local push update subscriptions must be enabled to use long-press events", + "polling_interval_to_small": "Polling more frequently than 10 seconds is not supported", + "unknown": "[%key:common::config_flow::error::unknown%]" + } + }, "device_automation": { "trigger_type": { "long_press": "Wemo button was pressed for 2 seconds" diff --git a/homeassistant/components/wemo/wemo_device.py b/homeassistant/components/wemo/wemo_device.py index 9f2c72a1585..902b196fe7d 100644 --- a/homeassistant/components/wemo/wemo_device.py +++ b/homeassistant/components/wemo/wemo_device.py @@ -1,10 +1,14 @@ """Home Assistant wrapper for a pyWeMo device.""" +from __future__ import annotations + import asyncio +from dataclasses import dataclass, fields from datetime import timedelta import logging +from typing import Literal from pywemo import Insight, LongPressMixin, WeMoDevice -from pywemo.exceptions import ActionException +from pywemo.exceptions import ActionException, PyWeMoException from pywemo.subscribe import EVENT_TYPE_LONG_PRESS from homeassistant.config_entries import ConfigEntry @@ -29,23 +33,87 @@ from .const import DOMAIN, WEMO_SUBSCRIPTION_EVENT _LOGGER = logging.getLogger(__name__) +# Literal values must match options.error keys from strings.json. +ErrorStringKey = Literal[ + "long_press_requires_subscription", "polling_interval_to_small" +] +# Literal values must match options.step.init.data keys from strings.json. +OptionsFieldKey = Literal[ + "enable_subscription", "enable_long_press", "polling_interval_seconds" +] + + +class OptionsValidationError(Exception): + """Error validating options.""" + + def __init__( + self, field_key: OptionsFieldKey, error_key: ErrorStringKey, message: str + ) -> None: + """Store field and error_key so the exception handler can used them. + + The field_key and error_key strings must be the same as in strings.json. + + Args: + field_key: Name of the options.step.init.data key that corresponds to this error. + field_key must also match one of the field names inside the Options class. + error_key: Name of the options.error key that corresponds to this error. + message: Message for the Exception class. + """ + super().__init__(message) + self.field_key = field_key + self.error_key = error_key + + +@dataclass(frozen=True) +class Options: + """Configuration options for the DeviceCoordinator class. + + Note: The field names must match the keys (OptionsFieldKey) + from options.step.init.data in strings.json. + """ + + # Subscribe to device local push updates. + enable_subscription: bool = True + + # Register for device long-press events. + enable_long_press: bool = True + + # Polling interval for when subscriptions are not enabled or broken. + polling_interval_seconds: int = 30 + + def __post_init__(self) -> None: + """Validate parameters.""" + if not self.enable_subscription and self.enable_long_press: + raise OptionsValidationError( + "enable_subscription", + "long_press_requires_subscription", + "Local push update subscriptions must be enabled to use long-press events", + ) + if self.polling_interval_seconds < 10: + raise OptionsValidationError( + "polling_interval_seconds", + "polling_interval_to_small", + "Polling more frequently than 10 seconds is not supported", + ) + class DeviceCoordinator(DataUpdateCoordinator[None]): """Home Assistant wrapper for a pyWeMo device.""" + options: Options | None = None + def __init__(self, hass: HomeAssistant, wemo: WeMoDevice, device_id: str) -> None: """Initialize DeviceCoordinator.""" super().__init__( hass, _LOGGER, name=wemo.name, - update_interval=timedelta(seconds=30), ) self.hass = hass self.wemo = wemo self.device_id = device_id self.device_info = _create_device_info(wemo) - self.supports_long_press = wemo.supports_long_press() + self.supports_long_press = isinstance(wemo, LongPressMixin) self.update_lock = asyncio.Lock() def subscription_callback( @@ -68,6 +136,54 @@ class DeviceCoordinator(DataUpdateCoordinator[None]): updated = self.wemo.subscription_update(event_type, params) self.hass.create_task(self._async_subscription_callback(updated)) + async def _async_set_enable_subscription(self, enable_subscription: bool) -> None: + """Turn on/off push updates from the device.""" + registry = self.hass.data[DOMAIN]["registry"] + if enable_subscription: + registry.on(self.wemo, None, self.subscription_callback) + await self.hass.async_add_executor_job(registry.register, self.wemo) + elif self.options is not None: + await self.hass.async_add_executor_job(registry.unregister, self.wemo) + + async def _async_set_enable_long_press(self, enable_long_press: bool) -> None: + """Turn on/off long-press events from the device.""" + if not (isinstance(self.wemo, LongPressMixin) and self.supports_long_press): + return + try: + if enable_long_press: + await self.hass.async_add_executor_job( + self.wemo.ensure_long_press_virtual_device + ) + elif self.options is not None: + await self.hass.async_add_executor_job( + self.wemo.remove_long_press_virtual_device + ) + except PyWeMoException: + _LOGGER.exception( + "Failed to enable long press support for device: %s", self.wemo.name + ) + self.supports_long_press = False + + async def _async_set_polling_interval_seconds( + self, polling_interval_seconds: int + ) -> None: + self.update_interval = timedelta(seconds=polling_interval_seconds) + + async def async_set_options( + self, hass: HomeAssistant, config_entry: ConfigEntry + ) -> None: + """Update the configuration options for the device.""" + options = Options(**config_entry.options) + _LOGGER.debug( + "async_set_options old(%s) new(%s)", repr(self.options), repr(options) + ) + for field in fields(options): + new_value = getattr(options, field.name) + if self.options is None or getattr(self.options, field.name) != new_value: + # The value changed, call the _async_set_* method for the option. + await getattr(self, f"_async_set_{field.name}")(new_value) + self.options = options + async def _async_subscription_callback(self, updated: bool) -> None: """Update the state by the Wemo device.""" # If an update is in progress, we don't do anything. @@ -160,20 +276,11 @@ async def async_register_device( device = DeviceCoordinator(hass, wemo, entry.id) hass.data[DOMAIN].setdefault("devices", {})[entry.id] = device - registry = hass.data[DOMAIN]["registry"] - registry.on(wemo, None, device.subscription_callback) - await hass.async_add_executor_job(registry.register, wemo) - if isinstance(wemo, LongPressMixin): - try: - await hass.async_add_executor_job(wemo.ensure_long_press_virtual_device) - # Temporarily handling all exceptions for #52996 & pywemo/pywemo/issues/276 - # Replace this with `except: PyWeMoException` after upstream has been fixed. - except Exception: # pylint: disable=broad-except - _LOGGER.exception( - "Failed to enable long press support for device: %s", wemo.name - ) - device.supports_long_press = False + config_entry.async_on_unload( + config_entry.add_update_listener(device.async_set_options) + ) + await device.async_set_options(hass, config_entry) return device diff --git a/tests/components/wemo/test_config_flow.py b/tests/components/wemo/test_config_flow.py index 0ffcb9d7f5e..e84ae4f0205 100644 --- a/tests/components/wemo/test_config_flow.py +++ b/tests/components/wemo/test_config_flow.py @@ -1,11 +1,14 @@ """Tests for Wemo config flow.""" +from dataclasses import asdict + from homeassistant import data_entry_flow from homeassistant.components.wemo.const import DOMAIN +from homeassistant.components.wemo.wemo_device import Options from homeassistant.config_entries import SOURCE_USER from homeassistant.core import HomeAssistant -from tests.common import patch +from tests.common import MockConfigEntry, patch async def test_not_discovered(hass: HomeAssistant) -> None: @@ -20,3 +23,47 @@ async def test_not_discovered(hass: HomeAssistant) -> None: result = await hass.config_entries.flow.async_configure(result["flow_id"], {}) assert result["type"] == data_entry_flow.FlowResultType.ABORT assert result["reason"] == "no_devices_found" + + +async def test_options(hass: HomeAssistant) -> None: + """Test updating options.""" + options = Options(enable_subscription=False, enable_long_press=False) + entry = MockConfigEntry(domain=DOMAIN, title="Wemo") + entry.add_to_hass(hass) + + result = await hass.config_entries.options.async_init(entry.entry_id) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "init" + + result = await hass.config_entries.options.async_configure( + result["flow_id"], user_input=asdict(options) + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert Options(**result["data"]) == options + + +async def test_invalid_options(hass: HomeAssistant) -> None: + """Test invalid option combinations.""" + entry = MockConfigEntry(domain=DOMAIN, title="Wemo") + entry.add_to_hass(hass) + + result = await hass.config_entries.options.async_init(entry.entry_id) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "init" + + # enable_subscription must be True if enable_long_press is True (default). + result = await hass.config_entries.options.async_configure( + result["flow_id"], user_input={"enable_subscription": False} + ) + assert result["errors"] == { + "enable_subscription": "long_press_requires_subscription" + } + + # polling_interval_seconds must be larger than 10. + result = await hass.config_entries.options.async_configure( + result["flow_id"], user_input={"polling_interval_seconds": 1} + ) + assert result["errors"] == {"polling_interval_seconds": "polling_interval_to_small"} diff --git a/tests/components/wemo/test_wemo_device.py b/tests/components/wemo/test_wemo_device.py index 40e06f0a698..2ace50f5543 100644 --- a/tests/components/wemo/test_wemo_device.py +++ b/tests/components/wemo/test_wemo_device.py @@ -1,5 +1,6 @@ """Tests for wemo_device.py.""" import asyncio +from dataclasses import asdict from datetime import timedelta from unittest.mock import call, patch @@ -188,6 +189,62 @@ async def test_dli_device_info( assert device_entries[0].identifiers == {(DOMAIN, "123456789")} +async def test_options_enable_subscription_false( + hass, pywemo_registry, pywemo_device, wemo_entity +): + """Test setting Options.enable_subscription = False.""" + config_entry = hass.config_entries.async_get_entry(wemo_entity.config_entry_id) + assert hass.config_entries.async_update_entry( + config_entry, + options=asdict( + wemo_device.Options(enable_subscription=False, enable_long_press=False) + ), + ) + await hass.async_block_till_done() + pywemo_registry.unregister.assert_called_once_with(pywemo_device) + + +async def test_options_enable_long_press_false(hass, pywemo_device, wemo_entity): + """Test setting Options.enable_long_press = False.""" + config_entry = hass.config_entries.async_get_entry(wemo_entity.config_entry_id) + assert hass.config_entries.async_update_entry( + config_entry, options=asdict(wemo_device.Options(enable_long_press=False)) + ) + await hass.async_block_till_done() + pywemo_device.remove_long_press_virtual_device.assert_called_once_with() + + +async def test_options_polling_interval_seconds(hass, pywemo_device, wemo_entity): + """Test setting Options.polling_interval_seconds = 45.""" + config_entry = hass.config_entries.async_get_entry(wemo_entity.config_entry_id) + assert hass.config_entries.async_update_entry( + config_entry, + options=asdict( + wemo_device.Options( + enable_subscription=False, + enable_long_press=False, + polling_interval_seconds=45, + ) + ), + ) + await hass.async_block_till_done() + + # Move time forward to capture the new interval. + async_fire_time_changed(hass, utcnow() + timedelta(seconds=31)) + await hass.async_block_till_done() + pywemo_device.get_state.reset_mock() + + # Make sure no polling occurs before 45 seconds. + async_fire_time_changed(hass, utcnow() + timedelta(seconds=31)) + await hass.async_block_till_done() + pywemo_device.get_state.assert_not_called() + + # Polling occurred after the interval. + async_fire_time_changed(hass, utcnow() + timedelta(seconds=46)) + await hass.async_block_till_done() + pywemo_device.get_state.assert_has_calls([call(True), call()]) + + class TestInsight: """Tests specific to the WeMo Insight device."""