diff --git a/homeassistant/components/light/reproduce_state.py b/homeassistant/components/light/reproduce_state.py new file mode 100644 index 00000000000..ae618f7a8ef --- /dev/null +++ b/homeassistant/components/light/reproduce_state.py @@ -0,0 +1,94 @@ +"""Reproduce an Light state.""" +import asyncio +import logging +from types import MappingProxyType +from typing import Iterable, Optional + +from homeassistant.const import ( + ATTR_ENTITY_ID, + STATE_ON, + STATE_OFF, + SERVICE_TURN_OFF, + SERVICE_TURN_ON, +) +from homeassistant.core import Context, State +from homeassistant.helpers.typing import HomeAssistantType + +from . import ( + DOMAIN, + ATTR_BRIGHTNESS, + ATTR_COLOR_TEMP, + ATTR_EFFECT, + ATTR_HS_COLOR, + ATTR_RGB_COLOR, + ATTR_WHITE_VALUE, + ATTR_XY_COLOR, +) + +_LOGGER = logging.getLogger(__name__) + +VALID_STATES = {STATE_ON, STATE_OFF} +ATTR_GROUP = [ATTR_BRIGHTNESS, ATTR_EFFECT, ATTR_WHITE_VALUE] +COLOR_GROUP = [ATTR_COLOR_TEMP, ATTR_HS_COLOR, ATTR_RGB_COLOR, ATTR_XY_COLOR] + + +async def _async_reproduce_state( + hass: HomeAssistantType, state: State, context: Optional[Context] = None +) -> None: + """Reproduce a single state.""" + cur_state = hass.states.get(state.entity_id) + + if cur_state is None: + _LOGGER.warning("Unable to find entity %s", state.entity_id) + return + + if state.state not in VALID_STATES: + _LOGGER.warning( + "Invalid state specified for %s: %s", state.entity_id, state.state + ) + return + + # Return if we are already at the right state. + if cur_state.state == state.state and all( + check_attr_equal(cur_state.attributes, state.attributes, attr) + for attr in ATTR_GROUP + COLOR_GROUP + ): + return + + service_data = {ATTR_ENTITY_ID: state.entity_id} + + if state.state == STATE_ON: + service = SERVICE_TURN_ON + for attr in ATTR_GROUP: + # All attributes that are not colors + if attr in state.attributes: + service_data[attr] = state.attributes[attr] + + for color_attr in COLOR_GROUP: + # Choose the first color that is specified + if color_attr in state.attributes: + service_data[color_attr] = state.attributes[color_attr] + break + + elif state.state == STATE_OFF: + service = SERVICE_TURN_OFF + + await hass.services.async_call( + DOMAIN, service, service_data, context=context, blocking=True + ) + + +async def async_reproduce_states( + hass: HomeAssistantType, states: Iterable[State], context: Optional[Context] = None +) -> None: + """Reproduce Light states.""" + await asyncio.gather( + *(_async_reproduce_state(hass, state, context) for state in states) + ) + + +def check_attr_equal( + attr1: MappingProxyType, attr2: MappingProxyType, attr_str: str +) -> bool: + """Return true if the given attributes are equal.""" + return attr1.get(attr_str) == attr2.get(attr_str) diff --git a/tests/components/light/test_reproduce_state.py b/tests/components/light/test_reproduce_state.py new file mode 100644 index 00000000000..92790890a4c --- /dev/null +++ b/tests/components/light/test_reproduce_state.py @@ -0,0 +1,117 @@ +"""Test reproduce state for Light.""" +from homeassistant.core import State + +from tests.common import async_mock_service + +VALID_BRIGHTNESS = {"brightness": 180} +VALID_WHITE_VALUE = {"white_value": 200} +VALID_EFFECT = {"effect": "random"} +VALID_COLOR_TEMP = {"color_temp": 240} +VALID_HS_COLOR = {"hs_color": (345, 75)} +VALID_RGB_COLOR = {"rgb_color": (255, 63, 111)} +VALID_XY_COLOR = {"xy_color": (0.59, 0.274)} + + +async def test_reproducing_states(hass, caplog): + """Test reproducing Light states.""" + hass.states.async_set("light.entity_off", "off", {}) + hass.states.async_set("light.entity_bright", "on", VALID_BRIGHTNESS) + hass.states.async_set("light.entity_white", "on", VALID_WHITE_VALUE) + hass.states.async_set("light.entity_effect", "on", VALID_EFFECT) + hass.states.async_set("light.entity_temp", "on", VALID_COLOR_TEMP) + hass.states.async_set("light.entity_hs", "on", VALID_HS_COLOR) + hass.states.async_set("light.entity_rgb", "on", VALID_RGB_COLOR) + hass.states.async_set("light.entity_xy", "on", VALID_XY_COLOR) + + turn_on_calls = async_mock_service(hass, "light", "turn_on") + turn_off_calls = async_mock_service(hass, "light", "turn_off") + + # These calls should do nothing as entities already in desired state + await hass.helpers.state.async_reproduce_state( + [ + State("light.entity_off", "off"), + State("light.entity_bright", "on", VALID_BRIGHTNESS), + State("light.entity_white", "on", VALID_WHITE_VALUE), + State("light.entity_effect", "on", VALID_EFFECT), + State("light.entity_temp", "on", VALID_COLOR_TEMP), + State("light.entity_hs", "on", VALID_HS_COLOR), + State("light.entity_rgb", "on", VALID_RGB_COLOR), + State("light.entity_xy", "on", VALID_XY_COLOR), + ], + blocking=True, + ) + + assert len(turn_on_calls) == 0 + assert len(turn_off_calls) == 0 + + # Test invalid state is handled + await hass.helpers.state.async_reproduce_state( + [State("light.entity_off", "not_supported")], blocking=True + ) + + assert "not_supported" in caplog.text + assert len(turn_on_calls) == 0 + assert len(turn_off_calls) == 0 + + # Make sure correct services are called + await hass.helpers.state.async_reproduce_state( + [ + State("light.entity_xy", "off"), + State("light.entity_off", "on", VALID_BRIGHTNESS), + State("light.entity_bright", "on", VALID_WHITE_VALUE), + State("light.entity_white", "on", VALID_EFFECT), + State("light.entity_effect", "on", VALID_COLOR_TEMP), + State("light.entity_temp", "on", VALID_HS_COLOR), + State("light.entity_hs", "on", VALID_RGB_COLOR), + State("light.entity_rgb", "on", VALID_XY_COLOR), + ], + blocking=True, + ) + + assert len(turn_on_calls) == 7 + + expected_calls = [] + + expected_off = VALID_BRIGHTNESS + expected_off["entity_id"] = "light.entity_off" + expected_calls.append(expected_off) + + expected_bright = VALID_WHITE_VALUE + expected_bright["entity_id"] = "light.entity_bright" + expected_calls.append(expected_bright) + + expected_white = VALID_EFFECT + expected_white["entity_id"] = "light.entity_white" + expected_calls.append(expected_white) + + expected_effect = VALID_COLOR_TEMP + expected_effect["entity_id"] = "light.entity_effect" + expected_calls.append(expected_effect) + + expected_temp = VALID_HS_COLOR + expected_temp["entity_id"] = "light.entity_temp" + expected_calls.append(expected_temp) + + expected_hs = VALID_RGB_COLOR + expected_hs["entity_id"] = "light.entity_hs" + expected_calls.append(expected_hs) + + expected_rgb = VALID_XY_COLOR + expected_rgb["entity_id"] = "light.entity_rgb" + expected_calls.append(expected_rgb) + + for call in turn_on_calls: + assert call.domain == "light" + found = False + for expected in expected_calls: + if call.data["entity_id"] == expected["entity_id"]: + # We found the matching entry + assert call.data == expected + found = True + break + # No entry found + assert found + + assert len(turn_off_calls) == 1 + assert turn_off_calls[0].domain == "light" + assert turn_off_calls[0].data == {"entity_id": "light.entity_xy"} diff --git a/tests/helpers/test_state.py b/tests/helpers/test_state.py index 7f428c0833d..14bcbde5094 100644 --- a/tests/helpers/test_state.py +++ b/tests/helpers/test_state.py @@ -129,7 +129,7 @@ async def test_reproduce_turn_on(hass): last_call = calls[-1] assert last_call.domain == "light" assert SERVICE_TURN_ON == last_call.service - assert ["light.test"] == last_call.data.get("entity_id") + assert "light.test" == last_call.data.get("entity_id") async def test_reproduce_turn_off(hass): @@ -146,7 +146,7 @@ async def test_reproduce_turn_off(hass): last_call = calls[-1] assert last_call.domain == "light" assert SERVICE_TURN_OFF == last_call.service - assert ["light.test"] == last_call.data.get("entity_id") + assert "light.test" == last_call.data.get("entity_id") async def test_reproduce_complex_data(hass): @@ -155,10 +155,10 @@ async def test_reproduce_complex_data(hass): hass.states.async_set("light.test", "off") - complex_data = ["hello", {"11": "22"}] + complex_data = [255, 100, 100] await state.async_reproduce_state( - hass, ha.State("light.test", "on", {"complex": complex_data}) + hass, ha.State("light.test", "on", {"rgb_color": complex_data}) ) await hass.async_block_till_done() @@ -167,7 +167,7 @@ async def test_reproduce_complex_data(hass): last_call = calls[-1] assert last_call.domain == "light" assert SERVICE_TURN_ON == last_call.service - assert complex_data == last_call.data.get("complex") + assert complex_data == last_call.data.get("rgb_color") async def test_reproduce_bad_state(hass):