Fix bug wherein RainMachine services use the wrong controller (#78780)

This commit is contained in:
Aaron Bach 2022-09-21 11:34:04 -06:00 committed by GitHub
parent 0d696b84b2
commit fa245e24f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,9 +2,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from functools import partial from functools import partial, wraps
from typing import Any from typing import Any
from regenmaschine import Client from regenmaschine import Client
@ -22,7 +23,7 @@ from homeassistant.const import (
Platform, Platform,
) )
from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
from homeassistant.helpers import ( from homeassistant.helpers import (
aiohttp_client, aiohttp_client,
config_validation as cv, config_validation as cv,
@ -152,9 +153,9 @@ class RainMachineData:
@callback @callback
def async_get_controller_for_service_call( def async_get_entry_for_service_call(
hass: HomeAssistant, call: ServiceCall hass: HomeAssistant, call: ServiceCall
) -> Controller: ) -> ConfigEntry:
"""Get the controller related to a service call (by device ID).""" """Get the controller related to a service call (by device ID)."""
device_id = call.data[CONF_DEVICE_ID] device_id = call.data[CONF_DEVICE_ID]
device_registry = dr.async_get(hass) device_registry = dr.async_get(hass)
@ -166,8 +167,7 @@ def async_get_controller_for_service_call(
if (entry := hass.config_entries.async_get_entry(entry_id)) is None: if (entry := hass.config_entries.async_get_entry(entry_id)) is None:
continue continue
if entry.domain == DOMAIN: if entry.domain == DOMAIN:
data: RainMachineData = hass.data[DOMAIN][entry_id] return entry
return data.controller
raise ValueError(f"No controller for device ID: {device_id}") raise ValueError(f"No controller for device ID: {device_id}")
@ -288,15 +288,42 @@ async def async_setup_entry( # noqa: C901
entry.async_on_unload(entry.add_update_listener(async_reload_entry)) entry.async_on_unload(entry.add_update_listener(async_reload_entry))
async def async_pause_watering(call: ServiceCall) -> None: def call_with_controller(update_programs_and_zones: bool = True) -> Callable:
"""Pause watering for a set number of seconds.""" """Hydrate a service call with the appropriate controller."""
controller = async_get_controller_for_service_call(hass, call)
await controller.watering.pause_all(call.data[CONF_SECONDS])
await async_update_programs_and_zones(hass, entry)
async def async_push_weather_data(call: ServiceCall) -> None: def decorator(func: Callable) -> Callable[..., Awaitable]:
"""Define the decorator."""
@wraps(func)
async def wrapper(call: ServiceCall) -> None:
"""Wrap the service function."""
entry = async_get_entry_for_service_call(hass, call)
data: RainMachineData = hass.data[DOMAIN][entry.entry_id]
try:
await func(call, data.controller)
except RainMachineError as err:
raise HomeAssistantError(
f"Error while executing {func.__name__}: {err}"
) from err
if update_programs_and_zones:
await async_update_programs_and_zones(hass, entry)
return wrapper
return decorator
@call_with_controller()
async def async_pause_watering(call: ServiceCall, controller: Controller) -> None:
"""Pause watering for a set number of seconds."""
await controller.watering.pause_all(call.data[CONF_SECONDS])
@call_with_controller(update_programs_and_zones=False)
async def async_push_weather_data(
call: ServiceCall, controller: Controller
) -> None:
"""Push weather data to the device.""" """Push weather data to the device."""
controller = async_get_controller_for_service_call(hass, call)
await controller.parsers.post_data( await controller.parsers.post_data(
{ {
CONF_WEATHER: [ CONF_WEATHER: [
@ -309,9 +336,11 @@ async def async_setup_entry( # noqa: C901
} }
) )
async def async_restrict_watering(call: ServiceCall) -> None: @call_with_controller()
async def async_restrict_watering(
call: ServiceCall, controller: Controller
) -> None:
"""Restrict watering for a time period.""" """Restrict watering for a time period."""
controller = async_get_controller_for_service_call(hass, call)
duration = call.data[CONF_DURATION] duration = call.data[CONF_DURATION]
await controller.restrictions.set_universal( await controller.restrictions.set_universal(
{ {
@ -319,30 +348,28 @@ async def async_setup_entry( # noqa: C901
"rainDelayDuration": duration.total_seconds(), "rainDelayDuration": duration.total_seconds(),
}, },
) )
await async_update_programs_and_zones(hass, entry)
async def async_stop_all(call: ServiceCall) -> None: @call_with_controller()
async def async_stop_all(call: ServiceCall, controller: Controller) -> None:
"""Stop all watering.""" """Stop all watering."""
controller = async_get_controller_for_service_call(hass, call)
await controller.watering.stop_all() await controller.watering.stop_all()
await async_update_programs_and_zones(hass, entry)
async def async_unpause_watering(call: ServiceCall) -> None: @call_with_controller()
async def async_unpause_watering(call: ServiceCall, controller: Controller) -> None:
"""Unpause watering.""" """Unpause watering."""
controller = async_get_controller_for_service_call(hass, call)
await controller.watering.unpause_all() await controller.watering.unpause_all()
await async_update_programs_and_zones(hass, entry)
async def async_unrestrict_watering(call: ServiceCall) -> None: @call_with_controller()
async def async_unrestrict_watering(
call: ServiceCall, controller: Controller
) -> None:
"""Unrestrict watering.""" """Unrestrict watering."""
controller = async_get_controller_for_service_call(hass, call)
await controller.restrictions.set_universal( await controller.restrictions.set_universal(
{ {
"rainDelayStartTime": round(as_timestamp(utcnow())), "rainDelayStartTime": round(as_timestamp(utcnow())),
"rainDelayDuration": 0, "rainDelayDuration": 0,
}, },
) )
await async_update_programs_and_zones(hass, entry)
for service_name, schema, method in ( for service_name, schema, method in (
( (