diff --git a/homeassistant/components/rainmachine/switch.py b/homeassistant/components/rainmachine/switch.py index ee6ac670840..029c8c06771 100644 --- a/homeassistant/components/rainmachine/switch.py +++ b/homeassistant/components/rainmachine/switch.py @@ -2,12 +2,13 @@ from __future__ import annotations import asyncio -from collections.abc import Coroutine +from collections.abc import Awaitable, Callable, Coroutine from dataclasses import dataclass from datetime import datetime -from typing import Any +from typing import Any, TypeVar -from regenmaschine.errors import RequestError +from regenmaschine.errors import RainMachineError +from typing_extensions import Concatenate, ParamSpec import voluptuous as vol from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription @@ -104,6 +105,27 @@ VEGETATION_MAP = { } +_T = TypeVar("_T", bound="RainMachineBaseSwitch") +_P = ParamSpec("_P") + + +def raise_on_request_error( + func: Callable[Concatenate[_T, _P], Awaitable[None]] +) -> Callable[Concatenate[_T, _P], Coroutine[Any, Any, None]]: + """Define a decorator to raise on a request error.""" + + async def decorator(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> None: + """Decorate.""" + try: + await func(self, *args, **kwargs) + except RainMachineError as err: + raise HomeAssistantError( + f"Error while executing {func.__name__}: {err}", + ) from err + + return decorator + + @dataclass class RainMachineSwitchDescription( SwitchEntityDescription, @@ -197,22 +219,9 @@ class RainMachineBaseSwitch(RainMachineEntity, SwitchEntity): self._attr_is_on = False self._entry = entry - async def _async_run_api_coroutine(self, api_coro: Coroutine) -> None: - """Await an API coroutine, handle any errors, and update as appropriate.""" - try: - resp = await api_coro - except RequestError as err: - raise HomeAssistantError( - f'Error while executing {api_coro.__name__} on "{self.name}": {err}', - ) from err - - if resp["statusCode"] != 0: - raise HomeAssistantError( - f'Error while executing {api_coro.__name__} on "{self.name}": {resp["message"]}', - ) - - # Because of how inextricably linked programs and zones are, anytime one is - # toggled, we make sure to update the data of both coordinators: + @callback + def _update_activities(self) -> None: + """Update all activity data.""" self.hass.async_create_task( async_update_programs_and_zones(self.hass, self._entry) ) @@ -250,6 +259,7 @@ class RainMachineActivitySwitch(RainMachineBaseSwitch): await self.async_turn_off_when_active(**kwargs) + @raise_on_request_error async def async_turn_off_when_active(self, **kwargs: Any) -> None: """Turn the switch off when its associated activity is active.""" raise NotImplementedError @@ -265,6 +275,7 @@ class RainMachineActivitySwitch(RainMachineBaseSwitch): await self.async_turn_on_when_active(**kwargs) + @raise_on_request_error async def async_turn_on_when_active(self, **kwargs: Any) -> None: """Turn the switch on when its associated activity is active.""" raise NotImplementedError @@ -290,17 +301,17 @@ class RainMachineProgram(RainMachineActivitySwitch): """Stop the program.""" await self.async_turn_off() + @raise_on_request_error async def async_turn_off_when_active(self, **kwargs: Any) -> None: """Turn the switch off when its associated activity is active.""" - await self._async_run_api_coroutine( - self._data.controller.programs.stop(self.entity_description.uid) - ) + await self._data.controller.programs.stop(self.entity_description.uid) + self._update_activities() + @raise_on_request_error async def async_turn_on_when_active(self, **kwargs: Any) -> None: """Turn the switch on when its associated activity is active.""" - await self._async_run_api_coroutine( - self._data.controller.programs.start(self.entity_description.uid) - ) + await self._data.controller.programs.start(self.entity_description.uid) + self._update_activities() @callback def update_from_latest_data(self) -> None: @@ -332,24 +343,21 @@ class RainMachineProgram(RainMachineActivitySwitch): class RainMachineProgramEnabled(RainMachineEnabledSwitch): """Define a switch to enable/disable a RainMachine program.""" + @raise_on_request_error async def async_turn_off(self, **kwargs: Any) -> None: """Disable the program.""" tasks = [ - self._async_run_api_coroutine( - self._data.controller.programs.stop(self.entity_description.uid) - ), - self._async_run_api_coroutine( - self._data.controller.programs.disable(self.entity_description.uid) - ), + self._data.controller.programs.stop(self.entity_description.uid), + self._data.controller.programs.disable(self.entity_description.uid), ] - await asyncio.gather(*tasks) + self._update_activities() + @raise_on_request_error async def async_turn_on(self, **kwargs: Any) -> None: """Enable the program.""" - await self._async_run_api_coroutine( - self._data.controller.programs.enable(self.entity_description.uid) - ) + await self._data.controller.programs.enable(self.entity_description.uid) + self._update_activities() class RainMachineZone(RainMachineActivitySwitch): @@ -363,20 +371,20 @@ class RainMachineZone(RainMachineActivitySwitch): """Stop a zone.""" await self.async_turn_off() + @raise_on_request_error async def async_turn_off_when_active(self, **kwargs: Any) -> None: """Turn the switch off when its associated activity is active.""" - await self._async_run_api_coroutine( - self._data.controller.zones.stop(self.entity_description.uid) - ) + await self._data.controller.zones.stop(self.entity_description.uid) + self._update_activities() + @raise_on_request_error async def async_turn_on_when_active(self, **kwargs: Any) -> None: """Turn the switch on when its associated activity is active.""" - await self._async_run_api_coroutine( - self._data.controller.zones.start( - self.entity_description.uid, - kwargs.get("duration", self._entry.options[CONF_ZONE_RUN_TIME]), - ) + await self._data.controller.zones.start( + self.entity_description.uid, + kwargs.get("duration", self._entry.options[CONF_ZONE_RUN_TIME]), ) + self._update_activities() @callback def update_from_latest_data(self) -> None: @@ -416,21 +424,18 @@ class RainMachineZone(RainMachineActivitySwitch): class RainMachineZoneEnabled(RainMachineEnabledSwitch): """Define a switch to enable/disable a RainMachine zone.""" + @raise_on_request_error async def async_turn_off(self, **kwargs: Any) -> None: """Disable the zone.""" tasks = [ - self._async_run_api_coroutine( - self._data.controller.zones.stop(self.entity_description.uid) - ), - self._async_run_api_coroutine( - self._data.controller.zones.disable(self.entity_description.uid) - ), + self._data.controller.zones.stop(self.entity_description.uid), + self._data.controller.zones.disable(self.entity_description.uid), ] - await asyncio.gather(*tasks) + self._update_activities() + @raise_on_request_error async def async_turn_on(self, **kwargs: Any) -> None: """Enable the zone.""" - await self._async_run_api_coroutine( - self._data.controller.zones.enable(self.entity_description.uid) - ) + await self._data.controller.zones.enable(self.entity_description.uid) + self._update_activities()