diff --git a/homeassistant/components/cover/reproduce_state.py b/homeassistant/components/cover/reproduce_state.py new file mode 100644 index 00000000000..64ea410ce93 --- /dev/null +++ b/homeassistant/components/cover/reproduce_state.py @@ -0,0 +1,117 @@ +"""Reproduce an Cover state.""" +import asyncio +import logging +from typing import Iterable, Optional + +from homeassistant.components.cover import ( + ATTR_CURRENT_POSITION, + ATTR_CURRENT_TILT_POSITION, + ATTR_POSITION, + ATTR_TILT_POSITION, +) +from homeassistant.const import ( + ATTR_ENTITY_ID, + SERVICE_CLOSE_COVER, + SERVICE_CLOSE_COVER_TILT, + SERVICE_OPEN_COVER, + SERVICE_OPEN_COVER_TILT, + SERVICE_SET_COVER_POSITION, + SERVICE_SET_COVER_TILT_POSITION, + STATE_CLOSED, + STATE_CLOSING, + STATE_OPEN, + STATE_OPENING, +) +from homeassistant.core import Context, State +from homeassistant.helpers.typing import HomeAssistantType + +from . import DOMAIN + +_LOGGER = logging.getLogger(__name__) + +VALID_STATES = {STATE_CLOSED, STATE_CLOSING, STATE_OPEN, STATE_OPENING} + + +async def _async_reproduce_state( + hass: HomeAssistantType, state: State, context: Optional[Context] = None +) -> None: + """Reproduce a single state.""" + cur_state = hass.states.get(state.entity_id) + + if cur_state is None: + _LOGGER.warning("Unable to find entity %s", state.entity_id) + return + + if state.state not in VALID_STATES: + _LOGGER.warning( + "Invalid state specified for %s: %s", state.entity_id, state.state + ) + return + + # Return if we are already at the right state. + if ( + cur_state.state == state.state + and cur_state.attributes.get(ATTR_CURRENT_POSITION) + == state.attributes.get(ATTR_CURRENT_POSITION) + and cur_state.attributes.get(ATTR_CURRENT_TILT_POSITION) + == state.attributes.get(ATTR_CURRENT_TILT_POSITION) + ): + return + + service_data = {ATTR_ENTITY_ID: state.entity_id} + service_data_tilting = {ATTR_ENTITY_ID: state.entity_id} + + if cur_state.state != state.state or cur_state.attributes.get( + ATTR_CURRENT_POSITION + ) != state.attributes.get(ATTR_CURRENT_POSITION): + # Open/Close + if state.state == STATE_CLOSED or state.state == STATE_CLOSING: + service = SERVICE_CLOSE_COVER + elif state.state == STATE_OPEN or state.state == STATE_OPENING: + if ( + ATTR_CURRENT_POSITION in cur_state.attributes + and ATTR_CURRENT_POSITION in state.attributes + ): + service = SERVICE_SET_COVER_POSITION + service_data[ATTR_POSITION] = state.attributes[ATTR_CURRENT_POSITION] + else: + service = SERVICE_OPEN_COVER + + await hass.services.async_call( + DOMAIN, service, service_data, context=context, blocking=True + ) + + if ( + ATTR_CURRENT_TILT_POSITION in state.attributes + and ATTR_CURRENT_TILT_POSITION in cur_state.attributes + and cur_state.attributes.get(ATTR_CURRENT_TILT_POSITION) + != state.attributes.get(ATTR_CURRENT_TILT_POSITION) + ): + # Tilt position + if state.attributes.get(ATTR_CURRENT_TILT_POSITION) == 100: + service_tilting = SERVICE_OPEN_COVER_TILT + elif state.attributes.get(ATTR_CURRENT_TILT_POSITION) == 0: + service_tilting = SERVICE_CLOSE_COVER_TILT + else: + service_tilting = SERVICE_SET_COVER_TILT_POSITION + service_data_tilting[ATTR_TILT_POSITION] = state.attributes[ + ATTR_CURRENT_TILT_POSITION + ] + + await hass.services.async_call( + DOMAIN, + service_tilting, + service_data_tilting, + context=context, + blocking=True, + ) + + +async def async_reproduce_states( + hass: HomeAssistantType, states: Iterable[State], context: Optional[Context] = None +) -> None: + """Reproduce Cover states.""" + # Reproduce states in parallel. + await asyncio.gather( + *(_async_reproduce_state(hass, state, context) for state in states) + ) diff --git a/tests/components/cover/test_reproduce_state.py b/tests/components/cover/test_reproduce_state.py new file mode 100644 index 00000000000..39fdf3d3992 --- /dev/null +++ b/tests/components/cover/test_reproduce_state.py @@ -0,0 +1,198 @@ +"""Test reproduce state for Cover.""" +from homeassistant.components.cover import ( + ATTR_CURRENT_POSITION, + ATTR_CURRENT_TILT_POSITION, + ATTR_POSITION, + ATTR_TILT_POSITION, +) +from homeassistant.const import ( + SERVICE_CLOSE_COVER, + SERVICE_CLOSE_COVER_TILT, + SERVICE_OPEN_COVER, + SERVICE_OPEN_COVER_TILT, + SERVICE_SET_COVER_POSITION, + SERVICE_SET_COVER_TILT_POSITION, + STATE_CLOSED, + STATE_OPEN, +) +from homeassistant.core import State +from tests.common import async_mock_service + + +async def test_reproducing_states(hass, caplog): + """Test reproducing Cover states.""" + hass.states.async_set("cover.entity_close", STATE_CLOSED, {}) + hass.states.async_set( + "cover.entity_close_attr", + STATE_CLOSED, + {ATTR_CURRENT_POSITION: 0, ATTR_CURRENT_TILT_POSITION: 0}, + ) + hass.states.async_set( + "cover.entity_close_tilt", STATE_CLOSED, {ATTR_CURRENT_TILT_POSITION: 50} + ) + hass.states.async_set("cover.entity_open", STATE_OPEN, {}) + hass.states.async_set( + "cover.entity_slightly_open", STATE_OPEN, {ATTR_CURRENT_POSITION: 50} + ) + hass.states.async_set( + "cover.entity_open_attr", + STATE_OPEN, + {ATTR_CURRENT_POSITION: 100, ATTR_CURRENT_TILT_POSITION: 0}, + ) + hass.states.async_set( + "cover.entity_open_tilt", + STATE_OPEN, + {ATTR_CURRENT_POSITION: 50, ATTR_CURRENT_TILT_POSITION: 50}, + ) + hass.states.async_set( + "cover.entity_entirely_open", + STATE_OPEN, + {ATTR_CURRENT_POSITION: 100, ATTR_CURRENT_TILT_POSITION: 100}, + ) + + close_calls = async_mock_service(hass, "cover", SERVICE_CLOSE_COVER) + open_calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER) + close_tilt_calls = async_mock_service(hass, "cover", SERVICE_CLOSE_COVER_TILT) + open_tilt_calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER_TILT) + position_calls = async_mock_service(hass, "cover", SERVICE_SET_COVER_POSITION) + position_tilt_calls = async_mock_service( + hass, "cover", SERVICE_SET_COVER_TILT_POSITION + ) + + # These calls should do nothing as entities already in desired state + await hass.helpers.state.async_reproduce_state( + [ + State("cover.entity_close", STATE_CLOSED), + State( + "cover.entity_close_attr", + STATE_CLOSED, + {ATTR_CURRENT_POSITION: 0, ATTR_CURRENT_TILT_POSITION: 0}, + ), + State( + "cover.entity_close_tilt", + STATE_CLOSED, + {ATTR_CURRENT_TILT_POSITION: 50}, + ), + State("cover.entity_open", STATE_OPEN), + State( + "cover.entity_slightly_open", STATE_OPEN, {ATTR_CURRENT_POSITION: 50} + ), + State( + "cover.entity_open_attr", + STATE_OPEN, + {ATTR_CURRENT_POSITION: 100, ATTR_CURRENT_TILT_POSITION: 0}, + ), + State( + "cover.entity_open_tilt", + STATE_OPEN, + {ATTR_CURRENT_POSITION: 50, ATTR_CURRENT_TILT_POSITION: 50}, + ), + State( + "cover.entity_entirely_open", + STATE_OPEN, + {ATTR_CURRENT_POSITION: 100, ATTR_CURRENT_TILT_POSITION: 100}, + ), + ], + blocking=True, + ) + + assert len(close_calls) == 0 + assert len(open_calls) == 0 + assert len(close_tilt_calls) == 0 + assert len(open_tilt_calls) == 0 + assert len(position_calls) == 0 + assert len(position_tilt_calls) == 0 + + # Test invalid state is handled + await hass.helpers.state.async_reproduce_state( + [State("cover.entity_close", "not_supported")], blocking=True + ) + + assert "not_supported" in caplog.text + assert len(close_calls) == 0 + assert len(open_calls) == 0 + assert len(close_tilt_calls) == 0 + assert len(open_tilt_calls) == 0 + assert len(position_calls) == 0 + assert len(position_tilt_calls) == 0 + + # Make sure correct services are called + await hass.helpers.state.async_reproduce_state( + [ + State("cover.entity_close", STATE_OPEN), + State( + "cover.entity_close_attr", + STATE_OPEN, + {ATTR_CURRENT_POSITION: 50, ATTR_CURRENT_TILT_POSITION: 50}, + ), + State( + "cover.entity_close_tilt", + STATE_CLOSED, + {ATTR_CURRENT_TILT_POSITION: 100}, + ), + State("cover.entity_open", STATE_CLOSED), + State("cover.entity_slightly_open", STATE_OPEN, {}), + State("cover.entity_open_attr", STATE_CLOSED, {}), + State( + "cover.entity_open_tilt", STATE_OPEN, {ATTR_CURRENT_TILT_POSITION: 0} + ), + State( + "cover.entity_entirely_open", + STATE_CLOSED, + {ATTR_CURRENT_POSITION: 0, ATTR_CURRENT_TILT_POSITION: 0}, + ), + # Should not raise + State("cover.non_existing", "on"), + ], + blocking=True, + ) + + valid_close_calls = [ + {"entity_id": "cover.entity_open"}, + {"entity_id": "cover.entity_open_attr"}, + {"entity_id": "cover.entity_entirely_open"}, + ] + assert len(close_calls) == 3 + for call in close_calls: + assert call.domain == "cover" + assert call.data in valid_close_calls + valid_close_calls.remove(call.data) + + valid_open_calls = [ + {"entity_id": "cover.entity_close"}, + {"entity_id": "cover.entity_slightly_open"}, + {"entity_id": "cover.entity_open_tilt"}, + ] + assert len(open_calls) == 3 + for call in open_calls: + assert call.domain == "cover" + assert call.data in valid_open_calls + valid_open_calls.remove(call.data) + + valid_close_tilt_calls = [ + {"entity_id": "cover.entity_open_tilt"}, + {"entity_id": "cover.entity_entirely_open"}, + ] + assert len(close_tilt_calls) == 2 + for call in close_tilt_calls: + assert call.domain == "cover" + assert call.data in valid_close_tilt_calls + valid_close_tilt_calls.remove(call.data) + + assert len(open_tilt_calls) == 1 + assert open_tilt_calls[0].domain == "cover" + assert open_tilt_calls[0].data == {"entity_id": "cover.entity_close_tilt"} + + assert len(position_calls) == 1 + assert position_calls[0].domain == "cover" + assert position_calls[0].data == { + "entity_id": "cover.entity_close_attr", + ATTR_POSITION: 50, + } + + assert len(position_tilt_calls) == 1 + assert position_tilt_calls[0].domain == "cover" + assert position_tilt_calls[0].data == { + "entity_id": "cover.entity_close_attr", + ATTR_TILT_POSITION: 50, + }