From e088119d6d1fe5063998f047807b8be43da5273d Mon Sep 17 00:00:00 2001 From: Santobert Date: Sat, 5 Oct 2019 21:42:37 +0200 Subject: [PATCH] fan_reproduce_state (#27227) --- .../components/fan/reproduce_state.py | 100 ++++++++++++++++++ tests/components/fan/test_reproduce_state.py | 89 ++++++++++++++++ 2 files changed, 189 insertions(+) create mode 100644 homeassistant/components/fan/reproduce_state.py create mode 100644 tests/components/fan/test_reproduce_state.py diff --git a/homeassistant/components/fan/reproduce_state.py b/homeassistant/components/fan/reproduce_state.py new file mode 100644 index 00000000000..1053861e2bf --- /dev/null +++ b/homeassistant/components/fan/reproduce_state.py @@ -0,0 +1,100 @@ +"""Reproduce an Fan state.""" +import asyncio +import logging +from types import MappingProxyType +from typing import Iterable, Optional + +from homeassistant.const import ( + ATTR_ENTITY_ID, + STATE_ON, + STATE_OFF, + SERVICE_TURN_OFF, + SERVICE_TURN_ON, +) +from homeassistant.core import Context, State +from homeassistant.helpers.typing import HomeAssistantType + +from . import ( + DOMAIN, + ATTR_DIRECTION, + ATTR_OSCILLATING, + ATTR_SPEED, + SERVICE_OSCILLATE, + SERVICE_SET_DIRECTION, + SERVICE_SET_SPEED, +) + +_LOGGER = logging.getLogger(__name__) + +VALID_STATES = {STATE_ON, STATE_OFF} +ATTRIBUTES = { # attribute: service + ATTR_DIRECTION: SERVICE_SET_DIRECTION, + ATTR_OSCILLATING: SERVICE_OSCILLATE, + ATTR_SPEED: SERVICE_SET_SPEED, +} + + +async def _async_reproduce_state( + hass: HomeAssistantType, state: State, context: Optional[Context] = None +) -> None: + """Reproduce a single state.""" + cur_state = hass.states.get(state.entity_id) + + if cur_state is None: + _LOGGER.warning("Unable to find entity %s", state.entity_id) + return + + if state.state not in VALID_STATES: + _LOGGER.warning( + "Invalid state specified for %s: %s", state.entity_id, state.state + ) + return + + # Return if we are already at the right state. + if cur_state.state == state.state and all( + check_attr_equal(cur_state.attributes, state.attributes, attr) + for attr in ATTRIBUTES + ): + return + + service_data = {ATTR_ENTITY_ID: state.entity_id} + service_calls = {} # service: service_data + + if state.state == STATE_ON: + # The fan should be on + if cur_state.state != STATE_ON: + # Turn on the fan at first + service_calls[SERVICE_TURN_ON] = service_data + + for attr, service in ATTRIBUTES.items(): + # Call services to adjust the attributes + if attr in state.attributes and not check_attr_equal( + state.attributes, cur_state.attributes, attr + ): + data = service_data.copy() + data[attr] = state.attributes[attr] + service_calls[service] = data + + elif state.state == STATE_OFF: + service_calls[SERVICE_TURN_OFF] = service_data + + for service, data in service_calls.items(): + await hass.services.async_call( + DOMAIN, service, data, context=context, blocking=True + ) + + +async def async_reproduce_states( + hass: HomeAssistantType, states: Iterable[State], context: Optional[Context] = None +) -> None: + """Reproduce Fan states.""" + await asyncio.gather( + *(_async_reproduce_state(hass, state, context) for state in states) + ) + + +def check_attr_equal( + attr1: MappingProxyType, attr2: MappingProxyType, attr_str: str +) -> bool: + """Return true if the given attributes are equal.""" + return attr1.get(attr_str) == attr2.get(attr_str) diff --git a/tests/components/fan/test_reproduce_state.py b/tests/components/fan/test_reproduce_state.py new file mode 100644 index 00000000000..0dcd38580b8 --- /dev/null +++ b/tests/components/fan/test_reproduce_state.py @@ -0,0 +1,89 @@ +"""Test reproduce state for Fan.""" +from homeassistant.core import State + +from tests.common import async_mock_service + + +async def test_reproducing_states(hass, caplog): + """Test reproducing Fan states.""" + hass.states.async_set("fan.entity_off", "off", {}) + hass.states.async_set("fan.entity_on", "on", {}) + hass.states.async_set("fan.entity_speed", "on", {"speed": "high"}) + hass.states.async_set("fan.entity_oscillating", "on", {"oscillating": True}) + hass.states.async_set("fan.entity_direction", "on", {"direction": "forward"}) + + turn_on_calls = async_mock_service(hass, "fan", "turn_on") + turn_off_calls = async_mock_service(hass, "fan", "turn_off") + set_direction_calls = async_mock_service(hass, "fan", "set_direction") + oscillate_calls = async_mock_service(hass, "fan", "oscillate") + set_speed_calls = async_mock_service(hass, "fan", "set_speed") + + # These calls should do nothing as entities already in desired state + await hass.helpers.state.async_reproduce_state( + [ + State("fan.entity_off", "off"), + State("fan.entity_on", "on"), + State("fan.entity_speed", "on", {"speed": "high"}), + State("fan.entity_oscillating", "on", {"oscillating": True}), + State("fan.entity_direction", "on", {"direction": "forward"}), + ], + blocking=True, + ) + + assert len(turn_on_calls) == 0 + assert len(turn_off_calls) == 0 + assert len(set_direction_calls) == 0 + assert len(oscillate_calls) == 0 + assert len(set_speed_calls) == 0 + + # Test invalid state is handled + await hass.helpers.state.async_reproduce_state( + [State("fan.entity_off", "not_supported")], blocking=True + ) + + assert "not_supported" in caplog.text + assert len(turn_on_calls) == 0 + assert len(turn_off_calls) == 0 + assert len(set_direction_calls) == 0 + assert len(oscillate_calls) == 0 + assert len(set_speed_calls) == 0 + + # Make sure correct services are called + await hass.helpers.state.async_reproduce_state( + [ + State("fan.entity_on", "off"), + State("fan.entity_off", "on"), + State("fan.entity_speed", "on", {"speed": "low"}), + State("fan.entity_oscillating", "on", {"oscillating": False}), + State("fan.entity_direction", "on", {"direction": "reverse"}), + # Should not raise + State("fan.non_existing", "on"), + ], + blocking=True, + ) + + assert len(turn_on_calls) == 1 + assert turn_on_calls[0].domain == "fan" + assert turn_on_calls[0].data == {"entity_id": "fan.entity_off"} + + assert len(set_direction_calls) == 1 + assert set_direction_calls[0].domain == "fan" + assert set_direction_calls[0].data == { + "entity_id": "fan.entity_direction", + "direction": "reverse", + } + + assert len(oscillate_calls) == 1 + assert oscillate_calls[0].domain == "fan" + assert oscillate_calls[0].data == { + "entity_id": "fan.entity_oscillating", + "oscillating": False, + } + + assert len(set_speed_calls) == 1 + assert set_speed_calls[0].domain == "fan" + assert set_speed_calls[0].data == {"entity_id": "fan.entity_speed", "speed": "low"} + + assert len(turn_off_calls) == 1 + assert turn_off_calls[0].domain == "fan" + assert turn_off_calls[0].data == {"entity_id": "fan.entity_on"}