From fa245e24f8a12f000fe6dacd4e414def6bdb077a Mon Sep 17 00:00:00 2001 From: Aaron Bach Date: Wed, 21 Sep 2022 11:34:04 -0600 Subject: [PATCH] Fix bug wherein RainMachine services use the wrong controller (#78780) --- .../components/rainmachine/__init__.py | 77 +++++++++++++------ 1 file changed, 52 insertions(+), 25 deletions(-) diff --git a/homeassistant/components/rainmachine/__init__.py b/homeassistant/components/rainmachine/__init__.py index 756dc9b958d..4d19dbc7bfc 100644 --- a/homeassistant/components/rainmachine/__init__.py +++ b/homeassistant/components/rainmachine/__init__.py @@ -2,9 +2,10 @@ from __future__ import annotations import asyncio +from collections.abc import Awaitable, Callable from dataclasses import dataclass from datetime import timedelta -from functools import partial +from functools import partial, wraps from typing import Any from regenmaschine import Client @@ -22,7 +23,7 @@ from homeassistant.const import ( Platform, ) from homeassistant.core import HomeAssistant, ServiceCall, callback -from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError from homeassistant.helpers import ( aiohttp_client, config_validation as cv, @@ -152,9 +153,9 @@ class RainMachineData: @callback -def async_get_controller_for_service_call( +def async_get_entry_for_service_call( hass: HomeAssistant, call: ServiceCall -) -> Controller: +) -> ConfigEntry: """Get the controller related to a service call (by device ID).""" device_id = call.data[CONF_DEVICE_ID] device_registry = dr.async_get(hass) @@ -166,8 +167,7 @@ def async_get_controller_for_service_call( if (entry := hass.config_entries.async_get_entry(entry_id)) is None: continue if entry.domain == DOMAIN: - data: RainMachineData = hass.data[DOMAIN][entry_id] - return data.controller + return entry raise ValueError(f"No controller for device ID: {device_id}") @@ -288,15 +288,42 @@ async def async_setup_entry( # noqa: C901 entry.async_on_unload(entry.add_update_listener(async_reload_entry)) - async def async_pause_watering(call: ServiceCall) -> None: - """Pause watering for a set number of seconds.""" - controller = async_get_controller_for_service_call(hass, call) - await controller.watering.pause_all(call.data[CONF_SECONDS]) - await async_update_programs_and_zones(hass, entry) + def call_with_controller(update_programs_and_zones: bool = True) -> Callable: + """Hydrate a service call with the appropriate controller.""" - async def async_push_weather_data(call: ServiceCall) -> None: + def decorator(func: Callable) -> Callable[..., Awaitable]: + """Define the decorator.""" + + @wraps(func) + async def wrapper(call: ServiceCall) -> None: + """Wrap the service function.""" + entry = async_get_entry_for_service_call(hass, call) + data: RainMachineData = hass.data[DOMAIN][entry.entry_id] + + try: + await func(call, data.controller) + except RainMachineError as err: + raise HomeAssistantError( + f"Error while executing {func.__name__}: {err}" + ) from err + + if update_programs_and_zones: + await async_update_programs_and_zones(hass, entry) + + return wrapper + + return decorator + + @call_with_controller() + async def async_pause_watering(call: ServiceCall, controller: Controller) -> None: + """Pause watering for a set number of seconds.""" + await controller.watering.pause_all(call.data[CONF_SECONDS]) + + @call_with_controller(update_programs_and_zones=False) + async def async_push_weather_data( + call: ServiceCall, controller: Controller + ) -> None: """Push weather data to the device.""" - controller = async_get_controller_for_service_call(hass, call) await controller.parsers.post_data( { CONF_WEATHER: [ @@ -309,9 +336,11 @@ async def async_setup_entry( # noqa: C901 } ) - async def async_restrict_watering(call: ServiceCall) -> None: + @call_with_controller() + async def async_restrict_watering( + call: ServiceCall, controller: Controller + ) -> None: """Restrict watering for a time period.""" - controller = async_get_controller_for_service_call(hass, call) duration = call.data[CONF_DURATION] await controller.restrictions.set_universal( { @@ -319,30 +348,28 @@ async def async_setup_entry( # noqa: C901 "rainDelayDuration": duration.total_seconds(), }, ) - await async_update_programs_and_zones(hass, entry) - async def async_stop_all(call: ServiceCall) -> None: + @call_with_controller() + async def async_stop_all(call: ServiceCall, controller: Controller) -> None: """Stop all watering.""" - controller = async_get_controller_for_service_call(hass, call) await controller.watering.stop_all() - await async_update_programs_and_zones(hass, entry) - async def async_unpause_watering(call: ServiceCall) -> None: + @call_with_controller() + async def async_unpause_watering(call: ServiceCall, controller: Controller) -> None: """Unpause watering.""" - controller = async_get_controller_for_service_call(hass, call) await controller.watering.unpause_all() - await async_update_programs_and_zones(hass, entry) - async def async_unrestrict_watering(call: ServiceCall) -> None: + @call_with_controller() + async def async_unrestrict_watering( + call: ServiceCall, controller: Controller + ) -> None: """Unrestrict watering.""" - controller = async_get_controller_for_service_call(hass, call) await controller.restrictions.set_universal( { "rainDelayStartTime": round(as_timestamp(utcnow())), "rainDelayDuration": 0, }, ) - await async_update_programs_and_zones(hass, entry) for service_name, schema, method in ( (