diff --git a/homeassistant/components/select/reproduce_state.py b/homeassistant/components/select/reproduce_state.py new file mode 100644 index 00000000000..8af4b94fd6f --- /dev/null +++ b/homeassistant/components/select/reproduce_state.py @@ -0,0 +1,66 @@ +"""Reproduce a Select entity state.""" +from __future__ import annotations + +import asyncio +from collections.abc import Iterable +import logging +from typing import Any + +from homeassistant.components.select.const import ATTR_OPTIONS +from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.core import Context, HomeAssistant, State + +from . import ATTR_OPTION, DOMAIN, SERVICE_SELECT_OPTION + +_LOGGER = logging.getLogger(__name__) + + +async def _async_reproduce_state( + hass: HomeAssistant, + state: State, + *, + context: Context | None = None, + reproduce_options: dict[str, Any] | None = None, +) -> None: + """Reproduce a single state.""" + cur_state = hass.states.get(state.entity_id) + + if cur_state is None: + _LOGGER.warning("Unable to find entity %s", state.entity_id) + return + + if state.state not in cur_state.attributes.get(ATTR_OPTIONS, []): + _LOGGER.warning( + "Invalid state specified for %s: %s", state.entity_id, state.state + ) + return + + # Return if we are already at the right state. + if cur_state.state == state.state: + return + + await hass.services.async_call( + DOMAIN, + SERVICE_SELECT_OPTION, + {ATTR_ENTITY_ID: state.entity_id, ATTR_OPTION: state.state}, + context=context, + blocking=True, + ) + + +async def async_reproduce_states( + hass: HomeAssistant, + states: Iterable[State], + *, + context: Context | None = None, + reproduce_options: dict[str, Any] | None = None, +) -> None: + """Reproduce multiple select states.""" + await asyncio.gather( + *( + _async_reproduce_state( + hass, state, context=context, reproduce_options=reproduce_options + ) + for state in states + ) + ) diff --git a/tests/components/select/test_reproduce_state.py b/tests/components/select/test_reproduce_state.py new file mode 100644 index 00000000000..b1ab3a0a5aa --- /dev/null +++ b/tests/components/select/test_reproduce_state.py @@ -0,0 +1,57 @@ +"""Test reproduce state for select entities.""" +import pytest + +from homeassistant.components.select.const import ( + ATTR_OPTION, + ATTR_OPTIONS, + DOMAIN, + SERVICE_SELECT_OPTION, +) +from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.core import HomeAssistant, State + +from tests.common import async_mock_service + + +async def test_reproducing_states( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test reproducing select states.""" + calls = async_mock_service(hass, DOMAIN, SERVICE_SELECT_OPTION) + hass.states.async_set( + "select.test", + "option_one", + {ATTR_OPTIONS: ["option_one", "option_two", "option_three"]}, + ) + + await hass.helpers.state.async_reproduce_state( + [ + State("select.test", "option_two"), + ], + ) + + assert len(calls) == 1 + assert calls[0].domain == DOMAIN + assert calls[0].data == {ATTR_ENTITY_ID: "select.test", ATTR_OPTION: "option_two"} + + # Calling it again should not do anything + await hass.helpers.state.async_reproduce_state( + [ + State("select.test", "option_one"), + ], + ) + assert len(calls) == 1 + + # Restoring an invalid state should not work either + await hass.helpers.state.async_reproduce_state( + [State("select.test", "option_four")] + ) + assert len(calls) == 1 + assert "Invalid state specified" in caplog.text + + # Restoring an state for an invalid entity ID logs a warning + await hass.helpers.state.async_reproduce_state( + [State("select.non_existing", "option_three")] + ) + assert len(calls) == 1 + assert "Unable to find entity" in caplog.text